mirror of
https://github.com/pomerium/pomerium.git
synced 2025-06-06 04:42:56 +02:00
directory: support non-base64 encoded service accounts (#3150)
This commit is contained in:
parent
925fc29ab8
commit
f894205d08
14 changed files with 267 additions and 51 deletions
|
@ -3,7 +3,6 @@ package azure
|
|||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
@ -15,6 +14,7 @@ import (
|
|||
"github.com/rs/zerolog"
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/encoding"
|
||||
"github.com/pomerium/pomerium/internal/httputil"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/directory"
|
||||
)
|
||||
|
@ -301,14 +301,8 @@ func parseServiceAccountFromOptions(clientID, clientSecret, providerURL string)
|
|||
}
|
||||
|
||||
func parseServiceAccountFromString(rawServiceAccount string) (*ServiceAccount, error) {
|
||||
bs, err := base64.StdEncoding.DecodeString(rawServiceAccount)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var serviceAccount ServiceAccount
|
||||
err = json.Unmarshal(bs, &serviceAccount)
|
||||
if err != nil {
|
||||
if err := encoding.DecodeBase64OrJSON(rawServiceAccount, &serviceAccount); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
|
|
@ -218,7 +218,7 @@ func TestParseServiceAccount(t *testing.T) {
|
|||
DirectoryID: "0303f438-3c5c-4190-9854-08d3eb31bd9f",
|
||||
}, serviceAccount)
|
||||
})
|
||||
t.Run("by service account", func(t *testing.T) {
|
||||
t.Run("by service account base64", func(t *testing.T) {
|
||||
serviceAccount, err := ParseServiceAccount(directory.Options{
|
||||
ServiceAccount: base64.StdEncoding.EncodeToString([]byte(`{
|
||||
"client_id": "CLIENT_ID",
|
||||
|
@ -230,6 +230,24 @@ func TestParseServiceAccount(t *testing.T) {
|
|||
return
|
||||
}
|
||||
|
||||
assert.Equal(t, &ServiceAccount{
|
||||
ClientID: "CLIENT_ID",
|
||||
ClientSecret: "CLIENT_SECRET",
|
||||
DirectoryID: "0303f438-3c5c-4190-9854-08d3eb31bd9f",
|
||||
}, serviceAccount)
|
||||
})
|
||||
t.Run("by service account json", func(t *testing.T) {
|
||||
serviceAccount, err := ParseServiceAccount(directory.Options{
|
||||
ServiceAccount: `{
|
||||
"client_id": "CLIENT_ID",
|
||||
"client_secret": "CLIENT_SECRET",
|
||||
"directory_id": "0303f438-3c5c-4190-9854-08d3eb31bd9f"
|
||||
}`,
|
||||
})
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
|
||||
assert.Equal(t, &ServiceAccount{
|
||||
ClientID: "CLIENT_ID",
|
||||
ClientSecret: "CLIENT_SECRET",
|
||||
|
|
|
@ -4,7 +4,6 @@ package github
|
|||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
@ -14,6 +13,7 @@ import (
|
|||
"github.com/rs/zerolog"
|
||||
"github.com/tomnomnom/linkheader"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/encoding"
|
||||
"github.com/pomerium/pomerium/internal/httputil"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/directory"
|
||||
|
@ -300,14 +300,8 @@ type ServiceAccount struct {
|
|||
|
||||
// ParseServiceAccount parses the service account in the config options.
|
||||
func ParseServiceAccount(rawServiceAccount string) (*ServiceAccount, error) {
|
||||
bs, err := base64.StdEncoding.DecodeString(rawServiceAccount)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var serviceAccount ServiceAccount
|
||||
err = json.Unmarshal(bs, &serviceAccount)
|
||||
if err != nil {
|
||||
if err := encoding.DecodeBase64OrJSON(rawServiceAccount, &serviceAccount); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
|
|
@ -11,6 +11,7 @@ import (
|
|||
"github.com/go-chi/chi/v5"
|
||||
"github.com/go-chi/chi/v5/middleware"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/vektah/gqlparser/ast"
|
||||
"github.com/vektah/gqlparser/parser"
|
||||
|
||||
|
@ -347,6 +348,49 @@ func TestProvider_UserGroups(t *testing.T) {
|
|||
]`, groups)
|
||||
}
|
||||
|
||||
func TestParseServiceAccount(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
rawServiceAccount string
|
||||
serviceAccount *ServiceAccount
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
"json",
|
||||
`{"username": "USERNAME", "personal_access_token": "PERSONAL_ACCESS_TOKEN"}`,
|
||||
&ServiceAccount{Username: "USERNAME", PersonalAccessToken: "PERSONAL_ACCESS_TOKEN"},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"base64 json",
|
||||
`eyJ1c2VybmFtZSI6ICJVU0VSTkFNRSIsICJwZXJzb25hbF9hY2Nlc3NfdG9rZW4iOiAiUEVSU09OQUxfQUNDRVNTX1RPS0VOIn0=`,
|
||||
&ServiceAccount{Username: "USERNAME", PersonalAccessToken: "PERSONAL_ACCESS_TOKEN"},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"empty",
|
||||
"",
|
||||
nil,
|
||||
true,
|
||||
},
|
||||
{
|
||||
"invalid",
|
||||
"Zm9v---",
|
||||
nil,
|
||||
true,
|
||||
},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got, err := ParseServiceAccount(tc.rawServiceAccount)
|
||||
require.True(t, (err != nil) == tc.wantErr)
|
||||
assert.Equal(t, tc.serviceAccount, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func mustParseURL(rawurl string) *url.URL {
|
||||
u, err := url.Parse(rawurl)
|
||||
if err != nil {
|
||||
|
|
|
@ -3,7 +3,6 @@ package gitlab
|
|||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
@ -14,6 +13,7 @@ import (
|
|||
"github.com/rs/zerolog"
|
||||
"github.com/tomnomnom/linkheader"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/encoding"
|
||||
"github.com/pomerium/pomerium/internal/httputil"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/directory"
|
||||
|
@ -267,14 +267,8 @@ type ServiceAccount struct {
|
|||
|
||||
// ParseServiceAccount parses the service account in the config options.
|
||||
func ParseServiceAccount(rawServiceAccount string) (*ServiceAccount, error) {
|
||||
bs, err := base64.StdEncoding.DecodeString(rawServiceAccount)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var serviceAccount ServiceAccount
|
||||
err = json.Unmarshal(bs, &serviceAccount)
|
||||
if err != nil {
|
||||
if err := encoding.DecodeBase64OrJSON(rawServiceAccount, &serviceAccount); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
|
|
@ -11,6 +11,7 @@ import (
|
|||
"github.com/go-chi/chi/v5"
|
||||
"github.com/go-chi/chi/v5/middleware"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/testutil"
|
||||
)
|
||||
|
@ -79,6 +80,49 @@ func Test(t *testing.T) {
|
|||
]`, groups)
|
||||
}
|
||||
|
||||
func TestParseServiceAccount(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
rawServiceAccount string
|
||||
serviceAccount *ServiceAccount
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
"json",
|
||||
`{"private_token":"PRIVATE_TOKEN"}`,
|
||||
&ServiceAccount{PrivateToken: "PRIVATE_TOKEN"},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"base64 json",
|
||||
`eyJwcml2YXRlX3Rva2VuIjoiUFJJVkFURV9UT0tFTiJ9`,
|
||||
&ServiceAccount{PrivateToken: "PRIVATE_TOKEN"},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"empty",
|
||||
"",
|
||||
nil,
|
||||
true,
|
||||
},
|
||||
{
|
||||
"invalid",
|
||||
"Zm9v---",
|
||||
nil,
|
||||
true,
|
||||
},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got, err := ParseServiceAccount(tc.rawServiceAccount)
|
||||
require.True(t, (err != nil) == tc.wantErr)
|
||||
assert.Equal(t, tc.serviceAccount, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func mustParseURL(rawurl string) *url.URL {
|
||||
u, err := url.Parse(rawurl)
|
||||
if err != nil {
|
||||
|
|
|
@ -3,7 +3,6 @@ package google
|
|||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
@ -18,6 +17,7 @@ import (
|
|||
"google.golang.org/api/option"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/directory/directoryerrors"
|
||||
"github.com/pomerium/pomerium/internal/encoding"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/directory"
|
||||
)
|
||||
|
@ -291,14 +291,8 @@ type ServiceAccount struct {
|
|||
|
||||
// ParseServiceAccount parses the service account in the config options.
|
||||
func ParseServiceAccount(rawServiceAccount string) (*ServiceAccount, error) {
|
||||
bs, err := base64.StdEncoding.DecodeString(rawServiceAccount)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var serviceAccount ServiceAccount
|
||||
err = json.Unmarshal(bs, &serviceAccount)
|
||||
if err != nil {
|
||||
if err := encoding.DecodeBase64OrJSON(rawServiceAccount, &serviceAccount); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
|
|
@ -11,6 +11,7 @@ import (
|
|||
"github.com/go-chi/chi/v5"
|
||||
"github.com/go-chi/chi/v5/middleware"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/directory/directoryerrors"
|
||||
"github.com/pomerium/pomerium/internal/testutil"
|
||||
|
@ -214,3 +215,46 @@ func TestProvider_UserGroups(t *testing.T) {
|
|||
{ "id": "outside-user1", "email": "user1@outside.test", "groupIds": ["group1"] }
|
||||
]`, dus)
|
||||
}
|
||||
|
||||
func TestParseServiceAccount(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
rawServiceAccount string
|
||||
serviceAccount *ServiceAccount
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
"json",
|
||||
`{"impersonate_user":"IMPERSONATE_USER"}`,
|
||||
&ServiceAccount{ImpersonateUser: "IMPERSONATE_USER"},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"base64 json",
|
||||
`eyJpbXBlcnNvbmF0ZV91c2VyIjoiSU1QRVJTT05BVEVfVVNFUiJ9`,
|
||||
&ServiceAccount{ImpersonateUser: "IMPERSONATE_USER"},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"empty",
|
||||
"",
|
||||
nil,
|
||||
true,
|
||||
},
|
||||
{
|
||||
"invalid",
|
||||
"Zm9v---",
|
||||
nil,
|
||||
true,
|
||||
},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got, err := ParseServiceAccount(tc.rawServiceAccount)
|
||||
require.True(t, (err != nil) == tc.wantErr)
|
||||
assert.Equal(t, tc.serviceAccount, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -17,6 +17,7 @@ import (
|
|||
"github.com/rs/zerolog"
|
||||
"github.com/tomnomnom/linkheader"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/encoding"
|
||||
"github.com/pomerium/pomerium/internal/httputil"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/directory"
|
||||
|
@ -350,13 +351,13 @@ type ServiceAccount struct {
|
|||
|
||||
// ParseServiceAccount parses the service account in the config options.
|
||||
func ParseServiceAccount(rawServiceAccount string) (*ServiceAccount, error) {
|
||||
bs, err := base64.StdEncoding.DecodeString(rawServiceAccount)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var serviceAccount ServiceAccount
|
||||
if err := json.Unmarshal(bs, &serviceAccount); err != nil {
|
||||
err := encoding.DecodeBase64OrJSON(rawServiceAccount, &serviceAccount)
|
||||
if err != nil {
|
||||
bs, err := base64.StdEncoding.DecodeString(rawServiceAccount)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
serviceAccount.APIKey = string(bs)
|
||||
}
|
||||
|
||||
|
|
|
@ -341,8 +341,9 @@ func TestParseServiceAccount(t *testing.T) {
|
|||
apiKey string
|
||||
wantErr bool
|
||||
}{
|
||||
{"json", "ewogICAgImFwaV9rZXkiOiAiZm9vIgp9Cg==", "foo", false},
|
||||
{"value", "Zm9v", "foo", false},
|
||||
{"json", `{"api_key": "foo"}`, "foo", false},
|
||||
{"base64 json", "ewogICAgImFwaV9rZXkiOiAiZm9vIgp9Cg==", "foo", false},
|
||||
{"base64 value", "Zm9v", "foo", false},
|
||||
{"empty", "", "", true},
|
||||
{"invalid", "Zm9v---", "", true},
|
||||
}
|
||||
|
|
|
@ -3,7 +3,6 @@ package onelogin
|
|||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
@ -16,6 +15,7 @@ import (
|
|||
"github.com/rs/zerolog"
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/encoding"
|
||||
"github.com/pomerium/pomerium/internal/httputil"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/directory"
|
||||
|
@ -332,14 +332,8 @@ type ServiceAccount struct {
|
|||
|
||||
// ParseServiceAccount parses the service account in the config options.
|
||||
func ParseServiceAccount(rawServiceAccount string) (*ServiceAccount, error) {
|
||||
bs, err := base64.StdEncoding.DecodeString(rawServiceAccount)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var serviceAccount ServiceAccount
|
||||
err = json.Unmarshal(bs, &serviceAccount)
|
||||
if err != nil {
|
||||
if err := encoding.DecodeBase64OrJSON(rawServiceAccount, &serviceAccount); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
|
|
@ -15,6 +15,7 @@ import (
|
|||
"github.com/go-chi/chi/v5"
|
||||
"github.com/go-chi/chi/v5/middleware"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/testutil"
|
||||
)
|
||||
|
@ -217,6 +218,49 @@ func TestProvider_UserGroups(t *testing.T) {
|
|||
]`, groups)
|
||||
}
|
||||
|
||||
func TestParseServiceAccount(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
rawServiceAccount string
|
||||
serviceAccount *ServiceAccount
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
"json",
|
||||
`{"client_id":"CLIENT_ID","client_secret":"CLIENT_SECRET"}`,
|
||||
&ServiceAccount{ClientID: "CLIENT_ID", ClientSecret: "CLIENT_SECRET"},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"base64 json",
|
||||
`eyJjbGllbnRfaWQiOiJDTElFTlRfSUQiLCJjbGllbnRfc2VjcmV0IjoiQ0xJRU5UX1NFQ1JFVCJ9`,
|
||||
&ServiceAccount{ClientID: "CLIENT_ID", ClientSecret: "CLIENT_SECRET"},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"empty",
|
||||
"",
|
||||
nil,
|
||||
true,
|
||||
},
|
||||
{
|
||||
"invalid",
|
||||
"Zm9v---",
|
||||
nil,
|
||||
true,
|
||||
},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got, err := ParseServiceAccount(tc.rawServiceAccount)
|
||||
require.True(t, (err != nil) == tc.wantErr)
|
||||
assert.Equal(t, tc.serviceAccount, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func mustParseURL(rawurl string) *url.URL {
|
||||
u, err := url.Parse(rawurl)
|
||||
if err != nil {
|
||||
|
|
|
@ -101,8 +101,7 @@ type ServiceAccount struct {
|
|||
// ParseServiceAccount parses the service account in the config options.
|
||||
func ParseServiceAccount(rawServiceAccount string) (*ServiceAccount, error) {
|
||||
var serviceAccount ServiceAccount
|
||||
err := encoding.DecodeBase64OrJSON(rawServiceAccount, &serviceAccount)
|
||||
if err != nil {
|
||||
if err := encoding.DecodeBase64OrJSON(rawServiceAccount, &serviceAccount); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
|
51
internal/directory/ping/config_test.go
Normal file
51
internal/directory/ping/config_test.go
Normal file
|
@ -0,0 +1,51 @@
|
|||
package ping
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestParseServiceAccount(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
rawServiceAccount string
|
||||
serviceAccount *ServiceAccount
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
"json",
|
||||
`{"client_id":"CLIENT_ID","client_secret":"CLIENT_SECRET","environment_id":"ENVIRONMENT_ID"}`,
|
||||
&ServiceAccount{ClientID: "CLIENT_ID", ClientSecret: "CLIENT_SECRET", EnvironmentID: "ENVIRONMENT_ID"},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"base64 json",
|
||||
`eyJjbGllbnRfaWQiOiJDTElFTlRfSUQiLCJjbGllbnRfc2VjcmV0IjoiQ0xJRU5UX1NFQ1JFVCIsImVudmlyb25tZW50X2lkIjoiRU5WSVJPTk1FTlRfSUQifQ==`,
|
||||
&ServiceAccount{ClientID: "CLIENT_ID", ClientSecret: "CLIENT_SECRET", EnvironmentID: "ENVIRONMENT_ID"},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"empty",
|
||||
"",
|
||||
nil,
|
||||
true,
|
||||
},
|
||||
{
|
||||
"invalid",
|
||||
"Zm9v---",
|
||||
nil,
|
||||
true,
|
||||
},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got, err := ParseServiceAccount(tc.rawServiceAccount)
|
||||
require.True(t, (err != nil) == tc.wantErr)
|
||||
assert.Equal(t, tc.serviceAccount, got)
|
||||
})
|
||||
}
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue