diff --git a/internal/directory/azure/azure.go b/internal/directory/azure/azure.go index 2db6364d7..655bded56 100644 --- a/internal/directory/azure/azure.go +++ b/internal/directory/azure/azure.go @@ -111,24 +111,20 @@ func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*direc return nil, nil, err } - userIDToGroupIDs := map[string][]string{} + groupLookup := newGroupLookup() for _, group := range groups { - userIDs, err := p.listGroupMembers(ctx, group.Id) + groupIDs, userIDs, err := p.listGroupMembers(ctx, group.Id) if err != nil { return nil, nil, err } - - for _, userID := range userIDs { - userIDToGroupIDs[userID] = append(userIDToGroupIDs[userID], group.Id) - } + groupLookup.addGroup(group.Id, groupIDs, userIDs) } var users []*directory.User - for userID, groupIDs := range userIDToGroupIDs { - sort.Strings(groupIDs) + for _, userID := range groupLookup.getUserIDs() { users = append(users, &directory.User{ Id: databroker.GetUserID(Name, userID), - GroupIds: groupIDs, + GroupIds: groupLookup.getGroupIDsForUser(userID), }) } sort.Slice(users, func(i, j int) bool { @@ -168,7 +164,7 @@ func (p *Provider) listGroups(ctx context.Context) ([]*directory.Group, error) { return groups, nil } -func (p *Provider) listGroupMembers(ctx context.Context, groupID string) (userIDs []string, err error) { +func (p *Provider) listGroupMembers(ctx context.Context, groupID string) (groupIDs, userIDs []string, err error) { nextURL := p.cfg.graphURL.ResolveReference(&url.URL{ Path: fmt.Sprintf("/v1.0/groups/%s/members", groupID), }).String() @@ -176,21 +172,27 @@ func (p *Provider) listGroupMembers(ctx context.Context, groupID string) (userID for nextURL != "" { var result struct { Value []struct { - ID string `json:"id"` + Type string `json:"@odata.type"` + ID string `json:"id"` } `json:"value"` NextLink string `json:"@odata.nextLink"` } err := p.api(ctx, "GET", nextURL, nil, &result) if err != nil { - return nil, err + return nil, nil, err } for _, v := range result.Value { - userIDs = append(userIDs, v.ID) + switch v.Type { + case "#microsoft.graph.group": + groupIDs = append(groupIDs, v.ID) + case "#microsoft.graph.user": + userIDs = append(userIDs, v.ID) + } } nextURL = result.NextLink } - return userIDs, nil + return groupIDs, userIDs, nil } func (p *Provider) api(ctx context.Context, method, url string, body io.Reader, out interface{}) error { diff --git a/internal/directory/azure/azure_test.go b/internal/directory/azure/azure_test.go index 424514dbc..068608fbb 100644 --- a/internal/directory/azure/azure_test.go +++ b/internal/directory/azure/azure_test.go @@ -54,11 +54,11 @@ func newMockAPI(t *testing.T, srv *httptest.Server) http.Handler { r.Get("/groups/{group_name}/members", func(w http.ResponseWriter, r *http.Request) { members := map[string][]M{ "admin": { - {"id": "user-1"}, + {"@odata.type": "#microsoft.graph.user", "id": "user-1"}, }, "test": { - {"id": "user-2"}, - {"id": "user-3"}, + {"@odata.type": "#microsoft.graph.user", "id": "user-2"}, + {"@odata.type": "#microsoft.graph.user", "id": "user-3"}, }, } _ = json.NewEncoder(w).Encode(M{ diff --git a/internal/directory/azure/group.go b/internal/directory/azure/group.go new file mode 100644 index 000000000..c5e5827b1 --- /dev/null +++ b/internal/directory/azure/group.go @@ -0,0 +1,107 @@ +package azure + +import "sort" + +type stringSet map[string]struct{} + +func newStringSet() stringSet { + return make(stringSet) +} + +func (ss stringSet) add(value string) { + ss[value] = struct{}{} +} + +func (ss stringSet) has(value string) bool { + if ss == nil { + return false + } + + _, ok := ss[value] + return ok +} + +func (ss stringSet) sorted() []string { + if ss == nil { + return nil + } + + s := make([]string, 0, len(ss)) + for v := range ss { + s = append(s, v) + } + sort.Strings(s) + return s +} + +type stringSetSet map[string]stringSet + +func newStringSetSet() stringSetSet { + return make(stringSetSet) +} + +func (sss stringSetSet) add(v1, v2 string) { + ss, ok := sss[v1] + if !ok { + ss = newStringSet() + sss[v1] = ss + } + ss.add(v2) +} + +func (sss stringSetSet) get(v1 string) stringSet { + return sss[v1] +} + +type groupLookup struct { + childUserIDToParentGroupID stringSetSet + childGroupIDToParentGroupID stringSetSet +} + +func newGroupLookup() *groupLookup { + return &groupLookup{ + childUserIDToParentGroupID: newStringSetSet(), + childGroupIDToParentGroupID: newStringSetSet(), + } +} + +func (l *groupLookup) addGroup(parentGroupID string, childGroupIDs, childUserIDs []string) { + for _, childGroupID := range childGroupIDs { + l.childGroupIDToParentGroupID.add(childGroupID, parentGroupID) + } + for _, childUserID := range childUserIDs { + l.childUserIDToParentGroupID.add(childUserID, parentGroupID) + } +} + +func (l *groupLookup) getUserIDs() []string { + s := make([]string, 0, len(l.childUserIDToParentGroupID)) + for userID := range l.childUserIDToParentGroupID { + s = append(s, userID) + } + sort.Strings(s) + return s +} + +func (l *groupLookup) getGroupIDsForUser(userID string) []string { + groupIDs := newStringSet() + var todo []string + for groupID := range l.childUserIDToParentGroupID.get(userID) { + todo = append(todo, groupID) + } + + for len(todo) > 0 { + groupID := todo[len(todo)-1] + todo = todo[:len(todo)-1] + if groupIDs.has(groupID) { + continue + } + + groupIDs.add(groupID) + for parentGroupID := range l.childGroupIDToParentGroupID.get(groupID) { + todo = append(todo, parentGroupID) + } + } + + return groupIDs.sorted() +} diff --git a/internal/directory/azure/group_test.go b/internal/directory/azure/group_test.go new file mode 100644 index 000000000..ae649c3c1 --- /dev/null +++ b/internal/directory/azure/group_test.go @@ -0,0 +1,25 @@ +package azure + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestGroupLookup(t *testing.T) { + gl := newGroupLookup() + + gl.addGroup("g1", []string{"g11", "g12", "g13"}, []string{"u1"}) + gl.addGroup("g11", []string{"g111"}, nil) + gl.addGroup("g111", nil, []string{"u2"}) + + assert.Equal(t, []string{"u1", "u2"}, gl.getUserIDs()) + assert.Equal(t, []string{"g1", "g11", "g111"}, gl.getGroupIDsForUser("u2")) + + t.Run("cycle protection", func(t *testing.T) { + gl.addGroup("g12", []string{"g1"}, nil) + + assert.Equal(t, []string{"u1", "u2"}, gl.getUserIDs()) + assert.Equal(t, []string{"g1", "g11", "g111", "g12"}, gl.getGroupIDsForUser("u2")) + }) +}