mirror of
https://github.com/pomerium/pomerium.git
synced 2025-04-30 19:06:33 +02:00
Add a new Stateful type implementing the stateful authentication flow from Pomerium v0.20 and earlier. This consists mainly of logic from authenticate/handlers.go prior to commits57217af
and539fd51
. One significant change is to set the default IdP ID when an IdP ID is not provided in the request URL (e.g. when signing in directly at the authenticate service domain). Otherwise, if session state is stored with an empty IdP ID, it won't be valid for any route.
271 lines
12 KiB
Go
271 lines
12 KiB
Go
package authenticateflow
|
|
|
|
import (
|
|
"encoding/base64"
|
|
"errors"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"net/url"
|
|
"strings"
|
|
"testing"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
|
|
"github.com/pomerium/pomerium/config"
|
|
"github.com/pomerium/pomerium/internal/encoding"
|
|
"github.com/pomerium/pomerium/internal/encoding/mock"
|
|
"github.com/pomerium/pomerium/internal/sessions"
|
|
mstore "github.com/pomerium/pomerium/internal/sessions/mock"
|
|
"github.com/pomerium/pomerium/internal/urlutil"
|
|
"github.com/pomerium/pomerium/pkg/cryptutil"
|
|
)
|
|
|
|
func TestStatefulSignIn(t *testing.T) {
|
|
opts := config.NewDefaultOptions()
|
|
tests := []struct {
|
|
name string
|
|
|
|
host string
|
|
qp map[string]string
|
|
validSignature bool
|
|
|
|
session *sessions.State
|
|
encoder encoding.MarshalUnmarshaler
|
|
saveError error
|
|
|
|
wantErrorMsg string
|
|
wantRedirectBaseURL string
|
|
}{
|
|
{"good", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, true, &sessions.State{}, &mock.Encoder{}, nil, "", "https://dst.some.example/.pomerium/callback/"},
|
|
{"good alternate port", "corp.example.example:8443", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, true, &sessions.State{}, &mock.Encoder{}, nil, "", "https://dst.some.example/.pomerium/callback/"},
|
|
{"invalid signature", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, false, &sessions.State{}, &mock.Encoder{}, nil, "Bad Request:", ""},
|
|
{"bad redirect uri query", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "^^^"}, true, &sessions.State{}, &mock.Encoder{}, nil, "Bad Request:", ""},
|
|
{"bad marshal", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, true, &sessions.State{}, &mock.Encoder{MarshalError: errors.New("error")}, nil, "Bad Request: error", ""},
|
|
{"good with different programmatic redirect", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/", urlutil.QueryCallbackURI: "https://some.example"}, true, &sessions.State{}, &mock.Encoder{}, nil, "", "https://some.example"},
|
|
{"encrypted encoder error", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/", urlutil.QueryCallbackURI: "https://some.example"}, true, &sessions.State{}, &mock.Encoder{MarshalError: errors.New("error")}, nil, "Bad Request: error", ""},
|
|
{"good with callback uri set", "corp.example.example", map[string]string{urlutil.QueryCallbackURI: "https://some.example/", urlutil.QueryRedirectURI: "https://dst.some.example/"}, true, &sessions.State{}, &mock.Encoder{}, nil, "", "https://some.example/"},
|
|
{"bad callback uri set", "corp.example.example", map[string]string{urlutil.QueryCallbackURI: "^", urlutil.QueryRedirectURI: "https://dst.some.example/"}, true, &sessions.State{}, &mock.Encoder{}, nil, "Bad Request:", ""},
|
|
{"good programmatic request", "corp.example.example", map[string]string{urlutil.QueryIsProgrammatic: "true", urlutil.QueryRedirectURI: "https://dst.some.example/"}, true, &sessions.State{}, &mock.Encoder{}, nil, "", "https://dst.some.example/.pomerium/callback/"},
|
|
}
|
|
for _, tt := range tests {
|
|
tt := tt
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
sessionStore := &mstore.Store{SaveError: tt.saveError}
|
|
flow, err := NewStateful(&config.Config{Options: opts}, sessionStore)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
flow.sharedEncoder = tt.encoder
|
|
|
|
uri := &url.URL{Scheme: "https", Host: tt.host}
|
|
queryString := uri.Query()
|
|
for k, v := range tt.qp {
|
|
queryString.Set(k, v)
|
|
}
|
|
uri.RawQuery = queryString.Encode()
|
|
if tt.validSignature {
|
|
sharedKey, _ := opts.GetSharedKey()
|
|
uri = urlutil.NewSignedURL(sharedKey, uri).Sign()
|
|
}
|
|
|
|
r := httptest.NewRequest(http.MethodGet, uri.String(), nil)
|
|
r.Header.Set("Accept", "application/json")
|
|
|
|
w := httptest.NewRecorder()
|
|
err = flow.SignIn(w, r, tt.session)
|
|
result := w.Result()
|
|
if tt.wantErrorMsg == "" {
|
|
if err != nil {
|
|
t.Errorf("unexpected error: %v", err)
|
|
}
|
|
expectedStatus := "302 Found"
|
|
if result.Status != expectedStatus {
|
|
t.Errorf("wrong status code: got %v, want %v", result.Status, expectedStatus)
|
|
}
|
|
loc, err := url.Parse(result.Header.Get("Location"))
|
|
if err != nil {
|
|
t.Fatalf("couldn't parse redirect URL: %v", err)
|
|
}
|
|
loc.RawQuery = "" // ignore the query parameters
|
|
if loc.String() != tt.wantRedirectBaseURL {
|
|
t.Errorf("wrong redirect base URL: got %q, want %q",
|
|
loc.String(), tt.wantRedirectBaseURL)
|
|
}
|
|
} else {
|
|
if err == nil || !strings.Contains(err.Error(), tt.wantErrorMsg) {
|
|
t.Errorf("expected error containing %q; got %v", tt.wantErrorMsg, err)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestStatefulAuthenticateSignInURL(t *testing.T) {
|
|
opts := config.NewDefaultOptions()
|
|
opts.AuthenticateURLString = "https://authenticate.example.com"
|
|
key := cryptutil.NewKey()
|
|
opts.SharedKey = base64.StdEncoding.EncodeToString(key)
|
|
flow, err := NewStateful(&config.Config{Options: opts}, nil)
|
|
require.NoError(t, err)
|
|
|
|
t.Run("NilQueryParams", func(t *testing.T) {
|
|
redirectURL := &url.URL{Scheme: "https", Host: "example.com"}
|
|
u, err := flow.AuthenticateSignInURL(nil, nil, redirectURL, "fake-idp-id")
|
|
assert.NoError(t, err)
|
|
parsed, _ := url.Parse(u)
|
|
assert.NoError(t, urlutil.NewSignedURL(key, parsed).Validate())
|
|
assert.Equal(t, "https", parsed.Scheme)
|
|
assert.Equal(t, "authenticate.example.com", parsed.Host)
|
|
assert.Equal(t, "/.pomerium/sign_in", parsed.Path)
|
|
q := parsed.Query()
|
|
assert.Equal(t, "https://example.com", parsed.Query().Get("pomerium_redirect_uri"))
|
|
assert.Equal(t, "fake-idp-id", q.Get("pomerium_idp_id"))
|
|
})
|
|
t.Run("ExtraQueryParams", func(t *testing.T) {
|
|
redirectURL := &url.URL{Scheme: "https", Host: "example.com"}
|
|
q := url.Values{}
|
|
q.Set("foo", "bar")
|
|
u, err := flow.AuthenticateSignInURL(nil, q, redirectURL, "fake-idp-id")
|
|
assert.NoError(t, err)
|
|
parsed, _ := url.Parse(u)
|
|
assert.NoError(t, urlutil.NewSignedURL(key, parsed).Validate())
|
|
assert.Equal(t, "https", parsed.Scheme)
|
|
assert.Equal(t, "authenticate.example.com", parsed.Host)
|
|
assert.Equal(t, "/.pomerium/sign_in", parsed.Path)
|
|
q = parsed.Query()
|
|
assert.Equal(t, "https://example.com", q.Get("pomerium_redirect_uri"))
|
|
assert.Equal(t, "fake-idp-id", q.Get("pomerium_idp_id"))
|
|
assert.Equal(t, "bar", q.Get("foo"))
|
|
})
|
|
}
|
|
|
|
func TestStatefulGetIdentityProviderIDForURLValues(t *testing.T) {
|
|
flow := Stateful{defaultIdentityProviderID: "default-id"}
|
|
assert.Equal(t, "default-id", flow.GetIdentityProviderIDForURLValues(nil))
|
|
q := url.Values{"pomerium_idp_id": []string{"idp-id"}}
|
|
assert.Equal(t, "idp-id", flow.GetIdentityProviderIDForURLValues(q))
|
|
}
|
|
|
|
const goodEncryptionString = "KBEjQ9rnCxaAX-GOqetGw9ivEQURqts3zZ2mNGy0wnVa3SbtM399KlBq2nZ-9wM21FfsZX52er4jlmC7kPEKM3P7uZ41zR0zeys1-_74a5tQp-vsf1WXZfRsgVOuBcWPkMiWEoc379JFHxGDudp5VhU8B-dcQt4f3_PtLTHARkuH54io1Va2gNMq4Hiy8sQ1MPGCQeltH_JMzzdDpXdmdusWrXUvCGkba24muvAV06D8XRVJj6Iu9eK94qFnqcHc7wzziEbb8ADBues9dwbtb6jl8vMWz5rN6XvXqA5YpZv_MQZlsrO4oXFFQDevdgB84cX1tVbVu6qZvK_yQBZqzpOjWA9uIaoSENMytoXuWAlFO_sXjswfX8JTNdGwzB7qQRNPqxVG_sM_tzY3QhPm8zqwEzsXG5DokxZfVt2I5WJRUEovFDb4BnK9KFnnkEzLEdMudixVnXeGmTtycgJvoTeTCQRPfDYkcgJ7oKf4tGea-W7z5UAVa2RduJM9ZoM6YtJX7jgDm__PvvqcE0knJUF87XHBzdcOjoDF-CUze9xDJgNBlvPbJqVshKrwoqSYpePSDH9GUCNKxGequW3Ma8GvlFfhwd0rK6IZG-XWkyk0XSWQIGkDSjAvhB1wsOusCCguDjbpVZpaW5MMyTkmx68pl6qlIKT5UCcrVPl4ix5ZEj91mUDF0O1t04haD7VZuLVFXVGmqtFrBKI76sdYN-zkokaa1_chPRTyqMQFlqu_8LD6-RiK3UccGM-dEmnX72i91NP9F9OK0WJr9Cheup1C_P0mjqAO4Cb8oIHm0Oxz_mRqv5QbTGJtb3xwPLPuVjVCiE4gGBcuU2ixpSVf5HUF7y1KicVMCKiX9ATCBtg8sTdQZQnPEtHcHHAvdsnDVwev1LGfqA-Gdvg="
|
|
|
|
func TestStatefulCallback(t *testing.T) {
|
|
opts := config.NewDefaultOptions()
|
|
opts.SharedKey = "80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ="
|
|
tests := []struct {
|
|
name string
|
|
|
|
qp map[string]string
|
|
validSignature bool
|
|
cipher encoding.MarshalUnmarshaler
|
|
sessionStore sessions.SessionStore
|
|
|
|
wantErrorMsg string
|
|
}{
|
|
{
|
|
"good",
|
|
map[string]string{urlutil.QueryCallbackURI: "ok", urlutil.QuerySessionEncrypted: goodEncryptionString},
|
|
true,
|
|
&mock.Encoder{MarshalResponse: []byte("x")},
|
|
&mstore.Store{Session: &sessions.State{}},
|
|
"",
|
|
},
|
|
{
|
|
"good programmatic",
|
|
map[string]string{urlutil.QueryIsProgrammatic: "true", urlutil.QueryCallbackURI: "ok", urlutil.QuerySessionEncrypted: goodEncryptionString},
|
|
true,
|
|
&mock.Encoder{MarshalResponse: []byte("x")},
|
|
&mstore.Store{Session: &sessions.State{}},
|
|
"",
|
|
},
|
|
{
|
|
"invalid signature",
|
|
map[string]string{urlutil.QueryCallbackURI: "ok", urlutil.QuerySessionEncrypted: goodEncryptionString},
|
|
false,
|
|
&mock.Encoder{MarshalResponse: []byte("x")},
|
|
&mstore.Store{Session: &sessions.State{}},
|
|
"Bad Request:",
|
|
},
|
|
{
|
|
"bad decrypt",
|
|
map[string]string{urlutil.QuerySessionEncrypted: "KBEjQ9rnCxaAX-GOqexGw9ivEQURqts3zZ2mNGy0wnVa3SbtM399KlBq2nZ-9wM21FfsZX52er4jlmC7kPEKM3P7uZ41zR0zeys1-_74a5tQp-vsf1WXZfRsgVOuBcWPkMiWEoc379JFHxGDudp5VhU8B-dcQt4f3_PtLTHARkuH54io1Va2gNMq4Hiy8sQ1MPGCQeltH_JMzzdDpXdmdusWrXUvCGkba24muvAV06D8XRVJj6Iu9eK94qFnqcHc7wzziEbb8ADBues9dwbtb6jl8vMWz5rN6XvXqA5YpZv_MQZlsrO4oXFFQDevdgB84cX1tVbVu6qZvK_yQBZqzpOjWA9uIaoSENMytoXuWAlFO_sXjswfX8JTNdGwzB7qQRNPqxVG_sM_tzY3QhPm8zqwEzsXG5DokxZfVt2I5WJRUEovFDb4BnK9KFnnkEzLEdMudixVnXeGmTtycgJvoTeTCQRPfDYkcgJ7oKf4tGea-W7z5UAVa2RduJM9ZoM6YtJX7jgDm__PvvqcE0knJUF87XHBzdcOjoDF-CUze9xDJgNBlvPbJqVshKrwoqSYpePSDH9GUCNKxGequW3Ma8GvlFfhwd0rK6IZG-XWkyk0XSWQIGkDSjAvhB1wsOusCCguDjbpVZpaW5MMyTkmx68pl6qlIKT5UCcrVPl4ix5ZEj91mUDF0O1t04haD7VZuLVFXVGmqtFrBKI76sdYN-zkokaa1_chPRTyqMQFlqu_8LD6-RiK3UccGM-dEmnX72i91NP9F9OK0WJr9Cheup1C_P0mjqAO4Cb8oIHm0Oxz_mRqv5QbTGJtb3xwPLPuVjVCiE4gGBcuU2ixpSVf5HUF7y1KicVMCKiX9ATCBtg8sTdQZQnPEtHcHHAvdsnDVwev1LGfqA-Gdvg="},
|
|
true,
|
|
&mock.Encoder{MarshalResponse: []byte("x")},
|
|
&mstore.Store{Session: &sessions.State{}},
|
|
"proxy: callback token decrypt error:",
|
|
},
|
|
{
|
|
"bad save session",
|
|
map[string]string{urlutil.QuerySessionEncrypted: goodEncryptionString},
|
|
true,
|
|
&mock.Encoder{MarshalResponse: []byte("x")},
|
|
&mstore.Store{SaveError: errors.New("hi")},
|
|
"Internal Server Error: proxy: error saving session state:",
|
|
},
|
|
{
|
|
"bad base64",
|
|
map[string]string{urlutil.QuerySessionEncrypted: "^"},
|
|
true,
|
|
&mock.Encoder{MarshalResponse: []byte("x")},
|
|
&mstore.Store{Session: &sessions.State{}},
|
|
"proxy: malfromed callback token:",
|
|
},
|
|
{
|
|
"malformed redirect",
|
|
nil,
|
|
true,
|
|
&mock.Encoder{},
|
|
&mstore.Store{Session: &sessions.State{}},
|
|
"Bad Request:",
|
|
},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
flow, err := NewStateful(&config.Config{Options: opts}, tt.sessionStore)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
flow.sharedEncoder = tt.cipher
|
|
redirectURI := &url.URL{Scheme: "http", Host: "example.com", Path: "/"}
|
|
queryString := redirectURI.Query()
|
|
for k, v := range tt.qp {
|
|
queryString.Set(k, v)
|
|
}
|
|
redirectURI.RawQuery = queryString.Encode()
|
|
|
|
uri := &url.URL{Scheme: "https", Host: "example.com", Path: "/"}
|
|
if tt.qp != nil {
|
|
qu := uri.Query()
|
|
for k, v := range tt.qp {
|
|
qu.Set(k, v)
|
|
}
|
|
qu.Set(urlutil.QueryRedirectURI, redirectURI.String())
|
|
uri.RawQuery = qu.Encode()
|
|
}
|
|
if tt.validSignature {
|
|
sharedKey, _ := opts.GetSharedKey()
|
|
uri = urlutil.NewSignedURL(sharedKey, uri).Sign()
|
|
}
|
|
|
|
r := httptest.NewRequest(http.MethodGet, uri.String(), nil)
|
|
//fmt.Println(uri.String())
|
|
r.Host = r.URL.Host
|
|
|
|
r.Header.Set("Accept", "application/json")
|
|
|
|
w := httptest.NewRecorder()
|
|
err = flow.Callback(w, r)
|
|
if tt.wantErrorMsg == "" {
|
|
if err != nil {
|
|
t.Errorf("unexpected error: %v", err)
|
|
}
|
|
} else {
|
|
if err == nil || !strings.Contains(err.Error(), tt.wantErrorMsg) {
|
|
t.Errorf("expected error containing %q; got %v", tt.wantErrorMsg, err)
|
|
}
|
|
}
|
|
|
|
// XXX: assert redirect URL
|
|
})
|
|
}
|
|
}
|