identity: rework session refresh error handling (#4639)

identity: rework session refresh error handling (#4638)

Currently, if a temporary error occurs while attempting to refresh an
OAuth2 token, the identity manager won't schedule another attempt.

Instead, update the session refresh logic so that it will retry after
temporary errors. Extract the bulk of this logic into a separate method
that returns a boolean indicating whether to schedule another refresh.

Update the unit test to simulate a temporary error during OAuth2 token
refresh.

Co-authored-by: Kenneth Jenkins <51246568+kenjenkins@users.noreply.github.com>
This commit is contained in:
backport-actions-token[bot] 2023-10-24 15:59:52 -07:00 committed by GitHub
parent 51456671cf
commit 70d77b283b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 46 additions and 26 deletions

View file

@ -201,16 +201,6 @@ func (mgr *Manager) refreshSession(ctx context.Context, userID, sessionID string
Str("session_id", sessionID). Str("session_id", sessionID).
Msg("refreshing session") Msg("refreshing session")
authenticator := mgr.cfg.Load().authenticator
if authenticator == nil {
log.Info(ctx).
Str("user_id", userID).
Str("session_id", sessionID).
Msg("no authenticator defined, deleting session")
mgr.deleteSession(ctx, userID, sessionID)
return
}
s, ok := mgr.sessions.Get(userID, sessionID) s, ok := mgr.sessions.Get(userID, sessionID)
if !ok { if !ok {
log.Warn(ctx). log.Warn(ctx).
@ -220,6 +210,29 @@ func (mgr *Manager) refreshSession(ctx context.Context, userID, sessionID string
return return
} }
s.lastRefresh = mgr.now()
if mgr.refreshSessionInternal(ctx, userID, sessionID, &s) {
mgr.sessions.ReplaceOrInsert(s)
mgr.sessionScheduler.Add(s.NextRefresh(), toSessionSchedulerKey(userID, sessionID))
}
}
// refreshSessionInternal performs the core refresh logic and returns true if
// the session should be scheduled for refresh again, or false if not.
func (mgr *Manager) refreshSessionInternal(
ctx context.Context, userID, sessionID string, s *Session,
) bool {
authenticator := mgr.cfg.Load().authenticator
if authenticator == nil {
log.Info(ctx).
Str("user_id", userID).
Str("session_id", sessionID).
Msg("no authenticator defined, deleting session")
mgr.deleteSession(ctx, userID, sessionID)
return false
}
expiry := s.GetExpiresAt().AsTime() expiry := s.GetExpiresAt().AsTime()
if !expiry.After(mgr.now()) { if !expiry.After(mgr.now()) {
log.Info(ctx). log.Info(ctx).
@ -227,7 +240,7 @@ func (mgr *Manager) refreshSession(ctx context.Context, userID, sessionID string
Str("session_id", sessionID). Str("session_id", sessionID).
Msg("deleting expired session") Msg("deleting expired session")
mgr.deleteSession(ctx, userID, sessionID) mgr.deleteSession(ctx, userID, sessionID)
return return false
} }
if s.Session == nil || s.Session.OauthToken == nil { if s.Session == nil || s.Session.OauthToken == nil {
@ -235,10 +248,10 @@ func (mgr *Manager) refreshSession(ctx context.Context, userID, sessionID string
Str("user_id", userID). Str("user_id", userID).
Str("session_id", sessionID). Str("session_id", sessionID).
Msg("no session oauth2 token found for refresh") Msg("no session oauth2 token found for refresh")
return return false
} }
newToken, err := authenticator.Refresh(ctx, FromOAuthToken(s.OauthToken), &s) newToken, err := authenticator.Refresh(ctx, FromOAuthToken(s.OauthToken), s)
metrics.RecordIdentityManagerSessionRefresh(ctx, err) metrics.RecordIdentityManagerSessionRefresh(ctx, err)
mgr.recordLastError(metrics_ids.IdentityManagerLastSessionRefreshError, err) mgr.recordLastError(metrics_ids.IdentityManagerLastSessionRefreshError, err)
if isTemporaryError(err) { if isTemporaryError(err) {
@ -246,18 +259,18 @@ func (mgr *Manager) refreshSession(ctx context.Context, userID, sessionID string
Str("user_id", s.GetUserId()). Str("user_id", s.GetUserId()).
Str("session_id", s.GetId()). Str("session_id", s.GetId()).
Msg("failed to refresh oauth2 token") Msg("failed to refresh oauth2 token")
return return true
} else if err != nil { } else if err != nil {
log.Error(ctx).Err(err). log.Error(ctx).Err(err).
Str("user_id", s.GetUserId()). Str("user_id", s.GetUserId()).
Str("session_id", s.GetId()). Str("session_id", s.GetId()).
Msg("failed to refresh oauth2 token, deleting session") Msg("failed to refresh oauth2 token, deleting session")
mgr.deleteSession(ctx, userID, sessionID) mgr.deleteSession(ctx, userID, sessionID)
return return false
} }
s.OauthToken = ToOAuthToken(newToken) s.OauthToken = ToOAuthToken(newToken)
err = authenticator.UpdateUserInfo(ctx, FromOAuthToken(s.OauthToken), &s) err = authenticator.UpdateUserInfo(ctx, FromOAuthToken(s.OauthToken), s)
metrics.RecordIdentityManagerUserRefresh(ctx, err) metrics.RecordIdentityManagerUserRefresh(ctx, err)
mgr.recordLastError(metrics_ids.IdentityManagerLastUserRefreshError, err) mgr.recordLastError(metrics_ids.IdentityManagerLastUserRefreshError, err)
if isTemporaryError(err) { if isTemporaryError(err) {
@ -265,14 +278,14 @@ func (mgr *Manager) refreshSession(ctx context.Context, userID, sessionID string
Str("user_id", s.GetUserId()). Str("user_id", s.GetUserId()).
Str("session_id", s.GetId()). Str("session_id", s.GetId()).
Msg("failed to update user info") Msg("failed to update user info")
return return true
} else if err != nil { } else if err != nil {
log.Error(ctx).Err(err). log.Error(ctx).Err(err).
Str("user_id", s.GetUserId()). Str("user_id", s.GetUserId()).
Str("session_id", s.GetId()). Str("session_id", s.GetId()).
Msg("failed to update user info, deleting session") Msg("failed to update user info, deleting session")
mgr.deleteSession(ctx, userID, sessionID) mgr.deleteSession(ctx, userID, sessionID)
return return false
} }
if _, err := session.Put(ctx, mgr.cfg.Load().dataBrokerClient, s.Session); err != nil { if _, err := session.Put(ctx, mgr.cfg.Load().dataBrokerClient, s.Session); err != nil {
@ -280,12 +293,8 @@ func (mgr *Manager) refreshSession(ctx context.Context, userID, sessionID string
Str("user_id", s.GetUserId()). Str("user_id", s.GetUserId()).
Str("session_id", s.GetId()). Str("session_id", s.GetId()).
Msg("failed to update session") Msg("failed to update session")
return
} }
return true
s.lastRefresh = mgr.now()
mgr.sessions.ReplaceOrInsert(s)
mgr.sessionScheduler.Add(s.NextRefresh(), toSessionSchedulerKey(userID, sessionID))
} }
func (mgr *Manager) refreshUser(ctx context.Context, userID string) { func (mgr *Manager) refreshUser(ctx context.Context, userID string) {

View file

@ -216,9 +216,20 @@ func TestManager_refreshSession(t *testing.T) {
coolOffDuration: 10 * time.Second, coolOffDuration: 10 * time.Second,
}) })
// After a success token refresh, the manager should schedule another // If OAuth2 token refresh fails with a temporary error, the manager should
// refresh event. // still reschedule another refresh attempt.
now = now.Add(4 * time.Minute) now = now.Add(4 * time.Minute)
auth.refreshError = context.DeadlineExceeded
mgr.refreshSession(context.Background(), "user-id", "session-id")
tm, key := mgr.sessionScheduler.Next()
assert.Equal(t, now.Add(10*time.Second), tm)
assert.Equal(t, "user-id\037session-id", key)
// Simulate a successful token refresh on the second attempt. The manager
// should store the updated session in the databroker and schedule another
// refresh event.
now = now.Add(10 * time.Second)
auth.refreshResult, auth.refreshError = &oauth2.Token{ auth.refreshResult, auth.refreshError = &oauth2.Token{
AccessToken: "new-access-token", AccessToken: "new-access-token",
RefreshToken: "new-refresh-token", RefreshToken: "new-refresh-token",
@ -239,7 +250,7 @@ func TestManager_refreshSession(t *testing.T) {
Return(nil /* this result is currently unused */, nil) Return(nil /* this result is currently unused */, nil)
mgr.refreshSession(context.Background(), "user-id", "session-id") mgr.refreshSession(context.Background(), "user-id", "session-id")
tm, key := mgr.sessionScheduler.Next() tm, key = mgr.sessionScheduler.Next()
assert.Equal(t, now.Add(4*time.Minute), tm) assert.Equal(t, now.Add(4*time.Minute), tm)
assert.Equal(t, "user-id\037session-id", key) assert.Equal(t, "user-id\037session-id", key)
} }