mirror of
https://github.com/pomerium/pomerium.git
synced 2025-06-06 04:42:56 +02:00
internal/sessions: refactor how sessions loading (#351)
These chagnes standardize how session loading is done for session cookie, auth bearer token, and query params. - Bearer token previously combined with session cookie. - rearranged cookie-store to put exported methods above unexported - added header store that implements session loader interface - added query param store that implements session loader interface Signed-off-by: Bobby DeSimone <bobbydesimone@gmail.com>
This commit is contained in:
parent
7aa4621b1b
commit
badd8d69af
13 changed files with 322 additions and 234 deletions
|
@ -10,47 +10,39 @@ import (
|
|||
"github.com/pomerium/pomerium/internal/cryptutil"
|
||||
)
|
||||
|
||||
// ChunkedCanaryByte is the byte value used as a canary prefix to distinguish if
|
||||
// the cookie is multi-part or not. This constant *should not* be valid
|
||||
// base64. It's important this byte is ASCII to avoid UTF-8 variable sized runes.
|
||||
// https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Set-Cookie#Directives
|
||||
const ChunkedCanaryByte byte = '%'
|
||||
const (
|
||||
// ChunkedCanaryByte is the byte value used as a canary prefix to distinguish if
|
||||
// the cookie is multi-part or not. This constant *should not* be valid
|
||||
// base64. It's important this byte is ASCII to avoid UTF-8 variable sized runes.
|
||||
// https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Set-Cookie#Directives
|
||||
ChunkedCanaryByte byte = '%'
|
||||
// MaxChunkSize sets the upper bound on a cookie chunks payload value.
|
||||
// Note, this should be lower than the actual cookie's max size (4096 bytes)
|
||||
// which includes metadata.
|
||||
MaxChunkSize = 3800
|
||||
// MaxNumChunks limits the number of chunks to iterate through. Conservatively
|
||||
// set to prevent any abuse.
|
||||
MaxNumChunks = 5
|
||||
)
|
||||
|
||||
// DefaultBearerTokenHeader is default header name for the authorization bearer
|
||||
// token header as defined in rfc2617
|
||||
// https://tools.ietf.org/html/rfc6750#section-2.1
|
||||
const DefaultBearerTokenHeader = "Authorization"
|
||||
|
||||
// MaxChunkSize sets the upper bound on a cookie chunks payload value.
|
||||
// Note, this should be lower than the actual cookie's max size (4096 bytes)
|
||||
// which includes metadata.
|
||||
const MaxChunkSize = 3800
|
||||
|
||||
// MaxNumChunks limits the number of chunks to iterate through. Conservatively
|
||||
// set to prevent any abuse.
|
||||
const MaxNumChunks = 5
|
||||
|
||||
// CookieStore represents all the cookie related configurations
|
||||
// CookieStore implements the session store interface for session cookies.
|
||||
type CookieStore struct {
|
||||
Name string
|
||||
Encoder cryptutil.SecureEncoder
|
||||
CookieExpire time.Duration
|
||||
CookieRefresh time.Duration
|
||||
CookieSecure bool
|
||||
CookieHTTPOnly bool
|
||||
CookieDomain string
|
||||
BearerTokenHeader string
|
||||
Name string
|
||||
CookieDomain string
|
||||
CookieExpire time.Duration
|
||||
CookieHTTPOnly bool
|
||||
CookieSecure bool
|
||||
Encoder cryptutil.SecureEncoder
|
||||
}
|
||||
|
||||
// CookieStoreOptions holds options for CookieStore
|
||||
type CookieStoreOptions struct {
|
||||
Name string
|
||||
CookieSecure bool
|
||||
CookieHTTPOnly bool
|
||||
CookieDomain string
|
||||
BearerTokenHeader string
|
||||
CookieExpire time.Duration
|
||||
Encoder cryptutil.SecureEncoder
|
||||
Name string
|
||||
CookieDomain string
|
||||
CookieExpire time.Duration
|
||||
CookieHTTPOnly bool
|
||||
CookieSecure bool
|
||||
Encoder cryptutil.SecureEncoder
|
||||
}
|
||||
|
||||
// NewCookieStore returns a new session with ciphers for each of the cookie secrets
|
||||
|
@ -61,18 +53,14 @@ func NewCookieStore(opts *CookieStoreOptions) (*CookieStore, error) {
|
|||
if opts.Encoder == nil {
|
||||
return nil, fmt.Errorf("internal/sessions: cipher cannot be nil")
|
||||
}
|
||||
if opts.BearerTokenHeader == "" {
|
||||
opts.BearerTokenHeader = DefaultBearerTokenHeader
|
||||
}
|
||||
|
||||
return &CookieStore{
|
||||
Name: opts.Name,
|
||||
CookieSecure: opts.CookieSecure,
|
||||
CookieHTTPOnly: opts.CookieHTTPOnly,
|
||||
CookieDomain: opts.CookieDomain,
|
||||
CookieExpire: opts.CookieExpire,
|
||||
Encoder: opts.Encoder,
|
||||
BearerTokenHeader: opts.BearerTokenHeader,
|
||||
Name: opts.Name,
|
||||
CookieSecure: opts.CookieSecure,
|
||||
CookieHTTPOnly: opts.CookieHTTPOnly,
|
||||
CookieDomain: opts.CookieDomain,
|
||||
CookieExpire: opts.CookieExpire,
|
||||
Encoder: opts.Encoder,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
@ -103,11 +91,43 @@ func (cs *CookieStore) makeCookie(req *http.Request, name string, value string,
|
|||
return c
|
||||
}
|
||||
|
||||
// ClearSession clears the session cookie from a request
|
||||
func (cs *CookieStore) ClearSession(w http.ResponseWriter, req *http.Request) {
|
||||
http.SetCookie(w, cs.makeCookie(req, cs.Name, "", time.Hour*-1, time.Now()))
|
||||
}
|
||||
|
||||
// LoadSession returns a State from the cookie in the request.
|
||||
func (cs *CookieStore) LoadSession(req *http.Request) (*State, error) {
|
||||
cipherText := loadChunkedCookie(req, cs.Name)
|
||||
if cipherText == "" {
|
||||
return nil, ErrNoSessionFound
|
||||
}
|
||||
session, err := UnmarshalSession(cipherText, cs.Encoder)
|
||||
if err != nil {
|
||||
return nil, ErrMalformed
|
||||
}
|
||||
return session, nil
|
||||
}
|
||||
|
||||
// SaveSession saves a session state to a request sessions.
|
||||
func (cs *CookieStore) SaveSession(w http.ResponseWriter, req *http.Request, s *State) error {
|
||||
value, err := MarshalSession(s, cs.Encoder)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cs.setSessionCookie(w, req, value)
|
||||
return nil
|
||||
}
|
||||
|
||||
// makeSessionCookie constructs a session cookie given the request, an expiration time and the current time.
|
||||
func (cs *CookieStore) makeSessionCookie(req *http.Request, value string, expiration time.Duration, now time.Time) *http.Cookie {
|
||||
return cs.makeCookie(req, cs.Name, value, expiration, now)
|
||||
}
|
||||
|
||||
func (cs *CookieStore) setSessionCookie(w http.ResponseWriter, req *http.Request, val string) {
|
||||
cs.setCookie(w, cs.makeSessionCookie(req, val, cs.CookieExpire, time.Now()))
|
||||
}
|
||||
|
||||
func (cs *CookieStore) setCookie(w http.ResponseWriter, cookie *http.Cookie) {
|
||||
if len(cookie.String()) <= MaxChunkSize {
|
||||
http.SetCookie(w, cookie)
|
||||
|
@ -128,35 +148,6 @@ func (cs *CookieStore) setCookie(w http.ResponseWriter, cookie *http.Cookie) {
|
|||
}
|
||||
}
|
||||
|
||||
func chunk(s string, size int) []string {
|
||||
ss := make([]string, 0, len(s)/size+1)
|
||||
for len(s) > 0 {
|
||||
if len(s) < size {
|
||||
size = len(s)
|
||||
}
|
||||
ss, s = append(ss, s[:size]), s[size:]
|
||||
}
|
||||
return ss
|
||||
}
|
||||
|
||||
// ClearSession clears the session cookie from a request
|
||||
func (cs *CookieStore) ClearSession(w http.ResponseWriter, req *http.Request) {
|
||||
http.SetCookie(w, cs.makeCookie(req, cs.Name, "", time.Hour*-1, time.Now()))
|
||||
}
|
||||
|
||||
func (cs *CookieStore) setSessionCookie(w http.ResponseWriter, req *http.Request, val string) {
|
||||
cs.setCookie(w, cs.makeSessionCookie(req, val, cs.CookieExpire, time.Now()))
|
||||
}
|
||||
|
||||
func loadBearerToken(r *http.Request, headerKey string) string {
|
||||
authHeader := r.Header.Get(headerKey)
|
||||
split := strings.Split(authHeader, "Bearer")
|
||||
if authHeader == "" || len(split) != 2 {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(split[1])
|
||||
}
|
||||
|
||||
func loadChunkedCookie(r *http.Request, cookieName string) string {
|
||||
c, err := r.Cookie(cookieName)
|
||||
if err != nil {
|
||||
|
@ -179,37 +170,13 @@ func loadChunkedCookie(r *http.Request, cookieName string) string {
|
|||
return cipherText
|
||||
}
|
||||
|
||||
// LoadSession returns a State from the cookie in the request.
|
||||
func (cs *CookieStore) LoadSession(req *http.Request) (*State, error) {
|
||||
cipherText := loadChunkedCookie(req, cs.Name)
|
||||
if cipherText == "" {
|
||||
cipherText = loadBearerToken(req, cs.BearerTokenHeader)
|
||||
func chunk(s string, size int) []string {
|
||||
ss := make([]string, 0, len(s)/size+1)
|
||||
for len(s) > 0 {
|
||||
if len(s) < size {
|
||||
size = len(s)
|
||||
}
|
||||
ss, s = append(ss, s[:size]), s[size:]
|
||||
}
|
||||
if cipherText == "" {
|
||||
return nil, ErrEmptySession
|
||||
}
|
||||
session, err := UnmarshalSession(cipherText, cs.Encoder)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return session, nil
|
||||
}
|
||||
|
||||
// SaveSession saves a session state to a request sessions.
|
||||
func (cs *CookieStore) SaveSession(w http.ResponseWriter, req *http.Request, s *State) error {
|
||||
value, err := MarshalSession(s, cs.Encoder)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cs.setSessionCookie(w, req, value)
|
||||
return nil
|
||||
}
|
||||
|
||||
// ParentSubdomain returns the parent subdomain.
|
||||
func ParentSubdomain(s string) string {
|
||||
if strings.Count(s, ".") < 2 {
|
||||
return ""
|
||||
}
|
||||
split := strings.SplitN(s, ".", 2)
|
||||
return split[1]
|
||||
return ss
|
||||
}
|
||||
|
|
|
@ -38,33 +38,30 @@ func TestNewCookieStore(t *testing.T) {
|
|||
}{
|
||||
{"good",
|
||||
&CookieStoreOptions{
|
||||
Name: "_cookie",
|
||||
CookieSecure: true,
|
||||
CookieHTTPOnly: true,
|
||||
CookieDomain: "pomerium.io",
|
||||
CookieExpire: 10 * time.Second,
|
||||
Encoder: encoder,
|
||||
BearerTokenHeader: "Authorization",
|
||||
Name: "_cookie",
|
||||
CookieSecure: true,
|
||||
CookieHTTPOnly: true,
|
||||
CookieDomain: "pomerium.io",
|
||||
CookieExpire: 10 * time.Second,
|
||||
Encoder: encoder,
|
||||
},
|
||||
&CookieStore{
|
||||
Name: "_cookie",
|
||||
CookieSecure: true,
|
||||
CookieHTTPOnly: true,
|
||||
CookieDomain: "pomerium.io",
|
||||
CookieExpire: 10 * time.Second,
|
||||
Encoder: encoder,
|
||||
BearerTokenHeader: "Authorization",
|
||||
Name: "_cookie",
|
||||
CookieSecure: true,
|
||||
CookieHTTPOnly: true,
|
||||
CookieDomain: "pomerium.io",
|
||||
CookieExpire: 10 * time.Second,
|
||||
Encoder: encoder,
|
||||
},
|
||||
false},
|
||||
{"missing name",
|
||||
&CookieStoreOptions{
|
||||
Name: "",
|
||||
CookieSecure: true,
|
||||
CookieHTTPOnly: true,
|
||||
CookieDomain: "pomerium.io",
|
||||
CookieExpire: 10 * time.Second,
|
||||
Encoder: encoder,
|
||||
BearerTokenHeader: "Authorization",
|
||||
Name: "",
|
||||
CookieSecure: true,
|
||||
CookieHTTPOnly: true,
|
||||
CookieDomain: "pomerium.io",
|
||||
CookieExpire: 10 * time.Second,
|
||||
Encoder: encoder,
|
||||
},
|
||||
nil,
|
||||
true},
|
||||
|
@ -250,23 +247,3 @@ func TestMockSessionStore(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_ParentSubdomain(t *testing.T) {
|
||||
t.Parallel()
|
||||
tests := []struct {
|
||||
s string
|
||||
want string
|
||||
}{
|
||||
{"httpbin.corp.example.com", "corp.example.com"},
|
||||
{"some.httpbin.corp.example.com", "httpbin.corp.example.com"},
|
||||
{"example.com", ""},
|
||||
{"", ""},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.s, func(t *testing.T) {
|
||||
if got := ParentSubdomain(tt.s); got != tt.want {
|
||||
t.Errorf("ParentSubdomain() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
61
internal/sessions/header_store.go
Normal file
61
internal/sessions/header_store.go
Normal file
|
@ -0,0 +1,61 @@
|
|||
package sessions // import "github.com/pomerium/pomerium/internal/sessions"
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/cryptutil"
|
||||
)
|
||||
|
||||
const (
|
||||
// defaultAuthHeader and defaultAuthType are default header name for the
|
||||
// authorization bearer token header as defined in rfc2617
|
||||
// https://tools.ietf.org/html/rfc6750#section-2.1
|
||||
defaultAuthHeader = "Authorization"
|
||||
defaultAuthType = "Bearer"
|
||||
)
|
||||
|
||||
// HeaderStore implements the load session store interface using http
|
||||
// authorization headers.
|
||||
type HeaderStore struct {
|
||||
authHeader string
|
||||
authType string
|
||||
encoder cryptutil.SecureEncoder
|
||||
}
|
||||
|
||||
// NewHeaderStore returns a new header store for loading sessions from
|
||||
// authorization headers.
|
||||
func NewHeaderStore(enc cryptutil.SecureEncoder) *HeaderStore {
|
||||
return &HeaderStore{
|
||||
authHeader: defaultAuthHeader,
|
||||
authType: defaultAuthType,
|
||||
encoder: enc,
|
||||
}
|
||||
}
|
||||
|
||||
// LoadSession tries to retrieve the token string from the Authorization header.
|
||||
//
|
||||
// NOTA BENE: While most servers do not log Authorization headers by default,
|
||||
// you should ensure no other services are logging or leaking your auth headers.
|
||||
func (as *HeaderStore) LoadSession(r *http.Request) (*State, error) {
|
||||
cipherText := as.tokenFromHeader(r)
|
||||
if cipherText == "" {
|
||||
return nil, ErrNoSessionFound
|
||||
}
|
||||
session, err := UnmarshalSession(cipherText, as.encoder)
|
||||
if err != nil {
|
||||
return nil, ErrMalformed
|
||||
}
|
||||
return session, nil
|
||||
|
||||
}
|
||||
|
||||
// retrieve the value of the authorization header
|
||||
func (as *HeaderStore) tokenFromHeader(r *http.Request) string {
|
||||
bearer := r.Header.Get(as.authHeader)
|
||||
atSize := len(as.authType)
|
||||
if len(bearer) > atSize && strings.EqualFold(bearer[0:atSize], as.authType) {
|
||||
return bearer[atSize+1:]
|
||||
}
|
||||
return ""
|
||||
}
|
|
@ -4,7 +4,6 @@ import (
|
|||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Context keys
|
||||
|
@ -13,65 +12,56 @@ var (
|
|||
ErrorCtxKey = &contextKey{"Error"}
|
||||
)
|
||||
|
||||
// Library errors
|
||||
var (
|
||||
ErrExpired = errors.New("internal/sessions: session is expired")
|
||||
ErrNoSessionFound = errors.New("internal/sessions: session is not found")
|
||||
ErrMalformed = errors.New("internal/sessions: session is malformed")
|
||||
)
|
||||
|
||||
// RetrieveSession http middleware handler will verify a auth session from a http request.
|
||||
//
|
||||
// RetrieveSession will search for a auth session in a http request, in the order:
|
||||
// 1. `pomerium_session` URI query parameter
|
||||
// 2. `Authorization: BEARER` request header
|
||||
// 3. Cookie `_pomerium` value
|
||||
func RetrieveSession(s SessionStore) func(http.Handler) http.Handler {
|
||||
func RetrieveSession(s ...SessionLoader) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return retrieve(s, TokenFromQuery, TokenFromHeader, TokenFromCookie)(next)
|
||||
return retrieve(s...)(next)
|
||||
}
|
||||
}
|
||||
|
||||
func retrieve(s SessionStore, findTokenFns ...func(r *http.Request) string) func(http.Handler) http.Handler {
|
||||
func retrieve(s ...SessionLoader) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
hfn := func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
token, err := retrieveFromRequest(s, r, findTokenFns...)
|
||||
ctx = NewContext(ctx, token, err)
|
||||
state, err := retrieveFromRequest(r, s...)
|
||||
ctx = NewContext(ctx, state, err)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
}
|
||||
return http.HandlerFunc(hfn)
|
||||
}
|
||||
}
|
||||
|
||||
func retrieveFromRequest(s SessionStore, r *http.Request, findTokenFns ...func(r *http.Request) string) (*State, error) {
|
||||
var tokenStr string
|
||||
func retrieveFromRequest(r *http.Request, sessions ...SessionLoader) (*State, error) {
|
||||
state := new(State)
|
||||
var err error
|
||||
|
||||
// Extract token string from the request by calling token find functions in
|
||||
// Extract sessions state from the request by calling token find functions in
|
||||
// the order they where provided. Further extraction stops if a function
|
||||
// returns a non-empty string.
|
||||
for _, fn := range findTokenFns {
|
||||
tokenStr = fn(r)
|
||||
if tokenStr != "" {
|
||||
for _, s := range sessions {
|
||||
state, err = s.LoadSession(r)
|
||||
if err != nil && !errors.Is(err, ErrNoSessionFound) {
|
||||
// unexpected error
|
||||
return nil, err
|
||||
}
|
||||
// break, we found a session state
|
||||
if state != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
if tokenStr == "" {
|
||||
// no session found if state is still empty
|
||||
if state == nil {
|
||||
return nil, ErrNoSessionFound
|
||||
}
|
||||
|
||||
state, err := s.LoadSession(r)
|
||||
if err != nil {
|
||||
return nil, ErrMalformed
|
||||
}
|
||||
err = state.Valid()
|
||||
if err != nil {
|
||||
if err = state.Valid(); err != nil {
|
||||
// a little unusual but we want to return the expired state too
|
||||
return state, err
|
||||
}
|
||||
|
||||
// Valid!
|
||||
return state, nil
|
||||
}
|
||||
|
||||
|
@ -89,35 +79,6 @@ func FromContext(ctx context.Context) (*State, error) {
|
|||
return state, err
|
||||
}
|
||||
|
||||
// TokenFromCookie tries to retrieve the token string from a cookie named
|
||||
// "_pomerium".
|
||||
func TokenFromCookie(r *http.Request) string {
|
||||
cookie, err := r.Cookie("_pomerium")
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return cookie.Value
|
||||
}
|
||||
|
||||
// TokenFromHeader tries to retrieve the token string from the
|
||||
// "Authorization" request header: "Authorization: BEARER T".
|
||||
func TokenFromHeader(r *http.Request) string {
|
||||
// Get token from authorization header.
|
||||
bearer := r.Header.Get("Authorization")
|
||||
if len(bearer) > 7 && strings.EqualFold(bearer[0:6], "BEARER") {
|
||||
return bearer[7:]
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// TokenFromQuery tries to retrieve the token string from the "pomerium_session" URI
|
||||
// query parameter.
|
||||
// todo(bdd) : document setting session code as queryparam
|
||||
func TokenFromQuery(r *http.Request) string {
|
||||
// Get token from query param named "pomerium_session".
|
||||
return r.URL.Query().Get("pomerium_session")
|
||||
}
|
||||
|
||||
// contextKey is a value for use with context.WithValue. It's used as
|
||||
// a pointer so it fits in an interface{} without allocation. This technique
|
||||
// for defining context keys was copied from Go 1.7's new use of context in net/http.
|
||||
|
@ -126,5 +87,5 @@ type contextKey struct {
|
|||
}
|
||||
|
||||
func (k *contextKey) String() string {
|
||||
return "SessionStore context value " + k.name
|
||||
return "context value " + k.name
|
||||
}
|
||||
|
|
|
@ -9,9 +9,8 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/cryptutil"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/pomerium/pomerium/internal/cryptutil"
|
||||
)
|
||||
|
||||
func TestNewContext(t *testing.T) {
|
||||
|
@ -75,8 +74,8 @@ func TestVerifier(t *testing.T) {
|
|||
{"expired auth header", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(-10 * time.Second)}, false, true, false, "internal/sessions: session is expired\n", http.StatusUnauthorized},
|
||||
{"malformed auth header", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(-10 * time.Second)}, false, true, false, "internal/sessions: session is malformed\n", http.StatusUnauthorized},
|
||||
{"good auth query param session", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(10 * time.Second)}, false, true, true, http.StatusText(http.StatusOK), http.StatusOK},
|
||||
{"expired auth query param", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(-10 * time.Second)}, false, true, true, "internal/sessions: session is expired\n", http.StatusUnauthorized},
|
||||
{"malformed auth query param", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(-10 * time.Second)}, false, true, true, "internal/sessions: session is malformed\n", http.StatusUnauthorized},
|
||||
{"expired auth query param", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(-10 * time.Second)}, false, false, true, "internal/sessions: session is expired\n", http.StatusUnauthorized},
|
||||
{"malformed auth query param", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(-10 * time.Second)}, false, false, true, "internal/sessions: session is malformed\n", http.StatusUnauthorized},
|
||||
{"no session", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(-10 * time.Second)}, false, false, false, "internal/sessions: session is not found\n", http.StatusUnauthorized},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
|
@ -94,7 +93,6 @@ func TestVerifier(t *testing.T) {
|
|||
if strings.Contains(tt.name, "malformed") {
|
||||
// add some garbage to the end of the string
|
||||
encSession += cryptutil.NewBase64Key()
|
||||
fmt.Println(encSession)
|
||||
}
|
||||
|
||||
cs, err := NewCookieStore(&CookieStoreOptions{
|
||||
|
@ -104,6 +102,9 @@ func TestVerifier(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
as := NewHeaderStore(encoder)
|
||||
|
||||
qp := NewQueryParamStore(encoder)
|
||||
|
||||
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
r.Header.Set("Accept", "application/json")
|
||||
|
@ -114,11 +115,12 @@ func TestVerifier(t *testing.T) {
|
|||
r.Header.Set("Authorization", "Bearer "+encSession)
|
||||
} else if tt.param {
|
||||
q := r.URL.Query()
|
||||
q.Add("pomerium_session", encSession)
|
||||
|
||||
q.Set("pomerium_session", encSession)
|
||||
r.URL.RawQuery = q.Encode()
|
||||
}
|
||||
|
||||
got := RetrieveSession(cs)(testAuthorizer((fnh)))
|
||||
got := RetrieveSession(cs, as, qp)(testAuthorizer((fnh)))
|
||||
got.ServeHTTP(w, r)
|
||||
|
||||
gotBody := w.Body.String()
|
||||
|
@ -133,3 +135,23 @@ func TestVerifier(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_contextKey_String(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
keyName string
|
||||
want string
|
||||
}{
|
||||
{"simple example", "test", "context value test"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
k := &contextKey{
|
||||
name: tt.keyName,
|
||||
}
|
||||
if got := k.String(); got != tt.want {
|
||||
t.Errorf("contextKey.String() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
44
internal/sessions/query_store.go
Normal file
44
internal/sessions/query_store.go
Normal file
|
@ -0,0 +1,44 @@
|
|||
package sessions // import "github.com/pomerium/pomerium/internal/sessions"
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/cryptutil"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultQueryParamKey = "pomerium_session"
|
||||
)
|
||||
|
||||
// QueryParamStore implements the load session store interface using http
|
||||
// query strings / query parameters.
|
||||
type QueryParamStore struct {
|
||||
queryParamKey string
|
||||
encoder cryptutil.SecureEncoder
|
||||
}
|
||||
|
||||
// NewQueryParamStore returns a new query param store for loading sessions from
|
||||
// query strings / query parameters.
|
||||
func NewQueryParamStore(enc cryptutil.SecureEncoder) *QueryParamStore {
|
||||
return &QueryParamStore{
|
||||
queryParamKey: defaultQueryParamKey,
|
||||
encoder: enc,
|
||||
}
|
||||
}
|
||||
|
||||
// LoadSession tries to retrieve the token string from URL query parameters.
|
||||
//
|
||||
// NOTA BENE: By default, most servers _DO_ log query params, the leaking or
|
||||
// accidental logging of which should be considered a security issue.
|
||||
func (qp *QueryParamStore) LoadSession(r *http.Request) (*State, error) {
|
||||
cipherText := r.URL.Query().Get(qp.queryParamKey)
|
||||
if cipherText == "" {
|
||||
return nil, ErrNoSessionFound
|
||||
}
|
||||
session, err := UnmarshalSession(cipherText, qp.encoder)
|
||||
if err != nil {
|
||||
return nil, ErrMalformed
|
||||
}
|
||||
return session, nil
|
||||
|
||||
}
|
|
@ -5,8 +5,14 @@ import (
|
|||
"net/http"
|
||||
)
|
||||
|
||||
// ErrEmptySession is an error for an empty sessions.
|
||||
var ErrEmptySession = errors.New("internal/sessions: empty session")
|
||||
var (
|
||||
// ErrExpired is the error for an expired session.
|
||||
ErrExpired = errors.New("internal/sessions: session is expired")
|
||||
// ErrNoSessionFound is the error for when no session is found.
|
||||
ErrNoSessionFound = errors.New("internal/sessions: session is not found")
|
||||
// ErrMalformed is the error for when a session is found but is malformed.
|
||||
ErrMalformed = errors.New("internal/sessions: session is malformed")
|
||||
)
|
||||
|
||||
// SessionStore has the functions for setting, getting, and clearing the Session cookie
|
||||
type SessionStore interface {
|
||||
|
@ -14,3 +20,9 @@ type SessionStore interface {
|
|||
LoadSession(*http.Request) (*State, error)
|
||||
SaveSession(http.ResponseWriter, *http.Request, *State) error
|
||||
}
|
||||
|
||||
// SessionLoader is implemented by any struct that loads a pomerium session
|
||||
// given a request, and returns a user state.
|
||||
type SessionLoader interface {
|
||||
LoadSession(*http.Request) (*State, error)
|
||||
}
|
||||
|
|
12
internal/sessions/util.go
Normal file
12
internal/sessions/util.go
Normal file
|
@ -0,0 +1,12 @@
|
|||
package sessions // import "github.com/pomerium/pomerium/internal/sessions"
|
||||
|
||||
import "strings"
|
||||
|
||||
// ParentSubdomain returns the parent subdomain.
|
||||
func ParentSubdomain(s string) string {
|
||||
if strings.Count(s, ".") < 2 {
|
||||
return ""
|
||||
}
|
||||
split := strings.SplitN(s, ".", 2)
|
||||
return split[1]
|
||||
}
|
23
internal/sessions/util_test.go
Normal file
23
internal/sessions/util_test.go
Normal file
|
@ -0,0 +1,23 @@
|
|||
package sessions
|
||||
|
||||
import "testing"
|
||||
|
||||
func Test_ParentSubdomain(t *testing.T) {
|
||||
t.Parallel()
|
||||
tests := []struct {
|
||||
s string
|
||||
want string
|
||||
}{
|
||||
{"httpbin.corp.example.com", "corp.example.com"},
|
||||
{"some.httpbin.corp.example.com", "httpbin.corp.example.com"},
|
||||
{"example.com", ""},
|
||||
{"", ""},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.s, func(t *testing.T) {
|
||||
if got := ParentSubdomain(tt.s); got != tt.want {
|
||||
t.Errorf("ParentSubdomain() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -184,8 +184,8 @@ func (p *Proxy) Verify(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
// check the queryparams to see if this check immediately followed
|
||||
// authentication. If so, redirect back to the originally requested hostname.
|
||||
if isCallback := r.URL.Query().Get(disableCallback); isCallback == "true" {
|
||||
http.Redirect(w, r, "http://"+hostname, http.StatusFound)
|
||||
if isCallback := r.URL.Query().Get("pomerium-auth-callback"); isCallback == "true" {
|
||||
http.Redirect(w, r, hostname, http.StatusFound)
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
@ -302,13 +302,13 @@ func TestProxy_VerifyWithMiddleware(t *testing.T) {
|
|||
authorizer clients.Authorizer
|
||||
wantStatus int
|
||||
}{
|
||||
{"good", opts, nil, http.MethodGet, "false", "/.pomerium/verify/some.domain.name", &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusOK},
|
||||
{"good post auth redirect", opts, nil, http.MethodGet, "true", "/.pomerium/verify/some.domain.name", &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound},
|
||||
{"not authorized", opts, nil, http.MethodGet, "false", "/.pomerium/verify/some.domain.name", &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusForbidden},
|
||||
{"not authorized expired, redirect to auth", opts, nil, http.MethodGet, "false", "/.pomerium/verify/some.domain.name", &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusFound},
|
||||
{"not authorized expired, don't redirect!", opts, nil, http.MethodGet, "true", "/.pomerium/verify/some.domain.name?no_redirect=true", &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusUnauthorized},
|
||||
{"not authorized because of error", opts, nil, http.MethodGet, "false", "/.pomerium/verify/some.domain.name", &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{AuthorizeError: errors.New("authz error")}, http.StatusInternalServerError},
|
||||
{"bad context retrieval error", opts, errors.New("oh no"), http.MethodGet, "false", "/.pomerium/verify/some.domain.name", &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusOK},
|
||||
{"good", opts, nil, http.MethodGet, "", "/.pomerium/verify/some.domain.name", &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusOK},
|
||||
{"good post auth redirect", opts, nil, http.MethodGet, "pomerium-auth-callback", "/.pomerium/verify/some.domain.name", &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound},
|
||||
{"not authorized", opts, nil, http.MethodGet, "", "/.pomerium/verify/some.domain.name", &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusForbidden},
|
||||
{"not authorized expired, redirect to auth", opts, nil, http.MethodGet, "", "/.pomerium/verify/some.domain.name", &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusFound},
|
||||
{"not authorized expired, don't redirect!", opts, nil, http.MethodGet, HeaderNoAuthRedirect, "/.pomerium/verify/some.domain.name?no_redirect=true", &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusUnauthorized},
|
||||
{"not authorized because of error", opts, nil, http.MethodGet, "", "/.pomerium/verify/some.domain.name", &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{AuthorizeError: errors.New("authz error")}, http.StatusInternalServerError},
|
||||
{"bad context retrieval error", opts, errors.New("oh no"), http.MethodGet, "", "/.pomerium/verify/some.domain.name", &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusOK},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
@ -322,8 +322,8 @@ func TestProxy_VerifyWithMiddleware(t *testing.T) {
|
|||
p.UpdateOptions(tt.options)
|
||||
uri := &url.URL{Path: tt.path}
|
||||
queryString := uri.Query()
|
||||
if tt.qp == "true" {
|
||||
queryString.Set("pomerium-auth-callback", tt.qp)
|
||||
if tt.qp != "" {
|
||||
queryString.Set(tt.qp, "true")
|
||||
}
|
||||
uri.RawQuery = queryString.Encode()
|
||||
|
||||
|
|
|
@ -23,7 +23,9 @@ const (
|
|||
// HeaderGroups is the header key containing the user's groups.
|
||||
HeaderGroups = "x-pomerium-authenticated-user-groups"
|
||||
|
||||
disableCallback = "pomerium-auth-callback"
|
||||
// HeaderNoAuthRedirect is the header / query param key used to disable
|
||||
// redirecting unauthenticated request by default but instead return a 401.
|
||||
HeaderNoAuthRedirect = "x-pomerium-no-auth-redirect"
|
||||
)
|
||||
|
||||
// AuthenticateSession is middleware to enforce a valid authentication
|
||||
|
@ -33,8 +35,8 @@ func (p *Proxy) AuthenticateSession(next http.Handler) http.Handler {
|
|||
ctx, span := trace.StartSpan(r.Context(), "middleware.AuthenticateSession")
|
||||
defer span.End()
|
||||
s, err := sessions.FromContext(r.Context())
|
||||
if err != nil {
|
||||
log.Debug().Str("cause", err.Error()).Msg("proxy: re-authenticating due to session state error")
|
||||
if err != nil || s == nil {
|
||||
log.Debug().Msg("proxy: re-authenticating due to session state error")
|
||||
p.reqNeedsAuthentication(w, r)
|
||||
return
|
||||
}
|
||||
|
@ -58,7 +60,7 @@ func (p *Proxy) AuthorizeSession(next http.Handler) http.Handler {
|
|||
ctx, span := trace.StartSpan(r.Context(), "middleware.AuthorizeSession")
|
||||
defer span.End()
|
||||
s, err := sessions.FromContext(r.Context())
|
||||
if err != nil {
|
||||
if err != nil || s == nil {
|
||||
httputil.ErrorResponse(w, r.WithContext(ctx), httputil.Error("", http.StatusForbidden, err))
|
||||
return
|
||||
}
|
||||
|
@ -105,9 +107,11 @@ func (p *Proxy) reqNeedsAuthentication(w http.ResponseWriter, r *http.Request) {
|
|||
// some proxies like nginx won't follow redirects, and treat any
|
||||
// non 2xx or 4xx status as an internal service error.
|
||||
// https://nginx.org/en/docs/http/ngx_http_auth_request_module.html
|
||||
if _, ok := r.URL.Query()[disableCallback]; ok {
|
||||
redirectHeader := r.Header.Get(HeaderNoAuthRedirect)
|
||||
if _, ok := r.URL.Query()[HeaderNoAuthRedirect]; ok || redirectHeader == "true" {
|
||||
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
|
||||
}
|
||||
r.Header.Get(HeaderNoAuthRedirect)
|
||||
uri := urlutil.SignedRedirectURL(p.SharedKey, p.authenticateSigninURL, urlutil.GetAbsoluteURL(r))
|
||||
http.Redirect(w, r, uri.String(), http.StatusFound)
|
||||
}
|
||||
|
|
|
@ -78,6 +78,7 @@ type Proxy struct {
|
|||
refreshCooldown time.Duration
|
||||
Handler http.Handler
|
||||
sessionStore sessions.SessionStore
|
||||
sessionLoaders []sessions.SessionLoader
|
||||
signingKey string
|
||||
templates *template.Template
|
||||
}
|
||||
|
@ -122,8 +123,12 @@ func New(opts config.Options) (*Proxy, error) {
|
|||
defaultUpstreamTimeout: opts.DefaultUpstreamTimeout,
|
||||
refreshCooldown: opts.RefreshCooldown,
|
||||
sessionStore: cookieStore,
|
||||
signingKey: opts.SigningKey,
|
||||
templates: templates.New(),
|
||||
sessionLoaders: []sessions.SessionLoader{
|
||||
cookieStore,
|
||||
sessions.NewHeaderStore(encoder),
|
||||
sessions.NewQueryParamStore(encoder)},
|
||||
signingKey: opts.SigningKey,
|
||||
templates: templates.New(),
|
||||
}
|
||||
// errors checked in ValidateOptions
|
||||
p.authorizeURL, _ = urlutil.DeepCopy(opts.AuthorizeURL)
|
||||
|
@ -227,7 +232,7 @@ func (p *Proxy) reverseProxyHandler(r *mux.Router, policy *config.Policy) (*mux.
|
|||
}
|
||||
|
||||
// 4. Retrieve the user session and add it to the request context
|
||||
rp.Use(sessions.RetrieveSession(p.sessionStore))
|
||||
rp.Use(sessions.RetrieveSession(p.sessionLoaders...))
|
||||
// 5. Strip the user session cookie from the downstream request
|
||||
rp.Use(middleware.StripCookie(p.cookieName))
|
||||
// 6. AuthN - Verify the user is authenticated. Set email, group, & id headers
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue