storage/postgres: implement patch operation (#4656)

Implement the new Patch() method for the Postgres storage backend.
This commit is contained in:
Kenneth Jenkins 2023-11-02 12:07:36 -07:00 committed by GitHub
parent 4f648e9ac1
commit 4842002ed7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 97 additions and 4 deletions

View file

@ -10,6 +10,7 @@ import (
"github.com/cenkalti/backoff/v4" "github.com/cenkalti/backoff/v4"
"github.com/jackc/pgx/v5/pgxpool" "github.com/jackc/pgx/v5/pgxpool"
"google.golang.org/protobuf/types/known/fieldmaskpb"
"google.golang.org/protobuf/types/known/timestamppb" "google.golang.org/protobuf/types/known/timestamppb"
"github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/log"
@ -140,7 +141,7 @@ func (backend *Backend) Get(
return nil, err return nil, err
} }
return getRecord(ctx, conn, recordType, recordID) return getRecord(ctx, conn, recordType, recordID, lockModeNone)
} }
// GetOptions returns the options for the given record type. // GetOptions returns the options for the given record type.
@ -239,6 +240,42 @@ func (backend *Backend) Put(
return serverVersion, err return serverVersion, err
} }
// Patch updates specific fields of existing records in Postgres.
func (backend *Backend) Patch(
ctx context.Context,
records []*databroker.Record,
fields *fieldmaskpb.FieldMask,
) (uint64, []*databroker.Record, error) {
ctx, cancel := contextutil.Merge(ctx, backend.closeCtx)
defer cancel()
serverVersion, pool, err := backend.init(ctx)
if err != nil {
return serverVersion, nil, err
}
patchedRecords := make([]*databroker.Record, 0, len(records))
now := timestamppb.Now()
for _, record := range records {
record = dup(record)
record.ModifiedAt = now
err := patchRecord(ctx, pool, record, fields)
if storage.IsNotFound(err) {
continue
} else if err != nil {
err = fmt.Errorf("storage/postgres: error patching record %q of type %q: %w",
record.GetId(), record.GetType(), err)
return serverVersion, patchedRecords, err
}
patchedRecords = append(patchedRecords, record)
}
err = signalRecordChange(ctx, pool)
return serverVersion, patchedRecords, err
}
// SetOptions sets the options for the given record type. // SetOptions sets the options for the given record type.
func (backend *Backend) SetOptions( func (backend *Backend) SetOptions(
ctx context.Context, ctx context.Context,

View file

@ -18,6 +18,7 @@ import (
"github.com/pomerium/pomerium/pkg/grpc/databroker" "github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/protoutil" "github.com/pomerium/pomerium/pkg/protoutil"
"github.com/pomerium/pomerium/pkg/storage" "github.com/pomerium/pomerium/pkg/storage"
"github.com/pomerium/pomerium/pkg/storage/storagetest"
) )
const maxWait = time.Minute * 10 const maxWait = time.Minute * 10
@ -188,6 +189,10 @@ func TestBackend(t *testing.T) {
assert.Equal(t, []string{"capacity-test", "latest-test", "sync-test", "test-1", "unknown"}, types) assert.Equal(t, []string{"capacity-test", "latest-test", "sync-test", "test-1", "unknown"}, types)
}) })
t.Run("patch", func(t *testing.T) {
storagetest.TestBackendPatch(t, ctx, backend)
})
return nil return nil
})) }))
} }

View file

@ -12,10 +12,12 @@ import (
"github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgconn"
"github.com/jackc/pgx/v5/pgtype" "github.com/jackc/pgx/v5/pgtype"
"github.com/jackc/pgx/v5/pgxpool"
"google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/proto" "google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoregistry" "google.golang.org/protobuf/reflect/protoregistry"
"google.golang.org/protobuf/types/known/anypb" "google.golang.org/protobuf/types/known/anypb"
"google.golang.org/protobuf/types/known/fieldmaskpb"
"google.golang.org/protobuf/types/known/timestamppb" "google.golang.org/protobuf/types/known/timestamppb"
"github.com/pomerium/pomerium/pkg/grpc/databroker" "github.com/pomerium/pomerium/pkg/grpc/databroker"
@ -160,15 +162,24 @@ func getOptions(ctx context.Context, q querier, recordType string) (*databroker.
return options, nil return options, nil
} }
func getRecord(ctx context.Context, q querier, recordType, recordID string) (*databroker.Record, error) { type lockMode string
const (
lockModeNone lockMode = ""
lockModeUpdate lockMode = "FOR UPDATE"
)
func getRecord(
ctx context.Context, q querier, recordType, recordID string, lockMode lockMode,
) (*databroker.Record, error) {
var version uint64 var version uint64
var data []byte var data []byte
var modifiedAt pgtype.Timestamptz var modifiedAt pgtype.Timestamptz
err := q.QueryRow(ctx, ` err := q.QueryRow(ctx, `
SELECT version, data, modified_at SELECT version, data, modified_at
FROM `+schemaName+`.`+recordsTableName+` FROM `+schemaName+`.`+recordsTableName+`
WHERE type=$1 AND id=$2 WHERE type=$1 AND id=$2 `+string(lockMode),
`, recordType, recordID).Scan(&version, &data, &modifiedAt) recordType, recordID).Scan(&version, &data, &modifiedAt)
if isNotFound(err) { if isNotFound(err) {
return nil, storage.ErrNotFound return nil, storage.ErrNotFound
} else if err != nil { } else if err != nil {
@ -378,6 +389,34 @@ func putRecordAndChange(ctx context.Context, q querier, record *databroker.Recor
return nil return nil
} }
// patchRecord updates specific fields of an existing record.
func patchRecord(
ctx context.Context, p *pgxpool.Pool, record *databroker.Record, fields *fieldmaskpb.FieldMask,
) error {
tx, err := p.Begin(ctx)
if err != nil {
return err
}
defer func() { _ = tx.Rollback(ctx) }()
existing, err := getRecord(ctx, tx, record.GetType(), record.GetId(), lockModeUpdate)
if isNotFound(err) {
return storage.ErrNotFound
} else if err != nil {
return err
}
if err := storage.PatchRecord(existing, record, fields); err != nil {
return err
}
if err := putRecordAndChange(ctx, tx, record); err != nil {
return err
}
return tx.Commit(ctx)
}
func putService(ctx context.Context, q querier, svc *registry.Service, expiresAt time.Time) error { func putService(ctx context.Context, q querier, svc *registry.Service, expiresAt time.Time) error {
query := ` query := `
INSERT INTO ` + schemaName + `.` + servicesTableName + ` (kind, endpoint, expires_at) INSERT INTO ` + schemaName + `.` + servicesTableName + ` (kind, endpoint, expires_at)

View file

@ -10,6 +10,7 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"google.golang.org/protobuf/types/known/anypb" "google.golang.org/protobuf/types/known/anypb"
"google.golang.org/protobuf/types/known/fieldmaskpb" "google.golang.org/protobuf/types/known/fieldmaskpb"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/pomerium/pomerium/internal/testutil" "github.com/pomerium/pomerium/internal/testutil"
"github.com/pomerium/pomerium/pkg/grpc/databroker" "github.com/pomerium/pomerium/pkg/grpc/databroker"
@ -104,8 +105,19 @@ func TestBackendPatch(t *testing.T, ctx context.Context, backend BackendWithPatc
}`, updated[1].Data) }`, updated[1].Data)
// Verify that the updates will indeed be seen by a subsequent Get(). // Verify that the updates will indeed be seen by a subsequent Get().
// Note: first truncate the modified_at timestamps to 1 µs precision, as
// that is the maximum precision supported by Postgres.
r1, _ := backend.Get(ctx, "type.googleapis.com/session.Session", "session-1") r1, _ := backend.Get(ctx, "type.googleapis.com/session.Session", "session-1")
truncateTimestamps(updated[0].ModifiedAt, r1.ModifiedAt)
testutil.AssertProtoEqual(t, updated[0], r1) testutil.AssertProtoEqual(t, updated[0], r1)
r3, _ := backend.Get(ctx, "type.googleapis.com/session.Session", "session-3") r3, _ := backend.Get(ctx, "type.googleapis.com/session.Session", "session-3")
truncateTimestamps(updated[1].ModifiedAt, r3.ModifiedAt)
testutil.AssertProtoEqual(t, updated[1], r3) testutil.AssertProtoEqual(t, updated[1], r3)
} }
// truncateTimestamps truncates Timestamp messages to 1 µs precision.
func truncateTimestamps(ts ...*timestamppb.Timestamp) {
for _, t := range ts {
t.Nanos = (t.Nanos / 1000) * 1000
}
}