diff --git a/internal/directory/ping/api.go b/internal/directory/ping/api.go new file mode 100644 index 000000000..0e0c4a872 --- /dev/null +++ b/internal/directory/ping/api.go @@ -0,0 +1,174 @@ +package ping + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "net/url" + "strings" +) + +var errNotFound = errors.New("ping: user not found") + +type ( + apiGroup struct { + ID string `json:"id"` + Name string `json:"name"` + } + apiUser struct { + ID string `json:"id"` + Email string `json:"email"` + Name apiUserName `json:"name"` + Username string `json:"username"` + MemberOfGroupIDs []string `json:"memberOfGroupIDs"` + } + apiUserName struct { + Given string `json:"given"` + Middle string `json:"middle"` + Family string `json:"family"` + } +) + +func (au apiUser) getDisplayName() string { + var parts []string + if au.Name.Given != "" { + parts = append(parts, au.Name.Given) + } + if au.Name.Middle != "" { + parts = append(parts, au.Name.Middle) + } + if au.Name.Family != "" { + parts = append(parts, au.Name.Family) + } + if len(parts) == 0 { + parts = append(parts, au.Username) + } + return strings.Join(parts, " ") +} + +func getAllGroups(ctx context.Context, client *http.Client, apiURL *url.URL, envID string) ([]apiGroup, error) { + nextURL := apiURL.ResolveReference(&url.URL{ + Path: fmt.Sprintf("/v1/environments/%s/groups", url.PathEscape(envID)), + }).String() + + var apiGroups []apiGroup + err := batchAPIRequest(ctx, client, nextURL, func(body []byte) error { + var apiResponse struct { + Embedded struct { + Groups []apiGroup `json:"groups"` + } `json:"_embedded"` + } + err := json.Unmarshal(body, &apiResponse) + if err != nil { + return fmt.Errorf("ping: error decoding API response: %w", err) + } + apiGroups = append(apiGroups, apiResponse.Embedded.Groups...) + return nil + }) + return apiGroups, err +} + +func getGroupUsers(ctx context.Context, client *http.Client, apiURL *url.URL, envID, groupID string) ([]apiUser, error) { + nextURL := apiURL.ResolveReference(&url.URL{ + Path: fmt.Sprintf("/v1/environments/%s/users", url.PathEscape(envID)), + RawQuery: (&url.Values{ + "filter": {fmt.Sprintf(`memberOfGroups[id eq "%s"]`, groupID)}, + }).Encode(), + }).String() + + var apiUsers []apiUser + err := batchAPIRequest(ctx, client, nextURL, func(body []byte) error { + var apiResponse struct { + Embedded struct { + Users []apiUser `json:"users"` + } `json:"_embedded"` + } + err := json.Unmarshal(body, &apiResponse) + if err != nil { + return fmt.Errorf("ping: error decoding API response: %w", err) + } + apiUsers = append(apiUsers, apiResponse.Embedded.Users...) + return nil + }) + return apiUsers, err +} + +func getUser(ctx context.Context, client *http.Client, apiURL *url.URL, envID, userID string) (*apiUser, error) { + nextURL := apiURL.ResolveReference(&url.URL{ + Path: fmt.Sprintf("/v1/environments/%s/users/%s", url.PathEscape(envID), url.PathEscape(userID)), + RawQuery: (&url.Values{ + "include": {"memberOfGroupIDs"}, + }).Encode(), + }).String() + + req, err := http.NewRequestWithContext(ctx, "GET", nextURL, nil) + if err != nil { + return nil, fmt.Errorf("ping: error building API request: %w", err) + } + res, err := client.Do(req) + if err != nil { + return nil, fmt.Errorf("ping: error making API request: %w", err) + } + body, err := io.ReadAll(res.Body) + if err != nil { + return nil, fmt.Errorf("ping: error reading API response: %w", err) + } + _ = res.Body.Close() + if res.StatusCode == http.StatusNotFound { + return nil, errNotFound + } else if res.StatusCode/100 != 2 { + return nil, fmt.Errorf("ping: unexpected status code: %d", res.StatusCode) + } + + var u apiUser + err = json.Unmarshal(body, &u) + if err != nil { + return nil, fmt.Errorf("ping: error decoding API response: %w", err) + } + return &u, nil +} + +func batchAPIRequest(ctx context.Context, client *http.Client, nextURL string, callback func(body []byte) error) error { + for nextURL != "" { + req, err := http.NewRequestWithContext(ctx, "GET", nextURL, nil) + if err != nil { + return fmt.Errorf("ping: error building API request: %w", err) + } + + res, err := client.Do(req) + if err != nil { + return fmt.Errorf("ping: error making API request: %w", err) + } + bs, err := io.ReadAll(res.Body) + if err != nil { + return fmt.Errorf("ping: error reading API response: %w", err) + } + _ = res.Body.Close() + if res.StatusCode/100 != 2 { + return fmt.Errorf("ping: unexpected status code: %d", res.StatusCode) + } + + var apiResponse struct { + Links struct { + Next struct { + HREF string `json:"href"` + } `json:"next"` + } `json:"_links"` + } + err = json.Unmarshal(bs, &apiResponse) + if err != nil { + return fmt.Errorf("ping: error decoding API response: %w", err) + } + + err = callback(bs) + if err != nil { + return err + } + + nextURL = apiResponse.Links.Next.HREF + } + return nil +} diff --git a/internal/directory/ping/config.go b/internal/directory/ping/config.go new file mode 100644 index 000000000..b1be48450 --- /dev/null +++ b/internal/directory/ping/config.go @@ -0,0 +1,111 @@ +package ping + +import ( + "fmt" + "net/http" + "net/url" + "strings" + + "github.com/pomerium/pomerium/internal/encoding" +) + +type config struct { + authURL *url.URL + apiURL *url.URL + serviceAccount *ServiceAccount + httpClient *http.Client + environmentID string +} + +// An Option updates the Ping configuration. +type Option func(*config) + +// WithAPIURL sets the api url in the config. +func WithAPIURL(apiURL *url.URL) Option { + return func(cfg *config) { + cfg.apiURL = apiURL + } +} + +// WithAuthURL sets the auth url in the config. +func WithAuthURL(authURL *url.URL) Option { + return func(cfg *config) { + cfg.authURL = authURL + } +} + +// WithEnvironmentID sets the environment ID in the config. +func WithEnvironmentID(environmentID string) Option { + return func(cfg *config) { + cfg.environmentID = environmentID + } +} + +// WithHTTPClient sets the http client option. +func WithHTTPClient(httpClient *http.Client) Option { + return func(cfg *config) { + cfg.httpClient = httpClient + } +} + +// WithProviderURL sets the environment ID from the provider URL set in the config. +func WithProviderURL(providerURL *url.URL) Option { + // provider URL will be https://auth.pingone.com/{ENVIRONMENT_ID}/as + if providerURL == nil { + return func(cfg *config) {} + } + parts := strings.Split(providerURL.Path, "/") + if len(parts) < 1 { + return func(cfg *config) {} + } + return WithEnvironmentID(parts[1]) +} + +// WithServiceAccount sets the service account in the config. +func WithServiceAccount(serviceAccount *ServiceAccount) Option { + return func(cfg *config) { + cfg.serviceAccount = serviceAccount + } +} + +func getConfig(options ...Option) *config { + cfg := new(config) + WithHTTPClient(http.DefaultClient)(cfg) + WithAuthURL(&url.URL{ + Scheme: "https", + Host: "auth.pingone.com", + })(cfg) + WithAPIURL(&url.URL{ + Scheme: "https", + Host: "api.pingone.com", + })(cfg) + for _, option := range options { + option(cfg) + } + return cfg +} + +// A ServiceAccount is used by the Ping provider to query the API. +type ServiceAccount struct { + ClientID string `json:"client_id"` + ClientSecret string `json:"client_secret"` + EnvironmentID string `json:"environment_id"` +} + +// ParseServiceAccount parses the service account in the config options. +func ParseServiceAccount(rawServiceAccount string) (*ServiceAccount, error) { + var serviceAccount ServiceAccount + err := encoding.DecodeBase64OrJSON(rawServiceAccount, &serviceAccount) + if err != nil { + return nil, err + } + + if serviceAccount.ClientID == "" { + return nil, fmt.Errorf("client_id is required") + } + if serviceAccount.ClientSecret == "" { + return nil, fmt.Errorf("client_secret is required") + } + + return &serviceAccount, nil +} diff --git a/internal/directory/ping/provider.go b/internal/directory/ping/provider.go new file mode 100644 index 000000000..b0aba6c56 --- /dev/null +++ b/internal/directory/ping/provider.go @@ -0,0 +1,167 @@ +// Package ping implements a directory provider for Ping. +package ping + +import ( + "context" + "fmt" + "net/http" + "net/url" + "sort" + "sync" + + "golang.org/x/oauth2" + "golang.org/x/oauth2/clientcredentials" + + "github.com/pomerium/pomerium/pkg/grpc/directory" +) + +// Name is the name of the Ping provider. +const Name = "ping" + +// Provider implements a directory provider using the Ping API. +type Provider struct { + cfg *config + mu sync.RWMutex + token *oauth2.Token +} + +// New creates a new Ping Provider. +func New(options ...Option) *Provider { + cfg := getConfig(options...) + return &Provider{ + cfg: cfg, + } +} + +// User returns a user's directory information. +func (p *Provider) User(ctx context.Context, userID, accessToken string) (*directory.User, error) { + client, err := p.getClient(ctx) + if err != nil { + return nil, err + } + + au, err := getUser(ctx, client, p.cfg.apiURL, p.cfg.environmentID, userID) + if err != nil { + return nil, err + } + + return &directory.User{ + Id: au.ID, + DisplayName: au.getDisplayName(), + Email: au.Email, + GroupIds: au.MemberOfGroupIDs, + }, nil +} + +// UserGroups returns all the users and groups in the directory. +func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*directory.User, error) { + client, err := p.getClient(ctx) + if err != nil { + return nil, nil, err + } + + apiGroups, err := getAllGroups(ctx, client, p.cfg.apiURL, p.cfg.environmentID) + if err != nil { + return nil, nil, err + } + + directoryUserLookup := map[string]*directory.User{} + directoryGroups := make([]*directory.Group, len(apiGroups)) + for i, ag := range apiGroups { + dg := &directory.Group{ + Id: ag.ID, + Name: ag.Name, + } + + apiUsers, err := getGroupUsers(ctx, client, p.cfg.apiURL, p.cfg.environmentID, ag.ID) + if err != nil { + return nil, nil, err + } + for _, au := range apiUsers { + du, ok := directoryUserLookup[au.ID] + if !ok { + du = &directory.User{ + Id: au.ID, + DisplayName: au.getDisplayName(), + Email: au.Email, + } + directoryUserLookup[au.ID] = du + } + du.GroupIds = append(du.GroupIds, ag.ID) + } + + directoryGroups[i] = dg + } + sort.Slice(directoryGroups, func(i, j int) bool { + return directoryGroups[i].Id < directoryGroups[j].Id + }) + + directoryUsers := make([]*directory.User, 0, len(directoryUserLookup)) + for _, du := range directoryUserLookup { + directoryUsers = append(directoryUsers, du) + } + sort.Slice(directoryUsers, func(i, j int) bool { + return directoryUsers[i].Id < directoryUsers[j].Id + }) + + return directoryGroups, directoryUsers, nil +} + +func (p *Provider) getClient(ctx context.Context) (*http.Client, error) { + token, err := p.getToken(ctx) + if err != nil { + return nil, err + } + + client := new(http.Client) + *client = *p.cfg.httpClient + client.Transport = &oauth2.Transport{ + Source: oauth2.StaticTokenSource(token), + Base: p.cfg.httpClient.Transport, + } + return client, nil +} + +func (p *Provider) getToken(ctx context.Context) (*oauth2.Token, error) { + if p.cfg.serviceAccount == nil { + return nil, fmt.Errorf("ping: service account is required") + } + environmentID := p.cfg.serviceAccount.EnvironmentID + if environmentID == "" { + environmentID = p.cfg.environmentID + } + if environmentID == "" { + return nil, fmt.Errorf("ping: environment ID is required") + } + + p.mu.RLock() + token := p.token + p.mu.RUnlock() + + if token != nil && token.Valid() { + return token, nil + } + + p.mu.Lock() + defer p.mu.Unlock() + + token = p.token + if token != nil && token.Valid() { + return token, nil + } + + ocfg := &clientcredentials.Config{ + ClientID: p.cfg.serviceAccount.ClientID, + ClientSecret: p.cfg.serviceAccount.ClientSecret, + TokenURL: p.cfg.authURL.ResolveReference(&url.URL{ + Path: fmt.Sprintf("/%s/as/token", environmentID), + }).String(), + } + var err error + p.token, err = ocfg.Token(ctx) + if err != nil { + return nil, err + } + + return p.token, nil +} diff --git a/internal/directory/ping/provider_test.go b/internal/directory/ping/provider_test.go new file mode 100644 index 000000000..4775037dd --- /dev/null +++ b/internal/directory/ping/provider_test.go @@ -0,0 +1,230 @@ +package ping + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "net/url" + "sort" + "testing" + "time" + + "github.com/go-chi/chi" + "github.com/go-chi/chi/middleware" + "github.com/stretchr/testify/require" + + "github.com/pomerium/pomerium/internal/testutil" +) + +type M = map[string]interface{} + +func newMockAPI(userIDToGroupIDs map[string][]string) http.Handler { + lookup := map[string]struct{}{} + for _, groups := range userIDToGroupIDs { + for _, group := range groups { + lookup[group] = struct{}{} + } + } + var allGroups []string + for groupID := range lookup { + allGroups = append(allGroups, groupID) + } + sort.Strings(allGroups) + + var allUserIDs []string + for userID := range userIDToGroupIDs { + allUserIDs = append(allUserIDs, userID) + } + sort.Strings(allUserIDs) + + filterToUserIDs := map[string][]string{} + for userID, groupIDs := range userIDToGroupIDs { + for _, groupID := range groupIDs { + filter := fmt.Sprintf(`memberOfGroups[id eq "%s"]`, groupID) + filterToUserIDs[filter] = append(filterToUserIDs[filter], userID) + } + } + + r := chi.NewRouter() + r.Use(middleware.Logger) + r.Post("/ENVIRONMENTID/as/token", func(w http.ResponseWriter, r *http.Request) { + u, p, _ := r.BasicAuth() + if u != "CLIENTID" || p != "CLIENTSECRET" { + http.Error(w, "forbidden", http.StatusForbidden) + return + } + + grantType := r.FormValue("grant_type") + if grantType != "client_credentials" { + http.Error(w, "invalid grant_type", http.StatusBadRequest) + return + } + + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(M{ + "access_token": "ACCESSTOKEN", + "created_at": time.Now().Format(time.RFC3339), + "expires_in": 360000, + "refresh_token": "REFRESHTOKEN", + "token_type": "bearer", + }) + }) + r.Route("/v1/environments/ENVIRONMENTID", func(r chi.Router) { + r.Get("/groups", func(w http.ResponseWriter, r *http.Request) { + var apiGroups []apiGroup + for _, id := range allGroups { + apiGroups = append(apiGroups, apiGroup{ + ID: id, + Name: "Group " + id, + }) + } + + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(M{ + "_embedded": M{ + "groups": apiGroups, + }, + }) + }) + r.Route("/users", func(r chi.Router) { + r.Get("/{user_id}", func(w http.ResponseWriter, r *http.Request) { + userID := chi.URLParam(r, "user_id") + groupIDs, ok := userIDToGroupIDs[userID] + if !ok { + http.NotFound(w, r) + return + } + + au := apiUser{ + ID: userID, + Email: userID + "@example.com", + Name: apiUserName{ + Given: "Given-" + userID, + Middle: "Middle-" + userID, + Family: "Family-" + userID, + }, + } + if r.URL.Query().Get("include") == "memberOfGroupIDs" { + au.MemberOfGroupIDs = groupIDs + } + + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(au) + }) + r.Get("/", func(w http.ResponseWriter, r *http.Request) { + filter := r.URL.Query().Get("filter") + userIDs, ok := filterToUserIDs[filter] + if !ok { + http.Error(w, "expected filter", http.StatusBadRequest) + return + } + + var apiUsers []apiUser + for _, id := range userIDs { + apiUsers = append(apiUsers, apiUser{ + ID: id, + Email: id + "@example.com", + Name: apiUserName{ + Given: "Given-" + id, + Middle: "Middle-" + id, + Family: "Family-" + id, + }, + }) + } + + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(M{ + "_embedded": M{ + "users": apiUsers, + }, + }) + }) + }) + }) + return r +} + +func TestProvider_User(t *testing.T) { + ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10) + defer clearTimeout() + + srv := httptest.NewServer(newMockAPI(map[string][]string{ + "user1": {"group1", "group2"}, + "user2": {"group1", "group3"}, + "user3": {"group3"}, + })) + defer srv.Close() + + u, err := url.Parse(srv.URL) + require.NoError(t, err) + + p := New( + WithAPIURL(u), + WithAuthURL(u), + WithEnvironmentID("ENVIRONMENTID"), + WithServiceAccount(&ServiceAccount{ + ClientID: "CLIENTID", + ClientSecret: "CLIENTSECRET", + })) + du, err := p.User(ctx, "user1", "") + require.NoError(t, err) + testutil.AssertProtoJSONEqual(t, `{ + "id": "user1", + "email": "user1@example.com", + "displayName": "Given-user1 Middle-user1 Family-user1", + "groupIds": ["group1", "group2"] + }`, du) +} + +func TestProvider_UserGroups(t *testing.T) { + ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10) + defer clearTimeout() + + srv := httptest.NewServer(newMockAPI(map[string][]string{ + "user1": {"group1", "group2"}, + "user2": {"group1", "group3"}, + "user3": {"group3"}, + })) + defer srv.Close() + + u, err := url.Parse(srv.URL) + require.NoError(t, err) + + p := New( + WithAPIURL(u), + WithAuthURL(u), + WithEnvironmentID("ENVIRONMENTID"), + WithServiceAccount(&ServiceAccount{ + ClientID: "CLIENTID", + ClientSecret: "CLIENTSECRET", + })) + dgs, dus, err := p.UserGroups(ctx) + require.NoError(t, err) + testutil.AssertProtoJSONEqual(t, `[ + { "id": "group1", "name": "Group group1" }, + { "id": "group2", "name": "Group group2" }, + { "id": "group3", "name": "Group group3" } + ]`, dgs) + testutil.AssertProtoJSONEqual(t, `[ + { + "id": "user1", + "email": "user1@example.com", + "displayName": "Given-user1 Middle-user1 Family-user1", + "groupIds": ["group1", "group2"] + }, + { + "id": "user2", + "email": "user2@example.com", + "displayName": "Given-user2 Middle-user2 Family-user2", + "groupIds": ["group1", "group3"] + }, + { + "id": "user3", + "email": "user3@example.com", + "displayName": "Given-user3 Middle-user3 Family-user3", + "groupIds": ["group3"] + } + ]`, dus) +} diff --git a/internal/directory/provider.go b/internal/directory/provider.go index aebcc7abe..11b5dab26 100644 --- a/internal/directory/provider.go +++ b/internal/directory/provider.go @@ -8,6 +8,8 @@ import ( "github.com/google/go-cmp/cmp" + "github.com/pomerium/pomerium/internal/directory/ping" + "github.com/pomerium/pomerium/internal/directory/auth0" "github.com/pomerium/pomerium/internal/directory/azure" "github.com/pomerium/pomerium/internal/directory/github" @@ -138,6 +140,18 @@ func GetProvider(options Options) (provider Provider) { Str("provider", options.Provider). Err(err). Msg("invalid service account for onelogin directory provider") + case ping.Name: + serviceAccount, err := ping.ParseServiceAccount(options.ServiceAccount) + if err == nil { + return ping.New( + ping.WithProviderURL(providerURL), + ping.WithServiceAccount(serviceAccount)) + } + log.Warn(). + Str("service", "directory"). + Str("provider", options.Provider). + Err(err). + Msg("invalid service account for ping directory provider") } log.Warn(). diff --git a/internal/identity/oidc/ping/ping.go b/internal/identity/oidc/ping/ping.go new file mode 100644 index 000000000..79d1273bf --- /dev/null +++ b/internal/identity/oidc/ping/ping.go @@ -0,0 +1,39 @@ +// Package ping implements OpenID Connect for Ping +// +// https://www.pomerium.io/docs/identity-providers/ping.html +package ping + +import ( + "context" + "fmt" + + "github.com/pomerium/pomerium/internal/identity/oauth" + pom_oidc "github.com/pomerium/pomerium/internal/identity/oidc" +) + +const ( + // Name identifies the Ping identity provider. + Name = "ping" +) + +// Provider is a Ping implementation of the Authenticator interface. +type Provider struct { + *pom_oidc.Provider +} + +// New instantiates an OpenID Connect (OIDC) provider for Ping. +func New(ctx context.Context, o *oauth.Options) (*Provider, error) { + var p Provider + var err error + genericOidc, err := pom_oidc.New(ctx, o) + if err != nil { + return nil, fmt.Errorf("%s: failed creating oidc provider: %w", Name, err) + } + p.Provider = genericOidc + return &p, nil +} + +// Name returns the provider name. +func (p *Provider) Name() string { + return Name +} diff --git a/internal/identity/providers.go b/internal/identity/providers.go index 9c09b6b97..c1bc85808 100644 --- a/internal/identity/providers.go +++ b/internal/identity/providers.go @@ -20,6 +20,7 @@ import ( "github.com/pomerium/pomerium/internal/identity/oidc/google" "github.com/pomerium/pomerium/internal/identity/oidc/okta" "github.com/pomerium/pomerium/internal/identity/oidc/onelogin" + "github.com/pomerium/pomerium/internal/identity/oidc/ping" ) // Authenticator is an interface representing the ability to authenticate with an identity provider. @@ -53,6 +54,8 @@ func NewAuthenticator(o oauth.Options) (a Authenticator, err error) { a, err = okta.New(ctx, &o) case onelogin.Name: a, err = onelogin.New(ctx, &o) + case ping.Name: + a, err = ping.New(ctx, &o) default: return nil, fmt.Errorf("identity: unknown provider: %s", o.ProviderName) }