diff --git a/internal/directory/onelogin/onelogin.go b/internal/directory/onelogin/onelogin.go index 5091d9b11..5f1950d28 100644 --- a/internal/directory/onelogin/onelogin.go +++ b/internal/directory/onelogin/onelogin.go @@ -9,6 +9,7 @@ import ( "net/http" "net/url" "sort" + "strconv" "strings" "sync" @@ -111,30 +112,31 @@ func (p *Provider) UserGroups(ctx context.Context) ([]*directory.User, error) { return nil, err } - userEmailToGroupIDs, err := p.getUserEmailToGroupIDs(ctx, token) + userIDToGroupIDs, err := p.getUserIDToGroupIDs(ctx, token) if err != nil { return nil, err } - userEmailToGroupNames := map[string][]string{} - for email, groupIDs := range userEmailToGroupIDs { + userIDToGroupNames := map[int][]string{} + for userID, groupIDs := range userIDToGroupIDs { for _, groupID := range groupIDs { if groupName, ok := groupIDToName[groupID]; ok { - userEmailToGroupNames[email] = append(userEmailToGroupNames[email], groupName) + userIDToGroupNames[userID] = append(userIDToGroupNames[userID], groupName) } else { - userEmailToGroupNames[email] = append(userEmailToGroupNames[email], "NOGROUP") + userIDToGroupNames[userID] = append(userIDToGroupNames[userID], "NOGROUP") } } } var users []*directory.User - for userEmail, groups := range userEmailToGroupNames { + for userID, groups := range userIDToGroupNames { sort.Strings(groups) users = append(users, &directory.User{ - Id: databroker.GetUserID(Name, userEmail), + Id: databroker.GetUserID(Name, strconv.Itoa(userID)), Groups: groups, }) } + sort.Slice(users, func(i, j int) bool { return users[i].Id < users[j].Id }) @@ -168,8 +170,8 @@ func (p *Provider) getGroupIDToName(ctx context.Context, token *oauth2.Token) (m return groupIDToName, nil } -func (p *Provider) getUserEmailToGroupIDs(ctx context.Context, token *oauth2.Token) (map[string][]int, error) { - userEmailToGroupIDs := map[string][]int{} +func (p *Provider) getUserIDToGroupIDs(ctx context.Context, token *oauth2.Token) (map[int][]int, error) { + userIDToGroupIDs := map[int][]int{} apiURL := p.cfg.apiURL.ResolveReference(&url.URL{ Path: "/api/1/users", @@ -177,8 +179,8 @@ func (p *Provider) getUserEmailToGroupIDs(ctx context.Context, token *oauth2.Tok }).String() for apiURL != "" { var result []struct { - Email string `json:"email"` - GroupID *int `json:"group_id"` + ID int `json:"id"` + GroupID *int `json:"group_id"` } nextLink, err := p.apiGet(ctx, token, apiURL, &result) if err != nil { @@ -190,13 +192,13 @@ func (p *Provider) getUserEmailToGroupIDs(ctx context.Context, token *oauth2.Tok if r.GroupID != nil { groupID = *r.GroupID } - userEmailToGroupIDs[r.Email] = append(userEmailToGroupIDs[r.Email], groupID) + userIDToGroupIDs[r.ID] = append(userIDToGroupIDs[r.ID], groupID) } apiURL = nextLink } - return userEmailToGroupIDs, nil + return userIDToGroupIDs, nil } func (p *Provider) apiGet(ctx context.Context, token *oauth2.Token, uri string, out interface{}) (nextLink string, err error) { diff --git a/internal/directory/onelogin/onelogin_test.go b/internal/directory/onelogin/onelogin_test.go index b26c1e4ff..8f84c8ebf 100644 --- a/internal/directory/onelogin/onelogin_test.go +++ b/internal/directory/onelogin/onelogin_test.go @@ -20,9 +20,9 @@ import ( type M = map[string]interface{} -func newMockAPI(srv *httptest.Server, userEmailToGroupName map[string]string) http.Handler { +func newMockAPI(srv *httptest.Server, userIDToGroupName map[int]string) http.Handler { lookup := map[string]struct{}{} - for _, group := range userEmailToGroupName { + for _, group := range userIDToGroupName { lookup[group] = struct{}{} } var allGroups []string @@ -31,11 +31,11 @@ func newMockAPI(srv *httptest.Server, userEmailToGroupName map[string]string) ht } sort.Strings(allGroups) - var allEmails []string - for email := range userEmailToGroupName { - allEmails = append(allEmails, email) + var allUserIDs []int + for userID := range userIDToGroupName { + allUserIDs = append(allUserIDs, userID) } - sort.Strings(allEmails) + sort.Ints(allUserIDs) r := chi.NewRouter() r.Use(middleware.Logger) @@ -103,21 +103,21 @@ func newMockAPI(srv *httptest.Server, userEmailToGroupName map[string]string) ht _ = json.NewEncoder(w).Encode(result) }) r.Get("/users", func(w http.ResponseWriter, r *http.Request) { - userEmailToGroupID := map[string]int{} - for email, groupName := range userEmailToGroupName { + userIDToGroupID := map[int]int{} + for userID, groupName := range userIDToGroupName { for id, n := range allGroups { if groupName == n { - userEmailToGroupID[email] = id + userIDToGroupID[userID] = id } } } var result []M - for i, email := range allEmails { + for _, userID := range allUserIDs { result = append(result, M{ - "id": i, - "email": email, - "group_id": userEmailToGroupID[email], + "id": userID, + "email": userIDToGroupName[userID] + "@example.com", + "group_id": userIDToGroupID[userID], }) } _ = json.NewEncoder(w).Encode(M{ @@ -134,10 +134,10 @@ func TestProvider_UserGroups(t *testing.T) { mockAPI.ServeHTTP(w, r) })) defer srv.Close() - mockAPI = newMockAPI(srv, map[string]string{ - "a@example.com": "admin", - "b@example.com": "test", - "c@example.com": "user", + mockAPI = newMockAPI(srv, map[int]string{ + 111: "admin", + 222: "test", + 333: "user", }) p := New( @@ -151,15 +151,15 @@ func TestProvider_UserGroups(t *testing.T) { assert.NoError(t, err) assert.Equal(t, []*directory.User{ { - Id: "onelogin/a@example.com", + Id: "onelogin/111", Groups: []string{"admin"}, }, { - Id: "onelogin/b@example.com", + Id: "onelogin/222", Groups: []string{"test"}, }, { - Id: "onelogin/c@example.com", + Id: "onelogin/333", Groups: []string{"user"}, }, }, users)