azure: add support for nested groups (#1408)

* azure: add support for nested groups

* fix test
This commit is contained in:
Caleb Doxsey 2020-09-17 08:25:10 -06:00 committed by GitHub
parent 79a01bcfbb
commit 665f0f9a74
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 151 additions and 17 deletions

View file

@ -111,24 +111,20 @@ func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*direc
return nil, nil, err
}
userIDToGroupIDs := map[string][]string{}
groupLookup := newGroupLookup()
for _, group := range groups {
userIDs, err := p.listGroupMembers(ctx, group.Id)
groupIDs, userIDs, err := p.listGroupMembers(ctx, group.Id)
if err != nil {
return nil, nil, err
}
for _, userID := range userIDs {
userIDToGroupIDs[userID] = append(userIDToGroupIDs[userID], group.Id)
}
groupLookup.addGroup(group.Id, groupIDs, userIDs)
}
var users []*directory.User
for userID, groupIDs := range userIDToGroupIDs {
sort.Strings(groupIDs)
for _, userID := range groupLookup.getUserIDs() {
users = append(users, &directory.User{
Id: databroker.GetUserID(Name, userID),
GroupIds: groupIDs,
GroupIds: groupLookup.getGroupIDsForUser(userID),
})
}
sort.Slice(users, func(i, j int) bool {
@ -168,7 +164,7 @@ func (p *Provider) listGroups(ctx context.Context) ([]*directory.Group, error) {
return groups, nil
}
func (p *Provider) listGroupMembers(ctx context.Context, groupID string) (userIDs []string, err error) {
func (p *Provider) listGroupMembers(ctx context.Context, groupID string) (groupIDs, userIDs []string, err error) {
nextURL := p.cfg.graphURL.ResolveReference(&url.URL{
Path: fmt.Sprintf("/v1.0/groups/%s/members", groupID),
}).String()
@ -176,21 +172,27 @@ func (p *Provider) listGroupMembers(ctx context.Context, groupID string) (userID
for nextURL != "" {
var result struct {
Value []struct {
Type string `json:"@odata.type"`
ID string `json:"id"`
} `json:"value"`
NextLink string `json:"@odata.nextLink"`
}
err := p.api(ctx, "GET", nextURL, nil, &result)
if err != nil {
return nil, err
return nil, nil, err
}
for _, v := range result.Value {
switch v.Type {
case "#microsoft.graph.group":
groupIDs = append(groupIDs, v.ID)
case "#microsoft.graph.user":
userIDs = append(userIDs, v.ID)
}
}
nextURL = result.NextLink
}
return userIDs, nil
return groupIDs, userIDs, nil
}
func (p *Provider) api(ctx context.Context, method, url string, body io.Reader, out interface{}) error {

View file

@ -54,11 +54,11 @@ func newMockAPI(t *testing.T, srv *httptest.Server) http.Handler {
r.Get("/groups/{group_name}/members", func(w http.ResponseWriter, r *http.Request) {
members := map[string][]M{
"admin": {
{"id": "user-1"},
{"@odata.type": "#microsoft.graph.user", "id": "user-1"},
},
"test": {
{"id": "user-2"},
{"id": "user-3"},
{"@odata.type": "#microsoft.graph.user", "id": "user-2"},
{"@odata.type": "#microsoft.graph.user", "id": "user-3"},
},
}
_ = json.NewEncoder(w).Encode(M{

View file

@ -0,0 +1,107 @@
package azure
import "sort"
type stringSet map[string]struct{}
func newStringSet() stringSet {
return make(stringSet)
}
func (ss stringSet) add(value string) {
ss[value] = struct{}{}
}
func (ss stringSet) has(value string) bool {
if ss == nil {
return false
}
_, ok := ss[value]
return ok
}
func (ss stringSet) sorted() []string {
if ss == nil {
return nil
}
s := make([]string, 0, len(ss))
for v := range ss {
s = append(s, v)
}
sort.Strings(s)
return s
}
type stringSetSet map[string]stringSet
func newStringSetSet() stringSetSet {
return make(stringSetSet)
}
func (sss stringSetSet) add(v1, v2 string) {
ss, ok := sss[v1]
if !ok {
ss = newStringSet()
sss[v1] = ss
}
ss.add(v2)
}
func (sss stringSetSet) get(v1 string) stringSet {
return sss[v1]
}
type groupLookup struct {
childUserIDToParentGroupID stringSetSet
childGroupIDToParentGroupID stringSetSet
}
func newGroupLookup() *groupLookup {
return &groupLookup{
childUserIDToParentGroupID: newStringSetSet(),
childGroupIDToParentGroupID: newStringSetSet(),
}
}
func (l *groupLookup) addGroup(parentGroupID string, childGroupIDs, childUserIDs []string) {
for _, childGroupID := range childGroupIDs {
l.childGroupIDToParentGroupID.add(childGroupID, parentGroupID)
}
for _, childUserID := range childUserIDs {
l.childUserIDToParentGroupID.add(childUserID, parentGroupID)
}
}
func (l *groupLookup) getUserIDs() []string {
s := make([]string, 0, len(l.childUserIDToParentGroupID))
for userID := range l.childUserIDToParentGroupID {
s = append(s, userID)
}
sort.Strings(s)
return s
}
func (l *groupLookup) getGroupIDsForUser(userID string) []string {
groupIDs := newStringSet()
var todo []string
for groupID := range l.childUserIDToParentGroupID.get(userID) {
todo = append(todo, groupID)
}
for len(todo) > 0 {
groupID := todo[len(todo)-1]
todo = todo[:len(todo)-1]
if groupIDs.has(groupID) {
continue
}
groupIDs.add(groupID)
for parentGroupID := range l.childGroupIDToParentGroupID.get(groupID) {
todo = append(todo, parentGroupID)
}
}
return groupIDs.sorted()
}

View file

@ -0,0 +1,25 @@
package azure
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestGroupLookup(t *testing.T) {
gl := newGroupLookup()
gl.addGroup("g1", []string{"g11", "g12", "g13"}, []string{"u1"})
gl.addGroup("g11", []string{"g111"}, nil)
gl.addGroup("g111", nil, []string{"u2"})
assert.Equal(t, []string{"u1", "u2"}, gl.getUserIDs())
assert.Equal(t, []string{"g1", "g11", "g111"}, gl.getGroupIDsForUser("u2"))
t.Run("cycle protection", func(t *testing.T) {
gl.addGroup("g12", []string{"g1"}, nil)
assert.Equal(t, []string{"u1", "u2"}, gl.getUserIDs())
assert.Equal(t, []string{"g1", "g11", "g111", "g12"}, gl.getGroupIDsForUser("u2"))
})
}