diff --git a/go.sum b/go.sum index 84a5898dd..0677ddf5c 100644 --- a/go.sum +++ b/go.sum @@ -356,8 +356,6 @@ github.com/onsi/gocleanup v0.0.0-20140331211545-c1a5478700b5/go.mod h1:tHaogb+iP github.com/onsi/gomega v1.4.3/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY= github.com/onsi/gomega v1.8.1 h1:C5Dqfs/LeauYDX0jJXIe2SWmwCbGzx9yF8C8xy3Lh34= github.com/onsi/gomega v1.8.1/go.mod h1:Ho0h+IUsWyvy1OpqCwxlQ/21gkhVunqlU8fDGcoTdcA= -github.com/open-policy-agent/opa v0.20.5 h1:1zEofrGa+a1Tb186yflIVMkkdAQujG3ySPbeTmR+py0= -github.com/open-policy-agent/opa v0.20.5/go.mod h1:cZaTfhxsj7QdIiUI0U9aBtOLLTqVNe+XE60+9kZKLHw= github.com/open-policy-agent/opa v0.21.0 h1:0CVq4EEUP+fJEzjwd9yNLSTWZk4W5rM+QbjdLcT1nY0= github.com/open-policy-agent/opa v0.21.0/go.mod h1:cZaTfhxsj7QdIiUI0U9aBtOLLTqVNe+XE60+9kZKLHw= github.com/openzipkin/zipkin-go v0.1.6/go.mod h1:QgAqvLzwWbR/WpD4A3cGpPtJrZXNIiJc5AZX7/PBEpw= diff --git a/internal/directory/github/github.go b/internal/directory/github/github.go new file mode 100644 index 000000000..2352788cb --- /dev/null +++ b/internal/directory/github/github.go @@ -0,0 +1,276 @@ +// Package github contains a directory provider for github. +package github + +import ( + "bytes" + "context" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "sort" + + "github.com/rs/zerolog" + "github.com/tomnomnom/linkheader" + + "github.com/pomerium/pomerium/internal/grpc/databroker" + "github.com/pomerium/pomerium/internal/grpc/directory" + "github.com/pomerium/pomerium/internal/log" +) + +// Name is the provider name. +const Name = "github" + +var ( + defaultURL = &url.URL{ + Scheme: "https", + Host: "api.github.com", + } +) + +type config struct { + httpClient *http.Client + serviceAccount *ServiceAccount + url *url.URL +} + +// An Option updates the github configuration. +type Option func(cfg *config) + +// WithServiceAccount sets the service account in the config. +func WithServiceAccount(serviceAccount *ServiceAccount) Option { + return func(cfg *config) { + cfg.serviceAccount = serviceAccount + } +} + +// WithHTTPClient sets the http client option. +func WithHTTPClient(httpClient *http.Client) Option { + return func(cfg *config) { + cfg.httpClient = httpClient + } +} + +// WithURL sets the api url in the config. +func WithURL(u *url.URL) Option { + return func(cfg *config) { + cfg.url = u + } +} + +func getConfig(options ...Option) *config { + cfg := new(config) + WithHTTPClient(http.DefaultClient)(cfg) + WithURL(defaultURL)(cfg) + for _, option := range options { + option(cfg) + } + return cfg +} + +// The Provider retrieves users and groups from github. +type Provider struct { + cfg *config + log zerolog.Logger +} + +// New creates a new Provider. +func New(options ...Option) *Provider { + return &Provider{ + cfg: getConfig(options...), + log: log.With().Str("service", "directory").Str("provider", "github").Logger(), + } +} + +// UserGroups gets the directory user groups for github. +func (p *Provider) UserGroups(ctx context.Context) ([]*directory.User, error) { + if p.cfg.serviceAccount == nil { + return nil, fmt.Errorf("github: service account not defined") + } + + orgSlugs, err := p.listOrgs(ctx) + if err != nil { + return nil, err + } + + userLoginToGroups := map[string][]string{} + + for _, orgSlug := range orgSlugs { + teamSlugs, err := p.listTeams(ctx, orgSlug) + if err != nil { + return nil, err + } + + for _, teamSlug := range teamSlugs { + userLogins, err := p.listTeamMembers(ctx, orgSlug, teamSlug) + if err != nil { + return nil, err + } + + for _, userLogin := range userLogins { + userLoginToGroups[userLogin] = append(userLoginToGroups[userLogin], teamSlug) + } + } + } + + var users []*directory.User + for userLogin, groups := range userLoginToGroups { + user := &directory.User{ + Id: databroker.GetUserID(Name, userLogin), + Groups: groups, + } + sort.Strings(user.Groups) + users = append(users, user) + } + sort.Slice(users, func(i, j int) bool { + return users[i].GetId() < users[j].GetId() + }) + return users, nil +} + +func (p *Provider) listOrgs(ctx context.Context) (orgSlugs []string, err error) { + nextURL := p.cfg.url.ResolveReference(&url.URL{ + Path: "/user/orgs", + }).String() + + for nextURL != "" { + var results []struct { + Login string `json:"login"` + } + hdrs, err := p.api(ctx, "GET", nextURL, nil, &results) + if err != nil { + return nil, err + } + + for _, result := range results { + orgSlugs = append(orgSlugs, result.Login) + } + + nextURL = getNextLink(hdrs) + } + + return orgSlugs, nil +} + +func (p *Provider) listTeams(ctx context.Context, orgSlug string) (teamSlugs []string, err error) { + nextURL := p.cfg.url.ResolveReference(&url.URL{ + Path: fmt.Sprintf("/orgs/%s/teams", orgSlug), + }).String() + + for nextURL != "" { + var results []struct { + Slug string `json:"slug"` + } + hdrs, err := p.api(ctx, "GET", nextURL, nil, &results) + if err != nil { + return nil, err + } + + for _, result := range results { + teamSlugs = append(teamSlugs, result.Slug) + } + + nextURL = getNextLink(hdrs) + } + + return teamSlugs, nil +} + +func (p *Provider) listTeamMembers(ctx context.Context, orgSlug, teamSlug string) (userLogins []string, err error) { + nextURL := p.cfg.url.ResolveReference(&url.URL{ + Path: fmt.Sprintf("/orgs/%s/teams/%s/members", orgSlug, teamSlug), + }).String() + + for nextURL != "" { + var results []struct { + Login string `json:"login"` + } + hdrs, err := p.api(ctx, "GET", nextURL, nil, &results) + if err != nil { + return nil, err + } + + for _, result := range results { + userLogins = append(userLogins, result.Login) + } + + nextURL = getNextLink(hdrs) + } + + return userLogins, err +} + +func (p *Provider) api(ctx context.Context, method string, apiURL string, in, out interface{}) (http.Header, error) { + var body io.Reader + if in != nil { + bs, err := json.Marshal(in) + if err != nil { + return nil, fmt.Errorf("github: failed to marshal api input") + } + body = bytes.NewReader(bs) + } + req, err := http.NewRequestWithContext(ctx, method, apiURL, body) + if err != nil { + return nil, fmt.Errorf("github: failed to create http request: %w", err) + } + req.SetBasicAuth(p.cfg.serviceAccount.Username, p.cfg.serviceAccount.PersonalAccessToken) + + res, err := p.cfg.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("github: failed to make http request: %w", err) + } + defer res.Body.Close() + + if res.StatusCode/100 != 2 { + return nil, fmt.Errorf("github: error from API: %s", res.Status) + } + + if out != nil { + err := json.NewDecoder(res.Body).Decode(out) + if err != nil { + return nil, fmt.Errorf("github: failed to decode json body: %w", err) + } + } + + return res.Header, nil +} + +func getNextLink(hdrs http.Header) string { + for _, link := range linkheader.ParseMultiple(hdrs.Values("Link")) { + if link.Rel == "next" { + return link.URL + } + } + return "" +} + +// A ServiceAccount is used by the GitHub provider to query the GitHub API. +type ServiceAccount struct { + Username string `json:"username"` + PersonalAccessToken string `json:"personal_access_token"` +} + +// ParseServiceAccount parses the service account in the config options. +func ParseServiceAccount(rawServiceAccount string) (*ServiceAccount, error) { + bs, err := base64.StdEncoding.DecodeString(rawServiceAccount) + if err != nil { + return nil, err + } + + var serviceAccount ServiceAccount + err = json.Unmarshal(bs, &serviceAccount) + if err != nil { + return nil, err + } + + if serviceAccount.Username == "" { + return nil, fmt.Errorf("username is required") + } + if serviceAccount.PersonalAccessToken == "" { + return nil, fmt.Errorf("personal_access_token is required") + } + + return &serviceAccount, nil +} diff --git a/internal/directory/github/github_test.go b/internal/directory/github/github_test.go new file mode 100644 index 000000000..c6026ce74 --- /dev/null +++ b/internal/directory/github/github_test.go @@ -0,0 +1,112 @@ +package github + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "net/url" + "testing" + + "github.com/go-chi/chi" + "github.com/go-chi/chi/middleware" + "github.com/stretchr/testify/assert" + + "github.com/pomerium/pomerium/internal/testutil" +) + +type M = map[string]interface{} + +func newMockAPI(t *testing.T, srv *httptest.Server) http.Handler { + r := chi.NewRouter() + r.Use(middleware.Logger) + r.Use(func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if !assert.Equal(t, "Basic YWJjOnh5eg==", r.Header.Get("Authorization")) { + http.Error(w, "forbidden", http.StatusForbidden) + return + } + next.ServeHTTP(w, r) + }) + }) + r.Get("/user/orgs", func(w http.ResponseWriter, r *http.Request) { + json.NewEncoder(w).Encode([]M{ + {"login": "org1"}, + {"login": "org2"}, + }) + }) + r.Get("/orgs/{org_id}/teams", func(w http.ResponseWriter, r *http.Request) { + teams := map[string][]M{ + "org1": { + {"slug": "team1"}, + {"slug": "team2"}, + }, + "org2": { + {"slug": "team3"}, + {"slug": "team4"}, + }, + } + orgID := chi.URLParam(r, "org_id") + json.NewEncoder(w).Encode(teams[orgID]) + }) + r.Get("/orgs/{org_id}/teams/{team_id}/members", func(w http.ResponseWriter, r *http.Request) { + members := map[string]map[string][]M{ + "org1": { + "team1": { + {"login": "user1"}, + {"login": "user2"}, + }, + "team2": { + {"login": "user1"}, + }, + }, + "org2": { + "team3": { + {"login": "user1"}, + {"login": "user2"}, + {"login": "user3"}, + }, + "team4": { + {"login": "user4"}, + }, + }, + } + orgID := chi.URLParam(r, "org_id") + teamID := chi.URLParam(r, "team_id") + json.NewEncoder(w).Encode(members[orgID][teamID]) + }) + return r +} + +func Test(t *testing.T) { + var mockAPI http.Handler + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + mockAPI.ServeHTTP(w, r) + })) + defer srv.Close() + mockAPI = newMockAPI(t, srv) + + p := New( + WithURL(mustParseURL(srv.URL)), + WithServiceAccount(&ServiceAccount{ + Username: "abc", + PersonalAccessToken: "xyz", + }), + ) + users, err := p.UserGroups(context.Background()) + assert.NoError(t, err) + testutil.AssertProtoJSONEqual(t, `[ + { "id": "github/user1", "groups": ["team1", "team2", "team3"] }, + { "id": "github/user2", "groups": ["team1", "team3"] }, + { "id": "github/user3", "groups": ["team3"] }, + { "id": "github/user4", "groups": ["team4"] } + ]`, users) +} + +func mustParseURL(rawurl string) *url.URL { + u, err := url.Parse(rawurl) + if err != nil { + panic(err) + } + return u +} diff --git a/internal/directory/provider.go b/internal/directory/provider.go index 763ca53d5..d2ed42252 100644 --- a/internal/directory/provider.go +++ b/internal/directory/provider.go @@ -7,6 +7,7 @@ import ( "github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/internal/directory/azure" + "github.com/pomerium/pomerium/internal/directory/github" "github.com/pomerium/pomerium/internal/directory/gitlab" "github.com/pomerium/pomerium/internal/directory/google" "github.com/pomerium/pomerium/internal/directory/okta" @@ -37,6 +38,16 @@ func GetProvider(options *config.Options) Provider { Str("provider", options.Provider). Err(err). Msg("invalid service account for azure directory provider") + case github.Name: + serviceAccount, err := github.ParseServiceAccount(options.ServiceAccount) + if err == nil { + return github.New(github.WithServiceAccount(serviceAccount)) + } + log.Warn(). + Str("service", "directory"). + Str("provider", options.Provider). + Err(err). + Msg("invalid service account for github directory provider") case gitlab.Name: serviceAccount, err := gitlab.ParseServiceAccount(options.ServiceAccount) if err == nil {