ping: identity and directory providers (#1975)

* ping: add identity provider

* ping: implement directory provider

* ping, not onelogin

* ping, not onelogin

* escape path params
This commit is contained in:
Caleb Doxsey 2021-03-10 16:25:49 -07:00 committed by GitHub
parent 00a1cb7456
commit fd97561ab1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 738 additions and 0 deletions

View file

@ -0,0 +1,174 @@
package ping
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"strings"
)
var errNotFound = errors.New("ping: user not found")
type (
apiGroup struct {
ID string `json:"id"`
Name string `json:"name"`
}
apiUser struct {
ID string `json:"id"`
Email string `json:"email"`
Name apiUserName `json:"name"`
Username string `json:"username"`
MemberOfGroupIDs []string `json:"memberOfGroupIDs"`
}
apiUserName struct {
Given string `json:"given"`
Middle string `json:"middle"`
Family string `json:"family"`
}
)
func (au apiUser) getDisplayName() string {
var parts []string
if au.Name.Given != "" {
parts = append(parts, au.Name.Given)
}
if au.Name.Middle != "" {
parts = append(parts, au.Name.Middle)
}
if au.Name.Family != "" {
parts = append(parts, au.Name.Family)
}
if len(parts) == 0 {
parts = append(parts, au.Username)
}
return strings.Join(parts, " ")
}
func getAllGroups(ctx context.Context, client *http.Client, apiURL *url.URL, envID string) ([]apiGroup, error) {
nextURL := apiURL.ResolveReference(&url.URL{
Path: fmt.Sprintf("/v1/environments/%s/groups", url.PathEscape(envID)),
}).String()
var apiGroups []apiGroup
err := batchAPIRequest(ctx, client, nextURL, func(body []byte) error {
var apiResponse struct {
Embedded struct {
Groups []apiGroup `json:"groups"`
} `json:"_embedded"`
}
err := json.Unmarshal(body, &apiResponse)
if err != nil {
return fmt.Errorf("ping: error decoding API response: %w", err)
}
apiGroups = append(apiGroups, apiResponse.Embedded.Groups...)
return nil
})
return apiGroups, err
}
func getGroupUsers(ctx context.Context, client *http.Client, apiURL *url.URL, envID, groupID string) ([]apiUser, error) {
nextURL := apiURL.ResolveReference(&url.URL{
Path: fmt.Sprintf("/v1/environments/%s/users", url.PathEscape(envID)),
RawQuery: (&url.Values{
"filter": {fmt.Sprintf(`memberOfGroups[id eq "%s"]`, groupID)},
}).Encode(),
}).String()
var apiUsers []apiUser
err := batchAPIRequest(ctx, client, nextURL, func(body []byte) error {
var apiResponse struct {
Embedded struct {
Users []apiUser `json:"users"`
} `json:"_embedded"`
}
err := json.Unmarshal(body, &apiResponse)
if err != nil {
return fmt.Errorf("ping: error decoding API response: %w", err)
}
apiUsers = append(apiUsers, apiResponse.Embedded.Users...)
return nil
})
return apiUsers, err
}
func getUser(ctx context.Context, client *http.Client, apiURL *url.URL, envID, userID string) (*apiUser, error) {
nextURL := apiURL.ResolveReference(&url.URL{
Path: fmt.Sprintf("/v1/environments/%s/users/%s", url.PathEscape(envID), url.PathEscape(userID)),
RawQuery: (&url.Values{
"include": {"memberOfGroupIDs"},
}).Encode(),
}).String()
req, err := http.NewRequestWithContext(ctx, "GET", nextURL, nil)
if err != nil {
return nil, fmt.Errorf("ping: error building API request: %w", err)
}
res, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("ping: error making API request: %w", err)
}
body, err := io.ReadAll(res.Body)
if err != nil {
return nil, fmt.Errorf("ping: error reading API response: %w", err)
}
_ = res.Body.Close()
if res.StatusCode == http.StatusNotFound {
return nil, errNotFound
} else if res.StatusCode/100 != 2 {
return nil, fmt.Errorf("ping: unexpected status code: %d", res.StatusCode)
}
var u apiUser
err = json.Unmarshal(body, &u)
if err != nil {
return nil, fmt.Errorf("ping: error decoding API response: %w", err)
}
return &u, nil
}
func batchAPIRequest(ctx context.Context, client *http.Client, nextURL string, callback func(body []byte) error) error {
for nextURL != "" {
req, err := http.NewRequestWithContext(ctx, "GET", nextURL, nil)
if err != nil {
return fmt.Errorf("ping: error building API request: %w", err)
}
res, err := client.Do(req)
if err != nil {
return fmt.Errorf("ping: error making API request: %w", err)
}
bs, err := io.ReadAll(res.Body)
if err != nil {
return fmt.Errorf("ping: error reading API response: %w", err)
}
_ = res.Body.Close()
if res.StatusCode/100 != 2 {
return fmt.Errorf("ping: unexpected status code: %d", res.StatusCode)
}
var apiResponse struct {
Links struct {
Next struct {
HREF string `json:"href"`
} `json:"next"`
} `json:"_links"`
}
err = json.Unmarshal(bs, &apiResponse)
if err != nil {
return fmt.Errorf("ping: error decoding API response: %w", err)
}
err = callback(bs)
if err != nil {
return err
}
nextURL = apiResponse.Links.Next.HREF
}
return nil
}

View file

@ -0,0 +1,111 @@
package ping
import (
"fmt"
"net/http"
"net/url"
"strings"
"github.com/pomerium/pomerium/internal/encoding"
)
type config struct {
authURL *url.URL
apiURL *url.URL
serviceAccount *ServiceAccount
httpClient *http.Client
environmentID string
}
// An Option updates the Ping configuration.
type Option func(*config)
// WithAPIURL sets the api url in the config.
func WithAPIURL(apiURL *url.URL) Option {
return func(cfg *config) {
cfg.apiURL = apiURL
}
}
// WithAuthURL sets the auth url in the config.
func WithAuthURL(authURL *url.URL) Option {
return func(cfg *config) {
cfg.authURL = authURL
}
}
// WithEnvironmentID sets the environment ID in the config.
func WithEnvironmentID(environmentID string) Option {
return func(cfg *config) {
cfg.environmentID = environmentID
}
}
// WithHTTPClient sets the http client option.
func WithHTTPClient(httpClient *http.Client) Option {
return func(cfg *config) {
cfg.httpClient = httpClient
}
}
// WithProviderURL sets the environment ID from the provider URL set in the config.
func WithProviderURL(providerURL *url.URL) Option {
// provider URL will be https://auth.pingone.com/{ENVIRONMENT_ID}/as
if providerURL == nil {
return func(cfg *config) {}
}
parts := strings.Split(providerURL.Path, "/")
if len(parts) < 1 {
return func(cfg *config) {}
}
return WithEnvironmentID(parts[1])
}
// WithServiceAccount sets the service account in the config.
func WithServiceAccount(serviceAccount *ServiceAccount) Option {
return func(cfg *config) {
cfg.serviceAccount = serviceAccount
}
}
func getConfig(options ...Option) *config {
cfg := new(config)
WithHTTPClient(http.DefaultClient)(cfg)
WithAuthURL(&url.URL{
Scheme: "https",
Host: "auth.pingone.com",
})(cfg)
WithAPIURL(&url.URL{
Scheme: "https",
Host: "api.pingone.com",
})(cfg)
for _, option := range options {
option(cfg)
}
return cfg
}
// A ServiceAccount is used by the Ping provider to query the API.
type ServiceAccount struct {
ClientID string `json:"client_id"`
ClientSecret string `json:"client_secret"`
EnvironmentID string `json:"environment_id"`
}
// ParseServiceAccount parses the service account in the config options.
func ParseServiceAccount(rawServiceAccount string) (*ServiceAccount, error) {
var serviceAccount ServiceAccount
err := encoding.DecodeBase64OrJSON(rawServiceAccount, &serviceAccount)
if err != nil {
return nil, err
}
if serviceAccount.ClientID == "" {
return nil, fmt.Errorf("client_id is required")
}
if serviceAccount.ClientSecret == "" {
return nil, fmt.Errorf("client_secret is required")
}
return &serviceAccount, nil
}

View file

@ -0,0 +1,167 @@
// Package ping implements a directory provider for Ping.
package ping
import (
"context"
"fmt"
"net/http"
"net/url"
"sort"
"sync"
"golang.org/x/oauth2"
"golang.org/x/oauth2/clientcredentials"
"github.com/pomerium/pomerium/pkg/grpc/directory"
)
// Name is the name of the Ping provider.
const Name = "ping"
// Provider implements a directory provider using the Ping API.
type Provider struct {
cfg *config
mu sync.RWMutex
token *oauth2.Token
}
// New creates a new Ping Provider.
func New(options ...Option) *Provider {
cfg := getConfig(options...)
return &Provider{
cfg: cfg,
}
}
// User returns a user's directory information.
func (p *Provider) User(ctx context.Context, userID, accessToken string) (*directory.User, error) {
client, err := p.getClient(ctx)
if err != nil {
return nil, err
}
au, err := getUser(ctx, client, p.cfg.apiURL, p.cfg.environmentID, userID)
if err != nil {
return nil, err
}
return &directory.User{
Id: au.ID,
DisplayName: au.getDisplayName(),
Email: au.Email,
GroupIds: au.MemberOfGroupIDs,
}, nil
}
// UserGroups returns all the users and groups in the directory.
func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*directory.User, error) {
client, err := p.getClient(ctx)
if err != nil {
return nil, nil, err
}
apiGroups, err := getAllGroups(ctx, client, p.cfg.apiURL, p.cfg.environmentID)
if err != nil {
return nil, nil, err
}
directoryUserLookup := map[string]*directory.User{}
directoryGroups := make([]*directory.Group, len(apiGroups))
for i, ag := range apiGroups {
dg := &directory.Group{
Id: ag.ID,
Name: ag.Name,
}
apiUsers, err := getGroupUsers(ctx, client, p.cfg.apiURL, p.cfg.environmentID, ag.ID)
if err != nil {
return nil, nil, err
}
for _, au := range apiUsers {
du, ok := directoryUserLookup[au.ID]
if !ok {
du = &directory.User{
Id: au.ID,
DisplayName: au.getDisplayName(),
Email: au.Email,
}
directoryUserLookup[au.ID] = du
}
du.GroupIds = append(du.GroupIds, ag.ID)
}
directoryGroups[i] = dg
}
sort.Slice(directoryGroups, func(i, j int) bool {
return directoryGroups[i].Id < directoryGroups[j].Id
})
directoryUsers := make([]*directory.User, 0, len(directoryUserLookup))
for _, du := range directoryUserLookup {
directoryUsers = append(directoryUsers, du)
}
sort.Slice(directoryUsers, func(i, j int) bool {
return directoryUsers[i].Id < directoryUsers[j].Id
})
return directoryGroups, directoryUsers, nil
}
func (p *Provider) getClient(ctx context.Context) (*http.Client, error) {
token, err := p.getToken(ctx)
if err != nil {
return nil, err
}
client := new(http.Client)
*client = *p.cfg.httpClient
client.Transport = &oauth2.Transport{
Source: oauth2.StaticTokenSource(token),
Base: p.cfg.httpClient.Transport,
}
return client, nil
}
func (p *Provider) getToken(ctx context.Context) (*oauth2.Token, error) {
if p.cfg.serviceAccount == nil {
return nil, fmt.Errorf("ping: service account is required")
}
environmentID := p.cfg.serviceAccount.EnvironmentID
if environmentID == "" {
environmentID = p.cfg.environmentID
}
if environmentID == "" {
return nil, fmt.Errorf("ping: environment ID is required")
}
p.mu.RLock()
token := p.token
p.mu.RUnlock()
if token != nil && token.Valid() {
return token, nil
}
p.mu.Lock()
defer p.mu.Unlock()
token = p.token
if token != nil && token.Valid() {
return token, nil
}
ocfg := &clientcredentials.Config{
ClientID: p.cfg.serviceAccount.ClientID,
ClientSecret: p.cfg.serviceAccount.ClientSecret,
TokenURL: p.cfg.authURL.ResolveReference(&url.URL{
Path: fmt.Sprintf("/%s/as/token", environmentID),
}).String(),
}
var err error
p.token, err = ocfg.Token(ctx)
if err != nil {
return nil, err
}
return p.token, nil
}

View file

@ -0,0 +1,230 @@
package ping
import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"sort"
"testing"
"time"
"github.com/go-chi/chi"
"github.com/go-chi/chi/middleware"
"github.com/stretchr/testify/require"
"github.com/pomerium/pomerium/internal/testutil"
)
type M = map[string]interface{}
func newMockAPI(userIDToGroupIDs map[string][]string) http.Handler {
lookup := map[string]struct{}{}
for _, groups := range userIDToGroupIDs {
for _, group := range groups {
lookup[group] = struct{}{}
}
}
var allGroups []string
for groupID := range lookup {
allGroups = append(allGroups, groupID)
}
sort.Strings(allGroups)
var allUserIDs []string
for userID := range userIDToGroupIDs {
allUserIDs = append(allUserIDs, userID)
}
sort.Strings(allUserIDs)
filterToUserIDs := map[string][]string{}
for userID, groupIDs := range userIDToGroupIDs {
for _, groupID := range groupIDs {
filter := fmt.Sprintf(`memberOfGroups[id eq "%s"]`, groupID)
filterToUserIDs[filter] = append(filterToUserIDs[filter], userID)
}
}
r := chi.NewRouter()
r.Use(middleware.Logger)
r.Post("/ENVIRONMENTID/as/token", func(w http.ResponseWriter, r *http.Request) {
u, p, _ := r.BasicAuth()
if u != "CLIENTID" || p != "CLIENTSECRET" {
http.Error(w, "forbidden", http.StatusForbidden)
return
}
grantType := r.FormValue("grant_type")
if grantType != "client_credentials" {
http.Error(w, "invalid grant_type", http.StatusBadRequest)
return
}
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(M{
"access_token": "ACCESSTOKEN",
"created_at": time.Now().Format(time.RFC3339),
"expires_in": 360000,
"refresh_token": "REFRESHTOKEN",
"token_type": "bearer",
})
})
r.Route("/v1/environments/ENVIRONMENTID", func(r chi.Router) {
r.Get("/groups", func(w http.ResponseWriter, r *http.Request) {
var apiGroups []apiGroup
for _, id := range allGroups {
apiGroups = append(apiGroups, apiGroup{
ID: id,
Name: "Group " + id,
})
}
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(M{
"_embedded": M{
"groups": apiGroups,
},
})
})
r.Route("/users", func(r chi.Router) {
r.Get("/{user_id}", func(w http.ResponseWriter, r *http.Request) {
userID := chi.URLParam(r, "user_id")
groupIDs, ok := userIDToGroupIDs[userID]
if !ok {
http.NotFound(w, r)
return
}
au := apiUser{
ID: userID,
Email: userID + "@example.com",
Name: apiUserName{
Given: "Given-" + userID,
Middle: "Middle-" + userID,
Family: "Family-" + userID,
},
}
if r.URL.Query().Get("include") == "memberOfGroupIDs" {
au.MemberOfGroupIDs = groupIDs
}
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(au)
})
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
filter := r.URL.Query().Get("filter")
userIDs, ok := filterToUserIDs[filter]
if !ok {
http.Error(w, "expected filter", http.StatusBadRequest)
return
}
var apiUsers []apiUser
for _, id := range userIDs {
apiUsers = append(apiUsers, apiUser{
ID: id,
Email: id + "@example.com",
Name: apiUserName{
Given: "Given-" + id,
Middle: "Middle-" + id,
Family: "Family-" + id,
},
})
}
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(M{
"_embedded": M{
"users": apiUsers,
},
})
})
})
})
return r
}
func TestProvider_User(t *testing.T) {
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10)
defer clearTimeout()
srv := httptest.NewServer(newMockAPI(map[string][]string{
"user1": {"group1", "group2"},
"user2": {"group1", "group3"},
"user3": {"group3"},
}))
defer srv.Close()
u, err := url.Parse(srv.URL)
require.NoError(t, err)
p := New(
WithAPIURL(u),
WithAuthURL(u),
WithEnvironmentID("ENVIRONMENTID"),
WithServiceAccount(&ServiceAccount{
ClientID: "CLIENTID",
ClientSecret: "CLIENTSECRET",
}))
du, err := p.User(ctx, "user1", "")
require.NoError(t, err)
testutil.AssertProtoJSONEqual(t, `{
"id": "user1",
"email": "user1@example.com",
"displayName": "Given-user1 Middle-user1 Family-user1",
"groupIds": ["group1", "group2"]
}`, du)
}
func TestProvider_UserGroups(t *testing.T) {
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10)
defer clearTimeout()
srv := httptest.NewServer(newMockAPI(map[string][]string{
"user1": {"group1", "group2"},
"user2": {"group1", "group3"},
"user3": {"group3"},
}))
defer srv.Close()
u, err := url.Parse(srv.URL)
require.NoError(t, err)
p := New(
WithAPIURL(u),
WithAuthURL(u),
WithEnvironmentID("ENVIRONMENTID"),
WithServiceAccount(&ServiceAccount{
ClientID: "CLIENTID",
ClientSecret: "CLIENTSECRET",
}))
dgs, dus, err := p.UserGroups(ctx)
require.NoError(t, err)
testutil.AssertProtoJSONEqual(t, `[
{ "id": "group1", "name": "Group group1" },
{ "id": "group2", "name": "Group group2" },
{ "id": "group3", "name": "Group group3" }
]`, dgs)
testutil.AssertProtoJSONEqual(t, `[
{
"id": "user1",
"email": "user1@example.com",
"displayName": "Given-user1 Middle-user1 Family-user1",
"groupIds": ["group1", "group2"]
},
{
"id": "user2",
"email": "user2@example.com",
"displayName": "Given-user2 Middle-user2 Family-user2",
"groupIds": ["group1", "group3"]
},
{
"id": "user3",
"email": "user3@example.com",
"displayName": "Given-user3 Middle-user3 Family-user3",
"groupIds": ["group3"]
}
]`, dus)
}

View file

@ -8,6 +8,8 @@ import (
"github.com/google/go-cmp/cmp"
"github.com/pomerium/pomerium/internal/directory/ping"
"github.com/pomerium/pomerium/internal/directory/auth0"
"github.com/pomerium/pomerium/internal/directory/azure"
"github.com/pomerium/pomerium/internal/directory/github"
@ -138,6 +140,18 @@ func GetProvider(options Options) (provider Provider) {
Str("provider", options.Provider).
Err(err).
Msg("invalid service account for onelogin directory provider")
case ping.Name:
serviceAccount, err := ping.ParseServiceAccount(options.ServiceAccount)
if err == nil {
return ping.New(
ping.WithProviderURL(providerURL),
ping.WithServiceAccount(serviceAccount))
}
log.Warn().
Str("service", "directory").
Str("provider", options.Provider).
Err(err).
Msg("invalid service account for ping directory provider")
}
log.Warn().

View file

@ -0,0 +1,39 @@
// Package ping implements OpenID Connect for Ping
//
// https://www.pomerium.io/docs/identity-providers/ping.html
package ping
import (
"context"
"fmt"
"github.com/pomerium/pomerium/internal/identity/oauth"
pom_oidc "github.com/pomerium/pomerium/internal/identity/oidc"
)
const (
// Name identifies the Ping identity provider.
Name = "ping"
)
// Provider is a Ping implementation of the Authenticator interface.
type Provider struct {
*pom_oidc.Provider
}
// New instantiates an OpenID Connect (OIDC) provider for Ping.
func New(ctx context.Context, o *oauth.Options) (*Provider, error) {
var p Provider
var err error
genericOidc, err := pom_oidc.New(ctx, o)
if err != nil {
return nil, fmt.Errorf("%s: failed creating oidc provider: %w", Name, err)
}
p.Provider = genericOidc
return &p, nil
}
// Name returns the provider name.
func (p *Provider) Name() string {
return Name
}

View file

@ -20,6 +20,7 @@ import (
"github.com/pomerium/pomerium/internal/identity/oidc/google"
"github.com/pomerium/pomerium/internal/identity/oidc/okta"
"github.com/pomerium/pomerium/internal/identity/oidc/onelogin"
"github.com/pomerium/pomerium/internal/identity/oidc/ping"
)
// Authenticator is an interface representing the ability to authenticate with an identity provider.
@ -53,6 +54,8 @@ func NewAuthenticator(o oauth.Options) (a Authenticator, err error) {
a, err = okta.New(ctx, &o)
case onelogin.Name:
a, err = onelogin.New(ctx, &o)
case ping.Name:
a, err = ping.New(ctx, &o)
default:
return nil, fmt.Errorf("identity: unknown provider: %s", o.ProviderName)
}