move directory providers (#3633)

* remove directory providers and support for groups

* idp: remove directory providers

* better error messages

* fix errors

* restore postgres

* fix test
This commit is contained in:
Caleb Doxsey 2022-11-03 11:33:56 -06:00 committed by GitHub
parent bb5c80bae9
commit c178819875
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
78 changed files with 723 additions and 8703 deletions

View file

@ -30,7 +30,6 @@ import (
"github.com/pomerium/pomerium/internal/urlutil"
"github.com/pomerium/pomerium/pkg/cryptutil"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/grpc/directory"
"github.com/pomerium/pomerium/pkg/grpc/session"
"github.com/pomerium/pomerium/pkg/grpc/user"
)
@ -544,31 +543,10 @@ func (a *Authenticate) getUserInfoData(r *http.Request) (handlers.UserInfoData,
Id: pbSession.GetUserId(),
}
}
pbDirectoryUser, err := a.getDirectoryUser(r.Context(), pbSession.GetUserId())
if err != nil {
pbDirectoryUser = &directory.User{
Id: pbSession.GetUserId(),
}
}
var groups []*directory.Group
for _, groupID := range pbDirectoryUser.GetGroupIds() {
pbDirectoryGroup, err := directory.GetGroup(r.Context(), state.dataBrokerClient, groupID)
if err != nil {
pbDirectoryGroup = &directory.Group{
Id: groupID,
Name: groupID,
Email: groupID,
}
}
groups = append(groups, pbDirectoryGroup)
}
creationOptions, requestOptions, _ := a.webauthn.GetOptions(r.Context())
return handlers.UserInfoData{
CSRFToken: csrf.Token(r),
DirectoryGroups: groups,
DirectoryUser: pbDirectoryUser,
IsImpersonated: isImpersonated,
Session: pbSession,
User: pbUser,
@ -645,14 +623,6 @@ func (a *Authenticate) saveSessionToDataBroker(
sessionState.DatabrokerServerVersion = res.GetServerVersion()
sessionState.DatabrokerRecordVersion = res.GetRecord().GetVersion()
_, err = state.directoryClient.RefreshUser(ctx, &directory.RefreshUserRequest{
UserId: s.UserId,
AccessToken: accessToken.AccessToken,
})
if err != nil {
log.Error(ctx).Err(err).Msg("directory: failed to refresh user data")
}
return nil
}
@ -718,11 +688,6 @@ func (a *Authenticate) getUser(ctx context.Context, userID string) (*user.User,
return user.Get(ctx, client, userID)
}
func (a *Authenticate) getDirectoryUser(ctx context.Context, userID string) (*directory.User, error) {
client := a.state.Load().dataBrokerClient
return directory.GetUser(ctx, client, userID)
}
func (a *Authenticate) getWebauthnState(ctx context.Context) (*webauthn.State, error) {
state := a.state.Load()

View file

@ -6,7 +6,6 @@ import (
"google.golang.org/protobuf/encoding/protojson"
"github.com/pomerium/pomerium/internal/directory"
"github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/pkg/grpc/session"
"github.com/pomerium/pomerium/pkg/grpc/user"
@ -17,8 +16,6 @@ import (
// UserInfoData is the data for the UserInfo page.
type UserInfoData struct {
CSRFToken string
DirectoryGroups []*directory.Group
DirectoryUser *directory.User
IsImpersonated bool
Session *session.Session
User *user.User
@ -34,16 +31,6 @@ type UserInfoData struct {
func (data UserInfoData) ToJSON() map[string]any {
m := map[string]any{}
m["csrfToken"] = data.CSRFToken
var directoryGroups []json.RawMessage
for _, directoryGroup := range data.DirectoryGroups {
if bs, err := protojson.Marshal(directoryGroup); err == nil {
directoryGroups = append(directoryGroups, json.RawMessage(bs))
}
}
m["directoryGroups"] = directoryGroups
if bs, err := protojson.Marshal(data.DirectoryUser); err == nil {
m["directoryUser"] = json.RawMessage(bs)
}
m["isImpersonated"] = data.IsImpersonated
if bs, err := protojson.Marshal(data.Session); err == nil {
m["session"] = json.RawMessage(bs)

View file

@ -14,14 +14,11 @@ import (
"github.com/go-jose/go-jose/v3/jwt"
"github.com/golang/mock/gomock"
"github.com/golang/protobuf/ptypes/empty"
"github.com/google/go-cmp/cmp"
"github.com/stretchr/testify/assert"
"golang.org/x/crypto/chacha20poly1305"
"golang.org/x/oauth2"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/pomerium/pomerium/authenticate/handlers/webauthn"
@ -38,7 +35,6 @@ import (
"github.com/pomerium/pomerium/internal/urlutil"
"github.com/pomerium/pomerium/pkg/cryptutil"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/grpc/directory"
"github.com/pomerium/pomerium/pkg/grpc/session"
)
@ -165,7 +161,6 @@ func TestAuthenticate_SignIn(t *testing.T) {
}, nil
},
},
directoryClient: new(mockDirectoryServiceClient),
}),
options: config.NewAtomicOptions(),
@ -321,7 +316,6 @@ func TestAuthenticate_SignOut(t *testing.T) {
return nil, nil
},
},
directoryClient: new(mockDirectoryServiceClient),
}),
options: config.NewAtomicOptions(),
}
@ -423,7 +417,6 @@ func TestAuthenticate_OAuthCallback(t *testing.T) {
return nil, nil
},
},
directoryClient: new(mockDirectoryServiceClient),
redirectURL: authURL,
sessionStore: tt.session,
cookieCipher: aead,
@ -565,7 +558,6 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) {
}, nil
},
},
directoryClient: new(mockDirectoryServiceClient),
}),
options: config.NewAtomicOptions(),
}
@ -681,7 +673,6 @@ func TestAuthenticate_userInfo(t *testing.T) {
}, nil
},
},
directoryClient: new(mockDirectoryServiceClient),
}),
}
a.webauthn = webauthn.New(a.getWebauthnState)
@ -723,19 +714,6 @@ func (m mockDataBrokerServiceClient) Put(ctx context.Context, in *databroker.Put
return m.put(ctx, in, opts...)
}
type mockDirectoryServiceClient struct {
directory.DirectoryServiceClient
refreshUser func(ctx context.Context, in *directory.RefreshUserRequest, opts ...grpc.CallOption) (*empty.Empty, error)
}
func (m mockDirectoryServiceClient) RefreshUser(ctx context.Context, in *directory.RefreshUserRequest, opts ...grpc.CallOption) (*empty.Empty, error) {
if m.refreshUser != nil {
return m.refreshUser(ctx, in, opts...)
}
return nil, status.Error(codes.Unimplemented, "")
}
func mustParseURL(rawurl string) *url.URL {
u, err := url.Parse(rawurl)
if err != nil {

View file

@ -30,7 +30,6 @@ func defaultGetIdentityProvider(options *config.Options, idpID string) (identity
ClientID: idp.GetClientId(),
ClientSecret: idp.GetClientSecret(),
Scopes: idp.GetScopes(),
ServiceAccount: idp.GetServiceAccount(),
AuthCodeOptions: idp.GetRequestParams(),
})
}

View file

@ -18,7 +18,6 @@ import (
"github.com/pomerium/pomerium/pkg/cryptutil"
"github.com/pomerium/pomerium/pkg/grpc"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/grpc/directory"
"github.com/pomerium/pomerium/pkg/webauthnutil"
"github.com/pomerium/webauthn"
)
@ -47,7 +46,6 @@ type authenticateState struct {
jwk *jose.JSONWebKeySet
dataBrokerClient databroker.DataBrokerServiceClient
directoryClient directory.DirectoryServiceClient
webauthnRelyingParty *webauthn.RelyingParty
}
@ -154,7 +152,6 @@ func newAuthenticateStateFromConfig(cfg *config.Config) (*authenticateState, err
}
state.dataBrokerClient = databroker.NewDataBrokerServiceClient(dataBrokerConn)
state.directoryClient = directory.NewDirectoryServiceClient(dataBrokerConn)
state.webauthnRelyingParty = webauthn.NewRelyingParty(
authenticateURL.String(),

View file

@ -14,7 +14,6 @@ import (
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/pkg/cryptutil"
"github.com/pomerium/pomerium/pkg/grpc/directory"
"github.com/pomerium/pomerium/pkg/grpc/session"
"github.com/pomerium/pomerium/pkg/grpc/user"
"github.com/pomerium/pomerium/pkg/policy/criteria"
@ -77,10 +76,6 @@ func TestEvaluator(t *testing.T) {
To: config.WeightedURLs{{URL: *mustParseURL("https://to7.example.com")}},
AllowedDomains: []string{"example.com"},
},
{
To: config.WeightedURLs{{URL: *mustParseURL("https://to8.example.com")}},
AllowedGroups: []string{"group1@example.com"},
},
{
To: config.WeightedURLs{{URL: *mustParseURL("https://to9.example.com")}},
AllowAnyAuthenticatedUser: true,
@ -375,39 +370,6 @@ func TestEvaluator(t *testing.T) {
require.NoError(t, err)
assert.True(t, res.Allow.Value)
})
t.Run("groups", func(t *testing.T) {
res, err := eval(t, options, []proto.Message{
&session.Session{
Id: "session1",
UserId: "user1",
},
&user.User{
Id: "user1",
Email: "a@example.com",
},
&directory.User{
Id: "user1",
GroupIds: []string{"group1"},
},
&directory.Group{
Id: "group1",
Name: "group1name",
Email: "group1@example.com",
},
}, &Request{
Policy: &policies[7],
Session: RequestSession{
ID: "session1",
},
HTTP: RequestHTTP{
Method: "GET",
URL: "https://from.example.com",
ClientCertificate: testValidCert,
},
})
require.NoError(t, err)
assert.True(t, res.Allow.Value)
})
t.Run("any authenticated user", func(t *testing.T) {
res, err := eval(t, options, []proto.Message{
&session.Session{
@ -473,7 +435,7 @@ func TestEvaluator(t *testing.T) {
})
t.Run("http method", func(t *testing.T) {
res, err := eval(t, options, []proto.Message{}, &Request{
Policy: &policies[9],
Policy: &policies[8],
HTTP: NewRequestHTTP(
"GET",
*mustParseURL("https://from.example.com/"),
@ -487,7 +449,7 @@ func TestEvaluator(t *testing.T) {
})
t.Run("http path", func(t *testing.T) {
res, err := eval(t, options, []proto.Message{}, &Request{
Policy: &policies[10],
Policy: &policies[9],
HTTP: NewRequestHTTP(
"POST",
*mustParseURL("https://from.example.com/test"),

View file

@ -15,9 +15,7 @@ import (
"github.com/pomerium/pomerium/authorize/internal/store"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/pkg/cryptutil"
"github.com/pomerium/pomerium/pkg/grpc/directory"
"github.com/pomerium/pomerium/pkg/grpc/session"
"github.com/pomerium/pomerium/pkg/grpc/user"
"github.com/pomerium/pomerium/pkg/storage"
)
@ -63,25 +61,6 @@ func TestHeadersEvaluator(t *testing.T) {
return e.Evaluate(ctx, input)
}
t.Run("groups", func(t *testing.T) {
output, err := eval(t,
[]proto.Message{
&session.Session{Id: "s1", UserId: "u1"},
&user.User{Id: "u1"},
&directory.User{Id: "u1", GroupIds: []string{"g1", "g2", "g3"}},
},
&HeadersRequest{
FromAudience: "from.example.com",
ToAudience: "to.example.com",
Session: RequestSession{
ID: "s1",
},
})
require.NoError(t, err)
assert.Equal(t, "g1,g2,g3", output.Headers.Get("X-Pomerium-Claim-Groups"))
})
t.Run("jwt", func(t *testing.T) {
output, err := eval(t,
[]proto.Message{

View file

@ -59,7 +59,7 @@ user = u {
}
directory_user = du {
du = get_databroker_record("type.googleapis.com/directory.User", session.user_id)
du = get_databroker_record("pomerium.io/DirectoryUser", session.user_id)
du != null
} else = {} {
true
@ -273,11 +273,11 @@ identity_headers := {key: values |
}
get_databroker_group_names(ids) = gs {
gs := [name | id := ids[i]; group := get_databroker_record("type.googleapis.com/directory.Group", id); name := group.name]
gs := [name | id := ids[i]; group := get_databroker_record("pomerium.io/DirectoryGroup", id); name := group.name]
}
get_databroker_group_emails(ids) = gs {
gs := [email | id := ids[i]; group := get_databroker_record("type.googleapis.com/directory.Group", id); email := group.email]
gs := [email | id := ids[i]; group := get_databroker_record("pomerium.io/DirectoryGroup", id); email := group.email]
}
get_header_string_value(obj) = s {

View file

@ -35,7 +35,6 @@ func (o *Options) GetIdentityProviderForPolicy(policy *Policy) (*identity.Provid
ClientSecret: clientSecret,
Type: o.Provider,
Scopes: o.Scopes,
ServiceAccount: o.ServiceAccount,
Url: o.ProviderURL,
RequestParams: o.RequestParams,
}

View file

@ -20,12 +20,6 @@ import (
"github.com/volatiletech/null/v9"
"github.com/pomerium/pomerium/internal/atomicutil"
"github.com/pomerium/pomerium/internal/directory/azure"
"github.com/pomerium/pomerium/internal/directory/github"
"github.com/pomerium/pomerium/internal/directory/gitlab"
"github.com/pomerium/pomerium/internal/directory/google"
"github.com/pomerium/pomerium/internal/directory/okta"
"github.com/pomerium/pomerium/internal/directory/onelogin"
"github.com/pomerium/pomerium/internal/hashutil"
"github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/identity/oauth"
@ -41,11 +35,6 @@ import (
// DisableHeaderKey is the key used to check whether to disable setting header
const DisableHeaderKey = "disable"
const (
idpCustomScopesDocLink = "https://www.pomerium.com/docs/reference/identity-provider-scopes"
idpCustomScopesWarnMsg = "config: using custom scopes may result in undefined behavior, see: " + idpCustomScopesDocLink
)
// DefaultAlternativeAddr is the address used is two services are competing over
// the same listener. Typically this is invisible to the end user (e.g. localhost)
// gRPC server, or is used for healthchecks (authorize only service)
@ -151,11 +140,6 @@ type Options struct {
Provider string `mapstructure:"idp_provider" yaml:"idp_provider,omitempty"`
ProviderURL string `mapstructure:"idp_provider_url" yaml:"idp_provider_url,omitempty"`
Scopes []string `mapstructure:"idp_scopes" yaml:"idp_scopes,omitempty"`
ServiceAccount string `mapstructure:"idp_service_account" yaml:"idp_service_account,omitempty"`
// Identity provider refresh directory interval/timeout settings.
RefreshDirectoryTimeout time.Duration `mapstructure:"idp_refresh_directory_timeout" yaml:"idp_refresh_directory_timeout,omitempty"`
RefreshDirectoryInterval time.Duration `mapstructure:"idp_refresh_directory_interval" yaml:"idp_refresh_directory_interval,omitempty"`
QPS float64 `mapstructure:"idp_qps" yaml:"idp_qps"`
// RequestParams are custom request params added to the signin request as
// part of an Oauth2 code flow.
@ -334,9 +318,6 @@ var defaultOptions = Options{
GRPCClientDNSRoundRobin: true,
AuthenticateCallbackPath: "/oauth2/callback",
TracingSampleRate: 0.0001,
RefreshDirectoryInterval: 10 * time.Minute,
RefreshDirectoryTimeout: 1 * time.Minute,
QPS: 1.0,
AutocertOptions: AutocertOptions{
Folder: dataDir(),
@ -698,17 +679,6 @@ func (o *Options) Validate() error {
}
}
// if no service account was defined, there should not be any policies that
// assert group membership (except for azure which can be derived from the client
// id, secret and provider url)
if o.ServiceAccount == "" && o.Provider != "azure" {
for _, p := range o.GetAllPolicies() {
if len(p.AllowedGroups) != 0 {
return fmt.Errorf("config: `allowed_groups` requires `idp_service_account`")
}
}
}
// strip quotes from redirect address (#811)
o.HTTPRedirectAddr = strings.Trim(o.HTTPRedirectAddr, `"'`)
@ -717,14 +687,6 @@ func (o *Options) Validate() error {
"`insecure_server` or manually provided certificates were provided, server will be using a self-signed certificate")
}
switch o.Provider {
case azure.Name, github.Name, gitlab.Name, google.Name, okta.Name, onelogin.Name:
if len(o.Scopes) > 0 {
log.Warn(ctx).Msg(idpCustomScopesWarnMsg)
}
default:
}
if err := ValidateDNSLookupFamily(o.DNSLookupFamily); err != nil {
return fmt.Errorf("config: %w", err)
}
@ -918,7 +880,6 @@ func (o *Options) GetOauthOptions() (oauth.Options, error) {
ClientID: o.ClientID,
ClientSecret: clientSecret,
Scopes: o.Scopes,
ServiceAccount: o.ServiceAccount,
}, nil
}
@ -1029,9 +990,6 @@ func (o *Options) GetSharedKey() ([]byte, error) {
// GetGoogleCloudServerlessAuthenticationServiceAccount gets the GoogleCloudServerlessAuthenticationServiceAccount.
func (o *Options) GetGoogleCloudServerlessAuthenticationServiceAccount() string {
if o.GoogleCloudServerlessAuthenticationServiceAccount == "" && o.Provider == "google" {
return o.ServiceAccount
}
return o.GoogleCloudServerlessAuthenticationServiceAccount
}
@ -1043,14 +1001,6 @@ func (o *Options) GetSetResponseHeaders() map[string]string {
return o.SetResponseHeaders
}
// GetQPS gets the QPS.
func (o *Options) GetQPS() float64 {
if o.QPS < 1 {
return 1
}
return o.QPS
}
// GetCodecType gets a codec type.
func (o *Options) GetCodecType() CodecType {
if o.CodecType == CodecTypeUnset {
@ -1393,15 +1343,6 @@ func (o *Options) ApplySettings(ctx context.Context, settings *config.Settings)
if len(settings.Scopes) > 0 {
o.Scopes = settings.Scopes
}
if settings.IdpServiceAccount != nil {
o.ServiceAccount = settings.GetIdpServiceAccount()
}
if settings.IdpRefreshDirectoryTimeout != nil {
o.RefreshDirectoryTimeout = settings.GetIdpRefreshDirectoryTimeout().AsDuration()
}
if settings.IdpRefreshDirectoryInterval != nil {
o.RefreshDirectoryInterval = settings.GetIdpRefreshDirectoryInterval().AsDuration()
}
if settings.RequestParams != nil && len(settings.RequestParams) > 0 {
o.RequestParams = settings.RequestParams
}

View file

@ -310,9 +310,6 @@ func TestOptionsFromViper(t *testing.T) {
"X-Frame-Options": "SAMEORIGIN",
"X-XSS-Protection": "1; mode=block",
},
RefreshDirectoryTimeout: 1 * time.Minute,
RefreshDirectoryInterval: 10 * time.Minute,
QPS: 1.0,
DataBrokerStorageType: "memory",
EnvoyAdminAccessLogPath: os.DevNull,
EnvoyAdminProfilePath: os.DevNull,
@ -330,9 +327,6 @@ func TestOptionsFromViper(t *testing.T) {
CookieHTTPOnly: true,
InsecureServer: true,
SetResponseHeaders: map[string]string{"disable": "true"},
RefreshDirectoryTimeout: 1 * time.Minute,
RefreshDirectoryInterval: 10 * time.Minute,
QPS: 1.0,
DataBrokerStorageType: "memory",
EnvoyAdminAccessLogPath: os.DevNull,
EnvoyAdminProfilePath: os.DevNull,
@ -342,7 +336,6 @@ func TestOptionsFromViper(t *testing.T) {
{"bad url", []byte(`{"policy":[{"from": "https://","to":"https://to.example"}]}`), nil, true},
{"bad policy", []byte(`{"policy":[{"allow_public_unauthenticated_access": "dog","to":"https://to.example"}]}`), nil, true},
{"bad file", []byte(`{''''}`), nil, true},
{"allowed_groups without idp_service_account should fail", []byte(`{"autocert_dir":"","insecure_server":true,"policy":[{"from": "https://from.example","to":"https://to.example","allowed_groups": "['group1']"}]}`), nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {

View file

@ -38,7 +38,6 @@ type Policy struct {
// Identity related policy
AllowedUsers []string `mapstructure:"allowed_users" yaml:"allowed_users,omitempty" json:"allowed_users,omitempty"`
AllowedGroups []string `mapstructure:"allowed_groups" yaml:"allowed_groups,omitempty" json:"allowed_groups,omitempty"`
AllowedDomains []string `mapstructure:"allowed_domains" yaml:"allowed_domains,omitempty" json:"allowed_domains,omitempty"`
AllowedIDPClaims identity.FlattenedClaims `mapstructure:"allowed_idp_claims" yaml:"allowed_idp_claims,omitempty" json:"allowed_idp_claims,omitempty"`
@ -192,7 +191,6 @@ type SubPolicy struct {
ID string `mapstructure:"id" yaml:"id" json:"id"`
Name string `mapstructure:"name" yaml:"name" json:"name"`
AllowedUsers []string `mapstructure:"allowed_users" yaml:"allowed_users,omitempty" json:"allowed_users,omitempty"`
AllowedGroups []string `mapstructure:"allowed_groups" yaml:"allowed_groups,omitempty" json:"allowed_groups,omitempty"`
AllowedDomains []string `mapstructure:"allowed_domains" yaml:"allowed_domains,omitempty" json:"allowed_domains,omitempty"`
AllowedIDPClaims identity.FlattenedClaims `mapstructure:"allowed_idp_claims" yaml:"allowed_idp_claims,omitempty" json:"allowed_idp_claims,omitempty"`
Rego []string `mapstructure:"rego" yaml:"rego" json:"rego,omitempty"`
@ -231,7 +229,6 @@ func NewPolicyFromProto(pb *configpb.Route) (*Policy, error) {
p := &Policy{
From: pb.GetFrom(),
AllowedUsers: pb.GetAllowedUsers(),
AllowedGroups: pb.GetAllowedGroups(),
AllowedDomains: pb.GetAllowedDomains(),
AllowedIDPClaims: identity.NewFlattenedClaimsFromPB(pb.GetAllowedIdpClaims()),
Prefix: pb.GetPrefix(),
@ -317,7 +314,6 @@ func NewPolicyFromProto(pb *configpb.Route) (*Policy, error) {
ID: sp.GetId(),
Name: sp.GetName(),
AllowedUsers: sp.GetAllowedUsers(),
AllowedGroups: sp.GetAllowedGroups(),
AllowedDomains: sp.GetAllowedDomains(),
AllowedIDPClaims: identity.NewFlattenedClaimsFromPB(sp.GetAllowedIdpClaims()),
Rego: sp.GetRego(),
@ -347,7 +343,6 @@ func (p *Policy) ToProto() (*configpb.Route, error) {
Id: sp.ID,
Name: sp.Name,
AllowedUsers: sp.AllowedUsers,
AllowedGroups: sp.AllowedGroups,
AllowedDomains: sp.AllowedDomains,
AllowedIdpClaims: sp.AllowedIDPClaims.ToPB(),
Rego: sp.Rego,
@ -358,7 +353,6 @@ func (p *Policy) ToProto() (*configpb.Route, error) {
Name: fmt.Sprint(p.RouteID()),
From: p.From,
AllowedUsers: p.AllowedUsers,
AllowedGroups: p.AllowedGroups,
AllowedDomains: p.AllowedDomains,
AllowedIdpClaims: p.AllowedIDPClaims.ToPB(),
Prefix: p.Prefix,
@ -466,12 +460,12 @@ func (p *Policy) Validate() error {
}
// Only allow public access if no other whitelists are in place
if p.AllowPublicUnauthenticatedAccess && (p.AllowAnyAuthenticatedUser || p.AllowedDomains != nil || p.AllowedGroups != nil || p.AllowedUsers != nil) {
if p.AllowPublicUnauthenticatedAccess && (p.AllowAnyAuthenticatedUser || p.AllowedDomains != nil || p.AllowedUsers != nil) {
return fmt.Errorf("config: policy route marked as public but contains whitelists")
}
// Only allow any authenticated user if no other whitelists are in place
if p.AllowAnyAuthenticatedUser && (p.AllowedDomains != nil || p.AllowedGroups != nil || p.AllowedUsers != nil) {
if p.AllowAnyAuthenticatedUser && (p.AllowedDomains != nil || p.AllowedUsers != nil) {
return fmt.Errorf("config: policy route marked accessible for any authenticated user but contains whitelists")
}
@ -642,16 +636,6 @@ func (p *Policy) AllAllowedDomains() []string {
return ads
}
// AllAllowedGroups returns all the allowed groups.
func (p *Policy) AllAllowedGroups() []string {
var ags []string
ags = append(ags, p.AllowedGroups...)
for _, sp := range p.SubPolicies {
ags = append(ags, sp.AllowedGroups...)
}
return ags
}
// AllAllowedIDPClaims returns all the allowed IDP claims.
func (p *Policy) AllAllowedIDPClaims() []identity.FlattenedClaims {
var aics []identity.FlattenedClaims

View file

@ -47,15 +47,6 @@ func (p *Policy) ToPPL() *parser.Policy {
},
})
}
for _, ag := range p.AllAllowedGroups() {
allowRule.Or = append(allowRule.Or,
parser.Criterion{
Name: "groups",
Data: parser.Object{
"has": parser.String(ag),
},
})
}
for _, aic := range p.AllAllowedIDPClaims() {
var ks []string
for k := range aic {

View file

@ -16,7 +16,6 @@ func TestPolicy_ToPPL(t *testing.T) {
CORSAllowPreflight: true,
AllowAnyAuthenticatedUser: true,
AllowedDomains: []string{"a.example.com", "b.example.com"},
AllowedGroups: []string{"group1", "group2"},
AllowedUsers: []string{"user1", "user2"},
AllowedIDPClaims: map[string][]interface{}{
"family_name": {"Smith", "Jones"},
@ -24,7 +23,6 @@ func TestPolicy_ToPPL(t *testing.T) {
SubPolicies: []SubPolicy{
{
AllowedDomains: []string{"c.example.com", "d.example.com"},
AllowedGroups: []string{"group3", "group4"},
AllowedUsers: []string{"user3", "user4"},
AllowedIDPClaims: map[string][]interface{}{
"given_name": {"John"},
@ -32,7 +30,6 @@ func TestPolicy_ToPPL(t *testing.T) {
},
{
AllowedDomains: []string{"e.example.com"},
AllowedGroups: []string{"group5"},
AllowedUsers: []string{"user5"},
AllowedIDPClaims: map[string][]interface{}{
"timezone": {"EST"},
@ -175,161 +172,6 @@ else = [false, {"user-unauthenticated"}] {
true
}
groups_0 = [true, {"groups-ok"}] {
session := get_session(input.session.id)
directory_user := get_directory_user(session)
group_ids := get_group_ids(session, directory_user)
group_names := [directory_group.name |
some i
group_id := group_ids[i]
directory_group := get_directory_group(group_id)
directory_group != null
directory_group.name != null
]
group_emails := [directory_group.email |
some i
group_id := group_ids[i]
directory_group := get_directory_group(group_id)
directory_group != null
directory_group.email != null
]
groups = array.concat(group_ids, array.concat(group_names, group_emails))
count([true | some v; v = groups[_0]; v == "group1"]) > 0
}
else = [false, {"groups-unauthorized"}] {
session := get_session(input.session.id)
session.id != ""
}
else = [false, {"user-unauthenticated"}] {
true
}
groups_1 = [true, {"groups-ok"}] {
session := get_session(input.session.id)
directory_user := get_directory_user(session)
group_ids := get_group_ids(session, directory_user)
group_names := [directory_group.name |
some i
group_id := group_ids[i]
directory_group := get_directory_group(group_id)
directory_group != null
directory_group.name != null
]
group_emails := [directory_group.email |
some i
group_id := group_ids[i]
directory_group := get_directory_group(group_id)
directory_group != null
directory_group.email != null
]
groups = array.concat(group_ids, array.concat(group_names, group_emails))
count([true | some v; v = groups[_0]; v == "group2"]) > 0
}
else = [false, {"groups-unauthorized"}] {
session := get_session(input.session.id)
session.id != ""
}
else = [false, {"user-unauthenticated"}] {
true
}
groups_2 = [true, {"groups-ok"}] {
session := get_session(input.session.id)
directory_user := get_directory_user(session)
group_ids := get_group_ids(session, directory_user)
group_names := [directory_group.name |
some i
group_id := group_ids[i]
directory_group := get_directory_group(group_id)
directory_group != null
directory_group.name != null
]
group_emails := [directory_group.email |
some i
group_id := group_ids[i]
directory_group := get_directory_group(group_id)
directory_group != null
directory_group.email != null
]
groups = array.concat(group_ids, array.concat(group_names, group_emails))
count([true | some v; v = groups[_0]; v == "group3"]) > 0
}
else = [false, {"groups-unauthorized"}] {
session := get_session(input.session.id)
session.id != ""
}
else = [false, {"user-unauthenticated"}] {
true
}
groups_3 = [true, {"groups-ok"}] {
session := get_session(input.session.id)
directory_user := get_directory_user(session)
group_ids := get_group_ids(session, directory_user)
group_names := [directory_group.name |
some i
group_id := group_ids[i]
directory_group := get_directory_group(group_id)
directory_group != null
directory_group.name != null
]
group_emails := [directory_group.email |
some i
group_id := group_ids[i]
directory_group := get_directory_group(group_id)
directory_group != null
directory_group.email != null
]
groups = array.concat(group_ids, array.concat(group_names, group_emails))
count([true | some v; v = groups[_0]; v == "group4"]) > 0
}
else = [false, {"groups-unauthorized"}] {
session := get_session(input.session.id)
session.id != ""
}
else = [false, {"user-unauthenticated"}] {
true
}
groups_4 = [true, {"groups-ok"}] {
session := get_session(input.session.id)
directory_user := get_directory_user(session)
group_ids := get_group_ids(session, directory_user)
group_names := [directory_group.name |
some i
group_id := group_ids[i]
directory_group := get_directory_group(group_id)
directory_group != null
directory_group.name != null
]
group_emails := [directory_group.email |
some i
group_id := group_ids[i]
directory_group := get_directory_group(group_id)
directory_group != null
directory_group.email != null
]
groups = array.concat(group_ids, array.concat(group_names, group_emails))
count([true | some v; v = groups[_0]; v == "group5"]) > 0
}
else = [false, {"groups-unauthorized"}] {
session := get_session(input.session.id)
session.id != ""
}
else = [false, {"user-unauthenticated"}] {
true
}
claim_0 = [true, {"claim-ok"}] {
rule_data := "Smith"
rule_path := "family_name"
@ -570,7 +412,7 @@ else = [false, {"user-unauthenticated"}] {
}
or_0 = v {
results := [pomerium_routes_0, accept_0, cors_preflight_0, authenticated_user_0, domain_0, domain_1, domain_2, domain_3, domain_4, groups_0, groups_1, groups_2, groups_3, groups_4, claim_0, claim_1, claim_2, claim_3, user_0, email_0, user_1, email_1, user_2, email_2, user_3, email_3, user_4, email_4]
results := [pomerium_routes_0, accept_0, cors_preflight_0, authenticated_user_0, domain_0, domain_1, domain_2, domain_3, domain_4, claim_0, claim_1, claim_2, claim_3, user_0, email_0, user_1, email_1, user_2, email_2, user_3, email_3, user_4, email_4]
normalized := [normalize_criterion_result(x) | x := results[i]]
v := merge_with_or(normalized)
}
@ -715,24 +557,6 @@ else = {} {
true
}
get_directory_user(session) = v {
v = get_databroker_record("type.googleapis.com/directory.User", session.user_id)
v != null
}
else = "" {
true
}
get_directory_group(id) = v {
v = get_databroker_record("type.googleapis.com/directory.Group", id)
v != null
}
else = {} {
true
}
get_user_email(session, user) = v {
v = user.email
}
@ -741,15 +565,6 @@ else = "" {
true
}
get_group_ids(session, directory_user) = v {
v = directory_user.group_ids
v != null
}
else = [] {
true
}
object_get(obj, key, def) = value {
undefined := "10a0fd35-0f1a-4e5b-97ce-631e89e1bafa"
value = object.get(obj, key, undefined)

View file

@ -106,7 +106,7 @@ func Test_PolicyRouteID(t *testing.T) {
{
"same",
&Policy{From: "https://pomerium.io", To: mustParseWeightedURLs(t, "http://localhost"), AllowedUsers: []string{"foo@bar.com"}},
&Policy{From: "https://pomerium.io", To: mustParseWeightedURLs(t, "http://localhost"), AllowedGroups: []string{"allusers"}},
&Policy{From: "https://pomerium.io", To: mustParseWeightedURLs(t, "http://localhost")},
true,
},
{

View file

@ -7,7 +7,6 @@ import (
"context"
"fmt"
"net"
"sync"
"github.com/rs/zerolog"
"golang.org/x/sync/errgroup"
@ -16,7 +15,6 @@ import (
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/atomicutil"
"github.com/pomerium/pomerium/internal/directory"
"github.com/pomerium/pomerium/internal/events"
"github.com/pomerium/pomerium/internal/identity"
"github.com/pomerium/pomerium/internal/identity/manager"
@ -41,9 +39,6 @@ type DataBroker struct {
localGRPCServer *grpc.Server
localGRPCConnection *grpc.ClientConn
sharedKey *atomicutil.Value[[]byte]
mu sync.Mutex
directoryProvider directory.Provider
}
// New creates a new databroker service.
@ -126,7 +121,6 @@ func (c *DataBroker) OnConfigChange(ctx context.Context, cfg *config.Config) {
// Register registers all the gRPC services with the given server.
func (c *DataBroker) Register(grpcServer *grpc.Server) {
databroker.RegisterDataBrokerServiceServer(grpcServer, c.dataBrokerServer)
directory.RegisterDirectoryServiceServer(grpcServer, c)
registry.RegisterRegistryServer(grpcServer, c.dataBrokerServer)
}
@ -163,30 +157,10 @@ func (c *DataBroker) update(ctx context.Context, cfg *config.Config) error {
return fmt.Errorf("databroker: invalid oauth options: %w", err)
}
clientSecret, err := cfg.Options.GetClientSecret()
if err != nil {
return fmt.Errorf("databroker: error retrieving IPD client secret: %w", err)
}
directoryProvider := directory.GetProvider(directory.Options{
ServiceAccount: cfg.Options.ServiceAccount,
Provider: cfg.Options.Provider,
ProviderURL: cfg.Options.ProviderURL,
QPS: cfg.Options.GetQPS(),
ClientID: cfg.Options.ClientID,
ClientSecret: clientSecret,
})
c.mu.Lock()
c.directoryProvider = directoryProvider
c.mu.Unlock()
dataBrokerClient := databroker.NewDataBrokerServiceClient(c.localGRPCConnection)
options := []manager.Option{
manager.WithDirectoryProvider(directoryProvider),
manager.WithDataBrokerClient(dataBrokerClient),
manager.WithGroupRefreshInterval(cfg.Options.RefreshDirectoryInterval),
manager.WithGroupRefreshTimeout(cfg.Options.RefreshDirectoryTimeout),
manager.WithEventManager(c.eventsMgr),
}

View file

@ -6,6 +6,7 @@ import (
"strconv"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
@ -89,7 +90,7 @@ func TestServerSync(t *testing.T) {
require.NoError(t, err)
_, err = client.Recv()
require.Error(t, err)
require.Equal(t, codes.Aborted, status.Code(err))
assert.Equal(t, codes.Aborted.String(), status.Code(err).String())
})
}

View file

@ -1,59 +0,0 @@
package databroker
import (
"context"
"errors"
"fmt"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/emptypb"
"github.com/pomerium/pomerium/internal/directory/directoryerrors"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/grpc/directory"
"github.com/pomerium/pomerium/pkg/protoutil"
)
// RefreshUser refreshes a user's directory information.
func (c *DataBroker) RefreshUser(ctx context.Context, req *directory.RefreshUserRequest) (*emptypb.Empty, error) {
c.mu.Lock()
dp := c.directoryProvider
c.mu.Unlock()
if dp == nil {
return nil, errors.New("no directory provider is available for refresh")
}
u, err := dp.User(ctx, req.GetUserId(), req.GetAccessToken())
// if the returned error signals we should prefer existing information
if errors.Is(err, directoryerrors.ErrPreferExistingInformation) {
_, err = c.dataBrokerServer.Get(ctx, &databroker.GetRequest{
Type: protoutil.GetTypeURL(new(directory.User)),
Id: req.GetUserId(),
})
switch status.Code(err) {
case codes.OK:
return new(emptypb.Empty), nil
case codes.NotFound: // go ahead and save the user that was returned
default:
return nil, fmt.Errorf("databroker: error retrieving existing user record for refresh: %w", err)
}
} else if err != nil {
return nil, err
}
any := protoutil.NewAny(u)
_, err = c.dataBrokerServer.Put(ctx, &databroker.PutRequest{
Records: []*databroker.Record{{
Type: any.GetTypeUrl(),
Id: u.GetId(),
Data: any,
}},
})
if err != nil {
return nil, err
}
return new(emptypb.Empty), nil
}

View file

@ -1,6 +1,5 @@
# See detailed configuration settings : https://www.pomerium.com/docs/reference/
# this is the domain the identity provider will callback after a user authenticates
authenticate_service_url: https://authenticate.localhost.pomerium.io
@ -20,7 +19,6 @@ certificate_key_file: /pomerium/privkey.pem
idp_provider: google
idp_client_id: REPLACE_ME
idp_client_secret: REPLACE_ME
#idp_service_account: REPLACE_ME # Required by some identity providers for directory sync
# Generate 256 bit random keys e.g. `head -c32 /dev/urandom | base64`
cookie_secret: V2JBZk0zWGtsL29UcFUvWjVDWWQ2UHExNXJ0b2VhcDI=

View file

@ -1,7 +1,6 @@
#!/bin/bash
# Main configuration flags : https://www.pomerium.com/docs/reference/
# Main configuration flags
# export ADDRESS=":8443" # optional, default is 443
# export POMERIUM_DEBUG=true # optional, default is false
@ -37,7 +36,6 @@ export COOKIE_SECRET="$(head -c32 /dev/urandom | base64)"
# export IDP_CLIENT_ID="REPLACEME" # from the application the users login to
# export IDP_CLIENT_SECRET="REPLACEME" # from the application the users login to
# the following is optional and only needed if you want role (Auth0 calls groups roles) data
# export IDP_SERVICE_ACCOUNT="REPLACEME" # built from the machine-to-machine application which talks to the Auth0 Management API
# Azure
# export IDP_PROVIDER="azure"
@ -64,8 +62,3 @@ export IDP_PROVIDER="google"
# directly as a base64 encoded yaml/json file, or as the policy key in the configuration
# file
export POLICY="$(base64 ./docs/configuration/examples/config/policy.example.yaml)"
# For Group data you must set an IDP_SERVICE_ACCOUNT
# https://www.pomerium.com/configuration/#identity-provider-service-account
# export IDP_SERVICE_ACCOUNT=$( echo YOUR_SERVICE_ACCOUNT | base64)
# For Google manually edit the service account to add the impersonate_user field before base64

View file

@ -40,8 +40,6 @@ authenticate_service_url: https://authenticate.localhost.pomerium.io
# idp_provider_url: "https://REPLACEME.us.auth0.com"
# idp_client_id: "REPLACEME" # from the application the users login to
# idp_client_secret: "REPLACEME" # from the application the users login to
# the following is optional and only needed if you want role (Auth0 calls groups roles) data
# idp_service_account: "REPLACEME" # built from the machine-to-machine application which talks to the Auth0 Management API
# Azure
# idp_provider: "azure"
@ -54,10 +52,6 @@ authenticate_service_url: https://authenticate.localhost.pomerium.io
# idp_client_id: "REPLACEME
# idp_client_secret: "REPLACEME
# IF GSUITE and you want to get user groups you will need to set a service account
# see identity provider docs for gooogle for more info :
# idp_service_account: $(echo '{"impersonate_user": "user@example.com"}' | base64)
# OKTA
# idp_provider: "okta"
# idp_client_id: "REPLACEME"
@ -70,9 +64,6 @@ authenticate_service_url: https://authenticate.localhost.pomerium.io
# idp_client_secret: "REPLACEME"
# idp_provider_url: "https://openid-connect.onelogin.com/oidc" #optional, defaults to `https://openid-connect.onelogin.com/oidc`
# For Group data you must set an IDP_SERVICE_ACCOUNT
# idp_service_account: YOUR_SERVICE_ACCOUNT
# Proxied routes and per-route policies are defined in a routes block
routes:
- from: https://verify.localhost.pomerium.io

View file

@ -22,7 +22,6 @@ services:
# - IDP_PROVIDER_URL=https://beyondperimeter.okta.com
# - IDP_CLIENT_ID=REPLACE_ME
# - IDP_CLIENT_SECRET=REPLACE_ME
# - IDP_SERVICE_ACCOUNT=REPLACE_ME
# NOTE! Generate new secret keys! e.g. `head -c32 /dev/urandom | base64`
# Generated secret keys must match between services
- SHARED_SECRET=aDducXQzK2tPY3R4TmdqTGhaYS80eGYxcTUvWWJDb2M=

View file

@ -3,7 +3,7 @@
# curl https://raw.githubusercontent.com/kubernetes/helm/master/scripts/get | bash
# NOTE! This will create real resources on Google's cloud. Make sure you clean up any unused
# resources to avoid being billed. For reference, this tutorial cost me <10 cents for a couple of hours.
# NOTE! You must change the identity provider client secret setting, and service account setting!
# NOTE! You must change the identity provider client secret setting!
# NOTE! If you are using gsuite, you should also set `authenticate.idp.serviceAccount`, see docs !
echo "=> [GCE] creating cluster"

View file

@ -14,9 +14,6 @@ override_certificate_name: "*.localhost.pomerium.io"
idp_provider: google
idp_client_id: REPLACE_ME.apps.googleusercontent.com
idp_client_secret: "REPLACE_ME"
# Required for group data
# https://www.pomerium.com/configuration/#identity-provider-service-account
idp_service_account: YOUR_SERVICE_ACCOUNT
routes:
- from: https://verify.localhost.pomerium.io

View file

@ -7,7 +7,6 @@ authenticate:
provider: "google"
clientID: YOUR_CLIENT_ID
clientSecret: YOUR_SECRET
serviceAccount: YOUR_SERVICE_ACCOUNT
proxied: false
proxy:

View file

@ -3,9 +3,6 @@ authenticate:
provider: "google"
clientID: YOUR_CLIENT_ID
clientSecret: YOUR_SECRET
# Required for group data
# https://www.pomerium.com/configuration/#identity-provider-service-account
serviceAccount: YOUR_SERVICE_ACCOUNT
service:
type: NodePort
annotations:

5
go.mod
View file

@ -57,8 +57,6 @@ require (
github.com/spf13/viper v1.13.0
github.com/stretchr/testify v1.8.1
github.com/tniswong/go.rfcx v0.0.0-20181019234604-07783c52761f
github.com/tomnomnom/linkheader v0.0.0-20180905144013-02ca5825eb80
github.com/vektah/gqlparser v1.3.1
github.com/volatiletech/null/v9 v9.0.0
github.com/yuin/gopher-lua v0.0.0-20200816102855-ee81675732da
go.opencensus.io v0.23.0
@ -72,7 +70,6 @@ require (
google.golang.org/genproto v0.0.0-20221018160656-63c7b68cfc55
google.golang.org/grpc v1.50.1
google.golang.org/protobuf v1.28.1
gopkg.in/auth0.v5 v5.21.1
gopkg.in/yaml.v3 v3.0.1
namespacelabs.dev/go-filenotify v0.0.0-20220511192020-53ea11be7eaa
sigs.k8s.io/yaml v1.3.0
@ -94,7 +91,6 @@ require (
github.com/Nvveen/Gotty v0.0.0-20120604004816-cd527374f1e5 // indirect
github.com/OneOfOne/xxhash v1.2.8 // indirect
github.com/OpenPeeDeeP/depguard v1.1.1 // indirect
github.com/PuerkitoBio/rehttp v1.0.0 // indirect
github.com/agnivade/levenshtein v1.1.1 // indirect
github.com/alexkohler/prealloc v1.0.0 // indirect
github.com/alingse/asasalint v0.0.11 // indirect
@ -160,7 +156,6 @@ require (
github.com/google/go-tpm v0.3.2 // indirect
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 // indirect
github.com/googleapis/enterprise-certificate-proxy v0.2.0 // indirect
github.com/googleapis/gax-go/v2 v2.6.0 // indirect
github.com/gordonklaus/ineffassign v0.0.0-20210914165742-4cc7213b9bc8 // indirect
github.com/gorilla/securecookie v1.1.1 // indirect
github.com/gostaticanalysis/analysisutil v0.7.1 // indirect

22
go.sum
View file

@ -102,13 +102,10 @@ github.com/OneOfOne/xxhash v1.2.8 h1:31czK/TI9sNkxIKfaUfGlU47BAxQ0ztGgd9vPyqimf8
github.com/OneOfOne/xxhash v1.2.8/go.mod h1:eZbhyaAYD41SGSSsnmcpxVoRiQ/MPUTjUdIIOT9Um7Q=
github.com/OpenPeeDeeP/depguard v1.1.1 h1:TSUznLjvp/4IUP+OQ0t/4jF4QUyxIcVX8YnghZdunyA=
github.com/OpenPeeDeeP/depguard v1.1.1/go.mod h1:JtAMzWkmFEzDPyAd+W0NHl1lvpQKTvT9jnRVsohBKpc=
github.com/PuerkitoBio/rehttp v1.0.0 h1:aJ7A7YI2lIvOxcJVeUZY4P6R7kKZtLeONjgyKGwOIu8=
github.com/PuerkitoBio/rehttp v1.0.0/go.mod h1:ItsOiHl4XeMOV3rzbZqQRjLc3QQxbE6391/9iNG7rE8=
github.com/Shopify/sarama v1.19.0/go.mod h1:FVkBWblsNy7DGZRfXLU0O9RCGt5g3g3yEuWXgklEdEo=
github.com/Shopify/toxiproxy v2.1.4+incompatible/go.mod h1:OXgGpZ6Cli1/URJOF1DMxUHB2q5Ap20/P/eIdh4G0pI=
github.com/VictoriaMetrics/fastcache v1.12.0 h1:vnVi/y9yKDcD9akmc4NqAoqgQhJrOwUF+j9LTgn4QDE=
github.com/VictoriaMetrics/fastcache v1.12.0/go.mod h1:tjiYeEfYXCqacuvYw/7UoDIeJaNxq6132xHICNP77w8=
github.com/agnivade/levenshtein v1.0.1/go.mod h1:CURSv5d9Uaml+FovSIICkLbAUZ9S4RqaHDIsdSBg7lM=
github.com/agnivade/levenshtein v1.1.1 h1:QY8M92nrzkmr798gCo3kmMyqXFzdQVpxLlGPRBij0P8=
github.com/agnivade/levenshtein v1.1.1/go.mod h1:veldBMzWxcCG2ZvUTKD2kJNRdCk5hVbJomOvKkmgYbo=
github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
@ -123,8 +120,6 @@ github.com/alingse/asasalint v0.0.11 h1:SFwnQXJ49Kx/1GghOFz1XGqHYKp21Kq1nHad/0WQ
github.com/alingse/asasalint v0.0.11/go.mod h1:nCaoMhw7a9kSJObvQyVzNTPBDbNpdocqrSP7t/cW5+I=
github.com/allegro/bigcache v1.2.1-0.20190218064605-e24eb225f156 h1:eMwmnE/GDgah4HI848JfFxHt+iPb26b4zyfspmqY0/8=
github.com/allegro/bigcache v1.2.1-0.20190218064605-e24eb225f156/go.mod h1:Cb/ax3seSYIx7SuZdm2G2xzfwmv3TPSk2ucNfQESPXM=
github.com/andreyvit/diff v0.0.0-20170406064948-c7f18ee00883 h1:bvNMNQO63//z+xNgfBlViaCIJKLlCJ6/fmUseuG0wVQ=
github.com/andreyvit/diff v0.0.0-20170406064948-c7f18ee00883/go.mod h1:rCTlJbsFo29Kk6CurOXKm700vrz8f0KW0JNfpkRJY/8=
github.com/andybalholm/brotli v1.0.4 h1:V7DdXeJtZscaqfNuAdSRuRFzuiKlHSC/Zh3zl9qY3JY=
github.com/andybalholm/brotli v1.0.4/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig=
github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY=
@ -135,9 +130,6 @@ github.com/ashanbrown/forbidigo v1.3.0 h1:VkYIwb/xxdireGAdJNZoo24O4lmnEWkactplBl
github.com/ashanbrown/forbidigo v1.3.0/go.mod h1:vVW7PEdqEFqapJe95xHkTfB1+XvZXBFg8t0sG2FIxmI=
github.com/ashanbrown/makezero v1.1.1 h1:iCQ87C0V0vSyO+M9E/FZYbu65auqH0lnsOkf5FcB28s=
github.com/ashanbrown/makezero v1.1.1/go.mod h1:i1bJLCRSCHOcOa9Y6MyF2FTfMZMFdHvxKHxgO5Z1axI=
github.com/aybabtme/iocontrol v0.0.0-20150809002002-ad15bcfc95a0 h1:0NmehRCgyk5rljDQLKUO+cRJCnduDyn11+zGZIc9Z48=
github.com/aybabtme/iocontrol v0.0.0-20150809002002-ad15bcfc95a0/go.mod h1:6L7zgvqo0idzI7IO8de6ZC051AfXb5ipkIJ7bIA2tGA=
github.com/benbjohnson/clock v1.0.3/go.mod h1:bGMdMPoPVvcYyt1gHDf4J2KE153Yf9BuiUKYMaxlTDM=
github.com/benbjohnson/clock v1.1.0 h1:Q92kusRqC1XV2MjkWETPvjJVqKetz1OzxZB7mHJLju8=
github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA=
github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q=
@ -245,7 +237,6 @@ github.com/docker/go-connections v0.4.0/go.mod h1:Gbd7IOopHjR8Iph03tsViu4nIes5Xh
github.com/docker/go-units v0.4.0 h1:3uh0PgVws3nIA0Q+MwDC8yjEPf9zjRfZZWXZYDct3Tw=
github.com/docker/go-units v0.4.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk=
github.com/dustin/go-humanize v1.0.0 h1:VSnTsYCnlFHaM2/igO1h6X3HA71jcobQuxemgkq4zYo=
github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk=
github.com/eapache/go-resiliency v1.1.0/go.mod h1:kFI+JgMyC7bLPUVY133qvEBtVayf5mFgVsvEsIPBvNs=
github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21/go.mod h1:+020luEh2TKB4/GOp8oxxtq0Daoen/Cii55CzbTV6DU=
github.com/eapache/queue v1.1.0/go.mod h1:6eCeP0CKFpHLu8blIFXhExK/dRa7WDZfr6jVFPTqq+I=
@ -482,8 +473,6 @@ github.com/googleapis/gax-go/v2 v2.1.1/go.mod h1:hddJymUZASv3XPyGkUpKj8pPO47Rmb0
github.com/googleapis/gax-go/v2 v2.2.0/go.mod h1:as02EH8zWkzwUoLbBaFeQ+arQaj/OthfcblKl4IGNaM=
github.com/googleapis/gax-go/v2 v2.3.0/go.mod h1:b8LNqSzNabLiUpXKkY7HAR5jr6bIT99EXz9pXxye9YM=
github.com/googleapis/gax-go/v2 v2.4.0/go.mod h1:XOTVJ59hdnfJLIP/dh8n5CGryZR2LxK9wbMD5+iXC6c=
github.com/googleapis/gax-go/v2 v2.6.0 h1:SXk3ABtQYDT/OH8jAyvEOQ58mgawq5C4o/4/89qN2ZU=
github.com/googleapis/gax-go/v2 v2.6.0/go.mod h1:1mjbznJAPHFpesgE5ucqfYEscaz5kMdcIDwU/6+DDoY=
github.com/googleapis/go-type-adapters v1.0.0/go.mod h1:zHW75FOG2aur7gAO2B+MLby+cLsWGBF62rFAi7WjWO4=
github.com/googleapis/google-cloud-go-testing v0.0.0-20200911160855-bcd43fbb19e8/go.mod h1:dvDLG8qkwmyD9a/MJJN3XJcT3xFxOKAvTZGvuZmac9g=
github.com/gordonklaus/ineffassign v0.0.0-20210914165742-4cc7213b9bc8 h1:PVRE9d4AQKmbelZ7emNig1+NT27DUmKZn5qXxfio54U=
@ -600,8 +589,6 @@ github.com/jingyugao/rowserrcheck v1.1.1/go.mod h1:4yvlZSDb3IyDTUZJUmpZfm2Hwok+D
github.com/jirfag/go-printf-func-name v0.0.0-20200119135958-7558a9eaa5af h1:KA9BjwUk7KlCh6S9EAGWBt1oExIUv9WyNCiRz5amv48=
github.com/jirfag/go-printf-func-name v0.0.0-20200119135958-7558a9eaa5af/go.mod h1:HEWGJkRDzjJY2sqdDwxccsGicWEf9BQOZsq2tV+xzM0=
github.com/jmoiron/sqlx v1.2.0/go.mod h1:1FEQNm3xlJgrMD+FBdI9+xvCksHtbpVBBw5dYhBSsks=
github.com/joho/godotenv v1.3.0 h1:Zjp+RcGpHhGlrMbJzXTrZZPrWj+1vfm90La1wgB6Bhc=
github.com/joho/godotenv v1.3.0/go.mod h1:7hK45KPybAkOC6peb+G5yklZfMxEjkZhHbwpqxOKXbg=
github.com/jonboulle/clockwork v0.1.0/go.mod h1:Ii8DK3G1RaLaWxj9trq07+26W01tbo22gdxWY5EU2bo=
github.com/jpillora/backoff v1.0.0/go.mod h1:J/6gKK9jxlEcS3zixgDgUAsiuZ7yrSoa/FX5e0EB2j4=
github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU=
@ -886,7 +873,6 @@ github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdh
github.com/seccomp/libseccomp-golang v0.9.2-0.20210429002308-3879420cc921/go.mod h1:JA8cRccbGaA1s33RQf7Y1+q9gHmZX1yB/z9WDN1C6fg=
github.com/securego/gosec/v2 v2.13.1 h1:7mU32qn2dyC81MH9L2kefnQyRMUarfDER3iQyMHcjYM=
github.com/securego/gosec/v2 v2.13.1/go.mod h1:EO1sImBMBWFjOTFzMWfTRrZW6M15gm60ljzrmy/wtHo=
github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo=
github.com/sergi/go-diff v1.1.0 h1:we8PVUC3FE2uYfodKH/nBHMSetSfHDR6scGdBi+erh0=
github.com/sergi/go-diff v1.1.0/go.mod h1:STckp+ISIX8hZLjrqAeVduY0gWCT9IjLuqbuNXdaHfM=
github.com/shazow/go-diff v0.0.0-20160112020656-b6b7b6733b8c h1:W65qqJCIOVP4jpqPQ0YvHYKwcMEMVWIzWC5iNQQfBTU=
@ -994,8 +980,6 @@ github.com/tomarrell/wrapcheck/v2 v2.7.0 h1:J/F8DbSKJC83bAvC6FoZaRjZiZ/iKoueSdrE
github.com/tomarrell/wrapcheck/v2 v2.7.0/go.mod h1:ao7l5p0aOlUNJKI0qVwB4Yjlqutd0IvAB9Rdwyilxvg=
github.com/tommy-muehle/go-mnd/v2 v2.5.1 h1:NowYhSdyE/1zwK9QCLeRb6USWdoif80Ie+v+yU8u1Zw=
github.com/tommy-muehle/go-mnd/v2 v2.5.1/go.mod h1:WsUAkMJMYww6l/ufffCD3m+P7LEvr8TnZn9lwVDlgzw=
github.com/tomnomnom/linkheader v0.0.0-20180905144013-02ca5825eb80 h1:nrZ3ySNYwJbSpD6ce9duiP+QkD3JuLCcWkdaehUS/3Y=
github.com/tomnomnom/linkheader v0.0.0-20180905144013-02ca5825eb80/go.mod h1:iFyPdL66DjUD96XmzVL3ZntbzcflLnznH0fr99w5VqE=
github.com/uber/jaeger-client-go v2.25.0+incompatible h1:IxcNZ7WRY1Y3G4poYlx24szfsn/3LvK9QHCq9oQw8+U=
github.com/uber/jaeger-client-go v2.25.0+incompatible/go.mod h1:WVhlPFC8FDjOFMMWRy2pZqQJSXxYSwNYOkTr/Z6d3Kk=
github.com/ugorji/go v1.1.4/go.mod h1:uQMGLiO92mf5W77hV/PUCpI3pbzQx3CRekS0kk+RGrc=
@ -1009,8 +993,6 @@ github.com/uudashr/gocognit v1.0.6 h1:2Cgi6MweCsdB6kpcVQp7EW4U23iBFQWfTXiWlyp842
github.com/uudashr/gocognit v1.0.6/go.mod h1:nAIUuVBnYU7pcninia3BHOvQkpQCeO76Uscky5BOwcY=
github.com/valyala/gozstd v1.11.0 h1:VV6qQFt+4sBBj9OJ7eKVvsFAMy59Urcs9Lgd+o5FOw0=
github.com/valyala/gozstd v1.11.0/go.mod h1:y5Ew47GLlP37EkTB+B4s7r6A5rdaeB7ftbl9zoYiIPQ=
github.com/vektah/gqlparser v1.3.1 h1:8b0IcD3qZKWJQHSzynbDlrtP3IxVydZ2DZepCGofqfU=
github.com/vektah/gqlparser v1.3.1/go.mod h1:bkVf0FX+Stjg/MHnm8mEyubuaArhNEqfQhF+OTiAL74=
github.com/vishvananda/netlink v1.1.0/go.mod h1:cTgwzPIzzgDAYoQrMm0EdrjRUBkTqKYppBueQtXaqoE=
github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df/go.mod h1:JP3t17pCcGlemwknint6hfoeCVQrEMVwxRLRjXpq+BU=
github.com/volatiletech/null/v9 v9.0.0 h1:JCdlHEiSRVxOi7/MABiEfdsqmuj9oTV20Ao7VvZ0JkE=
@ -1367,7 +1349,6 @@ golang.org/x/tools v0.0.0-20180221164845-07fd8470d635/go.mod h1:n7NCudcB/nEzxVGm
golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20190125232054-d66bd3c5d5a6/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY=
golang.org/x/tools v0.0.0-20190307163923-6a08e3108db3/go.mod h1:25r3+/G6/xytQM8iWZKq3Hn0kr0rgFKPUNVEL/dr3z4=
golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs=
@ -1467,7 +1448,6 @@ golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8T
golang.org/x/xerrors v0.0.0-20220411194840-2f41105eb62f/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20220517211312-f3a8303e98df/go.mod h1:K8+ghG5WaK9qNqU5K3HdILfMLy1f3aNYFI/wnl100a8=
golang.org/x/xerrors v0.0.0-20220609144429-65e65417b02f/go.mod h1:K8+ghG5WaK9qNqU5K3HdILfMLy1f3aNYFI/wnl100a8=
golang.org/x/xerrors v0.0.0-20220907171357-04be3eba64a2 h1:H2TDz8ibqkAF6YGhCdN3jS9O0/s90v0rJh3X/OLHEUk=
golang.org/x/xerrors v0.0.0-20220907171357-04be3eba64a2/go.mod h1:K8+ghG5WaK9qNqU5K3HdILfMLy1f3aNYFI/wnl100a8=
google.golang.org/api v0.4.0/go.mod h1:8k5glujaEP+g9n7WNsDg8QP6cUVNI86fCNMcbazEtwE=
google.golang.org/api v0.7.0/go.mod h1:WtwebWUNSVBH/HAw79HIFXZNqEvBhG+Ra+ax0hx3E3M=
@ -1657,8 +1637,6 @@ google.golang.org/protobuf v1.28.1/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqw
gopkg.in/DataDog/dd-trace-go.v1 v1.22.0 h1:gpWsqqkwUldNZXGJqT69NU9MdEDhLboK1C4nMgR0MWw=
gopkg.in/DataDog/dd-trace-go.v1 v1.22.0/go.mod h1:DVp8HmDh8PuTu2Z0fVVlBsyWaC++fzwVCaGWylTe3tg=
gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw=
gopkg.in/auth0.v5 v5.21.1 h1:aIqHBmnqaDv4eK2WSpTRsv2dEpT1jdHJPl+iwyDJNoo=
gopkg.in/auth0.v5 v5.21.1/go.mod h1:k1eJq1+II4rwUlecBabE7u4igEuzKUCEZAMa11PUfQk=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=

View file

@ -363,7 +363,7 @@ func (srv *Server) SyncLatest(req *databroker.SyncLatestRequest, stream databrok
}
}
if recordStream.Err() != nil {
return recordStream.Err()
return err
}
// always send the server version last in case there are no records

View file

@ -1,292 +0,0 @@
// Package auth0 contains the Auth0 directory provider.
// Note that Auth0 refers to groups as roles.
package auth0
import (
"context"
"errors"
"fmt"
"sort"
"github.com/rs/zerolog"
"gopkg.in/auth0.v5/management"
"github.com/pomerium/pomerium/internal/encoding"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/pkg/grpc/directory"
)
// Name is the provider name.
const Name = "auth0"
type (
// RoleManager defines what is needed to get role info from Auth0.
RoleManager interface {
List(opts ...management.RequestOption) (r *management.RoleList, err error)
Users(id string, opts ...management.RequestOption) (u *management.UserList, err error)
}
// UserManager defines what is needed to get user info from Auth0.
UserManager interface {
Read(id string, opts ...management.RequestOption) (*management.User, error)
Roles(id string, opts ...management.RequestOption) (r *management.RoleList, err error)
}
)
type newManagersFunc = func(ctx context.Context, domain string, serviceAccount *ServiceAccount) (RoleManager, UserManager, error)
func defaultNewManagersFunc(ctx context.Context, domain string, serviceAccount *ServiceAccount) (RoleManager, UserManager, error) {
// override the domain for the management api if supplied
if serviceAccount.Domain != "" {
domain = serviceAccount.Domain
}
m, err := management.New(domain,
management.WithClientCredentials(serviceAccount.ClientID, serviceAccount.Secret),
management.WithContext(ctx))
if err != nil {
return nil, nil, fmt.Errorf("auth0: could not build management: %w", err)
}
return m.Role, m.User, nil
}
type config struct {
domain string
serviceAccount *ServiceAccount
newManagers newManagersFunc
}
// Option provides config for the Auth0 Provider.
type Option func(cfg *config)
// WithServiceAccount sets the service account option.
func WithServiceAccount(serviceAccount *ServiceAccount) Option {
return func(cfg *config) {
cfg.serviceAccount = serviceAccount
}
}
// WithDomain sets the provider domain option.
func WithDomain(domain string) Option {
return func(cfg *config) {
cfg.domain = domain
}
}
func getConfig(options ...Option) *config {
cfg := &config{
newManagers: defaultNewManagersFunc,
}
for _, option := range options {
option(cfg)
}
return cfg
}
// Provider is an Auth0 user group directory provider.
type Provider struct {
cfg *config
log zerolog.Logger
}
// New creates a new Provider.
func New(options ...Option) *Provider {
return &Provider{
cfg: getConfig(options...),
log: log.With().Str("service", "directory").Str("provider", "auth0").Logger(),
}
}
func (p *Provider) getManagers(ctx context.Context) (RoleManager, UserManager, error) {
return p.cfg.newManagers(ctx, p.cfg.domain, p.cfg.serviceAccount)
}
// User returns the user record for the given id.
func (p *Provider) User(ctx context.Context, userID, accessToken string) (*directory.User, error) {
_, um, err := p.getManagers(ctx)
if err != nil {
return nil, fmt.Errorf("auth0: could not get the role manager: %w", err)
}
du := &directory.User{
Id: userID,
}
u, err := um.Read(userID)
if err != nil {
return nil, fmt.Errorf("auth0: error getting user info: %w", err)
}
du.DisplayName = u.GetName()
du.Email = u.GetEmail()
for page, hasNext := 0, true; hasNext; page++ {
rl, err := um.Roles(userID, management.IncludeTotals(true), management.Page(page))
if err != nil {
return nil, fmt.Errorf("auth0: error getting user roles: %w", err)
}
for _, role := range rl.Roles {
du.GroupIds = append(du.GroupIds, role.GetID())
}
hasNext = rl.HasNext()
}
sort.Strings(du.GroupIds)
return du, nil
}
// UserGroups fetches a slice of groups and users.
func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*directory.User, error) {
rm, _, err := p.getManagers(ctx)
if err != nil {
return nil, nil, fmt.Errorf("auth0: could not get the role manager: %w", err)
}
roles, err := getRoles(rm)
if err != nil {
return nil, nil, fmt.Errorf("auth0: %w", err)
}
userIDToGroups := map[string][]string{}
for _, role := range roles {
ids, err := getRoleUserIDs(rm, role.Id)
if err != nil {
return nil, nil, fmt.Errorf("auth0: %w", err)
}
for _, id := range ids {
userIDToGroups[id] = append(userIDToGroups[id], role.Id)
}
}
var users []*directory.User
for userID, groups := range userIDToGroups {
sort.Strings(groups)
users = append(users, &directory.User{
Id: userID,
GroupIds: groups,
})
}
sort.Slice(users, func(i, j int) bool {
return users[i].Id < users[j].Id
})
return roles, users, nil
}
func getRoles(rm RoleManager) ([]*directory.Group, error) {
roles := []*directory.Group{}
shouldContinue := true
page := 0
for shouldContinue {
listRes, err := rm.List(management.IncludeTotals(true), management.Page(page))
if err != nil {
return nil, fmt.Errorf("could not list roles: %w", err)
}
for _, role := range listRes.Roles {
roles = append(roles, &directory.Group{
Id: *role.ID,
Name: *role.Name,
})
}
page++
shouldContinue = listRes.HasNext()
}
sort.Slice(roles, func(i, j int) bool {
return roles[i].GetId() < roles[j].GetId()
})
return roles, nil
}
func getRoleUserIDs(rm RoleManager, roleID string) ([]string, error) {
var ids []string
shouldContinue := true
page := 0
for shouldContinue {
usersRes, err := rm.Users(roleID, management.IncludeTotals(true), management.Page(page))
if err != nil {
return nil, fmt.Errorf("could not get users for role %q: %w", roleID, err)
}
for _, user := range usersRes.Users {
ids = append(ids, *user.ID)
}
page++
shouldContinue = usersRes.HasNext()
}
sort.Strings(ids)
return ids, nil
}
// A ServiceAccount is used by the Auth0 provider to query the API.
type ServiceAccount struct {
Domain string `json:"domain"`
ClientID string `json:"client_id"`
ClientSecret string `json:"client_secret"`
Secret string `json:"secret"`
}
// ParseServiceAccount parses the service account in the config options.
func ParseServiceAccount(options directory.Options) (*ServiceAccount, error) {
if options.ServiceAccount != "" {
return parseServiceAccountFromString(options.ServiceAccount)
}
return parseServiceAccountFromOptions(options.ClientID, options.ClientSecret)
}
func parseServiceAccountFromOptions(clientID, clientSecret string) (*ServiceAccount, error) {
serviceAccount := ServiceAccount{
ClientID: clientID,
Secret: clientSecret,
}
if serviceAccount.ClientID == "" {
return nil, fmt.Errorf("auth0: client_id is required")
}
// for backwards compatibility we support secret and client_secret
if serviceAccount.Secret == "" {
serviceAccount.Secret = serviceAccount.ClientSecret
}
if serviceAccount.ClientSecret == "" {
serviceAccount.ClientSecret = serviceAccount.Secret
}
if serviceAccount.Secret == "" {
return nil, fmt.Errorf("auth0: client_secret is required")
}
return &serviceAccount, nil
}
func parseServiceAccountFromString(rawServiceAccount string) (*ServiceAccount, error) {
var serviceAccount ServiceAccount
if err := encoding.DecodeBase64OrJSON(rawServiceAccount, &serviceAccount); err != nil {
return nil, fmt.Errorf("auth0: could not unmarshal json: %w", err)
}
if serviceAccount.ClientID == "" {
return nil, errors.New("auth0: client_id is required")
}
// for backwards compatibility we support secret and client_secret
if serviceAccount.Secret == "" {
serviceAccount.Secret = serviceAccount.ClientSecret
}
if serviceAccount.ClientSecret == "" {
serviceAccount.ClientSecret = serviceAccount.Secret
}
if serviceAccount.Secret == "" {
return nil, errors.New("auth0: secret is required")
}
return &serviceAccount, nil
}

View file

@ -1,563 +0,0 @@
package auth0
import (
"context"
"crypto/tls"
"encoding/json"
"errors"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/go-chi/chi/v5"
"github.com/go-chi/chi/v5/middleware"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
"gopkg.in/auth0.v5/management"
"github.com/pomerium/pomerium/internal/directory/auth0/mock_auth0"
"github.com/pomerium/pomerium/internal/testutil"
"github.com/pomerium/pomerium/pkg/grpc/directory"
)
type M = map[string]interface{}
func newMockAPI(t *testing.T, srv *httptest.Server) http.Handler {
r := chi.NewRouter()
r.Use(middleware.Logger)
r.Post("/oauth/token", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(M{
"access_token": "ACCESSTOKEN",
"token_type": "Bearer",
"refresh_token": "REFRESHTOKEN",
})
})
r.Route("/api/v2", func(r chi.Router) {
r.Route("/users", func(r chi.Router) {
r.Route("/{user_id}", func(r chi.Router) {
r.Get("/roles", func(w http.ResponseWriter, r *http.Request) {
switch chi.URLParam(r, "user_id") {
case "user1":
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(M{
"total": 2,
"limit": 2,
"roles": []M{
{"id": "role1"},
{"id": "role2"},
},
})
default:
http.Error(w, "not found", http.StatusNotFound)
}
})
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
switch chi.URLParam(r, "user_id") {
case "user1":
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(M{
"user_id": "user1",
"email": "user1@example.com",
"name": "User 1",
})
default:
http.Error(w, "not found", http.StatusNotFound)
}
})
})
})
})
return r
}
func TestProvider_User(t *testing.T) {
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Minute)
defer clearTimeout()
orig := http.DefaultTransport
defer func() {
http.DefaultTransport = orig
}()
http.DefaultTransport = &http.Transport{
TLSClientConfig: &tls.Config{
InsecureSkipVerify: true,
},
}
var mockAPI http.Handler
srv := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
mockAPI.ServeHTTP(w, r)
}))
srv.StartTLS()
defer srv.Close()
mockAPI = newMockAPI(t, srv)
p := New(
WithDomain(srv.URL),
WithServiceAccount(&ServiceAccount{ClientID: "CLIENT_ID", Secret: "SECRET"}),
)
du, err := p.User(ctx, "user1", "")
if !assert.NoError(t, err) {
return
}
testutil.AssertProtoJSONEqual(t, `{
"id": "user1",
"displayName": "User 1",
"email": "user1@example.com",
"groupIds": ["role1", "role2"]
}`, du)
}
type mockNewRoleManagerFunc struct {
CalledWithContext context.Context
CalledWithDomain string
CalledWithServiceAccount *ServiceAccount
ReturnRoleManager RoleManager
ReturnUserManager UserManager
ReturnError error
}
func (m *mockNewRoleManagerFunc) f(ctx context.Context, domain string, serviceAccount *ServiceAccount) (RoleManager, UserManager, error) {
m.CalledWithContext = ctx
m.CalledWithDomain = domain
m.CalledWithServiceAccount = serviceAccount
return m.ReturnRoleManager, m.ReturnUserManager, m.ReturnError
}
type listOptionMatcher struct {
expected management.RequestOption
}
func buildValues(opt management.RequestOption) map[string][]string {
req, err := (&management.Management{}).NewRequest("GET", "example.com", nil, opt)
if err != nil {
panic(err)
}
return req.URL.Query()
}
func (lom listOptionMatcher) Matches(actual interface{}) bool {
return gomock.Eq(buildValues(lom.expected)).Matches(buildValues(actual.(management.RequestOption)))
}
func (lom listOptionMatcher) String() string {
return fmt.Sprintf("is equal to %v", buildValues(lom.expected))
}
func stringPtr(in string) *string {
return &in
}
func TestProvider_UserGroups(t *testing.T) {
expectedDomain := "login.example.com"
expectedServiceAccount := &ServiceAccount{Domain: "login-example.auth0.com", ClientID: "c_id", Secret: "secret"}
tests := []struct {
name string
setupRoleManagerExpectations func(*mock_auth0.MockRoleManager)
newRoleManagerError error
expectedGroups []*directory.Group
expectedUsers []*directory.User
expectedError error
}{
{
name: "errors if getting the role manager errors",
newRoleManagerError: errors.New("new role manager error"),
expectedError: errors.New("auth0: could not get the role manager: new role manager error"),
},
{
name: "errors if listing roles errors",
setupRoleManagerExpectations: func(mrm *mock_auth0.MockRoleManager) {
mrm.EXPECT().List(
listOptionMatcher{expected: management.IncludeTotals(true)},
listOptionMatcher{expected: management.Page(0)},
).Return(nil, errors.New("list error"))
},
expectedError: errors.New("auth0: could not list roles: list error"),
},
{
name: "errors if getting user ids errors",
setupRoleManagerExpectations: func(mrm *mock_auth0.MockRoleManager) {
mrm.EXPECT().List(
listOptionMatcher{expected: management.IncludeTotals(true)},
listOptionMatcher{expected: management.Page(0)},
).Return(&management.RoleList{
Roles: []*management.Role{
{
ID: stringPtr("i-am-role-id"),
Name: stringPtr("i-am-role-name"),
},
},
}, nil)
mrm.EXPECT().Users(
"i-am-role-id",
listOptionMatcher{expected: management.IncludeTotals(true)},
listOptionMatcher{expected: management.Page(0)},
).Return(nil, errors.New("users error"))
},
expectedError: errors.New("auth0: could not get users for role \"i-am-role-id\": users error"),
},
{
name: "handles multiple pages of roles",
setupRoleManagerExpectations: func(mrm *mock_auth0.MockRoleManager) {
mrm.EXPECT().List(
listOptionMatcher{expected: management.IncludeTotals(true)},
listOptionMatcher{expected: management.Page(0)},
).Return(&management.RoleList{
List: management.List{
Total: 3,
Start: 0,
Limit: 1,
},
Roles: []*management.Role{
{
ID: stringPtr("i-am-role-id-1"),
Name: stringPtr("i-am-role-name-1"),
},
},
}, nil)
mrm.EXPECT().List(
listOptionMatcher{expected: management.IncludeTotals(true)},
listOptionMatcher{expected: management.Page(1)},
).Return(&management.RoleList{
List: management.List{
Total: 3,
Start: 1,
Limit: 1,
},
Roles: []*management.Role{
{
ID: stringPtr("i-am-role-id-2"),
Name: stringPtr("i-am-role-name-2"),
},
},
}, nil)
mrm.EXPECT().List(
listOptionMatcher{expected: management.IncludeTotals(true)},
listOptionMatcher{expected: management.Page(2)},
).Return(&management.RoleList{
List: management.List{
Total: 3,
Start: 2,
Limit: 1,
},
Roles: []*management.Role{
{
ID: stringPtr("i-am-role-id-3"),
Name: stringPtr("i-am-role-name-3"),
},
},
}, nil)
mrm.EXPECT().Users(
"i-am-role-id-1",
listOptionMatcher{expected: management.IncludeTotals(true)},
listOptionMatcher{expected: management.Page(0)},
).Return(&management.UserList{}, nil)
mrm.EXPECT().Users(
"i-am-role-id-2",
listOptionMatcher{expected: management.IncludeTotals(true)},
listOptionMatcher{expected: management.Page(0)},
).Return(&management.UserList{}, nil)
mrm.EXPECT().Users(
"i-am-role-id-3",
listOptionMatcher{expected: management.IncludeTotals(true)},
listOptionMatcher{expected: management.Page(0)},
).Return(&management.UserList{}, nil)
},
expectedGroups: []*directory.Group{
{
Id: "i-am-role-id-1",
Name: "i-am-role-name-1",
},
{
Id: "i-am-role-id-2",
Name: "i-am-role-name-2",
},
{
Id: "i-am-role-id-3",
Name: "i-am-role-name-3",
},
},
},
{
name: "handles multiple pages of users",
setupRoleManagerExpectations: func(mrm *mock_auth0.MockRoleManager) {
mrm.EXPECT().List(
listOptionMatcher{expected: management.IncludeTotals(true)},
listOptionMatcher{expected: management.Page(0)},
).Return(&management.RoleList{
Roles: []*management.Role{
{
ID: stringPtr("i-am-role-id-1"),
Name: stringPtr("i-am-role-name-1"),
},
},
}, nil)
mrm.EXPECT().Users(
"i-am-role-id-1",
listOptionMatcher{expected: management.IncludeTotals(true)},
listOptionMatcher{expected: management.Page(0)},
).Return(&management.UserList{
List: management.List{
Total: 3,
Start: 0,
Limit: 1,
},
Users: []*management.User{
{ID: stringPtr("i-am-user-id-1")},
},
}, nil)
mrm.EXPECT().Users(
"i-am-role-id-1",
listOptionMatcher{expected: management.IncludeTotals(true)},
listOptionMatcher{expected: management.Page(1)},
).Return(&management.UserList{
List: management.List{
Total: 3,
Start: 1,
Limit: 1,
},
Users: []*management.User{
{ID: stringPtr("i-am-user-id-2")},
},
}, nil)
mrm.EXPECT().Users(
"i-am-role-id-1",
listOptionMatcher{expected: management.IncludeTotals(true)},
listOptionMatcher{expected: management.Page(2)},
).Return(&management.UserList{
List: management.List{
Total: 3,
Start: 2,
Limit: 1,
},
Users: []*management.User{
{ID: stringPtr("i-am-user-id-3")},
},
}, nil)
},
expectedGroups: []*directory.Group{
{
Id: "i-am-role-id-1",
Name: "i-am-role-name-1",
},
},
expectedUsers: []*directory.User{
{
Id: "i-am-user-id-1",
GroupIds: []string{"i-am-role-id-1"},
},
{
Id: "i-am-user-id-2",
GroupIds: []string{"i-am-role-id-1"},
},
{
Id: "i-am-user-id-3",
GroupIds: []string{"i-am-role-id-1"},
},
},
},
{
name: "correctly builds out groups and users",
setupRoleManagerExpectations: func(mrm *mock_auth0.MockRoleManager) {
mrm.EXPECT().List(
listOptionMatcher{expected: management.IncludeTotals(true)},
listOptionMatcher{expected: management.Page(0)},
).Return(&management.RoleList{
List: management.List{
Total: 2,
Start: 0,
Limit: 1,
},
Roles: []*management.Role{
{
ID: stringPtr("i-am-role-id-1"),
Name: stringPtr("i-am-role-name-1"),
},
},
}, nil)
mrm.EXPECT().List(
listOptionMatcher{expected: management.IncludeTotals(true)},
listOptionMatcher{expected: management.Page(1)},
).Return(&management.RoleList{
List: management.List{
Total: 2,
Start: 1,
Limit: 1,
},
Roles: []*management.Role{
{
ID: stringPtr("i-am-role-id-2"),
Name: stringPtr("i-am-role-name-2"),
},
},
}, nil)
mrm.EXPECT().Users(
"i-am-role-id-1",
listOptionMatcher{expected: management.IncludeTotals(true)},
listOptionMatcher{expected: management.Page(0)},
).Return(&management.UserList{
Users: []*management.User{
{ID: stringPtr("i-am-user-id-4")},
{ID: stringPtr("i-am-user-id-3")},
{ID: stringPtr("i-am-user-id-2")},
{ID: stringPtr("i-am-user-id-1")},
},
}, nil)
mrm.EXPECT().Users(
"i-am-role-id-2",
listOptionMatcher{expected: management.IncludeTotals(true)},
listOptionMatcher{expected: management.Page(0)},
).Return(&management.UserList{
Users: []*management.User{
{ID: stringPtr("i-am-user-id-1")},
{ID: stringPtr("i-am-user-id-4")},
{ID: stringPtr("i-am-user-id-5")},
},
}, nil)
},
expectedGroups: []*directory.Group{
{
Id: "i-am-role-id-1",
Name: "i-am-role-name-1",
},
{
Id: "i-am-role-id-2",
Name: "i-am-role-name-2",
},
},
expectedUsers: []*directory.User{
{
Id: "i-am-user-id-1",
GroupIds: []string{"i-am-role-id-1", "i-am-role-id-2"},
},
{
Id: "i-am-user-id-2",
GroupIds: []string{"i-am-role-id-1"},
},
{
Id: "i-am-user-id-3",
GroupIds: []string{"i-am-role-id-1"},
},
{
Id: "i-am-user-id-4",
GroupIds: []string{"i-am-role-id-1", "i-am-role-id-2"},
},
{
Id: "i-am-user-id-5",
GroupIds: []string{"i-am-role-id-2"},
},
},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mRoleManager := mock_auth0.NewMockRoleManager(ctrl)
mNewManagersFunc := mockNewRoleManagerFunc{
ReturnRoleManager: mRoleManager,
ReturnError: tc.newRoleManagerError,
}
if tc.setupRoleManagerExpectations != nil {
tc.setupRoleManagerExpectations(mRoleManager)
}
p := New(
WithDomain(expectedDomain),
WithServiceAccount(expectedServiceAccount),
)
p.cfg.newManagers = mNewManagersFunc.f
actualGroups, actualUsers, err := p.UserGroups(context.Background())
if tc.expectedError != nil {
assert.EqualError(t, err, tc.expectedError.Error())
} else {
assert.NoError(t, err)
}
assert.Equal(t, tc.expectedGroups, actualGroups)
assert.Equal(t, tc.expectedUsers, actualUsers)
assert.Equal(t, expectedDomain, mNewManagersFunc.CalledWithDomain)
assert.Equal(t, expectedServiceAccount, mNewManagersFunc.CalledWithServiceAccount)
})
}
}
func TestParseServiceAccount(t *testing.T) {
tests := []struct {
name string
rawServiceAccount string
expectedServiceAccount *ServiceAccount
expectedError error
}{
{
"valid", "eyJjbGllbnRfaWQiOiJpLWFtLWNsaWVudC1pZCIsInNlY3JldCI6ImktYW0tc2VjcmV0In0K",
&ServiceAccount{
ClientID: "i-am-client-id",
ClientSecret: "i-am-secret",
Secret: "i-am-secret",
},
nil,
},
{
"json", `{"client_id": "i-am-client-id", "client_secret": "i-am-secret"}`,
&ServiceAccount{
ClientID: "i-am-client-id",
ClientSecret: "i-am-secret",
Secret: "i-am-secret",
},
nil,
},
{
"domain", `{"domain": "example.auth0.com", "client_id": "i-am-client-id", "client_secret": "i-am-secret"}`,
&ServiceAccount{
ClientID: "i-am-client-id",
ClientSecret: "i-am-secret",
Domain: "example.auth0.com",
Secret: "i-am-secret",
},
nil,
},
{"base64 err", "!!!!", nil, errors.New("auth0: could not unmarshal json: illegal base64 data at input byte 0")},
{"json err", "PAo=", nil, errors.New("auth0: could not unmarshal json: invalid character '<' looking for beginning of value")},
{"no client_id", "eyJzZWNyZXQiOiJpLWFtLXNlY3JldCJ9Cg==", nil, errors.New("auth0: client_id is required")},
{"no secret", "eyJjbGllbnRfaWQiOiJpLWFtLWNsaWVudC1pZCJ9Cg==", nil, errors.New("auth0: secret is required")},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
actualServiceAccount, err := ParseServiceAccount(directory.Options{ServiceAccount: tc.rawServiceAccount})
if tc.expectedError != nil {
assert.EqualError(t, err, tc.expectedError.Error())
} else {
assert.NoError(t, err)
}
assert.Equal(t, tc.expectedServiceAccount, actualServiceAccount)
})
}
}

View file

@ -1,73 +0,0 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/pomerium/pomerium/internal/directory/auth0 (interfaces: RoleManager)
// Package mock_auth0 is a generated GoMock package.
package mock_auth0
import (
gomock "github.com/golang/mock/gomock"
management "gopkg.in/auth0.v5/management"
reflect "reflect"
)
// MockRoleManager is a mock of RoleManager interface
type MockRoleManager struct {
ctrl *gomock.Controller
recorder *MockRoleManagerMockRecorder
}
// MockRoleManagerMockRecorder is the mock recorder for MockRoleManager
type MockRoleManagerMockRecorder struct {
mock *MockRoleManager
}
// NewMockRoleManager creates a new mock instance
func NewMockRoleManager(ctrl *gomock.Controller) *MockRoleManager {
mock := &MockRoleManager{ctrl: ctrl}
mock.recorder = &MockRoleManagerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockRoleManager) EXPECT() *MockRoleManagerMockRecorder {
return m.recorder
}
// List mocks base method
func (m *MockRoleManager) List(arg0 ...management.RequestOption) (*management.RoleList, error) {
m.ctrl.T.Helper()
varargs := []interface{}{}
for _, a := range arg0 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "List", varargs...)
ret0, _ := ret[0].(*management.RoleList)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// List indicates an expected call of List
func (mr *MockRoleManagerMockRecorder) List(arg0 ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "List", reflect.TypeOf((*MockRoleManager)(nil).List), arg0...)
}
// Users mocks base method
func (m *MockRoleManager) Users(arg0 string, arg1 ...management.RequestOption) (*management.UserList, error) {
m.ctrl.T.Helper()
varargs := []interface{}{arg0}
for _, a := range arg1 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "Users", varargs...)
ret0, _ := ret[0].(*management.UserList)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Users indicates an expected call of Users
func (mr *MockRoleManagerMockRecorder) Users(arg0 interface{}, arg1 ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]interface{}{arg0}, arg1...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Users", reflect.TypeOf((*MockRoleManager)(nil).Users), varargs...)
}

View file

@ -1,47 +0,0 @@
package azure
import "strings"
type (
apiGetUserResponse struct {
apiUser
}
apiGetUserMembersResponse struct {
Context string `json:"@odata.context"`
NextLink string `json:"@odata.nextLink,omitempty"`
Value []apiGroup `json:"value"`
}
apiGroup struct {
ID string `json:"id"`
DisplayName string `json:"displayName"`
}
apiUser struct {
ID string `json:"id"`
DisplayName string `json:"displayName"`
Mail string `json:"mail"`
UserPrincipalName string `json:"userPrincipalName"`
}
)
func (obj apiUser) getEmail() string {
if obj.Mail != "" {
return obj.Mail
}
// AD often doesn't have the email address returned, but we can parse it from the UPN
// UPN looks like either:
// cdoxsey_pomerium.com#EXT#@cdoxseypomerium.onmicrosoft.com
// cdoxsey@pomerium.com
email := obj.UserPrincipalName
if idx := strings.Index(email, "#EXT"); idx > 0 {
email = email[:idx]
// find the last _ and replace it with @
if idx := strings.LastIndex(email, "_"); idx > 0 {
email = email[:idx] + "@" + email[idx+1:]
}
}
return email
}

View file

@ -1,334 +0,0 @@
// Package azure contains an azure active directory directory provider.
package azure
import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/url"
"sort"
"strings"
"sync"
"github.com/rs/zerolog"
"golang.org/x/oauth2"
"github.com/pomerium/pomerium/internal/encoding"
"github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/pkg/grpc/directory"
)
// Name is the provider name.
const Name = "azure"
const (
defaultGraphHost = "graph.microsoft.com"
defaultLoginHost = "login.microsoftonline.com"
defaultLoginScope = "https://graph.microsoft.com/.default"
defaultLoginGrantType = "client_credentials"
)
type config struct {
graphURL *url.URL
httpClient *http.Client
loginURL *url.URL
serviceAccount *ServiceAccount
}
// An Option updates the provider configuration.
type Option func(*config)
// WithGraphURL sets the graph URL for the configuration.
func WithGraphURL(graphURL *url.URL) Option {
return func(cfg *config) {
cfg.graphURL = graphURL
}
}
// WithLoginURL sets the login URL for the configuration.
func WithLoginURL(loginURL *url.URL) Option {
return func(cfg *config) {
cfg.loginURL = loginURL
}
}
// WithHTTPClient sets the http client to use for requests to the Azure APIs.
func WithHTTPClient(httpClient *http.Client) Option {
return func(cfg *config) {
cfg.httpClient = httputil.NewLoggingClient(httpClient, "azure_idp_client",
func(evt *zerolog.Event) *zerolog.Event {
return evt.Str("provider", "azure")
})
}
}
// WithServiceAccount sets the service account to use to access Azure.
func WithServiceAccount(serviceAccount *ServiceAccount) Option {
return func(cfg *config) {
cfg.serviceAccount = serviceAccount
}
}
func getConfig(options ...Option) *config {
cfg := new(config)
WithGraphURL(&url.URL{
Scheme: "https",
Host: defaultGraphHost,
})(cfg)
WithHTTPClient(http.DefaultClient)(cfg)
WithLoginURL(&url.URL{
Scheme: "https",
Host: defaultLoginHost,
})(cfg)
for _, option := range options {
option(cfg)
}
return cfg
}
// A Provider is a directory implementation using azure active directory.
type Provider struct {
cfg *config
dc *deltaCollection
mu sync.RWMutex
token *oauth2.Token
}
// New creates a new Provider.
func New(options ...Option) *Provider {
p := &Provider{
cfg: getConfig(options...),
}
p.dc = newDeltaCollection(p)
return p
}
// User returns the user record for the given id.
func (p *Provider) User(ctx context.Context, userID, accessToken string) (*directory.User, error) {
if p.cfg.serviceAccount == nil {
return nil, fmt.Errorf("azure: service account not defined")
}
du := &directory.User{
Id: userID,
}
userURL := p.cfg.graphURL.ResolveReference(&url.URL{
Path: fmt.Sprintf("/v1.0/users/%s", userID),
}).String()
var u apiGetUserResponse
err := p.api(ctx, userURL, &u)
if err != nil {
return nil, err
}
du.DisplayName = u.DisplayName
du.Email = u.getEmail()
du.GroupIds, err = p.transitiveMemberOf(ctx, userID)
if err != nil {
return nil, err
}
return du, nil
}
// UserGroups returns the directory users in azure active directory.
func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*directory.User, error) {
if p.cfg.serviceAccount == nil {
return nil, nil, fmt.Errorf("azure: service account not defined")
}
err := p.dc.Sync(ctx)
if err != nil {
return nil, nil, err
}
groups, users := p.dc.CurrentUserGroups()
return groups, users, nil
}
func (p *Provider) api(ctx context.Context, url string, out interface{}) error {
token, err := p.getToken(ctx)
if err != nil {
return err
}
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
return fmt.Errorf("azure: error creating HTTP request: %w", err)
}
req.Header.Set("Accept", "application/json")
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.AccessToken))
res, err := p.cfg.httpClient.Do(req)
if err != nil {
return fmt.Errorf("azure: error making HTTP request: %w", err)
}
defer res.Body.Close()
// if we get unauthorized, invalidate the token
if res.StatusCode == http.StatusUnauthorized {
p.mu.Lock()
p.token = nil
p.mu.Unlock()
}
if res.StatusCode/100 != 2 {
return fmt.Errorf("azure: error querying api (%s): %s", url, res.Status)
}
err = json.NewDecoder(res.Body).Decode(out)
if err != nil {
return fmt.Errorf("azure: error decoding api response: %w", err)
}
return nil
}
func (p *Provider) getToken(ctx context.Context) (*oauth2.Token, error) {
p.mu.RLock()
token := p.token
p.mu.RUnlock()
if token != nil && token.Valid() {
return token, nil
}
p.mu.Lock()
defer p.mu.Unlock()
token = p.token
if token != nil && token.Valid() {
return token, nil
}
tokenURL := p.cfg.loginURL.ResolveReference(&url.URL{
Path: fmt.Sprintf("/%s/oauth2/v2.0/token", p.cfg.serviceAccount.DirectoryID),
})
req, err := http.NewRequestWithContext(ctx, "POST", tokenURL.String(), strings.NewReader(url.Values{
"client_id": {p.cfg.serviceAccount.ClientID},
"client_secret": {p.cfg.serviceAccount.ClientSecret},
"scope": {defaultLoginScope},
"grant_type": {defaultLoginGrantType},
}.Encode()))
if err != nil {
return nil, fmt.Errorf("azure: error creating HTTP request: %w", err)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
res, err := p.cfg.httpClient.Do(req)
if err != nil {
return nil, err
}
defer res.Body.Close()
if res.StatusCode/100 != 2 {
return nil, fmt.Errorf("azure: error querying oauth2 token: %s", res.Status)
}
err = json.NewDecoder(res.Body).Decode(&token)
if err != nil {
return nil, fmt.Errorf("azure: error decoding oauth2 token: %w", err)
}
p.token = token
return p.token, nil
}
func (p *Provider) transitiveMemberOf(ctx context.Context, userID string) (groupIDs []string, err error) {
apiURL := p.cfg.graphURL.ResolveReference(&url.URL{
Path: fmt.Sprintf("/v1.0/users/%s/transitiveMemberOf", userID),
}).String()
for {
var res apiGetUserMembersResponse
err := p.api(ctx, apiURL, &res)
if err != nil {
return nil, err
}
for _, g := range res.Value {
groupIDs = append(groupIDs, g.ID)
}
if res.NextLink == "" {
break
}
apiURL = res.NextLink
}
sort.Strings(groupIDs)
return groupIDs, nil
}
// A ServiceAccount is used by the Azure provider to query the Microsoft Graph API.
type ServiceAccount struct {
ClientID string `json:"client_id"`
ClientSecret string `json:"client_secret"`
DirectoryID string `json:"directory_id"`
}
// ParseServiceAccount parses the service account in the config options.
func ParseServiceAccount(options directory.Options) (*ServiceAccount, error) {
if options.ServiceAccount != "" {
return parseServiceAccountFromString(options.ServiceAccount)
}
return parseServiceAccountFromOptions(options.ClientID, options.ClientSecret, options.ProviderURL)
}
func parseServiceAccountFromOptions(clientID, clientSecret, providerURL string) (*ServiceAccount, error) {
serviceAccount := ServiceAccount{
ClientID: clientID,
ClientSecret: clientSecret,
}
var err error
serviceAccount.DirectoryID, err = parseDirectoryIDFromURL(providerURL)
if err != nil {
return nil, err
}
if serviceAccount.ClientID == "" {
return nil, fmt.Errorf("client_id is required")
}
if serviceAccount.ClientSecret == "" {
return nil, fmt.Errorf("client_secret is required")
}
if serviceAccount.DirectoryID == "" {
return nil, fmt.Errorf("directory_id is required")
}
return &serviceAccount, nil
}
func parseServiceAccountFromString(rawServiceAccount string) (*ServiceAccount, error) {
var serviceAccount ServiceAccount
if err := encoding.DecodeBase64OrJSON(rawServiceAccount, &serviceAccount); err != nil {
return nil, err
}
if serviceAccount.ClientID == "" {
return nil, fmt.Errorf("client_id is required")
}
if serviceAccount.ClientSecret == "" {
return nil, fmt.Errorf("client_secret is required")
}
if serviceAccount.DirectoryID == "" {
return nil, fmt.Errorf("directory_id is required")
}
return &serviceAccount, nil
}
func parseDirectoryIDFromURL(providerURL string) (string, error) {
u, err := url.Parse(providerURL)
if err != nil {
return "", err
}
pathParts := strings.SplitN(u.Path, "/", 3)
if len(pathParts) != 3 {
return "", fmt.Errorf("no directory id found in path")
}
return pathParts[1], nil
}

View file

@ -1,282 +0,0 @@
package azure
import (
"context"
"encoding/base64"
"encoding/json"
"net/http"
"net/http/httptest"
"net/url"
"strconv"
"testing"
"github.com/go-chi/chi/v5"
"github.com/go-chi/chi/v5/middleware"
"github.com/stretchr/testify/assert"
"github.com/pomerium/pomerium/internal/testutil"
"github.com/pomerium/pomerium/pkg/grpc/directory"
)
type M = map[string]interface{}
func newMockAPI(t *testing.T, srv *httptest.Server) http.Handler {
r := chi.NewRouter()
r.Use(middleware.Logger)
r.Post("/DIRECTORY_ID/oauth2/v2.0/token", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "CLIENT_ID", r.FormValue("client_id"))
assert.Equal(t, "CLIENT_SECRET", r.FormValue("client_secret"))
assert.Equal(t, defaultLoginScope, r.FormValue("scope"))
assert.Equal(t, defaultLoginGrantType, r.FormValue("grant_type"))
_ = json.NewEncoder(w).Encode(M{
"access_token": "ACCESSTOKEN",
"token_type": "Bearer",
"refresh_token": "REFRESHTOKEN",
})
})
r.Route("/v1.0", func(r chi.Router) {
r.Use(func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Header.Get("Authorization") != "Bearer ACCESSTOKEN" {
http.Error(w, "forbidden", http.StatusForbidden)
return
}
next.ServeHTTP(w, r)
})
})
r.Get("/groups/delta", func(w http.ResponseWriter, r *http.Request) {
_ = json.NewEncoder(w).Encode(M{
"value": []M{
{
"id": "admin",
"displayName": "Admin Group",
"members@delta": []M{
{"@odata.type": "#microsoft.graph.user", "id": "user-1"},
},
},
{
"id": "test",
"displayName": "Test Group",
"members@delta": []M{
{"@odata.type": "#microsoft.graph.user", "id": "user-2"},
{"@odata.type": "#microsoft.graph.user", "id": "user-3"},
{"@odata.type": "#microsoft.graph.user", "id": "user-4"},
},
},
},
})
})
r.Get("/users/delta", func(w http.ResponseWriter, r *http.Request) {
_ = json.NewEncoder(w).Encode(M{
"value": []M{
{"id": "user-1", "displayName": "User 1", "mail": "user1@example.com"},
{"id": "user-2", "displayName": "User 2", "mail": "user2@example.com"},
{"id": "user-3", "displayName": "User 3", "userPrincipalName": "user3_example.com#EXT#@user3example.onmicrosoft.com"},
{"id": "user-4", "displayName": "User 4", "userPrincipalName": "user4@example.com"},
},
})
})
r.Get("/users/{user_id}", func(w http.ResponseWriter, r *http.Request) {
switch chi.URLParam(r, "user_id") {
case "user-1":
_ = json.NewEncoder(w).Encode(M{"id": "user-1", "displayName": "User 1", "mail": "user1@example.com"})
default:
http.Error(w, "not found", http.StatusNotFound)
}
})
r.Get("/users/{user_id}/transitiveMemberOf", func(w http.ResponseWriter, r *http.Request) {
switch chi.URLParam(r, "user_id") {
case "user-1":
switch r.URL.Query().Get("page") {
case "":
_ = json.NewEncoder(w).Encode(M{
"value": []M{
{"id": "admin"},
},
"@odata.nextLink": getPageURL(r, 1),
})
case "1":
_ = json.NewEncoder(w).Encode(M{
"value": []M{
{"id": "group1"},
},
"@odata.nextLink": getPageURL(r, 2),
})
case "2":
_ = json.NewEncoder(w).Encode(M{
"value": []M{
{"id": "group2"},
},
})
}
default:
http.Error(w, "not found", http.StatusNotFound)
}
})
})
return r
}
func TestProvider_User(t *testing.T) {
var mockAPI http.Handler
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
mockAPI.ServeHTTP(w, r)
}))
defer srv.Close()
mockAPI = newMockAPI(t, srv)
p := New(
WithGraphURL(mustParseURL(srv.URL)),
WithLoginURL(mustParseURL(srv.URL)),
WithServiceAccount(&ServiceAccount{
ClientID: "CLIENT_ID",
ClientSecret: "CLIENT_SECRET",
DirectoryID: "DIRECTORY_ID",
}),
)
du, err := p.User(context.Background(), "user-1", "")
if !assert.NoError(t, err) {
return
}
testutil.AssertProtoJSONEqual(t, `{
"id": "user-1",
"displayName": "User 1",
"email": "user1@example.com",
"groupIds": ["admin", "group1", "group2"]
}`, du)
}
func TestProvider_UserGroups(t *testing.T) {
var mockAPI http.Handler
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
mockAPI.ServeHTTP(w, r)
}))
defer srv.Close()
mockAPI = newMockAPI(t, srv)
p := New(
WithGraphURL(mustParseURL(srv.URL)),
WithLoginURL(mustParseURL(srv.URL)),
WithServiceAccount(&ServiceAccount{
ClientID: "CLIENT_ID",
ClientSecret: "CLIENT_SECRET",
DirectoryID: "DIRECTORY_ID",
}),
)
groups, users, err := p.UserGroups(context.Background())
assert.NoError(t, err)
assert.Equal(t, []*directory.User{
{
Id: "user-1",
GroupIds: []string{"admin"},
DisplayName: "User 1",
Email: "user1@example.com",
},
{
Id: "user-2",
GroupIds: []string{"test"},
DisplayName: "User 2",
Email: "user2@example.com",
},
{
Id: "user-3",
GroupIds: []string{"test"},
DisplayName: "User 3",
Email: "user3@example.com",
},
{
Id: "user-4",
GroupIds: []string{"test"},
DisplayName: "User 4",
Email: "user4@example.com",
},
}, users)
testutil.AssertProtoJSONEqual(t, `[
{ "id": "admin", "name": "Admin Group" },
{ "id": "test", "name": "Test Group"}
]`, groups)
}
func TestParseServiceAccount(t *testing.T) {
t.Run("by options", func(t *testing.T) {
serviceAccount, err := ParseServiceAccount(directory.Options{
ProviderURL: "https://login.microsoftonline.com/0303f438-3c5c-4190-9854-08d3eb31bd9f/v2.0",
ClientID: "CLIENT_ID",
ClientSecret: "CLIENT_SECRET",
})
if !assert.NoError(t, err) {
return
}
assert.Equal(t, &ServiceAccount{
ClientID: "CLIENT_ID",
ClientSecret: "CLIENT_SECRET",
DirectoryID: "0303f438-3c5c-4190-9854-08d3eb31bd9f",
}, serviceAccount)
})
t.Run("by service account base64", func(t *testing.T) {
serviceAccount, err := ParseServiceAccount(directory.Options{
ServiceAccount: base64.StdEncoding.EncodeToString([]byte(`{
"client_id": "CLIENT_ID",
"client_secret": "CLIENT_SECRET",
"directory_id": "0303f438-3c5c-4190-9854-08d3eb31bd9f"
}`)),
})
if !assert.NoError(t, err) {
return
}
assert.Equal(t, &ServiceAccount{
ClientID: "CLIENT_ID",
ClientSecret: "CLIENT_SECRET",
DirectoryID: "0303f438-3c5c-4190-9854-08d3eb31bd9f",
}, serviceAccount)
})
t.Run("by service account json", func(t *testing.T) {
serviceAccount, err := ParseServiceAccount(directory.Options{
ServiceAccount: `{
"client_id": "CLIENT_ID",
"client_secret": "CLIENT_SECRET",
"directory_id": "0303f438-3c5c-4190-9854-08d3eb31bd9f"
}`,
})
if !assert.NoError(t, err) {
return
}
assert.Equal(t, &ServiceAccount{
ClientID: "CLIENT_ID",
ClientSecret: "CLIENT_SECRET",
DirectoryID: "0303f438-3c5c-4190-9854-08d3eb31bd9f",
}, serviceAccount)
})
}
func mustParseURL(rawurl string) *url.URL {
u, err := url.Parse(rawurl)
if err != nil {
panic(err)
}
return u
}
func getPageURL(r *http.Request, page int) string {
var u url.URL
u = *r.URL
if r.TLS == nil {
u.Scheme = "http"
} else {
u.Scheme = "https"
}
if u.Host == "" {
u.Host = r.Host
}
q := u.Query()
q.Set("page", strconv.Itoa(page))
u.RawQuery = q.Encode()
return u.String()
}

View file

@ -1,248 +0,0 @@
package azure
import (
"context"
"net/url"
"sort"
"github.com/pomerium/pomerium/pkg/grpc/directory"
)
const (
groupsDeltaPath = "/v1.0/groups/delta"
usersDeltaPath = "/v1.0/users/delta"
)
type (
deltaCollection struct {
provider *Provider
groups map[string]deltaGroup
groupDeltaLink string
users map[string]deltaUser
userDeltaLink string
}
deltaGroup struct {
id string
displayName string
members map[string]deltaGroupMember
}
deltaGroupMember struct {
memberType string
id string
}
deltaUser struct {
id string
displayName string
email string
}
)
func newDeltaCollection(p *Provider) *deltaCollection {
return &deltaCollection{
provider: p,
groups: make(map[string]deltaGroup),
users: make(map[string]deltaUser),
}
}
// Sync syncs the latest changes from the microsoft graph API.
//
// Synchronization is based on https://docs.microsoft.com/en-us/graph/delta-query-groups
//
// It involves 4 steps:
//
// 1. an initial request to /v1.0/groups/delta
// 2. one or more requests to /v1.0/groups/delta?$skiptoken=..., which comes from the @odata.nextLink
// 3. a final response with @odata.deltaLink
// 4. on the next call to sync, starting at @odata.deltaLink
//
// Only the changed groups/members are returned. Removed groups/members have an @removed property.
func (dc *deltaCollection) Sync(ctx context.Context) error {
err := dc.syncGroups(ctx)
if err != nil {
return err
}
return dc.syncUsers(ctx)
}
func (dc *deltaCollection) syncGroups(ctx context.Context) error {
apiURL := dc.groupDeltaLink
// if no delta link is set yet, start the initial fill
if apiURL == "" {
apiURL = dc.provider.cfg.graphURL.ResolveReference(&url.URL{
Path: groupsDeltaPath,
RawQuery: url.Values{
"$select": {"displayName,members"},
}.Encode(),
}).String()
}
for {
var res groupsDeltaResponse
err := dc.provider.api(ctx, apiURL, &res)
if err != nil {
return err
}
for _, g := range res.Value {
// if removed exists, the group was deleted
if g.Removed != nil {
delete(dc.groups, g.ID)
continue
}
gdg := dc.groups[g.ID]
gdg.id = g.ID
gdg.displayName = g.DisplayName
if gdg.members == nil {
gdg.members = make(map[string]deltaGroupMember)
}
for _, m := range g.Members {
// if removed exists, the member was deleted
if m.Removed != nil {
delete(gdg.members, m.ID)
continue
}
gdg.members[m.ID] = deltaGroupMember{
memberType: m.Type,
id: m.ID,
}
}
dc.groups[g.ID] = gdg
}
switch {
case res.NextLink != "":
// when there's a next link we will query again
apiURL = res.NextLink
default:
// once no next link is set anymore, we save the delta link and return
dc.groupDeltaLink = res.DeltaLink
return nil
}
}
}
func (dc *deltaCollection) syncUsers(ctx context.Context) error {
apiURL := dc.userDeltaLink
// if no delta link is set yet, start the initial fill
if apiURL == "" {
apiURL = dc.provider.cfg.graphURL.ResolveReference(&url.URL{
Path: usersDeltaPath,
RawQuery: url.Values{
"$select": {"displayName,mail,userPrincipalName"},
}.Encode(),
}).String()
}
for {
var res usersDeltaResponse
err := dc.provider.api(ctx, apiURL, &res)
if err != nil {
return err
}
for _, u := range res.Value {
// if removed exists, the user was deleted
if u.Removed != nil {
delete(dc.users, u.ID)
continue
}
dc.users[u.ID] = deltaUser{
id: u.ID,
displayName: u.DisplayName,
email: u.getEmail(),
}
}
switch {
case res.NextLink != "":
// when there's a next link we will query again
apiURL = res.NextLink
default:
// once no next link is set anymore, we save the delta link and return
dc.userDeltaLink = res.DeltaLink
return nil
}
}
}
// CurrentUserGroups returns the directory groups and users based on the current state.
func (dc *deltaCollection) CurrentUserGroups() ([]*directory.Group, []*directory.User) {
var groups []*directory.Group
groupLookup := newGroupLookup()
for _, g := range dc.groups {
groups = append(groups, &directory.Group{
Id: g.id,
Name: g.displayName,
})
var groupIDs, userIDs []string
for _, m := range g.members {
switch m.memberType {
case "#microsoft.graph.group":
groupIDs = append(groupIDs, m.id)
case "#microsoft.graph.user":
userIDs = append(userIDs, m.id)
}
}
groupLookup.addGroup(g.id, groupIDs, userIDs)
}
sort.Slice(groups, func(i, j int) bool {
return groups[i].GetId() < groups[j].GetId()
})
var users []*directory.User
for _, u := range dc.users {
users = append(users, &directory.User{
Id: u.id,
GroupIds: groupLookup.getGroupIDsForUser(u.id),
DisplayName: u.displayName,
Email: u.email,
})
}
sort.Slice(users, func(i, j int) bool {
return users[i].GetId() < users[j].GetId()
})
return groups, users
}
// API types for the microsoft graph API.
type (
deltaResponseRemoved struct {
Reason string `json:"reason"`
}
groupsDeltaResponse struct {
Context string `json:"@odata.context"`
NextLink string `json:"@odata.nextLink,omitempty"`
DeltaLink string `json:"@odata.deltaLink,omitempty"`
Value []groupsDeltaResponseGroup `json:"value"`
}
groupsDeltaResponseGroup struct {
apiGroup
Members []groupsDeltaResponseGroupMember `json:"members@delta"`
Removed *deltaResponseRemoved `json:"@removed,omitempty"`
}
groupsDeltaResponseGroupMember struct {
Type string `json:"@odata.type"`
ID string `json:"id"`
Removed *deltaResponseRemoved `json:"@removed,omitempty"`
}
usersDeltaResponse struct {
Context string `json:"@odata.context"`
NextLink string `json:"@odata.nextLink,omitempty"`
DeltaLink string `json:"@odata.deltaLink,omitempty"`
Value []usersDeltaResponseUser `json:"value"`
}
usersDeltaResponseUser struct {
apiUser
Removed *deltaResponseRemoved `json:"@removed,omitempty"`
}
)

View file

@ -1,107 +0,0 @@
package azure
import "sort"
type stringSet map[string]struct{}
func newStringSet() stringSet {
return make(stringSet)
}
func (ss stringSet) add(value string) {
ss[value] = struct{}{}
}
func (ss stringSet) has(value string) bool {
if ss == nil {
return false
}
_, ok := ss[value]
return ok
}
func (ss stringSet) sorted() []string {
if ss == nil {
return nil
}
s := make([]string, 0, len(ss))
for v := range ss {
s = append(s, v)
}
sort.Strings(s)
return s
}
type stringSetSet map[string]stringSet
func newStringSetSet() stringSetSet {
return make(stringSetSet)
}
func (sss stringSetSet) add(v1, v2 string) {
ss, ok := sss[v1]
if !ok {
ss = newStringSet()
sss[v1] = ss
}
ss.add(v2)
}
func (sss stringSetSet) get(v1 string) stringSet {
return sss[v1]
}
type groupLookup struct {
childUserIDToParentGroupID stringSetSet
childGroupIDToParentGroupID stringSetSet
}
func newGroupLookup() *groupLookup {
return &groupLookup{
childUserIDToParentGroupID: newStringSetSet(),
childGroupIDToParentGroupID: newStringSetSet(),
}
}
func (l *groupLookup) addGroup(parentGroupID string, childGroupIDs, childUserIDs []string) {
for _, childGroupID := range childGroupIDs {
l.childGroupIDToParentGroupID.add(childGroupID, parentGroupID)
}
for _, childUserID := range childUserIDs {
l.childUserIDToParentGroupID.add(childUserID, parentGroupID)
}
}
func (l *groupLookup) getUserIDs() []string {
s := make([]string, 0, len(l.childUserIDToParentGroupID))
for userID := range l.childUserIDToParentGroupID {
s = append(s, userID)
}
sort.Strings(s)
return s
}
func (l *groupLookup) getGroupIDsForUser(userID string) []string {
groupIDs := newStringSet()
var todo []string
for groupID := range l.childUserIDToParentGroupID.get(userID) {
todo = append(todo, groupID)
}
for len(todo) > 0 {
groupID := todo[len(todo)-1]
todo = todo[:len(todo)-1]
if groupIDs.has(groupID) {
continue
}
groupIDs.add(groupID)
for parentGroupID := range l.childGroupIDToParentGroupID.get(groupID) {
todo = append(todo, parentGroupID)
}
}
return groupIDs.sorted()
}

View file

@ -1,25 +0,0 @@
package azure
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestGroupLookup(t *testing.T) {
gl := newGroupLookup()
gl.addGroup("g1", []string{"g11", "g12", "g13"}, []string{"u1"})
gl.addGroup("g11", []string{"g111"}, nil)
gl.addGroup("g111", nil, []string{"u2"})
assert.Equal(t, []string{"u1", "u2"}, gl.getUserIDs())
assert.Equal(t, []string{"g1", "g11", "g111"}, gl.getGroupIDsForUser("u2"))
t.Run("cycle protection", func(t *testing.T) {
gl.addGroup("g12", []string{"g1"}, nil)
assert.Equal(t, []string{"u1", "u2"}, gl.getUserIDs())
assert.Equal(t, []string{"g1", "g11", "g111", "g12"}, gl.getGroupIDsForUser("u2"))
})
}

View file

@ -1,8 +0,0 @@
// Package directoryerrors contains errors used by directory providers.
package directoryerrors
import "errors"
// ErrPreferExistingInformation indicates that the information returned by the provider should
// only be used if a record is brand new, otherwise the existing information should be kept as is.
var ErrPreferExistingInformation = errors.New("user ignored")

View file

@ -1,330 +0,0 @@
// Package github contains a directory provider for github.
package github
import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
"net/url"
"sort"
"github.com/rs/zerolog"
"github.com/tomnomnom/linkheader"
"github.com/pomerium/pomerium/internal/encoding"
"github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/pkg/grpc/directory"
)
// Name is the provider name.
const Name = "github"
var defaultURL = &url.URL{
Scheme: "https",
Host: "api.github.com",
}
type config struct {
httpClient *http.Client
serviceAccount *ServiceAccount
url *url.URL
}
// An Option updates the github configuration.
type Option func(cfg *config)
// WithServiceAccount sets the service account in the config.
func WithServiceAccount(serviceAccount *ServiceAccount) Option {
return func(cfg *config) {
cfg.serviceAccount = serviceAccount
}
}
// WithHTTPClient sets the http client option.
func WithHTTPClient(httpClient *http.Client) Option {
return func(cfg *config) {
cfg.httpClient = httputil.NewLoggingClient(httpClient, "github_idp_client",
func(evt *zerolog.Event) *zerolog.Event {
return evt.Str("provider", "github")
})
}
}
// WithURL sets the api url in the config.
func WithURL(u *url.URL) Option {
return func(cfg *config) {
cfg.url = u
}
}
func getConfig(options ...Option) *config {
cfg := new(config)
WithHTTPClient(http.DefaultClient)(cfg)
WithURL(defaultURL)(cfg)
for _, option := range options {
option(cfg)
}
return cfg
}
// The Provider retrieves users and groups from github.
type Provider struct {
cfg *config
log zerolog.Logger
}
// New creates a new Provider.
func New(options ...Option) *Provider {
return &Provider{
cfg: getConfig(options...),
log: log.With().Str("service", "directory").Str("provider", "github").Logger(),
}
}
// User returns the user record for the given id.
func (p *Provider) User(ctx context.Context, userID, accessToken string) (*directory.User, error) {
if p.cfg.serviceAccount == nil {
return nil, fmt.Errorf("github: service account not defined")
}
du := &directory.User{
Id: userID,
}
au, err := p.getUser(ctx, userID)
if err != nil {
return nil, err
}
du.DisplayName = au.Name
du.Email = au.Email
teamIDLookup := map[string]struct{}{}
orgSlugs, err := p.listOrgs(ctx)
if err != nil {
return nil, err
}
for _, orgSlug := range orgSlugs {
teamIDs, err := p.listUserOrganizationTeams(ctx, userID, orgSlug)
if err != nil {
return nil, err
}
for _, teamID := range teamIDs {
teamIDLookup[teamID] = struct{}{}
}
}
for teamID := range teamIDLookup {
du.GroupIds = append(du.GroupIds, teamID)
}
sort.Strings(du.GroupIds)
return du, nil
}
// UserGroups gets the directory user groups for github.
func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*directory.User, error) {
if p.cfg.serviceAccount == nil {
return nil, nil, fmt.Errorf("github: service account not defined")
}
orgSlugs, err := p.listOrgs(ctx)
if err != nil {
return nil, nil, err
}
userLoginToGroups := map[string][]string{}
var allGroups []*directory.Group
for _, orgSlug := range orgSlugs {
teams, err := p.listOrganizationTeamsWithMemberIDs(ctx, orgSlug)
if err != nil {
return nil, nil, err
}
for _, team := range teams {
allGroups = append(allGroups, &directory.Group{
Id: team.Slug,
Name: team.Slug,
})
for _, memberID := range team.MemberIDs {
userLoginToGroups[memberID] = append(userLoginToGroups[memberID], team.Slug)
}
}
}
sort.Slice(allGroups, func(i, j int) bool {
return allGroups[i].Id < allGroups[j].Id
})
var allUsers []*directory.User
for _, orgSlug := range orgSlugs {
members, err := p.listOrganizationMembers(ctx, orgSlug)
if err != nil {
return nil, nil, err
}
for _, member := range members {
du := &directory.User{
Id: member.Login,
GroupIds: userLoginToGroups[member.ID],
DisplayName: member.Name,
Email: member.Email,
}
sort.Strings(du.GroupIds)
allUsers = append(allUsers, du)
}
}
sort.Slice(allUsers, func(i, j int) bool {
return allUsers[i].Id < allUsers[j].Id
})
return allGroups, allUsers, nil
}
func (p *Provider) listOrgs(ctx context.Context) (orgSlugs []string, err error) {
nextURL := p.cfg.url.ResolveReference(&url.URL{
Path: "/user/orgs",
}).String()
for nextURL != "" {
var results []struct {
Login string `json:"login"`
}
hdrs, err := p.api(ctx, nextURL, &results)
if err != nil {
return nil, err
}
for _, result := range results {
orgSlugs = append(orgSlugs, result.Login)
}
nextURL = getNextLink(hdrs)
}
return orgSlugs, nil
}
func (p *Provider) getUser(ctx context.Context, userLogin string) (*apiUserObject, error) {
apiURL := p.cfg.url.ResolveReference(&url.URL{
Path: fmt.Sprintf("/users/%s", userLogin),
}).String()
var res apiUserObject
_, err := p.api(ctx, apiURL, &res)
if err != nil {
return nil, err
}
return &res, nil
}
func (p *Provider) api(ctx context.Context, apiURL string, out interface{}) (http.Header, error) {
req, err := http.NewRequestWithContext(ctx, "GET", apiURL, nil)
if err != nil {
return nil, fmt.Errorf("github: failed to create http request: %w", err)
}
req.SetBasicAuth(p.cfg.serviceAccount.Username, p.cfg.serviceAccount.PersonalAccessToken)
res, err := p.cfg.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("github: failed to make http request: %w", err)
}
defer res.Body.Close()
if res.StatusCode/100 != 2 {
return nil, fmt.Errorf("github: error from API: %s", res.Status)
}
if out != nil {
err := json.NewDecoder(res.Body).Decode(out)
if err != nil {
return nil, fmt.Errorf("github: failed to decode json body: %w", err)
}
}
return res.Header, nil
}
func (p *Provider) graphql(ctx context.Context, query string, out interface{}) (http.Header, error) {
apiURL := p.cfg.url.ResolveReference(&url.URL{
Path: "/graphql",
}).String()
bs, _ := json.Marshal(struct {
Query string `json:"query"`
}{query})
req, err := http.NewRequestWithContext(ctx, "POST", apiURL, bytes.NewReader(bs))
if err != nil {
return nil, fmt.Errorf("github: failed to create http request: %w", err)
}
req.SetBasicAuth(p.cfg.serviceAccount.Username, p.cfg.serviceAccount.PersonalAccessToken)
res, err := p.cfg.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("github: failed to make http request: %w", err)
}
defer res.Body.Close()
if res.StatusCode/100 != 2 {
return nil, fmt.Errorf("github: error from API: %s", res.Status)
}
if out != nil {
err := json.NewDecoder(res.Body).Decode(out)
if err != nil {
return nil, fmt.Errorf("github: failed to decode json body: %w", err)
}
}
return res.Header, nil
}
func getNextLink(hdrs http.Header) string {
for _, link := range linkheader.ParseMultiple(hdrs.Values("Link")) {
if link.Rel == "next" {
return link.URL
}
}
return ""
}
// A ServiceAccount is used by the GitHub provider to query the GitHub API.
type ServiceAccount struct {
Username string `json:"username"`
PersonalAccessToken string `json:"personal_access_token"`
}
// ParseServiceAccount parses the service account in the config options.
func ParseServiceAccount(rawServiceAccount string) (*ServiceAccount, error) {
var serviceAccount ServiceAccount
if err := encoding.DecodeBase64OrJSON(rawServiceAccount, &serviceAccount); err != nil {
return nil, err
}
if serviceAccount.Username == "" {
return nil, fmt.Errorf("username is required")
}
if serviceAccount.PersonalAccessToken == "" {
return nil, fmt.Errorf("personal_access_token is required")
}
return &serviceAccount, nil
}
// see: https://docs.github.com/en/free-pro-team@latest/rest/reference/users#get-a-user
type apiUserObject struct {
Login string `json:"login"`
Name string `json:"name"`
Email string `json:"email"`
}
type teamWithMemberIDs struct {
ID string
Slug string
Name string
MemberIDs []string
}

View file

@ -1,436 +0,0 @@
package github
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"net/url"
"testing"
"github.com/go-chi/chi/v5"
"github.com/go-chi/chi/v5/middleware"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/vektah/gqlparser/ast"
"github.com/vektah/gqlparser/parser"
"github.com/pomerium/pomerium/internal/testutil"
)
type M = map[string]interface{}
func newMockAPI(t *testing.T, srv *httptest.Server) http.Handler {
r := chi.NewRouter()
r.Use(middleware.Logger)
r.Use(func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !assert.Equal(t, "Basic YWJjOnh5eg==", r.Header.Get("Authorization")) {
http.Error(w, "forbidden", http.StatusForbidden)
return
}
next.ServeHTTP(w, r)
})
})
r.Post("/graphql", func(w http.ResponseWriter, r *http.Request) {
var body struct {
Query string `json:"query"`
}
json.NewDecoder(r.Body).Decode(&body)
q, err := parser.ParseQuery(&ast.Source{
Input: body.Query,
})
if err != nil {
panic(err)
}
result := qlResult{
Data: &qlData{
Organization: &qlOrganization{},
},
}
handleMembersWithRole := func(orgSlug string, field *ast.Field) {
membersWithRole := &qlMembersWithRoleConnection{}
var cursor string
for _, arg := range field.Arguments {
if arg.Name == "after" {
cursor = arg.Value.Raw
}
}
switch cursor {
case `null`:
switch orgSlug {
case "org1":
membersWithRole.PageInfo = qlPageInfo{EndCursor: "TOKEN1", HasNextPage: true}
membersWithRole.Nodes = []qlUser{
{ID: "user1", Login: "user1", Name: "User 1", Email: "user1@example.com"},
{ID: "user2", Login: "user2", Name: "User 2", Email: "user2@example.com"},
}
case "org2":
membersWithRole.PageInfo = qlPageInfo{HasNextPage: false}
membersWithRole.Nodes = []qlUser{
{ID: "user4", Login: "user4", Name: "User 4", Email: "user4@example.com"},
}
default:
t.Errorf("unexpected org slug: %s", orgSlug)
}
case `TOKEN1`:
membersWithRole.PageInfo = qlPageInfo{HasNextPage: false}
membersWithRole.Nodes = []qlUser{
{ID: "user3", Login: "user3", Name: "User 3", Email: "user3@example.com"},
}
default:
t.Errorf("unexpected cursor: %s", cursor)
}
result.Data.Organization.MembersWithRole = membersWithRole
}
handleTeamMembers := func(orgSlug, teamSlug string, field *ast.Field) {
result.Data.Organization.Team.Members = &qlTeamMemberConnection{
PageInfo: qlPageInfo{HasNextPage: false},
}
switch teamSlug {
case "team3":
result.Data.Organization.Team.Members.Edges = []qlTeamMemberEdge{
{Node: qlUser{ID: "user3"}},
}
}
}
handleTeam := func(orgSlug string, field *ast.Field) {
result.Data.Organization.Team = &qlTeam{}
var teamSlug string
for _, arg := range field.Arguments {
if arg.Name == "slug" {
teamSlug = arg.Value.Raw
}
}
for _, selection := range field.SelectionSet {
subField, ok := selection.(*ast.Field)
if !ok {
continue
}
switch subField.Name {
case "members":
handleTeamMembers(orgSlug, teamSlug, subField)
}
}
}
renderNodeField := func(field *ast.Field, path []string, value string) string {
outer:
for _, segment := range path {
for _, selection := range field.SelectionSet {
subField, ok := selection.(*ast.Field)
if !ok {
continue
}
if subField.Name != segment {
continue
}
field = subField
continue outer
}
return ""
}
return value
}
handleTeams := func(orgSlug string, field *ast.Field) {
teams := &qlTeamConnection{}
var cursor string
var userLogin string
for _, arg := range field.Arguments {
if arg.Name == "after" {
cursor = arg.Value.Raw
}
if arg.Name == "userLogins" {
userLogin = arg.Value.Children[0].Value.Raw
}
}
switch cursor {
case `null`:
switch orgSlug {
case "org1":
teams.PageInfo = qlPageInfo{HasNextPage: true, EndCursor: "TOKEN1"}
teams.Edges = []qlTeamEdge{
{Node: qlTeam{
ID: renderNodeField(field, []string{"edges", "node", "id"}, "MDQ6VGVhbTE="),
Slug: renderNodeField(field, []string{"edges", "node", "slug"}, "team1"),
Name: renderNodeField(field, []string{"edges", "node", "name"}, "Team 1"),
Members: &qlTeamMemberConnection{
PageInfo: qlPageInfo{HasNextPage: false},
Edges: []qlTeamMemberEdge{
{Node: qlUser{ID: "user1"}},
{Node: qlUser{ID: "user2"}},
},
}}},
}
case "org2":
teams.PageInfo = qlPageInfo{HasNextPage: false}
teams.Edges = []qlTeamEdge{
{Node: qlTeam{
ID: renderNodeField(field, []string{"edges", "node", "id"}, "MDQ6VGVhbTM="),
Slug: renderNodeField(field, []string{"edges", "node", "slug"}, "team3"),
Name: renderNodeField(field, []string{"edges", "node", "name"}, "Team 3"),
Members: &qlTeamMemberConnection{
PageInfo: qlPageInfo{HasNextPage: true, EndCursor: "TOKEN1"},
Edges: []qlTeamMemberEdge{
{Node: qlUser{ID: "user1"}},
{Node: qlUser{ID: "user2"}},
},
}}},
}
if userLogin == "" || userLogin == "user4" {
teams.Edges = append(teams.Edges, qlTeamEdge{
Node: qlTeam{
ID: renderNodeField(field, []string{"edges", "node", "id"}, "MDQ6VGVhbTQ="),
Slug: renderNodeField(field, []string{"edges", "node", "slug"}, "team4"),
Name: renderNodeField(field, []string{"edges", "node", "name"}, "Team 4"),
Members: &qlTeamMemberConnection{
PageInfo: qlPageInfo{HasNextPage: false},
Edges: []qlTeamMemberEdge{
{Node: qlUser{ID: "user4"}},
},
}},
})
}
default:
t.Errorf("unexpected org slug: %s", orgSlug)
}
case "TOKEN1":
teams.PageInfo = qlPageInfo{HasNextPage: false}
teams.Edges = []qlTeamEdge{
{Node: qlTeam{
ID: renderNodeField(field, []string{"edges", "node", "id"}, "MDQ6VGVhbTI="),
Slug: renderNodeField(field, []string{"edges", "node", "slug"}, "team2"),
Name: renderNodeField(field, []string{"edges", "node", "name"}, "Team 2"),
Members: &qlTeamMemberConnection{
PageInfo: qlPageInfo{HasNextPage: false},
Edges: []qlTeamMemberEdge{
{Node: qlUser{ID: "user1"}},
},
}}},
}
default:
t.Errorf("unexpected cursor: %s", cursor)
}
result.Data.Organization.Teams = teams
}
handleOrganization := func(field *ast.Field) {
var orgSlug string
for _, arg := range field.Arguments {
if arg.Name == "login" {
orgSlug = arg.Value.Raw
}
}
for _, orgSelection := range field.SelectionSet {
orgField, ok := orgSelection.(*ast.Field)
if !ok {
continue
}
switch orgField.Name {
case "teams":
handleTeams(orgSlug, orgField)
case "team":
handleTeam(orgSlug, orgField)
case "membersWithRole":
handleMembersWithRole(orgSlug, orgField)
}
}
}
for _, operation := range q.Operations {
for _, selection := range operation.SelectionSet {
field, ok := selection.(*ast.Field)
if !ok {
continue
}
if field.Name != "organization" {
continue
}
handleOrganization(field)
}
}
_ = json.NewEncoder(w).Encode(result)
})
r.Get("/user/orgs", func(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode([]M{
{"login": "org1"},
{"login": "org2"},
})
})
r.Get("/orgs/{org_id}/teams", func(w http.ResponseWriter, r *http.Request) {
teams := map[string][]M{
"org1": {
{"slug": "team1", "id": 1},
{"slug": "team2", "id": 2},
},
"org2": {
{"slug": "team3", "id": 3},
{"slug": "team4", "id": 4},
},
}
orgID := chi.URLParam(r, "org_id")
json.NewEncoder(w).Encode(teams[orgID])
})
r.Get("/orgs/{org_id}/teams/{team_id}/members", func(w http.ResponseWriter, r *http.Request) {
members := map[string]map[string][]M{
"org1": {
"team1": {
{"login": "user1"},
{"login": "user2"},
},
"team2": {
{"login": "user1"},
},
},
"org2": {
"team3": {
{"login": "user1"},
{"login": "user2"},
{"login": "user3"},
},
"team4": {
{"login": "user4"},
},
},
}
orgID := chi.URLParam(r, "org_id")
teamID := chi.URLParam(r, "team_id")
json.NewEncoder(w).Encode(members[orgID][teamID])
})
r.Get("/users/{user_id}", func(w http.ResponseWriter, r *http.Request) {
users := map[string]apiUserObject{
"user1": {Login: "user1", Name: "User 1", Email: "user1@example.com"},
"user2": {Login: "user2", Name: "User 2", Email: "user2@example.com"},
"user3": {Login: "user3", Name: "User 3", Email: "user3@example.com"},
"user4": {Login: "user4", Name: "User 4", Email: "user4@example.com"},
}
userID := chi.URLParam(r, "user_id")
json.NewEncoder(w).Encode(users[userID])
})
return r
}
func TestProvider_User(t *testing.T) {
var mockAPI http.Handler
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
mockAPI.ServeHTTP(w, r)
}))
defer srv.Close()
mockAPI = newMockAPI(t, srv)
p := New(
WithURL(mustParseURL(srv.URL)),
WithServiceAccount(&ServiceAccount{
Username: "abc",
PersonalAccessToken: "xyz",
}),
)
du, err := p.User(context.Background(), "user1", "")
if !assert.NoError(t, err) {
return
}
testutil.AssertProtoJSONEqual(t, `{
"id": "user1",
"groupIds": ["team1", "team2", "team3"],
"displayName": "User 1",
"email": "user1@example.com"
}`, du)
}
func TestProvider_UserGroups(t *testing.T) {
var mockAPI http.Handler
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
mockAPI.ServeHTTP(w, r)
}))
defer srv.Close()
mockAPI = newMockAPI(t, srv)
p := New(
WithURL(mustParseURL(srv.URL)),
WithServiceAccount(&ServiceAccount{
Username: "abc",
PersonalAccessToken: "xyz",
}),
)
groups, users, err := p.UserGroups(context.Background())
assert.NoError(t, err)
testutil.AssertProtoJSONEqual(t, `[
{ "id": "user1", "groupIds": ["team1", "team2", "team3"], "displayName": "User 1", "email": "user1@example.com" },
{ "id": "user2", "groupIds": ["team1", "team3"], "displayName": "User 2", "email": "user2@example.com" },
{ "id": "user3", "groupIds": ["team3"], "displayName": "User 3", "email": "user3@example.com" },
{ "id": "user4", "groupIds": ["team4"], "displayName": "User 4", "email": "user4@example.com" }
]`, users)
testutil.AssertProtoJSONEqual(t, `[
{ "id": "team1", "name": "team1" },
{ "id": "team2", "name": "team2" },
{ "id": "team3", "name": "team3" },
{ "id": "team4", "name": "team4" }
]`, groups)
}
func TestParseServiceAccount(t *testing.T) {
tests := []struct {
name string
rawServiceAccount string
serviceAccount *ServiceAccount
wantErr bool
}{
{
"json",
`{"username": "USERNAME", "personal_access_token": "PERSONAL_ACCESS_TOKEN"}`,
&ServiceAccount{Username: "USERNAME", PersonalAccessToken: "PERSONAL_ACCESS_TOKEN"},
false,
},
{
"base64 json",
`eyJ1c2VybmFtZSI6ICJVU0VSTkFNRSIsICJwZXJzb25hbF9hY2Nlc3NfdG9rZW4iOiAiUEVSU09OQUxfQUNDRVNTX1RPS0VOIn0=`,
&ServiceAccount{Username: "USERNAME", PersonalAccessToken: "PERSONAL_ACCESS_TOKEN"},
false,
},
{
"empty",
"",
nil,
true,
},
{
"invalid",
"Zm9v---",
nil,
true,
},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
got, err := ParseServiceAccount(tc.rawServiceAccount)
require.True(t, (err != nil) == tc.wantErr)
assert.Equal(t, tc.serviceAccount, got)
})
}
}
func mustParseURL(rawurl string) *url.URL {
u, err := url.Parse(rawurl)
if err != nil {
panic(err)
}
return u
}

View file

@ -1,246 +0,0 @@
package github
import (
"context"
"encoding/json"
"fmt"
)
const maxPageCount = 100
type (
qlData struct {
Organization *qlOrganization `json:"organization"`
}
qlMembersWithRoleConnection struct {
Nodes []qlUser `json:"nodes"`
PageInfo qlPageInfo `json:"pageInfo"`
}
qlOrganization struct {
MembersWithRole *qlMembersWithRoleConnection `json:"membersWithRole"`
Team *qlTeam `json:"team"`
Teams *qlTeamConnection `json:"teams"`
}
qlPageInfo struct {
EndCursor string `json:"endCursor"`
HasNextPage bool `json:"hasNextPage"`
}
qlResult struct {
Data *qlData `json:"data"`
}
qlTeam struct {
ID string `json:"id"`
Name string `json:"name"`
Slug string `json:"slug"`
Members *qlTeamMemberConnection `json:"members"`
}
qlTeamConnection struct {
Edges []qlTeamEdge `json:"edges"`
PageInfo qlPageInfo `json:"pageInfo"`
}
qlTeamEdge struct {
Node qlTeam `json:"node"`
}
qlTeamMemberConnection struct {
Edges []qlTeamMemberEdge `json:"edges"`
PageInfo qlPageInfo `json:"pageInfo"`
}
qlTeamMemberEdge struct {
Node qlUser `json:"node"`
}
qlUser struct {
ID string `json:"id"`
Login string `json:"login"`
Name string `json:"name"`
Email string `json:"email"`
}
)
func (p *Provider) listOrganizationMembers(ctx context.Context, orgSlug string) ([]qlUser, error) {
var results []qlUser
var cursor *string
for {
var res qlResult
q := fmt.Sprintf(`query {
organization(login:%s) {
membersWithRole(first:%d, after:%s) {
pageInfo {
endCursor
hasNextPage
}
nodes {
id
login
name
email
}
}
}
}`, encode(orgSlug), maxPageCount, encode(cursor))
_, err := p.graphql(ctx, q, &res)
if err != nil {
return nil, err
}
results = append(results, res.Data.Organization.MembersWithRole.Nodes...)
if !res.Data.Organization.MembersWithRole.PageInfo.HasNextPage {
break
}
cursor = &res.Data.Organization.MembersWithRole.PageInfo.EndCursor
}
return results, nil
}
func (p *Provider) listOrganizationTeamsWithMemberIDs(ctx context.Context, orgSlug string) ([]teamWithMemberIDs, error) {
var results []teamWithMemberIDs
var pageInfos []qlPageInfo
// first query all the teams with their members
var cursor *string
for {
var res qlResult
q := fmt.Sprintf(`query {
organization(login:%s) {
teams(first:%d, after:%s) {
pageInfo {
endCursor
hasNextPage
}
edges {
node {
id
name
slug
members(first:%d) {
pageInfo {
endCursor
hasNextPage
}
edges {
node {
id
}
}
}
}
}
}
}
}`, encode(orgSlug), maxPageCount, encode(cursor), maxPageCount)
_, err := p.graphql(ctx, q, &res)
if err != nil {
return nil, err
}
for _, teamEdge := range res.Data.Organization.Teams.Edges {
var memberIDs []string
for _, memberEdge := range teamEdge.Node.Members.Edges {
memberIDs = append(memberIDs, memberEdge.Node.ID)
}
results = append(results, teamWithMemberIDs{
ID: teamEdge.Node.ID,
Slug: teamEdge.Node.Slug,
Name: teamEdge.Node.Name,
MemberIDs: memberIDs,
})
pageInfos = append(pageInfos, teamEdge.Node.Members.PageInfo)
}
if !res.Data.Organization.Teams.PageInfo.HasNextPage {
break
}
cursor = &res.Data.Organization.Teams.PageInfo.EndCursor
}
// it's possible we didn't get all the members if the initial query, so go through each team and
// check the member pageInfo. If there are still remaining members, query those.
for i, pageInfo := range pageInfos {
if !pageInfo.HasNextPage {
continue
}
cursor = &pageInfo.EndCursor
for {
var res qlResult
q := fmt.Sprintf(`query {
organization(login:%s) {
team(slug:%s) {
members(first:%d, after:%s) {
pageInfo {
endCursor
hasNextPage
}
edges {
node {
id
}
}
}
}
}
}`, encode(orgSlug), encode(results[i].Slug), maxPageCount, encode(cursor))
_, err := p.graphql(ctx, q, &res)
if err != nil {
return nil, err
}
for _, memberEdge := range res.Data.Organization.Team.Members.Edges {
results[i].MemberIDs = append(results[i].MemberIDs, memberEdge.Node.ID)
}
if !res.Data.Organization.Team.Members.PageInfo.HasNextPage {
break
}
cursor = &res.Data.Organization.Team.Members.PageInfo.EndCursor
}
}
return results, nil
}
func (p *Provider) listUserOrganizationTeams(ctx context.Context, userSlug string, orgSlug string) ([]string, error) {
// GitHub's Rest API doesn't have an easy way of querying this data, so we use the GraphQL API.
var teamSlugs []string
var cursor *string
for {
var res qlResult
q := fmt.Sprintf(`query {
organization(login:%s) {
teams(first:%d, userLogins:[%s], after:%s) {
pageInfo {
endCursor
hasNextPage
}
edges {
node {
id
slug
}
}
}
}
}`, encode(orgSlug), maxPageCount, encode(userSlug), encode(cursor))
_, err := p.graphql(ctx, q, &res)
if err != nil {
return nil, err
}
for _, edge := range res.Data.Organization.Teams.Edges {
teamSlugs = append(teamSlugs, edge.Node.Slug)
}
if !res.Data.Organization.Teams.PageInfo.HasNextPage {
break
}
cursor = &res.Data.Organization.Teams.PageInfo.EndCursor
}
return teamSlugs, nil
}
func encode(obj interface{}) string {
bs, _ := json.Marshal(obj)
return string(bs)
}

View file

@ -1,286 +0,0 @@
// Package gitlab contains a directory provider for gitlab.
package gitlab
import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/url"
"sort"
"strconv"
"github.com/rs/zerolog"
"github.com/tomnomnom/linkheader"
"github.com/pomerium/pomerium/internal/encoding"
"github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/pkg/grpc/directory"
)
// Name is the provider name.
const Name = "gitlab"
var defaultURL = &url.URL{
Scheme: "https",
Host: "gitlab.com",
}
type config struct {
httpClient *http.Client
serviceAccount *ServiceAccount
url *url.URL
}
// An Option updates the gitlab configuration.
type Option func(cfg *config)
// WithServiceAccount sets the service account in the config.
func WithServiceAccount(serviceAccount *ServiceAccount) Option {
return func(cfg *config) {
cfg.serviceAccount = serviceAccount
}
}
// WithHTTPClient sets the http client option.
func WithHTTPClient(httpClient *http.Client) Option {
return func(cfg *config) {
cfg.httpClient = httputil.NewLoggingClient(httpClient, "gitlab_idp_client",
func(evt *zerolog.Event) *zerolog.Event {
return evt.Str("provider", "azure")
})
}
}
// WithURL sets the api url in the config.
func WithURL(u *url.URL) Option {
return func(cfg *config) {
cfg.url = u
}
}
func getConfig(options ...Option) *config {
cfg := new(config)
WithHTTPClient(http.DefaultClient)(cfg)
WithURL(defaultURL)(cfg)
for _, option := range options {
option(cfg)
}
return cfg
}
// The Provider retrieves users and groups from gitlab.
type Provider struct {
cfg *config
}
// New creates a new Provider.
func New(options ...Option) *Provider {
return &Provider{
cfg: getConfig(options...),
}
}
func withLog(ctx context.Context) context.Context {
return log.WithContext(ctx, func(c zerolog.Context) zerolog.Context {
return c.Str("service", "directory").Str("provider", "gitlab")
})
}
// User returns the user record for the given id.
func (p *Provider) User(ctx context.Context, userID, accessToken string) (*directory.User, error) {
ctx = withLog(ctx)
du := &directory.User{
Id: userID,
}
au, err := p.getUser(ctx, userID, accessToken)
if err != nil {
return nil, err
}
du.DisplayName = au.Name
du.Email = au.Email
groups, err := p.listGroups(ctx, accessToken)
if err != nil {
return nil, err
}
for _, g := range groups {
du.GroupIds = append(du.GroupIds, g.Id)
}
sort.Strings(du.GroupIds)
return du, nil
}
// UserGroups gets the directory user groups for gitlab.
func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*directory.User, error) {
ctx = withLog(ctx)
if p.cfg.serviceAccount == nil {
return nil, nil, fmt.Errorf("gitlab: service account not defined")
}
log.Info(ctx).Msg("getting user groups")
groups, err := p.listGroups(ctx, "")
if err != nil {
return nil, nil, err
}
userLookup := map[int]apiUserObject{}
userIDToGroupIDs := map[int][]string{}
for _, group := range groups {
users, err := p.listGroupMembers(ctx, group.Id)
if err != nil {
return nil, nil, err
}
for _, u := range users {
userIDToGroupIDs[u.ID] = append(userIDToGroupIDs[u.ID], group.Id)
userLookup[u.ID] = u
}
}
var users []*directory.User
for _, u := range userLookup {
user := &directory.User{
Id: fmt.Sprint(u.ID),
DisplayName: u.Name,
Email: u.Email,
}
user.GroupIds = append(user.GroupIds, userIDToGroupIDs[u.ID]...)
sort.Strings(user.GroupIds)
users = append(users, user)
}
sort.Slice(users, func(i, j int) bool {
return users[i].GetId() < users[j].GetId()
})
return groups, users, nil
}
func (p *Provider) getUser(ctx context.Context, userID string, accessToken string) (*apiUserObject, error) {
apiURL := p.cfg.url.ResolveReference(&url.URL{
Path: fmt.Sprintf("/api/v4/users/%s", userID),
}).String()
var result apiUserObject
_, err := p.api(ctx, accessToken, apiURL, &result)
if err != nil {
return nil, fmt.Errorf("gitlab: error querying user: %w", err)
}
return &result, nil
}
// listGroups returns a map, with key is group ID, element is group name.
func (p *Provider) listGroups(ctx context.Context, accessToken string) ([]*directory.Group, error) {
nextURL := p.cfg.url.ResolveReference(&url.URL{
Path: "/api/v4/groups",
}).String()
var groups []*directory.Group
for nextURL != "" {
var result []struct {
ID int `json:"id"`
Name string `json:"name"`
}
hdrs, err := p.api(ctx, accessToken, nextURL, &result)
if err != nil {
return nil, fmt.Errorf("gitlab: error querying groups: %w", err)
}
for _, r := range result {
groups = append(groups, &directory.Group{
Id: strconv.Itoa(r.ID),
Name: r.Name,
})
}
nextURL = getNextLink(hdrs)
}
return groups, nil
}
func (p *Provider) listGroupMembers(ctx context.Context, groupID string) (users []apiUserObject, err error) {
nextURL := p.cfg.url.ResolveReference(&url.URL{
Path: fmt.Sprintf("/api/v4/groups/%s/members", groupID),
}).String()
for nextURL != "" {
var result []apiUserObject
hdrs, err := p.api(ctx, "", nextURL, &result)
if err != nil {
return nil, fmt.Errorf("gitlab: error querying group members: %w", err)
}
users = append(users, result...)
nextURL = getNextLink(hdrs)
}
return users, nil
}
func (p *Provider) api(ctx context.Context, accessToken string, uri string, out interface{}) (http.Header, error) {
req, err := http.NewRequestWithContext(ctx, "GET", uri, nil)
if err != nil {
return nil, fmt.Errorf("gitlab: failed to create HTTP request: %w", err)
}
req.Header.Set("Accept", "application/json")
req.Header.Set("Content-Type", "application/json")
if accessToken != "" {
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", accessToken))
} else {
req.Header.Set("PRIVATE-TOKEN", p.cfg.serviceAccount.PrivateToken)
}
res, err := p.cfg.httpClient.Do(req)
if err != nil {
return nil, err
}
defer res.Body.Close()
if res.StatusCode/100 != 2 {
return nil, fmt.Errorf("gitlab: error querying api url=%s status_code=%d: %s", uri, res.StatusCode, res.Status)
}
err = json.NewDecoder(res.Body).Decode(out)
if err != nil {
return nil, err
}
return res.Header, nil
}
func getNextLink(hdrs http.Header) string {
for _, link := range linkheader.ParseMultiple(hdrs.Values("Link")) {
if link.Rel == "next" {
return link.URL
}
}
return ""
}
// A ServiceAccount is used by the Gitlab provider to query the Gitlab API.
type ServiceAccount struct {
PrivateToken string `json:"private_token"`
}
// ParseServiceAccount parses the service account in the config options.
func ParseServiceAccount(rawServiceAccount string) (*ServiceAccount, error) {
var serviceAccount ServiceAccount
if err := encoding.DecodeBase64OrJSON(rawServiceAccount, &serviceAccount); err != nil {
return nil, err
}
if serviceAccount.PrivateToken == "" {
return nil, fmt.Errorf("private_token is required")
}
return &serviceAccount, nil
}
type apiUserObject struct {
ID int `json:"id"`
Name string `json:"name"`
Email string `json:"email"`
}

View file

@ -1,132 +0,0 @@
package gitlab
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"net/url"
"testing"
"github.com/go-chi/chi/v5"
"github.com/go-chi/chi/v5/middleware"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/pomerium/pomerium/internal/testutil"
)
type M = map[string]interface{}
func newMockAPI(t *testing.T, srv *httptest.Server) http.Handler {
r := chi.NewRouter()
r.Use(middleware.Logger)
r.Route("/api/v4", func(r chi.Router) {
r.Use(func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Header.Get("Private-Token") != "PRIVATE_TOKEN" {
http.Error(w, "forbidden", http.StatusForbidden)
return
}
next.ServeHTTP(w, r)
})
})
r.Get("/groups", func(w http.ResponseWriter, r *http.Request) {
_ = json.NewEncoder(w).Encode([]M{
{"id": 1, "name": "Group 1"},
{"id": 2, "name": "Group 2"},
})
})
r.Get("/groups/{group_name}/members", func(w http.ResponseWriter, r *http.Request) {
members := map[string][]M{
"1": {
{"id": 11, "name": "User 1", "email": "user1@example.com"},
},
"2": {
{"id": 12, "name": "User 2", "email": "user2@example.com"},
{"id": 13, "name": "User 3", "email": "user3@example.com"},
},
}
_ = json.NewEncoder(w).Encode(members[chi.URLParam(r, "group_name")])
})
})
return r
}
func Test(t *testing.T) {
var mockAPI http.Handler
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
mockAPI.ServeHTTP(w, r)
}))
defer srv.Close()
mockAPI = newMockAPI(t, srv)
p := New(
WithURL(mustParseURL(srv.URL)),
WithServiceAccount(&ServiceAccount{
PrivateToken: "PRIVATE_TOKEN",
}),
)
groups, users, err := p.UserGroups(context.Background())
assert.NoError(t, err)
testutil.AssertProtoJSONEqual(t, `[
{ "id": "11", "groupIds": ["1"], "displayName": "User 1", "email": "user1@example.com" },
{ "id": "12", "groupIds": ["2"], "displayName": "User 2", "email": "user2@example.com" },
{ "id": "13", "groupIds": ["2"], "displayName": "User 3", "email": "user3@example.com" }
]`, users)
testutil.AssertProtoJSONEqual(t, `[
{ "id": "1", "name": "Group 1" },
{ "id": "2", "name": "Group 2" }
]`, groups)
}
func TestParseServiceAccount(t *testing.T) {
tests := []struct {
name string
rawServiceAccount string
serviceAccount *ServiceAccount
wantErr bool
}{
{
"json",
`{"private_token":"PRIVATE_TOKEN"}`,
&ServiceAccount{PrivateToken: "PRIVATE_TOKEN"},
false,
},
{
"base64 json",
`eyJwcml2YXRlX3Rva2VuIjoiUFJJVkFURV9UT0tFTiJ9`,
&ServiceAccount{PrivateToken: "PRIVATE_TOKEN"},
false,
},
{
"empty",
"",
nil,
true,
},
{
"invalid",
"Zm9v---",
nil,
true,
},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
got, err := ParseServiceAccount(tc.rawServiceAccount)
require.True(t, (err != nil) == tc.wantErr)
assert.Equal(t, tc.serviceAccount, got)
})
}
}
func mustParseURL(rawurl string) *url.URL {
u, err := url.Parse(rawurl)
if err != nil {
panic(err)
}
return u
}

View file

@ -1,323 +0,0 @@
// Package google contains the Google directory provider.
package google
import (
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"sort"
"sync"
"github.com/rs/zerolog"
"golang.org/x/oauth2/google"
admin "google.golang.org/api/admin/directory/v1"
"google.golang.org/api/googleapi"
"google.golang.org/api/option"
"github.com/pomerium/pomerium/internal/directory/directoryerrors"
"github.com/pomerium/pomerium/internal/encoding"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/pkg/grpc/directory"
)
const (
// Name is the provider name.
Name = "google"
currentAccountCustomerID = "my_customer"
)
const (
defaultProviderURL = "https://www.googleapis.com/"
)
type config struct {
serviceAccount *ServiceAccount
url string
}
// An Option changes the configuration for the Google directory provider.
type Option func(cfg *config)
// WithServiceAccount sets the service account in the Google configuration.
func WithServiceAccount(serviceAccount *ServiceAccount) Option {
return func(cfg *config) {
cfg.serviceAccount = serviceAccount
}
}
// WithURL sets the provider url to use.
func WithURL(url string) Option {
return func(cfg *config) {
cfg.url = url
}
}
func getConfig(opts ...Option) *config {
cfg := new(config)
WithURL(defaultProviderURL)(cfg)
for _, opt := range opts {
opt(cfg)
}
return cfg
}
// Required scopes for groups api
// https://developers.google.com/admin-sdk/directory/v1/reference/groups/list
var apiScopes = []string{admin.AdminDirectoryUserReadonlyScope, admin.AdminDirectoryGroupReadonlyScope}
// A Provider is a Google directory provider.
type Provider struct {
cfg *config
log zerolog.Logger
mu sync.RWMutex
apiClient *admin.Service
}
// New creates a new Google directory provider.
func New(options ...Option) *Provider {
return &Provider{
cfg: getConfig(options...),
log: log.With().Str("service", "directory").Str("provider", "google").Logger(),
}
}
// User returns the user record for the given id.
func (p *Provider) User(ctx context.Context, userID, accessToken string) (*directory.User, error) {
apiClient, err := p.getAPIClient(ctx)
if err != nil {
return nil, fmt.Errorf("google: error getting API client: %w", err)
}
du := &directory.User{
Id: userID,
}
au, err := apiClient.Users.Get(userID).
Context(ctx).
Do()
if isAccessDenied(err) {
// ignore forbidden errors as a user may login from another gsuite domain
return du, directoryerrors.ErrPreferExistingInformation
} else if err != nil {
return nil, fmt.Errorf("google: error getting user: %w", err)
} else {
if au.Name != nil {
du.DisplayName = au.Name.FullName
}
du.Email = au.PrimaryEmail
}
err = apiClient.Groups.List().
Context(ctx).
UserKey(userID).
Pages(ctx, func(res *admin.Groups) error {
for _, g := range res.Groups {
du.GroupIds = append(du.GroupIds, g.Id)
}
return nil
})
if err != nil {
return nil, fmt.Errorf("google: error getting groups for user: %w", err)
}
sort.Strings(du.GroupIds)
return du, nil
}
// UserGroups returns a slice of group names a given user is in
// NOTE: groups via Directory API is limited to 1 QPS!
// https://developers.google.com/admin-sdk/directory/v1/reference/groups/list
// https://developers.google.com/admin-sdk/directory/v1/limits
func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*directory.User, error) {
apiClient, err := p.getAPIClient(ctx)
if err != nil {
return nil, nil, fmt.Errorf("google: error getting API client: %w", err)
}
// query all the groups
var groups []*directory.Group
err = apiClient.Groups.List().
Context(ctx).
Customer(currentAccountCustomerID).
Pages(ctx, func(res *admin.Groups) error {
for _, g := range res.Groups {
// Skip group without member.
if g.DirectMembersCount == 0 {
continue
}
groups = append(groups, &directory.Group{
Id: g.Id,
Name: g.Email,
Email: g.Email,
})
}
return nil
})
if err != nil {
return nil, nil, fmt.Errorf("google: error getting groups: %w", err)
}
// query all the user members for each group
// - create a lookup table for the user (storing id and name)
// (this includes users who aren't necessarily members of the same organization)
// - create a lookup table for the user's groups
userLookup := map[string]apiUserObject{}
userIDToGroups := map[string][]string{}
for _, group := range groups {
group := group
err = apiClient.Members.List(group.Id).
Context(ctx).
Pages(ctx, func(res *admin.Members) error {
for _, member := range res.Members {
// only include user objects
if member.Type != "USER" {
continue
}
userLookup[member.Id] = apiUserObject{
ID: member.Id,
Email: member.Email,
}
userIDToGroups[member.Id] = append(userIDToGroups[member.Id], group.Id)
}
return nil
})
if err != nil {
return nil, nil, fmt.Errorf("google: error getting group members: %w", err)
}
}
// query all the users in the organization
err = apiClient.Users.List().
Context(ctx).
Customer(currentAccountCustomerID).
Pages(ctx, func(res *admin.Users) error {
for _, u := range res.Users {
auo := apiUserObject{
ID: u.Id,
Email: u.PrimaryEmail,
}
if u.Name != nil {
auo.DisplayName = u.Name.FullName
}
userLookup[u.Id] = auo
}
return nil
})
if err != nil {
return nil, nil, fmt.Errorf("google: error getting users: %w", err)
}
var users []*directory.User
for _, u := range userLookup {
groups := userIDToGroups[u.ID]
sort.Strings(groups)
users = append(users, &directory.User{
Id: u.ID,
GroupIds: groups,
DisplayName: u.DisplayName,
Email: u.Email,
})
}
sort.Slice(users, func(i, j int) bool {
return users[i].Id < users[j].Id
})
return groups, users, nil
}
func (p *Provider) getAPIClient(ctx context.Context) (*admin.Service, error) {
p.mu.RLock()
apiClient := p.apiClient
p.mu.RUnlock()
if apiClient != nil {
return apiClient, nil
}
p.mu.Lock()
defer p.mu.Unlock()
if p.apiClient != nil {
return p.apiClient, nil
}
apiCreds, err := json.Marshal(p.cfg.serviceAccount)
if err != nil {
return nil, fmt.Errorf("google: could not marshal service account json %w", err)
}
config, err := google.JWTConfigFromJSON(apiCreds, apiScopes...)
if err != nil {
return nil, fmt.Errorf("google: error reading jwt config: %w", err)
}
config.Subject = p.cfg.serviceAccount.ImpersonateUser
ts := config.TokenSource(ctx)
p.apiClient, err = admin.NewService(ctx,
option.WithTokenSource(ts),
option.WithEndpoint(p.cfg.url))
if err != nil {
return nil, fmt.Errorf("google: failed creating admin service %w", err)
}
return p.apiClient, nil
}
// A ServiceAccount is used to authenticate with the Google APIs.
//
// Google oauth fields are from https://github.com/golang/oauth2/blob/master/google/google.go#L99
type ServiceAccount struct {
Type string `json:"type"` // serviceAccountKey or userCredentialsKey
// Service Account fields
ClientEmail string `json:"client_email"`
PrivateKeyID string `json:"private_key_id"`
PrivateKey string `json:"private_key"`
TokenURL string `json:"token_uri"`
ProjectID string `json:"project_id"`
// User Credential fields
// (These typically come from gcloud auth.)
ClientSecret string `json:"client_secret"`
ClientID string `json:"client_id"`
RefreshToken string `json:"refresh_token"`
// The User to use for Admin Directory API calls
ImpersonateUser string `json:"impersonate_user"`
}
// ParseServiceAccount parses the service account in the config options.
func ParseServiceAccount(rawServiceAccount string) (*ServiceAccount, error) {
var serviceAccount ServiceAccount
if err := encoding.DecodeBase64OrJSON(rawServiceAccount, &serviceAccount); err != nil {
return nil, err
}
if serviceAccount.ImpersonateUser == "" {
return nil, fmt.Errorf("impersonate_user is required")
}
return &serviceAccount, nil
}
type apiUserObject struct {
ID string
DisplayName string
Email string
}
func isAccessDenied(err error) bool {
if err == nil {
return false
}
gerr := new(googleapi.Error)
if errors.As(err, &gerr) {
return gerr.Code == http.StatusForbidden
}
return false
}

