pomerium/internal/sessions/cookie_store_test.go
2019-07-13 18:28:51 -07:00

352 lines
11 KiB
Go

package sessions
import (
"crypto/rand"
"errors"
"fmt"
"net/http"
"net/http/httptest"
"reflect"
"testing"
"time"
"github.com/google/go-cmp/cmp"
"github.com/pomerium/pomerium/internal/cryptutil"
)
type mockCipher struct{}
func (a mockCipher) Encrypt(s []byte) ([]byte, error) {
if string(s) == "error" {
return []byte(""), errors.New("error encrypting")
}
return []byte("OK"), nil
}
func (a mockCipher) Decrypt(s []byte) ([]byte, error) {
if string(s) == "error" {
return []byte(""), errors.New("error encrypting")
}
return []byte("OK"), nil
}
func (a mockCipher) Marshal(s interface{}) (string, error) { return "", errors.New("error") }
func (a mockCipher) Unmarshal(s string, i interface{}) error {
if s == "unmarshal error" || s == "error" {
return errors.New("error")
}
return nil
}
func TestNewCookieStore(t *testing.T) {
cipher, err := cryptutil.NewCipher(cryptutil.GenerateKey())
if err != nil {
t.Fatal(err)
}
tests := []struct {
name string
opts *CookieStoreOptions
want *CookieStore
wantErr bool
}{
{"good",
&CookieStoreOptions{
Name: "_cookie",
CookieSecure: true,
CookieHTTPOnly: true,
CookieDomain: "pomerium.io",
CookieExpire: 10 * time.Second,
CookieCipher: cipher,
},
&CookieStore{
Name: "_cookie",
CookieSecure: true,
CookieHTTPOnly: true,
CookieDomain: "pomerium.io",
CookieExpire: 10 * time.Second,
CookieCipher: cipher,
},
false},
{"missing name",
&CookieStoreOptions{
Name: "",
CookieSecure: true,
CookieHTTPOnly: true,
CookieDomain: "pomerium.io",
CookieExpire: 10 * time.Second,
CookieCipher: cipher,
},
nil,
true},
{"missing cipher",
&CookieStoreOptions{
Name: "_pomerium",
CookieSecure: true,
CookieHTTPOnly: true,
CookieDomain: "pomerium.io",
CookieExpire: 10 * time.Second,
CookieCipher: nil,
},
nil,
true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := NewCookieStore(tt.opts)
if (err != nil) != tt.wantErr {
t.Errorf("NewCookieStore() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("NewCookieStore() = %#v, want %#v", got, tt.want)
}
})
}
}
func TestCookieStore_makeCookie(t *testing.T) {
cipher, err := cryptutil.NewCipher(cryptutil.GenerateKey())
if err != nil {
t.Fatal(err)
}
now := time.Now()
tests := []struct {
name string
domain string
cookieDomain string
cookieName string
value string
expiration time.Duration
want *http.Cookie
wantCSRF *http.Cookie
}{
{"good", "http://httpbin.corp.pomerium.io", "", "_pomerium", "value", 0, &http.Cookie{Name: "_pomerium", Value: "value", Path: "/", Domain: "corp.pomerium.io", Secure: true, HttpOnly: true}, &http.Cookie{Name: "_pomerium_csrf", Value: "value", Path: "/", Domain: "httpbin.corp.pomerium.io", Secure: true, HttpOnly: true}},
{"domains with https", "https://httpbin.corp.pomerium.io", "", "_pomerium", "value", 0, &http.Cookie{Name: "_pomerium", Value: "value", Path: "/", Domain: "corp.pomerium.io", Secure: true, HttpOnly: true}, &http.Cookie{Name: "_pomerium_csrf", Value: "value", Path: "/", Domain: "httpbin.corp.pomerium.io", Secure: true, HttpOnly: true}},
{"domain with port", "http://httpbin.corp.pomerium.io:443", "", "_pomerium", "value", 0, &http.Cookie{Name: "_pomerium", Value: "value", Path: "/", Domain: "corp.pomerium.io", Secure: true, HttpOnly: true}, &http.Cookie{Name: "_pomerium_csrf", Value: "value", Path: "/", Domain: "httpbin.corp.pomerium.io", Secure: true, HttpOnly: true}},
{"expiration set", "http://httpbin.corp.pomerium.io:443", "", "_pomerium", "value", 10 * time.Second, &http.Cookie{Expires: now.Add(10 * time.Second), Name: "_pomerium", Value: "value", Path: "/", Domain: "corp.pomerium.io", Secure: true, HttpOnly: true}, &http.Cookie{Expires: now.Add(10 * time.Second), Name: "_pomerium_csrf", Value: "value", Path: "/", Domain: "httpbin.corp.pomerium.io", Secure: true, HttpOnly: true}},
{"good", "http://httpbin.corp.pomerium.io", "pomerium.io", "_pomerium", "value", 0, &http.Cookie{Name: "_pomerium", Value: "value", Path: "/", Domain: "pomerium.io", Secure: true, HttpOnly: true}, &http.Cookie{Name: "_pomerium_csrf", Value: "value", Path: "/", Domain: "httpbin.corp.pomerium.io", Secure: true, HttpOnly: true}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := httptest.NewRequest("GET", tt.domain, nil)
s, err := NewCookieStore(
&CookieStoreOptions{
Name: "_pomerium",
CookieSecure: true,
CookieHTTPOnly: true,
CookieDomain: tt.cookieDomain,
CookieExpire: 10 * time.Second,
CookieCipher: cipher})
if err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(s.makeCookie(r, tt.cookieName, tt.value, tt.expiration, now), tt.want); diff != "" {
t.Errorf("CookieStore.makeCookie() = \n%s", diff)
}
if diff := cmp.Diff(s.makeSessionCookie(r, tt.value, tt.expiration, now), tt.want); diff != "" {
t.Errorf("CookieStore.makeSessionCookie() = \n%s", diff)
}
got := s.makeCSRFCookie(r, tt.value, tt.expiration, now)
tt.wantCSRF.Name = "_pomerium_csrf"
if !reflect.DeepEqual(got, tt.wantCSRF) {
t.Errorf("CookieStore.makeCookie() = \n%#v, \nwant\n%#v", got, tt.wantCSRF)
}
w := httptest.NewRecorder()
want := "new-csrf"
s.SetCSRF(w, r, want)
found := false
for _, cookie := range w.Result().Cookies() {
if cookie.Name == s.Name+"_csrf" && cookie.Value == want {
found = true
break
}
}
if !found {
t.Error("SetCSRF failed")
}
w = httptest.NewRecorder()
s.ClearCSRF(w, r)
for _, cookie := range w.Result().Cookies() {
if cookie.Name == s.Name+"_csrf" && cookie.Value == want {
t.Error("clear csrf failed")
break
}
}
w = httptest.NewRecorder()
want = "new-session"
s.setSessionCookie(w, r, want)
found = false
for _, cookie := range w.Result().Cookies() {
if cookie.Name == s.Name && cookie.Value == want {
found = true
break
}
}
if !found {
t.Error("SetCSRF failed")
}
w = httptest.NewRecorder()
s.ClearSession(w, r)
for _, cookie := range w.Result().Cookies() {
if cookie.Name == s.Name && cookie.Value == want {
t.Error("clear csrf failed")
break
}
}
})
}
}
func TestCookieStore_SaveSession(t *testing.T) {
cipher, err := cryptutil.NewCipher(cryptutil.GenerateKey())
if err != nil {
t.Fatal(err)
}
hugeString := make([]byte, 4097)
if _, err := rand.Read(hugeString); err != nil {
t.Fatal(err)
}
tests := []struct {
name string
sessionState *SessionState
cipher cryptutil.Cipher
wantErr bool
wantLoadErr bool
}{
{"good", &SessionState{AccessToken: "token1234", RefreshToken: "refresh4321", RefreshDeadline: time.Now().Add(1 * time.Hour).Truncate(time.Second).UTC(), Email: "user@domain.com", User: "user"}, cipher, false, false},
{"bad cipher", &SessionState{AccessToken: "token1234", RefreshToken: "refresh4321", RefreshDeadline: time.Now().Add(1 * time.Hour).Truncate(time.Second).UTC(), Email: "user@domain.com", User: "user"}, mockCipher{}, true, true},
{"huge cookie", &SessionState{AccessToken: fmt.Sprintf("%x", hugeString), RefreshToken: "refresh4321", RefreshDeadline: time.Now().Add(1 * time.Hour).Truncate(time.Second).UTC(), Email: "user@domain.com", User: "user"}, cipher, false, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := &CookieStore{
Name: "_pomerium",
CookieSecure: true,
CookieHTTPOnly: true,
CookieDomain: "pomerium.io",
CookieExpire: 10 * time.Second,
CookieCipher: tt.cipher}
r := httptest.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
if err := s.SaveSession(w, r, tt.sessionState); (err != nil) != tt.wantErr {
t.Errorf("CookieStore.SaveSession() error = %v, wantErr %v", err, tt.wantErr)
}
r = httptest.NewRequest("GET", "/", nil)
for _, cookie := range w.Result().Cookies() {
t.Log(cookie)
r.AddCookie(cookie)
}
state, err := s.LoadSession(r)
if (err != nil) != tt.wantLoadErr {
t.Errorf("LoadSession() error = %v, wantErr %v", err, tt.wantLoadErr)
return
}
if err == nil && !reflect.DeepEqual(state, tt.sessionState) {
t.Errorf("CookieStore.LoadSession() got = \n%v, want \n%v", state, tt.sessionState)
}
})
}
}
func TestMockCSRFStore(t *testing.T) {
tests := []struct {
name string
mockCSRF *MockCSRFStore
newCSRFValue string
wantErr bool
}{
{"basic",
&MockCSRFStore{
ResponseCSRF: "ok",
Cookie: &http.Cookie{Name: "hi"}},
"newcsrf",
false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ms := tt.mockCSRF
ms.SetCSRF(nil, nil, tt.newCSRFValue)
ms.ClearCSRF(nil, nil)
got, err := ms.GetCSRF(nil)
if (err != nil) != tt.wantErr {
t.Errorf("MockCSRFStore.GetCSRF() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.mockCSRF.Cookie) {
t.Errorf("MockCSRFStore.GetCSRF() = %v, want %v", got, tt.mockCSRF.Cookie)
}
})
}
}
func TestMockSessionStore(t *testing.T) {
tests := []struct {
name string
mockCSRF *MockSessionStore
saveSession *SessionState
wantLoadErr bool
wantSaveErr bool
}{
{"basic",
&MockSessionStore{
ResponseSession: "test",
Session: &SessionState{AccessToken: "AccessToken"},
SaveError: nil,
LoadError: nil,
},
&SessionState{AccessToken: "AccessToken"},
false,
false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ms := tt.mockCSRF
err := ms.SaveSession(nil, nil, tt.saveSession)
if (err != nil) != tt.wantSaveErr {
t.Errorf("MockCSRFStore.GetCSRF() error = %v, wantSaveErr %v", err, tt.wantSaveErr)
return
}
got, err := ms.LoadSession(nil)
if (err != nil) != tt.wantLoadErr {
t.Errorf("MockCSRFStore.GetCSRF() error = %v, wantLoadErr %v", err, tt.wantLoadErr)
return
}
if !reflect.DeepEqual(got, tt.mockCSRF.Session) {
t.Errorf("MockCSRFStore.GetCSRF() = %v, want %v", got, tt.mockCSRF.Session)
}
ms.ClearSession(nil, nil)
if ms.ResponseSession != "" {
t.Errorf("ResponseSession not empty! %s", ms.ResponseSession)
}
})
}
}
func Test_splitDomain(t *testing.T) {
t.Parallel()
tests := []struct {
s string
want string
}{
{"httpbin.corp.example.com", "corp.example.com"},
{"some.httpbin.corp.example.com", "httpbin.corp.example.com"},
{"example.com", ""},
{"", ""},
}
for _, tt := range tests {
t.Run(tt.s, func(t *testing.T) {
if got := splitDomain(tt.s); got != tt.want {
t.Errorf("splitDomain() = %v, want %v", got, tt.want)
}
})
}
}