mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-12 08:37:38 +02:00
directory: add explicit RefreshUser endpoint for faster sync (#1460)
* directory: add explicit RefreshUser endpoint for faster sync * add test * implement azure * update api call * add test for azure User * implement github * implement AccessToken, gitlab * implement okta * implement onelogin * fix test * fix inconsistent test * implement auth0
This commit is contained in:
parent
9b39deabd8
commit
aa731ae068
23 changed files with 1405 additions and 179 deletions
|
@ -2,6 +2,7 @@
|
|||
package github
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
|
@ -83,6 +84,47 @@ func New(options ...Option) *Provider {
|
|||
}
|
||||
}
|
||||
|
||||
// User returns the user record for the given id.
|
||||
func (p *Provider) User(ctx context.Context, userID, accessToken string) (*directory.User, error) {
|
||||
if p.cfg.serviceAccount == nil {
|
||||
return nil, fmt.Errorf("github: service account not defined")
|
||||
}
|
||||
|
||||
_, providerUserID := databroker.FromUserID(userID)
|
||||
du := &directory.User{
|
||||
Id: userID,
|
||||
}
|
||||
|
||||
au, err := p.getUser(ctx, providerUserID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
du.DisplayName = au.Name
|
||||
du.Email = au.Email
|
||||
|
||||
teamIDLookup := map[int]struct{}{}
|
||||
orgSlugs, err := p.listOrgs(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, orgSlug := range orgSlugs {
|
||||
teamIDs, err := p.listUserOrganizationTeams(ctx, userID, orgSlug)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, teamID := range teamIDs {
|
||||
teamIDLookup[teamID] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
for teamID := range teamIDLookup {
|
||||
du.GroupIds = append(du.GroupIds, strconv.Itoa(teamID))
|
||||
}
|
||||
sort.Strings(du.GroupIds)
|
||||
|
||||
return du, nil
|
||||
}
|
||||
|
||||
// UserGroups gets the directory user groups for github.
|
||||
func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*directory.User, error) {
|
||||
if p.cfg.serviceAccount == nil {
|
||||
|
@ -230,6 +272,77 @@ func (p *Provider) getUser(ctx context.Context, userLogin string) (*apiUserObjec
|
|||
return &res, nil
|
||||
}
|
||||
|
||||
func (p *Provider) listUserOrganizationTeams(ctx context.Context, userSlug string, orgSlug string) ([]int, error) {
|
||||
// GitHub's Rest API doesn't have an easy way of querying this data, so we use the GraphQL API.
|
||||
|
||||
enc := func(obj interface{}) string {
|
||||
bs, _ := json.Marshal(obj)
|
||||
return string(bs)
|
||||
}
|
||||
const pageCount = 100
|
||||
|
||||
var teamIDs []int
|
||||
var cursor *string
|
||||
for {
|
||||
var res struct {
|
||||
Data struct {
|
||||
Organization struct {
|
||||
Teams struct {
|
||||
TotalCount int `json:"totalCount"`
|
||||
PageInfo struct {
|
||||
EndCursor string `json:"endCursor"`
|
||||
} `json:"pageInfo"`
|
||||
Edges []struct {
|
||||
Node struct {
|
||||
ID int `json:"id"`
|
||||
} `json:"node"`
|
||||
} `json:"edges"`
|
||||
} `json:"teams"`
|
||||
} `json:"organization"`
|
||||
} `json:"data"`
|
||||
}
|
||||
cursorStr := ""
|
||||
if cursor != nil {
|
||||
cursorStr = fmt.Sprintf(",%s", enc(*cursor))
|
||||
}
|
||||
q := fmt.Sprintf(`query {
|
||||
organization(login:%s) {
|
||||
teams(first:%s, userLogins:[%s] %s) {
|
||||
totalCount
|
||||
pageInfo {
|
||||
endCursor
|
||||
}
|
||||
edges {
|
||||
node {
|
||||
id
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}`, enc(orgSlug), enc(pageCount), enc(userSlug), cursorStr)
|
||||
_, err := p.graphql(ctx, q, &res)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(res.Data.Organization.Teams.Edges) == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
for _, edge := range res.Data.Organization.Teams.Edges {
|
||||
teamIDs = append(teamIDs, edge.Node.ID)
|
||||
}
|
||||
|
||||
if len(teamIDs) >= res.Data.Organization.Teams.TotalCount {
|
||||
break
|
||||
}
|
||||
|
||||
cursor = &res.Data.Organization.Teams.PageInfo.EndCursor
|
||||
}
|
||||
|
||||
return teamIDs, nil
|
||||
}
|
||||
|
||||
func (p *Provider) api(ctx context.Context, apiURL string, out interface{}) (http.Header, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", apiURL, nil)
|
||||
if err != nil {
|
||||
|
@ -257,6 +370,41 @@ func (p *Provider) api(ctx context.Context, apiURL string, out interface{}) (htt
|
|||
return res.Header, nil
|
||||
}
|
||||
|
||||
func (p *Provider) graphql(ctx context.Context, query string, out interface{}) (http.Header, error) {
|
||||
apiURL := p.cfg.url.ResolveReference(&url.URL{
|
||||
Path: "/graphql",
|
||||
}).String()
|
||||
|
||||
bs, _ := json.Marshal(struct {
|
||||
Query string `json:"query"`
|
||||
}{query})
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", apiURL, bytes.NewReader(bs))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("github: failed to create http request: %w", err)
|
||||
}
|
||||
req.SetBasicAuth(p.cfg.serviceAccount.Username, p.cfg.serviceAccount.PersonalAccessToken)
|
||||
|
||||
res, err := p.cfg.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("github: failed to make http request: %w", err)
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
if res.StatusCode/100 != 2 {
|
||||
return nil, fmt.Errorf("github: error from API: %s", res.Status)
|
||||
}
|
||||
|
||||
if out != nil {
|
||||
err := json.NewDecoder(res.Body).Decode(out)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("github: failed to decode json body: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return res.Header, nil
|
||||
}
|
||||
|
||||
func getNextLink(hdrs http.Header) string {
|
||||
for _, link := range linkheader.ParseMultiple(hdrs.Values("Link")) {
|
||||
if link.Rel == "next" {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue