Add storage backend interface (#1072)

* pkg: add storage package

Which contains storage.Backend interface to initial support for multiple
backend storage.

* pkg/storage: add inmemory storage

* internal/databroker: use storage.Backend interface

Instead of implementing multiple databroker server implementation for
each kind of storage backend, we use only one databroker server
implementation, which is supported multiple storage backends, which
satisfy storage.Backend interface.
This commit is contained in:
Cuong Manh Le 2020-07-15 09:42:01 +07:00 committed by GitHub
parent a70254ab76
commit 2f84dd2aff
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 88 additions and 45 deletions

7
cache/databroker.go vendored
View file

@ -3,7 +3,7 @@ package cache
import ( import (
"google.golang.org/grpc" "google.golang.org/grpc"
"github.com/pomerium/pomerium/internal/databroker/memory" internal_databroker "github.com/pomerium/pomerium/internal/databroker"
"github.com/pomerium/pomerium/pkg/grpc/databroker" "github.com/pomerium/pomerium/pkg/grpc/databroker"
) )
@ -14,10 +14,7 @@ type DataBrokerServer struct {
// NewDataBrokerServer creates a new databroker service server. // NewDataBrokerServer creates a new databroker service server.
func NewDataBrokerServer(grpcServer *grpc.Server) *DataBrokerServer { func NewDataBrokerServer(grpcServer *grpc.Server) *DataBrokerServer {
srv := &DataBrokerServer{ srv := &DataBrokerServer{DataBrokerServiceServer: internal_databroker.New()}
// just wrap the in-memory data broker server
DataBrokerServiceServer: memory.New(),
}
databroker.RegisterDataBrokerServiceServer(grpcServer, srv) databroker.RegisterDataBrokerServiceServer(grpcServer, srv)
return srv return srv
} }

1
go.sum
View file

@ -737,6 +737,7 @@ google.golang.org/api v0.18.0/go.mod h1:BwFmGc8tA3vsd7r/7kR8DY7iEEGSU04BFxCo5jP/
google.golang.org/api v0.20.0/go.mod h1:BwFmGc8tA3vsd7r/7kR8DY7iEEGSU04BFxCo5jP/sfE= google.golang.org/api v0.20.0/go.mod h1:BwFmGc8tA3vsd7r/7kR8DY7iEEGSU04BFxCo5jP/sfE=
google.golang.org/api v0.28.0 h1:jMF5hhVfMkTZwHW1SDpKq5CkgWLXOb31Foaca9Zr3oM= google.golang.org/api v0.28.0 h1:jMF5hhVfMkTZwHW1SDpKq5CkgWLXOb31Foaca9Zr3oM=
google.golang.org/api v0.28.0/go.mod h1:lIXQywCXRcnZPGlsd8NbLnOjtAoL6em04bJ9+z0MncE= google.golang.org/api v0.28.0/go.mod h1:lIXQywCXRcnZPGlsd8NbLnOjtAoL6em04bJ9+z0MncE=
google.golang.org/api v0.29.0 h1:BaiDisFir8O4IJxvAabCGGkQ6yCJegNQqSVoYUNAnbk=
google.golang.org/api v0.29.0/go.mod h1:Lcubydp8VUV7KeIHD9z2Bys/sm/vGKnG1UHuDBSrHWM= google.golang.org/api v0.29.0/go.mod h1:Lcubydp8VUV7KeIHD9z2Bys/sm/vGKnG1UHuDBSrHWM=
google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM=
google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=

View file

@ -1,4 +1,4 @@
package memory package databroker
import "time" import "time"

View file

@ -1,5 +1,5 @@
// Package memory contains an in-memory data broker implementation. // Package databroker contains a data broker implementation.
package memory package databroker
import ( import (
"context" "context"
@ -18,6 +18,8 @@ import (
"github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/telemetry/trace" "github.com/pomerium/pomerium/internal/telemetry/trace"
"github.com/pomerium/pomerium/pkg/grpc/databroker" "github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/storage"
"github.com/pomerium/pomerium/pkg/storage/inmemory"
) )
// Server implements the databroker service using an in memory database. // Server implements the databroker service using an in memory database.
@ -27,7 +29,7 @@ type Server struct {
log zerolog.Logger log zerolog.Logger
mu sync.RWMutex mu sync.RWMutex
byType map[string]*DB byType map[string]storage.Backend
onchange *Signal onchange *Signal
} }
@ -39,7 +41,7 @@ func New(options ...ServerOption) *Server {
cfg: cfg, cfg: cfg,
log: log.With().Str("service", "databroker").Logger(), log: log.With().Str("service", "databroker").Logger(),
byType: make(map[string]*DB), byType: make(map[string]storage.Backend),
onchange: NewSignal(), onchange: NewSignal(),
} }
go func() { go func() {
@ -55,7 +57,7 @@ func New(options ...ServerOption) *Server {
srv.mu.RUnlock() srv.mu.RUnlock()
for _, recordType := range recordTypes { for _, recordType := range recordTypes {
srv.getDB(recordType).ClearDeleted(time.Now().Add(-cfg.deletePermanentlyAfter)) srv.getDB(recordType).ClearDeleted(context.Background(), time.Now().Add(-cfg.deletePermanentlyAfter))
} }
} }
}() }()
@ -73,7 +75,9 @@ func (srv *Server) Delete(ctx context.Context, req *databroker.DeleteRequest) (*
defer srv.onchange.Broadcast() defer srv.onchange.Broadcast()
srv.getDB(req.GetType()).Delete(req.GetId()) if err := srv.getDB(req.GetType()).Delete(ctx, req.GetId()); err != nil {
return nil, err
}
return new(empty.Empty), nil return new(empty.Empty), nil
} }
@ -87,7 +91,7 @@ func (srv *Server) Get(ctx context.Context, req *databroker.GetRequest) (*databr
Str("id", req.GetId()). Str("id", req.GetId()).
Msg("get") Msg("get")
record := srv.getDB(req.GetType()).Get(req.GetId()) record := srv.getDB(req.GetType()).Get(ctx, req.GetId())
if record == nil { if record == nil {
return nil, status.Error(codes.NotFound, "record not found") return nil, status.Error(codes.NotFound, "record not found")
} }
@ -102,7 +106,7 @@ func (srv *Server) GetAll(ctx context.Context, req *databroker.GetAllRequest) (*
Str("type", req.GetType()). Str("type", req.GetType()).
Msg("get all") Msg("get all")
records := srv.getDB(req.GetType()).GetAll() records := srv.getDB(req.GetType()).GetAll(ctx)
var recordVersion string var recordVersion string
for _, record := range records { for _, record := range records {
if record.GetVersion() > recordVersion { if record.GetVersion() > recordVersion {
@ -128,8 +132,10 @@ func (srv *Server) Set(ctx context.Context, req *databroker.SetRequest) (*databr
defer srv.onchange.Broadcast() defer srv.onchange.Broadcast()
db := srv.getDB(req.GetType()) db := srv.getDB(req.GetType())
db.Set(req.GetId(), req.GetData()) if err := db.Put(ctx, req.GetId(), req.GetData()); err != nil {
record := db.Get(req.GetId()) return nil, err
}
record := db.Get(ctx, req.GetId())
return &databroker.SetResponse{ return &databroker.SetResponse{
Record: record, Record: record,
@ -158,7 +164,7 @@ func (srv *Server) Sync(req *databroker.SyncRequest, stream databroker.DataBroke
ch := srv.onchange.Bind() ch := srv.onchange.Bind()
defer srv.onchange.Unbind(ch) defer srv.onchange.Unbind(ch)
for { for {
updated := db.List(recordVersion) updated := db.List(context.Background(), recordVersion)
if len(updated) > 0 { if len(updated) > 0 {
sort.Slice(updated, func(i, j int) bool { sort.Slice(updated, func(i, j int) bool {
@ -232,7 +238,7 @@ func (srv *Server) SyncTypes(req *emptypb.Empty, stream databroker.DataBrokerSer
} }
} }
func (srv *Server) getDB(recordType string) *DB { func (srv *Server) getDB(recordType string) storage.Backend {
// double-checked locking: // double-checked locking:
// first try the read lock, then re-try with the write lock, and finally create a new db if nil // first try the read lock, then re-try with the write lock, and finally create a new db if nil
srv.mu.RLock() srv.mu.RLock()
@ -242,7 +248,7 @@ func (srv *Server) getDB(recordType string) *DB {
srv.mu.Lock() srv.mu.Lock()
db = srv.byType[recordType] db = srv.byType[recordType]
if db == nil { if db == nil {
db = NewDB(recordType, srv.cfg.btreeDegree) db = inmemory.NewDB(recordType, srv.cfg.btreeDegree)
srv.byType[recordType] = db srv.byType[recordType] = db
} }
srv.mu.Unlock() srv.mu.Unlock()

View file

@ -1,4 +1,4 @@
package memory package databroker
import "sync" import "sync"

View file

@ -1,6 +1,7 @@
package memory package inmemory
import ( import (
"context"
"fmt" "fmt"
"sync" "sync"
"sync/atomic" "sync/atomic"
@ -12,8 +13,11 @@ import (
"google.golang.org/protobuf/types/known/anypb" "google.golang.org/protobuf/types/known/anypb"
"github.com/pomerium/pomerium/pkg/grpc/databroker" "github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/storage"
) )
var _ storage.Backend = (*DB)(nil)
type byIDRecord struct { type byIDRecord struct {
*databroker.Record *databroker.Record
} }
@ -52,7 +56,7 @@ func NewDB(recordType string, btreeDegree int) *DB {
} }
// ClearDeleted clears all the currently deleted records older than the given cutoff. // ClearDeleted clears all the currently deleted records older than the given cutoff.
func (db *DB) ClearDeleted(cutoff time.Time) { func (db *DB) ClearDeleted(_ context.Context, cutoff time.Time) {
db.mu.Lock() db.mu.Lock()
defer db.mu.Unlock() defer db.mu.Unlock()
@ -71,15 +75,16 @@ func (db *DB) ClearDeleted(cutoff time.Time) {
} }
// Delete marks a record as deleted. // Delete marks a record as deleted.
func (db *DB) Delete(id string) { func (db *DB) Delete(_ context.Context, id string) error {
db.replaceOrInsert(id, func(record *databroker.Record) { db.replaceOrInsert(id, func(record *databroker.Record) {
record.DeletedAt = ptypes.TimestampNow() record.DeletedAt = ptypes.TimestampNow()
db.deletedIDs = append(db.deletedIDs, id) db.deletedIDs = append(db.deletedIDs, id)
}) })
return nil
} }
// Get gets a record from the db. // Get gets a record from the db.
func (db *DB) Get(id string) *databroker.Record { func (db *DB) Get(_ context.Context, id string) *databroker.Record {
record, ok := db.byID.Get(byIDRecord{Record: &databroker.Record{Id: id}}).(byIDRecord) record, ok := db.byID.Get(byIDRecord{Record: &databroker.Record{Id: id}}).(byIDRecord)
if !ok { if !ok {
return nil return nil
@ -88,7 +93,7 @@ func (db *DB) Get(id string) *databroker.Record {
} }
// GetAll gets all the records in the db. // GetAll gets all the records in the db.
func (db *DB) GetAll() []*databroker.Record { func (db *DB) GetAll(_ context.Context) []*databroker.Record {
var records []*databroker.Record var records []*databroker.Record
db.byID.Ascend(func(item btree.Item) bool { db.byID.Ascend(func(item btree.Item) bool {
records = append(records, item.(byIDRecord).Record) records = append(records, item.(byIDRecord).Record)
@ -98,7 +103,7 @@ func (db *DB) GetAll() []*databroker.Record {
} }
// List lists all the changes since the given version. // List lists all the changes since the given version.
func (db *DB) List(sinceVersion string) []*databroker.Record { func (db *DB) List(_ context.Context, sinceVersion string) []*databroker.Record {
var records []*databroker.Record var records []*databroker.Record
db.byVersion.AscendGreaterOrEqual(byVersionRecord{Record: &databroker.Record{Version: sinceVersion}}, func(i btree.Item) bool { db.byVersion.AscendGreaterOrEqual(byVersionRecord{Record: &databroker.Record{Version: sinceVersion}}, func(i btree.Item) bool {
record := i.(byVersionRecord) record := i.(byVersionRecord)
@ -110,11 +115,12 @@ func (db *DB) List(sinceVersion string) []*databroker.Record {
return records return records
} }
// Set replaces or inserts a record in the db. // Put replaces or inserts a record in the db.
func (db *DB) Set(id string, data *anypb.Any) { func (db *DB) Put(_ context.Context, id string, data *anypb.Any) error {
db.replaceOrInsert(id, func(record *databroker.Record) { db.replaceOrInsert(id, func(record *databroker.Record) {
record.Data = data record.Data = data
}) })
return nil
} }
func (db *DB) replaceOrInsert(id string, f func(record *databroker.Record)) { func (db *DB) replaceOrInsert(id string, f func(record *databroker.Record)) {

View file

@ -1,6 +1,7 @@
package memory package inmemory
import ( import (
"context"
"fmt" "fmt"
"testing" "testing"
"time" "time"
@ -10,14 +11,15 @@ import (
) )
func TestDB(t *testing.T) { func TestDB(t *testing.T) {
ctx := context.Background()
db := NewDB("example", 2) db := NewDB("example", 2)
t.Run("get missing record", func(t *testing.T) { t.Run("get missing record", func(t *testing.T) {
assert.Nil(t, db.Get("abcd")) assert.Nil(t, db.Get(ctx, "abcd"))
}) })
t.Run("get record", func(t *testing.T) { t.Run("get record", func(t *testing.T) {
data := new(anypb.Any) data := new(anypb.Any)
db.Set("abcd", data) assert.NoError(t, db.Put(ctx, "abcd", data))
record := db.Get("abcd") record := db.Get(ctx, "abcd")
if assert.NotNil(t, record) { if assert.NotNil(t, record) {
assert.NotNil(t, record.CreatedAt) assert.NotNil(t, record.CreatedAt)
assert.Equal(t, data, record.Data) assert.Equal(t, data, record.Data)
@ -29,32 +31,32 @@ func TestDB(t *testing.T) {
} }
}) })
t.Run("delete record", func(t *testing.T) { t.Run("delete record", func(t *testing.T) {
db.Delete("abcd") assert.NoError(t, db.Delete(ctx, "abcd"))
record := db.Get("abcd") record := db.Get(ctx, "abcd")
if assert.NotNil(t, record) { if assert.NotNil(t, record) {
assert.NotNil(t, record.DeletedAt) assert.NotNil(t, record.DeletedAt)
} }
}) })
t.Run("clear deleted", func(t *testing.T) { t.Run("clear deleted", func(t *testing.T) {
db.ClearDeleted(time.Now().Add(time.Second)) db.ClearDeleted(ctx, time.Now().Add(time.Second))
assert.Nil(t, db.Get("abcd")) assert.Nil(t, db.Get(ctx, "abcd"))
}) })
t.Run("keep remaining", func(t *testing.T) { t.Run("keep remaining", func(t *testing.T) {
data := new(anypb.Any) data := new(anypb.Any)
db.Set("abcd", data) assert.NoError(t, db.Put(ctx, "abcd", data))
db.Delete("abcd") assert.NoError(t, db.Delete(ctx, "abcd"))
db.ClearDeleted(time.Now().Add(-10 * time.Second)) db.ClearDeleted(ctx, time.Now().Add(-10*time.Second))
assert.NotNil(t, db.Get("abcd")) assert.NotNil(t, db.Get(ctx, "abcd"))
db.ClearDeleted(time.Now().Add(time.Second)) db.ClearDeleted(ctx, time.Now().Add(time.Second))
}) })
t.Run("list", func(t *testing.T) { t.Run("list", func(t *testing.T) {
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
data := new(anypb.Any) data := new(anypb.Any)
db.Set(fmt.Sprintf("%02d", i), data) assert.NoError(t, db.Put(ctx, fmt.Sprintf("%02d", i), data))
} }
assert.Len(t, db.List(""), 10) assert.Len(t, db.List(ctx, ""), 10)
assert.Len(t, db.List("00000000000A"), 4) assert.Len(t, db.List(ctx, "00000000000A"), 4)
assert.Len(t, db.List("00000000000F"), 0) assert.Len(t, db.List(ctx, "00000000000F"), 0)
}) })
} }

31
pkg/storage/storage.go Normal file
View file

@ -0,0 +1,31 @@
package storage
import (
"context"
"time"
"google.golang.org/protobuf/types/known/anypb"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
)
// Backend is the interface required for a storage backend.
type Backend interface {
// Put is used to insert or update a record.
Put(ctx context.Context, id string, data *anypb.Any) error
// Get is used to retrieve a record.
Get(ctx context.Context, id string) *databroker.Record
// GetAll is used to retrieve all the records.
GetAll(ctx context.Context) []*databroker.Record
// List is used to retrieve all the records since a version.
List(ctx context.Context, sinceVersion string) []*databroker.Record
// Delete is used to mark a record as deleted.
Delete(ctx context.Context, id string) error
// ClearDeleted is used clear marked delete records.
ClearDeleted(ctx context.Context, cutoff time.Time)
}