diff --git a/internal/authenticateflow/stateful.go b/internal/authenticateflow/stateful.go index 51e0fa512..64030e4f7 100644 --- a/internal/authenticateflow/stateful.go +++ b/internal/authenticateflow/stateful.go @@ -27,6 +27,7 @@ import ( "github.com/pomerium/pomerium/pkg/grpc/databroker" "github.com/pomerium/pomerium/pkg/grpc/session" "github.com/pomerium/pomerium/pkg/grpc/user" + "github.com/pomerium/pomerium/pkg/grpcutil" ) // Stateful implements the stateful authentication flow. In this flow, the @@ -261,15 +262,44 @@ func (s *Stateful) RevokeSession( return "" } + // Note: session.Delete() cannot be used safely, because the identity + // manager expects to be able to read both session ID and user ID from + // deleted session records. Instead, we match the behavior used in the + // identity manager itself: fetch the existing databroker session record, + // explicitly set the DeletedAt timestamp, and Put() that record back. + + res, err := s.dataBrokerClient.Get(ctx, &databroker.GetRequest{ + Type: grpcutil.GetTypeURL(new(session.Session)), + Id: sessionState.ID, + }) + if err != nil { + err = fmt.Errorf("couldn't get session to be revoked: %w", err) + log.Ctx(ctx).Warn().Err(err).Msg("authenticate: failed to revoke access token") + return "" + } + + record := res.GetRecord() + + var sess session.Session + if err := record.GetData().UnmarshalTo(&sess); err != nil { + err = fmt.Errorf("couldn't unmarshal data of session to be revoked: %w", err) + log.Ctx(ctx).Warn().Err(err).Msg("authenticate: failed to revoke access token") + return "" + } + var rawIDToken string - sess, _ := session.Get(ctx, s.dataBrokerClient, sessionState.ID) - if sess != nil && sess.OauthToken != nil { + if sess.OauthToken != nil { rawIDToken = sess.GetIdToken().GetRaw() if err := authenticator.Revoke(ctx, manager.FromOAuthToken(sess.OauthToken)); err != nil { log.Ctx(ctx).Warn().Err(err).Msg("authenticate: failed to revoke access token") } } - if err := session.Delete(ctx, s.dataBrokerClient, sessionState.ID); err != nil { + + record.DeletedAt = timestamppb.Now() + _, err = s.dataBrokerClient.Put(ctx, &databroker.PutRequest{ + Records: []*databroker.Record{record}, + }) + if err != nil { log.Ctx(ctx).Warn().Err(err). Msg("authenticate: failed to delete session from session store") } diff --git a/internal/authenticateflow/stateful_test.go b/internal/authenticateflow/stateful_test.go index 51f1a62a3..a219677be 100644 --- a/internal/authenticateflow/stateful_test.go +++ b/internal/authenticateflow/stateful_test.go @@ -1,24 +1,37 @@ package authenticateflow import ( + "context" "encoding/base64" "errors" + "fmt" "net/http" "net/http/httptest" "net/url" "strings" "testing" + "time" + "github.com/golang/mock/gomock" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "golang.org/x/oauth2" + "google.golang.org/grpc" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/known/timestamppb" "github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/internal/encoding" "github.com/pomerium/pomerium/internal/encoding/mock" + "github.com/pomerium/pomerium/internal/identity" "github.com/pomerium/pomerium/internal/sessions" mstore "github.com/pomerium/pomerium/internal/sessions/mock" "github.com/pomerium/pomerium/internal/urlutil" "github.com/pomerium/pomerium/pkg/cryptutil" + "github.com/pomerium/pomerium/pkg/grpc/databroker" + "github.com/pomerium/pomerium/pkg/grpc/databroker/mock_databroker" + "github.com/pomerium/pomerium/pkg/grpc/session" + "github.com/pomerium/pomerium/pkg/protoutil" ) func TestStatefulSignIn(t *testing.T) { @@ -269,3 +282,107 @@ func TestStatefulCallback(t *testing.T) { }) } } + +func TestStatefulRevokeSession(t *testing.T) { + opts := config.NewDefaultOptions() + flow, err := NewStateful(&config.Config{Options: opts}, nil) + require.NoError(t, err) + + ctrl := gomock.NewController(t) + client := mock_databroker.NewMockDataBrokerServiceClient(ctrl) + flow.dataBrokerClient = client + + // Exercise the happy path (no errors): calling RevokeSession() should + // fetch and delete a session record from the databroker and make a request + // to the identity provider to revoke the corresponding OAuth2 token. + + ctx := context.Background() + authenticator := &mockAuthenticator{} + sessionState := &sessions.State{ID: "session-id"} + tokenExpiry := time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC) + + client.EXPECT().Get(ctx, protoEqualMatcher{ + &databroker.GetRequest{ + Type: "type.googleapis.com/session.Session", + Id: "session-id", + }, + }).Return(&databroker.GetResponse{ + Record: &databroker.Record{ + Version: 123456, + Type: "type.googleapis.com/session.Session", + Id: "session-id", + Data: protoutil.NewAny(&session.Session{ + Id: "session-id", + UserId: "user-id", + IdToken: &session.IDToken{ + Raw: "[raw-id-token]", + }, + OauthToken: &session.OAuthToken{ + AccessToken: "[oauth-access-token]", + TokenType: "Bearer", + RefreshToken: "[oauth-refresh-token]", + ExpiresAt: timestamppb.New(tokenExpiry), + }, + }), + }, + }, nil) + + client.EXPECT().Put(ctx, gomock.Any()).DoAndReturn( + func(_ context.Context, r *databroker.PutRequest, _ ...grpc.CallOption) (*databroker.PutResponse, error) { + require.Len(t, r.Records, 1) + record := r.GetRecord() + assert.Equal(t, "type.googleapis.com/session.Session", record.Type) + assert.Equal(t, "session-id", record.Id) + assert.Equal(t, uint64(123456), record.Version) + + // The session record received in this PutRequest should have a + // DeletedAt timestamp, as well as the same session ID and user ID + // as was returned in the previous GetResponse. + assert.NotNil(t, record.DeletedAt) + var s session.Session + record.GetData().UnmarshalTo(&s) + assert.Equal(t, "session-id", s.Id) + assert.Equal(t, "user-id", s.UserId) + return nil, nil + }) + + idToken := flow.RevokeSession(ctx, nil, authenticator, sessionState) + + assert.Equal(t, "[raw-id-token]", idToken) + assert.Equal(t, &oauth2.Token{ + AccessToken: "[oauth-access-token]", + TokenType: "Bearer", + RefreshToken: "[oauth-refresh-token]", + Expiry: tokenExpiry, + }, authenticator.revokedToken) +} + +// protoEqualMatcher implements gomock.Matcher using proto.Equal. +// TODO: move this to a testutil package? +type protoEqualMatcher struct { + expected proto.Message +} + +func (m protoEqualMatcher) Matches(x interface{}) bool { + p, ok := x.(proto.Message) + if !ok { + return false + } + return proto.Equal(m.expected, p) +} + +func (m protoEqualMatcher) String() string { + return fmt.Sprintf("is equal to %v (%T)", m.expected, m.expected) +} + +type mockAuthenticator struct { + identity.Authenticator + + revokedToken *oauth2.Token + revokeError error +} + +func (a *mockAuthenticator) Revoke(_ context.Context, token *oauth2.Token) error { + a.revokedToken = token + return a.revokeError +}