diff --git a/internal/identity/manager/manager.go b/internal/identity/manager/manager.go index b9a2c45a9..51056b70c 100644 --- a/internal/identity/manager/manager.go +++ b/internal/identity/manager/manager.go @@ -201,16 +201,6 @@ func (mgr *Manager) refreshSession(ctx context.Context, userID, sessionID string Str("session_id", sessionID). Msg("refreshing session") - authenticator := mgr.cfg.Load().authenticator - if authenticator == nil { - log.Info(ctx). - Str("user_id", userID). - Str("session_id", sessionID). - Msg("no authenticator defined, deleting session") - mgr.deleteSession(ctx, userID, sessionID) - return - } - s, ok := mgr.sessions.Get(userID, sessionID) if !ok { log.Warn(ctx). @@ -220,6 +210,29 @@ func (mgr *Manager) refreshSession(ctx context.Context, userID, sessionID string return } + s.lastRefresh = mgr.now() + + if mgr.refreshSessionInternal(ctx, userID, sessionID, &s) { + mgr.sessions.ReplaceOrInsert(s) + mgr.sessionScheduler.Add(s.NextRefresh(), toSessionSchedulerKey(userID, sessionID)) + } +} + +// refreshSessionInternal performs the core refresh logic and returns true if +// the session should be scheduled for refresh again, or false if not. +func (mgr *Manager) refreshSessionInternal( + ctx context.Context, userID, sessionID string, s *Session, +) bool { + authenticator := mgr.cfg.Load().authenticator + if authenticator == nil { + log.Info(ctx). + Str("user_id", userID). + Str("session_id", sessionID). + Msg("no authenticator defined, deleting session") + mgr.deleteSession(ctx, userID, sessionID) + return false + } + expiry := s.GetExpiresAt().AsTime() if !expiry.After(mgr.now()) { log.Info(ctx). @@ -227,7 +240,7 @@ func (mgr *Manager) refreshSession(ctx context.Context, userID, sessionID string Str("session_id", sessionID). Msg("deleting expired session") mgr.deleteSession(ctx, userID, sessionID) - return + return false } if s.Session == nil || s.Session.OauthToken == nil { @@ -235,10 +248,10 @@ func (mgr *Manager) refreshSession(ctx context.Context, userID, sessionID string Str("user_id", userID). Str("session_id", sessionID). Msg("no session oauth2 token found for refresh") - return + return false } - newToken, err := authenticator.Refresh(ctx, FromOAuthToken(s.OauthToken), &s) + newToken, err := authenticator.Refresh(ctx, FromOAuthToken(s.OauthToken), s) metrics.RecordIdentityManagerSessionRefresh(ctx, err) mgr.recordLastError(metrics_ids.IdentityManagerLastSessionRefreshError, err) if isTemporaryError(err) { @@ -246,18 +259,18 @@ func (mgr *Manager) refreshSession(ctx context.Context, userID, sessionID string Str("user_id", s.GetUserId()). Str("session_id", s.GetId()). Msg("failed to refresh oauth2 token") - return + return true } else if err != nil { log.Error(ctx).Err(err). Str("user_id", s.GetUserId()). Str("session_id", s.GetId()). Msg("failed to refresh oauth2 token, deleting session") mgr.deleteSession(ctx, userID, sessionID) - return + return false } s.OauthToken = ToOAuthToken(newToken) - err = authenticator.UpdateUserInfo(ctx, FromOAuthToken(s.OauthToken), &s) + err = authenticator.UpdateUserInfo(ctx, FromOAuthToken(s.OauthToken), s) metrics.RecordIdentityManagerUserRefresh(ctx, err) mgr.recordLastError(metrics_ids.IdentityManagerLastUserRefreshError, err) if isTemporaryError(err) { @@ -265,14 +278,14 @@ func (mgr *Manager) refreshSession(ctx context.Context, userID, sessionID string Str("user_id", s.GetUserId()). Str("session_id", s.GetId()). Msg("failed to update user info") - return + return true } else if err != nil { log.Error(ctx).Err(err). Str("user_id", s.GetUserId()). Str("session_id", s.GetId()). Msg("failed to update user info, deleting session") mgr.deleteSession(ctx, userID, sessionID) - return + return false } if _, err := session.Put(ctx, mgr.cfg.Load().dataBrokerClient, s.Session); err != nil { @@ -280,12 +293,8 @@ func (mgr *Manager) refreshSession(ctx context.Context, userID, sessionID string Str("user_id", s.GetUserId()). Str("session_id", s.GetId()). Msg("failed to update session") - return } - - s.lastRefresh = mgr.now() - mgr.sessions.ReplaceOrInsert(s) - mgr.sessionScheduler.Add(s.NextRefresh(), toSessionSchedulerKey(userID, sessionID)) + return true } func (mgr *Manager) refreshUser(ctx context.Context, userID string) { diff --git a/internal/identity/manager/manager_test.go b/internal/identity/manager/manager_test.go index 437a72058..29e49280b 100644 --- a/internal/identity/manager/manager_test.go +++ b/internal/identity/manager/manager_test.go @@ -216,9 +216,20 @@ func TestManager_refreshSession(t *testing.T) { coolOffDuration: 10 * time.Second, }) - // After a success token refresh, the manager should schedule another - // refresh event. + // If OAuth2 token refresh fails with a temporary error, the manager should + // still reschedule another refresh attempt. now = now.Add(4 * time.Minute) + auth.refreshError = context.DeadlineExceeded + mgr.refreshSession(context.Background(), "user-id", "session-id") + + tm, key := mgr.sessionScheduler.Next() + assert.Equal(t, now.Add(10*time.Second), tm) + assert.Equal(t, "user-id\037session-id", key) + + // Simulate a successful token refresh on the second attempt. The manager + // should store the updated session in the databroker and schedule another + // refresh event. + now = now.Add(10 * time.Second) auth.refreshResult, auth.refreshError = &oauth2.Token{ AccessToken: "new-access-token", RefreshToken: "new-refresh-token", @@ -239,7 +250,7 @@ func TestManager_refreshSession(t *testing.T) { Return(nil /* this result is currently unused */, nil) mgr.refreshSession(context.Background(), "user-id", "session-id") - tm, key := mgr.sessionScheduler.Next() + tm, key = mgr.sessionScheduler.Next() assert.Equal(t, now.Add(4*time.Minute), tm) assert.Equal(t, "user-id\037session-id", key) }