diff --git a/authenticate/handlers.go b/authenticate/handlers.go index 5b6c52586..dabb26c3d 100644 --- a/authenticate/handlers.go +++ b/authenticate/handlers.go @@ -539,23 +539,22 @@ func (a *Authenticate) saveSessionToDataBroker( s.SetRawIDToken(claims.RawIDToken) s.AddClaims(claims.Flatten()) - // if no user exists yet, create a new one - currentUser, _ := user.Get(ctx, state.dataBrokerClient, s.GetUserId()) - if currentUser == nil { - mu := manager.User{ - User: &user.User{ - Id: s.GetUserId(), - }, - } - err := a.provider.Load().UpdateUserInfo(ctx, accessToken, &mu) - if err != nil { - return fmt.Errorf("authenticate: error retrieving user info: %w", err) - } - _, err = user.Put(ctx, state.dataBrokerClient, mu.User) - if err != nil { - return fmt.Errorf("authenticate: error saving user: %w", err) + var managerUser manager.User + managerUser.User, _ = user.Get(ctx, state.dataBrokerClient, s.GetUserId()) + if managerUser.User == nil { + // if no user exists yet, create a new one + managerUser.User = &user.User{ + Id: s.GetUserId(), } } + err := a.provider.Load().UpdateUserInfo(ctx, accessToken, &managerUser) + if err != nil { + return fmt.Errorf("authenticate: error retrieving user info: %w", err) + } + _, err = user.Put(ctx, state.dataBrokerClient, managerUser.User) + if err != nil { + return fmt.Errorf("authenticate: error saving user: %w", err) + } res, err := session.Put(ctx, state.dataBrokerClient, s) if err != nil { diff --git a/internal/identity/manager/config.go b/internal/identity/manager/config.go index bb9d1b929..e9ef874d6 100644 --- a/internal/identity/manager/config.go +++ b/internal/identity/manager/config.go @@ -23,6 +23,7 @@ type config struct { groupRefreshTimeout time.Duration sessionRefreshGracePeriod time.Duration sessionRefreshCoolOffDuration time.Duration + now func() time.Time } func newConfig(options ...Option) *config { @@ -31,6 +32,7 @@ func newConfig(options ...Option) *config { WithGroupRefreshTimeout(defaultGroupRefreshTimeout)(cfg) WithSessionRefreshGracePeriod(defaultSessionRefreshGracePeriod)(cfg) WithSessionRefreshCoolOffDuration(defaultSessionRefreshCoolOffDuration)(cfg) + WithNow(time.Now)(cfg) for _, option := range options { option(cfg) } @@ -89,6 +91,13 @@ func WithSessionRefreshCoolOffDuration(dur time.Duration) Option { } } +// WithNow customizes the time.Now function used by the manager. +func WithNow(now func() time.Time) Option { + return func(cfg *config) { + cfg.now = now + } +} + type atomicConfig struct { value atomic.Value } diff --git a/internal/identity/manager/manager.go b/internal/identity/manager/manager.go index c701d38c3..7d8e6235e 100644 --- a/internal/identity/manager/manager.go +++ b/internal/identity/manager/manager.go @@ -77,6 +77,7 @@ func New( } mgr.directoryBackoff = backoff.NewExponentialBackOff() mgr.directoryBackoff.MaxElapsedTime = 0 + mgr.reset() mgr.UpdateConfig(options...) return mgr } @@ -128,10 +129,7 @@ func (mgr *Manager) refreshLoop(ctx context.Context, update <-chan updateRecords case <-ctx.Done(): return ctx.Err() case <-clear: - mgr.directoryGroups = make(map[string]*directory.Group) - mgr.directoryUsers = make(map[string]*directory.User) - mgr.sessions = sessionCollection{BTree: btree.New(8)} - mgr.users = userCollection{BTree: btree.New(8)} + mgr.reset() } select { case <-ctx.Done(): @@ -161,10 +159,7 @@ func (mgr *Manager) refreshLoop(ctx context.Context, update <-chan updateRecords case <-ctx.Done(): return ctx.Err() case <-clear: - mgr.directoryGroups = make(map[string]*directory.Group) - mgr.directoryUsers = make(map[string]*directory.User) - mgr.sessions = sessionCollection{BTree: btree.New(8)} - mgr.users = userCollection{BTree: btree.New(8)} + mgr.reset() case msg := <-update: mgr.onUpdateRecords(ctx, msg) case <-timer.C: @@ -571,7 +566,7 @@ func (mgr *Manager) onUpdateUser(_ context.Context, record *databroker.Record, u } u, _ := mgr.users.Get(user.GetId()) - u.lastRefresh = time.Now() + u.lastRefresh = mgr.cfg.Load().now() u.refreshInterval = mgr.cfg.Load().groupRefreshInterval u.User = user mgr.users.ReplaceOrInsert(u) @@ -595,6 +590,14 @@ func (mgr *Manager) deleteSession(ctx context.Context, pbSession *session.Sessio } } +// reset resets all the manager datastructures to their initial state +func (mgr *Manager) reset() { + mgr.directoryGroups = make(map[string]*directory.Group) + mgr.directoryUsers = make(map[string]*directory.User) + mgr.sessions = sessionCollection{BTree: btree.New(8)} + mgr.users = userCollection{BTree: btree.New(8)} +} + func isTemporaryError(err error) bool { if err == nil { return false diff --git a/internal/identity/manager/manager_test.go b/internal/identity/manager/manager_test.go index 4b5bfab67..41578a53d 100644 --- a/internal/identity/manager/manager_test.go +++ b/internal/identity/manager/manager_test.go @@ -7,8 +7,13 @@ import ( "time" "github.com/stretchr/testify/assert" + "google.golang.org/protobuf/proto" "github.com/pomerium/pomerium/internal/directory" + "github.com/pomerium/pomerium/pkg/grpc/databroker" + "github.com/pomerium/pomerium/pkg/grpc/session" + "github.com/pomerium/pomerium/pkg/grpc/user" + "github.com/pomerium/pomerium/pkg/protoutil" ) type mockProvider struct { @@ -24,6 +29,43 @@ func (mock mockProvider) UserGroups(ctx context.Context) ([]*directory.Group, [] return mock.userGroups(ctx) } +func TestManager_onUpdateRecords(t *testing.T) { + ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10) + defer clearTimeout() + + now := time.Now() + + mgr := New( + WithDirectoryProvider(mockProvider{}), + WithGroupRefreshInterval(time.Hour), + WithNow(func() time.Time { + return now + }), + ) + mgr.directoryBackoff.RandomizationFactor = 0 // disable randomization for deterministic testing + + mgr.onUpdateRecords(ctx, updateRecordsMessage{ + records: []*databroker.Record{ + mkRecord(&directory.Group{Id: "group1", Name: "group 1", Email: "group1@example.com"}), + mkRecord(&directory.User{Id: "user1", DisplayName: "user 1", Email: "user1@example.com", GroupIds: []string{"group1s"}}), + mkRecord(&session.Session{Id: "session1", UserId: "user1"}), + mkRecord(&user.User{Id: "user1", Name: "user 1", Email: "user1@example.com"}), + }, + }) + + assert.NotNil(t, mgr.directoryGroups["group1"]) + assert.NotNil(t, mgr.directoryUsers["user1"]) + if _, ok := mgr.sessions.Get("user1", "session1"); assert.True(t, ok) { + + } + if _, ok := mgr.users.Get("user1"); assert.True(t, ok) { + tm, id := mgr.userScheduler.Next() + assert.Equal(t, now.Add(time.Hour), tm) + assert.Equal(t, "user1", id) + } + +} + func TestManager_refreshDirectoryUserGroups(t *testing.T) { ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10) defer clearTimeout() @@ -56,3 +98,17 @@ func TestManager_refreshDirectoryUserGroups(t *testing.T) { assert.Equal(t, time.Hour, dur3) }) } + +func mkRecord(msg recordable) *databroker.Record { + any := protoutil.NewAny(msg) + return &databroker.Record{ + Type: any.GetTypeUrl(), + Id: msg.GetId(), + Data: any, + } +} + +type recordable interface { + proto.Message + GetId() string +}