diff --git a/authenticate/authenticate.go b/authenticate/authenticate.go index 5183afca4..9732c3205 100644 --- a/authenticate/authenticate.go +++ b/authenticate/authenticate.go @@ -108,7 +108,7 @@ func (a *Authenticate) updateProvider(cfg *config.Config) error { // configure our identity provider provider, err := identity.NewAuthenticator( oauth.Options{ - RedirectURL: redirectURL, + RedirectURL: redirectURL.String(), ProviderName: cfg.Options.Provider, ProviderURL: cfg.Options.ProviderURL, ClientID: cfg.Options.ClientID, diff --git a/authenticate/handlers.go b/authenticate/handlers.go index ce116b470..c3d4d1f07 100644 --- a/authenticate/handlers.go +++ b/authenticate/handlers.go @@ -23,6 +23,7 @@ import ( "github.com/pomerium/pomerium/internal/httputil" "github.com/pomerium/pomerium/internal/identity" "github.com/pomerium/pomerium/internal/identity/manager" + "github.com/pomerium/pomerium/internal/identity/oauth" "github.com/pomerium/pomerium/internal/identity/oidc" "github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/middleware" @@ -56,7 +57,6 @@ func (a *Authenticate) Mount(r *mux.Router) { csrf.Path("/"), csrf.UnsafePaths( []string{ - "/oauth2/callback", // rfc6749#section-10.12 accepts GET "/.pomerium/sign_out", // https://openid.net/specs/openid-connect-frontchannel-1_0.html }), csrf.FormValueName("state"), // rfc6749#section-10.12 @@ -306,13 +306,16 @@ func (a *Authenticate) reauthenticateOrFail(w http.ResponseWriter, r *http.Reque } state.sessionStore.ClearSession(w, r) redirectURL := state.redirectURL.ResolveReference(r.URL) - nonce := csrf.Token(r) - now := time.Now().Unix() - b := []byte(fmt.Sprintf("%s|%d|", nonce, now)) - enc := cryptutil.Encrypt(state.cookieCipher, []byte(redirectURL.String()), b) - b = append(b, enc...) - encodedState := base64.URLEncoding.EncodeToString(b) - signinURL, err := a.provider.Load().GetSignInURL(encodedState) + + if rawOAuthRedirectURI := r.FormValue(urlutil.QueryOAuthRedirectURI); rawOAuthRedirectURI != "" { + redirectURL, err = urlutil.ParseAndValidateURL(rawOAuthRedirectURI) + if err != nil { + return httputil.NewError(http.StatusInternalServerError, err) + } + } + + oauthState := oauth.NewState(redirectURL.String()).Encode(state.cookieCipher) + signinURL, err := a.provider.Load().GetSignInURL(oauthState, redirectURL.String()) if err != nil { return httputil.NewError(http.StatusInternalServerError, fmt.Errorf("failed to get sign in url: %w", err)) @@ -370,34 +373,12 @@ func (a *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request) return nil, fmt.Errorf("error redeeming authenticate code: %w", err) } - // state includes a csrf nonce (validated by middleware) and redirect uri - bytes, err := base64.URLEncoding.DecodeString(r.FormValue("state")) - if err != nil { - return nil, httputil.NewError(http.StatusBadRequest, fmt.Errorf("bad bytes: %w", err)) - } - - // split state into concat'd components - // (nonce|timestamp|redirect_url|encrypted_data(redirect_url)+mac(nonce,ts)) - statePayload := strings.SplitN(string(bytes), "|", 3) - if len(statePayload) != 3 { - return nil, httputil.NewError(http.StatusBadRequest, fmt.Errorf("state malformed, size: %d", len(statePayload))) - } - - // verify that the returned timestamp is valid - if err := cryptutil.ValidTimestamp(statePayload[1]); err != nil { - return nil, httputil.NewError(http.StatusBadRequest, err) - } - - // Use our AEAD construct to enforce secrecy and authenticity: - // mac: to validate the nonce again, and above timestamp - // decrypt: to prevent leaking 'redirect_uri' to IdP or logs - b := []byte(fmt.Sprint(statePayload[0], "|", statePayload[1], "|")) - redirectString, err := cryptutil.Decrypt(state.cookieCipher, []byte(statePayload[2]), b) + oauthState, err := oauth.DecodeState(state.cookieCipher, r.FormValue("state")) if err != nil { return nil, httputil.NewError(http.StatusBadRequest, err) } - redirectURL, err := urlutil.ParseAndValidateURL(string(redirectString)) + redirectURL, err := urlutil.ParseAndValidateURL(oauthState.RedirectURL) if err != nil { return nil, httputil.NewError(http.StatusBadRequest, err) } diff --git a/authorize/check_response.go b/authorize/check_response.go index 45ae564e8..2f281e26a 100644 --- a/authorize/check_response.go +++ b/authorize/check_response.go @@ -38,6 +38,7 @@ func (a *Authorize) handleResultAllowed( func (a *Authorize) handleResultDenied( ctx context.Context, in *envoy_service_auth_v3.CheckRequest, + req *evaluator.Request, result *evaluator.Result, isForwardAuthVerify bool, reasons criteria.Reasons, @@ -49,7 +50,7 @@ func (a *Authorize) handleResultDenied( case reasons.Has(criteria.ReasonUserUnauthenticated): // when the user is unauthenticated it means they haven't // logged in yet, so redirect to authenticate - return a.requireLoginResponse(ctx, in, isForwardAuthVerify) + return a.requireLoginResponse(ctx, in, req, isForwardAuthVerify) case reasons.Has(criteria.ReasonDeviceUnauthenticated): // when the user's device is unauthenticated it means they haven't // registered a webauthn device yet, so redirect to the webauthn flow @@ -141,19 +142,20 @@ func (a *Authorize) deniedResponse( func (a *Authorize) requireLoginResponse( ctx context.Context, in *envoy_service_auth_v3.CheckRequest, + req *evaluator.Request, isForwardAuthVerify bool, ) (*envoy_service_auth_v3.CheckResponse, error) { opts := a.currentOptions.Load() state := a.state.Load() - authenticateURL, err := opts.GetAuthenticateURL() - if err != nil { - return nil, err - } if !a.shouldRedirect(in) || isForwardAuthVerify { return a.deniedResponse(ctx, in, http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized), nil) } + authenticateURL, err := opts.GetAuthenticateURL() + if err != nil { + return nil, err + } signinURL := authenticateURL.ResolveReference(&url.URL{ Path: "/.pomerium/sign_in", }) @@ -164,6 +166,17 @@ func (a *Authorize) requireLoginResponse( checkRequestURL.Scheme = "https" q.Set(urlutil.QueryRedirectURI, checkRequestURL.String()) + + // If an OAuthRedirectURL is explicitly set, pass that on the query string to + // override the default authenticate redirect url. + if req.Policy != nil && req.Policy.OAuthRedirectURL != "" { + u, err := urlutil.ParseAndValidateURL(req.Policy.OAuthRedirectURL) + if err != nil { + return nil, err + } + q.Set(urlutil.QueryOAuthRedirectURI, u.String()) + } + signinURL.RawQuery = q.Encode() redirectTo := urlutil.NewSignedURL(state.sharedKey, signinURL).String() diff --git a/authorize/grpc.go b/authorize/grpc.go index 804d37956..1e3557235 100644 --- a/authorize/grpc.go +++ b/authorize/grpc.go @@ -76,7 +76,7 @@ func (a *Authorize) Check(ctx context.Context, in *envoy_service_auth_v3.CheckRe // if there's a deny, the result is denied using the deny reasons. if res.Deny.Value { - return a.handleResultDenied(ctx, in, res, isForwardAuthVerify, res.Deny.Reasons) + return a.handleResultDenied(ctx, in, req, res, isForwardAuthVerify, res.Deny.Reasons) } // if there's an allow, the result is allowed. @@ -85,7 +85,7 @@ func (a *Authorize) Check(ctx context.Context, in *envoy_service_auth_v3.CheckRe } // otherwise, the result is denied using the allow reasons. - return a.handleResultDenied(ctx, in, res, isForwardAuthVerify, res.Allow.Reasons) + return a.handleResultDenied(ctx, in, req, res, isForwardAuthVerify, res.Allow.Reasons) } func getForwardAuthURL(r *http.Request) *url.URL { diff --git a/config/envoyconfig/routes.go b/config/envoyconfig/routes.go index 347933723..7db2bf8a9 100644 --- a/config/envoyconfig/routes.go +++ b/config/envoyconfig/routes.go @@ -3,8 +3,10 @@ package envoyconfig import ( "encoding/json" "fmt" + "net" "net/url" "sort" + "strconv" envoy_config_core_v3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" envoy_config_route_v3 "github.com/envoyproxy/go-control-plane/envoy/config/route/v3" @@ -17,6 +19,7 @@ import ( "github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/internal/httputil" + "github.com/pomerium/pomerium/internal/identity/oauth" "github.com/pomerium/pomerium/internal/urlutil" ) @@ -268,6 +271,14 @@ func (b *Builder) buildPolicyRoutes(options *config.Options, domain string) ([]* continue } + if policy.OAuthRedirectURL != "" { + oauthRedirectURLRoute, err := b.buildOAuthRedirectURLRoute(options, policy.OAuthRedirectURL) + if err != nil { + return nil, err + } + routes = append(routes, oauthRedirectURLRoute) + } + match := mkRouteMatch(&policy) envoyRoute := &envoy_config_route_v3.Route{ Name: fmt.Sprintf("policy-%d", i), @@ -446,6 +457,69 @@ func (b *Builder) buildPolicyRouteRouteAction(options *config.Options, policy *c return action, nil } +func (b *Builder) buildOAuthRedirectURLRoute(options *config.Options, rawOAuthRedirectURL string) (*envoy_config_route_v3.Route, error) { + oauthRedirectURL, err := urlutil.ParseAndValidateURL(rawOAuthRedirectURL) + if err != nil { + return nil, err + } + + envoyRedirect, err := b.buildAuthenticateCallbackRouteRedirectAction(options) + if err != nil { + return nil, err + } + + envoyRoute := &envoy_config_route_v3.Route{ + Action: &envoy_config_route_v3.Route_Redirect{ + Redirect: envoyRedirect, + }, + Match: &envoy_config_route_v3.RouteMatch{ + PathSpecifier: &envoy_config_route_v3.RouteMatch_Path{ + Path: oauthRedirectURL.Path, + }, + QueryParameters: []*envoy_config_route_v3.QueryParameterMatcher{ + {Name: "state", QueryParameterMatchSpecifier: &envoy_config_route_v3.QueryParameterMatcher_StringMatch{ + StringMatch: &envoy_type_matcher_v3.StringMatcher{ + MatchPattern: &envoy_type_matcher_v3.StringMatcher_Prefix{ + Prefix: oauth.StatePrefix, + }, + }, + }}, + }, + }, + TypedPerFilterConfig: map[string]*any.Any{ + "envoy.filters.http.ext_authz": disableExtAuthz, + }, + } + return envoyRoute, nil +} + +func (b *Builder) buildAuthenticateCallbackRouteRedirectAction(options *config.Options) (*envoy_config_route_v3.RedirectAction, error) { + authenticateURL, err := options.GetAuthenticateURL() + if err != nil { + return nil, err + } + + authenticateURL = authenticateURL.ResolveReference(&url.URL{ + Path: options.AuthenticateCallbackPath, + }) + + redirect := &envoy_config_route_v3.RedirectAction{} + if host, rawPort, err := net.SplitHostPort(authenticateURL.Host); err == nil { + if port, err := strconv.ParseUint(rawPort, 10, 32); err == nil { + redirect.HostRedirect = host + redirect.PortRedirect = uint32(port) + } else { + return nil, fmt.Errorf("invalid port in authenticate URL") + } + } + redirect.PathRewriteSpecifier = &envoy_config_route_v3.RedirectAction_PathRedirect{ + PathRedirect: authenticateURL.Path, + } + redirect.ResponseCode = envoy_config_route_v3.RedirectAction_FOUND + + return redirect, nil +} + func mkEnvoyHeader(k, v string) *envoy_config_core_v3.HeaderValueOption { return &envoy_config_core_v3.HeaderValueOption{ Header: &envoy_config_core_v3.HeaderValue{ diff --git a/config/options.go b/config/options.go index 9d8558441..ccfa59f93 100644 --- a/config/options.go +++ b/config/options.go @@ -899,7 +899,7 @@ func (o *Options) GetOauthOptions() (oauth.Options, error) { Path: o.AuthenticateCallbackPath, }) return oauth.Options{ - RedirectURL: redirectURL, + RedirectURL: redirectURL.String(), ProviderName: o.Provider, ProviderURL: o.ProviderURL, ClientID: o.ClientID, diff --git a/config/policy.go b/config/policy.go index da461e20e..851d9f73e 100644 --- a/config/policy.go +++ b/config/policy.go @@ -162,6 +162,9 @@ type Policy struct { // SetResponseHeaders sets response headers. SetResponseHeaders map[string]string `mapstructure:"set_response_headers" yaml:"set_response_headers,omitempty"` + // OAuthRedirectURL overrides the default authenticate redirect URL for this route. + OAuthRedirectURL string `mapstructure:"oauth_redirect_url" yaml:"oauth_redirect_url,omitempty"` + Policy *PPLPolicy `mapstructure:"policy" yaml:"policy,omitempty" json:"policy,omitempty"` } diff --git a/internal/identity/oauth/github/github.go b/internal/identity/oauth/github/github.go index 36c4c6cc2..46e0401be 100644 --- a/internal/identity/oauth/github/github.go +++ b/internal/identity/oauth/github/github.go @@ -78,7 +78,7 @@ func New(ctx context.Context, o *oauth.Options) (*Provider, error) { ClientID: o.ClientID, ClientSecret: o.ClientSecret, Scopes: o.Scopes, - RedirectURL: o.RedirectURL.String(), + RedirectURL: o.RedirectURL, Endpoint: oauth2.Endpoint{ AuthURL: urlutil.Join(o.ProviderURL, authURL), TokenURL: urlutil.Join(o.ProviderURL, tokenURL), @@ -241,8 +241,10 @@ func (p *Provider) Revoke(ctx context.Context, token *oauth2.Token) error { // GetSignInURL returns a URL to OAuth 2.0 provider's consent page // that asks for permissions for the required scopes explicitly. -func (p *Provider) GetSignInURL(state string) (string, error) { - return p.Oauth.AuthCodeURL(state, oauth2.AccessTypeOffline), nil +func (p *Provider) GetSignInURL(state, redirectURL string) (string, error) { + oa := *p.Oauth + oa.RedirectURL = redirectURL + return oa.AuthCodeURL(state, oauth2.AccessTypeOffline), nil } // LogOut is not implemented by github. diff --git a/internal/identity/oauth/options.go b/internal/identity/oauth/options.go index 2be2ddf80..018a8ccf6 100644 --- a/internal/identity/oauth/options.go +++ b/internal/identity/oauth/options.go @@ -3,8 +3,6 @@ // authorization with Bearer JWT. package oauth -import "net/url" - // Options contains the fields required for an OAuth 2.0 (inc. OIDC) auth flow. // // https://tools.ietf.org/html/rfc6749 @@ -22,7 +20,7 @@ type Options struct { ClientSecret string // RedirectURL is the URL to redirect users going through // the OAuth flow, after the resource owner's URLs. - RedirectURL *url.URL + RedirectURL string // Scope specifies optional requested permissions. Scopes []string diff --git a/internal/identity/oauth/state.go b/internal/identity/oauth/state.go new file mode 100644 index 000000000..f50f7f532 --- /dev/null +++ b/internal/identity/oauth/state.go @@ -0,0 +1,80 @@ +package oauth + +import ( + "bytes" + "crypto/cipher" + "encoding/base64" + "fmt" + "strconv" + "strings" + "time" + + "github.com/google/uuid" + + "github.com/pomerium/pomerium/pkg/cryptutil" +) + +// StatePrefix is the prefix used to indicate the state is via pomerium. +const StatePrefix = "POMERIUM-" + +// State is the state in the oauth query string. +type State struct { + Nonce string + Timestamp time.Time + RedirectURL string +} + +// NewState creates a new State. +func NewState(redirectURL string) *State { + return &State{ + Nonce: uuid.NewString(), + Timestamp: time.Now(), + RedirectURL: redirectURL, + } +} + +// DecodeState decodes state from a raw state string. +func DecodeState(aead cipher.AEAD, rawState string) (*State, error) { + withoutPrefix := strings.TrimPrefix(rawState, StatePrefix) + rawStateBytes, err := base64.RawURLEncoding.DecodeString(withoutPrefix) + if err != nil { + return nil, fmt.Errorf("invalid state encoding: %w", err) + } + + // split the state into its components + state := bytes.SplitN(rawStateBytes, []byte{'|'}, 3) + if len(state) != 3 { + return nil, fmt.Errorf("invalid state format") + } + + // verify that the returned timestamp is valid + if err := cryptutil.ValidTimestamp(string(state[1])); err != nil { + return nil, fmt.Errorf("invalid state timestamp: %w", err) + } + + nonce := string(state[0]) + timestamp, err := strconv.ParseInt(string(state[1]), 10, 64) + if err != nil { + return nil, fmt.Errorf("invalid state timestamp: %w", err) + } + + ad := []byte(fmt.Sprintf("%s|%d|", nonce, timestamp)) + decrypted, err := cryptutil.Decrypt(aead, state[2], ad) + if err != nil { + return nil, fmt.Errorf("invalid state redirect URL: %w", err) + } + + return &State{ + Nonce: nonce, + Timestamp: time.Unix(timestamp, 0), + RedirectURL: string(decrypted), + }, nil +} + +// Encode encodes the state. +func (state *State) Encode(aead cipher.AEAD) string { + timestamp := state.Timestamp.Unix() + ad := []byte(fmt.Sprintf("%s|%d|", state.Nonce, timestamp)) + encrypted := cryptutil.Encrypt(aead, []byte(state.RedirectURL), ad) + return StatePrefix + base64.RawURLEncoding.EncodeToString(append(ad, encrypted...)) +} diff --git a/internal/identity/oidc/oidc.go b/internal/identity/oidc/oidc.go index 14a565e00..332022b29 100644 --- a/internal/identity/oidc/oidc.go +++ b/internal/identity/oidc/oidc.go @@ -73,7 +73,7 @@ func New(ctx context.Context, o *oauth.Options, options ...Option) (*Provider, e ClientSecret: o.ClientSecret, Scopes: o.Scopes, Endpoint: provider.Endpoint(), - RedirectURL: o.RedirectURL.String(), + RedirectURL: o.RedirectURL, } }), WithGetProvider(func() (*go_oidc.Provider, error) { @@ -103,11 +103,12 @@ func New(ctx context.Context, o *oauth.Options, options ...Option) (*Provider, e // always provide a non-empty string and validate that it matches the // the state query parameter on your redirect callback. // See http://tools.ietf.org/html/rfc6749#section-10.12 for more info. -func (p *Provider) GetSignInURL(state string) (string, error) { +func (p *Provider) GetSignInURL(state, redirectURL string) (string, error) { oa, err := p.GetOauthConfig() if err != nil { return "", err } + oa.RedirectURL = redirectURL opts := defaultAuthCodeOptions for k, v := range p.AuthCodeOptions { diff --git a/internal/identity/providers.go b/internal/identity/providers.go index c1bc85808..20a6f4cd3 100644 --- a/internal/identity/providers.go +++ b/internal/identity/providers.go @@ -28,7 +28,7 @@ type Authenticator interface { Authenticate(context.Context, string, identity.State) (*oauth2.Token, error) Refresh(context.Context, *oauth2.Token, identity.State) (*oauth2.Token, error) Revoke(context.Context, *oauth2.Token) error - GetSignInURL(state string) (string, error) + GetSignInURL(state, redirectURL string) (string, error) Name() string LogOut() (*url.URL, error) UpdateUserInfo(ctx context.Context, t *oauth2.Token, v interface{}) error diff --git a/internal/urlutil/query_params.go b/internal/urlutil/query_params.go index 72aed1d6c..08553e4fa 100644 --- a/internal/urlutil/query_params.go +++ b/internal/urlutil/query_params.go @@ -10,6 +10,7 @@ const ( QueryEnrollmentToken = "pomerium_enrollment_token" //nolint QueryIsProgrammatic = "pomerium_programmatic" QueryForwardAuth = "pomerium_forward_auth" + QueryOAuthRedirectURI = "pomerium_oauth_redirect_uri" QueryPomeriumJWT = "pomerium_jwt" QuerySession = "pomerium_session" QuerySessionEncrypted = "pomerium_session_encrypted"