authenticateflow: add stateful flow (#4822)

Add a new Stateful type implementing the stateful authentication flow
from Pomerium v0.20 and earlier.

This consists mainly of logic from authenticate/handlers.go prior to
commits 57217af and 539fd51.

One significant change is to set the default IdP ID when an IdP ID is
not provided in the request URL (e.g. when signing in directly at the
authenticate service domain). Otherwise, if session state is stored with
an empty IdP ID, it won't be valid for any route.
This commit is contained in:
Kenneth Jenkins 2023-12-07 09:54:42 -08:00 committed by GitHub
parent 0e9a07eac9
commit c01d0e045d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 637 additions and 0 deletions

View file

@ -0,0 +1,271 @@
package authenticateflow
import (
"encoding/base64"
"errors"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/encoding"
"github.com/pomerium/pomerium/internal/encoding/mock"
"github.com/pomerium/pomerium/internal/sessions"
mstore "github.com/pomerium/pomerium/internal/sessions/mock"
"github.com/pomerium/pomerium/internal/urlutil"
"github.com/pomerium/pomerium/pkg/cryptutil"
)
func TestStatefulSignIn(t *testing.T) {
opts := config.NewDefaultOptions()
tests := []struct {
name string
host string
qp map[string]string
validSignature bool
session *sessions.State
encoder encoding.MarshalUnmarshaler
saveError error
wantErrorMsg string
wantRedirectBaseURL string
}{
{"good", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, true, &sessions.State{}, &mock.Encoder{}, nil, "", "https://dst.some.example/.pomerium/callback/"},
{"good alternate port", "corp.example.example:8443", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, true, &sessions.State{}, &mock.Encoder{}, nil, "", "https://dst.some.example/.pomerium/callback/"},
{"invalid signature", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, false, &sessions.State{}, &mock.Encoder{}, nil, "Bad Request:", ""},
{"bad redirect uri query", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "^^^"}, true, &sessions.State{}, &mock.Encoder{}, nil, "Bad Request:", ""},
{"bad marshal", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, true, &sessions.State{}, &mock.Encoder{MarshalError: errors.New("error")}, nil, "Bad Request: error", ""},
{"good with different programmatic redirect", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/", urlutil.QueryCallbackURI: "https://some.example"}, true, &sessions.State{}, &mock.Encoder{}, nil, "", "https://some.example"},
{"encrypted encoder error", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/", urlutil.QueryCallbackURI: "https://some.example"}, true, &sessions.State{}, &mock.Encoder{MarshalError: errors.New("error")}, nil, "Bad Request: error", ""},
{"good with callback uri set", "corp.example.example", map[string]string{urlutil.QueryCallbackURI: "https://some.example/", urlutil.QueryRedirectURI: "https://dst.some.example/"}, true, &sessions.State{}, &mock.Encoder{}, nil, "", "https://some.example/"},
{"bad callback uri set", "corp.example.example", map[string]string{urlutil.QueryCallbackURI: "^", urlutil.QueryRedirectURI: "https://dst.some.example/"}, true, &sessions.State{}, &mock.Encoder{}, nil, "Bad Request:", ""},
{"good programmatic request", "corp.example.example", map[string]string{urlutil.QueryIsProgrammatic: "true", urlutil.QueryRedirectURI: "https://dst.some.example/"}, true, &sessions.State{}, &mock.Encoder{}, nil, "", "https://dst.some.example/.pomerium/callback/"},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
sessionStore := &mstore.Store{SaveError: tt.saveError}
flow, err := NewStateful(&config.Config{Options: opts}, sessionStore)
if err != nil {
t.Fatal(err)
}
flow.sharedEncoder = tt.encoder
uri := &url.URL{Scheme: "https", Host: tt.host}
queryString := uri.Query()
for k, v := range tt.qp {
queryString.Set(k, v)
}
uri.RawQuery = queryString.Encode()
if tt.validSignature {
sharedKey, _ := opts.GetSharedKey()
uri = urlutil.NewSignedURL(sharedKey, uri).Sign()
}
r := httptest.NewRequest(http.MethodGet, uri.String(), nil)
r.Header.Set("Accept", "application/json")
w := httptest.NewRecorder()
err = flow.SignIn(w, r, tt.session)
result := w.Result()
if tt.wantErrorMsg == "" {
if err != nil {
t.Errorf("unexpected error: %v", err)
}
expectedStatus := "302 Found"
if result.Status != expectedStatus {
t.Errorf("wrong status code: got %v, want %v", result.Status, expectedStatus)
}
loc, err := url.Parse(result.Header.Get("Location"))
if err != nil {
t.Fatalf("couldn't parse redirect URL: %v", err)
}
loc.RawQuery = "" // ignore the query parameters
if loc.String() != tt.wantRedirectBaseURL {
t.Errorf("wrong redirect base URL: got %q, want %q",
loc.String(), tt.wantRedirectBaseURL)
}
} else {
if err == nil || !strings.Contains(err.Error(), tt.wantErrorMsg) {
t.Errorf("expected error containing %q; got %v", tt.wantErrorMsg, err)
}
}
})
}
}
func TestStatefulAuthenticateSignInURL(t *testing.T) {
opts := config.NewDefaultOptions()
opts.AuthenticateURLString = "https://authenticate.example.com"
key := cryptutil.NewKey()
opts.SharedKey = base64.StdEncoding.EncodeToString(key)
flow, err := NewStateful(&config.Config{Options: opts}, nil)
require.NoError(t, err)
t.Run("NilQueryParams", func(t *testing.T) {
redirectURL := &url.URL{Scheme: "https", Host: "example.com"}
u, err := flow.AuthenticateSignInURL(nil, nil, redirectURL, "fake-idp-id")
assert.NoError(t, err)
parsed, _ := url.Parse(u)
assert.NoError(t, urlutil.NewSignedURL(key, parsed).Validate())
assert.Equal(t, "https", parsed.Scheme)
assert.Equal(t, "authenticate.example.com", parsed.Host)
assert.Equal(t, "/.pomerium/sign_in", parsed.Path)
q := parsed.Query()
assert.Equal(t, "https://example.com", parsed.Query().Get("pomerium_redirect_uri"))
assert.Equal(t, "fake-idp-id", q.Get("pomerium_idp_id"))
})
t.Run("ExtraQueryParams", func(t *testing.T) {
redirectURL := &url.URL{Scheme: "https", Host: "example.com"}
q := url.Values{}
q.Set("foo", "bar")
u, err := flow.AuthenticateSignInURL(nil, q, redirectURL, "fake-idp-id")
assert.NoError(t, err)
parsed, _ := url.Parse(u)
assert.NoError(t, urlutil.NewSignedURL(key, parsed).Validate())
assert.Equal(t, "https", parsed.Scheme)
assert.Equal(t, "authenticate.example.com", parsed.Host)
assert.Equal(t, "/.pomerium/sign_in", parsed.Path)
q = parsed.Query()
assert.Equal(t, "https://example.com", q.Get("pomerium_redirect_uri"))
assert.Equal(t, "fake-idp-id", q.Get("pomerium_idp_id"))
assert.Equal(t, "bar", q.Get("foo"))
})
}
func TestStatefulGetIdentityProviderIDForURLValues(t *testing.T) {
flow := Stateful{defaultIdentityProviderID: "default-id"}
assert.Equal(t, "default-id", flow.GetIdentityProviderIDForURLValues(nil))
q := url.Values{"pomerium_idp_id": []string{"idp-id"}}
assert.Equal(t, "idp-id", flow.GetIdentityProviderIDForURLValues(q))
}
const goodEncryptionString = "KBEjQ9rnCxaAX-GOqetGw9ivEQURqts3zZ2mNGy0wnVa3SbtM399KlBq2nZ-9wM21FfsZX52er4jlmC7kPEKM3P7uZ41zR0zeys1-_74a5tQp-vsf1WXZfRsgVOuBcWPkMiWEoc379JFHxGDudp5VhU8B-dcQt4f3_PtLTHARkuH54io1Va2gNMq4Hiy8sQ1MPGCQeltH_JMzzdDpXdmdusWrXUvCGkba24muvAV06D8XRVJj6Iu9eK94qFnqcHc7wzziEbb8ADBues9dwbtb6jl8vMWz5rN6XvXqA5YpZv_MQZlsrO4oXFFQDevdgB84cX1tVbVu6qZvK_yQBZqzpOjWA9uIaoSENMytoXuWAlFO_sXjswfX8JTNdGwzB7qQRNPqxVG_sM_tzY3QhPm8zqwEzsXG5DokxZfVt2I5WJRUEovFDb4BnK9KFnnkEzLEdMudixVnXeGmTtycgJvoTeTCQRPfDYkcgJ7oKf4tGea-W7z5UAVa2RduJM9ZoM6YtJX7jgDm__PvvqcE0knJUF87XHBzdcOjoDF-CUze9xDJgNBlvPbJqVshKrwoqSYpePSDH9GUCNKxGequW3Ma8GvlFfhwd0rK6IZG-XWkyk0XSWQIGkDSjAvhB1wsOusCCguDjbpVZpaW5MMyTkmx68pl6qlIKT5UCcrVPl4ix5ZEj91mUDF0O1t04haD7VZuLVFXVGmqtFrBKI76sdYN-zkokaa1_chPRTyqMQFlqu_8LD6-RiK3UccGM-dEmnX72i91NP9F9OK0WJr9Cheup1C_P0mjqAO4Cb8oIHm0Oxz_mRqv5QbTGJtb3xwPLPuVjVCiE4gGBcuU2ixpSVf5HUF7y1KicVMCKiX9ATCBtg8sTdQZQnPEtHcHHAvdsnDVwev1LGfqA-Gdvg="
func TestStatefulCallback(t *testing.T) {
opts := config.NewDefaultOptions()
opts.SharedKey = "80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ="
tests := []struct {
name string
qp map[string]string
validSignature bool
cipher encoding.MarshalUnmarshaler
sessionStore sessions.SessionStore
wantErrorMsg string
}{
{
"good",
map[string]string{urlutil.QueryCallbackURI: "ok", urlutil.QuerySessionEncrypted: goodEncryptionString},
true,
&mock.Encoder{MarshalResponse: []byte("x")},
&mstore.Store{Session: &sessions.State{}},
"",
},
{
"good programmatic",
map[string]string{urlutil.QueryIsProgrammatic: "true", urlutil.QueryCallbackURI: "ok", urlutil.QuerySessionEncrypted: goodEncryptionString},
true,
&mock.Encoder{MarshalResponse: []byte("x")},
&mstore.Store{Session: &sessions.State{}},
"",
},
{
"invalid signature",
map[string]string{urlutil.QueryCallbackURI: "ok", urlutil.QuerySessionEncrypted: goodEncryptionString},
false,
&mock.Encoder{MarshalResponse: []byte("x")},
&mstore.Store{Session: &sessions.State{}},
"Bad Request:",
},
{
"bad decrypt",
map[string]string{urlutil.QuerySessionEncrypted: "KBEjQ9rnCxaAX-GOqexGw9ivEQURqts3zZ2mNGy0wnVa3SbtM399KlBq2nZ-9wM21FfsZX52er4jlmC7kPEKM3P7uZ41zR0zeys1-_74a5tQp-vsf1WXZfRsgVOuBcWPkMiWEoc379JFHxGDudp5VhU8B-dcQt4f3_PtLTHARkuH54io1Va2gNMq4Hiy8sQ1MPGCQeltH_JMzzdDpXdmdusWrXUvCGkba24muvAV06D8XRVJj6Iu9eK94qFnqcHc7wzziEbb8ADBues9dwbtb6jl8vMWz5rN6XvXqA5YpZv_MQZlsrO4oXFFQDevdgB84cX1tVbVu6qZvK_yQBZqzpOjWA9uIaoSENMytoXuWAlFO_sXjswfX8JTNdGwzB7qQRNPqxVG_sM_tzY3QhPm8zqwEzsXG5DokxZfVt2I5WJRUEovFDb4BnK9KFnnkEzLEdMudixVnXeGmTtycgJvoTeTCQRPfDYkcgJ7oKf4tGea-W7z5UAVa2RduJM9ZoM6YtJX7jgDm__PvvqcE0knJUF87XHBzdcOjoDF-CUze9xDJgNBlvPbJqVshKrwoqSYpePSDH9GUCNKxGequW3Ma8GvlFfhwd0rK6IZG-XWkyk0XSWQIGkDSjAvhB1wsOusCCguDjbpVZpaW5MMyTkmx68pl6qlIKT5UCcrVPl4ix5ZEj91mUDF0O1t04haD7VZuLVFXVGmqtFrBKI76sdYN-zkokaa1_chPRTyqMQFlqu_8LD6-RiK3UccGM-dEmnX72i91NP9F9OK0WJr9Cheup1C_P0mjqAO4Cb8oIHm0Oxz_mRqv5QbTGJtb3xwPLPuVjVCiE4gGBcuU2ixpSVf5HUF7y1KicVMCKiX9ATCBtg8sTdQZQnPEtHcHHAvdsnDVwev1LGfqA-Gdvg="},
true,
&mock.Encoder{MarshalResponse: []byte("x")},
&mstore.Store{Session: &sessions.State{}},
"proxy: callback token decrypt error:",
},
{
"bad save session",
map[string]string{urlutil.QuerySessionEncrypted: goodEncryptionString},
true,
&mock.Encoder{MarshalResponse: []byte("x")},
&mstore.Store{SaveError: errors.New("hi")},
"Internal Server Error: proxy: error saving session state:",
},
{
"bad base64",
map[string]string{urlutil.QuerySessionEncrypted: "^"},
true,
&mock.Encoder{MarshalResponse: []byte("x")},
&mstore.Store{Session: &sessions.State{}},
"proxy: malfromed callback token:",
},
{
"malformed redirect",
nil,
true,
&mock.Encoder{},
&mstore.Store{Session: &sessions.State{}},
"Bad Request:",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
flow, err := NewStateful(&config.Config{Options: opts}, tt.sessionStore)
if err != nil {
t.Fatal(err)
}
flow.sharedEncoder = tt.cipher
redirectURI := &url.URL{Scheme: "http", Host: "example.com", Path: "/"}
queryString := redirectURI.Query()
for k, v := range tt.qp {
queryString.Set(k, v)
}
redirectURI.RawQuery = queryString.Encode()
uri := &url.URL{Scheme: "https", Host: "example.com", Path: "/"}
if tt.qp != nil {
qu := uri.Query()
for k, v := range tt.qp {
qu.Set(k, v)
}
qu.Set(urlutil.QueryRedirectURI, redirectURI.String())
uri.RawQuery = qu.Encode()
}
if tt.validSignature {
sharedKey, _ := opts.GetSharedKey()
uri = urlutil.NewSignedURL(sharedKey, uri).Sign()
}
r := httptest.NewRequest(http.MethodGet, uri.String(), nil)
//fmt.Println(uri.String())
r.Host = r.URL.Host
r.Header.Set("Accept", "application/json")
w := httptest.NewRecorder()
err = flow.Callback(w, r)
if tt.wantErrorMsg == "" {
if err != nil {
t.Errorf("unexpected error: %v", err)
}
} else {
if err == nil || !strings.Contains(err.Error(), tt.wantErrorMsg) {
t.Errorf("expected error containing %q; got %v", tt.wantErrorMsg, err)
}
}
// XXX: assert redirect URL
})
}
}