mirror of
https://github.com/pomerium/pomerium.git
synced 2025-04-29 18:36:30 +02:00
126 lines
3 KiB
Go
126 lines
3 KiB
Go
package databroker
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"io"
|
|
|
|
"github.com/pomerium/pomerium/config"
|
|
"github.com/pomerium/pomerium/internal/log"
|
|
"github.com/pomerium/pomerium/internal/registry"
|
|
"github.com/pomerium/pomerium/internal/registry/inmemory"
|
|
"github.com/pomerium/pomerium/internal/registry/redis"
|
|
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
|
registrypb "github.com/pomerium/pomerium/pkg/grpc/registry"
|
|
"github.com/pomerium/pomerium/pkg/storage"
|
|
)
|
|
|
|
type registryWatchServer struct {
|
|
registrypb.Registry_WatchServer
|
|
ctx context.Context
|
|
}
|
|
|
|
func (stream registryWatchServer) Context() context.Context {
|
|
return stream.ctx
|
|
}
|
|
|
|
// Report calls the registry Report method.
|
|
func (srv *Server) Report(ctx context.Context, req *registrypb.RegisterRequest) (*registrypb.RegisterResponse, error) {
|
|
ctx, span := trace.StartSpan(ctx, "databroker.grpc.Report")
|
|
defer span.End()
|
|
|
|
r, err := srv.getRegistry()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return r.Report(ctx, req)
|
|
}
|
|
|
|
// List calls the registry List method.
|
|
func (srv *Server) List(ctx context.Context, req *registrypb.ListRequest) (*registrypb.ServiceList, error) {
|
|
ctx, span := trace.StartSpan(ctx, "databroker.grpc.List")
|
|
defer span.End()
|
|
|
|
r, err := srv.getRegistry()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return r.List(ctx, req)
|
|
}
|
|
|
|
// Watch calls the registry Watch method.
|
|
func (srv *Server) Watch(req *registrypb.ListRequest, stream registrypb.Registry_WatchServer) error {
|
|
ctx := stream.Context()
|
|
ctx, span := trace.StartSpan(ctx, "databroker.grpc.Watch")
|
|
defer span.End()
|
|
|
|
r, err := srv.getRegistry()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return r.Watch(req, registryWatchServer{
|
|
Registry_WatchServer: stream,
|
|
ctx: ctx,
|
|
})
|
|
}
|
|
|
|
func (srv *Server) getRegistry() (registry.Interface, error) {
|
|
backend, err := srv.getBackend()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// double-checked locking
|
|
srv.mu.RLock()
|
|
r := srv.registry
|
|
srv.mu.RUnlock()
|
|
if r == nil {
|
|
srv.mu.Lock()
|
|
r = srv.registry
|
|
var err error
|
|
if r == nil {
|
|
r, err = srv.newRegistryLocked(backend)
|
|
srv.registry = r
|
|
}
|
|
srv.mu.Unlock()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
return r, nil
|
|
}
|
|
|
|
func (srv *Server) newRegistryLocked(backend storage.Backend) (registry.Interface, error) {
|
|
ctx := context.Background()
|
|
|
|
if hasRegistryServer, ok := backend.(interface {
|
|
RegistryServer() registrypb.RegistryServer
|
|
}); ok {
|
|
log.Info(ctx).Msg("using registry via storage")
|
|
return struct {
|
|
io.Closer
|
|
registrypb.RegistryServer
|
|
}{backend, hasRegistryServer.RegistryServer()}, nil
|
|
}
|
|
|
|
switch srv.cfg.storageType {
|
|
case config.StorageInMemoryName:
|
|
log.Info(ctx).Msg("using in-memory registry")
|
|
return inmemory.New(ctx, srv.cfg.registryTTL), nil
|
|
case config.StorageRedisName:
|
|
log.Info(ctx).Msg("using redis registry")
|
|
r, err := redis.New(
|
|
srv.cfg.storageConnectionString,
|
|
redis.WithTLSConfig(srv.getTLSConfigLocked(ctx)),
|
|
)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create new redis registry: %w", err)
|
|
}
|
|
return r, nil
|
|
}
|
|
|
|
return nil, fmt.Errorf("unsupported registry type: %s", srv.cfg.storageType)
|
|
}
|