authenticate/proxy: add backend refresh (#438)

This commit is contained in:
Bobby DeSimone 2019-12-30 10:47:54 -08:00 committed by GitHub
parent 9a330613aa
commit ec029c679b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
35 changed files with 1226 additions and 445 deletions

View file

@ -198,11 +198,6 @@ issues:
linters:
- staticcheck
# todo(bdd): replace in go 1.13
- path: proxy/proxy.go
text: "copylocks: assignment copies lock value to transport"
linters:
- govet
# Independently from option `exclude` we use default exclude patterns,
# it can be disabled by this option. To list all
# excluded by default patterns execute `golangci-lint run --help`.

View file

@ -16,6 +16,10 @@ import (
"github.com/pomerium/pomerium/internal/frontend"
"github.com/pomerium/pomerium/internal/identity"
"github.com/pomerium/pomerium/internal/sessions"
"github.com/pomerium/pomerium/internal/sessions/cache"
"github.com/pomerium/pomerium/internal/sessions/cookie"
"github.com/pomerium/pomerium/internal/sessions/header"
"github.com/pomerium/pomerium/internal/sessions/queryparam"
"github.com/pomerium/pomerium/internal/urlutil"
)
@ -49,6 +53,8 @@ type Authenticate struct {
// authentication flow
RedirectURL *url.URL
// values related to cross service communication
//
// sharedKey is used to encrypt and authenticate data between services
sharedKey string
// sharedCipher is used to encrypt data for use between services
@ -57,16 +63,21 @@ type Authenticate struct {
// by other services
sharedEncoder encoding.MarshalUnmarshaler
// data related to this service only
cookieOptions *sessions.CookieOptions
// cookieSecret is the secret to encrypt and authenticate data for this service
// values related to user sessions
//
// cookieSecret is the secret to encrypt and authenticate session data
cookieSecret []byte
// is the cipher to use to encrypt data for this service
cookieCipher cipher.AEAD
sessionStore sessions.SessionStore
// cookieCipher is the cipher to use to encrypt/decrypt session data
cookieCipher cipher.AEAD
// encryptedEncoder is the encoder used to marshal and unmarshal session data
encryptedEncoder encoding.MarshalUnmarshaler
sessionStores []sessions.SessionStore
sessionLoaders []sessions.SessionLoader
// sessionStore is the session store used to persist a user's session
sessionStore sessions.SessionStore
cookieOptions *cookie.Options
// sessionLoaders are a collection of session loaders to attempt to pull
// a user's session state from
sessionLoaders []sessions.SessionLoader
// provider is the interface to interacting with the identity provider (IdP)
provider identity.Authenticator
@ -92,7 +103,7 @@ func New(opts config.Options) (*Authenticate, error) {
cookieCipher, _ := cryptutil.NewAEADCipher(decodedCookieSecret)
encryptedEncoder := ecjson.New(cookieCipher)
cookieOptions := &sessions.CookieOptions{
cookieOptions := &cookie.Options{
Name: opts.CookieName,
Domain: opts.CookieDomain,
Secure: opts.CookieSecure,
@ -100,12 +111,13 @@ func New(opts config.Options) (*Authenticate, error) {
Expire: opts.CookieExpire,
}
cookieStore, err := sessions.NewCookieStore(cookieOptions, encryptedEncoder)
cookieStore, err := cookie.NewStore(cookieOptions, encryptedEncoder)
if err != nil {
return nil, err
}
qpStore := sessions.NewQueryParamStore(encryptedEncoder, "pomerium_programmatic_token")
headerStore := sessions.NewHeaderStore(encryptedEncoder, "Pomerium")
cacheStore := cache.NewStore(encryptedEncoder, cookieStore, opts.CookieName)
qpStore := queryparam.NewStore(encryptedEncoder, "pomerium_programmatic_token")
headerStore := header.NewStore(encryptedEncoder, "Pomerium")
redirectURL, _ := urlutil.DeepCopy(opts.AuthenticateURL)
redirectURL.Path = callbackPath
@ -135,10 +147,9 @@ func New(opts config.Options) (*Authenticate, error) {
cookieSecret: decodedCookieSecret,
cookieCipher: cookieCipher,
cookieOptions: cookieOptions,
sessionStore: cookieStore,
sessionStore: cacheStore,
encryptedEncoder: encryptedEncoder,
sessionLoaders: []sessions.SessionLoader{qpStore, headerStore, cookieStore},
sessionStores: []sessions.SessionStore{cookieStore, qpStore},
sessionLoaders: []sessions.SessionLoader{cacheStore, qpStore, headerStore, cookieStore},
// IdP
provider: provider,

View file

@ -72,15 +72,18 @@ func TestOptions_Validate(t *testing.T) {
func TestNew(t *testing.T) {
good := newTestOptions(t)
good.CookieName = "A"
badRedirectURL := newTestOptions(t)
badRedirectURL.AuthenticateURL = nil
badRedirectURL.CookieName = "B"
badCookieName := newTestOptions(t)
badCookieName.CookieName = ""
badProvider := newTestOptions(t)
badProvider.Provider = ""
badProvider.CookieName = "C"
tests := []struct {
name string

View file

@ -1,6 +1,7 @@
package authenticate // import "github.com/pomerium/pomerium/authenticate"
import (
"context"
"encoding/base64"
"encoding/json"
"errors"
@ -18,6 +19,7 @@ import (
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/middleware"
"github.com/pomerium/pomerium/internal/sessions"
"github.com/pomerium/pomerium/internal/telemetry/trace"
"github.com/pomerium/pomerium/internal/urlutil"
)
@ -58,6 +60,7 @@ func (a *Authenticate) Handler() http.Handler {
v.Use(a.VerifySession)
v.Path("/sign_in").Handler(httputil.HandlerFunc(a.SignIn))
v.Path("/sign_out").Handler(httputil.HandlerFunc(a.SignOut))
v.Path("/refresh").Handler(httputil.HandlerFunc(a.Refresh)).Methods(http.MethodGet)
// programmatic access api endpoint
api := r.PathPrefix("/api").Subrouter()
@ -73,12 +76,12 @@ func (a *Authenticate) VerifySession(next http.Handler) http.Handler {
return httputil.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
state, err := sessions.FromContext(r.Context())
if errors.Is(err, sessions.ErrExpired) {
if err := a.refresh(w, r, state); err != nil {
ctx, err := a.refresh(w, r, state)
if err != nil {
log.FromRequest(r).Info().Err(err).Msg("authenticate: verify session, refresh")
return a.reauthenticateOrFail(w, r, err)
}
// redirect to restart middleware-chain following refresh
httputil.Redirect(w, r, urlutil.GetAbsoluteURL(r).String(), http.StatusFound)
next.ServeHTTP(w, r.WithContext(ctx))
return nil
} else if err != nil {
log.FromRequest(r).Info().Err(err).Msg("authenticate: verify session")
@ -89,15 +92,18 @@ func (a *Authenticate) VerifySession(next http.Handler) http.Handler {
})
}
func (a *Authenticate) refresh(w http.ResponseWriter, r *http.Request, s *sessions.State) error {
newSession, err := a.provider.Refresh(r.Context(), s)
func (a *Authenticate) refresh(w http.ResponseWriter, r *http.Request, s *sessions.State) (context.Context, error) {
ctx, span := trace.StartSpan(r.Context(), "authenticate.VerifySession/refresh")
defer span.End()
newSession, err := a.provider.Refresh(ctx, s)
if err != nil {
return fmt.Errorf("authenticate: refresh failed: %w", err)
return nil, fmt.Errorf("authenticate: refresh failed: %w", err)
}
if err := a.sessionStore.SaveSession(w, r, newSession); err != nil {
return fmt.Errorf("authenticate: refresh save failed: %w", err)
return nil, fmt.Errorf("authenticate: refresh save failed: %w", err)
}
return nil
// return the new session and add it to the current request context
return sessions.NewContext(ctx, newSession, err), nil
}
// RobotsTxt handles the /robots.txt route.
@ -158,7 +164,6 @@ func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) error {
encSession, err := a.encryptedEncoder.Marshal(newSession)
if err != nil {
return httputil.NewError(http.StatusBadRequest, err)
}
callbackParams.Set(urlutil.QueryRefreshToken, string(encSession))
callbackParams.Set(urlutil.QueryIsProgrammatic, "true")
@ -345,3 +350,27 @@ func (a *Authenticate) RefreshAPI(w http.ResponseWriter, r *http.Request) error
w.Write(jsonResponse)
return nil
}
// Refresh is called by the proxy service to handle backend session refresh.
//
// NOTE: The actual refresh is actually handled as part of the "VerifySession"
// middleware. This handler is responsible for creating a new route scoped
// session and returning it.
func (a *Authenticate) Refresh(w http.ResponseWriter, r *http.Request) error {
s, err := sessions.FromContext(r.Context())
if err != nil {
return httputil.NewError(http.StatusBadRequest, err)
}
routeSession := s.NewSession(r.Host, []string{r.Host, r.FormValue("aud")})
routeSession.AccessTokenID = s.AccessTokenID
signedJWT, err := a.sharedEncoder.Marshal(routeSession.RouteSession())
if err != nil {
return err
}
w.Header().Set("Content-Type", "application/jwt") // RFC 7519 : 10.3.1
w.Write(signedJWT)
return nil
}

View file

@ -11,17 +11,18 @@ import (
"testing"
"time"
"github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/cryptutil"
"github.com/pomerium/pomerium/internal/encoding"
"github.com/pomerium/pomerium/internal/encoding/mock"
"github.com/pomerium/pomerium/internal/frontend"
"github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/identity"
"github.com/pomerium/pomerium/internal/sessions"
"github.com/pomerium/pomerium/internal/urlutil"
"github.com/pomerium/pomerium/internal/sessions/cookie"
mstore "github.com/pomerium/pomerium/internal/sessions/mock"
"github.com/google/go-cmp/cmp"
"github.com/pomerium/pomerium/internal/urlutil"
"golang.org/x/crypto/chacha20poly1305"
"golang.org/x/oauth2"
"gopkg.in/square/go-jose.v2/jwt"
@ -32,7 +33,7 @@ func testAuthenticate() *Authenticate {
auth.RedirectURL, _ = url.Parse("https://auth.example.com/oauth/callback")
auth.sharedKey = cryptutil.NewBase64Key()
auth.cookieSecret = cryptutil.NewKey()
auth.cookieOptions = &sessions.CookieOptions{Name: "name"}
auth.cookieOptions = &cookie.Options{Name: "name"}
auth.templates = template.Must(frontend.NewTemplates())
return &auth
}
@ -112,19 +113,19 @@ func TestAuthenticate_SignIn(t *testing.T) {
encoder encoding.MarshalUnmarshaler
wantCode int
}{
{"good", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
{"session not valid", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(-10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
{"bad redirect uri query", "", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "^^^"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusBadRequest},
{"bad marshal", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{MarshalError: errors.New("error")}, http.StatusBadRequest},
{"session error", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, &sessions.MockSessionStore{LoadError: errors.New("error")}, identity.MockProvider{}, &mock.Encoder{}, http.StatusBadRequest},
{"good with different programmatic redirect", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/", urlutil.QueryCallbackURI: "https://some.example"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
{"encrypted encoder error", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/", urlutil.QueryCallbackURI: "https://some.example"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{MarshalError: errors.New("error")}, http.StatusBadRequest},
{"good with callback uri set", "https", "corp.example.example", map[string]string{urlutil.QueryCallbackURI: "https://some.example/", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
{"bad callback uri set", "https", "corp.example.example", map[string]string{urlutil.QueryCallbackURI: "^", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusBadRequest},
{"good programmatic request", "https", "corp.example.example", map[string]string{urlutil.QueryIsProgrammatic: "true", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
{"good additional audience", "https", "corp.example.example", map[string]string{urlutil.QueryForwardAuth: "x.y.z", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
{"good user impersonate", "https", "corp.example.example", map[string]string{urlutil.QueryImpersonateAction: "set", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
{"bad user impersonate save failure", "https", "corp.example.example", map[string]string{urlutil.QueryImpersonateAction: "set", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &sessions.MockSessionStore{SaveError: errors.New("err"), Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusBadRequest},
{"good", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
{"session not valid", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(-10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
{"bad redirect uri query", "", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "^^^"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusBadRequest},
{"bad marshal", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{MarshalError: errors.New("error")}, http.StatusBadRequest},
{"session error", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{LoadError: errors.New("error")}, identity.MockProvider{}, &mock.Encoder{}, http.StatusBadRequest},
{"good with different programmatic redirect", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/", urlutil.QueryCallbackURI: "https://some.example"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
{"encrypted encoder error", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/", urlutil.QueryCallbackURI: "https://some.example"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{MarshalError: errors.New("error")}, http.StatusBadRequest},
{"good with callback uri set", "https", "corp.example.example", map[string]string{urlutil.QueryCallbackURI: "https://some.example/", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
{"bad callback uri set", "https", "corp.example.example", map[string]string{urlutil.QueryCallbackURI: "^", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusBadRequest},
{"good programmatic request", "https", "corp.example.example", map[string]string{urlutil.QueryIsProgrammatic: "true", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
{"good additional audience", "https", "corp.example.example", map[string]string{urlutil.QueryForwardAuth: "x.y.z", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
{"good user impersonate", "https", "corp.example.example", map[string]string{urlutil.QueryImpersonateAction: "set", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
{"bad user impersonate save failure", "https", "corp.example.example", map[string]string{urlutil.QueryImpersonateAction: "set", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{SaveError: errors.New("err"), Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusBadRequest},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
@ -136,7 +137,7 @@ func TestAuthenticate_SignIn(t *testing.T) {
sharedEncoder: tt.encoder,
encryptedEncoder: tt.encoder,
sharedCipher: aead,
cookieOptions: &sessions.CookieOptions{
cookieOptions: &cookie.Options{
Name: "cookie",
Domain: "foo",
},
@ -186,10 +187,10 @@ func TestAuthenticate_SignOut(t *testing.T) {
wantCode int
wantBody string
}{
{"good post", http.MethodPost, nil, "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, http.StatusFound, ""},
{"failed revoke", http.MethodPost, nil, "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{RevokeError: errors.New("OH NO")}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, http.StatusBadRequest, "{\"Status\":400,\"Error\":\"Bad Request: OH NO\"}\n"},
{"load session error", http.MethodPost, errors.New("error"), "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{RevokeError: errors.New("OH NO")}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, http.StatusBadRequest, "{\"Status\":400,\"Error\":\"Bad Request: error\"}\n"},
{"bad redirect uri", http.MethodPost, nil, "corp.pomerium.io/", "sig", "ts", identity.MockProvider{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, http.StatusBadRequest, "{\"Status\":400,\"Error\":\"Bad Request: corp.pomerium.io/ url does contain a valid scheme\"}\n"},
{"good post", http.MethodPost, nil, "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, http.StatusFound, ""},
{"failed revoke", http.MethodPost, nil, "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{RevokeError: errors.New("OH NO")}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, http.StatusBadRequest, "{\"Status\":400,\"Error\":\"Bad Request: OH NO\"}\n"},
{"load session error", http.MethodPost, errors.New("error"), "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{RevokeError: errors.New("OH NO")}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, http.StatusBadRequest, "{\"Status\":400,\"Error\":\"Bad Request: error\"}\n"},
{"bad redirect uri", http.MethodPost, nil, "corp.pomerium.io/", "sig", "ts", identity.MockProvider{}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, http.StatusBadRequest, "{\"Status\":400,\"Error\":\"Bad Request: corp.pomerium.io/ url does contain a valid scheme\"}\n"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
@ -247,19 +248,19 @@ func TestAuthenticate_OAuthCallback(t *testing.T) {
want string
wantCode int
}{
{"good", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusFound},
{"failed authenticate", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateError: errors.New("error")}, "", http.StatusInternalServerError},
{"failed save session", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{SaveError: errors.New("error")}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "", http.StatusInternalServerError},
{"provider returned error", http.MethodGet, time.Now().Unix(), "", "", "", "idp error", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "", http.StatusBadRequest},
{"empty code", http.MethodGet, time.Now().Unix(), "", "", "", "", "", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "", http.StatusBadRequest},
{"invalid redirect uri", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "", http.StatusBadRequest},
{"bad redirect uri", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "http://^^^", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusBadRequest},
{"bad timing - too soon", http.MethodGet, time.Now().Add(1 * time.Hour).Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusBadRequest},
{"bad timing - expired", http.MethodGet, time.Now().Add(-1 * time.Hour).Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusBadRequest},
{"bad base64", http.MethodGet, time.Now().Unix(), "", "", "^", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusBadRequest},
{"too many seperators", http.MethodGet, time.Now().Unix(), "", "", "|ok|now|what", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusBadRequest},
{"bad hmac", http.MethodGet, time.Now().Unix(), "", "NOTMAC", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusBadRequest},
{"bad hmac", http.MethodGet, time.Now().Unix(), base64.URLEncoding.EncodeToString([]byte("malformed_state")), "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusBadRequest},
{"good", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusFound},
{"failed authenticate", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateError: errors.New("error")}, "", http.StatusInternalServerError},
{"failed save session", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{SaveError: errors.New("error")}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "", http.StatusInternalServerError},
{"provider returned error", http.MethodGet, time.Now().Unix(), "", "", "", "idp error", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "", http.StatusBadRequest},
{"empty code", http.MethodGet, time.Now().Unix(), "", "", "", "", "", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "", http.StatusBadRequest},
{"invalid redirect uri", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "", http.StatusBadRequest},
{"bad redirect uri", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "http://^^^", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusBadRequest},
{"bad timing - too soon", http.MethodGet, time.Now().Add(1 * time.Hour).Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusBadRequest},
{"bad timing - expired", http.MethodGet, time.Now().Add(-1 * time.Hour).Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusBadRequest},
{"bad base64", http.MethodGet, time.Now().Unix(), "", "", "^", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusBadRequest},
{"too many seperators", http.MethodGet, time.Now().Unix(), "", "", "|ok|now|what", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusBadRequest},
{"bad hmac", http.MethodGet, time.Now().Unix(), "", "NOTMAC", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusBadRequest},
{"bad hmac", http.MethodGet, time.Now().Unix(), base64.URLEncoding.EncodeToString([]byte("malformed_state")), "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusBadRequest},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
@ -326,12 +327,12 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) {
wantStatus int
}{
{"good", nil, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, http.StatusOK},
{"invalid session", nil, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, errors.New("hi"), identity.MockProvider{}, http.StatusFound},
{"good refresh expired", nil, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, http.StatusFound},
{"expired,refresh error", nil, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{RefreshError: errors.New("error")}, http.StatusFound},
{"expired,save error", nil, &sessions.MockSessionStore{SaveError: errors.New("error"), Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, http.StatusFound},
{"expired XHR,refresh error", map[string]string{"X-Requested-With": "XmlHttpRequest"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{RefreshError: errors.New("error")}, http.StatusUnauthorized},
{"good", nil, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, http.StatusOK},
{"invalid session", nil, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, errors.New("hi"), identity.MockProvider{}, http.StatusFound},
{"good refresh expired", nil, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, http.StatusOK},
{"expired,refresh error", nil, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{RefreshError: errors.New("error")}, http.StatusFound},
{"expired,save error", nil, &mstore.Store{SaveError: errors.New("error"), Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, http.StatusFound},
{"expired XHR,refresh error", map[string]string{"X-Requested-With": "XmlHttpRequest"}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{RefreshError: errors.New("error")}, http.StatusUnauthorized},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
@ -384,11 +385,11 @@ func TestAuthenticate_RefreshAPI(t *testing.T) {
wantStatus int
}{
{"good", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, mock.Encoder{MarshalResponse: []byte("ok")}, mock.Encoder{MarshalResponse: []byte("ok")}, http.StatusOK},
{"refresh error", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshError: errors.New("error")}, mock.Encoder{MarshalResponse: []byte("ok")}, mock.Encoder{MarshalResponse: []byte("ok")}, http.StatusInternalServerError},
{"session is not refreshable error", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, errors.New("session error"), identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, mock.Encoder{MarshalResponse: []byte("ok")}, mock.Encoder{MarshalResponse: []byte("ok")}, http.StatusBadRequest},
{"secret encoder failed", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, mock.Encoder{MarshalError: errors.New("error")}, mock.Encoder{MarshalResponse: []byte("ok")}, http.StatusInternalServerError},
{"shared encoder failed", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, mock.Encoder{MarshalResponse: []byte("ok")}, mock.Encoder{MarshalError: errors.New("error")}, http.StatusInternalServerError},
{"good", &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, mock.Encoder{MarshalResponse: []byte("ok")}, mock.Encoder{MarshalResponse: []byte("ok")}, http.StatusOK},
{"refresh error", &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshError: errors.New("error")}, mock.Encoder{MarshalResponse: []byte("ok")}, mock.Encoder{MarshalResponse: []byte("ok")}, http.StatusInternalServerError},
{"session is not refreshable error", &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, errors.New("session error"), identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, mock.Encoder{MarshalResponse: []byte("ok")}, mock.Encoder{MarshalResponse: []byte("ok")}, http.StatusBadRequest},
{"secret encoder failed", &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, mock.Encoder{MarshalError: errors.New("error")}, mock.Encoder{MarshalResponse: []byte("ok")}, http.StatusInternalServerError},
{"shared encoder failed", &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, mock.Encoder{MarshalResponse: []byte("ok")}, mock.Encoder{MarshalError: errors.New("error")}, http.StatusInternalServerError},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
@ -423,3 +424,54 @@ func TestAuthenticate_RefreshAPI(t *testing.T) {
})
}
}
func TestAuthenticate_Refresh(t *testing.T) {
t.Parallel()
tests := []struct {
name string
session sessions.SessionStore
ctxError error
provider identity.Authenticator
secretEncoder encoding.MarshalUnmarshaler
sharedEncoder encoding.MarshalUnmarshaler
wantStatus int
}{
{"good", &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, mock.Encoder{MarshalResponse: []byte("ok")}, mock.Encoder{MarshalResponse: []byte("ok")}, http.StatusOK},
{"bad session", &mstore.Store{}, errors.New("err"), identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, mock.Encoder{MarshalResponse: []byte("ok")}, mock.Encoder{MarshalResponse: []byte("ok")}, http.StatusBadRequest},
{"encoder error", &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, mock.Encoder{MarshalResponse: []byte("ok")}, mock.Encoder{MarshalError: errors.New("err")}, http.StatusInternalServerError},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
aead, err := chacha20poly1305.NewX(cryptutil.NewKey())
if err != nil {
t.Fatal(err)
}
a := Authenticate{
sharedKey: cryptutil.NewBase64Key(),
cookieSecret: cryptutil.NewKey(),
RedirectURL: uriParseHelper("https://authenticate.corp.beyondperimeter.com"),
encryptedEncoder: tt.secretEncoder,
sharedEncoder: tt.sharedEncoder,
sessionStore: tt.session,
provider: tt.provider,
cookieCipher: aead,
}
r := httptest.NewRequest("GET", "/", nil)
state, _ := tt.session.LoadSession(r)
ctx := r.Context()
ctx = sessions.NewContext(ctx, state, tt.ctxError)
r = r.WithContext(ctx)
r.Header.Set("Accept", "application/json")
w := httptest.NewRecorder()
httputil.HandlerFunc(a.Refresh).ServeHTTP(w, r)
if status := w.Code; status != tt.wantStatus {
t.Errorf("VerifySession() error = %v, wantErr %v\n%v", w.Result().StatusCode, tt.wantStatus, w.Body.String())
}
})
}
}

View file

@ -11,12 +11,13 @@ import (
"strings"
"time"
"github.com/fsnotify/fsnotify"
"github.com/pomerium/pomerium/internal/cryptutil"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/telemetry/metrics"
"github.com/pomerium/pomerium/internal/urlutil"
"github.com/cespare/xxhash/v2"
"github.com/fsnotify/fsnotify"
"github.com/mitchellh/hashstructure"
"github.com/spf13/viper"
"gopkg.in/yaml.v2"
@ -477,7 +478,7 @@ type OptionsUpdater interface {
// Checksum returns the checksum of the current options struct
func (o *Options) Checksum() string {
hash, err := hashstructure.Hash(o, nil)
hash, err := hashstructure.Hash(o, &hashstructure.HashOptions{Hasher: xxhash.New()})
if err != nil {
log.Warn().Err(err).Msg("config: checksum failure")
return "no checksum available"

4
go.mod
View file

@ -6,9 +6,9 @@ require (
cloud.google.com/go v0.49.0 // indirect
contrib.go.opencensus.io/exporter/jaeger v0.2.0
contrib.go.opencensus.io/exporter/prometheus v0.1.0
github.com/cespare/xxhash/v2 v2.1.1 // indirect
github.com/cespare/xxhash/v2 v2.1.1
github.com/fsnotify/fsnotify v1.4.7
github.com/golang/groupcache v0.0.0-20191027212112-611e8accdfc9 // indirect
github.com/golang/groupcache v0.0.0-20191227052852-215e87163ea7
github.com/golang/mock v1.3.1
github.com/golang/protobuf v1.3.2
github.com/google/go-cmp v0.3.1

5
go.sum
View file

@ -71,6 +71,8 @@ github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6 h1:ZgQEtGgCBiWRM
github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
github.com/golang/groupcache v0.0.0-20191027212112-611e8accdfc9 h1:uHTyIjqVhYRhLbJ8nIiOJHkEZZ+5YoOsAbD3sk82NiE=
github.com/golang/groupcache v0.0.0-20191027212112-611e8accdfc9/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
github.com/golang/groupcache v0.0.0-20191227052852-215e87163ea7 h1:5ZkaAPbicIKTF2I64qf5Fh8Aa83Q/dnOafMYV0OMwjA=
github.com/golang/groupcache v0.0.0-20191227052852-215e87163ea7/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A=
github.com/golang/mock v1.2.0 h1:28o5sBqPkBsMGnC6b4MvE2TzSr5/AT4c/1fLqVGIwlk=
github.com/golang/mock v1.2.0/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A=
@ -162,6 +164,7 @@ github.com/prometheus/client_golang v0.9.3/go.mod h1:/TN21ttK/J9q6uSwhBd54HahCDf
github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5FsnadC4Ky3P0J6CfImo=
github.com/prometheus/client_golang v1.2.1 h1:JnMpQc6ppsNgw9QPAGF6Dod479itz7lvlsMzzNayLOI=
github.com/prometheus/client_golang v1.2.1/go.mod h1:XMU6Z2MjaRKVu/dC1qupJI9SiNkDYzz3xecMgSW/F+U=
github.com/prometheus/client_golang v1.3.0 h1:miYCvYqFXtl/J9FIy8eNpBfYthAEFg+Ys0XyUVEcDsc=
github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo=
github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90 h1:S/YWwWx/RA8rT8tKFRuGUZhuA90OyIBpPCXkcbwU8DE=
github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
@ -214,6 +217,7 @@ github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/spf13/viper v1.5.0 h1:GpsTwfsQ27oS/Aha/6d1oD7tpKIqWnOA6tgOX9HHkt4=
github.com/spf13/viper v1.5.0/go.mod h1:AkYRkVJF8TkSG/xet6PzXX+l39KhhXa2pdqVSxnTcn4=
github.com/spf13/viper v1.6.1 h1:VPZzIkznI1YhVMRi6vNFLHSwhnhReBfgTxIPccpfdZk=
github.com/spf13/viper v1.6.1/go.mod h1:t3iDnF5Jlj76alVNuyFBk5oUMCvsrkbvZK0WQdfDi5k=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
@ -384,6 +388,7 @@ gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI=
gopkg.in/ini.v1 v1.51.0 h1:AQvPpx3LzTDM0AjnIRlVFwFFGC+npRopjZxLJj6gdno=
gopkg.in/ini.v1 v1.51.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k=
gopkg.in/resty.v1 v1.12.0/go.mod h1:mDo4pnntr5jdWRML875a/NmxYqAlA73dVijT2AXvQQo=
gopkg.in/square/go-jose.v2 v2.4.0 h1:0kXPskUMGAXXWJlP05ktEMOV0vmzFQUWw6d+aZJQU8A=

View file

@ -1,5 +1,13 @@
package mock // import "github.com/pomerium/pomerium/internal/encoding/mock"
import (
"github.com/pomerium/pomerium/internal/encoding"
)
var _ encoding.MarshalUnmarshaler = &Encoder{}
var _ encoding.Marshaler = &Encoder{}
var _ encoding.Unmarshaler = &Encoder{}
// Encoder MockCSRFStore is a mock implementation of Cipher.
type Encoder struct {
MarshalResponse []byte

View file

@ -8,23 +8,21 @@ import (
"fmt"
"io"
"io/ioutil"
"net"
"net/http"
"net/url"
"time"
"go.opencensus.io/plugin/ochttp"
)
// ErrTokenRevoked signifies a token revokation or expiration error
var ErrTokenRevoked = errors.New("token expired or revoked")
var httpClient = &http.Client{
Timeout: time.Second * 5,
Transport: &http.Transport{
Dial: (&net.Dialer{
Timeout: 2 * time.Second,
}).Dial,
TLSHandshakeTimeout: 2 * time.Second,
},
// DefaultClient avoids leaks by setting an upper limit for timeouts.
var DefaultClient = &http.Client{
Timeout: 1 * time.Minute,
//todo(bdd): incorporate metrics.HTTPMetricsRoundTripper
Transport: &ochttp.Transport{},
}
// Client provides a simple helper interface to make HTTP requests
@ -36,9 +34,11 @@ func Client(ctx context.Context, method, endpoint, userAgent string, headers map
case http.MethodGet:
// error checking skipped because we are just parsing in
// order to make a copy of an existing URL
u, _ := url.Parse(endpoint)
u.RawQuery = params.Encode()
endpoint = u.String()
if params != nil {
u, _ := url.Parse(endpoint)
u.RawQuery = params.Encode()
endpoint = u.String()
}
default:
return fmt.Errorf(http.StatusText(http.StatusBadRequest))
}
@ -52,7 +52,7 @@ func Client(ctx context.Context, method, endpoint, userAgent string, headers map
req.Header.Set(k, v)
}
resp, err := httpClient.Do(req)
resp, err := DefaultClient.Do(req)
if err != nil {
return err
}
@ -79,7 +79,6 @@ func Client(ctx context.Context, method, endpoint, userAgent string, headers map
return fmt.Errorf(http.StatusText(resp.StatusCode))
}
}
if response != nil {
err := json.Unmarshal(respBody, &response)
if err != nil {

131
internal/sessions/cache/cache_store.go vendored Normal file
View file

@ -0,0 +1,131 @@
package cache // import "github.com/pomerium/pomerium/internal/sessions/cache"
import (
"context"
"errors"
"fmt"
"net/http"
"github.com/golang/groupcache"
"github.com/pomerium/pomerium/internal/encoding"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/sessions"
)
var _ sessions.SessionStore = &Store{}
var _ sessions.SessionLoader = &Store{}
const (
defaultQueryParamKey = "ati"
)
// Store implements the session store interface using a distributed cache.
type Store struct {
name string
encoder encoding.Marshaler
decoder encoding.Unmarshaler
cache *groupcache.Group
wrappedStore sessions.SessionStore
}
// defaultCacheSize is ~10MB
var defaultCacheSize int64 = 10 << 20
// NewStore creates a new session store built on the distributed caching library
// groupcache. On a cache miss, the cache store attempts to fallback to another
// SessionStore implementation.
func NewStore(enc encoding.MarshalUnmarshaler, wrappedStore sessions.SessionStore, name string) *Store {
store := &Store{
name: name,
encoder: enc,
decoder: enc,
wrappedStore: wrappedStore,
}
store.cache = groupcache.NewGroup(name, defaultCacheSize, groupcache.GetterFunc(
func(ctx context.Context, id string, dest groupcache.Sink) error {
// fill the cache with session set as part of the request
// context set previously as part of SaveSession.
b := fromContext(ctx)
if len(b) == 0 {
return fmt.Errorf("sessions/cache: cannot fill key %s from ctx", id)
}
if err := dest.SetBytes(b); err != nil {
return fmt.Errorf("sessions/cache: sink error %w", err)
}
return nil
},
))
return store
}
// LoadSession implements SessionLoaders's LoadSession method for cache store.
func (s *Store) LoadSession(r *http.Request) (*sessions.State, error) {
// look for our cache's key in the default query param
sessionID := r.URL.Query().Get(defaultQueryParamKey)
if sessionID == "" {
// if unset, fallback to default cache store
log.FromRequest(r).Debug().Msg("sessions/cache: no query param, trying wrapped loader")
return s.wrappedStore.LoadSession(r)
}
var b []byte
if err := s.cache.Get(r.Context(), sessionID, groupcache.AllocatingByteSliceSink(&b)); err != nil {
log.FromRequest(r).Debug().Err(err).Msg("sessions/cache: miss, trying wrapped loader")
return s.wrappedStore.LoadSession(r)
}
var session sessions.State
if err := s.decoder.Unmarshal(b, &session); err != nil {
log.FromRequest(r).Error().Err(err).Msg("sessions/cache: unmarshal")
return nil, sessions.ErrMalformed
}
return &session, nil
}
// ClearSession implements SessionStore's ClearSession for the cache store.
// Since group cache has no explicit eviction, we just call the wrapped
// store's ClearSession method here.
func (s *Store) ClearSession(w http.ResponseWriter, r *http.Request) {
s.wrappedStore.ClearSession(w, r)
}
// SaveSession implements SessionStore's SaveSession method for cache store.
func (s *Store) SaveSession(w http.ResponseWriter, r *http.Request, x interface{}) error {
err := s.wrappedStore.SaveSession(w, r, x)
if err != nil {
return fmt.Errorf("sessions/cache: wrapped store save error %w", err)
}
state, ok := x.(*sessions.State)
if !ok {
return errors.New("internal/sessions: cannot cache non state type")
}
data, err := s.encoder.Marshal(&state)
if err != nil {
return fmt.Errorf("sessions/cache: marshal %w", err)
}
ctx := newContext(r.Context(), data)
var b []byte
return s.cache.Get(ctx, state.AccessTokenID, groupcache.AllocatingByteSliceSink(&b))
}
var sessionCtxKey = &contextKey{"PomeriumCachedSessionBytes"}
type contextKey struct {
name string
}
func newContext(ctx context.Context, b []byte) context.Context {
ctx = context.WithValue(ctx, sessionCtxKey, b)
return ctx
}
func fromContext(ctx context.Context) []byte {
b, _ := ctx.Value(sessionCtxKey).([]byte)
return b
}

View file

@ -0,0 +1,133 @@
package cache
import (
"fmt"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/google/go-cmp/cmp"
"github.com/pomerium/pomerium/internal/cryptutil"
"github.com/pomerium/pomerium/internal/encoding/ecjson"
"github.com/pomerium/pomerium/internal/sessions"
"github.com/pomerium/pomerium/internal/sessions/cookie"
"gopkg.in/square/go-jose.v2/jwt"
)
func testAuthorizer(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, err := sessions.FromContext(r.Context())
if err != nil {
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}
next.ServeHTTP(w, r)
})
}
func TestVerifier(t *testing.T) {
fnh := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
fmt.Fprint(w, http.StatusText(http.StatusOK))
w.WriteHeader(http.StatusOK)
})
tests := []struct {
name string
skipSave bool
cacheSize int64
state sessions.State
wantBody string
wantStatus int
}{
{"good", false, 1 << 10, sessions.State{AccessTokenID: cryptutil.NewBase64Key(), Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}, http.StatusText(http.StatusOK), http.StatusOK},
{"expired", false, 1 << 10, sessions.State{AccessTokenID: cryptutil.NewBase64Key(), Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}, "internal/sessions: validation failed, token is expired (exp)\n", http.StatusUnauthorized},
{"empty", false, 1 << 10, sessions.State{AccessTokenID: "", Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}, "internal/sessions: session is not found\n", http.StatusUnauthorized},
{"miss", true, 1 << 10, sessions.State{AccessTokenID: cryptutil.NewBase64Key(), Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}, "internal/sessions: session is not found\n", http.StatusUnauthorized},
{"cache eviction", false, 1, sessions.State{AccessTokenID: cryptutil.NewBase64Key(), Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}, "internal/sessions: session is not found\n", http.StatusUnauthorized},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
defaultCacheSize = tt.cacheSize
cipher, err := cryptutil.NewAEADCipherFromBase64(cryptutil.NewBase64Key())
encoder := ecjson.New(cipher)
if err != nil {
t.Fatal(err)
}
cs, err := cookie.NewStore(&cookie.Options{Name: t.Name()}, encoder)
if err != nil {
t.Fatal(err)
}
cacheStore := NewStore(encoder, cs, t.Name())
r := httptest.NewRequest(http.MethodGet, "/", nil)
q := r.URL.Query()
q.Set(defaultQueryParamKey, tt.state.AccessTokenID)
r.URL.RawQuery = q.Encode()
r.Header.Set("Accept", "application/json")
w := httptest.NewRecorder()
got := sessions.RetrieveSession(cacheStore)(testAuthorizer((fnh)))
if !tt.skipSave {
cacheStore.SaveSession(w, r, &tt.state)
}
for i := 1; i <= 10; i++ {
s := tt.state
s.AccessTokenID = cryptutil.NewBase64Key()
cacheStore.SaveSession(w, r, s)
}
got.ServeHTTP(w, r)
gotBody := w.Body.String()
gotStatus := w.Result().StatusCode
if diff := cmp.Diff(gotBody, tt.wantBody); diff != "" {
t.Errorf("RetrieveSession() = %v", diff)
}
if diff := cmp.Diff(gotStatus, tt.wantStatus); diff != "" {
t.Errorf("RetrieveSession() = %v", diff)
}
})
}
}
func TestStore_SaveSession(t *testing.T) {
tests := []struct {
name string
x interface{}
wantErr bool
}{
{"good", &sessions.State{AccessTokenID: cryptutil.NewBase64Key(), Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}, false},
{"bad type", "bad type!", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cipher, err := cryptutil.NewAEADCipherFromBase64(cryptutil.NewBase64Key())
encoder := ecjson.New(cipher)
if err != nil {
t.Fatal(err)
}
cs, err := cookie.NewStore(&cookie.Options{
Name: "_pomerium",
}, encoder)
if err != nil {
t.Fatal(err)
}
cacheStore := NewStore(encoder, cs, t.Name())
r := httptest.NewRequest(http.MethodGet, "/", nil)
r.Header.Set("Accept", "application/json")
w := httptest.NewRecorder()
if err := cacheStore.SaveSession(w, r, tt.x); (err != nil) != tt.wantErr {
t.Errorf("Store.SaveSession() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}

View file

@ -1,4 +1,4 @@
package sessions // import "github.com/pomerium/pomerium/internal/sessions"
package cookie // import "github.com/pomerium/pomerium/internal/sessions/cookie"
import (
"errors"
@ -8,8 +8,15 @@ import (
"time"
"github.com/pomerium/pomerium/internal/encoding"
"github.com/pomerium/pomerium/internal/sessions"
)
var _ sessions.SessionStore = &Store{}
var _ sessions.SessionLoader = &Store{}
// timeNow is time.Now but pulled out as a variable for tests.
var timeNow = time.Now
const (
// ChunkedCanaryByte is the byte value used as a canary prefix to distinguish if
// the cookie is multi-part or not. This constant *should not* be valid
@ -25,8 +32,8 @@ const (
MaxNumChunks = 5
)
// CookieStore implements the session store interface for session cookies.
type CookieStore struct {
// Store implements the session store interface for session cookies.
type Store struct {
Name string
Domain string
Expire time.Duration
@ -37,8 +44,8 @@ type CookieStore struct {
decoder encoding.Unmarshaler
}
// CookieOptions holds options for CookieStore
type CookieOptions struct {
// Options holds options for Store
type Options struct {
Name string
Domain string
Expire time.Duration
@ -46,8 +53,9 @@ type CookieOptions struct {
Secure bool
}
// NewCookieStore returns a new session with ciphers for each of the cookie secrets
func NewCookieStore(opts *CookieOptions, encoder encoding.MarshalUnmarshaler) (*CookieStore, error) {
// NewStore returns a new store that implements the SessionStore interface
// using http cookies.
func NewStore(opts *Options, encoder encoding.MarshalUnmarshaler) (sessions.SessionStore, error) {
cs, err := NewCookieLoader(opts, encoder)
if err != nil {
return nil, err
@ -56,12 +64,13 @@ func NewCookieStore(opts *CookieOptions, encoder encoding.MarshalUnmarshaler) (*
return cs, nil
}
// NewCookieLoader returns a new session with ciphers for each of the cookie secrets
func NewCookieLoader(opts *CookieOptions, dencoder encoding.Unmarshaler) (*CookieStore, error) {
// NewCookieLoader returns a new store that implements the SessionLoader
// interface using http cookies.
func NewCookieLoader(opts *Options, dencoder encoding.Unmarshaler) (*Store, error) {
if dencoder == nil {
return nil, fmt.Errorf("internal/sessions: dencoder cannot be nil")
}
cs, err := newCookieStore(opts)
cs, err := newStore(opts)
if err != nil {
return nil, err
}
@ -69,12 +78,12 @@ func NewCookieLoader(opts *CookieOptions, dencoder encoding.Unmarshaler) (*Cooki
return cs, nil
}
func newCookieStore(opts *CookieOptions) (*CookieStore, error) {
func newStore(opts *Options) (*Store, error) {
if opts.Name == "" {
return nil, fmt.Errorf("internal/sessions: cookie name cannot be empty")
}
return &CookieStore{
return &Store{
Name: opts.Name,
Secure: opts.Secure,
HTTPOnly: opts.HTTPOnly,
@ -83,7 +92,7 @@ func newCookieStore(opts *CookieOptions) (*CookieStore, error) {
}, nil
}
func (cs *CookieStore) makeCookie(value string) *http.Cookie {
func (cs *Store) makeCookie(value string) *http.Cookie {
return &http.Cookie{
Name: cs.Name,
Value: value,
@ -96,7 +105,7 @@ func (cs *CookieStore) makeCookie(value string) *http.Cookie {
}
// ClearSession clears the session cookie from a request
func (cs *CookieStore) ClearSession(w http.ResponseWriter, r *http.Request) {
func (cs *Store) ClearSession(w http.ResponseWriter, r *http.Request) {
c := cs.makeCookie("")
c.MaxAge = -1
c.Expires = timeNow().Add(-time.Hour)
@ -115,51 +124,51 @@ func getCookies(r *http.Request, name string) []*http.Cookie {
}
// LoadSession returns a State from the cookie in the request.
func (cs *CookieStore) LoadSession(r *http.Request) (*State, error) {
func (cs *Store) LoadSession(r *http.Request) (*sessions.State, error) {
cookies := getCookies(r, cs.Name)
if len(cookies) == 0 {
return nil, ErrNoSessionFound
return nil, sessions.ErrNoSessionFound
}
for _, cookie := range cookies {
data := loadChunkedCookie(r, cookie)
session := &State{}
session := &sessions.State{}
err := cs.decoder.Unmarshal([]byte(data), session)
if err == nil {
return session, nil
}
}
return nil, ErrMalformed
return nil, sessions.ErrMalformed
}
// SaveSession saves a session state to a request's cookie store.
func (cs *CookieStore) SaveSession(w http.ResponseWriter, _ *http.Request, x interface{}) error {
func (cs *Store) SaveSession(w http.ResponseWriter, _ *http.Request, x interface{}) error {
var value string
if cs.encoder != nil {
switch v := x.(type) {
case []byte:
value = string(v)
case string:
value = v
default:
if cs.encoder == nil {
return errors.New("internal/sessions: cannot save non-string type")
}
data, err := cs.encoder.Marshal(x)
if err != nil {
return err
}
value = string(data)
} else {
switch v := x.(type) {
case []byte:
value = string(v)
case string:
value = v
default:
return errors.New("internal/sessions: cannot save non-string type")
}
}
cs.setSessionCookie(w, value)
return nil
}
func (cs *CookieStore) setSessionCookie(w http.ResponseWriter, val string) {
func (cs *Store) setSessionCookie(w http.ResponseWriter, val string) {
cs.setCookie(w, cs.makeCookie(val))
}
func (cs *CookieStore) setCookie(w http.ResponseWriter, cookie *http.Cookie) {
func (cs *Store) setCookie(w http.ResponseWriter, cookie *http.Cookie) {
if len(cookie.String()) <= MaxChunkSize {
http.SetCookie(w, cookie)
return
@ -180,20 +189,26 @@ func (cs *CookieStore) setCookie(w http.ResponseWriter, cookie *http.Cookie) {
}
func loadChunkedCookie(r *http.Request, c *http.Cookie) string {
data := c.Value
// if the first byte is our canary byte, we need to handle the multipart bit
if []byte(c.Value)[0] == ChunkedCanaryByte {
var b strings.Builder
fmt.Fprintf(&b, "%s", data[1:])
for i := 1; i <= MaxNumChunks; i++ {
next, err := r.Cookie(fmt.Sprintf("%s_%d", c.Name, i))
if err != nil {
break // break if we can't find the next cookie
}
fmt.Fprintf(&b, "%s", next.Value)
}
data = b.String()
if len(c.Value) == 0 {
return ""
}
// if the first byte is our canary byte, we need to handle the multipart bit
if []byte(c.Value)[0] != ChunkedCanaryByte {
return c.Value
}
data := c.Value
var b strings.Builder
fmt.Fprintf(&b, "%s", data[1:])
for i := 1; i <= MaxNumChunks; i++ {
next, err := r.Cookie(fmt.Sprintf("%s_%d", c.Name, i))
if err != nil {
break // break if we can't find the next cookie
}
fmt.Fprintf(&b, "%s", next.Value)
}
data = b.String()
return data
}

View file

@ -1,4 +1,4 @@
package sessions // import "github.com/pomerium/pomerium/internal/sessions"
package cookie // import "github.com/pomerium/pomerium/internal/sessions/cookie"
import (
"crypto/rand"
@ -13,12 +13,13 @@ import (
"github.com/pomerium/pomerium/internal/encoding"
"github.com/pomerium/pomerium/internal/encoding/ecjson"
"github.com/pomerium/pomerium/internal/encoding/mock"
"github.com/pomerium/pomerium/internal/sessions"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
)
func TestNewCookieStore(t *testing.T) {
func TestNewStore(t *testing.T) {
cipher, err := cryptutil.NewAEADCipher(cryptutil.NewKey())
if err != nil {
t.Fatal(err)
@ -26,28 +27,28 @@ func TestNewCookieStore(t *testing.T) {
encoder := ecjson.New(cipher)
tests := []struct {
name string
opts *CookieOptions
opts *Options
encoder encoding.MarshalUnmarshaler
want *CookieStore
want sessions.SessionStore
wantErr bool
}{
{"good", &CookieOptions{Name: "_cookie", Secure: true, HTTPOnly: true, Domain: "pomerium.io", Expire: 10 * time.Second}, encoder, &CookieStore{Name: "_cookie", Secure: true, HTTPOnly: true, Domain: "pomerium.io", Expire: 10 * time.Second}, false},
{"missing name", &CookieOptions{Name: "", Secure: true, HTTPOnly: true, Domain: "pomerium.io", Expire: 10 * time.Second}, encoder, nil, true},
{"missing encoder", &CookieOptions{Name: "_cookie", Secure: true, HTTPOnly: true, Domain: "pomerium.io", Expire: 10 * time.Second}, nil, nil, true},
{"good", &Options{Name: "_cookie", Secure: true, HTTPOnly: true, Domain: "pomerium.io", Expire: 10 * time.Second}, encoder, &Store{Name: "_cookie", Secure: true, HTTPOnly: true, Domain: "pomerium.io", Expire: 10 * time.Second}, false},
{"missing name", &Options{Name: "", Secure: true, HTTPOnly: true, Domain: "pomerium.io", Expire: 10 * time.Second}, encoder, nil, true},
{"missing encoder", &Options{Name: "_cookie", Secure: true, HTTPOnly: true, Domain: "pomerium.io", Expire: 10 * time.Second}, nil, nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := NewCookieStore(tt.opts, tt.encoder)
got, err := NewStore(tt.opts, tt.encoder)
if (err != nil) != tt.wantErr {
t.Errorf("NewCookieStore() error = %v, wantErr %v", err, tt.wantErr)
t.Errorf("NewStore() error = %v, wantErr %v", err, tt.wantErr)
return
}
cmpOpts := []cmp.Option{
cmpopts.IgnoreUnexported(CookieStore{}),
cmpopts.IgnoreUnexported(Store{}),
}
if diff := cmp.Diff(got, tt.want, cmpOpts...); diff != "" {
t.Errorf("NewCookieStore() = %s", diff)
t.Errorf("NewStore() = %s", diff)
}
})
}
@ -60,14 +61,14 @@ func TestNewCookieLoader(t *testing.T) {
encoder := ecjson.New(cipher)
tests := []struct {
name string
opts *CookieOptions
opts *Options
encoder encoding.MarshalUnmarshaler
want *CookieStore
want *Store
wantErr bool
}{
{"good", &CookieOptions{Name: "_cookie", Secure: true, HTTPOnly: true, Domain: "pomerium.io", Expire: 10 * time.Second}, encoder, &CookieStore{Name: "_cookie", Secure: true, HTTPOnly: true, Domain: "pomerium.io", Expire: 10 * time.Second}, false},
{"missing name", &CookieOptions{Name: "", Secure: true, HTTPOnly: true, Domain: "pomerium.io", Expire: 10 * time.Second}, encoder, nil, true},
{"missing encoder", &CookieOptions{Name: "_cookie", Secure: true, HTTPOnly: true, Domain: "pomerium.io", Expire: 10 * time.Second}, nil, nil, true},
{"good", &Options{Name: "_cookie", Secure: true, HTTPOnly: true, Domain: "pomerium.io", Expire: 10 * time.Second}, encoder, &Store{Name: "_cookie", Secure: true, HTTPOnly: true, Domain: "pomerium.io", Expire: 10 * time.Second}, false},
{"missing name", &Options{Name: "", Secure: true, HTTPOnly: true, Domain: "pomerium.io", Expire: 10 * time.Second}, encoder, nil, true},
{"missing encoder", &Options{Name: "_cookie", Secure: true, HTTPOnly: true, Domain: "pomerium.io", Expire: 10 * time.Second}, nil, nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
@ -77,7 +78,7 @@ func TestNewCookieLoader(t *testing.T) {
return
}
cmpOpts := []cmp.Option{
cmpopts.IgnoreUnexported(CookieStore{}),
cmpopts.IgnoreUnexported(Store{}),
}
if diff := cmp.Diff(got, tt.want, cmpOpts...); diff != "" {
@ -87,7 +88,7 @@ func TestNewCookieLoader(t *testing.T) {
}
}
func TestCookieStore_SaveSession(t *testing.T) {
func TestStore_SaveSession(t *testing.T) {
c, err := cryptutil.NewAEADCipher(cryptutil.NewKey())
if err != nil {
t.Fatal(err)
@ -106,17 +107,17 @@ func TestCookieStore_SaveSession(t *testing.T) {
wantErr bool
wantLoadErr bool
}{
{"good", &State{Email: "user@domain.com", User: "user"}, ecjson.New(c), ecjson.New(c), false, false},
{"bad cipher", &State{Email: "user@domain.com", User: "user"}, nil, nil, true, true},
{"huge cookie", &State{Subject: fmt.Sprintf("%x", hugeString), Email: "user@domain.com", User: "user"}, ecjson.New(c), ecjson.New(c), false, false},
{"marshal error", &State{Email: "user@domain.com", User: "user"}, mock.Encoder{MarshalError: errors.New("error")}, ecjson.New(c), true, true},
{"nil encoder cannot save non string type", &State{Email: "user@domain.com", User: "user"}, nil, ecjson.New(c), true, true},
{"good", &sessions.State{Email: "user@domain.com", User: "user"}, ecjson.New(c), ecjson.New(c), false, false},
{"bad cipher", &sessions.State{Email: "user@domain.com", User: "user"}, nil, nil, true, true},
{"huge cookie", &sessions.State{Subject: fmt.Sprintf("%x", hugeString), Email: "user@domain.com", User: "user"}, ecjson.New(c), ecjson.New(c), false, false},
{"marshal error", &sessions.State{Email: "user@domain.com", User: "user"}, mock.Encoder{MarshalError: errors.New("error")}, ecjson.New(c), true, true},
{"nil encoder cannot save non string type", &sessions.State{Email: "user@domain.com", User: "user"}, nil, ecjson.New(c), true, true},
{"good marshal string directly", cryptutil.NewBase64Key(), nil, ecjson.New(c), false, true},
{"good marshal bytes directly", cryptutil.NewKey(), nil, ecjson.New(c), false, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := &CookieStore{
s := &Store{
Name: "_pomerium",
Secure: true,
HTTPOnly: true,
@ -130,7 +131,7 @@ func TestCookieStore_SaveSession(t *testing.T) {
w := httptest.NewRecorder()
if err := s.SaveSession(w, r, tt.State); (err != nil) != tt.wantErr {
t.Errorf("CookieStore.SaveSession() error = %v, wantErr %v", err, tt.wantErr)
t.Errorf("Store.SaveSession() error = %v, wantErr %v", err, tt.wantErr)
}
r = httptest.NewRequest("GET", "/", nil)
for _, cookie := range w.Result().Cookies() {
@ -143,11 +144,11 @@ func TestCookieStore_SaveSession(t *testing.T) {
return
}
cmpOpts := []cmp.Option{
cmpopts.IgnoreUnexported(State{}),
cmpopts.IgnoreUnexported(sessions.State{}),
}
if err == nil {
if diff := cmp.Diff(state, tt.State, cmpOpts...); diff != "" {
t.Errorf("CookieStore.LoadSession() got = %s", diff)
t.Errorf("Store.LoadSession() got = %s", diff)
}
}
w = httptest.NewRecorder()

View file

@ -0,0 +1,90 @@
package cookie // import "github.com/pomerium/pomerium/internal/sessions/cookie"
import (
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/pomerium/pomerium/internal/sessions"
"github.com/google/go-cmp/cmp"
"github.com/pomerium/pomerium/internal/cryptutil"
"github.com/pomerium/pomerium/internal/encoding/ecjson"
"gopkg.in/square/go-jose.v2/jwt"
)
func testAuthorizer(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, err := sessions.FromContext(r.Context())
if err != nil {
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}
next.ServeHTTP(w, r)
})
}
func TestVerifier(t *testing.T) {
fnh := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
fmt.Fprint(w, http.StatusText(http.StatusOK))
w.WriteHeader(http.StatusOK)
})
tests := []struct {
name string
state sessions.State
wantBody string
wantStatus int
}{
{"good cookie session", sessions.State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}, http.StatusText(http.StatusOK), http.StatusOK},
{"expired cookie", sessions.State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}, "internal/sessions: validation failed, token is expired (exp)\n", http.StatusUnauthorized},
{"malformed cookie", sessions.State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}, "internal/sessions: session is malformed\n", http.StatusUnauthorized},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cipher, err := cryptutil.NewAEADCipherFromBase64(cryptutil.NewBase64Key())
encoder := ecjson.New(cipher)
if err != nil {
t.Fatal(err)
}
encSession, err := encoder.Marshal(&tt.state)
if err != nil {
t.Fatal(err)
}
if strings.Contains(tt.name, "malformed") {
// add some garbage to the end of the string
encSession = append(encSession, cryptutil.NewKey()...)
}
cs, err := NewStore(&Options{
Name: "_pomerium",
}, encoder)
if err != nil {
t.Fatal(err)
}
r := httptest.NewRequest(http.MethodGet, "/", nil)
r.Header.Set("Accept", "application/json")
w := httptest.NewRecorder()
r.AddCookie(&http.Cookie{Name: "_pomerium", Value: string(encSession)})
got := sessions.RetrieveSession(cs)(testAuthorizer((fnh)))
got.ServeHTTP(w, r)
gotBody := w.Body.String()
gotStatus := w.Result().StatusCode
if diff := cmp.Diff(gotBody, tt.wantBody); diff != "" {
t.Errorf("RetrieveSession() = %v", diff)
}
if diff := cmp.Diff(gotStatus, tt.wantStatus); diff != "" {
t.Errorf("RetrieveSession() = %v", diff)
}
})
}
}

View file

@ -0,0 +1,28 @@
package sessions // import "github.com/pomerium/pomerium/internal/sessions"
import (
"errors"
)
var (
// ErrNoSessionFound is the error for when no session is found.
ErrNoSessionFound = errors.New("internal/sessions: session is not found")
// ErrMalformed is the error for when a session is found but is malformed.
ErrMalformed = errors.New("internal/sessions: session is malformed")
// ErrNotValidYet indicates that token is used before time indicated in nbf claim.
ErrNotValidYet = errors.New("internal/sessions: validation failed, token not valid yet (nbf)")
// ErrExpired indicates that token is used after expiry time indicated in exp claim.
ErrExpired = errors.New("internal/sessions: validation failed, token is expired (exp)")
// ErrExpiryRequired indicates that the token does not contain a valid expiry (exp) claim.
ErrExpiryRequired = errors.New("internal/sessions: validation failed, token expiry (exp) is required")
// ErrIssuedInTheFuture indicates that the iat field is in the future.
ErrIssuedInTheFuture = errors.New("internal/sessions: validation field, token issued in the future (iat)")
// ErrInvalidAudience indicated invalid aud claim.
ErrInvalidAudience = errors.New("internal/sessions: validation failed, invalid audience claim (aud)")
)

View file

@ -1,35 +1,38 @@
package sessions // import "github.com/pomerium/pomerium/internal/sessions"
package header // import "github.com/pomerium/pomerium/internal/sessions/header"
import (
"net/http"
"strings"
"github.com/pomerium/pomerium/internal/encoding"
"github.com/pomerium/pomerium/internal/sessions"
)
var _ sessions.SessionLoader = &Store{}
const (
defaultAuthHeader = "Authorization"
defaultAuthType = "Bearer"
)
// HeaderStore implements the load session store interface using http
// Store implements the load session store interface using http
// authorization headers.
type HeaderStore struct {
type Store struct {
authHeader string
authType string
encoder encoding.Unmarshaler
}
// NewHeaderStore returns a new header store for loading sessions from
// NewStore returns a new header store for loading sessions from
// authorization header as defined in as defined in rfc2617
//
// NOTA BENE: While most servers do not log Authorization headers by default,
// you should ensure no other services are logging or leaking your auth headers.
func NewHeaderStore(enc encoding.Unmarshaler, headerType string) *HeaderStore {
func NewStore(enc encoding.Unmarshaler, headerType string) *Store {
if headerType == "" {
headerType = defaultAuthType
}
return &HeaderStore{
return &Store{
authHeader: defaultAuthHeader,
authType: headerType,
encoder: enc,
@ -37,14 +40,14 @@ func NewHeaderStore(enc encoding.Unmarshaler, headerType string) *HeaderStore {
}
// LoadSession tries to retrieve the token string from the Authorization header.
func (as *HeaderStore) LoadSession(r *http.Request) (*State, error) {
func (as *Store) LoadSession(r *http.Request) (*sessions.State, error) {
cipherText := TokenFromHeader(r, as.authHeader, as.authType)
if cipherText == "" {
return nil, ErrNoSessionFound
return nil, sessions.ErrNoSessionFound
}
var session State
var session sessions.State
if err := as.encoder.Unmarshal([]byte(cipherText), &session); err != nil {
return nil, ErrMalformed
return nil, sessions.ErrMalformed
}
return &session, nil
}

View file

@ -0,0 +1,90 @@
package header // import "github.com/pomerium/pomerium/internal/sessions/header"
import (
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/pomerium/pomerium/internal/cryptutil"
"github.com/pomerium/pomerium/internal/encoding/ecjson"
"github.com/pomerium/pomerium/internal/sessions"
"github.com/google/go-cmp/cmp"
"gopkg.in/square/go-jose.v2/jwt"
)
func testAuthorizer(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, err := sessions.FromContext(r.Context())
if err != nil {
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}
next.ServeHTTP(w, r)
})
}
func TestVerifier(t *testing.T) {
fnh := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
fmt.Fprint(w, http.StatusText(http.StatusOK))
w.WriteHeader(http.StatusOK)
})
tests := []struct {
name string
authType string
state sessions.State
wantBody string
wantStatus int
}{
{"good auth header session", "Bearer ", sessions.State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}, http.StatusText(http.StatusOK), http.StatusOK},
{"expired auth header", "Bearer ", sessions.State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}, "internal/sessions: validation failed, token is expired (exp)\n", http.StatusUnauthorized},
{"malformed auth header", "Bearer ", sessions.State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}, "internal/sessions: session is malformed\n", http.StatusUnauthorized},
{"empty auth header", "Bearer ", sessions.State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}, "internal/sessions: session is not found\n", http.StatusUnauthorized},
{"bad auth type", "bees ", sessions.State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}, "internal/sessions: session is not found\n", http.StatusUnauthorized},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cipher, err := cryptutil.NewAEADCipherFromBase64(cryptutil.NewBase64Key())
encoder := ecjson.New(cipher)
if err != nil {
t.Fatal(err)
}
encSession, err := encoder.Marshal(&tt.state)
if err != nil {
t.Fatal(err)
}
if strings.Contains(tt.name, "malformed") {
// add some garbage to the end of the string
encSession = append(encSession, cryptutil.NewKey()...)
}
s := NewStore(encoder, "")
r := httptest.NewRequest(http.MethodGet, "/", nil)
r.Header.Set("Accept", "application/json")
w := httptest.NewRecorder()
if strings.Contains(tt.name, "empty") {
encSession = []byte("")
}
r.Header.Set("Authorization", tt.authType+string(encSession))
got := sessions.RetrieveSession(s)(testAuthorizer((fnh)))
got.ServeHTTP(w, r)
gotBody := w.Body.String()
gotStatus := w.Result().StatusCode
if diff := cmp.Diff(gotBody, tt.wantBody); diff != "" {
t.Errorf("RetrieveSession() = %v", diff)
}
if diff := cmp.Diff(gotStatus, tt.wantStatus); diff != "" {
t.Errorf("RetrieveSession() = %v", diff)
}
})
}
}

View file

@ -44,7 +44,7 @@ func retrieveFromRequest(r *http.Request, sessions ...SessionLoader) (*State, er
}
if state != nil {
err := state.Verify(urlutil.StripPort(r.Host))
return state, err // N.B.: state is _not_ nil_
return state, err // N.B.: state is _not_ nil
}
}

View file

@ -2,16 +2,14 @@ package sessions
import (
"context"
"errors"
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/google/go-cmp/cmp"
"github.com/pomerium/pomerium/internal/cryptutil"
"github.com/pomerium/pomerium/internal/encoding/ecjson"
"gopkg.in/square/go-jose.v2/jwt"
)
@ -39,103 +37,6 @@ func TestNewContext(t *testing.T) {
}
}
func testAuthorizer(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, err := FromContext(r.Context())
if err != nil {
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}
next.ServeHTTP(w, r)
})
}
func TestVerifier(t *testing.T) {
fnh := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
fmt.Fprint(w, http.StatusText(http.StatusOK))
w.WriteHeader(http.StatusOK)
})
tests := []struct {
name string
// s SessionStore
state State
cookie bool
header bool
param bool
wantBody string
wantStatus int
}{
{"good cookie session", State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}, true, false, false, http.StatusText(http.StatusOK), http.StatusOK},
{"expired cookie", State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}, true, false, false, "internal/sessions: validation failed, token is expired (exp)\n", http.StatusUnauthorized},
{"malformed cookie", State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}, true, false, false, "internal/sessions: session is malformed\n", http.StatusUnauthorized},
{"good auth header session", State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}, false, true, false, http.StatusText(http.StatusOK), http.StatusOK},
{"expired auth header", State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}, false, true, false, "internal/sessions: validation failed, token is expired (exp)\n", http.StatusUnauthorized},
{"malformed auth header", State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}, false, true, false, "internal/sessions: session is malformed\n", http.StatusUnauthorized},
{"good auth query param session", State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}, false, true, true, http.StatusText(http.StatusOK), http.StatusOK},
{"expired auth query param", State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}, false, false, true, "internal/sessions: validation failed, token is expired (exp)\n", http.StatusUnauthorized},
{"malformed auth query param", State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}, false, false, true, "internal/sessions: session is malformed\n", http.StatusUnauthorized},
{"no session", State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}, false, false, false, "internal/sessions: session is not found\n", http.StatusUnauthorized},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cipher, err := cryptutil.NewAEADCipherFromBase64(cryptutil.NewBase64Key())
encoder := ecjson.New(cipher)
if err != nil {
t.Fatal(err)
}
encSession, err := encoder.Marshal(&tt.state)
if err != nil {
t.Fatal(err)
}
if strings.Contains(tt.name, "malformed") {
// add some garbage to the end of the string
encSession = append(encSession, cryptutil.NewKey()...)
}
cs, err := NewCookieStore(&CookieOptions{
Name: "_pomerium",
}, encoder)
if err != nil {
t.Fatal(err)
}
as := NewHeaderStore(encoder, "")
qp := NewQueryParamStore(encoder, "")
r := httptest.NewRequest(http.MethodGet, "/", nil)
r.Header.Set("Accept", "application/json")
w := httptest.NewRecorder()
if tt.cookie {
r.AddCookie(&http.Cookie{Name: "_pomerium", Value: string(encSession)})
} else if tt.header {
r.Header.Set("Authorization", "Bearer "+string(encSession))
} else if tt.param {
q := r.URL.Query()
q.Set("pomerium_session", string(encSession))
r.URL.RawQuery = q.Encode()
}
got := RetrieveSession(cs, as, qp)(testAuthorizer((fnh)))
got.ServeHTTP(w, r)
gotBody := w.Body.String()
gotStatus := w.Result().StatusCode
if diff := cmp.Diff(gotBody, tt.wantBody); diff != "" {
t.Errorf("RetrieveSession() = %v", diff)
}
if diff := cmp.Diff(gotStatus, tt.wantStatus); diff != "" {
t.Errorf("RetrieveSession() = %v", diff)
}
})
}
}
func Test_contextKey_String(t *testing.T) {
tests := []struct {
name string
@ -155,3 +56,80 @@ func Test_contextKey_String(t *testing.T) {
})
}
}
func testAuthorizer(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, err := FromContext(r.Context())
if err != nil {
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}
next.ServeHTTP(w, r)
})
}
var _ SessionStore = &store{}
// Store is a mock implementation of the SessionStore interface
type store struct {
ResponseSession string
Session *State
SaveError error
LoadError error
}
// ClearSession clears the ResponseSession
func (ms *store) ClearSession(http.ResponseWriter, *http.Request) {
ms.ResponseSession = ""
}
// LoadSession returns the session and a error
func (ms store) LoadSession(*http.Request) (*State, error) {
return ms.Session, ms.LoadError
}
// SaveSession returns a save error.
func (ms store) SaveSession(http.ResponseWriter, *http.Request, interface{}) error {
return ms.SaveError
}
func TestVerifier(t *testing.T) {
fnh := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
fmt.Fprint(w, http.StatusText(http.StatusOK))
w.WriteHeader(http.StatusOK)
})
tests := []struct {
name string
store store
state State
wantBody string
wantStatus int
}{
{"empty session", store{}, State{}, "internal/sessions: session is not found\n", 401},
{"simple good load", store{Session: &State{Subject: "hi", Expiry: jwt.NewNumericDate(time.Now().Add(time.Second))}}, State{}, "OK", 200},
{"empty session", store{LoadError: errors.New("err")}, State{}, "err\n", 401},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := httptest.NewRequest(http.MethodGet, "/", nil)
r.Header.Set("Accept", "application/json")
w := httptest.NewRecorder()
got := RetrieveSession(tt.store)(testAuthorizer((fnh)))
got.ServeHTTP(w, r)
gotBody := w.Body.String()
gotStatus := w.Result().StatusCode
if diff := cmp.Diff(gotBody, tt.wantBody); diff != "" {
t.Errorf("RetrieveSession() = %v", diff)
}
if diff := cmp.Diff(gotStatus, tt.wantStatus); diff != "" {
t.Errorf("RetrieveSession() = %v", diff)
}
})
}
}

View file

@ -0,0 +1,33 @@
package mock // import "github.com/pomerium/pomerium/internal/sessions/mock"
import (
"net/http"
"github.com/pomerium/pomerium/internal/sessions"
)
var _ sessions.SessionStore = &Store{}
var _ sessions.SessionLoader = &Store{}
// Store is a mock implementation of the SessionStore interface
type Store struct {
ResponseSession string
Session *sessions.State
SaveError error
LoadError error
}
// ClearSession clears the ResponseSession
func (ms *Store) ClearSession(http.ResponseWriter, *http.Request) {
ms.ResponseSession = ""
}
// LoadSession returns the session and a error
func (ms Store) LoadSession(*http.Request) (*sessions.State, error) {
return ms.Session, ms.LoadError
}
// SaveSession returns a save error.
func (ms Store) SaveSession(http.ResponseWriter, *http.Request, interface{}) error {
return ms.SaveError
}

View file

@ -1,26 +1,28 @@
package sessions // import "github.com/pomerium/pomerium/internal/sessions"
package mock // import "github.com/pomerium/pomerium/internal/sessions/mock"
import (
"reflect"
"testing"
"github.com/pomerium/pomerium/internal/sessions"
)
func TestMockSessionStore(t *testing.T) {
func TestStore(t *testing.T) {
tests := []struct {
name string
mockCSRF *MockSessionStore
saveSession *State
mockCSRF *Store
saveSession *sessions.State
wantLoadErr bool
wantSaveErr bool
}{
{"basic",
&MockSessionStore{
&Store{
ResponseSession: "test",
Session: &State{Subject: "0101"},
Session: &sessions.State{Subject: "0101"},
SaveError: nil,
LoadError: nil,
},
&State{Subject: "0101"},
&sessions.State{Subject: "0101"},
false,
false},
}

View file

@ -1,28 +0,0 @@
package sessions // import "github.com/pomerium/pomerium/internal/sessions"
import (
"net/http"
)
// MockSessionStore is a mock implementation of the SessionStore interface
type MockSessionStore struct {
ResponseSession string
Session *State
SaveError error
LoadError error
}
// ClearSession clears the ResponseSession
func (ms *MockSessionStore) ClearSession(http.ResponseWriter, *http.Request) {
ms.ResponseSession = ""
}
// LoadSession returns the session and a error
func (ms MockSessionStore) LoadSession(*http.Request) (*State, error) {
return ms.Session, ms.LoadError
}
// SaveSession returns a save error.
func (ms MockSessionStore) SaveSession(http.ResponseWriter, *http.Request, interface{}) error {
return ms.SaveError
}

View file

@ -0,0 +1,92 @@
package queryparam // import "github.com/pomerium/pomerium/internal/sessions/queryparam"
import (
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/pomerium/pomerium/internal/cryptutil"
"github.com/pomerium/pomerium/internal/encoding/ecjson"
"github.com/pomerium/pomerium/internal/sessions"
"github.com/google/go-cmp/cmp"
"gopkg.in/square/go-jose.v2/jwt"
)
func testAuthorizer(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, err := sessions.FromContext(r.Context())
if err != nil {
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}
next.ServeHTTP(w, r)
})
}
func TestVerifier(t *testing.T) {
fnh := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
fmt.Fprint(w, http.StatusText(http.StatusOK))
w.WriteHeader(http.StatusOK)
})
tests := []struct {
name string
state sessions.State
wantBody string
wantStatus int
}{
{"good auth query param session", sessions.State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}, http.StatusText(http.StatusOK), http.StatusOK},
{"expired auth query param", sessions.State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}, "internal/sessions: validation failed, token is expired (exp)\n", http.StatusUnauthorized},
{"malformed auth query param", sessions.State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}, "internal/sessions: session is malformed\n", http.StatusUnauthorized},
{"empty auth query param", sessions.State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}, "internal/sessions: session is not found\n", http.StatusUnauthorized},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cipher, err := cryptutil.NewAEADCipherFromBase64(cryptutil.NewBase64Key())
encoder := ecjson.New(cipher)
if err != nil {
t.Fatal(err)
}
encSession, err := encoder.Marshal(&tt.state)
if err != nil {
t.Fatal(err)
}
if strings.Contains(tt.name, "malformed") {
// add some garbage to the end of the string
encSession = append(encSession, cryptutil.NewKey()...)
}
s := NewStore(encoder, "")
r := httptest.NewRequest(http.MethodGet, "/", nil)
r.Header.Set("Accept", "application/json")
w := httptest.NewRecorder()
q := r.URL.Query()
if strings.Contains(tt.name, "empty") {
encSession = []byte("")
}
q.Set("pomerium_session", string(encSession))
r.URL.RawQuery = q.Encode()
got := sessions.RetrieveSession(s)(testAuthorizer((fnh)))
got.ServeHTTP(w, r)
gotBody := w.Body.String()
gotStatus := w.Result().StatusCode
if diff := cmp.Diff(gotBody, tt.wantBody); diff != "" {
t.Errorf("RetrieveSession() = %v", diff)
}
if diff := cmp.Diff(gotStatus, tt.wantStatus); diff != "" {
t.Errorf("RetrieveSession() = %v", diff)
}
})
}
}

View file

@ -1,33 +1,37 @@
package sessions // import "github.com/pomerium/pomerium/internal/sessions"
package queryparam // import "github.com/pomerium/pomerium/internal/sessions/queryparam"
import (
"net/http"
"github.com/pomerium/pomerium/internal/encoding"
"github.com/pomerium/pomerium/internal/sessions"
)
var _ sessions.SessionStore = &Store{}
var _ sessions.SessionLoader = &Store{}
const (
defaultQueryParamKey = "pomerium_session"
)
// QueryParamStore implements the load session store interface using http
// Store implements the load session store interface using http
// query strings / query parameters.
type QueryParamStore struct {
type Store struct {
queryParamKey string
encoder encoding.Marshaler
decoder encoding.Unmarshaler
}
// NewQueryParamStore returns a new query param store for loading sessions from
// NewStore returns a new query param store for loading sessions from
// query strings / query parameters.
//
// NOTA BENE: By default, most servers _DO_ log query params, the leaking or
// accidental logging of which should be considered a security issue.
func NewQueryParamStore(enc encoding.MarshalUnmarshaler, qp string) *QueryParamStore {
func NewStore(enc encoding.MarshalUnmarshaler, qp string) *Store {
if qp == "" {
qp = defaultQueryParamKey
}
return &QueryParamStore{
return &Store{
queryParamKey: qp,
encoder: enc,
decoder: enc,
@ -35,27 +39,27 @@ func NewQueryParamStore(enc encoding.MarshalUnmarshaler, qp string) *QueryParamS
}
// LoadSession tries to retrieve the token string from URL query parameters.
func (qp *QueryParamStore) LoadSession(r *http.Request) (*State, error) {
func (qp *Store) LoadSession(r *http.Request) (*sessions.State, error) {
cipherText := r.URL.Query().Get(qp.queryParamKey)
if cipherText == "" {
return nil, ErrNoSessionFound
return nil, sessions.ErrNoSessionFound
}
var session State
var session sessions.State
if err := qp.decoder.Unmarshal([]byte(cipherText), &session); err != nil {
return nil, ErrMalformed
return nil, sessions.ErrMalformed
}
return &session, nil
}
// ClearSession clears the session cookie from a request's query param key `pomerium_session`.
func (qp *QueryParamStore) ClearSession(w http.ResponseWriter, r *http.Request) {
func (qp *Store) ClearSession(w http.ResponseWriter, r *http.Request) {
params := r.URL.Query()
params.Del(qp.queryParamKey)
r.URL.RawQuery = params.Encode()
}
// SaveSession sets a session to a request's query param key `pomerium_session`
func (qp *QueryParamStore) SaveSession(w http.ResponseWriter, r *http.Request, x interface{}) error {
func (qp *Store) SaveSession(w http.ResponseWriter, r *http.Request, x interface{}) error {
data, err := qp.encoder.Marshal(x)
if err != nil {
return err

View file

@ -1,4 +1,4 @@
package sessions
package queryparam // import "github.com/pomerium/pomerium/internal/sessions/queryparam"
import (
"errors"
@ -9,39 +9,40 @@ import (
"github.com/google/go-cmp/cmp"
"github.com/pomerium/pomerium/internal/encoding"
"github.com/pomerium/pomerium/internal/encoding/mock"
"github.com/pomerium/pomerium/internal/sessions"
)
func TestNewQueryParamStore(t *testing.T) {
tests := []struct {
name string
State *State
State *sessions.State
enc encoding.MarshalUnmarshaler
qp string
wantErr bool
wantURL *url.URL
}{
{"simple good", &State{Email: "user@domain.com", User: "user"}, mock.Encoder{MarshalResponse: []byte("ok")}, "", false, &url.URL{Path: "/", RawQuery: "pomerium_session=ok"}},
{"marshall error", &State{Email: "user@domain.com", User: "user"}, mock.Encoder{MarshalError: errors.New("error")}, "", true, &url.URL{Path: "/"}},
{"simple good", &sessions.State{Email: "user@domain.com", User: "user"}, mock.Encoder{MarshalResponse: []byte("ok")}, "", false, &url.URL{Path: "/", RawQuery: "pomerium_session=ok"}},
{"marshall error", &sessions.State{Email: "user@domain.com", User: "user"}, mock.Encoder{MarshalError: errors.New("error")}, "", true, &url.URL{Path: "/"}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := NewQueryParamStore(tt.enc, tt.qp)
got := NewStore(tt.enc, tt.qp)
r := httptest.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
if err := got.SaveSession(w, r, tt.State); (err != nil) != tt.wantErr {
t.Errorf("NewQueryParamStore.SaveSession() error = %v, wantErr %v", err, tt.wantErr)
t.Errorf("NewStore.SaveSession() error = %v, wantErr %v", err, tt.wantErr)
}
if diff := cmp.Diff(r.URL, tt.wantURL); diff != "" {
t.Errorf("NewQueryParamStore() = %v", diff)
t.Errorf("NewStore() = %v", diff)
}
got.ClearSession(w, r)
if diff := cmp.Diff(r.URL, &url.URL{Path: "/"}); diff != "" {
t.Errorf("NewQueryParamStore() = %v", diff)
t.Errorf("NewStore() = %v", diff)
}
})
}

View file

@ -6,6 +6,8 @@ import (
"strings"
"time"
"github.com/cespare/xxhash/v2"
"github.com/mitchellh/hashstructure"
oidc "github.com/pomerium/go-oidc"
"golang.org/x/oauth2"
"gopkg.in/square/go-jose.v2/jwt"
@ -51,7 +53,8 @@ type State struct {
// programatic access.
Programmatic bool `json:"programatic"`
AccessToken *oauth2.Token `json:"access_token,omitempty"`
AccessToken *oauth2.Token `json:"act,omitempty"`
AccessTokenID string `json:"ati,omitempty"`
idToken *oidc.IDToken
}
@ -73,7 +76,7 @@ func NewStateFromTokens(idToken *oidc.IDToken, accessToken *oauth2.Token, audien
s.Audience = []string{audience}
s.idToken = idToken
s.AccessToken = accessToken
s.AccessTokenID = s.accessTokenHash()
return s, nil
}
@ -95,6 +98,7 @@ func (s *State) UpdateState(idToken *oidc.IDToken, accessToken *oauth2.Token) er
}
s.Audience = audience
s.Expiry = jwt.NewNumericDate(accessToken.Expiry)
s.AccessTokenID = s.accessTokenHash()
return nil
}
@ -173,3 +177,13 @@ func (s *State) SetImpersonation(email, groups string) {
s.ImpersonateGroups = strings.Split(groups, ",")
}
}
func (s *State) accessTokenHash() string {
hash, err := hashstructure.Hash(
s.AccessToken,
&hashstructure.HashOptions{Hasher: xxhash.New()})
if err != nil {
return ""
}
return fmt.Sprintf("%x", hash)
}

View file

@ -124,3 +124,26 @@ func TestState_RouteSession(t *testing.T) {
})
}
}
func TestState_accessTokenHash(t *testing.T) {
t.Parallel()
tests := []struct {
name string
state State
want string
}{
{"empty access token", State{}, "34c96acdcadb1bbb"},
{"no change to access token", State{Subject: "test"}, "34c96acdcadb1bbb"},
{"empty oauth2 token", State{AccessToken: &oauth2.Token{}}, "bbd82197d215198f"},
{"refresh token a", State{AccessToken: &oauth2.Token{RefreshToken: "a"}}, "76316ac79b301bd6"},
{"refresh token b", State{AccessToken: &oauth2.Token{RefreshToken: "b"}}, "fab7cb29e50161f1"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := &tt.state
if got := s.accessTokenHash(); got != tt.want {
t.Errorf("State.accessTokenHash() = %v, want %v", got, tt.want)
}
})
}
}

View file

@ -1,39 +1,17 @@
package sessions // import "github.com/pomerium/pomerium/internal/sessions"
import (
"errors"
"net/http"
)
var (
// ErrNoSessionFound is the error for when no session is found.
ErrNoSessionFound = errors.New("internal/sessions: session is not found")
// ErrMalformed is the error for when a session is found but is malformed.
ErrMalformed = errors.New("internal/sessions: session is malformed")
// ErrNotValidYet indicates that token is used before time indicated in nbf claim.
ErrNotValidYet = errors.New("internal/sessions: validation failed, token not valid yet (nbf)")
// ErrExpired indicates that token is used after expiry time indicated in exp claim.
ErrExpired = errors.New("internal/sessions: validation failed, token is expired (exp)")
// ErrIssuedInTheFuture indicates that the iat field is in the future.
ErrIssuedInTheFuture = errors.New("internal/sessions: validation field, token issued in the future (iat)")
// ErrInvalidAudience indicated invalid aud claim.
ErrInvalidAudience = errors.New("internal/sessions: validation failed, invalid audience claim (aud)")
)
// SessionStore has the functions for setting, getting, and clearing the Session cookie
// SessionStore defines an interface for loading, saving, and clearing a session.
type SessionStore interface {
ClearSession(http.ResponseWriter, *http.Request)
SessionLoader
ClearSession(http.ResponseWriter, *http.Request)
SaveSession(http.ResponseWriter, *http.Request, interface{}) error
}
// SessionLoader is implemented by any struct that loads a pomerium session
// given a request, and returns a user state.
// SessionLoader defines an interface for loading a session.
type SessionLoader interface {
LoadSession(*http.Request) (*State, error)
}

View file

@ -34,10 +34,10 @@ var (
// DefaultViews are a set of default views to view HTTP and GRPC metrics.
var (
DefaultViews = [][]*view.View{
GRPCServerViews,
HTTPServerViews,
GRPCClientViews,
GRPCServerViews,
HTTPClientViews,
HTTPServerViews,
InfoViews,
}
)

View file

@ -9,14 +9,16 @@ import (
"time"
"github.com/google/go-cmp/cmp"
"gopkg.in/square/go-jose.v2/jwt"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/encoding"
"github.com/pomerium/pomerium/internal/encoding/mock"
"github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/sessions"
mstore "github.com/pomerium/pomerium/internal/sessions/mock"
"github.com/pomerium/pomerium/internal/urlutil"
"github.com/pomerium/pomerium/proxy/clients"
"gopkg.in/square/go-jose.v2/jwt"
)
func TestProxy_ForwardAuth(t *testing.T) {
@ -40,29 +42,29 @@ func TestProxy_ForwardAuth(t *testing.T) {
wantStatus int
wantBody string
}{
{"good redirect not required", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusOK, "Access to some.domain.example is allowed."},
{"good verify only, no redirect", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/verify", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusOK, ""},
{"good redirect not required", opts, nil, http.MethodGet, nil, nil, "/", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{LoadError: sessions.ErrInvalidAudience}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusUnauthorized, "{\"Status\":401,\"Error\":\"Unauthorized: internal/sessions: validation failed, invalid audience claim (aud)\"}\n"},
{"bad naked domain uri", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/", "a.naked.domain", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, "{\"Status\":400,\"Error\":\"Bad Request: a.naked.domain url does contain a valid scheme\"}\n"},
{"bad naked domain uri verify only", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/verify", "a.naked.domain", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, "{\"Status\":400,\"Error\":\"Bad Request: a.naked.domain url does contain a valid scheme\"}\n"},
{"bad empty verification uri", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/", " ", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, "{\"Status\":400,\"Error\":\"Bad Request: %20 url does contain a valid scheme\"}\n"},
{"bad empty verification uri verify only", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/verify", " ", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, "{\"Status\":400,\"Error\":\"Bad Request: %20 url does contain a valid scheme\"}\n"},
{"not authorized", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusUnauthorized, "{\"Status\":401,\"Error\":\"Unauthorized: user@test.example is not authorized for some.domain.example\"}\n"},
{"not authorized verify endpoint", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/verify", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusUnauthorized, "{\"Status\":401,\"Error\":\"Unauthorized: user@test.example is not authorized for some.domain.example\"}\n"},
{"not authorized expired, redirect to auth", opts, sessions.ErrExpired, http.MethodGet, nil, nil, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusFound, ""},
{"not authorized expired, don't redirect!", opts, sessions.ErrExpired, http.MethodGet, nil, nil, "https://some.domain.example/verify", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusUnauthorized, "{\"Status\":401,\"Error\":\"Unauthorized: internal/sessions: validation failed, token is expired (exp)\"}\n"},
{"not authorized because of error", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeError: errors.New("authz error")}, http.StatusInternalServerError, "{\"Status\":500,\"Error\":\"Internal Server Error: authz error\"}\n"},
{"not authorized expired, do not redirect to auth", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/verify", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusUnauthorized, "{\"Status\":401,\"Error\":\"Unauthorized: internal/sessions: validation failed, token is expired (exp)\"}\n"},
{"not authorized, bad audience request uri", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Audience: []string{"not.domain.example"}, Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusUnauthorized, "{\"Status\":401,\"Error\":\"Unauthorized: internal/sessions: validation failed, invalid audience claim (aud)\"}\n"},
{"not authorized, bad audience verify uri", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/", "https://fwdauth.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Audience: []string{"some.domain.example"}, Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusUnauthorized, "{\"Status\":401,\"Error\":\"Unauthorized: internal/sessions: validation failed, invalid audience claim (aud)\"}\n"},
{"good redirect not required", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusOK, "Access to some.domain.example is allowed."},
{"good verify only, no redirect", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/verify", "https://some.domain.example", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusOK, ""},
{"good redirect not required", opts, nil, http.MethodGet, nil, nil, "/", "https://some.domain.example", &mock.Encoder{}, &mstore.Store{LoadError: sessions.ErrInvalidAudience}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusUnauthorized, "{\"Status\":401,\"Error\":\"Unauthorized: internal/sessions: validation failed, invalid audience claim (aud)\"}\n"},
{"bad naked domain uri", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/", "a.naked.domain", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, "{\"Status\":400,\"Error\":\"Bad Request: a.naked.domain url does contain a valid scheme\"}\n"},
{"bad naked domain uri verify only", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/verify", "a.naked.domain", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, "{\"Status\":400,\"Error\":\"Bad Request: a.naked.domain url does contain a valid scheme\"}\n"},
{"bad empty verification uri", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/", " ", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, "{\"Status\":400,\"Error\":\"Bad Request: %20 url does contain a valid scheme\"}\n"},
{"bad empty verification uri verify only", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/verify", " ", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, "{\"Status\":400,\"Error\":\"Bad Request: %20 url does contain a valid scheme\"}\n"},
{"not authorized", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusUnauthorized, "{\"Status\":401,\"Error\":\"Unauthorized: user@test.example is not authorized for some.domain.example\"}\n"},
{"not authorized verify endpoint", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/verify", "https://some.domain.example", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusUnauthorized, "{\"Status\":401,\"Error\":\"Unauthorized: user@test.example is not authorized for some.domain.example\"}\n"},
{"not authorized expired, redirect to auth", opts, sessions.ErrExpired, http.MethodGet, nil, nil, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusFound, ""},
{"not authorized expired, don't redirect!", opts, sessions.ErrExpired, http.MethodGet, nil, nil, "https://some.domain.example/verify", "https://some.domain.example", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusUnauthorized, "{\"Status\":401,\"Error\":\"Unauthorized: internal/sessions: validation failed, token is expired (exp)\"}\n"},
{"not authorized because of error", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeError: errors.New("authz error")}, http.StatusInternalServerError, "{\"Status\":500,\"Error\":\"Internal Server Error: authz error\"}\n"},
{"not authorized expired, do not redirect to auth", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/verify", "https://some.domain.example", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusUnauthorized, "{\"Status\":401,\"Error\":\"Unauthorized: internal/sessions: validation failed, token is expired (exp)\"}\n"},
{"not authorized, bad audience request uri", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Audience: []string{"not.domain.example"}, Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusUnauthorized, "{\"Status\":401,\"Error\":\"Unauthorized: internal/sessions: validation failed, invalid audience claim (aud)\"}\n"},
{"not authorized, bad audience verify uri", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/", "https://fwdauth.domain.example", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Audience: []string{"some.domain.example"}, Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusUnauthorized, "{\"Status\":401,\"Error\":\"Unauthorized: internal/sessions: validation failed, invalid audience claim (aud)\"}\n"},
// traefik
{"good traefik callback", opts, nil, http.MethodGet, map[string]string{httputil.HeaderForwardedURI: "https://some.domain.example?" + urlutil.QuerySessionEncrypted + "=" + goodEncryptionString}, nil, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound, ""},
{"bad traefik callback bad session", opts, nil, http.MethodGet, map[string]string{httputil.HeaderForwardedURI: "https://some.domain.example?" + urlutil.QuerySessionEncrypted + "=" + goodEncryptionString + "garbage"}, nil, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""},
{"bad traefik callback bad url", opts, nil, http.MethodGet, map[string]string{httputil.HeaderForwardedURI: urlutil.QuerySessionEncrypted + ""}, nil, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""},
{"good traefik callback", opts, nil, http.MethodGet, map[string]string{httputil.HeaderForwardedURI: "https://some.domain.example?" + urlutil.QuerySessionEncrypted + "=" + goodEncryptionString}, nil, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound, ""},
{"bad traefik callback bad session", opts, nil, http.MethodGet, map[string]string{httputil.HeaderForwardedURI: "https://some.domain.example?" + urlutil.QuerySessionEncrypted + "=" + goodEncryptionString + "garbage"}, nil, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""},
{"bad traefik callback bad url", opts, nil, http.MethodGet, map[string]string{httputil.HeaderForwardedURI: urlutil.QuerySessionEncrypted + ""}, nil, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""},
// nginx
{"good nginx callback redirect", opts, nil, http.MethodGet, nil, map[string]string{urlutil.QueryRedirectURI: "https://some.domain.example/", urlutil.QuerySessionEncrypted: goodEncryptionString}, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound, ""},
{"good nginx callback set session okay but return unauthorized", opts, nil, http.MethodGet, nil, map[string]string{urlutil.QueryRedirectURI: "https://some.domain.example/", urlutil.QuerySessionEncrypted: goodEncryptionString}, "https://some.domain.example/verify", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusUnauthorized, ""},
{"bad nginx callback failed to set sesion", opts, nil, http.MethodGet, nil, map[string]string{urlutil.QueryRedirectURI: "https://some.domain.example/", urlutil.QuerySessionEncrypted: goodEncryptionString + "nope"}, "https://some.domain.example/verify", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""},
{"good nginx callback redirect", opts, nil, http.MethodGet, nil, map[string]string{urlutil.QueryRedirectURI: "https://some.domain.example/", urlutil.QuerySessionEncrypted: goodEncryptionString}, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound, ""},
{"good nginx callback set session okay but return unauthorized", opts, nil, http.MethodGet, nil, map[string]string{urlutil.QueryRedirectURI: "https://some.domain.example/", urlutil.QuerySessionEncrypted: goodEncryptionString}, "https://some.domain.example/verify", "https://some.domain.example", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusUnauthorized, ""},
{"bad nginx callback failed to set sesion", opts, nil, http.MethodGet, nil, map[string]string{urlutil.QueryRedirectURI: "https://some.domain.example/", urlutil.QuerySessionEncrypted: goodEncryptionString + "nope"}, "https://some.domain.example/verify", "https://some.domain.example", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {

View file

@ -12,6 +12,7 @@ import (
"time"
"github.com/pomerium/pomerium/internal/cryptutil"
mstore "github.com/pomerium/pomerium/internal/sessions/mock"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/encoding"
@ -78,10 +79,10 @@ func TestProxy_UserDashboard(t *testing.T) {
wantAdminForm bool
wantStatus int
}{
{"good", nil, opts, http.MethodGet, &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{}, false, http.StatusOK},
{"session context error", errors.New("error"), opts, http.MethodGet, &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{}, false, http.StatusInternalServerError},
{"want admin form good admin authorization", nil, opts, http.MethodGet, &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{IsAdminResponse: true}, true, http.StatusOK},
{"is admin but authorization fails", nil, opts, http.MethodGet, &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{IsAdminError: errors.New("err")}, false, http.StatusInternalServerError},
{"good", nil, opts, http.MethodGet, &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{}, false, http.StatusOK},
{"session context error", errors.New("error"), opts, http.MethodGet, &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{}, false, http.StatusInternalServerError},
{"want admin form good admin authorization", nil, opts, http.MethodGet, &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{IsAdminResponse: true}, true, http.StatusOK},
{"is admin but authorization fails", nil, opts, http.MethodGet, &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{IsAdminError: errors.New("err")}, false, http.StatusInternalServerError},
}
for _, tt := range tests {
@ -135,12 +136,12 @@ func TestProxy_Impersonate(t *testing.T) {
authorizer clients.Authorizer
wantStatus int
}{
{"good", false, opts, nil, http.MethodPost, "user@blah.com", "", "", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example"}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusFound},
{"good", false, opts, errors.New("error"), http.MethodPost, "user@blah.com", "", "", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example"}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusInternalServerError},
{"session load error", false, opts, nil, http.MethodPost, "user@blah.com", "", "", &mock.Encoder{}, &sessions.MockSessionStore{LoadError: errors.New("err"), Session: &sessions.State{Email: "user@test.example"}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusFound},
{"non admin users rejected", false, opts, nil, http.MethodPost, "user@blah.com", "", "", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute)), Email: "user@test.example"}}, clients.MockAuthorize{IsAdminResponse: false}, http.StatusForbidden},
{"non admin users rejected on error", false, opts, nil, http.MethodPost, "user@blah.com", "", "", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute)), Email: "user@test.example"}}, clients.MockAuthorize{IsAdminResponse: true, IsAdminError: errors.New("err")}, http.StatusInternalServerError},
{"groups", false, opts, nil, http.MethodPost, "user@blah.com", "group1,group2", "", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute)), Email: "user@test.example"}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusFound},
{"good", false, opts, nil, http.MethodPost, "user@blah.com", "", "", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example"}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusFound},
{"good", false, opts, errors.New("error"), http.MethodPost, "user@blah.com", "", "", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example"}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusInternalServerError},
{"session load error", false, opts, nil, http.MethodPost, "user@blah.com", "", "", &mock.Encoder{}, &mstore.Store{LoadError: errors.New("err"), Session: &sessions.State{Email: "user@test.example"}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusFound},
{"non admin users rejected", false, opts, nil, http.MethodPost, "user@blah.com", "", "", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute)), Email: "user@test.example"}}, clients.MockAuthorize{IsAdminResponse: false}, http.StatusForbidden},
{"non admin users rejected on error", false, opts, nil, http.MethodPost, "user@blah.com", "", "", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute)), Email: "user@test.example"}}, clients.MockAuthorize{IsAdminResponse: true, IsAdminError: errors.New("err")}, http.StatusInternalServerError},
{"groups", false, opts, nil, http.MethodPost, "user@blah.com", "group1,group2", "", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute)), Email: "user@test.example"}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusFound},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
@ -245,12 +246,12 @@ func TestProxy_Callback(t *testing.T) {
wantStatus int
wantBody string
}{
{"good", opts, http.MethodGet, "http", "example.com", "/", nil, map[string]string{urlutil.QueryCallbackURI: "ok", urlutil.QuerySessionEncrypted: goodEncryptionString}, &mock.Encoder{MarshalResponse: []byte("x")}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound, ""},
{"good programmatic", opts, http.MethodGet, "http", "example.com", "/", nil, map[string]string{urlutil.QueryIsProgrammatic: "true", urlutil.QueryCallbackURI: "ok", urlutil.QuerySessionEncrypted: goodEncryptionString}, &mock.Encoder{MarshalResponse: []byte("x")}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound, ""},
{"bad decrypt", opts, http.MethodGet, "http", "example.com", "/", nil, map[string]string{urlutil.QuerySessionEncrypted: "KBEjQ9rnCxaAX-GOqexGw9ivEQURqts3zZ2mNGy0wnVa3SbtM399KlBq2nZ-9wM21FfsZX52er4jlmC7kPEKM3P7uZ41zR0zeys1-_74a5tQp-vsf1WXZfRsgVOuBcWPkMiWEoc379JFHxGDudp5VhU8B-dcQt4f3_PtLTHARkuH54io1Va2gNMq4Hiy8sQ1MPGCQeltH_JMzzdDpXdmdusWrXUvCGkba24muvAV06D8XRVJj6Iu9eK94qFnqcHc7wzziEbb8ADBues9dwbtb6jl8vMWz5rN6XvXqA5YpZv_MQZlsrO4oXFFQDevdgB84cX1tVbVu6qZvK_yQBZqzpOjWA9uIaoSENMytoXuWAlFO_sXjswfX8JTNdGwzB7qQRNPqxVG_sM_tzY3QhPm8zqwEzsXG5DokxZfVt2I5WJRUEovFDb4BnK9KFnnkEzLEdMudixVnXeGmTtycgJvoTeTCQRPfDYkcgJ7oKf4tGea-W7z5UAVa2RduJM9ZoM6YtJX7jgDm__PvvqcE0knJUF87XHBzdcOjoDF-CUze9xDJgNBlvPbJqVshKrwoqSYpePSDH9GUCNKxGequW3Ma8GvlFfhwd0rK6IZG-XWkyk0XSWQIGkDSjAvhB1wsOusCCguDjbpVZpaW5MMyTkmx68pl6qlIKT5UCcrVPl4ix5ZEj91mUDF0O1t04haD7VZuLVFXVGmqtFrBKI76sdYN-zkokaa1_chPRTyqMQFlqu_8LD6-RiK3UccGM-dEmnX72i91NP9F9OK0WJr9Cheup1C_P0mjqAO4Cb8oIHm0Oxz_mRqv5QbTGJtb3xwPLPuVjVCiE4gGBcuU2ixpSVf5HUF7y1KicVMCKiX9ATCBtg8sTdQZQnPEtHcHHAvdsnDVwev1LGfqA-Gdvg="}, &mock.Encoder{MarshalResponse: []byte("x")}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""},
{"bad save session", opts, http.MethodGet, "http", "example.com", "/", nil, map[string]string{urlutil.QuerySessionEncrypted: goodEncryptionString}, &mock.Encoder{MarshalResponse: []byte("x")}, &sessions.MockSessionStore{SaveError: errors.New("hi")}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""},
{"bad base64", opts, http.MethodGet, "http", "example.com", "/", nil, map[string]string{urlutil.QuerySessionEncrypted: "^"}, &mock.Encoder{MarshalResponse: []byte("x")}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""},
{"malformed redirect", opts, http.MethodGet, "http", "example.com", "/", nil, nil, &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""},
{"good", opts, http.MethodGet, "http", "example.com", "/", nil, map[string]string{urlutil.QueryCallbackURI: "ok", urlutil.QuerySessionEncrypted: goodEncryptionString}, &mock.Encoder{MarshalResponse: []byte("x")}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound, ""},
{"good programmatic", opts, http.MethodGet, "http", "example.com", "/", nil, map[string]string{urlutil.QueryIsProgrammatic: "true", urlutil.QueryCallbackURI: "ok", urlutil.QuerySessionEncrypted: goodEncryptionString}, &mock.Encoder{MarshalResponse: []byte("x")}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound, ""},
{"bad decrypt", opts, http.MethodGet, "http", "example.com", "/", nil, map[string]string{urlutil.QuerySessionEncrypted: "KBEjQ9rnCxaAX-GOqexGw9ivEQURqts3zZ2mNGy0wnVa3SbtM399KlBq2nZ-9wM21FfsZX52er4jlmC7kPEKM3P7uZ41zR0zeys1-_74a5tQp-vsf1WXZfRsgVOuBcWPkMiWEoc379JFHxGDudp5VhU8B-dcQt4f3_PtLTHARkuH54io1Va2gNMq4Hiy8sQ1MPGCQeltH_JMzzdDpXdmdusWrXUvCGkba24muvAV06D8XRVJj6Iu9eK94qFnqcHc7wzziEbb8ADBues9dwbtb6jl8vMWz5rN6XvXqA5YpZv_MQZlsrO4oXFFQDevdgB84cX1tVbVu6qZvK_yQBZqzpOjWA9uIaoSENMytoXuWAlFO_sXjswfX8JTNdGwzB7qQRNPqxVG_sM_tzY3QhPm8zqwEzsXG5DokxZfVt2I5WJRUEovFDb4BnK9KFnnkEzLEdMudixVnXeGmTtycgJvoTeTCQRPfDYkcgJ7oKf4tGea-W7z5UAVa2RduJM9ZoM6YtJX7jgDm__PvvqcE0knJUF87XHBzdcOjoDF-CUze9xDJgNBlvPbJqVshKrwoqSYpePSDH9GUCNKxGequW3Ma8GvlFfhwd0rK6IZG-XWkyk0XSWQIGkDSjAvhB1wsOusCCguDjbpVZpaW5MMyTkmx68pl6qlIKT5UCcrVPl4ix5ZEj91mUDF0O1t04haD7VZuLVFXVGmqtFrBKI76sdYN-zkokaa1_chPRTyqMQFlqu_8LD6-RiK3UccGM-dEmnX72i91NP9F9OK0WJr9Cheup1C_P0mjqAO4Cb8oIHm0Oxz_mRqv5QbTGJtb3xwPLPuVjVCiE4gGBcuU2ixpSVf5HUF7y1KicVMCKiX9ATCBtg8sTdQZQnPEtHcHHAvdsnDVwev1LGfqA-Gdvg="}, &mock.Encoder{MarshalResponse: []byte("x")}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""},
{"bad save session", opts, http.MethodGet, "http", "example.com", "/", nil, map[string]string{urlutil.QuerySessionEncrypted: goodEncryptionString}, &mock.Encoder{MarshalResponse: []byte("x")}, &mstore.Store{SaveError: errors.New("hi")}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""},
{"bad base64", opts, http.MethodGet, "http", "example.com", "/", nil, map[string]string{urlutil.QuerySessionEncrypted: "^"}, &mock.Encoder{MarshalResponse: []byte("x")}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""},
{"malformed redirect", opts, http.MethodGet, "http", "example.com", "/", nil, nil, &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
@ -386,12 +387,12 @@ func TestProxy_ProgrammaticCallback(t *testing.T) {
wantStatus int
wantBody string
}{
{"good", opts, http.MethodGet, "http://pomerium.io/", nil, map[string]string{urlutil.QueryCallbackURI: "ok", urlutil.QuerySessionEncrypted: goodEncryptionString}, &mock.Encoder{MarshalResponse: []byte("x")}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound, ""},
{"good programmatic", opts, http.MethodGet, "http://pomerium.io/", nil, map[string]string{urlutil.QueryIsProgrammatic: "true", urlutil.QueryCallbackURI: "ok", urlutil.QuerySessionEncrypted: goodEncryptionString}, &mock.Encoder{MarshalResponse: []byte("x")}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound, ""},
{"bad decrypt", opts, http.MethodGet, "http://pomerium.io/", nil, map[string]string{urlutil.QuerySessionEncrypted: goodEncryptionString + cryptutil.NewBase64Key()}, &mock.Encoder{MarshalResponse: []byte("x")}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""},
{"bad save session", opts, http.MethodGet, "http://pomerium.io/", nil, map[string]string{urlutil.QuerySessionEncrypted: goodEncryptionString}, &mock.Encoder{MarshalResponse: []byte("x")}, &sessions.MockSessionStore{SaveError: errors.New("hi")}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""},
{"bad base64", opts, http.MethodGet, "http://pomerium.io/", nil, map[string]string{urlutil.QuerySessionEncrypted: "^"}, &mock.Encoder{MarshalResponse: []byte("x")}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""},
{"malformed redirect", opts, http.MethodGet, "http://pomerium.io/", nil, nil, &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""},
{"good", opts, http.MethodGet, "http://pomerium.io/", nil, map[string]string{urlutil.QueryCallbackURI: "ok", urlutil.QuerySessionEncrypted: goodEncryptionString}, &mock.Encoder{MarshalResponse: []byte("x")}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound, ""},
{"good programmatic", opts, http.MethodGet, "http://pomerium.io/", nil, map[string]string{urlutil.QueryIsProgrammatic: "true", urlutil.QueryCallbackURI: "ok", urlutil.QuerySessionEncrypted: goodEncryptionString}, &mock.Encoder{MarshalResponse: []byte("x")}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound, ""},
{"bad decrypt", opts, http.MethodGet, "http://pomerium.io/", nil, map[string]string{urlutil.QuerySessionEncrypted: goodEncryptionString + cryptutil.NewBase64Key()}, &mock.Encoder{MarshalResponse: []byte("x")}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""},
{"bad save session", opts, http.MethodGet, "http://pomerium.io/", nil, map[string]string{urlutil.QuerySessionEncrypted: goodEncryptionString}, &mock.Encoder{MarshalResponse: []byte("x")}, &mstore.Store{SaveError: errors.New("hi")}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""},
{"bad base64", opts, http.MethodGet, "http://pomerium.io/", nil, map[string]string{urlutil.QuerySessionEncrypted: "^"}, &mock.Encoder{MarshalResponse: []byte("x")}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""},
{"malformed redirect", opts, http.MethodGet, "http://pomerium.io/", nil, nil, &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {

View file

@ -1,7 +1,11 @@
package proxy // import "github.com/pomerium/pomerium/proxy"
import (
"context"
"errors"
"fmt"
"io"
"io/ioutil"
"net/http"
"github.com/pomerium/pomerium/internal/encoding"
@ -30,23 +34,82 @@ func (p *Proxy) AuthenticateSession(next http.Handler) http.Handler {
ctx, span := trace.StartSpan(r.Context(), "proxy.AuthenticateSession")
defer span.End()
if s, err := sessions.FromContext(ctx); err != nil {
log.FromRequest(r).Debug().Err(err).Msg("proxy: authenticate session")
p.sessionStore.ClearSession(w, r)
if s != nil && s.Programmatic {
return httputil.NewError(http.StatusUnauthorized, err)
_, err := sessions.FromContext(ctx)
if errors.Is(err, sessions.ErrExpired) {
ctx, err = p.refresh(ctx, w, r)
if err != nil {
log.FromRequest(r).Warn().Err(err).Msg("proxy: refresh failed")
return p.redirectToSignin(w, r)
}
signinURL := *p.authenticateSigninURL
q := signinURL.Query()
q.Set(urlutil.QueryRedirectURI, urlutil.GetAbsoluteURL(r).String())
signinURL.RawQuery = q.Encode()
httputil.Redirect(w, r, urlutil.NewSignedURL(p.SharedKey, &signinURL).String(), http.StatusFound)
log.FromRequest(r).Info().Msg("proxy: refresh success")
} else if err != nil {
log.FromRequest(r).Debug().Err(err).Msg("proxy: session state")
return p.redirectToSignin(w, r)
}
p.addPomeriumHeaders(w, r)
next.ServeHTTP(w, r.WithContext(ctx))
return nil
})
}
func (p *Proxy) refresh(ctx context.Context, w http.ResponseWriter, r *http.Request) (context.Context, error) {
ctx, span := trace.StartSpan(ctx, "proxy.AuthenticateSession/refresh")
defer span.End()
s, err := sessions.FromContext(ctx)
if !errors.Is(err, sessions.ErrExpired) || s == nil {
return nil, errors.New("proxy: unexpected session state for refresh")
}
// 1 - build a signed url to call refresh on authenticate service
refreshURI := *p.authenticateRefreshURL
q := refreshURI.Query()
q.Set("ati", s.AccessTokenID) // hash value points to parent token
q.Set("aud", urlutil.StripPort(r.Host)) // request's audience, this route
refreshURI.RawQuery = q.Encode()
signedRefreshURL := urlutil.NewSignedURL(p.SharedKey, &refreshURI).String()
// 2 - http call to authenticate service
req, err := http.NewRequestWithContext(ctx, http.MethodGet, signedRefreshURL, nil)
if err != nil {
return nil, fmt.Errorf("proxy: backend refresh: new request: %v", err)
}
res, err := httputil.DefaultClient.Do(req)
if err != nil {
return nil, fmt.Errorf("proxy: fetch %v: %w", signedRefreshURL, err)
}
defer res.Body.Close()
jwtBytes, err := ioutil.ReadAll(io.LimitReader(res.Body, 4<<10))
if err != nil {
return nil, err
}
// 3 - save refreshed session to the client's session store
if err = p.sessionStore.SaveSession(w, r, jwtBytes); err != nil {
return nil, err
}
// 4 - add refreshed session to the current request context
var state sessions.State
if err := p.encoder.Unmarshal(jwtBytes, &state); err != nil {
return nil, err
}
if err := state.Verify(urlutil.StripPort(r.Host)); err != nil {
return nil, err
}
return sessions.NewContext(r.Context(), &state, err), nil
}
func (p *Proxy) redirectToSignin(w http.ResponseWriter, r *http.Request) error {
s, err := sessions.FromContext(r.Context())
if s != nil && err != nil && s.Programmatic {
return httputil.NewError(http.StatusUnauthorized, err)
}
p.sessionStore.ClearSession(w, r)
signinURL := *p.authenticateSigninURL
q := signinURL.Query()
q.Set(urlutil.QueryRedirectURI, urlutil.GetAbsoluteURL(r).String())
signinURL.RawQuery = q.Encode()
log.FromRequest(r).Debug().Str("url", signinURL.String()).Msg("proxy: redirectToSignin")
httputil.Redirect(w, r, urlutil.NewSignedURL(p.SharedKey, &signinURL).String(), http.StatusFound)
return nil
}
func (p *Proxy) addPomeriumHeaders(w http.ResponseWriter, r *http.Request) {
@ -61,8 +124,8 @@ func (p *Proxy) addPomeriumHeaders(w http.ResponseWriter, r *http.Request) {
}
}
// AuthorizeSession is middleware to enforce a user is authorized for a request
// session state is retrieved from the users's request context.
// AuthorizeSession is middleware to enforce a user is authorized for a request.
// Session state is retrieved from the users's request context.
func (p *Proxy) AuthorizeSession(next http.Handler) http.Handler {
return httputil.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
ctx, span := trace.StartSpan(r.Context(), "proxy.AuthorizeSession")

View file

@ -10,10 +10,14 @@ import (
"time"
"github.com/google/go-cmp/cmp"
"github.com/pomerium/pomerium/internal/identity"
"github.com/pomerium/pomerium/internal/sessions"
"github.com/pomerium/pomerium/proxy/clients"
"gopkg.in/square/go-jose.v2/jwt"
"github.com/pomerium/pomerium/internal/encoding"
"github.com/pomerium/pomerium/internal/encoding/mock"
"github.com/pomerium/pomerium/internal/identity"
"github.com/pomerium/pomerium/internal/sessions"
mstore "github.com/pomerium/pomerium/internal/sessions/mock"
)
func TestProxy_AuthenticateSession(t *testing.T) {
@ -30,24 +34,39 @@ func TestProxy_AuthenticateSession(t *testing.T) {
session sessions.SessionStore
ctxError error
provider identity.Authenticator
encoder encoding.MarshalUnmarshaler
refreshURL string
wantStatus int
}{
{"good", false, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Second))}}, nil, identity.MockProvider{}, http.StatusOK},
{"invalid session", false, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Second))}}, errors.New("hi"), identity.MockProvider{}, http.StatusFound},
{"expired", false, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{}, http.StatusFound},
{"expired and programmatic", false, &sessions.MockSessionStore{Session: &sessions.State{Programmatic: true, Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{}, http.StatusUnauthorized},
{"invalid session and programmatic", false, &sessions.MockSessionStore{Session: &sessions.State{Programmatic: true, Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Second))}}, errors.New("hi"), identity.MockProvider{}, http.StatusUnauthorized},
{"good", false, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Second))}}, nil, identity.MockProvider{}, &mock.Encoder{}, "", http.StatusOK},
{"invalid session", false, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Second))}}, errors.New("hi"), identity.MockProvider{}, &mock.Encoder{}, "", http.StatusFound},
{"expired", false, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{}, &mock.Encoder{}, "", http.StatusOK},
{"expired and programmatic", false, &mstore.Store{Session: &sessions.State{Programmatic: true, Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{}, &mock.Encoder{}, "", http.StatusOK},
{"invalid session and programmatic", false, &mstore.Store{Session: &sessions.State{Programmatic: true, Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Second))}}, errors.New("hi"), identity.MockProvider{}, &mock.Encoder{}, "", http.StatusUnauthorized},
{"expired and refreshed ok", false, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{}, &mock.Encoder{}, "", http.StatusOK},
{"expired and save failed", false, &mstore.Store{SaveError: errors.New("err"), Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{}, &mock.Encoder{}, "", http.StatusFound},
{"expired and unmarshal failed", false, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{}, &mock.Encoder{UnmarshalError: errors.New("err")}, "", http.StatusFound},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "REFRESH GOOD")
}))
defer ts.Close()
rURL := ts.URL
if tt.refreshURL != "" {
rURL = tt.refreshURL
}
a := Proxy{
SharedKey: "80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ=",
cookieSecret: []byte("80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ="),
authenticateURL: uriParseHelper("https://authenticate.corp.example"),
authenticateSigninURL: uriParseHelper("https://authenticate.corp.example/sign_in"),
sessionStore: tt.session,
SharedKey: "80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ=",
cookieSecret: []byte("80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ="),
authenticateURL: uriParseHelper("https://authenticate.corp.example"),
authenticateSigninURL: uriParseHelper("https://authenticate.corp.example/sign_in"),
authenticateRefreshURL: uriParseHelper(rURL),
sessionStore: tt.session,
encoder: tt.encoder,
}
r := httptest.NewRequest(http.MethodGet, "/", nil)
state, _ := tt.session.LoadSession(r)
@ -82,10 +101,10 @@ func TestProxy_AuthorizeSession(t *testing.T) {
wantStatus int
}{
{"user is authorized", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Second))}}, clients.MockAuthorize{AuthorizeResponse: true}, nil, identity.MockProvider{}, http.StatusOK},
{"user is not authorized", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Second))}}, clients.MockAuthorize{AuthorizeResponse: false}, nil, identity.MockProvider{}, http.StatusUnauthorized},
{"invalid session", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Second))}}, clients.MockAuthorize{AuthorizeResponse: true}, errors.New("hi"), identity.MockProvider{}, http.StatusUnauthorized},
{"authz client error", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Second))}}, clients.MockAuthorize{AuthorizeError: errors.New("err")}, nil, identity.MockProvider{}, http.StatusInternalServerError},
{"user is authorized", &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Second))}}, clients.MockAuthorize{AuthorizeResponse: true}, nil, identity.MockProvider{}, http.StatusOK},
{"user is not authorized", &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Second))}}, clients.MockAuthorize{AuthorizeResponse: false}, nil, identity.MockProvider{}, http.StatusUnauthorized},
{"invalid session", &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Second))}}, clients.MockAuthorize{AuthorizeResponse: true}, errors.New("hi"), identity.MockProvider{}, http.StatusUnauthorized},
{"authz client error", &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Second))}}, clients.MockAuthorize{AuthorizeError: errors.New("err")}, nil, identity.MockProvider{}, http.StatusInternalServerError},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
@ -143,9 +162,9 @@ func TestProxy_SignRequest(t *testing.T) {
wantStatus int
wantHeaders string
}{
{"good", &sessions.MockSessionStore{Session: &sessions.State{Email: "test"}}, nil, nil, http.StatusOK, "ok"},
{"invalid session", &sessions.MockSessionStore{Session: &sessions.State{Email: "test"}}, nil, errors.New("err"), http.StatusForbidden, ""},
{"signature failure, warn but ok", &sessions.MockSessionStore{Session: &sessions.State{Email: "test"}}, errors.New("err"), nil, http.StatusOK, ""},
{"good", &mstore.Store{Session: &sessions.State{Email: "test"}}, nil, nil, http.StatusOK, "ok"},
{"invalid session", &mstore.Store{Session: &sessions.State{Email: "test"}}, nil, errors.New("err"), http.StatusForbidden, ""},
{"signature failure, warn but ok", &mstore.Store{Session: &sessions.State{Email: "test"}}, errors.New("err"), nil, http.StatusOK, ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {

View file

@ -21,6 +21,9 @@ import (
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/middleware"
"github.com/pomerium/pomerium/internal/sessions"
"github.com/pomerium/pomerium/internal/sessions/cookie"
"github.com/pomerium/pomerium/internal/sessions/header"
"github.com/pomerium/pomerium/internal/sessions/queryparam"
"github.com/pomerium/pomerium/internal/telemetry/metrics"
"github.com/pomerium/pomerium/internal/tripper"
"github.com/pomerium/pomerium/internal/urlutil"
@ -28,12 +31,11 @@ import (
)
const (
// dashboardURL is the path to authenticate's sign in endpoint
// authenticate urls
dashboardURL = "/.pomerium"
// signinURL is the path to authenticate's sign in endpoint
signinURL = "/.pomerium/sign_in"
// signoutURL is the path to authenticate's sign out endpoint
signoutURL = "/.pomerium/sign_out"
signinURL = "/.pomerium/sign_in"
signoutURL = "/.pomerium/sign_out"
refreshURL = "/.pomerium/refresh"
)
// ValidateOptions checks that proper configuration settings are set to create
@ -72,12 +74,14 @@ type Proxy struct {
authenticateURL *url.URL
authenticateSigninURL *url.URL
authenticateSignoutURL *url.URL
authorizeURL *url.URL
authenticateRefreshURL *url.URL
authorizeURL *url.URL
AuthorizeClient clients.Authorizer
encoder encoding.Unmarshaler
cookieOptions *sessions.CookieOptions
cookieOptions *cookie.Options
cookieSecret []byte
defaultUpstreamTimeout time.Duration
refreshCooldown time.Duration
@ -104,7 +108,7 @@ func New(opts config.Options) (*Proxy, error) {
return nil, err
}
cookieOptions := &sessions.CookieOptions{
cookieOptions := &cookie.Options{
Name: opts.CookieName,
Domain: opts.CookieDomain,
Secure: opts.CookieSecure,
@ -112,7 +116,7 @@ func New(opts config.Options) (*Proxy, error) {
Expire: opts.CookieExpire,
}
cookieStore, err := sessions.NewCookieLoader(cookieOptions, encoder)
cookieStore, err := cookie.NewStore(cookieOptions, encoder)
if err != nil {
return nil, err
}
@ -129,8 +133,8 @@ func New(opts config.Options) (*Proxy, error) {
sessionStore: cookieStore,
sessionLoaders: []sessions.SessionLoader{
cookieStore,
sessions.NewHeaderStore(encoder, "Pomerium"),
sessions.NewQueryParamStore(encoder, "pomerium_session")},
header.NewStore(encoder, "Pomerium"),
queryparam.NewStore(encoder, "pomerium_session")},
signingKey: opts.SigningKey,
templates: template.Must(frontend.NewTemplates()),
}
@ -139,6 +143,7 @@ func New(opts config.Options) (*Proxy, error) {
p.authenticateURL, _ = urlutil.DeepCopy(opts.AuthenticateURL)
p.authenticateSigninURL = p.authenticateURL.ResolveReference(&url.URL{Path: signinURL})
p.authenticateSignoutURL = p.authenticateURL.ResolveReference(&url.URL{Path: signoutURL})
p.authenticateRefreshURL = p.authenticateURL.ResolveReference(&url.URL{Path: refreshURL})
if err := p.UpdatePolicies(&opts); err != nil {
return nil, err