mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-02 20:06:03 +02:00
identity: rework session refresh error handling (#4638) Currently, if a temporary error occurs while attempting to refresh an OAuth2 token, the identity manager won't schedule another attempt. Instead, update the session refresh logic so that it will retry after temporary errors. Extract the bulk of this logic into a separate method that returns a boolean indicating whether to schedule another refresh. Update the unit test to simulate a temporary error during OAuth2 token refresh. Co-authored-by: Kenneth Jenkins <51246568+kenjenkins@users.noreply.github.com>
482 lines
13 KiB
Go
482 lines
13 KiB
Go
// Package manager contains an identity manager responsible for refreshing sessions and creating users.
|
|
package manager
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"time"
|
|
|
|
"github.com/google/btree"
|
|
"github.com/rs/zerolog"
|
|
"golang.org/x/oauth2"
|
|
"golang.org/x/sync/errgroup"
|
|
"google.golang.org/grpc/codes"
|
|
"google.golang.org/grpc/status"
|
|
"google.golang.org/protobuf/types/known/timestamppb"
|
|
|
|
"github.com/pomerium/pomerium/internal/atomicutil"
|
|
"github.com/pomerium/pomerium/internal/events"
|
|
"github.com/pomerium/pomerium/internal/identity/identity"
|
|
"github.com/pomerium/pomerium/internal/log"
|
|
"github.com/pomerium/pomerium/internal/scheduler"
|
|
"github.com/pomerium/pomerium/internal/telemetry/metrics"
|
|
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
|
"github.com/pomerium/pomerium/pkg/grpc/session"
|
|
"github.com/pomerium/pomerium/pkg/grpc/user"
|
|
"github.com/pomerium/pomerium/pkg/grpcutil"
|
|
metrics_ids "github.com/pomerium/pomerium/pkg/metrics"
|
|
)
|
|
|
|
// Authenticator is an identity.Provider with only the methods needed by the manager.
|
|
type Authenticator interface {
|
|
Refresh(context.Context, *oauth2.Token, identity.State) (*oauth2.Token, error)
|
|
Revoke(context.Context, *oauth2.Token) error
|
|
UpdateUserInfo(context.Context, *oauth2.Token, interface{}) error
|
|
}
|
|
|
|
type (
|
|
updateRecordsMessage struct {
|
|
records []*databroker.Record
|
|
}
|
|
)
|
|
|
|
// A Manager refreshes identity information using session and user data.
|
|
type Manager struct {
|
|
cfg *atomicutil.Value[*config]
|
|
|
|
sessionScheduler *scheduler.Scheduler
|
|
userScheduler *scheduler.Scheduler
|
|
|
|
sessions sessionCollection
|
|
users userCollection
|
|
}
|
|
|
|
// New creates a new identity manager.
|
|
func New(
|
|
options ...Option,
|
|
) *Manager {
|
|
mgr := &Manager{
|
|
cfg: atomicutil.NewValue(newConfig()),
|
|
|
|
sessionScheduler: scheduler.New(),
|
|
userScheduler: scheduler.New(),
|
|
}
|
|
mgr.reset()
|
|
mgr.UpdateConfig(options...)
|
|
return mgr
|
|
}
|
|
|
|
func withLog(ctx context.Context) context.Context {
|
|
return log.WithContext(ctx, func(c zerolog.Context) zerolog.Context {
|
|
return c.Str("service", "identity_manager")
|
|
})
|
|
}
|
|
|
|
// UpdateConfig updates the manager with the new options.
|
|
func (mgr *Manager) UpdateConfig(options ...Option) {
|
|
mgr.cfg.Store(newConfig(options...))
|
|
}
|
|
|
|
// Run runs the manager. This method blocks until an error occurs or the given context is canceled.
|
|
func (mgr *Manager) Run(ctx context.Context) error {
|
|
leaser := databroker.NewLeaser("identity_manager", time.Second*30, mgr)
|
|
return leaser.Run(ctx)
|
|
}
|
|
|
|
// RunLeased runs the identity manager when a lease is acquired.
|
|
func (mgr *Manager) RunLeased(ctx context.Context) error {
|
|
ctx = withLog(ctx)
|
|
update := make(chan updateRecordsMessage, 1)
|
|
clear := make(chan struct{}, 1)
|
|
|
|
syncer := newDataBrokerSyncer(ctx, mgr.cfg, update, clear)
|
|
|
|
eg, ctx := errgroup.WithContext(ctx)
|
|
eg.Go(func() error {
|
|
return syncer.Run(ctx)
|
|
})
|
|
eg.Go(func() error {
|
|
return mgr.refreshLoop(ctx, update, clear)
|
|
})
|
|
|
|
return eg.Wait()
|
|
}
|
|
|
|
// GetDataBrokerServiceClient gets the databroker client.
|
|
func (mgr *Manager) GetDataBrokerServiceClient() databroker.DataBrokerServiceClient {
|
|
return mgr.cfg.Load().dataBrokerClient
|
|
}
|
|
|
|
func (mgr *Manager) now() time.Time {
|
|
return mgr.cfg.Load().now()
|
|
}
|
|
|
|
func (mgr *Manager) refreshLoop(ctx context.Context, update <-chan updateRecordsMessage, clear <-chan struct{}) error {
|
|
// wait for initial sync
|
|
select {
|
|
case <-ctx.Done():
|
|
return ctx.Err()
|
|
case <-clear:
|
|
mgr.reset()
|
|
}
|
|
select {
|
|
case <-ctx.Done():
|
|
return ctx.Err()
|
|
case msg := <-update:
|
|
mgr.onUpdateRecords(ctx, msg)
|
|
}
|
|
|
|
log.Info(ctx).
|
|
Int("sessions", mgr.sessions.Len()).
|
|
Int("users", mgr.users.Len()).
|
|
Msg("initial sync complete")
|
|
|
|
// start refreshing
|
|
maxWait := time.Minute * 10
|
|
var nextTime time.Time
|
|
|
|
timer := time.NewTimer(0)
|
|
defer timer.Stop()
|
|
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return ctx.Err()
|
|
case <-clear:
|
|
mgr.reset()
|
|
case msg := <-update:
|
|
mgr.onUpdateRecords(ctx, msg)
|
|
case <-timer.C:
|
|
}
|
|
|
|
now := mgr.now()
|
|
nextTime = now.Add(maxWait)
|
|
|
|
// refresh sessions
|
|
for {
|
|
tm, key := mgr.sessionScheduler.Next()
|
|
if now.Before(tm) {
|
|
if tm.Before(nextTime) {
|
|
nextTime = tm
|
|
}
|
|
break
|
|
}
|
|
mgr.sessionScheduler.Remove(key)
|
|
|
|
userID, sessionID := fromSessionSchedulerKey(key)
|
|
mgr.refreshSession(ctx, userID, sessionID)
|
|
}
|
|
|
|
// refresh users
|
|
for {
|
|
tm, key := mgr.userScheduler.Next()
|
|
if now.Before(tm) {
|
|
if tm.Before(nextTime) {
|
|
nextTime = tm
|
|
}
|
|
break
|
|
}
|
|
mgr.userScheduler.Remove(key)
|
|
|
|
mgr.refreshUser(ctx, key)
|
|
}
|
|
|
|
metrics.RecordIdentityManagerLastRefresh(ctx)
|
|
timer.Reset(time.Until(nextTime))
|
|
}
|
|
}
|
|
|
|
// refreshSession handles two distinct session lifecycle events:
|
|
//
|
|
// 1. If the session itself has expired, delete the session.
|
|
// 2. If the session's underlying OAuth2 access token is nearing expiration
|
|
// (but the session itself is still valid), refresh the access token.
|
|
//
|
|
// After a successful access token refresh, this method will also trigger a
|
|
// user info refresh. If an access token refresh or a user info refresh fails
|
|
// with a permanent error, the session will be deleted.
|
|
func (mgr *Manager) refreshSession(ctx context.Context, userID, sessionID string) {
|
|
log.Info(ctx).
|
|
Str("user_id", userID).
|
|
Str("session_id", sessionID).
|
|
Msg("refreshing session")
|
|
|
|
s, ok := mgr.sessions.Get(userID, sessionID)
|
|
if !ok {
|
|
log.Warn(ctx).
|
|
Str("user_id", userID).
|
|
Str("session_id", sessionID).
|
|
Msg("no session found for refresh")
|
|
return
|
|
}
|
|
|
|
s.lastRefresh = mgr.now()
|
|
|
|
if mgr.refreshSessionInternal(ctx, userID, sessionID, &s) {
|
|
mgr.sessions.ReplaceOrInsert(s)
|
|
mgr.sessionScheduler.Add(s.NextRefresh(), toSessionSchedulerKey(userID, sessionID))
|
|
}
|
|
}
|
|
|
|
// refreshSessionInternal performs the core refresh logic and returns true if
|
|
// the session should be scheduled for refresh again, or false if not.
|
|
func (mgr *Manager) refreshSessionInternal(
|
|
ctx context.Context, userID, sessionID string, s *Session,
|
|
) bool {
|
|
authenticator := mgr.cfg.Load().authenticator
|
|
if authenticator == nil {
|
|
log.Info(ctx).
|
|
Str("user_id", userID).
|
|
Str("session_id", sessionID).
|
|
Msg("no authenticator defined, deleting session")
|
|
mgr.deleteSession(ctx, userID, sessionID)
|
|
return false
|
|
}
|
|
|
|
expiry := s.GetExpiresAt().AsTime()
|
|
if !expiry.After(mgr.now()) {
|
|
log.Info(ctx).
|
|
Str("user_id", userID).
|
|
Str("session_id", sessionID).
|
|
Msg("deleting expired session")
|
|
mgr.deleteSession(ctx, userID, sessionID)
|
|
return false
|
|
}
|
|
|
|
if s.Session == nil || s.Session.OauthToken == nil {
|
|
log.Warn(ctx).
|
|
Str("user_id", userID).
|
|
Str("session_id", sessionID).
|
|
Msg("no session oauth2 token found for refresh")
|
|
return false
|
|
}
|
|
|
|
newToken, err := authenticator.Refresh(ctx, FromOAuthToken(s.OauthToken), s)
|
|
metrics.RecordIdentityManagerSessionRefresh(ctx, err)
|
|
mgr.recordLastError(metrics_ids.IdentityManagerLastSessionRefreshError, err)
|
|
if isTemporaryError(err) {
|
|
log.Error(ctx).Err(err).
|
|
Str("user_id", s.GetUserId()).
|
|
Str("session_id", s.GetId()).
|
|
Msg("failed to refresh oauth2 token")
|
|
return true
|
|
} else if err != nil {
|
|
log.Error(ctx).Err(err).
|
|
Str("user_id", s.GetUserId()).
|
|
Str("session_id", s.GetId()).
|
|
Msg("failed to refresh oauth2 token, deleting session")
|
|
mgr.deleteSession(ctx, userID, sessionID)
|
|
return false
|
|
}
|
|
s.OauthToken = ToOAuthToken(newToken)
|
|
|
|
err = authenticator.UpdateUserInfo(ctx, FromOAuthToken(s.OauthToken), s)
|
|
metrics.RecordIdentityManagerUserRefresh(ctx, err)
|
|
mgr.recordLastError(metrics_ids.IdentityManagerLastUserRefreshError, err)
|
|
if isTemporaryError(err) {
|
|
log.Error(ctx).Err(err).
|
|
Str("user_id", s.GetUserId()).
|
|
Str("session_id", s.GetId()).
|
|
Msg("failed to update user info")
|
|
return true
|
|
} else if err != nil {
|
|
log.Error(ctx).Err(err).
|
|
Str("user_id", s.GetUserId()).
|
|
Str("session_id", s.GetId()).
|
|
Msg("failed to update user info, deleting session")
|
|
mgr.deleteSession(ctx, userID, sessionID)
|
|
return false
|
|
}
|
|
|
|
if _, err := session.Put(ctx, mgr.cfg.Load().dataBrokerClient, s.Session); err != nil {
|
|
log.Error(ctx).Err(err).
|
|
Str("user_id", s.GetUserId()).
|
|
Str("session_id", s.GetId()).
|
|
Msg("failed to update session")
|
|
}
|
|
return true
|
|
}
|
|
|
|
func (mgr *Manager) refreshUser(ctx context.Context, userID string) {
|
|
log.Info(ctx).
|
|
Str("user_id", userID).
|
|
Msg("refreshing user")
|
|
|
|
authenticator := mgr.cfg.Load().authenticator
|
|
if authenticator == nil {
|
|
return
|
|
}
|
|
|
|
u, ok := mgr.users.Get(userID)
|
|
if !ok {
|
|
log.Warn(ctx).
|
|
Str("user_id", userID).
|
|
Msg("no user found for refresh")
|
|
return
|
|
}
|
|
u.lastRefresh = mgr.now()
|
|
mgr.userScheduler.Add(u.NextRefresh(), u.GetId())
|
|
|
|
for _, s := range mgr.sessions.GetSessionsForUser(userID) {
|
|
if s.Session == nil || s.Session.OauthToken == nil {
|
|
log.Warn(ctx).
|
|
Str("user_id", userID).
|
|
Msg("no session oauth2 token found for refresh")
|
|
continue
|
|
}
|
|
|
|
err := authenticator.UpdateUserInfo(ctx, FromOAuthToken(s.OauthToken), &u)
|
|
metrics.RecordIdentityManagerUserRefresh(ctx, err)
|
|
mgr.recordLastError(metrics_ids.IdentityManagerLastUserRefreshError, err)
|
|
if isTemporaryError(err) {
|
|
log.Error(ctx).Err(err).
|
|
Str("user_id", s.GetUserId()).
|
|
Str("session_id", s.GetId()).
|
|
Msg("failed to update user info")
|
|
return
|
|
} else if err != nil {
|
|
log.Error(ctx).Err(err).
|
|
Str("user_id", s.GetUserId()).
|
|
Str("session_id", s.GetId()).
|
|
Msg("failed to update user info, deleting session")
|
|
mgr.deleteSession(ctx, userID, s.GetId())
|
|
continue
|
|
}
|
|
|
|
res, err := databroker.Put(ctx, mgr.cfg.Load().dataBrokerClient, u.User)
|
|
if err != nil {
|
|
log.Error(ctx).Err(err).
|
|
Str("user_id", s.GetUserId()).
|
|
Str("session_id", s.GetId()).
|
|
Msg("failed to update user")
|
|
continue
|
|
}
|
|
|
|
mgr.onUpdateUser(ctx, res.GetRecords()[0], u.User)
|
|
}
|
|
}
|
|
|
|
func (mgr *Manager) onUpdateRecords(ctx context.Context, msg updateRecordsMessage) {
|
|
for _, record := range msg.records {
|
|
switch record.GetType() {
|
|
case grpcutil.GetTypeURL(new(session.Session)):
|
|
var pbSession session.Session
|
|
err := record.GetData().UnmarshalTo(&pbSession)
|
|
if err != nil {
|
|
log.Warn(ctx).Msgf("error unmarshaling session: %s", err)
|
|
continue
|
|
}
|
|
mgr.onUpdateSession(record, &pbSession)
|
|
case grpcutil.GetTypeURL(new(user.User)):
|
|
var pbUser user.User
|
|
err := record.GetData().UnmarshalTo(&pbUser)
|
|
if err != nil {
|
|
log.Warn(ctx).Msgf("error unmarshaling user: %s", err)
|
|
continue
|
|
}
|
|
mgr.onUpdateUser(ctx, record, &pbUser)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (mgr *Manager) onUpdateSession(record *databroker.Record, session *session.Session) {
|
|
mgr.sessionScheduler.Remove(toSessionSchedulerKey(session.GetUserId(), session.GetId()))
|
|
|
|
if record.GetDeletedAt() != nil {
|
|
mgr.sessions.Delete(session.GetUserId(), session.GetId())
|
|
return
|
|
}
|
|
|
|
// update session
|
|
s, _ := mgr.sessions.Get(session.GetUserId(), session.GetId())
|
|
if s.lastRefresh.IsZero() {
|
|
s.lastRefresh = mgr.now()
|
|
}
|
|
s.gracePeriod = mgr.cfg.Load().sessionRefreshGracePeriod
|
|
s.coolOffDuration = mgr.cfg.Load().sessionRefreshCoolOffDuration
|
|
s.Session = session
|
|
mgr.sessions.ReplaceOrInsert(s)
|
|
mgr.sessionScheduler.Add(s.NextRefresh(), toSessionSchedulerKey(session.GetUserId(), session.GetId()))
|
|
}
|
|
|
|
func (mgr *Manager) onUpdateUser(_ context.Context, record *databroker.Record, user *user.User) {
|
|
mgr.userScheduler.Remove(user.GetId())
|
|
|
|
if record.GetDeletedAt() != nil {
|
|
mgr.users.Delete(user.GetId())
|
|
return
|
|
}
|
|
|
|
u, _ := mgr.users.Get(user.GetId())
|
|
u.lastRefresh = mgr.cfg.Load().now()
|
|
u.User = user
|
|
mgr.users.ReplaceOrInsert(u)
|
|
mgr.userScheduler.Add(u.NextRefresh(), u.GetId())
|
|
}
|
|
|
|
func (mgr *Manager) deleteSession(ctx context.Context, userID, sessionID string) {
|
|
mgr.sessionScheduler.Remove(toSessionSchedulerKey(userID, sessionID))
|
|
mgr.sessions.Delete(userID, sessionID)
|
|
|
|
client := mgr.cfg.Load().dataBrokerClient
|
|
res, err := client.Get(ctx, &databroker.GetRequest{
|
|
Type: grpcutil.GetTypeURL(new(session.Session)),
|
|
Id: sessionID,
|
|
})
|
|
if status.Code(err) == codes.NotFound {
|
|
return
|
|
} else if err != nil {
|
|
log.Error(ctx).Err(err).
|
|
Str("session_id", sessionID).
|
|
Msg("failed to delete session")
|
|
return
|
|
}
|
|
|
|
record := res.GetRecord()
|
|
record.DeletedAt = timestamppb.Now()
|
|
|
|
_, err = client.Put(ctx, &databroker.PutRequest{
|
|
Records: []*databroker.Record{record},
|
|
})
|
|
if err != nil {
|
|
log.Error(ctx).Err(err).
|
|
Str("session_id", sessionID).
|
|
Msg("failed to delete session")
|
|
return
|
|
}
|
|
}
|
|
|
|
// reset resets all the manager datastructures to their initial state
|
|
func (mgr *Manager) reset() {
|
|
mgr.sessions = sessionCollection{BTree: btree.New(8)}
|
|
mgr.users = userCollection{BTree: btree.New(8)}
|
|
}
|
|
|
|
func (mgr *Manager) recordLastError(id string, err error) {
|
|
if err == nil {
|
|
return
|
|
}
|
|
evtMgr := mgr.cfg.Load().eventMgr
|
|
if evtMgr == nil {
|
|
return
|
|
}
|
|
evtMgr.Dispatch(&events.LastError{
|
|
Time: timestamppb.Now(),
|
|
Message: err.Error(),
|
|
Id: id,
|
|
})
|
|
}
|
|
|
|
func isTemporaryError(err error) bool {
|
|
if err == nil {
|
|
return false
|
|
}
|
|
if errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled) {
|
|
return true
|
|
}
|
|
var hasTemporary interface{ Temporary() bool }
|
|
if errors.As(err, &hasTemporary) && hasTemporary.Temporary() {
|
|
return true
|
|
}
|
|
return false
|
|
}
|