From c4c8ef8e53d7a35f07b2665570a0f353ee710e2d Mon Sep 17 00:00:00 2001 From: Caleb Doxsey Date: Tue, 18 Aug 2020 10:17:28 -0600 Subject: [PATCH] azure: support deriving credentials from client id, client secret and provider url (#1300) --- cache/cache.go | 2 ++ config/options.go | 5 +-- internal/directory/azure/azure.go | 48 +++++++++++++++++++++++++- internal/directory/azure/azure_test.go | 38 ++++++++++++++++++++ internal/directory/provider.go | 14 +++----- pkg/grpc/directory/directory.go | 10 ++++++ 6 files changed, 104 insertions(+), 13 deletions(-) diff --git a/cache/cache.go b/cache/cache.go index 9140fe1f4..d8bf8e336 100644 --- a/cache/cache.go +++ b/cache/cache.go @@ -129,6 +129,8 @@ func (c *Cache) update(cfg *config.Config) error { Provider: cfg.Options.Provider, ProviderURL: cfg.Options.ProviderURL, QPS: cfg.Options.QPS, + ClientID: cfg.Options.ClientID, + ClientSecret: cfg.Options.ClientSecret, }) dataBrokerClient := databroker.NewDataBrokerServiceClient(c.localGRPCConnection) diff --git a/config/options.go b/config/options.go index 3ecea38af..aaeb4fa6d 100644 --- a/config/options.go +++ b/config/options.go @@ -632,8 +632,9 @@ func (o *Options) Validate() error { } // if no service account was defined, there should not be any policies that - // assert group membership - if o.ServiceAccount == "" { + // assert group membership (except for azure which can be derived from the client + // id, secret and provider url) + if o.ServiceAccount == "" && o.Provider != "azure" { for _, p := range o.Policies { if len(p.AllowedGroups) != 0 { return fmt.Errorf("config: `allowed_groups` requires `idp_service_account`") diff --git a/internal/directory/azure/azure.go b/internal/directory/azure/azure.go index 7e6f86125..2db6364d7 100644 --- a/internal/directory/azure/azure.go +++ b/internal/directory/azure/azure.go @@ -282,7 +282,39 @@ type ServiceAccount struct { } // ParseServiceAccount parses the service account in the config options. -func ParseServiceAccount(rawServiceAccount string) (*ServiceAccount, error) { +func ParseServiceAccount(options directory.Options) (*ServiceAccount, error) { + if options.ServiceAccount != "" { + return parseServiceAccountFromString(options.ServiceAccount) + } + return parseServiceAccountFromOptions(options.ClientID, options.ClientSecret, options.ProviderURL) +} + +func parseServiceAccountFromOptions(clientID, clientSecret, providerURL string) (*ServiceAccount, error) { + serviceAccount := ServiceAccount{ + ClientID: clientID, + ClientSecret: clientSecret, + } + + var err error + serviceAccount.DirectoryID, err = parseDirectoryIDFromURL(providerURL) + if err != nil { + return nil, err + } + + if serviceAccount.ClientID == "" { + return nil, fmt.Errorf("client_id is required") + } + if serviceAccount.ClientSecret == "" { + return nil, fmt.Errorf("client_secret is required") + } + if serviceAccount.DirectoryID == "" { + return nil, fmt.Errorf("directory_id is required") + } + + return &serviceAccount, nil +} + +func parseServiceAccountFromString(rawServiceAccount string) (*ServiceAccount, error) { bs, err := base64.StdEncoding.DecodeString(rawServiceAccount) if err != nil { return nil, err @@ -306,3 +338,17 @@ func ParseServiceAccount(rawServiceAccount string) (*ServiceAccount, error) { return &serviceAccount, nil } + +func parseDirectoryIDFromURL(providerURL string) (string, error) { + u, err := url.Parse(providerURL) + if err != nil { + return "", err + } + + pathParts := strings.SplitN(u.Path, "/", 3) + if len(pathParts) != 3 { + return "", fmt.Errorf("no directory id found in path") + } + + return pathParts[1], nil +} diff --git a/internal/directory/azure/azure_test.go b/internal/directory/azure/azure_test.go index 2dd9cfcfc..424514dbc 100644 --- a/internal/directory/azure/azure_test.go +++ b/internal/directory/azure/azure_test.go @@ -2,6 +2,7 @@ package azure import ( "context" + "encoding/base64" "encoding/json" "net/http" "net/http/httptest" @@ -107,6 +108,43 @@ func Test(t *testing.T) { }, groups) } +func TestParseServiceAccount(t *testing.T) { + t.Run("by options", func(t *testing.T) { + serviceAccount, err := ParseServiceAccount(directory.Options{ + ProviderURL: "https://login.microsoftonline.com/0303f438-3c5c-4190-9854-08d3eb31bd9f/v2.0", + ClientID: "CLIENT_ID", + ClientSecret: "CLIENT_SECRET", + }) + if !assert.NoError(t, err) { + return + } + + assert.Equal(t, &ServiceAccount{ + ClientID: "CLIENT_ID", + ClientSecret: "CLIENT_SECRET", + DirectoryID: "0303f438-3c5c-4190-9854-08d3eb31bd9f", + }, serviceAccount) + }) + t.Run("by service account", func(t *testing.T) { + serviceAccount, err := ParseServiceAccount(directory.Options{ + ServiceAccount: base64.StdEncoding.EncodeToString([]byte(`{ + "client_id": "CLIENT_ID", + "client_secret": "CLIENT_SECRET", + "directory_id": "0303f438-3c5c-4190-9854-08d3eb31bd9f" + }`)), + }) + if !assert.NoError(t, err) { + return + } + + assert.Equal(t, &ServiceAccount{ + ClientID: "CLIENT_ID", + ClientSecret: "CLIENT_SECRET", + DirectoryID: "0303f438-3c5c-4190-9854-08d3eb31bd9f", + }, serviceAccount) + }) +} + func mustParseURL(rawurl string) *url.URL { u, err := url.Parse(rawurl) if err != nil { diff --git a/internal/directory/provider.go b/internal/directory/provider.go index 258e4134b..8cd3b73e3 100644 --- a/internal/directory/provider.go +++ b/internal/directory/provider.go @@ -24,19 +24,14 @@ type Group = directory.Group // A User is a directory User. type User = directory.User +// Options are the options specific to the provider. +type Options = directory.Options + // A Provider provides user group directory information. type Provider interface { UserGroups(ctx context.Context) ([]*Group, []*User, error) } -// Options are the options specific to the provider. -type Options struct { - ServiceAccount string - Provider string - ProviderURL string - QPS float64 -} - var globalProvider = struct { sync.Mutex provider Provider @@ -59,11 +54,10 @@ func GetProvider(options Options) (provider Provider) { switch options.Provider { case azure.Name: - serviceAccount, err := azure.ParseServiceAccount(options.ServiceAccount) + serviceAccount, err := azure.ParseServiceAccount(options) if err == nil { return azure.New(azure.WithServiceAccount(serviceAccount)) } - log.Warn(). Str("service", "directory"). Str("provider", options.Provider). diff --git a/pkg/grpc/directory/directory.go b/pkg/grpc/directory/directory.go index 1bf0a1db1..ba25c33e7 100644 --- a/pkg/grpc/directory/directory.go +++ b/pkg/grpc/directory/directory.go @@ -48,3 +48,13 @@ func GetUser(ctx context.Context, client databroker.DataBrokerServiceClient, use } return &u, nil } + +// Options are directory provider options. +type Options struct { + ServiceAccount string + Provider string + ProviderURL string + ClientID string + ClientSecret string + QPS float64 +}