mirror of
https://github.com/pomerium/pomerium.git
synced 2025-08-04 01:09:36 +02:00
Add storage backend interface (#1072)
* pkg: add storage package Which contains storage.Backend interface to initial support for multiple backend storage. * pkg/storage: add inmemory storage * internal/databroker: use storage.Backend interface Instead of implementing multiple databroker server implementation for each kind of storage backend, we use only one databroker server implementation, which is supported multiple storage backends, which satisfy storage.Backend interface.
This commit is contained in:
parent
a70254ab76
commit
2f84dd2aff
8 changed files with 88 additions and 45 deletions
|
@ -1,4 +1,4 @@
|
|||
package memory
|
||||
package databroker
|
||||
|
||||
import "time"
|
||||
|
|
@ -1,141 +0,0 @@
|
|||
package memory
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/golang/protobuf/ptypes"
|
||||
"github.com/google/btree"
|
||||
"google.golang.org/protobuf/proto"
|
||||
"google.golang.org/protobuf/types/known/anypb"
|
||||
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
)
|
||||
|
||||
type byIDRecord struct {
|
||||
*databroker.Record
|
||||
}
|
||||
|
||||
func (k byIDRecord) Less(than btree.Item) bool {
|
||||
return k.Id < than.(byIDRecord).Id
|
||||
}
|
||||
|
||||
type byVersionRecord struct {
|
||||
*databroker.Record
|
||||
}
|
||||
|
||||
func (k byVersionRecord) Less(than btree.Item) bool {
|
||||
return k.Version < than.(byVersionRecord).Version
|
||||
}
|
||||
|
||||
// DB is an in-memory database of records using b-trees.
|
||||
type DB struct {
|
||||
recordType string
|
||||
|
||||
lastVersion uint64
|
||||
|
||||
mu sync.Mutex
|
||||
byID *btree.BTree
|
||||
byVersion *btree.BTree
|
||||
deletedIDs []string
|
||||
}
|
||||
|
||||
// NewDB creates a new in-memory database for the given record type.
|
||||
func NewDB(recordType string, btreeDegree int) *DB {
|
||||
return &DB{
|
||||
recordType: recordType,
|
||||
byID: btree.New(btreeDegree),
|
||||
byVersion: btree.New(btreeDegree),
|
||||
}
|
||||
}
|
||||
|
||||
// ClearDeleted clears all the currently deleted records older than the given cutoff.
|
||||
func (db *DB) ClearDeleted(cutoff time.Time) {
|
||||
db.mu.Lock()
|
||||
defer db.mu.Unlock()
|
||||
|
||||
var remaining []string
|
||||
for _, id := range db.deletedIDs {
|
||||
record, _ := db.byID.Get(byIDRecord{Record: &databroker.Record{Id: id}}).(byIDRecord)
|
||||
ts, _ := ptypes.Timestamp(record.DeletedAt)
|
||||
if ts.Before(cutoff) {
|
||||
db.byID.Delete(record)
|
||||
db.byVersion.Delete(byVersionRecord(record))
|
||||
} else {
|
||||
remaining = append(remaining, id)
|
||||
}
|
||||
}
|
||||
db.deletedIDs = remaining
|
||||
}
|
||||
|
||||
// Delete marks a record as deleted.
|
||||
func (db *DB) Delete(id string) {
|
||||
db.replaceOrInsert(id, func(record *databroker.Record) {
|
||||
record.DeletedAt = ptypes.TimestampNow()
|
||||
db.deletedIDs = append(db.deletedIDs, id)
|
||||
})
|
||||
}
|
||||
|
||||
// Get gets a record from the db.
|
||||
func (db *DB) Get(id string) *databroker.Record {
|
||||
record, ok := db.byID.Get(byIDRecord{Record: &databroker.Record{Id: id}}).(byIDRecord)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
return record.Record
|
||||
}
|
||||
|
||||
// GetAll gets all the records in the db.
|
||||
func (db *DB) GetAll() []*databroker.Record {
|
||||
var records []*databroker.Record
|
||||
db.byID.Ascend(func(item btree.Item) bool {
|
||||
records = append(records, item.(byIDRecord).Record)
|
||||
return true
|
||||
})
|
||||
return records
|
||||
}
|
||||
|
||||
// List lists all the changes since the given version.
|
||||
func (db *DB) List(sinceVersion string) []*databroker.Record {
|
||||
var records []*databroker.Record
|
||||
db.byVersion.AscendGreaterOrEqual(byVersionRecord{Record: &databroker.Record{Version: sinceVersion}}, func(i btree.Item) bool {
|
||||
record := i.(byVersionRecord)
|
||||
if record.Version > sinceVersion {
|
||||
records = append(records, record.Record)
|
||||
}
|
||||
return true
|
||||
})
|
||||
return records
|
||||
}
|
||||
|
||||
// Set replaces or inserts a record in the db.
|
||||
func (db *DB) Set(id string, data *anypb.Any) {
|
||||
db.replaceOrInsert(id, func(record *databroker.Record) {
|
||||
record.Data = data
|
||||
})
|
||||
}
|
||||
|
||||
func (db *DB) replaceOrInsert(id string, f func(record *databroker.Record)) {
|
||||
db.mu.Lock()
|
||||
defer db.mu.Unlock()
|
||||
|
||||
record, ok := db.byID.Get(byIDRecord{Record: &databroker.Record{Id: id}}).(byIDRecord)
|
||||
if ok {
|
||||
db.byVersion.Delete(byVersionRecord(record))
|
||||
record.Record = proto.Clone(record.Record).(*databroker.Record)
|
||||
} else {
|
||||
record.Record = new(databroker.Record)
|
||||
}
|
||||
f(record.Record)
|
||||
if record.CreatedAt == nil {
|
||||
record.CreatedAt = ptypes.TimestampNow()
|
||||
}
|
||||
record.ModifiedAt = ptypes.TimestampNow()
|
||||
record.Type = db.recordType
|
||||
record.Id = id
|
||||
record.Version = fmt.Sprintf("%012X", atomic.AddUint64(&db.lastVersion, 1))
|
||||
db.byID.ReplaceOrInsert(record)
|
||||
db.byVersion.ReplaceOrInsert(byVersionRecord(record))
|
||||
}
|
|
@ -1,60 +0,0 @@
|
|||
package memory
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"google.golang.org/protobuf/types/known/anypb"
|
||||
)
|
||||
|
||||
func TestDB(t *testing.T) {
|
||||
db := NewDB("example", 2)
|
||||
t.Run("get missing record", func(t *testing.T) {
|
||||
assert.Nil(t, db.Get("abcd"))
|
||||
})
|
||||
t.Run("get record", func(t *testing.T) {
|
||||
data := new(anypb.Any)
|
||||
db.Set("abcd", data)
|
||||
record := db.Get("abcd")
|
||||
if assert.NotNil(t, record) {
|
||||
assert.NotNil(t, record.CreatedAt)
|
||||
assert.Equal(t, data, record.Data)
|
||||
assert.Nil(t, record.DeletedAt)
|
||||
assert.Equal(t, "abcd", record.Id)
|
||||
assert.NotNil(t, record.ModifiedAt)
|
||||
assert.Equal(t, "example", record.Type)
|
||||
assert.Equal(t, "000000000001", record.Version)
|
||||
}
|
||||
})
|
||||
t.Run("delete record", func(t *testing.T) {
|
||||
db.Delete("abcd")
|
||||
record := db.Get("abcd")
|
||||
if assert.NotNil(t, record) {
|
||||
assert.NotNil(t, record.DeletedAt)
|
||||
}
|
||||
})
|
||||
t.Run("clear deleted", func(t *testing.T) {
|
||||
db.ClearDeleted(time.Now().Add(time.Second))
|
||||
assert.Nil(t, db.Get("abcd"))
|
||||
})
|
||||
t.Run("keep remaining", func(t *testing.T) {
|
||||
data := new(anypb.Any)
|
||||
db.Set("abcd", data)
|
||||
db.Delete("abcd")
|
||||
db.ClearDeleted(time.Now().Add(-10 * time.Second))
|
||||
assert.NotNil(t, db.Get("abcd"))
|
||||
db.ClearDeleted(time.Now().Add(time.Second))
|
||||
})
|
||||
t.Run("list", func(t *testing.T) {
|
||||
for i := 0; i < 10; i++ {
|
||||
data := new(anypb.Any)
|
||||
db.Set(fmt.Sprintf("%02d", i), data)
|
||||
}
|
||||
|
||||
assert.Len(t, db.List(""), 10)
|
||||
assert.Len(t, db.List("00000000000A"), 4)
|
||||
assert.Len(t, db.List("00000000000F"), 0)
|
||||
})
|
||||
}
|
|
@ -1,5 +1,5 @@
|
|||
// Package memory contains an in-memory data broker implementation.
|
||||
package memory
|
||||
// Package databroker contains a data broker implementation.
|
||||
package databroker
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
@ -18,6 +18,8 @@ import (
|
|||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/pkg/storage"
|
||||
"github.com/pomerium/pomerium/pkg/storage/inmemory"
|
||||
)
|
||||
|
||||
// Server implements the databroker service using an in memory database.
|
||||
|
@ -27,7 +29,7 @@ type Server struct {
|
|||
log zerolog.Logger
|
||||
|
||||
mu sync.RWMutex
|
||||
byType map[string]*DB
|
||||
byType map[string]storage.Backend
|
||||
onchange *Signal
|
||||
}
|
||||
|
||||
|
@ -39,7 +41,7 @@ func New(options ...ServerOption) *Server {
|
|||
cfg: cfg,
|
||||
log: log.With().Str("service", "databroker").Logger(),
|
||||
|
||||
byType: make(map[string]*DB),
|
||||
byType: make(map[string]storage.Backend),
|
||||
onchange: NewSignal(),
|
||||
}
|
||||
go func() {
|
||||
|
@ -55,7 +57,7 @@ func New(options ...ServerOption) *Server {
|
|||
srv.mu.RUnlock()
|
||||
|
||||
for _, recordType := range recordTypes {
|
||||
srv.getDB(recordType).ClearDeleted(time.Now().Add(-cfg.deletePermanentlyAfter))
|
||||
srv.getDB(recordType).ClearDeleted(context.Background(), time.Now().Add(-cfg.deletePermanentlyAfter))
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
@ -73,7 +75,9 @@ func (srv *Server) Delete(ctx context.Context, req *databroker.DeleteRequest) (*
|
|||
|
||||
defer srv.onchange.Broadcast()
|
||||
|
||||
srv.getDB(req.GetType()).Delete(req.GetId())
|
||||
if err := srv.getDB(req.GetType()).Delete(ctx, req.GetId()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return new(empty.Empty), nil
|
||||
}
|
||||
|
@ -87,7 +91,7 @@ func (srv *Server) Get(ctx context.Context, req *databroker.GetRequest) (*databr
|
|||
Str("id", req.GetId()).
|
||||
Msg("get")
|
||||
|
||||
record := srv.getDB(req.GetType()).Get(req.GetId())
|
||||
record := srv.getDB(req.GetType()).Get(ctx, req.GetId())
|
||||
if record == nil {
|
||||
return nil, status.Error(codes.NotFound, "record not found")
|
||||
}
|
||||
|
@ -102,7 +106,7 @@ func (srv *Server) GetAll(ctx context.Context, req *databroker.GetAllRequest) (*
|
|||
Str("type", req.GetType()).
|
||||
Msg("get all")
|
||||
|
||||
records := srv.getDB(req.GetType()).GetAll()
|
||||
records := srv.getDB(req.GetType()).GetAll(ctx)
|
||||
var recordVersion string
|
||||
for _, record := range records {
|
||||
if record.GetVersion() > recordVersion {
|
||||
|
@ -128,8 +132,10 @@ func (srv *Server) Set(ctx context.Context, req *databroker.SetRequest) (*databr
|
|||
defer srv.onchange.Broadcast()
|
||||
|
||||
db := srv.getDB(req.GetType())
|
||||
db.Set(req.GetId(), req.GetData())
|
||||
record := db.Get(req.GetId())
|
||||
if err := db.Put(ctx, req.GetId(), req.GetData()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
record := db.Get(ctx, req.GetId())
|
||||
|
||||
return &databroker.SetResponse{
|
||||
Record: record,
|
||||
|
@ -158,7 +164,7 @@ func (srv *Server) Sync(req *databroker.SyncRequest, stream databroker.DataBroke
|
|||
ch := srv.onchange.Bind()
|
||||
defer srv.onchange.Unbind(ch)
|
||||
for {
|
||||
updated := db.List(recordVersion)
|
||||
updated := db.List(context.Background(), recordVersion)
|
||||
|
||||
if len(updated) > 0 {
|
||||
sort.Slice(updated, func(i, j int) bool {
|
||||
|
@ -232,7 +238,7 @@ func (srv *Server) SyncTypes(req *emptypb.Empty, stream databroker.DataBrokerSer
|
|||
}
|
||||
}
|
||||
|
||||
func (srv *Server) getDB(recordType string) *DB {
|
||||
func (srv *Server) getDB(recordType string) storage.Backend {
|
||||
// double-checked locking:
|
||||
// first try the read lock, then re-try with the write lock, and finally create a new db if nil
|
||||
srv.mu.RLock()
|
||||
|
@ -242,7 +248,7 @@ func (srv *Server) getDB(recordType string) *DB {
|
|||
srv.mu.Lock()
|
||||
db = srv.byType[recordType]
|
||||
if db == nil {
|
||||
db = NewDB(recordType, srv.cfg.btreeDegree)
|
||||
db = inmemory.NewDB(recordType, srv.cfg.btreeDegree)
|
||||
srv.byType[recordType] = db
|
||||
}
|
||||
srv.mu.Unlock()
|
|
@ -1,4 +1,4 @@
|
|||
package memory
|
||||
package databroker
|
||||
|
||||
import "sync"
|
||||
|
Loading…
Add table
Add a link
Reference in a new issue