mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-23 14:07:11 +02:00
registry: implement redis backend (#2179)
This commit is contained in:
parent
28155314e9
commit
a54d43b937
21 changed files with 772 additions and 64 deletions
|
@ -22,6 +22,7 @@ import (
|
||||||
"github.com/pomerium/pomerium/internal/version"
|
"github.com/pomerium/pomerium/internal/version"
|
||||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||||
|
"github.com/pomerium/pomerium/pkg/grpc/registry"
|
||||||
"github.com/pomerium/pomerium/pkg/grpcutil"
|
"github.com/pomerium/pomerium/pkg/grpcutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -116,6 +117,7 @@ func (c *DataBroker) OnConfigChange(ctx context.Context, cfg *config.Config) {
|
||||||
func (c *DataBroker) Register(grpcServer *grpc.Server) {
|
func (c *DataBroker) Register(grpcServer *grpc.Server) {
|
||||||
databroker.RegisterDataBrokerServiceServer(grpcServer, c.dataBrokerServer)
|
databroker.RegisterDataBrokerServiceServer(grpcServer, c.dataBrokerServer)
|
||||||
directory.RegisterDirectoryServiceServer(grpcServer, c)
|
directory.RegisterDirectoryServiceServer(grpcServer, c)
|
||||||
|
registry.RegisterRegistryServer(grpcServer, c.dataBrokerServer)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Run runs the databroker components.
|
// Run runs the databroker components.
|
||||||
|
|
|
@ -8,6 +8,7 @@ import (
|
||||||
"github.com/pomerium/pomerium/config"
|
"github.com/pomerium/pomerium/config"
|
||||||
"github.com/pomerium/pomerium/internal/databroker"
|
"github.com/pomerium/pomerium/internal/databroker"
|
||||||
databrokerpb "github.com/pomerium/pomerium/pkg/grpc/databroker"
|
databrokerpb "github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||||
|
registrypb "github.com/pomerium/pomerium/pkg/grpc/registry"
|
||||||
"github.com/pomerium/pomerium/pkg/grpcutil"
|
"github.com/pomerium/pomerium/pkg/grpcutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -51,6 +52,8 @@ func (srv *dataBrokerServer) setKey(cfg *config.Config) {
|
||||||
srv.sharedKey.Store(bs)
|
srv.sharedKey.Store(bs)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Databroker functions
|
||||||
|
|
||||||
func (srv *dataBrokerServer) Get(ctx context.Context, req *databrokerpb.GetRequest) (*databrokerpb.GetResponse, error) {
|
func (srv *dataBrokerServer) Get(ctx context.Context, req *databrokerpb.GetRequest) (*databrokerpb.GetResponse, error) {
|
||||||
if err := grpcutil.RequireSignedJWT(ctx, srv.sharedKey.Load().([]byte)); err != nil {
|
if err := grpcutil.RequireSignedJWT(ctx, srv.sharedKey.Load().([]byte)); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -92,3 +95,26 @@ func (srv *dataBrokerServer) SyncLatest(req *databrokerpb.SyncLatestRequest, str
|
||||||
}
|
}
|
||||||
return srv.server.SyncLatest(req, stream)
|
return srv.server.SyncLatest(req, stream)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Registry functions
|
||||||
|
|
||||||
|
func (srv *dataBrokerServer) Report(ctx context.Context, req *registrypb.RegisterRequest) (*registrypb.RegisterResponse, error) {
|
||||||
|
if err := grpcutil.RequireSignedJWT(ctx, srv.sharedKey.Load().([]byte)); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return srv.server.Report(ctx, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (srv *dataBrokerServer) List(ctx context.Context, req *registrypb.ListRequest) (*registrypb.ServiceList, error) {
|
||||||
|
if err := grpcutil.RequireSignedJWT(ctx, srv.sharedKey.Load().([]byte)); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return srv.server.List(ctx, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (srv *dataBrokerServer) Watch(req *registrypb.ListRequest, stream registrypb.Registry_WatchServer) error {
|
||||||
|
if err := grpcutil.RequireSignedJWT(stream.Context(), srv.sharedKey.Load().([]byte)); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return srv.server.Watch(req, stream)
|
||||||
|
}
|
||||||
|
|
|
@ -1,7 +0,0 @@
|
||||||
package pomerium
|
|
||||||
|
|
||||||
import "time"
|
|
||||||
|
|
||||||
const (
|
|
||||||
registryTTL = time.Minute
|
|
||||||
)
|
|
|
@ -27,7 +27,6 @@ import (
|
||||||
"github.com/pomerium/pomerium/internal/registry"
|
"github.com/pomerium/pomerium/internal/registry"
|
||||||
"github.com/pomerium/pomerium/internal/urlutil"
|
"github.com/pomerium/pomerium/internal/urlutil"
|
||||||
"github.com/pomerium/pomerium/internal/version"
|
"github.com/pomerium/pomerium/internal/version"
|
||||||
registry_pb "github.com/pomerium/pomerium/pkg/grpc/registry"
|
|
||||||
"github.com/pomerium/pomerium/proxy"
|
"github.com/pomerium/pomerium/proxy"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -110,10 +109,6 @@ func Run(ctx context.Context, configFile string) error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("setting up databroker: %w", err)
|
return fmt.Errorf("setting up databroker: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = setupRegistryServer(src, controlPlane); err != nil {
|
|
||||||
return fmt.Errorf("setting up registry: %w", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = setupRegistryReporter(ctx, src); err != nil {
|
if err = setupRegistryReporter(ctx, src); err != nil {
|
||||||
|
@ -213,13 +208,6 @@ func setupDataBroker(ctx context.Context, src config.Source, controlPlane *contr
|
||||||
return svc, nil
|
return svc, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func setupRegistryServer(src config.Source, controlPlane *controlplane.Server) error {
|
|
||||||
svc := registry.NewInMemoryServer(context.TODO(), registryTTL)
|
|
||||||
registry_pb.RegisterRegistryServer(controlPlane.GRPCServer, svc)
|
|
||||||
log.Info(context.TODO()).Msg("enabled service discovery")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func setupRegistryReporter(ctx context.Context, src config.Source) error {
|
func setupRegistryReporter(ctx context.Context, src config.Source) error {
|
||||||
reporter := new(registry.Reporter)
|
reporter := new(registry.Reporter)
|
||||||
src.OnConfigChange(ctx, reporter.OnConfigChange)
|
src.OnConfigChange(ctx, reporter.OnConfigChange)
|
||||||
|
|
|
@ -17,6 +17,8 @@ var (
|
||||||
DefaultStorageType = "memory"
|
DefaultStorageType = "memory"
|
||||||
// DefaultGetAllPageSize is the default page size for GetAll calls.
|
// DefaultGetAllPageSize is the default page size for GetAll calls.
|
||||||
DefaultGetAllPageSize = 50
|
DefaultGetAllPageSize = 50
|
||||||
|
// DefaultRegistryTTL is the default registry time to live.
|
||||||
|
DefaultRegistryTTL = time.Minute
|
||||||
)
|
)
|
||||||
|
|
||||||
type serverConfig struct {
|
type serverConfig struct {
|
||||||
|
@ -28,6 +30,7 @@ type serverConfig struct {
|
||||||
storageCertSkipVerify bool
|
storageCertSkipVerify bool
|
||||||
storageCertificate *tls.Certificate
|
storageCertificate *tls.Certificate
|
||||||
getAllPageSize int
|
getAllPageSize int
|
||||||
|
registryTTL time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
func newServerConfig(options ...ServerOption) *serverConfig {
|
func newServerConfig(options ...ServerOption) *serverConfig {
|
||||||
|
@ -35,6 +38,7 @@ func newServerConfig(options ...ServerOption) *serverConfig {
|
||||||
WithDeletePermanentlyAfter(DefaultDeletePermanentlyAfter)(cfg)
|
WithDeletePermanentlyAfter(DefaultDeletePermanentlyAfter)(cfg)
|
||||||
WithStorageType(DefaultStorageType)(cfg)
|
WithStorageType(DefaultStorageType)(cfg)
|
||||||
WithGetAllPageSize(DefaultGetAllPageSize)(cfg)
|
WithGetAllPageSize(DefaultGetAllPageSize)(cfg)
|
||||||
|
WithRegistryTTL(DefaultRegistryTTL)(cfg)
|
||||||
for _, option := range options {
|
for _, option := range options {
|
||||||
option(cfg)
|
option(cfg)
|
||||||
}
|
}
|
||||||
|
@ -60,6 +64,13 @@ func WithGetAllPageSize(pageSize int) ServerOption {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WithRegistryTTL sets the registry time to live in the config.
|
||||||
|
func WithRegistryTTL(ttl time.Duration) ServerOption {
|
||||||
|
return func(cfg *serverConfig) {
|
||||||
|
cfg.registryTTL = ttl
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// WithGetSharedKey sets the secret in the config.
|
// WithGetSharedKey sets the secret in the config.
|
||||||
func WithGetSharedKey(getSharedKey func() ([]byte, error)) ServerOption {
|
func WithGetSharedKey(getSharedKey func() ([]byte, error)) ServerOption {
|
||||||
return func(cfg *serverConfig) {
|
return func(cfg *serverConfig) {
|
||||||
|
|
109
internal/databroker/registry.go
Normal file
109
internal/databroker/registry.go
Normal file
|
@ -0,0 +1,109 @@
|
||||||
|
package databroker
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/config"
|
||||||
|
"github.com/pomerium/pomerium/internal/log"
|
||||||
|
"github.com/pomerium/pomerium/internal/registry"
|
||||||
|
"github.com/pomerium/pomerium/internal/registry/inmemory"
|
||||||
|
"github.com/pomerium/pomerium/internal/registry/redis"
|
||||||
|
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
||||||
|
registrypb "github.com/pomerium/pomerium/pkg/grpc/registry"
|
||||||
|
)
|
||||||
|
|
||||||
|
type registryWatchServer struct {
|
||||||
|
registrypb.Registry_WatchServer
|
||||||
|
ctx context.Context
|
||||||
|
}
|
||||||
|
|
||||||
|
func (stream registryWatchServer) Context() context.Context {
|
||||||
|
return stream.ctx
|
||||||
|
}
|
||||||
|
|
||||||
|
// Report calls the registry Report method.
|
||||||
|
func (srv *Server) Report(ctx context.Context, req *registrypb.RegisterRequest) (*registrypb.RegisterResponse, error) {
|
||||||
|
ctx, span := trace.StartSpan(ctx, "databroker.grpc.Report")
|
||||||
|
defer span.End()
|
||||||
|
|
||||||
|
r, err := srv.getRegistry()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return r.Report(ctx, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// List calls the registry List method.
|
||||||
|
func (srv *Server) List(ctx context.Context, req *registrypb.ListRequest) (*registrypb.ServiceList, error) {
|
||||||
|
ctx, span := trace.StartSpan(ctx, "databroker.grpc.List")
|
||||||
|
defer span.End()
|
||||||
|
|
||||||
|
r, err := srv.getRegistry()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return r.List(ctx, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Watch calls the registry Watch method.
|
||||||
|
func (srv *Server) Watch(req *registrypb.ListRequest, stream registrypb.Registry_WatchServer) error {
|
||||||
|
ctx := stream.Context()
|
||||||
|
ctx, span := trace.StartSpan(ctx, "databroker.grpc.Watch")
|
||||||
|
defer span.End()
|
||||||
|
|
||||||
|
r, err := srv.getRegistry()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return r.Watch(req, registryWatchServer{
|
||||||
|
Registry_WatchServer: stream,
|
||||||
|
ctx: ctx,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (srv *Server) getRegistry() (registry.Interface, error) {
|
||||||
|
// double-checked locking
|
||||||
|
srv.mu.RLock()
|
||||||
|
r := srv.registry
|
||||||
|
srv.mu.RUnlock()
|
||||||
|
if r == nil {
|
||||||
|
srv.mu.Lock()
|
||||||
|
r = srv.registry
|
||||||
|
var err error
|
||||||
|
if r == nil {
|
||||||
|
r, err = srv.newRegistryLocked()
|
||||||
|
srv.registry = r
|
||||||
|
}
|
||||||
|
srv.mu.Unlock()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return r, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (srv *Server) newRegistryLocked() (registry.Interface, error) {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
switch srv.cfg.storageType {
|
||||||
|
case config.StorageInMemoryName:
|
||||||
|
log.Info(ctx).Msg("using in-memory registry")
|
||||||
|
return inmemory.New(ctx, srv.cfg.registryTTL), nil
|
||||||
|
case config.StorageRedisName:
|
||||||
|
log.Info(ctx).Msg("using redis registry")
|
||||||
|
r, err := redis.New(
|
||||||
|
srv.cfg.storageConnectionString,
|
||||||
|
redis.WithTLSConfig(srv.getTLSConfigLocked(ctx)),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create new redis registry: %w", err)
|
||||||
|
}
|
||||||
|
return r, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, fmt.Errorf("unsupported registry type: %s", srv.cfg.storageType)
|
||||||
|
}
|
|
@ -15,6 +15,7 @@ import (
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/config"
|
"github.com/pomerium/pomerium/config"
|
||||||
"github.com/pomerium/pomerium/internal/log"
|
"github.com/pomerium/pomerium/internal/log"
|
||||||
|
"github.com/pomerium/pomerium/internal/registry"
|
||||||
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
||||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||||
|
@ -28,8 +29,9 @@ import (
|
||||||
type Server struct {
|
type Server struct {
|
||||||
cfg *serverConfig
|
cfg *serverConfig
|
||||||
|
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
backend storage.Backend
|
backend storage.Backend
|
||||||
|
registry registry.Interface
|
||||||
}
|
}
|
||||||
|
|
||||||
// New creates a new server.
|
// New creates a new server.
|
||||||
|
@ -60,6 +62,14 @@ func (srv *Server) UpdateConfig(options ...ServerOption) {
|
||||||
}
|
}
|
||||||
srv.backend = nil
|
srv.backend = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if srv.registry != nil {
|
||||||
|
err := srv.registry.Close()
|
||||||
|
if err != nil {
|
||||||
|
log.Error(ctx).Err(err).Msg("databroker: error closing registry")
|
||||||
|
}
|
||||||
|
srv.registry = nil
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get gets a record from the in-memory list.
|
// Get gets a record from the in-memory list.
|
||||||
|
@ -288,18 +298,6 @@ func (srv *Server) getBackend() (backend storage.Backend, err error) {
|
||||||
|
|
||||||
func (srv *Server) newBackendLocked() (backend storage.Backend, err error) {
|
func (srv *Server) newBackendLocked() (backend storage.Backend, err error) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
caCertPool, err := cryptutil.GetCertPool("", srv.cfg.storageCAFile)
|
|
||||||
if err != nil {
|
|
||||||
log.Warn(ctx).Err(err).Msg("failed to read databroker CA file")
|
|
||||||
}
|
|
||||||
tlsConfig := &tls.Config{
|
|
||||||
RootCAs: caCertPool,
|
|
||||||
// nolint: gosec
|
|
||||||
InsecureSkipVerify: srv.cfg.storageCertSkipVerify,
|
|
||||||
}
|
|
||||||
if srv.cfg.storageCertificate != nil {
|
|
||||||
tlsConfig.Certificates = []tls.Certificate{*srv.cfg.storageCertificate}
|
|
||||||
}
|
|
||||||
|
|
||||||
switch srv.cfg.storageType {
|
switch srv.cfg.storageType {
|
||||||
case config.StorageInMemoryName:
|
case config.StorageInMemoryName:
|
||||||
|
@ -309,7 +307,7 @@ func (srv *Server) newBackendLocked() (backend storage.Backend, err error) {
|
||||||
log.Info(ctx).Msg("using redis store")
|
log.Info(ctx).Msg("using redis store")
|
||||||
backend, err = redis.New(
|
backend, err = redis.New(
|
||||||
srv.cfg.storageConnectionString,
|
srv.cfg.storageConnectionString,
|
||||||
redis.WithTLSConfig(tlsConfig),
|
redis.WithTLSConfig(srv.getTLSConfigLocked(ctx)),
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to create new redis storage: %w", err)
|
return nil, fmt.Errorf("failed to create new redis storage: %w", err)
|
||||||
|
@ -325,3 +323,19 @@ func (srv *Server) newBackendLocked() (backend storage.Backend, err error) {
|
||||||
}
|
}
|
||||||
return backend, nil
|
return backend, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (srv *Server) getTLSConfigLocked(ctx context.Context) *tls.Config {
|
||||||
|
caCertPool, err := cryptutil.GetCertPool("", srv.cfg.storageCAFile)
|
||||||
|
if err != nil {
|
||||||
|
log.Warn(ctx).Err(err).Msg("failed to read databroker CA file")
|
||||||
|
}
|
||||||
|
tlsConfig := &tls.Config{
|
||||||
|
RootCAs: caCertPool,
|
||||||
|
// nolint: gosec
|
||||||
|
InsecureSkipVerify: srv.cfg.storageCertSkipVerify,
|
||||||
|
}
|
||||||
|
if srv.cfg.storageCertificate != nil {
|
||||||
|
tlsConfig.Certificates = []tls.Certificate{*srv.cfg.storageCertificate}
|
||||||
|
}
|
||||||
|
return tlsConfig
|
||||||
|
}
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
package redis
|
package redisutil
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
|
@ -43,15 +43,16 @@ var (
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
func newClientFromURL(rawurl string, tlsConfig *tls.Config) (redis.UniversalClient, error) {
|
// NewClientFromURL creates a new redis client by parsing the raw URL.
|
||||||
u, err := url.Parse(rawurl)
|
func NewClientFromURL(rawURL string, tlsConfig *tls.Config) (redis.UniversalClient, error) {
|
||||||
|
u, err := url.Parse(rawURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
switch {
|
switch {
|
||||||
case standardSchemes.Has(u.Scheme):
|
case standardSchemes.Has(u.Scheme):
|
||||||
opts, err := redis.ParseURL(rawurl)
|
opts, err := redis.ParseURL(rawURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -62,7 +63,7 @@ func newClientFromURL(rawurl string, tlsConfig *tls.Config) (redis.UniversalClie
|
||||||
return redis.NewClient(opts), nil
|
return redis.NewClient(opts), nil
|
||||||
|
|
||||||
case clusterSchemes.Has(u.Scheme):
|
case clusterSchemes.Has(u.Scheme):
|
||||||
opts, err := ParseClusterURL(rawurl)
|
opts, err := ParseClusterURL(rawURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -72,7 +73,7 @@ func newClientFromURL(rawurl string, tlsConfig *tls.Config) (redis.UniversalClie
|
||||||
return redis.NewClusterClient(opts), nil
|
return redis.NewClusterClient(opts), nil
|
||||||
|
|
||||||
case sentinelSchemes.Has(u.Scheme):
|
case sentinelSchemes.Has(u.Scheme):
|
||||||
opts, err := ParseSentinelURL(rawurl)
|
opts, err := ParseSentinelURL(rawURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -82,7 +83,7 @@ func newClientFromURL(rawurl string, tlsConfig *tls.Config) (redis.UniversalClie
|
||||||
return redis.NewFailoverClient(opts), nil
|
return redis.NewFailoverClient(opts), nil
|
||||||
|
|
||||||
case sentinelClusterSchemes.Has(u.Scheme):
|
case sentinelClusterSchemes.Has(u.Scheme):
|
||||||
opts, err := ParseSentinelURL(rawurl)
|
opts, err := ParseSentinelURL(rawURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
|
@ -1,4 +1,4 @@
|
||||||
package redis
|
package redisutil
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net/url"
|
"net/url"
|
5
internal/redisutil/redisutil.go
Normal file
5
internal/redisutil/redisutil.go
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
// Package redisutil contains functions for working with redis.
|
||||||
|
package redisutil
|
||||||
|
|
||||||
|
// KeyPrefix is the prefix used for all redis keys.
|
||||||
|
const KeyPrefix = "{pomerium_v3}."
|
|
@ -1,14 +1,8 @@
|
||||||
package registry
|
package registry
|
||||||
|
|
||||||
import (
|
import "time"
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
const (
|
||||||
// callAfterTTLFactor will request to report back again after TTL/callAfterTTLFactor time
|
|
||||||
callAfterTTLFactor = 2
|
|
||||||
// purgeAfterTTLFactor will purge keys with TTL * purgeAfterTTLFactor time
|
|
||||||
purgeAfterTTLFactor = 1
|
|
||||||
// min reporting ttl
|
// min reporting ttl
|
||||||
minTTL = time.Second
|
minTTL = time.Second
|
||||||
// path metrics are available at
|
// path metrics are available at
|
||||||
|
|
8
internal/registry/inmemory/constants.go
Normal file
8
internal/registry/inmemory/constants.go
Normal file
|
@ -0,0 +1,8 @@
|
||||||
|
package inmemory
|
||||||
|
|
||||||
|
const (
|
||||||
|
// callAfterTTLFactor will request to report back again after TTL/callAfterTTLFactor time
|
||||||
|
callAfterTTLFactor = 2
|
||||||
|
// purgeAfterTTLFactor will purge keys with TTL * purgeAfterTTLFactor time
|
||||||
|
purgeAfterTTLFactor = 1
|
||||||
|
)
|
|
@ -1,10 +1,12 @@
|
||||||
package registry
|
// Package inmemory implements an in-memory registry.
|
||||||
|
package inmemory
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/internal/registry"
|
||||||
"github.com/pomerium/pomerium/internal/signal"
|
"github.com/pomerium/pomerium/internal/signal"
|
||||||
pb "github.com/pomerium/pomerium/pkg/grpc/registry"
|
pb "github.com/pomerium/pomerium/pkg/grpc/registry"
|
||||||
|
|
||||||
|
@ -31,9 +33,9 @@ type inMemoryKey struct {
|
||||||
endpoint string
|
endpoint string
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewInMemoryServer constructs a new registry tracking service that operates in RAM
|
// New constructs a new registry tracking service that operates in RAM
|
||||||
// as such, it is not usable for multi-node deployment where REDIS or other alternative should be used
|
// as such, it is not usable for multi-node deployment where REDIS or other alternative should be used
|
||||||
func NewInMemoryServer(ctx context.Context, ttl time.Duration) pb.RegistryServer {
|
func New(ctx context.Context, ttl time.Duration) registry.Interface {
|
||||||
srv := &inMemoryServer{
|
srv := &inMemoryServer{
|
||||||
ttl: ttl,
|
ttl: ttl,
|
||||||
regs: make(map[inMemoryKey]*timestamppb.Timestamp),
|
regs: make(map[inMemoryKey]*timestamppb.Timestamp),
|
||||||
|
@ -57,6 +59,11 @@ func (s *inMemoryServer) periodicCheck(ctx context.Context) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Close closes the in memory server.
|
||||||
|
func (s *inMemoryServer) Close() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// Report is periodically sent by each service to confirm it is still serving with the registry
|
// Report is periodically sent by each service to confirm it is still serving with the registry
|
||||||
// data is persisted with a certain TTL
|
// data is persisted with a certain TTL
|
||||||
func (s *inMemoryServer) Report(ctx context.Context, req *pb.RegisterRequest) (*pb.RegisterResponse, error) {
|
func (s *inMemoryServer) Report(ctx context.Context, req *pb.RegisterRequest) (*pb.RegisterResponse, error) {
|
|
@ -1,4 +1,4 @@
|
||||||
package registry_test
|
package inmemory
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
@ -8,7 +8,6 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/internal/registry"
|
|
||||||
pb "github.com/pomerium/pomerium/pkg/grpc/registry"
|
pb "github.com/pomerium/pomerium/pkg/grpc/registry"
|
||||||
|
|
||||||
"github.com/google/go-cmp/cmp"
|
"github.com/google/go-cmp/cmp"
|
||||||
|
@ -182,7 +181,7 @@ func newTestRegistry() (context.Context, pb.RegistryClient, func(), error) {
|
||||||
gs := grpc.NewServer()
|
gs := grpc.NewServer()
|
||||||
|
|
||||||
ttl := time.Second
|
ttl := time.Second
|
||||||
pb.RegisterRegistryServer(gs, registry.NewInMemoryServer(ctx, ttl))
|
pb.RegisterRegistryServer(gs, New(ctx, ttl))
|
||||||
|
|
||||||
go gs.Serve(l)
|
go gs.Serve(l)
|
||||||
cancel.Append(gs.Stop)
|
cancel.Append(gs.Stop)
|
20
internal/registry/redis/lua/lua.go
Normal file
20
internal/registry/redis/lua/lua.go
Normal file
|
@ -0,0 +1,20 @@
|
||||||
|
// Package lua contains lua source code.
|
||||||
|
package lua
|
||||||
|
|
||||||
|
import (
|
||||||
|
"embed"
|
||||||
|
)
|
||||||
|
|
||||||
|
//go:embed registry.lua
|
||||||
|
var fs embed.FS
|
||||||
|
|
||||||
|
// Registry is the registry lua script
|
||||||
|
var Registry string
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
bs, err := fs.ReadFile("registry.lua")
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
Registry = string(bs)
|
||||||
|
}
|
20
internal/registry/redis/lua/registry.lua
Normal file
20
internal/registry/redis/lua/registry.lua
Normal file
|
@ -0,0 +1,20 @@
|
||||||
|
-- ARGV = [current time in seconds, ttl in seconds, services ...]
|
||||||
|
local current_time = ARGV[1]
|
||||||
|
local ttl = ARGV[2]
|
||||||
|
|
||||||
|
-- update the service list
|
||||||
|
for i = 3, #ARGV, 1 do
|
||||||
|
redis.call('HSET', KEYS[1], ARGV[i], current_time + ttl)
|
||||||
|
end
|
||||||
|
|
||||||
|
-- retrieve all the services, removing any that have expired
|
||||||
|
local svcs = {}
|
||||||
|
local kvs = redis.call('HGETALL', KEYS[1])
|
||||||
|
for i = 1, #kvs, 2 do
|
||||||
|
if kvs[i + 1] < current_time then
|
||||||
|
redis.call('HDEL', KEYS[1], kvs[i])
|
||||||
|
else
|
||||||
|
table.insert(svcs, kvs[i])
|
||||||
|
end
|
||||||
|
end
|
||||||
|
return svcs
|
48
internal/registry/redis/option.go
Normal file
48
internal/registry/redis/option.go
Normal file
|
@ -0,0 +1,48 @@
|
||||||
|
package redis
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const defaultTTL = time.Second * 30
|
||||||
|
|
||||||
|
type config struct {
|
||||||
|
tls *tls.Config
|
||||||
|
ttl time.Duration
|
||||||
|
getNow func() time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// An Option modifies the config..
|
||||||
|
type Option func(*config)
|
||||||
|
|
||||||
|
// WithGetNow sets the time.Now function in the config.
|
||||||
|
func WithGetNow(getNow func() time.Time) Option {
|
||||||
|
return func(cfg *config) {
|
||||||
|
cfg.getNow = getNow
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithTLSConfig sets the tls.Config in the config.
|
||||||
|
func WithTLSConfig(tlsConfig *tls.Config) Option {
|
||||||
|
return func(cfg *config) {
|
||||||
|
cfg.tls = tlsConfig
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithTTL sets the ttl in the config.
|
||||||
|
func WithTTL(ttl time.Duration) Option {
|
||||||
|
return func(cfg *config) {
|
||||||
|
cfg.ttl = ttl
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func getConfig(options ...Option) *config {
|
||||||
|
cfg := new(config)
|
||||||
|
WithGetNow(time.Now)(cfg)
|
||||||
|
WithTTL(defaultTTL)(cfg)
|
||||||
|
for _, o := range options {
|
||||||
|
o(cfg)
|
||||||
|
}
|
||||||
|
return cfg
|
||||||
|
}
|
254
internal/registry/redis/redis.go
Normal file
254
internal/registry/redis/redis.go
Normal file
|
@ -0,0 +1,254 @@
|
||||||
|
// Package redis implements a registry in redis.
|
||||||
|
package redis
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"sort"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/cenkalti/backoff/v4"
|
||||||
|
"github.com/go-redis/redis/v8"
|
||||||
|
"google.golang.org/protobuf/proto"
|
||||||
|
"google.golang.org/protobuf/types/known/durationpb"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/internal/log"
|
||||||
|
"github.com/pomerium/pomerium/internal/redisutil"
|
||||||
|
"github.com/pomerium/pomerium/internal/registry"
|
||||||
|
"github.com/pomerium/pomerium/internal/registry/redis/lua"
|
||||||
|
"github.com/pomerium/pomerium/internal/signal"
|
||||||
|
registrypb "github.com/pomerium/pomerium/pkg/grpc/registry"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
registryKey = redisutil.KeyPrefix + "registry"
|
||||||
|
registryUpdateKey = redisutil.KeyPrefix + "registry_changed_ch"
|
||||||
|
|
||||||
|
pollInterval = time.Second * 30
|
||||||
|
)
|
||||||
|
|
||||||
|
type impl struct {
|
||||||
|
cfg *config
|
||||||
|
|
||||||
|
client redis.UniversalClient
|
||||||
|
onChange *signal.Signal
|
||||||
|
|
||||||
|
closeOnce sync.Once
|
||||||
|
closed chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// New creates a new registry implementation backend by redis.
|
||||||
|
func New(rawURL string, options ...Option) (registry.Interface, error) {
|
||||||
|
cfg := getConfig(options...)
|
||||||
|
|
||||||
|
client, err := redisutil.NewClientFromURL(rawURL, cfg.tls)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
i := &impl{
|
||||||
|
cfg: cfg,
|
||||||
|
client: client,
|
||||||
|
onChange: signal.New(),
|
||||||
|
closed: make(chan struct{}),
|
||||||
|
}
|
||||||
|
go i.listenForChanges(context.Background())
|
||||||
|
return i, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *impl) Report(ctx context.Context, req *registrypb.RegisterRequest) (*registrypb.RegisterResponse, error) {
|
||||||
|
_, err := i.runReport(ctx, req.GetServices())
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return ®istrypb.RegisterResponse{
|
||||||
|
CallBackAfter: durationpb.New(i.cfg.ttl / 2),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *impl) List(ctx context.Context, req *registrypb.ListRequest) (*registrypb.ServiceList, error) {
|
||||||
|
all, err := i.runReport(ctx, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
include := map[registrypb.ServiceKind]struct{}{}
|
||||||
|
for _, kind := range req.GetKinds() {
|
||||||
|
include[kind] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
filtered := make([]*registrypb.Service, 0, len(all))
|
||||||
|
for _, svc := range all {
|
||||||
|
if _, ok := include[svc.GetKind()]; !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
filtered = append(filtered, svc)
|
||||||
|
}
|
||||||
|
|
||||||
|
sort.Slice(filtered, func(i, j int) bool {
|
||||||
|
{
|
||||||
|
iv, jv := filtered[i].GetKind(), filtered[j].GetKind()
|
||||||
|
switch {
|
||||||
|
case iv < jv:
|
||||||
|
return true
|
||||||
|
case jv < iv:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
iv, jv := filtered[i].GetEndpoint(), filtered[j].GetEndpoint()
|
||||||
|
switch {
|
||||||
|
case iv < jv:
|
||||||
|
return true
|
||||||
|
case jv < iv:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
})
|
||||||
|
|
||||||
|
return ®istrypb.ServiceList{
|
||||||
|
Services: filtered,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *impl) Watch(req *registrypb.ListRequest, stream registrypb.Registry_WatchServer) error {
|
||||||
|
// listen for changes
|
||||||
|
ch := i.onChange.Bind()
|
||||||
|
defer i.onChange.Unbind(ch)
|
||||||
|
|
||||||
|
// force a check periodically
|
||||||
|
poll := time.NewTicker(pollInterval)
|
||||||
|
defer poll.Stop()
|
||||||
|
|
||||||
|
var prev *registrypb.ServiceList
|
||||||
|
for {
|
||||||
|
// retrieve the most recent list of services
|
||||||
|
lst, err := i.List(stream.Context(), req)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// only send a new list if something changed
|
||||||
|
if !proto.Equal(prev, lst) {
|
||||||
|
err = stream.Send(lst)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
prev = lst
|
||||||
|
|
||||||
|
// wait for an update
|
||||||
|
select {
|
||||||
|
case <-i.closed:
|
||||||
|
return nil
|
||||||
|
case <-stream.Context().Done():
|
||||||
|
return stream.Context().Err()
|
||||||
|
case <-ch:
|
||||||
|
case <-poll.C:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *impl) Close() error {
|
||||||
|
var err error
|
||||||
|
i.closeOnce.Do(func() {
|
||||||
|
err = i.client.Close()
|
||||||
|
close(i.closed)
|
||||||
|
})
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *impl) listenForChanges(ctx context.Context) {
|
||||||
|
ctx, cancel := context.WithCancel(ctx)
|
||||||
|
go func() {
|
||||||
|
<-i.closed
|
||||||
|
cancel()
|
||||||
|
}()
|
||||||
|
|
||||||
|
bo := backoff.NewExponentialBackOff()
|
||||||
|
bo.MaxElapsedTime = 0
|
||||||
|
|
||||||
|
outer:
|
||||||
|
for {
|
||||||
|
pubsub := i.client.Subscribe(ctx, registryUpdateKey)
|
||||||
|
for {
|
||||||
|
msg, err := pubsub.Receive(ctx)
|
||||||
|
if err != nil {
|
||||||
|
_ = pubsub.Close()
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case <-time.After(bo.NextBackOff()):
|
||||||
|
}
|
||||||
|
continue outer
|
||||||
|
}
|
||||||
|
bo.Reset()
|
||||||
|
|
||||||
|
switch msg.(type) {
|
||||||
|
case *redis.Message:
|
||||||
|
i.onChange.Broadcast(ctx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *impl) runReport(ctx context.Context, updates []*registrypb.Service) ([]*registrypb.Service, error) {
|
||||||
|
args := []interface{}{
|
||||||
|
i.cfg.getNow().UnixNano() / int64(time.Millisecond), // current_time
|
||||||
|
i.cfg.ttl.Milliseconds(), // ttl
|
||||||
|
}
|
||||||
|
for _, svc := range updates {
|
||||||
|
args = append(args, i.getRegistryHashKey(svc))
|
||||||
|
}
|
||||||
|
res, err := i.client.Eval(ctx, lua.Registry, []string{registryKey}, args...).Result()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
_, err = i.client.Publish(ctx, registryUpdateKey, time.Now().Format(time.RFC3339Nano)).Result()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if values, ok := res.([]interface{}); ok {
|
||||||
|
var all []*registrypb.Service
|
||||||
|
for _, value := range values {
|
||||||
|
svc, err := i.getServiceFromRegistryHashKey(fmt.Sprint(value))
|
||||||
|
if err != nil {
|
||||||
|
log.Warn(ctx).Err(err).Msg("redis: invalid service")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
all = append(all, svc)
|
||||||
|
}
|
||||||
|
return all, nil
|
||||||
|
}
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *impl) getServiceFromRegistryHashKey(key string) (*registrypb.Service, error) {
|
||||||
|
idx := strings.Index(key, "|")
|
||||||
|
if idx == -1 {
|
||||||
|
return nil, fmt.Errorf("redis: invalid service entry in hash: %s", key)
|
||||||
|
}
|
||||||
|
|
||||||
|
svcKindStr := key[:idx]
|
||||||
|
svcEndpointStr := key[idx+1:]
|
||||||
|
|
||||||
|
svcKind, ok := registrypb.ServiceKind_value[svcKindStr]
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("redis: unknown service kind: %s", svcKindStr)
|
||||||
|
}
|
||||||
|
|
||||||
|
svc := ®istrypb.Service{
|
||||||
|
Kind: registrypb.ServiceKind(svcKind),
|
||||||
|
Endpoint: svcEndpointStr,
|
||||||
|
}
|
||||||
|
return svc, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *impl) getRegistryHashKey(svc *registrypb.Service) string {
|
||||||
|
return svc.GetKind().String() + "|" + svc.GetEndpoint()
|
||||||
|
}
|
196
internal/registry/redis/redis_test.go
Normal file
196
internal/registry/redis/redis_test.go
Normal file
|
@ -0,0 +1,196 @@
|
||||||
|
package redis
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
"runtime"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"golang.org/x/sync/errgroup"
|
||||||
|
"google.golang.org/grpc"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/internal/testutil"
|
||||||
|
registrypb "github.com/pomerium/pomerium/pkg/grpc/registry"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestReport(t *testing.T) {
|
||||||
|
if os.Getenv("GITHUB_ACTION") != "" && runtime.GOOS == "darwin" {
|
||||||
|
t.Skip("Github action can not run docker on MacOS")
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
require.NoError(t, testutil.WithTestRedis(false, func(rawURL string) error {
|
||||||
|
tm := time.Now()
|
||||||
|
|
||||||
|
i, err := New(rawURL,
|
||||||
|
WithGetNow(func() time.Time {
|
||||||
|
return tm
|
||||||
|
}),
|
||||||
|
WithTTL(time.Second*10))
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() { _ = i.Close() }()
|
||||||
|
|
||||||
|
_, err = i.Report(ctx, ®istrypb.RegisterRequest{
|
||||||
|
Services: []*registrypb.Service{
|
||||||
|
{Kind: registrypb.ServiceKind_AUTHORIZE, Endpoint: "https://authorize.example.com"},
|
||||||
|
{Kind: registrypb.ServiceKind_AUTHENTICATE, Endpoint: "https://authenticate.example.com"},
|
||||||
|
{Kind: registrypb.ServiceKind_PROXY, Endpoint: "https://proxy.example.com"},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// move forward 5 seconds
|
||||||
|
tm = tm.Add(time.Second * 5)
|
||||||
|
_, err = i.Report(ctx, ®istrypb.RegisterRequest{
|
||||||
|
Services: []*registrypb.Service{
|
||||||
|
{Kind: registrypb.ServiceKind_AUTHENTICATE, Endpoint: "https://authenticate.example.com"},
|
||||||
|
{Kind: registrypb.ServiceKind_PROXY, Endpoint: "https://proxy.example.com"},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
lst, err := i.List(ctx, ®istrypb.ListRequest{
|
||||||
|
Kinds: []registrypb.ServiceKind{
|
||||||
|
registrypb.ServiceKind_AUTHORIZE,
|
||||||
|
registrypb.ServiceKind_PROXY,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, []*registrypb.Service{
|
||||||
|
{Kind: registrypb.ServiceKind_AUTHORIZE, Endpoint: "https://authorize.example.com"},
|
||||||
|
{Kind: registrypb.ServiceKind_PROXY, Endpoint: "https://proxy.example.com"},
|
||||||
|
}, lst.GetServices(), "should list selected services")
|
||||||
|
|
||||||
|
// move forward 6 seconds
|
||||||
|
tm = tm.Add(time.Second * 6)
|
||||||
|
lst, err = i.List(ctx, ®istrypb.ListRequest{
|
||||||
|
Kinds: []registrypb.ServiceKind{
|
||||||
|
registrypb.ServiceKind_AUTHORIZE,
|
||||||
|
registrypb.ServiceKind_PROXY,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, []*registrypb.Service{
|
||||||
|
{Kind: registrypb.ServiceKind_PROXY, Endpoint: "https://proxy.example.com"},
|
||||||
|
}, lst.GetServices(), "should expire old services")
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWatch(t *testing.T) {
|
||||||
|
if os.Getenv("GITHUB_ACTION") != "" && runtime.GOOS == "darwin" {
|
||||||
|
t.Skip("Github action can not run docker on MacOS")
|
||||||
|
}
|
||||||
|
|
||||||
|
require.NoError(t, testutil.WithTestRedis(false, func(rawURL string) error {
|
||||||
|
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*15)
|
||||||
|
defer clearTimeout()
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(ctx)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
tm := time.Now()
|
||||||
|
i, err := New(rawURL,
|
||||||
|
WithGetNow(func() time.Time {
|
||||||
|
return tm
|
||||||
|
}),
|
||||||
|
WithTTL(time.Second*10))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
li, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer li.Close()
|
||||||
|
|
||||||
|
srv := grpc.NewServer()
|
||||||
|
registrypb.RegisterRegistryServer(srv, i)
|
||||||
|
eg, ctx := errgroup.WithContext(ctx)
|
||||||
|
eg.Go(func() error {
|
||||||
|
<-ctx.Done()
|
||||||
|
srv.Stop()
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
eg.Go(func() error {
|
||||||
|
return srv.Serve(li)
|
||||||
|
})
|
||||||
|
eg.Go(func() error {
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
cc, err := grpc.Dial(li.Addr().String(), grpc.WithInsecure())
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
client := registrypb.NewRegistryClient(cc)
|
||||||
|
|
||||||
|
// store the initial services
|
||||||
|
_, err = client.Report(ctx, ®istrypb.RegisterRequest{
|
||||||
|
Services: []*registrypb.Service{
|
||||||
|
{Kind: registrypb.ServiceKind_AUTHENTICATE, Endpoint: "http://authenticate1.example.com"},
|
||||||
|
{Kind: registrypb.ServiceKind_AUTHORIZE, Endpoint: "http://authorize2.example.com"},
|
||||||
|
{Kind: registrypb.ServiceKind_AUTHORIZE, Endpoint: "http://authorize1.example.com"},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
stream, err := client.Watch(ctx, ®istrypb.ListRequest{
|
||||||
|
Kinds: []registrypb.ServiceKind{
|
||||||
|
registrypb.ServiceKind_AUTHORIZE,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer func() { _ = stream.CloseSend() }()
|
||||||
|
|
||||||
|
lst, err := stream.Recv()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
assert.Equal(t, []*registrypb.Service{
|
||||||
|
{Kind: registrypb.ServiceKind_AUTHORIZE, Endpoint: "http://authorize1.example.com"},
|
||||||
|
{Kind: registrypb.ServiceKind_AUTHORIZE, Endpoint: "http://authorize2.example.com"},
|
||||||
|
}, lst.GetServices())
|
||||||
|
|
||||||
|
// update authenticate
|
||||||
|
_, err = client.Report(ctx, ®istrypb.RegisterRequest{
|
||||||
|
Services: []*registrypb.Service{
|
||||||
|
{Kind: registrypb.ServiceKind_AUTHENTICATE, Endpoint: "http://authenticate1.example.com"},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// add an authorize
|
||||||
|
_, err = client.Report(ctx, ®istrypb.RegisterRequest{
|
||||||
|
Services: []*registrypb.Service{
|
||||||
|
{Kind: registrypb.ServiceKind_AUTHORIZE, Endpoint: "http://authorize3.example.com"},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
lst, err = stream.Recv()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
assert.Equal(t, []*registrypb.Service{
|
||||||
|
{Kind: registrypb.ServiceKind_AUTHORIZE, Endpoint: "http://authorize1.example.com"},
|
||||||
|
{Kind: registrypb.ServiceKind_AUTHORIZE, Endpoint: "http://authorize2.example.com"},
|
||||||
|
{Kind: registrypb.ServiceKind_AUTHORIZE, Endpoint: "http://authorize3.example.com"},
|
||||||
|
}, lst.GetServices())
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
require.NoError(t, eg.Wait())
|
||||||
|
return nil
|
||||||
|
}))
|
||||||
|
}
|
|
@ -1,2 +1,14 @@
|
||||||
// Package registry implements a service registry server.
|
// Package registry implements a service registry server.
|
||||||
package registry
|
package registry
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
|
||||||
|
registrypb "github.com/pomerium/pomerium/pkg/grpc/registry"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Interface is a registry implementation.
|
||||||
|
type Interface interface {
|
||||||
|
registrypb.RegistryServer
|
||||||
|
io.Closer
|
||||||
|
}
|
||||||
|
|
|
@ -9,11 +9,12 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/cenkalti/backoff/v4"
|
"github.com/cenkalti/backoff/v4"
|
||||||
redis "github.com/go-redis/redis/v8"
|
"github.com/go-redis/redis/v8"
|
||||||
"github.com/golang/protobuf/proto"
|
"github.com/golang/protobuf/proto"
|
||||||
"google.golang.org/protobuf/types/known/timestamppb"
|
"google.golang.org/protobuf/types/known/timestamppb"
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/internal/log"
|
"github.com/pomerium/pomerium/internal/log"
|
||||||
|
"github.com/pomerium/pomerium/internal/redisutil"
|
||||||
"github.com/pomerium/pomerium/internal/signal"
|
"github.com/pomerium/pomerium/internal/signal"
|
||||||
"github.com/pomerium/pomerium/internal/telemetry/metrics"
|
"github.com/pomerium/pomerium/internal/telemetry/metrics"
|
||||||
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
||||||
|
@ -28,14 +29,14 @@ const (
|
||||||
|
|
||||||
// we rely on transactions in redis, so all redis-cluster keys need to be
|
// we rely on transactions in redis, so all redis-cluster keys need to be
|
||||||
// on the same node. Using a `hash tag` gives us this capability.
|
// on the same node. Using a `hash tag` gives us this capability.
|
||||||
serverVersionKey = "{pomerium_v3}.server_version"
|
serverVersionKey = redisutil.KeyPrefix + "server_version"
|
||||||
lastVersionKey = "{pomerium_v3}.last_version"
|
lastVersionKey = redisutil.KeyPrefix + "last_version"
|
||||||
lastVersionChKey = "{pomerium_v3}.last_version_ch"
|
lastVersionChKey = redisutil.KeyPrefix + "last_version_ch"
|
||||||
recordHashKey = "{pomerium_v3}.records"
|
recordHashKey = redisutil.KeyPrefix + "records"
|
||||||
changesSetKey = "{pomerium_v3}.changes"
|
changesSetKey = redisutil.KeyPrefix + "changes"
|
||||||
optionsKey = "{pomerium_v3}.options"
|
optionsKey = redisutil.KeyPrefix + "options"
|
||||||
|
|
||||||
recordTypeChangesKeyTpl = "{pomerium_v3}.changes.%s"
|
recordTypeChangesKeyTpl = redisutil.KeyPrefix + "changes.%s"
|
||||||
)
|
)
|
||||||
|
|
||||||
// custom errors
|
// custom errors
|
||||||
|
@ -76,7 +77,7 @@ func New(rawURL string, options ...Option) (*Backend, error) {
|
||||||
onChange: signal.New(),
|
onChange: signal.New(),
|
||||||
}
|
}
|
||||||
var err error
|
var err error
|
||||||
backend.client, err = newClientFromURL(rawURL, backend.cfg.tls)
|
backend.client, err = redisutil.NewClientFromURL(rawURL, backend.cfg.tls)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue