mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-09 23:27:43 +02:00
cache: support databroker option changes (#1294)
This commit is contained in:
parent
31205c0c29
commit
a1378c81f8
16 changed files with 408 additions and 179 deletions
|
@ -3,7 +3,10 @@ package databroker
|
|||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"reflect"
|
||||
"sort"
|
||||
"sync"
|
||||
|
@ -11,6 +14,7 @@ import (
|
|||
|
||||
"github.com/golang/protobuf/ptypes"
|
||||
"github.com/golang/protobuf/ptypes/empty"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/google/uuid"
|
||||
"github.com/rs/zerolog"
|
||||
"google.golang.org/grpc/codes"
|
||||
|
@ -33,35 +37,39 @@ const (
|
|||
syncBatchSize = 100
|
||||
)
|
||||
|
||||
// newUUID returns a new UUID. This make it easy to stub out in tests.
|
||||
var newUUID = uuid.New
|
||||
|
||||
// Server implements the databroker service using an in memory database.
|
||||
type Server struct {
|
||||
version string
|
||||
cfg *serverConfig
|
||||
log zerolog.Logger
|
||||
cfg *serverConfig
|
||||
log zerolog.Logger
|
||||
|
||||
mu sync.RWMutex
|
||||
version string
|
||||
byType map[string]storage.Backend
|
||||
onTypechange *signal.Signal
|
||||
}
|
||||
|
||||
// New creates a new server.
|
||||
func New(options ...ServerOption) *Server {
|
||||
cfg := newServerConfig(options...)
|
||||
srv := &Server{
|
||||
version: uuid.New().String(),
|
||||
cfg: cfg,
|
||||
log: log.With().Str("service", "databroker").Logger(),
|
||||
log: log.With().Str("service", "databroker").Logger(),
|
||||
|
||||
byType: make(map[string]storage.Backend),
|
||||
onTypechange: signal.New(),
|
||||
}
|
||||
srv.initVersion()
|
||||
srv.UpdateConfig(options...)
|
||||
|
||||
go func() {
|
||||
ticker := time.NewTicker(cfg.deletePermanentlyAfter / 2)
|
||||
ticker := time.NewTicker(time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
srv.mu.RLock()
|
||||
tm := time.Now().Add(-srv.cfg.deletePermanentlyAfter)
|
||||
srv.mu.RUnlock()
|
||||
|
||||
var recordTypes []string
|
||||
srv.mu.RLock()
|
||||
for recordType := range srv.byType {
|
||||
|
@ -70,11 +78,11 @@ func New(options ...ServerOption) *Server {
|
|||
srv.mu.RUnlock()
|
||||
|
||||
for _, recordType := range recordTypes {
|
||||
db, err := srv.getDB(recordType)
|
||||
db, _, err := srv.getDB(recordType, true)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
db.ClearDeleted(context.Background(), time.Now().Add(-cfg.deletePermanentlyAfter))
|
||||
db.ClearDeleted(context.Background(), tm)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
@ -82,7 +90,7 @@ func New(options ...ServerOption) *Server {
|
|||
}
|
||||
|
||||
func (srv *Server) initVersion() {
|
||||
dbServerVersion, err := srv.getDB(recordTypeServerVersion)
|
||||
dbServerVersion, _, err := srv.getDB(recordTypeServerVersion, false)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("failed to init server version")
|
||||
return
|
||||
|
@ -98,12 +106,36 @@ func (srv *Server) initVersion() {
|
|||
return
|
||||
}
|
||||
|
||||
srv.version = newUUID().String()
|
||||
data, _ := ptypes.MarshalAny(&databroker.ServerVersion{Version: srv.version})
|
||||
if err := dbServerVersion.Put(context.Background(), serverVersionKey, data); err != nil {
|
||||
srv.log.Warn().Err(err).Msg("failed to save server version.")
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateConfig updates the server with the new options.
|
||||
func (srv *Server) UpdateConfig(options ...ServerOption) {
|
||||
srv.mu.Lock()
|
||||
defer srv.mu.Unlock()
|
||||
|
||||
cfg := newServerConfig(options...)
|
||||
if cmp.Equal(cfg, srv.cfg, cmp.AllowUnexported(serverConfig{})) {
|
||||
log.Debug().Msg("databroker: no changes detected, re-using existing DBs")
|
||||
return
|
||||
}
|
||||
srv.cfg = cfg
|
||||
|
||||
for t, db := range srv.byType {
|
||||
err := db.Close()
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("databroker: error closing backend")
|
||||
}
|
||||
delete(srv.byType, t)
|
||||
}
|
||||
|
||||
srv.initVersion()
|
||||
}
|
||||
|
||||
// Delete deletes a record from the in-memory list.
|
||||
func (srv *Server) Delete(ctx context.Context, req *databroker.DeleteRequest) (*empty.Empty, error) {
|
||||
_, span := trace.StartSpan(ctx, "databroker.grpc.Delete")
|
||||
|
@ -113,7 +145,7 @@ func (srv *Server) Delete(ctx context.Context, req *databroker.DeleteRequest) (*
|
|||
Str("id", req.GetId()).
|
||||
Msg("delete")
|
||||
|
||||
db, err := srv.getDB(req.GetType())
|
||||
db, _, err := srv.getDB(req.GetType(), true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -134,7 +166,7 @@ func (srv *Server) Get(ctx context.Context, req *databroker.GetRequest) (*databr
|
|||
Str("id", req.GetId()).
|
||||
Msg("get")
|
||||
|
||||
db, err := srv.getDB(req.GetType())
|
||||
db, _, err := srv.getDB(req.GetType(), true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -156,7 +188,7 @@ func (srv *Server) GetAll(ctx context.Context, req *databroker.GetAllRequest) (*
|
|||
Str("type", req.GetType()).
|
||||
Msg("get all")
|
||||
|
||||
db, err := srv.getDB(req.GetType())
|
||||
db, version, err := srv.getDB(req.GetType(), true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -167,7 +199,7 @@ func (srv *Server) GetAll(ctx context.Context, req *databroker.GetAllRequest) (*
|
|||
}
|
||||
|
||||
if len(all) == 0 {
|
||||
return &databroker.GetAllResponse{ServerVersion: srv.version}, nil
|
||||
return &databroker.GetAllResponse{ServerVersion: version}, nil
|
||||
}
|
||||
|
||||
var recordVersion string
|
||||
|
@ -182,7 +214,7 @@ func (srv *Server) GetAll(ctx context.Context, req *databroker.GetAllRequest) (*
|
|||
}
|
||||
|
||||
return &databroker.GetAllResponse{
|
||||
ServerVersion: srv.version,
|
||||
ServerVersion: version,
|
||||
RecordVersion: recordVersion,
|
||||
Records: records,
|
||||
}, nil
|
||||
|
@ -197,7 +229,7 @@ func (srv *Server) Set(ctx context.Context, req *databroker.SetRequest) (*databr
|
|||
Str("id", req.GetId()).
|
||||
Msg("set")
|
||||
|
||||
db, err := srv.getDB(req.GetType())
|
||||
db, version, err := srv.getDB(req.GetType(), true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -210,11 +242,13 @@ func (srv *Server) Set(ctx context.Context, req *databroker.SetRequest) (*databr
|
|||
}
|
||||
return &databroker.SetResponse{
|
||||
Record: record,
|
||||
ServerVersion: srv.version,
|
||||
ServerVersion: version,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (srv *Server) doSync(ctx context.Context, recordVersion *string, db storage.Backend, stream databroker.DataBrokerService_SyncServer) error {
|
||||
func (srv *Server) doSync(ctx context.Context,
|
||||
serverVersion string, recordVersion *string,
|
||||
db storage.Backend, stream databroker.DataBrokerService_SyncServer) error {
|
||||
updated, err := db.List(ctx, *recordVersion)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -232,7 +266,7 @@ func (srv *Server) doSync(ctx context.Context, recordVersion *string, db storage
|
|||
j = len(updated)
|
||||
}
|
||||
if err := stream.Send(&databroker.SyncResponse{
|
||||
ServerVersion: srv.version,
|
||||
ServerVersion: serverVersion,
|
||||
Records: updated[i:j],
|
||||
}); err != nil {
|
||||
return err
|
||||
|
@ -251,34 +285,34 @@ func (srv *Server) Sync(req *databroker.SyncRequest, stream databroker.DataBroke
|
|||
Str("record_version", req.GetRecordVersion()).
|
||||
Msg("sync")
|
||||
|
||||
db, serverVersion, err := srv.getDB(req.GetType(), true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
recordVersion := req.GetRecordVersion()
|
||||
// reset record version if the server versions don't match
|
||||
if req.GetServerVersion() != srv.version {
|
||||
if req.GetServerVersion() != serverVersion {
|
||||
recordVersion = ""
|
||||
// send the new server version to the client
|
||||
err := stream.Send(&databroker.SyncResponse{
|
||||
ServerVersion: srv.version,
|
||||
ServerVersion: serverVersion,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
db, err := srv.getDB(req.GetType())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ctx := stream.Context()
|
||||
ch := db.Watch(ctx)
|
||||
|
||||
// Do first sync, so we won't missed anything.
|
||||
if err := srv.doSync(ctx, &recordVersion, db, stream); err != nil {
|
||||
if err := srv.doSync(ctx, serverVersion, &recordVersion, db, stream); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for range ch {
|
||||
if err := srv.doSync(ctx, &recordVersion, db, stream); err != nil {
|
||||
if err := srv.doSync(ctx, serverVersion, &recordVersion, db, stream); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
@ -335,29 +369,56 @@ func (srv *Server) SyncTypes(req *emptypb.Empty, stream databroker.DataBrokerSer
|
|||
}
|
||||
}
|
||||
|
||||
func (srv *Server) getDB(recordType string) (storage.Backend, error) {
|
||||
func (srv *Server) getDB(recordType string, lock bool) (db storage.Backend, version string, err error) {
|
||||
// double-checked locking:
|
||||
// first try the read lock, then re-try with the write lock, and finally create a new db if nil
|
||||
srv.mu.RLock()
|
||||
db := srv.byType[recordType]
|
||||
srv.mu.RUnlock()
|
||||
if lock {
|
||||
srv.mu.RLock()
|
||||
}
|
||||
db = srv.byType[recordType]
|
||||
version = srv.version
|
||||
if lock {
|
||||
srv.mu.RUnlock()
|
||||
}
|
||||
if db == nil {
|
||||
srv.mu.Lock()
|
||||
if lock {
|
||||
srv.mu.Lock()
|
||||
}
|
||||
db = srv.byType[recordType]
|
||||
version = srv.version
|
||||
var err error
|
||||
if db == nil {
|
||||
db, err = srv.newDB(recordType)
|
||||
srv.byType[recordType] = db
|
||||
}
|
||||
srv.mu.Unlock()
|
||||
if lock {
|
||||
srv.mu.Unlock()
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, "", err
|
||||
}
|
||||
}
|
||||
return db, nil
|
||||
return db, version, nil
|
||||
}
|
||||
|
||||
func (srv *Server) newDB(recordType string) (db storage.Backend, err error) {
|
||||
caCertPool := x509.NewCertPool()
|
||||
if srv.cfg.storageCAFile != "" {
|
||||
if caCert, err := ioutil.ReadFile(srv.cfg.storageCAFile); err == nil {
|
||||
caCertPool.AppendCertsFromPEM(caCert)
|
||||
} else {
|
||||
log.Warn().Err(err).Msg("failed to read databroker CA file")
|
||||
}
|
||||
}
|
||||
tlsConfig := &tls.Config{
|
||||
RootCAs: caCertPool,
|
||||
// nolint: gosec
|
||||
InsecureSkipVerify: srv.cfg.storageCertSkipVerify,
|
||||
}
|
||||
if srv.cfg.storageCertificate != nil {
|
||||
tlsConfig.Certificates = []tls.Certificate{*srv.cfg.storageCertificate}
|
||||
}
|
||||
|
||||
switch srv.cfg.storageType {
|
||||
case config.StorageInMemoryName:
|
||||
return inmemory.NewDB(recordType, srv.cfg.btreeDegree), nil
|
||||
|
@ -366,7 +427,7 @@ func (srv *Server) newDB(recordType string) (db storage.Backend, err error) {
|
|||
srv.cfg.storageConnectionString,
|
||||
recordType,
|
||||
int64(srv.cfg.deletePermanentlyAfter.Seconds()),
|
||||
redis.WithTLSConfig(srv.cfg.storageTLSConfig),
|
||||
redis.WithTLSConfig(tlsConfig),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create new redis storage: %w", err)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue