pomerium/internal/sessions/rest_store_test.go
2019-07-13 18:28:51 -07:00

135 lines
3.5 KiB
Go

package sessions
import (
"errors"
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/google/go-cmp/cmp"
"github.com/pomerium/pomerium/internal/cryptutil"
)
func TestRestStore_SaveSession(t *testing.T) {
now := time.Date(2008, 1, 8, 17, 5, 5, 0, time.UTC)
tests := []struct {
name string
optionsName string
optionsCipher cryptutil.Cipher
sessionState *SessionState
wantErr bool
wantSaveResponse string
}{
{"good", "Authenticate", &cryptutil.MockCipher{MarshalResponse: "test"}, &SessionState{RefreshDeadline: now}, false, `{"Token":"test","Expiry":"2008-01-08T17:05:05Z"}`},
{"bad session marshal", "Authenticate", &cryptutil.MockCipher{MarshalError: errors.New("error")}, &SessionState{RefreshDeadline: now}, true, ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s, err := NewRestStore(
&RestStoreOptions{
Name: tt.optionsName,
Cipher: tt.optionsCipher,
})
if err != nil {
t.Fatalf("NewRestStore err %v", err)
}
r := httptest.NewRequest(http.MethodGet, "/", nil)
w := httptest.NewRecorder()
if err := s.SaveSession(w, r, tt.sessionState); (err != nil) != tt.wantErr {
t.Errorf("RestStore.SaveSession() error = %v, wantErr %v", err, tt.wantErr)
}
resp := w.Result()
body, _ := ioutil.ReadAll(resp.Body)
if diff := cmp.Diff(string(body), tt.wantSaveResponse); diff != "" {
t.Errorf("RestStore.SaveSession() got / want diff \n%s\n", diff)
}
})
}
}
func TestNewRestStore(t *testing.T) {
tests := []struct {
name string
optionsName string
optionsCipher cryptutil.Cipher
wantErr bool
}{
{"good", "Authenticate", &cryptutil.MockCipher{}, false},
{"good default to authenticate", "", &cryptutil.MockCipher{}, false},
{"empty cipher", "Authenticate", nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := NewRestStore(
&RestStoreOptions{
Name: tt.optionsName,
Cipher: tt.optionsCipher,
})
if (err != nil) != tt.wantErr {
t.Errorf("NewRestStore() error = %v, wantErr %v", err, tt.wantErr)
return
}
})
}
}
func TestRestStore_ClearSession(t *testing.T) {
tests := []struct {
name string
expectedStatus int
}{
{"always returns reset!", http.StatusUnauthorized},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := &RestStore{Name: "Authenticate", Cipher: &cryptutil.MockCipher{}}
r := httptest.NewRequest(http.MethodGet, "/", nil)
w := httptest.NewRecorder()
s.ClearSession(w, r)
resp := w.Result()
if diff := cmp.Diff(resp.StatusCode, tt.expectedStatus); diff != "" {
t.Errorf("RestStore.ClearSession() got / want diff \n%s\n", diff)
}
})
}
}
func TestRestStore_LoadSession(t *testing.T) {
tests := []struct {
name string
optionsName string
optionsCipher cryptutil.Cipher
token string
wantErr bool
}{
{"good", "Authorization", &cryptutil.MockCipher{}, "test", false},
{"empty auth header", "", &cryptutil.MockCipher{}, "", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := &RestStore{
Name: tt.optionsName,
Cipher: tt.optionsCipher,
}
r := httptest.NewRequest(http.MethodGet, "/", nil)
if tt.optionsName != "" {
r.Header.Set(tt.optionsName, fmt.Sprintf(("Bearer %s"), tt.token))
}
_, err := s.LoadSession(r)
if (err != nil) != tt.wantErr {
t.Errorf("RestStore.LoadSession() error = %v, wantErr %v", err, tt.wantErr)
return
}
})
}
}