mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-29 08:57:18 +02:00
fix concurrency race (#1675)
This commit is contained in:
parent
6e33067eef
commit
35f871ad42
2 changed files with 37 additions and 4 deletions
|
@ -49,7 +49,7 @@ type DB struct {
|
||||||
|
|
||||||
lastVersion uint64
|
lastVersion uint64
|
||||||
|
|
||||||
mu sync.Mutex
|
mu sync.RWMutex
|
||||||
byID *btree.BTree
|
byID *btree.BTree
|
||||||
byVersion *btree.BTree
|
byVersion *btree.BTree
|
||||||
deletedIDs []string
|
deletedIDs []string
|
||||||
|
@ -100,6 +100,9 @@ func (db *DB) Close() error {
|
||||||
|
|
||||||
// Delete marks a record as deleted.
|
// Delete marks a record as deleted.
|
||||||
func (db *DB) Delete(_ context.Context, id string) error {
|
func (db *DB) Delete(_ context.Context, id string) error {
|
||||||
|
db.mu.Lock()
|
||||||
|
defer db.mu.Unlock()
|
||||||
|
|
||||||
defer db.onchange.Broadcast()
|
defer db.onchange.Broadcast()
|
||||||
db.replaceOrInsert(id, func(record *databroker.Record) {
|
db.replaceOrInsert(id, func(record *databroker.Record) {
|
||||||
record.DeletedAt = ptypes.TimestampNow()
|
record.DeletedAt = ptypes.TimestampNow()
|
||||||
|
@ -110,6 +113,9 @@ func (db *DB) Delete(_ context.Context, id string) error {
|
||||||
|
|
||||||
// Get gets a record from the db.
|
// Get gets a record from the db.
|
||||||
func (db *DB) Get(_ context.Context, id string) (*databroker.Record, error) {
|
func (db *DB) Get(_ context.Context, id string) (*databroker.Record, error) {
|
||||||
|
db.mu.RLock()
|
||||||
|
defer db.mu.RUnlock()
|
||||||
|
|
||||||
record, ok := db.byID.Get(byIDRecord{Record: &databroker.Record{Id: id}}).(byIDRecord)
|
record, ok := db.byID.Get(byIDRecord{Record: &databroker.Record{Id: id}}).(byIDRecord)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, errors.New("not found")
|
return nil, errors.New("not found")
|
||||||
|
@ -119,6 +125,9 @@ func (db *DB) Get(_ context.Context, id string) (*databroker.Record, error) {
|
||||||
|
|
||||||
// GetAll gets all the records in the db.
|
// GetAll gets all the records in the db.
|
||||||
func (db *DB) GetAll(_ context.Context) ([]*databroker.Record, error) {
|
func (db *DB) GetAll(_ context.Context) ([]*databroker.Record, error) {
|
||||||
|
db.mu.RLock()
|
||||||
|
defer db.mu.RUnlock()
|
||||||
|
|
||||||
var records []*databroker.Record
|
var records []*databroker.Record
|
||||||
db.byID.Ascend(func(item btree.Item) bool {
|
db.byID.Ascend(func(item btree.Item) bool {
|
||||||
records = append(records, item.(byIDRecord).Record)
|
records = append(records, item.(byIDRecord).Record)
|
||||||
|
@ -129,6 +138,9 @@ func (db *DB) GetAll(_ context.Context) ([]*databroker.Record, error) {
|
||||||
|
|
||||||
// List lists all the changes since the given version.
|
// List lists all the changes since the given version.
|
||||||
func (db *DB) List(_ context.Context, sinceVersion string) ([]*databroker.Record, error) {
|
func (db *DB) List(_ context.Context, sinceVersion string) ([]*databroker.Record, error) {
|
||||||
|
db.mu.RLock()
|
||||||
|
defer db.mu.RUnlock()
|
||||||
|
|
||||||
var records []*databroker.Record
|
var records []*databroker.Record
|
||||||
db.byVersion.AscendGreaterOrEqual(byVersionRecord{Record: &databroker.Record{Version: sinceVersion}}, func(i btree.Item) bool {
|
db.byVersion.AscendGreaterOrEqual(byVersionRecord{Record: &databroker.Record{Version: sinceVersion}}, func(i btree.Item) bool {
|
||||||
record := i.(byVersionRecord)
|
record := i.(byVersionRecord)
|
||||||
|
@ -142,6 +154,9 @@ func (db *DB) List(_ context.Context, sinceVersion string) ([]*databroker.Record
|
||||||
|
|
||||||
// Put replaces or inserts a record in the db.
|
// Put replaces or inserts a record in the db.
|
||||||
func (db *DB) Put(_ context.Context, id string, data *anypb.Any) error {
|
func (db *DB) Put(_ context.Context, id string, data *anypb.Any) error {
|
||||||
|
db.mu.Lock()
|
||||||
|
defer db.mu.Unlock()
|
||||||
|
|
||||||
defer db.onchange.Broadcast()
|
defer db.onchange.Broadcast()
|
||||||
db.replaceOrInsert(id, func(record *databroker.Record) {
|
db.replaceOrInsert(id, func(record *databroker.Record) {
|
||||||
record.Data = data
|
record.Data = data
|
||||||
|
@ -165,9 +180,6 @@ func (db *DB) Watch(ctx context.Context) <-chan struct{} {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *DB) replaceOrInsert(id string, f func(record *databroker.Record)) {
|
func (db *DB) replaceOrInsert(id string, f func(record *databroker.Record)) {
|
||||||
db.mu.Lock()
|
|
||||||
defer db.mu.Unlock()
|
|
||||||
|
|
||||||
record, ok := db.byID.Get(byIDRecord{Record: &databroker.Record{Id: id}}).(byIDRecord)
|
record, ok := db.byID.Get(byIDRecord{Record: &databroker.Record{Id: id}}).(byIDRecord)
|
||||||
if ok {
|
if ok {
|
||||||
db.byVersion.Delete(byVersionRecord(record))
|
db.byVersion.Delete(byVersionRecord(record))
|
||||||
|
|
|
@ -8,6 +8,7 @@ import (
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
"golang.org/x/sync/errgroup"
|
||||||
"google.golang.org/protobuf/types/known/anypb"
|
"google.golang.org/protobuf/types/known/anypb"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -85,3 +86,23 @@ func TestDB(t *testing.T) {
|
||||||
assert.Nil(t, record)
|
assert.Nil(t, record)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestConcurrency(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
db := NewDB("example", 2)
|
||||||
|
|
||||||
|
eg, ctx := errgroup.WithContext(ctx)
|
||||||
|
eg.Go(func() error {
|
||||||
|
for i := 0; i < 1000; i++ {
|
||||||
|
_, _ = db.List(ctx, "")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
eg.Go(func() error {
|
||||||
|
for i := 0; i < 1000; i++ {
|
||||||
|
db.Put(ctx, fmt.Sprint(i), new(anypb.Any))
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
eg.Wait()
|
||||||
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue