diff --git a/.golangci.yml b/.golangci.yml index dea9f19e8..19dbb46d5 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -198,11 +198,6 @@ issues: linters: - staticcheck - # todo(bdd): replace in go 1.13 - - path: proxy/proxy.go - text: "copylocks: assignment copies lock value to transport" - linters: - - govet # Independently from option `exclude` we use default exclude patterns, # it can be disabled by this option. To list all # excluded by default patterns execute `golangci-lint run --help`. diff --git a/authenticate/authenticate.go b/authenticate/authenticate.go index d382fe3a3..e3cac9982 100644 --- a/authenticate/authenticate.go +++ b/authenticate/authenticate.go @@ -16,6 +16,10 @@ import ( "github.com/pomerium/pomerium/internal/frontend" "github.com/pomerium/pomerium/internal/identity" "github.com/pomerium/pomerium/internal/sessions" + "github.com/pomerium/pomerium/internal/sessions/cache" + "github.com/pomerium/pomerium/internal/sessions/cookie" + "github.com/pomerium/pomerium/internal/sessions/header" + "github.com/pomerium/pomerium/internal/sessions/queryparam" "github.com/pomerium/pomerium/internal/urlutil" ) @@ -49,6 +53,8 @@ type Authenticate struct { // authentication flow RedirectURL *url.URL + // values related to cross service communication + // // sharedKey is used to encrypt and authenticate data between services sharedKey string // sharedCipher is used to encrypt data for use between services @@ -57,16 +63,21 @@ type Authenticate struct { // by other services sharedEncoder encoding.MarshalUnmarshaler - // data related to this service only - cookieOptions *sessions.CookieOptions - // cookieSecret is the secret to encrypt and authenticate data for this service + // values related to user sessions + // + // cookieSecret is the secret to encrypt and authenticate session data cookieSecret []byte - // is the cipher to use to encrypt data for this service - cookieCipher cipher.AEAD - sessionStore sessions.SessionStore + // cookieCipher is the cipher to use to encrypt/decrypt session data + cookieCipher cipher.AEAD + // encryptedEncoder is the encoder used to marshal and unmarshal session data encryptedEncoder encoding.MarshalUnmarshaler - sessionStores []sessions.SessionStore - sessionLoaders []sessions.SessionLoader + // sessionStore is the session store used to persist a user's session + sessionStore sessions.SessionStore + cookieOptions *cookie.Options + + // sessionLoaders are a collection of session loaders to attempt to pull + // a user's session state from + sessionLoaders []sessions.SessionLoader // provider is the interface to interacting with the identity provider (IdP) provider identity.Authenticator @@ -92,7 +103,7 @@ func New(opts config.Options) (*Authenticate, error) { cookieCipher, _ := cryptutil.NewAEADCipher(decodedCookieSecret) encryptedEncoder := ecjson.New(cookieCipher) - cookieOptions := &sessions.CookieOptions{ + cookieOptions := &cookie.Options{ Name: opts.CookieName, Domain: opts.CookieDomain, Secure: opts.CookieSecure, @@ -100,12 +111,13 @@ func New(opts config.Options) (*Authenticate, error) { Expire: opts.CookieExpire, } - cookieStore, err := sessions.NewCookieStore(cookieOptions, encryptedEncoder) + cookieStore, err := cookie.NewStore(cookieOptions, encryptedEncoder) if err != nil { return nil, err } - qpStore := sessions.NewQueryParamStore(encryptedEncoder, "pomerium_programmatic_token") - headerStore := sessions.NewHeaderStore(encryptedEncoder, "Pomerium") + cacheStore := cache.NewStore(encryptedEncoder, cookieStore, opts.CookieName) + qpStore := queryparam.NewStore(encryptedEncoder, "pomerium_programmatic_token") + headerStore := header.NewStore(encryptedEncoder, "Pomerium") redirectURL, _ := urlutil.DeepCopy(opts.AuthenticateURL) redirectURL.Path = callbackPath @@ -135,10 +147,9 @@ func New(opts config.Options) (*Authenticate, error) { cookieSecret: decodedCookieSecret, cookieCipher: cookieCipher, cookieOptions: cookieOptions, - sessionStore: cookieStore, + sessionStore: cacheStore, encryptedEncoder: encryptedEncoder, - sessionLoaders: []sessions.SessionLoader{qpStore, headerStore, cookieStore}, - sessionStores: []sessions.SessionStore{cookieStore, qpStore}, + sessionLoaders: []sessions.SessionLoader{cacheStore, qpStore, headerStore, cookieStore}, // IdP provider: provider, diff --git a/authenticate/authenticate_test.go b/authenticate/authenticate_test.go index 26ed69f52..f1f0bed53 100644 --- a/authenticate/authenticate_test.go +++ b/authenticate/authenticate_test.go @@ -72,15 +72,18 @@ func TestOptions_Validate(t *testing.T) { func TestNew(t *testing.T) { good := newTestOptions(t) + good.CookieName = "A" badRedirectURL := newTestOptions(t) badRedirectURL.AuthenticateURL = nil + badRedirectURL.CookieName = "B" badCookieName := newTestOptions(t) badCookieName.CookieName = "" badProvider := newTestOptions(t) badProvider.Provider = "" + badProvider.CookieName = "C" tests := []struct { name string diff --git a/authenticate/handlers.go b/authenticate/handlers.go index c084d87cf..8b39e71f1 100644 --- a/authenticate/handlers.go +++ b/authenticate/handlers.go @@ -1,6 +1,7 @@ package authenticate // import "github.com/pomerium/pomerium/authenticate" import ( + "context" "encoding/base64" "encoding/json" "errors" @@ -18,6 +19,7 @@ import ( "github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/middleware" "github.com/pomerium/pomerium/internal/sessions" + "github.com/pomerium/pomerium/internal/telemetry/trace" "github.com/pomerium/pomerium/internal/urlutil" ) @@ -58,6 +60,7 @@ func (a *Authenticate) Handler() http.Handler { v.Use(a.VerifySession) v.Path("/sign_in").Handler(httputil.HandlerFunc(a.SignIn)) v.Path("/sign_out").Handler(httputil.HandlerFunc(a.SignOut)) + v.Path("/refresh").Handler(httputil.HandlerFunc(a.Refresh)).Methods(http.MethodGet) // programmatic access api endpoint api := r.PathPrefix("/api").Subrouter() @@ -73,12 +76,12 @@ func (a *Authenticate) VerifySession(next http.Handler) http.Handler { return httputil.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error { state, err := sessions.FromContext(r.Context()) if errors.Is(err, sessions.ErrExpired) { - if err := a.refresh(w, r, state); err != nil { + ctx, err := a.refresh(w, r, state) + if err != nil { log.FromRequest(r).Info().Err(err).Msg("authenticate: verify session, refresh") return a.reauthenticateOrFail(w, r, err) } - // redirect to restart middleware-chain following refresh - httputil.Redirect(w, r, urlutil.GetAbsoluteURL(r).String(), http.StatusFound) + next.ServeHTTP(w, r.WithContext(ctx)) return nil } else if err != nil { log.FromRequest(r).Info().Err(err).Msg("authenticate: verify session") @@ -89,15 +92,18 @@ func (a *Authenticate) VerifySession(next http.Handler) http.Handler { }) } -func (a *Authenticate) refresh(w http.ResponseWriter, r *http.Request, s *sessions.State) error { - newSession, err := a.provider.Refresh(r.Context(), s) +func (a *Authenticate) refresh(w http.ResponseWriter, r *http.Request, s *sessions.State) (context.Context, error) { + ctx, span := trace.StartSpan(r.Context(), "authenticate.VerifySession/refresh") + defer span.End() + newSession, err := a.provider.Refresh(ctx, s) if err != nil { - return fmt.Errorf("authenticate: refresh failed: %w", err) + return nil, fmt.Errorf("authenticate: refresh failed: %w", err) } if err := a.sessionStore.SaveSession(w, r, newSession); err != nil { - return fmt.Errorf("authenticate: refresh save failed: %w", err) + return nil, fmt.Errorf("authenticate: refresh save failed: %w", err) } - return nil + // return the new session and add it to the current request context + return sessions.NewContext(ctx, newSession, err), nil } // RobotsTxt handles the /robots.txt route. @@ -158,7 +164,6 @@ func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) error { encSession, err := a.encryptedEncoder.Marshal(newSession) if err != nil { return httputil.NewError(http.StatusBadRequest, err) - } callbackParams.Set(urlutil.QueryRefreshToken, string(encSession)) callbackParams.Set(urlutil.QueryIsProgrammatic, "true") @@ -345,3 +350,27 @@ func (a *Authenticate) RefreshAPI(w http.ResponseWriter, r *http.Request) error w.Write(jsonResponse) return nil } + +// Refresh is called by the proxy service to handle backend session refresh. +// +// NOTE: The actual refresh is actually handled as part of the "VerifySession" +// middleware. This handler is responsible for creating a new route scoped +// session and returning it. +func (a *Authenticate) Refresh(w http.ResponseWriter, r *http.Request) error { + s, err := sessions.FromContext(r.Context()) + if err != nil { + return httputil.NewError(http.StatusBadRequest, err) + } + + routeSession := s.NewSession(r.Host, []string{r.Host, r.FormValue("aud")}) + routeSession.AccessTokenID = s.AccessTokenID + + signedJWT, err := a.sharedEncoder.Marshal(routeSession.RouteSession()) + if err != nil { + return err + } + + w.Header().Set("Content-Type", "application/jwt") // RFC 7519 : 10.3.1 + w.Write(signedJWT) + return nil +} diff --git a/authenticate/handlers_test.go b/authenticate/handlers_test.go index 508a132b1..5555373fe 100644 --- a/authenticate/handlers_test.go +++ b/authenticate/handlers_test.go @@ -11,17 +11,18 @@ import ( "testing" "time" - "github.com/pomerium/pomerium/internal/httputil" - "github.com/pomerium/pomerium/internal/cryptutil" "github.com/pomerium/pomerium/internal/encoding" "github.com/pomerium/pomerium/internal/encoding/mock" "github.com/pomerium/pomerium/internal/frontend" + "github.com/pomerium/pomerium/internal/httputil" "github.com/pomerium/pomerium/internal/identity" "github.com/pomerium/pomerium/internal/sessions" - "github.com/pomerium/pomerium/internal/urlutil" + "github.com/pomerium/pomerium/internal/sessions/cookie" + mstore "github.com/pomerium/pomerium/internal/sessions/mock" "github.com/google/go-cmp/cmp" + "github.com/pomerium/pomerium/internal/urlutil" "golang.org/x/crypto/chacha20poly1305" "golang.org/x/oauth2" "gopkg.in/square/go-jose.v2/jwt" @@ -32,7 +33,7 @@ func testAuthenticate() *Authenticate { auth.RedirectURL, _ = url.Parse("https://auth.example.com/oauth/callback") auth.sharedKey = cryptutil.NewBase64Key() auth.cookieSecret = cryptutil.NewKey() - auth.cookieOptions = &sessions.CookieOptions{Name: "name"} + auth.cookieOptions = &cookie.Options{Name: "name"} auth.templates = template.Must(frontend.NewTemplates()) return &auth } @@ -112,19 +113,19 @@ func TestAuthenticate_SignIn(t *testing.T) { encoder encoding.MarshalUnmarshaler wantCode int }{ - {"good", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound}, - {"session not valid", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(-10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound}, - {"bad redirect uri query", "", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "^^^"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusBadRequest}, - {"bad marshal", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{MarshalError: errors.New("error")}, http.StatusBadRequest}, - {"session error", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, &sessions.MockSessionStore{LoadError: errors.New("error")}, identity.MockProvider{}, &mock.Encoder{}, http.StatusBadRequest}, - {"good with different programmatic redirect", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/", urlutil.QueryCallbackURI: "https://some.example"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound}, - {"encrypted encoder error", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/", urlutil.QueryCallbackURI: "https://some.example"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{MarshalError: errors.New("error")}, http.StatusBadRequest}, - {"good with callback uri set", "https", "corp.example.example", map[string]string{urlutil.QueryCallbackURI: "https://some.example/", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound}, - {"bad callback uri set", "https", "corp.example.example", map[string]string{urlutil.QueryCallbackURI: "^", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusBadRequest}, - {"good programmatic request", "https", "corp.example.example", map[string]string{urlutil.QueryIsProgrammatic: "true", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound}, - {"good additional audience", "https", "corp.example.example", map[string]string{urlutil.QueryForwardAuth: "x.y.z", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound}, - {"good user impersonate", "https", "corp.example.example", map[string]string{urlutil.QueryImpersonateAction: "set", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound}, - {"bad user impersonate save failure", "https", "corp.example.example", map[string]string{urlutil.QueryImpersonateAction: "set", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &sessions.MockSessionStore{SaveError: errors.New("err"), Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusBadRequest}, + {"good", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound}, + {"session not valid", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(-10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound}, + {"bad redirect uri query", "", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "^^^"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusBadRequest}, + {"bad marshal", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{MarshalError: errors.New("error")}, http.StatusBadRequest}, + {"session error", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{LoadError: errors.New("error")}, identity.MockProvider{}, &mock.Encoder{}, http.StatusBadRequest}, + {"good with different programmatic redirect", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/", urlutil.QueryCallbackURI: "https://some.example"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound}, + {"encrypted encoder error", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/", urlutil.QueryCallbackURI: "https://some.example"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{MarshalError: errors.New("error")}, http.StatusBadRequest}, + {"good with callback uri set", "https", "corp.example.example", map[string]string{urlutil.QueryCallbackURI: "https://some.example/", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound}, + {"bad callback uri set", "https", "corp.example.example", map[string]string{urlutil.QueryCallbackURI: "^", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusBadRequest}, + {"good programmatic request", "https", "corp.example.example", map[string]string{urlutil.QueryIsProgrammatic: "true", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound}, + {"good additional audience", "https", "corp.example.example", map[string]string{urlutil.QueryForwardAuth: "x.y.z", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound}, + {"good user impersonate", "https", "corp.example.example", map[string]string{urlutil.QueryImpersonateAction: "set", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound}, + {"bad user impersonate save failure", "https", "corp.example.example", map[string]string{urlutil.QueryImpersonateAction: "set", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{SaveError: errors.New("err"), Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusBadRequest}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -136,7 +137,7 @@ func TestAuthenticate_SignIn(t *testing.T) { sharedEncoder: tt.encoder, encryptedEncoder: tt.encoder, sharedCipher: aead, - cookieOptions: &sessions.CookieOptions{ + cookieOptions: &cookie.Options{ Name: "cookie", Domain: "foo", }, @@ -186,10 +187,10 @@ func TestAuthenticate_SignOut(t *testing.T) { wantCode int wantBody string }{ - {"good post", http.MethodPost, nil, "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, http.StatusFound, ""}, - {"failed revoke", http.MethodPost, nil, "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{RevokeError: errors.New("OH NO")}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, http.StatusBadRequest, "{\"Status\":400,\"Error\":\"Bad Request: OH NO\"}\n"}, - {"load session error", http.MethodPost, errors.New("error"), "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{RevokeError: errors.New("OH NO")}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, http.StatusBadRequest, "{\"Status\":400,\"Error\":\"Bad Request: error\"}\n"}, - {"bad redirect uri", http.MethodPost, nil, "corp.pomerium.io/", "sig", "ts", identity.MockProvider{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, http.StatusBadRequest, "{\"Status\":400,\"Error\":\"Bad Request: corp.pomerium.io/ url does contain a valid scheme\"}\n"}, + {"good post", http.MethodPost, nil, "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, http.StatusFound, ""}, + {"failed revoke", http.MethodPost, nil, "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{RevokeError: errors.New("OH NO")}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, http.StatusBadRequest, "{\"Status\":400,\"Error\":\"Bad Request: OH NO\"}\n"}, + {"load session error", http.MethodPost, errors.New("error"), "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{RevokeError: errors.New("OH NO")}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, http.StatusBadRequest, "{\"Status\":400,\"Error\":\"Bad Request: error\"}\n"}, + {"bad redirect uri", http.MethodPost, nil, "corp.pomerium.io/", "sig", "ts", identity.MockProvider{}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, http.StatusBadRequest, "{\"Status\":400,\"Error\":\"Bad Request: corp.pomerium.io/ url does contain a valid scheme\"}\n"}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -247,19 +248,19 @@ func TestAuthenticate_OAuthCallback(t *testing.T) { want string wantCode int }{ - {"good", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusFound}, - {"failed authenticate", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateError: errors.New("error")}, "", http.StatusInternalServerError}, - {"failed save session", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{SaveError: errors.New("error")}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "", http.StatusInternalServerError}, - {"provider returned error", http.MethodGet, time.Now().Unix(), "", "", "", "idp error", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "", http.StatusBadRequest}, - {"empty code", http.MethodGet, time.Now().Unix(), "", "", "", "", "", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "", http.StatusBadRequest}, - {"invalid redirect uri", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "", http.StatusBadRequest}, - {"bad redirect uri", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "http://^^^", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusBadRequest}, - {"bad timing - too soon", http.MethodGet, time.Now().Add(1 * time.Hour).Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusBadRequest}, - {"bad timing - expired", http.MethodGet, time.Now().Add(-1 * time.Hour).Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusBadRequest}, - {"bad base64", http.MethodGet, time.Now().Unix(), "", "", "^", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusBadRequest}, - {"too many seperators", http.MethodGet, time.Now().Unix(), "", "", "|ok|now|what", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusBadRequest}, - {"bad hmac", http.MethodGet, time.Now().Unix(), "", "NOTMAC", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusBadRequest}, - {"bad hmac", http.MethodGet, time.Now().Unix(), base64.URLEncoding.EncodeToString([]byte("malformed_state")), "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusBadRequest}, + {"good", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusFound}, + {"failed authenticate", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateError: errors.New("error")}, "", http.StatusInternalServerError}, + {"failed save session", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{SaveError: errors.New("error")}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "", http.StatusInternalServerError}, + {"provider returned error", http.MethodGet, time.Now().Unix(), "", "", "", "idp error", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "", http.StatusBadRequest}, + {"empty code", http.MethodGet, time.Now().Unix(), "", "", "", "", "", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "", http.StatusBadRequest}, + {"invalid redirect uri", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "", http.StatusBadRequest}, + {"bad redirect uri", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "http://^^^", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusBadRequest}, + {"bad timing - too soon", http.MethodGet, time.Now().Add(1 * time.Hour).Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusBadRequest}, + {"bad timing - expired", http.MethodGet, time.Now().Add(-1 * time.Hour).Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusBadRequest}, + {"bad base64", http.MethodGet, time.Now().Unix(), "", "", "^", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusBadRequest}, + {"too many seperators", http.MethodGet, time.Now().Unix(), "", "", "|ok|now|what", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusBadRequest}, + {"bad hmac", http.MethodGet, time.Now().Unix(), "", "NOTMAC", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusBadRequest}, + {"bad hmac", http.MethodGet, time.Now().Unix(), base64.URLEncoding.EncodeToString([]byte("malformed_state")), "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusBadRequest}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -326,12 +327,12 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) { wantStatus int }{ - {"good", nil, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, http.StatusOK}, - {"invalid session", nil, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, errors.New("hi"), identity.MockProvider{}, http.StatusFound}, - {"good refresh expired", nil, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, http.StatusFound}, - {"expired,refresh error", nil, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{RefreshError: errors.New("error")}, http.StatusFound}, - {"expired,save error", nil, &sessions.MockSessionStore{SaveError: errors.New("error"), Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, http.StatusFound}, - {"expired XHR,refresh error", map[string]string{"X-Requested-With": "XmlHttpRequest"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{RefreshError: errors.New("error")}, http.StatusUnauthorized}, + {"good", nil, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, http.StatusOK}, + {"invalid session", nil, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, errors.New("hi"), identity.MockProvider{}, http.StatusFound}, + {"good refresh expired", nil, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, http.StatusOK}, + {"expired,refresh error", nil, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{RefreshError: errors.New("error")}, http.StatusFound}, + {"expired,save error", nil, &mstore.Store{SaveError: errors.New("error"), Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, http.StatusFound}, + {"expired XHR,refresh error", map[string]string{"X-Requested-With": "XmlHttpRequest"}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{RefreshError: errors.New("error")}, http.StatusUnauthorized}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -384,11 +385,11 @@ func TestAuthenticate_RefreshAPI(t *testing.T) { wantStatus int }{ - {"good", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, mock.Encoder{MarshalResponse: []byte("ok")}, mock.Encoder{MarshalResponse: []byte("ok")}, http.StatusOK}, - {"refresh error", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshError: errors.New("error")}, mock.Encoder{MarshalResponse: []byte("ok")}, mock.Encoder{MarshalResponse: []byte("ok")}, http.StatusInternalServerError}, - {"session is not refreshable error", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, errors.New("session error"), identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, mock.Encoder{MarshalResponse: []byte("ok")}, mock.Encoder{MarshalResponse: []byte("ok")}, http.StatusBadRequest}, - {"secret encoder failed", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, mock.Encoder{MarshalError: errors.New("error")}, mock.Encoder{MarshalResponse: []byte("ok")}, http.StatusInternalServerError}, - {"shared encoder failed", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, mock.Encoder{MarshalResponse: []byte("ok")}, mock.Encoder{MarshalError: errors.New("error")}, http.StatusInternalServerError}, + {"good", &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, mock.Encoder{MarshalResponse: []byte("ok")}, mock.Encoder{MarshalResponse: []byte("ok")}, http.StatusOK}, + {"refresh error", &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshError: errors.New("error")}, mock.Encoder{MarshalResponse: []byte("ok")}, mock.Encoder{MarshalResponse: []byte("ok")}, http.StatusInternalServerError}, + {"session is not refreshable error", &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, errors.New("session error"), identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, mock.Encoder{MarshalResponse: []byte("ok")}, mock.Encoder{MarshalResponse: []byte("ok")}, http.StatusBadRequest}, + {"secret encoder failed", &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, mock.Encoder{MarshalError: errors.New("error")}, mock.Encoder{MarshalResponse: []byte("ok")}, http.StatusInternalServerError}, + {"shared encoder failed", &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, mock.Encoder{MarshalResponse: []byte("ok")}, mock.Encoder{MarshalError: errors.New("error")}, http.StatusInternalServerError}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -423,3 +424,54 @@ func TestAuthenticate_RefreshAPI(t *testing.T) { }) } } +func TestAuthenticate_Refresh(t *testing.T) { + t.Parallel() + tests := []struct { + name string + + session sessions.SessionStore + ctxError error + + provider identity.Authenticator + secretEncoder encoding.MarshalUnmarshaler + sharedEncoder encoding.MarshalUnmarshaler + + wantStatus int + }{ + {"good", &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, mock.Encoder{MarshalResponse: []byte("ok")}, mock.Encoder{MarshalResponse: []byte("ok")}, http.StatusOK}, + {"bad session", &mstore.Store{}, errors.New("err"), identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, mock.Encoder{MarshalResponse: []byte("ok")}, mock.Encoder{MarshalResponse: []byte("ok")}, http.StatusBadRequest}, + {"encoder error", &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, mock.Encoder{MarshalResponse: []byte("ok")}, mock.Encoder{MarshalError: errors.New("err")}, http.StatusInternalServerError}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + aead, err := chacha20poly1305.NewX(cryptutil.NewKey()) + if err != nil { + t.Fatal(err) + } + a := Authenticate{ + sharedKey: cryptutil.NewBase64Key(), + cookieSecret: cryptutil.NewKey(), + RedirectURL: uriParseHelper("https://authenticate.corp.beyondperimeter.com"), + encryptedEncoder: tt.secretEncoder, + sharedEncoder: tt.sharedEncoder, + sessionStore: tt.session, + provider: tt.provider, + cookieCipher: aead, + } + r := httptest.NewRequest("GET", "/", nil) + state, _ := tt.session.LoadSession(r) + ctx := r.Context() + ctx = sessions.NewContext(ctx, state, tt.ctxError) + r = r.WithContext(ctx) + + r.Header.Set("Accept", "application/json") + + w := httptest.NewRecorder() + httputil.HandlerFunc(a.Refresh).ServeHTTP(w, r) + if status := w.Code; status != tt.wantStatus { + t.Errorf("VerifySession() error = %v, wantErr %v\n%v", w.Result().StatusCode, tt.wantStatus, w.Body.String()) + + } + }) + } +} diff --git a/config/options.go b/config/options.go index 85a27caf9..e884e80d2 100644 --- a/config/options.go +++ b/config/options.go @@ -11,12 +11,13 @@ import ( "strings" "time" - "github.com/fsnotify/fsnotify" "github.com/pomerium/pomerium/internal/cryptutil" "github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/telemetry/metrics" "github.com/pomerium/pomerium/internal/urlutil" + "github.com/cespare/xxhash/v2" + "github.com/fsnotify/fsnotify" "github.com/mitchellh/hashstructure" "github.com/spf13/viper" "gopkg.in/yaml.v2" @@ -477,7 +478,7 @@ type OptionsUpdater interface { // Checksum returns the checksum of the current options struct func (o *Options) Checksum() string { - hash, err := hashstructure.Hash(o, nil) + hash, err := hashstructure.Hash(o, &hashstructure.HashOptions{Hasher: xxhash.New()}) if err != nil { log.Warn().Err(err).Msg("config: checksum failure") return "no checksum available" diff --git a/go.mod b/go.mod index ff8f61d58..1eb7a34e9 100644 --- a/go.mod +++ b/go.mod @@ -6,9 +6,9 @@ require ( cloud.google.com/go v0.49.0 // indirect contrib.go.opencensus.io/exporter/jaeger v0.2.0 contrib.go.opencensus.io/exporter/prometheus v0.1.0 - github.com/cespare/xxhash/v2 v2.1.1 // indirect + github.com/cespare/xxhash/v2 v2.1.1 github.com/fsnotify/fsnotify v1.4.7 - github.com/golang/groupcache v0.0.0-20191027212112-611e8accdfc9 // indirect + github.com/golang/groupcache v0.0.0-20191227052852-215e87163ea7 github.com/golang/mock v1.3.1 github.com/golang/protobuf v1.3.2 github.com/google/go-cmp v0.3.1 diff --git a/go.sum b/go.sum index 192302142..b7a313eb2 100644 --- a/go.sum +++ b/go.sum @@ -71,6 +71,8 @@ github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6 h1:ZgQEtGgCBiWRM github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/groupcache v0.0.0-20191027212112-611e8accdfc9 h1:uHTyIjqVhYRhLbJ8nIiOJHkEZZ+5YoOsAbD3sk82NiE= github.com/golang/groupcache v0.0.0-20191027212112-611e8accdfc9/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= +github.com/golang/groupcache v0.0.0-20191227052852-215e87163ea7 h1:5ZkaAPbicIKTF2I64qf5Fh8Aa83Q/dnOafMYV0OMwjA= +github.com/golang/groupcache v0.0.0-20191227052852-215e87163ea7/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= github.com/golang/mock v1.2.0 h1:28o5sBqPkBsMGnC6b4MvE2TzSr5/AT4c/1fLqVGIwlk= github.com/golang/mock v1.2.0/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= @@ -162,6 +164,7 @@ github.com/prometheus/client_golang v0.9.3/go.mod h1:/TN21ttK/J9q6uSwhBd54HahCDf github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5FsnadC4Ky3P0J6CfImo= github.com/prometheus/client_golang v1.2.1 h1:JnMpQc6ppsNgw9QPAGF6Dod479itz7lvlsMzzNayLOI= github.com/prometheus/client_golang v1.2.1/go.mod h1:XMU6Z2MjaRKVu/dC1qupJI9SiNkDYzz3xecMgSW/F+U= +github.com/prometheus/client_golang v1.3.0 h1:miYCvYqFXtl/J9FIy8eNpBfYthAEFg+Ys0XyUVEcDsc= github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90 h1:S/YWwWx/RA8rT8tKFRuGUZhuA90OyIBpPCXkcbwU8DE= github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= @@ -214,6 +217,7 @@ github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/spf13/viper v1.5.0 h1:GpsTwfsQ27oS/Aha/6d1oD7tpKIqWnOA6tgOX9HHkt4= github.com/spf13/viper v1.5.0/go.mod h1:AkYRkVJF8TkSG/xet6PzXX+l39KhhXa2pdqVSxnTcn4= +github.com/spf13/viper v1.6.1 h1:VPZzIkznI1YhVMRi6vNFLHSwhnhReBfgTxIPccpfdZk= github.com/spf13/viper v1.6.1/go.mod h1:t3iDnF5Jlj76alVNuyFBk5oUMCvsrkbvZK0WQdfDi5k= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= @@ -384,6 +388,7 @@ gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8 gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= +gopkg.in/ini.v1 v1.51.0 h1:AQvPpx3LzTDM0AjnIRlVFwFFGC+npRopjZxLJj6gdno= gopkg.in/ini.v1 v1.51.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k= gopkg.in/resty.v1 v1.12.0/go.mod h1:mDo4pnntr5jdWRML875a/NmxYqAlA73dVijT2AXvQQo= gopkg.in/square/go-jose.v2 v2.4.0 h1:0kXPskUMGAXXWJlP05ktEMOV0vmzFQUWw6d+aZJQU8A= diff --git a/internal/encoding/mock/mock_encoder.go b/internal/encoding/mock/mock_encoder.go index a28d0f861..f757fc521 100644 --- a/internal/encoding/mock/mock_encoder.go +++ b/internal/encoding/mock/mock_encoder.go @@ -1,5 +1,13 @@ package mock // import "github.com/pomerium/pomerium/internal/encoding/mock" +import ( + "github.com/pomerium/pomerium/internal/encoding" +) + +var _ encoding.MarshalUnmarshaler = &Encoder{} +var _ encoding.Marshaler = &Encoder{} +var _ encoding.Unmarshaler = &Encoder{} + // Encoder MockCSRFStore is a mock implementation of Cipher. type Encoder struct { MarshalResponse []byte diff --git a/internal/httputil/client.go b/internal/httputil/client.go index 535bd0d2f..175660d29 100644 --- a/internal/httputil/client.go +++ b/internal/httputil/client.go @@ -8,23 +8,21 @@ import ( "fmt" "io" "io/ioutil" - "net" "net/http" "net/url" "time" + + "go.opencensus.io/plugin/ochttp" ) // ErrTokenRevoked signifies a token revokation or expiration error var ErrTokenRevoked = errors.New("token expired or revoked") -var httpClient = &http.Client{ - Timeout: time.Second * 5, - Transport: &http.Transport{ - Dial: (&net.Dialer{ - Timeout: 2 * time.Second, - }).Dial, - TLSHandshakeTimeout: 2 * time.Second, - }, +// DefaultClient avoids leaks by setting an upper limit for timeouts. +var DefaultClient = &http.Client{ + Timeout: 1 * time.Minute, + //todo(bdd): incorporate metrics.HTTPMetricsRoundTripper + Transport: &ochttp.Transport{}, } // Client provides a simple helper interface to make HTTP requests @@ -36,9 +34,11 @@ func Client(ctx context.Context, method, endpoint, userAgent string, headers map case http.MethodGet: // error checking skipped because we are just parsing in // order to make a copy of an existing URL - u, _ := url.Parse(endpoint) - u.RawQuery = params.Encode() - endpoint = u.String() + if params != nil { + u, _ := url.Parse(endpoint) + u.RawQuery = params.Encode() + endpoint = u.String() + } default: return fmt.Errorf(http.StatusText(http.StatusBadRequest)) } @@ -52,7 +52,7 @@ func Client(ctx context.Context, method, endpoint, userAgent string, headers map req.Header.Set(k, v) } - resp, err := httpClient.Do(req) + resp, err := DefaultClient.Do(req) if err != nil { return err } @@ -79,7 +79,6 @@ func Client(ctx context.Context, method, endpoint, userAgent string, headers map return fmt.Errorf(http.StatusText(resp.StatusCode)) } } - if response != nil { err := json.Unmarshal(respBody, &response) if err != nil { diff --git a/internal/sessions/cache/cache_store.go b/internal/sessions/cache/cache_store.go new file mode 100644 index 000000000..2c9748be4 --- /dev/null +++ b/internal/sessions/cache/cache_store.go @@ -0,0 +1,131 @@ +package cache // import "github.com/pomerium/pomerium/internal/sessions/cache" + +import ( + "context" + "errors" + "fmt" + "net/http" + + "github.com/golang/groupcache" + + "github.com/pomerium/pomerium/internal/encoding" + "github.com/pomerium/pomerium/internal/log" + "github.com/pomerium/pomerium/internal/sessions" +) + +var _ sessions.SessionStore = &Store{} +var _ sessions.SessionLoader = &Store{} + +const ( + defaultQueryParamKey = "ati" +) + +// Store implements the session store interface using a distributed cache. +type Store struct { + name string + encoder encoding.Marshaler + decoder encoding.Unmarshaler + + cache *groupcache.Group + wrappedStore sessions.SessionStore +} + +// defaultCacheSize is ~10MB +var defaultCacheSize int64 = 10 << 20 + +// NewStore creates a new session store built on the distributed caching library +// groupcache. On a cache miss, the cache store attempts to fallback to another +// SessionStore implementation. +func NewStore(enc encoding.MarshalUnmarshaler, wrappedStore sessions.SessionStore, name string) *Store { + store := &Store{ + name: name, + encoder: enc, + decoder: enc, + wrappedStore: wrappedStore, + } + + store.cache = groupcache.NewGroup(name, defaultCacheSize, groupcache.GetterFunc( + func(ctx context.Context, id string, dest groupcache.Sink) error { + // fill the cache with session set as part of the request + // context set previously as part of SaveSession. + b := fromContext(ctx) + if len(b) == 0 { + return fmt.Errorf("sessions/cache: cannot fill key %s from ctx", id) + } + if err := dest.SetBytes(b); err != nil { + return fmt.Errorf("sessions/cache: sink error %w", err) + } + return nil + }, + )) + + return store +} + +// LoadSession implements SessionLoaders's LoadSession method for cache store. +func (s *Store) LoadSession(r *http.Request) (*sessions.State, error) { + // look for our cache's key in the default query param + sessionID := r.URL.Query().Get(defaultQueryParamKey) + if sessionID == "" { + // if unset, fallback to default cache store + log.FromRequest(r).Debug().Msg("sessions/cache: no query param, trying wrapped loader") + return s.wrappedStore.LoadSession(r) + } + + var b []byte + if err := s.cache.Get(r.Context(), sessionID, groupcache.AllocatingByteSliceSink(&b)); err != nil { + log.FromRequest(r).Debug().Err(err).Msg("sessions/cache: miss, trying wrapped loader") + return s.wrappedStore.LoadSession(r) + } + var session sessions.State + if err := s.decoder.Unmarshal(b, &session); err != nil { + log.FromRequest(r).Error().Err(err).Msg("sessions/cache: unmarshal") + return nil, sessions.ErrMalformed + } + return &session, nil +} + +// ClearSession implements SessionStore's ClearSession for the cache store. +// Since group cache has no explicit eviction, we just call the wrapped +// store's ClearSession method here. +func (s *Store) ClearSession(w http.ResponseWriter, r *http.Request) { + s.wrappedStore.ClearSession(w, r) +} + +// SaveSession implements SessionStore's SaveSession method for cache store. +func (s *Store) SaveSession(w http.ResponseWriter, r *http.Request, x interface{}) error { + err := s.wrappedStore.SaveSession(w, r, x) + if err != nil { + return fmt.Errorf("sessions/cache: wrapped store save error %w", err) + } + + state, ok := x.(*sessions.State) + if !ok { + return errors.New("internal/sessions: cannot cache non state type") + } + + data, err := s.encoder.Marshal(&state) + if err != nil { + return fmt.Errorf("sessions/cache: marshal %w", err) + } + + ctx := newContext(r.Context(), data) + var b []byte + return s.cache.Get(ctx, state.AccessTokenID, groupcache.AllocatingByteSliceSink(&b)) +} + +var sessionCtxKey = &contextKey{"PomeriumCachedSessionBytes"} + +type contextKey struct { + name string +} + +func newContext(ctx context.Context, b []byte) context.Context { + ctx = context.WithValue(ctx, sessionCtxKey, b) + return ctx +} + +func fromContext(ctx context.Context) []byte { + b, _ := ctx.Value(sessionCtxKey).([]byte) + return b +} diff --git a/internal/sessions/cache/cache_store_test.go b/internal/sessions/cache/cache_store_test.go new file mode 100644 index 000000000..6949338a0 --- /dev/null +++ b/internal/sessions/cache/cache_store_test.go @@ -0,0 +1,133 @@ +package cache + +import ( + "fmt" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "github.com/pomerium/pomerium/internal/cryptutil" + "github.com/pomerium/pomerium/internal/encoding/ecjson" + "github.com/pomerium/pomerium/internal/sessions" + "github.com/pomerium/pomerium/internal/sessions/cookie" + "gopkg.in/square/go-jose.v2/jwt" +) + +func testAuthorizer(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, err := sessions.FromContext(r.Context()) + if err != nil { + http.Error(w, err.Error(), http.StatusUnauthorized) + return + } + next.ServeHTTP(w, r) + }) +} + +func TestVerifier(t *testing.T) { + fnh := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + fmt.Fprint(w, http.StatusText(http.StatusOK)) + w.WriteHeader(http.StatusOK) + }) + + tests := []struct { + name string + skipSave bool + cacheSize int64 + state sessions.State + + wantBody string + wantStatus int + }{ + {"good", false, 1 << 10, sessions.State{AccessTokenID: cryptutil.NewBase64Key(), Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}, http.StatusText(http.StatusOK), http.StatusOK}, + {"expired", false, 1 << 10, sessions.State{AccessTokenID: cryptutil.NewBase64Key(), Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}, "internal/sessions: validation failed, token is expired (exp)\n", http.StatusUnauthorized}, + {"empty", false, 1 << 10, sessions.State{AccessTokenID: "", Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}, "internal/sessions: session is not found\n", http.StatusUnauthorized}, + {"miss", true, 1 << 10, sessions.State{AccessTokenID: cryptutil.NewBase64Key(), Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}, "internal/sessions: session is not found\n", http.StatusUnauthorized}, + {"cache eviction", false, 1, sessions.State{AccessTokenID: cryptutil.NewBase64Key(), Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}, "internal/sessions: session is not found\n", http.StatusUnauthorized}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + defaultCacheSize = tt.cacheSize + cipher, err := cryptutil.NewAEADCipherFromBase64(cryptutil.NewBase64Key()) + encoder := ecjson.New(cipher) + if err != nil { + t.Fatal(err) + } + cs, err := cookie.NewStore(&cookie.Options{Name: t.Name()}, encoder) + if err != nil { + t.Fatal(err) + } + cacheStore := NewStore(encoder, cs, t.Name()) + + r := httptest.NewRequest(http.MethodGet, "/", nil) + q := r.URL.Query() + + q.Set(defaultQueryParamKey, tt.state.AccessTokenID) + r.URL.RawQuery = q.Encode() + r.Header.Set("Accept", "application/json") + w := httptest.NewRecorder() + + got := sessions.RetrieveSession(cacheStore)(testAuthorizer((fnh))) + + if !tt.skipSave { + cacheStore.SaveSession(w, r, &tt.state) + } + + for i := 1; i <= 10; i++ { + s := tt.state + s.AccessTokenID = cryptutil.NewBase64Key() + cacheStore.SaveSession(w, r, s) + } + + got.ServeHTTP(w, r) + + gotBody := w.Body.String() + gotStatus := w.Result().StatusCode + + if diff := cmp.Diff(gotBody, tt.wantBody); diff != "" { + t.Errorf("RetrieveSession() = %v", diff) + } + if diff := cmp.Diff(gotStatus, tt.wantStatus); diff != "" { + t.Errorf("RetrieveSession() = %v", diff) + } + }) + } +} + +func TestStore_SaveSession(t *testing.T) { + + tests := []struct { + name string + x interface{} + wantErr bool + }{ + {"good", &sessions.State{AccessTokenID: cryptutil.NewBase64Key(), Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}, false}, + {"bad type", "bad type!", true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cipher, err := cryptutil.NewAEADCipherFromBase64(cryptutil.NewBase64Key()) + encoder := ecjson.New(cipher) + if err != nil { + t.Fatal(err) + } + cs, err := cookie.NewStore(&cookie.Options{ + Name: "_pomerium", + }, encoder) + if err != nil { + t.Fatal(err) + } + cacheStore := NewStore(encoder, cs, t.Name()) + r := httptest.NewRequest(http.MethodGet, "/", nil) + r.Header.Set("Accept", "application/json") + w := httptest.NewRecorder() + + if err := cacheStore.SaveSession(w, r, tt.x); (err != nil) != tt.wantErr { + t.Errorf("Store.SaveSession() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} diff --git a/internal/sessions/cookie_store.go b/internal/sessions/cookie/cookie_store.go similarity index 65% rename from internal/sessions/cookie_store.go rename to internal/sessions/cookie/cookie_store.go index 8b1b5d255..bf3d87769 100644 --- a/internal/sessions/cookie_store.go +++ b/internal/sessions/cookie/cookie_store.go @@ -1,4 +1,4 @@ -package sessions // import "github.com/pomerium/pomerium/internal/sessions" +package cookie // import "github.com/pomerium/pomerium/internal/sessions/cookie" import ( "errors" @@ -8,8 +8,15 @@ import ( "time" "github.com/pomerium/pomerium/internal/encoding" + "github.com/pomerium/pomerium/internal/sessions" ) +var _ sessions.SessionStore = &Store{} +var _ sessions.SessionLoader = &Store{} + +// timeNow is time.Now but pulled out as a variable for tests. +var timeNow = time.Now + const ( // ChunkedCanaryByte is the byte value used as a canary prefix to distinguish if // the cookie is multi-part or not. This constant *should not* be valid @@ -25,8 +32,8 @@ const ( MaxNumChunks = 5 ) -// CookieStore implements the session store interface for session cookies. -type CookieStore struct { +// Store implements the session store interface for session cookies. +type Store struct { Name string Domain string Expire time.Duration @@ -37,8 +44,8 @@ type CookieStore struct { decoder encoding.Unmarshaler } -// CookieOptions holds options for CookieStore -type CookieOptions struct { +// Options holds options for Store +type Options struct { Name string Domain string Expire time.Duration @@ -46,8 +53,9 @@ type CookieOptions struct { Secure bool } -// NewCookieStore returns a new session with ciphers for each of the cookie secrets -func NewCookieStore(opts *CookieOptions, encoder encoding.MarshalUnmarshaler) (*CookieStore, error) { +// NewStore returns a new store that implements the SessionStore interface +// using http cookies. +func NewStore(opts *Options, encoder encoding.MarshalUnmarshaler) (sessions.SessionStore, error) { cs, err := NewCookieLoader(opts, encoder) if err != nil { return nil, err @@ -56,12 +64,13 @@ func NewCookieStore(opts *CookieOptions, encoder encoding.MarshalUnmarshaler) (* return cs, nil } -// NewCookieLoader returns a new session with ciphers for each of the cookie secrets -func NewCookieLoader(opts *CookieOptions, dencoder encoding.Unmarshaler) (*CookieStore, error) { +// NewCookieLoader returns a new store that implements the SessionLoader +// interface using http cookies. +func NewCookieLoader(opts *Options, dencoder encoding.Unmarshaler) (*Store, error) { if dencoder == nil { return nil, fmt.Errorf("internal/sessions: dencoder cannot be nil") } - cs, err := newCookieStore(opts) + cs, err := newStore(opts) if err != nil { return nil, err } @@ -69,12 +78,12 @@ func NewCookieLoader(opts *CookieOptions, dencoder encoding.Unmarshaler) (*Cooki return cs, nil } -func newCookieStore(opts *CookieOptions) (*CookieStore, error) { +func newStore(opts *Options) (*Store, error) { if opts.Name == "" { return nil, fmt.Errorf("internal/sessions: cookie name cannot be empty") } - return &CookieStore{ + return &Store{ Name: opts.Name, Secure: opts.Secure, HTTPOnly: opts.HTTPOnly, @@ -83,7 +92,7 @@ func newCookieStore(opts *CookieOptions) (*CookieStore, error) { }, nil } -func (cs *CookieStore) makeCookie(value string) *http.Cookie { +func (cs *Store) makeCookie(value string) *http.Cookie { return &http.Cookie{ Name: cs.Name, Value: value, @@ -96,7 +105,7 @@ func (cs *CookieStore) makeCookie(value string) *http.Cookie { } // ClearSession clears the session cookie from a request -func (cs *CookieStore) ClearSession(w http.ResponseWriter, r *http.Request) { +func (cs *Store) ClearSession(w http.ResponseWriter, r *http.Request) { c := cs.makeCookie("") c.MaxAge = -1 c.Expires = timeNow().Add(-time.Hour) @@ -115,51 +124,51 @@ func getCookies(r *http.Request, name string) []*http.Cookie { } // LoadSession returns a State from the cookie in the request. -func (cs *CookieStore) LoadSession(r *http.Request) (*State, error) { +func (cs *Store) LoadSession(r *http.Request) (*sessions.State, error) { cookies := getCookies(r, cs.Name) if len(cookies) == 0 { - return nil, ErrNoSessionFound + return nil, sessions.ErrNoSessionFound } for _, cookie := range cookies { data := loadChunkedCookie(r, cookie) - session := &State{} + session := &sessions.State{} err := cs.decoder.Unmarshal([]byte(data), session) if err == nil { return session, nil } } - return nil, ErrMalformed + return nil, sessions.ErrMalformed } // SaveSession saves a session state to a request's cookie store. -func (cs *CookieStore) SaveSession(w http.ResponseWriter, _ *http.Request, x interface{}) error { +func (cs *Store) SaveSession(w http.ResponseWriter, _ *http.Request, x interface{}) error { var value string - if cs.encoder != nil { + switch v := x.(type) { + case []byte: + value = string(v) + case string: + value = v + default: + if cs.encoder == nil { + return errors.New("internal/sessions: cannot save non-string type") + } data, err := cs.encoder.Marshal(x) if err != nil { return err } value = string(data) - } else { - switch v := x.(type) { - case []byte: - value = string(v) - case string: - value = v - default: - return errors.New("internal/sessions: cannot save non-string type") - } } + cs.setSessionCookie(w, value) return nil } -func (cs *CookieStore) setSessionCookie(w http.ResponseWriter, val string) { +func (cs *Store) setSessionCookie(w http.ResponseWriter, val string) { cs.setCookie(w, cs.makeCookie(val)) } -func (cs *CookieStore) setCookie(w http.ResponseWriter, cookie *http.Cookie) { +func (cs *Store) setCookie(w http.ResponseWriter, cookie *http.Cookie) { if len(cookie.String()) <= MaxChunkSize { http.SetCookie(w, cookie) return @@ -180,20 +189,26 @@ func (cs *CookieStore) setCookie(w http.ResponseWriter, cookie *http.Cookie) { } func loadChunkedCookie(r *http.Request, c *http.Cookie) string { - data := c.Value - // if the first byte is our canary byte, we need to handle the multipart bit - if []byte(c.Value)[0] == ChunkedCanaryByte { - var b strings.Builder - fmt.Fprintf(&b, "%s", data[1:]) - for i := 1; i <= MaxNumChunks; i++ { - next, err := r.Cookie(fmt.Sprintf("%s_%d", c.Name, i)) - if err != nil { - break // break if we can't find the next cookie - } - fmt.Fprintf(&b, "%s", next.Value) - } - data = b.String() + if len(c.Value) == 0 { + return "" } + // if the first byte is our canary byte, we need to handle the multipart bit + if []byte(c.Value)[0] != ChunkedCanaryByte { + return c.Value + } + + data := c.Value + var b strings.Builder + fmt.Fprintf(&b, "%s", data[1:]) + for i := 1; i <= MaxNumChunks; i++ { + next, err := r.Cookie(fmt.Sprintf("%s_%d", c.Name, i)) + if err != nil { + break // break if we can't find the next cookie + } + fmt.Fprintf(&b, "%s", next.Value) + } + data = b.String() + return data } diff --git a/internal/sessions/cookie_store_test.go b/internal/sessions/cookie/cookie_store_test.go similarity index 55% rename from internal/sessions/cookie_store_test.go rename to internal/sessions/cookie/cookie_store_test.go index 50ba5209f..4a4475868 100644 --- a/internal/sessions/cookie_store_test.go +++ b/internal/sessions/cookie/cookie_store_test.go @@ -1,4 +1,4 @@ -package sessions // import "github.com/pomerium/pomerium/internal/sessions" +package cookie // import "github.com/pomerium/pomerium/internal/sessions/cookie" import ( "crypto/rand" @@ -13,12 +13,13 @@ import ( "github.com/pomerium/pomerium/internal/encoding" "github.com/pomerium/pomerium/internal/encoding/ecjson" "github.com/pomerium/pomerium/internal/encoding/mock" + "github.com/pomerium/pomerium/internal/sessions" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" ) -func TestNewCookieStore(t *testing.T) { +func TestNewStore(t *testing.T) { cipher, err := cryptutil.NewAEADCipher(cryptutil.NewKey()) if err != nil { t.Fatal(err) @@ -26,28 +27,28 @@ func TestNewCookieStore(t *testing.T) { encoder := ecjson.New(cipher) tests := []struct { name string - opts *CookieOptions + opts *Options encoder encoding.MarshalUnmarshaler - want *CookieStore + want sessions.SessionStore wantErr bool }{ - {"good", &CookieOptions{Name: "_cookie", Secure: true, HTTPOnly: true, Domain: "pomerium.io", Expire: 10 * time.Second}, encoder, &CookieStore{Name: "_cookie", Secure: true, HTTPOnly: true, Domain: "pomerium.io", Expire: 10 * time.Second}, false}, - {"missing name", &CookieOptions{Name: "", Secure: true, HTTPOnly: true, Domain: "pomerium.io", Expire: 10 * time.Second}, encoder, nil, true}, - {"missing encoder", &CookieOptions{Name: "_cookie", Secure: true, HTTPOnly: true, Domain: "pomerium.io", Expire: 10 * time.Second}, nil, nil, true}, + {"good", &Options{Name: "_cookie", Secure: true, HTTPOnly: true, Domain: "pomerium.io", Expire: 10 * time.Second}, encoder, &Store{Name: "_cookie", Secure: true, HTTPOnly: true, Domain: "pomerium.io", Expire: 10 * time.Second}, false}, + {"missing name", &Options{Name: "", Secure: true, HTTPOnly: true, Domain: "pomerium.io", Expire: 10 * time.Second}, encoder, nil, true}, + {"missing encoder", &Options{Name: "_cookie", Secure: true, HTTPOnly: true, Domain: "pomerium.io", Expire: 10 * time.Second}, nil, nil, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := NewCookieStore(tt.opts, tt.encoder) + got, err := NewStore(tt.opts, tt.encoder) if (err != nil) != tt.wantErr { - t.Errorf("NewCookieStore() error = %v, wantErr %v", err, tt.wantErr) + t.Errorf("NewStore() error = %v, wantErr %v", err, tt.wantErr) return } cmpOpts := []cmp.Option{ - cmpopts.IgnoreUnexported(CookieStore{}), + cmpopts.IgnoreUnexported(Store{}), } if diff := cmp.Diff(got, tt.want, cmpOpts...); diff != "" { - t.Errorf("NewCookieStore() = %s", diff) + t.Errorf("NewStore() = %s", diff) } }) } @@ -60,14 +61,14 @@ func TestNewCookieLoader(t *testing.T) { encoder := ecjson.New(cipher) tests := []struct { name string - opts *CookieOptions + opts *Options encoder encoding.MarshalUnmarshaler - want *CookieStore + want *Store wantErr bool }{ - {"good", &CookieOptions{Name: "_cookie", Secure: true, HTTPOnly: true, Domain: "pomerium.io", Expire: 10 * time.Second}, encoder, &CookieStore{Name: "_cookie", Secure: true, HTTPOnly: true, Domain: "pomerium.io", Expire: 10 * time.Second}, false}, - {"missing name", &CookieOptions{Name: "", Secure: true, HTTPOnly: true, Domain: "pomerium.io", Expire: 10 * time.Second}, encoder, nil, true}, - {"missing encoder", &CookieOptions{Name: "_cookie", Secure: true, HTTPOnly: true, Domain: "pomerium.io", Expire: 10 * time.Second}, nil, nil, true}, + {"good", &Options{Name: "_cookie", Secure: true, HTTPOnly: true, Domain: "pomerium.io", Expire: 10 * time.Second}, encoder, &Store{Name: "_cookie", Secure: true, HTTPOnly: true, Domain: "pomerium.io", Expire: 10 * time.Second}, false}, + {"missing name", &Options{Name: "", Secure: true, HTTPOnly: true, Domain: "pomerium.io", Expire: 10 * time.Second}, encoder, nil, true}, + {"missing encoder", &Options{Name: "_cookie", Secure: true, HTTPOnly: true, Domain: "pomerium.io", Expire: 10 * time.Second}, nil, nil, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -77,7 +78,7 @@ func TestNewCookieLoader(t *testing.T) { return } cmpOpts := []cmp.Option{ - cmpopts.IgnoreUnexported(CookieStore{}), + cmpopts.IgnoreUnexported(Store{}), } if diff := cmp.Diff(got, tt.want, cmpOpts...); diff != "" { @@ -87,7 +88,7 @@ func TestNewCookieLoader(t *testing.T) { } } -func TestCookieStore_SaveSession(t *testing.T) { +func TestStore_SaveSession(t *testing.T) { c, err := cryptutil.NewAEADCipher(cryptutil.NewKey()) if err != nil { t.Fatal(err) @@ -106,17 +107,17 @@ func TestCookieStore_SaveSession(t *testing.T) { wantErr bool wantLoadErr bool }{ - {"good", &State{Email: "user@domain.com", User: "user"}, ecjson.New(c), ecjson.New(c), false, false}, - {"bad cipher", &State{Email: "user@domain.com", User: "user"}, nil, nil, true, true}, - {"huge cookie", &State{Subject: fmt.Sprintf("%x", hugeString), Email: "user@domain.com", User: "user"}, ecjson.New(c), ecjson.New(c), false, false}, - {"marshal error", &State{Email: "user@domain.com", User: "user"}, mock.Encoder{MarshalError: errors.New("error")}, ecjson.New(c), true, true}, - {"nil encoder cannot save non string type", &State{Email: "user@domain.com", User: "user"}, nil, ecjson.New(c), true, true}, + {"good", &sessions.State{Email: "user@domain.com", User: "user"}, ecjson.New(c), ecjson.New(c), false, false}, + {"bad cipher", &sessions.State{Email: "user@domain.com", User: "user"}, nil, nil, true, true}, + {"huge cookie", &sessions.State{Subject: fmt.Sprintf("%x", hugeString), Email: "user@domain.com", User: "user"}, ecjson.New(c), ecjson.New(c), false, false}, + {"marshal error", &sessions.State{Email: "user@domain.com", User: "user"}, mock.Encoder{MarshalError: errors.New("error")}, ecjson.New(c), true, true}, + {"nil encoder cannot save non string type", &sessions.State{Email: "user@domain.com", User: "user"}, nil, ecjson.New(c), true, true}, {"good marshal string directly", cryptutil.NewBase64Key(), nil, ecjson.New(c), false, true}, {"good marshal bytes directly", cryptutil.NewKey(), nil, ecjson.New(c), false, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - s := &CookieStore{ + s := &Store{ Name: "_pomerium", Secure: true, HTTPOnly: true, @@ -130,7 +131,7 @@ func TestCookieStore_SaveSession(t *testing.T) { w := httptest.NewRecorder() if err := s.SaveSession(w, r, tt.State); (err != nil) != tt.wantErr { - t.Errorf("CookieStore.SaveSession() error = %v, wantErr %v", err, tt.wantErr) + t.Errorf("Store.SaveSession() error = %v, wantErr %v", err, tt.wantErr) } r = httptest.NewRequest("GET", "/", nil) for _, cookie := range w.Result().Cookies() { @@ -143,11 +144,11 @@ func TestCookieStore_SaveSession(t *testing.T) { return } cmpOpts := []cmp.Option{ - cmpopts.IgnoreUnexported(State{}), + cmpopts.IgnoreUnexported(sessions.State{}), } if err == nil { if diff := cmp.Diff(state, tt.State, cmpOpts...); diff != "" { - t.Errorf("CookieStore.LoadSession() got = %s", diff) + t.Errorf("Store.LoadSession() got = %s", diff) } } w = httptest.NewRecorder() diff --git a/internal/sessions/cookie/middleware_test.go b/internal/sessions/cookie/middleware_test.go new file mode 100644 index 000000000..a077330a9 --- /dev/null +++ b/internal/sessions/cookie/middleware_test.go @@ -0,0 +1,90 @@ +package cookie // import "github.com/pomerium/pomerium/internal/sessions/cookie" + +import ( + "fmt" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/pomerium/pomerium/internal/sessions" + + "github.com/google/go-cmp/cmp" + "github.com/pomerium/pomerium/internal/cryptutil" + "github.com/pomerium/pomerium/internal/encoding/ecjson" + "gopkg.in/square/go-jose.v2/jwt" +) + +func testAuthorizer(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, err := sessions.FromContext(r.Context()) + if err != nil { + http.Error(w, err.Error(), http.StatusUnauthorized) + return + } + next.ServeHTTP(w, r) + }) +} + +func TestVerifier(t *testing.T) { + fnh := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + fmt.Fprint(w, http.StatusText(http.StatusOK)) + w.WriteHeader(http.StatusOK) + }) + + tests := []struct { + name string + state sessions.State + + wantBody string + wantStatus int + }{ + {"good cookie session", sessions.State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}, http.StatusText(http.StatusOK), http.StatusOK}, + {"expired cookie", sessions.State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}, "internal/sessions: validation failed, token is expired (exp)\n", http.StatusUnauthorized}, + {"malformed cookie", sessions.State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}, "internal/sessions: session is malformed\n", http.StatusUnauthorized}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cipher, err := cryptutil.NewAEADCipherFromBase64(cryptutil.NewBase64Key()) + encoder := ecjson.New(cipher) + if err != nil { + t.Fatal(err) + } + encSession, err := encoder.Marshal(&tt.state) + if err != nil { + t.Fatal(err) + } + if strings.Contains(tt.name, "malformed") { + // add some garbage to the end of the string + encSession = append(encSession, cryptutil.NewKey()...) + } + + cs, err := NewStore(&Options{ + Name: "_pomerium", + }, encoder) + if err != nil { + t.Fatal(err) + } + + r := httptest.NewRequest(http.MethodGet, "/", nil) + r.Header.Set("Accept", "application/json") + w := httptest.NewRecorder() + r.AddCookie(&http.Cookie{Name: "_pomerium", Value: string(encSession)}) + + got := sessions.RetrieveSession(cs)(testAuthorizer((fnh))) + got.ServeHTTP(w, r) + + gotBody := w.Body.String() + gotStatus := w.Result().StatusCode + + if diff := cmp.Diff(gotBody, tt.wantBody); diff != "" { + t.Errorf("RetrieveSession() = %v", diff) + } + if diff := cmp.Diff(gotStatus, tt.wantStatus); diff != "" { + t.Errorf("RetrieveSession() = %v", diff) + } + }) + } +} diff --git a/internal/sessions/errors.go b/internal/sessions/errors.go new file mode 100644 index 000000000..fec29a680 --- /dev/null +++ b/internal/sessions/errors.go @@ -0,0 +1,28 @@ +package sessions // import "github.com/pomerium/pomerium/internal/sessions" + +import ( + "errors" +) + +var ( + // ErrNoSessionFound is the error for when no session is found. + ErrNoSessionFound = errors.New("internal/sessions: session is not found") + + // ErrMalformed is the error for when a session is found but is malformed. + ErrMalformed = errors.New("internal/sessions: session is malformed") + + // ErrNotValidYet indicates that token is used before time indicated in nbf claim. + ErrNotValidYet = errors.New("internal/sessions: validation failed, token not valid yet (nbf)") + + // ErrExpired indicates that token is used after expiry time indicated in exp claim. + ErrExpired = errors.New("internal/sessions: validation failed, token is expired (exp)") + + // ErrExpiryRequired indicates that the token does not contain a valid expiry (exp) claim. + ErrExpiryRequired = errors.New("internal/sessions: validation failed, token expiry (exp) is required") + + // ErrIssuedInTheFuture indicates that the iat field is in the future. + ErrIssuedInTheFuture = errors.New("internal/sessions: validation field, token issued in the future (iat)") + + // ErrInvalidAudience indicated invalid aud claim. + ErrInvalidAudience = errors.New("internal/sessions: validation failed, invalid audience claim (aud)") +) diff --git a/internal/sessions/header_store.go b/internal/sessions/header/header_store.go similarity index 68% rename from internal/sessions/header_store.go rename to internal/sessions/header/header_store.go index 58ec3c182..132cfdfd7 100644 --- a/internal/sessions/header_store.go +++ b/internal/sessions/header/header_store.go @@ -1,35 +1,38 @@ -package sessions // import "github.com/pomerium/pomerium/internal/sessions" +package header // import "github.com/pomerium/pomerium/internal/sessions/header" import ( "net/http" "strings" "github.com/pomerium/pomerium/internal/encoding" + "github.com/pomerium/pomerium/internal/sessions" ) +var _ sessions.SessionLoader = &Store{} + const ( defaultAuthHeader = "Authorization" defaultAuthType = "Bearer" ) -// HeaderStore implements the load session store interface using http +// Store implements the load session store interface using http // authorization headers. -type HeaderStore struct { +type Store struct { authHeader string authType string encoder encoding.Unmarshaler } -// NewHeaderStore returns a new header store for loading sessions from +// NewStore returns a new header store for loading sessions from // authorization header as defined in as defined in rfc2617 // // NOTA BENE: While most servers do not log Authorization headers by default, // you should ensure no other services are logging or leaking your auth headers. -func NewHeaderStore(enc encoding.Unmarshaler, headerType string) *HeaderStore { +func NewStore(enc encoding.Unmarshaler, headerType string) *Store { if headerType == "" { headerType = defaultAuthType } - return &HeaderStore{ + return &Store{ authHeader: defaultAuthHeader, authType: headerType, encoder: enc, @@ -37,14 +40,14 @@ func NewHeaderStore(enc encoding.Unmarshaler, headerType string) *HeaderStore { } // LoadSession tries to retrieve the token string from the Authorization header. -func (as *HeaderStore) LoadSession(r *http.Request) (*State, error) { +func (as *Store) LoadSession(r *http.Request) (*sessions.State, error) { cipherText := TokenFromHeader(r, as.authHeader, as.authType) if cipherText == "" { - return nil, ErrNoSessionFound + return nil, sessions.ErrNoSessionFound } - var session State + var session sessions.State if err := as.encoder.Unmarshal([]byte(cipherText), &session); err != nil { - return nil, ErrMalformed + return nil, sessions.ErrMalformed } return &session, nil } diff --git a/internal/sessions/header/middleware_test.go b/internal/sessions/header/middleware_test.go new file mode 100644 index 000000000..e1f45f47a --- /dev/null +++ b/internal/sessions/header/middleware_test.go @@ -0,0 +1,90 @@ +package header // import "github.com/pomerium/pomerium/internal/sessions/header" + +import ( + "fmt" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/pomerium/pomerium/internal/cryptutil" + "github.com/pomerium/pomerium/internal/encoding/ecjson" + "github.com/pomerium/pomerium/internal/sessions" + + "github.com/google/go-cmp/cmp" + "gopkg.in/square/go-jose.v2/jwt" +) + +func testAuthorizer(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, err := sessions.FromContext(r.Context()) + if err != nil { + http.Error(w, err.Error(), http.StatusUnauthorized) + return + } + next.ServeHTTP(w, r) + }) +} + +func TestVerifier(t *testing.T) { + fnh := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + fmt.Fprint(w, http.StatusText(http.StatusOK)) + w.WriteHeader(http.StatusOK) + }) + + tests := []struct { + name string + authType string + state sessions.State + wantBody string + wantStatus int + }{ + {"good auth header session", "Bearer ", sessions.State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}, http.StatusText(http.StatusOK), http.StatusOK}, + {"expired auth header", "Bearer ", sessions.State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}, "internal/sessions: validation failed, token is expired (exp)\n", http.StatusUnauthorized}, + {"malformed auth header", "Bearer ", sessions.State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}, "internal/sessions: session is malformed\n", http.StatusUnauthorized}, + {"empty auth header", "Bearer ", sessions.State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}, "internal/sessions: session is not found\n", http.StatusUnauthorized}, + {"bad auth type", "bees ", sessions.State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}, "internal/sessions: session is not found\n", http.StatusUnauthorized}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cipher, err := cryptutil.NewAEADCipherFromBase64(cryptutil.NewBase64Key()) + encoder := ecjson.New(cipher) + if err != nil { + t.Fatal(err) + } + encSession, err := encoder.Marshal(&tt.state) + if err != nil { + t.Fatal(err) + } + if strings.Contains(tt.name, "malformed") { + // add some garbage to the end of the string + encSession = append(encSession, cryptutil.NewKey()...) + } + s := NewStore(encoder, "") + + r := httptest.NewRequest(http.MethodGet, "/", nil) + r.Header.Set("Accept", "application/json") + w := httptest.NewRecorder() + + if strings.Contains(tt.name, "empty") { + encSession = []byte("") + } + r.Header.Set("Authorization", tt.authType+string(encSession)) + + got := sessions.RetrieveSession(s)(testAuthorizer((fnh))) + got.ServeHTTP(w, r) + + gotBody := w.Body.String() + gotStatus := w.Result().StatusCode + + if diff := cmp.Diff(gotBody, tt.wantBody); diff != "" { + t.Errorf("RetrieveSession() = %v", diff) + } + if diff := cmp.Diff(gotStatus, tt.wantStatus); diff != "" { + t.Errorf("RetrieveSession() = %v", diff) + } + }) + } +} diff --git a/internal/sessions/middleware.go b/internal/sessions/middleware.go index f7a619b90..738a048e2 100644 --- a/internal/sessions/middleware.go +++ b/internal/sessions/middleware.go @@ -44,7 +44,7 @@ func retrieveFromRequest(r *http.Request, sessions ...SessionLoader) (*State, er } if state != nil { err := state.Verify(urlutil.StripPort(r.Host)) - return state, err // N.B.: state is _not_ nil_ + return state, err // N.B.: state is _not_ nil } } diff --git a/internal/sessions/middleware_test.go b/internal/sessions/middleware_test.go index ca7baabbc..0435dbf9c 100644 --- a/internal/sessions/middleware_test.go +++ b/internal/sessions/middleware_test.go @@ -2,16 +2,14 @@ package sessions import ( "context" + "errors" "fmt" "net/http" "net/http/httptest" - "strings" "testing" "time" "github.com/google/go-cmp/cmp" - "github.com/pomerium/pomerium/internal/cryptutil" - "github.com/pomerium/pomerium/internal/encoding/ecjson" "gopkg.in/square/go-jose.v2/jwt" ) @@ -39,103 +37,6 @@ func TestNewContext(t *testing.T) { } } -func testAuthorizer(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - _, err := FromContext(r.Context()) - if err != nil { - http.Error(w, err.Error(), http.StatusUnauthorized) - return - } - next.ServeHTTP(w, r) - }) -} - -func TestVerifier(t *testing.T) { - fnh := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "text/plain; charset=utf-8") - fmt.Fprint(w, http.StatusText(http.StatusOK)) - w.WriteHeader(http.StatusOK) - }) - - tests := []struct { - name string - // s SessionStore - state State - - cookie bool - header bool - param bool - - wantBody string - wantStatus int - }{ - {"good cookie session", State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}, true, false, false, http.StatusText(http.StatusOK), http.StatusOK}, - {"expired cookie", State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}, true, false, false, "internal/sessions: validation failed, token is expired (exp)\n", http.StatusUnauthorized}, - {"malformed cookie", State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}, true, false, false, "internal/sessions: session is malformed\n", http.StatusUnauthorized}, - {"good auth header session", State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}, false, true, false, http.StatusText(http.StatusOK), http.StatusOK}, - {"expired auth header", State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}, false, true, false, "internal/sessions: validation failed, token is expired (exp)\n", http.StatusUnauthorized}, - {"malformed auth header", State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}, false, true, false, "internal/sessions: session is malformed\n", http.StatusUnauthorized}, - {"good auth query param session", State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}, false, true, true, http.StatusText(http.StatusOK), http.StatusOK}, - {"expired auth query param", State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}, false, false, true, "internal/sessions: validation failed, token is expired (exp)\n", http.StatusUnauthorized}, - {"malformed auth query param", State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}, false, false, true, "internal/sessions: session is malformed\n", http.StatusUnauthorized}, - {"no session", State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}, false, false, false, "internal/sessions: session is not found\n", http.StatusUnauthorized}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - cipher, err := cryptutil.NewAEADCipherFromBase64(cryptutil.NewBase64Key()) - encoder := ecjson.New(cipher) - if err != nil { - t.Fatal(err) - } - encSession, err := encoder.Marshal(&tt.state) - if err != nil { - t.Fatal(err) - } - if strings.Contains(tt.name, "malformed") { - // add some garbage to the end of the string - encSession = append(encSession, cryptutil.NewKey()...) - } - - cs, err := NewCookieStore(&CookieOptions{ - Name: "_pomerium", - }, encoder) - if err != nil { - t.Fatal(err) - } - as := NewHeaderStore(encoder, "") - - qp := NewQueryParamStore(encoder, "") - - r := httptest.NewRequest(http.MethodGet, "/", nil) - r.Header.Set("Accept", "application/json") - w := httptest.NewRecorder() - if tt.cookie { - r.AddCookie(&http.Cookie{Name: "_pomerium", Value: string(encSession)}) - } else if tt.header { - r.Header.Set("Authorization", "Bearer "+string(encSession)) - } else if tt.param { - q := r.URL.Query() - - q.Set("pomerium_session", string(encSession)) - r.URL.RawQuery = q.Encode() - } - - got := RetrieveSession(cs, as, qp)(testAuthorizer((fnh))) - got.ServeHTTP(w, r) - - gotBody := w.Body.String() - gotStatus := w.Result().StatusCode - - if diff := cmp.Diff(gotBody, tt.wantBody); diff != "" { - t.Errorf("RetrieveSession() = %v", diff) - } - if diff := cmp.Diff(gotStatus, tt.wantStatus); diff != "" { - t.Errorf("RetrieveSession() = %v", diff) - } - }) - } -} - func Test_contextKey_String(t *testing.T) { tests := []struct { name string @@ -155,3 +56,80 @@ func Test_contextKey_String(t *testing.T) { }) } } + +func testAuthorizer(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, err := FromContext(r.Context()) + if err != nil { + http.Error(w, err.Error(), http.StatusUnauthorized) + return + } + next.ServeHTTP(w, r) + }) +} + +var _ SessionStore = &store{} + +// Store is a mock implementation of the SessionStore interface +type store struct { + ResponseSession string + Session *State + SaveError error + LoadError error +} + +// ClearSession clears the ResponseSession +func (ms *store) ClearSession(http.ResponseWriter, *http.Request) { + ms.ResponseSession = "" +} + +// LoadSession returns the session and a error +func (ms store) LoadSession(*http.Request) (*State, error) { + return ms.Session, ms.LoadError +} + +// SaveSession returns a save error. +func (ms store) SaveSession(http.ResponseWriter, *http.Request, interface{}) error { + return ms.SaveError +} + +func TestVerifier(t *testing.T) { + fnh := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + fmt.Fprint(w, http.StatusText(http.StatusOK)) + w.WriteHeader(http.StatusOK) + }) + + tests := []struct { + name string + store store + state State + wantBody string + wantStatus int + }{ + {"empty session", store{}, State{}, "internal/sessions: session is not found\n", 401}, + {"simple good load", store{Session: &State{Subject: "hi", Expiry: jwt.NewNumericDate(time.Now().Add(time.Second))}}, State{}, "OK", 200}, + {"empty session", store{LoadError: errors.New("err")}, State{}, "err\n", 401}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + + r := httptest.NewRequest(http.MethodGet, "/", nil) + r.Header.Set("Accept", "application/json") + w := httptest.NewRecorder() + + got := RetrieveSession(tt.store)(testAuthorizer((fnh))) + got.ServeHTTP(w, r) + + gotBody := w.Body.String() + gotStatus := w.Result().StatusCode + + if diff := cmp.Diff(gotBody, tt.wantBody); diff != "" { + t.Errorf("RetrieveSession() = %v", diff) + } + if diff := cmp.Diff(gotStatus, tt.wantStatus); diff != "" { + t.Errorf("RetrieveSession() = %v", diff) + } + }) + } +} diff --git a/internal/sessions/mock/mock_store.go b/internal/sessions/mock/mock_store.go new file mode 100644 index 000000000..c787109a8 --- /dev/null +++ b/internal/sessions/mock/mock_store.go @@ -0,0 +1,33 @@ +package mock // import "github.com/pomerium/pomerium/internal/sessions/mock" + +import ( + "net/http" + + "github.com/pomerium/pomerium/internal/sessions" +) + +var _ sessions.SessionStore = &Store{} +var _ sessions.SessionLoader = &Store{} + +// Store is a mock implementation of the SessionStore interface +type Store struct { + ResponseSession string + Session *sessions.State + SaveError error + LoadError error +} + +// ClearSession clears the ResponseSession +func (ms *Store) ClearSession(http.ResponseWriter, *http.Request) { + ms.ResponseSession = "" +} + +// LoadSession returns the session and a error +func (ms Store) LoadSession(*http.Request) (*sessions.State, error) { + return ms.Session, ms.LoadError +} + +// SaveSession returns a save error. +func (ms Store) SaveSession(http.ResponseWriter, *http.Request, interface{}) error { + return ms.SaveError +} diff --git a/internal/sessions/mock_store_test.go b/internal/sessions/mock/mock_store_test.go similarity index 75% rename from internal/sessions/mock_store_test.go rename to internal/sessions/mock/mock_store_test.go index d3d9d71a7..9f2ae4c3a 100644 --- a/internal/sessions/mock_store_test.go +++ b/internal/sessions/mock/mock_store_test.go @@ -1,26 +1,28 @@ -package sessions // import "github.com/pomerium/pomerium/internal/sessions" +package mock // import "github.com/pomerium/pomerium/internal/sessions/mock" import ( "reflect" "testing" + + "github.com/pomerium/pomerium/internal/sessions" ) -func TestMockSessionStore(t *testing.T) { +func TestStore(t *testing.T) { tests := []struct { name string - mockCSRF *MockSessionStore - saveSession *State + mockCSRF *Store + saveSession *sessions.State wantLoadErr bool wantSaveErr bool }{ {"basic", - &MockSessionStore{ + &Store{ ResponseSession: "test", - Session: &State{Subject: "0101"}, + Session: &sessions.State{Subject: "0101"}, SaveError: nil, LoadError: nil, }, - &State{Subject: "0101"}, + &sessions.State{Subject: "0101"}, false, false}, } diff --git a/internal/sessions/mock_store.go b/internal/sessions/mock_store.go deleted file mode 100644 index e3d3c7c91..000000000 --- a/internal/sessions/mock_store.go +++ /dev/null @@ -1,28 +0,0 @@ -package sessions // import "github.com/pomerium/pomerium/internal/sessions" - -import ( - "net/http" -) - -// MockSessionStore is a mock implementation of the SessionStore interface -type MockSessionStore struct { - ResponseSession string - Session *State - SaveError error - LoadError error -} - -// ClearSession clears the ResponseSession -func (ms *MockSessionStore) ClearSession(http.ResponseWriter, *http.Request) { - ms.ResponseSession = "" -} - -// LoadSession returns the session and a error -func (ms MockSessionStore) LoadSession(*http.Request) (*State, error) { - return ms.Session, ms.LoadError -} - -// SaveSession returns a save error. -func (ms MockSessionStore) SaveSession(http.ResponseWriter, *http.Request, interface{}) error { - return ms.SaveError -} diff --git a/internal/sessions/queryparam/middleware_test.go b/internal/sessions/queryparam/middleware_test.go new file mode 100644 index 000000000..c9aab9c59 --- /dev/null +++ b/internal/sessions/queryparam/middleware_test.go @@ -0,0 +1,92 @@ +package queryparam // import "github.com/pomerium/pomerium/internal/sessions/queryparam" + +import ( + "fmt" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/pomerium/pomerium/internal/cryptutil" + "github.com/pomerium/pomerium/internal/encoding/ecjson" + "github.com/pomerium/pomerium/internal/sessions" + + "github.com/google/go-cmp/cmp" + "gopkg.in/square/go-jose.v2/jwt" +) + +func testAuthorizer(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, err := sessions.FromContext(r.Context()) + if err != nil { + http.Error(w, err.Error(), http.StatusUnauthorized) + return + } + next.ServeHTTP(w, r) + }) +} + +func TestVerifier(t *testing.T) { + fnh := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + fmt.Fprint(w, http.StatusText(http.StatusOK)) + w.WriteHeader(http.StatusOK) + }) + + tests := []struct { + name string + state sessions.State + + wantBody string + wantStatus int + }{ + {"good auth query param session", sessions.State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}, http.StatusText(http.StatusOK), http.StatusOK}, + {"expired auth query param", sessions.State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}, "internal/sessions: validation failed, token is expired (exp)\n", http.StatusUnauthorized}, + {"malformed auth query param", sessions.State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}, "internal/sessions: session is malformed\n", http.StatusUnauthorized}, + {"empty auth query param", sessions.State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}, "internal/sessions: session is not found\n", http.StatusUnauthorized}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cipher, err := cryptutil.NewAEADCipherFromBase64(cryptutil.NewBase64Key()) + encoder := ecjson.New(cipher) + if err != nil { + t.Fatal(err) + } + encSession, err := encoder.Marshal(&tt.state) + if err != nil { + t.Fatal(err) + } + if strings.Contains(tt.name, "malformed") { + // add some garbage to the end of the string + encSession = append(encSession, cryptutil.NewKey()...) + } + + s := NewStore(encoder, "") + + r := httptest.NewRequest(http.MethodGet, "/", nil) + r.Header.Set("Accept", "application/json") + w := httptest.NewRecorder() + + q := r.URL.Query() + if strings.Contains(tt.name, "empty") { + encSession = []byte("") + } + q.Set("pomerium_session", string(encSession)) + r.URL.RawQuery = q.Encode() + + got := sessions.RetrieveSession(s)(testAuthorizer((fnh))) + got.ServeHTTP(w, r) + + gotBody := w.Body.String() + gotStatus := w.Result().StatusCode + + if diff := cmp.Diff(gotBody, tt.wantBody); diff != "" { + t.Errorf("RetrieveSession() = %v", diff) + } + if diff := cmp.Diff(gotStatus, tt.wantStatus); diff != "" { + t.Errorf("RetrieveSession() = %v", diff) + } + }) + } +} diff --git a/internal/sessions/query_store.go b/internal/sessions/queryparam/query_store.go similarity index 62% rename from internal/sessions/query_store.go rename to internal/sessions/queryparam/query_store.go index cb8e3a484..3ce7c25b9 100644 --- a/internal/sessions/query_store.go +++ b/internal/sessions/queryparam/query_store.go @@ -1,33 +1,37 @@ -package sessions // import "github.com/pomerium/pomerium/internal/sessions" +package queryparam // import "github.com/pomerium/pomerium/internal/sessions/queryparam" import ( "net/http" "github.com/pomerium/pomerium/internal/encoding" + "github.com/pomerium/pomerium/internal/sessions" ) +var _ sessions.SessionStore = &Store{} +var _ sessions.SessionLoader = &Store{} + const ( defaultQueryParamKey = "pomerium_session" ) -// QueryParamStore implements the load session store interface using http +// Store implements the load session store interface using http // query strings / query parameters. -type QueryParamStore struct { +type Store struct { queryParamKey string encoder encoding.Marshaler decoder encoding.Unmarshaler } -// NewQueryParamStore returns a new query param store for loading sessions from +// NewStore returns a new query param store for loading sessions from // query strings / query parameters. // // NOTA BENE: By default, most servers _DO_ log query params, the leaking or // accidental logging of which should be considered a security issue. -func NewQueryParamStore(enc encoding.MarshalUnmarshaler, qp string) *QueryParamStore { +func NewStore(enc encoding.MarshalUnmarshaler, qp string) *Store { if qp == "" { qp = defaultQueryParamKey } - return &QueryParamStore{ + return &Store{ queryParamKey: qp, encoder: enc, decoder: enc, @@ -35,27 +39,27 @@ func NewQueryParamStore(enc encoding.MarshalUnmarshaler, qp string) *QueryParamS } // LoadSession tries to retrieve the token string from URL query parameters. -func (qp *QueryParamStore) LoadSession(r *http.Request) (*State, error) { +func (qp *Store) LoadSession(r *http.Request) (*sessions.State, error) { cipherText := r.URL.Query().Get(qp.queryParamKey) if cipherText == "" { - return nil, ErrNoSessionFound + return nil, sessions.ErrNoSessionFound } - var session State + var session sessions.State if err := qp.decoder.Unmarshal([]byte(cipherText), &session); err != nil { - return nil, ErrMalformed + return nil, sessions.ErrMalformed } return &session, nil } // ClearSession clears the session cookie from a request's query param key `pomerium_session`. -func (qp *QueryParamStore) ClearSession(w http.ResponseWriter, r *http.Request) { +func (qp *Store) ClearSession(w http.ResponseWriter, r *http.Request) { params := r.URL.Query() params.Del(qp.queryParamKey) r.URL.RawQuery = params.Encode() } // SaveSession sets a session to a request's query param key `pomerium_session` -func (qp *QueryParamStore) SaveSession(w http.ResponseWriter, r *http.Request, x interface{}) error { +func (qp *Store) SaveSession(w http.ResponseWriter, r *http.Request, x interface{}) error { data, err := qp.encoder.Marshal(x) if err != nil { return err diff --git a/internal/sessions/query_store_test.go b/internal/sessions/queryparam/query_store_test.go similarity index 52% rename from internal/sessions/query_store_test.go rename to internal/sessions/queryparam/query_store_test.go index dc3d5ecf8..0137b5e93 100644 --- a/internal/sessions/query_store_test.go +++ b/internal/sessions/queryparam/query_store_test.go @@ -1,4 +1,4 @@ -package sessions +package queryparam // import "github.com/pomerium/pomerium/internal/sessions/queryparam" import ( "errors" @@ -9,39 +9,40 @@ import ( "github.com/google/go-cmp/cmp" "github.com/pomerium/pomerium/internal/encoding" "github.com/pomerium/pomerium/internal/encoding/mock" + "github.com/pomerium/pomerium/internal/sessions" ) func TestNewQueryParamStore(t *testing.T) { tests := []struct { name string - State *State + State *sessions.State enc encoding.MarshalUnmarshaler qp string wantErr bool wantURL *url.URL }{ - {"simple good", &State{Email: "user@domain.com", User: "user"}, mock.Encoder{MarshalResponse: []byte("ok")}, "", false, &url.URL{Path: "/", RawQuery: "pomerium_session=ok"}}, - {"marshall error", &State{Email: "user@domain.com", User: "user"}, mock.Encoder{MarshalError: errors.New("error")}, "", true, &url.URL{Path: "/"}}, + {"simple good", &sessions.State{Email: "user@domain.com", User: "user"}, mock.Encoder{MarshalResponse: []byte("ok")}, "", false, &url.URL{Path: "/", RawQuery: "pomerium_session=ok"}}, + {"marshall error", &sessions.State{Email: "user@domain.com", User: "user"}, mock.Encoder{MarshalError: errors.New("error")}, "", true, &url.URL{Path: "/"}}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got := NewQueryParamStore(tt.enc, tt.qp) + got := NewStore(tt.enc, tt.qp) r := httptest.NewRequest("GET", "/", nil) w := httptest.NewRecorder() if err := got.SaveSession(w, r, tt.State); (err != nil) != tt.wantErr { - t.Errorf("NewQueryParamStore.SaveSession() error = %v, wantErr %v", err, tt.wantErr) + t.Errorf("NewStore.SaveSession() error = %v, wantErr %v", err, tt.wantErr) } if diff := cmp.Diff(r.URL, tt.wantURL); diff != "" { - t.Errorf("NewQueryParamStore() = %v", diff) + t.Errorf("NewStore() = %v", diff) } got.ClearSession(w, r) if diff := cmp.Diff(r.URL, &url.URL{Path: "/"}); diff != "" { - t.Errorf("NewQueryParamStore() = %v", diff) + t.Errorf("NewStore() = %v", diff) } }) } diff --git a/internal/sessions/state.go b/internal/sessions/state.go index a266770a8..0bc2543e3 100644 --- a/internal/sessions/state.go +++ b/internal/sessions/state.go @@ -6,6 +6,8 @@ import ( "strings" "time" + "github.com/cespare/xxhash/v2" + "github.com/mitchellh/hashstructure" oidc "github.com/pomerium/go-oidc" "golang.org/x/oauth2" "gopkg.in/square/go-jose.v2/jwt" @@ -51,7 +53,8 @@ type State struct { // programatic access. Programmatic bool `json:"programatic"` - AccessToken *oauth2.Token `json:"access_token,omitempty"` + AccessToken *oauth2.Token `json:"act,omitempty"` + AccessTokenID string `json:"ati,omitempty"` idToken *oidc.IDToken } @@ -73,7 +76,7 @@ func NewStateFromTokens(idToken *oidc.IDToken, accessToken *oauth2.Token, audien s.Audience = []string{audience} s.idToken = idToken s.AccessToken = accessToken - + s.AccessTokenID = s.accessTokenHash() return s, nil } @@ -95,6 +98,7 @@ func (s *State) UpdateState(idToken *oidc.IDToken, accessToken *oauth2.Token) er } s.Audience = audience s.Expiry = jwt.NewNumericDate(accessToken.Expiry) + s.AccessTokenID = s.accessTokenHash() return nil } @@ -173,3 +177,13 @@ func (s *State) SetImpersonation(email, groups string) { s.ImpersonateGroups = strings.Split(groups, ",") } } + +func (s *State) accessTokenHash() string { + hash, err := hashstructure.Hash( + s.AccessToken, + &hashstructure.HashOptions{Hasher: xxhash.New()}) + if err != nil { + return "" + } + return fmt.Sprintf("%x", hash) +} diff --git a/internal/sessions/state_test.go b/internal/sessions/state_test.go index 22204cd48..54940f899 100644 --- a/internal/sessions/state_test.go +++ b/internal/sessions/state_test.go @@ -124,3 +124,26 @@ func TestState_RouteSession(t *testing.T) { }) } } + +func TestState_accessTokenHash(t *testing.T) { + t.Parallel() + tests := []struct { + name string + state State + want string + }{ + {"empty access token", State{}, "34c96acdcadb1bbb"}, + {"no change to access token", State{Subject: "test"}, "34c96acdcadb1bbb"}, + {"empty oauth2 token", State{AccessToken: &oauth2.Token{}}, "bbd82197d215198f"}, + {"refresh token a", State{AccessToken: &oauth2.Token{RefreshToken: "a"}}, "76316ac79b301bd6"}, + {"refresh token b", State{AccessToken: &oauth2.Token{RefreshToken: "b"}}, "fab7cb29e50161f1"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &tt.state + if got := s.accessTokenHash(); got != tt.want { + t.Errorf("State.accessTokenHash() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/internal/sessions/store.go b/internal/sessions/store.go index 9433771e2..66df7452b 100644 --- a/internal/sessions/store.go +++ b/internal/sessions/store.go @@ -1,39 +1,17 @@ package sessions // import "github.com/pomerium/pomerium/internal/sessions" import ( - "errors" "net/http" ) -var ( - // ErrNoSessionFound is the error for when no session is found. - ErrNoSessionFound = errors.New("internal/sessions: session is not found") - - // ErrMalformed is the error for when a session is found but is malformed. - ErrMalformed = errors.New("internal/sessions: session is malformed") - - // ErrNotValidYet indicates that token is used before time indicated in nbf claim. - ErrNotValidYet = errors.New("internal/sessions: validation failed, token not valid yet (nbf)") - - // ErrExpired indicates that token is used after expiry time indicated in exp claim. - ErrExpired = errors.New("internal/sessions: validation failed, token is expired (exp)") - - // ErrIssuedInTheFuture indicates that the iat field is in the future. - ErrIssuedInTheFuture = errors.New("internal/sessions: validation field, token issued in the future (iat)") - - // ErrInvalidAudience indicated invalid aud claim. - ErrInvalidAudience = errors.New("internal/sessions: validation failed, invalid audience claim (aud)") -) - -// SessionStore has the functions for setting, getting, and clearing the Session cookie +// SessionStore defines an interface for loading, saving, and clearing a session. type SessionStore interface { - ClearSession(http.ResponseWriter, *http.Request) SessionLoader + ClearSession(http.ResponseWriter, *http.Request) SaveSession(http.ResponseWriter, *http.Request, interface{}) error } -// SessionLoader is implemented by any struct that loads a pomerium session -// given a request, and returns a user state. +// SessionLoader defines an interface for loading a session. type SessionLoader interface { LoadSession(*http.Request) (*State, error) } diff --git a/internal/telemetry/metrics/const.go b/internal/telemetry/metrics/const.go index b55197af1..60fd8c558 100644 --- a/internal/telemetry/metrics/const.go +++ b/internal/telemetry/metrics/const.go @@ -34,10 +34,10 @@ var ( // DefaultViews are a set of default views to view HTTP and GRPC metrics. var ( DefaultViews = [][]*view.View{ - GRPCServerViews, - HTTPServerViews, GRPCClientViews, GRPCServerViews, + HTTPClientViews, + HTTPServerViews, InfoViews, } ) diff --git a/proxy/forward_auth_test.go b/proxy/forward_auth_test.go index 9afaef325..9aaf853b9 100644 --- a/proxy/forward_auth_test.go +++ b/proxy/forward_auth_test.go @@ -9,14 +9,16 @@ import ( "time" "github.com/google/go-cmp/cmp" + "gopkg.in/square/go-jose.v2/jwt" + "github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/internal/encoding" "github.com/pomerium/pomerium/internal/encoding/mock" "github.com/pomerium/pomerium/internal/httputil" "github.com/pomerium/pomerium/internal/sessions" + mstore "github.com/pomerium/pomerium/internal/sessions/mock" "github.com/pomerium/pomerium/internal/urlutil" "github.com/pomerium/pomerium/proxy/clients" - "gopkg.in/square/go-jose.v2/jwt" ) func TestProxy_ForwardAuth(t *testing.T) { @@ -40,29 +42,29 @@ func TestProxy_ForwardAuth(t *testing.T) { wantStatus int wantBody string }{ - {"good redirect not required", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusOK, "Access to some.domain.example is allowed."}, - {"good verify only, no redirect", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/verify", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusOK, ""}, - {"good redirect not required", opts, nil, http.MethodGet, nil, nil, "/", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{LoadError: sessions.ErrInvalidAudience}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusUnauthorized, "{\"Status\":401,\"Error\":\"Unauthorized: internal/sessions: validation failed, invalid audience claim (aud)\"}\n"}, - {"bad naked domain uri", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/", "a.naked.domain", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, "{\"Status\":400,\"Error\":\"Bad Request: a.naked.domain url does contain a valid scheme\"}\n"}, - {"bad naked domain uri verify only", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/verify", "a.naked.domain", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, "{\"Status\":400,\"Error\":\"Bad Request: a.naked.domain url does contain a valid scheme\"}\n"}, - {"bad empty verification uri", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/", " ", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, "{\"Status\":400,\"Error\":\"Bad Request: %20 url does contain a valid scheme\"}\n"}, - {"bad empty verification uri verify only", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/verify", " ", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, "{\"Status\":400,\"Error\":\"Bad Request: %20 url does contain a valid scheme\"}\n"}, - {"not authorized", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusUnauthorized, "{\"Status\":401,\"Error\":\"Unauthorized: user@test.example is not authorized for some.domain.example\"}\n"}, - {"not authorized verify endpoint", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/verify", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusUnauthorized, "{\"Status\":401,\"Error\":\"Unauthorized: user@test.example is not authorized for some.domain.example\"}\n"}, - {"not authorized expired, redirect to auth", opts, sessions.ErrExpired, http.MethodGet, nil, nil, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusFound, ""}, - {"not authorized expired, don't redirect!", opts, sessions.ErrExpired, http.MethodGet, nil, nil, "https://some.domain.example/verify", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusUnauthorized, "{\"Status\":401,\"Error\":\"Unauthorized: internal/sessions: validation failed, token is expired (exp)\"}\n"}, - {"not authorized because of error", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeError: errors.New("authz error")}, http.StatusInternalServerError, "{\"Status\":500,\"Error\":\"Internal Server Error: authz error\"}\n"}, - {"not authorized expired, do not redirect to auth", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/verify", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusUnauthorized, "{\"Status\":401,\"Error\":\"Unauthorized: internal/sessions: validation failed, token is expired (exp)\"}\n"}, - {"not authorized, bad audience request uri", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Audience: []string{"not.domain.example"}, Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusUnauthorized, "{\"Status\":401,\"Error\":\"Unauthorized: internal/sessions: validation failed, invalid audience claim (aud)\"}\n"}, - {"not authorized, bad audience verify uri", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/", "https://fwdauth.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Audience: []string{"some.domain.example"}, Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusUnauthorized, "{\"Status\":401,\"Error\":\"Unauthorized: internal/sessions: validation failed, invalid audience claim (aud)\"}\n"}, + {"good redirect not required", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusOK, "Access to some.domain.example is allowed."}, + {"good verify only, no redirect", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/verify", "https://some.domain.example", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusOK, ""}, + {"good redirect not required", opts, nil, http.MethodGet, nil, nil, "/", "https://some.domain.example", &mock.Encoder{}, &mstore.Store{LoadError: sessions.ErrInvalidAudience}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusUnauthorized, "{\"Status\":401,\"Error\":\"Unauthorized: internal/sessions: validation failed, invalid audience claim (aud)\"}\n"}, + {"bad naked domain uri", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/", "a.naked.domain", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, "{\"Status\":400,\"Error\":\"Bad Request: a.naked.domain url does contain a valid scheme\"}\n"}, + {"bad naked domain uri verify only", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/verify", "a.naked.domain", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, "{\"Status\":400,\"Error\":\"Bad Request: a.naked.domain url does contain a valid scheme\"}\n"}, + {"bad empty verification uri", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/", " ", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, "{\"Status\":400,\"Error\":\"Bad Request: %20 url does contain a valid scheme\"}\n"}, + {"bad empty verification uri verify only", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/verify", " ", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, "{\"Status\":400,\"Error\":\"Bad Request: %20 url does contain a valid scheme\"}\n"}, + {"not authorized", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusUnauthorized, "{\"Status\":401,\"Error\":\"Unauthorized: user@test.example is not authorized for some.domain.example\"}\n"}, + {"not authorized verify endpoint", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/verify", "https://some.domain.example", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusUnauthorized, "{\"Status\":401,\"Error\":\"Unauthorized: user@test.example is not authorized for some.domain.example\"}\n"}, + {"not authorized expired, redirect to auth", opts, sessions.ErrExpired, http.MethodGet, nil, nil, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusFound, ""}, + {"not authorized expired, don't redirect!", opts, sessions.ErrExpired, http.MethodGet, nil, nil, "https://some.domain.example/verify", "https://some.domain.example", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusUnauthorized, "{\"Status\":401,\"Error\":\"Unauthorized: internal/sessions: validation failed, token is expired (exp)\"}\n"}, + {"not authorized because of error", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeError: errors.New("authz error")}, http.StatusInternalServerError, "{\"Status\":500,\"Error\":\"Internal Server Error: authz error\"}\n"}, + {"not authorized expired, do not redirect to auth", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/verify", "https://some.domain.example", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusUnauthorized, "{\"Status\":401,\"Error\":\"Unauthorized: internal/sessions: validation failed, token is expired (exp)\"}\n"}, + {"not authorized, bad audience request uri", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Audience: []string{"not.domain.example"}, Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusUnauthorized, "{\"Status\":401,\"Error\":\"Unauthorized: internal/sessions: validation failed, invalid audience claim (aud)\"}\n"}, + {"not authorized, bad audience verify uri", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/", "https://fwdauth.domain.example", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Audience: []string{"some.domain.example"}, Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusUnauthorized, "{\"Status\":401,\"Error\":\"Unauthorized: internal/sessions: validation failed, invalid audience claim (aud)\"}\n"}, // traefik - {"good traefik callback", opts, nil, http.MethodGet, map[string]string{httputil.HeaderForwardedURI: "https://some.domain.example?" + urlutil.QuerySessionEncrypted + "=" + goodEncryptionString}, nil, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound, ""}, - {"bad traefik callback bad session", opts, nil, http.MethodGet, map[string]string{httputil.HeaderForwardedURI: "https://some.domain.example?" + urlutil.QuerySessionEncrypted + "=" + goodEncryptionString + "garbage"}, nil, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""}, - {"bad traefik callback bad url", opts, nil, http.MethodGet, map[string]string{httputil.HeaderForwardedURI: urlutil.QuerySessionEncrypted + ""}, nil, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""}, + {"good traefik callback", opts, nil, http.MethodGet, map[string]string{httputil.HeaderForwardedURI: "https://some.domain.example?" + urlutil.QuerySessionEncrypted + "=" + goodEncryptionString}, nil, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound, ""}, + {"bad traefik callback bad session", opts, nil, http.MethodGet, map[string]string{httputil.HeaderForwardedURI: "https://some.domain.example?" + urlutil.QuerySessionEncrypted + "=" + goodEncryptionString + "garbage"}, nil, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""}, + {"bad traefik callback bad url", opts, nil, http.MethodGet, map[string]string{httputil.HeaderForwardedURI: urlutil.QuerySessionEncrypted + ""}, nil, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""}, // nginx - {"good nginx callback redirect", opts, nil, http.MethodGet, nil, map[string]string{urlutil.QueryRedirectURI: "https://some.domain.example/", urlutil.QuerySessionEncrypted: goodEncryptionString}, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound, ""}, - {"good nginx callback set session okay but return unauthorized", opts, nil, http.MethodGet, nil, map[string]string{urlutil.QueryRedirectURI: "https://some.domain.example/", urlutil.QuerySessionEncrypted: goodEncryptionString}, "https://some.domain.example/verify", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusUnauthorized, ""}, - {"bad nginx callback failed to set sesion", opts, nil, http.MethodGet, nil, map[string]string{urlutil.QueryRedirectURI: "https://some.domain.example/", urlutil.QuerySessionEncrypted: goodEncryptionString + "nope"}, "https://some.domain.example/verify", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""}, + {"good nginx callback redirect", opts, nil, http.MethodGet, nil, map[string]string{urlutil.QueryRedirectURI: "https://some.domain.example/", urlutil.QuerySessionEncrypted: goodEncryptionString}, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound, ""}, + {"good nginx callback set session okay but return unauthorized", opts, nil, http.MethodGet, nil, map[string]string{urlutil.QueryRedirectURI: "https://some.domain.example/", urlutil.QuerySessionEncrypted: goodEncryptionString}, "https://some.domain.example/verify", "https://some.domain.example", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusUnauthorized, ""}, + {"bad nginx callback failed to set sesion", opts, nil, http.MethodGet, nil, map[string]string{urlutil.QueryRedirectURI: "https://some.domain.example/", urlutil.QuerySessionEncrypted: goodEncryptionString + "nope"}, "https://some.domain.example/verify", "https://some.domain.example", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { diff --git a/proxy/handlers_test.go b/proxy/handlers_test.go index b840abf52..f05c69b15 100644 --- a/proxy/handlers_test.go +++ b/proxy/handlers_test.go @@ -12,6 +12,7 @@ import ( "time" "github.com/pomerium/pomerium/internal/cryptutil" + mstore "github.com/pomerium/pomerium/internal/sessions/mock" "github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/internal/encoding" @@ -78,10 +79,10 @@ func TestProxy_UserDashboard(t *testing.T) { wantAdminForm bool wantStatus int }{ - {"good", nil, opts, http.MethodGet, &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{}, false, http.StatusOK}, - {"session context error", errors.New("error"), opts, http.MethodGet, &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{}, false, http.StatusInternalServerError}, - {"want admin form good admin authorization", nil, opts, http.MethodGet, &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{IsAdminResponse: true}, true, http.StatusOK}, - {"is admin but authorization fails", nil, opts, http.MethodGet, &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{IsAdminError: errors.New("err")}, false, http.StatusInternalServerError}, + {"good", nil, opts, http.MethodGet, &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{}, false, http.StatusOK}, + {"session context error", errors.New("error"), opts, http.MethodGet, &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{}, false, http.StatusInternalServerError}, + {"want admin form good admin authorization", nil, opts, http.MethodGet, &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{IsAdminResponse: true}, true, http.StatusOK}, + {"is admin but authorization fails", nil, opts, http.MethodGet, &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{IsAdminError: errors.New("err")}, false, http.StatusInternalServerError}, } for _, tt := range tests { @@ -135,12 +136,12 @@ func TestProxy_Impersonate(t *testing.T) { authorizer clients.Authorizer wantStatus int }{ - {"good", false, opts, nil, http.MethodPost, "user@blah.com", "", "", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example"}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusFound}, - {"good", false, opts, errors.New("error"), http.MethodPost, "user@blah.com", "", "", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example"}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusInternalServerError}, - {"session load error", false, opts, nil, http.MethodPost, "user@blah.com", "", "", &mock.Encoder{}, &sessions.MockSessionStore{LoadError: errors.New("err"), Session: &sessions.State{Email: "user@test.example"}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusFound}, - {"non admin users rejected", false, opts, nil, http.MethodPost, "user@blah.com", "", "", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute)), Email: "user@test.example"}}, clients.MockAuthorize{IsAdminResponse: false}, http.StatusForbidden}, - {"non admin users rejected on error", false, opts, nil, http.MethodPost, "user@blah.com", "", "", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute)), Email: "user@test.example"}}, clients.MockAuthorize{IsAdminResponse: true, IsAdminError: errors.New("err")}, http.StatusInternalServerError}, - {"groups", false, opts, nil, http.MethodPost, "user@blah.com", "group1,group2", "", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute)), Email: "user@test.example"}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusFound}, + {"good", false, opts, nil, http.MethodPost, "user@blah.com", "", "", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example"}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusFound}, + {"good", false, opts, errors.New("error"), http.MethodPost, "user@blah.com", "", "", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example"}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusInternalServerError}, + {"session load error", false, opts, nil, http.MethodPost, "user@blah.com", "", "", &mock.Encoder{}, &mstore.Store{LoadError: errors.New("err"), Session: &sessions.State{Email: "user@test.example"}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusFound}, + {"non admin users rejected", false, opts, nil, http.MethodPost, "user@blah.com", "", "", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute)), Email: "user@test.example"}}, clients.MockAuthorize{IsAdminResponse: false}, http.StatusForbidden}, + {"non admin users rejected on error", false, opts, nil, http.MethodPost, "user@blah.com", "", "", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute)), Email: "user@test.example"}}, clients.MockAuthorize{IsAdminResponse: true, IsAdminError: errors.New("err")}, http.StatusInternalServerError}, + {"groups", false, opts, nil, http.MethodPost, "user@blah.com", "group1,group2", "", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute)), Email: "user@test.example"}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusFound}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -245,12 +246,12 @@ func TestProxy_Callback(t *testing.T) { wantStatus int wantBody string }{ - {"good", opts, http.MethodGet, "http", "example.com", "/", nil, map[string]string{urlutil.QueryCallbackURI: "ok", urlutil.QuerySessionEncrypted: goodEncryptionString}, &mock.Encoder{MarshalResponse: []byte("x")}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound, ""}, - {"good programmatic", opts, http.MethodGet, "http", "example.com", "/", nil, map[string]string{urlutil.QueryIsProgrammatic: "true", urlutil.QueryCallbackURI: "ok", urlutil.QuerySessionEncrypted: goodEncryptionString}, &mock.Encoder{MarshalResponse: []byte("x")}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound, ""}, - {"bad decrypt", opts, http.MethodGet, "http", "example.com", "/", nil, map[string]string{urlutil.QuerySessionEncrypted: "KBEjQ9rnCxaAX-GOqexGw9ivEQURqts3zZ2mNGy0wnVa3SbtM399KlBq2nZ-9wM21FfsZX52er4jlmC7kPEKM3P7uZ41zR0zeys1-_74a5tQp-vsf1WXZfRsgVOuBcWPkMiWEoc379JFHxGDudp5VhU8B-dcQt4f3_PtLTHARkuH54io1Va2gNMq4Hiy8sQ1MPGCQeltH_JMzzdDpXdmdusWrXUvCGkba24muvAV06D8XRVJj6Iu9eK94qFnqcHc7wzziEbb8ADBues9dwbtb6jl8vMWz5rN6XvXqA5YpZv_MQZlsrO4oXFFQDevdgB84cX1tVbVu6qZvK_yQBZqzpOjWA9uIaoSENMytoXuWAlFO_sXjswfX8JTNdGwzB7qQRNPqxVG_sM_tzY3QhPm8zqwEzsXG5DokxZfVt2I5WJRUEovFDb4BnK9KFnnkEzLEdMudixVnXeGmTtycgJvoTeTCQRPfDYkcgJ7oKf4tGea-W7z5UAVa2RduJM9ZoM6YtJX7jgDm__PvvqcE0knJUF87XHBzdcOjoDF-CUze9xDJgNBlvPbJqVshKrwoqSYpePSDH9GUCNKxGequW3Ma8GvlFfhwd0rK6IZG-XWkyk0XSWQIGkDSjAvhB1wsOusCCguDjbpVZpaW5MMyTkmx68pl6qlIKT5UCcrVPl4ix5ZEj91mUDF0O1t04haD7VZuLVFXVGmqtFrBKI76sdYN-zkokaa1_chPRTyqMQFlqu_8LD6-RiK3UccGM-dEmnX72i91NP9F9OK0WJr9Cheup1C_P0mjqAO4Cb8oIHm0Oxz_mRqv5QbTGJtb3xwPLPuVjVCiE4gGBcuU2ixpSVf5HUF7y1KicVMCKiX9ATCBtg8sTdQZQnPEtHcHHAvdsnDVwev1LGfqA-Gdvg="}, &mock.Encoder{MarshalResponse: []byte("x")}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""}, - {"bad save session", opts, http.MethodGet, "http", "example.com", "/", nil, map[string]string{urlutil.QuerySessionEncrypted: goodEncryptionString}, &mock.Encoder{MarshalResponse: []byte("x")}, &sessions.MockSessionStore{SaveError: errors.New("hi")}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""}, - {"bad base64", opts, http.MethodGet, "http", "example.com", "/", nil, map[string]string{urlutil.QuerySessionEncrypted: "^"}, &mock.Encoder{MarshalResponse: []byte("x")}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""}, - {"malformed redirect", opts, http.MethodGet, "http", "example.com", "/", nil, nil, &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""}, + {"good", opts, http.MethodGet, "http", "example.com", "/", nil, map[string]string{urlutil.QueryCallbackURI: "ok", urlutil.QuerySessionEncrypted: goodEncryptionString}, &mock.Encoder{MarshalResponse: []byte("x")}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound, ""}, + {"good programmatic", opts, http.MethodGet, "http", "example.com", "/", nil, map[string]string{urlutil.QueryIsProgrammatic: "true", urlutil.QueryCallbackURI: "ok", urlutil.QuerySessionEncrypted: goodEncryptionString}, &mock.Encoder{MarshalResponse: []byte("x")}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound, ""}, + {"bad decrypt", opts, http.MethodGet, "http", "example.com", "/", nil, map[string]string{urlutil.QuerySessionEncrypted: "KBEjQ9rnCxaAX-GOqexGw9ivEQURqts3zZ2mNGy0wnVa3SbtM399KlBq2nZ-9wM21FfsZX52er4jlmC7kPEKM3P7uZ41zR0zeys1-_74a5tQp-vsf1WXZfRsgVOuBcWPkMiWEoc379JFHxGDudp5VhU8B-dcQt4f3_PtLTHARkuH54io1Va2gNMq4Hiy8sQ1MPGCQeltH_JMzzdDpXdmdusWrXUvCGkba24muvAV06D8XRVJj6Iu9eK94qFnqcHc7wzziEbb8ADBues9dwbtb6jl8vMWz5rN6XvXqA5YpZv_MQZlsrO4oXFFQDevdgB84cX1tVbVu6qZvK_yQBZqzpOjWA9uIaoSENMytoXuWAlFO_sXjswfX8JTNdGwzB7qQRNPqxVG_sM_tzY3QhPm8zqwEzsXG5DokxZfVt2I5WJRUEovFDb4BnK9KFnnkEzLEdMudixVnXeGmTtycgJvoTeTCQRPfDYkcgJ7oKf4tGea-W7z5UAVa2RduJM9ZoM6YtJX7jgDm__PvvqcE0knJUF87XHBzdcOjoDF-CUze9xDJgNBlvPbJqVshKrwoqSYpePSDH9GUCNKxGequW3Ma8GvlFfhwd0rK6IZG-XWkyk0XSWQIGkDSjAvhB1wsOusCCguDjbpVZpaW5MMyTkmx68pl6qlIKT5UCcrVPl4ix5ZEj91mUDF0O1t04haD7VZuLVFXVGmqtFrBKI76sdYN-zkokaa1_chPRTyqMQFlqu_8LD6-RiK3UccGM-dEmnX72i91NP9F9OK0WJr9Cheup1C_P0mjqAO4Cb8oIHm0Oxz_mRqv5QbTGJtb3xwPLPuVjVCiE4gGBcuU2ixpSVf5HUF7y1KicVMCKiX9ATCBtg8sTdQZQnPEtHcHHAvdsnDVwev1LGfqA-Gdvg="}, &mock.Encoder{MarshalResponse: []byte("x")}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""}, + {"bad save session", opts, http.MethodGet, "http", "example.com", "/", nil, map[string]string{urlutil.QuerySessionEncrypted: goodEncryptionString}, &mock.Encoder{MarshalResponse: []byte("x")}, &mstore.Store{SaveError: errors.New("hi")}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""}, + {"bad base64", opts, http.MethodGet, "http", "example.com", "/", nil, map[string]string{urlutil.QuerySessionEncrypted: "^"}, &mock.Encoder{MarshalResponse: []byte("x")}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""}, + {"malformed redirect", opts, http.MethodGet, "http", "example.com", "/", nil, nil, &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -386,12 +387,12 @@ func TestProxy_ProgrammaticCallback(t *testing.T) { wantStatus int wantBody string }{ - {"good", opts, http.MethodGet, "http://pomerium.io/", nil, map[string]string{urlutil.QueryCallbackURI: "ok", urlutil.QuerySessionEncrypted: goodEncryptionString}, &mock.Encoder{MarshalResponse: []byte("x")}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound, ""}, - {"good programmatic", opts, http.MethodGet, "http://pomerium.io/", nil, map[string]string{urlutil.QueryIsProgrammatic: "true", urlutil.QueryCallbackURI: "ok", urlutil.QuerySessionEncrypted: goodEncryptionString}, &mock.Encoder{MarshalResponse: []byte("x")}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound, ""}, - {"bad decrypt", opts, http.MethodGet, "http://pomerium.io/", nil, map[string]string{urlutil.QuerySessionEncrypted: goodEncryptionString + cryptutil.NewBase64Key()}, &mock.Encoder{MarshalResponse: []byte("x")}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""}, - {"bad save session", opts, http.MethodGet, "http://pomerium.io/", nil, map[string]string{urlutil.QuerySessionEncrypted: goodEncryptionString}, &mock.Encoder{MarshalResponse: []byte("x")}, &sessions.MockSessionStore{SaveError: errors.New("hi")}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""}, - {"bad base64", opts, http.MethodGet, "http://pomerium.io/", nil, map[string]string{urlutil.QuerySessionEncrypted: "^"}, &mock.Encoder{MarshalResponse: []byte("x")}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""}, - {"malformed redirect", opts, http.MethodGet, "http://pomerium.io/", nil, nil, &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""}, + {"good", opts, http.MethodGet, "http://pomerium.io/", nil, map[string]string{urlutil.QueryCallbackURI: "ok", urlutil.QuerySessionEncrypted: goodEncryptionString}, &mock.Encoder{MarshalResponse: []byte("x")}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound, ""}, + {"good programmatic", opts, http.MethodGet, "http://pomerium.io/", nil, map[string]string{urlutil.QueryIsProgrammatic: "true", urlutil.QueryCallbackURI: "ok", urlutil.QuerySessionEncrypted: goodEncryptionString}, &mock.Encoder{MarshalResponse: []byte("x")}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound, ""}, + {"bad decrypt", opts, http.MethodGet, "http://pomerium.io/", nil, map[string]string{urlutil.QuerySessionEncrypted: goodEncryptionString + cryptutil.NewBase64Key()}, &mock.Encoder{MarshalResponse: []byte("x")}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""}, + {"bad save session", opts, http.MethodGet, "http://pomerium.io/", nil, map[string]string{urlutil.QuerySessionEncrypted: goodEncryptionString}, &mock.Encoder{MarshalResponse: []byte("x")}, &mstore.Store{SaveError: errors.New("hi")}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""}, + {"bad base64", opts, http.MethodGet, "http://pomerium.io/", nil, map[string]string{urlutil.QuerySessionEncrypted: "^"}, &mock.Encoder{MarshalResponse: []byte("x")}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""}, + {"malformed redirect", opts, http.MethodGet, "http://pomerium.io/", nil, nil, &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { diff --git a/proxy/middleware.go b/proxy/middleware.go index 29bfd5bd0..129df2c66 100644 --- a/proxy/middleware.go +++ b/proxy/middleware.go @@ -1,7 +1,11 @@ package proxy // import "github.com/pomerium/pomerium/proxy" import ( + "context" + "errors" "fmt" + "io" + "io/ioutil" "net/http" "github.com/pomerium/pomerium/internal/encoding" @@ -30,23 +34,82 @@ func (p *Proxy) AuthenticateSession(next http.Handler) http.Handler { ctx, span := trace.StartSpan(r.Context(), "proxy.AuthenticateSession") defer span.End() - if s, err := sessions.FromContext(ctx); err != nil { - log.FromRequest(r).Debug().Err(err).Msg("proxy: authenticate session") - p.sessionStore.ClearSession(w, r) - if s != nil && s.Programmatic { - return httputil.NewError(http.StatusUnauthorized, err) + _, err := sessions.FromContext(ctx) + if errors.Is(err, sessions.ErrExpired) { + ctx, err = p.refresh(ctx, w, r) + if err != nil { + log.FromRequest(r).Warn().Err(err).Msg("proxy: refresh failed") + return p.redirectToSignin(w, r) } - signinURL := *p.authenticateSigninURL - q := signinURL.Query() - q.Set(urlutil.QueryRedirectURI, urlutil.GetAbsoluteURL(r).String()) - signinURL.RawQuery = q.Encode() - httputil.Redirect(w, r, urlutil.NewSignedURL(p.SharedKey, &signinURL).String(), http.StatusFound) + log.FromRequest(r).Info().Msg("proxy: refresh success") + } else if err != nil { + log.FromRequest(r).Debug().Err(err).Msg("proxy: session state") + return p.redirectToSignin(w, r) } p.addPomeriumHeaders(w, r) next.ServeHTTP(w, r.WithContext(ctx)) return nil }) +} +func (p *Proxy) refresh(ctx context.Context, w http.ResponseWriter, r *http.Request) (context.Context, error) { + ctx, span := trace.StartSpan(ctx, "proxy.AuthenticateSession/refresh") + defer span.End() + s, err := sessions.FromContext(ctx) + if !errors.Is(err, sessions.ErrExpired) || s == nil { + return nil, errors.New("proxy: unexpected session state for refresh") + } + // 1 - build a signed url to call refresh on authenticate service + refreshURI := *p.authenticateRefreshURL + q := refreshURI.Query() + q.Set("ati", s.AccessTokenID) // hash value points to parent token + q.Set("aud", urlutil.StripPort(r.Host)) // request's audience, this route + refreshURI.RawQuery = q.Encode() + signedRefreshURL := urlutil.NewSignedURL(p.SharedKey, &refreshURI).String() + + // 2 - http call to authenticate service + req, err := http.NewRequestWithContext(ctx, http.MethodGet, signedRefreshURL, nil) + if err != nil { + return nil, fmt.Errorf("proxy: backend refresh: new request: %v", err) + } + res, err := httputil.DefaultClient.Do(req) + if err != nil { + return nil, fmt.Errorf("proxy: fetch %v: %w", signedRefreshURL, err) + } + defer res.Body.Close() + jwtBytes, err := ioutil.ReadAll(io.LimitReader(res.Body, 4<<10)) + if err != nil { + return nil, err + } + + // 3 - save refreshed session to the client's session store + if err = p.sessionStore.SaveSession(w, r, jwtBytes); err != nil { + return nil, err + } + // 4 - add refreshed session to the current request context + var state sessions.State + if err := p.encoder.Unmarshal(jwtBytes, &state); err != nil { + return nil, err + } + if err := state.Verify(urlutil.StripPort(r.Host)); err != nil { + return nil, err + } + return sessions.NewContext(r.Context(), &state, err), nil +} + +func (p *Proxy) redirectToSignin(w http.ResponseWriter, r *http.Request) error { + s, err := sessions.FromContext(r.Context()) + if s != nil && err != nil && s.Programmatic { + return httputil.NewError(http.StatusUnauthorized, err) + } + p.sessionStore.ClearSession(w, r) + signinURL := *p.authenticateSigninURL + q := signinURL.Query() + q.Set(urlutil.QueryRedirectURI, urlutil.GetAbsoluteURL(r).String()) + signinURL.RawQuery = q.Encode() + log.FromRequest(r).Debug().Str("url", signinURL.String()).Msg("proxy: redirectToSignin") + httputil.Redirect(w, r, urlutil.NewSignedURL(p.SharedKey, &signinURL).String(), http.StatusFound) + return nil } func (p *Proxy) addPomeriumHeaders(w http.ResponseWriter, r *http.Request) { @@ -61,8 +124,8 @@ func (p *Proxy) addPomeriumHeaders(w http.ResponseWriter, r *http.Request) { } } -// AuthorizeSession is middleware to enforce a user is authorized for a request -// session state is retrieved from the users's request context. +// AuthorizeSession is middleware to enforce a user is authorized for a request. +// Session state is retrieved from the users's request context. func (p *Proxy) AuthorizeSession(next http.Handler) http.Handler { return httputil.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error { ctx, span := trace.StartSpan(r.Context(), "proxy.AuthorizeSession") diff --git a/proxy/middleware_test.go b/proxy/middleware_test.go index 5d72335ef..06f9220a7 100644 --- a/proxy/middleware_test.go +++ b/proxy/middleware_test.go @@ -10,10 +10,14 @@ import ( "time" "github.com/google/go-cmp/cmp" - "github.com/pomerium/pomerium/internal/identity" - "github.com/pomerium/pomerium/internal/sessions" "github.com/pomerium/pomerium/proxy/clients" "gopkg.in/square/go-jose.v2/jwt" + + "github.com/pomerium/pomerium/internal/encoding" + "github.com/pomerium/pomerium/internal/encoding/mock" + "github.com/pomerium/pomerium/internal/identity" + "github.com/pomerium/pomerium/internal/sessions" + mstore "github.com/pomerium/pomerium/internal/sessions/mock" ) func TestProxy_AuthenticateSession(t *testing.T) { @@ -30,24 +34,39 @@ func TestProxy_AuthenticateSession(t *testing.T) { session sessions.SessionStore ctxError error provider identity.Authenticator + encoder encoding.MarshalUnmarshaler + refreshURL string wantStatus int }{ - {"good", false, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Second))}}, nil, identity.MockProvider{}, http.StatusOK}, - {"invalid session", false, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Second))}}, errors.New("hi"), identity.MockProvider{}, http.StatusFound}, - {"expired", false, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{}, http.StatusFound}, - {"expired and programmatic", false, &sessions.MockSessionStore{Session: &sessions.State{Programmatic: true, Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{}, http.StatusUnauthorized}, - {"invalid session and programmatic", false, &sessions.MockSessionStore{Session: &sessions.State{Programmatic: true, Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Second))}}, errors.New("hi"), identity.MockProvider{}, http.StatusUnauthorized}, + {"good", false, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Second))}}, nil, identity.MockProvider{}, &mock.Encoder{}, "", http.StatusOK}, + {"invalid session", false, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Second))}}, errors.New("hi"), identity.MockProvider{}, &mock.Encoder{}, "", http.StatusFound}, + {"expired", false, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{}, &mock.Encoder{}, "", http.StatusOK}, + {"expired and programmatic", false, &mstore.Store{Session: &sessions.State{Programmatic: true, Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{}, &mock.Encoder{}, "", http.StatusOK}, + {"invalid session and programmatic", false, &mstore.Store{Session: &sessions.State{Programmatic: true, Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Second))}}, errors.New("hi"), identity.MockProvider{}, &mock.Encoder{}, "", http.StatusUnauthorized}, + {"expired and refreshed ok", false, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{}, &mock.Encoder{}, "", http.StatusOK}, + {"expired and save failed", false, &mstore.Store{SaveError: errors.New("err"), Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{}, &mock.Encoder{}, "", http.StatusFound}, + {"expired and unmarshal failed", false, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{}, &mock.Encoder{UnmarshalError: errors.New("err")}, "", http.StatusFound}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, "REFRESH GOOD") + })) + defer ts.Close() + rURL := ts.URL + if tt.refreshURL != "" { + rURL = tt.refreshURL + } a := Proxy{ - SharedKey: "80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ=", - cookieSecret: []byte("80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ="), - authenticateURL: uriParseHelper("https://authenticate.corp.example"), - authenticateSigninURL: uriParseHelper("https://authenticate.corp.example/sign_in"), - sessionStore: tt.session, + SharedKey: "80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ=", + cookieSecret: []byte("80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ="), + authenticateURL: uriParseHelper("https://authenticate.corp.example"), + authenticateSigninURL: uriParseHelper("https://authenticate.corp.example/sign_in"), + authenticateRefreshURL: uriParseHelper(rURL), + sessionStore: tt.session, + encoder: tt.encoder, } r := httptest.NewRequest(http.MethodGet, "/", nil) state, _ := tt.session.LoadSession(r) @@ -82,10 +101,10 @@ func TestProxy_AuthorizeSession(t *testing.T) { wantStatus int }{ - {"user is authorized", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Second))}}, clients.MockAuthorize{AuthorizeResponse: true}, nil, identity.MockProvider{}, http.StatusOK}, - {"user is not authorized", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Second))}}, clients.MockAuthorize{AuthorizeResponse: false}, nil, identity.MockProvider{}, http.StatusUnauthorized}, - {"invalid session", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Second))}}, clients.MockAuthorize{AuthorizeResponse: true}, errors.New("hi"), identity.MockProvider{}, http.StatusUnauthorized}, - {"authz client error", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Second))}}, clients.MockAuthorize{AuthorizeError: errors.New("err")}, nil, identity.MockProvider{}, http.StatusInternalServerError}, + {"user is authorized", &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Second))}}, clients.MockAuthorize{AuthorizeResponse: true}, nil, identity.MockProvider{}, http.StatusOK}, + {"user is not authorized", &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Second))}}, clients.MockAuthorize{AuthorizeResponse: false}, nil, identity.MockProvider{}, http.StatusUnauthorized}, + {"invalid session", &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Second))}}, clients.MockAuthorize{AuthorizeResponse: true}, errors.New("hi"), identity.MockProvider{}, http.StatusUnauthorized}, + {"authz client error", &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Second))}}, clients.MockAuthorize{AuthorizeError: errors.New("err")}, nil, identity.MockProvider{}, http.StatusInternalServerError}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -143,9 +162,9 @@ func TestProxy_SignRequest(t *testing.T) { wantStatus int wantHeaders string }{ - {"good", &sessions.MockSessionStore{Session: &sessions.State{Email: "test"}}, nil, nil, http.StatusOK, "ok"}, - {"invalid session", &sessions.MockSessionStore{Session: &sessions.State{Email: "test"}}, nil, errors.New("err"), http.StatusForbidden, ""}, - {"signature failure, warn but ok", &sessions.MockSessionStore{Session: &sessions.State{Email: "test"}}, errors.New("err"), nil, http.StatusOK, ""}, + {"good", &mstore.Store{Session: &sessions.State{Email: "test"}}, nil, nil, http.StatusOK, "ok"}, + {"invalid session", &mstore.Store{Session: &sessions.State{Email: "test"}}, nil, errors.New("err"), http.StatusForbidden, ""}, + {"signature failure, warn but ok", &mstore.Store{Session: &sessions.State{Email: "test"}}, errors.New("err"), nil, http.StatusOK, ""}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { diff --git a/proxy/proxy.go b/proxy/proxy.go index 93e985c9b..3a491c6f1 100755 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -21,6 +21,9 @@ import ( "github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/middleware" "github.com/pomerium/pomerium/internal/sessions" + "github.com/pomerium/pomerium/internal/sessions/cookie" + "github.com/pomerium/pomerium/internal/sessions/header" + "github.com/pomerium/pomerium/internal/sessions/queryparam" "github.com/pomerium/pomerium/internal/telemetry/metrics" "github.com/pomerium/pomerium/internal/tripper" "github.com/pomerium/pomerium/internal/urlutil" @@ -28,12 +31,11 @@ import ( ) const ( - // dashboardURL is the path to authenticate's sign in endpoint + // authenticate urls dashboardURL = "/.pomerium" - // signinURL is the path to authenticate's sign in endpoint - signinURL = "/.pomerium/sign_in" - // signoutURL is the path to authenticate's sign out endpoint - signoutURL = "/.pomerium/sign_out" + signinURL = "/.pomerium/sign_in" + signoutURL = "/.pomerium/sign_out" + refreshURL = "/.pomerium/refresh" ) // ValidateOptions checks that proper configuration settings are set to create @@ -72,12 +74,14 @@ type Proxy struct { authenticateURL *url.URL authenticateSigninURL *url.URL authenticateSignoutURL *url.URL - authorizeURL *url.URL + authenticateRefreshURL *url.URL + + authorizeURL *url.URL AuthorizeClient clients.Authorizer encoder encoding.Unmarshaler - cookieOptions *sessions.CookieOptions + cookieOptions *cookie.Options cookieSecret []byte defaultUpstreamTimeout time.Duration refreshCooldown time.Duration @@ -104,7 +108,7 @@ func New(opts config.Options) (*Proxy, error) { return nil, err } - cookieOptions := &sessions.CookieOptions{ + cookieOptions := &cookie.Options{ Name: opts.CookieName, Domain: opts.CookieDomain, Secure: opts.CookieSecure, @@ -112,7 +116,7 @@ func New(opts config.Options) (*Proxy, error) { Expire: opts.CookieExpire, } - cookieStore, err := sessions.NewCookieLoader(cookieOptions, encoder) + cookieStore, err := cookie.NewStore(cookieOptions, encoder) if err != nil { return nil, err } @@ -129,8 +133,8 @@ func New(opts config.Options) (*Proxy, error) { sessionStore: cookieStore, sessionLoaders: []sessions.SessionLoader{ cookieStore, - sessions.NewHeaderStore(encoder, "Pomerium"), - sessions.NewQueryParamStore(encoder, "pomerium_session")}, + header.NewStore(encoder, "Pomerium"), + queryparam.NewStore(encoder, "pomerium_session")}, signingKey: opts.SigningKey, templates: template.Must(frontend.NewTemplates()), } @@ -139,6 +143,7 @@ func New(opts config.Options) (*Proxy, error) { p.authenticateURL, _ = urlutil.DeepCopy(opts.AuthenticateURL) p.authenticateSigninURL = p.authenticateURL.ResolveReference(&url.URL{Path: signinURL}) p.authenticateSignoutURL = p.authenticateURL.ResolveReference(&url.URL{Path: signoutURL}) + p.authenticateRefreshURL = p.authenticateURL.ResolveReference(&url.URL{Path: refreshURL}) if err := p.UpdatePolicies(&opts); err != nil { return nil, err