diff --git a/authenticate/authenticate.go b/authenticate/authenticate.go index a2d09ee76..b1804229f 100644 --- a/authenticate/authenticate.go +++ b/authenticate/authenticate.go @@ -45,15 +45,16 @@ type Authenticate struct { // New validates and creates a new authenticate service from a set of Options. func New(cfg *config.Config, options ...Option) (*Authenticate, error) { + authenticateConfig := getAuthenticateConfig(options...) a := &Authenticate{ - cfg: getAuthenticateConfig(options...), + cfg: authenticateConfig, options: config.NewAtomicOptions(), state: atomicutil.NewValue(newAuthenticateState()), } a.options.Store(cfg.Options) - state, err := newAuthenticateStateFromConfig(cfg) + state, err := newAuthenticateStateFromConfig(cfg, authenticateConfig) if err != nil { return nil, err } @@ -69,7 +70,7 @@ func (a *Authenticate) OnConfigChange(ctx context.Context, cfg *config.Config) { } a.options.Store(cfg.Options) - if state, err := newAuthenticateStateFromConfig(cfg); err != nil { + if state, err := newAuthenticateStateFromConfig(cfg, a.cfg); err != nil { log.Error(ctx).Err(err).Msg("authenticate: failed to update state") } else { a.state.Store(state) diff --git a/authenticate/config.go b/authenticate/config.go index e3eedbeaf..a1f9a8ecd 100644 --- a/authenticate/config.go +++ b/authenticate/config.go @@ -2,6 +2,7 @@ package authenticate import ( "github.com/pomerium/pomerium/config" + "github.com/pomerium/pomerium/internal/authenticateflow" "github.com/pomerium/pomerium/internal/identity" identitypb "github.com/pomerium/pomerium/pkg/grpc/identity" ) @@ -9,7 +10,7 @@ import ( type authenticateConfig struct { getIdentityProvider func(options *config.Options, idpID string) (identity.Authenticator, error) profileTrimFn func(*identitypb.Profile) - authEventFn AuthEventFn + authEventFn authenticateflow.AuthEventFn } // An Option customizes the Authenticate config. @@ -39,7 +40,7 @@ func WithProfileTrimFn(profileTrimFn func(*identitypb.Profile)) Option { } // WithOnAuthenticationEventHook sets the authEventFn function in the config -func WithOnAuthenticationEventHook(fn AuthEventFn) Option { +func WithOnAuthenticationEventHook(fn authenticateflow.AuthEventFn) Option { return func(cfg *authenticateConfig) { cfg.authEventFn = fn } diff --git a/authenticate/handlers.go b/authenticate/handlers.go index ef4a1bd24..351ea5192 100644 --- a/authenticate/handlers.go +++ b/authenticate/handlers.go @@ -3,7 +3,6 @@ package authenticate import ( "context" "encoding/base64" - "encoding/json" "errors" "fmt" "net/http" @@ -14,9 +13,9 @@ import ( "github.com/google/uuid" "github.com/gorilla/mux" "github.com/rs/cors" - "golang.org/x/oauth2" "github.com/pomerium/csrf" + "github.com/pomerium/pomerium/internal/authenticateflow" "github.com/pomerium/pomerium/internal/handlers" "github.com/pomerium/pomerium/internal/httputil" "github.com/pomerium/pomerium/internal/identity" @@ -27,7 +26,6 @@ import ( "github.com/pomerium/pomerium/internal/telemetry/trace" "github.com/pomerium/pomerium/internal/urlutil" "github.com/pomerium/pomerium/pkg/cryptutil" - "github.com/pomerium/pomerium/pkg/hpke" ) // Handler returns the authenticate service's handler chain. @@ -76,7 +74,7 @@ func (a *Authenticate) mountDashboard(r *mux.Router) { c := cors.New(cors.Options{ AllowOriginRequestFunc: func(r *http.Request, _ string) bool { state := a.state.Load() - err := middleware.ValidateRequestURL(a.getExternalRequest(r), state.sharedKey) + err := state.flow.VerifyAuthenticateSignature(r) if err != nil { log.FromRequest(r).Info().Err(err).Msg("authenticate: origin blocked") } @@ -139,21 +137,11 @@ func (a *Authenticate) VerifySession(next http.Handler) http.Handler { return a.reauthenticateOrFail(w, r, err) } - profile, err := a.loadIdentityProfile(r, state.cookieCipher) - if err != nil { + if err := state.flow.VerifySession(ctx, r, sessionState); err != nil { log.FromRequest(r).Info(). Err(err). Str("idp_id", idpID). - Msg("authenticate: identity profile load error") - return a.reauthenticateOrFail(w, r, err) - } - - err = a.validateIdentityProfile(ctx, profile) - if err != nil { - log.FromRequest(r).Info(). - Err(err). - Str("idp_id", idpID). - Msg("authenticate: invalid identity profile") + Msg("authenticate: couldn't verify session") return a.reauthenticateOrFail(w, r, err) } @@ -176,62 +164,20 @@ func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) error { state := a.state.Load() - if err := r.ParseForm(); err != nil { - return httputil.NewError(http.StatusBadRequest, err) - } - proxyPublicKey, requestParams, err := hpke.DecryptURLValues(state.hpkePrivateKey, r.Form) - if err != nil { - return err - } - - idpID := requestParams.Get(urlutil.QueryIdentityProviderID) - s, err := a.getSessionFromCtx(ctx) if err != nil { state.sessionStore.ClearSession(w, r) return err } - // start over if this is a different identity provider - if s == nil || s.IdentityProviderID != idpID { - s = sessions.NewState(idpID) - } - - // re-persist the session, useful when session was evicted from session - if err := state.sessionStore.SaveSession(w, r, s); err != nil { - return httputil.NewError(http.StatusBadRequest, err) - } - - profile, err := a.loadIdentityProfile(r, state.cookieCipher) - if err != nil { - return httputil.NewError(http.StatusBadRequest, err) - } - - if a.cfg.profileTrimFn != nil { - a.cfg.profileTrimFn(profile) - } - - a.logAuthenticateEvent(r, profile) - - encryptURLValues := hpke.EncryptURLValuesV1 - if hpke.IsEncryptedURLV2(r.Form) { - encryptURLValues = hpke.EncryptURLValuesV2 - } - - redirectTo, err := urlutil.CallbackURL(state.hpkePrivateKey, proxyPublicKey, requestParams, profile, encryptURLValues) - if err != nil { - return httputil.NewError(http.StatusInternalServerError, err) - } - - httputil.Redirect(w, r, redirectTo, http.StatusFound) - return nil + return state.flow.SignIn(w, r, s) } // SignOut signs the user out and attempts to revoke the user's identity session // Handles both GET and POST. func (a *Authenticate) SignOut(w http.ResponseWriter, r *http.Request) error { // check for an HMAC'd URL. If none is found, show a confirmation page. - err := middleware.ValidateRequestURL(a.getExternalRequest(r), a.state.Load().sharedKey) + err := a.state.Load().flow.VerifyAuthenticateSignature(r) if err != nil { authenticateURL, err := a.options.Load().GetAuthenticateURL() if err != nil { @@ -319,7 +265,7 @@ func (a *Authenticate) reauthenticateOrFail(w http.ResponseWriter, r *http.Reque return err } - a.logAuthenticateEvent(r, nil) + state.flow.LogAuthenticateEvent(r) state.sessionStore.ClearSession(w, r) redirectURL := state.redirectURL.ResolveReference(r.URL) @@ -418,7 +364,7 @@ Or contact your administrator. `, redirectURL.String(), redirectURL.String())) } - idpID := a.getIdentityProviderIDForURLValues(redirectURL.Query()) + idpID := state.flow.GetIdentityProviderIDForURLValues(redirectURL.Query()) authenticator, err := a.cfg.getIdentityProvider(options, idpID) if err != nil { @@ -445,13 +391,9 @@ Or contact your administrator. newState.Audience = append(newState.Audience, nextRedirectURL.Hostname()) } - // save the session and access token to the databroker - profile, err := a.buildIdentityProfile(r, claims, accessToken) - if err != nil { - return nil, httputil.NewError(http.StatusInternalServerError, err) - } - if err := a.storeIdentityProfile(w, state.cookieCipher, profile); err != nil { - log.Error(r.Context()).Err(err).Msg("failed to store identity profile") + // save the session and access token to the databroker/cookie store + if err := state.flow.PersistSession(ctx, w, &newState, claims, accessToken); err != nil { + return nil, fmt.Errorf("failed saving new session: %w", err) } // ... and the user state to local storage. @@ -478,11 +420,12 @@ func (a *Authenticate) getSessionFromCtx(ctx context.Context) (*sessions.State, func (a *Authenticate) userInfo(w http.ResponseWriter, r *http.Request) error { ctx, span := trace.StartSpan(r.Context(), "authenticate.userInfo") defer span.End() - r = r.WithContext(ctx) - r = a.getExternalRequest(r) options := a.options.Load() + r = r.WithContext(ctx) + r = authenticateflow.GetExternalAuthenticateRequest(r, options) + // if we came in with a redirect URI, save it to a cookie so it doesn't expire with the HMAC if redirectURI := r.FormValue(urlutil.QueryRedirectURI); redirectURI != "" { u := urlutil.GetAbsoluteURL(r) @@ -509,14 +452,9 @@ func (a *Authenticate) getUserInfoData(r *http.Request) handlers.UserInfoData { s.ID = uuid.New().String() } - profile, _ := a.loadIdentityProfile(r, state.cookieCipher) - - data := handlers.UserInfoData{ - CSRFToken: csrf.Token(r), - Profile: profile, - - BrandingOptions: a.options.Load().BrandingOptions, - } + data := state.flow.GetUserInfoData(r, s) + data.CSRFToken = csrf.Token(r) + data.BrandingOptions = a.options.Load().BrandingOptions return data } @@ -537,18 +475,7 @@ func (a *Authenticate) revokeSession(ctx context.Context, w http.ResponseWriter, return "" } - profile, err := a.loadIdentityProfile(r, a.state.Load().cookieCipher) - if err != nil { - return "" - } - - oauthToken := new(oauth2.Token) - _ = json.Unmarshal(profile.GetOauthToken(), oauthToken) - if err := authenticator.Revoke(ctx, oauthToken); err != nil { - log.Ctx(ctx).Warn().Err(err).Msg("authenticate: failed to revoke access token") - } - - return string(profile.GetIdToken()) + return state.flow.RevokeSession(ctx, r, authenticator, nil) } // Callback handles the result of a successful call to the authenticate service @@ -603,19 +530,5 @@ func (a *Authenticate) getIdentityProviderIDForRequest(r *http.Request) string { if err := r.ParseForm(); err != nil { return "" } - return a.getIdentityProviderIDForURLValues(r.Form) -} - -func (a *Authenticate) getIdentityProviderIDForURLValues(vs url.Values) string { - state := a.state.Load() - idpID := "" - if _, requestParams, err := hpke.DecryptURLValues(state.hpkePrivateKey, vs); err == nil { - if idpID == "" { - idpID = requestParams.Get(urlutil.QueryIdentityProviderID) - } - } - if idpID == "" { - idpID = vs.Get(urlutil.QueryIdentityProviderID) - } - return idpID + return a.state.Load().flow.GetIdentityProviderIDForURLValues(r.Form) } diff --git a/authenticate/handlers_test.go b/authenticate/handlers_test.go index 943977314..d4cf4c5f2 100644 --- a/authenticate/handlers_test.go +++ b/authenticate/handlers_test.go @@ -8,7 +8,6 @@ import ( "net/http" "net/http/httptest" "net/url" - "strings" "testing" "time" @@ -24,6 +23,7 @@ import ( "github.com/pomerium/pomerium/internal/atomicutil" "github.com/pomerium/pomerium/internal/encoding/jws" "github.com/pomerium/pomerium/internal/encoding/mock" + "github.com/pomerium/pomerium/internal/handlers" "github.com/pomerium/pomerium/internal/httputil" "github.com/pomerium/pomerium/internal/identity" "github.com/pomerium/pomerium/internal/identity/oidc" @@ -40,6 +40,7 @@ func testAuthenticate() *Authenticate { auth.state = atomicutil.NewValue(&authenticateState{ redirectURL: redirectURL, cookieSecret: cryptutil.NewKey(), + flow: new(stubFlow), }) auth.options = config.NewAtomicOptions() auth.options.Store(&config.Options{ @@ -205,6 +206,7 @@ func TestAuthenticate_SignOut(t *testing.T) { state: atomicutil.NewValue(&authenticateState{ sessionStore: tt.sessionStore, sharedEncoder: mock.Encoder{}, + flow: new(stubFlow), }), options: config.NewAtomicOptions(), } @@ -301,6 +303,7 @@ func TestAuthenticate_OAuthCallback(t *testing.T) { redirectURL: authURL, sessionStore: tt.session, cookieCipher: aead, + flow: new(stubFlow), }), options: config.NewAtomicOptions(), } @@ -414,6 +417,7 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) { sessionStore: tt.session, cookieCipher: aead, sharedEncoder: signer, + flow: new(stubFlow), }), options: config.NewAtomicOptions(), } @@ -452,6 +456,7 @@ func TestAuthenticate_userInfo(t *testing.T) { var a Authenticate a.state = atomicutil.NewValue(&authenticateState{ cookieSecret: cryptutil.NewKey(), + flow: new(stubFlow), }) a.options = config.NewAtomicOptions() a.options.Store(&config.Options{ @@ -467,36 +472,32 @@ func TestAuthenticate_userInfo(t *testing.T) { now := time.Now() tests := []struct { - name string - url *url.URL - method string - sessionStore sessions.SessionStore - wantCode int - wantBody string + name string + url string + validSignature bool + sessionStore sessions.SessionStore + wantCode int }{ { - "good", - mustParseURL("/"), - http.MethodGet, + "not a redirect", + "/", + true, &mstore.Store{Encrypted: true, Session: &sessions.State{ID: "SESSION_ID", IssuedAt: jwt.NewNumericDate(now)}}, http.StatusOK, - "", }, { - "missing signature", - mustParseURL("/?pomerium_redirect_uri=http://example.com"), - http.MethodGet, + "signed redirect", + "/?pomerium_redirect_uri=http://example.com", + true, &mstore.Store{Encrypted: true, Session: &sessions.State{ID: "SESSION_ID", IssuedAt: jwt.NewNumericDate(now)}}, - http.StatusBadRequest, - "", + http.StatusFound, }, { - "bad signature", - urlutil.NewSignedURL([]byte("BAD KEY"), mustParseURL("/?pomerium_redirect_uri=http://example.com")).Sign(), - http.MethodGet, + "invalid redirect", + "/?pomerium_redirect_uri=http://example.com", + false, &mstore.Store{Encrypted: true, Session: &sessions.State{ID: "SESSION_ID", IssuedAt: jwt.NewNumericDate(now)}}, http.StatusBadRequest, - "", }, } for _, tt := range tests { @@ -513,14 +514,19 @@ func TestAuthenticate_userInfo(t *testing.T) { AuthenticateURLString: "https://authenticate.localhost.pomerium.io", SharedKey: "SHARED KEY", }) + f := new(stubFlow) + if !tt.validSignature { + f.verifySignatureErr = errors.New("bad signature") + } a := &Authenticate{ options: o, state: atomicutil.NewValue(&authenticateState{ sessionStore: tt.sessionStore, sharedEncoder: signer, + flow: f, }), } - r := httptest.NewRequest(tt.method, tt.url.String(), nil) + r := httptest.NewRequest(http.MethodGet, tt.url, nil) state, err := tt.sessionStore.LoadSession(r) if err != nil { t.Fatal(err) @@ -535,10 +541,6 @@ func TestAuthenticate_userInfo(t *testing.T) { if status := w.Code; status != tt.wantCode { t.Errorf("handler returned wrong status code: got %v want %v", status, tt.wantCode) } - body := w.Body.String() - if !strings.Contains(body, tt.wantBody) { - t.Errorf("Unexpected body, contains: %s, got: %s", tt.wantBody, body) - } }) } } @@ -565,3 +567,42 @@ func mustParseURL(rawurl string) *url.URL { } return u } + +// stubFlow is a stub implementation of the flow interface. +type stubFlow struct { + verifySignatureErr error +} + +func (f *stubFlow) VerifyAuthenticateSignature(*http.Request) error { + return f.verifySignatureErr +} + +func (*stubFlow) SignIn(http.ResponseWriter, *http.Request, *sessions.State) error { + return nil +} + +func (*stubFlow) PersistSession( + context.Context, http.ResponseWriter, *sessions.State, identity.SessionClaims, *oauth2.Token, +) error { + return nil +} + +func (*stubFlow) VerifySession(context.Context, *http.Request, *sessions.State) error { + return nil +} + +func (*stubFlow) RevokeSession( + context.Context, *http.Request, identity.Authenticator, *sessions.State, +) string { + return "" +} + +func (*stubFlow) GetUserInfoData(*http.Request, *sessions.State) handlers.UserInfoData { + return handlers.UserInfoData{} +} + +func (*stubFlow) LogAuthenticateEvent(*http.Request) {} + +func (*stubFlow) GetIdentityProviderIDForURLValues(url.Values) string { + return "" +} diff --git a/authenticate/middleware.go b/authenticate/middleware.go index 86cc3cbdb..abe949212 100644 --- a/authenticate/middleware.go +++ b/authenticate/middleware.go @@ -4,7 +4,6 @@ import ( "net/http" "github.com/pomerium/pomerium/internal/httputil" - "github.com/pomerium/pomerium/internal/middleware" "github.com/pomerium/pomerium/internal/urlutil" ) @@ -13,7 +12,7 @@ import ( func (a *Authenticate) requireValidSignatureOnRedirect(next httputil.HandlerFunc) http.Handler { return httputil.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error { if r.FormValue(urlutil.QueryRedirectURI) != "" || r.FormValue(urlutil.QueryHmacSignature) != "" { - err := middleware.ValidateRequestURL(a.getExternalRequest(r), a.state.Load().sharedKey) + err := a.state.Load().flow.VerifyAuthenticateSignature(r) if err != nil { return httputil.NewError(http.StatusBadRequest, err) } @@ -25,26 +24,10 @@ func (a *Authenticate) requireValidSignatureOnRedirect(next httputil.HandlerFunc // requireValidSignature validates the pomerium_signature. func (a *Authenticate) requireValidSignature(next httputil.HandlerFunc) http.Handler { return httputil.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error { - err := middleware.ValidateRequestURL(a.getExternalRequest(r), a.state.Load().sharedKey) + err := a.state.Load().flow.VerifyAuthenticateSignature(r) if err != nil { return err } return next(w, r) }) } - -func (a *Authenticate) getExternalRequest(r *http.Request) *http.Request { - options := a.options.Load() - - externalURL, err := options.GetAuthenticateURL() - if err != nil { - return r - } - - internalURL, err := options.GetInternalAuthenticateURL() - if err != nil { - return r - } - - return urlutil.GetExternalRequest(internalURL, externalURL, r) -} diff --git a/authenticate/state.go b/authenticate/state.go index dbce739b4..213060e11 100644 --- a/authenticate/state.go +++ b/authenticate/state.go @@ -1,23 +1,41 @@ package authenticate import ( + "context" "crypto/cipher" "fmt" + "net/http" "net/url" "github.com/go-jose/go-jose/v3" + "golang.org/x/oauth2" "github.com/pomerium/pomerium/config" + "github.com/pomerium/pomerium/internal/authenticateflow" "github.com/pomerium/pomerium/internal/encoding" "github.com/pomerium/pomerium/internal/encoding/jws" + "github.com/pomerium/pomerium/internal/handlers" + "github.com/pomerium/pomerium/internal/identity" "github.com/pomerium/pomerium/internal/sessions" "github.com/pomerium/pomerium/internal/sessions/cookie" "github.com/pomerium/pomerium/internal/urlutil" "github.com/pomerium/pomerium/pkg/cryptutil" - "github.com/pomerium/pomerium/pkg/hpke" ) +type flow interface { + VerifyAuthenticateSignature(r *http.Request) error + SignIn(w http.ResponseWriter, r *http.Request, sessionState *sessions.State) error + PersistSession(ctx context.Context, w http.ResponseWriter, sessionState *sessions.State, claims identity.SessionClaims, accessToken *oauth2.Token) error + VerifySession(ctx context.Context, r *http.Request, sessionState *sessions.State) error + RevokeSession(ctx context.Context, r *http.Request, authenticator identity.Authenticator, sessionState *sessions.State) string + GetUserInfoData(r *http.Request, sessionState *sessions.State) handlers.UserInfoData + LogAuthenticateEvent(r *http.Request) + GetIdentityProviderIDForURLValues(url.Values) string +} + type authenticateState struct { + flow flow + redirectURL *url.URL // sharedEncoder is the encoder to use to serialize data to be consumed // by other services @@ -34,8 +52,7 @@ type authenticateState struct { sessionStore sessions.SessionStore // sessionLoaders are a collection of session loaders to attempt to pull // a user's session state from - sessionLoader sessions.SessionLoader - hpkePrivateKey *hpke.PrivateKey + sessionLoader sessions.SessionLoader jwk *jose.JSONWebKeySet } @@ -46,7 +63,9 @@ func newAuthenticateState() *authenticateState { } } -func newAuthenticateStateFromConfig(cfg *config.Config) (*authenticateState, error) { +func newAuthenticateStateFromConfig( + cfg *config.Config, authenticateConfig *authenticateConfig, +) (*authenticateState, error) { err := ValidateOptions(cfg.Options) if err != nil { return nil, err @@ -125,12 +144,16 @@ func newAuthenticateStateFromConfig(cfg *config.Config) (*authenticateState, err } } - sharedKey, err := cfg.Options.GetSharedKey() + state.flow, err = authenticateflow.NewStateless( + cfg, + cookieStore, + authenticateConfig.getIdentityProvider, + authenticateConfig.profileTrimFn, + authenticateConfig.authEventFn, + ) if err != nil { return nil, err } - state.hpkePrivateKey = hpke.DerivePrivateKey(sharedKey) - return state, nil } diff --git a/authorize/authorize_test.go b/authorize/authorize_test.go index bc2a9b6c8..e545a8acf 100644 --- a/authorize/authorize_test.go +++ b/authorize/authorize_test.go @@ -26,6 +26,7 @@ func TestNew(t *testing.T) { config.Options{ AuthenticateURLString: "https://authN.example.com", DataBrokerURLString: "https://databroker.example.com", + CookieSecret: "15WXae6fvK9Hal0RGZ600JlCaflYHtNy9bAyOLTlvmc=", SharedKey: "2p/Wi2Q6bYDfzmoSEbKqYKtg+DUoLWTEHHs7vOhvL7w=", Policies: policies, }, @@ -36,6 +37,7 @@ func TestNew(t *testing.T) { config.Options{ AuthenticateURLString: "https://authN.example.com", DataBrokerURLString: "https://databroker.example.com", + CookieSecret: "15WXae6fvK9Hal0RGZ600JlCaflYHtNy9bAyOLTlvmc=", SharedKey: "AZA85podM73CjLCjViDNz1EUvvejKpWp7Hysr0knXA==", Policies: policies, }, @@ -46,6 +48,7 @@ func TestNew(t *testing.T) { config.Options{ AuthenticateURLString: "https://authN.example.com", DataBrokerURLString: "https://databroker.example.com", + CookieSecret: "15WXae6fvK9Hal0RGZ600JlCaflYHtNy9bAyOLTlvmc=", SharedKey: "sup", Policies: policies, }, @@ -56,6 +59,7 @@ func TestNew(t *testing.T) { config.Options{ AuthenticateURLString: "https://authN.example.com", DataBrokerURLString: "https://databroker.example.com", + CookieSecret: "15WXae6fvK9Hal0RGZ600JlCaflYHtNy9bAyOLTlvmc=", SharedKey: "AZA85podM73CjLCjViDNz1EUvvejKpWp7Hysr0knXA==", Policies: policies, }, @@ -67,6 +71,7 @@ func TestNew(t *testing.T) { config.Options{ AuthenticateURLString: "https://authN.example.com", DataBrokerURLString: "BAD", + CookieSecret: "15WXae6fvK9Hal0RGZ600JlCaflYHtNy9bAyOLTlvmc=", SharedKey: "AZA85podM73CjLCjViDNz1EUvvejKpWp7Hysr0knXA==", Policies: policies, }, @@ -105,6 +110,7 @@ func TestAuthorize_OnConfigChange(t *testing.T) { o := &config.Options{ AuthenticateURLString: "https://authN.example.com", DataBrokerURLString: "https://databroker.example.com", + CookieSecret: "15WXae6fvK9Hal0RGZ600JlCaflYHtNy9bAyOLTlvmc=", SharedKey: tc.SharedKey, Policies: tc.Policies, } diff --git a/authorize/check_response.go b/authorize/check_response.go index ecad76de5..2e931a04b 100644 --- a/authorize/check_response.go +++ b/authorize/check_response.go @@ -192,32 +192,17 @@ func (a *Authorize) requireLoginResponse( return a.deniedResponse(ctx, in, http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized), nil) } - authenticateURL, err := options.GetAuthenticateURL() - if err != nil { - return nil, err - } - idp, err := options.GetIdentityProviderForPolicy(request.Policy) if err != nil { return nil, err } - authenticateHPKEPublicKey, err := state.authenticateKeyFetcher.FetchPublicKey(ctx) - if err != nil { - return nil, err - } - // always assume https scheme checkRequestURL := getCheckRequestURL(in) checkRequestURL.Scheme = "https" - redirectTo, err := urlutil.SignInURL( - state.hpkePrivateKey, - authenticateHPKEPublicKey, - authenticateURL, - &checkRequestURL, - idp.GetId(), - ) + redirectTo, err := state.authenticateFlow.AuthenticateSignInURL( + ctx, nil, &checkRequestURL, idp.GetId()) if err != nil { return nil, err } diff --git a/authorize/state.go b/authorize/state.go index 1f421ad15..bb2d9e277 100644 --- a/authorize/state.go +++ b/authorize/state.go @@ -3,20 +3,25 @@ package authorize import ( "context" "fmt" + "net/url" googlegrpc "google.golang.org/grpc" "github.com/pomerium/pomerium/authorize/evaluator" "github.com/pomerium/pomerium/authorize/internal/store" "github.com/pomerium/pomerium/config" + "github.com/pomerium/pomerium/internal/authenticateflow" "github.com/pomerium/pomerium/pkg/grpc" "github.com/pomerium/pomerium/pkg/grpc/databroker" - "github.com/pomerium/pomerium/pkg/hpke" "github.com/pomerium/pomerium/pkg/protoutil" ) var outboundGRPCConnection = new(grpc.CachedOutboundGRPClientConn) +type authenticateFlow interface { + AuthenticateSignInURL(ctx context.Context, queryParams url.Values, redirectURL *url.URL, idpID string) (string, error) +} + type authorizeState struct { sharedKey []byte evaluator *evaluator.Evaluator @@ -24,8 +29,7 @@ type authorizeState struct { dataBrokerClient databroker.DataBrokerServiceClient auditEncryptor *protoutil.Encryptor sessionStore *config.SessionStore - hpkePrivateKey *hpke.PrivateKey - authenticateKeyFetcher hpke.KeyFetcher + authenticateFlow authenticateFlow } func newAuthorizeStateFromConfig( @@ -79,10 +83,9 @@ func newAuthorizeStateFromConfig( return nil, fmt.Errorf("authorize: invalid session store: %w", err) } - state.hpkePrivateKey = hpke.DerivePrivateKey(sharedKey) - state.authenticateKeyFetcher, err = cfg.GetAuthenticateKeyFetcher() + state.authenticateFlow, err = authenticateflow.NewStateless(cfg, nil, nil, nil, nil) if err != nil { - return nil, fmt.Errorf("authorize: get authenticate JWKS key fetcher: %w", err) + return nil, err } return state, nil diff --git a/internal/authenticateflow/authenticateflow.go b/internal/authenticateflow/authenticateflow.go new file mode 100644 index 000000000..67b78e972 --- /dev/null +++ b/internal/authenticateflow/authenticateflow.go @@ -0,0 +1,31 @@ +// Package authenticateflow implements the core authentication flow. This +// includes creating and parsing sign-in redirect URLs, storing and retrieving +// session data, and handling authentication callback URLs. +package authenticateflow + +import ( + "fmt" + + "google.golang.org/protobuf/types/known/structpb" + + "github.com/pomerium/pomerium/internal/identity" + "github.com/pomerium/pomerium/pkg/grpc" + "github.com/pomerium/pomerium/pkg/grpc/user" +) + +var outboundGRPCConnection = new(grpc.CachedOutboundGRPClientConn) + +func populateUserFromClaims(u *user.User, claims map[string]interface{}) { + if v, ok := claims["name"]; ok { + u.Name = fmt.Sprint(v) + } + if v, ok := claims["email"]; ok { + u.Email = fmt.Sprint(v) + } + if u.Claims == nil { + u.Claims = make(map[string]*structpb.ListValue) + } + for k, vs := range identity.Claims(claims).Flatten().ToPB() { + u.Claims[k] = vs + } +} diff --git a/authenticate/events.go b/internal/authenticateflow/events.go similarity index 77% rename from authenticate/events.go rename to internal/authenticateflow/events.go index cdb251ed1..9a9f1cd4f 100644 --- a/authenticate/events.go +++ b/internal/authenticateflow/events.go @@ -1,4 +1,4 @@ -package authenticate +package authenticateflow import ( "context" @@ -8,7 +8,7 @@ import ( "github.com/pomerium/pomerium/internal/httputil" "github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/urlutil" - "github.com/pomerium/pomerium/pkg/grpc/identity" + identitypb "github.com/pomerium/pomerium/pkg/grpc/identity" "github.com/pomerium/pomerium/pkg/hpke" ) @@ -45,14 +45,15 @@ type AuthEvent struct { // AuthEventFn is a function that handles an authentication event type AuthEventFn func(context.Context, AuthEvent) -func (a *Authenticate) logAuthenticateEvent(r *http.Request, profile *identity.Profile) { - if a.cfg.authEventFn == nil { +// TODO: move into stateless.go; this is here for now just so that Git will +// track the file history as a rename from authenticate/events.go. +func (s *Stateless) logAuthenticateEvent(r *http.Request, profile *identitypb.Profile) { + if s.authEventFn == nil { return } - state := a.state.Load() ctx := r.Context() - pub, params, err := hpke.DecryptURLValues(state.hpkePrivateKey, r.Form) + pub, params, err := hpke.DecryptURLValues(s.hpkePrivateKey, r.Form) if err != nil { log.Warn(ctx).Err(err).Msg("log authenticate event: failed to decrypt request params") } @@ -82,20 +83,5 @@ func (a *Authenticate) logAuthenticateEvent(r *http.Request, profile *identity.P evt.Domain = &domain } - a.cfg.authEventFn(ctx, evt) -} - -func getUserClaim(profile *identity.Profile, field string) *string { - if profile == nil { - return nil - } - if profile.Claims == nil { - return nil - } - val, ok := profile.Claims.Fields[field] - if !ok || val == nil { - return nil - } - txt := val.GetStringValue() - return &txt + s.authEventFn(ctx, evt) } diff --git a/authenticate/identity_profile.go b/internal/authenticateflow/identityprofile.go similarity index 50% rename from authenticate/identity_profile.go rename to internal/authenticateflow/identityprofile.go index 29fed27b3..d22c8c60e 100644 --- a/authenticate/identity_profile.go +++ b/internal/authenticateflow/identityprofile.go @@ -1,4 +1,4 @@ -package authenticate +package authenticateflow import ( "context" @@ -7,27 +7,36 @@ import ( "encoding/json" "fmt" "net/http" + "time" "golang.org/x/oauth2" "google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/types/known/structpb" + "google.golang.org/protobuf/types/known/timestamppb" "github.com/pomerium/pomerium/internal/httputil" "github.com/pomerium/pomerium/internal/identity" + "github.com/pomerium/pomerium/internal/identity/manager" + "github.com/pomerium/pomerium/internal/sessions" "github.com/pomerium/pomerium/internal/urlutil" "github.com/pomerium/pomerium/pkg/cryptutil" identitypb "github.com/pomerium/pomerium/pkg/grpc/identity" + "github.com/pomerium/pomerium/pkg/grpc/session" ) +// An "identity profile" is an alternative to a session, used in the stateless +// authenticate flow. An identity profile contains an IdP ID (to distinguish +// between different IdP's or between different clients of the same IdP), a +// user ID token, and an OAuth2 token. + var cookieChunker = httputil.NewCookieChunker() -func (a *Authenticate) buildIdentityProfile( - r *http.Request, +// buildIdentityProfile populates an identity profile. +func buildIdentityProfile( + idpID string, claims identity.SessionClaims, oauthToken *oauth2.Token, ) (*identitypb.Profile, error) { - idpID := r.FormValue(urlutil.QueryIdentityProviderID) - rawIDToken := []byte(claims.RawIDToken) rawOAuthToken, err := json.Marshal(oauthToken) if err != nil { @@ -46,7 +55,8 @@ func (a *Authenticate) buildIdentityProfile( }, nil } -func (a *Authenticate) loadIdentityProfile(r *http.Request, aead cipher.AEAD) (*identitypb.Profile, error) { +// loadIdentityProfile loads an identity profile from a chunked set of cookies. +func loadIdentityProfile(r *http.Request, aead cipher.AEAD) (*identitypb.Profile, error) { cookie, err := cookieChunker.LoadCookie(r, urlutil.QueryIdentityProfile) if err != nil { return nil, fmt.Errorf("authenticate: error loading identity profile cookie: %w", err) @@ -70,30 +80,35 @@ func (a *Authenticate) loadIdentityProfile(r *http.Request, aead cipher.AEAD) (* return &profile, nil } -func (a *Authenticate) storeIdentityProfile(w http.ResponseWriter, aead cipher.AEAD, profile *identitypb.Profile) error { - options := a.options.Load() - +// storeIdentityProfile writes the identity profile to a chunked set of cookies. +func storeIdentityProfile( + w http.ResponseWriter, + cookie *http.Cookie, + aead cipher.AEAD, + profile *identitypb.Profile, +) error { decrypted, err := protojson.Marshal(profile) if err != nil { // this shouldn't happen panic(fmt.Errorf("error marshaling message: %w", err)) } encrypted := cryptutil.Encrypt(aead, decrypted, nil) - cookie := options.NewCookie() cookie.Name = urlutil.QueryIdentityProfile cookie.Value = base64.RawURLEncoding.EncodeToString(encrypted) cookie.Path = "/" return cookieChunker.SetCookie(w, cookie) } -func (a *Authenticate) validateIdentityProfile(ctx context.Context, profile *identitypb.Profile) error { - authenticator, err := a.cfg.getIdentityProvider(a.options.Load(), profile.GetProviderId()) - if err != nil { - return err - } - +// validateIdentityProfile checks expirations timestamps for the ID token and +// OAuth2 token, and makes a user info request to the IdP in order to determine +// whether the OAuth2 token is still valid. +func validateIdentityProfile( + ctx context.Context, + authenticator identity.Authenticator, + profile *identitypb.Profile, +) error { oauthToken := new(oauth2.Token) - err = json.Unmarshal(profile.GetOauthToken(), oauthToken) + err := json.Unmarshal(profile.GetOauthToken(), oauthToken) if err != nil { return fmt.Errorf("invalid oauth token in profile: %w", err) } @@ -110,3 +125,48 @@ func (a *Authenticate) validateIdentityProfile(ctx context.Context, profile *ide return nil } + +func newSessionStateFromProfile(p *identitypb.Profile) *sessions.State { + claims := p.GetClaims().AsMap() + + ss := sessions.NewState(p.GetProviderId()) + + // set the subject + if v, ok := claims["sub"]; ok { + ss.Subject = fmt.Sprint(v) + } else if v, ok := claims["user"]; ok { + ss.Subject = fmt.Sprint(v) + } + + // set the oid + if v, ok := claims["oid"]; ok { + ss.OID = fmt.Sprint(v) + } + + return ss +} + +func populateSessionFromProfile(s *session.Session, p *identitypb.Profile, ss *sessions.State, cookieExpire time.Duration) { + claims := p.GetClaims().AsMap() + oauthToken := new(oauth2.Token) + _ = json.Unmarshal(p.GetOauthToken(), oauthToken) + + s.UserId = ss.UserID() + s.IssuedAt = timestamppb.Now() + s.AccessedAt = timestamppb.Now() + s.ExpiresAt = timestamppb.New(time.Now().Add(cookieExpire)) + s.IdToken = &session.IDToken{ + Issuer: ss.Issuer, + Subject: ss.Subject, + ExpiresAt: timestamppb.New(time.Now().Add(cookieExpire)), + IssuedAt: timestamppb.Now(), + Raw: string(p.GetIdToken()), + } + s.OauthToken = manager.ToOAuthToken(oauthToken) + if s.Claims == nil { + s.Claims = make(map[string]*structpb.ListValue) + } + for k, vs := range identity.Claims(claims).Flatten().ToPB() { + s.Claims[k] = vs + } +} diff --git a/internal/authenticateflow/request.go b/internal/authenticateflow/request.go new file mode 100644 index 000000000..09eb8f154 --- /dev/null +++ b/internal/authenticateflow/request.go @@ -0,0 +1,36 @@ +package authenticateflow + +import ( + "net/http" + + "github.com/pomerium/pomerium/config" + "github.com/pomerium/pomerium/internal/middleware" + "github.com/pomerium/pomerium/internal/urlutil" +) + +type signatureVerifier struct { + options *config.Options + sharedKey []byte +} + +// VerifyAuthenticateSignature checks that the provided request has a valid +// signature (for the authenticate service). +func (v signatureVerifier) VerifyAuthenticateSignature(r *http.Request) error { + return middleware.ValidateRequestURL(GetExternalAuthenticateRequest(r, v.options), v.sharedKey) +} + +// GetExternalAuthenticateRequest canonicalizes an authenticate request URL +// based on the provided configuration options. +func GetExternalAuthenticateRequest(r *http.Request, options *config.Options) *http.Request { + externalURL, err := options.GetAuthenticateURL() + if err != nil { + return r + } + + internalURL, err := options.GetInternalAuthenticateURL() + if err != nil { + return r + } + + return urlutil.GetExternalRequest(internalURL, externalURL, r) +} diff --git a/internal/authenticateflow/request_test.go b/internal/authenticateflow/request_test.go new file mode 100644 index 000000000..3f689bd03 --- /dev/null +++ b/internal/authenticateflow/request_test.go @@ -0,0 +1,58 @@ +package authenticateflow + +import ( + "net/http" + "net/url" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/pomerium/pomerium/config" + "github.com/pomerium/pomerium/internal/urlutil" +) + +func TestVerifyAuthenticateSignature(t *testing.T) { + options := &config.Options{ + AuthenticateURLString: "https://authenticate.example.com", + AuthenticateInternalURLString: "https://authenticate.internal", + } + key := []byte("SHARED KEY--(must be 32 bytes)--") + v := signatureVerifier{options, key} + + t.Run("Valid", func(t *testing.T) { + u := mustParseURL("https://example.com/") + r := &http.Request{Host: "example.com", URL: urlutil.NewSignedURL(key, u).Sign()} + err := v.VerifyAuthenticateSignature(r) + assert.NoError(t, err) + }) + t.Run("NoSignature", func(t *testing.T) { + r := &http.Request{Host: "example.com", URL: mustParseURL("https://example.com/")} + err := v.VerifyAuthenticateSignature(r) + assert.Error(t, err) + }) + t.Run("DifferentKey", func(t *testing.T) { + zeros := make([]byte, 32) + u := mustParseURL("https://example.com/") + r := &http.Request{Host: "example.com", URL: urlutil.NewSignedURL(zeros, u).Sign()} + err := v.VerifyAuthenticateSignature(r) + assert.Error(t, err) + }) + t.Run("InternalDomain", func(t *testing.T) { + // A request with the internal authenticate service URL should first be + // canonicalized to use the external authenticate service URL before + // validating the request signature. + u := urlutil.NewSignedURL(key, mustParseURL("https://authenticate.example.com/")).Sign() + u.Host = "authenticate.internal" + r := &http.Request{Host: "authenticate.internal", URL: u} + err := v.VerifyAuthenticateSignature(r) + assert.NoError(t, err) + }) +} + +func mustParseURL(rawurl string) *url.URL { + u, err := url.Parse(rawurl) + if err != nil { + panic(err) + } + return u +} diff --git a/internal/authenticateflow/stateless.go b/internal/authenticateflow/stateless.go new file mode 100644 index 000000000..48e567eb5 --- /dev/null +++ b/internal/authenticateflow/stateless.go @@ -0,0 +1,449 @@ +package authenticateflow + +import ( + "context" + "crypto/cipher" + "encoding/json" + "fmt" + "net/http" + "net/url" + + "github.com/go-jose/go-jose/v3" + "golang.org/x/oauth2" + "google.golang.org/protobuf/encoding/protojson" + + "github.com/pomerium/pomerium/config" + "github.com/pomerium/pomerium/internal/encoding" + "github.com/pomerium/pomerium/internal/encoding/jws" + "github.com/pomerium/pomerium/internal/handlers" + "github.com/pomerium/pomerium/internal/httputil" + "github.com/pomerium/pomerium/internal/identity" + "github.com/pomerium/pomerium/internal/log" + "github.com/pomerium/pomerium/internal/sessions" + "github.com/pomerium/pomerium/internal/urlutil" + "github.com/pomerium/pomerium/pkg/cryptutil" + "github.com/pomerium/pomerium/pkg/grpc" + "github.com/pomerium/pomerium/pkg/grpc/databroker" + identitypb "github.com/pomerium/pomerium/pkg/grpc/identity" + "github.com/pomerium/pomerium/pkg/grpc/session" + "github.com/pomerium/pomerium/pkg/grpc/user" + "github.com/pomerium/pomerium/pkg/hpke" +) + +// Stateless implements the stateless authentication flow. In this flow, the +// authenticate service has no direct access to the databroker and instead +// stores profile information in a cookie. +type Stateless struct { + signatureVerifier + + // sharedEncoder is the encoder to use to serialize data to be consumed + // by other services + sharedEncoder encoding.MarshalUnmarshaler + // cookieCipher is the cipher to use to encrypt/decrypt session data + cookieCipher cipher.AEAD + + sessionStore sessions.SessionStore + + hpkePrivateKey *hpke.PrivateKey + authenticateKeyFetcher hpke.KeyFetcher + + jwk *jose.JSONWebKeySet + + authenticateURL *url.URL + + options *config.Options + + dataBrokerClient databroker.DataBrokerServiceClient + + getIdentityProvider func(options *config.Options, idpID string) (identity.Authenticator, error) + profileTrimFn func(*identitypb.Profile) + authEventFn AuthEventFn +} + +// NewStateless initializes the authentication flow for the given +// configuration, session store, and additional options. +func NewStateless( + cfg *config.Config, + sessionStore sessions.SessionStore, + getIdentityProvider func(options *config.Options, idpID string) (identity.Authenticator, error), + profileTrimFn func(*identitypb.Profile), + authEventFn AuthEventFn, +) (*Stateless, error) { + s := &Stateless{ + options: cfg.Options, + sessionStore: sessionStore, + getIdentityProvider: getIdentityProvider, + profileTrimFn: profileTrimFn, + authEventFn: authEventFn, + } + + var err error + s.authenticateURL, err = cfg.Options.GetAuthenticateURL() + if err != nil { + return nil, err + } + + // shared cipher to encrypt data before passing data between services + sharedKey, err := cfg.Options.GetSharedKey() + if err != nil { + return nil, err + } + + // shared state encoder setup + s.sharedEncoder, err = jws.NewHS256Signer(sharedKey) + if err != nil { + return nil, err + } + + // private state encoder setup, used to encrypt oauth2 tokens + cookieSecret, err := cfg.Options.GetCookieSecret() + if err != nil { + return nil, err + } + + s.cookieCipher, err = cryptutil.NewAEADCipher(cookieSecret) + if err != nil { + return nil, err + } + + s.jwk = new(jose.JSONWebKeySet) + signingKey, err := cfg.Options.GetSigningKey() + if err != nil { + return nil, err + } + if len(signingKey) > 0 { + ks, err := cryptutil.PublicJWKsFromBytes(signingKey) + if err != nil { + return nil, fmt.Errorf("authenticate: failed to convert jwks: %w", err) + } + for _, k := range ks { + s.jwk.Keys = append(s.jwk.Keys, *k) + } + } + + s.signatureVerifier = signatureVerifier{cfg.Options, sharedKey} + + s.hpkePrivateKey = hpke.DerivePrivateKey(sharedKey) + + s.authenticateKeyFetcher, err = cfg.GetAuthenticateKeyFetcher() + if err != nil { + return nil, fmt.Errorf("authorize: get authenticate JWKS key fetcher: %w", err) + } + + dataBrokerConn, err := outboundGRPCConnection.Get(context.Background(), &grpc.OutboundOptions{ + OutboundPort: cfg.OutboundPort, + InstallationID: cfg.Options.InstallationID, + ServiceName: cfg.Options.Services, + SignedJWTKey: sharedKey, + }) + if err != nil { + return nil, err + } + + s.dataBrokerClient = databroker.NewDataBrokerServiceClient(dataBrokerConn) + + return s, nil +} + +// VerifySession checks that an existing session is still valid. +func (s *Stateless) VerifySession(ctx context.Context, r *http.Request, _ *sessions.State) error { + profile, err := loadIdentityProfile(r, s.cookieCipher) + if err != nil { + return fmt.Errorf("identity profile load error: %w", err) + } + + authenticator, err := s.getIdentityProvider(s.options, profile.GetProviderId()) + if err != nil { + return fmt.Errorf("couldn't get identity provider: %w", err) + } + + if err := validateIdentityProfile(ctx, authenticator, profile); err != nil { + return fmt.Errorf("invalid identity profile: %w", err) + } + + return nil +} + +// SignIn redirects to a route callback URL, if the provided request and +// session state are valid. +func (s *Stateless) SignIn( + w http.ResponseWriter, + r *http.Request, + sessionState *sessions.State, +) error { + if err := r.ParseForm(); err != nil { + return httputil.NewError(http.StatusBadRequest, err) + } + proxyPublicKey, requestParams, err := hpke.DecryptURLValues(s.hpkePrivateKey, r.Form) + if err != nil { + return err + } + + idpID := requestParams.Get(urlutil.QueryIdentityProviderID) + + // start over if this is a different identity provider + if sessionState == nil || sessionState.IdentityProviderID != idpID { + sessionState = sessions.NewState(idpID) + } + + // re-persist the session, useful when session was evicted from session store + if err := s.sessionStore.SaveSession(w, r, sessionState); err != nil { + return httputil.NewError(http.StatusBadRequest, err) + } + + profile, err := loadIdentityProfile(r, s.cookieCipher) + if err != nil { + return httputil.NewError(http.StatusBadRequest, err) + } + + if s.profileTrimFn != nil { + s.profileTrimFn(profile) + } + + s.logAuthenticateEvent(r, profile) + + encryptURLValues := hpke.EncryptURLValuesV1 + if hpke.IsEncryptedURLV2(r.Form) { + encryptURLValues = hpke.EncryptURLValuesV2 + } + + redirectTo, err := urlutil.CallbackURL(s.hpkePrivateKey, proxyPublicKey, requestParams, profile, encryptURLValues) + if err != nil { + return httputil.NewError(http.StatusInternalServerError, err) + } + + httputil.Redirect(w, r, redirectTo, http.StatusFound) + return nil +} + +// PersistSession stores session data in a cookie. +func (s *Stateless) PersistSession( + ctx context.Context, + w http.ResponseWriter, + sessionState *sessions.State, + claims identity.SessionClaims, + accessToken *oauth2.Token, +) error { + idpID := sessionState.IdentityProviderID + profile, err := buildIdentityProfile(idpID, claims, accessToken) + if err != nil { + return err + } + err = storeIdentityProfile(w, s.options.NewCookie(), s.cookieCipher, profile) + if err != nil { + log.Error(ctx).Err(err).Msg("failed to store identity profile") + } + return nil +} + +// GetUserInfoData returns user info data associated with the given request (if +// any). +func (s *Stateless) GetUserInfoData(r *http.Request, _ *sessions.State) handlers.UserInfoData { + profile, _ := loadIdentityProfile(r, s.cookieCipher) + return handlers.UserInfoData{ + Profile: profile, + } +} + +// RevokeSession revokes the session associated with the provided request, +// returning the ID token from the revoked session. +func (s *Stateless) RevokeSession( + ctx context.Context, r *http.Request, authenticator identity.Authenticator, _ *sessions.State, +) string { + profile, err := loadIdentityProfile(r, s.cookieCipher) + if err != nil { + return "" + } + + oauthToken := new(oauth2.Token) + _ = json.Unmarshal(profile.GetOauthToken(), oauthToken) + if err := authenticator.Revoke(ctx, oauthToken); err != nil { + log.Ctx(ctx).Warn().Err(err).Msg("authenticate: failed to revoke access token") + } + + return string(profile.GetIdToken()) +} + +// GetIdentityProviderIDForURLValues returns the identity provider ID +// associated with the given URL values. +func (s *Stateless) GetIdentityProviderIDForURLValues(vs url.Values) string { + idpID := "" + if _, requestParams, err := hpke.DecryptURLValues(s.hpkePrivateKey, vs); err == nil { + if idpID == "" { + idpID = requestParams.Get(urlutil.QueryIdentityProviderID) + } + } + if idpID == "" { + idpID = vs.Get(urlutil.QueryIdentityProviderID) + } + return idpID +} + +// LogAuthenticateEvent logs an authenticate service event. +func (s *Stateless) LogAuthenticateEvent(r *http.Request) { + s.logAuthenticateEvent(r, nil) +} + +func getUserClaim(profile *identitypb.Profile, field string) *string { + if profile == nil { + return nil + } + if profile.Claims == nil { + return nil + } + val, ok := profile.Claims.Fields[field] + if !ok || val == nil { + return nil + } + txt := val.GetStringValue() + return &txt +} + +// AuthenticateSignInURL returns a URL to redirect the user to the authenticate +// domain. +func (s *Stateless) AuthenticateSignInURL( + ctx context.Context, queryParams url.Values, redirectURL *url.URL, idpID string, +) (string, error) { + authenticateHPKEPublicKey, err := s.authenticateKeyFetcher.FetchPublicKey(ctx) + if err != nil { + return "", err + } + + authenticateURLWithParams := *s.authenticateURL + q := authenticateURLWithParams.Query() + for k, v := range queryParams { + q[k] = v + } + authenticateURLWithParams.RawQuery = q.Encode() + + return urlutil.SignInURL( + s.hpkePrivateKey, + authenticateHPKEPublicKey, + &authenticateURLWithParams, + redirectURL, + idpID, + ) +} + +// Callback handles a redirect to a route domain once signed in. +func (s *Stateless) Callback(w http.ResponseWriter, r *http.Request) error { + if err := r.ParseForm(); err != nil { + return httputil.NewError(http.StatusBadRequest, err) + } + + // decrypt the URL values + senderPublicKey, values, err := hpke.DecryptURLValues(s.hpkePrivateKey, r.Form) + if err != nil { + return httputil.NewError(http.StatusBadRequest, fmt.Errorf("invalid encrypted query string: %w", err)) + } + + // confirm this request came from the authenticate service + err = s.validateSenderPublicKey(r.Context(), senderPublicKey) + if err != nil { + return err + } + + // validate that the request has not expired + err = urlutil.ValidateTimeParameters(values) + if err != nil { + return httputil.NewError(http.StatusBadRequest, err) + } + + profile, err := getProfileFromValues(values) + if err != nil { + return err + } + + ss := newSessionStateFromProfile(profile) + sess, err := session.Get(r.Context(), s.dataBrokerClient, ss.ID) + if err != nil { + sess = &session.Session{Id: ss.ID} + } + populateSessionFromProfile(sess, profile, ss, s.options.CookieExpire) + u, err := user.Get(r.Context(), s.dataBrokerClient, ss.UserID()) + if err != nil { + u = &user.User{Id: ss.UserID()} + } + populateUserFromClaims(u, profile.GetClaims().AsMap()) + + redirectURI, err := getRedirectURIFromValues(values) + if err != nil { + return err + } + + // save the records + res, err := s.dataBrokerClient.Put(r.Context(), &databroker.PutRequest{ + Records: []*databroker.Record{ + databroker.NewRecord(sess), + databroker.NewRecord(u), + }, + }) + if err != nil { + return httputil.NewError(http.StatusInternalServerError, fmt.Errorf("proxy: error saving databroker records: %w", err)) + } + ss.DatabrokerServerVersion = res.GetServerVersion() + for _, record := range res.GetRecords() { + if record.GetVersion() > ss.DatabrokerRecordVersion { + ss.DatabrokerRecordVersion = record.GetVersion() + } + } + + // save the session state + rawJWT, err := s.sharedEncoder.Marshal(ss) + if err != nil { + return httputil.NewError(http.StatusInternalServerError, fmt.Errorf("proxy: error marshaling session state: %w", err)) + } + if err = s.sessionStore.SaveSession(w, r, rawJWT); err != nil { + return httputil.NewError(http.StatusInternalServerError, fmt.Errorf("proxy: error saving session state: %w", err)) + } + + // if programmatic, encode the session jwt as a query param + if isProgrammatic := values.Get(urlutil.QueryIsProgrammatic); isProgrammatic == "true" { + q := redirectURI.Query() + q.Set(urlutil.QueryPomeriumJWT, string(rawJWT)) + redirectURI.RawQuery = q.Encode() + } + + // redirect + httputil.Redirect(w, r, redirectURI.String(), http.StatusFound) + return nil +} + +func (s *Stateless) validateSenderPublicKey(ctx context.Context, senderPublicKey *hpke.PublicKey) error { + authenticatePublicKey, err := s.authenticateKeyFetcher.FetchPublicKey(ctx) + if err != nil { + return httputil.NewError(http.StatusInternalServerError, fmt.Errorf("hpke: error retrieving authenticate service public key: %w", err)) + } + + if !authenticatePublicKey.Equals(senderPublicKey) { + return httputil.NewError(http.StatusBadRequest, fmt.Errorf("hpke: invalid authenticate service public key")) + } + + return nil +} + +func getProfileFromValues(values url.Values) (*identitypb.Profile, error) { + rawProfile := values.Get(urlutil.QueryIdentityProfile) + if rawProfile == "" { + return nil, httputil.NewError(http.StatusBadRequest, fmt.Errorf("missing %s", urlutil.QueryIdentityProfile)) + } + + var profile identitypb.Profile + err := protojson.Unmarshal([]byte(rawProfile), &profile) + if err != nil { + return nil, httputil.NewError(http.StatusBadRequest, fmt.Errorf("invalid %s: %w", urlutil.QueryIdentityProfile, err)) + } + return &profile, nil +} + +func getRedirectURIFromValues(values url.Values) (*url.URL, error) { + rawRedirectURI := values.Get(urlutil.QueryRedirectURI) + if rawRedirectURI == "" { + return nil, httputil.NewError(http.StatusBadRequest, fmt.Errorf("missing %s", urlutil.QueryRedirectURI)) + } + redirectURI, err := urlutil.ParseAndValidateURL(rawRedirectURI) + if err != nil { + return nil, httputil.NewError(http.StatusBadRequest, fmt.Errorf("invalid %s: %w", urlutil.QueryRedirectURI, err)) + } + return redirectURI, nil +} diff --git a/proxy/handlers.go b/proxy/handlers.go index 132114b06..81f1889ad 100644 --- a/proxy/handlers.go +++ b/proxy/handlers.go @@ -1,7 +1,6 @@ package proxy import ( - "context" "errors" "fmt" "io" @@ -10,17 +9,11 @@ import ( "github.com/go-jose/go-jose/v3/jwt" "github.com/gorilla/mux" - "google.golang.org/protobuf/encoding/protojson" "github.com/pomerium/pomerium/internal/handlers" "github.com/pomerium/pomerium/internal/httputil" "github.com/pomerium/pomerium/internal/middleware" "github.com/pomerium/pomerium/internal/urlutil" - "github.com/pomerium/pomerium/pkg/grpc/databroker" - "github.com/pomerium/pomerium/pkg/grpc/identity" - "github.com/pomerium/pomerium/pkg/grpc/session" - "github.com/pomerium/pomerium/pkg/grpc/user" - "github.com/pomerium/pomerium/pkg/hpke" ) // registerDashboardHandlers returns the proxy service's ServeMux @@ -110,89 +103,7 @@ func (p *Proxy) deviceEnrolled(w http.ResponseWriter, r *http.Request) error { // Callback handles the result of a successful call to the authenticate service // and is responsible setting per-route sessions. func (p *Proxy) Callback(w http.ResponseWriter, r *http.Request) error { - state := p.state.Load() - options := p.currentOptions.Load() - - if err := r.ParseForm(); err != nil { - return httputil.NewError(http.StatusBadRequest, err) - } - - // decrypt the URL values - senderPublicKey, values, err := hpke.DecryptURLValues(state.hpkePrivateKey, r.Form) - if err != nil { - return httputil.NewError(http.StatusBadRequest, fmt.Errorf("invalid encrypted query string: %w", err)) - } - - // confirm this request came from the authenticate service - err = p.validateSenderPublicKey(r.Context(), senderPublicKey) - if err != nil { - return err - } - - // validate that the request has not expired - err = urlutil.ValidateTimeParameters(values) - if err != nil { - return httputil.NewError(http.StatusBadRequest, err) - } - - profile, err := getProfileFromValues(values) - if err != nil { - return err - } - - ss := newSessionStateFromProfile(profile) - s, err := session.Get(r.Context(), state.dataBrokerClient, ss.ID) - if err != nil { - s = &session.Session{Id: ss.ID} - } - populateSessionFromProfile(s, profile, ss, options.CookieExpire) - u, err := user.Get(r.Context(), state.dataBrokerClient, ss.UserID()) - if err != nil { - u = &user.User{Id: ss.UserID()} - } - populateUserFromProfile(u, profile, ss) - - redirectURI, err := getRedirectURIFromValues(values) - if err != nil { - return err - } - - // save the records - res, err := state.dataBrokerClient.Put(r.Context(), &databroker.PutRequest{ - Records: []*databroker.Record{ - databroker.NewRecord(s), - databroker.NewRecord(u), - }, - }) - if err != nil { - return httputil.NewError(http.StatusInternalServerError, fmt.Errorf("proxy: error saving databroker records: %w", err)) - } - ss.DatabrokerServerVersion = res.GetServerVersion() - for _, record := range res.GetRecords() { - if record.GetVersion() > ss.DatabrokerRecordVersion { - ss.DatabrokerRecordVersion = record.GetVersion() - } - } - - // save the session state - rawJWT, err := state.encoder.Marshal(ss) - if err != nil { - return httputil.NewError(http.StatusInternalServerError, fmt.Errorf("proxy: error marshaling session state: %w", err)) - } - if err = state.sessionStore.SaveSession(w, r, rawJWT); err != nil { - return httputil.NewError(http.StatusInternalServerError, fmt.Errorf("proxy: error saving session state: %w", err)) - } - - // if programmatic, encode the session jwt as a query param - if isProgrammatic := values.Get(urlutil.QueryIsProgrammatic); isProgrammatic == "true" { - q := redirectURI.Query() - q.Set(urlutil.QueryPomeriumJWT, string(rawJWT)) - redirectURI.RawQuery = q.Encode() - } - - // redirect - httputil.Redirect(w, r, redirectURI.String(), http.StatusFound) - return nil + return p.state.Load().authenticateFlow.Callback(w, r) } // ProgrammaticLogin returns a signed url that can be used to login @@ -215,20 +126,14 @@ func (p *Proxy) ProgrammaticLogin(w http.ResponseWriter, r *http.Request) error return httputil.NewError(http.StatusInternalServerError, err) } - hpkeAuthenticateKey, err := state.authenticateKeyFetcher.FetchPublicKey(r.Context()) - if err != nil { - return httputil.NewError(http.StatusInternalServerError, err) - } - - signinURL := *state.authenticateSigninURL callbackURI := urlutil.GetAbsoluteURL(r) callbackURI.Path = dashboardPath + "/callback/" - q := signinURL.Query() + q := url.Values{} q.Set(urlutil.QueryCallbackURI, callbackURI.String()) q.Set(urlutil.QueryIsProgrammatic, "true") - signinURL.RawQuery = q.Encode() - rawURL, err := urlutil.SignInURL(state.hpkePrivateKey, hpkeAuthenticateKey, &signinURL, redirectURI, idp.GetId()) + rawURL, err := state.authenticateFlow.AuthenticateSignInURL( + r.Context(), q, redirectURI, idp.GetId()) if err != nil { return httputil.NewError(http.StatusInternalServerError, err) } @@ -263,44 +168,3 @@ func (p *Proxy) jwtAssertion(w http.ResponseWriter, r *http.Request) error { _, _ = io.WriteString(w, rawAssertionJWT) return nil } - -func (p *Proxy) validateSenderPublicKey(ctx context.Context, senderPublicKey *hpke.PublicKey) error { - state := p.state.Load() - - authenticatePublicKey, err := state.authenticateKeyFetcher.FetchPublicKey(ctx) - if err != nil { - return httputil.NewError(http.StatusInternalServerError, fmt.Errorf("hpke: error retrieving authenticate service public key: %w", err)) - } - - if !authenticatePublicKey.Equals(senderPublicKey) { - return httputil.NewError(http.StatusBadRequest, fmt.Errorf("hpke: invalid authenticate service public key")) - } - - return nil -} - -func getProfileFromValues(values url.Values) (*identity.Profile, error) { - rawProfile := values.Get(urlutil.QueryIdentityProfile) - if rawProfile == "" { - return nil, httputil.NewError(http.StatusBadRequest, fmt.Errorf("missing %s", urlutil.QueryIdentityProfile)) - } - - var profile identity.Profile - err := protojson.Unmarshal([]byte(rawProfile), &profile) - if err != nil { - return nil, httputil.NewError(http.StatusBadRequest, fmt.Errorf("invalid %s: %w", urlutil.QueryIdentityProfile, err)) - } - return &profile, nil -} - -func getRedirectURIFromValues(values url.Values) (*url.URL, error) { - rawRedirectURI := values.Get(urlutil.QueryRedirectURI) - if rawRedirectURI == "" { - return nil, httputil.NewError(http.StatusBadRequest, fmt.Errorf("missing %s", urlutil.QueryRedirectURI)) - } - redirectURI, err := urlutil.ParseAndValidateURL(rawRedirectURI) - if err != nil { - return nil, httputil.NewError(http.StatusBadRequest, fmt.Errorf("invalid %s: %w", urlutil.QueryRedirectURI, err)) - } - return redirectURI, nil -} diff --git a/proxy/identity_profile.go b/proxy/identity_profile.go deleted file mode 100644 index cdc55d134..000000000 --- a/proxy/identity_profile.go +++ /dev/null @@ -1,79 +0,0 @@ -package proxy - -import ( - "encoding/json" - "fmt" - "time" - - "golang.org/x/oauth2" - "google.golang.org/protobuf/types/known/structpb" - "google.golang.org/protobuf/types/known/timestamppb" - - "github.com/pomerium/pomerium/internal/identity" - "github.com/pomerium/pomerium/internal/identity/manager" - "github.com/pomerium/pomerium/internal/sessions" - identitypb "github.com/pomerium/pomerium/pkg/grpc/identity" - "github.com/pomerium/pomerium/pkg/grpc/session" - "github.com/pomerium/pomerium/pkg/grpc/user" -) - -func newSessionStateFromProfile(p *identitypb.Profile) *sessions.State { - claims := p.GetClaims().AsMap() - - ss := sessions.NewState(p.GetProviderId()) - - // set the subject - if v, ok := claims["sub"]; ok { - ss.Subject = fmt.Sprint(v) - } else if v, ok := claims["user"]; ok { - ss.Subject = fmt.Sprint(v) - } - - // set the oid - if v, ok := claims["oid"]; ok { - ss.OID = fmt.Sprint(v) - } - - return ss -} - -func populateSessionFromProfile(s *session.Session, p *identitypb.Profile, ss *sessions.State, cookieExpire time.Duration) { - claims := p.GetClaims().AsMap() - oauthToken := new(oauth2.Token) - _ = json.Unmarshal(p.GetOauthToken(), oauthToken) - - s.UserId = ss.UserID() - s.IssuedAt = timestamppb.Now() - s.AccessedAt = timestamppb.Now() - s.ExpiresAt = timestamppb.New(time.Now().Add(cookieExpire)) - s.IdToken = &session.IDToken{ - Issuer: ss.Issuer, - Subject: ss.Subject, - ExpiresAt: timestamppb.New(time.Now().Add(cookieExpire)), - IssuedAt: timestamppb.Now(), - Raw: string(p.GetIdToken()), - } - s.OauthToken = manager.ToOAuthToken(oauthToken) - if s.Claims == nil { - s.Claims = make(map[string]*structpb.ListValue) - } - for k, vs := range identity.Claims(claims).Flatten().ToPB() { - s.Claims[k] = vs - } -} - -func populateUserFromProfile(u *user.User, p *identitypb.Profile, _ *sessions.State) { - claims := p.GetClaims().AsMap() - if v, ok := claims["name"]; ok { - u.Name = fmt.Sprint(v) - } - if v, ok := claims["email"]; ok { - u.Email = fmt.Sprint(v) - } - if u.Claims == nil { - u.Claims = make(map[string]*structpb.ListValue) - } - for k, vs := range identity.Claims(claims).Flatten().ToPB() { - u.Claims[k] = vs - } -} diff --git a/proxy/state.go b/proxy/state.go index ff58daaf7..bdb4b1a64 100644 --- a/proxy/state.go +++ b/proxy/state.go @@ -3,10 +3,11 @@ package proxy import ( "context" "crypto/cipher" - "fmt" + "net/http" "net/url" "github.com/pomerium/pomerium/config" + "github.com/pomerium/pomerium/internal/authenticateflow" "github.com/pomerium/pomerium/internal/encoding" "github.com/pomerium/pomerium/internal/encoding/jws" "github.com/pomerium/pomerium/internal/sessions" @@ -14,11 +15,15 @@ import ( "github.com/pomerium/pomerium/pkg/cryptutil" "github.com/pomerium/pomerium/pkg/grpc" "github.com/pomerium/pomerium/pkg/grpc/databroker" - "github.com/pomerium/pomerium/pkg/hpke" ) var outboundGRPCConnection = new(grpc.CachedOutboundGRPClientConn) +type authenticateFlow interface { + AuthenticateSignInURL(ctx context.Context, queryParams url.Values, redirectURL *url.URL, idpID string) (string, error) + Callback(w http.ResponseWriter, r *http.Request) error +} + type proxyState struct { sharedKey []byte sharedCipher cipher.AEAD @@ -28,16 +33,16 @@ type proxyState struct { authenticateSigninURL *url.URL authenticateRefreshURL *url.URL - encoder encoding.MarshalUnmarshaler - cookieSecret []byte - sessionStore sessions.SessionStore - jwtClaimHeaders config.JWTClaimHeaders - hpkePrivateKey *hpke.PrivateKey - authenticateKeyFetcher hpke.KeyFetcher + encoder encoding.MarshalUnmarshaler + cookieSecret []byte + sessionStore sessions.SessionStore + jwtClaimHeaders config.JWTClaimHeaders dataBrokerClient databroker.DataBrokerServiceClient programmaticRedirectDomainWhitelist []string + + authenticateFlow authenticateFlow } func newProxyStateFromConfig(cfg *config.Config) (*proxyState, error) { @@ -53,16 +58,6 @@ func newProxyStateFromConfig(cfg *config.Config) (*proxyState, error) { return nil, err } - state.hpkePrivateKey, err = cfg.Options.GetHPKEPrivateKey() - if err != nil { - return nil, err - } - - state.authenticateKeyFetcher, err = cfg.GetAuthenticateKeyFetcher() - if err != nil { - return nil, fmt.Errorf("authorize: get authenticate JWKS key fetcher: %w", err) - } - state.sharedCipher, err = cryptutil.NewAEADCipher(state.sharedKey) if err != nil { return nil, err @@ -119,5 +114,11 @@ func newProxyStateFromConfig(cfg *config.Config) (*proxyState, error) { state.programmaticRedirectDomainWhitelist = cfg.Options.ProgrammaticRedirectDomainWhitelist + state.authenticateFlow, err = authenticateflow.NewStateless( + cfg, state.sessionStore, nil, nil, nil) + if err != nil { + return nil, err + } + return state, nil }