mirror of
https://github.com/pomerium/pomerium.git
synced 2025-08-06 10:21:05 +02:00
This commit is contained in:
parent
95f9e94bea
commit
26c05e5436
5 changed files with 108 additions and 5 deletions
|
@ -623,8 +623,9 @@ func (o *Options) Validate() error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// if no service account was defined, there should not be any policies that
|
// if no service account was defined, there should not be any policies that
|
||||||
// assert group membership
|
// assert group membership (except for azure which can be derived from the client
|
||||||
if o.ServiceAccount == "" {
|
// id, secret and provider url)
|
||||||
|
if o.ServiceAccount == "" && o.Provider != "azure" {
|
||||||
for _, p := range o.Policies {
|
for _, p := range o.Policies {
|
||||||
if len(p.AllowedGroups) != 0 {
|
if len(p.AllowedGroups) != 0 {
|
||||||
return fmt.Errorf("config: `allowed_groups` requires `idp_service_account`")
|
return fmt.Errorf("config: `allowed_groups` requires `idp_service_account`")
|
||||||
|
|
|
@ -282,7 +282,39 @@ type ServiceAccount struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// ParseServiceAccount parses the service account in the config options.
|
// ParseServiceAccount parses the service account in the config options.
|
||||||
func ParseServiceAccount(rawServiceAccount string) (*ServiceAccount, error) {
|
func ParseServiceAccount(options directory.Options) (*ServiceAccount, error) {
|
||||||
|
if options.ServiceAccount != "" {
|
||||||
|
return parseServiceAccountFromString(options.ServiceAccount)
|
||||||
|
}
|
||||||
|
return parseServiceAccountFromOptions(options.ClientID, options.ClientSecret, options.ProviderURL)
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseServiceAccountFromOptions(clientID, clientSecret, providerURL string) (*ServiceAccount, error) {
|
||||||
|
serviceAccount := ServiceAccount{
|
||||||
|
ClientID: clientID,
|
||||||
|
ClientSecret: clientSecret,
|
||||||
|
}
|
||||||
|
|
||||||
|
var err error
|
||||||
|
serviceAccount.DirectoryID, err = parseDirectoryIDFromURL(providerURL)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if serviceAccount.ClientID == "" {
|
||||||
|
return nil, fmt.Errorf("client_id is required")
|
||||||
|
}
|
||||||
|
if serviceAccount.ClientSecret == "" {
|
||||||
|
return nil, fmt.Errorf("client_secret is required")
|
||||||
|
}
|
||||||
|
if serviceAccount.DirectoryID == "" {
|
||||||
|
return nil, fmt.Errorf("directory_id is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
return &serviceAccount, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseServiceAccountFromString(rawServiceAccount string) (*ServiceAccount, error) {
|
||||||
bs, err := base64.StdEncoding.DecodeString(rawServiceAccount)
|
bs, err := base64.StdEncoding.DecodeString(rawServiceAccount)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -306,3 +338,17 @@ func ParseServiceAccount(rawServiceAccount string) (*ServiceAccount, error) {
|
||||||
|
|
||||||
return &serviceAccount, nil
|
return &serviceAccount, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func parseDirectoryIDFromURL(providerURL string) (string, error) {
|
||||||
|
u, err := url.Parse(providerURL)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
pathParts := strings.SplitN(u.Path, "/", 3)
|
||||||
|
if len(pathParts) != 3 {
|
||||||
|
return "", fmt.Errorf("no directory id found in path")
|
||||||
|
}
|
||||||
|
|
||||||
|
return pathParts[1], nil
|
||||||
|
}
|
||||||
|
|
|
@ -2,6 +2,7 @@ package azure
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
@ -107,6 +108,43 @@ func Test(t *testing.T) {
|
||||||
}, groups)
|
}, groups)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestParseServiceAccount(t *testing.T) {
|
||||||
|
t.Run("by options", func(t *testing.T) {
|
||||||
|
serviceAccount, err := ParseServiceAccount(directory.Options{
|
||||||
|
ProviderURL: "https://login.microsoftonline.com/0303f438-3c5c-4190-9854-08d3eb31bd9f/v2.0",
|
||||||
|
ClientID: "CLIENT_ID",
|
||||||
|
ClientSecret: "CLIENT_SECRET",
|
||||||
|
})
|
||||||
|
if !assert.NoError(t, err) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, &ServiceAccount{
|
||||||
|
ClientID: "CLIENT_ID",
|
||||||
|
ClientSecret: "CLIENT_SECRET",
|
||||||
|
DirectoryID: "0303f438-3c5c-4190-9854-08d3eb31bd9f",
|
||||||
|
}, serviceAccount)
|
||||||
|
})
|
||||||
|
t.Run("by service account", func(t *testing.T) {
|
||||||
|
serviceAccount, err := ParseServiceAccount(directory.Options{
|
||||||
|
ServiceAccount: base64.StdEncoding.EncodeToString([]byte(`{
|
||||||
|
"client_id": "CLIENT_ID",
|
||||||
|
"client_secret": "CLIENT_SECRET",
|
||||||
|
"directory_id": "0303f438-3c5c-4190-9854-08d3eb31bd9f"
|
||||||
|
}`)),
|
||||||
|
})
|
||||||
|
if !assert.NoError(t, err) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, &ServiceAccount{
|
||||||
|
ClientID: "CLIENT_ID",
|
||||||
|
ClientSecret: "CLIENT_SECRET",
|
||||||
|
DirectoryID: "0303f438-3c5c-4190-9854-08d3eb31bd9f",
|
||||||
|
}, serviceAccount)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func mustParseURL(rawurl string) *url.URL {
|
func mustParseURL(rawurl string) *url.URL {
|
||||||
u, err := url.Parse(rawurl)
|
u, err := url.Parse(rawurl)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -22,6 +22,9 @@ type Group = directory.Group
|
||||||
// A User is a directory User.
|
// A User is a directory User.
|
||||||
type User = directory.User
|
type User = directory.User
|
||||||
|
|
||||||
|
// Options are the options specific to the provider.
|
||||||
|
type Options = directory.Options
|
||||||
|
|
||||||
// A Provider provides user group directory information.
|
// A Provider provides user group directory information.
|
||||||
type Provider interface {
|
type Provider interface {
|
||||||
UserGroups(ctx context.Context) ([]*Group, []*User, error)
|
UserGroups(ctx context.Context) ([]*Group, []*User, error)
|
||||||
|
@ -31,11 +34,16 @@ type Provider interface {
|
||||||
func GetProvider(options *config.Options) Provider {
|
func GetProvider(options *config.Options) Provider {
|
||||||
switch options.Provider {
|
switch options.Provider {
|
||||||
case azure.Name:
|
case azure.Name:
|
||||||
serviceAccount, err := azure.ParseServiceAccount(options.ServiceAccount)
|
serviceAccount, err := azure.ParseServiceAccount(directory.Options{
|
||||||
|
ServiceAccount: options.ServiceAccount,
|
||||||
|
Provider: options.Provider,
|
||||||
|
ProviderURL: options.ProviderURL,
|
||||||
|
ClientID: options.ClientID,
|
||||||
|
ClientSecret: options.ClientSecret,
|
||||||
|
})
|
||||||
if err == nil {
|
if err == nil {
|
||||||
return azure.New(azure.WithServiceAccount(serviceAccount))
|
return azure.New(azure.WithServiceAccount(serviceAccount))
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Warn().
|
log.Warn().
|
||||||
Str("service", "directory").
|
Str("service", "directory").
|
||||||
Str("provider", options.Provider).
|
Str("provider", options.Provider).
|
||||||
|
|
|
@ -48,3 +48,13 @@ func GetUser(ctx context.Context, client databroker.DataBrokerServiceClient, use
|
||||||
}
|
}
|
||||||
return &u, nil
|
return &u, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Options are directory provider options.
|
||||||
|
type Options struct {
|
||||||
|
ServiceAccount string
|
||||||
|
Provider string
|
||||||
|
ProviderURL string
|
||||||
|
ClientID string
|
||||||
|
ClientSecret string
|
||||||
|
QPS float64
|
||||||
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue