mirror of
https://github.com/pomerium/pomerium.git
synced 2025-04-29 10:26:29 +02:00
This also replaces instances where we manually write "return ctx.Err()" with "return context.Cause(ctx)" which is functionally identical, but will also correctly propagate cause errors if present.
511 lines
20 KiB
Go
511 lines
20 KiB
Go
package authenticateflow
|
|
|
|
import (
|
|
"context"
|
|
"encoding/base64"
|
|
"errors"
|
|
"fmt"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"net/url"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/go-jose/go-jose/v3/jwt"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
"go.uber.org/mock/gomock"
|
|
"golang.org/x/oauth2"
|
|
"google.golang.org/grpc"
|
|
"google.golang.org/protobuf/proto"
|
|
"google.golang.org/protobuf/types/known/structpb"
|
|
"google.golang.org/protobuf/types/known/timestamppb"
|
|
|
|
"github.com/pomerium/pomerium/config"
|
|
"github.com/pomerium/pomerium/internal/encoding"
|
|
"github.com/pomerium/pomerium/internal/encoding/mock"
|
|
"github.com/pomerium/pomerium/internal/sessions"
|
|
mstore "github.com/pomerium/pomerium/internal/sessions/mock"
|
|
"github.com/pomerium/pomerium/internal/testutil"
|
|
"github.com/pomerium/pomerium/internal/urlutil"
|
|
"github.com/pomerium/pomerium/pkg/cryptutil"
|
|
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
|
"github.com/pomerium/pomerium/pkg/grpc/databroker/mock_databroker"
|
|
"github.com/pomerium/pomerium/pkg/grpc/session"
|
|
"github.com/pomerium/pomerium/pkg/grpc/user"
|
|
"github.com/pomerium/pomerium/pkg/identity"
|
|
"github.com/pomerium/pomerium/pkg/protoutil"
|
|
)
|
|
|
|
func TestStatefulSignIn(t *testing.T) {
|
|
opts := config.NewDefaultOptions()
|
|
tests := []struct {
|
|
name string
|
|
|
|
host string
|
|
qp map[string]string
|
|
validSignature bool
|
|
|
|
session *sessions.State
|
|
encoder encoding.MarshalUnmarshaler
|
|
saveError error
|
|
|
|
wantErrorMsg string
|
|
wantRedirectBaseURL string
|
|
}{
|
|
{"good", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, true, &sessions.State{}, &mock.Encoder{}, nil, "", "https://dst.some.example/.pomerium/callback/"},
|
|
{"good alternate port", "corp.example.example:8443", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, true, &sessions.State{}, &mock.Encoder{}, nil, "", "https://dst.some.example/.pomerium/callback/"},
|
|
{"invalid signature", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, false, &sessions.State{}, &mock.Encoder{}, nil, "Bad Request:", ""},
|
|
{"bad redirect uri query", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "^^^"}, true, &sessions.State{}, &mock.Encoder{}, nil, "Bad Request:", ""},
|
|
{"bad marshal", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, true, &sessions.State{}, &mock.Encoder{MarshalError: errors.New("error")}, nil, "Bad Request: error", ""},
|
|
{"good with different programmatic redirect", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/", urlutil.QueryCallbackURI: "https://some.example"}, true, &sessions.State{}, &mock.Encoder{}, nil, "", "https://some.example"},
|
|
{"encrypted encoder error", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/", urlutil.QueryCallbackURI: "https://some.example"}, true, &sessions.State{}, &mock.Encoder{MarshalError: errors.New("error")}, nil, "Bad Request: error", ""},
|
|
{"good with callback uri set", "corp.example.example", map[string]string{urlutil.QueryCallbackURI: "https://some.example/", urlutil.QueryRedirectURI: "https://dst.some.example/"}, true, &sessions.State{}, &mock.Encoder{}, nil, "", "https://some.example/"},
|
|
{"bad callback uri set", "corp.example.example", map[string]string{urlutil.QueryCallbackURI: "^", urlutil.QueryRedirectURI: "https://dst.some.example/"}, true, &sessions.State{}, &mock.Encoder{}, nil, "Bad Request:", ""},
|
|
{"good programmatic request", "corp.example.example", map[string]string{urlutil.QueryIsProgrammatic: "true", urlutil.QueryRedirectURI: "https://dst.some.example/"}, true, &sessions.State{}, &mock.Encoder{}, nil, "", "https://dst.some.example/.pomerium/callback/"},
|
|
}
|
|
for _, tt := range tests {
|
|
tt := tt
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
sessionStore := &mstore.Store{SaveError: tt.saveError}
|
|
flow, err := NewStateful(context.Background(), &config.Config{Options: opts}, sessionStore)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
flow.sharedEncoder = tt.encoder
|
|
|
|
uri := &url.URL{Scheme: "https", Host: tt.host}
|
|
queryString := uri.Query()
|
|
for k, v := range tt.qp {
|
|
queryString.Set(k, v)
|
|
}
|
|
uri.RawQuery = queryString.Encode()
|
|
if tt.validSignature {
|
|
sharedKey, _ := opts.GetSharedKey()
|
|
uri = urlutil.NewSignedURL(sharedKey, uri).Sign()
|
|
}
|
|
|
|
r := httptest.NewRequest(http.MethodGet, uri.String(), nil)
|
|
r.Header.Set("Accept", "application/json")
|
|
|
|
w := httptest.NewRecorder()
|
|
err = flow.SignIn(w, r, tt.session)
|
|
result := w.Result()
|
|
if tt.wantErrorMsg == "" {
|
|
if err != nil {
|
|
t.Errorf("unexpected error: %v", err)
|
|
}
|
|
expectedStatus := "302 Found"
|
|
if result.Status != expectedStatus {
|
|
t.Errorf("wrong status code: got %v, want %v", result.Status, expectedStatus)
|
|
}
|
|
loc, err := url.Parse(result.Header.Get("Location"))
|
|
if err != nil {
|
|
t.Fatalf("couldn't parse redirect URL: %v", err)
|
|
}
|
|
loc.RawQuery = "" // ignore the query parameters
|
|
if loc.String() != tt.wantRedirectBaseURL {
|
|
t.Errorf("wrong redirect base URL: got %q, want %q",
|
|
loc.String(), tt.wantRedirectBaseURL)
|
|
}
|
|
} else {
|
|
if err == nil || !strings.Contains(err.Error(), tt.wantErrorMsg) {
|
|
t.Errorf("expected error containing %q; got %v", tt.wantErrorMsg, err)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestStatefulAuthenticateSignInURL(t *testing.T) {
|
|
opts := config.NewDefaultOptions()
|
|
opts.AuthenticateURLString = "https://authenticate.example.com"
|
|
key := cryptutil.NewKey()
|
|
opts.SharedKey = base64.StdEncoding.EncodeToString(key)
|
|
flow, err := NewStateful(context.Background(), &config.Config{Options: opts}, nil)
|
|
require.NoError(t, err)
|
|
|
|
t.Run("NilQueryParams", func(t *testing.T) {
|
|
redirectURL := &url.URL{Scheme: "https", Host: "example.com"}
|
|
u, err := flow.AuthenticateSignInURL(nil, nil, redirectURL, "fake-idp-id")
|
|
assert.NoError(t, err)
|
|
parsed, _ := url.Parse(u)
|
|
assert.NoError(t, urlutil.NewSignedURL(key, parsed).Validate())
|
|
assert.Equal(t, "https", parsed.Scheme)
|
|
assert.Equal(t, "authenticate.example.com", parsed.Host)
|
|
assert.Equal(t, "/.pomerium/sign_in", parsed.Path)
|
|
q := parsed.Query()
|
|
assert.Equal(t, "https://example.com", parsed.Query().Get("pomerium_redirect_uri"))
|
|
assert.Equal(t, "fake-idp-id", q.Get("pomerium_idp_id"))
|
|
})
|
|
t.Run("ExtraQueryParams", func(t *testing.T) {
|
|
redirectURL := &url.URL{Scheme: "https", Host: "example.com"}
|
|
q := url.Values{}
|
|
q.Set("foo", "bar")
|
|
u, err := flow.AuthenticateSignInURL(nil, q, redirectURL, "fake-idp-id")
|
|
assert.NoError(t, err)
|
|
parsed, _ := url.Parse(u)
|
|
assert.NoError(t, urlutil.NewSignedURL(key, parsed).Validate())
|
|
assert.Equal(t, "https", parsed.Scheme)
|
|
assert.Equal(t, "authenticate.example.com", parsed.Host)
|
|
assert.Equal(t, "/.pomerium/sign_in", parsed.Path)
|
|
q = parsed.Query()
|
|
assert.Equal(t, "https://example.com", q.Get("pomerium_redirect_uri"))
|
|
assert.Equal(t, "fake-idp-id", q.Get("pomerium_idp_id"))
|
|
assert.Equal(t, "bar", q.Get("foo"))
|
|
})
|
|
}
|
|
|
|
func TestStatefulGetIdentityProviderIDForURLValues(t *testing.T) {
|
|
flow := Stateful{defaultIdentityProviderID: "default-id"}
|
|
assert.Equal(t, "default-id", flow.GetIdentityProviderIDForURLValues(nil))
|
|
q := url.Values{"pomerium_idp_id": []string{"idp-id"}}
|
|
assert.Equal(t, "idp-id", flow.GetIdentityProviderIDForURLValues(q))
|
|
}
|
|
|
|
const goodEncryptionString = "KBEjQ9rnCxaAX-GOqetGw9ivEQURqts3zZ2mNGy0wnVa3SbtM399KlBq2nZ-9wM21FfsZX52er4jlmC7kPEKM3P7uZ41zR0zeys1-_74a5tQp-vsf1WXZfRsgVOuBcWPkMiWEoc379JFHxGDudp5VhU8B-dcQt4f3_PtLTHARkuH54io1Va2gNMq4Hiy8sQ1MPGCQeltH_JMzzdDpXdmdusWrXUvCGkba24muvAV06D8XRVJj6Iu9eK94qFnqcHc7wzziEbb8ADBues9dwbtb6jl8vMWz5rN6XvXqA5YpZv_MQZlsrO4oXFFQDevdgB84cX1tVbVu6qZvK_yQBZqzpOjWA9uIaoSENMytoXuWAlFO_sXjswfX8JTNdGwzB7qQRNPqxVG_sM_tzY3QhPm8zqwEzsXG5DokxZfVt2I5WJRUEovFDb4BnK9KFnnkEzLEdMudixVnXeGmTtycgJvoTeTCQRPfDYkcgJ7oKf4tGea-W7z5UAVa2RduJM9ZoM6YtJX7jgDm__PvvqcE0knJUF87XHBzdcOjoDF-CUze9xDJgNBlvPbJqVshKrwoqSYpePSDH9GUCNKxGequW3Ma8GvlFfhwd0rK6IZG-XWkyk0XSWQIGkDSjAvhB1wsOusCCguDjbpVZpaW5MMyTkmx68pl6qlIKT5UCcrVPl4ix5ZEj91mUDF0O1t04haD7VZuLVFXVGmqtFrBKI76sdYN-zkokaa1_chPRTyqMQFlqu_8LD6-RiK3UccGM-dEmnX72i91NP9F9OK0WJr9Cheup1C_P0mjqAO4Cb8oIHm0Oxz_mRqv5QbTGJtb3xwPLPuVjVCiE4gGBcuU2ixpSVf5HUF7y1KicVMCKiX9ATCBtg8sTdQZQnPEtHcHHAvdsnDVwev1LGfqA-Gdvg="
|
|
|
|
func TestStatefulCallback(t *testing.T) {
|
|
opts := config.NewDefaultOptions()
|
|
opts.SharedKey = "80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ="
|
|
tests := []struct {
|
|
name string
|
|
|
|
qp map[string]string
|
|
validSignature bool
|
|
cipher encoding.MarshalUnmarshaler
|
|
sessionStore sessions.SessionStore
|
|
|
|
wantErrorMsg string
|
|
}{
|
|
{
|
|
"good",
|
|
map[string]string{urlutil.QueryCallbackURI: "ok", urlutil.QuerySessionEncrypted: goodEncryptionString},
|
|
true,
|
|
&mock.Encoder{MarshalResponse: []byte("x")},
|
|
&mstore.Store{Session: &sessions.State{}},
|
|
"",
|
|
},
|
|
{
|
|
"good programmatic",
|
|
map[string]string{urlutil.QueryIsProgrammatic: "true", urlutil.QueryCallbackURI: "ok", urlutil.QuerySessionEncrypted: goodEncryptionString},
|
|
true,
|
|
&mock.Encoder{MarshalResponse: []byte("x")},
|
|
&mstore.Store{Session: &sessions.State{}},
|
|
"",
|
|
},
|
|
{
|
|
"invalid signature",
|
|
map[string]string{urlutil.QueryCallbackURI: "ok", urlutil.QuerySessionEncrypted: goodEncryptionString},
|
|
false,
|
|
&mock.Encoder{MarshalResponse: []byte("x")},
|
|
&mstore.Store{Session: &sessions.State{}},
|
|
"Bad Request:",
|
|
},
|
|
{
|
|
"bad decrypt",
|
|
map[string]string{urlutil.QuerySessionEncrypted: "KBEjQ9rnCxaAX-GOqexGw9ivEQURqts3zZ2mNGy0wnVa3SbtM399KlBq2nZ-9wM21FfsZX52er4jlmC7kPEKM3P7uZ41zR0zeys1-_74a5tQp-vsf1WXZfRsgVOuBcWPkMiWEoc379JFHxGDudp5VhU8B-dcQt4f3_PtLTHARkuH54io1Va2gNMq4Hiy8sQ1MPGCQeltH_JMzzdDpXdmdusWrXUvCGkba24muvAV06D8XRVJj6Iu9eK94qFnqcHc7wzziEbb8ADBues9dwbtb6jl8vMWz5rN6XvXqA5YpZv_MQZlsrO4oXFFQDevdgB84cX1tVbVu6qZvK_yQBZqzpOjWA9uIaoSENMytoXuWAlFO_sXjswfX8JTNdGwzB7qQRNPqxVG_sM_tzY3QhPm8zqwEzsXG5DokxZfVt2I5WJRUEovFDb4BnK9KFnnkEzLEdMudixVnXeGmTtycgJvoTeTCQRPfDYkcgJ7oKf4tGea-W7z5UAVa2RduJM9ZoM6YtJX7jgDm__PvvqcE0knJUF87XHBzdcOjoDF-CUze9xDJgNBlvPbJqVshKrwoqSYpePSDH9GUCNKxGequW3Ma8GvlFfhwd0rK6IZG-XWkyk0XSWQIGkDSjAvhB1wsOusCCguDjbpVZpaW5MMyTkmx68pl6qlIKT5UCcrVPl4ix5ZEj91mUDF0O1t04haD7VZuLVFXVGmqtFrBKI76sdYN-zkokaa1_chPRTyqMQFlqu_8LD6-RiK3UccGM-dEmnX72i91NP9F9OK0WJr9Cheup1C_P0mjqAO4Cb8oIHm0Oxz_mRqv5QbTGJtb3xwPLPuVjVCiE4gGBcuU2ixpSVf5HUF7y1KicVMCKiX9ATCBtg8sTdQZQnPEtHcHHAvdsnDVwev1LGfqA-Gdvg="},
|
|
true,
|
|
&mock.Encoder{MarshalResponse: []byte("x")},
|
|
&mstore.Store{Session: &sessions.State{}},
|
|
"proxy: callback token decrypt error:",
|
|
},
|
|
{
|
|
"bad save session",
|
|
map[string]string{urlutil.QuerySessionEncrypted: goodEncryptionString},
|
|
true,
|
|
&mock.Encoder{MarshalResponse: []byte("x")},
|
|
&mstore.Store{SaveError: errors.New("hi")},
|
|
"Internal Server Error: proxy: error saving session state:",
|
|
},
|
|
{
|
|
"bad base64",
|
|
map[string]string{urlutil.QuerySessionEncrypted: "^"},
|
|
true,
|
|
&mock.Encoder{MarshalResponse: []byte("x")},
|
|
&mstore.Store{Session: &sessions.State{}},
|
|
"proxy: malfromed callback token:",
|
|
},
|
|
{
|
|
"malformed redirect",
|
|
nil,
|
|
true,
|
|
&mock.Encoder{},
|
|
&mstore.Store{Session: &sessions.State{}},
|
|
"Bad Request:",
|
|
},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
flow, err := NewStateful(context.Background(), &config.Config{Options: opts}, tt.sessionStore)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
flow.sharedEncoder = tt.cipher
|
|
redirectURI := &url.URL{Scheme: "http", Host: "example.com", Path: "/"}
|
|
queryString := redirectURI.Query()
|
|
for k, v := range tt.qp {
|
|
queryString.Set(k, v)
|
|
}
|
|
redirectURI.RawQuery = queryString.Encode()
|
|
|
|
uri := &url.URL{Scheme: "https", Host: "example.com", Path: "/"}
|
|
if tt.qp != nil {
|
|
qu := uri.Query()
|
|
for k, v := range tt.qp {
|
|
qu.Set(k, v)
|
|
}
|
|
qu.Set(urlutil.QueryRedirectURI, redirectURI.String())
|
|
uri.RawQuery = qu.Encode()
|
|
}
|
|
if tt.validSignature {
|
|
sharedKey, _ := opts.GetSharedKey()
|
|
uri = urlutil.NewSignedURL(sharedKey, uri).Sign()
|
|
}
|
|
|
|
r := httptest.NewRequest(http.MethodGet, uri.String(), nil)
|
|
r.Host = r.URL.Host
|
|
|
|
r.Header.Set("Accept", "application/json")
|
|
|
|
w := httptest.NewRecorder()
|
|
err = flow.Callback(w, r)
|
|
if tt.wantErrorMsg == "" {
|
|
if err != nil {
|
|
t.Errorf("unexpected error: %v", err)
|
|
}
|
|
location, _ := url.Parse(w.Result().Header.Get("Location"))
|
|
assert.Equal(t, "example.com", location.Host)
|
|
assert.Equal(t, "ok", location.Query().Get("pomerium_callback_uri"))
|
|
} else {
|
|
if err == nil || !strings.Contains(err.Error(), tt.wantErrorMsg) {
|
|
t.Errorf("expected error containing %q; got %v", tt.wantErrorMsg, err)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestStatefulRevokeSession(t *testing.T) {
|
|
opts := config.NewDefaultOptions()
|
|
flow, err := NewStateful(context.Background(), &config.Config{Options: opts}, nil)
|
|
require.NoError(t, err)
|
|
|
|
ctrl := gomock.NewController(t)
|
|
client := mock_databroker.NewMockDataBrokerServiceClient(ctrl)
|
|
flow.dataBrokerClient = client
|
|
|
|
// Exercise the happy path (no errors): calling RevokeSession() should
|
|
// fetch and delete a session record from the databroker and make a request
|
|
// to the identity provider to revoke the corresponding OAuth2 token.
|
|
|
|
ctx := context.Background()
|
|
authenticator := &mockAuthenticator{}
|
|
sessionState := &sessions.State{ID: "session-id"}
|
|
tokenExpiry := time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC)
|
|
|
|
client.EXPECT().Get(ctx, protoEqualMatcher{
|
|
&databroker.GetRequest{
|
|
Type: "type.googleapis.com/session.Session",
|
|
Id: "session-id",
|
|
},
|
|
}).Return(&databroker.GetResponse{
|
|
Record: &databroker.Record{
|
|
Version: 123456,
|
|
Type: "type.googleapis.com/session.Session",
|
|
Id: "session-id",
|
|
Data: protoutil.NewAny(&session.Session{
|
|
Id: "session-id",
|
|
UserId: "user-id",
|
|
IdToken: &session.IDToken{
|
|
Raw: "[raw-id-token]",
|
|
},
|
|
OauthToken: &session.OAuthToken{
|
|
AccessToken: "[oauth-access-token]",
|
|
TokenType: "Bearer",
|
|
RefreshToken: "[oauth-refresh-token]",
|
|
ExpiresAt: timestamppb.New(tokenExpiry),
|
|
},
|
|
}),
|
|
},
|
|
}, nil)
|
|
|
|
client.EXPECT().Put(ctx, gomock.Any()).DoAndReturn(
|
|
func(_ context.Context, r *databroker.PutRequest, _ ...grpc.CallOption) (*databroker.PutResponse, error) {
|
|
require.Len(t, r.Records, 1)
|
|
record := r.GetRecord()
|
|
assert.Equal(t, "type.googleapis.com/session.Session", record.Type)
|
|
assert.Equal(t, "session-id", record.Id)
|
|
assert.Equal(t, uint64(123456), record.Version)
|
|
|
|
// The session record received in this PutRequest should have a
|
|
// DeletedAt timestamp, as well as the same session ID and user ID
|
|
// as was returned in the previous GetResponse.
|
|
assert.NotNil(t, record.DeletedAt)
|
|
var s session.Session
|
|
record.GetData().UnmarshalTo(&s)
|
|
assert.Equal(t, "session-id", s.Id)
|
|
assert.Equal(t, "user-id", s.UserId)
|
|
return nil, nil
|
|
})
|
|
|
|
idToken := flow.RevokeSession(ctx, nil, authenticator, sessionState)
|
|
|
|
assert.Equal(t, "[raw-id-token]", idToken)
|
|
assert.Equal(t, &oauth2.Token{
|
|
AccessToken: "[oauth-access-token]",
|
|
TokenType: "Bearer",
|
|
RefreshToken: "[oauth-refresh-token]",
|
|
Expiry: tokenExpiry,
|
|
}, authenticator.revokedToken)
|
|
}
|
|
|
|
func TestPersistSession(t *testing.T) {
|
|
timeNow = func() time.Time { return time.Unix(1721965100, 0) }
|
|
t.Cleanup(func() { timeNow = time.Now })
|
|
|
|
opts := config.NewDefaultOptions()
|
|
opts.CookieExpire = 4 * time.Hour
|
|
flow, err := NewStateful(context.Background(), &config.Config{Options: opts}, nil)
|
|
require.NoError(t, err)
|
|
|
|
ctrl := gomock.NewController(t)
|
|
client := mock_databroker.NewMockDataBrokerServiceClient(ctrl)
|
|
flow.dataBrokerClient = client
|
|
|
|
ctx := context.Background()
|
|
|
|
client.EXPECT().Get(ctx, protoEqualMatcher{
|
|
&databroker.GetRequest{
|
|
Type: "type.googleapis.com/user.User",
|
|
Id: "user-id",
|
|
},
|
|
}).Return(&databroker.GetResponse{}, nil)
|
|
|
|
// PersistSession should copy data from the sessions.State,
|
|
// identity.SessionClaims, and oauth2.Token into a Session and User record.
|
|
sessionState := &sessions.State{
|
|
ID: "session-id",
|
|
Subject: "user-id",
|
|
Audience: jwt.Audience{"route.example.com"},
|
|
}
|
|
claims := identity.SessionClaims{
|
|
Claims: map[string]any{
|
|
"name": "John Doe",
|
|
"email": "john.doe@example.com",
|
|
},
|
|
RawIDToken: "e30." + base64.RawURLEncoding.EncodeToString([]byte(`{
|
|
"iss": "https://issuer.example.com",
|
|
"sub": "id-token-user-id",
|
|
"iat": 1721965070,
|
|
"exp": 1721965670
|
|
}`)) + ".fake-signature",
|
|
}
|
|
accessToken := &oauth2.Token{
|
|
AccessToken: "access-token",
|
|
RefreshToken: "refresh-token",
|
|
Expiry: time.Unix(1721965190, 0),
|
|
}
|
|
|
|
expectedClaims := map[string]*structpb.ListValue{
|
|
"name": {Values: []*structpb.Value{structpb.NewStringValue("John Doe")}},
|
|
"email": {Values: []*structpb.Value{structpb.NewStringValue("john.doe@example.com")}},
|
|
}
|
|
|
|
client.EXPECT().Put(ctx, gomock.Any()).DoAndReturn(
|
|
func(_ context.Context, r *databroker.PutRequest, _ ...grpc.CallOption) (*databroker.PutResponse, error) {
|
|
require.Len(t, r.Records, 1)
|
|
record := r.GetRecord()
|
|
assert.Equal(t, "type.googleapis.com/user.User", record.Type)
|
|
assert.Equal(t, "user-id", record.Id)
|
|
assert.Nil(t, record.DeletedAt)
|
|
|
|
// Verify that claims data is populated into the User record.
|
|
var u user.User
|
|
record.GetData().UnmarshalTo(&u)
|
|
assert.Equal(t, "user-id", u.Id)
|
|
assert.Equal(t, expectedClaims, u.Claims)
|
|
|
|
// A real response would include the record, but here we can skip it as it isn't used.
|
|
return &databroker.PutResponse{}, nil
|
|
})
|
|
|
|
client.EXPECT().Put(ctx, gomock.Any()).DoAndReturn(
|
|
func(_ context.Context, r *databroker.PutRequest, _ ...grpc.CallOption) (*databroker.PutResponse, error) {
|
|
require.Len(t, r.Records, 1)
|
|
record := r.GetRecord()
|
|
assert.Equal(t, "type.googleapis.com/session.Session", record.Type)
|
|
assert.Equal(t, "session-id", record.Id)
|
|
assert.Nil(t, record.DeletedAt)
|
|
|
|
var s session.Session
|
|
record.GetData().UnmarshalTo(&s)
|
|
testutil.AssertProtoEqual(t, &session.Session{
|
|
Id: "session-id",
|
|
UserId: "user-id",
|
|
IssuedAt: timestamppb.New(time.Unix(1721965100, 0)),
|
|
AccessedAt: timestamppb.New(time.Unix(1721965100, 0)),
|
|
ExpiresAt: timestamppb.New(time.Unix(1721979500, 0)),
|
|
Audience: []string{"route.example.com"},
|
|
Claims: expectedClaims,
|
|
IdToken: &session.IDToken{
|
|
Issuer: "https://issuer.example.com",
|
|
Subject: "id-token-user-id",
|
|
IssuedAt: ×tamppb.Timestamp{Seconds: 1721965070},
|
|
ExpiresAt: ×tamppb.Timestamp{Seconds: 1721965670},
|
|
Raw: claims.RawIDToken,
|
|
},
|
|
OauthToken: &session.OAuthToken{
|
|
AccessToken: "access-token",
|
|
RefreshToken: "refresh-token",
|
|
ExpiresAt: ×tamppb.Timestamp{Seconds: 1721965190},
|
|
},
|
|
}, &s)
|
|
|
|
return &databroker.PutResponse{
|
|
ServerVersion: 2222,
|
|
Records: []*databroker.Record{{
|
|
Version: 1111,
|
|
Type: "type.googleapis.com/session.Session",
|
|
Id: "session-id",
|
|
Data: protoutil.NewAny(&s),
|
|
}},
|
|
}, nil
|
|
})
|
|
|
|
err = flow.PersistSession(ctx, nil, sessionState, claims, accessToken)
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, uint64(1111), sessionState.DatabrokerRecordVersion)
|
|
assert.Equal(t, uint64(2222), sessionState.DatabrokerServerVersion)
|
|
}
|
|
|
|
// protoEqualMatcher implements gomock.Matcher using proto.Equal.
|
|
// TODO: move this to a testutil package?
|
|
type protoEqualMatcher struct {
|
|
expected proto.Message
|
|
}
|
|
|
|
func (m protoEqualMatcher) Matches(x any) bool {
|
|
p, ok := x.(proto.Message)
|
|
if !ok {
|
|
return false
|
|
}
|
|
return proto.Equal(m.expected, p)
|
|
}
|
|
|
|
func (m protoEqualMatcher) String() string {
|
|
return fmt.Sprintf("is equal to %v (%T)", m.expected, m.expected)
|
|
}
|
|
|
|
type mockAuthenticator struct {
|
|
identity.Authenticator
|
|
|
|
revokedToken *oauth2.Token
|
|
revokeError error
|
|
}
|
|
|
|
func (a *mockAuthenticator) Revoke(_ context.Context, token *oauth2.Token) error {
|
|
a.revokedToken = token
|
|
return a.revokeError
|
|
}
|