internal/sessions: refactor how sessions loading (#351)

These chagnes standardize how session loading is done for session
cookie, auth bearer token, and query params.

- Bearer token previously combined with session cookie.
- rearranged cookie-store to put exported methods above unexported
- added header store that implements session loader interface
- added query param store that implements session loader interface

Signed-off-by: Bobby DeSimone <bobbydesimone@gmail.com>
This commit is contained in:
Bobby DeSimone 2019-10-06 10:47:53 -07:00 committed by GitHub
parent 7aa4621b1b
commit badd8d69af
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 322 additions and 234 deletions

View file

@ -10,47 +10,39 @@ import (
"github.com/pomerium/pomerium/internal/cryptutil" "github.com/pomerium/pomerium/internal/cryptutil"
) )
// ChunkedCanaryByte is the byte value used as a canary prefix to distinguish if const (
// the cookie is multi-part or not. This constant *should not* be valid // ChunkedCanaryByte is the byte value used as a canary prefix to distinguish if
// base64. It's important this byte is ASCII to avoid UTF-8 variable sized runes. // the cookie is multi-part or not. This constant *should not* be valid
// https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Set-Cookie#Directives // base64. It's important this byte is ASCII to avoid UTF-8 variable sized runes.
const ChunkedCanaryByte byte = '%' // https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Set-Cookie#Directives
ChunkedCanaryByte byte = '%'
// MaxChunkSize sets the upper bound on a cookie chunks payload value.
// Note, this should be lower than the actual cookie's max size (4096 bytes)
// which includes metadata.
MaxChunkSize = 3800
// MaxNumChunks limits the number of chunks to iterate through. Conservatively
// set to prevent any abuse.
MaxNumChunks = 5
)
// DefaultBearerTokenHeader is default header name for the authorization bearer // CookieStore implements the session store interface for session cookies.
// token header as defined in rfc2617
// https://tools.ietf.org/html/rfc6750#section-2.1
const DefaultBearerTokenHeader = "Authorization"
// MaxChunkSize sets the upper bound on a cookie chunks payload value.
// Note, this should be lower than the actual cookie's max size (4096 bytes)
// which includes metadata.
const MaxChunkSize = 3800
// MaxNumChunks limits the number of chunks to iterate through. Conservatively
// set to prevent any abuse.
const MaxNumChunks = 5
// CookieStore represents all the cookie related configurations
type CookieStore struct { type CookieStore struct {
Name string Name string
Encoder cryptutil.SecureEncoder CookieDomain string
CookieExpire time.Duration CookieExpire time.Duration
CookieRefresh time.Duration CookieHTTPOnly bool
CookieSecure bool CookieSecure bool
CookieHTTPOnly bool Encoder cryptutil.SecureEncoder
CookieDomain string
BearerTokenHeader string
} }
// CookieStoreOptions holds options for CookieStore // CookieStoreOptions holds options for CookieStore
type CookieStoreOptions struct { type CookieStoreOptions struct {
Name string Name string
CookieSecure bool CookieDomain string
CookieHTTPOnly bool CookieExpire time.Duration
CookieDomain string CookieHTTPOnly bool
BearerTokenHeader string CookieSecure bool
CookieExpire time.Duration Encoder cryptutil.SecureEncoder
Encoder cryptutil.SecureEncoder
} }
// NewCookieStore returns a new session with ciphers for each of the cookie secrets // NewCookieStore returns a new session with ciphers for each of the cookie secrets
@ -61,18 +53,14 @@ func NewCookieStore(opts *CookieStoreOptions) (*CookieStore, error) {
if opts.Encoder == nil { if opts.Encoder == nil {
return nil, fmt.Errorf("internal/sessions: cipher cannot be nil") return nil, fmt.Errorf("internal/sessions: cipher cannot be nil")
} }
if opts.BearerTokenHeader == "" {
opts.BearerTokenHeader = DefaultBearerTokenHeader
}
return &CookieStore{ return &CookieStore{
Name: opts.Name, Name: opts.Name,
CookieSecure: opts.CookieSecure, CookieSecure: opts.CookieSecure,
CookieHTTPOnly: opts.CookieHTTPOnly, CookieHTTPOnly: opts.CookieHTTPOnly,
CookieDomain: opts.CookieDomain, CookieDomain: opts.CookieDomain,
CookieExpire: opts.CookieExpire, CookieExpire: opts.CookieExpire,
Encoder: opts.Encoder, Encoder: opts.Encoder,
BearerTokenHeader: opts.BearerTokenHeader,
}, nil }, nil
} }
@ -103,11 +91,43 @@ func (cs *CookieStore) makeCookie(req *http.Request, name string, value string,
return c return c
} }
// ClearSession clears the session cookie from a request
func (cs *CookieStore) ClearSession(w http.ResponseWriter, req *http.Request) {
http.SetCookie(w, cs.makeCookie(req, cs.Name, "", time.Hour*-1, time.Now()))
}
// LoadSession returns a State from the cookie in the request.
func (cs *CookieStore) LoadSession(req *http.Request) (*State, error) {
cipherText := loadChunkedCookie(req, cs.Name)
if cipherText == "" {
return nil, ErrNoSessionFound
}
session, err := UnmarshalSession(cipherText, cs.Encoder)
if err != nil {
return nil, ErrMalformed
}
return session, nil
}
// SaveSession saves a session state to a request sessions.
func (cs *CookieStore) SaveSession(w http.ResponseWriter, req *http.Request, s *State) error {
value, err := MarshalSession(s, cs.Encoder)
if err != nil {
return err
}
cs.setSessionCookie(w, req, value)
return nil
}
// makeSessionCookie constructs a session cookie given the request, an expiration time and the current time. // makeSessionCookie constructs a session cookie given the request, an expiration time and the current time.
func (cs *CookieStore) makeSessionCookie(req *http.Request, value string, expiration time.Duration, now time.Time) *http.Cookie { func (cs *CookieStore) makeSessionCookie(req *http.Request, value string, expiration time.Duration, now time.Time) *http.Cookie {
return cs.makeCookie(req, cs.Name, value, expiration, now) return cs.makeCookie(req, cs.Name, value, expiration, now)
} }
func (cs *CookieStore) setSessionCookie(w http.ResponseWriter, req *http.Request, val string) {
cs.setCookie(w, cs.makeSessionCookie(req, val, cs.CookieExpire, time.Now()))
}
func (cs *CookieStore) setCookie(w http.ResponseWriter, cookie *http.Cookie) { func (cs *CookieStore) setCookie(w http.ResponseWriter, cookie *http.Cookie) {
if len(cookie.String()) <= MaxChunkSize { if len(cookie.String()) <= MaxChunkSize {
http.SetCookie(w, cookie) http.SetCookie(w, cookie)
@ -128,35 +148,6 @@ func (cs *CookieStore) setCookie(w http.ResponseWriter, cookie *http.Cookie) {
} }
} }
func chunk(s string, size int) []string {
ss := make([]string, 0, len(s)/size+1)
for len(s) > 0 {
if len(s) < size {
size = len(s)
}
ss, s = append(ss, s[:size]), s[size:]
}
return ss
}
// ClearSession clears the session cookie from a request
func (cs *CookieStore) ClearSession(w http.ResponseWriter, req *http.Request) {
http.SetCookie(w, cs.makeCookie(req, cs.Name, "", time.Hour*-1, time.Now()))
}
func (cs *CookieStore) setSessionCookie(w http.ResponseWriter, req *http.Request, val string) {
cs.setCookie(w, cs.makeSessionCookie(req, val, cs.CookieExpire, time.Now()))
}
func loadBearerToken(r *http.Request, headerKey string) string {
authHeader := r.Header.Get(headerKey)
split := strings.Split(authHeader, "Bearer")
if authHeader == "" || len(split) != 2 {
return ""
}
return strings.TrimSpace(split[1])
}
func loadChunkedCookie(r *http.Request, cookieName string) string { func loadChunkedCookie(r *http.Request, cookieName string) string {
c, err := r.Cookie(cookieName) c, err := r.Cookie(cookieName)
if err != nil { if err != nil {
@ -179,37 +170,13 @@ func loadChunkedCookie(r *http.Request, cookieName string) string {
return cipherText return cipherText
} }
// LoadSession returns a State from the cookie in the request. func chunk(s string, size int) []string {
func (cs *CookieStore) LoadSession(req *http.Request) (*State, error) { ss := make([]string, 0, len(s)/size+1)
cipherText := loadChunkedCookie(req, cs.Name) for len(s) > 0 {
if cipherText == "" { if len(s) < size {
cipherText = loadBearerToken(req, cs.BearerTokenHeader) size = len(s)
}
ss, s = append(ss, s[:size]), s[size:]
} }
if cipherText == "" { return ss
return nil, ErrEmptySession
}
session, err := UnmarshalSession(cipherText, cs.Encoder)
if err != nil {
return nil, err
}
return session, nil
}
// SaveSession saves a session state to a request sessions.
func (cs *CookieStore) SaveSession(w http.ResponseWriter, req *http.Request, s *State) error {
value, err := MarshalSession(s, cs.Encoder)
if err != nil {
return err
}
cs.setSessionCookie(w, req, value)
return nil
}
// ParentSubdomain returns the parent subdomain.
func ParentSubdomain(s string) string {
if strings.Count(s, ".") < 2 {
return ""
}
split := strings.SplitN(s, ".", 2)
return split[1]
} }

View file

@ -38,33 +38,30 @@ func TestNewCookieStore(t *testing.T) {
}{ }{
{"good", {"good",
&CookieStoreOptions{ &CookieStoreOptions{
Name: "_cookie", Name: "_cookie",
CookieSecure: true, CookieSecure: true,
CookieHTTPOnly: true, CookieHTTPOnly: true,
CookieDomain: "pomerium.io", CookieDomain: "pomerium.io",
CookieExpire: 10 * time.Second, CookieExpire: 10 * time.Second,
Encoder: encoder, Encoder: encoder,
BearerTokenHeader: "Authorization",
}, },
&CookieStore{ &CookieStore{
Name: "_cookie", Name: "_cookie",
CookieSecure: true, CookieSecure: true,
CookieHTTPOnly: true, CookieHTTPOnly: true,
CookieDomain: "pomerium.io", CookieDomain: "pomerium.io",
CookieExpire: 10 * time.Second, CookieExpire: 10 * time.Second,
Encoder: encoder, Encoder: encoder,
BearerTokenHeader: "Authorization",
}, },
false}, false},
{"missing name", {"missing name",
&CookieStoreOptions{ &CookieStoreOptions{
Name: "", Name: "",
CookieSecure: true, CookieSecure: true,
CookieHTTPOnly: true, CookieHTTPOnly: true,
CookieDomain: "pomerium.io", CookieDomain: "pomerium.io",
CookieExpire: 10 * time.Second, CookieExpire: 10 * time.Second,
Encoder: encoder, Encoder: encoder,
BearerTokenHeader: "Authorization",
}, },
nil, nil,
true}, true},
@ -250,23 +247,3 @@ func TestMockSessionStore(t *testing.T) {
}) })
} }
} }
func Test_ParentSubdomain(t *testing.T) {
t.Parallel()
tests := []struct {
s string
want string
}{
{"httpbin.corp.example.com", "corp.example.com"},
{"some.httpbin.corp.example.com", "httpbin.corp.example.com"},
{"example.com", ""},
{"", ""},
}
for _, tt := range tests {
t.Run(tt.s, func(t *testing.T) {
if got := ParentSubdomain(tt.s); got != tt.want {
t.Errorf("ParentSubdomain() = %v, want %v", got, tt.want)
}
})
}
}

