From 24a9d627cd98da74d87dcbb8ce8799a61faa5ce0 Mon Sep 17 00:00:00 2001 From: Caleb Doxsey Date: Wed, 13 Jul 2022 09:14:47 -0600 Subject: [PATCH] postgres: registry support (#3454) --- internal/databroker/registry.go | 23 +++++- internal/sets/hash.go | 31 +++++++ pkg/storage/postgres/backend.go | 82 ++++++++++++++----- pkg/storage/postgres/migrate.go | 16 ++++ pkg/storage/postgres/option.go | 16 +++- pkg/storage/postgres/postgres.go | 78 ++++++++++++++++-- pkg/storage/postgres/registry.go | 109 +++++++++++++++++++++++++ pkg/storage/postgres/registry_test.go | 113 ++++++++++++++++++++++++++ pkg/storage/postgres/stream.go | 4 +- 9 files changed, 436 insertions(+), 36 deletions(-) create mode 100644 internal/sets/hash.go create mode 100644 pkg/storage/postgres/registry.go create mode 100644 pkg/storage/postgres/registry_test.go diff --git a/internal/databroker/registry.go b/internal/databroker/registry.go index b858246d7..e2d3d4997 100644 --- a/internal/databroker/registry.go +++ b/internal/databroker/registry.go @@ -3,6 +3,7 @@ package databroker import ( "context" "fmt" + "io" "github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/internal/log" @@ -11,6 +12,7 @@ import ( "github.com/pomerium/pomerium/internal/registry/redis" "github.com/pomerium/pomerium/internal/telemetry/trace" registrypb "github.com/pomerium/pomerium/pkg/grpc/registry" + "github.com/pomerium/pomerium/pkg/storage" ) type registryWatchServer struct { @@ -66,6 +68,11 @@ func (srv *Server) Watch(req *registrypb.ListRequest, stream registrypb.Registry } func (srv *Server) getRegistry() (registry.Interface, error) { + backend, err := srv.getBackend() + if err != nil { + return nil, err + } + // double-checked locking srv.mu.RLock() r := srv.registry @@ -75,7 +82,7 @@ func (srv *Server) getRegistry() (registry.Interface, error) { r = srv.registry var err error if r == nil { - r, err = srv.newRegistryLocked() + r, err = srv.newRegistryLocked(backend) srv.registry = r } srv.mu.Unlock() @@ -86,11 +93,21 @@ func (srv *Server) getRegistry() (registry.Interface, error) { return r, nil } -func (srv *Server) newRegistryLocked() (registry.Interface, error) { +func (srv *Server) newRegistryLocked(backend storage.Backend) (registry.Interface, error) { ctx := context.Background() + if hasRegistryServer, ok := backend.(interface { + RegistryServer() registrypb.RegistryServer + }); ok { + log.Info(ctx).Msg("using registry via storage") + return struct { + io.Closer + registrypb.RegistryServer + }{backend, hasRegistryServer.RegistryServer()}, nil + } + switch srv.cfg.storageType { - case config.StorageInMemoryName, config.StoragePostgresName: + case config.StorageInMemoryName: log.Info(ctx).Msg("using in-memory registry") return inmemory.New(ctx, srv.cfg.registryTTL), nil case config.StorageRedisName: diff --git a/internal/sets/hash.go b/internal/sets/hash.go new file mode 100644 index 000000000..b9e6a2275 --- /dev/null +++ b/internal/sets/hash.go @@ -0,0 +1,31 @@ +package sets + +// A Hash is a set implemented via a map. +type Hash[T comparable] struct { + m map[T]struct{} +} + +// NewHash creates a new Hash set. +func NewHash[T comparable]() *Hash[T] { + return &Hash[T]{ + m: make(map[T]struct{}), + } +} + +// Add adds a value to the set. +func (s *Hash[T]) Add(elements ...T) { + for _, element := range elements { + s.m[element] = struct{}{} + } +} + +// Has returns true if the element is in the set. +func (s *Hash[T]) Has(element T) bool { + _, ok := s.m[element] + return ok +} + +// Size returns the size of the set. +func (s *Hash[T]) Size() int { + return len(s.m) +} diff --git a/pkg/storage/postgres/backend.go b/pkg/storage/postgres/backend.go index d5f6641ac..81e48a33e 100644 --- a/pkg/storage/postgres/backend.go +++ b/pkg/storage/postgres/backend.go @@ -2,6 +2,7 @@ package postgres import ( "context" + "errors" "fmt" "sync" "time" @@ -20,9 +21,10 @@ import ( // Backend is a storage Backend implemented with Postgres. type Backend struct { - cfg *config - dsn string - onChange *signal.Signal + cfg *config + dsn string + onRecordChange *signal.Signal + onServiceChange *signal.Signal closeCtx context.Context close context.CancelFunc @@ -35,11 +37,13 @@ type Backend struct { // New creates a new Backend. func New(dsn string, options ...Option) *Backend { backend := &Backend{ - cfg: getConfig(options...), - dsn: dsn, - onChange: signal.New(), + cfg: getConfig(options...), + dsn: dsn, + onRecordChange: signal.New(), + onServiceChange: signal.New(), } backend.closeCtx, backend.close = context.WithCancel(context.Background()) + go backend.doPeriodically(func(ctx context.Context) error { _, pool, err := backend.init(ctx) if err != nil { @@ -48,32 +52,64 @@ func New(dsn string, options ...Option) *Backend { return deleteChangesBefore(ctx, pool, time.Now().Add(-backend.cfg.expiry)) }, time.Minute) + go backend.doPeriodically(func(ctx context.Context) error { - _, pool, err := backend.init(backend.closeCtx) + _, pool, err := backend.init(ctx) if err != nil { return err } - conn, err := pool.Acquire(ctx) + rowCount, err := deleteExpiredServices(ctx, pool, time.Now()) if err != nil { return err } - defer conn.Release() - - _, err = conn.Exec(ctx, `LISTEN `+recordChangeNotifyName) - if err != nil { - return err + if rowCount > 0 { + err = signalServiceChange(ctx, pool) + if err != nil { + return err + } } - _, err = conn.Conn().WaitForNotification(ctx) - if err != nil { - return err - } - - backend.onChange.Broadcast(ctx) - return nil - }, time.Millisecond*100) + }, backend.cfg.registryTTL/2) + + // listen for changes and broadcast them via signals + for _, row := range []struct { + signal *signal.Signal + channel string + }{ + {backend.onRecordChange, recordChangeNotifyName}, + {backend.onServiceChange, serviceChangeNotifyName}, + } { + sig, ch := row.signal, row.channel + go backend.doPeriodically(func(ctx context.Context) error { + _, pool, err := backend.init(backend.closeCtx) + if err != nil { + return err + } + + conn, err := pool.Acquire(ctx) + if err != nil { + return err + } + defer conn.Release() + + _, err = conn.Exec(ctx, `LISTEN `+ch) + if err != nil { + return err + } + + _, err = conn.Conn().WaitForNotification(ctx) + if err != nil { + return err + } + + sig.Broadcast(ctx) + + return nil + }, time.Millisecond*100) + } + return backend } @@ -327,7 +363,9 @@ func (backend *Backend) doPeriodically(f func(ctx context.Context) error, dur ti case <-ticker.C: } } else { - log.Error(ctx).Err(err).Msg("storage/postgres") + if !errors.Is(err, context.Canceled) { + log.Error(ctx).Err(err).Msg("storage/postgres") + } select { case <-backend.closeCtx.Done(): return diff --git a/pkg/storage/postgres/migrate.go b/pkg/storage/postgres/migrate.go index d65e2809b..33fab28b4 100644 --- a/pkg/storage/postgres/migrate.go +++ b/pkg/storage/postgres/migrate.go @@ -112,6 +112,22 @@ var migrations = []func(context.Context, pgx.Tx) error{ return err } + return nil + }, + 3: func(ctx context.Context, tx pgx.Tx) error { + _, err := tx.Exec(ctx, ` + CREATE TABLE `+schemaName+`.`+servicesTableName+` ( + kind TEXT NOT NULL, + endpoint TEXT NOT NULL, + expires_at TIMESTAMPTZ NOT NULL, + + PRIMARY KEY (kind, endpoint) + ) + `) + if err != nil { + return err + } + return nil }, } diff --git a/pkg/storage/postgres/option.go b/pkg/storage/postgres/option.go index 6626f3b56..7e548d25a 100644 --- a/pkg/storage/postgres/option.go +++ b/pkg/storage/postgres/option.go @@ -4,10 +4,14 @@ import ( "time" ) -const defaultExpiry = time.Hour * 24 +const ( + defaultExpiry = time.Hour * 24 + defaultRegistryTTL = time.Second * 30 +) type config struct { - expiry time.Duration + expiry time.Duration + registryTTL time.Duration } // Option customizes a Backend. @@ -20,9 +24,17 @@ func WithExpiry(expiry time.Duration) Option { } } +// WithRegistryTTL sets the default registry TTL. +func WithRegistryTTL(ttl time.Duration) Option { + return func(cfg *config) { + cfg.registryTTL = ttl + } +} + func getConfig(options ...Option) *config { cfg := new(config) WithExpiry(defaultExpiry)(cfg) + WithRegistryTTL(defaultRegistryTTL)(cfg) for _, o := range options { o(cfg) } diff --git a/pkg/storage/postgres/postgres.go b/pkg/storage/postgres/postgres.go index 0c26a6081..e3e2fe60d 100644 --- a/pkg/storage/postgres/postgres.go +++ b/pkg/storage/postgres/postgres.go @@ -16,17 +16,20 @@ import ( "google.golang.org/protobuf/types/known/timestamppb" "github.com/pomerium/pomerium/pkg/grpc/databroker" + "github.com/pomerium/pomerium/pkg/grpc/registry" "github.com/pomerium/pomerium/pkg/storage" ) var ( - schemaName = "pomerium" - migrationInfoTableName = "migration_info" - recordsTableName = "records" - recordChangesTableName = "record_changes" - recordChangeNotifyName = "pomerium_record_change" - recordOptionsTableName = "record_options" - leasesTableName = "leases" + schemaName = "pomerium" + migrationInfoTableName = "migration_info" + recordsTableName = "records" + recordChangesTableName = "record_changes" + recordChangeNotifyName = "pomerium_record_change" + recordOptionsTableName = "record_options" + leasesTableName = "leases" + serviceChangeNotifyName = "pomerium_service_change" + servicesTableName = "services" ) type querier interface { @@ -43,6 +46,17 @@ func deleteChangesBefore(ctx context.Context, q querier, cutoff time.Time) error return err } +func deleteExpiredServices(ctx context.Context, q querier, cutoff time.Time) (rowCount int64, err error) { + cmd, err := q.Exec(ctx, ` + DELETE FROM `+schemaName+`.`+servicesTableName+` + WHERE expires_at < $1 + `, cutoff) + if err != nil { + return 0, err + } + return cmd.RowsAffected(), nil +} + func dup(record *databroker.Record) *databroker.Record { return proto.Clone(record).(*databroker.Record) } @@ -221,6 +235,40 @@ func listRecords(ctx context.Context, q querier, expr storage.FilterExpression, return records, rows.Err() } +func listServices(ctx context.Context, q querier) ([]*registry.Service, error) { + var services []*registry.Service + + query := ` + SELECT kind, endpoint + FROM ` + schemaName + `.` + servicesTableName + ` + ORDER BY kind, endpoint + ` + rows, err := q.Query(ctx, query) + if err != nil { + return nil, err + } + defer rows.Close() + + for rows.Next() { + var kind, endpoint string + err = rows.Scan(&kind, &endpoint) + if err != nil { + return nil, err + } + + services = append(services, ®istry.Service{ + Kind: registry.ServiceKind(registry.ServiceKind_value[kind]), + Endpoint: endpoint, + }) + } + err = rows.Err() + if err != nil { + return nil, err + } + + return services, nil +} + func maybeAcquireLease(ctx context.Context, q querier, leaseName, leaseID string, ttl time.Duration) (leaseHolderID string, err error) { tbl := schemaName + "." + leasesTableName expiresAt := timestamptzFromTimestamppb(timestamppb.New(time.Now().Add(ttl))) @@ -283,6 +331,17 @@ func putRecordAndChange(ctx context.Context, q querier, record *databroker.Recor return nil } +func putService(ctx context.Context, q querier, svc *registry.Service, expiresAt time.Time) error { + query := ` + INSERT INTO ` + schemaName + `.` + servicesTableName + ` (kind, endpoint, expires_at) + VALUES ($1, $2, $3) + ON CONFLICT (kind, endpoint) DO UPDATE + SET expires_at=$3 + ` + _, err := q.Exec(ctx, query, svc.GetKind().String(), svc.GetEndpoint(), expiresAt) + return err +} + func setOptions(ctx context.Context, q querier, recordType string, options *databroker.Options) error { capacity := pgtype.Int8{Status: pgtype.Null} if options != nil && options.Capacity != nil { @@ -304,6 +363,11 @@ func signalRecordChange(ctx context.Context, q querier) error { return err } +func signalServiceChange(ctx context.Context, q querier) error { + _, err := q.Exec(ctx, `NOTIFY `+serviceChangeNotifyName) + return err +} + func jsonbFromAny(any *anypb.Any) (pgtype.JSONB, error) { if any == nil { return pgtype.JSONB{Status: pgtype.Null}, nil diff --git a/pkg/storage/postgres/registry.go b/pkg/storage/postgres/registry.go new file mode 100644 index 000000000..d0e7dfd12 --- /dev/null +++ b/pkg/storage/postgres/registry.go @@ -0,0 +1,109 @@ +package postgres + +import ( + "context" + "time" + + "github.com/golang/protobuf/proto" + "google.golang.org/protobuf/types/known/durationpb" + + "github.com/pomerium/pomerium/internal/sets" + "github.com/pomerium/pomerium/pkg/grpc/registry" +) + +type registryServer struct { + *Backend +} + +// RegistryServer returns a registry.RegistryServer for the backend. +func (backend *Backend) RegistryServer() registry.RegistryServer { + return registryServer{backend} +} + +// List lists services. +func (backend registryServer) List( + ctx context.Context, + req *registry.ListRequest, +) (*registry.ServiceList, error) { + _, pool, err := backend.init(ctx) + if err != nil { + return nil, err + } + + all, err := listServices(ctx, pool) + if err != nil { + return nil, err + } + + res := new(registry.ServiceList) + s := sets.NewHash[registry.ServiceKind]() + s.Add(req.GetKinds()...) + for _, svc := range all { + if s.Size() == 0 || s.Has(svc.GetKind()) { + res.Services = append(res.Services, svc) + } + } + return res, nil +} + +// Report registers services. +func (backend registryServer) Report( + ctx context.Context, + req *registry.RegisterRequest, +) (*registry.RegisterResponse, error) { + _, pool, err := backend.init(ctx) + if err != nil { + return nil, err + } + + for _, svc := range req.GetServices() { + err = putService(ctx, pool, svc, time.Now().Add(backend.cfg.registryTTL)) + if err != nil { + return nil, err + } + } + + err = signalServiceChange(ctx, pool) + if err != nil { + return nil, err + } + + return ®istry.RegisterResponse{ + CallBackAfter: durationpb.New(backend.cfg.registryTTL / 2), + }, nil +} + +// Watch watches services. +func (backend registryServer) Watch( + req *registry.ListRequest, + srv registry.Registry_WatchServer, +) error { + ch := backend.onServiceChange.Bind() + defer backend.onServiceChange.Unbind(ch) + + ticker := time.NewTicker(watchPollInterval) + defer ticker.Stop() + + var prev *registry.ServiceList + for i := 0; ; i++ { + res, err := backend.List(srv.Context(), req) + if err != nil { + return err + } + + if i == 0 || !proto.Equal(res, prev) { + err = srv.Send(res) + if err != nil { + return err + } + prev = res + } + + select { + case <-srv.Context().Done(): + return srv.Context().Err() + case <-ch: + case <-ticker.C: + } + } +} diff --git a/pkg/storage/postgres/registry_test.go b/pkg/storage/postgres/registry_test.go new file mode 100644 index 000000000..1a3dfc9ff --- /dev/null +++ b/pkg/storage/postgres/registry_test.go @@ -0,0 +1,113 @@ +package postgres + +import ( + "context" + "errors" + "fmt" + "os" + "runtime" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/sync/errgroup" + + "github.com/pomerium/pomerium/internal/testutil" + "github.com/pomerium/pomerium/pkg/grpc/registry" +) + +type mockRegistryWatchServer struct { + registry.Registry_WatchServer + context context.Context + send func(*registry.ServiceList) error +} + +func (m mockRegistryWatchServer) Context() context.Context { + return m.context +} + +func (m mockRegistryWatchServer) Send(res *registry.ServiceList) error { + return m.send(res) +} + +func TestRegistry(t *testing.T) { + if os.Getenv("GITHUB_ACTION") != "" && runtime.GOOS == "darwin" { + t.Skip("Github action can not run docker on MacOS") + } + + ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10) + defer clearTimeout() + + require.NoError(t, testutil.WithTestPostgres(func(dsn string) error { + backend := New(dsn) + defer backend.Close() + + eg, ctx := errgroup.WithContext(ctx) + listResults := make(chan *registry.ServiceList) + eg.Go(func() error { + srv := mockRegistryWatchServer{ + context: ctx, + send: func(res *registry.ServiceList) error { + select { + case <-ctx.Done(): + return ctx.Err() + case listResults <- res: + } + return nil + }, + } + err := backend.RegistryServer().Watch(®istry.ListRequest{ + Kinds: []registry.ServiceKind{ + registry.ServiceKind_AUTHENTICATE, + registry.ServiceKind_CONSOLE, + }, + }, srv) + if errors.Is(err, context.Canceled) { + return nil + } + return err + }) + eg.Go(func() error { + select { + case <-ctx.Done(): + return ctx.Err() + case res := <-listResults: + testutil.AssertProtoEqual(t, ®istry.ServiceList{}, res) + } + + res, err := backend.RegistryServer().Report(ctx, ®istry.RegisterRequest{ + Services: []*registry.Service{ + {Kind: registry.ServiceKind_AUTHENTICATE, Endpoint: "authenticate.example.com"}, + {Kind: registry.ServiceKind_AUTHORIZE, Endpoint: "authorize.example.com"}, + {Kind: registry.ServiceKind_CONSOLE, Endpoint: "console.example.com"}, + }, + }) + if err != nil { + return fmt.Errorf("error reporting status: %w", err) + } + assert.NotEqual(t, 0, res.GetCallBackAfter()) + + select { + case <-ctx.Done(): + return ctx.Err() + case res := <-listResults: + testutil.AssertProtoEqual(t, ®istry.ServiceList{ + Services: []*registry.Service{ + {Kind: registry.ServiceKind_AUTHENTICATE, Endpoint: "authenticate.example.com"}, + {Kind: registry.ServiceKind_CONSOLE, Endpoint: "console.example.com"}, + }, + }, res) + } + + return context.Canceled + }) + err := eg.Wait() + if errors.Is(err, context.Canceled) { + err = nil + } + assert.NoError(t, err) + + return nil + })) +} diff --git a/pkg/storage/postgres/stream.go b/pkg/storage/postgres/stream.go index feaf7b01b..3dc265f67 100644 --- a/pkg/storage/postgres/stream.go +++ b/pkg/storage/postgres/stream.go @@ -104,7 +104,7 @@ func newChangedRecordStream( recordType: recordType, recordVersion: recordVersion, ticker: time.NewTicker(watchPollInterval), - changed: backend.onChange.Bind(), + changed: backend.onRecordChange.Bind(), } stream.ctx, stream.cancel = contextutil.Merge(ctx, backend.closeCtx) return stream @@ -113,7 +113,7 @@ func newChangedRecordStream( func (stream *changedRecordStream) Close() error { stream.cancel() stream.ticker.Stop() - stream.backend.onChange.Unbind(stream.changed) + stream.backend.onRecordChange.Unbind(stream.changed) return nil }