internal/sessions: refactor how sessions loading (#351)

These chagnes standardize how session loading is done for session
cookie, auth bearer token, and query params.

- Bearer token previously combined with session cookie.
- rearranged cookie-store to put exported methods above unexported
- added header store that implements session loader interface
- added query param store that implements session loader interface

Signed-off-by: Bobby DeSimone <bobbydesimone@gmail.com>
This commit is contained in:
Bobby DeSimone 2019-10-06 10:47:53 -07:00 committed by GitHub
parent 7aa4621b1b
commit badd8d69af
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 322 additions and 234 deletions

View file

@ -10,47 +10,39 @@ import (
"github.com/pomerium/pomerium/internal/cryptutil"
)
// ChunkedCanaryByte is the byte value used as a canary prefix to distinguish if
// the cookie is multi-part or not. This constant *should not* be valid
// base64. It's important this byte is ASCII to avoid UTF-8 variable sized runes.
// https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Set-Cookie#Directives
const ChunkedCanaryByte byte = '%'
const (
// ChunkedCanaryByte is the byte value used as a canary prefix to distinguish if
// the cookie is multi-part or not. This constant *should not* be valid
// base64. It's important this byte is ASCII to avoid UTF-8 variable sized runes.
// https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Set-Cookie#Directives
ChunkedCanaryByte byte = '%'
// MaxChunkSize sets the upper bound on a cookie chunks payload value.
// Note, this should be lower than the actual cookie's max size (4096 bytes)
// which includes metadata.
MaxChunkSize = 3800
// MaxNumChunks limits the number of chunks to iterate through. Conservatively
// set to prevent any abuse.
MaxNumChunks = 5
)
// DefaultBearerTokenHeader is default header name for the authorization bearer
// token header as defined in rfc2617
// https://tools.ietf.org/html/rfc6750#section-2.1
const DefaultBearerTokenHeader = "Authorization"
// MaxChunkSize sets the upper bound on a cookie chunks payload value.
// Note, this should be lower than the actual cookie's max size (4096 bytes)
// which includes metadata.
const MaxChunkSize = 3800
// MaxNumChunks limits the number of chunks to iterate through. Conservatively
// set to prevent any abuse.
const MaxNumChunks = 5
// CookieStore represents all the cookie related configurations
// CookieStore implements the session store interface for session cookies.
type CookieStore struct {
Name string
Encoder cryptutil.SecureEncoder
CookieExpire time.Duration
CookieRefresh time.Duration
CookieSecure bool
CookieHTTPOnly bool
CookieDomain string
BearerTokenHeader string
Name string
CookieDomain string
CookieExpire time.Duration
CookieHTTPOnly bool
CookieSecure bool
Encoder cryptutil.SecureEncoder
}
// CookieStoreOptions holds options for CookieStore
type CookieStoreOptions struct {
Name string
CookieSecure bool
CookieHTTPOnly bool
CookieDomain string
BearerTokenHeader string
CookieExpire time.Duration
Encoder cryptutil.SecureEncoder
Name string
CookieDomain string
CookieExpire time.Duration
CookieHTTPOnly bool
CookieSecure bool
Encoder cryptutil.SecureEncoder
}
// NewCookieStore returns a new session with ciphers for each of the cookie secrets
@ -61,18 +53,14 @@ func NewCookieStore(opts *CookieStoreOptions) (*CookieStore, error) {
if opts.Encoder == nil {
return nil, fmt.Errorf("internal/sessions: cipher cannot be nil")
}
if opts.BearerTokenHeader == "" {
opts.BearerTokenHeader = DefaultBearerTokenHeader
}
return &CookieStore{
Name: opts.Name,
CookieSecure: opts.CookieSecure,
CookieHTTPOnly: opts.CookieHTTPOnly,
CookieDomain: opts.CookieDomain,
CookieExpire: opts.CookieExpire,
Encoder: opts.Encoder,
BearerTokenHeader: opts.BearerTokenHeader,
Name: opts.Name,
CookieSecure: opts.CookieSecure,
CookieHTTPOnly: opts.CookieHTTPOnly,
CookieDomain: opts.CookieDomain,
CookieExpire: opts.CookieExpire,
Encoder: opts.Encoder,
}, nil
}
@ -103,11 +91,43 @@ func (cs *CookieStore) makeCookie(req *http.Request, name string, value string,
return c
}
// ClearSession clears the session cookie from a request
func (cs *CookieStore) ClearSession(w http.ResponseWriter, req *http.Request) {
http.SetCookie(w, cs.makeCookie(req, cs.Name, "", time.Hour*-1, time.Now()))
}
// LoadSession returns a State from the cookie in the request.
func (cs *CookieStore) LoadSession(req *http.Request) (*State, error) {
cipherText := loadChunkedCookie(req, cs.Name)
if cipherText == "" {
return nil, ErrNoSessionFound
}
session, err := UnmarshalSession(cipherText, cs.Encoder)
if err != nil {
return nil, ErrMalformed
}
return session, nil
}
// SaveSession saves a session state to a request sessions.
func (cs *CookieStore) SaveSession(w http.ResponseWriter, req *http.Request, s *State) error {
value, err := MarshalSession(s, cs.Encoder)
if err != nil {
return err
}
cs.setSessionCookie(w, req, value)
return nil
}
// makeSessionCookie constructs a session cookie given the request, an expiration time and the current time.
func (cs *CookieStore) makeSessionCookie(req *http.Request, value string, expiration time.Duration, now time.Time) *http.Cookie {
return cs.makeCookie(req, cs.Name, value, expiration, now)
}
func (cs *CookieStore) setSessionCookie(w http.ResponseWriter, req *http.Request, val string) {
cs.setCookie(w, cs.makeSessionCookie(req, val, cs.CookieExpire, time.Now()))
}
func (cs *CookieStore) setCookie(w http.ResponseWriter, cookie *http.Cookie) {
if len(cookie.String()) <= MaxChunkSize {
http.SetCookie(w, cookie)
@ -128,35 +148,6 @@ func (cs *CookieStore) setCookie(w http.ResponseWriter, cookie *http.Cookie) {
}
}
func chunk(s string, size int) []string {
ss := make([]string, 0, len(s)/size+1)
for len(s) > 0 {
if len(s) < size {
size = len(s)
}
ss, s = append(ss, s[:size]), s[size:]
}
return ss
}
// ClearSession clears the session cookie from a request
func (cs *CookieStore) ClearSession(w http.ResponseWriter, req *http.Request) {
http.SetCookie(w, cs.makeCookie(req, cs.Name, "", time.Hour*-1, time.Now()))
}
func (cs *CookieStore) setSessionCookie(w http.ResponseWriter, req *http.Request, val string) {
cs.setCookie(w, cs.makeSessionCookie(req, val, cs.CookieExpire, time.Now()))
}
func loadBearerToken(r *http.Request, headerKey string) string {
authHeader := r.Header.Get(headerKey)
split := strings.Split(authHeader, "Bearer")
if authHeader == "" || len(split) != 2 {
return ""
}
return strings.TrimSpace(split[1])
}
func loadChunkedCookie(r *http.Request, cookieName string) string {
c, err := r.Cookie(cookieName)
if err != nil {
@ -179,37 +170,13 @@ func loadChunkedCookie(r *http.Request, cookieName string) string {
return cipherText
}
// LoadSession returns a State from the cookie in the request.
func (cs *CookieStore) LoadSession(req *http.Request) (*State, error) {
cipherText := loadChunkedCookie(req, cs.Name)
if cipherText == "" {
cipherText = loadBearerToken(req, cs.BearerTokenHeader)
func chunk(s string, size int) []string {
ss := make([]string, 0, len(s)/size+1)
for len(s) > 0 {
if len(s) < size {
size = len(s)
}
ss, s = append(ss, s[:size]), s[size:]
}
if cipherText == "" {
return nil, ErrEmptySession
}
session, err := UnmarshalSession(cipherText, cs.Encoder)
if err != nil {
return nil, err
}
return session, nil
}
// SaveSession saves a session state to a request sessions.
func (cs *CookieStore) SaveSession(w http.ResponseWriter, req *http.Request, s *State) error {
value, err := MarshalSession(s, cs.Encoder)
if err != nil {
return err
}
cs.setSessionCookie(w, req, value)
return nil
}
// ParentSubdomain returns the parent subdomain.
func ParentSubdomain(s string) string {
if strings.Count(s, ".") < 2 {
return ""
}
split := strings.SplitN(s, ".", 2)
return split[1]
return ss
}

View file

@ -38,33 +38,30 @@ func TestNewCookieStore(t *testing.T) {
}{
{"good",
&CookieStoreOptions{
Name: "_cookie",
CookieSecure: true,
CookieHTTPOnly: true,
CookieDomain: "pomerium.io",
CookieExpire: 10 * time.Second,
Encoder: encoder,
BearerTokenHeader: "Authorization",
Name: "_cookie",
CookieSecure: true,
CookieHTTPOnly: true,
CookieDomain: "pomerium.io",
CookieExpire: 10 * time.Second,
Encoder: encoder,
},
&CookieStore{
Name: "_cookie",
CookieSecure: true,
CookieHTTPOnly: true,
CookieDomain: "pomerium.io",
CookieExpire: 10 * time.Second,
Encoder: encoder,
BearerTokenHeader: "Authorization",
Name: "_cookie",
CookieSecure: true,
CookieHTTPOnly: true,
CookieDomain: "pomerium.io",
CookieExpire: 10 * time.Second,
Encoder: encoder,
},
false},
{"missing name",
&CookieStoreOptions{
Name: "",
CookieSecure: true,
CookieHTTPOnly: true,
CookieDomain: "pomerium.io",
CookieExpire: 10 * time.Second,
Encoder: encoder,
BearerTokenHeader: "Authorization",
Name: "",
CookieSecure: true,
CookieHTTPOnly: true,
CookieDomain: "pomerium.io",
CookieExpire: 10 * time.Second,
Encoder: encoder,
},
nil,
true},
@ -250,23 +247,3 @@ func TestMockSessionStore(t *testing.T) {
})
}
}
func Test_ParentSubdomain(t *testing.T) {
t.Parallel()
tests := []struct {
s string
want string
}{
{"httpbin.corp.example.com", "corp.example.com"},
{"some.httpbin.corp.example.com", "httpbin.corp.example.com"},
{"example.com", ""},
{"", ""},
}
for _, tt := range tests {
t.Run(tt.s, func(t *testing.T) {
if got := ParentSubdomain(tt.s); got != tt.want {
t.Errorf("ParentSubdomain() = %v, want %v", got, tt.want)
}
})
}
}

View file

@ -0,0 +1,61 @@
package sessions // import "github.com/pomerium/pomerium/internal/sessions"
import (
"net/http"
"strings"
"github.com/pomerium/pomerium/internal/cryptutil"
)
const (
// defaultAuthHeader and defaultAuthType are default header name for the
// authorization bearer token header as defined in rfc2617
// https://tools.ietf.org/html/rfc6750#section-2.1
defaultAuthHeader = "Authorization"
defaultAuthType = "Bearer"
)
// HeaderStore implements the load session store interface using http
// authorization headers.
type HeaderStore struct {
authHeader string
authType string
encoder cryptutil.SecureEncoder
}
// NewHeaderStore returns a new header store for loading sessions from
// authorization headers.
func NewHeaderStore(enc cryptutil.SecureEncoder) *HeaderStore {
return &HeaderStore{
authHeader: defaultAuthHeader,
authType: defaultAuthType,
encoder: enc,
}
}
// LoadSession tries to retrieve the token string from the Authorization header.
//
// NOTA BENE: While most servers do not log Authorization headers by default,
// you should ensure no other services are logging or leaking your auth headers.
func (as *HeaderStore) LoadSession(r *http.Request) (*State, error) {
cipherText := as.tokenFromHeader(r)
if cipherText == "" {
return nil, ErrNoSessionFound
}
session, err := UnmarshalSession(cipherText, as.encoder)
if err != nil {
return nil, ErrMalformed
}
return session, nil
}
// retrieve the value of the authorization header
func (as *HeaderStore) tokenFromHeader(r *http.Request) string {
bearer := r.Header.Get(as.authHeader)
atSize := len(as.authType)
if len(bearer) > atSize && strings.EqualFold(bearer[0:atSize], as.authType) {
return bearer[atSize+1:]
}
return ""
}

View file

@ -4,7 +4,6 @@ import (
"context"
"errors"
"net/http"
"strings"
)
// Context keys
@ -13,65 +12,56 @@ var (
ErrorCtxKey = &contextKey{"Error"}
)
// Library errors
var (
ErrExpired = errors.New("internal/sessions: session is expired")
ErrNoSessionFound = errors.New("internal/sessions: session is not found")
ErrMalformed = errors.New("internal/sessions: session is malformed")
)
// RetrieveSession http middleware handler will verify a auth session from a http request.
//
// RetrieveSession will search for a auth session in a http request, in the order:
// 1. `pomerium_session` URI query parameter
// 2. `Authorization: BEARER` request header
// 3. Cookie `_pomerium` value
func RetrieveSession(s SessionStore) func(http.Handler) http.Handler {
func RetrieveSession(s ...SessionLoader) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return retrieve(s, TokenFromQuery, TokenFromHeader, TokenFromCookie)(next)
return retrieve(s...)(next)
}
}
func retrieve(s SessionStore, findTokenFns ...func(r *http.Request) string) func(http.Handler) http.Handler {
func retrieve(s ...SessionLoader) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
hfn := func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
token, err := retrieveFromRequest(s, r, findTokenFns...)
ctx = NewContext(ctx, token, err)
state, err := retrieveFromRequest(r, s...)
ctx = NewContext(ctx, state, err)
next.ServeHTTP(w, r.WithContext(ctx))
}
return http.HandlerFunc(hfn)
}
}
func retrieveFromRequest(s SessionStore, r *http.Request, findTokenFns ...func(r *http.Request) string) (*State, error) {
var tokenStr string
func retrieveFromRequest(r *http.Request, sessions ...SessionLoader) (*State, error) {
state := new(State)
var err error
// Extract token string from the request by calling token find functions in
// Extract sessions state from the request by calling token find functions in
// the order they where provided. Further extraction stops if a function
// returns a non-empty string.
for _, fn := range findTokenFns {
tokenStr = fn(r)
if tokenStr != "" {
for _, s := range sessions {
state, err = s.LoadSession(r)
if err != nil && !errors.Is(err, ErrNoSessionFound) {
// unexpected error
return nil, err
}
// break, we found a session state
if state != nil {
break
}
}
if tokenStr == "" {
// no session found if state is still empty
if state == nil {
return nil, ErrNoSessionFound
}
state, err := s.LoadSession(r)
if err != nil {
return nil, ErrMalformed
}
err = state.Valid()
if err != nil {
if err = state.Valid(); err != nil {
// a little unusual but we want to return the expired state too
return state, err
}
// Valid!
return state, nil
}
@ -89,35 +79,6 @@ func FromContext(ctx context.Context) (*State, error) {
return state, err
}
// TokenFromCookie tries to retrieve the token string from a cookie named
// "_pomerium".
func TokenFromCookie(r *http.Request) string {
cookie, err := r.Cookie("_pomerium")
if err != nil {
return ""
}
return cookie.Value
}
// TokenFromHeader tries to retrieve the token string from the
// "Authorization" request header: "Authorization: BEARER T".
func TokenFromHeader(r *http.Request) string {
// Get token from authorization header.
bearer := r.Header.Get("Authorization")
if len(bearer) > 7 && strings.EqualFold(bearer[0:6], "BEARER") {
return bearer[7:]
}
return ""
}
// TokenFromQuery tries to retrieve the token string from the "pomerium_session" URI
// query parameter.
// todo(bdd) : document setting session code as queryparam
func TokenFromQuery(r *http.Request) string {
// Get token from query param named "pomerium_session".
return r.URL.Query().Get("pomerium_session")
}
// contextKey is a value for use with context.WithValue. It's used as
// a pointer so it fits in an interface{} without allocation. This technique
// for defining context keys was copied from Go 1.7's new use of context in net/http.
@ -126,5 +87,5 @@ type contextKey struct {
}
func (k *contextKey) String() string {
return "SessionStore context value " + k.name
return "context value " + k.name
}

View file

@ -9,9 +9,8 @@ import (
"testing"
"time"
"github.com/pomerium/pomerium/internal/cryptutil"
"github.com/google/go-cmp/cmp"
"github.com/pomerium/pomerium/internal/cryptutil"
)
func TestNewContext(t *testing.T) {
@ -75,8 +74,8 @@ func TestVerifier(t *testing.T) {
{"expired auth header", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(-10 * time.Second)}, false, true, false, "internal/sessions: session is expired\n", http.StatusUnauthorized},
{"malformed auth header", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(-10 * time.Second)}, false, true, false, "internal/sessions: session is malformed\n", http.StatusUnauthorized},
{"good auth query param session", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(10 * time.Second)}, false, true, true, http.StatusText(http.StatusOK), http.StatusOK},
{"expired auth query param", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(-10 * time.Second)}, false, true, true, "internal/sessions: session is expired\n", http.StatusUnauthorized},
{"malformed auth query param", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(-10 * time.Second)}, false, true, true, "internal/sessions: session is malformed\n", http.StatusUnauthorized},
{"expired auth query param", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(-10 * time.Second)}, false, false, true, "internal/sessions: session is expired\n", http.StatusUnauthorized},
{"malformed auth query param", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(-10 * time.Second)}, false, false, true, "internal/sessions: session is malformed\n", http.StatusUnauthorized},
{"no session", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(-10 * time.Second)}, false, false, false, "internal/sessions: session is not found\n", http.StatusUnauthorized},
}
for _, tt := range tests {
@ -94,7 +93,6 @@ func TestVerifier(t *testing.T) {
if strings.Contains(tt.name, "malformed") {
// add some garbage to the end of the string
encSession += cryptutil.NewBase64Key()
fmt.Println(encSession)
}
cs, err := NewCookieStore(&CookieStoreOptions{
@ -104,6 +102,9 @@ func TestVerifier(t *testing.T) {
if err != nil {
t.Fatal(err)
}
as := NewHeaderStore(encoder)
qp := NewQueryParamStore(encoder)
r := httptest.NewRequest(http.MethodGet, "/", nil)
r.Header.Set("Accept", "application/json")
@ -114,11 +115,12 @@ func TestVerifier(t *testing.T) {
r.Header.Set("Authorization", "Bearer "+encSession)
} else if tt.param {
q := r.URL.Query()
q.Add("pomerium_session", encSession)
q.Set("pomerium_session", encSession)
r.URL.RawQuery = q.Encode()
}
got := RetrieveSession(cs)(testAuthorizer((fnh)))
got := RetrieveSession(cs, as, qp)(testAuthorizer((fnh)))
got.ServeHTTP(w, r)
gotBody := w.Body.String()
@ -133,3 +135,23 @@ func TestVerifier(t *testing.T) {
})
}
}
func Test_contextKey_String(t *testing.T) {
tests := []struct {
name string
keyName string
want string
}{
{"simple example", "test", "context value test"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
k := &contextKey{
name: tt.keyName,
}
if got := k.String(); got != tt.want {
t.Errorf("contextKey.String() = %v, want %v", got, tt.want)
}
})
}
}

View file

@ -0,0 +1,44 @@
package sessions // import "github.com/pomerium/pomerium/internal/sessions"
import (
"net/http"
"github.com/pomerium/pomerium/internal/cryptutil"
)
const (
defaultQueryParamKey = "pomerium_session"
)
// QueryParamStore implements the load session store interface using http
// query strings / query parameters.
type QueryParamStore struct {
queryParamKey string
encoder cryptutil.SecureEncoder
}
// NewQueryParamStore returns a new query param store for loading sessions from
// query strings / query parameters.
func NewQueryParamStore(enc cryptutil.SecureEncoder) *QueryParamStore {
return &QueryParamStore{
queryParamKey: defaultQueryParamKey,
encoder: enc,
}
}
// LoadSession tries to retrieve the token string from URL query parameters.
//
// NOTA BENE: By default, most servers _DO_ log query params, the leaking or
// accidental logging of which should be considered a security issue.
func (qp *QueryParamStore) LoadSession(r *http.Request) (*State, error) {
cipherText := r.URL.Query().Get(qp.queryParamKey)
if cipherText == "" {
return nil, ErrNoSessionFound
}
session, err := UnmarshalSession(cipherText, qp.encoder)
if err != nil {
return nil, ErrMalformed
}
return session, nil
}

View file

@ -5,8 +5,14 @@ import (
"net/http"
)
// ErrEmptySession is an error for an empty sessions.
var ErrEmptySession = errors.New("internal/sessions: empty session")
var (
// ErrExpired is the error for an expired session.
ErrExpired = errors.New("internal/sessions: session is expired")
// ErrNoSessionFound is the error for when no session is found.
ErrNoSessionFound = errors.New("internal/sessions: session is not found")
// ErrMalformed is the error for when a session is found but is malformed.
ErrMalformed = errors.New("internal/sessions: session is malformed")
)
// SessionStore has the functions for setting, getting, and clearing the Session cookie
type SessionStore interface {
@ -14,3 +20,9 @@ type SessionStore interface {
LoadSession(*http.Request) (*State, error)
SaveSession(http.ResponseWriter, *http.Request, *State) error
}
// SessionLoader is implemented by any struct that loads a pomerium session
// given a request, and returns a user state.
type SessionLoader interface {
LoadSession(*http.Request) (*State, error)
}

12
internal/sessions/util.go Normal file
View file

@ -0,0 +1,12 @@
package sessions // import "github.com/pomerium/pomerium/internal/sessions"
import "strings"
// ParentSubdomain returns the parent subdomain.
func ParentSubdomain(s string) string {
if strings.Count(s, ".") < 2 {
return ""
}
split := strings.SplitN(s, ".", 2)
return split[1]
}

View file

@ -0,0 +1,23 @@
package sessions
import "testing"
func Test_ParentSubdomain(t *testing.T) {
t.Parallel()
tests := []struct {
s string
want string
}{
{"httpbin.corp.example.com", "corp.example.com"},
{"some.httpbin.corp.example.com", "httpbin.corp.example.com"},
{"example.com", ""},
{"", ""},
}
for _, tt := range tests {
t.Run(tt.s, func(t *testing.T) {
if got := ParentSubdomain(tt.s); got != tt.want {
t.Errorf("ParentSubdomain() = %v, want %v", got, tt.want)
}
})
}
}

View file

@ -184,8 +184,8 @@ func (p *Proxy) Verify(w http.ResponseWriter, r *http.Request) {
}
// check the queryparams to see if this check immediately followed
// authentication. If so, redirect back to the originally requested hostname.
if isCallback := r.URL.Query().Get(disableCallback); isCallback == "true" {
http.Redirect(w, r, "http://"+hostname, http.StatusFound)
if isCallback := r.URL.Query().Get("pomerium-auth-callback"); isCallback == "true" {
http.Redirect(w, r, hostname, http.StatusFound)
return
}

View file

@ -302,13 +302,13 @@ func TestProxy_VerifyWithMiddleware(t *testing.T) {
authorizer clients.Authorizer
wantStatus int
}{
{"good", opts, nil, http.MethodGet, "false", "/.pomerium/verify/some.domain.name", &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusOK},
{"good post auth redirect", opts, nil, http.MethodGet, "true", "/.pomerium/verify/some.domain.name", &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound},
{"not authorized", opts, nil, http.MethodGet, "false", "/.pomerium/verify/some.domain.name", &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusForbidden},
{"not authorized expired, redirect to auth", opts, nil, http.MethodGet, "false", "/.pomerium/verify/some.domain.name", &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusFound},
{"not authorized expired, don't redirect!", opts, nil, http.MethodGet, "true", "/.pomerium/verify/some.domain.name?no_redirect=true", &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusUnauthorized},
{"not authorized because of error", opts, nil, http.MethodGet, "false", "/.pomerium/verify/some.domain.name", &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{AuthorizeError: errors.New("authz error")}, http.StatusInternalServerError},
{"bad context retrieval error", opts, errors.New("oh no"), http.MethodGet, "false", "/.pomerium/verify/some.domain.name", &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusOK},
{"good", opts, nil, http.MethodGet, "", "/.pomerium/verify/some.domain.name", &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusOK},
{"good post auth redirect", opts, nil, http.MethodGet, "pomerium-auth-callback", "/.pomerium/verify/some.domain.name", &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound},
{"not authorized", opts, nil, http.MethodGet, "", "/.pomerium/verify/some.domain.name", &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusForbidden},
{"not authorized expired, redirect to auth", opts, nil, http.MethodGet, "", "/.pomerium/verify/some.domain.name", &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusFound},
{"not authorized expired, don't redirect!", opts, nil, http.MethodGet, HeaderNoAuthRedirect, "/.pomerium/verify/some.domain.name?no_redirect=true", &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusUnauthorized},
{"not authorized because of error", opts, nil, http.MethodGet, "", "/.pomerium/verify/some.domain.name", &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{AuthorizeError: errors.New("authz error")}, http.StatusInternalServerError},
{"bad context retrieval error", opts, errors.New("oh no"), http.MethodGet, "", "/.pomerium/verify/some.domain.name", &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusOK},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
@ -322,8 +322,8 @@ func TestProxy_VerifyWithMiddleware(t *testing.T) {
p.UpdateOptions(tt.options)
uri := &url.URL{Path: tt.path}
queryString := uri.Query()
if tt.qp == "true" {
queryString.Set("pomerium-auth-callback", tt.qp)
if tt.qp != "" {
queryString.Set(tt.qp, "true")
}
uri.RawQuery = queryString.Encode()

View file

@ -23,7 +23,9 @@ const (
// HeaderGroups is the header key containing the user's groups.
HeaderGroups = "x-pomerium-authenticated-user-groups"
disableCallback = "pomerium-auth-callback"
// HeaderNoAuthRedirect is the header / query param key used to disable
// redirecting unauthenticated request by default but instead return a 401.
HeaderNoAuthRedirect = "x-pomerium-no-auth-redirect"
)
// AuthenticateSession is middleware to enforce a valid authentication
@ -33,8 +35,8 @@ func (p *Proxy) AuthenticateSession(next http.Handler) http.Handler {
ctx, span := trace.StartSpan(r.Context(), "middleware.AuthenticateSession")
defer span.End()
s, err := sessions.FromContext(r.Context())
if err != nil {
log.Debug().Str("cause", err.Error()).Msg("proxy: re-authenticating due to session state error")
if err != nil || s == nil {
log.Debug().Msg("proxy: re-authenticating due to session state error")
p.reqNeedsAuthentication(w, r)
return
}
@ -58,7 +60,7 @@ func (p *Proxy) AuthorizeSession(next http.Handler) http.Handler {
ctx, span := trace.StartSpan(r.Context(), "middleware.AuthorizeSession")
defer span.End()
s, err := sessions.FromContext(r.Context())
if err != nil {
if err != nil || s == nil {
httputil.ErrorResponse(w, r.WithContext(ctx), httputil.Error("", http.StatusForbidden, err))
return
}
@ -105,9 +107,11 @@ func (p *Proxy) reqNeedsAuthentication(w http.ResponseWriter, r *http.Request) {
// some proxies like nginx won't follow redirects, and treat any
// non 2xx or 4xx status as an internal service error.
// https://nginx.org/en/docs/http/ngx_http_auth_request_module.html
if _, ok := r.URL.Query()[disableCallback]; ok {
redirectHeader := r.Header.Get(HeaderNoAuthRedirect)
if _, ok := r.URL.Query()[HeaderNoAuthRedirect]; ok || redirectHeader == "true" {
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
}
r.Header.Get(HeaderNoAuthRedirect)
uri := urlutil.SignedRedirectURL(p.SharedKey, p.authenticateSigninURL, urlutil.GetAbsoluteURL(r))
http.Redirect(w, r, uri.String(), http.StatusFound)
}

View file

@ -78,6 +78,7 @@ type Proxy struct {
refreshCooldown time.Duration
Handler http.Handler
sessionStore sessions.SessionStore
sessionLoaders []sessions.SessionLoader
signingKey string
templates *template.Template
}
@ -122,8 +123,12 @@ func New(opts config.Options) (*Proxy, error) {
defaultUpstreamTimeout: opts.DefaultUpstreamTimeout,
refreshCooldown: opts.RefreshCooldown,
sessionStore: cookieStore,
signingKey: opts.SigningKey,
templates: templates.New(),
sessionLoaders: []sessions.SessionLoader{
cookieStore,
sessions.NewHeaderStore(encoder),
sessions.NewQueryParamStore(encoder)},
signingKey: opts.SigningKey,
templates: templates.New(),
}
// errors checked in ValidateOptions
p.authorizeURL, _ = urlutil.DeepCopy(opts.AuthorizeURL)
@ -227,7 +232,7 @@ func (p *Proxy) reverseProxyHandler(r *mux.Router, policy *config.Policy) (*mux.
}
// 4. Retrieve the user session and add it to the request context
rp.Use(sessions.RetrieveSession(p.sessionStore))
rp.Use(sessions.RetrieveSession(p.sessionLoaders...))
// 5. Strip the user session cookie from the downstream request
rp.Use(middleware.StripCookie(p.cookieName))
// 6. AuthN - Verify the user is authenticated. Set email, group, & id headers