Add storage backend interface (#1072)

* pkg: add storage package

Which contains storage.Backend interface to initial support for multiple
backend storage.

* pkg/storage: add inmemory storage

* internal/databroker: use storage.Backend interface

Instead of implementing multiple databroker server implementation for
each kind of storage backend, we use only one databroker server
implementation, which is supported multiple storage backends, which
satisfy storage.Backend interface.
This commit is contained in:
Cuong Manh Le 2020-07-15 09:42:01 +07:00 committed by GitHub
parent a70254ab76
commit 2f84dd2aff
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 88 additions and 45 deletions

7
cache/databroker.go vendored
View file

@ -3,7 +3,7 @@ package cache
import (
"google.golang.org/grpc"
"github.com/pomerium/pomerium/internal/databroker/memory"
internal_databroker "github.com/pomerium/pomerium/internal/databroker"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
)
@ -14,10 +14,7 @@ type DataBrokerServer struct {
// NewDataBrokerServer creates a new databroker service server.
func NewDataBrokerServer(grpcServer *grpc.Server) *DataBrokerServer {
srv := &DataBrokerServer{
// just wrap the in-memory data broker server
DataBrokerServiceServer: memory.New(),
}
srv := &DataBrokerServer{DataBrokerServiceServer: internal_databroker.New()}
databroker.RegisterDataBrokerServiceServer(grpcServer, srv)
return srv
}

1
go.sum
View file

@ -737,6 +737,7 @@ google.golang.org/api v0.18.0/go.mod h1:BwFmGc8tA3vsd7r/7kR8DY7iEEGSU04BFxCo5jP/
google.golang.org/api v0.20.0/go.mod h1:BwFmGc8tA3vsd7r/7kR8DY7iEEGSU04BFxCo5jP/sfE=
google.golang.org/api v0.28.0 h1:jMF5hhVfMkTZwHW1SDpKq5CkgWLXOb31Foaca9Zr3oM=
google.golang.org/api v0.28.0/go.mod h1:lIXQywCXRcnZPGlsd8NbLnOjtAoL6em04bJ9+z0MncE=
google.golang.org/api v0.29.0 h1:BaiDisFir8O4IJxvAabCGGkQ6yCJegNQqSVoYUNAnbk=
google.golang.org/api v0.29.0/go.mod h1:Lcubydp8VUV7KeIHD9z2Bys/sm/vGKnG1UHuDBSrHWM=
google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM=
google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=

View file

@ -1,4 +1,4 @@
package memory
package databroker
import "time"

View file

@ -1,5 +1,5 @@
// Package memory contains an in-memory data broker implementation.
package memory
// Package databroker contains a data broker implementation.
package databroker
import (
"context"
@ -18,6 +18,8 @@ import (
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/telemetry/trace"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/storage"
"github.com/pomerium/pomerium/pkg/storage/inmemory"
)
// Server implements the databroker service using an in memory database.
@ -27,7 +29,7 @@ type Server struct {
log zerolog.Logger
mu sync.RWMutex
byType map[string]*DB
byType map[string]storage.Backend
onchange *Signal
}
@ -39,7 +41,7 @@ func New(options ...ServerOption) *Server {
cfg: cfg,
log: log.With().Str("service", "databroker").Logger(),
byType: make(map[string]*DB),
byType: make(map[string]storage.Backend),
onchange: NewSignal(),
}
go func() {
@ -55,7 +57,7 @@ func New(options ...ServerOption) *Server {
srv.mu.RUnlock()
for _, recordType := range recordTypes {
srv.getDB(recordType).ClearDeleted(time.Now().Add(-cfg.deletePermanentlyAfter))
srv.getDB(recordType).ClearDeleted(context.Background(), time.Now().Add(-cfg.deletePermanentlyAfter))
}
}
}()
@ -73,7 +75,9 @@ func (srv *Server) Delete(ctx context.Context, req *databroker.DeleteRequest) (*
defer srv.onchange.Broadcast()
srv.getDB(req.GetType()).Delete(req.GetId())
if err := srv.getDB(req.GetType()).Delete(ctx, req.GetId()); err != nil {
return nil, err
}
return new(empty.Empty), nil
}
@ -87,7 +91,7 @@ func (srv *Server) Get(ctx context.Context, req *databroker.GetRequest) (*databr
Str("id", req.GetId()).
Msg("get")
record := srv.getDB(req.GetType()).Get(req.GetId())
record := srv.getDB(req.GetType()).Get(ctx, req.GetId())
if record == nil {
return nil, status.Error(codes.NotFound, "record not found")
}
@ -102,7 +106,7 @@ func (srv *Server) GetAll(ctx context.Context, req *databroker.GetAllRequest) (*
Str("type", req.GetType()).
Msg("get all")
records := srv.getDB(req.GetType()).GetAll()
records := srv.getDB(req.GetType()).GetAll(ctx)
var recordVersion string
for _, record := range records {
if record.GetVersion() > recordVersion {
@ -128,8 +132,10 @@ func (srv *Server) Set(ctx context.Context, req *databroker.SetRequest) (*databr
defer srv.onchange.Broadcast()
db := srv.getDB(req.GetType())
db.Set(req.GetId(), req.GetData())
record := db.Get(req.GetId())
if err := db.Put(ctx, req.GetId(), req.GetData()); err != nil {
return nil, err
}
record := db.Get(ctx, req.GetId())
return &databroker.SetResponse{
Record: record,
@ -158,7 +164,7 @@ func (srv *Server) Sync(req *databroker.SyncRequest, stream databroker.DataBroke
ch := srv.onchange.Bind()
defer srv.onchange.Unbind(ch)
for {
updated := db.List(recordVersion)
updated := db.List(context.Background(), recordVersion)
if len(updated) > 0 {
sort.Slice(updated, func(i, j int) bool {
@ -232,7 +238,7 @@ func (srv *Server) SyncTypes(req *emptypb.Empty, stream databroker.DataBrokerSer
}
}
func (srv *Server) getDB(recordType string) *DB {
func (srv *Server) getDB(recordType string) storage.Backend {
// double-checked locking:
// first try the read lock, then re-try with the write lock, and finally create a new db if nil
srv.mu.RLock()
@ -242,7 +248,7 @@ func (srv *Server) getDB(recordType string) *DB {
srv.mu.Lock()
db = srv.byType[recordType]
if db == nil {
db = NewDB(recordType, srv.cfg.btreeDegree)
db = inmemory.NewDB(recordType, srv.cfg.btreeDegree)
srv.byType[recordType] = db
}
srv.mu.Unlock()

View file

@ -1,4 +1,4 @@
package memory
package databroker
import "sync"

View file

@ -1,6 +1,7 @@
package memory
package inmemory
import (
"context"
"fmt"
"sync"
"sync/atomic"
@ -12,8 +13,11 @@ import (
"google.golang.org/protobuf/types/known/anypb"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/storage"
)
var _ storage.Backend = (*DB)(nil)
type byIDRecord struct {
*databroker.Record
}
@ -52,7 +56,7 @@ func NewDB(recordType string, btreeDegree int) *DB {
}
// ClearDeleted clears all the currently deleted records older than the given cutoff.
func (db *DB) ClearDeleted(cutoff time.Time) {
func (db *DB) ClearDeleted(_ context.Context, cutoff time.Time) {
db.mu.Lock()
defer db.mu.Unlock()
@ -71,15 +75,16 @@ func (db *DB) ClearDeleted(cutoff time.Time) {
}
// Delete marks a record as deleted.
func (db *DB) Delete(id string) {
func (db *DB) Delete(_ context.Context, id string) error {
db.replaceOrInsert(id, func(record *databroker.Record) {
record.DeletedAt = ptypes.TimestampNow()
db.deletedIDs = append(db.deletedIDs, id)
})
return nil
}
// Get gets a record from the db.
func (db *DB) Get(id string) *databroker.Record {
func (db *DB) Get(_ context.Context, id string) *databroker.Record {
record, ok := db.byID.Get(byIDRecord{Record: &databroker.Record{Id: id}}).(byIDRecord)
if !ok {
return nil
@ -88,7 +93,7 @@ func (db *DB) Get(id string) *databroker.Record {
}
// GetAll gets all the records in the db.
func (db *DB) GetAll() []*databroker.Record {
func (db *DB) GetAll(_ context.Context) []*databroker.Record {
var records []*databroker.Record
db.byID.Ascend(func(item btree.Item) bool {
records = append(records, item.(byIDRecord).Record)
@ -98,7 +103,7 @@ func (db *DB) GetAll() []*databroker.Record {
}
// List lists all the changes since the given version.
func (db *DB) List(sinceVersion string) []*databroker.Record {
func (db *DB) List(_ context.Context, sinceVersion string) []*databroker.Record {
var records []*databroker.Record
db.byVersion.AscendGreaterOrEqual(byVersionRecord{Record: &databroker.Record{Version: sinceVersion}}, func(i btree.Item) bool {
record := i.(byVersionRecord)
@ -110,11 +115,12 @@ func (db *DB) List(sinceVersion string) []*databroker.Record {
return records
}
// Set replaces or inserts a record in the db.
func (db *DB) Set(id string, data *anypb.Any) {
// Put replaces or inserts a record in the db.
func (db *DB) Put(_ context.Context, id string, data *anypb.Any) error {
db.replaceOrInsert(id, func(record *databroker.Record) {
record.Data = data
})
return nil
}
func (db *DB) replaceOrInsert(id string, f func(record *databroker.Record)) {

View file

@ -1,6 +1,7 @@
package memory
package inmemory
import (
"context"
"fmt"
"testing"
"time"
@ -10,14 +11,15 @@ import (
)
func TestDB(t *testing.T) {
ctx := context.Background()
db := NewDB("example", 2)
t.Run("get missing record", func(t *testing.T) {
assert.Nil(t, db.Get("abcd"))
assert.Nil(t, db.Get(ctx, "abcd"))
})
t.Run("get record", func(t *testing.T) {
data := new(anypb.Any)
db.Set("abcd", data)
record := db.Get("abcd")
assert.NoError(t, db.Put(ctx, "abcd", data))
record := db.Get(ctx, "abcd")
if assert.NotNil(t, record) {
assert.NotNil(t, record.CreatedAt)
assert.Equal(t, data, record.Data)
@ -29,32 +31,32 @@ func TestDB(t *testing.T) {
}
})
t.Run("delete record", func(t *testing.T) {
db.Delete("abcd")
record := db.Get("abcd")
assert.NoError(t, db.Delete(ctx, "abcd"))
record := db.Get(ctx, "abcd")
if assert.NotNil(t, record) {
assert.NotNil(t, record.DeletedAt)
}
})
t.Run("clear deleted", func(t *testing.T) {
db.ClearDeleted(time.Now().Add(time.Second))
assert.Nil(t, db.Get("abcd"))
db.ClearDeleted(ctx, time.Now().Add(time.Second))
assert.Nil(t, db.Get(ctx, "abcd"))
})
t.Run("keep remaining", func(t *testing.T) {
data := new(anypb.Any)
db.Set("abcd", data)
db.Delete("abcd")
db.ClearDeleted(time.Now().Add(-10 * time.Second))
assert.NotNil(t, db.Get("abcd"))
db.ClearDeleted(time.Now().Add(time.Second))
assert.NoError(t, db.Put(ctx, "abcd", data))
assert.NoError(t, db.Delete(ctx, "abcd"))
db.ClearDeleted(ctx, time.Now().Add(-10*time.Second))
assert.NotNil(t, db.Get(ctx, "abcd"))
db.ClearDeleted(ctx, time.Now().Add(time.Second))
})
t.Run("list", func(t *testing.T) {
for i := 0; i < 10; i++ {
data := new(anypb.Any)
db.Set(fmt.Sprintf("%02d", i), data)
assert.NoError(t, db.Put(ctx, fmt.Sprintf("%02d", i), data))
}
assert.Len(t, db.List(""), 10)
assert.Len(t, db.List("00000000000A"), 4)
assert.Len(t, db.List("00000000000F"), 0)
assert.Len(t, db.List(ctx, ""), 10)
assert.Len(t, db.List(ctx, "00000000000A"), 4)
assert.Len(t, db.List(ctx, "00000000000F"), 0)
})
}

31
pkg/storage/storage.go Normal file
View file

@ -0,0 +1,31 @@
package storage
import (
"context"
"time"
"google.golang.org/protobuf/types/known/anypb"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
)
// Backend is the interface required for a storage backend.
type Backend interface {
// Put is used to insert or update a record.
Put(ctx context.Context, id string, data *anypb.Any) error
// Get is used to retrieve a record.
Get(ctx context.Context, id string) *databroker.Record
// GetAll is used to retrieve all the records.
GetAll(ctx context.Context) []*databroker.Record
// List is used to retrieve all the records since a version.
List(ctx context.Context, sinceVersion string) []*databroker.Record
// Delete is used to mark a record as deleted.
Delete(ctx context.Context, id string) error
// ClearDeleted is used clear marked delete records.
ClearDeleted(ctx context.Context, cutoff time.Time)
}