mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-03 12:26:03 +02:00
Currently, if a temporary error occurs while attempting to refresh an OAuth2 token, the identity manager won't schedule another attempt. Instead, update the session refresh logic so that it will retry after temporary errors. Extract the bulk of this logic into a separate method that returns a boolean indicating whether to schedule another refresh. Update the unit test to simulate a temporary error during OAuth2 token refresh.
353 lines
11 KiB
Go
353 lines
11 KiB
Go
package manager
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/golang/mock/gomock"
|
|
"github.com/stretchr/testify/assert"
|
|
"golang.org/x/oauth2"
|
|
"google.golang.org/grpc/codes"
|
|
"google.golang.org/grpc/status"
|
|
"google.golang.org/protobuf/proto"
|
|
"google.golang.org/protobuf/types/known/timestamppb"
|
|
|
|
"github.com/pomerium/pomerium/internal/events"
|
|
"github.com/pomerium/pomerium/internal/identity/identity"
|
|
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
|
"github.com/pomerium/pomerium/pkg/grpc/databroker/mock_databroker"
|
|
"github.com/pomerium/pomerium/pkg/grpc/session"
|
|
"github.com/pomerium/pomerium/pkg/grpc/user"
|
|
metrics_ids "github.com/pomerium/pomerium/pkg/metrics"
|
|
"github.com/pomerium/pomerium/pkg/protoutil"
|
|
)
|
|
|
|
type mockAuthenticator struct {
|
|
refreshResult *oauth2.Token
|
|
refreshError error
|
|
revokeError error
|
|
updateUserInfoError error
|
|
}
|
|
|
|
func (mock *mockAuthenticator) Refresh(_ context.Context, _ *oauth2.Token, _ identity.State) (*oauth2.Token, error) {
|
|
return mock.refreshResult, mock.refreshError
|
|
}
|
|
|
|
func (mock *mockAuthenticator) Revoke(_ context.Context, _ *oauth2.Token) error {
|
|
return mock.revokeError
|
|
}
|
|
|
|
func (mock *mockAuthenticator) UpdateUserInfo(_ context.Context, _ *oauth2.Token, _ any) error {
|
|
return mock.updateUserInfoError
|
|
}
|
|
|
|
func TestManager_refresh(t *testing.T) {
|
|
ctrl := gomock.NewController(t)
|
|
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10)
|
|
t.Cleanup(clearTimeout)
|
|
|
|
client := mock_databroker.NewMockDataBrokerServiceClient(ctrl)
|
|
mgr := New(WithDataBrokerClient(client))
|
|
mgr.onUpdateRecords(ctx, updateRecordsMessage{
|
|
records: []*databroker.Record{
|
|
databroker.NewRecord(&session.Session{
|
|
Id: "s1",
|
|
UserId: "u1",
|
|
OauthToken: &session.OAuthToken{},
|
|
ExpiresAt: timestamppb.New(time.Now().Add(time.Second * 10)),
|
|
}),
|
|
databroker.NewRecord(&user.User{
|
|
Id: "u1",
|
|
}),
|
|
},
|
|
})
|
|
client.EXPECT().Get(gomock.Any(), gomock.Any()).Return(nil, status.Error(codes.NotFound, "not found"))
|
|
mgr.refreshSession(ctx, "u1", "s1")
|
|
mgr.refreshUser(ctx, "u1")
|
|
}
|
|
|
|
func TestManager_onUpdateRecords(t *testing.T) {
|
|
ctrl := gomock.NewController(t)
|
|
|
|
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10)
|
|
defer clearTimeout()
|
|
|
|
now := time.Now()
|
|
|
|
mgr := New(
|
|
WithDataBrokerClient(mock_databroker.NewMockDataBrokerServiceClient(ctrl)),
|
|
WithNow(func() time.Time {
|
|
return now
|
|
}),
|
|
)
|
|
|
|
mgr.onUpdateRecords(ctx, updateRecordsMessage{
|
|
records: []*databroker.Record{
|
|
mkRecord(&session.Session{Id: "session1", UserId: "user1"}),
|
|
mkRecord(&user.User{Id: "user1", Name: "user 1", Email: "user1@example.com"}),
|
|
},
|
|
})
|
|
|
|
if _, ok := mgr.sessions.Get("user1", "session1"); assert.True(t, ok) {
|
|
tm, id := mgr.sessionScheduler.Next()
|
|
assert.Equal(t, now.Add(10*time.Second), tm)
|
|
assert.Equal(t, "user1\037session1", id)
|
|
}
|
|
if _, ok := mgr.users.Get("user1"); assert.True(t, ok) {
|
|
tm, id := mgr.userScheduler.Next()
|
|
assert.Equal(t, now.Add(userRefreshInterval), tm)
|
|
assert.Equal(t, "user1", id)
|
|
}
|
|
}
|
|
|
|
func TestManager_onUpdateSession(t *testing.T) {
|
|
startTime := time.Date(2023, 10, 19, 12, 0, 0, 0, time.UTC)
|
|
|
|
s := &session.Session{
|
|
Id: "session-id",
|
|
UserId: "user-id",
|
|
OauthToken: &session.OAuthToken{
|
|
AccessToken: "access-token",
|
|
ExpiresAt: timestamppb.New(startTime.Add(5 * time.Minute)),
|
|
},
|
|
IssuedAt: timestamppb.New(startTime),
|
|
ExpiresAt: timestamppb.New(startTime.Add(24 * time.Hour)),
|
|
}
|
|
|
|
assertNextScheduled := func(t *testing.T, mgr *Manager, expectedTime time.Time) {
|
|
t.Helper()
|
|
tm, key := mgr.sessionScheduler.Next()
|
|
assert.Equal(t, expectedTime, tm)
|
|
assert.Equal(t, "user-id\037session-id", key)
|
|
}
|
|
|
|
t.Run("initial refresh event when not expiring soon", func(t *testing.T) {
|
|
now := startTime
|
|
mgr := New(WithNow(func() time.Time { return now }))
|
|
|
|
// When the Manager first becomes aware of a session it should schedule
|
|
// a refresh event for one minute before access token expiration.
|
|
mgr.onUpdateSession(mkRecord(s), s)
|
|
assertNextScheduled(t, mgr, startTime.Add(4*time.Minute))
|
|
})
|
|
t.Run("initial refresh event when expiring soon", func(t *testing.T) {
|
|
now := startTime
|
|
mgr := New(WithNow(func() time.Time { return now }))
|
|
|
|
// When the Manager first becomes aware of a session, if that session
|
|
// is expiring within the gracePeriod (1 minute), it should schedule a
|
|
// refresh event for as soon as possible, subject to the
|
|
// coolOffDuration (10 seconds).
|
|
now = now.Add(4*time.Minute + 30*time.Second) // 30 s before expiration
|
|
mgr.onUpdateSession(mkRecord(s), s)
|
|
assertNextScheduled(t, mgr, now.Add(10*time.Second))
|
|
})
|
|
t.Run("update near scheduled refresh", func(t *testing.T) {
|
|
now := startTime
|
|
mgr := New(WithNow(func() time.Time { return now }))
|
|
|
|
mgr.onUpdateSession(mkRecord(s), s)
|
|
assertNextScheduled(t, mgr, startTime.Add(4*time.Minute))
|
|
|
|
// If a session is updated close to the time when it is scheduled to be
|
|
// refreshed, the scheduled refresh event should not be pushed back.
|
|
now = now.Add(3*time.Minute + 55*time.Second) // 5 s before refresh
|
|
mgr.onUpdateSession(mkRecord(s), s)
|
|
assertNextScheduled(t, mgr, now.Add(5*time.Second))
|
|
|
|
// However, if an update changes the access token validity, the refresh
|
|
// event should be rescheduled accordingly. (This should be uncommon,
|
|
// as only the refresh loop itself should modify the access token.)
|
|
s2 := proto.Clone(s).(*session.Session)
|
|
s2.OauthToken.ExpiresAt = timestamppb.New(now.Add(5 * time.Minute))
|
|
mgr.onUpdateSession(mkRecord(s2), s2)
|
|
assertNextScheduled(t, mgr, now.Add(4*time.Minute))
|
|
})
|
|
t.Run("session record deleted", func(t *testing.T) {
|
|
now := startTime
|
|
mgr := New(WithNow(func() time.Time { return now }))
|
|
|
|
mgr.onUpdateSession(mkRecord(s), s)
|
|
assertNextScheduled(t, mgr, startTime.Add(4*time.Minute))
|
|
|
|
// If a session is deleted, any scheduled refresh event should be canceled.
|
|
record := mkRecord(s)
|
|
record.DeletedAt = timestamppb.New(now)
|
|
mgr.onUpdateSession(record, s)
|
|
_, key := mgr.sessionScheduler.Next()
|
|
assert.Empty(t, key)
|
|
})
|
|
}
|
|
|
|
func TestManager_refreshSession(t *testing.T) {
|
|
startTime := time.Date(2023, 10, 19, 12, 0, 0, 0, time.UTC)
|
|
|
|
var auth mockAuthenticator
|
|
|
|
ctrl := gomock.NewController(t)
|
|
client := mock_databroker.NewMockDataBrokerServiceClient(ctrl)
|
|
|
|
now := startTime
|
|
mgr := New(
|
|
WithDataBrokerClient(client),
|
|
WithNow(func() time.Time { return now }),
|
|
WithAuthenticator(&auth),
|
|
)
|
|
|
|
// Initialize the Manager with a new session.
|
|
s := &session.Session{
|
|
Id: "session-id",
|
|
UserId: "user-id",
|
|
OauthToken: &session.OAuthToken{
|
|
AccessToken: "access-token",
|
|
ExpiresAt: timestamppb.New(startTime.Add(5 * time.Minute)),
|
|
RefreshToken: "refresh-token",
|
|
},
|
|
IssuedAt: timestamppb.New(startTime),
|
|
ExpiresAt: timestamppb.New(startTime.Add(24 * time.Hour)),
|
|
}
|
|
mgr.sessions.ReplaceOrInsert(Session{
|
|
Session: s,
|
|
lastRefresh: startTime,
|
|
gracePeriod: time.Minute,
|
|
coolOffDuration: 10 * time.Second,
|
|
})
|
|
|
|
// If OAuth2 token refresh fails with a temporary error, the manager should
|
|
// still reschedule another refresh attempt.
|
|
now = now.Add(4 * time.Minute)
|
|
auth.refreshError = context.DeadlineExceeded
|
|
mgr.refreshSession(context.Background(), "user-id", "session-id")
|
|
|
|
tm, key := mgr.sessionScheduler.Next()
|
|
assert.Equal(t, now.Add(10*time.Second), tm)
|
|
assert.Equal(t, "user-id\037session-id", key)
|
|
|
|
// Simulate a successful token refresh on the second attempt. The manager
|
|
// should store the updated session in the databroker and schedule another
|
|
// refresh event.
|
|
now = now.Add(10 * time.Second)
|
|
auth.refreshResult, auth.refreshError = &oauth2.Token{
|
|
AccessToken: "new-access-token",
|
|
RefreshToken: "new-refresh-token",
|
|
Expiry: now.Add(5 * time.Minute),
|
|
}, nil
|
|
expectedSession := proto.Clone(s).(*session.Session)
|
|
expectedSession.OauthToken = &session.OAuthToken{
|
|
AccessToken: "new-access-token",
|
|
ExpiresAt: timestamppb.New(now.Add(5 * time.Minute)),
|
|
RefreshToken: "new-refresh-token",
|
|
}
|
|
client.EXPECT().Put(gomock.Any(),
|
|
objectsAreEqualMatcher{&databroker.PutRequest{Records: []*databroker.Record{{
|
|
Type: "type.googleapis.com/session.Session",
|
|
Id: "session-id",
|
|
Data: protoutil.NewAny(expectedSession),
|
|
}}}}).
|
|
Return(nil /* this result is currently unused */, nil)
|
|
mgr.refreshSession(context.Background(), "user-id", "session-id")
|
|
|
|
tm, key = mgr.sessionScheduler.Next()
|
|
assert.Equal(t, now.Add(4*time.Minute), tm)
|
|
assert.Equal(t, "user-id\037session-id", key)
|
|
}
|
|
|
|
func TestManager_reportErrors(t *testing.T) {
|
|
ctrl := gomock.NewController(t)
|
|
|
|
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10)
|
|
defer clearTimeout()
|
|
|
|
evtMgr := events.New()
|
|
received := make(chan events.Event, 1)
|
|
handle := evtMgr.Register(func(evt events.Event) {
|
|
received <- evt
|
|
})
|
|
defer evtMgr.Unregister(handle)
|
|
|
|
expectMsg := func(id, msg string) {
|
|
t.Helper()
|
|
assert.Eventually(t, func() bool {
|
|
select {
|
|
case evt := <-received:
|
|
lastErr := evt.(*events.LastError)
|
|
return msg == lastErr.Message && id == lastErr.Id
|
|
default:
|
|
return false
|
|
}
|
|
}, time.Second, time.Millisecond*20, msg)
|
|
}
|
|
|
|
s := &session.Session{
|
|
Id: "session1",
|
|
UserId: "user1",
|
|
OauthToken: &session.OAuthToken{
|
|
ExpiresAt: timestamppb.New(time.Now().Add(time.Hour)),
|
|
},
|
|
ExpiresAt: timestamppb.New(time.Now().Add(time.Hour)),
|
|
}
|
|
|
|
client := mock_databroker.NewMockDataBrokerServiceClient(ctrl)
|
|
client.EXPECT().Get(gomock.Any(), gomock.Any()).AnyTimes().Return(&databroker.GetResponse{Record: databroker.NewRecord(s)}, nil)
|
|
client.EXPECT().Put(gomock.Any(), gomock.Any()).AnyTimes()
|
|
mgr := New(
|
|
WithEventManager(evtMgr),
|
|
WithDataBrokerClient(client),
|
|
WithAuthenticator(&mockAuthenticator{
|
|
refreshError: errors.New("update session"),
|
|
updateUserInfoError: errors.New("update user info"),
|
|
}),
|
|
)
|
|
|
|
mgr.onUpdateRecords(ctx, updateRecordsMessage{
|
|
records: []*databroker.Record{
|
|
mkRecord(s),
|
|
mkRecord(&user.User{Id: "user1", Name: "user 1", Email: "user1@example.com"}),
|
|
},
|
|
})
|
|
|
|
mgr.refreshUser(ctx, "user1")
|
|
expectMsg(metrics_ids.IdentityManagerLastUserRefreshError, "update user info")
|
|
|
|
mgr.onUpdateRecords(ctx, updateRecordsMessage{
|
|
records: []*databroker.Record{
|
|
mkRecord(s),
|
|
mkRecord(&user.User{Id: "user1", Name: "user 1", Email: "user1@example.com"}),
|
|
},
|
|
})
|
|
|
|
mgr.refreshSession(ctx, "user1", "session1")
|
|
expectMsg(metrics_ids.IdentityManagerLastSessionRefreshError, "update session")
|
|
}
|
|
|
|
func mkRecord(msg recordable) *databroker.Record {
|
|
data := protoutil.NewAny(msg)
|
|
return &databroker.Record{
|
|
Type: data.GetTypeUrl(),
|
|
Id: msg.GetId(),
|
|
Data: data,
|
|
}
|
|
}
|
|
|
|
type recordable interface {
|
|
proto.Message
|
|
GetId() string
|
|
}
|
|
|
|
// objectsAreEqualMatcher implements gomock.Matcher using ObjectsAreEqual. This
|
|
// is especially helpful when working with pointers, as it will compare the
|
|
// underlying values rather than the pointers themselves.
|
|
type objectsAreEqualMatcher struct {
|
|
expected interface{}
|
|
}
|
|
|
|
func (m objectsAreEqualMatcher) Matches(x interface{}) bool {
|
|
return assert.ObjectsAreEqual(m.expected, x)
|
|
}
|
|
|
|
func (m objectsAreEqualMatcher) String() string {
|
|
return fmt.Sprintf("is equal to %v (%T)", m.expected, m.expected)
|
|
}
|