diff --git a/authenticate/authenticate.go b/authenticate/authenticate.go index 7f57fdee6..dc6173645 100644 --- a/authenticate/authenticate.go +++ b/authenticate/authenticate.go @@ -95,10 +95,6 @@ type Authenticate struct { // a user's session state from sessionLoaders []sessions.SessionLoader - // provider is the interface to interacting with the identity provider (IdP) - provider identity.Authenticator - providerName string - // dataBrokerClient is used to retrieve sessions dataBrokerClient databroker.DataBrokerServiceClient @@ -111,45 +107,46 @@ type Authenticate struct { templates *template.Template - options *config.AtomicOptions + options *config.AtomicOptions + provider *identity.AtomicAuthenticator } // New validates and creates a new authenticate service from a set of Options. -func New(opts *config.Options) (*Authenticate, error) { - if err := ValidateOptions(opts); err != nil { +func New(cfg *config.Config) (*Authenticate, error) { + if err := ValidateOptions(cfg.Options); err != nil { return nil, err } // shared state encoder setup - sharedCipher, _ := cryptutil.NewAEADCipherFromBase64(opts.SharedKey) - sharedEncoder, err := jws.NewHS256Signer([]byte(opts.SharedKey), opts.GetAuthenticateURL().Host) + sharedCipher, _ := cryptutil.NewAEADCipherFromBase64(cfg.Options.SharedKey) + sharedEncoder, err := jws.NewHS256Signer([]byte(cfg.Options.SharedKey), cfg.Options.GetAuthenticateURL().Host) if err != nil { return nil, err } // private state encoder setup, used to encrypt oauth2 tokens - decodedCookieSecret, _ := base64.StdEncoding.DecodeString(opts.CookieSecret) + decodedCookieSecret, _ := base64.StdEncoding.DecodeString(cfg.Options.CookieSecret) cookieCipher, _ := cryptutil.NewAEADCipher(decodedCookieSecret) encryptedEncoder := ecjson.New(cookieCipher) cookieOptions := &cookie.Options{ - Name: opts.CookieName, - Domain: opts.CookieDomain, - Secure: opts.CookieSecure, - HTTPOnly: opts.CookieHTTPOnly, - Expire: opts.CookieExpire, + Name: cfg.Options.CookieName, + Domain: cfg.Options.CookieDomain, + Secure: cfg.Options.CookieSecure, + HTTPOnly: cfg.Options.CookieHTTPOnly, + Expire: cfg.Options.CookieExpire, } dataBrokerConn, err := grpc.NewGRPCClientConn( &grpc.Options{ - Addr: opts.DataBrokerURL, - OverrideCertificateName: opts.OverrideCertificateName, - CA: opts.CA, - CAFile: opts.CAFile, - RequestTimeout: opts.GRPCClientTimeout, - ClientDNSRoundRobin: opts.GRPCClientDNSRoundRobin, - WithInsecure: opts.GRPCInsecure, - ServiceName: opts.Services, + Addr: cfg.Options.DataBrokerURL, + OverrideCertificateName: cfg.Options.OverrideCertificateName, + CA: cfg.Options.CA, + CAFile: cfg.Options.CAFile, + RequestTimeout: cfg.Options.GRPCClientTimeout, + ClientDNSRoundRobin: cfg.Options.GRPCClientDNSRoundRobin, + WithInsecure: cfg.Options.GRPCInsecure, + ServiceName: cfg.Options.Services, }) if err != nil { return nil, err @@ -160,20 +157,8 @@ func New(opts *config.Options) (*Authenticate, error) { qpStore := queryparam.NewStore(encryptedEncoder, urlutil.QueryProgrammaticToken) headerStore := header.NewStore(encryptedEncoder, httputil.AuthorizationTypePomerium) - redirectURL, _ := urlutil.DeepCopy(opts.AuthenticateURL) - redirectURL.Path = opts.AuthenticateCallbackPath - // configure our identity provider - provider, err := identity.NewAuthenticator( - oauth.Options{ - RedirectURL: redirectURL, - ProviderName: opts.Provider, - ProviderURL: opts.ProviderURL, - ClientID: opts.ClientID, - ClientSecret: opts.ClientSecret, - Scopes: opts.Scopes, - ServiceAccount: opts.ServiceAccount, - AuthCodeOptions: opts.RequestParams, - }) + redirectURL, _ := urlutil.DeepCopy(cfg.Options.AuthenticateURL) + redirectURL.Path = cfg.Options.AuthenticateCallbackPath if err != nil { return nil, err @@ -182,7 +167,7 @@ func New(opts *config.Options) (*Authenticate, error) { a := &Authenticate{ RedirectURL: redirectURL, // shared state - sharedKey: opts.SharedKey, + sharedKey: cfg.Options.SharedKey, sharedCipher: sharedCipher, sharedEncoder: sharedEncoder, // private state @@ -190,14 +175,17 @@ func New(opts *config.Options) (*Authenticate, error) { cookieCipher: cookieCipher, cookieOptions: cookieOptions, encryptedEncoder: encryptedEncoder, - // IdP - provider: provider, - providerName: opts.Provider, // grpc client for cache dataBrokerClient: dataBrokerClient, jwk: &jose.JSONWebKeySet{}, templates: template.Must(frontend.NewTemplates()), options: config.NewAtomicOptions(), + provider: identity.NewAtomicAuthenticator(), + } + + err = a.updateProvider(cfg) + if err != nil { + return nil, err } cookieStore, err := cookie.NewStore(func() cookie.Options { @@ -217,8 +205,8 @@ func New(opts *config.Options) (*Authenticate, error) { a.sessionStore = cookieStore a.sessionLoaders = []sessions.SessionLoader{qpStore, headerStore, cookieStore} - if opts.SigningKey != "" { - decodedCert, err := base64.StdEncoding.DecodeString(opts.SigningKey) + if cfg.Options.SigningKey != "" { + decodedCert, err := base64.StdEncoding.DecodeString(cfg.Options.SigningKey) if err != nil { return nil, fmt.Errorf("authenticate: failed to decode signing key: %w", err) } @@ -251,4 +239,31 @@ func (a *Authenticate) OnConfigChange(cfg *config.Config) { log.Info().Str("checksum", fmt.Sprintf("%x", cfg.Options.Checksum())).Msg("authenticate: updating options") a.options.Store(cfg.Options) a.setAdminUsers(cfg.Options) + if err := a.updateProvider(cfg); err != nil { + log.Error().Err(err).Msg("authenticate: failed to update identity provider") + } +} + +func (a *Authenticate) updateProvider(cfg *config.Config) error { + redirectURL, _ := urlutil.DeepCopy(cfg.Options.AuthenticateURL) + redirectURL.Path = cfg.Options.AuthenticateCallbackPath + + // configure our identity provider + provider, err := identity.NewAuthenticator( + oauth.Options{ + RedirectURL: redirectURL, + ProviderName: cfg.Options.Provider, + ProviderURL: cfg.Options.ProviderURL, + ClientID: cfg.Options.ClientID, + ClientSecret: cfg.Options.ClientSecret, + Scopes: cfg.Options.Scopes, + ServiceAccount: cfg.Options.ServiceAccount, + AuthCodeOptions: cfg.Options.RequestParams, + }) + if err != nil { + return err + } + a.provider.Store(provider) + + return nil } diff --git a/authenticate/authenticate_test.go b/authenticate/authenticate_test.go index 1e77a4d34..419bacc44 100644 --- a/authenticate/authenticate_test.go +++ b/authenticate/authenticate_test.go @@ -124,7 +124,7 @@ func TestNew(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - _, err := New(tt.opts) + _, err := New(&config.Config{Options: tt.opts}) if (err != nil) != tt.wantErr { t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr) return @@ -151,7 +151,7 @@ func TestIsAdmin(t *testing.T) { t.Parallel() opts := newTestOptions(t) opts.Administrators = tc.admins - a, err := New(opts) + a, err := New(&config.Config{Options: opts}) a.OnConfigChange(&config.Config{Options: opts}) require.NoError(t, err) assert.True(t, a.isAdmin(tc.user) == tc.isAdmin) diff --git a/authenticate/handlers.go b/authenticate/handlers.go index 3cb84f004..955c2b72c 100644 --- a/authenticate/handlers.go +++ b/authenticate/handlers.go @@ -256,7 +256,7 @@ func (a *Authenticate) SignOut(w http.ResponseWriter, r *http.Request) error { sessionState, err := a.getSessionFromCtx(ctx) if err == nil { if s, _ := session.Get(ctx, a.dataBrokerClient, sessionState.ID); s != nil && s.OauthToken != nil { - if err := a.provider.Revoke(ctx, manager.FromOAuthToken(s.OauthToken)); err != nil { + if err := a.provider.Load().Revoke(ctx, manager.FromOAuthToken(s.OauthToken)); err != nil { log.Warn().Err(err).Msg("failed to revoke access token") } } @@ -269,7 +269,7 @@ func (a *Authenticate) SignOut(w http.ResponseWriter, r *http.Request) error { // no matter what happens, we want to clear the session store a.sessionStore.ClearSession(w, r) redirectString := r.FormValue(urlutil.QueryRedirectURI) - endSessionURL, err := a.provider.LogOut() + endSessionURL, err := a.provider.Load().LogOut() if err == nil { params := url.Values{} params.Add("post_logout_redirect_uri", redirectString) @@ -331,7 +331,7 @@ func (a *Authenticate) reauthenticateOrFail(w http.ResponseWriter, r *http.Reque enc := cryptutil.Encrypt(a.cookieCipher, []byte(redirectURL.String()), b) b = append(b, enc...) encodedState := base64.URLEncoding.EncodeToString(b) - httputil.Redirect(w, r, a.provider.GetSignInURL(encodedState), http.StatusFound) + httputil.Redirect(w, r, a.provider.Load().GetSignInURL(encodedState), http.StatusFound) return nil } @@ -377,7 +377,7 @@ func (a *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request) // // Exchange the supplied Authorization Code for a valid user session. s := sessions.State{ID: uuid.New().String()} - accessToken, err := a.provider.Authenticate(ctx, code, &s) + accessToken, err := a.provider.Load().Authenticate(ctx, code, &s) if err != nil { return nil, fmt.Errorf("error redeeming authenticate code: %w", err) } @@ -538,7 +538,7 @@ func (a *Authenticate) saveSessionToDataBroker(ctx context.Context, sessionState s := &session.Session{ Id: sessionState.ID, - UserId: sessionState.UserID(a.providerName), + UserId: sessionState.UserID(a.provider.Load().Name()), ExpiresAt: sessionExpiry, IdToken: &session.IDToken{ Issuer: sessionState.Issuer, @@ -557,7 +557,7 @@ func (a *Authenticate) saveSessionToDataBroker(ctx context.Context, sessionState Id: s.GetUserId(), }, } - err := a.provider.UpdateUserInfo(ctx, accessToken, &mu) + err := a.provider.Load().UpdateUserInfo(ctx, accessToken, &mu) if err != nil { return fmt.Errorf("authenticate: error retrieving user info: %w", err) } diff --git a/authenticate/handlers_test.go b/authenticate/handlers_test.go index 46bb0e973..54304fe11 100644 --- a/authenticate/handlers_test.go +++ b/authenticate/handlers_test.go @@ -16,6 +16,7 @@ import ( "google.golang.org/grpc" "google.golang.org/protobuf/types/known/emptypb" + "github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/internal/encoding" "github.com/pomerium/pomerium/internal/encoding/jws" "github.com/pomerium/pomerium/internal/encoding/mock" @@ -147,7 +148,6 @@ func TestAuthenticate_SignIn(t *testing.T) { a := &Authenticate{ sessionStore: tt.session, - provider: tt.provider, RedirectURL: uriParseHelper("https://some.example"), sharedKey: "secret", sharedEncoder: tt.encoder, @@ -176,7 +176,9 @@ func TestAuthenticate_SignIn(t *testing.T) { }, nil }, }, + provider: identity.NewAtomicAuthenticator(), } + a.provider.Store(tt.provider) uri := &url.URL{Scheme: tt.scheme, Host: tt.host} queryString := uri.Query() @@ -234,7 +236,6 @@ func TestAuthenticate_SignOut(t *testing.T) { defer ctrl.Finish() a := &Authenticate{ sessionStore: tt.sessionStore, - provider: tt.provider, encryptedEncoder: mock.Encoder{}, templates: template.Must(frontend.NewTemplates()), sharedEncoder: mock.Encoder{}, @@ -260,7 +261,9 @@ func TestAuthenticate_SignOut(t *testing.T) { }, nil }, }, + provider: identity.NewAtomicAuthenticator(), } + a.provider.Store(tt.provider) u, _ := url.Parse("/sign_out") params, _ := url.ParseQuery(u.RawQuery) params.Add("sig", tt.sig) @@ -344,10 +347,11 @@ func TestAuthenticate_OAuthCallback(t *testing.T) { a := &Authenticate{ RedirectURL: authURL, sessionStore: tt.session, - provider: tt.provider, cookieCipher: aead, encryptedEncoder: signer, + provider: identity.NewAtomicAuthenticator(), } + a.provider.Store(tt.provider) u, _ := url.Parse("/oauthGet") params, _ := url.ParseQuery(u.RawQuery) params.Add("error", tt.paramErr) @@ -466,7 +470,6 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) { cookieSecret: cryptutil.NewKey(), RedirectURL: uriParseHelper("https://authenticate.corp.beyondperimeter.com"), sessionStore: tt.session, - provider: tt.provider, cookieCipher: aead, encryptedEncoder: signer, sharedEncoder: signer, @@ -489,7 +492,9 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) { }, nil }, }, + provider: identity.NewAtomicAuthenticator(), } + a.provider.Store(tt.provider) r := httptest.NewRequest("GET", "/", nil) state, err := tt.session.LoadSession(r) if err != nil { @@ -535,7 +540,7 @@ func TestWellKnownEndpoint(t *testing.T) { func TestJwksEndpoint(t *testing.T) { o := newTestOptions(t) o.SigningKey = "LS0tLS1CRUdJTiBFQyBQUklWQVRFIEtFWS0tLS0tCk1IY0NBUUVFSUpCMFZkbko1VjEvbVlpYUlIWHhnd2Q0Yzd5YWRTeXMxb3Y0bzA1b0F3ekdvQW9HQ0NxR1NNNDkKQXdFSG9VUURRZ0FFVUc1eENQMEpUVDFINklvbDhqS3VUSVBWTE0wNENnVzlQbEV5cE5SbVdsb29LRVhSOUhUMwpPYnp6aktZaWN6YjArMUt3VjJmTVRFMTh1dy82MXJVQ0JBPT0KLS0tLS1FTkQgRUMgUFJJVkFURSBLRVktLS0tLQo=" - auth, err := New(o) + auth, err := New(&config.Config{Options: o}) if err != nil { t.Fatal(err) } diff --git a/internal/cmd/pomerium/pomerium.go b/internal/cmd/pomerium/pomerium.go index 1a5f5d12d..d74a7dcbb 100644 --- a/internal/cmd/pomerium/pomerium.go +++ b/internal/cmd/pomerium/pomerium.go @@ -135,7 +135,7 @@ func setupAuthenticate(src config.Source, cfg *config.Config, controlPlane *cont return nil } - svc, err := authenticate.New(cfg.Options) + svc, err := authenticate.New(cfg) if err != nil { return fmt.Errorf("error creating authenticate service: %w", err) } diff --git a/internal/identity/mock_provider.go b/internal/identity/mock_provider.go index 73fc8cf42..c330204d4 100644 --- a/internal/identity/mock_provider.go +++ b/internal/identity/mock_provider.go @@ -45,3 +45,8 @@ func (mp MockProvider) LogOut() (*url.URL, error) { return &mp.LogOutResponse, m func (mp MockProvider) UpdateUserInfo(ctx context.Context, t *oauth2.Token, v interface{}) error { return mp.UpdateUserInfoError } + +// Name returns the provider name. +func (mp MockProvider) Name() string { + return "mock" +} diff --git a/internal/identity/oauth/github/github.go b/internal/identity/oauth/github/github.go index 2752b08d9..ee05cd6f6 100644 --- a/internal/identity/oauth/github/github.go +++ b/internal/identity/oauth/github/github.go @@ -238,3 +238,8 @@ func (p *Provider) GetSignInURL(state string) string { func (p *Provider) LogOut() (*url.URL, error) { return nil, oidc.ErrSignoutNotImplemented } + +// Name returns the provider name. +func (p *Provider) Name() string { + return Name +} diff --git a/internal/identity/oidc/azure/microsoft.go b/internal/identity/oidc/azure/microsoft.go index 9c9eb7889..3cbb7d540 100644 --- a/internal/identity/oidc/azure/microsoft.go +++ b/internal/identity/oidc/azure/microsoft.go @@ -47,3 +47,8 @@ func New(ctx context.Context, o *oauth.Options) (*Provider, error) { return &p, nil } + +// Name returns the provider name. +func (p *Provider) Name() string { + return Name +} diff --git a/internal/identity/oidc/gitlab/gitlab.go b/internal/identity/oidc/gitlab/gitlab.go index 560b3f226..e1e53cda0 100644 --- a/internal/identity/oidc/gitlab/gitlab.go +++ b/internal/identity/oidc/gitlab/gitlab.go @@ -45,3 +45,8 @@ func New(ctx context.Context, o *oauth.Options) (*Provider, error) { return &p, nil } + +// Name returns the provider name. +func (p *Provider) Name() string { + return Name +} diff --git a/internal/identity/oidc/google/google.go b/internal/identity/oidc/google/google.go index 20b71db48..4ade21f42 100644 --- a/internal/identity/oidc/google/google.go +++ b/internal/identity/oidc/google/google.go @@ -53,3 +53,8 @@ func New(ctx context.Context, o *oauth.Options) (*Provider, error) { } return &p, nil } + +// Name returns the provider name. +func (p *Provider) Name() string { + return Name +} diff --git a/internal/identity/oidc/oidc.go b/internal/identity/oidc/oidc.go index 3df0bc48b..2cea249f4 100644 --- a/internal/identity/oidc/oidc.go +++ b/internal/identity/oidc/oidc.go @@ -239,3 +239,8 @@ func (p *Provider) GetSubject(v interface{}) (string, error) { } return s.Subject, nil } + +// Name returns the provider name. +func (p *Provider) Name() string { + return Name +} diff --git a/internal/identity/oidc/okta/okta.go b/internal/identity/oidc/okta/okta.go index 7786812ae..d31f25a8c 100644 --- a/internal/identity/oidc/okta/okta.go +++ b/internal/identity/oidc/okta/okta.go @@ -33,3 +33,8 @@ func New(ctx context.Context, o *oauth.Options) (*Provider, error) { return &p, nil } + +// Name returns the provider name. +func (p *Provider) Name() string { + return Name +} diff --git a/internal/identity/oidc/onelogin/onelogin.go b/internal/identity/oidc/onelogin/onelogin.go index 632a2fd5d..f6b2461c2 100644 --- a/internal/identity/oidc/onelogin/onelogin.go +++ b/internal/identity/oidc/onelogin/onelogin.go @@ -44,3 +44,8 @@ func New(ctx context.Context, o *oauth.Options) (*Provider, error) { p.Provider = genericOidc return &p, nil } + +// Name returns the provider name. +func (p *Provider) Name() string { + return Name +} diff --git a/internal/identity/providers.go b/internal/identity/providers.go index 22d74cae1..862b5ba28 100644 --- a/internal/identity/providers.go +++ b/internal/identity/providers.go @@ -6,6 +6,7 @@ import ( "context" "fmt" "net/url" + "sync/atomic" "golang.org/x/oauth2" @@ -19,24 +20,13 @@ import ( "github.com/pomerium/pomerium/internal/identity/oidc/onelogin" ) -var ( - // compile time assertions that providers are satisfying the interface - _ Authenticator = &azure.Provider{} - _ Authenticator = &github.Provider{} - _ Authenticator = &gitlab.Provider{} - _ Authenticator = &google.Provider{} - _ Authenticator = &MockProvider{} - _ Authenticator = &oidc.Provider{} - _ Authenticator = &okta.Provider{} - _ Authenticator = &onelogin.Provider{} -) - // Authenticator is an interface representing the ability to authenticate with an identity provider. type Authenticator interface { Authenticate(context.Context, string, interface{}) (*oauth2.Token, error) Refresh(context.Context, *oauth2.Token, interface{}) (*oauth2.Token, error) Revoke(context.Context, *oauth2.Token) error GetSignInURL(state string) string + Name() string LogOut() (*url.URL, error) UpdateUserInfo(ctx context.Context, t *oauth2.Token, v interface{}) error } @@ -67,3 +57,30 @@ func NewAuthenticator(o oauth.Options) (a Authenticator, err error) { } return a, nil } + +// wrap the Authenticator for the AtomicAuthenticator to support a nil default value. +type authenticatorValue struct { + Authenticator +} + +// An AtomicAuthenticator is a strongly-typed atomic.Value for storing an authenticator. +type AtomicAuthenticator struct { + current atomic.Value +} + +// NewAtomicAuthenticator creates a new AtomicAuthenticator. +func NewAtomicAuthenticator() *AtomicAuthenticator { + a := &AtomicAuthenticator{} + a.current.Store(authenticatorValue{}) + return a +} + +// Load loads the current authenticator. +func (a *AtomicAuthenticator) Load() Authenticator { + return a.current.Load().(authenticatorValue) +} + +// Store stores the authenticator. +func (a *AtomicAuthenticator) Store(value Authenticator) { + a.current.Store(authenticatorValue{value}) +}