diff --git a/pkg/storage/inmemory/inmemory.go b/pkg/storage/inmemory/inmemory.go index 9b0b9606c..1ac6674dc 100644 --- a/pkg/storage/inmemory/inmemory.go +++ b/pkg/storage/inmemory/inmemory.go @@ -49,7 +49,7 @@ type DB struct { lastVersion uint64 - mu sync.Mutex + mu sync.RWMutex byID *btree.BTree byVersion *btree.BTree deletedIDs []string @@ -100,6 +100,9 @@ func (db *DB) Close() error { // Delete marks a record as deleted. func (db *DB) Delete(_ context.Context, id string) error { + db.mu.Lock() + defer db.mu.Unlock() + defer db.onchange.Broadcast() db.replaceOrInsert(id, func(record *databroker.Record) { record.DeletedAt = ptypes.TimestampNow() @@ -110,6 +113,9 @@ func (db *DB) Delete(_ context.Context, id string) error { // Get gets a record from the db. func (db *DB) Get(_ context.Context, id string) (*databroker.Record, error) { + db.mu.RLock() + defer db.mu.RUnlock() + record, ok := db.byID.Get(byIDRecord{Record: &databroker.Record{Id: id}}).(byIDRecord) if !ok { return nil, errors.New("not found") @@ -119,6 +125,9 @@ func (db *DB) Get(_ context.Context, id string) (*databroker.Record, error) { // GetAll gets all the records in the db. func (db *DB) GetAll(_ context.Context) ([]*databroker.Record, error) { + db.mu.RLock() + defer db.mu.RUnlock() + var records []*databroker.Record db.byID.Ascend(func(item btree.Item) bool { records = append(records, item.(byIDRecord).Record) @@ -129,6 +138,9 @@ func (db *DB) GetAll(_ context.Context) ([]*databroker.Record, error) { // List lists all the changes since the given version. func (db *DB) List(_ context.Context, sinceVersion string) ([]*databroker.Record, error) { + db.mu.RLock() + defer db.mu.RUnlock() + var records []*databroker.Record db.byVersion.AscendGreaterOrEqual(byVersionRecord{Record: &databroker.Record{Version: sinceVersion}}, func(i btree.Item) bool { record := i.(byVersionRecord) @@ -142,6 +154,9 @@ func (db *DB) List(_ context.Context, sinceVersion string) ([]*databroker.Record // Put replaces or inserts a record in the db. func (db *DB) Put(_ context.Context, id string, data *anypb.Any) error { + db.mu.Lock() + defer db.mu.Unlock() + defer db.onchange.Broadcast() db.replaceOrInsert(id, func(record *databroker.Record) { record.Data = data @@ -165,9 +180,6 @@ func (db *DB) Watch(ctx context.Context) <-chan struct{} { } func (db *DB) replaceOrInsert(id string, f func(record *databroker.Record)) { - db.mu.Lock() - defer db.mu.Unlock() - record, ok := db.byID.Get(byIDRecord{Record: &databroker.Record{Id: id}}).(byIDRecord) if ok { db.byVersion.Delete(byVersionRecord(record)) diff --git a/pkg/storage/inmemory/inmemory_test.go b/pkg/storage/inmemory/inmemory_test.go index d79b5f9a6..9605e6a89 100644 --- a/pkg/storage/inmemory/inmemory_test.go +++ b/pkg/storage/inmemory/inmemory_test.go @@ -8,6 +8,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "golang.org/x/sync/errgroup" "google.golang.org/protobuf/types/known/anypb" ) @@ -85,3 +86,23 @@ func TestDB(t *testing.T) { assert.Nil(t, record) }) } + +func TestConcurrency(t *testing.T) { + ctx := context.Background() + db := NewDB("example", 2) + + eg, ctx := errgroup.WithContext(ctx) + eg.Go(func() error { + for i := 0; i < 1000; i++ { + _, _ = db.List(ctx, "") + } + return nil + }) + eg.Go(func() error { + for i := 0; i < 1000; i++ { + db.Put(ctx, fmt.Sprint(i), new(anypb.Any)) + } + return nil + }) + eg.Wait() +}