fix concurrency race (#1675)

This commit is contained in:
Caleb Doxsey 2020-12-11 14:43:26 -07:00 committed by GitHub
parent 6e33067eef
commit 35f871ad42
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 37 additions and 4 deletions

View file

@ -49,7 +49,7 @@ type DB struct {
lastVersion uint64
mu sync.Mutex
mu sync.RWMutex
byID *btree.BTree
byVersion *btree.BTree
deletedIDs []string
@ -100,6 +100,9 @@ func (db *DB) Close() error {
// Delete marks a record as deleted.
func (db *DB) Delete(_ context.Context, id string) error {
db.mu.Lock()
defer db.mu.Unlock()
defer db.onchange.Broadcast()
db.replaceOrInsert(id, func(record *databroker.Record) {
record.DeletedAt = ptypes.TimestampNow()
@ -110,6 +113,9 @@ func (db *DB) Delete(_ context.Context, id string) error {
// Get gets a record from the db.
func (db *DB) Get(_ context.Context, id string) (*databroker.Record, error) {
db.mu.RLock()
defer db.mu.RUnlock()
record, ok := db.byID.Get(byIDRecord{Record: &databroker.Record{Id: id}}).(byIDRecord)
if !ok {
return nil, errors.New("not found")
@ -119,6 +125,9 @@ func (db *DB) Get(_ context.Context, id string) (*databroker.Record, error) {
// GetAll gets all the records in the db.
func (db *DB) GetAll(_ context.Context) ([]*databroker.Record, error) {
db.mu.RLock()
defer db.mu.RUnlock()
var records []*databroker.Record
db.byID.Ascend(func(item btree.Item) bool {
records = append(records, item.(byIDRecord).Record)
@ -129,6 +138,9 @@ func (db *DB) GetAll(_ context.Context) ([]*databroker.Record, error) {
// List lists all the changes since the given version.
func (db *DB) List(_ context.Context, sinceVersion string) ([]*databroker.Record, error) {
db.mu.RLock()
defer db.mu.RUnlock()
var records []*databroker.Record
db.byVersion.AscendGreaterOrEqual(byVersionRecord{Record: &databroker.Record{Version: sinceVersion}}, func(i btree.Item) bool {
record := i.(byVersionRecord)
@ -142,6 +154,9 @@ func (db *DB) List(_ context.Context, sinceVersion string) ([]*databroker.Record
// Put replaces or inserts a record in the db.
func (db *DB) Put(_ context.Context, id string, data *anypb.Any) error {
db.mu.Lock()
defer db.mu.Unlock()
defer db.onchange.Broadcast()
db.replaceOrInsert(id, func(record *databroker.Record) {
record.Data = data
@ -165,9 +180,6 @@ func (db *DB) Watch(ctx context.Context) <-chan struct{} {
}
func (db *DB) replaceOrInsert(id string, f func(record *databroker.Record)) {
db.mu.Lock()
defer db.mu.Unlock()
record, ok := db.byID.Get(byIDRecord{Record: &databroker.Record{Id: id}}).(byIDRecord)
if ok {
db.byVersion.Delete(byVersionRecord(record))

View file

@ -8,6 +8,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/sync/errgroup"
"google.golang.org/protobuf/types/known/anypb"
)
@ -85,3 +86,23 @@ func TestDB(t *testing.T) {
assert.Nil(t, record)
})
}
func TestConcurrency(t *testing.T) {
ctx := context.Background()
db := NewDB("example", 2)
eg, ctx := errgroup.WithContext(ctx)
eg.Go(func() error {
for i := 0; i < 1000; i++ {
_, _ = db.List(ctx, "")
}
return nil
})
eg.Go(func() error {
for i := 0; i < 1000; i++ {
db.Put(ctx, fmt.Sprint(i), new(anypb.Any))
}
return nil
})
eg.Wait()
}