diff --git a/internal/identity/manager/data.go b/internal/identity/manager/data.go index a397d5a45..2c8637102 100644 --- a/internal/identity/manager/data.go +++ b/internal/identity/manager/data.go @@ -54,6 +54,9 @@ func (u *User) UnmarshalJSON(data []byte) error { // A Session is a session managed by the Manager. type Session struct { *session.Session + // lastRefresh is the time of the last refresh attempt (which may or may + // not have succeeded), or else the time the Manager first became aware of + // the session (if it has not yet attempted to refresh this session). lastRefresh time.Time // gracePeriod is the amount of time before expiration to attempt a refresh. gracePeriod time.Duration diff --git a/internal/identity/manager/manager.go b/internal/identity/manager/manager.go index 02ddb635a..b9a2c45a9 100644 --- a/internal/identity/manager/manager.go +++ b/internal/identity/manager/manager.go @@ -107,6 +107,10 @@ func (mgr *Manager) GetDataBrokerServiceClient() databroker.DataBrokerServiceCli return mgr.cfg.Load().dataBrokerClient } +func (mgr *Manager) now() time.Time { + return mgr.cfg.Load().now() +} + func (mgr *Manager) refreshLoop(ctx context.Context, update <-chan updateRecordsMessage, clear <-chan struct{}) error { // wait for initial sync select { @@ -145,7 +149,7 @@ func (mgr *Manager) refreshLoop(ctx context.Context, update <-chan updateRecords case <-timer.C: } - now := time.Now() + now := mgr.now() nextTime = now.Add(maxWait) // refresh sessions @@ -182,6 +186,15 @@ func (mgr *Manager) refreshLoop(ctx context.Context, update <-chan updateRecords } } +// refreshSession handles two distinct session lifecycle events: +// +// 1. If the session itself has expired, delete the session. +// 2. If the session's underlying OAuth2 access token is nearing expiration +// (but the session itself is still valid), refresh the access token. +// +// After a successful access token refresh, this method will also trigger a +// user info refresh. If an access token refresh or a user info refresh fails +// with a permanent error, the session will be deleted. func (mgr *Manager) refreshSession(ctx context.Context, userID, sessionID string) { log.Info(ctx). Str("user_id", userID). @@ -208,7 +221,7 @@ func (mgr *Manager) refreshSession(ctx context.Context, userID, sessionID string } expiry := s.GetExpiresAt().AsTime() - if !expiry.After(time.Now()) { + if !expiry.After(mgr.now()) { log.Info(ctx). Str("user_id", userID). Str("session_id", sessionID). @@ -262,8 +275,7 @@ func (mgr *Manager) refreshSession(ctx context.Context, userID, sessionID string return } - res, err := session.Put(ctx, mgr.cfg.Load().dataBrokerClient, s.Session) - if err != nil { + if _, err := session.Put(ctx, mgr.cfg.Load().dataBrokerClient, s.Session); err != nil { log.Error(ctx).Err(err). Str("user_id", s.GetUserId()). Str("session_id", s.GetId()). @@ -271,7 +283,9 @@ func (mgr *Manager) refreshSession(ctx context.Context, userID, sessionID string return } - mgr.onUpdateSession(ctx, res.GetRecord(), s.Session) + s.lastRefresh = mgr.now() + mgr.sessions.ReplaceOrInsert(s) + mgr.sessionScheduler.Add(s.NextRefresh(), toSessionSchedulerKey(userID, sessionID)) } func (mgr *Manager) refreshUser(ctx context.Context, userID string) { @@ -291,7 +305,7 @@ func (mgr *Manager) refreshUser(ctx context.Context, userID string) { Msg("no user found for refresh") return } - u.lastRefresh = time.Now() + u.lastRefresh = mgr.now() mgr.userScheduler.Add(u.NextRefresh(), u.GetId()) for _, s := range mgr.sessions.GetSessionsForUser(userID) { @@ -343,7 +357,7 @@ func (mgr *Manager) onUpdateRecords(ctx context.Context, msg updateRecordsMessag log.Warn(ctx).Msgf("error unmarshaling session: %s", err) continue } - mgr.onUpdateSession(ctx, record, &pbSession) + mgr.onUpdateSession(record, &pbSession) case grpcutil.GetTypeURL(new(user.User)): var pbUser user.User err := record.GetData().UnmarshalTo(&pbUser) @@ -356,7 +370,7 @@ func (mgr *Manager) onUpdateRecords(ctx context.Context, msg updateRecordsMessag } } -func (mgr *Manager) onUpdateSession(_ context.Context, record *databroker.Record, session *session.Session) { +func (mgr *Manager) onUpdateSession(record *databroker.Record, session *session.Session) { mgr.sessionScheduler.Remove(toSessionSchedulerKey(session.GetUserId(), session.GetId())) if record.GetDeletedAt() != nil { @@ -366,7 +380,9 @@ func (mgr *Manager) onUpdateSession(_ context.Context, record *databroker.Record // update session s, _ := mgr.sessions.Get(session.GetUserId(), session.GetId()) - s.lastRefresh = time.Now() + if s.lastRefresh.IsZero() { + s.lastRefresh = mgr.now() + } s.gracePeriod = mgr.cfg.Load().sessionRefreshGracePeriod s.coolOffDuration = mgr.cfg.Load().sessionRefreshCoolOffDuration s.Session = session diff --git a/internal/identity/manager/manager_test.go b/internal/identity/manager/manager_test.go index 1a4ed386b..97f37fa87 100644 --- a/internal/identity/manager/manager_test.go +++ b/internal/identity/manager/manager_test.go @@ -3,6 +3,7 @@ package manager import ( "context" "errors" + "fmt" "testing" "time" @@ -24,18 +25,23 @@ import ( "github.com/pomerium/pomerium/pkg/protoutil" ) -type mockAuthenticator struct{} - -func (mock mockAuthenticator) Refresh(_ context.Context, _ *oauth2.Token, _ identity.State) (*oauth2.Token, error) { - return nil, errors.New("update session") +type mockAuthenticator struct { + refreshResult *oauth2.Token + refreshError error + revokeError error + updateUserInfoError error } -func (mock mockAuthenticator) Revoke(_ context.Context, _ *oauth2.Token) error { - return errors.New("not implemented") +func (mock *mockAuthenticator) Refresh(_ context.Context, _ *oauth2.Token, _ identity.State) (*oauth2.Token, error) { + return mock.refreshResult, mock.refreshError } -func (mock mockAuthenticator) UpdateUserInfo(_ context.Context, _ *oauth2.Token, _ any) error { - return errors.New("update user info") +func (mock *mockAuthenticator) Revoke(_ context.Context, _ *oauth2.Token) error { + return mock.revokeError +} + +func (mock *mockAuthenticator) UpdateUserInfo(_ context.Context, _ *oauth2.Token, _ any) error { + return mock.updateUserInfoError } func TestManager_refresh(t *testing.T) { @@ -86,6 +92,9 @@ func TestManager_onUpdateRecords(t *testing.T) { }) if _, ok := mgr.sessions.Get("user1", "session1"); assert.True(t, ok) { + tm, id := mgr.sessionScheduler.Next() + assert.Equal(t, now.Add(10*time.Second), tm) + assert.Equal(t, "user1\037session1", id) } if _, ok := mgr.users.Get("user1"); assert.True(t, ok) { tm, id := mgr.userScheduler.Next() @@ -94,6 +103,159 @@ func TestManager_onUpdateRecords(t *testing.T) { } } +func TestManager_onUpdateSession(t *testing.T) { + startTime := time.Date(2023, 10, 19, 12, 0, 0, 0, time.UTC) + + s := &session.Session{ + Id: "session-id", + UserId: "user-id", + OauthToken: &session.OAuthToken{ + AccessToken: "access-token", + ExpiresAt: timestamppb.New(startTime.Add(5 * time.Minute)), + }, + IssuedAt: timestamppb.New(startTime), + ExpiresAt: timestamppb.New(startTime.Add(24 * time.Hour)), + } + + assertNextScheduled := func(t *testing.T, mgr *Manager, expectedTime time.Time) { + t.Helper() + tm, key := mgr.sessionScheduler.Next() + assert.Equal(t, expectedTime, tm) + assert.Equal(t, "user-id\037session-id", key) + } + + t.Run("initial refresh event when not expiring soon", func(t *testing.T) { + now := startTime + mgr := New(WithNow(func() time.Time { return now })) + + // When the Manager first becomes aware of a session it should schedule + // a refresh event for one minute before access token expiration. + mgr.onUpdateSession(mkRecord(s), s) + assertNextScheduled(t, mgr, startTime.Add(4*time.Minute)) + }) + t.Run("initial refresh event when expiring soon", func(t *testing.T) { + now := startTime + mgr := New(WithNow(func() time.Time { return now })) + + // When the Manager first becomes aware of a session, if that session + // is expiring within the gracePeriod (1 minute), it should schedule a + // refresh event for as soon as possible, subject to the + // coolOffDuration (10 seconds). + now = now.Add(4*time.Minute + 30*time.Second) // 30 s before expiration + mgr.onUpdateSession(mkRecord(s), s) + assertNextScheduled(t, mgr, now.Add(10*time.Second)) + }) + t.Run("update near scheduled refresh", func(t *testing.T) { + now := startTime + mgr := New(WithNow(func() time.Time { return now })) + + mgr.onUpdateSession(mkRecord(s), s) + assertNextScheduled(t, mgr, startTime.Add(4*time.Minute)) + + // If a session is updated close to the time when it is scheduled to be + // refreshed, the scheduled refresh event should not be pushed back. + now = now.Add(3*time.Minute + 55*time.Second) // 5 s before refresh + mgr.onUpdateSession(mkRecord(s), s) + assertNextScheduled(t, mgr, now.Add(5*time.Second)) + + // However, if an update changes the access token validity, the refresh + // event should be rescheduled accordingly. (This should be uncommon, + // as only the refresh loop itself should modify the access token.) + s2 := proto.Clone(s).(*session.Session) + s2.OauthToken.ExpiresAt = timestamppb.New(now.Add(5 * time.Minute)) + mgr.onUpdateSession(mkRecord(s2), s2) + assertNextScheduled(t, mgr, now.Add(4*time.Minute)) + }) + t.Run("session record deleted", func(t *testing.T) { + now := startTime + mgr := New(WithNow(func() time.Time { return now })) + + mgr.onUpdateSession(mkRecord(s), s) + assertNextScheduled(t, mgr, startTime.Add(4*time.Minute)) + + // If a session is deleted, any scheduled refresh event should be canceled. + record := mkRecord(s) + record.DeletedAt = timestamppb.New(now) + mgr.onUpdateSession(record, s) + _, key := mgr.sessionScheduler.Next() + assert.Empty(t, key) + }) +} + +func TestManager_refreshSession(t *testing.T) { + startTime := time.Date(2023, 10, 19, 12, 0, 0, 0, time.UTC) + + var auth mockAuthenticator + + ctrl := gomock.NewController(t) + client := mock_databroker.NewMockDataBrokerServiceClient(ctrl) + + now := startTime + mgr := New( + WithDataBrokerClient(client), + WithNow(func() time.Time { return now }), + WithAuthenticator(&auth), + ) + + // Initialize the Manager with a new session. + s := &session.Session{ + Id: "session-id", + UserId: "user-id", + OauthToken: &session.OAuthToken{ + AccessToken: "access-token", + ExpiresAt: timestamppb.New(startTime.Add(5 * time.Minute)), + RefreshToken: "refresh-token", + }, + IssuedAt: timestamppb.New(startTime), + ExpiresAt: timestamppb.New(startTime.Add(24 * time.Hour)), + } + mgr.sessions.ReplaceOrInsert(Session{ + Session: s, + lastRefresh: startTime, + gracePeriod: time.Minute, + coolOffDuration: 10 * time.Second, + }) + + // After a success token refresh, the manager should schedule another + // refresh event. + now = now.Add(4 * time.Minute) + auth.refreshResult, auth.refreshError = &oauth2.Token{ + AccessToken: "new-access-token", + RefreshToken: "new-refresh-token", + Expiry: now.Add(5 * time.Minute), + }, nil + expectedSession := proto.Clone(s).(*session.Session) + expectedSession.OauthToken = &session.OAuthToken{ + AccessToken: "new-access-token", + ExpiresAt: timestamppb.New(now.Add(5 * time.Minute)), + RefreshToken: "new-refresh-token", + } + client.EXPECT().Put(gomock.Any(), + objectsAreEqualMatcher{&databroker.PutRequest{Records: []*databroker.Record{{ + Type: "type.googleapis.com/session.Session", + Id: "session-id", + Data: protoutil.NewAny(expectedSession), + }}}}). + Return(nil /* this result is currently unused */, nil) + mgr.refreshSession(context.Background(), "user-id", "session-id") + + tm, key := mgr.sessionScheduler.Next() + assert.Equal(t, now.Add(4*time.Minute), tm) + assert.Equal(t, "user-id\037session-id", key) +} + +type objectsAreEqualMatcher struct { + expected interface{} +} + +func (m objectsAreEqualMatcher) Matches(x interface{}) bool { + return assert.ObjectsAreEqual(m.expected, x) +} + +func (m objectsAreEqualMatcher) String() string { + return fmt.Sprintf("is equal to %v (%T)", m.expected, m.expected) +} + func TestManager_reportErrors(t *testing.T) { ctrl := gomock.NewController(t) @@ -135,7 +297,10 @@ func TestManager_reportErrors(t *testing.T) { mgr := New( WithEventManager(evtMgr), WithDataBrokerClient(client), - WithAuthenticator(mockAuthenticator{}), + WithAuthenticator(&mockAuthenticator{ + refreshError: errors.New("update session"), + updateUserInfoError: errors.New("update user info"), + }), ) mgr.onUpdateRecords(ctx, updateRecordsMessage{