mirror of
https://github.com/pomerium/pomerium.git
synced 2025-06-03 11:22:45 +02:00
azure: add support for nested groups (#1408)
* azure: add support for nested groups * fix test
This commit is contained in:
parent
79a01bcfbb
commit
665f0f9a74
4 changed files with 151 additions and 17 deletions
|
@ -111,24 +111,20 @@ func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*direc
|
|||
return nil, nil, err
|
||||
}
|
||||
|
||||
userIDToGroupIDs := map[string][]string{}
|
||||
groupLookup := newGroupLookup()
|
||||
for _, group := range groups {
|
||||
userIDs, err := p.listGroupMembers(ctx, group.Id)
|
||||
groupIDs, userIDs, err := p.listGroupMembers(ctx, group.Id)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
for _, userID := range userIDs {
|
||||
userIDToGroupIDs[userID] = append(userIDToGroupIDs[userID], group.Id)
|
||||
}
|
||||
groupLookup.addGroup(group.Id, groupIDs, userIDs)
|
||||
}
|
||||
|
||||
var users []*directory.User
|
||||
for userID, groupIDs := range userIDToGroupIDs {
|
||||
sort.Strings(groupIDs)
|
||||
for _, userID := range groupLookup.getUserIDs() {
|
||||
users = append(users, &directory.User{
|
||||
Id: databroker.GetUserID(Name, userID),
|
||||
GroupIds: groupIDs,
|
||||
GroupIds: groupLookup.getGroupIDsForUser(userID),
|
||||
})
|
||||
}
|
||||
sort.Slice(users, func(i, j int) bool {
|
||||
|
@ -168,7 +164,7 @@ func (p *Provider) listGroups(ctx context.Context) ([]*directory.Group, error) {
|
|||
return groups, nil
|
||||
}
|
||||
|
||||
func (p *Provider) listGroupMembers(ctx context.Context, groupID string) (userIDs []string, err error) {
|
||||
func (p *Provider) listGroupMembers(ctx context.Context, groupID string) (groupIDs, userIDs []string, err error) {
|
||||
nextURL := p.cfg.graphURL.ResolveReference(&url.URL{
|
||||
Path: fmt.Sprintf("/v1.0/groups/%s/members", groupID),
|
||||
}).String()
|
||||
|
@ -176,21 +172,27 @@ func (p *Provider) listGroupMembers(ctx context.Context, groupID string) (userID
|
|||
for nextURL != "" {
|
||||
var result struct {
|
||||
Value []struct {
|
||||
Type string `json:"@odata.type"`
|
||||
ID string `json:"id"`
|
||||
} `json:"value"`
|
||||
NextLink string `json:"@odata.nextLink"`
|
||||
}
|
||||
err := p.api(ctx, "GET", nextURL, nil, &result)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
for _, v := range result.Value {
|
||||
switch v.Type {
|
||||
case "#microsoft.graph.group":
|
||||
groupIDs = append(groupIDs, v.ID)
|
||||
case "#microsoft.graph.user":
|
||||
userIDs = append(userIDs, v.ID)
|
||||
}
|
||||
}
|
||||
nextURL = result.NextLink
|
||||
}
|
||||
|
||||
return userIDs, nil
|
||||
return groupIDs, userIDs, nil
|
||||
}
|
||||
|
||||
func (p *Provider) api(ctx context.Context, method, url string, body io.Reader, out interface{}) error {
|
||||
|
|
|
@ -54,11 +54,11 @@ func newMockAPI(t *testing.T, srv *httptest.Server) http.Handler {
|
|||
r.Get("/groups/{group_name}/members", func(w http.ResponseWriter, r *http.Request) {
|
||||
members := map[string][]M{
|
||||
"admin": {
|
||||
{"id": "user-1"},
|
||||
{"@odata.type": "#microsoft.graph.user", "id": "user-1"},
|
||||
},
|
||||
"test": {
|
||||
{"id": "user-2"},
|
||||
{"id": "user-3"},
|
||||
{"@odata.type": "#microsoft.graph.user", "id": "user-2"},
|
||||
{"@odata.type": "#microsoft.graph.user", "id": "user-3"},
|
||||
},
|
||||
}
|
||||
_ = json.NewEncoder(w).Encode(M{
|
||||
|
|
107
internal/directory/azure/group.go
Normal file
107
internal/directory/azure/group.go
Normal file
|
@ -0,0 +1,107 @@
|
|||
package azure
|
||||
|
||||
import "sort"
|
||||
|
||||
type stringSet map[string]struct{}
|
||||
|
||||
func newStringSet() stringSet {
|
||||
return make(stringSet)
|
||||
}
|
||||
|
||||
func (ss stringSet) add(value string) {
|
||||
ss[value] = struct{}{}
|
||||
}
|
||||
|
||||
func (ss stringSet) has(value string) bool {
|
||||
if ss == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
_, ok := ss[value]
|
||||
return ok
|
||||
}
|
||||
|
||||
func (ss stringSet) sorted() []string {
|
||||
if ss == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
s := make([]string, 0, len(ss))
|
||||
for v := range ss {
|
||||
s = append(s, v)
|
||||
}
|
||||
sort.Strings(s)
|
||||
return s
|
||||
}
|
||||
|
||||
type stringSetSet map[string]stringSet
|
||||
|
||||
func newStringSetSet() stringSetSet {
|
||||
return make(stringSetSet)
|
||||
}
|
||||
|
||||
func (sss stringSetSet) add(v1, v2 string) {
|
||||
ss, ok := sss[v1]
|
||||
if !ok {
|
||||
ss = newStringSet()
|
||||
sss[v1] = ss
|
||||
}
|
||||
ss.add(v2)
|
||||
}
|
||||
|
||||
func (sss stringSetSet) get(v1 string) stringSet {
|
||||
return sss[v1]
|
||||
}
|
||||
|
||||
type groupLookup struct {
|
||||
childUserIDToParentGroupID stringSetSet
|
||||
childGroupIDToParentGroupID stringSetSet
|
||||
}
|
||||
|
||||
func newGroupLookup() *groupLookup {
|
||||
return &groupLookup{
|
||||
childUserIDToParentGroupID: newStringSetSet(),
|
||||
childGroupIDToParentGroupID: newStringSetSet(),
|
||||
}
|
||||
}
|
||||
|
||||
func (l *groupLookup) addGroup(parentGroupID string, childGroupIDs, childUserIDs []string) {
|
||||
for _, childGroupID := range childGroupIDs {
|
||||
l.childGroupIDToParentGroupID.add(childGroupID, parentGroupID)
|
||||
}
|
||||
for _, childUserID := range childUserIDs {
|
||||
l.childUserIDToParentGroupID.add(childUserID, parentGroupID)
|
||||
}
|
||||
}
|
||||
|
||||
func (l *groupLookup) getUserIDs() []string {
|
||||
s := make([]string, 0, len(l.childUserIDToParentGroupID))
|
||||
for userID := range l.childUserIDToParentGroupID {
|
||||
s = append(s, userID)
|
||||
}
|
||||
sort.Strings(s)
|
||||
return s
|
||||
}
|
||||
|
||||
func (l *groupLookup) getGroupIDsForUser(userID string) []string {
|
||||
groupIDs := newStringSet()
|
||||
var todo []string
|
||||
for groupID := range l.childUserIDToParentGroupID.get(userID) {
|
||||
todo = append(todo, groupID)
|
||||
}
|
||||
|
||||
for len(todo) > 0 {
|
||||
groupID := todo[len(todo)-1]
|
||||
todo = todo[:len(todo)-1]
|
||||
if groupIDs.has(groupID) {
|
||||
continue
|
||||
}
|
||||
|
||||
groupIDs.add(groupID)
|
||||
for parentGroupID := range l.childGroupIDToParentGroupID.get(groupID) {
|
||||
todo = append(todo, parentGroupID)
|
||||
}
|
||||
}
|
||||
|
||||
return groupIDs.sorted()
|
||||
}
|
25
internal/directory/azure/group_test.go
Normal file
25
internal/directory/azure/group_test.go
Normal file
|
@ -0,0 +1,25 @@
|
|||
package azure
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestGroupLookup(t *testing.T) {
|
||||
gl := newGroupLookup()
|
||||
|
||||
gl.addGroup("g1", []string{"g11", "g12", "g13"}, []string{"u1"})
|
||||
gl.addGroup("g11", []string{"g111"}, nil)
|
||||
gl.addGroup("g111", nil, []string{"u2"})
|
||||
|
||||
assert.Equal(t, []string{"u1", "u2"}, gl.getUserIDs())
|
||||
assert.Equal(t, []string{"g1", "g11", "g111"}, gl.getGroupIDsForUser("u2"))
|
||||
|
||||
t.Run("cycle protection", func(t *testing.T) {
|
||||
gl.addGroup("g12", []string{"g1"}, nil)
|
||||
|
||||
assert.Equal(t, []string{"u1", "u2"}, gl.getUserIDs())
|
||||
assert.Equal(t, []string{"g1", "g11", "g111", "g12"}, gl.getGroupIDsForUser("u2"))
|
||||
})
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue