mirror of
https://github.com/pomerium/pomerium.git
synced 2025-06-05 04:13:11 +02:00
azure: add support for nested groups (#1408)
* azure: add support for nested groups * fix test
This commit is contained in:
parent
79a01bcfbb
commit
665f0f9a74
4 changed files with 151 additions and 17 deletions
|
@ -111,24 +111,20 @@ func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*direc
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
userIDToGroupIDs := map[string][]string{}
|
groupLookup := newGroupLookup()
|
||||||
for _, group := range groups {
|
for _, group := range groups {
|
||||||
userIDs, err := p.listGroupMembers(ctx, group.Id)
|
groupIDs, userIDs, err := p.listGroupMembers(ctx, group.Id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
groupLookup.addGroup(group.Id, groupIDs, userIDs)
|
||||||
for _, userID := range userIDs {
|
|
||||||
userIDToGroupIDs[userID] = append(userIDToGroupIDs[userID], group.Id)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var users []*directory.User
|
var users []*directory.User
|
||||||
for userID, groupIDs := range userIDToGroupIDs {
|
for _, userID := range groupLookup.getUserIDs() {
|
||||||
sort.Strings(groupIDs)
|
|
||||||
users = append(users, &directory.User{
|
users = append(users, &directory.User{
|
||||||
Id: databroker.GetUserID(Name, userID),
|
Id: databroker.GetUserID(Name, userID),
|
||||||
GroupIds: groupIDs,
|
GroupIds: groupLookup.getGroupIDsForUser(userID),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
sort.Slice(users, func(i, j int) bool {
|
sort.Slice(users, func(i, j int) bool {
|
||||||
|
@ -168,7 +164,7 @@ func (p *Provider) listGroups(ctx context.Context) ([]*directory.Group, error) {
|
||||||
return groups, nil
|
return groups, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Provider) listGroupMembers(ctx context.Context, groupID string) (userIDs []string, err error) {
|
func (p *Provider) listGroupMembers(ctx context.Context, groupID string) (groupIDs, userIDs []string, err error) {
|
||||||
nextURL := p.cfg.graphURL.ResolveReference(&url.URL{
|
nextURL := p.cfg.graphURL.ResolveReference(&url.URL{
|
||||||
Path: fmt.Sprintf("/v1.0/groups/%s/members", groupID),
|
Path: fmt.Sprintf("/v1.0/groups/%s/members", groupID),
|
||||||
}).String()
|
}).String()
|
||||||
|
@ -176,21 +172,27 @@ func (p *Provider) listGroupMembers(ctx context.Context, groupID string) (userID
|
||||||
for nextURL != "" {
|
for nextURL != "" {
|
||||||
var result struct {
|
var result struct {
|
||||||
Value []struct {
|
Value []struct {
|
||||||
|
Type string `json:"@odata.type"`
|
||||||
ID string `json:"id"`
|
ID string `json:"id"`
|
||||||
} `json:"value"`
|
} `json:"value"`
|
||||||
NextLink string `json:"@odata.nextLink"`
|
NextLink string `json:"@odata.nextLink"`
|
||||||
}
|
}
|
||||||
err := p.api(ctx, "GET", nextURL, nil, &result)
|
err := p.api(ctx, "GET", nextURL, nil, &result)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
for _, v := range result.Value {
|
for _, v := range result.Value {
|
||||||
|
switch v.Type {
|
||||||
|
case "#microsoft.graph.group":
|
||||||
|
groupIDs = append(groupIDs, v.ID)
|
||||||
|
case "#microsoft.graph.user":
|
||||||
userIDs = append(userIDs, v.ID)
|
userIDs = append(userIDs, v.ID)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
nextURL = result.NextLink
|
nextURL = result.NextLink
|
||||||
}
|
}
|
||||||
|
|
||||||
return userIDs, nil
|
return groupIDs, userIDs, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Provider) api(ctx context.Context, method, url string, body io.Reader, out interface{}) error {
|
func (p *Provider) api(ctx context.Context, method, url string, body io.Reader, out interface{}) error {
|
||||||
|
|
|
@ -54,11 +54,11 @@ func newMockAPI(t *testing.T, srv *httptest.Server) http.Handler {
|
||||||
r.Get("/groups/{group_name}/members", func(w http.ResponseWriter, r *http.Request) {
|
r.Get("/groups/{group_name}/members", func(w http.ResponseWriter, r *http.Request) {
|
||||||
members := map[string][]M{
|
members := map[string][]M{
|
||||||
"admin": {
|
"admin": {
|
||||||
{"id": "user-1"},
|
{"@odata.type": "#microsoft.graph.user", "id": "user-1"},
|
||||||
},
|
},
|
||||||
"test": {
|
"test": {
|
||||||
{"id": "user-2"},
|
{"@odata.type": "#microsoft.graph.user", "id": "user-2"},
|
||||||
{"id": "user-3"},
|
{"@odata.type": "#microsoft.graph.user", "id": "user-3"},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
_ = json.NewEncoder(w).Encode(M{
|
_ = json.NewEncoder(w).Encode(M{
|
||||||
|
|
107
internal/directory/azure/group.go
Normal file
107
internal/directory/azure/group.go
Normal file
|
@ -0,0 +1,107 @@
|
||||||
|
package azure
|
||||||
|
|
||||||
|
import "sort"
|
||||||
|
|
||||||
|
type stringSet map[string]struct{}
|
||||||
|
|
||||||
|
func newStringSet() stringSet {
|
||||||
|
return make(stringSet)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ss stringSet) add(value string) {
|
||||||
|
ss[value] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ss stringSet) has(value string) bool {
|
||||||
|
if ss == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
_, ok := ss[value]
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ss stringSet) sorted() []string {
|
||||||
|
if ss == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
s := make([]string, 0, len(ss))
|
||||||
|
for v := range ss {
|
||||||
|
s = append(s, v)
|
||||||
|
}
|
||||||
|
sort.Strings(s)
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
type stringSetSet map[string]stringSet
|
||||||
|
|
||||||
|
func newStringSetSet() stringSetSet {
|
||||||
|
return make(stringSetSet)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sss stringSetSet) add(v1, v2 string) {
|
||||||
|
ss, ok := sss[v1]
|
||||||
|
if !ok {
|
||||||
|
ss = newStringSet()
|
||||||
|
sss[v1] = ss
|
||||||
|
}
|
||||||
|
ss.add(v2)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sss stringSetSet) get(v1 string) stringSet {
|
||||||
|
return sss[v1]
|
||||||
|
}
|
||||||
|
|
||||||
|
type groupLookup struct {
|
||||||
|
childUserIDToParentGroupID stringSetSet
|
||||||
|
childGroupIDToParentGroupID stringSetSet
|
||||||
|
}
|
||||||
|
|
||||||
|
func newGroupLookup() *groupLookup {
|
||||||
|
return &groupLookup{
|
||||||
|
childUserIDToParentGroupID: newStringSetSet(),
|
||||||
|
childGroupIDToParentGroupID: newStringSetSet(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *groupLookup) addGroup(parentGroupID string, childGroupIDs, childUserIDs []string) {
|
||||||
|
for _, childGroupID := range childGroupIDs {
|
||||||
|
l.childGroupIDToParentGroupID.add(childGroupID, parentGroupID)
|
||||||
|
}
|
||||||
|
for _, childUserID := range childUserIDs {
|
||||||
|
l.childUserIDToParentGroupID.add(childUserID, parentGroupID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *groupLookup) getUserIDs() []string {
|
||||||
|
s := make([]string, 0, len(l.childUserIDToParentGroupID))
|
||||||
|
for userID := range l.childUserIDToParentGroupID {
|
||||||
|
s = append(s, userID)
|
||||||
|
}
|
||||||
|
sort.Strings(s)
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *groupLookup) getGroupIDsForUser(userID string) []string {
|
||||||
|
groupIDs := newStringSet()
|
||||||
|
var todo []string
|
||||||
|
for groupID := range l.childUserIDToParentGroupID.get(userID) {
|
||||||
|
todo = append(todo, groupID)
|
||||||
|
}
|
||||||
|
|
||||||
|
for len(todo) > 0 {
|
||||||
|
groupID := todo[len(todo)-1]
|
||||||
|
todo = todo[:len(todo)-1]
|
||||||
|
if groupIDs.has(groupID) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
groupIDs.add(groupID)
|
||||||
|
for parentGroupID := range l.childGroupIDToParentGroupID.get(groupID) {
|
||||||
|
todo = append(todo, parentGroupID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return groupIDs.sorted()
|
||||||
|
}
|
25
internal/directory/azure/group_test.go
Normal file
25
internal/directory/azure/group_test.go
Normal file
|
@ -0,0 +1,25 @@
|
||||||
|
package azure
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGroupLookup(t *testing.T) {
|
||||||
|
gl := newGroupLookup()
|
||||||
|
|
||||||
|
gl.addGroup("g1", []string{"g11", "g12", "g13"}, []string{"u1"})
|
||||||
|
gl.addGroup("g11", []string{"g111"}, nil)
|
||||||
|
gl.addGroup("g111", nil, []string{"u2"})
|
||||||
|
|
||||||
|
assert.Equal(t, []string{"u1", "u2"}, gl.getUserIDs())
|
||||||
|
assert.Equal(t, []string{"g1", "g11", "g111"}, gl.getGroupIDsForUser("u2"))
|
||||||
|
|
||||||
|
t.Run("cycle protection", func(t *testing.T) {
|
||||||
|
gl.addGroup("g12", []string{"g1"}, nil)
|
||||||
|
|
||||||
|
assert.Equal(t, []string{"u1", "u2"}, gl.getUserIDs())
|
||||||
|
assert.Equal(t, []string{"g1", "g11", "g111", "g12"}, gl.getGroupIDsForUser("u2"))
|
||||||
|
})
|
||||||
|
}
|
Loading…
Add table
Add a link
Reference in a new issue