mirror of
https://github.com/pomerium/pomerium.git
synced 2025-07-04 02:18:42 +02:00
postgres: databroker storage backend (#3370)
* wip * storage: add filtering to SyncLatest * don't increment the record version, so intermediate changes are requested * databroker: add support for query filtering * fill server and record version * postgres: databroker storage backend * wip * serialize puts * add test * skip tests for macos * add test * return error from protojson * set data * exclude postgres from cover tests
This commit is contained in:
parent
550698b1ca
commit
1c2aad2de6
21 changed files with 1573 additions and 17 deletions
358
pkg/storage/postgres/backend.go
Normal file
358
pkg/storage/postgres/backend.go
Normal file
|
@ -0,0 +1,358 @@
|
|||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
"github.com/jackc/pgx/v4"
|
||||
"github.com/jackc/pgx/v4/pgxpool"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/internal/signal"
|
||||
"github.com/pomerium/pomerium/pkg/contextutil"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/pkg/storage"
|
||||
)
|
||||
|
||||
// Backend is a storage Backend implemented with Postgres.
|
||||
type Backend struct {
|
||||
cfg *config
|
||||
dsn string
|
||||
onChange *signal.Signal
|
||||
|
||||
closeCtx context.Context
|
||||
close context.CancelFunc
|
||||
|
||||
mu sync.RWMutex
|
||||
pool *pgxpool.Pool
|
||||
serverVersion uint64
|
||||
}
|
||||
|
||||
// New creates a new Backend.
|
||||
func New(dsn string, options ...Option) *Backend {
|
||||
backend := &Backend{
|
||||
cfg: getConfig(options...),
|
||||
dsn: dsn,
|
||||
onChange: signal.New(),
|
||||
}
|
||||
backend.closeCtx, backend.close = context.WithCancel(context.Background())
|
||||
go backend.doPeriodically(func(ctx context.Context) error {
|
||||
_, pool, err := backend.init(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return deleteChangesBefore(ctx, pool, time.Now().Add(-backend.cfg.expiry))
|
||||
}, time.Minute)
|
||||
go backend.doPeriodically(func(ctx context.Context) error {
|
||||
_, pool, err := backend.init(backend.closeCtx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
conn, err := pool.Acquire(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer conn.Release()
|
||||
|
||||
_, err = conn.Exec(ctx, `LISTEN `+recordChangeNotifyName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = conn.Conn().WaitForNotification(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
backend.onChange.Broadcast(ctx)
|
||||
|
||||
return nil
|
||||
}, time.Millisecond*100)
|
||||
return backend
|
||||
}
|
||||
|
||||
// Close closes the underlying database connection.
|
||||
func (backend *Backend) Close() error {
|
||||
backend.mu.Lock()
|
||||
defer backend.mu.Unlock()
|
||||
|
||||
backend.close()
|
||||
|
||||
if backend.pool != nil {
|
||||
backend.pool.Close()
|
||||
backend.pool = nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get gets a record from the database.
|
||||
func (backend *Backend) Get(
|
||||
ctx context.Context,
|
||||
recordType, recordID string,
|
||||
) (*databroker.Record, error) {
|
||||
ctx, cancel := contextutil.Merge(ctx, backend.closeCtx)
|
||||
defer cancel()
|
||||
|
||||
_, conn, err := backend.init(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return getRecord(ctx, conn, recordType, recordID)
|
||||
}
|
||||
|
||||
// GetOptions returns the options for the given record type.
|
||||
func (backend *Backend) GetOptions(
|
||||
ctx context.Context,
|
||||
recordType string,
|
||||
) (*databroker.Options, error) {
|
||||
ctx, cancel := contextutil.Merge(ctx, backend.closeCtx)
|
||||
defer cancel()
|
||||
|
||||
_, conn, err := backend.init(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return getOptions(ctx, conn, recordType)
|
||||
}
|
||||
|
||||
// Lease attempts to acquire a lease for the given name.
|
||||
func (backend *Backend) Lease(
|
||||
ctx context.Context,
|
||||
leaseName, leaseID string,
|
||||
ttl time.Duration,
|
||||
) (acquired bool, err error) {
|
||||
ctx, cancel := contextutil.Merge(ctx, backend.closeCtx)
|
||||
defer cancel()
|
||||
|
||||
_, conn, err := backend.init(ctx)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
leaseHolderID, err := maybeAcquireLease(ctx, conn, leaseName, leaseID, ttl)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return leaseHolderID == leaseID, nil
|
||||
}
|
||||
|
||||
// Put puts a record into Postgres.
|
||||
func (backend *Backend) Put(
|
||||
ctx context.Context,
|
||||
records []*databroker.Record,
|
||||
) (serverVersion uint64, err error) {
|
||||
ctx, cancel := contextutil.Merge(ctx, backend.closeCtx)
|
||||
defer cancel()
|
||||
|
||||
serverVersion, pool, err := backend.init(ctx)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
err = pool.BeginTxFunc(ctx, pgx.TxOptions{
|
||||
IsoLevel: pgx.Serializable,
|
||||
AccessMode: pgx.ReadWrite,
|
||||
}, func(tx pgx.Tx) error {
|
||||
now := timestamppb.Now()
|
||||
|
||||
recordVersion, err := getLatestRecordVersion(ctx, tx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("storage/postgres: error getting latest record version: %w", err)
|
||||
}
|
||||
|
||||
// add all the records
|
||||
recordTypes := map[string]struct{}{}
|
||||
for i, record := range records {
|
||||
recordTypes[record.GetType()] = struct{}{}
|
||||
|
||||
record = dup(record)
|
||||
record.ModifiedAt = now
|
||||
record.Version = recordVersion + uint64(i) + 1
|
||||
err := putRecordChange(ctx, tx, record)
|
||||
if err != nil {
|
||||
return fmt.Errorf("storage/postgres: error saving record change: %w", err)
|
||||
}
|
||||
|
||||
err = putRecord(ctx, tx, record)
|
||||
if err != nil {
|
||||
return fmt.Errorf("storage/postgres: error saving record: %w", err)
|
||||
}
|
||||
records[i] = record
|
||||
}
|
||||
|
||||
// enforce options for each record type
|
||||
for recordType := range recordTypes {
|
||||
options, err := getOptions(ctx, tx, recordType)
|
||||
if err != nil {
|
||||
return fmt.Errorf("storage/postgres: error getting options: %w", err)
|
||||
}
|
||||
err = enforceOptions(ctx, tx, recordType, options)
|
||||
if err != nil {
|
||||
return fmt.Errorf("storage/postgres: error enforcing options: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return serverVersion, err
|
||||
}
|
||||
|
||||
err = signalRecordChange(ctx, pool)
|
||||
return serverVersion, err
|
||||
}
|
||||
|
||||
// SetOptions sets the options for the given record type.
|
||||
func (backend *Backend) SetOptions(
|
||||
ctx context.Context,
|
||||
recordType string,
|
||||
options *databroker.Options,
|
||||
) error {
|
||||
ctx, cancel := contextutil.Merge(ctx, backend.closeCtx)
|
||||
defer cancel()
|
||||
|
||||
_, conn, err := backend.init(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return setOptions(ctx, conn, recordType, options)
|
||||
}
|
||||
|
||||
// Sync syncs the records.
|
||||
func (backend *Backend) Sync(
|
||||
ctx context.Context,
|
||||
serverVersion, recordVersion uint64,
|
||||
) (storage.RecordStream, error) {
|
||||
// the original ctx will be used for the stream, this ctx used for pre-stream calls
|
||||
callCtx, cancel := contextutil.Merge(ctx, backend.closeCtx)
|
||||
defer cancel()
|
||||
|
||||
currentServerVersion, _, err := backend.init(callCtx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if currentServerVersion != serverVersion {
|
||||
return nil, storage.ErrInvalidServerVersion
|
||||
}
|
||||
|
||||
return newChangedRecordStream(ctx, backend, recordVersion), nil
|
||||
}
|
||||
|
||||
// SyncLatest syncs the latest version of each record.
|
||||
func (backend *Backend) SyncLatest(
|
||||
ctx context.Context,
|
||||
recordType string,
|
||||
expr storage.FilterExpression,
|
||||
) (serverVersion, recordVersion uint64, stream storage.RecordStream, err error) {
|
||||
// the original ctx will be used for the stream, this ctx used for pre-stream calls
|
||||
callCtx, cancel := contextutil.Merge(ctx, backend.closeCtx)
|
||||
defer cancel()
|
||||
|
||||
serverVersion, pool, err := backend.init(callCtx)
|
||||
if err != nil {
|
||||
return 0, 0, nil, err
|
||||
}
|
||||
|
||||
recordVersion, err = getLatestRecordVersion(callCtx, pool)
|
||||
if err != nil {
|
||||
return 0, 0, nil, err
|
||||
}
|
||||
|
||||
if recordType != "" {
|
||||
f := storage.EqualsFilterExpression{
|
||||
Fields: []string{"type"},
|
||||
Value: recordType,
|
||||
}
|
||||
if expr != nil {
|
||||
expr = storage.AndFilterExpression{expr, f}
|
||||
} else {
|
||||
expr = f
|
||||
}
|
||||
}
|
||||
|
||||
stream = newRecordStream(ctx, backend, expr)
|
||||
return serverVersion, recordVersion, stream, nil
|
||||
}
|
||||
|
||||
func (backend *Backend) init(ctx context.Context) (serverVersion uint64, pool *pgxpool.Pool, err error) {
|
||||
backend.mu.RLock()
|
||||
serverVersion = backend.serverVersion
|
||||
pool = backend.pool
|
||||
backend.mu.RUnlock()
|
||||
|
||||
if pool != nil {
|
||||
return serverVersion, pool, nil
|
||||
}
|
||||
|
||||
backend.mu.Lock()
|
||||
defer backend.mu.Unlock()
|
||||
|
||||
// double-checked locking, might have already initialized, so just return
|
||||
serverVersion = backend.serverVersion
|
||||
pool = backend.pool
|
||||
if pool != nil {
|
||||
return serverVersion, pool, nil
|
||||
}
|
||||
|
||||
config, err := pgxpool.ParseConfig(backend.dsn)
|
||||
if err != nil {
|
||||
return serverVersion, nil, err
|
||||
}
|
||||
|
||||
pool, err = pgxpool.ConnectConfig(context.Background(), config)
|
||||
if err != nil {
|
||||
return serverVersion, nil, err
|
||||
}
|
||||
|
||||
err = pool.BeginFunc(ctx, func(tx pgx.Tx) error {
|
||||
var err error
|
||||
serverVersion, err = migrate(ctx, tx)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
return serverVersion, nil, err
|
||||
}
|
||||
|
||||
backend.serverVersion = serverVersion
|
||||
backend.pool = pool
|
||||
return serverVersion, pool, nil
|
||||
}
|
||||
|
||||
func (backend *Backend) doPeriodically(f func(ctx context.Context) error, dur time.Duration) {
|
||||
ctx := backend.closeCtx
|
||||
|
||||
ticker := time.NewTicker(dur)
|
||||
defer ticker.Stop()
|
||||
|
||||
bo := backoff.NewExponentialBackOff()
|
||||
bo.MaxElapsedTime = 0
|
||||
|
||||
for {
|
||||
err := f(ctx)
|
||||
if err == nil {
|
||||
bo.Reset()
|
||||
select {
|
||||
case <-backend.closeCtx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
}
|
||||
} else {
|
||||
log.Error(ctx).Err(err).Msg("storage/postgres")
|
||||
select {
|
||||
case <-backend.closeCtx.Done():
|
||||
return
|
||||
case <-time.After(bo.NextBackOff()):
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue