mirror of
https://github.com/pomerium/pomerium.git
synced 2025-04-30 19:06:33 +02:00
The identity manager expects to be able to read session ID and user ID from any deleted databroker session records. The session.Delete() wrapper method is not compatible with this expectation, as it calls Put() with a record containing an empty session. The stateful authentication flow currently calls session.Delete() from its RevokeSession() method. The result is that the identity manager will not correctly track sessions deleted by the the stateful authentication flow, and will still try to use them during session refresh and user info refresh. Instead, let's change the stateful authentication flow RevokeSession() method to perform deletions in a way that is compatible with the current identity manager code. That is, include the existing session data in the Put() call to delete the revoked session.
388 lines
15 KiB
Go
388 lines
15 KiB
Go
package authenticateflow
|
|
|
|
import (
|
|
"context"
|
|
"encoding/base64"
|
|
"errors"
|
|
"fmt"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"net/url"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/golang/mock/gomock"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
"golang.org/x/oauth2"
|
|
"google.golang.org/grpc"
|
|
"google.golang.org/protobuf/proto"
|
|
"google.golang.org/protobuf/types/known/timestamppb"
|
|
|
|
"github.com/pomerium/pomerium/config"
|
|
"github.com/pomerium/pomerium/internal/encoding"
|
|
"github.com/pomerium/pomerium/internal/encoding/mock"
|
|
"github.com/pomerium/pomerium/internal/identity"
|
|
"github.com/pomerium/pomerium/internal/sessions"
|
|
mstore "github.com/pomerium/pomerium/internal/sessions/mock"
|
|
"github.com/pomerium/pomerium/internal/urlutil"
|
|
"github.com/pomerium/pomerium/pkg/cryptutil"
|
|
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
|
"github.com/pomerium/pomerium/pkg/grpc/databroker/mock_databroker"
|
|
"github.com/pomerium/pomerium/pkg/grpc/session"
|
|
"github.com/pomerium/pomerium/pkg/protoutil"
|
|
)
|
|
|
|
func TestStatefulSignIn(t *testing.T) {
|
|
opts := config.NewDefaultOptions()
|
|
tests := []struct {
|
|
name string
|
|
|
|
host string
|
|
qp map[string]string
|
|
validSignature bool
|
|
|
|
session *sessions.State
|
|
encoder encoding.MarshalUnmarshaler
|
|
saveError error
|
|
|
|
wantErrorMsg string
|
|
wantRedirectBaseURL string
|
|
}{
|
|
{"good", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, true, &sessions.State{}, &mock.Encoder{}, nil, "", "https://dst.some.example/.pomerium/callback/"},
|
|
{"good alternate port", "corp.example.example:8443", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, true, &sessions.State{}, &mock.Encoder{}, nil, "", "https://dst.some.example/.pomerium/callback/"},
|
|
{"invalid signature", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, false, &sessions.State{}, &mock.Encoder{}, nil, "Bad Request:", ""},
|
|
{"bad redirect uri query", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "^^^"}, true, &sessions.State{}, &mock.Encoder{}, nil, "Bad Request:", ""},
|
|
{"bad marshal", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, true, &sessions.State{}, &mock.Encoder{MarshalError: errors.New("error")}, nil, "Bad Request: error", ""},
|
|
{"good with different programmatic redirect", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/", urlutil.QueryCallbackURI: "https://some.example"}, true, &sessions.State{}, &mock.Encoder{}, nil, "", "https://some.example"},
|
|
{"encrypted encoder error", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/", urlutil.QueryCallbackURI: "https://some.example"}, true, &sessions.State{}, &mock.Encoder{MarshalError: errors.New("error")}, nil, "Bad Request: error", ""},
|
|
{"good with callback uri set", "corp.example.example", map[string]string{urlutil.QueryCallbackURI: "https://some.example/", urlutil.QueryRedirectURI: "https://dst.some.example/"}, true, &sessions.State{}, &mock.Encoder{}, nil, "", "https://some.example/"},
|
|
{"bad callback uri set", "corp.example.example", map[string]string{urlutil.QueryCallbackURI: "^", urlutil.QueryRedirectURI: "https://dst.some.example/"}, true, &sessions.State{}, &mock.Encoder{}, nil, "Bad Request:", ""},
|
|
{"good programmatic request", "corp.example.example", map[string]string{urlutil.QueryIsProgrammatic: "true", urlutil.QueryRedirectURI: "https://dst.some.example/"}, true, &sessions.State{}, &mock.Encoder{}, nil, "", "https://dst.some.example/.pomerium/callback/"},
|
|
}
|
|
for _, tt := range tests {
|
|
tt := tt
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
sessionStore := &mstore.Store{SaveError: tt.saveError}
|
|
flow, err := NewStateful(&config.Config{Options: opts}, sessionStore)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
flow.sharedEncoder = tt.encoder
|
|
|
|
uri := &url.URL{Scheme: "https", Host: tt.host}
|
|
queryString := uri.Query()
|
|
for k, v := range tt.qp {
|
|
queryString.Set(k, v)
|
|
}
|
|
uri.RawQuery = queryString.Encode()
|
|
if tt.validSignature {
|
|
sharedKey, _ := opts.GetSharedKey()
|
|
uri = urlutil.NewSignedURL(sharedKey, uri).Sign()
|
|
}
|
|
|
|
r := httptest.NewRequest(http.MethodGet, uri.String(), nil)
|
|
r.Header.Set("Accept", "application/json")
|
|
|
|
w := httptest.NewRecorder()
|
|
err = flow.SignIn(w, r, tt.session)
|
|
result := w.Result()
|
|
if tt.wantErrorMsg == "" {
|
|
if err != nil {
|
|
t.Errorf("unexpected error: %v", err)
|
|
}
|
|
expectedStatus := "302 Found"
|
|
if result.Status != expectedStatus {
|
|
t.Errorf("wrong status code: got %v, want %v", result.Status, expectedStatus)
|
|
}
|
|
loc, err := url.Parse(result.Header.Get("Location"))
|
|
if err != nil {
|
|
t.Fatalf("couldn't parse redirect URL: %v", err)
|
|
}
|
|
loc.RawQuery = "" // ignore the query parameters
|
|
if loc.String() != tt.wantRedirectBaseURL {
|
|
t.Errorf("wrong redirect base URL: got %q, want %q",
|
|
loc.String(), tt.wantRedirectBaseURL)
|
|
}
|
|
} else {
|
|
if err == nil || !strings.Contains(err.Error(), tt.wantErrorMsg) {
|
|
t.Errorf("expected error containing %q; got %v", tt.wantErrorMsg, err)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestStatefulAuthenticateSignInURL(t *testing.T) {
|
|
opts := config.NewDefaultOptions()
|
|
opts.AuthenticateURLString = "https://authenticate.example.com"
|
|
key := cryptutil.NewKey()
|
|
opts.SharedKey = base64.StdEncoding.EncodeToString(key)
|
|
flow, err := NewStateful(&config.Config{Options: opts}, nil)
|
|
require.NoError(t, err)
|
|
|
|
t.Run("NilQueryParams", func(t *testing.T) {
|
|
redirectURL := &url.URL{Scheme: "https", Host: "example.com"}
|
|
u, err := flow.AuthenticateSignInURL(nil, nil, redirectURL, "fake-idp-id")
|
|
assert.NoError(t, err)
|
|
parsed, _ := url.Parse(u)
|
|
assert.NoError(t, urlutil.NewSignedURL(key, parsed).Validate())
|
|
assert.Equal(t, "https", parsed.Scheme)
|
|
assert.Equal(t, "authenticate.example.com", parsed.Host)
|
|
assert.Equal(t, "/.pomerium/sign_in", parsed.Path)
|
|
q := parsed.Query()
|
|
assert.Equal(t, "https://example.com", parsed.Query().Get("pomerium_redirect_uri"))
|
|
assert.Equal(t, "fake-idp-id", q.Get("pomerium_idp_id"))
|
|
})
|
|
t.Run("ExtraQueryParams", func(t *testing.T) {
|
|
redirectURL := &url.URL{Scheme: "https", Host: "example.com"}
|
|
q := url.Values{}
|
|
q.Set("foo", "bar")
|
|
u, err := flow.AuthenticateSignInURL(nil, q, redirectURL, "fake-idp-id")
|
|
assert.NoError(t, err)
|
|
parsed, _ := url.Parse(u)
|
|
assert.NoError(t, urlutil.NewSignedURL(key, parsed).Validate())
|
|
assert.Equal(t, "https", parsed.Scheme)
|
|
assert.Equal(t, "authenticate.example.com", parsed.Host)
|
|
assert.Equal(t, "/.pomerium/sign_in", parsed.Path)
|
|
q = parsed.Query()
|
|
assert.Equal(t, "https://example.com", q.Get("pomerium_redirect_uri"))
|
|
assert.Equal(t, "fake-idp-id", q.Get("pomerium_idp_id"))
|
|
assert.Equal(t, "bar", q.Get("foo"))
|
|
})
|
|
}
|
|
|
|
func TestStatefulGetIdentityProviderIDForURLValues(t *testing.T) {
|
|
flow := Stateful{defaultIdentityProviderID: "default-id"}
|
|
assert.Equal(t, "default-id", flow.GetIdentityProviderIDForURLValues(nil))
|
|
q := url.Values{"pomerium_idp_id": []string{"idp-id"}}
|
|
assert.Equal(t, "idp-id", flow.GetIdentityProviderIDForURLValues(q))
|
|
}
|
|
|
|
const goodEncryptionString = "KBEjQ9rnCxaAX-GOqetGw9ivEQURqts3zZ2mNGy0wnVa3SbtM399KlBq2nZ-9wM21FfsZX52er4jlmC7kPEKM3P7uZ41zR0zeys1-_74a5tQp-vsf1WXZfRsgVOuBcWPkMiWEoc379JFHxGDudp5VhU8B-dcQt4f3_PtLTHARkuH54io1Va2gNMq4Hiy8sQ1MPGCQeltH_JMzzdDpXdmdusWrXUvCGkba24muvAV06D8XRVJj6Iu9eK94qFnqcHc7wzziEbb8ADBues9dwbtb6jl8vMWz5rN6XvXqA5YpZv_MQZlsrO4oXFFQDevdgB84cX1tVbVu6qZvK_yQBZqzpOjWA9uIaoSENMytoXuWAlFO_sXjswfX8JTNdGwzB7qQRNPqxVG_sM_tzY3QhPm8zqwEzsXG5DokxZfVt2I5WJRUEovFDb4BnK9KFnnkEzLEdMudixVnXeGmTtycgJvoTeTCQRPfDYkcgJ7oKf4tGea-W7z5UAVa2RduJM9ZoM6YtJX7jgDm__PvvqcE0knJUF87XHBzdcOjoDF-CUze9xDJgNBlvPbJqVshKrwoqSYpePSDH9GUCNKxGequW3Ma8GvlFfhwd0rK6IZG-XWkyk0XSWQIGkDSjAvhB1wsOusCCguDjbpVZpaW5MMyTkmx68pl6qlIKT5UCcrVPl4ix5ZEj91mUDF0O1t04haD7VZuLVFXVGmqtFrBKI76sdYN-zkokaa1_chPRTyqMQFlqu_8LD6-RiK3UccGM-dEmnX72i91NP9F9OK0WJr9Cheup1C_P0mjqAO4Cb8oIHm0Oxz_mRqv5QbTGJtb3xwPLPuVjVCiE4gGBcuU2ixpSVf5HUF7y1KicVMCKiX9ATCBtg8sTdQZQnPEtHcHHAvdsnDVwev1LGfqA-Gdvg="
|
|
|
|
func TestStatefulCallback(t *testing.T) {
|
|
opts := config.NewDefaultOptions()
|
|
opts.SharedKey = "80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ="
|
|
tests := []struct {
|
|
name string
|
|
|
|
qp map[string]string
|
|
validSignature bool
|
|
cipher encoding.MarshalUnmarshaler
|
|
sessionStore sessions.SessionStore
|
|
|
|
wantErrorMsg string
|
|
}{
|
|
{
|
|
"good",
|
|
map[string]string{urlutil.QueryCallbackURI: "ok", urlutil.QuerySessionEncrypted: goodEncryptionString},
|
|
true,
|
|
&mock.Encoder{MarshalResponse: []byte("x")},
|
|
&mstore.Store{Session: &sessions.State{}},
|
|
"",
|
|
},
|
|
{
|
|
"good programmatic",
|
|
map[string]string{urlutil.QueryIsProgrammatic: "true", urlutil.QueryCallbackURI: "ok", urlutil.QuerySessionEncrypted: goodEncryptionString},
|
|
true,
|
|
&mock.Encoder{MarshalResponse: []byte("x")},
|
|
&mstore.Store{Session: &sessions.State{}},
|
|
"",
|
|
},
|
|
{
|
|
"invalid signature",
|
|
map[string]string{urlutil.QueryCallbackURI: "ok", urlutil.QuerySessionEncrypted: goodEncryptionString},
|
|
false,
|
|
&mock.Encoder{MarshalResponse: []byte("x")},
|
|
&mstore.Store{Session: &sessions.State{}},
|
|
"Bad Request:",
|
|
},
|
|
{
|
|
"bad decrypt",
|
|
map[string]string{urlutil.QuerySessionEncrypted: "KBEjQ9rnCxaAX-GOqexGw9ivEQURqts3zZ2mNGy0wnVa3SbtM399KlBq2nZ-9wM21FfsZX52er4jlmC7kPEKM3P7uZ41zR0zeys1-_74a5tQp-vsf1WXZfRsgVOuBcWPkMiWEoc379JFHxGDudp5VhU8B-dcQt4f3_PtLTHARkuH54io1Va2gNMq4Hiy8sQ1MPGCQeltH_JMzzdDpXdmdusWrXUvCGkba24muvAV06D8XRVJj6Iu9eK94qFnqcHc7wzziEbb8ADBues9dwbtb6jl8vMWz5rN6XvXqA5YpZv_MQZlsrO4oXFFQDevdgB84cX1tVbVu6qZvK_yQBZqzpOjWA9uIaoSENMytoXuWAlFO_sXjswfX8JTNdGwzB7qQRNPqxVG_sM_tzY3QhPm8zqwEzsXG5DokxZfVt2I5WJRUEovFDb4BnK9KFnnkEzLEdMudixVnXeGmTtycgJvoTeTCQRPfDYkcgJ7oKf4tGea-W7z5UAVa2RduJM9ZoM6YtJX7jgDm__PvvqcE0knJUF87XHBzdcOjoDF-CUze9xDJgNBlvPbJqVshKrwoqSYpePSDH9GUCNKxGequW3Ma8GvlFfhwd0rK6IZG-XWkyk0XSWQIGkDSjAvhB1wsOusCCguDjbpVZpaW5MMyTkmx68pl6qlIKT5UCcrVPl4ix5ZEj91mUDF0O1t04haD7VZuLVFXVGmqtFrBKI76sdYN-zkokaa1_chPRTyqMQFlqu_8LD6-RiK3UccGM-dEmnX72i91NP9F9OK0WJr9Cheup1C_P0mjqAO4Cb8oIHm0Oxz_mRqv5QbTGJtb3xwPLPuVjVCiE4gGBcuU2ixpSVf5HUF7y1KicVMCKiX9ATCBtg8sTdQZQnPEtHcHHAvdsnDVwev1LGfqA-Gdvg="},
|
|
true,
|
|
&mock.Encoder{MarshalResponse: []byte("x")},
|
|
&mstore.Store{Session: &sessions.State{}},
|
|
"proxy: callback token decrypt error:",
|
|
},
|
|
{
|
|
"bad save session",
|
|
map[string]string{urlutil.QuerySessionEncrypted: goodEncryptionString},
|
|
true,
|
|
&mock.Encoder{MarshalResponse: []byte("x")},
|
|
&mstore.Store{SaveError: errors.New("hi")},
|
|
"Internal Server Error: proxy: error saving session state:",
|
|
},
|
|
{
|
|
"bad base64",
|
|
map[string]string{urlutil.QuerySessionEncrypted: "^"},
|
|
true,
|
|
&mock.Encoder{MarshalResponse: []byte("x")},
|
|
&mstore.Store{Session: &sessions.State{}},
|
|
"proxy: malfromed callback token:",
|
|
},
|
|
{
|
|
"malformed redirect",
|
|
nil,
|
|
true,
|
|
&mock.Encoder{},
|
|
&mstore.Store{Session: &sessions.State{}},
|
|
"Bad Request:",
|
|
},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
flow, err := NewStateful(&config.Config{Options: opts}, tt.sessionStore)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
flow.sharedEncoder = tt.cipher
|
|
redirectURI := &url.URL{Scheme: "http", Host: "example.com", Path: "/"}
|
|
queryString := redirectURI.Query()
|
|
for k, v := range tt.qp {
|
|
queryString.Set(k, v)
|
|
}
|
|
redirectURI.RawQuery = queryString.Encode()
|
|
|
|
uri := &url.URL{Scheme: "https", Host: "example.com", Path: "/"}
|
|
if tt.qp != nil {
|
|
qu := uri.Query()
|
|
for k, v := range tt.qp {
|
|
qu.Set(k, v)
|
|
}
|
|
qu.Set(urlutil.QueryRedirectURI, redirectURI.String())
|
|
uri.RawQuery = qu.Encode()
|
|
}
|
|
if tt.validSignature {
|
|
sharedKey, _ := opts.GetSharedKey()
|
|
uri = urlutil.NewSignedURL(sharedKey, uri).Sign()
|
|
}
|
|
|
|
r := httptest.NewRequest(http.MethodGet, uri.String(), nil)
|
|
// fmt.Println(uri.String())
|
|
r.Host = r.URL.Host
|
|
|
|
r.Header.Set("Accept", "application/json")
|
|
|
|
w := httptest.NewRecorder()
|
|
err = flow.Callback(w, r)
|
|
if tt.wantErrorMsg == "" {
|
|
if err != nil {
|
|
t.Errorf("unexpected error: %v", err)
|
|
}
|
|
} else {
|
|
if err == nil || !strings.Contains(err.Error(), tt.wantErrorMsg) {
|
|
t.Errorf("expected error containing %q; got %v", tt.wantErrorMsg, err)
|
|
}
|
|
}
|
|
|
|
// XXX: assert redirect URL
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestStatefulRevokeSession(t *testing.T) {
|
|
opts := config.NewDefaultOptions()
|
|
flow, err := NewStateful(&config.Config{Options: opts}, nil)
|
|
require.NoError(t, err)
|
|
|
|
ctrl := gomock.NewController(t)
|
|
client := mock_databroker.NewMockDataBrokerServiceClient(ctrl)
|
|
flow.dataBrokerClient = client
|
|
|
|
// Exercise the happy path (no errors): calling RevokeSession() should
|
|
// fetch and delete a session record from the databroker and make a request
|
|
// to the identity provider to revoke the corresponding OAuth2 token.
|
|
|
|
ctx := context.Background()
|
|
authenticator := &mockAuthenticator{}
|
|
sessionState := &sessions.State{ID: "session-id"}
|
|
tokenExpiry := time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC)
|
|
|
|
client.EXPECT().Get(ctx, protoEqualMatcher{
|
|
&databroker.GetRequest{
|
|
Type: "type.googleapis.com/session.Session",
|
|
Id: "session-id",
|
|
},
|
|
}).Return(&databroker.GetResponse{
|
|
Record: &databroker.Record{
|
|
Version: 123456,
|
|
Type: "type.googleapis.com/session.Session",
|
|
Id: "session-id",
|
|
Data: protoutil.NewAny(&session.Session{
|
|
Id: "session-id",
|
|
UserId: "user-id",
|
|
IdToken: &session.IDToken{
|
|
Raw: "[raw-id-token]",
|
|
},
|
|
OauthToken: &session.OAuthToken{
|
|
AccessToken: "[oauth-access-token]",
|
|
TokenType: "Bearer",
|
|
RefreshToken: "[oauth-refresh-token]",
|
|
ExpiresAt: timestamppb.New(tokenExpiry),
|
|
},
|
|
}),
|
|
},
|
|
}, nil)
|
|
|
|
client.EXPECT().Put(ctx, gomock.Any()).DoAndReturn(
|
|
func(_ context.Context, r *databroker.PutRequest, _ ...grpc.CallOption) (*databroker.PutResponse, error) {
|
|
require.Len(t, r.Records, 1)
|
|
record := r.GetRecord()
|
|
assert.Equal(t, "type.googleapis.com/session.Session", record.Type)
|
|
assert.Equal(t, "session-id", record.Id)
|
|
assert.Equal(t, uint64(123456), record.Version)
|
|
|
|
// The session record received in this PutRequest should have a
|
|
// DeletedAt timestamp, as well as the same session ID and user ID
|
|
// as was returned in the previous GetResponse.
|
|
assert.NotNil(t, record.DeletedAt)
|
|
var s session.Session
|
|
record.GetData().UnmarshalTo(&s)
|
|
assert.Equal(t, "session-id", s.Id)
|
|
assert.Equal(t, "user-id", s.UserId)
|
|
return nil, nil
|
|
})
|
|
|
|
idToken := flow.RevokeSession(ctx, nil, authenticator, sessionState)
|
|
|
|
assert.Equal(t, "[raw-id-token]", idToken)
|
|
assert.Equal(t, &oauth2.Token{
|
|
AccessToken: "[oauth-access-token]",
|
|
TokenType: "Bearer",
|
|
RefreshToken: "[oauth-refresh-token]",
|
|
Expiry: tokenExpiry,
|
|
}, authenticator.revokedToken)
|
|
}
|
|
|
|
// protoEqualMatcher implements gomock.Matcher using proto.Equal.
|
|
// TODO: move this to a testutil package?
|
|
type protoEqualMatcher struct {
|
|
expected proto.Message
|
|
}
|
|
|
|
func (m protoEqualMatcher) Matches(x interface{}) bool {
|
|
p, ok := x.(proto.Message)
|
|
if !ok {
|
|
return false
|
|
}
|
|
return proto.Equal(m.expected, p)
|
|
}
|
|
|
|
func (m protoEqualMatcher) String() string {
|
|
return fmt.Sprintf("is equal to %v (%T)", m.expected, m.expected)
|
|
}
|
|
|
|
type mockAuthenticator struct {
|
|
identity.Authenticator
|
|
|
|
revokedToken *oauth2.Token
|
|
revokeError error
|
|
}
|
|
|
|
func (a *mockAuthenticator) Revoke(_ context.Context, token *oauth2.Token) error {
|
|
a.revokedToken = token
|
|
return a.revokeError
|
|
}
|