mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-10 15:47:36 +02:00
ping: identity and directory providers (#1975)
* ping: add identity provider * ping: implement directory provider * ping, not onelogin * ping, not onelogin * escape path params
This commit is contained in:
parent
00a1cb7456
commit
fd97561ab1
7 changed files with 738 additions and 0 deletions
174
internal/directory/ping/api.go
Normal file
174
internal/directory/ping/api.go
Normal file
|
@ -0,0 +1,174 @@
|
|||
package ping
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var errNotFound = errors.New("ping: user not found")
|
||||
|
||||
type (
|
||||
apiGroup struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
apiUser struct {
|
||||
ID string `json:"id"`
|
||||
Email string `json:"email"`
|
||||
Name apiUserName `json:"name"`
|
||||
Username string `json:"username"`
|
||||
MemberOfGroupIDs []string `json:"memberOfGroupIDs"`
|
||||
}
|
||||
apiUserName struct {
|
||||
Given string `json:"given"`
|
||||
Middle string `json:"middle"`
|
||||
Family string `json:"family"`
|
||||
}
|
||||
)
|
||||
|
||||
func (au apiUser) getDisplayName() string {
|
||||
var parts []string
|
||||
if au.Name.Given != "" {
|
||||
parts = append(parts, au.Name.Given)
|
||||
}
|
||||
if au.Name.Middle != "" {
|
||||
parts = append(parts, au.Name.Middle)
|
||||
}
|
||||
if au.Name.Family != "" {
|
||||
parts = append(parts, au.Name.Family)
|
||||
}
|
||||
if len(parts) == 0 {
|
||||
parts = append(parts, au.Username)
|
||||
}
|
||||
return strings.Join(parts, " ")
|
||||
}
|
||||
|
||||
func getAllGroups(ctx context.Context, client *http.Client, apiURL *url.URL, envID string) ([]apiGroup, error) {
|
||||
nextURL := apiURL.ResolveReference(&url.URL{
|
||||
Path: fmt.Sprintf("/v1/environments/%s/groups", url.PathEscape(envID)),
|
||||
}).String()
|
||||
|
||||
var apiGroups []apiGroup
|
||||
err := batchAPIRequest(ctx, client, nextURL, func(body []byte) error {
|
||||
var apiResponse struct {
|
||||
Embedded struct {
|
||||
Groups []apiGroup `json:"groups"`
|
||||
} `json:"_embedded"`
|
||||
}
|
||||
err := json.Unmarshal(body, &apiResponse)
|
||||
if err != nil {
|
||||
return fmt.Errorf("ping: error decoding API response: %w", err)
|
||||
}
|
||||
apiGroups = append(apiGroups, apiResponse.Embedded.Groups...)
|
||||
return nil
|
||||
})
|
||||
return apiGroups, err
|
||||
}
|
||||
|
||||
func getGroupUsers(ctx context.Context, client *http.Client, apiURL *url.URL, envID, groupID string) ([]apiUser, error) {
|
||||
nextURL := apiURL.ResolveReference(&url.URL{
|
||||
Path: fmt.Sprintf("/v1/environments/%s/users", url.PathEscape(envID)),
|
||||
RawQuery: (&url.Values{
|
||||
"filter": {fmt.Sprintf(`memberOfGroups[id eq "%s"]`, groupID)},
|
||||
}).Encode(),
|
||||
}).String()
|
||||
|
||||
var apiUsers []apiUser
|
||||
err := batchAPIRequest(ctx, client, nextURL, func(body []byte) error {
|
||||
var apiResponse struct {
|
||||
Embedded struct {
|
||||
Users []apiUser `json:"users"`
|
||||
} `json:"_embedded"`
|
||||
}
|
||||
err := json.Unmarshal(body, &apiResponse)
|
||||
if err != nil {
|
||||
return fmt.Errorf("ping: error decoding API response: %w", err)
|
||||
}
|
||||
apiUsers = append(apiUsers, apiResponse.Embedded.Users...)
|
||||
return nil
|
||||
})
|
||||
return apiUsers, err
|
||||
}
|
||||
|
||||
func getUser(ctx context.Context, client *http.Client, apiURL *url.URL, envID, userID string) (*apiUser, error) {
|
||||
nextURL := apiURL.ResolveReference(&url.URL{
|
||||
Path: fmt.Sprintf("/v1/environments/%s/users/%s", url.PathEscape(envID), url.PathEscape(userID)),
|
||||
RawQuery: (&url.Values{
|
||||
"include": {"memberOfGroupIDs"},
|
||||
}).Encode(),
|
||||
}).String()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", nextURL, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("ping: error building API request: %w", err)
|
||||
}
|
||||
res, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("ping: error making API request: %w", err)
|
||||
}
|
||||
body, err := io.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("ping: error reading API response: %w", err)
|
||||
}
|
||||
_ = res.Body.Close()
|
||||
if res.StatusCode == http.StatusNotFound {
|
||||
return nil, errNotFound
|
||||
} else if res.StatusCode/100 != 2 {
|
||||
return nil, fmt.Errorf("ping: unexpected status code: %d", res.StatusCode)
|
||||
}
|
||||
|
||||
var u apiUser
|
||||
err = json.Unmarshal(body, &u)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("ping: error decoding API response: %w", err)
|
||||
}
|
||||
return &u, nil
|
||||
}
|
||||
|
||||
func batchAPIRequest(ctx context.Context, client *http.Client, nextURL string, callback func(body []byte) error) error {
|
||||
for nextURL != "" {
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", nextURL, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("ping: error building API request: %w", err)
|
||||
}
|
||||
|
||||
res, err := client.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("ping: error making API request: %w", err)
|
||||
}
|
||||
bs, err := io.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("ping: error reading API response: %w", err)
|
||||
}
|
||||
_ = res.Body.Close()
|
||||
if res.StatusCode/100 != 2 {
|
||||
return fmt.Errorf("ping: unexpected status code: %d", res.StatusCode)
|
||||
}
|
||||
|
||||
var apiResponse struct {
|
||||
Links struct {
|
||||
Next struct {
|
||||
HREF string `json:"href"`
|
||||
} `json:"next"`
|
||||
} `json:"_links"`
|
||||
}
|
||||
err = json.Unmarshal(bs, &apiResponse)
|
||||
if err != nil {
|
||||
return fmt.Errorf("ping: error decoding API response: %w", err)
|
||||
}
|
||||
|
||||
err = callback(bs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
nextURL = apiResponse.Links.Next.HREF
|
||||
}
|
||||
return nil
|
||||
}
|
111
internal/directory/ping/config.go
Normal file
111
internal/directory/ping/config.go
Normal file
|
@ -0,0 +1,111 @@
|
|||
package ping
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/encoding"
|
||||
)
|
||||
|
||||
type config struct {
|
||||
authURL *url.URL
|
||||
apiURL *url.URL
|
||||
serviceAccount *ServiceAccount
|
||||
httpClient *http.Client
|
||||
environmentID string
|
||||
}
|
||||
|
||||
// An Option updates the Ping configuration.
|
||||
type Option func(*config)
|
||||
|
||||
// WithAPIURL sets the api url in the config.
|
||||
func WithAPIURL(apiURL *url.URL) Option {
|
||||
return func(cfg *config) {
|
||||
cfg.apiURL = apiURL
|
||||
}
|
||||
}
|
||||
|
||||
// WithAuthURL sets the auth url in the config.
|
||||
func WithAuthURL(authURL *url.URL) Option {
|
||||
return func(cfg *config) {
|
||||
cfg.authURL = authURL
|
||||
}
|
||||
}
|
||||
|
||||
// WithEnvironmentID sets the environment ID in the config.
|
||||
func WithEnvironmentID(environmentID string) Option {
|
||||
return func(cfg *config) {
|
||||
cfg.environmentID = environmentID
|
||||
}
|
||||
}
|
||||
|
||||
// WithHTTPClient sets the http client option.
|
||||
func WithHTTPClient(httpClient *http.Client) Option {
|
||||
return func(cfg *config) {
|
||||
cfg.httpClient = httpClient
|
||||
}
|
||||
}
|
||||
|
||||
// WithProviderURL sets the environment ID from the provider URL set in the config.
|
||||
func WithProviderURL(providerURL *url.URL) Option {
|
||||
// provider URL will be https://auth.pingone.com/{ENVIRONMENT_ID}/as
|
||||
if providerURL == nil {
|
||||
return func(cfg *config) {}
|
||||
}
|
||||
parts := strings.Split(providerURL.Path, "/")
|
||||
if len(parts) < 1 {
|
||||
return func(cfg *config) {}
|
||||
}
|
||||
return WithEnvironmentID(parts[1])
|
||||
}
|
||||
|
||||
// WithServiceAccount sets the service account in the config.
|
||||
func WithServiceAccount(serviceAccount *ServiceAccount) Option {
|
||||
return func(cfg *config) {
|
||||
cfg.serviceAccount = serviceAccount
|
||||
}
|
||||
}
|
||||
|
||||
func getConfig(options ...Option) *config {
|
||||
cfg := new(config)
|
||||
WithHTTPClient(http.DefaultClient)(cfg)
|
||||
WithAuthURL(&url.URL{
|
||||
Scheme: "https",
|
||||
Host: "auth.pingone.com",
|
||||
})(cfg)
|
||||
WithAPIURL(&url.URL{
|
||||
Scheme: "https",
|
||||
Host: "api.pingone.com",
|
||||
})(cfg)
|
||||
for _, option := range options {
|
||||
option(cfg)
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
|
||||
// A ServiceAccount is used by the Ping provider to query the API.
|
||||
type ServiceAccount struct {
|
||||
ClientID string `json:"client_id"`
|
||||
ClientSecret string `json:"client_secret"`
|
||||
EnvironmentID string `json:"environment_id"`
|
||||
}
|
||||
|
||||
// ParseServiceAccount parses the service account in the config options.
|
||||
func ParseServiceAccount(rawServiceAccount string) (*ServiceAccount, error) {
|
||||
var serviceAccount ServiceAccount
|
||||
err := encoding.DecodeBase64OrJSON(rawServiceAccount, &serviceAccount)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if serviceAccount.ClientID == "" {
|
||||
return nil, fmt.Errorf("client_id is required")
|
||||
}
|
||||
if serviceAccount.ClientSecret == "" {
|
||||
return nil, fmt.Errorf("client_secret is required")
|
||||
}
|
||||
|
||||
return &serviceAccount, nil
|
||||
}
|
167
internal/directory/ping/provider.go
Normal file
167
internal/directory/ping/provider.go
Normal file
|
@ -0,0 +1,167 @@
|
|||
// Package ping implements a directory provider for Ping.
|
||||
package ping
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"sort"
|
||||
"sync"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/oauth2/clientcredentials"
|
||||
|
||||
"github.com/pomerium/pomerium/pkg/grpc/directory"
|
||||
)
|
||||
|
||||
// Name is the name of the Ping provider.
|
||||
const Name = "ping"
|
||||
|
||||
// Provider implements a directory provider using the Ping API.
|
||||
type Provider struct {
|
||||
cfg *config
|
||||
mu sync.RWMutex
|
||||
token *oauth2.Token
|
||||
}
|
||||
|
||||
// New creates a new Ping Provider.
|
||||
func New(options ...Option) *Provider {
|
||||
cfg := getConfig(options...)
|
||||
return &Provider{
|
||||
cfg: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
// User returns a user's directory information.
|
||||
func (p *Provider) User(ctx context.Context, userID, accessToken string) (*directory.User, error) {
|
||||
client, err := p.getClient(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
au, err := getUser(ctx, client, p.cfg.apiURL, p.cfg.environmentID, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &directory.User{
|
||||
Id: au.ID,
|
||||
DisplayName: au.getDisplayName(),
|
||||
Email: au.Email,
|
||||
GroupIds: au.MemberOfGroupIDs,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// UserGroups returns all the users and groups in the directory.
|
||||
func (p *Provider) UserGroups(ctx context.Context) ([]*directory.Group, []*directory.User, error) {
|
||||
client, err := p.getClient(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
apiGroups, err := getAllGroups(ctx, client, p.cfg.apiURL, p.cfg.environmentID)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
directoryUserLookup := map[string]*directory.User{}
|
||||
directoryGroups := make([]*directory.Group, len(apiGroups))
|
||||
for i, ag := range apiGroups {
|
||||
dg := &directory.Group{
|
||||
Id: ag.ID,
|
||||
Name: ag.Name,
|
||||
}
|
||||
|
||||
apiUsers, err := getGroupUsers(ctx, client, p.cfg.apiURL, p.cfg.environmentID, ag.ID)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
for _, au := range apiUsers {
|
||||
du, ok := directoryUserLookup[au.ID]
|
||||
if !ok {
|
||||
du = &directory.User{
|
||||
Id: au.ID,
|
||||
DisplayName: au.getDisplayName(),
|
||||
Email: au.Email,
|
||||
}
|
||||
directoryUserLookup[au.ID] = du
|
||||
}
|
||||
du.GroupIds = append(du.GroupIds, ag.ID)
|
||||
}
|
||||
|
||||
directoryGroups[i] = dg
|
||||
}
|
||||
sort.Slice(directoryGroups, func(i, j int) bool {
|
||||
return directoryGroups[i].Id < directoryGroups[j].Id
|
||||
})
|
||||
|
||||
directoryUsers := make([]*directory.User, 0, len(directoryUserLookup))
|
||||
for _, du := range directoryUserLookup {
|
||||
directoryUsers = append(directoryUsers, du)
|
||||
}
|
||||
sort.Slice(directoryUsers, func(i, j int) bool {
|
||||
return directoryUsers[i].Id < directoryUsers[j].Id
|
||||
})
|
||||
|
||||
return directoryGroups, directoryUsers, nil
|
||||
}
|
||||
|
||||
func (p *Provider) getClient(ctx context.Context) (*http.Client, error) {
|
||||
token, err := p.getToken(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
client := new(http.Client)
|
||||
*client = *p.cfg.httpClient
|
||||
client.Transport = &oauth2.Transport{
|
||||
Source: oauth2.StaticTokenSource(token),
|
||||
Base: p.cfg.httpClient.Transport,
|
||||
}
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func (p *Provider) getToken(ctx context.Context) (*oauth2.Token, error) {
|
||||
if p.cfg.serviceAccount == nil {
|
||||
return nil, fmt.Errorf("ping: service account is required")
|
||||
}
|
||||
environmentID := p.cfg.serviceAccount.EnvironmentID
|
||||
if environmentID == "" {
|
||||
environmentID = p.cfg.environmentID
|
||||
}
|
||||
if environmentID == "" {
|
||||
return nil, fmt.Errorf("ping: environment ID is required")
|
||||
}
|
||||
|
||||
p.mu.RLock()
|
||||
token := p.token
|
||||
p.mu.RUnlock()
|
||||
|
||||
if token != nil && token.Valid() {
|
||||
return token, nil
|
||||
}
|
||||
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
token = p.token
|
||||
if token != nil && token.Valid() {
|
||||
return token, nil
|
||||
}
|
||||
|
||||
ocfg := &clientcredentials.Config{
|
||||
ClientID: p.cfg.serviceAccount.ClientID,
|
||||
ClientSecret: p.cfg.serviceAccount.ClientSecret,
|
||||
TokenURL: p.cfg.authURL.ResolveReference(&url.URL{
|
||||
Path: fmt.Sprintf("/%s/as/token", environmentID),
|
||||
}).String(),
|
||||
}
|
||||
var err error
|
||||
p.token, err = ocfg.Token(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return p.token, nil
|
||||
}
|
230
internal/directory/ping/provider_test.go
Normal file
230
internal/directory/ping/provider_test.go
Normal file
|
@ -0,0 +1,230 @@
|
|||
package ping
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"sort"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi"
|
||||
"github.com/go-chi/chi/middleware"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/testutil"
|
||||
)
|
||||
|
||||
type M = map[string]interface{}
|
||||
|
||||
func newMockAPI(userIDToGroupIDs map[string][]string) http.Handler {
|
||||
lookup := map[string]struct{}{}
|
||||
for _, groups := range userIDToGroupIDs {
|
||||
for _, group := range groups {
|
||||
lookup[group] = struct{}{}
|
||||
}
|
||||
}
|
||||
var allGroups []string
|
||||
for groupID := range lookup {
|
||||
allGroups = append(allGroups, groupID)
|
||||
}
|
||||
sort.Strings(allGroups)
|
||||
|
||||
var allUserIDs []string
|
||||
for userID := range userIDToGroupIDs {
|
||||
allUserIDs = append(allUserIDs, userID)
|
||||
}
|
||||
sort.Strings(allUserIDs)
|
||||
|
||||
filterToUserIDs := map[string][]string{}
|
||||
for userID, groupIDs := range userIDToGroupIDs {
|
||||
for _, groupID := range groupIDs {
|
||||
filter := fmt.Sprintf(`memberOfGroups[id eq "%s"]`, groupID)
|
||||
filterToUserIDs[filter] = append(filterToUserIDs[filter], userID)
|
||||
}
|
||||
}
|
||||
|
||||
r := chi.NewRouter()
|
||||
r.Use(middleware.Logger)
|
||||
r.Post("/ENVIRONMENTID/as/token", func(w http.ResponseWriter, r *http.Request) {
|
||||
u, p, _ := r.BasicAuth()
|
||||
if u != "CLIENTID" || p != "CLIENTSECRET" {
|
||||
http.Error(w, "forbidden", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
grantType := r.FormValue("grant_type")
|
||||
if grantType != "client_credentials" {
|
||||
http.Error(w, "invalid grant_type", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(M{
|
||||
"access_token": "ACCESSTOKEN",
|
||||
"created_at": time.Now().Format(time.RFC3339),
|
||||
"expires_in": 360000,
|
||||
"refresh_token": "REFRESHTOKEN",
|
||||
"token_type": "bearer",
|
||||
})
|
||||
})
|
||||
r.Route("/v1/environments/ENVIRONMENTID", func(r chi.Router) {
|
||||
r.Get("/groups", func(w http.ResponseWriter, r *http.Request) {
|
||||
var apiGroups []apiGroup
|
||||
for _, id := range allGroups {
|
||||
apiGroups = append(apiGroups, apiGroup{
|
||||
ID: id,
|
||||
Name: "Group " + id,
|
||||
})
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(M{
|
||||
"_embedded": M{
|
||||
"groups": apiGroups,
|
||||
},
|
||||
})
|
||||
})
|
||||
r.Route("/users", func(r chi.Router) {
|
||||
r.Get("/{user_id}", func(w http.ResponseWriter, r *http.Request) {
|
||||
userID := chi.URLParam(r, "user_id")
|
||||
groupIDs, ok := userIDToGroupIDs[userID]
|
||||
if !ok {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
au := apiUser{
|
||||
ID: userID,
|
||||
Email: userID + "@example.com",
|
||||
Name: apiUserName{
|
||||
Given: "Given-" + userID,
|
||||
Middle: "Middle-" + userID,
|
||||
Family: "Family-" + userID,
|
||||
},
|
||||
}
|
||||
if r.URL.Query().Get("include") == "memberOfGroupIDs" {
|
||||
au.MemberOfGroupIDs = groupIDs
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(au)
|
||||
})
|
||||
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
|
||||
filter := r.URL.Query().Get("filter")
|
||||
userIDs, ok := filterToUserIDs[filter]
|
||||
if !ok {
|
||||
http.Error(w, "expected filter", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
var apiUsers []apiUser
|
||||
for _, id := range userIDs {
|
||||
apiUsers = append(apiUsers, apiUser{
|
||||
ID: id,
|
||||
Email: id + "@example.com",
|
||||
Name: apiUserName{
|
||||
Given: "Given-" + id,
|
||||
Middle: "Middle-" + id,
|
||||
Family: "Family-" + id,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(M{
|
||||
"_embedded": M{
|
||||
"users": apiUsers,
|
||||
},
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
return r
|
||||
}
|
||||
|
||||
func TestProvider_User(t *testing.T) {
|
||||
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10)
|
||||
defer clearTimeout()
|
||||
|
||||
srv := httptest.NewServer(newMockAPI(map[string][]string{
|
||||
"user1": {"group1", "group2"},
|
||||
"user2": {"group1", "group3"},
|
||||
"user3": {"group3"},
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
u, err := url.Parse(srv.URL)
|
||||
require.NoError(t, err)
|
||||
|
||||
p := New(
|
||||
WithAPIURL(u),
|
||||
WithAuthURL(u),
|
||||
WithEnvironmentID("ENVIRONMENTID"),
|
||||
WithServiceAccount(&ServiceAccount{
|
||||
ClientID: "CLIENTID",
|
||||
ClientSecret: "CLIENTSECRET",
|
||||
}))
|
||||
du, err := p.User(ctx, "user1", "")
|
||||
require.NoError(t, err)
|
||||
testutil.AssertProtoJSONEqual(t, `{
|
||||
"id": "user1",
|
||||
"email": "user1@example.com",
|
||||
"displayName": "Given-user1 Middle-user1 Family-user1",
|
||||
"groupIds": ["group1", "group2"]
|
||||
}`, du)
|
||||
}
|
||||
|
||||
func TestProvider_UserGroups(t *testing.T) {
|
||||
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10)
|
||||
defer clearTimeout()
|
||||
|
||||
srv := httptest.NewServer(newMockAPI(map[string][]string{
|
||||
"user1": {"group1", "group2"},
|
||||
"user2": {"group1", "group3"},
|
||||
"user3": {"group3"},
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
u, err := url.Parse(srv.URL)
|
||||
require.NoError(t, err)
|
||||
|
||||
p := New(
|
||||
WithAPIURL(u),
|
||||
WithAuthURL(u),
|
||||
WithEnvironmentID("ENVIRONMENTID"),
|
||||
WithServiceAccount(&ServiceAccount{
|
||||
ClientID: "CLIENTID",
|
||||
ClientSecret: "CLIENTSECRET",
|
||||
}))
|
||||
dgs, dus, err := p.UserGroups(ctx)
|
||||
require.NoError(t, err)
|
||||
testutil.AssertProtoJSONEqual(t, `[
|
||||
{ "id": "group1", "name": "Group group1" },
|
||||
{ "id": "group2", "name": "Group group2" },
|
||||
{ "id": "group3", "name": "Group group3" }
|
||||
]`, dgs)
|
||||
testutil.AssertProtoJSONEqual(t, `[
|
||||
{
|
||||
"id": "user1",
|
||||
"email": "user1@example.com",
|
||||
"displayName": "Given-user1 Middle-user1 Family-user1",
|
||||
"groupIds": ["group1", "group2"]
|
||||
},
|
||||
{
|
||||
"id": "user2",
|
||||
"email": "user2@example.com",
|
||||
"displayName": "Given-user2 Middle-user2 Family-user2",
|
||||
"groupIds": ["group1", "group3"]
|
||||
},
|
||||
{
|
||||
"id": "user3",
|
||||
"email": "user3@example.com",
|
||||
"displayName": "Given-user3 Middle-user3 Family-user3",
|
||||
"groupIds": ["group3"]
|
||||
}
|
||||
]`, dus)
|
||||
}
|
|
@ -8,6 +8,8 @@ import (
|
|||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/directory/ping"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/directory/auth0"
|
||||
"github.com/pomerium/pomerium/internal/directory/azure"
|
||||
"github.com/pomerium/pomerium/internal/directory/github"
|
||||
|
@ -138,6 +140,18 @@ func GetProvider(options Options) (provider Provider) {
|
|||
Str("provider", options.Provider).
|
||||
Err(err).
|
||||
Msg("invalid service account for onelogin directory provider")
|
||||
case ping.Name:
|
||||
serviceAccount, err := ping.ParseServiceAccount(options.ServiceAccount)
|
||||
if err == nil {
|
||||
return ping.New(
|
||||
ping.WithProviderURL(providerURL),
|
||||
ping.WithServiceAccount(serviceAccount))
|
||||
}
|
||||
log.Warn().
|
||||
Str("service", "directory").
|
||||
Str("provider", options.Provider).
|
||||
Err(err).
|
||||
Msg("invalid service account for ping directory provider")
|
||||
}
|
||||
|
||||
log.Warn().
|
||||
|
|
39
internal/identity/oidc/ping/ping.go
Normal file
39
internal/identity/oidc/ping/ping.go
Normal file
|
@ -0,0 +1,39 @@
|
|||
// Package ping implements OpenID Connect for Ping
|
||||
//
|
||||
// https://www.pomerium.io/docs/identity-providers/ping.html
|
||||
package ping
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/identity/oauth"
|
||||
pom_oidc "github.com/pomerium/pomerium/internal/identity/oidc"
|
||||
)
|
||||
|
||||
const (
|
||||
// Name identifies the Ping identity provider.
|
||||
Name = "ping"
|
||||
)
|
||||
|
||||
// Provider is a Ping implementation of the Authenticator interface.
|
||||
type Provider struct {
|
||||
*pom_oidc.Provider
|
||||
}
|
||||
|
||||
// New instantiates an OpenID Connect (OIDC) provider for Ping.
|
||||
func New(ctx context.Context, o *oauth.Options) (*Provider, error) {
|
||||
var p Provider
|
||||
var err error
|
||||
genericOidc, err := pom_oidc.New(ctx, o)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%s: failed creating oidc provider: %w", Name, err)
|
||||
}
|
||||
p.Provider = genericOidc
|
||||
return &p, nil
|
||||
}
|
||||
|
||||
// Name returns the provider name.
|
||||
func (p *Provider) Name() string {
|
||||
return Name
|
||||
}
|
|
@ -20,6 +20,7 @@ import (
|
|||
"github.com/pomerium/pomerium/internal/identity/oidc/google"
|
||||
"github.com/pomerium/pomerium/internal/identity/oidc/okta"
|
||||
"github.com/pomerium/pomerium/internal/identity/oidc/onelogin"
|
||||
"github.com/pomerium/pomerium/internal/identity/oidc/ping"
|
||||
)
|
||||
|
||||
// Authenticator is an interface representing the ability to authenticate with an identity provider.
|
||||
|
@ -53,6 +54,8 @@ func NewAuthenticator(o oauth.Options) (a Authenticator, err error) {
|
|||
a, err = okta.New(ctx, &o)
|
||||
case onelogin.Name:
|
||||
a, err = onelogin.New(ctx, &o)
|
||||
case ping.Name:
|
||||
a, err = ping.New(ctx, &o)
|
||||
default:
|
||||
return nil, fmt.Errorf("identity: unknown provider: %s", o.ProviderName)
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue