registry: implement redis backend (#2179)

This commit is contained in:
Caleb Doxsey 2021-05-10 10:33:37 -06:00 committed by GitHub
parent 28155314e9
commit a54d43b937
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
21 changed files with 772 additions and 64 deletions

View file

@ -22,6 +22,7 @@ import (
"github.com/pomerium/pomerium/internal/version" "github.com/pomerium/pomerium/internal/version"
"github.com/pomerium/pomerium/pkg/cryptutil" "github.com/pomerium/pomerium/pkg/cryptutil"
"github.com/pomerium/pomerium/pkg/grpc/databroker" "github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/grpc/registry"
"github.com/pomerium/pomerium/pkg/grpcutil" "github.com/pomerium/pomerium/pkg/grpcutil"
) )
@ -116,6 +117,7 @@ func (c *DataBroker) OnConfigChange(ctx context.Context, cfg *config.Config) {
func (c *DataBroker) Register(grpcServer *grpc.Server) { func (c *DataBroker) Register(grpcServer *grpc.Server) {
databroker.RegisterDataBrokerServiceServer(grpcServer, c.dataBrokerServer) databroker.RegisterDataBrokerServiceServer(grpcServer, c.dataBrokerServer)
directory.RegisterDirectoryServiceServer(grpcServer, c) directory.RegisterDirectoryServiceServer(grpcServer, c)
registry.RegisterRegistryServer(grpcServer, c.dataBrokerServer)
} }
// Run runs the databroker components. // Run runs the databroker components.

View file

@ -8,6 +8,7 @@ import (
"github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/databroker" "github.com/pomerium/pomerium/internal/databroker"
databrokerpb "github.com/pomerium/pomerium/pkg/grpc/databroker" databrokerpb "github.com/pomerium/pomerium/pkg/grpc/databroker"
registrypb "github.com/pomerium/pomerium/pkg/grpc/registry"
"github.com/pomerium/pomerium/pkg/grpcutil" "github.com/pomerium/pomerium/pkg/grpcutil"
) )
@ -51,6 +52,8 @@ func (srv *dataBrokerServer) setKey(cfg *config.Config) {
srv.sharedKey.Store(bs) srv.sharedKey.Store(bs)
} }
// Databroker functions
func (srv *dataBrokerServer) Get(ctx context.Context, req *databrokerpb.GetRequest) (*databrokerpb.GetResponse, error) { func (srv *dataBrokerServer) Get(ctx context.Context, req *databrokerpb.GetRequest) (*databrokerpb.GetResponse, error) {
if err := grpcutil.RequireSignedJWT(ctx, srv.sharedKey.Load().([]byte)); err != nil { if err := grpcutil.RequireSignedJWT(ctx, srv.sharedKey.Load().([]byte)); err != nil {
return nil, err return nil, err
@ -92,3 +95,26 @@ func (srv *dataBrokerServer) SyncLatest(req *databrokerpb.SyncLatestRequest, str
} }
return srv.server.SyncLatest(req, stream) return srv.server.SyncLatest(req, stream)
} }
// Registry functions
func (srv *dataBrokerServer) Report(ctx context.Context, req *registrypb.RegisterRequest) (*registrypb.RegisterResponse, error) {
if err := grpcutil.RequireSignedJWT(ctx, srv.sharedKey.Load().([]byte)); err != nil {
return nil, err
}
return srv.server.Report(ctx, req)
}
func (srv *dataBrokerServer) List(ctx context.Context, req *registrypb.ListRequest) (*registrypb.ServiceList, error) {
if err := grpcutil.RequireSignedJWT(ctx, srv.sharedKey.Load().([]byte)); err != nil {
return nil, err
}
return srv.server.List(ctx, req)
}
func (srv *dataBrokerServer) Watch(req *registrypb.ListRequest, stream registrypb.Registry_WatchServer) error {
if err := grpcutil.RequireSignedJWT(stream.Context(), srv.sharedKey.Load().([]byte)); err != nil {
return err
}
return srv.server.Watch(req, stream)
}

View file

@ -1,7 +0,0 @@
package pomerium
import "time"
const (
registryTTL = time.Minute
)

View file

@ -27,7 +27,6 @@ import (
"github.com/pomerium/pomerium/internal/registry" "github.com/pomerium/pomerium/internal/registry"
"github.com/pomerium/pomerium/internal/urlutil" "github.com/pomerium/pomerium/internal/urlutil"
"github.com/pomerium/pomerium/internal/version" "github.com/pomerium/pomerium/internal/version"
registry_pb "github.com/pomerium/pomerium/pkg/grpc/registry"
"github.com/pomerium/pomerium/proxy" "github.com/pomerium/pomerium/proxy"
) )
@ -110,10 +109,6 @@ func Run(ctx context.Context, configFile string) error {
if err != nil { if err != nil {
return fmt.Errorf("setting up databroker: %w", err) return fmt.Errorf("setting up databroker: %w", err)
} }
if err = setupRegistryServer(src, controlPlane); err != nil {
return fmt.Errorf("setting up registry: %w", err)
}
} }
if err = setupRegistryReporter(ctx, src); err != nil { if err = setupRegistryReporter(ctx, src); err != nil {
@ -213,13 +208,6 @@ func setupDataBroker(ctx context.Context, src config.Source, controlPlane *contr
return svc, nil return svc, nil
} }
func setupRegistryServer(src config.Source, controlPlane *controlplane.Server) error {
svc := registry.NewInMemoryServer(context.TODO(), registryTTL)
registry_pb.RegisterRegistryServer(controlPlane.GRPCServer, svc)
log.Info(context.TODO()).Msg("enabled service discovery")
return nil
}
func setupRegistryReporter(ctx context.Context, src config.Source) error { func setupRegistryReporter(ctx context.Context, src config.Source) error {
reporter := new(registry.Reporter) reporter := new(registry.Reporter)
src.OnConfigChange(ctx, reporter.OnConfigChange) src.OnConfigChange(ctx, reporter.OnConfigChange)

View file

@ -17,6 +17,8 @@ var (
DefaultStorageType = "memory" DefaultStorageType = "memory"
// DefaultGetAllPageSize is the default page size for GetAll calls. // DefaultGetAllPageSize is the default page size for GetAll calls.
DefaultGetAllPageSize = 50 DefaultGetAllPageSize = 50
// DefaultRegistryTTL is the default registry time to live.
DefaultRegistryTTL = time.Minute
) )
type serverConfig struct { type serverConfig struct {
@ -28,6 +30,7 @@ type serverConfig struct {
storageCertSkipVerify bool storageCertSkipVerify bool
storageCertificate *tls.Certificate storageCertificate *tls.Certificate
getAllPageSize int getAllPageSize int
registryTTL time.Duration
} }
func newServerConfig(options ...ServerOption) *serverConfig { func newServerConfig(options ...ServerOption) *serverConfig {
@ -35,6 +38,7 @@ func newServerConfig(options ...ServerOption) *serverConfig {
WithDeletePermanentlyAfter(DefaultDeletePermanentlyAfter)(cfg) WithDeletePermanentlyAfter(DefaultDeletePermanentlyAfter)(cfg)
WithStorageType(DefaultStorageType)(cfg) WithStorageType(DefaultStorageType)(cfg)
WithGetAllPageSize(DefaultGetAllPageSize)(cfg) WithGetAllPageSize(DefaultGetAllPageSize)(cfg)
WithRegistryTTL(DefaultRegistryTTL)(cfg)
for _, option := range options { for _, option := range options {
option(cfg) option(cfg)
} }
@ -60,6 +64,13 @@ func WithGetAllPageSize(pageSize int) ServerOption {
} }
} }
// WithRegistryTTL sets the registry time to live in the config.
func WithRegistryTTL(ttl time.Duration) ServerOption {
return func(cfg *serverConfig) {
cfg.registryTTL = ttl
}
}
// WithGetSharedKey sets the secret in the config. // WithGetSharedKey sets the secret in the config.
func WithGetSharedKey(getSharedKey func() ([]byte, error)) ServerOption { func WithGetSharedKey(getSharedKey func() ([]byte, error)) ServerOption {
return func(cfg *serverConfig) { return func(cfg *serverConfig) {

View file

@ -0,0 +1,109 @@
package databroker
import (
"context"
"fmt"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/registry"
"github.com/pomerium/pomerium/internal/registry/inmemory"
"github.com/pomerium/pomerium/internal/registry/redis"
"github.com/pomerium/pomerium/internal/telemetry/trace"
registrypb "github.com/pomerium/pomerium/pkg/grpc/registry"
)
type registryWatchServer struct {
registrypb.Registry_WatchServer
ctx context.Context
}
func (stream registryWatchServer) Context() context.Context {
return stream.ctx
}
// Report calls the registry Report method.
func (srv *Server) Report(ctx context.Context, req *registrypb.RegisterRequest) (*registrypb.RegisterResponse, error) {
ctx, span := trace.StartSpan(ctx, "databroker.grpc.Report")
defer span.End()
r, err := srv.getRegistry()
if err != nil {
return nil, err
}
return r.Report(ctx, req)
}
// List calls the registry List method.
func (srv *Server) List(ctx context.Context, req *registrypb.ListRequest) (*registrypb.ServiceList, error) {
ctx, span := trace.StartSpan(ctx, "databroker.grpc.List")
defer span.End()
r, err := srv.getRegistry()
if err != nil {
return nil, err
}
return r.List(ctx, req)
}
// Watch calls the registry Watch method.
func (srv *Server) Watch(req *registrypb.ListRequest, stream registrypb.Registry_WatchServer) error {
ctx := stream.Context()
ctx, span := trace.StartSpan(ctx, "databroker.grpc.Watch")
defer span.End()
r, err := srv.getRegistry()
if err != nil {
return err
}
return r.Watch(req, registryWatchServer{
Registry_WatchServer: stream,
ctx: ctx,
})
}
func (srv *Server) getRegistry() (registry.Interface, error) {
// double-checked locking
srv.mu.RLock()
r := srv.registry
srv.mu.RUnlock()
if r == nil {
srv.mu.Lock()
r = srv.registry
var err error
if r == nil {
r, err = srv.newRegistryLocked()
srv.registry = r
}
srv.mu.Unlock()
if err != nil {
return nil, err
}
}
return r, nil
}
func (srv *Server) newRegistryLocked() (registry.Interface, error) {
ctx := context.Background()
switch srv.cfg.storageType {
case config.StorageInMemoryName:
log.Info(ctx).Msg("using in-memory registry")
return inmemory.New(ctx, srv.cfg.registryTTL), nil
case config.StorageRedisName:
log.Info(ctx).Msg("using redis registry")
r, err := redis.New(
srv.cfg.storageConnectionString,
redis.WithTLSConfig(srv.getTLSConfigLocked(ctx)),
)
if err != nil {
return nil, fmt.Errorf("failed to create new redis registry: %w", err)
}
return r, nil
}
return nil, fmt.Errorf("unsupported registry type: %s", srv.cfg.storageType)
}

View file

@ -15,6 +15,7 @@ import (
"github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/registry"
"github.com/pomerium/pomerium/internal/telemetry/trace" "github.com/pomerium/pomerium/internal/telemetry/trace"
"github.com/pomerium/pomerium/pkg/cryptutil" "github.com/pomerium/pomerium/pkg/cryptutil"
"github.com/pomerium/pomerium/pkg/grpc/databroker" "github.com/pomerium/pomerium/pkg/grpc/databroker"
@ -28,8 +29,9 @@ import (
type Server struct { type Server struct {
cfg *serverConfig cfg *serverConfig
mu sync.RWMutex mu sync.RWMutex
backend storage.Backend backend storage.Backend
registry registry.Interface
} }
// New creates a new server. // New creates a new server.
@ -60,6 +62,14 @@ func (srv *Server) UpdateConfig(options ...ServerOption) {
} }
srv.backend = nil srv.backend = nil
} }
if srv.registry != nil {
err := srv.registry.Close()
if err != nil {
log.Error(ctx).Err(err).Msg("databroker: error closing registry")
}
srv.registry = nil
}
} }
// Get gets a record from the in-memory list. // Get gets a record from the in-memory list.
@ -288,18 +298,6 @@ func (srv *Server) getBackend() (backend storage.Backend, err error) {
func (srv *Server) newBackendLocked() (backend storage.Backend, err error) { func (srv *Server) newBackendLocked() (backend storage.Backend, err error) {
ctx := context.Background() ctx := context.Background()
caCertPool, err := cryptutil.GetCertPool("", srv.cfg.storageCAFile)
if err != nil {
log.Warn(ctx).Err(err).Msg("failed to read databroker CA file")
}
tlsConfig := &tls.Config{
RootCAs: caCertPool,
// nolint: gosec
InsecureSkipVerify: srv.cfg.storageCertSkipVerify,
}
if srv.cfg.storageCertificate != nil {
tlsConfig.Certificates = []tls.Certificate{*srv.cfg.storageCertificate}
}
switch srv.cfg.storageType { switch srv.cfg.storageType {
case config.StorageInMemoryName: case config.StorageInMemoryName:
@ -309,7 +307,7 @@ func (srv *Server) newBackendLocked() (backend storage.Backend, err error) {
log.Info(ctx).Msg("using redis store") log.Info(ctx).Msg("using redis store")
backend, err = redis.New( backend, err = redis.New(
srv.cfg.storageConnectionString, srv.cfg.storageConnectionString,
redis.WithTLSConfig(tlsConfig), redis.WithTLSConfig(srv.getTLSConfigLocked(ctx)),
) )
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to create new redis storage: %w", err) return nil, fmt.Errorf("failed to create new redis storage: %w", err)
@ -325,3 +323,19 @@ func (srv *Server) newBackendLocked() (backend storage.Backend, err error) {
} }
return backend, nil return backend, nil
} }
func (srv *Server) getTLSConfigLocked(ctx context.Context) *tls.Config {
caCertPool, err := cryptutil.GetCertPool("", srv.cfg.storageCAFile)
if err != nil {
log.Warn(ctx).Err(err).Msg("failed to read databroker CA file")
}
tlsConfig := &tls.Config{
RootCAs: caCertPool,
// nolint: gosec
InsecureSkipVerify: srv.cfg.storageCertSkipVerify,
}
if srv.cfg.storageCertificate != nil {
tlsConfig.Certificates = []tls.Certificate{*srv.cfg.storageCertificate}
}
return tlsConfig
}

View file

@ -1,4 +1,4 @@
package redis package redisutil
import ( import (
"crypto/tls" "crypto/tls"
@ -43,15 +43,16 @@ var (
) )
) )
func newClientFromURL(rawurl string, tlsConfig *tls.Config) (redis.UniversalClient, error) { // NewClientFromURL creates a new redis client by parsing the raw URL.
u, err := url.Parse(rawurl) func NewClientFromURL(rawURL string, tlsConfig *tls.Config) (redis.UniversalClient, error) {
u, err := url.Parse(rawURL)
if err != nil { if err != nil {
return nil, err return nil, err
} }
switch { switch {
case standardSchemes.Has(u.Scheme): case standardSchemes.Has(u.Scheme):
opts, err := redis.ParseURL(rawurl) opts, err := redis.ParseURL(rawURL)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -62,7 +63,7 @@ func newClientFromURL(rawurl string, tlsConfig *tls.Config) (redis.UniversalClie
return redis.NewClient(opts), nil return redis.NewClient(opts), nil
case clusterSchemes.Has(u.Scheme): case clusterSchemes.Has(u.Scheme):
opts, err := ParseClusterURL(rawurl) opts, err := ParseClusterURL(rawURL)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -72,7 +73,7 @@ func newClientFromURL(rawurl string, tlsConfig *tls.Config) (redis.UniversalClie
return redis.NewClusterClient(opts), nil return redis.NewClusterClient(opts), nil
case sentinelSchemes.Has(u.Scheme): case sentinelSchemes.Has(u.Scheme):
opts, err := ParseSentinelURL(rawurl) opts, err := ParseSentinelURL(rawURL)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -82,7 +83,7 @@ func newClientFromURL(rawurl string, tlsConfig *tls.Config) (redis.UniversalClie
return redis.NewFailoverClient(opts), nil return redis.NewFailoverClient(opts), nil
case sentinelClusterSchemes.Has(u.Scheme): case sentinelClusterSchemes.Has(u.Scheme):
opts, err := ParseSentinelURL(rawurl) opts, err := ParseSentinelURL(rawURL)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -1,4 +1,4 @@
package redis package redisutil
import ( import (
"net/url" "net/url"

View file

@ -0,0 +1,5 @@
// Package redisutil contains functions for working with redis.
package redisutil
// KeyPrefix is the prefix used for all redis keys.
const KeyPrefix = "{pomerium_v3}."

View file

@ -1,14 +1,8 @@
package registry package registry
import ( import "time"
"time"
)
const ( const (
// callAfterTTLFactor will request to report back again after TTL/callAfterTTLFactor time
callAfterTTLFactor = 2
// purgeAfterTTLFactor will purge keys with TTL * purgeAfterTTLFactor time
purgeAfterTTLFactor = 1
// min reporting ttl // min reporting ttl
minTTL = time.Second minTTL = time.Second
// path metrics are available at // path metrics are available at

View file

@ -0,0 +1,8 @@
package inmemory
const (
// callAfterTTLFactor will request to report back again after TTL/callAfterTTLFactor time
callAfterTTLFactor = 2
// purgeAfterTTLFactor will purge keys with TTL * purgeAfterTTLFactor time
purgeAfterTTLFactor = 1
)

View file

@ -1,10 +1,12 @@
package registry // Package inmemory implements an in-memory registry.
package inmemory
import ( import (
"context" "context"
"sync" "sync"
"time" "time"
"github.com/pomerium/pomerium/internal/registry"
"github.com/pomerium/pomerium/internal/signal" "github.com/pomerium/pomerium/internal/signal"
pb "github.com/pomerium/pomerium/pkg/grpc/registry" pb "github.com/pomerium/pomerium/pkg/grpc/registry"
@ -31,9 +33,9 @@ type inMemoryKey struct {
endpoint string endpoint string
} }
// NewInMemoryServer constructs a new registry tracking service that operates in RAM // New constructs a new registry tracking service that operates in RAM
// as such, it is not usable for multi-node deployment where REDIS or other alternative should be used // as such, it is not usable for multi-node deployment where REDIS or other alternative should be used
func NewInMemoryServer(ctx context.Context, ttl time.Duration) pb.RegistryServer { func New(ctx context.Context, ttl time.Duration) registry.Interface {
srv := &inMemoryServer{ srv := &inMemoryServer{
ttl: ttl, ttl: ttl,
regs: make(map[inMemoryKey]*timestamppb.Timestamp), regs: make(map[inMemoryKey]*timestamppb.Timestamp),
@ -57,6 +59,11 @@ func (s *inMemoryServer) periodicCheck(ctx context.Context) {
} }
} }
// Close closes the in memory server.
func (s *inMemoryServer) Close() error {
return nil
}
// Report is periodically sent by each service to confirm it is still serving with the registry // Report is periodically sent by each service to confirm it is still serving with the registry
// data is persisted with a certain TTL // data is persisted with a certain TTL
func (s *inMemoryServer) Report(ctx context.Context, req *pb.RegisterRequest) (*pb.RegisterResponse, error) { func (s *inMemoryServer) Report(ctx context.Context, req *pb.RegisterRequest) (*pb.RegisterResponse, error) {

View file

@ -1,4 +1,4 @@
package registry_test package inmemory
import ( import (
"context" "context"
@ -8,7 +8,6 @@ import (
"testing" "testing"
"time" "time"
"github.com/pomerium/pomerium/internal/registry"
pb "github.com/pomerium/pomerium/pkg/grpc/registry" pb "github.com/pomerium/pomerium/pkg/grpc/registry"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
@ -182,7 +181,7 @@ func newTestRegistry() (context.Context, pb.RegistryClient, func(), error) {
gs := grpc.NewServer() gs := grpc.NewServer()
ttl := time.Second ttl := time.Second
pb.RegisterRegistryServer(gs, registry.NewInMemoryServer(ctx, ttl)) pb.RegisterRegistryServer(gs, New(ctx, ttl))
go gs.Serve(l) go gs.Serve(l)
cancel.Append(gs.Stop) cancel.Append(gs.Stop)

View file

@ -0,0 +1,20 @@
// Package lua contains lua source code.
package lua
import (
"embed"
)
//go:embed registry.lua
var fs embed.FS
// Registry is the registry lua script
var Registry string
func init() {
bs, err := fs.ReadFile("registry.lua")
if err != nil {
panic(err)
}
Registry = string(bs)
}

View file

@ -0,0 +1,20 @@
-- ARGV = [current time in seconds, ttl in seconds, services ...]
local current_time = ARGV[1]
local ttl = ARGV[2]
-- update the service list
for i = 3, #ARGV, 1 do
redis.call('HSET', KEYS[1], ARGV[i], current_time + ttl)
end
-- retrieve all the services, removing any that have expired
local svcs = {}
local kvs = redis.call('HGETALL', KEYS[1])
for i = 1, #kvs, 2 do
if kvs[i + 1] < current_time then
redis.call('HDEL', KEYS[1], kvs[i])
else
table.insert(svcs, kvs[i])
end
end
return svcs

View file

@ -0,0 +1,48 @@
package redis
import (
"crypto/tls"
"time"
)
const defaultTTL = time.Second * 30
type config struct {
tls *tls.Config
ttl time.Duration
getNow func() time.Time
}
// An Option modifies the config..
type Option func(*config)
// WithGetNow sets the time.Now function in the config.
func WithGetNow(getNow func() time.Time) Option {
return func(cfg *config) {
cfg.getNow = getNow
}
}
// WithTLSConfig sets the tls.Config in the config.
func WithTLSConfig(tlsConfig *tls.Config) Option {
return func(cfg *config) {
cfg.tls = tlsConfig
}
}
// WithTTL sets the ttl in the config.
func WithTTL(ttl time.Duration) Option {
return func(cfg *config) {
cfg.ttl = ttl
}
}
func getConfig(options ...Option) *config {
cfg := new(config)
WithGetNow(time.Now)(cfg)
WithTTL(defaultTTL)(cfg)
for _, o := range options {
o(cfg)
}
return cfg
}

View file

@ -0,0 +1,254 @@
// Package redis implements a registry in redis.
package redis
import (
"context"
"fmt"
"sort"
"strings"
"sync"
"time"
"github.com/cenkalti/backoff/v4"
"github.com/go-redis/redis/v8"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/durationpb"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/redisutil"
"github.com/pomerium/pomerium/internal/registry"
"github.com/pomerium/pomerium/internal/registry/redis/lua"
"github.com/pomerium/pomerium/internal/signal"
registrypb "github.com/pomerium/pomerium/pkg/grpc/registry"
)
const (
registryKey = redisutil.KeyPrefix + "registry"
registryUpdateKey = redisutil.KeyPrefix + "registry_changed_ch"
pollInterval = time.Second * 30
)
type impl struct {
cfg *config
client redis.UniversalClient
onChange *signal.Signal
closeOnce sync.Once
closed chan struct{}
}
// New creates a new registry implementation backend by redis.
func New(rawURL string, options ...Option) (registry.Interface, error) {
cfg := getConfig(options...)
client, err := redisutil.NewClientFromURL(rawURL, cfg.tls)
if err != nil {
return nil, err
}
i := &impl{
cfg: cfg,
client: client,
onChange: signal.New(),
closed: make(chan struct{}),
}
go i.listenForChanges(context.Background())
return i, nil
}
func (i *impl) Report(ctx context.Context, req *registrypb.RegisterRequest) (*registrypb.RegisterResponse, error) {
_, err := i.runReport(ctx, req.GetServices())
if err != nil {
return nil, err
}
return &registrypb.RegisterResponse{
CallBackAfter: durationpb.New(i.cfg.ttl / 2),
}, nil
}
func (i *impl) List(ctx context.Context, req *registrypb.ListRequest) (*registrypb.ServiceList, error) {
all, err := i.runReport(ctx, nil)
if err != nil {
return nil, err
}
include := map[registrypb.ServiceKind]struct{}{}
for _, kind := range req.GetKinds() {
include[kind] = struct{}{}
}
filtered := make([]*registrypb.Service, 0, len(all))
for _, svc := range all {
if _, ok := include[svc.GetKind()]; !ok {
continue
}
filtered = append(filtered, svc)
}
sort.Slice(filtered, func(i, j int) bool {
{
iv, jv := filtered[i].GetKind(), filtered[j].GetKind()
switch {
case iv < jv:
return true
case jv < iv:
return false
}
}
{
iv, jv := filtered[i].GetEndpoint(), filtered[j].GetEndpoint()
switch {
case iv < jv:
return true
case jv < iv:
return false
}
}
return false
})
return &registrypb.ServiceList{
Services: filtered,
}, nil
}
func (i *impl) Watch(req *registrypb.ListRequest, stream registrypb.Registry_WatchServer) error {
// listen for changes
ch := i.onChange.Bind()
defer i.onChange.Unbind(ch)
// force a check periodically
poll := time.NewTicker(pollInterval)
defer poll.Stop()
var prev *registrypb.ServiceList
for {
// retrieve the most recent list of services
lst, err := i.List(stream.Context(), req)
if err != nil {
return err
}
// only send a new list if something changed
if !proto.Equal(prev, lst) {
err = stream.Send(lst)
if err != nil {
return err
}
}
prev = lst
// wait for an update
select {
case <-i.closed:
return nil
case <-stream.Context().Done():
return stream.Context().Err()
case <-ch:
case <-poll.C:
}
}
}
func (i *impl) Close() error {
var err error
i.closeOnce.Do(func() {
err = i.client.Close()
close(i.closed)
})
return err
}
func (i *impl) listenForChanges(ctx context.Context) {
ctx, cancel := context.WithCancel(ctx)
go func() {
<-i.closed
cancel()
}()
bo := backoff.NewExponentialBackOff()
bo.MaxElapsedTime = 0
outer:
for {
pubsub := i.client.Subscribe(ctx, registryUpdateKey)
for {
msg, err := pubsub.Receive(ctx)
if err != nil {
_ = pubsub.Close()
select {
case <-ctx.Done():
return
case <-time.After(bo.NextBackOff()):
}
continue outer
}
bo.Reset()
switch msg.(type) {
case *redis.Message:
i.onChange.Broadcast(ctx)
}
}
}
}
func (i *impl) runReport(ctx context.Context, updates []*registrypb.Service) ([]*registrypb.Service, error) {
args := []interface{}{
i.cfg.getNow().UnixNano() / int64(time.Millisecond), // current_time
i.cfg.ttl.Milliseconds(), // ttl
}
for _, svc := range updates {
args = append(args, i.getRegistryHashKey(svc))
}
res, err := i.client.Eval(ctx, lua.Registry, []string{registryKey}, args...).Result()
if err != nil {
return nil, err
}
_, err = i.client.Publish(ctx, registryUpdateKey, time.Now().Format(time.RFC3339Nano)).Result()
if err != nil {
return nil, err
}
if values, ok := res.([]interface{}); ok {
var all []*registrypb.Service
for _, value := range values {
svc, err := i.getServiceFromRegistryHashKey(fmt.Sprint(value))
if err != nil {
log.Warn(ctx).Err(err).Msg("redis: invalid service")
continue
}
all = append(all, svc)
}
return all, nil
}
return nil, nil
}
func (i *impl) getServiceFromRegistryHashKey(key string) (*registrypb.Service, error) {
idx := strings.Index(key, "|")
if idx == -1 {
return nil, fmt.Errorf("redis: invalid service entry in hash: %s", key)
}
svcKindStr := key[:idx]
svcEndpointStr := key[idx+1:]
svcKind, ok := registrypb.ServiceKind_value[svcKindStr]
if !ok {
return nil, fmt.Errorf("redis: unknown service kind: %s", svcKindStr)
}
svc := &registrypb.Service{
Kind: registrypb.ServiceKind(svcKind),
Endpoint: svcEndpointStr,
}
return svc, nil
}
func (i *impl) getRegistryHashKey(svc *registrypb.Service) string {
return svc.GetKind().String() + "|" + svc.GetEndpoint()
}

View file

@ -0,0 +1,196 @@
package redis
import (
"context"
"net"
"os"
"runtime"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/sync/errgroup"
"google.golang.org/grpc"
"github.com/pomerium/pomerium/internal/testutil"
registrypb "github.com/pomerium/pomerium/pkg/grpc/registry"
)
func TestReport(t *testing.T) {
if os.Getenv("GITHUB_ACTION") != "" && runtime.GOOS == "darwin" {
t.Skip("Github action can not run docker on MacOS")
}
ctx := context.Background()
require.NoError(t, testutil.WithTestRedis(false, func(rawURL string) error {
tm := time.Now()
i, err := New(rawURL,
WithGetNow(func() time.Time {
return tm
}),
WithTTL(time.Second*10))
require.NoError(t, err)
defer func() { _ = i.Close() }()
_, err = i.Report(ctx, &registrypb.RegisterRequest{
Services: []*registrypb.Service{
{Kind: registrypb.ServiceKind_AUTHORIZE, Endpoint: "https://authorize.example.com"},
{Kind: registrypb.ServiceKind_AUTHENTICATE, Endpoint: "https://authenticate.example.com"},
{Kind: registrypb.ServiceKind_PROXY, Endpoint: "https://proxy.example.com"},
},
})
require.NoError(t, err)
// move forward 5 seconds
tm = tm.Add(time.Second * 5)
_, err = i.Report(ctx, &registrypb.RegisterRequest{
Services: []*registrypb.Service{
{Kind: registrypb.ServiceKind_AUTHENTICATE, Endpoint: "https://authenticate.example.com"},
{Kind: registrypb.ServiceKind_PROXY, Endpoint: "https://proxy.example.com"},
},
})
require.NoError(t, err)
lst, err := i.List(ctx, &registrypb.ListRequest{
Kinds: []registrypb.ServiceKind{
registrypb.ServiceKind_AUTHORIZE,
registrypb.ServiceKind_PROXY,
},
})
require.NoError(t, err)
assert.Equal(t, []*registrypb.Service{
{Kind: registrypb.ServiceKind_AUTHORIZE, Endpoint: "https://authorize.example.com"},
{Kind: registrypb.ServiceKind_PROXY, Endpoint: "https://proxy.example.com"},
}, lst.GetServices(), "should list selected services")
// move forward 6 seconds
tm = tm.Add(time.Second * 6)
lst, err = i.List(ctx, &registrypb.ListRequest{
Kinds: []registrypb.ServiceKind{
registrypb.ServiceKind_AUTHORIZE,
registrypb.ServiceKind_PROXY,
},
})
require.NoError(t, err)
assert.Equal(t, []*registrypb.Service{
{Kind: registrypb.ServiceKind_PROXY, Endpoint: "https://proxy.example.com"},
}, lst.GetServices(), "should expire old services")
return nil
}))
}
func TestWatch(t *testing.T) {
if os.Getenv("GITHUB_ACTION") != "" && runtime.GOOS == "darwin" {
t.Skip("Github action can not run docker on MacOS")
}
require.NoError(t, testutil.WithTestRedis(false, func(rawURL string) error {
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*15)
defer clearTimeout()
ctx, cancel := context.WithCancel(ctx)
defer cancel()
tm := time.Now()
i, err := New(rawURL,
WithGetNow(func() time.Time {
return tm
}),
WithTTL(time.Second*10))
require.NoError(t, err)
li, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
defer li.Close()
srv := grpc.NewServer()
registrypb.RegisterRegistryServer(srv, i)
eg, ctx := errgroup.WithContext(ctx)
eg.Go(func() error {
<-ctx.Done()
srv.Stop()
return nil
})
eg.Go(func() error {
return srv.Serve(li)
})
eg.Go(func() error {
defer cancel()
cc, err := grpc.Dial(li.Addr().String(), grpc.WithInsecure())
if err != nil {
return err
}
client := registrypb.NewRegistryClient(cc)
// store the initial services
_, err = client.Report(ctx, &registrypb.RegisterRequest{
Services: []*registrypb.Service{
{Kind: registrypb.ServiceKind_AUTHENTICATE, Endpoint: "http://authenticate1.example.com"},
{Kind: registrypb.ServiceKind_AUTHORIZE, Endpoint: "http://authorize2.example.com"},
{Kind: registrypb.ServiceKind_AUTHORIZE, Endpoint: "http://authorize1.example.com"},
},
})
if err != nil {
return err
}
stream, err := client.Watch(ctx, &registrypb.ListRequest{
Kinds: []registrypb.ServiceKind{
registrypb.ServiceKind_AUTHORIZE,
},
})
if err != nil {
return err
}
defer func() { _ = stream.CloseSend() }()
lst, err := stream.Recv()
if err != nil {
return err
}
assert.Equal(t, []*registrypb.Service{
{Kind: registrypb.ServiceKind_AUTHORIZE, Endpoint: "http://authorize1.example.com"},
{Kind: registrypb.ServiceKind_AUTHORIZE, Endpoint: "http://authorize2.example.com"},
}, lst.GetServices())
// update authenticate
_, err = client.Report(ctx, &registrypb.RegisterRequest{
Services: []*registrypb.Service{
{Kind: registrypb.ServiceKind_AUTHENTICATE, Endpoint: "http://authenticate1.example.com"},
},
})
if err != nil {
return err
}
// add an authorize
_, err = client.Report(ctx, &registrypb.RegisterRequest{
Services: []*registrypb.Service{
{Kind: registrypb.ServiceKind_AUTHORIZE, Endpoint: "http://authorize3.example.com"},
},
})
if err != nil {
return err
}
lst, err = stream.Recv()
if err != nil {
return err
}
assert.Equal(t, []*registrypb.Service{
{Kind: registrypb.ServiceKind_AUTHORIZE, Endpoint: "http://authorize1.example.com"},
{Kind: registrypb.ServiceKind_AUTHORIZE, Endpoint: "http://authorize2.example.com"},
{Kind: registrypb.ServiceKind_AUTHORIZE, Endpoint: "http://authorize3.example.com"},
}, lst.GetServices())
return nil
})
require.NoError(t, eg.Wait())
return nil
}))
}

View file

@ -1,2 +1,14 @@
// Package registry implements a service registry server. // Package registry implements a service registry server.
package registry package registry
import (
"io"
registrypb "github.com/pomerium/pomerium/pkg/grpc/registry"
)
// Interface is a registry implementation.
type Interface interface {
registrypb.RegistryServer
io.Closer
}

View file

@ -9,11 +9,12 @@ import (
"time" "time"
"github.com/cenkalti/backoff/v4" "github.com/cenkalti/backoff/v4"
redis "github.com/go-redis/redis/v8" "github.com/go-redis/redis/v8"
"github.com/golang/protobuf/proto" "github.com/golang/protobuf/proto"
"google.golang.org/protobuf/types/known/timestamppb" "google.golang.org/protobuf/types/known/timestamppb"
"github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/redisutil"
"github.com/pomerium/pomerium/internal/signal" "github.com/pomerium/pomerium/internal/signal"
"github.com/pomerium/pomerium/internal/telemetry/metrics" "github.com/pomerium/pomerium/internal/telemetry/metrics"
"github.com/pomerium/pomerium/internal/telemetry/trace" "github.com/pomerium/pomerium/internal/telemetry/trace"
@ -28,14 +29,14 @@ const (
// we rely on transactions in redis, so all redis-cluster keys need to be // we rely on transactions in redis, so all redis-cluster keys need to be
// on the same node. Using a `hash tag` gives us this capability. // on the same node. Using a `hash tag` gives us this capability.
serverVersionKey = "{pomerium_v3}.server_version" serverVersionKey = redisutil.KeyPrefix + "server_version"
lastVersionKey = "{pomerium_v3}.last_version" lastVersionKey = redisutil.KeyPrefix + "last_version"
lastVersionChKey = "{pomerium_v3}.last_version_ch" lastVersionChKey = redisutil.KeyPrefix + "last_version_ch"
recordHashKey = "{pomerium_v3}.records" recordHashKey = redisutil.KeyPrefix + "records"
changesSetKey = "{pomerium_v3}.changes" changesSetKey = redisutil.KeyPrefix + "changes"
optionsKey = "{pomerium_v3}.options" optionsKey = redisutil.KeyPrefix + "options"
recordTypeChangesKeyTpl = "{pomerium_v3}.changes.%s" recordTypeChangesKeyTpl = redisutil.KeyPrefix + "changes.%s"
) )
// custom errors // custom errors
@ -76,7 +77,7 @@ func New(rawURL string, options ...Option) (*Backend, error) {
onChange: signal.New(), onChange: signal.New(),
} }
var err error var err error
backend.client, err = newClientFromURL(rawURL, backend.cfg.tls) backend.client, err = redisutil.NewClientFromURL(rawURL, backend.cfg.tls)
if err != nil { if err != nil {
return nil, err return nil, err
} }