mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-24 06:27:17 +02:00
Add storage backend interface (#1072)
* pkg: add storage package Which contains storage.Backend interface to initial support for multiple backend storage. * pkg/storage: add inmemory storage * internal/databroker: use storage.Backend interface Instead of implementing multiple databroker server implementation for each kind of storage backend, we use only one databroker server implementation, which is supported multiple storage backends, which satisfy storage.Backend interface.
This commit is contained in:
parent
a70254ab76
commit
2f84dd2aff
8 changed files with 88 additions and 45 deletions
7
cache/databroker.go
vendored
7
cache/databroker.go
vendored
|
@ -3,7 +3,7 @@ package cache
|
||||||
import (
|
import (
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/internal/databroker/memory"
|
internal_databroker "github.com/pomerium/pomerium/internal/databroker"
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -14,10 +14,7 @@ type DataBrokerServer struct {
|
||||||
|
|
||||||
// NewDataBrokerServer creates a new databroker service server.
|
// NewDataBrokerServer creates a new databroker service server.
|
||||||
func NewDataBrokerServer(grpcServer *grpc.Server) *DataBrokerServer {
|
func NewDataBrokerServer(grpcServer *grpc.Server) *DataBrokerServer {
|
||||||
srv := &DataBrokerServer{
|
srv := &DataBrokerServer{DataBrokerServiceServer: internal_databroker.New()}
|
||||||
// just wrap the in-memory data broker server
|
|
||||||
DataBrokerServiceServer: memory.New(),
|
|
||||||
}
|
|
||||||
databroker.RegisterDataBrokerServiceServer(grpcServer, srv)
|
databroker.RegisterDataBrokerServiceServer(grpcServer, srv)
|
||||||
return srv
|
return srv
|
||||||
}
|
}
|
||||||
|
|
1
go.sum
1
go.sum
|
@ -737,6 +737,7 @@ google.golang.org/api v0.18.0/go.mod h1:BwFmGc8tA3vsd7r/7kR8DY7iEEGSU04BFxCo5jP/
|
||||||
google.golang.org/api v0.20.0/go.mod h1:BwFmGc8tA3vsd7r/7kR8DY7iEEGSU04BFxCo5jP/sfE=
|
google.golang.org/api v0.20.0/go.mod h1:BwFmGc8tA3vsd7r/7kR8DY7iEEGSU04BFxCo5jP/sfE=
|
||||||
google.golang.org/api v0.28.0 h1:jMF5hhVfMkTZwHW1SDpKq5CkgWLXOb31Foaca9Zr3oM=
|
google.golang.org/api v0.28.0 h1:jMF5hhVfMkTZwHW1SDpKq5CkgWLXOb31Foaca9Zr3oM=
|
||||||
google.golang.org/api v0.28.0/go.mod h1:lIXQywCXRcnZPGlsd8NbLnOjtAoL6em04bJ9+z0MncE=
|
google.golang.org/api v0.28.0/go.mod h1:lIXQywCXRcnZPGlsd8NbLnOjtAoL6em04bJ9+z0MncE=
|
||||||
|
google.golang.org/api v0.29.0 h1:BaiDisFir8O4IJxvAabCGGkQ6yCJegNQqSVoYUNAnbk=
|
||||||
google.golang.org/api v0.29.0/go.mod h1:Lcubydp8VUV7KeIHD9z2Bys/sm/vGKnG1UHuDBSrHWM=
|
google.golang.org/api v0.29.0/go.mod h1:Lcubydp8VUV7KeIHD9z2Bys/sm/vGKnG1UHuDBSrHWM=
|
||||||
google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM=
|
google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM=
|
||||||
google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
|
google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
package memory
|
package databroker
|
||||||
|
|
||||||
import "time"
|
import "time"
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
// Package memory contains an in-memory data broker implementation.
|
// Package databroker contains a data broker implementation.
|
||||||
package memory
|
package databroker
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
@ -18,6 +18,8 @@ import (
|
||||||
"github.com/pomerium/pomerium/internal/log"
|
"github.com/pomerium/pomerium/internal/log"
|
||||||
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||||
|
"github.com/pomerium/pomerium/pkg/storage"
|
||||||
|
"github.com/pomerium/pomerium/pkg/storage/inmemory"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Server implements the databroker service using an in memory database.
|
// Server implements the databroker service using an in memory database.
|
||||||
|
@ -27,7 +29,7 @@ type Server struct {
|
||||||
log zerolog.Logger
|
log zerolog.Logger
|
||||||
|
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
byType map[string]*DB
|
byType map[string]storage.Backend
|
||||||
onchange *Signal
|
onchange *Signal
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -39,7 +41,7 @@ func New(options ...ServerOption) *Server {
|
||||||
cfg: cfg,
|
cfg: cfg,
|
||||||
log: log.With().Str("service", "databroker").Logger(),
|
log: log.With().Str("service", "databroker").Logger(),
|
||||||
|
|
||||||
byType: make(map[string]*DB),
|
byType: make(map[string]storage.Backend),
|
||||||
onchange: NewSignal(),
|
onchange: NewSignal(),
|
||||||
}
|
}
|
||||||
go func() {
|
go func() {
|
||||||
|
@ -55,7 +57,7 @@ func New(options ...ServerOption) *Server {
|
||||||
srv.mu.RUnlock()
|
srv.mu.RUnlock()
|
||||||
|
|
||||||
for _, recordType := range recordTypes {
|
for _, recordType := range recordTypes {
|
||||||
srv.getDB(recordType).ClearDeleted(time.Now().Add(-cfg.deletePermanentlyAfter))
|
srv.getDB(recordType).ClearDeleted(context.Background(), time.Now().Add(-cfg.deletePermanentlyAfter))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
@ -73,7 +75,9 @@ func (srv *Server) Delete(ctx context.Context, req *databroker.DeleteRequest) (*
|
||||||
|
|
||||||
defer srv.onchange.Broadcast()
|
defer srv.onchange.Broadcast()
|
||||||
|
|
||||||
srv.getDB(req.GetType()).Delete(req.GetId())
|
if err := srv.getDB(req.GetType()).Delete(ctx, req.GetId()); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
return new(empty.Empty), nil
|
return new(empty.Empty), nil
|
||||||
}
|
}
|
||||||
|
@ -87,7 +91,7 @@ func (srv *Server) Get(ctx context.Context, req *databroker.GetRequest) (*databr
|
||||||
Str("id", req.GetId()).
|
Str("id", req.GetId()).
|
||||||
Msg("get")
|
Msg("get")
|
||||||
|
|
||||||
record := srv.getDB(req.GetType()).Get(req.GetId())
|
record := srv.getDB(req.GetType()).Get(ctx, req.GetId())
|
||||||
if record == nil {
|
if record == nil {
|
||||||
return nil, status.Error(codes.NotFound, "record not found")
|
return nil, status.Error(codes.NotFound, "record not found")
|
||||||
}
|
}
|
||||||
|
@ -102,7 +106,7 @@ func (srv *Server) GetAll(ctx context.Context, req *databroker.GetAllRequest) (*
|
||||||
Str("type", req.GetType()).
|
Str("type", req.GetType()).
|
||||||
Msg("get all")
|
Msg("get all")
|
||||||
|
|
||||||
records := srv.getDB(req.GetType()).GetAll()
|
records := srv.getDB(req.GetType()).GetAll(ctx)
|
||||||
var recordVersion string
|
var recordVersion string
|
||||||
for _, record := range records {
|
for _, record := range records {
|
||||||
if record.GetVersion() > recordVersion {
|
if record.GetVersion() > recordVersion {
|
||||||
|
@ -128,8 +132,10 @@ func (srv *Server) Set(ctx context.Context, req *databroker.SetRequest) (*databr
|
||||||
defer srv.onchange.Broadcast()
|
defer srv.onchange.Broadcast()
|
||||||
|
|
||||||
db := srv.getDB(req.GetType())
|
db := srv.getDB(req.GetType())
|
||||||
db.Set(req.GetId(), req.GetData())
|
if err := db.Put(ctx, req.GetId(), req.GetData()); err != nil {
|
||||||
record := db.Get(req.GetId())
|
return nil, err
|
||||||
|
}
|
||||||
|
record := db.Get(ctx, req.GetId())
|
||||||
|
|
||||||
return &databroker.SetResponse{
|
return &databroker.SetResponse{
|
||||||
Record: record,
|
Record: record,
|
||||||
|
@ -158,7 +164,7 @@ func (srv *Server) Sync(req *databroker.SyncRequest, stream databroker.DataBroke
|
||||||
ch := srv.onchange.Bind()
|
ch := srv.onchange.Bind()
|
||||||
defer srv.onchange.Unbind(ch)
|
defer srv.onchange.Unbind(ch)
|
||||||
for {
|
for {
|
||||||
updated := db.List(recordVersion)
|
updated := db.List(context.Background(), recordVersion)
|
||||||
|
|
||||||
if len(updated) > 0 {
|
if len(updated) > 0 {
|
||||||
sort.Slice(updated, func(i, j int) bool {
|
sort.Slice(updated, func(i, j int) bool {
|
||||||
|
@ -232,7 +238,7 @@ func (srv *Server) SyncTypes(req *emptypb.Empty, stream databroker.DataBrokerSer
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (srv *Server) getDB(recordType string) *DB {
|
func (srv *Server) getDB(recordType string) storage.Backend {
|
||||||
// double-checked locking:
|
// double-checked locking:
|
||||||
// first try the read lock, then re-try with the write lock, and finally create a new db if nil
|
// first try the read lock, then re-try with the write lock, and finally create a new db if nil
|
||||||
srv.mu.RLock()
|
srv.mu.RLock()
|
||||||
|
@ -242,7 +248,7 @@ func (srv *Server) getDB(recordType string) *DB {
|
||||||
srv.mu.Lock()
|
srv.mu.Lock()
|
||||||
db = srv.byType[recordType]
|
db = srv.byType[recordType]
|
||||||
if db == nil {
|
if db == nil {
|
||||||
db = NewDB(recordType, srv.cfg.btreeDegree)
|
db = inmemory.NewDB(recordType, srv.cfg.btreeDegree)
|
||||||
srv.byType[recordType] = db
|
srv.byType[recordType] = db
|
||||||
}
|
}
|
||||||
srv.mu.Unlock()
|
srv.mu.Unlock()
|
|
@ -1,4 +1,4 @@
|
||||||
package memory
|
package databroker
|
||||||
|
|
||||||
import "sync"
|
import "sync"
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package memory
|
package inmemory
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
|
@ -12,8 +13,11 @@ import (
|
||||||
"google.golang.org/protobuf/types/known/anypb"
|
"google.golang.org/protobuf/types/known/anypb"
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||||
|
"github.com/pomerium/pomerium/pkg/storage"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var _ storage.Backend = (*DB)(nil)
|
||||||
|
|
||||||
type byIDRecord struct {
|
type byIDRecord struct {
|
||||||
*databroker.Record
|
*databroker.Record
|
||||||
}
|
}
|
||||||
|
@ -52,7 +56,7 @@ func NewDB(recordType string, btreeDegree int) *DB {
|
||||||
}
|
}
|
||||||
|
|
||||||
// ClearDeleted clears all the currently deleted records older than the given cutoff.
|
// ClearDeleted clears all the currently deleted records older than the given cutoff.
|
||||||
func (db *DB) ClearDeleted(cutoff time.Time) {
|
func (db *DB) ClearDeleted(_ context.Context, cutoff time.Time) {
|
||||||
db.mu.Lock()
|
db.mu.Lock()
|
||||||
defer db.mu.Unlock()
|
defer db.mu.Unlock()
|
||||||
|
|
||||||
|
@ -71,15 +75,16 @@ func (db *DB) ClearDeleted(cutoff time.Time) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Delete marks a record as deleted.
|
// Delete marks a record as deleted.
|
||||||
func (db *DB) Delete(id string) {
|
func (db *DB) Delete(_ context.Context, id string) error {
|
||||||
db.replaceOrInsert(id, func(record *databroker.Record) {
|
db.replaceOrInsert(id, func(record *databroker.Record) {
|
||||||
record.DeletedAt = ptypes.TimestampNow()
|
record.DeletedAt = ptypes.TimestampNow()
|
||||||
db.deletedIDs = append(db.deletedIDs, id)
|
db.deletedIDs = append(db.deletedIDs, id)
|
||||||
})
|
})
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get gets a record from the db.
|
// Get gets a record from the db.
|
||||||
func (db *DB) Get(id string) *databroker.Record {
|
func (db *DB) Get(_ context.Context, id string) *databroker.Record {
|
||||||
record, ok := db.byID.Get(byIDRecord{Record: &databroker.Record{Id: id}}).(byIDRecord)
|
record, ok := db.byID.Get(byIDRecord{Record: &databroker.Record{Id: id}}).(byIDRecord)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil
|
return nil
|
||||||
|
@ -88,7 +93,7 @@ func (db *DB) Get(id string) *databroker.Record {
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAll gets all the records in the db.
|
// GetAll gets all the records in the db.
|
||||||
func (db *DB) GetAll() []*databroker.Record {
|
func (db *DB) GetAll(_ context.Context) []*databroker.Record {
|
||||||
var records []*databroker.Record
|
var records []*databroker.Record
|
||||||
db.byID.Ascend(func(item btree.Item) bool {
|
db.byID.Ascend(func(item btree.Item) bool {
|
||||||
records = append(records, item.(byIDRecord).Record)
|
records = append(records, item.(byIDRecord).Record)
|
||||||
|
@ -98,7 +103,7 @@ func (db *DB) GetAll() []*databroker.Record {
|
||||||
}
|
}
|
||||||
|
|
||||||
// List lists all the changes since the given version.
|
// List lists all the changes since the given version.
|
||||||
func (db *DB) List(sinceVersion string) []*databroker.Record {
|
func (db *DB) List(_ context.Context, sinceVersion string) []*databroker.Record {
|
||||||
var records []*databroker.Record
|
var records []*databroker.Record
|
||||||
db.byVersion.AscendGreaterOrEqual(byVersionRecord{Record: &databroker.Record{Version: sinceVersion}}, func(i btree.Item) bool {
|
db.byVersion.AscendGreaterOrEqual(byVersionRecord{Record: &databroker.Record{Version: sinceVersion}}, func(i btree.Item) bool {
|
||||||
record := i.(byVersionRecord)
|
record := i.(byVersionRecord)
|
||||||
|
@ -110,11 +115,12 @@ func (db *DB) List(sinceVersion string) []*databroker.Record {
|
||||||
return records
|
return records
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set replaces or inserts a record in the db.
|
// Put replaces or inserts a record in the db.
|
||||||
func (db *DB) Set(id string, data *anypb.Any) {
|
func (db *DB) Put(_ context.Context, id string, data *anypb.Any) error {
|
||||||
db.replaceOrInsert(id, func(record *databroker.Record) {
|
db.replaceOrInsert(id, func(record *databroker.Record) {
|
||||||
record.Data = data
|
record.Data = data
|
||||||
})
|
})
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *DB) replaceOrInsert(id string, f func(record *databroker.Record)) {
|
func (db *DB) replaceOrInsert(id string, f func(record *databroker.Record)) {
|
|
@ -1,6 +1,7 @@
|
||||||
package memory
|
package inmemory
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
@ -10,14 +11,15 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestDB(t *testing.T) {
|
func TestDB(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
db := NewDB("example", 2)
|
db := NewDB("example", 2)
|
||||||
t.Run("get missing record", func(t *testing.T) {
|
t.Run("get missing record", func(t *testing.T) {
|
||||||
assert.Nil(t, db.Get("abcd"))
|
assert.Nil(t, db.Get(ctx, "abcd"))
|
||||||
})
|
})
|
||||||
t.Run("get record", func(t *testing.T) {
|
t.Run("get record", func(t *testing.T) {
|
||||||
data := new(anypb.Any)
|
data := new(anypb.Any)
|
||||||
db.Set("abcd", data)
|
assert.NoError(t, db.Put(ctx, "abcd", data))
|
||||||
record := db.Get("abcd")
|
record := db.Get(ctx, "abcd")
|
||||||
if assert.NotNil(t, record) {
|
if assert.NotNil(t, record) {
|
||||||
assert.NotNil(t, record.CreatedAt)
|
assert.NotNil(t, record.CreatedAt)
|
||||||
assert.Equal(t, data, record.Data)
|
assert.Equal(t, data, record.Data)
|
||||||
|
@ -29,32 +31,32 @@ func TestDB(t *testing.T) {
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
t.Run("delete record", func(t *testing.T) {
|
t.Run("delete record", func(t *testing.T) {
|
||||||
db.Delete("abcd")
|
assert.NoError(t, db.Delete(ctx, "abcd"))
|
||||||
record := db.Get("abcd")
|
record := db.Get(ctx, "abcd")
|
||||||
if assert.NotNil(t, record) {
|
if assert.NotNil(t, record) {
|
||||||
assert.NotNil(t, record.DeletedAt)
|
assert.NotNil(t, record.DeletedAt)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
t.Run("clear deleted", func(t *testing.T) {
|
t.Run("clear deleted", func(t *testing.T) {
|
||||||
db.ClearDeleted(time.Now().Add(time.Second))
|
db.ClearDeleted(ctx, time.Now().Add(time.Second))
|
||||||
assert.Nil(t, db.Get("abcd"))
|
assert.Nil(t, db.Get(ctx, "abcd"))
|
||||||
})
|
})
|
||||||
t.Run("keep remaining", func(t *testing.T) {
|
t.Run("keep remaining", func(t *testing.T) {
|
||||||
data := new(anypb.Any)
|
data := new(anypb.Any)
|
||||||
db.Set("abcd", data)
|
assert.NoError(t, db.Put(ctx, "abcd", data))
|
||||||
db.Delete("abcd")
|
assert.NoError(t, db.Delete(ctx, "abcd"))
|
||||||
db.ClearDeleted(time.Now().Add(-10 * time.Second))
|
db.ClearDeleted(ctx, time.Now().Add(-10*time.Second))
|
||||||
assert.NotNil(t, db.Get("abcd"))
|
assert.NotNil(t, db.Get(ctx, "abcd"))
|
||||||
db.ClearDeleted(time.Now().Add(time.Second))
|
db.ClearDeleted(ctx, time.Now().Add(time.Second))
|
||||||
})
|
})
|
||||||
t.Run("list", func(t *testing.T) {
|
t.Run("list", func(t *testing.T) {
|
||||||
for i := 0; i < 10; i++ {
|
for i := 0; i < 10; i++ {
|
||||||
data := new(anypb.Any)
|
data := new(anypb.Any)
|
||||||
db.Set(fmt.Sprintf("%02d", i), data)
|
assert.NoError(t, db.Put(ctx, fmt.Sprintf("%02d", i), data))
|
||||||
}
|
}
|
||||||
|
|
||||||
assert.Len(t, db.List(""), 10)
|
assert.Len(t, db.List(ctx, ""), 10)
|
||||||
assert.Len(t, db.List("00000000000A"), 4)
|
assert.Len(t, db.List(ctx, "00000000000A"), 4)
|
||||||
assert.Len(t, db.List("00000000000F"), 0)
|
assert.Len(t, db.List(ctx, "00000000000F"), 0)
|
||||||
})
|
})
|
||||||
}
|
}
|
31
pkg/storage/storage.go
Normal file
31
pkg/storage/storage.go
Normal file
|
@ -0,0 +1,31 @@
|
||||||
|
package storage
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"google.golang.org/protobuf/types/known/anypb"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Backend is the interface required for a storage backend.
|
||||||
|
type Backend interface {
|
||||||
|
// Put is used to insert or update a record.
|
||||||
|
Put(ctx context.Context, id string, data *anypb.Any) error
|
||||||
|
|
||||||
|
// Get is used to retrieve a record.
|
||||||
|
Get(ctx context.Context, id string) *databroker.Record
|
||||||
|
|
||||||
|
// GetAll is used to retrieve all the records.
|
||||||
|
GetAll(ctx context.Context) []*databroker.Record
|
||||||
|
|
||||||
|
// List is used to retrieve all the records since a version.
|
||||||
|
List(ctx context.Context, sinceVersion string) []*databroker.Record
|
||||||
|
|
||||||
|
// Delete is used to mark a record as deleted.
|
||||||
|
Delete(ctx context.Context, id string) error
|
||||||
|
|
||||||
|
// ClearDeleted is used clear marked delete records.
|
||||||
|
ClearDeleted(ctx context.Context, cutoff time.Time)
|
||||||
|
}
|
Loading…
Add table
Add a link
Reference in a new issue