directory: generate user/directory.User ID in a consistent way (#944)

This commit is contained in:
Caleb Doxsey 2020-06-22 07:42:57 -06:00 committed by GitHub
parent 84dde097c7
commit f7760c413e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 59 additions and 26 deletions

View file

@ -93,7 +93,8 @@ type Authenticate struct {
sessionLoaders []sessions.SessionLoader sessionLoaders []sessions.SessionLoader
// provider is the interface to interacting with the identity provider (IdP) // provider is the interface to interacting with the identity provider (IdP)
provider identity.Authenticator provider identity.Authenticator
providerName string
// dataBrokerClient is used to retrieve sessions // dataBrokerClient is used to retrieve sessions
dataBrokerClient databroker.DataBrokerServiceClient dataBrokerClient databroker.DataBrokerServiceClient
@ -193,7 +194,8 @@ func New(opts config.Options) (*Authenticate, error) {
encryptedEncoder: encryptedEncoder, encryptedEncoder: encryptedEncoder,
sessionLoaders: []sessions.SessionLoader{qpStore, headerStore, cookieStore}, sessionLoaders: []sessions.SessionLoader{qpStore, headerStore, cookieStore},
// IdP // IdP
provider: provider, provider: provider,
providerName: opts.Provider,
// grpc client for cache // grpc client for cache
dataBrokerClient: dataBrokerClient, dataBrokerClient: dataBrokerClient,
sessionClient: sessionClient, sessionClient: sessionClient,

View file

@ -19,6 +19,7 @@ import (
"golang.org/x/oauth2" "golang.org/x/oauth2"
"github.com/pomerium/pomerium/internal/cryptutil" "github.com/pomerium/pomerium/internal/cryptutil"
"github.com/pomerium/pomerium/internal/grpc/databroker"
"github.com/pomerium/pomerium/internal/grpc/directory" "github.com/pomerium/pomerium/internal/grpc/directory"
"github.com/pomerium/pomerium/internal/grpc/session" "github.com/pomerium/pomerium/internal/grpc/session"
"github.com/pomerium/pomerium/internal/grpc/user" "github.com/pomerium/pomerium/internal/grpc/user"
@ -496,7 +497,7 @@ func (a *Authenticate) saveSessionToDataBroker(ctx context.Context, sessionState
s := &session.Session{ s := &session.Session{
Id: sessionState.ID, Id: sessionState.ID,
UserId: sessionState.Issuer + "/" + sessionState.Subject, UserId: databroker.GetUserID(a.providerName, sessionState.Subject),
ExpiresAt: sessionExpiry, ExpiresAt: sessionExpiry,
IdToken: &session.IDToken{ IdToken: &session.IDToken{
Issuer: sessionState.Issuer, Issuer: sessionState.Issuer,

View file

@ -15,10 +15,14 @@ import (
"golang.org/x/oauth2" "golang.org/x/oauth2"
"github.com/pomerium/pomerium/internal/grpc/databroker"
"github.com/pomerium/pomerium/internal/grpc/directory" "github.com/pomerium/pomerium/internal/grpc/directory"
) )
var ( // Name is the provider name.
const Name = "azure"
const (
defaultGraphHost = "graph.microsoft.com" defaultGraphHost = "graph.microsoft.com"
defaultLoginHost = "login.microsoftonline.com" defaultLoginHost = "login.microsoftonline.com"
@ -122,7 +126,10 @@ func (p *Provider) UserGroups(ctx context.Context) ([]*directory.User, error) {
var users []*directory.User var users []*directory.User
for userID, groupIDs := range userIDToGroupIDs { for userID, groupIDs := range userIDToGroupIDs {
sort.Strings(groupIDs) sort.Strings(groupIDs)
users = append(users, &directory.User{Id: userID, Groups: groupIDs}) users = append(users, &directory.User{
Id: databroker.GetUserID(Name, userID),
Groups: groupIDs,
})
} }
sort.Slice(users, func(i, j int) bool { sort.Slice(users, func(i, j int) bool {
return users[i].GetId() < users[j].GetId() return users[i].GetId() < users[j].GetId()

View file

@ -89,15 +89,15 @@ func Test(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, []*directory.User{ assert.Equal(t, []*directory.User{
{ {
Id: "user-1", Id: "azure/user-1",
Groups: []string{"admin"}, Groups: []string{"admin"},
}, },
{ {
Id: "user-2", Id: "azure/user-2",
Groups: []string{"test"}, Groups: []string{"test"},
}, },
{ {
Id: "user-3", Id: "azure/user-3",
Groups: []string{"test"}, Groups: []string{"test"},
}, },
}, users) }, users)

View file

@ -13,10 +13,14 @@ import (
"github.com/rs/zerolog" "github.com/rs/zerolog"
"github.com/tomnomnom/linkheader" "github.com/tomnomnom/linkheader"
"github.com/pomerium/pomerium/internal/grpc/databroker"
"github.com/pomerium/pomerium/internal/grpc/directory" "github.com/pomerium/pomerium/internal/grpc/directory"
"github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/log"
) )
// Name is the provider name.
const Name = "gitlab"
var ( var (
defaultURL = &url.URL{ defaultURL = &url.URL{
Scheme: "https", Scheme: "https",
@ -106,7 +110,7 @@ func (p *Provider) UserGroups(ctx context.Context) ([]*directory.User, error) {
var users []*directory.User var users []*directory.User
for userID, groupIDs := range userIDToGroupIDs { for userID, groupIDs := range userIDToGroupIDs {
user := &directory.User{ user := &directory.User{
Id: fmt.Sprint(userID), Id: databroker.GetUserID(Name, fmt.Sprint(userID)),
} }
for _, groupID := range groupIDs { for _, groupID := range groupIDs {
user.Groups = append(user.Groups, fmt.Sprint(groupID)) user.Groups = append(user.Groups, fmt.Sprint(groupID))

View file

@ -69,9 +69,9 @@ func Test(t *testing.T) {
users, err := p.UserGroups(context.Background()) users, err := p.UserGroups(context.Background())
assert.NoError(t, err) assert.NoError(t, err)
testutil.AssertProtoJSONEqual(t, `[ testutil.AssertProtoJSONEqual(t, `[
{ "id": "11", "groups": ["1"] }, { "id": "gitlab/11", "groups": ["1"] },
{ "id": "12", "groups": ["2"] }, { "id": "gitlab/12", "groups": ["2"] },
{ "id": "13", "groups": ["2"] } { "id": "gitlab/13", "groups": ["2"] }
]`, users) ]`, users)
} }

View file

@ -14,10 +14,14 @@ import (
admin "google.golang.org/api/admin/directory/v1" admin "google.golang.org/api/admin/directory/v1"
"google.golang.org/api/option" "google.golang.org/api/option"
"github.com/pomerium/pomerium/internal/grpc/databroker"
"github.com/pomerium/pomerium/internal/grpc/directory" "github.com/pomerium/pomerium/internal/grpc/directory"
"github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/log"
) )
// Name is the provider name.
const Name = "google"
const ( const (
defaultProviderURL = "https://accounts.google.com" defaultProviderURL = "https://accounts.google.com"
) )
@ -118,7 +122,7 @@ func (p *Provider) UserGroups(ctx context.Context) ([]*directory.User, error) {
for userEmail, groups := range userEmailToGroups { for userEmail, groups := range userEmailToGroups {
sort.Strings(groups) sort.Strings(groups)
users = append(users, &directory.User{ users = append(users, &directory.User{
Id: userEmail, Id: databroker.GetUserID(Name, userEmail),
Groups: groups, Groups: groups,
}) })
} }

View file

@ -13,10 +13,14 @@ import (
"github.com/rs/zerolog" "github.com/rs/zerolog"
"github.com/tomnomnom/linkheader" "github.com/tomnomnom/linkheader"
"github.com/pomerium/pomerium/internal/grpc/databroker"
"github.com/pomerium/pomerium/internal/grpc/directory" "github.com/pomerium/pomerium/internal/grpc/directory"
"github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/log"
) )
// Name is the provider name.
const Name = "okta"
type config struct { type config struct {
batchSize int batchSize int
httpClient *http.Client httpClient *http.Client
@ -112,7 +116,7 @@ func (p *Provider) UserGroups(ctx context.Context) ([]*directory.User, error) {
for userEmail, groups := range userEmailToGroups { for userEmail, groups := range userEmailToGroups {
sort.Strings(groups) sort.Strings(groups)
users = append(users, &directory.User{ users = append(users, &directory.User{
Id: userEmail, Id: databroker.GetUserID(Name, userEmail),
Groups: groups, Groups: groups,
}) })
} }

View file

@ -119,15 +119,15 @@ func TestProvider_UserGroups(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, []*directory.User{ assert.Equal(t, []*directory.User{
{ {
Id: "a@example.com", Id: "okta/a@example.com",
Groups: []string{"admin", "user"}, Groups: []string{"admin", "user"},
}, },
{ {
Id: "b@example.com", Id: "okta/b@example.com",
Groups: []string{"test", "user"}, Groups: []string{"test", "user"},
}, },
{ {
Id: "c@example.com", Id: "okta/c@example.com",
Groups: []string{"user"}, Groups: []string{"user"},
}, },
}, users) }, users)

View file

@ -15,10 +15,14 @@ import (
"github.com/rs/zerolog" "github.com/rs/zerolog"
"golang.org/x/oauth2" "golang.org/x/oauth2"
"github.com/pomerium/pomerium/internal/grpc/databroker"
"github.com/pomerium/pomerium/internal/grpc/directory" "github.com/pomerium/pomerium/internal/grpc/directory"
"github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/log"
) )
// Name is the provider name.
const Name = "onelogin"
type config struct { type config struct {
apiURL *url.URL apiURL *url.URL
batchSize int batchSize int
@ -127,7 +131,7 @@ func (p *Provider) UserGroups(ctx context.Context) ([]*directory.User, error) {
for userEmail, groups := range userEmailToGroupNames { for userEmail, groups := range userEmailToGroupNames {
sort.Strings(groups) sort.Strings(groups)
users = append(users, &directory.User{ users = append(users, &directory.User{
Id: userEmail, Id: databroker.GetUserID(Name, userEmail),
Groups: groups, Groups: groups,
}) })
} }

View file

@ -151,15 +151,15 @@ func TestProvider_UserGroups(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, []*directory.User{ assert.Equal(t, []*directory.User{
{ {
Id: "a@example.com", Id: "onelogin/a@example.com",
Groups: []string{"admin"}, Groups: []string{"admin"},
}, },
{ {
Id: "b@example.com", Id: "onelogin/b@example.com",
Groups: []string{"test"}, Groups: []string{"test"},
}, },
{ {
Id: "c@example.com", Id: "onelogin/c@example.com",
Groups: []string{"user"}, Groups: []string{"user"},
}, },
}, users) }, users)

View file

@ -26,7 +26,7 @@ type Provider interface {
// GetProvider gets the provider for the given options. // GetProvider gets the provider for the given options.
func GetProvider(options *config.Options) Provider { func GetProvider(options *config.Options) Provider {
switch options.Provider { switch options.Provider {
case "azure": case azure.Name:
serviceAccount, err := azure.ParseServiceAccount(options.ServiceAccount) serviceAccount, err := azure.ParseServiceAccount(options.ServiceAccount)
if err == nil { if err == nil {
return azure.New(azure.WithServiceAccount(serviceAccount)) return azure.New(azure.WithServiceAccount(serviceAccount))
@ -37,7 +37,7 @@ func GetProvider(options *config.Options) Provider {
Str("provider", options.Provider). Str("provider", options.Provider).
Err(err). Err(err).
Msg("invalid service account for azure directory provider") Msg("invalid service account for azure directory provider")
case "gitlab": case gitlab.Name:
serviceAccount, err := gitlab.ParseServiceAccount(options.ServiceAccount) serviceAccount, err := gitlab.ParseServiceAccount(options.ServiceAccount)
if err == nil { if err == nil {
return gitlab.New(gitlab.WithServiceAccount(serviceAccount)) return gitlab.New(gitlab.WithServiceAccount(serviceAccount))
@ -47,11 +47,11 @@ func GetProvider(options *config.Options) Provider {
Str("provider", options.Provider). Str("provider", options.Provider).
Err(err). Err(err).
Msg("invalid service account for gitlab directory provider") Msg("invalid service account for gitlab directory provider")
case "google": case google.Name:
if options.ServiceAccount != "" { if options.ServiceAccount != "" {
return google.New(google.WithServiceAccount(options.ServiceAccount)) return google.New(google.WithServiceAccount(options.ServiceAccount))
} }
case "okta": case okta.Name:
providerURL, _ := url.Parse(options.ProviderURL) providerURL, _ := url.Parse(options.ProviderURL)
serviceAccount, err := okta.ParseServiceAccount(options.ServiceAccount) serviceAccount, err := okta.ParseServiceAccount(options.ServiceAccount)
if err == nil { if err == nil {
@ -64,7 +64,7 @@ func GetProvider(options *config.Options) Provider {
Str("provider", options.Provider). Str("provider", options.Provider).
Err(err). Err(err).
Msg("invalid service account for okta directory provider") Msg("invalid service account for okta directory provider")
case "onelogin": case onelogin.Name:
serviceAccount, err := onelogin.ParseServiceAccount(options.ServiceAccount) serviceAccount, err := onelogin.ParseServiceAccount(options.ServiceAccount)
if err == nil { if err == nil {
return onelogin.New(onelogin.WithServiceAccount(serviceAccount)) return onelogin.New(onelogin.WithServiceAccount(serviceAccount))

View file

@ -0,0 +1,7 @@
// Package databroker contains databroker protobuf definitions.
package databroker
// GetUserID gets the databroker user id from a provider user id.
func GetUserID(provider, providerUserID string) string {
return provider + "/" + providerUserID
}