mirror of
https://github.com/pomerium/pomerium.git
synced 2025-07-01 17:08:16 +02:00
core/identity: add data store for thread-safe storage of sessions and users
This commit is contained in:
parent
494dc4accc
commit
a6577fd570
2 changed files with 218 additions and 0 deletions
140
internal/identity/manager/datastore.go
Normal file
140
internal/identity/manager/datastore.go
Normal file
|
@ -0,0 +1,140 @@
|
|||
package manager
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"slices"
|
||||
"sync"
|
||||
|
||||
"google.golang.org/protobuf/proto"
|
||||
|
||||
"github.com/pomerium/pomerium/pkg/grpc/session"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/user"
|
||||
)
|
||||
|
||||
// dataStore stores session and user data. All public methods are thread-safe.
|
||||
type dataStore struct {
|
||||
mu sync.Mutex
|
||||
sessions map[string]*session.Session
|
||||
users map[string]*user.User
|
||||
userIDToSessionIDs map[string]map[string]struct{}
|
||||
}
|
||||
|
||||
func newDataStore() *dataStore {
|
||||
return &dataStore{
|
||||
sessions: make(map[string]*session.Session),
|
||||
users: make(map[string]*user.User),
|
||||
userIDToSessionIDs: make(map[string]map[string]struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// DeleteSession deletes a session.
|
||||
func (ds *dataStore) DeleteSession(sID string) {
|
||||
ds.mu.Lock()
|
||||
ds.deleteSessionLocked(sID)
|
||||
ds.mu.Unlock()
|
||||
}
|
||||
|
||||
// DeleteUser deletes a user.
|
||||
func (ds *dataStore) DeleteUser(userID string) {
|
||||
ds.mu.Lock()
|
||||
delete(ds.users, userID)
|
||||
ds.mu.Unlock()
|
||||
}
|
||||
|
||||
// GetSessionAndUser gets a session and its associated user.
|
||||
func (ds *dataStore) GetSessionAndUser(sessionID string) (s *session.Session, u *user.User) {
|
||||
ds.mu.Lock()
|
||||
s = ds.sessions[sessionID]
|
||||
if s.GetUserId() != "" {
|
||||
u = ds.users[s.GetUserId()]
|
||||
}
|
||||
ds.mu.Unlock()
|
||||
|
||||
// clone to avoid sharing memory
|
||||
s = clone(s)
|
||||
u = clone(u)
|
||||
return s, u
|
||||
}
|
||||
|
||||
// GetUserAndSessions gets a user and all of its associated sessions.
|
||||
func (ds *dataStore) GetUserAndSessions(userID string) (u *user.User, ss []*session.Session) {
|
||||
ds.mu.Lock()
|
||||
u = ds.users[userID]
|
||||
for sessionID := range ds.userIDToSessionIDs[userID] {
|
||||
ss = append(ss, ds.sessions[sessionID])
|
||||
}
|
||||
ds.mu.Unlock()
|
||||
|
||||
// remove nils and sort by id
|
||||
ss = slices.Compact(ss)
|
||||
slices.SortFunc(ss, func(a, b *session.Session) int {
|
||||
return cmp.Compare(a.GetId(), b.GetId())
|
||||
})
|
||||
|
||||
// clone to avoid sharing memory
|
||||
u = clone(u)
|
||||
for i := range ss {
|
||||
ss[i] = clone(ss[i])
|
||||
}
|
||||
return u, ss
|
||||
}
|
||||
|
||||
// PutSession stores the session.
|
||||
func (ds *dataStore) PutSession(s *session.Session) {
|
||||
// clone to avoid sharing memory
|
||||
s = clone(s)
|
||||
|
||||
ds.mu.Lock()
|
||||
if s.GetId() != "" {
|
||||
ds.deleteSessionLocked(s.GetId())
|
||||
ds.sessions[s.GetId()] = s
|
||||
if s.GetUserId() != "" {
|
||||
m, ok := ds.userIDToSessionIDs[s.GetUserId()]
|
||||
if !ok {
|
||||
m = make(map[string]struct{})
|
||||
ds.userIDToSessionIDs[s.GetUserId()] = m
|
||||
}
|
||||
m[s.GetId()] = struct{}{}
|
||||
}
|
||||
}
|
||||
ds.mu.Unlock()
|
||||
}
|
||||
|
||||
// PutUser stores the user.
|
||||
func (ds *dataStore) PutUser(u *user.User) {
|
||||
// clone to avoid sharing memory
|
||||
u = clone(u)
|
||||
|
||||
ds.mu.Lock()
|
||||
if u.GetId() != "" {
|
||||
ds.users[u.GetId()] = u
|
||||
}
|
||||
ds.mu.Unlock()
|
||||
}
|
||||
|
||||
func (ds *dataStore) deleteSessionLocked(sID string) {
|
||||
s := ds.sessions[sID]
|
||||
delete(ds.sessions, sID)
|
||||
if s.GetUserId() == "" {
|
||||
return
|
||||
}
|
||||
|
||||
m := ds.userIDToSessionIDs[s.GetUserId()]
|
||||
if m != nil {
|
||||
delete(m, s.GetId())
|
||||
}
|
||||
if len(m) == 0 {
|
||||
delete(ds.userIDToSessionIDs, s.GetUserId())
|
||||
}
|
||||
}
|
||||
|
||||
// clone clones a protobuf message
|
||||
func clone[T any, U interface {
|
||||
*T
|
||||
proto.Message
|
||||
}](src U) U {
|
||||
if src == nil {
|
||||
return src
|
||||
}
|
||||
return proto.Clone(src).(U)
|
||||
}
|
78
internal/identity/manager/datastore_test.go
Normal file
78
internal/identity/manager/datastore_test.go
Normal file
|
@ -0,0 +1,78 @@
|
|||
package manager
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"google.golang.org/protobuf/testing/protocmp"
|
||||
|
||||
"github.com/pomerium/pomerium/pkg/grpc/session"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/user"
|
||||
)
|
||||
|
||||
func TestDataStore(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ds := newDataStore()
|
||||
s, u := ds.GetSessionAndUser("S1")
|
||||
assert.Nil(t, s, "should return a nil session when none exists")
|
||||
assert.Nil(t, u, "should return a nil user when none exists")
|
||||
|
||||
u, ss := ds.GetUserAndSessions("U1")
|
||||
assert.Nil(t, u, "should return a nil user when none exists")
|
||||
assert.Empty(t, ss, "should return an empty list of sessions when no user exists")
|
||||
|
||||
s = &session.Session{Id: "S1", UserId: "U1"}
|
||||
ds.PutSession(s)
|
||||
s1, u1 := ds.GetSessionAndUser("S1")
|
||||
assert.NotNil(t, s1, "should return a non-nil session")
|
||||
assert.False(t, s == s1, "should return different pointers")
|
||||
assert.Empty(t, cmp.Diff(s, s1, protocmp.Transform()), "should be the same as was entered")
|
||||
assert.Nil(t, u1, "should return a nil user when only the session exists")
|
||||
|
||||
ds.PutUser(&user.User{
|
||||
Id: "U1",
|
||||
})
|
||||
_, u1 = ds.GetSessionAndUser("S1")
|
||||
assert.NotNil(t, u1, "should return a user now that it has been added")
|
||||
|
||||
ds.PutSession(&session.Session{Id: "S4", UserId: "U1"})
|
||||
ds.PutSession(&session.Session{Id: "S3", UserId: "U1"})
|
||||
ds.PutSession(&session.Session{Id: "S2", UserId: "U1"})
|
||||
u, ss = ds.GetUserAndSessions("U1")
|
||||
assert.NotNil(t, u)
|
||||
assert.Empty(t, cmp.Diff(ss, []*session.Session{
|
||||
{Id: "S1", UserId: "U1"},
|
||||
{Id: "S2", UserId: "U1"},
|
||||
{Id: "S3", UserId: "U1"},
|
||||
{Id: "S4", UserId: "U1"},
|
||||
}, protocmp.Transform()), "should return all sessions in id order")
|
||||
|
||||
ds.DeleteSession("S4")
|
||||
|
||||
u, ss = ds.GetUserAndSessions("U1")
|
||||
assert.NotNil(t, u)
|
||||
assert.Empty(t, cmp.Diff(ss, []*session.Session{
|
||||
{Id: "S1", UserId: "U1"},
|
||||
{Id: "S2", UserId: "U1"},
|
||||
{Id: "S3", UserId: "U1"},
|
||||
}, protocmp.Transform()), "should return all sessions in id order")
|
||||
|
||||
ds.DeleteUser("U1")
|
||||
u, ss = ds.GetUserAndSessions("U1")
|
||||
assert.Nil(t, u)
|
||||
assert.Empty(t, cmp.Diff(ss, []*session.Session{
|
||||
{Id: "S1", UserId: "U1"},
|
||||
{Id: "S2", UserId: "U1"},
|
||||
{Id: "S3", UserId: "U1"},
|
||||
}, protocmp.Transform()), "should still return all sessions in id order")
|
||||
|
||||
ds.DeleteSession("S1")
|
||||
ds.DeleteSession("S2")
|
||||
ds.DeleteSession("S3")
|
||||
|
||||
u, ss = ds.GetUserAndSessions("U1")
|
||||
assert.Nil(t, u)
|
||||
assert.Empty(t, ss)
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue