// Package databroker contains a data broker implementation. package databroker import ( "context" "errors" "fmt" "strings" "sync" "github.com/google/go-cmp/cmp" "github.com/google/uuid" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" "google.golang.org/protobuf/types/known/emptypb" "github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/registry" "github.com/pomerium/pomerium/pkg/grpc/databroker" "github.com/pomerium/pomerium/pkg/storage" "github.com/pomerium/pomerium/pkg/storage/inmemory" "github.com/pomerium/pomerium/pkg/storage/postgres" "github.com/pomerium/pomerium/pkg/telemetry/trace" oteltrace "go.opentelemetry.io/otel/trace" ) // Server implements the databroker service using an in memory database. type Server struct { cfg *serverConfig mu sync.RWMutex backend storage.Backend backendCtx context.Context registry registry.Interface tracerProvider oteltrace.TracerProvider tracer oteltrace.Tracer } // New creates a new server. func New(ctx context.Context, tracerProvider oteltrace.TracerProvider, options ...ServerOption) *Server { tracer := tracerProvider.Tracer(trace.PomeriumCoreTracer) srv := &Server{ backendCtx: ctx, tracerProvider: tracerProvider, tracer: tracer, } srv.UpdateConfig(ctx, options...) return srv } // UpdateConfig updates the server with the new options. func (srv *Server) UpdateConfig(ctx context.Context, options ...ServerOption) { srv.mu.Lock() defer srv.mu.Unlock() cfg := newServerConfig(options...) if cmp.Equal(cfg, srv.cfg, cmp.AllowUnexported(serverConfig{})) { log.Ctx(ctx).Debug().Msg("databroker: no changes detected, re-using existing DBs") return } srv.cfg = cfg if srv.backend != nil { err := srv.backend.Close() if err != nil { log.Ctx(ctx).Error().Err(err).Msg("databroker: error closing backend") } srv.backend = nil } if srv.registry != nil { err := srv.registry.Close() if err != nil { log.Ctx(ctx).Error().Err(err).Msg("databroker: error closing registry") } srv.registry = nil } } // AcquireLease acquires a lease. func (srv *Server) AcquireLease(ctx context.Context, req *databroker.AcquireLeaseRequest) (*databroker.AcquireLeaseResponse, error) { ctx, span := srv.tracer.Start(ctx, "databroker.grpc.AcquireLease") defer span.End() log.Ctx(ctx).Debug(). Str("name", req.GetName()). Dur("duration", req.GetDuration().AsDuration()). Msg("acquire lease") db, err := srv.getBackend(ctx) if err != nil { return nil, err } leaseID := uuid.NewString() acquired, err := db.Lease(ctx, req.GetName(), leaseID, req.GetDuration().AsDuration()) if err != nil { return nil, err } else if !acquired { return nil, status.Error(codes.AlreadyExists, "lease is already taken") } return &databroker.AcquireLeaseResponse{ Id: leaseID, }, nil } // Get gets a record from the in-memory list. func (srv *Server) Get(ctx context.Context, req *databroker.GetRequest) (*databroker.GetResponse, error) { ctx, span := srv.tracer.Start(ctx, "databroker.grpc.Get") defer span.End() log.Ctx(ctx).Debug(). Str("type", req.GetType()). Str("id", req.GetId()). Msg("get") db, err := srv.getBackend(ctx) if err != nil { return nil, err } record, err := db.Get(ctx, req.GetType(), req.GetId()) switch { case errors.Is(err, storage.ErrNotFound): return nil, status.Error(codes.NotFound, "record not found") case err != nil: return nil, status.Error(codes.Internal, err.Error()) case record.DeletedAt != nil: return nil, status.Error(codes.NotFound, "record not found") } return &databroker.GetResponse{ Record: record, }, nil } // ListTypes lists all the record types. func (srv *Server) ListTypes(ctx context.Context, _ *emptypb.Empty) (*databroker.ListTypesResponse, error) { ctx, span := srv.tracer.Start(ctx, "databroker.grpc.ListTypes") defer span.End() log.Ctx(ctx).Debug().Msg("list types") db, err := srv.getBackend(ctx) if err != nil { return nil, err } types, err := db.ListTypes(ctx) if err != nil { return nil, err } return &databroker.ListTypesResponse{Types: types}, nil } // Query queries for records. func (srv *Server) Query(ctx context.Context, req *databroker.QueryRequest) (*databroker.QueryResponse, error) { ctx, span := srv.tracer.Start(ctx, "databroker.grpc.Query") defer span.End() log.Ctx(ctx).Debug(). Str("type", req.GetType()). Str("query", req.GetQuery()). Int64("offset", req.GetOffset()). Int64("limit", req.GetLimit()). Interface("filter", req.GetFilter()). Msg("query") query := strings.ToLower(req.GetQuery()) db, err := srv.getBackend(ctx) if err != nil { return nil, err } expr, err := storage.FilterExpressionFromStruct(req.GetFilter()) if err != nil { return nil, status.Errorf(codes.InvalidArgument, "invalid query filter: %v", err) } serverVersion, recordVersion, stream, err := db.SyncLatest(ctx, req.GetType(), expr) if err != nil { return nil, err } defer stream.Close() var filtered []*databroker.Record for stream.Next(false) { record := stream.Record() if query != "" && !storage.MatchAny(record.GetData(), query) { continue } filtered = append(filtered, record) } if stream.Err() != nil { return nil, stream.Err() } records, totalCount := databroker.ApplyOffsetAndLimit(filtered, int(req.GetOffset()), int(req.GetLimit())) return &databroker.QueryResponse{ Records: records, TotalCount: int64(totalCount), ServerVersion: serverVersion, RecordVersion: recordVersion, }, nil } // Put updates an existing record or adds a new one. func (srv *Server) Put(ctx context.Context, req *databroker.PutRequest) (*databroker.PutResponse, error) { ctx, span := srv.tracer.Start(ctx, "databroker.grpc.Put") defer span.End() records := req.GetRecords() if len(records) == 1 { log.Ctx(ctx).Debug(). Str("record-type", records[0].GetType()). Str("record-id", records[0].GetId()). Msg("put") } else { var recordType string for _, record := range records { recordType = record.GetType() } log.Ctx(ctx).Debug(). Int("record-count", len(records)). Str("record-type", recordType). Msg("put") } db, err := srv.getBackend(ctx) if err != nil { return nil, err } serverVersion, err := db.Put(ctx, records) if err != nil { return nil, err } res := &databroker.PutResponse{ ServerVersion: serverVersion, Records: records, } return res, nil } // Patch updates specific fields of an existing record. func (srv *Server) Patch(ctx context.Context, req *databroker.PatchRequest) (*databroker.PatchResponse, error) { ctx, span := srv.tracer.Start(ctx, "databroker.grpc.Patch") defer span.End() records := req.GetRecords() if len(records) == 1 { log.Ctx(ctx).Debug(). Str("record-type", records[0].GetType()). Str("record-id", records[0].GetId()). Msg("patch") } else { var recordType string for _, record := range records { recordType = record.GetType() } log.Ctx(ctx).Debug(). Int("record-count", len(records)). Str("record-type", recordType). Msg("patch") } db, err := srv.getBackend(ctx) if err != nil { return nil, err } serverVersion, patchedRecords, err := db.Patch(ctx, records, req.GetFieldMask()) if err != nil { return nil, err } res := &databroker.PatchResponse{ ServerVersion: serverVersion, Records: patchedRecords, } return res, nil } // ReleaseLease releases a lease. func (srv *Server) ReleaseLease(ctx context.Context, req *databroker.ReleaseLeaseRequest) (*emptypb.Empty, error) { ctx, span := srv.tracer.Start(ctx, "databroker.grpc.ReleaseLease") defer span.End() log.Ctx(ctx).Trace(). Str("name", req.GetName()). Str("id", req.GetId()). Msg("release lease") db, err := srv.getBackend(ctx) if err != nil { return nil, err } _, err = db.Lease(ctx, req.GetName(), req.GetId(), -1) if err != nil { return nil, err } return new(emptypb.Empty), nil } // RenewLease releases a lease. func (srv *Server) RenewLease(ctx context.Context, req *databroker.RenewLeaseRequest) (*emptypb.Empty, error) { ctx, span := srv.tracer.Start(ctx, "databroker.grpc.RenewLease") defer span.End() log.Ctx(ctx).Trace(). Str("name", req.GetName()). Str("id", req.GetId()). Dur("duration", req.GetDuration().AsDuration()). Msg("renew lease") db, err := srv.getBackend(ctx) if err != nil { return nil, err } acquired, err := db.Lease(ctx, req.GetName(), req.GetId(), req.GetDuration().AsDuration()) if err != nil { return nil, err } else if !acquired { return nil, status.Error(codes.AlreadyExists, "lease no longer held") } return new(emptypb.Empty), nil } // SetOptions sets options for a type in the databroker. func (srv *Server) SetOptions(ctx context.Context, req *databroker.SetOptionsRequest) (*databroker.SetOptionsResponse, error) { ctx, span := srv.tracer.Start(ctx, "databroker.grpc.SetOptions") defer span.End() backend, err := srv.getBackend(ctx) if err != nil { return nil, err } err = backend.SetOptions(ctx, req.GetType(), req.GetOptions()) if err != nil { return nil, err } options, err := backend.GetOptions(ctx, req.GetType()) if err != nil { return nil, err } return &databroker.SetOptionsResponse{ Options: options, }, nil } // Sync streams updates for the given record type. func (srv *Server) Sync(req *databroker.SyncRequest, stream databroker.DataBrokerService_SyncServer) error { ctx := stream.Context() ctx, span := srv.tracer.Start(ctx, "databroker.grpc.Sync") defer span.End() ctx, cancel := context.WithCancel(ctx) defer cancel() log.Ctx(ctx). Debug(). Uint64("server_version", req.GetServerVersion()). Uint64("record_version", req.GetRecordVersion()). Msg("sync") backend, err := srv.getBackend(ctx) if err != nil { return err } recordStream, err := backend.Sync(ctx, req.GetType(), req.GetServerVersion(), req.GetRecordVersion()) if err != nil { return err } defer func() { _ = recordStream.Close() }() for recordStream.Next(true) { err = stream.Send(&databroker.SyncResponse{ Record: recordStream.Record(), }) if err != nil { return err } } return recordStream.Err() } // SyncLatest returns the latest value of every record in the databroker as a stream of records. func (srv *Server) SyncLatest(req *databroker.SyncLatestRequest, stream databroker.DataBrokerService_SyncLatestServer) error { ctx := stream.Context() ctx, span := srv.tracer.Start(ctx, "databroker.grpc.SyncLatest") defer span.End() ctx, cancel := context.WithCancel(ctx) defer cancel() log.Ctx(ctx).Debug(). Str("type", req.GetType()). Msg("sync latest") backend, err := srv.getBackend(ctx) if err != nil { return err } serverVersion, recordVersion, recordStream, err := backend.SyncLatest(ctx, req.GetType(), nil) if err != nil { return err } for recordStream.Next(false) { record := recordStream.Record() if req.GetType() == "" || req.GetType() == record.GetType() { err = stream.Send(&databroker.SyncLatestResponse{ Response: &databroker.SyncLatestResponse_Record{ Record: record, }, }) if err != nil { return err } } } if recordStream.Err() != nil { return err } // always send the server version last in case there are no records return stream.Send(&databroker.SyncLatestResponse{ Response: &databroker.SyncLatestResponse_Versions{ Versions: &databroker.Versions{ ServerVersion: serverVersion, LatestRecordVersion: recordVersion, }, }, }) } func (srv *Server) getBackend(ctx context.Context) (backend storage.Backend, err error) { // double-checked locking: // first try the read lock, then re-try with the write lock, and finally create a new backend if nil srv.mu.RLock() backend = srv.backend srv.mu.RUnlock() if backend == nil { srv.mu.Lock() backend = srv.backend var err error if backend == nil { backend, err = srv.newBackendLocked(ctx) srv.backend = backend } srv.mu.Unlock() if err != nil { return nil, err } } return backend, nil } func (srv *Server) newBackendLocked(ctx context.Context) (storage.Backend, error) { switch srv.cfg.storageType { case config.StorageInMemoryName: log.Ctx(ctx).Info().Msg("initializing new in-memory store") return inmemory.New(), nil case config.StoragePostgresName: log.Ctx(ctx).Info().Msg("initializing new postgres store") // NB: the context passed to postgres.New here is a separate context scoped // to the lifetime of the server itself. 'ctx' may be a short-lived request // context, since the backend is lazy-initialized. return postgres.New(srv.backendCtx, srv.cfg.storageConnectionString, postgres.WithTracerProvider(srv.tracerProvider)), nil default: return nil, fmt.Errorf("unsupported storage type: %s", srv.cfg.storageType) } }