diff --git a/internal/directory/azure/azure.go b/internal/directory/azure/azure.go index 655bded56..5568e2a04 100644 --- a/internal/directory/azure/azure.go +++ b/internal/directory/azure/azure.go @@ -111,20 +111,28 @@ func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*direc return nil, nil, err } + userLookup := map[string]apiDirectoryObject{} groupLookup := newGroupLookup() for _, group := range groups { - groupIDs, userIDs, err := p.listGroupMembers(ctx, group.Id) + groupIDs, users, err := p.listGroupMembers(ctx, group.Id) if err != nil { return nil, nil, err } + userIDs := make([]string, 0, len(users)) + for _, u := range users { + userIDs = append(userIDs, u.ID) + userLookup[u.ID] = u + } groupLookup.addGroup(group.Id, groupIDs, userIDs) } - var users []*directory.User - for _, userID := range groupLookup.getUserIDs() { + users := make([]*directory.User, 0, len(userLookup)) + for _, u := range userLookup { users = append(users, &directory.User{ - Id: databroker.GetUserID(Name, userID), - GroupIds: groupLookup.getGroupIDsForUser(userID), + Id: databroker.GetUserID(Name, u.ID), + GroupIds: groupLookup.getGroupIDsForUser(u.ID), + DisplayName: u.DisplayName, + Email: u.getEmail(), }) } sort.Slice(users, func(i, j int) bool { @@ -164,18 +172,15 @@ func (p *Provider) listGroups(ctx context.Context) ([]*directory.Group, error) { return groups, nil } -func (p *Provider) listGroupMembers(ctx context.Context, groupID string) (groupIDs, userIDs []string, err error) { +func (p *Provider) listGroupMembers(ctx context.Context, groupID string) (groupIDs []string, users []apiDirectoryObject, err error) { nextURL := p.cfg.graphURL.ResolveReference(&url.URL{ Path: fmt.Sprintf("/v1.0/groups/%s/members", groupID), }).String() for nextURL != "" { var result struct { - Value []struct { - Type string `json:"@odata.type"` - ID string `json:"id"` - } `json:"value"` - NextLink string `json:"@odata.nextLink"` + Value []apiDirectoryObject `json:"value"` + NextLink string `json:"@odata.nextLink"` } err := p.api(ctx, "GET", nextURL, nil, &result) if err != nil { @@ -186,13 +191,13 @@ func (p *Provider) listGroupMembers(ctx context.Context, groupID string) (groupI case "#microsoft.graph.group": groupIDs = append(groupIDs, v.ID) case "#microsoft.graph.user": - userIDs = append(userIDs, v.ID) + users = append(users, v) } } nextURL = result.NextLink } - return groupIDs, userIDs, nil + return groupIDs, users, nil } func (p *Provider) api(ctx context.Context, method, url string, body io.Reader, out interface{}) error { @@ -354,3 +359,31 @@ func parseDirectoryIDFromURL(providerURL string) (string, error) { return pathParts[1], nil } + +type apiDirectoryObject struct { + Type string `json:"@odata.type"` + ID string `json:"id"` + Mail string `json:"mail"` + DisplayName string `json:"displayName"` + UserPrincipalName string `json:"userPrincipalName"` +} + +func (obj apiDirectoryObject) getEmail() string { + if obj.Mail != "" { + return obj.Mail + } + + // AD often doesn't have the email address returned, but we can parse it from the UPN + + // UPN looks like: + // cdoxsey_pomerium.com#EXT#@cdoxseypomerium.onmicrosoft.com + email := obj.UserPrincipalName + if idx := strings.Index(email, "#EXT"); idx > 0 { + email = email[:idx] + } + // find the last _ and replace it with @ + if idx := strings.LastIndex(email, "_"); idx > 0 { + email = email[:idx] + "@" + email[idx+1:] + } + return email +} diff --git a/internal/directory/azure/azure_test.go b/internal/directory/azure/azure_test.go index 068608fbb..b2bec3097 100644 --- a/internal/directory/azure/azure_test.go +++ b/internal/directory/azure/azure_test.go @@ -38,6 +38,7 @@ func newMockAPI(t *testing.T, srv *httptest.Server) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Header.Get("Authorization") != "Bearer ACCESSTOKEN" { http.Error(w, "forbidden", http.StatusForbidden) + return } next.ServeHTTP(w, r) @@ -54,11 +55,11 @@ func newMockAPI(t *testing.T, srv *httptest.Server) http.Handler { r.Get("/groups/{group_name}/members", func(w http.ResponseWriter, r *http.Request) { members := map[string][]M{ "admin": { - {"@odata.type": "#microsoft.graph.user", "id": "user-1"}, + {"@odata.type": "#microsoft.graph.user", "id": "user-1", "displayName": "User 1", "mail": "user1@example.com"}, }, "test": { - {"@odata.type": "#microsoft.graph.user", "id": "user-2"}, - {"@odata.type": "#microsoft.graph.user", "id": "user-3"}, + {"@odata.type": "#microsoft.graph.user", "id": "user-2", "displayName": "User 2", "mail": "user2@example.com"}, + {"@odata.type": "#microsoft.graph.user", "id": "user-3", "displayName": "User 3", "userPrincipalName": "user3_example.com#EXT#@user3example.onmicrosoft.com"}, }, } _ = json.NewEncoder(w).Encode(M{ @@ -66,6 +67,7 @@ func newMockAPI(t *testing.T, srv *httptest.Server) http.Handler { }) }) }) + return r } @@ -90,16 +92,22 @@ func Test(t *testing.T) { assert.NoError(t, err) assert.Equal(t, []*directory.User{ { - Id: "azure/user-1", - GroupIds: []string{"admin"}, + Id: "azure/user-1", + GroupIds: []string{"admin"}, + DisplayName: "User 1", + Email: "user1@example.com", }, { - Id: "azure/user-2", - GroupIds: []string{"test"}, + Id: "azure/user-2", + GroupIds: []string{"test"}, + DisplayName: "User 2", + Email: "user2@example.com", }, { - Id: "azure/user-3", - GroupIds: []string{"test"}, + Id: "azure/user-3", + GroupIds: []string{"test"}, + DisplayName: "User 3", + Email: "user3@example.com", }, }, users) assert.Equal(t, []*directory.Group{ diff --git a/internal/directory/github/github.go b/internal/directory/github/github.go index 93bcffb7f..7db5e2978 100644 --- a/internal/directory/github/github.go +++ b/internal/directory/github/github.go @@ -2,12 +2,10 @@ package github import ( - "bytes" "context" "encoding/base64" "encoding/json" "fmt" - "io" "net/http" "net/url" "sort" @@ -121,9 +119,16 @@ func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*direc var users []*directory.User for userLogin, groups := range userLoginToGroups { + u, err := p.getUser(ctx, userLogin) + if err != nil { + return nil, nil, err + } + user := &directory.User{ - Id: databroker.GetUserID(Name, userLogin), - GroupIds: groups, + Id: databroker.GetUserID(Name, userLogin), + GroupIds: groups, + DisplayName: u.Name, + Email: u.Email, } sort.Strings(user.GroupIds) users = append(users, user) @@ -143,7 +148,7 @@ func (p *Provider) listOrgs(ctx context.Context) (orgSlugs []string, err error) var results []struct { Login string `json:"login"` } - hdrs, err := p.api(ctx, "GET", nextURL, nil, &results) + hdrs, err := p.api(ctx, nextURL, &results) if err != nil { return nil, err } @@ -169,7 +174,7 @@ func (p *Provider) listGroups(ctx context.Context, orgSlug string) ([]*directory ID int `json:"id"` Slug string `json:"slug"` } - hdrs, err := p.api(ctx, "GET", nextURL, nil, &results) + hdrs, err := p.api(ctx, nextURL, &results) if err != nil { return nil, err } @@ -196,7 +201,7 @@ func (p *Provider) listTeamMembers(ctx context.Context, orgSlug, teamSlug string var results []struct { Login string `json:"login"` } - hdrs, err := p.api(ctx, "GET", nextURL, nil, &results) + hdrs, err := p.api(ctx, nextURL, &results) if err != nil { return nil, err } @@ -211,16 +216,22 @@ func (p *Provider) listTeamMembers(ctx context.Context, orgSlug, teamSlug string return userLogins, err } -func (p *Provider) api(ctx context.Context, method string, apiURL string, in, out interface{}) (http.Header, error) { - var body io.Reader - if in != nil { - bs, err := json.Marshal(in) - if err != nil { - return nil, fmt.Errorf("github: failed to marshal api input") - } - body = bytes.NewReader(bs) +func (p *Provider) getUser(ctx context.Context, userLogin string) (*apiUserObject, error) { + apiURL := p.cfg.url.ResolveReference(&url.URL{ + Path: fmt.Sprintf("/users/%s", userLogin), + }).String() + + var res apiUserObject + _, err := p.api(ctx, apiURL, &res) + if err != nil { + return nil, err } - req, err := http.NewRequestWithContext(ctx, method, apiURL, body) + + return &res, nil +} + +func (p *Provider) api(ctx context.Context, apiURL string, out interface{}) (http.Header, error) { + req, err := http.NewRequestWithContext(ctx, "GET", apiURL, nil) if err != nil { return nil, fmt.Errorf("github: failed to create http request: %w", err) } @@ -283,3 +294,10 @@ func ParseServiceAccount(rawServiceAccount string) (*ServiceAccount, error) { return &serviceAccount, nil } + +// see: https://docs.github.com/en/free-pro-team@latest/rest/reference/users#get-a-user +type apiUserObject struct { + Login string `json:"login"` + Name string `json:"name"` + Email string `json:"email"` +} diff --git a/internal/directory/github/github_test.go b/internal/directory/github/github_test.go index b7a125f4d..a9e52f878 100644 --- a/internal/directory/github/github_test.go +++ b/internal/directory/github/github_test.go @@ -75,6 +75,16 @@ func newMockAPI(t *testing.T, srv *httptest.Server) http.Handler { teamID := chi.URLParam(r, "team_id") json.NewEncoder(w).Encode(members[orgID][teamID]) }) + r.Get("/users/{user_id}", func(w http.ResponseWriter, r *http.Request) { + users := map[string]apiUserObject{ + "user1": {Login: "user1", Name: "User 1", Email: "user1@example.com"}, + "user2": {Login: "user2", Name: "User 2", Email: "user2@example.com"}, + "user3": {Login: "user3", Name: "User 3", Email: "user3@example.com"}, + "user4": {Login: "user4", Name: "User 4", Email: "user4@example.com"}, + } + userID := chi.URLParam(r, "user_id") + json.NewEncoder(w).Encode(users[userID]) + }) return r } @@ -96,10 +106,10 @@ func Test(t *testing.T) { groups, users, err := p.UserGroups(context.Background()) assert.NoError(t, err) testutil.AssertProtoJSONEqual(t, `[ - { "id": "github/user1", "groupIds": ["1", "2", "3"] }, - { "id": "github/user2", "groupIds": ["1", "3"] }, - { "id": "github/user3", "groupIds": ["3"] }, - { "id": "github/user4", "groupIds": ["4"] } + { "id": "github/user1", "groupIds": ["1", "2", "3"], "displayName": "User 1", "email": "user1@example.com" }, + { "id": "github/user2", "groupIds": ["1", "3"], "displayName": "User 2", "email": "user2@example.com" }, + { "id": "github/user3", "groupIds": ["3"], "displayName": "User 3", "email": "user3@example.com" }, + { "id": "github/user4", "groupIds": ["4"], "displayName": "User 4", "email": "user4@example.com" } ]`, users) testutil.AssertProtoJSONEqual(t, `[ { "id": "1", "name": "team1" }, diff --git a/internal/directory/gitlab/gitlab.go b/internal/directory/gitlab/gitlab.go index a922061c6..6ba5a0cd2 100644 --- a/internal/directory/gitlab/gitlab.go +++ b/internal/directory/gitlab/gitlab.go @@ -96,25 +96,29 @@ func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*direc return nil, nil, err } + userLookup := map[int]apiUserObject{} userIDToGroupIDs := map[int][]string{} for _, group := range groups { - userIDs, err := p.listGroupMemberIDs(ctx, group.Id) + users, err := p.listGroupMembers(ctx, group.Id) if err != nil { return nil, nil, err } - for _, userID := range userIDs { - userIDToGroupIDs[userID] = append(userIDToGroupIDs[userID], group.Id) + for _, u := range users { + userIDToGroupIDs[u.ID] = append(userIDToGroupIDs[u.ID], group.Id) + userLookup[u.ID] = u } } var users []*directory.User - for userID, groups := range userIDToGroupIDs { + for _, u := range userLookup { user := &directory.User{ - Id: databroker.GetUserID(Name, fmt.Sprint(userID)), + Id: databroker.GetUserID(Name, fmt.Sprint(u.ID)), + DisplayName: u.Name, + Email: u.Email, } - user.GroupIds = append(user.GroupIds, groups...) + user.GroupIds = append(user.GroupIds, userIDToGroupIDs[u.ID]...) sort.Strings(user.GroupIds) users = append(users, user) @@ -153,26 +157,21 @@ func (p *Provider) listGroups(ctx context.Context) ([]*directory.Group, error) { return groups, nil } -func (p *Provider) listGroupMemberIDs(ctx context.Context, groupID string) (userIDs []int, err error) { +func (p *Provider) listGroupMembers(ctx context.Context, groupID string) (users []apiUserObject, err error) { nextURL := p.cfg.url.ResolveReference(&url.URL{ Path: fmt.Sprintf("/api/v4/groups/%s/members", groupID), }).String() for nextURL != "" { - var result []struct { - ID int `json:"id"` - } + var result []apiUserObject hdrs, err := p.apiGet(ctx, nextURL, &result) if err != nil { return nil, fmt.Errorf("gitlab: error querying group members: %w", err) } - for _, r := range result { - userIDs = append(userIDs, r.ID) - } - + users = append(users, result...) nextURL = getNextLink(hdrs) } - return userIDs, nil + return users, nil } func (p *Provider) apiGet(ctx context.Context, uri string, out interface{}) (http.Header, error) { @@ -235,3 +234,9 @@ func ParseServiceAccount(rawServiceAccount string) (*ServiceAccount, error) { return &serviceAccount, nil } + +type apiUserObject struct { + ID int `json:"id"` + Name string `json:"name"` + Email string `json:"email"` +} diff --git a/internal/directory/gitlab/gitlab_test.go b/internal/directory/gitlab/gitlab_test.go index bb064aa93..5a41bea7f 100644 --- a/internal/directory/gitlab/gitlab_test.go +++ b/internal/directory/gitlab/gitlab_test.go @@ -39,11 +39,11 @@ func newMockAPI(t *testing.T, srv *httptest.Server) http.Handler { r.Get("/groups/{group_name}/members", func(w http.ResponseWriter, r *http.Request) { members := map[string][]M{ "1": { - {"id": 11}, + {"id": 11, "name": "User 1", "email": "user1@example.com"}, }, "2": { - {"id": 12}, - {"id": 13}, + {"id": 12, "name": "User 2", "email": "user2@example.com"}, + {"id": 13, "name": "User 3", "email": "user3@example.com"}, }, } _ = json.NewEncoder(w).Encode(members[chi.URLParam(r, "group_name")]) @@ -69,9 +69,9 @@ func Test(t *testing.T) { groups, users, err := p.UserGroups(context.Background()) assert.NoError(t, err) testutil.AssertProtoJSONEqual(t, `[ - { "id": "gitlab/11", "groupIds": ["1"] }, - { "id": "gitlab/12", "groupIds": ["2"] }, - { "id": "gitlab/13", "groupIds": ["2"] } + { "id": "gitlab/11", "groupIds": ["1"], "displayName": "User 1", "email": "user1@example.com" }, + { "id": "gitlab/12", "groupIds": ["2"], "displayName": "User 2", "email": "user2@example.com" }, + { "id": "gitlab/13", "groupIds": ["2"], "displayName": "User 3", "email": "user3@example.com" } ]`, users) testutil.AssertProtoJSONEqual(t, `[ { "id": "1", "name": "Group 1" }, diff --git a/internal/directory/google/google.go b/internal/directory/google/google.go index b33b831d9..b47eba6f6 100644 --- a/internal/directory/google/google.go +++ b/internal/directory/google/google.go @@ -110,6 +110,24 @@ func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*direc return nil, nil, fmt.Errorf("google: error getting groups: %w", err) } + userLookup := map[string]apiUserObject{} + err = apiClient.Users.List(). + Context(ctx). + Customer("my_customer"). + Pages(ctx, func(res *admin.Users) error { + for _, u := range res.Users { + userLookup[u.Id] = apiUserObject{ + ID: u.Id, + DisplayName: u.Name.FullName, + Email: u.PrimaryEmail, + } + } + return nil + }) + if err != nil { + return nil, nil, fmt.Errorf("google: error getting users: %w", err) + } + userIDToGroups := map[string][]string{} for _, group := range groups { group := group @@ -127,11 +145,14 @@ func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*direc } var users []*directory.User - for userID, groups := range userIDToGroups { + for _, u := range userLookup { + groups := userIDToGroups[u.ID] sort.Strings(groups) users = append(users, &directory.User{ - Id: databroker.GetUserID(Name, userID), - GroupIds: groups, + Id: databroker.GetUserID(Name, u.ID), + GroupIds: groups, + DisplayName: u.DisplayName, + Email: u.Email, }) } sort.Slice(users, func(i, j int) bool { @@ -216,3 +237,9 @@ func ParseServiceAccount(rawServiceAccount string) (*ServiceAccount, error) { return &serviceAccount, nil } + +type apiUserObject struct { + ID string + DisplayName string + Email string +} diff --git a/internal/directory/okta/okta.go b/internal/directory/okta/okta.go index 5ae05ef70..641e9f3ab 100644 --- a/internal/directory/okta/okta.go +++ b/internal/directory/okta/okta.go @@ -26,8 +26,21 @@ import ( // Name is the provider name. const Name = "okta" -// Okta use ISO-8601, see https://developer.okta.com/docs/reference/api-overview/#media-types -const filterDateFormat = "2006-01-02T15:04:05.999Z" +const ( + // Okta use ISO-8601, see https://developer.okta.com/docs/reference/api-overview/#media-types + filterDateFormat = "2006-01-02T15:04:05.999Z" + + batchSize = 200 + readLimit = 100 * 1024 + httpSuccessClass = 2 +) + +// Errors. +var ( + ErrAPIKeyRequired = errors.New("okta: api_key is required") + ErrServiceAccountNotDefined = errors.New("okta: service account not defined") + ErrProviderURLNotDefined = errors.New("okta: provider url not defined") +) type config struct { batchSize int @@ -69,11 +82,12 @@ func WithServiceAccount(serviceAccount *ServiceAccount) Option { func getConfig(options ...Option) *config { cfg := new(config) - WithBatchSize(200)(cfg) + WithBatchSize(batchSize)(cfg) WithHTTPClient(http.DefaultClient)(cfg) for _, option := range options { option(cfg) } + return cfg } @@ -98,13 +112,13 @@ func New(options ...Option) *Provider { // https://developer.okta.com/docs/reference/api/users/#get-user-s-groups func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*directory.User, error) { if p.cfg.serviceAccount == nil { - return nil, nil, fmt.Errorf("okta: service account not defined") + return nil, nil, ErrServiceAccountNotDefined } p.log.Info().Msg("getting user groups") if p.cfg.providerURL == nil { - return nil, nil, fmt.Errorf("okta: provider url not defined") + return nil, nil, ErrProviderURLNotDefined } groups, err := p.getGroups(ctx) @@ -112,10 +126,11 @@ func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*direc return nil, nil, err } + userLookup := map[string]apiUserObject{} userIDToGroups := map[string][]string{} for i := 0; i < len(groups); i++ { group := groups[i] - ids, err := p.getGroupMemberIDs(ctx, group.Id) + users, err := p.getGroupMembers(ctx, group.Id) // if we get a 404 on the member query, it means the group doesn't exist, so we should remove it from // the cached lookup and the local groups list @@ -131,17 +146,21 @@ func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*direc if err != nil { return nil, nil, err } - for _, id := range ids { - userIDToGroups[id] = append(userIDToGroups[id], group.Id) + for _, u := range users { + userIDToGroups[u.ID] = append(userIDToGroups[u.ID], group.Id) + userLookup[u.ID] = u } } var users []*directory.User - for userID, groups := range userIDToGroups { + for _, u := range userLookup { + groups := userIDToGroups[u.ID] sort.Strings(groups) users = append(users, &directory.User{ - Id: databroker.GetUserID(Name, userID), - GroupIds: groups, + Id: databroker.GetUserID(Name, u.ID), + GroupIds: groups, + DisplayName: u.Profile.FirstName + " " + u.Profile.LastName, + Email: u.Profile.Email, }) } sort.Slice(users, func(i, j int) bool { @@ -201,30 +220,23 @@ func (p *Provider) getGroups(ctx context.Context) ([]*directory.Group, error) { return groups, nil } -func (p *Provider) getGroupMemberIDs(ctx context.Context, groupID string) ([]string, error) { - var emails []string - +func (p *Provider) getGroupMembers(ctx context.Context, groupID string) (users []apiUserObject, err error) { usersURL := p.cfg.providerURL.ResolveReference(&url.URL{ Path: fmt.Sprintf("/api/v1/groups/%s/users", groupID), RawQuery: fmt.Sprintf("limit=%d", p.cfg.batchSize), }).String() for usersURL != "" { - var out []struct { - ID string `json:"id"` - } + var out []apiUserObject hdrs, err := p.apiGet(ctx, usersURL, &out) if err != nil { return nil, fmt.Errorf("okta: error querying for groups: %w", err) } - for _, el := range out { - emails = append(emails, el.ID) - } - + users = append(users, out...) usersURL = getNextLink(hdrs) } - return emails, nil + return users, nil } func (p *Provider) apiGet(ctx context.Context, uri string, out interface{}) (http.Header, error) { @@ -250,7 +262,7 @@ func (p *Provider) apiGet(ctx context.Context, uri string, out interface{}) (htt } continue } - if res.StatusCode/100 != 2 { + if res.StatusCode/100 != httpSuccessClass { return nil, newAPIError(res) } if err := json.NewDecoder(res.Body).Decode(out); err != nil { @@ -287,7 +299,7 @@ func ParseServiceAccount(rawServiceAccount string) (*ServiceAccount, error) { } if serviceAccount.APIKey == "" { - return nil, fmt.Errorf("api_key is required") + return nil, ErrAPIKeyRequired } return &serviceAccount, nil @@ -308,7 +320,7 @@ func newAPIError(res *http.Response) error { if res == nil { return nil } - buf, _ := ioutil.ReadAll(io.LimitReader(res.Body, 100*1024)) // limit to 100kb + buf, _ := ioutil.ReadAll(io.LimitReader(res.Body, readLimit)) // limit to 100kb err := &APIError{ HTTPStatusCode: res.StatusCode, @@ -321,3 +333,12 @@ func newAPIError(res *http.Response) error { func (err *APIError) Error() string { return fmt.Sprintf("okta: error querying API, status_code=%d: %s", err.HTTPStatusCode, err.Body) } + +type apiUserObject struct { + ID string `json:"id"` + Profile struct { + FirstName string `json:"firstName"` + LastName string `json:"lastName"` + Email string `json:"email"` + } `json:"profile"` +} diff --git a/internal/directory/okta/okta_test.go b/internal/directory/okta/okta_test.go index f2e4cc229..3fa6f6351 100644 --- a/internal/directory/okta/okta_test.go +++ b/internal/directory/okta/okta_test.go @@ -108,7 +108,9 @@ func newMockOkta(srv *httptest.Server, userEmailToGroups map[string][]string) ht result = append(result, M{ "id": email, "profile": M{ - "email": email, + "email": email, + "firstName": "first", + "lastName": "last", }, }) } @@ -143,16 +145,22 @@ func TestProvider_UserGroups(t *testing.T) { assert.NoError(t, err) assert.Equal(t, []*directory.User{ { - Id: "okta/a@example.com", - GroupIds: []string{"admin", "user"}, + Id: "okta/a@example.com", + GroupIds: []string{"admin", "user"}, + DisplayName: "first last", + Email: "a@example.com", }, { - Id: "okta/b@example.com", - GroupIds: []string{"test", "user"}, + Id: "okta/b@example.com", + GroupIds: []string{"test", "user"}, + DisplayName: "first last", + Email: "b@example.com", }, { - Id: "okta/c@example.com", - GroupIds: []string{"user"}, + Id: "okta/c@example.com", + GroupIds: []string{"user"}, + DisplayName: "first last", + Email: "c@example.com", }, }, users) assert.Len(t, groups, 3) @@ -180,16 +188,22 @@ func TestProvider_UserGroupsQueryUpdated(t *testing.T) { assert.NoError(t, err) assert.Equal(t, []*directory.User{ { - Id: "okta/a@example.com", - GroupIds: []string{"admin", "user"}, + Id: "okta/a@example.com", + GroupIds: []string{"admin", "user"}, + DisplayName: "first last", + Email: "a@example.com", }, { - Id: "okta/b@example.com", - GroupIds: []string{"test", "user"}, + Id: "okta/b@example.com", + GroupIds: []string{"test", "user"}, + DisplayName: "first last", + Email: "b@example.com", }, { - Id: "okta/c@example.com", - GroupIds: []string{"user"}, + Id: "okta/c@example.com", + GroupIds: []string{"user"}, + DisplayName: "first last", + Email: "c@example.com", }, }, users) assert.Len(t, groups, 3) @@ -198,20 +212,28 @@ func TestProvider_UserGroupsQueryUpdated(t *testing.T) { assert.NoError(t, err) assert.Equal(t, []*directory.User{ { - Id: "okta/a@example.com", - GroupIds: []string{"admin", "user"}, + Id: "okta/a@example.com", + GroupIds: []string{"admin", "user"}, + DisplayName: "first last", + Email: "a@example.com", }, { - Id: "okta/b@example.com", - GroupIds: []string{"test", "user"}, + Id: "okta/b@example.com", + GroupIds: []string{"test", "user"}, + DisplayName: "first last", + Email: "b@example.com", }, { - Id: "okta/c@example.com", - GroupIds: []string{"user"}, + Id: "okta/c@example.com", + GroupIds: []string{"user"}, + DisplayName: "first last", + Email: "c@example.com", }, { - Id: "okta/updated@example.com", - GroupIds: []string{"user-updated"}, + Id: "okta/updated@example.com", + GroupIds: []string{"user-updated"}, + DisplayName: "first last", + Email: "updated@example.com", }, }, users) assert.Len(t, groups, 4) @@ -222,20 +244,28 @@ func TestProvider_UserGroupsQueryUpdated(t *testing.T) { assert.NoError(t, err) assert.Equal(t, []*directory.User{ { - Id: "okta/a@example.com", - GroupIds: []string{"admin", "user"}, + Id: "okta/a@example.com", + GroupIds: []string{"admin", "user"}, + DisplayName: "first last", + Email: "a@example.com", }, { - Id: "okta/b@example.com", - GroupIds: []string{"user"}, + Id: "okta/b@example.com", + GroupIds: []string{"user"}, + DisplayName: "first last", + Email: "b@example.com", }, { - Id: "okta/c@example.com", - GroupIds: []string{"user"}, + Id: "okta/c@example.com", + GroupIds: []string{"user"}, + DisplayName: "first last", + Email: "c@example.com", }, { - Id: "okta/updated@example.com", - GroupIds: []string{"user-updated"}, + Id: "okta/updated@example.com", + GroupIds: []string{"user-updated"}, + DisplayName: "first last", + Email: "updated@example.com", }, }, users) assert.Len(t, groups, 3) diff --git a/internal/directory/onelogin/onelogin.go b/internal/directory/onelogin/onelogin.go index b202f0d11..f436f8fd7 100644 --- a/internal/directory/onelogin/onelogin.go +++ b/internal/directory/onelogin/onelogin.go @@ -112,17 +112,18 @@ func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*direc return nil, nil, err } - userIDToGroupIDs, err := p.getUserIDToGroupIDs(ctx, token) + apiUsers, err := p.getUsers(ctx, token) if err != nil { return nil, nil, err } var users []*directory.User - for userID, groupIDs := range userIDToGroupIDs { - sort.Strings(groupIDs) + for _, u := range apiUsers { users = append(users, &directory.User{ - Id: databroker.GetUserID(Name, strconv.Itoa(userID)), - GroupIds: groupIDs, + Id: databroker.GetUserID(Name, strconv.Itoa(u.ID)), + GroupIds: []string{strconv.Itoa(u.GroupID)}, + DisplayName: u.FirstName + " " + u.LastName, + Email: u.Email, }) } @@ -160,35 +161,25 @@ func (p *Provider) listGroups(ctx context.Context, token *oauth2.Token) ([]*dire return groups, nil } -func (p *Provider) getUserIDToGroupIDs(ctx context.Context, token *oauth2.Token) (map[int][]string, error) { - userIDToGroupIDs := map[int][]string{} +func (p *Provider) getUsers(ctx context.Context, token *oauth2.Token) ([]apiUserObject, error) { + var users []apiUserObject apiURL := p.cfg.apiURL.ResolveReference(&url.URL{ Path: "/api/1/users", RawQuery: fmt.Sprintf("limit=%d", p.cfg.batchSize), }).String() for apiURL != "" { - var result []struct { - ID int `json:"id"` - GroupID *int `json:"group_id"` - } + var result []apiUserObject nextLink, err := p.apiGet(ctx, token, apiURL, &result) if err != nil { return nil, err } - for _, r := range result { - groupID := 0 - if r.GroupID != nil { - groupID = *r.GroupID - } - userIDToGroupIDs[r.ID] = append(userIDToGroupIDs[r.ID], strconv.Itoa(groupID)) - } - + users = append(users, result...) apiURL = nextLink } - return userIDToGroupIDs, nil + return users, nil } func (p *Provider) apiGet(ctx context.Context, token *oauth2.Token, uri string, out interface{}) (nextLink string, err error) { @@ -308,3 +299,11 @@ func ParseServiceAccount(rawServiceAccount string) (*ServiceAccount, error) { return &serviceAccount, nil } + +type apiUserObject struct { + ID int `json:"id"` + GroupID int `json:"group_id"` + Email string `json:"email"` + FirstName string `json:"firstname"` + LastName string `json:"lastname"` +} diff --git a/internal/directory/onelogin/onelogin_test.go b/internal/directory/onelogin/onelogin_test.go index efe1a4530..cc21bb706 100644 --- a/internal/directory/onelogin/onelogin_test.go +++ b/internal/directory/onelogin/onelogin_test.go @@ -115,9 +115,11 @@ func newMockAPI(srv *httptest.Server, userIDToGroupName map[int]string) http.Han var result []M for _, userID := range allUserIDs { result = append(result, M{ - "id": userID, - "email": userIDToGroupName[userID] + "@example.com", - "group_id": userIDToGroupID[userID], + "id": userID, + "email": userIDToGroupName[userID] + "@example.com", + "group_id": userIDToGroupID[userID], + "firstname": "User", + "lastname": fmt.Sprint(userID), }) } _ = json.NewEncoder(w).Encode(M{ @@ -150,9 +152,9 @@ func TestProvider_UserGroups(t *testing.T) { groups, users, err := p.UserGroups(context.Background()) assert.NoError(t, err) testutil.AssertProtoJSONEqual(t, `[ - { "id": "onelogin/111", "groupIds": ["0"] }, - { "id": "onelogin/222", "groupIds": ["1"] }, - { "id": "onelogin/333", "groupIds": ["2"] } + { "id": "onelogin/111", "groupIds": ["0"], "displayName": "User 111", "email": "admin@example.com" }, + { "id": "onelogin/222", "groupIds": ["1"], "displayName": "User 222", "email": "test@example.com" }, + { "id": "onelogin/333", "groupIds": ["2"], "displayName": "User 333", "email": "user@example.com" } ]`, users) testutil.AssertProtoJSONEqual(t, `[ { "id": "0", "name": "admin" }, diff --git a/pkg/grpc/directory/directory.pb.go b/pkg/grpc/directory/directory.pb.go index 3997493dd..8ff082dbb 100644 --- a/pkg/grpc/directory/directory.pb.go +++ b/pkg/grpc/directory/directory.pb.go @@ -30,9 +30,11 @@ type User struct { sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields - Version string `protobuf:"bytes,1,opt,name=version,proto3" json:"version,omitempty"` - Id string `protobuf:"bytes,2,opt,name=id,proto3" json:"id,omitempty"` - GroupIds []string `protobuf:"bytes,3,rep,name=group_ids,json=groupIds,proto3" json:"group_ids,omitempty"` + Version string `protobuf:"bytes,1,opt,name=version,proto3" json:"version,omitempty"` + Id string `protobuf:"bytes,2,opt,name=id,proto3" json:"id,omitempty"` + GroupIds []string `protobuf:"bytes,3,rep,name=group_ids,json=groupIds,proto3" json:"group_ids,omitempty"` + DisplayName string `protobuf:"bytes,4,opt,name=display_name,json=displayName,proto3" json:"display_name,omitempty"` + Email string `protobuf:"bytes,5,opt,name=email,proto3" json:"email,omitempty"` } func (x *User) Reset() { @@ -88,6 +90,20 @@ func (x *User) GetGroupIds() []string { return nil } +func (x *User) GetDisplayName() string { + if x != nil { + return x.DisplayName + } + return "" +} + +func (x *User) GetEmail() string { + if x != nil { + return x.Email + } + return "" +} + type Group struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache @@ -163,22 +179,25 @@ var File_directory_proto protoreflect.FileDescriptor var file_directory_proto_rawDesc = []byte{ 0x0a, 0x0f, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x79, 0x2e, 0x70, 0x72, 0x6f, 0x74, - 0x6f, 0x12, 0x09, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x79, 0x22, 0x4d, 0x0a, 0x04, - 0x55, 0x73, 0x65, 0x72, 0x12, 0x18, 0x0a, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, - 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x0e, - 0x0a, 0x02, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x69, 0x64, 0x12, 0x1b, - 0x0a, 0x09, 0x67, 0x72, 0x6f, 0x75, 0x70, 0x5f, 0x69, 0x64, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, - 0x09, 0x52, 0x08, 0x67, 0x72, 0x6f, 0x75, 0x70, 0x49, 0x64, 0x73, 0x22, 0x5b, 0x0a, 0x05, 0x47, - 0x72, 0x6f, 0x75, 0x70, 0x12, 0x18, 0x0a, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, - 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x0e, - 0x0a, 0x02, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x69, 0x64, 0x12, 0x12, - 0x0a, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x6e, 0x61, - 0x6d, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x65, 0x6d, 0x61, 0x69, 0x6c, 0x18, 0x04, 0x20, 0x01, 0x28, - 0x09, 0x52, 0x05, 0x65, 0x6d, 0x61, 0x69, 0x6c, 0x42, 0x31, 0x5a, 0x2f, 0x67, 0x69, 0x74, 0x68, - 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x70, 0x6f, 0x6d, 0x65, 0x72, 0x69, 0x75, 0x6d, 0x2f, - 0x70, 0x6f, 0x6d, 0x65, 0x72, 0x69, 0x75, 0x6d, 0x2f, 0x70, 0x6b, 0x67, 0x2f, 0x67, 0x72, 0x70, - 0x63, 0x2f, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x79, 0x62, 0x06, 0x70, 0x72, 0x6f, - 0x74, 0x6f, 0x33, + 0x6f, 0x12, 0x09, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x79, 0x22, 0x86, 0x01, 0x0a, + 0x04, 0x55, 0x73, 0x65, 0x72, 0x12, 0x18, 0x0a, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, + 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x12, + 0x0e, 0x0a, 0x02, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x69, 0x64, 0x12, + 0x1b, 0x0a, 0x09, 0x67, 0x72, 0x6f, 0x75, 0x70, 0x5f, 0x69, 0x64, 0x73, 0x18, 0x03, 0x20, 0x03, + 0x28, 0x09, 0x52, 0x08, 0x67, 0x72, 0x6f, 0x75, 0x70, 0x49, 0x64, 0x73, 0x12, 0x21, 0x0a, 0x0c, + 0x64, 0x69, 0x73, 0x70, 0x6c, 0x61, 0x79, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x04, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x0b, 0x64, 0x69, 0x73, 0x70, 0x6c, 0x61, 0x79, 0x4e, 0x61, 0x6d, 0x65, 0x12, + 0x14, 0x0a, 0x05, 0x65, 0x6d, 0x61, 0x69, 0x6c, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, + 0x65, 0x6d, 0x61, 0x69, 0x6c, 0x22, 0x5b, 0x0a, 0x05, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x12, 0x18, + 0x0a, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, + 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x0e, 0x0a, 0x02, 0x69, 0x64, 0x18, 0x02, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x69, 0x64, 0x12, 0x12, 0x0a, 0x04, 0x6e, 0x61, 0x6d, 0x65, + 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x12, 0x14, 0x0a, 0x05, + 0x65, 0x6d, 0x61, 0x69, 0x6c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x65, 0x6d, 0x61, + 0x69, 0x6c, 0x42, 0x31, 0x5a, 0x2f, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, + 0x2f, 0x70, 0x6f, 0x6d, 0x65, 0x72, 0x69, 0x75, 0x6d, 0x2f, 0x70, 0x6f, 0x6d, 0x65, 0x72, 0x69, + 0x75, 0x6d, 0x2f, 0x70, 0x6b, 0x67, 0x2f, 0x67, 0x72, 0x70, 0x63, 0x2f, 0x64, 0x69, 0x72, 0x65, + 0x63, 0x74, 0x6f, 0x72, 0x79, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( diff --git a/pkg/grpc/directory/directory.proto b/pkg/grpc/directory/directory.proto index 27328cc8d..51c182bf0 100644 --- a/pkg/grpc/directory/directory.proto +++ b/pkg/grpc/directory/directory.proto @@ -7,6 +7,8 @@ message User { string version = 1; string id = 2; repeated string group_ids = 3; + string display_name = 4; + string email = 5; } message Group {