From ed6c3e50871478c8fdade4ab23af56ab3048eae3 Mon Sep 17 00:00:00 2001 From: Caleb Doxsey Date: Fri, 21 Jan 2022 09:36:32 -0700 Subject: [PATCH] google: support groups for users outside of the organization (#2950) * google: support groups for users outside of the organization * wrap error --- databroker/directory.go | 19 +++++++- internal/directory/directoryerrors/errors.go | 8 ++++ internal/directory/google/google.go | 50 +++++++++++++------- internal/directory/google/google_test.go | 47 +++++++++++------- internal/directory/provider.go | 3 +- 5 files changed, 89 insertions(+), 38 deletions(-) create mode 100644 internal/directory/directoryerrors/errors.go diff --git a/databroker/directory.go b/databroker/directory.go index 9565115b1..12dfba109 100644 --- a/databroker/directory.go +++ b/databroker/directory.go @@ -3,9 +3,13 @@ package databroker import ( "context" "errors" + "fmt" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" "google.golang.org/protobuf/types/known/emptypb" + "github.com/pomerium/pomerium/internal/directory/directoryerrors" "github.com/pomerium/pomerium/pkg/grpc/databroker" "github.com/pomerium/pomerium/pkg/grpc/directory" "github.com/pomerium/pomerium/pkg/protoutil" @@ -22,7 +26,20 @@ func (c *DataBroker) RefreshUser(ctx context.Context, req *directory.RefreshUser } u, err := dp.User(ctx, req.GetUserId(), req.GetAccessToken()) - if err != nil { + // if the returned error signals we should prefer existing information + if errors.Is(err, directoryerrors.ErrPreferExistingInformation) { + _, err = c.dataBrokerServer.Get(ctx, &databroker.GetRequest{ + Type: protoutil.GetTypeURL(new(directory.User)), + Id: req.GetUserId(), + }) + switch status.Code(err) { + case codes.OK: + return new(emptypb.Empty), nil + case codes.NotFound: // go ahead and save the user that was returned + default: + return nil, fmt.Errorf("databroker: error retrieving existing user record for refresh: %w", err) + } + } else if err != nil { return nil, err } diff --git a/internal/directory/directoryerrors/errors.go b/internal/directory/directoryerrors/errors.go new file mode 100644 index 000000000..84a7275f7 --- /dev/null +++ b/internal/directory/directoryerrors/errors.go @@ -0,0 +1,8 @@ +// Package directoryerrors contains errors used by directory providers. +package directoryerrors + +import "errors" + +// ErrPreferExistingInformation indicates that the information returned by the provider should +// only be used if a record is brand new, otherwise the existing information should be kept as is. +var ErrPreferExistingInformation = errors.New("user ignored") diff --git a/internal/directory/google/google.go b/internal/directory/google/google.go index e414fde70..e1351490c 100644 --- a/internal/directory/google/google.go +++ b/internal/directory/google/google.go @@ -17,6 +17,7 @@ import ( "google.golang.org/api/googleapi" "google.golang.org/api/option" + "github.com/pomerium/pomerium/internal/directory/directoryerrors" "github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/pkg/grpc/directory" ) @@ -100,7 +101,7 @@ func (p *Provider) User(ctx context.Context, userID, accessToken string) (*direc Do() if isAccessDenied(err) { // ignore forbidden errors as a user may login from another gsuite domain - return du, nil + return du, directoryerrors.ErrPreferExistingInformation } else if err != nil { return nil, fmt.Errorf("google: error getting user: %w", err) } else { @@ -138,6 +139,7 @@ func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*direc return nil, nil, fmt.Errorf("google: error getting API client: %w", err) } + // query all the groups var groups []*directory.Group err = apiClient.Groups.List(). Context(ctx). @@ -160,7 +162,37 @@ func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*direc return nil, nil, fmt.Errorf("google: error getting groups: %w", err) } + // query all the user members for each group + // - create a lookup table for the user (storing id and name) + // (this includes users who aren't necessarily members of the same organization) + // - create a lookup table for the user's groups userLookup := map[string]apiUserObject{} + userIDToGroups := map[string][]string{} + for _, group := range groups { + group := group + err = apiClient.Members.List(group.Id). + Context(ctx). + Pages(ctx, func(res *admin.Members) error { + for _, member := range res.Members { + // only include user objects + if member.Type != "USER" { + continue + } + + userLookup[member.Id] = apiUserObject{ + ID: member.Id, + Email: member.Email, + } + userIDToGroups[member.Id] = append(userIDToGroups[member.Id], group.Id) + } + return nil + }) + if err != nil { + return nil, nil, fmt.Errorf("google: error getting group members: %w", err) + } + } + + // query all the users in the organization err = apiClient.Users.List(). Context(ctx). Customer(currentAccountCustomerID). @@ -181,22 +213,6 @@ func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*direc return nil, nil, fmt.Errorf("google: error getting users: %w", err) } - userIDToGroups := map[string][]string{} - for _, group := range groups { - group := group - err = apiClient.Members.List(group.Id). - Context(ctx). - Pages(ctx, func(res *admin.Members) error { - for _, member := range res.Members { - userIDToGroups[member.Id] = append(userIDToGroups[member.Id], group.Id) - } - return nil - }) - if err != nil { - return nil, nil, fmt.Errorf("google: error getting group members: %w", err) - } - } - var users []*directory.User for _, u := range userLookup { groups := userIDToGroups[u.ID] diff --git a/internal/directory/google/google_test.go b/internal/directory/google/google_test.go index 87b25a492..d31fdc778 100644 --- a/internal/directory/google/google_test.go +++ b/internal/directory/google/google_test.go @@ -12,6 +12,8 @@ import ( "github.com/go-chi/chi/v5/middleware" "github.com/stretchr/testify/assert" + "github.com/pomerium/pomerium/internal/directory/directoryerrors" + "github.com/pomerium/pomerium/internal/testutil" "github.com/pomerium/pomerium/pkg/grpc/directory" ) @@ -85,7 +87,7 @@ func newMockAPI(t *testing.T, srv *httptest.Server) http.Handler { _ = json.NewEncoder(w).Encode(M{ "kind": "admin#directory#groups", "groups": []M{ - {"id": "group1", "directMembersCount": "1"}, + {"id": "group1", "directMembersCount": "2"}, {"id": "group2"}, }, }) @@ -97,8 +99,16 @@ func newMockAPI(t *testing.T, srv *httptest.Server) http.Handler { _ = json.NewEncoder(w).Encode(M{ "members": []M{ { - "kind": "admin#directory#member", - "id": "user1", + "kind": "admin#directory#member", + "id": "inside-user1", + "email": "user1@inside.test", + "type": "USER", + }, + { + "kind": "admin#directory#member", + "id": "outside-user1", + "email": "user1@outside.test", + "type": "USER", }, }, }) @@ -112,24 +122,24 @@ func newMockAPI(t *testing.T, srv *httptest.Server) http.Handler { "users": []M{ { "kind": "admin#directory#user", - "id": "user1", - "primaryEmail": "user1@example.com", + "id": "inside-user1", + "primaryEmail": "user1@inside.test", }, }, }) }) r.Get("/{user_id}", func(w http.ResponseWriter, r *http.Request) { switch chi.URLParam(r, "user_id") { - case "user1": + case "inside-user1": _ = json.NewEncoder(w).Encode(M{ "kind": "admin#directory#user", - "id": "user1", + "id": "inside-user1", "name": M{ "fullName": "User 1", }, - "primaryEmail": "user1@example.com", + "primaryEmail": "user1@inside.test", }) - case "user2": + case "outside-user1": http.Error(w, "forbidden", http.StatusForbidden) default: http.Error(w, "not found", http.StatusNotFound) @@ -158,20 +168,20 @@ func TestProvider_User(t *testing.T) { TokenURL: srv.URL + "/token", }), WithURL(srv.URL)) - du, err := p.User(ctx, "user1", "") + du, err := p.User(ctx, "inside-user1", "") if !assert.NoError(t, err) { return } - assert.Equal(t, "user1", du.Id) - assert.Equal(t, "user1@example.com", du.Email) + assert.Equal(t, "inside-user1", du.Id) + assert.Equal(t, "user1@inside.test", du.Email) assert.Equal(t, "User 1", du.DisplayName) assert.Equal(t, []string{"group1", "group2"}, du.GroupIds) - du, err = p.User(ctx, "user2", "") - if !assert.NoError(t, err) { + du, err = p.User(ctx, "outside-user1", "") + if assert.ErrorIs(t, err, directoryerrors.ErrPreferExistingInformation) { return } - assert.Equal(t, "user2", du.Id) + assert.Equal(t, "outside-user1", du.Id) } func TestProvider_UserGroups(t *testing.T) { @@ -199,7 +209,8 @@ func TestProvider_UserGroups(t *testing.T) { assert.Equal(t, []*directory.Group{ {Id: "group1"}, }, dgs) - assert.Equal(t, []*directory.User{ - {Id: "user1", Email: "user1@example.com", GroupIds: []string{"group1"}}, - }, dus) + testutil.AssertProtoJSONEqual(t, `[ + { "id": "inside-user1", "email": "user1@inside.test", "groupIds": ["group1"] }, + { "id": "outside-user1", "email": "user1@outside.test", "groupIds": ["group1"] } + ]`, dus) } diff --git a/internal/directory/provider.go b/internal/directory/provider.go index f52e89a82..24c618464 100644 --- a/internal/directory/provider.go +++ b/internal/directory/provider.go @@ -9,8 +9,6 @@ import ( "github.com/google/go-cmp/cmp" - "github.com/pomerium/pomerium/internal/directory/ping" - "github.com/pomerium/pomerium/internal/directory/auth0" "github.com/pomerium/pomerium/internal/directory/azure" "github.com/pomerium/pomerium/internal/directory/github" @@ -18,6 +16,7 @@ import ( "github.com/pomerium/pomerium/internal/directory/google" "github.com/pomerium/pomerium/internal/directory/okta" "github.com/pomerium/pomerium/internal/directory/onelogin" + "github.com/pomerium/pomerium/internal/directory/ping" "github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/pkg/grpc/directory" )