mirror of
https://github.com/pomerium/pomerium.git
synced 2025-04-29 02:16:28 +02:00
directory: generate user/directory.User ID in a consistent way (#944)
This commit is contained in:
parent
84dde097c7
commit
f7760c413e
13 changed files with 59 additions and 26 deletions
|
@ -93,7 +93,8 @@ type Authenticate struct {
|
|||
sessionLoaders []sessions.SessionLoader
|
||||
|
||||
// provider is the interface to interacting with the identity provider (IdP)
|
||||
provider identity.Authenticator
|
||||
provider identity.Authenticator
|
||||
providerName string
|
||||
|
||||
// dataBrokerClient is used to retrieve sessions
|
||||
dataBrokerClient databroker.DataBrokerServiceClient
|
||||
|
@ -193,7 +194,8 @@ func New(opts config.Options) (*Authenticate, error) {
|
|||
encryptedEncoder: encryptedEncoder,
|
||||
sessionLoaders: []sessions.SessionLoader{qpStore, headerStore, cookieStore},
|
||||
// IdP
|
||||
provider: provider,
|
||||
provider: provider,
|
||||
providerName: opts.Provider,
|
||||
// grpc client for cache
|
||||
dataBrokerClient: dataBrokerClient,
|
||||
sessionClient: sessionClient,
|
||||
|
|
|
@ -19,6 +19,7 @@ import (
|
|||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/cryptutil"
|
||||
"github.com/pomerium/pomerium/internal/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/internal/grpc/directory"
|
||||
"github.com/pomerium/pomerium/internal/grpc/session"
|
||||
"github.com/pomerium/pomerium/internal/grpc/user"
|
||||
|
@ -496,7 +497,7 @@ func (a *Authenticate) saveSessionToDataBroker(ctx context.Context, sessionState
|
|||
|
||||
s := &session.Session{
|
||||
Id: sessionState.ID,
|
||||
UserId: sessionState.Issuer + "/" + sessionState.Subject,
|
||||
UserId: databroker.GetUserID(a.providerName, sessionState.Subject),
|
||||
ExpiresAt: sessionExpiry,
|
||||
IdToken: &session.IDToken{
|
||||
Issuer: sessionState.Issuer,
|
||||
|
|
|
@ -15,10 +15,14 @@ import (
|
|||
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/internal/grpc/directory"
|
||||
)
|
||||
|
||||
var (
|
||||
// Name is the provider name.
|
||||
const Name = "azure"
|
||||
|
||||
const (
|
||||
defaultGraphHost = "graph.microsoft.com"
|
||||
|
||||
defaultLoginHost = "login.microsoftonline.com"
|
||||
|
@ -122,7 +126,10 @@ func (p *Provider) UserGroups(ctx context.Context) ([]*directory.User, error) {
|
|||
var users []*directory.User
|
||||
for userID, groupIDs := range userIDToGroupIDs {
|
||||
sort.Strings(groupIDs)
|
||||
users = append(users, &directory.User{Id: userID, Groups: groupIDs})
|
||||
users = append(users, &directory.User{
|
||||
Id: databroker.GetUserID(Name, userID),
|
||||
Groups: groupIDs,
|
||||
})
|
||||
}
|
||||
sort.Slice(users, func(i, j int) bool {
|
||||
return users[i].GetId() < users[j].GetId()
|
||||
|
|
|
@ -89,15 +89,15 @@ func Test(t *testing.T) {
|
|||
assert.NoError(t, err)
|
||||
assert.Equal(t, []*directory.User{
|
||||
{
|
||||
Id: "user-1",
|
||||
Id: "azure/user-1",
|
||||
Groups: []string{"admin"},
|
||||
},
|
||||
{
|
||||
Id: "user-2",
|
||||
Id: "azure/user-2",
|
||||
Groups: []string{"test"},
|
||||
},
|
||||
{
|
||||
Id: "user-3",
|
||||
Id: "azure/user-3",
|
||||
Groups: []string{"test"},
|
||||
},
|
||||
}, users)
|
||||
|
|
|
@ -13,10 +13,14 @@ import (
|
|||
"github.com/rs/zerolog"
|
||||
"github.com/tomnomnom/linkheader"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/internal/grpc/directory"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
)
|
||||
|
||||
// Name is the provider name.
|
||||
const Name = "gitlab"
|
||||
|
||||
var (
|
||||
defaultURL = &url.URL{
|
||||
Scheme: "https",
|
||||
|
@ -106,7 +110,7 @@ func (p *Provider) UserGroups(ctx context.Context) ([]*directory.User, error) {
|
|||
var users []*directory.User
|
||||
for userID, groupIDs := range userIDToGroupIDs {
|
||||
user := &directory.User{
|
||||
Id: fmt.Sprint(userID),
|
||||
Id: databroker.GetUserID(Name, fmt.Sprint(userID)),
|
||||
}
|
||||
for _, groupID := range groupIDs {
|
||||
user.Groups = append(user.Groups, fmt.Sprint(groupID))
|
||||
|
|
|
@ -69,9 +69,9 @@ func Test(t *testing.T) {
|
|||
users, err := p.UserGroups(context.Background())
|
||||
assert.NoError(t, err)
|
||||
testutil.AssertProtoJSONEqual(t, `[
|
||||
{ "id": "11", "groups": ["1"] },
|
||||
{ "id": "12", "groups": ["2"] },
|
||||
{ "id": "13", "groups": ["2"] }
|
||||
{ "id": "gitlab/11", "groups": ["1"] },
|
||||
{ "id": "gitlab/12", "groups": ["2"] },
|
||||
{ "id": "gitlab/13", "groups": ["2"] }
|
||||
]`, users)
|
||||
}
|
||||
|
||||
|
|
|
@ -14,10 +14,14 @@ import (
|
|||
admin "google.golang.org/api/admin/directory/v1"
|
||||
"google.golang.org/api/option"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/internal/grpc/directory"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
)
|
||||
|
||||
// Name is the provider name.
|
||||
const Name = "google"
|
||||
|
||||
const (
|
||||
defaultProviderURL = "https://accounts.google.com"
|
||||
)
|
||||
|
@ -118,7 +122,7 @@ func (p *Provider) UserGroups(ctx context.Context) ([]*directory.User, error) {
|
|||
for userEmail, groups := range userEmailToGroups {
|
||||
sort.Strings(groups)
|
||||
users = append(users, &directory.User{
|
||||
Id: userEmail,
|
||||
Id: databroker.GetUserID(Name, userEmail),
|
||||
Groups: groups,
|
||||
})
|
||||
}
|
||||
|
|
|
@ -13,10 +13,14 @@ import (
|
|||
"github.com/rs/zerolog"
|
||||
"github.com/tomnomnom/linkheader"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/internal/grpc/directory"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
)
|
||||
|
||||
// Name is the provider name.
|
||||
const Name = "okta"
|
||||
|
||||
type config struct {
|
||||
batchSize int
|
||||
httpClient *http.Client
|
||||
|
@ -112,7 +116,7 @@ func (p *Provider) UserGroups(ctx context.Context) ([]*directory.User, error) {
|
|||
for userEmail, groups := range userEmailToGroups {
|
||||
sort.Strings(groups)
|
||||
users = append(users, &directory.User{
|
||||
Id: userEmail,
|
||||
Id: databroker.GetUserID(Name, userEmail),
|
||||
Groups: groups,
|
||||
})
|
||||
}
|
||||
|
|
|
@ -119,15 +119,15 @@ func TestProvider_UserGroups(t *testing.T) {
|
|||
assert.NoError(t, err)
|
||||
assert.Equal(t, []*directory.User{
|
||||
{
|
||||
Id: "a@example.com",
|
||||
Id: "okta/a@example.com",
|
||||
Groups: []string{"admin", "user"},
|
||||
},
|
||||
{
|
||||
Id: "b@example.com",
|
||||
Id: "okta/b@example.com",
|
||||
Groups: []string{"test", "user"},
|
||||
},
|
||||
{
|
||||
Id: "c@example.com",
|
||||
Id: "okta/c@example.com",
|
||||
Groups: []string{"user"},
|
||||
},
|
||||
}, users)
|
||||
|
|
|
@ -15,10 +15,14 @@ import (
|
|||
"github.com/rs/zerolog"
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/internal/grpc/directory"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
)
|
||||
|
||||
// Name is the provider name.
|
||||
const Name = "onelogin"
|
||||
|
||||
type config struct {
|
||||
apiURL *url.URL
|
||||
batchSize int
|
||||
|
@ -127,7 +131,7 @@ func (p *Provider) UserGroups(ctx context.Context) ([]*directory.User, error) {
|
|||
for userEmail, groups := range userEmailToGroupNames {
|
||||
sort.Strings(groups)
|
||||
users = append(users, &directory.User{
|
||||
Id: userEmail,
|
||||
Id: databroker.GetUserID(Name, userEmail),
|
||||
Groups: groups,
|
||||
})
|
||||
}
|
||||
|
|
|
@ -151,15 +151,15 @@ func TestProvider_UserGroups(t *testing.T) {
|
|||
assert.NoError(t, err)
|
||||
assert.Equal(t, []*directory.User{
|
||||
{
|
||||
Id: "a@example.com",
|
||||
Id: "onelogin/a@example.com",
|
||||
Groups: []string{"admin"},
|
||||
},
|
||||
{
|
||||
Id: "b@example.com",
|
||||
Id: "onelogin/b@example.com",
|
||||
Groups: []string{"test"},
|
||||
},
|
||||
{
|
||||
Id: "c@example.com",
|
||||
Id: "onelogin/c@example.com",
|
||||
Groups: []string{"user"},
|
||||
},
|
||||
}, users)
|
||||
|
|
|
@ -26,7 +26,7 @@ type Provider interface {
|
|||
// GetProvider gets the provider for the given options.
|
||||
func GetProvider(options *config.Options) Provider {
|
||||
switch options.Provider {
|
||||
case "azure":
|
||||
case azure.Name:
|
||||
serviceAccount, err := azure.ParseServiceAccount(options.ServiceAccount)
|
||||
if err == nil {
|
||||
return azure.New(azure.WithServiceAccount(serviceAccount))
|
||||
|
@ -37,7 +37,7 @@ func GetProvider(options *config.Options) Provider {
|
|||
Str("provider", options.Provider).
|
||||
Err(err).
|
||||
Msg("invalid service account for azure directory provider")
|
||||
case "gitlab":
|
||||
case gitlab.Name:
|
||||
serviceAccount, err := gitlab.ParseServiceAccount(options.ServiceAccount)
|
||||
if err == nil {
|
||||
return gitlab.New(gitlab.WithServiceAccount(serviceAccount))
|
||||
|
@ -47,11 +47,11 @@ func GetProvider(options *config.Options) Provider {
|
|||
Str("provider", options.Provider).
|
||||
Err(err).
|
||||
Msg("invalid service account for gitlab directory provider")
|
||||
case "google":
|
||||
case google.Name:
|
||||
if options.ServiceAccount != "" {
|
||||
return google.New(google.WithServiceAccount(options.ServiceAccount))
|
||||
}
|
||||
case "okta":
|
||||
case okta.Name:
|
||||
providerURL, _ := url.Parse(options.ProviderURL)
|
||||
serviceAccount, err := okta.ParseServiceAccount(options.ServiceAccount)
|
||||
if err == nil {
|
||||
|
@ -64,7 +64,7 @@ func GetProvider(options *config.Options) Provider {
|
|||
Str("provider", options.Provider).
|
||||
Err(err).
|
||||
Msg("invalid service account for okta directory provider")
|
||||
case "onelogin":
|
||||
case onelogin.Name:
|
||||
serviceAccount, err := onelogin.ParseServiceAccount(options.ServiceAccount)
|
||||
if err == nil {
|
||||
return onelogin.New(onelogin.WithServiceAccount(serviceAccount))
|
||||
|
|
7
internal/grpc/databroker/databroker.go
Normal file
7
internal/grpc/databroker/databroker.go
Normal file
|
@ -0,0 +1,7 @@
|
|||
// Package databroker contains databroker protobuf definitions.
|
||||
package databroker
|
||||
|
||||
// GetUserID gets the databroker user id from a provider user id.
|
||||
func GetUserID(provider, providerUserID string) string {
|
||||
return provider + "/" + providerUserID
|
||||
}
|
Loading…
Add table
Reference in a new issue