mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-24 14:37:12 +02:00
Add storage backend interface (#1072)
* pkg: add storage package Which contains storage.Backend interface to initial support for multiple backend storage. * pkg/storage: add inmemory storage * internal/databroker: use storage.Backend interface Instead of implementing multiple databroker server implementation for each kind of storage backend, we use only one databroker server implementation, which is supported multiple storage backends, which satisfy storage.Backend interface.
This commit is contained in:
parent
a70254ab76
commit
2f84dd2aff
8 changed files with 88 additions and 45 deletions
7
cache/databroker.go
vendored
7
cache/databroker.go
vendored
|
@ -3,7 +3,7 @@ package cache
|
|||
import (
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/databroker/memory"
|
||||
internal_databroker "github.com/pomerium/pomerium/internal/databroker"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
)
|
||||
|
||||
|
@ -14,10 +14,7 @@ type DataBrokerServer struct {
|
|||
|
||||
// NewDataBrokerServer creates a new databroker service server.
|
||||
func NewDataBrokerServer(grpcServer *grpc.Server) *DataBrokerServer {
|
||||
srv := &DataBrokerServer{
|
||||
// just wrap the in-memory data broker server
|
||||
DataBrokerServiceServer: memory.New(),
|
||||
}
|
||||
srv := &DataBrokerServer{DataBrokerServiceServer: internal_databroker.New()}
|
||||
databroker.RegisterDataBrokerServiceServer(grpcServer, srv)
|
||||
return srv
|
||||
}
|
||||
|
|
1
go.sum
1
go.sum
|
@ -737,6 +737,7 @@ google.golang.org/api v0.18.0/go.mod h1:BwFmGc8tA3vsd7r/7kR8DY7iEEGSU04BFxCo5jP/
|
|||
google.golang.org/api v0.20.0/go.mod h1:BwFmGc8tA3vsd7r/7kR8DY7iEEGSU04BFxCo5jP/sfE=
|
||||
google.golang.org/api v0.28.0 h1:jMF5hhVfMkTZwHW1SDpKq5CkgWLXOb31Foaca9Zr3oM=
|
||||
google.golang.org/api v0.28.0/go.mod h1:lIXQywCXRcnZPGlsd8NbLnOjtAoL6em04bJ9+z0MncE=
|
||||
google.golang.org/api v0.29.0 h1:BaiDisFir8O4IJxvAabCGGkQ6yCJegNQqSVoYUNAnbk=
|
||||
google.golang.org/api v0.29.0/go.mod h1:Lcubydp8VUV7KeIHD9z2Bys/sm/vGKnG1UHuDBSrHWM=
|
||||
google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM=
|
||||
google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
package memory
|
||||
package databroker
|
||||
|
||||
import "time"
|
||||
|
|
@ -1,5 +1,5 @@
|
|||
// Package memory contains an in-memory data broker implementation.
|
||||
package memory
|
||||
// Package databroker contains a data broker implementation.
|
||||
package databroker
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
@ -18,6 +18,8 @@ import (
|
|||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/pkg/storage"
|
||||
"github.com/pomerium/pomerium/pkg/storage/inmemory"
|
||||
)
|
||||
|
||||
// Server implements the databroker service using an in memory database.
|
||||
|
@ -27,7 +29,7 @@ type Server struct {
|
|||
log zerolog.Logger
|
||||
|
||||
mu sync.RWMutex
|
||||
byType map[string]*DB
|
||||
byType map[string]storage.Backend
|
||||
onchange *Signal
|
||||
}
|
||||
|
||||
|
@ -39,7 +41,7 @@ func New(options ...ServerOption) *Server {
|
|||
cfg: cfg,
|
||||
log: log.With().Str("service", "databroker").Logger(),
|
||||
|
||||
byType: make(map[string]*DB),
|
||||
byType: make(map[string]storage.Backend),
|
||||
onchange: NewSignal(),
|
||||
}
|
||||
go func() {
|
||||
|
@ -55,7 +57,7 @@ func New(options ...ServerOption) *Server {
|
|||
srv.mu.RUnlock()
|
||||
|
||||
for _, recordType := range recordTypes {
|
||||
srv.getDB(recordType).ClearDeleted(time.Now().Add(-cfg.deletePermanentlyAfter))
|
||||
srv.getDB(recordType).ClearDeleted(context.Background(), time.Now().Add(-cfg.deletePermanentlyAfter))
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
@ -73,7 +75,9 @@ func (srv *Server) Delete(ctx context.Context, req *databroker.DeleteRequest) (*
|
|||
|
||||
defer srv.onchange.Broadcast()
|
||||
|
||||
srv.getDB(req.GetType()).Delete(req.GetId())
|
||||
if err := srv.getDB(req.GetType()).Delete(ctx, req.GetId()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return new(empty.Empty), nil
|
||||
}
|
||||
|
@ -87,7 +91,7 @@ func (srv *Server) Get(ctx context.Context, req *databroker.GetRequest) (*databr
|
|||
Str("id", req.GetId()).
|
||||
Msg("get")
|
||||
|
||||
record := srv.getDB(req.GetType()).Get(req.GetId())
|
||||
record := srv.getDB(req.GetType()).Get(ctx, req.GetId())
|
||||
if record == nil {
|
||||
return nil, status.Error(codes.NotFound, "record not found")
|
||||
}
|
||||
|
@ -102,7 +106,7 @@ func (srv *Server) GetAll(ctx context.Context, req *databroker.GetAllRequest) (*
|
|||
Str("type", req.GetType()).
|
||||
Msg("get all")
|
||||
|
||||
records := srv.getDB(req.GetType()).GetAll()
|
||||
records := srv.getDB(req.GetType()).GetAll(ctx)
|
||||
var recordVersion string
|
||||
for _, record := range records {
|
||||
if record.GetVersion() > recordVersion {
|
||||
|
@ -128,8 +132,10 @@ func (srv *Server) Set(ctx context.Context, req *databroker.SetRequest) (*databr
|
|||
defer srv.onchange.Broadcast()
|
||||
|
||||
db := srv.getDB(req.GetType())
|
||||
db.Set(req.GetId(), req.GetData())
|
||||
record := db.Get(req.GetId())
|
||||
if err := db.Put(ctx, req.GetId(), req.GetData()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
record := db.Get(ctx, req.GetId())
|
||||
|
||||
return &databroker.SetResponse{
|
||||
Record: record,
|
||||
|
@ -158,7 +164,7 @@ func (srv *Server) Sync(req *databroker.SyncRequest, stream databroker.DataBroke
|
|||
ch := srv.onchange.Bind()
|
||||
defer srv.onchange.Unbind(ch)
|
||||
for {
|
||||
updated := db.List(recordVersion)
|
||||
updated := db.List(context.Background(), recordVersion)
|
||||
|
||||
if len(updated) > 0 {
|
||||
sort.Slice(updated, func(i, j int) bool {
|
||||
|
@ -232,7 +238,7 @@ func (srv *Server) SyncTypes(req *emptypb.Empty, stream databroker.DataBrokerSer
|
|||
}
|
||||
}
|
||||
|
||||
func (srv *Server) getDB(recordType string) *DB {
|
||||
func (srv *Server) getDB(recordType string) storage.Backend {
|
||||
// double-checked locking:
|
||||
// first try the read lock, then re-try with the write lock, and finally create a new db if nil
|
||||
srv.mu.RLock()
|
||||
|
@ -242,7 +248,7 @@ func (srv *Server) getDB(recordType string) *DB {
|
|||
srv.mu.Lock()
|
||||
db = srv.byType[recordType]
|
||||
if db == nil {
|
||||
db = NewDB(recordType, srv.cfg.btreeDegree)
|
||||
db = inmemory.NewDB(recordType, srv.cfg.btreeDegree)
|
||||
srv.byType[recordType] = db
|
||||
}
|
||||
srv.mu.Unlock()
|
|
@ -1,4 +1,4 @@
|
|||
package memory
|
||||
package databroker
|
||||
|
||||
import "sync"
|
||||
|
|
@ -1,6 +1,7 @@
|
|||
package memory
|
||||
package inmemory
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
@ -12,8 +13,11 @@ import (
|
|||
"google.golang.org/protobuf/types/known/anypb"
|
||||
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/pkg/storage"
|
||||
)
|
||||
|
||||
var _ storage.Backend = (*DB)(nil)
|
||||
|
||||
type byIDRecord struct {
|
||||
*databroker.Record
|
||||
}
|
||||
|
@ -52,7 +56,7 @@ func NewDB(recordType string, btreeDegree int) *DB {
|
|||
}
|
||||
|
||||
// ClearDeleted clears all the currently deleted records older than the given cutoff.
|
||||
func (db *DB) ClearDeleted(cutoff time.Time) {
|
||||
func (db *DB) ClearDeleted(_ context.Context, cutoff time.Time) {
|
||||
db.mu.Lock()
|
||||
defer db.mu.Unlock()
|
||||
|
||||
|
@ -71,15 +75,16 @@ func (db *DB) ClearDeleted(cutoff time.Time) {
|
|||
}
|
||||
|
||||
// Delete marks a record as deleted.
|
||||
func (db *DB) Delete(id string) {
|
||||
func (db *DB) Delete(_ context.Context, id string) error {
|
||||
db.replaceOrInsert(id, func(record *databroker.Record) {
|
||||
record.DeletedAt = ptypes.TimestampNow()
|
||||
db.deletedIDs = append(db.deletedIDs, id)
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get gets a record from the db.
|
||||
func (db *DB) Get(id string) *databroker.Record {
|
||||
func (db *DB) Get(_ context.Context, id string) *databroker.Record {
|
||||
record, ok := db.byID.Get(byIDRecord{Record: &databroker.Record{Id: id}}).(byIDRecord)
|
||||
if !ok {
|
||||
return nil
|
||||
|
@ -88,7 +93,7 @@ func (db *DB) Get(id string) *databroker.Record {
|
|||
}
|
||||
|
||||
// GetAll gets all the records in the db.
|
||||
func (db *DB) GetAll() []*databroker.Record {
|
||||
func (db *DB) GetAll(_ context.Context) []*databroker.Record {
|
||||
var records []*databroker.Record
|
||||
db.byID.Ascend(func(item btree.Item) bool {
|
||||
records = append(records, item.(byIDRecord).Record)
|
||||
|
@ -98,7 +103,7 @@ func (db *DB) GetAll() []*databroker.Record {
|
|||
}
|
||||
|
||||
// List lists all the changes since the given version.
|
||||
func (db *DB) List(sinceVersion string) []*databroker.Record {
|
||||
func (db *DB) List(_ context.Context, sinceVersion string) []*databroker.Record {
|
||||
var records []*databroker.Record
|
||||
db.byVersion.AscendGreaterOrEqual(byVersionRecord{Record: &databroker.Record{Version: sinceVersion}}, func(i btree.Item) bool {
|
||||
record := i.(byVersionRecord)
|
||||
|
@ -110,11 +115,12 @@ func (db *DB) List(sinceVersion string) []*databroker.Record {
|
|||
return records
|
||||
}
|
||||
|
||||
// Set replaces or inserts a record in the db.
|
||||
func (db *DB) Set(id string, data *anypb.Any) {
|
||||
// Put replaces or inserts a record in the db.
|
||||
func (db *DB) Put(_ context.Context, id string, data *anypb.Any) error {
|
||||
db.replaceOrInsert(id, func(record *databroker.Record) {
|
||||
record.Data = data
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *DB) replaceOrInsert(id string, f func(record *databroker.Record)) {
|
|
@ -1,6 +1,7 @@
|
|||
package memory
|
||||
package inmemory
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
@ -10,14 +11,15 @@ import (
|
|||
)
|
||||
|
||||
func TestDB(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
db := NewDB("example", 2)
|
||||
t.Run("get missing record", func(t *testing.T) {
|
||||
assert.Nil(t, db.Get("abcd"))
|
||||
assert.Nil(t, db.Get(ctx, "abcd"))
|
||||
})
|
||||
t.Run("get record", func(t *testing.T) {
|
||||
data := new(anypb.Any)
|
||||
db.Set("abcd", data)
|
||||
record := db.Get("abcd")
|
||||
assert.NoError(t, db.Put(ctx, "abcd", data))
|
||||
record := db.Get(ctx, "abcd")
|
||||
if assert.NotNil(t, record) {
|
||||
assert.NotNil(t, record.CreatedAt)
|
||||
assert.Equal(t, data, record.Data)
|
||||
|
@ -29,32 +31,32 @@ func TestDB(t *testing.T) {
|
|||
}
|
||||
})
|
||||
t.Run("delete record", func(t *testing.T) {
|
||||
db.Delete("abcd")
|
||||
record := db.Get("abcd")
|
||||
assert.NoError(t, db.Delete(ctx, "abcd"))
|
||||
record := db.Get(ctx, "abcd")
|
||||
if assert.NotNil(t, record) {
|
||||
assert.NotNil(t, record.DeletedAt)
|
||||
}
|
||||
})
|
||||
t.Run("clear deleted", func(t *testing.T) {
|
||||
db.ClearDeleted(time.Now().Add(time.Second))
|
||||
assert.Nil(t, db.Get("abcd"))
|
||||
db.ClearDeleted(ctx, time.Now().Add(time.Second))
|
||||
assert.Nil(t, db.Get(ctx, "abcd"))
|
||||
})
|
||||
t.Run("keep remaining", func(t *testing.T) {
|
||||
data := new(anypb.Any)
|
||||
db.Set("abcd", data)
|
||||
db.Delete("abcd")
|
||||
db.ClearDeleted(time.Now().Add(-10 * time.Second))
|
||||
assert.NotNil(t, db.Get("abcd"))
|
||||
db.ClearDeleted(time.Now().Add(time.Second))
|
||||
assert.NoError(t, db.Put(ctx, "abcd", data))
|
||||
assert.NoError(t, db.Delete(ctx, "abcd"))
|
||||
db.ClearDeleted(ctx, time.Now().Add(-10*time.Second))
|
||||
assert.NotNil(t, db.Get(ctx, "abcd"))
|
||||
db.ClearDeleted(ctx, time.Now().Add(time.Second))
|
||||
})
|
||||
t.Run("list", func(t *testing.T) {
|
||||
for i := 0; i < 10; i++ {
|
||||
data := new(anypb.Any)
|
||||
db.Set(fmt.Sprintf("%02d", i), data)
|
||||
assert.NoError(t, db.Put(ctx, fmt.Sprintf("%02d", i), data))
|
||||
}
|
||||
|
||||
assert.Len(t, db.List(""), 10)
|
||||
assert.Len(t, db.List("00000000000A"), 4)
|
||||
assert.Len(t, db.List("00000000000F"), 0)
|
||||
assert.Len(t, db.List(ctx, ""), 10)
|
||||
assert.Len(t, db.List(ctx, "00000000000A"), 4)
|
||||
assert.Len(t, db.List(ctx, "00000000000F"), 0)
|
||||
})
|
||||
}
|
31
pkg/storage/storage.go
Normal file
31
pkg/storage/storage.go
Normal file
|
@ -0,0 +1,31 @@
|
|||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"google.golang.org/protobuf/types/known/anypb"
|
||||
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
)
|
||||
|
||||
// Backend is the interface required for a storage backend.
|
||||
type Backend interface {
|
||||
// Put is used to insert or update a record.
|
||||
Put(ctx context.Context, id string, data *anypb.Any) error
|
||||
|
||||
// Get is used to retrieve a record.
|
||||
Get(ctx context.Context, id string) *databroker.Record
|
||||
|
||||
// GetAll is used to retrieve all the records.
|
||||
GetAll(ctx context.Context) []*databroker.Record
|
||||
|
||||
// List is used to retrieve all the records since a version.
|
||||
List(ctx context.Context, sinceVersion string) []*databroker.Record
|
||||
|
||||
// Delete is used to mark a record as deleted.
|
||||
Delete(ctx context.Context, id string) error
|
||||
|
||||
// ClearDeleted is used clear marked delete records.
|
||||
ClearDeleted(ctx context.Context, cutoff time.Time)
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue