storage: ignore removed fields when deserializing the data (#3768)

ignore removed fields when deserializing the data
This commit is contained in:
Denis Mishin 2022-11-28 11:31:57 -05:00 committed by GitHub
parent 424bdb4e62
commit 1d252f43ee
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 38 additions and 9 deletions

View file

@ -1,6 +1,7 @@
package protoutil package protoutil
import ( import (
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/proto" "google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/anypb" "google.golang.org/protobuf/types/known/anypb"
"google.golang.org/protobuf/types/known/wrapperspb" "google.golang.org/protobuf/types/known/wrapperspb"
@ -58,6 +59,19 @@ func NewAny(msg proto.Message) *anypb.Any {
return a return a
} }
// UnmarshalAnyJSON unmarshals JSON data into Any
func UnmarshalAnyJSON(data []byte) (*anypb.Any, error) {
opts := protojson.UnmarshalOptions{
AllowPartial: true,
DiscardUnknown: true,
}
var val anypb.Any
if err := opts.Unmarshal(data, &val); err != nil {
return nil, err
}
return &val, nil
}
// NewAnyBool creates a new any type from a bool. // NewAnyBool creates a new any type from a bool.
func NewAnyBool(v bool) *anypb.Any { func NewAnyBool(v bool) *anypb.Any {
return NewAny(wrapperspb.Bool(v)) return NewAny(wrapperspb.Bool(v))

View file

@ -19,6 +19,7 @@ import (
"github.com/pomerium/pomerium/pkg/grpc/databroker" "github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/grpc/registry" "github.com/pomerium/pomerium/pkg/grpc/registry"
"github.com/pomerium/pomerium/pkg/protoutil"
"github.com/pomerium/pomerium/pkg/storage" "github.com/pomerium/pomerium/pkg/storage"
) )
@ -124,8 +125,7 @@ func getNextChangedRecord(ctx context.Context, q querier, recordType string, aft
} }
afterRecordVersion = version afterRecordVersion = version
var any anypb.Any any, err := protoutil.UnmarshalAnyJSON(data.Bytes)
err = protojson.Unmarshal(data.Bytes, &any)
if isUnknownType(err) { if isUnknownType(err) {
// ignore // ignore
continue continue
@ -137,7 +137,7 @@ func getNextChangedRecord(ctx context.Context, q querier, recordType string, aft
Version: version, Version: version,
Type: recordType, Type: recordType,
Id: recordID, Id: recordID,
Data: &any, Data: any,
ModifiedAt: timestamppbFromTimestamptz(modifiedAt), ModifiedAt: timestamppbFromTimestamptz(modifiedAt),
DeletedAt: timestamppbFromTimestamptz(deletedAt), DeletedAt: timestamppbFromTimestamptz(deletedAt),
}, nil }, nil
@ -176,8 +176,7 @@ func getRecord(ctx context.Context, q querier, recordType, recordID string) (*da
return nil, fmt.Errorf("postgres: failed to execute query: %w", err) return nil, fmt.Errorf("postgres: failed to execute query: %w", err)
} }
var any anypb.Any any, err := protoutil.UnmarshalAnyJSON(data.Bytes)
err = protojson.Unmarshal(data.Bytes, &any)
if isUnknownType(err) { if isUnknownType(err) {
return nil, storage.ErrNotFound return nil, storage.ErrNotFound
} else if err != nil { } else if err != nil {
@ -188,7 +187,7 @@ func getRecord(ctx context.Context, q querier, recordType, recordID string) (*da
Version: version, Version: version,
Type: recordType, Type: recordType,
Id: recordID, Id: recordID,
Data: &any, Data: any,
ModifiedAt: timestamppbFromTimestamptz(modifiedAt), ModifiedAt: timestamppbFromTimestamptz(modifiedAt),
}, nil }, nil
} }
@ -228,8 +227,7 @@ func listRecords(ctx context.Context, q querier, expr storage.FilterExpression,
return nil, fmt.Errorf("postgres: failed to scan row: %w", err) return nil, fmt.Errorf("postgres: failed to scan row: %w", err)
} }
var any anypb.Any any, err := protoutil.UnmarshalAnyJSON(data.Bytes)
err = protojson.Unmarshal(data.Bytes, &any)
if isUnknownType(err) { if isUnknownType(err) {
// ignore records with an unknown type // ignore records with an unknown type
continue continue
@ -241,7 +239,7 @@ func listRecords(ctx context.Context, q querier, expr storage.FilterExpression,
Version: version, Version: version,
Type: recordType, Type: recordType,
Id: id, Id: id,
Data: &any, Data: any,
ModifiedAt: timestamppbFromTimestamptz(modifiedAt), ModifiedAt: timestamppbFromTimestamptz(modifiedAt),
}) })
} }

View file

@ -14,6 +14,7 @@ import (
"github.com/pomerium/pomerium/internal/testutil" "github.com/pomerium/pomerium/internal/testutil"
"github.com/pomerium/pomerium/pkg/grpc/registry" "github.com/pomerium/pomerium/pkg/grpc/registry"
"github.com/pomerium/pomerium/pkg/protoutil"
) )
type mockRegistryWatchServer struct { type mockRegistryWatchServer struct {
@ -112,3 +113,19 @@ func TestRegistry(t *testing.T) {
return nil return nil
})) }))
} }
func TestUnmarshalJSONUnknownFields(t *testing.T) {
any, err := protoutil.UnmarshalAnyJSON([]byte(`
{
"@type": "type.googleapis.com/registry.Service",
"kind": "AUTHENTICATE",
"endpoint": "endpoint",
"unknown_field": true
}
`))
require.NoError(t, err)
var val registry.Service
require.NoError(t, any.UnmarshalTo(&val))
assert.Equal(t, registry.ServiceKind_AUTHENTICATE, val.Kind)
assert.Equal(t, "endpoint", val.Endpoint)
}