azure: incremental sync (#1471)

* azure: incremental sync

* identity manager: fix directory sync timing

* on unauthorized, reset token

* support querying the user api

* update name

* pull out constants
This commit is contained in:
Caleb Doxsey 2020-09-30 08:18:04 -06:00 committed by GitHub
parent 3e86d2f9bf
commit 697be04c6f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 309 additions and 132 deletions

View file

@ -9,13 +9,11 @@ import (
"io"
"net/http"
"net/url"
"sort"
"strings"
"sync"
"golang.org/x/oauth2"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/grpc/directory"
)
@ -88,6 +86,7 @@ func getConfig(options ...Option) *config {
// A Provider is a directory implementation using azure active directory.
type Provider struct {
cfg *config
dc *deltaCollection
mu sync.RWMutex
token *oauth2.Token
@ -95,9 +94,11 @@ type Provider struct {
// New creates a new Provider.
func New(options ...Option) *Provider {
return &Provider{
p := &Provider{
cfg: getConfig(options...),
}
p.dc = newDeltaCollection(p)
return p
}
// UserGroups returns the directory users in azure active directory.
@ -106,100 +107,15 @@ func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*direc
return nil, nil, fmt.Errorf("azure: service account not defined")
}
groups, err := p.listGroups(ctx)
err := p.dc.Sync(ctx)
if err != nil {
return nil, nil, err
}
userLookup := map[string]apiDirectoryObject{}
groupLookup := newGroupLookup()
for _, group := range groups {
groupIDs, users, err := p.listGroupMembers(ctx, group.Id)
if err != nil {
return nil, nil, err
}
userIDs := make([]string, 0, len(users))
for _, u := range users {
userIDs = append(userIDs, u.ID)
userLookup[u.ID] = u
}
groupLookup.addGroup(group.Id, groupIDs, userIDs)
}
users := make([]*directory.User, 0, len(userLookup))
for _, u := range userLookup {
users = append(users, &directory.User{
Id: databroker.GetUserID(Name, u.ID),
GroupIds: groupLookup.getGroupIDsForUser(u.ID),
DisplayName: u.DisplayName,
Email: u.getEmail(),
})
}
sort.Slice(users, func(i, j int) bool {
return users[i].GetId() < users[j].GetId()
})
groups, users := p.dc.CurrentUserGroups()
return groups, users, nil
}
// listGroups returns a map, with key is group ID, element is group name.
func (p *Provider) listGroups(ctx context.Context) ([]*directory.Group, error) {
nextURL := p.cfg.graphURL.ResolveReference(&url.URL{
Path: "/v1.0/groups",
}).String()
var groups []*directory.Group
for nextURL != "" {
var result struct {
Value []struct {
ID string `json:"id"`
DisplayName string `json:"displayName"`
} `json:"value"`
NextLink string `json:"@odata.nextLink"`
}
err := p.api(ctx, "GET", nextURL, nil, &result)
if err != nil {
return nil, err
}
for _, v := range result.Value {
groups = append(groups, &directory.Group{
Id: v.ID,
Name: v.DisplayName,
})
}
nextURL = result.NextLink
}
return groups, nil
}
func (p *Provider) listGroupMembers(ctx context.Context, groupID string) (groupIDs []string, users []apiDirectoryObject, err error) {
nextURL := p.cfg.graphURL.ResolveReference(&url.URL{
Path: fmt.Sprintf("/v1.0/groups/%s/members", groupID),
}).String()
for nextURL != "" {
var result struct {
Value []apiDirectoryObject `json:"value"`
NextLink string `json:"@odata.nextLink"`
}
err := p.api(ctx, "GET", nextURL, nil, &result)
if err != nil {
return nil, nil, err
}
for _, v := range result.Value {
switch v.Type {
case "#microsoft.graph.group":
groupIDs = append(groupIDs, v.ID)
case "#microsoft.graph.user":
users = append(users, v)
}
}
nextURL = result.NextLink
}
return groupIDs, users, nil
}
func (p *Provider) api(ctx context.Context, method, url string, body io.Reader, out interface{}) error {
token, err := p.getToken(ctx)
if err != nil {
@ -219,6 +135,13 @@ func (p *Provider) api(ctx context.Context, method, url string, body io.Reader,
}
defer res.Body.Close()
// if we get unauthorized, invalidate the token
if res.StatusCode == http.StatusUnauthorized {
p.mu.Lock()
p.token = nil
p.mu.Unlock()
}
if res.StatusCode/100 != 2 {
return fmt.Errorf("azure: error querying api: %s", res.Status)
}
@ -359,31 +282,3 @@ func parseDirectoryIDFromURL(providerURL string) (string, error) {
return pathParts[1], nil
}
type apiDirectoryObject struct {
Type string `json:"@odata.type"`
ID string `json:"id"`
Mail string `json:"mail"`
DisplayName string `json:"displayName"`
UserPrincipalName string `json:"userPrincipalName"`
}
func (obj apiDirectoryObject) getEmail() string {
if obj.Mail != "" {
return obj.Mail
}
// AD often doesn't have the email address returned, but we can parse it from the UPN
// UPN looks like:
// cdoxsey_pomerium.com#EXT#@cdoxseypomerium.onmicrosoft.com
email := obj.UserPrincipalName
if idx := strings.Index(email, "#EXT"); idx > 0 {
email = email[:idx]
}
// find the last _ and replace it with @
if idx := strings.LastIndex(email, "_"); idx > 0 {
email = email[:idx] + "@" + email[idx+1:]
}
return email
}