mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-20 12:37:16 +02:00
storage/postgres: implement patch operation (#4656)
Implement the new Patch() method for the Postgres storage backend.
This commit is contained in:
parent
4f648e9ac1
commit
4842002ed7
4 changed files with 97 additions and 4 deletions
|
@ -10,6 +10,7 @@ import (
|
||||||
|
|
||||||
"github.com/cenkalti/backoff/v4"
|
"github.com/cenkalti/backoff/v4"
|
||||||
"github.com/jackc/pgx/v5/pgxpool"
|
"github.com/jackc/pgx/v5/pgxpool"
|
||||||
|
"google.golang.org/protobuf/types/known/fieldmaskpb"
|
||||||
"google.golang.org/protobuf/types/known/timestamppb"
|
"google.golang.org/protobuf/types/known/timestamppb"
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/internal/log"
|
"github.com/pomerium/pomerium/internal/log"
|
||||||
|
@ -140,7 +141,7 @@ func (backend *Backend) Get(
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return getRecord(ctx, conn, recordType, recordID)
|
return getRecord(ctx, conn, recordType, recordID, lockModeNone)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetOptions returns the options for the given record type.
|
// GetOptions returns the options for the given record type.
|
||||||
|
@ -239,6 +240,42 @@ func (backend *Backend) Put(
|
||||||
return serverVersion, err
|
return serverVersion, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Patch updates specific fields of existing records in Postgres.
|
||||||
|
func (backend *Backend) Patch(
|
||||||
|
ctx context.Context,
|
||||||
|
records []*databroker.Record,
|
||||||
|
fields *fieldmaskpb.FieldMask,
|
||||||
|
) (uint64, []*databroker.Record, error) {
|
||||||
|
ctx, cancel := contextutil.Merge(ctx, backend.closeCtx)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
serverVersion, pool, err := backend.init(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return serverVersion, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
patchedRecords := make([]*databroker.Record, 0, len(records))
|
||||||
|
|
||||||
|
now := timestamppb.Now()
|
||||||
|
|
||||||
|
for _, record := range records {
|
||||||
|
record = dup(record)
|
||||||
|
record.ModifiedAt = now
|
||||||
|
err := patchRecord(ctx, pool, record, fields)
|
||||||
|
if storage.IsNotFound(err) {
|
||||||
|
continue
|
||||||
|
} else if err != nil {
|
||||||
|
err = fmt.Errorf("storage/postgres: error patching record %q of type %q: %w",
|
||||||
|
record.GetId(), record.GetType(), err)
|
||||||
|
return serverVersion, patchedRecords, err
|
||||||
|
}
|
||||||
|
patchedRecords = append(patchedRecords, record)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = signalRecordChange(ctx, pool)
|
||||||
|
return serverVersion, patchedRecords, err
|
||||||
|
}
|
||||||
|
|
||||||
// SetOptions sets the options for the given record type.
|
// SetOptions sets the options for the given record type.
|
||||||
func (backend *Backend) SetOptions(
|
func (backend *Backend) SetOptions(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
|
|
|
@ -18,6 +18,7 @@ import (
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||||
"github.com/pomerium/pomerium/pkg/protoutil"
|
"github.com/pomerium/pomerium/pkg/protoutil"
|
||||||
"github.com/pomerium/pomerium/pkg/storage"
|
"github.com/pomerium/pomerium/pkg/storage"
|
||||||
|
"github.com/pomerium/pomerium/pkg/storage/storagetest"
|
||||||
)
|
)
|
||||||
|
|
||||||
const maxWait = time.Minute * 10
|
const maxWait = time.Minute * 10
|
||||||
|
@ -188,6 +189,10 @@ func TestBackend(t *testing.T) {
|
||||||
assert.Equal(t, []string{"capacity-test", "latest-test", "sync-test", "test-1", "unknown"}, types)
|
assert.Equal(t, []string{"capacity-test", "latest-test", "sync-test", "test-1", "unknown"}, types)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("patch", func(t *testing.T) {
|
||||||
|
storagetest.TestBackendPatch(t, ctx, backend)
|
||||||
|
})
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
|
|
@ -12,10 +12,12 @@ import (
|
||||||
"github.com/jackc/pgx/v5"
|
"github.com/jackc/pgx/v5"
|
||||||
"github.com/jackc/pgx/v5/pgconn"
|
"github.com/jackc/pgx/v5/pgconn"
|
||||||
"github.com/jackc/pgx/v5/pgtype"
|
"github.com/jackc/pgx/v5/pgtype"
|
||||||
|
"github.com/jackc/pgx/v5/pgxpool"
|
||||||
"google.golang.org/protobuf/encoding/protojson"
|
"google.golang.org/protobuf/encoding/protojson"
|
||||||
"google.golang.org/protobuf/proto"
|
"google.golang.org/protobuf/proto"
|
||||||
"google.golang.org/protobuf/reflect/protoregistry"
|
"google.golang.org/protobuf/reflect/protoregistry"
|
||||||
"google.golang.org/protobuf/types/known/anypb"
|
"google.golang.org/protobuf/types/known/anypb"
|
||||||
|
"google.golang.org/protobuf/types/known/fieldmaskpb"
|
||||||
"google.golang.org/protobuf/types/known/timestamppb"
|
"google.golang.org/protobuf/types/known/timestamppb"
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||||
|
@ -160,15 +162,24 @@ func getOptions(ctx context.Context, q querier, recordType string) (*databroker.
|
||||||
return options, nil
|
return options, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func getRecord(ctx context.Context, q querier, recordType, recordID string) (*databroker.Record, error) {
|
type lockMode string
|
||||||
|
|
||||||
|
const (
|
||||||
|
lockModeNone lockMode = ""
|
||||||
|
lockModeUpdate lockMode = "FOR UPDATE"
|
||||||
|
)
|
||||||
|
|
||||||
|
func getRecord(
|
||||||
|
ctx context.Context, q querier, recordType, recordID string, lockMode lockMode,
|
||||||
|
) (*databroker.Record, error) {
|
||||||
var version uint64
|
var version uint64
|
||||||
var data []byte
|
var data []byte
|
||||||
var modifiedAt pgtype.Timestamptz
|
var modifiedAt pgtype.Timestamptz
|
||||||
err := q.QueryRow(ctx, `
|
err := q.QueryRow(ctx, `
|
||||||
SELECT version, data, modified_at
|
SELECT version, data, modified_at
|
||||||
FROM `+schemaName+`.`+recordsTableName+`
|
FROM `+schemaName+`.`+recordsTableName+`
|
||||||
WHERE type=$1 AND id=$2
|
WHERE type=$1 AND id=$2 `+string(lockMode),
|
||||||
`, recordType, recordID).Scan(&version, &data, &modifiedAt)
|
recordType, recordID).Scan(&version, &data, &modifiedAt)
|
||||||
if isNotFound(err) {
|
if isNotFound(err) {
|
||||||
return nil, storage.ErrNotFound
|
return nil, storage.ErrNotFound
|
||||||
} else if err != nil {
|
} else if err != nil {
|
||||||
|
@ -378,6 +389,34 @@ func putRecordAndChange(ctx context.Context, q querier, record *databroker.Recor
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// patchRecord updates specific fields of an existing record.
|
||||||
|
func patchRecord(
|
||||||
|
ctx context.Context, p *pgxpool.Pool, record *databroker.Record, fields *fieldmaskpb.FieldMask,
|
||||||
|
) error {
|
||||||
|
tx, err := p.Begin(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer func() { _ = tx.Rollback(ctx) }()
|
||||||
|
|
||||||
|
existing, err := getRecord(ctx, tx, record.GetType(), record.GetId(), lockModeUpdate)
|
||||||
|
if isNotFound(err) {
|
||||||
|
return storage.ErrNotFound
|
||||||
|
} else if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := storage.PatchRecord(existing, record, fields); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := putRecordAndChange(ctx, tx, record); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return tx.Commit(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
func putService(ctx context.Context, q querier, svc *registry.Service, expiresAt time.Time) error {
|
func putService(ctx context.Context, q querier, svc *registry.Service, expiresAt time.Time) error {
|
||||||
query := `
|
query := `
|
||||||
INSERT INTO ` + schemaName + `.` + servicesTableName + ` (kind, endpoint, expires_at)
|
INSERT INTO ` + schemaName + `.` + servicesTableName + ` (kind, endpoint, expires_at)
|
||||||
|
|
|
@ -10,6 +10,7 @@ import (
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"google.golang.org/protobuf/types/known/anypb"
|
"google.golang.org/protobuf/types/known/anypb"
|
||||||
"google.golang.org/protobuf/types/known/fieldmaskpb"
|
"google.golang.org/protobuf/types/known/fieldmaskpb"
|
||||||
|
"google.golang.org/protobuf/types/known/timestamppb"
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/internal/testutil"
|
"github.com/pomerium/pomerium/internal/testutil"
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||||
|
@ -104,8 +105,19 @@ func TestBackendPatch(t *testing.T, ctx context.Context, backend BackendWithPatc
|
||||||
}`, updated[1].Data)
|
}`, updated[1].Data)
|
||||||
|
|
||||||
// Verify that the updates will indeed be seen by a subsequent Get().
|
// Verify that the updates will indeed be seen by a subsequent Get().
|
||||||
|
// Note: first truncate the modified_at timestamps to 1 µs precision, as
|
||||||
|
// that is the maximum precision supported by Postgres.
|
||||||
r1, _ := backend.Get(ctx, "type.googleapis.com/session.Session", "session-1")
|
r1, _ := backend.Get(ctx, "type.googleapis.com/session.Session", "session-1")
|
||||||
|
truncateTimestamps(updated[0].ModifiedAt, r1.ModifiedAt)
|
||||||
testutil.AssertProtoEqual(t, updated[0], r1)
|
testutil.AssertProtoEqual(t, updated[0], r1)
|
||||||
r3, _ := backend.Get(ctx, "type.googleapis.com/session.Session", "session-3")
|
r3, _ := backend.Get(ctx, "type.googleapis.com/session.Session", "session-3")
|
||||||
|
truncateTimestamps(updated[1].ModifiedAt, r3.ModifiedAt)
|
||||||
testutil.AssertProtoEqual(t, updated[1], r3)
|
testutil.AssertProtoEqual(t, updated[1], r3)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// truncateTimestamps truncates Timestamp messages to 1 µs precision.
|
||||||
|
func truncateTimestamps(ts ...*timestamppb.Timestamp) {
|
||||||
|
for _, t := range ts {
|
||||||
|
t.Nanos = (t.Nanos / 1000) * 1000
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue