From 6d947ebd26bd2a7e14fe63fc064ec823d4fe7f15 Mon Sep 17 00:00:00 2001 From: Joe Kralicky Date: Thu, 16 May 2024 16:47:02 -0400 Subject: [PATCH] Prototype device authorization flow (core) --- authenticate/handlers.go | 103 +++++++++++++++++++++++++ internal/authenticateflow/stateful.go | 11 +++ internal/authenticateflow/stateless.go | 11 +++ internal/urlutil/query_params.go | 34 ++++---- pkg/identity/mock_provider.go | 32 ++++++-- pkg/identity/oauth/apple/apple.go | 8 ++ pkg/identity/oauth/github/github.go | 8 ++ pkg/identity/oidc/device_auth.go | 55 +++++++++++++ pkg/identity/oidc/errors.go | 4 + pkg/identity/oidc/oidc.go | 56 ++++++++++++++ pkg/identity/providers.go | 4 + proxy/handlers.go | 28 +++++++ proxy/state.go | 1 + 13 files changed, 331 insertions(+), 24 deletions(-) create mode 100644 pkg/identity/oidc/device_auth.go diff --git a/authenticate/handlers.go b/authenticate/handlers.go index ceaa721f5..54e3e9b73 100644 --- a/authenticate/handlers.go +++ b/authenticate/handlers.go @@ -3,6 +3,7 @@ package authenticate import ( "context" "encoding/base64" + "encoding/json" "errors" "fmt" "net/http" @@ -90,6 +91,9 @@ func (a *Authenticate) mountDashboard(r *mux.Router) { // routes that don't need a session: sr.Path("/sign_out").Handler(httputil.HandlerFunc(a.SignOut)) sr.Path("/signed_out").Handler(httputil.HandlerFunc(a.signedOut)).Methods(http.MethodGet) + sr.Path("/device_auth").Handler(httputil.HandlerFunc(a.DeviceAuthLogin)). + Queries(urlutil.QueryDeviceAuthRouteURI, ""). + Methods(http.MethodGet, http.MethodPost) // routes that need a session: sr = sr.NewRoute().Subrouter() @@ -548,3 +552,102 @@ func (a *Authenticate) getIdentityProviderIDForRequest(r *http.Request) string { } return a.state.Load().flow.GetIdentityProviderIDForURLValues(r.Form) } + +func (a *Authenticate) getRetryTokenForRequest(r *http.Request) []byte { + if err := r.ParseForm(); err != nil { + return nil + } + dec, _ := base64.URLEncoding.DecodeString(r.Form.Get(urlutil.QueryDeviceAuthRetryToken)) + return dec +} + +func (a *Authenticate) DeviceAuthLogin(w http.ResponseWriter, r *http.Request) error { + state := a.state.Load() + options := a.options.Load() + idpID := a.getIdentityProviderIDForRequest(r) + + routeUri := r.FormValue(urlutil.QueryDeviceAuthRouteURI) + ad := []byte(fmt.Sprintf("%s|%s|", routeUri, idpID)) + authenticator, err := a.cfg.getIdentityProvider(options, idpID) + if err != nil { + return err + } + + // check if the request includes a retry token + if encRetryToken := a.getRetryTokenForRequest(r); len(encRetryToken) > 0 { + retryTokenJwt, err := cryptutil.Decrypt(state.cookieCipher, []byte(encRetryToken), ad) + if err != nil { + return httputil.NewError(http.StatusUnauthorized, fmt.Errorf("bad retry token: %w", err)) + } + var retryToken oidc.RetryToken + if err := state.sharedEncoder.Unmarshal(retryTokenJwt, &retryToken); err != nil { + return httputil.NewError(http.StatusUnauthorized, fmt.Errorf("bad retry token: %w", err)) + } + now := time.Now() + if now.After(time.Unix(0, retryToken.NotAfter)) { + return httputil.NewError(http.StatusUnauthorized, fmt.Errorf("retry token expired")) + } else if now.Before(time.Unix(0, retryToken.NotBefore)) { + w.Header().Set("Retry-After", time.Until(time.Unix(0, retryToken.NotBefore)).String()) + return httputil.NewError(http.StatusTooManyRequests, fmt.Errorf("retry token not yet valid")) + } + + var claims identity.SessionClaims + accessToken, err := authenticator.DeviceAccessToken(r.Context(), retryToken.AsDeviceAuthResponse(), &claims) + if err != nil { + return httputil.NewError(http.StatusInternalServerError, fmt.Errorf("failed to get device access token: %w", err)) + } + + // + // TODO: code copied from getOAuthCallback + // + s := sessions.NewState(idpID) + err = claims.Claims.Claims(&s) + if err != nil { + return fmt.Errorf("error unmarshaling session state: %w", err) + } + + newState := s.WithNewIssuer(state.redirectURL.Hostname(), []string{state.redirectURL.Hostname()}) + + // save the session and access token to the databroker/cookie store + if err := state.flow.PersistSession(r.Context(), w, &newState, claims, accessToken); err != nil { + return fmt.Errorf("failed saving new session: %w", err) + } + + // ... and the user state to local storage. + if err := state.sessionStore.SaveSession(w, r, &newState); err != nil { + return fmt.Errorf("failed saving new session: %w", err) + } + // + // end + // + + tokenJwt, err := state.sharedEncoder.Marshal(newState) + if err != nil { + return httputil.NewError(http.StatusInternalServerError, fmt.Errorf("failed to marshal session: %w", err)) + } + w.WriteHeader(http.StatusOK) + w.Header().Set("Content-Type", "application/json") + fmt.Fprintf(w, `{"token": "%s"}`, string(tokenJwt)) + return nil + } else { + authResp, err := authenticator.DeviceAuth(w, r) + if err != nil { + return httputil.NewError(http.StatusInternalServerError, + fmt.Errorf("failed to get device code: %w", err)) + } + // construct a retry token + retryToken := oidc.NewRetryToken(authResp) + // encode + retryTokenJwt, err := state.sharedEncoder.Marshal(retryToken) + if err != nil { + return httputil.NewError(http.StatusInternalServerError, + fmt.Errorf("failed to marshal retry token: %w", err)) + } + + // write the user-facing part of the auth response plus the encrypted retry token + userResp := oidc.NewUserDeviceAuthResponse(authResp, cryptutil.Encrypt(state.cookieCipher, retryTokenJwt, ad)) + w.WriteHeader(http.StatusOK) + w.Header().Set("Content-Type", "application/json") + return json.NewEncoder(w).Encode(userResp) + } +} diff --git a/internal/authenticateflow/stateful.go b/internal/authenticateflow/stateful.go index a7b27762d..8cffcbb80 100644 --- a/internal/authenticateflow/stateful.go +++ b/internal/authenticateflow/stateful.go @@ -339,6 +339,17 @@ func (s *Stateful) AuthenticateSignInURL( return redirectTo, nil } +func (s *Stateful) AuthenticateDeviceCode(w http.ResponseWriter, r *http.Request, params url.Values) error { + deviceAuthURL := s.authenticateURL.ResolveReference(&url.URL{ + Path: "/.pomerium/device_auth", + RawQuery: params.Encode(), + }) + + signedURL := urlutil.NewSignedURL(s.sharedKey, deviceAuthURL) + httputil.Redirect(w, r, signedURL.String(), http.StatusFound) + return nil +} + // GetIdentityProviderIDForURLValues returns the identity provider ID // associated with the given URL values. func (s *Stateful) GetIdentityProviderIDForURLValues(vs url.Values) string { diff --git a/internal/authenticateflow/stateless.go b/internal/authenticateflow/stateless.go index 363962937..cb07bab13 100644 --- a/internal/authenticateflow/stateless.go +++ b/internal/authenticateflow/stateless.go @@ -365,6 +365,17 @@ func (s *Stateless) AuthenticateSignInURL( ) } +func (s *Stateless) AuthenticateDeviceCode(w http.ResponseWriter, r *http.Request, params url.Values) error { + signinURL := s.authenticateURL.ResolveReference(&url.URL{ + Path: "/.pomerium/device_auth", + RawQuery: params.Encode(), + }) + + signedURL := urlutil.NewSignedURL(s.sharedKey, signinURL) + httputil.Redirect(w, r, signedURL.String(), http.StatusFound) + return nil +} + // Callback handles a redirect to a route domain once signed in. func (s *Stateless) Callback(w http.ResponseWriter, r *http.Request) error { if err := r.ParseForm(); err != nil { diff --git a/internal/urlutil/query_params.go b/internal/urlutil/query_params.go index ac5a127bc..d068aa905 100644 --- a/internal/urlutil/query_params.go +++ b/internal/urlutil/query_params.go @@ -4,22 +4,24 @@ package urlutil // services over HTTP calls and redirects. They are typically used in // conjunction with a HMAC to ensure authenticity. const ( - QueryCallbackURI = "pomerium_callback_uri" - QueryDeviceCredentialID = "pomerium_device_credential_id" - QueryDeviceType = "pomerium_device_type" - QueryEnrollmentToken = "pomerium_enrollment_token" //nolint - QueryExpiry = "pomerium_expiry" - QueryIdentityProfile = "pomerium_identity_profile" - QueryIdentityProviderID = "pomerium_idp_id" - QueryIsProgrammatic = "pomerium_programmatic" - QueryIssued = "pomerium_issued" - QueryPomeriumJWT = "pomerium_jwt" - QueryRedirectURI = "pomerium_redirect_uri" - QuerySession = "pomerium_session" - QuerySessionEncrypted = "pomerium_session_encrypted" - QuerySessionState = "pomerium_session_state" - QueryVersion = "pomerium_version" - QueryRequestUUID = "pomerium_request_uuid" + QueryCallbackURI = "pomerium_callback_uri" + QueryDeviceCredentialID = "pomerium_device_credential_id" + QueryDeviceType = "pomerium_device_type" + QueryEnrollmentToken = "pomerium_enrollment_token" //nolint + QueryExpiry = "pomerium_expiry" + QueryIdentityProfile = "pomerium_identity_profile" + QueryIdentityProviderID = "pomerium_idp_id" + QueryIsProgrammatic = "pomerium_programmatic" + QueryIssued = "pomerium_issued" + QueryPomeriumJWT = "pomerium_jwt" + QueryRedirectURI = "pomerium_redirect_uri" + QuerySession = "pomerium_session" + QuerySessionEncrypted = "pomerium_session_encrypted" + QuerySessionState = "pomerium_session_state" + QueryVersion = "pomerium_version" + QueryRequestUUID = "pomerium_request_uuid" + QueryDeviceAuthRetryToken = "pomerium_device_auth_retry_token" + QueryDeviceAuthRouteURI = "pomerium_device_auth_route_uri" ) // URL signature based query params used for verifying the authenticity of a URL. diff --git a/pkg/identity/mock_provider.go b/pkg/identity/mock_provider.go index 5d376157e..9dda2aee1 100644 --- a/pkg/identity/mock_provider.go +++ b/pkg/identity/mock_provider.go @@ -11,16 +11,22 @@ import ( // MockProvider provides a mocked implementation of the providers interface. type MockProvider struct { - AuthenticateResponse oauth2.Token - AuthenticateError error - RefreshResponse oauth2.Token - RefreshError error - RevokeError error - UpdateUserInfoError error - SignInError error - SignOutError error + AuthenticateResponse oauth2.Token + AuthenticateError error + RefreshResponse oauth2.Token + RefreshError error + RevokeError error + UpdateUserInfoError error + SignInError error + SignOutError error + DeviceAuthResponse oauth2.DeviceAuthResponse + DeviceAuthError error + DeviceAccessTokenResponse oauth2.Token + DeviceAccessTokenError error } +var _ Authenticator = MockProvider{} + // Authenticate is a mocked providers function. func (mp MockProvider) Authenticate(context.Context, string, identity.State) (*oauth2.Token, error) { return &mp.AuthenticateResponse, mp.AuthenticateError @@ -55,3 +61,13 @@ func (mp MockProvider) SignOut(_ http.ResponseWriter, _ *http.Request, _, _, _ s func (mp MockProvider) SignIn(_ http.ResponseWriter, _ *http.Request, _ string) error { return mp.SignInError } + +// DeviceAccessToken implements Authenticator. +func (mp MockProvider) DeviceAccessToken(ctx context.Context, r *oauth2.DeviceAuthResponse, state identity.State) (*oauth2.Token, error) { + return &mp.DeviceAccessTokenResponse, mp.DeviceAccessTokenError +} + +// DeviceAuth implements Authenticator. +func (mp MockProvider) DeviceAuth(w http.ResponseWriter, r *http.Request) (*oauth2.DeviceAuthResponse, error) { + return &mp.DeviceAuthResponse, mp.DeviceAuthError +} diff --git a/pkg/identity/oauth/apple/apple.go b/pkg/identity/oauth/apple/apple.go index ea0807637..da3f708bf 100644 --- a/pkg/identity/oauth/apple/apple.go +++ b/pkg/identity/oauth/apple/apple.go @@ -182,3 +182,11 @@ func (p *Provider) SignIn(w http.ResponseWriter, r *http.Request, state string) func (p *Provider) SignOut(_ http.ResponseWriter, _ *http.Request, _, _, _ string) error { return oidc.ErrSignoutNotImplemented } + +func (p *Provider) DeviceAuth(_ http.ResponseWriter, _ *http.Request) (*oauth2.DeviceAuthResponse, error) { + return nil, oidc.ErrDeviceAuthNotImplemented +} + +func (p *Provider) DeviceAccessToken(_ context.Context, _ *oauth2.DeviceAuthResponse, _ identity.State) (*oauth2.Token, error) { + return nil, oidc.ErrDeviceAuthNotImplemented +} diff --git a/pkg/identity/oauth/github/github.go b/pkg/identity/oauth/github/github.go index 7ae0f5c79..0bcf5c4ea 100644 --- a/pkg/identity/oauth/github/github.go +++ b/pkg/identity/oauth/github/github.go @@ -256,3 +256,11 @@ func (p *Provider) SignIn(w http.ResponseWriter, r *http.Request, state string) func (p *Provider) SignOut(_ http.ResponseWriter, _ *http.Request, _, _, _ string) error { return oidc.ErrSignoutNotImplemented } + +func (p *Provider) DeviceAuth(_ http.ResponseWriter, _ *http.Request) (*oauth2.DeviceAuthResponse, error) { + return nil, oidc.ErrDeviceAuthNotImplemented +} + +func (p *Provider) DeviceAccessToken(_ context.Context, _ *oauth2.DeviceAuthResponse, _ identity.State) (*oauth2.Token, error) { + return nil, oidc.ErrDeviceAuthNotImplemented +} diff --git a/pkg/identity/oidc/device_auth.go b/pkg/identity/oidc/device_auth.go new file mode 100644 index 000000000..5dd0f78ac --- /dev/null +++ b/pkg/identity/oidc/device_auth.go @@ -0,0 +1,55 @@ +package oidc + +import ( + "time" + + "golang.org/x/oauth2" +) + +type UserDeviceAuthResponse struct { + // UserCode is the code the user should enter at the verification uri + UserCode string `json:"user_code"` + // VerificationURI is where user should enter the user code + VerificationURI string `json:"verification_uri"` + // VerificationURIComplete (if populated) includes the user code in the verification URI. This is typically shown to the user in non-textual form, such as a QR code. + VerificationURIComplete string `json:"verification_uri_complete,omitempty"` + + // InitialRetryDelay is the duration in seconds the client must wait before + // attempting to retry the request, after completing their sign-in. + // This gives the server time to poll the identity provider for the results. + InitialRetryDelay int64 `json:"initial_retry_delay,omitempty"` + + // RetryToken should be sent on subsequent retries of the original request. + RetryToken []byte `json:"retry_token,omitempty"` +} + +type RetryToken struct { + DeviceCode string `json:"device_code"` + NotBefore int64 `json:"not_before"` + NotAfter int64 `json:"not_after"` +} + +func (rt RetryToken) AsDeviceAuthResponse() *oauth2.DeviceAuthResponse { + return &oauth2.DeviceAuthResponse{ + DeviceCode: rt.DeviceCode, + Expiry: time.Unix(0, rt.NotAfter), + } +} + +func NewRetryToken(authResp *oauth2.DeviceAuthResponse) RetryToken { + return RetryToken{ + DeviceCode: authResp.DeviceCode, + NotBefore: time.Now().Add(time.Duration(authResp.Interval) * time.Second).UnixNano(), + NotAfter: authResp.Expiry.UnixNano(), + } +} + +func NewUserDeviceAuthResponse(authResp *oauth2.DeviceAuthResponse, retryTokenCiphertext []byte) UserDeviceAuthResponse { + return UserDeviceAuthResponse{ + UserCode: authResp.UserCode, + VerificationURI: authResp.VerificationURI, + VerificationURIComplete: authResp.VerificationURIComplete, + InitialRetryDelay: authResp.Interval, + RetryToken: retryTokenCiphertext, + } +} diff --git a/pkg/identity/oidc/errors.go b/pkg/identity/oidc/errors.go index be4ee9743..2b5f7d930 100644 --- a/pkg/identity/oidc/errors.go +++ b/pkg/identity/oidc/errors.go @@ -13,6 +13,10 @@ var ErrRevokeNotImplemented = errors.New("identity/oidc: revoke not implemented" // https://openid.net/specs/openid-connect-frontchannel-1_0.html#RPInitiated var ErrSignoutNotImplemented = errors.New("identity/oidc: end session not implemented") +// ErrDeviceAuthNotImplemented is returned when device auth is not implemented +// by an identity provider. +var ErrDeviceAuthNotImplemented = errors.New("identity/oidc: device auth not implemented") + // ErrMissingProviderURL is returned when an identity provider requires a provider url // does not receive one. var ErrMissingProviderURL = errors.New("identity/oidc: missing provider url") diff --git a/pkg/identity/oidc/oidc.go b/pkg/identity/oidc/oidc.go index 3960e7017..3f930283b 100644 --- a/pkg/identity/oidc/oidc.go +++ b/pkg/identity/oidc/oidc.go @@ -118,6 +118,62 @@ func (p *Provider) SignIn(w http.ResponseWriter, r *http.Request, state string) return nil } +func (p *Provider) DeviceAuth(w http.ResponseWriter, r *http.Request) (*oauth2.DeviceAuthResponse, error) { + oa, err := p.GetOauthConfig() + if err != nil { + return nil, err + } + + opts := defaultAuthCodeOptions + for k, v := range p.AuthCodeOptions { + opts = append(opts, oauth2.SetAuthURLParam(k, v)) + } + + resp, err := oa.DeviceAuth(r.Context(), opts...) + if err != nil { + return nil, err + } + + return resp, nil +} + +func (p *Provider) DeviceAccessToken(ctx context.Context, da *oauth2.DeviceAuthResponse, v identity.State) (*oauth2.Token, error) { + oa, err := p.GetOauthConfig() + if err != nil { + return nil, err + } + + oauth2Token, err := oa.DeviceAccessToken(ctx, da) + if err != nil { + return nil, err + } + + // + // TODO: the rest of this function is copied from Authenticate + // + + idToken, err := p.getIDToken(ctx, oauth2Token) + if err != nil { + return nil, fmt.Errorf("identity/oidc: failed getting id_token: %w", err) + } + + if rawIDToken, ok := oauth2Token.Extra("id_token").(string); ok { + v.SetRawIDToken(rawIDToken) + } + + // hydrate `v` using claims inside the returned `id_token` + // https://openid.net/specs/openid-connect-core-1_0.html#TokenEndpoint + if err := idToken.Claims(v); err != nil { + return nil, fmt.Errorf("identity/oidc: couldn't unmarshal extra claims %w", err) + } + + if err := p.UpdateUserInfo(ctx, oauth2Token, v); err != nil { + return nil, fmt.Errorf("identity/oidc: couldn't update user info %w", err) + } + + return oauth2Token, nil +} + // Authenticate converts an authorization code returned from the identity // provider into a token which is then converted into a user session. func (p *Provider) Authenticate(ctx context.Context, code string, v identity.State) (*oauth2.Token, error) { diff --git a/pkg/identity/providers.go b/pkg/identity/providers.go index 34b569cc1..75e481e63 100644 --- a/pkg/identity/providers.go +++ b/pkg/identity/providers.go @@ -37,6 +37,10 @@ type Authenticator interface { SignIn(w http.ResponseWriter, r *http.Request, state string) error SignOut(w http.ResponseWriter, r *http.Request, idTokenHint, authenticateSignedOutURL, redirectToURL string) error + + // alternatives for these methods? + DeviceAuth(w http.ResponseWriter, r *http.Request) (*oauth2.DeviceAuthResponse, error) + DeviceAccessToken(ctx context.Context, r *oauth2.DeviceAuthResponse, state State) (*oauth2.Token, error) } // AuthenticatorConstructor makes an Authenticator from the given options. diff --git a/proxy/handlers.go b/proxy/handlers.go index 682e0d100..37fe5d0c6 100644 --- a/proxy/handlers.go +++ b/proxy/handlers.go @@ -39,6 +39,10 @@ func (p *Proxy) registerDashboardHandlers(r *mux.Router) *mux.Router { Queries(urlutil.QueryRedirectURI, ""). Methods(http.MethodGet) + a.Path("/v1/device_auth").Handler(httputil.HandlerFunc(p.DeviceAuthLogin)). + Queries(urlutil.QueryDeviceAuthRouteURI, ""). + Methods(http.MethodGet, http.MethodPost) + return r } @@ -136,6 +140,30 @@ func (p *Proxy) ProgrammaticLogin(w http.ResponseWriter, r *http.Request) error return nil } +func (p *Proxy) DeviceAuthLogin(w http.ResponseWriter, r *http.Request) error { + state := p.state.Load() + options := p.currentOptions.Load() + + params := url.Values{} + routeUri, err := urlutil.ParseAndValidateURL(r.FormValue(urlutil.QueryDeviceAuthRouteURI)) + if err != nil { + return httputil.NewError(http.StatusBadRequest, err) + } + params.Set(urlutil.QueryDeviceAuthRouteURI, routeUri.String()) + + idp, err := options.GetIdentityProviderForRequestURL(routeUri.String()) + if err != nil { + return httputil.NewError(http.StatusInternalServerError, err) + } + params.Set(urlutil.QueryIdentityProviderID, idp.Id) + + if retryToken := r.FormValue(urlutil.QueryDeviceAuthRetryToken); retryToken != "" { + params.Set(urlutil.QueryDeviceAuthRetryToken, retryToken) + } + + return state.authenticateFlow.AuthenticateDeviceCode(w, r, params) +} + // jwtAssertion returns the current request's JWT assertion (rfc7519#section-10.3.1). func (p *Proxy) jwtAssertion(w http.ResponseWriter, r *http.Request) error { rawAssertionJWT := r.Header.Get(httputil.HeaderPomeriumJWTAssertion) diff --git a/proxy/state.go b/proxy/state.go index 55842d97b..72c1517e4 100644 --- a/proxy/state.go +++ b/proxy/state.go @@ -21,6 +21,7 @@ var outboundGRPCConnection = new(grpc.CachedOutboundGRPClientConn) type authenticateFlow interface { AuthenticateSignInURL(ctx context.Context, queryParams url.Values, redirectURL *url.URL, idpID string) (string, error) + AuthenticateDeviceCode(w http.ResponseWriter, r *http.Request, queryParams url.Values) error Callback(w http.ResponseWriter, r *http.Request) error }