View file

@ -1,260 +0,0 @@
package google
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/go-chi/chi/v5"
"github.com/go-chi/chi/v5/middleware"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/pomerium/pomerium/internal/directory/directoryerrors"
"github.com/pomerium/pomerium/internal/testutil"
"github.com/pomerium/pomerium/pkg/grpc/directory"
)
var privateKey = `
-----BEGIN RSA PRIVATE KEY-----
MIIG4wIBAAKCAYEAnetGqPqS6dqYnV9S5S8gL34t7RRUMsf4prxIR+1PMv+bEqVH
pzXBbzgGKIbQbu+njoRzhQ95RbcEzZiVFivNggijkFoUNkFIjy42O/xXdTRTX3/u
4pxu1ctccqYnZwnry6E8ekQTVX7kgmVqgzIrY1Y6K3PVlhkgGDK/TStu+RIPoA73
vJUpJTFTw+6tgUSBmCkzctQsGFGUiGBRpqlxogEkImJYrcMUJkWhTopLy79+3OB5
eGGKWwA6P2ZhBB4RKHBWyWipsYxr889QNG6P3o+Be6lIzvcdiIIBYK8qWLmO45hR
xUGWSRK8sveAJO54t+wGE0dSPVKfoS4oNqGYdGhBzoTwsfZRvyvcWeQob4DNCqQa
n41XAuYGOG3X1PexdwGSwwqrq2tuG9d2AJ2NjG8nC9hjuuKDfBGTigTwzrwkrn+F
3o94NoglQgsXfZeWoBXR5HDaDTdqexSRK0OpSPbvzkn8QDUymdaw7nRS7kU2O7fa
W8kxiV8AVt2v/jYjAgMBAAECggGAS+6NE0Mo0Pki2Mi0+y4ls7BgNNbJhYFRthpi
RvN8WXE+B0EhquzWDbxKecIZBr6FOqnFQf2muja+QH1VckutjRDKVOZ7QXsygGYf
/cff5aM7U3gYTS4avQIDeb0axRioIElu4vtIsJtLFMfe5yaAZktXvPz9fiamn/wG
r/xqZ6ifir6nsC2okxGczWE+XCGsjpWA/321lhvj548os5JV6SfTUBUpvqNGVQC2
ByXIPDffsCTfQ1rjQ85gM4vuqiQqKn/KXRrMR1kRIOrglMJ6dllitsadkfWdPkVg
fHjM1KAnw/uob5kFhvqeEll9sESfCXttrb4XneKAsEck958ChSpU4csaBAfLFVYP
5xyIfoaQ+/CUjWI0u1lQbg6jZO59rYfdd5OlH+MyhHybuHR1a0G1izNfuG9WPOWI
aprNayH2Wxy9/ZvlrE5yTAeW9tof28hO6O7wBNOcJTrzztsN+V8pSAo0IE2r4D83
h978LneAwhC/8mVvzhd/y2t99vcBAoHBAMumCoHmHRAsYBApg2SHxPCTDi4fS1hC
IbcuzyvJLv214bHOsdgg2I72a1Q+bbrZAHgENVgSP9Sx1k8aXlvnGOEbpL8w3jRL
G4/qXzGrMBp3LCzupi3dI3yrrIMWb2C0goyHeAejzrfaM+uDYTGW4iqhA39zBj4o
zoydz3v0i8Yag7Df9MIwr34WD9Ng0oXh8XRCAYJmS1e43jnM+XcFdSfGVhKn9h1B
Cbv/hqUSv6baNloWLlPBffLII5bx633MMwKBwQDGg87fKEPz73pbp1kAOuXDpUnm
+OUFuf6Uqr/OqCorh8SOwxb2q3ScqyCWGVLOaVBcSV+zxIdcQzX2kjf1mj7PcQ6Q
2xfDIS1/xT+RiX9LO0kbkVDYcwcGeKVtmUwWyjauo96OB2r+SchTsNJpYOT3a/7r
JUKdbHFwsFwAx5q7r9mOh0BOybuXM6N7lUDBf4SgrhjnKRh1pME3R0JbJj9m8tZg
SsWlHcj04yAXJ7NGemiiYgeDZ4unsAfx7/sS/lECgcEAsL0Yj2XTQU8Ry9ULaDsA
az1k6BhWvnEea6lfOQPwGVY5WqQk6oqPB3vK6CEKAEgGRSJ53UZxSTlR4fLjg2UL
zYm9MATMQ5wPfpYMKcIFDGLy3sf7RwCNpMwk+tuEq+vdBPMo85BxflQMDVBHEM9+
1zpIG9sKxvWJVLY89LnmeHZYZi/nboTsOUQSVgPIkVLmx1vljXMT3jzd+FHxCx+c
bnmOB8DnMrpYJWV9SFP+KmNlGkf3ys65bPPPF1g7ZUDLAoHAaKHqtQa1Imr0NED1
kUB6AHArjrlbhXQuck+5f4R1jbIm8RR1Exj2AunT6Cl60t8Bg1MNRWRt8DxgwhD5
u9NMDezKP6GrWacwIytlQSGW3aFm/EfQs/WVG10V3LmzOEPnJI+s63GPfG6JT0tg
7DgtFxhuKaTfAri45iueoq6SqSCb7Brv01dTL/QA1E+r7RF4Z3S8HYM0qDVpvegq
Wn7DZlDSm7htioUzeZgJPwsm3BwC8Kv4x9MY8g6/cU8LKEyxAoHAWCaDpLIuQ51r
PeL+u/1cfNdi6OOrtZ6S95tu3Vv+mYzpCPnOpgPHFp3l+RGmLg56t7uvHFaFxOvB
EjPm4bVhnPSA7pl7ZHQXhinG9+4UgcejoCAJzfg05BI1tMbwFZ+C0tG/PNzBlaX+
IwkGO8VP/54N6wL1UqfZ8AKJFZW8G7W7KVkjqye1FS4oeDlJ197t/X+PMn5sFAc7
UVsDaSelBqpsfmetXSH8KC3XkbgCtHvgAnJDkGkp84VmJvMr5ukv
-----END RSA PRIVATE KEY-----
`
type M = map[string]interface{}
func newMockAPI(t *testing.T, srv *httptest.Server) http.Handler {
r := chi.NewRouter()
r.Use(middleware.Logger)
r.Post("/token", func(w http.ResponseWriter, r *http.Request) {
_ = json.NewEncoder(w).Encode(M{
"access_token": "ACCESSTOKEN",
"token_type": "Bearer",
"refresh_token": "REFRESHTOKEN",
})
})
r.Route("/admin/directory/v1", func(r chi.Router) {
r.Route("/groups", func(r chi.Router) {
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Query().Get("userKey") {
case "user1":
_ = json.NewEncoder(w).Encode(M{
"kind": "admin#directory#groups",
"groups": []M{
{"id": "group1"},
{"id": "group2"},
},
})
default:
_ = json.NewEncoder(w).Encode(M{
"kind": "admin#directory#groups",
"groups": []M{
{"id": "group1", "directMembersCount": "2"},
{"id": "group2"},
},
})
}
})
r.Get("/{groupKey}/members", func(w http.ResponseWriter, r *http.Request) {
switch chi.URLParam(r, "groupKey") {
case "group1":
_ = json.NewEncoder(w).Encode(M{
"members": []M{
{
"kind": "admin#directory#member",
"id": "inside-user1",
"email": "user1@inside.test",
"type": "USER",
},
{
"kind": "admin#directory#member",
"id": "outside-user1",
"email": "user1@outside.test",
"type": "USER",
},
},
})
}
})
})
r.Route("/users", func(r chi.Router) {
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
_ = json.NewEncoder(w).Encode(M{
"kind": "admin#directory#users",
"users": []M{
{
"kind": "admin#directory#user",
"id": "inside-user1",
"primaryEmail": "user1@inside.test",
},
},
})
})
r.Get("/{user_id}", func(w http.ResponseWriter, r *http.Request) {
switch chi.URLParam(r, "user_id") {
case "inside-user1":
_ = json.NewEncoder(w).Encode(M{
"kind": "admin#directory#user",
"id": "inside-user1",
"name": M{
"fullName": "User 1",
},
"primaryEmail": "user1@inside.test",
})
case "outside-user1":
http.Error(w, "forbidden", http.StatusForbidden)
default:
http.Error(w, "not found", http.StatusNotFound)
}
})
})
})
return r
}
func TestProvider_User(t *testing.T) {
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*30)
defer clearTimeout()
var mockAPI http.Handler
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
mockAPI.ServeHTTP(w, r)
}))
defer srv.Close()
mockAPI = newMockAPI(t, srv)
p := New(WithServiceAccount(&ServiceAccount{
Type: "service_account",
PrivateKey: privateKey,
TokenURL: srv.URL + "/token",
}), WithURL(srv.URL))
du, err := p.User(ctx, "inside-user1", "")
if !assert.NoError(t, err) {
return
}
assert.Equal(t, "inside-user1", du.Id)
assert.Equal(t, "user1@inside.test", du.Email)
assert.Equal(t, "User 1", du.DisplayName)
assert.Equal(t, []string{"group1", "group2"}, du.GroupIds)
du, err = p.User(ctx, "outside-user1", "")
if assert.ErrorIs(t, err, directoryerrors.ErrPreferExistingInformation) {
return
}
assert.Equal(t, "outside-user1", du.Id)
}
func TestProvider_UserGroups(t *testing.T) {
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*30)
defer clearTimeout()
var mockAPI http.Handler
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
mockAPI.ServeHTTP(w, r)
}))
defer srv.Close()
mockAPI = newMockAPI(t, srv)
p := New(WithServiceAccount(&ServiceAccount{
Type: "service_account",
PrivateKey: privateKey,
TokenURL: srv.URL + "/token",
}), WithURL(srv.URL))
dgs, dus, err := p.UserGroups(ctx)
if !assert.NoError(t, err) {
return
}
assert.Equal(t, []*directory.Group{
{Id: "group1"},
}, dgs)
testutil.AssertProtoJSONEqual(t, `[
{ "id": "inside-user1", "email": "user1@inside.test", "groupIds": ["group1"] },
{ "id": "outside-user1", "email": "user1@outside.test", "groupIds": ["group1"] }
]`, dus)
}
func TestParseServiceAccount(t *testing.T) {
tests := []struct {
name string
rawServiceAccount string
serviceAccount *ServiceAccount
wantErr bool
}{
{
"json",
`{"impersonate_user":"IMPERSONATE_USER"}`,
&ServiceAccount{ImpersonateUser: "IMPERSONATE_USER"},
false,
},
{
"base64 json",
`eyJpbXBlcnNvbmF0ZV91c2VyIjoiSU1QRVJTT05BVEVfVVNFUiJ9`,
&ServiceAccount{ImpersonateUser: "IMPERSONATE_USER"},
false,
},
{
"empty",
"",
nil,
true,
},
{
"invalid",
"Zm9v---",
nil,
true,
},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
got, err := ParseServiceAccount(tc.rawServiceAccount)
require.True(t, (err != nil) == tc.wantErr)
assert.Equal(t, tc.serviceAccount, got)
})
}
}

View file

@ -1,421 +0,0 @@
// Package okta contains the Okta directory provider.
package okta
import (
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"sort"
"strconv"
"time"
"github.com/rs/zerolog"
"github.com/tomnomnom/linkheader"
"github.com/pomerium/pomerium/internal/encoding"
"github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/pkg/grpc/directory"
)
// Name is the provider name.
const Name = "okta"
const (
// Okta use ISO-8601, see https://developer.okta.com/docs/reference/api-overview/#media-types
filterDateFormat = "2006-01-02T15:04:05.999Z"
batchSize = 200
readLimit = 100 * 1024
httpSuccessClass = 2
)
// Errors.
var (
ErrAPIKeyRequired = errors.New("okta: api_key is required")
ErrServiceAccountNotDefined = errors.New("okta: service account not defined")
ErrProviderURLNotDefined = errors.New("okta: provider url not defined")
)
type config struct {
batchSize int
httpClient *http.Client
providerURL *url.URL
serviceAccount *ServiceAccount
}
// An Option configures the Okta Provider.
type Option func(cfg *config)
// WithBatchSize sets the batch size option.
func WithBatchSize(batchSize int) Option {
return func(cfg *config) {
cfg.batchSize = batchSize
}
}
// WithHTTPClient sets the http client option.
func WithHTTPClient(httpClient *http.Client) Option {
return func(cfg *config) {
cfg.httpClient = httputil.NewLoggingClient(httpClient, "okta_idp_client",
func(evt *zerolog.Event) *zerolog.Event {
return evt.Str("provider", "okta")
})
}
}
// WithProviderURL sets the provider URL option.
func WithProviderURL(uri *url.URL) Option {
return func(cfg *config) {
cfg.providerURL = uri
}
}
// WithServiceAccount sets the service account option.
func WithServiceAccount(serviceAccount *ServiceAccount) Option {
return func(cfg *config) {
cfg.serviceAccount = serviceAccount
}
}
func getConfig(options ...Option) *config {
cfg := new(config)
WithBatchSize(batchSize)(cfg)
WithHTTPClient(http.DefaultClient)(cfg)
for _, option := range options {
option(cfg)
}
return cfg
}
// A Provider is an Okta user group directory provider.
type Provider struct {
cfg *config
lastUpdated *time.Time
groups map[string]*directory.Group
}
// New creates a new Provider.
func New(options ...Option) *Provider {
return &Provider{
cfg: getConfig(options...),
groups: make(map[string]*directory.Group),
}
}
func withLog(ctx context.Context) context.Context {
return log.WithContext(ctx, func(c zerolog.Context) zerolog.Context {
return c.Str("service", "directory").Str("provider", "okta")
})
}
// User returns the user record for the given id.
func (p *Provider) User(ctx context.Context, userID, accessToken string) (*directory.User, error) {
ctx = withLog(ctx)
if p.cfg.serviceAccount == nil {
return nil, ErrServiceAccountNotDefined
}
du := &directory.User{
Id: userID,
}
au, err := p.getUser(ctx, userID)
if err != nil {
return nil, err
}
du.DisplayName = au.getDisplayName()
du.Email = au.Profile.Email
groups, err := p.listUserGroups(ctx, userID)
if err != nil {
return nil, err
}
for _, g := range groups {
du.GroupIds = append(du.GroupIds, g.ID)
}
sort.Strings(du.GroupIds)
return du, nil
}
// UserGroups fetches the groups of which the user is a member
// https://developer.okta.com/docs/reference/api/users/#get-user-s-groups
func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*directory.User, error) {
ctx = withLog(ctx)
if p.cfg.serviceAccount == nil {
return nil, nil, ErrServiceAccountNotDefined
}
log.Info(ctx).Msg("getting user groups")
if p.cfg.providerURL == nil {
return nil, nil, ErrProviderURLNotDefined
}
groups, err := p.getGroups(ctx)
if err != nil {
return nil, nil, err
}
userLookup := map[string]apiUserObject{}
userIDToGroups := map[string][]string{}
for i := 0; i < len(groups); i++ {
group := groups[i]
users, err := p.getGroupMembers(ctx, group.Id)
// if we get a 404 on the member query, it means the group doesn't exist, so we should remove it from
// the cached lookup and the local groups list
var apiErr *APIError
if errors.As(err, &apiErr) && apiErr.HTTPStatusCode == http.StatusNotFound {
log.Debug(ctx).Str("group", group.Id).Msg("okta: group was removed")
delete(p.groups, group.Id)
groups = append(groups[:i], groups[i+1:]...)
i--
continue
}
if err != nil {
return nil, nil, err
}
for _, u := range users {
userIDToGroups[u.ID] = append(userIDToGroups[u.ID], group.Id)
userLookup[u.ID] = u
}
}
var users []*directory.User
for _, u := range userLookup {
groups := userIDToGroups[u.ID]
sort.Strings(groups)
users = append(users, &directory.User{
Id: u.ID,
GroupIds: groups,
DisplayName: u.getDisplayName(),
Email: u.Profile.Email,
})
}
sort.Slice(users, func(i, j int) bool {
return users[i].Id < users[j].Id
})
return groups, users, nil
}
func (p *Provider) getGroups(ctx context.Context) ([]*directory.Group, error) {
u := &url.URL{Path: "/api/v1/groups"}
q := u.Query()
q.Set("limit", strconv.Itoa(p.cfg.batchSize))
if p.lastUpdated != nil {
q.Set("filter", fmt.Sprintf(`lastUpdated gt "%[1]s" or lastMembershipUpdated gt "%[1]s"`, p.lastUpdated.UTC().Format(filterDateFormat)))
} else {
now := time.Now()
p.lastUpdated = &now
}
u.RawQuery = q.Encode()
groupURL := p.cfg.providerURL.ResolveReference(u).String()
for groupURL != "" {
var out []apiGroupObject
hdrs, err := p.apiGet(ctx, groupURL, &out)
if err != nil {
return nil, fmt.Errorf("okta: error querying for groups: %w", err)
}
for _, el := range out {
lu, _ := time.Parse(el.LastUpdated, filterDateFormat)
lmu, _ := time.Parse(el.LastMembershipUpdated, filterDateFormat)
if lu.After(*p.lastUpdated) {
p.lastUpdated = &lu
}
if lmu.After(*p.lastUpdated) {
p.lastUpdated = &lmu
}
p.groups[el.ID] = &directory.Group{
Id: el.ID,
Name: el.Profile.Name,
}
}
groupURL = getNextLink(hdrs)
}
groups := make([]*directory.Group, 0, len(p.groups))
for _, dg := range p.groups {
groups = append(groups, dg)
}
return groups, nil
}
func (p *Provider) getGroupMembers(ctx context.Context, groupID string) (users []apiUserObject, err error) {
usersURL := p.cfg.providerURL.ResolveReference(&url.URL{
Path: fmt.Sprintf("/api/v1/groups/%s/users", groupID),
RawQuery: fmt.Sprintf("limit=%d", p.cfg.batchSize),
}).String()
for usersURL != "" {
var out []apiUserObject
hdrs, err := p.apiGet(ctx, usersURL, &out)
if err != nil {
return nil, fmt.Errorf("okta: error querying for groups: %w", err)
}
users = append(users, out...)
usersURL = getNextLink(hdrs)
}
return users, nil
}
func (p *Provider) getUser(ctx context.Context, userID string) (*apiUserObject, error) {
apiURL := p.cfg.providerURL.ResolveReference(&url.URL{
Path: fmt.Sprintf("/api/v1/users/%s", userID),
}).String()
var out apiUserObject
_, err := p.apiGet(ctx, apiURL, &out)
if err != nil {
return nil, fmt.Errorf("okta: error querying for user: %w", err)
}
return &out, nil
}
func (p *Provider) listUserGroups(ctx context.Context, userID string) (groups []apiGroupObject, err error) {
apiURL := p.cfg.providerURL.ResolveReference(&url.URL{
Path: fmt.Sprintf("/api/v1/users/%s/groups", userID),
}).String()
for apiURL != "" {
var out []apiGroupObject
hdrs, err := p.apiGet(ctx, apiURL, &out)
if err != nil {
return nil, fmt.Errorf("okta: error querying for user groups: %w", err)
}
groups = append(groups, out...)
apiURL = getNextLink(hdrs)
}
return groups, nil
}
func (p *Provider) apiGet(ctx context.Context, uri string, out interface{}) (http.Header, error) {
req, err := http.NewRequestWithContext(ctx, "GET", uri, nil)
if err != nil {
return nil, fmt.Errorf("okta: failed to create HTTP request: %w", err)
}
req.Header.Set("Accept", "application/json")
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "SSWS "+p.cfg.serviceAccount.APIKey)
for {
res, err := p.cfg.httpClient.Do(req)
if err != nil {
return nil, err
}
defer res.Body.Close()
if res.StatusCode == http.StatusTooManyRequests {
limitReset, err := strconv.ParseInt(res.Header.Get("X-Rate-Limit-Reset"), 10, 64)
if err == nil {
time.Sleep(time.Until(time.Unix(limitReset, 0)))
}
continue
}
if res.StatusCode/100 != httpSuccessClass {
return nil, newAPIError(res)
}
if err := json.NewDecoder(res.Body).Decode(out); err != nil {
return nil, err
}
return res.Header, nil
}
}
func getNextLink(hdrs http.Header) string {
for _, link := range linkheader.ParseMultiple(hdrs.Values("Link")) {
if link.Rel == "next" {
return link.URL
}
}
return ""
}
// A ServiceAccount is used by the Okta provider to query the API.
type ServiceAccount struct {
APIKey string `json:"api_key"`
}
// ParseServiceAccount parses the service account in the config options.
func ParseServiceAccount(rawServiceAccount string) (*ServiceAccount, error) {
var serviceAccount ServiceAccount
err := encoding.DecodeBase64OrJSON(rawServiceAccount, &serviceAccount)
if err != nil {
bs, err := base64.StdEncoding.DecodeString(rawServiceAccount)
if err != nil {
return nil, err
}
serviceAccount.APIKey = string(bs)
}
if serviceAccount.APIKey == "" {
return nil, ErrAPIKeyRequired
}
return &serviceAccount, nil
}
// An APIError is an error from the okta API.
type APIError struct {
HTTPStatusCode int
Body string
ErrorCode string `json:"errorCode"`
ErrorSummary string `json:"errorSummary"`
ErrorLink string `json:"errorLink"`
ErrorID string `json:"errorId"`
ErrorCauses []string `json:"errorCauses"`
}
func newAPIError(res *http.Response) error {
if res == nil {
return nil
}
buf, _ := io.ReadAll(io.LimitReader(res.Body, readLimit)) // limit to 100kb
err := &APIError{
HTTPStatusCode: res.StatusCode,
Body: string(buf),
}
_ = json.Unmarshal(buf, err)
return err
}
func (err *APIError) Error() string {
return fmt.Sprintf("okta: error querying API, status_code=%d: %s", err.HTTPStatusCode, err.Body)
}
type (
apiGroupObject struct {
ID string `json:"id"`
Profile struct {
Name string `json:"name"`
} `json:"profile"`
LastUpdated string `json:"lastUpdated"`
LastMembershipUpdated string `json:"lastMembershipUpdated"`
}
apiUserObject struct {
ID string `json:"id"`
Profile struct {
FirstName string `json:"firstName"`
LastName string `json:"lastName"`
Email string `json:"email"`
} `json:"profile"`
}
)
func (obj *apiUserObject) getDisplayName() string {
return obj.Profile.FirstName + " " + obj.Profile.LastName
}

View file

@ -1,361 +0,0 @@
package okta
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"net/url"
"sort"
"strings"
"testing"
"github.com/go-chi/chi/v5"
"github.com/go-chi/chi/v5/middleware"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tomnomnom/linkheader"
"github.com/pomerium/pomerium/internal/testutil"
"github.com/pomerium/pomerium/pkg/grpc/directory"
)
type M = map[string]interface{}
func newMockOkta(srv *httptest.Server, userEmailToGroups map[string][]string) http.Handler {
getAllGroups := func() map[string]struct{} {
allGroups := map[string]struct{}{}
for _, groups := range userEmailToGroups {
for _, group := range groups {
allGroups[group] = struct{}{}
}
}
return allGroups
}
r := chi.NewRouter()
r.Use(middleware.Logger)
r.Use(func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Header.Get("Authorization") != "SSWS APITOKEN" {
http.Error(w, "forbidden", http.StatusForbidden)
return
}
next.ServeHTTP(w, r)
})
})
r.Route("/api/v1", func(r chi.Router) {
r.Route("/groups", func(r chi.Router) {
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
lastUpdated := strings.Contains(r.URL.Query().Get("filter"), "lastUpdated ")
var groups []string
for group := range getAllGroups() {
if lastUpdated && group != "user-updated" {
continue
}
if !lastUpdated && group == "user-updated" {
continue
}
groups = append(groups, group)
}
sort.Strings(groups)
var result []M
found := r.URL.Query().Get("after") == ""
for i := range groups {
if found {
result = append(result, M{
"id": groups[i],
"profile": M{
"name": groups[i] + "-name",
},
})
break
}
found = r.URL.Query().Get("after") == groups[i]
}
if len(result) > 0 {
nextURL := mustParseURL(srv.URL).ResolveReference(r.URL)
q := nextURL.Query()
q.Set("after", result[0]["id"].(string))
nextURL.RawQuery = q.Encode()
w.Header().Set("Link", linkheader.Link{
URL: nextURL.String(),
Rel: "next",
}.String())
}
_ = json.NewEncoder(w).Encode(result)
})
r.Get("/{group}/users", func(w http.ResponseWriter, r *http.Request) {
group := chi.URLParam(r, "group")
if _, ok := getAllGroups()[group]; !ok {
w.WriteHeader(http.StatusNotFound)
w.Write([]byte(`{
"errorCode": "E0000007",
"errorSummary": "Not found: {0}",
"errorLink": E0000007,
"errorId": "sampleE7p0NECLNnSN5z_xLNT",
"errorCauses": []
}`))
return
}
var result []M
for email, groups := range userEmailToGroups {
for _, g := range groups {
if group == g {
result = append(result, M{
"id": email,
"profile": M{
"email": email,
"firstName": "first",
"lastName": "last",
},
})
}
}
}
sort.Slice(result, func(i, j int) bool {
return result[i]["id"].(string) < result[j]["id"].(string)
})
_ = json.NewEncoder(w).Encode(result)
})
})
r.Route("/users", func(r chi.Router) {
r.Get("/{user_id}/groups", func(w http.ResponseWriter, r *http.Request) {
var groups []apiGroupObject
for _, nm := range userEmailToGroups[chi.URLParam(r, "user_id")] {
obj := apiGroupObject{
ID: nm,
}
obj.Profile.Name = nm
groups = append(groups, obj)
}
_ = json.NewEncoder(w).Encode(groups)
})
r.Get("/{user_id}", func(w http.ResponseWriter, r *http.Request) {
user := apiUserObject{
ID: chi.URLParam(r, "user_id"),
}
user.Profile.Email = chi.URLParam(r, "user_id")
user.Profile.FirstName = "first"
user.Profile.LastName = "last"
_ = json.NewEncoder(w).Encode(user)
})
})
})
return r
}
func TestProvider_User(t *testing.T) {
var mockOkta http.Handler
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
mockOkta.ServeHTTP(w, r)
}))
defer srv.Close()
mockOkta = newMockOkta(srv, map[string][]string{
"a@example.com": {"user", "admin"},
"b@example.com": {"user", "test"},
"c@example.com": {"user"},
})
p := New(
WithServiceAccount(&ServiceAccount{APIKey: "APITOKEN"}),
WithProviderURL(mustParseURL(srv.URL)),
)
user, err := p.User(context.Background(), "a@example.com", "")
if !assert.NoError(t, err) {
return
}
testutil.AssertProtoJSONEqual(t, `{
"id": "a@example.com",
"groupIds": ["admin","user"],
"displayName": "first last",
"email": "a@example.com"
}`, user)
}
func TestProvider_UserGroups(t *testing.T) {
var mockOkta http.Handler
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
mockOkta.ServeHTTP(w, r)
}))
defer srv.Close()
mockOkta = newMockOkta(srv, map[string][]string{
"a@example.com": {"user", "admin"},
"b@example.com": {"user", "test"},
"c@example.com": {"user"},
})
p := New(
WithServiceAccount(&ServiceAccount{APIKey: "APITOKEN"}),
WithProviderURL(mustParseURL(srv.URL)),
)
groups, users, err := p.UserGroups(context.Background())
assert.NoError(t, err)
assert.Equal(t, []*directory.User{
{
Id: "a@example.com",
GroupIds: []string{"admin", "user"},
DisplayName: "first last",
Email: "a@example.com",
},
{
Id: "b@example.com",
GroupIds: []string{"test", "user"},
DisplayName: "first last",
Email: "b@example.com",
},
{
Id: "c@example.com",
GroupIds: []string{"user"},
DisplayName: "first last",
Email: "c@example.com",
},
}, users)
assert.Len(t, groups, 3)
}
func TestProvider_UserGroupsQueryUpdated(t *testing.T) {
var mockOkta http.Handler
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
mockOkta.ServeHTTP(w, r)
}))
defer srv.Close()
userEmailToGroups := map[string][]string{
"a@example.com": {"user", "admin"},
"b@example.com": {"user", "test"},
"c@example.com": {"user"},
"updated@example.com": {"user-updated"},
}
mockOkta = newMockOkta(srv, userEmailToGroups)
p := New(
WithServiceAccount(&ServiceAccount{APIKey: "APITOKEN"}),
WithProviderURL(mustParseURL(srv.URL)),
)
groups, users, err := p.UserGroups(context.Background())
assert.NoError(t, err)
assert.Equal(t, []*directory.User{
{
Id: "a@example.com",
GroupIds: []string{"admin", "user"},
DisplayName: "first last",
Email: "a@example.com",
},
{
Id: "b@example.com",
GroupIds: []string{"test", "user"},
DisplayName: "first last",
Email: "b@example.com",
},
{
Id: "c@example.com",
GroupIds: []string{"user"},
DisplayName: "first last",
Email: "c@example.com",
},
}, users)
assert.Len(t, groups, 3)
groups, users, err = p.UserGroups(context.Background())
assert.NoError(t, err)
assert.Equal(t, []*directory.User{
{
Id: "a@example.com",
GroupIds: []string{"admin", "user"},
DisplayName: "first last",
Email: "a@example.com",
},
{
Id: "b@example.com",
GroupIds: []string{"test", "user"},
DisplayName: "first last",
Email: "b@example.com",
},
{
Id: "c@example.com",
GroupIds: []string{"user"},
DisplayName: "first last",
Email: "c@example.com",
},
{
Id: "updated@example.com",
GroupIds: []string{"user-updated"},
DisplayName: "first last",
Email: "updated@example.com",
},
}, users)
assert.Len(t, groups, 4)
userEmailToGroups["b@example.com"] = []string{"user"}
groups, users, err = p.UserGroups(context.Background())
assert.NoError(t, err)
assert.Equal(t, []*directory.User{
{
Id: "a@example.com",
GroupIds: []string{"admin", "user"},
DisplayName: "first last",
Email: "a@example.com",
},
{
Id: "b@example.com",
GroupIds: []string{"user"},
DisplayName: "first last",
Email: "b@example.com",
},
{
Id: "c@example.com",
GroupIds: []string{"user"},
DisplayName: "first last",
Email: "c@example.com",
},
{
Id: "updated@example.com",
GroupIds: []string{"user-updated"},
DisplayName: "first last",
Email: "updated@example.com",
},
}, users)
assert.Len(t, groups, 3)
}
func mustParseURL(rawurl string) *url.URL {
u, err := url.Parse(rawurl)
if err != nil {
panic(err)
}
return u
}
func TestParseServiceAccount(t *testing.T) {
tests := []struct {
name string
rawServiceAccount string
apiKey string
wantErr bool
}{
{"json", `{"api_key": "foo"}`, "foo", false},
{"base64 json", "ewogICAgImFwaV9rZXkiOiAiZm9vIgp9Cg==", "foo", false},
{"base64 value", "Zm9v", "foo", false},
{"empty", "", "", true},
{"invalid", "Zm9v---", "", true},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
got, err := ParseServiceAccount(tc.rawServiceAccount)
require.True(t, (err != nil) == tc.wantErr)
if tc.apiKey != "" {
assert.Equal(t, tc.apiKey, got.APIKey)
}
})
}
}

View file

@ -1,360 +0,0 @@
// Package onelogin contains the onelogin directory provider.
package onelogin
import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/url"
"sort"
"strconv"
"strings"
"sync"
"github.com/rs/zerolog"
"golang.org/x/oauth2"
"github.com/pomerium/pomerium/internal/encoding"
"github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/pkg/grpc/directory"
)
// Name is the provider name.
const Name = "onelogin"
type config struct {
apiURL *url.URL
batchSize int
serviceAccount *ServiceAccount
httpClient *http.Client
}
// An Option updates the onelogin configuration.
type Option func(*config)
// WithBatchSize sets the batch size option.
func WithBatchSize(batchSize int) Option {
return func(cfg *config) {
cfg.batchSize = batchSize
}
}
// WithHTTPClient sets the http client option.
func WithHTTPClient(httpClient *http.Client) Option {
return func(cfg *config) {
cfg.httpClient = httputil.NewLoggingClient(httpClient, "onelogin_idp_client",
func(evt *zerolog.Event) *zerolog.Event {
return evt.Str("provider", "onelogin")
})
}
}
// WithServiceAccount sets the service account in the config.
func WithServiceAccount(serviceAccount *ServiceAccount) Option {
return func(cfg *config) {
cfg.serviceAccount = serviceAccount
}
}
// WithURL sets the api url in the config.
func WithURL(apiURL *url.URL) Option {
return func(cfg *config) {
cfg.apiURL = apiURL
}
}
func getConfig(options ...Option) *config {
cfg := new(config)
WithBatchSize(20)(cfg)
WithHTTPClient(http.DefaultClient)(cfg)
WithURL(&url.URL{
Scheme: "https",
Host: "api.us.onelogin.com",
})(cfg)
for _, option := range options {
option(cfg)
}
return cfg
}
// The Provider retrieves users and groups from onelogin.
type Provider struct {
cfg *config
mu sync.RWMutex
token *oauth2.Token
}
// New creates a new Provider.
func New(options ...Option) *Provider {
cfg := getConfig(options...)
return &Provider{
cfg: cfg,
}
}
func withLog(ctx context.Context) context.Context {
return log.WithContext(ctx, func(c zerolog.Context) zerolog.Context {
return c.Str("service", "directory").Str("provider", "onelogin")
})
}
// User returns the user record for the given id.
func (p *Provider) User(ctx context.Context, userID, accessToken string) (*directory.User, error) {
if p.cfg.serviceAccount == nil {
return nil, fmt.Errorf("onelogin: service account not defined")
}
du := &directory.User{
Id: userID,
}
ctx = withLog(ctx)
token, err := p.getToken(ctx)
if err != nil {
return nil, err
}
au, err := p.getUser(ctx, token.AccessToken, userID)
if err != nil {
return nil, err
}
du.DisplayName = au.getDisplayName()
du.Email = au.Email
du.GroupIds = []string{strconv.Itoa(au.GroupID)}
return du, nil
}
// UserGroups gets the directory user groups for onelogin.
func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*directory.User, error) {
if p.cfg.serviceAccount == nil {
return nil, nil, fmt.Errorf("onelogin: service account not defined")
}
ctx = withLog(ctx)
log.Info(ctx).Msg("getting user groups")
token, err := p.getToken(ctx)
if err != nil {
return nil, nil, err
}
groups, err := p.listGroups(ctx, token.AccessToken)
if err != nil {
return nil, nil, err
}
apiUsers, err := p.listUsers(ctx, token.AccessToken)
if err != nil {
return nil, nil, err
}
var users []*directory.User
for _, u := range apiUsers {
users = append(users, &directory.User{
Id: strconv.Itoa(u.ID),
GroupIds: []string{strconv.Itoa(u.GroupID)},
DisplayName: u.FirstName + " " + u.LastName,
Email: u.Email,
})
}
sort.Slice(users, func(i, j int) bool {
return users[i].Id < users[j].Id
})
return groups, users, nil
}
func (p *Provider) listGroups(ctx context.Context, accessToken string) ([]*directory.Group, error) {
var groups []*directory.Group
apiURL := p.cfg.apiURL.ResolveReference(&url.URL{
Path: "/api/1/groups",
RawQuery: fmt.Sprintf("limit=%d", p.cfg.batchSize),
}).String()
for apiURL != "" {
var result []struct {
ID int `json:"id"`
Name string `json:"name"`
}
nextLink, err := p.apiGet(ctx, accessToken, apiURL, &result)
if err != nil {
return nil, fmt.Errorf("onelogin: listing groups: %w", err)
}
for _, r := range result {
groups = append(groups, &directory.Group{
Id: strconv.Itoa(r.ID),
Name: r.Name,
})
}
apiURL = nextLink
}
return groups, nil
}
func (p *Provider) getUser(ctx context.Context, accessToken string, userID string) (*apiUserObject, error) {
apiURL := p.cfg.apiURL.ResolveReference(&url.URL{
Path: fmt.Sprintf("/api/1/users/%s", userID),
}).String()
var out []apiUserObject
_, err := p.apiGet(ctx, accessToken, apiURL, &out)
if err != nil {
return nil, fmt.Errorf("onelogin: error getting user: %w", err)
}
if len(out) == 0 {
return nil, fmt.Errorf("onelogin: user not found")
}
return &out[0], nil
}
func (p *Provider) listUsers(ctx context.Context, accessToken string) ([]apiUserObject, error) {
var users []apiUserObject
apiURL := p.cfg.apiURL.ResolveReference(&url.URL{
Path: "/api/1/users",
RawQuery: fmt.Sprintf("limit=%d", p.cfg.batchSize),
}).String()
for apiURL != "" {
var result []apiUserObject
nextLink, err := p.apiGet(ctx, accessToken, apiURL, &result)
if err != nil {
return nil, fmt.Errorf("onelogin: listing users: %w", err)
}
users = append(users, result...)
apiURL = nextLink
}
return users, nil
}
func (p *Provider) apiGet(ctx context.Context, accessToken string, uri string, out interface{}) (nextLink string, err error) {
req, err := http.NewRequestWithContext(ctx, "GET", uri, nil)
if err != nil {
return "", err
}
req.Header.Set("Authorization", fmt.Sprintf("bearer:%s", accessToken))
req.Header.Set("Content-Type", "application/json")
res, err := p.cfg.httpClient.Do(req)
if err != nil {
return "", err
}
defer res.Body.Close()
if res.StatusCode/100 != 2 {
return "", fmt.Errorf("onelogin: error querying api: %s", res.Status)
}
var result struct {
Pagination struct {
NextLink string `json:"next_link"`
}
Data json.RawMessage `json:"data"`
}
err = json.NewDecoder(res.Body).Decode(&result)
if err != nil {
return "", err
}
log.Info(ctx).
Str("url", uri).
Interface("result", result).
Msg("api request")
err = json.Unmarshal(result.Data, out)
if err != nil {
return "", err
}
return result.Pagination.NextLink, nil
}
func (p *Provider) getToken(ctx context.Context) (*oauth2.Token, error) {
p.mu.RLock()
token := p.token
p.mu.RUnlock()
if token != nil && token.Valid() {
return token, nil
}
p.mu.Lock()
defer p.mu.Unlock()
token = p.token
if token != nil && token.Valid() {
return token, nil
}
apiURL := p.cfg.apiURL.ResolveReference(&url.URL{
Path: "/auth/oauth2/v2/token",
})
req, err := http.NewRequestWithContext(ctx, "POST", apiURL.String(), strings.NewReader(`{ "grant_type": "client_credentials" }`))
if err != nil {
return nil, err
}
req.Header.Set("Authorization", fmt.Sprintf("client_id:%s, client_secret:%s",
p.cfg.serviceAccount.ClientID, p.cfg.serviceAccount.ClientSecret))
req.Header.Set("Content-Type", "application/json")
res, err := p.cfg.httpClient.Do(req)
if err != nil {
return nil, err
}
defer res.Body.Close()
if res.StatusCode/100 != 2 {
return nil, fmt.Errorf("onelogin: error querying oauth2 token: %s", res.Status)
}
err = json.NewDecoder(res.Body).Decode(&token)
if err != nil {
return nil, err
}
p.token = token
return p.token, nil
}
// A ServiceAccount is used by the OneLogin provider to query the API.
type ServiceAccount struct {
ClientID string `json:"client_id"`
ClientSecret string `json:"client_secret"`
}
// ParseServiceAccount parses the service account in the config options.
func ParseServiceAccount(rawServiceAccount string) (*ServiceAccount, error) {
var serviceAccount ServiceAccount
if err := encoding.DecodeBase64OrJSON(rawServiceAccount, &serviceAccount); err != nil {
return nil, err
}
if serviceAccount.ClientID == "" {
return nil, fmt.Errorf("client_id is required")
}
if serviceAccount.ClientSecret == "" {
return nil, fmt.Errorf("client_secret is required")
}
return &serviceAccount, nil
}
type apiUserObject struct {
ID int `json:"id"`
GroupID int `json:"group_id"`
Email string `json:"email"`
FirstName string `json:"firstname"`
LastName string `json:"lastname"`
}
func (obj *apiUserObject) getDisplayName() string {
return obj.FirstName + " " + obj.LastName
}

View file

@ -1,270 +0,0 @@
package onelogin
import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"sort"
"strconv"
"testing"
"time"
"github.com/go-chi/chi/v5"
"github.com/go-chi/chi/v5/middleware"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/pomerium/pomerium/internal/testutil"
)
type M = map[string]interface{}
func newMockAPI(srv *httptest.Server, userIDToGroupName map[int]string) http.Handler {
lookup := map[string]struct{}{}
for _, group := range userIDToGroupName {
lookup[group] = struct{}{}
}
var allGroups []string
for groupName := range lookup {
allGroups = append(allGroups, groupName)
}
sort.Strings(allGroups)
var allUserIDs []int
for userID := range userIDToGroupName {
allUserIDs = append(allUserIDs, userID)
}
sort.Ints(allUserIDs)
r := chi.NewRouter()
r.Use(middleware.Logger)
r.Post("/auth/oauth2/v2/token", func(w http.ResponseWriter, r *http.Request) {
if r.Header.Get("Authorization") != "client_id:CLIENTID, client_secret:CLIENTSECRET" {
http.Error(w, "forbidden", http.StatusForbidden)
return
}
var request struct {
GrantType string `json:"grant_type"`
}
_ = json.NewDecoder(r.Body).Decode(&request)
if request.GrantType != "client_credentials" {
http.Error(w, "invalid grant_type", http.StatusBadRequest)
return
}
_ = json.NewEncoder(w).Encode(M{
"access_token": "ACCESSTOKEN",
"created_at": time.Now().Format(time.RFC3339),
"expires_in": 360000,
"refresh_token": "REFRESHTOKEN",
"token_type": "bearer",
})
})
r.Route("/api/1", func(r chi.Router) {
r.Use(func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Header.Get("Authorization") != "bearer:ACCESSTOKEN" {
http.Error(w, "forbidden", http.StatusForbidden)
return
}
next.ServeHTTP(w, r)
})
})
r.Get("/groups", func(w http.ResponseWriter, r *http.Request) {
var result struct {
Pagination struct {
NextLink string `json:"next_link"`
} `json:"pagination"`
Data []M `json:"data"`
}
found := r.URL.Query().Get("after") == ""
for i := range allGroups {
if found {
result.Data = append(result.Data, M{
"id": i,
"name": allGroups[i],
})
break
}
found = r.URL.Query().Get("after") == fmt.Sprint(i)
}
if len(result.Data) > 0 {
nextURL := mustParseURL(srv.URL).ResolveReference(r.URL)
q := nextURL.Query()
q.Set("after", fmt.Sprint(result.Data[0]["id"]))
nextURL.RawQuery = q.Encode()
result.Pagination.NextLink = nextURL.String()
}
_ = json.NewEncoder(w).Encode(result)
})
r.Get("/users/{user_id}", func(w http.ResponseWriter, r *http.Request) {
userIDToGroupID := map[int]int{}
for userID, groupName := range userIDToGroupName {
for id, n := range allGroups {
if groupName == n {
userIDToGroupID[userID] = id
}
}
}
userID, _ := strconv.Atoi(chi.URLParam(r, "user_id"))
_ = json.NewEncoder(w).Encode(M{
"data": []M{{
"id": userID,
"email": userIDToGroupName[userID] + "@example.com",
"group_id": userIDToGroupID[userID],
"firstname": "User",
"lastname": fmt.Sprint(userID),
}},
})
})
r.Get("/users", func(w http.ResponseWriter, r *http.Request) {
userIDToGroupID := map[int]int{}
for userID, groupName := range userIDToGroupName {
for id, n := range allGroups {
if groupName == n {
userIDToGroupID[userID] = id
}
}
}
var result []M
for _, userID := range allUserIDs {
result = append(result, M{
"id": userID,
"email": userIDToGroupName[userID] + "@example.com",
"group_id": userIDToGroupID[userID],
"firstname": "User",
"lastname": fmt.Sprint(userID),
})
}
_ = json.NewEncoder(w).Encode(M{
"data": result,
})
})
})
return r
}
func TestProvider_User(t *testing.T) {
var mockAPI http.Handler
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
mockAPI.ServeHTTP(w, r)
}))
defer srv.Close()
mockAPI = newMockAPI(srv, map[int]string{
111: "admin",
222: "test",
333: "user",
})
p := New(
WithServiceAccount(&ServiceAccount{
ClientID: "CLIENTID",
ClientSecret: "CLIENTSECRET",
}),
WithURL(mustParseURL(srv.URL)),
)
user, err := p.User(context.Background(), "111", "ACCESSTOKEN")
if !assert.NoError(t, err) {
return
}
testutil.AssertProtoJSONEqual(t, `{
"id": "111",
"groupIds": ["0"],
"displayName": "User 111",
"email": "admin@example.com"
}`, user)
}
func TestProvider_UserGroups(t *testing.T) {
var mockAPI http.Handler
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
mockAPI.ServeHTTP(w, r)
}))
defer srv.Close()
mockAPI = newMockAPI(srv, map[int]string{
111: "admin",
222: "test",
333: "user",
})
p := New(
WithServiceAccount(&ServiceAccount{
ClientID: "CLIENTID",
ClientSecret: "CLIENTSECRET",
}),
WithURL(mustParseURL(srv.URL)),
)
groups, users, err := p.UserGroups(context.Background())
assert.NoError(t, err)
testutil.AssertProtoJSONEqual(t, `[
{ "id": "111", "groupIds": ["0"], "displayName": "User 111", "email": "admin@example.com" },
{ "id": "222", "groupIds": ["1"], "displayName": "User 222", "email": "test@example.com" },
{ "id": "333", "groupIds": ["2"], "displayName": "User 333", "email": "user@example.com" }
]`, users)
testutil.AssertProtoJSONEqual(t, `[
{ "id": "0", "name": "admin" },
{ "id": "1", "name": "test" },
{ "id": "2", "name": "user" }
]`, groups)
}
func TestParseServiceAccount(t *testing.T) {
tests := []struct {
name string
rawServiceAccount string
serviceAccount *ServiceAccount
wantErr bool
}{
{
"json",
`{"client_id":"CLIENT_ID","client_secret":"CLIENT_SECRET"}`,
&ServiceAccount{ClientID: "CLIENT_ID", ClientSecret: "CLIENT_SECRET"},
false,
},
{
"base64 json",
`eyJjbGllbnRfaWQiOiJDTElFTlRfSUQiLCJjbGllbnRfc2VjcmV0IjoiQ0xJRU5UX1NFQ1JFVCJ9`,
&ServiceAccount{ClientID: "CLIENT_ID", ClientSecret: "CLIENT_SECRET"},
false,
},
{
"empty",
"",
nil,
true,
},
{
"invalid",
"Zm9v---",
nil,
true,
},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
got, err := ParseServiceAccount(tc.rawServiceAccount)
require.True(t, (err != nil) == tc.wantErr)
assert.Equal(t, tc.serviceAccount, got)
})
}
}
func mustParseURL(rawurl string) *url.URL {
u, err := url.Parse(rawurl)
if err != nil {
panic(err)
}
return u
}

View file

@ -1,174 +0,0 @@
package ping
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"strings"
)
var errNotFound = errors.New("ping: user not found")
type (
apiGroup struct {
ID string `json:"id"`
Name string `json:"name"`
}
apiUser struct {
ID string `json:"id"`
Email string `json:"email"`
Name apiUserName `json:"name"`
Username string `json:"username"`
MemberOfGroupIDs []string `json:"memberOfGroupIDs"`
}
apiUserName struct {
Given string `json:"given"`
Middle string `json:"middle"`
Family string `json:"family"`
}
)
func (au apiUser) getDisplayName() string {
var parts []string
if au.Name.Given != "" {
parts = append(parts, au.Name.Given)
}
if au.Name.Middle != "" {
parts = append(parts, au.Name.Middle)
}
if au.Name.Family != "" {
parts = append(parts, au.Name.Family)
}
if len(parts) == 0 {
parts = append(parts, au.Username)
}
return strings.Join(parts, " ")
}
func getAllGroups(ctx context.Context, client *http.Client, apiURL *url.URL, envID string) ([]apiGroup, error) {
nextURL := apiURL.ResolveReference(&url.URL{
Path: fmt.Sprintf("/v1/environments/%s/groups", url.PathEscape(envID)),
}).String()
var apiGroups []apiGroup
err := batchAPIRequest(ctx, client, nextURL, func(body []byte) error {
var apiResponse struct {
Embedded struct {
Groups []apiGroup `json:"groups"`
} `json:"_embedded"`
}
err := json.Unmarshal(body, &apiResponse)
if err != nil {
return fmt.Errorf("ping: error decoding API response: %w", err)
}
apiGroups = append(apiGroups, apiResponse.Embedded.Groups...)
return nil
})
return apiGroups, err
}
func getGroupUsers(ctx context.Context, client *http.Client, apiURL *url.URL, envID, groupID string) ([]apiUser, error) {
nextURL := apiURL.ResolveReference(&url.URL{
Path: fmt.Sprintf("/v1/environments/%s/users", url.PathEscape(envID)),
RawQuery: (&url.Values{
"filter": {fmt.Sprintf(`memberOfGroups[id eq "%s"]`, groupID)},
}).Encode(),
}).String()
var apiUsers []apiUser
err := batchAPIRequest(ctx, client, nextURL, func(body []byte) error {
var apiResponse struct {
Embedded struct {
Users []apiUser `json:"users"`
} `json:"_embedded"`
}
err := json.Unmarshal(body, &apiResponse)
if err != nil {
return fmt.Errorf("ping: error decoding API response: %w", err)
}
apiUsers = append(apiUsers, apiResponse.Embedded.Users...)
return nil
})
return apiUsers, err
}
func getUser(ctx context.Context, client *http.Client, apiURL *url.URL, envID, userID string) (*apiUser, error) {
nextURL := apiURL.ResolveReference(&url.URL{
Path: fmt.Sprintf("/v1/environments/%s/users/%s", url.PathEscape(envID), url.PathEscape(userID)),
RawQuery: (&url.Values{
"include": {"memberOfGroupIDs"},
}).Encode(),
}).String()
req, err := http.NewRequestWithContext(ctx, "GET", nextURL, nil)
if err != nil {
return nil, fmt.Errorf("ping: error building API request: %w", err)
}
res, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("ping: error making API request: %w", err)
}
body, err := io.ReadAll(res.Body)
if err != nil {
return nil, fmt.Errorf("ping: error reading API response: %w", err)
}
_ = res.Body.Close()
if res.StatusCode == http.StatusNotFound {
return nil, errNotFound
} else if res.StatusCode/100 != 2 {
return nil, fmt.Errorf("ping: unexpected status code: %d", res.StatusCode)
}
var u apiUser
err = json.Unmarshal(body, &u)
if err != nil {
return nil, fmt.Errorf("ping: error decoding API response: %w", err)
}
return &u, nil
}
func batchAPIRequest(ctx context.Context, client *http.Client, nextURL string, callback func(body []byte) error) error {
for nextURL != "" {
req, err := http.NewRequestWithContext(ctx, "GET", nextURL, nil)
if err != nil {
return fmt.Errorf("ping: error building API request: %w", err)
}
res, err := client.Do(req)
if err != nil {
return fmt.Errorf("ping: error making API request: %w", err)
}
bs, err := io.ReadAll(res.Body)
if err != nil {
return fmt.Errorf("ping: error reading API response: %w", err)
}
_ = res.Body.Close()
if res.StatusCode/100 != 2 {
return fmt.Errorf("ping: unexpected status code: %d", res.StatusCode)
}
var apiResponse struct {
Links struct {
Next struct {
HREF string `json:"href"`
} `json:"next"`
} `json:"_links"`
}
err = json.Unmarshal(bs, &apiResponse)
if err != nil {
return fmt.Errorf("ping: error decoding API response: %w", err)
}
err = callback(bs)
if err != nil {
return err
}
nextURL = apiResponse.Links.Next.HREF
}
return nil
}

View file

@ -1,116 +0,0 @@
package ping
import (
"fmt"
"net/http"
"net/url"
"strings"
"github.com/rs/zerolog"
"github.com/pomerium/pomerium/internal/encoding"
"github.com/pomerium/pomerium/internal/httputil"
)
type config struct {
authURL *url.URL
apiURL *url.URL
serviceAccount *ServiceAccount
httpClient *http.Client
environmentID string
}
// An Option updates the Ping configuration.
type Option func(*config)
// WithAPIURL sets the api url in the config.
func WithAPIURL(apiURL *url.URL) Option {
return func(cfg *config) {
cfg.apiURL = apiURL
}
}
// WithAuthURL sets the auth url in the config.
func WithAuthURL(authURL *url.URL) Option {
return func(cfg *config) {
cfg.authURL = authURL
}
}
// WithEnvironmentID sets the environment ID in the config.
func WithEnvironmentID(environmentID string) Option {
return func(cfg *config) {
cfg.environmentID = environmentID
}
}
// WithHTTPClient sets the http client option.
func WithHTTPClient(httpClient *http.Client) Option {
return func(cfg *config) {
cfg.httpClient = httputil.NewLoggingClient(httpClient, "ping_idp_client",
func(evt *zerolog.Event) *zerolog.Event {
return evt.Str("provider", "ping")
})
}
}
// WithProviderURL sets the environment ID from the provider URL set in the config.
func WithProviderURL(providerURL *url.URL) Option {
// provider URL will be https://auth.pingone.com/{ENVIRONMENT_ID}/as
if providerURL == nil {
return func(cfg *config) {}
}
parts := strings.Split(providerURL.Path, "/")
if len(parts) < 1 {
return func(cfg *config) {}
}
return WithEnvironmentID(parts[1])
}
// WithServiceAccount sets the service account in the config.
func WithServiceAccount(serviceAccount *ServiceAccount) Option {
return func(cfg *config) {
cfg.serviceAccount = serviceAccount
}
}
func getConfig(options ...Option) *config {
cfg := new(config)
WithHTTPClient(http.DefaultClient)(cfg)
WithAuthURL(&url.URL{
Scheme: "https",
Host: "auth.pingone.com",
})(cfg)
WithAPIURL(&url.URL{
Scheme: "https",
Host: "api.pingone.com",
})(cfg)
for _, option := range options {
option(cfg)
}
return cfg
}
// A ServiceAccount is used by the Ping provider to query the API.
type ServiceAccount struct {
ClientID string `json:"client_id"`
ClientSecret string `json:"client_secret"`
EnvironmentID string `json:"environment_id"`
}
// ParseServiceAccount parses the service account in the config options.
func ParseServiceAccount(rawServiceAccount string) (*ServiceAccount, error) {
var serviceAccount ServiceAccount
if err := encoding.DecodeBase64OrJSON(rawServiceAccount, &serviceAccount); err != nil {
return nil, err
}
if serviceAccount.ClientID == "" {
return nil, fmt.Errorf("client_id is required")
}
if serviceAccount.ClientSecret == "" {
return nil, fmt.Errorf("client_secret is required")
}
return &serviceAccount, nil
}

View file

@ -1,51 +0,0 @@
package ping
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestParseServiceAccount(t *testing.T) {
tests := []struct {
name string
rawServiceAccount string
serviceAccount *ServiceAccount
wantErr bool
}{
{
"json",
`{"client_id":"CLIENT_ID","client_secret":"CLIENT_SECRET","environment_id":"ENVIRONMENT_ID"}`,
&ServiceAccount{ClientID: "CLIENT_ID", ClientSecret: "CLIENT_SECRET", EnvironmentID: "ENVIRONMENT_ID"},
false,
},
{
"base64 json",
`eyJjbGllbnRfaWQiOiJDTElFTlRfSUQiLCJjbGllbnRfc2VjcmV0IjoiQ0xJRU5UX1NFQ1JFVCIsImVudmlyb25tZW50X2lkIjoiRU5WSVJPTk1FTlRfSUQifQ==`,
&ServiceAccount{ClientID: "CLIENT_ID", ClientSecret: "CLIENT_SECRET", EnvironmentID: "ENVIRONMENT_ID"},
false,
},
{
"empty",
"",
nil,
true,
},
{
"invalid",
"Zm9v---",
nil,
true,
},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
got, err := ParseServiceAccount(tc.rawServiceAccount)
require.True(t, (err != nil) == tc.wantErr)
assert.Equal(t, tc.serviceAccount, got)
})
}
}

View file

@ -1,167 +0,0 @@
// Package ping implements a directory provider for Ping.
package ping
import (
"context"
"fmt"
"net/http"
"net/url"
"sort"
"sync"
"golang.org/x/oauth2"
"golang.org/x/oauth2/clientcredentials"
"github.com/pomerium/pomerium/pkg/grpc/directory"
)
// Name is the name of the Ping provider.
const Name = "ping"
// Provider implements a directory provider using the Ping API.
type Provider struct {
cfg *config
mu sync.RWMutex
token *oauth2.Token
}
// New creates a new Ping Provider.
func New(options ...Option) *Provider {
cfg := getConfig(options...)
return &Provider{
cfg: cfg,
}
}
// User returns a user's directory information.
func (p *Provider) User(ctx context.Context, userID, accessToken string) (*directory.User, error) {
client, err := p.getClient(ctx)
if err != nil {
return nil, err
}
au, err := getUser(ctx, client, p.cfg.apiURL, p.cfg.environmentID, userID)
if err != nil {
return nil, err
}
return &directory.User{
Id: au.ID,
DisplayName: au.getDisplayName(),
Email: au.Email,
GroupIds: au.MemberOfGroupIDs,
}, nil
}
// UserGroups returns all the users and groups in the directory.
func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*directory.User, error) {
client, err := p.getClient(ctx)
if err != nil {
return nil, nil, err
}
apiGroups, err := getAllGroups(ctx, client, p.cfg.apiURL, p.cfg.environmentID)
if err != nil {
return nil, nil, err
}
directoryUserLookup := map[string]*directory.User{}
directoryGroups := make([]*directory.Group, len(apiGroups))
for i, ag := range apiGroups {
dg := &directory.Group{
Id: ag.ID,
Name: ag.Name,
}
apiUsers, err := getGroupUsers(ctx, client, p.cfg.apiURL, p.cfg.environmentID, ag.ID)
if err != nil {
return nil, nil, err
}
for _, au := range apiUsers {
du, ok := directoryUserLookup[au.ID]
if !ok {
du = &directory.User{
Id: au.ID,
DisplayName: au.getDisplayName(),
Email: au.Email,
}
directoryUserLookup[au.ID] = du
}
du.GroupIds = append(du.GroupIds, ag.ID)
}
directoryGroups[i] = dg
}
sort.Slice(directoryGroups, func(i, j int) bool {
return directoryGroups[i].Id < directoryGroups[j].Id
})
directoryUsers := make([]*directory.User, 0, len(directoryUserLookup))
for _, du := range directoryUserLookup {
directoryUsers = append(directoryUsers, du)
}
sort.Slice(directoryUsers, func(i, j int) bool {
return directoryUsers[i].Id < directoryUsers[j].Id
})
return directoryGroups, directoryUsers, nil
}
func (p *Provider) getClient(ctx context.Context) (*http.Client, error) {
token, err := p.getToken(ctx)
if err != nil {
return nil, err
}
client := new(http.Client)
*client = *p.cfg.httpClient
client.Transport = &oauth2.Transport{
Source: oauth2.StaticTokenSource(token),
Base: p.cfg.httpClient.Transport,
}
return client, nil
}
func (p *Provider) getToken(ctx context.Context) (*oauth2.Token, error) {
if p.cfg.serviceAccount == nil {
return nil, fmt.Errorf("ping: service account is required")
}
environmentID := p.cfg.serviceAccount.EnvironmentID
if environmentID == "" {
environmentID = p.cfg.environmentID
}
if environmentID == "" {
return nil, fmt.Errorf("ping: environment ID is required")
}
p.mu.RLock()
token := p.token
p.mu.RUnlock()
if token != nil && token.Valid() {
return token, nil
}
p.mu.Lock()
defer p.mu.Unlock()
token = p.token
if token != nil && token.Valid() {
return token, nil
}
ocfg := &clientcredentials.Config{
ClientID: p.cfg.serviceAccount.ClientID,
ClientSecret: p.cfg.serviceAccount.ClientSecret,
TokenURL: p.cfg.authURL.ResolveReference(&url.URL{
Path: fmt.Sprintf("/%s/as/token", environmentID),
}).String(),
}
var err error
p.token, err = ocfg.Token(ctx)
if err != nil {
return nil, err
}
return p.token, nil
}

View file

@ -1,230 +0,0 @@
package ping
import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"sort"
"testing"
"time"
"github.com/go-chi/chi/v5"
"github.com/go-chi/chi/v5/middleware"
"github.com/stretchr/testify/require"
"github.com/pomerium/pomerium/internal/testutil"
)
type M = map[string]interface{}
func newMockAPI(userIDToGroupIDs map[string][]string) http.Handler {
lookup := map[string]struct{}{}
for _, groups := range userIDToGroupIDs {
for _, group := range groups {
lookup[group] = struct{}{}
}
}
var allGroups []string
for groupID := range lookup {
allGroups = append(allGroups, groupID)
}
sort.Strings(allGroups)
var allUserIDs []string
for userID := range userIDToGroupIDs {
allUserIDs = append(allUserIDs, userID)
}
sort.Strings(allUserIDs)
filterToUserIDs := map[string][]string{}
for userID, groupIDs := range userIDToGroupIDs {
for _, groupID := range groupIDs {
filter := fmt.Sprintf(`memberOfGroups[id eq "%s"]`, groupID)
filterToUserIDs[filter] = append(filterToUserIDs[filter], userID)
}
}
r := chi.NewRouter()
r.Use(middleware.Logger)
r.Post("/ENVIRONMENTID/as/token", func(w http.ResponseWriter, r *http.Request) {
u, p, _ := r.BasicAuth()
if u != "CLIENTID" || p != "CLIENTSECRET" {
http.Error(w, "forbidden", http.StatusForbidden)
return
}
grantType := r.FormValue("grant_type")
if grantType != "client_credentials" {
http.Error(w, "invalid grant_type", http.StatusBadRequest)
return
}
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(M{
"access_token": "ACCESSTOKEN",
"created_at": time.Now().Format(time.RFC3339),
"expires_in": 360000,
"refresh_token": "REFRESHTOKEN",
"token_type": "bearer",
})
})
r.Route("/v1/environments/ENVIRONMENTID", func(r chi.Router) {
r.Get("/groups", func(w http.ResponseWriter, r *http.Request) {
var apiGroups []apiGroup
for _, id := range allGroups {
apiGroups = append(apiGroups, apiGroup{
ID: id,
Name: "Group " + id,
})
}
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(M{
"_embedded": M{
"groups": apiGroups,
},
})
})
r.Route("/users", func(r chi.Router) {
r.Get("/{user_id}", func(w http.ResponseWriter, r *http.Request) {
userID := chi.URLParam(r, "user_id")
groupIDs, ok := userIDToGroupIDs[userID]
if !ok {
http.NotFound(w, r)
return
}
au := apiUser{
ID: userID,
Email: userID + "@example.com",
Name: apiUserName{
Given: "Given-" + userID,
Middle: "Middle-" + userID,
Family: "Family-" + userID,
},
}
if r.URL.Query().Get("include") == "memberOfGroupIDs" {
au.MemberOfGroupIDs = groupIDs
}
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(au)
})
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
filter := r.URL.Query().Get("filter")
userIDs, ok := filterToUserIDs[filter]
if !ok {
http.Error(w, "expected filter", http.StatusBadRequest)
return
}
var apiUsers []apiUser
for _, id := range userIDs {
apiUsers = append(apiUsers, apiUser{
ID: id,
Email: id + "@example.com",
Name: apiUserName{
Given: "Given-" + id,
Middle: "Middle-" + id,
Family: "Family-" + id,
},
})
}
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(M{
"_embedded": M{
"users": apiUsers,
},
})
})
})
})
return r
}
func TestProvider_User(t *testing.T) {
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10)
defer clearTimeout()
srv := httptest.NewServer(newMockAPI(map[string][]string{
"user1": {"group1", "group2"},
"user2": {"group1", "group3"},
"user3": {"group3"},
}))
defer srv.Close()
u, err := url.Parse(srv.URL)
require.NoError(t, err)
p := New(
WithAPIURL(u),
WithAuthURL(u),
WithEnvironmentID("ENVIRONMENTID"),
WithServiceAccount(&ServiceAccount{
ClientID: "CLIENTID",
ClientSecret: "CLIENTSECRET",
}))
du, err := p.User(ctx, "user1", "")
require.NoError(t, err)
testutil.AssertProtoJSONEqual(t, `{
"id": "user1",
"email": "user1@example.com",
"displayName": "Given-user1 Middle-user1 Family-user1",
"groupIds": ["group1", "group2"]
}`, du)
}
func TestProvider_UserGroups(t *testing.T) {
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10)
defer clearTimeout()
srv := httptest.NewServer(newMockAPI(map[string][]string{
"user1": {"group1", "group2"},
"user2": {"group1", "group3"},
"user3": {"group3"},
}))
defer srv.Close()
u, err := url.Parse(srv.URL)
require.NoError(t, err)
p := New(
WithAPIURL(u),
WithAuthURL(u),
WithEnvironmentID("ENVIRONMENTID"),
WithServiceAccount(&ServiceAccount{
ClientID: "CLIENTID",
ClientSecret: "CLIENTSECRET",
}))
dgs, dus, err := p.UserGroups(ctx)
require.NoError(t, err)
testutil.AssertProtoJSONEqual(t, `[
{ "id": "group1", "name": "Group group1" },
{ "id": "group2", "name": "Group group2" },
{ "id": "group3", "name": "Group group3" }
]`, dgs)
testutil.AssertProtoJSONEqual(t, `[
{
"id": "user1",
"email": "user1@example.com",
"displayName": "Given-user1 Middle-user1 Family-user1",
"groupIds": ["group1", "group2"]
},
{
"id": "user2",
"email": "user2@example.com",
"displayName": "Given-user2 Middle-user2 Family-user2",
"groupIds": ["group1", "group3"]
},
{
"id": "user3",
"email": "user3@example.com",
"displayName": "Given-user3 Middle-user3 Family-user3",
"groupIds": ["group3"]
}
]`, dus)
}

View file

@ -1,197 +0,0 @@
// Package directory implements the user group directory service.
package directory
import (
"context"
"fmt"
"net/url"
"sync"
"github.com/google/go-cmp/cmp"
"github.com/pomerium/pomerium/internal/directory/auth0"
"github.com/pomerium/pomerium/internal/directory/azure"
"github.com/pomerium/pomerium/internal/directory/github"
"github.com/pomerium/pomerium/internal/directory/gitlab"
"github.com/pomerium/pomerium/internal/directory/google"
"github.com/pomerium/pomerium/internal/directory/okta"
"github.com/pomerium/pomerium/internal/directory/onelogin"
"github.com/pomerium/pomerium/internal/directory/ping"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/pkg/grpc/directory"
)
// A Group is a directory Group.
type Group = directory.Group
// A User is a directory User.
type User = directory.User
// Options are the options specific to the provider.
type Options = directory.Options
// RegisterDirectoryServiceServer registers the directory gRPC service.
var RegisterDirectoryServiceServer = directory.RegisterDirectoryServiceServer
// A Provider provides user group directory information.
type Provider interface {
User(ctx context.Context, userID, accessToken string) (*User, error)
UserGroups(ctx context.Context) ([]*Group, []*User, error)
}
var globalProvider = struct {
sync.Mutex
provider Provider
options Options
}{}
// GetProvider gets the provider for the given options.
func GetProvider(options Options) (provider Provider) {
globalProvider.Lock()
defer globalProvider.Unlock()
ctx := context.TODO()
if globalProvider.provider != nil && cmp.Equal(globalProvider.options, options) {
log.Debug(ctx).Str("provider", options.Provider).Msg("directory: no change detected, reusing existing directory provider")
return globalProvider.provider
}
defer func() {
globalProvider.provider = provider
globalProvider.options = options
}()
var providerURL *url.URL
// url.Parse will succeed even if we pass an empty string
if options.ProviderURL != "" {
providerURL, _ = url.Parse(options.ProviderURL)
}
var errSyncDisabled error
switch options.Provider {
case auth0.Name:
serviceAccount, err := auth0.ParseServiceAccount(options)
if err == nil {
return auth0.New(
auth0.WithDomain(options.ProviderURL),
auth0.WithServiceAccount(serviceAccount))
}
errSyncDisabled = fmt.Errorf("invalid auth0 service account: %w", err)
log.Warn(ctx).
Str("service", "directory").
Str("provider", options.Provider).
Err(err).
Msg("invalid service account for auth0 directory provider")
case azure.Name:
serviceAccount, err := azure.ParseServiceAccount(options)
if err == nil {
return azure.New(azure.WithServiceAccount(serviceAccount))
}
errSyncDisabled = fmt.Errorf("invalid Azure service account: %w", err)
log.Warn(ctx).
Str("service", "directory").
Str("provider", options.Provider).
Err(err).
Msg("invalid service account for azure directory provider")
case github.Name:
serviceAccount, err := github.ParseServiceAccount(options.ServiceAccount)
if err == nil {
return github.New(github.WithServiceAccount(serviceAccount))
}
errSyncDisabled = fmt.Errorf("invalid GitHub service account: %w", err)
log.Warn(ctx).
Str("service", "directory").
Str("provider", options.Provider).
Err(err).
Msg("invalid service account for github directory provider")
case gitlab.Name:
serviceAccount, err := gitlab.ParseServiceAccount(options.ServiceAccount)
if err == nil {
if providerURL == nil {
return gitlab.New(gitlab.WithServiceAccount(serviceAccount))
}
return gitlab.New(
gitlab.WithURL(providerURL),
gitlab.WithServiceAccount(serviceAccount))
}
errSyncDisabled = fmt.Errorf("invalid GitLab service account: %w", err)
log.Warn(ctx).
Str("service", "directory").
Str("provider", options.Provider).
Err(err).
Msg("invalid service account for gitlab directory provider")
case google.Name:
serviceAccount, err := google.ParseServiceAccount(options.ServiceAccount)
if err == nil {
googleOptions := []google.Option{
google.WithServiceAccount(serviceAccount),
}
if options.ProviderURL != "" {
googleOptions = append(googleOptions, google.WithURL(options.ProviderURL))
}
return google.New(googleOptions...)
}
errSyncDisabled = fmt.Errorf("invalid google service account: %w", err)
log.Warn(ctx).
Str("service", "directory").
Str("provider", options.Provider).
Err(err).
Msg("invalid service account for Google directory provider")
case okta.Name:
serviceAccount, err := okta.ParseServiceAccount(options.ServiceAccount)
if err == nil {
return okta.New(
okta.WithProviderURL(providerURL),
okta.WithServiceAccount(serviceAccount))
}
errSyncDisabled = fmt.Errorf("invalid Okta service account: %w", err)
log.Warn(ctx).
Str("service", "directory").
Str("provider", options.Provider).
Err(err).
Msg("invalid service account for okta directory provider")
case onelogin.Name:
serviceAccount, err := onelogin.ParseServiceAccount(options.ServiceAccount)
if err == nil {
return onelogin.New(onelogin.WithServiceAccount(serviceAccount))
}
errSyncDisabled = fmt.Errorf("invalid OneLogin service account: %w", err)
log.Warn(ctx).
Str("service", "directory").
Str("provider", options.Provider).
Err(err).
Msg("invalid service account for onelogin directory provider")
case ping.Name:
serviceAccount, err := ping.ParseServiceAccount(options.ServiceAccount)
if err == nil {
return ping.New(
ping.WithProviderURL(providerURL),
ping.WithServiceAccount(serviceAccount))
}
errSyncDisabled = fmt.Errorf("invalid Ping service account: %w", err)
log.Warn(ctx).
Str("service", "directory").
Str("provider", options.Provider).
Err(err).
Msg("invalid service account for ping directory provider")
case "":
errSyncDisabled = fmt.Errorf("no directory provider configured")
default:
errSyncDisabled = fmt.Errorf("unknown directory provider %s", options.Provider)
}
log.Warn(ctx).
Str("provider", options.Provider).
Msg(errSyncDisabled.Error())
return nullProvider{errSyncDisabled}
}
type nullProvider struct {
error
}
func (p nullProvider) User(ctx context.Context, userID, accessToken string) (*directory.User, error) {
return nil, p.error
}
func (p nullProvider) UserGroups(ctx context.Context) ([]*Group, []*User, error) {
return nil, nil, p.error
}

View file

@ -3,24 +3,18 @@ package manager
import (
"time"
"github.com/pomerium/pomerium/internal/directory"
"github.com/pomerium/pomerium/internal/events"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
)
var (
defaultGroupRefreshInterval = 15 * time.Minute
defaultGroupRefreshTimeout = 10 * time.Minute
defaultSessionRefreshGracePeriod = 1 * time.Minute
defaultSessionRefreshCoolOffDuration = 10 * time.Second
)
type config struct {
authenticator Authenticator
directory directory.Provider
dataBrokerClient databroker.DataBrokerServiceClient
groupRefreshInterval time.Duration
groupRefreshTimeout time.Duration
sessionRefreshGracePeriod time.Duration
sessionRefreshCoolOffDuration time.Duration
now func() time.Time
@ -29,8 +23,6 @@ type config struct {
func newConfig(options ...Option) *config {
cfg := new(config)
WithGroupRefreshInterval(defaultGroupRefreshInterval)(cfg)
WithGroupRefreshTimeout(defaultGroupRefreshTimeout)(cfg)
WithSessionRefreshGracePeriod(defaultSessionRefreshGracePeriod)(cfg)
WithSessionRefreshCoolOffDuration(defaultSessionRefreshCoolOffDuration)(cfg)
WithNow(time.Now)(cfg)
@ -50,13 +42,6 @@ func WithAuthenticator(authenticator Authenticator) Option {
}
}
// WithDirectoryProvider sets the directory provider in the config.
func WithDirectoryProvider(directoryProvider directory.Provider) Option {
return func(cfg *config) {
cfg.directory = directoryProvider
}
}
// WithDataBrokerClient sets the databroker client in the config.
func WithDataBrokerClient(dataBrokerClient databroker.DataBrokerServiceClient) Option {
return func(cfg *config) {
@ -64,20 +49,6 @@ func WithDataBrokerClient(dataBrokerClient databroker.DataBrokerServiceClient) O
}
}
// WithGroupRefreshInterval sets the group refresh interval used by the manager.
func WithGroupRefreshInterval(interval time.Duration) Option {
return func(cfg *config) {
cfg.groupRefreshInterval = interval
}
}
// WithGroupRefreshTimeout sets the group refresh timeout used by the manager.
func WithGroupRefreshTimeout(timeout time.Duration) Option {
return func(cfg *config) {
cfg.groupRefreshTimeout = timeout
}
}
// WithSessionRefreshGracePeriod sets the session refresh grace period used by the manager.
func WithSessionRefreshGracePeriod(dur time.Duration) Option {
return func(cfg *config) {

View file

@ -12,16 +12,17 @@ import (
"github.com/pomerium/pomerium/pkg/grpc/user"
)
const userRefreshInterval = 10 * time.Minute
// A User is a user managed by the Manager.
type User struct {
*user.User
lastRefresh time.Time
refreshInterval time.Duration
}
// NextRefresh returns the next time the user information needs to be refreshed.
func (u User) NextRefresh() time.Time {
return u.lastRefresh.Add(u.refreshInterval)
return u.lastRefresh.Add(userRefreshInterval)
}
// UnmarshalJSON unmarshals json data into the user object.

View file

@ -6,16 +6,13 @@ import (
"errors"
"time"
"github.com/cenkalti/backoff/v4"
"github.com/google/btree"
"github.com/rs/zerolog"
"golang.org/x/oauth2"
"golang.org/x/sync/errgroup"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/pomerium/pomerium/internal/atomicutil"
"github.com/pomerium/pomerium/internal/directory"
"github.com/pomerium/pomerium/internal/events"
"github.com/pomerium/pomerium/internal/identity/identity"
"github.com/pomerium/pomerium/internal/log"
@ -26,7 +23,6 @@ import (
"github.com/pomerium/pomerium/pkg/grpc/user"
"github.com/pomerium/pomerium/pkg/grpcutil"
metrics_ids "github.com/pomerium/pomerium/pkg/metrics"
"github.com/pomerium/pomerium/pkg/protoutil"
)
// Authenticator is an identity.Provider with only the methods needed by the manager.
@ -51,11 +47,6 @@ type Manager struct {
sessions sessionCollection
users userCollection
directoryUsers map[string]*directory.User
directoryGroups map[string]*directory.Group
directoryBackoff *backoff.ExponentialBackOff
directoryNextRefresh time.Time
}
// New creates a new identity manager.
@ -68,8 +59,6 @@ func New(
sessionScheduler: scheduler.New(),
userScheduler: scheduler.New(),
}
mgr.directoryBackoff = backoff.NewExponentialBackOff()
mgr.directoryBackoff.MaxElapsedTime = 0
mgr.reset()
mgr.UpdateConfig(options...)
return mgr
@ -131,8 +120,6 @@ func (mgr *Manager) refreshLoop(ctx context.Context, update <-chan updateRecords
}
log.Info(ctx).
Int("directory_groups", len(mgr.directoryGroups)).
Int("directory_users", len(mgr.directoryUsers)).
Int("sessions", mgr.sessions.Len()).
Int("users", mgr.users.Len()).
Msg("initial sync complete")
@ -140,9 +127,6 @@ func (mgr *Manager) refreshLoop(ctx context.Context, update <-chan updateRecords
// start refreshing
maxWait := time.Minute * 10
nextTime := time.Now().Add(maxWait)
if mgr.directoryNextRefresh.Before(nextTime) {
nextTime = mgr.directoryNextRefresh
}
timer := time.NewTimer(time.Until(nextTime))
defer timer.Stop()
@ -161,14 +145,6 @@ func (mgr *Manager) refreshLoop(ctx context.Context, update <-chan updateRecords
now := time.Now()
nextTime = now.Add(maxWait)
// refresh groups
if mgr.directoryNextRefresh.Before(now) {
mgr.directoryNextRefresh = now.Add(mgr.refreshDirectoryUserGroups(ctx))
}
if mgr.directoryNextRefresh.Before(nextTime) {
nextTime = mgr.directoryNextRefresh
}
// refresh sessions
for {
tm, key := mgr.sessionScheduler.Next()
@ -203,128 +179,6 @@ func (mgr *Manager) refreshLoop(ctx context.Context, update <-chan updateRecords
}
}
func (mgr *Manager) refreshDirectoryUserGroups(ctx context.Context) (nextRefreshDelay time.Duration) {
log.Info(ctx).Msg("refreshing directory users")
ctx, clearTimeout := context.WithTimeout(ctx, mgr.cfg.Load().groupRefreshTimeout)
defer clearTimeout()
directoryGroups, directoryUsers, err := mgr.cfg.Load().directory.UserGroups(ctx)
metrics.RecordIdentityManagerUserGroupRefresh(ctx, err)
mgr.recordLastError(metrics_ids.IdentityManagerLastUserGroupRefreshError, err)
if err != nil {
msg := "failed to refresh directory users and groups"
if ctx.Err() != nil {
msg += ". You may need to increase the identity provider directory timeout setting"
msg += "(https://www.pomerium.com/docs/reference/identity-provider-refresh-directory-settings)"
}
log.Warn(ctx).Err(err).Msg(msg)
return minDuration(
mgr.cfg.Load().groupRefreshInterval, // never wait more than the refresh interval
mgr.directoryBackoff.NextBackOff(),
)
}
mgr.directoryBackoff.Reset() // success so reset the backoff
mgr.mergeGroups(ctx, directoryGroups)
mgr.mergeUsers(ctx, directoryUsers)
return mgr.cfg.Load().groupRefreshInterval
}
func (mgr *Manager) mergeGroups(ctx context.Context, directoryGroups []*directory.Group) {
lookup := map[string]*directory.Group{}
for _, dg := range directoryGroups {
lookup[dg.GetId()] = dg
}
var records []*databroker.Record
for groupID, newDG := range lookup {
curDG, ok := mgr.directoryGroups[groupID]
if !ok || !proto.Equal(newDG, curDG) {
id := newDG.GetId()
any := protoutil.NewAny(newDG)
records = append(records, &databroker.Record{
Type: any.GetTypeUrl(),
Id: id,
Data: any,
})
}
}
for groupID, curDG := range mgr.directoryGroups {
_, ok := lookup[groupID]
if !ok {
id := curDG.GetId()
any := protoutil.NewAny(curDG)
records = append(records, &databroker.Record{
Type: any.GetTypeUrl(),
Id: id,
Data: any,
DeletedAt: timestamppb.New(mgr.cfg.Load().now()),
})
}
}
for i, batch := range databroker.OptimumPutRequestsFromRecords(records) {
_, err := mgr.cfg.Load().dataBrokerClient.Put(ctx, batch)
if err != nil {
log.Warn(ctx).Err(err).
Int("batch", i).
Int("record-count", len(batch.GetRecords())).
Msg("manager: failed to update groups")
}
}
}
func (mgr *Manager) mergeUsers(ctx context.Context, directoryUsers []*directory.User) {
lookup := map[string]*directory.User{}
for _, du := range directoryUsers {
lookup[du.GetId()] = du
}
var records []*databroker.Record
for userID, newDU := range lookup {
curDU, ok := mgr.directoryUsers[userID]
if !ok || !proto.Equal(newDU, curDU) {
id := newDU.GetId()
any := protoutil.NewAny(newDU)
records = append(records, &databroker.Record{
Type: any.GetTypeUrl(),
Id: id,
Data: any,
})
}
}
for userID, curDU := range mgr.directoryUsers {
_, ok := lookup[userID]
if !ok {
id := curDU.GetId()
any := protoutil.NewAny(curDU)
records = append(records, &databroker.Record{
Type: any.GetTypeUrl(),
Id: id,
Data: any,
DeletedAt: timestamppb.New(mgr.cfg.Load().now()),
})
}
}
for i, batch := range databroker.OptimumPutRequestsFromRecords(records) {
_, err := mgr.cfg.Load().dataBrokerClient.Put(ctx, batch)
if err != nil {
log.Warn(ctx).Err(err).
Int("batch", i).
Int("record-count", len(batch.GetRecords())).
Msg("manager: failed to update users")
}
}
}
func (mgr *Manager) refreshSession(ctx context.Context, userID, sessionID string) {
log.Info(ctx).
Str("user_id", userID).
@ -464,22 +318,6 @@ func (mgr *Manager) refreshUser(ctx context.Context, userID string) {
func (mgr *Manager) onUpdateRecords(ctx context.Context, msg updateRecordsMessage) {
for _, record := range msg.records {
switch record.GetType() {
case grpcutil.GetTypeURL(new(directory.Group)):
var pbDirectoryGroup directory.Group
err := record.GetData().UnmarshalTo(&pbDirectoryGroup)
if err != nil {
log.Warn(ctx).Msgf("error unmarshaling directory group: %s", err)
continue
}
mgr.onUpdateDirectoryGroup(ctx, &pbDirectoryGroup)
case grpcutil.GetTypeURL(new(directory.User)):
var pbDirectoryUser directory.User
err := record.GetData().UnmarshalTo(&pbDirectoryUser)
if err != nil {
log.Warn(ctx).Msgf("error unmarshaling directory user: %s", err)
continue
}
mgr.onUpdateDirectoryUser(ctx, &pbDirectoryUser)
case grpcutil.GetTypeURL(new(session.Session)):
var pbSession session.Session
err := record.GetData().UnmarshalTo(&pbSession)
@ -528,20 +366,11 @@ func (mgr *Manager) onUpdateUser(_ context.Context, record *databroker.Record, u
u, _ := mgr.users.Get(user.GetId())
u.lastRefresh = mgr.cfg.Load().now()
u.refreshInterval = mgr.cfg.Load().groupRefreshInterval
u.User = user
mgr.users.ReplaceOrInsert(u)
mgr.userScheduler.Add(u.NextRefresh(), u.GetId())
}
func (mgr *Manager) onUpdateDirectoryUser(_ context.Context, pbDirectoryUser *directory.User) {
mgr.directoryUsers[pbDirectoryUser.GetId()] = pbDirectoryUser
}
func (mgr *Manager) onUpdateDirectoryGroup(_ context.Context, pbDirectoryGroup *directory.Group) {
mgr.directoryGroups[pbDirectoryGroup.GetId()] = pbDirectoryGroup
}
func (mgr *Manager) deleteSession(ctx context.Context, pbSession *session.Session) {
err := session.Delete(ctx, mgr.cfg.Load().dataBrokerClient, pbSession.GetId())
if err != nil {
@ -553,8 +382,6 @@ func (mgr *Manager) deleteSession(ctx context.Context, pbSession *session.Sessio
// reset resets all the manager datastructures to their initial state
func (mgr *Manager) reset() {
mgr.directoryGroups = make(map[string]*directory.Group)
mgr.directoryUsers = make(map[string]*directory.User)
mgr.sessions = sessionCollection{BTree: btree.New(8)}
mgr.users = userCollection{BTree: btree.New(8)}
}
@ -586,13 +413,3 @@ func isTemporaryError(err error) bool {
}
return false
}
func minDuration(d1 time.Duration, ds ...time.Duration) time.Duration {
min := d1
for _, d := range ds {
if d < min {
min = d
}
}
return min
}

View file

@ -3,7 +3,6 @@ package manager
import (
"context"
"errors"
"fmt"
"testing"
"time"
@ -13,7 +12,6 @@ import (
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/pomerium/pomerium/internal/directory"
"github.com/pomerium/pomerium/internal/events"
"github.com/pomerium/pomerium/internal/identity/identity"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
@ -24,19 +22,6 @@ import (
"github.com/pomerium/pomerium/pkg/protoutil"
)
type mockProvider struct {
user func(ctx context.Context, userID, accessToken string) (*directory.User, error)
userGroups func(ctx context.Context) ([]*directory.Group, []*directory.User, error)
}
func (mock mockProvider) User(ctx context.Context, userID, accessToken string) (*directory.User, error) {
return mock.user(ctx, userID, accessToken)
}
func (mock mockProvider) UserGroups(ctx context.Context) ([]*directory.Group, []*directory.User, error) {
return mock.userGroups(ctx)
}
type mockAuthenticator struct{}
func (mock mockAuthenticator) Refresh(_ context.Context, _ *oauth2.Token, _ identity.State) (*oauth2.Token, error) {
@ -61,72 +46,25 @@ func TestManager_onUpdateRecords(t *testing.T) {
mgr := New(
WithDataBrokerClient(mock_databroker.NewMockDataBrokerServiceClient(ctrl)),
WithDirectoryProvider(mockProvider{}),
WithGroupRefreshInterval(time.Hour),
WithNow(func() time.Time {
return now
}),
)
mgr.directoryBackoff.RandomizationFactor = 0 // disable randomization for deterministic testing
mgr.onUpdateRecords(ctx, updateRecordsMessage{
records: []*databroker.Record{
mkRecord(&directory.Group{Id: "group1", Name: "group 1", Email: "group1@example.com"}),
mkRecord(&directory.User{Id: "user1", DisplayName: "user 1", Email: "user1@example.com", GroupIds: []string{"group1s"}}),
mkRecord(&session.Session{Id: "session1", UserId: "user1"}),
mkRecord(&user.User{Id: "user1", Name: "user 1", Email: "user1@example.com"}),
},
})
assert.NotNil(t, mgr.directoryGroups["group1"])
assert.NotNil(t, mgr.directoryUsers["user1"])
if _, ok := mgr.sessions.Get("user1", "session1"); assert.True(t, ok) {
}
if _, ok := mgr.users.Get("user1"); assert.True(t, ok) {
tm, id := mgr.userScheduler.Next()
assert.Equal(t, now.Add(time.Hour), tm)
assert.Equal(t, now.Add(userRefreshInterval), tm)
assert.Equal(t, "user1", id)
}
}
func TestManager_refreshDirectoryUserGroups(t *testing.T) {
ctrl := gomock.NewController(t)
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10)
defer clearTimeout()
t.Run("backoff", func(t *testing.T) {
cnt := 0
client := mock_databroker.NewMockDataBrokerServiceClient(ctrl)
client.EXPECT().Put(gomock.Any(), gomock.Any()).AnyTimes()
mgr := New(
WithDataBrokerClient(client),
WithDirectoryProvider(mockProvider{
userGroups: func(ctx context.Context) ([]*directory.Group, []*directory.User, error) {
cnt++
switch cnt {
case 1:
return nil, nil, fmt.Errorf("error 1")
case 2:
return nil, nil, fmt.Errorf("error 2")
}
return nil, nil, nil
},
}),
WithGroupRefreshInterval(time.Hour),
)
mgr.directoryBackoff.RandomizationFactor = 0 // disable randomization for deterministic testing
dur1 := mgr.refreshDirectoryUserGroups(ctx)
dur2 := mgr.refreshDirectoryUserGroups(ctx)
dur3 := mgr.refreshDirectoryUserGroups(ctx)
assert.Greater(t, dur2, dur1)
assert.Greater(t, dur3, dur2)
assert.Equal(t, time.Hour, dur3)
})
}
func TestManager_reportErrors(t *testing.T) {
@ -161,22 +99,10 @@ func TestManager_reportErrors(t *testing.T) {
WithEventManager(evtMgr),
WithDataBrokerClient(client),
WithAuthenticator(mockAuthenticator{}),
WithDirectoryProvider(mockProvider{
user: func(ctx context.Context, userID, accessToken string) (*directory.User, error) {
return nil, fmt.Errorf("user")
},
userGroups: func(ctx context.Context) ([]*directory.Group, []*directory.User, error) {
return nil, nil, fmt.Errorf("user groups")
},
}),
WithGroupRefreshInterval(time.Second),
)
mgr.directoryBackoff.RandomizationFactor = 0 // disable randomization for deterministic testing
mgr.onUpdateRecords(ctx, updateRecordsMessage{
records: []*databroker.Record{
mkRecord(&directory.Group{Id: "group1", Name: "group 1", Email: "group1@example.com"}),
mkRecord(&directory.User{Id: "user1", DisplayName: "user 1", Email: "user1@example.com", GroupIds: []string{"group1s"}}),
mkRecord(&session.Session{Id: "session1", UserId: "user1", OauthToken: &session.OAuthToken{
ExpiresAt: timestamppb.New(time.Now().Add(time.Hour)),
}, ExpiresAt: timestamppb.New(time.Now().Add(time.Hour))}),
@ -184,9 +110,6 @@ func TestManager_reportErrors(t *testing.T) {
},
})
_ = mgr.refreshDirectoryUserGroups(ctx)
expectMsg(metrics_ids.IdentityManagerLastUserGroupRefreshError, "user groups")
mgr.refreshUser(ctx, "user1")
expectMsg(metrics_ids.IdentityManagerLastUserRefreshError, "update user info")

View file

@ -26,10 +26,6 @@ type Options struct {
// Scope specifies optional requested permissions.
Scopes []string
// ServiceAccount can be set for those providers that require additional
// credentials or tokens to do follow up API calls (e.g. Google)
ServiceAccount string
// AuthCodeOptions specifies additional key value pairs query params to add
// to the request flow signin url.
AuthCodeOptions map[string]string

View file

@ -27,7 +27,6 @@ var defaultScopes = []string{oidc.ScopeOpenID, "profile", "email"}
// requires we set this on a custom uri param. Also, ` prompt` must be set to `consent`to ensure
// that our application always receives a refresh token (ask google). And finally, we default to
// having the user select which Google account they'd like to use.
//
// For more details, please see google's documentation:
//
// https://developers.google.com/identity/protocols/oauth2/web-server#offline

View file

@ -13,7 +13,6 @@ cookie_secret: YYYYY
idp_provider: "google"
idp_client_id: XXXX
idp_client_secret: YYYY
idp_service_account: XXXXXX
routes:
- from: https://yoursite.localhost.pomerium.io

View file

@ -9,6 +9,7 @@ import (
)
func TestMerge(t *testing.T) {
type key string
t.Run("value", func(t *testing.T) {
type contextKey string
k1 := contextKey("key1")

File diff suppressed because it is too large Load diff

View file

@ -55,7 +55,7 @@ message Route {
RouteRedirect redirect = 34;
repeated string allowed_users = 4 [ deprecated = true ];
repeated string allowed_groups = 5 [ deprecated = true ];
// repeated string allowed_groups = 5 [ deprecated = true ];
repeated string allowed_domains = 6 [ deprecated = true ];
map<string, google.protobuf.ListValue> allowed_idp_claims = 32
[ deprecated = true ];
@ -121,7 +121,7 @@ message Policy {
string id = 1;
string name = 2;
repeated string allowed_users = 3;
repeated string allowed_groups = 4;
// repeated string allowed_groups = 4;
repeated string allowed_domains = 5;
map<string, google.protobuf.ListValue> allowed_idp_claims = 7;
repeated string rego = 6;
@ -166,9 +166,9 @@ message Settings {
optional string idp_provider = 24;
optional string idp_provider_url = 25;
repeated string scopes = 26;
optional string idp_service_account = 27;
optional google.protobuf.Duration idp_refresh_directory_timeout = 28;
optional google.protobuf.Duration idp_refresh_directory_interval = 29;
// optional string idp_service_account = 27;
// optional google.protobuf.Duration idp_refresh_directory_timeout = 28;
// optional google.protobuf.Duration idp_refresh_directory_interval = 29;
map<string, string> request_params = 30;
repeated string authorize_service_urls = 32;
optional string authorize_internal_service_url = 83;

View file

@ -84,7 +84,7 @@ loop:
case err == io.EOF:
break loop
case err != nil:
return nil, 0, 0, err
return nil, 0, 0, fmt.Errorf("error receiving record: %w", err)
}
switch res := res.GetResponse().(type) {

View file

@ -2,6 +2,7 @@ package databroker
import (
"context"
"fmt"
"time"
backoff "github.com/cenkalti/backoff/v4"
@ -132,8 +133,7 @@ func (syncer *Syncer) init(ctx context.Context) error {
Type: syncer.cfg.typeURL,
})
if err != nil {
log.Error(ctx).Err(err).Msg("error during initial sync")
return err
return fmt.Errorf("error during initial sync: %w", err)
}
syncer.backoff.Reset()
@ -154,8 +154,7 @@ func (syncer *Syncer) sync(ctx context.Context) error {
Type: syncer.cfg.typeURL,
})
if err != nil {
log.Error(ctx).Err(err).Msg("error during sync")
return err
return fmt.Errorf("error calling sync: %w", err)
}
log.Info(ctx).Msg("listening for updates")
@ -168,7 +167,7 @@ func (syncer *Syncer) sync(ctx context.Context) error {
syncer.serverVersion = 0
return nil
} else if err != nil {
return err
return fmt.Errorf("error receiving sync record: %w", err)
}
rec := res.GetRecord()

View file

@ -1,30 +0,0 @@
// Package directory contains protobuf types for directory users.
package directory
import (
context "context"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
)
// GetGroup gets a directory group from the databroker.
func GetGroup(ctx context.Context, client databroker.DataBrokerServiceClient, groupID string) (*Group, error) {
g := Group{Id: groupID}
return &g, databroker.Get(ctx, client, &g)
}
// GetUser gets a directory user from the databroker.
func GetUser(ctx context.Context, client databroker.DataBrokerServiceClient, userID string) (*User, error) {
u := User{Id: userID}
return &u, databroker.Get(ctx, client, &u)
}
// Options are directory provider options.
type Options struct {
ServiceAccount string
Provider string
ProviderURL string
ClientID string
ClientSecret string
QPS float64
}

View file

@ -1,441 +0,0 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.28.1
// protoc v3.21.7
// source: directory.proto
package directory
import (
context "context"
grpc "google.golang.org/grpc"
codes "google.golang.org/grpc/codes"
status "google.golang.org/grpc/status"
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
emptypb "google.golang.org/protobuf/types/known/emptypb"
reflect "reflect"
sync "sync"
)
const (
// Verify that this generated code is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
// Verify that runtime/protoimpl is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
)
type User struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Version string `protobuf:"bytes,1,opt,name=version,proto3" json:"version,omitempty"`
Id string `protobuf:"bytes,2,opt,name=id,proto3" json:"id,omitempty"`
GroupIds []string `protobuf:"bytes,3,rep,name=group_ids,json=groupIds,proto3" json:"group_ids,omitempty"`
DisplayName string `protobuf:"bytes,4,opt,name=display_name,json=displayName,proto3" json:"display_name,omitempty"`
Email string `protobuf:"bytes,5,opt,name=email,proto3" json:"email,omitempty"`
}
func (x *User) Reset() {
*x = User{}
if protoimpl.UnsafeEnabled {
mi := &file_directory_proto_msgTypes[0]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *User) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*User) ProtoMessage() {}
func (x *User) ProtoReflect() protoreflect.Message {
mi := &file_directory_proto_msgTypes[0]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use User.ProtoReflect.Descriptor instead.
func (*User) Descriptor() ([]byte, []int) {
return file_directory_proto_rawDescGZIP(), []int{0}
}
func (x *User) GetVersion() string {
if x != nil {
return x.Version
}
return ""
}
func (x *User) GetId() string {
if x != nil {
return x.Id
}
return ""
}
func (x *User) GetGroupIds() []string {
if x != nil {
return x.GroupIds
}
return nil
}
func (x *User) GetDisplayName() string {
if x != nil {
return x.DisplayName
}
return ""
}
func (x *User) GetEmail() string {
if x != nil {
return x.Email
}
return ""
}
type Group struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Version string `protobuf:"bytes,1,opt,name=version,proto3" json:"version,omitempty"`
Id string `protobuf:"bytes,2,opt,name=id,proto3" json:"id,omitempty"`
Name string `protobuf:"bytes,3,opt,name=name,proto3" json:"name,omitempty"`
Email string `protobuf:"bytes,4,opt,name=email,proto3" json:"email,omitempty"`
}
func (x *Group) Reset() {
*x = Group{}
if protoimpl.UnsafeEnabled {
mi := &file_directory_proto_msgTypes[1]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *Group) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*Group) ProtoMessage() {}
func (x *Group) ProtoReflect() protoreflect.Message {
mi := &file_directory_proto_msgTypes[1]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use Group.ProtoReflect.Descriptor instead.
func (*Group) Descriptor() ([]byte, []int) {
return file_directory_proto_rawDescGZIP(), []int{1}
}
func (x *Group) GetVersion() string {
if x != nil {
return x.Version
}
return ""
}
func (x *Group) GetId() string {
if x != nil {
return x.Id
}
return ""
}
func (x *Group) GetName() string {
if x != nil {
return x.Name
}
return ""
}
func (x *Group) GetEmail() string {
if x != nil {
return x.Email
}
return ""
}
type RefreshUserRequest struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
UserId string `protobuf:"bytes,1,opt,name=user_id,json=userId,proto3" json:"user_id,omitempty"`
AccessToken string `protobuf:"bytes,2,opt,name=access_token,json=accessToken,proto3" json:"access_token,omitempty"`
}
func (x *RefreshUserRequest) Reset() {
*x = RefreshUserRequest{}
if protoimpl.UnsafeEnabled {
mi := &file_directory_proto_msgTypes[2]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *RefreshUserRequest) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*RefreshUserRequest) ProtoMessage() {}
func (x *RefreshUserRequest) ProtoReflect() protoreflect.Message {
mi := &file_directory_proto_msgTypes[2]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use RefreshUserRequest.ProtoReflect.Descriptor instead.
func (*RefreshUserRequest) Descriptor() ([]byte, []int) {
return file_directory_proto_rawDescGZIP(), []int{2}
}
func (x *RefreshUserRequest) GetUserId() string {
if x != nil {
return x.UserId
}
return ""
}
func (x *RefreshUserRequest) GetAccessToken() string {
if x != nil {
return x.AccessToken
}
return ""
}
var File_directory_proto protoreflect.FileDescriptor
var file_directory_proto_rawDesc = []byte{
0x0a, 0x0f, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x79, 0x2e, 0x70, 0x72, 0x6f, 0x74,
0x6f, 0x12, 0x09, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x79, 0x1a, 0x1b, 0x67, 0x6f,
0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f, 0x65, 0x6d,
0x70, 0x74, 0x79, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0x86, 0x01, 0x0a, 0x04, 0x55, 0x73,
0x65, 0x72, 0x12, 0x18, 0x0a, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x01, 0x20,
0x01, 0x28, 0x09, 0x52, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x0e, 0x0a, 0x02,
0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x69, 0x64, 0x12, 0x1b, 0x0a, 0x09,
0x67, 0x72, 0x6f, 0x75, 0x70, 0x5f, 0x69, 0x64, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x09, 0x52,
0x08, 0x67, 0x72, 0x6f, 0x75, 0x70, 0x49, 0x64, 0x73, 0x12, 0x21, 0x0a, 0x0c, 0x64, 0x69, 0x73,
0x70, 0x6c, 0x61, 0x79, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52,
0x0b, 0x64, 0x69, 0x73, 0x70, 0x6c, 0x61, 0x79, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x14, 0x0a, 0x05,
0x65, 0x6d, 0x61, 0x69, 0x6c, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x65, 0x6d, 0x61,
0x69, 0x6c, 0x22, 0x5b, 0x0a, 0x05, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x12, 0x18, 0x0a, 0x07, 0x76,
0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x76, 0x65,
0x72, 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x0e, 0x0a, 0x02, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28,
0x09, 0x52, 0x02, 0x69, 0x64, 0x12, 0x12, 0x0a, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x03, 0x20,
0x01, 0x28, 0x09, 0x52, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x65, 0x6d, 0x61,
0x69, 0x6c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x65, 0x6d, 0x61, 0x69, 0x6c, 0x22,
0x50, 0x0a, 0x12, 0x52, 0x65, 0x66, 0x72, 0x65, 0x73, 0x68, 0x55, 0x73, 0x65, 0x72, 0x52, 0x65,
0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x17, 0x0a, 0x07, 0x75, 0x73, 0x65, 0x72, 0x5f, 0x69, 0x64,
0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x75, 0x73, 0x65, 0x72, 0x49, 0x64, 0x12, 0x21,
0x0a, 0x0c, 0x61, 0x63, 0x63, 0x65, 0x73, 0x73, 0x5f, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x18, 0x02,
0x20, 0x01, 0x28, 0x09, 0x52, 0x0b, 0x61, 0x63, 0x63, 0x65, 0x73, 0x73, 0x54, 0x6f, 0x6b, 0x65,
0x6e, 0x32, 0x58, 0x0a, 0x10, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x79, 0x53, 0x65,
0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x44, 0x0a, 0x0b, 0x52, 0x65, 0x66, 0x72, 0x65, 0x73, 0x68,
0x55, 0x73, 0x65, 0x72, 0x12, 0x1d, 0x2e, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x79,
0x2e, 0x52, 0x65, 0x66, 0x72, 0x65, 0x73, 0x68, 0x55, 0x73, 0x65, 0x72, 0x52, 0x65, 0x71, 0x75,
0x65, 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f,
0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x42, 0x31, 0x5a, 0x2f, 0x67,
0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x70, 0x6f, 0x6d, 0x65, 0x72, 0x69,
0x75, 0x6d, 0x2f, 0x70, 0x6f, 0x6d, 0x65, 0x72, 0x69, 0x75, 0x6d, 0x2f, 0x70, 0x6b, 0x67, 0x2f,
0x67, 0x72, 0x70, 0x63, 0x2f, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x79, 0x62, 0x06,
0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
}
var (
file_directory_proto_rawDescOnce sync.Once
file_directory_proto_rawDescData = file_directory_proto_rawDesc
)
func file_directory_proto_rawDescGZIP() []byte {
file_directory_proto_rawDescOnce.Do(func() {
file_directory_proto_rawDescData = protoimpl.X.CompressGZIP(file_directory_proto_rawDescData)
})
return file_directory_proto_rawDescData
}
var file_directory_proto_msgTypes = make([]protoimpl.MessageInfo, 3)
var file_directory_proto_goTypes = []interface{}{
(*User)(nil), // 0: directory.User
(*Group)(nil), // 1: directory.Group
(*RefreshUserRequest)(nil), // 2: directory.RefreshUserRequest
(*emptypb.Empty)(nil), // 3: google.protobuf.Empty
}
var file_directory_proto_depIdxs = []int32{
2, // 0: directory.DirectoryService.RefreshUser:input_type -> directory.RefreshUserRequest
3, // 1: directory.DirectoryService.RefreshUser:output_type -> google.protobuf.Empty
1, // [1:2] is the sub-list for method output_type
0, // [0:1] is the sub-list for method input_type
0, // [0:0] is the sub-list for extension type_name
0, // [0:0] is the sub-list for extension extendee
0, // [0:0] is the sub-list for field type_name
}
func init() { file_directory_proto_init() }
func file_directory_proto_init() {
if File_directory_proto != nil {
return
}
if !protoimpl.UnsafeEnabled {
file_directory_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*User); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_directory_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*Group); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_directory_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*RefreshUserRequest); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
}
type x struct{}
out := protoimpl.TypeBuilder{
File: protoimpl.DescBuilder{
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: file_directory_proto_rawDesc,
NumEnums: 0,
NumMessages: 3,
NumExtensions: 0,
NumServices: 1,
},
GoTypes: file_directory_proto_goTypes,
DependencyIndexes: file_directory_proto_depIdxs,
MessageInfos: file_directory_proto_msgTypes,
}.Build()
File_directory_proto = out.File
file_directory_proto_rawDesc = nil
file_directory_proto_goTypes = nil
file_directory_proto_depIdxs = nil
}
// Reference imports to suppress errors if they are not otherwise used.
var _ context.Context
var _ grpc.ClientConnInterface
// This is a compile-time assertion to ensure that this generated file
// is compatible with the grpc package it is being compiled against.
const _ = grpc.SupportPackageIsVersion6
// DirectoryServiceClient is the client API for DirectoryService service.
//
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://godoc.org/google.golang.org/grpc#ClientConn.NewStream.
type DirectoryServiceClient interface {
RefreshUser(ctx context.Context, in *RefreshUserRequest, opts ...grpc.CallOption) (*emptypb.Empty, error)
}
type directoryServiceClient struct {
cc grpc.ClientConnInterface
}
func NewDirectoryServiceClient(cc grpc.ClientConnInterface) DirectoryServiceClient {
return &directoryServiceClient{cc}
}
func (c *directoryServiceClient) RefreshUser(ctx context.Context, in *RefreshUserRequest, opts ...grpc.CallOption) (*emptypb.Empty, error) {
out := new(emptypb.Empty)
err := c.cc.Invoke(ctx, "/directory.DirectoryService/RefreshUser", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
// DirectoryServiceServer is the server API for DirectoryService service.
type DirectoryServiceServer interface {
RefreshUser(context.Context, *RefreshUserRequest) (*emptypb.Empty, error)
}
// UnimplementedDirectoryServiceServer can be embedded to have forward compatible implementations.
type UnimplementedDirectoryServiceServer struct {
}
func (*UnimplementedDirectoryServiceServer) RefreshUser(context.Context, *RefreshUserRequest) (*emptypb.Empty, error) {
return nil, status.Errorf(codes.Unimplemented, "method RefreshUser not implemented")
}
func RegisterDirectoryServiceServer(s *grpc.Server, srv DirectoryServiceServer) {
s.RegisterService(&_DirectoryService_serviceDesc, srv)
}
func _DirectoryService_RefreshUser_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(RefreshUserRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(DirectoryServiceServer).RefreshUser(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/directory.DirectoryService/RefreshUser",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DirectoryServiceServer).RefreshUser(ctx, req.(*RefreshUserRequest))
}
return interceptor(ctx, in, info, handler)
}
var _DirectoryService_serviceDesc = grpc.ServiceDesc{
ServiceName: "directory.DirectoryService",
HandlerType: (*DirectoryServiceServer)(nil),
Methods: []grpc.MethodDesc{
{
MethodName: "RefreshUser",
Handler: _DirectoryService_RefreshUser_Handler,
},
},
Streams: []grpc.StreamDesc{},
Metadata: "directory.proto",
}

View file

@ -1,30 +0,0 @@
syntax = "proto3";
package directory;
option go_package = "github.com/pomerium/pomerium/pkg/grpc/directory";
import "google/protobuf/empty.proto";
message User {
string version = 1;
string id = 2;
repeated string group_ids = 3;
string display_name = 4;
string email = 5;
}
message Group {
string version = 1;
string id = 2;
string name = 3;
string email = 4;
}
message RefreshUserRequest {
string user_id = 1;
string access_token = 2;
}
service DirectoryService {
rpc RefreshUser(RefreshUserRequest) returns (google.protobuf.Empty);
}

View file

@ -30,7 +30,7 @@ type Provider struct {
ClientSecret string `protobuf:"bytes,3,opt,name=client_secret,json=clientSecret,proto3" json:"client_secret,omitempty"`
Type string `protobuf:"bytes,4,opt,name=type,proto3" json:"type,omitempty"`
Scopes []string `protobuf:"bytes,5,rep,name=scopes,proto3" json:"scopes,omitempty"`
ServiceAccount string `protobuf:"bytes,6,opt,name=service_account,json=serviceAccount,proto3" json:"service_account,omitempty"`
// string service_account = 6;
Url string `protobuf:"bytes,7,opt,name=url,proto3" json:"url,omitempty"`
RequestParams map[string]string `protobuf:"bytes,8,rep,name=request_params,json=requestParams,proto3" json:"request_params,omitempty" protobuf_key:"bytes,1,opt,name=key,proto3" protobuf_val:"bytes,2,opt,name=value,proto3"`
}
@ -102,13 +102,6 @@ func (x *Provider) GetScopes() []string {
return nil
}
func (x *Provider) GetServiceAccount() string {
if x != nil {
return x.ServiceAccount
}
return ""
}
func (x *Provider) GetUrl() string {
if x != nil {
return x.Url
@ -128,7 +121,7 @@ var File_identity_proto protoreflect.FileDescriptor
var file_identity_proto_rawDesc = []byte{
0x0a, 0x0e, 0x69, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f,
0x12, 0x11, 0x70, 0x6f, 0x6d, 0x65, 0x72, 0x69, 0x75, 0x6d, 0x2e, 0x69, 0x64, 0x65, 0x6e, 0x74,
0x69, 0x74, 0x79, 0x22, 0xdc, 0x02, 0x0a, 0x08, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72,
0x69, 0x74, 0x79, 0x22, 0xb3, 0x02, 0x0a, 0x08, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72,
0x12, 0x0e, 0x0a, 0x02, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x69, 0x64,
0x12, 0x1b, 0x0a, 0x09, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x02, 0x20,
0x01, 0x28, 0x09, 0x52, 0x08, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x49, 0x64, 0x12, 0x23, 0x0a,
@ -136,24 +129,22 @@ var file_identity_proto_rawDesc = []byte{
0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x53, 0x65, 0x63, 0x72,
0x65, 0x74, 0x12, 0x12, 0x0a, 0x04, 0x74, 0x79, 0x70, 0x65, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09,
0x52, 0x04, 0x74, 0x79, 0x70, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x73, 0x63, 0x6f, 0x70, 0x65, 0x73,
0x18, 0x05, 0x20, 0x03, 0x28, 0x09, 0x52, 0x06, 0x73, 0x63, 0x6f, 0x70, 0x65, 0x73, 0x12, 0x27,
0x0a, 0x0f, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x5f, 0x61, 0x63, 0x63, 0x6f, 0x75, 0x6e,
0x74, 0x18, 0x06, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0e, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65,
0x41, 0x63, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x12, 0x10, 0x0a, 0x03, 0x75, 0x72, 0x6c, 0x18, 0x07,
0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x75, 0x72, 0x6c, 0x12, 0x55, 0x0a, 0x0e, 0x72, 0x65, 0x71,
0x75, 0x65, 0x73, 0x74, 0x5f, 0x70, 0x61, 0x72, 0x61, 0x6d, 0x73, 0x18, 0x08, 0x20, 0x03, 0x28,
0x0b, 0x32, 0x2e, 0x2e, 0x70, 0x6f, 0x6d, 0x65, 0x72, 0x69, 0x75, 0x6d, 0x2e, 0x69, 0x64, 0x65,
0x6e, 0x74, 0x69, 0x74, 0x79, 0x2e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x2e, 0x52,
0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x73, 0x45, 0x6e, 0x74, 0x72,
0x79, 0x52, 0x0d, 0x72, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x73,
0x1a, 0x40, 0x0a, 0x12, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x50, 0x61, 0x72, 0x61, 0x6d,
0x73, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20,
0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x14, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75,
0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x3a, 0x02,
0x38, 0x01, 0x42, 0x30, 0x5a, 0x2e, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d,
0x2f, 0x70, 0x6f, 0x6d, 0x65, 0x72, 0x69, 0x75, 0x6d, 0x2f, 0x70, 0x6f, 0x6d, 0x65, 0x72, 0x69,
0x75, 0x6d, 0x2f, 0x70, 0x6b, 0x67, 0x2f, 0x67, 0x72, 0x70, 0x63, 0x2f, 0x69, 0x64, 0x65, 0x6e,
0x74, 0x69, 0x74, 0x79, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
0x18, 0x05, 0x20, 0x03, 0x28, 0x09, 0x52, 0x06, 0x73, 0x63, 0x6f, 0x70, 0x65, 0x73, 0x12, 0x10,
0x0a, 0x03, 0x75, 0x72, 0x6c, 0x18, 0x07, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x75, 0x72, 0x6c,
0x12, 0x55, 0x0a, 0x0e, 0x72, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x5f, 0x70, 0x61, 0x72, 0x61,
0x6d, 0x73, 0x18, 0x08, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x2e, 0x2e, 0x70, 0x6f, 0x6d, 0x65, 0x72,
0x69, 0x75, 0x6d, 0x2e, 0x69, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x2e, 0x50, 0x72, 0x6f,
0x76, 0x69, 0x64, 0x65, 0x72, 0x2e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x50, 0x61, 0x72,
0x61, 0x6d, 0x73, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x52, 0x0d, 0x72, 0x65, 0x71, 0x75, 0x65, 0x73,
0x74, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x73, 0x1a, 0x40, 0x0a, 0x12, 0x52, 0x65, 0x71, 0x75, 0x65,
0x73, 0x74, 0x50, 0x61, 0x72, 0x61, 0x6d, 0x73, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, 0x10, 0x0a,
0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12,
0x14, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05,
0x76, 0x61, 0x6c, 0x75, 0x65, 0x3a, 0x02, 0x38, 0x01, 0x42, 0x30, 0x5a, 0x2e, 0x67, 0x69, 0x74,
0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x70, 0x6f, 0x6d, 0x65, 0x72, 0x69, 0x75, 0x6d,
0x2f, 0x70, 0x6f, 0x6d, 0x65, 0x72, 0x69, 0x75, 0x6d, 0x2f, 0x70, 0x6b, 0x67, 0x2f, 0x67, 0x72,
0x70, 0x63, 0x2f, 0x69, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x62, 0x06, 0x70, 0x72, 0x6f,
0x74, 0x6f, 0x33,
}
var (

View file

@ -9,7 +9,7 @@ message Provider {
string client_secret = 3;
string type = 4;
repeated string scopes = 5;
string service_account = 6;
// string service_account = 6;
string url = 7;
map<string, string> request_params = 8;
}

View file

@ -92,11 +92,6 @@ _import_paths=$(join_by , "${_imports[@]}")
--go_out="$_import_paths,plugins=grpc,paths=source_relative:./device/." \
./device/device.proto
../../scripts/protoc -I ./directory/ \
--go_out="$_import_paths,plugins=grpc,paths=source_relative:./directory/." \
./directory/directory.proto
../../scripts/protoc -I ./identity/ \
--go_out="$_import_paths,plugins=grpc,paths=source_relative:./identity/." \
./identity/identity.proto

View file

@ -1,82 +0,0 @@
package criteria
import (
"github.com/open-policy-agent/opa/ast"
"github.com/pomerium/pomerium/pkg/policy/generator"
"github.com/pomerium/pomerium/pkg/policy/parser"
"github.com/pomerium/pomerium/pkg/policy/rules"
)
var groupsBody = ast.Body{
ast.MustParseExpr(`
session := get_session(input.session.id)
`),
ast.MustParseExpr(`
directory_user := get_directory_user(session)
`),
ast.MustParseExpr(`
group_ids := get_group_ids(session, directory_user)
`),
ast.MustParseExpr(`
group_names := [directory_group.name |
some i
group_id := group_ids[i]
directory_group := get_directory_group(group_id)
directory_group != null
directory_group.name != null]
`),
ast.MustParseExpr(`
group_emails := [directory_group.email |
some i
group_id := group_ids[i]
directory_group := get_directory_group(group_id)
directory_group != null
directory_group.email != null]
`),
ast.MustParseExpr(`
groups = array.concat(group_ids, array.concat(group_names, group_emails))
`),
}
type groupsCriterion struct {
g *Generator
}
func (groupsCriterion) DataType() generator.CriterionDataType {
return CriterionDataTypeStringListMatcher
}
func (groupsCriterion) Name() string {
return "groups"
}
func (c groupsCriterion) GenerateRule(_ string, data parser.Value) (*ast.Rule, []*ast.Rule, error) {
var body ast.Body
body = append(body, groupsBody...)
err := matchStringList(&body, ast.VarTerm("groups"), data)
if err != nil {
return nil, nil, err
}
rule := NewCriterionSessionRule(c.g, c.Name(),
ReasonGroupsOK, ReasonGroupsUnauthorized,
body)
return rule, []*ast.Rule{
rules.GetSession(),
rules.GetDirectoryUser(),
rules.GetDirectoryGroup(),
rules.GetGroupIDs(),
}, nil
}
// Groups returns a Criterion on a user's group ids, names or emails.
func Groups(generator *Generator) Criterion {
return groupsCriterion{g: generator}
}
func init() {
Register(Groups)
}

View file

@ -1,100 +0,0 @@
package criteria
import (
"testing"
"github.com/stretchr/testify/require"
"github.com/pomerium/pomerium/pkg/grpc/directory"
"github.com/pomerium/pomerium/pkg/grpc/session"
)
func TestGroups(t *testing.T) {
t.Run("no session", func(t *testing.T) {
res, err := evaluate(t, `
allow:
and:
- groups:
has: group1
- groups:
has: group2
`, []dataBrokerRecord{}, Input{Session: InputSession{ID: "SESSION_ID"}})
require.NoError(t, err)
require.Equal(t, A{false, A{ReasonUserUnauthenticated}, M{}}, res["allow"])
require.Equal(t, A{false, A{}}, res["deny"])
})
t.Run("by id", func(t *testing.T) {
res, err := evaluate(t, `
allow:
and:
- groups:
has: group1
`,
[]dataBrokerRecord{
&session.Session{
Id: "SESSION_ID",
UserId: "USER_ID",
},
&directory.User{
Id: "USER_ID",
GroupIds: []string{"group1"},
},
},
Input{Session: InputSession{ID: "SESSION_ID"}})
require.NoError(t, err)
require.Equal(t, A{true, A{ReasonGroupsOK}, M{}}, res["allow"])
require.Equal(t, A{false, A{}}, res["deny"])
})
t.Run("by email", func(t *testing.T) {
res, err := evaluate(t, `
allow:
and:
- groups:
has: "group1@example.com"
`,
[]dataBrokerRecord{
&session.Session{
Id: "SESSION_ID",
UserId: "USER_ID",
},
&directory.User{
Id: "USER_ID",
GroupIds: []string{"group1"},
},
&directory.Group{
Id: "group1",
Email: "group1@example.com",
},
},
Input{Session: InputSession{ID: "SESSION_ID"}})
require.NoError(t, err)
require.Equal(t, A{true, A{ReasonGroupsOK}, M{}}, res["allow"])
require.Equal(t, A{false, A{}}, res["deny"])
})
t.Run("by name", func(t *testing.T) {
res, err := evaluate(t, `
allow:
and:
- groups:
has: "Group 1"
`,
[]dataBrokerRecord{
&session.Session{
Id: "SESSION_ID",
UserId: "USER_ID",
},
&directory.User{
Id: "USER_ID",
GroupIds: []string{"group1"},
},
&directory.Group{
Id: "group1",
Name: "Group 1",
},
},
Input{Session: InputSession{ID: "SESSION_ID"}})
require.NoError(t, err)
require.Equal(t, A{true, A{ReasonGroupsOK}, M{}}, res["allow"])
require.Equal(t, A{false, A{}}, res["deny"])
})
}

View file

@ -18,8 +18,6 @@ const (
ReasonDomainUnauthorized = "domain-unauthorized"
ReasonEmailOK = "email-ok"
ReasonEmailUnauthorized = "email-unauthorized"
ReasonGroupsOK = "groups-ok"
ReasonGroupsUnauthorized = "groups-unauthorized"
ReasonHTTPMethodOK = "http-method-ok"
ReasonHTTPMethodUnauthorized = "http-method-unauthorized"
ReasonHTTPPathOK = "http-path-ok"

View file

@ -74,42 +74,6 @@ get_device_enrollment(device_credential) = v {
`)
}
// GetDirectoryUser returns the directory user for the given session.
func GetDirectoryUser() *ast.Rule {
return ast.MustParseRule(`
get_directory_user(session) = v {
v = get_databroker_record("type.googleapis.com/directory.User", session.user_id)
v != null
} else = "" {
true
}
`)
}
// GetDirectoryGroup returns the directory group for the given id.
func GetDirectoryGroup() *ast.Rule {
return ast.MustParseRule(`
get_directory_group(id) = v {
v = get_databroker_record("type.googleapis.com/directory.Group", id)
v != null
} else = {} {
true
}
`)
}
// GetGroupIDs returns the group ids for the given session or directory user.
func GetGroupIDs() *ast.Rule {
return ast.MustParseRule(`
get_group_ids(session, directory_user) = v {
v = directory_user.group_ids
v != null
} else = [] {
true
}
`)
}
// MergeWithAnd merges criterion results using `and`.
func MergeWithAnd() *ast.Rule {
return ast.MustParseRule(`