diff --git a/authorize/grpc.go b/authorize/grpc.go index 51aec0100..df78c3a00 100644 --- a/authorize/grpc.go +++ b/authorize/grpc.go @@ -48,7 +48,7 @@ func (a *Authorize) Check(ctx context.Context, in *envoy_service_auth_v3.CheckRe requestID := requestid.FromHTTPHeader(hreq.Header) ctx = requestid.WithValue(ctx, requestID) - sessionState, _ := state.sessionStore.LoadSessionState(hreq) + sessionState, _ := state.sessionStore.LoadSessionStateAndCheckIDP(hreq) var s sessionOrServiceAccount var u *user.User diff --git a/config/session.go b/config/session.go index f0268d611..95f42bdf0 100644 --- a/config/session.go +++ b/config/session.go @@ -15,11 +15,14 @@ import ( // A SessionStore saves and loads sessions based on the options. type SessionStore struct { + store sessions.SessionStore + loader sessions.SessionLoader options *Options encoder encoding.MarshalUnmarshaler - loader sessions.SessionLoader } +var _ sessions.SessionStore = (*SessionStore)(nil) + // NewSessionStore creates a new SessionStore from the Options. func NewSessionStore(options *Options) (*SessionStore, error) { store := &SessionStore{ @@ -36,7 +39,7 @@ func NewSessionStore(options *Options) (*SessionStore, error) { return nil, fmt.Errorf("config/sessions: invalid session encoder: %w", err) } - cookieStore, err := cookie.NewStore(func() cookie.Options { + store.store, err = cookie.NewStore(func() cookie.Options { return cookie.Options{ Name: options.CookieName, Domain: options.CookieDomain, @@ -51,11 +54,21 @@ func NewSessionStore(options *Options) (*SessionStore, error) { } headerStore := header.NewStore(store.encoder) queryParamStore := queryparam.NewStore(store.encoder, urlutil.QuerySession) - store.loader = sessions.MultiSessionLoader(cookieStore, headerStore, queryParamStore) + store.loader = sessions.MultiSessionLoader(store.store, headerStore, queryParamStore) return store, nil } +// ClearSession clears the session. +func (store *SessionStore) ClearSession(w http.ResponseWriter, r *http.Request) { + store.store.ClearSession(w, r) +} + +// LoadSession loads the session. +func (store *SessionStore) LoadSession(r *http.Request) (string, error) { + return store.loader.LoadSession(r) +} + // LoadSessionState loads the session state from a request. func (store *SessionStore) LoadSessionState(r *http.Request) (*sessions.State, error) { rawJWT, err := store.loader.LoadSession(r) @@ -69,6 +82,16 @@ func (store *SessionStore) LoadSessionState(r *http.Request) (*sessions.State, e return nil, err } + return &state, nil +} + +// LoadSessionStateAndCheckIDP loads the session state from a request and checks that the idp id matches. +func (store *SessionStore) LoadSessionStateAndCheckIDP(r *http.Request) (*sessions.State, error) { + state, err := store.LoadSessionState(r) + if err != nil { + return nil, err + } + // confirm that the identity provider id matches the state if state.IdentityProviderID != "" { idp, err := store.options.GetIdentityProviderForRequestURL(urlutil.GetAbsoluteURL(r).String()) @@ -82,5 +105,10 @@ func (store *SessionStore) LoadSessionState(r *http.Request) (*sessions.State, e } } - return &state, nil + return state, nil +} + +// SaveSession saves the session. +func (store *SessionStore) SaveSession(w http.ResponseWriter, r *http.Request, v any) error { + return store.store.SaveSession(w, r, v) } diff --git a/config/session_test.go b/config/session_test.go index 1bd62d7c2..936b45c7c 100644 --- a/config/session_test.go +++ b/config/session_test.go @@ -70,7 +70,7 @@ func TestSessionStore_LoadSessionState(t *testing.T) { t.Run("mssing", func(t *testing.T) { r, err := http.NewRequest(http.MethodGet, "https://p1.example.com", nil) require.NoError(t, err) - s, err := store.LoadSessionState(r) + s, err := store.LoadSessionStateAndCheckIDP(r) assert.ErrorIs(t, err, sessions.ErrNoSessionFound) assert.Nil(t, s) }) @@ -85,7 +85,7 @@ func TestSessionStore_LoadSessionState(t *testing.T) { urlutil.QuerySession: {rawJWS}, }.Encode(), nil) require.NoError(t, err) - s, err := store.LoadSessionState(r) + s, err := store.LoadSessionStateAndCheckIDP(r) assert.NoError(t, err) assert.Empty(t, cmp.Diff(&sessions.State{ Issuer: "authenticate.example.com", @@ -103,7 +103,7 @@ func TestSessionStore_LoadSessionState(t *testing.T) { r, err := http.NewRequest(http.MethodGet, "https://p2.example.com", nil) require.NoError(t, err) r.Header.Set(httputil.HeaderPomeriumAuthorization, rawJWS) - s, err := store.LoadSessionState(r) + s, err := store.LoadSessionStateAndCheckIDP(r) assert.NoError(t, err) assert.Empty(t, cmp.Diff(&sessions.State{ Issuer: "authenticate.example.com", @@ -121,7 +121,7 @@ func TestSessionStore_LoadSessionState(t *testing.T) { r, err := http.NewRequest(http.MethodGet, "https://p2.example.com", nil) require.NoError(t, err) r.Header.Set(httputil.HeaderPomeriumAuthorization, rawJWS) - s, err := store.LoadSessionState(r) + s, err := store.LoadSessionStateAndCheckIDP(r) assert.Error(t, err) assert.Nil(t, s) }) @@ -134,7 +134,7 @@ func TestSessionStore_LoadSessionState(t *testing.T) { r, err := http.NewRequest(http.MethodGet, "https://p2.example.com", nil) require.NoError(t, err) r.Header.Set(httputil.HeaderPomeriumAuthorization, rawJWS) - s, err := store.LoadSessionState(r) + s, err := store.LoadSessionStateAndCheckIDP(r) assert.NoError(t, err) assert.Empty(t, cmp.Diff(&sessions.State{ Issuer: "authenticate.example.com", diff --git a/proxy/data.go b/proxy/data.go index ce13ce53e..288647a58 100644 --- a/proxy/data.go +++ b/proxy/data.go @@ -6,11 +6,8 @@ import ( "github.com/pomerium/csrf" "github.com/pomerium/datasource/pkg/directory" - "github.com/pomerium/pomerium/internal/encoding/jws" "github.com/pomerium/pomerium/internal/handlers" "github.com/pomerium/pomerium/internal/handlers/webauthn" - "github.com/pomerium/pomerium/internal/httputil" - "github.com/pomerium/pomerium/internal/sessions" "github.com/pomerium/pomerium/internal/urlutil" "github.com/pomerium/pomerium/pkg/grpc/databroker" "github.com/pomerium/pomerium/pkg/grpc/session" @@ -31,33 +28,12 @@ func (p *Proxy) getSession(ctx context.Context, sessionID string) (s *session.Se return s, isImpersonated, err } -func (p *Proxy) getSessionState(r *http.Request) (sessions.State, error) { - state := p.state.Load() - - rawJWT, err := state.sessionStore.LoadSession(r) - if err != nil { - return sessions.State{}, err - } - - encoder, err := jws.NewHS256Signer(state.sharedKey) - if err != nil { - return sessions.State{}, err - } - - var sessionState sessions.State - if err := encoder.Unmarshal([]byte(rawJWT), &sessionState); err != nil { - return sessions.State{}, httputil.NewError(http.StatusBadRequest, err) - } - - return sessionState, nil -} - func (p *Proxy) getUser(ctx context.Context, userID string) (*user.User, error) { client := p.state.Load().dataBrokerClient return user.Get(ctx, client, userID) } -func (p *Proxy) getUserInfoData(r *http.Request) (handlers.UserInfoData, error) { +func (p *Proxy) getUserInfoData(r *http.Request) handlers.UserInfoData { options := p.currentOptions.Load() state := p.state.Load() @@ -66,7 +42,7 @@ func (p *Proxy) getUserInfoData(r *http.Request) (handlers.UserInfoData, error) BrandingOptions: options.BrandingOptions, } - ss, err := p.getSessionState(r) + ss, err := p.state.Load().sessionStore.LoadSessionState(r) if err == nil { data.Session, data.IsImpersonated, err = p.getSession(r.Context(), ss.ID) if err != nil { @@ -82,7 +58,7 @@ func (p *Proxy) getUserInfoData(r *http.Request) (handlers.UserInfoData, error) data.WebAuthnCreationOptions, data.WebAuthnRequestOptions, _ = p.webauthn.GetOptions(r) data.WebAuthnURL = urlutil.WebAuthnURL(r, urlutil.GetAbsoluteURL(r), state.sharedKey, r.URL.Query()) p.fillEnterpriseUserInfoData(r.Context(), &data) - return data, nil + return data } func (p *Proxy) fillEnterpriseUserInfoData(ctx context.Context, data *handlers.UserInfoData) { @@ -109,7 +85,7 @@ func (p *Proxy) getWebauthnState(r *http.Request) (*webauthn.State, error) { options := p.currentOptions.Load() state := p.state.Load() - ss, err := p.getSessionState(r) + ss, err := p.state.Load().sessionStore.LoadSessionState(r) if err != nil { return nil, err } @@ -135,7 +111,7 @@ func (p *Proxy) getWebauthnState(r *http.Request) (*webauthn.State, error) { SharedKey: state.sharedKey, Client: state.dataBrokerClient, Session: s, - SessionState: &ss, + SessionState: ss, SessionStore: state.sessionStore, RelyingParty: webauthnutil.GetRelyingParty(r, state.dataBrokerClient), BrandingOptions: options.BrandingOptions, diff --git a/proxy/handlers.go b/proxy/handlers.go index 2038496a8..42fa01953 100644 --- a/proxy/handlers.go +++ b/proxy/handlers.go @@ -80,19 +80,13 @@ func (p *Proxy) SignOut(w http.ResponseWriter, r *http.Request) error { } func (p *Proxy) userInfo(w http.ResponseWriter, r *http.Request) error { - data, err := p.getUserInfoData(r) - if err != nil { - return err - } + data := p.getUserInfoData(r) handlers.UserInfo(data).ServeHTTP(w, r) return nil } func (p *Proxy) deviceEnrolled(w http.ResponseWriter, r *http.Request) error { - data, err := p.getUserInfoData(r) - if err != nil { - return err - } + data := p.getUserInfoData(r) handlers.DeviceEnrolled(data).ServeHTTP(w, r) return nil } diff --git a/proxy/handlers_test.go b/proxy/handlers_test.go index 596c628e4..49b1f5287 100644 --- a/proxy/handlers_test.go +++ b/proxy/handlers_test.go @@ -15,7 +15,9 @@ import ( "github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/internal/atomicutil" + "github.com/pomerium/pomerium/internal/encoding/jws" "github.com/pomerium/pomerium/internal/httputil" + "github.com/pomerium/pomerium/internal/sessions" "github.com/pomerium/pomerium/internal/urlutil" ) @@ -260,3 +262,78 @@ func TestProxy_registerDashboardHandlers_jwtEndpoint(t *testing.T) { assert.Equal(t, rawJWT, string(b)) }) } + +func TestLoadSessionState(t *testing.T) { + t.Parallel() + + t.Run("no session", func(t *testing.T) { + t.Parallel() + + opts := testOptions(t) + proxy, err := New(&config.Config{Options: opts}) + require.NoError(t, err) + + r := httptest.NewRequest(http.MethodGet, "/.pomerium/", nil) + w := httptest.NewRecorder() + proxy.ServeHTTP(w, r) + + assert.Equal(t, http.StatusOK, w.Code) + assert.Contains(t, w.Body.String(), "window.POMERIUM_DATA") + assert.NotContains(t, w.Body.String(), "___SESSION_ID___") + }) + t.Run("cookie session", func(t *testing.T) { + t.Parallel() + + opts := testOptions(t) + proxy, err := New(&config.Config{Options: opts}) + require.NoError(t, err) + + session := encodeSession(t, opts, &sessions.State{ + ID: "___SESSION_ID___", + }) + + r := httptest.NewRequest(http.MethodGet, "/.pomerium/", nil) + r.AddCookie(&http.Cookie{ + Name: opts.CookieName, + Domain: opts.CookieDomain, + Value: session, + }) + w := httptest.NewRecorder() + proxy.ServeHTTP(w, r) + + assert.Equal(t, http.StatusOK, w.Code) + assert.Contains(t, w.Body.String(), "___SESSION_ID___") + }) + t.Run("header session", func(t *testing.T) { + t.Parallel() + + opts := testOptions(t) + proxy, err := New(&config.Config{Options: opts}) + require.NoError(t, err) + + session := encodeSession(t, opts, &sessions.State{ + ID: "___SESSION_ID___", + }) + + r := httptest.NewRequest(http.MethodGet, "/.pomerium/", nil) + r.Header.Set("Authorization", "Bearer Pomerium-"+session) + w := httptest.NewRecorder() + proxy.ServeHTTP(w, r) + + assert.Equal(t, http.StatusOK, w.Code) + assert.Contains(t, w.Body.String(), "___SESSION_ID___") + }) +} + +func encodeSession(t *testing.T, opts *config.Options, state *sessions.State) string { + sharedKey, err := opts.GetSharedKey() + require.NoError(t, err) + + encoder, err := jws.NewHS256Signer(sharedKey) + require.NoError(t, err) + + sessionBS, err := encoder.Marshal(state) + require.NoError(t, err) + + return string(sessionBS) +} diff --git a/proxy/proxy.go b/proxy/proxy.go index dfcedadcc..ad8967d8a 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -71,6 +71,7 @@ func New(cfg *config.Config) (*Proxy, error) { currentOptions: config.NewAtomicOptions(), currentRouter: atomicutil.NewValue(httputil.NewRouter()), } + p.OnConfigChange(context.Background(), cfg) p.webauthn = webauthn.New(p.getWebauthnState) metrics.AddPolicyCountCallback("pomerium-proxy", func() int64 { diff --git a/proxy/state.go b/proxy/state.go index 55842d97b..cd5bc22a3 100644 --- a/proxy/state.go +++ b/proxy/state.go @@ -2,17 +2,11 @@ package proxy import ( "context" - "crypto/cipher" "net/http" "net/url" "github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/internal/authenticateflow" - "github.com/pomerium/pomerium/internal/encoding" - "github.com/pomerium/pomerium/internal/encoding/jws" - "github.com/pomerium/pomerium/internal/sessions" - "github.com/pomerium/pomerium/internal/sessions/cookie" - "github.com/pomerium/pomerium/pkg/cryptutil" "github.com/pomerium/pomerium/pkg/grpc" "github.com/pomerium/pomerium/pkg/grpc/databroker" ) @@ -25,24 +19,16 @@ type authenticateFlow interface { } type proxyState struct { - sharedKey []byte - sharedCipher cipher.AEAD - authenticateURL *url.URL authenticateDashboardURL *url.URL authenticateSigninURL *url.URL authenticateRefreshURL *url.URL - encoder encoding.MarshalUnmarshaler - cookieSecret []byte - sessionStore sessions.SessionStore - jwtClaimHeaders config.JWTClaimHeaders - - dataBrokerClient databroker.DataBrokerServiceClient - + sharedKey []byte + sessionStore *config.SessionStore + dataBrokerClient databroker.DataBrokerServiceClient programmaticRedirectDomainWhitelist []string - - authenticateFlow authenticateFlow + authenticateFlow authenticateFlow } func newProxyStateFromConfig(cfg *config.Config) (*proxyState, error) { @@ -53,49 +39,20 @@ func newProxyStateFromConfig(cfg *config.Config) (*proxyState, error) { state := new(proxyState) + state.authenticateURL, err = cfg.Options.GetAuthenticateURL() + if err != nil { + return nil, err + } + state.authenticateDashboardURL = state.authenticateURL.ResolveReference(&url.URL{Path: "/.pomerium/"}) + state.authenticateSigninURL = state.authenticateURL.ResolveReference(&url.URL{Path: signinURL}) + state.authenticateRefreshURL = state.authenticateURL.ResolveReference(&url.URL{Path: refreshURL}) + state.sharedKey, err = cfg.Options.GetSharedKey() if err != nil { return nil, err } - state.sharedCipher, err = cryptutil.NewAEADCipher(state.sharedKey) - if err != nil { - return nil, err - } - - state.cookieSecret, err = cfg.Options.GetCookieSecret() - if err != nil { - return nil, err - } - - // used to load and verify JWT tokens signed by the authenticate service - state.encoder, err = jws.NewHS256Signer(state.sharedKey) - if err != nil { - return nil, err - } - - state.jwtClaimHeaders = cfg.Options.JWTClaimsHeaders - - // errors checked in ValidateOptions - state.authenticateURL, err = cfg.Options.GetAuthenticateURL() - if err != nil { - return nil, err - } - - state.authenticateDashboardURL = state.authenticateURL.ResolveReference(&url.URL{Path: "/.pomerium/"}) - state.authenticateSigninURL = state.authenticateURL.ResolveReference(&url.URL{Path: signinURL}) - state.authenticateRefreshURL = state.authenticateURL.ResolveReference(&url.URL{Path: refreshURL}) - - state.sessionStore, err = cookie.NewStore(func() cookie.Options { - return cookie.Options{ - Name: cfg.Options.CookieName, - Domain: cfg.Options.CookieDomain, - Secure: true, - HTTPOnly: cfg.Options.CookieHTTPOnly, - Expire: cfg.Options.CookieExpire, - SameSite: cfg.Options.GetCookieSameSite(), - } - }, state.encoder) + state.sessionStore, err = config.NewSessionStore(cfg.Options) if err != nil { return nil, err } @@ -109,7 +66,6 @@ func newProxyStateFromConfig(cfg *config.Config) (*proxyState, error) { if err != nil { return nil, err } - state.dataBrokerClient = databroker.NewDataBrokerServiceClient(dataBrokerConn) state.programmaticRedirectDomainWhitelist = cfg.Options.ProgrammaticRedirectDomainWhitelist