directory.Group entry for groups (#1118)

* store directory groups separate from directory users

* fix group lookup, azure display name

* remove fields restriction

* fix test

* also support email

* use Email as name for google'

* remove changed file

* show groups on dashboard

* fix test

* re-add accidentally removed code
This commit is contained in:
Caleb Doxsey 2020-07-22 11:28:53 -06:00 committed by GitHub
parent 489cdd8b63
commit 1ad243dfd1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
25 changed files with 525 additions and 209 deletions

View file

@ -484,17 +484,30 @@ func (a *Authenticate) Dashboard(w http.ResponseWriter, r *http.Request) error {
Id: pbSession.GetUserId(), Id: pbSession.GetUserId(),
} }
} }
pbDirectoryUser, err := directory.Get(r.Context(), a.dataBrokerClient, pbSession.GetUserId()) pbDirectoryUser, err := directory.GetUser(r.Context(), a.dataBrokerClient, pbSession.GetUserId())
if err != nil { if err != nil {
pbDirectoryUser = &directory.User{ pbDirectoryUser = &directory.User{
Id: pbSession.GetUserId(), Id: pbSession.GetUserId(),
} }
} }
var groups []*directory.Group
for _, groupID := range pbDirectoryUser.GetGroupIds() {
pbDirectoryGroup, err := directory.GetGroup(r.Context(), a.dataBrokerClient, groupID)
if err != nil {
pbDirectoryGroup = &directory.Group{
Id: groupID,
Name: groupID,
Email: groupID,
}
}
groups = append(groups, pbDirectoryGroup)
}
input := map[string]interface{}{ input := map[string]interface{}{
"State": s, "State": s,
"Session": pbSession, "Session": pbSession,
"User": pbUser, "User": pbUser,
"DirectoryGroups": groups,
"DirectoryUser": pbDirectoryUser, "DirectoryUser": pbDirectoryUser,
"csrfField": csrf.TemplateField(r), "csrfField": csrf.TemplateField(r),
"ImpersonateAction": urlutil.QueryImpersonateAction, "ImpersonateAction": urlutil.QueryImpersonateAction,

View file

@ -30,9 +30,10 @@ import (
) )
const ( const (
sessionTypeURL = "type.googleapis.com/session.Session" sessionTypeURL = "type.googleapis.com/session.Session"
userTypeURL = "type.googleapis.com/user.User" userTypeURL = "type.googleapis.com/user.User"
directoryUserTypeURL = "type.googleapis.com/directory.User" directoryUserTypeURL = "type.googleapis.com/directory.User"
directoryGroupTypeURL = "type.googleapis.com/directory.Group"
) )
// Evaluator specifies the interface for a policy engine. // Evaluator specifies the interface for a policy engine.
@ -217,7 +218,16 @@ func (e *Evaluator) JWTPayload(req *Request) map[string]interface{} {
payload["email"] = u.GetEmail() payload["email"] = u.GetEmail()
} }
if du, ok := req.DataBrokerData.Get("type.googleapis.com/directory.User", s.GetUserId()).(*directory.User); ok { if du, ok := req.DataBrokerData.Get("type.googleapis.com/directory.User", s.GetUserId()).(*directory.User); ok {
payload["groups"] = du.GetGroups() var groupNames []string
for _, groupID := range du.GetGroupIds() {
if dg, ok := req.DataBrokerData.Get("type.googleapis.com/directory.Group", groupID).(*directory.Group); ok {
groupNames = append(groupNames, dg.Name)
}
}
var groups []string
groups = append(groups, du.GetGroupIds()...)
groups = append(groups, groupNames...)
payload["groups"] = groups
} }
} }
return payload return payload
@ -255,9 +265,9 @@ type input struct {
} }
type dataBrokerDataInput struct { type dataBrokerDataInput struct {
Session interface{} `json:"session,omitempty"` Session interface{} `json:"session,omitempty"`
User interface{} `json:"user,omitempty"` User interface{} `json:"user,omitempty"`
DirectoryUser interface{} `json:"directory_user,omitempty"` Groups interface{} `json:"groups,omitempty"`
} }
func (e *Evaluator) newInput(req *Request, isValidClientCertificate bool) *input { func (e *Evaluator) newInput(req *Request, isValidClientCertificate bool) *input {
@ -265,7 +275,23 @@ func (e *Evaluator) newInput(req *Request, isValidClientCertificate bool) *input
i.DataBrokerData.Session = req.DataBrokerData.Get(sessionTypeURL, req.Session.ID) i.DataBrokerData.Session = req.DataBrokerData.Get(sessionTypeURL, req.Session.ID)
if obj, ok := i.DataBrokerData.Session.(interface{ GetUserId() string }); ok { if obj, ok := i.DataBrokerData.Session.(interface{ GetUserId() string }); ok {
i.DataBrokerData.User = req.DataBrokerData.Get(userTypeURL, obj.GetUserId()) i.DataBrokerData.User = req.DataBrokerData.Get(userTypeURL, obj.GetUserId())
i.DataBrokerData.DirectoryUser = req.DataBrokerData.Get(directoryUserTypeURL, obj.GetUserId())
user, ok := req.DataBrokerData.Get(directoryUserTypeURL, obj.GetUserId()).(*directory.User)
if ok {
var groups []string
for _, groupID := range user.GetGroupIds() {
if dg, ok := req.DataBrokerData.Get(directoryGroupTypeURL, groupID).(*directory.Group); ok {
if dg.Name != "" {
groups = append(groups, dg.Name)
}
if dg.Email != "" {
groups = append(groups, dg.Email)
}
}
}
groups = append(groups, user.GetGroupIds()...)
i.DataBrokerData.Groups = groups
}
} }
i.HTTP = req.HTTP i.HTTP = req.HTTP
i.Session = req.Session i.Session = req.Session

View file

@ -25,8 +25,8 @@ func TestJSONMarshal(t *testing.T) {
dbd := DataBrokerData{ dbd := DataBrokerData{
"type.googleapis.com/directory.User": map[string]interface{}{ "type.googleapis.com/directory.User": map[string]interface{}{
"user1": &directory.User{ "user1": &directory.User{
Id: "user1", Id: "user1",
Groups: []string{"group1", "group2"}, GroupIds: []string{"group1", "group2"},
}, },
}, },
"type.googleapis.com/session.Session": map[string]interface{}{}, "type.googleapis.com/session.Session": map[string]interface{}{},

View file

@ -7,7 +7,7 @@ route_policy_idx := first_allowed_route_policy_idx(input.http.url)
route_policy := data.route_policies[route_policy_idx] route_policy := data.route_policies[route_policy_idx]
session := input.databroker_data.session session := input.databroker_data.session
user := input.databroker_data.user user := input.databroker_data.user
directory_user := input.databroker_data.directory_user groups := input.databroker_data.groups
all_allowed_domains := get_allowed_domains(route_policy) all_allowed_domains := get_allowed_domains(route_policy)
all_allowed_groups := get_allowed_groups(route_policy) all_allowed_groups := get_allowed_groups(route_policy)
@ -35,7 +35,7 @@ allow {
# allow group # allow group
allow { allow {
some group some group
directory_user.groups[_] = group groups[_] = group
all_allowed_groups[_] = group all_allowed_groups[_] = group
input.session.impersonate_groups == null input.session.impersonate_groups == null
} }

View file

@ -67,9 +67,7 @@ test_group_allowed {
"user": { "user": {
"email": "x@example.com", "email": "x@example.com",
}, },
"directory_user": { "groups": ["1"]
"groups": ["1"]
}
} with } with
input.http as { "url": "http://example.com" } with input.http as { "url": "http://example.com" } with
input.session as { "id": "session1", "impersonate_groups": null } input.session as { "id": "session1", "impersonate_groups": null }

File diff suppressed because one or more lines are too long

1
go.sum
View file

@ -208,7 +208,6 @@ github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw
github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
github.com/gomodule/redigo v1.8.2 h1:H5XSIre1MB5NbPYFp+i1NBbb5qN1W8Y8YAQoAYbkm8k= github.com/gomodule/redigo v1.8.2 h1:H5XSIre1MB5NbPYFp+i1NBbb5qN1W8Y8YAQoAYbkm8k=
github.com/gomodule/redigo v1.8.2/go.mod h1:P9dn9mFrCBvWhGE1wpxx6fgq7BAeLBk+UUUzlpkBYO0= github.com/gomodule/redigo v1.8.2/go.mod h1:P9dn9mFrCBvWhGE1wpxx6fgq7BAeLBk+UUUzlpkBYO0=
github.com/gomodule/redigo/redis v0.0.0-do-not-use h1:J7XIp6Kau0WoyT4JtXHT3Ei0gA1KkSc6bc87j9v9WIo=
github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ=
github.com/google/btree v1.0.0 h1:0udJVsspx3VBr5FwtLhQQtuAsVc79tTq0ocGIPAU6qo= github.com/google/btree v1.0.0 h1:0udJVsspx3VBr5FwtLhQQtuAsVc79tTq0ocGIPAU6qo=
github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ=

View file

@ -101,25 +101,25 @@ func New(options ...Option) *Provider {
} }
// UserGroups returns the directory users in azure active directory. // UserGroups returns the directory users in azure active directory.
func (p *Provider) UserGroups(ctx context.Context) ([]*directory.User, error) { func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*directory.User, error) {
if p.cfg.serviceAccount == nil { if p.cfg.serviceAccount == nil {
return nil, fmt.Errorf("azure: service account not defined") return nil, nil, fmt.Errorf("azure: service account not defined")
} }
groupIDs, err := p.listGroups(ctx) groups, err := p.listGroups(ctx)
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
userIDToGroupIDs := map[string][]string{} userIDToGroupIDs := map[string][]string{}
for groupID, groupName := range groupIDs { for _, group := range groups {
userIDs, err := p.listGroupMembers(ctx, groupID) userIDs, err := p.listGroupMembers(ctx, group.Id)
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
for _, userID := range userIDs { for _, userID := range userIDs {
userIDToGroupIDs[userID] = append(userIDToGroupIDs[userID], groupID, groupName) userIDToGroupIDs[userID] = append(userIDToGroupIDs[userID], group.Id)
} }
} }
@ -127,28 +127,28 @@ func (p *Provider) UserGroups(ctx context.Context) ([]*directory.User, error) {
for userID, groupIDs := range userIDToGroupIDs { for userID, groupIDs := range userIDToGroupIDs {
sort.Strings(groupIDs) sort.Strings(groupIDs)
users = append(users, &directory.User{ users = append(users, &directory.User{
Id: databroker.GetUserID(Name, userID), Id: databroker.GetUserID(Name, userID),
Groups: groupIDs, GroupIds: groupIDs,
}) })
} }
sort.Slice(users, func(i, j int) bool { sort.Slice(users, func(i, j int) bool {
return users[i].GetId() < users[j].GetId() return users[i].GetId() < users[j].GetId()
}) })
return users, nil return groups, users, nil
} }
// listGroups returns a map, with key is group ID, element is group name. // listGroups returns a map, with key is group ID, element is group name.
func (p *Provider) listGroups(ctx context.Context) (map[string]string, error) { func (p *Provider) listGroups(ctx context.Context) ([]*directory.Group, error) {
nextURL := p.cfg.graphURL.ResolveReference(&url.URL{ nextURL := p.cfg.graphURL.ResolveReference(&url.URL{
Path: "/v1.0/groups", Path: "/v1.0/groups",
}).String() }).String()
groups := make(map[string]string) var groups []*directory.Group
for nextURL != "" { for nextURL != "" {
var result struct { var result struct {
Value []struct { Value []struct {
ID string `json:"id"` ID string `json:"id"`
Name string `json:"name"` DisplayName string `json:"displayName"`
} `json:"value"` } `json:"value"`
NextLink string `json:"@odata.nextLink"` NextLink string `json:"@odata.nextLink"`
} }
@ -157,7 +157,10 @@ func (p *Provider) listGroups(ctx context.Context) (map[string]string, error) {
return nil, err return nil, err
} }
for _, v := range result.Value { for _, v := range result.Value {
groups[v.ID] = v.Name groups = append(groups, &directory.Group{
Id: v.ID,
Name: v.DisplayName,
})
} }
nextURL = result.NextLink nextURL = result.NextLink
} }

View file

@ -45,8 +45,8 @@ func newMockAPI(t *testing.T, srv *httptest.Server) http.Handler {
r.Get("/groups", func(w http.ResponseWriter, r *http.Request) { r.Get("/groups", func(w http.ResponseWriter, r *http.Request) {
_ = json.NewEncoder(w).Encode(M{ _ = json.NewEncoder(w).Encode(M{
"value": []M{ "value": []M{
{"id": "admin", "name": "Admin Group"}, {"id": "admin", "displayName": "Admin Group"},
{"id": "test", "name": "Test Group"}, {"id": "test", "displayName": "Test Group"},
}, },
}) })
}) })
@ -85,22 +85,26 @@ func Test(t *testing.T) {
DirectoryID: "DIRECTORY_ID", DirectoryID: "DIRECTORY_ID",
}), }),
) )
users, err := p.UserGroups(context.Background()) groups, users, err := p.UserGroups(context.Background())
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, []*directory.User{ assert.Equal(t, []*directory.User{
{ {
Id: "azure/user-1", Id: "azure/user-1",
Groups: []string{"Admin Group", "admin"}, GroupIds: []string{"admin"},
}, },
{ {
Id: "azure/user-2", Id: "azure/user-2",
Groups: []string{"Test Group", "test"}, GroupIds: []string{"test"},
}, },
{ {
Id: "azure/user-3", Id: "azure/user-3",
Groups: []string{"Test Group", "test"}, GroupIds: []string{"test"},
}, },
}, users) }, users)
assert.Equal(t, []*directory.Group{
{Id: "admin", Name: "Admin Group"},
{Id: "test", Name: "Test Group"},
}, groups)
} }
func mustParseURL(rawurl string) *url.URL { func mustParseURL(rawurl string) *url.URL {

View file

@ -86,49 +86,52 @@ func New(options ...Option) *Provider {
} }
// UserGroups gets the directory user groups for github. // UserGroups gets the directory user groups for github.
func (p *Provider) UserGroups(ctx context.Context) ([]*directory.User, error) { func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*directory.User, error) {
if p.cfg.serviceAccount == nil { if p.cfg.serviceAccount == nil {
return nil, fmt.Errorf("github: service account not defined") return nil, nil, fmt.Errorf("github: service account not defined")
} }
orgSlugs, err := p.listOrgs(ctx) orgSlugs, err := p.listOrgs(ctx)
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
userLoginToGroups := map[string][]string{} userLoginToGroups := map[string][]string{}
var allGroups []*directory.Group
for _, orgSlug := range orgSlugs { for _, orgSlug := range orgSlugs {
teamSlugs, err := p.listTeams(ctx, orgSlug) groups, err := p.listGroups(ctx, orgSlug)
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
for teamSlug, teamID := range teamSlugs { for _, group := range groups {
userLogins, err := p.listTeamMembers(ctx, orgSlug, teamSlug) userLogins, err := p.listTeamMembers(ctx, orgSlug, group.Name)
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
for _, userLogin := range userLogins { for _, userLogin := range userLogins {
userLoginToGroups[userLogin] = append(userLoginToGroups[userLogin], teamSlug, strconv.Itoa(teamID)) userLoginToGroups[userLogin] = append(userLoginToGroups[userLogin], group.Id)
} }
} }
allGroups = append(allGroups, groups...)
} }
var users []*directory.User var users []*directory.User
for userLogin, groups := range userLoginToGroups { for userLogin, groups := range userLoginToGroups {
user := &directory.User{ user := &directory.User{
Id: databroker.GetUserID(Name, userLogin), Id: databroker.GetUserID(Name, userLogin),
Groups: groups, GroupIds: groups,
} }
sort.Strings(user.Groups) sort.Strings(user.GroupIds)
users = append(users, user) users = append(users, user)
} }
sort.Slice(users, func(i, j int) bool { sort.Slice(users, func(i, j int) bool {
return users[i].GetId() < users[j].GetId() return users[i].GetId() < users[j].GetId()
}) })
return users, nil return allGroups, users, nil
} }
func (p *Provider) listOrgs(ctx context.Context) (orgSlugs []string, err error) { func (p *Provider) listOrgs(ctx context.Context) (orgSlugs []string, err error) {
@ -155,12 +158,12 @@ func (p *Provider) listOrgs(ctx context.Context) (orgSlugs []string, err error)
return orgSlugs, nil return orgSlugs, nil
} }
func (p *Provider) listTeams(ctx context.Context, orgSlug string) (map[string]int, error) { func (p *Provider) listGroups(ctx context.Context, orgSlug string) ([]*directory.Group, error) {
nextURL := p.cfg.url.ResolveReference(&url.URL{ nextURL := p.cfg.url.ResolveReference(&url.URL{
Path: fmt.Sprintf("/orgs/%s/teams", orgSlug), Path: fmt.Sprintf("/orgs/%s/teams", orgSlug),
}).String() }).String()
teamSlugs := make(map[string]int) var groups []*directory.Group
for nextURL != "" { for nextURL != "" {
var results []struct { var results []struct {
ID int `json:"id"` ID int `json:"id"`
@ -172,13 +175,16 @@ func (p *Provider) listTeams(ctx context.Context, orgSlug string) (map[string]in
} }
for _, result := range results { for _, result := range results {
teamSlugs[result.Slug] = result.ID groups = append(groups, &directory.Group{
Id: strconv.Itoa(result.ID),
Name: result.Slug,
})
} }
nextURL = getNextLink(hdrs) nextURL = getNextLink(hdrs)
} }
return teamSlugs, nil return groups, nil
} }
func (p *Provider) listTeamMembers(ctx context.Context, orgSlug, teamSlug string) (userLogins []string, err error) { func (p *Provider) listTeamMembers(ctx context.Context, orgSlug, teamSlug string) (userLogins []string, err error) {

View file

@ -93,14 +93,20 @@ func Test(t *testing.T) {
PersonalAccessToken: "xyz", PersonalAccessToken: "xyz",
}), }),
) )
users, err := p.UserGroups(context.Background()) groups, users, err := p.UserGroups(context.Background())
assert.NoError(t, err) assert.NoError(t, err)
testutil.AssertProtoJSONEqual(t, `[ testutil.AssertProtoJSONEqual(t, `[
{ "id": "github/user1", "groups": ["1", "2", "3", "team1", "team2", "team3"] }, { "id": "github/user1", "groupIds": ["1", "2", "3"] },
{ "id": "github/user2", "groups": ["1", "3", "team1", "team3"] }, { "id": "github/user2", "groupIds": ["1", "3"] },
{ "id": "github/user3", "groups": ["3", "team3"] }, { "id": "github/user3", "groupIds": ["3"] },
{ "id": "github/user4", "groups": ["4", "team4"] } { "id": "github/user4", "groupIds": ["4"] }
]`, users) ]`, users)
testutil.AssertProtoJSONEqual(t, `[
{ "id": "1", "name": "team1" },
{ "id": "2", "name": "team2" },
{ "id": "3", "name": "team3" },
{ "id": "4", "name": "team4" }
]`, groups)
} }
func mustParseURL(rawurl string) *url.URL { func mustParseURL(rawurl string) *url.URL {

View file

@ -84,27 +84,27 @@ func New(options ...Option) *Provider {
} }
// UserGroups gets the directory user groups for gitlab. // UserGroups gets the directory user groups for gitlab.
func (p *Provider) UserGroups(ctx context.Context) ([]*directory.User, error) { func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*directory.User, error) {
if p.cfg.serviceAccount == nil { if p.cfg.serviceAccount == nil {
return nil, fmt.Errorf("gitlab: service account not defined") return nil, nil, fmt.Errorf("gitlab: service account not defined")
} }
p.log.Info().Msg("getting user groups") p.log.Info().Msg("getting user groups")
groups, err := p.listGroups(ctx) groups, err := p.listGroups(ctx)
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
userIDToGroupIDs := map[int][]string{} userIDToGroupIDs := map[int][]string{}
for groupID, groupName := range groups { for _, group := range groups {
userIDs, err := p.listGroupMemberIDs(ctx, groupID) userIDs, err := p.listGroupMemberIDs(ctx, group.Id)
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
for _, userID := range userIDs { for _, userID := range userIDs {
userIDToGroupIDs[userID] = append(userIDToGroupIDs[userID], strconv.Itoa(groupID), groupName) userIDToGroupIDs[userID] = append(userIDToGroupIDs[userID], group.Id)
} }
} }
@ -114,23 +114,23 @@ func (p *Provider) UserGroups(ctx context.Context) ([]*directory.User, error) {
Id: databroker.GetUserID(Name, fmt.Sprint(userID)), Id: databroker.GetUserID(Name, fmt.Sprint(userID)),
} }
user.Groups = append(user.Groups, groups...) user.GroupIds = append(user.GroupIds, groups...)
sort.Strings(user.Groups) sort.Strings(user.GroupIds)
users = append(users, user) users = append(users, user)
} }
sort.Slice(users, func(i, j int) bool { sort.Slice(users, func(i, j int) bool {
return users[i].GetId() < users[j].GetId() return users[i].GetId() < users[j].GetId()
}) })
return users, nil return groups, users, nil
} }
// listGroups returns a map, with key is group ID, element is group name. // listGroups returns a map, with key is group ID, element is group name.
func (p *Provider) listGroups(ctx context.Context) (map[int]string, error) { func (p *Provider) listGroups(ctx context.Context) ([]*directory.Group, error) {
nextURL := p.cfg.url.ResolveReference(&url.URL{ nextURL := p.cfg.url.ResolveReference(&url.URL{
Path: "/api/v4/groups", Path: "/api/v4/groups",
}).String() }).String()
groups := make(map[int]string) var groups []*directory.Group
for nextURL != "" { for nextURL != "" {
var result []struct { var result []struct {
ID int `json:"id"` ID int `json:"id"`
@ -142,7 +142,10 @@ func (p *Provider) listGroups(ctx context.Context) (map[int]string, error) {
} }
for _, r := range result { for _, r := range result {
groups[r.ID] = r.Name groups = append(groups, &directory.Group{
Id: strconv.Itoa(r.ID),
Name: r.Name,
})
} }
nextURL = getNextLink(hdrs) nextURL = getNextLink(hdrs)
@ -150,9 +153,9 @@ func (p *Provider) listGroups(ctx context.Context) (map[int]string, error) {
return groups, nil return groups, nil
} }
func (p *Provider) listGroupMemberIDs(ctx context.Context, groupID int) (userIDs []int, err error) { func (p *Provider) listGroupMemberIDs(ctx context.Context, groupID string) (userIDs []int, err error) {
nextURL := p.cfg.url.ResolveReference(&url.URL{ nextURL := p.cfg.url.ResolveReference(&url.URL{
Path: fmt.Sprintf("/api/v4/groups/%d/members", groupID), Path: fmt.Sprintf("/api/v4/groups/%s/members", groupID),
}).String() }).String()
for nextURL != "" { for nextURL != "" {
var result []struct { var result []struct {

View file

@ -66,13 +66,17 @@ func Test(t *testing.T) {
PrivateToken: "PRIVATE_TOKEN", PrivateToken: "PRIVATE_TOKEN",
}), }),
) )
users, err := p.UserGroups(context.Background()) groups, users, err := p.UserGroups(context.Background())
assert.NoError(t, err) assert.NoError(t, err)
testutil.AssertProtoJSONEqual(t, `[ testutil.AssertProtoJSONEqual(t, `[
{ "id": "gitlab/11", "groups": ["1", "Group 1"] }, { "id": "gitlab/11", "groupIds": ["1"] },
{ "id": "gitlab/12", "groups": ["2", "Group 2"] }, { "id": "gitlab/12", "groupIds": ["2"] },
{ "id": "gitlab/13", "groups": ["2", "Group 2"] } { "id": "gitlab/13", "groupIds": ["2"] }
]`, users) ]`, users)
testutil.AssertProtoJSONEqual(t, `[
{ "id": "1", "name": "Group 1" },
{ "id": "2", "name": "Group 2" }
]`, groups)
} }
func mustParseURL(rawurl string) *url.URL { func mustParseURL(rawurl string) *url.URL {

View file

@ -82,17 +82,15 @@ func New(options ...Option) *Provider {
// NOTE: groups via Directory API is limited to 1 QPS! // NOTE: groups via Directory API is limited to 1 QPS!
// https://developers.google.com/admin-sdk/directory/v1/reference/groups/list // https://developers.google.com/admin-sdk/directory/v1/reference/groups/list
// https://developers.google.com/admin-sdk/directory/v1/limits // https://developers.google.com/admin-sdk/directory/v1/limits
func (p *Provider) UserGroups(ctx context.Context) ([]*directory.User, error) { func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*directory.User, error) {
apiClient, err := p.getAPIClient(ctx) apiClient, err := p.getAPIClient(ctx)
if err != nil { if err != nil {
return nil, fmt.Errorf("google: error getting API client: %w", err) return nil, nil, fmt.Errorf("google: error getting API client: %w", err)
} }
groupIDToEmails := map[string]string{} var groups []*directory.Group
var groups []string
err = apiClient.Groups.List(). err = apiClient.Groups.List().
Context(ctx). Context(ctx).
Fields("id", "email", "directMembersCount").
Customer("my_customer"). Customer("my_customer").
Pages(ctx, func(res *admin.Groups) error { Pages(ctx, func(res *admin.Groups) error {
for _, g := range res.Groups { for _, g := range res.Groups {
@ -100,29 +98,31 @@ func (p *Provider) UserGroups(ctx context.Context) ([]*directory.User, error) {
if g.DirectMembersCount == 0 { if g.DirectMembersCount == 0 {
continue continue
} }
groups = append(groups, g.Id) groups = append(groups, &directory.Group{
groupIDToEmails[g.Id] = g.Email Id: g.Id,
Name: g.Email,
Email: g.Email,
})
} }
return nil return nil
}) })
if err != nil { if err != nil {
return nil, fmt.Errorf("google: error getting groups: %w", err) return nil, nil, fmt.Errorf("google: error getting groups: %w", err)
} }
userIDToGroups := map[string][]string{} userIDToGroups := map[string][]string{}
for _, group := range groups { for _, group := range groups {
group := group group := group
err = apiClient.Members.List(group). err = apiClient.Members.List(group.Id).
Context(ctx). Context(ctx).
Fields("id").
Pages(ctx, func(res *admin.Members) error { Pages(ctx, func(res *admin.Members) error {
for _, member := range res.Members { for _, member := range res.Members {
userIDToGroups[member.Id] = append(userIDToGroups[member.Id], group, groupIDToEmails[group]) userIDToGroups[member.Id] = append(userIDToGroups[member.Id], group.Id)
} }
return nil return nil
}) })
if err != nil { if err != nil {
return nil, fmt.Errorf("google: error getting group members: %w", err) return nil, nil, fmt.Errorf("google: error getting group members: %w", err)
} }
} }
@ -130,14 +130,14 @@ func (p *Provider) UserGroups(ctx context.Context) ([]*directory.User, error) {
for userID, groups := range userIDToGroups { for userID, groups := range userIDToGroups {
sort.Strings(groups) sort.Strings(groups)
users = append(users, &directory.User{ users = append(users, &directory.User{
Id: databroker.GetUserID(Name, userID), Id: databroker.GetUserID(Name, userID),
Groups: groups, GroupIds: groups,
}) })
} }
sort.Slice(users, func(i, j int) bool { sort.Slice(users, func(i, j int) bool {
return users[i].Id < users[j].Id return users[i].Id < users[j].Id
}) })
return users, nil return groups, users, nil
} }
func (p *Provider) getAPIClient(ctx context.Context) (*admin.Service, error) { func (p *Provider) getAPIClient(ctx context.Context) (*admin.Service, error) {

View file

@ -85,30 +85,30 @@ func New(options ...Option) *Provider {
// UserGroups fetches the groups of which the user is a member // UserGroups fetches the groups of which the user is a member
// https://developer.okta.com/docs/reference/api/users/#get-user-s-groups // https://developer.okta.com/docs/reference/api/users/#get-user-s-groups
func (p *Provider) UserGroups(ctx context.Context) ([]*directory.User, error) { func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*directory.User, error) {
if p.cfg.serviceAccount == nil { if p.cfg.serviceAccount == nil {
return nil, fmt.Errorf("okta: service account not defined") return nil, nil, fmt.Errorf("okta: service account not defined")
} }
p.log.Info().Msg("getting user groups") p.log.Info().Msg("getting user groups")
if p.cfg.providerURL == nil { if p.cfg.providerURL == nil {
return nil, fmt.Errorf("okta: provider url not defined") return nil, nil, fmt.Errorf("okta: provider url not defined")
} }
groupIDToName, err := p.getGroups(ctx) groups, err := p.getGroups(ctx)
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
userIDToGroups := map[string][]string{} userIDToGroups := map[string][]string{}
for groupID, groupName := range groupIDToName { for _, group := range groups {
ids, err := p.getGroupMemberIDs(ctx, groupID) ids, err := p.getGroupMemberIDs(ctx, group.Id)
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
for _, id := range ids { for _, id := range ids {
userIDToGroups[id] = append(userIDToGroups[id], groupID, groupName) userIDToGroups[id] = append(userIDToGroups[id], group.Id)
} }
} }
@ -116,19 +116,18 @@ func (p *Provider) UserGroups(ctx context.Context) ([]*directory.User, error) {
for userID, groups := range userIDToGroups { for userID, groups := range userIDToGroups {
sort.Strings(groups) sort.Strings(groups)
users = append(users, &directory.User{ users = append(users, &directory.User{
Id: databroker.GetUserID(Name, userID), Id: databroker.GetUserID(Name, userID),
Groups: groups, GroupIds: groups,
}) })
} }
sort.Slice(users, func(i, j int) bool { sort.Slice(users, func(i, j int) bool {
return users[i].Id < users[j].Id return users[i].Id < users[j].Id
}) })
return users, nil return groups, users, nil
} }
func (p *Provider) getGroups(ctx context.Context) (map[string]string, error) { func (p *Provider) getGroups(ctx context.Context) ([]*directory.Group, error) {
groups := map[string]string{} var groups []*directory.Group
groupURL := p.cfg.providerURL.ResolveReference(&url.URL{ groupURL := p.cfg.providerURL.ResolveReference(&url.URL{
Path: "/api/v1/groups", Path: "/api/v1/groups",
RawQuery: fmt.Sprintf("limit=%d", p.cfg.batchSize), RawQuery: fmt.Sprintf("limit=%d", p.cfg.batchSize),
@ -146,12 +145,14 @@ func (p *Provider) getGroups(ctx context.Context) (map[string]string, error) {
} }
for _, el := range out { for _, el := range out {
groups[el.ID] = el.Profile.Name groups = append(groups, &directory.Group{
Id: el.ID,
Name: el.Profile.Name,
})
} }
groupURL = getNextLink(hdrs) groupURL = getNextLink(hdrs)
} }
return groups, nil return groups, nil
} }

View file

@ -14,6 +14,7 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/tomnomnom/linkheader" "github.com/tomnomnom/linkheader"
"github.com/pomerium/pomerium/internal/testutil"
"github.com/pomerium/pomerium/pkg/grpc/directory" "github.com/pomerium/pomerium/pkg/grpc/directory"
) )
@ -115,22 +116,27 @@ func TestProvider_UserGroups(t *testing.T) {
WithServiceAccount(&ServiceAccount{APIKey: "APITOKEN"}), WithServiceAccount(&ServiceAccount{APIKey: "APITOKEN"}),
WithProviderURL(mustParseURL(srv.URL)), WithProviderURL(mustParseURL(srv.URL)),
) )
users, err := p.UserGroups(context.Background()) groups, users, err := p.UserGroups(context.Background())
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, []*directory.User{ assert.Equal(t, []*directory.User{
{ {
Id: "okta/a@example.com", Id: "okta/a@example.com",
Groups: []string{"admin", "admin-name", "user", "user-name"}, GroupIds: []string{"admin", "user"},
}, },
{ {
Id: "okta/b@example.com", Id: "okta/b@example.com",
Groups: []string{"test", "test-name", "user", "user-name"}, GroupIds: []string{"test", "user"},
}, },
{ {
Id: "okta/c@example.com", Id: "okta/c@example.com",
Groups: []string{"user", "user-name"}, GroupIds: []string{"user"},
}, },
}, users) }, users)
testutil.AssertProtoJSONEqual(t, `[
{ "id": "admin", "name": "admin-name" },
{ "id": "test", "name": "test-name" },
{ "id": "user", "name": "user-name" }
]`, groups)
} }
func mustParseURL(rawurl string) *url.URL { func mustParseURL(rawurl string) *url.URL {

View file

@ -95,57 +95,45 @@ func New(options ...Option) *Provider {
} }
// UserGroups gets the directory user groups for onelogin. // UserGroups gets the directory user groups for onelogin.
func (p *Provider) UserGroups(ctx context.Context) ([]*directory.User, error) { func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*directory.User, error) {
if p.cfg.serviceAccount == nil { if p.cfg.serviceAccount == nil {
return nil, fmt.Errorf("onelogin: service account not defined") return nil, nil, fmt.Errorf("onelogin: service account not defined")
} }
p.log.Info().Msg("getting user groups") p.log.Info().Msg("getting user groups")
token, err := p.getToken(ctx) token, err := p.getToken(ctx)
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
groupIDToName, err := p.getGroupIDToName(ctx, token) groups, err := p.listGroups(ctx, token)
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
userIDToGroupIDs, err := p.getUserIDToGroupIDs(ctx, token) userIDToGroupIDs, err := p.getUserIDToGroupIDs(ctx, token)
if err != nil { if err != nil {
return nil, err return nil, nil, err
}
userIDToGroupNames := map[int][]string{}
for userID, groupIDs := range userIDToGroupIDs {
for _, groupID := range groupIDs {
if groupName, ok := groupIDToName[groupID]; ok {
userIDToGroupNames[userID] = append(userIDToGroupNames[userID], strconv.Itoa(groupID), groupName)
} else {
userIDToGroupNames[userID] = append(userIDToGroupNames[userID], "NOGROUP")
}
}
} }
var users []*directory.User var users []*directory.User
for userID, groups := range userIDToGroupNames { for userID, groupIDs := range userIDToGroupIDs {
sort.Strings(groups) sort.Strings(groupIDs)
users = append(users, &directory.User{ users = append(users, &directory.User{
Id: databroker.GetUserID(Name, strconv.Itoa(userID)), Id: databroker.GetUserID(Name, strconv.Itoa(userID)),
Groups: groups, GroupIds: groupIDs,
}) })
} }
sort.Slice(users, func(i, j int) bool { sort.Slice(users, func(i, j int) bool {
return users[i].Id < users[j].Id return users[i].Id < users[j].Id
}) })
return users, nil return groups, users, nil
} }
func (p *Provider) getGroupIDToName(ctx context.Context, token *oauth2.Token) (map[int]string, error) { func (p *Provider) listGroups(ctx context.Context, token *oauth2.Token) ([]*directory.Group, error) {
groupIDToName := map[int]string{} var groups []*directory.Group
apiURL := p.cfg.apiURL.ResolveReference(&url.URL{ apiURL := p.cfg.apiURL.ResolveReference(&url.URL{
Path: "/api/1/groups", Path: "/api/1/groups",
RawQuery: fmt.Sprintf("limit=%d", p.cfg.batchSize), RawQuery: fmt.Sprintf("limit=%d", p.cfg.batchSize),
@ -161,17 +149,19 @@ func (p *Provider) getGroupIDToName(ctx context.Context, token *oauth2.Token) (m
} }
for _, r := range result { for _, r := range result {
groupIDToName[r.ID] = r.Name groups = append(groups, &directory.Group{
Id: strconv.Itoa(r.ID),
Name: r.Name,
})
} }
apiURL = nextLink apiURL = nextLink
} }
return groups, nil
return groupIDToName, nil
} }
func (p *Provider) getUserIDToGroupIDs(ctx context.Context, token *oauth2.Token) (map[int][]int, error) { func (p *Provider) getUserIDToGroupIDs(ctx context.Context, token *oauth2.Token) (map[int][]string, error) {
userIDToGroupIDs := map[int][]int{} userIDToGroupIDs := map[int][]string{}
apiURL := p.cfg.apiURL.ResolveReference(&url.URL{ apiURL := p.cfg.apiURL.ResolveReference(&url.URL{
Path: "/api/1/users", Path: "/api/1/users",
@ -192,7 +182,7 @@ func (p *Provider) getUserIDToGroupIDs(ctx context.Context, token *oauth2.Token)
if r.GroupID != nil { if r.GroupID != nil {
groupID = *r.GroupID groupID = *r.GroupID
} }
userIDToGroupIDs[r.ID] = append(userIDToGroupIDs[r.ID], groupID) userIDToGroupIDs[r.ID] = append(userIDToGroupIDs[r.ID], strconv.Itoa(groupID))
} }
apiURL = nextLink apiURL = nextLink

View file

@ -15,7 +15,7 @@ import (
"github.com/go-chi/chi/middleware" "github.com/go-chi/chi/middleware"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/pomerium/pomerium/pkg/grpc/directory" "github.com/pomerium/pomerium/internal/testutil"
) )
type M = map[string]interface{} type M = map[string]interface{}
@ -147,22 +147,18 @@ func TestProvider_UserGroups(t *testing.T) {
}), }),
WithURL(mustParseURL(srv.URL)), WithURL(mustParseURL(srv.URL)),
) )
users, err := p.UserGroups(context.Background()) groups, users, err := p.UserGroups(context.Background())
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, []*directory.User{ testutil.AssertProtoJSONEqual(t, `[
{ { "id": "onelogin/111", "groupIds": ["0"] },
Id: "onelogin/111", { "id": "onelogin/222", "groupIds": ["1"] },
Groups: []string{"0", "admin"}, { "id": "onelogin/333", "groupIds": ["2"] }
}, ]`, users)
{ testutil.AssertProtoJSONEqual(t, `[
Id: "onelogin/222", { "id": "0", "name": "admin" },
Groups: []string{"1", "test"}, { "id": "1", "name": "test" },
}, { "id": "2", "name": "user" }
{ ]`, groups)
Id: "onelogin/333",
Groups: []string{"2", "user"},
},
}, users)
} }
func mustParseURL(rawurl string) *url.URL { func mustParseURL(rawurl string) *url.URL {

View file

@ -16,12 +16,15 @@ import (
"github.com/pomerium/pomerium/pkg/grpc/directory" "github.com/pomerium/pomerium/pkg/grpc/directory"
) )
// A Group is a directory Group.
type Group = directory.Group
// A User is a directory User. // A User is a directory User.
type User = directory.User type User = directory.User
// A Provider provides user group directory information. // A Provider provides user group directory information.
type Provider interface { type Provider interface {
UserGroups(ctx context.Context) ([]*User, error) UserGroups(ctx context.Context) ([]*Group, []*User, error)
} }
// GetProvider gets the provider for the given options. // GetProvider gets the provider for the given options.
@ -101,6 +104,6 @@ func GetProvider(options *config.Options) Provider {
type nullProvider struct{} type nullProvider struct{}
func (nullProvider) UserGroups(ctx context.Context) ([]*User, error) { func (nullProvider) UserGroups(ctx context.Context) ([]*Group, []*User, error) {
return nil, nil return nil, nil, nil
} }

View file

@ -87,7 +87,7 @@
disabled disabled
/> />
</label> </label>
{{end}} {{range $i,$_:= .DirectoryUser.Groups}} {{end}} {{range $i,$_:= .DirectoryGroups}}
<label> <label>
{{if eq $i 0}} {{if eq $i 0}}
<span>Group</span> <span>Group</span>
@ -97,8 +97,8 @@
<input <input
type="text" type="text"
class="field" class="field"
value="{{.}}" value="{{.Name}}"
title="{{.}}" title="{{.Id}}"
disabled disabled
/> />
</label> </label>

File diff suppressed because one or more lines are too long

View file

@ -48,7 +48,12 @@ type Manager struct {
directoryUsers map[string]*directory.User directoryUsers map[string]*directory.User
directoryUsersServerVersion string directoryUsersServerVersion string
directoryUsersRecordVersion string directoryUsersRecordVersion string
directoryUsersNextRefresh time.Time
directoryGroups map[string]*directory.Group
directoryGroupsServerVersion string
directoryGroupsRecordVersion string
directoryNextRefresh time.Time
} }
// New creates a new identity manager. // New creates a new identity manager.
@ -83,7 +88,12 @@ func New(
// Run runs the manager. This method blocks until an error occurs or the given context is canceled. // Run runs the manager. This method blocks until an error occurs or the given context is canceled.
func (mgr *Manager) Run(ctx context.Context) error { func (mgr *Manager) Run(ctx context.Context) error {
err := mgr.initDirectoryUsers(ctx) err := mgr.initDirectoryGroups(ctx)
if err != nil {
return fmt.Errorf("failed to initialize directory groups: %w", err)
}
err = mgr.initDirectoryUsers(ctx)
if err != nil { if err != nil {
return fmt.Errorf("failed to initialize directory users: %w", err) return fmt.Errorf("failed to initialize directory users: %w", err)
} }
@ -100,13 +110,18 @@ func (mgr *Manager) Run(ctx context.Context) error {
return mgr.syncUsers(ctx, updatedUser) return mgr.syncUsers(ctx, updatedUser)
}) })
updatedDirectoryGroup := make(chan *directory.Group, 1)
t.Go(func() error {
return mgr.syncDirectoryGroups(ctx, updatedDirectoryGroup)
})
updatedDirectoryUser := make(chan *directory.User, 1) updatedDirectoryUser := make(chan *directory.User, 1)
t.Go(func() error { t.Go(func() error {
return mgr.syncDirectoryUsers(ctx, updatedDirectoryUser) return mgr.syncDirectoryUsers(ctx, updatedDirectoryUser)
}) })
t.Go(func() error { t.Go(func() error {
return mgr.refreshLoop(ctx, updatedSession, updatedUser, updatedDirectoryUser) return mgr.refreshLoop(ctx, updatedSession, updatedUser, updatedDirectoryUser, updatedDirectoryGroup)
}) })
return t.Wait() return t.Wait()
@ -117,11 +132,12 @@ func (mgr *Manager) refreshLoop(
updatedSession <-chan *session.Session, updatedSession <-chan *session.Session,
updatedUser <-chan *user.User, updatedUser <-chan *user.User,
updatedDirectoryUser <-chan *directory.User, updatedDirectoryUser <-chan *directory.User,
updatedDirectoryGroup <-chan *directory.Group,
) error { ) error {
maxWait := time.Minute * 10 maxWait := time.Minute * 10
nextTime := time.Now().Add(maxWait) nextTime := time.Now().Add(maxWait)
if mgr.directoryUsersNextRefresh.Before(nextTime) { if mgr.directoryNextRefresh.Before(nextTime) {
nextTime = mgr.directoryUsersNextRefresh nextTime = mgr.directoryNextRefresh
} }
timer := time.NewTimer(time.Until(nextTime)) timer := time.NewTimer(time.Until(nextTime))
@ -137,6 +153,8 @@ func (mgr *Manager) refreshLoop(
mgr.onUpdateUser(ctx, u) mgr.onUpdateUser(ctx, u)
case du := <-updatedDirectoryUser: case du := <-updatedDirectoryUser:
mgr.onUpdateDirectoryUser(ctx, du) mgr.onUpdateDirectoryUser(ctx, du)
case dg := <-updatedDirectoryGroup:
mgr.onUpdateDirectoryGroup(ctx, dg)
case <-timer.C: case <-timer.C:
} }
@ -144,11 +162,11 @@ func (mgr *Manager) refreshLoop(
nextTime := now.Add(maxWait) nextTime := now.Add(maxWait)
// refresh groups // refresh groups
if mgr.directoryUsersNextRefresh.Before(now) { if mgr.directoryNextRefresh.Before(now) {
mgr.refreshDirectoryUsers(ctx) mgr.refreshDirectoryUserGroups(ctx)
mgr.directoryUsersNextRefresh = now.Add(mgr.cfg.groupRefreshInterval) mgr.directoryNextRefresh = now.Add(mgr.cfg.groupRefreshInterval)
if mgr.directoryUsersNextRefresh.Before(nextTime) { if mgr.directoryNextRefresh.Before(nextTime) {
nextTime = mgr.directoryUsersNextRefresh nextTime = mgr.directoryNextRefresh
} }
} }
@ -185,18 +203,69 @@ func (mgr *Manager) refreshLoop(
} }
} }
func (mgr *Manager) refreshDirectoryUsers(ctx context.Context) { func (mgr *Manager) refreshDirectoryUserGroups(ctx context.Context) {
mgr.log.Info().Msg("refreshing directory users") mgr.log.Info().Msg("refreshing directory users")
ctx, clearTimeout := context.WithTimeout(ctx, mgr.cfg.groupRefreshTimeout) ctx, clearTimeout := context.WithTimeout(ctx, mgr.cfg.groupRefreshTimeout)
defer clearTimeout() defer clearTimeout()
directoryUsers, err := mgr.directory.UserGroups(ctx) directoryGroups, directoryUsers, err := mgr.directory.UserGroups(ctx)
if err != nil { if err != nil {
mgr.log.Warn().Err(err).Msg("failed to refresh directory users and groups") mgr.log.Warn().Err(err).Msg("failed to refresh directory users and groups")
return return
} }
mgr.mergeGroups(ctx, directoryGroups)
mgr.mergeUsers(ctx, directoryUsers)
}
func (mgr *Manager) mergeGroups(ctx context.Context, directoryGroups []*directory.Group) {
lookup := map[string]*directory.Group{}
for _, dg := range directoryGroups {
lookup[dg.GetId()] = dg
}
for groupID, newDG := range lookup {
curDG, ok := mgr.directoryGroups[groupID]
if !ok || !proto.Equal(newDG, curDG) {
any, err := ptypes.MarshalAny(newDG)
if err != nil {
mgr.log.Warn().Err(err).Msg("failed to marshal directory group")
return
}
_, err = mgr.dataBrokerClient.Set(ctx, &databroker.SetRequest{
Type: any.GetTypeUrl(),
Id: newDG.GetId(),
Data: any,
})
if err != nil {
mgr.log.Warn().Err(err).Msg("failed to update directory group")
return
}
}
}
for groupID, curDG := range mgr.directoryGroups {
_, ok := lookup[groupID]
if !ok {
any, err := ptypes.MarshalAny(curDG)
if err != nil {
mgr.log.Warn().Err(err).Msg("failed to marshal directory group")
return
}
_, err = mgr.dataBrokerClient.Delete(ctx, &databroker.DeleteRequest{
Type: any.GetTypeUrl(),
Id: curDG.GetId(),
})
if err != nil {
mgr.log.Warn().Err(err).Msg("failed to delete directory group")
return
}
}
}
}
func (mgr *Manager) mergeUsers(ctx context.Context, directoryUsers []*directory.User) {
lookup := map[string]*directory.User{} lookup := map[string]*directory.User{}
for _, du := range directoryUsers { for _, du := range directoryUsers {
lookup[du.GetId()] = du lookup[du.GetId()] = du
@ -497,6 +566,75 @@ func (mgr *Manager) syncDirectoryUsers(ctx context.Context, ch chan<- *directory
} }
} }
func (mgr *Manager) initDirectoryGroups(ctx context.Context) error {
mgr.log.Info().Msg("initializing directory groups")
any, err := ptypes.MarshalAny(new(directory.Group))
if err != nil {
return err
}
res, err := mgr.dataBrokerClient.GetAll(ctx, &databroker.GetAllRequest{
Type: any.GetTypeUrl(),
})
if err != nil {
return fmt.Errorf("error getting all directory groups: %w", err)
}
mgr.directoryGroups = map[string]*directory.Group{}
for _, record := range res.GetRecords() {
var pbDirectoryGroup directory.Group
err := ptypes.UnmarshalAny(record.GetData(), &pbDirectoryGroup)
if err != nil {
return fmt.Errorf("error unmarshaling directory group: %w", err)
}
mgr.directoryGroups[pbDirectoryGroup.GetId()] = &pbDirectoryGroup
}
mgr.directoryGroupsRecordVersion = res.GetRecordVersion()
mgr.directoryGroupsServerVersion = res.GetServerVersion()
return nil
}
func (mgr *Manager) syncDirectoryGroups(ctx context.Context, ch chan<- *directory.Group) error {
mgr.log.Info().Msg("syncing directory groups")
any, err := ptypes.MarshalAny(new(directory.Group))
if err != nil {
return err
}
client, err := mgr.dataBrokerClient.Sync(ctx, &databroker.SyncRequest{
Type: any.GetTypeUrl(),
ServerVersion: mgr.directoryGroupsServerVersion,
RecordVersion: mgr.directoryGroupsRecordVersion,
})
if err != nil {
return fmt.Errorf("error syncing directory groups: %w", err)
}
for {
res, err := client.Recv()
if err != nil {
return fmt.Errorf("error receiving directory groups: %w", err)
}
for _, record := range res.GetRecords() {
var pbDirectoryGroup directory.Group
err := ptypes.UnmarshalAny(record.GetData(), &pbDirectoryGroup)
if err != nil {
return fmt.Errorf("error unmarshaling directory group: %w", err)
}
select {
case <-ctx.Done():
return ctx.Err()
case ch <- &pbDirectoryGroup:
}
}
}
}
func (mgr *Manager) onUpdateSession(ctx context.Context, pbSession *session.Session) { func (mgr *Manager) onUpdateSession(ctx context.Context, pbSession *session.Session) {
mgr.sessionScheduler.Remove(toSessionSchedulerKey(pbSession.GetUserId(), pbSession.GetId())) mgr.sessionScheduler.Remove(toSessionSchedulerKey(pbSession.GetUserId(), pbSession.GetId()))
@ -543,6 +681,10 @@ func (mgr *Manager) onUpdateDirectoryUser(_ context.Context, pbDirectoryUser *di
mgr.directoryUsers[pbDirectoryUser.GetId()] = pbDirectoryUser mgr.directoryUsers[pbDirectoryUser.GetId()] = pbDirectoryUser
} }
func (mgr *Manager) onUpdateDirectoryGroup(_ context.Context, pbDirectoryGroup *directory.Group) {
mgr.directoryGroups[pbDirectoryGroup.GetId()] = pbDirectoryGroup
}
func (mgr *Manager) createUser(ctx context.Context, pbSession *session.Session) { func (mgr *Manager) createUser(ctx context.Context, pbSession *session.Session) {
u := User{ u := User{
User: &user.User{ User: &user.User{

View file

@ -9,8 +9,28 @@ import (
"github.com/pomerium/pomerium/pkg/grpc/databroker" "github.com/pomerium/pomerium/pkg/grpc/databroker"
) )
// Get gets a directory user from the databroker. // GetGroup gets a directory group from the databroker.
func Get(ctx context.Context, client databroker.DataBrokerServiceClient, userID string) (*User, error) { func GetGroup(ctx context.Context, client databroker.DataBrokerServiceClient, groupID string) (*Group, error) {
any, _ := ptypes.MarshalAny(new(Group))
res, err := client.Get(ctx, &databroker.GetRequest{
Type: any.GetTypeUrl(),
Id: groupID,
})
if err != nil {
return nil, err
}
var g Group
err = ptypes.UnmarshalAny(res.GetRecord().GetData(), &g)
if err != nil {
return nil, err
}
return &g, nil
}
// GetUser gets a directory user from the databroker.
func GetUser(ctx context.Context, client databroker.DataBrokerServiceClient, userID string) (*User, error) {
any, _ := ptypes.MarshalAny(new(User)) any, _ := ptypes.MarshalAny(new(User))
res, err := client.Get(ctx, &databroker.GetRequest{ res, err := client.Get(ctx, &databroker.GetRequest{

View file

@ -30,9 +30,9 @@ type User struct {
sizeCache protoimpl.SizeCache sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields unknownFields protoimpl.UnknownFields
Version string `protobuf:"bytes,1,opt,name=version,proto3" json:"version,omitempty"` Version string `protobuf:"bytes,1,opt,name=version,proto3" json:"version,omitempty"`
Id string `protobuf:"bytes,2,opt,name=id,proto3" json:"id,omitempty"` Id string `protobuf:"bytes,2,opt,name=id,proto3" json:"id,omitempty"`
Groups []string `protobuf:"bytes,3,rep,name=groups,proto3" json:"groups,omitempty"` GroupIds []string `protobuf:"bytes,3,rep,name=group_ids,json=groupIds,proto3" json:"group_ids,omitempty"`
} }
func (x *User) Reset() { func (x *User) Reset() {
@ -81,27 +81,104 @@ func (x *User) GetId() string {
return "" return ""
} }
func (x *User) GetGroups() []string { func (x *User) GetGroupIds() []string {
if x != nil { if x != nil {
return x.Groups return x.GroupIds
} }
return nil return nil
} }
type Group struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Version string `protobuf:"bytes,1,opt,name=version,proto3" json:"version,omitempty"`
Id string `protobuf:"bytes,2,opt,name=id,proto3" json:"id,omitempty"`
Name string `protobuf:"bytes,3,opt,name=name,proto3" json:"name,omitempty"`
Email string `protobuf:"bytes,4,opt,name=email,proto3" json:"email,omitempty"`
}
func (x *Group) Reset() {
*x = Group{}
if protoimpl.UnsafeEnabled {
mi := &file_directory_proto_msgTypes[1]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *Group) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*Group) ProtoMessage() {}
func (x *Group) ProtoReflect() protoreflect.Message {
mi := &file_directory_proto_msgTypes[1]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use Group.ProtoReflect.Descriptor instead.
func (*Group) Descriptor() ([]byte, []int) {
return file_directory_proto_rawDescGZIP(), []int{1}
}
func (x *Group) GetVersion() string {
if x != nil {
return x.Version
}
return ""
}
func (x *Group) GetId() string {
if x != nil {
return x.Id
}
return ""
}
func (x *Group) GetName() string {
if x != nil {
return x.Name
}
return ""
}
func (x *Group) GetEmail() string {
if x != nil {
return x.Email
}
return ""
}
var File_directory_proto protoreflect.FileDescriptor var File_directory_proto protoreflect.FileDescriptor
var file_directory_proto_rawDesc = []byte{ var file_directory_proto_rawDesc = []byte{
0x0a, 0x0f, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x79, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x0a, 0x0f, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x79, 0x2e, 0x70, 0x72, 0x6f, 0x74,
0x6f, 0x12, 0x09, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x79, 0x22, 0x48, 0x0a, 0x04, 0x6f, 0x12, 0x09, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x79, 0x22, 0x4d, 0x0a, 0x04,
0x55, 0x73, 0x65, 0x72, 0x12, 0x18, 0x0a, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x55, 0x73, 0x65, 0x72, 0x12, 0x18, 0x0a, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18,
0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x0e, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x0e,
0x0a, 0x02, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x69, 0x64, 0x12, 0x16, 0x0a, 0x02, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x69, 0x64, 0x12, 0x1b,
0x0a, 0x06, 0x67, 0x72, 0x6f, 0x75, 0x70, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x09, 0x52, 0x06, 0x0a, 0x09, 0x67, 0x72, 0x6f, 0x75, 0x70, 0x5f, 0x69, 0x64, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28,
0x67, 0x72, 0x6f, 0x75, 0x70, 0x73, 0x42, 0x31, 0x5a, 0x2f, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x09, 0x52, 0x08, 0x67, 0x72, 0x6f, 0x75, 0x70, 0x49, 0x64, 0x73, 0x22, 0x5b, 0x0a, 0x05, 0x47,
0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x70, 0x6f, 0x6d, 0x65, 0x72, 0x69, 0x75, 0x6d, 0x2f, 0x70, 0x6f, 0x72, 0x6f, 0x75, 0x70, 0x12, 0x18, 0x0a, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18,
0x6d, 0x65, 0x72, 0x69, 0x75, 0x6d, 0x2f, 0x70, 0x6b, 0x67, 0x2f, 0x67, 0x72, 0x70, 0x63, 0x2f, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x0e,
0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x79, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x0a, 0x02, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x69, 0x64, 0x12, 0x12,
0x33, 0x0a, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x6e, 0x61,
0x6d, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x65, 0x6d, 0x61, 0x69, 0x6c, 0x18, 0x04, 0x20, 0x01, 0x28,
0x09, 0x52, 0x05, 0x65, 0x6d, 0x61, 0x69, 0x6c, 0x42, 0x31, 0x5a, 0x2f, 0x67, 0x69, 0x74, 0x68,
0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x70, 0x6f, 0x6d, 0x65, 0x72, 0x69, 0x75, 0x6d, 0x2f,
0x70, 0x6f, 0x6d, 0x65, 0x72, 0x69, 0x75, 0x6d, 0x2f, 0x70, 0x6b, 0x67, 0x2f, 0x67, 0x72, 0x70,
0x63, 0x2f, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x79, 0x62, 0x06, 0x70, 0x72, 0x6f,
0x74, 0x6f, 0x33,
} }
var ( var (
@ -116,9 +193,10 @@ func file_directory_proto_rawDescGZIP() []byte {
return file_directory_proto_rawDescData return file_directory_proto_rawDescData
} }
var file_directory_proto_msgTypes = make([]protoimpl.MessageInfo, 1) var file_directory_proto_msgTypes = make([]protoimpl.MessageInfo, 2)
var file_directory_proto_goTypes = []interface{}{ var file_directory_proto_goTypes = []interface{}{
(*User)(nil), // 0: directory.User (*User)(nil), // 0: directory.User
(*Group)(nil), // 1: directory.Group
} }
var file_directory_proto_depIdxs = []int32{ var file_directory_proto_depIdxs = []int32{
0, // [0:0] is the sub-list for method output_type 0, // [0:0] is the sub-list for method output_type
@ -146,6 +224,18 @@ func file_directory_proto_init() {
return nil return nil
} }
} }
file_directory_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*Group); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
} }
type x struct{} type x struct{}
out := protoimpl.TypeBuilder{ out := protoimpl.TypeBuilder{
@ -153,7 +243,7 @@ func file_directory_proto_init() {
GoPackagePath: reflect.TypeOf(x{}).PkgPath(), GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: file_directory_proto_rawDesc, RawDescriptor: file_directory_proto_rawDesc,
NumEnums: 0, NumEnums: 0,
NumMessages: 1, NumMessages: 2,
NumExtensions: 0, NumExtensions: 0,
NumServices: 0, NumServices: 0,
}, },

View file

@ -6,5 +6,12 @@ option go_package = "github.com/pomerium/pomerium/pkg/grpc/directory";
message User { message User {
string version = 1; string version = 1;
string id = 2; string id = 2;
repeated string groups = 3; repeated string group_ids = 3;
}
message Group {
string version = 1;
string id = 2;
string name = 3;
string email = 4;
} }