diff --git a/authenticate/handlers.go b/authenticate/handlers.go index f4c215366..cdad2ad5b 100644 --- a/authenticate/handlers.go +++ b/authenticate/handlers.go @@ -293,7 +293,12 @@ func (a *Authenticate) reauthenticateOrFail(w http.ResponseWriter, r *http.Reque enc := cryptutil.Encrypt(state.cookieCipher, []byte(redirectURL.String()), b) b = append(b, enc...) encodedState := base64.URLEncoding.EncodeToString(b) - httputil.Redirect(w, r, a.provider.Load().GetSignInURL(encodedState), http.StatusFound) + signinURL, err := a.provider.Load().GetSignInURL(encodedState) + if err != nil { + return httputil.NewError(http.StatusInternalServerError, + fmt.Errorf("failed to get sign in url: %w", err)) + } + httputil.Redirect(w, r, signinURL, http.StatusFound) return nil } diff --git a/internal/identity/mock_provider.go b/internal/identity/mock_provider.go index b208359a1..cdb298416 100644 --- a/internal/identity/mock_provider.go +++ b/internal/identity/mock_provider.go @@ -38,7 +38,7 @@ func (mp MockProvider) Revoke(ctx context.Context, s *oauth2.Token) error { } // GetSignInURL is a mocked providers function. -func (mp MockProvider) GetSignInURL(s string) string { return mp.GetSignInURLResponse } +func (mp MockProvider) GetSignInURL(s string) (string, error) { return mp.GetSignInURLResponse, nil } // LogOut is a mocked providers function. func (mp MockProvider) LogOut() (*url.URL, error) { return &mp.LogOutResponse, mp.LogOutError } diff --git a/internal/identity/oauth/github/github.go b/internal/identity/oauth/github/github.go index 8d6389553..5077ad4e3 100644 --- a/internal/identity/oauth/github/github.go +++ b/internal/identity/oauth/github/github.go @@ -231,8 +231,8 @@ func (p *Provider) Revoke(ctx context.Context, token *oauth2.Token) error { // GetSignInURL returns a URL to OAuth 2.0 provider's consent page // that asks for permissions for the required scopes explicitly. -func (p *Provider) GetSignInURL(state string) string { - return p.Oauth.AuthCodeURL(state, oauth2.AccessTypeOffline) +func (p *Provider) GetSignInURL(state string) (string, error) { + return p.Oauth.AuthCodeURL(state, oauth2.AccessTypeOffline), nil } // LogOut is not implemented by github. diff --git a/internal/identity/oidc/azure/microsoft.go b/internal/identity/oidc/azure/microsoft.go index 61d8936df..d84a45f38 100644 --- a/internal/identity/oidc/azure/microsoft.go +++ b/internal/identity/oidc/azure/microsoft.go @@ -46,19 +46,20 @@ func New(ctx context.Context, o *oauth.Options) (*Provider, error) { if o.ProviderURL == "" { o.ProviderURL = defaultProviderURL } - genericOidc, err := newProvider(ctx, o) + genericOidc, err := newProvider(ctx, o, + pom_oidc.WithGetVerifier(func(provider *go_oidc.Provider) *go_oidc.IDTokenVerifier { + return provider.Verifier(&go_oidc.Config{ + ClientID: o.ClientID, + // If using the common endpoint, the verification provider URI will not match. + // https://github.com/pomerium/pomerium/issues/1605 + SkipIssuerCheck: o.ProviderURL == defaultProviderURL, + }) + })) if err != nil { return nil, fmt.Errorf("%s: failed creating oidc provider: %w", Name, err) } p.Provider = genericOidc - genericOidc.Verifier = genericOidc.Provider.Verifier(&go_oidc.Config{ - ClientID: o.ClientID, - // If using the common endpoint, the verification provider URI will not match. - // https://github.com/pomerium/pomerium/issues/1605 - SkipIssuerCheck: o.ProviderURL == defaultProviderURL, - }) - p.AuthCodeOptions = defaultAuthCodeOptions if len(o.AuthCodeOptions) != 0 { p.AuthCodeOptions = o.AuthCodeOptions @@ -79,7 +80,7 @@ func (p *Provider) Name() string { // If {tenantid} is in the issuer string, we force the issuer to match the defaultURL. // // https://github.com/pomerium/pomerium/issues/1605 -func newProvider(ctx context.Context, o *oauth.Options) (*pom_oidc.Provider, error) { +func newProvider(ctx context.Context, o *oauth.Options, options ...pom_oidc.Option) (*pom_oidc.Provider, error) { originalClient := http.DefaultClient if c, ok := ctx.Value(oauth2.HTTPClient).(*http.Client); ok { originalClient = c @@ -90,7 +91,7 @@ func newProvider(ctx context.Context, o *oauth.Options) (*pom_oidc.Provider, err client.Transport = &wellKnownConfiguration{underlying: client.Transport} ctx = context.WithValue(ctx, oauth2.HTTPClient, client) - return pom_oidc.New(ctx, o) + return pom_oidc.New(ctx, o, options...) } type wellKnownConfiguration struct { diff --git a/internal/identity/oidc/config.go b/internal/identity/oidc/config.go new file mode 100644 index 000000000..0d317c359 --- /dev/null +++ b/internal/identity/oidc/config.go @@ -0,0 +1,44 @@ +package oidc + +import ( + "github.com/coreos/go-oidc/v3/oidc" + "golang.org/x/oauth2" +) + +type config struct { + getProvider func() (*oidc.Provider, error) + getVerifier func(provider *oidc.Provider) *oidc.IDTokenVerifier + getOauthConfig func(provider *oidc.Provider) *oauth2.Config +} + +// An Option customizes the config. +type Option func(*config) + +func getConfig(options ...Option) *config { + cfg := &config{} + for _, option := range options { + option(cfg) + } + return cfg +} + +// WithGetOauthConfig sets the getOauthConfig function in the config. +func WithGetOauthConfig(f func(provider *oidc.Provider) *oauth2.Config) Option { + return func(cfg *config) { + cfg.getOauthConfig = f + } +} + +// WithGetProvider sets the getProvider function in the config. +func WithGetProvider(f func() (*oidc.Provider, error)) Option { + return func(cfg *config) { + cfg.getProvider = f + } +} + +// WithGetVerifier sets the getVerifier function in the config. +func WithGetVerifier(f func(*oidc.Provider) *oidc.IDTokenVerifier) Option { + return func(cfg *config) { + cfg.getVerifier = f + } +} diff --git a/internal/identity/oidc/oidc.go b/internal/identity/oidc/oidc.go index d81ebe39a..9ffd7cb31 100644 --- a/internal/identity/oidc/oidc.go +++ b/internal/identity/oidc/oidc.go @@ -10,6 +10,7 @@ import ( "fmt" "net/http" "net/url" + "sync" go_oidc "github.com/coreos/go-oidc/v3/oidc" "golang.org/x/oauth2" @@ -32,13 +33,7 @@ var defaultAuthCodeOptions = []oauth2.AuthCodeOption{oauth2.AccessTypeOffline} // of an authorization identity provider. // https://openid.net/specs/openid-connect-core-1_0.html type Provider struct { - // Provider represents an OpenID Connect server's configuration. - Provider *go_oidc.Provider - // Verifier provides verification for ID Tokens. - Verifier *go_oidc.IDTokenVerifier - // Oauth describes a typical 3-legged OAuth2 flow, with both the - // client application information and the server's endpoint URLs. - Oauth *oauth2.Config + cfg *config // RevocationURL is the location of the OAuth 2.0 token revocation endpoint. // https://tools.ietf.org/html/rfc7009 @@ -52,41 +47,53 @@ type Provider struct { // AuthCodeOptions specifies additional key value pairs query params to add // to the request flow signin url. AuthCodeOptions map[string]string + + mu sync.Mutex + provider *go_oidc.Provider } // New creates a new instance of a generic OpenID Connect provider. -func New(ctx context.Context, o *oauth.Options) (*Provider, error) { - var err error - var p Provider +func New(ctx context.Context, o *oauth.Options, options ...Option) (*Provider, error) { if o.ProviderURL == "" { return nil, ErrMissingProviderURL } + + p := new(Provider) if len(o.Scopes) == 0 { o.Scopes = defaultScopes } - p.Provider, err = go_oidc.NewProvider(ctx, o.ProviderURL) - if err != nil { - return nil, fmt.Errorf("identity/oidc: could not connect to %s: %w", o.ProviderName, err) - } - - p.Verifier = p.Provider.Verifier(&go_oidc.Config{ClientID: o.ClientID}) - p.Oauth = &oauth2.Config{ - ClientID: o.ClientID, - ClientSecret: o.ClientSecret, - Scopes: o.Scopes, - Endpoint: p.Provider.Endpoint(), - RedirectURL: o.RedirectURL.String(), - } - if len(o.AuthCodeOptions) != 0 { p.AuthCodeOptions = o.AuthCodeOptions } - // add non-standard claims like end-session, revoke, and user info - if err := p.Provider.Claims(&p); err != nil { - return nil, fmt.Errorf("identity/oidc: could not retrieve additional claims: %w", err) - } - return &p, nil + p.cfg = getConfig(append([]Option{ + WithGetOauthConfig(func(provider *go_oidc.Provider) *oauth2.Config { + return &oauth2.Config{ + ClientID: o.ClientID, + ClientSecret: o.ClientSecret, + Scopes: o.Scopes, + Endpoint: provider.Endpoint(), + RedirectURL: o.RedirectURL.String(), + } + }), + WithGetProvider(func() (*go_oidc.Provider, error) { + pp, err := go_oidc.NewProvider(ctx, o.ProviderURL) + if err != nil { + return nil, fmt.Errorf("identity/oidc: could not connect to %s: %w", o.ProviderName, err) + } + + // add non-standard claims like end-session, revoke, and user info + if err := pp.Claims(&p); err != nil { + return nil, fmt.Errorf("identity/oidc: could not retrieve additional claims: %w", err) + } + + return pp, nil + }), + WithGetVerifier(func(provider *go_oidc.Provider) *go_oidc.IDTokenVerifier { + return provider.Verifier(&go_oidc.Config{ClientID: o.ClientID}) + }), + }, options...)...) + return p, nil } // GetSignInURL returns the url of the provider's OAuth 2.0 consent page @@ -96,19 +103,29 @@ func New(ctx context.Context, o *oauth.Options) (*Provider, error) { // always provide a non-empty string and validate that it matches the // the state query parameter on your redirect callback. // See http://tools.ietf.org/html/rfc6749#section-10.12 for more info. -func (p *Provider) GetSignInURL(state string) string { +func (p *Provider) GetSignInURL(state string) (string, error) { + oa, err := p.GetOauthConfig() + if err != nil { + return "", err + } + opts := defaultAuthCodeOptions for k, v := range p.AuthCodeOptions { opts = append(opts, oauth2.SetAuthURLParam(k, v)) } - return p.Oauth.AuthCodeURL(state, opts...) + return oa.AuthCodeURL(state, opts...), nil } // Authenticate converts an authorization code returned from the identity // provider into a token which is then converted into a user session. func (p *Provider) Authenticate(ctx context.Context, code string, v identity.State) (*oauth2.Token, error) { + oa, err := p.GetOauthConfig() + if err != nil { + return nil, err + } + // Exchange converts an authorization code into a token. - oauth2Token, err := p.Oauth.Exchange(ctx, code) + oauth2Token, err := oa.Exchange(ctx, code) if err != nil { return nil, fmt.Errorf("identity/oidc: token exchange failed: %w", err) } @@ -140,7 +157,12 @@ func (p *Provider) Authenticate(ctx context.Context, code string, v identity.Sta // // https://openid.net/specs/openid-connect-core-1_0.html#UserInfo func (p *Provider) UpdateUserInfo(ctx context.Context, t *oauth2.Token, v interface{}) error { - userInfo, err := getUserInfo(ctx, p.Provider, oauth2.StaticTokenSource(t)) + pp, err := p.GetProvider() + if err != nil { + return err + } + + userInfo, err := getUserInfo(ctx, pp, oauth2.StaticTokenSource(t)) if err != nil { return fmt.Errorf("identity/oidc: user info endpoint: %w", err) } @@ -160,8 +182,13 @@ func (p *Provider) Refresh(ctx context.Context, t *oauth2.Token, v identity.Stat if t.RefreshToken == "" { return nil, ErrMissingRefreshToken } - var err error - newToken, err := p.Oauth.TokenSource(ctx, t).Token() + + oa, err := p.GetOauthConfig() + if err != nil { + return nil, err + } + + newToken, err := oa.TokenSource(ctx, t).Token() if err != nil { return nil, fmt.Errorf("identity/oidc: refresh failed: %w", err) } @@ -186,11 +213,16 @@ func (p *Provider) Refresh(ctx context.Context, t *oauth2.Token, v identity.Stat // // https://openid.net/specs/openid-connect-core-1_0.html#TokenResponse func (p *Provider) getIDToken(ctx context.Context, t *oauth2.Token) (*go_oidc.IDToken, error) { + v, err := p.GetVerifier() + if err != nil { + return nil, err + } + rawIDToken, ok := t.Extra("id_token").(string) if !ok { return nil, ErrMissingIDToken } - return p.Verifier.Verify(ctx, rawIDToken) + return v.Verify(ctx, rawIDToken) } // Revoke enables a user to revoke her token. If the identity provider does not @@ -205,16 +237,21 @@ func (p *Provider) Revoke(ctx context.Context, t *oauth2.Token) error { return ErrMissingAccessToken } + oa, err := p.GetOauthConfig() + if err != nil { + return err + } + params := url.Values{} params.Add("token", t.AccessToken) params.Add("token_type_hint", "access_token") // Some providers like okta / onelogin require "client authentication" // https://developer.okta.com/docs/reference/api/oidc/#client-secret // https://developers.onelogin.com/openid-connect/api/revoke-session - params.Add("client_id", p.Oauth.ClientID) - params.Add("client_secret", p.Oauth.ClientSecret) + params.Add("client_id", oa.ClientID) + params.Add("client_secret", oa.ClientSecret) - err := httputil.Client(ctx, http.MethodPost, p.RevocationURL, version.UserAgent(), nil, params, nil) + err = httputil.Client(ctx, http.MethodPost, p.RevocationURL, version.UserAgent(), nil, params, nil) if err != nil && errors.Is(err, httputil.ErrTokenRevoked) { return fmt.Errorf("internal/oidc: unexpected revoke error: %w", err) } @@ -253,3 +290,38 @@ func (p *Provider) GetSubject(v interface{}) (string, error) { func (p *Provider) Name() string { return Name } + +// GetProvider gets the underlying oidc Provider. +func (p *Provider) GetProvider() (*go_oidc.Provider, error) { + p.mu.Lock() + defer p.mu.Unlock() + + if p.provider != nil { + return p.provider, nil + } + + pp, err := p.cfg.getProvider() + if err != nil { + return nil, err + } + p.provider = pp + return pp, nil +} + +// GetVerifier gets the verifier. +func (p *Provider) GetVerifier() (*go_oidc.IDTokenVerifier, error) { + pp, err := p.GetProvider() + if err != nil { + return nil, err + } + return p.cfg.getVerifier(pp), nil +} + +// GetOauthConfig gets the oauth. +func (p *Provider) GetOauthConfig() (*oauth2.Config, error) { + pp, err := p.GetProvider() + if err != nil { + return nil, err + } + return p.cfg.getOauthConfig(pp), nil +} diff --git a/internal/identity/providers.go b/internal/identity/providers.go index 3382cc2c1..9c09b6b97 100644 --- a/internal/identity/providers.go +++ b/internal/identity/providers.go @@ -27,7 +27,7 @@ type Authenticator interface { Authenticate(context.Context, string, identity.State) (*oauth2.Token, error) Refresh(context.Context, *oauth2.Token, identity.State) (*oauth2.Token, error) Revoke(context.Context, *oauth2.Token) error - GetSignInURL(state string) string + GetSignInURL(state string) (string, error) Name() string LogOut() (*url.URL, error) UpdateUserInfo(ctx context.Context, t *oauth2.Token, v interface{}) error