mirror of
https://github.com/pomerium/pomerium.git
synced 2025-04-29 18:36:30 +02:00
This also replaces instances where we manually write "return ctx.Err()" with "return context.Cause(ctx)" which is functionally identical, but will also correctly propagate cause errors if present.
471 lines
12 KiB
Go
471 lines
12 KiB
Go
// Package databroker contains a data broker implementation.
|
|
package databroker
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"strings"
|
|
"sync"
|
|
|
|
"github.com/google/go-cmp/cmp"
|
|
"github.com/google/uuid"
|
|
"google.golang.org/grpc/codes"
|
|
"google.golang.org/grpc/status"
|
|
"google.golang.org/protobuf/types/known/emptypb"
|
|
|
|
"github.com/pomerium/pomerium/config"
|
|
"github.com/pomerium/pomerium/internal/log"
|
|
"github.com/pomerium/pomerium/internal/registry"
|
|
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
|
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
|
"github.com/pomerium/pomerium/pkg/storage"
|
|
"github.com/pomerium/pomerium/pkg/storage/inmemory"
|
|
"github.com/pomerium/pomerium/pkg/storage/postgres"
|
|
)
|
|
|
|
// Server implements the databroker service using an in memory database.
|
|
type Server struct {
|
|
cfg *serverConfig
|
|
|
|
mu sync.RWMutex
|
|
backend storage.Backend
|
|
backendCtx context.Context
|
|
registry registry.Interface
|
|
}
|
|
|
|
// New creates a new server.
|
|
func New(ctx context.Context, options ...ServerOption) *Server {
|
|
srv := &Server{
|
|
backendCtx: ctx,
|
|
}
|
|
srv.UpdateConfig(ctx, options...)
|
|
return srv
|
|
}
|
|
|
|
// UpdateConfig updates the server with the new options.
|
|
func (srv *Server) UpdateConfig(ctx context.Context, options ...ServerOption) {
|
|
srv.mu.Lock()
|
|
defer srv.mu.Unlock()
|
|
|
|
cfg := newServerConfig(options...)
|
|
if cmp.Equal(cfg, srv.cfg, cmp.AllowUnexported(serverConfig{})) {
|
|
log.Ctx(ctx).Debug().Msg("databroker: no changes detected, re-using existing DBs")
|
|
return
|
|
}
|
|
srv.cfg = cfg
|
|
|
|
if srv.backend != nil {
|
|
err := srv.backend.Close()
|
|
if err != nil {
|
|
log.Ctx(ctx).Error().Err(err).Msg("databroker: error closing backend")
|
|
}
|
|
srv.backend = nil
|
|
}
|
|
|
|
if srv.registry != nil {
|
|
err := srv.registry.Close()
|
|
if err != nil {
|
|
log.Ctx(ctx).Error().Err(err).Msg("databroker: error closing registry")
|
|
}
|
|
srv.registry = nil
|
|
}
|
|
}
|
|
|
|
// AcquireLease acquires a lease.
|
|
func (srv *Server) AcquireLease(ctx context.Context, req *databroker.AcquireLeaseRequest) (*databroker.AcquireLeaseResponse, error) {
|
|
ctx, span := trace.StartSpan(ctx, "databroker.grpc.AcquireLease")
|
|
defer span.End()
|
|
log.Ctx(ctx).Debug().
|
|
Str("name", req.GetName()).
|
|
Dur("duration", req.GetDuration().AsDuration()).
|
|
Msg("acquire lease")
|
|
|
|
db, err := srv.getBackend(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
leaseID := uuid.NewString()
|
|
acquired, err := db.Lease(ctx, req.GetName(), leaseID, req.GetDuration().AsDuration())
|
|
if err != nil {
|
|
return nil, err
|
|
} else if !acquired {
|
|
return nil, status.Error(codes.AlreadyExists, "lease is already taken")
|
|
}
|
|
|
|
return &databroker.AcquireLeaseResponse{
|
|
Id: leaseID,
|
|
}, nil
|
|
}
|
|
|
|
// Get gets a record from the in-memory list.
|
|
func (srv *Server) Get(ctx context.Context, req *databroker.GetRequest) (*databroker.GetResponse, error) {
|
|
ctx, span := trace.StartSpan(ctx, "databroker.grpc.Get")
|
|
defer span.End()
|
|
log.Ctx(ctx).Debug().
|
|
Str("type", req.GetType()).
|
|
Str("id", req.GetId()).
|
|
Msg("get")
|
|
|
|
db, err := srv.getBackend(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
record, err := db.Get(ctx, req.GetType(), req.GetId())
|
|
switch {
|
|
case errors.Is(err, storage.ErrNotFound):
|
|
return nil, status.Error(codes.NotFound, "record not found")
|
|
case err != nil:
|
|
return nil, status.Error(codes.Internal, err.Error())
|
|
case record.DeletedAt != nil:
|
|
return nil, status.Error(codes.NotFound, "record not found")
|
|
}
|
|
return &databroker.GetResponse{
|
|
Record: record,
|
|
}, nil
|
|
}
|
|
|
|
// ListTypes lists all the record types.
|
|
func (srv *Server) ListTypes(ctx context.Context, _ *emptypb.Empty) (*databroker.ListTypesResponse, error) {
|
|
ctx, span := trace.StartSpan(ctx, "databroker.grpc.ListTypes")
|
|
defer span.End()
|
|
log.Ctx(ctx).Debug().Msg("list types")
|
|
|
|
db, err := srv.getBackend(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
types, err := db.ListTypes(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &databroker.ListTypesResponse{Types: types}, nil
|
|
}
|
|
|
|
// Query queries for records.
|
|
func (srv *Server) Query(ctx context.Context, req *databroker.QueryRequest) (*databroker.QueryResponse, error) {
|
|
ctx, span := trace.StartSpan(ctx, "databroker.grpc.Query")
|
|
defer span.End()
|
|
log.Ctx(ctx).Debug().
|
|
Str("type", req.GetType()).
|
|
Str("query", req.GetQuery()).
|
|
Int64("offset", req.GetOffset()).
|
|
Int64("limit", req.GetLimit()).
|
|
Interface("filter", req.GetFilter()).
|
|
Msg("query")
|
|
|
|
query := strings.ToLower(req.GetQuery())
|
|
|
|
db, err := srv.getBackend(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
expr, err := storage.FilterExpressionFromStruct(req.GetFilter())
|
|
if err != nil {
|
|
return nil, status.Errorf(codes.InvalidArgument, "invalid query filter: %v", err)
|
|
}
|
|
|
|
serverVersion, recordVersion, stream, err := db.SyncLatest(ctx, req.GetType(), expr)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer stream.Close()
|
|
|
|
var filtered []*databroker.Record
|
|
for stream.Next(false) {
|
|
record := stream.Record()
|
|
|
|
if query != "" && !storage.MatchAny(record.GetData(), query) {
|
|
continue
|
|
}
|
|
|
|
filtered = append(filtered, record)
|
|
}
|
|
if stream.Err() != nil {
|
|
return nil, stream.Err()
|
|
}
|
|
|
|
records, totalCount := databroker.ApplyOffsetAndLimit(filtered, int(req.GetOffset()), int(req.GetLimit()))
|
|
return &databroker.QueryResponse{
|
|
Records: records,
|
|
TotalCount: int64(totalCount),
|
|
ServerVersion: serverVersion,
|
|
RecordVersion: recordVersion,
|
|
}, nil
|
|
}
|
|
|
|
// Put updates an existing record or adds a new one.
|
|
func (srv *Server) Put(ctx context.Context, req *databroker.PutRequest) (*databroker.PutResponse, error) {
|
|
ctx, span := trace.StartSpan(ctx, "databroker.grpc.Put")
|
|
defer span.End()
|
|
|
|
records := req.GetRecords()
|
|
if len(records) == 1 {
|
|
log.Ctx(ctx).Debug().
|
|
Str("record-type", records[0].GetType()).
|
|
Str("record-id", records[0].GetId()).
|
|
Msg("put")
|
|
} else {
|
|
var recordType string
|
|
for _, record := range records {
|
|
recordType = record.GetType()
|
|
}
|
|
log.Ctx(ctx).Debug().
|
|
Int("record-count", len(records)).
|
|
Str("record-type", recordType).
|
|
Msg("put")
|
|
}
|
|
|
|
db, err := srv.getBackend(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
serverVersion, err := db.Put(ctx, records)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
res := &databroker.PutResponse{
|
|
ServerVersion: serverVersion,
|
|
Records: records,
|
|
}
|
|
|
|
return res, nil
|
|
}
|
|
|
|
// Patch updates specific fields of an existing record.
|
|
func (srv *Server) Patch(ctx context.Context, req *databroker.PatchRequest) (*databroker.PatchResponse, error) {
|
|
ctx, span := trace.StartSpan(ctx, "databroker.grpc.Patch")
|
|
defer span.End()
|
|
|
|
records := req.GetRecords()
|
|
if len(records) == 1 {
|
|
log.Ctx(ctx).Debug().
|
|
Str("record-type", records[0].GetType()).
|
|
Str("record-id", records[0].GetId()).
|
|
Msg("patch")
|
|
} else {
|
|
var recordType string
|
|
for _, record := range records {
|
|
recordType = record.GetType()
|
|
}
|
|
log.Ctx(ctx).Debug().
|
|
Int("record-count", len(records)).
|
|
Str("record-type", recordType).
|
|
Msg("patch")
|
|
}
|
|
|
|
db, err := srv.getBackend(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
serverVersion, patchedRecords, err := db.Patch(ctx, records, req.GetFieldMask())
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
res := &databroker.PatchResponse{
|
|
ServerVersion: serverVersion,
|
|
Records: patchedRecords,
|
|
}
|
|
|
|
return res, nil
|
|
}
|
|
|
|
// ReleaseLease releases a lease.
|
|
func (srv *Server) ReleaseLease(ctx context.Context, req *databroker.ReleaseLeaseRequest) (*emptypb.Empty, error) {
|
|
ctx, span := trace.StartSpan(ctx, "databroker.grpc.ReleaseLease")
|
|
defer span.End()
|
|
log.Ctx(ctx).Debug().
|
|
Str("name", req.GetName()).
|
|
Str("id", req.GetId()).
|
|
Msg("release lease")
|
|
|
|
db, err := srv.getBackend(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
_, err = db.Lease(ctx, req.GetName(), req.GetId(), -1)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return new(emptypb.Empty), nil
|
|
}
|
|
|
|
// RenewLease releases a lease.
|
|
func (srv *Server) RenewLease(ctx context.Context, req *databroker.RenewLeaseRequest) (*emptypb.Empty, error) {
|
|
ctx, span := trace.StartSpan(ctx, "databroker.grpc.RenewLease")
|
|
defer span.End()
|
|
log.Ctx(ctx).Debug().
|
|
Str("name", req.GetName()).
|
|
Str("id", req.GetId()).
|
|
Dur("duration", req.GetDuration().AsDuration()).
|
|
Msg("renew lease")
|
|
|
|
db, err := srv.getBackend(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
acquired, err := db.Lease(ctx, req.GetName(), req.GetId(), req.GetDuration().AsDuration())
|
|
if err != nil {
|
|
return nil, err
|
|
} else if !acquired {
|
|
return nil, status.Error(codes.AlreadyExists, "lease no longer held")
|
|
}
|
|
|
|
return new(emptypb.Empty), nil
|
|
}
|
|
|
|
// SetOptions sets options for a type in the databroker.
|
|
func (srv *Server) SetOptions(ctx context.Context, req *databroker.SetOptionsRequest) (*databroker.SetOptionsResponse, error) {
|
|
ctx, span := trace.StartSpan(ctx, "databroker.grpc.SetOptions")
|
|
defer span.End()
|
|
|
|
backend, err := srv.getBackend(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
err = backend.SetOptions(ctx, req.GetType(), req.GetOptions())
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
options, err := backend.GetOptions(ctx, req.GetType())
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &databroker.SetOptionsResponse{
|
|
Options: options,
|
|
}, nil
|
|
}
|
|
|
|
// Sync streams updates for the given record type.
|
|
func (srv *Server) Sync(req *databroker.SyncRequest, stream databroker.DataBrokerService_SyncServer) error {
|
|
ctx := stream.Context()
|
|
ctx, span := trace.StartSpan(ctx, "databroker.grpc.Sync")
|
|
defer span.End()
|
|
|
|
ctx, cancel := context.WithCancel(ctx)
|
|
defer cancel()
|
|
|
|
log.Ctx(ctx).
|
|
Debug().
|
|
Uint64("server_version", req.GetServerVersion()).
|
|
Uint64("record_version", req.GetRecordVersion()).
|
|
Msg("sync")
|
|
|
|
backend, err := srv.getBackend(ctx)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
recordStream, err := backend.Sync(ctx, req.GetType(), req.GetServerVersion(), req.GetRecordVersion())
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer func() { _ = recordStream.Close() }()
|
|
|
|
for recordStream.Next(true) {
|
|
err = stream.Send(&databroker.SyncResponse{
|
|
Record: recordStream.Record(),
|
|
})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return recordStream.Err()
|
|
}
|
|
|
|
// SyncLatest returns the latest value of every record in the databroker as a stream of records.
|
|
func (srv *Server) SyncLatest(req *databroker.SyncLatestRequest, stream databroker.DataBrokerService_SyncLatestServer) error {
|
|
ctx := stream.Context()
|
|
ctx, span := trace.StartSpan(ctx, "databroker.grpc.SyncLatest")
|
|
defer span.End()
|
|
|
|
ctx, cancel := context.WithCancel(ctx)
|
|
defer cancel()
|
|
|
|
log.Ctx(ctx).Debug().
|
|
Str("type", req.GetType()).
|
|
Msg("sync latest")
|
|
|
|
backend, err := srv.getBackend(ctx)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
serverVersion, recordVersion, recordStream, err := backend.SyncLatest(ctx, req.GetType(), nil)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
for recordStream.Next(false) {
|
|
record := recordStream.Record()
|
|
if req.GetType() == "" || req.GetType() == record.GetType() {
|
|
err = stream.Send(&databroker.SyncLatestResponse{
|
|
Response: &databroker.SyncLatestResponse_Record{
|
|
Record: record,
|
|
},
|
|
})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
if recordStream.Err() != nil {
|
|
return err
|
|
}
|
|
|
|
// always send the server version last in case there are no records
|
|
return stream.Send(&databroker.SyncLatestResponse{
|
|
Response: &databroker.SyncLatestResponse_Versions{
|
|
Versions: &databroker.Versions{
|
|
ServerVersion: serverVersion,
|
|
LatestRecordVersion: recordVersion,
|
|
},
|
|
},
|
|
})
|
|
}
|
|
|
|
func (srv *Server) getBackend(ctx context.Context) (backend storage.Backend, err error) {
|
|
// double-checked locking:
|
|
// first try the read lock, then re-try with the write lock, and finally create a new backend if nil
|
|
srv.mu.RLock()
|
|
backend = srv.backend
|
|
srv.mu.RUnlock()
|
|
if backend == nil {
|
|
srv.mu.Lock()
|
|
backend = srv.backend
|
|
var err error
|
|
if backend == nil {
|
|
backend, err = srv.newBackendLocked(ctx)
|
|
srv.backend = backend
|
|
}
|
|
srv.mu.Unlock()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
return backend, nil
|
|
}
|
|
|
|
func (srv *Server) newBackendLocked(ctx context.Context) (storage.Backend, error) {
|
|
switch srv.cfg.storageType {
|
|
case config.StorageInMemoryName:
|
|
log.Ctx(ctx).Info().Msg("initializing new in-memory store")
|
|
return inmemory.New(), nil
|
|
case config.StoragePostgresName:
|
|
log.Ctx(ctx).Info().Msg("initializing new postgres store")
|
|
// NB: the context passed to postgres.New here is a separate context scoped
|
|
// to the lifetime of the server itself. 'ctx' may be a short-lived request
|
|
// context, since the backend is lazy-initialized.
|
|
return postgres.New(srv.backendCtx, srv.cfg.storageConnectionString), nil
|
|
default:
|
|
return nil, fmt.Errorf("unsupported storage type: %s", srv.cfg.storageType)
|
|
}
|
|
}
|