core/identity: refactor identity manager (#5091)

* core/identity: add data store for thread-safe storage of sessions and users

* wip

* add test

* wip

* clean up context

* fix nil session error

* add stop message

* remove log

* use origin context

* use base context for manager calls

* use manager context for syncers too

* add runtime flag

* rename legacy lease

* add comment

* use NotSame

* add comment

* Update internal/identity/manager/manager.go

Co-authored-by: Kenneth Jenkins <51246568+kenjenkins@users.noreply.github.com>

* lint

---------

Co-authored-by: Kenneth Jenkins <51246568+kenjenkins@users.noreply.github.com>
This commit is contained in:
Caleb Doxsey 2024-05-02 10:27:06 -06:00 committed by GitHub
parent e30d90206d
commit a95423b310
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
21 changed files with 2284 additions and 876 deletions

View file

@ -8,6 +8,9 @@ var (
// RuntimeFlagMatchAnyIncomingPort enables ignoring the incoming port when matching routes
RuntimeFlagMatchAnyIncomingPort = runtimeFlag("match_any_incoming_port", true)
// RuntimeFlagLegacyIdentityManager enables the legacy identity manager
RuntimeFlagLegacyIdentityManager = runtimeFlag("legacy_identity_manager", false)
)
// RuntimeFlag is a runtime flag that can flip on/off certain features

View file

@ -18,6 +18,7 @@ import (
"github.com/pomerium/pomerium/internal/atomicutil"
"github.com/pomerium/pomerium/internal/events"
"github.com/pomerium/pomerium/internal/identity"
"github.com/pomerium/pomerium/internal/identity/legacymanager"
"github.com/pomerium/pomerium/internal/identity/manager"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/telemetry"
@ -34,6 +35,7 @@ import (
type DataBroker struct {
dataBrokerServer *dataBrokerServer
manager *manager.Manager
legacyManager *legacymanager.Manager
eventsMgr *events.Manager
localListener net.Listener
@ -158,6 +160,12 @@ func (c *DataBroker) update(ctx context.Context, cfg *config.Config) error {
options := []manager.Option{
manager.WithDataBrokerClient(dataBrokerClient),
manager.WithEventManager(c.eventsMgr),
manager.WithEnabled(!cfg.Options.IsRuntimeFlagSet(config.RuntimeFlagLegacyIdentityManager)),
}
legacyOptions := []legacymanager.Option{
legacymanager.WithDataBrokerClient(dataBrokerClient),
legacymanager.WithEventManager(c.eventsMgr),
legacymanager.WithEnabled(cfg.Options.IsRuntimeFlagSet(config.RuntimeFlagLegacyIdentityManager)),
}
if cfg.Options.SupportsUserRefresh() {
@ -166,6 +174,7 @@ func (c *DataBroker) update(ctx context.Context, cfg *config.Config) error {
log.Error(ctx).Err(err).Msg("databroker: failed to create authenticator")
} else {
options = append(options, manager.WithAuthenticator(authenticator))
legacyOptions = append(legacyOptions, legacymanager.WithAuthenticator(authenticator))
}
} else {
log.Info(ctx).Msg("databroker: disabling refresh of user sessions")
@ -177,6 +186,12 @@ func (c *DataBroker) update(ctx context.Context, cfg *config.Config) error {
c.manager.UpdateConfig(options...)
}
if c.legacyManager == nil {
c.legacyManager = legacymanager.New(legacyOptions...)
} else {
c.legacyManager.UpdateConfig(legacyOptions...)
}
return nil
}

View file

@ -196,16 +196,15 @@ func (s *Stateful) PersistSession(
sess.SetRawIDToken(claims.RawIDToken)
sess.AddClaims(claims.Flatten())
var managerUser manager.User
managerUser.User, _ = user.Get(ctx, s.dataBrokerClient, sess.GetUserId())
if managerUser.User == nil {
u, _ := user.Get(ctx, s.dataBrokerClient, sess.GetUserId())
if u == nil {
// if no user exists yet, create a new one
managerUser.User = &user.User{
u = &user.User{
Id: sess.GetUserId(),
}
}
populateUserFromClaims(managerUser.User, claims.Claims)
_, err := databroker.Put(ctx, s.dataBrokerClient, managerUser.User)
populateUserFromClaims(u, claims.Claims)
_, err := databroker.Put(ctx, s.dataBrokerClient, u)
if err != nil {
return fmt.Errorf("authenticate: error saving user: %w", err)
}

View file

@ -0,0 +1,87 @@
package legacymanager
import (
"time"
"github.com/pomerium/pomerium/internal/events"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
)
var (
defaultSessionRefreshGracePeriod = 1 * time.Minute
defaultSessionRefreshCoolOffDuration = 10 * time.Second
)
type config struct {
authenticator Authenticator
dataBrokerClient databroker.DataBrokerServiceClient
sessionRefreshGracePeriod time.Duration
sessionRefreshCoolOffDuration time.Duration
now func() time.Time
eventMgr *events.Manager
enabled bool
}
func newConfig(options ...Option) *config {
cfg := new(config)
WithSessionRefreshGracePeriod(defaultSessionRefreshGracePeriod)(cfg)
WithSessionRefreshCoolOffDuration(defaultSessionRefreshCoolOffDuration)(cfg)
WithNow(time.Now)(cfg)
WithEnabled(true)(cfg)
for _, option := range options {
option(cfg)
}
return cfg
}
// An Option customizes the configuration used for the identity manager.
type Option func(*config)
// WithAuthenticator sets the authenticator in the config.
func WithAuthenticator(authenticator Authenticator) Option {
return func(cfg *config) {
cfg.authenticator = authenticator
}
}
// WithDataBrokerClient sets the databroker client in the config.
func WithDataBrokerClient(dataBrokerClient databroker.DataBrokerServiceClient) Option {
return func(cfg *config) {
cfg.dataBrokerClient = dataBrokerClient
}
}
// WithSessionRefreshGracePeriod sets the session refresh grace period used by the manager.
func WithSessionRefreshGracePeriod(dur time.Duration) Option {
return func(cfg *config) {
cfg.sessionRefreshGracePeriod = dur
}
}
// WithSessionRefreshCoolOffDuration sets the session refresh cool-off duration used by the manager.
func WithSessionRefreshCoolOffDuration(dur time.Duration) Option {
return func(cfg *config) {
cfg.sessionRefreshCoolOffDuration = dur
}
}
// WithNow customizes the time.Now function used by the manager.
func WithNow(now func() time.Time) Option {
return func(cfg *config) {
cfg.now = now
}
}
// WithEventManager passes an event manager to record events
func WithEventManager(mgr *events.Manager) Option {
return func(cfg *config) {
cfg.eventMgr = mgr
}
}
// WithEnabled sets the enabled option in the config.
func WithEnabled(enabled bool) Option {
return func(cfg *config) {
cfg.enabled = enabled
}
}

View file

@ -0,0 +1,266 @@
package legacymanager
import (
"encoding/json"
"time"
"github.com/google/btree"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/pomerium/pomerium/internal/identity"
"github.com/pomerium/pomerium/pkg/grpc/session"
"github.com/pomerium/pomerium/pkg/grpc/user"
)
const userRefreshInterval = 10 * time.Minute
// A User is a user managed by the Manager.
type User struct {
*user.User
lastRefresh time.Time
}
// NextRefresh returns the next time the user information needs to be refreshed.
func (u User) NextRefresh() time.Time {
return u.lastRefresh.Add(userRefreshInterval)
}
// UnmarshalJSON unmarshals json data into the user object.
func (u *User) UnmarshalJSON(data []byte) error {
if u.User == nil {
u.User = new(user.User)
}
var raw map[string]json.RawMessage
err := json.Unmarshal(data, &raw)
if err != nil {
return err
}
if name, ok := raw["name"]; ok {
_ = json.Unmarshal(name, &u.User.Name)
delete(raw, "name")
}
if email, ok := raw["email"]; ok {
_ = json.Unmarshal(email, &u.User.Email)
delete(raw, "email")
}
u.AddClaims(identity.NewClaimsFromRaw(raw).Flatten())
return nil
}
// A Session is a session managed by the Manager.
type Session struct {
*session.Session
// lastRefresh is the time of the last refresh attempt (which may or may
// not have succeeded), or else the time the Manager first became aware of
// the session (if it has not yet attempted to refresh this session).
lastRefresh time.Time
// gracePeriod is the amount of time before expiration to attempt a refresh.
gracePeriod time.Duration
// coolOffDuration is the amount of time to wait before attempting another refresh.
coolOffDuration time.Duration
}
// NextRefresh returns the next time the session needs to be refreshed.
func (s Session) NextRefresh() time.Time {
var tm time.Time
if s.GetOauthToken().GetExpiresAt() != nil {
expiry := s.GetOauthToken().GetExpiresAt().AsTime()
if s.GetOauthToken().GetExpiresAt().IsValid() && !expiry.IsZero() {
expiry = expiry.Add(-s.gracePeriod)
if tm.IsZero() || expiry.Before(tm) {
tm = expiry
}
}
}
if s.GetExpiresAt() != nil {
expiry := s.GetExpiresAt().AsTime()
if s.GetExpiresAt().IsValid() && !expiry.IsZero() {
if tm.IsZero() || expiry.Before(tm) {
tm = expiry
}
}
}
// don't refresh any quicker than the cool-off duration
min := s.lastRefresh.Add(s.coolOffDuration)
if tm.Before(min) {
tm = min
}
return tm
}
// UnmarshalJSON unmarshals json data into the session object.
func (s *Session) UnmarshalJSON(data []byte) error {
if s.Session == nil {
s.Session = new(session.Session)
}
var raw map[string]json.RawMessage
err := json.Unmarshal(data, &raw)
if err != nil {
return err
}
if s.Session.IdToken == nil {
s.Session.IdToken = new(session.IDToken)
}
if iss, ok := raw["iss"]; ok {
_ = json.Unmarshal(iss, &s.Session.IdToken.Issuer)
delete(raw, "iss")
}
if sub, ok := raw["sub"]; ok {
_ = json.Unmarshal(sub, &s.Session.IdToken.Subject)
delete(raw, "sub")
}
if exp, ok := raw["exp"]; ok {
var secs int64
if err := json.Unmarshal(exp, &secs); err == nil {
s.Session.IdToken.ExpiresAt = timestamppb.New(time.Unix(secs, 0))
}
delete(raw, "exp")
}
if iat, ok := raw["iat"]; ok {
var secs int64
if err := json.Unmarshal(iat, &secs); err == nil {
s.Session.IdToken.IssuedAt = timestamppb.New(time.Unix(secs, 0))
}
delete(raw, "iat")
}
s.AddClaims(identity.NewClaimsFromRaw(raw).Flatten())
return nil
}
type sessionCollectionItem struct {
Session
}
func (item sessionCollectionItem) Less(than btree.Item) bool {
xUserID, yUserID := item.GetUserId(), than.(sessionCollectionItem).GetUserId()
switch {
case xUserID < yUserID:
return true
case yUserID < xUserID:
return false
}
xID, yID := item.GetId(), than.(sessionCollectionItem).GetId()
switch {
case xID < yID:
return true
case yID < xID:
return false
}
return false
}
type sessionCollection struct {
*btree.BTree
}
func (c *sessionCollection) Delete(userID, sessionID string) {
c.BTree.Delete(sessionCollectionItem{
Session: Session{
Session: &session.Session{
UserId: userID,
Id: sessionID,
},
},
})
}
func (c *sessionCollection) Get(userID, sessionID string) (Session, bool) {
item := c.BTree.Get(sessionCollectionItem{
Session: Session{
Session: &session.Session{
UserId: userID,
Id: sessionID,
},
},
})
if item == nil {
return Session{}, false
}
return item.(sessionCollectionItem).Session, true
}
// GetSessionsForUser gets all the sessions for the given user.
func (c *sessionCollection) GetSessionsForUser(userID string) []Session {
var sessions []Session
c.AscendGreaterOrEqual(sessionCollectionItem{
Session: Session{
Session: &session.Session{
UserId: userID,
},
},
}, func(item btree.Item) bool {
s := item.(sessionCollectionItem).Session
if s.UserId != userID {
return false
}
sessions = append(sessions, s)
return true
})
return sessions
}
func (c *sessionCollection) ReplaceOrInsert(s Session) {
c.BTree.ReplaceOrInsert(sessionCollectionItem{Session: s})
}
type userCollectionItem struct {
User
}
func (item userCollectionItem) Less(than btree.Item) bool {
xID, yID := item.GetId(), than.(userCollectionItem).GetId()
switch {
case xID < yID:
return true
case yID < xID:
return false
}
return false
}
type userCollection struct {
*btree.BTree
}
func (c *userCollection) Delete(userID string) {
c.BTree.Delete(userCollectionItem{
User: User{
User: &user.User{
Id: userID,
},
},
})
}
func (c *userCollection) Get(userID string) (User, bool) {
item := c.BTree.Get(userCollectionItem{
User: User{
User: &user.User{
Id: userID,
},
},
})
if item == nil {
return User{}, false
}
return item.(userCollectionItem).User, true
}
func (c *userCollection) ReplaceOrInsert(u User) {
c.BTree.ReplaceOrInsert(userCollectionItem{User: u})
}

View file

@ -0,0 +1,74 @@
package legacymanager
import (
"encoding/json"
"fmt"
"testing"
"time"
"github.com/stretchr/testify/assert"
"google.golang.org/protobuf/types/known/structpb"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/pomerium/pomerium/pkg/grpc/session"
"github.com/pomerium/pomerium/pkg/protoutil"
)
func TestUser_UnmarshalJSON(t *testing.T) {
var u User
err := json.Unmarshal([]byte(`{
"name": "joe",
"email": "joe@test.com",
"some-other-claim": "xyz"
}`), &u)
assert.NoError(t, err)
assert.NotNil(t, u.User)
assert.Equal(t, "joe", u.User.Name)
assert.Equal(t, "joe@test.com", u.User.Email)
assert.Equal(t, map[string]*structpb.ListValue{
"some-other-claim": {Values: []*structpb.Value{protoutil.ToStruct("xyz")}},
}, u.Claims)
}
func TestSession_NextRefresh(t *testing.T) {
tm1 := time.Date(2020, 6, 5, 12, 0, 0, 0, time.UTC)
s := Session{
Session: &session.Session{},
lastRefresh: tm1,
gracePeriod: time.Second * 10,
coolOffDuration: time.Minute,
}
assert.Equal(t, tm1.Add(time.Minute), s.NextRefresh())
tm2 := time.Date(2020, 6, 5, 13, 0, 0, 0, time.UTC)
s.OauthToken = &session.OAuthToken{
ExpiresAt: timestamppb.New(tm2),
}
assert.Equal(t, tm2.Add(-time.Second*10), s.NextRefresh())
tm3 := time.Date(2020, 6, 5, 12, 15, 0, 0, time.UTC)
s.ExpiresAt = timestamppb.New(tm3)
assert.Equal(t, tm3, s.NextRefresh())
}
func TestSession_UnmarshalJSON(t *testing.T) {
tm := time.Date(2020, 6, 5, 12, 0, 0, 0, time.UTC)
var s Session
err := json.Unmarshal([]byte(`{
"iss": "https://some.issuer.com",
"sub": "subject",
"exp": `+fmt.Sprint(tm.Unix())+`,
"iat": `+fmt.Sprint(tm.Unix())+`,
"some-other-claim": "xyz"
}`), &s)
assert.NoError(t, err)
assert.NotNil(t, s.Session)
assert.NotNil(t, s.Session.IdToken)
assert.Equal(t, "https://some.issuer.com", s.Session.IdToken.Issuer)
assert.Equal(t, "subject", s.Session.IdToken.Subject)
assert.Equal(t, timestamppb.New(tm), s.Session.IdToken.ExpiresAt)
assert.Equal(t, timestamppb.New(tm), s.Session.IdToken.IssuedAt)
assert.Equal(t, map[string]*structpb.ListValue{
"some-other-claim": {Values: []*structpb.Value{protoutil.ToStruct("xyz")}},
}, s.Claims)
}

View file

@ -0,0 +1,497 @@
// Package legacymanager contains an identity manager responsible for refreshing sessions and creating users.
package legacymanager
import (
"context"
"errors"
"time"
"github.com/google/btree"
"github.com/rs/zerolog"
"golang.org/x/oauth2"
"golang.org/x/sync/errgroup"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/fieldmaskpb"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/pomerium/pomerium/internal/atomicutil"
"github.com/pomerium/pomerium/internal/enabler"
"github.com/pomerium/pomerium/internal/events"
"github.com/pomerium/pomerium/internal/identity/identity"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/scheduler"
"github.com/pomerium/pomerium/internal/telemetry/metrics"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/grpc/session"
"github.com/pomerium/pomerium/pkg/grpc/user"
"github.com/pomerium/pomerium/pkg/grpcutil"
metrics_ids "github.com/pomerium/pomerium/pkg/metrics"
)
// Authenticator is an identity.Provider with only the methods needed by the manager.
type Authenticator interface {
Refresh(context.Context, *oauth2.Token, identity.State) (*oauth2.Token, error)
Revoke(context.Context, *oauth2.Token) error
UpdateUserInfo(context.Context, *oauth2.Token, interface{}) error
}
type (
updateRecordsMessage struct {
records []*databroker.Record
}
)
// A Manager refreshes identity information using session and user data.
type Manager struct {
enabler.Enabler
cfg *atomicutil.Value[*config]
sessionScheduler *scheduler.Scheduler
userScheduler *scheduler.Scheduler
sessions sessionCollection
users userCollection
}
// New creates a new identity manager.
func New(
options ...Option,
) *Manager {
mgr := &Manager{
cfg: atomicutil.NewValue(newConfig()),
sessionScheduler: scheduler.New(),
userScheduler: scheduler.New(),
}
mgr.Enabler = enabler.New("identity_manager", mgr, true)
mgr.reset()
mgr.UpdateConfig(options...)
return mgr
}
func withLog(ctx context.Context) context.Context {
return log.WithContext(ctx, func(c zerolog.Context) zerolog.Context {
return c.Str("service", "identity_manager")
})
}
// UpdateConfig updates the manager with the new options.
func (mgr *Manager) UpdateConfig(options ...Option) {
mgr.cfg.Store(newConfig(options...))
if mgr.cfg.Load().enabled {
mgr.Enable()
} else {
mgr.Disable()
}
}
// RunEnabled runs the manager. This method blocks until an error occurs or the given context is canceled.
func (mgr *Manager) RunEnabled(ctx context.Context) error {
leaser := databroker.NewLeaser("identity_manager", time.Second*30, mgr)
return leaser.Run(ctx)
}
// RunLeased runs the identity manager when a lease is acquired.
func (mgr *Manager) RunLeased(ctx context.Context) error {
ctx = withLog(ctx)
update := make(chan updateRecordsMessage, 1)
clear := make(chan struct{}, 1)
syncer := newDataBrokerSyncer(ctx, mgr.cfg, update, clear)
eg, ctx := errgroup.WithContext(ctx)
eg.Go(func() error {
return syncer.Run(ctx)
})
eg.Go(func() error {
return mgr.refreshLoop(ctx, update, clear)
})
return eg.Wait()
}
// GetDataBrokerServiceClient gets the databroker client.
func (mgr *Manager) GetDataBrokerServiceClient() databroker.DataBrokerServiceClient {
return mgr.cfg.Load().dataBrokerClient
}
func (mgr *Manager) now() time.Time {
return mgr.cfg.Load().now()
}
func (mgr *Manager) refreshLoop(ctx context.Context, update <-chan updateRecordsMessage, clear <-chan struct{}) error {
// wait for initial sync
select {
case <-ctx.Done():
return ctx.Err()
case <-clear:
mgr.reset()
}
select {
case <-ctx.Done():
return ctx.Err()
case msg := <-update:
mgr.onUpdateRecords(ctx, msg)
}
log.Debug(ctx).
Int("sessions", mgr.sessions.Len()).
Int("users", mgr.users.Len()).
Msg("initial sync complete")
// start refreshing
maxWait := time.Minute * 10
var nextTime time.Time
timer := time.NewTimer(0)
defer timer.Stop()
for {
select {
case <-ctx.Done():
return ctx.Err()
case <-clear:
mgr.reset()
case msg := <-update:
mgr.onUpdateRecords(ctx, msg)
case <-timer.C:
}
now := mgr.now()
nextTime = now.Add(maxWait)
// refresh sessions
for {
tm, key := mgr.sessionScheduler.Next()
if now.Before(tm) {
if tm.Before(nextTime) {
nextTime = tm
}
break
}
mgr.sessionScheduler.Remove(key)
userID, sessionID := fromSessionSchedulerKey(key)
mgr.refreshSession(ctx, userID, sessionID)
}
// refresh users
for {
tm, key := mgr.userScheduler.Next()
if now.Before(tm) {
if tm.Before(nextTime) {
nextTime = tm
}
break
}
mgr.userScheduler.Remove(key)
mgr.refreshUser(ctx, key)
}
metrics.RecordIdentityManagerLastRefresh(ctx)
timer.Reset(time.Until(nextTime))
}
}
// refreshSession handles two distinct session lifecycle events:
//
// 1. If the session itself has expired, delete the session.
// 2. If the session's underlying OAuth2 access token is nearing expiration
// (but the session itself is still valid), refresh the access token.
//
// After a successful access token refresh, this method will also trigger a
// user info refresh. If an access token refresh or a user info refresh fails
// with a permanent error, the session will be deleted.
func (mgr *Manager) refreshSession(ctx context.Context, userID, sessionID string) {
log.Info(ctx).
Str("user_id", userID).
Str("session_id", sessionID).
Msg("refreshing session")
s, ok := mgr.sessions.Get(userID, sessionID)
if !ok {
log.Warn(ctx).
Str("user_id", userID).
Str("session_id", sessionID).
Msg("no session found for refresh")
return
}
s.lastRefresh = mgr.now()
if mgr.refreshSessionInternal(ctx, userID, sessionID, &s) {
mgr.sessions.ReplaceOrInsert(s)
mgr.sessionScheduler.Add(s.NextRefresh(), toSessionSchedulerKey(userID, sessionID))
}
}
// refreshSessionInternal performs the core refresh logic and returns true if
// the session should be scheduled for refresh again, or false if not.
func (mgr *Manager) refreshSessionInternal(
ctx context.Context, userID, sessionID string, s *Session,
) bool {
authenticator := mgr.cfg.Load().authenticator
if authenticator == nil {
log.Info(ctx).
Str("user_id", userID).
Str("session_id", sessionID).
Msg("no authenticator defined, deleting session")
mgr.deleteSession(ctx, userID, sessionID)
return false
}
expiry := s.GetExpiresAt().AsTime()
if !expiry.After(mgr.now()) {
log.Info(ctx).
Str("user_id", userID).
Str("session_id", sessionID).
Msg("deleting expired session")
mgr.deleteSession(ctx, userID, sessionID)
return false
}
if s.Session == nil || s.Session.OauthToken == nil {
log.Warn(ctx).
Str("user_id", userID).
Str("session_id", sessionID).
Msg("no session oauth2 token found for refresh")
return false
}
newToken, err := authenticator.Refresh(ctx, FromOAuthToken(s.OauthToken), s)
metrics.RecordIdentityManagerSessionRefresh(ctx, err)
mgr.recordLastError(metrics_ids.IdentityManagerLastSessionRefreshError, err)
if isTemporaryError(err) {
log.Error(ctx).Err(err).
Str("user_id", s.GetUserId()).
Str("session_id", s.GetId()).
Msg("failed to refresh oauth2 token")
return true
} else if err != nil {
log.Error(ctx).Err(err).
Str("user_id", s.GetUserId()).
Str("session_id", s.GetId()).
Msg("failed to refresh oauth2 token, deleting session")
mgr.deleteSession(ctx, userID, sessionID)
return false
}
s.OauthToken = ToOAuthToken(newToken)
err = authenticator.UpdateUserInfo(ctx, FromOAuthToken(s.OauthToken), s)
metrics.RecordIdentityManagerUserRefresh(ctx, err)
mgr.recordLastError(metrics_ids.IdentityManagerLastUserRefreshError, err)
if isTemporaryError(err) {
log.Error(ctx).Err(err).
Str("user_id", s.GetUserId()).
Str("session_id", s.GetId()).
Msg("failed to update user info")
return true
} else if err != nil {
log.Error(ctx).Err(err).
Str("user_id", s.GetUserId()).
Str("session_id", s.GetId()).
Msg("failed to update user info, deleting session")
mgr.deleteSession(ctx, userID, sessionID)
return false
}
fm, err := fieldmaskpb.New(s.Session, "oauth_token", "id_token", "claims")
if err != nil {
log.Error(ctx).Err(err).Msg("internal error")
return false
}
if _, err := session.Patch(ctx, mgr.cfg.Load().dataBrokerClient, s.Session, fm); err != nil {
log.Error(ctx).Err(err).
Str("user_id", s.GetUserId()).
Str("session_id", s.GetId()).
Msg("failed to update session")
}
return true
}
func (mgr *Manager) refreshUser(ctx context.Context, userID string) {
log.Info(ctx).
Str("user_id", userID).
Msg("refreshing user")
authenticator := mgr.cfg.Load().authenticator
if authenticator == nil {
return
}
u, ok := mgr.users.Get(userID)
if !ok {
log.Warn(ctx).
Str("user_id", userID).
Msg("no user found for refresh")
return
}
u.lastRefresh = mgr.now()
mgr.userScheduler.Add(u.NextRefresh(), u.GetId())
for _, s := range mgr.sessions.GetSessionsForUser(userID) {
if s.Session == nil || s.Session.OauthToken == nil {
log.Warn(ctx).
Str("user_id", userID).
Msg("no session oauth2 token found for refresh")
continue
}
err := authenticator.UpdateUserInfo(ctx, FromOAuthToken(s.OauthToken), &u)
metrics.RecordIdentityManagerUserRefresh(ctx, err)
mgr.recordLastError(metrics_ids.IdentityManagerLastUserRefreshError, err)
if isTemporaryError(err) {
log.Error(ctx).Err(err).
Str("user_id", s.GetUserId()).
Str("session_id", s.GetId()).
Msg("failed to update user info")
return
} else if err != nil {
log.Error(ctx).Err(err).
Str("user_id", s.GetUserId()).
Str("session_id", s.GetId()).
Msg("failed to update user info, deleting session")
mgr.deleteSession(ctx, userID, s.GetId())
continue
}
res, err := databroker.Put(ctx, mgr.cfg.Load().dataBrokerClient, u.User)
if err != nil {
log.Error(ctx).Err(err).
Str("user_id", s.GetUserId()).
Str("session_id", s.GetId()).
Msg("failed to update user")
continue
}
mgr.onUpdateUser(ctx, res.GetRecords()[0], u.User)
}
}
func (mgr *Manager) onUpdateRecords(ctx context.Context, msg updateRecordsMessage) {
for _, record := range msg.records {
switch record.GetType() {
case grpcutil.GetTypeURL(new(session.Session)):
var pbSession session.Session
err := record.GetData().UnmarshalTo(&pbSession)
if err != nil {
log.Warn(ctx).Msgf("error unmarshaling session: %s", err)
continue
}
mgr.onUpdateSession(record, &pbSession)
case grpcutil.GetTypeURL(new(user.User)):
var pbUser user.User
err := record.GetData().UnmarshalTo(&pbUser)
if err != nil {
log.Warn(ctx).Msgf("error unmarshaling user: %s", err)
continue
}
mgr.onUpdateUser(ctx, record, &pbUser)
}
}
}
func (mgr *Manager) onUpdateSession(record *databroker.Record, session *session.Session) {
mgr.sessionScheduler.Remove(toSessionSchedulerKey(session.GetUserId(), session.GetId()))
if record.GetDeletedAt() != nil {
mgr.sessions.Delete(session.GetUserId(), session.GetId())
return
}
// update session
s, _ := mgr.sessions.Get(session.GetUserId(), session.GetId())
if s.lastRefresh.IsZero() {
s.lastRefresh = mgr.now()
}
s.gracePeriod = mgr.cfg.Load().sessionRefreshGracePeriod
s.coolOffDuration = mgr.cfg.Load().sessionRefreshCoolOffDuration
s.Session = session
mgr.sessions.ReplaceOrInsert(s)
mgr.sessionScheduler.Add(s.NextRefresh(), toSessionSchedulerKey(session.GetUserId(), session.GetId()))
}
func (mgr *Manager) onUpdateUser(_ context.Context, record *databroker.Record, user *user.User) {
mgr.userScheduler.Remove(user.GetId())
if record.GetDeletedAt() != nil {
mgr.users.Delete(user.GetId())
return
}
u, _ := mgr.users.Get(user.GetId())
u.lastRefresh = mgr.cfg.Load().now()
u.User = user
mgr.users.ReplaceOrInsert(u)
mgr.userScheduler.Add(u.NextRefresh(), u.GetId())
}
func (mgr *Manager) deleteSession(ctx context.Context, userID, sessionID string) {
mgr.sessionScheduler.Remove(toSessionSchedulerKey(userID, sessionID))
mgr.sessions.Delete(userID, sessionID)
client := mgr.cfg.Load().dataBrokerClient
res, err := client.Get(ctx, &databroker.GetRequest{
Type: grpcutil.GetTypeURL(new(session.Session)),
Id: sessionID,
})
if status.Code(err) == codes.NotFound {
return
} else if err != nil {
log.Error(ctx).Err(err).
Str("session_id", sessionID).
Msg("failed to delete session")
return
}
record := res.GetRecord()
record.DeletedAt = timestamppb.Now()
_, err = client.Put(ctx, &databroker.PutRequest{
Records: []*databroker.Record{record},
})
if err != nil {
log.Error(ctx).Err(err).
Str("session_id", sessionID).
Msg("failed to delete session")
return
}
}
// reset resets all the manager datastructures to their initial state
func (mgr *Manager) reset() {
mgr.sessions = sessionCollection{BTree: btree.New(8)}
mgr.users = userCollection{BTree: btree.New(8)}
}
func (mgr *Manager) recordLastError(id string, err error) {
if err == nil {
return
}
evtMgr := mgr.cfg.Load().eventMgr
if evtMgr == nil {
return
}
evtMgr.Dispatch(&events.LastError{
Time: timestamppb.Now(),
Message: err.Error(),
Id: id,
})
}
func isTemporaryError(err error) bool {
if err == nil {
return false
}
if errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled) {
return true
}
var hasTemporary interface{ Temporary() bool }
if errors.As(err, &hasTemporary) && hasTemporary.Temporary() {
return true
}
return false
}

View file

@ -0,0 +1,360 @@
package legacymanager
import (
"context"
"errors"
"fmt"
"testing"
"time"
"github.com/stretchr/testify/assert"
"go.uber.org/mock/gomock"
"golang.org/x/oauth2"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/fieldmaskpb"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/pomerium/pomerium/internal/events"
"github.com/pomerium/pomerium/internal/identity/identity"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/grpc/databroker/mock_databroker"
"github.com/pomerium/pomerium/pkg/grpc/session"
"github.com/pomerium/pomerium/pkg/grpc/user"
metrics_ids "github.com/pomerium/pomerium/pkg/metrics"
"github.com/pomerium/pomerium/pkg/protoutil"
)
type mockAuthenticator struct {
refreshResult *oauth2.Token
refreshError error
revokeError error
updateUserInfoError error
}
func (mock *mockAuthenticator) Refresh(_ context.Context, _ *oauth2.Token, _ identity.State) (*oauth2.Token, error) {
return mock.refreshResult, mock.refreshError
}
func (mock *mockAuthenticator) Revoke(_ context.Context, _ *oauth2.Token) error {
return mock.revokeError
}
func (mock *mockAuthenticator) UpdateUserInfo(_ context.Context, _ *oauth2.Token, _ any) error {
return mock.updateUserInfoError
}
func TestManager_refresh(t *testing.T) {
ctrl := gomock.NewController(t)
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10)
t.Cleanup(clearTimeout)
client := mock_databroker.NewMockDataBrokerServiceClient(ctrl)
mgr := New(WithDataBrokerClient(client))
mgr.onUpdateRecords(ctx, updateRecordsMessage{
records: []*databroker.Record{
databroker.NewRecord(&session.Session{
Id: "s1",
UserId: "u1",
OauthToken: &session.OAuthToken{},
ExpiresAt: timestamppb.New(time.Now().Add(time.Second * 10)),
}),
databroker.NewRecord(&user.User{
Id: "u1",
}),
},
})
client.EXPECT().Get(gomock.Any(), gomock.Any()).Return(nil, status.Error(codes.NotFound, "not found"))
mgr.refreshSession(ctx, "u1", "s1")
mgr.refreshUser(ctx, "u1")
}
func TestManager_onUpdateRecords(t *testing.T) {
ctrl := gomock.NewController(t)
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10)
defer clearTimeout()
now := time.Now()
mgr := New(
WithDataBrokerClient(mock_databroker.NewMockDataBrokerServiceClient(ctrl)),
WithNow(func() time.Time {
return now
}),
)
mgr.onUpdateRecords(ctx, updateRecordsMessage{
records: []*databroker.Record{
mkRecord(&session.Session{Id: "session1", UserId: "user1"}),
mkRecord(&user.User{Id: "user1", Name: "user 1", Email: "user1@example.com"}),
},
})
if _, ok := mgr.sessions.Get("user1", "session1"); assert.True(t, ok) {
tm, id := mgr.sessionScheduler.Next()
assert.Equal(t, now.Add(10*time.Second), tm)
assert.Equal(t, "user1\037session1", id)
}
if _, ok := mgr.users.Get("user1"); assert.True(t, ok) {
tm, id := mgr.userScheduler.Next()
assert.Equal(t, now.Add(userRefreshInterval), tm)
assert.Equal(t, "user1", id)
}
}
func TestManager_onUpdateSession(t *testing.T) {
startTime := time.Date(2023, 10, 19, 12, 0, 0, 0, time.UTC)
s := &session.Session{
Id: "session-id",
UserId: "user-id",
OauthToken: &session.OAuthToken{
AccessToken: "access-token",
ExpiresAt: timestamppb.New(startTime.Add(5 * time.Minute)),
},
IssuedAt: timestamppb.New(startTime),
ExpiresAt: timestamppb.New(startTime.Add(24 * time.Hour)),
}
assertNextScheduled := func(t *testing.T, mgr *Manager, expectedTime time.Time) {
t.Helper()
tm, key := mgr.sessionScheduler.Next()
assert.Equal(t, expectedTime, tm)
assert.Equal(t, "user-id\037session-id", key)
}
t.Run("initial refresh event when not expiring soon", func(t *testing.T) {
now := startTime
mgr := New(WithNow(func() time.Time { return now }))
// When the Manager first becomes aware of a session it should schedule
// a refresh event for one minute before access token expiration.
mgr.onUpdateSession(mkRecord(s), s)
assertNextScheduled(t, mgr, startTime.Add(4*time.Minute))
})
t.Run("initial refresh event when expiring soon", func(t *testing.T) {
now := startTime
mgr := New(WithNow(func() time.Time { return now }))
// When the Manager first becomes aware of a session, if that session
// is expiring within the gracePeriod (1 minute), it should schedule a
// refresh event for as soon as possible, subject to the
// coolOffDuration (10 seconds).
now = now.Add(4*time.Minute + 30*time.Second) // 30 s before expiration
mgr.onUpdateSession(mkRecord(s), s)
assertNextScheduled(t, mgr, now.Add(10*time.Second))
})
t.Run("update near scheduled refresh", func(t *testing.T) {
now := startTime
mgr := New(WithNow(func() time.Time { return now }))
mgr.onUpdateSession(mkRecord(s), s)
assertNextScheduled(t, mgr, startTime.Add(4*time.Minute))
// If a session is updated close to the time when it is scheduled to be
// refreshed, the scheduled refresh event should not be pushed back.
now = now.Add(3*time.Minute + 55*time.Second) // 5 s before refresh
mgr.onUpdateSession(mkRecord(s), s)
assertNextScheduled(t, mgr, now.Add(5*time.Second))
// However, if an update changes the access token validity, the refresh
// event should be rescheduled accordingly. (This should be uncommon,
// as only the refresh loop itself should modify the access token.)
s2 := proto.Clone(s).(*session.Session)
s2.OauthToken.ExpiresAt = timestamppb.New(now.Add(5 * time.Minute))
mgr.onUpdateSession(mkRecord(s2), s2)
assertNextScheduled(t, mgr, now.Add(4*time.Minute))
})
t.Run("session record deleted", func(t *testing.T) {
now := startTime
mgr := New(WithNow(func() time.Time { return now }))
mgr.onUpdateSession(mkRecord(s), s)
assertNextScheduled(t, mgr, startTime.Add(4*time.Minute))
// If a session is deleted, any scheduled refresh event should be canceled.
record := mkRecord(s)
record.DeletedAt = timestamppb.New(now)
mgr.onUpdateSession(record, s)
_, key := mgr.sessionScheduler.Next()
assert.Empty(t, key)
})
}
func TestManager_refreshSession(t *testing.T) {
startTime := time.Date(2023, 10, 19, 12, 0, 0, 0, time.UTC)
var auth mockAuthenticator
ctrl := gomock.NewController(t)
client := mock_databroker.NewMockDataBrokerServiceClient(ctrl)
now := startTime
mgr := New(
WithDataBrokerClient(client),
WithNow(func() time.Time { return now }),
WithAuthenticator(&auth),
)
// Initialize the Manager with a new session.
s := &session.Session{
Id: "session-id",
UserId: "user-id",
OauthToken: &session.OAuthToken{
AccessToken: "access-token",
ExpiresAt: timestamppb.New(startTime.Add(5 * time.Minute)),
RefreshToken: "refresh-token",
},
IssuedAt: timestamppb.New(startTime),
ExpiresAt: timestamppb.New(startTime.Add(24 * time.Hour)),
}
mgr.sessions.ReplaceOrInsert(Session{
Session: s,
lastRefresh: startTime,
gracePeriod: time.Minute,
coolOffDuration: 10 * time.Second,
})
// If OAuth2 token refresh fails with a temporary error, the manager should
// still reschedule another refresh attempt.
now = now.Add(4 * time.Minute)
auth.refreshError = context.DeadlineExceeded
mgr.refreshSession(context.Background(), "user-id", "session-id")
tm, key := mgr.sessionScheduler.Next()
assert.Equal(t, now.Add(10*time.Second), tm)
assert.Equal(t, "user-id\037session-id", key)
// Simulate a successful token refresh on the second attempt. The manager
// should store the updated session in the databroker and schedule another
// refresh event.
now = now.Add(10 * time.Second)
auth.refreshResult, auth.refreshError = &oauth2.Token{
AccessToken: "new-access-token",
RefreshToken: "new-refresh-token",
Expiry: now.Add(5 * time.Minute),
}, nil
expectedSession := proto.Clone(s).(*session.Session)
expectedSession.OauthToken = &session.OAuthToken{
AccessToken: "new-access-token",
ExpiresAt: timestamppb.New(now.Add(5 * time.Minute)),
RefreshToken: "new-refresh-token",
}
client.EXPECT().Patch(gomock.Any(), objectsAreEqualMatcher{
&databroker.PatchRequest{
Records: []*databroker.Record{{
Type: "type.googleapis.com/session.Session",
Id: "session-id",
Data: protoutil.NewAny(expectedSession),
}},
FieldMask: &fieldmaskpb.FieldMask{
Paths: []string{"oauth_token", "id_token", "claims"},
},
},
}).
Return(nil /* this result is currently unused */, nil)
mgr.refreshSession(context.Background(), "user-id", "session-id")
tm, key = mgr.sessionScheduler.Next()
assert.Equal(t, now.Add(4*time.Minute), tm)
assert.Equal(t, "user-id\037session-id", key)
}
func TestManager_reportErrors(t *testing.T) {
ctrl := gomock.NewController(t)
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10)
defer clearTimeout()
evtMgr := events.New()
received := make(chan events.Event, 1)
handle := evtMgr.Register(func(evt events.Event) {
received <- evt
})
defer evtMgr.Unregister(handle)
expectMsg := func(id, msg string) {
t.Helper()
assert.Eventually(t, func() bool {
select {
case evt := <-received:
lastErr := evt.(*events.LastError)
return msg == lastErr.Message && id == lastErr.Id
default:
return false
}
}, time.Second, time.Millisecond*20, msg)
}
s := &session.Session{
Id: "session1",
UserId: "user1",
OauthToken: &session.OAuthToken{
ExpiresAt: timestamppb.New(time.Now().Add(time.Hour)),
},
ExpiresAt: timestamppb.New(time.Now().Add(time.Hour)),
}
client := mock_databroker.NewMockDataBrokerServiceClient(ctrl)
client.EXPECT().Get(gomock.Any(), gomock.Any()).AnyTimes().Return(&databroker.GetResponse{Record: databroker.NewRecord(s)}, nil)
client.EXPECT().Put(gomock.Any(), gomock.Any()).AnyTimes()
mgr := New(
WithEventManager(evtMgr),
WithDataBrokerClient(client),
WithAuthenticator(&mockAuthenticator{
refreshError: errors.New("update session"),
updateUserInfoError: errors.New("update user info"),
}),
)
mgr.onUpdateRecords(ctx, updateRecordsMessage{
records: []*databroker.Record{
mkRecord(s),
mkRecord(&user.User{Id: "user1", Name: "user 1", Email: "user1@example.com"}),
},
})
mgr.refreshUser(ctx, "user1")
expectMsg(metrics_ids.IdentityManagerLastUserRefreshError, "update user info")
mgr.onUpdateRecords(ctx, updateRecordsMessage{
records: []*databroker.Record{
mkRecord(s),
mkRecord(&user.User{Id: "user1", Name: "user 1", Email: "user1@example.com"}),
},
})
mgr.refreshSession(ctx, "user1", "session1")
expectMsg(metrics_ids.IdentityManagerLastSessionRefreshError, "update session")
}
func mkRecord(msg recordable) *databroker.Record {
data := protoutil.NewAny(msg)
return &databroker.Record{
Type: data.GetTypeUrl(),
Id: msg.GetId(),
Data: data,
}
}
type recordable interface {
proto.Message
GetId() string
}
// objectsAreEqualMatcher implements gomock.Matcher using ObjectsAreEqual. This
// is especially helpful when working with pointers, as it will compare the
// underlying values rather than the pointers themselves.
type objectsAreEqualMatcher struct {
expected interface{}
}
func (m objectsAreEqualMatcher) Matches(x interface{}) bool {
return assert.ObjectsAreEqual(m.expected, x)
}
func (m objectsAreEqualMatcher) String() string {
return fmt.Sprintf("is equal to %v (%T)", m.expected, m.expected)
}

View file

@ -0,0 +1,46 @@
package legacymanager
import (
"strings"
"golang.org/x/oauth2"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/pomerium/pomerium/pkg/grpc/session"
)
func toSessionSchedulerKey(userID, sessionID string) string {
return userID + "\037" + sessionID
}
func fromSessionSchedulerKey(key string) (userID, sessionID string) {
idx := strings.Index(key, "\037")
if idx >= 0 {
userID = key[:idx]
sessionID = key[idx+1:]
} else {
userID = key
}
return userID, sessionID
}
// FromOAuthToken converts a session oauth token to oauth2.Token.
func FromOAuthToken(token *session.OAuthToken) *oauth2.Token {
return &oauth2.Token{
AccessToken: token.GetAccessToken(),
TokenType: token.GetTokenType(),
RefreshToken: token.GetRefreshToken(),
Expiry: token.GetExpiresAt().AsTime(),
}
}
// ToOAuthToken converts an oauth2.Token to a session oauth token.
func ToOAuthToken(token *oauth2.Token) *session.OAuthToken {
expiry := timestamppb.New(token.Expiry)
return &session.OAuthToken{
AccessToken: token.AccessToken,
TokenType: token.TokenType,
RefreshToken: token.RefreshToken,
ExpiresAt: expiry,
}
}

View file

@ -0,0 +1,55 @@
package legacymanager
import (
"context"
"github.com/pomerium/pomerium/internal/atomicutil"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
)
type dataBrokerSyncer struct {
cfg *atomicutil.Value[*config]
update chan<- updateRecordsMessage
clear chan<- struct{}
syncer *databroker.Syncer
}
func newDataBrokerSyncer(
_ context.Context,
cfg *atomicutil.Value[*config],
update chan<- updateRecordsMessage,
clear chan<- struct{},
) *dataBrokerSyncer {
syncer := &dataBrokerSyncer{
cfg: cfg,
update: update,
clear: clear,
}
syncer.syncer = databroker.NewSyncer("identity_manager", syncer)
return syncer
}
func (syncer *dataBrokerSyncer) Run(ctx context.Context) (err error) {
return syncer.syncer.Run(ctx)
}
func (syncer *dataBrokerSyncer) ClearRecords(ctx context.Context) {
select {
case <-ctx.Done():
case syncer.clear <- struct{}{}:
}
}
func (syncer *dataBrokerSyncer) GetDataBrokerServiceClient() databroker.DataBrokerServiceClient {
return syncer.cfg.Load().dataBrokerClient
}
func (syncer *dataBrokerSyncer) UpdateRecords(ctx context.Context, _ uint64, records []*databroker.Record) {
select {
case <-ctx.Done():
case syncer.update <- updateRecordsMessage{records: records}:
}
}

View file

@ -10,6 +10,7 @@ import (
var (
defaultSessionRefreshGracePeriod = 1 * time.Minute
defaultSessionRefreshCoolOffDuration = 10 * time.Second
defaultUpdateUserInfoInterval = 10 * time.Minute
)
type config struct {
@ -17,6 +18,7 @@ type config struct {
dataBrokerClient databroker.DataBrokerServiceClient
sessionRefreshGracePeriod time.Duration
sessionRefreshCoolOffDuration time.Duration
updateUserInfoInterval time.Duration
now func() time.Time
eventMgr *events.Manager
enabled bool
@ -27,6 +29,7 @@ func newConfig(options ...Option) *config {
WithSessionRefreshGracePeriod(defaultSessionRefreshGracePeriod)(cfg)
WithSessionRefreshCoolOffDuration(defaultSessionRefreshCoolOffDuration)(cfg)
WithNow(time.Now)(cfg)
WithUpdateUserInfoInterval(defaultUpdateUserInfoInterval)(cfg)
WithEnabled(true)(cfg)
for _, option := range options {
option(cfg)
@ -85,3 +88,10 @@ func WithEnabled(enabled bool) Option {
cfg.enabled = enabled
}
}
// WithUpdateUserInfoInterval sets the update user info interval in the config.
func WithUpdateUserInfoInterval(dur time.Duration) Option {
return func(cfg *config) {
cfg.updateUserInfoInterval = dur
}
}

View file

@ -2,9 +2,9 @@ package manager
import (
"encoding/json"
"errors"
"time"
"github.com/google/btree"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/pomerium/pomerium/internal/identity"
@ -12,66 +12,18 @@ import (
"github.com/pomerium/pomerium/pkg/grpc/user"
)
const userRefreshInterval = 10 * time.Minute
// A User is a user managed by the Manager.
type User struct {
*user.User
lastRefresh time.Time
}
// NextRefresh returns the next time the user information needs to be refreshed.
func (u User) NextRefresh() time.Time {
return u.lastRefresh.Add(userRefreshInterval)
}
// UnmarshalJSON unmarshals json data into the user object.
func (u *User) UnmarshalJSON(data []byte) error {
if u.User == nil {
u.User = new(user.User)
}
var raw map[string]json.RawMessage
err := json.Unmarshal(data, &raw)
if err != nil {
return err
}
if name, ok := raw["name"]; ok {
_ = json.Unmarshal(name, &u.User.Name)
delete(raw, "name")
}
if email, ok := raw["email"]; ok {
_ = json.Unmarshal(email, &u.User.Email)
delete(raw, "email")
}
u.AddClaims(identity.NewClaimsFromRaw(raw).Flatten())
return nil
}
// A Session is a session managed by the Manager.
type Session struct {
*session.Session
// lastRefresh is the time of the last refresh attempt (which may or may
// not have succeeded), or else the time the Manager first became aware of
// the session (if it has not yet attempted to refresh this session).
lastRefresh time.Time
// gracePeriod is the amount of time before expiration to attempt a refresh.
gracePeriod time.Duration
// coolOffDuration is the amount of time to wait before attempting another refresh.
coolOffDuration time.Duration
}
// NextRefresh returns the next time the session needs to be refreshed.
func (s Session) NextRefresh() time.Time {
func nextSessionRefresh(
s *session.Session,
lastRefresh time.Time,
gracePeriod time.Duration,
coolOffDuration time.Duration,
) time.Time {
var tm time.Time
if s.GetOauthToken().GetExpiresAt() != nil {
expiry := s.GetOauthToken().GetExpiresAt().AsTime()
if s.GetOauthToken().GetExpiresAt().IsValid() && !expiry.IsZero() {
expiry = expiry.Add(-s.gracePeriod)
expiry = expiry.Add(-gracePeriod)
if tm.IsZero() || expiry.Before(tm) {
tm = expiry
}
@ -88,7 +40,7 @@ func (s Session) NextRefresh() time.Time {
}
// don't refresh any quicker than the cool-off duration
min := s.lastRefresh.Add(s.coolOffDuration)
min := lastRefresh.Add(coolOffDuration)
if tm.Before(min) {
tm = min
}
@ -96,10 +48,35 @@ func (s Session) NextRefresh() time.Time {
return tm
}
// UnmarshalJSON unmarshals json data into the session object.
func (s *Session) UnmarshalJSON(data []byte) error {
if s.Session == nil {
s.Session = new(session.Session)
// a multiUnmarshaler is used as the target of the json Unmarshal function to
// unmarshal a single JSON value into multiple destinations.
type multiUnmarshaler []any
func newMultiUnmarshaler(args ...any) *multiUnmarshaler {
return (*multiUnmarshaler)(&args)
}
func (dst *multiUnmarshaler) UnmarshalJSON(data []byte) error {
var err error
for _, o := range *dst {
if o != nil {
err = errors.Join(err, json.Unmarshal(data, o))
}
}
return err
}
type sessionUnmarshaler struct {
*session.Session
}
func newSessionUnmarshaler(s *session.Session) *sessionUnmarshaler {
return &sessionUnmarshaler{Session: s}
}
func (dst *sessionUnmarshaler) UnmarshalJSON(data []byte) error {
if dst.Session == nil {
return nil
}
var raw map[string]json.RawMessage
@ -108,159 +85,67 @@ func (s *Session) UnmarshalJSON(data []byte) error {
return err
}
if s.Session.IdToken == nil {
s.Session.IdToken = new(session.IDToken)
if dst.Session.IdToken == nil {
dst.Session.IdToken = new(session.IDToken)
}
if iss, ok := raw["iss"]; ok {
_ = json.Unmarshal(iss, &s.Session.IdToken.Issuer)
_ = json.Unmarshal(iss, &dst.Session.IdToken.Issuer)
delete(raw, "iss")
}
if sub, ok := raw["sub"]; ok {
_ = json.Unmarshal(sub, &s.Session.IdToken.Subject)
_ = json.Unmarshal(sub, &dst.Session.IdToken.Subject)
delete(raw, "sub")
}
if exp, ok := raw["exp"]; ok {
var secs int64
if err := json.Unmarshal(exp, &secs); err == nil {
s.Session.IdToken.ExpiresAt = timestamppb.New(time.Unix(secs, 0))
dst.Session.IdToken.ExpiresAt = timestamppb.New(time.Unix(secs, 0))
}
delete(raw, "exp")
}
if iat, ok := raw["iat"]; ok {
var secs int64
if err := json.Unmarshal(iat, &secs); err == nil {
s.Session.IdToken.IssuedAt = timestamppb.New(time.Unix(secs, 0))
dst.Session.IdToken.IssuedAt = timestamppb.New(time.Unix(secs, 0))
}
delete(raw, "iat")
}
s.AddClaims(identity.NewClaimsFromRaw(raw).Flatten())
dst.Session.AddClaims(identity.NewClaimsFromRaw(raw).Flatten())
return nil
}
type sessionCollectionItem struct {
Session
type userUnmarshaler struct {
*user.User
}
func (item sessionCollectionItem) Less(than btree.Item) bool {
xUserID, yUserID := item.GetUserId(), than.(sessionCollectionItem).GetUserId()
switch {
case xUserID < yUserID:
return true
case yUserID < xUserID:
return false
func newUserUnmarshaler(u *user.User) *userUnmarshaler {
return &userUnmarshaler{User: u}
}
func (dst *userUnmarshaler) UnmarshalJSON(data []byte) error {
if dst.User == nil {
return nil
}
xID, yID := item.GetId(), than.(sessionCollectionItem).GetId()
switch {
case xID < yID:
return true
case yID < xID:
return false
}
return false
}
type sessionCollection struct {
*btree.BTree
}
func (c *sessionCollection) Delete(userID, sessionID string) {
c.BTree.Delete(sessionCollectionItem{
Session: Session{
Session: &session.Session{
UserId: userID,
Id: sessionID,
},
},
})
}
func (c *sessionCollection) Get(userID, sessionID string) (Session, bool) {
item := c.BTree.Get(sessionCollectionItem{
Session: Session{
Session: &session.Session{
UserId: userID,
Id: sessionID,
},
},
})
if item == nil {
return Session{}, false
}
return item.(sessionCollectionItem).Session, true
}
// GetSessionsForUser gets all the sessions for the given user.
func (c *sessionCollection) GetSessionsForUser(userID string) []Session {
var sessions []Session
c.AscendGreaterOrEqual(sessionCollectionItem{
Session: Session{
Session: &session.Session{
UserId: userID,
},
},
}, func(item btree.Item) bool {
s := item.(sessionCollectionItem).Session
if s.UserId != userID {
return false
var raw map[string]json.RawMessage
err := json.Unmarshal(data, &raw)
if err != nil {
return err
}
sessions = append(sessions, s)
return true
})
return sessions
}
func (c *sessionCollection) ReplaceOrInsert(s Session) {
c.BTree.ReplaceOrInsert(sessionCollectionItem{Session: s})
}
type userCollectionItem struct {
User
}
func (item userCollectionItem) Less(than btree.Item) bool {
xID, yID := item.GetId(), than.(userCollectionItem).GetId()
switch {
case xID < yID:
return true
case yID < xID:
return false
if name, ok := raw["name"]; ok {
_ = json.Unmarshal(name, &dst.User.Name)
delete(raw, "name")
}
return false
}
type userCollection struct {
*btree.BTree
}
func (c *userCollection) Delete(userID string) {
c.BTree.Delete(userCollectionItem{
User: User{
User: &user.User{
Id: userID,
},
},
})
}
func (c *userCollection) Get(userID string) (User, bool) {
item := c.BTree.Get(userCollectionItem{
User: User{
User: &user.User{
Id: userID,
},
},
})
if item == nil {
return User{}, false
if email, ok := raw["email"]; ok {
_ = json.Unmarshal(email, &dst.User.Email)
delete(raw, "email")
}
return item.(userCollectionItem).User, true
}
func (c *userCollection) ReplaceOrInsert(u User) {
c.BTree.ReplaceOrInsert(userCollectionItem{User: u})
dst.User.AddClaims(identity.NewClaimsFromRaw(raw).Flatten())
return nil
}

View file

@ -11,20 +11,20 @@ import (
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/pomerium/pomerium/pkg/grpc/session"
"github.com/pomerium/pomerium/pkg/grpc/user"
"github.com/pomerium/pomerium/pkg/protoutil"
)
func TestUser_UnmarshalJSON(t *testing.T) {
var u User
u := new(user.User)
err := json.Unmarshal([]byte(`{
"name": "joe",
"email": "joe@test.com",
"some-other-claim": "xyz"
}`), &u)
}`), newUserUnmarshaler(u))
assert.NoError(t, err)
assert.NotNil(t, u.User)
assert.Equal(t, "joe", u.User.Name)
assert.Equal(t, "joe@test.com", u.User.Email)
assert.Equal(t, "joe", u.Name)
assert.Equal(t, "joe@test.com", u.Email)
assert.Equal(t, map[string]*structpb.ListValue{
"some-other-claim": {Values: []*structpb.Value{protoutil.ToStruct("xyz")}},
}, u.Claims)
@ -32,42 +32,38 @@ func TestUser_UnmarshalJSON(t *testing.T) {
func TestSession_NextRefresh(t *testing.T) {
tm1 := time.Date(2020, 6, 5, 12, 0, 0, 0, time.UTC)
s := Session{
Session: &session.Session{},
lastRefresh: tm1,
gracePeriod: time.Second * 10,
coolOffDuration: time.Minute,
}
assert.Equal(t, tm1.Add(time.Minute), s.NextRefresh())
s := &session.Session{}
gracePeriod := time.Second * 10
coolOffDuration := time.Minute
assert.Equal(t, tm1.Add(time.Minute), nextSessionRefresh(s, tm1, gracePeriod, coolOffDuration))
tm2 := time.Date(2020, 6, 5, 13, 0, 0, 0, time.UTC)
s.OauthToken = &session.OAuthToken{
ExpiresAt: timestamppb.New(tm2),
}
assert.Equal(t, tm2.Add(-time.Second*10), s.NextRefresh())
assert.Equal(t, tm2.Add(-time.Second*10), nextSessionRefresh(s, tm1, gracePeriod, coolOffDuration))
tm3 := time.Date(2020, 6, 5, 12, 15, 0, 0, time.UTC)
s.ExpiresAt = timestamppb.New(tm3)
assert.Equal(t, tm3, s.NextRefresh())
assert.Equal(t, tm3, nextSessionRefresh(s, tm1, gracePeriod, coolOffDuration))
}
func TestSession_UnmarshalJSON(t *testing.T) {
tm := time.Date(2020, 6, 5, 12, 0, 0, 0, time.UTC)
var s Session
s := new(session.Session)
err := json.Unmarshal([]byte(`{
"iss": "https://some.issuer.com",
"sub": "subject",
"exp": `+fmt.Sprint(tm.Unix())+`,
"iat": `+fmt.Sprint(tm.Unix())+`,
"some-other-claim": "xyz"
}`), &s)
}`), newSessionUnmarshaler(s))
assert.NoError(t, err)
assert.NotNil(t, s.Session)
assert.NotNil(t, s.Session.IdToken)
assert.Equal(t, "https://some.issuer.com", s.Session.IdToken.Issuer)
assert.Equal(t, "subject", s.Session.IdToken.Subject)
assert.Equal(t, timestamppb.New(tm), s.Session.IdToken.ExpiresAt)
assert.Equal(t, timestamppb.New(tm), s.Session.IdToken.IssuedAt)
assert.NotNil(t, s.IdToken)
assert.Equal(t, "https://some.issuer.com", s.IdToken.Issuer)
assert.Equal(t, "subject", s.IdToken.Subject)
assert.Equal(t, timestamppb.New(tm), s.IdToken.ExpiresAt)
assert.Equal(t, timestamppb.New(tm), s.IdToken.IssuedAt)
assert.Equal(t, map[string]*structpb.ListValue{
"some-other-claim": {Values: []*structpb.Value{protoutil.ToStruct("xyz")}},
}, s.Claims)

View file

@ -0,0 +1,124 @@
package manager
import (
"cmp"
"slices"
"google.golang.org/protobuf/proto"
"github.com/pomerium/pomerium/pkg/grpc/session"
"github.com/pomerium/pomerium/pkg/grpc/user"
)
// dataStore stores session and user data
type dataStore struct {
sessions map[string]*session.Session
users map[string]*user.User
userIDToSessionIDs map[string]map[string]struct{}
}
func newDataStore() *dataStore {
ds := new(dataStore)
ds.deleteAllSessions()
ds.deleteAllUsers()
return ds
}
func (ds *dataStore) deleteAllSessions() {
ds.sessions = make(map[string]*session.Session)
ds.userIDToSessionIDs = make(map[string]map[string]struct{})
}
func (ds *dataStore) deleteAllUsers() {
ds.users = make(map[string]*user.User)
}
func (ds *dataStore) deleteSession(sessionID string) {
s := ds.sessions[sessionID]
delete(ds.sessions, sessionID)
if s.GetUserId() == "" {
return
}
m := ds.userIDToSessionIDs[s.GetUserId()]
if m != nil {
delete(m, s.GetId())
}
if len(m) == 0 {
delete(ds.userIDToSessionIDs, s.GetUserId())
}
}
func (ds *dataStore) deleteUser(userID string) {
delete(ds.users, userID)
}
func (ds *dataStore) getSessionAndUser(sessionID string) (s *session.Session, u *user.User) {
s = ds.sessions[sessionID]
if s.GetUserId() != "" {
u = ds.users[s.GetUserId()]
}
// clone to avoid sharing memory
s = clone(s)
u = clone(u)
return s, u
}
func (ds *dataStore) getUserAndSessions(userID string) (u *user.User, ss []*session.Session) {
u = ds.users[userID]
for sessionID := range ds.userIDToSessionIDs[userID] {
ss = append(ss, ds.sessions[sessionID])
}
// remove nils and sort by id
ss = slices.Compact(ss)
slices.SortFunc(ss, func(a, b *session.Session) int {
return cmp.Compare(a.GetId(), b.GetId())
})
// clone to avoid sharing memory
u = clone(u)
for i := range ss {
ss[i] = clone(ss[i])
}
return u, ss
}
func (ds *dataStore) putSession(s *session.Session) {
// clone to avoid sharing memory
s = clone(s)
if s.GetId() != "" {
ds.deleteSession(s.GetId())
ds.sessions[s.GetId()] = s
if s.GetUserId() != "" {
m, ok := ds.userIDToSessionIDs[s.GetUserId()]
if !ok {
m = make(map[string]struct{})
ds.userIDToSessionIDs[s.GetUserId()] = m
}
m[s.GetId()] = struct{}{}
}
}
}
func (ds *dataStore) putUser(u *user.User) {
// clone to avoid sharing memory
u = clone(u)
if u.GetId() != "" {
ds.users[u.GetId()] = u
}
}
// clone clones a protobuf message
func clone[T any, U interface {
*T
proto.Message
}](src U) U {
if src == nil {
return src
}
return proto.Clone(src).(U)
}

View file

@ -0,0 +1,78 @@
package manager
import (
"testing"
"github.com/google/go-cmp/cmp"
"github.com/stretchr/testify/assert"
"google.golang.org/protobuf/testing/protocmp"
"github.com/pomerium/pomerium/pkg/grpc/session"
"github.com/pomerium/pomerium/pkg/grpc/user"
)
func TestDataStore(t *testing.T) {
t.Parallel()
ds := newDataStore()
s, u := ds.getSessionAndUser("S1")
assert.Nil(t, s, "should return a nil session when none exists")
assert.Nil(t, u, "should return a nil user when none exists")
u, ss := ds.getUserAndSessions("U1")
assert.Nil(t, u, "should return a nil user when none exists")
assert.Empty(t, ss, "should return an empty list of sessions when no user exists")
s = &session.Session{Id: "S1", UserId: "U1"}
ds.putSession(s)
s1, u1 := ds.getSessionAndUser("S1")
assert.NotNil(t, s1, "should return a non-nil session")
assert.NotSame(t, s, s1, "should return different pointers")
assert.Empty(t, cmp.Diff(s, s1, protocmp.Transform()), "should be the same as was entered")
assert.Nil(t, u1, "should return a nil user when only the session exists")
ds.putUser(&user.User{
Id: "U1",
})
_, u1 = ds.getSessionAndUser("S1")
assert.NotNil(t, u1, "should return a user now that it has been added")
ds.putSession(&session.Session{Id: "S4", UserId: "U1"})
ds.putSession(&session.Session{Id: "S3", UserId: "U1"})
ds.putSession(&session.Session{Id: "S2", UserId: "U1"})
u, ss = ds.getUserAndSessions("U1")
assert.NotNil(t, u)
assert.Empty(t, cmp.Diff(ss, []*session.Session{
{Id: "S1", UserId: "U1"},
{Id: "S2", UserId: "U1"},
{Id: "S3", UserId: "U1"},
{Id: "S4", UserId: "U1"},
}, protocmp.Transform()), "should return all sessions in id order")
ds.deleteSession("S4")
u, ss = ds.getUserAndSessions("U1")
assert.NotNil(t, u)
assert.Empty(t, cmp.Diff(ss, []*session.Session{
{Id: "S1", UserId: "U1"},
{Id: "S2", UserId: "U1"},
{Id: "S3", UserId: "U1"},
}, protocmp.Transform()), "should return all sessions in id order")
ds.deleteUser("U1")
u, ss = ds.getUserAndSessions("U1")
assert.Nil(t, u)
assert.Empty(t, cmp.Diff(ss, []*session.Session{
{Id: "S1", UserId: "U1"},
{Id: "S2", UserId: "U1"},
{Id: "S3", UserId: "U1"},
}, protocmp.Transform()), "should still return all sessions in id order")
ds.deleteSession("S1")
ds.deleteSession("S2")
ds.deleteSession("S3")
u, ss = ds.getUserAndSessions("U1")
assert.Nil(t, u)
assert.Empty(t, ss)
}

View file

@ -3,10 +3,10 @@ package manager
import (
"context"
"errors"
"fmt"
"sync"
"time"
"github.com/google/btree"
"github.com/rs/zerolog"
"golang.org/x/oauth2"
"golang.org/x/sync/errgroup"
@ -20,7 +20,6 @@ import (
"github.com/pomerium/pomerium/internal/events"
"github.com/pomerium/pomerium/internal/identity/identity"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/scheduler"
"github.com/pomerium/pomerium/internal/telemetry/metrics"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/grpc/session"
@ -36,22 +35,15 @@ type Authenticator interface {
UpdateUserInfo(context.Context, *oauth2.Token, interface{}) error
}
type (
updateRecordsMessage struct {
records []*databroker.Record
}
)
// A Manager refreshes identity information using session and user data.
type Manager struct {
enabler.Enabler
cfg *atomicutil.Value[*config]
sessionScheduler *scheduler.Scheduler
userScheduler *scheduler.Scheduler
sessions sessionCollection
users userCollection
mu sync.Mutex
dataStore *dataStore
refreshSessionSchedulers map[string]*refreshSessionScheduler
updateUserInfoSchedulers map[string]*updateUserInfoScheduler
}
// New creates a new identity manager.
@ -61,21 +53,15 @@ func New(
mgr := &Manager{
cfg: atomicutil.NewValue(newConfig()),
sessionScheduler: scheduler.New(),
userScheduler: scheduler.New(),
dataStore: newDataStore(),
refreshSessionSchedulers: make(map[string]*refreshSessionScheduler),
updateUserInfoSchedulers: make(map[string]*updateUserInfoScheduler),
}
mgr.Enabler = enabler.New("identity_manager", mgr, true)
mgr.reset()
mgr.UpdateConfig(options...)
return mgr
}
func withLog(ctx context.Context) context.Context {
return log.WithContext(ctx, func(c zerolog.Context) zerolog.Context {
return c.Str("service", "identity_manager")
})
}
// UpdateConfig updates the manager with the new options.
func (mgr *Manager) UpdateConfig(options ...Option) {
mgr.cfg.Store(newConfig(options...))
@ -86,6 +72,11 @@ func (mgr *Manager) UpdateConfig(options ...Option) {
}
}
// GetDataBrokerServiceClient gets the databroker client.
func (mgr *Manager) GetDataBrokerServiceClient() databroker.DataBrokerServiceClient {
return mgr.cfg.Load().dataBrokerClient
}
// RunEnabled runs the manager. This method blocks until an error occurs or the given context is canceled.
func (mgr *Manager) RunEnabled(ctx context.Context) error {
leaser := databroker.NewLeaser("identity_manager", time.Second*30, mgr)
@ -94,173 +85,156 @@ func (mgr *Manager) RunEnabled(ctx context.Context) error {
// RunLeased runs the identity manager when a lease is acquired.
func (mgr *Manager) RunLeased(ctx context.Context) error {
ctx = withLog(ctx)
update := make(chan updateRecordsMessage, 1)
clear := make(chan struct{}, 1)
syncer := newDataBrokerSyncer(ctx, mgr.cfg, update, clear)
ctx = log.WithContext(ctx, func(c zerolog.Context) zerolog.Context {
return c.Str("service", "identity_manager")
})
eg, ctx := errgroup.WithContext(ctx)
eg.Go(func() error {
return syncer.Run(ctx)
sessionSyncer := newSessionSyncer(ctx, mgr)
defer sessionSyncer.Close()
return fmt.Errorf("session syncer error: %w", sessionSyncer.Run(ctx))
})
eg.Go(func() error {
return mgr.refreshLoop(ctx, update, clear)
userSyncer := newUserSyncer(ctx, mgr)
defer userSyncer.Close()
return fmt.Errorf("user syncer error: %w", userSyncer.Run(ctx))
})
return eg.Wait()
}
// GetDataBrokerServiceClient gets the databroker client.
func (mgr *Manager) GetDataBrokerServiceClient() databroker.DataBrokerServiceClient {
return mgr.cfg.Load().dataBrokerClient
func (mgr *Manager) onDeleteAllSessions(ctx context.Context) {
log.Ctx(ctx).Debug().Msg("all session deleted")
mgr.mu.Lock()
mgr.dataStore.deleteAllSessions()
for sID, rss := range mgr.refreshSessionSchedulers {
rss.Stop()
delete(mgr.refreshSessionSchedulers, sID)
}
mgr.mu.Unlock()
}
func (mgr *Manager) now() time.Time {
return mgr.cfg.Load().now()
func (mgr *Manager) onDeleteAllUsers(ctx context.Context) {
log.Ctx(ctx).Debug().Msg("all users deleted")
mgr.mu.Lock()
mgr.dataStore.deleteAllUsers()
for uID, uuis := range mgr.updateUserInfoSchedulers {
uuis.Stop()
delete(mgr.updateUserInfoSchedulers, uID)
}
mgr.mu.Unlock()
}
func (mgr *Manager) refreshLoop(ctx context.Context, update <-chan updateRecordsMessage, clear <-chan struct{}) error {
// wait for initial sync
select {
case <-ctx.Done():
return ctx.Err()
case <-clear:
mgr.reset()
}
select {
case <-ctx.Done():
return ctx.Err()
case msg := <-update:
mgr.onUpdateRecords(ctx, msg)
}
func (mgr *Manager) onDeleteSession(ctx context.Context, sessionID string) {
log.Ctx(ctx).Debug().Str("session_id", sessionID).Msg("session deleted")
log.Debug(ctx).
Int("sessions", mgr.sessions.Len()).
Int("users", mgr.users.Len()).
Msg("initial sync complete")
// start refreshing
maxWait := time.Minute * 10
var nextTime time.Time
timer := time.NewTimer(0)
defer timer.Stop()
for {
select {
case <-ctx.Done():
return ctx.Err()
case <-clear:
mgr.reset()
case msg := <-update:
mgr.onUpdateRecords(ctx, msg)
case <-timer.C:
}
now := mgr.now()
nextTime = now.Add(maxWait)
// refresh sessions
for {
tm, key := mgr.sessionScheduler.Next()
if now.Before(tm) {
if tm.Before(nextTime) {
nextTime = tm
}
break
}
mgr.sessionScheduler.Remove(key)
userID, sessionID := fromSessionSchedulerKey(key)
mgr.refreshSession(ctx, userID, sessionID)
}
// refresh users
for {
tm, key := mgr.userScheduler.Next()
if now.Before(tm) {
if tm.Before(nextTime) {
nextTime = tm
}
break
}
mgr.userScheduler.Remove(key)
mgr.refreshUser(ctx, key)
}
metrics.RecordIdentityManagerLastRefresh(ctx)
timer.Reset(time.Until(nextTime))
mgr.mu.Lock()
mgr.dataStore.deleteSession(sessionID)
if rss, ok := mgr.refreshSessionSchedulers[sessionID]; ok {
rss.Stop()
delete(mgr.refreshSessionSchedulers, sessionID)
}
mgr.mu.Unlock()
}
// refreshSession handles two distinct session lifecycle events:
//
// 1. If the session itself has expired, delete the session.
// 2. If the session's underlying OAuth2 access token is nearing expiration
// (but the session itself is still valid), refresh the access token.
//
// After a successful access token refresh, this method will also trigger a
// user info refresh. If an access token refresh or a user info refresh fails
// with a permanent error, the session will be deleted.
func (mgr *Manager) refreshSession(ctx context.Context, userID, sessionID string) {
log.Info(ctx).
Str("user_id", userID).
func (mgr *Manager) onDeleteUser(ctx context.Context, userID string) {
log.Ctx(ctx).Debug().Str("user_id", userID).Msg("user deleted")
mgr.mu.Lock()
mgr.dataStore.deleteUser(userID)
if uuis, ok := mgr.updateUserInfoSchedulers[userID]; ok {
uuis.Stop()
delete(mgr.updateUserInfoSchedulers, userID)
}
mgr.mu.Unlock()
}
func (mgr *Manager) onUpdateSession(ctx context.Context, s *session.Session) {
log.Ctx(ctx).Debug().Str("session_id", s.GetId()).Msg("session updated")
mgr.mu.Lock()
mgr.dataStore.putSession(s)
rss, ok := mgr.refreshSessionSchedulers[s.GetId()]
if !ok {
rss = newRefreshSessionScheduler(
ctx,
mgr.cfg.Load().now,
mgr.cfg.Load().sessionRefreshGracePeriod,
mgr.cfg.Load().sessionRefreshCoolOffDuration,
mgr.refreshSession,
s.GetId(),
)
mgr.refreshSessionSchedulers[s.GetId()] = rss
}
rss.Update(s)
mgr.mu.Unlock()
}
func (mgr *Manager) onUpdateUser(ctx context.Context, u *user.User) {
log.Ctx(ctx).Debug().Str("user_id", u.GetId()).Msg("user updated")
mgr.mu.Lock()
mgr.dataStore.putUser(u)
_, ok := mgr.updateUserInfoSchedulers[u.GetId()]
if !ok {
uuis := newUpdateUserInfoScheduler(
ctx,
mgr.cfg.Load().updateUserInfoInterval,
mgr.updateUserInfo,
u.GetId(),
)
mgr.updateUserInfoSchedulers[u.GetId()] = uuis
}
mgr.mu.Unlock()
}
func (mgr *Manager) refreshSession(ctx context.Context, sessionID string) {
log.Ctx(ctx).Debug().
Str("session_id", sessionID).
Msg("refreshing session")
s, ok := mgr.sessions.Get(userID, sessionID)
if !ok {
log.Warn(ctx).
Str("user_id", userID).
mgr.mu.Lock()
s, u := mgr.dataStore.getSessionAndUser(sessionID)
mgr.mu.Unlock()
if s == nil {
log.Ctx(ctx).Warn().
Str("user_id", u.GetId()).
Str("session_id", sessionID).
Msg("no session found for refresh")
return
}
s.lastRefresh = mgr.now()
if mgr.refreshSessionInternal(ctx, userID, sessionID, &s) {
mgr.sessions.ReplaceOrInsert(s)
mgr.sessionScheduler.Add(s.NextRefresh(), toSessionSchedulerKey(userID, sessionID))
}
}
// refreshSessionInternal performs the core refresh logic and returns true if
// the session should be scheduled for refresh again, or false if not.
func (mgr *Manager) refreshSessionInternal(
ctx context.Context, userID, sessionID string, s *Session,
) bool {
authenticator := mgr.cfg.Load().authenticator
if authenticator == nil {
log.Info(ctx).
Str("user_id", userID).
Str("session_id", sessionID).
log.Ctx(ctx).Info().
Str("user_id", s.GetUserId()).
Str("session_id", s.GetId()).
Msg("no authenticator defined, deleting session")
mgr.deleteSession(ctx, userID, sessionID)
return false
mgr.deleteSession(ctx, sessionID)
return
}
expiry := s.GetExpiresAt().AsTime()
if !expiry.After(mgr.now()) {
if !expiry.After(mgr.cfg.Load().now()) {
log.Info(ctx).
Str("user_id", userID).
Str("session_id", sessionID).
Str("user_id", s.GetUserId()).
Str("session_id", s.GetId()).
Msg("deleting expired session")
mgr.deleteSession(ctx, userID, sessionID)
return false
mgr.deleteSession(ctx, sessionID)
return
}
if s.Session == nil || s.Session.OauthToken == nil {
if s.GetOauthToken() == nil {
log.Warn(ctx).
Str("user_id", userID).
Str("session_id", sessionID).
Str("user_id", s.GetUserId()).
Str("session_id", s.GetId()).
Msg("no session oauth2 token found for refresh")
return false
return
}
newToken, err := authenticator.Refresh(ctx, FromOAuthToken(s.OauthToken), s)
newToken, err := authenticator.Refresh(ctx, FromOAuthToken(s.OauthToken), newSessionUnmarshaler(s))
metrics.RecordIdentityManagerSessionRefresh(ctx, err)
mgr.recordLastError(metrics_ids.IdentityManagerLastSessionRefreshError, err)
if isTemporaryError(err) {
@ -268,18 +242,18 @@ func (mgr *Manager) refreshSessionInternal(
Str("user_id", s.GetUserId()).
Str("session_id", s.GetId()).
Msg("failed to refresh oauth2 token")
return true
return
} else if err != nil {
log.Error(ctx).Err(err).
Str("user_id", s.GetUserId()).
Str("session_id", s.GetId()).
Msg("failed to refresh oauth2 token, deleting session")
mgr.deleteSession(ctx, userID, sessionID)
return false
mgr.deleteSession(ctx, sessionID)
return
}
s.OauthToken = ToOAuthToken(newToken)
err = authenticator.UpdateUserInfo(ctx, FromOAuthToken(s.OauthToken), s)
err = authenticator.UpdateUserInfo(ctx, FromOAuthToken(s.OauthToken), newMultiUnmarshaler(newUserUnmarshaler(u), newSessionUnmarshaler(s)))
metrics.RecordIdentityManagerUserRefresh(ctx, err)
mgr.recordLastError(metrics_ids.IdentityManagerLastUserRefreshError, err)
if isTemporaryError(err) {
@ -287,161 +261,94 @@ func (mgr *Manager) refreshSessionInternal(
Str("user_id", s.GetUserId()).
Str("session_id", s.GetId()).
Msg("failed to update user info")
return true
return
} else if err != nil {
log.Error(ctx).Err(err).
Str("user_id", s.GetUserId()).
Str("session_id", s.GetId()).
Msg("failed to update user info, deleting session")
mgr.deleteSession(ctx, userID, sessionID)
return false
mgr.deleteSession(ctx, sessionID)
return
}
fm, err := fieldmaskpb.New(s.Session, "oauth_token", "id_token", "claims")
if err != nil {
log.Error(ctx).Err(err).Msg("internal error")
return false
mgr.updateSession(ctx, s)
if u != nil {
mgr.updateUser(ctx, u)
}
if _, err := session.Patch(ctx, mgr.cfg.Load().dataBrokerClient, s.Session, fm); err != nil {
log.Error(ctx).Err(err).
Str("user_id", s.GetUserId()).
Str("session_id", s.GetId()).
Msg("failed to update session")
}
return true
}
func (mgr *Manager) refreshUser(ctx context.Context, userID string) {
log.Info(ctx).
Str("user_id", userID).
Msg("refreshing user")
func (mgr *Manager) updateUserInfo(ctx context.Context, userID string) {
log.Ctx(ctx).Info().Str("user_id", userID).Msg("updating user info")
authenticator := mgr.cfg.Load().authenticator
if authenticator == nil {
return
}
u, ok := mgr.users.Get(userID)
if !ok {
log.Warn(ctx).
mgr.mu.Lock()
u, ss := mgr.dataStore.getUserAndSessions(userID)
mgr.mu.Unlock()
if u == nil {
log.Ctx(ctx).Warn().
Str("user_id", userID).
Msg("no user found for refresh")
Msg("no user found for update")
return
}
u.lastRefresh = mgr.now()
mgr.userScheduler.Add(u.NextRefresh(), u.GetId())
for _, s := range mgr.sessions.GetSessionsForUser(userID) {
if s.Session == nil || s.Session.OauthToken == nil {
log.Warn(ctx).
Str("user_id", userID).
Msg("no session oauth2 token found for refresh")
for _, s := range ss {
if s.GetOauthToken() == nil {
log.Ctx(ctx).Warn().
Str("user_id", s.GetUserId()).
Str("session_id", s.GetId()).
Msg("no session oauth2 token found for updating user info")
continue
}
err := authenticator.UpdateUserInfo(ctx, FromOAuthToken(s.OauthToken), &u)
err := authenticator.UpdateUserInfo(ctx, FromOAuthToken(s.GetOauthToken()), newUserUnmarshaler(u))
metrics.RecordIdentityManagerUserRefresh(ctx, err)
mgr.recordLastError(metrics_ids.IdentityManagerLastUserRefreshError, err)
if isTemporaryError(err) {
log.Error(ctx).Err(err).
log.Ctx(ctx).Error().Err(err).
Str("user_id", s.GetUserId()).
Str("session_id", s.GetId()).
Msg("failed to update user info")
return
continue
} else if err != nil {
log.Error(ctx).Err(err).
log.Ctx(ctx).Error().Err(err).
Str("user_id", s.GetUserId()).
Str("session_id", s.GetId()).
Msg("failed to update user info, deleting session")
mgr.deleteSession(ctx, userID, s.GetId())
mgr.deleteSession(ctx, s.GetId())
continue
}
res, err := databroker.Put(ctx, mgr.cfg.Load().dataBrokerClient, u.User)
if err != nil {
log.Error(ctx).Err(err).
Str("user_id", s.GetUserId()).
Str("session_id", s.GetId()).
Msg("failed to update user")
continue
}
mgr.onUpdateUser(ctx, res.GetRecords()[0], u.User)
mgr.updateUser(ctx, u)
}
}
func (mgr *Manager) onUpdateRecords(ctx context.Context, msg updateRecordsMessage) {
for _, record := range msg.records {
switch record.GetType() {
case grpcutil.GetTypeURL(new(session.Session)):
var pbSession session.Session
err := record.GetData().UnmarshalTo(&pbSession)
if err != nil {
log.Warn(ctx).Msgf("error unmarshaling session: %s", err)
continue
// deleteSession deletes a session from the databroke, the local data store, and the schedulers
func (mgr *Manager) deleteSession(ctx context.Context, sessionID string) {
log.Ctx(ctx).Debug().
Str("session_id", sessionID).
Msg("deleting session")
mgr.mu.Lock()
mgr.dataStore.deleteSession(sessionID)
if rss, ok := mgr.refreshSessionSchedulers[sessionID]; ok {
rss.Stop()
delete(mgr.refreshSessionSchedulers, sessionID)
}
mgr.onUpdateSession(record, &pbSession)
case grpcutil.GetTypeURL(new(user.User)):
var pbUser user.User
err := record.GetData().UnmarshalTo(&pbUser)
if err != nil {
log.Warn(ctx).Msgf("error unmarshaling user: %s", err)
continue
}
mgr.onUpdateUser(ctx, record, &pbUser)
}
}
}
mgr.mu.Unlock()
func (mgr *Manager) onUpdateSession(record *databroker.Record, session *session.Session) {
mgr.sessionScheduler.Remove(toSessionSchedulerKey(session.GetUserId(), session.GetId()))
if record.GetDeletedAt() != nil {
mgr.sessions.Delete(session.GetUserId(), session.GetId())
return
}
// update session
s, _ := mgr.sessions.Get(session.GetUserId(), session.GetId())
if s.lastRefresh.IsZero() {
s.lastRefresh = mgr.now()
}
s.gracePeriod = mgr.cfg.Load().sessionRefreshGracePeriod
s.coolOffDuration = mgr.cfg.Load().sessionRefreshCoolOffDuration
s.Session = session
mgr.sessions.ReplaceOrInsert(s)
mgr.sessionScheduler.Add(s.NextRefresh(), toSessionSchedulerKey(session.GetUserId(), session.GetId()))
}
func (mgr *Manager) onUpdateUser(_ context.Context, record *databroker.Record, user *user.User) {
mgr.userScheduler.Remove(user.GetId())
if record.GetDeletedAt() != nil {
mgr.users.Delete(user.GetId())
return
}
u, _ := mgr.users.Get(user.GetId())
u.lastRefresh = mgr.cfg.Load().now()
u.User = user
mgr.users.ReplaceOrInsert(u)
mgr.userScheduler.Add(u.NextRefresh(), u.GetId())
}
func (mgr *Manager) deleteSession(ctx context.Context, userID, sessionID string) {
mgr.sessionScheduler.Remove(toSessionSchedulerKey(userID, sessionID))
mgr.sessions.Delete(userID, sessionID)
client := mgr.cfg.Load().dataBrokerClient
res, err := client.Get(ctx, &databroker.GetRequest{
res, err := mgr.cfg.Load().dataBrokerClient.Get(ctx, &databroker.GetRequest{
Type: grpcutil.GetTypeURL(new(session.Session)),
Id: sessionID,
})
if status.Code(err) == codes.NotFound {
return
} else if err != nil {
log.Error(ctx).Err(err).
log.Ctx(ctx).Error().Err(err).
Str("session_id", sessionID).
Msg("failed to delete session")
return
@ -450,21 +357,72 @@ func (mgr *Manager) deleteSession(ctx context.Context, userID, sessionID string)
record := res.GetRecord()
record.DeletedAt = timestamppb.Now()
_, err = client.Put(ctx, &databroker.PutRequest{
_, err = mgr.cfg.Load().dataBrokerClient.Put(ctx, &databroker.PutRequest{
Records: []*databroker.Record{record},
})
if err != nil {
log.Error(ctx).Err(err).
log.Ctx(ctx).Error().Err(err).
Str("session_id", sessionID).
Msg("failed to delete session")
return
}
}
// reset resets all the manager datastructures to their initial state
func (mgr *Manager) reset() {
mgr.sessions = sessionCollection{BTree: btree.New(8)}
mgr.users = userCollection{BTree: btree.New(8)}
func (mgr *Manager) updateSession(ctx context.Context, s *session.Session) {
log.Ctx(ctx).Debug().
Str("user_id", s.GetUserId()).
Str("session_id", s.GetId()).
Msg("updating session")
fm, err := fieldmaskpb.New(s, "oauth_token", "id_token", "claims")
if err != nil {
log.Error(ctx).Err(err).
Str("user_id", s.GetUserId()).
Str("session_id", s.GetId()).
Msg("failed to create fieldmask for session")
return
}
_, err = session.Patch(ctx, mgr.cfg.Load().dataBrokerClient, s, fm)
if err != nil {
log.Error(ctx).Err(err).
Str("user_id", s.GetUserId()).
Str("session_id", s.GetId()).
Msg("failed to patch updated session record")
return
}
mgr.mu.Lock()
mgr.dataStore.putSession(s)
if rss, ok := mgr.refreshSessionSchedulers[s.GetId()]; ok {
rss.Update(s)
}
mgr.mu.Unlock()
}
// updateUser updates the user in the databroker, the local data store, and resets the scheduler.
// (Whenever we refresh a session, we also refresh the user info. By resetting the user info
// scheduler here we can avoid refreshing user info more often than necessary.)
func (mgr *Manager) updateUser(ctx context.Context, u *user.User) {
log.Ctx(ctx).Debug().
Str("user_id", u.GetId()).
Msg("updating user")
_, err := databroker.Put(ctx, mgr.cfg.Load().dataBrokerClient, u)
if err != nil {
log.Ctx(ctx).Error().
Str("user_id", u.GetId()).
Err(err).
Msg("failed to store updated user record")
return
}
mgr.mu.Lock()
mgr.dataStore.putUser(u)
if uuis, ok := mgr.updateUserInfoSchedulers[u.GetId()]; ok {
uuis.Reset()
}
mgr.mu.Unlock()
}
func (mgr *Manager) recordLastError(id string, err error) {
@ -481,17 +439,3 @@ func (mgr *Manager) recordLastError(id string, err error) {
Id: id,
})
}
func isTemporaryError(err error) bool {
if err == nil {
return false
}
if errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled) {
return true
}
var hasTemporary interface{ Temporary() bool }
if errors.As(err, &hasTemporary) && hasTemporary.Temporary() {
return true
}
return false
}

View file

@ -2,28 +2,10 @@ package manager
import (
"context"
"errors"
"fmt"
"testing"
"time"
"github.com/stretchr/testify/assert"
"go.uber.org/mock/gomock"
"golang.org/x/oauth2"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/fieldmaskpb"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/pomerium/pomerium/internal/events"
"github.com/pomerium/pomerium/internal/identity/identity"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/grpc/databroker/mock_databroker"
"github.com/pomerium/pomerium/pkg/grpc/session"
"github.com/pomerium/pomerium/pkg/grpc/user"
metrics_ids "github.com/pomerium/pomerium/pkg/metrics"
"github.com/pomerium/pomerium/pkg/protoutil"
)
type mockAuthenticator struct {
@ -44,317 +26,3 @@ func (mock *mockAuthenticator) Revoke(_ context.Context, _ *oauth2.Token) error
func (mock *mockAuthenticator) UpdateUserInfo(_ context.Context, _ *oauth2.Token, _ any) error {
return mock.updateUserInfoError
}
func TestManager_refresh(t *testing.T) {
ctrl := gomock.NewController(t)
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10)
t.Cleanup(clearTimeout)
client := mock_databroker.NewMockDataBrokerServiceClient(ctrl)
mgr := New(WithDataBrokerClient(client))
mgr.onUpdateRecords(ctx, updateRecordsMessage{
records: []*databroker.Record{
databroker.NewRecord(&session.Session{
Id: "s1",
UserId: "u1",
OauthToken: &session.OAuthToken{},
ExpiresAt: timestamppb.New(time.Now().Add(time.Second * 10)),
}),
databroker.NewRecord(&user.User{
Id: "u1",
}),
},
})
client.EXPECT().Get(gomock.Any(), gomock.Any()).Return(nil, status.Error(codes.NotFound, "not found"))
mgr.refreshSession(ctx, "u1", "s1")
mgr.refreshUser(ctx, "u1")
}
func TestManager_onUpdateRecords(t *testing.T) {
ctrl := gomock.NewController(t)
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10)
defer clearTimeout()
now := time.Now()
mgr := New(
WithDataBrokerClient(mock_databroker.NewMockDataBrokerServiceClient(ctrl)),
WithNow(func() time.Time {
return now
}),
)
mgr.onUpdateRecords(ctx, updateRecordsMessage{
records: []*databroker.Record{
mkRecord(&session.Session{Id: "session1", UserId: "user1"}),
mkRecord(&user.User{Id: "user1", Name: "user 1", Email: "user1@example.com"}),
},
})
if _, ok := mgr.sessions.Get("user1", "session1"); assert.True(t, ok) {
tm, id := mgr.sessionScheduler.Next()
assert.Equal(t, now.Add(10*time.Second), tm)
assert.Equal(t, "user1\037session1", id)
}
if _, ok := mgr.users.Get("user1"); assert.True(t, ok) {
tm, id := mgr.userScheduler.Next()
assert.Equal(t, now.Add(userRefreshInterval), tm)
assert.Equal(t, "user1", id)
}
}
func TestManager_onUpdateSession(t *testing.T) {
startTime := time.Date(2023, 10, 19, 12, 0, 0, 0, time.UTC)
s := &session.Session{
Id: "session-id",
UserId: "user-id",
OauthToken: &session.OAuthToken{
AccessToken: "access-token",
ExpiresAt: timestamppb.New(startTime.Add(5 * time.Minute)),
},
IssuedAt: timestamppb.New(startTime),
ExpiresAt: timestamppb.New(startTime.Add(24 * time.Hour)),
}
assertNextScheduled := func(t *testing.T, mgr *Manager, expectedTime time.Time) {
t.Helper()
tm, key := mgr.sessionScheduler.Next()
assert.Equal(t, expectedTime, tm)
assert.Equal(t, "user-id\037session-id", key)
}
t.Run("initial refresh event when not expiring soon", func(t *testing.T) {
now := startTime
mgr := New(WithNow(func() time.Time { return now }))
// When the Manager first becomes aware of a session it should schedule
// a refresh event for one minute before access token expiration.
mgr.onUpdateSession(mkRecord(s), s)
assertNextScheduled(t, mgr, startTime.Add(4*time.Minute))
})
t.Run("initial refresh event when expiring soon", func(t *testing.T) {
now := startTime
mgr := New(WithNow(func() time.Time { return now }))
// When the Manager first becomes aware of a session, if that session
// is expiring within the gracePeriod (1 minute), it should schedule a
// refresh event for as soon as possible, subject to the
// coolOffDuration (10 seconds).
now = now.Add(4*time.Minute + 30*time.Second) // 30 s before expiration
mgr.onUpdateSession(mkRecord(s), s)
assertNextScheduled(t, mgr, now.Add(10*time.Second))
})
t.Run("update near scheduled refresh", func(t *testing.T) {
now := startTime
mgr := New(WithNow(func() time.Time { return now }))
mgr.onUpdateSession(mkRecord(s), s)
assertNextScheduled(t, mgr, startTime.Add(4*time.Minute))
// If a session is updated close to the time when it is scheduled to be
// refreshed, the scheduled refresh event should not be pushed back.
now = now.Add(3*time.Minute + 55*time.Second) // 5 s before refresh
mgr.onUpdateSession(mkRecord(s), s)
assertNextScheduled(t, mgr, now.Add(5*time.Second))
// However, if an update changes the access token validity, the refresh
// event should be rescheduled accordingly. (This should be uncommon,
// as only the refresh loop itself should modify the access token.)
s2 := proto.Clone(s).(*session.Session)
s2.OauthToken.ExpiresAt = timestamppb.New(now.Add(5 * time.Minute))
mgr.onUpdateSession(mkRecord(s2), s2)
assertNextScheduled(t, mgr, now.Add(4*time.Minute))
})
t.Run("session record deleted", func(t *testing.T) {
now := startTime
mgr := New(WithNow(func() time.Time { return now }))
mgr.onUpdateSession(mkRecord(s), s)
assertNextScheduled(t, mgr, startTime.Add(4*time.Minute))
// If a session is deleted, any scheduled refresh event should be canceled.
record := mkRecord(s)
record.DeletedAt = timestamppb.New(now)
mgr.onUpdateSession(record, s)
_, key := mgr.sessionScheduler.Next()
assert.Empty(t, key)
})
}
func TestManager_refreshSession(t *testing.T) {
startTime := time.Date(2023, 10, 19, 12, 0, 0, 0, time.UTC)
var auth mockAuthenticator
ctrl := gomock.NewController(t)
client := mock_databroker.NewMockDataBrokerServiceClient(ctrl)
now := startTime
mgr := New(
WithDataBrokerClient(client),
WithNow(func() time.Time { return now }),
WithAuthenticator(&auth),
)
// Initialize the Manager with a new session.
s := &session.Session{
Id: "session-id",
UserId: "user-id",
OauthToken: &session.OAuthToken{
AccessToken: "access-token",
ExpiresAt: timestamppb.New(startTime.Add(5 * time.Minute)),
RefreshToken: "refresh-token",
},
IssuedAt: timestamppb.New(startTime),
ExpiresAt: timestamppb.New(startTime.Add(24 * time.Hour)),
}
mgr.sessions.ReplaceOrInsert(Session{
Session: s,
lastRefresh: startTime,
gracePeriod: time.Minute,
coolOffDuration: 10 * time.Second,
})
// If OAuth2 token refresh fails with a temporary error, the manager should
// still reschedule another refresh attempt.
now = now.Add(4 * time.Minute)
auth.refreshError = context.DeadlineExceeded
mgr.refreshSession(context.Background(), "user-id", "session-id")
tm, key := mgr.sessionScheduler.Next()
assert.Equal(t, now.Add(10*time.Second), tm)
assert.Equal(t, "user-id\037session-id", key)
// Simulate a successful token refresh on the second attempt. The manager
// should store the updated session in the databroker and schedule another
// refresh event.
now = now.Add(10 * time.Second)
auth.refreshResult, auth.refreshError = &oauth2.Token{
AccessToken: "new-access-token",
RefreshToken: "new-refresh-token",
Expiry: now.Add(5 * time.Minute),
}, nil
expectedSession := proto.Clone(s).(*session.Session)
expectedSession.OauthToken = &session.OAuthToken{
AccessToken: "new-access-token",
ExpiresAt: timestamppb.New(now.Add(5 * time.Minute)),
RefreshToken: "new-refresh-token",
}
client.EXPECT().Patch(gomock.Any(), objectsAreEqualMatcher{
&databroker.PatchRequest{
Records: []*databroker.Record{{
Type: "type.googleapis.com/session.Session",
Id: "session-id",
Data: protoutil.NewAny(expectedSession),
}},
FieldMask: &fieldmaskpb.FieldMask{
Paths: []string{"oauth_token", "id_token", "claims"},
},
},
}).
Return(nil /* this result is currently unused */, nil)
mgr.refreshSession(context.Background(), "user-id", "session-id")
tm, key = mgr.sessionScheduler.Next()
assert.Equal(t, now.Add(4*time.Minute), tm)
assert.Equal(t, "user-id\037session-id", key)
}
func TestManager_reportErrors(t *testing.T) {
ctrl := gomock.NewController(t)
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10)
defer clearTimeout()
evtMgr := events.New()
received := make(chan events.Event, 1)
handle := evtMgr.Register(func(evt events.Event) {
received <- evt
})
defer evtMgr.Unregister(handle)
expectMsg := func(id, msg string) {
t.Helper()
assert.Eventually(t, func() bool {
select {
case evt := <-received:
lastErr := evt.(*events.LastError)
return msg == lastErr.Message && id == lastErr.Id
default:
return false
}
}, time.Second, time.Millisecond*20, msg)
}
s := &session.Session{
Id: "session1",
UserId: "user1",
OauthToken: &session.OAuthToken{
ExpiresAt: timestamppb.New(time.Now().Add(time.Hour)),
},
ExpiresAt: timestamppb.New(time.Now().Add(time.Hour)),
}
client := mock_databroker.NewMockDataBrokerServiceClient(ctrl)
client.EXPECT().Get(gomock.Any(), gomock.Any()).AnyTimes().Return(&databroker.GetResponse{Record: databroker.NewRecord(s)}, nil)
client.EXPECT().Put(gomock.Any(), gomock.Any()).AnyTimes()
mgr := New(
WithEventManager(evtMgr),
WithDataBrokerClient(client),
WithAuthenticator(&mockAuthenticator{
refreshError: errors.New("update session"),
updateUserInfoError: errors.New("update user info"),
}),
)
mgr.onUpdateRecords(ctx, updateRecordsMessage{
records: []*databroker.Record{
mkRecord(s),
mkRecord(&user.User{Id: "user1", Name: "user 1", Email: "user1@example.com"}),
},
})
mgr.refreshUser(ctx, "user1")
expectMsg(metrics_ids.IdentityManagerLastUserRefreshError, "update user info")
mgr.onUpdateRecords(ctx, updateRecordsMessage{
records: []*databroker.Record{
mkRecord(s),
mkRecord(&user.User{Id: "user1", Name: "user 1", Email: "user1@example.com"}),
},
})
mgr.refreshSession(ctx, "user1", "session1")
expectMsg(metrics_ids.IdentityManagerLastSessionRefreshError, "update session")
}
func mkRecord(msg recordable) *databroker.Record {
data := protoutil.NewAny(msg)
return &databroker.Record{
Type: data.GetTypeUrl(),
Id: msg.GetId(),
Data: data,
}
}
type recordable interface {
proto.Message
GetId() string
}
// objectsAreEqualMatcher implements gomock.Matcher using ObjectsAreEqual. This
// is especially helpful when working with pointers, as it will compare the
// underlying values rather than the pointers themselves.
type objectsAreEqualMatcher struct {
expected interface{}
}
func (m objectsAreEqualMatcher) Matches(x interface{}) bool {
return assert.ObjectsAreEqual(m.expected, x)
}
func (m objectsAreEqualMatcher) String() string {
return fmt.Sprintf("is equal to %v (%T)", m.expected, m.expected)
}

View file

@ -1,7 +1,8 @@
package manager
import (
"strings"
"context"
"errors"
"golang.org/x/oauth2"
"google.golang.org/protobuf/types/known/timestamppb"
@ -9,21 +10,6 @@ import (
"github.com/pomerium/pomerium/pkg/grpc/session"
)
func toSessionSchedulerKey(userID, sessionID string) string {
return userID + "\037" + sessionID
}
func fromSessionSchedulerKey(key string) (userID, sessionID string) {
idx := strings.Index(key, "\037")
if idx >= 0 {
userID = key[:idx]
sessionID = key[idx+1:]
} else {
userID = key
}
return userID, sessionID
}
// FromOAuthToken converts a session oauth token to oauth2.Token.
func FromOAuthToken(token *session.OAuthToken) *oauth2.Token {
return &oauth2.Token{
@ -44,3 +30,17 @@ func ToOAuthToken(token *oauth2.Token) *session.OAuthToken {
ExpiresAt: expiry,
}
}
func isTemporaryError(err error) bool {
if err == nil {
return false
}
if errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled) {
return true
}
var hasTemporary interface{ Temporary() bool }
if errors.As(err, &hasTemporary) && hasTemporary.Temporary() {
return true
}
return false
}

View file

@ -0,0 +1,164 @@
package manager
import (
"context"
"sync/atomic"
"time"
"github.com/pomerium/pomerium/pkg/grpc/session"
)
type updateUserInfoScheduler struct {
baseCtx context.Context
updateUserInfoInterval time.Duration
updateUserInfo func(ctx context.Context, userID string)
userID string
reset chan struct{}
cancel context.CancelFunc
}
func newUpdateUserInfoScheduler(
ctx context.Context,
updateUserInfoInterval time.Duration,
updateUserInfo func(ctx context.Context, userID string),
userID string,
) *updateUserInfoScheduler {
uuis := &updateUserInfoScheduler{
baseCtx: ctx,
updateUserInfoInterval: updateUserInfoInterval,
updateUserInfo: updateUserInfo,
userID: userID,
reset: make(chan struct{}, 1),
}
ctx, uuis.cancel = context.WithCancel(context.WithoutCancel(uuis.baseCtx))
go uuis.run(ctx)
return uuis
}
func (uuis *updateUserInfoScheduler) Reset() {
// trigger a reset by sending to the reset channel, which is buffered,
// so if we can't proceed there's already a pending reset and no need
// to wait
select {
case uuis.reset <- struct{}{}:
default:
}
}
func (uuis *updateUserInfoScheduler) Stop() {
uuis.cancel()
}
func (uuis *updateUserInfoScheduler) run(ctx context.Context) {
ticker := time.NewTicker(uuis.updateUserInfoInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-uuis.reset:
ticker.Reset(uuis.updateUserInfoInterval)
case <-ticker.C:
uuis.updateUserInfo(uuis.baseCtx, uuis.userID)
}
}
}
type refreshSessionScheduler struct {
baseCtx context.Context
now func() time.Time
sessionRefreshGracePeriod time.Duration
sessionRefreshCoolOffDuration time.Duration
refreshSession func(ctx context.Context, sesionID string)
sessionID string
lastRefresh atomic.Pointer[time.Time]
next chan time.Time
cancel context.CancelFunc
}
func newRefreshSessionScheduler(
ctx context.Context,
now func() time.Time,
sessionRefreshGracePeriod time.Duration,
sessionRefreshCoolOffDuration time.Duration,
refreshSession func(ctx context.Context, sesionID string),
sessionID string,
) *refreshSessionScheduler {
rss := &refreshSessionScheduler{
baseCtx: ctx,
now: now,
sessionRefreshGracePeriod: sessionRefreshGracePeriod,
sessionRefreshCoolOffDuration: sessionRefreshCoolOffDuration,
refreshSession: refreshSession,
sessionID: sessionID,
next: make(chan time.Time, 1),
}
tm := now()
rss.lastRefresh.Store(&tm)
ctx, rss.cancel = context.WithCancel(context.WithoutCancel(rss.baseCtx))
go rss.run(ctx)
return rss
}
func (rss *refreshSessionScheduler) Update(s *session.Session) {
due := nextSessionRefresh(
s,
*rss.lastRefresh.Load(),
rss.sessionRefreshGracePeriod,
rss.sessionRefreshCoolOffDuration,
)
for {
select {
case <-rss.next:
default:
}
select {
case rss.next <- due:
return
default:
}
}
}
func (rss *refreshSessionScheduler) Stop() {
rss.cancel()
}
func (rss *refreshSessionScheduler) run(ctx context.Context) {
var timer *time.Timer
// ensure we clean up any orphaned timers
defer func() {
if timer != nil {
timer.Stop()
}
}()
// wait for the first update
select {
case <-ctx.Done():
return
case due := <-rss.next:
delay := max(time.Until(due), 0)
timer = time.NewTimer(delay)
}
// wait for updates or for the timer to trigger
for {
select {
case <-ctx.Done():
return
case due := <-rss.next:
delay := max(time.Until(due), 0)
// stop the existing timer and start a new one
timer.Stop()
timer = time.NewTimer(delay)
case <-timer.C:
tm := rss.now()
rss.lastRefresh.Store(&tm)
rss.refreshSession(rss.baseCtx, rss.sessionID)
}
}
}

View file

@ -0,0 +1,113 @@
package manager
import (
"context"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/pomerium/pomerium/pkg/grpc/session"
)
func TestRefreshSessionScheduler(t *testing.T) {
t.Parallel()
var calls safeSlice[time.Time]
ctx := context.Background()
sessionRefreshGracePeriod := time.Millisecond
sessionRefreshCoolOffDuration := time.Millisecond
rss := newRefreshSessionScheduler(
ctx,
time.Now,
sessionRefreshGracePeriod,
sessionRefreshCoolOffDuration,
func(ctx context.Context, sesionID string) {
calls.Append(time.Now())
},
"S1",
)
t.Cleanup(rss.Stop)
rss.Update(&session.Session{ExpiresAt: timestamppb.Now()})
assert.Eventually(t, func() bool {
return calls.Len() == 1
}, 100*time.Millisecond, 10*time.Millisecond, "should trigger once")
rss.Update(&session.Session{ExpiresAt: timestamppb.Now()})
assert.Eventually(t, func() bool {
return calls.Len() == 2
}, 100*time.Millisecond, 10*time.Millisecond, "should trigger again")
}
func TestUpdateUserInfoScheduler(t *testing.T) {
t.Parallel()
var calls safeSlice[time.Time]
ctx := context.Background()
userUpdateInfoInterval := 100 * time.Millisecond
uuis := newUpdateUserInfoScheduler(ctx, userUpdateInfoInterval, func(ctx context.Context, userID string) {
calls.Append(time.Now())
}, "U1")
t.Cleanup(uuis.Stop)
// should eventually trigger
assert.Eventually(t, func() bool {
return calls.Len() == 1
}, 3*userUpdateInfoInterval, userUpdateInfoInterval/10, "should trigger once")
uuis.Reset()
uuis.Reset()
uuis.Reset()
assert.Eventually(t, func() bool {
return calls.Len() == 2
}, 3*userUpdateInfoInterval, userUpdateInfoInterval/10, "should trigger once after multiple resets")
var diff time.Duration
if calls.Len() >= 2 {
diff = calls.At(calls.Len() - 1).Sub(calls.At(calls.Len() - 2))
}
assert.GreaterOrEqual(t, diff, userUpdateInfoInterval, "delay should exceed interval")
uuis.Reset()
uuis.Stop()
time.Sleep(3 * userUpdateInfoInterval)
assert.Equal(t, 2, calls.Len(), "should not trigger again after stopping")
}
type safeSlice[T any] struct {
mu sync.Mutex
elements []T
}
func (s *safeSlice[T]) Append(elements ...T) {
s.mu.Lock()
s.elements = append(s.elements, elements...)
s.mu.Unlock()
}
func (s *safeSlice[T]) At(idx int) T {
var el T
s.mu.Lock()
if idx >= 0 && idx < len(s.elements) {
el = s.elements[idx]
}
s.mu.Unlock()
return el
}
func (s *safeSlice[T]) Len() int {
s.mu.Lock()
n := len(s.elements)
s.mu.Unlock()
return n
}

View file

@ -3,53 +3,77 @@ package manager
import (
"context"
"github.com/pomerium/pomerium/internal/atomicutil"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/grpc/session"
"github.com/pomerium/pomerium/pkg/grpc/user"
"github.com/pomerium/pomerium/pkg/grpcutil"
)
type dataBrokerSyncer struct {
cfg *atomicutil.Value[*config]
update chan<- updateRecordsMessage
clear chan<- struct{}
syncer *databroker.Syncer
type sessionSyncerHandler struct {
baseCtx context.Context
mgr *Manager
}
func newDataBrokerSyncer(
_ context.Context,
cfg *atomicutil.Value[*config],
update chan<- updateRecordsMessage,
clear chan<- struct{},
) *dataBrokerSyncer {
syncer := &dataBrokerSyncer{
cfg: cfg,
func newSessionSyncer(ctx context.Context, mgr *Manager) *databroker.Syncer {
return databroker.NewSyncer("identity_manager/sessions", sessionSyncerHandler{baseCtx: ctx, mgr: mgr},
databroker.WithTypeURL(grpcutil.GetTypeURL(new(session.Session))))
}
update: update,
clear: clear,
func (h sessionSyncerHandler) ClearRecords(ctx context.Context) {
h.mgr.onDeleteAllSessions(ctx)
}
func (h sessionSyncerHandler) GetDataBrokerServiceClient() databroker.DataBrokerServiceClient {
return h.mgr.cfg.Load().dataBrokerClient
}
func (h sessionSyncerHandler) UpdateRecords(ctx context.Context, _ uint64, records []*databroker.Record) {
for _, record := range records {
if record.GetDeletedAt() != nil {
h.mgr.onDeleteSession(h.baseCtx, record.GetId())
} else {
var s session.Session
err := record.Data.UnmarshalTo(&s)
if err != nil {
log.Ctx(ctx).Warn().Err(err).Msg("invalid data in session record, ignoring")
} else {
h.mgr.onUpdateSession(h.baseCtx, &s)
}
}
syncer.syncer = databroker.NewSyncer("identity_manager", syncer)
return syncer
}
func (syncer *dataBrokerSyncer) Run(ctx context.Context) (err error) {
return syncer.syncer.Run(ctx)
}
func (syncer *dataBrokerSyncer) ClearRecords(ctx context.Context) {
select {
case <-ctx.Done():
case syncer.clear <- struct{}{}:
}
}
func (syncer *dataBrokerSyncer) GetDataBrokerServiceClient() databroker.DataBrokerServiceClient {
return syncer.cfg.Load().dataBrokerClient
type userSyncerHandler struct {
baseCtx context.Context
mgr *Manager
}
func (syncer *dataBrokerSyncer) UpdateRecords(ctx context.Context, _ uint64, records []*databroker.Record) {
select {
case <-ctx.Done():
case syncer.update <- updateRecordsMessage{records: records}:
func newUserSyncer(ctx context.Context, mgr *Manager) *databroker.Syncer {
return databroker.NewSyncer("identity_manager/users", userSyncerHandler{baseCtx: ctx, mgr: mgr},
databroker.WithTypeURL(grpcutil.GetTypeURL(new(user.User))))
}
func (h userSyncerHandler) ClearRecords(ctx context.Context) {
h.mgr.onDeleteAllUsers(ctx)
}
func (h userSyncerHandler) GetDataBrokerServiceClient() databroker.DataBrokerServiceClient {
return h.mgr.cfg.Load().dataBrokerClient
}
func (h userSyncerHandler) UpdateRecords(ctx context.Context, _ uint64, records []*databroker.Record) {
for _, record := range records {
if record.GetDeletedAt() != nil {
h.mgr.onDeleteUser(h.baseCtx, record.GetId())
} else {
var u user.User
err := record.Data.UnmarshalTo(&u)
if err != nil {
log.Ctx(ctx).Warn().Err(err).Msg("invalid data in user record, ignoring")
} else {
h.mgr.onUpdateUser(h.baseCtx, &u)
}
}
}
}