mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-10 15:47:36 +02:00
authenticate: delay evaluation of OIDC provider (#1802)
* authenticate: delay evaluation of OIDC provider * add additional error message * address comments
This commit is contained in:
parent
a14b65ec3f
commit
5e3aa91f23
7 changed files with 176 additions and 54 deletions
|
@ -293,7 +293,12 @@ func (a *Authenticate) reauthenticateOrFail(w http.ResponseWriter, r *http.Reque
|
|||
enc := cryptutil.Encrypt(state.cookieCipher, []byte(redirectURL.String()), b)
|
||||
b = append(b, enc...)
|
||||
encodedState := base64.URLEncoding.EncodeToString(b)
|
||||
httputil.Redirect(w, r, a.provider.Load().GetSignInURL(encodedState), http.StatusFound)
|
||||
signinURL, err := a.provider.Load().GetSignInURL(encodedState)
|
||||
if err != nil {
|
||||
return httputil.NewError(http.StatusInternalServerError,
|
||||
fmt.Errorf("failed to get sign in url: %w", err))
|
||||
}
|
||||
httputil.Redirect(w, r, signinURL, http.StatusFound)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
@ -38,7 +38,7 @@ func (mp MockProvider) Revoke(ctx context.Context, s *oauth2.Token) error {
|
|||
}
|
||||
|
||||
// GetSignInURL is a mocked providers function.
|
||||
func (mp MockProvider) GetSignInURL(s string) string { return mp.GetSignInURLResponse }
|
||||
func (mp MockProvider) GetSignInURL(s string) (string, error) { return mp.GetSignInURLResponse, nil }
|
||||
|
||||
// LogOut is a mocked providers function.
|
||||
func (mp MockProvider) LogOut() (*url.URL, error) { return &mp.LogOutResponse, mp.LogOutError }
|
||||
|
|
|
@ -231,8 +231,8 @@ func (p *Provider) Revoke(ctx context.Context, token *oauth2.Token) error {
|
|||
|
||||
// GetSignInURL returns a URL to OAuth 2.0 provider's consent page
|
||||
// that asks for permissions for the required scopes explicitly.
|
||||
func (p *Provider) GetSignInURL(state string) string {
|
||||
return p.Oauth.AuthCodeURL(state, oauth2.AccessTypeOffline)
|
||||
func (p *Provider) GetSignInURL(state string) (string, error) {
|
||||
return p.Oauth.AuthCodeURL(state, oauth2.AccessTypeOffline), nil
|
||||
}
|
||||
|
||||
// LogOut is not implemented by github.
|
||||
|
|
|
@ -46,19 +46,20 @@ func New(ctx context.Context, o *oauth.Options) (*Provider, error) {
|
|||
if o.ProviderURL == "" {
|
||||
o.ProviderURL = defaultProviderURL
|
||||
}
|
||||
genericOidc, err := newProvider(ctx, o)
|
||||
genericOidc, err := newProvider(ctx, o,
|
||||
pom_oidc.WithGetVerifier(func(provider *go_oidc.Provider) *go_oidc.IDTokenVerifier {
|
||||
return provider.Verifier(&go_oidc.Config{
|
||||
ClientID: o.ClientID,
|
||||
// If using the common endpoint, the verification provider URI will not match.
|
||||
// https://github.com/pomerium/pomerium/issues/1605
|
||||
SkipIssuerCheck: o.ProviderURL == defaultProviderURL,
|
||||
})
|
||||
}))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%s: failed creating oidc provider: %w", Name, err)
|
||||
}
|
||||
p.Provider = genericOidc
|
||||
|
||||
genericOidc.Verifier = genericOidc.Provider.Verifier(&go_oidc.Config{
|
||||
ClientID: o.ClientID,
|
||||
// If using the common endpoint, the verification provider URI will not match.
|
||||
// https://github.com/pomerium/pomerium/issues/1605
|
||||
SkipIssuerCheck: o.ProviderURL == defaultProviderURL,
|
||||
})
|
||||
|
||||
p.AuthCodeOptions = defaultAuthCodeOptions
|
||||
if len(o.AuthCodeOptions) != 0 {
|
||||
p.AuthCodeOptions = o.AuthCodeOptions
|
||||
|
@ -79,7 +80,7 @@ func (p *Provider) Name() string {
|
|||
// If {tenantid} is in the issuer string, we force the issuer to match the defaultURL.
|
||||
//
|
||||
// https://github.com/pomerium/pomerium/issues/1605
|
||||
func newProvider(ctx context.Context, o *oauth.Options) (*pom_oidc.Provider, error) {
|
||||
func newProvider(ctx context.Context, o *oauth.Options, options ...pom_oidc.Option) (*pom_oidc.Provider, error) {
|
||||
originalClient := http.DefaultClient
|
||||
if c, ok := ctx.Value(oauth2.HTTPClient).(*http.Client); ok {
|
||||
originalClient = c
|
||||
|
@ -90,7 +91,7 @@ func newProvider(ctx context.Context, o *oauth.Options) (*pom_oidc.Provider, err
|
|||
client.Transport = &wellKnownConfiguration{underlying: client.Transport}
|
||||
|
||||
ctx = context.WithValue(ctx, oauth2.HTTPClient, client)
|
||||
return pom_oidc.New(ctx, o)
|
||||
return pom_oidc.New(ctx, o, options...)
|
||||
}
|
||||
|
||||
type wellKnownConfiguration struct {
|
||||
|
|
44
internal/identity/oidc/config.go
Normal file
44
internal/identity/oidc/config.go
Normal file
|
@ -0,0 +1,44 @@
|
|||
package oidc
|
||||
|
||||
import (
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
type config struct {
|
||||
getProvider func() (*oidc.Provider, error)
|
||||
getVerifier func(provider *oidc.Provider) *oidc.IDTokenVerifier
|
||||
getOauthConfig func(provider *oidc.Provider) *oauth2.Config
|
||||
}
|
||||
|
||||
// An Option customizes the config.
|
||||
type Option func(*config)
|
||||
|
||||
func getConfig(options ...Option) *config {
|
||||
cfg := &config{}
|
||||
for _, option := range options {
|
||||
option(cfg)
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
|
||||
// WithGetOauthConfig sets the getOauthConfig function in the config.
|
||||
func WithGetOauthConfig(f func(provider *oidc.Provider) *oauth2.Config) Option {
|
||||
return func(cfg *config) {
|
||||
cfg.getOauthConfig = f
|
||||
}
|
||||
}
|
||||
|
||||
// WithGetProvider sets the getProvider function in the config.
|
||||
func WithGetProvider(f func() (*oidc.Provider, error)) Option {
|
||||
return func(cfg *config) {
|
||||
cfg.getProvider = f
|
||||
}
|
||||
}
|
||||
|
||||
// WithGetVerifier sets the getVerifier function in the config.
|
||||
func WithGetVerifier(f func(*oidc.Provider) *oidc.IDTokenVerifier) Option {
|
||||
return func(cfg *config) {
|
||||
cfg.getVerifier = f
|
||||
}
|
||||
}
|
|
@ -10,6 +10,7 @@ import (
|
|||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"sync"
|
||||
|
||||
go_oidc "github.com/coreos/go-oidc/v3/oidc"
|
||||
"golang.org/x/oauth2"
|
||||
|
@ -32,13 +33,7 @@ var defaultAuthCodeOptions = []oauth2.AuthCodeOption{oauth2.AccessTypeOffline}
|
|||
// of an authorization identity provider.
|
||||
// https://openid.net/specs/openid-connect-core-1_0.html
|
||||
type Provider struct {
|
||||
// Provider represents an OpenID Connect server's configuration.
|
||||
Provider *go_oidc.Provider
|
||||
// Verifier provides verification for ID Tokens.
|
||||
Verifier *go_oidc.IDTokenVerifier
|
||||
// Oauth describes a typical 3-legged OAuth2 flow, with both the
|
||||
// client application information and the server's endpoint URLs.
|
||||
Oauth *oauth2.Config
|
||||
cfg *config
|
||||
|
||||
// RevocationURL is the location of the OAuth 2.0 token revocation endpoint.
|
||||
// https://tools.ietf.org/html/rfc7009
|
||||
|
@ -52,41 +47,53 @@ type Provider struct {
|
|||
// AuthCodeOptions specifies additional key value pairs query params to add
|
||||
// to the request flow signin url.
|
||||
AuthCodeOptions map[string]string
|
||||
|
||||
mu sync.Mutex
|
||||
provider *go_oidc.Provider
|
||||
}
|
||||
|
||||
// New creates a new instance of a generic OpenID Connect provider.
|
||||
func New(ctx context.Context, o *oauth.Options) (*Provider, error) {
|
||||
var err error
|
||||
var p Provider
|
||||
func New(ctx context.Context, o *oauth.Options, options ...Option) (*Provider, error) {
|
||||
if o.ProviderURL == "" {
|
||||
return nil, ErrMissingProviderURL
|
||||
}
|
||||
|
||||
p := new(Provider)
|
||||
if len(o.Scopes) == 0 {
|
||||
o.Scopes = defaultScopes
|
||||
}
|
||||
p.Provider, err = go_oidc.NewProvider(ctx, o.ProviderURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("identity/oidc: could not connect to %s: %w", o.ProviderName, err)
|
||||
}
|
||||
|
||||
p.Verifier = p.Provider.Verifier(&go_oidc.Config{ClientID: o.ClientID})
|
||||
p.Oauth = &oauth2.Config{
|
||||
ClientID: o.ClientID,
|
||||
ClientSecret: o.ClientSecret,
|
||||
Scopes: o.Scopes,
|
||||
Endpoint: p.Provider.Endpoint(),
|
||||
RedirectURL: o.RedirectURL.String(),
|
||||
}
|
||||
|
||||
if len(o.AuthCodeOptions) != 0 {
|
||||
p.AuthCodeOptions = o.AuthCodeOptions
|
||||
}
|
||||
|
||||
// add non-standard claims like end-session, revoke, and user info
|
||||
if err := p.Provider.Claims(&p); err != nil {
|
||||
return nil, fmt.Errorf("identity/oidc: could not retrieve additional claims: %w", err)
|
||||
}
|
||||
return &p, nil
|
||||
p.cfg = getConfig(append([]Option{
|
||||
WithGetOauthConfig(func(provider *go_oidc.Provider) *oauth2.Config {
|
||||
return &oauth2.Config{
|
||||
ClientID: o.ClientID,
|
||||
ClientSecret: o.ClientSecret,
|
||||
Scopes: o.Scopes,
|
||||
Endpoint: provider.Endpoint(),
|
||||
RedirectURL: o.RedirectURL.String(),
|
||||
}
|
||||
}),
|
||||
WithGetProvider(func() (*go_oidc.Provider, error) {
|
||||
pp, err := go_oidc.NewProvider(ctx, o.ProviderURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("identity/oidc: could not connect to %s: %w", o.ProviderName, err)
|
||||
}
|
||||
|
||||
// add non-standard claims like end-session, revoke, and user info
|
||||
if err := pp.Claims(&p); err != nil {
|
||||
return nil, fmt.Errorf("identity/oidc: could not retrieve additional claims: %w", err)
|
||||
}
|
||||
|
||||
return pp, nil
|
||||
}),
|
||||
WithGetVerifier(func(provider *go_oidc.Provider) *go_oidc.IDTokenVerifier {
|
||||
return provider.Verifier(&go_oidc.Config{ClientID: o.ClientID})
|
||||
}),
|
||||
}, options...)...)
|
||||
return p, nil
|
||||
}
|
||||
|
||||
// GetSignInURL returns the url of the provider's OAuth 2.0 consent page
|
||||
|
@ -96,19 +103,29 @@ func New(ctx context.Context, o *oauth.Options) (*Provider, error) {
|
|||
// always provide a non-empty string and validate that it matches the
|
||||
// the state query parameter on your redirect callback.
|
||||
// See http://tools.ietf.org/html/rfc6749#section-10.12 for more info.
|
||||
func (p *Provider) GetSignInURL(state string) string {
|
||||
func (p *Provider) GetSignInURL(state string) (string, error) {
|
||||
oa, err := p.GetOauthConfig()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
opts := defaultAuthCodeOptions
|
||||
for k, v := range p.AuthCodeOptions {
|
||||
opts = append(opts, oauth2.SetAuthURLParam(k, v))
|
||||
}
|
||||
return p.Oauth.AuthCodeURL(state, opts...)
|
||||
return oa.AuthCodeURL(state, opts...), nil
|
||||
}
|
||||
|
||||
// Authenticate converts an authorization code returned from the identity
|
||||
// provider into a token which is then converted into a user session.
|
||||
func (p *Provider) Authenticate(ctx context.Context, code string, v identity.State) (*oauth2.Token, error) {
|
||||
oa, err := p.GetOauthConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Exchange converts an authorization code into a token.
|
||||
oauth2Token, err := p.Oauth.Exchange(ctx, code)
|
||||
oauth2Token, err := oa.Exchange(ctx, code)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("identity/oidc: token exchange failed: %w", err)
|
||||
}
|
||||
|
@ -140,7 +157,12 @@ func (p *Provider) Authenticate(ctx context.Context, code string, v identity.Sta
|
|||
//
|
||||
// https://openid.net/specs/openid-connect-core-1_0.html#UserInfo
|
||||
func (p *Provider) UpdateUserInfo(ctx context.Context, t *oauth2.Token, v interface{}) error {
|
||||
userInfo, err := getUserInfo(ctx, p.Provider, oauth2.StaticTokenSource(t))
|
||||
pp, err := p.GetProvider()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
userInfo, err := getUserInfo(ctx, pp, oauth2.StaticTokenSource(t))
|
||||
if err != nil {
|
||||
return fmt.Errorf("identity/oidc: user info endpoint: %w", err)
|
||||
}
|
||||
|
@ -160,8 +182,13 @@ func (p *Provider) Refresh(ctx context.Context, t *oauth2.Token, v identity.Stat
|
|||
if t.RefreshToken == "" {
|
||||
return nil, ErrMissingRefreshToken
|
||||
}
|
||||
var err error
|
||||
newToken, err := p.Oauth.TokenSource(ctx, t).Token()
|
||||
|
||||
oa, err := p.GetOauthConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
newToken, err := oa.TokenSource(ctx, t).Token()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("identity/oidc: refresh failed: %w", err)
|
||||
}
|
||||
|
@ -186,11 +213,16 @@ func (p *Provider) Refresh(ctx context.Context, t *oauth2.Token, v identity.Stat
|
|||
//
|
||||
// https://openid.net/specs/openid-connect-core-1_0.html#TokenResponse
|
||||
func (p *Provider) getIDToken(ctx context.Context, t *oauth2.Token) (*go_oidc.IDToken, error) {
|
||||
v, err := p.GetVerifier()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rawIDToken, ok := t.Extra("id_token").(string)
|
||||
if !ok {
|
||||
return nil, ErrMissingIDToken
|
||||
}
|
||||
return p.Verifier.Verify(ctx, rawIDToken)
|
||||
return v.Verify(ctx, rawIDToken)
|
||||
}
|
||||
|
||||
// Revoke enables a user to revoke her token. If the identity provider does not
|
||||
|
@ -205,16 +237,21 @@ func (p *Provider) Revoke(ctx context.Context, t *oauth2.Token) error {
|
|||
return ErrMissingAccessToken
|
||||
}
|
||||
|
||||
oa, err := p.GetOauthConfig()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
params := url.Values{}
|
||||
params.Add("token", t.AccessToken)
|
||||
params.Add("token_type_hint", "access_token")
|
||||
// Some providers like okta / onelogin require "client authentication"
|
||||
// https://developer.okta.com/docs/reference/api/oidc/#client-secret
|
||||
// https://developers.onelogin.com/openid-connect/api/revoke-session
|
||||
params.Add("client_id", p.Oauth.ClientID)
|
||||
params.Add("client_secret", p.Oauth.ClientSecret)
|
||||
params.Add("client_id", oa.ClientID)
|
||||
params.Add("client_secret", oa.ClientSecret)
|
||||
|
||||
err := httputil.Client(ctx, http.MethodPost, p.RevocationURL, version.UserAgent(), nil, params, nil)
|
||||
err = httputil.Client(ctx, http.MethodPost, p.RevocationURL, version.UserAgent(), nil, params, nil)
|
||||
if err != nil && errors.Is(err, httputil.ErrTokenRevoked) {
|
||||
return fmt.Errorf("internal/oidc: unexpected revoke error: %w", err)
|
||||
}
|
||||
|
@ -253,3 +290,38 @@ func (p *Provider) GetSubject(v interface{}) (string, error) {
|
|||
func (p *Provider) Name() string {
|
||||
return Name
|
||||
}
|
||||
|
||||
// GetProvider gets the underlying oidc Provider.
|
||||
func (p *Provider) GetProvider() (*go_oidc.Provider, error) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
if p.provider != nil {
|
||||
return p.provider, nil
|
||||
}
|
||||
|
||||
pp, err := p.cfg.getProvider()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
p.provider = pp
|
||||
return pp, nil
|
||||
}
|
||||
|
||||
// GetVerifier gets the verifier.
|
||||
func (p *Provider) GetVerifier() (*go_oidc.IDTokenVerifier, error) {
|
||||
pp, err := p.GetProvider()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return p.cfg.getVerifier(pp), nil
|
||||
}
|
||||
|
||||
// GetOauthConfig gets the oauth.
|
||||
func (p *Provider) GetOauthConfig() (*oauth2.Config, error) {
|
||||
pp, err := p.GetProvider()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return p.cfg.getOauthConfig(pp), nil
|
||||
}
|
||||
|
|
|
@ -27,7 +27,7 @@ type Authenticator interface {
|
|||
Authenticate(context.Context, string, identity.State) (*oauth2.Token, error)
|
||||
Refresh(context.Context, *oauth2.Token, identity.State) (*oauth2.Token, error)
|
||||
Revoke(context.Context, *oauth2.Token) error
|
||||
GetSignInURL(state string) string
|
||||
GetSignInURL(state string) (string, error)
|
||||
Name() string
|
||||
LogOut() (*url.URL, error)
|
||||
UpdateUserInfo(ctx context.Context, t *oauth2.Token, v interface{}) error
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue