This commit is contained in:
Caleb Doxsey 2024-04-19 15:23:55 -06:00
parent a6577fd570
commit a80ef11763
11 changed files with 574 additions and 949 deletions

View file

@ -196,16 +196,15 @@ func (s *Stateful) PersistSession(
sess.SetRawIDToken(claims.RawIDToken) sess.SetRawIDToken(claims.RawIDToken)
sess.AddClaims(claims.Flatten()) sess.AddClaims(claims.Flatten())
var managerUser manager.User u, _ := user.Get(ctx, s.dataBrokerClient, sess.GetUserId())
managerUser.User, _ = user.Get(ctx, s.dataBrokerClient, sess.GetUserId()) if u == nil {
if managerUser.User == nil {
// if no user exists yet, create a new one // if no user exists yet, create a new one
managerUser.User = &user.User{ u = &user.User{
Id: sess.GetUserId(), Id: sess.GetUserId(),
} }
} }
populateUserFromClaims(managerUser.User, claims.Claims) populateUserFromClaims(u, claims.Claims)
_, err := databroker.Put(ctx, s.dataBrokerClient, managerUser.User) _, err := databroker.Put(ctx, s.dataBrokerClient, u)
if err != nil { if err != nil {
return fmt.Errorf("authenticate: error saving user: %w", err) return fmt.Errorf("authenticate: error saving user: %w", err)
} }

View file

@ -10,6 +10,7 @@ import (
var ( var (
defaultSessionRefreshGracePeriod = 1 * time.Minute defaultSessionRefreshGracePeriod = 1 * time.Minute
defaultSessionRefreshCoolOffDuration = 10 * time.Second defaultSessionRefreshCoolOffDuration = 10 * time.Second
defaultUpdateUserInfoInterval = 10 * time.Minute
) )
type config struct { type config struct {
@ -17,6 +18,7 @@ type config struct {
dataBrokerClient databroker.DataBrokerServiceClient dataBrokerClient databroker.DataBrokerServiceClient
sessionRefreshGracePeriod time.Duration sessionRefreshGracePeriod time.Duration
sessionRefreshCoolOffDuration time.Duration sessionRefreshCoolOffDuration time.Duration
updateUserInfoInterval time.Duration
now func() time.Time now func() time.Time
eventMgr *events.Manager eventMgr *events.Manager
} }
@ -26,6 +28,7 @@ func newConfig(options ...Option) *config {
WithSessionRefreshGracePeriod(defaultSessionRefreshGracePeriod)(cfg) WithSessionRefreshGracePeriod(defaultSessionRefreshGracePeriod)(cfg)
WithSessionRefreshCoolOffDuration(defaultSessionRefreshCoolOffDuration)(cfg) WithSessionRefreshCoolOffDuration(defaultSessionRefreshCoolOffDuration)(cfg)
WithNow(time.Now)(cfg) WithNow(time.Now)(cfg)
WithUpdateUserInfoInterval(defaultUpdateUserInfoInterval)(cfg)
for _, option := range options { for _, option := range options {
option(cfg) option(cfg)
} }
@ -76,3 +79,10 @@ func WithEventManager(mgr *events.Manager) Option {
c.eventMgr = mgr c.eventMgr = mgr
} }
} }
// WithUpdateUserInfoInterval sets the update user info interval in the config.
func WithUpdateUserInfoInterval(dur time.Duration) Option {
return func(cfg *config) {
cfg.updateUserInfoInterval = dur
}
}

View file

@ -2,9 +2,9 @@ package manager
import ( import (
"encoding/json" "encoding/json"
"errors"
"time" "time"
"github.com/google/btree"
"google.golang.org/protobuf/types/known/timestamppb" "google.golang.org/protobuf/types/known/timestamppb"
"github.com/pomerium/pomerium/internal/identity" "github.com/pomerium/pomerium/internal/identity"
@ -12,66 +12,18 @@ import (
"github.com/pomerium/pomerium/pkg/grpc/user" "github.com/pomerium/pomerium/pkg/grpc/user"
) )
const userRefreshInterval = 10 * time.Minute func nextSessionRefresh(
s *session.Session,
// A User is a user managed by the Manager. lastRefresh time.Time,
type User struct { gracePeriod time.Duration,
*user.User coolOffDuration time.Duration,
lastRefresh time.Time ) time.Time {
}
// NextRefresh returns the next time the user information needs to be refreshed.
func (u User) NextRefresh() time.Time {
return u.lastRefresh.Add(userRefreshInterval)
}
// UnmarshalJSON unmarshals json data into the user object.
func (u *User) UnmarshalJSON(data []byte) error {
if u.User == nil {
u.User = new(user.User)
}
var raw map[string]json.RawMessage
err := json.Unmarshal(data, &raw)
if err != nil {
return err
}
if name, ok := raw["name"]; ok {
_ = json.Unmarshal(name, &u.User.Name)
delete(raw, "name")
}
if email, ok := raw["email"]; ok {
_ = json.Unmarshal(email, &u.User.Email)
delete(raw, "email")
}
u.AddClaims(identity.NewClaimsFromRaw(raw).Flatten())
return nil
}
// A Session is a session managed by the Manager.
type Session struct {
*session.Session
// lastRefresh is the time of the last refresh attempt (which may or may
// not have succeeded), or else the time the Manager first became aware of
// the session (if it has not yet attempted to refresh this session).
lastRefresh time.Time
// gracePeriod is the amount of time before expiration to attempt a refresh.
gracePeriod time.Duration
// coolOffDuration is the amount of time to wait before attempting another refresh.
coolOffDuration time.Duration
}
// NextRefresh returns the next time the session needs to be refreshed.
func (s Session) NextRefresh() time.Time {
var tm time.Time var tm time.Time
if s.GetOauthToken().GetExpiresAt() != nil { if s.GetOauthToken().GetExpiresAt() != nil {
expiry := s.GetOauthToken().GetExpiresAt().AsTime() expiry := s.GetOauthToken().GetExpiresAt().AsTime()
if s.GetOauthToken().GetExpiresAt().IsValid() && !expiry.IsZero() { if s.GetOauthToken().GetExpiresAt().IsValid() && !expiry.IsZero() {
expiry = expiry.Add(-s.gracePeriod) expiry = expiry.Add(-gracePeriod)
if tm.IsZero() || expiry.Before(tm) { if tm.IsZero() || expiry.Before(tm) {
tm = expiry tm = expiry
} }
@ -88,7 +40,7 @@ func (s Session) NextRefresh() time.Time {
} }
// don't refresh any quicker than the cool-off duration // don't refresh any quicker than the cool-off duration
min := s.lastRefresh.Add(s.coolOffDuration) min := lastRefresh.Add(coolOffDuration)
if tm.Before(min) { if tm.Before(min) {
tm = min tm = min
} }
@ -96,10 +48,33 @@ func (s Session) NextRefresh() time.Time {
return tm return tm
} }
// UnmarshalJSON unmarshals json data into the session object. type multiUnmarshaler []any
func (s *Session) UnmarshalJSON(data []byte) error {
if s.Session == nil { func newMultiUnmarshaler(args ...any) *multiUnmarshaler {
s.Session = new(session.Session) return (*multiUnmarshaler)(&args)
}
func (dst *multiUnmarshaler) UnmarshalJSON(data []byte) error {
var err error
for _, o := range *dst {
if o != nil {
err = errors.Join(err, json.Unmarshal(data, o))
}
}
return err
}
type sessionUnmarshaler struct {
*session.Session
}
func newSessionUnmarshaler(s *session.Session) *sessionUnmarshaler {
return &sessionUnmarshaler{Session: s}
}
func (dst *sessionUnmarshaler) UnmarshalJSON(data []byte) error {
if dst.Session == nil {
return nil
} }
var raw map[string]json.RawMessage var raw map[string]json.RawMessage
@ -108,159 +83,67 @@ func (s *Session) UnmarshalJSON(data []byte) error {
return err return err
} }
if s.Session.IdToken == nil { if dst.Session.IdToken == nil {
s.Session.IdToken = new(session.IDToken) dst.Session.IdToken = new(session.IDToken)
} }
if iss, ok := raw["iss"]; ok { if iss, ok := raw["iss"]; ok {
_ = json.Unmarshal(iss, &s.Session.IdToken.Issuer) _ = json.Unmarshal(iss, &dst.Session.IdToken.Issuer)
delete(raw, "iss") delete(raw, "iss")
} }
if sub, ok := raw["sub"]; ok { if sub, ok := raw["sub"]; ok {
_ = json.Unmarshal(sub, &s.Session.IdToken.Subject) _ = json.Unmarshal(sub, &dst.Session.IdToken.Subject)
delete(raw, "sub") delete(raw, "sub")
} }
if exp, ok := raw["exp"]; ok { if exp, ok := raw["exp"]; ok {
var secs int64 var secs int64
if err := json.Unmarshal(exp, &secs); err == nil { if err := json.Unmarshal(exp, &secs); err == nil {
s.Session.IdToken.ExpiresAt = timestamppb.New(time.Unix(secs, 0)) dst.Session.IdToken.ExpiresAt = timestamppb.New(time.Unix(secs, 0))
} }
delete(raw, "exp") delete(raw, "exp")
} }
if iat, ok := raw["iat"]; ok { if iat, ok := raw["iat"]; ok {
var secs int64 var secs int64
if err := json.Unmarshal(iat, &secs); err == nil { if err := json.Unmarshal(iat, &secs); err == nil {
s.Session.IdToken.IssuedAt = timestamppb.New(time.Unix(secs, 0)) dst.Session.IdToken.IssuedAt = timestamppb.New(time.Unix(secs, 0))
} }
delete(raw, "iat") delete(raw, "iat")
} }
s.AddClaims(identity.NewClaimsFromRaw(raw).Flatten()) dst.Session.AddClaims(identity.NewClaimsFromRaw(raw).Flatten())
return nil return nil
} }
type sessionCollectionItem struct { type userUnmarshaler struct {
Session *user.User
} }
func (item sessionCollectionItem) Less(than btree.Item) bool { func newUserUnmarshaler(u *user.User) *userUnmarshaler {
xUserID, yUserID := item.GetUserId(), than.(sessionCollectionItem).GetUserId() return &userUnmarshaler{User: u}
switch { }
case xUserID < yUserID:
return true func (dst *userUnmarshaler) UnmarshalJSON(data []byte) error {
case yUserID < xUserID: if dst.User == nil {
return false return nil
} }
xID, yID := item.GetId(), than.(sessionCollectionItem).GetId() var raw map[string]json.RawMessage
switch { err := json.Unmarshal(data, &raw)
case xID < yID: if err != nil {
return true return err
case yID < xID:
return false
} }
return false
}
type sessionCollection struct { if name, ok := raw["name"]; ok {
*btree.BTree _ = json.Unmarshal(name, &dst.User.Name)
} delete(raw, "name")
func (c *sessionCollection) Delete(userID, sessionID string) {
c.BTree.Delete(sessionCollectionItem{
Session: Session{
Session: &session.Session{
UserId: userID,
Id: sessionID,
},
},
})
}
func (c *sessionCollection) Get(userID, sessionID string) (Session, bool) {
item := c.BTree.Get(sessionCollectionItem{
Session: Session{
Session: &session.Session{
UserId: userID,
Id: sessionID,
},
},
})
if item == nil {
return Session{}, false
} }
return item.(sessionCollectionItem).Session, true if email, ok := raw["email"]; ok {
} _ = json.Unmarshal(email, &dst.User.Email)
delete(raw, "email")
// GetSessionsForUser gets all the sessions for the given user.
func (c *sessionCollection) GetSessionsForUser(userID string) []Session {
var sessions []Session
c.AscendGreaterOrEqual(sessionCollectionItem{
Session: Session{
Session: &session.Session{
UserId: userID,
},
},
}, func(item btree.Item) bool {
s := item.(sessionCollectionItem).Session
if s.UserId != userID {
return false
}
sessions = append(sessions, s)
return true
})
return sessions
}
func (c *sessionCollection) ReplaceOrInsert(s Session) {
c.BTree.ReplaceOrInsert(sessionCollectionItem{Session: s})
}
type userCollectionItem struct {
User
}
func (item userCollectionItem) Less(than btree.Item) bool {
xID, yID := item.GetId(), than.(userCollectionItem).GetId()
switch {
case xID < yID:
return true
case yID < xID:
return false
} }
return false
}
type userCollection struct { dst.User.AddClaims(identity.NewClaimsFromRaw(raw).Flatten())
*btree.BTree
}
func (c *userCollection) Delete(userID string) { return nil
c.BTree.Delete(userCollectionItem{
User: User{
User: &user.User{
Id: userID,
},
},
})
}
func (c *userCollection) Get(userID string) (User, bool) {
item := c.BTree.Get(userCollectionItem{
User: User{
User: &user.User{
Id: userID,
},
},
})
if item == nil {
return User{}, false
}
return item.(userCollectionItem).User, true
}
func (c *userCollection) ReplaceOrInsert(u User) {
c.BTree.ReplaceOrInsert(userCollectionItem{User: u})
} }

View file

@ -11,20 +11,20 @@ import (
"google.golang.org/protobuf/types/known/timestamppb" "google.golang.org/protobuf/types/known/timestamppb"
"github.com/pomerium/pomerium/pkg/grpc/session" "github.com/pomerium/pomerium/pkg/grpc/session"
"github.com/pomerium/pomerium/pkg/grpc/user"
"github.com/pomerium/pomerium/pkg/protoutil" "github.com/pomerium/pomerium/pkg/protoutil"
) )
func TestUser_UnmarshalJSON(t *testing.T) { func TestUser_UnmarshalJSON(t *testing.T) {
var u User u := new(user.User)
err := json.Unmarshal([]byte(`{ err := json.Unmarshal([]byte(`{
"name": "joe", "name": "joe",
"email": "joe@test.com", "email": "joe@test.com",
"some-other-claim": "xyz" "some-other-claim": "xyz"
}`), &u) }`), newUserUnmarshaler(u))
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, u.User) assert.Equal(t, "joe", u.Name)
assert.Equal(t, "joe", u.User.Name) assert.Equal(t, "joe@test.com", u.Email)
assert.Equal(t, "joe@test.com", u.User.Email)
assert.Equal(t, map[string]*structpb.ListValue{ assert.Equal(t, map[string]*structpb.ListValue{
"some-other-claim": {Values: []*structpb.Value{protoutil.ToStruct("xyz")}}, "some-other-claim": {Values: []*structpb.Value{protoutil.ToStruct("xyz")}},
}, u.Claims) }, u.Claims)
@ -32,42 +32,38 @@ func TestUser_UnmarshalJSON(t *testing.T) {
func TestSession_NextRefresh(t *testing.T) { func TestSession_NextRefresh(t *testing.T) {
tm1 := time.Date(2020, 6, 5, 12, 0, 0, 0, time.UTC) tm1 := time.Date(2020, 6, 5, 12, 0, 0, 0, time.UTC)
s := Session{ s := &session.Session{}
Session: &session.Session{}, gracePeriod := time.Second * 10
lastRefresh: tm1, coolOffDuration := time.Minute
gracePeriod: time.Second * 10, assert.Equal(t, tm1.Add(time.Minute), nextSessionRefresh(s, tm1, gracePeriod, coolOffDuration))
coolOffDuration: time.Minute,
}
assert.Equal(t, tm1.Add(time.Minute), s.NextRefresh())
tm2 := time.Date(2020, 6, 5, 13, 0, 0, 0, time.UTC) tm2 := time.Date(2020, 6, 5, 13, 0, 0, 0, time.UTC)
s.OauthToken = &session.OAuthToken{ s.OauthToken = &session.OAuthToken{
ExpiresAt: timestamppb.New(tm2), ExpiresAt: timestamppb.New(tm2),
} }
assert.Equal(t, tm2.Add(-time.Second*10), s.NextRefresh()) assert.Equal(t, tm2.Add(-time.Second*10), nextSessionRefresh(s, tm1, gracePeriod, coolOffDuration))
tm3 := time.Date(2020, 6, 5, 12, 15, 0, 0, time.UTC) tm3 := time.Date(2020, 6, 5, 12, 15, 0, 0, time.UTC)
s.ExpiresAt = timestamppb.New(tm3) s.ExpiresAt = timestamppb.New(tm3)
assert.Equal(t, tm3, s.NextRefresh()) assert.Equal(t, tm3, nextSessionRefresh(s, tm1, gracePeriod, coolOffDuration))
} }
func TestSession_UnmarshalJSON(t *testing.T) { func TestSession_UnmarshalJSON(t *testing.T) {
tm := time.Date(2020, 6, 5, 12, 0, 0, 0, time.UTC) tm := time.Date(2020, 6, 5, 12, 0, 0, 0, time.UTC)
var s Session s := new(session.Session)
err := json.Unmarshal([]byte(`{ err := json.Unmarshal([]byte(`{
"iss": "https://some.issuer.com", "iss": "https://some.issuer.com",
"sub": "subject", "sub": "subject",
"exp": `+fmt.Sprint(tm.Unix())+`, "exp": `+fmt.Sprint(tm.Unix())+`,
"iat": `+fmt.Sprint(tm.Unix())+`, "iat": `+fmt.Sprint(tm.Unix())+`,
"some-other-claim": "xyz" "some-other-claim": "xyz"
}`), &s) }`), newSessionUnmarshaler(s))
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, s.Session) assert.NotNil(t, s.IdToken)
assert.NotNil(t, s.Session.IdToken) assert.Equal(t, "https://some.issuer.com", s.IdToken.Issuer)
assert.Equal(t, "https://some.issuer.com", s.Session.IdToken.Issuer) assert.Equal(t, "subject", s.IdToken.Subject)
assert.Equal(t, "subject", s.Session.IdToken.Subject) assert.Equal(t, timestamppb.New(tm), s.IdToken.ExpiresAt)
assert.Equal(t, timestamppb.New(tm), s.Session.IdToken.ExpiresAt) assert.Equal(t, timestamppb.New(tm), s.IdToken.IssuedAt)
assert.Equal(t, timestamppb.New(tm), s.Session.IdToken.IssuedAt)
assert.Equal(t, map[string]*structpb.ListValue{ assert.Equal(t, map[string]*structpb.ListValue{
"some-other-claim": {Values: []*structpb.Value{protoutil.ToStruct("xyz")}}, "some-other-claim": {Values: []*structpb.Value{protoutil.ToStruct("xyz")}},
}, s.Claims) }, s.Claims)

View file

@ -3,7 +3,6 @@ package manager
import ( import (
"cmp" "cmp"
"slices" "slices"
"sync"
"google.golang.org/protobuf/proto" "google.golang.org/protobuf/proto"
@ -11,44 +10,54 @@ import (
"github.com/pomerium/pomerium/pkg/grpc/user" "github.com/pomerium/pomerium/pkg/grpc/user"
) )
// dataStore stores session and user data. All public methods are thread-safe. // dataStore stores session and user data
type dataStore struct { type dataStore struct {
mu sync.Mutex
sessions map[string]*session.Session sessions map[string]*session.Session
users map[string]*user.User users map[string]*user.User
userIDToSessionIDs map[string]map[string]struct{} userIDToSessionIDs map[string]map[string]struct{}
} }
func newDataStore() *dataStore { func newDataStore() *dataStore {
return &dataStore{ ds := new(dataStore)
sessions: make(map[string]*session.Session), ds.deleteAllSessions()
users: make(map[string]*user.User), ds.deleteAllUsers()
userIDToSessionIDs: make(map[string]map[string]struct{}), return ds
}
func (ds *dataStore) deleteAllSessions() {
ds.sessions = make(map[string]*session.Session)
ds.userIDToSessionIDs = make(map[string]map[string]struct{})
}
func (ds *dataStore) deleteAllUsers() {
ds.users = make(map[string]*user.User)
}
func (ds *dataStore) deleteSession(sessionID string) {
s := ds.sessions[sessionID]
delete(ds.sessions, sessionID)
if s.GetUserId() == "" {
return
}
m := ds.userIDToSessionIDs[s.GetUserId()]
if m != nil {
delete(m, s.GetId())
}
if len(m) == 0 {
delete(ds.userIDToSessionIDs, s.GetUserId())
} }
} }
// DeleteSession deletes a session. func (ds *dataStore) deleteUser(userID string) {
func (ds *dataStore) DeleteSession(sID string) {
ds.mu.Lock()
ds.deleteSessionLocked(sID)
ds.mu.Unlock()
}
// DeleteUser deletes a user.
func (ds *dataStore) DeleteUser(userID string) {
ds.mu.Lock()
delete(ds.users, userID) delete(ds.users, userID)
ds.mu.Unlock()
} }
// GetSessionAndUser gets a session and its associated user. func (ds *dataStore) getSessionAndUser(sessionID string) (s *session.Session, u *user.User) {
func (ds *dataStore) GetSessionAndUser(sessionID string) (s *session.Session, u *user.User) {
ds.mu.Lock()
s = ds.sessions[sessionID] s = ds.sessions[sessionID]
if s.GetUserId() != "" { if s.GetUserId() != "" {
u = ds.users[s.GetUserId()] u = ds.users[s.GetUserId()]
} }
ds.mu.Unlock()
// clone to avoid sharing memory // clone to avoid sharing memory
s = clone(s) s = clone(s)
@ -56,14 +65,11 @@ func (ds *dataStore) GetSessionAndUser(sessionID string) (s *session.Session, u
return s, u return s, u
} }
// GetUserAndSessions gets a user and all of its associated sessions. func (ds *dataStore) getUserAndSessions(userID string) (u *user.User, ss []*session.Session) {
func (ds *dataStore) GetUserAndSessions(userID string) (u *user.User, ss []*session.Session) {
ds.mu.Lock()
u = ds.users[userID] u = ds.users[userID]
for sessionID := range ds.userIDToSessionIDs[userID] { for sessionID := range ds.userIDToSessionIDs[userID] {
ss = append(ss, ds.sessions[sessionID]) ss = append(ss, ds.sessions[sessionID])
} }
ds.mu.Unlock()
// remove nils and sort by id // remove nils and sort by id
ss = slices.Compact(ss) ss = slices.Compact(ss)
@ -79,14 +85,12 @@ func (ds *dataStore) GetUserAndSessions(userID string) (u *user.User, ss []*sess
return u, ss return u, ss
} }
// PutSession stores the session. func (ds *dataStore) putSession(s *session.Session) {
func (ds *dataStore) PutSession(s *session.Session) {
// clone to avoid sharing memory // clone to avoid sharing memory
s = clone(s) s = clone(s)
ds.mu.Lock()
if s.GetId() != "" { if s.GetId() != "" {
ds.deleteSessionLocked(s.GetId()) ds.deleteSession(s.GetId())
ds.sessions[s.GetId()] = s ds.sessions[s.GetId()] = s
if s.GetUserId() != "" { if s.GetUserId() != "" {
m, ok := ds.userIDToSessionIDs[s.GetUserId()] m, ok := ds.userIDToSessionIDs[s.GetUserId()]
@ -97,35 +101,15 @@ func (ds *dataStore) PutSession(s *session.Session) {
m[s.GetId()] = struct{}{} m[s.GetId()] = struct{}{}
} }
} }
ds.mu.Unlock()
} }
// PutUser stores the user. func (ds *dataStore) putUser(u *user.User) {
func (ds *dataStore) PutUser(u *user.User) {
// clone to avoid sharing memory // clone to avoid sharing memory
u = clone(u) u = clone(u)
ds.mu.Lock()
if u.GetId() != "" { if u.GetId() != "" {
ds.users[u.GetId()] = u ds.users[u.GetId()] = u
} }
ds.mu.Unlock()
}
func (ds *dataStore) deleteSessionLocked(sID string) {
s := ds.sessions[sID]
delete(ds.sessions, sID)
if s.GetUserId() == "" {
return
}
m := ds.userIDToSessionIDs[s.GetUserId()]
if m != nil {
delete(m, s.GetId())
}
if len(m) == 0 {
delete(ds.userIDToSessionIDs, s.GetUserId())
}
} }
// clone clones a protobuf message // clone clones a protobuf message

View file

@ -15,32 +15,32 @@ func TestDataStore(t *testing.T) {
t.Parallel() t.Parallel()
ds := newDataStore() ds := newDataStore()
s, u := ds.GetSessionAndUser("S1") s, u := ds.getSessionAndUser("S1")
assert.Nil(t, s, "should return a nil session when none exists") assert.Nil(t, s, "should return a nil session when none exists")
assert.Nil(t, u, "should return a nil user when none exists") assert.Nil(t, u, "should return a nil user when none exists")
u, ss := ds.GetUserAndSessions("U1") u, ss := ds.getUserAndSessions("U1")
assert.Nil(t, u, "should return a nil user when none exists") assert.Nil(t, u, "should return a nil user when none exists")
assert.Empty(t, ss, "should return an empty list of sessions when no user exists") assert.Empty(t, ss, "should return an empty list of sessions when no user exists")
s = &session.Session{Id: "S1", UserId: "U1"} s = &session.Session{Id: "S1", UserId: "U1"}
ds.PutSession(s) ds.putSession(s)
s1, u1 := ds.GetSessionAndUser("S1") s1, u1 := ds.getSessionAndUser("S1")
assert.NotNil(t, s1, "should return a non-nil session") assert.NotNil(t, s1, "should return a non-nil session")
assert.False(t, s == s1, "should return different pointers") assert.False(t, s == s1, "should return different pointers")
assert.Empty(t, cmp.Diff(s, s1, protocmp.Transform()), "should be the same as was entered") assert.Empty(t, cmp.Diff(s, s1, protocmp.Transform()), "should be the same as was entered")
assert.Nil(t, u1, "should return a nil user when only the session exists") assert.Nil(t, u1, "should return a nil user when only the session exists")
ds.PutUser(&user.User{ ds.putUser(&user.User{
Id: "U1", Id: "U1",
}) })
_, u1 = ds.GetSessionAndUser("S1") _, u1 = ds.getSessionAndUser("S1")
assert.NotNil(t, u1, "should return a user now that it has been added") assert.NotNil(t, u1, "should return a user now that it has been added")
ds.PutSession(&session.Session{Id: "S4", UserId: "U1"}) ds.putSession(&session.Session{Id: "S4", UserId: "U1"})
ds.PutSession(&session.Session{Id: "S3", UserId: "U1"}) ds.putSession(&session.Session{Id: "S3", UserId: "U1"})
ds.PutSession(&session.Session{Id: "S2", UserId: "U1"}) ds.putSession(&session.Session{Id: "S2", UserId: "U1"})
u, ss = ds.GetUserAndSessions("U1") u, ss = ds.getUserAndSessions("U1")
assert.NotNil(t, u) assert.NotNil(t, u)
assert.Empty(t, cmp.Diff(ss, []*session.Session{ assert.Empty(t, cmp.Diff(ss, []*session.Session{
{Id: "S1", UserId: "U1"}, {Id: "S1", UserId: "U1"},
@ -49,9 +49,9 @@ func TestDataStore(t *testing.T) {
{Id: "S4", UserId: "U1"}, {Id: "S4", UserId: "U1"},
}, protocmp.Transform()), "should return all sessions in id order") }, protocmp.Transform()), "should return all sessions in id order")
ds.DeleteSession("S4") ds.deleteSession("S4")
u, ss = ds.GetUserAndSessions("U1") u, ss = ds.getUserAndSessions("U1")
assert.NotNil(t, u) assert.NotNil(t, u)
assert.Empty(t, cmp.Diff(ss, []*session.Session{ assert.Empty(t, cmp.Diff(ss, []*session.Session{
{Id: "S1", UserId: "U1"}, {Id: "S1", UserId: "U1"},
@ -59,8 +59,8 @@ func TestDataStore(t *testing.T) {
{Id: "S3", UserId: "U1"}, {Id: "S3", UserId: "U1"},
}, protocmp.Transform()), "should return all sessions in id order") }, protocmp.Transform()), "should return all sessions in id order")
ds.DeleteUser("U1") ds.deleteUser("U1")
u, ss = ds.GetUserAndSessions("U1") u, ss = ds.getUserAndSessions("U1")
assert.Nil(t, u) assert.Nil(t, u)
assert.Empty(t, cmp.Diff(ss, []*session.Session{ assert.Empty(t, cmp.Diff(ss, []*session.Session{
{Id: "S1", UserId: "U1"}, {Id: "S1", UserId: "U1"},
@ -68,11 +68,11 @@ func TestDataStore(t *testing.T) {
{Id: "S3", UserId: "U1"}, {Id: "S3", UserId: "U1"},
}, protocmp.Transform()), "should still return all sessions in id order") }, protocmp.Transform()), "should still return all sessions in id order")
ds.DeleteSession("S1") ds.deleteSession("S1")
ds.DeleteSession("S2") ds.deleteSession("S2")
ds.DeleteSession("S3") ds.deleteSession("S3")
u, ss = ds.GetUserAndSessions("U1") u, ss = ds.getUserAndSessions("U1")
assert.Nil(t, u) assert.Nil(t, u)
assert.Empty(t, ss) assert.Empty(t, ss)
} }

View file

@ -3,10 +3,9 @@ package manager
import ( import (
"context" "context"
"errors" "sync"
"time" "time"
"github.com/google/btree"
"github.com/rs/zerolog" "github.com/rs/zerolog"
"golang.org/x/oauth2" "golang.org/x/oauth2"
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
@ -19,7 +18,6 @@ import (
"github.com/pomerium/pomerium/internal/events" "github.com/pomerium/pomerium/internal/events"
"github.com/pomerium/pomerium/internal/identity/identity" "github.com/pomerium/pomerium/internal/identity/identity"
"github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/scheduler"
"github.com/pomerium/pomerium/internal/telemetry/metrics" "github.com/pomerium/pomerium/internal/telemetry/metrics"
"github.com/pomerium/pomerium/pkg/grpc/databroker" "github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/grpc/session" "github.com/pomerium/pomerium/pkg/grpc/session"
@ -35,21 +33,14 @@ type Authenticator interface {
UpdateUserInfo(context.Context, *oauth2.Token, interface{}) error UpdateUserInfo(context.Context, *oauth2.Token, interface{}) error
} }
type (
updateRecordsMessage struct {
records []*databroker.Record
}
)
// A Manager refreshes identity information using session and user data. // A Manager refreshes identity information using session and user data.
type Manager struct { type Manager struct {
cfg *atomicutil.Value[*config] cfg *atomicutil.Value[*config]
sessionScheduler *scheduler.Scheduler mu sync.Mutex
userScheduler *scheduler.Scheduler dataStore *dataStore
refreshSessionSchedulers map[string]*refreshSessionScheduler
sessions sessionCollection updateUserInfoSchedulers map[string]*updateUserInfoScheduler
users userCollection
} }
// New creates a new identity manager. // New creates a new identity manager.
@ -59,25 +50,24 @@ func New(
mgr := &Manager{ mgr := &Manager{
cfg: atomicutil.NewValue(newConfig()), cfg: atomicutil.NewValue(newConfig()),
sessionScheduler: scheduler.New(), dataStore: newDataStore(),
userScheduler: scheduler.New(), refreshSessionSchedulers: make(map[string]*refreshSessionScheduler),
updateUserInfoSchedulers: make(map[string]*updateUserInfoScheduler),
} }
mgr.reset()
mgr.UpdateConfig(options...) mgr.UpdateConfig(options...)
return mgr return mgr
} }
func withLog(ctx context.Context) context.Context {
return log.WithContext(ctx, func(c zerolog.Context) zerolog.Context {
return c.Str("service", "identity_manager")
})
}
// UpdateConfig updates the manager with the new options. // UpdateConfig updates the manager with the new options.
func (mgr *Manager) UpdateConfig(options ...Option) { func (mgr *Manager) UpdateConfig(options ...Option) {
mgr.cfg.Store(newConfig(options...)) mgr.cfg.Store(newConfig(options...))
} }
// GetDataBrokerServiceClient gets the databroker client.
func (mgr *Manager) GetDataBrokerServiceClient() databroker.DataBrokerServiceClient {
return mgr.cfg.Load().dataBrokerClient
}
// Run runs the manager. This method blocks until an error occurs or the given context is canceled. // Run runs the manager. This method blocks until an error occurs or the given context is canceled.
func (mgr *Manager) Run(ctx context.Context) error { func (mgr *Manager) Run(ctx context.Context) error {
leaser := databroker.NewLeaser("identity_manager", time.Second*30, mgr) leaser := databroker.NewLeaser("identity_manager", time.Second*30, mgr)
@ -86,173 +76,144 @@ func (mgr *Manager) Run(ctx context.Context) error {
// RunLeased runs the identity manager when a lease is acquired. // RunLeased runs the identity manager when a lease is acquired.
func (mgr *Manager) RunLeased(ctx context.Context) error { func (mgr *Manager) RunLeased(ctx context.Context) error {
ctx = withLog(ctx) ctx = log.WithContext(ctx, func(c zerolog.Context) zerolog.Context {
update := make(chan updateRecordsMessage, 1) return c.Str("service", "identity_manager")
clear := make(chan struct{}, 1) })
syncer := newDataBrokerSyncer(ctx, mgr.cfg, update, clear)
eg, ctx := errgroup.WithContext(ctx) eg, ctx := errgroup.WithContext(ctx)
eg.Go(func() error { eg.Go(func() error {
return syncer.Run(ctx) sessionSyncer := newSessionSyncer(mgr)
defer sessionSyncer.Close()
return sessionSyncer.Run(ctx)
}) })
eg.Go(func() error { eg.Go(func() error {
return mgr.refreshLoop(ctx, update, clear) userSyncer := newUserSyncer(mgr)
defer userSyncer.Close()
return userSyncer.Run(ctx)
}) })
return eg.Wait() return eg.Wait()
} }
// GetDataBrokerServiceClient gets the databroker client. func (mgr *Manager) onDeleteAllSessions(ctx context.Context) {
func (mgr *Manager) GetDataBrokerServiceClient() databroker.DataBrokerServiceClient { log.Ctx(ctx).Debug().Msg("all session deleted")
return mgr.cfg.Load().dataBrokerClient
mgr.mu.Lock()
mgr.dataStore.deleteAllSessions()
for sID, rss := range mgr.refreshSessionSchedulers {
rss.Stop()
delete(mgr.refreshSessionSchedulers, sID)
}
mgr.mu.Unlock()
} }
func (mgr *Manager) now() time.Time { func (mgr *Manager) onDeleteAllUsers(ctx context.Context) {
return mgr.cfg.Load().now() log.Ctx(ctx).Debug().Msg("all users deleted")
mgr.mu.Lock()
mgr.dataStore.deleteAllUsers()
for uID, uuis := range mgr.updateUserInfoSchedulers {
uuis.Stop()
delete(mgr.updateUserInfoSchedulers, uID)
}
mgr.mu.Unlock()
} }
func (mgr *Manager) refreshLoop(ctx context.Context, update <-chan updateRecordsMessage, clear <-chan struct{}) error { func (mgr *Manager) onDeleteSession(ctx context.Context, sessionID string) {
// wait for initial sync log.Ctx(ctx).Debug().Str("session_id", sessionID).Msg("session deleted")
select {
case <-ctx.Done(): mgr.mu.Lock()
return ctx.Err() mgr.dataStore.deleteSession(sessionID)
case <-clear: if rss, ok := mgr.refreshSessionSchedulers[sessionID]; ok {
mgr.reset() rss.Stop()
} delete(mgr.refreshSessionSchedulers, sessionID)
select {
case <-ctx.Done():
return ctx.Err()
case msg := <-update:
mgr.onUpdateRecords(ctx, msg)
}
log.Debug(ctx).
Int("sessions", mgr.sessions.Len()).
Int("users", mgr.users.Len()).
Msg("initial sync complete")
// start refreshing
maxWait := time.Minute * 10
var nextTime time.Time
timer := time.NewTimer(0)
defer timer.Stop()
for {
select {
case <-ctx.Done():
return ctx.Err()
case <-clear:
mgr.reset()
case msg := <-update:
mgr.onUpdateRecords(ctx, msg)
case <-timer.C:
}
now := mgr.now()
nextTime = now.Add(maxWait)
// refresh sessions
for {
tm, key := mgr.sessionScheduler.Next()
if now.Before(tm) {
if tm.Before(nextTime) {
nextTime = tm
}
break
}
mgr.sessionScheduler.Remove(key)
userID, sessionID := fromSessionSchedulerKey(key)
mgr.refreshSession(ctx, userID, sessionID)
}
// refresh users
for {
tm, key := mgr.userScheduler.Next()
if now.Before(tm) {
if tm.Before(nextTime) {
nextTime = tm
}
break
}
mgr.userScheduler.Remove(key)
mgr.refreshUser(ctx, key)
}
metrics.RecordIdentityManagerLastRefresh(ctx)
timer.Reset(time.Until(nextTime))
} }
mgr.mu.Unlock()
} }
// refreshSession handles two distinct session lifecycle events: func (mgr *Manager) onDeleteUser(ctx context.Context, userID string) {
// log.Ctx(ctx).Debug().Str("user_id", userID).Msg("user deleted")
// 1. If the session itself has expired, delete the session.
// 2. If the session's underlying OAuth2 access token is nearing expiration mgr.mu.Lock()
// (but the session itself is still valid), refresh the access token. mgr.dataStore.deleteUser(userID)
// if uuis, ok := mgr.updateUserInfoSchedulers[userID]; ok {
// After a successful access token refresh, this method will also trigger a uuis.Stop()
// user info refresh. If an access token refresh or a user info refresh fails delete(mgr.updateUserInfoSchedulers, userID)
// with a permanent error, the session will be deleted. }
func (mgr *Manager) refreshSession(ctx context.Context, userID, sessionID string) { mgr.mu.Unlock()
log.Info(ctx). }
Str("user_id", userID).
func (mgr *Manager) onUpdateSession(ctx context.Context, s *session.Session) {
log.Ctx(ctx).Debug().Str("session_id", s.GetId()).Msg("session updated")
mgr.mu.Lock()
mgr.dataStore.putSession(s)
rss, ok := mgr.refreshSessionSchedulers[s.GetId()]
if !ok {
rss = newRefreshSessionScheduler(ctx, mgr, s.GetId())
mgr.refreshSessionSchedulers[s.GetId()] = rss
}
rss.Update(s)
mgr.mu.Unlock()
}
func (mgr *Manager) onUpdateUser(ctx context.Context, u *user.User) {
log.Ctx(ctx).Debug().Str("user_id", u.GetId()).Msg("user updated")
mgr.mu.Lock()
mgr.dataStore.putUser(u)
_, ok := mgr.updateUserInfoSchedulers[u.GetId()]
if !ok {
uuis := newUpdateUserInfoScheduler(ctx, mgr, u.GetId())
mgr.updateUserInfoSchedulers[u.GetId()] = uuis
}
mgr.mu.Unlock()
}
func (mgr *Manager) refreshSession(ctx context.Context, sessionID string) {
log.Ctx(ctx).Debug().
Str("session_id", sessionID). Str("session_id", sessionID).
Msg("refreshing session") Msg("refreshing session")
s, ok := mgr.sessions.Get(userID, sessionID) mgr.mu.Lock()
if !ok { s, u := mgr.dataStore.getSessionAndUser(sessionID)
log.Warn(ctx). mgr.mu.Unlock()
Str("user_id", userID).
if s == nil {
log.Ctx(ctx).Warn().
Str("user_id", u.GetId()).
Str("session_id", sessionID). Str("session_id", sessionID).
Msg("no session found for refresh") Msg("no session found for refresh")
return return
} }
s.lastRefresh = mgr.now()
if mgr.refreshSessionInternal(ctx, userID, sessionID, &s) {
mgr.sessions.ReplaceOrInsert(s)
mgr.sessionScheduler.Add(s.NextRefresh(), toSessionSchedulerKey(userID, sessionID))
}
}
// refreshSessionInternal performs the core refresh logic and returns true if
// the session should be scheduled for refresh again, or false if not.
func (mgr *Manager) refreshSessionInternal(
ctx context.Context, userID, sessionID string, s *Session,
) bool {
authenticator := mgr.cfg.Load().authenticator authenticator := mgr.cfg.Load().authenticator
if authenticator == nil { if authenticator == nil {
log.Info(ctx). log.Ctx(ctx).Info().
Str("user_id", userID). Str("user_id", s.GetUserId()).
Str("session_id", sessionID). Str("session_id", s.GetId()).
Msg("no authenticator defined, deleting session") Msg("no authenticator defined, deleting session")
mgr.deleteSession(ctx, userID, sessionID) mgr.deleteSession(ctx, sessionID)
return false return
} }
expiry := s.GetExpiresAt().AsTime() expiry := s.GetExpiresAt().AsTime()
if !expiry.After(mgr.now()) { if !expiry.After(mgr.cfg.Load().now()) {
log.Info(ctx). log.Info(ctx).
Str("user_id", userID). Str("user_id", s.GetUserId()).
Str("session_id", sessionID). Str("session_id", s.GetId()).
Msg("deleting expired session") Msg("deleting expired session")
mgr.deleteSession(ctx, userID, sessionID) mgr.deleteSession(ctx, sessionID)
return false return
} }
if s.Session == nil || s.Session.OauthToken == nil { if s.GetOauthToken() == nil {
log.Warn(ctx). log.Warn(ctx).
Str("user_id", userID). Str("user_id", s.GetUserId()).
Str("session_id", sessionID). Str("session_id", s.GetId()).
Msg("no session oauth2 token found for refresh") Msg("no session oauth2 token found for refresh")
return false return
} }
newToken, err := authenticator.Refresh(ctx, FromOAuthToken(s.OauthToken), s) newToken, err := authenticator.Refresh(ctx, FromOAuthToken(s.OauthToken), newSessionUnmarshaler(s))
metrics.RecordIdentityManagerSessionRefresh(ctx, err) metrics.RecordIdentityManagerSessionRefresh(ctx, err)
mgr.recordLastError(metrics_ids.IdentityManagerLastSessionRefreshError, err) mgr.recordLastError(metrics_ids.IdentityManagerLastSessionRefreshError, err)
if isTemporaryError(err) { if isTemporaryError(err) {
@ -260,18 +221,18 @@ func (mgr *Manager) refreshSessionInternal(
Str("user_id", s.GetUserId()). Str("user_id", s.GetUserId()).
Str("session_id", s.GetId()). Str("session_id", s.GetId()).
Msg("failed to refresh oauth2 token") Msg("failed to refresh oauth2 token")
return true return
} else if err != nil { } else if err != nil {
log.Error(ctx).Err(err). log.Error(ctx).Err(err).
Str("user_id", s.GetUserId()). Str("user_id", s.GetUserId()).
Str("session_id", s.GetId()). Str("session_id", s.GetId()).
Msg("failed to refresh oauth2 token, deleting session") Msg("failed to refresh oauth2 token, deleting session")
mgr.deleteSession(ctx, userID, sessionID) mgr.deleteSession(ctx, sessionID)
return false return
} }
s.OauthToken = ToOAuthToken(newToken) s.OauthToken = ToOAuthToken(newToken)
err = authenticator.UpdateUserInfo(ctx, FromOAuthToken(s.OauthToken), s) err = authenticator.UpdateUserInfo(ctx, FromOAuthToken(s.OauthToken), newMultiUnmarshaler(newUserUnmarshaler(u), newSessionUnmarshaler(s)))
metrics.RecordIdentityManagerUserRefresh(ctx, err) metrics.RecordIdentityManagerUserRefresh(ctx, err)
mgr.recordLastError(metrics_ids.IdentityManagerLastUserRefreshError, err) mgr.recordLastError(metrics_ids.IdentityManagerLastUserRefreshError, err)
if isTemporaryError(err) { if isTemporaryError(err) {
@ -279,184 +240,162 @@ func (mgr *Manager) refreshSessionInternal(
Str("user_id", s.GetUserId()). Str("user_id", s.GetUserId()).
Str("session_id", s.GetId()). Str("session_id", s.GetId()).
Msg("failed to update user info") Msg("failed to update user info")
return true return
} else if err != nil { } else if err != nil {
log.Error(ctx).Err(err). log.Error(ctx).Err(err).
Str("user_id", s.GetUserId()). Str("user_id", s.GetUserId()).
Str("session_id", s.GetId()). Str("session_id", s.GetId()).
Msg("failed to update user info, deleting session") Msg("failed to update user info, deleting session")
mgr.deleteSession(ctx, userID, sessionID) mgr.deleteSession(ctx, sessionID)
return false return
} }
fm, err := fieldmaskpb.New(s.Session, "oauth_token", "id_token", "claims") mgr.updateSession(ctx, s)
if err != nil { if u != nil {
log.Error(ctx).Err(err).Msg("internal error") mgr.updateUser(ctx, u)
return false
} }
if _, err := session.Patch(ctx, mgr.cfg.Load().dataBrokerClient, s.Session, fm); err != nil {
log.Error(ctx).Err(err).
Str("user_id", s.GetUserId()).
Str("session_id", s.GetId()).
Msg("failed to update session")
}
return true
} }
func (mgr *Manager) refreshUser(ctx context.Context, userID string) { func (mgr *Manager) updateUserInfo(ctx context.Context, userID string) {
log.Info(ctx). log.Ctx(ctx).Info().Str("user_id", userID).Msg("updating user info")
Str("user_id", userID).
Msg("refreshing user")
authenticator := mgr.cfg.Load().authenticator authenticator := mgr.cfg.Load().authenticator
if authenticator == nil { if authenticator == nil {
return return
} }
u, ok := mgr.users.Get(userID) mgr.mu.Lock()
if !ok { u, ss := mgr.dataStore.getUserAndSessions(userID)
log.Warn(ctx). mgr.mu.Unlock()
if u == nil {
log.Ctx(ctx).Warn().
Str("user_id", userID). Str("user_id", userID).
Msg("no user found for refresh") Msg("no user found for update")
return return
} }
u.lastRefresh = mgr.now()
mgr.userScheduler.Add(u.NextRefresh(), u.GetId())
for _, s := range mgr.sessions.GetSessionsForUser(userID) { for _, s := range ss {
if s.Session == nil || s.Session.OauthToken == nil { if s.GetOauthToken() == nil {
log.Warn(ctx). log.Ctx(ctx).Warn().
Str("user_id", userID). Str("user_id", s.GetUserId()).
Msg("no session oauth2 token found for refresh") Str("session_id", s.GetId()).
Msg("no session oauth2 token found for updating user info")
continue continue
} }
err := authenticator.UpdateUserInfo(ctx, FromOAuthToken(s.OauthToken), &u) err := authenticator.UpdateUserInfo(ctx, FromOAuthToken(s.GetOauthToken()), newMultiUnmarshaler(newUserUnmarshaler(u), newSessionUnmarshaler(s)))
metrics.RecordIdentityManagerUserRefresh(ctx, err) metrics.RecordIdentityManagerUserRefresh(ctx, err)
mgr.recordLastError(metrics_ids.IdentityManagerLastUserRefreshError, err) mgr.recordLastError(metrics_ids.IdentityManagerLastUserRefreshError, err)
if isTemporaryError(err) { if isTemporaryError(err) {
log.Error(ctx).Err(err). log.Ctx(ctx).Error().Err(err).
Str("user_id", s.GetUserId()). Str("user_id", s.GetUserId()).
Str("session_id", s.GetId()). Str("session_id", s.GetId()).
Msg("failed to update user info") Msg("failed to update user info")
return continue
} else if err != nil { } else if err != nil {
log.Error(ctx).Err(err). log.Ctx(ctx).Error().Err(err).
Str("user_id", s.GetUserId()). Str("user_id", s.GetUserId()).
Str("session_id", s.GetId()). Str("session_id", s.GetId()).
Msg("failed to update user info, deleting session") Msg("failed to update user info, deleting session")
mgr.deleteSession(ctx, userID, s.GetId()) mgr.deleteSession(ctx, s.GetId())
continue continue
} }
res, err := databroker.Put(ctx, mgr.cfg.Load().dataBrokerClient, u.User) mgr.updateSession(ctx, s)
if err != nil { mgr.updateUser(ctx, u)
log.Error(ctx).Err(err).
Str("user_id", s.GetUserId()).
Str("session_id", s.GetId()).
Msg("failed to update user")
continue
}
mgr.onUpdateUser(ctx, res.GetRecords()[0], u.User)
} }
} }
func (mgr *Manager) onUpdateRecords(ctx context.Context, msg updateRecordsMessage) { // deleteSession deletes a session from the databroke, the local data store, and the schedulers
for _, record := range msg.records { func (mgr *Manager) deleteSession(ctx context.Context, sessionID string) {
switch record.GetType() { log.Ctx(ctx).Debug().
case grpcutil.GetTypeURL(new(session.Session)): Str("session_id", sessionID).
var pbSession session.Session Msg("deleting session")
err := record.GetData().UnmarshalTo(&pbSession)
if err != nil { mgr.mu.Lock()
log.Warn(ctx).Msgf("error unmarshaling session: %s", err) mgr.dataStore.deleteSession(sessionID)
continue if rss, ok := mgr.refreshSessionSchedulers[sessionID]; ok {
} rss.Stop()
mgr.onUpdateSession(record, &pbSession) delete(mgr.refreshSessionSchedulers, sessionID)
case grpcutil.GetTypeURL(new(user.User)):
var pbUser user.User
err := record.GetData().UnmarshalTo(&pbUser)
if err != nil {
log.Warn(ctx).Msgf("error unmarshaling user: %s", err)
continue
}
mgr.onUpdateUser(ctx, record, &pbUser)
}
} }
} mgr.mu.Unlock()
func (mgr *Manager) onUpdateSession(record *databroker.Record, session *session.Session) { res, err := mgr.cfg.Load().dataBrokerClient.Get(ctx, &databroker.GetRequest{
mgr.sessionScheduler.Remove(toSessionSchedulerKey(session.GetUserId(), session.GetId()))
if record.GetDeletedAt() != nil {
mgr.sessions.Delete(session.GetUserId(), session.GetId())
return
}
// update session
s, _ := mgr.sessions.Get(session.GetUserId(), session.GetId())
if s.lastRefresh.IsZero() {
s.lastRefresh = mgr.now()
}
s.gracePeriod = mgr.cfg.Load().sessionRefreshGracePeriod
s.coolOffDuration = mgr.cfg.Load().sessionRefreshCoolOffDuration
s.Session = session
mgr.sessions.ReplaceOrInsert(s)
mgr.sessionScheduler.Add(s.NextRefresh(), toSessionSchedulerKey(session.GetUserId(), session.GetId()))
}
func (mgr *Manager) onUpdateUser(_ context.Context, record *databroker.Record, user *user.User) {
mgr.userScheduler.Remove(user.GetId())
if record.GetDeletedAt() != nil {
mgr.users.Delete(user.GetId())
return
}
u, _ := mgr.users.Get(user.GetId())
u.lastRefresh = mgr.cfg.Load().now()
u.User = user
mgr.users.ReplaceOrInsert(u)
mgr.userScheduler.Add(u.NextRefresh(), u.GetId())
}
func (mgr *Manager) deleteSession(ctx context.Context, userID, sessionID string) {
mgr.sessionScheduler.Remove(toSessionSchedulerKey(userID, sessionID))
mgr.sessions.Delete(userID, sessionID)
client := mgr.cfg.Load().dataBrokerClient
res, err := client.Get(ctx, &databroker.GetRequest{
Type: grpcutil.GetTypeURL(new(session.Session)), Type: grpcutil.GetTypeURL(new(session.Session)),
Id: sessionID, Id: sessionID,
}) })
if status.Code(err) == codes.NotFound { if status.Code(err) == codes.NotFound {
return return
} else if err != nil {
log.Error(ctx).Err(err).
Str("session_id", sessionID).
Msg("failed to delete session")
return
} }
record := res.GetRecord() record := res.GetRecord()
record.DeletedAt = timestamppb.Now() record.DeletedAt = timestamppb.Now()
_, err = client.Put(ctx, &databroker.PutRequest{ _, err = mgr.cfg.Load().dataBrokerClient.Put(ctx, &databroker.PutRequest{
Records: []*databroker.Record{record}, Records: []*databroker.Record{record},
}) })
if err != nil { if err != nil {
log.Error(ctx).Err(err). log.Ctx(ctx).Error().Err(err).
Str("session_id", sessionID). Str("session_id", sessionID).
Msg("failed to delete session") Msg("failed to delete session")
return return
} }
} }
// reset resets all the manager datastructures to their initial state func (mgr *Manager) updateSession(ctx context.Context, s *session.Session) {
func (mgr *Manager) reset() { log.Ctx(ctx).Debug().
mgr.sessions = sessionCollection{BTree: btree.New(8)} Str("user_id", s.GetUserId()).
mgr.users = userCollection{BTree: btree.New(8)} Str("session_id", s.GetId()).
Msg("updating session")
fm, err := fieldmaskpb.New(s, "oauth_token", "id_token", "claims")
if err != nil {
log.Error(ctx).Err(err).
Str("user_id", s.GetUserId()).
Str("session_id", s.GetId()).
Msg("failed to create fieldmask for session")
return
}
_, err = session.Patch(ctx, mgr.cfg.Load().dataBrokerClient, s, fm)
if err != nil {
log.Error(ctx).Err(err).
Str("user_id", s.GetUserId()).
Str("session_id", s.GetId()).
Msg("failed to patch updated session record")
return
}
mgr.mu.Lock()
mgr.dataStore.putSession(s)
if rss, ok := mgr.refreshSessionSchedulers[s.GetId()]; ok {
rss.Update(s)
}
mgr.mu.Unlock()
}
// updateUser updates the user in the databroker, the local data store, and resets the scheduler
func (mgr *Manager) updateUser(ctx context.Context, u *user.User) {
log.Ctx(ctx).Debug().
Str("user_id", u.GetId()).
Msg("updating user")
_, err := databroker.Put(ctx, mgr.cfg.Load().dataBrokerClient, u)
if err != nil {
log.Ctx(ctx).Error().
Str("user_id", u.GetId()).
Err(err).
Msg("failed to store updated user record")
return
}
mgr.mu.Lock()
mgr.dataStore.putUser(u)
if uuis, ok := mgr.updateUserInfoSchedulers[u.GetId()]; ok {
uuis.Reset()
}
mgr.mu.Unlock()
} }
func (mgr *Manager) recordLastError(id string, err error) { func (mgr *Manager) recordLastError(id string, err error) {
@ -473,17 +412,3 @@ func (mgr *Manager) recordLastError(id string, err error) {
Id: id, Id: id,
}) })
} }
func isTemporaryError(err error) bool {
if err == nil {
return false
}
if errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled) {
return true
}
var hasTemporary interface{ Temporary() bool }
if errors.As(err, &hasTemporary) && hasTemporary.Temporary() {
return true
}
return false
}

View file

@ -2,28 +2,10 @@ package manager
import ( import (
"context" "context"
"errors"
"fmt"
"testing"
"time"
"github.com/stretchr/testify/assert"
"go.uber.org/mock/gomock"
"golang.org/x/oauth2" "golang.org/x/oauth2"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/fieldmaskpb"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/pomerium/pomerium/internal/events"
"github.com/pomerium/pomerium/internal/identity/identity" "github.com/pomerium/pomerium/internal/identity/identity"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/grpc/databroker/mock_databroker"
"github.com/pomerium/pomerium/pkg/grpc/session"
"github.com/pomerium/pomerium/pkg/grpc/user"
metrics_ids "github.com/pomerium/pomerium/pkg/metrics"
"github.com/pomerium/pomerium/pkg/protoutil"
) )
type mockAuthenticator struct { type mockAuthenticator struct {
@ -44,317 +26,3 @@ func (mock *mockAuthenticator) Revoke(_ context.Context, _ *oauth2.Token) error
func (mock *mockAuthenticator) UpdateUserInfo(_ context.Context, _ *oauth2.Token, _ any) error { func (mock *mockAuthenticator) UpdateUserInfo(_ context.Context, _ *oauth2.Token, _ any) error {
return mock.updateUserInfoError return mock.updateUserInfoError
} }
func TestManager_refresh(t *testing.T) {
ctrl := gomock.NewController(t)
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10)
t.Cleanup(clearTimeout)
client := mock_databroker.NewMockDataBrokerServiceClient(ctrl)
mgr := New(WithDataBrokerClient(client))
mgr.onUpdateRecords(ctx, updateRecordsMessage{
records: []*databroker.Record{
databroker.NewRecord(&session.Session{
Id: "s1",
UserId: "u1",
OauthToken: &session.OAuthToken{},
ExpiresAt: timestamppb.New(time.Now().Add(time.Second * 10)),
}),
databroker.NewRecord(&user.User{
Id: "u1",
}),
},
})
client.EXPECT().Get(gomock.Any(), gomock.Any()).Return(nil, status.Error(codes.NotFound, "not found"))
mgr.refreshSession(ctx, "u1", "s1")
mgr.refreshUser(ctx, "u1")
}
func TestManager_onUpdateRecords(t *testing.T) {
ctrl := gomock.NewController(t)
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10)
defer clearTimeout()
now := time.Now()
mgr := New(
WithDataBrokerClient(mock_databroker.NewMockDataBrokerServiceClient(ctrl)),
WithNow(func() time.Time {
return now
}),
)
mgr.onUpdateRecords(ctx, updateRecordsMessage{
records: []*databroker.Record{
mkRecord(&session.Session{Id: "session1", UserId: "user1"}),
mkRecord(&user.User{Id: "user1", Name: "user 1", Email: "user1@example.com"}),
},
})
if _, ok := mgr.sessions.Get("user1", "session1"); assert.True(t, ok) {
tm, id := mgr.sessionScheduler.Next()
assert.Equal(t, now.Add(10*time.Second), tm)
assert.Equal(t, "user1\037session1", id)
}
if _, ok := mgr.users.Get("user1"); assert.True(t, ok) {
tm, id := mgr.userScheduler.Next()
assert.Equal(t, now.Add(userRefreshInterval), tm)
assert.Equal(t, "user1", id)
}
}
func TestManager_onUpdateSession(t *testing.T) {
startTime := time.Date(2023, 10, 19, 12, 0, 0, 0, time.UTC)
s := &session.Session{
Id: "session-id",
UserId: "user-id",
OauthToken: &session.OAuthToken{
AccessToken: "access-token",
ExpiresAt: timestamppb.New(startTime.Add(5 * time.Minute)),
},
IssuedAt: timestamppb.New(startTime),
ExpiresAt: timestamppb.New(startTime.Add(24 * time.Hour)),
}
assertNextScheduled := func(t *testing.T, mgr *Manager, expectedTime time.Time) {
t.Helper()
tm, key := mgr.sessionScheduler.Next()
assert.Equal(t, expectedTime, tm)
assert.Equal(t, "user-id\037session-id", key)
}
t.Run("initial refresh event when not expiring soon", func(t *testing.T) {
now := startTime
mgr := New(WithNow(func() time.Time { return now }))
// When the Manager first becomes aware of a session it should schedule
// a refresh event for one minute before access token expiration.
mgr.onUpdateSession(mkRecord(s), s)
assertNextScheduled(t, mgr, startTime.Add(4*time.Minute))
})
t.Run("initial refresh event when expiring soon", func(t *testing.T) {
now := startTime
mgr := New(WithNow(func() time.Time { return now }))
// When the Manager first becomes aware of a session, if that session
// is expiring within the gracePeriod (1 minute), it should schedule a
// refresh event for as soon as possible, subject to the
// coolOffDuration (10 seconds).
now = now.Add(4*time.Minute + 30*time.Second) // 30 s before expiration
mgr.onUpdateSession(mkRecord(s), s)
assertNextScheduled(t, mgr, now.Add(10*time.Second))
})
t.Run("update near scheduled refresh", func(t *testing.T) {
now := startTime
mgr := New(WithNow(func() time.Time { return now }))
mgr.onUpdateSession(mkRecord(s), s)
assertNextScheduled(t, mgr, startTime.Add(4*time.Minute))
// If a session is updated close to the time when it is scheduled to be
// refreshed, the scheduled refresh event should not be pushed back.
now = now.Add(3*time.Minute + 55*time.Second) // 5 s before refresh
mgr.onUpdateSession(mkRecord(s), s)
assertNextScheduled(t, mgr, now.Add(5*time.Second))
// However, if an update changes the access token validity, the refresh
// event should be rescheduled accordingly. (This should be uncommon,
// as only the refresh loop itself should modify the access token.)
s2 := proto.Clone(s).(*session.Session)
s2.OauthToken.ExpiresAt = timestamppb.New(now.Add(5 * time.Minute))
mgr.onUpdateSession(mkRecord(s2), s2)
assertNextScheduled(t, mgr, now.Add(4*time.Minute))
})
t.Run("session record deleted", func(t *testing.T) {
now := startTime
mgr := New(WithNow(func() time.Time { return now }))
mgr.onUpdateSession(mkRecord(s), s)
assertNextScheduled(t, mgr, startTime.Add(4*time.Minute))
// If a session is deleted, any scheduled refresh event should be canceled.
record := mkRecord(s)
record.DeletedAt = timestamppb.New(now)
mgr.onUpdateSession(record, s)
_, key := mgr.sessionScheduler.Next()
assert.Empty(t, key)
})
}
func TestManager_refreshSession(t *testing.T) {
startTime := time.Date(2023, 10, 19, 12, 0, 0, 0, time.UTC)
var auth mockAuthenticator
ctrl := gomock.NewController(t)
client := mock_databroker.NewMockDataBrokerServiceClient(ctrl)
now := startTime
mgr := New(
WithDataBrokerClient(client),
WithNow(func() time.Time { return now }),
WithAuthenticator(&auth),
)
// Initialize the Manager with a new session.
s := &session.Session{
Id: "session-id",
UserId: "user-id",
OauthToken: &session.OAuthToken{
AccessToken: "access-token",
ExpiresAt: timestamppb.New(startTime.Add(5 * time.Minute)),
RefreshToken: "refresh-token",
},
IssuedAt: timestamppb.New(startTime),
ExpiresAt: timestamppb.New(startTime.Add(24 * time.Hour)),
}
mgr.sessions.ReplaceOrInsert(Session{
Session: s,
lastRefresh: startTime,
gracePeriod: time.Minute,
coolOffDuration: 10 * time.Second,
})
// If OAuth2 token refresh fails with a temporary error, the manager should
// still reschedule another refresh attempt.
now = now.Add(4 * time.Minute)
auth.refreshError = context.DeadlineExceeded
mgr.refreshSession(context.Background(), "user-id", "session-id")
tm, key := mgr.sessionScheduler.Next()
assert.Equal(t, now.Add(10*time.Second), tm)
assert.Equal(t, "user-id\037session-id", key)
// Simulate a successful token refresh on the second attempt. The manager
// should store the updated session in the databroker and schedule another
// refresh event.
now = now.Add(10 * time.Second)
auth.refreshResult, auth.refreshError = &oauth2.Token{
AccessToken: "new-access-token",
RefreshToken: "new-refresh-token",
Expiry: now.Add(5 * time.Minute),
}, nil
expectedSession := proto.Clone(s).(*session.Session)
expectedSession.OauthToken = &session.OAuthToken{
AccessToken: "new-access-token",
ExpiresAt: timestamppb.New(now.Add(5 * time.Minute)),
RefreshToken: "new-refresh-token",
}
client.EXPECT().Patch(gomock.Any(), objectsAreEqualMatcher{
&databroker.PatchRequest{
Records: []*databroker.Record{{
Type: "type.googleapis.com/session.Session",
Id: "session-id",
Data: protoutil.NewAny(expectedSession),
}},
FieldMask: &fieldmaskpb.FieldMask{
Paths: []string{"oauth_token", "id_token", "claims"},
},
},
}).
Return(nil /* this result is currently unused */, nil)
mgr.refreshSession(context.Background(), "user-id", "session-id")
tm, key = mgr.sessionScheduler.Next()
assert.Equal(t, now.Add(4*time.Minute), tm)
assert.Equal(t, "user-id\037session-id", key)
}
func TestManager_reportErrors(t *testing.T) {
ctrl := gomock.NewController(t)
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10)
defer clearTimeout()
evtMgr := events.New()
received := make(chan events.Event, 1)
handle := evtMgr.Register(func(evt events.Event) {
received <- evt
})
defer evtMgr.Unregister(handle)
expectMsg := func(id, msg string) {
t.Helper()
assert.Eventually(t, func() bool {
select {
case evt := <-received:
lastErr := evt.(*events.LastError)
return msg == lastErr.Message && id == lastErr.Id
default:
return false
}
}, time.Second, time.Millisecond*20, msg)
}
s := &session.Session{
Id: "session1",
UserId: "user1",
OauthToken: &session.OAuthToken{
ExpiresAt: timestamppb.New(time.Now().Add(time.Hour)),
},
ExpiresAt: timestamppb.New(time.Now().Add(time.Hour)),
}
client := mock_databroker.NewMockDataBrokerServiceClient(ctrl)
client.EXPECT().Get(gomock.Any(), gomock.Any()).AnyTimes().Return(&databroker.GetResponse{Record: databroker.NewRecord(s)}, nil)
client.EXPECT().Put(gomock.Any(), gomock.Any()).AnyTimes()
mgr := New(
WithEventManager(evtMgr),
WithDataBrokerClient(client),
WithAuthenticator(&mockAuthenticator{
refreshError: errors.New("update session"),
updateUserInfoError: errors.New("update user info"),
}),
)
mgr.onUpdateRecords(ctx, updateRecordsMessage{
records: []*databroker.Record{
mkRecord(s),
mkRecord(&user.User{Id: "user1", Name: "user 1", Email: "user1@example.com"}),
},
})
mgr.refreshUser(ctx, "user1")
expectMsg(metrics_ids.IdentityManagerLastUserRefreshError, "update user info")
mgr.onUpdateRecords(ctx, updateRecordsMessage{
records: []*databroker.Record{
mkRecord(s),
mkRecord(&user.User{Id: "user1", Name: "user 1", Email: "user1@example.com"}),
},
})
mgr.refreshSession(ctx, "user1", "session1")
expectMsg(metrics_ids.IdentityManagerLastSessionRefreshError, "update session")
}
func mkRecord(msg recordable) *databroker.Record {
data := protoutil.NewAny(msg)
return &databroker.Record{
Type: data.GetTypeUrl(),
Id: msg.GetId(),
Data: data,
}
}
type recordable interface {
proto.Message
GetId() string
}
// objectsAreEqualMatcher implements gomock.Matcher using ObjectsAreEqual. This
// is especially helpful when working with pointers, as it will compare the
// underlying values rather than the pointers themselves.
type objectsAreEqualMatcher struct {
expected interface{}
}
func (m objectsAreEqualMatcher) Matches(x interface{}) bool {
return assert.ObjectsAreEqual(m.expected, x)
}
func (m objectsAreEqualMatcher) String() string {
return fmt.Sprintf("is equal to %v (%T)", m.expected, m.expected)
}

View file

@ -1,7 +1,8 @@
package manager package manager
import ( import (
"strings" "context"
"errors"
"golang.org/x/oauth2" "golang.org/x/oauth2"
"google.golang.org/protobuf/types/known/timestamppb" "google.golang.org/protobuf/types/known/timestamppb"
@ -9,21 +10,6 @@ import (
"github.com/pomerium/pomerium/pkg/grpc/session" "github.com/pomerium/pomerium/pkg/grpc/session"
) )
func toSessionSchedulerKey(userID, sessionID string) string {
return userID + "\037" + sessionID
}
func fromSessionSchedulerKey(key string) (userID, sessionID string) {
idx := strings.Index(key, "\037")
if idx >= 0 {
userID = key[:idx]
sessionID = key[idx+1:]
} else {
userID = key
}
return userID, sessionID
}
// FromOAuthToken converts a session oauth token to oauth2.Token. // FromOAuthToken converts a session oauth token to oauth2.Token.
func FromOAuthToken(token *session.OAuthToken) *oauth2.Token { func FromOAuthToken(token *session.OAuthToken) *oauth2.Token {
return &oauth2.Token{ return &oauth2.Token{
@ -44,3 +30,17 @@ func ToOAuthToken(token *oauth2.Token) *session.OAuthToken {
ExpiresAt: expiry, ExpiresAt: expiry,
} }
} }
func isTemporaryError(err error) bool {
if err == nil {
return false
}
if errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled) {
return true
}
var hasTemporary interface{ Temporary() bool }
if errors.As(err, &hasTemporary) && hasTemporary.Temporary() {
return true
}
return false
}

View file

@ -0,0 +1,138 @@
package manager
import (
"context"
"sync/atomic"
"time"
"github.com/pomerium/pomerium/pkg/grpc/session"
)
type updateUserInfoScheduler struct {
mgr *Manager
userID string
reset chan struct{}
cancel context.CancelFunc
}
func newUpdateUserInfoScheduler(ctx context.Context, mgr *Manager, userID string) *updateUserInfoScheduler {
uuis := &updateUserInfoScheduler{
mgr: mgr,
userID: userID,
reset: make(chan struct{}, 1),
}
ctx = context.WithoutCancel(ctx)
ctx, uuis.cancel = context.WithCancel(ctx)
go uuis.run(ctx)
return uuis
}
func (uuis *updateUserInfoScheduler) Reset() {
select {
case uuis.reset <- struct{}{}:
default:
}
}
func (uuis *updateUserInfoScheduler) Stop() {
uuis.cancel()
}
func (uuis *updateUserInfoScheduler) run(ctx context.Context) {
ticker := time.NewTicker(uuis.mgr.cfg.Load().updateUserInfoInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-uuis.reset:
ticker.Reset(uuis.mgr.cfg.Load().updateUserInfoInterval)
case <-ticker.C:
uuis.mgr.updateUserInfo(ctx, uuis.userID)
}
}
}
type refreshSessionScheduler struct {
mgr *Manager
sessionID string
lastRefresh atomic.Pointer[time.Time]
next chan time.Time
cancel context.CancelFunc
}
func newRefreshSessionScheduler(
ctx context.Context,
mgr *Manager,
sessionID string,
) *refreshSessionScheduler {
rss := &refreshSessionScheduler{
mgr: mgr,
sessionID: sessionID,
next: make(chan time.Time, 1),
}
now := rss.mgr.cfg.Load().now()
rss.lastRefresh.Store(&now)
ctx = context.WithoutCancel(ctx)
ctx, rss.cancel = context.WithCancel(ctx)
go rss.run(ctx)
return rss
}
func (rss *refreshSessionScheduler) Update(s *session.Session) {
due := nextSessionRefresh(
s,
*rss.lastRefresh.Load(),
rss.mgr.cfg.Load().sessionRefreshGracePeriod,
rss.mgr.cfg.Load().sessionRefreshCoolOffDuration,
)
for {
select {
case <-rss.next:
default:
}
select {
case rss.next <- due:
return
default:
}
}
}
func (rss *refreshSessionScheduler) Stop() {
rss.cancel()
}
func (rss *refreshSessionScheduler) run(ctx context.Context) {
var timer *time.Timer
// wait for the first update
select {
case <-ctx.Done():
return
case due := <-rss.next:
delay := max(time.Until(due), 0)
timer = time.NewTimer(delay)
defer timer.Stop()
}
// wait for updates or for the timer to trigger
for {
select {
case <-ctx.Done():
return
case due := <-rss.next:
delay := max(time.Until(due), 0)
// stop the current timer and reset it
if !timer.Stop() {
<-timer.C
}
timer.Reset(delay)
case <-timer.C:
now := rss.mgr.cfg.Load().now()
rss.lastRefresh.Store(&now)
rss.mgr.refreshSession(ctx, rss.sessionID)
}
}
}

View file

@ -3,53 +3,75 @@ package manager
import ( import (
"context" "context"
"github.com/pomerium/pomerium/internal/atomicutil" "github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/pkg/grpc/databroker" "github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/grpc/session"
"github.com/pomerium/pomerium/pkg/grpc/user"
"github.com/pomerium/pomerium/pkg/grpcutil"
) )
type dataBrokerSyncer struct { type sessionSyncerHandler struct {
cfg *atomicutil.Value[*config] mgr *Manager
update chan<- updateRecordsMessage
clear chan<- struct{}
syncer *databroker.Syncer
} }
func newDataBrokerSyncer( func newSessionSyncer(mgr *Manager) *databroker.Syncer {
_ context.Context, return databroker.NewSyncer("identity_manager/sessions", sessionSyncerHandler{mgr: mgr},
cfg *atomicutil.Value[*config], databroker.WithTypeURL(grpcutil.GetTypeURL(new(session.Session))))
update chan<- updateRecordsMessage,
clear chan<- struct{},
) *dataBrokerSyncer {
syncer := &dataBrokerSyncer{
cfg: cfg,
update: update,
clear: clear,
}
syncer.syncer = databroker.NewSyncer("identity_manager", syncer)
return syncer
} }
func (syncer *dataBrokerSyncer) Run(ctx context.Context) (err error) { func (h sessionSyncerHandler) ClearRecords(ctx context.Context) {
return syncer.syncer.Run(ctx) h.mgr.onDeleteAllSessions(ctx)
} }
func (syncer *dataBrokerSyncer) ClearRecords(ctx context.Context) { func (h sessionSyncerHandler) GetDataBrokerServiceClient() databroker.DataBrokerServiceClient {
select { return h.mgr.cfg.Load().dataBrokerClient
case <-ctx.Done(): }
case syncer.clear <- struct{}{}:
func (h sessionSyncerHandler) UpdateRecords(ctx context.Context, _ uint64, records []*databroker.Record) {
for _, record := range records {
if record.GetDeletedAt() != nil {
h.mgr.onDeleteSession(ctx, record.GetId())
} else {
var s session.Session
err := record.Data.UnmarshalTo(&s)
if err != nil {
log.Ctx(ctx).Warn().Err(err).Msg("invalid data in session record, ignoring")
} else {
h.mgr.onUpdateSession(ctx, &s)
}
}
} }
} }
func (syncer *dataBrokerSyncer) GetDataBrokerServiceClient() databroker.DataBrokerServiceClient { type userSyncerHandler struct {
return syncer.cfg.Load().dataBrokerClient mgr *Manager
} }
func (syncer *dataBrokerSyncer) UpdateRecords(ctx context.Context, _ uint64, records []*databroker.Record) { func newUserSyncer(mgr *Manager) *databroker.Syncer {
select { return databroker.NewSyncer("identity_manager/users", userSyncerHandler{mgr: mgr},
case <-ctx.Done(): databroker.WithTypeURL(grpcutil.GetTypeURL(new(user.User))))
case syncer.update <- updateRecordsMessage{records: records}: }
func (h userSyncerHandler) ClearRecords(ctx context.Context) {
h.mgr.onDeleteAllUsers(ctx)
}
func (h userSyncerHandler) GetDataBrokerServiceClient() databroker.DataBrokerServiceClient {
return h.mgr.cfg.Load().dataBrokerClient
}
func (h userSyncerHandler) UpdateRecords(ctx context.Context, _ uint64, records []*databroker.Record) {
for _, record := range records {
if record.GetDeletedAt() != nil {
h.mgr.onDeleteUser(ctx, record.GetId())
} else {
var u user.User
err := record.Data.UnmarshalTo(&u)
if err != nil {
log.Ctx(ctx).Warn().Err(err).Msg("invalid data in user record, ignoring")
} else {
h.mgr.onUpdateUser(ctx, &u)
}
}
} }
} }