mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-10 15:47:36 +02:00
ping: identity and directory providers (#1975)
* ping: add identity provider * ping: implement directory provider * ping, not onelogin * ping, not onelogin * escape path params
This commit is contained in:
parent
00a1cb7456
commit
fd97561ab1
7 changed files with 738 additions and 0 deletions
174
internal/directory/ping/api.go
Normal file
174
internal/directory/ping/api.go
Normal file
|
@ -0,0 +1,174 @@
|
||||||
|
package ping
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
var errNotFound = errors.New("ping: user not found")
|
||||||
|
|
||||||
|
type (
|
||||||
|
apiGroup struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
}
|
||||||
|
apiUser struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Email string `json:"email"`
|
||||||
|
Name apiUserName `json:"name"`
|
||||||
|
Username string `json:"username"`
|
||||||
|
MemberOfGroupIDs []string `json:"memberOfGroupIDs"`
|
||||||
|
}
|
||||||
|
apiUserName struct {
|
||||||
|
Given string `json:"given"`
|
||||||
|
Middle string `json:"middle"`
|
||||||
|
Family string `json:"family"`
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
func (au apiUser) getDisplayName() string {
|
||||||
|
var parts []string
|
||||||
|
if au.Name.Given != "" {
|
||||||
|
parts = append(parts, au.Name.Given)
|
||||||
|
}
|
||||||
|
if au.Name.Middle != "" {
|
||||||
|
parts = append(parts, au.Name.Middle)
|
||||||
|
}
|
||||||
|
if au.Name.Family != "" {
|
||||||
|
parts = append(parts, au.Name.Family)
|
||||||
|
}
|
||||||
|
if len(parts) == 0 {
|
||||||
|
parts = append(parts, au.Username)
|
||||||
|
}
|
||||||
|
return strings.Join(parts, " ")
|
||||||
|
}
|
||||||
|
|
||||||
|
func getAllGroups(ctx context.Context, client *http.Client, apiURL *url.URL, envID string) ([]apiGroup, error) {
|
||||||
|
nextURL := apiURL.ResolveReference(&url.URL{
|
||||||
|
Path: fmt.Sprintf("/v1/environments/%s/groups", url.PathEscape(envID)),
|
||||||
|
}).String()
|
||||||
|
|
||||||
|
var apiGroups []apiGroup
|
||||||
|
err := batchAPIRequest(ctx, client, nextURL, func(body []byte) error {
|
||||||
|
var apiResponse struct {
|
||||||
|
Embedded struct {
|
||||||
|
Groups []apiGroup `json:"groups"`
|
||||||
|
} `json:"_embedded"`
|
||||||
|
}
|
||||||
|
err := json.Unmarshal(body, &apiResponse)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("ping: error decoding API response: %w", err)
|
||||||
|
}
|
||||||
|
apiGroups = append(apiGroups, apiResponse.Embedded.Groups...)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
return apiGroups, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func getGroupUsers(ctx context.Context, client *http.Client, apiURL *url.URL, envID, groupID string) ([]apiUser, error) {
|
||||||
|
nextURL := apiURL.ResolveReference(&url.URL{
|
||||||
|
Path: fmt.Sprintf("/v1/environments/%s/users", url.PathEscape(envID)),
|
||||||
|
RawQuery: (&url.Values{
|
||||||
|
"filter": {fmt.Sprintf(`memberOfGroups[id eq "%s"]`, groupID)},
|
||||||
|
}).Encode(),
|
||||||
|
}).String()
|
||||||
|
|
||||||
|
var apiUsers []apiUser
|
||||||
|
err := batchAPIRequest(ctx, client, nextURL, func(body []byte) error {
|
||||||
|
var apiResponse struct {
|
||||||
|
Embedded struct {
|
||||||
|
Users []apiUser `json:"users"`
|
||||||
|
} `json:"_embedded"`
|
||||||
|
}
|
||||||
|
err := json.Unmarshal(body, &apiResponse)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("ping: error decoding API response: %w", err)
|
||||||
|
}
|
||||||
|
apiUsers = append(apiUsers, apiResponse.Embedded.Users...)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
return apiUsers, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func getUser(ctx context.Context, client *http.Client, apiURL *url.URL, envID, userID string) (*apiUser, error) {
|
||||||
|
nextURL := apiURL.ResolveReference(&url.URL{
|
||||||
|
Path: fmt.Sprintf("/v1/environments/%s/users/%s", url.PathEscape(envID), url.PathEscape(userID)),
|
||||||
|
RawQuery: (&url.Values{
|
||||||
|
"include": {"memberOfGroupIDs"},
|
||||||
|
}).Encode(),
|
||||||
|
}).String()
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, "GET", nextURL, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("ping: error building API request: %w", err)
|
||||||
|
}
|
||||||
|
res, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("ping: error making API request: %w", err)
|
||||||
|
}
|
||||||
|
body, err := io.ReadAll(res.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("ping: error reading API response: %w", err)
|
||||||
|
}
|
||||||
|
_ = res.Body.Close()
|
||||||
|
if res.StatusCode == http.StatusNotFound {
|
||||||
|
return nil, errNotFound
|
||||||
|
} else if res.StatusCode/100 != 2 {
|
||||||
|
return nil, fmt.Errorf("ping: unexpected status code: %d", res.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
var u apiUser
|
||||||
|
err = json.Unmarshal(body, &u)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("ping: error decoding API response: %w", err)
|
||||||
|
}
|
||||||
|
return &u, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func batchAPIRequest(ctx context.Context, client *http.Client, nextURL string, callback func(body []byte) error) error {
|
||||||
|
for nextURL != "" {
|
||||||
|
req, err := http.NewRequestWithContext(ctx, "GET", nextURL, nil)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("ping: error building API request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
res, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("ping: error making API request: %w", err)
|
||||||
|
}
|
||||||
|
bs, err := io.ReadAll(res.Body)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("ping: error reading API response: %w", err)
|
||||||
|
}
|
||||||
|
_ = res.Body.Close()
|
||||||
|
if res.StatusCode/100 != 2 {
|
||||||
|
return fmt.Errorf("ping: unexpected status code: %d", res.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
var apiResponse struct {
|
||||||
|
Links struct {
|
||||||
|
Next struct {
|
||||||
|
HREF string `json:"href"`
|
||||||
|
} `json:"next"`
|
||||||
|
} `json:"_links"`
|
||||||
|
}
|
||||||
|
err = json.Unmarshal(bs, &apiResponse)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("ping: error decoding API response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = callback(bs)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
nextURL = apiResponse.Links.Next.HREF
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
111
internal/directory/ping/config.go
Normal file
111
internal/directory/ping/config.go
Normal file
|
@ -0,0 +1,111 @@
|
||||||
|
package ping
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/internal/encoding"
|
||||||
|
)
|
||||||
|
|
||||||
|
type config struct {
|
||||||
|
authURL *url.URL
|
||||||
|
apiURL *url.URL
|
||||||
|
serviceAccount *ServiceAccount
|
||||||
|
httpClient *http.Client
|
||||||
|
environmentID string
|
||||||
|
}
|
||||||
|
|
||||||
|
// An Option updates the Ping configuration.
|
||||||
|
type Option func(*config)
|
||||||
|
|
||||||
|
// WithAPIURL sets the api url in the config.
|
||||||
|
func WithAPIURL(apiURL *url.URL) Option {
|
||||||
|
return func(cfg *config) {
|
||||||
|
cfg.apiURL = apiURL
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithAuthURL sets the auth url in the config.
|
||||||
|
func WithAuthURL(authURL *url.URL) Option {
|
||||||
|
return func(cfg *config) {
|
||||||
|
cfg.authURL = authURL
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithEnvironmentID sets the environment ID in the config.
|
||||||
|
func WithEnvironmentID(environmentID string) Option {
|
||||||
|
return func(cfg *config) {
|
||||||
|
cfg.environmentID = environmentID
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithHTTPClient sets the http client option.
|
||||||
|
func WithHTTPClient(httpClient *http.Client) Option {
|
||||||
|
return func(cfg *config) {
|
||||||
|
cfg.httpClient = httpClient
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithProviderURL sets the environment ID from the provider URL set in the config.
|
||||||
|
func WithProviderURL(providerURL *url.URL) Option {
|
||||||
|
// provider URL will be https://auth.pingone.com/{ENVIRONMENT_ID}/as
|
||||||
|
if providerURL == nil {
|
||||||
|
return func(cfg *config) {}
|
||||||
|
}
|
||||||
|
parts := strings.Split(providerURL.Path, "/")
|
||||||
|
if len(parts) < 1 {
|
||||||
|
return func(cfg *config) {}
|
||||||
|
}
|
||||||
|
return WithEnvironmentID(parts[1])
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithServiceAccount sets the service account in the config.
|
||||||
|
func WithServiceAccount(serviceAccount *ServiceAccount) Option {
|
||||||
|
return func(cfg *config) {
|
||||||
|
cfg.serviceAccount = serviceAccount
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func getConfig(options ...Option) *config {
|
||||||
|
cfg := new(config)
|
||||||
|
WithHTTPClient(http.DefaultClient)(cfg)
|
||||||
|
WithAuthURL(&url.URL{
|
||||||
|
Scheme: "https",
|
||||||
|
Host: "auth.pingone.com",
|
||||||
|
})(cfg)
|
||||||
|
WithAPIURL(&url.URL{
|
||||||
|
Scheme: "https",
|
||||||
|
Host: "api.pingone.com",
|
||||||
|
})(cfg)
|
||||||
|
for _, option := range options {
|
||||||
|
option(cfg)
|
||||||
|
}
|
||||||
|
return cfg
|
||||||
|
}
|
||||||
|
|
||||||
|
// A ServiceAccount is used by the Ping provider to query the API.
|
||||||
|
type ServiceAccount struct {
|
||||||
|
ClientID string `json:"client_id"`
|
||||||
|
ClientSecret string `json:"client_secret"`
|
||||||
|
EnvironmentID string `json:"environment_id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseServiceAccount parses the service account in the config options.
|
||||||
|
func ParseServiceAccount(rawServiceAccount string) (*ServiceAccount, error) {
|
||||||
|
var serviceAccount ServiceAccount
|
||||||
|
err := encoding.DecodeBase64OrJSON(rawServiceAccount, &serviceAccount)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if serviceAccount.ClientID == "" {
|
||||||
|
return nil, fmt.Errorf("client_id is required")
|
||||||
|
}
|
||||||
|
if serviceAccount.ClientSecret == "" {
|
||||||
|
return nil, fmt.Errorf("client_secret is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
return &serviceAccount, nil
|
||||||
|
}
|
167
internal/directory/ping/provider.go
Normal file
167
internal/directory/ping/provider.go
Normal file
|
@ -0,0 +1,167 @@
|
||||||
|
// Package ping implements a directory provider for Ping.
|
||||||
|
package ping
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"sort"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"golang.org/x/oauth2"
|
||||||
|
"golang.org/x/oauth2/clientcredentials"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/pkg/grpc/directory"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Name is the name of the Ping provider.
|
||||||
|
const Name = "ping"
|
||||||
|
|
||||||
|
// Provider implements a directory provider using the Ping API.
|
||||||
|
type Provider struct {
|
||||||
|
cfg *config
|
||||||
|
mu sync.RWMutex
|
||||||
|
token *oauth2.Token
|
||||||
|
}
|
||||||
|
|
||||||
|
// New creates a new Ping Provider.
|
||||||
|
func New(options ...Option) *Provider {
|
||||||
|
cfg := getConfig(options...)
|
||||||
|
return &Provider{
|
||||||
|
cfg: cfg,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// User returns a user's directory information.
|
||||||
|
func (p *Provider) User(ctx context.Context, userID, accessToken string) (*directory.User, error) {
|
||||||
|
client, err := p.getClient(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
au, err := getUser(ctx, client, p.cfg.apiURL, p.cfg.environmentID, userID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &directory.User{
|
||||||
|
Id: au.ID,
|
||||||
|
DisplayName: au.getDisplayName(),
|
||||||
|
Email: au.Email,
|
||||||
|
GroupIds: au.MemberOfGroupIDs,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UserGroups returns all the users and groups in the directory.
|
||||||
|
func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*directory.User, error) {
|
||||||
|
client, err := p.getClient(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
apiGroups, err := getAllGroups(ctx, client, p.cfg.apiURL, p.cfg.environmentID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
directoryUserLookup := map[string]*directory.User{}
|
||||||
|
directoryGroups := make([]*directory.Group, len(apiGroups))
|
||||||
|
for i, ag := range apiGroups {
|
||||||
|
dg := &directory.Group{
|
||||||
|
Id: ag.ID,
|
||||||
|
Name: ag.Name,
|
||||||
|
}
|
||||||
|
|
||||||
|
apiUsers, err := getGroupUsers(ctx, client, p.cfg.apiURL, p.cfg.environmentID, ag.ID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
for _, au := range apiUsers {
|
||||||
|
du, ok := directoryUserLookup[au.ID]
|
||||||
|
if !ok {
|
||||||
|
du = &directory.User{
|
||||||
|
Id: au.ID,
|
||||||
|
DisplayName: au.getDisplayName(),
|
||||||
|
Email: au.Email,
|
||||||
|
}
|
||||||
|
directoryUserLookup[au.ID] = du
|
||||||
|
}
|
||||||
|
du.GroupIds = append(du.GroupIds, ag.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
directoryGroups[i] = dg
|
||||||
|
}
|
||||||
|
sort.Slice(directoryGroups, func(i, j int) bool {
|
||||||
|
return directoryGroups[i].Id < directoryGroups[j].Id
|
||||||
|
})
|
||||||
|
|
||||||
|
directoryUsers := make([]*directory.User, 0, len(directoryUserLookup))
|
||||||
|
for _, du := range directoryUserLookup {
|
||||||
|
directoryUsers = append(directoryUsers, du)
|
||||||
|
}
|
||||||
|
sort.Slice(directoryUsers, func(i, j int) bool {
|
||||||
|
return directoryUsers[i].Id < directoryUsers[j].Id
|
||||||
|
})
|
||||||
|
|
||||||
|
return directoryGroups, directoryUsers, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Provider) getClient(ctx context.Context) (*http.Client, error) {
|
||||||
|
token, err := p.getToken(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
client := new(http.Client)
|
||||||
|
*client = *p.cfg.httpClient
|
||||||
|
client.Transport = &oauth2.Transport{
|
||||||
|
Source: oauth2.StaticTokenSource(token),
|
||||||
|
Base: p.cfg.httpClient.Transport,
|
||||||
|
}
|
||||||
|
return client, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Provider) getToken(ctx context.Context) (*oauth2.Token, error) {
|
||||||
|
if p.cfg.serviceAccount == nil {
|
||||||
|
return nil, fmt.Errorf("ping: service account is required")
|
||||||
|
}
|
||||||
|
environmentID := p.cfg.serviceAccount.EnvironmentID
|
||||||
|
if environmentID == "" {
|
||||||
|
environmentID = p.cfg.environmentID
|
||||||
|
}
|
||||||
|
if environmentID == "" {
|
||||||
|
return nil, fmt.Errorf("ping: environment ID is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
p.mu.RLock()
|
||||||
|
token := p.token
|
||||||
|
p.mu.RUnlock()
|
||||||
|
|
||||||
|
if token != nil && token.Valid() {
|
||||||
|
return token, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
p.mu.Lock()
|
||||||
|
defer p.mu.Unlock()
|
||||||
|
|
||||||
|
token = p.token
|
||||||
|
if token != nil && token.Valid() {
|
||||||
|
return token, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
ocfg := &clientcredentials.Config{
|
||||||
|
ClientID: p.cfg.serviceAccount.ClientID,
|
||||||
|
ClientSecret: p.cfg.serviceAccount.ClientSecret,
|
||||||
|
TokenURL: p.cfg.authURL.ResolveReference(&url.URL{
|
||||||
|
Path: fmt.Sprintf("/%s/as/token", environmentID),
|
||||||
|
}).String(),
|
||||||
|
}
|
||||||
|
var err error
|
||||||
|
p.token, err = ocfg.Token(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return p.token, nil
|
||||||
|
}
|
230
internal/directory/ping/provider_test.go
Normal file
230
internal/directory/ping/provider_test.go
Normal file
|
@ -0,0 +1,230 @@
|
||||||
|
package ping
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
|
"sort"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/go-chi/chi"
|
||||||
|
"github.com/go-chi/chi/middleware"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/internal/testutil"
|
||||||
|
)
|
||||||
|
|
||||||
|
type M = map[string]interface{}
|
||||||
|
|
||||||
|
func newMockAPI(userIDToGroupIDs map[string][]string) http.Handler {
|
||||||
|
lookup := map[string]struct{}{}
|
||||||
|
for _, groups := range userIDToGroupIDs {
|
||||||
|
for _, group := range groups {
|
||||||
|
lookup[group] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
var allGroups []string
|
||||||
|
for groupID := range lookup {
|
||||||
|
allGroups = append(allGroups, groupID)
|
||||||
|
}
|
||||||
|
sort.Strings(allGroups)
|
||||||
|
|
||||||
|
var allUserIDs []string
|
||||||
|
for userID := range userIDToGroupIDs {
|
||||||
|
allUserIDs = append(allUserIDs, userID)
|
||||||
|
}
|
||||||
|
sort.Strings(allUserIDs)
|
||||||
|
|
||||||
|
filterToUserIDs := map[string][]string{}
|
||||||
|
for userID, groupIDs := range userIDToGroupIDs {
|
||||||
|
for _, groupID := range groupIDs {
|
||||||
|
filter := fmt.Sprintf(`memberOfGroups[id eq "%s"]`, groupID)
|
||||||
|
filterToUserIDs[filter] = append(filterToUserIDs[filter], userID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
r := chi.NewRouter()
|
||||||
|
r.Use(middleware.Logger)
|
||||||
|
r.Post("/ENVIRONMENTID/as/token", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
u, p, _ := r.BasicAuth()
|
||||||
|
if u != "CLIENTID" || p != "CLIENTSECRET" {
|
||||||
|
http.Error(w, "forbidden", http.StatusForbidden)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
grantType := r.FormValue("grant_type")
|
||||||
|
if grantType != "client_credentials" {
|
||||||
|
http.Error(w, "invalid grant_type", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_ = json.NewEncoder(w).Encode(M{
|
||||||
|
"access_token": "ACCESSTOKEN",
|
||||||
|
"created_at": time.Now().Format(time.RFC3339),
|
||||||
|
"expires_in": 360000,
|
||||||
|
"refresh_token": "REFRESHTOKEN",
|
||||||
|
"token_type": "bearer",
|
||||||
|
})
|
||||||
|
})
|
||||||
|
r.Route("/v1/environments/ENVIRONMENTID", func(r chi.Router) {
|
||||||
|
r.Get("/groups", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
var apiGroups []apiGroup
|
||||||
|
for _, id := range allGroups {
|
||||||
|
apiGroups = append(apiGroups, apiGroup{
|
||||||
|
ID: id,
|
||||||
|
Name: "Group " + id,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_ = json.NewEncoder(w).Encode(M{
|
||||||
|
"_embedded": M{
|
||||||
|
"groups": apiGroups,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
})
|
||||||
|
r.Route("/users", func(r chi.Router) {
|
||||||
|
r.Get("/{user_id}", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
userID := chi.URLParam(r, "user_id")
|
||||||
|
groupIDs, ok := userIDToGroupIDs[userID]
|
||||||
|
if !ok {
|
||||||
|
http.NotFound(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
au := apiUser{
|
||||||
|
ID: userID,
|
||||||
|
Email: userID + "@example.com",
|
||||||
|
Name: apiUserName{
|
||||||
|
Given: "Given-" + userID,
|
||||||
|
Middle: "Middle-" + userID,
|
||||||
|
Family: "Family-" + userID,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if r.URL.Query().Get("include") == "memberOfGroupIDs" {
|
||||||
|
au.MemberOfGroupIDs = groupIDs
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_ = json.NewEncoder(w).Encode(au)
|
||||||
|
})
|
||||||
|
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
filter := r.URL.Query().Get("filter")
|
||||||
|
userIDs, ok := filterToUserIDs[filter]
|
||||||
|
if !ok {
|
||||||
|
http.Error(w, "expected filter", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var apiUsers []apiUser
|
||||||
|
for _, id := range userIDs {
|
||||||
|
apiUsers = append(apiUsers, apiUser{
|
||||||
|
ID: id,
|
||||||
|
Email: id + "@example.com",
|
||||||
|
Name: apiUserName{
|
||||||
|
Given: "Given-" + id,
|
||||||
|
Middle: "Middle-" + id,
|
||||||
|
Family: "Family-" + id,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_ = json.NewEncoder(w).Encode(M{
|
||||||
|
"_embedded": M{
|
||||||
|
"users": apiUsers,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProvider_User(t *testing.T) {
|
||||||
|
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10)
|
||||||
|
defer clearTimeout()
|
||||||
|
|
||||||
|
srv := httptest.NewServer(newMockAPI(map[string][]string{
|
||||||
|
"user1": {"group1", "group2"},
|
||||||
|
"user2": {"group1", "group3"},
|
||||||
|
"user3": {"group3"},
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
u, err := url.Parse(srv.URL)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
p := New(
|
||||||
|
WithAPIURL(u),
|
||||||
|
WithAuthURL(u),
|
||||||
|
WithEnvironmentID("ENVIRONMENTID"),
|
||||||
|
WithServiceAccount(&ServiceAccount{
|
||||||
|
ClientID: "CLIENTID",
|
||||||
|
ClientSecret: "CLIENTSECRET",
|
||||||
|
}))
|
||||||
|
du, err := p.User(ctx, "user1", "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
testutil.AssertProtoJSONEqual(t, `{
|
||||||
|
"id": "user1",
|
||||||
|
"email": "user1@example.com",
|
||||||
|
"displayName": "Given-user1 Middle-user1 Family-user1",
|
||||||
|
"groupIds": ["group1", "group2"]
|
||||||
|
}`, du)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProvider_UserGroups(t *testing.T) {
|
||||||
|
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10)
|
||||||
|
defer clearTimeout()
|
||||||
|
|
||||||
|
srv := httptest.NewServer(newMockAPI(map[string][]string{
|
||||||
|
"user1": {"group1", "group2"},
|
||||||
|
"user2": {"group1", "group3"},
|
||||||
|
"user3": {"group3"},
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
u, err := url.Parse(srv.URL)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
p := New(
|
||||||
|
WithAPIURL(u),
|
||||||
|
WithAuthURL(u),
|
||||||
|
WithEnvironmentID("ENVIRONMENTID"),
|
||||||
|
WithServiceAccount(&ServiceAccount{
|
||||||
|
ClientID: "CLIENTID",
|
||||||
|
ClientSecret: "CLIENTSECRET",
|
||||||
|
}))
|
||||||
|
dgs, dus, err := p.UserGroups(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
testutil.AssertProtoJSONEqual(t, `[
|
||||||
|
{ "id": "group1", "name": "Group group1" },
|
||||||
|
{ "id": "group2", "name": "Group group2" },
|
||||||
|
{ "id": "group3", "name": "Group group3" }
|
||||||
|
]`, dgs)
|
||||||
|
testutil.AssertProtoJSONEqual(t, `[
|
||||||
|
{
|
||||||
|
"id": "user1",
|
||||||
|
"email": "user1@example.com",
|
||||||
|
"displayName": "Given-user1 Middle-user1 Family-user1",
|
||||||
|
"groupIds": ["group1", "group2"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "user2",
|
||||||
|
"email": "user2@example.com",
|
||||||
|
"displayName": "Given-user2 Middle-user2 Family-user2",
|
||||||
|
"groupIds": ["group1", "group3"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "user3",
|
||||||
|
"email": "user3@example.com",
|
||||||
|
"displayName": "Given-user3 Middle-user3 Family-user3",
|
||||||
|
"groupIds": ["group3"]
|
||||||
|
}
|
||||||
|
]`, dus)
|
||||||
|
}
|
|
@ -8,6 +8,8 @@ import (
|
||||||
|
|
||||||
"github.com/google/go-cmp/cmp"
|
"github.com/google/go-cmp/cmp"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/internal/directory/ping"
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/internal/directory/auth0"
|
"github.com/pomerium/pomerium/internal/directory/auth0"
|
||||||
"github.com/pomerium/pomerium/internal/directory/azure"
|
"github.com/pomerium/pomerium/internal/directory/azure"
|
||||||
"github.com/pomerium/pomerium/internal/directory/github"
|
"github.com/pomerium/pomerium/internal/directory/github"
|
||||||
|
@ -138,6 +140,18 @@ func GetProvider(options Options) (provider Provider) {
|
||||||
Str("provider", options.Provider).
|
Str("provider", options.Provider).
|
||||||
Err(err).
|
Err(err).
|
||||||
Msg("invalid service account for onelogin directory provider")
|
Msg("invalid service account for onelogin directory provider")
|
||||||
|
case ping.Name:
|
||||||
|
serviceAccount, err := ping.ParseServiceAccount(options.ServiceAccount)
|
||||||
|
if err == nil {
|
||||||
|
return ping.New(
|
||||||
|
ping.WithProviderURL(providerURL),
|
||||||
|
ping.WithServiceAccount(serviceAccount))
|
||||||
|
}
|
||||||
|
log.Warn().
|
||||||
|
Str("service", "directory").
|
||||||
|
Str("provider", options.Provider).
|
||||||
|
Err(err).
|
||||||
|
Msg("invalid service account for ping directory provider")
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Warn().
|
log.Warn().
|
||||||
|
|
39
internal/identity/oidc/ping/ping.go
Normal file
39
internal/identity/oidc/ping/ping.go
Normal file
|
@ -0,0 +1,39 @@
|
||||||
|
// Package ping implements OpenID Connect for Ping
|
||||||
|
//
|
||||||
|
// https://www.pomerium.io/docs/identity-providers/ping.html
|
||||||
|
package ping
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/internal/identity/oauth"
|
||||||
|
pom_oidc "github.com/pomerium/pomerium/internal/identity/oidc"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// Name identifies the Ping identity provider.
|
||||||
|
Name = "ping"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Provider is a Ping implementation of the Authenticator interface.
|
||||||
|
type Provider struct {
|
||||||
|
*pom_oidc.Provider
|
||||||
|
}
|
||||||
|
|
||||||
|
// New instantiates an OpenID Connect (OIDC) provider for Ping.
|
||||||
|
func New(ctx context.Context, o *oauth.Options) (*Provider, error) {
|
||||||
|
var p Provider
|
||||||
|
var err error
|
||||||
|
genericOidc, err := pom_oidc.New(ctx, o)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("%s: failed creating oidc provider: %w", Name, err)
|
||||||
|
}
|
||||||
|
p.Provider = genericOidc
|
||||||
|
return &p, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Name returns the provider name.
|
||||||
|
func (p *Provider) Name() string {
|
||||||
|
return Name
|
||||||
|
}
|
|
@ -20,6 +20,7 @@ import (
|
||||||
"github.com/pomerium/pomerium/internal/identity/oidc/google"
|
"github.com/pomerium/pomerium/internal/identity/oidc/google"
|
||||||
"github.com/pomerium/pomerium/internal/identity/oidc/okta"
|
"github.com/pomerium/pomerium/internal/identity/oidc/okta"
|
||||||
"github.com/pomerium/pomerium/internal/identity/oidc/onelogin"
|
"github.com/pomerium/pomerium/internal/identity/oidc/onelogin"
|
||||||
|
"github.com/pomerium/pomerium/internal/identity/oidc/ping"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Authenticator is an interface representing the ability to authenticate with an identity provider.
|
// Authenticator is an interface representing the ability to authenticate with an identity provider.
|
||||||
|
@ -53,6 +54,8 @@ func NewAuthenticator(o oauth.Options) (a Authenticator, err error) {
|
||||||
a, err = okta.New(ctx, &o)
|
a, err = okta.New(ctx, &o)
|
||||||
case onelogin.Name:
|
case onelogin.Name:
|
||||||
a, err = onelogin.New(ctx, &o)
|
a, err = onelogin.New(ctx, &o)
|
||||||
|
case ping.Name:
|
||||||
|
a, err = ping.New(ctx, &o)
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("identity: unknown provider: %s", o.ProviderName)
|
return nil, fmt.Errorf("identity: unknown provider: %s", o.ProviderName)
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue