From f7760c413e7732e018f15e0c3f7a53e4c077b0b6 Mon Sep 17 00:00:00 2001 From: Caleb Doxsey Date: Mon, 22 Jun 2020 07:42:57 -0600 Subject: [PATCH] directory: generate user/directory.User ID in a consistent way (#944) --- authenticate/authenticate.go | 6 ++++-- authenticate/handlers.go | 3 ++- internal/directory/azure/azure.go | 11 +++++++++-- internal/directory/azure/azure_test.go | 6 +++--- internal/directory/gitlab/gitlab.go | 6 +++++- internal/directory/gitlab/gitlab_test.go | 6 +++--- internal/directory/google/google.go | 6 +++++- internal/directory/okta/okta.go | 6 +++++- internal/directory/okta/okta_test.go | 6 +++--- internal/directory/onelogin/onelogin.go | 6 +++++- internal/directory/onelogin/onelogin_test.go | 6 +++--- internal/directory/provider.go | 10 +++++----- internal/grpc/databroker/databroker.go | 7 +++++++ 13 files changed, 59 insertions(+), 26 deletions(-) create mode 100644 internal/grpc/databroker/databroker.go diff --git a/authenticate/authenticate.go b/authenticate/authenticate.go index 02050dd2e..4c2b46b01 100644 --- a/authenticate/authenticate.go +++ b/authenticate/authenticate.go @@ -93,7 +93,8 @@ type Authenticate struct { sessionLoaders []sessions.SessionLoader // provider is the interface to interacting with the identity provider (IdP) - provider identity.Authenticator + provider identity.Authenticator + providerName string // dataBrokerClient is used to retrieve sessions dataBrokerClient databroker.DataBrokerServiceClient @@ -193,7 +194,8 @@ func New(opts config.Options) (*Authenticate, error) { encryptedEncoder: encryptedEncoder, sessionLoaders: []sessions.SessionLoader{qpStore, headerStore, cookieStore}, // IdP - provider: provider, + provider: provider, + providerName: opts.Provider, // grpc client for cache dataBrokerClient: dataBrokerClient, sessionClient: sessionClient, diff --git a/authenticate/handlers.go b/authenticate/handlers.go index b53805bad..ddda715f4 100644 --- a/authenticate/handlers.go +++ b/authenticate/handlers.go @@ -19,6 +19,7 @@ import ( "golang.org/x/oauth2" "github.com/pomerium/pomerium/internal/cryptutil" + "github.com/pomerium/pomerium/internal/grpc/databroker" "github.com/pomerium/pomerium/internal/grpc/directory" "github.com/pomerium/pomerium/internal/grpc/session" "github.com/pomerium/pomerium/internal/grpc/user" @@ -496,7 +497,7 @@ func (a *Authenticate) saveSessionToDataBroker(ctx context.Context, sessionState s := &session.Session{ Id: sessionState.ID, - UserId: sessionState.Issuer + "/" + sessionState.Subject, + UserId: databroker.GetUserID(a.providerName, sessionState.Subject), ExpiresAt: sessionExpiry, IdToken: &session.IDToken{ Issuer: sessionState.Issuer, diff --git a/internal/directory/azure/azure.go b/internal/directory/azure/azure.go index 922794931..bea8b7959 100644 --- a/internal/directory/azure/azure.go +++ b/internal/directory/azure/azure.go @@ -15,10 +15,14 @@ import ( "golang.org/x/oauth2" + "github.com/pomerium/pomerium/internal/grpc/databroker" "github.com/pomerium/pomerium/internal/grpc/directory" ) -var ( +// Name is the provider name. +const Name = "azure" + +const ( defaultGraphHost = "graph.microsoft.com" defaultLoginHost = "login.microsoftonline.com" @@ -122,7 +126,10 @@ func (p *Provider) UserGroups(ctx context.Context) ([]*directory.User, error) { var users []*directory.User for userID, groupIDs := range userIDToGroupIDs { sort.Strings(groupIDs) - users = append(users, &directory.User{Id: userID, Groups: groupIDs}) + users = append(users, &directory.User{ + Id: databroker.GetUserID(Name, userID), + Groups: groupIDs, + }) } sort.Slice(users, func(i, j int) bool { return users[i].GetId() < users[j].GetId() diff --git a/internal/directory/azure/azure_test.go b/internal/directory/azure/azure_test.go index a4f0e13d9..f3ad9b414 100644 --- a/internal/directory/azure/azure_test.go +++ b/internal/directory/azure/azure_test.go @@ -89,15 +89,15 @@ func Test(t *testing.T) { assert.NoError(t, err) assert.Equal(t, []*directory.User{ { - Id: "user-1", + Id: "azure/user-1", Groups: []string{"admin"}, }, { - Id: "user-2", + Id: "azure/user-2", Groups: []string{"test"}, }, { - Id: "user-3", + Id: "azure/user-3", Groups: []string{"test"}, }, }, users) diff --git a/internal/directory/gitlab/gitlab.go b/internal/directory/gitlab/gitlab.go index 6e8460042..5900919ee 100644 --- a/internal/directory/gitlab/gitlab.go +++ b/internal/directory/gitlab/gitlab.go @@ -13,10 +13,14 @@ import ( "github.com/rs/zerolog" "github.com/tomnomnom/linkheader" + "github.com/pomerium/pomerium/internal/grpc/databroker" "github.com/pomerium/pomerium/internal/grpc/directory" "github.com/pomerium/pomerium/internal/log" ) +// Name is the provider name. +const Name = "gitlab" + var ( defaultURL = &url.URL{ Scheme: "https", @@ -106,7 +110,7 @@ func (p *Provider) UserGroups(ctx context.Context) ([]*directory.User, error) { var users []*directory.User for userID, groupIDs := range userIDToGroupIDs { user := &directory.User{ - Id: fmt.Sprint(userID), + Id: databroker.GetUserID(Name, fmt.Sprint(userID)), } for _, groupID := range groupIDs { user.Groups = append(user.Groups, fmt.Sprint(groupID)) diff --git a/internal/directory/gitlab/gitlab_test.go b/internal/directory/gitlab/gitlab_test.go index 16c217451..ffd1991ad 100644 --- a/internal/directory/gitlab/gitlab_test.go +++ b/internal/directory/gitlab/gitlab_test.go @@ -69,9 +69,9 @@ func Test(t *testing.T) { users, err := p.UserGroups(context.Background()) assert.NoError(t, err) testutil.AssertProtoJSONEqual(t, `[ - { "id": "11", "groups": ["1"] }, - { "id": "12", "groups": ["2"] }, - { "id": "13", "groups": ["2"] } + { "id": "gitlab/11", "groups": ["1"] }, + { "id": "gitlab/12", "groups": ["2"] }, + { "id": "gitlab/13", "groups": ["2"] } ]`, users) } diff --git a/internal/directory/google/google.go b/internal/directory/google/google.go index 05374a138..cbe1b5529 100644 --- a/internal/directory/google/google.go +++ b/internal/directory/google/google.go @@ -14,10 +14,14 @@ import ( admin "google.golang.org/api/admin/directory/v1" "google.golang.org/api/option" + "github.com/pomerium/pomerium/internal/grpc/databroker" "github.com/pomerium/pomerium/internal/grpc/directory" "github.com/pomerium/pomerium/internal/log" ) +// Name is the provider name. +const Name = "google" + const ( defaultProviderURL = "https://accounts.google.com" ) @@ -118,7 +122,7 @@ func (p *Provider) UserGroups(ctx context.Context) ([]*directory.User, error) { for userEmail, groups := range userEmailToGroups { sort.Strings(groups) users = append(users, &directory.User{ - Id: userEmail, + Id: databroker.GetUserID(Name, userEmail), Groups: groups, }) } diff --git a/internal/directory/okta/okta.go b/internal/directory/okta/okta.go index 09b480431..cd82afd72 100644 --- a/internal/directory/okta/okta.go +++ b/internal/directory/okta/okta.go @@ -13,10 +13,14 @@ import ( "github.com/rs/zerolog" "github.com/tomnomnom/linkheader" + "github.com/pomerium/pomerium/internal/grpc/databroker" "github.com/pomerium/pomerium/internal/grpc/directory" "github.com/pomerium/pomerium/internal/log" ) +// Name is the provider name. +const Name = "okta" + type config struct { batchSize int httpClient *http.Client @@ -112,7 +116,7 @@ func (p *Provider) UserGroups(ctx context.Context) ([]*directory.User, error) { for userEmail, groups := range userEmailToGroups { sort.Strings(groups) users = append(users, &directory.User{ - Id: userEmail, + Id: databroker.GetUserID(Name, userEmail), Groups: groups, }) } diff --git a/internal/directory/okta/okta_test.go b/internal/directory/okta/okta_test.go index 99dcc7e4a..8407dc854 100644 --- a/internal/directory/okta/okta_test.go +++ b/internal/directory/okta/okta_test.go @@ -119,15 +119,15 @@ func TestProvider_UserGroups(t *testing.T) { assert.NoError(t, err) assert.Equal(t, []*directory.User{ { - Id: "a@example.com", + Id: "okta/a@example.com", Groups: []string{"admin", "user"}, }, { - Id: "b@example.com", + Id: "okta/b@example.com", Groups: []string{"test", "user"}, }, { - Id: "c@example.com", + Id: "okta/c@example.com", Groups: []string{"user"}, }, }, users) diff --git a/internal/directory/onelogin/onelogin.go b/internal/directory/onelogin/onelogin.go index a60504f28..5091d9b11 100644 --- a/internal/directory/onelogin/onelogin.go +++ b/internal/directory/onelogin/onelogin.go @@ -15,10 +15,14 @@ import ( "github.com/rs/zerolog" "golang.org/x/oauth2" + "github.com/pomerium/pomerium/internal/grpc/databroker" "github.com/pomerium/pomerium/internal/grpc/directory" "github.com/pomerium/pomerium/internal/log" ) +// Name is the provider name. +const Name = "onelogin" + type config struct { apiURL *url.URL batchSize int @@ -127,7 +131,7 @@ func (p *Provider) UserGroups(ctx context.Context) ([]*directory.User, error) { for userEmail, groups := range userEmailToGroupNames { sort.Strings(groups) users = append(users, &directory.User{ - Id: userEmail, + Id: databroker.GetUserID(Name, userEmail), Groups: groups, }) } diff --git a/internal/directory/onelogin/onelogin_test.go b/internal/directory/onelogin/onelogin_test.go index 5a5e46e40..b26c1e4ff 100644 --- a/internal/directory/onelogin/onelogin_test.go +++ b/internal/directory/onelogin/onelogin_test.go @@ -151,15 +151,15 @@ func TestProvider_UserGroups(t *testing.T) { assert.NoError(t, err) assert.Equal(t, []*directory.User{ { - Id: "a@example.com", + Id: "onelogin/a@example.com", Groups: []string{"admin"}, }, { - Id: "b@example.com", + Id: "onelogin/b@example.com", Groups: []string{"test"}, }, { - Id: "c@example.com", + Id: "onelogin/c@example.com", Groups: []string{"user"}, }, }, users) diff --git a/internal/directory/provider.go b/internal/directory/provider.go index b3c9d2055..763ca53d5 100644 --- a/internal/directory/provider.go +++ b/internal/directory/provider.go @@ -26,7 +26,7 @@ type Provider interface { // GetProvider gets the provider for the given options. func GetProvider(options *config.Options) Provider { switch options.Provider { - case "azure": + case azure.Name: serviceAccount, err := azure.ParseServiceAccount(options.ServiceAccount) if err == nil { return azure.New(azure.WithServiceAccount(serviceAccount)) @@ -37,7 +37,7 @@ func GetProvider(options *config.Options) Provider { Str("provider", options.Provider). Err(err). Msg("invalid service account for azure directory provider") - case "gitlab": + case gitlab.Name: serviceAccount, err := gitlab.ParseServiceAccount(options.ServiceAccount) if err == nil { return gitlab.New(gitlab.WithServiceAccount(serviceAccount)) @@ -47,11 +47,11 @@ func GetProvider(options *config.Options) Provider { Str("provider", options.Provider). Err(err). Msg("invalid service account for gitlab directory provider") - case "google": + case google.Name: if options.ServiceAccount != "" { return google.New(google.WithServiceAccount(options.ServiceAccount)) } - case "okta": + case okta.Name: providerURL, _ := url.Parse(options.ProviderURL) serviceAccount, err := okta.ParseServiceAccount(options.ServiceAccount) if err == nil { @@ -64,7 +64,7 @@ func GetProvider(options *config.Options) Provider { Str("provider", options.Provider). Err(err). Msg("invalid service account for okta directory provider") - case "onelogin": + case onelogin.Name: serviceAccount, err := onelogin.ParseServiceAccount(options.ServiceAccount) if err == nil { return onelogin.New(onelogin.WithServiceAccount(serviceAccount)) diff --git a/internal/grpc/databroker/databroker.go b/internal/grpc/databroker/databroker.go new file mode 100644 index 000000000..538289e0a --- /dev/null +++ b/internal/grpc/databroker/databroker.go @@ -0,0 +1,7 @@ +// Package databroker contains databroker protobuf definitions. +package databroker + +// GetUserID gets the databroker user id from a provider user id. +func GetUserID(provider, providerUserID string) string { + return provider + "/" + providerUserID +}