mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-19 12:07:18 +02:00
directory: add explicit RefreshUser endpoint for faster sync (#1460)
* directory: add explicit RefreshUser endpoint for faster sync * add test * implement azure * update api call * add test for azure User * implement github * implement AccessToken, gitlab * implement okta * implement onelogin * fix test * fix inconsistent test * implement auth0
This commit is contained in:
parent
9b39deabd8
commit
aa731ae068
23 changed files with 1405 additions and 179 deletions
|
@ -2,6 +2,7 @@
|
|||
package github
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
|
@ -83,6 +84,47 @@ func New(options ...Option) *Provider {
|
|||
}
|
||||
}
|
||||
|
||||
// User returns the user record for the given id.
|
||||
func (p *Provider) User(ctx context.Context, userID, accessToken string) (*directory.User, error) {
|
||||
if p.cfg.serviceAccount == nil {
|
||||
return nil, fmt.Errorf("github: service account not defined")
|
||||
}
|
||||
|
||||
_, providerUserID := databroker.FromUserID(userID)
|
||||
du := &directory.User{
|
||||
Id: userID,
|
||||
}
|
||||
|
||||
au, err := p.getUser(ctx, providerUserID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
du.DisplayName = au.Name
|
||||
du.Email = au.Email
|
||||
|
||||
teamIDLookup := map[int]struct{}{}
|
||||
orgSlugs, err := p.listOrgs(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, orgSlug := range orgSlugs {
|
||||
teamIDs, err := p.listUserOrganizationTeams(ctx, userID, orgSlug)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, teamID := range teamIDs {
|
||||
teamIDLookup[teamID] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
for teamID := range teamIDLookup {
|
||||
du.GroupIds = append(du.GroupIds, strconv.Itoa(teamID))
|
||||
}
|
||||
sort.Strings(du.GroupIds)
|
||||
|
||||
return du, nil
|
||||
}
|
||||
|
||||
// UserGroups gets the directory user groups for github.
|
||||
func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*directory.User, error) {
|
||||
if p.cfg.serviceAccount == nil {
|
||||
|
@ -230,6 +272,77 @@ func (p *Provider) getUser(ctx context.Context, userLogin string) (*apiUserObjec
|
|||
return &res, nil
|
||||
}
|
||||
|
||||
func (p *Provider) listUserOrganizationTeams(ctx context.Context, userSlug string, orgSlug string) ([]int, error) {
|
||||
// GitHub's Rest API doesn't have an easy way of querying this data, so we use the GraphQL API.
|
||||
|
||||
enc := func(obj interface{}) string {
|
||||
bs, _ := json.Marshal(obj)
|
||||
return string(bs)
|
||||
}
|
||||
const pageCount = 100
|
||||
|
||||
var teamIDs []int
|
||||
var cursor *string
|
||||
for {
|
||||
var res struct {
|
||||
Data struct {
|
||||
Organization struct {
|
||||
Teams struct {
|
||||
TotalCount int `json:"totalCount"`
|
||||
PageInfo struct {
|
||||
EndCursor string `json:"endCursor"`
|
||||
} `json:"pageInfo"`
|
||||
Edges []struct {
|
||||
Node struct {
|
||||
ID int `json:"id"`
|
||||
} `json:"node"`
|
||||
} `json:"edges"`
|
||||
} `json:"teams"`
|
||||
} `json:"organization"`
|
||||
} `json:"data"`
|
||||
}
|
||||
cursorStr := ""
|
||||
if cursor != nil {
|
||||
cursorStr = fmt.Sprintf(",%s", enc(*cursor))
|
||||
}
|
||||
q := fmt.Sprintf(`query {
|
||||
organization(login:%s) {
|
||||
teams(first:%s, userLogins:[%s] %s) {
|
||||
totalCount
|
||||
pageInfo {
|
||||
endCursor
|
||||
}
|
||||
edges {
|
||||
node {
|
||||
id
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}`, enc(orgSlug), enc(pageCount), enc(userSlug), cursorStr)
|
||||
_, err := p.graphql(ctx, q, &res)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(res.Data.Organization.Teams.Edges) == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
for _, edge := range res.Data.Organization.Teams.Edges {
|
||||
teamIDs = append(teamIDs, edge.Node.ID)
|
||||
}
|
||||
|
||||
if len(teamIDs) >= res.Data.Organization.Teams.TotalCount {
|
||||
break
|
||||
}
|
||||
|
||||
cursor = &res.Data.Organization.Teams.PageInfo.EndCursor
|
||||
}
|
||||
|
||||
return teamIDs, nil
|
||||
}
|
||||
|
||||
func (p *Provider) api(ctx context.Context, apiURL string, out interface{}) (http.Header, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", apiURL, nil)
|
||||
if err != nil {
|
||||
|
@ -257,6 +370,41 @@ func (p *Provider) api(ctx context.Context, apiURL string, out interface{}) (htt
|
|||
return res.Header, nil
|
||||
}
|
||||
|
||||
func (p *Provider) graphql(ctx context.Context, query string, out interface{}) (http.Header, error) {
|
||||
apiURL := p.cfg.url.ResolveReference(&url.URL{
|
||||
Path: "/graphql",
|
||||
}).String()
|
||||
|
||||
bs, _ := json.Marshal(struct {
|
||||
Query string `json:"query"`
|
||||
}{query})
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", apiURL, bytes.NewReader(bs))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("github: failed to create http request: %w", err)
|
||||
}
|
||||
req.SetBasicAuth(p.cfg.serviceAccount.Username, p.cfg.serviceAccount.PersonalAccessToken)
|
||||
|
||||
res, err := p.cfg.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("github: failed to make http request: %w", err)
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
if res.StatusCode/100 != 2 {
|
||||
return nil, fmt.Errorf("github: error from API: %s", res.Status)
|
||||
}
|
||||
|
||||
if out != nil {
|
||||
err := json.NewDecoder(res.Body).Decode(out)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("github: failed to decode json body: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return res.Header, nil
|
||||
}
|
||||
|
||||
func getNextLink(hdrs http.Header) string {
|
||||
for _, link := range linkheader.ParseMultiple(hdrs.Values("Link")) {
|
||||
if link.Rel == "next" {
|
||||
|
|
|
@ -29,6 +29,33 @@ func newMockAPI(t *testing.T, srv *httptest.Server) http.Handler {
|
|||
next.ServeHTTP(w, r)
|
||||
})
|
||||
})
|
||||
r.Post("/graphql", func(w http.ResponseWriter, r *http.Request) {
|
||||
var body struct {
|
||||
Query string `json:"query"`
|
||||
}
|
||||
json.NewDecoder(r.Body).Decode(&body)
|
||||
|
||||
_ = json.NewEncoder(w).Encode(M{
|
||||
"data": M{
|
||||
"organization": M{
|
||||
"teams": M{
|
||||
"totalCount": 3,
|
||||
"edges": []M{
|
||||
{"node": M{
|
||||
"id": 1,
|
||||
}},
|
||||
{"node": M{
|
||||
"id": 2,
|
||||
}},
|
||||
{"node": M{
|
||||
"id": 3,
|
||||
}},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
})
|
||||
r.Get("/user/orgs", func(w http.ResponseWriter, r *http.Request) {
|
||||
json.NewEncoder(w).Encode([]M{
|
||||
{"login": "org1"},
|
||||
|
@ -88,7 +115,34 @@ func newMockAPI(t *testing.T, srv *httptest.Server) http.Handler {
|
|||
return r
|
||||
}
|
||||
|
||||
func Test(t *testing.T) {
|
||||
func TestProvider_User(t *testing.T) {
|
||||
var mockAPI http.Handler
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
mockAPI.ServeHTTP(w, r)
|
||||
}))
|
||||
defer srv.Close()
|
||||
mockAPI = newMockAPI(t, srv)
|
||||
|
||||
p := New(
|
||||
WithURL(mustParseURL(srv.URL)),
|
||||
WithServiceAccount(&ServiceAccount{
|
||||
Username: "abc",
|
||||
PersonalAccessToken: "xyz",
|
||||
}),
|
||||
)
|
||||
du, err := p.User(context.Background(), "github/user1", "")
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
testutil.AssertProtoJSONEqual(t, `{
|
||||
"id": "github/user1",
|
||||
"groupIds": ["1", "2", "3"],
|
||||
"displayName": "User 1",
|
||||
"email": "user1@example.com"
|
||||
}`, du)
|
||||
}
|
||||
|
||||
func TestProvider_UserGroups(t *testing.T) {
|
||||
var mockAPI http.Handler
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
mockAPI.ServeHTTP(w, r)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue