package databroker import ( "context" "fmt" "io" "github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/registry" "github.com/pomerium/pomerium/internal/registry/inmemory" "github.com/pomerium/pomerium/internal/registry/redis" "github.com/pomerium/pomerium/internal/telemetry/trace" registrypb "github.com/pomerium/pomerium/pkg/grpc/registry" "github.com/pomerium/pomerium/pkg/storage" ) type registryWatchServer struct { registrypb.Registry_WatchServer ctx context.Context } func (stream registryWatchServer) Context() context.Context { return stream.ctx } // Report calls the registry Report method. func (srv *Server) Report(ctx context.Context, req *registrypb.RegisterRequest) (*registrypb.RegisterResponse, error) { ctx, span := trace.StartSpan(ctx, "databroker.grpc.Report") defer span.End() r, err := srv.getRegistry() if err != nil { return nil, err } return r.Report(ctx, req) } // List calls the registry List method. func (srv *Server) List(ctx context.Context, req *registrypb.ListRequest) (*registrypb.ServiceList, error) { ctx, span := trace.StartSpan(ctx, "databroker.grpc.List") defer span.End() r, err := srv.getRegistry() if err != nil { return nil, err } return r.List(ctx, req) } // Watch calls the registry Watch method. func (srv *Server) Watch(req *registrypb.ListRequest, stream registrypb.Registry_WatchServer) error { ctx := stream.Context() ctx, span := trace.StartSpan(ctx, "databroker.grpc.Watch") defer span.End() r, err := srv.getRegistry() if err != nil { return err } return r.Watch(req, registryWatchServer{ Registry_WatchServer: stream, ctx: ctx, }) } func (srv *Server) getRegistry() (registry.Interface, error) { backend, err := srv.getBackend() if err != nil { return nil, err } // double-checked locking srv.mu.RLock() r := srv.registry srv.mu.RUnlock() if r == nil { srv.mu.Lock() r = srv.registry var err error if r == nil { r, err = srv.newRegistryLocked(backend) srv.registry = r } srv.mu.Unlock() if err != nil { return nil, err } } return r, nil } func (srv *Server) newRegistryLocked(backend storage.Backend) (registry.Interface, error) { ctx := context.Background() if hasRegistryServer, ok := backend.(interface { RegistryServer() registrypb.RegistryServer }); ok { log.Info(ctx).Msg("using registry via storage") return struct { io.Closer registrypb.RegistryServer }{backend, hasRegistryServer.RegistryServer()}, nil } switch srv.cfg.storageType { case config.StorageInMemoryName: log.Info(ctx).Msg("using in-memory registry") return inmemory.New(ctx, srv.cfg.registryTTL), nil case config.StorageRedisName: log.Info(ctx).Msg("using redis registry") r, err := redis.New( srv.cfg.storageConnectionString, redis.WithTLSConfig(srv.getTLSConfigLocked(ctx)), ) if err != nil { return nil, fmt.Errorf("failed to create new redis registry: %w", err) } return r, nil } return nil, fmt.Errorf("unsupported registry type: %s", srv.cfg.storageType) }