diff --git a/internal/identity/manager/data.go b/internal/identity/manager/data.go index a397d5a45..2c8637102 100644 --- a/internal/identity/manager/data.go +++ b/internal/identity/manager/data.go @@ -54,6 +54,9 @@ func (u *User) UnmarshalJSON(data []byte) error { // A Session is a session managed by the Manager. type Session struct { *session.Session + // lastRefresh is the time of the last refresh attempt (which may or may + // not have succeeded), or else the time the Manager first became aware of + // the session (if it has not yet attempted to refresh this session). lastRefresh time.Time // gracePeriod is the amount of time before expiration to attempt a refresh. gracePeriod time.Duration diff --git a/internal/identity/manager/manager.go b/internal/identity/manager/manager.go index 02ddb635a..51056b70c 100644 --- a/internal/identity/manager/manager.go +++ b/internal/identity/manager/manager.go @@ -107,6 +107,10 @@ func (mgr *Manager) GetDataBrokerServiceClient() databroker.DataBrokerServiceCli return mgr.cfg.Load().dataBrokerClient } +func (mgr *Manager) now() time.Time { + return mgr.cfg.Load().now() +} + func (mgr *Manager) refreshLoop(ctx context.Context, update <-chan updateRecordsMessage, clear <-chan struct{}) error { // wait for initial sync select { @@ -145,7 +149,7 @@ func (mgr *Manager) refreshLoop(ctx context.Context, update <-chan updateRecords case <-timer.C: } - now := time.Now() + now := mgr.now() nextTime = now.Add(maxWait) // refresh sessions @@ -182,22 +186,21 @@ func (mgr *Manager) refreshLoop(ctx context.Context, update <-chan updateRecords } } +// refreshSession handles two distinct session lifecycle events: +// +// 1. If the session itself has expired, delete the session. +// 2. If the session's underlying OAuth2 access token is nearing expiration +// (but the session itself is still valid), refresh the access token. +// +// After a successful access token refresh, this method will also trigger a +// user info refresh. If an access token refresh or a user info refresh fails +// with a permanent error, the session will be deleted. func (mgr *Manager) refreshSession(ctx context.Context, userID, sessionID string) { log.Info(ctx). Str("user_id", userID). Str("session_id", sessionID). Msg("refreshing session") - authenticator := mgr.cfg.Load().authenticator - if authenticator == nil { - log.Info(ctx). - Str("user_id", userID). - Str("session_id", sessionID). - Msg("no authenticator defined, deleting session") - mgr.deleteSession(ctx, userID, sessionID) - return - } - s, ok := mgr.sessions.Get(userID, sessionID) if !ok { log.Warn(ctx). @@ -207,14 +210,37 @@ func (mgr *Manager) refreshSession(ctx context.Context, userID, sessionID string return } + s.lastRefresh = mgr.now() + + if mgr.refreshSessionInternal(ctx, userID, sessionID, &s) { + mgr.sessions.ReplaceOrInsert(s) + mgr.sessionScheduler.Add(s.NextRefresh(), toSessionSchedulerKey(userID, sessionID)) + } +} + +// refreshSessionInternal performs the core refresh logic and returns true if +// the session should be scheduled for refresh again, or false if not. +func (mgr *Manager) refreshSessionInternal( + ctx context.Context, userID, sessionID string, s *Session, +) bool { + authenticator := mgr.cfg.Load().authenticator + if authenticator == nil { + log.Info(ctx). + Str("user_id", userID). + Str("session_id", sessionID). + Msg("no authenticator defined, deleting session") + mgr.deleteSession(ctx, userID, sessionID) + return false + } + expiry := s.GetExpiresAt().AsTime() - if !expiry.After(time.Now()) { + if !expiry.After(mgr.now()) { log.Info(ctx). Str("user_id", userID). Str("session_id", sessionID). Msg("deleting expired session") mgr.deleteSession(ctx, userID, sessionID) - return + return false } if s.Session == nil || s.Session.OauthToken == nil { @@ -222,10 +248,10 @@ func (mgr *Manager) refreshSession(ctx context.Context, userID, sessionID string Str("user_id", userID). Str("session_id", sessionID). Msg("no session oauth2 token found for refresh") - return + return false } - newToken, err := authenticator.Refresh(ctx, FromOAuthToken(s.OauthToken), &s) + newToken, err := authenticator.Refresh(ctx, FromOAuthToken(s.OauthToken), s) metrics.RecordIdentityManagerSessionRefresh(ctx, err) mgr.recordLastError(metrics_ids.IdentityManagerLastSessionRefreshError, err) if isTemporaryError(err) { @@ -233,18 +259,18 @@ func (mgr *Manager) refreshSession(ctx context.Context, userID, sessionID string Str("user_id", s.GetUserId()). Str("session_id", s.GetId()). Msg("failed to refresh oauth2 token") - return + return true } else if err != nil { log.Error(ctx).Err(err). Str("user_id", s.GetUserId()). Str("session_id", s.GetId()). Msg("failed to refresh oauth2 token, deleting session") mgr.deleteSession(ctx, userID, sessionID) - return + return false } s.OauthToken = ToOAuthToken(newToken) - err = authenticator.UpdateUserInfo(ctx, FromOAuthToken(s.OauthToken), &s) + err = authenticator.UpdateUserInfo(ctx, FromOAuthToken(s.OauthToken), s) metrics.RecordIdentityManagerUserRefresh(ctx, err) mgr.recordLastError(metrics_ids.IdentityManagerLastUserRefreshError, err) if isTemporaryError(err) { @@ -252,26 +278,23 @@ func (mgr *Manager) refreshSession(ctx context.Context, userID, sessionID string Str("user_id", s.GetUserId()). Str("session_id", s.GetId()). Msg("failed to update user info") - return + return true } else if err != nil { log.Error(ctx).Err(err). Str("user_id", s.GetUserId()). Str("session_id", s.GetId()). Msg("failed to update user info, deleting session") mgr.deleteSession(ctx, userID, sessionID) - return + return false } - res, err := session.Put(ctx, mgr.cfg.Load().dataBrokerClient, s.Session) - if err != nil { + if _, err := session.Put(ctx, mgr.cfg.Load().dataBrokerClient, s.Session); err != nil { log.Error(ctx).Err(err). Str("user_id", s.GetUserId()). Str("session_id", s.GetId()). Msg("failed to update session") - return } - - mgr.onUpdateSession(ctx, res.GetRecord(), s.Session) + return true } func (mgr *Manager) refreshUser(ctx context.Context, userID string) { @@ -291,7 +314,7 @@ func (mgr *Manager) refreshUser(ctx context.Context, userID string) { Msg("no user found for refresh") return } - u.lastRefresh = time.Now() + u.lastRefresh = mgr.now() mgr.userScheduler.Add(u.NextRefresh(), u.GetId()) for _, s := range mgr.sessions.GetSessionsForUser(userID) { @@ -343,7 +366,7 @@ func (mgr *Manager) onUpdateRecords(ctx context.Context, msg updateRecordsMessag log.Warn(ctx).Msgf("error unmarshaling session: %s", err) continue } - mgr.onUpdateSession(ctx, record, &pbSession) + mgr.onUpdateSession(record, &pbSession) case grpcutil.GetTypeURL(new(user.User)): var pbUser user.User err := record.GetData().UnmarshalTo(&pbUser) @@ -356,7 +379,7 @@ func (mgr *Manager) onUpdateRecords(ctx context.Context, msg updateRecordsMessag } } -func (mgr *Manager) onUpdateSession(_ context.Context, record *databroker.Record, session *session.Session) { +func (mgr *Manager) onUpdateSession(record *databroker.Record, session *session.Session) { mgr.sessionScheduler.Remove(toSessionSchedulerKey(session.GetUserId(), session.GetId())) if record.GetDeletedAt() != nil { @@ -366,7 +389,9 @@ func (mgr *Manager) onUpdateSession(_ context.Context, record *databroker.Record // update session s, _ := mgr.sessions.Get(session.GetUserId(), session.GetId()) - s.lastRefresh = time.Now() + if s.lastRefresh.IsZero() { + s.lastRefresh = mgr.now() + } s.gracePeriod = mgr.cfg.Load().sessionRefreshGracePeriod s.coolOffDuration = mgr.cfg.Load().sessionRefreshCoolOffDuration s.Session = session diff --git a/internal/identity/manager/manager_test.go b/internal/identity/manager/manager_test.go index 1a4ed386b..29e49280b 100644 --- a/internal/identity/manager/manager_test.go +++ b/internal/identity/manager/manager_test.go @@ -3,6 +3,7 @@ package manager import ( "context" "errors" + "fmt" "testing" "time" @@ -24,18 +25,23 @@ import ( "github.com/pomerium/pomerium/pkg/protoutil" ) -type mockAuthenticator struct{} - -func (mock mockAuthenticator) Refresh(_ context.Context, _ *oauth2.Token, _ identity.State) (*oauth2.Token, error) { - return nil, errors.New("update session") +type mockAuthenticator struct { + refreshResult *oauth2.Token + refreshError error + revokeError error + updateUserInfoError error } -func (mock mockAuthenticator) Revoke(_ context.Context, _ *oauth2.Token) error { - return errors.New("not implemented") +func (mock *mockAuthenticator) Refresh(_ context.Context, _ *oauth2.Token, _ identity.State) (*oauth2.Token, error) { + return mock.refreshResult, mock.refreshError } -func (mock mockAuthenticator) UpdateUserInfo(_ context.Context, _ *oauth2.Token, _ any) error { - return errors.New("update user info") +func (mock *mockAuthenticator) Revoke(_ context.Context, _ *oauth2.Token) error { + return mock.revokeError +} + +func (mock *mockAuthenticator) UpdateUserInfo(_ context.Context, _ *oauth2.Token, _ any) error { + return mock.updateUserInfoError } func TestManager_refresh(t *testing.T) { @@ -86,6 +92,9 @@ func TestManager_onUpdateRecords(t *testing.T) { }) if _, ok := mgr.sessions.Get("user1", "session1"); assert.True(t, ok) { + tm, id := mgr.sessionScheduler.Next() + assert.Equal(t, now.Add(10*time.Second), tm) + assert.Equal(t, "user1\037session1", id) } if _, ok := mgr.users.Get("user1"); assert.True(t, ok) { tm, id := mgr.userScheduler.Next() @@ -94,6 +103,158 @@ func TestManager_onUpdateRecords(t *testing.T) { } } +func TestManager_onUpdateSession(t *testing.T) { + startTime := time.Date(2023, 10, 19, 12, 0, 0, 0, time.UTC) + + s := &session.Session{ + Id: "session-id", + UserId: "user-id", + OauthToken: &session.OAuthToken{ + AccessToken: "access-token", + ExpiresAt: timestamppb.New(startTime.Add(5 * time.Minute)), + }, + IssuedAt: timestamppb.New(startTime), + ExpiresAt: timestamppb.New(startTime.Add(24 * time.Hour)), + } + + assertNextScheduled := func(t *testing.T, mgr *Manager, expectedTime time.Time) { + t.Helper() + tm, key := mgr.sessionScheduler.Next() + assert.Equal(t, expectedTime, tm) + assert.Equal(t, "user-id\037session-id", key) + } + + t.Run("initial refresh event when not expiring soon", func(t *testing.T) { + now := startTime + mgr := New(WithNow(func() time.Time { return now })) + + // When the Manager first becomes aware of a session it should schedule + // a refresh event for one minute before access token expiration. + mgr.onUpdateSession(mkRecord(s), s) + assertNextScheduled(t, mgr, startTime.Add(4*time.Minute)) + }) + t.Run("initial refresh event when expiring soon", func(t *testing.T) { + now := startTime + mgr := New(WithNow(func() time.Time { return now })) + + // When the Manager first becomes aware of a session, if that session + // is expiring within the gracePeriod (1 minute), it should schedule a + // refresh event for as soon as possible, subject to the + // coolOffDuration (10 seconds). + now = now.Add(4*time.Minute + 30*time.Second) // 30 s before expiration + mgr.onUpdateSession(mkRecord(s), s) + assertNextScheduled(t, mgr, now.Add(10*time.Second)) + }) + t.Run("update near scheduled refresh", func(t *testing.T) { + now := startTime + mgr := New(WithNow(func() time.Time { return now })) + + mgr.onUpdateSession(mkRecord(s), s) + assertNextScheduled(t, mgr, startTime.Add(4*time.Minute)) + + // If a session is updated close to the time when it is scheduled to be + // refreshed, the scheduled refresh event should not be pushed back. + now = now.Add(3*time.Minute + 55*time.Second) // 5 s before refresh + mgr.onUpdateSession(mkRecord(s), s) + assertNextScheduled(t, mgr, now.Add(5*time.Second)) + + // However, if an update changes the access token validity, the refresh + // event should be rescheduled accordingly. (This should be uncommon, + // as only the refresh loop itself should modify the access token.) + s2 := proto.Clone(s).(*session.Session) + s2.OauthToken.ExpiresAt = timestamppb.New(now.Add(5 * time.Minute)) + mgr.onUpdateSession(mkRecord(s2), s2) + assertNextScheduled(t, mgr, now.Add(4*time.Minute)) + }) + t.Run("session record deleted", func(t *testing.T) { + now := startTime + mgr := New(WithNow(func() time.Time { return now })) + + mgr.onUpdateSession(mkRecord(s), s) + assertNextScheduled(t, mgr, startTime.Add(4*time.Minute)) + + // If a session is deleted, any scheduled refresh event should be canceled. + record := mkRecord(s) + record.DeletedAt = timestamppb.New(now) + mgr.onUpdateSession(record, s) + _, key := mgr.sessionScheduler.Next() + assert.Empty(t, key) + }) +} + +func TestManager_refreshSession(t *testing.T) { + startTime := time.Date(2023, 10, 19, 12, 0, 0, 0, time.UTC) + + var auth mockAuthenticator + + ctrl := gomock.NewController(t) + client := mock_databroker.NewMockDataBrokerServiceClient(ctrl) + + now := startTime + mgr := New( + WithDataBrokerClient(client), + WithNow(func() time.Time { return now }), + WithAuthenticator(&auth), + ) + + // Initialize the Manager with a new session. + s := &session.Session{ + Id: "session-id", + UserId: "user-id", + OauthToken: &session.OAuthToken{ + AccessToken: "access-token", + ExpiresAt: timestamppb.New(startTime.Add(5 * time.Minute)), + RefreshToken: "refresh-token", + }, + IssuedAt: timestamppb.New(startTime), + ExpiresAt: timestamppb.New(startTime.Add(24 * time.Hour)), + } + mgr.sessions.ReplaceOrInsert(Session{ + Session: s, + lastRefresh: startTime, + gracePeriod: time.Minute, + coolOffDuration: 10 * time.Second, + }) + + // If OAuth2 token refresh fails with a temporary error, the manager should + // still reschedule another refresh attempt. + now = now.Add(4 * time.Minute) + auth.refreshError = context.DeadlineExceeded + mgr.refreshSession(context.Background(), "user-id", "session-id") + + tm, key := mgr.sessionScheduler.Next() + assert.Equal(t, now.Add(10*time.Second), tm) + assert.Equal(t, "user-id\037session-id", key) + + // Simulate a successful token refresh on the second attempt. The manager + // should store the updated session in the databroker and schedule another + // refresh event. + now = now.Add(10 * time.Second) + auth.refreshResult, auth.refreshError = &oauth2.Token{ + AccessToken: "new-access-token", + RefreshToken: "new-refresh-token", + Expiry: now.Add(5 * time.Minute), + }, nil + expectedSession := proto.Clone(s).(*session.Session) + expectedSession.OauthToken = &session.OAuthToken{ + AccessToken: "new-access-token", + ExpiresAt: timestamppb.New(now.Add(5 * time.Minute)), + RefreshToken: "new-refresh-token", + } + client.EXPECT().Put(gomock.Any(), + objectsAreEqualMatcher{&databroker.PutRequest{Records: []*databroker.Record{{ + Type: "type.googleapis.com/session.Session", + Id: "session-id", + Data: protoutil.NewAny(expectedSession), + }}}}). + Return(nil /* this result is currently unused */, nil) + mgr.refreshSession(context.Background(), "user-id", "session-id") + + tm, key = mgr.sessionScheduler.Next() + assert.Equal(t, now.Add(4*time.Minute), tm) + assert.Equal(t, "user-id\037session-id", key) +} + func TestManager_reportErrors(t *testing.T) { ctrl := gomock.NewController(t) @@ -135,7 +296,10 @@ func TestManager_reportErrors(t *testing.T) { mgr := New( WithEventManager(evtMgr), WithDataBrokerClient(client), - WithAuthenticator(mockAuthenticator{}), + WithAuthenticator(&mockAuthenticator{ + refreshError: errors.New("update session"), + updateUserInfoError: errors.New("update user info"), + }), ) mgr.onUpdateRecords(ctx, updateRecordsMessage{ @@ -172,3 +336,18 @@ type recordable interface { proto.Message GetId() string } + +// objectsAreEqualMatcher implements gomock.Matcher using ObjectsAreEqual. This +// is especially helpful when working with pointers, as it will compare the +// underlying values rather than the pointers themselves. +type objectsAreEqualMatcher struct { + expected interface{} +} + +func (m objectsAreEqualMatcher) Matches(x interface{}) bool { + return assert.ObjectsAreEqual(m.expected, x) +} + +func (m objectsAreEqualMatcher) String() string { + return fmt.Sprintf("is equal to %v (%T)", m.expected, m.expected) +} diff --git a/internal/identity/oauth/apple/apple.go b/internal/identity/oauth/apple/apple.go index 53878606f..bed70fa94 100644 --- a/internal/identity/oauth/apple/apple.go +++ b/internal/identity/oauth/apple/apple.go @@ -130,16 +130,9 @@ func (p *Provider) Authenticate(ctx context.Context, code string, v identity.Sta // Refresh renews a user's session. func (p *Provider) Refresh(ctx context.Context, t *oauth2.Token, v identity.State) (*oauth2.Token, error) { - if t == nil { - return nil, oidc.ErrMissingAccessToken - } - if t.RefreshToken == "" { - return nil, oidc.ErrMissingRefreshToken - } - - newToken, err := p.oauth.TokenSource(ctx, t).Token() + newToken, err := oidc.Refresh(ctx, p.oauth, t) if err != nil { - return nil, fmt.Errorf("identity/apple: refresh failed: %w", err) + return nil, err } if rawIDToken, ok := newToken.Extra("id_token").(string); ok { diff --git a/internal/identity/oidc/oidc.go b/internal/identity/oidc/oidc.go index 832a62a85..d13e3634f 100644 --- a/internal/identity/oidc/oidc.go +++ b/internal/identity/oidc/oidc.go @@ -213,16 +213,9 @@ func (p *Provider) Refresh(ctx context.Context, t *oauth2.Token, v identity.Stat return nil, err } - if t == nil { - return nil, ErrMissingAccessToken - } - if t.RefreshToken == "" { - return nil, ErrMissingRefreshToken - } - - newToken, err := oa.TokenSource(ctx, t).Token() + newToken, err := Refresh(ctx, oa, t) if err != nil { - return nil, fmt.Errorf("identity/oidc: refresh failed: %w", err) + return nil, err } // Many identity providers _will not_ return `id_token` on refresh diff --git a/internal/identity/oidc/refresh.go b/internal/identity/oidc/refresh.go new file mode 100644 index 000000000..a2c1a216c --- /dev/null +++ b/internal/identity/oidc/refresh.go @@ -0,0 +1,29 @@ +package oidc + +import ( + "context" + "fmt" + + "golang.org/x/oauth2" +) + +// Refresh requests a new oauth2.Token based on an existing Token and the +// provided Config. The existing Token must contain a refresh token. +func Refresh(ctx context.Context, cfg *oauth2.Config, t *oauth2.Token) (*oauth2.Token, error) { + if t == nil || t.RefreshToken == "" { + return nil, ErrMissingRefreshToken + } + + // Note: the TokenSource returned by oauth2.Config has its own threshold + // for determining when to attempt a refresh. In order to force a refresh + // we can remove the current AccessToken. + t = &oauth2.Token{ + TokenType: t.TokenType, + RefreshToken: t.RefreshToken, + } + newToken, err := cfg.TokenSource(ctx, t).Token() + if err != nil { + return nil, fmt.Errorf("identity/oidc: refresh failed: %w", err) + } + return newToken, nil +} diff --git a/internal/identity/oidc/refresh_test.go b/internal/identity/oidc/refresh_test.go new file mode 100644 index 000000000..6ac7f067b --- /dev/null +++ b/internal/identity/oidc/refresh_test.go @@ -0,0 +1,71 @@ +package oidc + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/oauth2" +) + +func TestRefresh(t *testing.T) { + t.Parallel() + + ctx, clearTimeout := context.WithTimeout(context.Background(), 10*time.Second) + t.Cleanup(clearTimeout) + + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{ + "access_token": "NEW_TOKEN", + "refresh_token": "NEW_REFRESH_TOKEN", + "expires_in": 3600 + }`)) + })) + t.Cleanup(s.Close) + + cfg := &oauth2.Config{Endpoint: oauth2.Endpoint{TokenURL: s.URL}} + + token := &oauth2.Token{ + AccessToken: "OLD_TOKEN", + RefreshToken: "OLD_REFRESH_TOKEN", + + // Even if a token is not expiring soon, Refresh() should still perform + // the refresh. + Expiry: time.Now().Add(time.Hour), + } + require.True(t, token.Valid()) + + newToken, err := Refresh(ctx, cfg, token) + require.NoError(t, err) + assert.Equal(t, "NEW_TOKEN", newToken.AccessToken) + assert.Equal(t, "NEW_REFRESH_TOKEN", newToken.RefreshToken) +} + +func TestRefresh_errors(t *testing.T) { + t.Parallel() + + ctx, clearTimeout := context.WithTimeout(context.Background(), 10*time.Second) + t.Cleanup(clearTimeout) + + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("{}")) + })) + t.Cleanup(s.Close) + + cfg := &oauth2.Config{Endpoint: oauth2.Endpoint{TokenURL: s.URL}} + + _, err := Refresh(ctx, cfg, nil) + assert.Equal(t, ErrMissingRefreshToken, err) + + _, err = Refresh(ctx, cfg, &oauth2.Token{}) + assert.Equal(t, ErrMissingRefreshToken, err) + + _, err = Refresh(ctx, cfg, &oauth2.Token{RefreshToken: "REFRESH_TOKEN"}) + assert.Equal(t, "identity/oidc: refresh failed: oauth2: server response missing access_token", + err.Error()) +} diff --git a/scripts/get-envoy.bash b/scripts/get-envoy.bash index a1cbc7ddb..6b73735f0 100755 --- a/scripts/get-envoy.bash +++ b/scripts/get-envoy.bash @@ -5,7 +5,7 @@ PATH="$PATH:$(go env GOPATH)/bin" export PATH _project_root="$(cd "$(dirname "${BASH_SOURCE[0]}")" >/dev/null 2>&1 && pwd)/.." -_envoy_version=1.27.1 +_envoy_version=1.28.0 _dir="$_project_root/pkg/envoy/files" for _target in darwin-amd64 darwin-arm64 linux-amd64 linux-arm64; do