mirror of
https://github.com/pomerium/pomerium.git
synced 2025-06-04 03:42:49 +02:00
internal/sessions: refactor how sessions loading (#351)
These chagnes standardize how session loading is done for session cookie, auth bearer token, and query params. - Bearer token previously combined with session cookie. - rearranged cookie-store to put exported methods above unexported - added header store that implements session loader interface - added query param store that implements session loader interface Signed-off-by: Bobby DeSimone <bobbydesimone@gmail.com>
This commit is contained in:
parent
7aa4621b1b
commit
badd8d69af
13 changed files with 322 additions and 234 deletions
|
@ -9,9 +9,8 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/cryptutil"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/pomerium/pomerium/internal/cryptutil"
|
||||
)
|
||||
|
||||
func TestNewContext(t *testing.T) {
|
||||
|
@ -75,8 +74,8 @@ func TestVerifier(t *testing.T) {
|
|||
{"expired auth header", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(-10 * time.Second)}, false, true, false, "internal/sessions: session is expired\n", http.StatusUnauthorized},
|
||||
{"malformed auth header", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(-10 * time.Second)}, false, true, false, "internal/sessions: session is malformed\n", http.StatusUnauthorized},
|
||||
{"good auth query param session", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(10 * time.Second)}, false, true, true, http.StatusText(http.StatusOK), http.StatusOK},
|
||||
{"expired auth query param", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(-10 * time.Second)}, false, true, true, "internal/sessions: session is expired\n", http.StatusUnauthorized},
|
||||
{"malformed auth query param", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(-10 * time.Second)}, false, true, true, "internal/sessions: session is malformed\n", http.StatusUnauthorized},
|
||||
{"expired auth query param", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(-10 * time.Second)}, false, false, true, "internal/sessions: session is expired\n", http.StatusUnauthorized},
|
||||
{"malformed auth query param", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(-10 * time.Second)}, false, false, true, "internal/sessions: session is malformed\n", http.StatusUnauthorized},
|
||||
{"no session", State{Email: "user@pomerium.io", RefreshDeadline: time.Now().Add(-10 * time.Second)}, false, false, false, "internal/sessions: session is not found\n", http.StatusUnauthorized},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
|
@ -94,7 +93,6 @@ func TestVerifier(t *testing.T) {
|
|||
if strings.Contains(tt.name, "malformed") {
|
||||
// add some garbage to the end of the string
|
||||
encSession += cryptutil.NewBase64Key()
|
||||
fmt.Println(encSession)
|
||||
}
|
||||
|
||||
cs, err := NewCookieStore(&CookieStoreOptions{
|
||||
|
@ -104,6 +102,9 @@ func TestVerifier(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
as := NewHeaderStore(encoder)
|
||||
|
||||
qp := NewQueryParamStore(encoder)
|
||||
|
||||
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
r.Header.Set("Accept", "application/json")
|
||||
|
@ -114,11 +115,12 @@ func TestVerifier(t *testing.T) {
|
|||
r.Header.Set("Authorization", "Bearer "+encSession)
|
||||
} else if tt.param {
|
||||
q := r.URL.Query()
|
||||
q.Add("pomerium_session", encSession)
|
||||
|
||||
q.Set("pomerium_session", encSession)
|
||||
r.URL.RawQuery = q.Encode()
|
||||
}
|
||||
|
||||
got := RetrieveSession(cs)(testAuthorizer((fnh)))
|
||||
got := RetrieveSession(cs, as, qp)(testAuthorizer((fnh)))
|
||||
got.ServeHTTP(w, r)
|
||||
|
||||
gotBody := w.Body.String()
|
||||
|
@ -133,3 +135,23 @@ func TestVerifier(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_contextKey_String(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
keyName string
|
||||
want string
|
||||
}{
|
||||
{"simple example", "test", "context value test"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
k := &contextKey{
|
||||
name: tt.keyName,
|
||||
}
|
||||
if got := k.String(); got != tt.want {
|
||||
t.Errorf("contextKey.String() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue