postgres: registry support (#3454)

This commit is contained in:
Caleb Doxsey 2022-07-13 09:14:47 -06:00 committed by GitHub
parent ca8db7b619
commit 24a9d627cd
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 436 additions and 36 deletions

View file

@ -3,6 +3,7 @@ package databroker
import (
"context"
"fmt"
"io"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/log"
@ -11,6 +12,7 @@ import (
"github.com/pomerium/pomerium/internal/registry/redis"
"github.com/pomerium/pomerium/internal/telemetry/trace"
registrypb "github.com/pomerium/pomerium/pkg/grpc/registry"
"github.com/pomerium/pomerium/pkg/storage"
)
type registryWatchServer struct {
@ -66,6 +68,11 @@ func (srv *Server) Watch(req *registrypb.ListRequest, stream registrypb.Registry
}
func (srv *Server) getRegistry() (registry.Interface, error) {
backend, err := srv.getBackend()
if err != nil {
return nil, err
}
// double-checked locking
srv.mu.RLock()
r := srv.registry
@ -75,7 +82,7 @@ func (srv *Server) getRegistry() (registry.Interface, error) {
r = srv.registry
var err error
if r == nil {
r, err = srv.newRegistryLocked()
r, err = srv.newRegistryLocked(backend)
srv.registry = r
}
srv.mu.Unlock()
@ -86,11 +93,21 @@ func (srv *Server) getRegistry() (registry.Interface, error) {
return r, nil
}
func (srv *Server) newRegistryLocked() (registry.Interface, error) {
func (srv *Server) newRegistryLocked(backend storage.Backend) (registry.Interface, error) {
ctx := context.Background()
if hasRegistryServer, ok := backend.(interface {
RegistryServer() registrypb.RegistryServer
}); ok {
log.Info(ctx).Msg("using registry via storage")
return struct {
io.Closer
registrypb.RegistryServer
}{backend, hasRegistryServer.RegistryServer()}, nil
}
switch srv.cfg.storageType {
case config.StorageInMemoryName, config.StoragePostgresName:
case config.StorageInMemoryName:
log.Info(ctx).Msg("using in-memory registry")
return inmemory.New(ctx, srv.cfg.registryTTL), nil
case config.StorageRedisName:

31
internal/sets/hash.go Normal file
View file

@ -0,0 +1,31 @@
package sets
// A Hash is a set implemented via a map.
type Hash[T comparable] struct {
m map[T]struct{}
}
// NewHash creates a new Hash set.
func NewHash[T comparable]() *Hash[T] {
return &Hash[T]{
m: make(map[T]struct{}),
}
}
// Add adds a value to the set.
func (s *Hash[T]) Add(elements ...T) {
for _, element := range elements {
s.m[element] = struct{}{}
}
}
// Has returns true if the element is in the set.
func (s *Hash[T]) Has(element T) bool {
_, ok := s.m[element]
return ok
}
// Size returns the size of the set.
func (s *Hash[T]) Size() int {
return len(s.m)
}

View file

@ -2,6 +2,7 @@ package postgres
import (
"context"
"errors"
"fmt"
"sync"
"time"
@ -20,9 +21,10 @@ import (
// Backend is a storage Backend implemented with Postgres.
type Backend struct {
cfg *config
dsn string
onChange *signal.Signal
cfg *config
dsn string
onRecordChange *signal.Signal
onServiceChange *signal.Signal
closeCtx context.Context
close context.CancelFunc
@ -35,11 +37,13 @@ type Backend struct {
// New creates a new Backend.
func New(dsn string, options ...Option) *Backend {
backend := &Backend{
cfg: getConfig(options...),
dsn: dsn,
onChange: signal.New(),
cfg: getConfig(options...),
dsn: dsn,
onRecordChange: signal.New(),
onServiceChange: signal.New(),
}
backend.closeCtx, backend.close = context.WithCancel(context.Background())
go backend.doPeriodically(func(ctx context.Context) error {
_, pool, err := backend.init(ctx)
if err != nil {
@ -48,32 +52,64 @@ func New(dsn string, options ...Option) *Backend {
return deleteChangesBefore(ctx, pool, time.Now().Add(-backend.cfg.expiry))
}, time.Minute)
go backend.doPeriodically(func(ctx context.Context) error {
_, pool, err := backend.init(backend.closeCtx)
_, pool, err := backend.init(ctx)
if err != nil {
return err
}
conn, err := pool.Acquire(ctx)
rowCount, err := deleteExpiredServices(ctx, pool, time.Now())
if err != nil {
return err
}
defer conn.Release()
_, err = conn.Exec(ctx, `LISTEN `+recordChangeNotifyName)
if err != nil {
return err
if rowCount > 0 {
err = signalServiceChange(ctx, pool)
if err != nil {
return err
}
}
_, err = conn.Conn().WaitForNotification(ctx)
if err != nil {
return err
}
backend.onChange.Broadcast(ctx)
return nil
}, time.Millisecond*100)
}, backend.cfg.registryTTL/2)
// listen for changes and broadcast them via signals
for _, row := range []struct {
signal *signal.Signal
channel string
}{
{backend.onRecordChange, recordChangeNotifyName},
{backend.onServiceChange, serviceChangeNotifyName},
} {
sig, ch := row.signal, row.channel
go backend.doPeriodically(func(ctx context.Context) error {
_, pool, err := backend.init(backend.closeCtx)
if err != nil {
return err
}
conn, err := pool.Acquire(ctx)
if err != nil {
return err
}
defer conn.Release()
_, err = conn.Exec(ctx, `LISTEN `+ch)
if err != nil {
return err
}
_, err = conn.Conn().WaitForNotification(ctx)
if err != nil {
return err
}
sig.Broadcast(ctx)
return nil
}, time.Millisecond*100)
}
return backend
}
@ -327,7 +363,9 @@ func (backend *Backend) doPeriodically(f func(ctx context.Context) error, dur ti
case <-ticker.C:
}
} else {
log.Error(ctx).Err(err).Msg("storage/postgres")
if !errors.Is(err, context.Canceled) {
log.Error(ctx).Err(err).Msg("storage/postgres")
}
select {
case <-backend.closeCtx.Done():
return

View file

@ -112,6 +112,22 @@ var migrations = []func(context.Context, pgx.Tx) error{
return err
}
return nil
},
3: func(ctx context.Context, tx pgx.Tx) error {
_, err := tx.Exec(ctx, `
CREATE TABLE `+schemaName+`.`+servicesTableName+` (
kind TEXT NOT NULL,
endpoint TEXT NOT NULL,
expires_at TIMESTAMPTZ NOT NULL,
PRIMARY KEY (kind, endpoint)
)
`)
if err != nil {
return err
}
return nil
},
}

View file

@ -4,10 +4,14 @@ import (
"time"
)
const defaultExpiry = time.Hour * 24
const (
defaultExpiry = time.Hour * 24
defaultRegistryTTL = time.Second * 30
)
type config struct {
expiry time.Duration
expiry time.Duration
registryTTL time.Duration
}
// Option customizes a Backend.
@ -20,9 +24,17 @@ func WithExpiry(expiry time.Duration) Option {
}
}
// WithRegistryTTL sets the default registry TTL.
func WithRegistryTTL(ttl time.Duration) Option {
return func(cfg *config) {
cfg.registryTTL = ttl
}
}
func getConfig(options ...Option) *config {
cfg := new(config)
WithExpiry(defaultExpiry)(cfg)
WithRegistryTTL(defaultRegistryTTL)(cfg)
for _, o := range options {
o(cfg)
}

View file

@ -16,17 +16,20 @@ import (
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/grpc/registry"
"github.com/pomerium/pomerium/pkg/storage"
)
var (
schemaName = "pomerium"
migrationInfoTableName = "migration_info"
recordsTableName = "records"
recordChangesTableName = "record_changes"
recordChangeNotifyName = "pomerium_record_change"
recordOptionsTableName = "record_options"
leasesTableName = "leases"
schemaName = "pomerium"
migrationInfoTableName = "migration_info"
recordsTableName = "records"
recordChangesTableName = "record_changes"
recordChangeNotifyName = "pomerium_record_change"
recordOptionsTableName = "record_options"
leasesTableName = "leases"
serviceChangeNotifyName = "pomerium_service_change"
servicesTableName = "services"
)
type querier interface {
@ -43,6 +46,17 @@ func deleteChangesBefore(ctx context.Context, q querier, cutoff time.Time) error
return err
}
func deleteExpiredServices(ctx context.Context, q querier, cutoff time.Time) (rowCount int64, err error) {
cmd, err := q.Exec(ctx, `
DELETE FROM `+schemaName+`.`+servicesTableName+`
WHERE expires_at < $1
`, cutoff)
if err != nil {
return 0, err
}
return cmd.RowsAffected(), nil
}
func dup(record *databroker.Record) *databroker.Record {
return proto.Clone(record).(*databroker.Record)
}
@ -221,6 +235,40 @@ func listRecords(ctx context.Context, q querier, expr storage.FilterExpression,
return records, rows.Err()
}
func listServices(ctx context.Context, q querier) ([]*registry.Service, error) {
var services []*registry.Service
query := `
SELECT kind, endpoint
FROM ` + schemaName + `.` + servicesTableName + `
ORDER BY kind, endpoint
`
rows, err := q.Query(ctx, query)
if err != nil {
return nil, err
}
defer rows.Close()
for rows.Next() {
var kind, endpoint string
err = rows.Scan(&kind, &endpoint)
if err != nil {
return nil, err
}
services = append(services, &registry.Service{
Kind: registry.ServiceKind(registry.ServiceKind_value[kind]),
Endpoint: endpoint,
})
}
err = rows.Err()
if err != nil {
return nil, err
}
return services, nil
}
func maybeAcquireLease(ctx context.Context, q querier, leaseName, leaseID string, ttl time.Duration) (leaseHolderID string, err error) {
tbl := schemaName + "." + leasesTableName
expiresAt := timestamptzFromTimestamppb(timestamppb.New(time.Now().Add(ttl)))
@ -283,6 +331,17 @@ func putRecordAndChange(ctx context.Context, q querier, record *databroker.Recor
return nil
}
func putService(ctx context.Context, q querier, svc *registry.Service, expiresAt time.Time) error {
query := `
INSERT INTO ` + schemaName + `.` + servicesTableName + ` (kind, endpoint, expires_at)
VALUES ($1, $2, $3)
ON CONFLICT (kind, endpoint) DO UPDATE
SET expires_at=$3
`
_, err := q.Exec(ctx, query, svc.GetKind().String(), svc.GetEndpoint(), expiresAt)
return err
}
func setOptions(ctx context.Context, q querier, recordType string, options *databroker.Options) error {
capacity := pgtype.Int8{Status: pgtype.Null}
if options != nil && options.Capacity != nil {
@ -304,6 +363,11 @@ func signalRecordChange(ctx context.Context, q querier) error {
return err
}
func signalServiceChange(ctx context.Context, q querier) error {
_, err := q.Exec(ctx, `NOTIFY `+serviceChangeNotifyName)
return err
}
func jsonbFromAny(any *anypb.Any) (pgtype.JSONB, error) {
if any == nil {
return pgtype.JSONB{Status: pgtype.Null}, nil

View file

@ -0,0 +1,109 @@
package postgres
import (
"context"
"time"
"github.com/golang/protobuf/proto"
"google.golang.org/protobuf/types/known/durationpb"
"github.com/pomerium/pomerium/internal/sets"
"github.com/pomerium/pomerium/pkg/grpc/registry"
)
type registryServer struct {
*Backend
}
// RegistryServer returns a registry.RegistryServer for the backend.
func (backend *Backend) RegistryServer() registry.RegistryServer {
return registryServer{backend}
}
// List lists services.
func (backend registryServer) List(
ctx context.Context,
req *registry.ListRequest,
) (*registry.ServiceList, error) {
_, pool, err := backend.init(ctx)
if err != nil {
return nil, err
}
all, err := listServices(ctx, pool)
if err != nil {
return nil, err
}
res := new(registry.ServiceList)
s := sets.NewHash[registry.ServiceKind]()
s.Add(req.GetKinds()...)
for _, svc := range all {
if s.Size() == 0 || s.Has(svc.GetKind()) {
res.Services = append(res.Services, svc)
}
}
return res, nil
}
// Report registers services.
func (backend registryServer) Report(
ctx context.Context,
req *registry.RegisterRequest,
) (*registry.RegisterResponse, error) {
_, pool, err := backend.init(ctx)
if err != nil {
return nil, err
}
for _, svc := range req.GetServices() {
err = putService(ctx, pool, svc, time.Now().Add(backend.cfg.registryTTL))
if err != nil {
return nil, err
}
}
err = signalServiceChange(ctx, pool)
if err != nil {
return nil, err
}
return &registry.RegisterResponse{
CallBackAfter: durationpb.New(backend.cfg.registryTTL / 2),
}, nil
}
// Watch watches services.
func (backend registryServer) Watch(
req *registry.ListRequest,
srv registry.Registry_WatchServer,
) error {
ch := backend.onServiceChange.Bind()
defer backend.onServiceChange.Unbind(ch)
ticker := time.NewTicker(watchPollInterval)
defer ticker.Stop()
var prev *registry.ServiceList
for i := 0; ; i++ {
res, err := backend.List(srv.Context(), req)
if err != nil {
return err
}
if i == 0 || !proto.Equal(res, prev) {
err = srv.Send(res)
if err != nil {
return err
}
prev = res
}
select {
case <-srv.Context().Done():
return srv.Context().Err()
case <-ch:
case <-ticker.C:
}
}
}

View file

@ -0,0 +1,113 @@
package postgres
import (
"context"
"errors"
"fmt"
"os"
"runtime"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/sync/errgroup"
"github.com/pomerium/pomerium/internal/testutil"
"github.com/pomerium/pomerium/pkg/grpc/registry"
)
type mockRegistryWatchServer struct {
registry.Registry_WatchServer
context context.Context
send func(*registry.ServiceList) error
}
func (m mockRegistryWatchServer) Context() context.Context {
return m.context
}
func (m mockRegistryWatchServer) Send(res *registry.ServiceList) error {
return m.send(res)
}
func TestRegistry(t *testing.T) {
if os.Getenv("GITHUB_ACTION") != "" && runtime.GOOS == "darwin" {
t.Skip("Github action can not run docker on MacOS")
}
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10)
defer clearTimeout()
require.NoError(t, testutil.WithTestPostgres(func(dsn string) error {
backend := New(dsn)
defer backend.Close()
eg, ctx := errgroup.WithContext(ctx)
listResults := make(chan *registry.ServiceList)
eg.Go(func() error {
srv := mockRegistryWatchServer{
context: ctx,
send: func(res *registry.ServiceList) error {
select {
case <-ctx.Done():
return ctx.Err()
case listResults <- res:
}
return nil
},
}
err := backend.RegistryServer().Watch(&registry.ListRequest{
Kinds: []registry.ServiceKind{
registry.ServiceKind_AUTHENTICATE,
registry.ServiceKind_CONSOLE,
},
}, srv)
if errors.Is(err, context.Canceled) {
return nil
}
return err
})
eg.Go(func() error {
select {
case <-ctx.Done():
return ctx.Err()
case res := <-listResults:
testutil.AssertProtoEqual(t, &registry.ServiceList{}, res)
}
res, err := backend.RegistryServer().Report(ctx, &registry.RegisterRequest{
Services: []*registry.Service{
{Kind: registry.ServiceKind_AUTHENTICATE, Endpoint: "authenticate.example.com"},
{Kind: registry.ServiceKind_AUTHORIZE, Endpoint: "authorize.example.com"},
{Kind: registry.ServiceKind_CONSOLE, Endpoint: "console.example.com"},
},
})
if err != nil {
return fmt.Errorf("error reporting status: %w", err)
}
assert.NotEqual(t, 0, res.GetCallBackAfter())
select {
case <-ctx.Done():
return ctx.Err()
case res := <-listResults:
testutil.AssertProtoEqual(t, &registry.ServiceList{
Services: []*registry.Service{
{Kind: registry.ServiceKind_AUTHENTICATE, Endpoint: "authenticate.example.com"},
{Kind: registry.ServiceKind_CONSOLE, Endpoint: "console.example.com"},
},
}, res)
}
return context.Canceled
})
err := eg.Wait()
if errors.Is(err, context.Canceled) {
err = nil
}
assert.NoError(t, err)
return nil
}))
}

View file

@ -104,7 +104,7 @@ func newChangedRecordStream(
recordType: recordType,
recordVersion: recordVersion,
ticker: time.NewTicker(watchPollInterval),
changed: backend.onChange.Bind(),
changed: backend.onRecordChange.Bind(),
}
stream.ctx, stream.cancel = contextutil.Merge(ctx, backend.closeCtx)
return stream
@ -113,7 +113,7 @@ func newChangedRecordStream(
func (stream *changedRecordStream) Close() error {
stream.cancel()
stream.ticker.Stop()
stream.backend.onChange.Unbind(stream.changed)
stream.backend.onRecordChange.Unbind(stream.changed)
return nil
}