View file

@ -0,0 +1,61 @@
package sessions // import "github.com/pomerium/pomerium/internal/sessions"
import (
"net/http"
"strings"
"github.com/pomerium/pomerium/internal/cryptutil"
)
const (
// defaultAuthHeader and defaultAuthType are default header name for the
// authorization bearer token header as defined in rfc2617
// https://tools.ietf.org/html/rfc6750#section-2.1
defaultAuthHeader = "Authorization"
defaultAuthType = "Bearer"
)
// HeaderStore implements the load session store interface using http
// authorization headers.
type HeaderStore struct {
authHeader string
authType string
encoder cryptutil.SecureEncoder
}
// NewHeaderStore returns a new header store for loading sessions from
// authorization headers.
func NewHeaderStore(enc cryptutil.SecureEncoder) *HeaderStore {
return &HeaderStore{
authHeader: defaultAuthHeader,
authType: defaultAuthType,
encoder: enc,
}
}
// LoadSession tries to retrieve the token string from the Authorization header.
//
// NOTA BENE: While most servers do not log Authorization headers by default,
// you should ensure no other services are logging or leaking your auth headers.
func (as *HeaderStore) LoadSession(r *http.Request) (*State, error) {
cipherText := as.tokenFromHeader(r)
if cipherText == "" {
return nil, ErrNoSessionFound
}
session, err := UnmarshalSession(cipherText, as.encoder)
if err != nil {
return nil, ErrMalformed
}
return session, nil
}
// retrieve the value of the authorization header
func (as *HeaderStore) tokenFromHeader(r *http.Request) string {
bearer := r.Header.Get(as.authHeader)
atSize := len(as.authType)
if len(bearer) > atSize && strings.EqualFold(bearer[0:atSize], as.authType) {
return bearer[atSize+1:]
}
return ""
}

View file

@ -4,7 +4,6 @@ import (
"context" "context"
"errors" "errors"
"net/http" "net/http"
"strings"
) )
// Context keys // Context keys
@ -13,65 +12,56 @@ var (
ErrorCtxKey = &contextKey{"Error"} ErrorCtxKey = &contextKey{"Error"}
) )
// Library errors
var (
ErrExpired = errors.New("internal/sessions: session is expired")
ErrNoSessionFound = errors.New("internal/sessions: session is not found")
ErrMalformed = errors.New("internal/sessions: session is malformed")
)
// RetrieveSession http middleware handler will verify a auth session from a http request.
//
// RetrieveSession will search for a auth session in a http request, in the order: // RetrieveSession will search for a auth session in a http request, in the order:
// 1. `pomerium_session` URI query parameter // 1. `pomerium_session` URI query parameter
// 2. `Authorization: BEARER` request header // 2. `Authorization: BEARER` request header
// 3. Cookie `_pomerium` value // 3. Cookie `_pomerium` value
func RetrieveSession(s SessionStore) func(http.Handler) http.Handler { func RetrieveSession(s ...SessionLoader) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler {
return retrieve(s, TokenFromQuery, TokenFromHeader, TokenFromCookie)(next) return retrieve(s...)(next)
} }
} }
func retrieve(s SessionStore, findTokenFns ...func(r *http.Request) string) func(http.Handler) http.Handler { func retrieve(s ...SessionLoader) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler {
hfn := func(w http.ResponseWriter, r *http.Request) { hfn := func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
token, err := retrieveFromRequest(s, r, findTokenFns...) state, err := retrieveFromRequest(r, s...)
ctx = NewContext(ctx, token, err) ctx = NewContext(ctx, state, err)
next.ServeHTTP(w, r.WithContext(ctx)) next.ServeHTTP(w, r.WithContext(ctx))
} }
return http.HandlerFunc(hfn) return http.HandlerFunc(hfn)
} }
} }
func retrieveFromRequest(s SessionStore, r *http.Request, findTokenFns ...func(r *http.Request) string) (*State, error) { func retrieveFromRequest(r *http.Request, sessions ...SessionLoader) (*State, error) {
var tokenStr string state := new(State)
var err error var err error
// Extract token string from the request by calling token find functions in // Extract sessions state from the request by calling token find functions in
// the order they where provided. Further extraction stops if a function // the order they where provided. Further extraction stops if a function
// returns a non-empty string. // returns a non-empty string.
for _, fn := range findTokenFns { for _, s := range sessions {
tokenStr = fn(r) state, err = s.LoadSession(r)
if tokenStr != "" { if err != nil && !errors.Is(err, ErrNoSessionFound) {
// unexpected error
return nil, err
}
// break, we found a session state
if state != nil {
break break
} }
} }
if tokenStr == "" { // no session found if state is still empty
if state == nil {
return nil, ErrNoSessionFound return nil, ErrNoSessionFound
} }
state, err := s.LoadSession(r) if err = state.Valid(); err != nil {
if err != nil {
return nil, ErrMalformed
}
err = state.Valid()
if err != nil {
// a little unusual but we want to return the expired state too // a little unusual but we want to return the expired state too
return state, err return state, err
} }
// Valid!
return state, nil return state, nil
} }
@ -89,35 +79,6 @@ func FromContext(ctx context.Context) (*State, error) {
return state, err return state, err
} }
// TokenFromCookie tries to retrieve the token string from a cookie named
// "_pomerium".
func TokenFromCookie(r *http.Request) string {
cookie, err := r.Cookie("_pomerium")
if err != nil {
return ""
}
return cookie.Value
}
// TokenFromHeader tries to retrieve the token string from the
// "Authorization" request header: "Authorization: BEARER T".
func TokenFromHeader(r *http.Request) string {
// Get token from authorization header.
bearer := r.Header.Get("Authorization")
if len(bearer) > 7 && strings.EqualFold(bearer[0:6], "BEARER") {
return bearer[7:]
}
return ""
}
// TokenFromQuery tries to retrieve the token string from the "pomerium_session" URI
// query parameter.
// todo(bdd) : document setting session code as queryparam
func TokenFromQuery(r *http.Request) string {
// Get token from query param named "pomerium_session".
return r.URL.Query().Get("pomerium_session")
}
// contextKey is a value for use with context.WithValue. It's used as // contextKey is a value for use with context.WithValue. It's used as
// a pointer so it fits in an interface{} without allocation. This technique // a pointer so it fits in an interface{} without allocation. This technique
// for defining context keys was copied from Go 1.7's new use of context in net/http. // for defining context keys was copied from Go 1.7's new use of context in net/http.
@ -126,5 +87,5 @@ type contextKey struct {
} }
func (k *contextKey) String() string { func (k *contextKey) String() string {
return "SessionStore context value " + k.name return "context value " + k.name
} }

View file

@ -9,9 +9,8 @@ import (
"testing" "testing"
"time" "time"
"github.com/pomerium/pomerium/internal/cryptutil"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
"github.com/pomerium/pomerium/internal/cryptutil"
) )
func TestNewContext(t *testing.T) { func TestNewContext(t *testing.T) {
@ -75,8 +74,8 @@ func TestVerifier(t *testing.T) {
{"expired auth header", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(-10 * time.Second)}, false, true, false, "internal/sessions: session is expired\n", http.StatusUnauthorized}, {"expired auth header", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(-10 * time.Second)}, false, true, false, "internal/sessions: session is expired\n", http.StatusUnauthorized},
{"malformed auth header", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(-10 * time.Second)}, false, true, false, "internal/sessions: session is malformed\n", http.StatusUnauthorized}, {"malformed auth header", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(-10 * time.Second)}, false, true, false, "internal/sessions: session is malformed\n", http.StatusUnauthorized},
{"good auth query param session", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(10 * time.Second)}, false, true, true, http.StatusText(http.StatusOK), http.StatusOK}, {"good auth query param session", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(10 * time.Second)}, false, true, true, http.StatusText(http.StatusOK), http.StatusOK},
{"expired auth query param", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(-10 * time.Second)}, false, true, true, "internal/sessions: session is expired\n", http.StatusUnauthorized}, {"expired auth query param", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(-10 * time.Second)}, false, false, true, "internal/sessions: session is expired\n", http.StatusUnauthorized},
{"malformed auth query param", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(-10 * time.Second)}, false, true, true, "internal/sessions: session is malformed\n", http.StatusUnauthorized}, {"malformed auth query param", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(-10 * time.Second)}, false, false, true, "internal/sessions: session is malformed\n", http.StatusUnauthorized},
{"no session", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(-10 * time.Second)}, false, false, false, "internal/sessions: session is not found\n", http.StatusUnauthorized}, {"no session", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(-10 * time.Second)}, false, false, false, "internal/sessions: session is not found\n", http.StatusUnauthorized},
} }
for _, tt := range tests { for _, tt := range tests {
@ -94,7 +93,6 @@ func TestVerifier(t *testing.T) {
if strings.Contains(tt.name, "malformed") { if strings.Contains(tt.name, "malformed") {
// add some garbage to the end of the string // add some garbage to the end of the string
encSession += cryptutil.NewBase64Key() encSession += cryptutil.NewBase64Key()
fmt.Println(encSession)
} }
cs, err := NewCookieStore(&CookieStoreOptions{ cs, err := NewCookieStore(&CookieStoreOptions{
@ -104,6 +102,9 @@ func TestVerifier(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
as := NewHeaderStore(encoder)
qp := NewQueryParamStore(encoder)
r := httptest.NewRequest(http.MethodGet, "/", nil) r := httptest.NewRequest(http.MethodGet, "/", nil)
r.Header.Set("Accept", "application/json") r.Header.Set("Accept", "application/json")
@ -114,11 +115,12 @@ func TestVerifier(t *testing.T) {
r.Header.Set("Authorization", "Bearer "+encSession) r.Header.Set("Authorization", "Bearer "+encSession)
} else if tt.param { } else if tt.param {
q := r.URL.Query() q := r.URL.Query()
q.Add("pomerium_session", encSession)
q.Set("pomerium_session", encSession)
r.URL.RawQuery = q.Encode() r.URL.RawQuery = q.Encode()
} }
got := RetrieveSession(cs)(testAuthorizer((fnh))) got := RetrieveSession(cs, as, qp)(testAuthorizer((fnh)))
got.ServeHTTP(w, r) got.ServeHTTP(w, r)
gotBody := w.Body.String() gotBody := w.Body.String()
@ -133,3 +135,23 @@ func TestVerifier(t *testing.T) {
}) })
} }
} }
func Test_contextKey_String(t *testing.T) {
tests := []struct {
name string
keyName string
want string
}{
{"simple example", "test", "context value test"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
k := &contextKey{
name: tt.keyName,
}
if got := k.String(); got != tt.want {
t.Errorf("contextKey.String() = %v, want %v", got, tt.want)
}
})
}
}

View file

@ -0,0 +1,44 @@
package sessions // import "github.com/pomerium/pomerium/internal/sessions"
import (
"net/http"
"github.com/pomerium/pomerium/internal/cryptutil"
)
const (
defaultQueryParamKey = "pomerium_session"
)
// QueryParamStore implements the load session store interface using http
// query strings / query parameters.
type QueryParamStore struct {
queryParamKey string
encoder cryptutil.SecureEncoder
}
// NewQueryParamStore returns a new query param store for loading sessions from
// query strings / query parameters.
func NewQueryParamStore(enc cryptutil.SecureEncoder) *QueryParamStore {
return &QueryParamStore{
queryParamKey: defaultQueryParamKey,
encoder: enc,
}
}
// LoadSession tries to retrieve the token string from URL query parameters.
//
// NOTA BENE: By default, most servers _DO_ log query params, the leaking or
// accidental logging of which should be considered a security issue.
func (qp *QueryParamStore) LoadSession(r *http.Request) (*State, error) {
cipherText := r.URL.Query().Get(qp.queryParamKey)
if cipherText == "" {
return nil, ErrNoSessionFound
}
session, err := UnmarshalSession(cipherText, qp.encoder)
if err != nil {
return nil, ErrMalformed
}
return session, nil
}

View file

@ -5,8 +5,14 @@ import (
"net/http" "net/http"
) )
// ErrEmptySession is an error for an empty sessions. var (
var ErrEmptySession = errors.New("internal/sessions: empty session") // ErrExpired is the error for an expired session.
ErrExpired = errors.New("internal/sessions: session is expired")
// ErrNoSessionFound is the error for when no session is found.
ErrNoSessionFound = errors.New("internal/sessions: session is not found")
// ErrMalformed is the error for when a session is found but is malformed.
ErrMalformed = errors.New("internal/sessions: session is malformed")
)
// SessionStore has the functions for setting, getting, and clearing the Session cookie // SessionStore has the functions for setting, getting, and clearing the Session cookie
type SessionStore interface { type SessionStore interface {
@ -14,3 +20,9 @@ type SessionStore interface {
LoadSession(*http.Request) (*State, error) LoadSession(*http.Request) (*State, error)
SaveSession(http.ResponseWriter, *http.Request, *State) error SaveSession(http.ResponseWriter, *http.Request, *State) error
} }
// SessionLoader is implemented by any struct that loads a pomerium session
// given a request, and returns a user state.
type SessionLoader interface {
LoadSession(*http.Request) (*State, error)
}

12
internal/sessions/util.go Normal file
View file

@ -0,0 +1,12 @@
package sessions // import "github.com/pomerium/pomerium/internal/sessions"
import "strings"
// ParentSubdomain returns the parent subdomain.
func ParentSubdomain(s string) string {
if strings.Count(s, ".") < 2 {
return ""
}
split := strings.SplitN(s, ".", 2)
return split[1]
}

View file

@ -0,0 +1,23 @@
package sessions
import "testing"
func Test_ParentSubdomain(t *testing.T) {
t.Parallel()
tests := []struct {
s string
want string
}{
{"httpbin.corp.example.com", "corp.example.com"},
{"some.httpbin.corp.example.com", "httpbin.corp.example.com"},
{"example.com", ""},
{"", ""},
}
for _, tt := range tests {
t.Run(tt.s, func(t *testing.T) {
if got := ParentSubdomain(tt.s); got != tt.want {
t.Errorf("ParentSubdomain() = %v, want %v", got, tt.want)
}
})
}
}

View file

@ -184,8 +184,8 @@ func (p *Proxy) Verify(w http.ResponseWriter, r *http.Request) {
} }
// check the queryparams to see if this check immediately followed // check the queryparams to see if this check immediately followed
// authentication. If so, redirect back to the originally requested hostname. // authentication. If so, redirect back to the originally requested hostname.
if isCallback := r.URL.Query().Get(disableCallback); isCallback == "true" { if isCallback := r.URL.Query().Get("pomerium-auth-callback"); isCallback == "true" {
http.Redirect(w, r, "http://"+hostname, http.StatusFound) http.Redirect(w, r, hostname, http.StatusFound)
return return
} }

View file

@ -302,13 +302,13 @@ func TestProxy_VerifyWithMiddleware(t *testing.T) {
authorizer clients.Authorizer authorizer clients.Authorizer
wantStatus int wantStatus int
}{ }{
{"good", opts, nil, http.MethodGet, "false", "/.pomerium/verify/some.domain.name", &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusOK}, {"good", opts, nil, http.MethodGet, "", "/.pomerium/verify/some.domain.name", &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusOK},
{"good post auth redirect", opts, nil, http.MethodGet, "true", "/.pomerium/verify/some.domain.name", &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound}, {"good post auth redirect", opts, nil, http.MethodGet, "pomerium-auth-callback", "/.pomerium/verify/some.domain.name", &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound},
{"not authorized", opts, nil, http.MethodGet, "false", "/.pomerium/verify/some.domain.name", &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusForbidden}, {"not authorized", opts, nil, http.MethodGet, "", "/.pomerium/verify/some.domain.name", &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusForbidden},
{"not authorized expired, redirect to auth", opts, nil, http.MethodGet, "false", "/.pomerium/verify/some.domain.name", &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusFound}, {"not authorized expired, redirect to auth", opts, nil, http.MethodGet, "", "/.pomerium/verify/some.domain.name", &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusFound},
{"not authorized expired, don't redirect!", opts, nil, http.MethodGet, "true", "/.pomerium/verify/some.domain.name?no_redirect=true", &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusUnauthorized}, {"not authorized expired, don't redirect!", opts, nil, http.MethodGet, HeaderNoAuthRedirect, "/.pomerium/verify/some.domain.name?no_redirect=true", &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusUnauthorized},
{"not authorized because of error", opts, nil, http.MethodGet, "false", "/.pomerium/verify/some.domain.name", &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{AuthorizeError: errors.New("authz error")}, http.StatusInternalServerError}, {"not authorized because of error", opts, nil, http.MethodGet, "", "/.pomerium/verify/some.domain.name", &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{AuthorizeError: errors.New("authz error")}, http.StatusInternalServerError},
{"bad context retrieval error", opts, errors.New("oh no"), http.MethodGet, "false", "/.pomerium/verify/some.domain.name", &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusOK}, {"bad context retrieval error", opts, errors.New("oh no"), http.MethodGet, "", "/.pomerium/verify/some.domain.name", &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusOK},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
@ -322,8 +322,8 @@ func TestProxy_VerifyWithMiddleware(t *testing.T) {
p.UpdateOptions(tt.options) p.UpdateOptions(tt.options)
uri := &url.URL{Path: tt.path} uri := &url.URL{Path: tt.path}
queryString := uri.Query() queryString := uri.Query()
if tt.qp == "true" { if tt.qp != "" {
queryString.Set("pomerium-auth-callback", tt.qp) queryString.Set(tt.qp, "true")
} }
uri.RawQuery = queryString.Encode() uri.RawQuery = queryString.Encode()

View file

@ -23,7 +23,9 @@ const (
// HeaderGroups is the header key containing the user's groups. // HeaderGroups is the header key containing the user's groups.
HeaderGroups = "x-pomerium-authenticated-user-groups" HeaderGroups = "x-pomerium-authenticated-user-groups"
disableCallback = "pomerium-auth-callback" // HeaderNoAuthRedirect is the header / query param key used to disable
// redirecting unauthenticated request by default but instead return a 401.
HeaderNoAuthRedirect = "x-pomerium-no-auth-redirect"
) )
// AuthenticateSession is middleware to enforce a valid authentication // AuthenticateSession is middleware to enforce a valid authentication
@ -33,8 +35,8 @@ func (p *Proxy) AuthenticateSession(next http.Handler) http.Handler {
ctx, span := trace.StartSpan(r.Context(), "middleware.AuthenticateSession") ctx, span := trace.StartSpan(r.Context(), "middleware.AuthenticateSession")
defer span.End() defer span.End()
s, err := sessions.FromContext(r.Context()) s, err := sessions.FromContext(r.Context())
if err != nil { if err != nil || s == nil {
log.Debug().Str("cause", err.Error()).Msg("proxy: re-authenticating due to session state error") log.Debug().Msg("proxy: re-authenticating due to session state error")
p.reqNeedsAuthentication(w, r) p.reqNeedsAuthentication(w, r)
return return
} }
@ -58,7 +60,7 @@ func (p *Proxy) AuthorizeSession(next http.Handler) http.Handler {
ctx, span := trace.StartSpan(r.Context(), "middleware.AuthorizeSession") ctx, span := trace.StartSpan(r.Context(), "middleware.AuthorizeSession")
defer span.End() defer span.End()
s, err := sessions.FromContext(r.Context()) s, err := sessions.FromContext(r.Context())
if err != nil { if err != nil || s == nil {
httputil.ErrorResponse(w, r.WithContext(ctx), httputil.Error("", http.StatusForbidden, err)) httputil.ErrorResponse(w, r.WithContext(ctx), httputil.Error("", http.StatusForbidden, err))
return return
} }
@ -105,9 +107,11 @@ func (p *Proxy) reqNeedsAuthentication(w http.ResponseWriter, r *http.Request) {
// some proxies like nginx won't follow redirects, and treat any // some proxies like nginx won't follow redirects, and treat any
// non 2xx or 4xx status as an internal service error. // non 2xx or 4xx status as an internal service error.
// https://nginx.org/en/docs/http/ngx_http_auth_request_module.html // https://nginx.org/en/docs/http/ngx_http_auth_request_module.html
if _, ok := r.URL.Query()[disableCallback]; ok { redirectHeader := r.Header.Get(HeaderNoAuthRedirect)
if _, ok := r.URL.Query()[HeaderNoAuthRedirect]; ok || redirectHeader == "true" {
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
} }
r.Header.Get(HeaderNoAuthRedirect)
uri := urlutil.SignedRedirectURL(p.SharedKey, p.authenticateSigninURL, urlutil.GetAbsoluteURL(r)) uri := urlutil.SignedRedirectURL(p.SharedKey, p.authenticateSigninURL, urlutil.GetAbsoluteURL(r))
http.Redirect(w, r, uri.String(), http.StatusFound) http.Redirect(w, r, uri.String(), http.StatusFound)
} }

View file

@ -78,6 +78,7 @@ type Proxy struct {
refreshCooldown time.Duration refreshCooldown time.Duration
Handler http.Handler Handler http.Handler
sessionStore sessions.SessionStore sessionStore sessions.SessionStore
sessionLoaders []sessions.SessionLoader
signingKey string signingKey string
templates *template.Template templates *template.Template
} }
@ -122,8 +123,12 @@ func New(opts config.Options) (*Proxy, error) {
defaultUpstreamTimeout: opts.DefaultUpstreamTimeout, defaultUpstreamTimeout: opts.DefaultUpstreamTimeout,
refreshCooldown: opts.RefreshCooldown, refreshCooldown: opts.RefreshCooldown,
sessionStore: cookieStore, sessionStore: cookieStore,
signingKey: opts.SigningKey, sessionLoaders: []sessions.SessionLoader{
templates: templates.New(), cookieStore,
sessions.NewHeaderStore(encoder),
sessions.NewQueryParamStore(encoder)},
signingKey: opts.SigningKey,
templates: templates.New(),
} }
// errors checked in ValidateOptions // errors checked in ValidateOptions
p.authorizeURL, _ = urlutil.DeepCopy(opts.AuthorizeURL) p.authorizeURL, _ = urlutil.DeepCopy(opts.AuthorizeURL)
@ -227,7 +232,7 @@ func (p *Proxy) reverseProxyHandler(r *mux.Router, policy *config.Policy) (*mux.
} }
// 4. Retrieve the user session and add it to the request context // 4. Retrieve the user session and add it to the request context
rp.Use(sessions.RetrieveSession(p.sessionStore)) rp.Use(sessions.RetrieveSession(p.sessionLoaders...))
// 5. Strip the user session cookie from the downstream request // 5. Strip the user session cookie from the downstream request
rp.Use(middleware.StripCookie(p.cookieName)) rp.Use(middleware.StripCookie(p.cookieName))
// 6. AuthN - Verify the user is authenticated. Set email, group, & id headers // 6. AuthN - Verify the user is authenticated. Set email, group, & id headers