diff --git a/cache/databroker.go b/cache/databroker.go index b31e0866e..100cedca9 100644 --- a/cache/databroker.go +++ b/cache/databroker.go @@ -3,7 +3,7 @@ package cache import ( "google.golang.org/grpc" - "github.com/pomerium/pomerium/internal/databroker/memory" + internal_databroker "github.com/pomerium/pomerium/internal/databroker" "github.com/pomerium/pomerium/pkg/grpc/databroker" ) @@ -14,10 +14,7 @@ type DataBrokerServer struct { // NewDataBrokerServer creates a new databroker service server. func NewDataBrokerServer(grpcServer *grpc.Server) *DataBrokerServer { - srv := &DataBrokerServer{ - // just wrap the in-memory data broker server - DataBrokerServiceServer: memory.New(), - } + srv := &DataBrokerServer{DataBrokerServiceServer: internal_databroker.New()} databroker.RegisterDataBrokerServiceServer(grpcServer, srv) return srv } diff --git a/go.sum b/go.sum index c5e79df60..a416b2b31 100644 --- a/go.sum +++ b/go.sum @@ -737,6 +737,7 @@ google.golang.org/api v0.18.0/go.mod h1:BwFmGc8tA3vsd7r/7kR8DY7iEEGSU04BFxCo5jP/ google.golang.org/api v0.20.0/go.mod h1:BwFmGc8tA3vsd7r/7kR8DY7iEEGSU04BFxCo5jP/sfE= google.golang.org/api v0.28.0 h1:jMF5hhVfMkTZwHW1SDpKq5CkgWLXOb31Foaca9Zr3oM= google.golang.org/api v0.28.0/go.mod h1:lIXQywCXRcnZPGlsd8NbLnOjtAoL6em04bJ9+z0MncE= +google.golang.org/api v0.29.0 h1:BaiDisFir8O4IJxvAabCGGkQ6yCJegNQqSVoYUNAnbk= google.golang.org/api v0.29.0/go.mod h1:Lcubydp8VUV7KeIHD9z2Bys/sm/vGKnG1UHuDBSrHWM= google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= diff --git a/internal/databroker/memory/config.go b/internal/databroker/config.go similarity index 98% rename from internal/databroker/memory/config.go rename to internal/databroker/config.go index ad88a7eba..b68ebf2f8 100644 --- a/internal/databroker/memory/config.go +++ b/internal/databroker/config.go @@ -1,4 +1,4 @@ -package memory +package databroker import "time" diff --git a/internal/databroker/memory/server.go b/internal/databroker/server.go similarity index 87% rename from internal/databroker/memory/server.go rename to internal/databroker/server.go index 5617efe17..f35c456a2 100644 --- a/internal/databroker/memory/server.go +++ b/internal/databroker/server.go @@ -1,5 +1,5 @@ -// Package memory contains an in-memory data broker implementation. -package memory +// Package databroker contains a data broker implementation. +package databroker import ( "context" @@ -18,6 +18,8 @@ import ( "github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/telemetry/trace" "github.com/pomerium/pomerium/pkg/grpc/databroker" + "github.com/pomerium/pomerium/pkg/storage" + "github.com/pomerium/pomerium/pkg/storage/inmemory" ) // Server implements the databroker service using an in memory database. @@ -27,7 +29,7 @@ type Server struct { log zerolog.Logger mu sync.RWMutex - byType map[string]*DB + byType map[string]storage.Backend onchange *Signal } @@ -39,7 +41,7 @@ func New(options ...ServerOption) *Server { cfg: cfg, log: log.With().Str("service", "databroker").Logger(), - byType: make(map[string]*DB), + byType: make(map[string]storage.Backend), onchange: NewSignal(), } go func() { @@ -55,7 +57,7 @@ func New(options ...ServerOption) *Server { srv.mu.RUnlock() for _, recordType := range recordTypes { - srv.getDB(recordType).ClearDeleted(time.Now().Add(-cfg.deletePermanentlyAfter)) + srv.getDB(recordType).ClearDeleted(context.Background(), time.Now().Add(-cfg.deletePermanentlyAfter)) } } }() @@ -73,7 +75,9 @@ func (srv *Server) Delete(ctx context.Context, req *databroker.DeleteRequest) (* defer srv.onchange.Broadcast() - srv.getDB(req.GetType()).Delete(req.GetId()) + if err := srv.getDB(req.GetType()).Delete(ctx, req.GetId()); err != nil { + return nil, err + } return new(empty.Empty), nil } @@ -87,7 +91,7 @@ func (srv *Server) Get(ctx context.Context, req *databroker.GetRequest) (*databr Str("id", req.GetId()). Msg("get") - record := srv.getDB(req.GetType()).Get(req.GetId()) + record := srv.getDB(req.GetType()).Get(ctx, req.GetId()) if record == nil { return nil, status.Error(codes.NotFound, "record not found") } @@ -102,7 +106,7 @@ func (srv *Server) GetAll(ctx context.Context, req *databroker.GetAllRequest) (* Str("type", req.GetType()). Msg("get all") - records := srv.getDB(req.GetType()).GetAll() + records := srv.getDB(req.GetType()).GetAll(ctx) var recordVersion string for _, record := range records { if record.GetVersion() > recordVersion { @@ -128,8 +132,10 @@ func (srv *Server) Set(ctx context.Context, req *databroker.SetRequest) (*databr defer srv.onchange.Broadcast() db := srv.getDB(req.GetType()) - db.Set(req.GetId(), req.GetData()) - record := db.Get(req.GetId()) + if err := db.Put(ctx, req.GetId(), req.GetData()); err != nil { + return nil, err + } + record := db.Get(ctx, req.GetId()) return &databroker.SetResponse{ Record: record, @@ -158,7 +164,7 @@ func (srv *Server) Sync(req *databroker.SyncRequest, stream databroker.DataBroke ch := srv.onchange.Bind() defer srv.onchange.Unbind(ch) for { - updated := db.List(recordVersion) + updated := db.List(context.Background(), recordVersion) if len(updated) > 0 { sort.Slice(updated, func(i, j int) bool { @@ -232,7 +238,7 @@ func (srv *Server) SyncTypes(req *emptypb.Empty, stream databroker.DataBrokerSer } } -func (srv *Server) getDB(recordType string) *DB { +func (srv *Server) getDB(recordType string) storage.Backend { // double-checked locking: // first try the read lock, then re-try with the write lock, and finally create a new db if nil srv.mu.RLock() @@ -242,7 +248,7 @@ func (srv *Server) getDB(recordType string) *DB { srv.mu.Lock() db = srv.byType[recordType] if db == nil { - db = NewDB(recordType, srv.cfg.btreeDegree) + db = inmemory.NewDB(recordType, srv.cfg.btreeDegree) srv.byType[recordType] = db } srv.mu.Unlock() diff --git a/internal/databroker/memory/signal.go b/internal/databroker/signal.go similarity index 98% rename from internal/databroker/memory/signal.go rename to internal/databroker/signal.go index 5901419a2..14436871a 100644 --- a/internal/databroker/memory/signal.go +++ b/internal/databroker/signal.go @@ -1,4 +1,4 @@ -package memory +package databroker import "sync" diff --git a/internal/databroker/memory/db.go b/pkg/storage/inmemory/inmemory.go similarity index 84% rename from internal/databroker/memory/db.go rename to pkg/storage/inmemory/inmemory.go index d1a78953b..c8ddb3606 100644 --- a/internal/databroker/memory/db.go +++ b/pkg/storage/inmemory/inmemory.go @@ -1,6 +1,7 @@ -package memory +package inmemory import ( + "context" "fmt" "sync" "sync/atomic" @@ -12,8 +13,11 @@ import ( "google.golang.org/protobuf/types/known/anypb" "github.com/pomerium/pomerium/pkg/grpc/databroker" + "github.com/pomerium/pomerium/pkg/storage" ) +var _ storage.Backend = (*DB)(nil) + type byIDRecord struct { *databroker.Record } @@ -52,7 +56,7 @@ func NewDB(recordType string, btreeDegree int) *DB { } // ClearDeleted clears all the currently deleted records older than the given cutoff. -func (db *DB) ClearDeleted(cutoff time.Time) { +func (db *DB) ClearDeleted(_ context.Context, cutoff time.Time) { db.mu.Lock() defer db.mu.Unlock() @@ -71,15 +75,16 @@ func (db *DB) ClearDeleted(cutoff time.Time) { } // Delete marks a record as deleted. -func (db *DB) Delete(id string) { +func (db *DB) Delete(_ context.Context, id string) error { db.replaceOrInsert(id, func(record *databroker.Record) { record.DeletedAt = ptypes.TimestampNow() db.deletedIDs = append(db.deletedIDs, id) }) + return nil } // Get gets a record from the db. -func (db *DB) Get(id string) *databroker.Record { +func (db *DB) Get(_ context.Context, id string) *databroker.Record { record, ok := db.byID.Get(byIDRecord{Record: &databroker.Record{Id: id}}).(byIDRecord) if !ok { return nil @@ -88,7 +93,7 @@ func (db *DB) Get(id string) *databroker.Record { } // GetAll gets all the records in the db. -func (db *DB) GetAll() []*databroker.Record { +func (db *DB) GetAll(_ context.Context) []*databroker.Record { var records []*databroker.Record db.byID.Ascend(func(item btree.Item) bool { records = append(records, item.(byIDRecord).Record) @@ -98,7 +103,7 @@ func (db *DB) GetAll() []*databroker.Record { } // List lists all the changes since the given version. -func (db *DB) List(sinceVersion string) []*databroker.Record { +func (db *DB) List(_ context.Context, sinceVersion string) []*databroker.Record { var records []*databroker.Record db.byVersion.AscendGreaterOrEqual(byVersionRecord{Record: &databroker.Record{Version: sinceVersion}}, func(i btree.Item) bool { record := i.(byVersionRecord) @@ -110,11 +115,12 @@ func (db *DB) List(sinceVersion string) []*databroker.Record { return records } -// Set replaces or inserts a record in the db. -func (db *DB) Set(id string, data *anypb.Any) { +// Put replaces or inserts a record in the db. +func (db *DB) Put(_ context.Context, id string, data *anypb.Any) error { db.replaceOrInsert(id, func(record *databroker.Record) { record.Data = data }) + return nil } func (db *DB) replaceOrInsert(id string, f func(record *databroker.Record)) { diff --git a/internal/databroker/memory/db_test.go b/pkg/storage/inmemory/inmemory_test.go similarity index 55% rename from internal/databroker/memory/db_test.go rename to pkg/storage/inmemory/inmemory_test.go index 8d6bc966c..0b9eda112 100644 --- a/internal/databroker/memory/db_test.go +++ b/pkg/storage/inmemory/inmemory_test.go @@ -1,6 +1,7 @@ -package memory +package inmemory import ( + "context" "fmt" "testing" "time" @@ -10,14 +11,15 @@ import ( ) func TestDB(t *testing.T) { + ctx := context.Background() db := NewDB("example", 2) t.Run("get missing record", func(t *testing.T) { - assert.Nil(t, db.Get("abcd")) + assert.Nil(t, db.Get(ctx, "abcd")) }) t.Run("get record", func(t *testing.T) { data := new(anypb.Any) - db.Set("abcd", data) - record := db.Get("abcd") + assert.NoError(t, db.Put(ctx, "abcd", data)) + record := db.Get(ctx, "abcd") if assert.NotNil(t, record) { assert.NotNil(t, record.CreatedAt) assert.Equal(t, data, record.Data) @@ -29,32 +31,32 @@ func TestDB(t *testing.T) { } }) t.Run("delete record", func(t *testing.T) { - db.Delete("abcd") - record := db.Get("abcd") + assert.NoError(t, db.Delete(ctx, "abcd")) + record := db.Get(ctx, "abcd") if assert.NotNil(t, record) { assert.NotNil(t, record.DeletedAt) } }) t.Run("clear deleted", func(t *testing.T) { - db.ClearDeleted(time.Now().Add(time.Second)) - assert.Nil(t, db.Get("abcd")) + db.ClearDeleted(ctx, time.Now().Add(time.Second)) + assert.Nil(t, db.Get(ctx, "abcd")) }) t.Run("keep remaining", func(t *testing.T) { data := new(anypb.Any) - db.Set("abcd", data) - db.Delete("abcd") - db.ClearDeleted(time.Now().Add(-10 * time.Second)) - assert.NotNil(t, db.Get("abcd")) - db.ClearDeleted(time.Now().Add(time.Second)) + assert.NoError(t, db.Put(ctx, "abcd", data)) + assert.NoError(t, db.Delete(ctx, "abcd")) + db.ClearDeleted(ctx, time.Now().Add(-10*time.Second)) + assert.NotNil(t, db.Get(ctx, "abcd")) + db.ClearDeleted(ctx, time.Now().Add(time.Second)) }) t.Run("list", func(t *testing.T) { for i := 0; i < 10; i++ { data := new(anypb.Any) - db.Set(fmt.Sprintf("%02d", i), data) + assert.NoError(t, db.Put(ctx, fmt.Sprintf("%02d", i), data)) } - assert.Len(t, db.List(""), 10) - assert.Len(t, db.List("00000000000A"), 4) - assert.Len(t, db.List("00000000000F"), 0) + assert.Len(t, db.List(ctx, ""), 10) + assert.Len(t, db.List(ctx, "00000000000A"), 4) + assert.Len(t, db.List(ctx, "00000000000F"), 0) }) } diff --git a/pkg/storage/storage.go b/pkg/storage/storage.go new file mode 100644 index 000000000..9796c1ac1 --- /dev/null +++ b/pkg/storage/storage.go @@ -0,0 +1,31 @@ +package storage + +import ( + "context" + "time" + + "google.golang.org/protobuf/types/known/anypb" + + "github.com/pomerium/pomerium/pkg/grpc/databroker" +) + +// Backend is the interface required for a storage backend. +type Backend interface { + // Put is used to insert or update a record. + Put(ctx context.Context, id string, data *anypb.Any) error + + // Get is used to retrieve a record. + Get(ctx context.Context, id string) *databroker.Record + + // GetAll is used to retrieve all the records. + GetAll(ctx context.Context) []*databroker.Record + + // List is used to retrieve all the records since a version. + List(ctx context.Context, sinceVersion string) []*databroker.Record + + // Delete is used to mark a record as deleted. + Delete(ctx context.Context, id string) error + + // ClearDeleted is used clear marked delete records. + ClearDeleted(ctx context.Context, cutoff time.Time) +}