mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-11 08:07:38 +02:00
authenticate/proxy: add backend refresh (#438)
This commit is contained in:
parent
9a330613aa
commit
ec029c679b
35 changed files with 1226 additions and 445 deletions
|
@ -198,11 +198,6 @@ issues:
|
|||
linters:
|
||||
- staticcheck
|
||||
|
||||
# todo(bdd): replace in go 1.13
|
||||
- path: proxy/proxy.go
|
||||
text: "copylocks: assignment copies lock value to transport"
|
||||
linters:
|
||||
- govet
|
||||
# Independently from option `exclude` we use default exclude patterns,
|
||||
# it can be disabled by this option. To list all
|
||||
# excluded by default patterns execute `golangci-lint run --help`.
|
||||
|
|
|
@ -16,6 +16,10 @@ import (
|
|||
"github.com/pomerium/pomerium/internal/frontend"
|
||||
"github.com/pomerium/pomerium/internal/identity"
|
||||
"github.com/pomerium/pomerium/internal/sessions"
|
||||
"github.com/pomerium/pomerium/internal/sessions/cache"
|
||||
"github.com/pomerium/pomerium/internal/sessions/cookie"
|
||||
"github.com/pomerium/pomerium/internal/sessions/header"
|
||||
"github.com/pomerium/pomerium/internal/sessions/queryparam"
|
||||
"github.com/pomerium/pomerium/internal/urlutil"
|
||||
)
|
||||
|
||||
|
@ -49,6 +53,8 @@ type Authenticate struct {
|
|||
// authentication flow
|
||||
RedirectURL *url.URL
|
||||
|
||||
// values related to cross service communication
|
||||
//
|
||||
// sharedKey is used to encrypt and authenticate data between services
|
||||
sharedKey string
|
||||
// sharedCipher is used to encrypt data for use between services
|
||||
|
@ -57,16 +63,21 @@ type Authenticate struct {
|
|||
// by other services
|
||||
sharedEncoder encoding.MarshalUnmarshaler
|
||||
|
||||
// data related to this service only
|
||||
cookieOptions *sessions.CookieOptions
|
||||
// cookieSecret is the secret to encrypt and authenticate data for this service
|
||||
// values related to user sessions
|
||||
//
|
||||
// cookieSecret is the secret to encrypt and authenticate session data
|
||||
cookieSecret []byte
|
||||
// is the cipher to use to encrypt data for this service
|
||||
cookieCipher cipher.AEAD
|
||||
sessionStore sessions.SessionStore
|
||||
// cookieCipher is the cipher to use to encrypt/decrypt session data
|
||||
cookieCipher cipher.AEAD
|
||||
// encryptedEncoder is the encoder used to marshal and unmarshal session data
|
||||
encryptedEncoder encoding.MarshalUnmarshaler
|
||||
sessionStores []sessions.SessionStore
|
||||
sessionLoaders []sessions.SessionLoader
|
||||
// sessionStore is the session store used to persist a user's session
|
||||
sessionStore sessions.SessionStore
|
||||
cookieOptions *cookie.Options
|
||||
|
||||
// sessionLoaders are a collection of session loaders to attempt to pull
|
||||
// a user's session state from
|
||||
sessionLoaders []sessions.SessionLoader
|
||||
|
||||
// provider is the interface to interacting with the identity provider (IdP)
|
||||
provider identity.Authenticator
|
||||
|
@ -92,7 +103,7 @@ func New(opts config.Options) (*Authenticate, error) {
|
|||
cookieCipher, _ := cryptutil.NewAEADCipher(decodedCookieSecret)
|
||||
encryptedEncoder := ecjson.New(cookieCipher)
|
||||
|
||||
cookieOptions := &sessions.CookieOptions{
|
||||
cookieOptions := &cookie.Options{
|
||||
Name: opts.CookieName,
|
||||
Domain: opts.CookieDomain,
|
||||
Secure: opts.CookieSecure,
|
||||
|
@ -100,12 +111,13 @@ func New(opts config.Options) (*Authenticate, error) {
|
|||
Expire: opts.CookieExpire,
|
||||
}
|
||||
|
||||
cookieStore, err := sessions.NewCookieStore(cookieOptions, encryptedEncoder)
|
||||
cookieStore, err := cookie.NewStore(cookieOptions, encryptedEncoder)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
qpStore := sessions.NewQueryParamStore(encryptedEncoder, "pomerium_programmatic_token")
|
||||
headerStore := sessions.NewHeaderStore(encryptedEncoder, "Pomerium")
|
||||
cacheStore := cache.NewStore(encryptedEncoder, cookieStore, opts.CookieName)
|
||||
qpStore := queryparam.NewStore(encryptedEncoder, "pomerium_programmatic_token")
|
||||
headerStore := header.NewStore(encryptedEncoder, "Pomerium")
|
||||
|
||||
redirectURL, _ := urlutil.DeepCopy(opts.AuthenticateURL)
|
||||
redirectURL.Path = callbackPath
|
||||
|
@ -135,10 +147,9 @@ func New(opts config.Options) (*Authenticate, error) {
|
|||
cookieSecret: decodedCookieSecret,
|
||||
cookieCipher: cookieCipher,
|
||||
cookieOptions: cookieOptions,
|
||||
sessionStore: cookieStore,
|
||||
sessionStore: cacheStore,
|
||||
encryptedEncoder: encryptedEncoder,
|
||||
sessionLoaders: []sessions.SessionLoader{qpStore, headerStore, cookieStore},
|
||||
sessionStores: []sessions.SessionStore{cookieStore, qpStore},
|
||||
sessionLoaders: []sessions.SessionLoader{cacheStore, qpStore, headerStore, cookieStore},
|
||||
// IdP
|
||||
provider: provider,
|
||||
|
||||
|
|
|
@ -72,15 +72,18 @@ func TestOptions_Validate(t *testing.T) {
|
|||
|
||||
func TestNew(t *testing.T) {
|
||||
good := newTestOptions(t)
|
||||
good.CookieName = "A"
|
||||
|
||||
badRedirectURL := newTestOptions(t)
|
||||
badRedirectURL.AuthenticateURL = nil
|
||||
badRedirectURL.CookieName = "B"
|
||||
|
||||
badCookieName := newTestOptions(t)
|
||||
badCookieName.CookieName = ""
|
||||
|
||||
badProvider := newTestOptions(t)
|
||||
badProvider.Provider = ""
|
||||
badProvider.CookieName = "C"
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package authenticate // import "github.com/pomerium/pomerium/authenticate"
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
|
@ -18,6 +19,7 @@ import (
|
|||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/internal/middleware"
|
||||
"github.com/pomerium/pomerium/internal/sessions"
|
||||
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
||||
"github.com/pomerium/pomerium/internal/urlutil"
|
||||
)
|
||||
|
||||
|
@ -58,6 +60,7 @@ func (a *Authenticate) Handler() http.Handler {
|
|||
v.Use(a.VerifySession)
|
||||
v.Path("/sign_in").Handler(httputil.HandlerFunc(a.SignIn))
|
||||
v.Path("/sign_out").Handler(httputil.HandlerFunc(a.SignOut))
|
||||
v.Path("/refresh").Handler(httputil.HandlerFunc(a.Refresh)).Methods(http.MethodGet)
|
||||
|
||||
// programmatic access api endpoint
|
||||
api := r.PathPrefix("/api").Subrouter()
|
||||
|
@ -73,12 +76,12 @@ func (a *Authenticate) VerifySession(next http.Handler) http.Handler {
|
|||
return httputil.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
|
||||
state, err := sessions.FromContext(r.Context())
|
||||
if errors.Is(err, sessions.ErrExpired) {
|
||||
if err := a.refresh(w, r, state); err != nil {
|
||||
ctx, err := a.refresh(w, r, state)
|
||||
if err != nil {
|
||||
log.FromRequest(r).Info().Err(err).Msg("authenticate: verify session, refresh")
|
||||
return a.reauthenticateOrFail(w, r, err)
|
||||
}
|
||||
// redirect to restart middleware-chain following refresh
|
||||
httputil.Redirect(w, r, urlutil.GetAbsoluteURL(r).String(), http.StatusFound)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
return nil
|
||||
} else if err != nil {
|
||||
log.FromRequest(r).Info().Err(err).Msg("authenticate: verify session")
|
||||
|
@ -89,15 +92,18 @@ func (a *Authenticate) VerifySession(next http.Handler) http.Handler {
|
|||
})
|
||||
}
|
||||
|
||||
func (a *Authenticate) refresh(w http.ResponseWriter, r *http.Request, s *sessions.State) error {
|
||||
newSession, err := a.provider.Refresh(r.Context(), s)
|
||||
func (a *Authenticate) refresh(w http.ResponseWriter, r *http.Request, s *sessions.State) (context.Context, error) {
|
||||
ctx, span := trace.StartSpan(r.Context(), "authenticate.VerifySession/refresh")
|
||||
defer span.End()
|
||||
newSession, err := a.provider.Refresh(ctx, s)
|
||||
if err != nil {
|
||||
return fmt.Errorf("authenticate: refresh failed: %w", err)
|
||||
return nil, fmt.Errorf("authenticate: refresh failed: %w", err)
|
||||
}
|
||||
if err := a.sessionStore.SaveSession(w, r, newSession); err != nil {
|
||||
return fmt.Errorf("authenticate: refresh save failed: %w", err)
|
||||
return nil, fmt.Errorf("authenticate: refresh save failed: %w", err)
|
||||
}
|
||||
return nil
|
||||
// return the new session and add it to the current request context
|
||||
return sessions.NewContext(ctx, newSession, err), nil
|
||||
}
|
||||
|
||||
// RobotsTxt handles the /robots.txt route.
|
||||
|
@ -158,7 +164,6 @@ func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) error {
|
|||
encSession, err := a.encryptedEncoder.Marshal(newSession)
|
||||
if err != nil {
|
||||
return httputil.NewError(http.StatusBadRequest, err)
|
||||
|
||||
}
|
||||
callbackParams.Set(urlutil.QueryRefreshToken, string(encSession))
|
||||
callbackParams.Set(urlutil.QueryIsProgrammatic, "true")
|
||||
|
@ -345,3 +350,27 @@ func (a *Authenticate) RefreshAPI(w http.ResponseWriter, r *http.Request) error
|
|||
w.Write(jsonResponse)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Refresh is called by the proxy service to handle backend session refresh.
|
||||
//
|
||||
// NOTE: The actual refresh is actually handled as part of the "VerifySession"
|
||||
// middleware. This handler is responsible for creating a new route scoped
|
||||
// session and returning it.
|
||||
func (a *Authenticate) Refresh(w http.ResponseWriter, r *http.Request) error {
|
||||
s, err := sessions.FromContext(r.Context())
|
||||
if err != nil {
|
||||
return httputil.NewError(http.StatusBadRequest, err)
|
||||
}
|
||||
|
||||
routeSession := s.NewSession(r.Host, []string{r.Host, r.FormValue("aud")})
|
||||
routeSession.AccessTokenID = s.AccessTokenID
|
||||
|
||||
signedJWT, err := a.sharedEncoder.Marshal(routeSession.RouteSession())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/jwt") // RFC 7519 : 10.3.1
|
||||
w.Write(signedJWT)
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -11,17 +11,18 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/httputil"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/cryptutil"
|
||||
"github.com/pomerium/pomerium/internal/encoding"
|
||||
"github.com/pomerium/pomerium/internal/encoding/mock"
|
||||
"github.com/pomerium/pomerium/internal/frontend"
|
||||
"github.com/pomerium/pomerium/internal/httputil"
|
||||
"github.com/pomerium/pomerium/internal/identity"
|
||||
"github.com/pomerium/pomerium/internal/sessions"
|
||||
"github.com/pomerium/pomerium/internal/urlutil"
|
||||
"github.com/pomerium/pomerium/internal/sessions/cookie"
|
||||
mstore "github.com/pomerium/pomerium/internal/sessions/mock"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/pomerium/pomerium/internal/urlutil"
|
||||
"golang.org/x/crypto/chacha20poly1305"
|
||||
"golang.org/x/oauth2"
|
||||
"gopkg.in/square/go-jose.v2/jwt"
|
||||
|
@ -32,7 +33,7 @@ func testAuthenticate() *Authenticate {
|
|||
auth.RedirectURL, _ = url.Parse("https://auth.example.com/oauth/callback")
|
||||
auth.sharedKey = cryptutil.NewBase64Key()
|
||||
auth.cookieSecret = cryptutil.NewKey()
|
||||
auth.cookieOptions = &sessions.CookieOptions{Name: "name"}
|
||||
auth.cookieOptions = &cookie.Options{Name: "name"}
|
||||
auth.templates = template.Must(frontend.NewTemplates())
|
||||
return &auth
|
||||
}
|
||||
|
@ -112,19 +113,19 @@ func TestAuthenticate_SignIn(t *testing.T) {
|
|||
encoder encoding.MarshalUnmarshaler
|
||||
wantCode int
|
||||
}{
|
||||
{"good", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
|
||||
{"session not valid", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(-10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
|
||||
{"bad redirect uri query", "", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "^^^"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusBadRequest},
|
||||
{"bad marshal", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{MarshalError: errors.New("error")}, http.StatusBadRequest},
|
||||
{"session error", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, &sessions.MockSessionStore{LoadError: errors.New("error")}, identity.MockProvider{}, &mock.Encoder{}, http.StatusBadRequest},
|
||||
{"good with different programmatic redirect", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/", urlutil.QueryCallbackURI: "https://some.example"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
|
||||
{"encrypted encoder error", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/", urlutil.QueryCallbackURI: "https://some.example"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{MarshalError: errors.New("error")}, http.StatusBadRequest},
|
||||
{"good with callback uri set", "https", "corp.example.example", map[string]string{urlutil.QueryCallbackURI: "https://some.example/", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
|
||||
{"bad callback uri set", "https", "corp.example.example", map[string]string{urlutil.QueryCallbackURI: "^", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusBadRequest},
|
||||
{"good programmatic request", "https", "corp.example.example", map[string]string{urlutil.QueryIsProgrammatic: "true", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
|
||||
{"good additional audience", "https", "corp.example.example", map[string]string{urlutil.QueryForwardAuth: "x.y.z", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
|
||||
{"good user impersonate", "https", "corp.example.example", map[string]string{urlutil.QueryImpersonateAction: "set", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
|
||||
{"bad user impersonate save failure", "https", "corp.example.example", map[string]string{urlutil.QueryImpersonateAction: "set", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &sessions.MockSessionStore{SaveError: errors.New("err"), Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusBadRequest},
|
||||
{"good", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
|
||||
{"session not valid", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(-10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
|
||||
{"bad redirect uri query", "", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "^^^"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusBadRequest},
|
||||
{"bad marshal", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{MarshalError: errors.New("error")}, http.StatusBadRequest},
|
||||
{"session error", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{LoadError: errors.New("error")}, identity.MockProvider{}, &mock.Encoder{}, http.StatusBadRequest},
|
||||
{"good with different programmatic redirect", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/", urlutil.QueryCallbackURI: "https://some.example"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
|
||||
{"encrypted encoder error", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/", urlutil.QueryCallbackURI: "https://some.example"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{MarshalError: errors.New("error")}, http.StatusBadRequest},
|
||||
{"good with callback uri set", "https", "corp.example.example", map[string]string{urlutil.QueryCallbackURI: "https://some.example/", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
|
||||
{"bad callback uri set", "https", "corp.example.example", map[string]string{urlutil.QueryCallbackURI: "^", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusBadRequest},
|
||||
{"good programmatic request", "https", "corp.example.example", map[string]string{urlutil.QueryIsProgrammatic: "true", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
|
||||
{"good additional audience", "https", "corp.example.example", map[string]string{urlutil.QueryForwardAuth: "x.y.z", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
|
||||
{"good user impersonate", "https", "corp.example.example", map[string]string{urlutil.QueryImpersonateAction: "set", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
|
||||
{"bad user impersonate save failure", "https", "corp.example.example", map[string]string{urlutil.QueryImpersonateAction: "set", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{SaveError: errors.New("err"), Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusBadRequest},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
@ -136,7 +137,7 @@ func TestAuthenticate_SignIn(t *testing.T) {
|
|||
sharedEncoder: tt.encoder,
|
||||
encryptedEncoder: tt.encoder,
|
||||
sharedCipher: aead,
|
||||
cookieOptions: &sessions.CookieOptions{
|
||||
cookieOptions: &cookie.Options{
|
||||
Name: "cookie",
|
||||
Domain: "foo",
|
||||
},
|
||||
|
@ -186,10 +187,10 @@ func TestAuthenticate_SignOut(t *testing.T) {
|
|||
wantCode int
|
||||
wantBody string
|
||||
}{
|
||||
{"good post", http.MethodPost, nil, "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, http.StatusFound, ""},
|
||||
{"failed revoke", http.MethodPost, nil, "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{RevokeError: errors.New("OH NO")}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, http.StatusBadRequest, "{\"Status\":400,\"Error\":\"Bad Request: OH NO\"}\n"},
|
||||
{"load session error", http.MethodPost, errors.New("error"), "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{RevokeError: errors.New("OH NO")}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, http.StatusBadRequest, "{\"Status\":400,\"Error\":\"Bad Request: error\"}\n"},
|
||||
{"bad redirect uri", http.MethodPost, nil, "corp.pomerium.io/", "sig", "ts", identity.MockProvider{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, http.StatusBadRequest, "{\"Status\":400,\"Error\":\"Bad Request: corp.pomerium.io/ url does contain a valid scheme\"}\n"},
|
||||
{"good post", http.MethodPost, nil, "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, http.StatusFound, ""},
|
||||
{"failed revoke", http.MethodPost, nil, "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{RevokeError: errors.New("OH NO")}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, http.StatusBadRequest, "{\"Status\":400,\"Error\":\"Bad Request: OH NO\"}\n"},
|
||||
{"load session error", http.MethodPost, errors.New("error"), "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{RevokeError: errors.New("OH NO")}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, http.StatusBadRequest, "{\"Status\":400,\"Error\":\"Bad Request: error\"}\n"},
|
||||
{"bad redirect uri", http.MethodPost, nil, "corp.pomerium.io/", "sig", "ts", identity.MockProvider{}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, http.StatusBadRequest, "{\"Status\":400,\"Error\":\"Bad Request: corp.pomerium.io/ url does contain a valid scheme\"}\n"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
@ -247,19 +248,19 @@ func TestAuthenticate_OAuthCallback(t *testing.T) {
|
|||
want string
|
||||
wantCode int
|
||||
}{
|
||||
{"good", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusFound},
|
||||
{"failed authenticate", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateError: errors.New("error")}, "", http.StatusInternalServerError},
|
||||
{"failed save session", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{SaveError: errors.New("error")}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "", http.StatusInternalServerError},
|
||||
{"provider returned error", http.MethodGet, time.Now().Unix(), "", "", "", "idp error", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "", http.StatusBadRequest},
|
||||
{"empty code", http.MethodGet, time.Now().Unix(), "", "", "", "", "", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "", http.StatusBadRequest},
|
||||
{"invalid redirect uri", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "", http.StatusBadRequest},
|
||||
{"bad redirect uri", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "http://^^^", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusBadRequest},
|
||||
{"bad timing - too soon", http.MethodGet, time.Now().Add(1 * time.Hour).Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusBadRequest},
|
||||
{"bad timing - expired", http.MethodGet, time.Now().Add(-1 * time.Hour).Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusBadRequest},
|
||||
{"bad base64", http.MethodGet, time.Now().Unix(), "", "", "^", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusBadRequest},
|
||||
{"too many seperators", http.MethodGet, time.Now().Unix(), "", "", "|ok|now|what", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusBadRequest},
|
||||
{"bad hmac", http.MethodGet, time.Now().Unix(), "", "NOTMAC", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusBadRequest},
|
||||
{"bad hmac", http.MethodGet, time.Now().Unix(), base64.URLEncoding.EncodeToString([]byte("malformed_state")), "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusBadRequest},
|
||||
{"good", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusFound},
|
||||
{"failed authenticate", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateError: errors.New("error")}, "", http.StatusInternalServerError},
|
||||
{"failed save session", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{SaveError: errors.New("error")}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "", http.StatusInternalServerError},
|
||||
{"provider returned error", http.MethodGet, time.Now().Unix(), "", "", "", "idp error", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "", http.StatusBadRequest},
|
||||
{"empty code", http.MethodGet, time.Now().Unix(), "", "", "", "", "", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "", http.StatusBadRequest},
|
||||
{"invalid redirect uri", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "", http.StatusBadRequest},
|
||||
{"bad redirect uri", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "http://^^^", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusBadRequest},
|
||||
{"bad timing - too soon", http.MethodGet, time.Now().Add(1 * time.Hour).Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusBadRequest},
|
||||
{"bad timing - expired", http.MethodGet, time.Now().Add(-1 * time.Hour).Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusBadRequest},
|
||||
{"bad base64", http.MethodGet, time.Now().Unix(), "", "", "^", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusBadRequest},
|
||||
{"too many seperators", http.MethodGet, time.Now().Unix(), "", "", "|ok|now|what", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusBadRequest},
|
||||
{"bad hmac", http.MethodGet, time.Now().Unix(), "", "NOTMAC", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusBadRequest},
|
||||
{"bad hmac", http.MethodGet, time.Now().Unix(), base64.URLEncoding.EncodeToString([]byte("malformed_state")), "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusBadRequest},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
@ -326,12 +327,12 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) {
|
|||
|
||||
wantStatus int
|
||||
}{
|
||||
{"good", nil, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, http.StatusOK},
|
||||
{"invalid session", nil, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, errors.New("hi"), identity.MockProvider{}, http.StatusFound},
|
||||
{"good refresh expired", nil, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, http.StatusFound},
|
||||
{"expired,refresh error", nil, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{RefreshError: errors.New("error")}, http.StatusFound},
|
||||
{"expired,save error", nil, &sessions.MockSessionStore{SaveError: errors.New("error"), Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, http.StatusFound},
|
||||
{"expired XHR,refresh error", map[string]string{"X-Requested-With": "XmlHttpRequest"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{RefreshError: errors.New("error")}, http.StatusUnauthorized},
|
||||
{"good", nil, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, http.StatusOK},
|
||||
{"invalid session", nil, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, errors.New("hi"), identity.MockProvider{}, http.StatusFound},
|
||||
{"good refresh expired", nil, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, http.StatusOK},
|
||||
{"expired,refresh error", nil, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{RefreshError: errors.New("error")}, http.StatusFound},
|
||||
{"expired,save error", nil, &mstore.Store{SaveError: errors.New("error"), Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, http.StatusFound},
|
||||
{"expired XHR,refresh error", map[string]string{"X-Requested-With": "XmlHttpRequest"}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{RefreshError: errors.New("error")}, http.StatusUnauthorized},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
@ -384,11 +385,11 @@ func TestAuthenticate_RefreshAPI(t *testing.T) {
|
|||
|
||||
wantStatus int
|
||||
}{
|
||||
{"good", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, mock.Encoder{MarshalResponse: []byte("ok")}, mock.Encoder{MarshalResponse: []byte("ok")}, http.StatusOK},
|
||||
{"refresh error", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshError: errors.New("error")}, mock.Encoder{MarshalResponse: []byte("ok")}, mock.Encoder{MarshalResponse: []byte("ok")}, http.StatusInternalServerError},
|
||||
{"session is not refreshable error", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, errors.New("session error"), identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, mock.Encoder{MarshalResponse: []byte("ok")}, mock.Encoder{MarshalResponse: []byte("ok")}, http.StatusBadRequest},
|
||||
{"secret encoder failed", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, mock.Encoder{MarshalError: errors.New("error")}, mock.Encoder{MarshalResponse: []byte("ok")}, http.StatusInternalServerError},
|
||||
{"shared encoder failed", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, mock.Encoder{MarshalResponse: []byte("ok")}, mock.Encoder{MarshalError: errors.New("error")}, http.StatusInternalServerError},
|
||||
{"good", &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, mock.Encoder{MarshalResponse: []byte("ok")}, mock.Encoder{MarshalResponse: []byte("ok")}, http.StatusOK},
|
||||
{"refresh error", &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshError: errors.New("error")}, mock.Encoder{MarshalResponse: []byte("ok")}, mock.Encoder{MarshalResponse: []byte("ok")}, http.StatusInternalServerError},
|
||||
{"session is not refreshable error", &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, errors.New("session error"), identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, mock.Encoder{MarshalResponse: []byte("ok")}, mock.Encoder{MarshalResponse: []byte("ok")}, http.StatusBadRequest},
|
||||
{"secret encoder failed", &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, mock.Encoder{MarshalError: errors.New("error")}, mock.Encoder{MarshalResponse: []byte("ok")}, http.StatusInternalServerError},
|
||||
{"shared encoder failed", &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, mock.Encoder{MarshalResponse: []byte("ok")}, mock.Encoder{MarshalError: errors.New("error")}, http.StatusInternalServerError},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
@ -423,3 +424,54 @@ func TestAuthenticate_RefreshAPI(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
func TestAuthenticate_Refresh(t *testing.T) {
|
||||
t.Parallel()
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
session sessions.SessionStore
|
||||
ctxError error
|
||||
|
||||
provider identity.Authenticator
|
||||
secretEncoder encoding.MarshalUnmarshaler
|
||||
sharedEncoder encoding.MarshalUnmarshaler
|
||||
|
||||
wantStatus int
|
||||
}{
|
||||
{"good", &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, mock.Encoder{MarshalResponse: []byte("ok")}, mock.Encoder{MarshalResponse: []byte("ok")}, http.StatusOK},
|
||||
{"bad session", &mstore.Store{}, errors.New("err"), identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, mock.Encoder{MarshalResponse: []byte("ok")}, mock.Encoder{MarshalResponse: []byte("ok")}, http.StatusBadRequest},
|
||||
{"encoder error", &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, mock.Encoder{MarshalResponse: []byte("ok")}, mock.Encoder{MarshalError: errors.New("err")}, http.StatusInternalServerError},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
aead, err := chacha20poly1305.NewX(cryptutil.NewKey())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
a := Authenticate{
|
||||
sharedKey: cryptutil.NewBase64Key(),
|
||||
cookieSecret: cryptutil.NewKey(),
|
||||
RedirectURL: uriParseHelper("https://authenticate.corp.beyondperimeter.com"),
|
||||
encryptedEncoder: tt.secretEncoder,
|
||||
sharedEncoder: tt.sharedEncoder,
|
||||
sessionStore: tt.session,
|
||||
provider: tt.provider,
|
||||
cookieCipher: aead,
|
||||
}
|
||||
r := httptest.NewRequest("GET", "/", nil)
|
||||
state, _ := tt.session.LoadSession(r)
|
||||
ctx := r.Context()
|
||||
ctx = sessions.NewContext(ctx, state, tt.ctxError)
|
||||
r = r.WithContext(ctx)
|
||||
|
||||
r.Header.Set("Accept", "application/json")
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
httputil.HandlerFunc(a.Refresh).ServeHTTP(w, r)
|
||||
if status := w.Code; status != tt.wantStatus {
|
||||
t.Errorf("VerifySession() error = %v, wantErr %v\n%v", w.Result().StatusCode, tt.wantStatus, w.Body.String())
|
||||
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -11,12 +11,13 @@ import (
|
|||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/fsnotify/fsnotify"
|
||||
"github.com/pomerium/pomerium/internal/cryptutil"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/internal/telemetry/metrics"
|
||||
"github.com/pomerium/pomerium/internal/urlutil"
|
||||
|
||||
"github.com/cespare/xxhash/v2"
|
||||
"github.com/fsnotify/fsnotify"
|
||||
"github.com/mitchellh/hashstructure"
|
||||
"github.com/spf13/viper"
|
||||
"gopkg.in/yaml.v2"
|
||||
|
@ -477,7 +478,7 @@ type OptionsUpdater interface {
|
|||
|
||||
// Checksum returns the checksum of the current options struct
|
||||
func (o *Options) Checksum() string {
|
||||
hash, err := hashstructure.Hash(o, nil)
|
||||
hash, err := hashstructure.Hash(o, &hashstructure.HashOptions{Hasher: xxhash.New()})
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("config: checksum failure")
|
||||
return "no checksum available"
|
||||
|
|
4
go.mod
4
go.mod
|
@ -6,9 +6,9 @@ require (
|
|||
cloud.google.com/go v0.49.0 // indirect
|
||||
contrib.go.opencensus.io/exporter/jaeger v0.2.0
|
||||
contrib.go.opencensus.io/exporter/prometheus v0.1.0
|
||||
github.com/cespare/xxhash/v2 v2.1.1 // indirect
|
||||
github.com/cespare/xxhash/v2 v2.1.1
|
||||
github.com/fsnotify/fsnotify v1.4.7
|
||||
github.com/golang/groupcache v0.0.0-20191027212112-611e8accdfc9 // indirect
|
||||
github.com/golang/groupcache v0.0.0-20191227052852-215e87163ea7
|
||||
github.com/golang/mock v1.3.1
|
||||
github.com/golang/protobuf v1.3.2
|
||||
github.com/google/go-cmp v0.3.1
|
||||
|
|
5
go.sum
5
go.sum
|
@ -71,6 +71,8 @@ github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6 h1:ZgQEtGgCBiWRM
|
|||
github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
|
||||
github.com/golang/groupcache v0.0.0-20191027212112-611e8accdfc9 h1:uHTyIjqVhYRhLbJ8nIiOJHkEZZ+5YoOsAbD3sk82NiE=
|
||||
github.com/golang/groupcache v0.0.0-20191027212112-611e8accdfc9/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
|
||||
github.com/golang/groupcache v0.0.0-20191227052852-215e87163ea7 h1:5ZkaAPbicIKTF2I64qf5Fh8Aa83Q/dnOafMYV0OMwjA=
|
||||
github.com/golang/groupcache v0.0.0-20191227052852-215e87163ea7/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
|
||||
github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A=
|
||||
github.com/golang/mock v1.2.0 h1:28o5sBqPkBsMGnC6b4MvE2TzSr5/AT4c/1fLqVGIwlk=
|
||||
github.com/golang/mock v1.2.0/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A=
|
||||
|
@ -162,6 +164,7 @@ github.com/prometheus/client_golang v0.9.3/go.mod h1:/TN21ttK/J9q6uSwhBd54HahCDf
|
|||
github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5FsnadC4Ky3P0J6CfImo=
|
||||
github.com/prometheus/client_golang v1.2.1 h1:JnMpQc6ppsNgw9QPAGF6Dod479itz7lvlsMzzNayLOI=
|
||||
github.com/prometheus/client_golang v1.2.1/go.mod h1:XMU6Z2MjaRKVu/dC1qupJI9SiNkDYzz3xecMgSW/F+U=
|
||||
github.com/prometheus/client_golang v1.3.0 h1:miYCvYqFXtl/J9FIy8eNpBfYthAEFg+Ys0XyUVEcDsc=
|
||||
github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo=
|
||||
github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90 h1:S/YWwWx/RA8rT8tKFRuGUZhuA90OyIBpPCXkcbwU8DE=
|
||||
github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
|
||||
|
@ -214,6 +217,7 @@ github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
|
|||
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
|
||||
github.com/spf13/viper v1.5.0 h1:GpsTwfsQ27oS/Aha/6d1oD7tpKIqWnOA6tgOX9HHkt4=
|
||||
github.com/spf13/viper v1.5.0/go.mod h1:AkYRkVJF8TkSG/xet6PzXX+l39KhhXa2pdqVSxnTcn4=
|
||||
github.com/spf13/viper v1.6.1 h1:VPZzIkznI1YhVMRi6vNFLHSwhnhReBfgTxIPccpfdZk=
|
||||
github.com/spf13/viper v1.6.1/go.mod h1:t3iDnF5Jlj76alVNuyFBk5oUMCvsrkbvZK0WQdfDi5k=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
|
@ -384,6 +388,7 @@ gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8
|
|||
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY=
|
||||
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI=
|
||||
gopkg.in/ini.v1 v1.51.0 h1:AQvPpx3LzTDM0AjnIRlVFwFFGC+npRopjZxLJj6gdno=
|
||||
gopkg.in/ini.v1 v1.51.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k=
|
||||
gopkg.in/resty.v1 v1.12.0/go.mod h1:mDo4pnntr5jdWRML875a/NmxYqAlA73dVijT2AXvQQo=
|
||||
gopkg.in/square/go-jose.v2 v2.4.0 h1:0kXPskUMGAXXWJlP05ktEMOV0vmzFQUWw6d+aZJQU8A=
|
||||
|
|
|
@ -1,5 +1,13 @@
|
|||
package mock // import "github.com/pomerium/pomerium/internal/encoding/mock"
|
||||
|
||||
import (
|
||||
"github.com/pomerium/pomerium/internal/encoding"
|
||||
)
|
||||
|
||||
var _ encoding.MarshalUnmarshaler = &Encoder{}
|
||||
var _ encoding.Marshaler = &Encoder{}
|
||||
var _ encoding.Unmarshaler = &Encoder{}
|
||||
|
||||
// Encoder MockCSRFStore is a mock implementation of Cipher.
|
||||
type Encoder struct {
|
||||
MarshalResponse []byte
|
||||
|
|
|
@ -8,23 +8,21 @@ import (
|
|||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"go.opencensus.io/plugin/ochttp"
|
||||
)
|
||||
|
||||
// ErrTokenRevoked signifies a token revokation or expiration error
|
||||
var ErrTokenRevoked = errors.New("token expired or revoked")
|
||||
|
||||
var httpClient = &http.Client{
|
||||
Timeout: time.Second * 5,
|
||||
Transport: &http.Transport{
|
||||
Dial: (&net.Dialer{
|
||||
Timeout: 2 * time.Second,
|
||||
}).Dial,
|
||||
TLSHandshakeTimeout: 2 * time.Second,
|
||||
},
|
||||
// DefaultClient avoids leaks by setting an upper limit for timeouts.
|
||||
var DefaultClient = &http.Client{
|
||||
Timeout: 1 * time.Minute,
|
||||
//todo(bdd): incorporate metrics.HTTPMetricsRoundTripper
|
||||
Transport: &ochttp.Transport{},
|
||||
}
|
||||
|
||||
// Client provides a simple helper interface to make HTTP requests
|
||||
|
@ -36,9 +34,11 @@ func Client(ctx context.Context, method, endpoint, userAgent string, headers map
|
|||
case http.MethodGet:
|
||||
// error checking skipped because we are just parsing in
|
||||
// order to make a copy of an existing URL
|
||||
u, _ := url.Parse(endpoint)
|
||||
u.RawQuery = params.Encode()
|
||||
endpoint = u.String()
|
||||
if params != nil {
|
||||
u, _ := url.Parse(endpoint)
|
||||
u.RawQuery = params.Encode()
|
||||
endpoint = u.String()
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf(http.StatusText(http.StatusBadRequest))
|
||||
}
|
||||
|
@ -52,7 +52,7 @@ func Client(ctx context.Context, method, endpoint, userAgent string, headers map
|
|||
req.Header.Set(k, v)
|
||||
}
|
||||
|
||||
resp, err := httpClient.Do(req)
|
||||
resp, err := DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -79,7 +79,6 @@ func Client(ctx context.Context, method, endpoint, userAgent string, headers map
|
|||
return fmt.Errorf(http.StatusText(resp.StatusCode))
|
||||
}
|
||||
}
|
||||
|
||||
if response != nil {
|
||||
err := json.Unmarshal(respBody, &response)
|
||||
if err != nil {
|
||||
|
|
131
internal/sessions/cache/cache_store.go
vendored
Normal file
131
internal/sessions/cache/cache_store.go
vendored
Normal file
|
@ -0,0 +1,131 @@
|
|||
package cache // import "github.com/pomerium/pomerium/internal/sessions/cache"
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/golang/groupcache"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/encoding"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/internal/sessions"
|
||||
)
|
||||
|
||||
var _ sessions.SessionStore = &Store{}
|
||||
var _ sessions.SessionLoader = &Store{}
|
||||
|
||||
const (
|
||||
defaultQueryParamKey = "ati"
|
||||
)
|
||||
|
||||
// Store implements the session store interface using a distributed cache.
|
||||
type Store struct {
|
||||
name string
|
||||
encoder encoding.Marshaler
|
||||
decoder encoding.Unmarshaler
|
||||
|
||||
cache *groupcache.Group
|
||||
wrappedStore sessions.SessionStore
|
||||
}
|
||||
|
||||
// defaultCacheSize is ~10MB
|
||||
var defaultCacheSize int64 = 10 << 20
|
||||
|
||||
// NewStore creates a new session store built on the distributed caching library
|
||||
// groupcache. On a cache miss, the cache store attempts to fallback to another
|
||||
// SessionStore implementation.
|
||||
func NewStore(enc encoding.MarshalUnmarshaler, wrappedStore sessions.SessionStore, name string) *Store {
|
||||
store := &Store{
|
||||
name: name,
|
||||
encoder: enc,
|
||||
decoder: enc,
|
||||
wrappedStore: wrappedStore,
|
||||
}
|
||||
|
||||
store.cache = groupcache.NewGroup(name, defaultCacheSize, groupcache.GetterFunc(
|
||||
func(ctx context.Context, id string, dest groupcache.Sink) error {
|
||||
// fill the cache with session set as part of the request
|
||||
// context set previously as part of SaveSession.
|
||||
b := fromContext(ctx)
|
||||
if len(b) == 0 {
|
||||
return fmt.Errorf("sessions/cache: cannot fill key %s from ctx", id)
|
||||
}
|
||||
if err := dest.SetBytes(b); err != nil {
|
||||
return fmt.Errorf("sessions/cache: sink error %w", err)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
))
|
||||
|
||||
return store
|
||||
}
|
||||
|
||||
// LoadSession implements SessionLoaders's LoadSession method for cache store.
|
||||
func (s *Store) LoadSession(r *http.Request) (*sessions.State, error) {
|
||||
// look for our cache's key in the default query param
|
||||
sessionID := r.URL.Query().Get(defaultQueryParamKey)
|
||||
if sessionID == "" {
|
||||
// if unset, fallback to default cache store
|
||||
log.FromRequest(r).Debug().Msg("sessions/cache: no query param, trying wrapped loader")
|
||||
return s.wrappedStore.LoadSession(r)
|
||||
}
|
||||
|
||||
var b []byte
|
||||
if err := s.cache.Get(r.Context(), sessionID, groupcache.AllocatingByteSliceSink(&b)); err != nil {
|
||||
log.FromRequest(r).Debug().Err(err).Msg("sessions/cache: miss, trying wrapped loader")
|
||||
return s.wrappedStore.LoadSession(r)
|
||||
}
|
||||
var session sessions.State
|
||||
if err := s.decoder.Unmarshal(b, &session); err != nil {
|
||||
log.FromRequest(r).Error().Err(err).Msg("sessions/cache: unmarshal")
|
||||
return nil, sessions.ErrMalformed
|
||||
}
|
||||
return &session, nil
|
||||
}
|
||||
|
||||
// ClearSession implements SessionStore's ClearSession for the cache store.
|
||||
// Since group cache has no explicit eviction, we just call the wrapped
|
||||
// store's ClearSession method here.
|
||||
func (s *Store) ClearSession(w http.ResponseWriter, r *http.Request) {
|
||||
s.wrappedStore.ClearSession(w, r)
|
||||
}
|
||||
|
||||
// SaveSession implements SessionStore's SaveSession method for cache store.
|
||||
func (s *Store) SaveSession(w http.ResponseWriter, r *http.Request, x interface{}) error {
|
||||
err := s.wrappedStore.SaveSession(w, r, x)
|
||||
if err != nil {
|
||||
return fmt.Errorf("sessions/cache: wrapped store save error %w", err)
|
||||
}
|
||||
|
||||
state, ok := x.(*sessions.State)
|
||||
if !ok {
|
||||
return errors.New("internal/sessions: cannot cache non state type")
|
||||
}
|
||||
|
||||
data, err := s.encoder.Marshal(&state)
|
||||
if err != nil {
|
||||
return fmt.Errorf("sessions/cache: marshal %w", err)
|
||||
}
|
||||
|
||||
ctx := newContext(r.Context(), data)
|
||||
var b []byte
|
||||
return s.cache.Get(ctx, state.AccessTokenID, groupcache.AllocatingByteSliceSink(&b))
|
||||
}
|
||||
|
||||
var sessionCtxKey = &contextKey{"PomeriumCachedSessionBytes"}
|
||||
|
||||
type contextKey struct {
|
||||
name string
|
||||
}
|
||||
|
||||
func newContext(ctx context.Context, b []byte) context.Context {
|
||||
ctx = context.WithValue(ctx, sessionCtxKey, b)
|
||||
return ctx
|
||||
}
|
||||
|
||||
func fromContext(ctx context.Context) []byte {
|
||||
b, _ := ctx.Value(sessionCtxKey).([]byte)
|
||||
return b
|
||||
}
|
133
internal/sessions/cache/cache_store_test.go
vendored
Normal file
133
internal/sessions/cache/cache_store_test.go
vendored
Normal file
|
@ -0,0 +1,133 @@
|
|||
package cache
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/pomerium/pomerium/internal/cryptutil"
|
||||
"github.com/pomerium/pomerium/internal/encoding/ecjson"
|
||||
"github.com/pomerium/pomerium/internal/sessions"
|
||||
"github.com/pomerium/pomerium/internal/sessions/cookie"
|
||||
"gopkg.in/square/go-jose.v2/jwt"
|
||||
)
|
||||
|
||||
func testAuthorizer(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, err := sessions.FromContext(r.Context())
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
func TestVerifier(t *testing.T) {
|
||||
fnh := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
fmt.Fprint(w, http.StatusText(http.StatusOK))
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
skipSave bool
|
||||
cacheSize int64
|
||||
state sessions.State
|
||||
|
||||
wantBody string
|
||||
wantStatus int
|
||||
}{
|
||||
{"good", false, 1 << 10, sessions.State{AccessTokenID: cryptutil.NewBase64Key(), Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}, http.StatusText(http.StatusOK), http.StatusOK},
|
||||
{"expired", false, 1 << 10, sessions.State{AccessTokenID: cryptutil.NewBase64Key(), Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}, "internal/sessions: validation failed, token is expired (exp)\n", http.StatusUnauthorized},
|
||||
{"empty", false, 1 << 10, sessions.State{AccessTokenID: "", Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}, "internal/sessions: session is not found\n", http.StatusUnauthorized},
|
||||
{"miss", true, 1 << 10, sessions.State{AccessTokenID: cryptutil.NewBase64Key(), Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}, "internal/sessions: session is not found\n", http.StatusUnauthorized},
|
||||
{"cache eviction", false, 1, sessions.State{AccessTokenID: cryptutil.NewBase64Key(), Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}, "internal/sessions: session is not found\n", http.StatusUnauthorized},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
defaultCacheSize = tt.cacheSize
|
||||
cipher, err := cryptutil.NewAEADCipherFromBase64(cryptutil.NewBase64Key())
|
||||
encoder := ecjson.New(cipher)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
cs, err := cookie.NewStore(&cookie.Options{Name: t.Name()}, encoder)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
cacheStore := NewStore(encoder, cs, t.Name())
|
||||
|
||||
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
q := r.URL.Query()
|
||||
|
||||
q.Set(defaultQueryParamKey, tt.state.AccessTokenID)
|
||||
r.URL.RawQuery = q.Encode()
|
||||
r.Header.Set("Accept", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
got := sessions.RetrieveSession(cacheStore)(testAuthorizer((fnh)))
|
||||
|
||||
if !tt.skipSave {
|
||||
cacheStore.SaveSession(w, r, &tt.state)
|
||||
}
|
||||
|
||||
for i := 1; i <= 10; i++ {
|
||||
s := tt.state
|
||||
s.AccessTokenID = cryptutil.NewBase64Key()
|
||||
cacheStore.SaveSession(w, r, s)
|
||||
}
|
||||
|
||||
got.ServeHTTP(w, r)
|
||||
|
||||
gotBody := w.Body.String()
|
||||
gotStatus := w.Result().StatusCode
|
||||
|
||||
if diff := cmp.Diff(gotBody, tt.wantBody); diff != "" {
|
||||
t.Errorf("RetrieveSession() = %v", diff)
|
||||
}
|
||||
if diff := cmp.Diff(gotStatus, tt.wantStatus); diff != "" {
|
||||
t.Errorf("RetrieveSession() = %v", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStore_SaveSession(t *testing.T) {
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
x interface{}
|
||||
wantErr bool
|
||||
}{
|
||||
{"good", &sessions.State{AccessTokenID: cryptutil.NewBase64Key(), Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}, false},
|
||||
{"bad type", "bad type!", true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cipher, err := cryptutil.NewAEADCipherFromBase64(cryptutil.NewBase64Key())
|
||||
encoder := ecjson.New(cipher)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
cs, err := cookie.NewStore(&cookie.Options{
|
||||
Name: "_pomerium",
|
||||
}, encoder)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
cacheStore := NewStore(encoder, cs, t.Name())
|
||||
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
r.Header.Set("Accept", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
if err := cacheStore.SaveSession(w, r, tt.x); (err != nil) != tt.wantErr {
|
||||
t.Errorf("Store.SaveSession() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
package sessions // import "github.com/pomerium/pomerium/internal/sessions"
|
||||
package cookie // import "github.com/pomerium/pomerium/internal/sessions/cookie"
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
@ -8,8 +8,15 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/encoding"
|
||||
"github.com/pomerium/pomerium/internal/sessions"
|
||||
)
|
||||
|
||||
var _ sessions.SessionStore = &Store{}
|
||||
var _ sessions.SessionLoader = &Store{}
|
||||
|
||||
// timeNow is time.Now but pulled out as a variable for tests.
|
||||
var timeNow = time.Now
|
||||
|
||||
const (
|
||||
// ChunkedCanaryByte is the byte value used as a canary prefix to distinguish if
|
||||
// the cookie is multi-part or not. This constant *should not* be valid
|
||||
|
@ -25,8 +32,8 @@ const (
|
|||
MaxNumChunks = 5
|
||||
)
|
||||
|
||||
// CookieStore implements the session store interface for session cookies.
|
||||
type CookieStore struct {
|
||||
// Store implements the session store interface for session cookies.
|
||||
type Store struct {
|
||||
Name string
|
||||
Domain string
|
||||
Expire time.Duration
|
||||
|
@ -37,8 +44,8 @@ type CookieStore struct {
|
|||
decoder encoding.Unmarshaler
|
||||
}
|
||||
|
||||
// CookieOptions holds options for CookieStore
|
||||
type CookieOptions struct {
|
||||
// Options holds options for Store
|
||||
type Options struct {
|
||||
Name string
|
||||
Domain string
|
||||
Expire time.Duration
|
||||
|
@ -46,8 +53,9 @@ type CookieOptions struct {
|
|||
Secure bool
|
||||
}
|
||||
|
||||
// NewCookieStore returns a new session with ciphers for each of the cookie secrets
|
||||
func NewCookieStore(opts *CookieOptions, encoder encoding.MarshalUnmarshaler) (*CookieStore, error) {
|
||||
// NewStore returns a new store that implements the SessionStore interface
|
||||
// using http cookies.
|
||||
func NewStore(opts *Options, encoder encoding.MarshalUnmarshaler) (sessions.SessionStore, error) {
|
||||
cs, err := NewCookieLoader(opts, encoder)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -56,12 +64,13 @@ func NewCookieStore(opts *CookieOptions, encoder encoding.MarshalUnmarshaler) (*
|
|||
return cs, nil
|
||||
}
|
||||
|
||||
// NewCookieLoader returns a new session with ciphers for each of the cookie secrets
|
||||
func NewCookieLoader(opts *CookieOptions, dencoder encoding.Unmarshaler) (*CookieStore, error) {
|
||||
// NewCookieLoader returns a new store that implements the SessionLoader
|
||||
// interface using http cookies.
|
||||
func NewCookieLoader(opts *Options, dencoder encoding.Unmarshaler) (*Store, error) {
|
||||
if dencoder == nil {
|
||||
return nil, fmt.Errorf("internal/sessions: dencoder cannot be nil")
|
||||
}
|
||||
cs, err := newCookieStore(opts)
|
||||
cs, err := newStore(opts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -69,12 +78,12 @@ func NewCookieLoader(opts *CookieOptions, dencoder encoding.Unmarshaler) (*Cooki
|
|||
return cs, nil
|
||||
}
|
||||
|
||||
func newCookieStore(opts *CookieOptions) (*CookieStore, error) {
|
||||
func newStore(opts *Options) (*Store, error) {
|
||||
if opts.Name == "" {
|
||||
return nil, fmt.Errorf("internal/sessions: cookie name cannot be empty")
|
||||
}
|
||||
|
||||
return &CookieStore{
|
||||
return &Store{
|
||||
Name: opts.Name,
|
||||
Secure: opts.Secure,
|
||||
HTTPOnly: opts.HTTPOnly,
|
||||
|
@ -83,7 +92,7 @@ func newCookieStore(opts *CookieOptions) (*CookieStore, error) {
|
|||
}, nil
|
||||
}
|
||||
|
||||
func (cs *CookieStore) makeCookie(value string) *http.Cookie {
|
||||
func (cs *Store) makeCookie(value string) *http.Cookie {
|
||||
return &http.Cookie{
|
||||
Name: cs.Name,
|
||||
Value: value,
|
||||
|
@ -96,7 +105,7 @@ func (cs *CookieStore) makeCookie(value string) *http.Cookie {
|
|||
}
|
||||
|
||||
// ClearSession clears the session cookie from a request
|
||||
func (cs *CookieStore) ClearSession(w http.ResponseWriter, r *http.Request) {
|
||||
func (cs *Store) ClearSession(w http.ResponseWriter, r *http.Request) {
|
||||
c := cs.makeCookie("")
|
||||
c.MaxAge = -1
|
||||
c.Expires = timeNow().Add(-time.Hour)
|
||||
|
@ -115,51 +124,51 @@ func getCookies(r *http.Request, name string) []*http.Cookie {
|
|||
}
|
||||
|
||||
// LoadSession returns a State from the cookie in the request.
|
||||
func (cs *CookieStore) LoadSession(r *http.Request) (*State, error) {
|
||||
func (cs *Store) LoadSession(r *http.Request) (*sessions.State, error) {
|
||||
cookies := getCookies(r, cs.Name)
|
||||
if len(cookies) == 0 {
|
||||
return nil, ErrNoSessionFound
|
||||
return nil, sessions.ErrNoSessionFound
|
||||
}
|
||||
for _, cookie := range cookies {
|
||||
data := loadChunkedCookie(r, cookie)
|
||||
|
||||
session := &State{}
|
||||
session := &sessions.State{}
|
||||
err := cs.decoder.Unmarshal([]byte(data), session)
|
||||
if err == nil {
|
||||
return session, nil
|
||||
}
|
||||
}
|
||||
return nil, ErrMalformed
|
||||
return nil, sessions.ErrMalformed
|
||||
}
|
||||
|
||||
// SaveSession saves a session state to a request's cookie store.
|
||||
func (cs *CookieStore) SaveSession(w http.ResponseWriter, _ *http.Request, x interface{}) error {
|
||||
func (cs *Store) SaveSession(w http.ResponseWriter, _ *http.Request, x interface{}) error {
|
||||
var value string
|
||||
if cs.encoder != nil {
|
||||
switch v := x.(type) {
|
||||
case []byte:
|
||||
value = string(v)
|
||||
case string:
|
||||
value = v
|
||||
default:
|
||||
if cs.encoder == nil {
|
||||
return errors.New("internal/sessions: cannot save non-string type")
|
||||
}
|
||||
data, err := cs.encoder.Marshal(x)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
value = string(data)
|
||||
} else {
|
||||
switch v := x.(type) {
|
||||
case []byte:
|
||||
value = string(v)
|
||||
case string:
|
||||
value = v
|
||||
default:
|
||||
return errors.New("internal/sessions: cannot save non-string type")
|
||||
}
|
||||
}
|
||||
|
||||
cs.setSessionCookie(w, value)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cs *CookieStore) setSessionCookie(w http.ResponseWriter, val string) {
|
||||
func (cs *Store) setSessionCookie(w http.ResponseWriter, val string) {
|
||||
cs.setCookie(w, cs.makeCookie(val))
|
||||
}
|
||||
|
||||
func (cs *CookieStore) setCookie(w http.ResponseWriter, cookie *http.Cookie) {
|
||||
func (cs *Store) setCookie(w http.ResponseWriter, cookie *http.Cookie) {
|
||||
if len(cookie.String()) <= MaxChunkSize {
|
||||
http.SetCookie(w, cookie)
|
||||
return
|
||||
|
@ -180,20 +189,26 @@ func (cs *CookieStore) setCookie(w http.ResponseWriter, cookie *http.Cookie) {
|
|||
}
|
||||
|
||||
func loadChunkedCookie(r *http.Request, c *http.Cookie) string {
|
||||
data := c.Value
|
||||
// if the first byte is our canary byte, we need to handle the multipart bit
|
||||
if []byte(c.Value)[0] == ChunkedCanaryByte {
|
||||
var b strings.Builder
|
||||
fmt.Fprintf(&b, "%s", data[1:])
|
||||
for i := 1; i <= MaxNumChunks; i++ {
|
||||
next, err := r.Cookie(fmt.Sprintf("%s_%d", c.Name, i))
|
||||
if err != nil {
|
||||
break // break if we can't find the next cookie
|
||||
}
|
||||
fmt.Fprintf(&b, "%s", next.Value)
|
||||
}
|
||||
data = b.String()
|
||||
if len(c.Value) == 0 {
|
||||
return ""
|
||||
}
|
||||
// if the first byte is our canary byte, we need to handle the multipart bit
|
||||
if []byte(c.Value)[0] != ChunkedCanaryByte {
|
||||
return c.Value
|
||||
}
|
||||
|
||||
data := c.Value
|
||||
var b strings.Builder
|
||||
fmt.Fprintf(&b, "%s", data[1:])
|
||||
for i := 1; i <= MaxNumChunks; i++ {
|
||||
next, err := r.Cookie(fmt.Sprintf("%s_%d", c.Name, i))
|
||||
if err != nil {
|
||||
break // break if we can't find the next cookie
|
||||
}
|
||||
fmt.Fprintf(&b, "%s", next.Value)
|
||||
}
|
||||
data = b.String()
|
||||
|
||||
return data
|
||||
}
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
package sessions // import "github.com/pomerium/pomerium/internal/sessions"
|
||||
package cookie // import "github.com/pomerium/pomerium/internal/sessions/cookie"
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
|
@ -13,12 +13,13 @@ import (
|
|||
"github.com/pomerium/pomerium/internal/encoding"
|
||||
"github.com/pomerium/pomerium/internal/encoding/ecjson"
|
||||
"github.com/pomerium/pomerium/internal/encoding/mock"
|
||||
"github.com/pomerium/pomerium/internal/sessions"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/google/go-cmp/cmp/cmpopts"
|
||||
)
|
||||
|
||||
func TestNewCookieStore(t *testing.T) {
|
||||
func TestNewStore(t *testing.T) {
|
||||
cipher, err := cryptutil.NewAEADCipher(cryptutil.NewKey())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
@ -26,28 +27,28 @@ func TestNewCookieStore(t *testing.T) {
|
|||
encoder := ecjson.New(cipher)
|
||||
tests := []struct {
|
||||
name string
|
||||
opts *CookieOptions
|
||||
opts *Options
|
||||
encoder encoding.MarshalUnmarshaler
|
||||
want *CookieStore
|
||||
want sessions.SessionStore
|
||||
wantErr bool
|
||||
}{
|
||||
{"good", &CookieOptions{Name: "_cookie", Secure: true, HTTPOnly: true, Domain: "pomerium.io", Expire: 10 * time.Second}, encoder, &CookieStore{Name: "_cookie", Secure: true, HTTPOnly: true, Domain: "pomerium.io", Expire: 10 * time.Second}, false},
|
||||
{"missing name", &CookieOptions{Name: "", Secure: true, HTTPOnly: true, Domain: "pomerium.io", Expire: 10 * time.Second}, encoder, nil, true},
|
||||
{"missing encoder", &CookieOptions{Name: "_cookie", Secure: true, HTTPOnly: true, Domain: "pomerium.io", Expire: 10 * time.Second}, nil, nil, true},
|
||||
{"good", &Options{Name: "_cookie", Secure: true, HTTPOnly: true, Domain: "pomerium.io", Expire: 10 * time.Second}, encoder, &Store{Name: "_cookie", Secure: true, HTTPOnly: true, Domain: "pomerium.io", Expire: 10 * time.Second}, false},
|
||||
{"missing name", &Options{Name: "", Secure: true, HTTPOnly: true, Domain: "pomerium.io", Expire: 10 * time.Second}, encoder, nil, true},
|
||||
{"missing encoder", &Options{Name: "_cookie", Secure: true, HTTPOnly: true, Domain: "pomerium.io", Expire: 10 * time.Second}, nil, nil, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := NewCookieStore(tt.opts, tt.encoder)
|
||||
got, err := NewStore(tt.opts, tt.encoder)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("NewCookieStore() error = %v, wantErr %v", err, tt.wantErr)
|
||||
t.Errorf("NewStore() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
cmpOpts := []cmp.Option{
|
||||
cmpopts.IgnoreUnexported(CookieStore{}),
|
||||
cmpopts.IgnoreUnexported(Store{}),
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(got, tt.want, cmpOpts...); diff != "" {
|
||||
t.Errorf("NewCookieStore() = %s", diff)
|
||||
t.Errorf("NewStore() = %s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
@ -60,14 +61,14 @@ func TestNewCookieLoader(t *testing.T) {
|
|||
encoder := ecjson.New(cipher)
|
||||
tests := []struct {
|
||||
name string
|
||||
opts *CookieOptions
|
||||
opts *Options
|
||||
encoder encoding.MarshalUnmarshaler
|
||||
want *CookieStore
|
||||
want *Store
|
||||
wantErr bool
|
||||
}{
|
||||
{"good", &CookieOptions{Name: "_cookie", Secure: true, HTTPOnly: true, Domain: "pomerium.io", Expire: 10 * time.Second}, encoder, &CookieStore{Name: "_cookie", Secure: true, HTTPOnly: true, Domain: "pomerium.io", Expire: 10 * time.Second}, false},
|
||||
{"missing name", &CookieOptions{Name: "", Secure: true, HTTPOnly: true, Domain: "pomerium.io", Expire: 10 * time.Second}, encoder, nil, true},
|
||||
{"missing encoder", &CookieOptions{Name: "_cookie", Secure: true, HTTPOnly: true, Domain: "pomerium.io", Expire: 10 * time.Second}, nil, nil, true},
|
||||
{"good", &Options{Name: "_cookie", Secure: true, HTTPOnly: true, Domain: "pomerium.io", Expire: 10 * time.Second}, encoder, &Store{Name: "_cookie", Secure: true, HTTPOnly: true, Domain: "pomerium.io", Expire: 10 * time.Second}, false},
|
||||
{"missing name", &Options{Name: "", Secure: true, HTTPOnly: true, Domain: "pomerium.io", Expire: 10 * time.Second}, encoder, nil, true},
|
||||
{"missing encoder", &Options{Name: "_cookie", Secure: true, HTTPOnly: true, Domain: "pomerium.io", Expire: 10 * time.Second}, nil, nil, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
@ -77,7 +78,7 @@ func TestNewCookieLoader(t *testing.T) {
|
|||
return
|
||||
}
|
||||
cmpOpts := []cmp.Option{
|
||||
cmpopts.IgnoreUnexported(CookieStore{}),
|
||||
cmpopts.IgnoreUnexported(Store{}),
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(got, tt.want, cmpOpts...); diff != "" {
|
||||
|
@ -87,7 +88,7 @@ func TestNewCookieLoader(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestCookieStore_SaveSession(t *testing.T) {
|
||||
func TestStore_SaveSession(t *testing.T) {
|
||||
c, err := cryptutil.NewAEADCipher(cryptutil.NewKey())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
@ -106,17 +107,17 @@ func TestCookieStore_SaveSession(t *testing.T) {
|
|||
wantErr bool
|
||||
wantLoadErr bool
|
||||
}{
|
||||
{"good", &State{Email: "user@domain.com", User: "user"}, ecjson.New(c), ecjson.New(c), false, false},
|
||||
{"bad cipher", &State{Email: "user@domain.com", User: "user"}, nil, nil, true, true},
|
||||
{"huge cookie", &State{Subject: fmt.Sprintf("%x", hugeString), Email: "user@domain.com", User: "user"}, ecjson.New(c), ecjson.New(c), false, false},
|
||||
{"marshal error", &State{Email: "user@domain.com", User: "user"}, mock.Encoder{MarshalError: errors.New("error")}, ecjson.New(c), true, true},
|
||||
{"nil encoder cannot save non string type", &State{Email: "user@domain.com", User: "user"}, nil, ecjson.New(c), true, true},
|
||||
{"good", &sessions.State{Email: "user@domain.com", User: "user"}, ecjson.New(c), ecjson.New(c), false, false},
|
||||
{"bad cipher", &sessions.State{Email: "user@domain.com", User: "user"}, nil, nil, true, true},
|
||||
{"huge cookie", &sessions.State{Subject: fmt.Sprintf("%x", hugeString), Email: "user@domain.com", User: "user"}, ecjson.New(c), ecjson.New(c), false, false},
|
||||
{"marshal error", &sessions.State{Email: "user@domain.com", User: "user"}, mock.Encoder{MarshalError: errors.New("error")}, ecjson.New(c), true, true},
|
||||
{"nil encoder cannot save non string type", &sessions.State{Email: "user@domain.com", User: "user"}, nil, ecjson.New(c), true, true},
|
||||
{"good marshal string directly", cryptutil.NewBase64Key(), nil, ecjson.New(c), false, true},
|
||||
{"good marshal bytes directly", cryptutil.NewKey(), nil, ecjson.New(c), false, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
s := &CookieStore{
|
||||
s := &Store{
|
||||
Name: "_pomerium",
|
||||
Secure: true,
|
||||
HTTPOnly: true,
|
||||
|
@ -130,7 +131,7 @@ func TestCookieStore_SaveSession(t *testing.T) {
|
|||
w := httptest.NewRecorder()
|
||||
|
||||
if err := s.SaveSession(w, r, tt.State); (err != nil) != tt.wantErr {
|
||||
t.Errorf("CookieStore.SaveSession() error = %v, wantErr %v", err, tt.wantErr)
|
||||
t.Errorf("Store.SaveSession() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
r = httptest.NewRequest("GET", "/", nil)
|
||||
for _, cookie := range w.Result().Cookies() {
|
||||
|
@ -143,11 +144,11 @@ func TestCookieStore_SaveSession(t *testing.T) {
|
|||
return
|
||||
}
|
||||
cmpOpts := []cmp.Option{
|
||||
cmpopts.IgnoreUnexported(State{}),
|
||||
cmpopts.IgnoreUnexported(sessions.State{}),
|
||||
}
|
||||
if err == nil {
|
||||
if diff := cmp.Diff(state, tt.State, cmpOpts...); diff != "" {
|
||||
t.Errorf("CookieStore.LoadSession() got = %s", diff)
|
||||
t.Errorf("Store.LoadSession() got = %s", diff)
|
||||
}
|
||||
}
|
||||
w = httptest.NewRecorder()
|
90
internal/sessions/cookie/middleware_test.go
Normal file
90
internal/sessions/cookie/middleware_test.go
Normal file
|
@ -0,0 +1,90 @@
|
|||
package cookie // import "github.com/pomerium/pomerium/internal/sessions/cookie"
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/sessions"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/pomerium/pomerium/internal/cryptutil"
|
||||
"github.com/pomerium/pomerium/internal/encoding/ecjson"
|
||||
"gopkg.in/square/go-jose.v2/jwt"
|
||||
)
|
||||
|
||||
func testAuthorizer(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, err := sessions.FromContext(r.Context())
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
func TestVerifier(t *testing.T) {
|
||||
fnh := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
fmt.Fprint(w, http.StatusText(http.StatusOK))
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
state sessions.State
|
||||
|
||||
wantBody string
|
||||
wantStatus int
|
||||
}{
|
||||
{"good cookie session", sessions.State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}, http.StatusText(http.StatusOK), http.StatusOK},
|
||||
{"expired cookie", sessions.State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}, "internal/sessions: validation failed, token is expired (exp)\n", http.StatusUnauthorized},
|
||||
{"malformed cookie", sessions.State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}, "internal/sessions: session is malformed\n", http.StatusUnauthorized},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cipher, err := cryptutil.NewAEADCipherFromBase64(cryptutil.NewBase64Key())
|
||||
encoder := ecjson.New(cipher)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
encSession, err := encoder.Marshal(&tt.state)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if strings.Contains(tt.name, "malformed") {
|
||||
// add some garbage to the end of the string
|
||||
encSession = append(encSession, cryptutil.NewKey()...)
|
||||
}
|
||||
|
||||
cs, err := NewStore(&Options{
|
||||
Name: "_pomerium",
|
||||
}, encoder)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
r.Header.Set("Accept", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
r.AddCookie(&http.Cookie{Name: "_pomerium", Value: string(encSession)})
|
||||
|
||||
got := sessions.RetrieveSession(cs)(testAuthorizer((fnh)))
|
||||
got.ServeHTTP(w, r)
|
||||
|
||||
gotBody := w.Body.String()
|
||||
gotStatus := w.Result().StatusCode
|
||||
|
||||
if diff := cmp.Diff(gotBody, tt.wantBody); diff != "" {
|
||||
t.Errorf("RetrieveSession() = %v", diff)
|
||||
}
|
||||
if diff := cmp.Diff(gotStatus, tt.wantStatus); diff != "" {
|
||||
t.Errorf("RetrieveSession() = %v", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
28
internal/sessions/errors.go
Normal file
28
internal/sessions/errors.go
Normal file
|
@ -0,0 +1,28 @@
|
|||
package sessions // import "github.com/pomerium/pomerium/internal/sessions"
|
||||
|
||||
import (
|
||||
"errors"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrNoSessionFound is the error for when no session is found.
|
||||
ErrNoSessionFound = errors.New("internal/sessions: session is not found")
|
||||
|
||||
// ErrMalformed is the error for when a session is found but is malformed.
|
||||
ErrMalformed = errors.New("internal/sessions: session is malformed")
|
||||
|
||||
// ErrNotValidYet indicates that token is used before time indicated in nbf claim.
|
||||
ErrNotValidYet = errors.New("internal/sessions: validation failed, token not valid yet (nbf)")
|
||||
|
||||
// ErrExpired indicates that token is used after expiry time indicated in exp claim.
|
||||
ErrExpired = errors.New("internal/sessions: validation failed, token is expired (exp)")
|
||||
|
||||
// ErrExpiryRequired indicates that the token does not contain a valid expiry (exp) claim.
|
||||
ErrExpiryRequired = errors.New("internal/sessions: validation failed, token expiry (exp) is required")
|
||||
|
||||
// ErrIssuedInTheFuture indicates that the iat field is in the future.
|
||||
ErrIssuedInTheFuture = errors.New("internal/sessions: validation field, token issued in the future (iat)")
|
||||
|
||||
// ErrInvalidAudience indicated invalid aud claim.
|
||||
ErrInvalidAudience = errors.New("internal/sessions: validation failed, invalid audience claim (aud)")
|
||||
)
|
|
@ -1,35 +1,38 @@
|
|||
package sessions // import "github.com/pomerium/pomerium/internal/sessions"
|
||||
package header // import "github.com/pomerium/pomerium/internal/sessions/header"
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/encoding"
|
||||
"github.com/pomerium/pomerium/internal/sessions"
|
||||
)
|
||||
|
||||
var _ sessions.SessionLoader = &Store{}
|
||||
|
||||
const (
|
||||
defaultAuthHeader = "Authorization"
|
||||
defaultAuthType = "Bearer"
|
||||
)
|
||||
|
||||
// HeaderStore implements the load session store interface using http
|
||||
// Store implements the load session store interface using http
|
||||
// authorization headers.
|
||||
type HeaderStore struct {
|
||||
type Store struct {
|
||||
authHeader string
|
||||
authType string
|
||||
encoder encoding.Unmarshaler
|
||||
}
|
||||
|
||||
// NewHeaderStore returns a new header store for loading sessions from
|
||||
// NewStore returns a new header store for loading sessions from
|
||||
// authorization header as defined in as defined in rfc2617
|
||||
//
|
||||
// NOTA BENE: While most servers do not log Authorization headers by default,
|
||||
// you should ensure no other services are logging or leaking your auth headers.
|
||||
func NewHeaderStore(enc encoding.Unmarshaler, headerType string) *HeaderStore {
|
||||
func NewStore(enc encoding.Unmarshaler, headerType string) *Store {
|
||||
if headerType == "" {
|
||||
headerType = defaultAuthType
|
||||
}
|
||||
return &HeaderStore{
|
||||
return &Store{
|
||||
authHeader: defaultAuthHeader,
|
||||
authType: headerType,
|
||||
encoder: enc,
|
||||
|
@ -37,14 +40,14 @@ func NewHeaderStore(enc encoding.Unmarshaler, headerType string) *HeaderStore {
|
|||
}
|
||||
|
||||
// LoadSession tries to retrieve the token string from the Authorization header.
|
||||
func (as *HeaderStore) LoadSession(r *http.Request) (*State, error) {
|
||||
func (as *Store) LoadSession(r *http.Request) (*sessions.State, error) {
|
||||
cipherText := TokenFromHeader(r, as.authHeader, as.authType)
|
||||
if cipherText == "" {
|
||||
return nil, ErrNoSessionFound
|
||||
return nil, sessions.ErrNoSessionFound
|
||||
}
|
||||
var session State
|
||||
var session sessions.State
|
||||
if err := as.encoder.Unmarshal([]byte(cipherText), &session); err != nil {
|
||||
return nil, ErrMalformed
|
||||
return nil, sessions.ErrMalformed
|
||||
}
|
||||
return &session, nil
|
||||
}
|
90
internal/sessions/header/middleware_test.go
Normal file
90
internal/sessions/header/middleware_test.go
Normal file
|
@ -0,0 +1,90 @@
|
|||
package header // import "github.com/pomerium/pomerium/internal/sessions/header"
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/cryptutil"
|
||||
"github.com/pomerium/pomerium/internal/encoding/ecjson"
|
||||
"github.com/pomerium/pomerium/internal/sessions"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"gopkg.in/square/go-jose.v2/jwt"
|
||||
)
|
||||
|
||||
func testAuthorizer(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, err := sessions.FromContext(r.Context())
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
func TestVerifier(t *testing.T) {
|
||||
fnh := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
fmt.Fprint(w, http.StatusText(http.StatusOK))
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
authType string
|
||||
state sessions.State
|
||||
wantBody string
|
||||
wantStatus int
|
||||
}{
|
||||
{"good auth header session", "Bearer ", sessions.State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}, http.StatusText(http.StatusOK), http.StatusOK},
|
||||
{"expired auth header", "Bearer ", sessions.State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}, "internal/sessions: validation failed, token is expired (exp)\n", http.StatusUnauthorized},
|
||||
{"malformed auth header", "Bearer ", sessions.State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}, "internal/sessions: session is malformed\n", http.StatusUnauthorized},
|
||||
{"empty auth header", "Bearer ", sessions.State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}, "internal/sessions: session is not found\n", http.StatusUnauthorized},
|
||||
{"bad auth type", "bees ", sessions.State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}, "internal/sessions: session is not found\n", http.StatusUnauthorized},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cipher, err := cryptutil.NewAEADCipherFromBase64(cryptutil.NewBase64Key())
|
||||
encoder := ecjson.New(cipher)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
encSession, err := encoder.Marshal(&tt.state)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if strings.Contains(tt.name, "malformed") {
|
||||
// add some garbage to the end of the string
|
||||
encSession = append(encSession, cryptutil.NewKey()...)
|
||||
}
|
||||
s := NewStore(encoder, "")
|
||||
|
||||
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
r.Header.Set("Accept", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
if strings.Contains(tt.name, "empty") {
|
||||
encSession = []byte("")
|
||||
}
|
||||
r.Header.Set("Authorization", tt.authType+string(encSession))
|
||||
|
||||
got := sessions.RetrieveSession(s)(testAuthorizer((fnh)))
|
||||
got.ServeHTTP(w, r)
|
||||
|
||||
gotBody := w.Body.String()
|
||||
gotStatus := w.Result().StatusCode
|
||||
|
||||
if diff := cmp.Diff(gotBody, tt.wantBody); diff != "" {
|
||||
t.Errorf("RetrieveSession() = %v", diff)
|
||||
}
|
||||
if diff := cmp.Diff(gotStatus, tt.wantStatus); diff != "" {
|
||||
t.Errorf("RetrieveSession() = %v", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -44,7 +44,7 @@ func retrieveFromRequest(r *http.Request, sessions ...SessionLoader) (*State, er
|
|||
}
|
||||
if state != nil {
|
||||
err := state.Verify(urlutil.StripPort(r.Host))
|
||||
return state, err // N.B.: state is _not_ nil_
|
||||
return state, err // N.B.: state is _not_ nil
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -2,16 +2,14 @@ package sessions
|
|||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/pomerium/pomerium/internal/cryptutil"
|
||||
"github.com/pomerium/pomerium/internal/encoding/ecjson"
|
||||
"gopkg.in/square/go-jose.v2/jwt"
|
||||
)
|
||||
|
||||
|
@ -39,103 +37,6 @@ func TestNewContext(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func testAuthorizer(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, err := FromContext(r.Context())
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
func TestVerifier(t *testing.T) {
|
||||
fnh := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
fmt.Fprint(w, http.StatusText(http.StatusOK))
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
// s SessionStore
|
||||
state State
|
||||
|
||||
cookie bool
|
||||
header bool
|
||||
param bool
|
||||
|
||||
wantBody string
|
||||
wantStatus int
|
||||
}{
|
||||
{"good cookie session", State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}, true, false, false, http.StatusText(http.StatusOK), http.StatusOK},
|
||||
{"expired cookie", State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}, true, false, false, "internal/sessions: validation failed, token is expired (exp)\n", http.StatusUnauthorized},
|
||||
{"malformed cookie", State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}, true, false, false, "internal/sessions: session is malformed\n", http.StatusUnauthorized},
|
||||
{"good auth header session", State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}, false, true, false, http.StatusText(http.StatusOK), http.StatusOK},
|
||||
{"expired auth header", State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}, false, true, false, "internal/sessions: validation failed, token is expired (exp)\n", http.StatusUnauthorized},
|
||||
{"malformed auth header", State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}, false, true, false, "internal/sessions: session is malformed\n", http.StatusUnauthorized},
|
||||
{"good auth query param session", State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}, false, true, true, http.StatusText(http.StatusOK), http.StatusOK},
|
||||
{"expired auth query param", State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}, false, false, true, "internal/sessions: validation failed, token is expired (exp)\n", http.StatusUnauthorized},
|
||||
{"malformed auth query param", State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}, false, false, true, "internal/sessions: session is malformed\n", http.StatusUnauthorized},
|
||||
{"no session", State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}, false, false, false, "internal/sessions: session is not found\n", http.StatusUnauthorized},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cipher, err := cryptutil.NewAEADCipherFromBase64(cryptutil.NewBase64Key())
|
||||
encoder := ecjson.New(cipher)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
encSession, err := encoder.Marshal(&tt.state)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if strings.Contains(tt.name, "malformed") {
|
||||
// add some garbage to the end of the string
|
||||
encSession = append(encSession, cryptutil.NewKey()...)
|
||||
}
|
||||
|
||||
cs, err := NewCookieStore(&CookieOptions{
|
||||
Name: "_pomerium",
|
||||
}, encoder)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
as := NewHeaderStore(encoder, "")
|
||||
|
||||
qp := NewQueryParamStore(encoder, "")
|
||||
|
||||
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
r.Header.Set("Accept", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
if tt.cookie {
|
||||
r.AddCookie(&http.Cookie{Name: "_pomerium", Value: string(encSession)})
|
||||
} else if tt.header {
|
||||
r.Header.Set("Authorization", "Bearer "+string(encSession))
|
||||
} else if tt.param {
|
||||
q := r.URL.Query()
|
||||
|
||||
q.Set("pomerium_session", string(encSession))
|
||||
r.URL.RawQuery = q.Encode()
|
||||
}
|
||||
|
||||
got := RetrieveSession(cs, as, qp)(testAuthorizer((fnh)))
|
||||
got.ServeHTTP(w, r)
|
||||
|
||||
gotBody := w.Body.String()
|
||||
gotStatus := w.Result().StatusCode
|
||||
|
||||
if diff := cmp.Diff(gotBody, tt.wantBody); diff != "" {
|
||||
t.Errorf("RetrieveSession() = %v", diff)
|
||||
}
|
||||
if diff := cmp.Diff(gotStatus, tt.wantStatus); diff != "" {
|
||||
t.Errorf("RetrieveSession() = %v", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_contextKey_String(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
@ -155,3 +56,80 @@ func Test_contextKey_String(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func testAuthorizer(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, err := FromContext(r.Context())
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
var _ SessionStore = &store{}
|
||||
|
||||
// Store is a mock implementation of the SessionStore interface
|
||||
type store struct {
|
||||
ResponseSession string
|
||||
Session *State
|
||||
SaveError error
|
||||
LoadError error
|
||||
}
|
||||
|
||||
// ClearSession clears the ResponseSession
|
||||
func (ms *store) ClearSession(http.ResponseWriter, *http.Request) {
|
||||
ms.ResponseSession = ""
|
||||
}
|
||||
|
||||
// LoadSession returns the session and a error
|
||||
func (ms store) LoadSession(*http.Request) (*State, error) {
|
||||
return ms.Session, ms.LoadError
|
||||
}
|
||||
|
||||
// SaveSession returns a save error.
|
||||
func (ms store) SaveSession(http.ResponseWriter, *http.Request, interface{}) error {
|
||||
return ms.SaveError
|
||||
}
|
||||
|
||||
func TestVerifier(t *testing.T) {
|
||||
fnh := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
fmt.Fprint(w, http.StatusText(http.StatusOK))
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
store store
|
||||
state State
|
||||
wantBody string
|
||||
wantStatus int
|
||||
}{
|
||||
{"empty session", store{}, State{}, "internal/sessions: session is not found\n", 401},
|
||||
{"simple good load", store{Session: &State{Subject: "hi", Expiry: jwt.NewNumericDate(time.Now().Add(time.Second))}}, State{}, "OK", 200},
|
||||
{"empty session", store{LoadError: errors.New("err")}, State{}, "err\n", 401},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
||||
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
r.Header.Set("Accept", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
got := RetrieveSession(tt.store)(testAuthorizer((fnh)))
|
||||
got.ServeHTTP(w, r)
|
||||
|
||||
gotBody := w.Body.String()
|
||||
gotStatus := w.Result().StatusCode
|
||||
|
||||
if diff := cmp.Diff(gotBody, tt.wantBody); diff != "" {
|
||||
t.Errorf("RetrieveSession() = %v", diff)
|
||||
}
|
||||
if diff := cmp.Diff(gotStatus, tt.wantStatus); diff != "" {
|
||||
t.Errorf("RetrieveSession() = %v", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
33
internal/sessions/mock/mock_store.go
Normal file
33
internal/sessions/mock/mock_store.go
Normal file
|
@ -0,0 +1,33 @@
|
|||
package mock // import "github.com/pomerium/pomerium/internal/sessions/mock"
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/sessions"
|
||||
)
|
||||
|
||||
var _ sessions.SessionStore = &Store{}
|
||||
var _ sessions.SessionLoader = &Store{}
|
||||
|
||||
// Store is a mock implementation of the SessionStore interface
|
||||
type Store struct {
|
||||
ResponseSession string
|
||||
Session *sessions.State
|
||||
SaveError error
|
||||
LoadError error
|
||||
}
|
||||
|
||||
// ClearSession clears the ResponseSession
|
||||
func (ms *Store) ClearSession(http.ResponseWriter, *http.Request) {
|
||||
ms.ResponseSession = ""
|
||||
}
|
||||
|
||||
// LoadSession returns the session and a error
|
||||
func (ms Store) LoadSession(*http.Request) (*sessions.State, error) {
|
||||
return ms.Session, ms.LoadError
|
||||
}
|
||||
|
||||
// SaveSession returns a save error.
|
||||
func (ms Store) SaveSession(http.ResponseWriter, *http.Request, interface{}) error {
|
||||
return ms.SaveError
|
||||
}
|
|
@ -1,26 +1,28 @@
|
|||
package sessions // import "github.com/pomerium/pomerium/internal/sessions"
|
||||
package mock // import "github.com/pomerium/pomerium/internal/sessions/mock"
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/sessions"
|
||||
)
|
||||
|
||||
func TestMockSessionStore(t *testing.T) {
|
||||
func TestStore(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
mockCSRF *MockSessionStore
|
||||
saveSession *State
|
||||
mockCSRF *Store
|
||||
saveSession *sessions.State
|
||||
wantLoadErr bool
|
||||
wantSaveErr bool
|
||||
}{
|
||||
{"basic",
|
||||
&MockSessionStore{
|
||||
&Store{
|
||||
ResponseSession: "test",
|
||||
Session: &State{Subject: "0101"},
|
||||
Session: &sessions.State{Subject: "0101"},
|
||||
SaveError: nil,
|
||||
LoadError: nil,
|
||||
},
|
||||
&State{Subject: "0101"},
|
||||
&sessions.State{Subject: "0101"},
|
||||
false,
|
||||
false},
|
||||
}
|
|
@ -1,28 +0,0 @@
|
|||
package sessions // import "github.com/pomerium/pomerium/internal/sessions"
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// MockSessionStore is a mock implementation of the SessionStore interface
|
||||
type MockSessionStore struct {
|
||||
ResponseSession string
|
||||
Session *State
|
||||
SaveError error
|
||||
LoadError error
|
||||
}
|
||||
|
||||
// ClearSession clears the ResponseSession
|
||||
func (ms *MockSessionStore) ClearSession(http.ResponseWriter, *http.Request) {
|
||||
ms.ResponseSession = ""
|
||||
}
|
||||
|
||||
// LoadSession returns the session and a error
|
||||
func (ms MockSessionStore) LoadSession(*http.Request) (*State, error) {
|
||||
return ms.Session, ms.LoadError
|
||||
}
|
||||
|
||||
// SaveSession returns a save error.
|
||||
func (ms MockSessionStore) SaveSession(http.ResponseWriter, *http.Request, interface{}) error {
|
||||
return ms.SaveError
|
||||
}
|
92
internal/sessions/queryparam/middleware_test.go
Normal file
92
internal/sessions/queryparam/middleware_test.go
Normal file
|
@ -0,0 +1,92 @@
|
|||
package queryparam // import "github.com/pomerium/pomerium/internal/sessions/queryparam"
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/cryptutil"
|
||||
"github.com/pomerium/pomerium/internal/encoding/ecjson"
|
||||
"github.com/pomerium/pomerium/internal/sessions"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"gopkg.in/square/go-jose.v2/jwt"
|
||||
)
|
||||
|
||||
func testAuthorizer(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, err := sessions.FromContext(r.Context())
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
func TestVerifier(t *testing.T) {
|
||||
fnh := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
fmt.Fprint(w, http.StatusText(http.StatusOK))
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
state sessions.State
|
||||
|
||||
wantBody string
|
||||
wantStatus int
|
||||
}{
|
||||
{"good auth query param session", sessions.State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}, http.StatusText(http.StatusOK), http.StatusOK},
|
||||
{"expired auth query param", sessions.State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}, "internal/sessions: validation failed, token is expired (exp)\n", http.StatusUnauthorized},
|
||||
{"malformed auth query param", sessions.State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}, "internal/sessions: session is malformed\n", http.StatusUnauthorized},
|
||||
{"empty auth query param", sessions.State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}, "internal/sessions: session is not found\n", http.StatusUnauthorized},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cipher, err := cryptutil.NewAEADCipherFromBase64(cryptutil.NewBase64Key())
|
||||
encoder := ecjson.New(cipher)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
encSession, err := encoder.Marshal(&tt.state)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if strings.Contains(tt.name, "malformed") {
|
||||
// add some garbage to the end of the string
|
||||
encSession = append(encSession, cryptutil.NewKey()...)
|
||||
}
|
||||
|
||||
s := NewStore(encoder, "")
|
||||
|
||||
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
r.Header.Set("Accept", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
q := r.URL.Query()
|
||||
if strings.Contains(tt.name, "empty") {
|
||||
encSession = []byte("")
|
||||
}
|
||||
q.Set("pomerium_session", string(encSession))
|
||||
r.URL.RawQuery = q.Encode()
|
||||
|
||||
got := sessions.RetrieveSession(s)(testAuthorizer((fnh)))
|
||||
got.ServeHTTP(w, r)
|
||||
|
||||
gotBody := w.Body.String()
|
||||
gotStatus := w.Result().StatusCode
|
||||
|
||||
if diff := cmp.Diff(gotBody, tt.wantBody); diff != "" {
|
||||
t.Errorf("RetrieveSession() = %v", diff)
|
||||
}
|
||||
if diff := cmp.Diff(gotStatus, tt.wantStatus); diff != "" {
|
||||
t.Errorf("RetrieveSession() = %v", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -1,33 +1,37 @@
|
|||
package sessions // import "github.com/pomerium/pomerium/internal/sessions"
|
||||
package queryparam // import "github.com/pomerium/pomerium/internal/sessions/queryparam"
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/encoding"
|
||||
"github.com/pomerium/pomerium/internal/sessions"
|
||||
)
|
||||
|
||||
var _ sessions.SessionStore = &Store{}
|
||||
var _ sessions.SessionLoader = &Store{}
|
||||
|
||||
const (
|
||||
defaultQueryParamKey = "pomerium_session"
|
||||
)
|
||||
|
||||
// QueryParamStore implements the load session store interface using http
|
||||
// Store implements the load session store interface using http
|
||||
// query strings / query parameters.
|
||||
type QueryParamStore struct {
|
||||
type Store struct {
|
||||
queryParamKey string
|
||||
encoder encoding.Marshaler
|
||||
decoder encoding.Unmarshaler
|
||||
}
|
||||
|
||||
// NewQueryParamStore returns a new query param store for loading sessions from
|
||||
// NewStore returns a new query param store for loading sessions from
|
||||
// query strings / query parameters.
|
||||
//
|
||||
// NOTA BENE: By default, most servers _DO_ log query params, the leaking or
|
||||
// accidental logging of which should be considered a security issue.
|
||||
func NewQueryParamStore(enc encoding.MarshalUnmarshaler, qp string) *QueryParamStore {
|
||||
func NewStore(enc encoding.MarshalUnmarshaler, qp string) *Store {
|
||||
if qp == "" {
|
||||
qp = defaultQueryParamKey
|
||||
}
|
||||
return &QueryParamStore{
|
||||
return &Store{
|
||||
queryParamKey: qp,
|
||||
encoder: enc,
|
||||
decoder: enc,
|
||||
|
@ -35,27 +39,27 @@ func NewQueryParamStore(enc encoding.MarshalUnmarshaler, qp string) *QueryParamS
|
|||
}
|
||||
|
||||
// LoadSession tries to retrieve the token string from URL query parameters.
|
||||
func (qp *QueryParamStore) LoadSession(r *http.Request) (*State, error) {
|
||||
func (qp *Store) LoadSession(r *http.Request) (*sessions.State, error) {
|
||||
cipherText := r.URL.Query().Get(qp.queryParamKey)
|
||||
if cipherText == "" {
|
||||
return nil, ErrNoSessionFound
|
||||
return nil, sessions.ErrNoSessionFound
|
||||
}
|
||||
var session State
|
||||
var session sessions.State
|
||||
if err := qp.decoder.Unmarshal([]byte(cipherText), &session); err != nil {
|
||||
return nil, ErrMalformed
|
||||
return nil, sessions.ErrMalformed
|
||||
}
|
||||
return &session, nil
|
||||
}
|
||||
|
||||
// ClearSession clears the session cookie from a request's query param key `pomerium_session`.
|
||||
func (qp *QueryParamStore) ClearSession(w http.ResponseWriter, r *http.Request) {
|
||||
func (qp *Store) ClearSession(w http.ResponseWriter, r *http.Request) {
|
||||
params := r.URL.Query()
|
||||
params.Del(qp.queryParamKey)
|
||||
r.URL.RawQuery = params.Encode()
|
||||
}
|
||||
|
||||
// SaveSession sets a session to a request's query param key `pomerium_session`
|
||||
func (qp *QueryParamStore) SaveSession(w http.ResponseWriter, r *http.Request, x interface{}) error {
|
||||
func (qp *Store) SaveSession(w http.ResponseWriter, r *http.Request, x interface{}) error {
|
||||
data, err := qp.encoder.Marshal(x)
|
||||
if err != nil {
|
||||
return err
|
|
@ -1,4 +1,4 @@
|
|||
package sessions
|
||||
package queryparam // import "github.com/pomerium/pomerium/internal/sessions/queryparam"
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
@ -9,39 +9,40 @@ import (
|
|||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/pomerium/pomerium/internal/encoding"
|
||||
"github.com/pomerium/pomerium/internal/encoding/mock"
|
||||
"github.com/pomerium/pomerium/internal/sessions"
|
||||
)
|
||||
|
||||
func TestNewQueryParamStore(t *testing.T) {
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
State *State
|
||||
State *sessions.State
|
||||
|
||||
enc encoding.MarshalUnmarshaler
|
||||
qp string
|
||||
wantErr bool
|
||||
wantURL *url.URL
|
||||
}{
|
||||
{"simple good", &State{Email: "user@domain.com", User: "user"}, mock.Encoder{MarshalResponse: []byte("ok")}, "", false, &url.URL{Path: "/", RawQuery: "pomerium_session=ok"}},
|
||||
{"marshall error", &State{Email: "user@domain.com", User: "user"}, mock.Encoder{MarshalError: errors.New("error")}, "", true, &url.URL{Path: "/"}},
|
||||
{"simple good", &sessions.State{Email: "user@domain.com", User: "user"}, mock.Encoder{MarshalResponse: []byte("ok")}, "", false, &url.URL{Path: "/", RawQuery: "pomerium_session=ok"}},
|
||||
{"marshall error", &sessions.State{Email: "user@domain.com", User: "user"}, mock.Encoder{MarshalError: errors.New("error")}, "", true, &url.URL{Path: "/"}},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := NewQueryParamStore(tt.enc, tt.qp)
|
||||
got := NewStore(tt.enc, tt.qp)
|
||||
|
||||
r := httptest.NewRequest("GET", "/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
if err := got.SaveSession(w, r, tt.State); (err != nil) != tt.wantErr {
|
||||
t.Errorf("NewQueryParamStore.SaveSession() error = %v, wantErr %v", err, tt.wantErr)
|
||||
t.Errorf("NewStore.SaveSession() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(r.URL, tt.wantURL); diff != "" {
|
||||
t.Errorf("NewQueryParamStore() = %v", diff)
|
||||
t.Errorf("NewStore() = %v", diff)
|
||||
}
|
||||
got.ClearSession(w, r)
|
||||
if diff := cmp.Diff(r.URL, &url.URL{Path: "/"}); diff != "" {
|
||||
t.Errorf("NewQueryParamStore() = %v", diff)
|
||||
t.Errorf("NewStore() = %v", diff)
|
||||
}
|
||||
})
|
||||
}
|
|
@ -6,6 +6,8 @@ import (
|
|||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/cespare/xxhash/v2"
|
||||
"github.com/mitchellh/hashstructure"
|
||||
oidc "github.com/pomerium/go-oidc"
|
||||
"golang.org/x/oauth2"
|
||||
"gopkg.in/square/go-jose.v2/jwt"
|
||||
|
@ -51,7 +53,8 @@ type State struct {
|
|||
// programatic access.
|
||||
Programmatic bool `json:"programatic"`
|
||||
|
||||
AccessToken *oauth2.Token `json:"access_token,omitempty"`
|
||||
AccessToken *oauth2.Token `json:"act,omitempty"`
|
||||
AccessTokenID string `json:"ati,omitempty"`
|
||||
|
||||
idToken *oidc.IDToken
|
||||
}
|
||||
|
@ -73,7 +76,7 @@ func NewStateFromTokens(idToken *oidc.IDToken, accessToken *oauth2.Token, audien
|
|||
s.Audience = []string{audience}
|
||||
s.idToken = idToken
|
||||
s.AccessToken = accessToken
|
||||
|
||||
s.AccessTokenID = s.accessTokenHash()
|
||||
return s, nil
|
||||
}
|
||||
|
||||
|
@ -95,6 +98,7 @@ func (s *State) UpdateState(idToken *oidc.IDToken, accessToken *oauth2.Token) er
|
|||
}
|
||||
s.Audience = audience
|
||||
s.Expiry = jwt.NewNumericDate(accessToken.Expiry)
|
||||
s.AccessTokenID = s.accessTokenHash()
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -173,3 +177,13 @@ func (s *State) SetImpersonation(email, groups string) {
|
|||
s.ImpersonateGroups = strings.Split(groups, ",")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *State) accessTokenHash() string {
|
||||
hash, err := hashstructure.Hash(
|
||||
s.AccessToken,
|
||||
&hashstructure.HashOptions{Hasher: xxhash.New()})
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return fmt.Sprintf("%x", hash)
|
||||
}
|
||||
|
|
|
@ -124,3 +124,26 @@ func TestState_RouteSession(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestState_accessTokenHash(t *testing.T) {
|
||||
t.Parallel()
|
||||
tests := []struct {
|
||||
name string
|
||||
state State
|
||||
want string
|
||||
}{
|
||||
{"empty access token", State{}, "34c96acdcadb1bbb"},
|
||||
{"no change to access token", State{Subject: "test"}, "34c96acdcadb1bbb"},
|
||||
{"empty oauth2 token", State{AccessToken: &oauth2.Token{}}, "bbd82197d215198f"},
|
||||
{"refresh token a", State{AccessToken: &oauth2.Token{RefreshToken: "a"}}, "76316ac79b301bd6"},
|
||||
{"refresh token b", State{AccessToken: &oauth2.Token{RefreshToken: "b"}}, "fab7cb29e50161f1"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
s := &tt.state
|
||||
if got := s.accessTokenHash(); got != tt.want {
|
||||
t.Errorf("State.accessTokenHash() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,39 +1,17 @@
|
|||
package sessions // import "github.com/pomerium/pomerium/internal/sessions"
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrNoSessionFound is the error for when no session is found.
|
||||
ErrNoSessionFound = errors.New("internal/sessions: session is not found")
|
||||
|
||||
// ErrMalformed is the error for when a session is found but is malformed.
|
||||
ErrMalformed = errors.New("internal/sessions: session is malformed")
|
||||
|
||||
// ErrNotValidYet indicates that token is used before time indicated in nbf claim.
|
||||
ErrNotValidYet = errors.New("internal/sessions: validation failed, token not valid yet (nbf)")
|
||||
|
||||
// ErrExpired indicates that token is used after expiry time indicated in exp claim.
|
||||
ErrExpired = errors.New("internal/sessions: validation failed, token is expired (exp)")
|
||||
|
||||
// ErrIssuedInTheFuture indicates that the iat field is in the future.
|
||||
ErrIssuedInTheFuture = errors.New("internal/sessions: validation field, token issued in the future (iat)")
|
||||
|
||||
// ErrInvalidAudience indicated invalid aud claim.
|
||||
ErrInvalidAudience = errors.New("internal/sessions: validation failed, invalid audience claim (aud)")
|
||||
)
|
||||
|
||||
// SessionStore has the functions for setting, getting, and clearing the Session cookie
|
||||
// SessionStore defines an interface for loading, saving, and clearing a session.
|
||||
type SessionStore interface {
|
||||
ClearSession(http.ResponseWriter, *http.Request)
|
||||
SessionLoader
|
||||
ClearSession(http.ResponseWriter, *http.Request)
|
||||
SaveSession(http.ResponseWriter, *http.Request, interface{}) error
|
||||
}
|
||||
|
||||
// SessionLoader is implemented by any struct that loads a pomerium session
|
||||
// given a request, and returns a user state.
|
||||
// SessionLoader defines an interface for loading a session.
|
||||
type SessionLoader interface {
|
||||
LoadSession(*http.Request) (*State, error)
|
||||
}
|
||||
|
|
|
@ -34,10 +34,10 @@ var (
|
|||
// DefaultViews are a set of default views to view HTTP and GRPC metrics.
|
||||
var (
|
||||
DefaultViews = [][]*view.View{
|
||||
GRPCServerViews,
|
||||
HTTPServerViews,
|
||||
GRPCClientViews,
|
||||
GRPCServerViews,
|
||||
HTTPClientViews,
|
||||
HTTPServerViews,
|
||||
InfoViews,
|
||||
}
|
||||
)
|
||||
|
|
|
@ -9,14 +9,16 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"gopkg.in/square/go-jose.v2/jwt"
|
||||
|
||||
"github.com/pomerium/pomerium/config"
|
||||
"github.com/pomerium/pomerium/internal/encoding"
|
||||
"github.com/pomerium/pomerium/internal/encoding/mock"
|
||||
"github.com/pomerium/pomerium/internal/httputil"
|
||||
"github.com/pomerium/pomerium/internal/sessions"
|
||||
mstore "github.com/pomerium/pomerium/internal/sessions/mock"
|
||||
"github.com/pomerium/pomerium/internal/urlutil"
|
||||
"github.com/pomerium/pomerium/proxy/clients"
|
||||
"gopkg.in/square/go-jose.v2/jwt"
|
||||
)
|
||||
|
||||
func TestProxy_ForwardAuth(t *testing.T) {
|
||||
|
@ -40,29 +42,29 @@ func TestProxy_ForwardAuth(t *testing.T) {
|
|||
wantStatus int
|
||||
wantBody string
|
||||
}{
|
||||
{"good redirect not required", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusOK, "Access to some.domain.example is allowed."},
|
||||
{"good verify only, no redirect", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/verify", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusOK, ""},
|
||||
{"good redirect not required", opts, nil, http.MethodGet, nil, nil, "/", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{LoadError: sessions.ErrInvalidAudience}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusUnauthorized, "{\"Status\":401,\"Error\":\"Unauthorized: internal/sessions: validation failed, invalid audience claim (aud)\"}\n"},
|
||||
{"bad naked domain uri", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/", "a.naked.domain", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, "{\"Status\":400,\"Error\":\"Bad Request: a.naked.domain url does contain a valid scheme\"}\n"},
|
||||
{"bad naked domain uri verify only", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/verify", "a.naked.domain", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, "{\"Status\":400,\"Error\":\"Bad Request: a.naked.domain url does contain a valid scheme\"}\n"},
|
||||
{"bad empty verification uri", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/", " ", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, "{\"Status\":400,\"Error\":\"Bad Request: %20 url does contain a valid scheme\"}\n"},
|
||||
{"bad empty verification uri verify only", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/verify", " ", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, "{\"Status\":400,\"Error\":\"Bad Request: %20 url does contain a valid scheme\"}\n"},
|
||||
{"not authorized", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusUnauthorized, "{\"Status\":401,\"Error\":\"Unauthorized: user@test.example is not authorized for some.domain.example\"}\n"},
|
||||
{"not authorized verify endpoint", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/verify", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusUnauthorized, "{\"Status\":401,\"Error\":\"Unauthorized: user@test.example is not authorized for some.domain.example\"}\n"},
|
||||
{"not authorized expired, redirect to auth", opts, sessions.ErrExpired, http.MethodGet, nil, nil, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusFound, ""},
|
||||
{"not authorized expired, don't redirect!", opts, sessions.ErrExpired, http.MethodGet, nil, nil, "https://some.domain.example/verify", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusUnauthorized, "{\"Status\":401,\"Error\":\"Unauthorized: internal/sessions: validation failed, token is expired (exp)\"}\n"},
|
||||
{"not authorized because of error", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeError: errors.New("authz error")}, http.StatusInternalServerError, "{\"Status\":500,\"Error\":\"Internal Server Error: authz error\"}\n"},
|
||||
{"not authorized expired, do not redirect to auth", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/verify", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusUnauthorized, "{\"Status\":401,\"Error\":\"Unauthorized: internal/sessions: validation failed, token is expired (exp)\"}\n"},
|
||||
{"not authorized, bad audience request uri", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Audience: []string{"not.domain.example"}, Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusUnauthorized, "{\"Status\":401,\"Error\":\"Unauthorized: internal/sessions: validation failed, invalid audience claim (aud)\"}\n"},
|
||||
{"not authorized, bad audience verify uri", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/", "https://fwdauth.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Audience: []string{"some.domain.example"}, Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusUnauthorized, "{\"Status\":401,\"Error\":\"Unauthorized: internal/sessions: validation failed, invalid audience claim (aud)\"}\n"},
|
||||
{"good redirect not required", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusOK, "Access to some.domain.example is allowed."},
|
||||
{"good verify only, no redirect", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/verify", "https://some.domain.example", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusOK, ""},
|
||||
{"good redirect not required", opts, nil, http.MethodGet, nil, nil, "/", "https://some.domain.example", &mock.Encoder{}, &mstore.Store{LoadError: sessions.ErrInvalidAudience}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusUnauthorized, "{\"Status\":401,\"Error\":\"Unauthorized: internal/sessions: validation failed, invalid audience claim (aud)\"}\n"},
|
||||
{"bad naked domain uri", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/", "a.naked.domain", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, "{\"Status\":400,\"Error\":\"Bad Request: a.naked.domain url does contain a valid scheme\"}\n"},
|
||||
{"bad naked domain uri verify only", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/verify", "a.naked.domain", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, "{\"Status\":400,\"Error\":\"Bad Request: a.naked.domain url does contain a valid scheme\"}\n"},
|
||||
{"bad empty verification uri", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/", " ", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, "{\"Status\":400,\"Error\":\"Bad Request: %20 url does contain a valid scheme\"}\n"},
|
||||
{"bad empty verification uri verify only", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/verify", " ", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, "{\"Status\":400,\"Error\":\"Bad Request: %20 url does contain a valid scheme\"}\n"},
|
||||
{"not authorized", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusUnauthorized, "{\"Status\":401,\"Error\":\"Unauthorized: user@test.example is not authorized for some.domain.example\"}\n"},
|
||||
{"not authorized verify endpoint", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/verify", "https://some.domain.example", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusUnauthorized, "{\"Status\":401,\"Error\":\"Unauthorized: user@test.example is not authorized for some.domain.example\"}\n"},
|
||||
{"not authorized expired, redirect to auth", opts, sessions.ErrExpired, http.MethodGet, nil, nil, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusFound, ""},
|
||||
{"not authorized expired, don't redirect!", opts, sessions.ErrExpired, http.MethodGet, nil, nil, "https://some.domain.example/verify", "https://some.domain.example", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusUnauthorized, "{\"Status\":401,\"Error\":\"Unauthorized: internal/sessions: validation failed, token is expired (exp)\"}\n"},
|
||||
{"not authorized because of error", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeError: errors.New("authz error")}, http.StatusInternalServerError, "{\"Status\":500,\"Error\":\"Internal Server Error: authz error\"}\n"},
|
||||
{"not authorized expired, do not redirect to auth", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/verify", "https://some.domain.example", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusUnauthorized, "{\"Status\":401,\"Error\":\"Unauthorized: internal/sessions: validation failed, token is expired (exp)\"}\n"},
|
||||
{"not authorized, bad audience request uri", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Audience: []string{"not.domain.example"}, Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusUnauthorized, "{\"Status\":401,\"Error\":\"Unauthorized: internal/sessions: validation failed, invalid audience claim (aud)\"}\n"},
|
||||
{"not authorized, bad audience verify uri", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/", "https://fwdauth.domain.example", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Audience: []string{"some.domain.example"}, Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusUnauthorized, "{\"Status\":401,\"Error\":\"Unauthorized: internal/sessions: validation failed, invalid audience claim (aud)\"}\n"},
|
||||
// traefik
|
||||
{"good traefik callback", opts, nil, http.MethodGet, map[string]string{httputil.HeaderForwardedURI: "https://some.domain.example?" + urlutil.QuerySessionEncrypted + "=" + goodEncryptionString}, nil, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound, ""},
|
||||
{"bad traefik callback bad session", opts, nil, http.MethodGet, map[string]string{httputil.HeaderForwardedURI: "https://some.domain.example?" + urlutil.QuerySessionEncrypted + "=" + goodEncryptionString + "garbage"}, nil, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""},
|
||||
{"bad traefik callback bad url", opts, nil, http.MethodGet, map[string]string{httputil.HeaderForwardedURI: urlutil.QuerySessionEncrypted + ""}, nil, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""},
|
||||
{"good traefik callback", opts, nil, http.MethodGet, map[string]string{httputil.HeaderForwardedURI: "https://some.domain.example?" + urlutil.QuerySessionEncrypted + "=" + goodEncryptionString}, nil, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound, ""},
|
||||
{"bad traefik callback bad session", opts, nil, http.MethodGet, map[string]string{httputil.HeaderForwardedURI: "https://some.domain.example?" + urlutil.QuerySessionEncrypted + "=" + goodEncryptionString + "garbage"}, nil, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""},
|
||||
{"bad traefik callback bad url", opts, nil, http.MethodGet, map[string]string{httputil.HeaderForwardedURI: urlutil.QuerySessionEncrypted + ""}, nil, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""},
|
||||
// nginx
|
||||
{"good nginx callback redirect", opts, nil, http.MethodGet, nil, map[string]string{urlutil.QueryRedirectURI: "https://some.domain.example/", urlutil.QuerySessionEncrypted: goodEncryptionString}, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound, ""},
|
||||
{"good nginx callback set session okay but return unauthorized", opts, nil, http.MethodGet, nil, map[string]string{urlutil.QueryRedirectURI: "https://some.domain.example/", urlutil.QuerySessionEncrypted: goodEncryptionString}, "https://some.domain.example/verify", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusUnauthorized, ""},
|
||||
{"bad nginx callback failed to set sesion", opts, nil, http.MethodGet, nil, map[string]string{urlutil.QueryRedirectURI: "https://some.domain.example/", urlutil.QuerySessionEncrypted: goodEncryptionString + "nope"}, "https://some.domain.example/verify", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""},
|
||||
{"good nginx callback redirect", opts, nil, http.MethodGet, nil, map[string]string{urlutil.QueryRedirectURI: "https://some.domain.example/", urlutil.QuerySessionEncrypted: goodEncryptionString}, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound, ""},
|
||||
{"good nginx callback set session okay but return unauthorized", opts, nil, http.MethodGet, nil, map[string]string{urlutil.QueryRedirectURI: "https://some.domain.example/", urlutil.QuerySessionEncrypted: goodEncryptionString}, "https://some.domain.example/verify", "https://some.domain.example", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusUnauthorized, ""},
|
||||
{"bad nginx callback failed to set sesion", opts, nil, http.MethodGet, nil, map[string]string{urlutil.QueryRedirectURI: "https://some.domain.example/", urlutil.QuerySessionEncrypted: goodEncryptionString + "nope"}, "https://some.domain.example/verify", "https://some.domain.example", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
|
|
@ -12,6 +12,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/cryptutil"
|
||||
mstore "github.com/pomerium/pomerium/internal/sessions/mock"
|
||||
|
||||
"github.com/pomerium/pomerium/config"
|
||||
"github.com/pomerium/pomerium/internal/encoding"
|
||||
|
@ -78,10 +79,10 @@ func TestProxy_UserDashboard(t *testing.T) {
|
|||
wantAdminForm bool
|
||||
wantStatus int
|
||||
}{
|
||||
{"good", nil, opts, http.MethodGet, &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{}, false, http.StatusOK},
|
||||
{"session context error", errors.New("error"), opts, http.MethodGet, &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{}, false, http.StatusInternalServerError},
|
||||
{"want admin form good admin authorization", nil, opts, http.MethodGet, &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{IsAdminResponse: true}, true, http.StatusOK},
|
||||
{"is admin but authorization fails", nil, opts, http.MethodGet, &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{IsAdminError: errors.New("err")}, false, http.StatusInternalServerError},
|
||||
{"good", nil, opts, http.MethodGet, &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{}, false, http.StatusOK},
|
||||
{"session context error", errors.New("error"), opts, http.MethodGet, &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{}, false, http.StatusInternalServerError},
|
||||
{"want admin form good admin authorization", nil, opts, http.MethodGet, &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{IsAdminResponse: true}, true, http.StatusOK},
|
||||
{"is admin but authorization fails", nil, opts, http.MethodGet, &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{IsAdminError: errors.New("err")}, false, http.StatusInternalServerError},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
|
@ -135,12 +136,12 @@ func TestProxy_Impersonate(t *testing.T) {
|
|||
authorizer clients.Authorizer
|
||||
wantStatus int
|
||||
}{
|
||||
{"good", false, opts, nil, http.MethodPost, "user@blah.com", "", "", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example"}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusFound},
|
||||
{"good", false, opts, errors.New("error"), http.MethodPost, "user@blah.com", "", "", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example"}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusInternalServerError},
|
||||
{"session load error", false, opts, nil, http.MethodPost, "user@blah.com", "", "", &mock.Encoder{}, &sessions.MockSessionStore{LoadError: errors.New("err"), Session: &sessions.State{Email: "user@test.example"}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusFound},
|
||||
{"non admin users rejected", false, opts, nil, http.MethodPost, "user@blah.com", "", "", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute)), Email: "user@test.example"}}, clients.MockAuthorize{IsAdminResponse: false}, http.StatusForbidden},
|
||||
{"non admin users rejected on error", false, opts, nil, http.MethodPost, "user@blah.com", "", "", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute)), Email: "user@test.example"}}, clients.MockAuthorize{IsAdminResponse: true, IsAdminError: errors.New("err")}, http.StatusInternalServerError},
|
||||
{"groups", false, opts, nil, http.MethodPost, "user@blah.com", "group1,group2", "", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute)), Email: "user@test.example"}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusFound},
|
||||
{"good", false, opts, nil, http.MethodPost, "user@blah.com", "", "", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example"}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusFound},
|
||||
{"good", false, opts, errors.New("error"), http.MethodPost, "user@blah.com", "", "", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example"}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusInternalServerError},
|
||||
{"session load error", false, opts, nil, http.MethodPost, "user@blah.com", "", "", &mock.Encoder{}, &mstore.Store{LoadError: errors.New("err"), Session: &sessions.State{Email: "user@test.example"}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusFound},
|
||||
{"non admin users rejected", false, opts, nil, http.MethodPost, "user@blah.com", "", "", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute)), Email: "user@test.example"}}, clients.MockAuthorize{IsAdminResponse: false}, http.StatusForbidden},
|
||||
{"non admin users rejected on error", false, opts, nil, http.MethodPost, "user@blah.com", "", "", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute)), Email: "user@test.example"}}, clients.MockAuthorize{IsAdminResponse: true, IsAdminError: errors.New("err")}, http.StatusInternalServerError},
|
||||
{"groups", false, opts, nil, http.MethodPost, "user@blah.com", "group1,group2", "", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute)), Email: "user@test.example"}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusFound},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
@ -245,12 +246,12 @@ func TestProxy_Callback(t *testing.T) {
|
|||
wantStatus int
|
||||
wantBody string
|
||||
}{
|
||||
{"good", opts, http.MethodGet, "http", "example.com", "/", nil, map[string]string{urlutil.QueryCallbackURI: "ok", urlutil.QuerySessionEncrypted: goodEncryptionString}, &mock.Encoder{MarshalResponse: []byte("x")}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound, ""},
|
||||
{"good programmatic", opts, http.MethodGet, "http", "example.com", "/", nil, map[string]string{urlutil.QueryIsProgrammatic: "true", urlutil.QueryCallbackURI: "ok", urlutil.QuerySessionEncrypted: goodEncryptionString}, &mock.Encoder{MarshalResponse: []byte("x")}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound, ""},
|
||||
{"bad decrypt", opts, http.MethodGet, "http", "example.com", "/", nil, map[string]string{urlutil.QuerySessionEncrypted: "KBEjQ9rnCxaAX-GOqexGw9ivEQURqts3zZ2mNGy0wnVa3SbtM399KlBq2nZ-9wM21FfsZX52er4jlmC7kPEKM3P7uZ41zR0zeys1-_74a5tQp-vsf1WXZfRsgVOuBcWPkMiWEoc379JFHxGDudp5VhU8B-dcQt4f3_PtLTHARkuH54io1Va2gNMq4Hiy8sQ1MPGCQeltH_JMzzdDpXdmdusWrXUvCGkba24muvAV06D8XRVJj6Iu9eK94qFnqcHc7wzziEbb8ADBues9dwbtb6jl8vMWz5rN6XvXqA5YpZv_MQZlsrO4oXFFQDevdgB84cX1tVbVu6qZvK_yQBZqzpOjWA9uIaoSENMytoXuWAlFO_sXjswfX8JTNdGwzB7qQRNPqxVG_sM_tzY3QhPm8zqwEzsXG5DokxZfVt2I5WJRUEovFDb4BnK9KFnnkEzLEdMudixVnXeGmTtycgJvoTeTCQRPfDYkcgJ7oKf4tGea-W7z5UAVa2RduJM9ZoM6YtJX7jgDm__PvvqcE0knJUF87XHBzdcOjoDF-CUze9xDJgNBlvPbJqVshKrwoqSYpePSDH9GUCNKxGequW3Ma8GvlFfhwd0rK6IZG-XWkyk0XSWQIGkDSjAvhB1wsOusCCguDjbpVZpaW5MMyTkmx68pl6qlIKT5UCcrVPl4ix5ZEj91mUDF0O1t04haD7VZuLVFXVGmqtFrBKI76sdYN-zkokaa1_chPRTyqMQFlqu_8LD6-RiK3UccGM-dEmnX72i91NP9F9OK0WJr9Cheup1C_P0mjqAO4Cb8oIHm0Oxz_mRqv5QbTGJtb3xwPLPuVjVCiE4gGBcuU2ixpSVf5HUF7y1KicVMCKiX9ATCBtg8sTdQZQnPEtHcHHAvdsnDVwev1LGfqA-Gdvg="}, &mock.Encoder{MarshalResponse: []byte("x")}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""},
|
||||
{"bad save session", opts, http.MethodGet, "http", "example.com", "/", nil, map[string]string{urlutil.QuerySessionEncrypted: goodEncryptionString}, &mock.Encoder{MarshalResponse: []byte("x")}, &sessions.MockSessionStore{SaveError: errors.New("hi")}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""},
|
||||
{"bad base64", opts, http.MethodGet, "http", "example.com", "/", nil, map[string]string{urlutil.QuerySessionEncrypted: "^"}, &mock.Encoder{MarshalResponse: []byte("x")}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""},
|
||||
{"malformed redirect", opts, http.MethodGet, "http", "example.com", "/", nil, nil, &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""},
|
||||
{"good", opts, http.MethodGet, "http", "example.com", "/", nil, map[string]string{urlutil.QueryCallbackURI: "ok", urlutil.QuerySessionEncrypted: goodEncryptionString}, &mock.Encoder{MarshalResponse: []byte("x")}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound, ""},
|
||||
{"good programmatic", opts, http.MethodGet, "http", "example.com", "/", nil, map[string]string{urlutil.QueryIsProgrammatic: "true", urlutil.QueryCallbackURI: "ok", urlutil.QuerySessionEncrypted: goodEncryptionString}, &mock.Encoder{MarshalResponse: []byte("x")}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound, ""},
|
||||
{"bad decrypt", opts, http.MethodGet, "http", "example.com", "/", nil, map[string]string{urlutil.QuerySessionEncrypted: "KBEjQ9rnCxaAX-GOqexGw9ivEQURqts3zZ2mNGy0wnVa3SbtM399KlBq2nZ-9wM21FfsZX52er4jlmC7kPEKM3P7uZ41zR0zeys1-_74a5tQp-vsf1WXZfRsgVOuBcWPkMiWEoc379JFHxGDudp5VhU8B-dcQt4f3_PtLTHARkuH54io1Va2gNMq4Hiy8sQ1MPGCQeltH_JMzzdDpXdmdusWrXUvCGkba24muvAV06D8XRVJj6Iu9eK94qFnqcHc7wzziEbb8ADBues9dwbtb6jl8vMWz5rN6XvXqA5YpZv_MQZlsrO4oXFFQDevdgB84cX1tVbVu6qZvK_yQBZqzpOjWA9uIaoSENMytoXuWAlFO_sXjswfX8JTNdGwzB7qQRNPqxVG_sM_tzY3QhPm8zqwEzsXG5DokxZfVt2I5WJRUEovFDb4BnK9KFnnkEzLEdMudixVnXeGmTtycgJvoTeTCQRPfDYkcgJ7oKf4tGea-W7z5UAVa2RduJM9ZoM6YtJX7jgDm__PvvqcE0knJUF87XHBzdcOjoDF-CUze9xDJgNBlvPbJqVshKrwoqSYpePSDH9GUCNKxGequW3Ma8GvlFfhwd0rK6IZG-XWkyk0XSWQIGkDSjAvhB1wsOusCCguDjbpVZpaW5MMyTkmx68pl6qlIKT5UCcrVPl4ix5ZEj91mUDF0O1t04haD7VZuLVFXVGmqtFrBKI76sdYN-zkokaa1_chPRTyqMQFlqu_8LD6-RiK3UccGM-dEmnX72i91NP9F9OK0WJr9Cheup1C_P0mjqAO4Cb8oIHm0Oxz_mRqv5QbTGJtb3xwPLPuVjVCiE4gGBcuU2ixpSVf5HUF7y1KicVMCKiX9ATCBtg8sTdQZQnPEtHcHHAvdsnDVwev1LGfqA-Gdvg="}, &mock.Encoder{MarshalResponse: []byte("x")}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""},
|
||||
{"bad save session", opts, http.MethodGet, "http", "example.com", "/", nil, map[string]string{urlutil.QuerySessionEncrypted: goodEncryptionString}, &mock.Encoder{MarshalResponse: []byte("x")}, &mstore.Store{SaveError: errors.New("hi")}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""},
|
||||
{"bad base64", opts, http.MethodGet, "http", "example.com", "/", nil, map[string]string{urlutil.QuerySessionEncrypted: "^"}, &mock.Encoder{MarshalResponse: []byte("x")}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""},
|
||||
{"malformed redirect", opts, http.MethodGet, "http", "example.com", "/", nil, nil, &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
@ -386,12 +387,12 @@ func TestProxy_ProgrammaticCallback(t *testing.T) {
|
|||
wantStatus int
|
||||
wantBody string
|
||||
}{
|
||||
{"good", opts, http.MethodGet, "http://pomerium.io/", nil, map[string]string{urlutil.QueryCallbackURI: "ok", urlutil.QuerySessionEncrypted: goodEncryptionString}, &mock.Encoder{MarshalResponse: []byte("x")}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound, ""},
|
||||
{"good programmatic", opts, http.MethodGet, "http://pomerium.io/", nil, map[string]string{urlutil.QueryIsProgrammatic: "true", urlutil.QueryCallbackURI: "ok", urlutil.QuerySessionEncrypted: goodEncryptionString}, &mock.Encoder{MarshalResponse: []byte("x")}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound, ""},
|
||||
{"bad decrypt", opts, http.MethodGet, "http://pomerium.io/", nil, map[string]string{urlutil.QuerySessionEncrypted: goodEncryptionString + cryptutil.NewBase64Key()}, &mock.Encoder{MarshalResponse: []byte("x")}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""},
|
||||
{"bad save session", opts, http.MethodGet, "http://pomerium.io/", nil, map[string]string{urlutil.QuerySessionEncrypted: goodEncryptionString}, &mock.Encoder{MarshalResponse: []byte("x")}, &sessions.MockSessionStore{SaveError: errors.New("hi")}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""},
|
||||
{"bad base64", opts, http.MethodGet, "http://pomerium.io/", nil, map[string]string{urlutil.QuerySessionEncrypted: "^"}, &mock.Encoder{MarshalResponse: []byte("x")}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""},
|
||||
{"malformed redirect", opts, http.MethodGet, "http://pomerium.io/", nil, nil, &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""},
|
||||
{"good", opts, http.MethodGet, "http://pomerium.io/", nil, map[string]string{urlutil.QueryCallbackURI: "ok", urlutil.QuerySessionEncrypted: goodEncryptionString}, &mock.Encoder{MarshalResponse: []byte("x")}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound, ""},
|
||||
{"good programmatic", opts, http.MethodGet, "http://pomerium.io/", nil, map[string]string{urlutil.QueryIsProgrammatic: "true", urlutil.QueryCallbackURI: "ok", urlutil.QuerySessionEncrypted: goodEncryptionString}, &mock.Encoder{MarshalResponse: []byte("x")}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound, ""},
|
||||
{"bad decrypt", opts, http.MethodGet, "http://pomerium.io/", nil, map[string]string{urlutil.QuerySessionEncrypted: goodEncryptionString + cryptutil.NewBase64Key()}, &mock.Encoder{MarshalResponse: []byte("x")}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""},
|
||||
{"bad save session", opts, http.MethodGet, "http://pomerium.io/", nil, map[string]string{urlutil.QuerySessionEncrypted: goodEncryptionString}, &mock.Encoder{MarshalResponse: []byte("x")}, &mstore.Store{SaveError: errors.New("hi")}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""},
|
||||
{"bad base64", opts, http.MethodGet, "http://pomerium.io/", nil, map[string]string{urlutil.QuerySessionEncrypted: "^"}, &mock.Encoder{MarshalResponse: []byte("x")}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""},
|
||||
{"malformed redirect", opts, http.MethodGet, "http://pomerium.io/", nil, nil, &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
|
|
@ -1,7 +1,11 @@
|
|||
package proxy // import "github.com/pomerium/pomerium/proxy"
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/encoding"
|
||||
|
@ -30,23 +34,82 @@ func (p *Proxy) AuthenticateSession(next http.Handler) http.Handler {
|
|||
ctx, span := trace.StartSpan(r.Context(), "proxy.AuthenticateSession")
|
||||
defer span.End()
|
||||
|
||||
if s, err := sessions.FromContext(ctx); err != nil {
|
||||
log.FromRequest(r).Debug().Err(err).Msg("proxy: authenticate session")
|
||||
p.sessionStore.ClearSession(w, r)
|
||||
if s != nil && s.Programmatic {
|
||||
return httputil.NewError(http.StatusUnauthorized, err)
|
||||
_, err := sessions.FromContext(ctx)
|
||||
if errors.Is(err, sessions.ErrExpired) {
|
||||
ctx, err = p.refresh(ctx, w, r)
|
||||
if err != nil {
|
||||
log.FromRequest(r).Warn().Err(err).Msg("proxy: refresh failed")
|
||||
return p.redirectToSignin(w, r)
|
||||
}
|
||||
signinURL := *p.authenticateSigninURL
|
||||
q := signinURL.Query()
|
||||
q.Set(urlutil.QueryRedirectURI, urlutil.GetAbsoluteURL(r).String())
|
||||
signinURL.RawQuery = q.Encode()
|
||||
httputil.Redirect(w, r, urlutil.NewSignedURL(p.SharedKey, &signinURL).String(), http.StatusFound)
|
||||
log.FromRequest(r).Info().Msg("proxy: refresh success")
|
||||
} else if err != nil {
|
||||
log.FromRequest(r).Debug().Err(err).Msg("proxy: session state")
|
||||
return p.redirectToSignin(w, r)
|
||||
}
|
||||
p.addPomeriumHeaders(w, r)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (p *Proxy) refresh(ctx context.Context, w http.ResponseWriter, r *http.Request) (context.Context, error) {
|
||||
ctx, span := trace.StartSpan(ctx, "proxy.AuthenticateSession/refresh")
|
||||
defer span.End()
|
||||
s, err := sessions.FromContext(ctx)
|
||||
if !errors.Is(err, sessions.ErrExpired) || s == nil {
|
||||
return nil, errors.New("proxy: unexpected session state for refresh")
|
||||
}
|
||||
// 1 - build a signed url to call refresh on authenticate service
|
||||
refreshURI := *p.authenticateRefreshURL
|
||||
q := refreshURI.Query()
|
||||
q.Set("ati", s.AccessTokenID) // hash value points to parent token
|
||||
q.Set("aud", urlutil.StripPort(r.Host)) // request's audience, this route
|
||||
refreshURI.RawQuery = q.Encode()
|
||||
signedRefreshURL := urlutil.NewSignedURL(p.SharedKey, &refreshURI).String()
|
||||
|
||||
// 2 - http call to authenticate service
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, signedRefreshURL, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("proxy: backend refresh: new request: %v", err)
|
||||
}
|
||||
res, err := httputil.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("proxy: fetch %v: %w", signedRefreshURL, err)
|
||||
}
|
||||
defer res.Body.Close()
|
||||
jwtBytes, err := ioutil.ReadAll(io.LimitReader(res.Body, 4<<10))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 3 - save refreshed session to the client's session store
|
||||
if err = p.sessionStore.SaveSession(w, r, jwtBytes); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// 4 - add refreshed session to the current request context
|
||||
var state sessions.State
|
||||
if err := p.encoder.Unmarshal(jwtBytes, &state); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := state.Verify(urlutil.StripPort(r.Host)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return sessions.NewContext(r.Context(), &state, err), nil
|
||||
}
|
||||
|
||||
func (p *Proxy) redirectToSignin(w http.ResponseWriter, r *http.Request) error {
|
||||
s, err := sessions.FromContext(r.Context())
|
||||
if s != nil && err != nil && s.Programmatic {
|
||||
return httputil.NewError(http.StatusUnauthorized, err)
|
||||
}
|
||||
p.sessionStore.ClearSession(w, r)
|
||||
signinURL := *p.authenticateSigninURL
|
||||
q := signinURL.Query()
|
||||
q.Set(urlutil.QueryRedirectURI, urlutil.GetAbsoluteURL(r).String())
|
||||
signinURL.RawQuery = q.Encode()
|
||||
log.FromRequest(r).Debug().Str("url", signinURL.String()).Msg("proxy: redirectToSignin")
|
||||
httputil.Redirect(w, r, urlutil.NewSignedURL(p.SharedKey, &signinURL).String(), http.StatusFound)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Proxy) addPomeriumHeaders(w http.ResponseWriter, r *http.Request) {
|
||||
|
@ -61,8 +124,8 @@ func (p *Proxy) addPomeriumHeaders(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
}
|
||||
|
||||
// AuthorizeSession is middleware to enforce a user is authorized for a request
|
||||
// session state is retrieved from the users's request context.
|
||||
// AuthorizeSession is middleware to enforce a user is authorized for a request.
|
||||
// Session state is retrieved from the users's request context.
|
||||
func (p *Proxy) AuthorizeSession(next http.Handler) http.Handler {
|
||||
return httputil.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
|
||||
ctx, span := trace.StartSpan(r.Context(), "proxy.AuthorizeSession")
|
||||
|
|
|
@ -10,10 +10,14 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/pomerium/pomerium/internal/identity"
|
||||
"github.com/pomerium/pomerium/internal/sessions"
|
||||
"github.com/pomerium/pomerium/proxy/clients"
|
||||
"gopkg.in/square/go-jose.v2/jwt"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/encoding"
|
||||
"github.com/pomerium/pomerium/internal/encoding/mock"
|
||||
"github.com/pomerium/pomerium/internal/identity"
|
||||
"github.com/pomerium/pomerium/internal/sessions"
|
||||
mstore "github.com/pomerium/pomerium/internal/sessions/mock"
|
||||
)
|
||||
|
||||
func TestProxy_AuthenticateSession(t *testing.T) {
|
||||
|
@ -30,24 +34,39 @@ func TestProxy_AuthenticateSession(t *testing.T) {
|
|||
session sessions.SessionStore
|
||||
ctxError error
|
||||
provider identity.Authenticator
|
||||
encoder encoding.MarshalUnmarshaler
|
||||
refreshURL string
|
||||
|
||||
wantStatus int
|
||||
}{
|
||||
{"good", false, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Second))}}, nil, identity.MockProvider{}, http.StatusOK},
|
||||
{"invalid session", false, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Second))}}, errors.New("hi"), identity.MockProvider{}, http.StatusFound},
|
||||
{"expired", false, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{}, http.StatusFound},
|
||||
{"expired and programmatic", false, &sessions.MockSessionStore{Session: &sessions.State{Programmatic: true, Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{}, http.StatusUnauthorized},
|
||||
{"invalid session and programmatic", false, &sessions.MockSessionStore{Session: &sessions.State{Programmatic: true, Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Second))}}, errors.New("hi"), identity.MockProvider{}, http.StatusUnauthorized},
|
||||
{"good", false, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Second))}}, nil, identity.MockProvider{}, &mock.Encoder{}, "", http.StatusOK},
|
||||
{"invalid session", false, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Second))}}, errors.New("hi"), identity.MockProvider{}, &mock.Encoder{}, "", http.StatusFound},
|
||||
{"expired", false, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{}, &mock.Encoder{}, "", http.StatusOK},
|
||||
{"expired and programmatic", false, &mstore.Store{Session: &sessions.State{Programmatic: true, Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{}, &mock.Encoder{}, "", http.StatusOK},
|
||||
{"invalid session and programmatic", false, &mstore.Store{Session: &sessions.State{Programmatic: true, Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Second))}}, errors.New("hi"), identity.MockProvider{}, &mock.Encoder{}, "", http.StatusUnauthorized},
|
||||
{"expired and refreshed ok", false, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{}, &mock.Encoder{}, "", http.StatusOK},
|
||||
{"expired and save failed", false, &mstore.Store{SaveError: errors.New("err"), Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{}, &mock.Encoder{}, "", http.StatusFound},
|
||||
{"expired and unmarshal failed", false, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{}, &mock.Encoder{UnmarshalError: errors.New("err")}, "", http.StatusFound},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintln(w, "REFRESH GOOD")
|
||||
}))
|
||||
defer ts.Close()
|
||||
rURL := ts.URL
|
||||
if tt.refreshURL != "" {
|
||||
rURL = tt.refreshURL
|
||||
}
|
||||
|
||||
a := Proxy{
|
||||
SharedKey: "80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ=",
|
||||
cookieSecret: []byte("80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ="),
|
||||
authenticateURL: uriParseHelper("https://authenticate.corp.example"),
|
||||
authenticateSigninURL: uriParseHelper("https://authenticate.corp.example/sign_in"),
|
||||
sessionStore: tt.session,
|
||||
SharedKey: "80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ=",
|
||||
cookieSecret: []byte("80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ="),
|
||||
authenticateURL: uriParseHelper("https://authenticate.corp.example"),
|
||||
authenticateSigninURL: uriParseHelper("https://authenticate.corp.example/sign_in"),
|
||||
authenticateRefreshURL: uriParseHelper(rURL),
|
||||
sessionStore: tt.session,
|
||||
encoder: tt.encoder,
|
||||
}
|
||||
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
state, _ := tt.session.LoadSession(r)
|
||||
|
@ -82,10 +101,10 @@ func TestProxy_AuthorizeSession(t *testing.T) {
|
|||
|
||||
wantStatus int
|
||||
}{
|
||||
{"user is authorized", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Second))}}, clients.MockAuthorize{AuthorizeResponse: true}, nil, identity.MockProvider{}, http.StatusOK},
|
||||
{"user is not authorized", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Second))}}, clients.MockAuthorize{AuthorizeResponse: false}, nil, identity.MockProvider{}, http.StatusUnauthorized},
|
||||
{"invalid session", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Second))}}, clients.MockAuthorize{AuthorizeResponse: true}, errors.New("hi"), identity.MockProvider{}, http.StatusUnauthorized},
|
||||
{"authz client error", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Second))}}, clients.MockAuthorize{AuthorizeError: errors.New("err")}, nil, identity.MockProvider{}, http.StatusInternalServerError},
|
||||
{"user is authorized", &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Second))}}, clients.MockAuthorize{AuthorizeResponse: true}, nil, identity.MockProvider{}, http.StatusOK},
|
||||
{"user is not authorized", &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Second))}}, clients.MockAuthorize{AuthorizeResponse: false}, nil, identity.MockProvider{}, http.StatusUnauthorized},
|
||||
{"invalid session", &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Second))}}, clients.MockAuthorize{AuthorizeResponse: true}, errors.New("hi"), identity.MockProvider{}, http.StatusUnauthorized},
|
||||
{"authz client error", &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Second))}}, clients.MockAuthorize{AuthorizeError: errors.New("err")}, nil, identity.MockProvider{}, http.StatusInternalServerError},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
@ -143,9 +162,9 @@ func TestProxy_SignRequest(t *testing.T) {
|
|||
wantStatus int
|
||||
wantHeaders string
|
||||
}{
|
||||
{"good", &sessions.MockSessionStore{Session: &sessions.State{Email: "test"}}, nil, nil, http.StatusOK, "ok"},
|
||||
{"invalid session", &sessions.MockSessionStore{Session: &sessions.State{Email: "test"}}, nil, errors.New("err"), http.StatusForbidden, ""},
|
||||
{"signature failure, warn but ok", &sessions.MockSessionStore{Session: &sessions.State{Email: "test"}}, errors.New("err"), nil, http.StatusOK, ""},
|
||||
{"good", &mstore.Store{Session: &sessions.State{Email: "test"}}, nil, nil, http.StatusOK, "ok"},
|
||||
{"invalid session", &mstore.Store{Session: &sessions.State{Email: "test"}}, nil, errors.New("err"), http.StatusForbidden, ""},
|
||||
{"signature failure, warn but ok", &mstore.Store{Session: &sessions.State{Email: "test"}}, errors.New("err"), nil, http.StatusOK, ""},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
|
|
@ -21,6 +21,9 @@ import (
|
|||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/internal/middleware"
|
||||
"github.com/pomerium/pomerium/internal/sessions"
|
||||
"github.com/pomerium/pomerium/internal/sessions/cookie"
|
||||
"github.com/pomerium/pomerium/internal/sessions/header"
|
||||
"github.com/pomerium/pomerium/internal/sessions/queryparam"
|
||||
"github.com/pomerium/pomerium/internal/telemetry/metrics"
|
||||
"github.com/pomerium/pomerium/internal/tripper"
|
||||
"github.com/pomerium/pomerium/internal/urlutil"
|
||||
|
@ -28,12 +31,11 @@ import (
|
|||
)
|
||||
|
||||
const (
|
||||
// dashboardURL is the path to authenticate's sign in endpoint
|
||||
// authenticate urls
|
||||
dashboardURL = "/.pomerium"
|
||||
// signinURL is the path to authenticate's sign in endpoint
|
||||
signinURL = "/.pomerium/sign_in"
|
||||
// signoutURL is the path to authenticate's sign out endpoint
|
||||
signoutURL = "/.pomerium/sign_out"
|
||||
signinURL = "/.pomerium/sign_in"
|
||||
signoutURL = "/.pomerium/sign_out"
|
||||
refreshURL = "/.pomerium/refresh"
|
||||
)
|
||||
|
||||
// ValidateOptions checks that proper configuration settings are set to create
|
||||
|
@ -72,12 +74,14 @@ type Proxy struct {
|
|||
authenticateURL *url.URL
|
||||
authenticateSigninURL *url.URL
|
||||
authenticateSignoutURL *url.URL
|
||||
authorizeURL *url.URL
|
||||
authenticateRefreshURL *url.URL
|
||||
|
||||
authorizeURL *url.URL
|
||||
|
||||
AuthorizeClient clients.Authorizer
|
||||
|
||||
encoder encoding.Unmarshaler
|
||||
cookieOptions *sessions.CookieOptions
|
||||
cookieOptions *cookie.Options
|
||||
cookieSecret []byte
|
||||
defaultUpstreamTimeout time.Duration
|
||||
refreshCooldown time.Duration
|
||||
|
@ -104,7 +108,7 @@ func New(opts config.Options) (*Proxy, error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
cookieOptions := &sessions.CookieOptions{
|
||||
cookieOptions := &cookie.Options{
|
||||
Name: opts.CookieName,
|
||||
Domain: opts.CookieDomain,
|
||||
Secure: opts.CookieSecure,
|
||||
|
@ -112,7 +116,7 @@ func New(opts config.Options) (*Proxy, error) {
|
|||
Expire: opts.CookieExpire,
|
||||
}
|
||||
|
||||
cookieStore, err := sessions.NewCookieLoader(cookieOptions, encoder)
|
||||
cookieStore, err := cookie.NewStore(cookieOptions, encoder)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -129,8 +133,8 @@ func New(opts config.Options) (*Proxy, error) {
|
|||
sessionStore: cookieStore,
|
||||
sessionLoaders: []sessions.SessionLoader{
|
||||
cookieStore,
|
||||
sessions.NewHeaderStore(encoder, "Pomerium"),
|
||||
sessions.NewQueryParamStore(encoder, "pomerium_session")},
|
||||
header.NewStore(encoder, "Pomerium"),
|
||||
queryparam.NewStore(encoder, "pomerium_session")},
|
||||
signingKey: opts.SigningKey,
|
||||
templates: template.Must(frontend.NewTemplates()),
|
||||
}
|
||||
|
@ -139,6 +143,7 @@ func New(opts config.Options) (*Proxy, error) {
|
|||
p.authenticateURL, _ = urlutil.DeepCopy(opts.AuthenticateURL)
|
||||
p.authenticateSigninURL = p.authenticateURL.ResolveReference(&url.URL{Path: signinURL})
|
||||
p.authenticateSignoutURL = p.authenticateURL.ResolveReference(&url.URL{Path: signoutURL})
|
||||
p.authenticateRefreshURL = p.authenticateURL.ResolveReference(&url.URL{Path: refreshURL})
|
||||
|
||||
if err := p.UpdatePolicies(&opts); err != nil {
|
||||
return nil, err
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue