Add storage backend interface (#1072)

* pkg: add storage package

Which contains storage.Backend interface to initial support for multiple
backend storage.

* pkg/storage: add inmemory storage

* internal/databroker: use storage.Backend interface

Instead of implementing multiple databroker server implementation for
each kind of storage backend, we use only one databroker server
implementation, which is supported multiple storage backends, which
satisfy storage.Backend interface.
This commit is contained in:
Cuong Manh Le 2020-07-15 09:42:01 +07:00 committed by GitHub
parent a70254ab76
commit 2f84dd2aff
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 88 additions and 45 deletions

View file

@ -1,4 +1,4 @@
package memory
package databroker
import "time"

View file

@ -1,141 +0,0 @@
package memory
import (
"fmt"
"sync"
"sync/atomic"
"time"
"github.com/golang/protobuf/ptypes"
"github.com/google/btree"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/anypb"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
)
type byIDRecord struct {
*databroker.Record
}
func (k byIDRecord) Less(than btree.Item) bool {
return k.Id < than.(byIDRecord).Id
}
type byVersionRecord struct {
*databroker.Record
}
func (k byVersionRecord) Less(than btree.Item) bool {
return k.Version < than.(byVersionRecord).Version
}
// DB is an in-memory database of records using b-trees.
type DB struct {
recordType string
lastVersion uint64
mu sync.Mutex
byID *btree.BTree
byVersion *btree.BTree
deletedIDs []string
}
// NewDB creates a new in-memory database for the given record type.
func NewDB(recordType string, btreeDegree int) *DB {
return &DB{
recordType: recordType,
byID: btree.New(btreeDegree),
byVersion: btree.New(btreeDegree),
}
}
// ClearDeleted clears all the currently deleted records older than the given cutoff.
func (db *DB) ClearDeleted(cutoff time.Time) {
db.mu.Lock()
defer db.mu.Unlock()
var remaining []string
for _, id := range db.deletedIDs {
record, _ := db.byID.Get(byIDRecord{Record: &databroker.Record{Id: id}}).(byIDRecord)
ts, _ := ptypes.Timestamp(record.DeletedAt)
if ts.Before(cutoff) {
db.byID.Delete(record)
db.byVersion.Delete(byVersionRecord(record))
} else {
remaining = append(remaining, id)
}
}
db.deletedIDs = remaining
}
// Delete marks a record as deleted.
func (db *DB) Delete(id string) {
db.replaceOrInsert(id, func(record *databroker.Record) {
record.DeletedAt = ptypes.TimestampNow()
db.deletedIDs = append(db.deletedIDs, id)
})
}
// Get gets a record from the db.
func (db *DB) Get(id string) *databroker.Record {
record, ok := db.byID.Get(byIDRecord{Record: &databroker.Record{Id: id}}).(byIDRecord)
if !ok {
return nil
}
return record.Record
}
// GetAll gets all the records in the db.
func (db *DB) GetAll() []*databroker.Record {
var records []*databroker.Record
db.byID.Ascend(func(item btree.Item) bool {
records = append(records, item.(byIDRecord).Record)
return true
})
return records
}
// List lists all the changes since the given version.
func (db *DB) List(sinceVersion string) []*databroker.Record {
var records []*databroker.Record
db.byVersion.AscendGreaterOrEqual(byVersionRecord{Record: &databroker.Record{Version: sinceVersion}}, func(i btree.Item) bool {
record := i.(byVersionRecord)
if record.Version > sinceVersion {
records = append(records, record.Record)
}
return true
})
return records
}
// Set replaces or inserts a record in the db.
func (db *DB) Set(id string, data *anypb.Any) {
db.replaceOrInsert(id, func(record *databroker.Record) {
record.Data = data
})
}
func (db *DB) replaceOrInsert(id string, f func(record *databroker.Record)) {
db.mu.Lock()
defer db.mu.Unlock()
record, ok := db.byID.Get(byIDRecord{Record: &databroker.Record{Id: id}}).(byIDRecord)
if ok {
db.byVersion.Delete(byVersionRecord(record))
record.Record = proto.Clone(record.Record).(*databroker.Record)
} else {
record.Record = new(databroker.Record)
}
f(record.Record)
if record.CreatedAt == nil {
record.CreatedAt = ptypes.TimestampNow()
}
record.ModifiedAt = ptypes.TimestampNow()
record.Type = db.recordType
record.Id = id
record.Version = fmt.Sprintf("%012X", atomic.AddUint64(&db.lastVersion, 1))
db.byID.ReplaceOrInsert(record)
db.byVersion.ReplaceOrInsert(byVersionRecord(record))
}

View file

@ -1,60 +0,0 @@
package memory
import (
"fmt"
"testing"
"time"
"github.com/stretchr/testify/assert"
"google.golang.org/protobuf/types/known/anypb"
)
func TestDB(t *testing.T) {
db := NewDB("example", 2)
t.Run("get missing record", func(t *testing.T) {
assert.Nil(t, db.Get("abcd"))
})
t.Run("get record", func(t *testing.T) {
data := new(anypb.Any)
db.Set("abcd", data)
record := db.Get("abcd")
if assert.NotNil(t, record) {
assert.NotNil(t, record.CreatedAt)
assert.Equal(t, data, record.Data)
assert.Nil(t, record.DeletedAt)
assert.Equal(t, "abcd", record.Id)
assert.NotNil(t, record.ModifiedAt)
assert.Equal(t, "example", record.Type)
assert.Equal(t, "000000000001", record.Version)
}
})
t.Run("delete record", func(t *testing.T) {
db.Delete("abcd")
record := db.Get("abcd")
if assert.NotNil(t, record) {
assert.NotNil(t, record.DeletedAt)
}
})
t.Run("clear deleted", func(t *testing.T) {
db.ClearDeleted(time.Now().Add(time.Second))
assert.Nil(t, db.Get("abcd"))
})
t.Run("keep remaining", func(t *testing.T) {
data := new(anypb.Any)
db.Set("abcd", data)
db.Delete("abcd")
db.ClearDeleted(time.Now().Add(-10 * time.Second))
assert.NotNil(t, db.Get("abcd"))
db.ClearDeleted(time.Now().Add(time.Second))
})
t.Run("list", func(t *testing.T) {
for i := 0; i < 10; i++ {
data := new(anypb.Any)
db.Set(fmt.Sprintf("%02d", i), data)
}
assert.Len(t, db.List(""), 10)
assert.Len(t, db.List("00000000000A"), 4)
assert.Len(t, db.List("00000000000F"), 0)
})
}

View file

@ -1,5 +1,5 @@
// Package memory contains an in-memory data broker implementation.
package memory
// Package databroker contains a data broker implementation.
package databroker
import (
"context"
@ -18,6 +18,8 @@ import (
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/telemetry/trace"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/storage"
"github.com/pomerium/pomerium/pkg/storage/inmemory"
)
// Server implements the databroker service using an in memory database.
@ -27,7 +29,7 @@ type Server struct {
log zerolog.Logger
mu sync.RWMutex
byType map[string]*DB
byType map[string]storage.Backend
onchange *Signal
}
@ -39,7 +41,7 @@ func New(options ...ServerOption) *Server {
cfg: cfg,
log: log.With().Str("service", "databroker").Logger(),
byType: make(map[string]*DB),
byType: make(map[string]storage.Backend),
onchange: NewSignal(),
}
go func() {
@ -55,7 +57,7 @@ func New(options ...ServerOption) *Server {
srv.mu.RUnlock()
for _, recordType := range recordTypes {
srv.getDB(recordType).ClearDeleted(time.Now().Add(-cfg.deletePermanentlyAfter))
srv.getDB(recordType).ClearDeleted(context.Background(), time.Now().Add(-cfg.deletePermanentlyAfter))
}
}
}()
@ -73,7 +75,9 @@ func (srv *Server) Delete(ctx context.Context, req *databroker.DeleteRequest) (*
defer srv.onchange.Broadcast()
srv.getDB(req.GetType()).Delete(req.GetId())
if err := srv.getDB(req.GetType()).Delete(ctx, req.GetId()); err != nil {
return nil, err
}
return new(empty.Empty), nil
}
@ -87,7 +91,7 @@ func (srv *Server) Get(ctx context.Context, req *databroker.GetRequest) (*databr
Str("id", req.GetId()).
Msg("get")
record := srv.getDB(req.GetType()).Get(req.GetId())
record := srv.getDB(req.GetType()).Get(ctx, req.GetId())
if record == nil {
return nil, status.Error(codes.NotFound, "record not found")
}
@ -102,7 +106,7 @@ func (srv *Server) GetAll(ctx context.Context, req *databroker.GetAllRequest) (*
Str("type", req.GetType()).
Msg("get all")
records := srv.getDB(req.GetType()).GetAll()
records := srv.getDB(req.GetType()).GetAll(ctx)
var recordVersion string
for _, record := range records {
if record.GetVersion() > recordVersion {
@ -128,8 +132,10 @@ func (srv *Server) Set(ctx context.Context, req *databroker.SetRequest) (*databr
defer srv.onchange.Broadcast()
db := srv.getDB(req.GetType())
db.Set(req.GetId(), req.GetData())
record := db.Get(req.GetId())
if err := db.Put(ctx, req.GetId(), req.GetData()); err != nil {
return nil, err
}
record := db.Get(ctx, req.GetId())
return &databroker.SetResponse{
Record: record,
@ -158,7 +164,7 @@ func (srv *Server) Sync(req *databroker.SyncRequest, stream databroker.DataBroke
ch := srv.onchange.Bind()
defer srv.onchange.Unbind(ch)
for {
updated := db.List(recordVersion)
updated := db.List(context.Background(), recordVersion)
if len(updated) > 0 {
sort.Slice(updated, func(i, j int) bool {
@ -232,7 +238,7 @@ func (srv *Server) SyncTypes(req *emptypb.Empty, stream databroker.DataBrokerSer
}
}
func (srv *Server) getDB(recordType string) *DB {
func (srv *Server) getDB(recordType string) storage.Backend {
// double-checked locking:
// first try the read lock, then re-try with the write lock, and finally create a new db if nil
srv.mu.RLock()
@ -242,7 +248,7 @@ func (srv *Server) getDB(recordType string) *DB {
srv.mu.Lock()
db = srv.byType[recordType]
if db == nil {
db = NewDB(recordType, srv.cfg.btreeDegree)
db = inmemory.NewDB(recordType, srv.cfg.btreeDegree)
srv.byType[recordType] = db
}
srv.mu.Unlock()

View file

@ -1,4 +1,4 @@
package memory
package databroker
import "sync"