databroker: store server version in backend (#2142)

This commit is contained in:
Caleb Doxsey 2021-04-28 09:12:52 -06:00 committed by GitHub
parent 1b698053f6
commit 91c7dc742f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
15 changed files with 317 additions and 333 deletions

View file

@ -12,8 +12,6 @@ import (
"github.com/google/go-cmp/cmp"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/anypb"
"google.golang.org/protobuf/types/known/wrapperspb"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/log"
@ -26,17 +24,11 @@ import (
"github.com/pomerium/pomerium/pkg/storage/redis"
)
const (
recordTypeServerVersion = "server_version"
serverVersionKey = "version"
)
// Server implements the databroker service using an in memory database.
type Server struct {
cfg *serverConfig
mu sync.RWMutex
version uint64
backend storage.Backend
}
@ -47,40 +39,6 @@ func New(options ...ServerOption) *Server {
return srv
}
func (srv *Server) initVersion(ctx context.Context) {
db, _, err := srv.getBackendLocked()
if err != nil {
log.Error(ctx).Err(err).Msg("failed to init server version")
return
}
// Get version from storage first.
r, err := db.Get(ctx, recordTypeServerVersion, serverVersionKey)
switch {
case err == nil:
var sv wrapperspb.UInt64Value
if err := r.GetData().UnmarshalTo(&sv); err == nil {
log.Debug(ctx).Uint64("server_version", sv.Value).Msg("got db version from Backend")
srv.version = sv.Value
}
return
case errors.Is(err, storage.ErrNotFound): // no server version, so we'll create a new one
case err != nil:
log.Error(ctx).Err(err).Msg("failed to retrieve server version")
return
}
srv.version = cryptutil.NewRandomUInt64()
data, _ := anypb.New(wrapperspb.UInt64(srv.version))
if err := db.Put(context.Background(), &databroker.Record{
Type: recordTypeServerVersion,
Id: serverVersionKey,
Data: data,
}); err != nil {
log.Warn(ctx).Err(err).Msg("failed to save server version.")
}
}
// UpdateConfig updates the server with the new options.
func (srv *Server) UpdateConfig(options ...ServerOption) {
srv.mu.Lock()
@ -102,8 +60,6 @@ func (srv *Server) UpdateConfig(options ...ServerOption) {
}
srv.backend = nil
}
srv.initVersion(ctx)
}
// Get gets a record from the in-memory list.
@ -116,7 +72,7 @@ func (srv *Server) Get(ctx context.Context, req *databroker.GetRequest) (*databr
Str("id", req.GetId()).
Msg("get")
db, version, err := srv.getBackend()
db, err := srv.getBackend()
if err != nil {
return nil, err
}
@ -130,8 +86,7 @@ func (srv *Server) Get(ctx context.Context, req *databroker.GetRequest) (*databr
return nil, status.Error(codes.NotFound, "record not found")
}
return &databroker.GetResponse{
Record: record,
ServerVersion: version,
Record: record,
}, nil
}
@ -149,7 +104,7 @@ func (srv *Server) Query(ctx context.Context, req *databroker.QueryRequest) (*da
query := strings.ToLower(req.GetQuery())
db, _, err := srv.getBackend()
db, err := srv.getBackend()
if err != nil {
return nil, err
}
@ -189,15 +144,17 @@ func (srv *Server) Put(ctx context.Context, req *databroker.PutRequest) (*databr
Str("id", record.GetId()).
Msg("put")
db, version, err := srv.getBackend()
db, err := srv.getBackend()
if err != nil {
return nil, err
}
if err := db.Put(ctx, record); err != nil {
serverVersion, err := db.Put(ctx, record)
if err != nil {
return nil, err
}
return &databroker.PutResponse{
ServerVersion: version,
ServerVersion: serverVersion,
Record: record,
}, nil
}
@ -207,7 +164,7 @@ func (srv *Server) SetOptions(ctx context.Context, req *databroker.SetOptionsReq
_, span := trace.StartSpan(ctx, "databroker.grpc.SetOptions")
defer span.End()
backend, _, err := srv.getBackend()
backend, err := srv.getBackend()
if err != nil {
return nil, err
}
@ -226,29 +183,25 @@ func (srv *Server) SetOptions(ctx context.Context, req *databroker.SetOptionsReq
// Sync streams updates for the given record type.
func (srv *Server) Sync(req *databroker.SyncRequest, stream databroker.DataBrokerService_SyncServer) error {
_, span := trace.StartSpan(stream.Context(), "databroker.grpc.Sync")
ctx := stream.Context()
ctx, span := trace.StartSpan(ctx, "databroker.grpc.Sync")
defer span.End()
log.Info(stream.Context()).
ctx, cancel := context.WithCancel(ctx)
defer cancel()
log.Info(ctx).
Str("peer", grpcutil.GetPeerAddr(stream.Context())).
Uint64("server_version", req.GetServerVersion()).
Uint64("record_version", req.GetRecordVersion()).
Msg("sync")
backend, serverVersion, err := srv.getBackend()
backend, err := srv.getBackend()
if err != nil {
return err
}
// reset record version if the server versions don't match
if req.GetServerVersion() != serverVersion {
return status.Errorf(codes.Aborted, "invalid server version, got %d, expected: %d", req.GetServerVersion(), serverVersion)
}
ctx := stream.Context()
ctx, cancel := context.WithCancel(ctx)
defer cancel()
recordStream, err := backend.Sync(ctx, req.GetRecordVersion())
recordStream, err := backend.Sync(ctx, req.GetServerVersion(), req.GetRecordVersion())
if err != nil {
return err
}
@ -256,8 +209,7 @@ func (srv *Server) Sync(req *databroker.SyncRequest, stream databroker.DataBroke
for recordStream.Next(true) {
err = stream.Send(&databroker.SyncResponse{
ServerVersion: serverVersion,
Record: recordStream.Record(),
Record: recordStream.Record(),
})
if err != nil {
return err
@ -269,23 +221,24 @@ func (srv *Server) Sync(req *databroker.SyncRequest, stream databroker.DataBroke
// SyncLatest returns the latest value of every record in the databroker as a stream of records.
func (srv *Server) SyncLatest(req *databroker.SyncLatestRequest, stream databroker.DataBrokerService_SyncLatestServer) error {
_, span := trace.StartSpan(stream.Context(), "databroker.grpc.SyncLatest")
ctx := stream.Context()
ctx, span := trace.StartSpan(ctx, "databroker.grpc.SyncLatest")
defer span.End()
log.Info(stream.Context()).
ctx, cancel := context.WithCancel(ctx)
defer cancel()
log.Info(ctx).
Str("peer", grpcutil.GetPeerAddr(stream.Context())).
Str("type", req.GetType()).
Msg("sync latest")
backend, serverVersion, err := srv.getBackend()
backend, err := srv.getBackend()
if err != nil {
return err
}
ctx := stream.Context()
ctx, cancel := context.WithCancel(ctx)
defer cancel()
records, latestRecordVersion, err := backend.GetAll(ctx)
records, versions, err := backend.GetAll(ctx)
if err != nil {
return err
}
@ -306,25 +259,20 @@ func (srv *Server) SyncLatest(req *databroker.SyncLatestRequest, stream databrok
// always send the server version last in case there are no records
return stream.Send(&databroker.SyncLatestResponse{
Response: &databroker.SyncLatestResponse_Versions{
Versions: &databroker.Versions{
ServerVersion: serverVersion,
LatestRecordVersion: latestRecordVersion,
},
Versions: versions,
},
})
}
func (srv *Server) getBackend() (backend storage.Backend, version uint64, err error) {
func (srv *Server) getBackend() (backend storage.Backend, err error) {
// double-checked locking:
// first try the read lock, then re-try with the write lock, and finally create a new backend if nil
srv.mu.RLock()
backend = srv.backend
version = srv.version
srv.mu.RUnlock()
if backend == nil {
srv.mu.Lock()
backend = srv.backend
version = srv.version
var err error
if backend == nil {
backend, err = srv.newBackendLocked()
@ -332,24 +280,10 @@ func (srv *Server) getBackend() (backend storage.Backend, version uint64, err er
}
srv.mu.Unlock()
if err != nil {
return nil, 0, err
return nil, err
}
}
return backend, version, nil
}
func (srv *Server) getBackendLocked() (backend storage.Backend, version uint64, err error) {
backend = srv.backend
version = srv.version
if backend == nil {
var err error
backend, err = srv.newBackendLocked()
srv.backend = backend
if err != nil {
return nil, 0, err
}
}
return backend, version, nil
return backend, nil
}
func (srv *Server) newBackendLocked() (backend storage.Backend, err error) {