From a80ef11763cbcd0513c12b8bea4be4bbc9abeb5b Mon Sep 17 00:00:00 2001 From: Caleb Doxsey Date: Fri, 19 Apr 2024 15:23:55 -0600 Subject: [PATCH] wip --- internal/authenticateflow/stateful.go | 11 +- internal/identity/manager/config.go | 10 + internal/identity/manager/data.go | 245 +++------- internal/identity/manager/data_test.go | 40 +- internal/identity/manager/datastore.go | 86 ++-- internal/identity/manager/datastore_test.go | 36 +- internal/identity/manager/manager.go | 503 +++++++++----------- internal/identity/manager/manager_test.go | 332 ------------- internal/identity/manager/misc.go | 32 +- internal/identity/manager/schedulers.go | 138 ++++++ internal/identity/manager/sync.go | 90 ++-- 11 files changed, 574 insertions(+), 949 deletions(-) create mode 100644 internal/identity/manager/schedulers.go diff --git a/internal/authenticateflow/stateful.go b/internal/authenticateflow/stateful.go index 64030e4f7..1c671da73 100644 --- a/internal/authenticateflow/stateful.go +++ b/internal/authenticateflow/stateful.go @@ -196,16 +196,15 @@ func (s *Stateful) PersistSession( sess.SetRawIDToken(claims.RawIDToken) sess.AddClaims(claims.Flatten()) - var managerUser manager.User - managerUser.User, _ = user.Get(ctx, s.dataBrokerClient, sess.GetUserId()) - if managerUser.User == nil { + u, _ := user.Get(ctx, s.dataBrokerClient, sess.GetUserId()) + if u == nil { // if no user exists yet, create a new one - managerUser.User = &user.User{ + u = &user.User{ Id: sess.GetUserId(), } } - populateUserFromClaims(managerUser.User, claims.Claims) - _, err := databroker.Put(ctx, s.dataBrokerClient, managerUser.User) + populateUserFromClaims(u, claims.Claims) + _, err := databroker.Put(ctx, s.dataBrokerClient, u) if err != nil { return fmt.Errorf("authenticate: error saving user: %w", err) } diff --git a/internal/identity/manager/config.go b/internal/identity/manager/config.go index fbc8f811b..1e41e767f 100644 --- a/internal/identity/manager/config.go +++ b/internal/identity/manager/config.go @@ -10,6 +10,7 @@ import ( var ( defaultSessionRefreshGracePeriod = 1 * time.Minute defaultSessionRefreshCoolOffDuration = 10 * time.Second + defaultUpdateUserInfoInterval = 10 * time.Minute ) type config struct { @@ -17,6 +18,7 @@ type config struct { dataBrokerClient databroker.DataBrokerServiceClient sessionRefreshGracePeriod time.Duration sessionRefreshCoolOffDuration time.Duration + updateUserInfoInterval time.Duration now func() time.Time eventMgr *events.Manager } @@ -26,6 +28,7 @@ func newConfig(options ...Option) *config { WithSessionRefreshGracePeriod(defaultSessionRefreshGracePeriod)(cfg) WithSessionRefreshCoolOffDuration(defaultSessionRefreshCoolOffDuration)(cfg) WithNow(time.Now)(cfg) + WithUpdateUserInfoInterval(defaultUpdateUserInfoInterval)(cfg) for _, option := range options { option(cfg) } @@ -76,3 +79,10 @@ func WithEventManager(mgr *events.Manager) Option { c.eventMgr = mgr } } + +// WithUpdateUserInfoInterval sets the update user info interval in the config. +func WithUpdateUserInfoInterval(dur time.Duration) Option { + return func(cfg *config) { + cfg.updateUserInfoInterval = dur + } +} diff --git a/internal/identity/manager/data.go b/internal/identity/manager/data.go index 2c8637102..35e4c2bc7 100644 --- a/internal/identity/manager/data.go +++ b/internal/identity/manager/data.go @@ -2,9 +2,9 @@ package manager import ( "encoding/json" + "errors" "time" - "github.com/google/btree" "google.golang.org/protobuf/types/known/timestamppb" "github.com/pomerium/pomerium/internal/identity" @@ -12,66 +12,18 @@ import ( "github.com/pomerium/pomerium/pkg/grpc/user" ) -const userRefreshInterval = 10 * time.Minute - -// A User is a user managed by the Manager. -type User struct { - *user.User - lastRefresh time.Time -} - -// NextRefresh returns the next time the user information needs to be refreshed. -func (u User) NextRefresh() time.Time { - return u.lastRefresh.Add(userRefreshInterval) -} - -// UnmarshalJSON unmarshals json data into the user object. -func (u *User) UnmarshalJSON(data []byte) error { - if u.User == nil { - u.User = new(user.User) - } - - var raw map[string]json.RawMessage - err := json.Unmarshal(data, &raw) - if err != nil { - return err - } - - if name, ok := raw["name"]; ok { - _ = json.Unmarshal(name, &u.User.Name) - delete(raw, "name") - } - if email, ok := raw["email"]; ok { - _ = json.Unmarshal(email, &u.User.Email) - delete(raw, "email") - } - - u.AddClaims(identity.NewClaimsFromRaw(raw).Flatten()) - - return nil -} - -// A Session is a session managed by the Manager. -type Session struct { - *session.Session - // lastRefresh is the time of the last refresh attempt (which may or may - // not have succeeded), or else the time the Manager first became aware of - // the session (if it has not yet attempted to refresh this session). - lastRefresh time.Time - // gracePeriod is the amount of time before expiration to attempt a refresh. - gracePeriod time.Duration - // coolOffDuration is the amount of time to wait before attempting another refresh. - coolOffDuration time.Duration -} - -// NextRefresh returns the next time the session needs to be refreshed. -func (s Session) NextRefresh() time.Time { +func nextSessionRefresh( + s *session.Session, + lastRefresh time.Time, + gracePeriod time.Duration, + coolOffDuration time.Duration, +) time.Time { var tm time.Time if s.GetOauthToken().GetExpiresAt() != nil { expiry := s.GetOauthToken().GetExpiresAt().AsTime() if s.GetOauthToken().GetExpiresAt().IsValid() && !expiry.IsZero() { - expiry = expiry.Add(-s.gracePeriod) + expiry = expiry.Add(-gracePeriod) if tm.IsZero() || expiry.Before(tm) { tm = expiry } @@ -88,7 +40,7 @@ func (s Session) NextRefresh() time.Time { } // don't refresh any quicker than the cool-off duration - min := s.lastRefresh.Add(s.coolOffDuration) + min := lastRefresh.Add(coolOffDuration) if tm.Before(min) { tm = min } @@ -96,10 +48,33 @@ func (s Session) NextRefresh() time.Time { return tm } -// UnmarshalJSON unmarshals json data into the session object. -func (s *Session) UnmarshalJSON(data []byte) error { - if s.Session == nil { - s.Session = new(session.Session) +type multiUnmarshaler []any + +func newMultiUnmarshaler(args ...any) *multiUnmarshaler { + return (*multiUnmarshaler)(&args) +} + +func (dst *multiUnmarshaler) UnmarshalJSON(data []byte) error { + var err error + for _, o := range *dst { + if o != nil { + err = errors.Join(err, json.Unmarshal(data, o)) + } + } + return err +} + +type sessionUnmarshaler struct { + *session.Session +} + +func newSessionUnmarshaler(s *session.Session) *sessionUnmarshaler { + return &sessionUnmarshaler{Session: s} +} + +func (dst *sessionUnmarshaler) UnmarshalJSON(data []byte) error { + if dst.Session == nil { + return nil } var raw map[string]json.RawMessage @@ -108,159 +83,67 @@ func (s *Session) UnmarshalJSON(data []byte) error { return err } - if s.Session.IdToken == nil { - s.Session.IdToken = new(session.IDToken) + if dst.Session.IdToken == nil { + dst.Session.IdToken = new(session.IDToken) } if iss, ok := raw["iss"]; ok { - _ = json.Unmarshal(iss, &s.Session.IdToken.Issuer) + _ = json.Unmarshal(iss, &dst.Session.IdToken.Issuer) delete(raw, "iss") } if sub, ok := raw["sub"]; ok { - _ = json.Unmarshal(sub, &s.Session.IdToken.Subject) + _ = json.Unmarshal(sub, &dst.Session.IdToken.Subject) delete(raw, "sub") } if exp, ok := raw["exp"]; ok { var secs int64 if err := json.Unmarshal(exp, &secs); err == nil { - s.Session.IdToken.ExpiresAt = timestamppb.New(time.Unix(secs, 0)) + dst.Session.IdToken.ExpiresAt = timestamppb.New(time.Unix(secs, 0)) } delete(raw, "exp") } if iat, ok := raw["iat"]; ok { var secs int64 if err := json.Unmarshal(iat, &secs); err == nil { - s.Session.IdToken.IssuedAt = timestamppb.New(time.Unix(secs, 0)) + dst.Session.IdToken.IssuedAt = timestamppb.New(time.Unix(secs, 0)) } delete(raw, "iat") } - s.AddClaims(identity.NewClaimsFromRaw(raw).Flatten()) + dst.Session.AddClaims(identity.NewClaimsFromRaw(raw).Flatten()) return nil } -type sessionCollectionItem struct { - Session +type userUnmarshaler struct { + *user.User } -func (item sessionCollectionItem) Less(than btree.Item) bool { - xUserID, yUserID := item.GetUserId(), than.(sessionCollectionItem).GetUserId() - switch { - case xUserID < yUserID: - return true - case yUserID < xUserID: - return false +func newUserUnmarshaler(u *user.User) *userUnmarshaler { + return &userUnmarshaler{User: u} +} + +func (dst *userUnmarshaler) UnmarshalJSON(data []byte) error { + if dst.User == nil { + return nil } - xID, yID := item.GetId(), than.(sessionCollectionItem).GetId() - switch { - case xID < yID: - return true - case yID < xID: - return false + var raw map[string]json.RawMessage + err := json.Unmarshal(data, &raw) + if err != nil { + return err } - return false -} -type sessionCollection struct { - *btree.BTree -} - -func (c *sessionCollection) Delete(userID, sessionID string) { - c.BTree.Delete(sessionCollectionItem{ - Session: Session{ - Session: &session.Session{ - UserId: userID, - Id: sessionID, - }, - }, - }) -} - -func (c *sessionCollection) Get(userID, sessionID string) (Session, bool) { - item := c.BTree.Get(sessionCollectionItem{ - Session: Session{ - Session: &session.Session{ - UserId: userID, - Id: sessionID, - }, - }, - }) - if item == nil { - return Session{}, false + if name, ok := raw["name"]; ok { + _ = json.Unmarshal(name, &dst.User.Name) + delete(raw, "name") } - return item.(sessionCollectionItem).Session, true -} - -// GetSessionsForUser gets all the sessions for the given user. -func (c *sessionCollection) GetSessionsForUser(userID string) []Session { - var sessions []Session - c.AscendGreaterOrEqual(sessionCollectionItem{ - Session: Session{ - Session: &session.Session{ - UserId: userID, - }, - }, - }, func(item btree.Item) bool { - s := item.(sessionCollectionItem).Session - if s.UserId != userID { - return false - } - - sessions = append(sessions, s) - return true - }) - return sessions -} - -func (c *sessionCollection) ReplaceOrInsert(s Session) { - c.BTree.ReplaceOrInsert(sessionCollectionItem{Session: s}) -} - -type userCollectionItem struct { - User -} - -func (item userCollectionItem) Less(than btree.Item) bool { - xID, yID := item.GetId(), than.(userCollectionItem).GetId() - switch { - case xID < yID: - return true - case yID < xID: - return false + if email, ok := raw["email"]; ok { + _ = json.Unmarshal(email, &dst.User.Email) + delete(raw, "email") } - return false -} -type userCollection struct { - *btree.BTree -} + dst.User.AddClaims(identity.NewClaimsFromRaw(raw).Flatten()) -func (c *userCollection) Delete(userID string) { - c.BTree.Delete(userCollectionItem{ - User: User{ - User: &user.User{ - Id: userID, - }, - }, - }) -} - -func (c *userCollection) Get(userID string) (User, bool) { - item := c.BTree.Get(userCollectionItem{ - User: User{ - User: &user.User{ - Id: userID, - }, - }, - }) - if item == nil { - return User{}, false - } - return item.(userCollectionItem).User, true -} - -func (c *userCollection) ReplaceOrInsert(u User) { - c.BTree.ReplaceOrInsert(userCollectionItem{User: u}) + return nil } diff --git a/internal/identity/manager/data_test.go b/internal/identity/manager/data_test.go index 675cf091e..3432589d2 100644 --- a/internal/identity/manager/data_test.go +++ b/internal/identity/manager/data_test.go @@ -11,20 +11,20 @@ import ( "google.golang.org/protobuf/types/known/timestamppb" "github.com/pomerium/pomerium/pkg/grpc/session" + "github.com/pomerium/pomerium/pkg/grpc/user" "github.com/pomerium/pomerium/pkg/protoutil" ) func TestUser_UnmarshalJSON(t *testing.T) { - var u User + u := new(user.User) err := json.Unmarshal([]byte(`{ "name": "joe", "email": "joe@test.com", "some-other-claim": "xyz" - }`), &u) + }`), newUserUnmarshaler(u)) assert.NoError(t, err) - assert.NotNil(t, u.User) - assert.Equal(t, "joe", u.User.Name) - assert.Equal(t, "joe@test.com", u.User.Email) + assert.Equal(t, "joe", u.Name) + assert.Equal(t, "joe@test.com", u.Email) assert.Equal(t, map[string]*structpb.ListValue{ "some-other-claim": {Values: []*structpb.Value{protoutil.ToStruct("xyz")}}, }, u.Claims) @@ -32,42 +32,38 @@ func TestUser_UnmarshalJSON(t *testing.T) { func TestSession_NextRefresh(t *testing.T) { tm1 := time.Date(2020, 6, 5, 12, 0, 0, 0, time.UTC) - s := Session{ - Session: &session.Session{}, - lastRefresh: tm1, - gracePeriod: time.Second * 10, - coolOffDuration: time.Minute, - } - assert.Equal(t, tm1.Add(time.Minute), s.NextRefresh()) + s := &session.Session{} + gracePeriod := time.Second * 10 + coolOffDuration := time.Minute + assert.Equal(t, tm1.Add(time.Minute), nextSessionRefresh(s, tm1, gracePeriod, coolOffDuration)) tm2 := time.Date(2020, 6, 5, 13, 0, 0, 0, time.UTC) s.OauthToken = &session.OAuthToken{ ExpiresAt: timestamppb.New(tm2), } - assert.Equal(t, tm2.Add(-time.Second*10), s.NextRefresh()) + assert.Equal(t, tm2.Add(-time.Second*10), nextSessionRefresh(s, tm1, gracePeriod, coolOffDuration)) tm3 := time.Date(2020, 6, 5, 12, 15, 0, 0, time.UTC) s.ExpiresAt = timestamppb.New(tm3) - assert.Equal(t, tm3, s.NextRefresh()) + assert.Equal(t, tm3, nextSessionRefresh(s, tm1, gracePeriod, coolOffDuration)) } func TestSession_UnmarshalJSON(t *testing.T) { tm := time.Date(2020, 6, 5, 12, 0, 0, 0, time.UTC) - var s Session + s := new(session.Session) err := json.Unmarshal([]byte(`{ "iss": "https://some.issuer.com", "sub": "subject", "exp": `+fmt.Sprint(tm.Unix())+`, "iat": `+fmt.Sprint(tm.Unix())+`, "some-other-claim": "xyz" - }`), &s) + }`), newSessionUnmarshaler(s)) assert.NoError(t, err) - assert.NotNil(t, s.Session) - assert.NotNil(t, s.Session.IdToken) - assert.Equal(t, "https://some.issuer.com", s.Session.IdToken.Issuer) - assert.Equal(t, "subject", s.Session.IdToken.Subject) - assert.Equal(t, timestamppb.New(tm), s.Session.IdToken.ExpiresAt) - assert.Equal(t, timestamppb.New(tm), s.Session.IdToken.IssuedAt) + assert.NotNil(t, s.IdToken) + assert.Equal(t, "https://some.issuer.com", s.IdToken.Issuer) + assert.Equal(t, "subject", s.IdToken.Subject) + assert.Equal(t, timestamppb.New(tm), s.IdToken.ExpiresAt) + assert.Equal(t, timestamppb.New(tm), s.IdToken.IssuedAt) assert.Equal(t, map[string]*structpb.ListValue{ "some-other-claim": {Values: []*structpb.Value{protoutil.ToStruct("xyz")}}, }, s.Claims) diff --git a/internal/identity/manager/datastore.go b/internal/identity/manager/datastore.go index a8aeea78a..4d981ea44 100644 --- a/internal/identity/manager/datastore.go +++ b/internal/identity/manager/datastore.go @@ -3,7 +3,6 @@ package manager import ( "cmp" "slices" - "sync" "google.golang.org/protobuf/proto" @@ -11,44 +10,54 @@ import ( "github.com/pomerium/pomerium/pkg/grpc/user" ) -// dataStore stores session and user data. All public methods are thread-safe. +// dataStore stores session and user data type dataStore struct { - mu sync.Mutex sessions map[string]*session.Session users map[string]*user.User userIDToSessionIDs map[string]map[string]struct{} } func newDataStore() *dataStore { - return &dataStore{ - sessions: make(map[string]*session.Session), - users: make(map[string]*user.User), - userIDToSessionIDs: make(map[string]map[string]struct{}), + ds := new(dataStore) + ds.deleteAllSessions() + ds.deleteAllUsers() + return ds +} + +func (ds *dataStore) deleteAllSessions() { + ds.sessions = make(map[string]*session.Session) + ds.userIDToSessionIDs = make(map[string]map[string]struct{}) +} + +func (ds *dataStore) deleteAllUsers() { + ds.users = make(map[string]*user.User) +} + +func (ds *dataStore) deleteSession(sessionID string) { + s := ds.sessions[sessionID] + delete(ds.sessions, sessionID) + if s.GetUserId() == "" { + return + } + + m := ds.userIDToSessionIDs[s.GetUserId()] + if m != nil { + delete(m, s.GetId()) + } + if len(m) == 0 { + delete(ds.userIDToSessionIDs, s.GetUserId()) } } -// DeleteSession deletes a session. -func (ds *dataStore) DeleteSession(sID string) { - ds.mu.Lock() - ds.deleteSessionLocked(sID) - ds.mu.Unlock() -} - -// DeleteUser deletes a user. -func (ds *dataStore) DeleteUser(userID string) { - ds.mu.Lock() +func (ds *dataStore) deleteUser(userID string) { delete(ds.users, userID) - ds.mu.Unlock() } -// GetSessionAndUser gets a session and its associated user. -func (ds *dataStore) GetSessionAndUser(sessionID string) (s *session.Session, u *user.User) { - ds.mu.Lock() +func (ds *dataStore) getSessionAndUser(sessionID string) (s *session.Session, u *user.User) { s = ds.sessions[sessionID] if s.GetUserId() != "" { u = ds.users[s.GetUserId()] } - ds.mu.Unlock() // clone to avoid sharing memory s = clone(s) @@ -56,14 +65,11 @@ func (ds *dataStore) GetSessionAndUser(sessionID string) (s *session.Session, u return s, u } -// GetUserAndSessions gets a user and all of its associated sessions. -func (ds *dataStore) GetUserAndSessions(userID string) (u *user.User, ss []*session.Session) { - ds.mu.Lock() +func (ds *dataStore) getUserAndSessions(userID string) (u *user.User, ss []*session.Session) { u = ds.users[userID] for sessionID := range ds.userIDToSessionIDs[userID] { ss = append(ss, ds.sessions[sessionID]) } - ds.mu.Unlock() // remove nils and sort by id ss = slices.Compact(ss) @@ -79,14 +85,12 @@ func (ds *dataStore) GetUserAndSessions(userID string) (u *user.User, ss []*sess return u, ss } -// PutSession stores the session. -func (ds *dataStore) PutSession(s *session.Session) { +func (ds *dataStore) putSession(s *session.Session) { // clone to avoid sharing memory s = clone(s) - ds.mu.Lock() if s.GetId() != "" { - ds.deleteSessionLocked(s.GetId()) + ds.deleteSession(s.GetId()) ds.sessions[s.GetId()] = s if s.GetUserId() != "" { m, ok := ds.userIDToSessionIDs[s.GetUserId()] @@ -97,35 +101,15 @@ func (ds *dataStore) PutSession(s *session.Session) { m[s.GetId()] = struct{}{} } } - ds.mu.Unlock() } -// PutUser stores the user. -func (ds *dataStore) PutUser(u *user.User) { +func (ds *dataStore) putUser(u *user.User) { // clone to avoid sharing memory u = clone(u) - ds.mu.Lock() if u.GetId() != "" { ds.users[u.GetId()] = u } - ds.mu.Unlock() -} - -func (ds *dataStore) deleteSessionLocked(sID string) { - s := ds.sessions[sID] - delete(ds.sessions, sID) - if s.GetUserId() == "" { - return - } - - m := ds.userIDToSessionIDs[s.GetUserId()] - if m != nil { - delete(m, s.GetId()) - } - if len(m) == 0 { - delete(ds.userIDToSessionIDs, s.GetUserId()) - } } // clone clones a protobuf message diff --git a/internal/identity/manager/datastore_test.go b/internal/identity/manager/datastore_test.go index 55592aa43..c75d26b2c 100644 --- a/internal/identity/manager/datastore_test.go +++ b/internal/identity/manager/datastore_test.go @@ -15,32 +15,32 @@ func TestDataStore(t *testing.T) { t.Parallel() ds := newDataStore() - s, u := ds.GetSessionAndUser("S1") + s, u := ds.getSessionAndUser("S1") assert.Nil(t, s, "should return a nil session when none exists") assert.Nil(t, u, "should return a nil user when none exists") - u, ss := ds.GetUserAndSessions("U1") + u, ss := ds.getUserAndSessions("U1") assert.Nil(t, u, "should return a nil user when none exists") assert.Empty(t, ss, "should return an empty list of sessions when no user exists") s = &session.Session{Id: "S1", UserId: "U1"} - ds.PutSession(s) - s1, u1 := ds.GetSessionAndUser("S1") + ds.putSession(s) + s1, u1 := ds.getSessionAndUser("S1") assert.NotNil(t, s1, "should return a non-nil session") assert.False(t, s == s1, "should return different pointers") assert.Empty(t, cmp.Diff(s, s1, protocmp.Transform()), "should be the same as was entered") assert.Nil(t, u1, "should return a nil user when only the session exists") - ds.PutUser(&user.User{ + ds.putUser(&user.User{ Id: "U1", }) - _, u1 = ds.GetSessionAndUser("S1") + _, u1 = ds.getSessionAndUser("S1") assert.NotNil(t, u1, "should return a user now that it has been added") - ds.PutSession(&session.Session{Id: "S4", UserId: "U1"}) - ds.PutSession(&session.Session{Id: "S3", UserId: "U1"}) - ds.PutSession(&session.Session{Id: "S2", UserId: "U1"}) - u, ss = ds.GetUserAndSessions("U1") + ds.putSession(&session.Session{Id: "S4", UserId: "U1"}) + ds.putSession(&session.Session{Id: "S3", UserId: "U1"}) + ds.putSession(&session.Session{Id: "S2", UserId: "U1"}) + u, ss = ds.getUserAndSessions("U1") assert.NotNil(t, u) assert.Empty(t, cmp.Diff(ss, []*session.Session{ {Id: "S1", UserId: "U1"}, @@ -49,9 +49,9 @@ func TestDataStore(t *testing.T) { {Id: "S4", UserId: "U1"}, }, protocmp.Transform()), "should return all sessions in id order") - ds.DeleteSession("S4") + ds.deleteSession("S4") - u, ss = ds.GetUserAndSessions("U1") + u, ss = ds.getUserAndSessions("U1") assert.NotNil(t, u) assert.Empty(t, cmp.Diff(ss, []*session.Session{ {Id: "S1", UserId: "U1"}, @@ -59,8 +59,8 @@ func TestDataStore(t *testing.T) { {Id: "S3", UserId: "U1"}, }, protocmp.Transform()), "should return all sessions in id order") - ds.DeleteUser("U1") - u, ss = ds.GetUserAndSessions("U1") + ds.deleteUser("U1") + u, ss = ds.getUserAndSessions("U1") assert.Nil(t, u) assert.Empty(t, cmp.Diff(ss, []*session.Session{ {Id: "S1", UserId: "U1"}, @@ -68,11 +68,11 @@ func TestDataStore(t *testing.T) { {Id: "S3", UserId: "U1"}, }, protocmp.Transform()), "should still return all sessions in id order") - ds.DeleteSession("S1") - ds.DeleteSession("S2") - ds.DeleteSession("S3") + ds.deleteSession("S1") + ds.deleteSession("S2") + ds.deleteSession("S3") - u, ss = ds.GetUserAndSessions("U1") + u, ss = ds.getUserAndSessions("U1") assert.Nil(t, u) assert.Empty(t, ss) } diff --git a/internal/identity/manager/manager.go b/internal/identity/manager/manager.go index 1ae18063f..02843e4cb 100644 --- a/internal/identity/manager/manager.go +++ b/internal/identity/manager/manager.go @@ -3,10 +3,9 @@ package manager import ( "context" - "errors" + "sync" "time" - "github.com/google/btree" "github.com/rs/zerolog" "golang.org/x/oauth2" "golang.org/x/sync/errgroup" @@ -19,7 +18,6 @@ import ( "github.com/pomerium/pomerium/internal/events" "github.com/pomerium/pomerium/internal/identity/identity" "github.com/pomerium/pomerium/internal/log" - "github.com/pomerium/pomerium/internal/scheduler" "github.com/pomerium/pomerium/internal/telemetry/metrics" "github.com/pomerium/pomerium/pkg/grpc/databroker" "github.com/pomerium/pomerium/pkg/grpc/session" @@ -35,21 +33,14 @@ type Authenticator interface { UpdateUserInfo(context.Context, *oauth2.Token, interface{}) error } -type ( - updateRecordsMessage struct { - records []*databroker.Record - } -) - // A Manager refreshes identity information using session and user data. type Manager struct { cfg *atomicutil.Value[*config] - sessionScheduler *scheduler.Scheduler - userScheduler *scheduler.Scheduler - - sessions sessionCollection - users userCollection + mu sync.Mutex + dataStore *dataStore + refreshSessionSchedulers map[string]*refreshSessionScheduler + updateUserInfoSchedulers map[string]*updateUserInfoScheduler } // New creates a new identity manager. @@ -59,25 +50,24 @@ func New( mgr := &Manager{ cfg: atomicutil.NewValue(newConfig()), - sessionScheduler: scheduler.New(), - userScheduler: scheduler.New(), + dataStore: newDataStore(), + refreshSessionSchedulers: make(map[string]*refreshSessionScheduler), + updateUserInfoSchedulers: make(map[string]*updateUserInfoScheduler), } - mgr.reset() mgr.UpdateConfig(options...) return mgr } -func withLog(ctx context.Context) context.Context { - return log.WithContext(ctx, func(c zerolog.Context) zerolog.Context { - return c.Str("service", "identity_manager") - }) -} - // UpdateConfig updates the manager with the new options. func (mgr *Manager) UpdateConfig(options ...Option) { mgr.cfg.Store(newConfig(options...)) } +// GetDataBrokerServiceClient gets the databroker client. +func (mgr *Manager) GetDataBrokerServiceClient() databroker.DataBrokerServiceClient { + return mgr.cfg.Load().dataBrokerClient +} + // Run runs the manager. This method blocks until an error occurs or the given context is canceled. func (mgr *Manager) Run(ctx context.Context) error { leaser := databroker.NewLeaser("identity_manager", time.Second*30, mgr) @@ -86,173 +76,144 @@ func (mgr *Manager) Run(ctx context.Context) error { // RunLeased runs the identity manager when a lease is acquired. func (mgr *Manager) RunLeased(ctx context.Context) error { - ctx = withLog(ctx) - update := make(chan updateRecordsMessage, 1) - clear := make(chan struct{}, 1) - - syncer := newDataBrokerSyncer(ctx, mgr.cfg, update, clear) - + ctx = log.WithContext(ctx, func(c zerolog.Context) zerolog.Context { + return c.Str("service", "identity_manager") + }) eg, ctx := errgroup.WithContext(ctx) eg.Go(func() error { - return syncer.Run(ctx) + sessionSyncer := newSessionSyncer(mgr) + defer sessionSyncer.Close() + return sessionSyncer.Run(ctx) }) eg.Go(func() error { - return mgr.refreshLoop(ctx, update, clear) + userSyncer := newUserSyncer(mgr) + defer userSyncer.Close() + return userSyncer.Run(ctx) }) - return eg.Wait() } -// GetDataBrokerServiceClient gets the databroker client. -func (mgr *Manager) GetDataBrokerServiceClient() databroker.DataBrokerServiceClient { - return mgr.cfg.Load().dataBrokerClient +func (mgr *Manager) onDeleteAllSessions(ctx context.Context) { + log.Ctx(ctx).Debug().Msg("all session deleted") + + mgr.mu.Lock() + mgr.dataStore.deleteAllSessions() + for sID, rss := range mgr.refreshSessionSchedulers { + rss.Stop() + delete(mgr.refreshSessionSchedulers, sID) + } + mgr.mu.Unlock() } -func (mgr *Manager) now() time.Time { - return mgr.cfg.Load().now() +func (mgr *Manager) onDeleteAllUsers(ctx context.Context) { + log.Ctx(ctx).Debug().Msg("all users deleted") + + mgr.mu.Lock() + mgr.dataStore.deleteAllUsers() + for uID, uuis := range mgr.updateUserInfoSchedulers { + uuis.Stop() + delete(mgr.updateUserInfoSchedulers, uID) + } + mgr.mu.Unlock() } -func (mgr *Manager) refreshLoop(ctx context.Context, update <-chan updateRecordsMessage, clear <-chan struct{}) error { - // wait for initial sync - select { - case <-ctx.Done(): - return ctx.Err() - case <-clear: - mgr.reset() - } - select { - case <-ctx.Done(): - return ctx.Err() - case msg := <-update: - mgr.onUpdateRecords(ctx, msg) - } - - log.Debug(ctx). - Int("sessions", mgr.sessions.Len()). - Int("users", mgr.users.Len()). - Msg("initial sync complete") - - // start refreshing - maxWait := time.Minute * 10 - var nextTime time.Time - - timer := time.NewTimer(0) - defer timer.Stop() - - for { - select { - case <-ctx.Done(): - return ctx.Err() - case <-clear: - mgr.reset() - case msg := <-update: - mgr.onUpdateRecords(ctx, msg) - case <-timer.C: - } - - now := mgr.now() - nextTime = now.Add(maxWait) - - // refresh sessions - for { - tm, key := mgr.sessionScheduler.Next() - if now.Before(tm) { - if tm.Before(nextTime) { - nextTime = tm - } - break - } - mgr.sessionScheduler.Remove(key) - - userID, sessionID := fromSessionSchedulerKey(key) - mgr.refreshSession(ctx, userID, sessionID) - } - - // refresh users - for { - tm, key := mgr.userScheduler.Next() - if now.Before(tm) { - if tm.Before(nextTime) { - nextTime = tm - } - break - } - mgr.userScheduler.Remove(key) - - mgr.refreshUser(ctx, key) - } - - metrics.RecordIdentityManagerLastRefresh(ctx) - timer.Reset(time.Until(nextTime)) +func (mgr *Manager) onDeleteSession(ctx context.Context, sessionID string) { + log.Ctx(ctx).Debug().Str("session_id", sessionID).Msg("session deleted") + + mgr.mu.Lock() + mgr.dataStore.deleteSession(sessionID) + if rss, ok := mgr.refreshSessionSchedulers[sessionID]; ok { + rss.Stop() + delete(mgr.refreshSessionSchedulers, sessionID) } + mgr.mu.Unlock() } -// refreshSession handles two distinct session lifecycle events: -// -// 1. If the session itself has expired, delete the session. -// 2. If the session's underlying OAuth2 access token is nearing expiration -// (but the session itself is still valid), refresh the access token. -// -// After a successful access token refresh, this method will also trigger a -// user info refresh. If an access token refresh or a user info refresh fails -// with a permanent error, the session will be deleted. -func (mgr *Manager) refreshSession(ctx context.Context, userID, sessionID string) { - log.Info(ctx). - Str("user_id", userID). +func (mgr *Manager) onDeleteUser(ctx context.Context, userID string) { + log.Ctx(ctx).Debug().Str("user_id", userID).Msg("user deleted") + + mgr.mu.Lock() + mgr.dataStore.deleteUser(userID) + if uuis, ok := mgr.updateUserInfoSchedulers[userID]; ok { + uuis.Stop() + delete(mgr.updateUserInfoSchedulers, userID) + } + mgr.mu.Unlock() +} + +func (mgr *Manager) onUpdateSession(ctx context.Context, s *session.Session) { + log.Ctx(ctx).Debug().Str("session_id", s.GetId()).Msg("session updated") + + mgr.mu.Lock() + mgr.dataStore.putSession(s) + rss, ok := mgr.refreshSessionSchedulers[s.GetId()] + if !ok { + rss = newRefreshSessionScheduler(ctx, mgr, s.GetId()) + mgr.refreshSessionSchedulers[s.GetId()] = rss + } + rss.Update(s) + mgr.mu.Unlock() +} + +func (mgr *Manager) onUpdateUser(ctx context.Context, u *user.User) { + log.Ctx(ctx).Debug().Str("user_id", u.GetId()).Msg("user updated") + + mgr.mu.Lock() + mgr.dataStore.putUser(u) + _, ok := mgr.updateUserInfoSchedulers[u.GetId()] + if !ok { + uuis := newUpdateUserInfoScheduler(ctx, mgr, u.GetId()) + mgr.updateUserInfoSchedulers[u.GetId()] = uuis + } + mgr.mu.Unlock() +} + +func (mgr *Manager) refreshSession(ctx context.Context, sessionID string) { + log.Ctx(ctx).Debug(). Str("session_id", sessionID). Msg("refreshing session") - s, ok := mgr.sessions.Get(userID, sessionID) - if !ok { - log.Warn(ctx). - Str("user_id", userID). + mgr.mu.Lock() + s, u := mgr.dataStore.getSessionAndUser(sessionID) + mgr.mu.Unlock() + + if s == nil { + log.Ctx(ctx).Warn(). + Str("user_id", u.GetId()). Str("session_id", sessionID). Msg("no session found for refresh") return } - s.lastRefresh = mgr.now() - - if mgr.refreshSessionInternal(ctx, userID, sessionID, &s) { - mgr.sessions.ReplaceOrInsert(s) - mgr.sessionScheduler.Add(s.NextRefresh(), toSessionSchedulerKey(userID, sessionID)) - } -} - -// refreshSessionInternal performs the core refresh logic and returns true if -// the session should be scheduled for refresh again, or false if not. -func (mgr *Manager) refreshSessionInternal( - ctx context.Context, userID, sessionID string, s *Session, -) bool { authenticator := mgr.cfg.Load().authenticator if authenticator == nil { - log.Info(ctx). - Str("user_id", userID). - Str("session_id", sessionID). + log.Ctx(ctx).Info(). + Str("user_id", s.GetUserId()). + Str("session_id", s.GetId()). Msg("no authenticator defined, deleting session") - mgr.deleteSession(ctx, userID, sessionID) - return false + mgr.deleteSession(ctx, sessionID) + return } expiry := s.GetExpiresAt().AsTime() - if !expiry.After(mgr.now()) { + if !expiry.After(mgr.cfg.Load().now()) { log.Info(ctx). - Str("user_id", userID). - Str("session_id", sessionID). + Str("user_id", s.GetUserId()). + Str("session_id", s.GetId()). Msg("deleting expired session") - mgr.deleteSession(ctx, userID, sessionID) - return false + mgr.deleteSession(ctx, sessionID) + return } - if s.Session == nil || s.Session.OauthToken == nil { + if s.GetOauthToken() == nil { log.Warn(ctx). - Str("user_id", userID). - Str("session_id", sessionID). + Str("user_id", s.GetUserId()). + Str("session_id", s.GetId()). Msg("no session oauth2 token found for refresh") - return false + return } - newToken, err := authenticator.Refresh(ctx, FromOAuthToken(s.OauthToken), s) + newToken, err := authenticator.Refresh(ctx, FromOAuthToken(s.OauthToken), newSessionUnmarshaler(s)) metrics.RecordIdentityManagerSessionRefresh(ctx, err) mgr.recordLastError(metrics_ids.IdentityManagerLastSessionRefreshError, err) if isTemporaryError(err) { @@ -260,18 +221,18 @@ func (mgr *Manager) refreshSessionInternal( Str("user_id", s.GetUserId()). Str("session_id", s.GetId()). Msg("failed to refresh oauth2 token") - return true + return } else if err != nil { log.Error(ctx).Err(err). Str("user_id", s.GetUserId()). Str("session_id", s.GetId()). Msg("failed to refresh oauth2 token, deleting session") - mgr.deleteSession(ctx, userID, sessionID) - return false + mgr.deleteSession(ctx, sessionID) + return } s.OauthToken = ToOAuthToken(newToken) - err = authenticator.UpdateUserInfo(ctx, FromOAuthToken(s.OauthToken), s) + err = authenticator.UpdateUserInfo(ctx, FromOAuthToken(s.OauthToken), newMultiUnmarshaler(newUserUnmarshaler(u), newSessionUnmarshaler(s))) metrics.RecordIdentityManagerUserRefresh(ctx, err) mgr.recordLastError(metrics_ids.IdentityManagerLastUserRefreshError, err) if isTemporaryError(err) { @@ -279,184 +240,162 @@ func (mgr *Manager) refreshSessionInternal( Str("user_id", s.GetUserId()). Str("session_id", s.GetId()). Msg("failed to update user info") - return true + return } else if err != nil { log.Error(ctx).Err(err). Str("user_id", s.GetUserId()). Str("session_id", s.GetId()). Msg("failed to update user info, deleting session") - mgr.deleteSession(ctx, userID, sessionID) - return false + mgr.deleteSession(ctx, sessionID) + return } - fm, err := fieldmaskpb.New(s.Session, "oauth_token", "id_token", "claims") - if err != nil { - log.Error(ctx).Err(err).Msg("internal error") - return false + mgr.updateSession(ctx, s) + if u != nil { + mgr.updateUser(ctx, u) } - - if _, err := session.Patch(ctx, mgr.cfg.Load().dataBrokerClient, s.Session, fm); err != nil { - log.Error(ctx).Err(err). - Str("user_id", s.GetUserId()). - Str("session_id", s.GetId()). - Msg("failed to update session") - } - return true } -func (mgr *Manager) refreshUser(ctx context.Context, userID string) { - log.Info(ctx). - Str("user_id", userID). - Msg("refreshing user") +func (mgr *Manager) updateUserInfo(ctx context.Context, userID string) { + log.Ctx(ctx).Info().Str("user_id", userID).Msg("updating user info") authenticator := mgr.cfg.Load().authenticator if authenticator == nil { return } - u, ok := mgr.users.Get(userID) - if !ok { - log.Warn(ctx). + mgr.mu.Lock() + u, ss := mgr.dataStore.getUserAndSessions(userID) + mgr.mu.Unlock() + + if u == nil { + log.Ctx(ctx).Warn(). Str("user_id", userID). - Msg("no user found for refresh") + Msg("no user found for update") return } - u.lastRefresh = mgr.now() - mgr.userScheduler.Add(u.NextRefresh(), u.GetId()) - for _, s := range mgr.sessions.GetSessionsForUser(userID) { - if s.Session == nil || s.Session.OauthToken == nil { - log.Warn(ctx). - Str("user_id", userID). - Msg("no session oauth2 token found for refresh") + for _, s := range ss { + if s.GetOauthToken() == nil { + log.Ctx(ctx).Warn(). + Str("user_id", s.GetUserId()). + Str("session_id", s.GetId()). + Msg("no session oauth2 token found for updating user info") continue } - err := authenticator.UpdateUserInfo(ctx, FromOAuthToken(s.OauthToken), &u) + err := authenticator.UpdateUserInfo(ctx, FromOAuthToken(s.GetOauthToken()), newMultiUnmarshaler(newUserUnmarshaler(u), newSessionUnmarshaler(s))) metrics.RecordIdentityManagerUserRefresh(ctx, err) mgr.recordLastError(metrics_ids.IdentityManagerLastUserRefreshError, err) if isTemporaryError(err) { - log.Error(ctx).Err(err). + log.Ctx(ctx).Error().Err(err). Str("user_id", s.GetUserId()). Str("session_id", s.GetId()). Msg("failed to update user info") - return + continue } else if err != nil { - log.Error(ctx).Err(err). + log.Ctx(ctx).Error().Err(err). Str("user_id", s.GetUserId()). Str("session_id", s.GetId()). Msg("failed to update user info, deleting session") - mgr.deleteSession(ctx, userID, s.GetId()) + mgr.deleteSession(ctx, s.GetId()) continue } - res, err := databroker.Put(ctx, mgr.cfg.Load().dataBrokerClient, u.User) - if err != nil { - log.Error(ctx).Err(err). - Str("user_id", s.GetUserId()). - Str("session_id", s.GetId()). - Msg("failed to update user") - continue - } - - mgr.onUpdateUser(ctx, res.GetRecords()[0], u.User) + mgr.updateSession(ctx, s) + mgr.updateUser(ctx, u) } } -func (mgr *Manager) onUpdateRecords(ctx context.Context, msg updateRecordsMessage) { - for _, record := range msg.records { - switch record.GetType() { - case grpcutil.GetTypeURL(new(session.Session)): - var pbSession session.Session - err := record.GetData().UnmarshalTo(&pbSession) - if err != nil { - log.Warn(ctx).Msgf("error unmarshaling session: %s", err) - continue - } - mgr.onUpdateSession(record, &pbSession) - case grpcutil.GetTypeURL(new(user.User)): - var pbUser user.User - err := record.GetData().UnmarshalTo(&pbUser) - if err != nil { - log.Warn(ctx).Msgf("error unmarshaling user: %s", err) - continue - } - mgr.onUpdateUser(ctx, record, &pbUser) - } +// deleteSession deletes a session from the databroke, the local data store, and the schedulers +func (mgr *Manager) deleteSession(ctx context.Context, sessionID string) { + log.Ctx(ctx).Debug(). + Str("session_id", sessionID). + Msg("deleting session") + + mgr.mu.Lock() + mgr.dataStore.deleteSession(sessionID) + if rss, ok := mgr.refreshSessionSchedulers[sessionID]; ok { + rss.Stop() + delete(mgr.refreshSessionSchedulers, sessionID) } -} + mgr.mu.Unlock() -func (mgr *Manager) onUpdateSession(record *databroker.Record, session *session.Session) { - mgr.sessionScheduler.Remove(toSessionSchedulerKey(session.GetUserId(), session.GetId())) - - if record.GetDeletedAt() != nil { - mgr.sessions.Delete(session.GetUserId(), session.GetId()) - return - } - - // update session - s, _ := mgr.sessions.Get(session.GetUserId(), session.GetId()) - if s.lastRefresh.IsZero() { - s.lastRefresh = mgr.now() - } - s.gracePeriod = mgr.cfg.Load().sessionRefreshGracePeriod - s.coolOffDuration = mgr.cfg.Load().sessionRefreshCoolOffDuration - s.Session = session - mgr.sessions.ReplaceOrInsert(s) - mgr.sessionScheduler.Add(s.NextRefresh(), toSessionSchedulerKey(session.GetUserId(), session.GetId())) -} - -func (mgr *Manager) onUpdateUser(_ context.Context, record *databroker.Record, user *user.User) { - mgr.userScheduler.Remove(user.GetId()) - - if record.GetDeletedAt() != nil { - mgr.users.Delete(user.GetId()) - return - } - - u, _ := mgr.users.Get(user.GetId()) - u.lastRefresh = mgr.cfg.Load().now() - u.User = user - mgr.users.ReplaceOrInsert(u) - mgr.userScheduler.Add(u.NextRefresh(), u.GetId()) -} - -func (mgr *Manager) deleteSession(ctx context.Context, userID, sessionID string) { - mgr.sessionScheduler.Remove(toSessionSchedulerKey(userID, sessionID)) - mgr.sessions.Delete(userID, sessionID) - - client := mgr.cfg.Load().dataBrokerClient - res, err := client.Get(ctx, &databroker.GetRequest{ + res, err := mgr.cfg.Load().dataBrokerClient.Get(ctx, &databroker.GetRequest{ Type: grpcutil.GetTypeURL(new(session.Session)), Id: sessionID, }) if status.Code(err) == codes.NotFound { return - } else if err != nil { - log.Error(ctx).Err(err). - Str("session_id", sessionID). - Msg("failed to delete session") - return } record := res.GetRecord() record.DeletedAt = timestamppb.Now() - _, err = client.Put(ctx, &databroker.PutRequest{ + _, err = mgr.cfg.Load().dataBrokerClient.Put(ctx, &databroker.PutRequest{ Records: []*databroker.Record{record}, }) if err != nil { - log.Error(ctx).Err(err). + log.Ctx(ctx).Error().Err(err). Str("session_id", sessionID). Msg("failed to delete session") return } } -// reset resets all the manager datastructures to their initial state -func (mgr *Manager) reset() { - mgr.sessions = sessionCollection{BTree: btree.New(8)} - mgr.users = userCollection{BTree: btree.New(8)} +func (mgr *Manager) updateSession(ctx context.Context, s *session.Session) { + log.Ctx(ctx).Debug(). + Str("user_id", s.GetUserId()). + Str("session_id", s.GetId()). + Msg("updating session") + + fm, err := fieldmaskpb.New(s, "oauth_token", "id_token", "claims") + if err != nil { + log.Error(ctx).Err(err). + Str("user_id", s.GetUserId()). + Str("session_id", s.GetId()). + Msg("failed to create fieldmask for session") + return + } + + _, err = session.Patch(ctx, mgr.cfg.Load().dataBrokerClient, s, fm) + if err != nil { + log.Error(ctx).Err(err). + Str("user_id", s.GetUserId()). + Str("session_id", s.GetId()). + Msg("failed to patch updated session record") + return + } + + mgr.mu.Lock() + mgr.dataStore.putSession(s) + if rss, ok := mgr.refreshSessionSchedulers[s.GetId()]; ok { + rss.Update(s) + } + mgr.mu.Unlock() +} + +// updateUser updates the user in the databroker, the local data store, and resets the scheduler +func (mgr *Manager) updateUser(ctx context.Context, u *user.User) { + log.Ctx(ctx).Debug(). + Str("user_id", u.GetId()). + Msg("updating user") + + _, err := databroker.Put(ctx, mgr.cfg.Load().dataBrokerClient, u) + if err != nil { + log.Ctx(ctx).Error(). + Str("user_id", u.GetId()). + Err(err). + Msg("failed to store updated user record") + return + } + + mgr.mu.Lock() + mgr.dataStore.putUser(u) + if uuis, ok := mgr.updateUserInfoSchedulers[u.GetId()]; ok { + uuis.Reset() + } + mgr.mu.Unlock() } func (mgr *Manager) recordLastError(id string, err error) { @@ -473,17 +412,3 @@ func (mgr *Manager) recordLastError(id string, err error) { Id: id, }) } - -func isTemporaryError(err error) bool { - if err == nil { - return false - } - if errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled) { - return true - } - var hasTemporary interface{ Temporary() bool } - if errors.As(err, &hasTemporary) && hasTemporary.Temporary() { - return true - } - return false -} diff --git a/internal/identity/manager/manager_test.go b/internal/identity/manager/manager_test.go index 9e4704d23..f5225b35f 100644 --- a/internal/identity/manager/manager_test.go +++ b/internal/identity/manager/manager_test.go @@ -2,28 +2,10 @@ package manager import ( "context" - "errors" - "fmt" - "testing" - "time" - "github.com/stretchr/testify/assert" - "go.uber.org/mock/gomock" "golang.org/x/oauth2" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" - "google.golang.org/protobuf/proto" - "google.golang.org/protobuf/types/known/fieldmaskpb" - "google.golang.org/protobuf/types/known/timestamppb" - "github.com/pomerium/pomerium/internal/events" "github.com/pomerium/pomerium/internal/identity/identity" - "github.com/pomerium/pomerium/pkg/grpc/databroker" - "github.com/pomerium/pomerium/pkg/grpc/databroker/mock_databroker" - "github.com/pomerium/pomerium/pkg/grpc/session" - "github.com/pomerium/pomerium/pkg/grpc/user" - metrics_ids "github.com/pomerium/pomerium/pkg/metrics" - "github.com/pomerium/pomerium/pkg/protoutil" ) type mockAuthenticator struct { @@ -44,317 +26,3 @@ func (mock *mockAuthenticator) Revoke(_ context.Context, _ *oauth2.Token) error func (mock *mockAuthenticator) UpdateUserInfo(_ context.Context, _ *oauth2.Token, _ any) error { return mock.updateUserInfoError } - -func TestManager_refresh(t *testing.T) { - ctrl := gomock.NewController(t) - ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10) - t.Cleanup(clearTimeout) - - client := mock_databroker.NewMockDataBrokerServiceClient(ctrl) - mgr := New(WithDataBrokerClient(client)) - mgr.onUpdateRecords(ctx, updateRecordsMessage{ - records: []*databroker.Record{ - databroker.NewRecord(&session.Session{ - Id: "s1", - UserId: "u1", - OauthToken: &session.OAuthToken{}, - ExpiresAt: timestamppb.New(time.Now().Add(time.Second * 10)), - }), - databroker.NewRecord(&user.User{ - Id: "u1", - }), - }, - }) - client.EXPECT().Get(gomock.Any(), gomock.Any()).Return(nil, status.Error(codes.NotFound, "not found")) - mgr.refreshSession(ctx, "u1", "s1") - mgr.refreshUser(ctx, "u1") -} - -func TestManager_onUpdateRecords(t *testing.T) { - ctrl := gomock.NewController(t) - - ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10) - defer clearTimeout() - - now := time.Now() - - mgr := New( - WithDataBrokerClient(mock_databroker.NewMockDataBrokerServiceClient(ctrl)), - WithNow(func() time.Time { - return now - }), - ) - - mgr.onUpdateRecords(ctx, updateRecordsMessage{ - records: []*databroker.Record{ - mkRecord(&session.Session{Id: "session1", UserId: "user1"}), - mkRecord(&user.User{Id: "user1", Name: "user 1", Email: "user1@example.com"}), - }, - }) - - if _, ok := mgr.sessions.Get("user1", "session1"); assert.True(t, ok) { - tm, id := mgr.sessionScheduler.Next() - assert.Equal(t, now.Add(10*time.Second), tm) - assert.Equal(t, "user1\037session1", id) - } - if _, ok := mgr.users.Get("user1"); assert.True(t, ok) { - tm, id := mgr.userScheduler.Next() - assert.Equal(t, now.Add(userRefreshInterval), tm) - assert.Equal(t, "user1", id) - } -} - -func TestManager_onUpdateSession(t *testing.T) { - startTime := time.Date(2023, 10, 19, 12, 0, 0, 0, time.UTC) - - s := &session.Session{ - Id: "session-id", - UserId: "user-id", - OauthToken: &session.OAuthToken{ - AccessToken: "access-token", - ExpiresAt: timestamppb.New(startTime.Add(5 * time.Minute)), - }, - IssuedAt: timestamppb.New(startTime), - ExpiresAt: timestamppb.New(startTime.Add(24 * time.Hour)), - } - - assertNextScheduled := func(t *testing.T, mgr *Manager, expectedTime time.Time) { - t.Helper() - tm, key := mgr.sessionScheduler.Next() - assert.Equal(t, expectedTime, tm) - assert.Equal(t, "user-id\037session-id", key) - } - - t.Run("initial refresh event when not expiring soon", func(t *testing.T) { - now := startTime - mgr := New(WithNow(func() time.Time { return now })) - - // When the Manager first becomes aware of a session it should schedule - // a refresh event for one minute before access token expiration. - mgr.onUpdateSession(mkRecord(s), s) - assertNextScheduled(t, mgr, startTime.Add(4*time.Minute)) - }) - t.Run("initial refresh event when expiring soon", func(t *testing.T) { - now := startTime - mgr := New(WithNow(func() time.Time { return now })) - - // When the Manager first becomes aware of a session, if that session - // is expiring within the gracePeriod (1 minute), it should schedule a - // refresh event for as soon as possible, subject to the - // coolOffDuration (10 seconds). - now = now.Add(4*time.Minute + 30*time.Second) // 30 s before expiration - mgr.onUpdateSession(mkRecord(s), s) - assertNextScheduled(t, mgr, now.Add(10*time.Second)) - }) - t.Run("update near scheduled refresh", func(t *testing.T) { - now := startTime - mgr := New(WithNow(func() time.Time { return now })) - - mgr.onUpdateSession(mkRecord(s), s) - assertNextScheduled(t, mgr, startTime.Add(4*time.Minute)) - - // If a session is updated close to the time when it is scheduled to be - // refreshed, the scheduled refresh event should not be pushed back. - now = now.Add(3*time.Minute + 55*time.Second) // 5 s before refresh - mgr.onUpdateSession(mkRecord(s), s) - assertNextScheduled(t, mgr, now.Add(5*time.Second)) - - // However, if an update changes the access token validity, the refresh - // event should be rescheduled accordingly. (This should be uncommon, - // as only the refresh loop itself should modify the access token.) - s2 := proto.Clone(s).(*session.Session) - s2.OauthToken.ExpiresAt = timestamppb.New(now.Add(5 * time.Minute)) - mgr.onUpdateSession(mkRecord(s2), s2) - assertNextScheduled(t, mgr, now.Add(4*time.Minute)) - }) - t.Run("session record deleted", func(t *testing.T) { - now := startTime - mgr := New(WithNow(func() time.Time { return now })) - - mgr.onUpdateSession(mkRecord(s), s) - assertNextScheduled(t, mgr, startTime.Add(4*time.Minute)) - - // If a session is deleted, any scheduled refresh event should be canceled. - record := mkRecord(s) - record.DeletedAt = timestamppb.New(now) - mgr.onUpdateSession(record, s) - _, key := mgr.sessionScheduler.Next() - assert.Empty(t, key) - }) -} - -func TestManager_refreshSession(t *testing.T) { - startTime := time.Date(2023, 10, 19, 12, 0, 0, 0, time.UTC) - - var auth mockAuthenticator - - ctrl := gomock.NewController(t) - client := mock_databroker.NewMockDataBrokerServiceClient(ctrl) - - now := startTime - mgr := New( - WithDataBrokerClient(client), - WithNow(func() time.Time { return now }), - WithAuthenticator(&auth), - ) - - // Initialize the Manager with a new session. - s := &session.Session{ - Id: "session-id", - UserId: "user-id", - OauthToken: &session.OAuthToken{ - AccessToken: "access-token", - ExpiresAt: timestamppb.New(startTime.Add(5 * time.Minute)), - RefreshToken: "refresh-token", - }, - IssuedAt: timestamppb.New(startTime), - ExpiresAt: timestamppb.New(startTime.Add(24 * time.Hour)), - } - mgr.sessions.ReplaceOrInsert(Session{ - Session: s, - lastRefresh: startTime, - gracePeriod: time.Minute, - coolOffDuration: 10 * time.Second, - }) - - // If OAuth2 token refresh fails with a temporary error, the manager should - // still reschedule another refresh attempt. - now = now.Add(4 * time.Minute) - auth.refreshError = context.DeadlineExceeded - mgr.refreshSession(context.Background(), "user-id", "session-id") - - tm, key := mgr.sessionScheduler.Next() - assert.Equal(t, now.Add(10*time.Second), tm) - assert.Equal(t, "user-id\037session-id", key) - - // Simulate a successful token refresh on the second attempt. The manager - // should store the updated session in the databroker and schedule another - // refresh event. - now = now.Add(10 * time.Second) - auth.refreshResult, auth.refreshError = &oauth2.Token{ - AccessToken: "new-access-token", - RefreshToken: "new-refresh-token", - Expiry: now.Add(5 * time.Minute), - }, nil - expectedSession := proto.Clone(s).(*session.Session) - expectedSession.OauthToken = &session.OAuthToken{ - AccessToken: "new-access-token", - ExpiresAt: timestamppb.New(now.Add(5 * time.Minute)), - RefreshToken: "new-refresh-token", - } - client.EXPECT().Patch(gomock.Any(), objectsAreEqualMatcher{ - &databroker.PatchRequest{ - Records: []*databroker.Record{{ - Type: "type.googleapis.com/session.Session", - Id: "session-id", - Data: protoutil.NewAny(expectedSession), - }}, - FieldMask: &fieldmaskpb.FieldMask{ - Paths: []string{"oauth_token", "id_token", "claims"}, - }, - }, - }). - Return(nil /* this result is currently unused */, nil) - mgr.refreshSession(context.Background(), "user-id", "session-id") - - tm, key = mgr.sessionScheduler.Next() - assert.Equal(t, now.Add(4*time.Minute), tm) - assert.Equal(t, "user-id\037session-id", key) -} - -func TestManager_reportErrors(t *testing.T) { - ctrl := gomock.NewController(t) - - ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10) - defer clearTimeout() - - evtMgr := events.New() - received := make(chan events.Event, 1) - handle := evtMgr.Register(func(evt events.Event) { - received <- evt - }) - defer evtMgr.Unregister(handle) - - expectMsg := func(id, msg string) { - t.Helper() - assert.Eventually(t, func() bool { - select { - case evt := <-received: - lastErr := evt.(*events.LastError) - return msg == lastErr.Message && id == lastErr.Id - default: - return false - } - }, time.Second, time.Millisecond*20, msg) - } - - s := &session.Session{ - Id: "session1", - UserId: "user1", - OauthToken: &session.OAuthToken{ - ExpiresAt: timestamppb.New(time.Now().Add(time.Hour)), - }, - ExpiresAt: timestamppb.New(time.Now().Add(time.Hour)), - } - - client := mock_databroker.NewMockDataBrokerServiceClient(ctrl) - client.EXPECT().Get(gomock.Any(), gomock.Any()).AnyTimes().Return(&databroker.GetResponse{Record: databroker.NewRecord(s)}, nil) - client.EXPECT().Put(gomock.Any(), gomock.Any()).AnyTimes() - mgr := New( - WithEventManager(evtMgr), - WithDataBrokerClient(client), - WithAuthenticator(&mockAuthenticator{ - refreshError: errors.New("update session"), - updateUserInfoError: errors.New("update user info"), - }), - ) - - mgr.onUpdateRecords(ctx, updateRecordsMessage{ - records: []*databroker.Record{ - mkRecord(s), - mkRecord(&user.User{Id: "user1", Name: "user 1", Email: "user1@example.com"}), - }, - }) - - mgr.refreshUser(ctx, "user1") - expectMsg(metrics_ids.IdentityManagerLastUserRefreshError, "update user info") - - mgr.onUpdateRecords(ctx, updateRecordsMessage{ - records: []*databroker.Record{ - mkRecord(s), - mkRecord(&user.User{Id: "user1", Name: "user 1", Email: "user1@example.com"}), - }, - }) - - mgr.refreshSession(ctx, "user1", "session1") - expectMsg(metrics_ids.IdentityManagerLastSessionRefreshError, "update session") -} - -func mkRecord(msg recordable) *databroker.Record { - data := protoutil.NewAny(msg) - return &databroker.Record{ - Type: data.GetTypeUrl(), - Id: msg.GetId(), - Data: data, - } -} - -type recordable interface { - proto.Message - GetId() string -} - -// objectsAreEqualMatcher implements gomock.Matcher using ObjectsAreEqual. This -// is especially helpful when working with pointers, as it will compare the -// underlying values rather than the pointers themselves. -type objectsAreEqualMatcher struct { - expected interface{} -} - -func (m objectsAreEqualMatcher) Matches(x interface{}) bool { - return assert.ObjectsAreEqual(m.expected, x) -} - -func (m objectsAreEqualMatcher) String() string { - return fmt.Sprintf("is equal to %v (%T)", m.expected, m.expected) -} diff --git a/internal/identity/manager/misc.go b/internal/identity/manager/misc.go index efb4fc0da..feb920ec6 100644 --- a/internal/identity/manager/misc.go +++ b/internal/identity/manager/misc.go @@ -1,7 +1,8 @@ package manager import ( - "strings" + "context" + "errors" "golang.org/x/oauth2" "google.golang.org/protobuf/types/known/timestamppb" @@ -9,21 +10,6 @@ import ( "github.com/pomerium/pomerium/pkg/grpc/session" ) -func toSessionSchedulerKey(userID, sessionID string) string { - return userID + "\037" + sessionID -} - -func fromSessionSchedulerKey(key string) (userID, sessionID string) { - idx := strings.Index(key, "\037") - if idx >= 0 { - userID = key[:idx] - sessionID = key[idx+1:] - } else { - userID = key - } - return userID, sessionID -} - // FromOAuthToken converts a session oauth token to oauth2.Token. func FromOAuthToken(token *session.OAuthToken) *oauth2.Token { return &oauth2.Token{ @@ -44,3 +30,17 @@ func ToOAuthToken(token *oauth2.Token) *session.OAuthToken { ExpiresAt: expiry, } } + +func isTemporaryError(err error) bool { + if err == nil { + return false + } + if errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled) { + return true + } + var hasTemporary interface{ Temporary() bool } + if errors.As(err, &hasTemporary) && hasTemporary.Temporary() { + return true + } + return false +} diff --git a/internal/identity/manager/schedulers.go b/internal/identity/manager/schedulers.go new file mode 100644 index 000000000..1a44ee811 --- /dev/null +++ b/internal/identity/manager/schedulers.go @@ -0,0 +1,138 @@ +package manager + +import ( + "context" + "sync/atomic" + "time" + + "github.com/pomerium/pomerium/pkg/grpc/session" +) + +type updateUserInfoScheduler struct { + mgr *Manager + userID string + reset chan struct{} + cancel context.CancelFunc +} + +func newUpdateUserInfoScheduler(ctx context.Context, mgr *Manager, userID string) *updateUserInfoScheduler { + uuis := &updateUserInfoScheduler{ + mgr: mgr, + userID: userID, + reset: make(chan struct{}, 1), + } + ctx = context.WithoutCancel(ctx) + ctx, uuis.cancel = context.WithCancel(ctx) + go uuis.run(ctx) + return uuis +} + +func (uuis *updateUserInfoScheduler) Reset() { + select { + case uuis.reset <- struct{}{}: + default: + } +} + +func (uuis *updateUserInfoScheduler) Stop() { + uuis.cancel() +} + +func (uuis *updateUserInfoScheduler) run(ctx context.Context) { + ticker := time.NewTicker(uuis.mgr.cfg.Load().updateUserInfoInterval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-uuis.reset: + ticker.Reset(uuis.mgr.cfg.Load().updateUserInfoInterval) + case <-ticker.C: + uuis.mgr.updateUserInfo(ctx, uuis.userID) + } + } +} + +type refreshSessionScheduler struct { + mgr *Manager + sessionID string + lastRefresh atomic.Pointer[time.Time] + next chan time.Time + cancel context.CancelFunc +} + +func newRefreshSessionScheduler( + ctx context.Context, + mgr *Manager, + sessionID string, +) *refreshSessionScheduler { + rss := &refreshSessionScheduler{ + mgr: mgr, + sessionID: sessionID, + next: make(chan time.Time, 1), + } + now := rss.mgr.cfg.Load().now() + rss.lastRefresh.Store(&now) + ctx = context.WithoutCancel(ctx) + ctx, rss.cancel = context.WithCancel(ctx) + go rss.run(ctx) + return rss +} + +func (rss *refreshSessionScheduler) Update(s *session.Session) { + due := nextSessionRefresh( + s, + *rss.lastRefresh.Load(), + rss.mgr.cfg.Load().sessionRefreshGracePeriod, + rss.mgr.cfg.Load().sessionRefreshCoolOffDuration, + ) + for { + select { + case <-rss.next: + default: + } + select { + case rss.next <- due: + return + default: + } + } +} + +func (rss *refreshSessionScheduler) Stop() { + rss.cancel() +} + +func (rss *refreshSessionScheduler) run(ctx context.Context) { + var timer *time.Timer + + // wait for the first update + select { + case <-ctx.Done(): + return + case due := <-rss.next: + delay := max(time.Until(due), 0) + timer = time.NewTimer(delay) + defer timer.Stop() + } + + // wait for updates or for the timer to trigger + for { + select { + case <-ctx.Done(): + return + case due := <-rss.next: + delay := max(time.Until(due), 0) + // stop the current timer and reset it + if !timer.Stop() { + <-timer.C + } + timer.Reset(delay) + case <-timer.C: + now := rss.mgr.cfg.Load().now() + rss.lastRefresh.Store(&now) + rss.mgr.refreshSession(ctx, rss.sessionID) + } + } +} diff --git a/internal/identity/manager/sync.go b/internal/identity/manager/sync.go index f8707cb0c..b588d804c 100644 --- a/internal/identity/manager/sync.go +++ b/internal/identity/manager/sync.go @@ -3,53 +3,75 @@ package manager import ( "context" - "github.com/pomerium/pomerium/internal/atomicutil" + "github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/pkg/grpc/databroker" + "github.com/pomerium/pomerium/pkg/grpc/session" + "github.com/pomerium/pomerium/pkg/grpc/user" + "github.com/pomerium/pomerium/pkg/grpcutil" ) -type dataBrokerSyncer struct { - cfg *atomicutil.Value[*config] - - update chan<- updateRecordsMessage - clear chan<- struct{} - - syncer *databroker.Syncer +type sessionSyncerHandler struct { + mgr *Manager } -func newDataBrokerSyncer( - _ context.Context, - cfg *atomicutil.Value[*config], - update chan<- updateRecordsMessage, - clear chan<- struct{}, -) *dataBrokerSyncer { - syncer := &dataBrokerSyncer{ - cfg: cfg, - - update: update, - clear: clear, - } - syncer.syncer = databroker.NewSyncer("identity_manager", syncer) - return syncer +func newSessionSyncer(mgr *Manager) *databroker.Syncer { + return databroker.NewSyncer("identity_manager/sessions", sessionSyncerHandler{mgr: mgr}, + databroker.WithTypeURL(grpcutil.GetTypeURL(new(session.Session)))) } -func (syncer *dataBrokerSyncer) Run(ctx context.Context) (err error) { - return syncer.syncer.Run(ctx) +func (h sessionSyncerHandler) ClearRecords(ctx context.Context) { + h.mgr.onDeleteAllSessions(ctx) } -func (syncer *dataBrokerSyncer) ClearRecords(ctx context.Context) { - select { - case <-ctx.Done(): - case syncer.clear <- struct{}{}: +func (h sessionSyncerHandler) GetDataBrokerServiceClient() databroker.DataBrokerServiceClient { + return h.mgr.cfg.Load().dataBrokerClient +} + +func (h sessionSyncerHandler) UpdateRecords(ctx context.Context, _ uint64, records []*databroker.Record) { + for _, record := range records { + if record.GetDeletedAt() != nil { + h.mgr.onDeleteSession(ctx, record.GetId()) + } else { + var s session.Session + err := record.Data.UnmarshalTo(&s) + if err != nil { + log.Ctx(ctx).Warn().Err(err).Msg("invalid data in session record, ignoring") + } else { + h.mgr.onUpdateSession(ctx, &s) + } + } } } -func (syncer *dataBrokerSyncer) GetDataBrokerServiceClient() databroker.DataBrokerServiceClient { - return syncer.cfg.Load().dataBrokerClient +type userSyncerHandler struct { + mgr *Manager } -func (syncer *dataBrokerSyncer) UpdateRecords(ctx context.Context, _ uint64, records []*databroker.Record) { - select { - case <-ctx.Done(): - case syncer.update <- updateRecordsMessage{records: records}: +func newUserSyncer(mgr *Manager) *databroker.Syncer { + return databroker.NewSyncer("identity_manager/users", userSyncerHandler{mgr: mgr}, + databroker.WithTypeURL(grpcutil.GetTypeURL(new(user.User)))) +} + +func (h userSyncerHandler) ClearRecords(ctx context.Context) { + h.mgr.onDeleteAllUsers(ctx) +} + +func (h userSyncerHandler) GetDataBrokerServiceClient() databroker.DataBrokerServiceClient { + return h.mgr.cfg.Load().dataBrokerClient +} + +func (h userSyncerHandler) UpdateRecords(ctx context.Context, _ uint64, records []*databroker.Record) { + for _, record := range records { + if record.GetDeletedAt() != nil { + h.mgr.onDeleteUser(ctx, record.GetId()) + } else { + var u user.User + err := record.Data.UnmarshalTo(&u) + if err != nil { + log.Ctx(ctx).Warn().Err(err).Msg("invalid data in user record, ignoring") + } else { + h.mgr.onUpdateUser(ctx, &u) + } + } } }