mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-01 19:36:32 +02:00
494 lines
14 KiB
Go
494 lines
14 KiB
Go
// Package postgres contains an implementation of the storage.Backend backed by postgres.
|
|
package postgres
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"sort"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/jackc/pgx/v5"
|
|
"github.com/jackc/pgx/v5/pgconn"
|
|
"github.com/jackc/pgx/v5/pgtype"
|
|
"github.com/jackc/pgx/v5/pgxpool"
|
|
"google.golang.org/protobuf/encoding/protojson"
|
|
"google.golang.org/protobuf/proto"
|
|
"google.golang.org/protobuf/reflect/protoregistry"
|
|
"google.golang.org/protobuf/types/known/anypb"
|
|
"google.golang.org/protobuf/types/known/fieldmaskpb"
|
|
"google.golang.org/protobuf/types/known/timestamppb"
|
|
|
|
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
|
"github.com/pomerium/pomerium/pkg/grpc/registry"
|
|
"github.com/pomerium/pomerium/pkg/protoutil"
|
|
"github.com/pomerium/pomerium/pkg/storage"
|
|
)
|
|
|
|
var (
|
|
schemaName = "pomerium"
|
|
migrationInfoTableName = "migration_info"
|
|
recordsTableName = "records"
|
|
recordChangesTableName = "record_changes"
|
|
recordChangeNotifyName = "pomerium_record_change"
|
|
recordOptionsTableName = "record_options"
|
|
leasesTableName = "leases"
|
|
serviceChangeNotifyName = "pomerium_service_change"
|
|
servicesTableName = "services"
|
|
)
|
|
|
|
type querier interface {
|
|
Exec(ctx context.Context, sql string, arguments ...any) (commandTag pgconn.CommandTag, err error)
|
|
Query(ctx context.Context, sql string, args ...any) (pgx.Rows, error)
|
|
QueryRow(ctx context.Context, sql string, args ...any) pgx.Row
|
|
}
|
|
|
|
func deleteChangesBefore(ctx context.Context, q querier, cutoff time.Time) error {
|
|
_, err := q.Exec(ctx, `
|
|
DELETE FROM `+schemaName+`.`+recordChangesTableName+`
|
|
WHERE modified_at < $1
|
|
`, cutoff)
|
|
return err
|
|
}
|
|
|
|
func deleteExpiredServices(ctx context.Context, q querier, cutoff time.Time) (rowCount int64, err error) {
|
|
cmd, err := q.Exec(ctx, `
|
|
DELETE FROM `+schemaName+`.`+servicesTableName+`
|
|
WHERE expires_at < $1
|
|
`, cutoff)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
return cmd.RowsAffected(), nil
|
|
}
|
|
|
|
func dup(record *databroker.Record) *databroker.Record {
|
|
return proto.Clone(record).(*databroker.Record)
|
|
}
|
|
|
|
func enforceOptions(ctx context.Context, q querier, recordType string, options *databroker.Options) error {
|
|
if options == nil || options.Capacity == nil {
|
|
return nil
|
|
}
|
|
|
|
_, err := q.Exec(ctx, `
|
|
DELETE FROM `+schemaName+`.`+recordsTableName+`
|
|
WHERE type=$1
|
|
AND id NOT IN (
|
|
SELECT id
|
|
FROM `+schemaName+`.`+recordsTableName+`
|
|
WHERE type=$1
|
|
ORDER BY version DESC
|
|
LIMIT $2
|
|
)
|
|
`, recordType, options.GetCapacity())
|
|
return err
|
|
}
|
|
|
|
func getLatestRecordVersion(ctx context.Context, q querier) (recordVersion uint64, err error) {
|
|
err = q.QueryRow(ctx, `
|
|
SELECT version
|
|
FROM `+schemaName+`.`+recordChangesTableName+`
|
|
ORDER BY version DESC
|
|
LIMIT 1
|
|
`).Scan(&recordVersion)
|
|
if isNotFound(err) {
|
|
err = nil
|
|
}
|
|
return recordVersion, err
|
|
}
|
|
|
|
func getNextChangedRecord(ctx context.Context, q querier, recordType string, afterRecordVersion uint64) (*databroker.Record, error) {
|
|
var recordID string
|
|
var version uint64
|
|
var data []byte
|
|
var modifiedAt pgtype.Timestamptz
|
|
var deletedAt pgtype.Timestamptz
|
|
query := `
|
|
SELECT type, id, version, data, modified_at, deleted_at
|
|
FROM ` + schemaName + `.` + recordChangesTableName + `
|
|
WHERE version > $1
|
|
`
|
|
args := []any{afterRecordVersion}
|
|
if recordType != "" {
|
|
query += ` AND type = $2`
|
|
args = append(args, recordType)
|
|
}
|
|
query += `
|
|
ORDER BY version ASC
|
|
LIMIT 1
|
|
`
|
|
err := q.QueryRow(ctx, query, args...).Scan(&recordType, &recordID, &version, &data, &modifiedAt, &deletedAt)
|
|
if isNotFound(err) {
|
|
return nil, storage.ErrNotFound
|
|
} else if err != nil {
|
|
return nil, fmt.Errorf("error querying next changed record: %w", err)
|
|
}
|
|
|
|
// data may be nil if a record is deleted
|
|
var a *anypb.Any
|
|
if len(data) != 0 {
|
|
a, err = protoutil.UnmarshalAnyJSON(data)
|
|
if isUnknownType(err) {
|
|
a = protoutil.ToAny(protoutil.ToStruct(map[string]string{
|
|
"id": recordID,
|
|
}))
|
|
} else if err != nil {
|
|
return nil, fmt.Errorf("error unmarshaling changed record data: %w", err)
|
|
}
|
|
}
|
|
|
|
return &databroker.Record{
|
|
Version: version,
|
|
Type: recordType,
|
|
Id: recordID,
|
|
Data: a,
|
|
ModifiedAt: timestamppbFromTimestamptz(modifiedAt),
|
|
DeletedAt: timestamppbFromTimestamptz(deletedAt),
|
|
}, nil
|
|
}
|
|
|
|
func getOptions(ctx context.Context, q querier, recordType string) (*databroker.Options, error) {
|
|
var capacity pgtype.Int8
|
|
err := q.QueryRow(ctx, `
|
|
SELECT capacity
|
|
FROM `+schemaName+`.`+recordOptionsTableName+`
|
|
WHERE type=$1
|
|
`, recordType).Scan(&capacity)
|
|
if err != nil && !isNotFound(err) {
|
|
return nil, err
|
|
}
|
|
options := new(databroker.Options)
|
|
if capacity.Valid {
|
|
options.Capacity = proto.Uint64(uint64(capacity.Int64))
|
|
}
|
|
return options, nil
|
|
}
|
|
|
|
type lockMode string
|
|
|
|
const (
|
|
lockModeNone lockMode = ""
|
|
lockModeUpdate lockMode = "FOR UPDATE"
|
|
)
|
|
|
|
func getRecord(
|
|
ctx context.Context, q querier, recordType, recordID string, lockMode lockMode,
|
|
) (*databroker.Record, error) {
|
|
var version uint64
|
|
var data []byte
|
|
var modifiedAt pgtype.Timestamptz
|
|
err := q.QueryRow(ctx, `
|
|
SELECT version, data, modified_at
|
|
FROM `+schemaName+`.`+recordsTableName+`
|
|
WHERE type=$1 AND id=$2 `+string(lockMode),
|
|
recordType, recordID).Scan(&version, &data, &modifiedAt)
|
|
if isNotFound(err) {
|
|
return nil, storage.ErrNotFound
|
|
} else if err != nil {
|
|
return nil, fmt.Errorf("postgres: failed to execute query: %w", err)
|
|
}
|
|
|
|
a, err := protoutil.UnmarshalAnyJSON(data)
|
|
if isUnknownType(err) {
|
|
return nil, storage.ErrNotFound
|
|
} else if err != nil {
|
|
return nil, fmt.Errorf("postgres: failed to unmarshal data: %w", err)
|
|
}
|
|
|
|
return &databroker.Record{
|
|
Version: version,
|
|
Type: recordType,
|
|
Id: recordID,
|
|
Data: a,
|
|
ModifiedAt: timestamppbFromTimestamptz(modifiedAt),
|
|
}, nil
|
|
}
|
|
|
|
func listRecords(ctx context.Context, q querier, expr storage.FilterExpression, offset, limit int) ([]*databroker.Record, error) {
|
|
args := []any{offset, limit}
|
|
query := `
|
|
SELECT type, id, version, data, modified_at
|
|
FROM ` + schemaName + `.` + recordsTableName + `
|
|
`
|
|
if expr != nil {
|
|
query += "WHERE "
|
|
err := addFilterExpressionToQuery(&query, &args, expr)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("postgres: failed to add filter to query: %w", err)
|
|
}
|
|
}
|
|
query += `
|
|
ORDER BY type, id
|
|
LIMIT $2
|
|
OFFSET $1
|
|
`
|
|
rows, err := q.Query(ctx, query, args...)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("postgres: failed to execute query: %w", err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
var records []*databroker.Record
|
|
for rows.Next() {
|
|
var recordType, id string
|
|
var version uint64
|
|
var data []byte
|
|
var modifiedAt pgtype.Timestamptz
|
|
err = rows.Scan(&recordType, &id, &version, &data, &modifiedAt)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("postgres: failed to scan row: %w", err)
|
|
}
|
|
|
|
a, err := protoutil.UnmarshalAnyJSON(data)
|
|
if isUnknownType(err) {
|
|
a = protoutil.ToAny(protoutil.ToStruct(map[string]string{
|
|
"id": id,
|
|
}))
|
|
} else if err != nil {
|
|
return nil, fmt.Errorf("postgres: failed to unmarshal data: %w", err)
|
|
}
|
|
|
|
records = append(records, &databroker.Record{
|
|
Version: version,
|
|
Type: recordType,
|
|
Id: id,
|
|
Data: a,
|
|
ModifiedAt: timestamppbFromTimestamptz(modifiedAt),
|
|
})
|
|
}
|
|
err = rows.Err()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("postgres: error iterating over rows: %w", err)
|
|
}
|
|
|
|
return records, nil
|
|
}
|
|
|
|
func listServices(ctx context.Context, q querier) ([]*registry.Service, error) {
|
|
var services []*registry.Service
|
|
|
|
query := `
|
|
SELECT kind, endpoint
|
|
FROM ` + schemaName + `.` + servicesTableName + `
|
|
ORDER BY kind, endpoint
|
|
`
|
|
rows, err := q.Query(ctx, query)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("postgres: failed to execute query: %w", err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
for rows.Next() {
|
|
var kind, endpoint string
|
|
err = rows.Scan(&kind, &endpoint)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("postgres: failed to scan row: %w", err)
|
|
}
|
|
|
|
services = append(services, ®istry.Service{
|
|
Kind: registry.ServiceKind(registry.ServiceKind_value[kind]),
|
|
Endpoint: endpoint,
|
|
})
|
|
}
|
|
err = rows.Err()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("postgres: error iterating over rows: %w", err)
|
|
}
|
|
|
|
return services, nil
|
|
}
|
|
|
|
func listTypes(ctx context.Context, q querier) ([]string, error) {
|
|
query := `
|
|
SELECT DISTINCT type
|
|
FROM ` + schemaName + `.` + recordsTableName + `
|
|
`
|
|
rows, err := q.Query(ctx, query)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("postgres: failed to execute query: %w", err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
var types []string
|
|
for rows.Next() {
|
|
var recordType string
|
|
err = rows.Scan(&recordType)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("postgres: failed to scan row: %w", err)
|
|
}
|
|
|
|
types = append(types, recordType)
|
|
}
|
|
err = rows.Err()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("postgres: error iterating over rows: %w", err)
|
|
}
|
|
|
|
sort.Strings(types)
|
|
return types, nil
|
|
}
|
|
|
|
func maybeAcquireLease(ctx context.Context, q querier, leaseName, leaseID string, ttl time.Duration) (leaseHolderID string, err error) {
|
|
tbl := schemaName + "." + leasesTableName
|
|
expiresAt := timestamptzFromTimestamppb(timestamppb.New(time.Now().Add(ttl)))
|
|
now := timestamptzFromTimestamppb(timestamppb.Now())
|
|
err = q.QueryRow(ctx, `
|
|
INSERT INTO `+tbl+` (name, id, expires_at)
|
|
VALUES ($1, $2, $3)
|
|
ON CONFLICT (name) DO UPDATE
|
|
SET id=CASE WHEN `+tbl+`.expires_at<$4 OR `+tbl+`.id=$2 THEN $2 ELSE `+tbl+`.id END,
|
|
expires_at=CASE WHEN `+tbl+`.expires_at<$4 OR `+tbl+`.id=$2 THEN $3 ELSE `+tbl+`.expires_at END
|
|
RETURNING `+tbl+`.id
|
|
`, leaseName, leaseID, expiresAt, now).Scan(&leaseHolderID)
|
|
return leaseHolderID, err
|
|
}
|
|
|
|
func putRecordAndChange(ctx context.Context, q querier, record *databroker.Record) error {
|
|
data, err := jsonbFromAny(record.GetData())
|
|
if err != nil {
|
|
return fmt.Errorf("postgres: failed to convert any to json: %w", err)
|
|
}
|
|
|
|
modifiedAt := timestamptzFromTimestamppb(record.GetModifiedAt())
|
|
deletedAt := timestamptzFromTimestamppb(record.GetDeletedAt())
|
|
indexCIDR := &pgtype.Text{Valid: false}
|
|
if cidr := storage.GetRecordIndexCIDR(record.GetData()); cidr != nil {
|
|
indexCIDR.String = cidr.String()
|
|
indexCIDR.Valid = true
|
|
}
|
|
|
|
query := `
|
|
WITH t1 AS (
|
|
INSERT INTO ` + schemaName + `.` + recordChangesTableName + ` (type, id, data, modified_at, deleted_at)
|
|
VALUES ($1, $2, $3, $4, $5)
|
|
RETURNING *
|
|
)
|
|
`
|
|
args := []any{
|
|
record.GetType(), record.GetId(), data, modifiedAt, deletedAt,
|
|
}
|
|
if record.GetDeletedAt() == nil {
|
|
query += `
|
|
INSERT INTO ` + schemaName + `.` + recordsTableName + ` (type, id, version, data, modified_at, index_cidr)
|
|
VALUES ($1, $2, (SELECT version FROM t1), $3, $4, $6)
|
|
ON CONFLICT (type, id) DO UPDATE
|
|
SET version=(SELECT version FROM t1), data=$3, modified_at=$4, index_cidr=$6
|
|
RETURNING ` + schemaName + `.` + recordsTableName + `.version
|
|
`
|
|
args = append(args, indexCIDR)
|
|
} else {
|
|
query += `
|
|
DELETE FROM ` + schemaName + `.` + recordsTableName + `
|
|
WHERE type=$1 AND id=$2
|
|
RETURNING ` + schemaName + `.` + recordsTableName + `.version
|
|
`
|
|
}
|
|
err = q.QueryRow(ctx, query, args...).Scan(&record.Version)
|
|
if err != nil && !isNotFound(err) {
|
|
return fmt.Errorf("postgres: failed to execute query: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// patchRecord updates specific fields of an existing record.
|
|
func patchRecord(
|
|
ctx context.Context, p *pgxpool.Pool, record *databroker.Record, fields *fieldmaskpb.FieldMask,
|
|
) error {
|
|
tx, err := p.Begin(ctx)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer func() { _ = tx.Rollback(ctx) }()
|
|
|
|
existing, err := getRecord(ctx, tx, record.GetType(), record.GetId(), lockModeUpdate)
|
|
if isNotFound(err) {
|
|
return storage.ErrNotFound
|
|
} else if err != nil {
|
|
return err
|
|
}
|
|
|
|
if err := storage.PatchRecord(existing, record, fields); err != nil {
|
|
return err
|
|
}
|
|
|
|
if err := putRecordAndChange(ctx, tx, record); err != nil {
|
|
return err
|
|
}
|
|
|
|
return tx.Commit(ctx)
|
|
}
|
|
|
|
func putService(ctx context.Context, q querier, svc *registry.Service, expiresAt time.Time) error {
|
|
query := `
|
|
INSERT INTO ` + schemaName + `.` + servicesTableName + ` (kind, endpoint, expires_at)
|
|
VALUES ($1, $2, $3)
|
|
ON CONFLICT (kind, endpoint) DO UPDATE
|
|
SET expires_at=$3
|
|
`
|
|
_, err := q.Exec(ctx, query, svc.GetKind().String(), svc.GetEndpoint(), expiresAt)
|
|
return err
|
|
}
|
|
|
|
func setOptions(ctx context.Context, q querier, recordType string, options *databroker.Options) error {
|
|
capacity := pgtype.Int8{}
|
|
if options != nil && options.Capacity != nil {
|
|
capacity.Int64 = int64(options.GetCapacity())
|
|
capacity.Valid = true
|
|
}
|
|
|
|
_, err := q.Exec(ctx, `
|
|
INSERT INTO `+schemaName+`.`+recordOptionsTableName+` (type, capacity)
|
|
VALUES ($1, $2)
|
|
ON CONFLICT (type) DO UPDATE
|
|
SET capacity=$2
|
|
`, recordType, capacity)
|
|
return err
|
|
}
|
|
|
|
func signalRecordChange(ctx context.Context, q querier) error {
|
|
_, err := q.Exec(ctx, `NOTIFY `+recordChangeNotifyName)
|
|
return err
|
|
}
|
|
|
|
func signalServiceChange(ctx context.Context, q querier) error {
|
|
_, err := q.Exec(ctx, `NOTIFY `+serviceChangeNotifyName)
|
|
return err
|
|
}
|
|
|
|
func jsonbFromAny(any *anypb.Any) ([]byte, error) {
|
|
if any == nil {
|
|
return nil, nil
|
|
}
|
|
|
|
return protojson.Marshal(any)
|
|
}
|
|
|
|
func timestamppbFromTimestamptz(ts pgtype.Timestamptz) *timestamppb.Timestamp {
|
|
if !ts.Valid {
|
|
return nil
|
|
}
|
|
return timestamppb.New(ts.Time)
|
|
}
|
|
|
|
func timestamptzFromTimestamppb(ts *timestamppb.Timestamp) pgtype.Timestamptz {
|
|
if !ts.IsValid() {
|
|
return pgtype.Timestamptz{}
|
|
}
|
|
return pgtype.Timestamptz{Time: ts.AsTime(), Valid: true}
|
|
}
|
|
|
|
func isNotFound(err error) bool {
|
|
return errors.Is(err, pgx.ErrNoRows) || errors.Is(err, storage.ErrNotFound)
|
|
}
|
|
|
|
func isUnknownType(err error) bool {
|
|
if err == nil {
|
|
return false
|
|
}
|
|
|
|
return errors.Is(err, protoregistry.NotFound) ||
|
|
strings.Contains(err.Error(), "unable to resolve") // protojson doesn't wrap errors so check for the string
|
|
}
|