mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-10 07:37:33 +02:00
move directory providers (#3633)
* remove directory providers and support for groups * idp: remove directory providers * better error messages * fix errors * restore postgres * fix test
This commit is contained in:
parent
bb5c80bae9
commit
c178819875
78 changed files with 723 additions and 8703 deletions
|
@ -30,7 +30,6 @@ import (
|
|||
"github.com/pomerium/pomerium/internal/urlutil"
|
||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/directory"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/session"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/user"
|
||||
)
|
||||
|
@ -544,34 +543,13 @@ func (a *Authenticate) getUserInfoData(r *http.Request) (handlers.UserInfoData,
|
|||
Id: pbSession.GetUserId(),
|
||||
}
|
||||
}
|
||||
pbDirectoryUser, err := a.getDirectoryUser(r.Context(), pbSession.GetUserId())
|
||||
if err != nil {
|
||||
pbDirectoryUser = &directory.User{
|
||||
Id: pbSession.GetUserId(),
|
||||
}
|
||||
}
|
||||
var groups []*directory.Group
|
||||
for _, groupID := range pbDirectoryUser.GetGroupIds() {
|
||||
pbDirectoryGroup, err := directory.GetGroup(r.Context(), state.dataBrokerClient, groupID)
|
||||
if err != nil {
|
||||
pbDirectoryGroup = &directory.Group{
|
||||
Id: groupID,
|
||||
Name: groupID,
|
||||
Email: groupID,
|
||||
}
|
||||
}
|
||||
groups = append(groups, pbDirectoryGroup)
|
||||
}
|
||||
|
||||
creationOptions, requestOptions, _ := a.webauthn.GetOptions(r.Context())
|
||||
|
||||
return handlers.UserInfoData{
|
||||
CSRFToken: csrf.Token(r),
|
||||
DirectoryGroups: groups,
|
||||
DirectoryUser: pbDirectoryUser,
|
||||
IsImpersonated: isImpersonated,
|
||||
Session: pbSession,
|
||||
User: pbUser,
|
||||
CSRFToken: csrf.Token(r),
|
||||
IsImpersonated: isImpersonated,
|
||||
Session: pbSession,
|
||||
User: pbUser,
|
||||
|
||||
WebAuthnCreationOptions: creationOptions,
|
||||
WebAuthnRequestOptions: requestOptions,
|
||||
|
@ -645,14 +623,6 @@ func (a *Authenticate) saveSessionToDataBroker(
|
|||
sessionState.DatabrokerServerVersion = res.GetServerVersion()
|
||||
sessionState.DatabrokerRecordVersion = res.GetRecord().GetVersion()
|
||||
|
||||
_, err = state.directoryClient.RefreshUser(ctx, &directory.RefreshUserRequest{
|
||||
UserId: s.UserId,
|
||||
AccessToken: accessToken.AccessToken,
|
||||
})
|
||||
if err != nil {
|
||||
log.Error(ctx).Err(err).Msg("directory: failed to refresh user data")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -718,11 +688,6 @@ func (a *Authenticate) getUser(ctx context.Context, userID string) (*user.User,
|
|||
return user.Get(ctx, client, userID)
|
||||
}
|
||||
|
||||
func (a *Authenticate) getDirectoryUser(ctx context.Context, userID string) (*directory.User, error) {
|
||||
client := a.state.Load().dataBrokerClient
|
||||
return directory.GetUser(ctx, client, userID)
|
||||
}
|
||||
|
||||
func (a *Authenticate) getWebauthnState(ctx context.Context) (*webauthn.State, error) {
|
||||
state := a.state.Load()
|
||||
|
||||
|
|
|
@ -6,7 +6,6 @@ import (
|
|||
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/directory"
|
||||
"github.com/pomerium/pomerium/internal/httputil"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/session"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/user"
|
||||
|
@ -16,12 +15,10 @@ import (
|
|||
|
||||
// UserInfoData is the data for the UserInfo page.
|
||||
type UserInfoData struct {
|
||||
CSRFToken string
|
||||
DirectoryGroups []*directory.Group
|
||||
DirectoryUser *directory.User
|
||||
IsImpersonated bool
|
||||
Session *session.Session
|
||||
User *user.User
|
||||
CSRFToken string
|
||||
IsImpersonated bool
|
||||
Session *session.Session
|
||||
User *user.User
|
||||
|
||||
WebAuthnCreationOptions *webauthn.PublicKeyCredentialCreationOptions
|
||||
WebAuthnRequestOptions *webauthn.PublicKeyCredentialRequestOptions
|
||||
|
@ -34,16 +31,6 @@ type UserInfoData struct {
|
|||
func (data UserInfoData) ToJSON() map[string]any {
|
||||
m := map[string]any{}
|
||||
m["csrfToken"] = data.CSRFToken
|
||||
var directoryGroups []json.RawMessage
|
||||
for _, directoryGroup := range data.DirectoryGroups {
|
||||
if bs, err := protojson.Marshal(directoryGroup); err == nil {
|
||||
directoryGroups = append(directoryGroups, json.RawMessage(bs))
|
||||
}
|
||||
}
|
||||
m["directoryGroups"] = directoryGroups
|
||||
if bs, err := protojson.Marshal(data.DirectoryUser); err == nil {
|
||||
m["directoryUser"] = json.RawMessage(bs)
|
||||
}
|
||||
m["isImpersonated"] = data.IsImpersonated
|
||||
if bs, err := protojson.Marshal(data.Session); err == nil {
|
||||
m["session"] = json.RawMessage(bs)
|
||||
|
|
|
@ -14,14 +14,11 @@ import (
|
|||
|
||||
"github.com/go-jose/go-jose/v3/jwt"
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/golang/protobuf/ptypes/empty"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"golang.org/x/crypto/chacha20poly1305"
|
||||
"golang.org/x/oauth2"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/pomerium/pomerium/authenticate/handlers/webauthn"
|
||||
|
@ -38,7 +35,6 @@ import (
|
|||
"github.com/pomerium/pomerium/internal/urlutil"
|
||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/directory"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/session"
|
||||
)
|
||||
|
||||
|
@ -165,7 +161,6 @@ func TestAuthenticate_SignIn(t *testing.T) {
|
|||
}, nil
|
||||
},
|
||||
},
|
||||
directoryClient: new(mockDirectoryServiceClient),
|
||||
}),
|
||||
|
||||
options: config.NewAtomicOptions(),
|
||||
|
@ -321,7 +316,6 @@ func TestAuthenticate_SignOut(t *testing.T) {
|
|||
return nil, nil
|
||||
},
|
||||
},
|
||||
directoryClient: new(mockDirectoryServiceClient),
|
||||
}),
|
||||
options: config.NewAtomicOptions(),
|
||||
}
|
||||
|
@ -423,10 +417,9 @@ func TestAuthenticate_OAuthCallback(t *testing.T) {
|
|||
return nil, nil
|
||||
},
|
||||
},
|
||||
directoryClient: new(mockDirectoryServiceClient),
|
||||
redirectURL: authURL,
|
||||
sessionStore: tt.session,
|
||||
cookieCipher: aead,
|
||||
redirectURL: authURL,
|
||||
sessionStore: tt.session,
|
||||
cookieCipher: aead,
|
||||
}),
|
||||
options: config.NewAtomicOptions(),
|
||||
}
|
||||
|
@ -565,7 +558,6 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) {
|
|||
}, nil
|
||||
},
|
||||
},
|
||||
directoryClient: new(mockDirectoryServiceClient),
|
||||
}),
|
||||
options: config.NewAtomicOptions(),
|
||||
}
|
||||
|
@ -681,7 +673,6 @@ func TestAuthenticate_userInfo(t *testing.T) {
|
|||
}, nil
|
||||
},
|
||||
},
|
||||
directoryClient: new(mockDirectoryServiceClient),
|
||||
}),
|
||||
}
|
||||
a.webauthn = webauthn.New(a.getWebauthnState)
|
||||
|
@ -723,19 +714,6 @@ func (m mockDataBrokerServiceClient) Put(ctx context.Context, in *databroker.Put
|
|||
return m.put(ctx, in, opts...)
|
||||
}
|
||||
|
||||
type mockDirectoryServiceClient struct {
|
||||
directory.DirectoryServiceClient
|
||||
|
||||
refreshUser func(ctx context.Context, in *directory.RefreshUserRequest, opts ...grpc.CallOption) (*empty.Empty, error)
|
||||
}
|
||||
|
||||
func (m mockDirectoryServiceClient) RefreshUser(ctx context.Context, in *directory.RefreshUserRequest, opts ...grpc.CallOption) (*empty.Empty, error) {
|
||||
if m.refreshUser != nil {
|
||||
return m.refreshUser(ctx, in, opts...)
|
||||
}
|
||||
return nil, status.Error(codes.Unimplemented, "")
|
||||
}
|
||||
|
||||
func mustParseURL(rawurl string) *url.URL {
|
||||
u, err := url.Parse(rawurl)
|
||||
if err != nil {
|
||||
|
|
|
@ -30,7 +30,6 @@ func defaultGetIdentityProvider(options *config.Options, idpID string) (identity
|
|||
ClientID: idp.GetClientId(),
|
||||
ClientSecret: idp.GetClientSecret(),
|
||||
Scopes: idp.GetScopes(),
|
||||
ServiceAccount: idp.GetServiceAccount(),
|
||||
AuthCodeOptions: idp.GetRequestParams(),
|
||||
})
|
||||
}
|
||||
|
|
|
@ -18,7 +18,6 @@ import (
|
|||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||
"github.com/pomerium/pomerium/pkg/grpc"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/directory"
|
||||
"github.com/pomerium/pomerium/pkg/webauthnutil"
|
||||
"github.com/pomerium/webauthn"
|
||||
)
|
||||
|
@ -47,7 +46,6 @@ type authenticateState struct {
|
|||
jwk *jose.JSONWebKeySet
|
||||
|
||||
dataBrokerClient databroker.DataBrokerServiceClient
|
||||
directoryClient directory.DirectoryServiceClient
|
||||
|
||||
webauthnRelyingParty *webauthn.RelyingParty
|
||||
}
|
||||
|
@ -154,7 +152,6 @@ func newAuthenticateStateFromConfig(cfg *config.Config) (*authenticateState, err
|
|||
}
|
||||
|
||||
state.dataBrokerClient = databroker.NewDataBrokerServiceClient(dataBrokerConn)
|
||||
state.directoryClient = directory.NewDirectoryServiceClient(dataBrokerConn)
|
||||
|
||||
state.webauthnRelyingParty = webauthn.NewRelyingParty(
|
||||
authenticateURL.String(),
|
||||
|
|
|
@ -14,7 +14,6 @@ import (
|
|||
"github.com/pomerium/pomerium/config"
|
||||
"github.com/pomerium/pomerium/internal/httputil"
|
||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/directory"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/session"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/user"
|
||||
"github.com/pomerium/pomerium/pkg/policy/criteria"
|
||||
|
@ -77,10 +76,6 @@ func TestEvaluator(t *testing.T) {
|
|||
To: config.WeightedURLs{{URL: *mustParseURL("https://to7.example.com")}},
|
||||
AllowedDomains: []string{"example.com"},
|
||||
},
|
||||
{
|
||||
To: config.WeightedURLs{{URL: *mustParseURL("https://to8.example.com")}},
|
||||
AllowedGroups: []string{"group1@example.com"},
|
||||
},
|
||||
{
|
||||
To: config.WeightedURLs{{URL: *mustParseURL("https://to9.example.com")}},
|
||||
AllowAnyAuthenticatedUser: true,
|
||||
|
@ -375,39 +370,6 @@ func TestEvaluator(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
assert.True(t, res.Allow.Value)
|
||||
})
|
||||
t.Run("groups", func(t *testing.T) {
|
||||
res, err := eval(t, options, []proto.Message{
|
||||
&session.Session{
|
||||
Id: "session1",
|
||||
UserId: "user1",
|
||||
},
|
||||
&user.User{
|
||||
Id: "user1",
|
||||
Email: "a@example.com",
|
||||
},
|
||||
&directory.User{
|
||||
Id: "user1",
|
||||
GroupIds: []string{"group1"},
|
||||
},
|
||||
&directory.Group{
|
||||
Id: "group1",
|
||||
Name: "group1name",
|
||||
Email: "group1@example.com",
|
||||
},
|
||||
}, &Request{
|
||||
Policy: &policies[7],
|
||||
Session: RequestSession{
|
||||
ID: "session1",
|
||||
},
|
||||
HTTP: RequestHTTP{
|
||||
Method: "GET",
|
||||
URL: "https://from.example.com",
|
||||
ClientCertificate: testValidCert,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.True(t, res.Allow.Value)
|
||||
})
|
||||
t.Run("any authenticated user", func(t *testing.T) {
|
||||
res, err := eval(t, options, []proto.Message{
|
||||
&session.Session{
|
||||
|
@ -473,7 +435,7 @@ func TestEvaluator(t *testing.T) {
|
|||
})
|
||||
t.Run("http method", func(t *testing.T) {
|
||||
res, err := eval(t, options, []proto.Message{}, &Request{
|
||||
Policy: &policies[9],
|
||||
Policy: &policies[8],
|
||||
HTTP: NewRequestHTTP(
|
||||
"GET",
|
||||
*mustParseURL("https://from.example.com/"),
|
||||
|
@ -487,7 +449,7 @@ func TestEvaluator(t *testing.T) {
|
|||
})
|
||||
t.Run("http path", func(t *testing.T) {
|
||||
res, err := eval(t, options, []proto.Message{}, &Request{
|
||||
Policy: &policies[10],
|
||||
Policy: &policies[9],
|
||||
HTTP: NewRequestHTTP(
|
||||
"POST",
|
||||
*mustParseURL("https://from.example.com/test"),
|
||||
|
|
|
@ -15,9 +15,7 @@ import (
|
|||
"github.com/pomerium/pomerium/authorize/internal/store"
|
||||
"github.com/pomerium/pomerium/config"
|
||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/directory"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/session"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/user"
|
||||
"github.com/pomerium/pomerium/pkg/storage"
|
||||
)
|
||||
|
||||
|
@ -63,25 +61,6 @@ func TestHeadersEvaluator(t *testing.T) {
|
|||
return e.Evaluate(ctx, input)
|
||||
}
|
||||
|
||||
t.Run("groups", func(t *testing.T) {
|
||||
output, err := eval(t,
|
||||
[]proto.Message{
|
||||
&session.Session{Id: "s1", UserId: "u1"},
|
||||
&user.User{Id: "u1"},
|
||||
&directory.User{Id: "u1", GroupIds: []string{"g1", "g2", "g3"}},
|
||||
},
|
||||
&HeadersRequest{
|
||||
FromAudience: "from.example.com",
|
||||
ToAudience: "to.example.com",
|
||||
Session: RequestSession{
|
||||
ID: "s1",
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "g1,g2,g3", output.Headers.Get("X-Pomerium-Claim-Groups"))
|
||||
})
|
||||
|
||||
t.Run("jwt", func(t *testing.T) {
|
||||
output, err := eval(t,
|
||||
[]proto.Message{
|
||||
|
|
|
@ -59,7 +59,7 @@ user = u {
|
|||
}
|
||||
|
||||
directory_user = du {
|
||||
du = get_databroker_record("type.googleapis.com/directory.User", session.user_id)
|
||||
du = get_databroker_record("pomerium.io/DirectoryUser", session.user_id)
|
||||
du != null
|
||||
} else = {} {
|
||||
true
|
||||
|
@ -273,11 +273,11 @@ identity_headers := {key: values |
|
|||
}
|
||||
|
||||
get_databroker_group_names(ids) = gs {
|
||||
gs := [name | id := ids[i]; group := get_databroker_record("type.googleapis.com/directory.Group", id); name := group.name]
|
||||
gs := [name | id := ids[i]; group := get_databroker_record("pomerium.io/DirectoryGroup", id); name := group.name]
|
||||
}
|
||||
|
||||
get_databroker_group_emails(ids) = gs {
|
||||
gs := [email | id := ids[i]; group := get_databroker_record("type.googleapis.com/directory.Group", id); email := group.email]
|
||||
gs := [email | id := ids[i]; group := get_databroker_record("pomerium.io/DirectoryGroup", id); email := group.email]
|
||||
}
|
||||
|
||||
get_header_string_value(obj) = s {
|
||||
|
|
|
@ -31,13 +31,12 @@ func (o *Options) GetIdentityProviderForPolicy(policy *Policy) (*identity.Provid
|
|||
}
|
||||
|
||||
idp := &identity.Provider{
|
||||
ClientId: o.ClientID,
|
||||
ClientSecret: clientSecret,
|
||||
Type: o.Provider,
|
||||
Scopes: o.Scopes,
|
||||
ServiceAccount: o.ServiceAccount,
|
||||
Url: o.ProviderURL,
|
||||
RequestParams: o.RequestParams,
|
||||
ClientId: o.ClientID,
|
||||
ClientSecret: clientSecret,
|
||||
Type: o.Provider,
|
||||
Scopes: o.Scopes,
|
||||
Url: o.ProviderURL,
|
||||
RequestParams: o.RequestParams,
|
||||
}
|
||||
if policy != nil {
|
||||
if policy.IDPClientID != "" {
|
||||
|
|
|
@ -20,12 +20,6 @@ import (
|
|||
"github.com/volatiletech/null/v9"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/atomicutil"
|
||||
"github.com/pomerium/pomerium/internal/directory/azure"
|
||||
"github.com/pomerium/pomerium/internal/directory/github"
|
||||
"github.com/pomerium/pomerium/internal/directory/gitlab"
|
||||
"github.com/pomerium/pomerium/internal/directory/google"
|
||||
"github.com/pomerium/pomerium/internal/directory/okta"
|
||||
"github.com/pomerium/pomerium/internal/directory/onelogin"
|
||||
"github.com/pomerium/pomerium/internal/hashutil"
|
||||
"github.com/pomerium/pomerium/internal/httputil"
|
||||
"github.com/pomerium/pomerium/internal/identity/oauth"
|
||||
|
@ -41,11 +35,6 @@ import (
|
|||
// DisableHeaderKey is the key used to check whether to disable setting header
|
||||
const DisableHeaderKey = "disable"
|
||||
|
||||
const (
|
||||
idpCustomScopesDocLink = "https://www.pomerium.com/docs/reference/identity-provider-scopes"
|
||||
idpCustomScopesWarnMsg = "config: using custom scopes may result in undefined behavior, see: " + idpCustomScopesDocLink
|
||||
)
|
||||
|
||||
// DefaultAlternativeAddr is the address used is two services are competing over
|
||||
// the same listener. Typically this is invisible to the end user (e.g. localhost)
|
||||
// gRPC server, or is used for healthchecks (authorize only service)
|
||||
|
@ -151,11 +140,6 @@ type Options struct {
|
|||
Provider string `mapstructure:"idp_provider" yaml:"idp_provider,omitempty"`
|
||||
ProviderURL string `mapstructure:"idp_provider_url" yaml:"idp_provider_url,omitempty"`
|
||||
Scopes []string `mapstructure:"idp_scopes" yaml:"idp_scopes,omitempty"`
|
||||
ServiceAccount string `mapstructure:"idp_service_account" yaml:"idp_service_account,omitempty"`
|
||||
// Identity provider refresh directory interval/timeout settings.
|
||||
RefreshDirectoryTimeout time.Duration `mapstructure:"idp_refresh_directory_timeout" yaml:"idp_refresh_directory_timeout,omitempty"`
|
||||
RefreshDirectoryInterval time.Duration `mapstructure:"idp_refresh_directory_interval" yaml:"idp_refresh_directory_interval,omitempty"`
|
||||
QPS float64 `mapstructure:"idp_qps" yaml:"idp_qps"`
|
||||
|
||||
// RequestParams are custom request params added to the signin request as
|
||||
// part of an Oauth2 code flow.
|
||||
|
@ -334,9 +318,6 @@ var defaultOptions = Options{
|
|||
GRPCClientDNSRoundRobin: true,
|
||||
AuthenticateCallbackPath: "/oauth2/callback",
|
||||
TracingSampleRate: 0.0001,
|
||||
RefreshDirectoryInterval: 10 * time.Minute,
|
||||
RefreshDirectoryTimeout: 1 * time.Minute,
|
||||
QPS: 1.0,
|
||||
|
||||
AutocertOptions: AutocertOptions{
|
||||
Folder: dataDir(),
|
||||
|
@ -698,17 +679,6 @@ func (o *Options) Validate() error {
|
|||
}
|
||||
}
|
||||
|
||||
// if no service account was defined, there should not be any policies that
|
||||
// assert group membership (except for azure which can be derived from the client
|
||||
// id, secret and provider url)
|
||||
if o.ServiceAccount == "" && o.Provider != "azure" {
|
||||
for _, p := range o.GetAllPolicies() {
|
||||
if len(p.AllowedGroups) != 0 {
|
||||
return fmt.Errorf("config: `allowed_groups` requires `idp_service_account`")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// strip quotes from redirect address (#811)
|
||||
o.HTTPRedirectAddr = strings.Trim(o.HTTPRedirectAddr, `"'`)
|
||||
|
||||
|
@ -717,14 +687,6 @@ func (o *Options) Validate() error {
|
|||
"`insecure_server` or manually provided certificates were provided, server will be using a self-signed certificate")
|
||||
}
|
||||
|
||||
switch o.Provider {
|
||||
case azure.Name, github.Name, gitlab.Name, google.Name, okta.Name, onelogin.Name:
|
||||
if len(o.Scopes) > 0 {
|
||||
log.Warn(ctx).Msg(idpCustomScopesWarnMsg)
|
||||
}
|
||||
default:
|
||||
}
|
||||
|
||||
if err := ValidateDNSLookupFamily(o.DNSLookupFamily); err != nil {
|
||||
return fmt.Errorf("config: %w", err)
|
||||
}
|
||||
|
@ -912,13 +874,12 @@ func (o *Options) GetOauthOptions() (oauth.Options, error) {
|
|||
return oauth.Options{}, err
|
||||
}
|
||||
return oauth.Options{
|
||||
RedirectURL: redirectURL,
|
||||
ProviderName: o.Provider,
|
||||
ProviderURL: o.ProviderURL,
|
||||
ClientID: o.ClientID,
|
||||
ClientSecret: clientSecret,
|
||||
Scopes: o.Scopes,
|
||||
ServiceAccount: o.ServiceAccount,
|
||||
RedirectURL: redirectURL,
|
||||
ProviderName: o.Provider,
|
||||
ProviderURL: o.ProviderURL,
|
||||
ClientID: o.ClientID,
|
||||
ClientSecret: clientSecret,
|
||||
Scopes: o.Scopes,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
@ -1029,9 +990,6 @@ func (o *Options) GetSharedKey() ([]byte, error) {
|
|||
|
||||
// GetGoogleCloudServerlessAuthenticationServiceAccount gets the GoogleCloudServerlessAuthenticationServiceAccount.
|
||||
func (o *Options) GetGoogleCloudServerlessAuthenticationServiceAccount() string {
|
||||
if o.GoogleCloudServerlessAuthenticationServiceAccount == "" && o.Provider == "google" {
|
||||
return o.ServiceAccount
|
||||
}
|
||||
return o.GoogleCloudServerlessAuthenticationServiceAccount
|
||||
}
|
||||
|
||||
|
@ -1043,14 +1001,6 @@ func (o *Options) GetSetResponseHeaders() map[string]string {
|
|||
return o.SetResponseHeaders
|
||||
}
|
||||
|
||||
// GetQPS gets the QPS.
|
||||
func (o *Options) GetQPS() float64 {
|
||||
if o.QPS < 1 {
|
||||
return 1
|
||||
}
|
||||
return o.QPS
|
||||
}
|
||||
|
||||
// GetCodecType gets a codec type.
|
||||
func (o *Options) GetCodecType() CodecType {
|
||||
if o.CodecType == CodecTypeUnset {
|
||||
|
@ -1393,15 +1343,6 @@ func (o *Options) ApplySettings(ctx context.Context, settings *config.Settings)
|
|||
if len(settings.Scopes) > 0 {
|
||||
o.Scopes = settings.Scopes
|
||||
}
|
||||
if settings.IdpServiceAccount != nil {
|
||||
o.ServiceAccount = settings.GetIdpServiceAccount()
|
||||
}
|
||||
if settings.IdpRefreshDirectoryTimeout != nil {
|
||||
o.RefreshDirectoryTimeout = settings.GetIdpRefreshDirectoryTimeout().AsDuration()
|
||||
}
|
||||
if settings.IdpRefreshDirectoryInterval != nil {
|
||||
o.RefreshDirectoryInterval = settings.GetIdpRefreshDirectoryInterval().AsDuration()
|
||||
}
|
||||
if settings.RequestParams != nil && len(settings.RequestParams) > 0 {
|
||||
o.RequestParams = settings.RequestParams
|
||||
}
|
||||
|
|
|
@ -310,12 +310,9 @@ func TestOptionsFromViper(t *testing.T) {
|
|||
"X-Frame-Options": "SAMEORIGIN",
|
||||
"X-XSS-Protection": "1; mode=block",
|
||||
},
|
||||
RefreshDirectoryTimeout: 1 * time.Minute,
|
||||
RefreshDirectoryInterval: 10 * time.Minute,
|
||||
QPS: 1.0,
|
||||
DataBrokerStorageType: "memory",
|
||||
EnvoyAdminAccessLogPath: os.DevNull,
|
||||
EnvoyAdminProfilePath: os.DevNull,
|
||||
DataBrokerStorageType: "memory",
|
||||
EnvoyAdminAccessLogPath: os.DevNull,
|
||||
EnvoyAdminProfilePath: os.DevNull,
|
||||
},
|
||||
false,
|
||||
},
|
||||
|
@ -330,9 +327,6 @@ func TestOptionsFromViper(t *testing.T) {
|
|||
CookieHTTPOnly: true,
|
||||
InsecureServer: true,
|
||||
SetResponseHeaders: map[string]string{"disable": "true"},
|
||||
RefreshDirectoryTimeout: 1 * time.Minute,
|
||||
RefreshDirectoryInterval: 10 * time.Minute,
|
||||
QPS: 1.0,
|
||||
DataBrokerStorageType: "memory",
|
||||
EnvoyAdminAccessLogPath: os.DevNull,
|
||||
EnvoyAdminProfilePath: os.DevNull,
|
||||
|
@ -342,7 +336,6 @@ func TestOptionsFromViper(t *testing.T) {
|
|||
{"bad url", []byte(`{"policy":[{"from": "https://","to":"https://to.example"}]}`), nil, true},
|
||||
{"bad policy", []byte(`{"policy":[{"allow_public_unauthenticated_access": "dog","to":"https://to.example"}]}`), nil, true},
|
||||
{"bad file", []byte(`{''''}`), nil, true},
|
||||
{"allowed_groups without idp_service_account should fail", []byte(`{"autocert_dir":"","insecure_server":true,"policy":[{"from": "https://from.example","to":"https://to.example","allowed_groups": "['group1']"}]}`), nil, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
|
|
@ -38,7 +38,6 @@ type Policy struct {
|
|||
|
||||
// Identity related policy
|
||||
AllowedUsers []string `mapstructure:"allowed_users" yaml:"allowed_users,omitempty" json:"allowed_users,omitempty"`
|
||||
AllowedGroups []string `mapstructure:"allowed_groups" yaml:"allowed_groups,omitempty" json:"allowed_groups,omitempty"`
|
||||
AllowedDomains []string `mapstructure:"allowed_domains" yaml:"allowed_domains,omitempty" json:"allowed_domains,omitempty"`
|
||||
AllowedIDPClaims identity.FlattenedClaims `mapstructure:"allowed_idp_claims" yaml:"allowed_idp_claims,omitempty" json:"allowed_idp_claims,omitempty"`
|
||||
|
||||
|
@ -192,7 +191,6 @@ type SubPolicy struct {
|
|||
ID string `mapstructure:"id" yaml:"id" json:"id"`
|
||||
Name string `mapstructure:"name" yaml:"name" json:"name"`
|
||||
AllowedUsers []string `mapstructure:"allowed_users" yaml:"allowed_users,omitempty" json:"allowed_users,omitempty"`
|
||||
AllowedGroups []string `mapstructure:"allowed_groups" yaml:"allowed_groups,omitempty" json:"allowed_groups,omitempty"`
|
||||
AllowedDomains []string `mapstructure:"allowed_domains" yaml:"allowed_domains,omitempty" json:"allowed_domains,omitempty"`
|
||||
AllowedIDPClaims identity.FlattenedClaims `mapstructure:"allowed_idp_claims" yaml:"allowed_idp_claims,omitempty" json:"allowed_idp_claims,omitempty"`
|
||||
Rego []string `mapstructure:"rego" yaml:"rego" json:"rego,omitempty"`
|
||||
|
@ -231,7 +229,6 @@ func NewPolicyFromProto(pb *configpb.Route) (*Policy, error) {
|
|||
p := &Policy{
|
||||
From: pb.GetFrom(),
|
||||
AllowedUsers: pb.GetAllowedUsers(),
|
||||
AllowedGroups: pb.GetAllowedGroups(),
|
||||
AllowedDomains: pb.GetAllowedDomains(),
|
||||
AllowedIDPClaims: identity.NewFlattenedClaimsFromPB(pb.GetAllowedIdpClaims()),
|
||||
Prefix: pb.GetPrefix(),
|
||||
|
@ -317,7 +314,6 @@ func NewPolicyFromProto(pb *configpb.Route) (*Policy, error) {
|
|||
ID: sp.GetId(),
|
||||
Name: sp.GetName(),
|
||||
AllowedUsers: sp.GetAllowedUsers(),
|
||||
AllowedGroups: sp.GetAllowedGroups(),
|
||||
AllowedDomains: sp.GetAllowedDomains(),
|
||||
AllowedIDPClaims: identity.NewFlattenedClaimsFromPB(sp.GetAllowedIdpClaims()),
|
||||
Rego: sp.GetRego(),
|
||||
|
@ -347,7 +343,6 @@ func (p *Policy) ToProto() (*configpb.Route, error) {
|
|||
Id: sp.ID,
|
||||
Name: sp.Name,
|
||||
AllowedUsers: sp.AllowedUsers,
|
||||
AllowedGroups: sp.AllowedGroups,
|
||||
AllowedDomains: sp.AllowedDomains,
|
||||
AllowedIdpClaims: sp.AllowedIDPClaims.ToPB(),
|
||||
Rego: sp.Rego,
|
||||
|
@ -358,7 +353,6 @@ func (p *Policy) ToProto() (*configpb.Route, error) {
|
|||
Name: fmt.Sprint(p.RouteID()),
|
||||
From: p.From,
|
||||
AllowedUsers: p.AllowedUsers,
|
||||
AllowedGroups: p.AllowedGroups,
|
||||
AllowedDomains: p.AllowedDomains,
|
||||
AllowedIdpClaims: p.AllowedIDPClaims.ToPB(),
|
||||
Prefix: p.Prefix,
|
||||
|
@ -466,12 +460,12 @@ func (p *Policy) Validate() error {
|
|||
}
|
||||
|
||||
// Only allow public access if no other whitelists are in place
|
||||
if p.AllowPublicUnauthenticatedAccess && (p.AllowAnyAuthenticatedUser || p.AllowedDomains != nil || p.AllowedGroups != nil || p.AllowedUsers != nil) {
|
||||
if p.AllowPublicUnauthenticatedAccess && (p.AllowAnyAuthenticatedUser || p.AllowedDomains != nil || p.AllowedUsers != nil) {
|
||||
return fmt.Errorf("config: policy route marked as public but contains whitelists")
|
||||
}
|
||||
|
||||
// Only allow any authenticated user if no other whitelists are in place
|
||||
if p.AllowAnyAuthenticatedUser && (p.AllowedDomains != nil || p.AllowedGroups != nil || p.AllowedUsers != nil) {
|
||||
if p.AllowAnyAuthenticatedUser && (p.AllowedDomains != nil || p.AllowedUsers != nil) {
|
||||
return fmt.Errorf("config: policy route marked accessible for any authenticated user but contains whitelists")
|
||||
}
|
||||
|
||||
|
@ -642,16 +636,6 @@ func (p *Policy) AllAllowedDomains() []string {
|
|||
return ads
|
||||
}
|
||||
|
||||
// AllAllowedGroups returns all the allowed groups.
|
||||
func (p *Policy) AllAllowedGroups() []string {
|
||||
var ags []string
|
||||
ags = append(ags, p.AllowedGroups...)
|
||||
for _, sp := range p.SubPolicies {
|
||||
ags = append(ags, sp.AllowedGroups...)
|
||||
}
|
||||
return ags
|
||||
}
|
||||
|
||||
// AllAllowedIDPClaims returns all the allowed IDP claims.
|
||||
func (p *Policy) AllAllowedIDPClaims() []identity.FlattenedClaims {
|
||||
var aics []identity.FlattenedClaims
|
||||
|
|
|
@ -47,15 +47,6 @@ func (p *Policy) ToPPL() *parser.Policy {
|
|||
},
|
||||
})
|
||||
}
|
||||
for _, ag := range p.AllAllowedGroups() {
|
||||
allowRule.Or = append(allowRule.Or,
|
||||
parser.Criterion{
|
||||
Name: "groups",
|
||||
Data: parser.Object{
|
||||
"has": parser.String(ag),
|
||||
},
|
||||
})
|
||||
}
|
||||
for _, aic := range p.AllAllowedIDPClaims() {
|
||||
var ks []string
|
||||
for k := range aic {
|
||||
|
|
|
@ -16,7 +16,6 @@ func TestPolicy_ToPPL(t *testing.T) {
|
|||
CORSAllowPreflight: true,
|
||||
AllowAnyAuthenticatedUser: true,
|
||||
AllowedDomains: []string{"a.example.com", "b.example.com"},
|
||||
AllowedGroups: []string{"group1", "group2"},
|
||||
AllowedUsers: []string{"user1", "user2"},
|
||||
AllowedIDPClaims: map[string][]interface{}{
|
||||
"family_name": {"Smith", "Jones"},
|
||||
|
@ -24,7 +23,6 @@ func TestPolicy_ToPPL(t *testing.T) {
|
|||
SubPolicies: []SubPolicy{
|
||||
{
|
||||
AllowedDomains: []string{"c.example.com", "d.example.com"},
|
||||
AllowedGroups: []string{"group3", "group4"},
|
||||
AllowedUsers: []string{"user3", "user4"},
|
||||
AllowedIDPClaims: map[string][]interface{}{
|
||||
"given_name": {"John"},
|
||||
|
@ -32,7 +30,6 @@ func TestPolicy_ToPPL(t *testing.T) {
|
|||
},
|
||||
{
|
||||
AllowedDomains: []string{"e.example.com"},
|
||||
AllowedGroups: []string{"group5"},
|
||||
AllowedUsers: []string{"user5"},
|
||||
AllowedIDPClaims: map[string][]interface{}{
|
||||
"timezone": {"EST"},
|
||||
|
@ -175,161 +172,6 @@ else = [false, {"user-unauthenticated"}] {
|
|||
true
|
||||
}
|
||||
|
||||
groups_0 = [true, {"groups-ok"}] {
|
||||
session := get_session(input.session.id)
|
||||
directory_user := get_directory_user(session)
|
||||
group_ids := get_group_ids(session, directory_user)
|
||||
group_names := [directory_group.name |
|
||||
some i
|
||||
group_id := group_ids[i]
|
||||
directory_group := get_directory_group(group_id)
|
||||
directory_group != null
|
||||
directory_group.name != null
|
||||
]
|
||||
group_emails := [directory_group.email |
|
||||
some i
|
||||
group_id := group_ids[i]
|
||||
directory_group := get_directory_group(group_id)
|
||||
directory_group != null
|
||||
directory_group.email != null
|
||||
]
|
||||
groups = array.concat(group_ids, array.concat(group_names, group_emails))
|
||||
count([true | some v; v = groups[_0]; v == "group1"]) > 0
|
||||
}
|
||||
|
||||
else = [false, {"groups-unauthorized"}] {
|
||||
session := get_session(input.session.id)
|
||||
session.id != ""
|
||||
}
|
||||
|
||||
else = [false, {"user-unauthenticated"}] {
|
||||
true
|
||||
}
|
||||
|
||||
groups_1 = [true, {"groups-ok"}] {
|
||||
session := get_session(input.session.id)
|
||||
directory_user := get_directory_user(session)
|
||||
group_ids := get_group_ids(session, directory_user)
|
||||
group_names := [directory_group.name |
|
||||
some i
|
||||
group_id := group_ids[i]
|
||||
directory_group := get_directory_group(group_id)
|
||||
directory_group != null
|
||||
directory_group.name != null
|
||||
]
|
||||
group_emails := [directory_group.email |
|
||||
some i
|
||||
group_id := group_ids[i]
|
||||
directory_group := get_directory_group(group_id)
|
||||
directory_group != null
|
||||
directory_group.email != null
|
||||
]
|
||||
groups = array.concat(group_ids, array.concat(group_names, group_emails))
|
||||
count([true | some v; v = groups[_0]; v == "group2"]) > 0
|
||||
}
|
||||
|
||||
else = [false, {"groups-unauthorized"}] {
|
||||
session := get_session(input.session.id)
|
||||
session.id != ""
|
||||
}
|
||||
|
||||
else = [false, {"user-unauthenticated"}] {
|
||||
true
|
||||
}
|
||||
|
||||
groups_2 = [true, {"groups-ok"}] {
|
||||
session := get_session(input.session.id)
|
||||
directory_user := get_directory_user(session)
|
||||
group_ids := get_group_ids(session, directory_user)
|
||||
group_names := [directory_group.name |
|
||||
some i
|
||||
group_id := group_ids[i]
|
||||
directory_group := get_directory_group(group_id)
|
||||
directory_group != null
|
||||
directory_group.name != null
|
||||
]
|
||||
group_emails := [directory_group.email |
|
||||
some i
|
||||
group_id := group_ids[i]
|
||||
directory_group := get_directory_group(group_id)
|
||||
directory_group != null
|
||||
directory_group.email != null
|
||||
]
|
||||
groups = array.concat(group_ids, array.concat(group_names, group_emails))
|
||||
count([true | some v; v = groups[_0]; v == "group3"]) > 0
|
||||
}
|
||||
|
||||
else = [false, {"groups-unauthorized"}] {
|
||||
session := get_session(input.session.id)
|
||||
session.id != ""
|
||||
}
|
||||
|
||||
else = [false, {"user-unauthenticated"}] {
|
||||
true
|
||||
}
|
||||
|
||||
groups_3 = [true, {"groups-ok"}] {
|
||||
session := get_session(input.session.id)
|
||||
directory_user := get_directory_user(session)
|
||||
group_ids := get_group_ids(session, directory_user)
|
||||
group_names := [directory_group.name |
|
||||
some i
|
||||
group_id := group_ids[i]
|
||||
directory_group := get_directory_group(group_id)
|
||||
directory_group != null
|
||||
directory_group.name != null
|
||||
]
|
||||
group_emails := [directory_group.email |
|
||||
some i
|
||||
group_id := group_ids[i]
|
||||
directory_group := get_directory_group(group_id)
|
||||
directory_group != null
|
||||
directory_group.email != null
|
||||
]
|
||||
groups = array.concat(group_ids, array.concat(group_names, group_emails))
|
||||
count([true | some v; v = groups[_0]; v == "group4"]) > 0
|
||||
}
|
||||
|
||||
else = [false, {"groups-unauthorized"}] {
|
||||
session := get_session(input.session.id)
|
||||
session.id != ""
|
||||
}
|
||||
|
||||
else = [false, {"user-unauthenticated"}] {
|
||||
true
|
||||
}
|
||||
|
||||
groups_4 = [true, {"groups-ok"}] {
|
||||
session := get_session(input.session.id)
|
||||
directory_user := get_directory_user(session)
|
||||
group_ids := get_group_ids(session, directory_user)
|
||||
group_names := [directory_group.name |
|
||||
some i
|
||||
group_id := group_ids[i]
|
||||
directory_group := get_directory_group(group_id)
|
||||
directory_group != null
|
||||
directory_group.name != null
|
||||
]
|
||||
group_emails := [directory_group.email |
|
||||
some i
|
||||
group_id := group_ids[i]
|
||||
directory_group := get_directory_group(group_id)
|
||||
directory_group != null
|
||||
directory_group.email != null
|
||||
]
|
||||
groups = array.concat(group_ids, array.concat(group_names, group_emails))
|
||||
count([true | some v; v = groups[_0]; v == "group5"]) > 0
|
||||
}
|
||||
|
||||
else = [false, {"groups-unauthorized"}] {
|
||||
session := get_session(input.session.id)
|
||||
session.id != ""
|
||||
}
|
||||
|
||||
else = [false, {"user-unauthenticated"}] {
|
||||
true
|
||||
}
|
||||
|
||||
claim_0 = [true, {"claim-ok"}] {
|
||||
rule_data := "Smith"
|
||||
rule_path := "family_name"
|
||||
|
@ -570,7 +412,7 @@ else = [false, {"user-unauthenticated"}] {
|
|||
}
|
||||
|
||||
or_0 = v {
|
||||
results := [pomerium_routes_0, accept_0, cors_preflight_0, authenticated_user_0, domain_0, domain_1, domain_2, domain_3, domain_4, groups_0, groups_1, groups_2, groups_3, groups_4, claim_0, claim_1, claim_2, claim_3, user_0, email_0, user_1, email_1, user_2, email_2, user_3, email_3, user_4, email_4]
|
||||
results := [pomerium_routes_0, accept_0, cors_preflight_0, authenticated_user_0, domain_0, domain_1, domain_2, domain_3, domain_4, claim_0, claim_1, claim_2, claim_3, user_0, email_0, user_1, email_1, user_2, email_2, user_3, email_3, user_4, email_4]
|
||||
normalized := [normalize_criterion_result(x) | x := results[i]]
|
||||
v := merge_with_or(normalized)
|
||||
}
|
||||
|
@ -715,24 +557,6 @@ else = {} {
|
|||
true
|
||||
}
|
||||
|
||||
get_directory_user(session) = v {
|
||||
v = get_databroker_record("type.googleapis.com/directory.User", session.user_id)
|
||||
v != null
|
||||
}
|
||||
|
||||
else = "" {
|
||||
true
|
||||
}
|
||||
|
||||
get_directory_group(id) = v {
|
||||
v = get_databroker_record("type.googleapis.com/directory.Group", id)
|
||||
v != null
|
||||
}
|
||||
|
||||
else = {} {
|
||||
true
|
||||
}
|
||||
|
||||
get_user_email(session, user) = v {
|
||||
v = user.email
|
||||
}
|
||||
|
@ -741,15 +565,6 @@ else = "" {
|
|||
true
|
||||
}
|
||||
|
||||
get_group_ids(session, directory_user) = v {
|
||||
v = directory_user.group_ids
|
||||
v != null
|
||||
}
|
||||
|
||||
else = [] {
|
||||
true
|
||||
}
|
||||
|
||||
object_get(obj, key, def) = value {
|
||||
undefined := "10a0fd35-0f1a-4e5b-97ce-631e89e1bafa"
|
||||
value = object.get(obj, key, undefined)
|
||||
|
|
|
@ -106,7 +106,7 @@ func Test_PolicyRouteID(t *testing.T) {
|
|||
{
|
||||
"same",
|
||||
&Policy{From: "https://pomerium.io", To: mustParseWeightedURLs(t, "http://localhost"), AllowedUsers: []string{"foo@bar.com"}},
|
||||
&Policy{From: "https://pomerium.io", To: mustParseWeightedURLs(t, "http://localhost"), AllowedGroups: []string{"allusers"}},
|
||||
&Policy{From: "https://pomerium.io", To: mustParseWeightedURLs(t, "http://localhost")},
|
||||
true,
|
||||
},
|
||||
{
|
||||
|
|
|
@ -7,7 +7,6 @@ import (
|
|||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
@ -16,7 +15,6 @@ import (
|
|||
|
||||
"github.com/pomerium/pomerium/config"
|
||||
"github.com/pomerium/pomerium/internal/atomicutil"
|
||||
"github.com/pomerium/pomerium/internal/directory"
|
||||
"github.com/pomerium/pomerium/internal/events"
|
||||
"github.com/pomerium/pomerium/internal/identity"
|
||||
"github.com/pomerium/pomerium/internal/identity/manager"
|
||||
|
@ -41,9 +39,6 @@ type DataBroker struct {
|
|||
localGRPCServer *grpc.Server
|
||||
localGRPCConnection *grpc.ClientConn
|
||||
sharedKey *atomicutil.Value[[]byte]
|
||||
|
||||
mu sync.Mutex
|
||||
directoryProvider directory.Provider
|
||||
}
|
||||
|
||||
// New creates a new databroker service.
|
||||
|
@ -126,7 +121,6 @@ func (c *DataBroker) OnConfigChange(ctx context.Context, cfg *config.Config) {
|
|||
// Register registers all the gRPC services with the given server.
|
||||
func (c *DataBroker) Register(grpcServer *grpc.Server) {
|
||||
databroker.RegisterDataBrokerServiceServer(grpcServer, c.dataBrokerServer)
|
||||
directory.RegisterDirectoryServiceServer(grpcServer, c)
|
||||
registry.RegisterRegistryServer(grpcServer, c.dataBrokerServer)
|
||||
}
|
||||
|
||||
|
@ -163,30 +157,10 @@ func (c *DataBroker) update(ctx context.Context, cfg *config.Config) error {
|
|||
return fmt.Errorf("databroker: invalid oauth options: %w", err)
|
||||
}
|
||||
|
||||
clientSecret, err := cfg.Options.GetClientSecret()
|
||||
if err != nil {
|
||||
return fmt.Errorf("databroker: error retrieving IPD client secret: %w", err)
|
||||
}
|
||||
|
||||
directoryProvider := directory.GetProvider(directory.Options{
|
||||
ServiceAccount: cfg.Options.ServiceAccount,
|
||||
Provider: cfg.Options.Provider,
|
||||
ProviderURL: cfg.Options.ProviderURL,
|
||||
QPS: cfg.Options.GetQPS(),
|
||||
ClientID: cfg.Options.ClientID,
|
||||
ClientSecret: clientSecret,
|
||||
})
|
||||
c.mu.Lock()
|
||||
c.directoryProvider = directoryProvider
|
||||
c.mu.Unlock()
|
||||
|
||||
dataBrokerClient := databroker.NewDataBrokerServiceClient(c.localGRPCConnection)
|
||||
|
||||
options := []manager.Option{
|
||||
manager.WithDirectoryProvider(directoryProvider),
|
||||
manager.WithDataBrokerClient(dataBrokerClient),
|
||||
manager.WithGroupRefreshInterval(cfg.Options.RefreshDirectoryInterval),
|
||||
manager.WithGroupRefreshTimeout(cfg.Options.RefreshDirectoryTimeout),
|
||||
manager.WithEventManager(c.eventsMgr),
|
||||
}
|
||||
|
||||
|
|
|
@ -6,6 +6,7 @@ import (
|
|||
"strconv"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
|
@ -89,7 +90,7 @@ func TestServerSync(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
_, err = client.Recv()
|
||||
require.Error(t, err)
|
||||
require.Equal(t, codes.Aborted, status.Code(err))
|
||||
assert.Equal(t, codes.Aborted.String(), status.Code(err).String())
|
||||
})
|
||||
}
|
||||
|
||||
|
|
|
@ -1,59 +0,0 @@
|
|||
package databroker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/types/known/emptypb"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/directory/directoryerrors"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/directory"
|
||||
"github.com/pomerium/pomerium/pkg/protoutil"
|
||||
)
|
||||
|
||||
// RefreshUser refreshes a user's directory information.
|
||||
func (c *DataBroker) RefreshUser(ctx context.Context, req *directory.RefreshUserRequest) (*emptypb.Empty, error) {
|
||||
c.mu.Lock()
|
||||
dp := c.directoryProvider
|
||||
c.mu.Unlock()
|
||||
|
||||
if dp == nil {
|
||||
return nil, errors.New("no directory provider is available for refresh")
|
||||
}
|
||||
|
||||
u, err := dp.User(ctx, req.GetUserId(), req.GetAccessToken())
|
||||
// if the returned error signals we should prefer existing information
|
||||
if errors.Is(err, directoryerrors.ErrPreferExistingInformation) {
|
||||
_, err = c.dataBrokerServer.Get(ctx, &databroker.GetRequest{
|
||||
Type: protoutil.GetTypeURL(new(directory.User)),
|
||||
Id: req.GetUserId(),
|
||||
})
|
||||
switch status.Code(err) {
|
||||
case codes.OK:
|
||||
return new(emptypb.Empty), nil
|
||||
case codes.NotFound: // go ahead and save the user that was returned
|
||||
default:
|
||||
return nil, fmt.Errorf("databroker: error retrieving existing user record for refresh: %w", err)
|
||||
}
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
any := protoutil.NewAny(u)
|
||||
_, err = c.dataBrokerServer.Put(ctx, &databroker.PutRequest{
|
||||
Records: []*databroker.Record{{
|
||||
Type: any.GetTypeUrl(),
|
||||
Id: u.GetId(),
|
||||
Data: any,
|
||||
}},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return new(emptypb.Empty), nil
|
||||
}
|
|
@ -1,6 +1,5 @@
|
|||
# See detailed configuration settings : https://www.pomerium.com/docs/reference/
|
||||
|
||||
|
||||
# this is the domain the identity provider will callback after a user authenticates
|
||||
authenticate_service_url: https://authenticate.localhost.pomerium.io
|
||||
|
||||
|
@ -20,7 +19,6 @@ certificate_key_file: /pomerium/privkey.pem
|
|||
idp_provider: google
|
||||
idp_client_id: REPLACE_ME
|
||||
idp_client_secret: REPLACE_ME
|
||||
#idp_service_account: REPLACE_ME # Required by some identity providers for directory sync
|
||||
|
||||
# Generate 256 bit random keys e.g. `head -c32 /dev/urandom | base64`
|
||||
cookie_secret: V2JBZk0zWGtsL29UcFUvWjVDWWQ2UHExNXJ0b2VhcDI=
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
#!/bin/bash
|
||||
# Main configuration flags : https://www.pomerium.com/docs/reference/
|
||||
|
||||
|
||||
# Main configuration flags
|
||||
# export ADDRESS=":8443" # optional, default is 443
|
||||
# export POMERIUM_DEBUG=true # optional, default is false
|
||||
|
@ -37,7 +36,6 @@ export COOKIE_SECRET="$(head -c32 /dev/urandom | base64)"
|
|||
# export IDP_CLIENT_ID="REPLACEME" # from the application the users login to
|
||||
# export IDP_CLIENT_SECRET="REPLACEME" # from the application the users login to
|
||||
# the following is optional and only needed if you want role (Auth0 calls groups roles) data
|
||||
# export IDP_SERVICE_ACCOUNT="REPLACEME" # built from the machine-to-machine application which talks to the Auth0 Management API
|
||||
|
||||
# Azure
|
||||
# export IDP_PROVIDER="azure"
|
||||
|
@ -64,8 +62,3 @@ export IDP_PROVIDER="google"
|
|||
# directly as a base64 encoded yaml/json file, or as the policy key in the configuration
|
||||
# file
|
||||
export POLICY="$(base64 ./docs/configuration/examples/config/policy.example.yaml)"
|
||||
|
||||
# For Group data you must set an IDP_SERVICE_ACCOUNT
|
||||
# https://www.pomerium.com/configuration/#identity-provider-service-account
|
||||
# export IDP_SERVICE_ACCOUNT=$( echo YOUR_SERVICE_ACCOUNT | base64)
|
||||
# For Google manually edit the service account to add the impersonate_user field before base64
|
||||
|
|
|
@ -40,8 +40,6 @@ authenticate_service_url: https://authenticate.localhost.pomerium.io
|
|||
# idp_provider_url: "https://REPLACEME.us.auth0.com"
|
||||
# idp_client_id: "REPLACEME" # from the application the users login to
|
||||
# idp_client_secret: "REPLACEME" # from the application the users login to
|
||||
# the following is optional and only needed if you want role (Auth0 calls groups roles) data
|
||||
# idp_service_account: "REPLACEME" # built from the machine-to-machine application which talks to the Auth0 Management API
|
||||
|
||||
# Azure
|
||||
# idp_provider: "azure"
|
||||
|
@ -54,10 +52,6 @@ authenticate_service_url: https://authenticate.localhost.pomerium.io
|
|||
# idp_client_id: "REPLACEME
|
||||
# idp_client_secret: "REPLACEME
|
||||
|
||||
# IF GSUITE and you want to get user groups you will need to set a service account
|
||||
# see identity provider docs for gooogle for more info :
|
||||
# idp_service_account: $(echo '{"impersonate_user": "user@example.com"}' | base64)
|
||||
|
||||
# OKTA
|
||||
# idp_provider: "okta"
|
||||
# idp_client_id: "REPLACEME"
|
||||
|
@ -70,9 +64,6 @@ authenticate_service_url: https://authenticate.localhost.pomerium.io
|
|||
# idp_client_secret: "REPLACEME"
|
||||
# idp_provider_url: "https://openid-connect.onelogin.com/oidc" #optional, defaults to `https://openid-connect.onelogin.com/oidc`
|
||||
|
||||
# For Group data you must set an IDP_SERVICE_ACCOUNT
|
||||
# idp_service_account: YOUR_SERVICE_ACCOUNT
|
||||
|
||||
# Proxied routes and per-route policies are defined in a routes block
|
||||
routes:
|
||||
- from: https://verify.localhost.pomerium.io
|
||||
|
|
|
@ -22,7 +22,6 @@ services:
|
|||
# - IDP_PROVIDER_URL=https://beyondperimeter.okta.com
|
||||
# - IDP_CLIENT_ID=REPLACE_ME
|
||||
# - IDP_CLIENT_SECRET=REPLACE_ME
|
||||
# - IDP_SERVICE_ACCOUNT=REPLACE_ME
|
||||
# NOTE! Generate new secret keys! e.g. `head -c32 /dev/urandom | base64`
|
||||
# Generated secret keys must match between services
|
||||
- SHARED_SECRET=aDducXQzK2tPY3R4TmdqTGhaYS80eGYxcTUvWWJDb2M=
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
# curl https://raw.githubusercontent.com/kubernetes/helm/master/scripts/get | bash
|
||||
# NOTE! This will create real resources on Google's cloud. Make sure you clean up any unused
|
||||
# resources to avoid being billed. For reference, this tutorial cost me <10 cents for a couple of hours.
|
||||
# NOTE! You must change the identity provider client secret setting, and service account setting!
|
||||
# NOTE! You must change the identity provider client secret setting!
|
||||
# NOTE! If you are using gsuite, you should also set `authenticate.idp.serviceAccount`, see docs !
|
||||
|
||||
echo "=> [GCE] creating cluster"
|
||||
|
|
|
@ -14,9 +14,6 @@ override_certificate_name: "*.localhost.pomerium.io"
|
|||
idp_provider: google
|
||||
idp_client_id: REPLACE_ME.apps.googleusercontent.com
|
||||
idp_client_secret: "REPLACE_ME"
|
||||
# Required for group data
|
||||
# https://www.pomerium.com/configuration/#identity-provider-service-account
|
||||
idp_service_account: YOUR_SERVICE_ACCOUNT
|
||||
|
||||
routes:
|
||||
- from: https://verify.localhost.pomerium.io
|
||||
|
|
|
@ -7,7 +7,6 @@ authenticate:
|
|||
provider: "google"
|
||||
clientID: YOUR_CLIENT_ID
|
||||
clientSecret: YOUR_SECRET
|
||||
serviceAccount: YOUR_SERVICE_ACCOUNT
|
||||
proxied: false
|
||||
|
||||
proxy:
|
||||
|
|
|
@ -3,9 +3,6 @@ authenticate:
|
|||
provider: "google"
|
||||
clientID: YOUR_CLIENT_ID
|
||||
clientSecret: YOUR_SECRET
|
||||
# Required for group data
|
||||
# https://www.pomerium.com/configuration/#identity-provider-service-account
|
||||
serviceAccount: YOUR_SERVICE_ACCOUNT
|
||||
service:
|
||||
type: NodePort
|
||||
annotations:
|
||||
|
|
5
go.mod
5
go.mod
|
@ -57,8 +57,6 @@ require (
|
|||
github.com/spf13/viper v1.13.0
|
||||
github.com/stretchr/testify v1.8.1
|
||||
github.com/tniswong/go.rfcx v0.0.0-20181019234604-07783c52761f
|
||||
github.com/tomnomnom/linkheader v0.0.0-20180905144013-02ca5825eb80
|
||||
github.com/vektah/gqlparser v1.3.1
|
||||
github.com/volatiletech/null/v9 v9.0.0
|
||||
github.com/yuin/gopher-lua v0.0.0-20200816102855-ee81675732da
|
||||
go.opencensus.io v0.23.0
|
||||
|
@ -72,7 +70,6 @@ require (
|
|||
google.golang.org/genproto v0.0.0-20221018160656-63c7b68cfc55
|
||||
google.golang.org/grpc v1.50.1
|
||||
google.golang.org/protobuf v1.28.1
|
||||
gopkg.in/auth0.v5 v5.21.1
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
namespacelabs.dev/go-filenotify v0.0.0-20220511192020-53ea11be7eaa
|
||||
sigs.k8s.io/yaml v1.3.0
|
||||
|
@ -94,7 +91,6 @@ require (
|
|||
github.com/Nvveen/Gotty v0.0.0-20120604004816-cd527374f1e5 // indirect
|
||||
github.com/OneOfOne/xxhash v1.2.8 // indirect
|
||||
github.com/OpenPeeDeeP/depguard v1.1.1 // indirect
|
||||
github.com/PuerkitoBio/rehttp v1.0.0 // indirect
|
||||
github.com/agnivade/levenshtein v1.1.1 // indirect
|
||||
github.com/alexkohler/prealloc v1.0.0 // indirect
|
||||
github.com/alingse/asasalint v0.0.11 // indirect
|
||||
|
@ -160,7 +156,6 @@ require (
|
|||
github.com/google/go-tpm v0.3.2 // indirect
|
||||
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 // indirect
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.2.0 // indirect
|
||||
github.com/googleapis/gax-go/v2 v2.6.0 // indirect
|
||||
github.com/gordonklaus/ineffassign v0.0.0-20210914165742-4cc7213b9bc8 // indirect
|
||||
github.com/gorilla/securecookie v1.1.1 // indirect
|
||||
github.com/gostaticanalysis/analysisutil v0.7.1 // indirect
|
||||
|
|
22
go.sum
22
go.sum
|
@ -102,13 +102,10 @@ github.com/OneOfOne/xxhash v1.2.8 h1:31czK/TI9sNkxIKfaUfGlU47BAxQ0ztGgd9vPyqimf8
|
|||
github.com/OneOfOne/xxhash v1.2.8/go.mod h1:eZbhyaAYD41SGSSsnmcpxVoRiQ/MPUTjUdIIOT9Um7Q=
|
||||
github.com/OpenPeeDeeP/depguard v1.1.1 h1:TSUznLjvp/4IUP+OQ0t/4jF4QUyxIcVX8YnghZdunyA=
|
||||
github.com/OpenPeeDeeP/depguard v1.1.1/go.mod h1:JtAMzWkmFEzDPyAd+W0NHl1lvpQKTvT9jnRVsohBKpc=
|
||||
github.com/PuerkitoBio/rehttp v1.0.0 h1:aJ7A7YI2lIvOxcJVeUZY4P6R7kKZtLeONjgyKGwOIu8=
|
||||
github.com/PuerkitoBio/rehttp v1.0.0/go.mod h1:ItsOiHl4XeMOV3rzbZqQRjLc3QQxbE6391/9iNG7rE8=
|
||||
github.com/Shopify/sarama v1.19.0/go.mod h1:FVkBWblsNy7DGZRfXLU0O9RCGt5g3g3yEuWXgklEdEo=
|
||||
github.com/Shopify/toxiproxy v2.1.4+incompatible/go.mod h1:OXgGpZ6Cli1/URJOF1DMxUHB2q5Ap20/P/eIdh4G0pI=
|
||||
github.com/VictoriaMetrics/fastcache v1.12.0 h1:vnVi/y9yKDcD9akmc4NqAoqgQhJrOwUF+j9LTgn4QDE=
|
||||
github.com/VictoriaMetrics/fastcache v1.12.0/go.mod h1:tjiYeEfYXCqacuvYw/7UoDIeJaNxq6132xHICNP77w8=
|
||||
github.com/agnivade/levenshtein v1.0.1/go.mod h1:CURSv5d9Uaml+FovSIICkLbAUZ9S4RqaHDIsdSBg7lM=
|
||||
github.com/agnivade/levenshtein v1.1.1 h1:QY8M92nrzkmr798gCo3kmMyqXFzdQVpxLlGPRBij0P8=
|
||||
github.com/agnivade/levenshtein v1.1.1/go.mod h1:veldBMzWxcCG2ZvUTKD2kJNRdCk5hVbJomOvKkmgYbo=
|
||||
github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
|
||||
|
@ -123,8 +120,6 @@ github.com/alingse/asasalint v0.0.11 h1:SFwnQXJ49Kx/1GghOFz1XGqHYKp21Kq1nHad/0WQ
|
|||
github.com/alingse/asasalint v0.0.11/go.mod h1:nCaoMhw7a9kSJObvQyVzNTPBDbNpdocqrSP7t/cW5+I=
|
||||
github.com/allegro/bigcache v1.2.1-0.20190218064605-e24eb225f156 h1:eMwmnE/GDgah4HI848JfFxHt+iPb26b4zyfspmqY0/8=
|
||||
github.com/allegro/bigcache v1.2.1-0.20190218064605-e24eb225f156/go.mod h1:Cb/ax3seSYIx7SuZdm2G2xzfwmv3TPSk2ucNfQESPXM=
|
||||
github.com/andreyvit/diff v0.0.0-20170406064948-c7f18ee00883 h1:bvNMNQO63//z+xNgfBlViaCIJKLlCJ6/fmUseuG0wVQ=
|
||||
github.com/andreyvit/diff v0.0.0-20170406064948-c7f18ee00883/go.mod h1:rCTlJbsFo29Kk6CurOXKm700vrz8f0KW0JNfpkRJY/8=
|
||||
github.com/andybalholm/brotli v1.0.4 h1:V7DdXeJtZscaqfNuAdSRuRFzuiKlHSC/Zh3zl9qY3JY=
|
||||
github.com/andybalholm/brotli v1.0.4/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig=
|
||||
github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY=
|
||||
|
@ -135,9 +130,6 @@ github.com/ashanbrown/forbidigo v1.3.0 h1:VkYIwb/xxdireGAdJNZoo24O4lmnEWkactplBl
|
|||
github.com/ashanbrown/forbidigo v1.3.0/go.mod h1:vVW7PEdqEFqapJe95xHkTfB1+XvZXBFg8t0sG2FIxmI=
|
||||
github.com/ashanbrown/makezero v1.1.1 h1:iCQ87C0V0vSyO+M9E/FZYbu65auqH0lnsOkf5FcB28s=
|
||||
github.com/ashanbrown/makezero v1.1.1/go.mod h1:i1bJLCRSCHOcOa9Y6MyF2FTfMZMFdHvxKHxgO5Z1axI=
|
||||
github.com/aybabtme/iocontrol v0.0.0-20150809002002-ad15bcfc95a0 h1:0NmehRCgyk5rljDQLKUO+cRJCnduDyn11+zGZIc9Z48=
|
||||
github.com/aybabtme/iocontrol v0.0.0-20150809002002-ad15bcfc95a0/go.mod h1:6L7zgvqo0idzI7IO8de6ZC051AfXb5ipkIJ7bIA2tGA=
|
||||
github.com/benbjohnson/clock v1.0.3/go.mod h1:bGMdMPoPVvcYyt1gHDf4J2KE153Yf9BuiUKYMaxlTDM=
|
||||
github.com/benbjohnson/clock v1.1.0 h1:Q92kusRqC1XV2MjkWETPvjJVqKetz1OzxZB7mHJLju8=
|
||||
github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA=
|
||||
github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q=
|
||||
|
@ -245,7 +237,6 @@ github.com/docker/go-connections v0.4.0/go.mod h1:Gbd7IOopHjR8Iph03tsViu4nIes5Xh
|
|||
github.com/docker/go-units v0.4.0 h1:3uh0PgVws3nIA0Q+MwDC8yjEPf9zjRfZZWXZYDct3Tw=
|
||||
github.com/docker/go-units v0.4.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk=
|
||||
github.com/dustin/go-humanize v1.0.0 h1:VSnTsYCnlFHaM2/igO1h6X3HA71jcobQuxemgkq4zYo=
|
||||
github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk=
|
||||
github.com/eapache/go-resiliency v1.1.0/go.mod h1:kFI+JgMyC7bLPUVY133qvEBtVayf5mFgVsvEsIPBvNs=
|
||||
github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21/go.mod h1:+020luEh2TKB4/GOp8oxxtq0Daoen/Cii55CzbTV6DU=
|
||||
github.com/eapache/queue v1.1.0/go.mod h1:6eCeP0CKFpHLu8blIFXhExK/dRa7WDZfr6jVFPTqq+I=
|
||||
|
@ -482,8 +473,6 @@ github.com/googleapis/gax-go/v2 v2.1.1/go.mod h1:hddJymUZASv3XPyGkUpKj8pPO47Rmb0
|
|||
github.com/googleapis/gax-go/v2 v2.2.0/go.mod h1:as02EH8zWkzwUoLbBaFeQ+arQaj/OthfcblKl4IGNaM=
|
||||
github.com/googleapis/gax-go/v2 v2.3.0/go.mod h1:b8LNqSzNabLiUpXKkY7HAR5jr6bIT99EXz9pXxye9YM=
|
||||
github.com/googleapis/gax-go/v2 v2.4.0/go.mod h1:XOTVJ59hdnfJLIP/dh8n5CGryZR2LxK9wbMD5+iXC6c=
|
||||
github.com/googleapis/gax-go/v2 v2.6.0 h1:SXk3ABtQYDT/OH8jAyvEOQ58mgawq5C4o/4/89qN2ZU=
|
||||
github.com/googleapis/gax-go/v2 v2.6.0/go.mod h1:1mjbznJAPHFpesgE5ucqfYEscaz5kMdcIDwU/6+DDoY=
|
||||
github.com/googleapis/go-type-adapters v1.0.0/go.mod h1:zHW75FOG2aur7gAO2B+MLby+cLsWGBF62rFAi7WjWO4=
|
||||
github.com/googleapis/google-cloud-go-testing v0.0.0-20200911160855-bcd43fbb19e8/go.mod h1:dvDLG8qkwmyD9a/MJJN3XJcT3xFxOKAvTZGvuZmac9g=
|
||||
github.com/gordonklaus/ineffassign v0.0.0-20210914165742-4cc7213b9bc8 h1:PVRE9d4AQKmbelZ7emNig1+NT27DUmKZn5qXxfio54U=
|
||||
|
@ -600,8 +589,6 @@ github.com/jingyugao/rowserrcheck v1.1.1/go.mod h1:4yvlZSDb3IyDTUZJUmpZfm2Hwok+D
|
|||
github.com/jirfag/go-printf-func-name v0.0.0-20200119135958-7558a9eaa5af h1:KA9BjwUk7KlCh6S9EAGWBt1oExIUv9WyNCiRz5amv48=
|
||||
github.com/jirfag/go-printf-func-name v0.0.0-20200119135958-7558a9eaa5af/go.mod h1:HEWGJkRDzjJY2sqdDwxccsGicWEf9BQOZsq2tV+xzM0=
|
||||
github.com/jmoiron/sqlx v1.2.0/go.mod h1:1FEQNm3xlJgrMD+FBdI9+xvCksHtbpVBBw5dYhBSsks=
|
||||
github.com/joho/godotenv v1.3.0 h1:Zjp+RcGpHhGlrMbJzXTrZZPrWj+1vfm90La1wgB6Bhc=
|
||||
github.com/joho/godotenv v1.3.0/go.mod h1:7hK45KPybAkOC6peb+G5yklZfMxEjkZhHbwpqxOKXbg=
|
||||
github.com/jonboulle/clockwork v0.1.0/go.mod h1:Ii8DK3G1RaLaWxj9trq07+26W01tbo22gdxWY5EU2bo=
|
||||
github.com/jpillora/backoff v1.0.0/go.mod h1:J/6gKK9jxlEcS3zixgDgUAsiuZ7yrSoa/FX5e0EB2j4=
|
||||
github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU=
|
||||
|
@ -886,7 +873,6 @@ github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdh
|
|||
github.com/seccomp/libseccomp-golang v0.9.2-0.20210429002308-3879420cc921/go.mod h1:JA8cRccbGaA1s33RQf7Y1+q9gHmZX1yB/z9WDN1C6fg=
|
||||
github.com/securego/gosec/v2 v2.13.1 h1:7mU32qn2dyC81MH9L2kefnQyRMUarfDER3iQyMHcjYM=
|
||||
github.com/securego/gosec/v2 v2.13.1/go.mod h1:EO1sImBMBWFjOTFzMWfTRrZW6M15gm60ljzrmy/wtHo=
|
||||
github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo=
|
||||
github.com/sergi/go-diff v1.1.0 h1:we8PVUC3FE2uYfodKH/nBHMSetSfHDR6scGdBi+erh0=
|
||||
github.com/sergi/go-diff v1.1.0/go.mod h1:STckp+ISIX8hZLjrqAeVduY0gWCT9IjLuqbuNXdaHfM=
|
||||
github.com/shazow/go-diff v0.0.0-20160112020656-b6b7b6733b8c h1:W65qqJCIOVP4jpqPQ0YvHYKwcMEMVWIzWC5iNQQfBTU=
|
||||
|
@ -994,8 +980,6 @@ github.com/tomarrell/wrapcheck/v2 v2.7.0 h1:J/F8DbSKJC83bAvC6FoZaRjZiZ/iKoueSdrE
|
|||
github.com/tomarrell/wrapcheck/v2 v2.7.0/go.mod h1:ao7l5p0aOlUNJKI0qVwB4Yjlqutd0IvAB9Rdwyilxvg=
|
||||
github.com/tommy-muehle/go-mnd/v2 v2.5.1 h1:NowYhSdyE/1zwK9QCLeRb6USWdoif80Ie+v+yU8u1Zw=
|
||||
github.com/tommy-muehle/go-mnd/v2 v2.5.1/go.mod h1:WsUAkMJMYww6l/ufffCD3m+P7LEvr8TnZn9lwVDlgzw=
|
||||
github.com/tomnomnom/linkheader v0.0.0-20180905144013-02ca5825eb80 h1:nrZ3ySNYwJbSpD6ce9duiP+QkD3JuLCcWkdaehUS/3Y=
|
||||
github.com/tomnomnom/linkheader v0.0.0-20180905144013-02ca5825eb80/go.mod h1:iFyPdL66DjUD96XmzVL3ZntbzcflLnznH0fr99w5VqE=
|
||||
github.com/uber/jaeger-client-go v2.25.0+incompatible h1:IxcNZ7WRY1Y3G4poYlx24szfsn/3LvK9QHCq9oQw8+U=
|
||||
github.com/uber/jaeger-client-go v2.25.0+incompatible/go.mod h1:WVhlPFC8FDjOFMMWRy2pZqQJSXxYSwNYOkTr/Z6d3Kk=
|
||||
github.com/ugorji/go v1.1.4/go.mod h1:uQMGLiO92mf5W77hV/PUCpI3pbzQx3CRekS0kk+RGrc=
|
||||
|
@ -1009,8 +993,6 @@ github.com/uudashr/gocognit v1.0.6 h1:2Cgi6MweCsdB6kpcVQp7EW4U23iBFQWfTXiWlyp842
|
|||
github.com/uudashr/gocognit v1.0.6/go.mod h1:nAIUuVBnYU7pcninia3BHOvQkpQCeO76Uscky5BOwcY=
|
||||
github.com/valyala/gozstd v1.11.0 h1:VV6qQFt+4sBBj9OJ7eKVvsFAMy59Urcs9Lgd+o5FOw0=
|
||||
github.com/valyala/gozstd v1.11.0/go.mod h1:y5Ew47GLlP37EkTB+B4s7r6A5rdaeB7ftbl9zoYiIPQ=
|
||||
github.com/vektah/gqlparser v1.3.1 h1:8b0IcD3qZKWJQHSzynbDlrtP3IxVydZ2DZepCGofqfU=
|
||||
github.com/vektah/gqlparser v1.3.1/go.mod h1:bkVf0FX+Stjg/MHnm8mEyubuaArhNEqfQhF+OTiAL74=
|
||||
github.com/vishvananda/netlink v1.1.0/go.mod h1:cTgwzPIzzgDAYoQrMm0EdrjRUBkTqKYppBueQtXaqoE=
|
||||
github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df/go.mod h1:JP3t17pCcGlemwknint6hfoeCVQrEMVwxRLRjXpq+BU=
|
||||
github.com/volatiletech/null/v9 v9.0.0 h1:JCdlHEiSRVxOi7/MABiEfdsqmuj9oTV20Ao7VvZ0JkE=
|
||||
|
@ -1367,7 +1349,6 @@ golang.org/x/tools v0.0.0-20180221164845-07fd8470d635/go.mod h1:n7NCudcB/nEzxVGm
|
|||
golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20190125232054-d66bd3c5d5a6/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY=
|
||||
golang.org/x/tools v0.0.0-20190307163923-6a08e3108db3/go.mod h1:25r3+/G6/xytQM8iWZKq3Hn0kr0rgFKPUNVEL/dr3z4=
|
||||
golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs=
|
||||
|
@ -1467,7 +1448,6 @@ golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8T
|
|||
golang.org/x/xerrors v0.0.0-20220411194840-2f41105eb62f/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20220517211312-f3a8303e98df/go.mod h1:K8+ghG5WaK9qNqU5K3HdILfMLy1f3aNYFI/wnl100a8=
|
||||
golang.org/x/xerrors v0.0.0-20220609144429-65e65417b02f/go.mod h1:K8+ghG5WaK9qNqU5K3HdILfMLy1f3aNYFI/wnl100a8=
|
||||
golang.org/x/xerrors v0.0.0-20220907171357-04be3eba64a2 h1:H2TDz8ibqkAF6YGhCdN3jS9O0/s90v0rJh3X/OLHEUk=
|
||||
golang.org/x/xerrors v0.0.0-20220907171357-04be3eba64a2/go.mod h1:K8+ghG5WaK9qNqU5K3HdILfMLy1f3aNYFI/wnl100a8=
|
||||
google.golang.org/api v0.4.0/go.mod h1:8k5glujaEP+g9n7WNsDg8QP6cUVNI86fCNMcbazEtwE=
|
||||
google.golang.org/api v0.7.0/go.mod h1:WtwebWUNSVBH/HAw79HIFXZNqEvBhG+Ra+ax0hx3E3M=
|
||||
|
@ -1657,8 +1637,6 @@ google.golang.org/protobuf v1.28.1/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqw
|
|||
gopkg.in/DataDog/dd-trace-go.v1 v1.22.0 h1:gpWsqqkwUldNZXGJqT69NU9MdEDhLboK1C4nMgR0MWw=
|
||||
gopkg.in/DataDog/dd-trace-go.v1 v1.22.0/go.mod h1:DVp8HmDh8PuTu2Z0fVVlBsyWaC++fzwVCaGWylTe3tg=
|
||||
gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw=
|
||||
gopkg.in/auth0.v5 v5.21.1 h1:aIqHBmnqaDv4eK2WSpTRsv2dEpT1jdHJPl+iwyDJNoo=
|
||||
gopkg.in/auth0.v5 v5.21.1/go.mod h1:k1eJq1+II4rwUlecBabE7u4igEuzKUCEZAMa11PUfQk=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
|
|
|
@ -363,7 +363,7 @@ func (srv *Server) SyncLatest(req *databroker.SyncLatestRequest, stream databrok
|
|||
}
|
||||
}
|
||||
if recordStream.Err() != nil {
|
||||
return recordStream.Err()
|
||||
return err
|
||||
}
|
||||
|
||||
// always send the server version last in case there are no records
|
||||
|
|
|
@ -1,292 +0,0 @@
|
|||
// Package auth0 contains the Auth0 directory provider.
|
||||
// Note that Auth0 refers to groups as roles.
|
||||
package auth0
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sort"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"gopkg.in/auth0.v5/management"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/encoding"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/directory"
|
||||
)
|
||||
|
||||
// Name is the provider name.
|
||||
const Name = "auth0"
|
||||
|
||||
type (
|
||||
// RoleManager defines what is needed to get role info from Auth0.
|
||||
RoleManager interface {
|
||||
List(opts ...management.RequestOption) (r *management.RoleList, err error)
|
||||
Users(id string, opts ...management.RequestOption) (u *management.UserList, err error)
|
||||
}
|
||||
// UserManager defines what is needed to get user info from Auth0.
|
||||
UserManager interface {
|
||||
Read(id string, opts ...management.RequestOption) (*management.User, error)
|
||||
Roles(id string, opts ...management.RequestOption) (r *management.RoleList, err error)
|
||||
}
|
||||
)
|
||||
|
||||
type newManagersFunc = func(ctx context.Context, domain string, serviceAccount *ServiceAccount) (RoleManager, UserManager, error)
|
||||
|
||||
func defaultNewManagersFunc(ctx context.Context, domain string, serviceAccount *ServiceAccount) (RoleManager, UserManager, error) {
|
||||
// override the domain for the management api if supplied
|
||||
if serviceAccount.Domain != "" {
|
||||
domain = serviceAccount.Domain
|
||||
}
|
||||
m, err := management.New(domain,
|
||||
management.WithClientCredentials(serviceAccount.ClientID, serviceAccount.Secret),
|
||||
management.WithContext(ctx))
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("auth0: could not build management: %w", err)
|
||||
}
|
||||
return m.Role, m.User, nil
|
||||
}
|
||||
|
||||
type config struct {
|
||||
domain string
|
||||
serviceAccount *ServiceAccount
|
||||
newManagers newManagersFunc
|
||||
}
|
||||
|
||||
// Option provides config for the Auth0 Provider.
|
||||
type Option func(cfg *config)
|
||||
|
||||
// WithServiceAccount sets the service account option.
|
||||
func WithServiceAccount(serviceAccount *ServiceAccount) Option {
|
||||
return func(cfg *config) {
|
||||
cfg.serviceAccount = serviceAccount
|
||||
}
|
||||
}
|
||||
|
||||
// WithDomain sets the provider domain option.
|
||||
func WithDomain(domain string) Option {
|
||||
return func(cfg *config) {
|
||||
cfg.domain = domain
|
||||
}
|
||||
}
|
||||
|
||||
func getConfig(options ...Option) *config {
|
||||
cfg := &config{
|
||||
newManagers: defaultNewManagersFunc,
|
||||
}
|
||||
for _, option := range options {
|
||||
option(cfg)
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
|
||||
// Provider is an Auth0 user group directory provider.
|
||||
type Provider struct {
|
||||
cfg *config
|
||||
log zerolog.Logger
|
||||
}
|
||||
|
||||
// New creates a new Provider.
|
||||
func New(options ...Option) *Provider {
|
||||
return &Provider{
|
||||
cfg: getConfig(options...),
|
||||
log: log.With().Str("service", "directory").Str("provider", "auth0").Logger(),
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Provider) getManagers(ctx context.Context) (RoleManager, UserManager, error) {
|
||||
return p.cfg.newManagers(ctx, p.cfg.domain, p.cfg.serviceAccount)
|
||||
}
|
||||
|
||||
// User returns the user record for the given id.
|
||||
func (p *Provider) User(ctx context.Context, userID, accessToken string) (*directory.User, error) {
|
||||
_, um, err := p.getManagers(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("auth0: could not get the role manager: %w", err)
|
||||
}
|
||||
|
||||
du := &directory.User{
|
||||
Id: userID,
|
||||
}
|
||||
|
||||
u, err := um.Read(userID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("auth0: error getting user info: %w", err)
|
||||
}
|
||||
du.DisplayName = u.GetName()
|
||||
du.Email = u.GetEmail()
|
||||
|
||||
for page, hasNext := 0, true; hasNext; page++ {
|
||||
rl, err := um.Roles(userID, management.IncludeTotals(true), management.Page(page))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("auth0: error getting user roles: %w", err)
|
||||
}
|
||||
|
||||
for _, role := range rl.Roles {
|
||||
du.GroupIds = append(du.GroupIds, role.GetID())
|
||||
}
|
||||
|
||||
hasNext = rl.HasNext()
|
||||
}
|
||||
|
||||
sort.Strings(du.GroupIds)
|
||||
return du, nil
|
||||
}
|
||||
|
||||
// UserGroups fetches a slice of groups and users.
|
||||
func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*directory.User, error) {
|
||||
rm, _, err := p.getManagers(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("auth0: could not get the role manager: %w", err)
|
||||
}
|
||||
|
||||
roles, err := getRoles(rm)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("auth0: %w", err)
|
||||
}
|
||||
|
||||
userIDToGroups := map[string][]string{}
|
||||
for _, role := range roles {
|
||||
ids, err := getRoleUserIDs(rm, role.Id)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("auth0: %w", err)
|
||||
}
|
||||
|
||||
for _, id := range ids {
|
||||
userIDToGroups[id] = append(userIDToGroups[id], role.Id)
|
||||
}
|
||||
}
|
||||
|
||||
var users []*directory.User
|
||||
for userID, groups := range userIDToGroups {
|
||||
sort.Strings(groups)
|
||||
users = append(users, &directory.User{
|
||||
Id: userID,
|
||||
GroupIds: groups,
|
||||
})
|
||||
}
|
||||
sort.Slice(users, func(i, j int) bool {
|
||||
return users[i].Id < users[j].Id
|
||||
})
|
||||
return roles, users, nil
|
||||
}
|
||||
|
||||
func getRoles(rm RoleManager) ([]*directory.Group, error) {
|
||||
roles := []*directory.Group{}
|
||||
|
||||
shouldContinue := true
|
||||
page := 0
|
||||
|
||||
for shouldContinue {
|
||||
listRes, err := rm.List(management.IncludeTotals(true), management.Page(page))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not list roles: %w", err)
|
||||
}
|
||||
|
||||
for _, role := range listRes.Roles {
|
||||
roles = append(roles, &directory.Group{
|
||||
Id: *role.ID,
|
||||
Name: *role.Name,
|
||||
})
|
||||
}
|
||||
|
||||
page++
|
||||
shouldContinue = listRes.HasNext()
|
||||
}
|
||||
|
||||
sort.Slice(roles, func(i, j int) bool {
|
||||
return roles[i].GetId() < roles[j].GetId()
|
||||
})
|
||||
return roles, nil
|
||||
}
|
||||
|
||||
func getRoleUserIDs(rm RoleManager, roleID string) ([]string, error) {
|
||||
var ids []string
|
||||
|
||||
shouldContinue := true
|
||||
page := 0
|
||||
|
||||
for shouldContinue {
|
||||
usersRes, err := rm.Users(roleID, management.IncludeTotals(true), management.Page(page))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not get users for role %q: %w", roleID, err)
|
||||
}
|
||||
|
||||
for _, user := range usersRes.Users {
|
||||
ids = append(ids, *user.ID)
|
||||
}
|
||||
|
||||
page++
|
||||
shouldContinue = usersRes.HasNext()
|
||||
}
|
||||
|
||||
sort.Strings(ids)
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
// A ServiceAccount is used by the Auth0 provider to query the API.
|
||||
type ServiceAccount struct {
|
||||
Domain string `json:"domain"`
|
||||
ClientID string `json:"client_id"`
|
||||
ClientSecret string `json:"client_secret"`
|
||||
Secret string `json:"secret"`
|
||||
}
|
||||
|
||||
// ParseServiceAccount parses the service account in the config options.
|
||||
func ParseServiceAccount(options directory.Options) (*ServiceAccount, error) {
|
||||
if options.ServiceAccount != "" {
|
||||
return parseServiceAccountFromString(options.ServiceAccount)
|
||||
}
|
||||
return parseServiceAccountFromOptions(options.ClientID, options.ClientSecret)
|
||||
}
|
||||
|
||||
func parseServiceAccountFromOptions(clientID, clientSecret string) (*ServiceAccount, error) {
|
||||
serviceAccount := ServiceAccount{
|
||||
ClientID: clientID,
|
||||
Secret: clientSecret,
|
||||
}
|
||||
|
||||
if serviceAccount.ClientID == "" {
|
||||
return nil, fmt.Errorf("auth0: client_id is required")
|
||||
}
|
||||
|
||||
// for backwards compatibility we support secret and client_secret
|
||||
if serviceAccount.Secret == "" {
|
||||
serviceAccount.Secret = serviceAccount.ClientSecret
|
||||
}
|
||||
if serviceAccount.ClientSecret == "" {
|
||||
serviceAccount.ClientSecret = serviceAccount.Secret
|
||||
}
|
||||
|
||||
if serviceAccount.Secret == "" {
|
||||
return nil, fmt.Errorf("auth0: client_secret is required")
|
||||
}
|
||||
|
||||
return &serviceAccount, nil
|
||||
}
|
||||
|
||||
func parseServiceAccountFromString(rawServiceAccount string) (*ServiceAccount, error) {
|
||||
var serviceAccount ServiceAccount
|
||||
if err := encoding.DecodeBase64OrJSON(rawServiceAccount, &serviceAccount); err != nil {
|
||||
return nil, fmt.Errorf("auth0: could not unmarshal json: %w", err)
|
||||
}
|
||||
|
||||
if serviceAccount.ClientID == "" {
|
||||
return nil, errors.New("auth0: client_id is required")
|
||||
}
|
||||
|
||||
// for backwards compatibility we support secret and client_secret
|
||||
if serviceAccount.Secret == "" {
|
||||
serviceAccount.Secret = serviceAccount.ClientSecret
|
||||
}
|
||||
if serviceAccount.ClientSecret == "" {
|
||||
serviceAccount.ClientSecret = serviceAccount.Secret
|
||||
}
|
||||
|
||||
if serviceAccount.Secret == "" {
|
||||
return nil, errors.New("auth0: secret is required")
|
||||
}
|
||||
|
||||
return &serviceAccount, nil
|
||||
}
|
|
@ -1,563 +0,0 @@
|
|||
package auth0
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/go-chi/chi/v5/middleware"
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"gopkg.in/auth0.v5/management"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/directory/auth0/mock_auth0"
|
||||
"github.com/pomerium/pomerium/internal/testutil"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/directory"
|
||||
)
|
||||
|
||||
type M = map[string]interface{}
|
||||
|
||||
func newMockAPI(t *testing.T, srv *httptest.Server) http.Handler {
|
||||
r := chi.NewRouter()
|
||||
r.Use(middleware.Logger)
|
||||
r.Post("/oauth/token", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(M{
|
||||
"access_token": "ACCESSTOKEN",
|
||||
"token_type": "Bearer",
|
||||
"refresh_token": "REFRESHTOKEN",
|
||||
})
|
||||
})
|
||||
r.Route("/api/v2", func(r chi.Router) {
|
||||
r.Route("/users", func(r chi.Router) {
|
||||
r.Route("/{user_id}", func(r chi.Router) {
|
||||
r.Get("/roles", func(w http.ResponseWriter, r *http.Request) {
|
||||
switch chi.URLParam(r, "user_id") {
|
||||
case "user1":
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(M{
|
||||
"total": 2,
|
||||
"limit": 2,
|
||||
"roles": []M{
|
||||
{"id": "role1"},
|
||||
{"id": "role2"},
|
||||
},
|
||||
})
|
||||
default:
|
||||
http.Error(w, "not found", http.StatusNotFound)
|
||||
}
|
||||
})
|
||||
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
|
||||
switch chi.URLParam(r, "user_id") {
|
||||
case "user1":
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(M{
|
||||
"user_id": "user1",
|
||||
"email": "user1@example.com",
|
||||
"name": "User 1",
|
||||
})
|
||||
default:
|
||||
http.Error(w, "not found", http.StatusNotFound)
|
||||
}
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
func TestProvider_User(t *testing.T) {
|
||||
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Minute)
|
||||
defer clearTimeout()
|
||||
|
||||
orig := http.DefaultTransport
|
||||
defer func() {
|
||||
http.DefaultTransport = orig
|
||||
}()
|
||||
http.DefaultTransport = &http.Transport{
|
||||
TLSClientConfig: &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
},
|
||||
}
|
||||
|
||||
var mockAPI http.Handler
|
||||
srv := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
mockAPI.ServeHTTP(w, r)
|
||||
}))
|
||||
srv.StartTLS()
|
||||
defer srv.Close()
|
||||
mockAPI = newMockAPI(t, srv)
|
||||
|
||||
p := New(
|
||||
WithDomain(srv.URL),
|
||||
WithServiceAccount(&ServiceAccount{ClientID: "CLIENT_ID", Secret: "SECRET"}),
|
||||
)
|
||||
du, err := p.User(ctx, "user1", "")
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
testutil.AssertProtoJSONEqual(t, `{
|
||||
"id": "user1",
|
||||
"displayName": "User 1",
|
||||
"email": "user1@example.com",
|
||||
"groupIds": ["role1", "role2"]
|
||||
}`, du)
|
||||
}
|
||||
|
||||
type mockNewRoleManagerFunc struct {
|
||||
CalledWithContext context.Context
|
||||
CalledWithDomain string
|
||||
CalledWithServiceAccount *ServiceAccount
|
||||
|
||||
ReturnRoleManager RoleManager
|
||||
ReturnUserManager UserManager
|
||||
ReturnError error
|
||||
}
|
||||
|
||||
func (m *mockNewRoleManagerFunc) f(ctx context.Context, domain string, serviceAccount *ServiceAccount) (RoleManager, UserManager, error) {
|
||||
m.CalledWithContext = ctx
|
||||
m.CalledWithDomain = domain
|
||||
m.CalledWithServiceAccount = serviceAccount
|
||||
|
||||
return m.ReturnRoleManager, m.ReturnUserManager, m.ReturnError
|
||||
}
|
||||
|
||||
type listOptionMatcher struct {
|
||||
expected management.RequestOption
|
||||
}
|
||||
|
||||
func buildValues(opt management.RequestOption) map[string][]string {
|
||||
req, err := (&management.Management{}).NewRequest("GET", "example.com", nil, opt)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return req.URL.Query()
|
||||
}
|
||||
|
||||
func (lom listOptionMatcher) Matches(actual interface{}) bool {
|
||||
return gomock.Eq(buildValues(lom.expected)).Matches(buildValues(actual.(management.RequestOption)))
|
||||
}
|
||||
|
||||
func (lom listOptionMatcher) String() string {
|
||||
return fmt.Sprintf("is equal to %v", buildValues(lom.expected))
|
||||
}
|
||||
|
||||
func stringPtr(in string) *string {
|
||||
return &in
|
||||
}
|
||||
|
||||
func TestProvider_UserGroups(t *testing.T) {
|
||||
expectedDomain := "login.example.com"
|
||||
expectedServiceAccount := &ServiceAccount{Domain: "login-example.auth0.com", ClientID: "c_id", Secret: "secret"}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
setupRoleManagerExpectations func(*mock_auth0.MockRoleManager)
|
||||
newRoleManagerError error
|
||||
expectedGroups []*directory.Group
|
||||
expectedUsers []*directory.User
|
||||
expectedError error
|
||||
}{
|
||||
{
|
||||
name: "errors if getting the role manager errors",
|
||||
newRoleManagerError: errors.New("new role manager error"),
|
||||
expectedError: errors.New("auth0: could not get the role manager: new role manager error"),
|
||||
},
|
||||
{
|
||||
name: "errors if listing roles errors",
|
||||
setupRoleManagerExpectations: func(mrm *mock_auth0.MockRoleManager) {
|
||||
mrm.EXPECT().List(
|
||||
listOptionMatcher{expected: management.IncludeTotals(true)},
|
||||
listOptionMatcher{expected: management.Page(0)},
|
||||
).Return(nil, errors.New("list error"))
|
||||
},
|
||||
expectedError: errors.New("auth0: could not list roles: list error"),
|
||||
},
|
||||
{
|
||||
name: "errors if getting user ids errors",
|
||||
setupRoleManagerExpectations: func(mrm *mock_auth0.MockRoleManager) {
|
||||
mrm.EXPECT().List(
|
||||
listOptionMatcher{expected: management.IncludeTotals(true)},
|
||||
listOptionMatcher{expected: management.Page(0)},
|
||||
).Return(&management.RoleList{
|
||||
Roles: []*management.Role{
|
||||
{
|
||||
ID: stringPtr("i-am-role-id"),
|
||||
Name: stringPtr("i-am-role-name"),
|
||||
},
|
||||
},
|
||||
}, nil)
|
||||
|
||||
mrm.EXPECT().Users(
|
||||
"i-am-role-id",
|
||||
listOptionMatcher{expected: management.IncludeTotals(true)},
|
||||
listOptionMatcher{expected: management.Page(0)},
|
||||
).Return(nil, errors.New("users error"))
|
||||
},
|
||||
expectedError: errors.New("auth0: could not get users for role \"i-am-role-id\": users error"),
|
||||
},
|
||||
{
|
||||
name: "handles multiple pages of roles",
|
||||
setupRoleManagerExpectations: func(mrm *mock_auth0.MockRoleManager) {
|
||||
mrm.EXPECT().List(
|
||||
listOptionMatcher{expected: management.IncludeTotals(true)},
|
||||
listOptionMatcher{expected: management.Page(0)},
|
||||
).Return(&management.RoleList{
|
||||
List: management.List{
|
||||
Total: 3,
|
||||
Start: 0,
|
||||
Limit: 1,
|
||||
},
|
||||
Roles: []*management.Role{
|
||||
{
|
||||
ID: stringPtr("i-am-role-id-1"),
|
||||
Name: stringPtr("i-am-role-name-1"),
|
||||
},
|
||||
},
|
||||
}, nil)
|
||||
|
||||
mrm.EXPECT().List(
|
||||
listOptionMatcher{expected: management.IncludeTotals(true)},
|
||||
listOptionMatcher{expected: management.Page(1)},
|
||||
).Return(&management.RoleList{
|
||||
List: management.List{
|
||||
Total: 3,
|
||||
Start: 1,
|
||||
Limit: 1,
|
||||
},
|
||||
Roles: []*management.Role{
|
||||
{
|
||||
ID: stringPtr("i-am-role-id-2"),
|
||||
Name: stringPtr("i-am-role-name-2"),
|
||||
},
|
||||
},
|
||||
}, nil)
|
||||
|
||||
mrm.EXPECT().List(
|
||||
listOptionMatcher{expected: management.IncludeTotals(true)},
|
||||
listOptionMatcher{expected: management.Page(2)},
|
||||
).Return(&management.RoleList{
|
||||
List: management.List{
|
||||
Total: 3,
|
||||
Start: 2,
|
||||
Limit: 1,
|
||||
},
|
||||
Roles: []*management.Role{
|
||||
{
|
||||
ID: stringPtr("i-am-role-id-3"),
|
||||
Name: stringPtr("i-am-role-name-3"),
|
||||
},
|
||||
},
|
||||
}, nil)
|
||||
|
||||
mrm.EXPECT().Users(
|
||||
"i-am-role-id-1",
|
||||
listOptionMatcher{expected: management.IncludeTotals(true)},
|
||||
listOptionMatcher{expected: management.Page(0)},
|
||||
).Return(&management.UserList{}, nil)
|
||||
|
||||
mrm.EXPECT().Users(
|
||||
"i-am-role-id-2",
|
||||
listOptionMatcher{expected: management.IncludeTotals(true)},
|
||||
listOptionMatcher{expected: management.Page(0)},
|
||||
).Return(&management.UserList{}, nil)
|
||||
|
||||
mrm.EXPECT().Users(
|
||||
"i-am-role-id-3",
|
||||
listOptionMatcher{expected: management.IncludeTotals(true)},
|
||||
listOptionMatcher{expected: management.Page(0)},
|
||||
).Return(&management.UserList{}, nil)
|
||||
},
|
||||
expectedGroups: []*directory.Group{
|
||||
{
|
||||
Id: "i-am-role-id-1",
|
||||
Name: "i-am-role-name-1",
|
||||
},
|
||||
{
|
||||
Id: "i-am-role-id-2",
|
||||
Name: "i-am-role-name-2",
|
||||
},
|
||||
{
|
||||
Id: "i-am-role-id-3",
|
||||
Name: "i-am-role-name-3",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "handles multiple pages of users",
|
||||
setupRoleManagerExpectations: func(mrm *mock_auth0.MockRoleManager) {
|
||||
mrm.EXPECT().List(
|
||||
listOptionMatcher{expected: management.IncludeTotals(true)},
|
||||
listOptionMatcher{expected: management.Page(0)},
|
||||
).Return(&management.RoleList{
|
||||
Roles: []*management.Role{
|
||||
{
|
||||
ID: stringPtr("i-am-role-id-1"),
|
||||
Name: stringPtr("i-am-role-name-1"),
|
||||
},
|
||||
},
|
||||
}, nil)
|
||||
|
||||
mrm.EXPECT().Users(
|
||||
"i-am-role-id-1",
|
||||
listOptionMatcher{expected: management.IncludeTotals(true)},
|
||||
listOptionMatcher{expected: management.Page(0)},
|
||||
).Return(&management.UserList{
|
||||
List: management.List{
|
||||
Total: 3,
|
||||
Start: 0,
|
||||
Limit: 1,
|
||||
},
|
||||
Users: []*management.User{
|
||||
{ID: stringPtr("i-am-user-id-1")},
|
||||
},
|
||||
}, nil)
|
||||
|
||||
mrm.EXPECT().Users(
|
||||
"i-am-role-id-1",
|
||||
listOptionMatcher{expected: management.IncludeTotals(true)},
|
||||
listOptionMatcher{expected: management.Page(1)},
|
||||
).Return(&management.UserList{
|
||||
List: management.List{
|
||||
Total: 3,
|
||||
Start: 1,
|
||||
Limit: 1,
|
||||
},
|
||||
Users: []*management.User{
|
||||
{ID: stringPtr("i-am-user-id-2")},
|
||||
},
|
||||
}, nil)
|
||||
|
||||
mrm.EXPECT().Users(
|
||||
"i-am-role-id-1",
|
||||
listOptionMatcher{expected: management.IncludeTotals(true)},
|
||||
listOptionMatcher{expected: management.Page(2)},
|
||||
).Return(&management.UserList{
|
||||
List: management.List{
|
||||
Total: 3,
|
||||
Start: 2,
|
||||
Limit: 1,
|
||||
},
|
||||
Users: []*management.User{
|
||||
{ID: stringPtr("i-am-user-id-3")},
|
||||
},
|
||||
}, nil)
|
||||
},
|
||||
expectedGroups: []*directory.Group{
|
||||
{
|
||||
Id: "i-am-role-id-1",
|
||||
Name: "i-am-role-name-1",
|
||||
},
|
||||
},
|
||||
expectedUsers: []*directory.User{
|
||||
{
|
||||
Id: "i-am-user-id-1",
|
||||
GroupIds: []string{"i-am-role-id-1"},
|
||||
},
|
||||
{
|
||||
Id: "i-am-user-id-2",
|
||||
GroupIds: []string{"i-am-role-id-1"},
|
||||
},
|
||||
{
|
||||
Id: "i-am-user-id-3",
|
||||
GroupIds: []string{"i-am-role-id-1"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "correctly builds out groups and users",
|
||||
setupRoleManagerExpectations: func(mrm *mock_auth0.MockRoleManager) {
|
||||
mrm.EXPECT().List(
|
||||
listOptionMatcher{expected: management.IncludeTotals(true)},
|
||||
listOptionMatcher{expected: management.Page(0)},
|
||||
).Return(&management.RoleList{
|
||||
List: management.List{
|
||||
Total: 2,
|
||||
Start: 0,
|
||||
Limit: 1,
|
||||
},
|
||||
Roles: []*management.Role{
|
||||
{
|
||||
ID: stringPtr("i-am-role-id-1"),
|
||||
Name: stringPtr("i-am-role-name-1"),
|
||||
},
|
||||
},
|
||||
}, nil)
|
||||
|
||||
mrm.EXPECT().List(
|
||||
listOptionMatcher{expected: management.IncludeTotals(true)},
|
||||
listOptionMatcher{expected: management.Page(1)},
|
||||
).Return(&management.RoleList{
|
||||
List: management.List{
|
||||
Total: 2,
|
||||
Start: 1,
|
||||
Limit: 1,
|
||||
},
|
||||
Roles: []*management.Role{
|
||||
{
|
||||
ID: stringPtr("i-am-role-id-2"),
|
||||
Name: stringPtr("i-am-role-name-2"),
|
||||
},
|
||||
},
|
||||
}, nil)
|
||||
|
||||
mrm.EXPECT().Users(
|
||||
"i-am-role-id-1",
|
||||
listOptionMatcher{expected: management.IncludeTotals(true)},
|
||||
listOptionMatcher{expected: management.Page(0)},
|
||||
).Return(&management.UserList{
|
||||
Users: []*management.User{
|
||||
{ID: stringPtr("i-am-user-id-4")},
|
||||
{ID: stringPtr("i-am-user-id-3")},
|
||||
{ID: stringPtr("i-am-user-id-2")},
|
||||
{ID: stringPtr("i-am-user-id-1")},
|
||||
},
|
||||
}, nil)
|
||||
|
||||
mrm.EXPECT().Users(
|
||||
"i-am-role-id-2",
|
||||
listOptionMatcher{expected: management.IncludeTotals(true)},
|
||||
listOptionMatcher{expected: management.Page(0)},
|
||||
).Return(&management.UserList{
|
||||
Users: []*management.User{
|
||||
{ID: stringPtr("i-am-user-id-1")},
|
||||
{ID: stringPtr("i-am-user-id-4")},
|
||||
{ID: stringPtr("i-am-user-id-5")},
|
||||
},
|
||||
}, nil)
|
||||
},
|
||||
expectedGroups: []*directory.Group{
|
||||
{
|
||||
Id: "i-am-role-id-1",
|
||||
Name: "i-am-role-name-1",
|
||||
},
|
||||
{
|
||||
Id: "i-am-role-id-2",
|
||||
Name: "i-am-role-name-2",
|
||||
},
|
||||
},
|
||||
expectedUsers: []*directory.User{
|
||||
{
|
||||
Id: "i-am-user-id-1",
|
||||
GroupIds: []string{"i-am-role-id-1", "i-am-role-id-2"},
|
||||
},
|
||||
{
|
||||
Id: "i-am-user-id-2",
|
||||
GroupIds: []string{"i-am-role-id-1"},
|
||||
},
|
||||
{
|
||||
Id: "i-am-user-id-3",
|
||||
GroupIds: []string{"i-am-role-id-1"},
|
||||
},
|
||||
{
|
||||
Id: "i-am-user-id-4",
|
||||
GroupIds: []string{"i-am-role-id-1", "i-am-role-id-2"},
|
||||
},
|
||||
{
|
||||
Id: "i-am-user-id-5",
|
||||
GroupIds: []string{"i-am-role-id-2"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mRoleManager := mock_auth0.NewMockRoleManager(ctrl)
|
||||
|
||||
mNewManagersFunc := mockNewRoleManagerFunc{
|
||||
ReturnRoleManager: mRoleManager,
|
||||
ReturnError: tc.newRoleManagerError,
|
||||
}
|
||||
|
||||
if tc.setupRoleManagerExpectations != nil {
|
||||
tc.setupRoleManagerExpectations(mRoleManager)
|
||||
}
|
||||
|
||||
p := New(
|
||||
WithDomain(expectedDomain),
|
||||
WithServiceAccount(expectedServiceAccount),
|
||||
)
|
||||
p.cfg.newManagers = mNewManagersFunc.f
|
||||
|
||||
actualGroups, actualUsers, err := p.UserGroups(context.Background())
|
||||
if tc.expectedError != nil {
|
||||
assert.EqualError(t, err, tc.expectedError.Error())
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
assert.Equal(t, tc.expectedGroups, actualGroups)
|
||||
assert.Equal(t, tc.expectedUsers, actualUsers)
|
||||
|
||||
assert.Equal(t, expectedDomain, mNewManagersFunc.CalledWithDomain)
|
||||
assert.Equal(t, expectedServiceAccount, mNewManagersFunc.CalledWithServiceAccount)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseServiceAccount(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
rawServiceAccount string
|
||||
expectedServiceAccount *ServiceAccount
|
||||
expectedError error
|
||||
}{
|
||||
{
|
||||
"valid", "eyJjbGllbnRfaWQiOiJpLWFtLWNsaWVudC1pZCIsInNlY3JldCI6ImktYW0tc2VjcmV0In0K",
|
||||
&ServiceAccount{
|
||||
ClientID: "i-am-client-id",
|
||||
ClientSecret: "i-am-secret",
|
||||
Secret: "i-am-secret",
|
||||
},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"json", `{"client_id": "i-am-client-id", "client_secret": "i-am-secret"}`,
|
||||
&ServiceAccount{
|
||||
ClientID: "i-am-client-id",
|
||||
ClientSecret: "i-am-secret",
|
||||
Secret: "i-am-secret",
|
||||
},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"domain", `{"domain": "example.auth0.com", "client_id": "i-am-client-id", "client_secret": "i-am-secret"}`,
|
||||
&ServiceAccount{
|
||||
ClientID: "i-am-client-id",
|
||||
ClientSecret: "i-am-secret",
|
||||
Domain: "example.auth0.com",
|
||||
Secret: "i-am-secret",
|
||||
},
|
||||
nil,
|
||||
},
|
||||
{"base64 err", "!!!!", nil, errors.New("auth0: could not unmarshal json: illegal base64 data at input byte 0")},
|
||||
{"json err", "PAo=", nil, errors.New("auth0: could not unmarshal json: invalid character '<' looking for beginning of value")},
|
||||
{"no client_id", "eyJzZWNyZXQiOiJpLWFtLXNlY3JldCJ9Cg==", nil, errors.New("auth0: client_id is required")},
|
||||
{"no secret", "eyJjbGllbnRfaWQiOiJpLWFtLWNsaWVudC1pZCJ9Cg==", nil, errors.New("auth0: secret is required")},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
actualServiceAccount, err := ParseServiceAccount(directory.Options{ServiceAccount: tc.rawServiceAccount})
|
||||
if tc.expectedError != nil {
|
||||
assert.EqualError(t, err, tc.expectedError.Error())
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
assert.Equal(t, tc.expectedServiceAccount, actualServiceAccount)
|
||||
})
|
||||
}
|
||||
}
|
|
@ -1,73 +0,0 @@
|
|||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/pomerium/pomerium/internal/directory/auth0 (interfaces: RoleManager)
|
||||
|
||||
// Package mock_auth0 is a generated GoMock package.
|
||||
package mock_auth0
|
||||
|
||||
import (
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
management "gopkg.in/auth0.v5/management"
|
||||
reflect "reflect"
|
||||
)
|
||||
|
||||
// MockRoleManager is a mock of RoleManager interface
|
||||
type MockRoleManager struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockRoleManagerMockRecorder
|
||||
}
|
||||
|
||||
// MockRoleManagerMockRecorder is the mock recorder for MockRoleManager
|
||||
type MockRoleManagerMockRecorder struct {
|
||||
mock *MockRoleManager
|
||||
}
|
||||
|
||||
// NewMockRoleManager creates a new mock instance
|
||||
func NewMockRoleManager(ctrl *gomock.Controller) *MockRoleManager {
|
||||
mock := &MockRoleManager{ctrl: ctrl}
|
||||
mock.recorder = &MockRoleManagerMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use
|
||||
func (m *MockRoleManager) EXPECT() *MockRoleManagerMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// List mocks base method
|
||||
func (m *MockRoleManager) List(arg0 ...management.RequestOption) (*management.RoleList, error) {
|
||||
m.ctrl.T.Helper()
|
||||
varargs := []interface{}{}
|
||||
for _, a := range arg0 {
|
||||
varargs = append(varargs, a)
|
||||
}
|
||||
ret := m.ctrl.Call(m, "List", varargs...)
|
||||
ret0, _ := ret[0].(*management.RoleList)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// List indicates an expected call of List
|
||||
func (mr *MockRoleManagerMockRecorder) List(arg0 ...interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "List", reflect.TypeOf((*MockRoleManager)(nil).List), arg0...)
|
||||
}
|
||||
|
||||
// Users mocks base method
|
||||
func (m *MockRoleManager) Users(arg0 string, arg1 ...management.RequestOption) (*management.UserList, error) {
|
||||
m.ctrl.T.Helper()
|
||||
varargs := []interface{}{arg0}
|
||||
for _, a := range arg1 {
|
||||
varargs = append(varargs, a)
|
||||
}
|
||||
ret := m.ctrl.Call(m, "Users", varargs...)
|
||||
ret0, _ := ret[0].(*management.UserList)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Users indicates an expected call of Users
|
||||
func (mr *MockRoleManagerMockRecorder) Users(arg0 interface{}, arg1 ...interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
varargs := append([]interface{}{arg0}, arg1...)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Users", reflect.TypeOf((*MockRoleManager)(nil).Users), varargs...)
|
||||
}
|
|
@ -1,47 +0,0 @@
|
|||
package azure
|
||||
|
||||
import "strings"
|
||||
|
||||
type (
|
||||
apiGetUserResponse struct {
|
||||
apiUser
|
||||
}
|
||||
apiGetUserMembersResponse struct {
|
||||
Context string `json:"@odata.context"`
|
||||
NextLink string `json:"@odata.nextLink,omitempty"`
|
||||
Value []apiGroup `json:"value"`
|
||||
}
|
||||
|
||||
apiGroup struct {
|
||||
ID string `json:"id"`
|
||||
DisplayName string `json:"displayName"`
|
||||
}
|
||||
apiUser struct {
|
||||
ID string `json:"id"`
|
||||
DisplayName string `json:"displayName"`
|
||||
Mail string `json:"mail"`
|
||||
UserPrincipalName string `json:"userPrincipalName"`
|
||||
}
|
||||
)
|
||||
|
||||
func (obj apiUser) getEmail() string {
|
||||
if obj.Mail != "" {
|
||||
return obj.Mail
|
||||
}
|
||||
|
||||
// AD often doesn't have the email address returned, but we can parse it from the UPN
|
||||
|
||||
// UPN looks like either:
|
||||
// cdoxsey_pomerium.com#EXT#@cdoxseypomerium.onmicrosoft.com
|
||||
// cdoxsey@pomerium.com
|
||||
email := obj.UserPrincipalName
|
||||
if idx := strings.Index(email, "#EXT"); idx > 0 {
|
||||
email = email[:idx]
|
||||
|
||||
// find the last _ and replace it with @
|
||||
if idx := strings.LastIndex(email, "_"); idx > 0 {
|
||||
email = email[:idx] + "@" + email[idx+1:]
|
||||
}
|
||||
}
|
||||
return email
|
||||
}
|
|
@ -1,334 +0,0 @@
|
|||
// Package azure contains an azure active directory directory provider.
|
||||
package azure
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/encoding"
|
||||
"github.com/pomerium/pomerium/internal/httputil"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/directory"
|
||||
)
|
||||
|
||||
// Name is the provider name.
|
||||
const Name = "azure"
|
||||
|
||||
const (
|
||||
defaultGraphHost = "graph.microsoft.com"
|
||||
|
||||
defaultLoginHost = "login.microsoftonline.com"
|
||||
defaultLoginScope = "https://graph.microsoft.com/.default"
|
||||
defaultLoginGrantType = "client_credentials"
|
||||
)
|
||||
|
||||
type config struct {
|
||||
graphURL *url.URL
|
||||
httpClient *http.Client
|
||||
loginURL *url.URL
|
||||
serviceAccount *ServiceAccount
|
||||
}
|
||||
|
||||
// An Option updates the provider configuration.
|
||||
type Option func(*config)
|
||||
|
||||
// WithGraphURL sets the graph URL for the configuration.
|
||||
func WithGraphURL(graphURL *url.URL) Option {
|
||||
return func(cfg *config) {
|
||||
cfg.graphURL = graphURL
|
||||
}
|
||||
}
|
||||
|
||||
// WithLoginURL sets the login URL for the configuration.
|
||||
func WithLoginURL(loginURL *url.URL) Option {
|
||||
return func(cfg *config) {
|
||||
cfg.loginURL = loginURL
|
||||
}
|
||||
}
|
||||
|
||||
// WithHTTPClient sets the http client to use for requests to the Azure APIs.
|
||||
func WithHTTPClient(httpClient *http.Client) Option {
|
||||
return func(cfg *config) {
|
||||
cfg.httpClient = httputil.NewLoggingClient(httpClient, "azure_idp_client",
|
||||
func(evt *zerolog.Event) *zerolog.Event {
|
||||
return evt.Str("provider", "azure")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// WithServiceAccount sets the service account to use to access Azure.
|
||||
func WithServiceAccount(serviceAccount *ServiceAccount) Option {
|
||||
return func(cfg *config) {
|
||||
cfg.serviceAccount = serviceAccount
|
||||
}
|
||||
}
|
||||
|
||||
func getConfig(options ...Option) *config {
|
||||
cfg := new(config)
|
||||
WithGraphURL(&url.URL{
|
||||
Scheme: "https",
|
||||
Host: defaultGraphHost,
|
||||
})(cfg)
|
||||
WithHTTPClient(http.DefaultClient)(cfg)
|
||||
WithLoginURL(&url.URL{
|
||||
Scheme: "https",
|
||||
Host: defaultLoginHost,
|
||||
})(cfg)
|
||||
for _, option := range options {
|
||||
option(cfg)
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
|
||||
// A Provider is a directory implementation using azure active directory.
|
||||
type Provider struct {
|
||||
cfg *config
|
||||
dc *deltaCollection
|
||||
|
||||
mu sync.RWMutex
|
||||
token *oauth2.Token
|
||||
}
|
||||
|
||||
// New creates a new Provider.
|
||||
func New(options ...Option) *Provider {
|
||||
p := &Provider{
|
||||
cfg: getConfig(options...),
|
||||
}
|
||||
p.dc = newDeltaCollection(p)
|
||||
return p
|
||||
}
|
||||
|
||||
// User returns the user record for the given id.
|
||||
func (p *Provider) User(ctx context.Context, userID, accessToken string) (*directory.User, error) {
|
||||
if p.cfg.serviceAccount == nil {
|
||||
return nil, fmt.Errorf("azure: service account not defined")
|
||||
}
|
||||
|
||||
du := &directory.User{
|
||||
Id: userID,
|
||||
}
|
||||
|
||||
userURL := p.cfg.graphURL.ResolveReference(&url.URL{
|
||||
Path: fmt.Sprintf("/v1.0/users/%s", userID),
|
||||
}).String()
|
||||
|
||||
var u apiGetUserResponse
|
||||
err := p.api(ctx, userURL, &u)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
du.DisplayName = u.DisplayName
|
||||
du.Email = u.getEmail()
|
||||
du.GroupIds, err = p.transitiveMemberOf(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return du, nil
|
||||
}
|
||||
|
||||
// UserGroups returns the directory users in azure active directory.
|
||||
func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*directory.User, error) {
|
||||
if p.cfg.serviceAccount == nil {
|
||||
return nil, nil, fmt.Errorf("azure: service account not defined")
|
||||
}
|
||||
|
||||
err := p.dc.Sync(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
groups, users := p.dc.CurrentUserGroups()
|
||||
return groups, users, nil
|
||||
}
|
||||
|
||||
func (p *Provider) api(ctx context.Context, url string, out interface{}) error {
|
||||
token, err := p.getToken(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("azure: error creating HTTP request: %w", err)
|
||||
}
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.AccessToken))
|
||||
|
||||
res, err := p.cfg.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("azure: error making HTTP request: %w", err)
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
// if we get unauthorized, invalidate the token
|
||||
if res.StatusCode == http.StatusUnauthorized {
|
||||
p.mu.Lock()
|
||||
p.token = nil
|
||||
p.mu.Unlock()
|
||||
}
|
||||
|
||||
if res.StatusCode/100 != 2 {
|
||||
return fmt.Errorf("azure: error querying api (%s): %s", url, res.Status)
|
||||
}
|
||||
|
||||
err = json.NewDecoder(res.Body).Decode(out)
|
||||
if err != nil {
|
||||
return fmt.Errorf("azure: error decoding api response: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Provider) getToken(ctx context.Context) (*oauth2.Token, error) {
|
||||
p.mu.RLock()
|
||||
token := p.token
|
||||
p.mu.RUnlock()
|
||||
|
||||
if token != nil && token.Valid() {
|
||||
return token, nil
|
||||
}
|
||||
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
token = p.token
|
||||
if token != nil && token.Valid() {
|
||||
return token, nil
|
||||
}
|
||||
|
||||
tokenURL := p.cfg.loginURL.ResolveReference(&url.URL{
|
||||
Path: fmt.Sprintf("/%s/oauth2/v2.0/token", p.cfg.serviceAccount.DirectoryID),
|
||||
})
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", tokenURL.String(), strings.NewReader(url.Values{
|
||||
"client_id": {p.cfg.serviceAccount.ClientID},
|
||||
"client_secret": {p.cfg.serviceAccount.ClientSecret},
|
||||
"scope": {defaultLoginScope},
|
||||
"grant_type": {defaultLoginGrantType},
|
||||
}.Encode()))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("azure: error creating HTTP request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
res, err := p.cfg.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
if res.StatusCode/100 != 2 {
|
||||
return nil, fmt.Errorf("azure: error querying oauth2 token: %s", res.Status)
|
||||
}
|
||||
err = json.NewDecoder(res.Body).Decode(&token)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("azure: error decoding oauth2 token: %w", err)
|
||||
}
|
||||
p.token = token
|
||||
|
||||
return p.token, nil
|
||||
}
|
||||
|
||||
func (p *Provider) transitiveMemberOf(ctx context.Context, userID string) (groupIDs []string, err error) {
|
||||
apiURL := p.cfg.graphURL.ResolveReference(&url.URL{
|
||||
Path: fmt.Sprintf("/v1.0/users/%s/transitiveMemberOf", userID),
|
||||
}).String()
|
||||
for {
|
||||
var res apiGetUserMembersResponse
|
||||
err := p.api(ctx, apiURL, &res)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, g := range res.Value {
|
||||
groupIDs = append(groupIDs, g.ID)
|
||||
}
|
||||
if res.NextLink == "" {
|
||||
break
|
||||
}
|
||||
apiURL = res.NextLink
|
||||
}
|
||||
sort.Strings(groupIDs)
|
||||
return groupIDs, nil
|
||||
}
|
||||
|
||||
// A ServiceAccount is used by the Azure provider to query the Microsoft Graph API.
|
||||
type ServiceAccount struct {
|
||||
ClientID string `json:"client_id"`
|
||||
ClientSecret string `json:"client_secret"`
|
||||
DirectoryID string `json:"directory_id"`
|
||||
}
|
||||
|
||||
// ParseServiceAccount parses the service account in the config options.
|
||||
func ParseServiceAccount(options directory.Options) (*ServiceAccount, error) {
|
||||
if options.ServiceAccount != "" {
|
||||
return parseServiceAccountFromString(options.ServiceAccount)
|
||||
}
|
||||
return parseServiceAccountFromOptions(options.ClientID, options.ClientSecret, options.ProviderURL)
|
||||
}
|
||||
|
||||
func parseServiceAccountFromOptions(clientID, clientSecret, providerURL string) (*ServiceAccount, error) {
|
||||
serviceAccount := ServiceAccount{
|
||||
ClientID: clientID,
|
||||
ClientSecret: clientSecret,
|
||||
}
|
||||
|
||||
var err error
|
||||
serviceAccount.DirectoryID, err = parseDirectoryIDFromURL(providerURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if serviceAccount.ClientID == "" {
|
||||
return nil, fmt.Errorf("client_id is required")
|
||||
}
|
||||
if serviceAccount.ClientSecret == "" {
|
||||
return nil, fmt.Errorf("client_secret is required")
|
||||
}
|
||||
if serviceAccount.DirectoryID == "" {
|
||||
return nil, fmt.Errorf("directory_id is required")
|
||||
}
|
||||
|
||||
return &serviceAccount, nil
|
||||
}
|
||||
|
||||
func parseServiceAccountFromString(rawServiceAccount string) (*ServiceAccount, error) {
|
||||
var serviceAccount ServiceAccount
|
||||
if err := encoding.DecodeBase64OrJSON(rawServiceAccount, &serviceAccount); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if serviceAccount.ClientID == "" {
|
||||
return nil, fmt.Errorf("client_id is required")
|
||||
}
|
||||
if serviceAccount.ClientSecret == "" {
|
||||
return nil, fmt.Errorf("client_secret is required")
|
||||
}
|
||||
if serviceAccount.DirectoryID == "" {
|
||||
return nil, fmt.Errorf("directory_id is required")
|
||||
}
|
||||
|
||||
return &serviceAccount, nil
|
||||
}
|
||||
|
||||
func parseDirectoryIDFromURL(providerURL string) (string, error) {
|
||||
u, err := url.Parse(providerURL)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
pathParts := strings.SplitN(u.Path, "/", 3)
|
||||
if len(pathParts) != 3 {
|
||||
return "", fmt.Errorf("no directory id found in path")
|
||||
}
|
||||
|
||||
return pathParts[1], nil
|
||||
}
|
|
@ -1,282 +0,0 @@
|
|||
package azure
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/go-chi/chi/v5/middleware"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/testutil"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/directory"
|
||||
)
|
||||
|
||||
type M = map[string]interface{}
|
||||
|
||||
func newMockAPI(t *testing.T, srv *httptest.Server) http.Handler {
|
||||
r := chi.NewRouter()
|
||||
r.Use(middleware.Logger)
|
||||
r.Post("/DIRECTORY_ID/oauth2/v2.0/token", func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "CLIENT_ID", r.FormValue("client_id"))
|
||||
assert.Equal(t, "CLIENT_SECRET", r.FormValue("client_secret"))
|
||||
assert.Equal(t, defaultLoginScope, r.FormValue("scope"))
|
||||
assert.Equal(t, defaultLoginGrantType, r.FormValue("grant_type"))
|
||||
|
||||
_ = json.NewEncoder(w).Encode(M{
|
||||
"access_token": "ACCESSTOKEN",
|
||||
"token_type": "Bearer",
|
||||
"refresh_token": "REFRESHTOKEN",
|
||||
})
|
||||
})
|
||||
r.Route("/v1.0", func(r chi.Router) {
|
||||
r.Use(func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Header.Get("Authorization") != "Bearer ACCESSTOKEN" {
|
||||
http.Error(w, "forbidden", http.StatusForbidden)
|
||||
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
})
|
||||
r.Get("/groups/delta", func(w http.ResponseWriter, r *http.Request) {
|
||||
_ = json.NewEncoder(w).Encode(M{
|
||||
"value": []M{
|
||||
{
|
||||
"id": "admin",
|
||||
"displayName": "Admin Group",
|
||||
"members@delta": []M{
|
||||
{"@odata.type": "#microsoft.graph.user", "id": "user-1"},
|
||||
},
|
||||
},
|
||||
{
|
||||
"id": "test",
|
||||
"displayName": "Test Group",
|
||||
"members@delta": []M{
|
||||
{"@odata.type": "#microsoft.graph.user", "id": "user-2"},
|
||||
{"@odata.type": "#microsoft.graph.user", "id": "user-3"},
|
||||
{"@odata.type": "#microsoft.graph.user", "id": "user-4"},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
})
|
||||
r.Get("/users/delta", func(w http.ResponseWriter, r *http.Request) {
|
||||
_ = json.NewEncoder(w).Encode(M{
|
||||
"value": []M{
|
||||
{"id": "user-1", "displayName": "User 1", "mail": "user1@example.com"},
|
||||
{"id": "user-2", "displayName": "User 2", "mail": "user2@example.com"},
|
||||
{"id": "user-3", "displayName": "User 3", "userPrincipalName": "user3_example.com#EXT#@user3example.onmicrosoft.com"},
|
||||
{"id": "user-4", "displayName": "User 4", "userPrincipalName": "user4@example.com"},
|
||||
},
|
||||
})
|
||||
})
|
||||
r.Get("/users/{user_id}", func(w http.ResponseWriter, r *http.Request) {
|
||||
switch chi.URLParam(r, "user_id") {
|
||||
case "user-1":
|
||||
_ = json.NewEncoder(w).Encode(M{"id": "user-1", "displayName": "User 1", "mail": "user1@example.com"})
|
||||
default:
|
||||
http.Error(w, "not found", http.StatusNotFound)
|
||||
}
|
||||
})
|
||||
r.Get("/users/{user_id}/transitiveMemberOf", func(w http.ResponseWriter, r *http.Request) {
|
||||
switch chi.URLParam(r, "user_id") {
|
||||
case "user-1":
|
||||
switch r.URL.Query().Get("page") {
|
||||
case "":
|
||||
_ = json.NewEncoder(w).Encode(M{
|
||||
"value": []M{
|
||||
{"id": "admin"},
|
||||
},
|
||||
"@odata.nextLink": getPageURL(r, 1),
|
||||
})
|
||||
case "1":
|
||||
_ = json.NewEncoder(w).Encode(M{
|
||||
"value": []M{
|
||||
{"id": "group1"},
|
||||
},
|
||||
"@odata.nextLink": getPageURL(r, 2),
|
||||
})
|
||||
case "2":
|
||||
_ = json.NewEncoder(w).Encode(M{
|
||||
"value": []M{
|
||||
{"id": "group2"},
|
||||
},
|
||||
})
|
||||
}
|
||||
default:
|
||||
http.Error(w, "not found", http.StatusNotFound)
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
func TestProvider_User(t *testing.T) {
|
||||
var mockAPI http.Handler
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
mockAPI.ServeHTTP(w, r)
|
||||
}))
|
||||
defer srv.Close()
|
||||
mockAPI = newMockAPI(t, srv)
|
||||
|
||||
p := New(
|
||||
WithGraphURL(mustParseURL(srv.URL)),
|
||||
WithLoginURL(mustParseURL(srv.URL)),
|
||||
WithServiceAccount(&ServiceAccount{
|
||||
ClientID: "CLIENT_ID",
|
||||
ClientSecret: "CLIENT_SECRET",
|
||||
DirectoryID: "DIRECTORY_ID",
|
||||
}),
|
||||
)
|
||||
|
||||
du, err := p.User(context.Background(), "user-1", "")
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
testutil.AssertProtoJSONEqual(t, `{
|
||||
"id": "user-1",
|
||||
"displayName": "User 1",
|
||||
"email": "user1@example.com",
|
||||
"groupIds": ["admin", "group1", "group2"]
|
||||
}`, du)
|
||||
}
|
||||
|
||||
func TestProvider_UserGroups(t *testing.T) {
|
||||
var mockAPI http.Handler
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
mockAPI.ServeHTTP(w, r)
|
||||
}))
|
||||
defer srv.Close()
|
||||
mockAPI = newMockAPI(t, srv)
|
||||
|
||||
p := New(
|
||||
WithGraphURL(mustParseURL(srv.URL)),
|
||||
WithLoginURL(mustParseURL(srv.URL)),
|
||||
WithServiceAccount(&ServiceAccount{
|
||||
ClientID: "CLIENT_ID",
|
||||
ClientSecret: "CLIENT_SECRET",
|
||||
DirectoryID: "DIRECTORY_ID",
|
||||
}),
|
||||
)
|
||||
groups, users, err := p.UserGroups(context.Background())
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []*directory.User{
|
||||
{
|
||||
Id: "user-1",
|
||||
GroupIds: []string{"admin"},
|
||||
DisplayName: "User 1",
|
||||
Email: "user1@example.com",
|
||||
},
|
||||
{
|
||||
Id: "user-2",
|
||||
GroupIds: []string{"test"},
|
||||
DisplayName: "User 2",
|
||||
Email: "user2@example.com",
|
||||
},
|
||||
{
|
||||
Id: "user-3",
|
||||
GroupIds: []string{"test"},
|
||||
DisplayName: "User 3",
|
||||
Email: "user3@example.com",
|
||||
},
|
||||
{
|
||||
Id: "user-4",
|
||||
GroupIds: []string{"test"},
|
||||
DisplayName: "User 4",
|
||||
Email: "user4@example.com",
|
||||
},
|
||||
}, users)
|
||||
testutil.AssertProtoJSONEqual(t, `[
|
||||
{ "id": "admin", "name": "Admin Group" },
|
||||
{ "id": "test", "name": "Test Group"}
|
||||
]`, groups)
|
||||
}
|
||||
|
||||
func TestParseServiceAccount(t *testing.T) {
|
||||
t.Run("by options", func(t *testing.T) {
|
||||
serviceAccount, err := ParseServiceAccount(directory.Options{
|
||||
ProviderURL: "https://login.microsoftonline.com/0303f438-3c5c-4190-9854-08d3eb31bd9f/v2.0",
|
||||
ClientID: "CLIENT_ID",
|
||||
ClientSecret: "CLIENT_SECRET",
|
||||
})
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
|
||||
assert.Equal(t, &ServiceAccount{
|
||||
ClientID: "CLIENT_ID",
|
||||
ClientSecret: "CLIENT_SECRET",
|
||||
DirectoryID: "0303f438-3c5c-4190-9854-08d3eb31bd9f",
|
||||
}, serviceAccount)
|
||||
})
|
||||
t.Run("by service account base64", func(t *testing.T) {
|
||||
serviceAccount, err := ParseServiceAccount(directory.Options{
|
||||
ServiceAccount: base64.StdEncoding.EncodeToString([]byte(`{
|
||||
"client_id": "CLIENT_ID",
|
||||
"client_secret": "CLIENT_SECRET",
|
||||
"directory_id": "0303f438-3c5c-4190-9854-08d3eb31bd9f"
|
||||
}`)),
|
||||
})
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
|
||||
assert.Equal(t, &ServiceAccount{
|
||||
ClientID: "CLIENT_ID",
|
||||
ClientSecret: "CLIENT_SECRET",
|
||||
DirectoryID: "0303f438-3c5c-4190-9854-08d3eb31bd9f",
|
||||
}, serviceAccount)
|
||||
})
|
||||
t.Run("by service account json", func(t *testing.T) {
|
||||
serviceAccount, err := ParseServiceAccount(directory.Options{
|
||||
ServiceAccount: `{
|
||||
"client_id": "CLIENT_ID",
|
||||
"client_secret": "CLIENT_SECRET",
|
||||
"directory_id": "0303f438-3c5c-4190-9854-08d3eb31bd9f"
|
||||
}`,
|
||||
})
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
|
||||
assert.Equal(t, &ServiceAccount{
|
||||
ClientID: "CLIENT_ID",
|
||||
ClientSecret: "CLIENT_SECRET",
|
||||
DirectoryID: "0303f438-3c5c-4190-9854-08d3eb31bd9f",
|
||||
}, serviceAccount)
|
||||
})
|
||||
}
|
||||
|
||||
func mustParseURL(rawurl string) *url.URL {
|
||||
u, err := url.Parse(rawurl)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return u
|
||||
}
|
||||
|
||||
func getPageURL(r *http.Request, page int) string {
|
||||
var u url.URL
|
||||
u = *r.URL
|
||||
if r.TLS == nil {
|
||||
u.Scheme = "http"
|
||||
} else {
|
||||
u.Scheme = "https"
|
||||
}
|
||||
if u.Host == "" {
|
||||
u.Host = r.Host
|
||||
}
|
||||
q := u.Query()
|
||||
q.Set("page", strconv.Itoa(page))
|
||||
u.RawQuery = q.Encode()
|
||||
return u.String()
|
||||
}
|
|
@ -1,248 +0,0 @@
|
|||
package azure
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/url"
|
||||
"sort"
|
||||
|
||||
"github.com/pomerium/pomerium/pkg/grpc/directory"
|
||||
)
|
||||
|
||||
const (
|
||||
groupsDeltaPath = "/v1.0/groups/delta"
|
||||
usersDeltaPath = "/v1.0/users/delta"
|
||||
)
|
||||
|
||||
type (
|
||||
deltaCollection struct {
|
||||
provider *Provider
|
||||
groups map[string]deltaGroup
|
||||
groupDeltaLink string
|
||||
users map[string]deltaUser
|
||||
userDeltaLink string
|
||||
}
|
||||
deltaGroup struct {
|
||||
id string
|
||||
displayName string
|
||||
members map[string]deltaGroupMember
|
||||
}
|
||||
deltaGroupMember struct {
|
||||
memberType string
|
||||
id string
|
||||
}
|
||||
deltaUser struct {
|
||||
id string
|
||||
displayName string
|
||||
email string
|
||||
}
|
||||
)
|
||||
|
||||
func newDeltaCollection(p *Provider) *deltaCollection {
|
||||
return &deltaCollection{
|
||||
provider: p,
|
||||
groups: make(map[string]deltaGroup),
|
||||
users: make(map[string]deltaUser),
|
||||
}
|
||||
}
|
||||
|
||||
// Sync syncs the latest changes from the microsoft graph API.
|
||||
//
|
||||
// Synchronization is based on https://docs.microsoft.com/en-us/graph/delta-query-groups
|
||||
//
|
||||
// It involves 4 steps:
|
||||
//
|
||||
// 1. an initial request to /v1.0/groups/delta
|
||||
// 2. one or more requests to /v1.0/groups/delta?$skiptoken=..., which comes from the @odata.nextLink
|
||||
// 3. a final response with @odata.deltaLink
|
||||
// 4. on the next call to sync, starting at @odata.deltaLink
|
||||
//
|
||||
// Only the changed groups/members are returned. Removed groups/members have an @removed property.
|
||||
func (dc *deltaCollection) Sync(ctx context.Context) error {
|
||||
err := dc.syncGroups(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return dc.syncUsers(ctx)
|
||||
}
|
||||
|
||||
func (dc *deltaCollection) syncGroups(ctx context.Context) error {
|
||||
apiURL := dc.groupDeltaLink
|
||||
|
||||
// if no delta link is set yet, start the initial fill
|
||||
if apiURL == "" {
|
||||
apiURL = dc.provider.cfg.graphURL.ResolveReference(&url.URL{
|
||||
Path: groupsDeltaPath,
|
||||
RawQuery: url.Values{
|
||||
"$select": {"displayName,members"},
|
||||
}.Encode(),
|
||||
}).String()
|
||||
}
|
||||
|
||||
for {
|
||||
var res groupsDeltaResponse
|
||||
err := dc.provider.api(ctx, apiURL, &res)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, g := range res.Value {
|
||||
// if removed exists, the group was deleted
|
||||
if g.Removed != nil {
|
||||
delete(dc.groups, g.ID)
|
||||
continue
|
||||
}
|
||||
|
||||
gdg := dc.groups[g.ID]
|
||||
gdg.id = g.ID
|
||||
gdg.displayName = g.DisplayName
|
||||
if gdg.members == nil {
|
||||
gdg.members = make(map[string]deltaGroupMember)
|
||||
}
|
||||
for _, m := range g.Members {
|
||||
// if removed exists, the member was deleted
|
||||
if m.Removed != nil {
|
||||
delete(gdg.members, m.ID)
|
||||
continue
|
||||
}
|
||||
|
||||
gdg.members[m.ID] = deltaGroupMember{
|
||||
memberType: m.Type,
|
||||
id: m.ID,
|
||||
}
|
||||
}
|
||||
dc.groups[g.ID] = gdg
|
||||
}
|
||||
|
||||
switch {
|
||||
case res.NextLink != "":
|
||||
// when there's a next link we will query again
|
||||
apiURL = res.NextLink
|
||||
default:
|
||||
// once no next link is set anymore, we save the delta link and return
|
||||
dc.groupDeltaLink = res.DeltaLink
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (dc *deltaCollection) syncUsers(ctx context.Context) error {
|
||||
apiURL := dc.userDeltaLink
|
||||
|
||||
// if no delta link is set yet, start the initial fill
|
||||
if apiURL == "" {
|
||||
apiURL = dc.provider.cfg.graphURL.ResolveReference(&url.URL{
|
||||
Path: usersDeltaPath,
|
||||
RawQuery: url.Values{
|
||||
"$select": {"displayName,mail,userPrincipalName"},
|
||||
}.Encode(),
|
||||
}).String()
|
||||
}
|
||||
|
||||
for {
|
||||
var res usersDeltaResponse
|
||||
err := dc.provider.api(ctx, apiURL, &res)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, u := range res.Value {
|
||||
// if removed exists, the user was deleted
|
||||
if u.Removed != nil {
|
||||
delete(dc.users, u.ID)
|
||||
continue
|
||||
}
|
||||
dc.users[u.ID] = deltaUser{
|
||||
id: u.ID,
|
||||
displayName: u.DisplayName,
|
||||
email: u.getEmail(),
|
||||
}
|
||||
}
|
||||
|
||||
switch {
|
||||
case res.NextLink != "":
|
||||
// when there's a next link we will query again
|
||||
apiURL = res.NextLink
|
||||
default:
|
||||
// once no next link is set anymore, we save the delta link and return
|
||||
dc.userDeltaLink = res.DeltaLink
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// CurrentUserGroups returns the directory groups and users based on the current state.
|
||||
func (dc *deltaCollection) CurrentUserGroups() ([]*directory.Group, []*directory.User) {
|
||||
var groups []*directory.Group
|
||||
|
||||
groupLookup := newGroupLookup()
|
||||
for _, g := range dc.groups {
|
||||
groups = append(groups, &directory.Group{
|
||||
Id: g.id,
|
||||
Name: g.displayName,
|
||||
})
|
||||
var groupIDs, userIDs []string
|
||||
for _, m := range g.members {
|
||||
switch m.memberType {
|
||||
case "#microsoft.graph.group":
|
||||
groupIDs = append(groupIDs, m.id)
|
||||
case "#microsoft.graph.user":
|
||||
userIDs = append(userIDs, m.id)
|
||||
}
|
||||
}
|
||||
groupLookup.addGroup(g.id, groupIDs, userIDs)
|
||||
}
|
||||
sort.Slice(groups, func(i, j int) bool {
|
||||
return groups[i].GetId() < groups[j].GetId()
|
||||
})
|
||||
|
||||
var users []*directory.User
|
||||
for _, u := range dc.users {
|
||||
users = append(users, &directory.User{
|
||||
Id: u.id,
|
||||
GroupIds: groupLookup.getGroupIDsForUser(u.id),
|
||||
DisplayName: u.displayName,
|
||||
Email: u.email,
|
||||
})
|
||||
}
|
||||
sort.Slice(users, func(i, j int) bool {
|
||||
return users[i].GetId() < users[j].GetId()
|
||||
})
|
||||
|
||||
return groups, users
|
||||
}
|
||||
|
||||
// API types for the microsoft graph API.
|
||||
type (
|
||||
deltaResponseRemoved struct {
|
||||
Reason string `json:"reason"`
|
||||
}
|
||||
|
||||
groupsDeltaResponse struct {
|
||||
Context string `json:"@odata.context"`
|
||||
NextLink string `json:"@odata.nextLink,omitempty"`
|
||||
DeltaLink string `json:"@odata.deltaLink,omitempty"`
|
||||
Value []groupsDeltaResponseGroup `json:"value"`
|
||||
}
|
||||
groupsDeltaResponseGroup struct {
|
||||
apiGroup
|
||||
Members []groupsDeltaResponseGroupMember `json:"members@delta"`
|
||||
Removed *deltaResponseRemoved `json:"@removed,omitempty"`
|
||||
}
|
||||
groupsDeltaResponseGroupMember struct {
|
||||
Type string `json:"@odata.type"`
|
||||
ID string `json:"id"`
|
||||
Removed *deltaResponseRemoved `json:"@removed,omitempty"`
|
||||
}
|
||||
|
||||
usersDeltaResponse struct {
|
||||
Context string `json:"@odata.context"`
|
||||
NextLink string `json:"@odata.nextLink,omitempty"`
|
||||
DeltaLink string `json:"@odata.deltaLink,omitempty"`
|
||||
Value []usersDeltaResponseUser `json:"value"`
|
||||
}
|
||||
usersDeltaResponseUser struct {
|
||||
apiUser
|
||||
Removed *deltaResponseRemoved `json:"@removed,omitempty"`
|
||||
}
|
||||
)
|
|
@ -1,107 +0,0 @@
|
|||
package azure
|
||||
|
||||
import "sort"
|
||||
|
||||
type stringSet map[string]struct{}
|
||||
|
||||
func newStringSet() stringSet {
|
||||
return make(stringSet)
|
||||
}
|
||||
|
||||
func (ss stringSet) add(value string) {
|
||||
ss[value] = struct{}{}
|
||||
}
|
||||
|
||||
func (ss stringSet) has(value string) bool {
|
||||
if ss == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
_, ok := ss[value]
|
||||
return ok
|
||||
}
|
||||
|
||||
func (ss stringSet) sorted() []string {
|
||||
if ss == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
s := make([]string, 0, len(ss))
|
||||
for v := range ss {
|
||||
s = append(s, v)
|
||||
}
|
||||
sort.Strings(s)
|
||||
return s
|
||||
}
|
||||
|
||||
type stringSetSet map[string]stringSet
|
||||
|
||||
func newStringSetSet() stringSetSet {
|
||||
return make(stringSetSet)
|
||||
}
|
||||
|
||||
func (sss stringSetSet) add(v1, v2 string) {
|
||||
ss, ok := sss[v1]
|
||||
if !ok {
|
||||
ss = newStringSet()
|
||||
sss[v1] = ss
|
||||
}
|
||||
ss.add(v2)
|
||||
}
|
||||
|
||||
func (sss stringSetSet) get(v1 string) stringSet {
|
||||
return sss[v1]
|
||||
}
|
||||
|
||||
type groupLookup struct {
|
||||
childUserIDToParentGroupID stringSetSet
|
||||
childGroupIDToParentGroupID stringSetSet
|
||||
}
|
||||
|
||||
func newGroupLookup() *groupLookup {
|
||||
return &groupLookup{
|
||||
childUserIDToParentGroupID: newStringSetSet(),
|
||||
childGroupIDToParentGroupID: newStringSetSet(),
|
||||
}
|
||||
}
|
||||
|
||||
func (l *groupLookup) addGroup(parentGroupID string, childGroupIDs, childUserIDs []string) {
|
||||
for _, childGroupID := range childGroupIDs {
|
||||
l.childGroupIDToParentGroupID.add(childGroupID, parentGroupID)
|
||||
}
|
||||
for _, childUserID := range childUserIDs {
|
||||
l.childUserIDToParentGroupID.add(childUserID, parentGroupID)
|
||||
}
|
||||
}
|
||||
|
||||
func (l *groupLookup) getUserIDs() []string {
|
||||
s := make([]string, 0, len(l.childUserIDToParentGroupID))
|
||||
for userID := range l.childUserIDToParentGroupID {
|
||||
s = append(s, userID)
|
||||
}
|
||||
sort.Strings(s)
|
||||
return s
|
||||
}
|
||||
|
||||
func (l *groupLookup) getGroupIDsForUser(userID string) []string {
|
||||
groupIDs := newStringSet()
|
||||
var todo []string
|
||||
for groupID := range l.childUserIDToParentGroupID.get(userID) {
|
||||
todo = append(todo, groupID)
|
||||
}
|
||||
|
||||
for len(todo) > 0 {
|
||||
groupID := todo[len(todo)-1]
|
||||
todo = todo[:len(todo)-1]
|
||||
if groupIDs.has(groupID) {
|
||||
continue
|
||||
}
|
||||
|
||||
groupIDs.add(groupID)
|
||||
for parentGroupID := range l.childGroupIDToParentGroupID.get(groupID) {
|
||||
todo = append(todo, parentGroupID)
|
||||
}
|
||||
}
|
||||
|
||||
return groupIDs.sorted()
|
||||
}
|
|
@ -1,25 +0,0 @@
|
|||
package azure
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestGroupLookup(t *testing.T) {
|
||||
gl := newGroupLookup()
|
||||
|
||||
gl.addGroup("g1", []string{"g11", "g12", "g13"}, []string{"u1"})
|
||||
gl.addGroup("g11", []string{"g111"}, nil)
|
||||
gl.addGroup("g111", nil, []string{"u2"})
|
||||
|
||||
assert.Equal(t, []string{"u1", "u2"}, gl.getUserIDs())
|
||||
assert.Equal(t, []string{"g1", "g11", "g111"}, gl.getGroupIDsForUser("u2"))
|
||||
|
||||
t.Run("cycle protection", func(t *testing.T) {
|
||||
gl.addGroup("g12", []string{"g1"}, nil)
|
||||
|
||||
assert.Equal(t, []string{"u1", "u2"}, gl.getUserIDs())
|
||||
assert.Equal(t, []string{"g1", "g11", "g111", "g12"}, gl.getGroupIDsForUser("u2"))
|
||||
})
|
||||
}
|
|
@ -1,8 +0,0 @@
|
|||
// Package directoryerrors contains errors used by directory providers.
|
||||
package directoryerrors
|
||||
|
||||
import "errors"
|
||||
|
||||
// ErrPreferExistingInformation indicates that the information returned by the provider should
|
||||
// only be used if a record is brand new, otherwise the existing information should be kept as is.
|
||||
var ErrPreferExistingInformation = errors.New("user ignored")
|
|
@ -1,330 +0,0 @@
|
|||
// Package github contains a directory provider for github.
|
||||
package github
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"sort"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/tomnomnom/linkheader"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/encoding"
|
||||
"github.com/pomerium/pomerium/internal/httputil"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/directory"
|
||||
)
|
||||
|
||||
// Name is the provider name.
|
||||
const Name = "github"
|
||||
|
||||
var defaultURL = &url.URL{
|
||||
Scheme: "https",
|
||||
Host: "api.github.com",
|
||||
}
|
||||
|
||||
type config struct {
|
||||
httpClient *http.Client
|
||||
serviceAccount *ServiceAccount
|
||||
url *url.URL
|
||||
}
|
||||
|
||||
// An Option updates the github configuration.
|
||||
type Option func(cfg *config)
|
||||
|
||||
// WithServiceAccount sets the service account in the config.
|
||||
func WithServiceAccount(serviceAccount *ServiceAccount) Option {
|
||||
return func(cfg *config) {
|
||||
cfg.serviceAccount = serviceAccount
|
||||
}
|
||||
}
|
||||
|
||||
// WithHTTPClient sets the http client option.
|
||||
func WithHTTPClient(httpClient *http.Client) Option {
|
||||
return func(cfg *config) {
|
||||
cfg.httpClient = httputil.NewLoggingClient(httpClient, "github_idp_client",
|
||||
func(evt *zerolog.Event) *zerolog.Event {
|
||||
return evt.Str("provider", "github")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// WithURL sets the api url in the config.
|
||||
func WithURL(u *url.URL) Option {
|
||||
return func(cfg *config) {
|
||||
cfg.url = u
|
||||
}
|
||||
}
|
||||
|
||||
func getConfig(options ...Option) *config {
|
||||
cfg := new(config)
|
||||
WithHTTPClient(http.DefaultClient)(cfg)
|
||||
WithURL(defaultURL)(cfg)
|
||||
for _, option := range options {
|
||||
option(cfg)
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
|
||||
// The Provider retrieves users and groups from github.
|
||||
type Provider struct {
|
||||
cfg *config
|
||||
log zerolog.Logger
|
||||
}
|
||||
|
||||
// New creates a new Provider.
|
||||
func New(options ...Option) *Provider {
|
||||
return &Provider{
|
||||
cfg: getConfig(options...),
|
||||
log: log.With().Str("service", "directory").Str("provider", "github").Logger(),
|
||||
}
|
||||
}
|
||||
|
||||
// User returns the user record for the given id.
|
||||
func (p *Provider) User(ctx context.Context, userID, accessToken string) (*directory.User, error) {
|
||||
if p.cfg.serviceAccount == nil {
|
||||
return nil, fmt.Errorf("github: service account not defined")
|
||||
}
|
||||
|
||||
du := &directory.User{
|
||||
Id: userID,
|
||||
}
|
||||
|
||||
au, err := p.getUser(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
du.DisplayName = au.Name
|
||||
du.Email = au.Email
|
||||
|
||||
teamIDLookup := map[string]struct{}{}
|
||||
orgSlugs, err := p.listOrgs(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, orgSlug := range orgSlugs {
|
||||
teamIDs, err := p.listUserOrganizationTeams(ctx, userID, orgSlug)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, teamID := range teamIDs {
|
||||
teamIDLookup[teamID] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
for teamID := range teamIDLookup {
|
||||
du.GroupIds = append(du.GroupIds, teamID)
|
||||
}
|
||||
sort.Strings(du.GroupIds)
|
||||
|
||||
return du, nil
|
||||
}
|
||||
|
||||
// UserGroups gets the directory user groups for github.
|
||||
func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*directory.User, error) {
|
||||
if p.cfg.serviceAccount == nil {
|
||||
return nil, nil, fmt.Errorf("github: service account not defined")
|
||||
}
|
||||
|
||||
orgSlugs, err := p.listOrgs(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
userLoginToGroups := map[string][]string{}
|
||||
|
||||
var allGroups []*directory.Group
|
||||
for _, orgSlug := range orgSlugs {
|
||||
teams, err := p.listOrganizationTeamsWithMemberIDs(ctx, orgSlug)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
for _, team := range teams {
|
||||
allGroups = append(allGroups, &directory.Group{
|
||||
Id: team.Slug,
|
||||
Name: team.Slug,
|
||||
})
|
||||
for _, memberID := range team.MemberIDs {
|
||||
userLoginToGroups[memberID] = append(userLoginToGroups[memberID], team.Slug)
|
||||
}
|
||||
}
|
||||
}
|
||||
sort.Slice(allGroups, func(i, j int) bool {
|
||||
return allGroups[i].Id < allGroups[j].Id
|
||||
})
|
||||
|
||||
var allUsers []*directory.User
|
||||
for _, orgSlug := range orgSlugs {
|
||||
members, err := p.listOrganizationMembers(ctx, orgSlug)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
for _, member := range members {
|
||||
du := &directory.User{
|
||||
Id: member.Login,
|
||||
GroupIds: userLoginToGroups[member.ID],
|
||||
DisplayName: member.Name,
|
||||
Email: member.Email,
|
||||
}
|
||||
sort.Strings(du.GroupIds)
|
||||
allUsers = append(allUsers, du)
|
||||
}
|
||||
}
|
||||
sort.Slice(allUsers, func(i, j int) bool {
|
||||
return allUsers[i].Id < allUsers[j].Id
|
||||
})
|
||||
|
||||
return allGroups, allUsers, nil
|
||||
}
|
||||
|
||||
func (p *Provider) listOrgs(ctx context.Context) (orgSlugs []string, err error) {
|
||||
nextURL := p.cfg.url.ResolveReference(&url.URL{
|
||||
Path: "/user/orgs",
|
||||
}).String()
|
||||
|
||||
for nextURL != "" {
|
||||
var results []struct {
|
||||
Login string `json:"login"`
|
||||
}
|
||||
hdrs, err := p.api(ctx, nextURL, &results)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, result := range results {
|
||||
orgSlugs = append(orgSlugs, result.Login)
|
||||
}
|
||||
|
||||
nextURL = getNextLink(hdrs)
|
||||
}
|
||||
|
||||
return orgSlugs, nil
|
||||
}
|
||||
|
||||
func (p *Provider) getUser(ctx context.Context, userLogin string) (*apiUserObject, error) {
|
||||
apiURL := p.cfg.url.ResolveReference(&url.URL{
|
||||
Path: fmt.Sprintf("/users/%s", userLogin),
|
||||
}).String()
|
||||
|
||||
var res apiUserObject
|
||||
_, err := p.api(ctx, apiURL, &res)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &res, nil
|
||||
}
|
||||
|
||||
func (p *Provider) api(ctx context.Context, apiURL string, out interface{}) (http.Header, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", apiURL, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("github: failed to create http request: %w", err)
|
||||
}
|
||||
req.SetBasicAuth(p.cfg.serviceAccount.Username, p.cfg.serviceAccount.PersonalAccessToken)
|
||||
|
||||
res, err := p.cfg.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("github: failed to make http request: %w", err)
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
if res.StatusCode/100 != 2 {
|
||||
return nil, fmt.Errorf("github: error from API: %s", res.Status)
|
||||
}
|
||||
|
||||
if out != nil {
|
||||
err := json.NewDecoder(res.Body).Decode(out)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("github: failed to decode json body: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return res.Header, nil
|
||||
}
|
||||
|
||||
func (p *Provider) graphql(ctx context.Context, query string, out interface{}) (http.Header, error) {
|
||||
apiURL := p.cfg.url.ResolveReference(&url.URL{
|
||||
Path: "/graphql",
|
||||
}).String()
|
||||
|
||||
bs, _ := json.Marshal(struct {
|
||||
Query string `json:"query"`
|
||||
}{query})
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", apiURL, bytes.NewReader(bs))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("github: failed to create http request: %w", err)
|
||||
}
|
||||
req.SetBasicAuth(p.cfg.serviceAccount.Username, p.cfg.serviceAccount.PersonalAccessToken)
|
||||
|
||||
res, err := p.cfg.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("github: failed to make http request: %w", err)
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
if res.StatusCode/100 != 2 {
|
||||
return nil, fmt.Errorf("github: error from API: %s", res.Status)
|
||||
}
|
||||
|
||||
if out != nil {
|
||||
err := json.NewDecoder(res.Body).Decode(out)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("github: failed to decode json body: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return res.Header, nil
|
||||
}
|
||||
|
||||
func getNextLink(hdrs http.Header) string {
|
||||
for _, link := range linkheader.ParseMultiple(hdrs.Values("Link")) {
|
||||
if link.Rel == "next" {
|
||||
return link.URL
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// A ServiceAccount is used by the GitHub provider to query the GitHub API.
|
||||
type ServiceAccount struct {
|
||||
Username string `json:"username"`
|
||||
PersonalAccessToken string `json:"personal_access_token"`
|
||||
}
|
||||
|
||||
// ParseServiceAccount parses the service account in the config options.
|
||||
func ParseServiceAccount(rawServiceAccount string) (*ServiceAccount, error) {
|
||||
var serviceAccount ServiceAccount
|
||||
if err := encoding.DecodeBase64OrJSON(rawServiceAccount, &serviceAccount); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if serviceAccount.Username == "" {
|
||||
return nil, fmt.Errorf("username is required")
|
||||
}
|
||||
if serviceAccount.PersonalAccessToken == "" {
|
||||
return nil, fmt.Errorf("personal_access_token is required")
|
||||
}
|
||||
|
||||
return &serviceAccount, nil
|
||||
}
|
||||
|
||||
// see: https://docs.github.com/en/free-pro-team@latest/rest/reference/users#get-a-user
|
||||
type apiUserObject struct {
|
||||
Login string `json:"login"`
|
||||
Name string `json:"name"`
|
||||
Email string `json:"email"`
|
||||
}
|
||||
|
||||
type teamWithMemberIDs struct {
|
||||
ID string
|
||||
Slug string
|
||||
Name string
|
||||
MemberIDs []string
|
||||
}
|
|
@ -1,436 +0,0 @@
|
|||
package github
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/go-chi/chi/v5/middleware"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/vektah/gqlparser/ast"
|
||||
"github.com/vektah/gqlparser/parser"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/testutil"
|
||||
)
|
||||
|
||||
type M = map[string]interface{}
|
||||
|
||||
func newMockAPI(t *testing.T, srv *httptest.Server) http.Handler {
|
||||
r := chi.NewRouter()
|
||||
r.Use(middleware.Logger)
|
||||
r.Use(func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if !assert.Equal(t, "Basic YWJjOnh5eg==", r.Header.Get("Authorization")) {
|
||||
http.Error(w, "forbidden", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
})
|
||||
r.Post("/graphql", func(w http.ResponseWriter, r *http.Request) {
|
||||
var body struct {
|
||||
Query string `json:"query"`
|
||||
}
|
||||
json.NewDecoder(r.Body).Decode(&body)
|
||||
|
||||
q, err := parser.ParseQuery(&ast.Source{
|
||||
Input: body.Query,
|
||||
})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
result := qlResult{
|
||||
Data: &qlData{
|
||||
Organization: &qlOrganization{},
|
||||
},
|
||||
}
|
||||
handleMembersWithRole := func(orgSlug string, field *ast.Field) {
|
||||
membersWithRole := &qlMembersWithRoleConnection{}
|
||||
|
||||
var cursor string
|
||||
for _, arg := range field.Arguments {
|
||||
if arg.Name == "after" {
|
||||
cursor = arg.Value.Raw
|
||||
}
|
||||
}
|
||||
|
||||
switch cursor {
|
||||
case `null`:
|
||||
switch orgSlug {
|
||||
case "org1":
|
||||
membersWithRole.PageInfo = qlPageInfo{EndCursor: "TOKEN1", HasNextPage: true}
|
||||
membersWithRole.Nodes = []qlUser{
|
||||
{ID: "user1", Login: "user1", Name: "User 1", Email: "user1@example.com"},
|
||||
{ID: "user2", Login: "user2", Name: "User 2", Email: "user2@example.com"},
|
||||
}
|
||||
case "org2":
|
||||
membersWithRole.PageInfo = qlPageInfo{HasNextPage: false}
|
||||
membersWithRole.Nodes = []qlUser{
|
||||
{ID: "user4", Login: "user4", Name: "User 4", Email: "user4@example.com"},
|
||||
}
|
||||
default:
|
||||
t.Errorf("unexpected org slug: %s", orgSlug)
|
||||
}
|
||||
case `TOKEN1`:
|
||||
membersWithRole.PageInfo = qlPageInfo{HasNextPage: false}
|
||||
membersWithRole.Nodes = []qlUser{
|
||||
{ID: "user3", Login: "user3", Name: "User 3", Email: "user3@example.com"},
|
||||
}
|
||||
default:
|
||||
t.Errorf("unexpected cursor: %s", cursor)
|
||||
}
|
||||
|
||||
result.Data.Organization.MembersWithRole = membersWithRole
|
||||
}
|
||||
handleTeamMembers := func(orgSlug, teamSlug string, field *ast.Field) {
|
||||
result.Data.Organization.Team.Members = &qlTeamMemberConnection{
|
||||
PageInfo: qlPageInfo{HasNextPage: false},
|
||||
}
|
||||
|
||||
switch teamSlug {
|
||||
case "team3":
|
||||
result.Data.Organization.Team.Members.Edges = []qlTeamMemberEdge{
|
||||
{Node: qlUser{ID: "user3"}},
|
||||
}
|
||||
}
|
||||
}
|
||||
handleTeam := func(orgSlug string, field *ast.Field) {
|
||||
result.Data.Organization.Team = &qlTeam{}
|
||||
|
||||
var teamSlug string
|
||||
for _, arg := range field.Arguments {
|
||||
if arg.Name == "slug" {
|
||||
teamSlug = arg.Value.Raw
|
||||
}
|
||||
}
|
||||
|
||||
for _, selection := range field.SelectionSet {
|
||||
subField, ok := selection.(*ast.Field)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
switch subField.Name {
|
||||
case "members":
|
||||
handleTeamMembers(orgSlug, teamSlug, subField)
|
||||
}
|
||||
}
|
||||
}
|
||||
renderNodeField := func(field *ast.Field, path []string, value string) string {
|
||||
outer:
|
||||
for _, segment := range path {
|
||||
for _, selection := range field.SelectionSet {
|
||||
subField, ok := selection.(*ast.Field)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
if subField.Name != segment {
|
||||
continue
|
||||
}
|
||||
|
||||
field = subField
|
||||
continue outer
|
||||
}
|
||||
return ""
|
||||
}
|
||||
return value
|
||||
}
|
||||
handleTeams := func(orgSlug string, field *ast.Field) {
|
||||
teams := &qlTeamConnection{}
|
||||
|
||||
var cursor string
|
||||
var userLogin string
|
||||
for _, arg := range field.Arguments {
|
||||
if arg.Name == "after" {
|
||||
cursor = arg.Value.Raw
|
||||
}
|
||||
if arg.Name == "userLogins" {
|
||||
userLogin = arg.Value.Children[0].Value.Raw
|
||||
}
|
||||
}
|
||||
|
||||
switch cursor {
|
||||
case `null`:
|
||||
switch orgSlug {
|
||||
case "org1":
|
||||
teams.PageInfo = qlPageInfo{HasNextPage: true, EndCursor: "TOKEN1"}
|
||||
teams.Edges = []qlTeamEdge{
|
||||
{Node: qlTeam{
|
||||
ID: renderNodeField(field, []string{"edges", "node", "id"}, "MDQ6VGVhbTE="),
|
||||
Slug: renderNodeField(field, []string{"edges", "node", "slug"}, "team1"),
|
||||
Name: renderNodeField(field, []string{"edges", "node", "name"}, "Team 1"),
|
||||
Members: &qlTeamMemberConnection{
|
||||
PageInfo: qlPageInfo{HasNextPage: false},
|
||||
Edges: []qlTeamMemberEdge{
|
||||
{Node: qlUser{ID: "user1"}},
|
||||
{Node: qlUser{ID: "user2"}},
|
||||
},
|
||||
}}},
|
||||
}
|
||||
case "org2":
|
||||
teams.PageInfo = qlPageInfo{HasNextPage: false}
|
||||
teams.Edges = []qlTeamEdge{
|
||||
{Node: qlTeam{
|
||||
ID: renderNodeField(field, []string{"edges", "node", "id"}, "MDQ6VGVhbTM="),
|
||||
Slug: renderNodeField(field, []string{"edges", "node", "slug"}, "team3"),
|
||||
Name: renderNodeField(field, []string{"edges", "node", "name"}, "Team 3"),
|
||||
Members: &qlTeamMemberConnection{
|
||||
PageInfo: qlPageInfo{HasNextPage: true, EndCursor: "TOKEN1"},
|
||||
Edges: []qlTeamMemberEdge{
|
||||
{Node: qlUser{ID: "user1"}},
|
||||
{Node: qlUser{ID: "user2"}},
|
||||
},
|
||||
}}},
|
||||
}
|
||||
if userLogin == "" || userLogin == "user4" {
|
||||
teams.Edges = append(teams.Edges, qlTeamEdge{
|
||||
Node: qlTeam{
|
||||
ID: renderNodeField(field, []string{"edges", "node", "id"}, "MDQ6VGVhbTQ="),
|
||||
Slug: renderNodeField(field, []string{"edges", "node", "slug"}, "team4"),
|
||||
Name: renderNodeField(field, []string{"edges", "node", "name"}, "Team 4"),
|
||||
Members: &qlTeamMemberConnection{
|
||||
PageInfo: qlPageInfo{HasNextPage: false},
|
||||
Edges: []qlTeamMemberEdge{
|
||||
{Node: qlUser{ID: "user4"}},
|
||||
},
|
||||
}},
|
||||
})
|
||||
}
|
||||
default:
|
||||
t.Errorf("unexpected org slug: %s", orgSlug)
|
||||
}
|
||||
case "TOKEN1":
|
||||
teams.PageInfo = qlPageInfo{HasNextPage: false}
|
||||
teams.Edges = []qlTeamEdge{
|
||||
{Node: qlTeam{
|
||||
ID: renderNodeField(field, []string{"edges", "node", "id"}, "MDQ6VGVhbTI="),
|
||||
Slug: renderNodeField(field, []string{"edges", "node", "slug"}, "team2"),
|
||||
Name: renderNodeField(field, []string{"edges", "node", "name"}, "Team 2"),
|
||||
Members: &qlTeamMemberConnection{
|
||||
PageInfo: qlPageInfo{HasNextPage: false},
|
||||
Edges: []qlTeamMemberEdge{
|
||||
{Node: qlUser{ID: "user1"}},
|
||||
},
|
||||
}}},
|
||||
}
|
||||
default:
|
||||
t.Errorf("unexpected cursor: %s", cursor)
|
||||
}
|
||||
|
||||
result.Data.Organization.Teams = teams
|
||||
}
|
||||
handleOrganization := func(field *ast.Field) {
|
||||
var orgSlug string
|
||||
for _, arg := range field.Arguments {
|
||||
if arg.Name == "login" {
|
||||
orgSlug = arg.Value.Raw
|
||||
}
|
||||
}
|
||||
for _, orgSelection := range field.SelectionSet {
|
||||
orgField, ok := orgSelection.(*ast.Field)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
switch orgField.Name {
|
||||
case "teams":
|
||||
handleTeams(orgSlug, orgField)
|
||||
case "team":
|
||||
handleTeam(orgSlug, orgField)
|
||||
case "membersWithRole":
|
||||
handleMembersWithRole(orgSlug, orgField)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, operation := range q.Operations {
|
||||
for _, selection := range operation.SelectionSet {
|
||||
field, ok := selection.(*ast.Field)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
if field.Name != "organization" {
|
||||
continue
|
||||
}
|
||||
|
||||
handleOrganization(field)
|
||||
}
|
||||
}
|
||||
|
||||
_ = json.NewEncoder(w).Encode(result)
|
||||
})
|
||||
r.Get("/user/orgs", func(w http.ResponseWriter, r *http.Request) {
|
||||
json.NewEncoder(w).Encode([]M{
|
||||
{"login": "org1"},
|
||||
{"login": "org2"},
|
||||
})
|
||||
})
|
||||
r.Get("/orgs/{org_id}/teams", func(w http.ResponseWriter, r *http.Request) {
|
||||
teams := map[string][]M{
|
||||
"org1": {
|
||||
{"slug": "team1", "id": 1},
|
||||
{"slug": "team2", "id": 2},
|
||||
},
|
||||
"org2": {
|
||||
{"slug": "team3", "id": 3},
|
||||
{"slug": "team4", "id": 4},
|
||||
},
|
||||
}
|
||||
orgID := chi.URLParam(r, "org_id")
|
||||
json.NewEncoder(w).Encode(teams[orgID])
|
||||
})
|
||||
r.Get("/orgs/{org_id}/teams/{team_id}/members", func(w http.ResponseWriter, r *http.Request) {
|
||||
members := map[string]map[string][]M{
|
||||
"org1": {
|
||||
"team1": {
|
||||
{"login": "user1"},
|
||||
{"login": "user2"},
|
||||
},
|
||||
"team2": {
|
||||
{"login": "user1"},
|
||||
},
|
||||
},
|
||||
"org2": {
|
||||
"team3": {
|
||||
{"login": "user1"},
|
||||
{"login": "user2"},
|
||||
{"login": "user3"},
|
||||
},
|
||||
"team4": {
|
||||
{"login": "user4"},
|
||||
},
|
||||
},
|
||||
}
|
||||
orgID := chi.URLParam(r, "org_id")
|
||||
teamID := chi.URLParam(r, "team_id")
|
||||
json.NewEncoder(w).Encode(members[orgID][teamID])
|
||||
})
|
||||
r.Get("/users/{user_id}", func(w http.ResponseWriter, r *http.Request) {
|
||||
users := map[string]apiUserObject{
|
||||
"user1": {Login: "user1", Name: "User 1", Email: "user1@example.com"},
|
||||
"user2": {Login: "user2", Name: "User 2", Email: "user2@example.com"},
|
||||
"user3": {Login: "user3", Name: "User 3", Email: "user3@example.com"},
|
||||
"user4": {Login: "user4", Name: "User 4", Email: "user4@example.com"},
|
||||
}
|
||||
userID := chi.URLParam(r, "user_id")
|
||||
json.NewEncoder(w).Encode(users[userID])
|
||||
})
|
||||
return r
|
||||
}
|
||||
|
||||
func TestProvider_User(t *testing.T) {
|
||||
var mockAPI http.Handler
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
mockAPI.ServeHTTP(w, r)
|
||||
}))
|
||||
defer srv.Close()
|
||||
mockAPI = newMockAPI(t, srv)
|
||||
|
||||
p := New(
|
||||
WithURL(mustParseURL(srv.URL)),
|
||||
WithServiceAccount(&ServiceAccount{
|
||||
Username: "abc",
|
||||
PersonalAccessToken: "xyz",
|
||||
}),
|
||||
)
|
||||
du, err := p.User(context.Background(), "user1", "")
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
testutil.AssertProtoJSONEqual(t, `{
|
||||
"id": "user1",
|
||||
"groupIds": ["team1", "team2", "team3"],
|
||||
"displayName": "User 1",
|
||||
"email": "user1@example.com"
|
||||
}`, du)
|
||||
}
|
||||
|
||||
func TestProvider_UserGroups(t *testing.T) {
|
||||
var mockAPI http.Handler
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
mockAPI.ServeHTTP(w, r)
|
||||
}))
|
||||
defer srv.Close()
|
||||
mockAPI = newMockAPI(t, srv)
|
||||
|
||||
p := New(
|
||||
WithURL(mustParseURL(srv.URL)),
|
||||
WithServiceAccount(&ServiceAccount{
|
||||
Username: "abc",
|
||||
PersonalAccessToken: "xyz",
|
||||
}),
|
||||
)
|
||||
groups, users, err := p.UserGroups(context.Background())
|
||||
assert.NoError(t, err)
|
||||
testutil.AssertProtoJSONEqual(t, `[
|
||||
{ "id": "user1", "groupIds": ["team1", "team2", "team3"], "displayName": "User 1", "email": "user1@example.com" },
|
||||
{ "id": "user2", "groupIds": ["team1", "team3"], "displayName": "User 2", "email": "user2@example.com" },
|
||||
{ "id": "user3", "groupIds": ["team3"], "displayName": "User 3", "email": "user3@example.com" },
|
||||
{ "id": "user4", "groupIds": ["team4"], "displayName": "User 4", "email": "user4@example.com" }
|
||||
]`, users)
|
||||
testutil.AssertProtoJSONEqual(t, `[
|
||||
{ "id": "team1", "name": "team1" },
|
||||
{ "id": "team2", "name": "team2" },
|
||||
{ "id": "team3", "name": "team3" },
|
||||
{ "id": "team4", "name": "team4" }
|
||||
]`, groups)
|
||||
}
|
||||
|
||||
func TestParseServiceAccount(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
rawServiceAccount string
|
||||
serviceAccount *ServiceAccount
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
"json",
|
||||
`{"username": "USERNAME", "personal_access_token": "PERSONAL_ACCESS_TOKEN"}`,
|
||||
&ServiceAccount{Username: "USERNAME", PersonalAccessToken: "PERSONAL_ACCESS_TOKEN"},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"base64 json",
|
||||
`eyJ1c2VybmFtZSI6ICJVU0VSTkFNRSIsICJwZXJzb25hbF9hY2Nlc3NfdG9rZW4iOiAiUEVSU09OQUxfQUNDRVNTX1RPS0VOIn0=`,
|
||||
&ServiceAccount{Username: "USERNAME", PersonalAccessToken: "PERSONAL_ACCESS_TOKEN"},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"empty",
|
||||
"",
|
||||
nil,
|
||||
true,
|
||||
},
|
||||
{
|
||||
"invalid",
|
||||
"Zm9v---",
|
||||
nil,
|
||||
true,
|
||||
},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got, err := ParseServiceAccount(tc.rawServiceAccount)
|
||||
require.True(t, (err != nil) == tc.wantErr)
|
||||
assert.Equal(t, tc.serviceAccount, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func mustParseURL(rawurl string) *url.URL {
|
||||
u, err := url.Parse(rawurl)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return u
|
||||
}
|
|
@ -1,246 +0,0 @@
|
|||
package github
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
const maxPageCount = 100
|
||||
|
||||
type (
|
||||
qlData struct {
|
||||
Organization *qlOrganization `json:"organization"`
|
||||
}
|
||||
qlMembersWithRoleConnection struct {
|
||||
Nodes []qlUser `json:"nodes"`
|
||||
PageInfo qlPageInfo `json:"pageInfo"`
|
||||
}
|
||||
qlOrganization struct {
|
||||
MembersWithRole *qlMembersWithRoleConnection `json:"membersWithRole"`
|
||||
Team *qlTeam `json:"team"`
|
||||
Teams *qlTeamConnection `json:"teams"`
|
||||
}
|
||||
qlPageInfo struct {
|
||||
EndCursor string `json:"endCursor"`
|
||||
HasNextPage bool `json:"hasNextPage"`
|
||||
}
|
||||
qlResult struct {
|
||||
Data *qlData `json:"data"`
|
||||
}
|
||||
qlTeam struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Slug string `json:"slug"`
|
||||
Members *qlTeamMemberConnection `json:"members"`
|
||||
}
|
||||
qlTeamConnection struct {
|
||||
Edges []qlTeamEdge `json:"edges"`
|
||||
PageInfo qlPageInfo `json:"pageInfo"`
|
||||
}
|
||||
qlTeamEdge struct {
|
||||
Node qlTeam `json:"node"`
|
||||
}
|
||||
qlTeamMemberConnection struct {
|
||||
Edges []qlTeamMemberEdge `json:"edges"`
|
||||
PageInfo qlPageInfo `json:"pageInfo"`
|
||||
}
|
||||
qlTeamMemberEdge struct {
|
||||
Node qlUser `json:"node"`
|
||||
}
|
||||
qlUser struct {
|
||||
ID string `json:"id"`
|
||||
Login string `json:"login"`
|
||||
Name string `json:"name"`
|
||||
Email string `json:"email"`
|
||||
}
|
||||
)
|
||||
|
||||
func (p *Provider) listOrganizationMembers(ctx context.Context, orgSlug string) ([]qlUser, error) {
|
||||
var results []qlUser
|
||||
var cursor *string
|
||||
for {
|
||||
var res qlResult
|
||||
q := fmt.Sprintf(`query {
|
||||
organization(login:%s) {
|
||||
membersWithRole(first:%d, after:%s) {
|
||||
pageInfo {
|
||||
endCursor
|
||||
hasNextPage
|
||||
}
|
||||
nodes {
|
||||
id
|
||||
login
|
||||
name
|
||||
email
|
||||
}
|
||||
}
|
||||
}
|
||||
}`, encode(orgSlug), maxPageCount, encode(cursor))
|
||||
_, err := p.graphql(ctx, q, &res)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
results = append(results, res.Data.Organization.MembersWithRole.Nodes...)
|
||||
|
||||
if !res.Data.Organization.MembersWithRole.PageInfo.HasNextPage {
|
||||
break
|
||||
}
|
||||
cursor = &res.Data.Organization.MembersWithRole.PageInfo.EndCursor
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func (p *Provider) listOrganizationTeamsWithMemberIDs(ctx context.Context, orgSlug string) ([]teamWithMemberIDs, error) {
|
||||
var results []teamWithMemberIDs
|
||||
var pageInfos []qlPageInfo
|
||||
|
||||
// first query all the teams with their members
|
||||
var cursor *string
|
||||
for {
|
||||
var res qlResult
|
||||
q := fmt.Sprintf(`query {
|
||||
organization(login:%s) {
|
||||
teams(first:%d, after:%s) {
|
||||
pageInfo {
|
||||
endCursor
|
||||
hasNextPage
|
||||
}
|
||||
edges {
|
||||
node {
|
||||
id
|
||||
name
|
||||
slug
|
||||
members(first:%d) {
|
||||
pageInfo {
|
||||
endCursor
|
||||
hasNextPage
|
||||
}
|
||||
edges {
|
||||
node {
|
||||
id
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}`, encode(orgSlug), maxPageCount, encode(cursor), maxPageCount)
|
||||
_, err := p.graphql(ctx, q, &res)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, teamEdge := range res.Data.Organization.Teams.Edges {
|
||||
var memberIDs []string
|
||||
for _, memberEdge := range teamEdge.Node.Members.Edges {
|
||||
memberIDs = append(memberIDs, memberEdge.Node.ID)
|
||||
}
|
||||
results = append(results, teamWithMemberIDs{
|
||||
ID: teamEdge.Node.ID,
|
||||
Slug: teamEdge.Node.Slug,
|
||||
Name: teamEdge.Node.Name,
|
||||
MemberIDs: memberIDs,
|
||||
})
|
||||
pageInfos = append(pageInfos, teamEdge.Node.Members.PageInfo)
|
||||
}
|
||||
|
||||
if !res.Data.Organization.Teams.PageInfo.HasNextPage {
|
||||
break
|
||||
}
|
||||
cursor = &res.Data.Organization.Teams.PageInfo.EndCursor
|
||||
}
|
||||
|
||||
// it's possible we didn't get all the members if the initial query, so go through each team and
|
||||
// check the member pageInfo. If there are still remaining members, query those.
|
||||
for i, pageInfo := range pageInfos {
|
||||
if !pageInfo.HasNextPage {
|
||||
continue
|
||||
}
|
||||
|
||||
cursor = &pageInfo.EndCursor
|
||||
for {
|
||||
var res qlResult
|
||||
q := fmt.Sprintf(`query {
|
||||
organization(login:%s) {
|
||||
team(slug:%s) {
|
||||
members(first:%d, after:%s) {
|
||||
pageInfo {
|
||||
endCursor
|
||||
hasNextPage
|
||||
}
|
||||
edges {
|
||||
node {
|
||||
id
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}`, encode(orgSlug), encode(results[i].Slug), maxPageCount, encode(cursor))
|
||||
_, err := p.graphql(ctx, q, &res)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, memberEdge := range res.Data.Organization.Team.Members.Edges {
|
||||
results[i].MemberIDs = append(results[i].MemberIDs, memberEdge.Node.ID)
|
||||
}
|
||||
|
||||
if !res.Data.Organization.Team.Members.PageInfo.HasNextPage {
|
||||
break
|
||||
}
|
||||
cursor = &res.Data.Organization.Team.Members.PageInfo.EndCursor
|
||||
}
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func (p *Provider) listUserOrganizationTeams(ctx context.Context, userSlug string, orgSlug string) ([]string, error) {
|
||||
// GitHub's Rest API doesn't have an easy way of querying this data, so we use the GraphQL API.
|
||||
|
||||
var teamSlugs []string
|
||||
var cursor *string
|
||||
for {
|
||||
var res qlResult
|
||||
q := fmt.Sprintf(`query {
|
||||
organization(login:%s) {
|
||||
teams(first:%d, userLogins:[%s], after:%s) {
|
||||
pageInfo {
|
||||
endCursor
|
||||
hasNextPage
|
||||
}
|
||||
edges {
|
||||
node {
|
||||
id
|
||||
slug
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}`, encode(orgSlug), maxPageCount, encode(userSlug), encode(cursor))
|
||||
_, err := p.graphql(ctx, q, &res)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, edge := range res.Data.Organization.Teams.Edges {
|
||||
teamSlugs = append(teamSlugs, edge.Node.Slug)
|
||||
}
|
||||
|
||||
if !res.Data.Organization.Teams.PageInfo.HasNextPage {
|
||||
break
|
||||
}
|
||||
cursor = &res.Data.Organization.Teams.PageInfo.EndCursor
|
||||
}
|
||||
|
||||
return teamSlugs, nil
|
||||
}
|
||||
|
||||
func encode(obj interface{}) string {
|
||||
bs, _ := json.Marshal(obj)
|
||||
return string(bs)
|
||||
}
|
|
@ -1,286 +0,0 @@
|
|||
// Package gitlab contains a directory provider for gitlab.
|
||||
package gitlab
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"sort"
|
||||
"strconv"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/tomnomnom/linkheader"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/encoding"
|
||||
"github.com/pomerium/pomerium/internal/httputil"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/directory"
|
||||
)
|
||||
|
||||
// Name is the provider name.
|
||||
const Name = "gitlab"
|
||||
|
||||
var defaultURL = &url.URL{
|
||||
Scheme: "https",
|
||||
Host: "gitlab.com",
|
||||
}
|
||||
|
||||
type config struct {
|
||||
httpClient *http.Client
|
||||
serviceAccount *ServiceAccount
|
||||
url *url.URL
|
||||
}
|
||||
|
||||
// An Option updates the gitlab configuration.
|
||||
type Option func(cfg *config)
|
||||
|
||||
// WithServiceAccount sets the service account in the config.
|
||||
func WithServiceAccount(serviceAccount *ServiceAccount) Option {
|
||||
return func(cfg *config) {
|
||||
cfg.serviceAccount = serviceAccount
|
||||
}
|
||||
}
|
||||
|
||||
// WithHTTPClient sets the http client option.
|
||||
func WithHTTPClient(httpClient *http.Client) Option {
|
||||
return func(cfg *config) {
|
||||
cfg.httpClient = httputil.NewLoggingClient(httpClient, "gitlab_idp_client",
|
||||
func(evt *zerolog.Event) *zerolog.Event {
|
||||
return evt.Str("provider", "azure")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// WithURL sets the api url in the config.
|
||||
func WithURL(u *url.URL) Option {
|
||||
return func(cfg *config) {
|
||||
cfg.url = u
|
||||
}
|
||||
}
|
||||
|
||||
func getConfig(options ...Option) *config {
|
||||
cfg := new(config)
|
||||
WithHTTPClient(http.DefaultClient)(cfg)
|
||||
WithURL(defaultURL)(cfg)
|
||||
for _, option := range options {
|
||||
option(cfg)
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
|
||||
// The Provider retrieves users and groups from gitlab.
|
||||
type Provider struct {
|
||||
cfg *config
|
||||
}
|
||||
|
||||
// New creates a new Provider.
|
||||
func New(options ...Option) *Provider {
|
||||
return &Provider{
|
||||
cfg: getConfig(options...),
|
||||
}
|
||||
}
|
||||
|
||||
func withLog(ctx context.Context) context.Context {
|
||||
return log.WithContext(ctx, func(c zerolog.Context) zerolog.Context {
|
||||
return c.Str("service", "directory").Str("provider", "gitlab")
|
||||
})
|
||||
}
|
||||
|
||||
// User returns the user record for the given id.
|
||||
func (p *Provider) User(ctx context.Context, userID, accessToken string) (*directory.User, error) {
|
||||
ctx = withLog(ctx)
|
||||
|
||||
du := &directory.User{
|
||||
Id: userID,
|
||||
}
|
||||
|
||||
au, err := p.getUser(ctx, userID, accessToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
du.DisplayName = au.Name
|
||||
du.Email = au.Email
|
||||
|
||||
groups, err := p.listGroups(ctx, accessToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, g := range groups {
|
||||
du.GroupIds = append(du.GroupIds, g.Id)
|
||||
}
|
||||
sort.Strings(du.GroupIds)
|
||||
|
||||
return du, nil
|
||||
}
|
||||
|
||||
// UserGroups gets the directory user groups for gitlab.
|
||||
func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*directory.User, error) {
|
||||
ctx = withLog(ctx)
|
||||
|
||||
if p.cfg.serviceAccount == nil {
|
||||
return nil, nil, fmt.Errorf("gitlab: service account not defined")
|
||||
}
|
||||
|
||||
log.Info(ctx).Msg("getting user groups")
|
||||
|
||||
groups, err := p.listGroups(ctx, "")
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
userLookup := map[int]apiUserObject{}
|
||||
userIDToGroupIDs := map[int][]string{}
|
||||
for _, group := range groups {
|
||||
users, err := p.listGroupMembers(ctx, group.Id)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
for _, u := range users {
|
||||
userIDToGroupIDs[u.ID] = append(userIDToGroupIDs[u.ID], group.Id)
|
||||
userLookup[u.ID] = u
|
||||
}
|
||||
}
|
||||
|
||||
var users []*directory.User
|
||||
for _, u := range userLookup {
|
||||
user := &directory.User{
|
||||
Id: fmt.Sprint(u.ID),
|
||||
DisplayName: u.Name,
|
||||
Email: u.Email,
|
||||
}
|
||||
|
||||
user.GroupIds = append(user.GroupIds, userIDToGroupIDs[u.ID]...)
|
||||
|
||||
sort.Strings(user.GroupIds)
|
||||
users = append(users, user)
|
||||
}
|
||||
sort.Slice(users, func(i, j int) bool {
|
||||
return users[i].GetId() < users[j].GetId()
|
||||
})
|
||||
return groups, users, nil
|
||||
}
|
||||
|
||||
func (p *Provider) getUser(ctx context.Context, userID string, accessToken string) (*apiUserObject, error) {
|
||||
apiURL := p.cfg.url.ResolveReference(&url.URL{
|
||||
Path: fmt.Sprintf("/api/v4/users/%s", userID),
|
||||
}).String()
|
||||
var result apiUserObject
|
||||
_, err := p.api(ctx, accessToken, apiURL, &result)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("gitlab: error querying user: %w", err)
|
||||
}
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// listGroups returns a map, with key is group ID, element is group name.
|
||||
func (p *Provider) listGroups(ctx context.Context, accessToken string) ([]*directory.Group, error) {
|
||||
nextURL := p.cfg.url.ResolveReference(&url.URL{
|
||||
Path: "/api/v4/groups",
|
||||
}).String()
|
||||
var groups []*directory.Group
|
||||
for nextURL != "" {
|
||||
var result []struct {
|
||||
ID int `json:"id"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
hdrs, err := p.api(ctx, accessToken, nextURL, &result)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("gitlab: error querying groups: %w", err)
|
||||
}
|
||||
|
||||
for _, r := range result {
|
||||
groups = append(groups, &directory.Group{
|
||||
Id: strconv.Itoa(r.ID),
|
||||
Name: r.Name,
|
||||
})
|
||||
}
|
||||
|
||||
nextURL = getNextLink(hdrs)
|
||||
}
|
||||
return groups, nil
|
||||
}
|
||||
|
||||
func (p *Provider) listGroupMembers(ctx context.Context, groupID string) (users []apiUserObject, err error) {
|
||||
nextURL := p.cfg.url.ResolveReference(&url.URL{
|
||||
Path: fmt.Sprintf("/api/v4/groups/%s/members", groupID),
|
||||
}).String()
|
||||
for nextURL != "" {
|
||||
var result []apiUserObject
|
||||
hdrs, err := p.api(ctx, "", nextURL, &result)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("gitlab: error querying group members: %w", err)
|
||||
}
|
||||
|
||||
users = append(users, result...)
|
||||
nextURL = getNextLink(hdrs)
|
||||
}
|
||||
return users, nil
|
||||
}
|
||||
|
||||
func (p *Provider) api(ctx context.Context, accessToken string, uri string, out interface{}) (http.Header, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", uri, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("gitlab: failed to create HTTP request: %w", err)
|
||||
}
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
if accessToken != "" {
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", accessToken))
|
||||
} else {
|
||||
req.Header.Set("PRIVATE-TOKEN", p.cfg.serviceAccount.PrivateToken)
|
||||
}
|
||||
|
||||
res, err := p.cfg.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
if res.StatusCode/100 != 2 {
|
||||
return nil, fmt.Errorf("gitlab: error querying api url=%s status_code=%d: %s", uri, res.StatusCode, res.Status)
|
||||
}
|
||||
|
||||
err = json.NewDecoder(res.Body).Decode(out)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return res.Header, nil
|
||||
}
|
||||
|
||||
func getNextLink(hdrs http.Header) string {
|
||||
for _, link := range linkheader.ParseMultiple(hdrs.Values("Link")) {
|
||||
if link.Rel == "next" {
|
||||
return link.URL
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// A ServiceAccount is used by the Gitlab provider to query the Gitlab API.
|
||||
type ServiceAccount struct {
|
||||
PrivateToken string `json:"private_token"`
|
||||
}
|
||||
|
||||
// ParseServiceAccount parses the service account in the config options.
|
||||
func ParseServiceAccount(rawServiceAccount string) (*ServiceAccount, error) {
|
||||
var serviceAccount ServiceAccount
|
||||
if err := encoding.DecodeBase64OrJSON(rawServiceAccount, &serviceAccount); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if serviceAccount.PrivateToken == "" {
|
||||
return nil, fmt.Errorf("private_token is required")
|
||||
}
|
||||
|
||||
return &serviceAccount, nil
|
||||
}
|
||||
|
||||
type apiUserObject struct {
|
||||
ID int `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Email string `json:"email"`
|
||||
}
|
|
@ -1,132 +0,0 @@
|
|||
package gitlab
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/go-chi/chi/v5/middleware"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/testutil"
|
||||
)
|
||||
|
||||
type M = map[string]interface{}
|
||||
|
||||
func newMockAPI(t *testing.T, srv *httptest.Server) http.Handler {
|
||||
r := chi.NewRouter()
|
||||
r.Use(middleware.Logger)
|
||||
r.Route("/api/v4", func(r chi.Router) {
|
||||
r.Use(func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Header.Get("Private-Token") != "PRIVATE_TOKEN" {
|
||||
http.Error(w, "forbidden", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
})
|
||||
r.Get("/groups", func(w http.ResponseWriter, r *http.Request) {
|
||||
_ = json.NewEncoder(w).Encode([]M{
|
||||
{"id": 1, "name": "Group 1"},
|
||||
{"id": 2, "name": "Group 2"},
|
||||
})
|
||||
})
|
||||
r.Get("/groups/{group_name}/members", func(w http.ResponseWriter, r *http.Request) {
|
||||
members := map[string][]M{
|
||||
"1": {
|
||||
{"id": 11, "name": "User 1", "email": "user1@example.com"},
|
||||
},
|
||||
"2": {
|
||||
{"id": 12, "name": "User 2", "email": "user2@example.com"},
|
||||
{"id": 13, "name": "User 3", "email": "user3@example.com"},
|
||||
},
|
||||
}
|
||||
_ = json.NewEncoder(w).Encode(members[chi.URLParam(r, "group_name")])
|
||||
})
|
||||
})
|
||||
return r
|
||||
}
|
||||
|
||||
func Test(t *testing.T) {
|
||||
var mockAPI http.Handler
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
mockAPI.ServeHTTP(w, r)
|
||||
}))
|
||||
defer srv.Close()
|
||||
mockAPI = newMockAPI(t, srv)
|
||||
|
||||
p := New(
|
||||
WithURL(mustParseURL(srv.URL)),
|
||||
WithServiceAccount(&ServiceAccount{
|
||||
PrivateToken: "PRIVATE_TOKEN",
|
||||
}),
|
||||
)
|
||||
groups, users, err := p.UserGroups(context.Background())
|
||||
assert.NoError(t, err)
|
||||
testutil.AssertProtoJSONEqual(t, `[
|
||||
{ "id": "11", "groupIds": ["1"], "displayName": "User 1", "email": "user1@example.com" },
|
||||
{ "id": "12", "groupIds": ["2"], "displayName": "User 2", "email": "user2@example.com" },
|
||||
{ "id": "13", "groupIds": ["2"], "displayName": "User 3", "email": "user3@example.com" }
|
||||
]`, users)
|
||||
testutil.AssertProtoJSONEqual(t, `[
|
||||
{ "id": "1", "name": "Group 1" },
|
||||
{ "id": "2", "name": "Group 2" }
|
||||
]`, groups)
|
||||
}
|
||||
|
||||
func TestParseServiceAccount(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
rawServiceAccount string
|
||||
serviceAccount *ServiceAccount
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
"json",
|
||||
`{"private_token":"PRIVATE_TOKEN"}`,
|
||||
&ServiceAccount{PrivateToken: "PRIVATE_TOKEN"},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"base64 json",
|
||||
`eyJwcml2YXRlX3Rva2VuIjoiUFJJVkFURV9UT0tFTiJ9`,
|
||||
&ServiceAccount{PrivateToken: "PRIVATE_TOKEN"},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"empty",
|
||||
"",
|
||||
nil,
|
||||
true,
|
||||
},
|
||||
{
|
||||
"invalid",
|
||||
"Zm9v---",
|
||||
nil,
|
||||
true,
|
||||
},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got, err := ParseServiceAccount(tc.rawServiceAccount)
|
||||
require.True(t, (err != nil) == tc.wantErr)
|
||||
assert.Equal(t, tc.serviceAccount, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func mustParseURL(rawurl string) *url.URL {
|
||||
u, err := url.Parse(rawurl)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return u
|
||||
}
|
|
@ -1,323 +0,0 @@
|
|||
// Package google contains the Google directory provider.
|
||||
package google
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sort"
|
||||
"sync"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"golang.org/x/oauth2/google"
|
||||
admin "google.golang.org/api/admin/directory/v1"
|
||||
"google.golang.org/api/googleapi"
|
||||
"google.golang.org/api/option"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/directory/directoryerrors"
|
||||
"github.com/pomerium/pomerium/internal/encoding"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/directory"
|
||||
)
|
||||
|
||||
const (
|
||||
// Name is the provider name.
|
||||
Name = "google"
|
||||
|
||||
currentAccountCustomerID = "my_customer"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultProviderURL = "https://www.googleapis.com/"
|
||||
)
|
||||
|
||||
type config struct {
|
||||
serviceAccount *ServiceAccount
|
||||
url string
|
||||
}
|
||||
|
||||
// An Option changes the configuration for the Google directory provider.
|
||||
type Option func(cfg *config)
|
||||
|
||||
// WithServiceAccount sets the service account in the Google configuration.
|
||||
func WithServiceAccount(serviceAccount *ServiceAccount) Option {
|
||||
return func(cfg *config) {
|
||||
cfg.serviceAccount = serviceAccount
|
||||
}
|
||||
}
|
||||
|
||||
// WithURL sets the provider url to use.
|
||||
func WithURL(url string) Option {
|
||||
return func(cfg *config) {
|
||||
cfg.url = url
|
||||
}
|
||||
}
|
||||
|
||||
func getConfig(opts ...Option) *config {
|
||||
cfg := new(config)
|
||||
WithURL(defaultProviderURL)(cfg)
|
||||
for _, opt := range opts {
|
||||
opt(cfg)
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
|
||||
// Required scopes for groups api
|
||||
// https://developers.google.com/admin-sdk/directory/v1/reference/groups/list
|
||||
var apiScopes = []string{admin.AdminDirectoryUserReadonlyScope, admin.AdminDirectoryGroupReadonlyScope}
|
||||
|
||||
// A Provider is a Google directory provider.
|
||||
type Provider struct {
|
||||
cfg *config
|
||||
log zerolog.Logger
|
||||
|
||||
mu sync.RWMutex
|
||||
apiClient *admin.Service
|
||||
}
|
||||
|
||||
// New creates a new Google directory provider.
|
||||
func New(options ...Option) *Provider {
|
||||
return &Provider{
|
||||
cfg: getConfig(options...),
|
||||
log: log.With().Str("service", "directory").Str("provider", "google").Logger(),
|
||||
}
|
||||
}
|
||||
|
||||
// User returns the user record for the given id.
|
||||
func (p *Provider) User(ctx context.Context, userID, accessToken string) (*directory.User, error) {
|
||||
apiClient, err := p.getAPIClient(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("google: error getting API client: %w", err)
|
||||
}
|
||||
|
||||
du := &directory.User{
|
||||
Id: userID,
|
||||
}
|
||||
|
||||
au, err := apiClient.Users.Get(userID).
|
||||
Context(ctx).
|
||||
Do()
|
||||
if isAccessDenied(err) {
|
||||
// ignore forbidden errors as a user may login from another gsuite domain
|
||||
return du, directoryerrors.ErrPreferExistingInformation
|
||||
} else if err != nil {
|
||||
return nil, fmt.Errorf("google: error getting user: %w", err)
|
||||
} else {
|
||||
if au.Name != nil {
|
||||
du.DisplayName = au.Name.FullName
|
||||
}
|
||||
du.Email = au.PrimaryEmail
|
||||
}
|
||||
|
||||
err = apiClient.Groups.List().
|
||||
Context(ctx).
|
||||
UserKey(userID).
|
||||
Pages(ctx, func(res *admin.Groups) error {
|
||||
for _, g := range res.Groups {
|
||||
du.GroupIds = append(du.GroupIds, g.Id)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("google: error getting groups for user: %w", err)
|
||||
}
|
||||
|
||||
sort.Strings(du.GroupIds)
|
||||
|
||||
return du, nil
|
||||
}
|
||||
|
||||
// UserGroups returns a slice of group names a given user is in
|
||||
// NOTE: groups via Directory API is limited to 1 QPS!
|
||||
// https://developers.google.com/admin-sdk/directory/v1/reference/groups/list
|
||||
// https://developers.google.com/admin-sdk/directory/v1/limits
|
||||
func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*directory.User, error) {
|
||||
apiClient, err := p.getAPIClient(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("google: error getting API client: %w", err)
|
||||
}
|
||||
|
||||
// query all the groups
|
||||
var groups []*directory.Group
|
||||
err = apiClient.Groups.List().
|
||||
Context(ctx).
|
||||
Customer(currentAccountCustomerID).
|
||||
Pages(ctx, func(res *admin.Groups) error {
|
||||
for _, g := range res.Groups {
|
||||
// Skip group without member.
|
||||
if g.DirectMembersCount == 0 {
|
||||
continue
|
||||
}
|
||||
groups = append(groups, &directory.Group{
|
||||
Id: g.Id,
|
||||
Name: g.Email,
|
||||
Email: g.Email,
|
||||
})
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("google: error getting groups: %w", err)
|
||||
}
|
||||
|
||||
// query all the user members for each group
|
||||
// - create a lookup table for the user (storing id and name)
|
||||
// (this includes users who aren't necessarily members of the same organization)
|
||||
// - create a lookup table for the user's groups
|
||||
userLookup := map[string]apiUserObject{}
|
||||
userIDToGroups := map[string][]string{}
|
||||
for _, group := range groups {
|
||||
group := group
|
||||
err = apiClient.Members.List(group.Id).
|
||||
Context(ctx).
|
||||
Pages(ctx, func(res *admin.Members) error {
|
||||
for _, member := range res.Members {
|
||||
// only include user objects
|
||||
if member.Type != "USER" {
|
||||
continue
|
||||
}
|
||||
|
||||
userLookup[member.Id] = apiUserObject{
|
||||
ID: member.Id,
|
||||
Email: member.Email,
|
||||
}
|
||||
userIDToGroups[member.Id] = append(userIDToGroups[member.Id], group.Id)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("google: error getting group members: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// query all the users in the organization
|
||||
err = apiClient.Users.List().
|
||||
Context(ctx).
|
||||
Customer(currentAccountCustomerID).
|
||||
Pages(ctx, func(res *admin.Users) error {
|
||||
for _, u := range res.Users {
|
||||
auo := apiUserObject{
|
||||
ID: u.Id,
|
||||
Email: u.PrimaryEmail,
|
||||
}
|
||||
if u.Name != nil {
|
||||
auo.DisplayName = u.Name.FullName
|
||||
}
|
||||
userLookup[u.Id] = auo
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("google: error getting users: %w", err)
|
||||
}
|
||||
|
||||
var users []*directory.User
|
||||
for _, u := range userLookup {
|
||||
groups := userIDToGroups[u.ID]
|
||||
sort.Strings(groups)
|
||||
users = append(users, &directory.User{
|
||||
Id: u.ID,
|
||||
GroupIds: groups,
|
||||
DisplayName: u.DisplayName,
|
||||
Email: u.Email,
|
||||
})
|
||||
}
|
||||
sort.Slice(users, func(i, j int) bool {
|
||||
return users[i].Id < users[j].Id
|
||||
})
|
||||
return groups, users, nil
|
||||
}
|
||||
|
||||
func (p *Provider) getAPIClient(ctx context.Context) (*admin.Service, error) {
|
||||
p.mu.RLock()
|
||||
apiClient := p.apiClient
|
||||
p.mu.RUnlock()
|
||||
if apiClient != nil {
|
||||
return apiClient, nil
|
||||
}
|
||||
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
if p.apiClient != nil {
|
||||
return p.apiClient, nil
|
||||
}
|
||||
|
||||
apiCreds, err := json.Marshal(p.cfg.serviceAccount)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("google: could not marshal service account json %w", err)
|
||||
}
|
||||
|
||||
config, err := google.JWTConfigFromJSON(apiCreds, apiScopes...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("google: error reading jwt config: %w", err)
|
||||
}
|
||||
config.Subject = p.cfg.serviceAccount.ImpersonateUser
|
||||
|
||||
ts := config.TokenSource(ctx)
|
||||
|
||||
p.apiClient, err = admin.NewService(ctx,
|
||||
option.WithTokenSource(ts),
|
||||
option.WithEndpoint(p.cfg.url))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("google: failed creating admin service %w", err)
|
||||
}
|
||||
return p.apiClient, nil
|
||||
}
|
||||
|
||||
// A ServiceAccount is used to authenticate with the Google APIs.
|
||||
//
|
||||
// Google oauth fields are from https://github.com/golang/oauth2/blob/master/google/google.go#L99
|
||||
type ServiceAccount struct {
|
||||
Type string `json:"type"` // serviceAccountKey or userCredentialsKey
|
||||
|
||||
// Service Account fields
|
||||
ClientEmail string `json:"client_email"`
|
||||
PrivateKeyID string `json:"private_key_id"`
|
||||
PrivateKey string `json:"private_key"`
|
||||
TokenURL string `json:"token_uri"`
|
||||
ProjectID string `json:"project_id"`
|
||||
|
||||
// User Credential fields
|
||||
// (These typically come from gcloud auth.)
|
||||
ClientSecret string `json:"client_secret"`
|
||||
ClientID string `json:"client_id"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
|
||||
// The User to use for Admin Directory API calls
|
||||
ImpersonateUser string `json:"impersonate_user"`
|
||||
}
|
||||
|
||||
// ParseServiceAccount parses the service account in the config options.
|
||||
func ParseServiceAccount(rawServiceAccount string) (*ServiceAccount, error) {
|
||||
var serviceAccount ServiceAccount
|
||||
if err := encoding.DecodeBase64OrJSON(rawServiceAccount, &serviceAccount); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if serviceAccount.ImpersonateUser == "" {
|
||||
return nil, fmt.Errorf("impersonate_user is required")
|
||||
}
|
||||
|
||||
return &serviceAccount, nil
|
||||
}
|
||||
|
||||
type apiUserObject struct {
|
||||
ID string
|
||||
DisplayName string
|
||||
Email string
|
||||
}
|
||||
|
||||
func isAccessDenied(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
gerr := new(googleapi.Error)
|
||||
if errors.As(err, &gerr) {
|
||||
return gerr.Code == http.StatusForbidden
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
|
@ -1,260 +0,0 @@
|
|||
package google
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/go-chi/chi/v5/middleware"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/directory/directoryerrors"
|
||||
"github.com/pomerium/pomerium/internal/testutil"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/directory"
|
||||
)
|
||||
|
||||
var privateKey = `
|
||||
-----BEGIN RSA PRIVATE KEY-----
|
||||
MIIG4wIBAAKCAYEAnetGqPqS6dqYnV9S5S8gL34t7RRUMsf4prxIR+1PMv+bEqVH
|
||||
pzXBbzgGKIbQbu+njoRzhQ95RbcEzZiVFivNggijkFoUNkFIjy42O/xXdTRTX3/u
|
||||
4pxu1ctccqYnZwnry6E8ekQTVX7kgmVqgzIrY1Y6K3PVlhkgGDK/TStu+RIPoA73
|
||||
vJUpJTFTw+6tgUSBmCkzctQsGFGUiGBRpqlxogEkImJYrcMUJkWhTopLy79+3OB5
|
||||
eGGKWwA6P2ZhBB4RKHBWyWipsYxr889QNG6P3o+Be6lIzvcdiIIBYK8qWLmO45hR
|
||||
xUGWSRK8sveAJO54t+wGE0dSPVKfoS4oNqGYdGhBzoTwsfZRvyvcWeQob4DNCqQa
|
||||
n41XAuYGOG3X1PexdwGSwwqrq2tuG9d2AJ2NjG8nC9hjuuKDfBGTigTwzrwkrn+F
|
||||
3o94NoglQgsXfZeWoBXR5HDaDTdqexSRK0OpSPbvzkn8QDUymdaw7nRS7kU2O7fa
|
||||
W8kxiV8AVt2v/jYjAgMBAAECggGAS+6NE0Mo0Pki2Mi0+y4ls7BgNNbJhYFRthpi
|
||||
RvN8WXE+B0EhquzWDbxKecIZBr6FOqnFQf2muja+QH1VckutjRDKVOZ7QXsygGYf
|
||||
/cff5aM7U3gYTS4avQIDeb0axRioIElu4vtIsJtLFMfe5yaAZktXvPz9fiamn/wG
|
||||
r/xqZ6ifir6nsC2okxGczWE+XCGsjpWA/321lhvj548os5JV6SfTUBUpvqNGVQC2
|
||||
ByXIPDffsCTfQ1rjQ85gM4vuqiQqKn/KXRrMR1kRIOrglMJ6dllitsadkfWdPkVg
|
||||
fHjM1KAnw/uob5kFhvqeEll9sESfCXttrb4XneKAsEck958ChSpU4csaBAfLFVYP
|
||||
5xyIfoaQ+/CUjWI0u1lQbg6jZO59rYfdd5OlH+MyhHybuHR1a0G1izNfuG9WPOWI
|
||||
aprNayH2Wxy9/ZvlrE5yTAeW9tof28hO6O7wBNOcJTrzztsN+V8pSAo0IE2r4D83
|
||||
h978LneAwhC/8mVvzhd/y2t99vcBAoHBAMumCoHmHRAsYBApg2SHxPCTDi4fS1hC
|
||||
IbcuzyvJLv214bHOsdgg2I72a1Q+bbrZAHgENVgSP9Sx1k8aXlvnGOEbpL8w3jRL
|
||||
G4/qXzGrMBp3LCzupi3dI3yrrIMWb2C0goyHeAejzrfaM+uDYTGW4iqhA39zBj4o
|
||||
zoydz3v0i8Yag7Df9MIwr34WD9Ng0oXh8XRCAYJmS1e43jnM+XcFdSfGVhKn9h1B
|
||||
Cbv/hqUSv6baNloWLlPBffLII5bx633MMwKBwQDGg87fKEPz73pbp1kAOuXDpUnm
|
||||
+OUFuf6Uqr/OqCorh8SOwxb2q3ScqyCWGVLOaVBcSV+zxIdcQzX2kjf1mj7PcQ6Q
|
||||
2xfDIS1/xT+RiX9LO0kbkVDYcwcGeKVtmUwWyjauo96OB2r+SchTsNJpYOT3a/7r
|
||||
JUKdbHFwsFwAx5q7r9mOh0BOybuXM6N7lUDBf4SgrhjnKRh1pME3R0JbJj9m8tZg
|
||||
SsWlHcj04yAXJ7NGemiiYgeDZ4unsAfx7/sS/lECgcEAsL0Yj2XTQU8Ry9ULaDsA
|
||||
az1k6BhWvnEea6lfOQPwGVY5WqQk6oqPB3vK6CEKAEgGRSJ53UZxSTlR4fLjg2UL
|
||||
zYm9MATMQ5wPfpYMKcIFDGLy3sf7RwCNpMwk+tuEq+vdBPMo85BxflQMDVBHEM9+
|
||||
1zpIG9sKxvWJVLY89LnmeHZYZi/nboTsOUQSVgPIkVLmx1vljXMT3jzd+FHxCx+c
|
||||
bnmOB8DnMrpYJWV9SFP+KmNlGkf3ys65bPPPF1g7ZUDLAoHAaKHqtQa1Imr0NED1
|
||||
kUB6AHArjrlbhXQuck+5f4R1jbIm8RR1Exj2AunT6Cl60t8Bg1MNRWRt8DxgwhD5
|
||||
u9NMDezKP6GrWacwIytlQSGW3aFm/EfQs/WVG10V3LmzOEPnJI+s63GPfG6JT0tg
|
||||
7DgtFxhuKaTfAri45iueoq6SqSCb7Brv01dTL/QA1E+r7RF4Z3S8HYM0qDVpvegq
|
||||
Wn7DZlDSm7htioUzeZgJPwsm3BwC8Kv4x9MY8g6/cU8LKEyxAoHAWCaDpLIuQ51r
|
||||
PeL+u/1cfNdi6OOrtZ6S95tu3Vv+mYzpCPnOpgPHFp3l+RGmLg56t7uvHFaFxOvB
|
||||
EjPm4bVhnPSA7pl7ZHQXhinG9+4UgcejoCAJzfg05BI1tMbwFZ+C0tG/PNzBlaX+
|
||||
IwkGO8VP/54N6wL1UqfZ8AKJFZW8G7W7KVkjqye1FS4oeDlJ197t/X+PMn5sFAc7
|
||||
UVsDaSelBqpsfmetXSH8KC3XkbgCtHvgAnJDkGkp84VmJvMr5ukv
|
||||
-----END RSA PRIVATE KEY-----
|
||||
`
|
||||
|
||||
type M = map[string]interface{}
|
||||
|
||||
func newMockAPI(t *testing.T, srv *httptest.Server) http.Handler {
|
||||
r := chi.NewRouter()
|
||||
r.Use(middleware.Logger)
|
||||
r.Post("/token", func(w http.ResponseWriter, r *http.Request) {
|
||||
_ = json.NewEncoder(w).Encode(M{
|
||||
"access_token": "ACCESSTOKEN",
|
||||
"token_type": "Bearer",
|
||||
"refresh_token": "REFRESHTOKEN",
|
||||
})
|
||||
})
|
||||
r.Route("/admin/directory/v1", func(r chi.Router) {
|
||||
r.Route("/groups", func(r chi.Router) {
|
||||
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Query().Get("userKey") {
|
||||
case "user1":
|
||||
_ = json.NewEncoder(w).Encode(M{
|
||||
"kind": "admin#directory#groups",
|
||||
"groups": []M{
|
||||
{"id": "group1"},
|
||||
{"id": "group2"},
|
||||
},
|
||||
})
|
||||
default:
|
||||
_ = json.NewEncoder(w).Encode(M{
|
||||
"kind": "admin#directory#groups",
|
||||
"groups": []M{
|
||||
{"id": "group1", "directMembersCount": "2"},
|
||||
{"id": "group2"},
|
||||
},
|
||||
})
|
||||
}
|
||||
})
|
||||
r.Get("/{groupKey}/members", func(w http.ResponseWriter, r *http.Request) {
|
||||
switch chi.URLParam(r, "groupKey") {
|
||||
case "group1":
|
||||
_ = json.NewEncoder(w).Encode(M{
|
||||
"members": []M{
|
||||
{
|
||||
"kind": "admin#directory#member",
|
||||
"id": "inside-user1",
|
||||
"email": "user1@inside.test",
|
||||
"type": "USER",
|
||||
},
|
||||
{
|
||||
"kind": "admin#directory#member",
|
||||
"id": "outside-user1",
|
||||
"email": "user1@outside.test",
|
||||
"type": "USER",
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
})
|
||||
})
|
||||
r.Route("/users", func(r chi.Router) {
|
||||
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
|
||||
_ = json.NewEncoder(w).Encode(M{
|
||||
"kind": "admin#directory#users",
|
||||
"users": []M{
|
||||
{
|
||||
"kind": "admin#directory#user",
|
||||
"id": "inside-user1",
|
||||
"primaryEmail": "user1@inside.test",
|
||||
},
|
||||
},
|
||||
})
|
||||
})
|
||||
r.Get("/{user_id}", func(w http.ResponseWriter, r *http.Request) {
|
||||
switch chi.URLParam(r, "user_id") {
|
||||
case "inside-user1":
|
||||
_ = json.NewEncoder(w).Encode(M{
|
||||
"kind": "admin#directory#user",
|
||||
"id": "inside-user1",
|
||||
"name": M{
|
||||
"fullName": "User 1",
|
||||
},
|
||||
"primaryEmail": "user1@inside.test",
|
||||
})
|
||||
case "outside-user1":
|
||||
http.Error(w, "forbidden", http.StatusForbidden)
|
||||
default:
|
||||
http.Error(w, "not found", http.StatusNotFound)
|
||||
}
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
func TestProvider_User(t *testing.T) {
|
||||
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*30)
|
||||
defer clearTimeout()
|
||||
|
||||
var mockAPI http.Handler
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
mockAPI.ServeHTTP(w, r)
|
||||
}))
|
||||
defer srv.Close()
|
||||
mockAPI = newMockAPI(t, srv)
|
||||
|
||||
p := New(WithServiceAccount(&ServiceAccount{
|
||||
Type: "service_account",
|
||||
PrivateKey: privateKey,
|
||||
TokenURL: srv.URL + "/token",
|
||||
}), WithURL(srv.URL))
|
||||
|
||||
du, err := p.User(ctx, "inside-user1", "")
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
assert.Equal(t, "inside-user1", du.Id)
|
||||
assert.Equal(t, "user1@inside.test", du.Email)
|
||||
assert.Equal(t, "User 1", du.DisplayName)
|
||||
assert.Equal(t, []string{"group1", "group2"}, du.GroupIds)
|
||||
|
||||
du, err = p.User(ctx, "outside-user1", "")
|
||||
if assert.ErrorIs(t, err, directoryerrors.ErrPreferExistingInformation) {
|
||||
return
|
||||
}
|
||||
assert.Equal(t, "outside-user1", du.Id)
|
||||
}
|
||||
|
||||
func TestProvider_UserGroups(t *testing.T) {
|
||||
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*30)
|
||||
defer clearTimeout()
|
||||
|
||||
var mockAPI http.Handler
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
mockAPI.ServeHTTP(w, r)
|
||||
}))
|
||||
defer srv.Close()
|
||||
mockAPI = newMockAPI(t, srv)
|
||||
|
||||
p := New(WithServiceAccount(&ServiceAccount{
|
||||
Type: "service_account",
|
||||
PrivateKey: privateKey,
|
||||
TokenURL: srv.URL + "/token",
|
||||
}), WithURL(srv.URL))
|
||||
|
||||
dgs, dus, err := p.UserGroups(ctx)
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
|
||||
assert.Equal(t, []*directory.Group{
|
||||
{Id: "group1"},
|
||||
}, dgs)
|
||||
testutil.AssertProtoJSONEqual(t, `[
|
||||
{ "id": "inside-user1", "email": "user1@inside.test", "groupIds": ["group1"] },
|
||||
{ "id": "outside-user1", "email": "user1@outside.test", "groupIds": ["group1"] }
|
||||
]`, dus)
|
||||
}
|
||||
|
||||
func TestParseServiceAccount(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
rawServiceAccount string
|
||||
serviceAccount *ServiceAccount
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
"json",
|
||||
`{"impersonate_user":"IMPERSONATE_USER"}`,
|
||||
&ServiceAccount{ImpersonateUser: "IMPERSONATE_USER"},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"base64 json",
|
||||
`eyJpbXBlcnNvbmF0ZV91c2VyIjoiSU1QRVJTT05BVEVfVVNFUiJ9`,
|
||||
&ServiceAccount{ImpersonateUser: "IMPERSONATE_USER"},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"empty",
|
||||
"",
|
||||
nil,
|
||||
true,
|
||||
},
|
||||
{
|
||||
"invalid",
|
||||
"Zm9v---",
|
||||
nil,
|
||||
true,
|
||||
},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got, err := ParseServiceAccount(tc.rawServiceAccount)
|
||||
require.True(t, (err != nil) == tc.wantErr)
|
||||
assert.Equal(t, tc.serviceAccount, got)
|
||||
})
|
||||
}
|
||||
}
|
|
@ -1,421 +0,0 @@
|
|||
// Package okta contains the Okta directory provider.
|
||||
package okta
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"sort"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/tomnomnom/linkheader"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/encoding"
|
||||
"github.com/pomerium/pomerium/internal/httputil"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/directory"
|
||||
)
|
||||
|
||||
// Name is the provider name.
|
||||
const Name = "okta"
|
||||
|
||||
const (
|
||||
// Okta use ISO-8601, see https://developer.okta.com/docs/reference/api-overview/#media-types
|
||||
filterDateFormat = "2006-01-02T15:04:05.999Z"
|
||||
|
||||
batchSize = 200
|
||||
readLimit = 100 * 1024
|
||||
httpSuccessClass = 2
|
||||
)
|
||||
|
||||
// Errors.
|
||||
var (
|
||||
ErrAPIKeyRequired = errors.New("okta: api_key is required")
|
||||
ErrServiceAccountNotDefined = errors.New("okta: service account not defined")
|
||||
ErrProviderURLNotDefined = errors.New("okta: provider url not defined")
|
||||
)
|
||||
|
||||
type config struct {
|
||||
batchSize int
|
||||
httpClient *http.Client
|
||||
providerURL *url.URL
|
||||
serviceAccount *ServiceAccount
|
||||
}
|
||||
|
||||
// An Option configures the Okta Provider.
|
||||
type Option func(cfg *config)
|
||||
|
||||
// WithBatchSize sets the batch size option.
|
||||
func WithBatchSize(batchSize int) Option {
|
||||
return func(cfg *config) {
|
||||
cfg.batchSize = batchSize
|
||||
}
|
||||
}
|
||||
|
||||
// WithHTTPClient sets the http client option.
|
||||
func WithHTTPClient(httpClient *http.Client) Option {
|
||||
return func(cfg *config) {
|
||||
cfg.httpClient = httputil.NewLoggingClient(httpClient, "okta_idp_client",
|
||||
func(evt *zerolog.Event) *zerolog.Event {
|
||||
return evt.Str("provider", "okta")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// WithProviderURL sets the provider URL option.
|
||||
func WithProviderURL(uri *url.URL) Option {
|
||||
return func(cfg *config) {
|
||||
cfg.providerURL = uri
|
||||
}
|
||||
}
|
||||
|
||||
// WithServiceAccount sets the service account option.
|
||||
func WithServiceAccount(serviceAccount *ServiceAccount) Option {
|
||||
return func(cfg *config) {
|
||||
cfg.serviceAccount = serviceAccount
|
||||
}
|
||||
}
|
||||
|
||||
func getConfig(options ...Option) *config {
|
||||
cfg := new(config)
|
||||
WithBatchSize(batchSize)(cfg)
|
||||
WithHTTPClient(http.DefaultClient)(cfg)
|
||||
for _, option := range options {
|
||||
option(cfg)
|
||||
}
|
||||
|
||||
return cfg
|
||||
}
|
||||
|
||||
// A Provider is an Okta user group directory provider.
|
||||
type Provider struct {
|
||||
cfg *config
|
||||
lastUpdated *time.Time
|
||||
groups map[string]*directory.Group
|
||||
}
|
||||
|
||||
// New creates a new Provider.
|
||||
func New(options ...Option) *Provider {
|
||||
return &Provider{
|
||||
cfg: getConfig(options...),
|
||||
groups: make(map[string]*directory.Group),
|
||||
}
|
||||
}
|
||||
|
||||
func withLog(ctx context.Context) context.Context {
|
||||
return log.WithContext(ctx, func(c zerolog.Context) zerolog.Context {
|
||||
return c.Str("service", "directory").Str("provider", "okta")
|
||||
})
|
||||
}
|
||||
|
||||
// User returns the user record for the given id.
|
||||
func (p *Provider) User(ctx context.Context, userID, accessToken string) (*directory.User, error) {
|
||||
ctx = withLog(ctx)
|
||||
|
||||
if p.cfg.serviceAccount == nil {
|
||||
return nil, ErrServiceAccountNotDefined
|
||||
}
|
||||
|
||||
du := &directory.User{
|
||||
Id: userID,
|
||||
}
|
||||
|
||||
au, err := p.getUser(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
du.DisplayName = au.getDisplayName()
|
||||
du.Email = au.Profile.Email
|
||||
|
||||
groups, err := p.listUserGroups(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, g := range groups {
|
||||
du.GroupIds = append(du.GroupIds, g.ID)
|
||||
}
|
||||
sort.Strings(du.GroupIds)
|
||||
|
||||
return du, nil
|
||||
}
|
||||
|
||||
// UserGroups fetches the groups of which the user is a member
|
||||
// https://developer.okta.com/docs/reference/api/users/#get-user-s-groups
|
||||
func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*directory.User, error) {
|
||||
ctx = withLog(ctx)
|
||||
|
||||
if p.cfg.serviceAccount == nil {
|
||||
return nil, nil, ErrServiceAccountNotDefined
|
||||
}
|
||||
|
||||
log.Info(ctx).Msg("getting user groups")
|
||||
|
||||
if p.cfg.providerURL == nil {
|
||||
return nil, nil, ErrProviderURLNotDefined
|
||||
}
|
||||
|
||||
groups, err := p.getGroups(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
userLookup := map[string]apiUserObject{}
|
||||
userIDToGroups := map[string][]string{}
|
||||
for i := 0; i < len(groups); i++ {
|
||||
group := groups[i]
|
||||
users, err := p.getGroupMembers(ctx, group.Id)
|
||||
|
||||
// if we get a 404 on the member query, it means the group doesn't exist, so we should remove it from
|
||||
// the cached lookup and the local groups list
|
||||
var apiErr *APIError
|
||||
if errors.As(err, &apiErr) && apiErr.HTTPStatusCode == http.StatusNotFound {
|
||||
log.Debug(ctx).Str("group", group.Id).Msg("okta: group was removed")
|
||||
delete(p.groups, group.Id)
|
||||
groups = append(groups[:i], groups[i+1:]...)
|
||||
i--
|
||||
continue
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
for _, u := range users {
|
||||
userIDToGroups[u.ID] = append(userIDToGroups[u.ID], group.Id)
|
||||
userLookup[u.ID] = u
|
||||
}
|
||||
}
|
||||
|
||||
var users []*directory.User
|
||||
for _, u := range userLookup {
|
||||
groups := userIDToGroups[u.ID]
|
||||
sort.Strings(groups)
|
||||
users = append(users, &directory.User{
|
||||
Id: u.ID,
|
||||
GroupIds: groups,
|
||||
DisplayName: u.getDisplayName(),
|
||||
Email: u.Profile.Email,
|
||||
})
|
||||
}
|
||||
sort.Slice(users, func(i, j int) bool {
|
||||
return users[i].Id < users[j].Id
|
||||
})
|
||||
return groups, users, nil
|
||||
}
|
||||
|
||||
func (p *Provider) getGroups(ctx context.Context) ([]*directory.Group, error) {
|
||||
u := &url.URL{Path: "/api/v1/groups"}
|
||||
q := u.Query()
|
||||
q.Set("limit", strconv.Itoa(p.cfg.batchSize))
|
||||
if p.lastUpdated != nil {
|
||||
q.Set("filter", fmt.Sprintf(`lastUpdated gt "%[1]s" or lastMembershipUpdated gt "%[1]s"`, p.lastUpdated.UTC().Format(filterDateFormat)))
|
||||
} else {
|
||||
now := time.Now()
|
||||
p.lastUpdated = &now
|
||||
}
|
||||
u.RawQuery = q.Encode()
|
||||
|
||||
groupURL := p.cfg.providerURL.ResolveReference(u).String()
|
||||
for groupURL != "" {
|
||||
var out []apiGroupObject
|
||||
hdrs, err := p.apiGet(ctx, groupURL, &out)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("okta: error querying for groups: %w", err)
|
||||
}
|
||||
|
||||
for _, el := range out {
|
||||
lu, _ := time.Parse(el.LastUpdated, filterDateFormat)
|
||||
lmu, _ := time.Parse(el.LastMembershipUpdated, filterDateFormat)
|
||||
if lu.After(*p.lastUpdated) {
|
||||
p.lastUpdated = &lu
|
||||
}
|
||||
if lmu.After(*p.lastUpdated) {
|
||||
p.lastUpdated = &lmu
|
||||
}
|
||||
p.groups[el.ID] = &directory.Group{
|
||||
Id: el.ID,
|
||||
Name: el.Profile.Name,
|
||||
}
|
||||
}
|
||||
groupURL = getNextLink(hdrs)
|
||||
}
|
||||
|
||||
groups := make([]*directory.Group, 0, len(p.groups))
|
||||
for _, dg := range p.groups {
|
||||
groups = append(groups, dg)
|
||||
}
|
||||
return groups, nil
|
||||
}
|
||||
|
||||
func (p *Provider) getGroupMembers(ctx context.Context, groupID string) (users []apiUserObject, err error) {
|
||||
usersURL := p.cfg.providerURL.ResolveReference(&url.URL{
|
||||
Path: fmt.Sprintf("/api/v1/groups/%s/users", groupID),
|
||||
RawQuery: fmt.Sprintf("limit=%d", p.cfg.batchSize),
|
||||
}).String()
|
||||
for usersURL != "" {
|
||||
var out []apiUserObject
|
||||
hdrs, err := p.apiGet(ctx, usersURL, &out)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("okta: error querying for groups: %w", err)
|
||||
}
|
||||
|
||||
users = append(users, out...)
|
||||
usersURL = getNextLink(hdrs)
|
||||
}
|
||||
|
||||
return users, nil
|
||||
}
|
||||
|
||||
func (p *Provider) getUser(ctx context.Context, userID string) (*apiUserObject, error) {
|
||||
apiURL := p.cfg.providerURL.ResolveReference(&url.URL{
|
||||
Path: fmt.Sprintf("/api/v1/users/%s", userID),
|
||||
}).String()
|
||||
|
||||
var out apiUserObject
|
||||
_, err := p.apiGet(ctx, apiURL, &out)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("okta: error querying for user: %w", err)
|
||||
}
|
||||
|
||||
return &out, nil
|
||||
}
|
||||
|
||||
func (p *Provider) listUserGroups(ctx context.Context, userID string) (groups []apiGroupObject, err error) {
|
||||
apiURL := p.cfg.providerURL.ResolveReference(&url.URL{
|
||||
Path: fmt.Sprintf("/api/v1/users/%s/groups", userID),
|
||||
}).String()
|
||||
for apiURL != "" {
|
||||
var out []apiGroupObject
|
||||
hdrs, err := p.apiGet(ctx, apiURL, &out)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("okta: error querying for user groups: %w", err)
|
||||
}
|
||||
groups = append(groups, out...)
|
||||
apiURL = getNextLink(hdrs)
|
||||
}
|
||||
return groups, nil
|
||||
}
|
||||
|
||||
func (p *Provider) apiGet(ctx context.Context, uri string, out interface{}) (http.Header, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", uri, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("okta: failed to create HTTP request: %w", err)
|
||||
}
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "SSWS "+p.cfg.serviceAccount.APIKey)
|
||||
|
||||
for {
|
||||
res, err := p.cfg.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
if res.StatusCode == http.StatusTooManyRequests {
|
||||
limitReset, err := strconv.ParseInt(res.Header.Get("X-Rate-Limit-Reset"), 10, 64)
|
||||
if err == nil {
|
||||
time.Sleep(time.Until(time.Unix(limitReset, 0)))
|
||||
}
|
||||
continue
|
||||
}
|
||||
if res.StatusCode/100 != httpSuccessClass {
|
||||
return nil, newAPIError(res)
|
||||
}
|
||||
if err := json.NewDecoder(res.Body).Decode(out); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return res.Header, nil
|
||||
}
|
||||
}
|
||||
|
||||
func getNextLink(hdrs http.Header) string {
|
||||
for _, link := range linkheader.ParseMultiple(hdrs.Values("Link")) {
|
||||
if link.Rel == "next" {
|
||||
return link.URL
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// A ServiceAccount is used by the Okta provider to query the API.
|
||||
type ServiceAccount struct {
|
||||
APIKey string `json:"api_key"`
|
||||
}
|
||||
|
||||
// ParseServiceAccount parses the service account in the config options.
|
||||
func ParseServiceAccount(rawServiceAccount string) (*ServiceAccount, error) {
|
||||
var serviceAccount ServiceAccount
|
||||
err := encoding.DecodeBase64OrJSON(rawServiceAccount, &serviceAccount)
|
||||
if err != nil {
|
||||
bs, err := base64.StdEncoding.DecodeString(rawServiceAccount)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
serviceAccount.APIKey = string(bs)
|
||||
}
|
||||
|
||||
if serviceAccount.APIKey == "" {
|
||||
return nil, ErrAPIKeyRequired
|
||||
}
|
||||
|
||||
return &serviceAccount, nil
|
||||
}
|
||||
|
||||
// An APIError is an error from the okta API.
|
||||
type APIError struct {
|
||||
HTTPStatusCode int
|
||||
Body string
|
||||
ErrorCode string `json:"errorCode"`
|
||||
ErrorSummary string `json:"errorSummary"`
|
||||
ErrorLink string `json:"errorLink"`
|
||||
ErrorID string `json:"errorId"`
|
||||
ErrorCauses []string `json:"errorCauses"`
|
||||
}
|
||||
|
||||
func newAPIError(res *http.Response) error {
|
||||
if res == nil {
|
||||
return nil
|
||||
}
|
||||
buf, _ := io.ReadAll(io.LimitReader(res.Body, readLimit)) // limit to 100kb
|
||||
|
||||
err := &APIError{
|
||||
HTTPStatusCode: res.StatusCode,
|
||||
Body: string(buf),
|
||||
}
|
||||
_ = json.Unmarshal(buf, err)
|
||||
return err
|
||||
}
|
||||
|
||||
func (err *APIError) Error() string {
|
||||
return fmt.Sprintf("okta: error querying API, status_code=%d: %s", err.HTTPStatusCode, err.Body)
|
||||
}
|
||||
|
||||
type (
|
||||
apiGroupObject struct {
|
||||
ID string `json:"id"`
|
||||
Profile struct {
|
||||
Name string `json:"name"`
|
||||
} `json:"profile"`
|
||||
LastUpdated string `json:"lastUpdated"`
|
||||
LastMembershipUpdated string `json:"lastMembershipUpdated"`
|
||||
}
|
||||
apiUserObject struct {
|
||||
ID string `json:"id"`
|
||||
Profile struct {
|
||||
FirstName string `json:"firstName"`
|
||||
LastName string `json:"lastName"`
|
||||
Email string `json:"email"`
|
||||
} `json:"profile"`
|
||||
}
|
||||
)
|
||||
|
||||
func (obj *apiUserObject) getDisplayName() string {
|
||||
return obj.Profile.FirstName + " " + obj.Profile.LastName
|
||||
}
|
|
@ -1,361 +0,0 @@
|
|||
package okta
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"sort"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/go-chi/chi/v5/middleware"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tomnomnom/linkheader"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/testutil"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/directory"
|
||||
)
|
||||
|
||||
type M = map[string]interface{}
|
||||
|
||||
func newMockOkta(srv *httptest.Server, userEmailToGroups map[string][]string) http.Handler {
|
||||
getAllGroups := func() map[string]struct{} {
|
||||
allGroups := map[string]struct{}{}
|
||||
for _, groups := range userEmailToGroups {
|
||||
for _, group := range groups {
|
||||
allGroups[group] = struct{}{}
|
||||
}
|
||||
}
|
||||
return allGroups
|
||||
}
|
||||
|
||||
r := chi.NewRouter()
|
||||
r.Use(middleware.Logger)
|
||||
r.Use(func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Header.Get("Authorization") != "SSWS APITOKEN" {
|
||||
http.Error(w, "forbidden", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
})
|
||||
r.Route("/api/v1", func(r chi.Router) {
|
||||
r.Route("/groups", func(r chi.Router) {
|
||||
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
|
||||
lastUpdated := strings.Contains(r.URL.Query().Get("filter"), "lastUpdated ")
|
||||
var groups []string
|
||||
for group := range getAllGroups() {
|
||||
if lastUpdated && group != "user-updated" {
|
||||
continue
|
||||
}
|
||||
if !lastUpdated && group == "user-updated" {
|
||||
continue
|
||||
}
|
||||
groups = append(groups, group)
|
||||
}
|
||||
sort.Strings(groups)
|
||||
|
||||
var result []M
|
||||
|
||||
found := r.URL.Query().Get("after") == ""
|
||||
for i := range groups {
|
||||
if found {
|
||||
result = append(result, M{
|
||||
"id": groups[i],
|
||||
"profile": M{
|
||||
"name": groups[i] + "-name",
|
||||
},
|
||||
})
|
||||
break
|
||||
}
|
||||
found = r.URL.Query().Get("after") == groups[i]
|
||||
}
|
||||
|
||||
if len(result) > 0 {
|
||||
nextURL := mustParseURL(srv.URL).ResolveReference(r.URL)
|
||||
q := nextURL.Query()
|
||||
q.Set("after", result[0]["id"].(string))
|
||||
nextURL.RawQuery = q.Encode()
|
||||
w.Header().Set("Link", linkheader.Link{
|
||||
URL: nextURL.String(),
|
||||
Rel: "next",
|
||||
}.String())
|
||||
}
|
||||
|
||||
_ = json.NewEncoder(w).Encode(result)
|
||||
})
|
||||
r.Get("/{group}/users", func(w http.ResponseWriter, r *http.Request) {
|
||||
group := chi.URLParam(r, "group")
|
||||
|
||||
if _, ok := getAllGroups()[group]; !ok {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
w.Write([]byte(`{
|
||||
"errorCode": "E0000007",
|
||||
"errorSummary": "Not found: {0}",
|
||||
"errorLink": E0000007,
|
||||
"errorId": "sampleE7p0NECLNnSN5z_xLNT",
|
||||
"errorCauses": []
|
||||
}`))
|
||||
return
|
||||
}
|
||||
|
||||
var result []M
|
||||
for email, groups := range userEmailToGroups {
|
||||
for _, g := range groups {
|
||||
if group == g {
|
||||
result = append(result, M{
|
||||
"id": email,
|
||||
"profile": M{
|
||||
"email": email,
|
||||
"firstName": "first",
|
||||
"lastName": "last",
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
sort.Slice(result, func(i, j int) bool {
|
||||
return result[i]["id"].(string) < result[j]["id"].(string)
|
||||
})
|
||||
|
||||
_ = json.NewEncoder(w).Encode(result)
|
||||
})
|
||||
})
|
||||
r.Route("/users", func(r chi.Router) {
|
||||
r.Get("/{user_id}/groups", func(w http.ResponseWriter, r *http.Request) {
|
||||
var groups []apiGroupObject
|
||||
for _, nm := range userEmailToGroups[chi.URLParam(r, "user_id")] {
|
||||
obj := apiGroupObject{
|
||||
ID: nm,
|
||||
}
|
||||
obj.Profile.Name = nm
|
||||
groups = append(groups, obj)
|
||||
}
|
||||
_ = json.NewEncoder(w).Encode(groups)
|
||||
})
|
||||
r.Get("/{user_id}", func(w http.ResponseWriter, r *http.Request) {
|
||||
user := apiUserObject{
|
||||
ID: chi.URLParam(r, "user_id"),
|
||||
}
|
||||
user.Profile.Email = chi.URLParam(r, "user_id")
|
||||
user.Profile.FirstName = "first"
|
||||
user.Profile.LastName = "last"
|
||||
_ = json.NewEncoder(w).Encode(user)
|
||||
})
|
||||
})
|
||||
})
|
||||
return r
|
||||
}
|
||||
|
||||
func TestProvider_User(t *testing.T) {
|
||||
var mockOkta http.Handler
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
mockOkta.ServeHTTP(w, r)
|
||||
}))
|
||||
defer srv.Close()
|
||||
mockOkta = newMockOkta(srv, map[string][]string{
|
||||
"a@example.com": {"user", "admin"},
|
||||
"b@example.com": {"user", "test"},
|
||||
"c@example.com": {"user"},
|
||||
})
|
||||
|
||||
p := New(
|
||||
WithServiceAccount(&ServiceAccount{APIKey: "APITOKEN"}),
|
||||
WithProviderURL(mustParseURL(srv.URL)),
|
||||
)
|
||||
user, err := p.User(context.Background(), "a@example.com", "")
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
testutil.AssertProtoJSONEqual(t, `{
|
||||
"id": "a@example.com",
|
||||
"groupIds": ["admin","user"],
|
||||
"displayName": "first last",
|
||||
"email": "a@example.com"
|
||||
}`, user)
|
||||
}
|
||||
|
||||
func TestProvider_UserGroups(t *testing.T) {
|
||||
var mockOkta http.Handler
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
mockOkta.ServeHTTP(w, r)
|
||||
}))
|
||||
defer srv.Close()
|
||||
mockOkta = newMockOkta(srv, map[string][]string{
|
||||
"a@example.com": {"user", "admin"},
|
||||
"b@example.com": {"user", "test"},
|
||||
"c@example.com": {"user"},
|
||||
})
|
||||
|
||||
p := New(
|
||||
WithServiceAccount(&ServiceAccount{APIKey: "APITOKEN"}),
|
||||
WithProviderURL(mustParseURL(srv.URL)),
|
||||
)
|
||||
groups, users, err := p.UserGroups(context.Background())
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []*directory.User{
|
||||
{
|
||||
Id: "a@example.com",
|
||||
GroupIds: []string{"admin", "user"},
|
||||
DisplayName: "first last",
|
||||
Email: "a@example.com",
|
||||
},
|
||||
{
|
||||
Id: "b@example.com",
|
||||
GroupIds: []string{"test", "user"},
|
||||
DisplayName: "first last",
|
||||
Email: "b@example.com",
|
||||
},
|
||||
{
|
||||
Id: "c@example.com",
|
||||
GroupIds: []string{"user"},
|
||||
DisplayName: "first last",
|
||||
Email: "c@example.com",
|
||||
},
|
||||
}, users)
|
||||
assert.Len(t, groups, 3)
|
||||
}
|
||||
|
||||
func TestProvider_UserGroupsQueryUpdated(t *testing.T) {
|
||||
var mockOkta http.Handler
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
mockOkta.ServeHTTP(w, r)
|
||||
}))
|
||||
defer srv.Close()
|
||||
userEmailToGroups := map[string][]string{
|
||||
"a@example.com": {"user", "admin"},
|
||||
"b@example.com": {"user", "test"},
|
||||
"c@example.com": {"user"},
|
||||
"updated@example.com": {"user-updated"},
|
||||
}
|
||||
mockOkta = newMockOkta(srv, userEmailToGroups)
|
||||
|
||||
p := New(
|
||||
WithServiceAccount(&ServiceAccount{APIKey: "APITOKEN"}),
|
||||
WithProviderURL(mustParseURL(srv.URL)),
|
||||
)
|
||||
groups, users, err := p.UserGroups(context.Background())
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []*directory.User{
|
||||
{
|
||||
Id: "a@example.com",
|
||||
GroupIds: []string{"admin", "user"},
|
||||
DisplayName: "first last",
|
||||
Email: "a@example.com",
|
||||
},
|
||||
{
|
||||
Id: "b@example.com",
|
||||
GroupIds: []string{"test", "user"},
|
||||
DisplayName: "first last",
|
||||
Email: "b@example.com",
|
||||
},
|
||||
{
|
||||
Id: "c@example.com",
|
||||
GroupIds: []string{"user"},
|
||||
DisplayName: "first last",
|
||||
Email: "c@example.com",
|
||||
},
|
||||
}, users)
|
||||
assert.Len(t, groups, 3)
|
||||
|
||||
groups, users, err = p.UserGroups(context.Background())
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []*directory.User{
|
||||
{
|
||||
Id: "a@example.com",
|
||||
GroupIds: []string{"admin", "user"},
|
||||
DisplayName: "first last",
|
||||
Email: "a@example.com",
|
||||
},
|
||||
{
|
||||
Id: "b@example.com",
|
||||
GroupIds: []string{"test", "user"},
|
||||
DisplayName: "first last",
|
||||
Email: "b@example.com",
|
||||
},
|
||||
{
|
||||
Id: "c@example.com",
|
||||
GroupIds: []string{"user"},
|
||||
DisplayName: "first last",
|
||||
Email: "c@example.com",
|
||||
},
|
||||
{
|
||||
Id: "updated@example.com",
|
||||
GroupIds: []string{"user-updated"},
|
||||
DisplayName: "first last",
|
||||
Email: "updated@example.com",
|
||||
},
|
||||
}, users)
|
||||
assert.Len(t, groups, 4)
|
||||
|
||||
userEmailToGroups["b@example.com"] = []string{"user"}
|
||||
|
||||
groups, users, err = p.UserGroups(context.Background())
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []*directory.User{
|
||||
{
|
||||
Id: "a@example.com",
|
||||
GroupIds: []string{"admin", "user"},
|
||||
DisplayName: "first last",
|
||||
Email: "a@example.com",
|
||||
},
|
||||
{
|
||||
Id: "b@example.com",
|
||||
GroupIds: []string{"user"},
|
||||
DisplayName: "first last",
|
||||
Email: "b@example.com",
|
||||
},
|
||||
{
|
||||
Id: "c@example.com",
|
||||
GroupIds: []string{"user"},
|
||||
DisplayName: "first last",
|
||||
Email: "c@example.com",
|
||||
},
|
||||
{
|
||||
Id: "updated@example.com",
|
||||
GroupIds: []string{"user-updated"},
|
||||
DisplayName: "first last",
|
||||
Email: "updated@example.com",
|
||||
},
|
||||
}, users)
|
||||
assert.Len(t, groups, 3)
|
||||
}
|
||||
|
||||
func mustParseURL(rawurl string) *url.URL {
|
||||
u, err := url.Parse(rawurl)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return u
|
||||
}
|
||||
|
||||
func TestParseServiceAccount(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
rawServiceAccount string
|
||||
apiKey string
|
||||
wantErr bool
|
||||
}{
|
||||
{"json", `{"api_key": "foo"}`, "foo", false},
|
||||
{"base64 json", "ewogICAgImFwaV9rZXkiOiAiZm9vIgp9Cg==", "foo", false},
|
||||
{"base64 value", "Zm9v", "foo", false},
|
||||
{"empty", "", "", true},
|
||||
{"invalid", "Zm9v---", "", true},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got, err := ParseServiceAccount(tc.rawServiceAccount)
|
||||
require.True(t, (err != nil) == tc.wantErr)
|
||||
if tc.apiKey != "" {
|
||||
assert.Equal(t, tc.apiKey, got.APIKey)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -1,360 +0,0 @@
|
|||
// Package onelogin contains the onelogin directory provider.
|
||||
package onelogin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/encoding"
|
||||
"github.com/pomerium/pomerium/internal/httputil"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/directory"
|
||||
)
|
||||
|
||||
// Name is the provider name.
|
||||
const Name = "onelogin"
|
||||
|
||||
type config struct {
|
||||
apiURL *url.URL
|
||||
batchSize int
|
||||
serviceAccount *ServiceAccount
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
// An Option updates the onelogin configuration.
|
||||
type Option func(*config)
|
||||
|
||||
// WithBatchSize sets the batch size option.
|
||||
func WithBatchSize(batchSize int) Option {
|
||||
return func(cfg *config) {
|
||||
cfg.batchSize = batchSize
|
||||
}
|
||||
}
|
||||
|
||||
// WithHTTPClient sets the http client option.
|
||||
func WithHTTPClient(httpClient *http.Client) Option {
|
||||
return func(cfg *config) {
|
||||
cfg.httpClient = httputil.NewLoggingClient(httpClient, "onelogin_idp_client",
|
||||
func(evt *zerolog.Event) *zerolog.Event {
|
||||
return evt.Str("provider", "onelogin")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// WithServiceAccount sets the service account in the config.
|
||||
func WithServiceAccount(serviceAccount *ServiceAccount) Option {
|
||||
return func(cfg *config) {
|
||||
cfg.serviceAccount = serviceAccount
|
||||
}
|
||||
}
|
||||
|
||||
// WithURL sets the api url in the config.
|
||||
func WithURL(apiURL *url.URL) Option {
|
||||
return func(cfg *config) {
|
||||
cfg.apiURL = apiURL
|
||||
}
|
||||
}
|
||||
|
||||
func getConfig(options ...Option) *config {
|
||||
cfg := new(config)
|
||||
WithBatchSize(20)(cfg)
|
||||
WithHTTPClient(http.DefaultClient)(cfg)
|
||||
WithURL(&url.URL{
|
||||
Scheme: "https",
|
||||
Host: "api.us.onelogin.com",
|
||||
})(cfg)
|
||||
for _, option := range options {
|
||||
option(cfg)
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
|
||||
// The Provider retrieves users and groups from onelogin.
|
||||
type Provider struct {
|
||||
cfg *config
|
||||
|
||||
mu sync.RWMutex
|
||||
token *oauth2.Token
|
||||
}
|
||||
|
||||
// New creates a new Provider.
|
||||
func New(options ...Option) *Provider {
|
||||
cfg := getConfig(options...)
|
||||
return &Provider{
|
||||
cfg: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
func withLog(ctx context.Context) context.Context {
|
||||
return log.WithContext(ctx, func(c zerolog.Context) zerolog.Context {
|
||||
return c.Str("service", "directory").Str("provider", "onelogin")
|
||||
})
|
||||
}
|
||||
|
||||
// User returns the user record for the given id.
|
||||
func (p *Provider) User(ctx context.Context, userID, accessToken string) (*directory.User, error) {
|
||||
if p.cfg.serviceAccount == nil {
|
||||
return nil, fmt.Errorf("onelogin: service account not defined")
|
||||
}
|
||||
du := &directory.User{
|
||||
Id: userID,
|
||||
}
|
||||
|
||||
ctx = withLog(ctx)
|
||||
|
||||
token, err := p.getToken(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
au, err := p.getUser(ctx, token.AccessToken, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
du.DisplayName = au.getDisplayName()
|
||||
du.Email = au.Email
|
||||
du.GroupIds = []string{strconv.Itoa(au.GroupID)}
|
||||
|
||||
return du, nil
|
||||
}
|
||||
|
||||
// UserGroups gets the directory user groups for onelogin.
|
||||
func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*directory.User, error) {
|
||||
if p.cfg.serviceAccount == nil {
|
||||
return nil, nil, fmt.Errorf("onelogin: service account not defined")
|
||||
}
|
||||
|
||||
ctx = withLog(ctx)
|
||||
|
||||
log.Info(ctx).Msg("getting user groups")
|
||||
|
||||
token, err := p.getToken(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
groups, err := p.listGroups(ctx, token.AccessToken)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
apiUsers, err := p.listUsers(ctx, token.AccessToken)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
var users []*directory.User
|
||||
for _, u := range apiUsers {
|
||||
users = append(users, &directory.User{
|
||||
Id: strconv.Itoa(u.ID),
|
||||
GroupIds: []string{strconv.Itoa(u.GroupID)},
|
||||
DisplayName: u.FirstName + " " + u.LastName,
|
||||
Email: u.Email,
|
||||
})
|
||||
}
|
||||
|
||||
sort.Slice(users, func(i, j int) bool {
|
||||
return users[i].Id < users[j].Id
|
||||
})
|
||||
return groups, users, nil
|
||||
}
|
||||
|
||||
func (p *Provider) listGroups(ctx context.Context, accessToken string) ([]*directory.Group, error) {
|
||||
var groups []*directory.Group
|
||||
apiURL := p.cfg.apiURL.ResolveReference(&url.URL{
|
||||
Path: "/api/1/groups",
|
||||
RawQuery: fmt.Sprintf("limit=%d", p.cfg.batchSize),
|
||||
}).String()
|
||||
for apiURL != "" {
|
||||
var result []struct {
|
||||
ID int `json:"id"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
nextLink, err := p.apiGet(ctx, accessToken, apiURL, &result)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("onelogin: listing groups: %w", err)
|
||||
}
|
||||
|
||||
for _, r := range result {
|
||||
groups = append(groups, &directory.Group{
|
||||
Id: strconv.Itoa(r.ID),
|
||||
Name: r.Name,
|
||||
})
|
||||
}
|
||||
|
||||
apiURL = nextLink
|
||||
}
|
||||
return groups, nil
|
||||
}
|
||||
|
||||
func (p *Provider) getUser(ctx context.Context, accessToken string, userID string) (*apiUserObject, error) {
|
||||
apiURL := p.cfg.apiURL.ResolveReference(&url.URL{
|
||||
Path: fmt.Sprintf("/api/1/users/%s", userID),
|
||||
}).String()
|
||||
|
||||
var out []apiUserObject
|
||||
_, err := p.apiGet(ctx, accessToken, apiURL, &out)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("onelogin: error getting user: %w", err)
|
||||
}
|
||||
if len(out) == 0 {
|
||||
return nil, fmt.Errorf("onelogin: user not found")
|
||||
}
|
||||
|
||||
return &out[0], nil
|
||||
}
|
||||
|
||||
func (p *Provider) listUsers(ctx context.Context, accessToken string) ([]apiUserObject, error) {
|
||||
var users []apiUserObject
|
||||
|
||||
apiURL := p.cfg.apiURL.ResolveReference(&url.URL{
|
||||
Path: "/api/1/users",
|
||||
RawQuery: fmt.Sprintf("limit=%d", p.cfg.batchSize),
|
||||
}).String()
|
||||
for apiURL != "" {
|
||||
var result []apiUserObject
|
||||
nextLink, err := p.apiGet(ctx, accessToken, apiURL, &result)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("onelogin: listing users: %w", err)
|
||||
}
|
||||
|
||||
users = append(users, result...)
|
||||
apiURL = nextLink
|
||||
}
|
||||
|
||||
return users, nil
|
||||
}
|
||||
|
||||
func (p *Provider) apiGet(ctx context.Context, accessToken string, uri string, out interface{}) (nextLink string, err error) {
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", uri, nil)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
req.Header.Set("Authorization", fmt.Sprintf("bearer:%s", accessToken))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
res, err := p.cfg.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
if res.StatusCode/100 != 2 {
|
||||
return "", fmt.Errorf("onelogin: error querying api: %s", res.Status)
|
||||
}
|
||||
|
||||
var result struct {
|
||||
Pagination struct {
|
||||
NextLink string `json:"next_link"`
|
||||
}
|
||||
Data json.RawMessage `json:"data"`
|
||||
}
|
||||
err = json.NewDecoder(res.Body).Decode(&result)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
log.Info(ctx).
|
||||
Str("url", uri).
|
||||
Interface("result", result).
|
||||
Msg("api request")
|
||||
|
||||
err = json.Unmarshal(result.Data, out)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return result.Pagination.NextLink, nil
|
||||
}
|
||||
|
||||
func (p *Provider) getToken(ctx context.Context) (*oauth2.Token, error) {
|
||||
p.mu.RLock()
|
||||
token := p.token
|
||||
p.mu.RUnlock()
|
||||
|
||||
if token != nil && token.Valid() {
|
||||
return token, nil
|
||||
}
|
||||
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
token = p.token
|
||||
if token != nil && token.Valid() {
|
||||
return token, nil
|
||||
}
|
||||
|
||||
apiURL := p.cfg.apiURL.ResolveReference(&url.URL{
|
||||
Path: "/auth/oauth2/v2/token",
|
||||
})
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", apiURL.String(), strings.NewReader(`{ "grant_type": "client_credentials" }`))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Authorization", fmt.Sprintf("client_id:%s, client_secret:%s",
|
||||
p.cfg.serviceAccount.ClientID, p.cfg.serviceAccount.ClientSecret))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
res, err := p.cfg.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
if res.StatusCode/100 != 2 {
|
||||
return nil, fmt.Errorf("onelogin: error querying oauth2 token: %s", res.Status)
|
||||
}
|
||||
err = json.NewDecoder(res.Body).Decode(&token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
p.token = token
|
||||
|
||||
return p.token, nil
|
||||
}
|
||||
|
||||
// A ServiceAccount is used by the OneLogin provider to query the API.
|
||||
type ServiceAccount struct {
|
||||
ClientID string `json:"client_id"`
|
||||
ClientSecret string `json:"client_secret"`
|
||||
}
|
||||
|
||||
// ParseServiceAccount parses the service account in the config options.
|
||||
func ParseServiceAccount(rawServiceAccount string) (*ServiceAccount, error) {
|
||||
var serviceAccount ServiceAccount
|
||||
if err := encoding.DecodeBase64OrJSON(rawServiceAccount, &serviceAccount); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if serviceAccount.ClientID == "" {
|
||||
return nil, fmt.Errorf("client_id is required")
|
||||
}
|
||||
if serviceAccount.ClientSecret == "" {
|
||||
return nil, fmt.Errorf("client_secret is required")
|
||||
}
|
||||
|
||||
return &serviceAccount, nil
|
||||
}
|
||||
|
||||
type apiUserObject struct {
|
||||
ID int `json:"id"`
|
||||
GroupID int `json:"group_id"`
|
||||
Email string `json:"email"`
|
||||
FirstName string `json:"firstname"`
|
||||
LastName string `json:"lastname"`
|
||||
}
|
||||
|
||||
func (obj *apiUserObject) getDisplayName() string {
|
||||
return obj.FirstName + " " + obj.LastName
|
||||
}
|
|
@ -1,270 +0,0 @@
|
|||
package onelogin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"sort"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/go-chi/chi/v5/middleware"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/testutil"
|
||||
)
|
||||
|
||||
type M = map[string]interface{}
|
||||
|
||||
func newMockAPI(srv *httptest.Server, userIDToGroupName map[int]string) http.Handler {
|
||||
lookup := map[string]struct{}{}
|
||||
for _, group := range userIDToGroupName {
|
||||
lookup[group] = struct{}{}
|
||||
}
|
||||
var allGroups []string
|
||||
for groupName := range lookup {
|
||||
allGroups = append(allGroups, groupName)
|
||||
}
|
||||
sort.Strings(allGroups)
|
||||
|
||||
var allUserIDs []int
|
||||
for userID := range userIDToGroupName {
|
||||
allUserIDs = append(allUserIDs, userID)
|
||||
}
|
||||
sort.Ints(allUserIDs)
|
||||
|
||||
r := chi.NewRouter()
|
||||
r.Use(middleware.Logger)
|
||||
r.Post("/auth/oauth2/v2/token", func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Header.Get("Authorization") != "client_id:CLIENTID, client_secret:CLIENTSECRET" {
|
||||
http.Error(w, "forbidden", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
var request struct {
|
||||
GrantType string `json:"grant_type"`
|
||||
}
|
||||
_ = json.NewDecoder(r.Body).Decode(&request)
|
||||
if request.GrantType != "client_credentials" {
|
||||
http.Error(w, "invalid grant_type", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
_ = json.NewEncoder(w).Encode(M{
|
||||
"access_token": "ACCESSTOKEN",
|
||||
"created_at": time.Now().Format(time.RFC3339),
|
||||
"expires_in": 360000,
|
||||
"refresh_token": "REFRESHTOKEN",
|
||||
"token_type": "bearer",
|
||||
})
|
||||
})
|
||||
r.Route("/api/1", func(r chi.Router) {
|
||||
r.Use(func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Header.Get("Authorization") != "bearer:ACCESSTOKEN" {
|
||||
http.Error(w, "forbidden", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
})
|
||||
r.Get("/groups", func(w http.ResponseWriter, r *http.Request) {
|
||||
var result struct {
|
||||
Pagination struct {
|
||||
NextLink string `json:"next_link"`
|
||||
} `json:"pagination"`
|
||||
Data []M `json:"data"`
|
||||
}
|
||||
|
||||
found := r.URL.Query().Get("after") == ""
|
||||
for i := range allGroups {
|
||||
if found {
|
||||
result.Data = append(result.Data, M{
|
||||
"id": i,
|
||||
"name": allGroups[i],
|
||||
})
|
||||
break
|
||||
}
|
||||
found = r.URL.Query().Get("after") == fmt.Sprint(i)
|
||||
}
|
||||
|
||||
if len(result.Data) > 0 {
|
||||
nextURL := mustParseURL(srv.URL).ResolveReference(r.URL)
|
||||
q := nextURL.Query()
|
||||
q.Set("after", fmt.Sprint(result.Data[0]["id"]))
|
||||
nextURL.RawQuery = q.Encode()
|
||||
result.Pagination.NextLink = nextURL.String()
|
||||
}
|
||||
|
||||
_ = json.NewEncoder(w).Encode(result)
|
||||
})
|
||||
r.Get("/users/{user_id}", func(w http.ResponseWriter, r *http.Request) {
|
||||
userIDToGroupID := map[int]int{}
|
||||
for userID, groupName := range userIDToGroupName {
|
||||
for id, n := range allGroups {
|
||||
if groupName == n {
|
||||
userIDToGroupID[userID] = id
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
userID, _ := strconv.Atoi(chi.URLParam(r, "user_id"))
|
||||
|
||||
_ = json.NewEncoder(w).Encode(M{
|
||||
"data": []M{{
|
||||
"id": userID,
|
||||
"email": userIDToGroupName[userID] + "@example.com",
|
||||
"group_id": userIDToGroupID[userID],
|
||||
"firstname": "User",
|
||||
"lastname": fmt.Sprint(userID),
|
||||
}},
|
||||
})
|
||||
})
|
||||
r.Get("/users", func(w http.ResponseWriter, r *http.Request) {
|
||||
userIDToGroupID := map[int]int{}
|
||||
for userID, groupName := range userIDToGroupName {
|
||||
for id, n := range allGroups {
|
||||
if groupName == n {
|
||||
userIDToGroupID[userID] = id
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var result []M
|
||||
for _, userID := range allUserIDs {
|
||||
result = append(result, M{
|
||||
"id": userID,
|
||||
"email": userIDToGroupName[userID] + "@example.com",
|
||||
"group_id": userIDToGroupID[userID],
|
||||
"firstname": "User",
|
||||
"lastname": fmt.Sprint(userID),
|
||||
})
|
||||
}
|
||||
_ = json.NewEncoder(w).Encode(M{
|
||||
"data": result,
|
||||
})
|
||||
})
|
||||
})
|
||||
return r
|
||||
}
|
||||
|
||||
func TestProvider_User(t *testing.T) {
|
||||
var mockAPI http.Handler
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
mockAPI.ServeHTTP(w, r)
|
||||
}))
|
||||
defer srv.Close()
|
||||
mockAPI = newMockAPI(srv, map[int]string{
|
||||
111: "admin",
|
||||
222: "test",
|
||||
333: "user",
|
||||
})
|
||||
|
||||
p := New(
|
||||
WithServiceAccount(&ServiceAccount{
|
||||
ClientID: "CLIENTID",
|
||||
ClientSecret: "CLIENTSECRET",
|
||||
}),
|
||||
WithURL(mustParseURL(srv.URL)),
|
||||
)
|
||||
user, err := p.User(context.Background(), "111", "ACCESSTOKEN")
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
testutil.AssertProtoJSONEqual(t, `{
|
||||
"id": "111",
|
||||
"groupIds": ["0"],
|
||||
"displayName": "User 111",
|
||||
"email": "admin@example.com"
|
||||
}`, user)
|
||||
}
|
||||
|
||||
func TestProvider_UserGroups(t *testing.T) {
|
||||
var mockAPI http.Handler
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
mockAPI.ServeHTTP(w, r)
|
||||
}))
|
||||
defer srv.Close()
|
||||
mockAPI = newMockAPI(srv, map[int]string{
|
||||
111: "admin",
|
||||
222: "test",
|
||||
333: "user",
|
||||
})
|
||||
|
||||
p := New(
|
||||
WithServiceAccount(&ServiceAccount{
|
||||
ClientID: "CLIENTID",
|
||||
ClientSecret: "CLIENTSECRET",
|
||||
}),
|
||||
WithURL(mustParseURL(srv.URL)),
|
||||
)
|
||||
groups, users, err := p.UserGroups(context.Background())
|
||||
assert.NoError(t, err)
|
||||
testutil.AssertProtoJSONEqual(t, `[
|
||||
{ "id": "111", "groupIds": ["0"], "displayName": "User 111", "email": "admin@example.com" },
|
||||
{ "id": "222", "groupIds": ["1"], "displayName": "User 222", "email": "test@example.com" },
|
||||
{ "id": "333", "groupIds": ["2"], "displayName": "User 333", "email": "user@example.com" }
|
||||
]`, users)
|
||||
testutil.AssertProtoJSONEqual(t, `[
|
||||
{ "id": "0", "name": "admin" },
|
||||
{ "id": "1", "name": "test" },
|
||||
{ "id": "2", "name": "user" }
|
||||
]`, groups)
|
||||
}
|
||||
|
||||
func TestParseServiceAccount(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
rawServiceAccount string
|
||||
serviceAccount *ServiceAccount
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
"json",
|
||||
`{"client_id":"CLIENT_ID","client_secret":"CLIENT_SECRET"}`,
|
||||
&ServiceAccount{ClientID: "CLIENT_ID", ClientSecret: "CLIENT_SECRET"},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"base64 json",
|
||||
`eyJjbGllbnRfaWQiOiJDTElFTlRfSUQiLCJjbGllbnRfc2VjcmV0IjoiQ0xJRU5UX1NFQ1JFVCJ9`,
|
||||
&ServiceAccount{ClientID: "CLIENT_ID", ClientSecret: "CLIENT_SECRET"},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"empty",
|
||||
"",
|
||||
nil,
|
||||
true,
|
||||
},
|
||||
{
|
||||
"invalid",
|
||||
"Zm9v---",
|
||||
nil,
|
||||
true,
|
||||
},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got, err := ParseServiceAccount(tc.rawServiceAccount)
|
||||
require.True(t, (err != nil) == tc.wantErr)
|
||||
assert.Equal(t, tc.serviceAccount, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func mustParseURL(rawurl string) *url.URL {
|
||||
u, err := url.Parse(rawurl)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return u
|
||||
}
|
|
@ -1,174 +0,0 @@
|
|||
package ping
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var errNotFound = errors.New("ping: user not found")
|
||||
|
||||
type (
|
||||
apiGroup struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
apiUser struct {
|
||||
ID string `json:"id"`
|
||||
Email string `json:"email"`
|
||||
Name apiUserName `json:"name"`
|
||||
Username string `json:"username"`
|
||||
MemberOfGroupIDs []string `json:"memberOfGroupIDs"`
|
||||
}
|
||||
apiUserName struct {
|
||||
Given string `json:"given"`
|
||||
Middle string `json:"middle"`
|
||||
Family string `json:"family"`
|
||||
}
|
||||
)
|
||||
|
||||
func (au apiUser) getDisplayName() string {
|
||||
var parts []string
|
||||
if au.Name.Given != "" {
|
||||
parts = append(parts, au.Name.Given)
|
||||
}
|
||||
if au.Name.Middle != "" {
|
||||
parts = append(parts, au.Name.Middle)
|
||||
}
|
||||
if au.Name.Family != "" {
|
||||
parts = append(parts, au.Name.Family)
|
||||
}
|
||||
if len(parts) == 0 {
|
||||
parts = append(parts, au.Username)
|
||||
}
|
||||
return strings.Join(parts, " ")
|
||||
}
|
||||
|
||||
func getAllGroups(ctx context.Context, client *http.Client, apiURL *url.URL, envID string) ([]apiGroup, error) {
|
||||
nextURL := apiURL.ResolveReference(&url.URL{
|
||||
Path: fmt.Sprintf("/v1/environments/%s/groups", url.PathEscape(envID)),
|
||||
}).String()
|
||||
|
||||
var apiGroups []apiGroup
|
||||
err := batchAPIRequest(ctx, client, nextURL, func(body []byte) error {
|
||||
var apiResponse struct {
|
||||
Embedded struct {
|
||||
Groups []apiGroup `json:"groups"`
|
||||
} `json:"_embedded"`
|
||||
}
|
||||
err := json.Unmarshal(body, &apiResponse)
|
||||
if err != nil {
|
||||
return fmt.Errorf("ping: error decoding API response: %w", err)
|
||||
}
|
||||
apiGroups = append(apiGroups, apiResponse.Embedded.Groups...)
|
||||
return nil
|
||||
})
|
||||
return apiGroups, err
|
||||
}
|
||||
|
||||
func getGroupUsers(ctx context.Context, client *http.Client, apiURL *url.URL, envID, groupID string) ([]apiUser, error) {
|
||||
nextURL := apiURL.ResolveReference(&url.URL{
|
||||
Path: fmt.Sprintf("/v1/environments/%s/users", url.PathEscape(envID)),
|
||||
RawQuery: (&url.Values{
|
||||
"filter": {fmt.Sprintf(`memberOfGroups[id eq "%s"]`, groupID)},
|
||||
}).Encode(),
|
||||
}).String()
|
||||
|
||||
var apiUsers []apiUser
|
||||
err := batchAPIRequest(ctx, client, nextURL, func(body []byte) error {
|
||||
var apiResponse struct {
|
||||
Embedded struct {
|
||||
Users []apiUser `json:"users"`
|
||||
} `json:"_embedded"`
|
||||
}
|
||||
err := json.Unmarshal(body, &apiResponse)
|
||||
if err != nil {
|
||||
return fmt.Errorf("ping: error decoding API response: %w", err)
|
||||
}
|
||||
apiUsers = append(apiUsers, apiResponse.Embedded.Users...)
|
||||
return nil
|
||||
})
|
||||
return apiUsers, err
|
||||
}
|
||||
|
||||
func getUser(ctx context.Context, client *http.Client, apiURL *url.URL, envID, userID string) (*apiUser, error) {
|
||||
nextURL := apiURL.ResolveReference(&url.URL{
|
||||
Path: fmt.Sprintf("/v1/environments/%s/users/%s", url.PathEscape(envID), url.PathEscape(userID)),
|
||||
RawQuery: (&url.Values{
|
||||
"include": {"memberOfGroupIDs"},
|
||||
}).Encode(),
|
||||
}).String()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", nextURL, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("ping: error building API request: %w", err)
|
||||
}
|
||||
res, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("ping: error making API request: %w", err)
|
||||
}
|
||||
body, err := io.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("ping: error reading API response: %w", err)
|
||||
}
|
||||
_ = res.Body.Close()
|
||||
if res.StatusCode == http.StatusNotFound {
|
||||
return nil, errNotFound
|
||||
} else if res.StatusCode/100 != 2 {
|
||||
return nil, fmt.Errorf("ping: unexpected status code: %d", res.StatusCode)
|
||||
}
|
||||
|
||||
var u apiUser
|
||||
err = json.Unmarshal(body, &u)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("ping: error decoding API response: %w", err)
|
||||
}
|
||||
return &u, nil
|
||||
}
|
||||
|
||||
func batchAPIRequest(ctx context.Context, client *http.Client, nextURL string, callback func(body []byte) error) error {
|
||||
for nextURL != "" {
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", nextURL, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("ping: error building API request: %w", err)
|
||||
}
|
||||
|
||||
res, err := client.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("ping: error making API request: %w", err)
|
||||
}
|
||||
bs, err := io.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("ping: error reading API response: %w", err)
|
||||
}
|
||||
_ = res.Body.Close()
|
||||
if res.StatusCode/100 != 2 {
|
||||
return fmt.Errorf("ping: unexpected status code: %d", res.StatusCode)
|
||||
}
|
||||
|
||||
var apiResponse struct {
|
||||
Links struct {
|
||||
Next struct {
|
||||
HREF string `json:"href"`
|
||||
} `json:"next"`
|
||||
} `json:"_links"`
|
||||
}
|
||||
err = json.Unmarshal(bs, &apiResponse)
|
||||
if err != nil {
|
||||
return fmt.Errorf("ping: error decoding API response: %w", err)
|
||||
}
|
||||
|
||||
err = callback(bs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
nextURL = apiResponse.Links.Next.HREF
|
||||
}
|
||||
return nil
|
||||
}
|
|
@ -1,116 +0,0 @@
|
|||
package ping
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/encoding"
|
||||
"github.com/pomerium/pomerium/internal/httputil"
|
||||
)
|
||||
|
||||
type config struct {
|
||||
authURL *url.URL
|
||||
apiURL *url.URL
|
||||
serviceAccount *ServiceAccount
|
||||
httpClient *http.Client
|
||||
environmentID string
|
||||
}
|
||||
|
||||
// An Option updates the Ping configuration.
|
||||
type Option func(*config)
|
||||
|
||||
// WithAPIURL sets the api url in the config.
|
||||
func WithAPIURL(apiURL *url.URL) Option {
|
||||
return func(cfg *config) {
|
||||
cfg.apiURL = apiURL
|
||||
}
|
||||
}
|
||||
|
||||
// WithAuthURL sets the auth url in the config.
|
||||
func WithAuthURL(authURL *url.URL) Option {
|
||||
return func(cfg *config) {
|
||||
cfg.authURL = authURL
|
||||
}
|
||||
}
|
||||
|
||||
// WithEnvironmentID sets the environment ID in the config.
|
||||
func WithEnvironmentID(environmentID string) Option {
|
||||
return func(cfg *config) {
|
||||
cfg.environmentID = environmentID
|
||||
}
|
||||
}
|
||||
|
||||
// WithHTTPClient sets the http client option.
|
||||
func WithHTTPClient(httpClient *http.Client) Option {
|
||||
return func(cfg *config) {
|
||||
cfg.httpClient = httputil.NewLoggingClient(httpClient, "ping_idp_client",
|
||||
func(evt *zerolog.Event) *zerolog.Event {
|
||||
return evt.Str("provider", "ping")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// WithProviderURL sets the environment ID from the provider URL set in the config.
|
||||
func WithProviderURL(providerURL *url.URL) Option {
|
||||
// provider URL will be https://auth.pingone.com/{ENVIRONMENT_ID}/as
|
||||
if providerURL == nil {
|
||||
return func(cfg *config) {}
|
||||
}
|
||||
parts := strings.Split(providerURL.Path, "/")
|
||||
if len(parts) < 1 {
|
||||
return func(cfg *config) {}
|
||||
}
|
||||
return WithEnvironmentID(parts[1])
|
||||
}
|
||||
|
||||
// WithServiceAccount sets the service account in the config.
|
||||
func WithServiceAccount(serviceAccount *ServiceAccount) Option {
|
||||
return func(cfg *config) {
|
||||
cfg.serviceAccount = serviceAccount
|
||||
}
|
||||
}
|
||||
|
||||
func getConfig(options ...Option) *config {
|
||||
cfg := new(config)
|
||||
WithHTTPClient(http.DefaultClient)(cfg)
|
||||
WithAuthURL(&url.URL{
|
||||
Scheme: "https",
|
||||
Host: "auth.pingone.com",
|
||||
})(cfg)
|
||||
WithAPIURL(&url.URL{
|
||||
Scheme: "https",
|
||||
Host: "api.pingone.com",
|
||||
})(cfg)
|
||||
for _, option := range options {
|
||||
option(cfg)
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
|
||||
// A ServiceAccount is used by the Ping provider to query the API.
|
||||
type ServiceAccount struct {
|
||||
ClientID string `json:"client_id"`
|
||||
ClientSecret string `json:"client_secret"`
|
||||
EnvironmentID string `json:"environment_id"`
|
||||
}
|
||||
|
||||
// ParseServiceAccount parses the service account in the config options.
|
||||
func ParseServiceAccount(rawServiceAccount string) (*ServiceAccount, error) {
|
||||
var serviceAccount ServiceAccount
|
||||
if err := encoding.DecodeBase64OrJSON(rawServiceAccount, &serviceAccount); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if serviceAccount.ClientID == "" {
|
||||
return nil, fmt.Errorf("client_id is required")
|
||||
}
|
||||
if serviceAccount.ClientSecret == "" {
|
||||
return nil, fmt.Errorf("client_secret is required")
|
||||
}
|
||||
|
||||
return &serviceAccount, nil
|
||||
}
|
|
@ -1,51 +0,0 @@
|
|||
package ping
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestParseServiceAccount(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
rawServiceAccount string
|
||||
serviceAccount *ServiceAccount
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
"json",
|
||||
`{"client_id":"CLIENT_ID","client_secret":"CLIENT_SECRET","environment_id":"ENVIRONMENT_ID"}`,
|
||||
&ServiceAccount{ClientID: "CLIENT_ID", ClientSecret: "CLIENT_SECRET", EnvironmentID: "ENVIRONMENT_ID"},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"base64 json",
|
||||
`eyJjbGllbnRfaWQiOiJDTElFTlRfSUQiLCJjbGllbnRfc2VjcmV0IjoiQ0xJRU5UX1NFQ1JFVCIsImVudmlyb25tZW50X2lkIjoiRU5WSVJPTk1FTlRfSUQifQ==`,
|
||||
&ServiceAccount{ClientID: "CLIENT_ID", ClientSecret: "CLIENT_SECRET", EnvironmentID: "ENVIRONMENT_ID"},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"empty",
|
||||
"",
|
||||
nil,
|
||||
true,
|
||||
},
|
||||
{
|
||||
"invalid",
|
||||
"Zm9v---",
|
||||
nil,
|
||||
true,
|
||||
},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got, err := ParseServiceAccount(tc.rawServiceAccount)
|
||||
require.True(t, (err != nil) == tc.wantErr)
|
||||
assert.Equal(t, tc.serviceAccount, got)
|
||||
})
|
||||
}
|
||||
}
|
|
@ -1,167 +0,0 @@
|
|||
// Package ping implements a directory provider for Ping.
|
||||
package ping
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"sort"
|
||||
"sync"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/oauth2/clientcredentials"
|
||||
|
||||
"github.com/pomerium/pomerium/pkg/grpc/directory"
|
||||
)
|
||||
|
||||
// Name is the name of the Ping provider.
|
||||
const Name = "ping"
|
||||
|
||||
// Provider implements a directory provider using the Ping API.
|
||||
type Provider struct {
|
||||
cfg *config
|
||||
mu sync.RWMutex
|
||||
token *oauth2.Token
|
||||
}
|
||||
|
||||
// New creates a new Ping Provider.
|
||||
func New(options ...Option) *Provider {
|
||||
cfg := getConfig(options...)
|
||||
return &Provider{
|
||||
cfg: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
// User returns a user's directory information.
|
||||
func (p *Provider) User(ctx context.Context, userID, accessToken string) (*directory.User, error) {
|
||||
client, err := p.getClient(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
au, err := getUser(ctx, client, p.cfg.apiURL, p.cfg.environmentID, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &directory.User{
|
||||
Id: au.ID,
|
||||
DisplayName: au.getDisplayName(),
|
||||
Email: au.Email,
|
||||
GroupIds: au.MemberOfGroupIDs,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// UserGroups returns all the users and groups in the directory.
|
||||
func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*directory.User, error) {
|
||||
client, err := p.getClient(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
apiGroups, err := getAllGroups(ctx, client, p.cfg.apiURL, p.cfg.environmentID)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
directoryUserLookup := map[string]*directory.User{}
|
||||
directoryGroups := make([]*directory.Group, len(apiGroups))
|
||||
for i, ag := range apiGroups {
|
||||
dg := &directory.Group{
|
||||
Id: ag.ID,
|
||||
Name: ag.Name,
|
||||
}
|
||||
|
||||
apiUsers, err := getGroupUsers(ctx, client, p.cfg.apiURL, p.cfg.environmentID, ag.ID)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
for _, au := range apiUsers {
|
||||
du, ok := directoryUserLookup[au.ID]
|
||||
if !ok {
|
||||
du = &directory.User{
|
||||
Id: au.ID,
|
||||
DisplayName: au.getDisplayName(),
|
||||
Email: au.Email,
|
||||
}
|
||||
directoryUserLookup[au.ID] = du
|
||||
}
|
||||
du.GroupIds = append(du.GroupIds, ag.ID)
|
||||
}
|
||||
|
||||
directoryGroups[i] = dg
|
||||
}
|
||||
sort.Slice(directoryGroups, func(i, j int) bool {
|
||||
return directoryGroups[i].Id < directoryGroups[j].Id
|
||||
})
|
||||
|
||||
directoryUsers := make([]*directory.User, 0, len(directoryUserLookup))
|
||||
for _, du := range directoryUserLookup {
|
||||
directoryUsers = append(directoryUsers, du)
|
||||
}
|
||||
sort.Slice(directoryUsers, func(i, j int) bool {
|
||||
return directoryUsers[i].Id < directoryUsers[j].Id
|
||||
})
|
||||
|
||||
return directoryGroups, directoryUsers, nil
|
||||
}
|
||||
|
||||
func (p *Provider) getClient(ctx context.Context) (*http.Client, error) {
|
||||
token, err := p.getToken(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
client := new(http.Client)
|
||||
*client = *p.cfg.httpClient
|
||||
client.Transport = &oauth2.Transport{
|
||||
Source: oauth2.StaticTokenSource(token),
|
||||
Base: p.cfg.httpClient.Transport,
|
||||
}
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func (p *Provider) getToken(ctx context.Context) (*oauth2.Token, error) {
|
||||
if p.cfg.serviceAccount == nil {
|
||||
return nil, fmt.Errorf("ping: service account is required")
|
||||
}
|
||||
environmentID := p.cfg.serviceAccount.EnvironmentID
|
||||
if environmentID == "" {
|
||||
environmentID = p.cfg.environmentID
|
||||
}
|
||||
if environmentID == "" {
|
||||
return nil, fmt.Errorf("ping: environment ID is required")
|
||||
}
|
||||
|
||||
p.mu.RLock()
|
||||
token := p.token
|
||||
p.mu.RUnlock()
|
||||
|
||||
if token != nil && token.Valid() {
|
||||
return token, nil
|
||||
}
|
||||
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
token = p.token
|
||||
if token != nil && token.Valid() {
|
||||
return token, nil
|
||||
}
|
||||
|
||||
ocfg := &clientcredentials.Config{
|
||||
ClientID: p.cfg.serviceAccount.ClientID,
|
||||
ClientSecret: p.cfg.serviceAccount.ClientSecret,
|
||||
TokenURL: p.cfg.authURL.ResolveReference(&url.URL{
|
||||
Path: fmt.Sprintf("/%s/as/token", environmentID),
|
||||
}).String(),
|
||||
}
|
||||
var err error
|
||||
p.token, err = ocfg.Token(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return p.token, nil
|
||||
}
|
|
@ -1,230 +0,0 @@
|
|||
package ping
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"sort"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/go-chi/chi/v5/middleware"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/testutil"
|
||||
)
|
||||
|
||||
type M = map[string]interface{}
|
||||
|
||||
func newMockAPI(userIDToGroupIDs map[string][]string) http.Handler {
|
||||
lookup := map[string]struct{}{}
|
||||
for _, groups := range userIDToGroupIDs {
|
||||
for _, group := range groups {
|
||||
lookup[group] = struct{}{}
|
||||
}
|
||||
}
|
||||
var allGroups []string
|
||||
for groupID := range lookup {
|
||||
allGroups = append(allGroups, groupID)
|
||||
}
|
||||
sort.Strings(allGroups)
|
||||
|
||||
var allUserIDs []string
|
||||
for userID := range userIDToGroupIDs {
|
||||
allUserIDs = append(allUserIDs, userID)
|
||||
}
|
||||
sort.Strings(allUserIDs)
|
||||
|
||||
filterToUserIDs := map[string][]string{}
|
||||
for userID, groupIDs := range userIDToGroupIDs {
|
||||
for _, groupID := range groupIDs {
|
||||
filter := fmt.Sprintf(`memberOfGroups[id eq "%s"]`, groupID)
|
||||
filterToUserIDs[filter] = append(filterToUserIDs[filter], userID)
|
||||
}
|
||||
}
|
||||
|
||||
r := chi.NewRouter()
|
||||
r.Use(middleware.Logger)
|
||||
r.Post("/ENVIRONMENTID/as/token", func(w http.ResponseWriter, r *http.Request) {
|
||||
u, p, _ := r.BasicAuth()
|
||||
if u != "CLIENTID" || p != "CLIENTSECRET" {
|
||||
http.Error(w, "forbidden", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
grantType := r.FormValue("grant_type")
|
||||
if grantType != "client_credentials" {
|
||||
http.Error(w, "invalid grant_type", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(M{
|
||||
"access_token": "ACCESSTOKEN",
|
||||
"created_at": time.Now().Format(time.RFC3339),
|
||||
"expires_in": 360000,
|
||||
"refresh_token": "REFRESHTOKEN",
|
||||
"token_type": "bearer",
|
||||
})
|
||||
})
|
||||
r.Route("/v1/environments/ENVIRONMENTID", func(r chi.Router) {
|
||||
r.Get("/groups", func(w http.ResponseWriter, r *http.Request) {
|
||||
var apiGroups []apiGroup
|
||||
for _, id := range allGroups {
|
||||
apiGroups = append(apiGroups, apiGroup{
|
||||
ID: id,
|
||||
Name: "Group " + id,
|
||||
})
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(M{
|
||||
"_embedded": M{
|
||||
"groups": apiGroups,
|
||||
},
|
||||
})
|
||||
})
|
||||
r.Route("/users", func(r chi.Router) {
|
||||
r.Get("/{user_id}", func(w http.ResponseWriter, r *http.Request) {
|
||||
userID := chi.URLParam(r, "user_id")
|
||||
groupIDs, ok := userIDToGroupIDs[userID]
|
||||
if !ok {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
au := apiUser{
|
||||
ID: userID,
|
||||
Email: userID + "@example.com",
|
||||
Name: apiUserName{
|
||||
Given: "Given-" + userID,
|
||||
Middle: "Middle-" + userID,
|
||||
Family: "Family-" + userID,
|
||||
},
|
||||
}
|
||||
if r.URL.Query().Get("include") == "memberOfGroupIDs" {
|
||||
au.MemberOfGroupIDs = groupIDs
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(au)
|
||||
})
|
||||
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
|
||||
filter := r.URL.Query().Get("filter")
|
||||
userIDs, ok := filterToUserIDs[filter]
|
||||
if !ok {
|
||||
http.Error(w, "expected filter", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
var apiUsers []apiUser
|
||||
for _, id := range userIDs {
|
||||
apiUsers = append(apiUsers, apiUser{
|
||||
ID: id,
|
||||
Email: id + "@example.com",
|
||||
Name: apiUserName{
|
||||
Given: "Given-" + id,
|
||||
Middle: "Middle-" + id,
|
||||
Family: "Family-" + id,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(M{
|
||||
"_embedded": M{
|
||||
"users": apiUsers,
|
||||
},
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
return r
|
||||
}
|
||||
|
||||
func TestProvider_User(t *testing.T) {
|
||||
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10)
|
||||
defer clearTimeout()
|
||||
|
||||
srv := httptest.NewServer(newMockAPI(map[string][]string{
|
||||
"user1": {"group1", "group2"},
|
||||
"user2": {"group1", "group3"},
|
||||
"user3": {"group3"},
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
u, err := url.Parse(srv.URL)
|
||||
require.NoError(t, err)
|
||||
|
||||
p := New(
|
||||
WithAPIURL(u),
|
||||
WithAuthURL(u),
|
||||
WithEnvironmentID("ENVIRONMENTID"),
|
||||
WithServiceAccount(&ServiceAccount{
|
||||
ClientID: "CLIENTID",
|
||||
ClientSecret: "CLIENTSECRET",
|
||||
}))
|
||||
du, err := p.User(ctx, "user1", "")
|
||||
require.NoError(t, err)
|
||||
testutil.AssertProtoJSONEqual(t, `{
|
||||
"id": "user1",
|
||||
"email": "user1@example.com",
|
||||
"displayName": "Given-user1 Middle-user1 Family-user1",
|
||||
"groupIds": ["group1", "group2"]
|
||||
}`, du)
|
||||
}
|
||||
|
||||
func TestProvider_UserGroups(t *testing.T) {
|
||||
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10)
|
||||
defer clearTimeout()
|
||||
|
||||
srv := httptest.NewServer(newMockAPI(map[string][]string{
|
||||
"user1": {"group1", "group2"},
|
||||
"user2": {"group1", "group3"},
|
||||
"user3": {"group3"},
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
u, err := url.Parse(srv.URL)
|
||||
require.NoError(t, err)
|
||||
|
||||
p := New(
|
||||
WithAPIURL(u),
|
||||
WithAuthURL(u),
|
||||
WithEnvironmentID("ENVIRONMENTID"),
|
||||
WithServiceAccount(&ServiceAccount{
|
||||
ClientID: "CLIENTID",
|
||||
ClientSecret: "CLIENTSECRET",
|
||||
}))
|
||||
dgs, dus, err := p.UserGroups(ctx)
|
||||
require.NoError(t, err)
|
||||
testutil.AssertProtoJSONEqual(t, `[
|
||||
{ "id": "group1", "name": "Group group1" },
|
||||
{ "id": "group2", "name": "Group group2" },
|
||||
{ "id": "group3", "name": "Group group3" }
|
||||
]`, dgs)
|
||||
testutil.AssertProtoJSONEqual(t, `[
|
||||
{
|
||||
"id": "user1",
|
||||
"email": "user1@example.com",
|
||||
"displayName": "Given-user1 Middle-user1 Family-user1",
|
||||
"groupIds": ["group1", "group2"]
|
||||
},
|
||||
{
|
||||
"id": "user2",
|
||||
"email": "user2@example.com",
|
||||
"displayName": "Given-user2 Middle-user2 Family-user2",
|
||||
"groupIds": ["group1", "group3"]
|
||||
},
|
||||
{
|
||||
"id": "user3",
|
||||
"email": "user3@example.com",
|
||||
"displayName": "Given-user3 Middle-user3 Family-user3",
|
||||
"groupIds": ["group3"]
|
||||
}
|
||||
]`, dus)
|
||||
}
|
|
@ -1,197 +0,0 @@
|
|||
// Package directory implements the user group directory service.
|
||||
package directory
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"sync"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/directory/auth0"
|
||||
"github.com/pomerium/pomerium/internal/directory/azure"
|
||||
"github.com/pomerium/pomerium/internal/directory/github"
|
||||
"github.com/pomerium/pomerium/internal/directory/gitlab"
|
||||
"github.com/pomerium/pomerium/internal/directory/google"
|
||||
"github.com/pomerium/pomerium/internal/directory/okta"
|
||||
"github.com/pomerium/pomerium/internal/directory/onelogin"
|
||||
"github.com/pomerium/pomerium/internal/directory/ping"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/directory"
|
||||
)
|
||||
|
||||
// A Group is a directory Group.
|
||||
type Group = directory.Group
|
||||
|
||||
// A User is a directory User.
|
||||
type User = directory.User
|
||||
|
||||
// Options are the options specific to the provider.
|
||||
type Options = directory.Options
|
||||
|
||||
// RegisterDirectoryServiceServer registers the directory gRPC service.
|
||||
var RegisterDirectoryServiceServer = directory.RegisterDirectoryServiceServer
|
||||
|
||||
// A Provider provides user group directory information.
|
||||
type Provider interface {
|
||||
User(ctx context.Context, userID, accessToken string) (*User, error)
|
||||
UserGroups(ctx context.Context) ([]*Group, []*User, error)
|
||||
}
|
||||
|
||||
var globalProvider = struct {
|
||||
sync.Mutex
|
||||
provider Provider
|
||||
options Options
|
||||
}{}
|
||||
|
||||
// GetProvider gets the provider for the given options.
|
||||
func GetProvider(options Options) (provider Provider) {
|
||||
globalProvider.Lock()
|
||||
defer globalProvider.Unlock()
|
||||
|
||||
ctx := context.TODO()
|
||||
if globalProvider.provider != nil && cmp.Equal(globalProvider.options, options) {
|
||||
log.Debug(ctx).Str("provider", options.Provider).Msg("directory: no change detected, reusing existing directory provider")
|
||||
return globalProvider.provider
|
||||
}
|
||||
defer func() {
|
||||
globalProvider.provider = provider
|
||||
globalProvider.options = options
|
||||
}()
|
||||
|
||||
var providerURL *url.URL
|
||||
// url.Parse will succeed even if we pass an empty string
|
||||
if options.ProviderURL != "" {
|
||||
providerURL, _ = url.Parse(options.ProviderURL)
|
||||
}
|
||||
var errSyncDisabled error
|
||||
switch options.Provider {
|
||||
case auth0.Name:
|
||||
serviceAccount, err := auth0.ParseServiceAccount(options)
|
||||
if err == nil {
|
||||
return auth0.New(
|
||||
auth0.WithDomain(options.ProviderURL),
|
||||
auth0.WithServiceAccount(serviceAccount))
|
||||
}
|
||||
errSyncDisabled = fmt.Errorf("invalid auth0 service account: %w", err)
|
||||
log.Warn(ctx).
|
||||
Str("service", "directory").
|
||||
Str("provider", options.Provider).
|
||||
Err(err).
|
||||
Msg("invalid service account for auth0 directory provider")
|
||||
case azure.Name:
|
||||
serviceAccount, err := azure.ParseServiceAccount(options)
|
||||
if err == nil {
|
||||
return azure.New(azure.WithServiceAccount(serviceAccount))
|
||||
}
|
||||
errSyncDisabled = fmt.Errorf("invalid Azure service account: %w", err)
|
||||
log.Warn(ctx).
|
||||
Str("service", "directory").
|
||||
Str("provider", options.Provider).
|
||||
Err(err).
|
||||
Msg("invalid service account for azure directory provider")
|
||||
case github.Name:
|
||||
serviceAccount, err := github.ParseServiceAccount(options.ServiceAccount)
|
||||
if err == nil {
|
||||
return github.New(github.WithServiceAccount(serviceAccount))
|
||||
}
|
||||
errSyncDisabled = fmt.Errorf("invalid GitHub service account: %w", err)
|
||||
log.Warn(ctx).
|
||||
Str("service", "directory").
|
||||
Str("provider", options.Provider).
|
||||
Err(err).
|
||||
Msg("invalid service account for github directory provider")
|
||||
case gitlab.Name:
|
||||
serviceAccount, err := gitlab.ParseServiceAccount(options.ServiceAccount)
|
||||
if err == nil {
|
||||
if providerURL == nil {
|
||||
return gitlab.New(gitlab.WithServiceAccount(serviceAccount))
|
||||
}
|
||||
return gitlab.New(
|
||||
gitlab.WithURL(providerURL),
|
||||
gitlab.WithServiceAccount(serviceAccount))
|
||||
}
|
||||
errSyncDisabled = fmt.Errorf("invalid GitLab service account: %w", err)
|
||||
log.Warn(ctx).
|
||||
Str("service", "directory").
|
||||
Str("provider", options.Provider).
|
||||
Err(err).
|
||||
Msg("invalid service account for gitlab directory provider")
|
||||
case google.Name:
|
||||
serviceAccount, err := google.ParseServiceAccount(options.ServiceAccount)
|
||||
if err == nil {
|
||||
googleOptions := []google.Option{
|
||||
google.WithServiceAccount(serviceAccount),
|
||||
}
|
||||
if options.ProviderURL != "" {
|
||||
googleOptions = append(googleOptions, google.WithURL(options.ProviderURL))
|
||||
}
|
||||
return google.New(googleOptions...)
|
||||
}
|
||||
errSyncDisabled = fmt.Errorf("invalid google service account: %w", err)
|
||||
log.Warn(ctx).
|
||||
Str("service", "directory").
|
||||
Str("provider", options.Provider).
|
||||
Err(err).
|
||||
Msg("invalid service account for Google directory provider")
|
||||
case okta.Name:
|
||||
serviceAccount, err := okta.ParseServiceAccount(options.ServiceAccount)
|
||||
if err == nil {
|
||||
return okta.New(
|
||||
okta.WithProviderURL(providerURL),
|
||||
okta.WithServiceAccount(serviceAccount))
|
||||
}
|
||||
errSyncDisabled = fmt.Errorf("invalid Okta service account: %w", err)
|
||||
log.Warn(ctx).
|
||||
Str("service", "directory").
|
||||
Str("provider", options.Provider).
|
||||
Err(err).
|
||||
Msg("invalid service account for okta directory provider")
|
||||
case onelogin.Name:
|
||||
serviceAccount, err := onelogin.ParseServiceAccount(options.ServiceAccount)
|
||||
if err == nil {
|
||||
return onelogin.New(onelogin.WithServiceAccount(serviceAccount))
|
||||
}
|
||||
errSyncDisabled = fmt.Errorf("invalid OneLogin service account: %w", err)
|
||||
log.Warn(ctx).
|
||||
Str("service", "directory").
|
||||
Str("provider", options.Provider).
|
||||
Err(err).
|
||||
Msg("invalid service account for onelogin directory provider")
|
||||
case ping.Name:
|
||||
serviceAccount, err := ping.ParseServiceAccount(options.ServiceAccount)
|
||||
if err == nil {
|
||||
return ping.New(
|
||||
ping.WithProviderURL(providerURL),
|
||||
ping.WithServiceAccount(serviceAccount))
|
||||
}
|
||||
errSyncDisabled = fmt.Errorf("invalid Ping service account: %w", err)
|
||||
log.Warn(ctx).
|
||||
Str("service", "directory").
|
||||
Str("provider", options.Provider).
|
||||
Err(err).
|
||||
Msg("invalid service account for ping directory provider")
|
||||
case "":
|
||||
errSyncDisabled = fmt.Errorf("no directory provider configured")
|
||||
default:
|
||||
errSyncDisabled = fmt.Errorf("unknown directory provider %s", options.Provider)
|
||||
}
|
||||
|
||||
log.Warn(ctx).
|
||||
Str("provider", options.Provider).
|
||||
Msg(errSyncDisabled.Error())
|
||||
return nullProvider{errSyncDisabled}
|
||||
}
|
||||
|
||||
type nullProvider struct {
|
||||
error
|
||||
}
|
||||
|
||||
func (p nullProvider) User(ctx context.Context, userID, accessToken string) (*directory.User, error) {
|
||||
return nil, p.error
|
||||
}
|
||||
|
||||
func (p nullProvider) UserGroups(ctx context.Context) ([]*Group, []*User, error) {
|
||||
return nil, nil, p.error
|
||||
}
|
|
@ -3,24 +3,18 @@ package manager
|
|||
import (
|
||||
"time"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/directory"
|
||||
"github.com/pomerium/pomerium/internal/events"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
)
|
||||
|
||||
var (
|
||||
defaultGroupRefreshInterval = 15 * time.Minute
|
||||
defaultGroupRefreshTimeout = 10 * time.Minute
|
||||
defaultSessionRefreshGracePeriod = 1 * time.Minute
|
||||
defaultSessionRefreshCoolOffDuration = 10 * time.Second
|
||||
)
|
||||
|
||||
type config struct {
|
||||
authenticator Authenticator
|
||||
directory directory.Provider
|
||||
dataBrokerClient databroker.DataBrokerServiceClient
|
||||
groupRefreshInterval time.Duration
|
||||
groupRefreshTimeout time.Duration
|
||||
sessionRefreshGracePeriod time.Duration
|
||||
sessionRefreshCoolOffDuration time.Duration
|
||||
now func() time.Time
|
||||
|
@ -29,8 +23,6 @@ type config struct {
|
|||
|
||||
func newConfig(options ...Option) *config {
|
||||
cfg := new(config)
|
||||
WithGroupRefreshInterval(defaultGroupRefreshInterval)(cfg)
|
||||
WithGroupRefreshTimeout(defaultGroupRefreshTimeout)(cfg)
|
||||
WithSessionRefreshGracePeriod(defaultSessionRefreshGracePeriod)(cfg)
|
||||
WithSessionRefreshCoolOffDuration(defaultSessionRefreshCoolOffDuration)(cfg)
|
||||
WithNow(time.Now)(cfg)
|
||||
|
@ -50,13 +42,6 @@ func WithAuthenticator(authenticator Authenticator) Option {
|
|||
}
|
||||
}
|
||||
|
||||
// WithDirectoryProvider sets the directory provider in the config.
|
||||
func WithDirectoryProvider(directoryProvider directory.Provider) Option {
|
||||
return func(cfg *config) {
|
||||
cfg.directory = directoryProvider
|
||||
}
|
||||
}
|
||||
|
||||
// WithDataBrokerClient sets the databroker client in the config.
|
||||
func WithDataBrokerClient(dataBrokerClient databroker.DataBrokerServiceClient) Option {
|
||||
return func(cfg *config) {
|
||||
|
@ -64,20 +49,6 @@ func WithDataBrokerClient(dataBrokerClient databroker.DataBrokerServiceClient) O
|
|||
}
|
||||
}
|
||||
|
||||
// WithGroupRefreshInterval sets the group refresh interval used by the manager.
|
||||
func WithGroupRefreshInterval(interval time.Duration) Option {
|
||||
return func(cfg *config) {
|
||||
cfg.groupRefreshInterval = interval
|
||||
}
|
||||
}
|
||||
|
||||
// WithGroupRefreshTimeout sets the group refresh timeout used by the manager.
|
||||
func WithGroupRefreshTimeout(timeout time.Duration) Option {
|
||||
return func(cfg *config) {
|
||||
cfg.groupRefreshTimeout = timeout
|
||||
}
|
||||
}
|
||||
|
||||
// WithSessionRefreshGracePeriod sets the session refresh grace period used by the manager.
|
||||
func WithSessionRefreshGracePeriod(dur time.Duration) Option {
|
||||
return func(cfg *config) {
|
||||
|
|
|
@ -12,16 +12,17 @@ import (
|
|||
"github.com/pomerium/pomerium/pkg/grpc/user"
|
||||
)
|
||||
|
||||
const userRefreshInterval = 10 * time.Minute
|
||||
|
||||
// A User is a user managed by the Manager.
|
||||
type User struct {
|
||||
*user.User
|
||||
lastRefresh time.Time
|
||||
refreshInterval time.Duration
|
||||
lastRefresh time.Time
|
||||
}
|
||||
|
||||
// NextRefresh returns the next time the user information needs to be refreshed.
|
||||
func (u User) NextRefresh() time.Time {
|
||||
return u.lastRefresh.Add(u.refreshInterval)
|
||||
return u.lastRefresh.Add(userRefreshInterval)
|
||||
}
|
||||
|
||||
// UnmarshalJSON unmarshals json data into the user object.
|
||||
|
|
|
@ -6,16 +6,13 @@ import (
|
|||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
"github.com/google/btree"
|
||||
"github.com/rs/zerolog"
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/sync/errgroup"
|
||||
"google.golang.org/protobuf/proto"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/atomicutil"
|
||||
"github.com/pomerium/pomerium/internal/directory"
|
||||
"github.com/pomerium/pomerium/internal/events"
|
||||
"github.com/pomerium/pomerium/internal/identity/identity"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
|
@ -26,7 +23,6 @@ import (
|
|||
"github.com/pomerium/pomerium/pkg/grpc/user"
|
||||
"github.com/pomerium/pomerium/pkg/grpcutil"
|
||||
metrics_ids "github.com/pomerium/pomerium/pkg/metrics"
|
||||
"github.com/pomerium/pomerium/pkg/protoutil"
|
||||
)
|
||||
|
||||
// Authenticator is an identity.Provider with only the methods needed by the manager.
|
||||
|
@ -49,13 +45,8 @@ type Manager struct {
|
|||
sessionScheduler *scheduler.Scheduler
|
||||
userScheduler *scheduler.Scheduler
|
||||
|
||||
sessions sessionCollection
|
||||
users userCollection
|
||||
directoryUsers map[string]*directory.User
|
||||
directoryGroups map[string]*directory.Group
|
||||
|
||||
directoryBackoff *backoff.ExponentialBackOff
|
||||
directoryNextRefresh time.Time
|
||||
sessions sessionCollection
|
||||
users userCollection
|
||||
}
|
||||
|
||||
// New creates a new identity manager.
|
||||
|
@ -68,8 +59,6 @@ func New(
|
|||
sessionScheduler: scheduler.New(),
|
||||
userScheduler: scheduler.New(),
|
||||
}
|
||||
mgr.directoryBackoff = backoff.NewExponentialBackOff()
|
||||
mgr.directoryBackoff.MaxElapsedTime = 0
|
||||
mgr.reset()
|
||||
mgr.UpdateConfig(options...)
|
||||
return mgr
|
||||
|
@ -131,8 +120,6 @@ func (mgr *Manager) refreshLoop(ctx context.Context, update <-chan updateRecords
|
|||
}
|
||||
|
||||
log.Info(ctx).
|
||||
Int("directory_groups", len(mgr.directoryGroups)).
|
||||
Int("directory_users", len(mgr.directoryUsers)).
|
||||
Int("sessions", mgr.sessions.Len()).
|
||||
Int("users", mgr.users.Len()).
|
||||
Msg("initial sync complete")
|
||||
|
@ -140,9 +127,6 @@ func (mgr *Manager) refreshLoop(ctx context.Context, update <-chan updateRecords
|
|||
// start refreshing
|
||||
maxWait := time.Minute * 10
|
||||
nextTime := time.Now().Add(maxWait)
|
||||
if mgr.directoryNextRefresh.Before(nextTime) {
|
||||
nextTime = mgr.directoryNextRefresh
|
||||
}
|
||||
|
||||
timer := time.NewTimer(time.Until(nextTime))
|
||||
defer timer.Stop()
|
||||
|
@ -161,14 +145,6 @@ func (mgr *Manager) refreshLoop(ctx context.Context, update <-chan updateRecords
|
|||
now := time.Now()
|
||||
nextTime = now.Add(maxWait)
|
||||
|
||||
// refresh groups
|
||||
if mgr.directoryNextRefresh.Before(now) {
|
||||
mgr.directoryNextRefresh = now.Add(mgr.refreshDirectoryUserGroups(ctx))
|
||||
}
|
||||
if mgr.directoryNextRefresh.Before(nextTime) {
|
||||
nextTime = mgr.directoryNextRefresh
|
||||
}
|
||||
|
||||
// refresh sessions
|
||||
for {
|
||||
tm, key := mgr.sessionScheduler.Next()
|
||||
|
@ -203,128 +179,6 @@ func (mgr *Manager) refreshLoop(ctx context.Context, update <-chan updateRecords
|
|||
}
|
||||
}
|
||||
|
||||
func (mgr *Manager) refreshDirectoryUserGroups(ctx context.Context) (nextRefreshDelay time.Duration) {
|
||||
log.Info(ctx).Msg("refreshing directory users")
|
||||
|
||||
ctx, clearTimeout := context.WithTimeout(ctx, mgr.cfg.Load().groupRefreshTimeout)
|
||||
defer clearTimeout()
|
||||
|
||||
directoryGroups, directoryUsers, err := mgr.cfg.Load().directory.UserGroups(ctx)
|
||||
metrics.RecordIdentityManagerUserGroupRefresh(ctx, err)
|
||||
mgr.recordLastError(metrics_ids.IdentityManagerLastUserGroupRefreshError, err)
|
||||
if err != nil {
|
||||
msg := "failed to refresh directory users and groups"
|
||||
if ctx.Err() != nil {
|
||||
msg += ". You may need to increase the identity provider directory timeout setting"
|
||||
msg += "(https://www.pomerium.com/docs/reference/identity-provider-refresh-directory-settings)"
|
||||
}
|
||||
log.Warn(ctx).Err(err).Msg(msg)
|
||||
|
||||
return minDuration(
|
||||
mgr.cfg.Load().groupRefreshInterval, // never wait more than the refresh interval
|
||||
mgr.directoryBackoff.NextBackOff(),
|
||||
)
|
||||
}
|
||||
mgr.directoryBackoff.Reset() // success so reset the backoff
|
||||
|
||||
mgr.mergeGroups(ctx, directoryGroups)
|
||||
mgr.mergeUsers(ctx, directoryUsers)
|
||||
|
||||
return mgr.cfg.Load().groupRefreshInterval
|
||||
}
|
||||
|
||||
func (mgr *Manager) mergeGroups(ctx context.Context, directoryGroups []*directory.Group) {
|
||||
lookup := map[string]*directory.Group{}
|
||||
for _, dg := range directoryGroups {
|
||||
lookup[dg.GetId()] = dg
|
||||
}
|
||||
|
||||
var records []*databroker.Record
|
||||
|
||||
for groupID, newDG := range lookup {
|
||||
curDG, ok := mgr.directoryGroups[groupID]
|
||||
if !ok || !proto.Equal(newDG, curDG) {
|
||||
id := newDG.GetId()
|
||||
any := protoutil.NewAny(newDG)
|
||||
records = append(records, &databroker.Record{
|
||||
Type: any.GetTypeUrl(),
|
||||
Id: id,
|
||||
Data: any,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
for groupID, curDG := range mgr.directoryGroups {
|
||||
_, ok := lookup[groupID]
|
||||
if !ok {
|
||||
id := curDG.GetId()
|
||||
any := protoutil.NewAny(curDG)
|
||||
records = append(records, &databroker.Record{
|
||||
Type: any.GetTypeUrl(),
|
||||
Id: id,
|
||||
Data: any,
|
||||
DeletedAt: timestamppb.New(mgr.cfg.Load().now()),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
for i, batch := range databroker.OptimumPutRequestsFromRecords(records) {
|
||||
_, err := mgr.cfg.Load().dataBrokerClient.Put(ctx, batch)
|
||||
if err != nil {
|
||||
log.Warn(ctx).Err(err).
|
||||
Int("batch", i).
|
||||
Int("record-count", len(batch.GetRecords())).
|
||||
Msg("manager: failed to update groups")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (mgr *Manager) mergeUsers(ctx context.Context, directoryUsers []*directory.User) {
|
||||
lookup := map[string]*directory.User{}
|
||||
for _, du := range directoryUsers {
|
||||
lookup[du.GetId()] = du
|
||||
}
|
||||
|
||||
var records []*databroker.Record
|
||||
|
||||
for userID, newDU := range lookup {
|
||||
curDU, ok := mgr.directoryUsers[userID]
|
||||
if !ok || !proto.Equal(newDU, curDU) {
|
||||
id := newDU.GetId()
|
||||
any := protoutil.NewAny(newDU)
|
||||
records = append(records, &databroker.Record{
|
||||
Type: any.GetTypeUrl(),
|
||||
Id: id,
|
||||
Data: any,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
for userID, curDU := range mgr.directoryUsers {
|
||||
_, ok := lookup[userID]
|
||||
if !ok {
|
||||
id := curDU.GetId()
|
||||
any := protoutil.NewAny(curDU)
|
||||
records = append(records, &databroker.Record{
|
||||
Type: any.GetTypeUrl(),
|
||||
Id: id,
|
||||
Data: any,
|
||||
DeletedAt: timestamppb.New(mgr.cfg.Load().now()),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
for i, batch := range databroker.OptimumPutRequestsFromRecords(records) {
|
||||
_, err := mgr.cfg.Load().dataBrokerClient.Put(ctx, batch)
|
||||
if err != nil {
|
||||
log.Warn(ctx).Err(err).
|
||||
Int("batch", i).
|
||||
Int("record-count", len(batch.GetRecords())).
|
||||
Msg("manager: failed to update users")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (mgr *Manager) refreshSession(ctx context.Context, userID, sessionID string) {
|
||||
log.Info(ctx).
|
||||
Str("user_id", userID).
|
||||
|
@ -464,22 +318,6 @@ func (mgr *Manager) refreshUser(ctx context.Context, userID string) {
|
|||
func (mgr *Manager) onUpdateRecords(ctx context.Context, msg updateRecordsMessage) {
|
||||
for _, record := range msg.records {
|
||||
switch record.GetType() {
|
||||
case grpcutil.GetTypeURL(new(directory.Group)):
|
||||
var pbDirectoryGroup directory.Group
|
||||
err := record.GetData().UnmarshalTo(&pbDirectoryGroup)
|
||||
if err != nil {
|
||||
log.Warn(ctx).Msgf("error unmarshaling directory group: %s", err)
|
||||
continue
|
||||
}
|
||||
mgr.onUpdateDirectoryGroup(ctx, &pbDirectoryGroup)
|
||||
case grpcutil.GetTypeURL(new(directory.User)):
|
||||
var pbDirectoryUser directory.User
|
||||
err := record.GetData().UnmarshalTo(&pbDirectoryUser)
|
||||
if err != nil {
|
||||
log.Warn(ctx).Msgf("error unmarshaling directory user: %s", err)
|
||||
continue
|
||||
}
|
||||
mgr.onUpdateDirectoryUser(ctx, &pbDirectoryUser)
|
||||
case grpcutil.GetTypeURL(new(session.Session)):
|
||||
var pbSession session.Session
|
||||
err := record.GetData().UnmarshalTo(&pbSession)
|
||||
|
@ -528,20 +366,11 @@ func (mgr *Manager) onUpdateUser(_ context.Context, record *databroker.Record, u
|
|||
|
||||
u, _ := mgr.users.Get(user.GetId())
|
||||
u.lastRefresh = mgr.cfg.Load().now()
|
||||
u.refreshInterval = mgr.cfg.Load().groupRefreshInterval
|
||||
u.User = user
|
||||
mgr.users.ReplaceOrInsert(u)
|
||||
mgr.userScheduler.Add(u.NextRefresh(), u.GetId())
|
||||
}
|
||||
|
||||
func (mgr *Manager) onUpdateDirectoryUser(_ context.Context, pbDirectoryUser *directory.User) {
|
||||
mgr.directoryUsers[pbDirectoryUser.GetId()] = pbDirectoryUser
|
||||
}
|
||||
|
||||
func (mgr *Manager) onUpdateDirectoryGroup(_ context.Context, pbDirectoryGroup *directory.Group) {
|
||||
mgr.directoryGroups[pbDirectoryGroup.GetId()] = pbDirectoryGroup
|
||||
}
|
||||
|
||||
func (mgr *Manager) deleteSession(ctx context.Context, pbSession *session.Session) {
|
||||
err := session.Delete(ctx, mgr.cfg.Load().dataBrokerClient, pbSession.GetId())
|
||||
if err != nil {
|
||||
|
@ -553,8 +382,6 @@ func (mgr *Manager) deleteSession(ctx context.Context, pbSession *session.Sessio
|
|||
|
||||
// reset resets all the manager datastructures to their initial state
|
||||
func (mgr *Manager) reset() {
|
||||
mgr.directoryGroups = make(map[string]*directory.Group)
|
||||
mgr.directoryUsers = make(map[string]*directory.User)
|
||||
mgr.sessions = sessionCollection{BTree: btree.New(8)}
|
||||
mgr.users = userCollection{BTree: btree.New(8)}
|
||||
}
|
||||
|
@ -586,13 +413,3 @@ func isTemporaryError(err error) bool {
|
|||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func minDuration(d1 time.Duration, ds ...time.Duration) time.Duration {
|
||||
min := d1
|
||||
for _, d := range ds {
|
||||
if d < min {
|
||||
min = d
|
||||
}
|
||||
}
|
||||
return min
|
||||
}
|
||||
|
|
|
@ -3,7 +3,6 @@ package manager
|
|||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
@ -13,7 +12,6 @@ import (
|
|||
"google.golang.org/protobuf/proto"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/directory"
|
||||
"github.com/pomerium/pomerium/internal/events"
|
||||
"github.com/pomerium/pomerium/internal/identity/identity"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
|
@ -24,19 +22,6 @@ import (
|
|||
"github.com/pomerium/pomerium/pkg/protoutil"
|
||||
)
|
||||
|
||||
type mockProvider struct {
|
||||
user func(ctx context.Context, userID, accessToken string) (*directory.User, error)
|
||||
userGroups func(ctx context.Context) ([]*directory.Group, []*directory.User, error)
|
||||
}
|
||||
|
||||
func (mock mockProvider) User(ctx context.Context, userID, accessToken string) (*directory.User, error) {
|
||||
return mock.user(ctx, userID, accessToken)
|
||||
}
|
||||
|
||||
func (mock mockProvider) UserGroups(ctx context.Context) ([]*directory.Group, []*directory.User, error) {
|
||||
return mock.userGroups(ctx)
|
||||
}
|
||||
|
||||
type mockAuthenticator struct{}
|
||||
|
||||
func (mock mockAuthenticator) Refresh(_ context.Context, _ *oauth2.Token, _ identity.State) (*oauth2.Token, error) {
|
||||
|
@ -61,72 +46,25 @@ func TestManager_onUpdateRecords(t *testing.T) {
|
|||
|
||||
mgr := New(
|
||||
WithDataBrokerClient(mock_databroker.NewMockDataBrokerServiceClient(ctrl)),
|
||||
WithDirectoryProvider(mockProvider{}),
|
||||
WithGroupRefreshInterval(time.Hour),
|
||||
WithNow(func() time.Time {
|
||||
return now
|
||||
}),
|
||||
)
|
||||
mgr.directoryBackoff.RandomizationFactor = 0 // disable randomization for deterministic testing
|
||||
|
||||
mgr.onUpdateRecords(ctx, updateRecordsMessage{
|
||||
records: []*databroker.Record{
|
||||
mkRecord(&directory.Group{Id: "group1", Name: "group 1", Email: "group1@example.com"}),
|
||||
mkRecord(&directory.User{Id: "user1", DisplayName: "user 1", Email: "user1@example.com", GroupIds: []string{"group1s"}}),
|
||||
mkRecord(&session.Session{Id: "session1", UserId: "user1"}),
|
||||
mkRecord(&user.User{Id: "user1", Name: "user 1", Email: "user1@example.com"}),
|
||||
},
|
||||
})
|
||||
|
||||
assert.NotNil(t, mgr.directoryGroups["group1"])
|
||||
assert.NotNil(t, mgr.directoryUsers["user1"])
|
||||
if _, ok := mgr.sessions.Get("user1", "session1"); assert.True(t, ok) {
|
||||
|
||||
}
|
||||
if _, ok := mgr.users.Get("user1"); assert.True(t, ok) {
|
||||
tm, id := mgr.userScheduler.Next()
|
||||
assert.Equal(t, now.Add(time.Hour), tm)
|
||||
assert.Equal(t, now.Add(userRefreshInterval), tm)
|
||||
assert.Equal(t, "user1", id)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestManager_refreshDirectoryUserGroups(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
|
||||
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10)
|
||||
defer clearTimeout()
|
||||
|
||||
t.Run("backoff", func(t *testing.T) {
|
||||
cnt := 0
|
||||
client := mock_databroker.NewMockDataBrokerServiceClient(ctrl)
|
||||
client.EXPECT().Put(gomock.Any(), gomock.Any()).AnyTimes()
|
||||
mgr := New(
|
||||
WithDataBrokerClient(client),
|
||||
WithDirectoryProvider(mockProvider{
|
||||
userGroups: func(ctx context.Context) ([]*directory.Group, []*directory.User, error) {
|
||||
cnt++
|
||||
switch cnt {
|
||||
case 1:
|
||||
return nil, nil, fmt.Errorf("error 1")
|
||||
case 2:
|
||||
return nil, nil, fmt.Errorf("error 2")
|
||||
}
|
||||
return nil, nil, nil
|
||||
},
|
||||
}),
|
||||
WithGroupRefreshInterval(time.Hour),
|
||||
)
|
||||
mgr.directoryBackoff.RandomizationFactor = 0 // disable randomization for deterministic testing
|
||||
|
||||
dur1 := mgr.refreshDirectoryUserGroups(ctx)
|
||||
dur2 := mgr.refreshDirectoryUserGroups(ctx)
|
||||
dur3 := mgr.refreshDirectoryUserGroups(ctx)
|
||||
|
||||
assert.Greater(t, dur2, dur1)
|
||||
assert.Greater(t, dur3, dur2)
|
||||
assert.Equal(t, time.Hour, dur3)
|
||||
})
|
||||
}
|
||||
|
||||
func TestManager_reportErrors(t *testing.T) {
|
||||
|
@ -161,22 +99,10 @@ func TestManager_reportErrors(t *testing.T) {
|
|||
WithEventManager(evtMgr),
|
||||
WithDataBrokerClient(client),
|
||||
WithAuthenticator(mockAuthenticator{}),
|
||||
WithDirectoryProvider(mockProvider{
|
||||
user: func(ctx context.Context, userID, accessToken string) (*directory.User, error) {
|
||||
return nil, fmt.Errorf("user")
|
||||
},
|
||||
userGroups: func(ctx context.Context) ([]*directory.Group, []*directory.User, error) {
|
||||
return nil, nil, fmt.Errorf("user groups")
|
||||
},
|
||||
}),
|
||||
WithGroupRefreshInterval(time.Second),
|
||||
)
|
||||
mgr.directoryBackoff.RandomizationFactor = 0 // disable randomization for deterministic testing
|
||||
|
||||
mgr.onUpdateRecords(ctx, updateRecordsMessage{
|
||||
records: []*databroker.Record{
|
||||
mkRecord(&directory.Group{Id: "group1", Name: "group 1", Email: "group1@example.com"}),
|
||||
mkRecord(&directory.User{Id: "user1", DisplayName: "user 1", Email: "user1@example.com", GroupIds: []string{"group1s"}}),
|
||||
mkRecord(&session.Session{Id: "session1", UserId: "user1", OauthToken: &session.OAuthToken{
|
||||
ExpiresAt: timestamppb.New(time.Now().Add(time.Hour)),
|
||||
}, ExpiresAt: timestamppb.New(time.Now().Add(time.Hour))}),
|
||||
|
@ -184,9 +110,6 @@ func TestManager_reportErrors(t *testing.T) {
|
|||
},
|
||||
})
|
||||
|
||||
_ = mgr.refreshDirectoryUserGroups(ctx)
|
||||
expectMsg(metrics_ids.IdentityManagerLastUserGroupRefreshError, "user groups")
|
||||
|
||||
mgr.refreshUser(ctx, "user1")
|
||||
expectMsg(metrics_ids.IdentityManagerLastUserRefreshError, "update user info")
|
||||
|
||||
|
|
|
@ -26,10 +26,6 @@ type Options struct {
|
|||
// Scope specifies optional requested permissions.
|
||||
Scopes []string
|
||||
|
||||
// ServiceAccount can be set for those providers that require additional
|
||||
// credentials or tokens to do follow up API calls (e.g. Google)
|
||||
ServiceAccount string
|
||||
|
||||
// AuthCodeOptions specifies additional key value pairs query params to add
|
||||
// to the request flow signin url.
|
||||
AuthCodeOptions map[string]string
|
||||
|
|
|
@ -27,7 +27,6 @@ var defaultScopes = []string{oidc.ScopeOpenID, "profile", "email"}
|
|||
// requires we set this on a custom uri param. Also, ` prompt` must be set to `consent`to ensure
|
||||
// that our application always receives a refresh token (ask google). And finally, we default to
|
||||
// having the user select which Google account they'd like to use.
|
||||
//
|
||||
// For more details, please see google's documentation:
|
||||
//
|
||||
// https://developers.google.com/identity/protocols/oauth2/web-server#offline
|
||||
|
|
|
@ -13,7 +13,6 @@ cookie_secret: YYYYY
|
|||
idp_provider: "google"
|
||||
idp_client_id: XXXX
|
||||
idp_client_secret: YYYY
|
||||
idp_service_account: XXXXXX
|
||||
|
||||
routes:
|
||||
- from: https://yoursite.localhost.pomerium.io
|
||||
|
|
|
@ -9,6 +9,7 @@ import (
|
|||
)
|
||||
|
||||
func TestMerge(t *testing.T) {
|
||||
type key string
|
||||
t.Run("value", func(t *testing.T) {
|
||||
type contextKey string
|
||||
k1 := contextKey("key1")
|
||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -55,7 +55,7 @@ message Route {
|
|||
RouteRedirect redirect = 34;
|
||||
|
||||
repeated string allowed_users = 4 [ deprecated = true ];
|
||||
repeated string allowed_groups = 5 [ deprecated = true ];
|
||||
// repeated string allowed_groups = 5 [ deprecated = true ];
|
||||
repeated string allowed_domains = 6 [ deprecated = true ];
|
||||
map<string, google.protobuf.ListValue> allowed_idp_claims = 32
|
||||
[ deprecated = true ];
|
||||
|
@ -121,7 +121,7 @@ message Policy {
|
|||
string id = 1;
|
||||
string name = 2;
|
||||
repeated string allowed_users = 3;
|
||||
repeated string allowed_groups = 4;
|
||||
// repeated string allowed_groups = 4;
|
||||
repeated string allowed_domains = 5;
|
||||
map<string, google.protobuf.ListValue> allowed_idp_claims = 7;
|
||||
repeated string rego = 6;
|
||||
|
@ -166,9 +166,9 @@ message Settings {
|
|||
optional string idp_provider = 24;
|
||||
optional string idp_provider_url = 25;
|
||||
repeated string scopes = 26;
|
||||
optional string idp_service_account = 27;
|
||||
optional google.protobuf.Duration idp_refresh_directory_timeout = 28;
|
||||
optional google.protobuf.Duration idp_refresh_directory_interval = 29;
|
||||
// optional string idp_service_account = 27;
|
||||
// optional google.protobuf.Duration idp_refresh_directory_timeout = 28;
|
||||
// optional google.protobuf.Duration idp_refresh_directory_interval = 29;
|
||||
map<string, string> request_params = 30;
|
||||
repeated string authorize_service_urls = 32;
|
||||
optional string authorize_internal_service_url = 83;
|
||||
|
|
|
@ -84,7 +84,7 @@ loop:
|
|||
case err == io.EOF:
|
||||
break loop
|
||||
case err != nil:
|
||||
return nil, 0, 0, err
|
||||
return nil, 0, 0, fmt.Errorf("error receiving record: %w", err)
|
||||
}
|
||||
|
||||
switch res := res.GetResponse().(type) {
|
||||
|
|
|
@ -2,6 +2,7 @@ package databroker
|
|||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
backoff "github.com/cenkalti/backoff/v4"
|
||||
|
@ -132,8 +133,7 @@ func (syncer *Syncer) init(ctx context.Context) error {
|
|||
Type: syncer.cfg.typeURL,
|
||||
})
|
||||
if err != nil {
|
||||
log.Error(ctx).Err(err).Msg("error during initial sync")
|
||||
return err
|
||||
return fmt.Errorf("error during initial sync: %w", err)
|
||||
}
|
||||
syncer.backoff.Reset()
|
||||
|
||||
|
@ -154,8 +154,7 @@ func (syncer *Syncer) sync(ctx context.Context) error {
|
|||
Type: syncer.cfg.typeURL,
|
||||
})
|
||||
if err != nil {
|
||||
log.Error(ctx).Err(err).Msg("error during sync")
|
||||
return err
|
||||
return fmt.Errorf("error calling sync: %w", err)
|
||||
}
|
||||
|
||||
log.Info(ctx).Msg("listening for updates")
|
||||
|
@ -168,7 +167,7 @@ func (syncer *Syncer) sync(ctx context.Context) error {
|
|||
syncer.serverVersion = 0
|
||||
return nil
|
||||
} else if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("error receiving sync record: %w", err)
|
||||
}
|
||||
|
||||
rec := res.GetRecord()
|
||||
|
|
|
@ -1,30 +0,0 @@
|
|||
// Package directory contains protobuf types for directory users.
|
||||
package directory
|
||||
|
||||
import (
|
||||
context "context"
|
||||
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
)
|
||||
|
||||
// GetGroup gets a directory group from the databroker.
|
||||
func GetGroup(ctx context.Context, client databroker.DataBrokerServiceClient, groupID string) (*Group, error) {
|
||||
g := Group{Id: groupID}
|
||||
return &g, databroker.Get(ctx, client, &g)
|
||||
}
|
||||
|
||||
// GetUser gets a directory user from the databroker.
|
||||
func GetUser(ctx context.Context, client databroker.DataBrokerServiceClient, userID string) (*User, error) {
|
||||
u := User{Id: userID}
|
||||
return &u, databroker.Get(ctx, client, &u)
|
||||
}
|
||||
|
||||
// Options are directory provider options.
|
||||
type Options struct {
|
||||
ServiceAccount string
|
||||
Provider string
|
||||
ProviderURL string
|
||||
ClientID string
|
||||
ClientSecret string
|
||||
QPS float64
|
||||
}
|
|
@ -1,441 +0,0 @@
|
|||
// Code generated by protoc-gen-go. DO NOT EDIT.
|
||||
// versions:
|
||||
// protoc-gen-go v1.28.1
|
||||
// protoc v3.21.7
|
||||
// source: directory.proto
|
||||
|
||||
package directory
|
||||
|
||||
import (
|
||||
context "context"
|
||||
grpc "google.golang.org/grpc"
|
||||
codes "google.golang.org/grpc/codes"
|
||||
status "google.golang.org/grpc/status"
|
||||
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
|
||||
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
|
||||
emptypb "google.golang.org/protobuf/types/known/emptypb"
|
||||
reflect "reflect"
|
||||
sync "sync"
|
||||
)
|
||||
|
||||
const (
|
||||
// Verify that this generated code is sufficiently up-to-date.
|
||||
_ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
|
||||
// Verify that runtime/protoimpl is sufficiently up-to-date.
|
||||
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
|
||||
)
|
||||
|
||||
type User struct {
|
||||
state protoimpl.MessageState
|
||||
sizeCache protoimpl.SizeCache
|
||||
unknownFields protoimpl.UnknownFields
|
||||
|
||||
Version string `protobuf:"bytes,1,opt,name=version,proto3" json:"version,omitempty"`
|
||||
Id string `protobuf:"bytes,2,opt,name=id,proto3" json:"id,omitempty"`
|
||||
GroupIds []string `protobuf:"bytes,3,rep,name=group_ids,json=groupIds,proto3" json:"group_ids,omitempty"`
|
||||
DisplayName string `protobuf:"bytes,4,opt,name=display_name,json=displayName,proto3" json:"display_name,omitempty"`
|
||||
Email string `protobuf:"bytes,5,opt,name=email,proto3" json:"email,omitempty"`
|
||||
}
|
||||
|
||||
func (x *User) Reset() {
|
||||
*x = User{}
|
||||
if protoimpl.UnsafeEnabled {
|
||||
mi := &file_directory_proto_msgTypes[0]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
}
|
||||
|
||||
func (x *User) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*User) ProtoMessage() {}
|
||||
|
||||
func (x *User) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_directory_proto_msgTypes[0]
|
||||
if protoimpl.UnsafeEnabled && x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use User.ProtoReflect.Descriptor instead.
|
||||
func (*User) Descriptor() ([]byte, []int) {
|
||||
return file_directory_proto_rawDescGZIP(), []int{0}
|
||||
}
|
||||
|
||||
func (x *User) GetVersion() string {
|
||||
if x != nil {
|
||||
return x.Version
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *User) GetId() string {
|
||||
if x != nil {
|
||||
return x.Id
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *User) GetGroupIds() []string {
|
||||
if x != nil {
|
||||
return x.GroupIds
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (x *User) GetDisplayName() string {
|
||||
if x != nil {
|
||||
return x.DisplayName
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *User) GetEmail() string {
|
||||
if x != nil {
|
||||
return x.Email
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
type Group struct {
|
||||
state protoimpl.MessageState
|
||||
sizeCache protoimpl.SizeCache
|
||||
unknownFields protoimpl.UnknownFields
|
||||
|
||||
Version string `protobuf:"bytes,1,opt,name=version,proto3" json:"version,omitempty"`
|
||||
Id string `protobuf:"bytes,2,opt,name=id,proto3" json:"id,omitempty"`
|
||||
Name string `protobuf:"bytes,3,opt,name=name,proto3" json:"name,omitempty"`
|
||||
Email string `protobuf:"bytes,4,opt,name=email,proto3" json:"email,omitempty"`
|
||||
}
|
||||
|
||||
func (x *Group) Reset() {
|
||||
*x = Group{}
|
||||
if protoimpl.UnsafeEnabled {
|
||||
mi := &file_directory_proto_msgTypes[1]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
}
|
||||
|
||||
func (x *Group) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*Group) ProtoMessage() {}
|
||||
|
||||
func (x *Group) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_directory_proto_msgTypes[1]
|
||||
if protoimpl.UnsafeEnabled && x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use Group.ProtoReflect.Descriptor instead.
|
||||
func (*Group) Descriptor() ([]byte, []int) {
|
||||
return file_directory_proto_rawDescGZIP(), []int{1}
|
||||
}
|
||||
|
||||
func (x *Group) GetVersion() string {
|
||||
if x != nil {
|
||||
return x.Version
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *Group) GetId() string {
|
||||
if x != nil {
|
||||
return x.Id
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *Group) GetName() string {
|
||||
if x != nil {
|
||||
return x.Name
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *Group) GetEmail() string {
|
||||
if x != nil {
|
||||
return x.Email
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
type RefreshUserRequest struct {
|
||||
state protoimpl.MessageState
|
||||
sizeCache protoimpl.SizeCache
|
||||
unknownFields protoimpl.UnknownFields
|
||||
|
||||
UserId string `protobuf:"bytes,1,opt,name=user_id,json=userId,proto3" json:"user_id,omitempty"`
|
||||
AccessToken string `protobuf:"bytes,2,opt,name=access_token,json=accessToken,proto3" json:"access_token,omitempty"`
|
||||
}
|
||||
|
||||
func (x *RefreshUserRequest) Reset() {
|
||||
*x = RefreshUserRequest{}
|
||||
if protoimpl.UnsafeEnabled {
|
||||
mi := &file_directory_proto_msgTypes[2]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
}
|
||||
|
||||
func (x *RefreshUserRequest) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*RefreshUserRequest) ProtoMessage() {}
|
||||
|
||||
func (x *RefreshUserRequest) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_directory_proto_msgTypes[2]
|
||||
if protoimpl.UnsafeEnabled && x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use RefreshUserRequest.ProtoReflect.Descriptor instead.
|
||||
func (*RefreshUserRequest) Descriptor() ([]byte, []int) {
|
||||
return file_directory_proto_rawDescGZIP(), []int{2}
|
||||
}
|
||||
|
||||
func (x *RefreshUserRequest) GetUserId() string {
|
||||
if x != nil {
|
||||
return x.UserId
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *RefreshUserRequest) GetAccessToken() string {
|
||||
if x != nil {
|
||||
return x.AccessToken
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
var File_directory_proto protoreflect.FileDescriptor
|
||||
|
||||
var file_directory_proto_rawDesc = []byte{
|
||||
0x0a, 0x0f, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x79, 0x2e, 0x70, 0x72, 0x6f, 0x74,
|
||||
0x6f, 0x12, 0x09, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x79, 0x1a, 0x1b, 0x67, 0x6f,
|
||||
0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f, 0x65, 0x6d,
|
||||
0x70, 0x74, 0x79, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0x86, 0x01, 0x0a, 0x04, 0x55, 0x73,
|
||||
0x65, 0x72, 0x12, 0x18, 0x0a, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x01, 0x20,
|
||||
0x01, 0x28, 0x09, 0x52, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x0e, 0x0a, 0x02,
|
||||
0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x69, 0x64, 0x12, 0x1b, 0x0a, 0x09,
|
||||
0x67, 0x72, 0x6f, 0x75, 0x70, 0x5f, 0x69, 0x64, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x09, 0x52,
|
||||
0x08, 0x67, 0x72, 0x6f, 0x75, 0x70, 0x49, 0x64, 0x73, 0x12, 0x21, 0x0a, 0x0c, 0x64, 0x69, 0x73,
|
||||
0x70, 0x6c, 0x61, 0x79, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52,
|
||||
0x0b, 0x64, 0x69, 0x73, 0x70, 0x6c, 0x61, 0x79, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x14, 0x0a, 0x05,
|
||||
0x65, 0x6d, 0x61, 0x69, 0x6c, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x65, 0x6d, 0x61,
|
||||
0x69, 0x6c, 0x22, 0x5b, 0x0a, 0x05, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x12, 0x18, 0x0a, 0x07, 0x76,
|
||||
0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x76, 0x65,
|
||||
0x72, 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x0e, 0x0a, 0x02, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28,
|
||||
0x09, 0x52, 0x02, 0x69, 0x64, 0x12, 0x12, 0x0a, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x03, 0x20,
|
||||
0x01, 0x28, 0x09, 0x52, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x65, 0x6d, 0x61,
|
||||
0x69, 0x6c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x65, 0x6d, 0x61, 0x69, 0x6c, 0x22,
|
||||
0x50, 0x0a, 0x12, 0x52, 0x65, 0x66, 0x72, 0x65, 0x73, 0x68, 0x55, 0x73, 0x65, 0x72, 0x52, 0x65,
|
||||
0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x17, 0x0a, 0x07, 0x75, 0x73, 0x65, 0x72, 0x5f, 0x69, 0x64,
|
||||
0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x75, 0x73, 0x65, 0x72, 0x49, 0x64, 0x12, 0x21,
|
||||
0x0a, 0x0c, 0x61, 0x63, 0x63, 0x65, 0x73, 0x73, 0x5f, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x18, 0x02,
|
||||
0x20, 0x01, 0x28, 0x09, 0x52, 0x0b, 0x61, 0x63, 0x63, 0x65, 0x73, 0x73, 0x54, 0x6f, 0x6b, 0x65,
|
||||
0x6e, 0x32, 0x58, 0x0a, 0x10, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x79, 0x53, 0x65,
|
||||
0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x44, 0x0a, 0x0b, 0x52, 0x65, 0x66, 0x72, 0x65, 0x73, 0x68,
|
||||
0x55, 0x73, 0x65, 0x72, 0x12, 0x1d, 0x2e, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x79,
|
||||
0x2e, 0x52, 0x65, 0x66, 0x72, 0x65, 0x73, 0x68, 0x55, 0x73, 0x65, 0x72, 0x52, 0x65, 0x71, 0x75,
|
||||
0x65, 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f,
|
||||
0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x42, 0x31, 0x5a, 0x2f, 0x67,
|
||||
0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x70, 0x6f, 0x6d, 0x65, 0x72, 0x69,
|
||||
0x75, 0x6d, 0x2f, 0x70, 0x6f, 0x6d, 0x65, 0x72, 0x69, 0x75, 0x6d, 0x2f, 0x70, 0x6b, 0x67, 0x2f,
|
||||
0x67, 0x72, 0x70, 0x63, 0x2f, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x79, 0x62, 0x06,
|
||||
0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
|
||||
}
|
||||
|
||||
var (
|
||||
file_directory_proto_rawDescOnce sync.Once
|
||||
file_directory_proto_rawDescData = file_directory_proto_rawDesc
|
||||
)
|
||||
|
||||
func file_directory_proto_rawDescGZIP() []byte {
|
||||
file_directory_proto_rawDescOnce.Do(func() {
|
||||
file_directory_proto_rawDescData = protoimpl.X.CompressGZIP(file_directory_proto_rawDescData)
|
||||
})
|
||||
return file_directory_proto_rawDescData
|
||||
}
|
||||
|
||||
var file_directory_proto_msgTypes = make([]protoimpl.MessageInfo, 3)
|
||||
var file_directory_proto_goTypes = []interface{}{
|
||||
(*User)(nil), // 0: directory.User
|
||||
(*Group)(nil), // 1: directory.Group
|
||||
(*RefreshUserRequest)(nil), // 2: directory.RefreshUserRequest
|
||||
(*emptypb.Empty)(nil), // 3: google.protobuf.Empty
|
||||
}
|
||||
var file_directory_proto_depIdxs = []int32{
|
||||
2, // 0: directory.DirectoryService.RefreshUser:input_type -> directory.RefreshUserRequest
|
||||
3, // 1: directory.DirectoryService.RefreshUser:output_type -> google.protobuf.Empty
|
||||
1, // [1:2] is the sub-list for method output_type
|
||||
0, // [0:1] is the sub-list for method input_type
|
||||
0, // [0:0] is the sub-list for extension type_name
|
||||
0, // [0:0] is the sub-list for extension extendee
|
||||
0, // [0:0] is the sub-list for field type_name
|
||||
}
|
||||
|
||||
func init() { file_directory_proto_init() }
|
||||
func file_directory_proto_init() {
|
||||
if File_directory_proto != nil {
|
||||
return
|
||||
}
|
||||
if !protoimpl.UnsafeEnabled {
|
||||
file_directory_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} {
|
||||
switch v := v.(*User); i {
|
||||
case 0:
|
||||
return &v.state
|
||||
case 1:
|
||||
return &v.sizeCache
|
||||
case 2:
|
||||
return &v.unknownFields
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
file_directory_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} {
|
||||
switch v := v.(*Group); i {
|
||||
case 0:
|
||||
return &v.state
|
||||
case 1:
|
||||
return &v.sizeCache
|
||||
case 2:
|
||||
return &v.unknownFields
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
file_directory_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} {
|
||||
switch v := v.(*RefreshUserRequest); i {
|
||||
case 0:
|
||||
return &v.state
|
||||
case 1:
|
||||
return &v.sizeCache
|
||||
case 2:
|
||||
return &v.unknownFields
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
type x struct{}
|
||||
out := protoimpl.TypeBuilder{
|
||||
File: protoimpl.DescBuilder{
|
||||
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
|
||||
RawDescriptor: file_directory_proto_rawDesc,
|
||||
NumEnums: 0,
|
||||
NumMessages: 3,
|
||||
NumExtensions: 0,
|
||||
NumServices: 1,
|
||||
},
|
||||
GoTypes: file_directory_proto_goTypes,
|
||||
DependencyIndexes: file_directory_proto_depIdxs,
|
||||
MessageInfos: file_directory_proto_msgTypes,
|
||||
}.Build()
|
||||
File_directory_proto = out.File
|
||||
file_directory_proto_rawDesc = nil
|
||||
file_directory_proto_goTypes = nil
|
||||
file_directory_proto_depIdxs = nil
|
||||
}
|
||||
|
||||
// Reference imports to suppress errors if they are not otherwise used.
|
||||
var _ context.Context
|
||||
var _ grpc.ClientConnInterface
|
||||
|
||||
// This is a compile-time assertion to ensure that this generated file
|
||||
// is compatible with the grpc package it is being compiled against.
|
||||
const _ = grpc.SupportPackageIsVersion6
|
||||
|
||||
// DirectoryServiceClient is the client API for DirectoryService service.
|
||||
//
|
||||
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://godoc.org/google.golang.org/grpc#ClientConn.NewStream.
|
||||
type DirectoryServiceClient interface {
|
||||
RefreshUser(ctx context.Context, in *RefreshUserRequest, opts ...grpc.CallOption) (*emptypb.Empty, error)
|
||||
}
|
||||
|
||||
type directoryServiceClient struct {
|
||||
cc grpc.ClientConnInterface
|
||||
}
|
||||
|
||||
func NewDirectoryServiceClient(cc grpc.ClientConnInterface) DirectoryServiceClient {
|
||||
return &directoryServiceClient{cc}
|
||||
}
|
||||
|
||||
func (c *directoryServiceClient) RefreshUser(ctx context.Context, in *RefreshUserRequest, opts ...grpc.CallOption) (*emptypb.Empty, error) {
|
||||
out := new(emptypb.Empty)
|
||||
err := c.cc.Invoke(ctx, "/directory.DirectoryService/RefreshUser", in, out, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// DirectoryServiceServer is the server API for DirectoryService service.
|
||||
type DirectoryServiceServer interface {
|
||||
RefreshUser(context.Context, *RefreshUserRequest) (*emptypb.Empty, error)
|
||||
}
|
||||
|
||||
// UnimplementedDirectoryServiceServer can be embedded to have forward compatible implementations.
|
||||
type UnimplementedDirectoryServiceServer struct {
|
||||
}
|
||||
|
||||
func (*UnimplementedDirectoryServiceServer) RefreshUser(context.Context, *RefreshUserRequest) (*emptypb.Empty, error) {
|
||||
return nil, status.Errorf(codes.Unimplemented, "method RefreshUser not implemented")
|
||||
}
|
||||
|
||||
func RegisterDirectoryServiceServer(s *grpc.Server, srv DirectoryServiceServer) {
|
||||
s.RegisterService(&_DirectoryService_serviceDesc, srv)
|
||||
}
|
||||
|
||||
func _DirectoryService_RefreshUser_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(RefreshUserRequest)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(DirectoryServiceServer).RefreshUser(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: "/directory.DirectoryService/RefreshUser",
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(DirectoryServiceServer).RefreshUser(ctx, req.(*RefreshUserRequest))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
var _DirectoryService_serviceDesc = grpc.ServiceDesc{
|
||||
ServiceName: "directory.DirectoryService",
|
||||
HandlerType: (*DirectoryServiceServer)(nil),
|
||||
Methods: []grpc.MethodDesc{
|
||||
{
|
||||
MethodName: "RefreshUser",
|
||||
Handler: _DirectoryService_RefreshUser_Handler,
|
||||
},
|
||||
},
|
||||
Streams: []grpc.StreamDesc{},
|
||||
Metadata: "directory.proto",
|
||||
}
|
|
@ -1,30 +0,0 @@
|
|||
syntax = "proto3";
|
||||
|
||||
package directory;
|
||||
option go_package = "github.com/pomerium/pomerium/pkg/grpc/directory";
|
||||
|
||||
import "google/protobuf/empty.proto";
|
||||
|
||||
message User {
|
||||
string version = 1;
|
||||
string id = 2;
|
||||
repeated string group_ids = 3;
|
||||
string display_name = 4;
|
||||
string email = 5;
|
||||
}
|
||||
|
||||
message Group {
|
||||
string version = 1;
|
||||
string id = 2;
|
||||
string name = 3;
|
||||
string email = 4;
|
||||
}
|
||||
|
||||
message RefreshUserRequest {
|
||||
string user_id = 1;
|
||||
string access_token = 2;
|
||||
}
|
||||
|
||||
service DirectoryService {
|
||||
rpc RefreshUser(RefreshUserRequest) returns (google.protobuf.Empty);
|
||||
}
|
|
@ -25,14 +25,14 @@ type Provider struct {
|
|||
sizeCache protoimpl.SizeCache
|
||||
unknownFields protoimpl.UnknownFields
|
||||
|
||||
Id string `protobuf:"bytes,1,opt,name=id,proto3" json:"id,omitempty"`
|
||||
ClientId string `protobuf:"bytes,2,opt,name=client_id,json=clientId,proto3" json:"client_id,omitempty"`
|
||||
ClientSecret string `protobuf:"bytes,3,opt,name=client_secret,json=clientSecret,proto3" json:"client_secret,omitempty"`
|
||||
Type string `protobuf:"bytes,4,opt,name=type,proto3" json:"type,omitempty"`
|
||||
Scopes []string `protobuf:"bytes,5,rep,name=scopes,proto3" json:"scopes,omitempty"`
|
||||
ServiceAccount string `protobuf:"bytes,6,opt,name=service_account,json=serviceAccount,proto3" json:"service_account,omitempty"`
|
||||
Url string `protobuf:"bytes,7,opt,name=url,proto3" json:"url,omitempty"`
|
||||
RequestParams map[string]string `protobuf:"bytes,8,rep,name=request_params,json=requestParams,proto3" json:"request_params,omitempty" protobuf_key:"bytes,1,opt,name=key,proto3" protobuf_val:"bytes,2,opt,name=value,proto3"`
|
||||
Id string `protobuf:"bytes,1,opt,name=id,proto3" json:"id,omitempty"`
|
||||
ClientId string `protobuf:"bytes,2,opt,name=client_id,json=clientId,proto3" json:"client_id,omitempty"`
|
||||
ClientSecret string `protobuf:"bytes,3,opt,name=client_secret,json=clientSecret,proto3" json:"client_secret,omitempty"`
|
||||
Type string `protobuf:"bytes,4,opt,name=type,proto3" json:"type,omitempty"`
|
||||
Scopes []string `protobuf:"bytes,5,rep,name=scopes,proto3" json:"scopes,omitempty"`
|
||||
// string service_account = 6;
|
||||
Url string `protobuf:"bytes,7,opt,name=url,proto3" json:"url,omitempty"`
|
||||
RequestParams map[string]string `protobuf:"bytes,8,rep,name=request_params,json=requestParams,proto3" json:"request_params,omitempty" protobuf_key:"bytes,1,opt,name=key,proto3" protobuf_val:"bytes,2,opt,name=value,proto3"`
|
||||
}
|
||||
|
||||
func (x *Provider) Reset() {
|
||||
|
@ -102,13 +102,6 @@ func (x *Provider) GetScopes() []string {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (x *Provider) GetServiceAccount() string {
|
||||
if x != nil {
|
||||
return x.ServiceAccount
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *Provider) GetUrl() string {
|
||||
if x != nil {
|
||||
return x.Url
|
||||
|
@ -128,7 +121,7 @@ var File_identity_proto protoreflect.FileDescriptor
|
|||
var file_identity_proto_rawDesc = []byte{
|
||||
0x0a, 0x0e, 0x69, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f,
|
||||
0x12, 0x11, 0x70, 0x6f, 0x6d, 0x65, 0x72, 0x69, 0x75, 0x6d, 0x2e, 0x69, 0x64, 0x65, 0x6e, 0x74,
|
||||
0x69, 0x74, 0x79, 0x22, 0xdc, 0x02, 0x0a, 0x08, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72,
|
||||
0x69, 0x74, 0x79, 0x22, 0xb3, 0x02, 0x0a, 0x08, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72,
|
||||
0x12, 0x0e, 0x0a, 0x02, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x69, 0x64,
|
||||
0x12, 0x1b, 0x0a, 0x09, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x02, 0x20,
|
||||
0x01, 0x28, 0x09, 0x52, 0x08, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x49, 0x64, 0x12, 0x23, 0x0a,
|
||||
|
@ -136,24 +129,22 @@ var file_identity_proto_rawDesc = []byte{
|
|||
0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x53, 0x65, 0x63, 0x72,
|
||||
0x65, 0x74, 0x12, 0x12, 0x0a, 0x04, 0x74, 0x79, 0x70, 0x65, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09,
|
||||
0x52, 0x04, 0x74, 0x79, 0x70, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x73, 0x63, 0x6f, 0x70, 0x65, 0x73,
|
||||
0x18, 0x05, 0x20, 0x03, 0x28, 0x09, 0x52, 0x06, 0x73, 0x63, 0x6f, 0x70, 0x65, 0x73, 0x12, 0x27,
|
||||
0x0a, 0x0f, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x5f, 0x61, 0x63, 0x63, 0x6f, 0x75, 0x6e,
|
||||
0x74, 0x18, 0x06, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0e, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65,
|
||||
0x41, 0x63, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x12, 0x10, 0x0a, 0x03, 0x75, 0x72, 0x6c, 0x18, 0x07,
|
||||
0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x75, 0x72, 0x6c, 0x12, 0x55, 0x0a, 0x0e, 0x72, 0x65, 0x71,
|
||||
0x75, 0x65, 0x73, 0x74, 0x5f, 0x70, 0x61, 0x72, 0x61, 0x6d, 0x73, 0x18, 0x08, 0x20, 0x03, 0x28,
|
||||
0x0b, 0x32, 0x2e, 0x2e, 0x70, 0x6f, 0x6d, 0x65, 0x72, 0x69, 0x75, 0x6d, 0x2e, 0x69, 0x64, 0x65,
|
||||
0x6e, 0x74, 0x69, 0x74, 0x79, 0x2e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x2e, 0x52,
|
||||
0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x73, 0x45, 0x6e, 0x74, 0x72,
|
||||
0x79, 0x52, 0x0d, 0x72, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x73,
|
||||
0x1a, 0x40, 0x0a, 0x12, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x50, 0x61, 0x72, 0x61, 0x6d,
|
||||
0x73, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20,
|
||||
0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x14, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75,
|
||||
0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x3a, 0x02,
|
||||
0x38, 0x01, 0x42, 0x30, 0x5a, 0x2e, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d,
|
||||
0x2f, 0x70, 0x6f, 0x6d, 0x65, 0x72, 0x69, 0x75, 0x6d, 0x2f, 0x70, 0x6f, 0x6d, 0x65, 0x72, 0x69,
|
||||
0x75, 0x6d, 0x2f, 0x70, 0x6b, 0x67, 0x2f, 0x67, 0x72, 0x70, 0x63, 0x2f, 0x69, 0x64, 0x65, 0x6e,
|
||||
0x74, 0x69, 0x74, 0x79, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
|
||||
0x18, 0x05, 0x20, 0x03, 0x28, 0x09, 0x52, 0x06, 0x73, 0x63, 0x6f, 0x70, 0x65, 0x73, 0x12, 0x10,
|
||||
0x0a, 0x03, 0x75, 0x72, 0x6c, 0x18, 0x07, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x75, 0x72, 0x6c,
|
||||
0x12, 0x55, 0x0a, 0x0e, 0x72, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x5f, 0x70, 0x61, 0x72, 0x61,
|
||||
0x6d, 0x73, 0x18, 0x08, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x2e, 0x2e, 0x70, 0x6f, 0x6d, 0x65, 0x72,
|
||||
0x69, 0x75, 0x6d, 0x2e, 0x69, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x2e, 0x50, 0x72, 0x6f,
|
||||
0x76, 0x69, 0x64, 0x65, 0x72, 0x2e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x50, 0x61, 0x72,
|
||||
0x61, 0x6d, 0x73, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x52, 0x0d, 0x72, 0x65, 0x71, 0x75, 0x65, 0x73,
|
||||
0x74, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x73, 0x1a, 0x40, 0x0a, 0x12, 0x52, 0x65, 0x71, 0x75, 0x65,
|
||||
0x73, 0x74, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x73, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, 0x10, 0x0a,
|
||||
0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12,
|
||||
0x14, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05,
|
||||
0x76, 0x61, 0x6c, 0x75, 0x65, 0x3a, 0x02, 0x38, 0x01, 0x42, 0x30, 0x5a, 0x2e, 0x67, 0x69, 0x74,
|
||||
0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x70, 0x6f, 0x6d, 0x65, 0x72, 0x69, 0x75, 0x6d,
|
||||
0x2f, 0x70, 0x6f, 0x6d, 0x65, 0x72, 0x69, 0x75, 0x6d, 0x2f, 0x70, 0x6b, 0x67, 0x2f, 0x67, 0x72,
|
||||
0x70, 0x63, 0x2f, 0x69, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x62, 0x06, 0x70, 0x72, 0x6f,
|
||||
0x74, 0x6f, 0x33,
|
||||
}
|
||||
|
||||
var (
|
||||
|
|
|
@ -9,7 +9,7 @@ message Provider {
|
|||
string client_secret = 3;
|
||||
string type = 4;
|
||||
repeated string scopes = 5;
|
||||
string service_account = 6;
|
||||
// string service_account = 6;
|
||||
string url = 7;
|
||||
map<string, string> request_params = 8;
|
||||
}
|
||||
|
|
|
@ -92,11 +92,6 @@ _import_paths=$(join_by , "${_imports[@]}")
|
|||
--go_out="$_import_paths,plugins=grpc,paths=source_relative:./device/." \
|
||||
./device/device.proto
|
||||
|
||||
../../scripts/protoc -I ./directory/ \
|
||||
--go_out="$_import_paths,plugins=grpc,paths=source_relative:./directory/." \
|
||||
./directory/directory.proto
|
||||
|
||||
|
||||
../../scripts/protoc -I ./identity/ \
|
||||
--go_out="$_import_paths,plugins=grpc,paths=source_relative:./identity/." \
|
||||
./identity/identity.proto
|
||||
|
|
|
@ -1,82 +0,0 @@
|
|||
package criteria
|
||||
|
||||
import (
|
||||
"github.com/open-policy-agent/opa/ast"
|
||||
|
||||
"github.com/pomerium/pomerium/pkg/policy/generator"
|
||||
"github.com/pomerium/pomerium/pkg/policy/parser"
|
||||
"github.com/pomerium/pomerium/pkg/policy/rules"
|
||||
)
|
||||
|
||||
var groupsBody = ast.Body{
|
||||
ast.MustParseExpr(`
|
||||
session := get_session(input.session.id)
|
||||
`),
|
||||
ast.MustParseExpr(`
|
||||
directory_user := get_directory_user(session)
|
||||
`),
|
||||
ast.MustParseExpr(`
|
||||
group_ids := get_group_ids(session, directory_user)
|
||||
`),
|
||||
ast.MustParseExpr(`
|
||||
group_names := [directory_group.name |
|
||||
some i
|
||||
group_id := group_ids[i]
|
||||
directory_group := get_directory_group(group_id)
|
||||
directory_group != null
|
||||
directory_group.name != null]
|
||||
`),
|
||||
ast.MustParseExpr(`
|
||||
group_emails := [directory_group.email |
|
||||
some i
|
||||
group_id := group_ids[i]
|
||||
directory_group := get_directory_group(group_id)
|
||||
directory_group != null
|
||||
directory_group.email != null]
|
||||
`),
|
||||
ast.MustParseExpr(`
|
||||
groups = array.concat(group_ids, array.concat(group_names, group_emails))
|
||||
`),
|
||||
}
|
||||
|
||||
type groupsCriterion struct {
|
||||
g *Generator
|
||||
}
|
||||
|
||||
func (groupsCriterion) DataType() generator.CriterionDataType {
|
||||
return CriterionDataTypeStringListMatcher
|
||||
}
|
||||
|
||||
func (groupsCriterion) Name() string {
|
||||
return "groups"
|
||||
}
|
||||
|
||||
func (c groupsCriterion) GenerateRule(_ string, data parser.Value) (*ast.Rule, []*ast.Rule, error) {
|
||||
var body ast.Body
|
||||
body = append(body, groupsBody...)
|
||||
|
||||
err := matchStringList(&body, ast.VarTerm("groups"), data)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
rule := NewCriterionSessionRule(c.g, c.Name(),
|
||||
ReasonGroupsOK, ReasonGroupsUnauthorized,
|
||||
body)
|
||||
|
||||
return rule, []*ast.Rule{
|
||||
rules.GetSession(),
|
||||
rules.GetDirectoryUser(),
|
||||
rules.GetDirectoryGroup(),
|
||||
rules.GetGroupIDs(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Groups returns a Criterion on a user's group ids, names or emails.
|
||||
func Groups(generator *Generator) Criterion {
|
||||
return groupsCriterion{g: generator}
|
||||
}
|
||||
|
||||
func init() {
|
||||
Register(Groups)
|
||||
}
|
|
@ -1,100 +0,0 @@
|
|||
package criteria
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/pomerium/pomerium/pkg/grpc/directory"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/session"
|
||||
)
|
||||
|
||||
func TestGroups(t *testing.T) {
|
||||
t.Run("no session", func(t *testing.T) {
|
||||
res, err := evaluate(t, `
|
||||
allow:
|
||||
and:
|
||||
- groups:
|
||||
has: group1
|
||||
- groups:
|
||||
has: group2
|
||||
`, []dataBrokerRecord{}, Input{Session: InputSession{ID: "SESSION_ID"}})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, A{false, A{ReasonUserUnauthenticated}, M{}}, res["allow"])
|
||||
require.Equal(t, A{false, A{}}, res["deny"])
|
||||
})
|
||||
t.Run("by id", func(t *testing.T) {
|
||||
res, err := evaluate(t, `
|
||||
allow:
|
||||
and:
|
||||
- groups:
|
||||
has: group1
|
||||
`,
|
||||
[]dataBrokerRecord{
|
||||
&session.Session{
|
||||
Id: "SESSION_ID",
|
||||
UserId: "USER_ID",
|
||||
},
|
||||
&directory.User{
|
||||
Id: "USER_ID",
|
||||
GroupIds: []string{"group1"},
|
||||
},
|
||||
},
|
||||
Input{Session: InputSession{ID: "SESSION_ID"}})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, A{true, A{ReasonGroupsOK}, M{}}, res["allow"])
|
||||
require.Equal(t, A{false, A{}}, res["deny"])
|
||||
})
|
||||
t.Run("by email", func(t *testing.T) {
|
||||
res, err := evaluate(t, `
|
||||
allow:
|
||||
and:
|
||||
- groups:
|
||||
has: "group1@example.com"
|
||||
`,
|
||||
[]dataBrokerRecord{
|
||||
&session.Session{
|
||||
Id: "SESSION_ID",
|
||||
UserId: "USER_ID",
|
||||
},
|
||||
&directory.User{
|
||||
Id: "USER_ID",
|
||||
GroupIds: []string{"group1"},
|
||||
},
|
||||
&directory.Group{
|
||||
Id: "group1",
|
||||
Email: "group1@example.com",
|
||||
},
|
||||
},
|
||||
Input{Session: InputSession{ID: "SESSION_ID"}})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, A{true, A{ReasonGroupsOK}, M{}}, res["allow"])
|
||||
require.Equal(t, A{false, A{}}, res["deny"])
|
||||
})
|
||||
t.Run("by name", func(t *testing.T) {
|
||||
res, err := evaluate(t, `
|
||||
allow:
|
||||
and:
|
||||
- groups:
|
||||
has: "Group 1"
|
||||
`,
|
||||
[]dataBrokerRecord{
|
||||
&session.Session{
|
||||
Id: "SESSION_ID",
|
||||
UserId: "USER_ID",
|
||||
},
|
||||
&directory.User{
|
||||
Id: "USER_ID",
|
||||
GroupIds: []string{"group1"},
|
||||
},
|
||||
&directory.Group{
|
||||
Id: "group1",
|
||||
Name: "Group 1",
|
||||
},
|
||||
},
|
||||
Input{Session: InputSession{ID: "SESSION_ID"}})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, A{true, A{ReasonGroupsOK}, M{}}, res["allow"])
|
||||
require.Equal(t, A{false, A{}}, res["deny"])
|
||||
})
|
||||
}
|
|
@ -18,8 +18,6 @@ const (
|
|||
ReasonDomainUnauthorized = "domain-unauthorized"
|
||||
ReasonEmailOK = "email-ok"
|
||||
ReasonEmailUnauthorized = "email-unauthorized"
|
||||
ReasonGroupsOK = "groups-ok"
|
||||
ReasonGroupsUnauthorized = "groups-unauthorized"
|
||||
ReasonHTTPMethodOK = "http-method-ok"
|
||||
ReasonHTTPMethodUnauthorized = "http-method-unauthorized"
|
||||
ReasonHTTPPathOK = "http-path-ok"
|
||||
|
|
|
@ -74,42 +74,6 @@ get_device_enrollment(device_credential) = v {
|
|||
`)
|
||||
}
|
||||
|
||||
// GetDirectoryUser returns the directory user for the given session.
|
||||
func GetDirectoryUser() *ast.Rule {
|
||||
return ast.MustParseRule(`
|
||||
get_directory_user(session) = v {
|
||||
v = get_databroker_record("type.googleapis.com/directory.User", session.user_id)
|
||||
v != null
|
||||
} else = "" {
|
||||
true
|
||||
}
|
||||
`)
|
||||
}
|
||||
|
||||
// GetDirectoryGroup returns the directory group for the given id.
|
||||
func GetDirectoryGroup() *ast.Rule {
|
||||
return ast.MustParseRule(`
|
||||
get_directory_group(id) = v {
|
||||
v = get_databroker_record("type.googleapis.com/directory.Group", id)
|
||||
v != null
|
||||
} else = {} {
|
||||
true
|
||||
}
|
||||
`)
|
||||
}
|
||||
|
||||
// GetGroupIDs returns the group ids for the given session or directory user.
|
||||
func GetGroupIDs() *ast.Rule {
|
||||
return ast.MustParseRule(`
|
||||
get_group_ids(session, directory_user) = v {
|
||||
v = directory_user.group_ids
|
||||
v != null
|
||||
} else = [] {
|
||||
true
|
||||
}
|
||||
`)
|
||||
}
|
||||
|
||||
// MergeWithAnd merges criterion results using `and`.
|
||||
func MergeWithAnd() *ast.Rule {
|
||||
return ast.MustParseRule(`
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue