postgres: registry support (#3454)

This commit is contained in:
Caleb Doxsey 2022-07-13 09:14:47 -06:00 committed by GitHub
parent ca8db7b619
commit 24a9d627cd
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 436 additions and 36 deletions

View file

@ -3,6 +3,7 @@ package databroker
import ( import (
"context" "context"
"fmt" "fmt"
"io"
"github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/log"
@ -11,6 +12,7 @@ import (
"github.com/pomerium/pomerium/internal/registry/redis" "github.com/pomerium/pomerium/internal/registry/redis"
"github.com/pomerium/pomerium/internal/telemetry/trace" "github.com/pomerium/pomerium/internal/telemetry/trace"
registrypb "github.com/pomerium/pomerium/pkg/grpc/registry" registrypb "github.com/pomerium/pomerium/pkg/grpc/registry"
"github.com/pomerium/pomerium/pkg/storage"
) )
type registryWatchServer struct { type registryWatchServer struct {
@ -66,6 +68,11 @@ func (srv *Server) Watch(req *registrypb.ListRequest, stream registrypb.Registry
} }
func (srv *Server) getRegistry() (registry.Interface, error) { func (srv *Server) getRegistry() (registry.Interface, error) {
backend, err := srv.getBackend()
if err != nil {
return nil, err
}
// double-checked locking // double-checked locking
srv.mu.RLock() srv.mu.RLock()
r := srv.registry r := srv.registry
@ -75,7 +82,7 @@ func (srv *Server) getRegistry() (registry.Interface, error) {
r = srv.registry r = srv.registry
var err error var err error
if r == nil { if r == nil {
r, err = srv.newRegistryLocked() r, err = srv.newRegistryLocked(backend)
srv.registry = r srv.registry = r
} }
srv.mu.Unlock() srv.mu.Unlock()
@ -86,11 +93,21 @@ func (srv *Server) getRegistry() (registry.Interface, error) {
return r, nil return r, nil
} }
func (srv *Server) newRegistryLocked() (registry.Interface, error) { func (srv *Server) newRegistryLocked(backend storage.Backend) (registry.Interface, error) {
ctx := context.Background() ctx := context.Background()
if hasRegistryServer, ok := backend.(interface {
RegistryServer() registrypb.RegistryServer
}); ok {
log.Info(ctx).Msg("using registry via storage")
return struct {
io.Closer
registrypb.RegistryServer
}{backend, hasRegistryServer.RegistryServer()}, nil
}
switch srv.cfg.storageType { switch srv.cfg.storageType {
case config.StorageInMemoryName, config.StoragePostgresName: case config.StorageInMemoryName:
log.Info(ctx).Msg("using in-memory registry") log.Info(ctx).Msg("using in-memory registry")
return inmemory.New(ctx, srv.cfg.registryTTL), nil return inmemory.New(ctx, srv.cfg.registryTTL), nil
case config.StorageRedisName: case config.StorageRedisName:

31
internal/sets/hash.go Normal file
View file

@ -0,0 +1,31 @@
package sets
// A Hash is a set implemented via a map.
type Hash[T comparable] struct {
m map[T]struct{}
}
// NewHash creates a new Hash set.
func NewHash[T comparable]() *Hash[T] {
return &Hash[T]{
m: make(map[T]struct{}),
}
}
// Add adds a value to the set.
func (s *Hash[T]) Add(elements ...T) {
for _, element := range elements {
s.m[element] = struct{}{}
}
}
// Has returns true if the element is in the set.
func (s *Hash[T]) Has(element T) bool {
_, ok := s.m[element]
return ok
}
// Size returns the size of the set.
func (s *Hash[T]) Size() int {
return len(s.m)
}

View file

@ -2,6 +2,7 @@ package postgres
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"sync" "sync"
"time" "time"
@ -20,9 +21,10 @@ import (
// Backend is a storage Backend implemented with Postgres. // Backend is a storage Backend implemented with Postgres.
type Backend struct { type Backend struct {
cfg *config cfg *config
dsn string dsn string
onChange *signal.Signal onRecordChange *signal.Signal
onServiceChange *signal.Signal
closeCtx context.Context closeCtx context.Context
close context.CancelFunc close context.CancelFunc
@ -35,11 +37,13 @@ type Backend struct {
// New creates a new Backend. // New creates a new Backend.
func New(dsn string, options ...Option) *Backend { func New(dsn string, options ...Option) *Backend {
backend := &Backend{ backend := &Backend{
cfg: getConfig(options...), cfg: getConfig(options...),
dsn: dsn, dsn: dsn,
onChange: signal.New(), onRecordChange: signal.New(),
onServiceChange: signal.New(),
} }
backend.closeCtx, backend.close = context.WithCancel(context.Background()) backend.closeCtx, backend.close = context.WithCancel(context.Background())
go backend.doPeriodically(func(ctx context.Context) error { go backend.doPeriodically(func(ctx context.Context) error {
_, pool, err := backend.init(ctx) _, pool, err := backend.init(ctx)
if err != nil { if err != nil {
@ -48,32 +52,64 @@ func New(dsn string, options ...Option) *Backend {
return deleteChangesBefore(ctx, pool, time.Now().Add(-backend.cfg.expiry)) return deleteChangesBefore(ctx, pool, time.Now().Add(-backend.cfg.expiry))
}, time.Minute) }, time.Minute)
go backend.doPeriodically(func(ctx context.Context) error { go backend.doPeriodically(func(ctx context.Context) error {
_, pool, err := backend.init(backend.closeCtx) _, pool, err := backend.init(ctx)
if err != nil { if err != nil {
return err return err
} }
conn, err := pool.Acquire(ctx) rowCount, err := deleteExpiredServices(ctx, pool, time.Now())
if err != nil { if err != nil {
return err return err
} }
defer conn.Release() if rowCount > 0 {
err = signalServiceChange(ctx, pool)
_, err = conn.Exec(ctx, `LISTEN `+recordChangeNotifyName) if err != nil {
if err != nil { return err
return err }
} }
_, err = conn.Conn().WaitForNotification(ctx)
if err != nil {
return err
}
backend.onChange.Broadcast(ctx)
return nil return nil
}, time.Millisecond*100) }, backend.cfg.registryTTL/2)
// listen for changes and broadcast them via signals
for _, row := range []struct {
signal *signal.Signal
channel string
}{
{backend.onRecordChange, recordChangeNotifyName},
{backend.onServiceChange, serviceChangeNotifyName},
} {
sig, ch := row.signal, row.channel
go backend.doPeriodically(func(ctx context.Context) error {
_, pool, err := backend.init(backend.closeCtx)
if err != nil {
return err
}
conn, err := pool.Acquire(ctx)
if err != nil {
return err
}
defer conn.Release()
_, err = conn.Exec(ctx, `LISTEN `+ch)
if err != nil {
return err
}
_, err = conn.Conn().WaitForNotification(ctx)
if err != nil {
return err
}
sig.Broadcast(ctx)
return nil
}, time.Millisecond*100)
}
return backend return backend
} }
@ -327,7 +363,9 @@ func (backend *Backend) doPeriodically(f func(ctx context.Context) error, dur ti
case <-ticker.C: case <-ticker.C:
} }
} else { } else {
log.Error(ctx).Err(err).Msg("storage/postgres") if !errors.Is(err, context.Canceled) {
log.Error(ctx).Err(err).Msg("storage/postgres")
}
select { select {
case <-backend.closeCtx.Done(): case <-backend.closeCtx.Done():
return return

View file

@ -112,6 +112,22 @@ var migrations = []func(context.Context, pgx.Tx) error{
return err return err
} }
return nil
},
3: func(ctx context.Context, tx pgx.Tx) error {
_, err := tx.Exec(ctx, `
CREATE TABLE `+schemaName+`.`+servicesTableName+` (
kind TEXT NOT NULL,
endpoint TEXT NOT NULL,
expires_at TIMESTAMPTZ NOT NULL,
PRIMARY KEY (kind, endpoint)
)
`)
if err != nil {
return err
}
return nil return nil
}, },
} }

View file

@ -4,10 +4,14 @@ import (
"time" "time"
) )
const defaultExpiry = time.Hour * 24 const (
defaultExpiry = time.Hour * 24
defaultRegistryTTL = time.Second * 30
)
type config struct { type config struct {
expiry time.Duration expiry time.Duration
registryTTL time.Duration
} }
// Option customizes a Backend. // Option customizes a Backend.
@ -20,9 +24,17 @@ func WithExpiry(expiry time.Duration) Option {
} }
} }
// WithRegistryTTL sets the default registry TTL.
func WithRegistryTTL(ttl time.Duration) Option {
return func(cfg *config) {
cfg.registryTTL = ttl
}
}
func getConfig(options ...Option) *config { func getConfig(options ...Option) *config {
cfg := new(config) cfg := new(config)
WithExpiry(defaultExpiry)(cfg) WithExpiry(defaultExpiry)(cfg)
WithRegistryTTL(defaultRegistryTTL)(cfg)
for _, o := range options { for _, o := range options {
o(cfg) o(cfg)
} }

View file

@ -16,17 +16,20 @@ import (
"google.golang.org/protobuf/types/known/timestamppb" "google.golang.org/protobuf/types/known/timestamppb"
"github.com/pomerium/pomerium/pkg/grpc/databroker" "github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/grpc/registry"
"github.com/pomerium/pomerium/pkg/storage" "github.com/pomerium/pomerium/pkg/storage"
) )
var ( var (
schemaName = "pomerium" schemaName = "pomerium"
migrationInfoTableName = "migration_info" migrationInfoTableName = "migration_info"
recordsTableName = "records" recordsTableName = "records"
recordChangesTableName = "record_changes" recordChangesTableName = "record_changes"
recordChangeNotifyName = "pomerium_record_change" recordChangeNotifyName = "pomerium_record_change"
recordOptionsTableName = "record_options" recordOptionsTableName = "record_options"
leasesTableName = "leases" leasesTableName = "leases"
serviceChangeNotifyName = "pomerium_service_change"
servicesTableName = "services"
) )
type querier interface { type querier interface {
@ -43,6 +46,17 @@ func deleteChangesBefore(ctx context.Context, q querier, cutoff time.Time) error
return err return err
} }
func deleteExpiredServices(ctx context.Context, q querier, cutoff time.Time) (rowCount int64, err error) {
cmd, err := q.Exec(ctx, `
DELETE FROM `+schemaName+`.`+servicesTableName+`
WHERE expires_at < $1
`, cutoff)
if err != nil {
return 0, err
}
return cmd.RowsAffected(), nil
}
func dup(record *databroker.Record) *databroker.Record { func dup(record *databroker.Record) *databroker.Record {
return proto.Clone(record).(*databroker.Record) return proto.Clone(record).(*databroker.Record)
} }
@ -221,6 +235,40 @@ func listRecords(ctx context.Context, q querier, expr storage.FilterExpression,
return records, rows.Err() return records, rows.Err()
} }
func listServices(ctx context.Context, q querier) ([]*registry.Service, error) {
var services []*registry.Service
query := `
SELECT kind, endpoint
FROM ` + schemaName + `.` + servicesTableName + `
ORDER BY kind, endpoint
`
rows, err := q.Query(ctx, query)
if err != nil {
return nil, err
}
defer rows.Close()
for rows.Next() {
var kind, endpoint string
err = rows.Scan(&kind, &endpoint)
if err != nil {
return nil, err
}
services = append(services, &registry.Service{
Kind: registry.ServiceKind(registry.ServiceKind_value[kind]),
Endpoint: endpoint,
})
}
err = rows.Err()
if err != nil {
return nil, err
}
return services, nil
}
func maybeAcquireLease(ctx context.Context, q querier, leaseName, leaseID string, ttl time.Duration) (leaseHolderID string, err error) { func maybeAcquireLease(ctx context.Context, q querier, leaseName, leaseID string, ttl time.Duration) (leaseHolderID string, err error) {
tbl := schemaName + "." + leasesTableName tbl := schemaName + "." + leasesTableName
expiresAt := timestamptzFromTimestamppb(timestamppb.New(time.Now().Add(ttl))) expiresAt := timestamptzFromTimestamppb(timestamppb.New(time.Now().Add(ttl)))
@ -283,6 +331,17 @@ func putRecordAndChange(ctx context.Context, q querier, record *databroker.Recor
return nil return nil
} }
func putService(ctx context.Context, q querier, svc *registry.Service, expiresAt time.Time) error {
query := `
INSERT INTO ` + schemaName + `.` + servicesTableName + ` (kind, endpoint, expires_at)
VALUES ($1, $2, $3)
ON CONFLICT (kind, endpoint) DO UPDATE
SET expires_at=$3
`
_, err := q.Exec(ctx, query, svc.GetKind().String(), svc.GetEndpoint(), expiresAt)
return err
}
func setOptions(ctx context.Context, q querier, recordType string, options *databroker.Options) error { func setOptions(ctx context.Context, q querier, recordType string, options *databroker.Options) error {
capacity := pgtype.Int8{Status: pgtype.Null} capacity := pgtype.Int8{Status: pgtype.Null}
if options != nil && options.Capacity != nil { if options != nil && options.Capacity != nil {
@ -304,6 +363,11 @@ func signalRecordChange(ctx context.Context, q querier) error {
return err return err
} }
func signalServiceChange(ctx context.Context, q querier) error {
_, err := q.Exec(ctx, `NOTIFY `+serviceChangeNotifyName)
return err
}
func jsonbFromAny(any *anypb.Any) (pgtype.JSONB, error) { func jsonbFromAny(any *anypb.Any) (pgtype.JSONB, error) {
if any == nil { if any == nil {
return pgtype.JSONB{Status: pgtype.Null}, nil return pgtype.JSONB{Status: pgtype.Null}, nil

View file

@ -0,0 +1,109 @@
package postgres
import (
"context"
"time"
"github.com/golang/protobuf/proto"
"google.golang.org/protobuf/types/known/durationpb"
"github.com/pomerium/pomerium/internal/sets"
"github.com/pomerium/pomerium/pkg/grpc/registry"
)
type registryServer struct {
*Backend
}
// RegistryServer returns a registry.RegistryServer for the backend.
func (backend *Backend) RegistryServer() registry.RegistryServer {
return registryServer{backend}
}
// List lists services.
func (backend registryServer) List(
ctx context.Context,
req *registry.ListRequest,
) (*registry.ServiceList, error) {
_, pool, err := backend.init(ctx)
if err != nil {
return nil, err
}
all, err := listServices(ctx, pool)
if err != nil {
return nil, err
}
res := new(registry.ServiceList)
s := sets.NewHash[registry.ServiceKind]()
s.Add(req.GetKinds()...)
for _, svc := range all {
if s.Size() == 0 || s.Has(svc.GetKind()) {
res.Services = append(res.Services, svc)
}
}
return res, nil
}
// Report registers services.
func (backend registryServer) Report(
ctx context.Context,
req *registry.RegisterRequest,
) (*registry.RegisterResponse, error) {
_, pool, err := backend.init(ctx)
if err != nil {
return nil, err
}
for _, svc := range req.GetServices() {
err = putService(ctx, pool, svc, time.Now().Add(backend.cfg.registryTTL))
if err != nil {
return nil, err
}
}
err = signalServiceChange(ctx, pool)
if err != nil {
return nil, err
}
return &registry.RegisterResponse{
CallBackAfter: durationpb.New(backend.cfg.registryTTL / 2),
}, nil
}
// Watch watches services.
func (backend registryServer) Watch(
req *registry.ListRequest,
srv registry.Registry_WatchServer,
) error {
ch := backend.onServiceChange.Bind()
defer backend.onServiceChange.Unbind(ch)
ticker := time.NewTicker(watchPollInterval)
defer ticker.Stop()
var prev *registry.ServiceList
for i := 0; ; i++ {
res, err := backend.List(srv.Context(), req)
if err != nil {
return err
}
if i == 0 || !proto.Equal(res, prev) {
err = srv.Send(res)
if err != nil {
return err
}
prev = res
}
select {
case <-srv.Context().Done():
return srv.Context().Err()
case <-ch:
case <-ticker.C:
}
}
}

View file

@ -0,0 +1,113 @@
package postgres
import (
"context"
"errors"
"fmt"
"os"
"runtime"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/sync/errgroup"
"github.com/pomerium/pomerium/internal/testutil"
"github.com/pomerium/pomerium/pkg/grpc/registry"
)
type mockRegistryWatchServer struct {
registry.Registry_WatchServer
context context.Context
send func(*registry.ServiceList) error
}
func (m mockRegistryWatchServer) Context() context.Context {
return m.context
}
func (m mockRegistryWatchServer) Send(res *registry.ServiceList) error {
return m.send(res)
}
func TestRegistry(t *testing.T) {
if os.Getenv("GITHUB_ACTION") != "" && runtime.GOOS == "darwin" {
t.Skip("Github action can not run docker on MacOS")
}
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10)
defer clearTimeout()
require.NoError(t, testutil.WithTestPostgres(func(dsn string) error {
backend := New(dsn)
defer backend.Close()
eg, ctx := errgroup.WithContext(ctx)
listResults := make(chan *registry.ServiceList)
eg.Go(func() error {
srv := mockRegistryWatchServer{
context: ctx,
send: func(res *registry.ServiceList) error {
select {
case <-ctx.Done():
return ctx.Err()
case listResults <- res:
}
return nil
},
}
err := backend.RegistryServer().Watch(&registry.ListRequest{
Kinds: []registry.ServiceKind{
registry.ServiceKind_AUTHENTICATE,
registry.ServiceKind_CONSOLE,
},
}, srv)
if errors.Is(err, context.Canceled) {
return nil
}
return err
})
eg.Go(func() error {
select {
case <-ctx.Done():
return ctx.Err()
case res := <-listResults:
testutil.AssertProtoEqual(t, &registry.ServiceList{}, res)
}
res, err := backend.RegistryServer().Report(ctx, &registry.RegisterRequest{
Services: []*registry.Service{
{Kind: registry.ServiceKind_AUTHENTICATE, Endpoint: "authenticate.example.com"},
{Kind: registry.ServiceKind_AUTHORIZE, Endpoint: "authorize.example.com"},
{Kind: registry.ServiceKind_CONSOLE, Endpoint: "console.example.com"},
},
})
if err != nil {
return fmt.Errorf("error reporting status: %w", err)
}
assert.NotEqual(t, 0, res.GetCallBackAfter())
select {
case <-ctx.Done():
return ctx.Err()
case res := <-listResults:
testutil.AssertProtoEqual(t, &registry.ServiceList{
Services: []*registry.Service{
{Kind: registry.ServiceKind_AUTHENTICATE, Endpoint: "authenticate.example.com"},
{Kind: registry.ServiceKind_CONSOLE, Endpoint: "console.example.com"},
},
}, res)
}
return context.Canceled
})
err := eg.Wait()
if errors.Is(err, context.Canceled) {
err = nil
}
assert.NoError(t, err)
return nil
}))
}

View file

@ -104,7 +104,7 @@ func newChangedRecordStream(
recordType: recordType, recordType: recordType,
recordVersion: recordVersion, recordVersion: recordVersion,
ticker: time.NewTicker(watchPollInterval), ticker: time.NewTicker(watchPollInterval),
changed: backend.onChange.Bind(), changed: backend.onRecordChange.Bind(),
} }
stream.ctx, stream.cancel = contextutil.Merge(ctx, backend.closeCtx) stream.ctx, stream.cancel = contextutil.Merge(ctx, backend.closeCtx)
return stream return stream
@ -113,7 +113,7 @@ func newChangedRecordStream(
func (stream *changedRecordStream) Close() error { func (stream *changedRecordStream) Close() error {
stream.cancel() stream.cancel()
stream.ticker.Stop() stream.ticker.Stop()
stream.backend.onChange.Unbind(stream.changed) stream.backend.onRecordChange.Unbind(stream.changed)
return nil return nil
} }