mirror of
https://github.com/pomerium/pomerium.git
synced 2025-07-23 11:39:32 +02:00
wip
This commit is contained in:
parent
a6577fd570
commit
a80ef11763
11 changed files with 574 additions and 949 deletions
|
@ -196,16 +196,15 @@ func (s *Stateful) PersistSession(
|
|||
sess.SetRawIDToken(claims.RawIDToken)
|
||||
sess.AddClaims(claims.Flatten())
|
||||
|
||||
var managerUser manager.User
|
||||
managerUser.User, _ = user.Get(ctx, s.dataBrokerClient, sess.GetUserId())
|
||||
if managerUser.User == nil {
|
||||
u, _ := user.Get(ctx, s.dataBrokerClient, sess.GetUserId())
|
||||
if u == nil {
|
||||
// if no user exists yet, create a new one
|
||||
managerUser.User = &user.User{
|
||||
u = &user.User{
|
||||
Id: sess.GetUserId(),
|
||||
}
|
||||
}
|
||||
populateUserFromClaims(managerUser.User, claims.Claims)
|
||||
_, err := databroker.Put(ctx, s.dataBrokerClient, managerUser.User)
|
||||
populateUserFromClaims(u, claims.Claims)
|
||||
_, err := databroker.Put(ctx, s.dataBrokerClient, u)
|
||||
if err != nil {
|
||||
return fmt.Errorf("authenticate: error saving user: %w", err)
|
||||
}
|
||||
|
|
|
@ -10,6 +10,7 @@ import (
|
|||
var (
|
||||
defaultSessionRefreshGracePeriod = 1 * time.Minute
|
||||
defaultSessionRefreshCoolOffDuration = 10 * time.Second
|
||||
defaultUpdateUserInfoInterval = 10 * time.Minute
|
||||
)
|
||||
|
||||
type config struct {
|
||||
|
@ -17,6 +18,7 @@ type config struct {
|
|||
dataBrokerClient databroker.DataBrokerServiceClient
|
||||
sessionRefreshGracePeriod time.Duration
|
||||
sessionRefreshCoolOffDuration time.Duration
|
||||
updateUserInfoInterval time.Duration
|
||||
now func() time.Time
|
||||
eventMgr *events.Manager
|
||||
}
|
||||
|
@ -26,6 +28,7 @@ func newConfig(options ...Option) *config {
|
|||
WithSessionRefreshGracePeriod(defaultSessionRefreshGracePeriod)(cfg)
|
||||
WithSessionRefreshCoolOffDuration(defaultSessionRefreshCoolOffDuration)(cfg)
|
||||
WithNow(time.Now)(cfg)
|
||||
WithUpdateUserInfoInterval(defaultUpdateUserInfoInterval)(cfg)
|
||||
for _, option := range options {
|
||||
option(cfg)
|
||||
}
|
||||
|
@ -76,3 +79,10 @@ func WithEventManager(mgr *events.Manager) Option {
|
|||
c.eventMgr = mgr
|
||||
}
|
||||
}
|
||||
|
||||
// WithUpdateUserInfoInterval sets the update user info interval in the config.
|
||||
func WithUpdateUserInfoInterval(dur time.Duration) Option {
|
||||
return func(cfg *config) {
|
||||
cfg.updateUserInfoInterval = dur
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2,9 +2,9 @@ package manager
|
|||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/google/btree"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/identity"
|
||||
|
@ -12,66 +12,18 @@ import (
|
|||
"github.com/pomerium/pomerium/pkg/grpc/user"
|
||||
)
|
||||
|
||||
const userRefreshInterval = 10 * time.Minute
|
||||
|
||||
// A User is a user managed by the Manager.
|
||||
type User struct {
|
||||
*user.User
|
||||
lastRefresh time.Time
|
||||
}
|
||||
|
||||
// NextRefresh returns the next time the user information needs to be refreshed.
|
||||
func (u User) NextRefresh() time.Time {
|
||||
return u.lastRefresh.Add(userRefreshInterval)
|
||||
}
|
||||
|
||||
// UnmarshalJSON unmarshals json data into the user object.
|
||||
func (u *User) UnmarshalJSON(data []byte) error {
|
||||
if u.User == nil {
|
||||
u.User = new(user.User)
|
||||
}
|
||||
|
||||
var raw map[string]json.RawMessage
|
||||
err := json.Unmarshal(data, &raw)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if name, ok := raw["name"]; ok {
|
||||
_ = json.Unmarshal(name, &u.User.Name)
|
||||
delete(raw, "name")
|
||||
}
|
||||
if email, ok := raw["email"]; ok {
|
||||
_ = json.Unmarshal(email, &u.User.Email)
|
||||
delete(raw, "email")
|
||||
}
|
||||
|
||||
u.AddClaims(identity.NewClaimsFromRaw(raw).Flatten())
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// A Session is a session managed by the Manager.
|
||||
type Session struct {
|
||||
*session.Session
|
||||
// lastRefresh is the time of the last refresh attempt (which may or may
|
||||
// not have succeeded), or else the time the Manager first became aware of
|
||||
// the session (if it has not yet attempted to refresh this session).
|
||||
lastRefresh time.Time
|
||||
// gracePeriod is the amount of time before expiration to attempt a refresh.
|
||||
gracePeriod time.Duration
|
||||
// coolOffDuration is the amount of time to wait before attempting another refresh.
|
||||
coolOffDuration time.Duration
|
||||
}
|
||||
|
||||
// NextRefresh returns the next time the session needs to be refreshed.
|
||||
func (s Session) NextRefresh() time.Time {
|
||||
func nextSessionRefresh(
|
||||
s *session.Session,
|
||||
lastRefresh time.Time,
|
||||
gracePeriod time.Duration,
|
||||
coolOffDuration time.Duration,
|
||||
) time.Time {
|
||||
var tm time.Time
|
||||
|
||||
if s.GetOauthToken().GetExpiresAt() != nil {
|
||||
expiry := s.GetOauthToken().GetExpiresAt().AsTime()
|
||||
if s.GetOauthToken().GetExpiresAt().IsValid() && !expiry.IsZero() {
|
||||
expiry = expiry.Add(-s.gracePeriod)
|
||||
expiry = expiry.Add(-gracePeriod)
|
||||
if tm.IsZero() || expiry.Before(tm) {
|
||||
tm = expiry
|
||||
}
|
||||
|
@ -88,7 +40,7 @@ func (s Session) NextRefresh() time.Time {
|
|||
}
|
||||
|
||||
// don't refresh any quicker than the cool-off duration
|
||||
min := s.lastRefresh.Add(s.coolOffDuration)
|
||||
min := lastRefresh.Add(coolOffDuration)
|
||||
if tm.Before(min) {
|
||||
tm = min
|
||||
}
|
||||
|
@ -96,10 +48,33 @@ func (s Session) NextRefresh() time.Time {
|
|||
return tm
|
||||
}
|
||||
|
||||
// UnmarshalJSON unmarshals json data into the session object.
|
||||
func (s *Session) UnmarshalJSON(data []byte) error {
|
||||
if s.Session == nil {
|
||||
s.Session = new(session.Session)
|
||||
type multiUnmarshaler []any
|
||||
|
||||
func newMultiUnmarshaler(args ...any) *multiUnmarshaler {
|
||||
return (*multiUnmarshaler)(&args)
|
||||
}
|
||||
|
||||
func (dst *multiUnmarshaler) UnmarshalJSON(data []byte) error {
|
||||
var err error
|
||||
for _, o := range *dst {
|
||||
if o != nil {
|
||||
err = errors.Join(err, json.Unmarshal(data, o))
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
type sessionUnmarshaler struct {
|
||||
*session.Session
|
||||
}
|
||||
|
||||
func newSessionUnmarshaler(s *session.Session) *sessionUnmarshaler {
|
||||
return &sessionUnmarshaler{Session: s}
|
||||
}
|
||||
|
||||
func (dst *sessionUnmarshaler) UnmarshalJSON(data []byte) error {
|
||||
if dst.Session == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var raw map[string]json.RawMessage
|
||||
|
@ -108,159 +83,67 @@ func (s *Session) UnmarshalJSON(data []byte) error {
|
|||
return err
|
||||
}
|
||||
|
||||
if s.Session.IdToken == nil {
|
||||
s.Session.IdToken = new(session.IDToken)
|
||||
if dst.Session.IdToken == nil {
|
||||
dst.Session.IdToken = new(session.IDToken)
|
||||
}
|
||||
|
||||
if iss, ok := raw["iss"]; ok {
|
||||
_ = json.Unmarshal(iss, &s.Session.IdToken.Issuer)
|
||||
_ = json.Unmarshal(iss, &dst.Session.IdToken.Issuer)
|
||||
delete(raw, "iss")
|
||||
}
|
||||
if sub, ok := raw["sub"]; ok {
|
||||
_ = json.Unmarshal(sub, &s.Session.IdToken.Subject)
|
||||
_ = json.Unmarshal(sub, &dst.Session.IdToken.Subject)
|
||||
delete(raw, "sub")
|
||||
}
|
||||
if exp, ok := raw["exp"]; ok {
|
||||
var secs int64
|
||||
if err := json.Unmarshal(exp, &secs); err == nil {
|
||||
s.Session.IdToken.ExpiresAt = timestamppb.New(time.Unix(secs, 0))
|
||||
dst.Session.IdToken.ExpiresAt = timestamppb.New(time.Unix(secs, 0))
|
||||
}
|
||||
delete(raw, "exp")
|
||||
}
|
||||
if iat, ok := raw["iat"]; ok {
|
||||
var secs int64
|
||||
if err := json.Unmarshal(iat, &secs); err == nil {
|
||||
s.Session.IdToken.IssuedAt = timestamppb.New(time.Unix(secs, 0))
|
||||
dst.Session.IdToken.IssuedAt = timestamppb.New(time.Unix(secs, 0))
|
||||
}
|
||||
delete(raw, "iat")
|
||||
}
|
||||
|
||||
s.AddClaims(identity.NewClaimsFromRaw(raw).Flatten())
|
||||
dst.Session.AddClaims(identity.NewClaimsFromRaw(raw).Flatten())
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type sessionCollectionItem struct {
|
||||
Session
|
||||
type userUnmarshaler struct {
|
||||
*user.User
|
||||
}
|
||||
|
||||
func (item sessionCollectionItem) Less(than btree.Item) bool {
|
||||
xUserID, yUserID := item.GetUserId(), than.(sessionCollectionItem).GetUserId()
|
||||
switch {
|
||||
case xUserID < yUserID:
|
||||
return true
|
||||
case yUserID < xUserID:
|
||||
return false
|
||||
func newUserUnmarshaler(u *user.User) *userUnmarshaler {
|
||||
return &userUnmarshaler{User: u}
|
||||
}
|
||||
|
||||
func (dst *userUnmarshaler) UnmarshalJSON(data []byte) error {
|
||||
if dst.User == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
xID, yID := item.GetId(), than.(sessionCollectionItem).GetId()
|
||||
switch {
|
||||
case xID < yID:
|
||||
return true
|
||||
case yID < xID:
|
||||
return false
|
||||
var raw map[string]json.RawMessage
|
||||
err := json.Unmarshal(data, &raw)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
type sessionCollection struct {
|
||||
*btree.BTree
|
||||
}
|
||||
|
||||
func (c *sessionCollection) Delete(userID, sessionID string) {
|
||||
c.BTree.Delete(sessionCollectionItem{
|
||||
Session: Session{
|
||||
Session: &session.Session{
|
||||
UserId: userID,
|
||||
Id: sessionID,
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func (c *sessionCollection) Get(userID, sessionID string) (Session, bool) {
|
||||
item := c.BTree.Get(sessionCollectionItem{
|
||||
Session: Session{
|
||||
Session: &session.Session{
|
||||
UserId: userID,
|
||||
Id: sessionID,
|
||||
},
|
||||
},
|
||||
})
|
||||
if item == nil {
|
||||
return Session{}, false
|
||||
if name, ok := raw["name"]; ok {
|
||||
_ = json.Unmarshal(name, &dst.User.Name)
|
||||
delete(raw, "name")
|
||||
}
|
||||
return item.(sessionCollectionItem).Session, true
|
||||
}
|
||||
|
||||
// GetSessionsForUser gets all the sessions for the given user.
|
||||
func (c *sessionCollection) GetSessionsForUser(userID string) []Session {
|
||||
var sessions []Session
|
||||
c.AscendGreaterOrEqual(sessionCollectionItem{
|
||||
Session: Session{
|
||||
Session: &session.Session{
|
||||
UserId: userID,
|
||||
},
|
||||
},
|
||||
}, func(item btree.Item) bool {
|
||||
s := item.(sessionCollectionItem).Session
|
||||
if s.UserId != userID {
|
||||
return false
|
||||
}
|
||||
|
||||
sessions = append(sessions, s)
|
||||
return true
|
||||
})
|
||||
return sessions
|
||||
}
|
||||
|
||||
func (c *sessionCollection) ReplaceOrInsert(s Session) {
|
||||
c.BTree.ReplaceOrInsert(sessionCollectionItem{Session: s})
|
||||
}
|
||||
|
||||
type userCollectionItem struct {
|
||||
User
|
||||
}
|
||||
|
||||
func (item userCollectionItem) Less(than btree.Item) bool {
|
||||
xID, yID := item.GetId(), than.(userCollectionItem).GetId()
|
||||
switch {
|
||||
case xID < yID:
|
||||
return true
|
||||
case yID < xID:
|
||||
return false
|
||||
if email, ok := raw["email"]; ok {
|
||||
_ = json.Unmarshal(email, &dst.User.Email)
|
||||
delete(raw, "email")
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
type userCollection struct {
|
||||
*btree.BTree
|
||||
}
|
||||
dst.User.AddClaims(identity.NewClaimsFromRaw(raw).Flatten())
|
||||
|
||||
func (c *userCollection) Delete(userID string) {
|
||||
c.BTree.Delete(userCollectionItem{
|
||||
User: User{
|
||||
User: &user.User{
|
||||
Id: userID,
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func (c *userCollection) Get(userID string) (User, bool) {
|
||||
item := c.BTree.Get(userCollectionItem{
|
||||
User: User{
|
||||
User: &user.User{
|
||||
Id: userID,
|
||||
},
|
||||
},
|
||||
})
|
||||
if item == nil {
|
||||
return User{}, false
|
||||
}
|
||||
return item.(userCollectionItem).User, true
|
||||
}
|
||||
|
||||
func (c *userCollection) ReplaceOrInsert(u User) {
|
||||
c.BTree.ReplaceOrInsert(userCollectionItem{User: u})
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -11,20 +11,20 @@ import (
|
|||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/pomerium/pomerium/pkg/grpc/session"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/user"
|
||||
"github.com/pomerium/pomerium/pkg/protoutil"
|
||||
)
|
||||
|
||||
func TestUser_UnmarshalJSON(t *testing.T) {
|
||||
var u User
|
||||
u := new(user.User)
|
||||
err := json.Unmarshal([]byte(`{
|
||||
"name": "joe",
|
||||
"email": "joe@test.com",
|
||||
"some-other-claim": "xyz"
|
||||
}`), &u)
|
||||
}`), newUserUnmarshaler(u))
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, u.User)
|
||||
assert.Equal(t, "joe", u.User.Name)
|
||||
assert.Equal(t, "joe@test.com", u.User.Email)
|
||||
assert.Equal(t, "joe", u.Name)
|
||||
assert.Equal(t, "joe@test.com", u.Email)
|
||||
assert.Equal(t, map[string]*structpb.ListValue{
|
||||
"some-other-claim": {Values: []*structpb.Value{protoutil.ToStruct("xyz")}},
|
||||
}, u.Claims)
|
||||
|
@ -32,42 +32,38 @@ func TestUser_UnmarshalJSON(t *testing.T) {
|
|||
|
||||
func TestSession_NextRefresh(t *testing.T) {
|
||||
tm1 := time.Date(2020, 6, 5, 12, 0, 0, 0, time.UTC)
|
||||
s := Session{
|
||||
Session: &session.Session{},
|
||||
lastRefresh: tm1,
|
||||
gracePeriod: time.Second * 10,
|
||||
coolOffDuration: time.Minute,
|
||||
}
|
||||
assert.Equal(t, tm1.Add(time.Minute), s.NextRefresh())
|
||||
s := &session.Session{}
|
||||
gracePeriod := time.Second * 10
|
||||
coolOffDuration := time.Minute
|
||||
assert.Equal(t, tm1.Add(time.Minute), nextSessionRefresh(s, tm1, gracePeriod, coolOffDuration))
|
||||
|
||||
tm2 := time.Date(2020, 6, 5, 13, 0, 0, 0, time.UTC)
|
||||
s.OauthToken = &session.OAuthToken{
|
||||
ExpiresAt: timestamppb.New(tm2),
|
||||
}
|
||||
assert.Equal(t, tm2.Add(-time.Second*10), s.NextRefresh())
|
||||
assert.Equal(t, tm2.Add(-time.Second*10), nextSessionRefresh(s, tm1, gracePeriod, coolOffDuration))
|
||||
|
||||
tm3 := time.Date(2020, 6, 5, 12, 15, 0, 0, time.UTC)
|
||||
s.ExpiresAt = timestamppb.New(tm3)
|
||||
assert.Equal(t, tm3, s.NextRefresh())
|
||||
assert.Equal(t, tm3, nextSessionRefresh(s, tm1, gracePeriod, coolOffDuration))
|
||||
}
|
||||
|
||||
func TestSession_UnmarshalJSON(t *testing.T) {
|
||||
tm := time.Date(2020, 6, 5, 12, 0, 0, 0, time.UTC)
|
||||
var s Session
|
||||
s := new(session.Session)
|
||||
err := json.Unmarshal([]byte(`{
|
||||
"iss": "https://some.issuer.com",
|
||||
"sub": "subject",
|
||||
"exp": `+fmt.Sprint(tm.Unix())+`,
|
||||
"iat": `+fmt.Sprint(tm.Unix())+`,
|
||||
"some-other-claim": "xyz"
|
||||
}`), &s)
|
||||
}`), newSessionUnmarshaler(s))
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, s.Session)
|
||||
assert.NotNil(t, s.Session.IdToken)
|
||||
assert.Equal(t, "https://some.issuer.com", s.Session.IdToken.Issuer)
|
||||
assert.Equal(t, "subject", s.Session.IdToken.Subject)
|
||||
assert.Equal(t, timestamppb.New(tm), s.Session.IdToken.ExpiresAt)
|
||||
assert.Equal(t, timestamppb.New(tm), s.Session.IdToken.IssuedAt)
|
||||
assert.NotNil(t, s.IdToken)
|
||||
assert.Equal(t, "https://some.issuer.com", s.IdToken.Issuer)
|
||||
assert.Equal(t, "subject", s.IdToken.Subject)
|
||||
assert.Equal(t, timestamppb.New(tm), s.IdToken.ExpiresAt)
|
||||
assert.Equal(t, timestamppb.New(tm), s.IdToken.IssuedAt)
|
||||
assert.Equal(t, map[string]*structpb.ListValue{
|
||||
"some-other-claim": {Values: []*structpb.Value{protoutil.ToStruct("xyz")}},
|
||||
}, s.Claims)
|
||||
|
|
|
@ -3,7 +3,6 @@ package manager
|
|||
import (
|
||||
"cmp"
|
||||
"slices"
|
||||
"sync"
|
||||
|
||||
"google.golang.org/protobuf/proto"
|
||||
|
||||
|
@ -11,44 +10,54 @@ import (
|
|||
"github.com/pomerium/pomerium/pkg/grpc/user"
|
||||
)
|
||||
|
||||
// dataStore stores session and user data. All public methods are thread-safe.
|
||||
// dataStore stores session and user data
|
||||
type dataStore struct {
|
||||
mu sync.Mutex
|
||||
sessions map[string]*session.Session
|
||||
users map[string]*user.User
|
||||
userIDToSessionIDs map[string]map[string]struct{}
|
||||
}
|
||||
|
||||
func newDataStore() *dataStore {
|
||||
return &dataStore{
|
||||
sessions: make(map[string]*session.Session),
|
||||
users: make(map[string]*user.User),
|
||||
userIDToSessionIDs: make(map[string]map[string]struct{}),
|
||||
ds := new(dataStore)
|
||||
ds.deleteAllSessions()
|
||||
ds.deleteAllUsers()
|
||||
return ds
|
||||
}
|
||||
|
||||
func (ds *dataStore) deleteAllSessions() {
|
||||
ds.sessions = make(map[string]*session.Session)
|
||||
ds.userIDToSessionIDs = make(map[string]map[string]struct{})
|
||||
}
|
||||
|
||||
func (ds *dataStore) deleteAllUsers() {
|
||||
ds.users = make(map[string]*user.User)
|
||||
}
|
||||
|
||||
func (ds *dataStore) deleteSession(sessionID string) {
|
||||
s := ds.sessions[sessionID]
|
||||
delete(ds.sessions, sessionID)
|
||||
if s.GetUserId() == "" {
|
||||
return
|
||||
}
|
||||
|
||||
m := ds.userIDToSessionIDs[s.GetUserId()]
|
||||
if m != nil {
|
||||
delete(m, s.GetId())
|
||||
}
|
||||
if len(m) == 0 {
|
||||
delete(ds.userIDToSessionIDs, s.GetUserId())
|
||||
}
|
||||
}
|
||||
|
||||
// DeleteSession deletes a session.
|
||||
func (ds *dataStore) DeleteSession(sID string) {
|
||||
ds.mu.Lock()
|
||||
ds.deleteSessionLocked(sID)
|
||||
ds.mu.Unlock()
|
||||
}
|
||||
|
||||
// DeleteUser deletes a user.
|
||||
func (ds *dataStore) DeleteUser(userID string) {
|
||||
ds.mu.Lock()
|
||||
func (ds *dataStore) deleteUser(userID string) {
|
||||
delete(ds.users, userID)
|
||||
ds.mu.Unlock()
|
||||
}
|
||||
|
||||
// GetSessionAndUser gets a session and its associated user.
|
||||
func (ds *dataStore) GetSessionAndUser(sessionID string) (s *session.Session, u *user.User) {
|
||||
ds.mu.Lock()
|
||||
func (ds *dataStore) getSessionAndUser(sessionID string) (s *session.Session, u *user.User) {
|
||||
s = ds.sessions[sessionID]
|
||||
if s.GetUserId() != "" {
|
||||
u = ds.users[s.GetUserId()]
|
||||
}
|
||||
ds.mu.Unlock()
|
||||
|
||||
// clone to avoid sharing memory
|
||||
s = clone(s)
|
||||
|
@ -56,14 +65,11 @@ func (ds *dataStore) GetSessionAndUser(sessionID string) (s *session.Session, u
|
|||
return s, u
|
||||
}
|
||||
|
||||
// GetUserAndSessions gets a user and all of its associated sessions.
|
||||
func (ds *dataStore) GetUserAndSessions(userID string) (u *user.User, ss []*session.Session) {
|
||||
ds.mu.Lock()
|
||||
func (ds *dataStore) getUserAndSessions(userID string) (u *user.User, ss []*session.Session) {
|
||||
u = ds.users[userID]
|
||||
for sessionID := range ds.userIDToSessionIDs[userID] {
|
||||
ss = append(ss, ds.sessions[sessionID])
|
||||
}
|
||||
ds.mu.Unlock()
|
||||
|
||||
// remove nils and sort by id
|
||||
ss = slices.Compact(ss)
|
||||
|
@ -79,14 +85,12 @@ func (ds *dataStore) GetUserAndSessions(userID string) (u *user.User, ss []*sess
|
|||
return u, ss
|
||||
}
|
||||
|
||||
// PutSession stores the session.
|
||||
func (ds *dataStore) PutSession(s *session.Session) {
|
||||
func (ds *dataStore) putSession(s *session.Session) {
|
||||
// clone to avoid sharing memory
|
||||
s = clone(s)
|
||||
|
||||
ds.mu.Lock()
|
||||
if s.GetId() != "" {
|
||||
ds.deleteSessionLocked(s.GetId())
|
||||
ds.deleteSession(s.GetId())
|
||||
ds.sessions[s.GetId()] = s
|
||||
if s.GetUserId() != "" {
|
||||
m, ok := ds.userIDToSessionIDs[s.GetUserId()]
|
||||
|
@ -97,35 +101,15 @@ func (ds *dataStore) PutSession(s *session.Session) {
|
|||
m[s.GetId()] = struct{}{}
|
||||
}
|
||||
}
|
||||
ds.mu.Unlock()
|
||||
}
|
||||
|
||||
// PutUser stores the user.
|
||||
func (ds *dataStore) PutUser(u *user.User) {
|
||||
func (ds *dataStore) putUser(u *user.User) {
|
||||
// clone to avoid sharing memory
|
||||
u = clone(u)
|
||||
|
||||
ds.mu.Lock()
|
||||
if u.GetId() != "" {
|
||||
ds.users[u.GetId()] = u
|
||||
}
|
||||
ds.mu.Unlock()
|
||||
}
|
||||
|
||||
func (ds *dataStore) deleteSessionLocked(sID string) {
|
||||
s := ds.sessions[sID]
|
||||
delete(ds.sessions, sID)
|
||||
if s.GetUserId() == "" {
|
||||
return
|
||||
}
|
||||
|
||||
m := ds.userIDToSessionIDs[s.GetUserId()]
|
||||
if m != nil {
|
||||
delete(m, s.GetId())
|
||||
}
|
||||
if len(m) == 0 {
|
||||
delete(ds.userIDToSessionIDs, s.GetUserId())
|
||||
}
|
||||
}
|
||||
|
||||
// clone clones a protobuf message
|
||||
|
|
|
@ -15,32 +15,32 @@ func TestDataStore(t *testing.T) {
|
|||
t.Parallel()
|
||||
|
||||
ds := newDataStore()
|
||||
s, u := ds.GetSessionAndUser("S1")
|
||||
s, u := ds.getSessionAndUser("S1")
|
||||
assert.Nil(t, s, "should return a nil session when none exists")
|
||||
assert.Nil(t, u, "should return a nil user when none exists")
|
||||
|
||||
u, ss := ds.GetUserAndSessions("U1")
|
||||
u, ss := ds.getUserAndSessions("U1")
|
||||
assert.Nil(t, u, "should return a nil user when none exists")
|
||||
assert.Empty(t, ss, "should return an empty list of sessions when no user exists")
|
||||
|
||||
s = &session.Session{Id: "S1", UserId: "U1"}
|
||||
ds.PutSession(s)
|
||||
s1, u1 := ds.GetSessionAndUser("S1")
|
||||
ds.putSession(s)
|
||||
s1, u1 := ds.getSessionAndUser("S1")
|
||||
assert.NotNil(t, s1, "should return a non-nil session")
|
||||
assert.False(t, s == s1, "should return different pointers")
|
||||
assert.Empty(t, cmp.Diff(s, s1, protocmp.Transform()), "should be the same as was entered")
|
||||
assert.Nil(t, u1, "should return a nil user when only the session exists")
|
||||
|
||||
ds.PutUser(&user.User{
|
||||
ds.putUser(&user.User{
|
||||
Id: "U1",
|
||||
})
|
||||
_, u1 = ds.GetSessionAndUser("S1")
|
||||
_, u1 = ds.getSessionAndUser("S1")
|
||||
assert.NotNil(t, u1, "should return a user now that it has been added")
|
||||
|
||||
ds.PutSession(&session.Session{Id: "S4", UserId: "U1"})
|
||||
ds.PutSession(&session.Session{Id: "S3", UserId: "U1"})
|
||||
ds.PutSession(&session.Session{Id: "S2", UserId: "U1"})
|
||||
u, ss = ds.GetUserAndSessions("U1")
|
||||
ds.putSession(&session.Session{Id: "S4", UserId: "U1"})
|
||||
ds.putSession(&session.Session{Id: "S3", UserId: "U1"})
|
||||
ds.putSession(&session.Session{Id: "S2", UserId: "U1"})
|
||||
u, ss = ds.getUserAndSessions("U1")
|
||||
assert.NotNil(t, u)
|
||||
assert.Empty(t, cmp.Diff(ss, []*session.Session{
|
||||
{Id: "S1", UserId: "U1"},
|
||||
|
@ -49,9 +49,9 @@ func TestDataStore(t *testing.T) {
|
|||
{Id: "S4", UserId: "U1"},
|
||||
}, protocmp.Transform()), "should return all sessions in id order")
|
||||
|
||||
ds.DeleteSession("S4")
|
||||
ds.deleteSession("S4")
|
||||
|
||||
u, ss = ds.GetUserAndSessions("U1")
|
||||
u, ss = ds.getUserAndSessions("U1")
|
||||
assert.NotNil(t, u)
|
||||
assert.Empty(t, cmp.Diff(ss, []*session.Session{
|
||||
{Id: "S1", UserId: "U1"},
|
||||
|
@ -59,8 +59,8 @@ func TestDataStore(t *testing.T) {
|
|||
{Id: "S3", UserId: "U1"},
|
||||
}, protocmp.Transform()), "should return all sessions in id order")
|
||||
|
||||
ds.DeleteUser("U1")
|
||||
u, ss = ds.GetUserAndSessions("U1")
|
||||
ds.deleteUser("U1")
|
||||
u, ss = ds.getUserAndSessions("U1")
|
||||
assert.Nil(t, u)
|
||||
assert.Empty(t, cmp.Diff(ss, []*session.Session{
|
||||
{Id: "S1", UserId: "U1"},
|
||||
|
@ -68,11 +68,11 @@ func TestDataStore(t *testing.T) {
|
|||
{Id: "S3", UserId: "U1"},
|
||||
}, protocmp.Transform()), "should still return all sessions in id order")
|
||||
|
||||
ds.DeleteSession("S1")
|
||||
ds.DeleteSession("S2")
|
||||
ds.DeleteSession("S3")
|
||||
ds.deleteSession("S1")
|
||||
ds.deleteSession("S2")
|
||||
ds.deleteSession("S3")
|
||||
|
||||
u, ss = ds.GetUserAndSessions("U1")
|
||||
u, ss = ds.getUserAndSessions("U1")
|
||||
assert.Nil(t, u)
|
||||
assert.Empty(t, ss)
|
||||
}
|
||||
|
|
|
@ -3,10 +3,9 @@ package manager
|
|||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/btree"
|
||||
"github.com/rs/zerolog"
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
@ -19,7 +18,6 @@ import (
|
|||
"github.com/pomerium/pomerium/internal/events"
|
||||
"github.com/pomerium/pomerium/internal/identity/identity"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/internal/scheduler"
|
||||
"github.com/pomerium/pomerium/internal/telemetry/metrics"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/session"
|
||||
|
@ -35,21 +33,14 @@ type Authenticator interface {
|
|||
UpdateUserInfo(context.Context, *oauth2.Token, interface{}) error
|
||||
}
|
||||
|
||||
type (
|
||||
updateRecordsMessage struct {
|
||||
records []*databroker.Record
|
||||
}
|
||||
)
|
||||
|
||||
// A Manager refreshes identity information using session and user data.
|
||||
type Manager struct {
|
||||
cfg *atomicutil.Value[*config]
|
||||
|
||||
sessionScheduler *scheduler.Scheduler
|
||||
userScheduler *scheduler.Scheduler
|
||||
|
||||
sessions sessionCollection
|
||||
users userCollection
|
||||
mu sync.Mutex
|
||||
dataStore *dataStore
|
||||
refreshSessionSchedulers map[string]*refreshSessionScheduler
|
||||
updateUserInfoSchedulers map[string]*updateUserInfoScheduler
|
||||
}
|
||||
|
||||
// New creates a new identity manager.
|
||||
|
@ -59,25 +50,24 @@ func New(
|
|||
mgr := &Manager{
|
||||
cfg: atomicutil.NewValue(newConfig()),
|
||||
|
||||
sessionScheduler: scheduler.New(),
|
||||
userScheduler: scheduler.New(),
|
||||
dataStore: newDataStore(),
|
||||
refreshSessionSchedulers: make(map[string]*refreshSessionScheduler),
|
||||
updateUserInfoSchedulers: make(map[string]*updateUserInfoScheduler),
|
||||
}
|
||||
mgr.reset()
|
||||
mgr.UpdateConfig(options...)
|
||||
return mgr
|
||||
}
|
||||
|
||||
func withLog(ctx context.Context) context.Context {
|
||||
return log.WithContext(ctx, func(c zerolog.Context) zerolog.Context {
|
||||
return c.Str("service", "identity_manager")
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateConfig updates the manager with the new options.
|
||||
func (mgr *Manager) UpdateConfig(options ...Option) {
|
||||
mgr.cfg.Store(newConfig(options...))
|
||||
}
|
||||
|
||||
// GetDataBrokerServiceClient gets the databroker client.
|
||||
func (mgr *Manager) GetDataBrokerServiceClient() databroker.DataBrokerServiceClient {
|
||||
return mgr.cfg.Load().dataBrokerClient
|
||||
}
|
||||
|
||||
// Run runs the manager. This method blocks until an error occurs or the given context is canceled.
|
||||
func (mgr *Manager) Run(ctx context.Context) error {
|
||||
leaser := databroker.NewLeaser("identity_manager", time.Second*30, mgr)
|
||||
|
@ -86,173 +76,144 @@ func (mgr *Manager) Run(ctx context.Context) error {
|
|||
|
||||
// RunLeased runs the identity manager when a lease is acquired.
|
||||
func (mgr *Manager) RunLeased(ctx context.Context) error {
|
||||
ctx = withLog(ctx)
|
||||
update := make(chan updateRecordsMessage, 1)
|
||||
clear := make(chan struct{}, 1)
|
||||
|
||||
syncer := newDataBrokerSyncer(ctx, mgr.cfg, update, clear)
|
||||
|
||||
ctx = log.WithContext(ctx, func(c zerolog.Context) zerolog.Context {
|
||||
return c.Str("service", "identity_manager")
|
||||
})
|
||||
eg, ctx := errgroup.WithContext(ctx)
|
||||
eg.Go(func() error {
|
||||
return syncer.Run(ctx)
|
||||
sessionSyncer := newSessionSyncer(mgr)
|
||||
defer sessionSyncer.Close()
|
||||
return sessionSyncer.Run(ctx)
|
||||
})
|
||||
eg.Go(func() error {
|
||||
return mgr.refreshLoop(ctx, update, clear)
|
||||
userSyncer := newUserSyncer(mgr)
|
||||
defer userSyncer.Close()
|
||||
return userSyncer.Run(ctx)
|
||||
})
|
||||
|
||||
return eg.Wait()
|
||||
}
|
||||
|
||||
// GetDataBrokerServiceClient gets the databroker client.
|
||||
func (mgr *Manager) GetDataBrokerServiceClient() databroker.DataBrokerServiceClient {
|
||||
return mgr.cfg.Load().dataBrokerClient
|
||||
func (mgr *Manager) onDeleteAllSessions(ctx context.Context) {
|
||||
log.Ctx(ctx).Debug().Msg("all session deleted")
|
||||
|
||||
mgr.mu.Lock()
|
||||
mgr.dataStore.deleteAllSessions()
|
||||
for sID, rss := range mgr.refreshSessionSchedulers {
|
||||
rss.Stop()
|
||||
delete(mgr.refreshSessionSchedulers, sID)
|
||||
}
|
||||
mgr.mu.Unlock()
|
||||
}
|
||||
|
||||
func (mgr *Manager) now() time.Time {
|
||||
return mgr.cfg.Load().now()
|
||||
func (mgr *Manager) onDeleteAllUsers(ctx context.Context) {
|
||||
log.Ctx(ctx).Debug().Msg("all users deleted")
|
||||
|
||||
mgr.mu.Lock()
|
||||
mgr.dataStore.deleteAllUsers()
|
||||
for uID, uuis := range mgr.updateUserInfoSchedulers {
|
||||
uuis.Stop()
|
||||
delete(mgr.updateUserInfoSchedulers, uID)
|
||||
}
|
||||
mgr.mu.Unlock()
|
||||
}
|
||||
|
||||
func (mgr *Manager) refreshLoop(ctx context.Context, update <-chan updateRecordsMessage, clear <-chan struct{}) error {
|
||||
// wait for initial sync
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-clear:
|
||||
mgr.reset()
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case msg := <-update:
|
||||
mgr.onUpdateRecords(ctx, msg)
|
||||
}
|
||||
|
||||
log.Debug(ctx).
|
||||
Int("sessions", mgr.sessions.Len()).
|
||||
Int("users", mgr.users.Len()).
|
||||
Msg("initial sync complete")
|
||||
|
||||
// start refreshing
|
||||
maxWait := time.Minute * 10
|
||||
var nextTime time.Time
|
||||
|
||||
timer := time.NewTimer(0)
|
||||
defer timer.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-clear:
|
||||
mgr.reset()
|
||||
case msg := <-update:
|
||||
mgr.onUpdateRecords(ctx, msg)
|
||||
case <-timer.C:
|
||||
}
|
||||
|
||||
now := mgr.now()
|
||||
nextTime = now.Add(maxWait)
|
||||
|
||||
// refresh sessions
|
||||
for {
|
||||
tm, key := mgr.sessionScheduler.Next()
|
||||
if now.Before(tm) {
|
||||
if tm.Before(nextTime) {
|
||||
nextTime = tm
|
||||
}
|
||||
break
|
||||
}
|
||||
mgr.sessionScheduler.Remove(key)
|
||||
|
||||
userID, sessionID := fromSessionSchedulerKey(key)
|
||||
mgr.refreshSession(ctx, userID, sessionID)
|
||||
}
|
||||
|
||||
// refresh users
|
||||
for {
|
||||
tm, key := mgr.userScheduler.Next()
|
||||
if now.Before(tm) {
|
||||
if tm.Before(nextTime) {
|
||||
nextTime = tm
|
||||
}
|
||||
break
|
||||
}
|
||||
mgr.userScheduler.Remove(key)
|
||||
|
||||
mgr.refreshUser(ctx, key)
|
||||
}
|
||||
|
||||
metrics.RecordIdentityManagerLastRefresh(ctx)
|
||||
timer.Reset(time.Until(nextTime))
|
||||
func (mgr *Manager) onDeleteSession(ctx context.Context, sessionID string) {
|
||||
log.Ctx(ctx).Debug().Str("session_id", sessionID).Msg("session deleted")
|
||||
|
||||
mgr.mu.Lock()
|
||||
mgr.dataStore.deleteSession(sessionID)
|
||||
if rss, ok := mgr.refreshSessionSchedulers[sessionID]; ok {
|
||||
rss.Stop()
|
||||
delete(mgr.refreshSessionSchedulers, sessionID)
|
||||
}
|
||||
mgr.mu.Unlock()
|
||||
}
|
||||
|
||||
// refreshSession handles two distinct session lifecycle events:
|
||||
//
|
||||
// 1. If the session itself has expired, delete the session.
|
||||
// 2. If the session's underlying OAuth2 access token is nearing expiration
|
||||
// (but the session itself is still valid), refresh the access token.
|
||||
//
|
||||
// After a successful access token refresh, this method will also trigger a
|
||||
// user info refresh. If an access token refresh or a user info refresh fails
|
||||
// with a permanent error, the session will be deleted.
|
||||
func (mgr *Manager) refreshSession(ctx context.Context, userID, sessionID string) {
|
||||
log.Info(ctx).
|
||||
Str("user_id", userID).
|
||||
func (mgr *Manager) onDeleteUser(ctx context.Context, userID string) {
|
||||
log.Ctx(ctx).Debug().Str("user_id", userID).Msg("user deleted")
|
||||
|
||||
mgr.mu.Lock()
|
||||
mgr.dataStore.deleteUser(userID)
|
||||
if uuis, ok := mgr.updateUserInfoSchedulers[userID]; ok {
|
||||
uuis.Stop()
|
||||
delete(mgr.updateUserInfoSchedulers, userID)
|
||||
}
|
||||
mgr.mu.Unlock()
|
||||
}
|
||||
|
||||
func (mgr *Manager) onUpdateSession(ctx context.Context, s *session.Session) {
|
||||
log.Ctx(ctx).Debug().Str("session_id", s.GetId()).Msg("session updated")
|
||||
|
||||
mgr.mu.Lock()
|
||||
mgr.dataStore.putSession(s)
|
||||
rss, ok := mgr.refreshSessionSchedulers[s.GetId()]
|
||||
if !ok {
|
||||
rss = newRefreshSessionScheduler(ctx, mgr, s.GetId())
|
||||
mgr.refreshSessionSchedulers[s.GetId()] = rss
|
||||
}
|
||||
rss.Update(s)
|
||||
mgr.mu.Unlock()
|
||||
}
|
||||
|
||||
func (mgr *Manager) onUpdateUser(ctx context.Context, u *user.User) {
|
||||
log.Ctx(ctx).Debug().Str("user_id", u.GetId()).Msg("user updated")
|
||||
|
||||
mgr.mu.Lock()
|
||||
mgr.dataStore.putUser(u)
|
||||
_, ok := mgr.updateUserInfoSchedulers[u.GetId()]
|
||||
if !ok {
|
||||
uuis := newUpdateUserInfoScheduler(ctx, mgr, u.GetId())
|
||||
mgr.updateUserInfoSchedulers[u.GetId()] = uuis
|
||||
}
|
||||
mgr.mu.Unlock()
|
||||
}
|
||||
|
||||
func (mgr *Manager) refreshSession(ctx context.Context, sessionID string) {
|
||||
log.Ctx(ctx).Debug().
|
||||
Str("session_id", sessionID).
|
||||
Msg("refreshing session")
|
||||
|
||||
s, ok := mgr.sessions.Get(userID, sessionID)
|
||||
if !ok {
|
||||
log.Warn(ctx).
|
||||
Str("user_id", userID).
|
||||
mgr.mu.Lock()
|
||||
s, u := mgr.dataStore.getSessionAndUser(sessionID)
|
||||
mgr.mu.Unlock()
|
||||
|
||||
if s == nil {
|
||||
log.Ctx(ctx).Warn().
|
||||
Str("user_id", u.GetId()).
|
||||
Str("session_id", sessionID).
|
||||
Msg("no session found for refresh")
|
||||
return
|
||||
}
|
||||
|
||||
s.lastRefresh = mgr.now()
|
||||
|
||||
if mgr.refreshSessionInternal(ctx, userID, sessionID, &s) {
|
||||
mgr.sessions.ReplaceOrInsert(s)
|
||||
mgr.sessionScheduler.Add(s.NextRefresh(), toSessionSchedulerKey(userID, sessionID))
|
||||
}
|
||||
}
|
||||
|
||||
// refreshSessionInternal performs the core refresh logic and returns true if
|
||||
// the session should be scheduled for refresh again, or false if not.
|
||||
func (mgr *Manager) refreshSessionInternal(
|
||||
ctx context.Context, userID, sessionID string, s *Session,
|
||||
) bool {
|
||||
authenticator := mgr.cfg.Load().authenticator
|
||||
if authenticator == nil {
|
||||
log.Info(ctx).
|
||||
Str("user_id", userID).
|
||||
Str("session_id", sessionID).
|
||||
log.Ctx(ctx).Info().
|
||||
Str("user_id", s.GetUserId()).
|
||||
Str("session_id", s.GetId()).
|
||||
Msg("no authenticator defined, deleting session")
|
||||
mgr.deleteSession(ctx, userID, sessionID)
|
||||
return false
|
||||
mgr.deleteSession(ctx, sessionID)
|
||||
return
|
||||
}
|
||||
|
||||
expiry := s.GetExpiresAt().AsTime()
|
||||
if !expiry.After(mgr.now()) {
|
||||
if !expiry.After(mgr.cfg.Load().now()) {
|
||||
log.Info(ctx).
|
||||
Str("user_id", userID).
|
||||
Str("session_id", sessionID).
|
||||
Str("user_id", s.GetUserId()).
|
||||
Str("session_id", s.GetId()).
|
||||
Msg("deleting expired session")
|
||||
mgr.deleteSession(ctx, userID, sessionID)
|
||||
return false
|
||||
mgr.deleteSession(ctx, sessionID)
|
||||
return
|
||||
}
|
||||
|
||||
if s.Session == nil || s.Session.OauthToken == nil {
|
||||
if s.GetOauthToken() == nil {
|
||||
log.Warn(ctx).
|
||||
Str("user_id", userID).
|
||||
Str("session_id", sessionID).
|
||||
Str("user_id", s.GetUserId()).
|
||||
Str("session_id", s.GetId()).
|
||||
Msg("no session oauth2 token found for refresh")
|
||||
return false
|
||||
return
|
||||
}
|
||||
|
||||
newToken, err := authenticator.Refresh(ctx, FromOAuthToken(s.OauthToken), s)
|
||||
newToken, err := authenticator.Refresh(ctx, FromOAuthToken(s.OauthToken), newSessionUnmarshaler(s))
|
||||
metrics.RecordIdentityManagerSessionRefresh(ctx, err)
|
||||
mgr.recordLastError(metrics_ids.IdentityManagerLastSessionRefreshError, err)
|
||||
if isTemporaryError(err) {
|
||||
|
@ -260,18 +221,18 @@ func (mgr *Manager) refreshSessionInternal(
|
|||
Str("user_id", s.GetUserId()).
|
||||
Str("session_id", s.GetId()).
|
||||
Msg("failed to refresh oauth2 token")
|
||||
return true
|
||||
return
|
||||
} else if err != nil {
|
||||
log.Error(ctx).Err(err).
|
||||
Str("user_id", s.GetUserId()).
|
||||
Str("session_id", s.GetId()).
|
||||
Msg("failed to refresh oauth2 token, deleting session")
|
||||
mgr.deleteSession(ctx, userID, sessionID)
|
||||
return false
|
||||
mgr.deleteSession(ctx, sessionID)
|
||||
return
|
||||
}
|
||||
s.OauthToken = ToOAuthToken(newToken)
|
||||
|
||||
err = authenticator.UpdateUserInfo(ctx, FromOAuthToken(s.OauthToken), s)
|
||||
err = authenticator.UpdateUserInfo(ctx, FromOAuthToken(s.OauthToken), newMultiUnmarshaler(newUserUnmarshaler(u), newSessionUnmarshaler(s)))
|
||||
metrics.RecordIdentityManagerUserRefresh(ctx, err)
|
||||
mgr.recordLastError(metrics_ids.IdentityManagerLastUserRefreshError, err)
|
||||
if isTemporaryError(err) {
|
||||
|
@ -279,184 +240,162 @@ func (mgr *Manager) refreshSessionInternal(
|
|||
Str("user_id", s.GetUserId()).
|
||||
Str("session_id", s.GetId()).
|
||||
Msg("failed to update user info")
|
||||
return true
|
||||
return
|
||||
} else if err != nil {
|
||||
log.Error(ctx).Err(err).
|
||||
Str("user_id", s.GetUserId()).
|
||||
Str("session_id", s.GetId()).
|
||||
Msg("failed to update user info, deleting session")
|
||||
mgr.deleteSession(ctx, userID, sessionID)
|
||||
return false
|
||||
mgr.deleteSession(ctx, sessionID)
|
||||
return
|
||||
}
|
||||
|
||||
fm, err := fieldmaskpb.New(s.Session, "oauth_token", "id_token", "claims")
|
||||
if err != nil {
|
||||
log.Error(ctx).Err(err).Msg("internal error")
|
||||
return false
|
||||
mgr.updateSession(ctx, s)
|
||||
if u != nil {
|
||||
mgr.updateUser(ctx, u)
|
||||
}
|
||||
|
||||
if _, err := session.Patch(ctx, mgr.cfg.Load().dataBrokerClient, s.Session, fm); err != nil {
|
||||
log.Error(ctx).Err(err).
|
||||
Str("user_id", s.GetUserId()).
|
||||
Str("session_id", s.GetId()).
|
||||
Msg("failed to update session")
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (mgr *Manager) refreshUser(ctx context.Context, userID string) {
|
||||
log.Info(ctx).
|
||||
Str("user_id", userID).
|
||||
Msg("refreshing user")
|
||||
func (mgr *Manager) updateUserInfo(ctx context.Context, userID string) {
|
||||
log.Ctx(ctx).Info().Str("user_id", userID).Msg("updating user info")
|
||||
|
||||
authenticator := mgr.cfg.Load().authenticator
|
||||
if authenticator == nil {
|
||||
return
|
||||
}
|
||||
|
||||
u, ok := mgr.users.Get(userID)
|
||||
if !ok {
|
||||
log.Warn(ctx).
|
||||
mgr.mu.Lock()
|
||||
u, ss := mgr.dataStore.getUserAndSessions(userID)
|
||||
mgr.mu.Unlock()
|
||||
|
||||
if u == nil {
|
||||
log.Ctx(ctx).Warn().
|
||||
Str("user_id", userID).
|
||||
Msg("no user found for refresh")
|
||||
Msg("no user found for update")
|
||||
return
|
||||
}
|
||||
u.lastRefresh = mgr.now()
|
||||
mgr.userScheduler.Add(u.NextRefresh(), u.GetId())
|
||||
|
||||
for _, s := range mgr.sessions.GetSessionsForUser(userID) {
|
||||
if s.Session == nil || s.Session.OauthToken == nil {
|
||||
log.Warn(ctx).
|
||||
Str("user_id", userID).
|
||||
Msg("no session oauth2 token found for refresh")
|
||||
for _, s := range ss {
|
||||
if s.GetOauthToken() == nil {
|
||||
log.Ctx(ctx).Warn().
|
||||
Str("user_id", s.GetUserId()).
|
||||
Str("session_id", s.GetId()).
|
||||
Msg("no session oauth2 token found for updating user info")
|
||||
continue
|
||||
}
|
||||
|
||||
err := authenticator.UpdateUserInfo(ctx, FromOAuthToken(s.OauthToken), &u)
|
||||
err := authenticator.UpdateUserInfo(ctx, FromOAuthToken(s.GetOauthToken()), newMultiUnmarshaler(newUserUnmarshaler(u), newSessionUnmarshaler(s)))
|
||||
metrics.RecordIdentityManagerUserRefresh(ctx, err)
|
||||
mgr.recordLastError(metrics_ids.IdentityManagerLastUserRefreshError, err)
|
||||
if isTemporaryError(err) {
|
||||
log.Error(ctx).Err(err).
|
||||
log.Ctx(ctx).Error().Err(err).
|
||||
Str("user_id", s.GetUserId()).
|
||||
Str("session_id", s.GetId()).
|
||||
Msg("failed to update user info")
|
||||
return
|
||||
continue
|
||||
} else if err != nil {
|
||||
log.Error(ctx).Err(err).
|
||||
log.Ctx(ctx).Error().Err(err).
|
||||
Str("user_id", s.GetUserId()).
|
||||
Str("session_id", s.GetId()).
|
||||
Msg("failed to update user info, deleting session")
|
||||
mgr.deleteSession(ctx, userID, s.GetId())
|
||||
mgr.deleteSession(ctx, s.GetId())
|
||||
continue
|
||||
}
|
||||
|
||||
res, err := databroker.Put(ctx, mgr.cfg.Load().dataBrokerClient, u.User)
|
||||
if err != nil {
|
||||
log.Error(ctx).Err(err).
|
||||
Str("user_id", s.GetUserId()).
|
||||
Str("session_id", s.GetId()).
|
||||
Msg("failed to update user")
|
||||
continue
|
||||
}
|
||||
|
||||
mgr.onUpdateUser(ctx, res.GetRecords()[0], u.User)
|
||||
mgr.updateSession(ctx, s)
|
||||
mgr.updateUser(ctx, u)
|
||||
}
|
||||
}
|
||||
|
||||
func (mgr *Manager) onUpdateRecords(ctx context.Context, msg updateRecordsMessage) {
|
||||
for _, record := range msg.records {
|
||||
switch record.GetType() {
|
||||
case grpcutil.GetTypeURL(new(session.Session)):
|
||||
var pbSession session.Session
|
||||
err := record.GetData().UnmarshalTo(&pbSession)
|
||||
if err != nil {
|
||||
log.Warn(ctx).Msgf("error unmarshaling session: %s", err)
|
||||
continue
|
||||
}
|
||||
mgr.onUpdateSession(record, &pbSession)
|
||||
case grpcutil.GetTypeURL(new(user.User)):
|
||||
var pbUser user.User
|
||||
err := record.GetData().UnmarshalTo(&pbUser)
|
||||
if err != nil {
|
||||
log.Warn(ctx).Msgf("error unmarshaling user: %s", err)
|
||||
continue
|
||||
}
|
||||
mgr.onUpdateUser(ctx, record, &pbUser)
|
||||
}
|
||||
// deleteSession deletes a session from the databroke, the local data store, and the schedulers
|
||||
func (mgr *Manager) deleteSession(ctx context.Context, sessionID string) {
|
||||
log.Ctx(ctx).Debug().
|
||||
Str("session_id", sessionID).
|
||||
Msg("deleting session")
|
||||
|
||||
mgr.mu.Lock()
|
||||
mgr.dataStore.deleteSession(sessionID)
|
||||
if rss, ok := mgr.refreshSessionSchedulers[sessionID]; ok {
|
||||
rss.Stop()
|
||||
delete(mgr.refreshSessionSchedulers, sessionID)
|
||||
}
|
||||
}
|
||||
mgr.mu.Unlock()
|
||||
|
||||
func (mgr *Manager) onUpdateSession(record *databroker.Record, session *session.Session) {
|
||||
mgr.sessionScheduler.Remove(toSessionSchedulerKey(session.GetUserId(), session.GetId()))
|
||||
|
||||
if record.GetDeletedAt() != nil {
|
||||
mgr.sessions.Delete(session.GetUserId(), session.GetId())
|
||||
return
|
||||
}
|
||||
|
||||
// update session
|
||||
s, _ := mgr.sessions.Get(session.GetUserId(), session.GetId())
|
||||
if s.lastRefresh.IsZero() {
|
||||
s.lastRefresh = mgr.now()
|
||||
}
|
||||
s.gracePeriod = mgr.cfg.Load().sessionRefreshGracePeriod
|
||||
s.coolOffDuration = mgr.cfg.Load().sessionRefreshCoolOffDuration
|
||||
s.Session = session
|
||||
mgr.sessions.ReplaceOrInsert(s)
|
||||
mgr.sessionScheduler.Add(s.NextRefresh(), toSessionSchedulerKey(session.GetUserId(), session.GetId()))
|
||||
}
|
||||
|
||||
func (mgr *Manager) onUpdateUser(_ context.Context, record *databroker.Record, user *user.User) {
|
||||
mgr.userScheduler.Remove(user.GetId())
|
||||
|
||||
if record.GetDeletedAt() != nil {
|
||||
mgr.users.Delete(user.GetId())
|
||||
return
|
||||
}
|
||||
|
||||
u, _ := mgr.users.Get(user.GetId())
|
||||
u.lastRefresh = mgr.cfg.Load().now()
|
||||
u.User = user
|
||||
mgr.users.ReplaceOrInsert(u)
|
||||
mgr.userScheduler.Add(u.NextRefresh(), u.GetId())
|
||||
}
|
||||
|
||||
func (mgr *Manager) deleteSession(ctx context.Context, userID, sessionID string) {
|
||||
mgr.sessionScheduler.Remove(toSessionSchedulerKey(userID, sessionID))
|
||||
mgr.sessions.Delete(userID, sessionID)
|
||||
|
||||
client := mgr.cfg.Load().dataBrokerClient
|
||||
res, err := client.Get(ctx, &databroker.GetRequest{
|
||||
res, err := mgr.cfg.Load().dataBrokerClient.Get(ctx, &databroker.GetRequest{
|
||||
Type: grpcutil.GetTypeURL(new(session.Session)),
|
||||
Id: sessionID,
|
||||
})
|
||||
if status.Code(err) == codes.NotFound {
|
||||
return
|
||||
} else if err != nil {
|
||||
log.Error(ctx).Err(err).
|
||||
Str("session_id", sessionID).
|
||||
Msg("failed to delete session")
|
||||
return
|
||||
}
|
||||
|
||||
record := res.GetRecord()
|
||||
record.DeletedAt = timestamppb.Now()
|
||||
|
||||
_, err = client.Put(ctx, &databroker.PutRequest{
|
||||
_, err = mgr.cfg.Load().dataBrokerClient.Put(ctx, &databroker.PutRequest{
|
||||
Records: []*databroker.Record{record},
|
||||
})
|
||||
if err != nil {
|
||||
log.Error(ctx).Err(err).
|
||||
log.Ctx(ctx).Error().Err(err).
|
||||
Str("session_id", sessionID).
|
||||
Msg("failed to delete session")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// reset resets all the manager datastructures to their initial state
|
||||
func (mgr *Manager) reset() {
|
||||
mgr.sessions = sessionCollection{BTree: btree.New(8)}
|
||||
mgr.users = userCollection{BTree: btree.New(8)}
|
||||
func (mgr *Manager) updateSession(ctx context.Context, s *session.Session) {
|
||||
log.Ctx(ctx).Debug().
|
||||
Str("user_id", s.GetUserId()).
|
||||
Str("session_id", s.GetId()).
|
||||
Msg("updating session")
|
||||
|
||||
fm, err := fieldmaskpb.New(s, "oauth_token", "id_token", "claims")
|
||||
if err != nil {
|
||||
log.Error(ctx).Err(err).
|
||||
Str("user_id", s.GetUserId()).
|
||||
Str("session_id", s.GetId()).
|
||||
Msg("failed to create fieldmask for session")
|
||||
return
|
||||
}
|
||||
|
||||
_, err = session.Patch(ctx, mgr.cfg.Load().dataBrokerClient, s, fm)
|
||||
if err != nil {
|
||||
log.Error(ctx).Err(err).
|
||||
Str("user_id", s.GetUserId()).
|
||||
Str("session_id", s.GetId()).
|
||||
Msg("failed to patch updated session record")
|
||||
return
|
||||
}
|
||||
|
||||
mgr.mu.Lock()
|
||||
mgr.dataStore.putSession(s)
|
||||
if rss, ok := mgr.refreshSessionSchedulers[s.GetId()]; ok {
|
||||
rss.Update(s)
|
||||
}
|
||||
mgr.mu.Unlock()
|
||||
}
|
||||
|
||||
// updateUser updates the user in the databroker, the local data store, and resets the scheduler
|
||||
func (mgr *Manager) updateUser(ctx context.Context, u *user.User) {
|
||||
log.Ctx(ctx).Debug().
|
||||
Str("user_id", u.GetId()).
|
||||
Msg("updating user")
|
||||
|
||||
_, err := databroker.Put(ctx, mgr.cfg.Load().dataBrokerClient, u)
|
||||
if err != nil {
|
||||
log.Ctx(ctx).Error().
|
||||
Str("user_id", u.GetId()).
|
||||
Err(err).
|
||||
Msg("failed to store updated user record")
|
||||
return
|
||||
}
|
||||
|
||||
mgr.mu.Lock()
|
||||
mgr.dataStore.putUser(u)
|
||||
if uuis, ok := mgr.updateUserInfoSchedulers[u.GetId()]; ok {
|
||||
uuis.Reset()
|
||||
}
|
||||
mgr.mu.Unlock()
|
||||
}
|
||||
|
||||
func (mgr *Manager) recordLastError(id string, err error) {
|
||||
|
@ -473,17 +412,3 @@ func (mgr *Manager) recordLastError(id string, err error) {
|
|||
Id: id,
|
||||
})
|
||||
}
|
||||
|
||||
func isTemporaryError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
if errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled) {
|
||||
return true
|
||||
}
|
||||
var hasTemporary interface{ Temporary() bool }
|
||||
if errors.As(err, &hasTemporary) && hasTemporary.Temporary() {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
|
|
@ -2,28 +2,10 @@ package manager
|
|||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"go.uber.org/mock/gomock"
|
||||
"golang.org/x/oauth2"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/proto"
|
||||
"google.golang.org/protobuf/types/known/fieldmaskpb"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/events"
|
||||
"github.com/pomerium/pomerium/internal/identity/identity"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker/mock_databroker"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/session"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/user"
|
||||
metrics_ids "github.com/pomerium/pomerium/pkg/metrics"
|
||||
"github.com/pomerium/pomerium/pkg/protoutil"
|
||||
)
|
||||
|
||||
type mockAuthenticator struct {
|
||||
|
@ -44,317 +26,3 @@ func (mock *mockAuthenticator) Revoke(_ context.Context, _ *oauth2.Token) error
|
|||
func (mock *mockAuthenticator) UpdateUserInfo(_ context.Context, _ *oauth2.Token, _ any) error {
|
||||
return mock.updateUserInfoError
|
||||
}
|
||||
|
||||
func TestManager_refresh(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10)
|
||||
t.Cleanup(clearTimeout)
|
||||
|
||||
client := mock_databroker.NewMockDataBrokerServiceClient(ctrl)
|
||||
mgr := New(WithDataBrokerClient(client))
|
||||
mgr.onUpdateRecords(ctx, updateRecordsMessage{
|
||||
records: []*databroker.Record{
|
||||
databroker.NewRecord(&session.Session{
|
||||
Id: "s1",
|
||||
UserId: "u1",
|
||||
OauthToken: &session.OAuthToken{},
|
||||
ExpiresAt: timestamppb.New(time.Now().Add(time.Second * 10)),
|
||||
}),
|
||||
databroker.NewRecord(&user.User{
|
||||
Id: "u1",
|
||||
}),
|
||||
},
|
||||
})
|
||||
client.EXPECT().Get(gomock.Any(), gomock.Any()).Return(nil, status.Error(codes.NotFound, "not found"))
|
||||
mgr.refreshSession(ctx, "u1", "s1")
|
||||
mgr.refreshUser(ctx, "u1")
|
||||
}
|
||||
|
||||
func TestManager_onUpdateRecords(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
|
||||
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10)
|
||||
defer clearTimeout()
|
||||
|
||||
now := time.Now()
|
||||
|
||||
mgr := New(
|
||||
WithDataBrokerClient(mock_databroker.NewMockDataBrokerServiceClient(ctrl)),
|
||||
WithNow(func() time.Time {
|
||||
return now
|
||||
}),
|
||||
)
|
||||
|
||||
mgr.onUpdateRecords(ctx, updateRecordsMessage{
|
||||
records: []*databroker.Record{
|
||||
mkRecord(&session.Session{Id: "session1", UserId: "user1"}),
|
||||
mkRecord(&user.User{Id: "user1", Name: "user 1", Email: "user1@example.com"}),
|
||||
},
|
||||
})
|
||||
|
||||
if _, ok := mgr.sessions.Get("user1", "session1"); assert.True(t, ok) {
|
||||
tm, id := mgr.sessionScheduler.Next()
|
||||
assert.Equal(t, now.Add(10*time.Second), tm)
|
||||
assert.Equal(t, "user1\037session1", id)
|
||||
}
|
||||
if _, ok := mgr.users.Get("user1"); assert.True(t, ok) {
|
||||
tm, id := mgr.userScheduler.Next()
|
||||
assert.Equal(t, now.Add(userRefreshInterval), tm)
|
||||
assert.Equal(t, "user1", id)
|
||||
}
|
||||
}
|
||||
|
||||
func TestManager_onUpdateSession(t *testing.T) {
|
||||
startTime := time.Date(2023, 10, 19, 12, 0, 0, 0, time.UTC)
|
||||
|
||||
s := &session.Session{
|
||||
Id: "session-id",
|
||||
UserId: "user-id",
|
||||
OauthToken: &session.OAuthToken{
|
||||
AccessToken: "access-token",
|
||||
ExpiresAt: timestamppb.New(startTime.Add(5 * time.Minute)),
|
||||
},
|
||||
IssuedAt: timestamppb.New(startTime),
|
||||
ExpiresAt: timestamppb.New(startTime.Add(24 * time.Hour)),
|
||||
}
|
||||
|
||||
assertNextScheduled := func(t *testing.T, mgr *Manager, expectedTime time.Time) {
|
||||
t.Helper()
|
||||
tm, key := mgr.sessionScheduler.Next()
|
||||
assert.Equal(t, expectedTime, tm)
|
||||
assert.Equal(t, "user-id\037session-id", key)
|
||||
}
|
||||
|
||||
t.Run("initial refresh event when not expiring soon", func(t *testing.T) {
|
||||
now := startTime
|
||||
mgr := New(WithNow(func() time.Time { return now }))
|
||||
|
||||
// When the Manager first becomes aware of a session it should schedule
|
||||
// a refresh event for one minute before access token expiration.
|
||||
mgr.onUpdateSession(mkRecord(s), s)
|
||||
assertNextScheduled(t, mgr, startTime.Add(4*time.Minute))
|
||||
})
|
||||
t.Run("initial refresh event when expiring soon", func(t *testing.T) {
|
||||
now := startTime
|
||||
mgr := New(WithNow(func() time.Time { return now }))
|
||||
|
||||
// When the Manager first becomes aware of a session, if that session
|
||||
// is expiring within the gracePeriod (1 minute), it should schedule a
|
||||
// refresh event for as soon as possible, subject to the
|
||||
// coolOffDuration (10 seconds).
|
||||
now = now.Add(4*time.Minute + 30*time.Second) // 30 s before expiration
|
||||
mgr.onUpdateSession(mkRecord(s), s)
|
||||
assertNextScheduled(t, mgr, now.Add(10*time.Second))
|
||||
})
|
||||
t.Run("update near scheduled refresh", func(t *testing.T) {
|
||||
now := startTime
|
||||
mgr := New(WithNow(func() time.Time { return now }))
|
||||
|
||||
mgr.onUpdateSession(mkRecord(s), s)
|
||||
assertNextScheduled(t, mgr, startTime.Add(4*time.Minute))
|
||||
|
||||
// If a session is updated close to the time when it is scheduled to be
|
||||
// refreshed, the scheduled refresh event should not be pushed back.
|
||||
now = now.Add(3*time.Minute + 55*time.Second) // 5 s before refresh
|
||||
mgr.onUpdateSession(mkRecord(s), s)
|
||||
assertNextScheduled(t, mgr, now.Add(5*time.Second))
|
||||
|
||||
// However, if an update changes the access token validity, the refresh
|
||||
// event should be rescheduled accordingly. (This should be uncommon,
|
||||
// as only the refresh loop itself should modify the access token.)
|
||||
s2 := proto.Clone(s).(*session.Session)
|
||||
s2.OauthToken.ExpiresAt = timestamppb.New(now.Add(5 * time.Minute))
|
||||
mgr.onUpdateSession(mkRecord(s2), s2)
|
||||
assertNextScheduled(t, mgr, now.Add(4*time.Minute))
|
||||
})
|
||||
t.Run("session record deleted", func(t *testing.T) {
|
||||
now := startTime
|
||||
mgr := New(WithNow(func() time.Time { return now }))
|
||||
|
||||
mgr.onUpdateSession(mkRecord(s), s)
|
||||
assertNextScheduled(t, mgr, startTime.Add(4*time.Minute))
|
||||
|
||||
// If a session is deleted, any scheduled refresh event should be canceled.
|
||||
record := mkRecord(s)
|
||||
record.DeletedAt = timestamppb.New(now)
|
||||
mgr.onUpdateSession(record, s)
|
||||
_, key := mgr.sessionScheduler.Next()
|
||||
assert.Empty(t, key)
|
||||
})
|
||||
}
|
||||
|
||||
func TestManager_refreshSession(t *testing.T) {
|
||||
startTime := time.Date(2023, 10, 19, 12, 0, 0, 0, time.UTC)
|
||||
|
||||
var auth mockAuthenticator
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
client := mock_databroker.NewMockDataBrokerServiceClient(ctrl)
|
||||
|
||||
now := startTime
|
||||
mgr := New(
|
||||
WithDataBrokerClient(client),
|
||||
WithNow(func() time.Time { return now }),
|
||||
WithAuthenticator(&auth),
|
||||
)
|
||||
|
||||
// Initialize the Manager with a new session.
|
||||
s := &session.Session{
|
||||
Id: "session-id",
|
||||
UserId: "user-id",
|
||||
OauthToken: &session.OAuthToken{
|
||||
AccessToken: "access-token",
|
||||
ExpiresAt: timestamppb.New(startTime.Add(5 * time.Minute)),
|
||||
RefreshToken: "refresh-token",
|
||||
},
|
||||
IssuedAt: timestamppb.New(startTime),
|
||||
ExpiresAt: timestamppb.New(startTime.Add(24 * time.Hour)),
|
||||
}
|
||||
mgr.sessions.ReplaceOrInsert(Session{
|
||||
Session: s,
|
||||
lastRefresh: startTime,
|
||||
gracePeriod: time.Minute,
|
||||
coolOffDuration: 10 * time.Second,
|
||||
})
|
||||
|
||||
// If OAuth2 token refresh fails with a temporary error, the manager should
|
||||
// still reschedule another refresh attempt.
|
||||
now = now.Add(4 * time.Minute)
|
||||
auth.refreshError = context.DeadlineExceeded
|
||||
mgr.refreshSession(context.Background(), "user-id", "session-id")
|
||||
|
||||
tm, key := mgr.sessionScheduler.Next()
|
||||
assert.Equal(t, now.Add(10*time.Second), tm)
|
||||
assert.Equal(t, "user-id\037session-id", key)
|
||||
|
||||
// Simulate a successful token refresh on the second attempt. The manager
|
||||
// should store the updated session in the databroker and schedule another
|
||||
// refresh event.
|
||||
now = now.Add(10 * time.Second)
|
||||
auth.refreshResult, auth.refreshError = &oauth2.Token{
|
||||
AccessToken: "new-access-token",
|
||||
RefreshToken: "new-refresh-token",
|
||||
Expiry: now.Add(5 * time.Minute),
|
||||
}, nil
|
||||
expectedSession := proto.Clone(s).(*session.Session)
|
||||
expectedSession.OauthToken = &session.OAuthToken{
|
||||
AccessToken: "new-access-token",
|
||||
ExpiresAt: timestamppb.New(now.Add(5 * time.Minute)),
|
||||
RefreshToken: "new-refresh-token",
|
||||
}
|
||||
client.EXPECT().Patch(gomock.Any(), objectsAreEqualMatcher{
|
||||
&databroker.PatchRequest{
|
||||
Records: []*databroker.Record{{
|
||||
Type: "type.googleapis.com/session.Session",
|
||||
Id: "session-id",
|
||||
Data: protoutil.NewAny(expectedSession),
|
||||
}},
|
||||
FieldMask: &fieldmaskpb.FieldMask{
|
||||
Paths: []string{"oauth_token", "id_token", "claims"},
|
||||
},
|
||||
},
|
||||
}).
|
||||
Return(nil /* this result is currently unused */, nil)
|
||||
mgr.refreshSession(context.Background(), "user-id", "session-id")
|
||||
|
||||
tm, key = mgr.sessionScheduler.Next()
|
||||
assert.Equal(t, now.Add(4*time.Minute), tm)
|
||||
assert.Equal(t, "user-id\037session-id", key)
|
||||
}
|
||||
|
||||
func TestManager_reportErrors(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
|
||||
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10)
|
||||
defer clearTimeout()
|
||||
|
||||
evtMgr := events.New()
|
||||
received := make(chan events.Event, 1)
|
||||
handle := evtMgr.Register(func(evt events.Event) {
|
||||
received <- evt
|
||||
})
|
||||
defer evtMgr.Unregister(handle)
|
||||
|
||||
expectMsg := func(id, msg string) {
|
||||
t.Helper()
|
||||
assert.Eventually(t, func() bool {
|
||||
select {
|
||||
case evt := <-received:
|
||||
lastErr := evt.(*events.LastError)
|
||||
return msg == lastErr.Message && id == lastErr.Id
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}, time.Second, time.Millisecond*20, msg)
|
||||
}
|
||||
|
||||
s := &session.Session{
|
||||
Id: "session1",
|
||||
UserId: "user1",
|
||||
OauthToken: &session.OAuthToken{
|
||||
ExpiresAt: timestamppb.New(time.Now().Add(time.Hour)),
|
||||
},
|
||||
ExpiresAt: timestamppb.New(time.Now().Add(time.Hour)),
|
||||
}
|
||||
|
||||
client := mock_databroker.NewMockDataBrokerServiceClient(ctrl)
|
||||
client.EXPECT().Get(gomock.Any(), gomock.Any()).AnyTimes().Return(&databroker.GetResponse{Record: databroker.NewRecord(s)}, nil)
|
||||
client.EXPECT().Put(gomock.Any(), gomock.Any()).AnyTimes()
|
||||
mgr := New(
|
||||
WithEventManager(evtMgr),
|
||||
WithDataBrokerClient(client),
|
||||
WithAuthenticator(&mockAuthenticator{
|
||||
refreshError: errors.New("update session"),
|
||||
updateUserInfoError: errors.New("update user info"),
|
||||
}),
|
||||
)
|
||||
|
||||
mgr.onUpdateRecords(ctx, updateRecordsMessage{
|
||||
records: []*databroker.Record{
|
||||
mkRecord(s),
|
||||
mkRecord(&user.User{Id: "user1", Name: "user 1", Email: "user1@example.com"}),
|
||||
},
|
||||
})
|
||||
|
||||
mgr.refreshUser(ctx, "user1")
|
||||
expectMsg(metrics_ids.IdentityManagerLastUserRefreshError, "update user info")
|
||||
|
||||
mgr.onUpdateRecords(ctx, updateRecordsMessage{
|
||||
records: []*databroker.Record{
|
||||
mkRecord(s),
|
||||
mkRecord(&user.User{Id: "user1", Name: "user 1", Email: "user1@example.com"}),
|
||||
},
|
||||
})
|
||||
|
||||
mgr.refreshSession(ctx, "user1", "session1")
|
||||
expectMsg(metrics_ids.IdentityManagerLastSessionRefreshError, "update session")
|
||||
}
|
||||
|
||||
func mkRecord(msg recordable) *databroker.Record {
|
||||
data := protoutil.NewAny(msg)
|
||||
return &databroker.Record{
|
||||
Type: data.GetTypeUrl(),
|
||||
Id: msg.GetId(),
|
||||
Data: data,
|
||||
}
|
||||
}
|
||||
|
||||
type recordable interface {
|
||||
proto.Message
|
||||
GetId() string
|
||||
}
|
||||
|
||||
// objectsAreEqualMatcher implements gomock.Matcher using ObjectsAreEqual. This
|
||||
// is especially helpful when working with pointers, as it will compare the
|
||||
// underlying values rather than the pointers themselves.
|
||||
type objectsAreEqualMatcher struct {
|
||||
expected interface{}
|
||||
}
|
||||
|
||||
func (m objectsAreEqualMatcher) Matches(x interface{}) bool {
|
||||
return assert.ObjectsAreEqual(m.expected, x)
|
||||
}
|
||||
|
||||
func (m objectsAreEqualMatcher) String() string {
|
||||
return fmt.Sprintf("is equal to %v (%T)", m.expected, m.expected)
|
||||
}
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
package manager
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
@ -9,21 +10,6 @@ import (
|
|||
"github.com/pomerium/pomerium/pkg/grpc/session"
|
||||
)
|
||||
|
||||
func toSessionSchedulerKey(userID, sessionID string) string {
|
||||
return userID + "\037" + sessionID
|
||||
}
|
||||
|
||||
func fromSessionSchedulerKey(key string) (userID, sessionID string) {
|
||||
idx := strings.Index(key, "\037")
|
||||
if idx >= 0 {
|
||||
userID = key[:idx]
|
||||
sessionID = key[idx+1:]
|
||||
} else {
|
||||
userID = key
|
||||
}
|
||||
return userID, sessionID
|
||||
}
|
||||
|
||||
// FromOAuthToken converts a session oauth token to oauth2.Token.
|
||||
func FromOAuthToken(token *session.OAuthToken) *oauth2.Token {
|
||||
return &oauth2.Token{
|
||||
|
@ -44,3 +30,17 @@ func ToOAuthToken(token *oauth2.Token) *session.OAuthToken {
|
|||
ExpiresAt: expiry,
|
||||
}
|
||||
}
|
||||
|
||||
func isTemporaryError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
if errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled) {
|
||||
return true
|
||||
}
|
||||
var hasTemporary interface{ Temporary() bool }
|
||||
if errors.As(err, &hasTemporary) && hasTemporary.Temporary() {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
|
138
internal/identity/manager/schedulers.go
Normal file
138
internal/identity/manager/schedulers.go
Normal file
|
@ -0,0 +1,138 @@
|
|||
package manager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/pomerium/pomerium/pkg/grpc/session"
|
||||
)
|
||||
|
||||
type updateUserInfoScheduler struct {
|
||||
mgr *Manager
|
||||
userID string
|
||||
reset chan struct{}
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
func newUpdateUserInfoScheduler(ctx context.Context, mgr *Manager, userID string) *updateUserInfoScheduler {
|
||||
uuis := &updateUserInfoScheduler{
|
||||
mgr: mgr,
|
||||
userID: userID,
|
||||
reset: make(chan struct{}, 1),
|
||||
}
|
||||
ctx = context.WithoutCancel(ctx)
|
||||
ctx, uuis.cancel = context.WithCancel(ctx)
|
||||
go uuis.run(ctx)
|
||||
return uuis
|
||||
}
|
||||
|
||||
func (uuis *updateUserInfoScheduler) Reset() {
|
||||
select {
|
||||
case uuis.reset <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func (uuis *updateUserInfoScheduler) Stop() {
|
||||
uuis.cancel()
|
||||
}
|
||||
|
||||
func (uuis *updateUserInfoScheduler) run(ctx context.Context) {
|
||||
ticker := time.NewTicker(uuis.mgr.cfg.Load().updateUserInfoInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-uuis.reset:
|
||||
ticker.Reset(uuis.mgr.cfg.Load().updateUserInfoInterval)
|
||||
case <-ticker.C:
|
||||
uuis.mgr.updateUserInfo(ctx, uuis.userID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type refreshSessionScheduler struct {
|
||||
mgr *Manager
|
||||
sessionID string
|
||||
lastRefresh atomic.Pointer[time.Time]
|
||||
next chan time.Time
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
func newRefreshSessionScheduler(
|
||||
ctx context.Context,
|
||||
mgr *Manager,
|
||||
sessionID string,
|
||||
) *refreshSessionScheduler {
|
||||
rss := &refreshSessionScheduler{
|
||||
mgr: mgr,
|
||||
sessionID: sessionID,
|
||||
next: make(chan time.Time, 1),
|
||||
}
|
||||
now := rss.mgr.cfg.Load().now()
|
||||
rss.lastRefresh.Store(&now)
|
||||
ctx = context.WithoutCancel(ctx)
|
||||
ctx, rss.cancel = context.WithCancel(ctx)
|
||||
go rss.run(ctx)
|
||||
return rss
|
||||
}
|
||||
|
||||
func (rss *refreshSessionScheduler) Update(s *session.Session) {
|
||||
due := nextSessionRefresh(
|
||||
s,
|
||||
*rss.lastRefresh.Load(),
|
||||
rss.mgr.cfg.Load().sessionRefreshGracePeriod,
|
||||
rss.mgr.cfg.Load().sessionRefreshCoolOffDuration,
|
||||
)
|
||||
for {
|
||||
select {
|
||||
case <-rss.next:
|
||||
default:
|
||||
}
|
||||
select {
|
||||
case rss.next <- due:
|
||||
return
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (rss *refreshSessionScheduler) Stop() {
|
||||
rss.cancel()
|
||||
}
|
||||
|
||||
func (rss *refreshSessionScheduler) run(ctx context.Context) {
|
||||
var timer *time.Timer
|
||||
|
||||
// wait for the first update
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case due := <-rss.next:
|
||||
delay := max(time.Until(due), 0)
|
||||
timer = time.NewTimer(delay)
|
||||
defer timer.Stop()
|
||||
}
|
||||
|
||||
// wait for updates or for the timer to trigger
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case due := <-rss.next:
|
||||
delay := max(time.Until(due), 0)
|
||||
// stop the current timer and reset it
|
||||
if !timer.Stop() {
|
||||
<-timer.C
|
||||
}
|
||||
timer.Reset(delay)
|
||||
case <-timer.C:
|
||||
now := rss.mgr.cfg.Load().now()
|
||||
rss.lastRefresh.Store(&now)
|
||||
rss.mgr.refreshSession(ctx, rss.sessionID)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -3,53 +3,75 @@ package manager
|
|||
import (
|
||||
"context"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/atomicutil"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/session"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/user"
|
||||
"github.com/pomerium/pomerium/pkg/grpcutil"
|
||||
)
|
||||
|
||||
type dataBrokerSyncer struct {
|
||||
cfg *atomicutil.Value[*config]
|
||||
|
||||
update chan<- updateRecordsMessage
|
||||
clear chan<- struct{}
|
||||
|
||||
syncer *databroker.Syncer
|
||||
type sessionSyncerHandler struct {
|
||||
mgr *Manager
|
||||
}
|
||||
|
||||
func newDataBrokerSyncer(
|
||||
_ context.Context,
|
||||
cfg *atomicutil.Value[*config],
|
||||
update chan<- updateRecordsMessage,
|
||||
clear chan<- struct{},
|
||||
) *dataBrokerSyncer {
|
||||
syncer := &dataBrokerSyncer{
|
||||
cfg: cfg,
|
||||
|
||||
update: update,
|
||||
clear: clear,
|
||||
}
|
||||
syncer.syncer = databroker.NewSyncer("identity_manager", syncer)
|
||||
return syncer
|
||||
func newSessionSyncer(mgr *Manager) *databroker.Syncer {
|
||||
return databroker.NewSyncer("identity_manager/sessions", sessionSyncerHandler{mgr: mgr},
|
||||
databroker.WithTypeURL(grpcutil.GetTypeURL(new(session.Session))))
|
||||
}
|
||||
|
||||
func (syncer *dataBrokerSyncer) Run(ctx context.Context) (err error) {
|
||||
return syncer.syncer.Run(ctx)
|
||||
func (h sessionSyncerHandler) ClearRecords(ctx context.Context) {
|
||||
h.mgr.onDeleteAllSessions(ctx)
|
||||
}
|
||||
|
||||
func (syncer *dataBrokerSyncer) ClearRecords(ctx context.Context) {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
case syncer.clear <- struct{}{}:
|
||||
func (h sessionSyncerHandler) GetDataBrokerServiceClient() databroker.DataBrokerServiceClient {
|
||||
return h.mgr.cfg.Load().dataBrokerClient
|
||||
}
|
||||
|
||||
func (h sessionSyncerHandler) UpdateRecords(ctx context.Context, _ uint64, records []*databroker.Record) {
|
||||
for _, record := range records {
|
||||
if record.GetDeletedAt() != nil {
|
||||
h.mgr.onDeleteSession(ctx, record.GetId())
|
||||
} else {
|
||||
var s session.Session
|
||||
err := record.Data.UnmarshalTo(&s)
|
||||
if err != nil {
|
||||
log.Ctx(ctx).Warn().Err(err).Msg("invalid data in session record, ignoring")
|
||||
} else {
|
||||
h.mgr.onUpdateSession(ctx, &s)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (syncer *dataBrokerSyncer) GetDataBrokerServiceClient() databroker.DataBrokerServiceClient {
|
||||
return syncer.cfg.Load().dataBrokerClient
|
||||
type userSyncerHandler struct {
|
||||
mgr *Manager
|
||||
}
|
||||
|
||||
func (syncer *dataBrokerSyncer) UpdateRecords(ctx context.Context, _ uint64, records []*databroker.Record) {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
case syncer.update <- updateRecordsMessage{records: records}:
|
||||
func newUserSyncer(mgr *Manager) *databroker.Syncer {
|
||||
return databroker.NewSyncer("identity_manager/users", userSyncerHandler{mgr: mgr},
|
||||
databroker.WithTypeURL(grpcutil.GetTypeURL(new(user.User))))
|
||||
}
|
||||
|
||||
func (h userSyncerHandler) ClearRecords(ctx context.Context) {
|
||||
h.mgr.onDeleteAllUsers(ctx)
|
||||
}
|
||||
|
||||
func (h userSyncerHandler) GetDataBrokerServiceClient() databroker.DataBrokerServiceClient {
|
||||
return h.mgr.cfg.Load().dataBrokerClient
|
||||
}
|
||||
|
||||
func (h userSyncerHandler) UpdateRecords(ctx context.Context, _ uint64, records []*databroker.Record) {
|
||||
for _, record := range records {
|
||||
if record.GetDeletedAt() != nil {
|
||||
h.mgr.onDeleteUser(ctx, record.GetId())
|
||||
} else {
|
||||
var u user.User
|
||||
err := record.Data.UnmarshalTo(&u)
|
||||
if err != nil {
|
||||
log.Ctx(ctx).Warn().Err(err).Msg("invalid data in user record, ignoring")
|
||||
} else {
|
||||
h.mgr.onUpdateUser(ctx, &u)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue