diff --git a/pkg/storage/postgres/backend.go b/pkg/storage/postgres/backend.go index 306142460..ccf44a0b7 100644 --- a/pkg/storage/postgres/backend.go +++ b/pkg/storage/postgres/backend.go @@ -10,6 +10,7 @@ import ( "github.com/cenkalti/backoff/v4" "github.com/jackc/pgx/v5/pgxpool" + "google.golang.org/protobuf/types/known/fieldmaskpb" "google.golang.org/protobuf/types/known/timestamppb" "github.com/pomerium/pomerium/internal/log" @@ -140,7 +141,7 @@ func (backend *Backend) Get( return nil, err } - return getRecord(ctx, conn, recordType, recordID) + return getRecord(ctx, conn, recordType, recordID, lockModeNone) } // GetOptions returns the options for the given record type. @@ -239,6 +240,42 @@ func (backend *Backend) Put( return serverVersion, err } +// Patch updates specific fields of existing records in Postgres. +func (backend *Backend) Patch( + ctx context.Context, + records []*databroker.Record, + fields *fieldmaskpb.FieldMask, +) (uint64, []*databroker.Record, error) { + ctx, cancel := contextutil.Merge(ctx, backend.closeCtx) + defer cancel() + + serverVersion, pool, err := backend.init(ctx) + if err != nil { + return serverVersion, nil, err + } + + patchedRecords := make([]*databroker.Record, 0, len(records)) + + now := timestamppb.Now() + + for _, record := range records { + record = dup(record) + record.ModifiedAt = now + err := patchRecord(ctx, pool, record, fields) + if storage.IsNotFound(err) { + continue + } else if err != nil { + err = fmt.Errorf("storage/postgres: error patching record %q of type %q: %w", + record.GetId(), record.GetType(), err) + return serverVersion, patchedRecords, err + } + patchedRecords = append(patchedRecords, record) + } + + err = signalRecordChange(ctx, pool) + return serverVersion, patchedRecords, err +} + // SetOptions sets the options for the given record type. func (backend *Backend) SetOptions( ctx context.Context, diff --git a/pkg/storage/postgres/backend_test.go b/pkg/storage/postgres/backend_test.go index 6d82b4bbc..8964c5732 100644 --- a/pkg/storage/postgres/backend_test.go +++ b/pkg/storage/postgres/backend_test.go @@ -18,6 +18,7 @@ import ( "github.com/pomerium/pomerium/pkg/grpc/databroker" "github.com/pomerium/pomerium/pkg/protoutil" "github.com/pomerium/pomerium/pkg/storage" + "github.com/pomerium/pomerium/pkg/storage/storagetest" ) const maxWait = time.Minute * 10 @@ -188,6 +189,10 @@ func TestBackend(t *testing.T) { assert.Equal(t, []string{"capacity-test", "latest-test", "sync-test", "test-1", "unknown"}, types) }) + t.Run("patch", func(t *testing.T) { + storagetest.TestBackendPatch(t, ctx, backend) + }) + return nil })) } diff --git a/pkg/storage/postgres/postgres.go b/pkg/storage/postgres/postgres.go index 36efb771d..69a10fd8b 100644 --- a/pkg/storage/postgres/postgres.go +++ b/pkg/storage/postgres/postgres.go @@ -12,10 +12,12 @@ import ( "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgxpool" "google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/reflect/protoregistry" "google.golang.org/protobuf/types/known/anypb" + "google.golang.org/protobuf/types/known/fieldmaskpb" "google.golang.org/protobuf/types/known/timestamppb" "github.com/pomerium/pomerium/pkg/grpc/databroker" @@ -160,15 +162,24 @@ func getOptions(ctx context.Context, q querier, recordType string) (*databroker. return options, nil } -func getRecord(ctx context.Context, q querier, recordType, recordID string) (*databroker.Record, error) { +type lockMode string + +const ( + lockModeNone lockMode = "" + lockModeUpdate lockMode = "FOR UPDATE" +) + +func getRecord( + ctx context.Context, q querier, recordType, recordID string, lockMode lockMode, +) (*databroker.Record, error) { var version uint64 var data []byte var modifiedAt pgtype.Timestamptz err := q.QueryRow(ctx, ` SELECT version, data, modified_at FROM `+schemaName+`.`+recordsTableName+` - WHERE type=$1 AND id=$2 - `, recordType, recordID).Scan(&version, &data, &modifiedAt) + WHERE type=$1 AND id=$2 `+string(lockMode), + recordType, recordID).Scan(&version, &data, &modifiedAt) if isNotFound(err) { return nil, storage.ErrNotFound } else if err != nil { @@ -378,6 +389,34 @@ func putRecordAndChange(ctx context.Context, q querier, record *databroker.Recor return nil } +// patchRecord updates specific fields of an existing record. +func patchRecord( + ctx context.Context, p *pgxpool.Pool, record *databroker.Record, fields *fieldmaskpb.FieldMask, +) error { + tx, err := p.Begin(ctx) + if err != nil { + return err + } + defer func() { _ = tx.Rollback(ctx) }() + + existing, err := getRecord(ctx, tx, record.GetType(), record.GetId(), lockModeUpdate) + if isNotFound(err) { + return storage.ErrNotFound + } else if err != nil { + return err + } + + if err := storage.PatchRecord(existing, record, fields); err != nil { + return err + } + + if err := putRecordAndChange(ctx, tx, record); err != nil { + return err + } + + return tx.Commit(ctx) +} + func putService(ctx context.Context, q querier, svc *registry.Service, expiresAt time.Time) error { query := ` INSERT INTO ` + schemaName + `.` + servicesTableName + ` (kind, endpoint, expires_at) diff --git a/pkg/storage/storagetest/storagetest.go b/pkg/storage/storagetest/storagetest.go index b8a290550..e34e4f00b 100644 --- a/pkg/storage/storagetest/storagetest.go +++ b/pkg/storage/storagetest/storagetest.go @@ -10,6 +10,7 @@ import ( "github.com/stretchr/testify/require" "google.golang.org/protobuf/types/known/anypb" "google.golang.org/protobuf/types/known/fieldmaskpb" + "google.golang.org/protobuf/types/known/timestamppb" "github.com/pomerium/pomerium/internal/testutil" "github.com/pomerium/pomerium/pkg/grpc/databroker" @@ -104,8 +105,19 @@ func TestBackendPatch(t *testing.T, ctx context.Context, backend BackendWithPatc }`, updated[1].Data) // Verify that the updates will indeed be seen by a subsequent Get(). + // Note: first truncate the modified_at timestamps to 1 µs precision, as + // that is the maximum precision supported by Postgres. r1, _ := backend.Get(ctx, "type.googleapis.com/session.Session", "session-1") + truncateTimestamps(updated[0].ModifiedAt, r1.ModifiedAt) testutil.AssertProtoEqual(t, updated[0], r1) r3, _ := backend.Get(ctx, "type.googleapis.com/session.Session", "session-3") + truncateTimestamps(updated[1].ModifiedAt, r3.ModifiedAt) testutil.AssertProtoEqual(t, updated[1], r3) } + +// truncateTimestamps truncates Timestamp messages to 1 µs precision. +func truncateTimestamps(ts ...*timestamppb.Timestamp) { + for _, t := range ts { + t.Nanos = (t.Nanos / 1000) * 1000 + } +}