diff --git a/internal/enabler/enabler.go b/internal/enabler/enabler.go new file mode 100644 index 000000000..32e5aeec0 --- /dev/null +++ b/internal/enabler/enabler.go @@ -0,0 +1,108 @@ +// package enabler contains a component that can be enabled and disabled dynamically +package enabler + +import ( + "context" + "errors" + "sync" + + "github.com/pomerium/pomerium/internal/log" +) + +var errCauseEnabler = errors.New("enabler") + +// A Handler is a component with a RunEnabled function. +type Handler interface { + RunEnabled(ctx context.Context) error +} + +// HandlerFunc is a function run by the enabler. +type HandlerFunc func(ctx context.Context) error + +func (f HandlerFunc) RunEnabled(ctx context.Context) error { + return f(ctx) +} + +// An Enabler enables or disables a component dynamically. +// When the Enabler is enabled, the Handler's RunEnabled will be called. +// If the Enabler is subsequently disabled the context passed to RunEnabled will be canceled. +// If the Enabler is subseqently enabled again, RunEnabled will be called again. +// Handlers should obey the context lifetime and be tolerant of RunEnabled +// being called multiple times. (not concurrently) +type Enabler interface { + Run(ctx context.Context) error + Enable() + Disable() +} + +type enabler struct { + name string + handler Handler + + mu sync.Mutex + cancel context.CancelCauseFunc + enabled bool +} + +// New creates a new Enabler. +func New(name string, handler Handler, enabled bool) Enabler { + d := &enabler{ + name: name, + handler: handler, + enabled: enabled, + cancel: func(_ error) {}, + } + return d +} + +// Run calls RunEnabled if enabled, otherwise it waits until enabled. +func (d *enabler) Run(ctx context.Context) error { + for { + err := d.runOrWaitForEnabled(ctx) + // if we received any error but our own, exit with that error + if !errors.Is(err, errCauseEnabler) { + return err + } + } +} + +func (d *enabler) runOrWaitForEnabled(ctx context.Context) error { + d.mu.Lock() + enabled := d.enabled + ctx, d.cancel = context.WithCancelCause(ctx) + d.mu.Unlock() + + // we're enabled so call RunEnabled. If Disabled is called it will cancel ctx. + if enabled { + log.Ctx(ctx).Info().Msgf("enabled %s", d.name) + err := d.handler.RunEnabled(ctx) + // if RunEnabled stopped because we canceled the context + if errors.Is(err, context.Canceled) && errors.Is(context.Cause(ctx), errCauseEnabler) { + log.Ctx(ctx).Info().Msgf("disabled %s", d.name) + return errCauseEnabler + } + return err + } + + // wait until Enabled is called + <-ctx.Done() + return context.Cause(ctx) +} + +func (d *enabler) Enable() { + d.mu.Lock() + if !d.enabled { + d.enabled = true + d.cancel(errCauseEnabler) + } + d.mu.Unlock() +} + +func (d *enabler) Disable() { + d.mu.Lock() + if d.enabled { + d.enabled = false + d.cancel(errCauseEnabler) + } + d.mu.Unlock() +} diff --git a/internal/enabler/enabler_test.go b/internal/enabler/enabler_test.go new file mode 100644 index 000000000..34c74dc8c --- /dev/null +++ b/internal/enabler/enabler_test.go @@ -0,0 +1,61 @@ +package enabler_test + +import ( + "context" + "errors" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "github.com/pomerium/pomerium/internal/enabler" +) + +func TestEnabler(t *testing.T) { + t.Parallel() + + t.Run("enabled immediately", func(t *testing.T) { + t.Parallel() + + e := enabler.New("test", enabler.HandlerFunc(func(ctx context.Context) error { + return errors.New("ERROR") + }), true) + err := e.Run(context.Background()) + assert.Error(t, err) + }) + t.Run("enabled delayed", func(t *testing.T) { + t.Parallel() + + e := enabler.New("test", enabler.HandlerFunc(func(ctx context.Context) error { + return errors.New("ERROR") + }), false) + time.AfterFunc(time.Millisecond*10, e.Enable) + err := e.Run(context.Background()) + assert.Error(t, err) + }) + t.Run("disabled", func(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + + var started, stopped atomic.Int64 + e := enabler.New("test", enabler.HandlerFunc(func(ctx context.Context) error { + started.Add(1) + <-ctx.Done() + stopped.Add(1) + return ctx.Err() + }), true) + time.AfterFunc(time.Millisecond*10, e.Disable) + go e.Run(ctx) + + assert.Eventually(t, func() bool { return stopped.Load() == 1 }, time.Second, time.Millisecond*100, + "should stop RunEnabled") + + e.Enable() + + assert.Eventually(t, func() bool { return started.Load() == 2 }, time.Second, time.Millisecond*100, + "should re-start RunEnabled") + }) +} diff --git a/internal/identity/legacymanager/config.go b/internal/identity/legacymanager/config.go new file mode 100644 index 000000000..49682fbd2 --- /dev/null +++ b/internal/identity/legacymanager/config.go @@ -0,0 +1,87 @@ +package legacymanager + +import ( + "time" + + "github.com/pomerium/pomerium/internal/events" + "github.com/pomerium/pomerium/pkg/grpc/databroker" +) + +var ( + defaultSessionRefreshGracePeriod = 1 * time.Minute + defaultSessionRefreshCoolOffDuration = 10 * time.Second +) + +type config struct { + authenticator Authenticator + dataBrokerClient databroker.DataBrokerServiceClient + sessionRefreshGracePeriod time.Duration + sessionRefreshCoolOffDuration time.Duration + now func() time.Time + eventMgr *events.Manager + enabled bool +} + +func newConfig(options ...Option) *config { + cfg := new(config) + WithSessionRefreshGracePeriod(defaultSessionRefreshGracePeriod)(cfg) + WithSessionRefreshCoolOffDuration(defaultSessionRefreshCoolOffDuration)(cfg) + WithNow(time.Now)(cfg) + WithEnabled(true)(cfg) + for _, option := range options { + option(cfg) + } + return cfg +} + +// An Option customizes the configuration used for the identity manager. +type Option func(*config) + +// WithAuthenticator sets the authenticator in the config. +func WithAuthenticator(authenticator Authenticator) Option { + return func(cfg *config) { + cfg.authenticator = authenticator + } +} + +// WithDataBrokerClient sets the databroker client in the config. +func WithDataBrokerClient(dataBrokerClient databroker.DataBrokerServiceClient) Option { + return func(cfg *config) { + cfg.dataBrokerClient = dataBrokerClient + } +} + +// WithSessionRefreshGracePeriod sets the session refresh grace period used by the manager. +func WithSessionRefreshGracePeriod(dur time.Duration) Option { + return func(cfg *config) { + cfg.sessionRefreshGracePeriod = dur + } +} + +// WithSessionRefreshCoolOffDuration sets the session refresh cool-off duration used by the manager. +func WithSessionRefreshCoolOffDuration(dur time.Duration) Option { + return func(cfg *config) { + cfg.sessionRefreshCoolOffDuration = dur + } +} + +// WithNow customizes the time.Now function used by the manager. +func WithNow(now func() time.Time) Option { + return func(cfg *config) { + cfg.now = now + } +} + +// WithEventManager passes an event manager to record events +func WithEventManager(mgr *events.Manager) Option { + return func(cfg *config) { + cfg.eventMgr = mgr + } +} + +// WithEnabled sets the enabled option in the config. +func WithEnabled(enabled bool) Option { + return func(cfg *config) { + cfg.enabled = enabled + } +} diff --git a/internal/identity/legacymanager/data.go b/internal/identity/legacymanager/data.go new file mode 100644 index 000000000..97033b671 --- /dev/null +++ b/internal/identity/legacymanager/data.go @@ -0,0 +1,266 @@ +package legacymanager + +import ( + "encoding/json" + "time" + + "github.com/google/btree" + "google.golang.org/protobuf/types/known/timestamppb" + + "github.com/pomerium/pomerium/internal/identity" + "github.com/pomerium/pomerium/pkg/grpc/session" + "github.com/pomerium/pomerium/pkg/grpc/user" +) + +const userRefreshInterval = 10 * time.Minute + +// A User is a user managed by the Manager. +type User struct { + *user.User + lastRefresh time.Time +} + +// NextRefresh returns the next time the user information needs to be refreshed. +func (u User) NextRefresh() time.Time { + return u.lastRefresh.Add(userRefreshInterval) +} + +// UnmarshalJSON unmarshals json data into the user object. +func (u *User) UnmarshalJSON(data []byte) error { + if u.User == nil { + u.User = new(user.User) + } + + var raw map[string]json.RawMessage + err := json.Unmarshal(data, &raw) + if err != nil { + return err + } + + if name, ok := raw["name"]; ok { + _ = json.Unmarshal(name, &u.User.Name) + delete(raw, "name") + } + if email, ok := raw["email"]; ok { + _ = json.Unmarshal(email, &u.User.Email) + delete(raw, "email") + } + + u.AddClaims(identity.NewClaimsFromRaw(raw).Flatten()) + + return nil +} + +// A Session is a session managed by the Manager. +type Session struct { + *session.Session + // lastRefresh is the time of the last refresh attempt (which may or may + // not have succeeded), or else the time the Manager first became aware of + // the session (if it has not yet attempted to refresh this session). + lastRefresh time.Time + // gracePeriod is the amount of time before expiration to attempt a refresh. + gracePeriod time.Duration + // coolOffDuration is the amount of time to wait before attempting another refresh. + coolOffDuration time.Duration +} + +// NextRefresh returns the next time the session needs to be refreshed. +func (s Session) NextRefresh() time.Time { + var tm time.Time + + if s.GetOauthToken().GetExpiresAt() != nil { + expiry := s.GetOauthToken().GetExpiresAt().AsTime() + if s.GetOauthToken().GetExpiresAt().IsValid() && !expiry.IsZero() { + expiry = expiry.Add(-s.gracePeriod) + if tm.IsZero() || expiry.Before(tm) { + tm = expiry + } + } + } + + if s.GetExpiresAt() != nil { + expiry := s.GetExpiresAt().AsTime() + if s.GetExpiresAt().IsValid() && !expiry.IsZero() { + if tm.IsZero() || expiry.Before(tm) { + tm = expiry + } + } + } + + // don't refresh any quicker than the cool-off duration + min := s.lastRefresh.Add(s.coolOffDuration) + if tm.Before(min) { + tm = min + } + + return tm +} + +// UnmarshalJSON unmarshals json data into the session object. +func (s *Session) UnmarshalJSON(data []byte) error { + if s.Session == nil { + s.Session = new(session.Session) + } + + var raw map[string]json.RawMessage + err := json.Unmarshal(data, &raw) + if err != nil { + return err + } + + if s.Session.IdToken == nil { + s.Session.IdToken = new(session.IDToken) + } + + if iss, ok := raw["iss"]; ok { + _ = json.Unmarshal(iss, &s.Session.IdToken.Issuer) + delete(raw, "iss") + } + if sub, ok := raw["sub"]; ok { + _ = json.Unmarshal(sub, &s.Session.IdToken.Subject) + delete(raw, "sub") + } + if exp, ok := raw["exp"]; ok { + var secs int64 + if err := json.Unmarshal(exp, &secs); err == nil { + s.Session.IdToken.ExpiresAt = timestamppb.New(time.Unix(secs, 0)) + } + delete(raw, "exp") + } + if iat, ok := raw["iat"]; ok { + var secs int64 + if err := json.Unmarshal(iat, &secs); err == nil { + s.Session.IdToken.IssuedAt = timestamppb.New(time.Unix(secs, 0)) + } + delete(raw, "iat") + } + + s.AddClaims(identity.NewClaimsFromRaw(raw).Flatten()) + + return nil +} + +type sessionCollectionItem struct { + Session +} + +func (item sessionCollectionItem) Less(than btree.Item) bool { + xUserID, yUserID := item.GetUserId(), than.(sessionCollectionItem).GetUserId() + switch { + case xUserID < yUserID: + return true + case yUserID < xUserID: + return false + } + + xID, yID := item.GetId(), than.(sessionCollectionItem).GetId() + switch { + case xID < yID: + return true + case yID < xID: + return false + } + return false +} + +type sessionCollection struct { + *btree.BTree +} + +func (c *sessionCollection) Delete(userID, sessionID string) { + c.BTree.Delete(sessionCollectionItem{ + Session: Session{ + Session: &session.Session{ + UserId: userID, + Id: sessionID, + }, + }, + }) +} + +func (c *sessionCollection) Get(userID, sessionID string) (Session, bool) { + item := c.BTree.Get(sessionCollectionItem{ + Session: Session{ + Session: &session.Session{ + UserId: userID, + Id: sessionID, + }, + }, + }) + if item == nil { + return Session{}, false + } + return item.(sessionCollectionItem).Session, true +} + +// GetSessionsForUser gets all the sessions for the given user. +func (c *sessionCollection) GetSessionsForUser(userID string) []Session { + var sessions []Session + c.AscendGreaterOrEqual(sessionCollectionItem{ + Session: Session{ + Session: &session.Session{ + UserId: userID, + }, + }, + }, func(item btree.Item) bool { + s := item.(sessionCollectionItem).Session + if s.UserId != userID { + return false + } + + sessions = append(sessions, s) + return true + }) + return sessions +} + +func (c *sessionCollection) ReplaceOrInsert(s Session) { + c.BTree.ReplaceOrInsert(sessionCollectionItem{Session: s}) +} + +type userCollectionItem struct { + User +} + +func (item userCollectionItem) Less(than btree.Item) bool { + xID, yID := item.GetId(), than.(userCollectionItem).GetId() + switch { + case xID < yID: + return true + case yID < xID: + return false + } + return false +} + +type userCollection struct { + *btree.BTree +} + +func (c *userCollection) Delete(userID string) { + c.BTree.Delete(userCollectionItem{ + User: User{ + User: &user.User{ + Id: userID, + }, + }, + }) +} + +func (c *userCollection) Get(userID string) (User, bool) { + item := c.BTree.Get(userCollectionItem{ + User: User{ + User: &user.User{ + Id: userID, + }, + }, + }) + if item == nil { + return User{}, false + } + return item.(userCollectionItem).User, true +} + +func (c *userCollection) ReplaceOrInsert(u User) { + c.BTree.ReplaceOrInsert(userCollectionItem{User: u}) +} diff --git a/internal/identity/legacymanager/data_test.go b/internal/identity/legacymanager/data_test.go new file mode 100644 index 000000000..14bb18e84 --- /dev/null +++ b/internal/identity/legacymanager/data_test.go @@ -0,0 +1,74 @@ +package legacymanager + +import ( + "encoding/json" + "fmt" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "google.golang.org/protobuf/types/known/structpb" + "google.golang.org/protobuf/types/known/timestamppb" + + "github.com/pomerium/pomerium/pkg/grpc/session" + "github.com/pomerium/pomerium/pkg/protoutil" +) + +func TestUser_UnmarshalJSON(t *testing.T) { + var u User + err := json.Unmarshal([]byte(`{ + "name": "joe", + "email": "joe@test.com", + "some-other-claim": "xyz" + }`), &u) + assert.NoError(t, err) + assert.NotNil(t, u.User) + assert.Equal(t, "joe", u.User.Name) + assert.Equal(t, "joe@test.com", u.User.Email) + assert.Equal(t, map[string]*structpb.ListValue{ + "some-other-claim": {Values: []*structpb.Value{protoutil.ToStruct("xyz")}}, + }, u.Claims) +} + +func TestSession_NextRefresh(t *testing.T) { + tm1 := time.Date(2020, 6, 5, 12, 0, 0, 0, time.UTC) + s := Session{ + Session: &session.Session{}, + lastRefresh: tm1, + gracePeriod: time.Second * 10, + coolOffDuration: time.Minute, + } + assert.Equal(t, tm1.Add(time.Minute), s.NextRefresh()) + + tm2 := time.Date(2020, 6, 5, 13, 0, 0, 0, time.UTC) + s.OauthToken = &session.OAuthToken{ + ExpiresAt: timestamppb.New(tm2), + } + assert.Equal(t, tm2.Add(-time.Second*10), s.NextRefresh()) + + tm3 := time.Date(2020, 6, 5, 12, 15, 0, 0, time.UTC) + s.ExpiresAt = timestamppb.New(tm3) + assert.Equal(t, tm3, s.NextRefresh()) +} + +func TestSession_UnmarshalJSON(t *testing.T) { + tm := time.Date(2020, 6, 5, 12, 0, 0, 0, time.UTC) + var s Session + err := json.Unmarshal([]byte(`{ + "iss": "https://some.issuer.com", + "sub": "subject", + "exp": `+fmt.Sprint(tm.Unix())+`, + "iat": `+fmt.Sprint(tm.Unix())+`, + "some-other-claim": "xyz" + }`), &s) + assert.NoError(t, err) + assert.NotNil(t, s.Session) + assert.NotNil(t, s.Session.IdToken) + assert.Equal(t, "https://some.issuer.com", s.Session.IdToken.Issuer) + assert.Equal(t, "subject", s.Session.IdToken.Subject) + assert.Equal(t, timestamppb.New(tm), s.Session.IdToken.ExpiresAt) + assert.Equal(t, timestamppb.New(tm), s.Session.IdToken.IssuedAt) + assert.Equal(t, map[string]*structpb.ListValue{ + "some-other-claim": {Values: []*structpb.Value{protoutil.ToStruct("xyz")}}, + }, s.Claims) +} diff --git a/internal/identity/legacymanager/manager.go b/internal/identity/legacymanager/manager.go new file mode 100644 index 000000000..296f414ff --- /dev/null +++ b/internal/identity/legacymanager/manager.go @@ -0,0 +1,497 @@ +// Package legacymanager contains an identity manager responsible for refreshing sessions and creating users. +package legacymanager + +import ( + "context" + "errors" + "time" + + "github.com/google/btree" + "github.com/rs/zerolog" + "golang.org/x/oauth2" + "golang.org/x/sync/errgroup" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/types/known/fieldmaskpb" + "google.golang.org/protobuf/types/known/timestamppb" + + "github.com/pomerium/pomerium/internal/atomicutil" + "github.com/pomerium/pomerium/internal/enabler" + "github.com/pomerium/pomerium/internal/events" + "github.com/pomerium/pomerium/internal/identity/identity" + "github.com/pomerium/pomerium/internal/log" + "github.com/pomerium/pomerium/internal/scheduler" + "github.com/pomerium/pomerium/internal/telemetry/metrics" + "github.com/pomerium/pomerium/pkg/grpc/databroker" + "github.com/pomerium/pomerium/pkg/grpc/session" + "github.com/pomerium/pomerium/pkg/grpc/user" + "github.com/pomerium/pomerium/pkg/grpcutil" + metrics_ids "github.com/pomerium/pomerium/pkg/metrics" +) + +// Authenticator is an identity.Provider with only the methods needed by the manager. +type Authenticator interface { + Refresh(context.Context, *oauth2.Token, identity.State) (*oauth2.Token, error) + Revoke(context.Context, *oauth2.Token) error + UpdateUserInfo(context.Context, *oauth2.Token, interface{}) error +} + +type ( + updateRecordsMessage struct { + records []*databroker.Record + } +) + +// A Manager refreshes identity information using session and user data. +type Manager struct { + enabler.Enabler + cfg *atomicutil.Value[*config] + + sessionScheduler *scheduler.Scheduler + userScheduler *scheduler.Scheduler + + sessions sessionCollection + users userCollection +} + +// New creates a new identity manager. +func New( + options ...Option, +) *Manager { + mgr := &Manager{ + cfg: atomicutil.NewValue(newConfig()), + + sessionScheduler: scheduler.New(), + userScheduler: scheduler.New(), + } + mgr.Enabler = enabler.New("identity_manager", mgr, true) + mgr.reset() + mgr.UpdateConfig(options...) + return mgr +} + +func withLog(ctx context.Context) context.Context { + return log.WithContext(ctx, func(c zerolog.Context) zerolog.Context { + return c.Str("service", "identity_manager") + }) +} + +// UpdateConfig updates the manager with the new options. +func (mgr *Manager) UpdateConfig(options ...Option) { + mgr.cfg.Store(newConfig(options...)) + if mgr.cfg.Load().enabled { + mgr.Enable() + } else { + mgr.Disable() + } +} + +// RunEnabled runs the manager. This method blocks until an error occurs or the given context is canceled. +func (mgr *Manager) RunEnabled(ctx context.Context) error { + leaser := databroker.NewLeaser("identity_manager", time.Second*30, mgr) + return leaser.Run(ctx) +} + +// RunLeased runs the identity manager when a lease is acquired. +func (mgr *Manager) RunLeased(ctx context.Context) error { + ctx = withLog(ctx) + update := make(chan updateRecordsMessage, 1) + clear := make(chan struct{}, 1) + + syncer := newDataBrokerSyncer(ctx, mgr.cfg, update, clear) + + eg, ctx := errgroup.WithContext(ctx) + eg.Go(func() error { + return syncer.Run(ctx) + }) + eg.Go(func() error { + return mgr.refreshLoop(ctx, update, clear) + }) + + return eg.Wait() +} + +// GetDataBrokerServiceClient gets the databroker client. +func (mgr *Manager) GetDataBrokerServiceClient() databroker.DataBrokerServiceClient { + return mgr.cfg.Load().dataBrokerClient +} + +func (mgr *Manager) now() time.Time { + return mgr.cfg.Load().now() +} + +func (mgr *Manager) refreshLoop(ctx context.Context, update <-chan updateRecordsMessage, clear <-chan struct{}) error { + // wait for initial sync + select { + case <-ctx.Done(): + return ctx.Err() + case <-clear: + mgr.reset() + } + select { + case <-ctx.Done(): + return ctx.Err() + case msg := <-update: + mgr.onUpdateRecords(ctx, msg) + } + + log.Debug(ctx). + Int("sessions", mgr.sessions.Len()). + Int("users", mgr.users.Len()). + Msg("initial sync complete") + + // start refreshing + maxWait := time.Minute * 10 + var nextTime time.Time + + timer := time.NewTimer(0) + defer timer.Stop() + + for { + select { + case <-ctx.Done(): + return ctx.Err() + case <-clear: + mgr.reset() + case msg := <-update: + mgr.onUpdateRecords(ctx, msg) + case <-timer.C: + } + + now := mgr.now() + nextTime = now.Add(maxWait) + + // refresh sessions + for { + tm, key := mgr.sessionScheduler.Next() + if now.Before(tm) { + if tm.Before(nextTime) { + nextTime = tm + } + break + } + mgr.sessionScheduler.Remove(key) + + userID, sessionID := fromSessionSchedulerKey(key) + mgr.refreshSession(ctx, userID, sessionID) + } + + // refresh users + for { + tm, key := mgr.userScheduler.Next() + if now.Before(tm) { + if tm.Before(nextTime) { + nextTime = tm + } + break + } + mgr.userScheduler.Remove(key) + + mgr.refreshUser(ctx, key) + } + + metrics.RecordIdentityManagerLastRefresh(ctx) + timer.Reset(time.Until(nextTime)) + } +} + +// refreshSession handles two distinct session lifecycle events: +// +// 1. If the session itself has expired, delete the session. +// 2. If the session's underlying OAuth2 access token is nearing expiration +// (but the session itself is still valid), refresh the access token. +// +// After a successful access token refresh, this method will also trigger a +// user info refresh. If an access token refresh or a user info refresh fails +// with a permanent error, the session will be deleted. +func (mgr *Manager) refreshSession(ctx context.Context, userID, sessionID string) { + log.Info(ctx). + Str("user_id", userID). + Str("session_id", sessionID). + Msg("refreshing session") + + s, ok := mgr.sessions.Get(userID, sessionID) + if !ok { + log.Warn(ctx). + Str("user_id", userID). + Str("session_id", sessionID). + Msg("no session found for refresh") + return + } + + s.lastRefresh = mgr.now() + + if mgr.refreshSessionInternal(ctx, userID, sessionID, &s) { + mgr.sessions.ReplaceOrInsert(s) + mgr.sessionScheduler.Add(s.NextRefresh(), toSessionSchedulerKey(userID, sessionID)) + } +} + +// refreshSessionInternal performs the core refresh logic and returns true if +// the session should be scheduled for refresh again, or false if not. +func (mgr *Manager) refreshSessionInternal( + ctx context.Context, userID, sessionID string, s *Session, +) bool { + authenticator := mgr.cfg.Load().authenticator + if authenticator == nil { + log.Info(ctx). + Str("user_id", userID). + Str("session_id", sessionID). + Msg("no authenticator defined, deleting session") + mgr.deleteSession(ctx, userID, sessionID) + return false + } + + expiry := s.GetExpiresAt().AsTime() + if !expiry.After(mgr.now()) { + log.Info(ctx). + Str("user_id", userID). + Str("session_id", sessionID). + Msg("deleting expired session") + mgr.deleteSession(ctx, userID, sessionID) + return false + } + + if s.Session == nil || s.Session.OauthToken == nil { + log.Warn(ctx). + Str("user_id", userID). + Str("session_id", sessionID). + Msg("no session oauth2 token found for refresh") + return false + } + + newToken, err := authenticator.Refresh(ctx, FromOAuthToken(s.OauthToken), s) + metrics.RecordIdentityManagerSessionRefresh(ctx, err) + mgr.recordLastError(metrics_ids.IdentityManagerLastSessionRefreshError, err) + if isTemporaryError(err) { + log.Error(ctx).Err(err). + Str("user_id", s.GetUserId()). + Str("session_id", s.GetId()). + Msg("failed to refresh oauth2 token") + return true + } else if err != nil { + log.Error(ctx).Err(err). + Str("user_id", s.GetUserId()). + Str("session_id", s.GetId()). + Msg("failed to refresh oauth2 token, deleting session") + mgr.deleteSession(ctx, userID, sessionID) + return false + } + s.OauthToken = ToOAuthToken(newToken) + + err = authenticator.UpdateUserInfo(ctx, FromOAuthToken(s.OauthToken), s) + metrics.RecordIdentityManagerUserRefresh(ctx, err) + mgr.recordLastError(metrics_ids.IdentityManagerLastUserRefreshError, err) + if isTemporaryError(err) { + log.Error(ctx).Err(err). + Str("user_id", s.GetUserId()). + Str("session_id", s.GetId()). + Msg("failed to update user info") + return true + } else if err != nil { + log.Error(ctx).Err(err). + Str("user_id", s.GetUserId()). + Str("session_id", s.GetId()). + Msg("failed to update user info, deleting session") + mgr.deleteSession(ctx, userID, sessionID) + return false + } + + fm, err := fieldmaskpb.New(s.Session, "oauth_token", "id_token", "claims") + if err != nil { + log.Error(ctx).Err(err).Msg("internal error") + return false + } + + if _, err := session.Patch(ctx, mgr.cfg.Load().dataBrokerClient, s.Session, fm); err != nil { + log.Error(ctx).Err(err). + Str("user_id", s.GetUserId()). + Str("session_id", s.GetId()). + Msg("failed to update session") + } + return true +} + +func (mgr *Manager) refreshUser(ctx context.Context, userID string) { + log.Info(ctx). + Str("user_id", userID). + Msg("refreshing user") + + authenticator := mgr.cfg.Load().authenticator + if authenticator == nil { + return + } + + u, ok := mgr.users.Get(userID) + if !ok { + log.Warn(ctx). + Str("user_id", userID). + Msg("no user found for refresh") + return + } + u.lastRefresh = mgr.now() + mgr.userScheduler.Add(u.NextRefresh(), u.GetId()) + + for _, s := range mgr.sessions.GetSessionsForUser(userID) { + if s.Session == nil || s.Session.OauthToken == nil { + log.Warn(ctx). + Str("user_id", userID). + Msg("no session oauth2 token found for refresh") + continue + } + + err := authenticator.UpdateUserInfo(ctx, FromOAuthToken(s.OauthToken), &u) + metrics.RecordIdentityManagerUserRefresh(ctx, err) + mgr.recordLastError(metrics_ids.IdentityManagerLastUserRefreshError, err) + if isTemporaryError(err) { + log.Error(ctx).Err(err). + Str("user_id", s.GetUserId()). + Str("session_id", s.GetId()). + Msg("failed to update user info") + return + } else if err != nil { + log.Error(ctx).Err(err). + Str("user_id", s.GetUserId()). + Str("session_id", s.GetId()). + Msg("failed to update user info, deleting session") + mgr.deleteSession(ctx, userID, s.GetId()) + continue + } + + res, err := databroker.Put(ctx, mgr.cfg.Load().dataBrokerClient, u.User) + if err != nil { + log.Error(ctx).Err(err). + Str("user_id", s.GetUserId()). + Str("session_id", s.GetId()). + Msg("failed to update user") + continue + } + + mgr.onUpdateUser(ctx, res.GetRecords()[0], u.User) + } +} + +func (mgr *Manager) onUpdateRecords(ctx context.Context, msg updateRecordsMessage) { + for _, record := range msg.records { + switch record.GetType() { + case grpcutil.GetTypeURL(new(session.Session)): + var pbSession session.Session + err := record.GetData().UnmarshalTo(&pbSession) + if err != nil { + log.Warn(ctx).Msgf("error unmarshaling session: %s", err) + continue + } + mgr.onUpdateSession(record, &pbSession) + case grpcutil.GetTypeURL(new(user.User)): + var pbUser user.User + err := record.GetData().UnmarshalTo(&pbUser) + if err != nil { + log.Warn(ctx).Msgf("error unmarshaling user: %s", err) + continue + } + mgr.onUpdateUser(ctx, record, &pbUser) + } + } +} + +func (mgr *Manager) onUpdateSession(record *databroker.Record, session *session.Session) { + mgr.sessionScheduler.Remove(toSessionSchedulerKey(session.GetUserId(), session.GetId())) + + if record.GetDeletedAt() != nil { + mgr.sessions.Delete(session.GetUserId(), session.GetId()) + return + } + + // update session + s, _ := mgr.sessions.Get(session.GetUserId(), session.GetId()) + if s.lastRefresh.IsZero() { + s.lastRefresh = mgr.now() + } + s.gracePeriod = mgr.cfg.Load().sessionRefreshGracePeriod + s.coolOffDuration = mgr.cfg.Load().sessionRefreshCoolOffDuration + s.Session = session + mgr.sessions.ReplaceOrInsert(s) + mgr.sessionScheduler.Add(s.NextRefresh(), toSessionSchedulerKey(session.GetUserId(), session.GetId())) +} + +func (mgr *Manager) onUpdateUser(_ context.Context, record *databroker.Record, user *user.User) { + mgr.userScheduler.Remove(user.GetId()) + + if record.GetDeletedAt() != nil { + mgr.users.Delete(user.GetId()) + return + } + + u, _ := mgr.users.Get(user.GetId()) + u.lastRefresh = mgr.cfg.Load().now() + u.User = user + mgr.users.ReplaceOrInsert(u) + mgr.userScheduler.Add(u.NextRefresh(), u.GetId()) +} + +func (mgr *Manager) deleteSession(ctx context.Context, userID, sessionID string) { + mgr.sessionScheduler.Remove(toSessionSchedulerKey(userID, sessionID)) + mgr.sessions.Delete(userID, sessionID) + + client := mgr.cfg.Load().dataBrokerClient + res, err := client.Get(ctx, &databroker.GetRequest{ + Type: grpcutil.GetTypeURL(new(session.Session)), + Id: sessionID, + }) + if status.Code(err) == codes.NotFound { + return + } else if err != nil { + log.Error(ctx).Err(err). + Str("session_id", sessionID). + Msg("failed to delete session") + return + } + + record := res.GetRecord() + record.DeletedAt = timestamppb.Now() + + _, err = client.Put(ctx, &databroker.PutRequest{ + Records: []*databroker.Record{record}, + }) + if err != nil { + log.Error(ctx).Err(err). + Str("session_id", sessionID). + Msg("failed to delete session") + return + } +} + +// reset resets all the manager datastructures to their initial state +func (mgr *Manager) reset() { + mgr.sessions = sessionCollection{BTree: btree.New(8)} + mgr.users = userCollection{BTree: btree.New(8)} +} + +func (mgr *Manager) recordLastError(id string, err error) { + if err == nil { + return + } + evtMgr := mgr.cfg.Load().eventMgr + if evtMgr == nil { + return + } + evtMgr.Dispatch(&events.LastError{ + Time: timestamppb.Now(), + Message: err.Error(), + Id: id, + }) +} + +func isTemporaryError(err error) bool { + if err == nil { + return false + } + if errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled) { + return true + } + var hasTemporary interface{ Temporary() bool } + if errors.As(err, &hasTemporary) && hasTemporary.Temporary() { + return true + } + return false +} diff --git a/internal/identity/legacymanager/manager_test.go b/internal/identity/legacymanager/manager_test.go new file mode 100644 index 000000000..4536c061a --- /dev/null +++ b/internal/identity/legacymanager/manager_test.go @@ -0,0 +1,360 @@ +package legacymanager + +import ( + "context" + "errors" + "fmt" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "go.uber.org/mock/gomock" + "golang.org/x/oauth2" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/known/fieldmaskpb" + "google.golang.org/protobuf/types/known/timestamppb" + + "github.com/pomerium/pomerium/internal/events" + "github.com/pomerium/pomerium/internal/identity/identity" + "github.com/pomerium/pomerium/pkg/grpc/databroker" + "github.com/pomerium/pomerium/pkg/grpc/databroker/mock_databroker" + "github.com/pomerium/pomerium/pkg/grpc/session" + "github.com/pomerium/pomerium/pkg/grpc/user" + metrics_ids "github.com/pomerium/pomerium/pkg/metrics" + "github.com/pomerium/pomerium/pkg/protoutil" +) + +type mockAuthenticator struct { + refreshResult *oauth2.Token + refreshError error + revokeError error + updateUserInfoError error +} + +func (mock *mockAuthenticator) Refresh(_ context.Context, _ *oauth2.Token, _ identity.State) (*oauth2.Token, error) { + return mock.refreshResult, mock.refreshError +} + +func (mock *mockAuthenticator) Revoke(_ context.Context, _ *oauth2.Token) error { + return mock.revokeError +} + +func (mock *mockAuthenticator) UpdateUserInfo(_ context.Context, _ *oauth2.Token, _ any) error { + return mock.updateUserInfoError +} + +func TestManager_refresh(t *testing.T) { + ctrl := gomock.NewController(t) + ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10) + t.Cleanup(clearTimeout) + + client := mock_databroker.NewMockDataBrokerServiceClient(ctrl) + mgr := New(WithDataBrokerClient(client)) + mgr.onUpdateRecords(ctx, updateRecordsMessage{ + records: []*databroker.Record{ + databroker.NewRecord(&session.Session{ + Id: "s1", + UserId: "u1", + OauthToken: &session.OAuthToken{}, + ExpiresAt: timestamppb.New(time.Now().Add(time.Second * 10)), + }), + databroker.NewRecord(&user.User{ + Id: "u1", + }), + }, + }) + client.EXPECT().Get(gomock.Any(), gomock.Any()).Return(nil, status.Error(codes.NotFound, "not found")) + mgr.refreshSession(ctx, "u1", "s1") + mgr.refreshUser(ctx, "u1") +} + +func TestManager_onUpdateRecords(t *testing.T) { + ctrl := gomock.NewController(t) + + ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10) + defer clearTimeout() + + now := time.Now() + + mgr := New( + WithDataBrokerClient(mock_databroker.NewMockDataBrokerServiceClient(ctrl)), + WithNow(func() time.Time { + return now + }), + ) + + mgr.onUpdateRecords(ctx, updateRecordsMessage{ + records: []*databroker.Record{ + mkRecord(&session.Session{Id: "session1", UserId: "user1"}), + mkRecord(&user.User{Id: "user1", Name: "user 1", Email: "user1@example.com"}), + }, + }) + + if _, ok := mgr.sessions.Get("user1", "session1"); assert.True(t, ok) { + tm, id := mgr.sessionScheduler.Next() + assert.Equal(t, now.Add(10*time.Second), tm) + assert.Equal(t, "user1\037session1", id) + } + if _, ok := mgr.users.Get("user1"); assert.True(t, ok) { + tm, id := mgr.userScheduler.Next() + assert.Equal(t, now.Add(userRefreshInterval), tm) + assert.Equal(t, "user1", id) + } +} + +func TestManager_onUpdateSession(t *testing.T) { + startTime := time.Date(2023, 10, 19, 12, 0, 0, 0, time.UTC) + + s := &session.Session{ + Id: "session-id", + UserId: "user-id", + OauthToken: &session.OAuthToken{ + AccessToken: "access-token", + ExpiresAt: timestamppb.New(startTime.Add(5 * time.Minute)), + }, + IssuedAt: timestamppb.New(startTime), + ExpiresAt: timestamppb.New(startTime.Add(24 * time.Hour)), + } + + assertNextScheduled := func(t *testing.T, mgr *Manager, expectedTime time.Time) { + t.Helper() + tm, key := mgr.sessionScheduler.Next() + assert.Equal(t, expectedTime, tm) + assert.Equal(t, "user-id\037session-id", key) + } + + t.Run("initial refresh event when not expiring soon", func(t *testing.T) { + now := startTime + mgr := New(WithNow(func() time.Time { return now })) + + // When the Manager first becomes aware of a session it should schedule + // a refresh event for one minute before access token expiration. + mgr.onUpdateSession(mkRecord(s), s) + assertNextScheduled(t, mgr, startTime.Add(4*time.Minute)) + }) + t.Run("initial refresh event when expiring soon", func(t *testing.T) { + now := startTime + mgr := New(WithNow(func() time.Time { return now })) + + // When the Manager first becomes aware of a session, if that session + // is expiring within the gracePeriod (1 minute), it should schedule a + // refresh event for as soon as possible, subject to the + // coolOffDuration (10 seconds). + now = now.Add(4*time.Minute + 30*time.Second) // 30 s before expiration + mgr.onUpdateSession(mkRecord(s), s) + assertNextScheduled(t, mgr, now.Add(10*time.Second)) + }) + t.Run("update near scheduled refresh", func(t *testing.T) { + now := startTime + mgr := New(WithNow(func() time.Time { return now })) + + mgr.onUpdateSession(mkRecord(s), s) + assertNextScheduled(t, mgr, startTime.Add(4*time.Minute)) + + // If a session is updated close to the time when it is scheduled to be + // refreshed, the scheduled refresh event should not be pushed back. + now = now.Add(3*time.Minute + 55*time.Second) // 5 s before refresh + mgr.onUpdateSession(mkRecord(s), s) + assertNextScheduled(t, mgr, now.Add(5*time.Second)) + + // However, if an update changes the access token validity, the refresh + // event should be rescheduled accordingly. (This should be uncommon, + // as only the refresh loop itself should modify the access token.) + s2 := proto.Clone(s).(*session.Session) + s2.OauthToken.ExpiresAt = timestamppb.New(now.Add(5 * time.Minute)) + mgr.onUpdateSession(mkRecord(s2), s2) + assertNextScheduled(t, mgr, now.Add(4*time.Minute)) + }) + t.Run("session record deleted", func(t *testing.T) { + now := startTime + mgr := New(WithNow(func() time.Time { return now })) + + mgr.onUpdateSession(mkRecord(s), s) + assertNextScheduled(t, mgr, startTime.Add(4*time.Minute)) + + // If a session is deleted, any scheduled refresh event should be canceled. + record := mkRecord(s) + record.DeletedAt = timestamppb.New(now) + mgr.onUpdateSession(record, s) + _, key := mgr.sessionScheduler.Next() + assert.Empty(t, key) + }) +} + +func TestManager_refreshSession(t *testing.T) { + startTime := time.Date(2023, 10, 19, 12, 0, 0, 0, time.UTC) + + var auth mockAuthenticator + + ctrl := gomock.NewController(t) + client := mock_databroker.NewMockDataBrokerServiceClient(ctrl) + + now := startTime + mgr := New( + WithDataBrokerClient(client), + WithNow(func() time.Time { return now }), + WithAuthenticator(&auth), + ) + + // Initialize the Manager with a new session. + s := &session.Session{ + Id: "session-id", + UserId: "user-id", + OauthToken: &session.OAuthToken{ + AccessToken: "access-token", + ExpiresAt: timestamppb.New(startTime.Add(5 * time.Minute)), + RefreshToken: "refresh-token", + }, + IssuedAt: timestamppb.New(startTime), + ExpiresAt: timestamppb.New(startTime.Add(24 * time.Hour)), + } + mgr.sessions.ReplaceOrInsert(Session{ + Session: s, + lastRefresh: startTime, + gracePeriod: time.Minute, + coolOffDuration: 10 * time.Second, + }) + + // If OAuth2 token refresh fails with a temporary error, the manager should + // still reschedule another refresh attempt. + now = now.Add(4 * time.Minute) + auth.refreshError = context.DeadlineExceeded + mgr.refreshSession(context.Background(), "user-id", "session-id") + + tm, key := mgr.sessionScheduler.Next() + assert.Equal(t, now.Add(10*time.Second), tm) + assert.Equal(t, "user-id\037session-id", key) + + // Simulate a successful token refresh on the second attempt. The manager + // should store the updated session in the databroker and schedule another + // refresh event. + now = now.Add(10 * time.Second) + auth.refreshResult, auth.refreshError = &oauth2.Token{ + AccessToken: "new-access-token", + RefreshToken: "new-refresh-token", + Expiry: now.Add(5 * time.Minute), + }, nil + expectedSession := proto.Clone(s).(*session.Session) + expectedSession.OauthToken = &session.OAuthToken{ + AccessToken: "new-access-token", + ExpiresAt: timestamppb.New(now.Add(5 * time.Minute)), + RefreshToken: "new-refresh-token", + } + client.EXPECT().Patch(gomock.Any(), objectsAreEqualMatcher{ + &databroker.PatchRequest{ + Records: []*databroker.Record{{ + Type: "type.googleapis.com/session.Session", + Id: "session-id", + Data: protoutil.NewAny(expectedSession), + }}, + FieldMask: &fieldmaskpb.FieldMask{ + Paths: []string{"oauth_token", "id_token", "claims"}, + }, + }, + }). + Return(nil /* this result is currently unused */, nil) + mgr.refreshSession(context.Background(), "user-id", "session-id") + + tm, key = mgr.sessionScheduler.Next() + assert.Equal(t, now.Add(4*time.Minute), tm) + assert.Equal(t, "user-id\037session-id", key) +} + +func TestManager_reportErrors(t *testing.T) { + ctrl := gomock.NewController(t) + + ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10) + defer clearTimeout() + + evtMgr := events.New() + received := make(chan events.Event, 1) + handle := evtMgr.Register(func(evt events.Event) { + received <- evt + }) + defer evtMgr.Unregister(handle) + + expectMsg := func(id, msg string) { + t.Helper() + assert.Eventually(t, func() bool { + select { + case evt := <-received: + lastErr := evt.(*events.LastError) + return msg == lastErr.Message && id == lastErr.Id + default: + return false + } + }, time.Second, time.Millisecond*20, msg) + } + + s := &session.Session{ + Id: "session1", + UserId: "user1", + OauthToken: &session.OAuthToken{ + ExpiresAt: timestamppb.New(time.Now().Add(time.Hour)), + }, + ExpiresAt: timestamppb.New(time.Now().Add(time.Hour)), + } + + client := mock_databroker.NewMockDataBrokerServiceClient(ctrl) + client.EXPECT().Get(gomock.Any(), gomock.Any()).AnyTimes().Return(&databroker.GetResponse{Record: databroker.NewRecord(s)}, nil) + client.EXPECT().Put(gomock.Any(), gomock.Any()).AnyTimes() + mgr := New( + WithEventManager(evtMgr), + WithDataBrokerClient(client), + WithAuthenticator(&mockAuthenticator{ + refreshError: errors.New("update session"), + updateUserInfoError: errors.New("update user info"), + }), + ) + + mgr.onUpdateRecords(ctx, updateRecordsMessage{ + records: []*databroker.Record{ + mkRecord(s), + mkRecord(&user.User{Id: "user1", Name: "user 1", Email: "user1@example.com"}), + }, + }) + + mgr.refreshUser(ctx, "user1") + expectMsg(metrics_ids.IdentityManagerLastUserRefreshError, "update user info") + + mgr.onUpdateRecords(ctx, updateRecordsMessage{ + records: []*databroker.Record{ + mkRecord(s), + mkRecord(&user.User{Id: "user1", Name: "user 1", Email: "user1@example.com"}), + }, + }) + + mgr.refreshSession(ctx, "user1", "session1") + expectMsg(metrics_ids.IdentityManagerLastSessionRefreshError, "update session") +} + +func mkRecord(msg recordable) *databroker.Record { + data := protoutil.NewAny(msg) + return &databroker.Record{ + Type: data.GetTypeUrl(), + Id: msg.GetId(), + Data: data, + } +} + +type recordable interface { + proto.Message + GetId() string +} + +// objectsAreEqualMatcher implements gomock.Matcher using ObjectsAreEqual. This +// is especially helpful when working with pointers, as it will compare the +// underlying values rather than the pointers themselves. +type objectsAreEqualMatcher struct { + expected interface{} +} + +func (m objectsAreEqualMatcher) Matches(x interface{}) bool { + return assert.ObjectsAreEqual(m.expected, x) +} + +func (m objectsAreEqualMatcher) String() string { + return fmt.Sprintf("is equal to %v (%T)", m.expected, m.expected) +} diff --git a/internal/identity/legacymanager/misc.go b/internal/identity/legacymanager/misc.go new file mode 100644 index 000000000..55e73605b --- /dev/null +++ b/internal/identity/legacymanager/misc.go @@ -0,0 +1,46 @@ +package legacymanager + +import ( + "strings" + + "golang.org/x/oauth2" + "google.golang.org/protobuf/types/known/timestamppb" + + "github.com/pomerium/pomerium/pkg/grpc/session" +) + +func toSessionSchedulerKey(userID, sessionID string) string { + return userID + "\037" + sessionID +} + +func fromSessionSchedulerKey(key string) (userID, sessionID string) { + idx := strings.Index(key, "\037") + if idx >= 0 { + userID = key[:idx] + sessionID = key[idx+1:] + } else { + userID = key + } + return userID, sessionID +} + +// FromOAuthToken converts a session oauth token to oauth2.Token. +func FromOAuthToken(token *session.OAuthToken) *oauth2.Token { + return &oauth2.Token{ + AccessToken: token.GetAccessToken(), + TokenType: token.GetTokenType(), + RefreshToken: token.GetRefreshToken(), + Expiry: token.GetExpiresAt().AsTime(), + } +} + +// ToOAuthToken converts an oauth2.Token to a session oauth token. +func ToOAuthToken(token *oauth2.Token) *session.OAuthToken { + expiry := timestamppb.New(token.Expiry) + return &session.OAuthToken{ + AccessToken: token.AccessToken, + TokenType: token.TokenType, + RefreshToken: token.RefreshToken, + ExpiresAt: expiry, + } +} diff --git a/internal/identity/legacymanager/sync.go b/internal/identity/legacymanager/sync.go new file mode 100644 index 000000000..94810a749 --- /dev/null +++ b/internal/identity/legacymanager/sync.go @@ -0,0 +1,55 @@ +package legacymanager + +import ( + "context" + + "github.com/pomerium/pomerium/internal/atomicutil" + "github.com/pomerium/pomerium/pkg/grpc/databroker" +) + +type dataBrokerSyncer struct { + cfg *atomicutil.Value[*config] + + update chan<- updateRecordsMessage + clear chan<- struct{} + + syncer *databroker.Syncer +} + +func newDataBrokerSyncer( + _ context.Context, + cfg *atomicutil.Value[*config], + update chan<- updateRecordsMessage, + clear chan<- struct{}, +) *dataBrokerSyncer { + syncer := &dataBrokerSyncer{ + cfg: cfg, + + update: update, + clear: clear, + } + syncer.syncer = databroker.NewSyncer("identity_manager", syncer) + return syncer +} + +func (syncer *dataBrokerSyncer) Run(ctx context.Context) (err error) { + return syncer.syncer.Run(ctx) +} + +func (syncer *dataBrokerSyncer) ClearRecords(ctx context.Context) { + select { + case <-ctx.Done(): + case syncer.clear <- struct{}{}: + } +} + +func (syncer *dataBrokerSyncer) GetDataBrokerServiceClient() databroker.DataBrokerServiceClient { + return syncer.cfg.Load().dataBrokerClient +} + +func (syncer *dataBrokerSyncer) UpdateRecords(ctx context.Context, _ uint64, records []*databroker.Record) { + select { + case <-ctx.Done(): + case syncer.update <- updateRecordsMessage{records: records}: + } +} diff --git a/internal/identity/manager/config.go b/internal/identity/manager/config.go index 1e41e767f..5fbca8b67 100644 --- a/internal/identity/manager/config.go +++ b/internal/identity/manager/config.go @@ -21,6 +21,7 @@ type config struct { updateUserInfoInterval time.Duration now func() time.Time eventMgr *events.Manager + enabled bool } func newConfig(options ...Option) *config { @@ -29,6 +30,7 @@ func newConfig(options ...Option) *config { WithSessionRefreshCoolOffDuration(defaultSessionRefreshCoolOffDuration)(cfg) WithNow(time.Now)(cfg) WithUpdateUserInfoInterval(defaultUpdateUserInfoInterval)(cfg) + WithEnabled(true)(cfg) for _, option := range options { option(cfg) } @@ -75,8 +77,15 @@ func WithNow(now func() time.Time) Option { // WithEventManager passes an event manager to record events func WithEventManager(mgr *events.Manager) Option { - return func(c *config) { - c.eventMgr = mgr + return func(cfg *config) { + cfg.eventMgr = mgr + } +} + +// WithEnabled sets the enabled option in the config. +func WithEnabled(enabled bool) Option { + return func(cfg *config) { + cfg.enabled = enabled } } diff --git a/internal/identity/manager/manager.go b/internal/identity/manager/manager.go index 6e7d14653..c68678fe8 100644 --- a/internal/identity/manager/manager.go +++ b/internal/identity/manager/manager.go @@ -16,6 +16,7 @@ import ( "google.golang.org/protobuf/types/known/timestamppb" "github.com/pomerium/pomerium/internal/atomicutil" + "github.com/pomerium/pomerium/internal/enabler" "github.com/pomerium/pomerium/internal/events" "github.com/pomerium/pomerium/internal/identity/identity" "github.com/pomerium/pomerium/internal/log" @@ -36,6 +37,7 @@ type Authenticator interface { // A Manager refreshes identity information using session and user data. type Manager struct { + enabler.Enabler cfg *atomicutil.Value[*config] mu sync.Mutex @@ -55,6 +57,7 @@ func New( refreshSessionSchedulers: make(map[string]*refreshSessionScheduler), updateUserInfoSchedulers: make(map[string]*updateUserInfoScheduler), } + mgr.Enabler = enabler.New("identity_manager", mgr, true) mgr.UpdateConfig(options...) return mgr } @@ -62,6 +65,11 @@ func New( // UpdateConfig updates the manager with the new options. func (mgr *Manager) UpdateConfig(options ...Option) { mgr.cfg.Store(newConfig(options...)) + if mgr.cfg.Load().enabled { + mgr.Enable() + } else { + mgr.Disable() + } } // GetDataBrokerServiceClient gets the databroker client. @@ -69,8 +77,8 @@ func (mgr *Manager) GetDataBrokerServiceClient() databroker.DataBrokerServiceCli return mgr.cfg.Load().dataBrokerClient } -// Run runs the manager. This method blocks until an error occurs or the given context is canceled. -func (mgr *Manager) Run(ctx context.Context) error { +// RunEnabled runs the manager. This method blocks until an error occurs or the given context is canceled. +func (mgr *Manager) RunEnabled(ctx context.Context) error { leaser := databroker.NewLeaser("identity_manager", time.Second*30, mgr) return leaser.Run(ctx) }