mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-17 02:57:11 +02:00
directory.Group entry for groups (#1118)
* store directory groups separate from directory users * fix group lookup, azure display name * remove fields restriction * fix test * also support email * use Email as name for google' * remove changed file * show groups on dashboard * fix test * re-add accidentally removed code
This commit is contained in:
parent
489cdd8b63
commit
1ad243dfd1
25 changed files with 525 additions and 209 deletions
|
@ -484,17 +484,30 @@ func (a *Authenticate) Dashboard(w http.ResponseWriter, r *http.Request) error {
|
|||
Id: pbSession.GetUserId(),
|
||||
}
|
||||
}
|
||||
pbDirectoryUser, err := directory.Get(r.Context(), a.dataBrokerClient, pbSession.GetUserId())
|
||||
pbDirectoryUser, err := directory.GetUser(r.Context(), a.dataBrokerClient, pbSession.GetUserId())
|
||||
if err != nil {
|
||||
pbDirectoryUser = &directory.User{
|
||||
Id: pbSession.GetUserId(),
|
||||
}
|
||||
}
|
||||
var groups []*directory.Group
|
||||
for _, groupID := range pbDirectoryUser.GetGroupIds() {
|
||||
pbDirectoryGroup, err := directory.GetGroup(r.Context(), a.dataBrokerClient, groupID)
|
||||
if err != nil {
|
||||
pbDirectoryGroup = &directory.Group{
|
||||
Id: groupID,
|
||||
Name: groupID,
|
||||
Email: groupID,
|
||||
}
|
||||
}
|
||||
groups = append(groups, pbDirectoryGroup)
|
||||
}
|
||||
|
||||
input := map[string]interface{}{
|
||||
"State": s,
|
||||
"Session": pbSession,
|
||||
"User": pbUser,
|
||||
"DirectoryGroups": groups,
|
||||
"DirectoryUser": pbDirectoryUser,
|
||||
"csrfField": csrf.TemplateField(r),
|
||||
"ImpersonateAction": urlutil.QueryImpersonateAction,
|
||||
|
|
|
@ -33,6 +33,7 @@ const (
|
|||
sessionTypeURL = "type.googleapis.com/session.Session"
|
||||
userTypeURL = "type.googleapis.com/user.User"
|
||||
directoryUserTypeURL = "type.googleapis.com/directory.User"
|
||||
directoryGroupTypeURL = "type.googleapis.com/directory.Group"
|
||||
)
|
||||
|
||||
// Evaluator specifies the interface for a policy engine.
|
||||
|
@ -217,7 +218,16 @@ func (e *Evaluator) JWTPayload(req *Request) map[string]interface{} {
|
|||
payload["email"] = u.GetEmail()
|
||||
}
|
||||
if du, ok := req.DataBrokerData.Get("type.googleapis.com/directory.User", s.GetUserId()).(*directory.User); ok {
|
||||
payload["groups"] = du.GetGroups()
|
||||
var groupNames []string
|
||||
for _, groupID := range du.GetGroupIds() {
|
||||
if dg, ok := req.DataBrokerData.Get("type.googleapis.com/directory.Group", groupID).(*directory.Group); ok {
|
||||
groupNames = append(groupNames, dg.Name)
|
||||
}
|
||||
}
|
||||
var groups []string
|
||||
groups = append(groups, du.GetGroupIds()...)
|
||||
groups = append(groups, groupNames...)
|
||||
payload["groups"] = groups
|
||||
}
|
||||
}
|
||||
return payload
|
||||
|
@ -257,7 +267,7 @@ type input struct {
|
|||
type dataBrokerDataInput struct {
|
||||
Session interface{} `json:"session,omitempty"`
|
||||
User interface{} `json:"user,omitempty"`
|
||||
DirectoryUser interface{} `json:"directory_user,omitempty"`
|
||||
Groups interface{} `json:"groups,omitempty"`
|
||||
}
|
||||
|
||||
func (e *Evaluator) newInput(req *Request, isValidClientCertificate bool) *input {
|
||||
|
@ -265,7 +275,23 @@ func (e *Evaluator) newInput(req *Request, isValidClientCertificate bool) *input
|
|||
i.DataBrokerData.Session = req.DataBrokerData.Get(sessionTypeURL, req.Session.ID)
|
||||
if obj, ok := i.DataBrokerData.Session.(interface{ GetUserId() string }); ok {
|
||||
i.DataBrokerData.User = req.DataBrokerData.Get(userTypeURL, obj.GetUserId())
|
||||
i.DataBrokerData.DirectoryUser = req.DataBrokerData.Get(directoryUserTypeURL, obj.GetUserId())
|
||||
|
||||
user, ok := req.DataBrokerData.Get(directoryUserTypeURL, obj.GetUserId()).(*directory.User)
|
||||
if ok {
|
||||
var groups []string
|
||||
for _, groupID := range user.GetGroupIds() {
|
||||
if dg, ok := req.DataBrokerData.Get(directoryGroupTypeURL, groupID).(*directory.Group); ok {
|
||||
if dg.Name != "" {
|
||||
groups = append(groups, dg.Name)
|
||||
}
|
||||
if dg.Email != "" {
|
||||
groups = append(groups, dg.Email)
|
||||
}
|
||||
}
|
||||
}
|
||||
groups = append(groups, user.GetGroupIds()...)
|
||||
i.DataBrokerData.Groups = groups
|
||||
}
|
||||
}
|
||||
i.HTTP = req.HTTP
|
||||
i.Session = req.Session
|
||||
|
|
|
@ -26,7 +26,7 @@ func TestJSONMarshal(t *testing.T) {
|
|||
"type.googleapis.com/directory.User": map[string]interface{}{
|
||||
"user1": &directory.User{
|
||||
Id: "user1",
|
||||
Groups: []string{"group1", "group2"},
|
||||
GroupIds: []string{"group1", "group2"},
|
||||
},
|
||||
},
|
||||
"type.googleapis.com/session.Session": map[string]interface{}{},
|
||||
|
|
|
@ -7,7 +7,7 @@ route_policy_idx := first_allowed_route_policy_idx(input.http.url)
|
|||
route_policy := data.route_policies[route_policy_idx]
|
||||
session := input.databroker_data.session
|
||||
user := input.databroker_data.user
|
||||
directory_user := input.databroker_data.directory_user
|
||||
groups := input.databroker_data.groups
|
||||
|
||||
all_allowed_domains := get_allowed_domains(route_policy)
|
||||
all_allowed_groups := get_allowed_groups(route_policy)
|
||||
|
@ -35,7 +35,7 @@ allow {
|
|||
# allow group
|
||||
allow {
|
||||
some group
|
||||
directory_user.groups[_] = group
|
||||
groups[_] = group
|
||||
all_allowed_groups[_] = group
|
||||
input.session.impersonate_groups == null
|
||||
}
|
||||
|
|
|
@ -67,9 +67,7 @@ test_group_allowed {
|
|||
"user": {
|
||||
"email": "x@example.com",
|
||||
},
|
||||
"directory_user": {
|
||||
"groups": ["1"]
|
||||
}
|
||||
} with
|
||||
input.http as { "url": "http://example.com" } with
|
||||
input.session as { "id": "session1", "impersonate_groups": null }
|
||||
|
|
File diff suppressed because one or more lines are too long
1
go.sum
1
go.sum
|
@ -208,7 +208,6 @@ github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw
|
|||
github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
|
||||
github.com/gomodule/redigo v1.8.2 h1:H5XSIre1MB5NbPYFp+i1NBbb5qN1W8Y8YAQoAYbkm8k=
|
||||
github.com/gomodule/redigo v1.8.2/go.mod h1:P9dn9mFrCBvWhGE1wpxx6fgq7BAeLBk+UUUzlpkBYO0=
|
||||
github.com/gomodule/redigo/redis v0.0.0-do-not-use h1:J7XIp6Kau0WoyT4JtXHT3Ei0gA1KkSc6bc87j9v9WIo=
|
||||
github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ=
|
||||
github.com/google/btree v1.0.0 h1:0udJVsspx3VBr5FwtLhQQtuAsVc79tTq0ocGIPAU6qo=
|
||||
github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ=
|
||||
|
|
|
@ -101,25 +101,25 @@ func New(options ...Option) *Provider {
|
|||
}
|
||||
|
||||
// UserGroups returns the directory users in azure active directory.
|
||||
func (p *Provider) UserGroups(ctx context.Context) ([]*directory.User, error) {
|
||||
func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*directory.User, error) {
|
||||
if p.cfg.serviceAccount == nil {
|
||||
return nil, fmt.Errorf("azure: service account not defined")
|
||||
return nil, nil, fmt.Errorf("azure: service account not defined")
|
||||
}
|
||||
|
||||
groupIDs, err := p.listGroups(ctx)
|
||||
groups, err := p.listGroups(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
userIDToGroupIDs := map[string][]string{}
|
||||
for groupID, groupName := range groupIDs {
|
||||
userIDs, err := p.listGroupMembers(ctx, groupID)
|
||||
for _, group := range groups {
|
||||
userIDs, err := p.listGroupMembers(ctx, group.Id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
for _, userID := range userIDs {
|
||||
userIDToGroupIDs[userID] = append(userIDToGroupIDs[userID], groupID, groupName)
|
||||
userIDToGroupIDs[userID] = append(userIDToGroupIDs[userID], group.Id)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -128,27 +128,27 @@ func (p *Provider) UserGroups(ctx context.Context) ([]*directory.User, error) {
|
|||
sort.Strings(groupIDs)
|
||||
users = append(users, &directory.User{
|
||||
Id: databroker.GetUserID(Name, userID),
|
||||
Groups: groupIDs,
|
||||
GroupIds: groupIDs,
|
||||
})
|
||||
}
|
||||
sort.Slice(users, func(i, j int) bool {
|
||||
return users[i].GetId() < users[j].GetId()
|
||||
})
|
||||
return users, nil
|
||||
return groups, users, nil
|
||||
}
|
||||
|
||||
// listGroups returns a map, with key is group ID, element is group name.
|
||||
func (p *Provider) listGroups(ctx context.Context) (map[string]string, error) {
|
||||
func (p *Provider) listGroups(ctx context.Context) ([]*directory.Group, error) {
|
||||
nextURL := p.cfg.graphURL.ResolveReference(&url.URL{
|
||||
Path: "/v1.0/groups",
|
||||
}).String()
|
||||
|
||||
groups := make(map[string]string)
|
||||
var groups []*directory.Group
|
||||
for nextURL != "" {
|
||||
var result struct {
|
||||
Value []struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
DisplayName string `json:"displayName"`
|
||||
} `json:"value"`
|
||||
NextLink string `json:"@odata.nextLink"`
|
||||
}
|
||||
|
@ -157,7 +157,10 @@ func (p *Provider) listGroups(ctx context.Context) (map[string]string, error) {
|
|||
return nil, err
|
||||
}
|
||||
for _, v := range result.Value {
|
||||
groups[v.ID] = v.Name
|
||||
groups = append(groups, &directory.Group{
|
||||
Id: v.ID,
|
||||
Name: v.DisplayName,
|
||||
})
|
||||
}
|
||||
nextURL = result.NextLink
|
||||
}
|
||||
|
|
|
@ -45,8 +45,8 @@ func newMockAPI(t *testing.T, srv *httptest.Server) http.Handler {
|
|||
r.Get("/groups", func(w http.ResponseWriter, r *http.Request) {
|
||||
_ = json.NewEncoder(w).Encode(M{
|
||||
"value": []M{
|
||||
{"id": "admin", "name": "Admin Group"},
|
||||
{"id": "test", "name": "Test Group"},
|
||||
{"id": "admin", "displayName": "Admin Group"},
|
||||
{"id": "test", "displayName": "Test Group"},
|
||||
},
|
||||
})
|
||||
})
|
||||
|
@ -85,22 +85,26 @@ func Test(t *testing.T) {
|
|||
DirectoryID: "DIRECTORY_ID",
|
||||
}),
|
||||
)
|
||||
users, err := p.UserGroups(context.Background())
|
||||
groups, users, err := p.UserGroups(context.Background())
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []*directory.User{
|
||||
{
|
||||
Id: "azure/user-1",
|
||||
Groups: []string{"Admin Group", "admin"},
|
||||
GroupIds: []string{"admin"},
|
||||
},
|
||||
{
|
||||
Id: "azure/user-2",
|
||||
Groups: []string{"Test Group", "test"},
|
||||
GroupIds: []string{"test"},
|
||||
},
|
||||
{
|
||||
Id: "azure/user-3",
|
||||
Groups: []string{"Test Group", "test"},
|
||||
GroupIds: []string{"test"},
|
||||
},
|
||||
}, users)
|
||||
assert.Equal(t, []*directory.Group{
|
||||
{Id: "admin", Name: "Admin Group"},
|
||||
{Id: "test", Name: "Test Group"},
|
||||
}, groups)
|
||||
}
|
||||
|
||||
func mustParseURL(rawurl string) *url.URL {
|
||||
|
|
|
@ -86,49 +86,52 @@ func New(options ...Option) *Provider {
|
|||
}
|
||||
|
||||
// UserGroups gets the directory user groups for github.
|
||||
func (p *Provider) UserGroups(ctx context.Context) ([]*directory.User, error) {
|
||||
func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*directory.User, error) {
|
||||
if p.cfg.serviceAccount == nil {
|
||||
return nil, fmt.Errorf("github: service account not defined")
|
||||
return nil, nil, fmt.Errorf("github: service account not defined")
|
||||
}
|
||||
|
||||
orgSlugs, err := p.listOrgs(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
userLoginToGroups := map[string][]string{}
|
||||
|
||||
var allGroups []*directory.Group
|
||||
for _, orgSlug := range orgSlugs {
|
||||
teamSlugs, err := p.listTeams(ctx, orgSlug)
|
||||
groups, err := p.listGroups(ctx, orgSlug)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
for teamSlug, teamID := range teamSlugs {
|
||||
userLogins, err := p.listTeamMembers(ctx, orgSlug, teamSlug)
|
||||
for _, group := range groups {
|
||||
userLogins, err := p.listTeamMembers(ctx, orgSlug, group.Name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
for _, userLogin := range userLogins {
|
||||
userLoginToGroups[userLogin] = append(userLoginToGroups[userLogin], teamSlug, strconv.Itoa(teamID))
|
||||
userLoginToGroups[userLogin] = append(userLoginToGroups[userLogin], group.Id)
|
||||
}
|
||||
}
|
||||
|
||||
allGroups = append(allGroups, groups...)
|
||||
}
|
||||
|
||||
var users []*directory.User
|
||||
for userLogin, groups := range userLoginToGroups {
|
||||
user := &directory.User{
|
||||
Id: databroker.GetUserID(Name, userLogin),
|
||||
Groups: groups,
|
||||
GroupIds: groups,
|
||||
}
|
||||
sort.Strings(user.Groups)
|
||||
sort.Strings(user.GroupIds)
|
||||
users = append(users, user)
|
||||
}
|
||||
sort.Slice(users, func(i, j int) bool {
|
||||
return users[i].GetId() < users[j].GetId()
|
||||
})
|
||||
return users, nil
|
||||
return allGroups, users, nil
|
||||
}
|
||||
|
||||
func (p *Provider) listOrgs(ctx context.Context) (orgSlugs []string, err error) {
|
||||
|
@ -155,12 +158,12 @@ func (p *Provider) listOrgs(ctx context.Context) (orgSlugs []string, err error)
|
|||
return orgSlugs, nil
|
||||
}
|
||||
|
||||
func (p *Provider) listTeams(ctx context.Context, orgSlug string) (map[string]int, error) {
|
||||
func (p *Provider) listGroups(ctx context.Context, orgSlug string) ([]*directory.Group, error) {
|
||||
nextURL := p.cfg.url.ResolveReference(&url.URL{
|
||||
Path: fmt.Sprintf("/orgs/%s/teams", orgSlug),
|
||||
}).String()
|
||||
|
||||
teamSlugs := make(map[string]int)
|
||||
var groups []*directory.Group
|
||||
for nextURL != "" {
|
||||
var results []struct {
|
||||
ID int `json:"id"`
|
||||
|
@ -172,13 +175,16 @@ func (p *Provider) listTeams(ctx context.Context, orgSlug string) (map[string]in
|
|||
}
|
||||
|
||||
for _, result := range results {
|
||||
teamSlugs[result.Slug] = result.ID
|
||||
groups = append(groups, &directory.Group{
|
||||
Id: strconv.Itoa(result.ID),
|
||||
Name: result.Slug,
|
||||
})
|
||||
}
|
||||
|
||||
nextURL = getNextLink(hdrs)
|
||||
}
|
||||
|
||||
return teamSlugs, nil
|
||||
return groups, nil
|
||||
}
|
||||
|
||||
func (p *Provider) listTeamMembers(ctx context.Context, orgSlug, teamSlug string) (userLogins []string, err error) {
|
||||
|
|
|
@ -93,14 +93,20 @@ func Test(t *testing.T) {
|
|||
PersonalAccessToken: "xyz",
|
||||
}),
|
||||
)
|
||||
users, err := p.UserGroups(context.Background())
|
||||
groups, users, err := p.UserGroups(context.Background())
|
||||
assert.NoError(t, err)
|
||||
testutil.AssertProtoJSONEqual(t, `[
|
||||
{ "id": "github/user1", "groups": ["1", "2", "3", "team1", "team2", "team3"] },
|
||||
{ "id": "github/user2", "groups": ["1", "3", "team1", "team3"] },
|
||||
{ "id": "github/user3", "groups": ["3", "team3"] },
|
||||
{ "id": "github/user4", "groups": ["4", "team4"] }
|
||||
{ "id": "github/user1", "groupIds": ["1", "2", "3"] },
|
||||
{ "id": "github/user2", "groupIds": ["1", "3"] },
|
||||
{ "id": "github/user3", "groupIds": ["3"] },
|
||||
{ "id": "github/user4", "groupIds": ["4"] }
|
||||
]`, users)
|
||||
testutil.AssertProtoJSONEqual(t, `[
|
||||
{ "id": "1", "name": "team1" },
|
||||
{ "id": "2", "name": "team2" },
|
||||
{ "id": "3", "name": "team3" },
|
||||
{ "id": "4", "name": "team4" }
|
||||
]`, groups)
|
||||
}
|
||||
|
||||
func mustParseURL(rawurl string) *url.URL {
|
||||
|
|
|
@ -84,27 +84,27 @@ func New(options ...Option) *Provider {
|
|||
}
|
||||
|
||||
// UserGroups gets the directory user groups for gitlab.
|
||||
func (p *Provider) UserGroups(ctx context.Context) ([]*directory.User, error) {
|
||||
func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*directory.User, error) {
|
||||
if p.cfg.serviceAccount == nil {
|
||||
return nil, fmt.Errorf("gitlab: service account not defined")
|
||||
return nil, nil, fmt.Errorf("gitlab: service account not defined")
|
||||
}
|
||||
|
||||
p.log.Info().Msg("getting user groups")
|
||||
|
||||
groups, err := p.listGroups(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
userIDToGroupIDs := map[int][]string{}
|
||||
for groupID, groupName := range groups {
|
||||
userIDs, err := p.listGroupMemberIDs(ctx, groupID)
|
||||
for _, group := range groups {
|
||||
userIDs, err := p.listGroupMemberIDs(ctx, group.Id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
for _, userID := range userIDs {
|
||||
userIDToGroupIDs[userID] = append(userIDToGroupIDs[userID], strconv.Itoa(groupID), groupName)
|
||||
userIDToGroupIDs[userID] = append(userIDToGroupIDs[userID], group.Id)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -114,23 +114,23 @@ func (p *Provider) UserGroups(ctx context.Context) ([]*directory.User, error) {
|
|||
Id: databroker.GetUserID(Name, fmt.Sprint(userID)),
|
||||
}
|
||||
|
||||
user.Groups = append(user.Groups, groups...)
|
||||
user.GroupIds = append(user.GroupIds, groups...)
|
||||
|
||||
sort.Strings(user.Groups)
|
||||
sort.Strings(user.GroupIds)
|
||||
users = append(users, user)
|
||||
}
|
||||
sort.Slice(users, func(i, j int) bool {
|
||||
return users[i].GetId() < users[j].GetId()
|
||||
})
|
||||
return users, nil
|
||||
return groups, users, nil
|
||||
}
|
||||
|
||||
// listGroups returns a map, with key is group ID, element is group name.
|
||||
func (p *Provider) listGroups(ctx context.Context) (map[int]string, error) {
|
||||
func (p *Provider) listGroups(ctx context.Context) ([]*directory.Group, error) {
|
||||
nextURL := p.cfg.url.ResolveReference(&url.URL{
|
||||
Path: "/api/v4/groups",
|
||||
}).String()
|
||||
groups := make(map[int]string)
|
||||
var groups []*directory.Group
|
||||
for nextURL != "" {
|
||||
var result []struct {
|
||||
ID int `json:"id"`
|
||||
|
@ -142,7 +142,10 @@ func (p *Provider) listGroups(ctx context.Context) (map[int]string, error) {
|
|||
}
|
||||
|
||||
for _, r := range result {
|
||||
groups[r.ID] = r.Name
|
||||
groups = append(groups, &directory.Group{
|
||||
Id: strconv.Itoa(r.ID),
|
||||
Name: r.Name,
|
||||
})
|
||||
}
|
||||
|
||||
nextURL = getNextLink(hdrs)
|
||||
|
@ -150,9 +153,9 @@ func (p *Provider) listGroups(ctx context.Context) (map[int]string, error) {
|
|||
return groups, nil
|
||||
}
|
||||
|
||||
func (p *Provider) listGroupMemberIDs(ctx context.Context, groupID int) (userIDs []int, err error) {
|
||||
func (p *Provider) listGroupMemberIDs(ctx context.Context, groupID string) (userIDs []int, err error) {
|
||||
nextURL := p.cfg.url.ResolveReference(&url.URL{
|
||||
Path: fmt.Sprintf("/api/v4/groups/%d/members", groupID),
|
||||
Path: fmt.Sprintf("/api/v4/groups/%s/members", groupID),
|
||||
}).String()
|
||||
for nextURL != "" {
|
||||
var result []struct {
|
||||
|
|
|
@ -66,13 +66,17 @@ func Test(t *testing.T) {
|
|||
PrivateToken: "PRIVATE_TOKEN",
|
||||
}),
|
||||
)
|
||||
users, err := p.UserGroups(context.Background())
|
||||
groups, users, err := p.UserGroups(context.Background())
|
||||
assert.NoError(t, err)
|
||||
testutil.AssertProtoJSONEqual(t, `[
|
||||
{ "id": "gitlab/11", "groups": ["1", "Group 1"] },
|
||||
{ "id": "gitlab/12", "groups": ["2", "Group 2"] },
|
||||
{ "id": "gitlab/13", "groups": ["2", "Group 2"] }
|
||||
{ "id": "gitlab/11", "groupIds": ["1"] },
|
||||
{ "id": "gitlab/12", "groupIds": ["2"] },
|
||||
{ "id": "gitlab/13", "groupIds": ["2"] }
|
||||
]`, users)
|
||||
testutil.AssertProtoJSONEqual(t, `[
|
||||
{ "id": "1", "name": "Group 1" },
|
||||
{ "id": "2", "name": "Group 2" }
|
||||
]`, groups)
|
||||
}
|
||||
|
||||
func mustParseURL(rawurl string) *url.URL {
|
||||
|
|
|
@ -82,17 +82,15 @@ func New(options ...Option) *Provider {
|
|||
// NOTE: groups via Directory API is limited to 1 QPS!
|
||||
// https://developers.google.com/admin-sdk/directory/v1/reference/groups/list
|
||||
// https://developers.google.com/admin-sdk/directory/v1/limits
|
||||
func (p *Provider) UserGroups(ctx context.Context) ([]*directory.User, error) {
|
||||
func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*directory.User, error) {
|
||||
apiClient, err := p.getAPIClient(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("google: error getting API client: %w", err)
|
||||
return nil, nil, fmt.Errorf("google: error getting API client: %w", err)
|
||||
}
|
||||
|
||||
groupIDToEmails := map[string]string{}
|
||||
var groups []string
|
||||
var groups []*directory.Group
|
||||
err = apiClient.Groups.List().
|
||||
Context(ctx).
|
||||
Fields("id", "email", "directMembersCount").
|
||||
Customer("my_customer").
|
||||
Pages(ctx, func(res *admin.Groups) error {
|
||||
for _, g := range res.Groups {
|
||||
|
@ -100,29 +98,31 @@ func (p *Provider) UserGroups(ctx context.Context) ([]*directory.User, error) {
|
|||
if g.DirectMembersCount == 0 {
|
||||
continue
|
||||
}
|
||||
groups = append(groups, g.Id)
|
||||
groupIDToEmails[g.Id] = g.Email
|
||||
groups = append(groups, &directory.Group{
|
||||
Id: g.Id,
|
||||
Name: g.Email,
|
||||
Email: g.Email,
|
||||
})
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("google: error getting groups: %w", err)
|
||||
return nil, nil, fmt.Errorf("google: error getting groups: %w", err)
|
||||
}
|
||||
|
||||
userIDToGroups := map[string][]string{}
|
||||
for _, group := range groups {
|
||||
group := group
|
||||
err = apiClient.Members.List(group).
|
||||
err = apiClient.Members.List(group.Id).
|
||||
Context(ctx).
|
||||
Fields("id").
|
||||
Pages(ctx, func(res *admin.Members) error {
|
||||
for _, member := range res.Members {
|
||||
userIDToGroups[member.Id] = append(userIDToGroups[member.Id], group, groupIDToEmails[group])
|
||||
userIDToGroups[member.Id] = append(userIDToGroups[member.Id], group.Id)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("google: error getting group members: %w", err)
|
||||
return nil, nil, fmt.Errorf("google: error getting group members: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -131,13 +131,13 @@ func (p *Provider) UserGroups(ctx context.Context) ([]*directory.User, error) {
|
|||
sort.Strings(groups)
|
||||
users = append(users, &directory.User{
|
||||
Id: databroker.GetUserID(Name, userID),
|
||||
Groups: groups,
|
||||
GroupIds: groups,
|
||||
})
|
||||
}
|
||||
sort.Slice(users, func(i, j int) bool {
|
||||
return users[i].Id < users[j].Id
|
||||
})
|
||||
return users, nil
|
||||
return groups, users, nil
|
||||
}
|
||||
|
||||
func (p *Provider) getAPIClient(ctx context.Context) (*admin.Service, error) {
|
||||
|
|
|
@ -85,30 +85,30 @@ func New(options ...Option) *Provider {
|
|||
|
||||
// UserGroups fetches the groups of which the user is a member
|
||||
// https://developer.okta.com/docs/reference/api/users/#get-user-s-groups
|
||||
func (p *Provider) UserGroups(ctx context.Context) ([]*directory.User, error) {
|
||||
func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*directory.User, error) {
|
||||
if p.cfg.serviceAccount == nil {
|
||||
return nil, fmt.Errorf("okta: service account not defined")
|
||||
return nil, nil, fmt.Errorf("okta: service account not defined")
|
||||
}
|
||||
|
||||
p.log.Info().Msg("getting user groups")
|
||||
|
||||
if p.cfg.providerURL == nil {
|
||||
return nil, fmt.Errorf("okta: provider url not defined")
|
||||
return nil, nil, fmt.Errorf("okta: provider url not defined")
|
||||
}
|
||||
|
||||
groupIDToName, err := p.getGroups(ctx)
|
||||
groups, err := p.getGroups(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
userIDToGroups := map[string][]string{}
|
||||
for groupID, groupName := range groupIDToName {
|
||||
ids, err := p.getGroupMemberIDs(ctx, groupID)
|
||||
for _, group := range groups {
|
||||
ids, err := p.getGroupMemberIDs(ctx, group.Id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
for _, id := range ids {
|
||||
userIDToGroups[id] = append(userIDToGroups[id], groupID, groupName)
|
||||
userIDToGroups[id] = append(userIDToGroups[id], group.Id)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -117,18 +117,17 @@ func (p *Provider) UserGroups(ctx context.Context) ([]*directory.User, error) {
|
|||
sort.Strings(groups)
|
||||
users = append(users, &directory.User{
|
||||
Id: databroker.GetUserID(Name, userID),
|
||||
Groups: groups,
|
||||
GroupIds: groups,
|
||||
})
|
||||
}
|
||||
sort.Slice(users, func(i, j int) bool {
|
||||
return users[i].Id < users[j].Id
|
||||
})
|
||||
return users, nil
|
||||
return groups, users, nil
|
||||
}
|
||||
|
||||
func (p *Provider) getGroups(ctx context.Context) (map[string]string, error) {
|
||||
groups := map[string]string{}
|
||||
|
||||
func (p *Provider) getGroups(ctx context.Context) ([]*directory.Group, error) {
|
||||
var groups []*directory.Group
|
||||
groupURL := p.cfg.providerURL.ResolveReference(&url.URL{
|
||||
Path: "/api/v1/groups",
|
||||
RawQuery: fmt.Sprintf("limit=%d", p.cfg.batchSize),
|
||||
|
@ -146,12 +145,14 @@ func (p *Provider) getGroups(ctx context.Context) (map[string]string, error) {
|
|||
}
|
||||
|
||||
for _, el := range out {
|
||||
groups[el.ID] = el.Profile.Name
|
||||
groups = append(groups, &directory.Group{
|
||||
Id: el.ID,
|
||||
Name: el.Profile.Name,
|
||||
})
|
||||
}
|
||||
|
||||
groupURL = getNextLink(hdrs)
|
||||
}
|
||||
|
||||
return groups, nil
|
||||
}
|
||||
|
||||
|
|
|
@ -14,6 +14,7 @@ import (
|
|||
"github.com/stretchr/testify/assert"
|
||||
"github.com/tomnomnom/linkheader"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/testutil"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/directory"
|
||||
)
|
||||
|
||||
|
@ -115,22 +116,27 @@ func TestProvider_UserGroups(t *testing.T) {
|
|||
WithServiceAccount(&ServiceAccount{APIKey: "APITOKEN"}),
|
||||
WithProviderURL(mustParseURL(srv.URL)),
|
||||
)
|
||||
users, err := p.UserGroups(context.Background())
|
||||
groups, users, err := p.UserGroups(context.Background())
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []*directory.User{
|
||||
{
|
||||
Id: "okta/a@example.com",
|
||||
Groups: []string{"admin", "admin-name", "user", "user-name"},
|
||||
GroupIds: []string{"admin", "user"},
|
||||
},
|
||||
{
|
||||
Id: "okta/b@example.com",
|
||||
Groups: []string{"test", "test-name", "user", "user-name"},
|
||||
GroupIds: []string{"test", "user"},
|
||||
},
|
||||
{
|
||||
Id: "okta/c@example.com",
|
||||
Groups: []string{"user", "user-name"},
|
||||
GroupIds: []string{"user"},
|
||||
},
|
||||
}, users)
|
||||
testutil.AssertProtoJSONEqual(t, `[
|
||||
{ "id": "admin", "name": "admin-name" },
|
||||
{ "id": "test", "name": "test-name" },
|
||||
{ "id": "user", "name": "user-name" }
|
||||
]`, groups)
|
||||
}
|
||||
|
||||
func mustParseURL(rawurl string) *url.URL {
|
||||
|
|
|
@ -95,57 +95,45 @@ func New(options ...Option) *Provider {
|
|||
}
|
||||
|
||||
// UserGroups gets the directory user groups for onelogin.
|
||||
func (p *Provider) UserGroups(ctx context.Context) ([]*directory.User, error) {
|
||||
func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*directory.User, error) {
|
||||
if p.cfg.serviceAccount == nil {
|
||||
return nil, fmt.Errorf("onelogin: service account not defined")
|
||||
return nil, nil, fmt.Errorf("onelogin: service account not defined")
|
||||
}
|
||||
|
||||
p.log.Info().Msg("getting user groups")
|
||||
|
||||
token, err := p.getToken(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
groupIDToName, err := p.getGroupIDToName(ctx, token)
|
||||
groups, err := p.listGroups(ctx, token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
userIDToGroupIDs, err := p.getUserIDToGroupIDs(ctx, token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
userIDToGroupNames := map[int][]string{}
|
||||
for userID, groupIDs := range userIDToGroupIDs {
|
||||
for _, groupID := range groupIDs {
|
||||
if groupName, ok := groupIDToName[groupID]; ok {
|
||||
userIDToGroupNames[userID] = append(userIDToGroupNames[userID], strconv.Itoa(groupID), groupName)
|
||||
} else {
|
||||
userIDToGroupNames[userID] = append(userIDToGroupNames[userID], "NOGROUP")
|
||||
}
|
||||
}
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
var users []*directory.User
|
||||
for userID, groups := range userIDToGroupNames {
|
||||
sort.Strings(groups)
|
||||
for userID, groupIDs := range userIDToGroupIDs {
|
||||
sort.Strings(groupIDs)
|
||||
users = append(users, &directory.User{
|
||||
Id: databroker.GetUserID(Name, strconv.Itoa(userID)),
|
||||
Groups: groups,
|
||||
GroupIds: groupIDs,
|
||||
})
|
||||
}
|
||||
|
||||
sort.Slice(users, func(i, j int) bool {
|
||||
return users[i].Id < users[j].Id
|
||||
})
|
||||
return users, nil
|
||||
return groups, users, nil
|
||||
}
|
||||
|
||||
func (p *Provider) getGroupIDToName(ctx context.Context, token *oauth2.Token) (map[int]string, error) {
|
||||
groupIDToName := map[int]string{}
|
||||
|
||||
func (p *Provider) listGroups(ctx context.Context, token *oauth2.Token) ([]*directory.Group, error) {
|
||||
var groups []*directory.Group
|
||||
apiURL := p.cfg.apiURL.ResolveReference(&url.URL{
|
||||
Path: "/api/1/groups",
|
||||
RawQuery: fmt.Sprintf("limit=%d", p.cfg.batchSize),
|
||||
|
@ -161,17 +149,19 @@ func (p *Provider) getGroupIDToName(ctx context.Context, token *oauth2.Token) (m
|
|||
}
|
||||
|
||||
for _, r := range result {
|
||||
groupIDToName[r.ID] = r.Name
|
||||
groups = append(groups, &directory.Group{
|
||||
Id: strconv.Itoa(r.ID),
|
||||
Name: r.Name,
|
||||
})
|
||||
}
|
||||
|
||||
apiURL = nextLink
|
||||
}
|
||||
|
||||
return groupIDToName, nil
|
||||
return groups, nil
|
||||
}
|
||||
|
||||
func (p *Provider) getUserIDToGroupIDs(ctx context.Context, token *oauth2.Token) (map[int][]int, error) {
|
||||
userIDToGroupIDs := map[int][]int{}
|
||||
func (p *Provider) getUserIDToGroupIDs(ctx context.Context, token *oauth2.Token) (map[int][]string, error) {
|
||||
userIDToGroupIDs := map[int][]string{}
|
||||
|
||||
apiURL := p.cfg.apiURL.ResolveReference(&url.URL{
|
||||
Path: "/api/1/users",
|
||||
|
@ -192,7 +182,7 @@ func (p *Provider) getUserIDToGroupIDs(ctx context.Context, token *oauth2.Token)
|
|||
if r.GroupID != nil {
|
||||
groupID = *r.GroupID
|
||||
}
|
||||
userIDToGroupIDs[r.ID] = append(userIDToGroupIDs[r.ID], groupID)
|
||||
userIDToGroupIDs[r.ID] = append(userIDToGroupIDs[r.ID], strconv.Itoa(groupID))
|
||||
}
|
||||
|
||||
apiURL = nextLink
|
||||
|
|
|
@ -15,7 +15,7 @@ import (
|
|||
"github.com/go-chi/chi/middleware"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/pomerium/pomerium/pkg/grpc/directory"
|
||||
"github.com/pomerium/pomerium/internal/testutil"
|
||||
)
|
||||
|
||||
type M = map[string]interface{}
|
||||
|
@ -147,22 +147,18 @@ func TestProvider_UserGroups(t *testing.T) {
|
|||
}),
|
||||
WithURL(mustParseURL(srv.URL)),
|
||||
)
|
||||
users, err := p.UserGroups(context.Background())
|
||||
groups, users, err := p.UserGroups(context.Background())
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []*directory.User{
|
||||
{
|
||||
Id: "onelogin/111",
|
||||
Groups: []string{"0", "admin"},
|
||||
},
|
||||
{
|
||||
Id: "onelogin/222",
|
||||
Groups: []string{"1", "test"},
|
||||
},
|
||||
{
|
||||
Id: "onelogin/333",
|
||||
Groups: []string{"2", "user"},
|
||||
},
|
||||
}, users)
|
||||
testutil.AssertProtoJSONEqual(t, `[
|
||||
{ "id": "onelogin/111", "groupIds": ["0"] },
|
||||
{ "id": "onelogin/222", "groupIds": ["1"] },
|
||||
{ "id": "onelogin/333", "groupIds": ["2"] }
|
||||
]`, users)
|
||||
testutil.AssertProtoJSONEqual(t, `[
|
||||
{ "id": "0", "name": "admin" },
|
||||
{ "id": "1", "name": "test" },
|
||||
{ "id": "2", "name": "user" }
|
||||
]`, groups)
|
||||
}
|
||||
|
||||
func mustParseURL(rawurl string) *url.URL {
|
||||
|
|
|
@ -16,12 +16,15 @@ import (
|
|||
"github.com/pomerium/pomerium/pkg/grpc/directory"
|
||||
)
|
||||
|
||||
// A Group is a directory Group.
|
||||
type Group = directory.Group
|
||||
|
||||
// A User is a directory User.
|
||||
type User = directory.User
|
||||
|
||||
// A Provider provides user group directory information.
|
||||
type Provider interface {
|
||||
UserGroups(ctx context.Context) ([]*User, error)
|
||||
UserGroups(ctx context.Context) ([]*Group, []*User, error)
|
||||
}
|
||||
|
||||
// GetProvider gets the provider for the given options.
|
||||
|
@ -101,6 +104,6 @@ func GetProvider(options *config.Options) Provider {
|
|||
|
||||
type nullProvider struct{}
|
||||
|
||||
func (nullProvider) UserGroups(ctx context.Context) ([]*User, error) {
|
||||
return nil, nil
|
||||
func (nullProvider) UserGroups(ctx context.Context) ([]*Group, []*User, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
|
|
|
@ -87,7 +87,7 @@
|
|||
disabled
|
||||
/>
|
||||
</label>
|
||||
{{end}} {{range $i,$_:= .DirectoryUser.Groups}}
|
||||
{{end}} {{range $i,$_:= .DirectoryGroups}}
|
||||
<label>
|
||||
{{if eq $i 0}}
|
||||
<span>Group</span>
|
||||
|
@ -97,8 +97,8 @@
|
|||
<input
|
||||
type="text"
|
||||
class="field"
|
||||
value="{{.}}"
|
||||
title="{{.}}"
|
||||
value="{{.Name}}"
|
||||
title="{{.Id}}"
|
||||
disabled
|
||||
/>
|
||||
</label>
|
||||
|
|
File diff suppressed because one or more lines are too long
|
@ -48,7 +48,12 @@ type Manager struct {
|
|||
directoryUsers map[string]*directory.User
|
||||
directoryUsersServerVersion string
|
||||
directoryUsersRecordVersion string
|
||||
directoryUsersNextRefresh time.Time
|
||||
|
||||
directoryGroups map[string]*directory.Group
|
||||
directoryGroupsServerVersion string
|
||||
directoryGroupsRecordVersion string
|
||||
|
||||
directoryNextRefresh time.Time
|
||||
}
|
||||
|
||||
// New creates a new identity manager.
|
||||
|
@ -83,7 +88,12 @@ func New(
|
|||
|
||||
// Run runs the manager. This method blocks until an error occurs or the given context is canceled.
|
||||
func (mgr *Manager) Run(ctx context.Context) error {
|
||||
err := mgr.initDirectoryUsers(ctx)
|
||||
err := mgr.initDirectoryGroups(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to initialize directory groups: %w", err)
|
||||
}
|
||||
|
||||
err = mgr.initDirectoryUsers(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to initialize directory users: %w", err)
|
||||
}
|
||||
|
@ -100,13 +110,18 @@ func (mgr *Manager) Run(ctx context.Context) error {
|
|||
return mgr.syncUsers(ctx, updatedUser)
|
||||
})
|
||||
|
||||
updatedDirectoryGroup := make(chan *directory.Group, 1)
|
||||
t.Go(func() error {
|
||||
return mgr.syncDirectoryGroups(ctx, updatedDirectoryGroup)
|
||||
})
|
||||
|
||||
updatedDirectoryUser := make(chan *directory.User, 1)
|
||||
t.Go(func() error {
|
||||
return mgr.syncDirectoryUsers(ctx, updatedDirectoryUser)
|
||||
})
|
||||
|
||||
t.Go(func() error {
|
||||
return mgr.refreshLoop(ctx, updatedSession, updatedUser, updatedDirectoryUser)
|
||||
return mgr.refreshLoop(ctx, updatedSession, updatedUser, updatedDirectoryUser, updatedDirectoryGroup)
|
||||
})
|
||||
|
||||
return t.Wait()
|
||||
|
@ -117,11 +132,12 @@ func (mgr *Manager) refreshLoop(
|
|||
updatedSession <-chan *session.Session,
|
||||
updatedUser <-chan *user.User,
|
||||
updatedDirectoryUser <-chan *directory.User,
|
||||
updatedDirectoryGroup <-chan *directory.Group,
|
||||
) error {
|
||||
maxWait := time.Minute * 10
|
||||
nextTime := time.Now().Add(maxWait)
|
||||
if mgr.directoryUsersNextRefresh.Before(nextTime) {
|
||||
nextTime = mgr.directoryUsersNextRefresh
|
||||
if mgr.directoryNextRefresh.Before(nextTime) {
|
||||
nextTime = mgr.directoryNextRefresh
|
||||
}
|
||||
|
||||
timer := time.NewTimer(time.Until(nextTime))
|
||||
|
@ -137,6 +153,8 @@ func (mgr *Manager) refreshLoop(
|
|||
mgr.onUpdateUser(ctx, u)
|
||||
case du := <-updatedDirectoryUser:
|
||||
mgr.onUpdateDirectoryUser(ctx, du)
|
||||
case dg := <-updatedDirectoryGroup:
|
||||
mgr.onUpdateDirectoryGroup(ctx, dg)
|
||||
case <-timer.C:
|
||||
}
|
||||
|
||||
|
@ -144,11 +162,11 @@ func (mgr *Manager) refreshLoop(
|
|||
nextTime := now.Add(maxWait)
|
||||
|
||||
// refresh groups
|
||||
if mgr.directoryUsersNextRefresh.Before(now) {
|
||||
mgr.refreshDirectoryUsers(ctx)
|
||||
mgr.directoryUsersNextRefresh = now.Add(mgr.cfg.groupRefreshInterval)
|
||||
if mgr.directoryUsersNextRefresh.Before(nextTime) {
|
||||
nextTime = mgr.directoryUsersNextRefresh
|
||||
if mgr.directoryNextRefresh.Before(now) {
|
||||
mgr.refreshDirectoryUserGroups(ctx)
|
||||
mgr.directoryNextRefresh = now.Add(mgr.cfg.groupRefreshInterval)
|
||||
if mgr.directoryNextRefresh.Before(nextTime) {
|
||||
nextTime = mgr.directoryNextRefresh
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -185,18 +203,69 @@ func (mgr *Manager) refreshLoop(
|
|||
}
|
||||
}
|
||||
|
||||
func (mgr *Manager) refreshDirectoryUsers(ctx context.Context) {
|
||||
func (mgr *Manager) refreshDirectoryUserGroups(ctx context.Context) {
|
||||
mgr.log.Info().Msg("refreshing directory users")
|
||||
|
||||
ctx, clearTimeout := context.WithTimeout(ctx, mgr.cfg.groupRefreshTimeout)
|
||||
defer clearTimeout()
|
||||
|
||||
directoryUsers, err := mgr.directory.UserGroups(ctx)
|
||||
directoryGroups, directoryUsers, err := mgr.directory.UserGroups(ctx)
|
||||
if err != nil {
|
||||
mgr.log.Warn().Err(err).Msg("failed to refresh directory users and groups")
|
||||
return
|
||||
}
|
||||
|
||||
mgr.mergeGroups(ctx, directoryGroups)
|
||||
mgr.mergeUsers(ctx, directoryUsers)
|
||||
}
|
||||
|
||||
func (mgr *Manager) mergeGroups(ctx context.Context, directoryGroups []*directory.Group) {
|
||||
lookup := map[string]*directory.Group{}
|
||||
for _, dg := range directoryGroups {
|
||||
lookup[dg.GetId()] = dg
|
||||
}
|
||||
|
||||
for groupID, newDG := range lookup {
|
||||
curDG, ok := mgr.directoryGroups[groupID]
|
||||
if !ok || !proto.Equal(newDG, curDG) {
|
||||
any, err := ptypes.MarshalAny(newDG)
|
||||
if err != nil {
|
||||
mgr.log.Warn().Err(err).Msg("failed to marshal directory group")
|
||||
return
|
||||
}
|
||||
_, err = mgr.dataBrokerClient.Set(ctx, &databroker.SetRequest{
|
||||
Type: any.GetTypeUrl(),
|
||||
Id: newDG.GetId(),
|
||||
Data: any,
|
||||
})
|
||||
if err != nil {
|
||||
mgr.log.Warn().Err(err).Msg("failed to update directory group")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for groupID, curDG := range mgr.directoryGroups {
|
||||
_, ok := lookup[groupID]
|
||||
if !ok {
|
||||
any, err := ptypes.MarshalAny(curDG)
|
||||
if err != nil {
|
||||
mgr.log.Warn().Err(err).Msg("failed to marshal directory group")
|
||||
return
|
||||
}
|
||||
_, err = mgr.dataBrokerClient.Delete(ctx, &databroker.DeleteRequest{
|
||||
Type: any.GetTypeUrl(),
|
||||
Id: curDG.GetId(),
|
||||
})
|
||||
if err != nil {
|
||||
mgr.log.Warn().Err(err).Msg("failed to delete directory group")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (mgr *Manager) mergeUsers(ctx context.Context, directoryUsers []*directory.User) {
|
||||
lookup := map[string]*directory.User{}
|
||||
for _, du := range directoryUsers {
|
||||
lookup[du.GetId()] = du
|
||||
|
@ -497,6 +566,75 @@ func (mgr *Manager) syncDirectoryUsers(ctx context.Context, ch chan<- *directory
|
|||
}
|
||||
}
|
||||
|
||||
func (mgr *Manager) initDirectoryGroups(ctx context.Context) error {
|
||||
mgr.log.Info().Msg("initializing directory groups")
|
||||
|
||||
any, err := ptypes.MarshalAny(new(directory.Group))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
res, err := mgr.dataBrokerClient.GetAll(ctx, &databroker.GetAllRequest{
|
||||
Type: any.GetTypeUrl(),
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("error getting all directory groups: %w", err)
|
||||
}
|
||||
|
||||
mgr.directoryGroups = map[string]*directory.Group{}
|
||||
for _, record := range res.GetRecords() {
|
||||
var pbDirectoryGroup directory.Group
|
||||
err := ptypes.UnmarshalAny(record.GetData(), &pbDirectoryGroup)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error unmarshaling directory group: %w", err)
|
||||
}
|
||||
|
||||
mgr.directoryGroups[pbDirectoryGroup.GetId()] = &pbDirectoryGroup
|
||||
}
|
||||
mgr.directoryGroupsRecordVersion = res.GetRecordVersion()
|
||||
mgr.directoryGroupsServerVersion = res.GetServerVersion()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mgr *Manager) syncDirectoryGroups(ctx context.Context, ch chan<- *directory.Group) error {
|
||||
mgr.log.Info().Msg("syncing directory groups")
|
||||
|
||||
any, err := ptypes.MarshalAny(new(directory.Group))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
client, err := mgr.dataBrokerClient.Sync(ctx, &databroker.SyncRequest{
|
||||
Type: any.GetTypeUrl(),
|
||||
ServerVersion: mgr.directoryGroupsServerVersion,
|
||||
RecordVersion: mgr.directoryGroupsRecordVersion,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("error syncing directory groups: %w", err)
|
||||
}
|
||||
for {
|
||||
res, err := client.Recv()
|
||||
if err != nil {
|
||||
return fmt.Errorf("error receiving directory groups: %w", err)
|
||||
}
|
||||
|
||||
for _, record := range res.GetRecords() {
|
||||
var pbDirectoryGroup directory.Group
|
||||
err := ptypes.UnmarshalAny(record.GetData(), &pbDirectoryGroup)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error unmarshaling directory group: %w", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case ch <- &pbDirectoryGroup:
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (mgr *Manager) onUpdateSession(ctx context.Context, pbSession *session.Session) {
|
||||
mgr.sessionScheduler.Remove(toSessionSchedulerKey(pbSession.GetUserId(), pbSession.GetId()))
|
||||
|
||||
|
@ -543,6 +681,10 @@ func (mgr *Manager) onUpdateDirectoryUser(_ context.Context, pbDirectoryUser *di
|
|||
mgr.directoryUsers[pbDirectoryUser.GetId()] = pbDirectoryUser
|
||||
}
|
||||
|
||||
func (mgr *Manager) onUpdateDirectoryGroup(_ context.Context, pbDirectoryGroup *directory.Group) {
|
||||
mgr.directoryGroups[pbDirectoryGroup.GetId()] = pbDirectoryGroup
|
||||
}
|
||||
|
||||
func (mgr *Manager) createUser(ctx context.Context, pbSession *session.Session) {
|
||||
u := User{
|
||||
User: &user.User{
|
||||
|
|
|
@ -9,8 +9,28 @@ import (
|
|||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
)
|
||||
|
||||
// Get gets a directory user from the databroker.
|
||||
func Get(ctx context.Context, client databroker.DataBrokerServiceClient, userID string) (*User, error) {
|
||||
// GetGroup gets a directory group from the databroker.
|
||||
func GetGroup(ctx context.Context, client databroker.DataBrokerServiceClient, groupID string) (*Group, error) {
|
||||
any, _ := ptypes.MarshalAny(new(Group))
|
||||
|
||||
res, err := client.Get(ctx, &databroker.GetRequest{
|
||||
Type: any.GetTypeUrl(),
|
||||
Id: groupID,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var g Group
|
||||
err = ptypes.UnmarshalAny(res.GetRecord().GetData(), &g)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &g, nil
|
||||
}
|
||||
|
||||
// GetUser gets a directory user from the databroker.
|
||||
func GetUser(ctx context.Context, client databroker.DataBrokerServiceClient, userID string) (*User, error) {
|
||||
any, _ := ptypes.MarshalAny(new(User))
|
||||
|
||||
res, err := client.Get(ctx, &databroker.GetRequest{
|
||||
|
|
|
@ -32,7 +32,7 @@ type User struct {
|
|||
|
||||
Version string `protobuf:"bytes,1,opt,name=version,proto3" json:"version,omitempty"`
|
||||
Id string `protobuf:"bytes,2,opt,name=id,proto3" json:"id,omitempty"`
|
||||
Groups []string `protobuf:"bytes,3,rep,name=groups,proto3" json:"groups,omitempty"`
|
||||
GroupIds []string `protobuf:"bytes,3,rep,name=group_ids,json=groupIds,proto3" json:"group_ids,omitempty"`
|
||||
}
|
||||
|
||||
func (x *User) Reset() {
|
||||
|
@ -81,27 +81,104 @@ func (x *User) GetId() string {
|
|||
return ""
|
||||
}
|
||||
|
||||
func (x *User) GetGroups() []string {
|
||||
func (x *User) GetGroupIds() []string {
|
||||
if x != nil {
|
||||
return x.Groups
|
||||
return x.GroupIds
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type Group struct {
|
||||
state protoimpl.MessageState
|
||||
sizeCache protoimpl.SizeCache
|
||||
unknownFields protoimpl.UnknownFields
|
||||
|
||||
Version string `protobuf:"bytes,1,opt,name=version,proto3" json:"version,omitempty"`
|
||||
Id string `protobuf:"bytes,2,opt,name=id,proto3" json:"id,omitempty"`
|
||||
Name string `protobuf:"bytes,3,opt,name=name,proto3" json:"name,omitempty"`
|
||||
Email string `protobuf:"bytes,4,opt,name=email,proto3" json:"email,omitempty"`
|
||||
}
|
||||
|
||||
func (x *Group) Reset() {
|
||||
*x = Group{}
|
||||
if protoimpl.UnsafeEnabled {
|
||||
mi := &file_directory_proto_msgTypes[1]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
}
|
||||
|
||||
func (x *Group) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*Group) ProtoMessage() {}
|
||||
|
||||
func (x *Group) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_directory_proto_msgTypes[1]
|
||||
if protoimpl.UnsafeEnabled && x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use Group.ProtoReflect.Descriptor instead.
|
||||
func (*Group) Descriptor() ([]byte, []int) {
|
||||
return file_directory_proto_rawDescGZIP(), []int{1}
|
||||
}
|
||||
|
||||
func (x *Group) GetVersion() string {
|
||||
if x != nil {
|
||||
return x.Version
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *Group) GetId() string {
|
||||
if x != nil {
|
||||
return x.Id
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *Group) GetName() string {
|
||||
if x != nil {
|
||||
return x.Name
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *Group) GetEmail() string {
|
||||
if x != nil {
|
||||
return x.Email
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
var File_directory_proto protoreflect.FileDescriptor
|
||||
|
||||
var file_directory_proto_rawDesc = []byte{
|
||||
0x0a, 0x0f, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x79, 0x2e, 0x70, 0x72, 0x6f, 0x74,
|
||||
0x6f, 0x12, 0x09, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x79, 0x22, 0x48, 0x0a, 0x04,
|
||||
0x6f, 0x12, 0x09, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x79, 0x22, 0x4d, 0x0a, 0x04,
|
||||
0x55, 0x73, 0x65, 0x72, 0x12, 0x18, 0x0a, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18,
|
||||
0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x0e,
|
||||
0x0a, 0x02, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x69, 0x64, 0x12, 0x16,
|
||||
0x0a, 0x06, 0x67, 0x72, 0x6f, 0x75, 0x70, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x09, 0x52, 0x06,
|
||||
0x67, 0x72, 0x6f, 0x75, 0x70, 0x73, 0x42, 0x31, 0x5a, 0x2f, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62,
|
||||
0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x70, 0x6f, 0x6d, 0x65, 0x72, 0x69, 0x75, 0x6d, 0x2f, 0x70, 0x6f,
|
||||
0x6d, 0x65, 0x72, 0x69, 0x75, 0x6d, 0x2f, 0x70, 0x6b, 0x67, 0x2f, 0x67, 0x72, 0x70, 0x63, 0x2f,
|
||||
0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x79, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f,
|
||||
0x33,
|
||||
0x0a, 0x02, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x69, 0x64, 0x12, 0x1b,
|
||||
0x0a, 0x09, 0x67, 0x72, 0x6f, 0x75, 0x70, 0x5f, 0x69, 0x64, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28,
|
||||
0x09, 0x52, 0x08, 0x67, 0x72, 0x6f, 0x75, 0x70, 0x49, 0x64, 0x73, 0x22, 0x5b, 0x0a, 0x05, 0x47,
|
||||
0x72, 0x6f, 0x75, 0x70, 0x12, 0x18, 0x0a, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18,
|
||||
0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x0e,
|
||||
0x0a, 0x02, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x69, 0x64, 0x12, 0x12,
|
||||
0x0a, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x6e, 0x61,
|
||||
0x6d, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x65, 0x6d, 0x61, 0x69, 0x6c, 0x18, 0x04, 0x20, 0x01, 0x28,
|
||||
0x09, 0x52, 0x05, 0x65, 0x6d, 0x61, 0x69, 0x6c, 0x42, 0x31, 0x5a, 0x2f, 0x67, 0x69, 0x74, 0x68,
|
||||
0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x70, 0x6f, 0x6d, 0x65, 0x72, 0x69, 0x75, 0x6d, 0x2f,
|
||||
0x70, 0x6f, 0x6d, 0x65, 0x72, 0x69, 0x75, 0x6d, 0x2f, 0x70, 0x6b, 0x67, 0x2f, 0x67, 0x72, 0x70,
|
||||
0x63, 0x2f, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x79, 0x62, 0x06, 0x70, 0x72, 0x6f,
|
||||
0x74, 0x6f, 0x33,
|
||||
}
|
||||
|
||||
var (
|
||||
|
@ -116,9 +193,10 @@ func file_directory_proto_rawDescGZIP() []byte {
|
|||
return file_directory_proto_rawDescData
|
||||
}
|
||||
|
||||
var file_directory_proto_msgTypes = make([]protoimpl.MessageInfo, 1)
|
||||
var file_directory_proto_msgTypes = make([]protoimpl.MessageInfo, 2)
|
||||
var file_directory_proto_goTypes = []interface{}{
|
||||
(*User)(nil), // 0: directory.User
|
||||
(*Group)(nil), // 1: directory.Group
|
||||
}
|
||||
var file_directory_proto_depIdxs = []int32{
|
||||
0, // [0:0] is the sub-list for method output_type
|
||||
|
@ -146,6 +224,18 @@ func file_directory_proto_init() {
|
|||
return nil
|
||||
}
|
||||
}
|
||||
file_directory_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} {
|
||||
switch v := v.(*Group); i {
|
||||
case 0:
|
||||
return &v.state
|
||||
case 1:
|
||||
return &v.sizeCache
|
||||
case 2:
|
||||
return &v.unknownFields
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
type x struct{}
|
||||
out := protoimpl.TypeBuilder{
|
||||
|
@ -153,7 +243,7 @@ func file_directory_proto_init() {
|
|||
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
|
||||
RawDescriptor: file_directory_proto_rawDesc,
|
||||
NumEnums: 0,
|
||||
NumMessages: 1,
|
||||
NumMessages: 2,
|
||||
NumExtensions: 0,
|
||||
NumServices: 0,
|
||||
},
|
||||
|
|
|
@ -6,5 +6,12 @@ option go_package = "github.com/pomerium/pomerium/pkg/grpc/directory";
|
|||
message User {
|
||||
string version = 1;
|
||||
string id = 2;
|
||||
repeated string groups = 3;
|
||||
repeated string group_ids = 3;
|
||||
}
|
||||
|
||||
message Group {
|
||||
string version = 1;
|
||||
string id = 2;
|
||||
string name = 3;
|
||||
string email = 4;
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue