From a54d43b937b58589a975d102ba288d816909f921 Mon Sep 17 00:00:00 2001 From: Caleb Doxsey Date: Mon, 10 May 2021 10:33:37 -0600 Subject: [PATCH] registry: implement redis backend (#2179) --- databroker/cache.go | 2 + databroker/databroker.go | 26 ++ internal/cmd/pomerium/constants.go | 7 - internal/cmd/pomerium/pomerium.go | 12 - internal/databroker/config.go | 11 + internal/databroker/registry.go | 109 ++++++++ internal/databroker/server.go | 44 +-- .../redisutil/client.go.go | 15 +- .../redisutil}/client_test.go | 2 +- internal/redisutil/redisutil.go | 5 + internal/registry/constants.go | 8 +- internal/registry/inmemory/constants.go | 8 + internal/registry/{ => inmemory}/inmemory.go | 13 +- .../inmemory_test.go} | 5 +- internal/registry/redis/lua/lua.go | 20 ++ internal/registry/redis/lua/registry.lua | 20 ++ internal/registry/redis/option.go | 48 ++++ internal/registry/redis/redis.go | 254 ++++++++++++++++++ internal/registry/redis/redis_test.go | 196 ++++++++++++++ internal/registry/registry.go | 12 + pkg/storage/redis/redis.go | 19 +- 21 files changed, 772 insertions(+), 64 deletions(-) delete mode 100644 internal/cmd/pomerium/constants.go create mode 100644 internal/databroker/registry.go rename pkg/storage/redis/client.go => internal/redisutil/client.go.go (96%) rename {pkg/storage/redis => internal/redisutil}/client_test.go (99%) create mode 100644 internal/redisutil/redisutil.go create mode 100644 internal/registry/inmemory/constants.go rename internal/registry/{ => inmemory}/inmemory.go (93%) rename internal/registry/{server_test.go => inmemory/inmemory_test.go} (98%) create mode 100644 internal/registry/redis/lua/lua.go create mode 100644 internal/registry/redis/lua/registry.lua create mode 100644 internal/registry/redis/option.go create mode 100644 internal/registry/redis/redis.go create mode 100644 internal/registry/redis/redis_test.go diff --git a/databroker/cache.go b/databroker/cache.go index 71262943b..883eff7eb 100644 --- a/databroker/cache.go +++ b/databroker/cache.go @@ -22,6 +22,7 @@ import ( "github.com/pomerium/pomerium/internal/version" "github.com/pomerium/pomerium/pkg/cryptutil" "github.com/pomerium/pomerium/pkg/grpc/databroker" + "github.com/pomerium/pomerium/pkg/grpc/registry" "github.com/pomerium/pomerium/pkg/grpcutil" ) @@ -116,6 +117,7 @@ func (c *DataBroker) OnConfigChange(ctx context.Context, cfg *config.Config) { func (c *DataBroker) Register(grpcServer *grpc.Server) { databroker.RegisterDataBrokerServiceServer(grpcServer, c.dataBrokerServer) directory.RegisterDirectoryServiceServer(grpcServer, c) + registry.RegisterRegistryServer(grpcServer, c.dataBrokerServer) } // Run runs the databroker components. diff --git a/databroker/databroker.go b/databroker/databroker.go index 5fb6e2d47..2e58a520a 100644 --- a/databroker/databroker.go +++ b/databroker/databroker.go @@ -8,6 +8,7 @@ import ( "github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/internal/databroker" databrokerpb "github.com/pomerium/pomerium/pkg/grpc/databroker" + registrypb "github.com/pomerium/pomerium/pkg/grpc/registry" "github.com/pomerium/pomerium/pkg/grpcutil" ) @@ -51,6 +52,8 @@ func (srv *dataBrokerServer) setKey(cfg *config.Config) { srv.sharedKey.Store(bs) } +// Databroker functions + func (srv *dataBrokerServer) Get(ctx context.Context, req *databrokerpb.GetRequest) (*databrokerpb.GetResponse, error) { if err := grpcutil.RequireSignedJWT(ctx, srv.sharedKey.Load().([]byte)); err != nil { return nil, err @@ -92,3 +95,26 @@ func (srv *dataBrokerServer) SyncLatest(req *databrokerpb.SyncLatestRequest, str } return srv.server.SyncLatest(req, stream) } + +// Registry functions + +func (srv *dataBrokerServer) Report(ctx context.Context, req *registrypb.RegisterRequest) (*registrypb.RegisterResponse, error) { + if err := grpcutil.RequireSignedJWT(ctx, srv.sharedKey.Load().([]byte)); err != nil { + return nil, err + } + return srv.server.Report(ctx, req) +} + +func (srv *dataBrokerServer) List(ctx context.Context, req *registrypb.ListRequest) (*registrypb.ServiceList, error) { + if err := grpcutil.RequireSignedJWT(ctx, srv.sharedKey.Load().([]byte)); err != nil { + return nil, err + } + return srv.server.List(ctx, req) +} + +func (srv *dataBrokerServer) Watch(req *registrypb.ListRequest, stream registrypb.Registry_WatchServer) error { + if err := grpcutil.RequireSignedJWT(stream.Context(), srv.sharedKey.Load().([]byte)); err != nil { + return err + } + return srv.server.Watch(req, stream) +} diff --git a/internal/cmd/pomerium/constants.go b/internal/cmd/pomerium/constants.go deleted file mode 100644 index 5e99a412a..000000000 --- a/internal/cmd/pomerium/constants.go +++ /dev/null @@ -1,7 +0,0 @@ -package pomerium - -import "time" - -const ( - registryTTL = time.Minute -) diff --git a/internal/cmd/pomerium/pomerium.go b/internal/cmd/pomerium/pomerium.go index 6d280ca76..36776c002 100644 --- a/internal/cmd/pomerium/pomerium.go +++ b/internal/cmd/pomerium/pomerium.go @@ -27,7 +27,6 @@ import ( "github.com/pomerium/pomerium/internal/registry" "github.com/pomerium/pomerium/internal/urlutil" "github.com/pomerium/pomerium/internal/version" - registry_pb "github.com/pomerium/pomerium/pkg/grpc/registry" "github.com/pomerium/pomerium/proxy" ) @@ -110,10 +109,6 @@ func Run(ctx context.Context, configFile string) error { if err != nil { return fmt.Errorf("setting up databroker: %w", err) } - - if err = setupRegistryServer(src, controlPlane); err != nil { - return fmt.Errorf("setting up registry: %w", err) - } } if err = setupRegistryReporter(ctx, src); err != nil { @@ -213,13 +208,6 @@ func setupDataBroker(ctx context.Context, src config.Source, controlPlane *contr return svc, nil } -func setupRegistryServer(src config.Source, controlPlane *controlplane.Server) error { - svc := registry.NewInMemoryServer(context.TODO(), registryTTL) - registry_pb.RegisterRegistryServer(controlPlane.GRPCServer, svc) - log.Info(context.TODO()).Msg("enabled service discovery") - return nil -} - func setupRegistryReporter(ctx context.Context, src config.Source) error { reporter := new(registry.Reporter) src.OnConfigChange(ctx, reporter.OnConfigChange) diff --git a/internal/databroker/config.go b/internal/databroker/config.go index 71fb98723..f13dde38e 100644 --- a/internal/databroker/config.go +++ b/internal/databroker/config.go @@ -17,6 +17,8 @@ var ( DefaultStorageType = "memory" // DefaultGetAllPageSize is the default page size for GetAll calls. DefaultGetAllPageSize = 50 + // DefaultRegistryTTL is the default registry time to live. + DefaultRegistryTTL = time.Minute ) type serverConfig struct { @@ -28,6 +30,7 @@ type serverConfig struct { storageCertSkipVerify bool storageCertificate *tls.Certificate getAllPageSize int + registryTTL time.Duration } func newServerConfig(options ...ServerOption) *serverConfig { @@ -35,6 +38,7 @@ func newServerConfig(options ...ServerOption) *serverConfig { WithDeletePermanentlyAfter(DefaultDeletePermanentlyAfter)(cfg) WithStorageType(DefaultStorageType)(cfg) WithGetAllPageSize(DefaultGetAllPageSize)(cfg) + WithRegistryTTL(DefaultRegistryTTL)(cfg) for _, option := range options { option(cfg) } @@ -60,6 +64,13 @@ func WithGetAllPageSize(pageSize int) ServerOption { } } +// WithRegistryTTL sets the registry time to live in the config. +func WithRegistryTTL(ttl time.Duration) ServerOption { + return func(cfg *serverConfig) { + cfg.registryTTL = ttl + } +} + // WithGetSharedKey sets the secret in the config. func WithGetSharedKey(getSharedKey func() ([]byte, error)) ServerOption { return func(cfg *serverConfig) { diff --git a/internal/databroker/registry.go b/internal/databroker/registry.go new file mode 100644 index 000000000..9d67839de --- /dev/null +++ b/internal/databroker/registry.go @@ -0,0 +1,109 @@ +package databroker + +import ( + "context" + "fmt" + + "github.com/pomerium/pomerium/config" + "github.com/pomerium/pomerium/internal/log" + "github.com/pomerium/pomerium/internal/registry" + "github.com/pomerium/pomerium/internal/registry/inmemory" + "github.com/pomerium/pomerium/internal/registry/redis" + "github.com/pomerium/pomerium/internal/telemetry/trace" + registrypb "github.com/pomerium/pomerium/pkg/grpc/registry" +) + +type registryWatchServer struct { + registrypb.Registry_WatchServer + ctx context.Context +} + +func (stream registryWatchServer) Context() context.Context { + return stream.ctx +} + +// Report calls the registry Report method. +func (srv *Server) Report(ctx context.Context, req *registrypb.RegisterRequest) (*registrypb.RegisterResponse, error) { + ctx, span := trace.StartSpan(ctx, "databroker.grpc.Report") + defer span.End() + + r, err := srv.getRegistry() + if err != nil { + return nil, err + } + + return r.Report(ctx, req) +} + +// List calls the registry List method. +func (srv *Server) List(ctx context.Context, req *registrypb.ListRequest) (*registrypb.ServiceList, error) { + ctx, span := trace.StartSpan(ctx, "databroker.grpc.List") + defer span.End() + + r, err := srv.getRegistry() + if err != nil { + return nil, err + } + + return r.List(ctx, req) +} + +// Watch calls the registry Watch method. +func (srv *Server) Watch(req *registrypb.ListRequest, stream registrypb.Registry_WatchServer) error { + ctx := stream.Context() + ctx, span := trace.StartSpan(ctx, "databroker.grpc.Watch") + defer span.End() + + r, err := srv.getRegistry() + if err != nil { + return err + } + + return r.Watch(req, registryWatchServer{ + Registry_WatchServer: stream, + ctx: ctx, + }) +} + +func (srv *Server) getRegistry() (registry.Interface, error) { + // double-checked locking + srv.mu.RLock() + r := srv.registry + srv.mu.RUnlock() + if r == nil { + srv.mu.Lock() + r = srv.registry + var err error + if r == nil { + r, err = srv.newRegistryLocked() + srv.registry = r + } + srv.mu.Unlock() + if err != nil { + return nil, err + } + } + return r, nil +} + +func (srv *Server) newRegistryLocked() (registry.Interface, error) { + ctx := context.Background() + + switch srv.cfg.storageType { + case config.StorageInMemoryName: + log.Info(ctx).Msg("using in-memory registry") + return inmemory.New(ctx, srv.cfg.registryTTL), nil + case config.StorageRedisName: + log.Info(ctx).Msg("using redis registry") + r, err := redis.New( + srv.cfg.storageConnectionString, + redis.WithTLSConfig(srv.getTLSConfigLocked(ctx)), + ) + if err != nil { + return nil, fmt.Errorf("failed to create new redis registry: %w", err) + } + return r, nil + } + + return nil, fmt.Errorf("unsupported registry type: %s", srv.cfg.storageType) +} diff --git a/internal/databroker/server.go b/internal/databroker/server.go index 0b3542dfa..428dd26d1 100644 --- a/internal/databroker/server.go +++ b/internal/databroker/server.go @@ -15,6 +15,7 @@ import ( "github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/internal/log" + "github.com/pomerium/pomerium/internal/registry" "github.com/pomerium/pomerium/internal/telemetry/trace" "github.com/pomerium/pomerium/pkg/cryptutil" "github.com/pomerium/pomerium/pkg/grpc/databroker" @@ -28,8 +29,9 @@ import ( type Server struct { cfg *serverConfig - mu sync.RWMutex - backend storage.Backend + mu sync.RWMutex + backend storage.Backend + registry registry.Interface } // New creates a new server. @@ -60,6 +62,14 @@ func (srv *Server) UpdateConfig(options ...ServerOption) { } srv.backend = nil } + + if srv.registry != nil { + err := srv.registry.Close() + if err != nil { + log.Error(ctx).Err(err).Msg("databroker: error closing registry") + } + srv.registry = nil + } } // Get gets a record from the in-memory list. @@ -288,18 +298,6 @@ func (srv *Server) getBackend() (backend storage.Backend, err error) { func (srv *Server) newBackendLocked() (backend storage.Backend, err error) { ctx := context.Background() - caCertPool, err := cryptutil.GetCertPool("", srv.cfg.storageCAFile) - if err != nil { - log.Warn(ctx).Err(err).Msg("failed to read databroker CA file") - } - tlsConfig := &tls.Config{ - RootCAs: caCertPool, - // nolint: gosec - InsecureSkipVerify: srv.cfg.storageCertSkipVerify, - } - if srv.cfg.storageCertificate != nil { - tlsConfig.Certificates = []tls.Certificate{*srv.cfg.storageCertificate} - } switch srv.cfg.storageType { case config.StorageInMemoryName: @@ -309,7 +307,7 @@ func (srv *Server) newBackendLocked() (backend storage.Backend, err error) { log.Info(ctx).Msg("using redis store") backend, err = redis.New( srv.cfg.storageConnectionString, - redis.WithTLSConfig(tlsConfig), + redis.WithTLSConfig(srv.getTLSConfigLocked(ctx)), ) if err != nil { return nil, fmt.Errorf("failed to create new redis storage: %w", err) @@ -325,3 +323,19 @@ func (srv *Server) newBackendLocked() (backend storage.Backend, err error) { } return backend, nil } + +func (srv *Server) getTLSConfigLocked(ctx context.Context) *tls.Config { + caCertPool, err := cryptutil.GetCertPool("", srv.cfg.storageCAFile) + if err != nil { + log.Warn(ctx).Err(err).Msg("failed to read databroker CA file") + } + tlsConfig := &tls.Config{ + RootCAs: caCertPool, + // nolint: gosec + InsecureSkipVerify: srv.cfg.storageCertSkipVerify, + } + if srv.cfg.storageCertificate != nil { + tlsConfig.Certificates = []tls.Certificate{*srv.cfg.storageCertificate} + } + return tlsConfig +} diff --git a/pkg/storage/redis/client.go b/internal/redisutil/client.go.go similarity index 96% rename from pkg/storage/redis/client.go rename to internal/redisutil/client.go.go index 93009e25f..439086436 100644 --- a/pkg/storage/redis/client.go +++ b/internal/redisutil/client.go.go @@ -1,4 +1,4 @@ -package redis +package redisutil import ( "crypto/tls" @@ -43,15 +43,16 @@ var ( ) ) -func newClientFromURL(rawurl string, tlsConfig *tls.Config) (redis.UniversalClient, error) { - u, err := url.Parse(rawurl) +// NewClientFromURL creates a new redis client by parsing the raw URL. +func NewClientFromURL(rawURL string, tlsConfig *tls.Config) (redis.UniversalClient, error) { + u, err := url.Parse(rawURL) if err != nil { return nil, err } switch { case standardSchemes.Has(u.Scheme): - opts, err := redis.ParseURL(rawurl) + opts, err := redis.ParseURL(rawURL) if err != nil { return nil, err } @@ -62,7 +63,7 @@ func newClientFromURL(rawurl string, tlsConfig *tls.Config) (redis.UniversalClie return redis.NewClient(opts), nil case clusterSchemes.Has(u.Scheme): - opts, err := ParseClusterURL(rawurl) + opts, err := ParseClusterURL(rawURL) if err != nil { return nil, err } @@ -72,7 +73,7 @@ func newClientFromURL(rawurl string, tlsConfig *tls.Config) (redis.UniversalClie return redis.NewClusterClient(opts), nil case sentinelSchemes.Has(u.Scheme): - opts, err := ParseSentinelURL(rawurl) + opts, err := ParseSentinelURL(rawURL) if err != nil { return nil, err } @@ -82,7 +83,7 @@ func newClientFromURL(rawurl string, tlsConfig *tls.Config) (redis.UniversalClie return redis.NewFailoverClient(opts), nil case sentinelClusterSchemes.Has(u.Scheme): - opts, err := ParseSentinelURL(rawurl) + opts, err := ParseSentinelURL(rawURL) if err != nil { return nil, err } diff --git a/pkg/storage/redis/client_test.go b/internal/redisutil/client_test.go similarity index 99% rename from pkg/storage/redis/client_test.go rename to internal/redisutil/client_test.go index cf19f1bab..1557d6f24 100644 --- a/pkg/storage/redis/client_test.go +++ b/internal/redisutil/client_test.go @@ -1,4 +1,4 @@ -package redis +package redisutil import ( "net/url" diff --git a/internal/redisutil/redisutil.go b/internal/redisutil/redisutil.go new file mode 100644 index 000000000..b4fde26dd --- /dev/null +++ b/internal/redisutil/redisutil.go @@ -0,0 +1,5 @@ +// Package redisutil contains functions for working with redis. +package redisutil + +// KeyPrefix is the prefix used for all redis keys. +const KeyPrefix = "{pomerium_v3}." diff --git a/internal/registry/constants.go b/internal/registry/constants.go index 9836bbb1d..404382657 100644 --- a/internal/registry/constants.go +++ b/internal/registry/constants.go @@ -1,14 +1,8 @@ package registry -import ( - "time" -) +import "time" const ( - // callAfterTTLFactor will request to report back again after TTL/callAfterTTLFactor time - callAfterTTLFactor = 2 - // purgeAfterTTLFactor will purge keys with TTL * purgeAfterTTLFactor time - purgeAfterTTLFactor = 1 // min reporting ttl minTTL = time.Second // path metrics are available at diff --git a/internal/registry/inmemory/constants.go b/internal/registry/inmemory/constants.go new file mode 100644 index 000000000..47cc4d866 --- /dev/null +++ b/internal/registry/inmemory/constants.go @@ -0,0 +1,8 @@ +package inmemory + +const ( + // callAfterTTLFactor will request to report back again after TTL/callAfterTTLFactor time + callAfterTTLFactor = 2 + // purgeAfterTTLFactor will purge keys with TTL * purgeAfterTTLFactor time + purgeAfterTTLFactor = 1 +) diff --git a/internal/registry/inmemory.go b/internal/registry/inmemory/inmemory.go similarity index 93% rename from internal/registry/inmemory.go rename to internal/registry/inmemory/inmemory.go index bae372670..ade0241fa 100644 --- a/internal/registry/inmemory.go +++ b/internal/registry/inmemory/inmemory.go @@ -1,10 +1,12 @@ -package registry +// Package inmemory implements an in-memory registry. +package inmemory import ( "context" "sync" "time" + "github.com/pomerium/pomerium/internal/registry" "github.com/pomerium/pomerium/internal/signal" pb "github.com/pomerium/pomerium/pkg/grpc/registry" @@ -31,9 +33,9 @@ type inMemoryKey struct { endpoint string } -// NewInMemoryServer constructs a new registry tracking service that operates in RAM +// New constructs a new registry tracking service that operates in RAM // as such, it is not usable for multi-node deployment where REDIS or other alternative should be used -func NewInMemoryServer(ctx context.Context, ttl time.Duration) pb.RegistryServer { +func New(ctx context.Context, ttl time.Duration) registry.Interface { srv := &inMemoryServer{ ttl: ttl, regs: make(map[inMemoryKey]*timestamppb.Timestamp), @@ -57,6 +59,11 @@ func (s *inMemoryServer) periodicCheck(ctx context.Context) { } } +// Close closes the in memory server. +func (s *inMemoryServer) Close() error { + return nil +} + // Report is periodically sent by each service to confirm it is still serving with the registry // data is persisted with a certain TTL func (s *inMemoryServer) Report(ctx context.Context, req *pb.RegisterRequest) (*pb.RegisterResponse, error) { diff --git a/internal/registry/server_test.go b/internal/registry/inmemory/inmemory_test.go similarity index 98% rename from internal/registry/server_test.go rename to internal/registry/inmemory/inmemory_test.go index 458ee0f67..4b366a49a 100644 --- a/internal/registry/server_test.go +++ b/internal/registry/inmemory/inmemory_test.go @@ -1,4 +1,4 @@ -package registry_test +package inmemory import ( "context" @@ -8,7 +8,6 @@ import ( "testing" "time" - "github.com/pomerium/pomerium/internal/registry" pb "github.com/pomerium/pomerium/pkg/grpc/registry" "github.com/google/go-cmp/cmp" @@ -182,7 +181,7 @@ func newTestRegistry() (context.Context, pb.RegistryClient, func(), error) { gs := grpc.NewServer() ttl := time.Second - pb.RegisterRegistryServer(gs, registry.NewInMemoryServer(ctx, ttl)) + pb.RegisterRegistryServer(gs, New(ctx, ttl)) go gs.Serve(l) cancel.Append(gs.Stop) diff --git a/internal/registry/redis/lua/lua.go b/internal/registry/redis/lua/lua.go new file mode 100644 index 000000000..1ec7d5d40 --- /dev/null +++ b/internal/registry/redis/lua/lua.go @@ -0,0 +1,20 @@ +// Package lua contains lua source code. +package lua + +import ( + "embed" +) + +//go:embed registry.lua +var fs embed.FS + +// Registry is the registry lua script +var Registry string + +func init() { + bs, err := fs.ReadFile("registry.lua") + if err != nil { + panic(err) + } + Registry = string(bs) +} diff --git a/internal/registry/redis/lua/registry.lua b/internal/registry/redis/lua/registry.lua new file mode 100644 index 000000000..5e22e648a --- /dev/null +++ b/internal/registry/redis/lua/registry.lua @@ -0,0 +1,20 @@ +-- ARGV = [current time in seconds, ttl in seconds, services ...] +local current_time = ARGV[1] +local ttl = ARGV[2] + +-- update the service list +for i = 3, #ARGV, 1 do + redis.call('HSET', KEYS[1], ARGV[i], current_time + ttl) +end + +-- retrieve all the services, removing any that have expired +local svcs = {} +local kvs = redis.call('HGETALL', KEYS[1]) +for i = 1, #kvs, 2 do + if kvs[i + 1] < current_time then + redis.call('HDEL', KEYS[1], kvs[i]) + else + table.insert(svcs, kvs[i]) + end +end +return svcs diff --git a/internal/registry/redis/option.go b/internal/registry/redis/option.go new file mode 100644 index 000000000..6559a8165 --- /dev/null +++ b/internal/registry/redis/option.go @@ -0,0 +1,48 @@ +package redis + +import ( + "crypto/tls" + "time" +) + +const defaultTTL = time.Second * 30 + +type config struct { + tls *tls.Config + ttl time.Duration + getNow func() time.Time +} + +// An Option modifies the config.. +type Option func(*config) + +// WithGetNow sets the time.Now function in the config. +func WithGetNow(getNow func() time.Time) Option { + return func(cfg *config) { + cfg.getNow = getNow + } +} + +// WithTLSConfig sets the tls.Config in the config. +func WithTLSConfig(tlsConfig *tls.Config) Option { + return func(cfg *config) { + cfg.tls = tlsConfig + } +} + +// WithTTL sets the ttl in the config. +func WithTTL(ttl time.Duration) Option { + return func(cfg *config) { + cfg.ttl = ttl + } +} + +func getConfig(options ...Option) *config { + cfg := new(config) + WithGetNow(time.Now)(cfg) + WithTTL(defaultTTL)(cfg) + for _, o := range options { + o(cfg) + } + return cfg +} diff --git a/internal/registry/redis/redis.go b/internal/registry/redis/redis.go new file mode 100644 index 000000000..fce0e4a07 --- /dev/null +++ b/internal/registry/redis/redis.go @@ -0,0 +1,254 @@ +// Package redis implements a registry in redis. +package redis + +import ( + "context" + "fmt" + "sort" + "strings" + "sync" + "time" + + "github.com/cenkalti/backoff/v4" + "github.com/go-redis/redis/v8" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/known/durationpb" + + "github.com/pomerium/pomerium/internal/log" + "github.com/pomerium/pomerium/internal/redisutil" + "github.com/pomerium/pomerium/internal/registry" + "github.com/pomerium/pomerium/internal/registry/redis/lua" + "github.com/pomerium/pomerium/internal/signal" + registrypb "github.com/pomerium/pomerium/pkg/grpc/registry" +) + +const ( + registryKey = redisutil.KeyPrefix + "registry" + registryUpdateKey = redisutil.KeyPrefix + "registry_changed_ch" + + pollInterval = time.Second * 30 +) + +type impl struct { + cfg *config + + client redis.UniversalClient + onChange *signal.Signal + + closeOnce sync.Once + closed chan struct{} +} + +// New creates a new registry implementation backend by redis. +func New(rawURL string, options ...Option) (registry.Interface, error) { + cfg := getConfig(options...) + + client, err := redisutil.NewClientFromURL(rawURL, cfg.tls) + if err != nil { + return nil, err + } + + i := &impl{ + cfg: cfg, + client: client, + onChange: signal.New(), + closed: make(chan struct{}), + } + go i.listenForChanges(context.Background()) + return i, nil +} + +func (i *impl) Report(ctx context.Context, req *registrypb.RegisterRequest) (*registrypb.RegisterResponse, error) { + _, err := i.runReport(ctx, req.GetServices()) + if err != nil { + return nil, err + } + return ®istrypb.RegisterResponse{ + CallBackAfter: durationpb.New(i.cfg.ttl / 2), + }, nil +} + +func (i *impl) List(ctx context.Context, req *registrypb.ListRequest) (*registrypb.ServiceList, error) { + all, err := i.runReport(ctx, nil) + if err != nil { + return nil, err + } + + include := map[registrypb.ServiceKind]struct{}{} + for _, kind := range req.GetKinds() { + include[kind] = struct{}{} + } + + filtered := make([]*registrypb.Service, 0, len(all)) + for _, svc := range all { + if _, ok := include[svc.GetKind()]; !ok { + continue + } + filtered = append(filtered, svc) + } + + sort.Slice(filtered, func(i, j int) bool { + { + iv, jv := filtered[i].GetKind(), filtered[j].GetKind() + switch { + case iv < jv: + return true + case jv < iv: + return false + } + } + + { + iv, jv := filtered[i].GetEndpoint(), filtered[j].GetEndpoint() + switch { + case iv < jv: + return true + case jv < iv: + return false + } + } + + return false + }) + + return ®istrypb.ServiceList{ + Services: filtered, + }, nil +} + +func (i *impl) Watch(req *registrypb.ListRequest, stream registrypb.Registry_WatchServer) error { + // listen for changes + ch := i.onChange.Bind() + defer i.onChange.Unbind(ch) + + // force a check periodically + poll := time.NewTicker(pollInterval) + defer poll.Stop() + + var prev *registrypb.ServiceList + for { + // retrieve the most recent list of services + lst, err := i.List(stream.Context(), req) + if err != nil { + return err + } + + // only send a new list if something changed + if !proto.Equal(prev, lst) { + err = stream.Send(lst) + if err != nil { + return err + } + } + prev = lst + + // wait for an update + select { + case <-i.closed: + return nil + case <-stream.Context().Done(): + return stream.Context().Err() + case <-ch: + case <-poll.C: + } + } +} + +func (i *impl) Close() error { + var err error + i.closeOnce.Do(func() { + err = i.client.Close() + close(i.closed) + }) + return err +} + +func (i *impl) listenForChanges(ctx context.Context) { + ctx, cancel := context.WithCancel(ctx) + go func() { + <-i.closed + cancel() + }() + + bo := backoff.NewExponentialBackOff() + bo.MaxElapsedTime = 0 + +outer: + for { + pubsub := i.client.Subscribe(ctx, registryUpdateKey) + for { + msg, err := pubsub.Receive(ctx) + if err != nil { + _ = pubsub.Close() + select { + case <-ctx.Done(): + return + case <-time.After(bo.NextBackOff()): + } + continue outer + } + bo.Reset() + + switch msg.(type) { + case *redis.Message: + i.onChange.Broadcast(ctx) + } + } + } +} + +func (i *impl) runReport(ctx context.Context, updates []*registrypb.Service) ([]*registrypb.Service, error) { + args := []interface{}{ + i.cfg.getNow().UnixNano() / int64(time.Millisecond), // current_time + i.cfg.ttl.Milliseconds(), // ttl + } + for _, svc := range updates { + args = append(args, i.getRegistryHashKey(svc)) + } + res, err := i.client.Eval(ctx, lua.Registry, []string{registryKey}, args...).Result() + if err != nil { + return nil, err + } + _, err = i.client.Publish(ctx, registryUpdateKey, time.Now().Format(time.RFC3339Nano)).Result() + if err != nil { + return nil, err + } + if values, ok := res.([]interface{}); ok { + var all []*registrypb.Service + for _, value := range values { + svc, err := i.getServiceFromRegistryHashKey(fmt.Sprint(value)) + if err != nil { + log.Warn(ctx).Err(err).Msg("redis: invalid service") + continue + } + all = append(all, svc) + } + return all, nil + } + return nil, nil +} + +func (i *impl) getServiceFromRegistryHashKey(key string) (*registrypb.Service, error) { + idx := strings.Index(key, "|") + if idx == -1 { + return nil, fmt.Errorf("redis: invalid service entry in hash: %s", key) + } + + svcKindStr := key[:idx] + svcEndpointStr := key[idx+1:] + + svcKind, ok := registrypb.ServiceKind_value[svcKindStr] + if !ok { + return nil, fmt.Errorf("redis: unknown service kind: %s", svcKindStr) + } + + svc := ®istrypb.Service{ + Kind: registrypb.ServiceKind(svcKind), + Endpoint: svcEndpointStr, + } + return svc, nil +} + +func (i *impl) getRegistryHashKey(svc *registrypb.Service) string { + return svc.GetKind().String() + "|" + svc.GetEndpoint() +} diff --git a/internal/registry/redis/redis_test.go b/internal/registry/redis/redis_test.go new file mode 100644 index 000000000..4092278ab --- /dev/null +++ b/internal/registry/redis/redis_test.go @@ -0,0 +1,196 @@ +package redis + +import ( + "context" + "net" + "os" + "runtime" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/sync/errgroup" + "google.golang.org/grpc" + + "github.com/pomerium/pomerium/internal/testutil" + registrypb "github.com/pomerium/pomerium/pkg/grpc/registry" +) + +func TestReport(t *testing.T) { + if os.Getenv("GITHUB_ACTION") != "" && runtime.GOOS == "darwin" { + t.Skip("Github action can not run docker on MacOS") + } + + ctx := context.Background() + require.NoError(t, testutil.WithTestRedis(false, func(rawURL string) error { + tm := time.Now() + + i, err := New(rawURL, + WithGetNow(func() time.Time { + return tm + }), + WithTTL(time.Second*10)) + require.NoError(t, err) + defer func() { _ = i.Close() }() + + _, err = i.Report(ctx, ®istrypb.RegisterRequest{ + Services: []*registrypb.Service{ + {Kind: registrypb.ServiceKind_AUTHORIZE, Endpoint: "https://authorize.example.com"}, + {Kind: registrypb.ServiceKind_AUTHENTICATE, Endpoint: "https://authenticate.example.com"}, + {Kind: registrypb.ServiceKind_PROXY, Endpoint: "https://proxy.example.com"}, + }, + }) + require.NoError(t, err) + + // move forward 5 seconds + tm = tm.Add(time.Second * 5) + _, err = i.Report(ctx, ®istrypb.RegisterRequest{ + Services: []*registrypb.Service{ + {Kind: registrypb.ServiceKind_AUTHENTICATE, Endpoint: "https://authenticate.example.com"}, + {Kind: registrypb.ServiceKind_PROXY, Endpoint: "https://proxy.example.com"}, + }, + }) + require.NoError(t, err) + + lst, err := i.List(ctx, ®istrypb.ListRequest{ + Kinds: []registrypb.ServiceKind{ + registrypb.ServiceKind_AUTHORIZE, + registrypb.ServiceKind_PROXY, + }, + }) + require.NoError(t, err) + assert.Equal(t, []*registrypb.Service{ + {Kind: registrypb.ServiceKind_AUTHORIZE, Endpoint: "https://authorize.example.com"}, + {Kind: registrypb.ServiceKind_PROXY, Endpoint: "https://proxy.example.com"}, + }, lst.GetServices(), "should list selected services") + + // move forward 6 seconds + tm = tm.Add(time.Second * 6) + lst, err = i.List(ctx, ®istrypb.ListRequest{ + Kinds: []registrypb.ServiceKind{ + registrypb.ServiceKind_AUTHORIZE, + registrypb.ServiceKind_PROXY, + }, + }) + require.NoError(t, err) + assert.Equal(t, []*registrypb.Service{ + {Kind: registrypb.ServiceKind_PROXY, Endpoint: "https://proxy.example.com"}, + }, lst.GetServices(), "should expire old services") + + return nil + })) +} + +func TestWatch(t *testing.T) { + if os.Getenv("GITHUB_ACTION") != "" && runtime.GOOS == "darwin" { + t.Skip("Github action can not run docker on MacOS") + } + + require.NoError(t, testutil.WithTestRedis(false, func(rawURL string) error { + ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*15) + defer clearTimeout() + + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + tm := time.Now() + i, err := New(rawURL, + WithGetNow(func() time.Time { + return tm + }), + WithTTL(time.Second*10)) + require.NoError(t, err) + + li, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer li.Close() + + srv := grpc.NewServer() + registrypb.RegisterRegistryServer(srv, i) + eg, ctx := errgroup.WithContext(ctx) + eg.Go(func() error { + <-ctx.Done() + srv.Stop() + return nil + }) + eg.Go(func() error { + return srv.Serve(li) + }) + eg.Go(func() error { + defer cancel() + + cc, err := grpc.Dial(li.Addr().String(), grpc.WithInsecure()) + if err != nil { + return err + } + + client := registrypb.NewRegistryClient(cc) + + // store the initial services + _, err = client.Report(ctx, ®istrypb.RegisterRequest{ + Services: []*registrypb.Service{ + {Kind: registrypb.ServiceKind_AUTHENTICATE, Endpoint: "http://authenticate1.example.com"}, + {Kind: registrypb.ServiceKind_AUTHORIZE, Endpoint: "http://authorize2.example.com"}, + {Kind: registrypb.ServiceKind_AUTHORIZE, Endpoint: "http://authorize1.example.com"}, + }, + }) + if err != nil { + return err + } + + stream, err := client.Watch(ctx, ®istrypb.ListRequest{ + Kinds: []registrypb.ServiceKind{ + registrypb.ServiceKind_AUTHORIZE, + }, + }) + if err != nil { + return err + } + defer func() { _ = stream.CloseSend() }() + + lst, err := stream.Recv() + if err != nil { + return err + } + assert.Equal(t, []*registrypb.Service{ + {Kind: registrypb.ServiceKind_AUTHORIZE, Endpoint: "http://authorize1.example.com"}, + {Kind: registrypb.ServiceKind_AUTHORIZE, Endpoint: "http://authorize2.example.com"}, + }, lst.GetServices()) + + // update authenticate + _, err = client.Report(ctx, ®istrypb.RegisterRequest{ + Services: []*registrypb.Service{ + {Kind: registrypb.ServiceKind_AUTHENTICATE, Endpoint: "http://authenticate1.example.com"}, + }, + }) + if err != nil { + return err + } + + // add an authorize + _, err = client.Report(ctx, ®istrypb.RegisterRequest{ + Services: []*registrypb.Service{ + {Kind: registrypb.ServiceKind_AUTHORIZE, Endpoint: "http://authorize3.example.com"}, + }, + }) + if err != nil { + return err + } + + lst, err = stream.Recv() + if err != nil { + return err + } + assert.Equal(t, []*registrypb.Service{ + {Kind: registrypb.ServiceKind_AUTHORIZE, Endpoint: "http://authorize1.example.com"}, + {Kind: registrypb.ServiceKind_AUTHORIZE, Endpoint: "http://authorize2.example.com"}, + {Kind: registrypb.ServiceKind_AUTHORIZE, Endpoint: "http://authorize3.example.com"}, + }, lst.GetServices()) + + return nil + }) + require.NoError(t, eg.Wait()) + return nil + })) +} diff --git a/internal/registry/registry.go b/internal/registry/registry.go index 89ab8cc17..ba2e9617b 100644 --- a/internal/registry/registry.go +++ b/internal/registry/registry.go @@ -1,2 +1,14 @@ // Package registry implements a service registry server. package registry + +import ( + "io" + + registrypb "github.com/pomerium/pomerium/pkg/grpc/registry" +) + +// Interface is a registry implementation. +type Interface interface { + registrypb.RegistryServer + io.Closer +} diff --git a/pkg/storage/redis/redis.go b/pkg/storage/redis/redis.go index 361e11dfc..f6860d6bd 100644 --- a/pkg/storage/redis/redis.go +++ b/pkg/storage/redis/redis.go @@ -9,11 +9,12 @@ import ( "time" "github.com/cenkalti/backoff/v4" - redis "github.com/go-redis/redis/v8" + "github.com/go-redis/redis/v8" "github.com/golang/protobuf/proto" "google.golang.org/protobuf/types/known/timestamppb" "github.com/pomerium/pomerium/internal/log" + "github.com/pomerium/pomerium/internal/redisutil" "github.com/pomerium/pomerium/internal/signal" "github.com/pomerium/pomerium/internal/telemetry/metrics" "github.com/pomerium/pomerium/internal/telemetry/trace" @@ -28,14 +29,14 @@ const ( // we rely on transactions in redis, so all redis-cluster keys need to be // on the same node. Using a `hash tag` gives us this capability. - serverVersionKey = "{pomerium_v3}.server_version" - lastVersionKey = "{pomerium_v3}.last_version" - lastVersionChKey = "{pomerium_v3}.last_version_ch" - recordHashKey = "{pomerium_v3}.records" - changesSetKey = "{pomerium_v3}.changes" - optionsKey = "{pomerium_v3}.options" + serverVersionKey = redisutil.KeyPrefix + "server_version" + lastVersionKey = redisutil.KeyPrefix + "last_version" + lastVersionChKey = redisutil.KeyPrefix + "last_version_ch" + recordHashKey = redisutil.KeyPrefix + "records" + changesSetKey = redisutil.KeyPrefix + "changes" + optionsKey = redisutil.KeyPrefix + "options" - recordTypeChangesKeyTpl = "{pomerium_v3}.changes.%s" + recordTypeChangesKeyTpl = redisutil.KeyPrefix + "changes.%s" ) // custom errors @@ -76,7 +77,7 @@ func New(rawURL string, options ...Option) (*Backend, error) { onChange: signal.New(), } var err error - backend.client, err = newClientFromURL(rawURL, backend.cfg.tls) + backend.client, err = redisutil.NewClientFromURL(rawURL, backend.cfg.tls) if err != nil { return nil, err }