directory: support non-base64 encoded service accounts (#3150)

This commit is contained in:
Caleb Doxsey 2022-03-14 14:38:41 -06:00 committed by GitHub
parent 925fc29ab8
commit f894205d08
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
14 changed files with 267 additions and 51 deletions

View file

@ -3,7 +3,6 @@ package azure
import ( import (
"context" "context"
"encoding/base64"
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/http" "net/http"
@ -15,6 +14,7 @@ import (
"github.com/rs/zerolog" "github.com/rs/zerolog"
"golang.org/x/oauth2" "golang.org/x/oauth2"
"github.com/pomerium/pomerium/internal/encoding"
"github.com/pomerium/pomerium/internal/httputil" "github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/pkg/grpc/directory" "github.com/pomerium/pomerium/pkg/grpc/directory"
) )
@ -301,14 +301,8 @@ func parseServiceAccountFromOptions(clientID, clientSecret, providerURL string)
} }
func parseServiceAccountFromString(rawServiceAccount string) (*ServiceAccount, error) { func parseServiceAccountFromString(rawServiceAccount string) (*ServiceAccount, error) {
bs, err := base64.StdEncoding.DecodeString(rawServiceAccount)
if err != nil {
return nil, err
}
var serviceAccount ServiceAccount var serviceAccount ServiceAccount
err = json.Unmarshal(bs, &serviceAccount) if err := encoding.DecodeBase64OrJSON(rawServiceAccount, &serviceAccount); err != nil {
if err != nil {
return nil, err return nil, err
} }

View file

@ -218,7 +218,7 @@ func TestParseServiceAccount(t *testing.T) {
DirectoryID: "0303f438-3c5c-4190-9854-08d3eb31bd9f", DirectoryID: "0303f438-3c5c-4190-9854-08d3eb31bd9f",
}, serviceAccount) }, serviceAccount)
}) })
t.Run("by service account", func(t *testing.T) { t.Run("by service account base64", func(t *testing.T) {
serviceAccount, err := ParseServiceAccount(directory.Options{ serviceAccount, err := ParseServiceAccount(directory.Options{
ServiceAccount: base64.StdEncoding.EncodeToString([]byte(`{ ServiceAccount: base64.StdEncoding.EncodeToString([]byte(`{
"client_id": "CLIENT_ID", "client_id": "CLIENT_ID",
@ -230,6 +230,24 @@ func TestParseServiceAccount(t *testing.T) {
return return
} }
assert.Equal(t, &ServiceAccount{
ClientID: "CLIENT_ID",
ClientSecret: "CLIENT_SECRET",
DirectoryID: "0303f438-3c5c-4190-9854-08d3eb31bd9f",
}, serviceAccount)
})
t.Run("by service account json", func(t *testing.T) {
serviceAccount, err := ParseServiceAccount(directory.Options{
ServiceAccount: `{
"client_id": "CLIENT_ID",
"client_secret": "CLIENT_SECRET",
"directory_id": "0303f438-3c5c-4190-9854-08d3eb31bd9f"
}`,
})
if !assert.NoError(t, err) {
return
}
assert.Equal(t, &ServiceAccount{ assert.Equal(t, &ServiceAccount{
ClientID: "CLIENT_ID", ClientID: "CLIENT_ID",
ClientSecret: "CLIENT_SECRET", ClientSecret: "CLIENT_SECRET",

View file

@ -4,7 +4,6 @@ package github
import ( import (
"bytes" "bytes"
"context" "context"
"encoding/base64"
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/http" "net/http"
@ -14,6 +13,7 @@ import (
"github.com/rs/zerolog" "github.com/rs/zerolog"
"github.com/tomnomnom/linkheader" "github.com/tomnomnom/linkheader"
"github.com/pomerium/pomerium/internal/encoding"
"github.com/pomerium/pomerium/internal/httputil" "github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/pkg/grpc/directory" "github.com/pomerium/pomerium/pkg/grpc/directory"
@ -300,14 +300,8 @@ type ServiceAccount struct {
// ParseServiceAccount parses the service account in the config options. // ParseServiceAccount parses the service account in the config options.
func ParseServiceAccount(rawServiceAccount string) (*ServiceAccount, error) { func ParseServiceAccount(rawServiceAccount string) (*ServiceAccount, error) {
bs, err := base64.StdEncoding.DecodeString(rawServiceAccount)
if err != nil {
return nil, err
}
var serviceAccount ServiceAccount var serviceAccount ServiceAccount
err = json.Unmarshal(bs, &serviceAccount) if err := encoding.DecodeBase64OrJSON(rawServiceAccount, &serviceAccount); err != nil {
if err != nil {
return nil, err return nil, err
} }

View file

@ -11,6 +11,7 @@ import (
"github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5"
"github.com/go-chi/chi/v5/middleware" "github.com/go-chi/chi/v5/middleware"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/vektah/gqlparser/ast" "github.com/vektah/gqlparser/ast"
"github.com/vektah/gqlparser/parser" "github.com/vektah/gqlparser/parser"
@ -347,6 +348,49 @@ func TestProvider_UserGroups(t *testing.T) {
]`, groups) ]`, groups)
} }
func TestParseServiceAccount(t *testing.T) {
tests := []struct {
name string
rawServiceAccount string
serviceAccount *ServiceAccount
wantErr bool
}{
{
"json",
`{"username": "USERNAME", "personal_access_token": "PERSONAL_ACCESS_TOKEN"}`,
&ServiceAccount{Username: "USERNAME", PersonalAccessToken: "PERSONAL_ACCESS_TOKEN"},
false,
},
{
"base64 json",
`eyJ1c2VybmFtZSI6ICJVU0VSTkFNRSIsICJwZXJzb25hbF9hY2Nlc3NfdG9rZW4iOiAiUEVSU09OQUxfQUNDRVNTX1RPS0VOIn0=`,
&ServiceAccount{Username: "USERNAME", PersonalAccessToken: "PERSONAL_ACCESS_TOKEN"},
false,
},
{
"empty",
"",
nil,
true,
},
{
"invalid",
"Zm9v---",
nil,
true,
},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
got, err := ParseServiceAccount(tc.rawServiceAccount)
require.True(t, (err != nil) == tc.wantErr)
assert.Equal(t, tc.serviceAccount, got)
})
}
}
func mustParseURL(rawurl string) *url.URL { func mustParseURL(rawurl string) *url.URL {
u, err := url.Parse(rawurl) u, err := url.Parse(rawurl)
if err != nil { if err != nil {

View file

@ -3,7 +3,6 @@ package gitlab
import ( import (
"context" "context"
"encoding/base64"
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/http" "net/http"
@ -14,6 +13,7 @@ import (
"github.com/rs/zerolog" "github.com/rs/zerolog"
"github.com/tomnomnom/linkheader" "github.com/tomnomnom/linkheader"
"github.com/pomerium/pomerium/internal/encoding"
"github.com/pomerium/pomerium/internal/httputil" "github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/pkg/grpc/directory" "github.com/pomerium/pomerium/pkg/grpc/directory"
@ -267,14 +267,8 @@ type ServiceAccount struct {
// ParseServiceAccount parses the service account in the config options. // ParseServiceAccount parses the service account in the config options.
func ParseServiceAccount(rawServiceAccount string) (*ServiceAccount, error) { func ParseServiceAccount(rawServiceAccount string) (*ServiceAccount, error) {
bs, err := base64.StdEncoding.DecodeString(rawServiceAccount)
if err != nil {
return nil, err
}
var serviceAccount ServiceAccount var serviceAccount ServiceAccount
err = json.Unmarshal(bs, &serviceAccount) if err := encoding.DecodeBase64OrJSON(rawServiceAccount, &serviceAccount); err != nil {
if err != nil {
return nil, err return nil, err
} }

View file

@ -11,6 +11,7 @@ import (
"github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5"
"github.com/go-chi/chi/v5/middleware" "github.com/go-chi/chi/v5/middleware"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/pomerium/pomerium/internal/testutil" "github.com/pomerium/pomerium/internal/testutil"
) )
@ -79,6 +80,49 @@ func Test(t *testing.T) {
]`, groups) ]`, groups)
} }
func TestParseServiceAccount(t *testing.T) {
tests := []struct {
name string
rawServiceAccount string
serviceAccount *ServiceAccount
wantErr bool
}{
{
"json",
`{"private_token":"PRIVATE_TOKEN"}`,
&ServiceAccount{PrivateToken: "PRIVATE_TOKEN"},
false,
},
{
"base64 json",
`eyJwcml2YXRlX3Rva2VuIjoiUFJJVkFURV9UT0tFTiJ9`,
&ServiceAccount{PrivateToken: "PRIVATE_TOKEN"},
false,
},
{
"empty",
"",
nil,
true,
},
{
"invalid",
"Zm9v---",
nil,
true,
},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
got, err := ParseServiceAccount(tc.rawServiceAccount)
require.True(t, (err != nil) == tc.wantErr)
assert.Equal(t, tc.serviceAccount, got)
})
}
}
func mustParseURL(rawurl string) *url.URL { func mustParseURL(rawurl string) *url.URL {
u, err := url.Parse(rawurl) u, err := url.Parse(rawurl)
if err != nil { if err != nil {

View file

@ -3,7 +3,6 @@ package google
import ( import (
"context" "context"
"encoding/base64"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
@ -18,6 +17,7 @@ import (
"google.golang.org/api/option" "google.golang.org/api/option"
"github.com/pomerium/pomerium/internal/directory/directoryerrors" "github.com/pomerium/pomerium/internal/directory/directoryerrors"
"github.com/pomerium/pomerium/internal/encoding"
"github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/pkg/grpc/directory" "github.com/pomerium/pomerium/pkg/grpc/directory"
) )
@ -291,14 +291,8 @@ type ServiceAccount struct {
// ParseServiceAccount parses the service account in the config options. // ParseServiceAccount parses the service account in the config options.
func ParseServiceAccount(rawServiceAccount string) (*ServiceAccount, error) { func ParseServiceAccount(rawServiceAccount string) (*ServiceAccount, error) {
bs, err := base64.StdEncoding.DecodeString(rawServiceAccount)
if err != nil {
return nil, err
}
var serviceAccount ServiceAccount var serviceAccount ServiceAccount
err = json.Unmarshal(bs, &serviceAccount) if err := encoding.DecodeBase64OrJSON(rawServiceAccount, &serviceAccount); err != nil {
if err != nil {
return nil, err return nil, err
} }

View file

@ -11,6 +11,7 @@ import (
"github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5"
"github.com/go-chi/chi/v5/middleware" "github.com/go-chi/chi/v5/middleware"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/pomerium/pomerium/internal/directory/directoryerrors" "github.com/pomerium/pomerium/internal/directory/directoryerrors"
"github.com/pomerium/pomerium/internal/testutil" "github.com/pomerium/pomerium/internal/testutil"
@ -214,3 +215,46 @@ func TestProvider_UserGroups(t *testing.T) {
{ "id": "outside-user1", "email": "user1@outside.test", "groupIds": ["group1"] } { "id": "outside-user1", "email": "user1@outside.test", "groupIds": ["group1"] }
]`, dus) ]`, dus)
} }
func TestParseServiceAccount(t *testing.T) {
tests := []struct {
name string
rawServiceAccount string
serviceAccount *ServiceAccount
wantErr bool
}{
{
"json",
`{"impersonate_user":"IMPERSONATE_USER"}`,
&ServiceAccount{ImpersonateUser: "IMPERSONATE_USER"},
false,
},
{
"base64 json",
`eyJpbXBlcnNvbmF0ZV91c2VyIjoiSU1QRVJTT05BVEVfVVNFUiJ9`,
&ServiceAccount{ImpersonateUser: "IMPERSONATE_USER"},
false,
},
{
"empty",
"",
nil,
true,
},
{
"invalid",
"Zm9v---",
nil,
true,
},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
got, err := ParseServiceAccount(tc.rawServiceAccount)
require.True(t, (err != nil) == tc.wantErr)
assert.Equal(t, tc.serviceAccount, got)
})
}
}

View file

@ -17,6 +17,7 @@ import (
"github.com/rs/zerolog" "github.com/rs/zerolog"
"github.com/tomnomnom/linkheader" "github.com/tomnomnom/linkheader"
"github.com/pomerium/pomerium/internal/encoding"
"github.com/pomerium/pomerium/internal/httputil" "github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/pkg/grpc/directory" "github.com/pomerium/pomerium/pkg/grpc/directory"
@ -350,13 +351,13 @@ type ServiceAccount struct {
// ParseServiceAccount parses the service account in the config options. // ParseServiceAccount parses the service account in the config options.
func ParseServiceAccount(rawServiceAccount string) (*ServiceAccount, error) { func ParseServiceAccount(rawServiceAccount string) (*ServiceAccount, error) {
bs, err := base64.StdEncoding.DecodeString(rawServiceAccount)
if err != nil {
return nil, err
}
var serviceAccount ServiceAccount var serviceAccount ServiceAccount
if err := json.Unmarshal(bs, &serviceAccount); err != nil { err := encoding.DecodeBase64OrJSON(rawServiceAccount, &serviceAccount)
if err != nil {
bs, err := base64.StdEncoding.DecodeString(rawServiceAccount)
if err != nil {
return nil, err
}
serviceAccount.APIKey = string(bs) serviceAccount.APIKey = string(bs)
} }

View file

@ -341,8 +341,9 @@ func TestParseServiceAccount(t *testing.T) {
apiKey string apiKey string
wantErr bool wantErr bool
}{ }{
{"json", "ewogICAgImFwaV9rZXkiOiAiZm9vIgp9Cg==", "foo", false}, {"json", `{"api_key": "foo"}`, "foo", false},
{"value", "Zm9v", "foo", false}, {"base64 json", "ewogICAgImFwaV9rZXkiOiAiZm9vIgp9Cg==", "foo", false},
{"base64 value", "Zm9v", "foo", false},
{"empty", "", "", true}, {"empty", "", "", true},
{"invalid", "Zm9v---", "", true}, {"invalid", "Zm9v---", "", true},
} }

View file

@ -3,7 +3,6 @@ package onelogin
import ( import (
"context" "context"
"encoding/base64"
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/http" "net/http"
@ -16,6 +15,7 @@ import (
"github.com/rs/zerolog" "github.com/rs/zerolog"
"golang.org/x/oauth2" "golang.org/x/oauth2"
"github.com/pomerium/pomerium/internal/encoding"
"github.com/pomerium/pomerium/internal/httputil" "github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/pkg/grpc/directory" "github.com/pomerium/pomerium/pkg/grpc/directory"
@ -332,14 +332,8 @@ type ServiceAccount struct {
// ParseServiceAccount parses the service account in the config options. // ParseServiceAccount parses the service account in the config options.
func ParseServiceAccount(rawServiceAccount string) (*ServiceAccount, error) { func ParseServiceAccount(rawServiceAccount string) (*ServiceAccount, error) {
bs, err := base64.StdEncoding.DecodeString(rawServiceAccount)
if err != nil {
return nil, err
}
var serviceAccount ServiceAccount var serviceAccount ServiceAccount
err = json.Unmarshal(bs, &serviceAccount) if err := encoding.DecodeBase64OrJSON(rawServiceAccount, &serviceAccount); err != nil {
if err != nil {
return nil, err return nil, err
} }

View file

@ -15,6 +15,7 @@ import (
"github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5"
"github.com/go-chi/chi/v5/middleware" "github.com/go-chi/chi/v5/middleware"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/pomerium/pomerium/internal/testutil" "github.com/pomerium/pomerium/internal/testutil"
) )
@ -217,6 +218,49 @@ func TestProvider_UserGroups(t *testing.T) {
]`, groups) ]`, groups)
} }
func TestParseServiceAccount(t *testing.T) {
tests := []struct {
name string
rawServiceAccount string
serviceAccount *ServiceAccount
wantErr bool
}{
{
"json",
`{"client_id":"CLIENT_ID","client_secret":"CLIENT_SECRET"}`,
&ServiceAccount{ClientID: "CLIENT_ID", ClientSecret: "CLIENT_SECRET"},
false,
},
{
"base64 json",
`eyJjbGllbnRfaWQiOiJDTElFTlRfSUQiLCJjbGllbnRfc2VjcmV0IjoiQ0xJRU5UX1NFQ1JFVCJ9`,
&ServiceAccount{ClientID: "CLIENT_ID", ClientSecret: "CLIENT_SECRET"},
false,
},
{
"empty",
"",
nil,
true,
},
{
"invalid",
"Zm9v---",
nil,
true,
},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
got, err := ParseServiceAccount(tc.rawServiceAccount)
require.True(t, (err != nil) == tc.wantErr)
assert.Equal(t, tc.serviceAccount, got)
})
}
}
func mustParseURL(rawurl string) *url.URL { func mustParseURL(rawurl string) *url.URL {
u, err := url.Parse(rawurl) u, err := url.Parse(rawurl)
if err != nil { if err != nil {

View file

@ -101,8 +101,7 @@ type ServiceAccount struct {
// ParseServiceAccount parses the service account in the config options. // ParseServiceAccount parses the service account in the config options.
func ParseServiceAccount(rawServiceAccount string) (*ServiceAccount, error) { func ParseServiceAccount(rawServiceAccount string) (*ServiceAccount, error) {
var serviceAccount ServiceAccount var serviceAccount ServiceAccount
err := encoding.DecodeBase64OrJSON(rawServiceAccount, &serviceAccount) if err := encoding.DecodeBase64OrJSON(rawServiceAccount, &serviceAccount); err != nil {
if err != nil {
return nil, err return nil, err
} }

View file

@ -0,0 +1,51 @@
package ping
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestParseServiceAccount(t *testing.T) {
tests := []struct {
name string
rawServiceAccount string
serviceAccount *ServiceAccount
wantErr bool
}{
{
"json",
`{"client_id":"CLIENT_ID","client_secret":"CLIENT_SECRET","environment_id":"ENVIRONMENT_ID"}`,
&ServiceAccount{ClientID: "CLIENT_ID", ClientSecret: "CLIENT_SECRET", EnvironmentID: "ENVIRONMENT_ID"},
false,
},
{
"base64 json",
`eyJjbGllbnRfaWQiOiJDTElFTlRfSUQiLCJjbGllbnRfc2VjcmV0IjoiQ0xJRU5UX1NFQ1JFVCIsImVudmlyb25tZW50X2lkIjoiRU5WSVJPTk1FTlRfSUQifQ==`,
&ServiceAccount{ClientID: "CLIENT_ID", ClientSecret: "CLIENT_SECRET", EnvironmentID: "ENVIRONMENT_ID"},
false,
},
{
"empty",
"",
nil,
true,
},
{
"invalid",
"Zm9v---",
nil,
true,
},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
got, err := ParseServiceAccount(tc.rawServiceAccount)
require.True(t, (err != nil) == tc.wantErr)
assert.Equal(t, tc.serviceAccount, got)
})
}
}