mirror of
https://github.com/pomerium/pomerium.git
synced 2025-04-29 10:26:29 +02:00
identity: preserve session refresh schedule (#4633)
The databroker identity manager is responsible for refreshing session records, to account for overall session expiration as well as OAuth2 access token expiration. Refresh events are scheduled subject to a coolOffDuration (10 seconds, by default) relative to a lastRefresh timestamp. Currently, any update to a session record will reset the associated lastRefresh value and reschedule any pending refresh event for that session. If an update occurs close before a scheduled refresh event, this will push back the scheduled refresh event to 10 seconds from that time. This means that if a session is updated frequently enough (e.g. if there is a steady stream of requests that cause constant updates via the AccessTracker), the access token may expire before a refresh ever runs. To avoid this problem, do not update the lastRefresh time upon every session record update, but only if it hasn't yet been set. Instead, update the lastRefresh during the refresh attempt itself. Add unit tests to exercise these changes. There is a now() function as part of the manager configuration (to allow unit tests to set a fake time); update the Manager to use this function throughout.
This commit is contained in:
parent
1996550c54
commit
fa7dc469a3
3 changed files with 205 additions and 18 deletions
|
@ -54,6 +54,9 @@ func (u *User) UnmarshalJSON(data []byte) error {
|
|||
// A Session is a session managed by the Manager.
|
||||
type Session struct {
|
||||
*session.Session
|
||||
// lastRefresh is the time of the last refresh attempt (which may or may
|
||||
// not have succeeded), or else the time the Manager first became aware of
|
||||
// the session (if it has not yet attempted to refresh this session).
|
||||
lastRefresh time.Time
|
||||
// gracePeriod is the amount of time before expiration to attempt a refresh.
|
||||
gracePeriod time.Duration
|
||||
|
|
|
@ -107,6 +107,10 @@ func (mgr *Manager) GetDataBrokerServiceClient() databroker.DataBrokerServiceCli
|
|||
return mgr.cfg.Load().dataBrokerClient
|
||||
}
|
||||
|
||||
func (mgr *Manager) now() time.Time {
|
||||
return mgr.cfg.Load().now()
|
||||
}
|
||||
|
||||
func (mgr *Manager) refreshLoop(ctx context.Context, update <-chan updateRecordsMessage, clear <-chan struct{}) error {
|
||||
// wait for initial sync
|
||||
select {
|
||||
|
@ -145,7 +149,7 @@ func (mgr *Manager) refreshLoop(ctx context.Context, update <-chan updateRecords
|
|||
case <-timer.C:
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
now := mgr.now()
|
||||
nextTime = now.Add(maxWait)
|
||||
|
||||
// refresh sessions
|
||||
|
@ -182,6 +186,15 @@ func (mgr *Manager) refreshLoop(ctx context.Context, update <-chan updateRecords
|
|||
}
|
||||
}
|
||||
|
||||
// refreshSession handles two distinct session lifecycle events:
|
||||
//
|
||||
// 1. If the session itself has expired, delete the session.
|
||||
// 2. If the session's underlying OAuth2 access token is nearing expiration
|
||||
// (but the session itself is still valid), refresh the access token.
|
||||
//
|
||||
// After a successful access token refresh, this method will also trigger a
|
||||
// user info refresh. If an access token refresh or a user info refresh fails
|
||||
// with a permanent error, the session will be deleted.
|
||||
func (mgr *Manager) refreshSession(ctx context.Context, userID, sessionID string) {
|
||||
log.Info(ctx).
|
||||
Str("user_id", userID).
|
||||
|
@ -208,7 +221,7 @@ func (mgr *Manager) refreshSession(ctx context.Context, userID, sessionID string
|
|||
}
|
||||
|
||||
expiry := s.GetExpiresAt().AsTime()
|
||||
if !expiry.After(time.Now()) {
|
||||
if !expiry.After(mgr.now()) {
|
||||
log.Info(ctx).
|
||||
Str("user_id", userID).
|
||||
Str("session_id", sessionID).
|
||||
|
@ -262,8 +275,7 @@ func (mgr *Manager) refreshSession(ctx context.Context, userID, sessionID string
|
|||
return
|
||||
}
|
||||
|
||||
res, err := session.Put(ctx, mgr.cfg.Load().dataBrokerClient, s.Session)
|
||||
if err != nil {
|
||||
if _, err := session.Put(ctx, mgr.cfg.Load().dataBrokerClient, s.Session); err != nil {
|
||||
log.Error(ctx).Err(err).
|
||||
Str("user_id", s.GetUserId()).
|
||||
Str("session_id", s.GetId()).
|
||||
|
@ -271,7 +283,9 @@ func (mgr *Manager) refreshSession(ctx context.Context, userID, sessionID string
|
|||
return
|
||||
}
|
||||
|
||||
mgr.onUpdateSession(ctx, res.GetRecord(), s.Session)
|
||||
s.lastRefresh = mgr.now()
|
||||
mgr.sessions.ReplaceOrInsert(s)
|
||||
mgr.sessionScheduler.Add(s.NextRefresh(), toSessionSchedulerKey(userID, sessionID))
|
||||
}
|
||||
|
||||
func (mgr *Manager) refreshUser(ctx context.Context, userID string) {
|
||||
|
@ -291,7 +305,7 @@ func (mgr *Manager) refreshUser(ctx context.Context, userID string) {
|
|||
Msg("no user found for refresh")
|
||||
return
|
||||
}
|
||||
u.lastRefresh = time.Now()
|
||||
u.lastRefresh = mgr.now()
|
||||
mgr.userScheduler.Add(u.NextRefresh(), u.GetId())
|
||||
|
||||
for _, s := range mgr.sessions.GetSessionsForUser(userID) {
|
||||
|
@ -343,7 +357,7 @@ func (mgr *Manager) onUpdateRecords(ctx context.Context, msg updateRecordsMessag
|
|||
log.Warn(ctx).Msgf("error unmarshaling session: %s", err)
|
||||
continue
|
||||
}
|
||||
mgr.onUpdateSession(ctx, record, &pbSession)
|
||||
mgr.onUpdateSession(record, &pbSession)
|
||||
case grpcutil.GetTypeURL(new(user.User)):
|
||||
var pbUser user.User
|
||||
err := record.GetData().UnmarshalTo(&pbUser)
|
||||
|
@ -356,7 +370,7 @@ func (mgr *Manager) onUpdateRecords(ctx context.Context, msg updateRecordsMessag
|
|||
}
|
||||
}
|
||||
|
||||
func (mgr *Manager) onUpdateSession(_ context.Context, record *databroker.Record, session *session.Session) {
|
||||
func (mgr *Manager) onUpdateSession(record *databroker.Record, session *session.Session) {
|
||||
mgr.sessionScheduler.Remove(toSessionSchedulerKey(session.GetUserId(), session.GetId()))
|
||||
|
||||
if record.GetDeletedAt() != nil {
|
||||
|
@ -366,7 +380,9 @@ func (mgr *Manager) onUpdateSession(_ context.Context, record *databroker.Record
|
|||
|
||||
// update session
|
||||
s, _ := mgr.sessions.Get(session.GetUserId(), session.GetId())
|
||||
s.lastRefresh = time.Now()
|
||||
if s.lastRefresh.IsZero() {
|
||||
s.lastRefresh = mgr.now()
|
||||
}
|
||||
s.gracePeriod = mgr.cfg.Load().sessionRefreshGracePeriod
|
||||
s.coolOffDuration = mgr.cfg.Load().sessionRefreshCoolOffDuration
|
||||
s.Session = session
|
||||
|
|
|
@ -3,6 +3,7 @@ package manager
|
|||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
@ -24,18 +25,23 @@ import (
|
|||
"github.com/pomerium/pomerium/pkg/protoutil"
|
||||
)
|
||||
|
||||
type mockAuthenticator struct{}
|
||||
|
||||
func (mock mockAuthenticator) Refresh(_ context.Context, _ *oauth2.Token, _ identity.State) (*oauth2.Token, error) {
|
||||
return nil, errors.New("update session")
|
||||
type mockAuthenticator struct {
|
||||
refreshResult *oauth2.Token
|
||||
refreshError error
|
||||
revokeError error
|
||||
updateUserInfoError error
|
||||
}
|
||||
|
||||
func (mock mockAuthenticator) Revoke(_ context.Context, _ *oauth2.Token) error {
|
||||
return errors.New("not implemented")
|
||||
func (mock *mockAuthenticator) Refresh(_ context.Context, _ *oauth2.Token, _ identity.State) (*oauth2.Token, error) {
|
||||
return mock.refreshResult, mock.refreshError
|
||||
}
|
||||
|
||||
func (mock mockAuthenticator) UpdateUserInfo(_ context.Context, _ *oauth2.Token, _ any) error {
|
||||
return errors.New("update user info")
|
||||
func (mock *mockAuthenticator) Revoke(_ context.Context, _ *oauth2.Token) error {
|
||||
return mock.revokeError
|
||||
}
|
||||
|
||||
func (mock *mockAuthenticator) UpdateUserInfo(_ context.Context, _ *oauth2.Token, _ any) error {
|
||||
return mock.updateUserInfoError
|
||||
}
|
||||
|
||||
func TestManager_refresh(t *testing.T) {
|
||||
|
@ -86,6 +92,9 @@ func TestManager_onUpdateRecords(t *testing.T) {
|
|||
})
|
||||
|
||||
if _, ok := mgr.sessions.Get("user1", "session1"); assert.True(t, ok) {
|
||||
tm, id := mgr.sessionScheduler.Next()
|
||||
assert.Equal(t, now.Add(10*time.Second), tm)
|
||||
assert.Equal(t, "user1\037session1", id)
|
||||
}
|
||||
if _, ok := mgr.users.Get("user1"); assert.True(t, ok) {
|
||||
tm, id := mgr.userScheduler.Next()
|
||||
|
@ -94,6 +103,147 @@ func TestManager_onUpdateRecords(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestManager_onUpdateSession(t *testing.T) {
|
||||
startTime := time.Date(2023, 10, 19, 12, 0, 0, 0, time.UTC)
|
||||
|
||||
s := &session.Session{
|
||||
Id: "session-id",
|
||||
UserId: "user-id",
|
||||
OauthToken: &session.OAuthToken{
|
||||
AccessToken: "access-token",
|
||||
ExpiresAt: timestamppb.New(startTime.Add(5 * time.Minute)),
|
||||
},
|
||||
IssuedAt: timestamppb.New(startTime),
|
||||
ExpiresAt: timestamppb.New(startTime.Add(24 * time.Hour)),
|
||||
}
|
||||
|
||||
assertNextScheduled := func(t *testing.T, mgr *Manager, expectedTime time.Time) {
|
||||
t.Helper()
|
||||
tm, key := mgr.sessionScheduler.Next()
|
||||
assert.Equal(t, expectedTime, tm)
|
||||
assert.Equal(t, "user-id\037session-id", key)
|
||||
}
|
||||
|
||||
t.Run("initial refresh event when not expiring soon", func(t *testing.T) {
|
||||
now := startTime
|
||||
mgr := New(WithNow(func() time.Time { return now }))
|
||||
|
||||
// When the Manager first becomes aware of a session it should schedule
|
||||
// a refresh event for one minute before access token expiration.
|
||||
mgr.onUpdateSession(mkRecord(s), s)
|
||||
assertNextScheduled(t, mgr, startTime.Add(4*time.Minute))
|
||||
})
|
||||
t.Run("initial refresh event when expiring soon", func(t *testing.T) {
|
||||
now := startTime
|
||||
mgr := New(WithNow(func() time.Time { return now }))
|
||||
|
||||
// When the Manager first becomes aware of a session, if that session
|
||||
// is expiring within the gracePeriod (1 minute), it should schedule a
|
||||
// refresh event for as soon as possible, subject to the
|
||||
// coolOffDuration (10 seconds).
|
||||
now = now.Add(4*time.Minute + 30*time.Second) // 30 s before expiration
|
||||
mgr.onUpdateSession(mkRecord(s), s)
|
||||
assertNextScheduled(t, mgr, now.Add(10*time.Second))
|
||||
})
|
||||
t.Run("update near scheduled refresh", func(t *testing.T) {
|
||||
now := startTime
|
||||
mgr := New(WithNow(func() time.Time { return now }))
|
||||
|
||||
mgr.onUpdateSession(mkRecord(s), s)
|
||||
assertNextScheduled(t, mgr, startTime.Add(4*time.Minute))
|
||||
|
||||
// If a session is updated close to the time when it is scheduled to be
|
||||
// refreshed, the scheduled refresh event should not be pushed back.
|
||||
now = now.Add(3*time.Minute + 55*time.Second) // 5 s before refresh
|
||||
mgr.onUpdateSession(mkRecord(s), s)
|
||||
assertNextScheduled(t, mgr, now.Add(5*time.Second))
|
||||
|
||||
// However, if an update changes the access token validity, the refresh
|
||||
// event should be rescheduled accordingly. (This should be uncommon,
|
||||
// as only the refresh loop itself should modify the access token.)
|
||||
s2 := proto.Clone(s).(*session.Session)
|
||||
s2.OauthToken.ExpiresAt = timestamppb.New(now.Add(5 * time.Minute))
|
||||
mgr.onUpdateSession(mkRecord(s2), s2)
|
||||
assertNextScheduled(t, mgr, now.Add(4*time.Minute))
|
||||
})
|
||||
t.Run("session record deleted", func(t *testing.T) {
|
||||
now := startTime
|
||||
mgr := New(WithNow(func() time.Time { return now }))
|
||||
|
||||
mgr.onUpdateSession(mkRecord(s), s)
|
||||
assertNextScheduled(t, mgr, startTime.Add(4*time.Minute))
|
||||
|
||||
// If a session is deleted, any scheduled refresh event should be canceled.
|
||||
record := mkRecord(s)
|
||||
record.DeletedAt = timestamppb.New(now)
|
||||
mgr.onUpdateSession(record, s)
|
||||
_, key := mgr.sessionScheduler.Next()
|
||||
assert.Empty(t, key)
|
||||
})
|
||||
}
|
||||
|
||||
func TestManager_refreshSession(t *testing.T) {
|
||||
startTime := time.Date(2023, 10, 19, 12, 0, 0, 0, time.UTC)
|
||||
|
||||
var auth mockAuthenticator
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
client := mock_databroker.NewMockDataBrokerServiceClient(ctrl)
|
||||
|
||||
now := startTime
|
||||
mgr := New(
|
||||
WithDataBrokerClient(client),
|
||||
WithNow(func() time.Time { return now }),
|
||||
WithAuthenticator(&auth),
|
||||
)
|
||||
|
||||
// Initialize the Manager with a new session.
|
||||
s := &session.Session{
|
||||
Id: "session-id",
|
||||
UserId: "user-id",
|
||||
OauthToken: &session.OAuthToken{
|
||||
AccessToken: "access-token",
|
||||
ExpiresAt: timestamppb.New(startTime.Add(5 * time.Minute)),
|
||||
RefreshToken: "refresh-token",
|
||||
},
|
||||
IssuedAt: timestamppb.New(startTime),
|
||||
ExpiresAt: timestamppb.New(startTime.Add(24 * time.Hour)),
|
||||
}
|
||||
mgr.sessions.ReplaceOrInsert(Session{
|
||||
Session: s,
|
||||
lastRefresh: startTime,
|
||||
gracePeriod: time.Minute,
|
||||
coolOffDuration: 10 * time.Second,
|
||||
})
|
||||
|
||||
// After a success token refresh, the manager should schedule another
|
||||
// refresh event.
|
||||
now = now.Add(4 * time.Minute)
|
||||
auth.refreshResult, auth.refreshError = &oauth2.Token{
|
||||
AccessToken: "new-access-token",
|
||||
RefreshToken: "new-refresh-token",
|
||||
Expiry: now.Add(5 * time.Minute),
|
||||
}, nil
|
||||
expectedSession := proto.Clone(s).(*session.Session)
|
||||
expectedSession.OauthToken = &session.OAuthToken{
|
||||
AccessToken: "new-access-token",
|
||||
ExpiresAt: timestamppb.New(now.Add(5 * time.Minute)),
|
||||
RefreshToken: "new-refresh-token",
|
||||
}
|
||||
client.EXPECT().Put(gomock.Any(),
|
||||
objectsAreEqualMatcher{&databroker.PutRequest{Records: []*databroker.Record{{
|
||||
Type: "type.googleapis.com/session.Session",
|
||||
Id: "session-id",
|
||||
Data: protoutil.NewAny(expectedSession),
|
||||
}}}}).
|
||||
Return(nil /* this result is currently unused */, nil)
|
||||
mgr.refreshSession(context.Background(), "user-id", "session-id")
|
||||
|
||||
tm, key := mgr.sessionScheduler.Next()
|
||||
assert.Equal(t, now.Add(4*time.Minute), tm)
|
||||
assert.Equal(t, "user-id\037session-id", key)
|
||||
}
|
||||
|
||||
func TestManager_reportErrors(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
|
||||
|
@ -135,7 +285,10 @@ func TestManager_reportErrors(t *testing.T) {
|
|||
mgr := New(
|
||||
WithEventManager(evtMgr),
|
||||
WithDataBrokerClient(client),
|
||||
WithAuthenticator(mockAuthenticator{}),
|
||||
WithAuthenticator(&mockAuthenticator{
|
||||
refreshError: errors.New("update session"),
|
||||
updateUserInfoError: errors.New("update user info"),
|
||||
}),
|
||||
)
|
||||
|
||||
mgr.onUpdateRecords(ctx, updateRecordsMessage{
|
||||
|
@ -172,3 +325,18 @@ type recordable interface {
|
|||
proto.Message
|
||||
GetId() string
|
||||
}
|
||||
|
||||
// objectsAreEqualMatcher implements gomock.Matcher using ObjectsAreEqual. This
|
||||
// is especially helpful when working with pointers, as it will compare the
|
||||
// underlying values rather than the pointers themselves.
|
||||
type objectsAreEqualMatcher struct {
|
||||
expected interface{}
|
||||
}
|
||||
|
||||
func (m objectsAreEqualMatcher) Matches(x interface{}) bool {
|
||||
return assert.ObjectsAreEqual(m.expected, x)
|
||||
}
|
||||
|
||||
func (m objectsAreEqualMatcher) String() string {
|
||||
return fmt.Sprintf("is equal to %v (%T)", m.expected, m.expected)
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue