all: refactor handler logic

- all: prefer `FormValues` to `ParseForm` with subsequent `Form.Get`s
- all: refactor authentication stack to be checked by middleware, and accessible via request context.
- all: replace http.ServeMux with gorilla/mux’s router
- all: replace custom CSRF checks with gorilla/csrf middleware
- authenticate: extract callback path as constant.
- internal/config: implement stringer interface for policy
- internal/cryptutil: add helper func `NewBase64Key`
- internal/cryptutil: rename `GenerateKey` to `NewKey`
- internal/cryptutil: rename `GenerateRandomString` to `NewRandomStringN`
- internal/middleware: removed alice in favor of gorilla/mux
- internal/sessions: remove unused `ValidateRedirectURI` and `ValidateClientSecret`
- internal/sessions: replace custom CSRF with gorilla/csrf fork that supports custom handler protection
- internal/urlutil: add `SignedRedirectURL` to create hmac'd URLs
- internal/urlutil: add `ValidateURL` helper to parse URL options
- internal/urlutil: add `GetAbsoluteURL` which takes a request and returns its absolute URL.
- proxy: remove holdover state verification checks; we no longer are setting sessions in any proxy routes so we don’t need them.
- proxy: replace un-named http.ServeMux with named domain routes.

Signed-off-by: Bobby DeSimone <bobbydesimone@gmail.com>
This commit is contained in:
Bobby DeSimone 2019-09-12 13:54:30 -07:00
parent a793249386
commit dc12947241
No known key found for this signature in database
GPG key ID: AEE4CF12FE86D07E
37 changed files with 1132 additions and 1384 deletions

View file

@ -79,12 +79,10 @@ func NewCookieStore(opts *CookieStoreOptions) (*CookieStore, error) {
func (cs *CookieStore) makeCookie(req *http.Request, name string, value string, expiration time.Duration, now time.Time) *http.Cookie {
domain := req.Host
if name == cs.csrfName() {
domain = req.Host
} else if cs.CookieDomain != "" {
if cs.CookieDomain != "" {
domain = cs.CookieDomain
} else {
domain = splitDomain(domain)
domain = ParentSubdomain(domain)
}
if h, _, err := net.SplitHostPort(domain); err == nil {
@ -105,19 +103,11 @@ func (cs *CookieStore) makeCookie(req *http.Request, name string, value string,
return c
}
func (cs *CookieStore) csrfName() string {
return fmt.Sprintf("%s_csrf", cs.Name)
}
// makeSessionCookie constructs a session cookie given the request, an expiration time and the current time.
func (cs *CookieStore) makeSessionCookie(req *http.Request, value string, expiration time.Duration, now time.Time) *http.Cookie {
return cs.makeCookie(req, cs.Name, value, expiration, now)
}
func (cs *CookieStore) makeCSRFCookie(req *http.Request, value string, expiration time.Duration, now time.Time) *http.Cookie {
return cs.makeCookie(req, cs.csrfName(), value, expiration, now)
}
func (cs *CookieStore) setCookie(w http.ResponseWriter, cookie *http.Cookie) {
if len(cookie.String()) <= MaxChunkSize {
http.SetCookie(w, cookie)
@ -134,7 +124,6 @@ func (cs *CookieStore) setCookie(w http.ResponseWriter, cookie *http.Cookie) {
nc.Name = fmt.Sprintf("%s_%d", cookie.Name, i)
nc.Value = c
}
fmt.Println(i)
http.SetCookie(w, &nc)
}
}
@ -150,25 +139,6 @@ func chunk(s string, size int) []string {
return ss
}
// ClearCSRF clears the CSRF cookie from the request
func (cs *CookieStore) ClearCSRF(w http.ResponseWriter, req *http.Request) {
http.SetCookie(w, cs.makeCSRFCookie(req, "", time.Hour*-1, time.Now()))
}
// SetCSRF sets the CSRFCookie creates a CSRF cookie in a given request
func (cs *CookieStore) SetCSRF(w http.ResponseWriter, req *http.Request, val string) {
http.SetCookie(w, cs.makeCSRFCookie(req, val, cs.CookieExpire, time.Now()))
}
// GetCSRF gets the CSRFCookie creates a CSRF cookie in a given request
func (cs *CookieStore) GetCSRF(req *http.Request) (*http.Cookie, error) {
c, err := req.Cookie(cs.csrfName())
if err != nil {
return nil, ErrEmptyCSRF // ErrNoCookie is confusing in this context
}
return c, nil
}
// ClearSession clears the session cookie from a request
func (cs *CookieStore) ClearSession(w http.ResponseWriter, req *http.Request) {
http.SetCookie(w, cs.makeCookie(req, cs.Name, "", time.Hour*-1, time.Now()))
@ -235,7 +205,8 @@ func (cs *CookieStore) SaveSession(w http.ResponseWriter, req *http.Request, s *
return nil
}
func splitDomain(s string) string {
// ParentSubdomain returns the parent subdomain.
func ParentSubdomain(s string) string {
if strings.Count(s, ".") < 2 {
return ""
}

View file

@ -1,4 +1,4 @@
package sessions
package sessions // import "github.com/pomerium/pomerium/internal/sessions"
import (
"crypto/rand"
@ -38,7 +38,7 @@ func (a mockCipher) Unmarshal(s string, i interface{}) error {
return nil
}
func TestNewCookieStore(t *testing.T) {
cipher, err := cryptutil.NewCipher(cryptutil.GenerateKey())
cipher, err := cryptutil.NewCipher(cryptutil.NewKey())
if err != nil {
t.Fatal(err)
}
@ -111,7 +111,7 @@ func TestNewCookieStore(t *testing.T) {
}
func TestCookieStore_makeCookie(t *testing.T) {
cipher, err := cryptutil.NewCipher(cryptutil.GenerateKey())
cipher, err := cryptutil.NewCipher(cryptutil.NewKey())
if err != nil {
t.Fatal(err)
}
@ -155,62 +155,13 @@ func TestCookieStore_makeCookie(t *testing.T) {
if diff := cmp.Diff(s.makeSessionCookie(r, tt.value, tt.expiration, now), tt.want); diff != "" {
t.Errorf("CookieStore.makeSessionCookie() = \n%s", diff)
}
got := s.makeCSRFCookie(r, tt.value, tt.expiration, now)
tt.wantCSRF.Name = "_pomerium_csrf"
if !reflect.DeepEqual(got, tt.wantCSRF) {
t.Errorf("CookieStore.makeCookie() = \n%#v, \nwant\n%#v", got, tt.wantCSRF)
}
w := httptest.NewRecorder()
want := "new-csrf"
s.SetCSRF(w, r, want)
found := false
for _, cookie := range w.Result().Cookies() {
if cookie.Name == s.Name+"_csrf" && cookie.Value == want {
found = true
break
}
}
if !found {
t.Error("SetCSRF failed")
}
w = httptest.NewRecorder()
s.ClearCSRF(w, r)
for _, cookie := range w.Result().Cookies() {
if cookie.Name == s.Name+"_csrf" && cookie.Value == want {
t.Error("clear csrf failed")
break
}
}
w = httptest.NewRecorder()
want = "new-session"
s.setSessionCookie(w, r, want)
found = false
for _, cookie := range w.Result().Cookies() {
if cookie.Name == s.Name && cookie.Value == want {
found = true
break
}
}
if !found {
t.Error("SetCSRF failed")
}
w = httptest.NewRecorder()
s.ClearSession(w, r)
for _, cookie := range w.Result().Cookies() {
if cookie.Name == s.Name && cookie.Value == want {
t.Error("clear csrf failed")
break
}
}
})
}
}
func TestCookieStore_SaveSession(t *testing.T) {
cipher, err := cryptutil.NewCipher(cryptutil.GenerateKey())
cipher, err := cryptutil.NewCipher(cryptutil.NewKey())
if err != nil {
t.Fatal(err)
}
@ -265,38 +216,6 @@ func TestCookieStore_SaveSession(t *testing.T) {
}
}
func TestMockCSRFStore(t *testing.T) {
tests := []struct {
name string
mockCSRF *MockCSRFStore
newCSRFValue string
wantErr bool
}{
{"basic",
&MockCSRFStore{
ResponseCSRF: "ok",
Cookie: &http.Cookie{Name: "hi"}},
"newcsrf",
false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ms := tt.mockCSRF
ms.SetCSRF(nil, nil, tt.newCSRFValue)
ms.ClearCSRF(nil, nil)
got, err := ms.GetCSRF(nil)
if (err != nil) != tt.wantErr {
t.Errorf("MockCSRFStore.GetCSRF() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.mockCSRF.Cookie) {
t.Errorf("MockCSRFStore.GetCSRF() = %v, want %v", got, tt.mockCSRF.Cookie)
}
})
}
}
func TestMockSessionStore(t *testing.T) {
tests := []struct {
name string
@ -341,7 +260,7 @@ func TestMockSessionStore(t *testing.T) {
}
}
func Test_splitDomain(t *testing.T) {
func Test_ParentSubdomain(t *testing.T) {
t.Parallel()
tests := []struct {
s string
@ -354,8 +273,8 @@ func Test_splitDomain(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.s, func(t *testing.T) {
if got := splitDomain(tt.s); got != tt.want {
t.Errorf("splitDomain() = %v, want %v", got, tt.want)
if got := ParentSubdomain(tt.s); got != tt.want {
t.Errorf("ParentSubdomain() = %v, want %v", got, tt.want)
}
})
}

View file

@ -0,0 +1,130 @@
package sessions // import "github.com/pomerium/pomerium/internal/sessions"
import (
"context"
"errors"
"net/http"
"strings"
)
// Context keys
var (
SessionCtxKey = &contextKey{"Session"}
ErrorCtxKey = &contextKey{"Error"}
)
// Library errors
var (
ErrExpired = errors.New("internal/sessions: session is expired")
ErrNoSessionFound = errors.New("internal/sessions: session is not found")
ErrMalformed = errors.New("internal/sessions: session is malformed")
)
// RetrieveSession http middleware handler will verify a auth session from a http request.
//
// RetrieveSession will search for a auth session in a http request, in the order:
// 1. `pomerium_session` URI query parameter
// 2. `Authorization: BEARER` request header
// 3. Cookie `_pomerium` value
func RetrieveSession(s SessionStore) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return retrieve(s, TokenFromQuery, TokenFromHeader, TokenFromCookie)(next)
}
}
func retrieve(s SessionStore, findTokenFns ...func(r *http.Request) string) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
hfn := func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
token, err := retrieveFromRequest(s, r, findTokenFns...)
ctx = NewContext(ctx, token, err)
next.ServeHTTP(w, r.WithContext(ctx))
}
return http.HandlerFunc(hfn)
}
}
func retrieveFromRequest(s SessionStore, r *http.Request, findTokenFns ...func(r *http.Request) string) (*State, error) {
var tokenStr string
var err error
// Extract token string from the request by calling token find functions in
// the order they where provided. Further extraction stops if a function
// returns a non-empty string.
for _, fn := range findTokenFns {
tokenStr = fn(r)
if tokenStr != "" {
break
}
}
if tokenStr == "" {
return nil, ErrNoSessionFound
}
state, err := s.LoadSession(r)
if err != nil {
return nil, ErrMalformed
}
err = state.Valid()
if err != nil {
// a little unusual but we want to return the expired state too
return state, err
}
// Valid!
return state, nil
}
// NewContext sets context values for the user session state and error.
func NewContext(ctx context.Context, t *State, err error) context.Context {
ctx = context.WithValue(ctx, SessionCtxKey, t)
ctx = context.WithValue(ctx, ErrorCtxKey, err)
return ctx
}
// FromContext retrieves context values for the user session state and error.
func FromContext(ctx context.Context) (*State, error) {
state, _ := ctx.Value(SessionCtxKey).(*State)
err, _ := ctx.Value(ErrorCtxKey).(error)
return state, err
}
// TokenFromCookie tries to retrieve the token string from a cookie named
// "_pomerium".
func TokenFromCookie(r *http.Request) string {
cookie, err := r.Cookie("_pomerium")
if err != nil {
return ""
}
return cookie.Value
}
// TokenFromHeader tries to retrieve the token string from the
// "Authorization" request header: "Authorization: BEARER T".
func TokenFromHeader(r *http.Request) string {
// Get token from authorization header.
bearer := r.Header.Get("Authorization")
if len(bearer) > 7 && strings.EqualFold(bearer[0:6], "BEARER") {
return bearer[7:]
}
return ""
}
// TokenFromQuery tries to retrieve the token string from the "pomerium_session" URI
// query parameter.
// todo(bdd) : document setting session code as queryparam
func TokenFromQuery(r *http.Request) string {
// Get token from query param named "pomerium_session".
return r.URL.Query().Get("pomerium_session")
}
// contextKey is a value for use with context.WithValue. It's used as
// a pointer so it fits in an interface{} without allocation. This technique
// for defining context keys was copied from Go 1.7's new use of context in net/http.
type contextKey struct {
name string
}
func (k *contextKey) String() string {
return "SessionStore context value " + k.name
}

View file

@ -0,0 +1,133 @@
package sessions
import (
"context"
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/pomerium/pomerium/internal/cryptutil"
"github.com/google/go-cmp/cmp"
)
func TestNewContext(t *testing.T) {
tests := []struct {
name string
ctx context.Context
t *State
err error
want context.Context
}{
{"simple", context.Background(), &State{Email: "bdd@pomerium.io"}, nil, nil},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctxOut := NewContext(tt.ctx, tt.t, tt.err)
stateOut, errOut := FromContext(ctxOut)
if diff := cmp.Diff(tt.t, stateOut); diff != "" {
t.Errorf("NewContext() = %s", diff)
}
if diff := cmp.Diff(tt.err, errOut); diff != "" {
t.Errorf("NewContext() = %s", diff)
}
})
}
}
func testAuthorizer(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, err := FromContext(r.Context())
if err != nil {
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}
next.ServeHTTP(w, r)
})
}
func TestVerifier(t *testing.T) {
fnh := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
fmt.Fprint(w, http.StatusText(http.StatusOK))
w.WriteHeader(http.StatusOK)
})
tests := []struct {
name string
// s SessionStore
state State
cookie bool
header bool
param bool
wantBody string
wantStatus int
}{
{"good cookie session", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(10 * time.Second)}, true, false, false, http.StatusText(http.StatusOK), http.StatusOK},
{"expired cookie", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(-10 * time.Second)}, true, false, false, "internal/sessions: session is expired\n", http.StatusUnauthorized},
{"malformed cookie", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(-10 * time.Second)}, true, false, false, "internal/sessions: session is malformed\n", http.StatusUnauthorized},
{"good auth header session", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(10 * time.Second)}, false, true, false, http.StatusText(http.StatusOK), http.StatusOK},
{"expired auth header", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(-10 * time.Second)}, false, true, false, "internal/sessions: session is expired\n", http.StatusUnauthorized},
{"malformed auth header", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(-10 * time.Second)}, false, true, false, "internal/sessions: session is malformed\n", http.StatusUnauthorized},
{"good auth query param session", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(10 * time.Second)}, false, true, true, http.StatusText(http.StatusOK), http.StatusOK},
{"expired auth query param", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(-10 * time.Second)}, false, true, true, "internal/sessions: session is expired\n", http.StatusUnauthorized},
{"malformed auth query param", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(-10 * time.Second)}, false, true, true, "internal/sessions: session is malformed\n", http.StatusUnauthorized},
{"no session", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(-10 * time.Second)}, false, false, false, "internal/sessions: session is not found\n", http.StatusUnauthorized},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cipher, err := cryptutil.NewCipherFromBase64(cryptutil.NewBase64Key())
if err != nil {
t.Fatal(err)
}
encSession, err := MarshalSession(&tt.state, cipher)
if err != nil {
t.Fatal(err)
}
if strings.Contains(tt.name, "malformed") {
// add some garbage to the end of the string
encSession += cryptutil.NewBase64Key()
fmt.Println(encSession)
}
cs, err := NewCookieStore(&CookieStoreOptions{
Name: "_pomerium",
CookieCipher: cipher,
})
if err != nil {
t.Fatal(err)
}
r := httptest.NewRequest(http.MethodGet, "/", nil)
r.Header.Set("Accept", "application/json")
w := httptest.NewRecorder()
if tt.cookie {
r.AddCookie(&http.Cookie{Name: "_pomerium", Value: encSession})
} else if tt.header {
r.Header.Set("Authorization", "Bearer "+encSession)
} else if tt.param {
q := r.URL.Query()
q.Add("pomerium_session", encSession)
r.URL.RawQuery = q.Encode()
}
got := RetrieveSession(cs)(testAuthorizer((fnh)))
got.ServeHTTP(w, r)
gotBody := w.Body.String()
gotStatus := w.Result().StatusCode
if diff := cmp.Diff(gotBody, tt.wantBody); diff != "" {
t.Errorf("RetrieveSession() = %v", diff)
}
if diff := cmp.Diff(gotStatus, tt.wantStatus); diff != "" {
t.Errorf("RetrieveSession() = %v", diff)
}
})
}
}

View file

@ -4,28 +4,6 @@ import (
"net/http"
)
// MockCSRFStore is a mock implementation of the CSRF store interface
type MockCSRFStore struct {
ResponseCSRF string
Cookie *http.Cookie
GetError error
}
// SetCSRF sets the ResponseCSRF string to a val
func (ms MockCSRFStore) SetCSRF(rw http.ResponseWriter, req *http.Request, val string) {
ms.ResponseCSRF = val
}
// ClearCSRF clears the ResponseCSRF string
func (ms MockCSRFStore) ClearCSRF(http.ResponseWriter, *http.Request) {
ms.ResponseCSRF = ""
}
// GetCSRF returns the cookie and error
func (ms MockCSRFStore) GetCSRF(*http.Request) (*http.Cookie, error) {
return ms.Cookie, ms.GetError
}
// MockSessionStore is a mock implementation of the SessionStore interface
type MockSessionStore struct {
ResponseSession string

View file

@ -10,9 +10,6 @@ import (
"github.com/pomerium/pomerium/internal/cryptutil"
)
// ErrExpired is an error for a expired sessions.
var ErrExpired = fmt.Errorf("internal/sessions: expired session")
// State is our object that keeps track of a user's session state
type State struct {
AccessToken string `json:"access_token"`

View file

@ -12,7 +12,7 @@ import (
)
func TestStateSerialization(t *testing.T) {
secret := cryptutil.GenerateKey()
secret := cryptutil.NewKey()
c, err := cryptutil.NewCipher(secret)
if err != nil {
t.Fatalf("expected to be able to create cipher: %v", err)
@ -123,7 +123,7 @@ func TestState_Impersonating(t *testing.T) {
}
func TestMarshalSession(t *testing.T) {
secret := cryptutil.GenerateKey()
secret := cryptutil.NewKey()
c, err := cryptutil.NewCipher(secret)
if err != nil {
t.Fatalf("expected to be able to create cipher: %v", err)

View file

@ -8,16 +8,6 @@ import (
// ErrEmptySession is an error for an empty sessions.
var ErrEmptySession = errors.New("internal/sessions: empty session")
// ErrEmptyCSRF is an error for an empty sessions.
var ErrEmptyCSRF = errors.New("internal/sessions: empty csrf")
// CSRFStore has the functions for setting, getting, and clearing the CSRF cookie
type CSRFStore interface {
SetCSRF(http.ResponseWriter, *http.Request, string)
GetCSRF(*http.Request) (*http.Cookie, error)
ClearCSRF(http.ResponseWriter, *http.Request)
}
// SessionStore has the functions for setting, getting, and clearing the Session cookie
type SessionStore interface {
ClearSession(http.ResponseWriter, *http.Request)