1
0
Fork 0
mirror of https://github.com/pomerium/pomerium.git synced 2025-05-15 01:57:45 +02:00

postgres: upgrade to pgx v5 ()

This commit is contained in:
Caleb Doxsey 2022-12-19 12:47:35 -07:00 committed by GitHub
parent f99ea7c8ad
commit c048af7523
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 47 additions and 143 deletions
pkg/storage/postgres

View file

@ -8,9 +8,9 @@ import (
"strings"
"time"
"github.com/jackc/pgconn"
"github.com/jackc/pgtype"
"github.com/jackc/pgx/v4"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
"github.com/jackc/pgx/v5/pgtype"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoregistry"
@ -100,7 +100,7 @@ func getNextChangedRecord(ctx context.Context, q querier, recordType string, aft
for {
var recordID string
var version uint64
var data pgtype.JSONB
var data []byte
var modifiedAt pgtype.Timestamptz
var deletedAt pgtype.Timestamptz
query := `
@ -125,7 +125,7 @@ func getNextChangedRecord(ctx context.Context, q querier, recordType string, aft
}
afterRecordVersion = version
any, err := protoutil.UnmarshalAnyJSON(data.Bytes)
any, err := protoutil.UnmarshalAnyJSON(data)
if isUnknownType(err) {
// ignore
continue
@ -155,15 +155,15 @@ func getOptions(ctx context.Context, q querier, recordType string) (*databroker.
return nil, err
}
options := new(databroker.Options)
if capacity.Status == pgtype.Present {
options.Capacity = proto.Uint64(uint64(capacity.Int))
if capacity.Valid {
options.Capacity = proto.Uint64(uint64(capacity.Int64))
}
return options, nil
}
func getRecord(ctx context.Context, q querier, recordType, recordID string) (*databroker.Record, error) {
var version uint64
var data pgtype.JSONB
var data []byte
var modifiedAt pgtype.Timestamptz
err := q.QueryRow(ctx, `
SELECT version, data, modified_at
@ -176,7 +176,7 @@ func getRecord(ctx context.Context, q querier, recordType, recordID string) (*da
return nil, fmt.Errorf("postgres: failed to execute query: %w", err)
}
any, err := protoutil.UnmarshalAnyJSON(data.Bytes)
any, err := protoutil.UnmarshalAnyJSON(data)
if isUnknownType(err) {
return nil, storage.ErrNotFound
} else if err != nil {
@ -220,14 +220,14 @@ func listRecords(ctx context.Context, q querier, expr storage.FilterExpression,
for rows.Next() {
var recordType, id string
var version uint64
var data pgtype.JSONB
var data []byte
var modifiedAt pgtype.Timestamptz
err = rows.Scan(&recordType, &id, &version, &data, &modifiedAt)
if err != nil {
return nil, fmt.Errorf("postgres: failed to scan row: %w", err)
}
any, err := protoutil.UnmarshalAnyJSON(data.Bytes)
any, err := protoutil.UnmarshalAnyJSON(data)
if isUnknownType(err) {
// ignore records with an unknown type
continue
@ -308,9 +308,10 @@ func putRecordAndChange(ctx context.Context, q querier, record *databroker.Recor
modifiedAt := timestamptzFromTimestamppb(record.GetModifiedAt())
deletedAt := timestamptzFromTimestamppb(record.GetDeletedAt())
indexCIDR := &pgtype.Text{Status: pgtype.Null}
indexCIDR := &pgtype.Text{Valid: false}
if cidr := storage.GetRecordIndexCIDR(record.GetData()); cidr != nil {
_ = indexCIDR.Set(cidr.String())
indexCIDR.String = cidr.String()
indexCIDR.Valid = true
}
query := `
@ -359,10 +360,10 @@ func putService(ctx context.Context, q querier, svc *registry.Service, expiresAt
}
func setOptions(ctx context.Context, q querier, recordType string, options *databroker.Options) error {
capacity := pgtype.Int8{Status: pgtype.Null}
capacity := pgtype.Int8{}
if options != nil && options.Capacity != nil {
capacity.Int = int64(options.GetCapacity())
capacity.Status = pgtype.Present
capacity.Int64 = int64(options.GetCapacity())
capacity.Valid = true
}
_, err := q.Exec(ctx, `
@ -384,21 +385,16 @@ func signalServiceChange(ctx context.Context, q querier) error {
return err
}
func jsonbFromAny(any *anypb.Any) (pgtype.JSONB, error) {
func jsonbFromAny(any *anypb.Any) ([]byte, error) {
if any == nil {
return pgtype.JSONB{Status: pgtype.Null}, nil
return nil, nil
}
bs, err := protojson.Marshal(any)
if err != nil {
return pgtype.JSONB{Status: pgtype.Null}, err
}
return pgtype.JSONB{Bytes: bs, Status: pgtype.Present}, nil
return protojson.Marshal(any)
}
func timestamppbFromTimestamptz(ts pgtype.Timestamptz) *timestamppb.Timestamp {
if ts.Status != pgtype.Present {
if !ts.Valid {
return nil
}
return timestamppb.New(ts.Time)
@ -406,9 +402,9 @@ func timestamppbFromTimestamptz(ts pgtype.Timestamptz) *timestamppb.Timestamp {
func timestamptzFromTimestamppb(ts *timestamppb.Timestamp) pgtype.Timestamptz {
if !ts.IsValid() {
return pgtype.Timestamptz{Status: pgtype.Null}
return pgtype.Timestamptz{}
}
return pgtype.Timestamptz{Time: ts.AsTime(), Status: pgtype.Present}
return pgtype.Timestamptz{Time: ts.AsTime(), Valid: true}
}
func isNotFound(err error) bool {