1
0
Fork 0
mirror of https://github.com/pomerium/pomerium.git synced 2025-05-15 10:07:47 +02:00

storage/postgres: implement patch operation ()

Implement the new Patch() method for the Postgres storage backend.
This commit is contained in:
Kenneth Jenkins 2023-11-02 12:07:36 -07:00 committed by GitHub
parent 4f648e9ac1
commit 4842002ed7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 97 additions and 4 deletions
pkg/storage/postgres

View file

@ -12,10 +12,12 @@ import (
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
"github.com/jackc/pgx/v5/pgtype"
"github.com/jackc/pgx/v5/pgxpool"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoregistry"
"google.golang.org/protobuf/types/known/anypb"
"google.golang.org/protobuf/types/known/fieldmaskpb"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
@ -160,15 +162,24 @@ func getOptions(ctx context.Context, q querier, recordType string) (*databroker.
return options, nil
}
func getRecord(ctx context.Context, q querier, recordType, recordID string) (*databroker.Record, error) {
type lockMode string
const (
lockModeNone lockMode = ""
lockModeUpdate lockMode = "FOR UPDATE"
)
func getRecord(
ctx context.Context, q querier, recordType, recordID string, lockMode lockMode,
) (*databroker.Record, error) {
var version uint64
var data []byte
var modifiedAt pgtype.Timestamptz
err := q.QueryRow(ctx, `
SELECT version, data, modified_at
FROM `+schemaName+`.`+recordsTableName+`
WHERE type=$1 AND id=$2
`, recordType, recordID).Scan(&version, &data, &modifiedAt)
WHERE type=$1 AND id=$2 `+string(lockMode),
recordType, recordID).Scan(&version, &data, &modifiedAt)
if isNotFound(err) {
return nil, storage.ErrNotFound
} else if err != nil {
@ -378,6 +389,34 @@ func putRecordAndChange(ctx context.Context, q querier, record *databroker.Recor
return nil
}
// patchRecord updates specific fields of an existing record.
func patchRecord(
ctx context.Context, p *pgxpool.Pool, record *databroker.Record, fields *fieldmaskpb.FieldMask,
) error {
tx, err := p.Begin(ctx)
if err != nil {
return err
}
defer func() { _ = tx.Rollback(ctx) }()
existing, err := getRecord(ctx, tx, record.GetType(), record.GetId(), lockModeUpdate)
if isNotFound(err) {
return storage.ErrNotFound
} else if err != nil {
return err
}
if err := storage.PatchRecord(existing, record, fields); err != nil {
return err
}
if err := putRecordAndChange(ctx, tx, record); err != nil {
return err
}
return tx.Commit(ctx)
}
func putService(ctx context.Context, q querier, svc *registry.Service, expiresAt time.Time) error {
query := `
INSERT INTO ` + schemaName + `.` + servicesTableName + ` (kind, endpoint, expires_at)