mirror of
https://github.com/pomerium/pomerium.git
synced 2025-04-29 10:26:29 +02:00
This also replaces instances where we manually write "return ctx.Err()" with "return context.Cause(ctx)" which is functionally identical, but will also correctly propagate cause errors if present.
481 lines
11 KiB
Go
481 lines
11 KiB
Go
package postgres
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"net"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/cenkalti/backoff/v4"
|
|
"github.com/jackc/pgx/v5/pgxpool"
|
|
"google.golang.org/protobuf/types/known/fieldmaskpb"
|
|
"google.golang.org/protobuf/types/known/timestamppb"
|
|
|
|
"github.com/pomerium/pomerium/internal/log"
|
|
"github.com/pomerium/pomerium/internal/signal"
|
|
"github.com/pomerium/pomerium/pkg/contextutil"
|
|
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
|
"github.com/pomerium/pomerium/pkg/health"
|
|
"github.com/pomerium/pomerium/pkg/storage"
|
|
)
|
|
|
|
// Backend is a storage Backend implemented with Postgres.
|
|
type Backend struct {
|
|
cfg *config
|
|
dsn string
|
|
onRecordChange *signal.Signal
|
|
onServiceChange *signal.Signal
|
|
|
|
closeCtx context.Context
|
|
close context.CancelFunc
|
|
|
|
mu sync.RWMutex
|
|
pool *pgxpool.Pool
|
|
serverVersion uint64
|
|
}
|
|
|
|
// New creates a new Backend.
|
|
func New(ctx context.Context, dsn string, options ...Option) *Backend {
|
|
backend := &Backend{
|
|
cfg: getConfig(options...),
|
|
dsn: dsn,
|
|
onRecordChange: signal.New(),
|
|
onServiceChange: signal.New(),
|
|
}
|
|
backend.closeCtx, backend.close = context.WithCancel(ctx)
|
|
|
|
go backend.doPeriodically(func(ctx context.Context) error {
|
|
_, pool, err := backend.init(ctx)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return deleteChangesBefore(ctx, pool, time.Now().Add(-backend.cfg.expiry))
|
|
}, time.Minute)
|
|
|
|
go backend.doPeriodically(func(ctx context.Context) error {
|
|
_, pool, err := backend.init(ctx)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
rowCount, err := deleteExpiredServices(ctx, pool, time.Now())
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if rowCount > 0 {
|
|
err = signalServiceChange(ctx, pool)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}, backend.cfg.registryTTL/2)
|
|
|
|
go backend.doPeriodically(func(ctx context.Context) error {
|
|
return backend.listenForNotifications(ctx)
|
|
}, time.Millisecond*100)
|
|
|
|
go backend.doPeriodically(func(ctx context.Context) error {
|
|
err := backend.ping(ctx)
|
|
if err != nil {
|
|
health.ReportError(health.StorageBackend, err, health.StrAttr("backend", "postgres"))
|
|
} else {
|
|
health.ReportOK(health.StorageBackend, health.StrAttr("backend", "postgres"))
|
|
}
|
|
return nil
|
|
}, time.Minute)
|
|
|
|
return backend
|
|
}
|
|
|
|
// Close closes the underlying database connection.
|
|
func (backend *Backend) Close() error {
|
|
backend.mu.Lock()
|
|
defer backend.mu.Unlock()
|
|
|
|
backend.close()
|
|
|
|
if backend.pool != nil {
|
|
backend.pool.Close()
|
|
backend.pool = nil
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Get gets a record from the database.
|
|
func (backend *Backend) Get(
|
|
ctx context.Context,
|
|
recordType, recordID string,
|
|
) (*databroker.Record, error) {
|
|
ctx, cancel := contextutil.Merge(ctx, backend.closeCtx)
|
|
defer cancel()
|
|
|
|
_, conn, err := backend.init(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return getRecord(ctx, conn, recordType, recordID, lockModeNone)
|
|
}
|
|
|
|
// GetOptions returns the options for the given record type.
|
|
func (backend *Backend) GetOptions(
|
|
ctx context.Context,
|
|
recordType string,
|
|
) (*databroker.Options, error) {
|
|
ctx, cancel := contextutil.Merge(ctx, backend.closeCtx)
|
|
defer cancel()
|
|
|
|
_, conn, err := backend.init(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return getOptions(ctx, conn, recordType)
|
|
}
|
|
|
|
// Lease attempts to acquire a lease for the given name.
|
|
func (backend *Backend) Lease(
|
|
ctx context.Context,
|
|
leaseName, leaseID string,
|
|
ttl time.Duration,
|
|
) (acquired bool, err error) {
|
|
ctx, cancel := contextutil.Merge(ctx, backend.closeCtx)
|
|
defer cancel()
|
|
|
|
_, conn, err := backend.init(ctx)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
|
|
leaseHolderID, err := maybeAcquireLease(ctx, conn, leaseName, leaseID, ttl)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
|
|
return leaseHolderID == leaseID, nil
|
|
}
|
|
|
|
// ListTypes lists the record types.
|
|
func (backend *Backend) ListTypes(ctx context.Context) ([]string, error) {
|
|
ctx, cancel := contextutil.Merge(ctx, backend.closeCtx)
|
|
defer cancel()
|
|
|
|
_, conn, err := backend.init(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return listTypes(ctx, conn)
|
|
}
|
|
|
|
// Put puts a record into Postgres.
|
|
func (backend *Backend) Put(
|
|
ctx context.Context,
|
|
records []*databroker.Record,
|
|
) (serverVersion uint64, err error) {
|
|
ctx, cancel := contextutil.Merge(ctx, backend.closeCtx)
|
|
defer cancel()
|
|
|
|
serverVersion, pool, err := backend.init(ctx)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
now := timestamppb.Now()
|
|
|
|
// add all the records
|
|
recordTypes := map[string]struct{}{}
|
|
for i, record := range records {
|
|
recordTypes[record.GetType()] = struct{}{}
|
|
|
|
record = dup(record)
|
|
record.ModifiedAt = now
|
|
err := putRecordAndChange(ctx, pool, record)
|
|
if err != nil {
|
|
return serverVersion, fmt.Errorf("storage/postgres: error saving record: %w", err)
|
|
}
|
|
records[i] = record
|
|
}
|
|
|
|
// enforce options for each record type
|
|
for recordType := range recordTypes {
|
|
options, err := getOptions(ctx, pool, recordType)
|
|
if err != nil {
|
|
return serverVersion, fmt.Errorf("storage/postgres: error getting options: %w", err)
|
|
}
|
|
err = enforceOptions(ctx, pool, recordType, options)
|
|
if err != nil {
|
|
return serverVersion, fmt.Errorf("storage/postgres: error enforcing options: %w", err)
|
|
}
|
|
}
|
|
|
|
err = signalRecordChange(ctx, pool)
|
|
return serverVersion, err
|
|
}
|
|
|
|
// Patch updates specific fields of existing records in Postgres.
|
|
func (backend *Backend) Patch(
|
|
ctx context.Context,
|
|
records []*databroker.Record,
|
|
fields *fieldmaskpb.FieldMask,
|
|
) (uint64, []*databroker.Record, error) {
|
|
ctx, cancel := contextutil.Merge(ctx, backend.closeCtx)
|
|
defer cancel()
|
|
|
|
serverVersion, pool, err := backend.init(ctx)
|
|
if err != nil {
|
|
return serverVersion, nil, err
|
|
}
|
|
|
|
patchedRecords := make([]*databroker.Record, 0, len(records))
|
|
|
|
now := timestamppb.Now()
|
|
|
|
for _, record := range records {
|
|
record = dup(record)
|
|
record.ModifiedAt = now
|
|
err := patchRecord(ctx, pool, record, fields)
|
|
if storage.IsNotFound(err) {
|
|
continue
|
|
} else if err != nil {
|
|
err = fmt.Errorf("storage/postgres: error patching record %q of type %q: %w",
|
|
record.GetId(), record.GetType(), err)
|
|
return serverVersion, patchedRecords, err
|
|
}
|
|
patchedRecords = append(patchedRecords, record)
|
|
}
|
|
|
|
err = signalRecordChange(ctx, pool)
|
|
return serverVersion, patchedRecords, err
|
|
}
|
|
|
|
// SetOptions sets the options for the given record type.
|
|
func (backend *Backend) SetOptions(
|
|
ctx context.Context,
|
|
recordType string,
|
|
options *databroker.Options,
|
|
) error {
|
|
ctx, cancel := contextutil.Merge(ctx, backend.closeCtx)
|
|
defer cancel()
|
|
|
|
_, conn, err := backend.init(ctx)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return setOptions(ctx, conn, recordType, options)
|
|
}
|
|
|
|
// Sync syncs the records.
|
|
func (backend *Backend) Sync(
|
|
ctx context.Context,
|
|
recordType string,
|
|
serverVersion, recordVersion uint64,
|
|
) (storage.RecordStream, error) {
|
|
// the original ctx will be used for the stream, this ctx used for pre-stream calls
|
|
callCtx, cancel := contextutil.Merge(ctx, backend.closeCtx)
|
|
defer cancel()
|
|
|
|
currentServerVersion, _, err := backend.init(callCtx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if currentServerVersion != serverVersion {
|
|
return nil, storage.ErrInvalidServerVersion
|
|
}
|
|
|
|
return newChangedRecordStream(ctx, backend, recordType, recordVersion), nil
|
|
}
|
|
|
|
// SyncLatest syncs the latest version of each record.
|
|
func (backend *Backend) SyncLatest(
|
|
ctx context.Context,
|
|
recordType string,
|
|
expr storage.FilterExpression,
|
|
) (serverVersion, recordVersion uint64, stream storage.RecordStream, err error) {
|
|
// the original ctx will be used for the stream, this ctx used for pre-stream calls
|
|
callCtx, cancel := contextutil.Merge(ctx, backend.closeCtx)
|
|
defer cancel()
|
|
|
|
serverVersion, pool, err := backend.init(callCtx)
|
|
if err != nil {
|
|
return 0, 0, nil, err
|
|
}
|
|
|
|
recordVersion, err = getLatestRecordVersion(callCtx, pool)
|
|
if err != nil {
|
|
return 0, 0, nil, err
|
|
}
|
|
|
|
if recordType != "" {
|
|
f := storage.EqualsFilterExpression{
|
|
Fields: []string{"type"},
|
|
Value: recordType,
|
|
}
|
|
if expr != nil {
|
|
expr = storage.AndFilterExpression{expr, f}
|
|
} else {
|
|
expr = f
|
|
}
|
|
}
|
|
|
|
stream = newRecordStream(ctx, backend, expr)
|
|
return serverVersion, recordVersion, stream, nil
|
|
}
|
|
|
|
func (backend *Backend) init(ctx context.Context) (serverVersion uint64, pool *pgxpool.Pool, err error) {
|
|
backend.mu.RLock()
|
|
serverVersion = backend.serverVersion
|
|
pool = backend.pool
|
|
backend.mu.RUnlock()
|
|
|
|
if pool != nil {
|
|
return serverVersion, pool, nil
|
|
}
|
|
|
|
backend.mu.Lock()
|
|
defer backend.mu.Unlock()
|
|
|
|
// double-checked locking, might have already initialized, so just return
|
|
serverVersion = backend.serverVersion
|
|
pool = backend.pool
|
|
if pool != nil {
|
|
return serverVersion, pool, nil
|
|
}
|
|
|
|
config, err := ParseConfig(backend.dsn)
|
|
if err != nil {
|
|
return serverVersion, nil, err
|
|
}
|
|
|
|
pool, err = pgxpool.NewWithConfig(context.Background(), config)
|
|
if err != nil {
|
|
return serverVersion, nil, err
|
|
}
|
|
|
|
tx, err := pool.Begin(ctx)
|
|
if err != nil {
|
|
return serverVersion, nil, err
|
|
}
|
|
|
|
serverVersion, err = migrate(ctx, tx)
|
|
if err != nil {
|
|
_ = tx.Rollback(ctx)
|
|
return serverVersion, nil, err
|
|
}
|
|
|
|
err = tx.Commit(ctx)
|
|
if err != nil {
|
|
_ = tx.Rollback(ctx)
|
|
return serverVersion, nil, err
|
|
}
|
|
|
|
backend.serverVersion = serverVersion
|
|
backend.pool = pool
|
|
return serverVersion, pool, nil
|
|
}
|
|
|
|
func (backend *Backend) doPeriodically(f func(ctx context.Context) error, dur time.Duration) {
|
|
ctx := backend.closeCtx
|
|
|
|
ticker := time.NewTicker(dur)
|
|
defer ticker.Stop()
|
|
|
|
bo := backoff.NewExponentialBackOff()
|
|
bo.MaxElapsedTime = 0
|
|
|
|
for {
|
|
err := f(ctx)
|
|
if err == nil {
|
|
bo.Reset()
|
|
select {
|
|
case <-backend.closeCtx.Done():
|
|
return
|
|
case <-ticker.C:
|
|
}
|
|
} else {
|
|
if !errors.Is(err, context.Canceled) {
|
|
log.Ctx(ctx).Error().Err(err).Msg("storage/postgres")
|
|
}
|
|
select {
|
|
case <-backend.closeCtx.Done():
|
|
return
|
|
case <-time.After(bo.NextBackOff()):
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (backend *Backend) listenForNotifications(ctx context.Context) error {
|
|
_, pool, err := backend.init(ctx)
|
|
if err != nil {
|
|
return fmt.Errorf("error initializing pool for notifications: %w", err)
|
|
}
|
|
|
|
poolConn, err := pool.Acquire(ctx)
|
|
if err != nil {
|
|
return fmt.Errorf("error acquiring connection from pool for notifications: %w", err)
|
|
}
|
|
|
|
// hijack the connection so the pool can be left for short-lived queries
|
|
// and so that LISTENs don't leak to other queries
|
|
conn := poolConn.Hijack()
|
|
defer conn.Close(ctx)
|
|
|
|
for _, ch := range []string{recordChangeNotifyName, serviceChangeNotifyName} {
|
|
_, err = conn.Exec(ctx, `LISTEN `+ch)
|
|
if err != nil {
|
|
return fmt.Errorf("error listening on channel %s for notifications: %w", ch, err)
|
|
}
|
|
}
|
|
|
|
// for each notification broadcast the signal
|
|
for {
|
|
n, err := conn.WaitForNotification(ctx)
|
|
if err != nil {
|
|
// on error we'll close the connection to stop listening
|
|
return fmt.Errorf("error receiving notification: %w", err)
|
|
}
|
|
|
|
switch n.Channel {
|
|
case recordChangeNotifyName:
|
|
backend.onRecordChange.Broadcast(ctx)
|
|
case serviceChangeNotifyName:
|
|
backend.onServiceChange.Broadcast(ctx)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (backend *Backend) ping(ctx context.Context) error {
|
|
_, pool, err := backend.init(ctx)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return pool.Ping(ctx)
|
|
}
|
|
|
|
// ParseConfig parses a DSN into a pgxpool.Config.
|
|
func ParseConfig(dsn string) (*pgxpool.Config, error) {
|
|
config, err := pgxpool.ParseConfig(dsn)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
config.ConnConfig.LookupFunc = lookup
|
|
return config, nil
|
|
}
|
|
|
|
func lookup(ctx context.Context, host string) (addrs []string, err error) {
|
|
addrs, err = net.DefaultResolver.LookupHost(ctx, host)
|
|
// ignore no such host errors
|
|
if e := new(net.DNSError); errors.As(err, &e) && e.IsNotFound {
|
|
addrs = nil
|
|
err = nil
|
|
}
|
|
return addrs, err
|
|
}
|