From 39a477c5106baf4bac53a1d2261b544c772b5250 Mon Sep 17 00:00:00 2001 From: Kenneth Jenkins <51246568+kenjenkins@users.noreply.github.com> Date: Mon, 23 Oct 2023 08:20:04 -0700 Subject: [PATCH 1/4] identity: override TokenSource expiry behavior (#4632) The current session refresh loop attempts to refresh access tokens when they are due to expire in less than one minute. However, the code to perform the refresh relies on a TokenSource from the x/oauth2 package, which has its own internal 'expiryDelta' threshold, with a default of 10 seconds. As a result, the first four or five attempts to refresh a particular access token will not actually refresh the token. The refresh will happen only when the access token is within 10 seconds of expiring. Instead, before we obtain a new TokenSource, first clear any existing access token. This causes the TokenSource to consider the token invalid, triggering a refresh. This should give the refresh loop more control over when refreshes happen. Consolidate this logic in a new Refresh() method in the oidc package. Add unit tests for this new method. --- internal/identity/oauth/apple/apple.go | 11 +--- internal/identity/oidc/oidc.go | 11 +--- internal/identity/oidc/refresh.go | 29 +++++++++++ internal/identity/oidc/refresh_test.go | 71 ++++++++++++++++++++++++++ 4 files changed, 104 insertions(+), 18 deletions(-) create mode 100644 internal/identity/oidc/refresh.go create mode 100644 internal/identity/oidc/refresh_test.go diff --git a/internal/identity/oauth/apple/apple.go b/internal/identity/oauth/apple/apple.go index 53878606f..bed70fa94 100644 --- a/internal/identity/oauth/apple/apple.go +++ b/internal/identity/oauth/apple/apple.go @@ -130,16 +130,9 @@ func (p *Provider) Authenticate(ctx context.Context, code string, v identity.Sta // Refresh renews a user's session. func (p *Provider) Refresh(ctx context.Context, t *oauth2.Token, v identity.State) (*oauth2.Token, error) { - if t == nil { - return nil, oidc.ErrMissingAccessToken - } - if t.RefreshToken == "" { - return nil, oidc.ErrMissingRefreshToken - } - - newToken, err := p.oauth.TokenSource(ctx, t).Token() + newToken, err := oidc.Refresh(ctx, p.oauth, t) if err != nil { - return nil, fmt.Errorf("identity/apple: refresh failed: %w", err) + return nil, err } if rawIDToken, ok := newToken.Extra("id_token").(string); ok { diff --git a/internal/identity/oidc/oidc.go b/internal/identity/oidc/oidc.go index 832a62a85..d13e3634f 100644 --- a/internal/identity/oidc/oidc.go +++ b/internal/identity/oidc/oidc.go @@ -213,16 +213,9 @@ func (p *Provider) Refresh(ctx context.Context, t *oauth2.Token, v identity.Stat return nil, err } - if t == nil { - return nil, ErrMissingAccessToken - } - if t.RefreshToken == "" { - return nil, ErrMissingRefreshToken - } - - newToken, err := oa.TokenSource(ctx, t).Token() + newToken, err := Refresh(ctx, oa, t) if err != nil { - return nil, fmt.Errorf("identity/oidc: refresh failed: %w", err) + return nil, err } // Many identity providers _will not_ return `id_token` on refresh diff --git a/internal/identity/oidc/refresh.go b/internal/identity/oidc/refresh.go new file mode 100644 index 000000000..a2c1a216c --- /dev/null +++ b/internal/identity/oidc/refresh.go @@ -0,0 +1,29 @@ +package oidc + +import ( + "context" + "fmt" + + "golang.org/x/oauth2" +) + +// Refresh requests a new oauth2.Token based on an existing Token and the +// provided Config. The existing Token must contain a refresh token. +func Refresh(ctx context.Context, cfg *oauth2.Config, t *oauth2.Token) (*oauth2.Token, error) { + if t == nil || t.RefreshToken == "" { + return nil, ErrMissingRefreshToken + } + + // Note: the TokenSource returned by oauth2.Config has its own threshold + // for determining when to attempt a refresh. In order to force a refresh + // we can remove the current AccessToken. + t = &oauth2.Token{ + TokenType: t.TokenType, + RefreshToken: t.RefreshToken, + } + newToken, err := cfg.TokenSource(ctx, t).Token() + if err != nil { + return nil, fmt.Errorf("identity/oidc: refresh failed: %w", err) + } + return newToken, nil +} diff --git a/internal/identity/oidc/refresh_test.go b/internal/identity/oidc/refresh_test.go new file mode 100644 index 000000000..6ac7f067b --- /dev/null +++ b/internal/identity/oidc/refresh_test.go @@ -0,0 +1,71 @@ +package oidc + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/oauth2" +) + +func TestRefresh(t *testing.T) { + t.Parallel() + + ctx, clearTimeout := context.WithTimeout(context.Background(), 10*time.Second) + t.Cleanup(clearTimeout) + + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{ + "access_token": "NEW_TOKEN", + "refresh_token": "NEW_REFRESH_TOKEN", + "expires_in": 3600 + }`)) + })) + t.Cleanup(s.Close) + + cfg := &oauth2.Config{Endpoint: oauth2.Endpoint{TokenURL: s.URL}} + + token := &oauth2.Token{ + AccessToken: "OLD_TOKEN", + RefreshToken: "OLD_REFRESH_TOKEN", + + // Even if a token is not expiring soon, Refresh() should still perform + // the refresh. + Expiry: time.Now().Add(time.Hour), + } + require.True(t, token.Valid()) + + newToken, err := Refresh(ctx, cfg, token) + require.NoError(t, err) + assert.Equal(t, "NEW_TOKEN", newToken.AccessToken) + assert.Equal(t, "NEW_REFRESH_TOKEN", newToken.RefreshToken) +} + +func TestRefresh_errors(t *testing.T) { + t.Parallel() + + ctx, clearTimeout := context.WithTimeout(context.Background(), 10*time.Second) + t.Cleanup(clearTimeout) + + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("{}")) + })) + t.Cleanup(s.Close) + + cfg := &oauth2.Config{Endpoint: oauth2.Endpoint{TokenURL: s.URL}} + + _, err := Refresh(ctx, cfg, nil) + assert.Equal(t, ErrMissingRefreshToken, err) + + _, err = Refresh(ctx, cfg, &oauth2.Token{}) + assert.Equal(t, ErrMissingRefreshToken, err) + + _, err = Refresh(ctx, cfg, &oauth2.Token{RefreshToken: "REFRESH_TOKEN"}) + assert.Equal(t, "identity/oidc: refresh failed: oauth2: server response missing access_token", + err.Error()) +} From 1996550c541e0c260210aaa4b35e0bfa4ebb664d Mon Sep 17 00:00:00 2001 From: Kenneth Jenkins <51246568+kenjenkins@users.noreply.github.com> Date: Tue, 24 Oct 2023 08:39:10 -0700 Subject: [PATCH 2/4] upgrade envoy to v1.28.0 (#4635) --- scripts/get-envoy.bash | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/get-envoy.bash b/scripts/get-envoy.bash index a1cbc7ddb..6b73735f0 100755 --- a/scripts/get-envoy.bash +++ b/scripts/get-envoy.bash @@ -5,7 +5,7 @@ PATH="$PATH:$(go env GOPATH)/bin" export PATH _project_root="$(cd "$(dirname "${BASH_SOURCE[0]}")" >/dev/null 2>&1 && pwd)/.." -_envoy_version=1.27.1 +_envoy_version=1.28.0 _dir="$_project_root/pkg/envoy/files" for _target in darwin-amd64 darwin-arm64 linux-amd64 linux-arm64; do From fa7dc469a360646c1385e3db6704ed1ffac1f466 Mon Sep 17 00:00:00 2001 From: Kenneth Jenkins <51246568+kenjenkins@users.noreply.github.com> Date: Tue, 24 Oct 2023 14:46:33 -0700 Subject: [PATCH 3/4] identity: preserve session refresh schedule (#4633) The databroker identity manager is responsible for refreshing session records, to account for overall session expiration as well as OAuth2 access token expiration. Refresh events are scheduled subject to a coolOffDuration (10 seconds, by default) relative to a lastRefresh timestamp. Currently, any update to a session record will reset the associated lastRefresh value and reschedule any pending refresh event for that session. If an update occurs close before a scheduled refresh event, this will push back the scheduled refresh event to 10 seconds from that time. This means that if a session is updated frequently enough (e.g. if there is a steady stream of requests that cause constant updates via the AccessTracker), the access token may expire before a refresh ever runs. To avoid this problem, do not update the lastRefresh time upon every session record update, but only if it hasn't yet been set. Instead, update the lastRefresh during the refresh attempt itself. Add unit tests to exercise these changes. There is a now() function as part of the manager configuration (to allow unit tests to set a fake time); update the Manager to use this function throughout. --- internal/identity/manager/data.go | 3 + internal/identity/manager/manager.go | 34 ++-- internal/identity/manager/manager_test.go | 186 ++++++++++++++++++++-- 3 files changed, 205 insertions(+), 18 deletions(-) diff --git a/internal/identity/manager/data.go b/internal/identity/manager/data.go index a397d5a45..2c8637102 100644 --- a/internal/identity/manager/data.go +++ b/internal/identity/manager/data.go @@ -54,6 +54,9 @@ func (u *User) UnmarshalJSON(data []byte) error { // A Session is a session managed by the Manager. type Session struct { *session.Session + // lastRefresh is the time of the last refresh attempt (which may or may + // not have succeeded), or else the time the Manager first became aware of + // the session (if it has not yet attempted to refresh this session). lastRefresh time.Time // gracePeriod is the amount of time before expiration to attempt a refresh. gracePeriod time.Duration diff --git a/internal/identity/manager/manager.go b/internal/identity/manager/manager.go index 02ddb635a..b9a2c45a9 100644 --- a/internal/identity/manager/manager.go +++ b/internal/identity/manager/manager.go @@ -107,6 +107,10 @@ func (mgr *Manager) GetDataBrokerServiceClient() databroker.DataBrokerServiceCli return mgr.cfg.Load().dataBrokerClient } +func (mgr *Manager) now() time.Time { + return mgr.cfg.Load().now() +} + func (mgr *Manager) refreshLoop(ctx context.Context, update <-chan updateRecordsMessage, clear <-chan struct{}) error { // wait for initial sync select { @@ -145,7 +149,7 @@ func (mgr *Manager) refreshLoop(ctx context.Context, update <-chan updateRecords case <-timer.C: } - now := time.Now() + now := mgr.now() nextTime = now.Add(maxWait) // refresh sessions @@ -182,6 +186,15 @@ func (mgr *Manager) refreshLoop(ctx context.Context, update <-chan updateRecords } } +// refreshSession handles two distinct session lifecycle events: +// +// 1. If the session itself has expired, delete the session. +// 2. If the session's underlying OAuth2 access token is nearing expiration +// (but the session itself is still valid), refresh the access token. +// +// After a successful access token refresh, this method will also trigger a +// user info refresh. If an access token refresh or a user info refresh fails +// with a permanent error, the session will be deleted. func (mgr *Manager) refreshSession(ctx context.Context, userID, sessionID string) { log.Info(ctx). Str("user_id", userID). @@ -208,7 +221,7 @@ func (mgr *Manager) refreshSession(ctx context.Context, userID, sessionID string } expiry := s.GetExpiresAt().AsTime() - if !expiry.After(time.Now()) { + if !expiry.After(mgr.now()) { log.Info(ctx). Str("user_id", userID). Str("session_id", sessionID). @@ -262,8 +275,7 @@ func (mgr *Manager) refreshSession(ctx context.Context, userID, sessionID string return } - res, err := session.Put(ctx, mgr.cfg.Load().dataBrokerClient, s.Session) - if err != nil { + if _, err := session.Put(ctx, mgr.cfg.Load().dataBrokerClient, s.Session); err != nil { log.Error(ctx).Err(err). Str("user_id", s.GetUserId()). Str("session_id", s.GetId()). @@ -271,7 +283,9 @@ func (mgr *Manager) refreshSession(ctx context.Context, userID, sessionID string return } - mgr.onUpdateSession(ctx, res.GetRecord(), s.Session) + s.lastRefresh = mgr.now() + mgr.sessions.ReplaceOrInsert(s) + mgr.sessionScheduler.Add(s.NextRefresh(), toSessionSchedulerKey(userID, sessionID)) } func (mgr *Manager) refreshUser(ctx context.Context, userID string) { @@ -291,7 +305,7 @@ func (mgr *Manager) refreshUser(ctx context.Context, userID string) { Msg("no user found for refresh") return } - u.lastRefresh = time.Now() + u.lastRefresh = mgr.now() mgr.userScheduler.Add(u.NextRefresh(), u.GetId()) for _, s := range mgr.sessions.GetSessionsForUser(userID) { @@ -343,7 +357,7 @@ func (mgr *Manager) onUpdateRecords(ctx context.Context, msg updateRecordsMessag log.Warn(ctx).Msgf("error unmarshaling session: %s", err) continue } - mgr.onUpdateSession(ctx, record, &pbSession) + mgr.onUpdateSession(record, &pbSession) case grpcutil.GetTypeURL(new(user.User)): var pbUser user.User err := record.GetData().UnmarshalTo(&pbUser) @@ -356,7 +370,7 @@ func (mgr *Manager) onUpdateRecords(ctx context.Context, msg updateRecordsMessag } } -func (mgr *Manager) onUpdateSession(_ context.Context, record *databroker.Record, session *session.Session) { +func (mgr *Manager) onUpdateSession(record *databroker.Record, session *session.Session) { mgr.sessionScheduler.Remove(toSessionSchedulerKey(session.GetUserId(), session.GetId())) if record.GetDeletedAt() != nil { @@ -366,7 +380,9 @@ func (mgr *Manager) onUpdateSession(_ context.Context, record *databroker.Record // update session s, _ := mgr.sessions.Get(session.GetUserId(), session.GetId()) - s.lastRefresh = time.Now() + if s.lastRefresh.IsZero() { + s.lastRefresh = mgr.now() + } s.gracePeriod = mgr.cfg.Load().sessionRefreshGracePeriod s.coolOffDuration = mgr.cfg.Load().sessionRefreshCoolOffDuration s.Session = session diff --git a/internal/identity/manager/manager_test.go b/internal/identity/manager/manager_test.go index 1a4ed386b..437a72058 100644 --- a/internal/identity/manager/manager_test.go +++ b/internal/identity/manager/manager_test.go @@ -3,6 +3,7 @@ package manager import ( "context" "errors" + "fmt" "testing" "time" @@ -24,18 +25,23 @@ import ( "github.com/pomerium/pomerium/pkg/protoutil" ) -type mockAuthenticator struct{} - -func (mock mockAuthenticator) Refresh(_ context.Context, _ *oauth2.Token, _ identity.State) (*oauth2.Token, error) { - return nil, errors.New("update session") +type mockAuthenticator struct { + refreshResult *oauth2.Token + refreshError error + revokeError error + updateUserInfoError error } -func (mock mockAuthenticator) Revoke(_ context.Context, _ *oauth2.Token) error { - return errors.New("not implemented") +func (mock *mockAuthenticator) Refresh(_ context.Context, _ *oauth2.Token, _ identity.State) (*oauth2.Token, error) { + return mock.refreshResult, mock.refreshError } -func (mock mockAuthenticator) UpdateUserInfo(_ context.Context, _ *oauth2.Token, _ any) error { - return errors.New("update user info") +func (mock *mockAuthenticator) Revoke(_ context.Context, _ *oauth2.Token) error { + return mock.revokeError +} + +func (mock *mockAuthenticator) UpdateUserInfo(_ context.Context, _ *oauth2.Token, _ any) error { + return mock.updateUserInfoError } func TestManager_refresh(t *testing.T) { @@ -86,6 +92,9 @@ func TestManager_onUpdateRecords(t *testing.T) { }) if _, ok := mgr.sessions.Get("user1", "session1"); assert.True(t, ok) { + tm, id := mgr.sessionScheduler.Next() + assert.Equal(t, now.Add(10*time.Second), tm) + assert.Equal(t, "user1\037session1", id) } if _, ok := mgr.users.Get("user1"); assert.True(t, ok) { tm, id := mgr.userScheduler.Next() @@ -94,6 +103,147 @@ func TestManager_onUpdateRecords(t *testing.T) { } } +func TestManager_onUpdateSession(t *testing.T) { + startTime := time.Date(2023, 10, 19, 12, 0, 0, 0, time.UTC) + + s := &session.Session{ + Id: "session-id", + UserId: "user-id", + OauthToken: &session.OAuthToken{ + AccessToken: "access-token", + ExpiresAt: timestamppb.New(startTime.Add(5 * time.Minute)), + }, + IssuedAt: timestamppb.New(startTime), + ExpiresAt: timestamppb.New(startTime.Add(24 * time.Hour)), + } + + assertNextScheduled := func(t *testing.T, mgr *Manager, expectedTime time.Time) { + t.Helper() + tm, key := mgr.sessionScheduler.Next() + assert.Equal(t, expectedTime, tm) + assert.Equal(t, "user-id\037session-id", key) + } + + t.Run("initial refresh event when not expiring soon", func(t *testing.T) { + now := startTime + mgr := New(WithNow(func() time.Time { return now })) + + // When the Manager first becomes aware of a session it should schedule + // a refresh event for one minute before access token expiration. + mgr.onUpdateSession(mkRecord(s), s) + assertNextScheduled(t, mgr, startTime.Add(4*time.Minute)) + }) + t.Run("initial refresh event when expiring soon", func(t *testing.T) { + now := startTime + mgr := New(WithNow(func() time.Time { return now })) + + // When the Manager first becomes aware of a session, if that session + // is expiring within the gracePeriod (1 minute), it should schedule a + // refresh event for as soon as possible, subject to the + // coolOffDuration (10 seconds). + now = now.Add(4*time.Minute + 30*time.Second) // 30 s before expiration + mgr.onUpdateSession(mkRecord(s), s) + assertNextScheduled(t, mgr, now.Add(10*time.Second)) + }) + t.Run("update near scheduled refresh", func(t *testing.T) { + now := startTime + mgr := New(WithNow(func() time.Time { return now })) + + mgr.onUpdateSession(mkRecord(s), s) + assertNextScheduled(t, mgr, startTime.Add(4*time.Minute)) + + // If a session is updated close to the time when it is scheduled to be + // refreshed, the scheduled refresh event should not be pushed back. + now = now.Add(3*time.Minute + 55*time.Second) // 5 s before refresh + mgr.onUpdateSession(mkRecord(s), s) + assertNextScheduled(t, mgr, now.Add(5*time.Second)) + + // However, if an update changes the access token validity, the refresh + // event should be rescheduled accordingly. (This should be uncommon, + // as only the refresh loop itself should modify the access token.) + s2 := proto.Clone(s).(*session.Session) + s2.OauthToken.ExpiresAt = timestamppb.New(now.Add(5 * time.Minute)) + mgr.onUpdateSession(mkRecord(s2), s2) + assertNextScheduled(t, mgr, now.Add(4*time.Minute)) + }) + t.Run("session record deleted", func(t *testing.T) { + now := startTime + mgr := New(WithNow(func() time.Time { return now })) + + mgr.onUpdateSession(mkRecord(s), s) + assertNextScheduled(t, mgr, startTime.Add(4*time.Minute)) + + // If a session is deleted, any scheduled refresh event should be canceled. + record := mkRecord(s) + record.DeletedAt = timestamppb.New(now) + mgr.onUpdateSession(record, s) + _, key := mgr.sessionScheduler.Next() + assert.Empty(t, key) + }) +} + +func TestManager_refreshSession(t *testing.T) { + startTime := time.Date(2023, 10, 19, 12, 0, 0, 0, time.UTC) + + var auth mockAuthenticator + + ctrl := gomock.NewController(t) + client := mock_databroker.NewMockDataBrokerServiceClient(ctrl) + + now := startTime + mgr := New( + WithDataBrokerClient(client), + WithNow(func() time.Time { return now }), + WithAuthenticator(&auth), + ) + + // Initialize the Manager with a new session. + s := &session.Session{ + Id: "session-id", + UserId: "user-id", + OauthToken: &session.OAuthToken{ + AccessToken: "access-token", + ExpiresAt: timestamppb.New(startTime.Add(5 * time.Minute)), + RefreshToken: "refresh-token", + }, + IssuedAt: timestamppb.New(startTime), + ExpiresAt: timestamppb.New(startTime.Add(24 * time.Hour)), + } + mgr.sessions.ReplaceOrInsert(Session{ + Session: s, + lastRefresh: startTime, + gracePeriod: time.Minute, + coolOffDuration: 10 * time.Second, + }) + + // After a success token refresh, the manager should schedule another + // refresh event. + now = now.Add(4 * time.Minute) + auth.refreshResult, auth.refreshError = &oauth2.Token{ + AccessToken: "new-access-token", + RefreshToken: "new-refresh-token", + Expiry: now.Add(5 * time.Minute), + }, nil + expectedSession := proto.Clone(s).(*session.Session) + expectedSession.OauthToken = &session.OAuthToken{ + AccessToken: "new-access-token", + ExpiresAt: timestamppb.New(now.Add(5 * time.Minute)), + RefreshToken: "new-refresh-token", + } + client.EXPECT().Put(gomock.Any(), + objectsAreEqualMatcher{&databroker.PutRequest{Records: []*databroker.Record{{ + Type: "type.googleapis.com/session.Session", + Id: "session-id", + Data: protoutil.NewAny(expectedSession), + }}}}). + Return(nil /* this result is currently unused */, nil) + mgr.refreshSession(context.Background(), "user-id", "session-id") + + tm, key := mgr.sessionScheduler.Next() + assert.Equal(t, now.Add(4*time.Minute), tm) + assert.Equal(t, "user-id\037session-id", key) +} + func TestManager_reportErrors(t *testing.T) { ctrl := gomock.NewController(t) @@ -135,7 +285,10 @@ func TestManager_reportErrors(t *testing.T) { mgr := New( WithEventManager(evtMgr), WithDataBrokerClient(client), - WithAuthenticator(mockAuthenticator{}), + WithAuthenticator(&mockAuthenticator{ + refreshError: errors.New("update session"), + updateUserInfoError: errors.New("update user info"), + }), ) mgr.onUpdateRecords(ctx, updateRecordsMessage{ @@ -172,3 +325,18 @@ type recordable interface { proto.Message GetId() string } + +// objectsAreEqualMatcher implements gomock.Matcher using ObjectsAreEqual. This +// is especially helpful when working with pointers, as it will compare the +// underlying values rather than the pointers themselves. +type objectsAreEqualMatcher struct { + expected interface{} +} + +func (m objectsAreEqualMatcher) Matches(x interface{}) bool { + return assert.ObjectsAreEqual(m.expected, x) +} + +func (m objectsAreEqualMatcher) String() string { + return fmt.Sprintf("is equal to %v (%T)", m.expected, m.expected) +} From 1d2c525b1a672a8ab028cab40c8c5ef41c42aa3f Mon Sep 17 00:00:00 2001 From: Kenneth Jenkins <51246568+kenjenkins@users.noreply.github.com> Date: Tue, 24 Oct 2023 15:44:51 -0700 Subject: [PATCH 4/4] identity: rework session refresh error handling (#4638) Currently, if a temporary error occurs while attempting to refresh an OAuth2 token, the identity manager won't schedule another attempt. Instead, update the session refresh logic so that it will retry after temporary errors. Extract the bulk of this logic into a separate method that returns a boolean indicating whether to schedule another refresh. Update the unit test to simulate a temporary error during OAuth2 token refresh. --- internal/identity/manager/manager.go | 55 +++++++++++++---------- internal/identity/manager/manager_test.go | 17 +++++-- 2 files changed, 46 insertions(+), 26 deletions(-) diff --git a/internal/identity/manager/manager.go b/internal/identity/manager/manager.go index b9a2c45a9..51056b70c 100644 --- a/internal/identity/manager/manager.go +++ b/internal/identity/manager/manager.go @@ -201,16 +201,6 @@ func (mgr *Manager) refreshSession(ctx context.Context, userID, sessionID string Str("session_id", sessionID). Msg("refreshing session") - authenticator := mgr.cfg.Load().authenticator - if authenticator == nil { - log.Info(ctx). - Str("user_id", userID). - Str("session_id", sessionID). - Msg("no authenticator defined, deleting session") - mgr.deleteSession(ctx, userID, sessionID) - return - } - s, ok := mgr.sessions.Get(userID, sessionID) if !ok { log.Warn(ctx). @@ -220,6 +210,29 @@ func (mgr *Manager) refreshSession(ctx context.Context, userID, sessionID string return } + s.lastRefresh = mgr.now() + + if mgr.refreshSessionInternal(ctx, userID, sessionID, &s) { + mgr.sessions.ReplaceOrInsert(s) + mgr.sessionScheduler.Add(s.NextRefresh(), toSessionSchedulerKey(userID, sessionID)) + } +} + +// refreshSessionInternal performs the core refresh logic and returns true if +// the session should be scheduled for refresh again, or false if not. +func (mgr *Manager) refreshSessionInternal( + ctx context.Context, userID, sessionID string, s *Session, +) bool { + authenticator := mgr.cfg.Load().authenticator + if authenticator == nil { + log.Info(ctx). + Str("user_id", userID). + Str("session_id", sessionID). + Msg("no authenticator defined, deleting session") + mgr.deleteSession(ctx, userID, sessionID) + return false + } + expiry := s.GetExpiresAt().AsTime() if !expiry.After(mgr.now()) { log.Info(ctx). @@ -227,7 +240,7 @@ func (mgr *Manager) refreshSession(ctx context.Context, userID, sessionID string Str("session_id", sessionID). Msg("deleting expired session") mgr.deleteSession(ctx, userID, sessionID) - return + return false } if s.Session == nil || s.Session.OauthToken == nil { @@ -235,10 +248,10 @@ func (mgr *Manager) refreshSession(ctx context.Context, userID, sessionID string Str("user_id", userID). Str("session_id", sessionID). Msg("no session oauth2 token found for refresh") - return + return false } - newToken, err := authenticator.Refresh(ctx, FromOAuthToken(s.OauthToken), &s) + newToken, err := authenticator.Refresh(ctx, FromOAuthToken(s.OauthToken), s) metrics.RecordIdentityManagerSessionRefresh(ctx, err) mgr.recordLastError(metrics_ids.IdentityManagerLastSessionRefreshError, err) if isTemporaryError(err) { @@ -246,18 +259,18 @@ func (mgr *Manager) refreshSession(ctx context.Context, userID, sessionID string Str("user_id", s.GetUserId()). Str("session_id", s.GetId()). Msg("failed to refresh oauth2 token") - return + return true } else if err != nil { log.Error(ctx).Err(err). Str("user_id", s.GetUserId()). Str("session_id", s.GetId()). Msg("failed to refresh oauth2 token, deleting session") mgr.deleteSession(ctx, userID, sessionID) - return + return false } s.OauthToken = ToOAuthToken(newToken) - err = authenticator.UpdateUserInfo(ctx, FromOAuthToken(s.OauthToken), &s) + err = authenticator.UpdateUserInfo(ctx, FromOAuthToken(s.OauthToken), s) metrics.RecordIdentityManagerUserRefresh(ctx, err) mgr.recordLastError(metrics_ids.IdentityManagerLastUserRefreshError, err) if isTemporaryError(err) { @@ -265,14 +278,14 @@ func (mgr *Manager) refreshSession(ctx context.Context, userID, sessionID string Str("user_id", s.GetUserId()). Str("session_id", s.GetId()). Msg("failed to update user info") - return + return true } else if err != nil { log.Error(ctx).Err(err). Str("user_id", s.GetUserId()). Str("session_id", s.GetId()). Msg("failed to update user info, deleting session") mgr.deleteSession(ctx, userID, sessionID) - return + return false } if _, err := session.Put(ctx, mgr.cfg.Load().dataBrokerClient, s.Session); err != nil { @@ -280,12 +293,8 @@ func (mgr *Manager) refreshSession(ctx context.Context, userID, sessionID string Str("user_id", s.GetUserId()). Str("session_id", s.GetId()). Msg("failed to update session") - return } - - s.lastRefresh = mgr.now() - mgr.sessions.ReplaceOrInsert(s) - mgr.sessionScheduler.Add(s.NextRefresh(), toSessionSchedulerKey(userID, sessionID)) + return true } func (mgr *Manager) refreshUser(ctx context.Context, userID string) { diff --git a/internal/identity/manager/manager_test.go b/internal/identity/manager/manager_test.go index 437a72058..29e49280b 100644 --- a/internal/identity/manager/manager_test.go +++ b/internal/identity/manager/manager_test.go @@ -216,9 +216,20 @@ func TestManager_refreshSession(t *testing.T) { coolOffDuration: 10 * time.Second, }) - // After a success token refresh, the manager should schedule another - // refresh event. + // If OAuth2 token refresh fails with a temporary error, the manager should + // still reschedule another refresh attempt. now = now.Add(4 * time.Minute) + auth.refreshError = context.DeadlineExceeded + mgr.refreshSession(context.Background(), "user-id", "session-id") + + tm, key := mgr.sessionScheduler.Next() + assert.Equal(t, now.Add(10*time.Second), tm) + assert.Equal(t, "user-id\037session-id", key) + + // Simulate a successful token refresh on the second attempt. The manager + // should store the updated session in the databroker and schedule another + // refresh event. + now = now.Add(10 * time.Second) auth.refreshResult, auth.refreshError = &oauth2.Token{ AccessToken: "new-access-token", RefreshToken: "new-refresh-token", @@ -239,7 +250,7 @@ func TestManager_refreshSession(t *testing.T) { Return(nil /* this result is currently unused */, nil) mgr.refreshSession(context.Background(), "user-id", "session-id") - tm, key := mgr.sessionScheduler.Next() + tm, key = mgr.sessionScheduler.Next() assert.Equal(t, now.Add(4*time.Minute), tm) assert.Equal(t, "user-id\037session-id", key) }