diff --git a/internal/identity/manager/datastore.go b/internal/identity/manager/datastore.go new file mode 100644 index 000000000..a8aeea78a --- /dev/null +++ b/internal/identity/manager/datastore.go @@ -0,0 +1,140 @@ +package manager + +import ( + "cmp" + "slices" + "sync" + + "google.golang.org/protobuf/proto" + + "github.com/pomerium/pomerium/pkg/grpc/session" + "github.com/pomerium/pomerium/pkg/grpc/user" +) + +// dataStore stores session and user data. All public methods are thread-safe. +type dataStore struct { + mu sync.Mutex + sessions map[string]*session.Session + users map[string]*user.User + userIDToSessionIDs map[string]map[string]struct{} +} + +func newDataStore() *dataStore { + return &dataStore{ + sessions: make(map[string]*session.Session), + users: make(map[string]*user.User), + userIDToSessionIDs: make(map[string]map[string]struct{}), + } +} + +// DeleteSession deletes a session. +func (ds *dataStore) DeleteSession(sID string) { + ds.mu.Lock() + ds.deleteSessionLocked(sID) + ds.mu.Unlock() +} + +// DeleteUser deletes a user. +func (ds *dataStore) DeleteUser(userID string) { + ds.mu.Lock() + delete(ds.users, userID) + ds.mu.Unlock() +} + +// GetSessionAndUser gets a session and its associated user. +func (ds *dataStore) GetSessionAndUser(sessionID string) (s *session.Session, u *user.User) { + ds.mu.Lock() + s = ds.sessions[sessionID] + if s.GetUserId() != "" { + u = ds.users[s.GetUserId()] + } + ds.mu.Unlock() + + // clone to avoid sharing memory + s = clone(s) + u = clone(u) + return s, u +} + +// GetUserAndSessions gets a user and all of its associated sessions. +func (ds *dataStore) GetUserAndSessions(userID string) (u *user.User, ss []*session.Session) { + ds.mu.Lock() + u = ds.users[userID] + for sessionID := range ds.userIDToSessionIDs[userID] { + ss = append(ss, ds.sessions[sessionID]) + } + ds.mu.Unlock() + + // remove nils and sort by id + ss = slices.Compact(ss) + slices.SortFunc(ss, func(a, b *session.Session) int { + return cmp.Compare(a.GetId(), b.GetId()) + }) + + // clone to avoid sharing memory + u = clone(u) + for i := range ss { + ss[i] = clone(ss[i]) + } + return u, ss +} + +// PutSession stores the session. +func (ds *dataStore) PutSession(s *session.Session) { + // clone to avoid sharing memory + s = clone(s) + + ds.mu.Lock() + if s.GetId() != "" { + ds.deleteSessionLocked(s.GetId()) + ds.sessions[s.GetId()] = s + if s.GetUserId() != "" { + m, ok := ds.userIDToSessionIDs[s.GetUserId()] + if !ok { + m = make(map[string]struct{}) + ds.userIDToSessionIDs[s.GetUserId()] = m + } + m[s.GetId()] = struct{}{} + } + } + ds.mu.Unlock() +} + +// PutUser stores the user. +func (ds *dataStore) PutUser(u *user.User) { + // clone to avoid sharing memory + u = clone(u) + + ds.mu.Lock() + if u.GetId() != "" { + ds.users[u.GetId()] = u + } + ds.mu.Unlock() +} + +func (ds *dataStore) deleteSessionLocked(sID string) { + s := ds.sessions[sID] + delete(ds.sessions, sID) + if s.GetUserId() == "" { + return + } + + m := ds.userIDToSessionIDs[s.GetUserId()] + if m != nil { + delete(m, s.GetId()) + } + if len(m) == 0 { + delete(ds.userIDToSessionIDs, s.GetUserId()) + } +} + +// clone clones a protobuf message +func clone[T any, U interface { + *T + proto.Message +}](src U) U { + if src == nil { + return src + } + return proto.Clone(src).(U) +} diff --git a/internal/identity/manager/datastore_test.go b/internal/identity/manager/datastore_test.go new file mode 100644 index 000000000..55592aa43 --- /dev/null +++ b/internal/identity/manager/datastore_test.go @@ -0,0 +1,78 @@ +package manager + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/stretchr/testify/assert" + "google.golang.org/protobuf/testing/protocmp" + + "github.com/pomerium/pomerium/pkg/grpc/session" + "github.com/pomerium/pomerium/pkg/grpc/user" +) + +func TestDataStore(t *testing.T) { + t.Parallel() + + ds := newDataStore() + s, u := ds.GetSessionAndUser("S1") + assert.Nil(t, s, "should return a nil session when none exists") + assert.Nil(t, u, "should return a nil user when none exists") + + u, ss := ds.GetUserAndSessions("U1") + assert.Nil(t, u, "should return a nil user when none exists") + assert.Empty(t, ss, "should return an empty list of sessions when no user exists") + + s = &session.Session{Id: "S1", UserId: "U1"} + ds.PutSession(s) + s1, u1 := ds.GetSessionAndUser("S1") + assert.NotNil(t, s1, "should return a non-nil session") + assert.False(t, s == s1, "should return different pointers") + assert.Empty(t, cmp.Diff(s, s1, protocmp.Transform()), "should be the same as was entered") + assert.Nil(t, u1, "should return a nil user when only the session exists") + + ds.PutUser(&user.User{ + Id: "U1", + }) + _, u1 = ds.GetSessionAndUser("S1") + assert.NotNil(t, u1, "should return a user now that it has been added") + + ds.PutSession(&session.Session{Id: "S4", UserId: "U1"}) + ds.PutSession(&session.Session{Id: "S3", UserId: "U1"}) + ds.PutSession(&session.Session{Id: "S2", UserId: "U1"}) + u, ss = ds.GetUserAndSessions("U1") + assert.NotNil(t, u) + assert.Empty(t, cmp.Diff(ss, []*session.Session{ + {Id: "S1", UserId: "U1"}, + {Id: "S2", UserId: "U1"}, + {Id: "S3", UserId: "U1"}, + {Id: "S4", UserId: "U1"}, + }, protocmp.Transform()), "should return all sessions in id order") + + ds.DeleteSession("S4") + + u, ss = ds.GetUserAndSessions("U1") + assert.NotNil(t, u) + assert.Empty(t, cmp.Diff(ss, []*session.Session{ + {Id: "S1", UserId: "U1"}, + {Id: "S2", UserId: "U1"}, + {Id: "S3", UserId: "U1"}, + }, protocmp.Transform()), "should return all sessions in id order") + + ds.DeleteUser("U1") + u, ss = ds.GetUserAndSessions("U1") + assert.Nil(t, u) + assert.Empty(t, cmp.Diff(ss, []*session.Session{ + {Id: "S1", UserId: "U1"}, + {Id: "S2", UserId: "U1"}, + {Id: "S3", UserId: "U1"}, + }, protocmp.Transform()), "should still return all sessions in id order") + + ds.DeleteSession("S1") + ds.DeleteSession("S2") + ds.DeleteSession("S3") + + u, ss = ds.GetUserAndSessions("U1") + assert.Nil(t, u) + assert.Empty(t, ss) +}