diff --git a/authenticate/handlers.go b/authenticate/handlers.go index d8515edf9..dbf60a173 100644 --- a/authenticate/handlers.go +++ b/authenticate/handlers.go @@ -596,5 +596,13 @@ func (a *Authenticate) saveSessionToDataBroker(ctx context.Context, sessionState } sessionState.Version = sessions.Version(res.GetServerVersion()) + _, err = state.directoryClient.RefreshUser(ctx, &directory.RefreshUserRequest{ + UserId: s.UserId, + AccessToken: accessToken.AccessToken, + }) + if err != nil { + log.Error().Err(err).Msg("directory: failed to refresh user data") + } + return nil } diff --git a/authenticate/handlers_test.go b/authenticate/handlers_test.go index 7a5634a83..08c43fa78 100644 --- a/authenticate/handlers_test.go +++ b/authenticate/handlers_test.go @@ -14,6 +14,8 @@ import ( "time" "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" "google.golang.org/protobuf/types/known/emptypb" "github.com/pomerium/pomerium/config" @@ -29,10 +31,12 @@ import ( "github.com/pomerium/pomerium/internal/urlutil" "github.com/pomerium/pomerium/pkg/cryptutil" "github.com/pomerium/pomerium/pkg/grpc/databroker" + "github.com/pomerium/pomerium/pkg/grpc/directory" "github.com/pomerium/pomerium/pkg/grpc/session" "github.com/golang/mock/gomock" "github.com/golang/protobuf/ptypes" + "github.com/golang/protobuf/ptypes/empty" "github.com/google/go-cmp/cmp" "github.com/stretchr/testify/assert" "golang.org/x/crypto/chacha20poly1305" @@ -171,6 +175,7 @@ func TestAuthenticate_SignIn(t *testing.T) { }, nil }, }, + directoryClient: new(mockDirectoryServiceClient), }), options: config.NewAtomicOptions(), @@ -262,6 +267,7 @@ func TestAuthenticate_SignOut(t *testing.T) { }, nil }, }, + directoryClient: new(mockDirectoryServiceClient), }), templates: template.Must(frontend.NewTemplates()), options: config.NewAtomicOptions(), @@ -366,6 +372,7 @@ func TestAuthenticate_OAuthCallback(t *testing.T) { return &databroker.SetResponse{Record: &databroker.Record{Data: in.Data}}, nil }, }, + directoryClient: new(mockDirectoryServiceClient), redirectURL: authURL, sessionStore: tt.session, cookieCipher: aead, @@ -515,6 +522,7 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) { }, nil }, }, + directoryClient: new(mockDirectoryServiceClient), }), options: config.NewAtomicOptions(), provider: identity.NewAtomicAuthenticator(), @@ -633,6 +641,7 @@ func TestAuthenticate_Dashboard(t *testing.T) { }, nil }, }, + directoryClient: new(mockDirectoryServiceClient), }), templates: template.Must(frontend.NewTemplates()), } @@ -679,3 +688,16 @@ func (m mockDataBrokerServiceClient) Get(ctx context.Context, in *databroker.Get func (m mockDataBrokerServiceClient) Set(ctx context.Context, in *databroker.SetRequest, opts ...grpc.CallOption) (*databroker.SetResponse, error) { return m.set(ctx, in, opts...) } + +type mockDirectoryServiceClient struct { + directory.DirectoryServiceClient + + refreshUser func(ctx context.Context, in *directory.RefreshUserRequest, opts ...grpc.CallOption) (*empty.Empty, error) +} + +func (m mockDirectoryServiceClient) RefreshUser(ctx context.Context, in *directory.RefreshUserRequest, opts ...grpc.CallOption) (*empty.Empty, error) { + if m.refreshUser != nil { + return m.refreshUser(ctx, in, opts...) + } + return nil, status.Error(codes.Unimplemented, "") +} diff --git a/authenticate/state.go b/authenticate/state.go index 00eae4d43..990eee294 100644 --- a/authenticate/state.go +++ b/authenticate/state.go @@ -22,6 +22,7 @@ import ( "github.com/pomerium/pomerium/pkg/cryptutil" "github.com/pomerium/pomerium/pkg/grpc" "github.com/pomerium/pomerium/pkg/grpc/databroker" + "github.com/pomerium/pomerium/pkg/grpc/directory" ) type authenticateState struct { @@ -46,6 +47,7 @@ type authenticateState struct { jwk *jose.JSONWebKeySet dataBrokerClient databroker.DataBrokerServiceClient + directoryClient directory.DirectoryServiceClient } func newAuthenticateState() *authenticateState { @@ -129,6 +131,7 @@ func newAuthenticateStateFromConfig(cfg *config.Config) (*authenticateState, err } state.dataBrokerClient = databroker.NewDataBrokerServiceClient(dataBrokerConn) + state.directoryClient = directory.NewDirectoryServiceClient(dataBrokerConn) return state, nil } diff --git a/cache/cache.go b/cache/cache.go index d8bf8e336..aa23f31b4 100644 --- a/cache/cache.go +++ b/cache/cache.go @@ -7,6 +7,7 @@ import ( "context" "fmt" "net" + "sync" "google.golang.org/grpc" "gopkg.in/tomb.v2" @@ -33,6 +34,9 @@ type Cache struct { localGRPCConnection *grpc.ClientConn dataBrokerStorageType string //TODO remove in v0.11 deprecatedCacheClusterDomain string //TODO: remove in v0.11 + + mu sync.Mutex + directoryProvider directory.Provider } // New creates a new cache service. @@ -90,6 +94,7 @@ func (c *Cache) OnConfigChange(cfg *config.Config) { // Register registers all the gRPC services with the given server. func (c *Cache) Register(grpcServer *grpc.Server) { databroker.RegisterDataBrokerServiceServer(grpcServer, c.dataBrokerServer) + directory.RegisterDirectoryServiceServer(grpcServer, c) } // Run runs the cache components. @@ -132,6 +137,9 @@ func (c *Cache) update(cfg *config.Config) error { ClientID: cfg.Options.ClientID, ClientSecret: cfg.Options.ClientSecret, }) + c.mu.Lock() + c.directoryProvider = directoryProvider + c.mu.Unlock() dataBrokerClient := databroker.NewDataBrokerServiceClient(c.localGRPCConnection) diff --git a/cache/directory.go b/cache/directory.go new file mode 100644 index 000000000..ee5561316 --- /dev/null +++ b/cache/directory.go @@ -0,0 +1,44 @@ +package cache + +import ( + "context" + "errors" + + "google.golang.org/protobuf/types/known/anypb" + "google.golang.org/protobuf/types/known/emptypb" + + "github.com/pomerium/pomerium/pkg/grpc/databroker" + "github.com/pomerium/pomerium/pkg/grpc/directory" +) + +// RefreshUser refreshes a user's directory information. +func (c *Cache) RefreshUser(ctx context.Context, req *directory.RefreshUserRequest) (*emptypb.Empty, error) { + c.mu.Lock() + dp := c.directoryProvider + c.mu.Unlock() + + if dp == nil { + return nil, errors.New("no directory provider is available for refresh") + } + + u, err := dp.User(ctx, req.GetUserId(), req.GetAccessToken()) + if err != nil { + return nil, err + } + + any, err := anypb.New(u) + if err != nil { + return nil, err + } + + _, err = c.dataBrokerServer.Set(ctx, &databroker.SetRequest{ + Type: any.GetTypeUrl(), + Id: u.GetId(), + Data: any, + }) + if err != nil { + return nil, err + } + + return new(emptypb.Empty), nil +} diff --git a/internal/directory/auth0/auth0.go b/internal/directory/auth0/auth0.go index 6f11ec745..65b1d8e6a 100644 --- a/internal/directory/auth0/auth0.go +++ b/internal/directory/auth0/auth0.go @@ -21,16 +21,33 @@ import ( // Name is the provider name. const Name = "auth0" -// RoleManager defines what is needed to get role info from Auth0. -type RoleManager interface { - List(opts ...management.ListOption) (r *management.RoleList, err error) - Users(id string, opts ...management.ListOption) (u *management.UserList, err error) +type ( + // RoleManager defines what is needed to get role info from Auth0. + RoleManager interface { + List(opts ...management.ListOption) (r *management.RoleList, err error) + Users(id string, opts ...management.ListOption) (u *management.UserList, err error) + } + // UserManager defines what is needed to get user info from Auth0. + UserManager interface { + Read(id string) (*management.User, error) + Roles(id string, opts ...management.ListOption) (r *management.RoleList, err error) + } +) + +type newManagersFunc = func(ctx context.Context, domain string, serviceAccount *ServiceAccount) (RoleManager, UserManager, error) + +func defaultNewManagersFunc(ctx context.Context, domain string, serviceAccount *ServiceAccount) (RoleManager, UserManager, error) { + m, err := management.New(domain, serviceAccount.ClientID, serviceAccount.Secret, management.WithContext(ctx)) + if err != nil { + return nil, nil, fmt.Errorf("auth0: could not build management: %w", err) + } + return m.Role, m.User, nil } type config struct { domain string serviceAccount *ServiceAccount - newRoleManager func(ctx context.Context, domain string, serviceAccount *ServiceAccount) (RoleManager, error) + newManagers newManagersFunc } // Option provides config for the Auth0 Provider. @@ -50,17 +67,9 @@ func WithDomain(domain string) Option { } } -func defaultNewRoleManagerFunc(ctx context.Context, domain string, serviceAccount *ServiceAccount) (RoleManager, error) { - m, err := management.New(domain, serviceAccount.ClientID, serviceAccount.Secret, management.WithContext(ctx)) - if err != nil { - return nil, fmt.Errorf("auth0: could not build management") - } - return m.Role, nil -} - func getConfig(options ...Option) *config { cfg := &config{ - newRoleManager: defaultNewRoleManagerFunc, + newManagers: defaultNewManagersFunc, } for _, option := range options { option(cfg) @@ -82,13 +91,49 @@ func New(options ...Option) *Provider { } } -func (p *Provider) getRoleManager(ctx context.Context) (RoleManager, error) { - return p.cfg.newRoleManager(ctx, p.cfg.domain, p.cfg.serviceAccount) +func (p *Provider) getManagers(ctx context.Context) (RoleManager, UserManager, error) { + return p.cfg.newManagers(ctx, p.cfg.domain, p.cfg.serviceAccount) +} + +// User returns the user record for the given id. +func (p *Provider) User(ctx context.Context, userID, accessToken string) (*directory.User, error) { + _, um, err := p.getManagers(ctx) + if err != nil { + return nil, fmt.Errorf("auth0: could not get the role manager: %w", err) + } + + _, providerUserID := databroker.FromUserID(userID) + du := &directory.User{ + Id: userID, + } + + u, err := um.Read(providerUserID) + if err != nil { + return nil, fmt.Errorf("auth0: error getting user info: %w", err) + } + du.DisplayName = u.GetName() + du.Email = u.GetEmail() + + for page, hasNext := 0, true; hasNext; page++ { + rl, err := um.Roles(providerUserID, management.IncludeTotals(true), management.Page(page)) + if err != nil { + return nil, fmt.Errorf("auth0: error getting user roles: %w", err) + } + + for _, role := range rl.Roles { + du.GroupIds = append(du.GroupIds, role.GetID()) + } + + hasNext = rl.HasNext() + } + + sort.Strings(du.GroupIds) + return du, nil } // UserGroups fetches a slice of groups and users. func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*directory.User, error) { - rm, err := p.getRoleManager(ctx) + rm, _, err := p.getManagers(ctx) if err != nil { return nil, nil, fmt.Errorf("auth0: could not get the role manager: %w", err) } @@ -147,6 +192,9 @@ func getRoles(rm RoleManager) ([]*directory.Group, error) { shouldContinue = listRes.HasNext() } + sort.Slice(roles, func(i, j int) bool { + return roles[i].GetId() < roles[j].GetId() + }) return roles, nil } @@ -170,6 +218,7 @@ func getRoleUserIDs(rm RoleManager, roleID string) ([]string, error) { shouldContinue = usersRes.HasNext() } + sort.Strings(ids) return ids, nil } @@ -180,7 +229,30 @@ type ServiceAccount struct { } // ParseServiceAccount parses the service account in the config options. -func ParseServiceAccount(rawServiceAccount string) (*ServiceAccount, error) { +func ParseServiceAccount(options directory.Options) (*ServiceAccount, error) { + if options.ServiceAccount != "" { + return parseServiceAccountFromString(options.ServiceAccount) + } + return parseServiceAccountFromOptions(options.ClientID, options.ClientSecret) +} + +func parseServiceAccountFromOptions(clientID, clientSecret string) (*ServiceAccount, error) { + serviceAccount := ServiceAccount{ + ClientID: clientID, + Secret: clientSecret, + } + + if serviceAccount.ClientID == "" { + return nil, fmt.Errorf("auth0: client_id is required") + } + if serviceAccount.Secret == "" { + return nil, fmt.Errorf("auth0: client_secret is required") + } + + return &serviceAccount, nil +} + +func parseServiceAccountFromString(rawServiceAccount string) (*ServiceAccount, error) { bs, err := base64.StdEncoding.DecodeString(rawServiceAccount) if err != nil { return nil, fmt.Errorf("auth0: could not decode base64: %w", err) diff --git a/internal/directory/auth0/auth0_test.go b/internal/directory/auth0/auth0_test.go index d462ce85e..451d48ff1 100644 --- a/internal/directory/auth0/auth0_test.go +++ b/internal/directory/auth0/auth0_test.go @@ -2,34 +2,133 @@ package auth0 import ( "context" + "crypto/tls" + "encoding/json" "errors" "fmt" + "net/http" + "net/http/httptest" "net/url" "testing" + "time" + "github.com/go-chi/chi" + "github.com/go-chi/chi/middleware" "github.com/golang/mock/gomock" "github.com/stretchr/testify/assert" "gopkg.in/auth0.v4/management" "github.com/pomerium/pomerium/internal/directory/auth0/mock_auth0" + "github.com/pomerium/pomerium/internal/testutil" "github.com/pomerium/pomerium/pkg/grpc/directory" ) +type M = map[string]interface{} + +func newMockAPI(t *testing.T, srv *httptest.Server) http.Handler { + r := chi.NewRouter() + r.Use(middleware.Logger) + r.Post("/oauth/token", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(M{ + "access_token": "ACCESSTOKEN", + "token_type": "Bearer", + "refresh_token": "REFRESHTOKEN", + }) + }) + r.Route("/api/v2", func(r chi.Router) { + r.Route("/users", func(r chi.Router) { + r.Route("/{user_id}", func(r chi.Router) { + r.Get("/roles", func(w http.ResponseWriter, r *http.Request) { + switch chi.URLParam(r, "user_id") { + case "user1": + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(M{ + "total": 2, + "limit": 2, + "roles": []M{ + {"id": "role1"}, + {"id": "role2"}, + }, + }) + default: + http.Error(w, "not found", http.StatusNotFound) + } + }) + r.Get("/", func(w http.ResponseWriter, r *http.Request) { + switch chi.URLParam(r, "user_id") { + case "user1": + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(M{ + "user_id": "user1", + "email": "user1@example.com", + "name": "User 1", + }) + default: + http.Error(w, "not found", http.StatusNotFound) + } + }) + }) + }) + }) + + return r +} + +func TestProvider_User(t *testing.T) { + ctx, clearTimeout := context.WithTimeout(context.Background(), time.Minute) + defer clearTimeout() + + orig := http.DefaultTransport + defer func() { + http.DefaultTransport = orig + }() + http.DefaultTransport = &http.Transport{ + TLSClientConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + } + + var mockAPI http.Handler + srv := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + mockAPI.ServeHTTP(w, r) + })) + srv.StartTLS() + defer srv.Close() + mockAPI = newMockAPI(t, srv) + + p := New( + WithDomain(srv.URL), + WithServiceAccount(&ServiceAccount{ClientID: "CLIENT_ID", Secret: "SECRET"}), + ) + du, err := p.User(ctx, "auth0/user1", "") + if !assert.NoError(t, err) { + return + } + testutil.AssertProtoJSONEqual(t, `{ + "id": "auth0/user1", + "displayName": "User 1", + "email": "user1@example.com", + "groupIds": ["role1", "role2"] + }`, du) +} + type mockNewRoleManagerFunc struct { CalledWithContext context.Context CalledWithDomain string CalledWithServiceAccount *ServiceAccount ReturnRoleManager RoleManager + ReturnUserManager UserManager ReturnError error } -func (m *mockNewRoleManagerFunc) f(ctx context.Context, domain string, serviceAccount *ServiceAccount) (RoleManager, error) { +func (m *mockNewRoleManagerFunc) f(ctx context.Context, domain string, serviceAccount *ServiceAccount) (RoleManager, UserManager, error) { m.CalledWithContext = ctx m.CalledWithDomain = domain m.CalledWithServiceAccount = serviceAccount - return m.ReturnRoleManager, m.ReturnError + return m.ReturnRoleManager, m.ReturnUserManager, m.ReturnError } type listOptionMatcher struct { @@ -379,7 +478,7 @@ func TestProvider_UserGroups(t *testing.T) { mRoleManager := mock_auth0.NewMockRoleManager(ctrl) - mNewRoleManagerFunc := mockNewRoleManagerFunc{ + mNewManagersFunc := mockNewRoleManagerFunc{ ReturnRoleManager: mRoleManager, ReturnError: tc.newRoleManagerError, } @@ -392,7 +491,7 @@ func TestProvider_UserGroups(t *testing.T) { WithDomain(expectedDomain), WithServiceAccount(expectedServiceAccount), ) - p.cfg.newRoleManager = mNewRoleManagerFunc.f + p.cfg.newManagers = mNewManagersFunc.f actualGroups, actualUsers, err := p.UserGroups(context.Background()) if tc.expectedError != nil { @@ -404,8 +503,8 @@ func TestProvider_UserGroups(t *testing.T) { assert.Equal(t, tc.expectedGroups, actualGroups) assert.Equal(t, tc.expectedUsers, actualUsers) - assert.Equal(t, expectedDomain, mNewRoleManagerFunc.CalledWithDomain) - assert.Equal(t, expectedServiceAccount, mNewRoleManagerFunc.CalledWithServiceAccount) + assert.Equal(t, expectedDomain, mNewManagersFunc.CalledWithDomain) + assert.Equal(t, expectedServiceAccount, mNewManagersFunc.CalledWithServiceAccount) }) } @@ -433,7 +532,7 @@ func TestParseServiceAccount(t *testing.T) { for _, tc := range tests { tc := tc t.Run(tc.name, func(t *testing.T) { - actualServiceAccount, err := ParseServiceAccount(tc.rawServiceAccount) + actualServiceAccount, err := ParseServiceAccount(directory.Options{ServiceAccount: tc.rawServiceAccount}) if tc.expectedError != nil { assert.EqualError(t, err, tc.expectedError.Error()) } else { diff --git a/internal/directory/azure/azure.go b/internal/directory/azure/azure.go index 17bd2e3d3..0f39533c8 100644 --- a/internal/directory/azure/azure.go +++ b/internal/directory/azure/azure.go @@ -6,14 +6,15 @@ import ( "encoding/base64" "encoding/json" "fmt" - "io" "net/http" "net/url" + "sort" "strings" "sync" "golang.org/x/oauth2" + "github.com/pomerium/pomerium/pkg/grpc/databroker" "github.com/pomerium/pomerium/pkg/grpc/directory" ) @@ -101,6 +102,50 @@ func New(options ...Option) *Provider { return p } +// User returns the user record for the given id. +func (p *Provider) User(ctx context.Context, userID, accessToken string) (*directory.User, error) { + if p.cfg.serviceAccount == nil { + return nil, fmt.Errorf("azure: service account not defined") + } + + _, providerUserID := databroker.FromUserID(userID) + + du := &directory.User{ + Id: userID, + } + + userURL := p.cfg.graphURL.ResolveReference(&url.URL{ + Path: fmt.Sprintf("/v1.0/users/%s", providerUserID), + }).String() + + var u usersDeltaResponseUser + err := p.api(ctx, userURL, &u) + if err != nil { + return nil, err + } + du.DisplayName = u.DisplayName + du.Email = u.getEmail() + + groupURL := p.cfg.graphURL.ResolveReference(&url.URL{ + Path: fmt.Sprintf("/v1.0/users/%s/transitiveMemberOf", providerUserID), + }).String() + + var res struct { + Value []usersDeltaResponseUser `json:"value"` + } + err = p.api(ctx, groupURL, &res) + if err != nil { + return nil, err + } + for _, g := range res.Value { + du.GroupIds = append(du.GroupIds, g.ID) + } + + sort.Strings(du.GroupIds) + + return du, nil +} + // UserGroups returns the directory users in azure active directory. func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*directory.User, error) { if p.cfg.serviceAccount == nil { @@ -116,13 +161,13 @@ func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*direc return groups, users, nil } -func (p *Provider) api(ctx context.Context, method, url string, body io.Reader, out interface{}) error { +func (p *Provider) api(ctx context.Context, url string, out interface{}) error { token, err := p.getToken(ctx) if err != nil { return err } - req, err := http.NewRequestWithContext(ctx, method, url, body) + req, err := http.NewRequestWithContext(ctx, "GET", url, nil) if err != nil { return fmt.Errorf("azure: error creating HTTP request: %w", err) } @@ -143,7 +188,7 @@ func (p *Provider) api(ctx context.Context, method, url string, body io.Reader, } if res.StatusCode/100 != 2 { - return fmt.Errorf("azure: error querying api: %s", res.Status) + return fmt.Errorf("azure: error querying api (%s): %s", url, res.Status) } err = json.NewDecoder(res.Body).Decode(out) diff --git a/internal/directory/azure/azure_test.go b/internal/directory/azure/azure_test.go index 6f8cb0251..a256b45a2 100644 --- a/internal/directory/azure/azure_test.go +++ b/internal/directory/azure/azure_test.go @@ -13,6 +13,7 @@ import ( "github.com/go-chi/chi/middleware" "github.com/stretchr/testify/assert" + "github.com/pomerium/pomerium/internal/testutil" "github.com/pomerium/pomerium/pkg/grpc/directory" ) @@ -74,12 +75,62 @@ func newMockAPI(t *testing.T, srv *httptest.Server) http.Handler { }, }) }) + r.Get("/users/{user_id}", func(w http.ResponseWriter, r *http.Request) { + switch chi.URLParam(r, "user_id") { + case "user-1": + _ = json.NewEncoder(w).Encode(M{"id": "user-1", "displayName": "User 1", "mail": "user1@example.com"}) + default: + http.Error(w, "not found", http.StatusNotFound) + } + }) + r.Get("/users/{user_id}/transitiveMemberOf", func(w http.ResponseWriter, r *http.Request) { + switch chi.URLParam(r, "user_id") { + case "user-1": + _ = json.NewEncoder(w).Encode(M{ + "value": []M{ + {"id": "admin"}, + }, + }) + default: + http.Error(w, "not found", http.StatusNotFound) + } + }) }) return r } -func Test(t *testing.T) { +func TestProvider_User(t *testing.T) { + var mockAPI http.Handler + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + mockAPI.ServeHTTP(w, r) + })) + defer srv.Close() + mockAPI = newMockAPI(t, srv) + + p := New( + WithGraphURL(mustParseURL(srv.URL)), + WithLoginURL(mustParseURL(srv.URL)), + WithServiceAccount(&ServiceAccount{ + ClientID: "CLIENT_ID", + ClientSecret: "CLIENT_SECRET", + DirectoryID: "DIRECTORY_ID", + }), + ) + + du, err := p.User(context.Background(), "azure/user-1", "") + if !assert.NoError(t, err) { + return + } + testutil.AssertProtoJSONEqual(t, `{ + "id": "azure/user-1", + "displayName": "User 1", + "email": "user1@example.com", + "groupIds": ["admin"] + }`, du) +} + +func TestProvider_UserGroups(t *testing.T) { var mockAPI http.Handler srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { mockAPI.ServeHTTP(w, r) @@ -118,10 +169,10 @@ func Test(t *testing.T) { Email: "user3@example.com", }, }, users) - assert.Equal(t, []*directory.Group{ - {Id: "admin", Name: "Admin Group"}, - {Id: "test", Name: "Test Group"}, - }, groups) + testutil.AssertProtoJSONEqual(t, `[ + { "id": "admin", "name": "Admin Group" }, + { "id": "test", "name": "Test Group"} + ]`, groups) } func TestParseServiceAccount(t *testing.T) { diff --git a/internal/directory/azure/delta.go b/internal/directory/azure/delta.go index a6d3ed667..cfb21ec27 100644 --- a/internal/directory/azure/delta.go +++ b/internal/directory/azure/delta.go @@ -86,7 +86,7 @@ func (dc *deltaCollection) syncGroups(ctx context.Context) error { for { var res groupsDeltaResponse - err := dc.provider.api(ctx, "GET", apiURL, nil, &res) + err := dc.provider.api(ctx, apiURL, &res) if err != nil { return err } @@ -146,7 +146,7 @@ func (dc *deltaCollection) syncUsers(ctx context.Context) error { for { var res usersDeltaResponse - err := dc.provider.api(ctx, "GET", apiURL, nil, &res) + err := dc.provider.api(ctx, apiURL, &res) if err != nil { return err } @@ -197,6 +197,9 @@ func (dc *deltaCollection) CurrentUserGroups() ([]*directory.Group, []*directory } groupLookup.addGroup(g.id, groupIDs, userIDs) } + sort.Slice(groups, func(i, j int) bool { + return groups[i].GetId() < groups[j].GetId() + }) var users []*directory.User for _, u := range dc.users { diff --git a/internal/directory/github/github.go b/internal/directory/github/github.go index 7db5e2978..193acbdcc 100644 --- a/internal/directory/github/github.go +++ b/internal/directory/github/github.go @@ -2,6 +2,7 @@ package github import ( + "bytes" "context" "encoding/base64" "encoding/json" @@ -83,6 +84,47 @@ func New(options ...Option) *Provider { } } +// User returns the user record for the given id. +func (p *Provider) User(ctx context.Context, userID, accessToken string) (*directory.User, error) { + if p.cfg.serviceAccount == nil { + return nil, fmt.Errorf("github: service account not defined") + } + + _, providerUserID := databroker.FromUserID(userID) + du := &directory.User{ + Id: userID, + } + + au, err := p.getUser(ctx, providerUserID) + if err != nil { + return nil, err + } + du.DisplayName = au.Name + du.Email = au.Email + + teamIDLookup := map[int]struct{}{} + orgSlugs, err := p.listOrgs(ctx) + if err != nil { + return nil, err + } + for _, orgSlug := range orgSlugs { + teamIDs, err := p.listUserOrganizationTeams(ctx, userID, orgSlug) + if err != nil { + return nil, err + } + for _, teamID := range teamIDs { + teamIDLookup[teamID] = struct{}{} + } + } + + for teamID := range teamIDLookup { + du.GroupIds = append(du.GroupIds, strconv.Itoa(teamID)) + } + sort.Strings(du.GroupIds) + + return du, nil +} + // UserGroups gets the directory user groups for github. func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*directory.User, error) { if p.cfg.serviceAccount == nil { @@ -230,6 +272,77 @@ func (p *Provider) getUser(ctx context.Context, userLogin string) (*apiUserObjec return &res, nil } +func (p *Provider) listUserOrganizationTeams(ctx context.Context, userSlug string, orgSlug string) ([]int, error) { + // GitHub's Rest API doesn't have an easy way of querying this data, so we use the GraphQL API. + + enc := func(obj interface{}) string { + bs, _ := json.Marshal(obj) + return string(bs) + } + const pageCount = 100 + + var teamIDs []int + var cursor *string + for { + var res struct { + Data struct { + Organization struct { + Teams struct { + TotalCount int `json:"totalCount"` + PageInfo struct { + EndCursor string `json:"endCursor"` + } `json:"pageInfo"` + Edges []struct { + Node struct { + ID int `json:"id"` + } `json:"node"` + } `json:"edges"` + } `json:"teams"` + } `json:"organization"` + } `json:"data"` + } + cursorStr := "" + if cursor != nil { + cursorStr = fmt.Sprintf(",%s", enc(*cursor)) + } + q := fmt.Sprintf(`query { + organization(login:%s) { + teams(first:%s, userLogins:[%s] %s) { + totalCount + pageInfo { + endCursor + } + edges { + node { + id + } + } + } + } + }`, enc(orgSlug), enc(pageCount), enc(userSlug), cursorStr) + _, err := p.graphql(ctx, q, &res) + if err != nil { + return nil, err + } + + if len(res.Data.Organization.Teams.Edges) == 0 { + break + } + + for _, edge := range res.Data.Organization.Teams.Edges { + teamIDs = append(teamIDs, edge.Node.ID) + } + + if len(teamIDs) >= res.Data.Organization.Teams.TotalCount { + break + } + + cursor = &res.Data.Organization.Teams.PageInfo.EndCursor + } + + return teamIDs, nil +} + func (p *Provider) api(ctx context.Context, apiURL string, out interface{}) (http.Header, error) { req, err := http.NewRequestWithContext(ctx, "GET", apiURL, nil) if err != nil { @@ -257,6 +370,41 @@ func (p *Provider) api(ctx context.Context, apiURL string, out interface{}) (htt return res.Header, nil } +func (p *Provider) graphql(ctx context.Context, query string, out interface{}) (http.Header, error) { + apiURL := p.cfg.url.ResolveReference(&url.URL{ + Path: "/graphql", + }).String() + + bs, _ := json.Marshal(struct { + Query string `json:"query"` + }{query}) + + req, err := http.NewRequestWithContext(ctx, "POST", apiURL, bytes.NewReader(bs)) + if err != nil { + return nil, fmt.Errorf("github: failed to create http request: %w", err) + } + req.SetBasicAuth(p.cfg.serviceAccount.Username, p.cfg.serviceAccount.PersonalAccessToken) + + res, err := p.cfg.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("github: failed to make http request: %w", err) + } + defer res.Body.Close() + + if res.StatusCode/100 != 2 { + return nil, fmt.Errorf("github: error from API: %s", res.Status) + } + + if out != nil { + err := json.NewDecoder(res.Body).Decode(out) + if err != nil { + return nil, fmt.Errorf("github: failed to decode json body: %w", err) + } + } + + return res.Header, nil +} + func getNextLink(hdrs http.Header) string { for _, link := range linkheader.ParseMultiple(hdrs.Values("Link")) { if link.Rel == "next" { diff --git a/internal/directory/github/github_test.go b/internal/directory/github/github_test.go index a9e52f878..1f9e7cd72 100644 --- a/internal/directory/github/github_test.go +++ b/internal/directory/github/github_test.go @@ -29,6 +29,33 @@ func newMockAPI(t *testing.T, srv *httptest.Server) http.Handler { next.ServeHTTP(w, r) }) }) + r.Post("/graphql", func(w http.ResponseWriter, r *http.Request) { + var body struct { + Query string `json:"query"` + } + json.NewDecoder(r.Body).Decode(&body) + + _ = json.NewEncoder(w).Encode(M{ + "data": M{ + "organization": M{ + "teams": M{ + "totalCount": 3, + "edges": []M{ + {"node": M{ + "id": 1, + }}, + {"node": M{ + "id": 2, + }}, + {"node": M{ + "id": 3, + }}, + }, + }, + }, + }, + }) + }) r.Get("/user/orgs", func(w http.ResponseWriter, r *http.Request) { json.NewEncoder(w).Encode([]M{ {"login": "org1"}, @@ -88,7 +115,34 @@ func newMockAPI(t *testing.T, srv *httptest.Server) http.Handler { return r } -func Test(t *testing.T) { +func TestProvider_User(t *testing.T) { + var mockAPI http.Handler + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + mockAPI.ServeHTTP(w, r) + })) + defer srv.Close() + mockAPI = newMockAPI(t, srv) + + p := New( + WithURL(mustParseURL(srv.URL)), + WithServiceAccount(&ServiceAccount{ + Username: "abc", + PersonalAccessToken: "xyz", + }), + ) + du, err := p.User(context.Background(), "github/user1", "") + if !assert.NoError(t, err) { + return + } + testutil.AssertProtoJSONEqual(t, `{ + "id": "github/user1", + "groupIds": ["1", "2", "3"], + "displayName": "User 1", + "email": "user1@example.com" + }`, du) +} + +func TestProvider_UserGroups(t *testing.T) { var mockAPI http.Handler srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { mockAPI.ServeHTTP(w, r) diff --git a/internal/directory/gitlab/gitlab.go b/internal/directory/gitlab/gitlab.go index 6ba5a0cd2..fec3f243f 100644 --- a/internal/directory/gitlab/gitlab.go +++ b/internal/directory/gitlab/gitlab.go @@ -83,6 +83,32 @@ func New(options ...Option) *Provider { } } +// User returns the user record for the given id. +func (p *Provider) User(ctx context.Context, userID, accessToken string) (*directory.User, error) { + _, providerUserID := databroker.FromUserID(userID) + du := &directory.User{ + Id: userID, + } + + au, err := p.getUser(ctx, providerUserID, accessToken) + if err != nil { + return nil, err + } + du.DisplayName = au.Name + du.Email = au.Email + + groups, err := p.listGroups(ctx, accessToken) + if err != nil { + return nil, err + } + for _, g := range groups { + du.GroupIds = append(du.GroupIds, g.Id) + } + sort.Strings(du.GroupIds) + + return du, nil +} + // UserGroups gets the directory user groups for gitlab. func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*directory.User, error) { if p.cfg.serviceAccount == nil { @@ -91,7 +117,7 @@ func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*direc p.log.Info().Msg("getting user groups") - groups, err := p.listGroups(ctx) + groups, err := p.listGroups(ctx, "") if err != nil { return nil, nil, err } @@ -129,8 +155,20 @@ func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*direc return groups, users, nil } +func (p *Provider) getUser(ctx context.Context, userID string, accessToken string) (*apiUserObject, error) { + apiURL := p.cfg.url.ResolveReference(&url.URL{ + Path: fmt.Sprintf("/api/v4/users/%s", userID), + }).String() + var result apiUserObject + _, err := p.api(ctx, accessToken, apiURL, &result) + if err != nil { + return nil, fmt.Errorf("gitlab: error querying user: %w", err) + } + return &result, nil +} + // listGroups returns a map, with key is group ID, element is group name. -func (p *Provider) listGroups(ctx context.Context) ([]*directory.Group, error) { +func (p *Provider) listGroups(ctx context.Context, accessToken string) ([]*directory.Group, error) { nextURL := p.cfg.url.ResolveReference(&url.URL{ Path: "/api/v4/groups", }).String() @@ -140,7 +178,7 @@ func (p *Provider) listGroups(ctx context.Context) ([]*directory.Group, error) { ID int `json:"id"` Name string `json:"name"` } - hdrs, err := p.apiGet(ctx, nextURL, &result) + hdrs, err := p.api(ctx, accessToken, nextURL, &result) if err != nil { return nil, fmt.Errorf("gitlab: error querying groups: %w", err) } @@ -163,7 +201,7 @@ func (p *Provider) listGroupMembers(ctx context.Context, groupID string) (users }).String() for nextURL != "" { var result []apiUserObject - hdrs, err := p.apiGet(ctx, nextURL, &result) + hdrs, err := p.api(ctx, "", nextURL, &result) if err != nil { return nil, fmt.Errorf("gitlab: error querying group members: %w", err) } @@ -174,14 +212,18 @@ func (p *Provider) listGroupMembers(ctx context.Context, groupID string) (users return users, nil } -func (p *Provider) apiGet(ctx context.Context, uri string, out interface{}) (http.Header, error) { +func (p *Provider) api(ctx context.Context, accessToken string, uri string, out interface{}) (http.Header, error) { req, err := http.NewRequestWithContext(ctx, "GET", uri, nil) if err != nil { return nil, fmt.Errorf("gitlab: failed to create HTTP request: %w", err) } req.Header.Set("Accept", "application/json") req.Header.Set("Content-Type", "application/json") - req.Header.Set("PRIVATE-TOKEN", p.cfg.serviceAccount.PrivateToken) + if accessToken != "" { + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", accessToken)) + } else { + req.Header.Set("PRIVATE-TOKEN", p.cfg.serviceAccount.PrivateToken) + } res, err := p.cfg.httpClient.Do(req) if err != nil { @@ -190,7 +232,7 @@ func (p *Provider) apiGet(ctx context.Context, uri string, out interface{}) (htt defer res.Body.Close() if res.StatusCode/100 != 2 { - return nil, fmt.Errorf("gitlab: error query api status_code=%d: %s", res.StatusCode, res.Status) + return nil, fmt.Errorf("gitlab: error querying api url=%s status_code=%d: %s", uri, res.StatusCode, res.Status) } err = json.NewDecoder(res.Body).Decode(out) diff --git a/internal/directory/google/google.go b/internal/directory/google/google.go index b47eba6f6..156605fce 100644 --- a/internal/directory/google/google.go +++ b/internal/directory/google/google.go @@ -5,13 +5,16 @@ import ( "context" "encoding/base64" "encoding/json" + "errors" "fmt" + "net/http" "sort" "sync" "github.com/rs/zerolog" "golang.org/x/oauth2/google" admin "google.golang.org/api/admin/directory/v1" + "google.golang.org/api/googleapi" "google.golang.org/api/option" "github.com/pomerium/pomerium/internal/log" @@ -19,11 +22,15 @@ import ( "github.com/pomerium/pomerium/pkg/grpc/directory" ) -// Name is the provider name. -const Name = "google" +const ( + // Name is the provider name. + Name = "google" + + currentAccountCustomerID = "my_customer" +) const ( - defaultProviderURL = "https://accounts.google.com" + defaultProviderURL = "https://www.googleapis.com/admin/directory/v1/" ) type config struct { @@ -78,6 +85,51 @@ func New(options ...Option) *Provider { } } +// User returns the user record for the given id. +func (p *Provider) User(ctx context.Context, userID, accessToken string) (*directory.User, error) { + apiClient, err := p.getAPIClient(ctx) + if err != nil { + return nil, fmt.Errorf("google: error getting API client: %w", err) + } + + _, providerUserID := databroker.FromUserID(userID) + du := &directory.User{ + Id: userID, + } + + au, err := apiClient.Users.Get(providerUserID). + Context(ctx). + Do() + if isAccessDenied(err) { + // ignore forbidden errors as a user may login from another gsuite domain + return du, nil + } else if err != nil { + return nil, fmt.Errorf("google: error getting user: %w", err) + } else { + if au.Name != nil { + du.DisplayName = au.Name.FullName + } + du.Email = au.PrimaryEmail + } + + err = apiClient.Groups.List(). + Context(ctx). + UserKey(providerUserID). + Pages(ctx, func(res *admin.Groups) error { + for _, g := range res.Groups { + du.GroupIds = append(du.GroupIds, g.Id) + } + return nil + }) + if err != nil { + return nil, fmt.Errorf("google: error getting groups for user: %w", err) + } + + sort.Strings(du.GroupIds) + + return du, nil +} + // UserGroups returns a slice of group names a given user is in // NOTE: groups via Directory API is limited to 1 QPS! // https://developers.google.com/admin-sdk/directory/v1/reference/groups/list @@ -91,7 +143,7 @@ func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*direc var groups []*directory.Group err = apiClient.Groups.List(). Context(ctx). - Customer("my_customer"). + Customer(currentAccountCustomerID). Pages(ctx, func(res *admin.Groups) error { for _, g := range res.Groups { // Skip group without member. @@ -113,7 +165,7 @@ func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*direc userLookup := map[string]apiUserObject{} err = apiClient.Users.List(). Context(ctx). - Customer("my_customer"). + Customer(currentAccountCustomerID). Pages(ctx, func(res *admin.Users) error { for _, u := range res.Users { userLookup[u.Id] = apiUserObject{ @@ -188,7 +240,7 @@ func (p *Provider) getAPIClient(ctx context.Context) (*admin.Service, error) { ts := config.TokenSource(ctx) - p.apiClient, err = admin.NewService(ctx, option.WithTokenSource(ts)) + p.apiClient, err = admin.NewService(ctx, option.WithTokenSource(ts), option.WithEndpoint(p.cfg.url)) if err != nil { return nil, fmt.Errorf("google: failed creating admin service %w", err) } @@ -243,3 +295,16 @@ type apiUserObject struct { DisplayName string Email string } + +func isAccessDenied(err error) bool { + if err == nil { + return false + } + + gerr := new(googleapi.Error) + if errors.As(err, &gerr) { + return gerr.Code == http.StatusForbidden + } + + return false +} diff --git a/internal/directory/google/google_test.go b/internal/directory/google/google_test.go new file mode 100644 index 000000000..be2a8a48d --- /dev/null +++ b/internal/directory/google/google_test.go @@ -0,0 +1,140 @@ +package google + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/go-chi/chi" + "github.com/go-chi/chi/middleware" + "github.com/stretchr/testify/assert" +) + +var privateKey = ` +-----BEGIN RSA PRIVATE KEY----- +MIIG4wIBAAKCAYEAnetGqPqS6dqYnV9S5S8gL34t7RRUMsf4prxIR+1PMv+bEqVH +pzXBbzgGKIbQbu+njoRzhQ95RbcEzZiVFivNggijkFoUNkFIjy42O/xXdTRTX3/u +4pxu1ctccqYnZwnry6E8ekQTVX7kgmVqgzIrY1Y6K3PVlhkgGDK/TStu+RIPoA73 +vJUpJTFTw+6tgUSBmCkzctQsGFGUiGBRpqlxogEkImJYrcMUJkWhTopLy79+3OB5 +eGGKWwA6P2ZhBB4RKHBWyWipsYxr889QNG6P3o+Be6lIzvcdiIIBYK8qWLmO45hR +xUGWSRK8sveAJO54t+wGE0dSPVKfoS4oNqGYdGhBzoTwsfZRvyvcWeQob4DNCqQa +n41XAuYGOG3X1PexdwGSwwqrq2tuG9d2AJ2NjG8nC9hjuuKDfBGTigTwzrwkrn+F +3o94NoglQgsXfZeWoBXR5HDaDTdqexSRK0OpSPbvzkn8QDUymdaw7nRS7kU2O7fa +W8kxiV8AVt2v/jYjAgMBAAECggGAS+6NE0Mo0Pki2Mi0+y4ls7BgNNbJhYFRthpi +RvN8WXE+B0EhquzWDbxKecIZBr6FOqnFQf2muja+QH1VckutjRDKVOZ7QXsygGYf +/cff5aM7U3gYTS4avQIDeb0axRioIElu4vtIsJtLFMfe5yaAZktXvPz9fiamn/wG +r/xqZ6ifir6nsC2okxGczWE+XCGsjpWA/321lhvj548os5JV6SfTUBUpvqNGVQC2 +ByXIPDffsCTfQ1rjQ85gM4vuqiQqKn/KXRrMR1kRIOrglMJ6dllitsadkfWdPkVg +fHjM1KAnw/uob5kFhvqeEll9sESfCXttrb4XneKAsEck958ChSpU4csaBAfLFVYP +5xyIfoaQ+/CUjWI0u1lQbg6jZO59rYfdd5OlH+MyhHybuHR1a0G1izNfuG9WPOWI +aprNayH2Wxy9/ZvlrE5yTAeW9tof28hO6O7wBNOcJTrzztsN+V8pSAo0IE2r4D83 +h978LneAwhC/8mVvzhd/y2t99vcBAoHBAMumCoHmHRAsYBApg2SHxPCTDi4fS1hC +IbcuzyvJLv214bHOsdgg2I72a1Q+bbrZAHgENVgSP9Sx1k8aXlvnGOEbpL8w3jRL +G4/qXzGrMBp3LCzupi3dI3yrrIMWb2C0goyHeAejzrfaM+uDYTGW4iqhA39zBj4o +zoydz3v0i8Yag7Df9MIwr34WD9Ng0oXh8XRCAYJmS1e43jnM+XcFdSfGVhKn9h1B +Cbv/hqUSv6baNloWLlPBffLII5bx633MMwKBwQDGg87fKEPz73pbp1kAOuXDpUnm ++OUFuf6Uqr/OqCorh8SOwxb2q3ScqyCWGVLOaVBcSV+zxIdcQzX2kjf1mj7PcQ6Q +2xfDIS1/xT+RiX9LO0kbkVDYcwcGeKVtmUwWyjauo96OB2r+SchTsNJpYOT3a/7r +JUKdbHFwsFwAx5q7r9mOh0BOybuXM6N7lUDBf4SgrhjnKRh1pME3R0JbJj9m8tZg +SsWlHcj04yAXJ7NGemiiYgeDZ4unsAfx7/sS/lECgcEAsL0Yj2XTQU8Ry9ULaDsA +az1k6BhWvnEea6lfOQPwGVY5WqQk6oqPB3vK6CEKAEgGRSJ53UZxSTlR4fLjg2UL +zYm9MATMQ5wPfpYMKcIFDGLy3sf7RwCNpMwk+tuEq+vdBPMo85BxflQMDVBHEM9+ +1zpIG9sKxvWJVLY89LnmeHZYZi/nboTsOUQSVgPIkVLmx1vljXMT3jzd+FHxCx+c +bnmOB8DnMrpYJWV9SFP+KmNlGkf3ys65bPPPF1g7ZUDLAoHAaKHqtQa1Imr0NED1 +kUB6AHArjrlbhXQuck+5f4R1jbIm8RR1Exj2AunT6Cl60t8Bg1MNRWRt8DxgwhD5 +u9NMDezKP6GrWacwIytlQSGW3aFm/EfQs/WVG10V3LmzOEPnJI+s63GPfG6JT0tg +7DgtFxhuKaTfAri45iueoq6SqSCb7Brv01dTL/QA1E+r7RF4Z3S8HYM0qDVpvegq +Wn7DZlDSm7htioUzeZgJPwsm3BwC8Kv4x9MY8g6/cU8LKEyxAoHAWCaDpLIuQ51r +PeL+u/1cfNdi6OOrtZ6S95tu3Vv+mYzpCPnOpgPHFp3l+RGmLg56t7uvHFaFxOvB +EjPm4bVhnPSA7pl7ZHQXhinG9+4UgcejoCAJzfg05BI1tMbwFZ+C0tG/PNzBlaX+ +IwkGO8VP/54N6wL1UqfZ8AKJFZW8G7W7KVkjqye1FS4oeDlJ197t/X+PMn5sFAc7 +UVsDaSelBqpsfmetXSH8KC3XkbgCtHvgAnJDkGkp84VmJvMr5ukv +-----END RSA PRIVATE KEY----- +` + +type M = map[string]interface{} + +func newMockAPI(t *testing.T, srv *httptest.Server) http.Handler { + r := chi.NewRouter() + r.Use(middleware.Logger) + r.Post("/token", func(w http.ResponseWriter, r *http.Request) { + _ = json.NewEncoder(w).Encode(M{ + "access_token": "ACCESSTOKEN", + "token_type": "Bearer", + "refresh_token": "REFRESHTOKEN", + }) + }) + r.Route("/groups", func(r chi.Router) { + r.Get("/", func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Query().Get("userKey") { + case "user1": + _ = json.NewEncoder(w).Encode(M{ + "kind": "admin#directory#groups", + "groups": []M{ + {"id": "group1"}, + {"id": "group2"}, + }, + }) + default: + http.Error(w, "not found", http.StatusNotFound) + } + }) + }) + r.Route("/users", func(r chi.Router) { + r.Get("/{user_id}", func(w http.ResponseWriter, r *http.Request) { + switch chi.URLParam(r, "user_id") { + case "user1": + _ = json.NewEncoder(w).Encode(M{ + "kind": "admin#directory#user", + "id": "1", + "name": M{ + "fullName": "User 1", + }, + "primaryEmail": "user1@example.com", + }) + case "user2": + http.Error(w, "forbidden", http.StatusForbidden) + default: + http.Error(w, "not found", http.StatusNotFound) + } + }) + }) + + return r +} + +func TestProvider_User(t *testing.T) { + ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*30) + defer clearTimeout() + + var mockAPI http.Handler + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + mockAPI.ServeHTTP(w, r) + })) + defer srv.Close() + mockAPI = newMockAPI(t, srv) + + p := New(WithServiceAccount(&ServiceAccount{ + Type: "service_account", + PrivateKey: privateKey, + TokenURL: srv.URL + "/token", + }), WithURL(srv.URL)) + + du, err := p.User(ctx, "user1", "") + if !assert.NoError(t, err) { + return + } + assert.Equal(t, "user1", du.Id) + assert.Equal(t, "user1@example.com", du.Email) + assert.Equal(t, "User 1", du.DisplayName) + assert.Equal(t, []string{"group1", "group2"}, du.GroupIds) + + du, err = p.User(ctx, "user2", "") + if !assert.NoError(t, err) { + return + } + assert.Equal(t, "user2", du.Id) +} diff --git a/internal/directory/okta/okta.go b/internal/directory/okta/okta.go index 641e9f3ab..e78f158fa 100644 --- a/internal/directory/okta/okta.go +++ b/internal/directory/okta/okta.go @@ -108,6 +108,36 @@ func New(options ...Option) *Provider { } } +// User returns the user record for the given id. +func (p *Provider) User(ctx context.Context, userID, accessToken string) (*directory.User, error) { + if p.cfg.serviceAccount == nil { + return nil, ErrServiceAccountNotDefined + } + + _, providerUserID := databroker.FromUserID(userID) + du := &directory.User{ + Id: userID, + } + + au, err := p.getUser(ctx, providerUserID) + if err != nil { + return nil, err + } + du.DisplayName = au.getDisplayName() + du.Email = au.Profile.Email + + groups, err := p.listUserGroups(ctx, providerUserID) + if err != nil { + return nil, err + } + for _, g := range groups { + du.GroupIds = append(du.GroupIds, g.ID) + } + sort.Strings(du.GroupIds) + + return du, nil +} + // UserGroups fetches the groups of which the user is a member // https://developer.okta.com/docs/reference/api/users/#get-user-s-groups func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*directory.User, error) { @@ -159,7 +189,7 @@ func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*direc users = append(users, &directory.User{ Id: databroker.GetUserID(Name, u.ID), GroupIds: groups, - DisplayName: u.Profile.FirstName + " " + u.Profile.LastName, + DisplayName: u.getDisplayName(), Email: u.Profile.Email, }) } @@ -183,14 +213,7 @@ func (p *Provider) getGroups(ctx context.Context) ([]*directory.Group, error) { groupURL := p.cfg.providerURL.ResolveReference(u).String() for groupURL != "" { - var out []struct { - ID string `json:"id"` - Profile struct { - Name string `json:"name"` - } `json:"profile"` - LastUpdated string `json:"lastUpdated"` - LastMembershipUpdated string `json:"lastMembershipUpdated"` - } + var out []apiGroupObject hdrs, err := p.apiGet(ctx, groupURL, &out) if err != nil { return nil, fmt.Errorf("okta: error querying for groups: %w", err) @@ -239,6 +262,36 @@ func (p *Provider) getGroupMembers(ctx context.Context, groupID string) (users [ return users, nil } +func (p *Provider) getUser(ctx context.Context, userID string) (*apiUserObject, error) { + apiURL := p.cfg.providerURL.ResolveReference(&url.URL{ + Path: fmt.Sprintf("/api/v1/users/%s", userID), + }).String() + + var out apiUserObject + _, err := p.apiGet(ctx, apiURL, &out) + if err != nil { + return nil, fmt.Errorf("okta: error querying for user: %w", err) + } + + return &out, nil +} + +func (p *Provider) listUserGroups(ctx context.Context, userID string) (groups []apiGroupObject, err error) { + apiURL := p.cfg.providerURL.ResolveReference(&url.URL{ + Path: fmt.Sprintf("/api/v1/users/%s/groups", userID), + }).String() + for apiURL != "" { + var out []apiGroupObject + hdrs, err := p.apiGet(ctx, apiURL, &out) + if err != nil { + return nil, fmt.Errorf("okta: error querying for user groups: %w", err) + } + groups = append(groups, out...) + apiURL = getNextLink(hdrs) + } + return groups, nil +} + func (p *Provider) apiGet(ctx context.Context, uri string, out interface{}) (http.Header, error) { req, err := http.NewRequestWithContext(ctx, "GET", uri, nil) if err != nil { @@ -334,11 +387,25 @@ func (err *APIError) Error() string { return fmt.Sprintf("okta: error querying API, status_code=%d: %s", err.HTTPStatusCode, err.Body) } -type apiUserObject struct { - ID string `json:"id"` - Profile struct { - FirstName string `json:"firstName"` - LastName string `json:"lastName"` - Email string `json:"email"` - } `json:"profile"` +type ( + apiGroupObject struct { + ID string `json:"id"` + Profile struct { + Name string `json:"name"` + } `json:"profile"` + LastUpdated string `json:"lastUpdated"` + LastMembershipUpdated string `json:"lastMembershipUpdated"` + } + apiUserObject struct { + ID string `json:"id"` + Profile struct { + FirstName string `json:"firstName"` + LastName string `json:"lastName"` + Email string `json:"email"` + } `json:"profile"` + } +) + +func (obj *apiUserObject) getDisplayName() string { + return obj.Profile.FirstName + " " + obj.Profile.LastName } diff --git a/internal/directory/okta/okta_test.go b/internal/directory/okta/okta_test.go index 3fa6f6351..fc945738b 100644 --- a/internal/directory/okta/okta_test.go +++ b/internal/directory/okta/okta_test.go @@ -16,6 +16,7 @@ import ( "github.com/stretchr/testify/require" "github.com/tomnomnom/linkheader" + "github.com/pomerium/pomerium/internal/testutil" "github.com/pomerium/pomerium/pkg/grpc/directory" ) @@ -43,88 +44,142 @@ func newMockOkta(srv *httptest.Server, userEmailToGroups map[string][]string) ht next.ServeHTTP(w, r) }) }) - r.Get("/api/v1/groups", func(w http.ResponseWriter, r *http.Request) { - lastUpdated := strings.Contains(r.URL.Query().Get("filter"), "lastUpdated ") - var groups []string - for group := range getAllGroups() { - if lastUpdated && group != "user-updated" { - continue - } - if !lastUpdated && group == "user-updated" { - continue - } - groups = append(groups, group) - } - sort.Strings(groups) - - var result []M - - found := r.URL.Query().Get("after") == "" - for i := range groups { - if found { - result = append(result, M{ - "id": groups[i], - "profile": M{ - "name": groups[i] + "-name", - }, - }) - break - } - found = r.URL.Query().Get("after") == groups[i] - } - - if len(result) > 0 { - nextURL := mustParseURL(srv.URL).ResolveReference(r.URL) - q := nextURL.Query() - q.Set("after", result[0]["id"].(string)) - nextURL.RawQuery = q.Encode() - w.Header().Set("Link", linkheader.Link{ - URL: nextURL.String(), - Rel: "next", - }.String()) - } - - _ = json.NewEncoder(w).Encode(result) - }) - r.Get("/api/v1/groups/{group}/users", func(w http.ResponseWriter, r *http.Request) { - group := chi.URLParam(r, "group") - - if _, ok := getAllGroups()[group]; !ok { - w.WriteHeader(http.StatusNotFound) - w.Write([]byte(`{ - "errorCode": "E0000007", - "errorSummary": "Not found: {0}", - "errorLink": E0000007, - "errorId": "sampleE7p0NECLNnSN5z_xLNT", - "errorCauses": [] - }`)) - return - } - - var result []M - for email, groups := range userEmailToGroups { - for _, g := range groups { - if group == g { - result = append(result, M{ - "id": email, - "profile": M{ - "email": email, - "firstName": "first", - "lastName": "last", - }, - }) + r.Route("/api/v1", func(r chi.Router) { + r.Route("/groups", func(r chi.Router) { + r.Get("/", func(w http.ResponseWriter, r *http.Request) { + lastUpdated := strings.Contains(r.URL.Query().Get("filter"), "lastUpdated ") + var groups []string + for group := range getAllGroups() { + if lastUpdated && group != "user-updated" { + continue + } + if !lastUpdated && group == "user-updated" { + continue + } + groups = append(groups, group) } - } - } - sort.Slice(result, func(i, j int) bool { - return result[i]["id"].(string) < result[j]["id"].(string) - }) + sort.Strings(groups) - _ = json.NewEncoder(w).Encode(result) + var result []M + + found := r.URL.Query().Get("after") == "" + for i := range groups { + if found { + result = append(result, M{ + "id": groups[i], + "profile": M{ + "name": groups[i] + "-name", + }, + }) + break + } + found = r.URL.Query().Get("after") == groups[i] + } + + if len(result) > 0 { + nextURL := mustParseURL(srv.URL).ResolveReference(r.URL) + q := nextURL.Query() + q.Set("after", result[0]["id"].(string)) + nextURL.RawQuery = q.Encode() + w.Header().Set("Link", linkheader.Link{ + URL: nextURL.String(), + Rel: "next", + }.String()) + } + + _ = json.NewEncoder(w).Encode(result) + }) + r.Get("/{group}/users", func(w http.ResponseWriter, r *http.Request) { + group := chi.URLParam(r, "group") + + if _, ok := getAllGroups()[group]; !ok { + w.WriteHeader(http.StatusNotFound) + w.Write([]byte(`{ + "errorCode": "E0000007", + "errorSummary": "Not found: {0}", + "errorLink": E0000007, + "errorId": "sampleE7p0NECLNnSN5z_xLNT", + "errorCauses": [] + }`)) + return + } + + var result []M + for email, groups := range userEmailToGroups { + for _, g := range groups { + if group == g { + result = append(result, M{ + "id": email, + "profile": M{ + "email": email, + "firstName": "first", + "lastName": "last", + }, + }) + } + } + } + sort.Slice(result, func(i, j int) bool { + return result[i]["id"].(string) < result[j]["id"].(string) + }) + + _ = json.NewEncoder(w).Encode(result) + }) + }) + r.Route("/users", func(r chi.Router) { + r.Get("/{user_id}/groups", func(w http.ResponseWriter, r *http.Request) { + var groups []apiGroupObject + for _, nm := range userEmailToGroups[chi.URLParam(r, "user_id")] { + obj := apiGroupObject{ + ID: nm, + } + obj.Profile.Name = nm + groups = append(groups, obj) + } + _ = json.NewEncoder(w).Encode(groups) + }) + r.Get("/{user_id}", func(w http.ResponseWriter, r *http.Request) { + user := apiUserObject{ + ID: chi.URLParam(r, "user_id"), + } + user.Profile.Email = chi.URLParam(r, "user_id") + user.Profile.FirstName = "first" + user.Profile.LastName = "last" + _ = json.NewEncoder(w).Encode(user) + }) + }) }) return r } +func TestProvider_User(t *testing.T) { + var mockOkta http.Handler + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + mockOkta.ServeHTTP(w, r) + })) + defer srv.Close() + mockOkta = newMockOkta(srv, map[string][]string{ + "a@example.com": {"user", "admin"}, + "b@example.com": {"user", "test"}, + "c@example.com": {"user"}, + }) + + p := New( + WithServiceAccount(&ServiceAccount{APIKey: "APITOKEN"}), + WithProviderURL(mustParseURL(srv.URL)), + ) + user, err := p.User(context.Background(), "okta/a@example.com", "") + if !assert.NoError(t, err) { + return + } + testutil.AssertProtoJSONEqual(t, `{ + "id": "okta/a@example.com", + "groupIds": ["admin","user"], + "displayName": "first last", + "email": "a@example.com" + }`, user) +} + func TestProvider_UserGroups(t *testing.T) { var mockOkta http.Handler srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { diff --git a/internal/directory/onelogin/onelogin.go b/internal/directory/onelogin/onelogin.go index f436f8fd7..9e058ee23 100644 --- a/internal/directory/onelogin/onelogin.go +++ b/internal/directory/onelogin/onelogin.go @@ -94,6 +94,32 @@ func New(options ...Option) *Provider { } } +// User returns the user record for the given id. +func (p *Provider) User(ctx context.Context, userID, accessToken string) (*directory.User, error) { + if p.cfg.serviceAccount == nil { + return nil, fmt.Errorf("onelogin: service account not defined") + } + _, providerUserID := databroker.FromUserID(userID) + du := &directory.User{ + Id: userID, + } + + token, err := p.getToken(ctx) + if err != nil { + return nil, err + } + + au, err := p.getUser(ctx, token.AccessToken, providerUserID) + if err != nil { + return nil, err + } + du.DisplayName = au.getDisplayName() + du.Email = au.Email + du.GroupIds = []string{strconv.Itoa(au.GroupID)} + + return du, nil +} + // UserGroups gets the directory user groups for onelogin. func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*directory.User, error) { if p.cfg.serviceAccount == nil { @@ -107,12 +133,12 @@ func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*direc return nil, nil, err } - groups, err := p.listGroups(ctx, token) + groups, err := p.listGroups(ctx, token.AccessToken) if err != nil { return nil, nil, err } - apiUsers, err := p.getUsers(ctx, token) + apiUsers, err := p.listUsers(ctx, token.AccessToken) if err != nil { return nil, nil, err } @@ -133,7 +159,7 @@ func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*direc return groups, users, nil } -func (p *Provider) listGroups(ctx context.Context, token *oauth2.Token) ([]*directory.Group, error) { +func (p *Provider) listGroups(ctx context.Context, accessToken string) ([]*directory.Group, error) { var groups []*directory.Group apiURL := p.cfg.apiURL.ResolveReference(&url.URL{ Path: "/api/1/groups", @@ -144,9 +170,9 @@ func (p *Provider) listGroups(ctx context.Context, token *oauth2.Token) ([]*dire ID int `json:"id"` Name string `json:"name"` } - nextLink, err := p.apiGet(ctx, token, apiURL, &result) + nextLink, err := p.apiGet(ctx, accessToken, apiURL, &result) if err != nil { - return nil, fmt.Errorf("onelogin: error querying group api: %w", err) + return nil, fmt.Errorf("onelogin: listing groups: %w", err) } for _, r := range result { @@ -161,7 +187,24 @@ func (p *Provider) listGroups(ctx context.Context, token *oauth2.Token) ([]*dire return groups, nil } -func (p *Provider) getUsers(ctx context.Context, token *oauth2.Token) ([]apiUserObject, error) { +func (p *Provider) getUser(ctx context.Context, accessToken string, userID string) (*apiUserObject, error) { + apiURL := p.cfg.apiURL.ResolveReference(&url.URL{ + Path: fmt.Sprintf("/api/1/users/%s", userID), + }).String() + + var out []apiUserObject + _, err := p.apiGet(ctx, accessToken, apiURL, &out) + if err != nil { + return nil, fmt.Errorf("onelogin: error getting user: %w", err) + } + if len(out) == 0 { + return nil, fmt.Errorf("onelogin: user not found") + } + + return &out[0], nil +} + +func (p *Provider) listUsers(ctx context.Context, accessToken string) ([]apiUserObject, error) { var users []apiUserObject apiURL := p.cfg.apiURL.ResolveReference(&url.URL{ @@ -170,9 +213,9 @@ func (p *Provider) getUsers(ctx context.Context, token *oauth2.Token) ([]apiUser }).String() for apiURL != "" { var result []apiUserObject - nextLink, err := p.apiGet(ctx, token, apiURL, &result) + nextLink, err := p.apiGet(ctx, accessToken, apiURL, &result) if err != nil { - return nil, err + return nil, fmt.Errorf("onelogin: listing users: %w", err) } users = append(users, result...) @@ -182,12 +225,12 @@ func (p *Provider) getUsers(ctx context.Context, token *oauth2.Token) ([]apiUser return users, nil } -func (p *Provider) apiGet(ctx context.Context, token *oauth2.Token, uri string, out interface{}) (nextLink string, err error) { +func (p *Provider) apiGet(ctx context.Context, accessToken string, uri string, out interface{}) (nextLink string, err error) { req, err := http.NewRequestWithContext(ctx, "GET", uri, nil) if err != nil { return "", err } - req.Header.Set("Authorization", fmt.Sprintf("bearer:%s", token.AccessToken)) + req.Header.Set("Authorization", fmt.Sprintf("bearer:%s", accessToken)) req.Header.Set("Content-Type", "application/json") res, err := p.cfg.httpClient.Do(req) @@ -307,3 +350,7 @@ type apiUserObject struct { FirstName string `json:"firstname"` LastName string `json:"lastname"` } + +func (obj *apiUserObject) getDisplayName() string { + return obj.FirstName + " " + obj.LastName +} diff --git a/internal/directory/onelogin/onelogin_test.go b/internal/directory/onelogin/onelogin_test.go index cc21bb706..e128128b1 100644 --- a/internal/directory/onelogin/onelogin_test.go +++ b/internal/directory/onelogin/onelogin_test.go @@ -8,6 +8,7 @@ import ( "net/http/httptest" "net/url" "sort" + "strconv" "testing" "time" @@ -102,6 +103,28 @@ func newMockAPI(srv *httptest.Server, userIDToGroupName map[int]string) http.Han _ = json.NewEncoder(w).Encode(result) }) + r.Get("/users/{user_id}", func(w http.ResponseWriter, r *http.Request) { + userIDToGroupID := map[int]int{} + for userID, groupName := range userIDToGroupName { + for id, n := range allGroups { + if groupName == n { + userIDToGroupID[userID] = id + } + } + } + + userID, _ := strconv.Atoi(chi.URLParam(r, "user_id")) + + _ = json.NewEncoder(w).Encode(M{ + "data": []M{{ + "id": userID, + "email": userIDToGroupName[userID] + "@example.com", + "group_id": userIDToGroupID[userID], + "firstname": "User", + "lastname": fmt.Sprint(userID), + }}, + }) + }) r.Get("/users", func(w http.ResponseWriter, r *http.Request) { userIDToGroupID := map[int]int{} for userID, groupName := range userIDToGroupName { @@ -130,6 +153,37 @@ func newMockAPI(srv *httptest.Server, userIDToGroupName map[int]string) http.Han return r } +func TestProvider_User(t *testing.T) { + var mockAPI http.Handler + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + mockAPI.ServeHTTP(w, r) + })) + defer srv.Close() + mockAPI = newMockAPI(srv, map[int]string{ + 111: "admin", + 222: "test", + 333: "user", + }) + + p := New( + WithServiceAccount(&ServiceAccount{ + ClientID: "CLIENTID", + ClientSecret: "CLIENTSECRET", + }), + WithURL(mustParseURL(srv.URL)), + ) + user, err := p.User(context.Background(), "onelogin/111", "ACCESSTOKEN") + if !assert.NoError(t, err) { + return + } + testutil.AssertProtoJSONEqual(t, `{ + "id": "onelogin/111", + "groupIds": ["0"], + "displayName": "User 111", + "email": "admin@example.com" + }`, user) +} + func TestProvider_UserGroups(t *testing.T) { var mockAPI http.Handler srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { diff --git a/internal/directory/provider.go b/internal/directory/provider.go index 91f0a53b7..5a81a953e 100644 --- a/internal/directory/provider.go +++ b/internal/directory/provider.go @@ -28,8 +28,12 @@ type User = directory.User // Options are the options specific to the provider. type Options = directory.Options +// RegisterDirectoryServiceServer registers the directory gRPC service. +var RegisterDirectoryServiceServer = directory.RegisterDirectoryServiceServer + // A Provider provides user group directory information. type Provider interface { + User(ctx context.Context, userID, accessToken string) (*User, error) UserGroups(ctx context.Context) ([]*Group, []*User, error) } @@ -55,7 +59,7 @@ func GetProvider(options Options) (provider Provider) { switch options.Provider { case auth0.Name: - serviceAccount, err := auth0.ParseServiceAccount(options.ServiceAccount) + serviceAccount, err := auth0.ParseServiceAccount(options) if err == nil { return auth0.New( auth0.WithDomain(options.ProviderURL), @@ -139,6 +143,10 @@ func GetProvider(options Options) (provider Provider) { type nullProvider struct{} +func (nullProvider) User(ctx context.Context, userID, accessToken string) (*directory.User, error) { + return nil, nil +} + func (nullProvider) UserGroups(ctx context.Context) ([]*Group, []*User, error) { return nil, nil, nil } diff --git a/pkg/grpc/databroker/databroker.go b/pkg/grpc/databroker/databroker.go index c50e26006..a4b4f9fc0 100644 --- a/pkg/grpc/databroker/databroker.go +++ b/pkg/grpc/databroker/databroker.go @@ -1,11 +1,22 @@ // Package databroker contains databroker protobuf definitions. package databroker +import "strings" + // GetUserID gets the databroker user id from a provider user id. func GetUserID(provider, providerUserID string) string { return provider + "/" + providerUserID } +// FromUserID gets the provider and provider user id from a databroker user id. +func FromUserID(userID string) (provider, providerUserID string) { + ps := strings.SplitN(userID, "/", 2) + if len(ps) < 2 { + return "", userID + } + return ps[0], ps[1] +} + // ApplyOffsetAndLimit applies the offset and limit to the list of records. func ApplyOffsetAndLimit(all []*Record, offset, limit int) (records []*Record, totalCount int) { records = all diff --git a/pkg/grpc/directory/directory.pb.go b/pkg/grpc/directory/directory.pb.go index 8ff082dbb..f30f00fca 100644 --- a/pkg/grpc/directory/directory.pb.go +++ b/pkg/grpc/directory/directory.pb.go @@ -7,7 +7,12 @@ package directory import ( + context "context" proto "github.com/golang/protobuf/proto" + empty "github.com/golang/protobuf/ptypes/empty" + grpc "google.golang.org/grpc" + codes "google.golang.org/grpc/codes" + status "google.golang.org/grpc/status" protoreflect "google.golang.org/protobuf/reflect/protoreflect" protoimpl "google.golang.org/protobuf/runtime/protoimpl" reflect "reflect" @@ -175,29 +180,97 @@ func (x *Group) GetEmail() string { return "" } +type RefreshUserRequest struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + UserId string `protobuf:"bytes,1,opt,name=user_id,json=userId,proto3" json:"user_id,omitempty"` + AccessToken string `protobuf:"bytes,2,opt,name=access_token,json=accessToken,proto3" json:"access_token,omitempty"` +} + +func (x *RefreshUserRequest) Reset() { + *x = RefreshUserRequest{} + if protoimpl.UnsafeEnabled { + mi := &file_directory_proto_msgTypes[2] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *RefreshUserRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*RefreshUserRequest) ProtoMessage() {} + +func (x *RefreshUserRequest) ProtoReflect() protoreflect.Message { + mi := &file_directory_proto_msgTypes[2] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use RefreshUserRequest.ProtoReflect.Descriptor instead. +func (*RefreshUserRequest) Descriptor() ([]byte, []int) { + return file_directory_proto_rawDescGZIP(), []int{2} +} + +func (x *RefreshUserRequest) GetUserId() string { + if x != nil { + return x.UserId + } + return "" +} + +func (x *RefreshUserRequest) GetAccessToken() string { + if x != nil { + return x.AccessToken + } + return "" +} + var File_directory_proto protoreflect.FileDescriptor var file_directory_proto_rawDesc = []byte{ 0x0a, 0x0f, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x79, 0x2e, 0x70, 0x72, 0x6f, 0x74, - 0x6f, 0x12, 0x09, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x79, 0x22, 0x86, 0x01, 0x0a, - 0x04, 0x55, 0x73, 0x65, 0x72, 0x12, 0x18, 0x0a, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, - 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x12, - 0x0e, 0x0a, 0x02, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x69, 0x64, 0x12, - 0x1b, 0x0a, 0x09, 0x67, 0x72, 0x6f, 0x75, 0x70, 0x5f, 0x69, 0x64, 0x73, 0x18, 0x03, 0x20, 0x03, - 0x28, 0x09, 0x52, 0x08, 0x67, 0x72, 0x6f, 0x75, 0x70, 0x49, 0x64, 0x73, 0x12, 0x21, 0x0a, 0x0c, - 0x64, 0x69, 0x73, 0x70, 0x6c, 0x61, 0x79, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x04, 0x20, 0x01, - 0x28, 0x09, 0x52, 0x0b, 0x64, 0x69, 0x73, 0x70, 0x6c, 0x61, 0x79, 0x4e, 0x61, 0x6d, 0x65, 0x12, - 0x14, 0x0a, 0x05, 0x65, 0x6d, 0x61, 0x69, 0x6c, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, - 0x65, 0x6d, 0x61, 0x69, 0x6c, 0x22, 0x5b, 0x0a, 0x05, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x12, 0x18, - 0x0a, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, - 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x0e, 0x0a, 0x02, 0x69, 0x64, 0x18, 0x02, - 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x69, 0x64, 0x12, 0x12, 0x0a, 0x04, 0x6e, 0x61, 0x6d, 0x65, - 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x12, 0x14, 0x0a, 0x05, - 0x65, 0x6d, 0x61, 0x69, 0x6c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x65, 0x6d, 0x61, - 0x69, 0x6c, 0x42, 0x31, 0x5a, 0x2f, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, - 0x2f, 0x70, 0x6f, 0x6d, 0x65, 0x72, 0x69, 0x75, 0x6d, 0x2f, 0x70, 0x6f, 0x6d, 0x65, 0x72, 0x69, - 0x75, 0x6d, 0x2f, 0x70, 0x6b, 0x67, 0x2f, 0x67, 0x72, 0x70, 0x63, 0x2f, 0x64, 0x69, 0x72, 0x65, - 0x63, 0x74, 0x6f, 0x72, 0x79, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x6f, 0x12, 0x09, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x79, 0x1a, 0x1b, 0x67, 0x6f, + 0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f, 0x65, 0x6d, + 0x70, 0x74, 0x79, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0x86, 0x01, 0x0a, 0x04, 0x55, 0x73, + 0x65, 0x72, 0x12, 0x18, 0x0a, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x01, 0x20, + 0x01, 0x28, 0x09, 0x52, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x0e, 0x0a, 0x02, + 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x69, 0x64, 0x12, 0x1b, 0x0a, 0x09, + 0x67, 0x72, 0x6f, 0x75, 0x70, 0x5f, 0x69, 0x64, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x09, 0x52, + 0x08, 0x67, 0x72, 0x6f, 0x75, 0x70, 0x49, 0x64, 0x73, 0x12, 0x21, 0x0a, 0x0c, 0x64, 0x69, 0x73, + 0x70, 0x6c, 0x61, 0x79, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, + 0x0b, 0x64, 0x69, 0x73, 0x70, 0x6c, 0x61, 0x79, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x14, 0x0a, 0x05, + 0x65, 0x6d, 0x61, 0x69, 0x6c, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x65, 0x6d, 0x61, + 0x69, 0x6c, 0x22, 0x5b, 0x0a, 0x05, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x12, 0x18, 0x0a, 0x07, 0x76, + 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x76, 0x65, + 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x0e, 0x0a, 0x02, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, + 0x09, 0x52, 0x02, 0x69, 0x64, 0x12, 0x12, 0x0a, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x03, 0x20, + 0x01, 0x28, 0x09, 0x52, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x65, 0x6d, 0x61, + 0x69, 0x6c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x65, 0x6d, 0x61, 0x69, 0x6c, 0x22, + 0x50, 0x0a, 0x12, 0x52, 0x65, 0x66, 0x72, 0x65, 0x73, 0x68, 0x55, 0x73, 0x65, 0x72, 0x52, 0x65, + 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x17, 0x0a, 0x07, 0x75, 0x73, 0x65, 0x72, 0x5f, 0x69, 0x64, + 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x75, 0x73, 0x65, 0x72, 0x49, 0x64, 0x12, 0x21, + 0x0a, 0x0c, 0x61, 0x63, 0x63, 0x65, 0x73, 0x73, 0x5f, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x18, 0x02, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x0b, 0x61, 0x63, 0x63, 0x65, 0x73, 0x73, 0x54, 0x6f, 0x6b, 0x65, + 0x6e, 0x32, 0x58, 0x0a, 0x10, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x79, 0x53, 0x65, + 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x44, 0x0a, 0x0b, 0x52, 0x65, 0x66, 0x72, 0x65, 0x73, 0x68, + 0x55, 0x73, 0x65, 0x72, 0x12, 0x1d, 0x2e, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x79, + 0x2e, 0x52, 0x65, 0x66, 0x72, 0x65, 0x73, 0x68, 0x55, 0x73, 0x65, 0x72, 0x52, 0x65, 0x71, 0x75, + 0x65, 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, + 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x42, 0x31, 0x5a, 0x2f, 0x67, + 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x70, 0x6f, 0x6d, 0x65, 0x72, 0x69, + 0x75, 0x6d, 0x2f, 0x70, 0x6f, 0x6d, 0x65, 0x72, 0x69, 0x75, 0x6d, 0x2f, 0x70, 0x6b, 0x67, 0x2f, + 0x67, 0x72, 0x70, 0x63, 0x2f, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x79, 0x62, 0x06, + 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( @@ -212,14 +285,18 @@ func file_directory_proto_rawDescGZIP() []byte { return file_directory_proto_rawDescData } -var file_directory_proto_msgTypes = make([]protoimpl.MessageInfo, 2) +var file_directory_proto_msgTypes = make([]protoimpl.MessageInfo, 3) var file_directory_proto_goTypes = []interface{}{ - (*User)(nil), // 0: directory.User - (*Group)(nil), // 1: directory.Group + (*User)(nil), // 0: directory.User + (*Group)(nil), // 1: directory.Group + (*RefreshUserRequest)(nil), // 2: directory.RefreshUserRequest + (*empty.Empty)(nil), // 3: google.protobuf.Empty } var file_directory_proto_depIdxs = []int32{ - 0, // [0:0] is the sub-list for method output_type - 0, // [0:0] is the sub-list for method input_type + 2, // 0: directory.DirectoryService.RefreshUser:input_type -> directory.RefreshUserRequest + 3, // 1: directory.DirectoryService.RefreshUser:output_type -> google.protobuf.Empty + 1, // [1:2] is the sub-list for method output_type + 0, // [0:1] is the sub-list for method input_type 0, // [0:0] is the sub-list for extension type_name 0, // [0:0] is the sub-list for extension extendee 0, // [0:0] is the sub-list for field type_name @@ -255,6 +332,18 @@ func file_directory_proto_init() { return nil } } + file_directory_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*RefreshUserRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } } type x struct{} out := protoimpl.TypeBuilder{ @@ -262,9 +351,9 @@ func file_directory_proto_init() { GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: file_directory_proto_rawDesc, NumEnums: 0, - NumMessages: 2, + NumMessages: 3, NumExtensions: 0, - NumServices: 0, + NumServices: 1, }, GoTypes: file_directory_proto_goTypes, DependencyIndexes: file_directory_proto_depIdxs, @@ -275,3 +364,83 @@ func file_directory_proto_init() { file_directory_proto_goTypes = nil file_directory_proto_depIdxs = nil } + +// Reference imports to suppress errors if they are not otherwise used. +var _ context.Context +var _ grpc.ClientConnInterface + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the grpc package it is being compiled against. +const _ = grpc.SupportPackageIsVersion6 + +// DirectoryServiceClient is the client API for DirectoryService service. +// +// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://godoc.org/google.golang.org/grpc#ClientConn.NewStream. +type DirectoryServiceClient interface { + RefreshUser(ctx context.Context, in *RefreshUserRequest, opts ...grpc.CallOption) (*empty.Empty, error) +} + +type directoryServiceClient struct { + cc grpc.ClientConnInterface +} + +func NewDirectoryServiceClient(cc grpc.ClientConnInterface) DirectoryServiceClient { + return &directoryServiceClient{cc} +} + +func (c *directoryServiceClient) RefreshUser(ctx context.Context, in *RefreshUserRequest, opts ...grpc.CallOption) (*empty.Empty, error) { + out := new(empty.Empty) + err := c.cc.Invoke(ctx, "/directory.DirectoryService/RefreshUser", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +// DirectoryServiceServer is the server API for DirectoryService service. +type DirectoryServiceServer interface { + RefreshUser(context.Context, *RefreshUserRequest) (*empty.Empty, error) +} + +// UnimplementedDirectoryServiceServer can be embedded to have forward compatible implementations. +type UnimplementedDirectoryServiceServer struct { +} + +func (*UnimplementedDirectoryServiceServer) RefreshUser(context.Context, *RefreshUserRequest) (*empty.Empty, error) { + return nil, status.Errorf(codes.Unimplemented, "method RefreshUser not implemented") +} + +func RegisterDirectoryServiceServer(s *grpc.Server, srv DirectoryServiceServer) { + s.RegisterService(&_DirectoryService_serviceDesc, srv) +} + +func _DirectoryService_RefreshUser_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(RefreshUserRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(DirectoryServiceServer).RefreshUser(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/directory.DirectoryService/RefreshUser", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(DirectoryServiceServer).RefreshUser(ctx, req.(*RefreshUserRequest)) + } + return interceptor(ctx, in, info, handler) +} + +var _DirectoryService_serviceDesc = grpc.ServiceDesc{ + ServiceName: "directory.DirectoryService", + HandlerType: (*DirectoryServiceServer)(nil), + Methods: []grpc.MethodDesc{ + { + MethodName: "RefreshUser", + Handler: _DirectoryService_RefreshUser_Handler, + }, + }, + Streams: []grpc.StreamDesc{}, + Metadata: "directory.proto", +} diff --git a/pkg/grpc/directory/directory.proto b/pkg/grpc/directory/directory.proto index 51c182bf0..1de12a90d 100644 --- a/pkg/grpc/directory/directory.proto +++ b/pkg/grpc/directory/directory.proto @@ -3,6 +3,8 @@ syntax = "proto3"; package directory; option go_package = "github.com/pomerium/pomerium/pkg/grpc/directory"; +import "google/protobuf/empty.proto"; + message User { string version = 1; string id = 2; @@ -17,3 +19,12 @@ message Group { string name = 3; string email = 4; } + +message RefreshUserRequest { + string user_id = 1; + string access_token = 2; +} + +service DirectoryService { + rpc RefreshUser(RefreshUserRequest) returns (google.protobuf.Empty); +}