mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-22 13:37:19 +02:00
directory: add explicit RefreshUser endpoint for faster sync (#1460)
* directory: add explicit RefreshUser endpoint for faster sync * add test * implement azure * update api call * add test for azure User * implement github * implement AccessToken, gitlab * implement okta * implement onelogin * fix test * fix inconsistent test * implement auth0
This commit is contained in:
parent
9b39deabd8
commit
aa731ae068
23 changed files with 1405 additions and 179 deletions
|
@ -83,6 +83,32 @@ func New(options ...Option) *Provider {
|
|||
}
|
||||
}
|
||||
|
||||
// User returns the user record for the given id.
|
||||
func (p *Provider) User(ctx context.Context, userID, accessToken string) (*directory.User, error) {
|
||||
_, providerUserID := databroker.FromUserID(userID)
|
||||
du := &directory.User{
|
||||
Id: userID,
|
||||
}
|
||||
|
||||
au, err := p.getUser(ctx, providerUserID, accessToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
du.DisplayName = au.Name
|
||||
du.Email = au.Email
|
||||
|
||||
groups, err := p.listGroups(ctx, accessToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, g := range groups {
|
||||
du.GroupIds = append(du.GroupIds, g.Id)
|
||||
}
|
||||
sort.Strings(du.GroupIds)
|
||||
|
||||
return du, nil
|
||||
}
|
||||
|
||||
// UserGroups gets the directory user groups for gitlab.
|
||||
func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*directory.User, error) {
|
||||
if p.cfg.serviceAccount == nil {
|
||||
|
@ -91,7 +117,7 @@ func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*direc
|
|||
|
||||
p.log.Info().Msg("getting user groups")
|
||||
|
||||
groups, err := p.listGroups(ctx)
|
||||
groups, err := p.listGroups(ctx, "")
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
@ -129,8 +155,20 @@ func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*direc
|
|||
return groups, users, nil
|
||||
}
|
||||
|
||||
func (p *Provider) getUser(ctx context.Context, userID string, accessToken string) (*apiUserObject, error) {
|
||||
apiURL := p.cfg.url.ResolveReference(&url.URL{
|
||||
Path: fmt.Sprintf("/api/v4/users/%s", userID),
|
||||
}).String()
|
||||
var result apiUserObject
|
||||
_, err := p.api(ctx, accessToken, apiURL, &result)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("gitlab: error querying user: %w", err)
|
||||
}
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// listGroups returns a map, with key is group ID, element is group name.
|
||||
func (p *Provider) listGroups(ctx context.Context) ([]*directory.Group, error) {
|
||||
func (p *Provider) listGroups(ctx context.Context, accessToken string) ([]*directory.Group, error) {
|
||||
nextURL := p.cfg.url.ResolveReference(&url.URL{
|
||||
Path: "/api/v4/groups",
|
||||
}).String()
|
||||
|
@ -140,7 +178,7 @@ func (p *Provider) listGroups(ctx context.Context) ([]*directory.Group, error) {
|
|||
ID int `json:"id"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
hdrs, err := p.apiGet(ctx, nextURL, &result)
|
||||
hdrs, err := p.api(ctx, accessToken, nextURL, &result)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("gitlab: error querying groups: %w", err)
|
||||
}
|
||||
|
@ -163,7 +201,7 @@ func (p *Provider) listGroupMembers(ctx context.Context, groupID string) (users
|
|||
}).String()
|
||||
for nextURL != "" {
|
||||
var result []apiUserObject
|
||||
hdrs, err := p.apiGet(ctx, nextURL, &result)
|
||||
hdrs, err := p.api(ctx, "", nextURL, &result)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("gitlab: error querying group members: %w", err)
|
||||
}
|
||||
|
@ -174,14 +212,18 @@ func (p *Provider) listGroupMembers(ctx context.Context, groupID string) (users
|
|||
return users, nil
|
||||
}
|
||||
|
||||
func (p *Provider) apiGet(ctx context.Context, uri string, out interface{}) (http.Header, error) {
|
||||
func (p *Provider) api(ctx context.Context, accessToken string, uri string, out interface{}) (http.Header, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", uri, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("gitlab: failed to create HTTP request: %w", err)
|
||||
}
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("PRIVATE-TOKEN", p.cfg.serviceAccount.PrivateToken)
|
||||
if accessToken != "" {
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", accessToken))
|
||||
} else {
|
||||
req.Header.Set("PRIVATE-TOKEN", p.cfg.serviceAccount.PrivateToken)
|
||||
}
|
||||
|
||||
res, err := p.cfg.httpClient.Do(req)
|
||||
if err != nil {
|
||||
|
@ -190,7 +232,7 @@ func (p *Provider) apiGet(ctx context.Context, uri string, out interface{}) (htt
|
|||
defer res.Body.Close()
|
||||
|
||||
if res.StatusCode/100 != 2 {
|
||||
return nil, fmt.Errorf("gitlab: error query api status_code=%d: %s", res.StatusCode, res.Status)
|
||||
return nil, fmt.Errorf("gitlab: error querying api url=%s status_code=%d: %s", uri, res.StatusCode, res.Status)
|
||||
}
|
||||
|
||||
err = json.NewDecoder(res.Body).Decode(out)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue