mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-01 11:26:29 +02:00
postgres: registry support (#3454)
This commit is contained in:
parent
ca8db7b619
commit
24a9d627cd
9 changed files with 436 additions and 36 deletions
|
@ -3,6 +3,7 @@ package databroker
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/config"
|
"github.com/pomerium/pomerium/config"
|
||||||
"github.com/pomerium/pomerium/internal/log"
|
"github.com/pomerium/pomerium/internal/log"
|
||||||
|
@ -11,6 +12,7 @@ import (
|
||||||
"github.com/pomerium/pomerium/internal/registry/redis"
|
"github.com/pomerium/pomerium/internal/registry/redis"
|
||||||
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
||||||
registrypb "github.com/pomerium/pomerium/pkg/grpc/registry"
|
registrypb "github.com/pomerium/pomerium/pkg/grpc/registry"
|
||||||
|
"github.com/pomerium/pomerium/pkg/storage"
|
||||||
)
|
)
|
||||||
|
|
||||||
type registryWatchServer struct {
|
type registryWatchServer struct {
|
||||||
|
@ -66,6 +68,11 @@ func (srv *Server) Watch(req *registrypb.ListRequest, stream registrypb.Registry
|
||||||
}
|
}
|
||||||
|
|
||||||
func (srv *Server) getRegistry() (registry.Interface, error) {
|
func (srv *Server) getRegistry() (registry.Interface, error) {
|
||||||
|
backend, err := srv.getBackend()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
// double-checked locking
|
// double-checked locking
|
||||||
srv.mu.RLock()
|
srv.mu.RLock()
|
||||||
r := srv.registry
|
r := srv.registry
|
||||||
|
@ -75,7 +82,7 @@ func (srv *Server) getRegistry() (registry.Interface, error) {
|
||||||
r = srv.registry
|
r = srv.registry
|
||||||
var err error
|
var err error
|
||||||
if r == nil {
|
if r == nil {
|
||||||
r, err = srv.newRegistryLocked()
|
r, err = srv.newRegistryLocked(backend)
|
||||||
srv.registry = r
|
srv.registry = r
|
||||||
}
|
}
|
||||||
srv.mu.Unlock()
|
srv.mu.Unlock()
|
||||||
|
@ -86,11 +93,21 @@ func (srv *Server) getRegistry() (registry.Interface, error) {
|
||||||
return r, nil
|
return r, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (srv *Server) newRegistryLocked() (registry.Interface, error) {
|
func (srv *Server) newRegistryLocked(backend storage.Backend) (registry.Interface, error) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
|
if hasRegistryServer, ok := backend.(interface {
|
||||||
|
RegistryServer() registrypb.RegistryServer
|
||||||
|
}); ok {
|
||||||
|
log.Info(ctx).Msg("using registry via storage")
|
||||||
|
return struct {
|
||||||
|
io.Closer
|
||||||
|
registrypb.RegistryServer
|
||||||
|
}{backend, hasRegistryServer.RegistryServer()}, nil
|
||||||
|
}
|
||||||
|
|
||||||
switch srv.cfg.storageType {
|
switch srv.cfg.storageType {
|
||||||
case config.StorageInMemoryName, config.StoragePostgresName:
|
case config.StorageInMemoryName:
|
||||||
log.Info(ctx).Msg("using in-memory registry")
|
log.Info(ctx).Msg("using in-memory registry")
|
||||||
return inmemory.New(ctx, srv.cfg.registryTTL), nil
|
return inmemory.New(ctx, srv.cfg.registryTTL), nil
|
||||||
case config.StorageRedisName:
|
case config.StorageRedisName:
|
||||||
|
|
31
internal/sets/hash.go
Normal file
31
internal/sets/hash.go
Normal file
|
@ -0,0 +1,31 @@
|
||||||
|
package sets
|
||||||
|
|
||||||
|
// A Hash is a set implemented via a map.
|
||||||
|
type Hash[T comparable] struct {
|
||||||
|
m map[T]struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewHash creates a new Hash set.
|
||||||
|
func NewHash[T comparable]() *Hash[T] {
|
||||||
|
return &Hash[T]{
|
||||||
|
m: make(map[T]struct{}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add adds a value to the set.
|
||||||
|
func (s *Hash[T]) Add(elements ...T) {
|
||||||
|
for _, element := range elements {
|
||||||
|
s.m[element] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Has returns true if the element is in the set.
|
||||||
|
func (s *Hash[T]) Has(element T) bool {
|
||||||
|
_, ok := s.m[element]
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
// Size returns the size of the set.
|
||||||
|
func (s *Hash[T]) Size() int {
|
||||||
|
return len(s.m)
|
||||||
|
}
|
|
@ -2,6 +2,7 @@ package postgres
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
@ -20,9 +21,10 @@ import (
|
||||||
|
|
||||||
// Backend is a storage Backend implemented with Postgres.
|
// Backend is a storage Backend implemented with Postgres.
|
||||||
type Backend struct {
|
type Backend struct {
|
||||||
cfg *config
|
cfg *config
|
||||||
dsn string
|
dsn string
|
||||||
onChange *signal.Signal
|
onRecordChange *signal.Signal
|
||||||
|
onServiceChange *signal.Signal
|
||||||
|
|
||||||
closeCtx context.Context
|
closeCtx context.Context
|
||||||
close context.CancelFunc
|
close context.CancelFunc
|
||||||
|
@ -35,11 +37,13 @@ type Backend struct {
|
||||||
// New creates a new Backend.
|
// New creates a new Backend.
|
||||||
func New(dsn string, options ...Option) *Backend {
|
func New(dsn string, options ...Option) *Backend {
|
||||||
backend := &Backend{
|
backend := &Backend{
|
||||||
cfg: getConfig(options...),
|
cfg: getConfig(options...),
|
||||||
dsn: dsn,
|
dsn: dsn,
|
||||||
onChange: signal.New(),
|
onRecordChange: signal.New(),
|
||||||
|
onServiceChange: signal.New(),
|
||||||
}
|
}
|
||||||
backend.closeCtx, backend.close = context.WithCancel(context.Background())
|
backend.closeCtx, backend.close = context.WithCancel(context.Background())
|
||||||
|
|
||||||
go backend.doPeriodically(func(ctx context.Context) error {
|
go backend.doPeriodically(func(ctx context.Context) error {
|
||||||
_, pool, err := backend.init(ctx)
|
_, pool, err := backend.init(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -48,32 +52,64 @@ func New(dsn string, options ...Option) *Backend {
|
||||||
|
|
||||||
return deleteChangesBefore(ctx, pool, time.Now().Add(-backend.cfg.expiry))
|
return deleteChangesBefore(ctx, pool, time.Now().Add(-backend.cfg.expiry))
|
||||||
}, time.Minute)
|
}, time.Minute)
|
||||||
|
|
||||||
go backend.doPeriodically(func(ctx context.Context) error {
|
go backend.doPeriodically(func(ctx context.Context) error {
|
||||||
_, pool, err := backend.init(backend.closeCtx)
|
_, pool, err := backend.init(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
conn, err := pool.Acquire(ctx)
|
rowCount, err := deleteExpiredServices(ctx, pool, time.Now())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer conn.Release()
|
if rowCount > 0 {
|
||||||
|
err = signalServiceChange(ctx, pool)
|
||||||
_, err = conn.Exec(ctx, `LISTEN `+recordChangeNotifyName)
|
if err != nil {
|
||||||
if err != nil {
|
return err
|
||||||
return err
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = conn.Conn().WaitForNotification(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
backend.onChange.Broadcast(ctx)
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}, time.Millisecond*100)
|
}, backend.cfg.registryTTL/2)
|
||||||
|
|
||||||
|
// listen for changes and broadcast them via signals
|
||||||
|
for _, row := range []struct {
|
||||||
|
signal *signal.Signal
|
||||||
|
channel string
|
||||||
|
}{
|
||||||
|
{backend.onRecordChange, recordChangeNotifyName},
|
||||||
|
{backend.onServiceChange, serviceChangeNotifyName},
|
||||||
|
} {
|
||||||
|
sig, ch := row.signal, row.channel
|
||||||
|
go backend.doPeriodically(func(ctx context.Context) error {
|
||||||
|
_, pool, err := backend.init(backend.closeCtx)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := pool.Acquire(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer conn.Release()
|
||||||
|
|
||||||
|
_, err = conn.Exec(ctx, `LISTEN `+ch)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = conn.Conn().WaitForNotification(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
sig.Broadcast(ctx)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}, time.Millisecond*100)
|
||||||
|
}
|
||||||
|
|
||||||
return backend
|
return backend
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -327,7 +363,9 @@ func (backend *Backend) doPeriodically(f func(ctx context.Context) error, dur ti
|
||||||
case <-ticker.C:
|
case <-ticker.C:
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
log.Error(ctx).Err(err).Msg("storage/postgres")
|
if !errors.Is(err, context.Canceled) {
|
||||||
|
log.Error(ctx).Err(err).Msg("storage/postgres")
|
||||||
|
}
|
||||||
select {
|
select {
|
||||||
case <-backend.closeCtx.Done():
|
case <-backend.closeCtx.Done():
|
||||||
return
|
return
|
||||||
|
|
|
@ -112,6 +112,22 @@ var migrations = []func(context.Context, pgx.Tx) error{
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
3: func(ctx context.Context, tx pgx.Tx) error {
|
||||||
|
_, err := tx.Exec(ctx, `
|
||||||
|
CREATE TABLE `+schemaName+`.`+servicesTableName+` (
|
||||||
|
kind TEXT NOT NULL,
|
||||||
|
endpoint TEXT NOT NULL,
|
||||||
|
expires_at TIMESTAMPTZ NOT NULL,
|
||||||
|
|
||||||
|
PRIMARY KEY (kind, endpoint)
|
||||||
|
)
|
||||||
|
`)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,10 +4,14 @@ import (
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
const defaultExpiry = time.Hour * 24
|
const (
|
||||||
|
defaultExpiry = time.Hour * 24
|
||||||
|
defaultRegistryTTL = time.Second * 30
|
||||||
|
)
|
||||||
|
|
||||||
type config struct {
|
type config struct {
|
||||||
expiry time.Duration
|
expiry time.Duration
|
||||||
|
registryTTL time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
// Option customizes a Backend.
|
// Option customizes a Backend.
|
||||||
|
@ -20,9 +24,17 @@ func WithExpiry(expiry time.Duration) Option {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WithRegistryTTL sets the default registry TTL.
|
||||||
|
func WithRegistryTTL(ttl time.Duration) Option {
|
||||||
|
return func(cfg *config) {
|
||||||
|
cfg.registryTTL = ttl
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func getConfig(options ...Option) *config {
|
func getConfig(options ...Option) *config {
|
||||||
cfg := new(config)
|
cfg := new(config)
|
||||||
WithExpiry(defaultExpiry)(cfg)
|
WithExpiry(defaultExpiry)(cfg)
|
||||||
|
WithRegistryTTL(defaultRegistryTTL)(cfg)
|
||||||
for _, o := range options {
|
for _, o := range options {
|
||||||
o(cfg)
|
o(cfg)
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,17 +16,20 @@ import (
|
||||||
"google.golang.org/protobuf/types/known/timestamppb"
|
"google.golang.org/protobuf/types/known/timestamppb"
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||||
|
"github.com/pomerium/pomerium/pkg/grpc/registry"
|
||||||
"github.com/pomerium/pomerium/pkg/storage"
|
"github.com/pomerium/pomerium/pkg/storage"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
schemaName = "pomerium"
|
schemaName = "pomerium"
|
||||||
migrationInfoTableName = "migration_info"
|
migrationInfoTableName = "migration_info"
|
||||||
recordsTableName = "records"
|
recordsTableName = "records"
|
||||||
recordChangesTableName = "record_changes"
|
recordChangesTableName = "record_changes"
|
||||||
recordChangeNotifyName = "pomerium_record_change"
|
recordChangeNotifyName = "pomerium_record_change"
|
||||||
recordOptionsTableName = "record_options"
|
recordOptionsTableName = "record_options"
|
||||||
leasesTableName = "leases"
|
leasesTableName = "leases"
|
||||||
|
serviceChangeNotifyName = "pomerium_service_change"
|
||||||
|
servicesTableName = "services"
|
||||||
)
|
)
|
||||||
|
|
||||||
type querier interface {
|
type querier interface {
|
||||||
|
@ -43,6 +46,17 @@ func deleteChangesBefore(ctx context.Context, q querier, cutoff time.Time) error
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func deleteExpiredServices(ctx context.Context, q querier, cutoff time.Time) (rowCount int64, err error) {
|
||||||
|
cmd, err := q.Exec(ctx, `
|
||||||
|
DELETE FROM `+schemaName+`.`+servicesTableName+`
|
||||||
|
WHERE expires_at < $1
|
||||||
|
`, cutoff)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return cmd.RowsAffected(), nil
|
||||||
|
}
|
||||||
|
|
||||||
func dup(record *databroker.Record) *databroker.Record {
|
func dup(record *databroker.Record) *databroker.Record {
|
||||||
return proto.Clone(record).(*databroker.Record)
|
return proto.Clone(record).(*databroker.Record)
|
||||||
}
|
}
|
||||||
|
@ -221,6 +235,40 @@ func listRecords(ctx context.Context, q querier, expr storage.FilterExpression,
|
||||||
return records, rows.Err()
|
return records, rows.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func listServices(ctx context.Context, q querier) ([]*registry.Service, error) {
|
||||||
|
var services []*registry.Service
|
||||||
|
|
||||||
|
query := `
|
||||||
|
SELECT kind, endpoint
|
||||||
|
FROM ` + schemaName + `.` + servicesTableName + `
|
||||||
|
ORDER BY kind, endpoint
|
||||||
|
`
|
||||||
|
rows, err := q.Query(ctx, query)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
for rows.Next() {
|
||||||
|
var kind, endpoint string
|
||||||
|
err = rows.Scan(&kind, &endpoint)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
services = append(services, ®istry.Service{
|
||||||
|
Kind: registry.ServiceKind(registry.ServiceKind_value[kind]),
|
||||||
|
Endpoint: endpoint,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
err = rows.Err()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return services, nil
|
||||||
|
}
|
||||||
|
|
||||||
func maybeAcquireLease(ctx context.Context, q querier, leaseName, leaseID string, ttl time.Duration) (leaseHolderID string, err error) {
|
func maybeAcquireLease(ctx context.Context, q querier, leaseName, leaseID string, ttl time.Duration) (leaseHolderID string, err error) {
|
||||||
tbl := schemaName + "." + leasesTableName
|
tbl := schemaName + "." + leasesTableName
|
||||||
expiresAt := timestamptzFromTimestamppb(timestamppb.New(time.Now().Add(ttl)))
|
expiresAt := timestamptzFromTimestamppb(timestamppb.New(time.Now().Add(ttl)))
|
||||||
|
@ -283,6 +331,17 @@ func putRecordAndChange(ctx context.Context, q querier, record *databroker.Recor
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func putService(ctx context.Context, q querier, svc *registry.Service, expiresAt time.Time) error {
|
||||||
|
query := `
|
||||||
|
INSERT INTO ` + schemaName + `.` + servicesTableName + ` (kind, endpoint, expires_at)
|
||||||
|
VALUES ($1, $2, $3)
|
||||||
|
ON CONFLICT (kind, endpoint) DO UPDATE
|
||||||
|
SET expires_at=$3
|
||||||
|
`
|
||||||
|
_, err := q.Exec(ctx, query, svc.GetKind().String(), svc.GetEndpoint(), expiresAt)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
func setOptions(ctx context.Context, q querier, recordType string, options *databroker.Options) error {
|
func setOptions(ctx context.Context, q querier, recordType string, options *databroker.Options) error {
|
||||||
capacity := pgtype.Int8{Status: pgtype.Null}
|
capacity := pgtype.Int8{Status: pgtype.Null}
|
||||||
if options != nil && options.Capacity != nil {
|
if options != nil && options.Capacity != nil {
|
||||||
|
@ -304,6 +363,11 @@ func signalRecordChange(ctx context.Context, q querier) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func signalServiceChange(ctx context.Context, q querier) error {
|
||||||
|
_, err := q.Exec(ctx, `NOTIFY `+serviceChangeNotifyName)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
func jsonbFromAny(any *anypb.Any) (pgtype.JSONB, error) {
|
func jsonbFromAny(any *anypb.Any) (pgtype.JSONB, error) {
|
||||||
if any == nil {
|
if any == nil {
|
||||||
return pgtype.JSONB{Status: pgtype.Null}, nil
|
return pgtype.JSONB{Status: pgtype.Null}, nil
|
||||||
|
|
109
pkg/storage/postgres/registry.go
Normal file
109
pkg/storage/postgres/registry.go
Normal file
|
@ -0,0 +1,109 @@
|
||||||
|
package postgres
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/golang/protobuf/proto"
|
||||||
|
"google.golang.org/protobuf/types/known/durationpb"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/internal/sets"
|
||||||
|
"github.com/pomerium/pomerium/pkg/grpc/registry"
|
||||||
|
)
|
||||||
|
|
||||||
|
type registryServer struct {
|
||||||
|
*Backend
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegistryServer returns a registry.RegistryServer for the backend.
|
||||||
|
func (backend *Backend) RegistryServer() registry.RegistryServer {
|
||||||
|
return registryServer{backend}
|
||||||
|
}
|
||||||
|
|
||||||
|
// List lists services.
|
||||||
|
func (backend registryServer) List(
|
||||||
|
ctx context.Context,
|
||||||
|
req *registry.ListRequest,
|
||||||
|
) (*registry.ServiceList, error) {
|
||||||
|
_, pool, err := backend.init(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
all, err := listServices(ctx, pool)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
res := new(registry.ServiceList)
|
||||||
|
s := sets.NewHash[registry.ServiceKind]()
|
||||||
|
s.Add(req.GetKinds()...)
|
||||||
|
for _, svc := range all {
|
||||||
|
if s.Size() == 0 || s.Has(svc.GetKind()) {
|
||||||
|
res.Services = append(res.Services, svc)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return res, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Report registers services.
|
||||||
|
func (backend registryServer) Report(
|
||||||
|
ctx context.Context,
|
||||||
|
req *registry.RegisterRequest,
|
||||||
|
) (*registry.RegisterResponse, error) {
|
||||||
|
_, pool, err := backend.init(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, svc := range req.GetServices() {
|
||||||
|
err = putService(ctx, pool, svc, time.Now().Add(backend.cfg.registryTTL))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
err = signalServiceChange(ctx, pool)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return ®istry.RegisterResponse{
|
||||||
|
CallBackAfter: durationpb.New(backend.cfg.registryTTL / 2),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Watch watches services.
|
||||||
|
func (backend registryServer) Watch(
|
||||||
|
req *registry.ListRequest,
|
||||||
|
srv registry.Registry_WatchServer,
|
||||||
|
) error {
|
||||||
|
ch := backend.onServiceChange.Bind()
|
||||||
|
defer backend.onServiceChange.Unbind(ch)
|
||||||
|
|
||||||
|
ticker := time.NewTicker(watchPollInterval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
var prev *registry.ServiceList
|
||||||
|
for i := 0; ; i++ {
|
||||||
|
res, err := backend.List(srv.Context(), req)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if i == 0 || !proto.Equal(res, prev) {
|
||||||
|
err = srv.Send(res)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
prev = res
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-srv.Context().Done():
|
||||||
|
return srv.Context().Err()
|
||||||
|
case <-ch:
|
||||||
|
case <-ticker.C:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
113
pkg/storage/postgres/registry_test.go
Normal file
113
pkg/storage/postgres/registry_test.go
Normal file
|
@ -0,0 +1,113 @@
|
||||||
|
package postgres
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"runtime"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"golang.org/x/sync/errgroup"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/internal/testutil"
|
||||||
|
"github.com/pomerium/pomerium/pkg/grpc/registry"
|
||||||
|
)
|
||||||
|
|
||||||
|
type mockRegistryWatchServer struct {
|
||||||
|
registry.Registry_WatchServer
|
||||||
|
context context.Context
|
||||||
|
send func(*registry.ServiceList) error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m mockRegistryWatchServer) Context() context.Context {
|
||||||
|
return m.context
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m mockRegistryWatchServer) Send(res *registry.ServiceList) error {
|
||||||
|
return m.send(res)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRegistry(t *testing.T) {
|
||||||
|
if os.Getenv("GITHUB_ACTION") != "" && runtime.GOOS == "darwin" {
|
||||||
|
t.Skip("Github action can not run docker on MacOS")
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10)
|
||||||
|
defer clearTimeout()
|
||||||
|
|
||||||
|
require.NoError(t, testutil.WithTestPostgres(func(dsn string) error {
|
||||||
|
backend := New(dsn)
|
||||||
|
defer backend.Close()
|
||||||
|
|
||||||
|
eg, ctx := errgroup.WithContext(ctx)
|
||||||
|
listResults := make(chan *registry.ServiceList)
|
||||||
|
eg.Go(func() error {
|
||||||
|
srv := mockRegistryWatchServer{
|
||||||
|
context: ctx,
|
||||||
|
send: func(res *registry.ServiceList) error {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
case listResults <- res:
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
err := backend.RegistryServer().Watch(®istry.ListRequest{
|
||||||
|
Kinds: []registry.ServiceKind{
|
||||||
|
registry.ServiceKind_AUTHENTICATE,
|
||||||
|
registry.ServiceKind_CONSOLE,
|
||||||
|
},
|
||||||
|
}, srv)
|
||||||
|
if errors.Is(err, context.Canceled) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
eg.Go(func() error {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
case res := <-listResults:
|
||||||
|
testutil.AssertProtoEqual(t, ®istry.ServiceList{}, res)
|
||||||
|
}
|
||||||
|
|
||||||
|
res, err := backend.RegistryServer().Report(ctx, ®istry.RegisterRequest{
|
||||||
|
Services: []*registry.Service{
|
||||||
|
{Kind: registry.ServiceKind_AUTHENTICATE, Endpoint: "authenticate.example.com"},
|
||||||
|
{Kind: registry.ServiceKind_AUTHORIZE, Endpoint: "authorize.example.com"},
|
||||||
|
{Kind: registry.ServiceKind_CONSOLE, Endpoint: "console.example.com"},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error reporting status: %w", err)
|
||||||
|
}
|
||||||
|
assert.NotEqual(t, 0, res.GetCallBackAfter())
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
case res := <-listResults:
|
||||||
|
testutil.AssertProtoEqual(t, ®istry.ServiceList{
|
||||||
|
Services: []*registry.Service{
|
||||||
|
{Kind: registry.ServiceKind_AUTHENTICATE, Endpoint: "authenticate.example.com"},
|
||||||
|
{Kind: registry.ServiceKind_CONSOLE, Endpoint: "console.example.com"},
|
||||||
|
},
|
||||||
|
}, res)
|
||||||
|
}
|
||||||
|
|
||||||
|
return context.Canceled
|
||||||
|
})
|
||||||
|
err := eg.Wait()
|
||||||
|
if errors.Is(err, context.Canceled) {
|
||||||
|
err = nil
|
||||||
|
}
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}))
|
||||||
|
}
|
|
@ -104,7 +104,7 @@ func newChangedRecordStream(
|
||||||
recordType: recordType,
|
recordType: recordType,
|
||||||
recordVersion: recordVersion,
|
recordVersion: recordVersion,
|
||||||
ticker: time.NewTicker(watchPollInterval),
|
ticker: time.NewTicker(watchPollInterval),
|
||||||
changed: backend.onChange.Bind(),
|
changed: backend.onRecordChange.Bind(),
|
||||||
}
|
}
|
||||||
stream.ctx, stream.cancel = contextutil.Merge(ctx, backend.closeCtx)
|
stream.ctx, stream.cancel = contextutil.Merge(ctx, backend.closeCtx)
|
||||||
return stream
|
return stream
|
||||||
|
@ -113,7 +113,7 @@ func newChangedRecordStream(
|
||||||
func (stream *changedRecordStream) Close() error {
|
func (stream *changedRecordStream) Close() error {
|
||||||
stream.cancel()
|
stream.cancel()
|
||||||
stream.ticker.Stop()
|
stream.ticker.Stop()
|
||||||
stream.backend.onChange.Unbind(stream.changed)
|
stream.backend.onRecordChange.Unbind(stream.changed)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue