mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-28 08:27:26 +02:00
databroker: store server version in backend (#2142)
This commit is contained in:
parent
1b698053f6
commit
91c7dc742f
15 changed files with 317 additions and 333 deletions
|
@ -12,8 +12,6 @@ import (
|
|||
"github.com/google/go-cmp/cmp"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/types/known/anypb"
|
||||
"google.golang.org/protobuf/types/known/wrapperspb"
|
||||
|
||||
"github.com/pomerium/pomerium/config"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
|
@ -26,17 +24,11 @@ import (
|
|||
"github.com/pomerium/pomerium/pkg/storage/redis"
|
||||
)
|
||||
|
||||
const (
|
||||
recordTypeServerVersion = "server_version"
|
||||
serverVersionKey = "version"
|
||||
)
|
||||
|
||||
// Server implements the databroker service using an in memory database.
|
||||
type Server struct {
|
||||
cfg *serverConfig
|
||||
|
||||
mu sync.RWMutex
|
||||
version uint64
|
||||
backend storage.Backend
|
||||
}
|
||||
|
||||
|
@ -47,40 +39,6 @@ func New(options ...ServerOption) *Server {
|
|||
return srv
|
||||
}
|
||||
|
||||
func (srv *Server) initVersion(ctx context.Context) {
|
||||
db, _, err := srv.getBackendLocked()
|
||||
if err != nil {
|
||||
log.Error(ctx).Err(err).Msg("failed to init server version")
|
||||
return
|
||||
}
|
||||
|
||||
// Get version from storage first.
|
||||
r, err := db.Get(ctx, recordTypeServerVersion, serverVersionKey)
|
||||
switch {
|
||||
case err == nil:
|
||||
var sv wrapperspb.UInt64Value
|
||||
if err := r.GetData().UnmarshalTo(&sv); err == nil {
|
||||
log.Debug(ctx).Uint64("server_version", sv.Value).Msg("got db version from Backend")
|
||||
srv.version = sv.Value
|
||||
}
|
||||
return
|
||||
case errors.Is(err, storage.ErrNotFound): // no server version, so we'll create a new one
|
||||
case err != nil:
|
||||
log.Error(ctx).Err(err).Msg("failed to retrieve server version")
|
||||
return
|
||||
}
|
||||
|
||||
srv.version = cryptutil.NewRandomUInt64()
|
||||
data, _ := anypb.New(wrapperspb.UInt64(srv.version))
|
||||
if err := db.Put(context.Background(), &databroker.Record{
|
||||
Type: recordTypeServerVersion,
|
||||
Id: serverVersionKey,
|
||||
Data: data,
|
||||
}); err != nil {
|
||||
log.Warn(ctx).Err(err).Msg("failed to save server version.")
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateConfig updates the server with the new options.
|
||||
func (srv *Server) UpdateConfig(options ...ServerOption) {
|
||||
srv.mu.Lock()
|
||||
|
@ -102,8 +60,6 @@ func (srv *Server) UpdateConfig(options ...ServerOption) {
|
|||
}
|
||||
srv.backend = nil
|
||||
}
|
||||
|
||||
srv.initVersion(ctx)
|
||||
}
|
||||
|
||||
// Get gets a record from the in-memory list.
|
||||
|
@ -116,7 +72,7 @@ func (srv *Server) Get(ctx context.Context, req *databroker.GetRequest) (*databr
|
|||
Str("id", req.GetId()).
|
||||
Msg("get")
|
||||
|
||||
db, version, err := srv.getBackend()
|
||||
db, err := srv.getBackend()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -130,8 +86,7 @@ func (srv *Server) Get(ctx context.Context, req *databroker.GetRequest) (*databr
|
|||
return nil, status.Error(codes.NotFound, "record not found")
|
||||
}
|
||||
return &databroker.GetResponse{
|
||||
Record: record,
|
||||
ServerVersion: version,
|
||||
Record: record,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
@ -149,7 +104,7 @@ func (srv *Server) Query(ctx context.Context, req *databroker.QueryRequest) (*da
|
|||
|
||||
query := strings.ToLower(req.GetQuery())
|
||||
|
||||
db, _, err := srv.getBackend()
|
||||
db, err := srv.getBackend()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -189,15 +144,17 @@ func (srv *Server) Put(ctx context.Context, req *databroker.PutRequest) (*databr
|
|||
Str("id", record.GetId()).
|
||||
Msg("put")
|
||||
|
||||
db, version, err := srv.getBackend()
|
||||
db, err := srv.getBackend()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := db.Put(ctx, record); err != nil {
|
||||
|
||||
serverVersion, err := db.Put(ctx, record)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &databroker.PutResponse{
|
||||
ServerVersion: version,
|
||||
ServerVersion: serverVersion,
|
||||
Record: record,
|
||||
}, nil
|
||||
}
|
||||
|
@ -207,7 +164,7 @@ func (srv *Server) SetOptions(ctx context.Context, req *databroker.SetOptionsReq
|
|||
_, span := trace.StartSpan(ctx, "databroker.grpc.SetOptions")
|
||||
defer span.End()
|
||||
|
||||
backend, _, err := srv.getBackend()
|
||||
backend, err := srv.getBackend()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -226,29 +183,25 @@ func (srv *Server) SetOptions(ctx context.Context, req *databroker.SetOptionsReq
|
|||
|
||||
// Sync streams updates for the given record type.
|
||||
func (srv *Server) Sync(req *databroker.SyncRequest, stream databroker.DataBrokerService_SyncServer) error {
|
||||
_, span := trace.StartSpan(stream.Context(), "databroker.grpc.Sync")
|
||||
ctx := stream.Context()
|
||||
ctx, span := trace.StartSpan(ctx, "databroker.grpc.Sync")
|
||||
defer span.End()
|
||||
log.Info(stream.Context()).
|
||||
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
log.Info(ctx).
|
||||
Str("peer", grpcutil.GetPeerAddr(stream.Context())).
|
||||
Uint64("server_version", req.GetServerVersion()).
|
||||
Uint64("record_version", req.GetRecordVersion()).
|
||||
Msg("sync")
|
||||
|
||||
backend, serverVersion, err := srv.getBackend()
|
||||
backend, err := srv.getBackend()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// reset record version if the server versions don't match
|
||||
if req.GetServerVersion() != serverVersion {
|
||||
return status.Errorf(codes.Aborted, "invalid server version, got %d, expected: %d", req.GetServerVersion(), serverVersion)
|
||||
}
|
||||
|
||||
ctx := stream.Context()
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
recordStream, err := backend.Sync(ctx, req.GetRecordVersion())
|
||||
recordStream, err := backend.Sync(ctx, req.GetServerVersion(), req.GetRecordVersion())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -256,8 +209,7 @@ func (srv *Server) Sync(req *databroker.SyncRequest, stream databroker.DataBroke
|
|||
|
||||
for recordStream.Next(true) {
|
||||
err = stream.Send(&databroker.SyncResponse{
|
||||
ServerVersion: serverVersion,
|
||||
Record: recordStream.Record(),
|
||||
Record: recordStream.Record(),
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -269,23 +221,24 @@ func (srv *Server) Sync(req *databroker.SyncRequest, stream databroker.DataBroke
|
|||
|
||||
// SyncLatest returns the latest value of every record in the databroker as a stream of records.
|
||||
func (srv *Server) SyncLatest(req *databroker.SyncLatestRequest, stream databroker.DataBrokerService_SyncLatestServer) error {
|
||||
_, span := trace.StartSpan(stream.Context(), "databroker.grpc.SyncLatest")
|
||||
ctx := stream.Context()
|
||||
ctx, span := trace.StartSpan(ctx, "databroker.grpc.SyncLatest")
|
||||
defer span.End()
|
||||
log.Info(stream.Context()).
|
||||
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
log.Info(ctx).
|
||||
Str("peer", grpcutil.GetPeerAddr(stream.Context())).
|
||||
Str("type", req.GetType()).
|
||||
Msg("sync latest")
|
||||
|
||||
backend, serverVersion, err := srv.getBackend()
|
||||
backend, err := srv.getBackend()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ctx := stream.Context()
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
records, latestRecordVersion, err := backend.GetAll(ctx)
|
||||
records, versions, err := backend.GetAll(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -306,25 +259,20 @@ func (srv *Server) SyncLatest(req *databroker.SyncLatestRequest, stream databrok
|
|||
// always send the server version last in case there are no records
|
||||
return stream.Send(&databroker.SyncLatestResponse{
|
||||
Response: &databroker.SyncLatestResponse_Versions{
|
||||
Versions: &databroker.Versions{
|
||||
ServerVersion: serverVersion,
|
||||
LatestRecordVersion: latestRecordVersion,
|
||||
},
|
||||
Versions: versions,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func (srv *Server) getBackend() (backend storage.Backend, version uint64, err error) {
|
||||
func (srv *Server) getBackend() (backend storage.Backend, err error) {
|
||||
// double-checked locking:
|
||||
// first try the read lock, then re-try with the write lock, and finally create a new backend if nil
|
||||
srv.mu.RLock()
|
||||
backend = srv.backend
|
||||
version = srv.version
|
||||
srv.mu.RUnlock()
|
||||
if backend == nil {
|
||||
srv.mu.Lock()
|
||||
backend = srv.backend
|
||||
version = srv.version
|
||||
var err error
|
||||
if backend == nil {
|
||||
backend, err = srv.newBackendLocked()
|
||||
|
@ -332,24 +280,10 @@ func (srv *Server) getBackend() (backend storage.Backend, version uint64, err er
|
|||
}
|
||||
srv.mu.Unlock()
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return backend, version, nil
|
||||
}
|
||||
|
||||
func (srv *Server) getBackendLocked() (backend storage.Backend, version uint64, err error) {
|
||||
backend = srv.backend
|
||||
version = srv.version
|
||||
if backend == nil {
|
||||
var err error
|
||||
backend, err = srv.newBackendLocked()
|
||||
srv.backend = backend
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
}
|
||||
return backend, version, nil
|
||||
return backend, nil
|
||||
}
|
||||
|
||||
func (srv *Server) newBackendLocked() (backend storage.Backend, err error) {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue