authenticate/proxy: add backend refresh (#438)

This commit is contained in:
Bobby DeSimone 2019-12-30 10:47:54 -08:00 committed by GitHub
parent 9a330613aa
commit ec029c679b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
35 changed files with 1226 additions and 445 deletions

View file

@ -198,11 +198,6 @@ issues:
linters: linters:
- staticcheck - staticcheck
# todo(bdd): replace in go 1.13
- path: proxy/proxy.go
text: "copylocks: assignment copies lock value to transport"
linters:
- govet
# Independently from option `exclude` we use default exclude patterns, # Independently from option `exclude` we use default exclude patterns,
# it can be disabled by this option. To list all # it can be disabled by this option. To list all
# excluded by default patterns execute `golangci-lint run --help`. # excluded by default patterns execute `golangci-lint run --help`.

View file

@ -16,6 +16,10 @@ import (
"github.com/pomerium/pomerium/internal/frontend" "github.com/pomerium/pomerium/internal/frontend"
"github.com/pomerium/pomerium/internal/identity" "github.com/pomerium/pomerium/internal/identity"
"github.com/pomerium/pomerium/internal/sessions" "github.com/pomerium/pomerium/internal/sessions"
"github.com/pomerium/pomerium/internal/sessions/cache"
"github.com/pomerium/pomerium/internal/sessions/cookie"
"github.com/pomerium/pomerium/internal/sessions/header"
"github.com/pomerium/pomerium/internal/sessions/queryparam"
"github.com/pomerium/pomerium/internal/urlutil" "github.com/pomerium/pomerium/internal/urlutil"
) )
@ -49,6 +53,8 @@ type Authenticate struct {
// authentication flow // authentication flow
RedirectURL *url.URL RedirectURL *url.URL
// values related to cross service communication
//
// sharedKey is used to encrypt and authenticate data between services // sharedKey is used to encrypt and authenticate data between services
sharedKey string sharedKey string
// sharedCipher is used to encrypt data for use between services // sharedCipher is used to encrypt data for use between services
@ -57,15 +63,20 @@ type Authenticate struct {
// by other services // by other services
sharedEncoder encoding.MarshalUnmarshaler sharedEncoder encoding.MarshalUnmarshaler
// data related to this service only // values related to user sessions
cookieOptions *sessions.CookieOptions //
// cookieSecret is the secret to encrypt and authenticate data for this service // cookieSecret is the secret to encrypt and authenticate session data
cookieSecret []byte cookieSecret []byte
// is the cipher to use to encrypt data for this service // cookieCipher is the cipher to use to encrypt/decrypt session data
cookieCipher cipher.AEAD cookieCipher cipher.AEAD
sessionStore sessions.SessionStore // encryptedEncoder is the encoder used to marshal and unmarshal session data
encryptedEncoder encoding.MarshalUnmarshaler encryptedEncoder encoding.MarshalUnmarshaler
sessionStores []sessions.SessionStore // sessionStore is the session store used to persist a user's session
sessionStore sessions.SessionStore
cookieOptions *cookie.Options
// sessionLoaders are a collection of session loaders to attempt to pull
// a user's session state from
sessionLoaders []sessions.SessionLoader sessionLoaders []sessions.SessionLoader
// provider is the interface to interacting with the identity provider (IdP) // provider is the interface to interacting with the identity provider (IdP)
@ -92,7 +103,7 @@ func New(opts config.Options) (*Authenticate, error) {
cookieCipher, _ := cryptutil.NewAEADCipher(decodedCookieSecret) cookieCipher, _ := cryptutil.NewAEADCipher(decodedCookieSecret)
encryptedEncoder := ecjson.New(cookieCipher) encryptedEncoder := ecjson.New(cookieCipher)
cookieOptions := &sessions.CookieOptions{ cookieOptions := &cookie.Options{
Name: opts.CookieName, Name: opts.CookieName,
Domain: opts.CookieDomain, Domain: opts.CookieDomain,
Secure: opts.CookieSecure, Secure: opts.CookieSecure,
@ -100,12 +111,13 @@ func New(opts config.Options) (*Authenticate, error) {
Expire: opts.CookieExpire, Expire: opts.CookieExpire,
} }
cookieStore, err := sessions.NewCookieStore(cookieOptions, encryptedEncoder) cookieStore, err := cookie.NewStore(cookieOptions, encryptedEncoder)
if err != nil { if err != nil {
return nil, err return nil, err
} }
qpStore := sessions.NewQueryParamStore(encryptedEncoder, "pomerium_programmatic_token") cacheStore := cache.NewStore(encryptedEncoder, cookieStore, opts.CookieName)
headerStore := sessions.NewHeaderStore(encryptedEncoder, "Pomerium") qpStore := queryparam.NewStore(encryptedEncoder, "pomerium_programmatic_token")
headerStore := header.NewStore(encryptedEncoder, "Pomerium")
redirectURL, _ := urlutil.DeepCopy(opts.AuthenticateURL) redirectURL, _ := urlutil.DeepCopy(opts.AuthenticateURL)
redirectURL.Path = callbackPath redirectURL.Path = callbackPath
@ -135,10 +147,9 @@ func New(opts config.Options) (*Authenticate, error) {
cookieSecret: decodedCookieSecret, cookieSecret: decodedCookieSecret,
cookieCipher: cookieCipher, cookieCipher: cookieCipher,
cookieOptions: cookieOptions, cookieOptions: cookieOptions,
sessionStore: cookieStore, sessionStore: cacheStore,
encryptedEncoder: encryptedEncoder, encryptedEncoder: encryptedEncoder,
sessionLoaders: []sessions.SessionLoader{qpStore, headerStore, cookieStore}, sessionLoaders: []sessions.SessionLoader{cacheStore, qpStore, headerStore, cookieStore},
sessionStores: []sessions.SessionStore{cookieStore, qpStore},
// IdP // IdP
provider: provider, provider: provider,

View file

@ -72,15 +72,18 @@ func TestOptions_Validate(t *testing.T) {
func TestNew(t *testing.T) { func TestNew(t *testing.T) {
good := newTestOptions(t) good := newTestOptions(t)
good.CookieName = "A"
badRedirectURL := newTestOptions(t) badRedirectURL := newTestOptions(t)
badRedirectURL.AuthenticateURL = nil badRedirectURL.AuthenticateURL = nil
badRedirectURL.CookieName = "B"
badCookieName := newTestOptions(t) badCookieName := newTestOptions(t)
badCookieName.CookieName = "" badCookieName.CookieName = ""
badProvider := newTestOptions(t) badProvider := newTestOptions(t)
badProvider.Provider = "" badProvider.Provider = ""
badProvider.CookieName = "C"
tests := []struct { tests := []struct {
name string name string

View file

@ -1,6 +1,7 @@
package authenticate // import "github.com/pomerium/pomerium/authenticate" package authenticate // import "github.com/pomerium/pomerium/authenticate"
import ( import (
"context"
"encoding/base64" "encoding/base64"
"encoding/json" "encoding/json"
"errors" "errors"
@ -18,6 +19,7 @@ import (
"github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/middleware" "github.com/pomerium/pomerium/internal/middleware"
"github.com/pomerium/pomerium/internal/sessions" "github.com/pomerium/pomerium/internal/sessions"
"github.com/pomerium/pomerium/internal/telemetry/trace"
"github.com/pomerium/pomerium/internal/urlutil" "github.com/pomerium/pomerium/internal/urlutil"
) )
@ -58,6 +60,7 @@ func (a *Authenticate) Handler() http.Handler {
v.Use(a.VerifySession) v.Use(a.VerifySession)
v.Path("/sign_in").Handler(httputil.HandlerFunc(a.SignIn)) v.Path("/sign_in").Handler(httputil.HandlerFunc(a.SignIn))
v.Path("/sign_out").Handler(httputil.HandlerFunc(a.SignOut)) v.Path("/sign_out").Handler(httputil.HandlerFunc(a.SignOut))
v.Path("/refresh").Handler(httputil.HandlerFunc(a.Refresh)).Methods(http.MethodGet)
// programmatic access api endpoint // programmatic access api endpoint
api := r.PathPrefix("/api").Subrouter() api := r.PathPrefix("/api").Subrouter()
@ -73,12 +76,12 @@ func (a *Authenticate) VerifySession(next http.Handler) http.Handler {
return httputil.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error { return httputil.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
state, err := sessions.FromContext(r.Context()) state, err := sessions.FromContext(r.Context())
if errors.Is(err, sessions.ErrExpired) { if errors.Is(err, sessions.ErrExpired) {
if err := a.refresh(w, r, state); err != nil { ctx, err := a.refresh(w, r, state)
if err != nil {
log.FromRequest(r).Info().Err(err).Msg("authenticate: verify session, refresh") log.FromRequest(r).Info().Err(err).Msg("authenticate: verify session, refresh")
return a.reauthenticateOrFail(w, r, err) return a.reauthenticateOrFail(w, r, err)
} }
// redirect to restart middleware-chain following refresh next.ServeHTTP(w, r.WithContext(ctx))
httputil.Redirect(w, r, urlutil.GetAbsoluteURL(r).String(), http.StatusFound)
return nil return nil
} else if err != nil { } else if err != nil {
log.FromRequest(r).Info().Err(err).Msg("authenticate: verify session") log.FromRequest(r).Info().Err(err).Msg("authenticate: verify session")
@ -89,15 +92,18 @@ func (a *Authenticate) VerifySession(next http.Handler) http.Handler {
}) })
} }
func (a *Authenticate) refresh(w http.ResponseWriter, r *http.Request, s *sessions.State) error { func (a *Authenticate) refresh(w http.ResponseWriter, r *http.Request, s *sessions.State) (context.Context, error) {
newSession, err := a.provider.Refresh(r.Context(), s) ctx, span := trace.StartSpan(r.Context(), "authenticate.VerifySession/refresh")
defer span.End()
newSession, err := a.provider.Refresh(ctx, s)
if err != nil { if err != nil {
return fmt.Errorf("authenticate: refresh failed: %w", err) return nil, fmt.Errorf("authenticate: refresh failed: %w", err)
} }
if err := a.sessionStore.SaveSession(w, r, newSession); err != nil { if err := a.sessionStore.SaveSession(w, r, newSession); err != nil {
return fmt.Errorf("authenticate: refresh save failed: %w", err) return nil, fmt.Errorf("authenticate: refresh save failed: %w", err)
} }
return nil // return the new session and add it to the current request context
return sessions.NewContext(ctx, newSession, err), nil
} }
// RobotsTxt handles the /robots.txt route. // RobotsTxt handles the /robots.txt route.
@ -158,7 +164,6 @@ func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) error {
encSession, err := a.encryptedEncoder.Marshal(newSession) encSession, err := a.encryptedEncoder.Marshal(newSession)
if err != nil { if err != nil {
return httputil.NewError(http.StatusBadRequest, err) return httputil.NewError(http.StatusBadRequest, err)
} }
callbackParams.Set(urlutil.QueryRefreshToken, string(encSession)) callbackParams.Set(urlutil.QueryRefreshToken, string(encSession))
callbackParams.Set(urlutil.QueryIsProgrammatic, "true") callbackParams.Set(urlutil.QueryIsProgrammatic, "true")
@ -345,3 +350,27 @@ func (a *Authenticate) RefreshAPI(w http.ResponseWriter, r *http.Request) error
w.Write(jsonResponse) w.Write(jsonResponse)
return nil return nil
} }
// Refresh is called by the proxy service to handle backend session refresh.
//
// NOTE: The actual refresh is actually handled as part of the "VerifySession"
// middleware. This handler is responsible for creating a new route scoped
// session and returning it.
func (a *Authenticate) Refresh(w http.ResponseWriter, r *http.Request) error {
s, err := sessions.FromContext(r.Context())
if err != nil {
return httputil.NewError(http.StatusBadRequest, err)
}
routeSession := s.NewSession(r.Host, []string{r.Host, r.FormValue("aud")})
routeSession.AccessTokenID = s.AccessTokenID
signedJWT, err := a.sharedEncoder.Marshal(routeSession.RouteSession())
if err != nil {
return err
}
w.Header().Set("Content-Type", "application/jwt") // RFC 7519 : 10.3.1
w.Write(signedJWT)
return nil
}

View file

@ -11,17 +11,18 @@ import (
"testing" "testing"
"time" "time"
"github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/cryptutil" "github.com/pomerium/pomerium/internal/cryptutil"
"github.com/pomerium/pomerium/internal/encoding" "github.com/pomerium/pomerium/internal/encoding"
"github.com/pomerium/pomerium/internal/encoding/mock" "github.com/pomerium/pomerium/internal/encoding/mock"
"github.com/pomerium/pomerium/internal/frontend" "github.com/pomerium/pomerium/internal/frontend"
"github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/identity" "github.com/pomerium/pomerium/internal/identity"
"github.com/pomerium/pomerium/internal/sessions" "github.com/pomerium/pomerium/internal/sessions"
"github.com/pomerium/pomerium/internal/urlutil" "github.com/pomerium/pomerium/internal/sessions/cookie"
mstore "github.com/pomerium/pomerium/internal/sessions/mock"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
"github.com/pomerium/pomerium/internal/urlutil"
"golang.org/x/crypto/chacha20poly1305" "golang.org/x/crypto/chacha20poly1305"
"golang.org/x/oauth2" "golang.org/x/oauth2"
"gopkg.in/square/go-jose.v2/jwt" "gopkg.in/square/go-jose.v2/jwt"
@ -32,7 +33,7 @@ func testAuthenticate() *Authenticate {
auth.RedirectURL, _ = url.Parse("https://auth.example.com/oauth/callback") auth.RedirectURL, _ = url.Parse("https://auth.example.com/oauth/callback")
auth.sharedKey = cryptutil.NewBase64Key() auth.sharedKey = cryptutil.NewBase64Key()
auth.cookieSecret = cryptutil.NewKey() auth.cookieSecret = cryptutil.NewKey()
auth.cookieOptions = &sessions.CookieOptions{Name: "name"} auth.cookieOptions = &cookie.Options{Name: "name"}
auth.templates = template.Must(frontend.NewTemplates()) auth.templates = template.Must(frontend.NewTemplates())
return &auth return &auth
} }
@ -112,19 +113,19 @@ func TestAuthenticate_SignIn(t *testing.T) {
encoder encoding.MarshalUnmarshaler encoder encoding.MarshalUnmarshaler
wantCode int wantCode int
}{ }{
{"good", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound}, {"good", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
{"session not valid", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(-10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound}, {"session not valid", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(-10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
{"bad redirect uri query", "", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "^^^"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusBadRequest}, {"bad redirect uri query", "", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "^^^"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusBadRequest},
{"bad marshal", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{MarshalError: errors.New("error")}, http.StatusBadRequest}, {"bad marshal", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{MarshalError: errors.New("error")}, http.StatusBadRequest},
{"session error", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, &sessions.MockSessionStore{LoadError: errors.New("error")}, identity.MockProvider{}, &mock.Encoder{}, http.StatusBadRequest}, {"session error", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{LoadError: errors.New("error")}, identity.MockProvider{}, &mock.Encoder{}, http.StatusBadRequest},
{"good with different programmatic redirect", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/", urlutil.QueryCallbackURI: "https://some.example"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound}, {"good with different programmatic redirect", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/", urlutil.QueryCallbackURI: "https://some.example"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
{"encrypted encoder error", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/", urlutil.QueryCallbackURI: "https://some.example"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{MarshalError: errors.New("error")}, http.StatusBadRequest}, {"encrypted encoder error", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/", urlutil.QueryCallbackURI: "https://some.example"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{MarshalError: errors.New("error")}, http.StatusBadRequest},
{"good with callback uri set", "https", "corp.example.example", map[string]string{urlutil.QueryCallbackURI: "https://some.example/", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound}, {"good with callback uri set", "https", "corp.example.example", map[string]string{urlutil.QueryCallbackURI: "https://some.example/", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
{"bad callback uri set", "https", "corp.example.example", map[string]string{urlutil.QueryCallbackURI: "^", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusBadRequest}, {"bad callback uri set", "https", "corp.example.example", map[string]string{urlutil.QueryCallbackURI: "^", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusBadRequest},
{"good programmatic request", "https", "corp.example.example", map[string]string{urlutil.QueryIsProgrammatic: "true", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound}, {"good programmatic request", "https", "corp.example.example", map[string]string{urlutil.QueryIsProgrammatic: "true", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
{"good additional audience", "https", "corp.example.example", map[string]string{urlutil.QueryForwardAuth: "x.y.z", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound}, {"good additional audience", "https", "corp.example.example", map[string]string{urlutil.QueryForwardAuth: "x.y.z", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
{"good user impersonate", "https", "corp.example.example", map[string]string{urlutil.QueryImpersonateAction: "set", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound}, {"good user impersonate", "https", "corp.example.example", map[string]string{urlutil.QueryImpersonateAction: "set", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
{"bad user impersonate save failure", "https", "corp.example.example", map[string]string{urlutil.QueryImpersonateAction: "set", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &sessions.MockSessionStore{SaveError: errors.New("err"), Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusBadRequest}, {"bad user impersonate save failure", "https", "corp.example.example", map[string]string{urlutil.QueryImpersonateAction: "set", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{SaveError: errors.New("err"), Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusBadRequest},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
@ -136,7 +137,7 @@ func TestAuthenticate_SignIn(t *testing.T) {
sharedEncoder: tt.encoder, sharedEncoder: tt.encoder,
encryptedEncoder: tt.encoder, encryptedEncoder: tt.encoder,
sharedCipher: aead, sharedCipher: aead,
cookieOptions: &sessions.CookieOptions{ cookieOptions: &cookie.Options{
Name: "cookie", Name: "cookie",
Domain: "foo", Domain: "foo",
}, },
@ -186,10 +187,10 @@ func TestAuthenticate_SignOut(t *testing.T) {
wantCode int wantCode int
wantBody string wantBody string
}{ }{
{"good post", http.MethodPost, nil, "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, http.StatusFound, ""}, {"good post", http.MethodPost, nil, "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, http.StatusFound, ""},
{"failed revoke", http.MethodPost, nil, "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{RevokeError: errors.New("OH NO")}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, http.StatusBadRequest, "{\"Status\":400,\"Error\":\"Bad Request: OH NO\"}\n"}, {"failed revoke", http.MethodPost, nil, "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{RevokeError: errors.New("OH NO")}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, http.StatusBadRequest, "{\"Status\":400,\"Error\":\"Bad Request: OH NO\"}\n"},
{"load session error", http.MethodPost, errors.New("error"), "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{RevokeError: errors.New("OH NO")}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, http.StatusBadRequest, "{\"Status\":400,\"Error\":\"Bad Request: error\"}\n"}, {"load session error", http.MethodPost, errors.New("error"), "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{RevokeError: errors.New("OH NO")}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, http.StatusBadRequest, "{\"Status\":400,\"Error\":\"Bad Request: error\"}\n"},
{"bad redirect uri", http.MethodPost, nil, "corp.pomerium.io/", "sig", "ts", identity.MockProvider{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, http.StatusBadRequest, "{\"Status\":400,\"Error\":\"Bad Request: corp.pomerium.io/ url does contain a valid scheme\"}\n"}, {"bad redirect uri", http.MethodPost, nil, "corp.pomerium.io/", "sig", "ts", identity.MockProvider{}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, http.StatusBadRequest, "{\"Status\":400,\"Error\":\"Bad Request: corp.pomerium.io/ url does contain a valid scheme\"}\n"},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
@ -247,19 +248,19 @@ func TestAuthenticate_OAuthCallback(t *testing.T) {
want string want string
wantCode int wantCode int
}{ }{
{"good", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusFound}, {"good", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusFound},
{"failed authenticate", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateError: errors.New("error")}, "", http.StatusInternalServerError}, {"failed authenticate", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateError: errors.New("error")}, "", http.StatusInternalServerError},
{"failed save session", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{SaveError: errors.New("error")}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "", http.StatusInternalServerError}, {"failed save session", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{SaveError: errors.New("error")}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "", http.StatusInternalServerError},
{"provider returned error", http.MethodGet, time.Now().Unix(), "", "", "", "idp error", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "", http.StatusBadRequest}, {"provider returned error", http.MethodGet, time.Now().Unix(), "", "", "", "idp error", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "", http.StatusBadRequest},
{"empty code", http.MethodGet, time.Now().Unix(), "", "", "", "", "", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "", http.StatusBadRequest}, {"empty code", http.MethodGet, time.Now().Unix(), "", "", "", "", "", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "", http.StatusBadRequest},
{"invalid redirect uri", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "", http.StatusBadRequest}, {"invalid redirect uri", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "", http.StatusBadRequest},
{"bad redirect uri", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "http://^^^", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusBadRequest}, {"bad redirect uri", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "http://^^^", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusBadRequest},
{"bad timing - too soon", http.MethodGet, time.Now().Add(1 * time.Hour).Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusBadRequest}, {"bad timing - too soon", http.MethodGet, time.Now().Add(1 * time.Hour).Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusBadRequest},
{"bad timing - expired", http.MethodGet, time.Now().Add(-1 * time.Hour).Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusBadRequest}, {"bad timing - expired", http.MethodGet, time.Now().Add(-1 * time.Hour).Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusBadRequest},
{"bad base64", http.MethodGet, time.Now().Unix(), "", "", "^", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusBadRequest}, {"bad base64", http.MethodGet, time.Now().Unix(), "", "", "^", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusBadRequest},
{"too many seperators", http.MethodGet, time.Now().Unix(), "", "", "|ok|now|what", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusBadRequest}, {"too many seperators", http.MethodGet, time.Now().Unix(), "", "", "|ok|now|what", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusBadRequest},
{"bad hmac", http.MethodGet, time.Now().Unix(), "", "NOTMAC", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusBadRequest}, {"bad hmac", http.MethodGet, time.Now().Unix(), "", "NOTMAC", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusBadRequest},
{"bad hmac", http.MethodGet, time.Now().Unix(), base64.URLEncoding.EncodeToString([]byte("malformed_state")), "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusBadRequest}, {"bad hmac", http.MethodGet, time.Now().Unix(), base64.URLEncoding.EncodeToString([]byte("malformed_state")), "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusBadRequest},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
@ -326,12 +327,12 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) {
wantStatus int wantStatus int
}{ }{
{"good", nil, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, http.StatusOK}, {"good", nil, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, http.StatusOK},
{"invalid session", nil, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, errors.New("hi"), identity.MockProvider{}, http.StatusFound}, {"invalid session", nil, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, errors.New("hi"), identity.MockProvider{}, http.StatusFound},
{"good refresh expired", nil, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, http.StatusFound}, {"good refresh expired", nil, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, http.StatusOK},
{"expired,refresh error", nil, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{RefreshError: errors.New("error")}, http.StatusFound}, {"expired,refresh error", nil, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{RefreshError: errors.New("error")}, http.StatusFound},
{"expired,save error", nil, &sessions.MockSessionStore{SaveError: errors.New("error"), Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, http.StatusFound}, {"expired,save error", nil, &mstore.Store{SaveError: errors.New("error"), Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, http.StatusFound},
{"expired XHR,refresh error", map[string]string{"X-Requested-With": "XmlHttpRequest"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{RefreshError: errors.New("error")}, http.StatusUnauthorized}, {"expired XHR,refresh error", map[string]string{"X-Requested-With": "XmlHttpRequest"}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{RefreshError: errors.New("error")}, http.StatusUnauthorized},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
@ -384,11 +385,11 @@ func TestAuthenticate_RefreshAPI(t *testing.T) {
wantStatus int wantStatus int
}{ }{
{"good", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, mock.Encoder{MarshalResponse: []byte("ok")}, mock.Encoder{MarshalResponse: []byte("ok")}, http.StatusOK}, {"good", &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, mock.Encoder{MarshalResponse: []byte("ok")}, mock.Encoder{MarshalResponse: []byte("ok")}, http.StatusOK},
{"refresh error", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshError: errors.New("error")}, mock.Encoder{MarshalResponse: []byte("ok")}, mock.Encoder{MarshalResponse: []byte("ok")}, http.StatusInternalServerError}, {"refresh error", &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshError: errors.New("error")}, mock.Encoder{MarshalResponse: []byte("ok")}, mock.Encoder{MarshalResponse: []byte("ok")}, http.StatusInternalServerError},
{"session is not refreshable error", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, errors.New("session error"), identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, mock.Encoder{MarshalResponse: []byte("ok")}, mock.Encoder{MarshalResponse: []byte("ok")}, http.StatusBadRequest}, {"session is not refreshable error", &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, errors.New("session error"), identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, mock.Encoder{MarshalResponse: []byte("ok")}, mock.Encoder{MarshalResponse: []byte("ok")}, http.StatusBadRequest},
{"secret encoder failed", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, mock.Encoder{MarshalError: errors.New("error")}, mock.Encoder{MarshalResponse: []byte("ok")}, http.StatusInternalServerError}, {"secret encoder failed", &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, mock.Encoder{MarshalError: errors.New("error")}, mock.Encoder{MarshalResponse: []byte("ok")}, http.StatusInternalServerError},
{"shared encoder failed", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, mock.Encoder{MarshalResponse: []byte("ok")}, mock.Encoder{MarshalError: errors.New("error")}, http.StatusInternalServerError}, {"shared encoder failed", &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, mock.Encoder{MarshalResponse: []byte("ok")}, mock.Encoder{MarshalError: errors.New("error")}, http.StatusInternalServerError},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
@ -423,3 +424,54 @@ func TestAuthenticate_RefreshAPI(t *testing.T) {
}) })
} }
} }
func TestAuthenticate_Refresh(t *testing.T) {
t.Parallel()
tests := []struct {
name string
session sessions.SessionStore
ctxError error
provider identity.Authenticator
secretEncoder encoding.MarshalUnmarshaler
sharedEncoder encoding.MarshalUnmarshaler
wantStatus int
}{
{"good", &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, mock.Encoder{MarshalResponse: []byte("ok")}, mock.Encoder{MarshalResponse: []byte("ok")}, http.StatusOK},
{"bad session", &mstore.Store{}, errors.New("err"), identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, mock.Encoder{MarshalResponse: []byte("ok")}, mock.Encoder{MarshalResponse: []byte("ok")}, http.StatusBadRequest},
{"encoder error", &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, mock.Encoder{MarshalResponse: []byte("ok")}, mock.Encoder{MarshalError: errors.New("err")}, http.StatusInternalServerError},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
aead, err := chacha20poly1305.NewX(cryptutil.NewKey())
if err != nil {
t.Fatal(err)
}
a := Authenticate{
sharedKey: cryptutil.NewBase64Key(),
cookieSecret: cryptutil.NewKey(),
RedirectURL: uriParseHelper("https://authenticate.corp.beyondperimeter.com"),
encryptedEncoder: tt.secretEncoder,
sharedEncoder: tt.sharedEncoder,
sessionStore: tt.session,
provider: tt.provider,
cookieCipher: aead,
}
r := httptest.NewRequest("GET", "/", nil)
state, _ := tt.session.LoadSession(r)
ctx := r.Context()
ctx = sessions.NewContext(ctx, state, tt.ctxError)
r = r.WithContext(ctx)
r.Header.Set("Accept", "application/json")
w := httptest.NewRecorder()
httputil.HandlerFunc(a.Refresh).ServeHTTP(w, r)
if status := w.Code; status != tt.wantStatus {
t.Errorf("VerifySession() error = %v, wantErr %v\n%v", w.Result().StatusCode, tt.wantStatus, w.Body.String())
}
})
}
}

View file

@ -11,12 +11,13 @@ import (
"strings" "strings"
"time" "time"
"github.com/fsnotify/fsnotify"
"github.com/pomerium/pomerium/internal/cryptutil" "github.com/pomerium/pomerium/internal/cryptutil"
"github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/telemetry/metrics" "github.com/pomerium/pomerium/internal/telemetry/metrics"
"github.com/pomerium/pomerium/internal/urlutil" "github.com/pomerium/pomerium/internal/urlutil"
"github.com/cespare/xxhash/v2"
"github.com/fsnotify/fsnotify"
"github.com/mitchellh/hashstructure" "github.com/mitchellh/hashstructure"
"github.com/spf13/viper" "github.com/spf13/viper"
"gopkg.in/yaml.v2" "gopkg.in/yaml.v2"
@ -477,7 +478,7 @@ type OptionsUpdater interface {
// Checksum returns the checksum of the current options struct // Checksum returns the checksum of the current options struct
func (o *Options) Checksum() string { func (o *Options) Checksum() string {
hash, err := hashstructure.Hash(o, nil) hash, err := hashstructure.Hash(o, &hashstructure.HashOptions{Hasher: xxhash.New()})
if err != nil { if err != nil {
log.Warn().Err(err).Msg("config: checksum failure") log.Warn().Err(err).Msg("config: checksum failure")
return "no checksum available" return "no checksum available"

4
go.mod
View file

@ -6,9 +6,9 @@ require (
cloud.google.com/go v0.49.0 // indirect cloud.google.com/go v0.49.0 // indirect
contrib.go.opencensus.io/exporter/jaeger v0.2.0 contrib.go.opencensus.io/exporter/jaeger v0.2.0
contrib.go.opencensus.io/exporter/prometheus v0.1.0 contrib.go.opencensus.io/exporter/prometheus v0.1.0
github.com/cespare/xxhash/v2 v2.1.1 // indirect github.com/cespare/xxhash/v2 v2.1.1
github.com/fsnotify/fsnotify v1.4.7 github.com/fsnotify/fsnotify v1.4.7
github.com/golang/groupcache v0.0.0-20191027212112-611e8accdfc9 // indirect github.com/golang/groupcache v0.0.0-20191227052852-215e87163ea7
github.com/golang/mock v1.3.1 github.com/golang/mock v1.3.1
github.com/golang/protobuf v1.3.2 github.com/golang/protobuf v1.3.2
github.com/google/go-cmp v0.3.1 github.com/google/go-cmp v0.3.1

5
go.sum
View file

@ -71,6 +71,8 @@ github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6 h1:ZgQEtGgCBiWRM
github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
github.com/golang/groupcache v0.0.0-20191027212112-611e8accdfc9 h1:uHTyIjqVhYRhLbJ8nIiOJHkEZZ+5YoOsAbD3sk82NiE= github.com/golang/groupcache v0.0.0-20191027212112-611e8accdfc9 h1:uHTyIjqVhYRhLbJ8nIiOJHkEZZ+5YoOsAbD3sk82NiE=
github.com/golang/groupcache v0.0.0-20191027212112-611e8accdfc9/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/groupcache v0.0.0-20191027212112-611e8accdfc9/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
github.com/golang/groupcache v0.0.0-20191227052852-215e87163ea7 h1:5ZkaAPbicIKTF2I64qf5Fh8Aa83Q/dnOafMYV0OMwjA=
github.com/golang/groupcache v0.0.0-20191227052852-215e87163ea7/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A=
github.com/golang/mock v1.2.0 h1:28o5sBqPkBsMGnC6b4MvE2TzSr5/AT4c/1fLqVGIwlk= github.com/golang/mock v1.2.0 h1:28o5sBqPkBsMGnC6b4MvE2TzSr5/AT4c/1fLqVGIwlk=
github.com/golang/mock v1.2.0/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= github.com/golang/mock v1.2.0/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A=
@ -162,6 +164,7 @@ github.com/prometheus/client_golang v0.9.3/go.mod h1:/TN21ttK/J9q6uSwhBd54HahCDf
github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5FsnadC4Ky3P0J6CfImo= github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5FsnadC4Ky3P0J6CfImo=
github.com/prometheus/client_golang v1.2.1 h1:JnMpQc6ppsNgw9QPAGF6Dod479itz7lvlsMzzNayLOI= github.com/prometheus/client_golang v1.2.1 h1:JnMpQc6ppsNgw9QPAGF6Dod479itz7lvlsMzzNayLOI=
github.com/prometheus/client_golang v1.2.1/go.mod h1:XMU6Z2MjaRKVu/dC1qupJI9SiNkDYzz3xecMgSW/F+U= github.com/prometheus/client_golang v1.2.1/go.mod h1:XMU6Z2MjaRKVu/dC1qupJI9SiNkDYzz3xecMgSW/F+U=
github.com/prometheus/client_golang v1.3.0 h1:miYCvYqFXtl/J9FIy8eNpBfYthAEFg+Ys0XyUVEcDsc=
github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo=
github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90 h1:S/YWwWx/RA8rT8tKFRuGUZhuA90OyIBpPCXkcbwU8DE= github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90 h1:S/YWwWx/RA8rT8tKFRuGUZhuA90OyIBpPCXkcbwU8DE=
github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
@ -214,6 +217,7 @@ github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/spf13/viper v1.5.0 h1:GpsTwfsQ27oS/Aha/6d1oD7tpKIqWnOA6tgOX9HHkt4= github.com/spf13/viper v1.5.0 h1:GpsTwfsQ27oS/Aha/6d1oD7tpKIqWnOA6tgOX9HHkt4=
github.com/spf13/viper v1.5.0/go.mod h1:AkYRkVJF8TkSG/xet6PzXX+l39KhhXa2pdqVSxnTcn4= github.com/spf13/viper v1.5.0/go.mod h1:AkYRkVJF8TkSG/xet6PzXX+l39KhhXa2pdqVSxnTcn4=
github.com/spf13/viper v1.6.1 h1:VPZzIkznI1YhVMRi6vNFLHSwhnhReBfgTxIPccpfdZk=
github.com/spf13/viper v1.6.1/go.mod h1:t3iDnF5Jlj76alVNuyFBk5oUMCvsrkbvZK0WQdfDi5k= github.com/spf13/viper v1.6.1/go.mod h1:t3iDnF5Jlj76alVNuyFBk5oUMCvsrkbvZK0WQdfDi5k=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
@ -384,6 +388,7 @@ gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI=
gopkg.in/ini.v1 v1.51.0 h1:AQvPpx3LzTDM0AjnIRlVFwFFGC+npRopjZxLJj6gdno=
gopkg.in/ini.v1 v1.51.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k= gopkg.in/ini.v1 v1.51.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k=
gopkg.in/resty.v1 v1.12.0/go.mod h1:mDo4pnntr5jdWRML875a/NmxYqAlA73dVijT2AXvQQo= gopkg.in/resty.v1 v1.12.0/go.mod h1:mDo4pnntr5jdWRML875a/NmxYqAlA73dVijT2AXvQQo=
gopkg.in/square/go-jose.v2 v2.4.0 h1:0kXPskUMGAXXWJlP05ktEMOV0vmzFQUWw6d+aZJQU8A= gopkg.in/square/go-jose.v2 v2.4.0 h1:0kXPskUMGAXXWJlP05ktEMOV0vmzFQUWw6d+aZJQU8A=

View file

@ -1,5 +1,13 @@
package mock // import "github.com/pomerium/pomerium/internal/encoding/mock" package mock // import "github.com/pomerium/pomerium/internal/encoding/mock"
import (
"github.com/pomerium/pomerium/internal/encoding"
)
var _ encoding.MarshalUnmarshaler = &Encoder{}
var _ encoding.Marshaler = &Encoder{}
var _ encoding.Unmarshaler = &Encoder{}
// Encoder MockCSRFStore is a mock implementation of Cipher. // Encoder MockCSRFStore is a mock implementation of Cipher.
type Encoder struct { type Encoder struct {
MarshalResponse []byte MarshalResponse []byte

View file

@ -8,23 +8,21 @@ import (
"fmt" "fmt"
"io" "io"
"io/ioutil" "io/ioutil"
"net"
"net/http" "net/http"
"net/url" "net/url"
"time" "time"
"go.opencensus.io/plugin/ochttp"
) )
// ErrTokenRevoked signifies a token revokation or expiration error // ErrTokenRevoked signifies a token revokation or expiration error
var ErrTokenRevoked = errors.New("token expired or revoked") var ErrTokenRevoked = errors.New("token expired or revoked")
var httpClient = &http.Client{ // DefaultClient avoids leaks by setting an upper limit for timeouts.
Timeout: time.Second * 5, var DefaultClient = &http.Client{
Transport: &http.Transport{ Timeout: 1 * time.Minute,
Dial: (&net.Dialer{ //todo(bdd): incorporate metrics.HTTPMetricsRoundTripper
Timeout: 2 * time.Second, Transport: &ochttp.Transport{},
}).Dial,
TLSHandshakeTimeout: 2 * time.Second,
},
} }
// Client provides a simple helper interface to make HTTP requests // Client provides a simple helper interface to make HTTP requests
@ -36,9 +34,11 @@ func Client(ctx context.Context, method, endpoint, userAgent string, headers map
case http.MethodGet: case http.MethodGet:
// error checking skipped because we are just parsing in // error checking skipped because we are just parsing in
// order to make a copy of an existing URL // order to make a copy of an existing URL
if params != nil {
u, _ := url.Parse(endpoint) u, _ := url.Parse(endpoint)
u.RawQuery = params.Encode() u.RawQuery = params.Encode()
endpoint = u.String() endpoint = u.String()
}
default: default:
return fmt.Errorf(http.StatusText(http.StatusBadRequest)) return fmt.Errorf(http.StatusText(http.StatusBadRequest))
} }
@ -52,7 +52,7 @@ func Client(ctx context.Context, method, endpoint, userAgent string, headers map
req.Header.Set(k, v) req.Header.Set(k, v)
} }
resp, err := httpClient.Do(req) resp, err := DefaultClient.Do(req)
if err != nil { if err != nil {
return err return err
} }
@ -79,7 +79,6 @@ func Client(ctx context.Context, method, endpoint, userAgent string, headers map
return fmt.Errorf(http.StatusText(resp.StatusCode)) return fmt.Errorf(http.StatusText(resp.StatusCode))
} }
} }
if response != nil { if response != nil {
err := json.Unmarshal(respBody, &response) err := json.Unmarshal(respBody, &response)
if err != nil { if err != nil {

131
internal/sessions/cache/cache_store.go vendored Normal file
View file

@ -0,0 +1,131 @@
package cache // import "github.com/pomerium/pomerium/internal/sessions/cache"
import (
"context"
"errors"
"fmt"
"net/http"
"github.com/golang/groupcache"
"github.com/pomerium/pomerium/internal/encoding"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/sessions"
)
var _ sessions.SessionStore = &Store{}
var _ sessions.SessionLoader = &Store{}
const (
defaultQueryParamKey = "ati"
)
// Store implements the session store interface using a distributed cache.
type Store struct {
name string
encoder encoding.Marshaler
decoder encoding.Unmarshaler
cache *groupcache.Group
wrappedStore sessions.SessionStore
}
// defaultCacheSize is ~10MB
var defaultCacheSize int64 = 10 << 20
// NewStore creates a new session store built on the distributed caching library
// groupcache. On a cache miss, the cache store attempts to fallback to another
// SessionStore implementation.
func NewStore(enc encoding.MarshalUnmarshaler, wrappedStore sessions.SessionStore, name string) *Store {
store := &Store{
name: name,
encoder: enc,
decoder: enc,
wrappedStore: wrappedStore,
}
store.cache = groupcache.NewGroup(name, defaultCacheSize, groupcache.GetterFunc(
func(ctx context.Context, id string, dest groupcache.Sink) error {
// fill the cache with session set as part of the request
// context set previously as part of SaveSession.
b := fromContext(ctx)
if len(b) == 0 {
return fmt.Errorf("sessions/cache: cannot fill key %s from ctx", id)
}
if err := dest.SetBytes(b); err != nil {
return fmt.Errorf("sessions/cache: sink error %w", err)
}
return nil
},
))
return store
}
// LoadSession implements SessionLoaders's LoadSession method for cache store.
func (s *Store) LoadSession(r *http.Request) (*sessions.State, error) {
// look for our cache's key in the default query param
sessionID := r.URL.Query().Get(defaultQueryParamKey)
if sessionID == "" {
// if unset, fallback to default cache store
log.FromRequest(r).Debug().Msg("sessions/cache: no query param, trying wrapped loader")
return s.wrappedStore.LoadSession(r)
}
var b []byte
if err := s.cache.Get(r.Context(), sessionID, groupcache.AllocatingByteSliceSink(&b)); err != nil {
log.FromRequest(r).Debug().Err(err).Msg("sessions/cache: miss, trying wrapped loader")
return s.wrappedStore.LoadSession(r)
}
var session sessions.State
if err := s.decoder.Unmarshal(b, &session); err != nil {
log.FromRequest(r).Error().Err(err).Msg("sessions/cache: unmarshal")
return nil, sessions.ErrMalformed
}
return &session, nil
}
// ClearSession implements SessionStore's ClearSession for the cache store.
// Since group cache has no explicit eviction, we just call the wrapped
// store's ClearSession method here.
func (s *Store) ClearSession(w http.ResponseWriter, r *http.Request) {
s.wrappedStore.ClearSession(w, r)
}
// SaveSession implements SessionStore's SaveSession method for cache store.
func (s *Store) SaveSession(w http.ResponseWriter, r *http.Request, x interface{}) error {
err := s.wrappedStore.SaveSession(w, r, x)
if err != nil {
return fmt.Errorf("sessions/cache: wrapped store save error %w", err)
}
state, ok := x.(*sessions.State)
if !ok {
return errors.New("internal/sessions: cannot cache non state type")
}
data, err := s.encoder.Marshal(&state)
if err != nil {
return fmt.Errorf("sessions/cache: marshal %w", err)
}
ctx := newContext(r.Context(), data)
var b []byte
return s.cache.Get(ctx, state.AccessTokenID, groupcache.AllocatingByteSliceSink(&b))
}
var sessionCtxKey = &contextKey{"PomeriumCachedSessionBytes"}
type contextKey struct {
name string
}
func newContext(ctx context.Context, b []byte) context.Context {
ctx = context.WithValue(ctx, sessionCtxKey, b)
return ctx
}
func fromContext(ctx context.Context) []byte {
b, _ := ctx.Value(sessionCtxKey).([]byte)
return b
}

View file

@ -0,0 +1,133 @@
package cache
import (
"fmt"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/google/go-cmp/cmp"
"github.com/pomerium/pomerium/internal/cryptutil"
"github.com/pomerium/pomerium/internal/encoding/ecjson"
"github.com/pomerium/pomerium/internal/sessions"
"github.com/pomerium/pomerium/internal/sessions/cookie"
"gopkg.in/square/go-jose.v2/jwt"
)
func testAuthorizer(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, err := sessions.FromContext(r.Context())
if err != nil {
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}
next.ServeHTTP(w, r)
})
}
func TestVerifier(t *testing.T) {
fnh := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
fmt.Fprint(w, http.StatusText(http.StatusOK))
w.WriteHeader(http.StatusOK)
})
tests := []struct {
name string
skipSave bool
cacheSize int64
state sessions.State
wantBody string
wantStatus int
}{
{"good", false, 1 << 10, sessions.State{AccessTokenID: cryptutil.NewBase64Key(), Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}, http.StatusText(http.StatusOK), http.StatusOK},
{"expired", false, 1 << 10, sessions.State{AccessTokenID: cryptutil.NewBase64Key(), Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}, "internal/sessions: validation failed, token is expired (exp)\n", http.StatusUnauthorized},
{"empty", false, 1 << 10, sessions.State{AccessTokenID: "", Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}, "internal/sessions: session is not found\n", http.StatusUnauthorized},
{"miss", true, 1 << 10, sessions.State{AccessTokenID: cryptutil.NewBase64Key(), Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}, "internal/sessions: session is not found\n", http.StatusUnauthorized},
{"cache eviction", false, 1, sessions.State{AccessTokenID: cryptutil.NewBase64Key(), Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}, "internal/sessions: session is not found\n", http.StatusUnauthorized},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
defaultCacheSize = tt.cacheSize
cipher, err := cryptutil.NewAEADCipherFromBase64(cryptutil.NewBase64Key())
encoder := ecjson.New(cipher)
if err != nil {
t.Fatal(err)
}
cs, err := cookie.NewStore(&cookie.Options{Name: t.Name()}, encoder)
if err != nil {
t.Fatal(err)
}
cacheStore := NewStore(encoder, cs, t.Name())
r := httptest.NewRequest(http.MethodGet, "/", nil)
q := r.URL.Query()
q.Set(defaultQueryParamKey, tt.state.AccessTokenID)
r.URL.RawQuery = q.Encode()
r.Header.Set("Accept", "application/json")
w := httptest.NewRecorder()
got := sessions.RetrieveSession(cacheStore)(testAuthorizer((fnh)))
if !tt.skipSave {
cacheStore.SaveSession(w, r, &tt.state)
}
for i := 1; i <= 10; i++ {
s := tt.state
s.AccessTokenID = cryptutil.NewBase64Key()
cacheStore.SaveSession(w, r, s)
}
got.ServeHTTP(w, r)
gotBody := w.Body.String()
gotStatus := w.Result().StatusCode
if diff := cmp.Diff(gotBody, tt.wantBody); diff != "" {
t.Errorf("RetrieveSession() = %v", diff)
}
if diff := cmp.Diff(gotStatus, tt.wantStatus); diff != "" {
t.Errorf("RetrieveSession() = %v", diff)
}
})
}
}
func TestStore_SaveSession(t *testing.T) {
tests := []struct {
name string
x interface{}
wantErr bool
}{
{"good", &sessions.State{AccessTokenID: cryptutil.NewBase64Key(), Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}, false},
{"bad type", "bad type!", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cipher, err := cryptutil.NewAEADCipherFromBase64(cryptutil.NewBase64Key())
encoder := ecjson.New(cipher)
if err != nil {
t.Fatal(err)
}
cs, err := cookie.NewStore(&cookie.Options{
Name: "_pomerium",
}, encoder)
if err != nil {
t.Fatal(err)
}
cacheStore := NewStore(encoder, cs, t.Name())
r := httptest.NewRequest(http.MethodGet, "/", nil)
r.Header.Set("Accept", "application/json")
w := httptest.NewRecorder()
if err := cacheStore.SaveSession(w, r, tt.x); (err != nil) != tt.wantErr {
t.Errorf("Store.SaveSession() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}

View file

@ -1,4 +1,4 @@
package sessions // import "github.com/pomerium/pomerium/internal/sessions" package cookie // import "github.com/pomerium/pomerium/internal/sessions/cookie"
import ( import (
"errors" "errors"
@ -8,8 +8,15 @@ import (
"time" "time"
"github.com/pomerium/pomerium/internal/encoding" "github.com/pomerium/pomerium/internal/encoding"
"github.com/pomerium/pomerium/internal/sessions"
) )
var _ sessions.SessionStore = &Store{}
var _ sessions.SessionLoader = &Store{}
// timeNow is time.Now but pulled out as a variable for tests.
var timeNow = time.Now
const ( const (
// ChunkedCanaryByte is the byte value used as a canary prefix to distinguish if // ChunkedCanaryByte is the byte value used as a canary prefix to distinguish if
// the cookie is multi-part or not. This constant *should not* be valid // the cookie is multi-part or not. This constant *should not* be valid
@ -25,8 +32,8 @@ const (
MaxNumChunks = 5 MaxNumChunks = 5
) )
// CookieStore implements the session store interface for session cookies. // Store implements the session store interface for session cookies.
type CookieStore struct { type Store struct {
Name string Name string
Domain string Domain string
Expire time.Duration Expire time.Duration
@ -37,8 +44,8 @@ type CookieStore struct {
decoder encoding.Unmarshaler decoder encoding.Unmarshaler
} }
// CookieOptions holds options for CookieStore // Options holds options for Store
type CookieOptions struct { type Options struct {
Name string Name string
Domain string Domain string
Expire time.Duration Expire time.Duration
@ -46,8 +53,9 @@ type CookieOptions struct {
Secure bool Secure bool
} }
// NewCookieStore returns a new session with ciphers for each of the cookie secrets // NewStore returns a new store that implements the SessionStore interface
func NewCookieStore(opts *CookieOptions, encoder encoding.MarshalUnmarshaler) (*CookieStore, error) { // using http cookies.
func NewStore(opts *Options, encoder encoding.MarshalUnmarshaler) (sessions.SessionStore, error) {
cs, err := NewCookieLoader(opts, encoder) cs, err := NewCookieLoader(opts, encoder)
if err != nil { if err != nil {
return nil, err return nil, err
@ -56,12 +64,13 @@ func NewCookieStore(opts *CookieOptions, encoder encoding.MarshalUnmarshaler) (*
return cs, nil return cs, nil
} }
// NewCookieLoader returns a new session with ciphers for each of the cookie secrets // NewCookieLoader returns a new store that implements the SessionLoader
func NewCookieLoader(opts *CookieOptions, dencoder encoding.Unmarshaler) (*CookieStore, error) { // interface using http cookies.
func NewCookieLoader(opts *Options, dencoder encoding.Unmarshaler) (*Store, error) {
if dencoder == nil { if dencoder == nil {
return nil, fmt.Errorf("internal/sessions: dencoder cannot be nil") return nil, fmt.Errorf("internal/sessions: dencoder cannot be nil")
} }
cs, err := newCookieStore(opts) cs, err := newStore(opts)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -69,12 +78,12 @@ func NewCookieLoader(opts *CookieOptions, dencoder encoding.Unmarshaler) (*Cooki
return cs, nil return cs, nil
} }
func newCookieStore(opts *CookieOptions) (*CookieStore, error) { func newStore(opts *Options) (*Store, error) {
if opts.Name == "" { if opts.Name == "" {
return nil, fmt.Errorf("internal/sessions: cookie name cannot be empty") return nil, fmt.Errorf("internal/sessions: cookie name cannot be empty")
} }
return &CookieStore{ return &Store{
Name: opts.Name, Name: opts.Name,
Secure: opts.Secure, Secure: opts.Secure,
HTTPOnly: opts.HTTPOnly, HTTPOnly: opts.HTTPOnly,
@ -83,7 +92,7 @@ func newCookieStore(opts *CookieOptions) (*CookieStore, error) {
}, nil }, nil
} }
func (cs *CookieStore) makeCookie(value string) *http.Cookie { func (cs *Store) makeCookie(value string) *http.Cookie {
return &http.Cookie{ return &http.Cookie{
Name: cs.Name, Name: cs.Name,
Value: value, Value: value,
@ -96,7 +105,7 @@ func (cs *CookieStore) makeCookie(value string) *http.Cookie {
} }
// ClearSession clears the session cookie from a request // ClearSession clears the session cookie from a request
func (cs *CookieStore) ClearSession(w http.ResponseWriter, r *http.Request) { func (cs *Store) ClearSession(w http.ResponseWriter, r *http.Request) {
c := cs.makeCookie("") c := cs.makeCookie("")
c.MaxAge = -1 c.MaxAge = -1
c.Expires = timeNow().Add(-time.Hour) c.Expires = timeNow().Add(-time.Hour)
@ -115,51 +124,51 @@ func getCookies(r *http.Request, name string) []*http.Cookie {
} }
// LoadSession returns a State from the cookie in the request. // LoadSession returns a State from the cookie in the request.
func (cs *CookieStore) LoadSession(r *http.Request) (*State, error) { func (cs *Store) LoadSession(r *http.Request) (*sessions.State, error) {
cookies := getCookies(r, cs.Name) cookies := getCookies(r, cs.Name)
if len(cookies) == 0 { if len(cookies) == 0 {
return nil, ErrNoSessionFound return nil, sessions.ErrNoSessionFound
} }
for _, cookie := range cookies { for _, cookie := range cookies {
data := loadChunkedCookie(r, cookie) data := loadChunkedCookie(r, cookie)
session := &State{} session := &sessions.State{}
err := cs.decoder.Unmarshal([]byte(data), session) err := cs.decoder.Unmarshal([]byte(data), session)
if err == nil { if err == nil {
return session, nil return session, nil
} }
} }
return nil, ErrMalformed return nil, sessions.ErrMalformed
} }
// SaveSession saves a session state to a request's cookie store. // SaveSession saves a session state to a request's cookie store.
func (cs *CookieStore) SaveSession(w http.ResponseWriter, _ *http.Request, x interface{}) error { func (cs *Store) SaveSession(w http.ResponseWriter, _ *http.Request, x interface{}) error {
var value string var value string
if cs.encoder != nil {
data, err := cs.encoder.Marshal(x)
if err != nil {
return err
}
value = string(data)
} else {
switch v := x.(type) { switch v := x.(type) {
case []byte: case []byte:
value = string(v) value = string(v)
case string: case string:
value = v value = v
default: default:
if cs.encoder == nil {
return errors.New("internal/sessions: cannot save non-string type") return errors.New("internal/sessions: cannot save non-string type")
} }
data, err := cs.encoder.Marshal(x)
if err != nil {
return err
} }
value = string(data)
}
cs.setSessionCookie(w, value) cs.setSessionCookie(w, value)
return nil return nil
} }
func (cs *CookieStore) setSessionCookie(w http.ResponseWriter, val string) { func (cs *Store) setSessionCookie(w http.ResponseWriter, val string) {
cs.setCookie(w, cs.makeCookie(val)) cs.setCookie(w, cs.makeCookie(val))
} }
func (cs *CookieStore) setCookie(w http.ResponseWriter, cookie *http.Cookie) { func (cs *Store) setCookie(w http.ResponseWriter, cookie *http.Cookie) {
if len(cookie.String()) <= MaxChunkSize { if len(cookie.String()) <= MaxChunkSize {
http.SetCookie(w, cookie) http.SetCookie(w, cookie)
return return
@ -180,9 +189,15 @@ func (cs *CookieStore) setCookie(w http.ResponseWriter, cookie *http.Cookie) {
} }
func loadChunkedCookie(r *http.Request, c *http.Cookie) string { func loadChunkedCookie(r *http.Request, c *http.Cookie) string {
data := c.Value if len(c.Value) == 0 {
return ""
}
// if the first byte is our canary byte, we need to handle the multipart bit // if the first byte is our canary byte, we need to handle the multipart bit
if []byte(c.Value)[0] == ChunkedCanaryByte { if []byte(c.Value)[0] != ChunkedCanaryByte {
return c.Value
}
data := c.Value
var b strings.Builder var b strings.Builder
fmt.Fprintf(&b, "%s", data[1:]) fmt.Fprintf(&b, "%s", data[1:])
for i := 1; i <= MaxNumChunks; i++ { for i := 1; i <= MaxNumChunks; i++ {
@ -193,7 +208,7 @@ func loadChunkedCookie(r *http.Request, c *http.Cookie) string {
fmt.Fprintf(&b, "%s", next.Value) fmt.Fprintf(&b, "%s", next.Value)
} }
data = b.String() data = b.String()
}
return data return data
} }

View file

@ -1,4 +1,4 @@
package sessions // import "github.com/pomerium/pomerium/internal/sessions" package cookie // import "github.com/pomerium/pomerium/internal/sessions/cookie"
import ( import (
"crypto/rand" "crypto/rand"
@ -13,12 +13,13 @@ import (
"github.com/pomerium/pomerium/internal/encoding" "github.com/pomerium/pomerium/internal/encoding"
"github.com/pomerium/pomerium/internal/encoding/ecjson" "github.com/pomerium/pomerium/internal/encoding/ecjson"
"github.com/pomerium/pomerium/internal/encoding/mock" "github.com/pomerium/pomerium/internal/encoding/mock"
"github.com/pomerium/pomerium/internal/sessions"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts" "github.com/google/go-cmp/cmp/cmpopts"
) )
func TestNewCookieStore(t *testing.T) { func TestNewStore(t *testing.T) {
cipher, err := cryptutil.NewAEADCipher(cryptutil.NewKey()) cipher, err := cryptutil.NewAEADCipher(cryptutil.NewKey())
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -26,28 +27,28 @@ func TestNewCookieStore(t *testing.T) {
encoder := ecjson.New(cipher) encoder := ecjson.New(cipher)
tests := []struct { tests := []struct {
name string name string
opts *CookieOptions opts *Options
encoder encoding.MarshalUnmarshaler encoder encoding.MarshalUnmarshaler
want *CookieStore want sessions.SessionStore
wantErr bool wantErr bool
}{ }{
{"good", &CookieOptions{Name: "_cookie", Secure: true, HTTPOnly: true, Domain: "pomerium.io", Expire: 10 * time.Second}, encoder, &CookieStore{Name: "_cookie", Secure: true, HTTPOnly: true, Domain: "pomerium.io", Expire: 10 * time.Second}, false}, {"good", &Options{Name: "_cookie", Secure: true, HTTPOnly: true, Domain: "pomerium.io", Expire: 10 * time.Second}, encoder, &Store{Name: "_cookie", Secure: true, HTTPOnly: true, Domain: "pomerium.io", Expire: 10 * time.Second}, false},
{"missing name", &CookieOptions{Name: "", Secure: true, HTTPOnly: true, Domain: "pomerium.io", Expire: 10 * time.Second}, encoder, nil, true}, {"missing name", &Options{Name: "", Secure: true, HTTPOnly: true, Domain: "pomerium.io", Expire: 10 * time.Second}, encoder, nil, true},
{"missing encoder", &CookieOptions{Name: "_cookie", Secure: true, HTTPOnly: true, Domain: "pomerium.io", Expire: 10 * time.Second}, nil, nil, true}, {"missing encoder", &Options{Name: "_cookie", Secure: true, HTTPOnly: true, Domain: "pomerium.io", Expire: 10 * time.Second}, nil, nil, true},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
got, err := NewCookieStore(tt.opts, tt.encoder) got, err := NewStore(tt.opts, tt.encoder)
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("NewCookieStore() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("NewStore() error = %v, wantErr %v", err, tt.wantErr)
return return
} }
cmpOpts := []cmp.Option{ cmpOpts := []cmp.Option{
cmpopts.IgnoreUnexported(CookieStore{}), cmpopts.IgnoreUnexported(Store{}),
} }
if diff := cmp.Diff(got, tt.want, cmpOpts...); diff != "" { if diff := cmp.Diff(got, tt.want, cmpOpts...); diff != "" {
t.Errorf("NewCookieStore() = %s", diff) t.Errorf("NewStore() = %s", diff)
} }
}) })
} }
@ -60,14 +61,14 @@ func TestNewCookieLoader(t *testing.T) {
encoder := ecjson.New(cipher) encoder := ecjson.New(cipher)
tests := []struct { tests := []struct {
name string name string
opts *CookieOptions opts *Options
encoder encoding.MarshalUnmarshaler encoder encoding.MarshalUnmarshaler
want *CookieStore want *Store
wantErr bool wantErr bool
}{ }{
{"good", &CookieOptions{Name: "_cookie", Secure: true, HTTPOnly: true, Domain: "pomerium.io", Expire: 10 * time.Second}, encoder, &CookieStore{Name: "_cookie", Secure: true, HTTPOnly: true, Domain: "pomerium.io", Expire: 10 * time.Second}, false}, {"good", &Options{Name: "_cookie", Secure: true, HTTPOnly: true, Domain: "pomerium.io", Expire: 10 * time.Second}, encoder, &Store{Name: "_cookie", Secure: true, HTTPOnly: true, Domain: "pomerium.io", Expire: 10 * time.Second}, false},
{"missing name", &CookieOptions{Name: "", Secure: true, HTTPOnly: true, Domain: "pomerium.io", Expire: 10 * time.Second}, encoder, nil, true}, {"missing name", &Options{Name: "", Secure: true, HTTPOnly: true, Domain: "pomerium.io", Expire: 10 * time.Second}, encoder, nil, true},
{"missing encoder", &CookieOptions{Name: "_cookie", Secure: true, HTTPOnly: true, Domain: "pomerium.io", Expire: 10 * time.Second}, nil, nil, true}, {"missing encoder", &Options{Name: "_cookie", Secure: true, HTTPOnly: true, Domain: "pomerium.io", Expire: 10 * time.Second}, nil, nil, true},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
@ -77,7 +78,7 @@ func TestNewCookieLoader(t *testing.T) {
return return
} }
cmpOpts := []cmp.Option{ cmpOpts := []cmp.Option{
cmpopts.IgnoreUnexported(CookieStore{}), cmpopts.IgnoreUnexported(Store{}),
} }
if diff := cmp.Diff(got, tt.want, cmpOpts...); diff != "" { if diff := cmp.Diff(got, tt.want, cmpOpts...); diff != "" {
@ -87,7 +88,7 @@ func TestNewCookieLoader(t *testing.T) {
} }
} }
func TestCookieStore_SaveSession(t *testing.T) { func TestStore_SaveSession(t *testing.T) {
c, err := cryptutil.NewAEADCipher(cryptutil.NewKey()) c, err := cryptutil.NewAEADCipher(cryptutil.NewKey())
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -106,17 +107,17 @@ func TestCookieStore_SaveSession(t *testing.T) {
wantErr bool wantErr bool
wantLoadErr bool wantLoadErr bool
}{ }{
{"good", &State{Email: "user@domain.com", User: "user"}, ecjson.New(c), ecjson.New(c), false, false}, {"good", &sessions.State{Email: "user@domain.com", User: "user"}, ecjson.New(c), ecjson.New(c), false, false},
{"bad cipher", &State{Email: "user@domain.com", User: "user"}, nil, nil, true, true}, {"bad cipher", &sessions.State{Email: "user@domain.com", User: "user"}, nil, nil, true, true},
{"huge cookie", &State{Subject: fmt.Sprintf("%x", hugeString), Email: "user@domain.com", User: "user"}, ecjson.New(c), ecjson.New(c), false, false}, {"huge cookie", &sessions.State{Subject: fmt.Sprintf("%x", hugeString), Email: "user@domain.com", User: "user"}, ecjson.New(c), ecjson.New(c), false, false},
{"marshal error", &State{Email: "user@domain.com", User: "user"}, mock.Encoder{MarshalError: errors.New("error")}, ecjson.New(c), true, true}, {"marshal error", &sessions.State{Email: "user@domain.com", User: "user"}, mock.Encoder{MarshalError: errors.New("error")}, ecjson.New(c), true, true},
{"nil encoder cannot save non string type", &State{Email: "user@domain.com", User: "user"}, nil, ecjson.New(c), true, true}, {"nil encoder cannot save non string type", &sessions.State{Email: "user@domain.com", User: "user"}, nil, ecjson.New(c), true, true},
{"good marshal string directly", cryptutil.NewBase64Key(), nil, ecjson.New(c), false, true}, {"good marshal string directly", cryptutil.NewBase64Key(), nil, ecjson.New(c), false, true},
{"good marshal bytes directly", cryptutil.NewKey(), nil, ecjson.New(c), false, true}, {"good marshal bytes directly", cryptutil.NewKey(), nil, ecjson.New(c), false, true},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
s := &CookieStore{ s := &Store{
Name: "_pomerium", Name: "_pomerium",
Secure: true, Secure: true,
HTTPOnly: true, HTTPOnly: true,
@ -130,7 +131,7 @@ func TestCookieStore_SaveSession(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
if err := s.SaveSession(w, r, tt.State); (err != nil) != tt.wantErr { if err := s.SaveSession(w, r, tt.State); (err != nil) != tt.wantErr {
t.Errorf("CookieStore.SaveSession() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("Store.SaveSession() error = %v, wantErr %v", err, tt.wantErr)
} }
r = httptest.NewRequest("GET", "/", nil) r = httptest.NewRequest("GET", "/", nil)
for _, cookie := range w.Result().Cookies() { for _, cookie := range w.Result().Cookies() {
@ -143,11 +144,11 @@ func TestCookieStore_SaveSession(t *testing.T) {
return return
} }
cmpOpts := []cmp.Option{ cmpOpts := []cmp.Option{
cmpopts.IgnoreUnexported(State{}), cmpopts.IgnoreUnexported(sessions.State{}),
} }
if err == nil { if err == nil {
if diff := cmp.Diff(state, tt.State, cmpOpts...); diff != "" { if diff := cmp.Diff(state, tt.State, cmpOpts...); diff != "" {
t.Errorf("CookieStore.LoadSession() got = %s", diff) t.Errorf("Store.LoadSession() got = %s", diff)
} }
} }
w = httptest.NewRecorder() w = httptest.NewRecorder()

View file

@ -0,0 +1,90 @@
package cookie // import "github.com/pomerium/pomerium/internal/sessions/cookie"
import (
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/pomerium/pomerium/internal/sessions"
"github.com/google/go-cmp/cmp"
"github.com/pomerium/pomerium/internal/cryptutil"
"github.com/pomerium/pomerium/internal/encoding/ecjson"
"gopkg.in/square/go-jose.v2/jwt"
)
func testAuthorizer(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, err := sessions.FromContext(r.Context())
if err != nil {
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}
next.ServeHTTP(w, r)
})
}
func TestVerifier(t *testing.T) {
fnh := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
fmt.Fprint(w, http.StatusText(http.StatusOK))
w.WriteHeader(http.StatusOK)
})
tests := []struct {
name string
state sessions.State
wantBody string
wantStatus int
}{
{"good cookie session", sessions.State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}, http.StatusText(http.StatusOK), http.StatusOK},
{"expired cookie", sessions.State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}, "internal/sessions: validation failed, token is expired (exp)\n", http.StatusUnauthorized},
{"malformed cookie", sessions.State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}, "internal/sessions: session is malformed\n", http.StatusUnauthorized},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cipher, err := cryptutil.NewAEADCipherFromBase64(cryptutil.NewBase64Key())
encoder := ecjson.New(cipher)
if err != nil {
t.Fatal(err)
}
encSession, err := encoder.Marshal(&tt.state)
if err != nil {
t.Fatal(err)
}
if strings.Contains(tt.name, "malformed") {
// add some garbage to the end of the string
encSession = append(encSession, cryptutil.NewKey()...)
}
cs, err := NewStore(&Options{
Name: "_pomerium",
}, encoder)
if err != nil {
t.Fatal(err)
}
r := httptest.NewRequest(http.MethodGet, "/", nil)
r.Header.Set("Accept", "application/json")
w := httptest.NewRecorder()
r.AddCookie(&http.Cookie{Name: "_pomerium", Value: string(encSession)})
got := sessions.RetrieveSession(cs)(testAuthorizer((fnh)))
got.ServeHTTP(w, r)
gotBody := w.Body.String()
gotStatus := w.Result().StatusCode
if diff := cmp.Diff(gotBody, tt.wantBody); diff != "" {
t.Errorf("RetrieveSession() = %v", diff)
}
if diff := cmp.Diff(gotStatus, tt.wantStatus); diff != "" {
t.Errorf("RetrieveSession() = %v", diff)
}
})
}
}

View file

@ -0,0 +1,28 @@
package sessions // import "github.com/pomerium/pomerium/internal/sessions"
import (
"errors"
)
var (
// ErrNoSessionFound is the error for when no session is found.
ErrNoSessionFound = errors.New("internal/sessions: session is not found")
// ErrMalformed is the error for when a session is found but is malformed.
ErrMalformed = errors.New("internal/sessions: session is malformed")
// ErrNotValidYet indicates that token is used before time indicated in nbf claim.
ErrNotValidYet = errors.New("internal/sessions: validation failed, token not valid yet (nbf)")
// ErrExpired indicates that token is used after expiry time indicated in exp claim.
ErrExpired = errors.New("internal/sessions: validation failed, token is expired (exp)")
// ErrExpiryRequired indicates that the token does not contain a valid expiry (exp) claim.
ErrExpiryRequired = errors.New("internal/sessions: validation failed, token expiry (exp) is required")
// ErrIssuedInTheFuture indicates that the iat field is in the future.
ErrIssuedInTheFuture = errors.New("internal/sessions: validation field, token issued in the future (iat)")
// ErrInvalidAudience indicated invalid aud claim.
ErrInvalidAudience = errors.New("internal/sessions: validation failed, invalid audience claim (aud)")
)

View file

@ -1,35 +1,38 @@
package sessions // import "github.com/pomerium/pomerium/internal/sessions" package header // import "github.com/pomerium/pomerium/internal/sessions/header"
import ( import (
"net/http" "net/http"
"strings" "strings"
"github.com/pomerium/pomerium/internal/encoding" "github.com/pomerium/pomerium/internal/encoding"
"github.com/pomerium/pomerium/internal/sessions"
) )
var _ sessions.SessionLoader = &Store{}
const ( const (
defaultAuthHeader = "Authorization" defaultAuthHeader = "Authorization"
defaultAuthType = "Bearer" defaultAuthType = "Bearer"
) )
// HeaderStore implements the load session store interface using http // Store implements the load session store interface using http
// authorization headers. // authorization headers.
type HeaderStore struct { type Store struct {
authHeader string authHeader string
authType string authType string
encoder encoding.Unmarshaler encoder encoding.Unmarshaler
} }
// NewHeaderStore returns a new header store for loading sessions from // NewStore returns a new header store for loading sessions from
// authorization header as defined in as defined in rfc2617 // authorization header as defined in as defined in rfc2617
// //
// NOTA BENE: While most servers do not log Authorization headers by default, // NOTA BENE: While most servers do not log Authorization headers by default,
// you should ensure no other services are logging or leaking your auth headers. // you should ensure no other services are logging or leaking your auth headers.
func NewHeaderStore(enc encoding.Unmarshaler, headerType string) *HeaderStore { func NewStore(enc encoding.Unmarshaler, headerType string) *Store {
if headerType == "" { if headerType == "" {
headerType = defaultAuthType headerType = defaultAuthType
} }
return &HeaderStore{ return &Store{
authHeader: defaultAuthHeader, authHeader: defaultAuthHeader,
authType: headerType, authType: headerType,
encoder: enc, encoder: enc,
@ -37,14 +40,14 @@ func NewHeaderStore(enc encoding.Unmarshaler, headerType string) *HeaderStore {
} }
// LoadSession tries to retrieve the token string from the Authorization header. // LoadSession tries to retrieve the token string from the Authorization header.
func (as *HeaderStore) LoadSession(r *http.Request) (*State, error) { func (as *Store) LoadSession(r *http.Request) (*sessions.State, error) {
cipherText := TokenFromHeader(r, as.authHeader, as.authType) cipherText := TokenFromHeader(r, as.authHeader, as.authType)
if cipherText == "" { if cipherText == "" {
return nil, ErrNoSessionFound return nil, sessions.ErrNoSessionFound
} }
var session State var session sessions.State
if err := as.encoder.Unmarshal([]byte(cipherText), &session); err != nil { if err := as.encoder.Unmarshal([]byte(cipherText), &session); err != nil {
return nil, ErrMalformed return nil, sessions.ErrMalformed
} }
return &session, nil return &session, nil
} }

View file

@ -0,0 +1,90 @@
package header // import "github.com/pomerium/pomerium/internal/sessions/header"
import (
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/pomerium/pomerium/internal/cryptutil"
"github.com/pomerium/pomerium/internal/encoding/ecjson"
"github.com/pomerium/pomerium/internal/sessions"
"github.com/google/go-cmp/cmp"
"gopkg.in/square/go-jose.v2/jwt"
)
func testAuthorizer(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, err := sessions.FromContext(r.Context())
if err != nil {
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}
next.ServeHTTP(w, r)
})
}
func TestVerifier(t *testing.T) {
fnh := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
fmt.Fprint(w, http.StatusText(http.StatusOK))
w.WriteHeader(http.StatusOK)
})
tests := []struct {
name string
authType string
state sessions.State
wantBody string
wantStatus int
}{
{"good auth header session", "Bearer ", sessions.State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}, http.StatusText(http.StatusOK), http.StatusOK},
{"expired auth header", "Bearer ", sessions.State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}, "internal/sessions: validation failed, token is expired (exp)\n", http.StatusUnauthorized},
{"malformed auth header", "Bearer ", sessions.State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}, "internal/sessions: session is malformed\n", http.StatusUnauthorized},
{"empty auth header", "Bearer ", sessions.State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}, "internal/sessions: session is not found\n", http.StatusUnauthorized},
{"bad auth type", "bees ", sessions.State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}, "internal/sessions: session is not found\n", http.StatusUnauthorized},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cipher, err := cryptutil.NewAEADCipherFromBase64(cryptutil.NewBase64Key())
encoder := ecjson.New(cipher)
if err != nil {
t.Fatal(err)
}
encSession, err := encoder.Marshal(&tt.state)
if err != nil {
t.Fatal(err)
}
if strings.Contains(tt.name, "malformed") {
// add some garbage to the end of the string
encSession = append(encSession, cryptutil.NewKey()...)
}
s := NewStore(encoder, "")
r := httptest.NewRequest(http.MethodGet, "/", nil)
r.Header.Set("Accept", "application/json")
w := httptest.NewRecorder()
if strings.Contains(tt.name, "empty") {
encSession = []byte("")
}
r.Header.Set("Authorization", tt.authType+string(encSession))
got := sessions.RetrieveSession(s)(testAuthorizer((fnh)))
got.ServeHTTP(w, r)
gotBody := w.Body.String()
gotStatus := w.Result().StatusCode
if diff := cmp.Diff(gotBody, tt.wantBody); diff != "" {
t.Errorf("RetrieveSession() = %v", diff)
}
if diff := cmp.Diff(gotStatus, tt.wantStatus); diff != "" {
t.Errorf("RetrieveSession() = %v", diff)
}
})
}
}

View file

@ -44,7 +44,7 @@ func retrieveFromRequest(r *http.Request, sessions ...SessionLoader) (*State, er
} }
if state != nil { if state != nil {
err := state.Verify(urlutil.StripPort(r.Host)) err := state.Verify(urlutil.StripPort(r.Host))
return state, err // N.B.: state is _not_ nil_ return state, err // N.B.: state is _not_ nil
} }
} }

View file

@ -2,16 +2,14 @@ package sessions
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"strings"
"testing" "testing"
"time" "time"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
"github.com/pomerium/pomerium/internal/cryptutil"
"github.com/pomerium/pomerium/internal/encoding/ecjson"
"gopkg.in/square/go-jose.v2/jwt" "gopkg.in/square/go-jose.v2/jwt"
) )
@ -39,103 +37,6 @@ func TestNewContext(t *testing.T) {
} }
} }
func testAuthorizer(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, err := FromContext(r.Context())
if err != nil {
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}
next.ServeHTTP(w, r)
})
}
func TestVerifier(t *testing.T) {
fnh := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
fmt.Fprint(w, http.StatusText(http.StatusOK))
w.WriteHeader(http.StatusOK)
})
tests := []struct {
name string
// s SessionStore
state State
cookie bool
header bool
param bool
wantBody string
wantStatus int
}{
{"good cookie session", State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}, true, false, false, http.StatusText(http.StatusOK), http.StatusOK},
{"expired cookie", State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}, true, false, false, "internal/sessions: validation failed, token is expired (exp)\n", http.StatusUnauthorized},
{"malformed cookie", State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}, true, false, false, "internal/sessions: session is malformed\n", http.StatusUnauthorized},
{"good auth header session", State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}, false, true, false, http.StatusText(http.StatusOK), http.StatusOK},
{"expired auth header", State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}, false, true, false, "internal/sessions: validation failed, token is expired (exp)\n", http.StatusUnauthorized},
{"malformed auth header", State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}, false, true, false, "internal/sessions: session is malformed\n", http.StatusUnauthorized},
{"good auth query param session", State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}, false, true, true, http.StatusText(http.StatusOK), http.StatusOK},
{"expired auth query param", State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}, false, false, true, "internal/sessions: validation failed, token is expired (exp)\n", http.StatusUnauthorized},
{"malformed auth query param", State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}, false, false, true, "internal/sessions: session is malformed\n", http.StatusUnauthorized},
{"no session", State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}, false, false, false, "internal/sessions: session is not found\n", http.StatusUnauthorized},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cipher, err := cryptutil.NewAEADCipherFromBase64(cryptutil.NewBase64Key())
encoder := ecjson.New(cipher)
if err != nil {
t.Fatal(err)
}
encSession, err := encoder.Marshal(&tt.state)
if err != nil {
t.Fatal(err)
}
if strings.Contains(tt.name, "malformed") {
// add some garbage to the end of the string
encSession = append(encSession, cryptutil.NewKey()...)
}
cs, err := NewCookieStore(&CookieOptions{
Name: "_pomerium",
}, encoder)
if err != nil {
t.Fatal(err)
}
as := NewHeaderStore(encoder, "")
qp := NewQueryParamStore(encoder, "")
r := httptest.NewRequest(http.MethodGet, "/", nil)
r.Header.Set("Accept", "application/json")
w := httptest.NewRecorder()
if tt.cookie {
r.AddCookie(&http.Cookie{Name: "_pomerium", Value: string(encSession)})
} else if tt.header {
r.Header.Set("Authorization", "Bearer "+string(encSession))
} else if tt.param {
q := r.URL.Query()
q.Set("pomerium_session", string(encSession))
r.URL.RawQuery = q.Encode()
}
got := RetrieveSession(cs, as, qp)(testAuthorizer((fnh)))
got.ServeHTTP(w, r)
gotBody := w.Body.String()
gotStatus := w.Result().StatusCode
if diff := cmp.Diff(gotBody, tt.wantBody); diff != "" {
t.Errorf("RetrieveSession() = %v", diff)
}
if diff := cmp.Diff(gotStatus, tt.wantStatus); diff != "" {
t.Errorf("RetrieveSession() = %v", diff)
}
})
}
}
func Test_contextKey_String(t *testing.T) { func Test_contextKey_String(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
@ -155,3 +56,80 @@ func Test_contextKey_String(t *testing.T) {
}) })
} }
} }
func testAuthorizer(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, err := FromContext(r.Context())
if err != nil {
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}
next.ServeHTTP(w, r)
})
}
var _ SessionStore = &store{}
// Store is a mock implementation of the SessionStore interface
type store struct {
ResponseSession string
Session *State
SaveError error
LoadError error
}
// ClearSession clears the ResponseSession
func (ms *store) ClearSession(http.ResponseWriter, *http.Request) {
ms.ResponseSession = ""
}
// LoadSession returns the session and a error
func (ms store) LoadSession(*http.Request) (*State, error) {
return ms.Session, ms.LoadError
}
// SaveSession returns a save error.
func (ms store) SaveSession(http.ResponseWriter, *http.Request, interface{}) error {
return ms.SaveError
}
func TestVerifier(t *testing.T) {
fnh := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
fmt.Fprint(w, http.StatusText(http.StatusOK))
w.WriteHeader(http.StatusOK)
})
tests := []struct {
name string
store store
state State
wantBody string
wantStatus int
}{
{"empty session", store{}, State{}, "internal/sessions: session is not found\n", 401},
{"simple good load", store{Session: &State{Subject: "hi", Expiry: jwt.NewNumericDate(time.Now().Add(time.Second))}}, State{}, "OK", 200},
{"empty session", store{LoadError: errors.New("err")}, State{}, "err\n", 401},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := httptest.NewRequest(http.MethodGet, "/", nil)
r.Header.Set("Accept", "application/json")
w := httptest.NewRecorder()
got := RetrieveSession(tt.store)(testAuthorizer((fnh)))
got.ServeHTTP(w, r)
gotBody := w.Body.String()
gotStatus := w.Result().StatusCode
if diff := cmp.Diff(gotBody, tt.wantBody); diff != "" {
t.Errorf("RetrieveSession() = %v", diff)
}
if diff := cmp.Diff(gotStatus, tt.wantStatus); diff != "" {
t.Errorf("RetrieveSession() = %v", diff)
}
})
}
}

View file

@ -0,0 +1,33 @@
package mock // import "github.com/pomerium/pomerium/internal/sessions/mock"
import (
"net/http"
"github.com/pomerium/pomerium/internal/sessions"
)
var _ sessions.SessionStore = &Store{}
var _ sessions.SessionLoader = &Store{}
// Store is a mock implementation of the SessionStore interface
type Store struct {
ResponseSession string
Session *sessions.State
SaveError error
LoadError error
}
// ClearSession clears the ResponseSession
func (ms *Store) ClearSession(http.ResponseWriter, *http.Request) {
ms.ResponseSession = ""
}
// LoadSession returns the session and a error
func (ms Store) LoadSession(*http.Request) (*sessions.State, error) {
return ms.Session, ms.LoadError
}
// SaveSession returns a save error.
func (ms Store) SaveSession(http.ResponseWriter, *http.Request, interface{}) error {
return ms.SaveError
}

View file

@ -1,26 +1,28 @@
package sessions // import "github.com/pomerium/pomerium/internal/sessions" package mock // import "github.com/pomerium/pomerium/internal/sessions/mock"
import ( import (
"reflect" "reflect"
"testing" "testing"
"github.com/pomerium/pomerium/internal/sessions"
) )
func TestMockSessionStore(t *testing.T) { func TestStore(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
mockCSRF *MockSessionStore mockCSRF *Store
saveSession *State saveSession *sessions.State
wantLoadErr bool wantLoadErr bool
wantSaveErr bool wantSaveErr bool
}{ }{
{"basic", {"basic",
&MockSessionStore{ &Store{
ResponseSession: "test", ResponseSession: "test",
Session: &State{Subject: "0101"}, Session: &sessions.State{Subject: "0101"},
SaveError: nil, SaveError: nil,
LoadError: nil, LoadError: nil,
}, },
&State{Subject: "0101"}, &sessions.State{Subject: "0101"},
false, false,
false}, false},
} }

View file

@ -1,28 +0,0 @@
package sessions // import "github.com/pomerium/pomerium/internal/sessions"
import (
"net/http"
)
// MockSessionStore is a mock implementation of the SessionStore interface
type MockSessionStore struct {
ResponseSession string
Session *State
SaveError error
LoadError error
}
// ClearSession clears the ResponseSession
func (ms *MockSessionStore) ClearSession(http.ResponseWriter, *http.Request) {
ms.ResponseSession = ""
}
// LoadSession returns the session and a error
func (ms MockSessionStore) LoadSession(*http.Request) (*State, error) {
return ms.Session, ms.LoadError
}
// SaveSession returns a save error.
func (ms MockSessionStore) SaveSession(http.ResponseWriter, *http.Request, interface{}) error {
return ms.SaveError
}

View file

@ -0,0 +1,92 @@
package queryparam // import "github.com/pomerium/pomerium/internal/sessions/queryparam"
import (
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/pomerium/pomerium/internal/cryptutil"
"github.com/pomerium/pomerium/internal/encoding/ecjson"
"github.com/pomerium/pomerium/internal/sessions"
"github.com/google/go-cmp/cmp"
"gopkg.in/square/go-jose.v2/jwt"
)
func testAuthorizer(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, err := sessions.FromContext(r.Context())
if err != nil {
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}
next.ServeHTTP(w, r)
})
}
func TestVerifier(t *testing.T) {
fnh := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
fmt.Fprint(w, http.StatusText(http.StatusOK))
w.WriteHeader(http.StatusOK)
})
tests := []struct {
name string
state sessions.State
wantBody string
wantStatus int
}{
{"good auth query param session", sessions.State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}, http.StatusText(http.StatusOK), http.StatusOK},
{"expired auth query param", sessions.State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}, "internal/sessions: validation failed, token is expired (exp)\n", http.StatusUnauthorized},
{"malformed auth query param", sessions.State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}, "internal/sessions: session is malformed\n", http.StatusUnauthorized},
{"empty auth query param", sessions.State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}, "internal/sessions: session is not found\n", http.StatusUnauthorized},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cipher, err := cryptutil.NewAEADCipherFromBase64(cryptutil.NewBase64Key())
encoder := ecjson.New(cipher)
if err != nil {
t.Fatal(err)
}
encSession, err := encoder.Marshal(&tt.state)
if err != nil {
t.Fatal(err)
}
if strings.Contains(tt.name, "malformed") {
// add some garbage to the end of the string
encSession = append(encSession, cryptutil.NewKey()...)
}
s := NewStore(encoder, "")
r := httptest.NewRequest(http.MethodGet, "/", nil)
r.Header.Set("Accept", "application/json")
w := httptest.NewRecorder()
q := r.URL.Query()
if strings.Contains(tt.name, "empty") {
encSession = []byte("")
}
q.Set("pomerium_session", string(encSession))
r.URL.RawQuery = q.Encode()
got := sessions.RetrieveSession(s)(testAuthorizer((fnh)))
got.ServeHTTP(w, r)
gotBody := w.Body.String()
gotStatus := w.Result().StatusCode
if diff := cmp.Diff(gotBody, tt.wantBody); diff != "" {
t.Errorf("RetrieveSession() = %v", diff)
}
if diff := cmp.Diff(gotStatus, tt.wantStatus); diff != "" {
t.Errorf("RetrieveSession() = %v", diff)
}
})
}
}

View file

@ -1,33 +1,37 @@
package sessions // import "github.com/pomerium/pomerium/internal/sessions" package queryparam // import "github.com/pomerium/pomerium/internal/sessions/queryparam"
import ( import (
"net/http" "net/http"
"github.com/pomerium/pomerium/internal/encoding" "github.com/pomerium/pomerium/internal/encoding"
"github.com/pomerium/pomerium/internal/sessions"
) )
var _ sessions.SessionStore = &Store{}
var _ sessions.SessionLoader = &Store{}
const ( const (
defaultQueryParamKey = "pomerium_session" defaultQueryParamKey = "pomerium_session"
) )
// QueryParamStore implements the load session store interface using http // Store implements the load session store interface using http
// query strings / query parameters. // query strings / query parameters.
type QueryParamStore struct { type Store struct {
queryParamKey string queryParamKey string
encoder encoding.Marshaler encoder encoding.Marshaler
decoder encoding.Unmarshaler decoder encoding.Unmarshaler
} }
// NewQueryParamStore returns a new query param store for loading sessions from // NewStore returns a new query param store for loading sessions from
// query strings / query parameters. // query strings / query parameters.
// //
// NOTA BENE: By default, most servers _DO_ log query params, the leaking or // NOTA BENE: By default, most servers _DO_ log query params, the leaking or
// accidental logging of which should be considered a security issue. // accidental logging of which should be considered a security issue.
func NewQueryParamStore(enc encoding.MarshalUnmarshaler, qp string) *QueryParamStore { func NewStore(enc encoding.MarshalUnmarshaler, qp string) *Store {
if qp == "" { if qp == "" {
qp = defaultQueryParamKey qp = defaultQueryParamKey
} }
return &QueryParamStore{ return &Store{
queryParamKey: qp, queryParamKey: qp,
encoder: enc, encoder: enc,
decoder: enc, decoder: enc,
@ -35,27 +39,27 @@ func NewQueryParamStore(enc encoding.MarshalUnmarshaler, qp string) *QueryParamS
} }
// LoadSession tries to retrieve the token string from URL query parameters. // LoadSession tries to retrieve the token string from URL query parameters.
func (qp *QueryParamStore) LoadSession(r *http.Request) (*State, error) { func (qp *Store) LoadSession(r *http.Request) (*sessions.State, error) {
cipherText := r.URL.Query().Get(qp.queryParamKey) cipherText := r.URL.Query().Get(qp.queryParamKey)
if cipherText == "" { if cipherText == "" {
return nil, ErrNoSessionFound return nil, sessions.ErrNoSessionFound
} }
var session State var session sessions.State
if err := qp.decoder.Unmarshal([]byte(cipherText), &session); err != nil { if err := qp.decoder.Unmarshal([]byte(cipherText), &session); err != nil {
return nil, ErrMalformed return nil, sessions.ErrMalformed
} }
return &session, nil return &session, nil
} }
// ClearSession clears the session cookie from a request's query param key `pomerium_session`. // ClearSession clears the session cookie from a request's query param key `pomerium_session`.
func (qp *QueryParamStore) ClearSession(w http.ResponseWriter, r *http.Request) { func (qp *Store) ClearSession(w http.ResponseWriter, r *http.Request) {
params := r.URL.Query() params := r.URL.Query()
params.Del(qp.queryParamKey) params.Del(qp.queryParamKey)
r.URL.RawQuery = params.Encode() r.URL.RawQuery = params.Encode()
} }
// SaveSession sets a session to a request's query param key `pomerium_session` // SaveSession sets a session to a request's query param key `pomerium_session`
func (qp *QueryParamStore) SaveSession(w http.ResponseWriter, r *http.Request, x interface{}) error { func (qp *Store) SaveSession(w http.ResponseWriter, r *http.Request, x interface{}) error {
data, err := qp.encoder.Marshal(x) data, err := qp.encoder.Marshal(x)
if err != nil { if err != nil {
return err return err

View file

@ -1,4 +1,4 @@
package sessions package queryparam // import "github.com/pomerium/pomerium/internal/sessions/queryparam"
import ( import (
"errors" "errors"
@ -9,39 +9,40 @@ import (
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
"github.com/pomerium/pomerium/internal/encoding" "github.com/pomerium/pomerium/internal/encoding"
"github.com/pomerium/pomerium/internal/encoding/mock" "github.com/pomerium/pomerium/internal/encoding/mock"
"github.com/pomerium/pomerium/internal/sessions"
) )
func TestNewQueryParamStore(t *testing.T) { func TestNewQueryParamStore(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
State *State State *sessions.State
enc encoding.MarshalUnmarshaler enc encoding.MarshalUnmarshaler
qp string qp string
wantErr bool wantErr bool
wantURL *url.URL wantURL *url.URL
}{ }{
{"simple good", &State{Email: "user@domain.com", User: "user"}, mock.Encoder{MarshalResponse: []byte("ok")}, "", false, &url.URL{Path: "/", RawQuery: "pomerium_session=ok"}}, {"simple good", &sessions.State{Email: "user@domain.com", User: "user"}, mock.Encoder{MarshalResponse: []byte("ok")}, "", false, &url.URL{Path: "/", RawQuery: "pomerium_session=ok"}},
{"marshall error", &State{Email: "user@domain.com", User: "user"}, mock.Encoder{MarshalError: errors.New("error")}, "", true, &url.URL{Path: "/"}}, {"marshall error", &sessions.State{Email: "user@domain.com", User: "user"}, mock.Encoder{MarshalError: errors.New("error")}, "", true, &url.URL{Path: "/"}},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
got := NewQueryParamStore(tt.enc, tt.qp) got := NewStore(tt.enc, tt.qp)
r := httptest.NewRequest("GET", "/", nil) r := httptest.NewRequest("GET", "/", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
if err := got.SaveSession(w, r, tt.State); (err != nil) != tt.wantErr { if err := got.SaveSession(w, r, tt.State); (err != nil) != tt.wantErr {
t.Errorf("NewQueryParamStore.SaveSession() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("NewStore.SaveSession() error = %v, wantErr %v", err, tt.wantErr)
} }
if diff := cmp.Diff(r.URL, tt.wantURL); diff != "" { if diff := cmp.Diff(r.URL, tt.wantURL); diff != "" {
t.Errorf("NewQueryParamStore() = %v", diff) t.Errorf("NewStore() = %v", diff)
} }
got.ClearSession(w, r) got.ClearSession(w, r)
if diff := cmp.Diff(r.URL, &url.URL{Path: "/"}); diff != "" { if diff := cmp.Diff(r.URL, &url.URL{Path: "/"}); diff != "" {
t.Errorf("NewQueryParamStore() = %v", diff) t.Errorf("NewStore() = %v", diff)
} }
}) })
} }

View file

@ -6,6 +6,8 @@ import (
"strings" "strings"
"time" "time"
"github.com/cespare/xxhash/v2"
"github.com/mitchellh/hashstructure"
oidc "github.com/pomerium/go-oidc" oidc "github.com/pomerium/go-oidc"
"golang.org/x/oauth2" "golang.org/x/oauth2"
"gopkg.in/square/go-jose.v2/jwt" "gopkg.in/square/go-jose.v2/jwt"
@ -51,7 +53,8 @@ type State struct {
// programatic access. // programatic access.
Programmatic bool `json:"programatic"` Programmatic bool `json:"programatic"`
AccessToken *oauth2.Token `json:"access_token,omitempty"` AccessToken *oauth2.Token `json:"act,omitempty"`
AccessTokenID string `json:"ati,omitempty"`
idToken *oidc.IDToken idToken *oidc.IDToken
} }
@ -73,7 +76,7 @@ func NewStateFromTokens(idToken *oidc.IDToken, accessToken *oauth2.Token, audien
s.Audience = []string{audience} s.Audience = []string{audience}
s.idToken = idToken s.idToken = idToken
s.AccessToken = accessToken s.AccessToken = accessToken
s.AccessTokenID = s.accessTokenHash()
return s, nil return s, nil
} }
@ -95,6 +98,7 @@ func (s *State) UpdateState(idToken *oidc.IDToken, accessToken *oauth2.Token) er
} }
s.Audience = audience s.Audience = audience
s.Expiry = jwt.NewNumericDate(accessToken.Expiry) s.Expiry = jwt.NewNumericDate(accessToken.Expiry)
s.AccessTokenID = s.accessTokenHash()
return nil return nil
} }
@ -173,3 +177,13 @@ func (s *State) SetImpersonation(email, groups string) {
s.ImpersonateGroups = strings.Split(groups, ",") s.ImpersonateGroups = strings.Split(groups, ",")
} }
} }
func (s *State) accessTokenHash() string {
hash, err := hashstructure.Hash(
s.AccessToken,
&hashstructure.HashOptions{Hasher: xxhash.New()})
if err != nil {
return ""
}
return fmt.Sprintf("%x", hash)
}

View file

@ -124,3 +124,26 @@ func TestState_RouteSession(t *testing.T) {
}) })
} }
} }
func TestState_accessTokenHash(t *testing.T) {
t.Parallel()
tests := []struct {
name string
state State
want string
}{
{"empty access token", State{}, "34c96acdcadb1bbb"},
{"no change to access token", State{Subject: "test"}, "34c96acdcadb1bbb"},
{"empty oauth2 token", State{AccessToken: &oauth2.Token{}}, "bbd82197d215198f"},
{"refresh token a", State{AccessToken: &oauth2.Token{RefreshToken: "a"}}, "76316ac79b301bd6"},
{"refresh token b", State{AccessToken: &oauth2.Token{RefreshToken: "b"}}, "fab7cb29e50161f1"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := &tt.state
if got := s.accessTokenHash(); got != tt.want {
t.Errorf("State.accessTokenHash() = %v, want %v", got, tt.want)
}
})
}
}

View file

@ -1,39 +1,17 @@
package sessions // import "github.com/pomerium/pomerium/internal/sessions" package sessions // import "github.com/pomerium/pomerium/internal/sessions"
import ( import (
"errors"
"net/http" "net/http"
) )
var ( // SessionStore defines an interface for loading, saving, and clearing a session.
// ErrNoSessionFound is the error for when no session is found.
ErrNoSessionFound = errors.New("internal/sessions: session is not found")
// ErrMalformed is the error for when a session is found but is malformed.
ErrMalformed = errors.New("internal/sessions: session is malformed")
// ErrNotValidYet indicates that token is used before time indicated in nbf claim.
ErrNotValidYet = errors.New("internal/sessions: validation failed, token not valid yet (nbf)")
// ErrExpired indicates that token is used after expiry time indicated in exp claim.
ErrExpired = errors.New("internal/sessions: validation failed, token is expired (exp)")
// ErrIssuedInTheFuture indicates that the iat field is in the future.
ErrIssuedInTheFuture = errors.New("internal/sessions: validation field, token issued in the future (iat)")
// ErrInvalidAudience indicated invalid aud claim.
ErrInvalidAudience = errors.New("internal/sessions: validation failed, invalid audience claim (aud)")
)
// SessionStore has the functions for setting, getting, and clearing the Session cookie
type SessionStore interface { type SessionStore interface {
ClearSession(http.ResponseWriter, *http.Request)
SessionLoader SessionLoader
ClearSession(http.ResponseWriter, *http.Request)
SaveSession(http.ResponseWriter, *http.Request, interface{}) error SaveSession(http.ResponseWriter, *http.Request, interface{}) error
} }
// SessionLoader is implemented by any struct that loads a pomerium session // SessionLoader defines an interface for loading a session.
// given a request, and returns a user state.
type SessionLoader interface { type SessionLoader interface {
LoadSession(*http.Request) (*State, error) LoadSession(*http.Request) (*State, error)
} }

View file

@ -34,10 +34,10 @@ var (
// DefaultViews are a set of default views to view HTTP and GRPC metrics. // DefaultViews are a set of default views to view HTTP and GRPC metrics.
var ( var (
DefaultViews = [][]*view.View{ DefaultViews = [][]*view.View{
GRPCServerViews,
HTTPServerViews,
GRPCClientViews, GRPCClientViews,
GRPCServerViews, GRPCServerViews,
HTTPClientViews,
HTTPServerViews,
InfoViews, InfoViews,
} }
) )

View file

@ -9,14 +9,16 @@ import (
"time" "time"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
"gopkg.in/square/go-jose.v2/jwt"
"github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/encoding" "github.com/pomerium/pomerium/internal/encoding"
"github.com/pomerium/pomerium/internal/encoding/mock" "github.com/pomerium/pomerium/internal/encoding/mock"
"github.com/pomerium/pomerium/internal/httputil" "github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/sessions" "github.com/pomerium/pomerium/internal/sessions"
mstore "github.com/pomerium/pomerium/internal/sessions/mock"
"github.com/pomerium/pomerium/internal/urlutil" "github.com/pomerium/pomerium/internal/urlutil"
"github.com/pomerium/pomerium/proxy/clients" "github.com/pomerium/pomerium/proxy/clients"
"gopkg.in/square/go-jose.v2/jwt"
) )
func TestProxy_ForwardAuth(t *testing.T) { func TestProxy_ForwardAuth(t *testing.T) {
@ -40,29 +42,29 @@ func TestProxy_ForwardAuth(t *testing.T) {
wantStatus int wantStatus int
wantBody string wantBody string
}{ }{
{"good redirect not required", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusOK, "Access to some.domain.example is allowed."}, {"good redirect not required", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusOK, "Access to some.domain.example is allowed."},
{"good verify only, no redirect", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/verify", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusOK, ""}, {"good verify only, no redirect", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/verify", "https://some.domain.example", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusOK, ""},
{"good redirect not required", opts, nil, http.MethodGet, nil, nil, "/", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{LoadError: sessions.ErrInvalidAudience}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusUnauthorized, "{\"Status\":401,\"Error\":\"Unauthorized: internal/sessions: validation failed, invalid audience claim (aud)\"}\n"}, {"good redirect not required", opts, nil, http.MethodGet, nil, nil, "/", "https://some.domain.example", &mock.Encoder{}, &mstore.Store{LoadError: sessions.ErrInvalidAudience}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusUnauthorized, "{\"Status\":401,\"Error\":\"Unauthorized: internal/sessions: validation failed, invalid audience claim (aud)\"}\n"},
{"bad naked domain uri", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/", "a.naked.domain", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, "{\"Status\":400,\"Error\":\"Bad Request: a.naked.domain url does contain a valid scheme\"}\n"}, {"bad naked domain uri", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/", "a.naked.domain", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, "{\"Status\":400,\"Error\":\"Bad Request: a.naked.domain url does contain a valid scheme\"}\n"},
{"bad naked domain uri verify only", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/verify", "a.naked.domain", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, "{\"Status\":400,\"Error\":\"Bad Request: a.naked.domain url does contain a valid scheme\"}\n"}, {"bad naked domain uri verify only", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/verify", "a.naked.domain", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, "{\"Status\":400,\"Error\":\"Bad Request: a.naked.domain url does contain a valid scheme\"}\n"},
{"bad empty verification uri", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/", " ", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, "{\"Status\":400,\"Error\":\"Bad Request: %20 url does contain a valid scheme\"}\n"}, {"bad empty verification uri", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/", " ", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, "{\"Status\":400,\"Error\":\"Bad Request: %20 url does contain a valid scheme\"}\n"},
{"bad empty verification uri verify only", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/verify", " ", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, "{\"Status\":400,\"Error\":\"Bad Request: %20 url does contain a valid scheme\"}\n"}, {"bad empty verification uri verify only", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/verify", " ", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, "{\"Status\":400,\"Error\":\"Bad Request: %20 url does contain a valid scheme\"}\n"},
{"not authorized", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusUnauthorized, "{\"Status\":401,\"Error\":\"Unauthorized: user@test.example is not authorized for some.domain.example\"}\n"}, {"not authorized", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusUnauthorized, "{\"Status\":401,\"Error\":\"Unauthorized: user@test.example is not authorized for some.domain.example\"}\n"},
{"not authorized verify endpoint", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/verify", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusUnauthorized, "{\"Status\":401,\"Error\":\"Unauthorized: user@test.example is not authorized for some.domain.example\"}\n"}, {"not authorized verify endpoint", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/verify", "https://some.domain.example", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusUnauthorized, "{\"Status\":401,\"Error\":\"Unauthorized: user@test.example is not authorized for some.domain.example\"}\n"},
{"not authorized expired, redirect to auth", opts, sessions.ErrExpired, http.MethodGet, nil, nil, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusFound, ""}, {"not authorized expired, redirect to auth", opts, sessions.ErrExpired, http.MethodGet, nil, nil, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusFound, ""},
{"not authorized expired, don't redirect!", opts, sessions.ErrExpired, http.MethodGet, nil, nil, "https://some.domain.example/verify", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusUnauthorized, "{\"Status\":401,\"Error\":\"Unauthorized: internal/sessions: validation failed, token is expired (exp)\"}\n"}, {"not authorized expired, don't redirect!", opts, sessions.ErrExpired, http.MethodGet, nil, nil, "https://some.domain.example/verify", "https://some.domain.example", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusUnauthorized, "{\"Status\":401,\"Error\":\"Unauthorized: internal/sessions: validation failed, token is expired (exp)\"}\n"},
{"not authorized because of error", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeError: errors.New("authz error")}, http.StatusInternalServerError, "{\"Status\":500,\"Error\":\"Internal Server Error: authz error\"}\n"}, {"not authorized because of error", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeError: errors.New("authz error")}, http.StatusInternalServerError, "{\"Status\":500,\"Error\":\"Internal Server Error: authz error\"}\n"},
{"not authorized expired, do not redirect to auth", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/verify", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusUnauthorized, "{\"Status\":401,\"Error\":\"Unauthorized: internal/sessions: validation failed, token is expired (exp)\"}\n"}, {"not authorized expired, do not redirect to auth", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/verify", "https://some.domain.example", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusUnauthorized, "{\"Status\":401,\"Error\":\"Unauthorized: internal/sessions: validation failed, token is expired (exp)\"}\n"},
{"not authorized, bad audience request uri", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Audience: []string{"not.domain.example"}, Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusUnauthorized, "{\"Status\":401,\"Error\":\"Unauthorized: internal/sessions: validation failed, invalid audience claim (aud)\"}\n"}, {"not authorized, bad audience request uri", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Audience: []string{"not.domain.example"}, Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusUnauthorized, "{\"Status\":401,\"Error\":\"Unauthorized: internal/sessions: validation failed, invalid audience claim (aud)\"}\n"},
{"not authorized, bad audience verify uri", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/", "https://fwdauth.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Audience: []string{"some.domain.example"}, Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusUnauthorized, "{\"Status\":401,\"Error\":\"Unauthorized: internal/sessions: validation failed, invalid audience claim (aud)\"}\n"}, {"not authorized, bad audience verify uri", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/", "https://fwdauth.domain.example", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Audience: []string{"some.domain.example"}, Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusUnauthorized, "{\"Status\":401,\"Error\":\"Unauthorized: internal/sessions: validation failed, invalid audience claim (aud)\"}\n"},
// traefik // traefik
{"good traefik callback", opts, nil, http.MethodGet, map[string]string{httputil.HeaderForwardedURI: "https://some.domain.example?" + urlutil.QuerySessionEncrypted + "=" + goodEncryptionString}, nil, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound, ""}, {"good traefik callback", opts, nil, http.MethodGet, map[string]string{httputil.HeaderForwardedURI: "https://some.domain.example?" + urlutil.QuerySessionEncrypted + "=" + goodEncryptionString}, nil, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound, ""},
{"bad traefik callback bad session", opts, nil, http.MethodGet, map[string]string{httputil.HeaderForwardedURI: "https://some.domain.example?" + urlutil.QuerySessionEncrypted + "=" + goodEncryptionString + "garbage"}, nil, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""}, {"bad traefik callback bad session", opts, nil, http.MethodGet, map[string]string{httputil.HeaderForwardedURI: "https://some.domain.example?" + urlutil.QuerySessionEncrypted + "=" + goodEncryptionString + "garbage"}, nil, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""},
{"bad traefik callback bad url", opts, nil, http.MethodGet, map[string]string{httputil.HeaderForwardedURI: urlutil.QuerySessionEncrypted + ""}, nil, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""}, {"bad traefik callback bad url", opts, nil, http.MethodGet, map[string]string{httputil.HeaderForwardedURI: urlutil.QuerySessionEncrypted + ""}, nil, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""},
// nginx // nginx
{"good nginx callback redirect", opts, nil, http.MethodGet, nil, map[string]string{urlutil.QueryRedirectURI: "https://some.domain.example/", urlutil.QuerySessionEncrypted: goodEncryptionString}, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound, ""}, {"good nginx callback redirect", opts, nil, http.MethodGet, nil, map[string]string{urlutil.QueryRedirectURI: "https://some.domain.example/", urlutil.QuerySessionEncrypted: goodEncryptionString}, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound, ""},
{"good nginx callback set session okay but return unauthorized", opts, nil, http.MethodGet, nil, map[string]string{urlutil.QueryRedirectURI: "https://some.domain.example/", urlutil.QuerySessionEncrypted: goodEncryptionString}, "https://some.domain.example/verify", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusUnauthorized, ""}, {"good nginx callback set session okay but return unauthorized", opts, nil, http.MethodGet, nil, map[string]string{urlutil.QueryRedirectURI: "https://some.domain.example/", urlutil.QuerySessionEncrypted: goodEncryptionString}, "https://some.domain.example/verify", "https://some.domain.example", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusUnauthorized, ""},
{"bad nginx callback failed to set sesion", opts, nil, http.MethodGet, nil, map[string]string{urlutil.QueryRedirectURI: "https://some.domain.example/", urlutil.QuerySessionEncrypted: goodEncryptionString + "nope"}, "https://some.domain.example/verify", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""}, {"bad nginx callback failed to set sesion", opts, nil, http.MethodGet, nil, map[string]string{urlutil.QueryRedirectURI: "https://some.domain.example/", urlutil.QuerySessionEncrypted: goodEncryptionString + "nope"}, "https://some.domain.example/verify", "https://some.domain.example", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {

View file

@ -12,6 +12,7 @@ import (
"time" "time"
"github.com/pomerium/pomerium/internal/cryptutil" "github.com/pomerium/pomerium/internal/cryptutil"
mstore "github.com/pomerium/pomerium/internal/sessions/mock"
"github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/encoding" "github.com/pomerium/pomerium/internal/encoding"
@ -78,10 +79,10 @@ func TestProxy_UserDashboard(t *testing.T) {
wantAdminForm bool wantAdminForm bool
wantStatus int wantStatus int
}{ }{
{"good", nil, opts, http.MethodGet, &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{}, false, http.StatusOK}, {"good", nil, opts, http.MethodGet, &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{}, false, http.StatusOK},
{"session context error", errors.New("error"), opts, http.MethodGet, &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{}, false, http.StatusInternalServerError}, {"session context error", errors.New("error"), opts, http.MethodGet, &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{}, false, http.StatusInternalServerError},
{"want admin form good admin authorization", nil, opts, http.MethodGet, &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{IsAdminResponse: true}, true, http.StatusOK}, {"want admin form good admin authorization", nil, opts, http.MethodGet, &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{IsAdminResponse: true}, true, http.StatusOK},
{"is admin but authorization fails", nil, opts, http.MethodGet, &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{IsAdminError: errors.New("err")}, false, http.StatusInternalServerError}, {"is admin but authorization fails", nil, opts, http.MethodGet, &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{IsAdminError: errors.New("err")}, false, http.StatusInternalServerError},
} }
for _, tt := range tests { for _, tt := range tests {
@ -135,12 +136,12 @@ func TestProxy_Impersonate(t *testing.T) {
authorizer clients.Authorizer authorizer clients.Authorizer
wantStatus int wantStatus int
}{ }{
{"good", false, opts, nil, http.MethodPost, "user@blah.com", "", "", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example"}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusFound}, {"good", false, opts, nil, http.MethodPost, "user@blah.com", "", "", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example"}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusFound},
{"good", false, opts, errors.New("error"), http.MethodPost, "user@blah.com", "", "", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example"}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusInternalServerError}, {"good", false, opts, errors.New("error"), http.MethodPost, "user@blah.com", "", "", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example"}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusInternalServerError},
{"session load error", false, opts, nil, http.MethodPost, "user@blah.com", "", "", &mock.Encoder{}, &sessions.MockSessionStore{LoadError: errors.New("err"), Session: &sessions.State{Email: "user@test.example"}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusFound}, {"session load error", false, opts, nil, http.MethodPost, "user@blah.com", "", "", &mock.Encoder{}, &mstore.Store{LoadError: errors.New("err"), Session: &sessions.State{Email: "user@test.example"}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusFound},
{"non admin users rejected", false, opts, nil, http.MethodPost, "user@blah.com", "", "", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute)), Email: "user@test.example"}}, clients.MockAuthorize{IsAdminResponse: false}, http.StatusForbidden}, {"non admin users rejected", false, opts, nil, http.MethodPost, "user@blah.com", "", "", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute)), Email: "user@test.example"}}, clients.MockAuthorize{IsAdminResponse: false}, http.StatusForbidden},
{"non admin users rejected on error", false, opts, nil, http.MethodPost, "user@blah.com", "", "", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute)), Email: "user@test.example"}}, clients.MockAuthorize{IsAdminResponse: true, IsAdminError: errors.New("err")}, http.StatusInternalServerError}, {"non admin users rejected on error", false, opts, nil, http.MethodPost, "user@blah.com", "", "", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute)), Email: "user@test.example"}}, clients.MockAuthorize{IsAdminResponse: true, IsAdminError: errors.New("err")}, http.StatusInternalServerError},
{"groups", false, opts, nil, http.MethodPost, "user@blah.com", "group1,group2", "", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute)), Email: "user@test.example"}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusFound}, {"groups", false, opts, nil, http.MethodPost, "user@blah.com", "group1,group2", "", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute)), Email: "user@test.example"}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusFound},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
@ -245,12 +246,12 @@ func TestProxy_Callback(t *testing.T) {
wantStatus int wantStatus int
wantBody string wantBody string
}{ }{
{"good", opts, http.MethodGet, "http", "example.com", "/", nil, map[string]string{urlutil.QueryCallbackURI: "ok", urlutil.QuerySessionEncrypted: goodEncryptionString}, &mock.Encoder{MarshalResponse: []byte("x")}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound, ""}, {"good", opts, http.MethodGet, "http", "example.com", "/", nil, map[string]string{urlutil.QueryCallbackURI: "ok", urlutil.QuerySessionEncrypted: goodEncryptionString}, &mock.Encoder{MarshalResponse: []byte("x")}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound, ""},
{"good programmatic", opts, http.MethodGet, "http", "example.com", "/", nil, map[string]string{urlutil.QueryIsProgrammatic: "true", urlutil.QueryCallbackURI: "ok", urlutil.QuerySessionEncrypted: goodEncryptionString}, &mock.Encoder{MarshalResponse: []byte("x")}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound, ""}, {"good programmatic", opts, http.MethodGet, "http", "example.com", "/", nil, map[string]string{urlutil.QueryIsProgrammatic: "true", urlutil.QueryCallbackURI: "ok", urlutil.QuerySessionEncrypted: goodEncryptionString}, &mock.Encoder{MarshalResponse: []byte("x")}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound, ""},
{"bad decrypt", opts, http.MethodGet, "http", "example.com", "/", nil, map[string]string{urlutil.QuerySessionEncrypted: "KBEjQ9rnCxaAX-GOqexGw9ivEQURqts3zZ2mNGy0wnVa3SbtM399KlBq2nZ-9wM21FfsZX52er4jlmC7kPEKM3P7uZ41zR0zeys1-_74a5tQp-vsf1WXZfRsgVOuBcWPkMiWEoc379JFHxGDudp5VhU8B-dcQt4f3_PtLTHARkuH54io1Va2gNMq4Hiy8sQ1MPGCQeltH_JMzzdDpXdmdusWrXUvCGkba24muvAV06D8XRVJj6Iu9eK94qFnqcHc7wzziEbb8ADBues9dwbtb6jl8vMWz5rN6XvXqA5YpZv_MQZlsrO4oXFFQDevdgB84cX1tVbVu6qZvK_yQBZqzpOjWA9uIaoSENMytoXuWAlFO_sXjswfX8JTNdGwzB7qQRNPqxVG_sM_tzY3QhPm8zqwEzsXG5DokxZfVt2I5WJRUEovFDb4BnK9KFnnkEzLEdMudixVnXeGmTtycgJvoTeTCQRPfDYkcgJ7oKf4tGea-W7z5UAVa2RduJM9ZoM6YtJX7jgDm__PvvqcE0knJUF87XHBzdcOjoDF-CUze9xDJgNBlvPbJqVshKrwoqSYpePSDH9GUCNKxGequW3Ma8GvlFfhwd0rK6IZG-XWkyk0XSWQIGkDSjAvhB1wsOusCCguDjbpVZpaW5MMyTkmx68pl6qlIKT5UCcrVPl4ix5ZEj91mUDF0O1t04haD7VZuLVFXVGmqtFrBKI76sdYN-zkokaa1_chPRTyqMQFlqu_8LD6-RiK3UccGM-dEmnX72i91NP9F9OK0WJr9Cheup1C_P0mjqAO4Cb8oIHm0Oxz_mRqv5QbTGJtb3xwPLPuVjVCiE4gGBcuU2ixpSVf5HUF7y1KicVMCKiX9ATCBtg8sTdQZQnPEtHcHHAvdsnDVwev1LGfqA-Gdvg="}, &mock.Encoder{MarshalResponse: []byte("x")}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""}, {"bad decrypt", opts, http.MethodGet, "http", "example.com", "/", nil, map[string]string{urlutil.QuerySessionEncrypted: "KBEjQ9rnCxaAX-GOqexGw9ivEQURqts3zZ2mNGy0wnVa3SbtM399KlBq2nZ-9wM21FfsZX52er4jlmC7kPEKM3P7uZ41zR0zeys1-_74a5tQp-vsf1WXZfRsgVOuBcWPkMiWEoc379JFHxGDudp5VhU8B-dcQt4f3_PtLTHARkuH54io1Va2gNMq4Hiy8sQ1MPGCQeltH_JMzzdDpXdmdusWrXUvCGkba24muvAV06D8XRVJj6Iu9eK94qFnqcHc7wzziEbb8ADBues9dwbtb6jl8vMWz5rN6XvXqA5YpZv_MQZlsrO4oXFFQDevdgB84cX1tVbVu6qZvK_yQBZqzpOjWA9uIaoSENMytoXuWAlFO_sXjswfX8JTNdGwzB7qQRNPqxVG_sM_tzY3QhPm8zqwEzsXG5DokxZfVt2I5WJRUEovFDb4BnK9KFnnkEzLEdMudixVnXeGmTtycgJvoTeTCQRPfDYkcgJ7oKf4tGea-W7z5UAVa2RduJM9ZoM6YtJX7jgDm__PvvqcE0knJUF87XHBzdcOjoDF-CUze9xDJgNBlvPbJqVshKrwoqSYpePSDH9GUCNKxGequW3Ma8GvlFfhwd0rK6IZG-XWkyk0XSWQIGkDSjAvhB1wsOusCCguDjbpVZpaW5MMyTkmx68pl6qlIKT5UCcrVPl4ix5ZEj91mUDF0O1t04haD7VZuLVFXVGmqtFrBKI76sdYN-zkokaa1_chPRTyqMQFlqu_8LD6-RiK3UccGM-dEmnX72i91NP9F9OK0WJr9Cheup1C_P0mjqAO4Cb8oIHm0Oxz_mRqv5QbTGJtb3xwPLPuVjVCiE4gGBcuU2ixpSVf5HUF7y1KicVMCKiX9ATCBtg8sTdQZQnPEtHcHHAvdsnDVwev1LGfqA-Gdvg="}, &mock.Encoder{MarshalResponse: []byte("x")}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""},
{"bad save session", opts, http.MethodGet, "http", "example.com", "/", nil, map[string]string{urlutil.QuerySessionEncrypted: goodEncryptionString}, &mock.Encoder{MarshalResponse: []byte("x")}, &sessions.MockSessionStore{SaveError: errors.New("hi")}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""}, {"bad save session", opts, http.MethodGet, "http", "example.com", "/", nil, map[string]string{urlutil.QuerySessionEncrypted: goodEncryptionString}, &mock.Encoder{MarshalResponse: []byte("x")}, &mstore.Store{SaveError: errors.New("hi")}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""},
{"bad base64", opts, http.MethodGet, "http", "example.com", "/", nil, map[string]string{urlutil.QuerySessionEncrypted: "^"}, &mock.Encoder{MarshalResponse: []byte("x")}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""}, {"bad base64", opts, http.MethodGet, "http", "example.com", "/", nil, map[string]string{urlutil.QuerySessionEncrypted: "^"}, &mock.Encoder{MarshalResponse: []byte("x")}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""},
{"malformed redirect", opts, http.MethodGet, "http", "example.com", "/", nil, nil, &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""}, {"malformed redirect", opts, http.MethodGet, "http", "example.com", "/", nil, nil, &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
@ -386,12 +387,12 @@ func TestProxy_ProgrammaticCallback(t *testing.T) {
wantStatus int wantStatus int
wantBody string wantBody string
}{ }{
{"good", opts, http.MethodGet, "http://pomerium.io/", nil, map[string]string{urlutil.QueryCallbackURI: "ok", urlutil.QuerySessionEncrypted: goodEncryptionString}, &mock.Encoder{MarshalResponse: []byte("x")}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound, ""}, {"good", opts, http.MethodGet, "http://pomerium.io/", nil, map[string]string{urlutil.QueryCallbackURI: "ok", urlutil.QuerySessionEncrypted: goodEncryptionString}, &mock.Encoder{MarshalResponse: []byte("x")}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound, ""},
{"good programmatic", opts, http.MethodGet, "http://pomerium.io/", nil, map[string]string{urlutil.QueryIsProgrammatic: "true", urlutil.QueryCallbackURI: "ok", urlutil.QuerySessionEncrypted: goodEncryptionString}, &mock.Encoder{MarshalResponse: []byte("x")}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound, ""}, {"good programmatic", opts, http.MethodGet, "http://pomerium.io/", nil, map[string]string{urlutil.QueryIsProgrammatic: "true", urlutil.QueryCallbackURI: "ok", urlutil.QuerySessionEncrypted: goodEncryptionString}, &mock.Encoder{MarshalResponse: []byte("x")}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound, ""},
{"bad decrypt", opts, http.MethodGet, "http://pomerium.io/", nil, map[string]string{urlutil.QuerySessionEncrypted: goodEncryptionString + cryptutil.NewBase64Key()}, &mock.Encoder{MarshalResponse: []byte("x")}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""}, {"bad decrypt", opts, http.MethodGet, "http://pomerium.io/", nil, map[string]string{urlutil.QuerySessionEncrypted: goodEncryptionString + cryptutil.NewBase64Key()}, &mock.Encoder{MarshalResponse: []byte("x")}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""},
{"bad save session", opts, http.MethodGet, "http://pomerium.io/", nil, map[string]string{urlutil.QuerySessionEncrypted: goodEncryptionString}, &mock.Encoder{MarshalResponse: []byte("x")}, &sessions.MockSessionStore{SaveError: errors.New("hi")}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""}, {"bad save session", opts, http.MethodGet, "http://pomerium.io/", nil, map[string]string{urlutil.QuerySessionEncrypted: goodEncryptionString}, &mock.Encoder{MarshalResponse: []byte("x")}, &mstore.Store{SaveError: errors.New("hi")}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""},
{"bad base64", opts, http.MethodGet, "http://pomerium.io/", nil, map[string]string{urlutil.QuerySessionEncrypted: "^"}, &mock.Encoder{MarshalResponse: []byte("x")}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""}, {"bad base64", opts, http.MethodGet, "http://pomerium.io/", nil, map[string]string{urlutil.QuerySessionEncrypted: "^"}, &mock.Encoder{MarshalResponse: []byte("x")}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""},
{"malformed redirect", opts, http.MethodGet, "http://pomerium.io/", nil, nil, &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""}, {"malformed redirect", opts, http.MethodGet, "http://pomerium.io/", nil, nil, &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {

View file

@ -1,7 +1,11 @@
package proxy // import "github.com/pomerium/pomerium/proxy" package proxy // import "github.com/pomerium/pomerium/proxy"
import ( import (
"context"
"errors"
"fmt" "fmt"
"io"
"io/ioutil"
"net/http" "net/http"
"github.com/pomerium/pomerium/internal/encoding" "github.com/pomerium/pomerium/internal/encoding"
@ -30,23 +34,82 @@ func (p *Proxy) AuthenticateSession(next http.Handler) http.Handler {
ctx, span := trace.StartSpan(r.Context(), "proxy.AuthenticateSession") ctx, span := trace.StartSpan(r.Context(), "proxy.AuthenticateSession")
defer span.End() defer span.End()
if s, err := sessions.FromContext(ctx); err != nil { _, err := sessions.FromContext(ctx)
log.FromRequest(r).Debug().Err(err).Msg("proxy: authenticate session") if errors.Is(err, sessions.ErrExpired) {
p.sessionStore.ClearSession(w, r) ctx, err = p.refresh(ctx, w, r)
if s != nil && s.Programmatic { if err != nil {
return httputil.NewError(http.StatusUnauthorized, err) log.FromRequest(r).Warn().Err(err).Msg("proxy: refresh failed")
return p.redirectToSignin(w, r)
} }
signinURL := *p.authenticateSigninURL log.FromRequest(r).Info().Msg("proxy: refresh success")
q := signinURL.Query() } else if err != nil {
q.Set(urlutil.QueryRedirectURI, urlutil.GetAbsoluteURL(r).String()) log.FromRequest(r).Debug().Err(err).Msg("proxy: session state")
signinURL.RawQuery = q.Encode() return p.redirectToSignin(w, r)
httputil.Redirect(w, r, urlutil.NewSignedURL(p.SharedKey, &signinURL).String(), http.StatusFound)
} }
p.addPomeriumHeaders(w, r) p.addPomeriumHeaders(w, r)
next.ServeHTTP(w, r.WithContext(ctx)) next.ServeHTTP(w, r.WithContext(ctx))
return nil return nil
}) })
}
func (p *Proxy) refresh(ctx context.Context, w http.ResponseWriter, r *http.Request) (context.Context, error) {
ctx, span := trace.StartSpan(ctx, "proxy.AuthenticateSession/refresh")
defer span.End()
s, err := sessions.FromContext(ctx)
if !errors.Is(err, sessions.ErrExpired) || s == nil {
return nil, errors.New("proxy: unexpected session state for refresh")
}
// 1 - build a signed url to call refresh on authenticate service
refreshURI := *p.authenticateRefreshURL
q := refreshURI.Query()
q.Set("ati", s.AccessTokenID) // hash value points to parent token
q.Set("aud", urlutil.StripPort(r.Host)) // request's audience, this route
refreshURI.RawQuery = q.Encode()
signedRefreshURL := urlutil.NewSignedURL(p.SharedKey, &refreshURI).String()
// 2 - http call to authenticate service
req, err := http.NewRequestWithContext(ctx, http.MethodGet, signedRefreshURL, nil)
if err != nil {
return nil, fmt.Errorf("proxy: backend refresh: new request: %v", err)
}
res, err := httputil.DefaultClient.Do(req)
if err != nil {
return nil, fmt.Errorf("proxy: fetch %v: %w", signedRefreshURL, err)
}
defer res.Body.Close()
jwtBytes, err := ioutil.ReadAll(io.LimitReader(res.Body, 4<<10))
if err != nil {
return nil, err
}
// 3 - save refreshed session to the client's session store
if err = p.sessionStore.SaveSession(w, r, jwtBytes); err != nil {
return nil, err
}
// 4 - add refreshed session to the current request context
var state sessions.State
if err := p.encoder.Unmarshal(jwtBytes, &state); err != nil {
return nil, err
}
if err := state.Verify(urlutil.StripPort(r.Host)); err != nil {
return nil, err
}
return sessions.NewContext(r.Context(), &state, err), nil
}
func (p *Proxy) redirectToSignin(w http.ResponseWriter, r *http.Request) error {
s, err := sessions.FromContext(r.Context())
if s != nil && err != nil && s.Programmatic {
return httputil.NewError(http.StatusUnauthorized, err)
}
p.sessionStore.ClearSession(w, r)
signinURL := *p.authenticateSigninURL
q := signinURL.Query()
q.Set(urlutil.QueryRedirectURI, urlutil.GetAbsoluteURL(r).String())
signinURL.RawQuery = q.Encode()
log.FromRequest(r).Debug().Str("url", signinURL.String()).Msg("proxy: redirectToSignin")
httputil.Redirect(w, r, urlutil.NewSignedURL(p.SharedKey, &signinURL).String(), http.StatusFound)
return nil
} }
func (p *Proxy) addPomeriumHeaders(w http.ResponseWriter, r *http.Request) { func (p *Proxy) addPomeriumHeaders(w http.ResponseWriter, r *http.Request) {
@ -61,8 +124,8 @@ func (p *Proxy) addPomeriumHeaders(w http.ResponseWriter, r *http.Request) {
} }
} }
// AuthorizeSession is middleware to enforce a user is authorized for a request // AuthorizeSession is middleware to enforce a user is authorized for a request.
// session state is retrieved from the users's request context. // Session state is retrieved from the users's request context.
func (p *Proxy) AuthorizeSession(next http.Handler) http.Handler { func (p *Proxy) AuthorizeSession(next http.Handler) http.Handler {
return httputil.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error { return httputil.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
ctx, span := trace.StartSpan(r.Context(), "proxy.AuthorizeSession") ctx, span := trace.StartSpan(r.Context(), "proxy.AuthorizeSession")

View file

@ -10,10 +10,14 @@ import (
"time" "time"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
"github.com/pomerium/pomerium/internal/identity"
"github.com/pomerium/pomerium/internal/sessions"
"github.com/pomerium/pomerium/proxy/clients" "github.com/pomerium/pomerium/proxy/clients"
"gopkg.in/square/go-jose.v2/jwt" "gopkg.in/square/go-jose.v2/jwt"
"github.com/pomerium/pomerium/internal/encoding"
"github.com/pomerium/pomerium/internal/encoding/mock"
"github.com/pomerium/pomerium/internal/identity"
"github.com/pomerium/pomerium/internal/sessions"
mstore "github.com/pomerium/pomerium/internal/sessions/mock"
) )
func TestProxy_AuthenticateSession(t *testing.T) { func TestProxy_AuthenticateSession(t *testing.T) {
@ -30,24 +34,39 @@ func TestProxy_AuthenticateSession(t *testing.T) {
session sessions.SessionStore session sessions.SessionStore
ctxError error ctxError error
provider identity.Authenticator provider identity.Authenticator
encoder encoding.MarshalUnmarshaler
refreshURL string
wantStatus int wantStatus int
}{ }{
{"good", false, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Second))}}, nil, identity.MockProvider{}, http.StatusOK}, {"good", false, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Second))}}, nil, identity.MockProvider{}, &mock.Encoder{}, "", http.StatusOK},
{"invalid session", false, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Second))}}, errors.New("hi"), identity.MockProvider{}, http.StatusFound}, {"invalid session", false, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Second))}}, errors.New("hi"), identity.MockProvider{}, &mock.Encoder{}, "", http.StatusFound},
{"expired", false, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{}, http.StatusFound}, {"expired", false, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{}, &mock.Encoder{}, "", http.StatusOK},
{"expired and programmatic", false, &sessions.MockSessionStore{Session: &sessions.State{Programmatic: true, Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{}, http.StatusUnauthorized}, {"expired and programmatic", false, &mstore.Store{Session: &sessions.State{Programmatic: true, Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{}, &mock.Encoder{}, "", http.StatusOK},
{"invalid session and programmatic", false, &sessions.MockSessionStore{Session: &sessions.State{Programmatic: true, Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Second))}}, errors.New("hi"), identity.MockProvider{}, http.StatusUnauthorized}, {"invalid session and programmatic", false, &mstore.Store{Session: &sessions.State{Programmatic: true, Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Second))}}, errors.New("hi"), identity.MockProvider{}, &mock.Encoder{}, "", http.StatusUnauthorized},
{"expired and refreshed ok", false, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{}, &mock.Encoder{}, "", http.StatusOK},
{"expired and save failed", false, &mstore.Store{SaveError: errors.New("err"), Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{}, &mock.Encoder{}, "", http.StatusFound},
{"expired and unmarshal failed", false, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{}, &mock.Encoder{UnmarshalError: errors.New("err")}, "", http.StatusFound},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "REFRESH GOOD")
}))
defer ts.Close()
rURL := ts.URL
if tt.refreshURL != "" {
rURL = tt.refreshURL
}
a := Proxy{ a := Proxy{
SharedKey: "80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ=", SharedKey: "80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ=",
cookieSecret: []byte("80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ="), cookieSecret: []byte("80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ="),
authenticateURL: uriParseHelper("https://authenticate.corp.example"), authenticateURL: uriParseHelper("https://authenticate.corp.example"),
authenticateSigninURL: uriParseHelper("https://authenticate.corp.example/sign_in"), authenticateSigninURL: uriParseHelper("https://authenticate.corp.example/sign_in"),
authenticateRefreshURL: uriParseHelper(rURL),
sessionStore: tt.session, sessionStore: tt.session,
encoder: tt.encoder,
} }
r := httptest.NewRequest(http.MethodGet, "/", nil) r := httptest.NewRequest(http.MethodGet, "/", nil)
state, _ := tt.session.LoadSession(r) state, _ := tt.session.LoadSession(r)
@ -82,10 +101,10 @@ func TestProxy_AuthorizeSession(t *testing.T) {
wantStatus int wantStatus int
}{ }{
{"user is authorized", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Second))}}, clients.MockAuthorize{AuthorizeResponse: true}, nil, identity.MockProvider{}, http.StatusOK}, {"user is authorized", &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Second))}}, clients.MockAuthorize{AuthorizeResponse: true}, nil, identity.MockProvider{}, http.StatusOK},
{"user is not authorized", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Second))}}, clients.MockAuthorize{AuthorizeResponse: false}, nil, identity.MockProvider{}, http.StatusUnauthorized}, {"user is not authorized", &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Second))}}, clients.MockAuthorize{AuthorizeResponse: false}, nil, identity.MockProvider{}, http.StatusUnauthorized},
{"invalid session", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Second))}}, clients.MockAuthorize{AuthorizeResponse: true}, errors.New("hi"), identity.MockProvider{}, http.StatusUnauthorized}, {"invalid session", &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Second))}}, clients.MockAuthorize{AuthorizeResponse: true}, errors.New("hi"), identity.MockProvider{}, http.StatusUnauthorized},
{"authz client error", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Second))}}, clients.MockAuthorize{AuthorizeError: errors.New("err")}, nil, identity.MockProvider{}, http.StatusInternalServerError}, {"authz client error", &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Second))}}, clients.MockAuthorize{AuthorizeError: errors.New("err")}, nil, identity.MockProvider{}, http.StatusInternalServerError},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
@ -143,9 +162,9 @@ func TestProxy_SignRequest(t *testing.T) {
wantStatus int wantStatus int
wantHeaders string wantHeaders string
}{ }{
{"good", &sessions.MockSessionStore{Session: &sessions.State{Email: "test"}}, nil, nil, http.StatusOK, "ok"}, {"good", &mstore.Store{Session: &sessions.State{Email: "test"}}, nil, nil, http.StatusOK, "ok"},
{"invalid session", &sessions.MockSessionStore{Session: &sessions.State{Email: "test"}}, nil, errors.New("err"), http.StatusForbidden, ""}, {"invalid session", &mstore.Store{Session: &sessions.State{Email: "test"}}, nil, errors.New("err"), http.StatusForbidden, ""},
{"signature failure, warn but ok", &sessions.MockSessionStore{Session: &sessions.State{Email: "test"}}, errors.New("err"), nil, http.StatusOK, ""}, {"signature failure, warn but ok", &mstore.Store{Session: &sessions.State{Email: "test"}}, errors.New("err"), nil, http.StatusOK, ""},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {

View file

@ -21,6 +21,9 @@ import (
"github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/middleware" "github.com/pomerium/pomerium/internal/middleware"
"github.com/pomerium/pomerium/internal/sessions" "github.com/pomerium/pomerium/internal/sessions"
"github.com/pomerium/pomerium/internal/sessions/cookie"
"github.com/pomerium/pomerium/internal/sessions/header"
"github.com/pomerium/pomerium/internal/sessions/queryparam"
"github.com/pomerium/pomerium/internal/telemetry/metrics" "github.com/pomerium/pomerium/internal/telemetry/metrics"
"github.com/pomerium/pomerium/internal/tripper" "github.com/pomerium/pomerium/internal/tripper"
"github.com/pomerium/pomerium/internal/urlutil" "github.com/pomerium/pomerium/internal/urlutil"
@ -28,12 +31,11 @@ import (
) )
const ( const (
// dashboardURL is the path to authenticate's sign in endpoint // authenticate urls
dashboardURL = "/.pomerium" dashboardURL = "/.pomerium"
// signinURL is the path to authenticate's sign in endpoint
signinURL = "/.pomerium/sign_in" signinURL = "/.pomerium/sign_in"
// signoutURL is the path to authenticate's sign out endpoint
signoutURL = "/.pomerium/sign_out" signoutURL = "/.pomerium/sign_out"
refreshURL = "/.pomerium/refresh"
) )
// ValidateOptions checks that proper configuration settings are set to create // ValidateOptions checks that proper configuration settings are set to create
@ -72,12 +74,14 @@ type Proxy struct {
authenticateURL *url.URL authenticateURL *url.URL
authenticateSigninURL *url.URL authenticateSigninURL *url.URL
authenticateSignoutURL *url.URL authenticateSignoutURL *url.URL
authenticateRefreshURL *url.URL
authorizeURL *url.URL authorizeURL *url.URL
AuthorizeClient clients.Authorizer AuthorizeClient clients.Authorizer
encoder encoding.Unmarshaler encoder encoding.Unmarshaler
cookieOptions *sessions.CookieOptions cookieOptions *cookie.Options
cookieSecret []byte cookieSecret []byte
defaultUpstreamTimeout time.Duration defaultUpstreamTimeout time.Duration
refreshCooldown time.Duration refreshCooldown time.Duration
@ -104,7 +108,7 @@ func New(opts config.Options) (*Proxy, error) {
return nil, err return nil, err
} }
cookieOptions := &sessions.CookieOptions{ cookieOptions := &cookie.Options{
Name: opts.CookieName, Name: opts.CookieName,
Domain: opts.CookieDomain, Domain: opts.CookieDomain,
Secure: opts.CookieSecure, Secure: opts.CookieSecure,
@ -112,7 +116,7 @@ func New(opts config.Options) (*Proxy, error) {
Expire: opts.CookieExpire, Expire: opts.CookieExpire,
} }
cookieStore, err := sessions.NewCookieLoader(cookieOptions, encoder) cookieStore, err := cookie.NewStore(cookieOptions, encoder)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -129,8 +133,8 @@ func New(opts config.Options) (*Proxy, error) {
sessionStore: cookieStore, sessionStore: cookieStore,
sessionLoaders: []sessions.SessionLoader{ sessionLoaders: []sessions.SessionLoader{
cookieStore, cookieStore,
sessions.NewHeaderStore(encoder, "Pomerium"), header.NewStore(encoder, "Pomerium"),
sessions.NewQueryParamStore(encoder, "pomerium_session")}, queryparam.NewStore(encoder, "pomerium_session")},
signingKey: opts.SigningKey, signingKey: opts.SigningKey,
templates: template.Must(frontend.NewTemplates()), templates: template.Must(frontend.NewTemplates()),
} }
@ -139,6 +143,7 @@ func New(opts config.Options) (*Proxy, error) {
p.authenticateURL, _ = urlutil.DeepCopy(opts.AuthenticateURL) p.authenticateURL, _ = urlutil.DeepCopy(opts.AuthenticateURL)
p.authenticateSigninURL = p.authenticateURL.ResolveReference(&url.URL{Path: signinURL}) p.authenticateSigninURL = p.authenticateURL.ResolveReference(&url.URL{Path: signinURL})
p.authenticateSignoutURL = p.authenticateURL.ResolveReference(&url.URL{Path: signoutURL}) p.authenticateSignoutURL = p.authenticateURL.ResolveReference(&url.URL{Path: signoutURL})
p.authenticateRefreshURL = p.authenticateURL.ResolveReference(&url.URL{Path: refreshURL})
if err := p.UpdatePolicies(&opts); err != nil { if err := p.UpdatePolicies(&opts); err != nil {
return nil, err return nil, err