mirror of
https://github.com/pomerium/pomerium.git
synced 2025-04-29 18:36:30 +02:00
352 lines
11 KiB
Go
352 lines
11 KiB
Go
package sessions
|
|
|
|
import (
|
|
"crypto/rand"
|
|
"errors"
|
|
"fmt"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"reflect"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/google/go-cmp/cmp"
|
|
"github.com/pomerium/pomerium/internal/cryptutil"
|
|
)
|
|
|
|
type mockCipher struct{}
|
|
|
|
func (a mockCipher) Encrypt(s []byte) ([]byte, error) {
|
|
if string(s) == "error" {
|
|
return []byte(""), errors.New("error encrypting")
|
|
}
|
|
return []byte("OK"), nil
|
|
}
|
|
|
|
func (a mockCipher) Decrypt(s []byte) ([]byte, error) {
|
|
if string(s) == "error" {
|
|
return []byte(""), errors.New("error encrypting")
|
|
}
|
|
return []byte("OK"), nil
|
|
}
|
|
func (a mockCipher) Marshal(s interface{}) (string, error) { return "", errors.New("error") }
|
|
func (a mockCipher) Unmarshal(s string, i interface{}) error {
|
|
if s == "unmarshal error" || s == "error" {
|
|
return errors.New("error")
|
|
}
|
|
return nil
|
|
}
|
|
func TestNewCookieStore(t *testing.T) {
|
|
cipher, err := cryptutil.NewCipher(cryptutil.GenerateKey())
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
tests := []struct {
|
|
name string
|
|
opts *CookieStoreOptions
|
|
want *CookieStore
|
|
wantErr bool
|
|
}{
|
|
{"good",
|
|
&CookieStoreOptions{
|
|
Name: "_cookie",
|
|
CookieSecure: true,
|
|
CookieHTTPOnly: true,
|
|
CookieDomain: "pomerium.io",
|
|
CookieExpire: 10 * time.Second,
|
|
CookieCipher: cipher,
|
|
},
|
|
&CookieStore{
|
|
Name: "_cookie",
|
|
CookieSecure: true,
|
|
CookieHTTPOnly: true,
|
|
CookieDomain: "pomerium.io",
|
|
CookieExpire: 10 * time.Second,
|
|
CookieCipher: cipher,
|
|
},
|
|
false},
|
|
{"missing name",
|
|
&CookieStoreOptions{
|
|
Name: "",
|
|
CookieSecure: true,
|
|
CookieHTTPOnly: true,
|
|
CookieDomain: "pomerium.io",
|
|
CookieExpire: 10 * time.Second,
|
|
CookieCipher: cipher,
|
|
},
|
|
nil,
|
|
true},
|
|
{"missing cipher",
|
|
&CookieStoreOptions{
|
|
Name: "_pomerium",
|
|
CookieSecure: true,
|
|
CookieHTTPOnly: true,
|
|
CookieDomain: "pomerium.io",
|
|
CookieExpire: 10 * time.Second,
|
|
CookieCipher: nil,
|
|
},
|
|
nil,
|
|
true},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
got, err := NewCookieStore(tt.opts)
|
|
if (err != nil) != tt.wantErr {
|
|
t.Errorf("NewCookieStore() error = %v, wantErr %v", err, tt.wantErr)
|
|
return
|
|
}
|
|
if !reflect.DeepEqual(got, tt.want) {
|
|
t.Errorf("NewCookieStore() = %#v, want %#v", got, tt.want)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestCookieStore_makeCookie(t *testing.T) {
|
|
cipher, err := cryptutil.NewCipher(cryptutil.GenerateKey())
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
now := time.Now()
|
|
tests := []struct {
|
|
name string
|
|
domain string
|
|
|
|
cookieDomain string
|
|
cookieName string
|
|
value string
|
|
expiration time.Duration
|
|
want *http.Cookie
|
|
wantCSRF *http.Cookie
|
|
}{
|
|
{"good", "http://httpbin.corp.pomerium.io", "", "_pomerium", "value", 0, &http.Cookie{Name: "_pomerium", Value: "value", Path: "/", Domain: "corp.pomerium.io", Secure: true, HttpOnly: true}, &http.Cookie{Name: "_pomerium_csrf", Value: "value", Path: "/", Domain: "httpbin.corp.pomerium.io", Secure: true, HttpOnly: true}},
|
|
{"domains with https", "https://httpbin.corp.pomerium.io", "", "_pomerium", "value", 0, &http.Cookie{Name: "_pomerium", Value: "value", Path: "/", Domain: "corp.pomerium.io", Secure: true, HttpOnly: true}, &http.Cookie{Name: "_pomerium_csrf", Value: "value", Path: "/", Domain: "httpbin.corp.pomerium.io", Secure: true, HttpOnly: true}},
|
|
{"domain with port", "http://httpbin.corp.pomerium.io:443", "", "_pomerium", "value", 0, &http.Cookie{Name: "_pomerium", Value: "value", Path: "/", Domain: "corp.pomerium.io", Secure: true, HttpOnly: true}, &http.Cookie{Name: "_pomerium_csrf", Value: "value", Path: "/", Domain: "httpbin.corp.pomerium.io", Secure: true, HttpOnly: true}},
|
|
{"expiration set", "http://httpbin.corp.pomerium.io:443", "", "_pomerium", "value", 10 * time.Second, &http.Cookie{Expires: now.Add(10 * time.Second), Name: "_pomerium", Value: "value", Path: "/", Domain: "corp.pomerium.io", Secure: true, HttpOnly: true}, &http.Cookie{Expires: now.Add(10 * time.Second), Name: "_pomerium_csrf", Value: "value", Path: "/", Domain: "httpbin.corp.pomerium.io", Secure: true, HttpOnly: true}},
|
|
{"good", "http://httpbin.corp.pomerium.io", "pomerium.io", "_pomerium", "value", 0, &http.Cookie{Name: "_pomerium", Value: "value", Path: "/", Domain: "pomerium.io", Secure: true, HttpOnly: true}, &http.Cookie{Name: "_pomerium_csrf", Value: "value", Path: "/", Domain: "httpbin.corp.pomerium.io", Secure: true, HttpOnly: true}},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
r := httptest.NewRequest("GET", tt.domain, nil)
|
|
|
|
s, err := NewCookieStore(
|
|
&CookieStoreOptions{
|
|
Name: "_pomerium",
|
|
CookieSecure: true,
|
|
CookieHTTPOnly: true,
|
|
CookieDomain: tt.cookieDomain,
|
|
CookieExpire: 10 * time.Second,
|
|
CookieCipher: cipher})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if diff := cmp.Diff(s.makeCookie(r, tt.cookieName, tt.value, tt.expiration, now), tt.want); diff != "" {
|
|
t.Errorf("CookieStore.makeCookie() = \n%s", diff)
|
|
}
|
|
if diff := cmp.Diff(s.makeSessionCookie(r, tt.value, tt.expiration, now), tt.want); diff != "" {
|
|
t.Errorf("CookieStore.makeSessionCookie() = \n%s", diff)
|
|
}
|
|
got := s.makeCSRFCookie(r, tt.value, tt.expiration, now)
|
|
tt.wantCSRF.Name = "_pomerium_csrf"
|
|
if !reflect.DeepEqual(got, tt.wantCSRF) {
|
|
t.Errorf("CookieStore.makeCookie() = \n%#v, \nwant\n%#v", got, tt.wantCSRF)
|
|
}
|
|
w := httptest.NewRecorder()
|
|
want := "new-csrf"
|
|
s.SetCSRF(w, r, want)
|
|
found := false
|
|
for _, cookie := range w.Result().Cookies() {
|
|
if cookie.Name == s.Name+"_csrf" && cookie.Value == want {
|
|
found = true
|
|
break
|
|
}
|
|
}
|
|
if !found {
|
|
t.Error("SetCSRF failed")
|
|
}
|
|
|
|
w = httptest.NewRecorder()
|
|
s.ClearCSRF(w, r)
|
|
for _, cookie := range w.Result().Cookies() {
|
|
if cookie.Name == s.Name+"_csrf" && cookie.Value == want {
|
|
t.Error("clear csrf failed")
|
|
break
|
|
|
|
}
|
|
}
|
|
w = httptest.NewRecorder()
|
|
want = "new-session"
|
|
s.setSessionCookie(w, r, want)
|
|
found = false
|
|
for _, cookie := range w.Result().Cookies() {
|
|
if cookie.Name == s.Name && cookie.Value == want {
|
|
found = true
|
|
break
|
|
}
|
|
}
|
|
if !found {
|
|
t.Error("SetCSRF failed")
|
|
}
|
|
w = httptest.NewRecorder()
|
|
s.ClearSession(w, r)
|
|
for _, cookie := range w.Result().Cookies() {
|
|
if cookie.Name == s.Name && cookie.Value == want {
|
|
t.Error("clear csrf failed")
|
|
break
|
|
}
|
|
}
|
|
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestCookieStore_SaveSession(t *testing.T) {
|
|
cipher, err := cryptutil.NewCipher(cryptutil.GenerateKey())
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
hugeString := make([]byte, 4097)
|
|
if _, err := rand.Read(hugeString); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
tests := []struct {
|
|
name string
|
|
sessionState *SessionState
|
|
cipher cryptutil.Cipher
|
|
wantErr bool
|
|
wantLoadErr bool
|
|
}{
|
|
{"good", &SessionState{AccessToken: "token1234", RefreshToken: "refresh4321", RefreshDeadline: time.Now().Add(1 * time.Hour).Truncate(time.Second).UTC(), Email: "user@domain.com", User: "user"}, cipher, false, false},
|
|
{"bad cipher", &SessionState{AccessToken: "token1234", RefreshToken: "refresh4321", RefreshDeadline: time.Now().Add(1 * time.Hour).Truncate(time.Second).UTC(), Email: "user@domain.com", User: "user"}, mockCipher{}, true, true},
|
|
{"huge cookie", &SessionState{AccessToken: fmt.Sprintf("%x", hugeString), RefreshToken: "refresh4321", RefreshDeadline: time.Now().Add(1 * time.Hour).Truncate(time.Second).UTC(), Email: "user@domain.com", User: "user"}, cipher, false, false},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
s := &CookieStore{
|
|
Name: "_pomerium",
|
|
CookieSecure: true,
|
|
CookieHTTPOnly: true,
|
|
CookieDomain: "pomerium.io",
|
|
CookieExpire: 10 * time.Second,
|
|
CookieCipher: tt.cipher}
|
|
|
|
r := httptest.NewRequest("GET", "/", nil)
|
|
w := httptest.NewRecorder()
|
|
|
|
if err := s.SaveSession(w, r, tt.sessionState); (err != nil) != tt.wantErr {
|
|
t.Errorf("CookieStore.SaveSession() error = %v, wantErr %v", err, tt.wantErr)
|
|
}
|
|
r = httptest.NewRequest("GET", "/", nil)
|
|
for _, cookie := range w.Result().Cookies() {
|
|
t.Log(cookie)
|
|
r.AddCookie(cookie)
|
|
}
|
|
|
|
state, err := s.LoadSession(r)
|
|
if (err != nil) != tt.wantLoadErr {
|
|
t.Errorf("LoadSession() error = %v, wantErr %v", err, tt.wantLoadErr)
|
|
return
|
|
}
|
|
if err == nil && !reflect.DeepEqual(state, tt.sessionState) {
|
|
t.Errorf("CookieStore.LoadSession() got = \n%v, want \n%v", state, tt.sessionState)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestMockCSRFStore(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
mockCSRF *MockCSRFStore
|
|
newCSRFValue string
|
|
wantErr bool
|
|
}{
|
|
{"basic",
|
|
&MockCSRFStore{
|
|
ResponseCSRF: "ok",
|
|
Cookie: &http.Cookie{Name: "hi"}},
|
|
"newcsrf",
|
|
false},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
ms := tt.mockCSRF
|
|
ms.SetCSRF(nil, nil, tt.newCSRFValue)
|
|
ms.ClearCSRF(nil, nil)
|
|
got, err := ms.GetCSRF(nil)
|
|
if (err != nil) != tt.wantErr {
|
|
t.Errorf("MockCSRFStore.GetCSRF() error = %v, wantErr %v", err, tt.wantErr)
|
|
return
|
|
}
|
|
if !reflect.DeepEqual(got, tt.mockCSRF.Cookie) {
|
|
t.Errorf("MockCSRFStore.GetCSRF() = %v, want %v", got, tt.mockCSRF.Cookie)
|
|
}
|
|
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestMockSessionStore(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
mockCSRF *MockSessionStore
|
|
saveSession *SessionState
|
|
wantLoadErr bool
|
|
wantSaveErr bool
|
|
}{
|
|
{"basic",
|
|
&MockSessionStore{
|
|
ResponseSession: "test",
|
|
Session: &SessionState{AccessToken: "AccessToken"},
|
|
SaveError: nil,
|
|
LoadError: nil,
|
|
},
|
|
&SessionState{AccessToken: "AccessToken"},
|
|
false,
|
|
false},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
ms := tt.mockCSRF
|
|
|
|
err := ms.SaveSession(nil, nil, tt.saveSession)
|
|
if (err != nil) != tt.wantSaveErr {
|
|
t.Errorf("MockCSRFStore.GetCSRF() error = %v, wantSaveErr %v", err, tt.wantSaveErr)
|
|
return
|
|
}
|
|
got, err := ms.LoadSession(nil)
|
|
if (err != nil) != tt.wantLoadErr {
|
|
t.Errorf("MockCSRFStore.GetCSRF() error = %v, wantLoadErr %v", err, tt.wantLoadErr)
|
|
return
|
|
}
|
|
if !reflect.DeepEqual(got, tt.mockCSRF.Session) {
|
|
t.Errorf("MockCSRFStore.GetCSRF() = %v, want %v", got, tt.mockCSRF.Session)
|
|
}
|
|
ms.ClearSession(nil, nil)
|
|
if ms.ResponseSession != "" {
|
|
t.Errorf("ResponseSession not empty! %s", ms.ResponseSession)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func Test_splitDomain(t *testing.T) {
|
|
t.Parallel()
|
|
tests := []struct {
|
|
s string
|
|
want string
|
|
}{
|
|
{"httpbin.corp.example.com", "corp.example.com"},
|
|
{"some.httpbin.corp.example.com", "httpbin.corp.example.com"},
|
|
{"example.com", ""},
|
|
{"", ""},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.s, func(t *testing.T) {
|
|
if got := splitDomain(tt.s); got != tt.want {
|
|
t.Errorf("splitDomain() = %v, want %v", got, tt.want)
|
|
}
|
|
})
|
|
}
|
|
}
|