mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-29 00:47:17 +02:00
fix concurrency race (#1675)
This commit is contained in:
parent
6e33067eef
commit
35f871ad42
2 changed files with 37 additions and 4 deletions
|
@ -49,7 +49,7 @@ type DB struct {
|
|||
|
||||
lastVersion uint64
|
||||
|
||||
mu sync.Mutex
|
||||
mu sync.RWMutex
|
||||
byID *btree.BTree
|
||||
byVersion *btree.BTree
|
||||
deletedIDs []string
|
||||
|
@ -100,6 +100,9 @@ func (db *DB) Close() error {
|
|||
|
||||
// Delete marks a record as deleted.
|
||||
func (db *DB) Delete(_ context.Context, id string) error {
|
||||
db.mu.Lock()
|
||||
defer db.mu.Unlock()
|
||||
|
||||
defer db.onchange.Broadcast()
|
||||
db.replaceOrInsert(id, func(record *databroker.Record) {
|
||||
record.DeletedAt = ptypes.TimestampNow()
|
||||
|
@ -110,6 +113,9 @@ func (db *DB) Delete(_ context.Context, id string) error {
|
|||
|
||||
// Get gets a record from the db.
|
||||
func (db *DB) Get(_ context.Context, id string) (*databroker.Record, error) {
|
||||
db.mu.RLock()
|
||||
defer db.mu.RUnlock()
|
||||
|
||||
record, ok := db.byID.Get(byIDRecord{Record: &databroker.Record{Id: id}}).(byIDRecord)
|
||||
if !ok {
|
||||
return nil, errors.New("not found")
|
||||
|
@ -119,6 +125,9 @@ func (db *DB) Get(_ context.Context, id string) (*databroker.Record, error) {
|
|||
|
||||
// GetAll gets all the records in the db.
|
||||
func (db *DB) GetAll(_ context.Context) ([]*databroker.Record, error) {
|
||||
db.mu.RLock()
|
||||
defer db.mu.RUnlock()
|
||||
|
||||
var records []*databroker.Record
|
||||
db.byID.Ascend(func(item btree.Item) bool {
|
||||
records = append(records, item.(byIDRecord).Record)
|
||||
|
@ -129,6 +138,9 @@ func (db *DB) GetAll(_ context.Context) ([]*databroker.Record, error) {
|
|||
|
||||
// List lists all the changes since the given version.
|
||||
func (db *DB) List(_ context.Context, sinceVersion string) ([]*databroker.Record, error) {
|
||||
db.mu.RLock()
|
||||
defer db.mu.RUnlock()
|
||||
|
||||
var records []*databroker.Record
|
||||
db.byVersion.AscendGreaterOrEqual(byVersionRecord{Record: &databroker.Record{Version: sinceVersion}}, func(i btree.Item) bool {
|
||||
record := i.(byVersionRecord)
|
||||
|
@ -142,6 +154,9 @@ func (db *DB) List(_ context.Context, sinceVersion string) ([]*databroker.Record
|
|||
|
||||
// Put replaces or inserts a record in the db.
|
||||
func (db *DB) Put(_ context.Context, id string, data *anypb.Any) error {
|
||||
db.mu.Lock()
|
||||
defer db.mu.Unlock()
|
||||
|
||||
defer db.onchange.Broadcast()
|
||||
db.replaceOrInsert(id, func(record *databroker.Record) {
|
||||
record.Data = data
|
||||
|
@ -165,9 +180,6 @@ func (db *DB) Watch(ctx context.Context) <-chan struct{} {
|
|||
}
|
||||
|
||||
func (db *DB) replaceOrInsert(id string, f func(record *databroker.Record)) {
|
||||
db.mu.Lock()
|
||||
defer db.mu.Unlock()
|
||||
|
||||
record, ok := db.byID.Get(byIDRecord{Record: &databroker.Record{Id: id}}).(byIDRecord)
|
||||
if ok {
|
||||
db.byVersion.Delete(byVersionRecord(record))
|
||||
|
|
|
@ -8,6 +8,7 @@ import (
|
|||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/sync/errgroup"
|
||||
"google.golang.org/protobuf/types/known/anypb"
|
||||
)
|
||||
|
||||
|
@ -85,3 +86,23 @@ func TestDB(t *testing.T) {
|
|||
assert.Nil(t, record)
|
||||
})
|
||||
}
|
||||
|
||||
func TestConcurrency(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
db := NewDB("example", 2)
|
||||
|
||||
eg, ctx := errgroup.WithContext(ctx)
|
||||
eg.Go(func() error {
|
||||
for i := 0; i < 1000; i++ {
|
||||
_, _ = db.List(ctx, "")
|
||||
}
|
||||
return nil
|
||||
})
|
||||
eg.Go(func() error {
|
||||
for i := 0; i < 1000; i++ {
|
||||
db.Put(ctx, fmt.Sprint(i), new(anypb.Any))
|
||||
}
|
||||
return nil
|
||||
})
|
||||
eg.Wait()
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue