mirror of
https://github.com/pomerium/pomerium.git
synced 2025-08-04 01:09:36 +02:00
all: refactor handler logic
- all: prefer `FormValues` to `ParseForm` with subsequent `Form.Get`s - all: refactor authentication stack to be checked by middleware, and accessible via request context. - all: replace http.ServeMux with gorilla/mux’s router - all: replace custom CSRF checks with gorilla/csrf middleware - authenticate: extract callback path as constant. - internal/config: implement stringer interface for policy - internal/cryptutil: add helper func `NewBase64Key` - internal/cryptutil: rename `GenerateKey` to `NewKey` - internal/cryptutil: rename `GenerateRandomString` to `NewRandomStringN` - internal/middleware: removed alice in favor of gorilla/mux - internal/sessions: remove unused `ValidateRedirectURI` and `ValidateClientSecret` - internal/sessions: replace custom CSRF with gorilla/csrf fork that supports custom handler protection - internal/urlutil: add `SignedRedirectURL` to create hmac'd URLs - internal/urlutil: add `ValidateURL` helper to parse URL options - internal/urlutil: add `GetAbsoluteURL` which takes a request and returns its absolute URL. - proxy: remove holdover state verification checks; we no longer are setting sessions in any proxy routes so we don’t need them. - proxy: replace un-named http.ServeMux with named domain routes. Signed-off-by: Bobby DeSimone <bobbydesimone@gmail.com>
This commit is contained in:
parent
a793249386
commit
dc12947241
37 changed files with 1132 additions and 1384 deletions
|
@ -79,12 +79,10 @@ func NewCookieStore(opts *CookieStoreOptions) (*CookieStore, error) {
|
|||
func (cs *CookieStore) makeCookie(req *http.Request, name string, value string, expiration time.Duration, now time.Time) *http.Cookie {
|
||||
domain := req.Host
|
||||
|
||||
if name == cs.csrfName() {
|
||||
domain = req.Host
|
||||
} else if cs.CookieDomain != "" {
|
||||
if cs.CookieDomain != "" {
|
||||
domain = cs.CookieDomain
|
||||
} else {
|
||||
domain = splitDomain(domain)
|
||||
domain = ParentSubdomain(domain)
|
||||
}
|
||||
|
||||
if h, _, err := net.SplitHostPort(domain); err == nil {
|
||||
|
@ -105,19 +103,11 @@ func (cs *CookieStore) makeCookie(req *http.Request, name string, value string,
|
|||
return c
|
||||
}
|
||||
|
||||
func (cs *CookieStore) csrfName() string {
|
||||
return fmt.Sprintf("%s_csrf", cs.Name)
|
||||
}
|
||||
|
||||
// makeSessionCookie constructs a session cookie given the request, an expiration time and the current time.
|
||||
func (cs *CookieStore) makeSessionCookie(req *http.Request, value string, expiration time.Duration, now time.Time) *http.Cookie {
|
||||
return cs.makeCookie(req, cs.Name, value, expiration, now)
|
||||
}
|
||||
|
||||
func (cs *CookieStore) makeCSRFCookie(req *http.Request, value string, expiration time.Duration, now time.Time) *http.Cookie {
|
||||
return cs.makeCookie(req, cs.csrfName(), value, expiration, now)
|
||||
}
|
||||
|
||||
func (cs *CookieStore) setCookie(w http.ResponseWriter, cookie *http.Cookie) {
|
||||
if len(cookie.String()) <= MaxChunkSize {
|
||||
http.SetCookie(w, cookie)
|
||||
|
@ -134,7 +124,6 @@ func (cs *CookieStore) setCookie(w http.ResponseWriter, cookie *http.Cookie) {
|
|||
nc.Name = fmt.Sprintf("%s_%d", cookie.Name, i)
|
||||
nc.Value = c
|
||||
}
|
||||
fmt.Println(i)
|
||||
http.SetCookie(w, &nc)
|
||||
}
|
||||
}
|
||||
|
@ -150,25 +139,6 @@ func chunk(s string, size int) []string {
|
|||
return ss
|
||||
}
|
||||
|
||||
// ClearCSRF clears the CSRF cookie from the request
|
||||
func (cs *CookieStore) ClearCSRF(w http.ResponseWriter, req *http.Request) {
|
||||
http.SetCookie(w, cs.makeCSRFCookie(req, "", time.Hour*-1, time.Now()))
|
||||
}
|
||||
|
||||
// SetCSRF sets the CSRFCookie creates a CSRF cookie in a given request
|
||||
func (cs *CookieStore) SetCSRF(w http.ResponseWriter, req *http.Request, val string) {
|
||||
http.SetCookie(w, cs.makeCSRFCookie(req, val, cs.CookieExpire, time.Now()))
|
||||
}
|
||||
|
||||
// GetCSRF gets the CSRFCookie creates a CSRF cookie in a given request
|
||||
func (cs *CookieStore) GetCSRF(req *http.Request) (*http.Cookie, error) {
|
||||
c, err := req.Cookie(cs.csrfName())
|
||||
if err != nil {
|
||||
return nil, ErrEmptyCSRF // ErrNoCookie is confusing in this context
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// ClearSession clears the session cookie from a request
|
||||
func (cs *CookieStore) ClearSession(w http.ResponseWriter, req *http.Request) {
|
||||
http.SetCookie(w, cs.makeCookie(req, cs.Name, "", time.Hour*-1, time.Now()))
|
||||
|
@ -235,7 +205,8 @@ func (cs *CookieStore) SaveSession(w http.ResponseWriter, req *http.Request, s *
|
|||
return nil
|
||||
}
|
||||
|
||||
func splitDomain(s string) string {
|
||||
// ParentSubdomain returns the parent subdomain.
|
||||
func ParentSubdomain(s string) string {
|
||||
if strings.Count(s, ".") < 2 {
|
||||
return ""
|
||||
}
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
package sessions
|
||||
package sessions // import "github.com/pomerium/pomerium/internal/sessions"
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
|
@ -38,7 +38,7 @@ func (a mockCipher) Unmarshal(s string, i interface{}) error {
|
|||
return nil
|
||||
}
|
||||
func TestNewCookieStore(t *testing.T) {
|
||||
cipher, err := cryptutil.NewCipher(cryptutil.GenerateKey())
|
||||
cipher, err := cryptutil.NewCipher(cryptutil.NewKey())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -111,7 +111,7 @@ func TestNewCookieStore(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestCookieStore_makeCookie(t *testing.T) {
|
||||
cipher, err := cryptutil.NewCipher(cryptutil.GenerateKey())
|
||||
cipher, err := cryptutil.NewCipher(cryptutil.NewKey())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -155,62 +155,13 @@ func TestCookieStore_makeCookie(t *testing.T) {
|
|||
if diff := cmp.Diff(s.makeSessionCookie(r, tt.value, tt.expiration, now), tt.want); diff != "" {
|
||||
t.Errorf("CookieStore.makeSessionCookie() = \n%s", diff)
|
||||
}
|
||||
got := s.makeCSRFCookie(r, tt.value, tt.expiration, now)
|
||||
tt.wantCSRF.Name = "_pomerium_csrf"
|
||||
if !reflect.DeepEqual(got, tt.wantCSRF) {
|
||||
t.Errorf("CookieStore.makeCookie() = \n%#v, \nwant\n%#v", got, tt.wantCSRF)
|
||||
}
|
||||
w := httptest.NewRecorder()
|
||||
want := "new-csrf"
|
||||
s.SetCSRF(w, r, want)
|
||||
found := false
|
||||
for _, cookie := range w.Result().Cookies() {
|
||||
if cookie.Name == s.Name+"_csrf" && cookie.Value == want {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Error("SetCSRF failed")
|
||||
}
|
||||
|
||||
w = httptest.NewRecorder()
|
||||
s.ClearCSRF(w, r)
|
||||
for _, cookie := range w.Result().Cookies() {
|
||||
if cookie.Name == s.Name+"_csrf" && cookie.Value == want {
|
||||
t.Error("clear csrf failed")
|
||||
break
|
||||
|
||||
}
|
||||
}
|
||||
w = httptest.NewRecorder()
|
||||
want = "new-session"
|
||||
s.setSessionCookie(w, r, want)
|
||||
found = false
|
||||
for _, cookie := range w.Result().Cookies() {
|
||||
if cookie.Name == s.Name && cookie.Value == want {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Error("SetCSRF failed")
|
||||
}
|
||||
w = httptest.NewRecorder()
|
||||
s.ClearSession(w, r)
|
||||
for _, cookie := range w.Result().Cookies() {
|
||||
if cookie.Name == s.Name && cookie.Value == want {
|
||||
t.Error("clear csrf failed")
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCookieStore_SaveSession(t *testing.T) {
|
||||
cipher, err := cryptutil.NewCipher(cryptutil.GenerateKey())
|
||||
cipher, err := cryptutil.NewCipher(cryptutil.NewKey())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -265,38 +216,6 @@ func TestCookieStore_SaveSession(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestMockCSRFStore(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
mockCSRF *MockCSRFStore
|
||||
newCSRFValue string
|
||||
wantErr bool
|
||||
}{
|
||||
{"basic",
|
||||
&MockCSRFStore{
|
||||
ResponseCSRF: "ok",
|
||||
Cookie: &http.Cookie{Name: "hi"}},
|
||||
"newcsrf",
|
||||
false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ms := tt.mockCSRF
|
||||
ms.SetCSRF(nil, nil, tt.newCSRFValue)
|
||||
ms.ClearCSRF(nil, nil)
|
||||
got, err := ms.GetCSRF(nil)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("MockCSRFStore.GetCSRF() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.mockCSRF.Cookie) {
|
||||
t.Errorf("MockCSRFStore.GetCSRF() = %v, want %v", got, tt.mockCSRF.Cookie)
|
||||
}
|
||||
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMockSessionStore(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
@ -341,7 +260,7 @@ func TestMockSessionStore(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func Test_splitDomain(t *testing.T) {
|
||||
func Test_ParentSubdomain(t *testing.T) {
|
||||
t.Parallel()
|
||||
tests := []struct {
|
||||
s string
|
||||
|
@ -354,8 +273,8 @@ func Test_splitDomain(t *testing.T) {
|
|||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.s, func(t *testing.T) {
|
||||
if got := splitDomain(tt.s); got != tt.want {
|
||||
t.Errorf("splitDomain() = %v, want %v", got, tt.want)
|
||||
if got := ParentSubdomain(tt.s); got != tt.want {
|
||||
t.Errorf("ParentSubdomain() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
130
internal/sessions/middleware.go
Normal file
130
internal/sessions/middleware.go
Normal file
|
@ -0,0 +1,130 @@
|
|||
package sessions // import "github.com/pomerium/pomerium/internal/sessions"
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Context keys
|
||||
var (
|
||||
SessionCtxKey = &contextKey{"Session"}
|
||||
ErrorCtxKey = &contextKey{"Error"}
|
||||
)
|
||||
|
||||
// Library errors
|
||||
var (
|
||||
ErrExpired = errors.New("internal/sessions: session is expired")
|
||||
ErrNoSessionFound = errors.New("internal/sessions: session is not found")
|
||||
ErrMalformed = errors.New("internal/sessions: session is malformed")
|
||||
)
|
||||
|
||||
// RetrieveSession http middleware handler will verify a auth session from a http request.
|
||||
//
|
||||
// RetrieveSession will search for a auth session in a http request, in the order:
|
||||
// 1. `pomerium_session` URI query parameter
|
||||
// 2. `Authorization: BEARER` request header
|
||||
// 3. Cookie `_pomerium` value
|
||||
func RetrieveSession(s SessionStore) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return retrieve(s, TokenFromQuery, TokenFromHeader, TokenFromCookie)(next)
|
||||
}
|
||||
}
|
||||
|
||||
func retrieve(s SessionStore, findTokenFns ...func(r *http.Request) string) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
hfn := func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
token, err := retrieveFromRequest(s, r, findTokenFns...)
|
||||
ctx = NewContext(ctx, token, err)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
}
|
||||
return http.HandlerFunc(hfn)
|
||||
}
|
||||
}
|
||||
|
||||
func retrieveFromRequest(s SessionStore, r *http.Request, findTokenFns ...func(r *http.Request) string) (*State, error) {
|
||||
var tokenStr string
|
||||
var err error
|
||||
|
||||
// Extract token string from the request by calling token find functions in
|
||||
// the order they where provided. Further extraction stops if a function
|
||||
// returns a non-empty string.
|
||||
for _, fn := range findTokenFns {
|
||||
tokenStr = fn(r)
|
||||
if tokenStr != "" {
|
||||
break
|
||||
}
|
||||
}
|
||||
if tokenStr == "" {
|
||||
return nil, ErrNoSessionFound
|
||||
}
|
||||
|
||||
state, err := s.LoadSession(r)
|
||||
if err != nil {
|
||||
return nil, ErrMalformed
|
||||
}
|
||||
err = state.Valid()
|
||||
if err != nil {
|
||||
// a little unusual but we want to return the expired state too
|
||||
return state, err
|
||||
}
|
||||
|
||||
// Valid!
|
||||
return state, nil
|
||||
}
|
||||
|
||||
// NewContext sets context values for the user session state and error.
|
||||
func NewContext(ctx context.Context, t *State, err error) context.Context {
|
||||
ctx = context.WithValue(ctx, SessionCtxKey, t)
|
||||
ctx = context.WithValue(ctx, ErrorCtxKey, err)
|
||||
return ctx
|
||||
}
|
||||
|
||||
// FromContext retrieves context values for the user session state and error.
|
||||
func FromContext(ctx context.Context) (*State, error) {
|
||||
state, _ := ctx.Value(SessionCtxKey).(*State)
|
||||
err, _ := ctx.Value(ErrorCtxKey).(error)
|
||||
return state, err
|
||||
}
|
||||
|
||||
// TokenFromCookie tries to retrieve the token string from a cookie named
|
||||
// "_pomerium".
|
||||
func TokenFromCookie(r *http.Request) string {
|
||||
cookie, err := r.Cookie("_pomerium")
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return cookie.Value
|
||||
}
|
||||
|
||||
// TokenFromHeader tries to retrieve the token string from the
|
||||
// "Authorization" request header: "Authorization: BEARER T".
|
||||
func TokenFromHeader(r *http.Request) string {
|
||||
// Get token from authorization header.
|
||||
bearer := r.Header.Get("Authorization")
|
||||
if len(bearer) > 7 && strings.EqualFold(bearer[0:6], "BEARER") {
|
||||
return bearer[7:]
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// TokenFromQuery tries to retrieve the token string from the "pomerium_session" URI
|
||||
// query parameter.
|
||||
// todo(bdd) : document setting session code as queryparam
|
||||
func TokenFromQuery(r *http.Request) string {
|
||||
// Get token from query param named "pomerium_session".
|
||||
return r.URL.Query().Get("pomerium_session")
|
||||
}
|
||||
|
||||
// contextKey is a value for use with context.WithValue. It's used as
|
||||
// a pointer so it fits in an interface{} without allocation. This technique
|
||||
// for defining context keys was copied from Go 1.7's new use of context in net/http.
|
||||
type contextKey struct {
|
||||
name string
|
||||
}
|
||||
|
||||
func (k *contextKey) String() string {
|
||||
return "SessionStore context value " + k.name
|
||||
}
|
133
internal/sessions/middleware_test.go
Normal file
133
internal/sessions/middleware_test.go
Normal file
|
@ -0,0 +1,133 @@
|
|||
package sessions
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/cryptutil"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
)
|
||||
|
||||
func TestNewContext(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
ctx context.Context
|
||||
t *State
|
||||
err error
|
||||
want context.Context
|
||||
}{
|
||||
{"simple", context.Background(), &State{Email: "bdd@pomerium.io"}, nil, nil},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctxOut := NewContext(tt.ctx, tt.t, tt.err)
|
||||
stateOut, errOut := FromContext(ctxOut)
|
||||
if diff := cmp.Diff(tt.t, stateOut); diff != "" {
|
||||
t.Errorf("NewContext() = %s", diff)
|
||||
}
|
||||
if diff := cmp.Diff(tt.err, errOut); diff != "" {
|
||||
t.Errorf("NewContext() = %s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func testAuthorizer(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, err := FromContext(r.Context())
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
func TestVerifier(t *testing.T) {
|
||||
fnh := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
fmt.Fprint(w, http.StatusText(http.StatusOK))
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
// s SessionStore
|
||||
state State
|
||||
|
||||
cookie bool
|
||||
header bool
|
||||
param bool
|
||||
|
||||
wantBody string
|
||||
wantStatus int
|
||||
}{
|
||||
{"good cookie session", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(10 * time.Second)}, true, false, false, http.StatusText(http.StatusOK), http.StatusOK},
|
||||
{"expired cookie", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(-10 * time.Second)}, true, false, false, "internal/sessions: session is expired\n", http.StatusUnauthorized},
|
||||
{"malformed cookie", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(-10 * time.Second)}, true, false, false, "internal/sessions: session is malformed\n", http.StatusUnauthorized},
|
||||
{"good auth header session", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(10 * time.Second)}, false, true, false, http.StatusText(http.StatusOK), http.StatusOK},
|
||||
{"expired auth header", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(-10 * time.Second)}, false, true, false, "internal/sessions: session is expired\n", http.StatusUnauthorized},
|
||||
{"malformed auth header", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(-10 * time.Second)}, false, true, false, "internal/sessions: session is malformed\n", http.StatusUnauthorized},
|
||||
{"good auth query param session", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(10 * time.Second)}, false, true, true, http.StatusText(http.StatusOK), http.StatusOK},
|
||||
{"expired auth query param", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(-10 * time.Second)}, false, true, true, "internal/sessions: session is expired\n", http.StatusUnauthorized},
|
||||
{"malformed auth query param", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(-10 * time.Second)}, false, true, true, "internal/sessions: session is malformed\n", http.StatusUnauthorized},
|
||||
{"no session", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(-10 * time.Second)}, false, false, false, "internal/sessions: session is not found\n", http.StatusUnauthorized},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cipher, err := cryptutil.NewCipherFromBase64(cryptutil.NewBase64Key())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
encSession, err := MarshalSession(&tt.state, cipher)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if strings.Contains(tt.name, "malformed") {
|
||||
// add some garbage to the end of the string
|
||||
encSession += cryptutil.NewBase64Key()
|
||||
fmt.Println(encSession)
|
||||
}
|
||||
|
||||
cs, err := NewCookieStore(&CookieStoreOptions{
|
||||
Name: "_pomerium",
|
||||
CookieCipher: cipher,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
r.Header.Set("Accept", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
if tt.cookie {
|
||||
r.AddCookie(&http.Cookie{Name: "_pomerium", Value: encSession})
|
||||
} else if tt.header {
|
||||
r.Header.Set("Authorization", "Bearer "+encSession)
|
||||
} else if tt.param {
|
||||
q := r.URL.Query()
|
||||
q.Add("pomerium_session", encSession)
|
||||
r.URL.RawQuery = q.Encode()
|
||||
}
|
||||
|
||||
got := RetrieveSession(cs)(testAuthorizer((fnh)))
|
||||
got.ServeHTTP(w, r)
|
||||
|
||||
gotBody := w.Body.String()
|
||||
gotStatus := w.Result().StatusCode
|
||||
|
||||
if diff := cmp.Diff(gotBody, tt.wantBody); diff != "" {
|
||||
t.Errorf("RetrieveSession() = %v", diff)
|
||||
}
|
||||
if diff := cmp.Diff(gotStatus, tt.wantStatus); diff != "" {
|
||||
t.Errorf("RetrieveSession() = %v", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -4,28 +4,6 @@ import (
|
|||
"net/http"
|
||||
)
|
||||
|
||||
// MockCSRFStore is a mock implementation of the CSRF store interface
|
||||
type MockCSRFStore struct {
|
||||
ResponseCSRF string
|
||||
Cookie *http.Cookie
|
||||
GetError error
|
||||
}
|
||||
|
||||
// SetCSRF sets the ResponseCSRF string to a val
|
||||
func (ms MockCSRFStore) SetCSRF(rw http.ResponseWriter, req *http.Request, val string) {
|
||||
ms.ResponseCSRF = val
|
||||
}
|
||||
|
||||
// ClearCSRF clears the ResponseCSRF string
|
||||
func (ms MockCSRFStore) ClearCSRF(http.ResponseWriter, *http.Request) {
|
||||
ms.ResponseCSRF = ""
|
||||
}
|
||||
|
||||
// GetCSRF returns the cookie and error
|
||||
func (ms MockCSRFStore) GetCSRF(*http.Request) (*http.Cookie, error) {
|
||||
return ms.Cookie, ms.GetError
|
||||
}
|
||||
|
||||
// MockSessionStore is a mock implementation of the SessionStore interface
|
||||
type MockSessionStore struct {
|
||||
ResponseSession string
|
||||
|
|
|
@ -10,9 +10,6 @@ import (
|
|||
"github.com/pomerium/pomerium/internal/cryptutil"
|
||||
)
|
||||
|
||||
// ErrExpired is an error for a expired sessions.
|
||||
var ErrExpired = fmt.Errorf("internal/sessions: expired session")
|
||||
|
||||
// State is our object that keeps track of a user's session state
|
||||
type State struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
|
|
|
@ -12,7 +12,7 @@ import (
|
|||
)
|
||||
|
||||
func TestStateSerialization(t *testing.T) {
|
||||
secret := cryptutil.GenerateKey()
|
||||
secret := cryptutil.NewKey()
|
||||
c, err := cryptutil.NewCipher(secret)
|
||||
if err != nil {
|
||||
t.Fatalf("expected to be able to create cipher: %v", err)
|
||||
|
@ -123,7 +123,7 @@ func TestState_Impersonating(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestMarshalSession(t *testing.T) {
|
||||
secret := cryptutil.GenerateKey()
|
||||
secret := cryptutil.NewKey()
|
||||
c, err := cryptutil.NewCipher(secret)
|
||||
if err != nil {
|
||||
t.Fatalf("expected to be able to create cipher: %v", err)
|
||||
|
|
|
@ -8,16 +8,6 @@ import (
|
|||
// ErrEmptySession is an error for an empty sessions.
|
||||
var ErrEmptySession = errors.New("internal/sessions: empty session")
|
||||
|
||||
// ErrEmptyCSRF is an error for an empty sessions.
|
||||
var ErrEmptyCSRF = errors.New("internal/sessions: empty csrf")
|
||||
|
||||
// CSRFStore has the functions for setting, getting, and clearing the CSRF cookie
|
||||
type CSRFStore interface {
|
||||
SetCSRF(http.ResponseWriter, *http.Request, string)
|
||||
GetCSRF(*http.Request) (*http.Cookie, error)
|
||||
ClearCSRF(http.ResponseWriter, *http.Request)
|
||||
}
|
||||
|
||||
// SessionStore has the functions for setting, getting, and clearing the Session cookie
|
||||
type SessionStore interface {
|
||||
ClearSession(http.ResponseWriter, *http.Request)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue