mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-15 01:57:45 +02:00
postgres: upgrade to pgx v5 (#3826)
This commit is contained in:
parent
f99ea7c8ad
commit
c048af7523
7 changed files with 47 additions and 143 deletions
pkg/storage/postgres
|
@ -8,9 +8,9 @@ import (
|
|||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgconn"
|
||||
"github.com/jackc/pgtype"
|
||||
"github.com/jackc/pgx/v4"
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
"google.golang.org/protobuf/proto"
|
||||
"google.golang.org/protobuf/reflect/protoregistry"
|
||||
|
@ -100,7 +100,7 @@ func getNextChangedRecord(ctx context.Context, q querier, recordType string, aft
|
|||
for {
|
||||
var recordID string
|
||||
var version uint64
|
||||
var data pgtype.JSONB
|
||||
var data []byte
|
||||
var modifiedAt pgtype.Timestamptz
|
||||
var deletedAt pgtype.Timestamptz
|
||||
query := `
|
||||
|
@ -125,7 +125,7 @@ func getNextChangedRecord(ctx context.Context, q querier, recordType string, aft
|
|||
}
|
||||
afterRecordVersion = version
|
||||
|
||||
any, err := protoutil.UnmarshalAnyJSON(data.Bytes)
|
||||
any, err := protoutil.UnmarshalAnyJSON(data)
|
||||
if isUnknownType(err) {
|
||||
// ignore
|
||||
continue
|
||||
|
@ -155,15 +155,15 @@ func getOptions(ctx context.Context, q querier, recordType string) (*databroker.
|
|||
return nil, err
|
||||
}
|
||||
options := new(databroker.Options)
|
||||
if capacity.Status == pgtype.Present {
|
||||
options.Capacity = proto.Uint64(uint64(capacity.Int))
|
||||
if capacity.Valid {
|
||||
options.Capacity = proto.Uint64(uint64(capacity.Int64))
|
||||
}
|
||||
return options, nil
|
||||
}
|
||||
|
||||
func getRecord(ctx context.Context, q querier, recordType, recordID string) (*databroker.Record, error) {
|
||||
var version uint64
|
||||
var data pgtype.JSONB
|
||||
var data []byte
|
||||
var modifiedAt pgtype.Timestamptz
|
||||
err := q.QueryRow(ctx, `
|
||||
SELECT version, data, modified_at
|
||||
|
@ -176,7 +176,7 @@ func getRecord(ctx context.Context, q querier, recordType, recordID string) (*da
|
|||
return nil, fmt.Errorf("postgres: failed to execute query: %w", err)
|
||||
}
|
||||
|
||||
any, err := protoutil.UnmarshalAnyJSON(data.Bytes)
|
||||
any, err := protoutil.UnmarshalAnyJSON(data)
|
||||
if isUnknownType(err) {
|
||||
return nil, storage.ErrNotFound
|
||||
} else if err != nil {
|
||||
|
@ -220,14 +220,14 @@ func listRecords(ctx context.Context, q querier, expr storage.FilterExpression,
|
|||
for rows.Next() {
|
||||
var recordType, id string
|
||||
var version uint64
|
||||
var data pgtype.JSONB
|
||||
var data []byte
|
||||
var modifiedAt pgtype.Timestamptz
|
||||
err = rows.Scan(&recordType, &id, &version, &data, &modifiedAt)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("postgres: failed to scan row: %w", err)
|
||||
}
|
||||
|
||||
any, err := protoutil.UnmarshalAnyJSON(data.Bytes)
|
||||
any, err := protoutil.UnmarshalAnyJSON(data)
|
||||
if isUnknownType(err) {
|
||||
// ignore records with an unknown type
|
||||
continue
|
||||
|
@ -308,9 +308,10 @@ func putRecordAndChange(ctx context.Context, q querier, record *databroker.Recor
|
|||
|
||||
modifiedAt := timestamptzFromTimestamppb(record.GetModifiedAt())
|
||||
deletedAt := timestamptzFromTimestamppb(record.GetDeletedAt())
|
||||
indexCIDR := &pgtype.Text{Status: pgtype.Null}
|
||||
indexCIDR := &pgtype.Text{Valid: false}
|
||||
if cidr := storage.GetRecordIndexCIDR(record.GetData()); cidr != nil {
|
||||
_ = indexCIDR.Set(cidr.String())
|
||||
indexCIDR.String = cidr.String()
|
||||
indexCIDR.Valid = true
|
||||
}
|
||||
|
||||
query := `
|
||||
|
@ -359,10 +360,10 @@ func putService(ctx context.Context, q querier, svc *registry.Service, expiresAt
|
|||
}
|
||||
|
||||
func setOptions(ctx context.Context, q querier, recordType string, options *databroker.Options) error {
|
||||
capacity := pgtype.Int8{Status: pgtype.Null}
|
||||
capacity := pgtype.Int8{}
|
||||
if options != nil && options.Capacity != nil {
|
||||
capacity.Int = int64(options.GetCapacity())
|
||||
capacity.Status = pgtype.Present
|
||||
capacity.Int64 = int64(options.GetCapacity())
|
||||
capacity.Valid = true
|
||||
}
|
||||
|
||||
_, err := q.Exec(ctx, `
|
||||
|
@ -384,21 +385,16 @@ func signalServiceChange(ctx context.Context, q querier) error {
|
|||
return err
|
||||
}
|
||||
|
||||
func jsonbFromAny(any *anypb.Any) (pgtype.JSONB, error) {
|
||||
func jsonbFromAny(any *anypb.Any) ([]byte, error) {
|
||||
if any == nil {
|
||||
return pgtype.JSONB{Status: pgtype.Null}, nil
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
bs, err := protojson.Marshal(any)
|
||||
if err != nil {
|
||||
return pgtype.JSONB{Status: pgtype.Null}, err
|
||||
}
|
||||
|
||||
return pgtype.JSONB{Bytes: bs, Status: pgtype.Present}, nil
|
||||
return protojson.Marshal(any)
|
||||
}
|
||||
|
||||
func timestamppbFromTimestamptz(ts pgtype.Timestamptz) *timestamppb.Timestamp {
|
||||
if ts.Status != pgtype.Present {
|
||||
if !ts.Valid {
|
||||
return nil
|
||||
}
|
||||
return timestamppb.New(ts.Time)
|
||||
|
@ -406,9 +402,9 @@ func timestamppbFromTimestamptz(ts pgtype.Timestamptz) *timestamppb.Timestamp {
|
|||
|
||||
func timestamptzFromTimestamppb(ts *timestamppb.Timestamp) pgtype.Timestamptz {
|
||||
if !ts.IsValid() {
|
||||
return pgtype.Timestamptz{Status: pgtype.Null}
|
||||
return pgtype.Timestamptz{}
|
||||
}
|
||||
return pgtype.Timestamptz{Time: ts.AsTime(), Status: pgtype.Present}
|
||||
return pgtype.Timestamptz{Time: ts.AsTime(), Valid: true}
|
||||
}
|
||||
|
||||
func isNotFound(err error) bool {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue