mirror of
https://github.com/pomerium/pomerium.git
synced 2025-06-06 04:42:56 +02:00
internal/sessions: refactor how sessions loading (#351)
These chagnes standardize how session loading is done for session cookie, auth bearer token, and query params. - Bearer token previously combined with session cookie. - rearranged cookie-store to put exported methods above unexported - added header store that implements session loader interface - added query param store that implements session loader interface Signed-off-by: Bobby DeSimone <bobbydesimone@gmail.com>
This commit is contained in:
parent
7aa4621b1b
commit
badd8d69af
13 changed files with 322 additions and 234 deletions
|
@ -10,47 +10,39 @@ import (
|
||||||
"github.com/pomerium/pomerium/internal/cryptutil"
|
"github.com/pomerium/pomerium/internal/cryptutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ChunkedCanaryByte is the byte value used as a canary prefix to distinguish if
|
const (
|
||||||
// the cookie is multi-part or not. This constant *should not* be valid
|
// ChunkedCanaryByte is the byte value used as a canary prefix to distinguish if
|
||||||
// base64. It's important this byte is ASCII to avoid UTF-8 variable sized runes.
|
// the cookie is multi-part or not. This constant *should not* be valid
|
||||||
// https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Set-Cookie#Directives
|
// base64. It's important this byte is ASCII to avoid UTF-8 variable sized runes.
|
||||||
const ChunkedCanaryByte byte = '%'
|
// https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Set-Cookie#Directives
|
||||||
|
ChunkedCanaryByte byte = '%'
|
||||||
|
// MaxChunkSize sets the upper bound on a cookie chunks payload value.
|
||||||
|
// Note, this should be lower than the actual cookie's max size (4096 bytes)
|
||||||
|
// which includes metadata.
|
||||||
|
MaxChunkSize = 3800
|
||||||
|
// MaxNumChunks limits the number of chunks to iterate through. Conservatively
|
||||||
|
// set to prevent any abuse.
|
||||||
|
MaxNumChunks = 5
|
||||||
|
)
|
||||||
|
|
||||||
// DefaultBearerTokenHeader is default header name for the authorization bearer
|
// CookieStore implements the session store interface for session cookies.
|
||||||
// token header as defined in rfc2617
|
|
||||||
// https://tools.ietf.org/html/rfc6750#section-2.1
|
|
||||||
const DefaultBearerTokenHeader = "Authorization"
|
|
||||||
|
|
||||||
// MaxChunkSize sets the upper bound on a cookie chunks payload value.
|
|
||||||
// Note, this should be lower than the actual cookie's max size (4096 bytes)
|
|
||||||
// which includes metadata.
|
|
||||||
const MaxChunkSize = 3800
|
|
||||||
|
|
||||||
// MaxNumChunks limits the number of chunks to iterate through. Conservatively
|
|
||||||
// set to prevent any abuse.
|
|
||||||
const MaxNumChunks = 5
|
|
||||||
|
|
||||||
// CookieStore represents all the cookie related configurations
|
|
||||||
type CookieStore struct {
|
type CookieStore struct {
|
||||||
Name string
|
Name string
|
||||||
Encoder cryptutil.SecureEncoder
|
CookieDomain string
|
||||||
CookieExpire time.Duration
|
CookieExpire time.Duration
|
||||||
CookieRefresh time.Duration
|
CookieHTTPOnly bool
|
||||||
CookieSecure bool
|
CookieSecure bool
|
||||||
CookieHTTPOnly bool
|
Encoder cryptutil.SecureEncoder
|
||||||
CookieDomain string
|
|
||||||
BearerTokenHeader string
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// CookieStoreOptions holds options for CookieStore
|
// CookieStoreOptions holds options for CookieStore
|
||||||
type CookieStoreOptions struct {
|
type CookieStoreOptions struct {
|
||||||
Name string
|
Name string
|
||||||
CookieSecure bool
|
CookieDomain string
|
||||||
CookieHTTPOnly bool
|
CookieExpire time.Duration
|
||||||
CookieDomain string
|
CookieHTTPOnly bool
|
||||||
BearerTokenHeader string
|
CookieSecure bool
|
||||||
CookieExpire time.Duration
|
Encoder cryptutil.SecureEncoder
|
||||||
Encoder cryptutil.SecureEncoder
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewCookieStore returns a new session with ciphers for each of the cookie secrets
|
// NewCookieStore returns a new session with ciphers for each of the cookie secrets
|
||||||
|
@ -61,18 +53,14 @@ func NewCookieStore(opts *CookieStoreOptions) (*CookieStore, error) {
|
||||||
if opts.Encoder == nil {
|
if opts.Encoder == nil {
|
||||||
return nil, fmt.Errorf("internal/sessions: cipher cannot be nil")
|
return nil, fmt.Errorf("internal/sessions: cipher cannot be nil")
|
||||||
}
|
}
|
||||||
if opts.BearerTokenHeader == "" {
|
|
||||||
opts.BearerTokenHeader = DefaultBearerTokenHeader
|
|
||||||
}
|
|
||||||
|
|
||||||
return &CookieStore{
|
return &CookieStore{
|
||||||
Name: opts.Name,
|
Name: opts.Name,
|
||||||
CookieSecure: opts.CookieSecure,
|
CookieSecure: opts.CookieSecure,
|
||||||
CookieHTTPOnly: opts.CookieHTTPOnly,
|
CookieHTTPOnly: opts.CookieHTTPOnly,
|
||||||
CookieDomain: opts.CookieDomain,
|
CookieDomain: opts.CookieDomain,
|
||||||
CookieExpire: opts.CookieExpire,
|
CookieExpire: opts.CookieExpire,
|
||||||
Encoder: opts.Encoder,
|
Encoder: opts.Encoder,
|
||||||
BearerTokenHeader: opts.BearerTokenHeader,
|
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -103,11 +91,43 @@ func (cs *CookieStore) makeCookie(req *http.Request, name string, value string,
|
||||||
return c
|
return c
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ClearSession clears the session cookie from a request
|
||||||
|
func (cs *CookieStore) ClearSession(w http.ResponseWriter, req *http.Request) {
|
||||||
|
http.SetCookie(w, cs.makeCookie(req, cs.Name, "", time.Hour*-1, time.Now()))
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadSession returns a State from the cookie in the request.
|
||||||
|
func (cs *CookieStore) LoadSession(req *http.Request) (*State, error) {
|
||||||
|
cipherText := loadChunkedCookie(req, cs.Name)
|
||||||
|
if cipherText == "" {
|
||||||
|
return nil, ErrNoSessionFound
|
||||||
|
}
|
||||||
|
session, err := UnmarshalSession(cipherText, cs.Encoder)
|
||||||
|
if err != nil {
|
||||||
|
return nil, ErrMalformed
|
||||||
|
}
|
||||||
|
return session, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SaveSession saves a session state to a request sessions.
|
||||||
|
func (cs *CookieStore) SaveSession(w http.ResponseWriter, req *http.Request, s *State) error {
|
||||||
|
value, err := MarshalSession(s, cs.Encoder)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
cs.setSessionCookie(w, req, value)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// makeSessionCookie constructs a session cookie given the request, an expiration time and the current time.
|
// makeSessionCookie constructs a session cookie given the request, an expiration time and the current time.
|
||||||
func (cs *CookieStore) makeSessionCookie(req *http.Request, value string, expiration time.Duration, now time.Time) *http.Cookie {
|
func (cs *CookieStore) makeSessionCookie(req *http.Request, value string, expiration time.Duration, now time.Time) *http.Cookie {
|
||||||
return cs.makeCookie(req, cs.Name, value, expiration, now)
|
return cs.makeCookie(req, cs.Name, value, expiration, now)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (cs *CookieStore) setSessionCookie(w http.ResponseWriter, req *http.Request, val string) {
|
||||||
|
cs.setCookie(w, cs.makeSessionCookie(req, val, cs.CookieExpire, time.Now()))
|
||||||
|
}
|
||||||
|
|
||||||
func (cs *CookieStore) setCookie(w http.ResponseWriter, cookie *http.Cookie) {
|
func (cs *CookieStore) setCookie(w http.ResponseWriter, cookie *http.Cookie) {
|
||||||
if len(cookie.String()) <= MaxChunkSize {
|
if len(cookie.String()) <= MaxChunkSize {
|
||||||
http.SetCookie(w, cookie)
|
http.SetCookie(w, cookie)
|
||||||
|
@ -128,35 +148,6 @@ func (cs *CookieStore) setCookie(w http.ResponseWriter, cookie *http.Cookie) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func chunk(s string, size int) []string {
|
|
||||||
ss := make([]string, 0, len(s)/size+1)
|
|
||||||
for len(s) > 0 {
|
|
||||||
if len(s) < size {
|
|
||||||
size = len(s)
|
|
||||||
}
|
|
||||||
ss, s = append(ss, s[:size]), s[size:]
|
|
||||||
}
|
|
||||||
return ss
|
|
||||||
}
|
|
||||||
|
|
||||||
// ClearSession clears the session cookie from a request
|
|
||||||
func (cs *CookieStore) ClearSession(w http.ResponseWriter, req *http.Request) {
|
|
||||||
http.SetCookie(w, cs.makeCookie(req, cs.Name, "", time.Hour*-1, time.Now()))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (cs *CookieStore) setSessionCookie(w http.ResponseWriter, req *http.Request, val string) {
|
|
||||||
cs.setCookie(w, cs.makeSessionCookie(req, val, cs.CookieExpire, time.Now()))
|
|
||||||
}
|
|
||||||
|
|
||||||
func loadBearerToken(r *http.Request, headerKey string) string {
|
|
||||||
authHeader := r.Header.Get(headerKey)
|
|
||||||
split := strings.Split(authHeader, "Bearer")
|
|
||||||
if authHeader == "" || len(split) != 2 {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
return strings.TrimSpace(split[1])
|
|
||||||
}
|
|
||||||
|
|
||||||
func loadChunkedCookie(r *http.Request, cookieName string) string {
|
func loadChunkedCookie(r *http.Request, cookieName string) string {
|
||||||
c, err := r.Cookie(cookieName)
|
c, err := r.Cookie(cookieName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -179,37 +170,13 @@ func loadChunkedCookie(r *http.Request, cookieName string) string {
|
||||||
return cipherText
|
return cipherText
|
||||||
}
|
}
|
||||||
|
|
||||||
// LoadSession returns a State from the cookie in the request.
|
func chunk(s string, size int) []string {
|
||||||
func (cs *CookieStore) LoadSession(req *http.Request) (*State, error) {
|
ss := make([]string, 0, len(s)/size+1)
|
||||||
cipherText := loadChunkedCookie(req, cs.Name)
|
for len(s) > 0 {
|
||||||
if cipherText == "" {
|
if len(s) < size {
|
||||||
cipherText = loadBearerToken(req, cs.BearerTokenHeader)
|
size = len(s)
|
||||||
|
}
|
||||||
|
ss, s = append(ss, s[:size]), s[size:]
|
||||||
}
|
}
|
||||||
if cipherText == "" {
|
return ss
|
||||||
return nil, ErrEmptySession
|
|
||||||
}
|
|
||||||
session, err := UnmarshalSession(cipherText, cs.Encoder)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return session, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// SaveSession saves a session state to a request sessions.
|
|
||||||
func (cs *CookieStore) SaveSession(w http.ResponseWriter, req *http.Request, s *State) error {
|
|
||||||
value, err := MarshalSession(s, cs.Encoder)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
cs.setSessionCookie(w, req, value)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ParentSubdomain returns the parent subdomain.
|
|
||||||
func ParentSubdomain(s string) string {
|
|
||||||
if strings.Count(s, ".") < 2 {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
split := strings.SplitN(s, ".", 2)
|
|
||||||
return split[1]
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -38,33 +38,30 @@ func TestNewCookieStore(t *testing.T) {
|
||||||
}{
|
}{
|
||||||
{"good",
|
{"good",
|
||||||
&CookieStoreOptions{
|
&CookieStoreOptions{
|
||||||
Name: "_cookie",
|
Name: "_cookie",
|
||||||
CookieSecure: true,
|
CookieSecure: true,
|
||||||
CookieHTTPOnly: true,
|
CookieHTTPOnly: true,
|
||||||
CookieDomain: "pomerium.io",
|
CookieDomain: "pomerium.io",
|
||||||
CookieExpire: 10 * time.Second,
|
CookieExpire: 10 * time.Second,
|
||||||
Encoder: encoder,
|
Encoder: encoder,
|
||||||
BearerTokenHeader: "Authorization",
|
|
||||||
},
|
},
|
||||||
&CookieStore{
|
&CookieStore{
|
||||||
Name: "_cookie",
|
Name: "_cookie",
|
||||||
CookieSecure: true,
|
CookieSecure: true,
|
||||||
CookieHTTPOnly: true,
|
CookieHTTPOnly: true,
|
||||||
CookieDomain: "pomerium.io",
|
CookieDomain: "pomerium.io",
|
||||||
CookieExpire: 10 * time.Second,
|
CookieExpire: 10 * time.Second,
|
||||||
Encoder: encoder,
|
Encoder: encoder,
|
||||||
BearerTokenHeader: "Authorization",
|
|
||||||
},
|
},
|
||||||
false},
|
false},
|
||||||
{"missing name",
|
{"missing name",
|
||||||
&CookieStoreOptions{
|
&CookieStoreOptions{
|
||||||
Name: "",
|
Name: "",
|
||||||
CookieSecure: true,
|
CookieSecure: true,
|
||||||
CookieHTTPOnly: true,
|
CookieHTTPOnly: true,
|
||||||
CookieDomain: "pomerium.io",
|
CookieDomain: "pomerium.io",
|
||||||
CookieExpire: 10 * time.Second,
|
CookieExpire: 10 * time.Second,
|
||||||
Encoder: encoder,
|
Encoder: encoder,
|
||||||
BearerTokenHeader: "Authorization",
|
|
||||||
},
|
},
|
||||||
nil,
|
nil,
|
||||||
true},
|
true},
|
||||||
|
@ -250,23 +247,3 @@ func TestMockSessionStore(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_ParentSubdomain(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
tests := []struct {
|
|
||||||
s string
|
|
||||||
want string
|
|
||||||
}{
|
|
||||||
{"httpbin.corp.example.com", "corp.example.com"},
|
|
||||||
{"some.httpbin.corp.example.com", "httpbin.corp.example.com"},
|
|
||||||
{"example.com", ""},
|
|
||||||
{"", ""},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.s, func(t *testing.T) {
|
|
||||||
if got := ParentSubdomain(tt.s); got != tt.want {
|
|
||||||
t.Errorf("ParentSubdomain() = %v, want %v", got, tt.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
61
internal/sessions/header_store.go
Normal file
61
internal/sessions/header_store.go
Normal file
|
@ -0,0 +1,61 @@
|
||||||
|
package sessions // import "github.com/pomerium/pomerium/internal/sessions"
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/internal/cryptutil"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// defaultAuthHeader and defaultAuthType are default header name for the
|
||||||
|
// authorization bearer token header as defined in rfc2617
|
||||||
|
// https://tools.ietf.org/html/rfc6750#section-2.1
|
||||||
|
defaultAuthHeader = "Authorization"
|
||||||
|
defaultAuthType = "Bearer"
|
||||||
|
)
|
||||||
|
|
||||||
|
// HeaderStore implements the load session store interface using http
|
||||||
|
// authorization headers.
|
||||||
|
type HeaderStore struct {
|
||||||
|
authHeader string
|
||||||
|
authType string
|
||||||
|
encoder cryptutil.SecureEncoder
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewHeaderStore returns a new header store for loading sessions from
|
||||||
|
// authorization headers.
|
||||||
|
func NewHeaderStore(enc cryptutil.SecureEncoder) *HeaderStore {
|
||||||
|
return &HeaderStore{
|
||||||
|
authHeader: defaultAuthHeader,
|
||||||
|
authType: defaultAuthType,
|
||||||
|
encoder: enc,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadSession tries to retrieve the token string from the Authorization header.
|
||||||
|
//
|
||||||
|
// NOTA BENE: While most servers do not log Authorization headers by default,
|
||||||
|
// you should ensure no other services are logging or leaking your auth headers.
|
||||||
|
func (as *HeaderStore) LoadSession(r *http.Request) (*State, error) {
|
||||||
|
cipherText := as.tokenFromHeader(r)
|
||||||
|
if cipherText == "" {
|
||||||
|
return nil, ErrNoSessionFound
|
||||||
|
}
|
||||||
|
session, err := UnmarshalSession(cipherText, as.encoder)
|
||||||
|
if err != nil {
|
||||||
|
return nil, ErrMalformed
|
||||||
|
}
|
||||||
|
return session, nil
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// retrieve the value of the authorization header
|
||||||
|
func (as *HeaderStore) tokenFromHeader(r *http.Request) string {
|
||||||
|
bearer := r.Header.Get(as.authHeader)
|
||||||
|
atSize := len(as.authType)
|
||||||
|
if len(bearer) > atSize && strings.EqualFold(bearer[0:atSize], as.authType) {
|
||||||
|
return bearer[atSize+1:]
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
|
@ -4,7 +4,6 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// Context keys
|
// Context keys
|
||||||
|
@ -13,65 +12,56 @@ var (
|
||||||
ErrorCtxKey = &contextKey{"Error"}
|
ErrorCtxKey = &contextKey{"Error"}
|
||||||
)
|
)
|
||||||
|
|
||||||
// Library errors
|
|
||||||
var (
|
|
||||||
ErrExpired = errors.New("internal/sessions: session is expired")
|
|
||||||
ErrNoSessionFound = errors.New("internal/sessions: session is not found")
|
|
||||||
ErrMalformed = errors.New("internal/sessions: session is malformed")
|
|
||||||
)
|
|
||||||
|
|
||||||
// RetrieveSession http middleware handler will verify a auth session from a http request.
|
|
||||||
//
|
|
||||||
// RetrieveSession will search for a auth session in a http request, in the order:
|
// RetrieveSession will search for a auth session in a http request, in the order:
|
||||||
// 1. `pomerium_session` URI query parameter
|
// 1. `pomerium_session` URI query parameter
|
||||||
// 2. `Authorization: BEARER` request header
|
// 2. `Authorization: BEARER` request header
|
||||||
// 3. Cookie `_pomerium` value
|
// 3. Cookie `_pomerium` value
|
||||||
func RetrieveSession(s SessionStore) func(http.Handler) http.Handler {
|
func RetrieveSession(s ...SessionLoader) func(http.Handler) http.Handler {
|
||||||
return func(next http.Handler) http.Handler {
|
return func(next http.Handler) http.Handler {
|
||||||
return retrieve(s, TokenFromQuery, TokenFromHeader, TokenFromCookie)(next)
|
return retrieve(s...)(next)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func retrieve(s SessionStore, findTokenFns ...func(r *http.Request) string) func(http.Handler) http.Handler {
|
func retrieve(s ...SessionLoader) func(http.Handler) http.Handler {
|
||||||
return func(next http.Handler) http.Handler {
|
return func(next http.Handler) http.Handler {
|
||||||
hfn := func(w http.ResponseWriter, r *http.Request) {
|
hfn := func(w http.ResponseWriter, r *http.Request) {
|
||||||
ctx := r.Context()
|
ctx := r.Context()
|
||||||
token, err := retrieveFromRequest(s, r, findTokenFns...)
|
state, err := retrieveFromRequest(r, s...)
|
||||||
ctx = NewContext(ctx, token, err)
|
ctx = NewContext(ctx, state, err)
|
||||||
next.ServeHTTP(w, r.WithContext(ctx))
|
next.ServeHTTP(w, r.WithContext(ctx))
|
||||||
}
|
}
|
||||||
return http.HandlerFunc(hfn)
|
return http.HandlerFunc(hfn)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func retrieveFromRequest(s SessionStore, r *http.Request, findTokenFns ...func(r *http.Request) string) (*State, error) {
|
func retrieveFromRequest(r *http.Request, sessions ...SessionLoader) (*State, error) {
|
||||||
var tokenStr string
|
state := new(State)
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
// Extract token string from the request by calling token find functions in
|
// Extract sessions state from the request by calling token find functions in
|
||||||
// the order they where provided. Further extraction stops if a function
|
// the order they where provided. Further extraction stops if a function
|
||||||
// returns a non-empty string.
|
// returns a non-empty string.
|
||||||
for _, fn := range findTokenFns {
|
for _, s := range sessions {
|
||||||
tokenStr = fn(r)
|
state, err = s.LoadSession(r)
|
||||||
if tokenStr != "" {
|
if err != nil && !errors.Is(err, ErrNoSessionFound) {
|
||||||
|
// unexpected error
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
// break, we found a session state
|
||||||
|
if state != nil {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if tokenStr == "" {
|
// no session found if state is still empty
|
||||||
|
if state == nil {
|
||||||
return nil, ErrNoSessionFound
|
return nil, ErrNoSessionFound
|
||||||
}
|
}
|
||||||
|
|
||||||
state, err := s.LoadSession(r)
|
if err = state.Valid(); err != nil {
|
||||||
if err != nil {
|
|
||||||
return nil, ErrMalformed
|
|
||||||
}
|
|
||||||
err = state.Valid()
|
|
||||||
if err != nil {
|
|
||||||
// a little unusual but we want to return the expired state too
|
// a little unusual but we want to return the expired state too
|
||||||
return state, err
|
return state, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Valid!
|
|
||||||
return state, nil
|
return state, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -89,35 +79,6 @@ func FromContext(ctx context.Context) (*State, error) {
|
||||||
return state, err
|
return state, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// TokenFromCookie tries to retrieve the token string from a cookie named
|
|
||||||
// "_pomerium".
|
|
||||||
func TokenFromCookie(r *http.Request) string {
|
|
||||||
cookie, err := r.Cookie("_pomerium")
|
|
||||||
if err != nil {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
return cookie.Value
|
|
||||||
}
|
|
||||||
|
|
||||||
// TokenFromHeader tries to retrieve the token string from the
|
|
||||||
// "Authorization" request header: "Authorization: BEARER T".
|
|
||||||
func TokenFromHeader(r *http.Request) string {
|
|
||||||
// Get token from authorization header.
|
|
||||||
bearer := r.Header.Get("Authorization")
|
|
||||||
if len(bearer) > 7 && strings.EqualFold(bearer[0:6], "BEARER") {
|
|
||||||
return bearer[7:]
|
|
||||||
}
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
// TokenFromQuery tries to retrieve the token string from the "pomerium_session" URI
|
|
||||||
// query parameter.
|
|
||||||
// todo(bdd) : document setting session code as queryparam
|
|
||||||
func TokenFromQuery(r *http.Request) string {
|
|
||||||
// Get token from query param named "pomerium_session".
|
|
||||||
return r.URL.Query().Get("pomerium_session")
|
|
||||||
}
|
|
||||||
|
|
||||||
// contextKey is a value for use with context.WithValue. It's used as
|
// contextKey is a value for use with context.WithValue. It's used as
|
||||||
// a pointer so it fits in an interface{} without allocation. This technique
|
// a pointer so it fits in an interface{} without allocation. This technique
|
||||||
// for defining context keys was copied from Go 1.7's new use of context in net/http.
|
// for defining context keys was copied from Go 1.7's new use of context in net/http.
|
||||||
|
@ -126,5 +87,5 @@ type contextKey struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (k *contextKey) String() string {
|
func (k *contextKey) String() string {
|
||||||
return "SessionStore context value " + k.name
|
return "context value " + k.name
|
||||||
}
|
}
|
||||||
|
|
|
@ -9,9 +9,8 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/internal/cryptutil"
|
|
||||||
|
|
||||||
"github.com/google/go-cmp/cmp"
|
"github.com/google/go-cmp/cmp"
|
||||||
|
"github.com/pomerium/pomerium/internal/cryptutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestNewContext(t *testing.T) {
|
func TestNewContext(t *testing.T) {
|
||||||
|
@ -75,8 +74,8 @@ func TestVerifier(t *testing.T) {
|
||||||
{"expired auth header", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(-10 * time.Second)}, false, true, false, "internal/sessions: session is expired\n", http.StatusUnauthorized},
|
{"expired auth header", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(-10 * time.Second)}, false, true, false, "internal/sessions: session is expired\n", http.StatusUnauthorized},
|
||||||
{"malformed auth header", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(-10 * time.Second)}, false, true, false, "internal/sessions: session is malformed\n", http.StatusUnauthorized},
|
{"malformed auth header", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(-10 * time.Second)}, false, true, false, "internal/sessions: session is malformed\n", http.StatusUnauthorized},
|
||||||
{"good auth query param session", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(10 * time.Second)}, false, true, true, http.StatusText(http.StatusOK), http.StatusOK},
|
{"good auth query param session", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(10 * time.Second)}, false, true, true, http.StatusText(http.StatusOK), http.StatusOK},
|
||||||
{"expired auth query param", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(-10 * time.Second)}, false, true, true, "internal/sessions: session is expired\n", http.StatusUnauthorized},
|
{"expired auth query param", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(-10 * time.Second)}, false, false, true, "internal/sessions: session is expired\n", http.StatusUnauthorized},
|
||||||
{"malformed auth query param", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(-10 * time.Second)}, false, true, true, "internal/sessions: session is malformed\n", http.StatusUnauthorized},
|
{"malformed auth query param", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(-10 * time.Second)}, false, false, true, "internal/sessions: session is malformed\n", http.StatusUnauthorized},
|
||||||
{"no session", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(-10 * time.Second)}, false, false, false, "internal/sessions: session is not found\n", http.StatusUnauthorized},
|
{"no session", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(-10 * time.Second)}, false, false, false, "internal/sessions: session is not found\n", http.StatusUnauthorized},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
|
@ -94,7 +93,6 @@ func TestVerifier(t *testing.T) {
|
||||||
if strings.Contains(tt.name, "malformed") {
|
if strings.Contains(tt.name, "malformed") {
|
||||||
// add some garbage to the end of the string
|
// add some garbage to the end of the string
|
||||||
encSession += cryptutil.NewBase64Key()
|
encSession += cryptutil.NewBase64Key()
|
||||||
fmt.Println(encSession)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
cs, err := NewCookieStore(&CookieStoreOptions{
|
cs, err := NewCookieStore(&CookieStoreOptions{
|
||||||
|
@ -104,6 +102,9 @@ func TestVerifier(t *testing.T) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
as := NewHeaderStore(encoder)
|
||||||
|
|
||||||
|
qp := NewQueryParamStore(encoder)
|
||||||
|
|
||||||
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
r.Header.Set("Accept", "application/json")
|
r.Header.Set("Accept", "application/json")
|
||||||
|
@ -114,11 +115,12 @@ func TestVerifier(t *testing.T) {
|
||||||
r.Header.Set("Authorization", "Bearer "+encSession)
|
r.Header.Set("Authorization", "Bearer "+encSession)
|
||||||
} else if tt.param {
|
} else if tt.param {
|
||||||
q := r.URL.Query()
|
q := r.URL.Query()
|
||||||
q.Add("pomerium_session", encSession)
|
|
||||||
|
q.Set("pomerium_session", encSession)
|
||||||
r.URL.RawQuery = q.Encode()
|
r.URL.RawQuery = q.Encode()
|
||||||
}
|
}
|
||||||
|
|
||||||
got := RetrieveSession(cs)(testAuthorizer((fnh)))
|
got := RetrieveSession(cs, as, qp)(testAuthorizer((fnh)))
|
||||||
got.ServeHTTP(w, r)
|
got.ServeHTTP(w, r)
|
||||||
|
|
||||||
gotBody := w.Body.String()
|
gotBody := w.Body.String()
|
||||||
|
@ -133,3 +135,23 @@ func TestVerifier(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Test_contextKey_String(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
keyName string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{"simple example", "test", "context value test"},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
k := &contextKey{
|
||||||
|
name: tt.keyName,
|
||||||
|
}
|
||||||
|
if got := k.String(); got != tt.want {
|
||||||
|
t.Errorf("contextKey.String() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
44
internal/sessions/query_store.go
Normal file
44
internal/sessions/query_store.go
Normal file
|
@ -0,0 +1,44 @@
|
||||||
|
package sessions // import "github.com/pomerium/pomerium/internal/sessions"
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/internal/cryptutil"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
defaultQueryParamKey = "pomerium_session"
|
||||||
|
)
|
||||||
|
|
||||||
|
// QueryParamStore implements the load session store interface using http
|
||||||
|
// query strings / query parameters.
|
||||||
|
type QueryParamStore struct {
|
||||||
|
queryParamKey string
|
||||||
|
encoder cryptutil.SecureEncoder
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewQueryParamStore returns a new query param store for loading sessions from
|
||||||
|
// query strings / query parameters.
|
||||||
|
func NewQueryParamStore(enc cryptutil.SecureEncoder) *QueryParamStore {
|
||||||
|
return &QueryParamStore{
|
||||||
|
queryParamKey: defaultQueryParamKey,
|
||||||
|
encoder: enc,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadSession tries to retrieve the token string from URL query parameters.
|
||||||
|
//
|
||||||
|
// NOTA BENE: By default, most servers _DO_ log query params, the leaking or
|
||||||
|
// accidental logging of which should be considered a security issue.
|
||||||
|
func (qp *QueryParamStore) LoadSession(r *http.Request) (*State, error) {
|
||||||
|
cipherText := r.URL.Query().Get(qp.queryParamKey)
|
||||||
|
if cipherText == "" {
|
||||||
|
return nil, ErrNoSessionFound
|
||||||
|
}
|
||||||
|
session, err := UnmarshalSession(cipherText, qp.encoder)
|
||||||
|
if err != nil {
|
||||||
|
return nil, ErrMalformed
|
||||||
|
}
|
||||||
|
return session, nil
|
||||||
|
|
||||||
|
}
|
|
@ -5,8 +5,14 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ErrEmptySession is an error for an empty sessions.
|
var (
|
||||||
var ErrEmptySession = errors.New("internal/sessions: empty session")
|
// ErrExpired is the error for an expired session.
|
||||||
|
ErrExpired = errors.New("internal/sessions: session is expired")
|
||||||
|
// ErrNoSessionFound is the error for when no session is found.
|
||||||
|
ErrNoSessionFound = errors.New("internal/sessions: session is not found")
|
||||||
|
// ErrMalformed is the error for when a session is found but is malformed.
|
||||||
|
ErrMalformed = errors.New("internal/sessions: session is malformed")
|
||||||
|
)
|
||||||
|
|
||||||
// SessionStore has the functions for setting, getting, and clearing the Session cookie
|
// SessionStore has the functions for setting, getting, and clearing the Session cookie
|
||||||
type SessionStore interface {
|
type SessionStore interface {
|
||||||
|
@ -14,3 +20,9 @@ type SessionStore interface {
|
||||||
LoadSession(*http.Request) (*State, error)
|
LoadSession(*http.Request) (*State, error)
|
||||||
SaveSession(http.ResponseWriter, *http.Request, *State) error
|
SaveSession(http.ResponseWriter, *http.Request, *State) error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SessionLoader is implemented by any struct that loads a pomerium session
|
||||||
|
// given a request, and returns a user state.
|
||||||
|
type SessionLoader interface {
|
||||||
|
LoadSession(*http.Request) (*State, error)
|
||||||
|
}
|
||||||
|
|
12
internal/sessions/util.go
Normal file
12
internal/sessions/util.go
Normal file
|
@ -0,0 +1,12 @@
|
||||||
|
package sessions // import "github.com/pomerium/pomerium/internal/sessions"
|
||||||
|
|
||||||
|
import "strings"
|
||||||
|
|
||||||
|
// ParentSubdomain returns the parent subdomain.
|
||||||
|
func ParentSubdomain(s string) string {
|
||||||
|
if strings.Count(s, ".") < 2 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
split := strings.SplitN(s, ".", 2)
|
||||||
|
return split[1]
|
||||||
|
}
|
23
internal/sessions/util_test.go
Normal file
23
internal/sessions/util_test.go
Normal file
|
@ -0,0 +1,23 @@
|
||||||
|
package sessions
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func Test_ParentSubdomain(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
tests := []struct {
|
||||||
|
s string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{"httpbin.corp.example.com", "corp.example.com"},
|
||||||
|
{"some.httpbin.corp.example.com", "httpbin.corp.example.com"},
|
||||||
|
{"example.com", ""},
|
||||||
|
{"", ""},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.s, func(t *testing.T) {
|
||||||
|
if got := ParentSubdomain(tt.s); got != tt.want {
|
||||||
|
t.Errorf("ParentSubdomain() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
|
@ -184,8 +184,8 @@ func (p *Proxy) Verify(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
// check the queryparams to see if this check immediately followed
|
// check the queryparams to see if this check immediately followed
|
||||||
// authentication. If so, redirect back to the originally requested hostname.
|
// authentication. If so, redirect back to the originally requested hostname.
|
||||||
if isCallback := r.URL.Query().Get(disableCallback); isCallback == "true" {
|
if isCallback := r.URL.Query().Get("pomerium-auth-callback"); isCallback == "true" {
|
||||||
http.Redirect(w, r, "http://"+hostname, http.StatusFound)
|
http.Redirect(w, r, hostname, http.StatusFound)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -302,13 +302,13 @@ func TestProxy_VerifyWithMiddleware(t *testing.T) {
|
||||||
authorizer clients.Authorizer
|
authorizer clients.Authorizer
|
||||||
wantStatus int
|
wantStatus int
|
||||||
}{
|
}{
|
||||||
{"good", opts, nil, http.MethodGet, "false", "/.pomerium/verify/some.domain.name", &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusOK},
|
{"good", opts, nil, http.MethodGet, "", "/.pomerium/verify/some.domain.name", &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusOK},
|
||||||
{"good post auth redirect", opts, nil, http.MethodGet, "true", "/.pomerium/verify/some.domain.name", &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound},
|
{"good post auth redirect", opts, nil, http.MethodGet, "pomerium-auth-callback", "/.pomerium/verify/some.domain.name", &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound},
|
||||||
{"not authorized", opts, nil, http.MethodGet, "false", "/.pomerium/verify/some.domain.name", &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusForbidden},
|
{"not authorized", opts, nil, http.MethodGet, "", "/.pomerium/verify/some.domain.name", &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusForbidden},
|
||||||
{"not authorized expired, redirect to auth", opts, nil, http.MethodGet, "false", "/.pomerium/verify/some.domain.name", &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusFound},
|
{"not authorized expired, redirect to auth", opts, nil, http.MethodGet, "", "/.pomerium/verify/some.domain.name", &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusFound},
|
||||||
{"not authorized expired, don't redirect!", opts, nil, http.MethodGet, "true", "/.pomerium/verify/some.domain.name?no_redirect=true", &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusUnauthorized},
|
{"not authorized expired, don't redirect!", opts, nil, http.MethodGet, HeaderNoAuthRedirect, "/.pomerium/verify/some.domain.name?no_redirect=true", &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(-10 * time.Second)}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusUnauthorized},
|
||||||
{"not authorized because of error", opts, nil, http.MethodGet, "false", "/.pomerium/verify/some.domain.name", &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{AuthorizeError: errors.New("authz error")}, http.StatusInternalServerError},
|
{"not authorized because of error", opts, nil, http.MethodGet, "", "/.pomerium/verify/some.domain.name", &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{AuthorizeError: errors.New("authz error")}, http.StatusInternalServerError},
|
||||||
{"bad context retrieval error", opts, errors.New("oh no"), http.MethodGet, "false", "/.pomerium/verify/some.domain.name", &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusOK},
|
{"bad context retrieval error", opts, errors.New("oh no"), http.MethodGet, "", "/.pomerium/verify/some.domain.name", &cryptutil.MockEncoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", RefreshDeadline: time.Now().Add(10 * time.Second)}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusOK},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
@ -322,8 +322,8 @@ func TestProxy_VerifyWithMiddleware(t *testing.T) {
|
||||||
p.UpdateOptions(tt.options)
|
p.UpdateOptions(tt.options)
|
||||||
uri := &url.URL{Path: tt.path}
|
uri := &url.URL{Path: tt.path}
|
||||||
queryString := uri.Query()
|
queryString := uri.Query()
|
||||||
if tt.qp == "true" {
|
if tt.qp != "" {
|
||||||
queryString.Set("pomerium-auth-callback", tt.qp)
|
queryString.Set(tt.qp, "true")
|
||||||
}
|
}
|
||||||
uri.RawQuery = queryString.Encode()
|
uri.RawQuery = queryString.Encode()
|
||||||
|
|
||||||
|
|
|
@ -23,7 +23,9 @@ const (
|
||||||
// HeaderGroups is the header key containing the user's groups.
|
// HeaderGroups is the header key containing the user's groups.
|
||||||
HeaderGroups = "x-pomerium-authenticated-user-groups"
|
HeaderGroups = "x-pomerium-authenticated-user-groups"
|
||||||
|
|
||||||
disableCallback = "pomerium-auth-callback"
|
// HeaderNoAuthRedirect is the header / query param key used to disable
|
||||||
|
// redirecting unauthenticated request by default but instead return a 401.
|
||||||
|
HeaderNoAuthRedirect = "x-pomerium-no-auth-redirect"
|
||||||
)
|
)
|
||||||
|
|
||||||
// AuthenticateSession is middleware to enforce a valid authentication
|
// AuthenticateSession is middleware to enforce a valid authentication
|
||||||
|
@ -33,8 +35,8 @@ func (p *Proxy) AuthenticateSession(next http.Handler) http.Handler {
|
||||||
ctx, span := trace.StartSpan(r.Context(), "middleware.AuthenticateSession")
|
ctx, span := trace.StartSpan(r.Context(), "middleware.AuthenticateSession")
|
||||||
defer span.End()
|
defer span.End()
|
||||||
s, err := sessions.FromContext(r.Context())
|
s, err := sessions.FromContext(r.Context())
|
||||||
if err != nil {
|
if err != nil || s == nil {
|
||||||
log.Debug().Str("cause", err.Error()).Msg("proxy: re-authenticating due to session state error")
|
log.Debug().Msg("proxy: re-authenticating due to session state error")
|
||||||
p.reqNeedsAuthentication(w, r)
|
p.reqNeedsAuthentication(w, r)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -58,7 +60,7 @@ func (p *Proxy) AuthorizeSession(next http.Handler) http.Handler {
|
||||||
ctx, span := trace.StartSpan(r.Context(), "middleware.AuthorizeSession")
|
ctx, span := trace.StartSpan(r.Context(), "middleware.AuthorizeSession")
|
||||||
defer span.End()
|
defer span.End()
|
||||||
s, err := sessions.FromContext(r.Context())
|
s, err := sessions.FromContext(r.Context())
|
||||||
if err != nil {
|
if err != nil || s == nil {
|
||||||
httputil.ErrorResponse(w, r.WithContext(ctx), httputil.Error("", http.StatusForbidden, err))
|
httputil.ErrorResponse(w, r.WithContext(ctx), httputil.Error("", http.StatusForbidden, err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -105,9 +107,11 @@ func (p *Proxy) reqNeedsAuthentication(w http.ResponseWriter, r *http.Request) {
|
||||||
// some proxies like nginx won't follow redirects, and treat any
|
// some proxies like nginx won't follow redirects, and treat any
|
||||||
// non 2xx or 4xx status as an internal service error.
|
// non 2xx or 4xx status as an internal service error.
|
||||||
// https://nginx.org/en/docs/http/ngx_http_auth_request_module.html
|
// https://nginx.org/en/docs/http/ngx_http_auth_request_module.html
|
||||||
if _, ok := r.URL.Query()[disableCallback]; ok {
|
redirectHeader := r.Header.Get(HeaderNoAuthRedirect)
|
||||||
|
if _, ok := r.URL.Query()[HeaderNoAuthRedirect]; ok || redirectHeader == "true" {
|
||||||
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
|
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
|
||||||
}
|
}
|
||||||
|
r.Header.Get(HeaderNoAuthRedirect)
|
||||||
uri := urlutil.SignedRedirectURL(p.SharedKey, p.authenticateSigninURL, urlutil.GetAbsoluteURL(r))
|
uri := urlutil.SignedRedirectURL(p.SharedKey, p.authenticateSigninURL, urlutil.GetAbsoluteURL(r))
|
||||||
http.Redirect(w, r, uri.String(), http.StatusFound)
|
http.Redirect(w, r, uri.String(), http.StatusFound)
|
||||||
}
|
}
|
||||||
|
|
|
@ -78,6 +78,7 @@ type Proxy struct {
|
||||||
refreshCooldown time.Duration
|
refreshCooldown time.Duration
|
||||||
Handler http.Handler
|
Handler http.Handler
|
||||||
sessionStore sessions.SessionStore
|
sessionStore sessions.SessionStore
|
||||||
|
sessionLoaders []sessions.SessionLoader
|
||||||
signingKey string
|
signingKey string
|
||||||
templates *template.Template
|
templates *template.Template
|
||||||
}
|
}
|
||||||
|
@ -122,8 +123,12 @@ func New(opts config.Options) (*Proxy, error) {
|
||||||
defaultUpstreamTimeout: opts.DefaultUpstreamTimeout,
|
defaultUpstreamTimeout: opts.DefaultUpstreamTimeout,
|
||||||
refreshCooldown: opts.RefreshCooldown,
|
refreshCooldown: opts.RefreshCooldown,
|
||||||
sessionStore: cookieStore,
|
sessionStore: cookieStore,
|
||||||
signingKey: opts.SigningKey,
|
sessionLoaders: []sessions.SessionLoader{
|
||||||
templates: templates.New(),
|
cookieStore,
|
||||||
|
sessions.NewHeaderStore(encoder),
|
||||||
|
sessions.NewQueryParamStore(encoder)},
|
||||||
|
signingKey: opts.SigningKey,
|
||||||
|
templates: templates.New(),
|
||||||
}
|
}
|
||||||
// errors checked in ValidateOptions
|
// errors checked in ValidateOptions
|
||||||
p.authorizeURL, _ = urlutil.DeepCopy(opts.AuthorizeURL)
|
p.authorizeURL, _ = urlutil.DeepCopy(opts.AuthorizeURL)
|
||||||
|
@ -227,7 +232,7 @@ func (p *Proxy) reverseProxyHandler(r *mux.Router, policy *config.Policy) (*mux.
|
||||||
}
|
}
|
||||||
|
|
||||||
// 4. Retrieve the user session and add it to the request context
|
// 4. Retrieve the user session and add it to the request context
|
||||||
rp.Use(sessions.RetrieveSession(p.sessionStore))
|
rp.Use(sessions.RetrieveSession(p.sessionLoaders...))
|
||||||
// 5. Strip the user session cookie from the downstream request
|
// 5. Strip the user session cookie from the downstream request
|
||||||
rp.Use(middleware.StripCookie(p.cookieName))
|
rp.Use(middleware.StripCookie(p.cookieName))
|
||||||
// 6. AuthN - Verify the user is authenticated. Set email, group, & id headers
|
// 6. AuthN - Verify the user is authenticated. Set email, group, & id headers
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue