mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-18 11:37:08 +02:00
postgres: handle unknown types (#3632)
This commit is contained in:
parent
95753de85d
commit
3fec00f2a8
2 changed files with 97 additions and 52 deletions
|
@ -163,6 +163,24 @@ func TestBackend(t *testing.T) {
|
||||||
assert.NoError(t, stream.Err())
|
assert.NoError(t, stream.Err())
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("unknown type", func(t *testing.T) {
|
||||||
|
_, err := backend.pool.Exec(ctx, `
|
||||||
|
INSERT INTO `+schemaName+"."+recordsTableName+` (type, id, version, data)
|
||||||
|
VALUES ('unknown', '1', 1000, '{"@type":"UNKNOWN","value":{}}')
|
||||||
|
`)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
_, err = backend.Get(ctx, "unknown", "1")
|
||||||
|
assert.ErrorIs(t, err, storage.ErrNotFound)
|
||||||
|
|
||||||
|
_, _, stream, err := backend.SyncLatest(ctx, "unknown-test", nil)
|
||||||
|
if assert.NoError(t, err) {
|
||||||
|
_, err := storage.RecordStreamToList(stream)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
stream.Close()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,6 +5,7 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/jackc/pgconn"
|
"github.com/jackc/pgconn"
|
||||||
|
@ -12,6 +13,7 @@ import (
|
||||||
"github.com/jackc/pgx/v4"
|
"github.com/jackc/pgx/v4"
|
||||||
"google.golang.org/protobuf/encoding/protojson"
|
"google.golang.org/protobuf/encoding/protojson"
|
||||||
"google.golang.org/protobuf/proto"
|
"google.golang.org/protobuf/proto"
|
||||||
|
"google.golang.org/protobuf/reflect/protoregistry"
|
||||||
"google.golang.org/protobuf/types/known/anypb"
|
"google.golang.org/protobuf/types/known/anypb"
|
||||||
"google.golang.org/protobuf/types/known/timestamppb"
|
"google.golang.org/protobuf/types/known/timestamppb"
|
||||||
|
|
||||||
|
@ -94,6 +96,7 @@ func getLatestRecordVersion(ctx context.Context, q querier) (recordVersion uint6
|
||||||
}
|
}
|
||||||
|
|
||||||
func getNextChangedRecord(ctx context.Context, q querier, recordType string, afterRecordVersion uint64) (*databroker.Record, error) {
|
func getNextChangedRecord(ctx context.Context, q querier, recordType string, afterRecordVersion uint64) (*databroker.Record, error) {
|
||||||
|
for {
|
||||||
var recordID string
|
var recordID string
|
||||||
var version uint64
|
var version uint64
|
||||||
var data pgtype.JSONB
|
var data pgtype.JSONB
|
||||||
|
@ -119,10 +122,14 @@ func getNextChangedRecord(ctx context.Context, q querier, recordType string, aft
|
||||||
} else if err != nil {
|
} else if err != nil {
|
||||||
return nil, fmt.Errorf("error querying next changed record: %w", err)
|
return nil, fmt.Errorf("error querying next changed record: %w", err)
|
||||||
}
|
}
|
||||||
|
afterRecordVersion = version
|
||||||
|
|
||||||
var any anypb.Any
|
var any anypb.Any
|
||||||
err = protojson.Unmarshal(data.Bytes, &any)
|
err = protojson.Unmarshal(data.Bytes, &any)
|
||||||
if err != nil {
|
if isUnknownType(err) {
|
||||||
|
// ignore
|
||||||
|
continue
|
||||||
|
} else if err != nil {
|
||||||
return nil, fmt.Errorf("error unmarshaling changed record data: %w", err)
|
return nil, fmt.Errorf("error unmarshaling changed record data: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -135,6 +142,7 @@ func getNextChangedRecord(ctx context.Context, q querier, recordType string, aft
|
||||||
DeletedAt: timestamppbFromTimestamptz(deletedAt),
|
DeletedAt: timestamppbFromTimestamptz(deletedAt),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func getOptions(ctx context.Context, q querier, recordType string) (*databroker.Options, error) {
|
func getOptions(ctx context.Context, q querier, recordType string) (*databroker.Options, error) {
|
||||||
var capacity pgtype.Int8
|
var capacity pgtype.Int8
|
||||||
|
@ -165,13 +173,15 @@ func getRecord(ctx context.Context, q querier, recordType, recordID string) (*da
|
||||||
if isNotFound(err) {
|
if isNotFound(err) {
|
||||||
return nil, storage.ErrNotFound
|
return nil, storage.ErrNotFound
|
||||||
} else if err != nil {
|
} else if err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("postgres: failed to execute query: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var any anypb.Any
|
var any anypb.Any
|
||||||
err = protojson.Unmarshal(data.Bytes, &any)
|
err = protojson.Unmarshal(data.Bytes, &any)
|
||||||
if err != nil {
|
if isUnknownType(err) {
|
||||||
return nil, err
|
return nil, storage.ErrNotFound
|
||||||
|
} else if err != nil {
|
||||||
|
return nil, fmt.Errorf("postgres: failed to unmarshal data: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &databroker.Record{
|
return &databroker.Record{
|
||||||
|
@ -193,7 +203,7 @@ func listRecords(ctx context.Context, q querier, expr storage.FilterExpression,
|
||||||
query += "WHERE "
|
query += "WHERE "
|
||||||
err := addFilterExpressionToQuery(&query, &args, expr)
|
err := addFilterExpressionToQuery(&query, &args, expr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("postgres: failed to add filter to query: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
query += `
|
query += `
|
||||||
|
@ -203,7 +213,7 @@ func listRecords(ctx context.Context, q querier, expr storage.FilterExpression,
|
||||||
`
|
`
|
||||||
rows, err := q.Query(ctx, query, args...)
|
rows, err := q.Query(ctx, query, args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("postgres: failed to execute query: %w", err)
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
defer rows.Close()
|
||||||
|
|
||||||
|
@ -215,13 +225,16 @@ func listRecords(ctx context.Context, q querier, expr storage.FilterExpression,
|
||||||
var modifiedAt pgtype.Timestamptz
|
var modifiedAt pgtype.Timestamptz
|
||||||
err = rows.Scan(&recordType, &id, &version, &data, &modifiedAt)
|
err = rows.Scan(&recordType, &id, &version, &data, &modifiedAt)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("postgres: failed to scan row: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var any anypb.Any
|
var any anypb.Any
|
||||||
err = protojson.Unmarshal(data.Bytes, &any)
|
err = protojson.Unmarshal(data.Bytes, &any)
|
||||||
if err != nil {
|
if isUnknownType(err) {
|
||||||
return nil, err
|
// ignore records with an unknown type
|
||||||
|
continue
|
||||||
|
} else if err != nil {
|
||||||
|
return nil, fmt.Errorf("postgres: failed to unmarshal data: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
records = append(records, &databroker.Record{
|
records = append(records, &databroker.Record{
|
||||||
|
@ -232,7 +245,12 @@ func listRecords(ctx context.Context, q querier, expr storage.FilterExpression,
|
||||||
ModifiedAt: timestamppbFromTimestamptz(modifiedAt),
|
ModifiedAt: timestamppbFromTimestamptz(modifiedAt),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
return records, rows.Err()
|
err = rows.Err()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("postgres: error iterating over rows: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return records, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func listServices(ctx context.Context, q querier) ([]*registry.Service, error) {
|
func listServices(ctx context.Context, q querier) ([]*registry.Service, error) {
|
||||||
|
@ -245,7 +263,7 @@ func listServices(ctx context.Context, q querier) ([]*registry.Service, error) {
|
||||||
`
|
`
|
||||||
rows, err := q.Query(ctx, query)
|
rows, err := q.Query(ctx, query)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("postgres: failed to execute query: %w", err)
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
defer rows.Close()
|
||||||
|
|
||||||
|
@ -253,7 +271,7 @@ func listServices(ctx context.Context, q querier) ([]*registry.Service, error) {
|
||||||
var kind, endpoint string
|
var kind, endpoint string
|
||||||
err = rows.Scan(&kind, &endpoint)
|
err = rows.Scan(&kind, &endpoint)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("postgres: failed to scan row: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
services = append(services, ®istry.Service{
|
services = append(services, ®istry.Service{
|
||||||
|
@ -263,7 +281,7 @@ func listServices(ctx context.Context, q querier) ([]*registry.Service, error) {
|
||||||
}
|
}
|
||||||
err = rows.Err()
|
err = rows.Err()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("postgres: error iterating over rows: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return services, nil
|
return services, nil
|
||||||
|
@ -287,7 +305,7 @@ func maybeAcquireLease(ctx context.Context, q querier, leaseName, leaseID string
|
||||||
func putRecordAndChange(ctx context.Context, q querier, record *databroker.Record) error {
|
func putRecordAndChange(ctx context.Context, q querier, record *databroker.Record) error {
|
||||||
data, err := jsonbFromAny(record.GetData())
|
data, err := jsonbFromAny(record.GetData())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("postgres: failed to convert any to json: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
modifiedAt := timestamptzFromTimestamppb(record.GetModifiedAt())
|
modifiedAt := timestamptzFromTimestamppb(record.GetModifiedAt())
|
||||||
|
@ -325,7 +343,7 @@ func putRecordAndChange(ctx context.Context, q querier, record *databroker.Recor
|
||||||
}
|
}
|
||||||
err = q.QueryRow(ctx, query, args...).Scan(&record.Version)
|
err = q.QueryRow(ctx, query, args...).Scan(&record.Version)
|
||||||
if err != nil && !isNotFound(err) {
|
if err != nil && !isNotFound(err) {
|
||||||
return err
|
return fmt.Errorf("postgres: failed to execute query: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
@ -398,3 +416,12 @@ func timestamptzFromTimestamppb(ts *timestamppb.Timestamp) pgtype.Timestamptz {
|
||||||
func isNotFound(err error) bool {
|
func isNotFound(err error) bool {
|
||||||
return errors.Is(err, pgx.ErrNoRows) || errors.Is(err, storage.ErrNotFound)
|
return errors.Is(err, pgx.ErrNoRows) || errors.Is(err, storage.ErrNotFound)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func isUnknownType(err error) bool {
|
||||||
|
if err == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return errors.Is(err, protoregistry.NotFound) ||
|
||||||
|
strings.Contains(err.Error(), "unable to resolve") // protojson doesn't wrap errors so check for the string
|
||||||
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue