mirror of
https://github.com/pomerium/pomerium.git
synced 2025-06-01 02:12:50 +02:00
azure: incremental sync (#1471)
* azure: incremental sync * identity manager: fix directory sync timing * on unauthorized, reset token * support querying the user api * update name * pull out constants
This commit is contained in:
parent
3e86d2f9bf
commit
697be04c6f
3 changed files with 309 additions and 132 deletions
|
@ -9,13 +9,11 @@ import (
|
|||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/directory"
|
||||
)
|
||||
|
||||
|
@ -88,6 +86,7 @@ func getConfig(options ...Option) *config {
|
|||
// A Provider is a directory implementation using azure active directory.
|
||||
type Provider struct {
|
||||
cfg *config
|
||||
dc *deltaCollection
|
||||
|
||||
mu sync.RWMutex
|
||||
token *oauth2.Token
|
||||
|
@ -95,9 +94,11 @@ type Provider struct {
|
|||
|
||||
// New creates a new Provider.
|
||||
func New(options ...Option) *Provider {
|
||||
return &Provider{
|
||||
p := &Provider{
|
||||
cfg: getConfig(options...),
|
||||
}
|
||||
p.dc = newDeltaCollection(p)
|
||||
return p
|
||||
}
|
||||
|
||||
// UserGroups returns the directory users in azure active directory.
|
||||
|
@ -106,100 +107,15 @@ func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*direc
|
|||
return nil, nil, fmt.Errorf("azure: service account not defined")
|
||||
}
|
||||
|
||||
groups, err := p.listGroups(ctx)
|
||||
err := p.dc.Sync(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
userLookup := map[string]apiDirectoryObject{}
|
||||
groupLookup := newGroupLookup()
|
||||
for _, group := range groups {
|
||||
groupIDs, users, err := p.listGroupMembers(ctx, group.Id)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
userIDs := make([]string, 0, len(users))
|
||||
for _, u := range users {
|
||||
userIDs = append(userIDs, u.ID)
|
||||
userLookup[u.ID] = u
|
||||
}
|
||||
groupLookup.addGroup(group.Id, groupIDs, userIDs)
|
||||
}
|
||||
|
||||
users := make([]*directory.User, 0, len(userLookup))
|
||||
for _, u := range userLookup {
|
||||
users = append(users, &directory.User{
|
||||
Id: databroker.GetUserID(Name, u.ID),
|
||||
GroupIds: groupLookup.getGroupIDsForUser(u.ID),
|
||||
DisplayName: u.DisplayName,
|
||||
Email: u.getEmail(),
|
||||
})
|
||||
}
|
||||
sort.Slice(users, func(i, j int) bool {
|
||||
return users[i].GetId() < users[j].GetId()
|
||||
})
|
||||
groups, users := p.dc.CurrentUserGroups()
|
||||
return groups, users, nil
|
||||
}
|
||||
|
||||
// listGroups returns a map, with key is group ID, element is group name.
|
||||
func (p *Provider) listGroups(ctx context.Context) ([]*directory.Group, error) {
|
||||
nextURL := p.cfg.graphURL.ResolveReference(&url.URL{
|
||||
Path: "/v1.0/groups",
|
||||
}).String()
|
||||
|
||||
var groups []*directory.Group
|
||||
for nextURL != "" {
|
||||
var result struct {
|
||||
Value []struct {
|
||||
ID string `json:"id"`
|
||||
DisplayName string `json:"displayName"`
|
||||
} `json:"value"`
|
||||
NextLink string `json:"@odata.nextLink"`
|
||||
}
|
||||
err := p.api(ctx, "GET", nextURL, nil, &result)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, v := range result.Value {
|
||||
groups = append(groups, &directory.Group{
|
||||
Id: v.ID,
|
||||
Name: v.DisplayName,
|
||||
})
|
||||
}
|
||||
nextURL = result.NextLink
|
||||
}
|
||||
|
||||
return groups, nil
|
||||
}
|
||||
|
||||
func (p *Provider) listGroupMembers(ctx context.Context, groupID string) (groupIDs []string, users []apiDirectoryObject, err error) {
|
||||
nextURL := p.cfg.graphURL.ResolveReference(&url.URL{
|
||||
Path: fmt.Sprintf("/v1.0/groups/%s/members", groupID),
|
||||
}).String()
|
||||
|
||||
for nextURL != "" {
|
||||
var result struct {
|
||||
Value []apiDirectoryObject `json:"value"`
|
||||
NextLink string `json:"@odata.nextLink"`
|
||||
}
|
||||
err := p.api(ctx, "GET", nextURL, nil, &result)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
for _, v := range result.Value {
|
||||
switch v.Type {
|
||||
case "#microsoft.graph.group":
|
||||
groupIDs = append(groupIDs, v.ID)
|
||||
case "#microsoft.graph.user":
|
||||
users = append(users, v)
|
||||
}
|
||||
}
|
||||
nextURL = result.NextLink
|
||||
}
|
||||
|
||||
return groupIDs, users, nil
|
||||
}
|
||||
|
||||
func (p *Provider) api(ctx context.Context, method, url string, body io.Reader, out interface{}) error {
|
||||
token, err := p.getToken(ctx)
|
||||
if err != nil {
|
||||
|
@ -219,6 +135,13 @@ func (p *Provider) api(ctx context.Context, method, url string, body io.Reader,
|
|||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
// if we get unauthorized, invalidate the token
|
||||
if res.StatusCode == http.StatusUnauthorized {
|
||||
p.mu.Lock()
|
||||
p.token = nil
|
||||
p.mu.Unlock()
|
||||
}
|
||||
|
||||
if res.StatusCode/100 != 2 {
|
||||
return fmt.Errorf("azure: error querying api: %s", res.Status)
|
||||
}
|
||||
|
@ -359,31 +282,3 @@ func parseDirectoryIDFromURL(providerURL string) (string, error) {
|
|||
|
||||
return pathParts[1], nil
|
||||
}
|
||||
|
||||
type apiDirectoryObject struct {
|
||||
Type string `json:"@odata.type"`
|
||||
ID string `json:"id"`
|
||||
Mail string `json:"mail"`
|
||||
DisplayName string `json:"displayName"`
|
||||
UserPrincipalName string `json:"userPrincipalName"`
|
||||
}
|
||||
|
||||
func (obj apiDirectoryObject) getEmail() string {
|
||||
if obj.Mail != "" {
|
||||
return obj.Mail
|
||||
}
|
||||
|
||||
// AD often doesn't have the email address returned, but we can parse it from the UPN
|
||||
|
||||
// UPN looks like:
|
||||
// cdoxsey_pomerium.com#EXT#@cdoxseypomerium.onmicrosoft.com
|
||||
email := obj.UserPrincipalName
|
||||
if idx := strings.Index(email, "#EXT"); idx > 0 {
|
||||
email = email[:idx]
|
||||
}
|
||||
// find the last _ and replace it with @
|
||||
if idx := strings.LastIndex(email, "_"); idx > 0 {
|
||||
email = email[:idx] + "@" + email[idx+1:]
|
||||
}
|
||||
return email
|
||||
}
|
||||
|
|
|
@ -44,26 +44,34 @@ func newMockAPI(t *testing.T, srv *httptest.Server) http.Handler {
|
|||
next.ServeHTTP(w, r)
|
||||
})
|
||||
})
|
||||
r.Get("/groups", func(w http.ResponseWriter, r *http.Request) {
|
||||
r.Get("/groups/delta", func(w http.ResponseWriter, r *http.Request) {
|
||||
_ = json.NewEncoder(w).Encode(M{
|
||||
"value": []M{
|
||||
{"id": "admin", "displayName": "Admin Group"},
|
||||
{"id": "test", "displayName": "Test Group"},
|
||||
{
|
||||
"id": "admin",
|
||||
"displayName": "Admin Group",
|
||||
"members@delta": []M{
|
||||
{"@odata.type": "#microsoft.graph.user", "id": "user-1"},
|
||||
},
|
||||
},
|
||||
{
|
||||
"id": "test",
|
||||
"displayName": "Test Group",
|
||||
"members@delta": []M{
|
||||
{"@odata.type": "#microsoft.graph.user", "id": "user-2"},
|
||||
{"@odata.type": "#microsoft.graph.user", "id": "user-3"},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
})
|
||||
r.Get("/groups/{group_name}/members", func(w http.ResponseWriter, r *http.Request) {
|
||||
members := map[string][]M{
|
||||
"admin": {
|
||||
{"@odata.type": "#microsoft.graph.user", "id": "user-1", "displayName": "User 1", "mail": "user1@example.com"},
|
||||
},
|
||||
"test": {
|
||||
{"@odata.type": "#microsoft.graph.user", "id": "user-2", "displayName": "User 2", "mail": "user2@example.com"},
|
||||
{"@odata.type": "#microsoft.graph.user", "id": "user-3", "displayName": "User 3", "userPrincipalName": "user3_example.com#EXT#@user3example.onmicrosoft.com"},
|
||||
},
|
||||
}
|
||||
r.Get("/users/delta", func(w http.ResponseWriter, r *http.Request) {
|
||||
_ = json.NewEncoder(w).Encode(M{
|
||||
"value": members[chi.URLParam(r, "group_name")],
|
||||
"value": []M{
|
||||
{"id": "user-1", "displayName": "User 1", "mail": "user1@example.com"},
|
||||
{"id": "user-2", "displayName": "User 2", "mail": "user2@example.com"},
|
||||
{"id": "user-3", "displayName": "User 3", "userPrincipalName": "user3_example.com#EXT#@user3example.onmicrosoft.com"},
|
||||
},
|
||||
})
|
||||
})
|
||||
})
|
||||
|
|
274
internal/directory/azure/delta.go
Normal file
274
internal/directory/azure/delta.go
Normal file
|
@ -0,0 +1,274 @@
|
|||
package azure
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/url"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/directory"
|
||||
)
|
||||
|
||||
const (
|
||||
groupsDeltaPath = "/v1.0/groups/delta"
|
||||
usersDeltaPath = "/v1.0/users/delta"
|
||||
)
|
||||
|
||||
type (
|
||||
deltaCollection struct {
|
||||
provider *Provider
|
||||
groups map[string]deltaGroup
|
||||
groupDeltaLink string
|
||||
users map[string]deltaUser
|
||||
userDeltaLink string
|
||||
}
|
||||
deltaGroup struct {
|
||||
id string
|
||||
displayName string
|
||||
members map[string]deltaGroupMember
|
||||
}
|
||||
deltaGroupMember struct {
|
||||
memberType string
|
||||
id string
|
||||
}
|
||||
deltaUser struct {
|
||||
id string
|
||||
displayName string
|
||||
email string
|
||||
}
|
||||
)
|
||||
|
||||
func newDeltaCollection(p *Provider) *deltaCollection {
|
||||
return &deltaCollection{
|
||||
provider: p,
|
||||
groups: make(map[string]deltaGroup),
|
||||
users: make(map[string]deltaUser),
|
||||
}
|
||||
}
|
||||
|
||||
// Sync syncs the latest changes from the microsoft graph API.
|
||||
//
|
||||
// Synchronization is based on https://docs.microsoft.com/en-us/graph/delta-query-groups
|
||||
//
|
||||
// It involves 4 steps:
|
||||
//
|
||||
// 1. an initial request to /v1.0/groups/delta
|
||||
// 2. one or more requests to /v1.0/groups/delta?$skiptoken=..., which comes from the @odata.nextLink
|
||||
// 3. a final response with @odata.deltaLink
|
||||
// 4. on the next call to sync, starting at @odata.deltaLink
|
||||
//
|
||||
// Only the changed groups/members are returned. Removed groups/members have an @removed property.
|
||||
func (dc *deltaCollection) Sync(ctx context.Context) error {
|
||||
if err := dc.syncGroups(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := dc.syncUsers(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (dc *deltaCollection) syncGroups(ctx context.Context) error {
|
||||
apiURL := dc.groupDeltaLink
|
||||
|
||||
// if no delta link is set yet, start the initial fill
|
||||
if apiURL == "" {
|
||||
apiURL = dc.provider.cfg.graphURL.ResolveReference(&url.URL{
|
||||
Path: groupsDeltaPath,
|
||||
RawQuery: url.Values{
|
||||
"$select": {"displayName,members"},
|
||||
}.Encode(),
|
||||
}).String()
|
||||
}
|
||||
|
||||
for {
|
||||
var res groupsDeltaResponse
|
||||
err := dc.provider.api(ctx, "GET", apiURL, nil, &res)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, g := range res.Value {
|
||||
// if removed exists, the group was deleted
|
||||
if g.Removed != nil {
|
||||
delete(dc.groups, g.ID)
|
||||
continue
|
||||
}
|
||||
|
||||
gdg := dc.groups[g.ID]
|
||||
gdg.id = g.ID
|
||||
gdg.displayName = g.DisplayName
|
||||
if gdg.members == nil {
|
||||
gdg.members = make(map[string]deltaGroupMember)
|
||||
}
|
||||
for _, m := range g.Members {
|
||||
// if removed exists, the member was deleted
|
||||
if m.Removed != nil {
|
||||
delete(gdg.members, m.ID)
|
||||
continue
|
||||
}
|
||||
|
||||
gdg.members[m.ID] = deltaGroupMember{
|
||||
memberType: m.Type,
|
||||
id: m.ID,
|
||||
}
|
||||
}
|
||||
dc.groups[g.ID] = gdg
|
||||
}
|
||||
|
||||
switch {
|
||||
case res.NextLink != "":
|
||||
// when there's a next link we will query again
|
||||
apiURL = res.NextLink
|
||||
default:
|
||||
// once no next link is set anymore, we save the delta link and return
|
||||
dc.groupDeltaLink = res.DeltaLink
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (dc *deltaCollection) syncUsers(ctx context.Context) error {
|
||||
apiURL := dc.userDeltaLink
|
||||
|
||||
// if no delta link is set yet, start the initial fill
|
||||
if apiURL == "" {
|
||||
apiURL = dc.provider.cfg.graphURL.ResolveReference(&url.URL{
|
||||
Path: usersDeltaPath,
|
||||
RawQuery: url.Values{
|
||||
"$select": {"displayName,mail,userPrincipalName"},
|
||||
}.Encode(),
|
||||
}).String()
|
||||
}
|
||||
|
||||
for {
|
||||
var res usersDeltaResponse
|
||||
err := dc.provider.api(ctx, "GET", apiURL, nil, &res)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, u := range res.Value {
|
||||
// if removed exists, the user was deleted
|
||||
if u.Removed != nil {
|
||||
delete(dc.users, u.ID)
|
||||
continue
|
||||
}
|
||||
dc.users[u.ID] = deltaUser{
|
||||
id: u.ID,
|
||||
displayName: u.DisplayName,
|
||||
email: u.getEmail(),
|
||||
}
|
||||
}
|
||||
|
||||
switch {
|
||||
case res.NextLink != "":
|
||||
// when there's a next link we will query again
|
||||
apiURL = res.NextLink
|
||||
default:
|
||||
// once no next link is set anymore, we save the delta link and return
|
||||
dc.userDeltaLink = res.DeltaLink
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// CurrentUserGroups returns the directory groups and users based on the current state.
|
||||
func (dc *deltaCollection) CurrentUserGroups() ([]*directory.Group, []*directory.User) {
|
||||
var groups []*directory.Group
|
||||
|
||||
groupLookup := newGroupLookup()
|
||||
for _, g := range dc.groups {
|
||||
groups = append(groups, &directory.Group{
|
||||
Id: g.id,
|
||||
Name: g.displayName,
|
||||
})
|
||||
var groupIDs, userIDs []string
|
||||
for _, m := range g.members {
|
||||
switch m.memberType {
|
||||
case "#microsoft.graph.group":
|
||||
groupIDs = append(groupIDs, m.id)
|
||||
case "#microsoft.graph.user":
|
||||
userIDs = append(userIDs, m.id)
|
||||
}
|
||||
}
|
||||
groupLookup.addGroup(g.id, groupIDs, userIDs)
|
||||
}
|
||||
|
||||
var users []*directory.User
|
||||
for _, u := range dc.users {
|
||||
users = append(users, &directory.User{
|
||||
Id: databroker.GetUserID(Name, u.id),
|
||||
GroupIds: groupLookup.getGroupIDsForUser(u.id),
|
||||
DisplayName: u.displayName,
|
||||
Email: u.email,
|
||||
})
|
||||
}
|
||||
sort.Slice(users, func(i, j int) bool {
|
||||
return users[i].GetId() < users[j].GetId()
|
||||
})
|
||||
|
||||
return groups, users
|
||||
}
|
||||
|
||||
// API types for the microsoft graph API.
|
||||
type (
|
||||
deltaResponseRemoved struct {
|
||||
Reason string `json:"reason"`
|
||||
}
|
||||
|
||||
groupsDeltaResponse struct {
|
||||
Context string `json:"@odata.context"`
|
||||
NextLink string `json:"@odata.nextLink,omitempty"`
|
||||
DeltaLink string `json:"@odata.deltaLink,omitempty"`
|
||||
Value []groupsDeltaResponseGroup `json:"value"`
|
||||
}
|
||||
groupsDeltaResponseGroup struct {
|
||||
ID string `json:"id"`
|
||||
DisplayName string `json:"displayName"`
|
||||
Members []groupsDeltaResponseGroupMember `json:"members@delta"`
|
||||
Removed *deltaResponseRemoved `json:"@removed,omitempty"`
|
||||
}
|
||||
groupsDeltaResponseGroupMember struct {
|
||||
Type string `json:"@odata.type"`
|
||||
ID string `json:"id"`
|
||||
Removed *deltaResponseRemoved `json:"@removed,omitempty"`
|
||||
}
|
||||
|
||||
usersDeltaResponse struct {
|
||||
Context string `json:"@odata.context"`
|
||||
NextLink string `json:"@odata.nextLink,omitempty"`
|
||||
DeltaLink string `json:"@odata.deltaLink,omitempty"`
|
||||
Value []usersDeltaResponseUser `json:"value"`
|
||||
}
|
||||
usersDeltaResponseUser struct {
|
||||
ID string `json:"id"`
|
||||
DisplayName string `json:"displayName"`
|
||||
Mail string `json:"mail"`
|
||||
UserPrincipalName string `json:"userPrincipalName"`
|
||||
Removed *deltaResponseRemoved `json:"@removed,omitempty"`
|
||||
}
|
||||
)
|
||||
|
||||
func (obj usersDeltaResponseUser) getEmail() string {
|
||||
if obj.Mail != "" {
|
||||
return obj.Mail
|
||||
}
|
||||
|
||||
// AD often doesn't have the email address returned, but we can parse it from the UPN
|
||||
|
||||
// UPN looks like:
|
||||
// cdoxsey_pomerium.com#EXT#@cdoxseypomerium.onmicrosoft.com
|
||||
email := obj.UserPrincipalName
|
||||
if idx := strings.Index(email, "#EXT"); idx > 0 {
|
||||
email = email[:idx]
|
||||
}
|
||||
// find the last _ and replace it with @
|
||||
if idx := strings.LastIndex(email, "_"); idx > 0 {
|
||||
email = email[:idx] + "@" + email[idx+1:]
|
||||
}
|
||||
return email
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue