mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-12 08:37:38 +02:00
authenticate/proxy: add backend refresh (#438)
This commit is contained in:
parent
9a330613aa
commit
ec029c679b
35 changed files with 1226 additions and 445 deletions
|
@ -198,11 +198,6 @@ issues:
|
||||||
linters:
|
linters:
|
||||||
- staticcheck
|
- staticcheck
|
||||||
|
|
||||||
# todo(bdd): replace in go 1.13
|
|
||||||
- path: proxy/proxy.go
|
|
||||||
text: "copylocks: assignment copies lock value to transport"
|
|
||||||
linters:
|
|
||||||
- govet
|
|
||||||
# Independently from option `exclude` we use default exclude patterns,
|
# Independently from option `exclude` we use default exclude patterns,
|
||||||
# it can be disabled by this option. To list all
|
# it can be disabled by this option. To list all
|
||||||
# excluded by default patterns execute `golangci-lint run --help`.
|
# excluded by default patterns execute `golangci-lint run --help`.
|
||||||
|
|
|
@ -16,6 +16,10 @@ import (
|
||||||
"github.com/pomerium/pomerium/internal/frontend"
|
"github.com/pomerium/pomerium/internal/frontend"
|
||||||
"github.com/pomerium/pomerium/internal/identity"
|
"github.com/pomerium/pomerium/internal/identity"
|
||||||
"github.com/pomerium/pomerium/internal/sessions"
|
"github.com/pomerium/pomerium/internal/sessions"
|
||||||
|
"github.com/pomerium/pomerium/internal/sessions/cache"
|
||||||
|
"github.com/pomerium/pomerium/internal/sessions/cookie"
|
||||||
|
"github.com/pomerium/pomerium/internal/sessions/header"
|
||||||
|
"github.com/pomerium/pomerium/internal/sessions/queryparam"
|
||||||
"github.com/pomerium/pomerium/internal/urlutil"
|
"github.com/pomerium/pomerium/internal/urlutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -49,6 +53,8 @@ type Authenticate struct {
|
||||||
// authentication flow
|
// authentication flow
|
||||||
RedirectURL *url.URL
|
RedirectURL *url.URL
|
||||||
|
|
||||||
|
// values related to cross service communication
|
||||||
|
//
|
||||||
// sharedKey is used to encrypt and authenticate data between services
|
// sharedKey is used to encrypt and authenticate data between services
|
||||||
sharedKey string
|
sharedKey string
|
||||||
// sharedCipher is used to encrypt data for use between services
|
// sharedCipher is used to encrypt data for use between services
|
||||||
|
@ -57,15 +63,20 @@ type Authenticate struct {
|
||||||
// by other services
|
// by other services
|
||||||
sharedEncoder encoding.MarshalUnmarshaler
|
sharedEncoder encoding.MarshalUnmarshaler
|
||||||
|
|
||||||
// data related to this service only
|
// values related to user sessions
|
||||||
cookieOptions *sessions.CookieOptions
|
//
|
||||||
// cookieSecret is the secret to encrypt and authenticate data for this service
|
// cookieSecret is the secret to encrypt and authenticate session data
|
||||||
cookieSecret []byte
|
cookieSecret []byte
|
||||||
// is the cipher to use to encrypt data for this service
|
// cookieCipher is the cipher to use to encrypt/decrypt session data
|
||||||
cookieCipher cipher.AEAD
|
cookieCipher cipher.AEAD
|
||||||
sessionStore sessions.SessionStore
|
// encryptedEncoder is the encoder used to marshal and unmarshal session data
|
||||||
encryptedEncoder encoding.MarshalUnmarshaler
|
encryptedEncoder encoding.MarshalUnmarshaler
|
||||||
sessionStores []sessions.SessionStore
|
// sessionStore is the session store used to persist a user's session
|
||||||
|
sessionStore sessions.SessionStore
|
||||||
|
cookieOptions *cookie.Options
|
||||||
|
|
||||||
|
// sessionLoaders are a collection of session loaders to attempt to pull
|
||||||
|
// a user's session state from
|
||||||
sessionLoaders []sessions.SessionLoader
|
sessionLoaders []sessions.SessionLoader
|
||||||
|
|
||||||
// provider is the interface to interacting with the identity provider (IdP)
|
// provider is the interface to interacting with the identity provider (IdP)
|
||||||
|
@ -92,7 +103,7 @@ func New(opts config.Options) (*Authenticate, error) {
|
||||||
cookieCipher, _ := cryptutil.NewAEADCipher(decodedCookieSecret)
|
cookieCipher, _ := cryptutil.NewAEADCipher(decodedCookieSecret)
|
||||||
encryptedEncoder := ecjson.New(cookieCipher)
|
encryptedEncoder := ecjson.New(cookieCipher)
|
||||||
|
|
||||||
cookieOptions := &sessions.CookieOptions{
|
cookieOptions := &cookie.Options{
|
||||||
Name: opts.CookieName,
|
Name: opts.CookieName,
|
||||||
Domain: opts.CookieDomain,
|
Domain: opts.CookieDomain,
|
||||||
Secure: opts.CookieSecure,
|
Secure: opts.CookieSecure,
|
||||||
|
@ -100,12 +111,13 @@ func New(opts config.Options) (*Authenticate, error) {
|
||||||
Expire: opts.CookieExpire,
|
Expire: opts.CookieExpire,
|
||||||
}
|
}
|
||||||
|
|
||||||
cookieStore, err := sessions.NewCookieStore(cookieOptions, encryptedEncoder)
|
cookieStore, err := cookie.NewStore(cookieOptions, encryptedEncoder)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
qpStore := sessions.NewQueryParamStore(encryptedEncoder, "pomerium_programmatic_token")
|
cacheStore := cache.NewStore(encryptedEncoder, cookieStore, opts.CookieName)
|
||||||
headerStore := sessions.NewHeaderStore(encryptedEncoder, "Pomerium")
|
qpStore := queryparam.NewStore(encryptedEncoder, "pomerium_programmatic_token")
|
||||||
|
headerStore := header.NewStore(encryptedEncoder, "Pomerium")
|
||||||
|
|
||||||
redirectURL, _ := urlutil.DeepCopy(opts.AuthenticateURL)
|
redirectURL, _ := urlutil.DeepCopy(opts.AuthenticateURL)
|
||||||
redirectURL.Path = callbackPath
|
redirectURL.Path = callbackPath
|
||||||
|
@ -135,10 +147,9 @@ func New(opts config.Options) (*Authenticate, error) {
|
||||||
cookieSecret: decodedCookieSecret,
|
cookieSecret: decodedCookieSecret,
|
||||||
cookieCipher: cookieCipher,
|
cookieCipher: cookieCipher,
|
||||||
cookieOptions: cookieOptions,
|
cookieOptions: cookieOptions,
|
||||||
sessionStore: cookieStore,
|
sessionStore: cacheStore,
|
||||||
encryptedEncoder: encryptedEncoder,
|
encryptedEncoder: encryptedEncoder,
|
||||||
sessionLoaders: []sessions.SessionLoader{qpStore, headerStore, cookieStore},
|
sessionLoaders: []sessions.SessionLoader{cacheStore, qpStore, headerStore, cookieStore},
|
||||||
sessionStores: []sessions.SessionStore{cookieStore, qpStore},
|
|
||||||
// IdP
|
// IdP
|
||||||
provider: provider,
|
provider: provider,
|
||||||
|
|
||||||
|
|
|
@ -72,15 +72,18 @@ func TestOptions_Validate(t *testing.T) {
|
||||||
|
|
||||||
func TestNew(t *testing.T) {
|
func TestNew(t *testing.T) {
|
||||||
good := newTestOptions(t)
|
good := newTestOptions(t)
|
||||||
|
good.CookieName = "A"
|
||||||
|
|
||||||
badRedirectURL := newTestOptions(t)
|
badRedirectURL := newTestOptions(t)
|
||||||
badRedirectURL.AuthenticateURL = nil
|
badRedirectURL.AuthenticateURL = nil
|
||||||
|
badRedirectURL.CookieName = "B"
|
||||||
|
|
||||||
badCookieName := newTestOptions(t)
|
badCookieName := newTestOptions(t)
|
||||||
badCookieName.CookieName = ""
|
badCookieName.CookieName = ""
|
||||||
|
|
||||||
badProvider := newTestOptions(t)
|
badProvider := newTestOptions(t)
|
||||||
badProvider.Provider = ""
|
badProvider.Provider = ""
|
||||||
|
badProvider.CookieName = "C"
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package authenticate // import "github.com/pomerium/pomerium/authenticate"
|
package authenticate // import "github.com/pomerium/pomerium/authenticate"
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
|
@ -18,6 +19,7 @@ import (
|
||||||
"github.com/pomerium/pomerium/internal/log"
|
"github.com/pomerium/pomerium/internal/log"
|
||||||
"github.com/pomerium/pomerium/internal/middleware"
|
"github.com/pomerium/pomerium/internal/middleware"
|
||||||
"github.com/pomerium/pomerium/internal/sessions"
|
"github.com/pomerium/pomerium/internal/sessions"
|
||||||
|
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
||||||
"github.com/pomerium/pomerium/internal/urlutil"
|
"github.com/pomerium/pomerium/internal/urlutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -58,6 +60,7 @@ func (a *Authenticate) Handler() http.Handler {
|
||||||
v.Use(a.VerifySession)
|
v.Use(a.VerifySession)
|
||||||
v.Path("/sign_in").Handler(httputil.HandlerFunc(a.SignIn))
|
v.Path("/sign_in").Handler(httputil.HandlerFunc(a.SignIn))
|
||||||
v.Path("/sign_out").Handler(httputil.HandlerFunc(a.SignOut))
|
v.Path("/sign_out").Handler(httputil.HandlerFunc(a.SignOut))
|
||||||
|
v.Path("/refresh").Handler(httputil.HandlerFunc(a.Refresh)).Methods(http.MethodGet)
|
||||||
|
|
||||||
// programmatic access api endpoint
|
// programmatic access api endpoint
|
||||||
api := r.PathPrefix("/api").Subrouter()
|
api := r.PathPrefix("/api").Subrouter()
|
||||||
|
@ -73,12 +76,12 @@ func (a *Authenticate) VerifySession(next http.Handler) http.Handler {
|
||||||
return httputil.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
|
return httputil.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
|
||||||
state, err := sessions.FromContext(r.Context())
|
state, err := sessions.FromContext(r.Context())
|
||||||
if errors.Is(err, sessions.ErrExpired) {
|
if errors.Is(err, sessions.ErrExpired) {
|
||||||
if err := a.refresh(w, r, state); err != nil {
|
ctx, err := a.refresh(w, r, state)
|
||||||
|
if err != nil {
|
||||||
log.FromRequest(r).Info().Err(err).Msg("authenticate: verify session, refresh")
|
log.FromRequest(r).Info().Err(err).Msg("authenticate: verify session, refresh")
|
||||||
return a.reauthenticateOrFail(w, r, err)
|
return a.reauthenticateOrFail(w, r, err)
|
||||||
}
|
}
|
||||||
// redirect to restart middleware-chain following refresh
|
next.ServeHTTP(w, r.WithContext(ctx))
|
||||||
httputil.Redirect(w, r, urlutil.GetAbsoluteURL(r).String(), http.StatusFound)
|
|
||||||
return nil
|
return nil
|
||||||
} else if err != nil {
|
} else if err != nil {
|
||||||
log.FromRequest(r).Info().Err(err).Msg("authenticate: verify session")
|
log.FromRequest(r).Info().Err(err).Msg("authenticate: verify session")
|
||||||
|
@ -89,15 +92,18 @@ func (a *Authenticate) VerifySession(next http.Handler) http.Handler {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Authenticate) refresh(w http.ResponseWriter, r *http.Request, s *sessions.State) error {
|
func (a *Authenticate) refresh(w http.ResponseWriter, r *http.Request, s *sessions.State) (context.Context, error) {
|
||||||
newSession, err := a.provider.Refresh(r.Context(), s)
|
ctx, span := trace.StartSpan(r.Context(), "authenticate.VerifySession/refresh")
|
||||||
|
defer span.End()
|
||||||
|
newSession, err := a.provider.Refresh(ctx, s)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("authenticate: refresh failed: %w", err)
|
return nil, fmt.Errorf("authenticate: refresh failed: %w", err)
|
||||||
}
|
}
|
||||||
if err := a.sessionStore.SaveSession(w, r, newSession); err != nil {
|
if err := a.sessionStore.SaveSession(w, r, newSession); err != nil {
|
||||||
return fmt.Errorf("authenticate: refresh save failed: %w", err)
|
return nil, fmt.Errorf("authenticate: refresh save failed: %w", err)
|
||||||
}
|
}
|
||||||
return nil
|
// return the new session and add it to the current request context
|
||||||
|
return sessions.NewContext(ctx, newSession, err), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// RobotsTxt handles the /robots.txt route.
|
// RobotsTxt handles the /robots.txt route.
|
||||||
|
@ -158,7 +164,6 @@ func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) error {
|
||||||
encSession, err := a.encryptedEncoder.Marshal(newSession)
|
encSession, err := a.encryptedEncoder.Marshal(newSession)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return httputil.NewError(http.StatusBadRequest, err)
|
return httputil.NewError(http.StatusBadRequest, err)
|
||||||
|
|
||||||
}
|
}
|
||||||
callbackParams.Set(urlutil.QueryRefreshToken, string(encSession))
|
callbackParams.Set(urlutil.QueryRefreshToken, string(encSession))
|
||||||
callbackParams.Set(urlutil.QueryIsProgrammatic, "true")
|
callbackParams.Set(urlutil.QueryIsProgrammatic, "true")
|
||||||
|
@ -345,3 +350,27 @@ func (a *Authenticate) RefreshAPI(w http.ResponseWriter, r *http.Request) error
|
||||||
w.Write(jsonResponse)
|
w.Write(jsonResponse)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Refresh is called by the proxy service to handle backend session refresh.
|
||||||
|
//
|
||||||
|
// NOTE: The actual refresh is actually handled as part of the "VerifySession"
|
||||||
|
// middleware. This handler is responsible for creating a new route scoped
|
||||||
|
// session and returning it.
|
||||||
|
func (a *Authenticate) Refresh(w http.ResponseWriter, r *http.Request) error {
|
||||||
|
s, err := sessions.FromContext(r.Context())
|
||||||
|
if err != nil {
|
||||||
|
return httputil.NewError(http.StatusBadRequest, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
routeSession := s.NewSession(r.Host, []string{r.Host, r.FormValue("aud")})
|
||||||
|
routeSession.AccessTokenID = s.AccessTokenID
|
||||||
|
|
||||||
|
signedJWT, err := a.sharedEncoder.Marshal(routeSession.RouteSession())
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/jwt") // RFC 7519 : 10.3.1
|
||||||
|
w.Write(signedJWT)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
|
@ -11,17 +11,18 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/internal/httputil"
|
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/internal/cryptutil"
|
"github.com/pomerium/pomerium/internal/cryptutil"
|
||||||
"github.com/pomerium/pomerium/internal/encoding"
|
"github.com/pomerium/pomerium/internal/encoding"
|
||||||
"github.com/pomerium/pomerium/internal/encoding/mock"
|
"github.com/pomerium/pomerium/internal/encoding/mock"
|
||||||
"github.com/pomerium/pomerium/internal/frontend"
|
"github.com/pomerium/pomerium/internal/frontend"
|
||||||
|
"github.com/pomerium/pomerium/internal/httputil"
|
||||||
"github.com/pomerium/pomerium/internal/identity"
|
"github.com/pomerium/pomerium/internal/identity"
|
||||||
"github.com/pomerium/pomerium/internal/sessions"
|
"github.com/pomerium/pomerium/internal/sessions"
|
||||||
"github.com/pomerium/pomerium/internal/urlutil"
|
"github.com/pomerium/pomerium/internal/sessions/cookie"
|
||||||
|
mstore "github.com/pomerium/pomerium/internal/sessions/mock"
|
||||||
|
|
||||||
"github.com/google/go-cmp/cmp"
|
"github.com/google/go-cmp/cmp"
|
||||||
|
"github.com/pomerium/pomerium/internal/urlutil"
|
||||||
"golang.org/x/crypto/chacha20poly1305"
|
"golang.org/x/crypto/chacha20poly1305"
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
"gopkg.in/square/go-jose.v2/jwt"
|
"gopkg.in/square/go-jose.v2/jwt"
|
||||||
|
@ -32,7 +33,7 @@ func testAuthenticate() *Authenticate {
|
||||||
auth.RedirectURL, _ = url.Parse("https://auth.example.com/oauth/callback")
|
auth.RedirectURL, _ = url.Parse("https://auth.example.com/oauth/callback")
|
||||||
auth.sharedKey = cryptutil.NewBase64Key()
|
auth.sharedKey = cryptutil.NewBase64Key()
|
||||||
auth.cookieSecret = cryptutil.NewKey()
|
auth.cookieSecret = cryptutil.NewKey()
|
||||||
auth.cookieOptions = &sessions.CookieOptions{Name: "name"}
|
auth.cookieOptions = &cookie.Options{Name: "name"}
|
||||||
auth.templates = template.Must(frontend.NewTemplates())
|
auth.templates = template.Must(frontend.NewTemplates())
|
||||||
return &auth
|
return &auth
|
||||||
}
|
}
|
||||||
|
@ -112,19 +113,19 @@ func TestAuthenticate_SignIn(t *testing.T) {
|
||||||
encoder encoding.MarshalUnmarshaler
|
encoder encoding.MarshalUnmarshaler
|
||||||
wantCode int
|
wantCode int
|
||||||
}{
|
}{
|
||||||
{"good", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
|
{"good", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
|
||||||
{"session not valid", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(-10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
|
{"session not valid", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(-10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
|
||||||
{"bad redirect uri query", "", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "^^^"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusBadRequest},
|
{"bad redirect uri query", "", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "^^^"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusBadRequest},
|
||||||
{"bad marshal", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{MarshalError: errors.New("error")}, http.StatusBadRequest},
|
{"bad marshal", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{MarshalError: errors.New("error")}, http.StatusBadRequest},
|
||||||
{"session error", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, &sessions.MockSessionStore{LoadError: errors.New("error")}, identity.MockProvider{}, &mock.Encoder{}, http.StatusBadRequest},
|
{"session error", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{LoadError: errors.New("error")}, identity.MockProvider{}, &mock.Encoder{}, http.StatusBadRequest},
|
||||||
{"good with different programmatic redirect", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/", urlutil.QueryCallbackURI: "https://some.example"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
|
{"good with different programmatic redirect", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/", urlutil.QueryCallbackURI: "https://some.example"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
|
||||||
{"encrypted encoder error", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/", urlutil.QueryCallbackURI: "https://some.example"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{MarshalError: errors.New("error")}, http.StatusBadRequest},
|
{"encrypted encoder error", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/", urlutil.QueryCallbackURI: "https://some.example"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{MarshalError: errors.New("error")}, http.StatusBadRequest},
|
||||||
{"good with callback uri set", "https", "corp.example.example", map[string]string{urlutil.QueryCallbackURI: "https://some.example/", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
|
{"good with callback uri set", "https", "corp.example.example", map[string]string{urlutil.QueryCallbackURI: "https://some.example/", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
|
||||||
{"bad callback uri set", "https", "corp.example.example", map[string]string{urlutil.QueryCallbackURI: "^", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusBadRequest},
|
{"bad callback uri set", "https", "corp.example.example", map[string]string{urlutil.QueryCallbackURI: "^", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusBadRequest},
|
||||||
{"good programmatic request", "https", "corp.example.example", map[string]string{urlutil.QueryIsProgrammatic: "true", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
|
{"good programmatic request", "https", "corp.example.example", map[string]string{urlutil.QueryIsProgrammatic: "true", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
|
||||||
{"good additional audience", "https", "corp.example.example", map[string]string{urlutil.QueryForwardAuth: "x.y.z", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
|
{"good additional audience", "https", "corp.example.example", map[string]string{urlutil.QueryForwardAuth: "x.y.z", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
|
||||||
{"good user impersonate", "https", "corp.example.example", map[string]string{urlutil.QueryImpersonateAction: "set", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
|
{"good user impersonate", "https", "corp.example.example", map[string]string{urlutil.QueryImpersonateAction: "set", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
|
||||||
{"bad user impersonate save failure", "https", "corp.example.example", map[string]string{urlutil.QueryImpersonateAction: "set", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &sessions.MockSessionStore{SaveError: errors.New("err"), Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusBadRequest},
|
{"bad user impersonate save failure", "https", "corp.example.example", map[string]string{urlutil.QueryImpersonateAction: "set", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{SaveError: errors.New("err"), Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusBadRequest},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
@ -136,7 +137,7 @@ func TestAuthenticate_SignIn(t *testing.T) {
|
||||||
sharedEncoder: tt.encoder,
|
sharedEncoder: tt.encoder,
|
||||||
encryptedEncoder: tt.encoder,
|
encryptedEncoder: tt.encoder,
|
||||||
sharedCipher: aead,
|
sharedCipher: aead,
|
||||||
cookieOptions: &sessions.CookieOptions{
|
cookieOptions: &cookie.Options{
|
||||||
Name: "cookie",
|
Name: "cookie",
|
||||||
Domain: "foo",
|
Domain: "foo",
|
||||||
},
|
},
|
||||||
|
@ -186,10 +187,10 @@ func TestAuthenticate_SignOut(t *testing.T) {
|
||||||
wantCode int
|
wantCode int
|
||||||
wantBody string
|
wantBody string
|
||||||
}{
|
}{
|
||||||
{"good post", http.MethodPost, nil, "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, http.StatusFound, ""},
|
{"good post", http.MethodPost, nil, "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, http.StatusFound, ""},
|
||||||
{"failed revoke", http.MethodPost, nil, "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{RevokeError: errors.New("OH NO")}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, http.StatusBadRequest, "{\"Status\":400,\"Error\":\"Bad Request: OH NO\"}\n"},
|
{"failed revoke", http.MethodPost, nil, "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{RevokeError: errors.New("OH NO")}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, http.StatusBadRequest, "{\"Status\":400,\"Error\":\"Bad Request: OH NO\"}\n"},
|
||||||
{"load session error", http.MethodPost, errors.New("error"), "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{RevokeError: errors.New("OH NO")}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, http.StatusBadRequest, "{\"Status\":400,\"Error\":\"Bad Request: error\"}\n"},
|
{"load session error", http.MethodPost, errors.New("error"), "https://corp.pomerium.io/", "sig", "ts", identity.MockProvider{RevokeError: errors.New("OH NO")}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, http.StatusBadRequest, "{\"Status\":400,\"Error\":\"Bad Request: error\"}\n"},
|
||||||
{"bad redirect uri", http.MethodPost, nil, "corp.pomerium.io/", "sig", "ts", identity.MockProvider{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, http.StatusBadRequest, "{\"Status\":400,\"Error\":\"Bad Request: corp.pomerium.io/ url does contain a valid scheme\"}\n"},
|
{"bad redirect uri", http.MethodPost, nil, "corp.pomerium.io/", "sig", "ts", identity.MockProvider{}, &mstore.Store{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, http.StatusBadRequest, "{\"Status\":400,\"Error\":\"Bad Request: corp.pomerium.io/ url does contain a valid scheme\"}\n"},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
@ -247,19 +248,19 @@ func TestAuthenticate_OAuthCallback(t *testing.T) {
|
||||||
want string
|
want string
|
||||||
wantCode int
|
wantCode int
|
||||||
}{
|
}{
|
||||||
{"good", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusFound},
|
{"good", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusFound},
|
||||||
{"failed authenticate", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateError: errors.New("error")}, "", http.StatusInternalServerError},
|
{"failed authenticate", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateError: errors.New("error")}, "", http.StatusInternalServerError},
|
||||||
{"failed save session", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{SaveError: errors.New("error")}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "", http.StatusInternalServerError},
|
{"failed save session", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{SaveError: errors.New("error")}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "", http.StatusInternalServerError},
|
||||||
{"provider returned error", http.MethodGet, time.Now().Unix(), "", "", "", "idp error", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "", http.StatusBadRequest},
|
{"provider returned error", http.MethodGet, time.Now().Unix(), "", "", "", "idp error", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "", http.StatusBadRequest},
|
||||||
{"empty code", http.MethodGet, time.Now().Unix(), "", "", "", "", "", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "", http.StatusBadRequest},
|
{"empty code", http.MethodGet, time.Now().Unix(), "", "", "", "", "", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "", http.StatusBadRequest},
|
||||||
{"invalid redirect uri", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "", http.StatusBadRequest},
|
{"invalid redirect uri", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "", http.StatusBadRequest},
|
||||||
{"bad redirect uri", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "http://^^^", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusBadRequest},
|
{"bad redirect uri", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "http://^^^", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusBadRequest},
|
||||||
{"bad timing - too soon", http.MethodGet, time.Now().Add(1 * time.Hour).Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusBadRequest},
|
{"bad timing - too soon", http.MethodGet, time.Now().Add(1 * time.Hour).Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusBadRequest},
|
||||||
{"bad timing - expired", http.MethodGet, time.Now().Add(-1 * time.Hour).Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusBadRequest},
|
{"bad timing - expired", http.MethodGet, time.Now().Add(-1 * time.Hour).Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusBadRequest},
|
||||||
{"bad base64", http.MethodGet, time.Now().Unix(), "", "", "^", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusBadRequest},
|
{"bad base64", http.MethodGet, time.Now().Unix(), "", "", "^", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusBadRequest},
|
||||||
{"too many seperators", http.MethodGet, time.Now().Unix(), "", "", "|ok|now|what", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusBadRequest},
|
{"too many seperators", http.MethodGet, time.Now().Unix(), "", "", "|ok|now|what", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusBadRequest},
|
||||||
{"bad hmac", http.MethodGet, time.Now().Unix(), "", "NOTMAC", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusBadRequest},
|
{"bad hmac", http.MethodGet, time.Now().Unix(), "", "NOTMAC", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusBadRequest},
|
||||||
{"bad hmac", http.MethodGet, time.Now().Unix(), base64.URLEncoding.EncodeToString([]byte("malformed_state")), "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &sessions.MockSessionStore{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusBadRequest},
|
{"bad hmac", http.MethodGet, time.Now().Unix(), base64.URLEncoding.EncodeToString([]byte("malformed_state")), "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, "https://corp.pomerium.io", http.StatusBadRequest},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
@ -326,12 +327,12 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) {
|
||||||
|
|
||||||
wantStatus int
|
wantStatus int
|
||||||
}{
|
}{
|
||||||
{"good", nil, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, http.StatusOK},
|
{"good", nil, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, http.StatusOK},
|
||||||
{"invalid session", nil, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, errors.New("hi"), identity.MockProvider{}, http.StatusFound},
|
{"invalid session", nil, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, errors.New("hi"), identity.MockProvider{}, http.StatusFound},
|
||||||
{"good refresh expired", nil, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, http.StatusFound},
|
{"good refresh expired", nil, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, http.StatusOK},
|
||||||
{"expired,refresh error", nil, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{RefreshError: errors.New("error")}, http.StatusFound},
|
{"expired,refresh error", nil, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{RefreshError: errors.New("error")}, http.StatusFound},
|
||||||
{"expired,save error", nil, &sessions.MockSessionStore{SaveError: errors.New("error"), Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, http.StatusFound},
|
{"expired,save error", nil, &mstore.Store{SaveError: errors.New("error"), Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, http.StatusFound},
|
||||||
{"expired XHR,refresh error", map[string]string{"X-Requested-With": "XmlHttpRequest"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{RefreshError: errors.New("error")}, http.StatusUnauthorized},
|
{"expired XHR,refresh error", map[string]string{"X-Requested-With": "XmlHttpRequest"}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{RefreshError: errors.New("error")}, http.StatusUnauthorized},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
@ -384,11 +385,11 @@ func TestAuthenticate_RefreshAPI(t *testing.T) {
|
||||||
|
|
||||||
wantStatus int
|
wantStatus int
|
||||||
}{
|
}{
|
||||||
{"good", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, mock.Encoder{MarshalResponse: []byte("ok")}, mock.Encoder{MarshalResponse: []byte("ok")}, http.StatusOK},
|
{"good", &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, mock.Encoder{MarshalResponse: []byte("ok")}, mock.Encoder{MarshalResponse: []byte("ok")}, http.StatusOK},
|
||||||
{"refresh error", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshError: errors.New("error")}, mock.Encoder{MarshalResponse: []byte("ok")}, mock.Encoder{MarshalResponse: []byte("ok")}, http.StatusInternalServerError},
|
{"refresh error", &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshError: errors.New("error")}, mock.Encoder{MarshalResponse: []byte("ok")}, mock.Encoder{MarshalResponse: []byte("ok")}, http.StatusInternalServerError},
|
||||||
{"session is not refreshable error", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, errors.New("session error"), identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, mock.Encoder{MarshalResponse: []byte("ok")}, mock.Encoder{MarshalResponse: []byte("ok")}, http.StatusBadRequest},
|
{"session is not refreshable error", &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, errors.New("session error"), identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, mock.Encoder{MarshalResponse: []byte("ok")}, mock.Encoder{MarshalResponse: []byte("ok")}, http.StatusBadRequest},
|
||||||
{"secret encoder failed", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, mock.Encoder{MarshalError: errors.New("error")}, mock.Encoder{MarshalResponse: []byte("ok")}, http.StatusInternalServerError},
|
{"secret encoder failed", &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, mock.Encoder{MarshalError: errors.New("error")}, mock.Encoder{MarshalResponse: []byte("ok")}, http.StatusInternalServerError},
|
||||||
{"shared encoder failed", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, mock.Encoder{MarshalResponse: []byte("ok")}, mock.Encoder{MarshalError: errors.New("error")}, http.StatusInternalServerError},
|
{"shared encoder failed", &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, mock.Encoder{MarshalResponse: []byte("ok")}, mock.Encoder{MarshalError: errors.New("error")}, http.StatusInternalServerError},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
@ -423,3 +424,54 @@ func TestAuthenticate_RefreshAPI(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
func TestAuthenticate_Refresh(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
|
||||||
|
session sessions.SessionStore
|
||||||
|
ctxError error
|
||||||
|
|
||||||
|
provider identity.Authenticator
|
||||||
|
secretEncoder encoding.MarshalUnmarshaler
|
||||||
|
sharedEncoder encoding.MarshalUnmarshaler
|
||||||
|
|
||||||
|
wantStatus int
|
||||||
|
}{
|
||||||
|
{"good", &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, mock.Encoder{MarshalResponse: []byte("ok")}, mock.Encoder{MarshalResponse: []byte("ok")}, http.StatusOK},
|
||||||
|
{"bad session", &mstore.Store{}, errors.New("err"), identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, mock.Encoder{MarshalResponse: []byte("ok")}, mock.Encoder{MarshalResponse: []byte("ok")}, http.StatusBadRequest},
|
||||||
|
{"encoder error", &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, mock.Encoder{MarshalResponse: []byte("ok")}, mock.Encoder{MarshalError: errors.New("err")}, http.StatusInternalServerError},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
aead, err := chacha20poly1305.NewX(cryptutil.NewKey())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
a := Authenticate{
|
||||||
|
sharedKey: cryptutil.NewBase64Key(),
|
||||||
|
cookieSecret: cryptutil.NewKey(),
|
||||||
|
RedirectURL: uriParseHelper("https://authenticate.corp.beyondperimeter.com"),
|
||||||
|
encryptedEncoder: tt.secretEncoder,
|
||||||
|
sharedEncoder: tt.sharedEncoder,
|
||||||
|
sessionStore: tt.session,
|
||||||
|
provider: tt.provider,
|
||||||
|
cookieCipher: aead,
|
||||||
|
}
|
||||||
|
r := httptest.NewRequest("GET", "/", nil)
|
||||||
|
state, _ := tt.session.LoadSession(r)
|
||||||
|
ctx := r.Context()
|
||||||
|
ctx = sessions.NewContext(ctx, state, tt.ctxError)
|
||||||
|
r = r.WithContext(ctx)
|
||||||
|
|
||||||
|
r.Header.Set("Accept", "application/json")
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
httputil.HandlerFunc(a.Refresh).ServeHTTP(w, r)
|
||||||
|
if status := w.Code; status != tt.wantStatus {
|
||||||
|
t.Errorf("VerifySession() error = %v, wantErr %v\n%v", w.Result().StatusCode, tt.wantStatus, w.Body.String())
|
||||||
|
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -11,12 +11,13 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/fsnotify/fsnotify"
|
|
||||||
"github.com/pomerium/pomerium/internal/cryptutil"
|
"github.com/pomerium/pomerium/internal/cryptutil"
|
||||||
"github.com/pomerium/pomerium/internal/log"
|
"github.com/pomerium/pomerium/internal/log"
|
||||||
"github.com/pomerium/pomerium/internal/telemetry/metrics"
|
"github.com/pomerium/pomerium/internal/telemetry/metrics"
|
||||||
"github.com/pomerium/pomerium/internal/urlutil"
|
"github.com/pomerium/pomerium/internal/urlutil"
|
||||||
|
|
||||||
|
"github.com/cespare/xxhash/v2"
|
||||||
|
"github.com/fsnotify/fsnotify"
|
||||||
"github.com/mitchellh/hashstructure"
|
"github.com/mitchellh/hashstructure"
|
||||||
"github.com/spf13/viper"
|
"github.com/spf13/viper"
|
||||||
"gopkg.in/yaml.v2"
|
"gopkg.in/yaml.v2"
|
||||||
|
@ -477,7 +478,7 @@ type OptionsUpdater interface {
|
||||||
|
|
||||||
// Checksum returns the checksum of the current options struct
|
// Checksum returns the checksum of the current options struct
|
||||||
func (o *Options) Checksum() string {
|
func (o *Options) Checksum() string {
|
||||||
hash, err := hashstructure.Hash(o, nil)
|
hash, err := hashstructure.Hash(o, &hashstructure.HashOptions{Hasher: xxhash.New()})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warn().Err(err).Msg("config: checksum failure")
|
log.Warn().Err(err).Msg("config: checksum failure")
|
||||||
return "no checksum available"
|
return "no checksum available"
|
||||||
|
|
4
go.mod
4
go.mod
|
@ -6,9 +6,9 @@ require (
|
||||||
cloud.google.com/go v0.49.0 // indirect
|
cloud.google.com/go v0.49.0 // indirect
|
||||||
contrib.go.opencensus.io/exporter/jaeger v0.2.0
|
contrib.go.opencensus.io/exporter/jaeger v0.2.0
|
||||||
contrib.go.opencensus.io/exporter/prometheus v0.1.0
|
contrib.go.opencensus.io/exporter/prometheus v0.1.0
|
||||||
github.com/cespare/xxhash/v2 v2.1.1 // indirect
|
github.com/cespare/xxhash/v2 v2.1.1
|
||||||
github.com/fsnotify/fsnotify v1.4.7
|
github.com/fsnotify/fsnotify v1.4.7
|
||||||
github.com/golang/groupcache v0.0.0-20191027212112-611e8accdfc9 // indirect
|
github.com/golang/groupcache v0.0.0-20191227052852-215e87163ea7
|
||||||
github.com/golang/mock v1.3.1
|
github.com/golang/mock v1.3.1
|
||||||
github.com/golang/protobuf v1.3.2
|
github.com/golang/protobuf v1.3.2
|
||||||
github.com/google/go-cmp v0.3.1
|
github.com/google/go-cmp v0.3.1
|
||||||
|
|
5
go.sum
5
go.sum
|
@ -71,6 +71,8 @@ github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6 h1:ZgQEtGgCBiWRM
|
||||||
github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
|
github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
|
||||||
github.com/golang/groupcache v0.0.0-20191027212112-611e8accdfc9 h1:uHTyIjqVhYRhLbJ8nIiOJHkEZZ+5YoOsAbD3sk82NiE=
|
github.com/golang/groupcache v0.0.0-20191027212112-611e8accdfc9 h1:uHTyIjqVhYRhLbJ8nIiOJHkEZZ+5YoOsAbD3sk82NiE=
|
||||||
github.com/golang/groupcache v0.0.0-20191027212112-611e8accdfc9/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
|
github.com/golang/groupcache v0.0.0-20191027212112-611e8accdfc9/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
|
||||||
|
github.com/golang/groupcache v0.0.0-20191227052852-215e87163ea7 h1:5ZkaAPbicIKTF2I64qf5Fh8Aa83Q/dnOafMYV0OMwjA=
|
||||||
|
github.com/golang/groupcache v0.0.0-20191227052852-215e87163ea7/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
|
||||||
github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A=
|
github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A=
|
||||||
github.com/golang/mock v1.2.0 h1:28o5sBqPkBsMGnC6b4MvE2TzSr5/AT4c/1fLqVGIwlk=
|
github.com/golang/mock v1.2.0 h1:28o5sBqPkBsMGnC6b4MvE2TzSr5/AT4c/1fLqVGIwlk=
|
||||||
github.com/golang/mock v1.2.0/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A=
|
github.com/golang/mock v1.2.0/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A=
|
||||||
|
@ -162,6 +164,7 @@ github.com/prometheus/client_golang v0.9.3/go.mod h1:/TN21ttK/J9q6uSwhBd54HahCDf
|
||||||
github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5FsnadC4Ky3P0J6CfImo=
|
github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5FsnadC4Ky3P0J6CfImo=
|
||||||
github.com/prometheus/client_golang v1.2.1 h1:JnMpQc6ppsNgw9QPAGF6Dod479itz7lvlsMzzNayLOI=
|
github.com/prometheus/client_golang v1.2.1 h1:JnMpQc6ppsNgw9QPAGF6Dod479itz7lvlsMzzNayLOI=
|
||||||
github.com/prometheus/client_golang v1.2.1/go.mod h1:XMU6Z2MjaRKVu/dC1qupJI9SiNkDYzz3xecMgSW/F+U=
|
github.com/prometheus/client_golang v1.2.1/go.mod h1:XMU6Z2MjaRKVu/dC1qupJI9SiNkDYzz3xecMgSW/F+U=
|
||||||
|
github.com/prometheus/client_golang v1.3.0 h1:miYCvYqFXtl/J9FIy8eNpBfYthAEFg+Ys0XyUVEcDsc=
|
||||||
github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo=
|
github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo=
|
||||||
github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90 h1:S/YWwWx/RA8rT8tKFRuGUZhuA90OyIBpPCXkcbwU8DE=
|
github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90 h1:S/YWwWx/RA8rT8tKFRuGUZhuA90OyIBpPCXkcbwU8DE=
|
||||||
github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
|
github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
|
||||||
|
@ -214,6 +217,7 @@ github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
|
||||||
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
|
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
|
||||||
github.com/spf13/viper v1.5.0 h1:GpsTwfsQ27oS/Aha/6d1oD7tpKIqWnOA6tgOX9HHkt4=
|
github.com/spf13/viper v1.5.0 h1:GpsTwfsQ27oS/Aha/6d1oD7tpKIqWnOA6tgOX9HHkt4=
|
||||||
github.com/spf13/viper v1.5.0/go.mod h1:AkYRkVJF8TkSG/xet6PzXX+l39KhhXa2pdqVSxnTcn4=
|
github.com/spf13/viper v1.5.0/go.mod h1:AkYRkVJF8TkSG/xet6PzXX+l39KhhXa2pdqVSxnTcn4=
|
||||||
|
github.com/spf13/viper v1.6.1 h1:VPZzIkznI1YhVMRi6vNFLHSwhnhReBfgTxIPccpfdZk=
|
||||||
github.com/spf13/viper v1.6.1/go.mod h1:t3iDnF5Jlj76alVNuyFBk5oUMCvsrkbvZK0WQdfDi5k=
|
github.com/spf13/viper v1.6.1/go.mod h1:t3iDnF5Jlj76alVNuyFBk5oUMCvsrkbvZK0WQdfDi5k=
|
||||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||||
github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||||
|
@ -384,6 +388,7 @@ gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8
|
||||||
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY=
|
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY=
|
||||||
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||||
gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI=
|
gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI=
|
||||||
|
gopkg.in/ini.v1 v1.51.0 h1:AQvPpx3LzTDM0AjnIRlVFwFFGC+npRopjZxLJj6gdno=
|
||||||
gopkg.in/ini.v1 v1.51.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k=
|
gopkg.in/ini.v1 v1.51.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k=
|
||||||
gopkg.in/resty.v1 v1.12.0/go.mod h1:mDo4pnntr5jdWRML875a/NmxYqAlA73dVijT2AXvQQo=
|
gopkg.in/resty.v1 v1.12.0/go.mod h1:mDo4pnntr5jdWRML875a/NmxYqAlA73dVijT2AXvQQo=
|
||||||
gopkg.in/square/go-jose.v2 v2.4.0 h1:0kXPskUMGAXXWJlP05ktEMOV0vmzFQUWw6d+aZJQU8A=
|
gopkg.in/square/go-jose.v2 v2.4.0 h1:0kXPskUMGAXXWJlP05ktEMOV0vmzFQUWw6d+aZJQU8A=
|
||||||
|
|
|
@ -1,5 +1,13 @@
|
||||||
package mock // import "github.com/pomerium/pomerium/internal/encoding/mock"
|
package mock // import "github.com/pomerium/pomerium/internal/encoding/mock"
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/pomerium/pomerium/internal/encoding"
|
||||||
|
)
|
||||||
|
|
||||||
|
var _ encoding.MarshalUnmarshaler = &Encoder{}
|
||||||
|
var _ encoding.Marshaler = &Encoder{}
|
||||||
|
var _ encoding.Unmarshaler = &Encoder{}
|
||||||
|
|
||||||
// Encoder MockCSRFStore is a mock implementation of Cipher.
|
// Encoder MockCSRFStore is a mock implementation of Cipher.
|
||||||
type Encoder struct {
|
type Encoder struct {
|
||||||
MarshalResponse []byte
|
MarshalResponse []byte
|
||||||
|
|
|
@ -8,23 +8,21 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"go.opencensus.io/plugin/ochttp"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ErrTokenRevoked signifies a token revokation or expiration error
|
// ErrTokenRevoked signifies a token revokation or expiration error
|
||||||
var ErrTokenRevoked = errors.New("token expired or revoked")
|
var ErrTokenRevoked = errors.New("token expired or revoked")
|
||||||
|
|
||||||
var httpClient = &http.Client{
|
// DefaultClient avoids leaks by setting an upper limit for timeouts.
|
||||||
Timeout: time.Second * 5,
|
var DefaultClient = &http.Client{
|
||||||
Transport: &http.Transport{
|
Timeout: 1 * time.Minute,
|
||||||
Dial: (&net.Dialer{
|
//todo(bdd): incorporate metrics.HTTPMetricsRoundTripper
|
||||||
Timeout: 2 * time.Second,
|
Transport: &ochttp.Transport{},
|
||||||
}).Dial,
|
|
||||||
TLSHandshakeTimeout: 2 * time.Second,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Client provides a simple helper interface to make HTTP requests
|
// Client provides a simple helper interface to make HTTP requests
|
||||||
|
@ -36,9 +34,11 @@ func Client(ctx context.Context, method, endpoint, userAgent string, headers map
|
||||||
case http.MethodGet:
|
case http.MethodGet:
|
||||||
// error checking skipped because we are just parsing in
|
// error checking skipped because we are just parsing in
|
||||||
// order to make a copy of an existing URL
|
// order to make a copy of an existing URL
|
||||||
|
if params != nil {
|
||||||
u, _ := url.Parse(endpoint)
|
u, _ := url.Parse(endpoint)
|
||||||
u.RawQuery = params.Encode()
|
u.RawQuery = params.Encode()
|
||||||
endpoint = u.String()
|
endpoint = u.String()
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
return fmt.Errorf(http.StatusText(http.StatusBadRequest))
|
return fmt.Errorf(http.StatusText(http.StatusBadRequest))
|
||||||
}
|
}
|
||||||
|
@ -52,7 +52,7 @@ func Client(ctx context.Context, method, endpoint, userAgent string, headers map
|
||||||
req.Header.Set(k, v)
|
req.Header.Set(k, v)
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, err := httpClient.Do(req)
|
resp, err := DefaultClient.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -79,7 +79,6 @@ func Client(ctx context.Context, method, endpoint, userAgent string, headers map
|
||||||
return fmt.Errorf(http.StatusText(resp.StatusCode))
|
return fmt.Errorf(http.StatusText(resp.StatusCode))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if response != nil {
|
if response != nil {
|
||||||
err := json.Unmarshal(respBody, &response)
|
err := json.Unmarshal(respBody, &response)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
131
internal/sessions/cache/cache_store.go
vendored
Normal file
131
internal/sessions/cache/cache_store.go
vendored
Normal file
|
@ -0,0 +1,131 @@
|
||||||
|
package cache // import "github.com/pomerium/pomerium/internal/sessions/cache"
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/golang/groupcache"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/internal/encoding"
|
||||||
|
"github.com/pomerium/pomerium/internal/log"
|
||||||
|
"github.com/pomerium/pomerium/internal/sessions"
|
||||||
|
)
|
||||||
|
|
||||||
|
var _ sessions.SessionStore = &Store{}
|
||||||
|
var _ sessions.SessionLoader = &Store{}
|
||||||
|
|
||||||
|
const (
|
||||||
|
defaultQueryParamKey = "ati"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Store implements the session store interface using a distributed cache.
|
||||||
|
type Store struct {
|
||||||
|
name string
|
||||||
|
encoder encoding.Marshaler
|
||||||
|
decoder encoding.Unmarshaler
|
||||||
|
|
||||||
|
cache *groupcache.Group
|
||||||
|
wrappedStore sessions.SessionStore
|
||||||
|
}
|
||||||
|
|
||||||
|
// defaultCacheSize is ~10MB
|
||||||
|
var defaultCacheSize int64 = 10 << 20
|
||||||
|
|
||||||
|
// NewStore creates a new session store built on the distributed caching library
|
||||||
|
// groupcache. On a cache miss, the cache store attempts to fallback to another
|
||||||
|
// SessionStore implementation.
|
||||||
|
func NewStore(enc encoding.MarshalUnmarshaler, wrappedStore sessions.SessionStore, name string) *Store {
|
||||||
|
store := &Store{
|
||||||
|
name: name,
|
||||||
|
encoder: enc,
|
||||||
|
decoder: enc,
|
||||||
|
wrappedStore: wrappedStore,
|
||||||
|
}
|
||||||
|
|
||||||
|
store.cache = groupcache.NewGroup(name, defaultCacheSize, groupcache.GetterFunc(
|
||||||
|
func(ctx context.Context, id string, dest groupcache.Sink) error {
|
||||||
|
// fill the cache with session set as part of the request
|
||||||
|
// context set previously as part of SaveSession.
|
||||||
|
b := fromContext(ctx)
|
||||||
|
if len(b) == 0 {
|
||||||
|
return fmt.Errorf("sessions/cache: cannot fill key %s from ctx", id)
|
||||||
|
}
|
||||||
|
if err := dest.SetBytes(b); err != nil {
|
||||||
|
return fmt.Errorf("sessions/cache: sink error %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
))
|
||||||
|
|
||||||
|
return store
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadSession implements SessionLoaders's LoadSession method for cache store.
|
||||||
|
func (s *Store) LoadSession(r *http.Request) (*sessions.State, error) {
|
||||||
|
// look for our cache's key in the default query param
|
||||||
|
sessionID := r.URL.Query().Get(defaultQueryParamKey)
|
||||||
|
if sessionID == "" {
|
||||||
|
// if unset, fallback to default cache store
|
||||||
|
log.FromRequest(r).Debug().Msg("sessions/cache: no query param, trying wrapped loader")
|
||||||
|
return s.wrappedStore.LoadSession(r)
|
||||||
|
}
|
||||||
|
|
||||||
|
var b []byte
|
||||||
|
if err := s.cache.Get(r.Context(), sessionID, groupcache.AllocatingByteSliceSink(&b)); err != nil {
|
||||||
|
log.FromRequest(r).Debug().Err(err).Msg("sessions/cache: miss, trying wrapped loader")
|
||||||
|
return s.wrappedStore.LoadSession(r)
|
||||||
|
}
|
||||||
|
var session sessions.State
|
||||||
|
if err := s.decoder.Unmarshal(b, &session); err != nil {
|
||||||
|
log.FromRequest(r).Error().Err(err).Msg("sessions/cache: unmarshal")
|
||||||
|
return nil, sessions.ErrMalformed
|
||||||
|
}
|
||||||
|
return &session, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearSession implements SessionStore's ClearSession for the cache store.
|
||||||
|
// Since group cache has no explicit eviction, we just call the wrapped
|
||||||
|
// store's ClearSession method here.
|
||||||
|
func (s *Store) ClearSession(w http.ResponseWriter, r *http.Request) {
|
||||||
|
s.wrappedStore.ClearSession(w, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SaveSession implements SessionStore's SaveSession method for cache store.
|
||||||
|
func (s *Store) SaveSession(w http.ResponseWriter, r *http.Request, x interface{}) error {
|
||||||
|
err := s.wrappedStore.SaveSession(w, r, x)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("sessions/cache: wrapped store save error %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
state, ok := x.(*sessions.State)
|
||||||
|
if !ok {
|
||||||
|
return errors.New("internal/sessions: cannot cache non state type")
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := s.encoder.Marshal(&state)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("sessions/cache: marshal %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := newContext(r.Context(), data)
|
||||||
|
var b []byte
|
||||||
|
return s.cache.Get(ctx, state.AccessTokenID, groupcache.AllocatingByteSliceSink(&b))
|
||||||
|
}
|
||||||
|
|
||||||
|
var sessionCtxKey = &contextKey{"PomeriumCachedSessionBytes"}
|
||||||
|
|
||||||
|
type contextKey struct {
|
||||||
|
name string
|
||||||
|
}
|
||||||
|
|
||||||
|
func newContext(ctx context.Context, b []byte) context.Context {
|
||||||
|
ctx = context.WithValue(ctx, sessionCtxKey, b)
|
||||||
|
return ctx
|
||||||
|
}
|
||||||
|
|
||||||
|
func fromContext(ctx context.Context) []byte {
|
||||||
|
b, _ := ctx.Value(sessionCtxKey).([]byte)
|
||||||
|
return b
|
||||||
|
}
|
133
internal/sessions/cache/cache_store_test.go
vendored
Normal file
133
internal/sessions/cache/cache_store_test.go
vendored
Normal file
|
@ -0,0 +1,133 @@
|
||||||
|
package cache
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
|
"github.com/pomerium/pomerium/internal/cryptutil"
|
||||||
|
"github.com/pomerium/pomerium/internal/encoding/ecjson"
|
||||||
|
"github.com/pomerium/pomerium/internal/sessions"
|
||||||
|
"github.com/pomerium/pomerium/internal/sessions/cookie"
|
||||||
|
"gopkg.in/square/go-jose.v2/jwt"
|
||||||
|
)
|
||||||
|
|
||||||
|
func testAuthorizer(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
_, err := sessions.FromContext(r.Context())
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestVerifier(t *testing.T) {
|
||||||
|
fnh := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||||
|
fmt.Fprint(w, http.StatusText(http.StatusOK))
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
})
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
skipSave bool
|
||||||
|
cacheSize int64
|
||||||
|
state sessions.State
|
||||||
|
|
||||||
|
wantBody string
|
||||||
|
wantStatus int
|
||||||
|
}{
|
||||||
|
{"good", false, 1 << 10, sessions.State{AccessTokenID: cryptutil.NewBase64Key(), Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}, http.StatusText(http.StatusOK), http.StatusOK},
|
||||||
|
{"expired", false, 1 << 10, sessions.State{AccessTokenID: cryptutil.NewBase64Key(), Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}, "internal/sessions: validation failed, token is expired (exp)\n", http.StatusUnauthorized},
|
||||||
|
{"empty", false, 1 << 10, sessions.State{AccessTokenID: "", Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}, "internal/sessions: session is not found\n", http.StatusUnauthorized},
|
||||||
|
{"miss", true, 1 << 10, sessions.State{AccessTokenID: cryptutil.NewBase64Key(), Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}, "internal/sessions: session is not found\n", http.StatusUnauthorized},
|
||||||
|
{"cache eviction", false, 1, sessions.State{AccessTokenID: cryptutil.NewBase64Key(), Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}, "internal/sessions: session is not found\n", http.StatusUnauthorized},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
defaultCacheSize = tt.cacheSize
|
||||||
|
cipher, err := cryptutil.NewAEADCipherFromBase64(cryptutil.NewBase64Key())
|
||||||
|
encoder := ecjson.New(cipher)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
cs, err := cookie.NewStore(&cookie.Options{Name: t.Name()}, encoder)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
cacheStore := NewStore(encoder, cs, t.Name())
|
||||||
|
|
||||||
|
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
q := r.URL.Query()
|
||||||
|
|
||||||
|
q.Set(defaultQueryParamKey, tt.state.AccessTokenID)
|
||||||
|
r.URL.RawQuery = q.Encode()
|
||||||
|
r.Header.Set("Accept", "application/json")
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
got := sessions.RetrieveSession(cacheStore)(testAuthorizer((fnh)))
|
||||||
|
|
||||||
|
if !tt.skipSave {
|
||||||
|
cacheStore.SaveSession(w, r, &tt.state)
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 1; i <= 10; i++ {
|
||||||
|
s := tt.state
|
||||||
|
s.AccessTokenID = cryptutil.NewBase64Key()
|
||||||
|
cacheStore.SaveSession(w, r, s)
|
||||||
|
}
|
||||||
|
|
||||||
|
got.ServeHTTP(w, r)
|
||||||
|
|
||||||
|
gotBody := w.Body.String()
|
||||||
|
gotStatus := w.Result().StatusCode
|
||||||
|
|
||||||
|
if diff := cmp.Diff(gotBody, tt.wantBody); diff != "" {
|
||||||
|
t.Errorf("RetrieveSession() = %v", diff)
|
||||||
|
}
|
||||||
|
if diff := cmp.Diff(gotStatus, tt.wantStatus); diff != "" {
|
||||||
|
t.Errorf("RetrieveSession() = %v", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStore_SaveSession(t *testing.T) {
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
x interface{}
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{"good", &sessions.State{AccessTokenID: cryptutil.NewBase64Key(), Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}, false},
|
||||||
|
{"bad type", "bad type!", true},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
cipher, err := cryptutil.NewAEADCipherFromBase64(cryptutil.NewBase64Key())
|
||||||
|
encoder := ecjson.New(cipher)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
cs, err := cookie.NewStore(&cookie.Options{
|
||||||
|
Name: "_pomerium",
|
||||||
|
}, encoder)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
cacheStore := NewStore(encoder, cs, t.Name())
|
||||||
|
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
r.Header.Set("Accept", "application/json")
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
if err := cacheStore.SaveSession(w, r, tt.x); (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("Store.SaveSession() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,4 +1,4 @@
|
||||||
package sessions // import "github.com/pomerium/pomerium/internal/sessions"
|
package cookie // import "github.com/pomerium/pomerium/internal/sessions/cookie"
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
|
@ -8,8 +8,15 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/internal/encoding"
|
"github.com/pomerium/pomerium/internal/encoding"
|
||||||
|
"github.com/pomerium/pomerium/internal/sessions"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var _ sessions.SessionStore = &Store{}
|
||||||
|
var _ sessions.SessionLoader = &Store{}
|
||||||
|
|
||||||
|
// timeNow is time.Now but pulled out as a variable for tests.
|
||||||
|
var timeNow = time.Now
|
||||||
|
|
||||||
const (
|
const (
|
||||||
// ChunkedCanaryByte is the byte value used as a canary prefix to distinguish if
|
// ChunkedCanaryByte is the byte value used as a canary prefix to distinguish if
|
||||||
// the cookie is multi-part or not. This constant *should not* be valid
|
// the cookie is multi-part or not. This constant *should not* be valid
|
||||||
|
@ -25,8 +32,8 @@ const (
|
||||||
MaxNumChunks = 5
|
MaxNumChunks = 5
|
||||||
)
|
)
|
||||||
|
|
||||||
// CookieStore implements the session store interface for session cookies.
|
// Store implements the session store interface for session cookies.
|
||||||
type CookieStore struct {
|
type Store struct {
|
||||||
Name string
|
Name string
|
||||||
Domain string
|
Domain string
|
||||||
Expire time.Duration
|
Expire time.Duration
|
||||||
|
@ -37,8 +44,8 @@ type CookieStore struct {
|
||||||
decoder encoding.Unmarshaler
|
decoder encoding.Unmarshaler
|
||||||
}
|
}
|
||||||
|
|
||||||
// CookieOptions holds options for CookieStore
|
// Options holds options for Store
|
||||||
type CookieOptions struct {
|
type Options struct {
|
||||||
Name string
|
Name string
|
||||||
Domain string
|
Domain string
|
||||||
Expire time.Duration
|
Expire time.Duration
|
||||||
|
@ -46,8 +53,9 @@ type CookieOptions struct {
|
||||||
Secure bool
|
Secure bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewCookieStore returns a new session with ciphers for each of the cookie secrets
|
// NewStore returns a new store that implements the SessionStore interface
|
||||||
func NewCookieStore(opts *CookieOptions, encoder encoding.MarshalUnmarshaler) (*CookieStore, error) {
|
// using http cookies.
|
||||||
|
func NewStore(opts *Options, encoder encoding.MarshalUnmarshaler) (sessions.SessionStore, error) {
|
||||||
cs, err := NewCookieLoader(opts, encoder)
|
cs, err := NewCookieLoader(opts, encoder)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -56,12 +64,13 @@ func NewCookieStore(opts *CookieOptions, encoder encoding.MarshalUnmarshaler) (*
|
||||||
return cs, nil
|
return cs, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewCookieLoader returns a new session with ciphers for each of the cookie secrets
|
// NewCookieLoader returns a new store that implements the SessionLoader
|
||||||
func NewCookieLoader(opts *CookieOptions, dencoder encoding.Unmarshaler) (*CookieStore, error) {
|
// interface using http cookies.
|
||||||
|
func NewCookieLoader(opts *Options, dencoder encoding.Unmarshaler) (*Store, error) {
|
||||||
if dencoder == nil {
|
if dencoder == nil {
|
||||||
return nil, fmt.Errorf("internal/sessions: dencoder cannot be nil")
|
return nil, fmt.Errorf("internal/sessions: dencoder cannot be nil")
|
||||||
}
|
}
|
||||||
cs, err := newCookieStore(opts)
|
cs, err := newStore(opts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -69,12 +78,12 @@ func NewCookieLoader(opts *CookieOptions, dencoder encoding.Unmarshaler) (*Cooki
|
||||||
return cs, nil
|
return cs, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func newCookieStore(opts *CookieOptions) (*CookieStore, error) {
|
func newStore(opts *Options) (*Store, error) {
|
||||||
if opts.Name == "" {
|
if opts.Name == "" {
|
||||||
return nil, fmt.Errorf("internal/sessions: cookie name cannot be empty")
|
return nil, fmt.Errorf("internal/sessions: cookie name cannot be empty")
|
||||||
}
|
}
|
||||||
|
|
||||||
return &CookieStore{
|
return &Store{
|
||||||
Name: opts.Name,
|
Name: opts.Name,
|
||||||
Secure: opts.Secure,
|
Secure: opts.Secure,
|
||||||
HTTPOnly: opts.HTTPOnly,
|
HTTPOnly: opts.HTTPOnly,
|
||||||
|
@ -83,7 +92,7 @@ func newCookieStore(opts *CookieOptions) (*CookieStore, error) {
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cs *CookieStore) makeCookie(value string) *http.Cookie {
|
func (cs *Store) makeCookie(value string) *http.Cookie {
|
||||||
return &http.Cookie{
|
return &http.Cookie{
|
||||||
Name: cs.Name,
|
Name: cs.Name,
|
||||||
Value: value,
|
Value: value,
|
||||||
|
@ -96,7 +105,7 @@ func (cs *CookieStore) makeCookie(value string) *http.Cookie {
|
||||||
}
|
}
|
||||||
|
|
||||||
// ClearSession clears the session cookie from a request
|
// ClearSession clears the session cookie from a request
|
||||||
func (cs *CookieStore) ClearSession(w http.ResponseWriter, r *http.Request) {
|
func (cs *Store) ClearSession(w http.ResponseWriter, r *http.Request) {
|
||||||
c := cs.makeCookie("")
|
c := cs.makeCookie("")
|
||||||
c.MaxAge = -1
|
c.MaxAge = -1
|
||||||
c.Expires = timeNow().Add(-time.Hour)
|
c.Expires = timeNow().Add(-time.Hour)
|
||||||
|
@ -115,51 +124,51 @@ func getCookies(r *http.Request, name string) []*http.Cookie {
|
||||||
}
|
}
|
||||||
|
|
||||||
// LoadSession returns a State from the cookie in the request.
|
// LoadSession returns a State from the cookie in the request.
|
||||||
func (cs *CookieStore) LoadSession(r *http.Request) (*State, error) {
|
func (cs *Store) LoadSession(r *http.Request) (*sessions.State, error) {
|
||||||
cookies := getCookies(r, cs.Name)
|
cookies := getCookies(r, cs.Name)
|
||||||
if len(cookies) == 0 {
|
if len(cookies) == 0 {
|
||||||
return nil, ErrNoSessionFound
|
return nil, sessions.ErrNoSessionFound
|
||||||
}
|
}
|
||||||
for _, cookie := range cookies {
|
for _, cookie := range cookies {
|
||||||
data := loadChunkedCookie(r, cookie)
|
data := loadChunkedCookie(r, cookie)
|
||||||
|
|
||||||
session := &State{}
|
session := &sessions.State{}
|
||||||
err := cs.decoder.Unmarshal([]byte(data), session)
|
err := cs.decoder.Unmarshal([]byte(data), session)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
return session, nil
|
return session, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil, ErrMalformed
|
return nil, sessions.ErrMalformed
|
||||||
}
|
}
|
||||||
|
|
||||||
// SaveSession saves a session state to a request's cookie store.
|
// SaveSession saves a session state to a request's cookie store.
|
||||||
func (cs *CookieStore) SaveSession(w http.ResponseWriter, _ *http.Request, x interface{}) error {
|
func (cs *Store) SaveSession(w http.ResponseWriter, _ *http.Request, x interface{}) error {
|
||||||
var value string
|
var value string
|
||||||
if cs.encoder != nil {
|
|
||||||
data, err := cs.encoder.Marshal(x)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
value = string(data)
|
|
||||||
} else {
|
|
||||||
switch v := x.(type) {
|
switch v := x.(type) {
|
||||||
case []byte:
|
case []byte:
|
||||||
value = string(v)
|
value = string(v)
|
||||||
case string:
|
case string:
|
||||||
value = v
|
value = v
|
||||||
default:
|
default:
|
||||||
|
if cs.encoder == nil {
|
||||||
return errors.New("internal/sessions: cannot save non-string type")
|
return errors.New("internal/sessions: cannot save non-string type")
|
||||||
}
|
}
|
||||||
|
data, err := cs.encoder.Marshal(x)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
value = string(data)
|
||||||
|
}
|
||||||
|
|
||||||
cs.setSessionCookie(w, value)
|
cs.setSessionCookie(w, value)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cs *CookieStore) setSessionCookie(w http.ResponseWriter, val string) {
|
func (cs *Store) setSessionCookie(w http.ResponseWriter, val string) {
|
||||||
cs.setCookie(w, cs.makeCookie(val))
|
cs.setCookie(w, cs.makeCookie(val))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cs *CookieStore) setCookie(w http.ResponseWriter, cookie *http.Cookie) {
|
func (cs *Store) setCookie(w http.ResponseWriter, cookie *http.Cookie) {
|
||||||
if len(cookie.String()) <= MaxChunkSize {
|
if len(cookie.String()) <= MaxChunkSize {
|
||||||
http.SetCookie(w, cookie)
|
http.SetCookie(w, cookie)
|
||||||
return
|
return
|
||||||
|
@ -180,9 +189,15 @@ func (cs *CookieStore) setCookie(w http.ResponseWriter, cookie *http.Cookie) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func loadChunkedCookie(r *http.Request, c *http.Cookie) string {
|
func loadChunkedCookie(r *http.Request, c *http.Cookie) string {
|
||||||
data := c.Value
|
if len(c.Value) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
// if the first byte is our canary byte, we need to handle the multipart bit
|
// if the first byte is our canary byte, we need to handle the multipart bit
|
||||||
if []byte(c.Value)[0] == ChunkedCanaryByte {
|
if []byte(c.Value)[0] != ChunkedCanaryByte {
|
||||||
|
return c.Value
|
||||||
|
}
|
||||||
|
|
||||||
|
data := c.Value
|
||||||
var b strings.Builder
|
var b strings.Builder
|
||||||
fmt.Fprintf(&b, "%s", data[1:])
|
fmt.Fprintf(&b, "%s", data[1:])
|
||||||
for i := 1; i <= MaxNumChunks; i++ {
|
for i := 1; i <= MaxNumChunks; i++ {
|
||||||
|
@ -193,7 +208,7 @@ func loadChunkedCookie(r *http.Request, c *http.Cookie) string {
|
||||||
fmt.Fprintf(&b, "%s", next.Value)
|
fmt.Fprintf(&b, "%s", next.Value)
|
||||||
}
|
}
|
||||||
data = b.String()
|
data = b.String()
|
||||||
}
|
|
||||||
return data
|
return data
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
package sessions // import "github.com/pomerium/pomerium/internal/sessions"
|
package cookie // import "github.com/pomerium/pomerium/internal/sessions/cookie"
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
|
@ -13,12 +13,13 @@ import (
|
||||||
"github.com/pomerium/pomerium/internal/encoding"
|
"github.com/pomerium/pomerium/internal/encoding"
|
||||||
"github.com/pomerium/pomerium/internal/encoding/ecjson"
|
"github.com/pomerium/pomerium/internal/encoding/ecjson"
|
||||||
"github.com/pomerium/pomerium/internal/encoding/mock"
|
"github.com/pomerium/pomerium/internal/encoding/mock"
|
||||||
|
"github.com/pomerium/pomerium/internal/sessions"
|
||||||
|
|
||||||
"github.com/google/go-cmp/cmp"
|
"github.com/google/go-cmp/cmp"
|
||||||
"github.com/google/go-cmp/cmp/cmpopts"
|
"github.com/google/go-cmp/cmp/cmpopts"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestNewCookieStore(t *testing.T) {
|
func TestNewStore(t *testing.T) {
|
||||||
cipher, err := cryptutil.NewAEADCipher(cryptutil.NewKey())
|
cipher, err := cryptutil.NewAEADCipher(cryptutil.NewKey())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
|
@ -26,28 +27,28 @@ func TestNewCookieStore(t *testing.T) {
|
||||||
encoder := ecjson.New(cipher)
|
encoder := ecjson.New(cipher)
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
opts *CookieOptions
|
opts *Options
|
||||||
encoder encoding.MarshalUnmarshaler
|
encoder encoding.MarshalUnmarshaler
|
||||||
want *CookieStore
|
want sessions.SessionStore
|
||||||
wantErr bool
|
wantErr bool
|
||||||
}{
|
}{
|
||||||
{"good", &CookieOptions{Name: "_cookie", Secure: true, HTTPOnly: true, Domain: "pomerium.io", Expire: 10 * time.Second}, encoder, &CookieStore{Name: "_cookie", Secure: true, HTTPOnly: true, Domain: "pomerium.io", Expire: 10 * time.Second}, false},
|
{"good", &Options{Name: "_cookie", Secure: true, HTTPOnly: true, Domain: "pomerium.io", Expire: 10 * time.Second}, encoder, &Store{Name: "_cookie", Secure: true, HTTPOnly: true, Domain: "pomerium.io", Expire: 10 * time.Second}, false},
|
||||||
{"missing name", &CookieOptions{Name: "", Secure: true, HTTPOnly: true, Domain: "pomerium.io", Expire: 10 * time.Second}, encoder, nil, true},
|
{"missing name", &Options{Name: "", Secure: true, HTTPOnly: true, Domain: "pomerium.io", Expire: 10 * time.Second}, encoder, nil, true},
|
||||||
{"missing encoder", &CookieOptions{Name: "_cookie", Secure: true, HTTPOnly: true, Domain: "pomerium.io", Expire: 10 * time.Second}, nil, nil, true},
|
{"missing encoder", &Options{Name: "_cookie", Secure: true, HTTPOnly: true, Domain: "pomerium.io", Expire: 10 * time.Second}, nil, nil, true},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
got, err := NewCookieStore(tt.opts, tt.encoder)
|
got, err := NewStore(tt.opts, tt.encoder)
|
||||||
if (err != nil) != tt.wantErr {
|
if (err != nil) != tt.wantErr {
|
||||||
t.Errorf("NewCookieStore() error = %v, wantErr %v", err, tt.wantErr)
|
t.Errorf("NewStore() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
cmpOpts := []cmp.Option{
|
cmpOpts := []cmp.Option{
|
||||||
cmpopts.IgnoreUnexported(CookieStore{}),
|
cmpopts.IgnoreUnexported(Store{}),
|
||||||
}
|
}
|
||||||
|
|
||||||
if diff := cmp.Diff(got, tt.want, cmpOpts...); diff != "" {
|
if diff := cmp.Diff(got, tt.want, cmpOpts...); diff != "" {
|
||||||
t.Errorf("NewCookieStore() = %s", diff)
|
t.Errorf("NewStore() = %s", diff)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -60,14 +61,14 @@ func TestNewCookieLoader(t *testing.T) {
|
||||||
encoder := ecjson.New(cipher)
|
encoder := ecjson.New(cipher)
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
opts *CookieOptions
|
opts *Options
|
||||||
encoder encoding.MarshalUnmarshaler
|
encoder encoding.MarshalUnmarshaler
|
||||||
want *CookieStore
|
want *Store
|
||||||
wantErr bool
|
wantErr bool
|
||||||
}{
|
}{
|
||||||
{"good", &CookieOptions{Name: "_cookie", Secure: true, HTTPOnly: true, Domain: "pomerium.io", Expire: 10 * time.Second}, encoder, &CookieStore{Name: "_cookie", Secure: true, HTTPOnly: true, Domain: "pomerium.io", Expire: 10 * time.Second}, false},
|
{"good", &Options{Name: "_cookie", Secure: true, HTTPOnly: true, Domain: "pomerium.io", Expire: 10 * time.Second}, encoder, &Store{Name: "_cookie", Secure: true, HTTPOnly: true, Domain: "pomerium.io", Expire: 10 * time.Second}, false},
|
||||||
{"missing name", &CookieOptions{Name: "", Secure: true, HTTPOnly: true, Domain: "pomerium.io", Expire: 10 * time.Second}, encoder, nil, true},
|
{"missing name", &Options{Name: "", Secure: true, HTTPOnly: true, Domain: "pomerium.io", Expire: 10 * time.Second}, encoder, nil, true},
|
||||||
{"missing encoder", &CookieOptions{Name: "_cookie", Secure: true, HTTPOnly: true, Domain: "pomerium.io", Expire: 10 * time.Second}, nil, nil, true},
|
{"missing encoder", &Options{Name: "_cookie", Secure: true, HTTPOnly: true, Domain: "pomerium.io", Expire: 10 * time.Second}, nil, nil, true},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
@ -77,7 +78,7 @@ func TestNewCookieLoader(t *testing.T) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
cmpOpts := []cmp.Option{
|
cmpOpts := []cmp.Option{
|
||||||
cmpopts.IgnoreUnexported(CookieStore{}),
|
cmpopts.IgnoreUnexported(Store{}),
|
||||||
}
|
}
|
||||||
|
|
||||||
if diff := cmp.Diff(got, tt.want, cmpOpts...); diff != "" {
|
if diff := cmp.Diff(got, tt.want, cmpOpts...); diff != "" {
|
||||||
|
@ -87,7 +88,7 @@ func TestNewCookieLoader(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCookieStore_SaveSession(t *testing.T) {
|
func TestStore_SaveSession(t *testing.T) {
|
||||||
c, err := cryptutil.NewAEADCipher(cryptutil.NewKey())
|
c, err := cryptutil.NewAEADCipher(cryptutil.NewKey())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
|
@ -106,17 +107,17 @@ func TestCookieStore_SaveSession(t *testing.T) {
|
||||||
wantErr bool
|
wantErr bool
|
||||||
wantLoadErr bool
|
wantLoadErr bool
|
||||||
}{
|
}{
|
||||||
{"good", &State{Email: "user@domain.com", User: "user"}, ecjson.New(c), ecjson.New(c), false, false},
|
{"good", &sessions.State{Email: "user@domain.com", User: "user"}, ecjson.New(c), ecjson.New(c), false, false},
|
||||||
{"bad cipher", &State{Email: "user@domain.com", User: "user"}, nil, nil, true, true},
|
{"bad cipher", &sessions.State{Email: "user@domain.com", User: "user"}, nil, nil, true, true},
|
||||||
{"huge cookie", &State{Subject: fmt.Sprintf("%x", hugeString), Email: "user@domain.com", User: "user"}, ecjson.New(c), ecjson.New(c), false, false},
|
{"huge cookie", &sessions.State{Subject: fmt.Sprintf("%x", hugeString), Email: "user@domain.com", User: "user"}, ecjson.New(c), ecjson.New(c), false, false},
|
||||||
{"marshal error", &State{Email: "user@domain.com", User: "user"}, mock.Encoder{MarshalError: errors.New("error")}, ecjson.New(c), true, true},
|
{"marshal error", &sessions.State{Email: "user@domain.com", User: "user"}, mock.Encoder{MarshalError: errors.New("error")}, ecjson.New(c), true, true},
|
||||||
{"nil encoder cannot save non string type", &State{Email: "user@domain.com", User: "user"}, nil, ecjson.New(c), true, true},
|
{"nil encoder cannot save non string type", &sessions.State{Email: "user@domain.com", User: "user"}, nil, ecjson.New(c), true, true},
|
||||||
{"good marshal string directly", cryptutil.NewBase64Key(), nil, ecjson.New(c), false, true},
|
{"good marshal string directly", cryptutil.NewBase64Key(), nil, ecjson.New(c), false, true},
|
||||||
{"good marshal bytes directly", cryptutil.NewKey(), nil, ecjson.New(c), false, true},
|
{"good marshal bytes directly", cryptutil.NewKey(), nil, ecjson.New(c), false, true},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
s := &CookieStore{
|
s := &Store{
|
||||||
Name: "_pomerium",
|
Name: "_pomerium",
|
||||||
Secure: true,
|
Secure: true,
|
||||||
HTTPOnly: true,
|
HTTPOnly: true,
|
||||||
|
@ -130,7 +131,7 @@ func TestCookieStore_SaveSession(t *testing.T) {
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
if err := s.SaveSession(w, r, tt.State); (err != nil) != tt.wantErr {
|
if err := s.SaveSession(w, r, tt.State); (err != nil) != tt.wantErr {
|
||||||
t.Errorf("CookieStore.SaveSession() error = %v, wantErr %v", err, tt.wantErr)
|
t.Errorf("Store.SaveSession() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
}
|
}
|
||||||
r = httptest.NewRequest("GET", "/", nil)
|
r = httptest.NewRequest("GET", "/", nil)
|
||||||
for _, cookie := range w.Result().Cookies() {
|
for _, cookie := range w.Result().Cookies() {
|
||||||
|
@ -143,11 +144,11 @@ func TestCookieStore_SaveSession(t *testing.T) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
cmpOpts := []cmp.Option{
|
cmpOpts := []cmp.Option{
|
||||||
cmpopts.IgnoreUnexported(State{}),
|
cmpopts.IgnoreUnexported(sessions.State{}),
|
||||||
}
|
}
|
||||||
if err == nil {
|
if err == nil {
|
||||||
if diff := cmp.Diff(state, tt.State, cmpOpts...); diff != "" {
|
if diff := cmp.Diff(state, tt.State, cmpOpts...); diff != "" {
|
||||||
t.Errorf("CookieStore.LoadSession() got = %s", diff)
|
t.Errorf("Store.LoadSession() got = %s", diff)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
w = httptest.NewRecorder()
|
w = httptest.NewRecorder()
|
90
internal/sessions/cookie/middleware_test.go
Normal file
90
internal/sessions/cookie/middleware_test.go
Normal file
|
@ -0,0 +1,90 @@
|
||||||
|
package cookie // import "github.com/pomerium/pomerium/internal/sessions/cookie"
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/internal/sessions"
|
||||||
|
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
|
"github.com/pomerium/pomerium/internal/cryptutil"
|
||||||
|
"github.com/pomerium/pomerium/internal/encoding/ecjson"
|
||||||
|
"gopkg.in/square/go-jose.v2/jwt"
|
||||||
|
)
|
||||||
|
|
||||||
|
func testAuthorizer(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
_, err := sessions.FromContext(r.Context())
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestVerifier(t *testing.T) {
|
||||||
|
fnh := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||||
|
fmt.Fprint(w, http.StatusText(http.StatusOK))
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
})
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
state sessions.State
|
||||||
|
|
||||||
|
wantBody string
|
||||||
|
wantStatus int
|
||||||
|
}{
|
||||||
|
{"good cookie session", sessions.State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}, http.StatusText(http.StatusOK), http.StatusOK},
|
||||||
|
{"expired cookie", sessions.State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}, "internal/sessions: validation failed, token is expired (exp)\n", http.StatusUnauthorized},
|
||||||
|
{"malformed cookie", sessions.State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}, "internal/sessions: session is malformed\n", http.StatusUnauthorized},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
cipher, err := cryptutil.NewAEADCipherFromBase64(cryptutil.NewBase64Key())
|
||||||
|
encoder := ecjson.New(cipher)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
encSession, err := encoder.Marshal(&tt.state)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if strings.Contains(tt.name, "malformed") {
|
||||||
|
// add some garbage to the end of the string
|
||||||
|
encSession = append(encSession, cryptutil.NewKey()...)
|
||||||
|
}
|
||||||
|
|
||||||
|
cs, err := NewStore(&Options{
|
||||||
|
Name: "_pomerium",
|
||||||
|
}, encoder)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
r.Header.Set("Accept", "application/json")
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r.AddCookie(&http.Cookie{Name: "_pomerium", Value: string(encSession)})
|
||||||
|
|
||||||
|
got := sessions.RetrieveSession(cs)(testAuthorizer((fnh)))
|
||||||
|
got.ServeHTTP(w, r)
|
||||||
|
|
||||||
|
gotBody := w.Body.String()
|
||||||
|
gotStatus := w.Result().StatusCode
|
||||||
|
|
||||||
|
if diff := cmp.Diff(gotBody, tt.wantBody); diff != "" {
|
||||||
|
t.Errorf("RetrieveSession() = %v", diff)
|
||||||
|
}
|
||||||
|
if diff := cmp.Diff(gotStatus, tt.wantStatus); diff != "" {
|
||||||
|
t.Errorf("RetrieveSession() = %v", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
28
internal/sessions/errors.go
Normal file
28
internal/sessions/errors.go
Normal file
|
@ -0,0 +1,28 @@
|
||||||
|
package sessions // import "github.com/pomerium/pomerium/internal/sessions"
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
// ErrNoSessionFound is the error for when no session is found.
|
||||||
|
ErrNoSessionFound = errors.New("internal/sessions: session is not found")
|
||||||
|
|
||||||
|
// ErrMalformed is the error for when a session is found but is malformed.
|
||||||
|
ErrMalformed = errors.New("internal/sessions: session is malformed")
|
||||||
|
|
||||||
|
// ErrNotValidYet indicates that token is used before time indicated in nbf claim.
|
||||||
|
ErrNotValidYet = errors.New("internal/sessions: validation failed, token not valid yet (nbf)")
|
||||||
|
|
||||||
|
// ErrExpired indicates that token is used after expiry time indicated in exp claim.
|
||||||
|
ErrExpired = errors.New("internal/sessions: validation failed, token is expired (exp)")
|
||||||
|
|
||||||
|
// ErrExpiryRequired indicates that the token does not contain a valid expiry (exp) claim.
|
||||||
|
ErrExpiryRequired = errors.New("internal/sessions: validation failed, token expiry (exp) is required")
|
||||||
|
|
||||||
|
// ErrIssuedInTheFuture indicates that the iat field is in the future.
|
||||||
|
ErrIssuedInTheFuture = errors.New("internal/sessions: validation field, token issued in the future (iat)")
|
||||||
|
|
||||||
|
// ErrInvalidAudience indicated invalid aud claim.
|
||||||
|
ErrInvalidAudience = errors.New("internal/sessions: validation failed, invalid audience claim (aud)")
|
||||||
|
)
|
|
@ -1,35 +1,38 @@
|
||||||
package sessions // import "github.com/pomerium/pomerium/internal/sessions"
|
package header // import "github.com/pomerium/pomerium/internal/sessions/header"
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/internal/encoding"
|
"github.com/pomerium/pomerium/internal/encoding"
|
||||||
|
"github.com/pomerium/pomerium/internal/sessions"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var _ sessions.SessionLoader = &Store{}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
defaultAuthHeader = "Authorization"
|
defaultAuthHeader = "Authorization"
|
||||||
defaultAuthType = "Bearer"
|
defaultAuthType = "Bearer"
|
||||||
)
|
)
|
||||||
|
|
||||||
// HeaderStore implements the load session store interface using http
|
// Store implements the load session store interface using http
|
||||||
// authorization headers.
|
// authorization headers.
|
||||||
type HeaderStore struct {
|
type Store struct {
|
||||||
authHeader string
|
authHeader string
|
||||||
authType string
|
authType string
|
||||||
encoder encoding.Unmarshaler
|
encoder encoding.Unmarshaler
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewHeaderStore returns a new header store for loading sessions from
|
// NewStore returns a new header store for loading sessions from
|
||||||
// authorization header as defined in as defined in rfc2617
|
// authorization header as defined in as defined in rfc2617
|
||||||
//
|
//
|
||||||
// NOTA BENE: While most servers do not log Authorization headers by default,
|
// NOTA BENE: While most servers do not log Authorization headers by default,
|
||||||
// you should ensure no other services are logging or leaking your auth headers.
|
// you should ensure no other services are logging or leaking your auth headers.
|
||||||
func NewHeaderStore(enc encoding.Unmarshaler, headerType string) *HeaderStore {
|
func NewStore(enc encoding.Unmarshaler, headerType string) *Store {
|
||||||
if headerType == "" {
|
if headerType == "" {
|
||||||
headerType = defaultAuthType
|
headerType = defaultAuthType
|
||||||
}
|
}
|
||||||
return &HeaderStore{
|
return &Store{
|
||||||
authHeader: defaultAuthHeader,
|
authHeader: defaultAuthHeader,
|
||||||
authType: headerType,
|
authType: headerType,
|
||||||
encoder: enc,
|
encoder: enc,
|
||||||
|
@ -37,14 +40,14 @@ func NewHeaderStore(enc encoding.Unmarshaler, headerType string) *HeaderStore {
|
||||||
}
|
}
|
||||||
|
|
||||||
// LoadSession tries to retrieve the token string from the Authorization header.
|
// LoadSession tries to retrieve the token string from the Authorization header.
|
||||||
func (as *HeaderStore) LoadSession(r *http.Request) (*State, error) {
|
func (as *Store) LoadSession(r *http.Request) (*sessions.State, error) {
|
||||||
cipherText := TokenFromHeader(r, as.authHeader, as.authType)
|
cipherText := TokenFromHeader(r, as.authHeader, as.authType)
|
||||||
if cipherText == "" {
|
if cipherText == "" {
|
||||||
return nil, ErrNoSessionFound
|
return nil, sessions.ErrNoSessionFound
|
||||||
}
|
}
|
||||||
var session State
|
var session sessions.State
|
||||||
if err := as.encoder.Unmarshal([]byte(cipherText), &session); err != nil {
|
if err := as.encoder.Unmarshal([]byte(cipherText), &session); err != nil {
|
||||||
return nil, ErrMalformed
|
return nil, sessions.ErrMalformed
|
||||||
}
|
}
|
||||||
return &session, nil
|
return &session, nil
|
||||||
}
|
}
|
90
internal/sessions/header/middleware_test.go
Normal file
90
internal/sessions/header/middleware_test.go
Normal file
|
@ -0,0 +1,90 @@
|
||||||
|
package header // import "github.com/pomerium/pomerium/internal/sessions/header"
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/internal/cryptutil"
|
||||||
|
"github.com/pomerium/pomerium/internal/encoding/ecjson"
|
||||||
|
"github.com/pomerium/pomerium/internal/sessions"
|
||||||
|
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
|
"gopkg.in/square/go-jose.v2/jwt"
|
||||||
|
)
|
||||||
|
|
||||||
|
func testAuthorizer(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
_, err := sessions.FromContext(r.Context())
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestVerifier(t *testing.T) {
|
||||||
|
fnh := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||||
|
fmt.Fprint(w, http.StatusText(http.StatusOK))
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
})
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
authType string
|
||||||
|
state sessions.State
|
||||||
|
wantBody string
|
||||||
|
wantStatus int
|
||||||
|
}{
|
||||||
|
{"good auth header session", "Bearer ", sessions.State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}, http.StatusText(http.StatusOK), http.StatusOK},
|
||||||
|
{"expired auth header", "Bearer ", sessions.State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}, "internal/sessions: validation failed, token is expired (exp)\n", http.StatusUnauthorized},
|
||||||
|
{"malformed auth header", "Bearer ", sessions.State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}, "internal/sessions: session is malformed\n", http.StatusUnauthorized},
|
||||||
|
{"empty auth header", "Bearer ", sessions.State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}, "internal/sessions: session is not found\n", http.StatusUnauthorized},
|
||||||
|
{"bad auth type", "bees ", sessions.State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}, "internal/sessions: session is not found\n", http.StatusUnauthorized},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
cipher, err := cryptutil.NewAEADCipherFromBase64(cryptutil.NewBase64Key())
|
||||||
|
encoder := ecjson.New(cipher)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
encSession, err := encoder.Marshal(&tt.state)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if strings.Contains(tt.name, "malformed") {
|
||||||
|
// add some garbage to the end of the string
|
||||||
|
encSession = append(encSession, cryptutil.NewKey()...)
|
||||||
|
}
|
||||||
|
s := NewStore(encoder, "")
|
||||||
|
|
||||||
|
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
r.Header.Set("Accept", "application/json")
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
if strings.Contains(tt.name, "empty") {
|
||||||
|
encSession = []byte("")
|
||||||
|
}
|
||||||
|
r.Header.Set("Authorization", tt.authType+string(encSession))
|
||||||
|
|
||||||
|
got := sessions.RetrieveSession(s)(testAuthorizer((fnh)))
|
||||||
|
got.ServeHTTP(w, r)
|
||||||
|
|
||||||
|
gotBody := w.Body.String()
|
||||||
|
gotStatus := w.Result().StatusCode
|
||||||
|
|
||||||
|
if diff := cmp.Diff(gotBody, tt.wantBody); diff != "" {
|
||||||
|
t.Errorf("RetrieveSession() = %v", diff)
|
||||||
|
}
|
||||||
|
if diff := cmp.Diff(gotStatus, tt.wantStatus); diff != "" {
|
||||||
|
t.Errorf("RetrieveSession() = %v", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
|
@ -44,7 +44,7 @@ func retrieveFromRequest(r *http.Request, sessions ...SessionLoader) (*State, er
|
||||||
}
|
}
|
||||||
if state != nil {
|
if state != nil {
|
||||||
err := state.Verify(urlutil.StripPort(r.Host))
|
err := state.Verify(urlutil.StripPort(r.Host))
|
||||||
return state, err // N.B.: state is _not_ nil_
|
return state, err // N.B.: state is _not_ nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -2,16 +2,14 @@ package sessions
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"strings"
|
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/google/go-cmp/cmp"
|
"github.com/google/go-cmp/cmp"
|
||||||
"github.com/pomerium/pomerium/internal/cryptutil"
|
|
||||||
"github.com/pomerium/pomerium/internal/encoding/ecjson"
|
|
||||||
"gopkg.in/square/go-jose.v2/jwt"
|
"gopkg.in/square/go-jose.v2/jwt"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -39,103 +37,6 @@ func TestNewContext(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func testAuthorizer(next http.Handler) http.Handler {
|
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
_, err := FromContext(r.Context())
|
|
||||||
if err != nil {
|
|
||||||
http.Error(w, err.Error(), http.StatusUnauthorized)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
next.ServeHTTP(w, r)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestVerifier(t *testing.T) {
|
|
||||||
fnh := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
|
||||||
fmt.Fprint(w, http.StatusText(http.StatusOK))
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
})
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
// s SessionStore
|
|
||||||
state State
|
|
||||||
|
|
||||||
cookie bool
|
|
||||||
header bool
|
|
||||||
param bool
|
|
||||||
|
|
||||||
wantBody string
|
|
||||||
wantStatus int
|
|
||||||
}{
|
|
||||||
{"good cookie session", State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}, true, false, false, http.StatusText(http.StatusOK), http.StatusOK},
|
|
||||||
{"expired cookie", State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}, true, false, false, "internal/sessions: validation failed, token is expired (exp)\n", http.StatusUnauthorized},
|
|
||||||
{"malformed cookie", State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}, true, false, false, "internal/sessions: session is malformed\n", http.StatusUnauthorized},
|
|
||||||
{"good auth header session", State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}, false, true, false, http.StatusText(http.StatusOK), http.StatusOK},
|
|
||||||
{"expired auth header", State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}, false, true, false, "internal/sessions: validation failed, token is expired (exp)\n", http.StatusUnauthorized},
|
|
||||||
{"malformed auth header", State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}, false, true, false, "internal/sessions: session is malformed\n", http.StatusUnauthorized},
|
|
||||||
{"good auth query param session", State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}, false, true, true, http.StatusText(http.StatusOK), http.StatusOK},
|
|
||||||
{"expired auth query param", State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}, false, false, true, "internal/sessions: validation failed, token is expired (exp)\n", http.StatusUnauthorized},
|
|
||||||
{"malformed auth query param", State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}, false, false, true, "internal/sessions: session is malformed\n", http.StatusUnauthorized},
|
|
||||||
{"no session", State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}, false, false, false, "internal/sessions: session is not found\n", http.StatusUnauthorized},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
cipher, err := cryptutil.NewAEADCipherFromBase64(cryptutil.NewBase64Key())
|
|
||||||
encoder := ecjson.New(cipher)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
encSession, err := encoder.Marshal(&tt.state)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
if strings.Contains(tt.name, "malformed") {
|
|
||||||
// add some garbage to the end of the string
|
|
||||||
encSession = append(encSession, cryptutil.NewKey()...)
|
|
||||||
}
|
|
||||||
|
|
||||||
cs, err := NewCookieStore(&CookieOptions{
|
|
||||||
Name: "_pomerium",
|
|
||||||
}, encoder)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
as := NewHeaderStore(encoder, "")
|
|
||||||
|
|
||||||
qp := NewQueryParamStore(encoder, "")
|
|
||||||
|
|
||||||
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
||||||
r.Header.Set("Accept", "application/json")
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
if tt.cookie {
|
|
||||||
r.AddCookie(&http.Cookie{Name: "_pomerium", Value: string(encSession)})
|
|
||||||
} else if tt.header {
|
|
||||||
r.Header.Set("Authorization", "Bearer "+string(encSession))
|
|
||||||
} else if tt.param {
|
|
||||||
q := r.URL.Query()
|
|
||||||
|
|
||||||
q.Set("pomerium_session", string(encSession))
|
|
||||||
r.URL.RawQuery = q.Encode()
|
|
||||||
}
|
|
||||||
|
|
||||||
got := RetrieveSession(cs, as, qp)(testAuthorizer((fnh)))
|
|
||||||
got.ServeHTTP(w, r)
|
|
||||||
|
|
||||||
gotBody := w.Body.String()
|
|
||||||
gotStatus := w.Result().StatusCode
|
|
||||||
|
|
||||||
if diff := cmp.Diff(gotBody, tt.wantBody); diff != "" {
|
|
||||||
t.Errorf("RetrieveSession() = %v", diff)
|
|
||||||
}
|
|
||||||
if diff := cmp.Diff(gotStatus, tt.wantStatus); diff != "" {
|
|
||||||
t.Errorf("RetrieveSession() = %v", diff)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func Test_contextKey_String(t *testing.T) {
|
func Test_contextKey_String(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
|
@ -155,3 +56,80 @@ func Test_contextKey_String(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func testAuthorizer(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
_, err := FromContext(r.Context())
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ SessionStore = &store{}
|
||||||
|
|
||||||
|
// Store is a mock implementation of the SessionStore interface
|
||||||
|
type store struct {
|
||||||
|
ResponseSession string
|
||||||
|
Session *State
|
||||||
|
SaveError error
|
||||||
|
LoadError error
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearSession clears the ResponseSession
|
||||||
|
func (ms *store) ClearSession(http.ResponseWriter, *http.Request) {
|
||||||
|
ms.ResponseSession = ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadSession returns the session and a error
|
||||||
|
func (ms store) LoadSession(*http.Request) (*State, error) {
|
||||||
|
return ms.Session, ms.LoadError
|
||||||
|
}
|
||||||
|
|
||||||
|
// SaveSession returns a save error.
|
||||||
|
func (ms store) SaveSession(http.ResponseWriter, *http.Request, interface{}) error {
|
||||||
|
return ms.SaveError
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestVerifier(t *testing.T) {
|
||||||
|
fnh := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||||
|
fmt.Fprint(w, http.StatusText(http.StatusOK))
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
})
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
store store
|
||||||
|
state State
|
||||||
|
wantBody string
|
||||||
|
wantStatus int
|
||||||
|
}{
|
||||||
|
{"empty session", store{}, State{}, "internal/sessions: session is not found\n", 401},
|
||||||
|
{"simple good load", store{Session: &State{Subject: "hi", Expiry: jwt.NewNumericDate(time.Now().Add(time.Second))}}, State{}, "OK", 200},
|
||||||
|
{"empty session", store{LoadError: errors.New("err")}, State{}, "err\n", 401},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
|
||||||
|
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
r.Header.Set("Accept", "application/json")
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
got := RetrieveSession(tt.store)(testAuthorizer((fnh)))
|
||||||
|
got.ServeHTTP(w, r)
|
||||||
|
|
||||||
|
gotBody := w.Body.String()
|
||||||
|
gotStatus := w.Result().StatusCode
|
||||||
|
|
||||||
|
if diff := cmp.Diff(gotBody, tt.wantBody); diff != "" {
|
||||||
|
t.Errorf("RetrieveSession() = %v", diff)
|
||||||
|
}
|
||||||
|
if diff := cmp.Diff(gotStatus, tt.wantStatus); diff != "" {
|
||||||
|
t.Errorf("RetrieveSession() = %v", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
33
internal/sessions/mock/mock_store.go
Normal file
33
internal/sessions/mock/mock_store.go
Normal file
|
@ -0,0 +1,33 @@
|
||||||
|
package mock // import "github.com/pomerium/pomerium/internal/sessions/mock"
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/internal/sessions"
|
||||||
|
)
|
||||||
|
|
||||||
|
var _ sessions.SessionStore = &Store{}
|
||||||
|
var _ sessions.SessionLoader = &Store{}
|
||||||
|
|
||||||
|
// Store is a mock implementation of the SessionStore interface
|
||||||
|
type Store struct {
|
||||||
|
ResponseSession string
|
||||||
|
Session *sessions.State
|
||||||
|
SaveError error
|
||||||
|
LoadError error
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearSession clears the ResponseSession
|
||||||
|
func (ms *Store) ClearSession(http.ResponseWriter, *http.Request) {
|
||||||
|
ms.ResponseSession = ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadSession returns the session and a error
|
||||||
|
func (ms Store) LoadSession(*http.Request) (*sessions.State, error) {
|
||||||
|
return ms.Session, ms.LoadError
|
||||||
|
}
|
||||||
|
|
||||||
|
// SaveSession returns a save error.
|
||||||
|
func (ms Store) SaveSession(http.ResponseWriter, *http.Request, interface{}) error {
|
||||||
|
return ms.SaveError
|
||||||
|
}
|
|
@ -1,26 +1,28 @@
|
||||||
package sessions // import "github.com/pomerium/pomerium/internal/sessions"
|
package mock // import "github.com/pomerium/pomerium/internal/sessions/mock"
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"reflect"
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/internal/sessions"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestMockSessionStore(t *testing.T) {
|
func TestStore(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
mockCSRF *MockSessionStore
|
mockCSRF *Store
|
||||||
saveSession *State
|
saveSession *sessions.State
|
||||||
wantLoadErr bool
|
wantLoadErr bool
|
||||||
wantSaveErr bool
|
wantSaveErr bool
|
||||||
}{
|
}{
|
||||||
{"basic",
|
{"basic",
|
||||||
&MockSessionStore{
|
&Store{
|
||||||
ResponseSession: "test",
|
ResponseSession: "test",
|
||||||
Session: &State{Subject: "0101"},
|
Session: &sessions.State{Subject: "0101"},
|
||||||
SaveError: nil,
|
SaveError: nil,
|
||||||
LoadError: nil,
|
LoadError: nil,
|
||||||
},
|
},
|
||||||
&State{Subject: "0101"},
|
&sessions.State{Subject: "0101"},
|
||||||
false,
|
false,
|
||||||
false},
|
false},
|
||||||
}
|
}
|
|
@ -1,28 +0,0 @@
|
||||||
package sessions // import "github.com/pomerium/pomerium/internal/sessions"
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/http"
|
|
||||||
)
|
|
||||||
|
|
||||||
// MockSessionStore is a mock implementation of the SessionStore interface
|
|
||||||
type MockSessionStore struct {
|
|
||||||
ResponseSession string
|
|
||||||
Session *State
|
|
||||||
SaveError error
|
|
||||||
LoadError error
|
|
||||||
}
|
|
||||||
|
|
||||||
// ClearSession clears the ResponseSession
|
|
||||||
func (ms *MockSessionStore) ClearSession(http.ResponseWriter, *http.Request) {
|
|
||||||
ms.ResponseSession = ""
|
|
||||||
}
|
|
||||||
|
|
||||||
// LoadSession returns the session and a error
|
|
||||||
func (ms MockSessionStore) LoadSession(*http.Request) (*State, error) {
|
|
||||||
return ms.Session, ms.LoadError
|
|
||||||
}
|
|
||||||
|
|
||||||
// SaveSession returns a save error.
|
|
||||||
func (ms MockSessionStore) SaveSession(http.ResponseWriter, *http.Request, interface{}) error {
|
|
||||||
return ms.SaveError
|
|
||||||
}
|
|
92
internal/sessions/queryparam/middleware_test.go
Normal file
92
internal/sessions/queryparam/middleware_test.go
Normal file
|
@ -0,0 +1,92 @@
|
||||||
|
package queryparam // import "github.com/pomerium/pomerium/internal/sessions/queryparam"
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/internal/cryptutil"
|
||||||
|
"github.com/pomerium/pomerium/internal/encoding/ecjson"
|
||||||
|
"github.com/pomerium/pomerium/internal/sessions"
|
||||||
|
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
|
"gopkg.in/square/go-jose.v2/jwt"
|
||||||
|
)
|
||||||
|
|
||||||
|
func testAuthorizer(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
_, err := sessions.FromContext(r.Context())
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestVerifier(t *testing.T) {
|
||||||
|
fnh := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||||
|
fmt.Fprint(w, http.StatusText(http.StatusOK))
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
})
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
state sessions.State
|
||||||
|
|
||||||
|
wantBody string
|
||||||
|
wantStatus int
|
||||||
|
}{
|
||||||
|
{"good auth query param session", sessions.State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}, http.StatusText(http.StatusOK), http.StatusOK},
|
||||||
|
{"expired auth query param", sessions.State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}, "internal/sessions: validation failed, token is expired (exp)\n", http.StatusUnauthorized},
|
||||||
|
{"malformed auth query param", sessions.State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}, "internal/sessions: session is malformed\n", http.StatusUnauthorized},
|
||||||
|
{"empty auth query param", sessions.State{Email: "user@pomerium.io", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}, "internal/sessions: session is not found\n", http.StatusUnauthorized},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
cipher, err := cryptutil.NewAEADCipherFromBase64(cryptutil.NewBase64Key())
|
||||||
|
encoder := ecjson.New(cipher)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
encSession, err := encoder.Marshal(&tt.state)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if strings.Contains(tt.name, "malformed") {
|
||||||
|
// add some garbage to the end of the string
|
||||||
|
encSession = append(encSession, cryptutil.NewKey()...)
|
||||||
|
}
|
||||||
|
|
||||||
|
s := NewStore(encoder, "")
|
||||||
|
|
||||||
|
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
r.Header.Set("Accept", "application/json")
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
q := r.URL.Query()
|
||||||
|
if strings.Contains(tt.name, "empty") {
|
||||||
|
encSession = []byte("")
|
||||||
|
}
|
||||||
|
q.Set("pomerium_session", string(encSession))
|
||||||
|
r.URL.RawQuery = q.Encode()
|
||||||
|
|
||||||
|
got := sessions.RetrieveSession(s)(testAuthorizer((fnh)))
|
||||||
|
got.ServeHTTP(w, r)
|
||||||
|
|
||||||
|
gotBody := w.Body.String()
|
||||||
|
gotStatus := w.Result().StatusCode
|
||||||
|
|
||||||
|
if diff := cmp.Diff(gotBody, tt.wantBody); diff != "" {
|
||||||
|
t.Errorf("RetrieveSession() = %v", diff)
|
||||||
|
}
|
||||||
|
if diff := cmp.Diff(gotStatus, tt.wantStatus); diff != "" {
|
||||||
|
t.Errorf("RetrieveSession() = %v", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,33 +1,37 @@
|
||||||
package sessions // import "github.com/pomerium/pomerium/internal/sessions"
|
package queryparam // import "github.com/pomerium/pomerium/internal/sessions/queryparam"
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/internal/encoding"
|
"github.com/pomerium/pomerium/internal/encoding"
|
||||||
|
"github.com/pomerium/pomerium/internal/sessions"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var _ sessions.SessionStore = &Store{}
|
||||||
|
var _ sessions.SessionLoader = &Store{}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
defaultQueryParamKey = "pomerium_session"
|
defaultQueryParamKey = "pomerium_session"
|
||||||
)
|
)
|
||||||
|
|
||||||
// QueryParamStore implements the load session store interface using http
|
// Store implements the load session store interface using http
|
||||||
// query strings / query parameters.
|
// query strings / query parameters.
|
||||||
type QueryParamStore struct {
|
type Store struct {
|
||||||
queryParamKey string
|
queryParamKey string
|
||||||
encoder encoding.Marshaler
|
encoder encoding.Marshaler
|
||||||
decoder encoding.Unmarshaler
|
decoder encoding.Unmarshaler
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewQueryParamStore returns a new query param store for loading sessions from
|
// NewStore returns a new query param store for loading sessions from
|
||||||
// query strings / query parameters.
|
// query strings / query parameters.
|
||||||
//
|
//
|
||||||
// NOTA BENE: By default, most servers _DO_ log query params, the leaking or
|
// NOTA BENE: By default, most servers _DO_ log query params, the leaking or
|
||||||
// accidental logging of which should be considered a security issue.
|
// accidental logging of which should be considered a security issue.
|
||||||
func NewQueryParamStore(enc encoding.MarshalUnmarshaler, qp string) *QueryParamStore {
|
func NewStore(enc encoding.MarshalUnmarshaler, qp string) *Store {
|
||||||
if qp == "" {
|
if qp == "" {
|
||||||
qp = defaultQueryParamKey
|
qp = defaultQueryParamKey
|
||||||
}
|
}
|
||||||
return &QueryParamStore{
|
return &Store{
|
||||||
queryParamKey: qp,
|
queryParamKey: qp,
|
||||||
encoder: enc,
|
encoder: enc,
|
||||||
decoder: enc,
|
decoder: enc,
|
||||||
|
@ -35,27 +39,27 @@ func NewQueryParamStore(enc encoding.MarshalUnmarshaler, qp string) *QueryParamS
|
||||||
}
|
}
|
||||||
|
|
||||||
// LoadSession tries to retrieve the token string from URL query parameters.
|
// LoadSession tries to retrieve the token string from URL query parameters.
|
||||||
func (qp *QueryParamStore) LoadSession(r *http.Request) (*State, error) {
|
func (qp *Store) LoadSession(r *http.Request) (*sessions.State, error) {
|
||||||
cipherText := r.URL.Query().Get(qp.queryParamKey)
|
cipherText := r.URL.Query().Get(qp.queryParamKey)
|
||||||
if cipherText == "" {
|
if cipherText == "" {
|
||||||
return nil, ErrNoSessionFound
|
return nil, sessions.ErrNoSessionFound
|
||||||
}
|
}
|
||||||
var session State
|
var session sessions.State
|
||||||
if err := qp.decoder.Unmarshal([]byte(cipherText), &session); err != nil {
|
if err := qp.decoder.Unmarshal([]byte(cipherText), &session); err != nil {
|
||||||
return nil, ErrMalformed
|
return nil, sessions.ErrMalformed
|
||||||
}
|
}
|
||||||
return &session, nil
|
return &session, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ClearSession clears the session cookie from a request's query param key `pomerium_session`.
|
// ClearSession clears the session cookie from a request's query param key `pomerium_session`.
|
||||||
func (qp *QueryParamStore) ClearSession(w http.ResponseWriter, r *http.Request) {
|
func (qp *Store) ClearSession(w http.ResponseWriter, r *http.Request) {
|
||||||
params := r.URL.Query()
|
params := r.URL.Query()
|
||||||
params.Del(qp.queryParamKey)
|
params.Del(qp.queryParamKey)
|
||||||
r.URL.RawQuery = params.Encode()
|
r.URL.RawQuery = params.Encode()
|
||||||
}
|
}
|
||||||
|
|
||||||
// SaveSession sets a session to a request's query param key `pomerium_session`
|
// SaveSession sets a session to a request's query param key `pomerium_session`
|
||||||
func (qp *QueryParamStore) SaveSession(w http.ResponseWriter, r *http.Request, x interface{}) error {
|
func (qp *Store) SaveSession(w http.ResponseWriter, r *http.Request, x interface{}) error {
|
||||||
data, err := qp.encoder.Marshal(x)
|
data, err := qp.encoder.Marshal(x)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
|
@ -1,4 +1,4 @@
|
||||||
package sessions
|
package queryparam // import "github.com/pomerium/pomerium/internal/sessions/queryparam"
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
|
@ -9,39 +9,40 @@ import (
|
||||||
"github.com/google/go-cmp/cmp"
|
"github.com/google/go-cmp/cmp"
|
||||||
"github.com/pomerium/pomerium/internal/encoding"
|
"github.com/pomerium/pomerium/internal/encoding"
|
||||||
"github.com/pomerium/pomerium/internal/encoding/mock"
|
"github.com/pomerium/pomerium/internal/encoding/mock"
|
||||||
|
"github.com/pomerium/pomerium/internal/sessions"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestNewQueryParamStore(t *testing.T) {
|
func TestNewQueryParamStore(t *testing.T) {
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
State *State
|
State *sessions.State
|
||||||
|
|
||||||
enc encoding.MarshalUnmarshaler
|
enc encoding.MarshalUnmarshaler
|
||||||
qp string
|
qp string
|
||||||
wantErr bool
|
wantErr bool
|
||||||
wantURL *url.URL
|
wantURL *url.URL
|
||||||
}{
|
}{
|
||||||
{"simple good", &State{Email: "user@domain.com", User: "user"}, mock.Encoder{MarshalResponse: []byte("ok")}, "", false, &url.URL{Path: "/", RawQuery: "pomerium_session=ok"}},
|
{"simple good", &sessions.State{Email: "user@domain.com", User: "user"}, mock.Encoder{MarshalResponse: []byte("ok")}, "", false, &url.URL{Path: "/", RawQuery: "pomerium_session=ok"}},
|
||||||
{"marshall error", &State{Email: "user@domain.com", User: "user"}, mock.Encoder{MarshalError: errors.New("error")}, "", true, &url.URL{Path: "/"}},
|
{"marshall error", &sessions.State{Email: "user@domain.com", User: "user"}, mock.Encoder{MarshalError: errors.New("error")}, "", true, &url.URL{Path: "/"}},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
got := NewQueryParamStore(tt.enc, tt.qp)
|
got := NewStore(tt.enc, tt.qp)
|
||||||
|
|
||||||
r := httptest.NewRequest("GET", "/", nil)
|
r := httptest.NewRequest("GET", "/", nil)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
if err := got.SaveSession(w, r, tt.State); (err != nil) != tt.wantErr {
|
if err := got.SaveSession(w, r, tt.State); (err != nil) != tt.wantErr {
|
||||||
t.Errorf("NewQueryParamStore.SaveSession() error = %v, wantErr %v", err, tt.wantErr)
|
t.Errorf("NewStore.SaveSession() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
}
|
}
|
||||||
|
|
||||||
if diff := cmp.Diff(r.URL, tt.wantURL); diff != "" {
|
if diff := cmp.Diff(r.URL, tt.wantURL); diff != "" {
|
||||||
t.Errorf("NewQueryParamStore() = %v", diff)
|
t.Errorf("NewStore() = %v", diff)
|
||||||
}
|
}
|
||||||
got.ClearSession(w, r)
|
got.ClearSession(w, r)
|
||||||
if diff := cmp.Diff(r.URL, &url.URL{Path: "/"}); diff != "" {
|
if diff := cmp.Diff(r.URL, &url.URL{Path: "/"}); diff != "" {
|
||||||
t.Errorf("NewQueryParamStore() = %v", diff)
|
t.Errorf("NewStore() = %v", diff)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
|
@ -6,6 +6,8 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/cespare/xxhash/v2"
|
||||||
|
"github.com/mitchellh/hashstructure"
|
||||||
oidc "github.com/pomerium/go-oidc"
|
oidc "github.com/pomerium/go-oidc"
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
"gopkg.in/square/go-jose.v2/jwt"
|
"gopkg.in/square/go-jose.v2/jwt"
|
||||||
|
@ -51,7 +53,8 @@ type State struct {
|
||||||
// programatic access.
|
// programatic access.
|
||||||
Programmatic bool `json:"programatic"`
|
Programmatic bool `json:"programatic"`
|
||||||
|
|
||||||
AccessToken *oauth2.Token `json:"access_token,omitempty"`
|
AccessToken *oauth2.Token `json:"act,omitempty"`
|
||||||
|
AccessTokenID string `json:"ati,omitempty"`
|
||||||
|
|
||||||
idToken *oidc.IDToken
|
idToken *oidc.IDToken
|
||||||
}
|
}
|
||||||
|
@ -73,7 +76,7 @@ func NewStateFromTokens(idToken *oidc.IDToken, accessToken *oauth2.Token, audien
|
||||||
s.Audience = []string{audience}
|
s.Audience = []string{audience}
|
||||||
s.idToken = idToken
|
s.idToken = idToken
|
||||||
s.AccessToken = accessToken
|
s.AccessToken = accessToken
|
||||||
|
s.AccessTokenID = s.accessTokenHash()
|
||||||
return s, nil
|
return s, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -95,6 +98,7 @@ func (s *State) UpdateState(idToken *oidc.IDToken, accessToken *oauth2.Token) er
|
||||||
}
|
}
|
||||||
s.Audience = audience
|
s.Audience = audience
|
||||||
s.Expiry = jwt.NewNumericDate(accessToken.Expiry)
|
s.Expiry = jwt.NewNumericDate(accessToken.Expiry)
|
||||||
|
s.AccessTokenID = s.accessTokenHash()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -173,3 +177,13 @@ func (s *State) SetImpersonation(email, groups string) {
|
||||||
s.ImpersonateGroups = strings.Split(groups, ",")
|
s.ImpersonateGroups = strings.Split(groups, ",")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *State) accessTokenHash() string {
|
||||||
|
hash, err := hashstructure.Hash(
|
||||||
|
s.AccessToken,
|
||||||
|
&hashstructure.HashOptions{Hasher: xxhash.New()})
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%x", hash)
|
||||||
|
}
|
||||||
|
|
|
@ -124,3 +124,26 @@ func TestState_RouteSession(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestState_accessTokenHash(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
state State
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{"empty access token", State{}, "34c96acdcadb1bbb"},
|
||||||
|
{"no change to access token", State{Subject: "test"}, "34c96acdcadb1bbb"},
|
||||||
|
{"empty oauth2 token", State{AccessToken: &oauth2.Token{}}, "bbd82197d215198f"},
|
||||||
|
{"refresh token a", State{AccessToken: &oauth2.Token{RefreshToken: "a"}}, "76316ac79b301bd6"},
|
||||||
|
{"refresh token b", State{AccessToken: &oauth2.Token{RefreshToken: "b"}}, "fab7cb29e50161f1"},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
s := &tt.state
|
||||||
|
if got := s.accessTokenHash(); got != tt.want {
|
||||||
|
t.Errorf("State.accessTokenHash() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -1,39 +1,17 @@
|
||||||
package sessions // import "github.com/pomerium/pomerium/internal/sessions"
|
package sessions // import "github.com/pomerium/pomerium/internal/sessions"
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
// SessionStore defines an interface for loading, saving, and clearing a session.
|
||||||
// ErrNoSessionFound is the error for when no session is found.
|
|
||||||
ErrNoSessionFound = errors.New("internal/sessions: session is not found")
|
|
||||||
|
|
||||||
// ErrMalformed is the error for when a session is found but is malformed.
|
|
||||||
ErrMalformed = errors.New("internal/sessions: session is malformed")
|
|
||||||
|
|
||||||
// ErrNotValidYet indicates that token is used before time indicated in nbf claim.
|
|
||||||
ErrNotValidYet = errors.New("internal/sessions: validation failed, token not valid yet (nbf)")
|
|
||||||
|
|
||||||
// ErrExpired indicates that token is used after expiry time indicated in exp claim.
|
|
||||||
ErrExpired = errors.New("internal/sessions: validation failed, token is expired (exp)")
|
|
||||||
|
|
||||||
// ErrIssuedInTheFuture indicates that the iat field is in the future.
|
|
||||||
ErrIssuedInTheFuture = errors.New("internal/sessions: validation field, token issued in the future (iat)")
|
|
||||||
|
|
||||||
// ErrInvalidAudience indicated invalid aud claim.
|
|
||||||
ErrInvalidAudience = errors.New("internal/sessions: validation failed, invalid audience claim (aud)")
|
|
||||||
)
|
|
||||||
|
|
||||||
// SessionStore has the functions for setting, getting, and clearing the Session cookie
|
|
||||||
type SessionStore interface {
|
type SessionStore interface {
|
||||||
ClearSession(http.ResponseWriter, *http.Request)
|
|
||||||
SessionLoader
|
SessionLoader
|
||||||
|
ClearSession(http.ResponseWriter, *http.Request)
|
||||||
SaveSession(http.ResponseWriter, *http.Request, interface{}) error
|
SaveSession(http.ResponseWriter, *http.Request, interface{}) error
|
||||||
}
|
}
|
||||||
|
|
||||||
// SessionLoader is implemented by any struct that loads a pomerium session
|
// SessionLoader defines an interface for loading a session.
|
||||||
// given a request, and returns a user state.
|
|
||||||
type SessionLoader interface {
|
type SessionLoader interface {
|
||||||
LoadSession(*http.Request) (*State, error)
|
LoadSession(*http.Request) (*State, error)
|
||||||
}
|
}
|
||||||
|
|
|
@ -34,10 +34,10 @@ var (
|
||||||
// DefaultViews are a set of default views to view HTTP and GRPC metrics.
|
// DefaultViews are a set of default views to view HTTP and GRPC metrics.
|
||||||
var (
|
var (
|
||||||
DefaultViews = [][]*view.View{
|
DefaultViews = [][]*view.View{
|
||||||
GRPCServerViews,
|
|
||||||
HTTPServerViews,
|
|
||||||
GRPCClientViews,
|
GRPCClientViews,
|
||||||
GRPCServerViews,
|
GRPCServerViews,
|
||||||
|
HTTPClientViews,
|
||||||
|
HTTPServerViews,
|
||||||
InfoViews,
|
InfoViews,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
|
@ -9,14 +9,16 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/google/go-cmp/cmp"
|
"github.com/google/go-cmp/cmp"
|
||||||
|
"gopkg.in/square/go-jose.v2/jwt"
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/config"
|
"github.com/pomerium/pomerium/config"
|
||||||
"github.com/pomerium/pomerium/internal/encoding"
|
"github.com/pomerium/pomerium/internal/encoding"
|
||||||
"github.com/pomerium/pomerium/internal/encoding/mock"
|
"github.com/pomerium/pomerium/internal/encoding/mock"
|
||||||
"github.com/pomerium/pomerium/internal/httputil"
|
"github.com/pomerium/pomerium/internal/httputil"
|
||||||
"github.com/pomerium/pomerium/internal/sessions"
|
"github.com/pomerium/pomerium/internal/sessions"
|
||||||
|
mstore "github.com/pomerium/pomerium/internal/sessions/mock"
|
||||||
"github.com/pomerium/pomerium/internal/urlutil"
|
"github.com/pomerium/pomerium/internal/urlutil"
|
||||||
"github.com/pomerium/pomerium/proxy/clients"
|
"github.com/pomerium/pomerium/proxy/clients"
|
||||||
"gopkg.in/square/go-jose.v2/jwt"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestProxy_ForwardAuth(t *testing.T) {
|
func TestProxy_ForwardAuth(t *testing.T) {
|
||||||
|
@ -40,29 +42,29 @@ func TestProxy_ForwardAuth(t *testing.T) {
|
||||||
wantStatus int
|
wantStatus int
|
||||||
wantBody string
|
wantBody string
|
||||||
}{
|
}{
|
||||||
{"good redirect not required", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusOK, "Access to some.domain.example is allowed."},
|
{"good redirect not required", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusOK, "Access to some.domain.example is allowed."},
|
||||||
{"good verify only, no redirect", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/verify", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusOK, ""},
|
{"good verify only, no redirect", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/verify", "https://some.domain.example", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusOK, ""},
|
||||||
{"good redirect not required", opts, nil, http.MethodGet, nil, nil, "/", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{LoadError: sessions.ErrInvalidAudience}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusUnauthorized, "{\"Status\":401,\"Error\":\"Unauthorized: internal/sessions: validation failed, invalid audience claim (aud)\"}\n"},
|
{"good redirect not required", opts, nil, http.MethodGet, nil, nil, "/", "https://some.domain.example", &mock.Encoder{}, &mstore.Store{LoadError: sessions.ErrInvalidAudience}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusUnauthorized, "{\"Status\":401,\"Error\":\"Unauthorized: internal/sessions: validation failed, invalid audience claim (aud)\"}\n"},
|
||||||
{"bad naked domain uri", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/", "a.naked.domain", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, "{\"Status\":400,\"Error\":\"Bad Request: a.naked.domain url does contain a valid scheme\"}\n"},
|
{"bad naked domain uri", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/", "a.naked.domain", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, "{\"Status\":400,\"Error\":\"Bad Request: a.naked.domain url does contain a valid scheme\"}\n"},
|
||||||
{"bad naked domain uri verify only", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/verify", "a.naked.domain", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, "{\"Status\":400,\"Error\":\"Bad Request: a.naked.domain url does contain a valid scheme\"}\n"},
|
{"bad naked domain uri verify only", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/verify", "a.naked.domain", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, "{\"Status\":400,\"Error\":\"Bad Request: a.naked.domain url does contain a valid scheme\"}\n"},
|
||||||
{"bad empty verification uri", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/", " ", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, "{\"Status\":400,\"Error\":\"Bad Request: %20 url does contain a valid scheme\"}\n"},
|
{"bad empty verification uri", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/", " ", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, "{\"Status\":400,\"Error\":\"Bad Request: %20 url does contain a valid scheme\"}\n"},
|
||||||
{"bad empty verification uri verify only", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/verify", " ", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, "{\"Status\":400,\"Error\":\"Bad Request: %20 url does contain a valid scheme\"}\n"},
|
{"bad empty verification uri verify only", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/verify", " ", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, "{\"Status\":400,\"Error\":\"Bad Request: %20 url does contain a valid scheme\"}\n"},
|
||||||
{"not authorized", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusUnauthorized, "{\"Status\":401,\"Error\":\"Unauthorized: user@test.example is not authorized for some.domain.example\"}\n"},
|
{"not authorized", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusUnauthorized, "{\"Status\":401,\"Error\":\"Unauthorized: user@test.example is not authorized for some.domain.example\"}\n"},
|
||||||
{"not authorized verify endpoint", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/verify", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusUnauthorized, "{\"Status\":401,\"Error\":\"Unauthorized: user@test.example is not authorized for some.domain.example\"}\n"},
|
{"not authorized verify endpoint", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/verify", "https://some.domain.example", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusUnauthorized, "{\"Status\":401,\"Error\":\"Unauthorized: user@test.example is not authorized for some.domain.example\"}\n"},
|
||||||
{"not authorized expired, redirect to auth", opts, sessions.ErrExpired, http.MethodGet, nil, nil, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusFound, ""},
|
{"not authorized expired, redirect to auth", opts, sessions.ErrExpired, http.MethodGet, nil, nil, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusFound, ""},
|
||||||
{"not authorized expired, don't redirect!", opts, sessions.ErrExpired, http.MethodGet, nil, nil, "https://some.domain.example/verify", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusUnauthorized, "{\"Status\":401,\"Error\":\"Unauthorized: internal/sessions: validation failed, token is expired (exp)\"}\n"},
|
{"not authorized expired, don't redirect!", opts, sessions.ErrExpired, http.MethodGet, nil, nil, "https://some.domain.example/verify", "https://some.domain.example", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusUnauthorized, "{\"Status\":401,\"Error\":\"Unauthorized: internal/sessions: validation failed, token is expired (exp)\"}\n"},
|
||||||
{"not authorized because of error", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeError: errors.New("authz error")}, http.StatusInternalServerError, "{\"Status\":500,\"Error\":\"Internal Server Error: authz error\"}\n"},
|
{"not authorized because of error", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeError: errors.New("authz error")}, http.StatusInternalServerError, "{\"Status\":500,\"Error\":\"Internal Server Error: authz error\"}\n"},
|
||||||
{"not authorized expired, do not redirect to auth", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/verify", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusUnauthorized, "{\"Status\":401,\"Error\":\"Unauthorized: internal/sessions: validation failed, token is expired (exp)\"}\n"},
|
{"not authorized expired, do not redirect to auth", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/verify", "https://some.domain.example", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusUnauthorized, "{\"Status\":401,\"Error\":\"Unauthorized: internal/sessions: validation failed, token is expired (exp)\"}\n"},
|
||||||
{"not authorized, bad audience request uri", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Audience: []string{"not.domain.example"}, Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusUnauthorized, "{\"Status\":401,\"Error\":\"Unauthorized: internal/sessions: validation failed, invalid audience claim (aud)\"}\n"},
|
{"not authorized, bad audience request uri", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Audience: []string{"not.domain.example"}, Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusUnauthorized, "{\"Status\":401,\"Error\":\"Unauthorized: internal/sessions: validation failed, invalid audience claim (aud)\"}\n"},
|
||||||
{"not authorized, bad audience verify uri", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/", "https://fwdauth.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Audience: []string{"some.domain.example"}, Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusUnauthorized, "{\"Status\":401,\"Error\":\"Unauthorized: internal/sessions: validation failed, invalid audience claim (aud)\"}\n"},
|
{"not authorized, bad audience verify uri", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/", "https://fwdauth.domain.example", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Audience: []string{"some.domain.example"}, Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusUnauthorized, "{\"Status\":401,\"Error\":\"Unauthorized: internal/sessions: validation failed, invalid audience claim (aud)\"}\n"},
|
||||||
// traefik
|
// traefik
|
||||||
{"good traefik callback", opts, nil, http.MethodGet, map[string]string{httputil.HeaderForwardedURI: "https://some.domain.example?" + urlutil.QuerySessionEncrypted + "=" + goodEncryptionString}, nil, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound, ""},
|
{"good traefik callback", opts, nil, http.MethodGet, map[string]string{httputil.HeaderForwardedURI: "https://some.domain.example?" + urlutil.QuerySessionEncrypted + "=" + goodEncryptionString}, nil, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound, ""},
|
||||||
{"bad traefik callback bad session", opts, nil, http.MethodGet, map[string]string{httputil.HeaderForwardedURI: "https://some.domain.example?" + urlutil.QuerySessionEncrypted + "=" + goodEncryptionString + "garbage"}, nil, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""},
|
{"bad traefik callback bad session", opts, nil, http.MethodGet, map[string]string{httputil.HeaderForwardedURI: "https://some.domain.example?" + urlutil.QuerySessionEncrypted + "=" + goodEncryptionString + "garbage"}, nil, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""},
|
||||||
{"bad traefik callback bad url", opts, nil, http.MethodGet, map[string]string{httputil.HeaderForwardedURI: urlutil.QuerySessionEncrypted + ""}, nil, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""},
|
{"bad traefik callback bad url", opts, nil, http.MethodGet, map[string]string{httputil.HeaderForwardedURI: urlutil.QuerySessionEncrypted + ""}, nil, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""},
|
||||||
// nginx
|
// nginx
|
||||||
{"good nginx callback redirect", opts, nil, http.MethodGet, nil, map[string]string{urlutil.QueryRedirectURI: "https://some.domain.example/", urlutil.QuerySessionEncrypted: goodEncryptionString}, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound, ""},
|
{"good nginx callback redirect", opts, nil, http.MethodGet, nil, map[string]string{urlutil.QueryRedirectURI: "https://some.domain.example/", urlutil.QuerySessionEncrypted: goodEncryptionString}, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound, ""},
|
||||||
{"good nginx callback set session okay but return unauthorized", opts, nil, http.MethodGet, nil, map[string]string{urlutil.QueryRedirectURI: "https://some.domain.example/", urlutil.QuerySessionEncrypted: goodEncryptionString}, "https://some.domain.example/verify", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusUnauthorized, ""},
|
{"good nginx callback set session okay but return unauthorized", opts, nil, http.MethodGet, nil, map[string]string{urlutil.QueryRedirectURI: "https://some.domain.example/", urlutil.QuerySessionEncrypted: goodEncryptionString}, "https://some.domain.example/verify", "https://some.domain.example", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusUnauthorized, ""},
|
||||||
{"bad nginx callback failed to set sesion", opts, nil, http.MethodGet, nil, map[string]string{urlutil.QueryRedirectURI: "https://some.domain.example/", urlutil.QuerySessionEncrypted: goodEncryptionString + "nope"}, "https://some.domain.example/verify", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""},
|
{"bad nginx callback failed to set sesion", opts, nil, http.MethodGet, nil, map[string]string{urlutil.QueryRedirectURI: "https://some.domain.example/", urlutil.QuerySessionEncrypted: goodEncryptionString + "nope"}, "https://some.domain.example/verify", "https://some.domain.example", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
|
|
@ -12,6 +12,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/internal/cryptutil"
|
"github.com/pomerium/pomerium/internal/cryptutil"
|
||||||
|
mstore "github.com/pomerium/pomerium/internal/sessions/mock"
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/config"
|
"github.com/pomerium/pomerium/config"
|
||||||
"github.com/pomerium/pomerium/internal/encoding"
|
"github.com/pomerium/pomerium/internal/encoding"
|
||||||
|
@ -78,10 +79,10 @@ func TestProxy_UserDashboard(t *testing.T) {
|
||||||
wantAdminForm bool
|
wantAdminForm bool
|
||||||
wantStatus int
|
wantStatus int
|
||||||
}{
|
}{
|
||||||
{"good", nil, opts, http.MethodGet, &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{}, false, http.StatusOK},
|
{"good", nil, opts, http.MethodGet, &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{}, false, http.StatusOK},
|
||||||
{"session context error", errors.New("error"), opts, http.MethodGet, &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{}, false, http.StatusInternalServerError},
|
{"session context error", errors.New("error"), opts, http.MethodGet, &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{}, false, http.StatusInternalServerError},
|
||||||
{"want admin form good admin authorization", nil, opts, http.MethodGet, &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{IsAdminResponse: true}, true, http.StatusOK},
|
{"want admin form good admin authorization", nil, opts, http.MethodGet, &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{IsAdminResponse: true}, true, http.StatusOK},
|
||||||
{"is admin but authorization fails", nil, opts, http.MethodGet, &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{IsAdminError: errors.New("err")}, false, http.StatusInternalServerError},
|
{"is admin but authorization fails", nil, opts, http.MethodGet, &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{IsAdminError: errors.New("err")}, false, http.StatusInternalServerError},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
|
@ -135,12 +136,12 @@ func TestProxy_Impersonate(t *testing.T) {
|
||||||
authorizer clients.Authorizer
|
authorizer clients.Authorizer
|
||||||
wantStatus int
|
wantStatus int
|
||||||
}{
|
}{
|
||||||
{"good", false, opts, nil, http.MethodPost, "user@blah.com", "", "", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example"}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusFound},
|
{"good", false, opts, nil, http.MethodPost, "user@blah.com", "", "", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example"}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusFound},
|
||||||
{"good", false, opts, errors.New("error"), http.MethodPost, "user@blah.com", "", "", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example"}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusInternalServerError},
|
{"good", false, opts, errors.New("error"), http.MethodPost, "user@blah.com", "", "", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example"}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusInternalServerError},
|
||||||
{"session load error", false, opts, nil, http.MethodPost, "user@blah.com", "", "", &mock.Encoder{}, &sessions.MockSessionStore{LoadError: errors.New("err"), Session: &sessions.State{Email: "user@test.example"}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusFound},
|
{"session load error", false, opts, nil, http.MethodPost, "user@blah.com", "", "", &mock.Encoder{}, &mstore.Store{LoadError: errors.New("err"), Session: &sessions.State{Email: "user@test.example"}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusFound},
|
||||||
{"non admin users rejected", false, opts, nil, http.MethodPost, "user@blah.com", "", "", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute)), Email: "user@test.example"}}, clients.MockAuthorize{IsAdminResponse: false}, http.StatusForbidden},
|
{"non admin users rejected", false, opts, nil, http.MethodPost, "user@blah.com", "", "", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute)), Email: "user@test.example"}}, clients.MockAuthorize{IsAdminResponse: false}, http.StatusForbidden},
|
||||||
{"non admin users rejected on error", false, opts, nil, http.MethodPost, "user@blah.com", "", "", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute)), Email: "user@test.example"}}, clients.MockAuthorize{IsAdminResponse: true, IsAdminError: errors.New("err")}, http.StatusInternalServerError},
|
{"non admin users rejected on error", false, opts, nil, http.MethodPost, "user@blah.com", "", "", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute)), Email: "user@test.example"}}, clients.MockAuthorize{IsAdminResponse: true, IsAdminError: errors.New("err")}, http.StatusInternalServerError},
|
||||||
{"groups", false, opts, nil, http.MethodPost, "user@blah.com", "group1,group2", "", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute)), Email: "user@test.example"}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusFound},
|
{"groups", false, opts, nil, http.MethodPost, "user@blah.com", "group1,group2", "", &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute)), Email: "user@test.example"}}, clients.MockAuthorize{IsAdminResponse: true}, http.StatusFound},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
@ -245,12 +246,12 @@ func TestProxy_Callback(t *testing.T) {
|
||||||
wantStatus int
|
wantStatus int
|
||||||
wantBody string
|
wantBody string
|
||||||
}{
|
}{
|
||||||
{"good", opts, http.MethodGet, "http", "example.com", "/", nil, map[string]string{urlutil.QueryCallbackURI: "ok", urlutil.QuerySessionEncrypted: goodEncryptionString}, &mock.Encoder{MarshalResponse: []byte("x")}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound, ""},
|
{"good", opts, http.MethodGet, "http", "example.com", "/", nil, map[string]string{urlutil.QueryCallbackURI: "ok", urlutil.QuerySessionEncrypted: goodEncryptionString}, &mock.Encoder{MarshalResponse: []byte("x")}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound, ""},
|
||||||
{"good programmatic", opts, http.MethodGet, "http", "example.com", "/", nil, map[string]string{urlutil.QueryIsProgrammatic: "true", urlutil.QueryCallbackURI: "ok", urlutil.QuerySessionEncrypted: goodEncryptionString}, &mock.Encoder{MarshalResponse: []byte("x")}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound, ""},
|
{"good programmatic", opts, http.MethodGet, "http", "example.com", "/", nil, map[string]string{urlutil.QueryIsProgrammatic: "true", urlutil.QueryCallbackURI: "ok", urlutil.QuerySessionEncrypted: goodEncryptionString}, &mock.Encoder{MarshalResponse: []byte("x")}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound, ""},
|
||||||
{"bad decrypt", opts, http.MethodGet, "http", "example.com", "/", nil, map[string]string{urlutil.QuerySessionEncrypted: "KBEjQ9rnCxaAX-GOqexGw9ivEQURqts3zZ2mNGy0wnVa3SbtM399KlBq2nZ-9wM21FfsZX52er4jlmC7kPEKM3P7uZ41zR0zeys1-_74a5tQp-vsf1WXZfRsgVOuBcWPkMiWEoc379JFHxGDudp5VhU8B-dcQt4f3_PtLTHARkuH54io1Va2gNMq4Hiy8sQ1MPGCQeltH_JMzzdDpXdmdusWrXUvCGkba24muvAV06D8XRVJj6Iu9eK94qFnqcHc7wzziEbb8ADBues9dwbtb6jl8vMWz5rN6XvXqA5YpZv_MQZlsrO4oXFFQDevdgB84cX1tVbVu6qZvK_yQBZqzpOjWA9uIaoSENMytoXuWAlFO_sXjswfX8JTNdGwzB7qQRNPqxVG_sM_tzY3QhPm8zqwEzsXG5DokxZfVt2I5WJRUEovFDb4BnK9KFnnkEzLEdMudixVnXeGmTtycgJvoTeTCQRPfDYkcgJ7oKf4tGea-W7z5UAVa2RduJM9ZoM6YtJX7jgDm__PvvqcE0knJUF87XHBzdcOjoDF-CUze9xDJgNBlvPbJqVshKrwoqSYpePSDH9GUCNKxGequW3Ma8GvlFfhwd0rK6IZG-XWkyk0XSWQIGkDSjAvhB1wsOusCCguDjbpVZpaW5MMyTkmx68pl6qlIKT5UCcrVPl4ix5ZEj91mUDF0O1t04haD7VZuLVFXVGmqtFrBKI76sdYN-zkokaa1_chPRTyqMQFlqu_8LD6-RiK3UccGM-dEmnX72i91NP9F9OK0WJr9Cheup1C_P0mjqAO4Cb8oIHm0Oxz_mRqv5QbTGJtb3xwPLPuVjVCiE4gGBcuU2ixpSVf5HUF7y1KicVMCKiX9ATCBtg8sTdQZQnPEtHcHHAvdsnDVwev1LGfqA-Gdvg="}, &mock.Encoder{MarshalResponse: []byte("x")}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""},
|
{"bad decrypt", opts, http.MethodGet, "http", "example.com", "/", nil, map[string]string{urlutil.QuerySessionEncrypted: "KBEjQ9rnCxaAX-GOqexGw9ivEQURqts3zZ2mNGy0wnVa3SbtM399KlBq2nZ-9wM21FfsZX52er4jlmC7kPEKM3P7uZ41zR0zeys1-_74a5tQp-vsf1WXZfRsgVOuBcWPkMiWEoc379JFHxGDudp5VhU8B-dcQt4f3_PtLTHARkuH54io1Va2gNMq4Hiy8sQ1MPGCQeltH_JMzzdDpXdmdusWrXUvCGkba24muvAV06D8XRVJj6Iu9eK94qFnqcHc7wzziEbb8ADBues9dwbtb6jl8vMWz5rN6XvXqA5YpZv_MQZlsrO4oXFFQDevdgB84cX1tVbVu6qZvK_yQBZqzpOjWA9uIaoSENMytoXuWAlFO_sXjswfX8JTNdGwzB7qQRNPqxVG_sM_tzY3QhPm8zqwEzsXG5DokxZfVt2I5WJRUEovFDb4BnK9KFnnkEzLEdMudixVnXeGmTtycgJvoTeTCQRPfDYkcgJ7oKf4tGea-W7z5UAVa2RduJM9ZoM6YtJX7jgDm__PvvqcE0knJUF87XHBzdcOjoDF-CUze9xDJgNBlvPbJqVshKrwoqSYpePSDH9GUCNKxGequW3Ma8GvlFfhwd0rK6IZG-XWkyk0XSWQIGkDSjAvhB1wsOusCCguDjbpVZpaW5MMyTkmx68pl6qlIKT5UCcrVPl4ix5ZEj91mUDF0O1t04haD7VZuLVFXVGmqtFrBKI76sdYN-zkokaa1_chPRTyqMQFlqu_8LD6-RiK3UccGM-dEmnX72i91NP9F9OK0WJr9Cheup1C_P0mjqAO4Cb8oIHm0Oxz_mRqv5QbTGJtb3xwPLPuVjVCiE4gGBcuU2ixpSVf5HUF7y1KicVMCKiX9ATCBtg8sTdQZQnPEtHcHHAvdsnDVwev1LGfqA-Gdvg="}, &mock.Encoder{MarshalResponse: []byte("x")}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""},
|
||||||
{"bad save session", opts, http.MethodGet, "http", "example.com", "/", nil, map[string]string{urlutil.QuerySessionEncrypted: goodEncryptionString}, &mock.Encoder{MarshalResponse: []byte("x")}, &sessions.MockSessionStore{SaveError: errors.New("hi")}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""},
|
{"bad save session", opts, http.MethodGet, "http", "example.com", "/", nil, map[string]string{urlutil.QuerySessionEncrypted: goodEncryptionString}, &mock.Encoder{MarshalResponse: []byte("x")}, &mstore.Store{SaveError: errors.New("hi")}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""},
|
||||||
{"bad base64", opts, http.MethodGet, "http", "example.com", "/", nil, map[string]string{urlutil.QuerySessionEncrypted: "^"}, &mock.Encoder{MarshalResponse: []byte("x")}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""},
|
{"bad base64", opts, http.MethodGet, "http", "example.com", "/", nil, map[string]string{urlutil.QuerySessionEncrypted: "^"}, &mock.Encoder{MarshalResponse: []byte("x")}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""},
|
||||||
{"malformed redirect", opts, http.MethodGet, "http", "example.com", "/", nil, nil, &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""},
|
{"malformed redirect", opts, http.MethodGet, "http", "example.com", "/", nil, nil, &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
@ -386,12 +387,12 @@ func TestProxy_ProgrammaticCallback(t *testing.T) {
|
||||||
wantStatus int
|
wantStatus int
|
||||||
wantBody string
|
wantBody string
|
||||||
}{
|
}{
|
||||||
{"good", opts, http.MethodGet, "http://pomerium.io/", nil, map[string]string{urlutil.QueryCallbackURI: "ok", urlutil.QuerySessionEncrypted: goodEncryptionString}, &mock.Encoder{MarshalResponse: []byte("x")}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound, ""},
|
{"good", opts, http.MethodGet, "http://pomerium.io/", nil, map[string]string{urlutil.QueryCallbackURI: "ok", urlutil.QuerySessionEncrypted: goodEncryptionString}, &mock.Encoder{MarshalResponse: []byte("x")}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound, ""},
|
||||||
{"good programmatic", opts, http.MethodGet, "http://pomerium.io/", nil, map[string]string{urlutil.QueryIsProgrammatic: "true", urlutil.QueryCallbackURI: "ok", urlutil.QuerySessionEncrypted: goodEncryptionString}, &mock.Encoder{MarshalResponse: []byte("x")}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound, ""},
|
{"good programmatic", opts, http.MethodGet, "http://pomerium.io/", nil, map[string]string{urlutil.QueryIsProgrammatic: "true", urlutil.QueryCallbackURI: "ok", urlutil.QuerySessionEncrypted: goodEncryptionString}, &mock.Encoder{MarshalResponse: []byte("x")}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound, ""},
|
||||||
{"bad decrypt", opts, http.MethodGet, "http://pomerium.io/", nil, map[string]string{urlutil.QuerySessionEncrypted: goodEncryptionString + cryptutil.NewBase64Key()}, &mock.Encoder{MarshalResponse: []byte("x")}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""},
|
{"bad decrypt", opts, http.MethodGet, "http://pomerium.io/", nil, map[string]string{urlutil.QuerySessionEncrypted: goodEncryptionString + cryptutil.NewBase64Key()}, &mock.Encoder{MarshalResponse: []byte("x")}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""},
|
||||||
{"bad save session", opts, http.MethodGet, "http://pomerium.io/", nil, map[string]string{urlutil.QuerySessionEncrypted: goodEncryptionString}, &mock.Encoder{MarshalResponse: []byte("x")}, &sessions.MockSessionStore{SaveError: errors.New("hi")}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""},
|
{"bad save session", opts, http.MethodGet, "http://pomerium.io/", nil, map[string]string{urlutil.QuerySessionEncrypted: goodEncryptionString}, &mock.Encoder{MarshalResponse: []byte("x")}, &mstore.Store{SaveError: errors.New("hi")}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""},
|
||||||
{"bad base64", opts, http.MethodGet, "http://pomerium.io/", nil, map[string]string{urlutil.QuerySessionEncrypted: "^"}, &mock.Encoder{MarshalResponse: []byte("x")}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""},
|
{"bad base64", opts, http.MethodGet, "http://pomerium.io/", nil, map[string]string{urlutil.QuerySessionEncrypted: "^"}, &mock.Encoder{MarshalResponse: []byte("x")}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""},
|
||||||
{"malformed redirect", opts, http.MethodGet, "http://pomerium.io/", nil, nil, &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""},
|
{"malformed redirect", opts, http.MethodGet, "http://pomerium.io/", nil, nil, &mock.Encoder{}, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
|
|
@ -1,7 +1,11 @@
|
||||||
package proxy // import "github.com/pomerium/pomerium/proxy"
|
package proxy // import "github.com/pomerium/pomerium/proxy"
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"io/ioutil"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/internal/encoding"
|
"github.com/pomerium/pomerium/internal/encoding"
|
||||||
|
@ -30,23 +34,82 @@ func (p *Proxy) AuthenticateSession(next http.Handler) http.Handler {
|
||||||
ctx, span := trace.StartSpan(r.Context(), "proxy.AuthenticateSession")
|
ctx, span := trace.StartSpan(r.Context(), "proxy.AuthenticateSession")
|
||||||
defer span.End()
|
defer span.End()
|
||||||
|
|
||||||
if s, err := sessions.FromContext(ctx); err != nil {
|
_, err := sessions.FromContext(ctx)
|
||||||
log.FromRequest(r).Debug().Err(err).Msg("proxy: authenticate session")
|
if errors.Is(err, sessions.ErrExpired) {
|
||||||
p.sessionStore.ClearSession(w, r)
|
ctx, err = p.refresh(ctx, w, r)
|
||||||
if s != nil && s.Programmatic {
|
if err != nil {
|
||||||
return httputil.NewError(http.StatusUnauthorized, err)
|
log.FromRequest(r).Warn().Err(err).Msg("proxy: refresh failed")
|
||||||
|
return p.redirectToSignin(w, r)
|
||||||
}
|
}
|
||||||
signinURL := *p.authenticateSigninURL
|
log.FromRequest(r).Info().Msg("proxy: refresh success")
|
||||||
q := signinURL.Query()
|
} else if err != nil {
|
||||||
q.Set(urlutil.QueryRedirectURI, urlutil.GetAbsoluteURL(r).String())
|
log.FromRequest(r).Debug().Err(err).Msg("proxy: session state")
|
||||||
signinURL.RawQuery = q.Encode()
|
return p.redirectToSignin(w, r)
|
||||||
httputil.Redirect(w, r, urlutil.NewSignedURL(p.SharedKey, &signinURL).String(), http.StatusFound)
|
|
||||||
}
|
}
|
||||||
p.addPomeriumHeaders(w, r)
|
p.addPomeriumHeaders(w, r)
|
||||||
next.ServeHTTP(w, r.WithContext(ctx))
|
next.ServeHTTP(w, r.WithContext(ctx))
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Proxy) refresh(ctx context.Context, w http.ResponseWriter, r *http.Request) (context.Context, error) {
|
||||||
|
ctx, span := trace.StartSpan(ctx, "proxy.AuthenticateSession/refresh")
|
||||||
|
defer span.End()
|
||||||
|
s, err := sessions.FromContext(ctx)
|
||||||
|
if !errors.Is(err, sessions.ErrExpired) || s == nil {
|
||||||
|
return nil, errors.New("proxy: unexpected session state for refresh")
|
||||||
|
}
|
||||||
|
// 1 - build a signed url to call refresh on authenticate service
|
||||||
|
refreshURI := *p.authenticateRefreshURL
|
||||||
|
q := refreshURI.Query()
|
||||||
|
q.Set("ati", s.AccessTokenID) // hash value points to parent token
|
||||||
|
q.Set("aud", urlutil.StripPort(r.Host)) // request's audience, this route
|
||||||
|
refreshURI.RawQuery = q.Encode()
|
||||||
|
signedRefreshURL := urlutil.NewSignedURL(p.SharedKey, &refreshURI).String()
|
||||||
|
|
||||||
|
// 2 - http call to authenticate service
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, signedRefreshURL, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("proxy: backend refresh: new request: %v", err)
|
||||||
|
}
|
||||||
|
res, err := httputil.DefaultClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("proxy: fetch %v: %w", signedRefreshURL, err)
|
||||||
|
}
|
||||||
|
defer res.Body.Close()
|
||||||
|
jwtBytes, err := ioutil.ReadAll(io.LimitReader(res.Body, 4<<10))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3 - save refreshed session to the client's session store
|
||||||
|
if err = p.sessionStore.SaveSession(w, r, jwtBytes); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
// 4 - add refreshed session to the current request context
|
||||||
|
var state sessions.State
|
||||||
|
if err := p.encoder.Unmarshal(jwtBytes, &state); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if err := state.Verify(urlutil.StripPort(r.Host)); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return sessions.NewContext(r.Context(), &state, err), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Proxy) redirectToSignin(w http.ResponseWriter, r *http.Request) error {
|
||||||
|
s, err := sessions.FromContext(r.Context())
|
||||||
|
if s != nil && err != nil && s.Programmatic {
|
||||||
|
return httputil.NewError(http.StatusUnauthorized, err)
|
||||||
|
}
|
||||||
|
p.sessionStore.ClearSession(w, r)
|
||||||
|
signinURL := *p.authenticateSigninURL
|
||||||
|
q := signinURL.Query()
|
||||||
|
q.Set(urlutil.QueryRedirectURI, urlutil.GetAbsoluteURL(r).String())
|
||||||
|
signinURL.RawQuery = q.Encode()
|
||||||
|
log.FromRequest(r).Debug().Str("url", signinURL.String()).Msg("proxy: redirectToSignin")
|
||||||
|
httputil.Redirect(w, r, urlutil.NewSignedURL(p.SharedKey, &signinURL).String(), http.StatusFound)
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Proxy) addPomeriumHeaders(w http.ResponseWriter, r *http.Request) {
|
func (p *Proxy) addPomeriumHeaders(w http.ResponseWriter, r *http.Request) {
|
||||||
|
@ -61,8 +124,8 @@ func (p *Proxy) addPomeriumHeaders(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// AuthorizeSession is middleware to enforce a user is authorized for a request
|
// AuthorizeSession is middleware to enforce a user is authorized for a request.
|
||||||
// session state is retrieved from the users's request context.
|
// Session state is retrieved from the users's request context.
|
||||||
func (p *Proxy) AuthorizeSession(next http.Handler) http.Handler {
|
func (p *Proxy) AuthorizeSession(next http.Handler) http.Handler {
|
||||||
return httputil.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
|
return httputil.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
|
||||||
ctx, span := trace.StartSpan(r.Context(), "proxy.AuthorizeSession")
|
ctx, span := trace.StartSpan(r.Context(), "proxy.AuthorizeSession")
|
||||||
|
|
|
@ -10,10 +10,14 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/google/go-cmp/cmp"
|
"github.com/google/go-cmp/cmp"
|
||||||
"github.com/pomerium/pomerium/internal/identity"
|
|
||||||
"github.com/pomerium/pomerium/internal/sessions"
|
|
||||||
"github.com/pomerium/pomerium/proxy/clients"
|
"github.com/pomerium/pomerium/proxy/clients"
|
||||||
"gopkg.in/square/go-jose.v2/jwt"
|
"gopkg.in/square/go-jose.v2/jwt"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/internal/encoding"
|
||||||
|
"github.com/pomerium/pomerium/internal/encoding/mock"
|
||||||
|
"github.com/pomerium/pomerium/internal/identity"
|
||||||
|
"github.com/pomerium/pomerium/internal/sessions"
|
||||||
|
mstore "github.com/pomerium/pomerium/internal/sessions/mock"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestProxy_AuthenticateSession(t *testing.T) {
|
func TestProxy_AuthenticateSession(t *testing.T) {
|
||||||
|
@ -30,24 +34,39 @@ func TestProxy_AuthenticateSession(t *testing.T) {
|
||||||
session sessions.SessionStore
|
session sessions.SessionStore
|
||||||
ctxError error
|
ctxError error
|
||||||
provider identity.Authenticator
|
provider identity.Authenticator
|
||||||
|
encoder encoding.MarshalUnmarshaler
|
||||||
|
refreshURL string
|
||||||
|
|
||||||
wantStatus int
|
wantStatus int
|
||||||
}{
|
}{
|
||||||
{"good", false, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Second))}}, nil, identity.MockProvider{}, http.StatusOK},
|
{"good", false, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Second))}}, nil, identity.MockProvider{}, &mock.Encoder{}, "", http.StatusOK},
|
||||||
{"invalid session", false, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Second))}}, errors.New("hi"), identity.MockProvider{}, http.StatusFound},
|
{"invalid session", false, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Second))}}, errors.New("hi"), identity.MockProvider{}, &mock.Encoder{}, "", http.StatusFound},
|
||||||
{"expired", false, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{}, http.StatusFound},
|
{"expired", false, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{}, &mock.Encoder{}, "", http.StatusOK},
|
||||||
{"expired and programmatic", false, &sessions.MockSessionStore{Session: &sessions.State{Programmatic: true, Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{}, http.StatusUnauthorized},
|
{"expired and programmatic", false, &mstore.Store{Session: &sessions.State{Programmatic: true, Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{}, &mock.Encoder{}, "", http.StatusOK},
|
||||||
{"invalid session and programmatic", false, &sessions.MockSessionStore{Session: &sessions.State{Programmatic: true, Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Second))}}, errors.New("hi"), identity.MockProvider{}, http.StatusUnauthorized},
|
{"invalid session and programmatic", false, &mstore.Store{Session: &sessions.State{Programmatic: true, Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Second))}}, errors.New("hi"), identity.MockProvider{}, &mock.Encoder{}, "", http.StatusUnauthorized},
|
||||||
|
{"expired and refreshed ok", false, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{}, &mock.Encoder{}, "", http.StatusOK},
|
||||||
|
{"expired and save failed", false, &mstore.Store{SaveError: errors.New("err"), Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{}, &mock.Encoder{}, "", http.StatusFound},
|
||||||
|
{"expired and unmarshal failed", false, &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{}, &mock.Encoder{UnmarshalError: errors.New("err")}, "", http.StatusFound},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
fmt.Fprintln(w, "REFRESH GOOD")
|
||||||
|
}))
|
||||||
|
defer ts.Close()
|
||||||
|
rURL := ts.URL
|
||||||
|
if tt.refreshURL != "" {
|
||||||
|
rURL = tt.refreshURL
|
||||||
|
}
|
||||||
|
|
||||||
a := Proxy{
|
a := Proxy{
|
||||||
SharedKey: "80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ=",
|
SharedKey: "80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ=",
|
||||||
cookieSecret: []byte("80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ="),
|
cookieSecret: []byte("80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ="),
|
||||||
authenticateURL: uriParseHelper("https://authenticate.corp.example"),
|
authenticateURL: uriParseHelper("https://authenticate.corp.example"),
|
||||||
authenticateSigninURL: uriParseHelper("https://authenticate.corp.example/sign_in"),
|
authenticateSigninURL: uriParseHelper("https://authenticate.corp.example/sign_in"),
|
||||||
|
authenticateRefreshURL: uriParseHelper(rURL),
|
||||||
sessionStore: tt.session,
|
sessionStore: tt.session,
|
||||||
|
encoder: tt.encoder,
|
||||||
}
|
}
|
||||||
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
state, _ := tt.session.LoadSession(r)
|
state, _ := tt.session.LoadSession(r)
|
||||||
|
@ -82,10 +101,10 @@ func TestProxy_AuthorizeSession(t *testing.T) {
|
||||||
|
|
||||||
wantStatus int
|
wantStatus int
|
||||||
}{
|
}{
|
||||||
{"user is authorized", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Second))}}, clients.MockAuthorize{AuthorizeResponse: true}, nil, identity.MockProvider{}, http.StatusOK},
|
{"user is authorized", &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Second))}}, clients.MockAuthorize{AuthorizeResponse: true}, nil, identity.MockProvider{}, http.StatusOK},
|
||||||
{"user is not authorized", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Second))}}, clients.MockAuthorize{AuthorizeResponse: false}, nil, identity.MockProvider{}, http.StatusUnauthorized},
|
{"user is not authorized", &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Second))}}, clients.MockAuthorize{AuthorizeResponse: false}, nil, identity.MockProvider{}, http.StatusUnauthorized},
|
||||||
{"invalid session", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Second))}}, clients.MockAuthorize{AuthorizeResponse: true}, errors.New("hi"), identity.MockProvider{}, http.StatusUnauthorized},
|
{"invalid session", &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Second))}}, clients.MockAuthorize{AuthorizeResponse: true}, errors.New("hi"), identity.MockProvider{}, http.StatusUnauthorized},
|
||||||
{"authz client error", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Second))}}, clients.MockAuthorize{AuthorizeError: errors.New("err")}, nil, identity.MockProvider{}, http.StatusInternalServerError},
|
{"authz client error", &mstore.Store{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Second))}}, clients.MockAuthorize{AuthorizeError: errors.New("err")}, nil, identity.MockProvider{}, http.StatusInternalServerError},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
@ -143,9 +162,9 @@ func TestProxy_SignRequest(t *testing.T) {
|
||||||
wantStatus int
|
wantStatus int
|
||||||
wantHeaders string
|
wantHeaders string
|
||||||
}{
|
}{
|
||||||
{"good", &sessions.MockSessionStore{Session: &sessions.State{Email: "test"}}, nil, nil, http.StatusOK, "ok"},
|
{"good", &mstore.Store{Session: &sessions.State{Email: "test"}}, nil, nil, http.StatusOK, "ok"},
|
||||||
{"invalid session", &sessions.MockSessionStore{Session: &sessions.State{Email: "test"}}, nil, errors.New("err"), http.StatusForbidden, ""},
|
{"invalid session", &mstore.Store{Session: &sessions.State{Email: "test"}}, nil, errors.New("err"), http.StatusForbidden, ""},
|
||||||
{"signature failure, warn but ok", &sessions.MockSessionStore{Session: &sessions.State{Email: "test"}}, errors.New("err"), nil, http.StatusOK, ""},
|
{"signature failure, warn but ok", &mstore.Store{Session: &sessions.State{Email: "test"}}, errors.New("err"), nil, http.StatusOK, ""},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
|
|
@ -21,6 +21,9 @@ import (
|
||||||
"github.com/pomerium/pomerium/internal/log"
|
"github.com/pomerium/pomerium/internal/log"
|
||||||
"github.com/pomerium/pomerium/internal/middleware"
|
"github.com/pomerium/pomerium/internal/middleware"
|
||||||
"github.com/pomerium/pomerium/internal/sessions"
|
"github.com/pomerium/pomerium/internal/sessions"
|
||||||
|
"github.com/pomerium/pomerium/internal/sessions/cookie"
|
||||||
|
"github.com/pomerium/pomerium/internal/sessions/header"
|
||||||
|
"github.com/pomerium/pomerium/internal/sessions/queryparam"
|
||||||
"github.com/pomerium/pomerium/internal/telemetry/metrics"
|
"github.com/pomerium/pomerium/internal/telemetry/metrics"
|
||||||
"github.com/pomerium/pomerium/internal/tripper"
|
"github.com/pomerium/pomerium/internal/tripper"
|
||||||
"github.com/pomerium/pomerium/internal/urlutil"
|
"github.com/pomerium/pomerium/internal/urlutil"
|
||||||
|
@ -28,12 +31,11 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
// dashboardURL is the path to authenticate's sign in endpoint
|
// authenticate urls
|
||||||
dashboardURL = "/.pomerium"
|
dashboardURL = "/.pomerium"
|
||||||
// signinURL is the path to authenticate's sign in endpoint
|
|
||||||
signinURL = "/.pomerium/sign_in"
|
signinURL = "/.pomerium/sign_in"
|
||||||
// signoutURL is the path to authenticate's sign out endpoint
|
|
||||||
signoutURL = "/.pomerium/sign_out"
|
signoutURL = "/.pomerium/sign_out"
|
||||||
|
refreshURL = "/.pomerium/refresh"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ValidateOptions checks that proper configuration settings are set to create
|
// ValidateOptions checks that proper configuration settings are set to create
|
||||||
|
@ -72,12 +74,14 @@ type Proxy struct {
|
||||||
authenticateURL *url.URL
|
authenticateURL *url.URL
|
||||||
authenticateSigninURL *url.URL
|
authenticateSigninURL *url.URL
|
||||||
authenticateSignoutURL *url.URL
|
authenticateSignoutURL *url.URL
|
||||||
|
authenticateRefreshURL *url.URL
|
||||||
|
|
||||||
authorizeURL *url.URL
|
authorizeURL *url.URL
|
||||||
|
|
||||||
AuthorizeClient clients.Authorizer
|
AuthorizeClient clients.Authorizer
|
||||||
|
|
||||||
encoder encoding.Unmarshaler
|
encoder encoding.Unmarshaler
|
||||||
cookieOptions *sessions.CookieOptions
|
cookieOptions *cookie.Options
|
||||||
cookieSecret []byte
|
cookieSecret []byte
|
||||||
defaultUpstreamTimeout time.Duration
|
defaultUpstreamTimeout time.Duration
|
||||||
refreshCooldown time.Duration
|
refreshCooldown time.Duration
|
||||||
|
@ -104,7 +108,7 @@ func New(opts config.Options) (*Proxy, error) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
cookieOptions := &sessions.CookieOptions{
|
cookieOptions := &cookie.Options{
|
||||||
Name: opts.CookieName,
|
Name: opts.CookieName,
|
||||||
Domain: opts.CookieDomain,
|
Domain: opts.CookieDomain,
|
||||||
Secure: opts.CookieSecure,
|
Secure: opts.CookieSecure,
|
||||||
|
@ -112,7 +116,7 @@ func New(opts config.Options) (*Proxy, error) {
|
||||||
Expire: opts.CookieExpire,
|
Expire: opts.CookieExpire,
|
||||||
}
|
}
|
||||||
|
|
||||||
cookieStore, err := sessions.NewCookieLoader(cookieOptions, encoder)
|
cookieStore, err := cookie.NewStore(cookieOptions, encoder)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -129,8 +133,8 @@ func New(opts config.Options) (*Proxy, error) {
|
||||||
sessionStore: cookieStore,
|
sessionStore: cookieStore,
|
||||||
sessionLoaders: []sessions.SessionLoader{
|
sessionLoaders: []sessions.SessionLoader{
|
||||||
cookieStore,
|
cookieStore,
|
||||||
sessions.NewHeaderStore(encoder, "Pomerium"),
|
header.NewStore(encoder, "Pomerium"),
|
||||||
sessions.NewQueryParamStore(encoder, "pomerium_session")},
|
queryparam.NewStore(encoder, "pomerium_session")},
|
||||||
signingKey: opts.SigningKey,
|
signingKey: opts.SigningKey,
|
||||||
templates: template.Must(frontend.NewTemplates()),
|
templates: template.Must(frontend.NewTemplates()),
|
||||||
}
|
}
|
||||||
|
@ -139,6 +143,7 @@ func New(opts config.Options) (*Proxy, error) {
|
||||||
p.authenticateURL, _ = urlutil.DeepCopy(opts.AuthenticateURL)
|
p.authenticateURL, _ = urlutil.DeepCopy(opts.AuthenticateURL)
|
||||||
p.authenticateSigninURL = p.authenticateURL.ResolveReference(&url.URL{Path: signinURL})
|
p.authenticateSigninURL = p.authenticateURL.ResolveReference(&url.URL{Path: signinURL})
|
||||||
p.authenticateSignoutURL = p.authenticateURL.ResolveReference(&url.URL{Path: signoutURL})
|
p.authenticateSignoutURL = p.authenticateURL.ResolveReference(&url.URL{Path: signoutURL})
|
||||||
|
p.authenticateRefreshURL = p.authenticateURL.ResolveReference(&url.URL{Path: refreshURL})
|
||||||
|
|
||||||
if err := p.UpdatePolicies(&opts); err != nil {
|
if err := p.UpdatePolicies(&opts); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue