mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-19 12:07:18 +02:00
postgres: use CTE and GENERATED version number instead of serialized transaction (#3408)
* postgres: use CTE and GENERATED version number instead of serialized transaction * update server version * fix indexing CIDRs
This commit is contained in:
parent
a7bd284b52
commit
a2d5d8062b
5 changed files with 81 additions and 82 deletions
|
@ -2,7 +2,6 @@ package databroker
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
backoff "github.com/cenkalti/backoff/v4"
|
backoff "github.com/cenkalti/backoff/v4"
|
||||||
|
@ -174,12 +173,6 @@ func (syncer *Syncer) sync(ctx context.Context) error {
|
||||||
rec := res.GetRecord()
|
rec := res.GetRecord()
|
||||||
log.Debug(logCtxRec(ctx, rec)).Msg("syncer got record")
|
log.Debug(logCtxRec(ctx, rec)).Msg("syncer got record")
|
||||||
|
|
||||||
if syncer.recordVersion != res.GetRecord().GetVersion()-1 {
|
|
||||||
log.Error(logCtxRec(ctx, rec)).Err(err).
|
|
||||||
Msg("aborted sync due to missing record")
|
|
||||||
syncer.serverVersion = 0
|
|
||||||
return fmt.Errorf("missing record version")
|
|
||||||
}
|
|
||||||
syncer.recordVersion = res.GetRecord().GetVersion()
|
syncer.recordVersion = res.GetRecord().GetVersion()
|
||||||
if syncer.cfg.typeURL == "" || syncer.cfg.typeURL == res.GetRecord().GetType() {
|
if syncer.cfg.typeURL == "" || syncer.cfg.typeURL == res.GetRecord().GetType() {
|
||||||
ctx := logCtxRec(ctx, rec)
|
ctx := logCtxRec(ctx, rec)
|
||||||
|
|
|
@ -205,15 +205,9 @@ func TestSyncer(t *testing.T) {
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
t.Fatal("6. expected call to clear records due to skipped version")
|
t.Fatal("6. expected call to update records")
|
||||||
case <-clearCh:
|
|
||||||
}
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
t.Fatal("7. expected call to update records")
|
|
||||||
case records := <-updateCh:
|
case records := <-updateCh:
|
||||||
testutil.AssertProtoJSONEqual(t, `[{"id": "r3", "version": "1002"}, {"id": "r5", "version": "1004"}]`, records)
|
testutil.AssertProtoJSONEqual(t, `[{"id": "r5", "version": "1004"}]`, records)
|
||||||
}
|
}
|
||||||
|
|
||||||
assert.NoError(t, syncer.Close())
|
assert.NoError(t, syncer.Close())
|
||||||
|
|
|
@ -158,53 +158,32 @@ func (backend *Backend) Put(
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = pool.BeginTxFunc(ctx, pgx.TxOptions{
|
now := timestamppb.Now()
|
||||||
IsoLevel: pgx.Serializable,
|
|
||||||
AccessMode: pgx.ReadWrite,
|
|
||||||
}, func(tx pgx.Tx) error {
|
|
||||||
now := timestamppb.Now()
|
|
||||||
|
|
||||||
recordVersion, err := getLatestRecordVersion(ctx, tx)
|
// add all the records
|
||||||
|
recordTypes := map[string]struct{}{}
|
||||||
|
for i, record := range records {
|
||||||
|
recordTypes[record.GetType()] = struct{}{}
|
||||||
|
|
||||||
|
record = dup(record)
|
||||||
|
record.ModifiedAt = now
|
||||||
|
err := putRecordAndChange(ctx, pool, record)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("storage/postgres: error getting latest record version: %w", err)
|
return serverVersion, fmt.Errorf("storage/postgres: error saving record: %w", err)
|
||||||
}
|
}
|
||||||
|
records[i] = record
|
||||||
|
}
|
||||||
|
|
||||||
// add all the records
|
// enforce options for each record type
|
||||||
recordTypes := map[string]struct{}{}
|
for recordType := range recordTypes {
|
||||||
for i, record := range records {
|
options, err := getOptions(ctx, pool, recordType)
|
||||||
recordTypes[record.GetType()] = struct{}{}
|
if err != nil {
|
||||||
|
return serverVersion, fmt.Errorf("storage/postgres: error getting options: %w", err)
|
||||||
record = dup(record)
|
|
||||||
record.ModifiedAt = now
|
|
||||||
record.Version = recordVersion + uint64(i) + 1
|
|
||||||
err := putRecordChange(ctx, tx, record)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("storage/postgres: error saving record change: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = putRecord(ctx, tx, record)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("storage/postgres: error saving record: %w", err)
|
|
||||||
}
|
|
||||||
records[i] = record
|
|
||||||
}
|
}
|
||||||
|
err = enforceOptions(ctx, pool, recordType, options)
|
||||||
// enforce options for each record type
|
if err != nil {
|
||||||
for recordType := range recordTypes {
|
return serverVersion, fmt.Errorf("storage/postgres: error enforcing options: %w", err)
|
||||||
options, err := getOptions(ctx, tx, recordType)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("storage/postgres: error getting options: %w", err)
|
|
||||||
}
|
|
||||||
err = enforceOptions(ctx, tx, recordType, options)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("storage/postgres: error enforcing options: %w", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return serverVersion, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
err = signalRecordChange(ctx, pool)
|
err = signalRecordChange(ctx, pool)
|
||||||
|
|
|
@ -77,6 +77,41 @@ var migrations = []func(context.Context, pgx.Tx) error{
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
2: func(ctx context.Context, tx pgx.Tx) error {
|
||||||
|
serverVersion := uint64(cryptutil.NewRandomUInt32())
|
||||||
|
_, err := tx.Exec(ctx, `
|
||||||
|
UPDATE `+schemaName+`.`+migrationInfoTableName+`
|
||||||
|
SET server_version = $1
|
||||||
|
`, serverVersion)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = tx.Exec(ctx, `
|
||||||
|
DELETE FROM `+schemaName+`.`+recordChangesTableName+`
|
||||||
|
`)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = tx.Exec(ctx, `
|
||||||
|
DELETE FROM `+schemaName+`.`+recordsTableName+`
|
||||||
|
`)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = tx.Exec(ctx, `
|
||||||
|
ALTER TABLE `+schemaName+`.`+recordChangesTableName+`
|
||||||
|
ALTER COLUMN version
|
||||||
|
ADD GENERATED BY DEFAULT AS IDENTITY
|
||||||
|
`)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
|
@ -225,7 +225,7 @@ func maybeAcquireLease(ctx context.Context, q querier, leaseName, leaseID string
|
||||||
return leaseHolderID, err
|
return leaseHolderID, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func putRecordChange(ctx context.Context, q querier, record *databroker.Record) error {
|
func putRecordAndChange(ctx context.Context, q querier, record *databroker.Record) error {
|
||||||
data, err := jsonbFromAny(record.GetData())
|
data, err := jsonbFromAny(record.GetData())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -233,40 +233,38 @@ func putRecordChange(ctx context.Context, q querier, record *databroker.Record)
|
||||||
|
|
||||||
modifiedAt := timestamptzFromTimestamppb(record.GetModifiedAt())
|
modifiedAt := timestamptzFromTimestamppb(record.GetModifiedAt())
|
||||||
deletedAt := timestamptzFromTimestamppb(record.GetDeletedAt())
|
deletedAt := timestamptzFromTimestamppb(record.GetDeletedAt())
|
||||||
_, err = q.Exec(ctx, `
|
indexCIDR := &pgtype.Text{Status: pgtype.Null}
|
||||||
INSERT INTO `+schemaName+`.`+recordChangesTableName+` (type, id, version, data, modified_at, deleted_at)
|
if cidr := storage.GetRecordIndexCIDR(record.GetData()); cidr != nil {
|
||||||
VALUES ($1, $2, $3, $4, $5, $6)
|
_ = indexCIDR.Set(cidr.String())
|
||||||
`, record.GetType(), record.GetId(), record.GetVersion(), data, modifiedAt, deletedAt)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
query := `
|
||||||
}
|
WITH t1 AS (
|
||||||
|
INSERT INTO ` + schemaName + `.` + recordChangesTableName + ` (type, id, data, modified_at, deleted_at)
|
||||||
func putRecord(ctx context.Context, q querier, record *databroker.Record) error {
|
|
||||||
data, err := jsonbFromAny(record.GetData())
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
modifiedAt := timestamptzFromTimestamppb(record.GetModifiedAt())
|
|
||||||
if record.GetDeletedAt() == nil {
|
|
||||||
_, err = q.Exec(ctx, `
|
|
||||||
INSERT INTO `+schemaName+`.`+recordsTableName+` (type, id, version, data, modified_at)
|
|
||||||
VALUES ($1, $2, $3, $4, $5)
|
VALUES ($1, $2, $3, $4, $5)
|
||||||
|
RETURNING *
|
||||||
|
)
|
||||||
|
`
|
||||||
|
if record.GetDeletedAt() == nil {
|
||||||
|
query += `
|
||||||
|
INSERT INTO ` + schemaName + `.` + recordsTableName + ` (type, id, version, data, modified_at, index_cidr)
|
||||||
|
VALUES ($1, $2, (SELECT version FROM t1), $3, $4, $6)
|
||||||
ON CONFLICT (type, id) DO UPDATE
|
ON CONFLICT (type, id) DO UPDATE
|
||||||
SET version=$3, data=$4, modified_at=$5
|
SET version=(SELECT version FROM t1), data=$3, modified_at=$4, index_cidr=$6
|
||||||
`, record.GetType(), record.GetId(), record.GetVersion(), data, modifiedAt)
|
RETURNING ` + schemaName + `.` + recordsTableName + `.version
|
||||||
|
`
|
||||||
} else {
|
} else {
|
||||||
_, err = q.Exec(ctx, `
|
query += `
|
||||||
DELETE FROM `+schemaName+`.`+recordsTableName+`
|
DELETE FROM ` + schemaName + `.` + recordsTableName + `
|
||||||
WHERE type=$1 AND id=$2 AND version<$3
|
WHERE type=$1 AND id=$2
|
||||||
`, record.GetType(), record.GetId(), record.GetVersion())
|
RETURNING ` + schemaName + `.` + recordsTableName + `.version
|
||||||
|
`
|
||||||
}
|
}
|
||||||
|
err = q.QueryRow(ctx, query, record.GetType(), record.GetId(), data, modifiedAt, deletedAt, indexCIDR).Scan(&record.Version)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue