mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-19 03:57:17 +02:00
postgres: use CTE and GENERATED version number instead of serialized transaction (#3408)
* postgres: use CTE and GENERATED version number instead of serialized transaction * update server version * fix indexing CIDRs
This commit is contained in:
parent
a7bd284b52
commit
a2d5d8062b
5 changed files with 81 additions and 82 deletions
|
@ -158,53 +158,32 @@ func (backend *Backend) Put(
|
|||
return 0, err
|
||||
}
|
||||
|
||||
err = pool.BeginTxFunc(ctx, pgx.TxOptions{
|
||||
IsoLevel: pgx.Serializable,
|
||||
AccessMode: pgx.ReadWrite,
|
||||
}, func(tx pgx.Tx) error {
|
||||
now := timestamppb.Now()
|
||||
now := timestamppb.Now()
|
||||
|
||||
recordVersion, err := getLatestRecordVersion(ctx, tx)
|
||||
// add all the records
|
||||
recordTypes := map[string]struct{}{}
|
||||
for i, record := range records {
|
||||
recordTypes[record.GetType()] = struct{}{}
|
||||
|
||||
record = dup(record)
|
||||
record.ModifiedAt = now
|
||||
err := putRecordAndChange(ctx, pool, record)
|
||||
if err != nil {
|
||||
return fmt.Errorf("storage/postgres: error getting latest record version: %w", err)
|
||||
return serverVersion, fmt.Errorf("storage/postgres: error saving record: %w", err)
|
||||
}
|
||||
records[i] = record
|
||||
}
|
||||
|
||||
// add all the records
|
||||
recordTypes := map[string]struct{}{}
|
||||
for i, record := range records {
|
||||
recordTypes[record.GetType()] = struct{}{}
|
||||
|
||||
record = dup(record)
|
||||
record.ModifiedAt = now
|
||||
record.Version = recordVersion + uint64(i) + 1
|
||||
err := putRecordChange(ctx, tx, record)
|
||||
if err != nil {
|
||||
return fmt.Errorf("storage/postgres: error saving record change: %w", err)
|
||||
}
|
||||
|
||||
err = putRecord(ctx, tx, record)
|
||||
if err != nil {
|
||||
return fmt.Errorf("storage/postgres: error saving record: %w", err)
|
||||
}
|
||||
records[i] = record
|
||||
// enforce options for each record type
|
||||
for recordType := range recordTypes {
|
||||
options, err := getOptions(ctx, pool, recordType)
|
||||
if err != nil {
|
||||
return serverVersion, fmt.Errorf("storage/postgres: error getting options: %w", err)
|
||||
}
|
||||
|
||||
// enforce options for each record type
|
||||
for recordType := range recordTypes {
|
||||
options, err := getOptions(ctx, tx, recordType)
|
||||
if err != nil {
|
||||
return fmt.Errorf("storage/postgres: error getting options: %w", err)
|
||||
}
|
||||
err = enforceOptions(ctx, tx, recordType, options)
|
||||
if err != nil {
|
||||
return fmt.Errorf("storage/postgres: error enforcing options: %w", err)
|
||||
}
|
||||
err = enforceOptions(ctx, pool, recordType, options)
|
||||
if err != nil {
|
||||
return serverVersion, fmt.Errorf("storage/postgres: error enforcing options: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return serverVersion, err
|
||||
}
|
||||
|
||||
err = signalRecordChange(ctx, pool)
|
||||
|
|
|
@ -77,6 +77,41 @@ var migrations = []func(context.Context, pgx.Tx) error{
|
|||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
2: func(ctx context.Context, tx pgx.Tx) error {
|
||||
serverVersion := uint64(cryptutil.NewRandomUInt32())
|
||||
_, err := tx.Exec(ctx, `
|
||||
UPDATE `+schemaName+`.`+migrationInfoTableName+`
|
||||
SET server_version = $1
|
||||
`, serverVersion)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = tx.Exec(ctx, `
|
||||
DELETE FROM `+schemaName+`.`+recordChangesTableName+`
|
||||
`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = tx.Exec(ctx, `
|
||||
DELETE FROM `+schemaName+`.`+recordsTableName+`
|
||||
`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = tx.Exec(ctx, `
|
||||
ALTER TABLE `+schemaName+`.`+recordChangesTableName+`
|
||||
ALTER COLUMN version
|
||||
ADD GENERATED BY DEFAULT AS IDENTITY
|
||||
`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
|
|
@ -225,7 +225,7 @@ func maybeAcquireLease(ctx context.Context, q querier, leaseName, leaseID string
|
|||
return leaseHolderID, err
|
||||
}
|
||||
|
||||
func putRecordChange(ctx context.Context, q querier, record *databroker.Record) error {
|
||||
func putRecordAndChange(ctx context.Context, q querier, record *databroker.Record) error {
|
||||
data, err := jsonbFromAny(record.GetData())
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -233,40 +233,38 @@ func putRecordChange(ctx context.Context, q querier, record *databroker.Record)
|
|||
|
||||
modifiedAt := timestamptzFromTimestamppb(record.GetModifiedAt())
|
||||
deletedAt := timestamptzFromTimestamppb(record.GetDeletedAt())
|
||||
_, err = q.Exec(ctx, `
|
||||
INSERT INTO `+schemaName+`.`+recordChangesTableName+` (type, id, version, data, modified_at, deleted_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6)
|
||||
`, record.GetType(), record.GetId(), record.GetVersion(), data, modifiedAt, deletedAt)
|
||||
if err != nil {
|
||||
return err
|
||||
indexCIDR := &pgtype.Text{Status: pgtype.Null}
|
||||
if cidr := storage.GetRecordIndexCIDR(record.GetData()); cidr != nil {
|
||||
_ = indexCIDR.Set(cidr.String())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func putRecord(ctx context.Context, q querier, record *databroker.Record) error {
|
||||
data, err := jsonbFromAny(record.GetData())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
modifiedAt := timestamptzFromTimestamppb(record.GetModifiedAt())
|
||||
if record.GetDeletedAt() == nil {
|
||||
_, err = q.Exec(ctx, `
|
||||
INSERT INTO `+schemaName+`.`+recordsTableName+` (type, id, version, data, modified_at)
|
||||
query := `
|
||||
WITH t1 AS (
|
||||
INSERT INTO ` + schemaName + `.` + recordChangesTableName + ` (type, id, data, modified_at, deleted_at)
|
||||
VALUES ($1, $2, $3, $4, $5)
|
||||
RETURNING *
|
||||
)
|
||||
`
|
||||
if record.GetDeletedAt() == nil {
|
||||
query += `
|
||||
INSERT INTO ` + schemaName + `.` + recordsTableName + ` (type, id, version, data, modified_at, index_cidr)
|
||||
VALUES ($1, $2, (SELECT version FROM t1), $3, $4, $6)
|
||||
ON CONFLICT (type, id) DO UPDATE
|
||||
SET version=$3, data=$4, modified_at=$5
|
||||
`, record.GetType(), record.GetId(), record.GetVersion(), data, modifiedAt)
|
||||
SET version=(SELECT version FROM t1), data=$3, modified_at=$4, index_cidr=$6
|
||||
RETURNING ` + schemaName + `.` + recordsTableName + `.version
|
||||
`
|
||||
} else {
|
||||
_, err = q.Exec(ctx, `
|
||||
DELETE FROM `+schemaName+`.`+recordsTableName+`
|
||||
WHERE type=$1 AND id=$2 AND version<$3
|
||||
`, record.GetType(), record.GetId(), record.GetVersion())
|
||||
query += `
|
||||
DELETE FROM ` + schemaName + `.` + recordsTableName + `
|
||||
WHERE type=$1 AND id=$2
|
||||
RETURNING ` + schemaName + `.` + recordsTableName + `.version
|
||||
`
|
||||
}
|
||||
err = q.QueryRow(ctx, query, record.GetType(), record.GetId(), data, modifiedAt, deletedAt, indexCIDR).Scan(&record.Version)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue