mirror of
https://github.com/pomerium/pomerium.git
synced 2025-08-03 16:59:22 +02:00
directory: additional user info (#1467)
* directory: support additional user information * implement github * implement gitlab * implement onelogin * implement okta * rename to display name * implement google * fill in properties * fix azure email parsing * fix tests, lint * fix onelogin tests * fix gitlab/github tests
This commit is contained in:
parent
88580cf2fb
commit
3e86d2f9bf
13 changed files with 339 additions and 165 deletions
|
@ -26,8 +26,21 @@ import (
|
|||
// Name is the provider name.
|
||||
const Name = "okta"
|
||||
|
||||
// Okta use ISO-8601, see https://developer.okta.com/docs/reference/api-overview/#media-types
|
||||
const filterDateFormat = "2006-01-02T15:04:05.999Z"
|
||||
const (
|
||||
// Okta use ISO-8601, see https://developer.okta.com/docs/reference/api-overview/#media-types
|
||||
filterDateFormat = "2006-01-02T15:04:05.999Z"
|
||||
|
||||
batchSize = 200
|
||||
readLimit = 100 * 1024
|
||||
httpSuccessClass = 2
|
||||
)
|
||||
|
||||
// Errors.
|
||||
var (
|
||||
ErrAPIKeyRequired = errors.New("okta: api_key is required")
|
||||
ErrServiceAccountNotDefined = errors.New("okta: service account not defined")
|
||||
ErrProviderURLNotDefined = errors.New("okta: provider url not defined")
|
||||
)
|
||||
|
||||
type config struct {
|
||||
batchSize int
|
||||
|
@ -69,11 +82,12 @@ func WithServiceAccount(serviceAccount *ServiceAccount) Option {
|
|||
|
||||
func getConfig(options ...Option) *config {
|
||||
cfg := new(config)
|
||||
WithBatchSize(200)(cfg)
|
||||
WithBatchSize(batchSize)(cfg)
|
||||
WithHTTPClient(http.DefaultClient)(cfg)
|
||||
for _, option := range options {
|
||||
option(cfg)
|
||||
}
|
||||
|
||||
return cfg
|
||||
}
|
||||
|
||||
|
@ -98,13 +112,13 @@ func New(options ...Option) *Provider {
|
|||
// https://developer.okta.com/docs/reference/api/users/#get-user-s-groups
|
||||
func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*directory.User, error) {
|
||||
if p.cfg.serviceAccount == nil {
|
||||
return nil, nil, fmt.Errorf("okta: service account not defined")
|
||||
return nil, nil, ErrServiceAccountNotDefined
|
||||
}
|
||||
|
||||
p.log.Info().Msg("getting user groups")
|
||||
|
||||
if p.cfg.providerURL == nil {
|
||||
return nil, nil, fmt.Errorf("okta: provider url not defined")
|
||||
return nil, nil, ErrProviderURLNotDefined
|
||||
}
|
||||
|
||||
groups, err := p.getGroups(ctx)
|
||||
|
@ -112,10 +126,11 @@ func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*direc
|
|||
return nil, nil, err
|
||||
}
|
||||
|
||||
userLookup := map[string]apiUserObject{}
|
||||
userIDToGroups := map[string][]string{}
|
||||
for i := 0; i < len(groups); i++ {
|
||||
group := groups[i]
|
||||
ids, err := p.getGroupMemberIDs(ctx, group.Id)
|
||||
users, err := p.getGroupMembers(ctx, group.Id)
|
||||
|
||||
// if we get a 404 on the member query, it means the group doesn't exist, so we should remove it from
|
||||
// the cached lookup and the local groups list
|
||||
|
@ -131,17 +146,21 @@ func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*direc
|
|||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
for _, id := range ids {
|
||||
userIDToGroups[id] = append(userIDToGroups[id], group.Id)
|
||||
for _, u := range users {
|
||||
userIDToGroups[u.ID] = append(userIDToGroups[u.ID], group.Id)
|
||||
userLookup[u.ID] = u
|
||||
}
|
||||
}
|
||||
|
||||
var users []*directory.User
|
||||
for userID, groups := range userIDToGroups {
|
||||
for _, u := range userLookup {
|
||||
groups := userIDToGroups[u.ID]
|
||||
sort.Strings(groups)
|
||||
users = append(users, &directory.User{
|
||||
Id: databroker.GetUserID(Name, userID),
|
||||
GroupIds: groups,
|
||||
Id: databroker.GetUserID(Name, u.ID),
|
||||
GroupIds: groups,
|
||||
DisplayName: u.Profile.FirstName + " " + u.Profile.LastName,
|
||||
Email: u.Profile.Email,
|
||||
})
|
||||
}
|
||||
sort.Slice(users, func(i, j int) bool {
|
||||
|
@ -201,30 +220,23 @@ func (p *Provider) getGroups(ctx context.Context) ([]*directory.Group, error) {
|
|||
return groups, nil
|
||||
}
|
||||
|
||||
func (p *Provider) getGroupMemberIDs(ctx context.Context, groupID string) ([]string, error) {
|
||||
var emails []string
|
||||
|
||||
func (p *Provider) getGroupMembers(ctx context.Context, groupID string) (users []apiUserObject, err error) {
|
||||
usersURL := p.cfg.providerURL.ResolveReference(&url.URL{
|
||||
Path: fmt.Sprintf("/api/v1/groups/%s/users", groupID),
|
||||
RawQuery: fmt.Sprintf("limit=%d", p.cfg.batchSize),
|
||||
}).String()
|
||||
for usersURL != "" {
|
||||
var out []struct {
|
||||
ID string `json:"id"`
|
||||
}
|
||||
var out []apiUserObject
|
||||
hdrs, err := p.apiGet(ctx, usersURL, &out)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("okta: error querying for groups: %w", err)
|
||||
}
|
||||
|
||||
for _, el := range out {
|
||||
emails = append(emails, el.ID)
|
||||
}
|
||||
|
||||
users = append(users, out...)
|
||||
usersURL = getNextLink(hdrs)
|
||||
}
|
||||
|
||||
return emails, nil
|
||||
return users, nil
|
||||
}
|
||||
|
||||
func (p *Provider) apiGet(ctx context.Context, uri string, out interface{}) (http.Header, error) {
|
||||
|
@ -250,7 +262,7 @@ func (p *Provider) apiGet(ctx context.Context, uri string, out interface{}) (htt
|
|||
}
|
||||
continue
|
||||
}
|
||||
if res.StatusCode/100 != 2 {
|
||||
if res.StatusCode/100 != httpSuccessClass {
|
||||
return nil, newAPIError(res)
|
||||
}
|
||||
if err := json.NewDecoder(res.Body).Decode(out); err != nil {
|
||||
|
@ -287,7 +299,7 @@ func ParseServiceAccount(rawServiceAccount string) (*ServiceAccount, error) {
|
|||
}
|
||||
|
||||
if serviceAccount.APIKey == "" {
|
||||
return nil, fmt.Errorf("api_key is required")
|
||||
return nil, ErrAPIKeyRequired
|
||||
}
|
||||
|
||||
return &serviceAccount, nil
|
||||
|
@ -308,7 +320,7 @@ func newAPIError(res *http.Response) error {
|
|||
if res == nil {
|
||||
return nil
|
||||
}
|
||||
buf, _ := ioutil.ReadAll(io.LimitReader(res.Body, 100*1024)) // limit to 100kb
|
||||
buf, _ := ioutil.ReadAll(io.LimitReader(res.Body, readLimit)) // limit to 100kb
|
||||
|
||||
err := &APIError{
|
||||
HTTPStatusCode: res.StatusCode,
|
||||
|
@ -321,3 +333,12 @@ func newAPIError(res *http.Response) error {
|
|||
func (err *APIError) Error() string {
|
||||
return fmt.Sprintf("okta: error querying API, status_code=%d: %s", err.HTTPStatusCode, err.Body)
|
||||
}
|
||||
|
||||
type apiUserObject struct {
|
||||
ID string `json:"id"`
|
||||
Profile struct {
|
||||
FirstName string `json:"firstName"`
|
||||
LastName string `json:"lastName"`
|
||||
Email string `json:"email"`
|
||||
} `json:"profile"`
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue