mirror of
https://github.com/pomerium/pomerium.git
synced 2025-08-02 00:10:45 +02:00
protoutil: add OverwriteMasked method
Add a method to copy selected fields from one proto message to another (of the same type), using a FieldMask. This is intended for use in a new databroker Patch method.
This commit is contained in:
parent
2472490075
commit
6b434b48f4
3 changed files with 216 additions and 0 deletions
95
pkg/protoutil/fieldmask.go
Normal file
95
pkg/protoutil/fieldmask.go
Normal file
|
@ -0,0 +1,95 @@
|
|||
package protoutil
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/exp/maps"
|
||||
"google.golang.org/protobuf/proto"
|
||||
"google.golang.org/protobuf/reflect/protoreflect"
|
||||
"google.golang.org/protobuf/types/known/fieldmaskpb"
|
||||
)
|
||||
|
||||
// OverwriteMasked copies values from src to dst subject to a field mask. It
|
||||
// will will return an error if dst and src are not the same type of message,
|
||||
// or if some path in the field mask is not valid for that message type.
|
||||
func OverwriteMasked(dst, src proto.Message, m *fieldmaskpb.FieldMask) error {
|
||||
return newFieldMaskTree(m).overwrite(dst.ProtoReflect(), src.ProtoReflect())
|
||||
}
|
||||
|
||||
// fieldMaskTree represents a FieldMask as a tree, making it simpler to operate
|
||||
// on messages recursively.
|
||||
type fieldMaskTree map[string]fieldMaskTree
|
||||
|
||||
func newFieldMaskTree(m *fieldmaskpb.FieldMask) fieldMaskTree {
|
||||
var t fieldMaskTree
|
||||
for _, p := range m.GetPaths() {
|
||||
t.addFieldPath(p)
|
||||
}
|
||||
return t
|
||||
}
|
||||
|
||||
// This is inspired by FieldMaskTree.java from the Java protobuf library:
|
||||
// https://github.com/protocolbuffers/protobuf/blob/3667102d9/java/util/src/main/java/com/google/protobuf/util/FieldMaskTree.java#L76
|
||||
func (t *fieldMaskTree) addFieldPath(path string) {
|
||||
if *t == nil {
|
||||
*t = make(map[string]fieldMaskTree)
|
||||
}
|
||||
|
||||
parts := strings.Split(path, ".")
|
||||
|
||||
node := *t
|
||||
for _, p := range parts {
|
||||
m := node[p]
|
||||
if m == nil {
|
||||
m = make(fieldMaskTree)
|
||||
node[p] = m
|
||||
} else if len(m) == 0 {
|
||||
return
|
||||
}
|
||||
node = m
|
||||
}
|
||||
maps.Clear(node)
|
||||
}
|
||||
|
||||
// ErrDescriptorMismatch indicates an operation could not be performed because
|
||||
// two proto messages did not have identical descriptors.
|
||||
var ErrDescriptorMismatch = errors.New("descriptor mismatch")
|
||||
|
||||
func (t fieldMaskTree) overwrite(dst, src protoreflect.Message) error {
|
||||
dd, sd := dst.Descriptor(), src.Descriptor()
|
||||
if dd != sd {
|
||||
return fmt.Errorf("%w: %v, %v", ErrDescriptorMismatch, dd.FullName(), sd.FullName())
|
||||
}
|
||||
|
||||
fields := dd.Fields()
|
||||
|
||||
for p, subTree := range t {
|
||||
f := fields.ByName(protoreflect.Name(p))
|
||||
if f == nil {
|
||||
return fmt.Errorf("cannot overwrite unknown field %q in message %v", p, dd.FullName())
|
||||
}
|
||||
|
||||
if len(subTree) > 0 {
|
||||
if f.Cardinality() == protoreflect.Repeated || f.Kind() != protoreflect.MessageKind {
|
||||
return fmt.Errorf("cannot overwrite sub-fields of field %q in message %v",
|
||||
f.TextName(), dd.FullName())
|
||||
}
|
||||
if !src.Has(f) && !src.Has(f) {
|
||||
// no need to copy fields that don't exist
|
||||
continue
|
||||
}
|
||||
subTree.overwrite(dst.Mutable(f).Message(), src.Get(f).Message())
|
||||
continue
|
||||
}
|
||||
|
||||
if src.Has(f) {
|
||||
dst.Set(f, src.Get(f))
|
||||
} else {
|
||||
dst.Clear(f)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
66
pkg/protoutil/fieldmask_test.go
Normal file
66
pkg/protoutil/fieldmask_test.go
Normal file
|
@ -0,0 +1,66 @@
|
|||
package protoutil_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/protobuf/types/known/fieldmaskpb"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/testutil"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/session"
|
||||
"github.com/pomerium/pomerium/pkg/protoutil"
|
||||
)
|
||||
|
||||
func TestOverwriteMasked(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
s1 := &session.Session{
|
||||
Id: "session-id",
|
||||
IssuedAt: timestamppb.New(time.Date(2023, 10, 25, 10, 0, 0, 0, time.UTC)),
|
||||
}
|
||||
s2 := &session.Session{
|
||||
Id: "new-session-id",
|
||||
AccessedAt: timestamppb.New(time.Date(2023, 10, 25, 12, 0, 0, 0, time.UTC)),
|
||||
OauthToken: &session.OAuthToken{
|
||||
AccessToken: "new-access-token",
|
||||
TokenType: "bearer",
|
||||
},
|
||||
}
|
||||
|
||||
m, err := fieldmaskpb.New(s2,
|
||||
"issued_at", "accessed_at", "oauth_token.access_token", "id_token.raw")
|
||||
require.NoError(t, err)
|
||||
|
||||
err = protoutil.OverwriteMasked(s1, s2, m)
|
||||
require.NoError(t, err)
|
||||
|
||||
testutil.AssertProtoJSONEqual(t, `{
|
||||
"id": "session-id",
|
||||
"accessedAt": "2023-10-25T12:00:00Z",
|
||||
"oauthToken": {
|
||||
"accessToken": "new-access-token"
|
||||
}
|
||||
}`, s1)
|
||||
}
|
||||
|
||||
func TestOverwriteMaskedErrors(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var s1, s2 session.Session
|
||||
var o session.OAuthToken
|
||||
|
||||
err := protoutil.OverwriteMasked(&s1, &s2, &fieldmaskpb.FieldMask{Paths: []string{"foo"}})
|
||||
assert.Equal(t, `cannot overwrite unknown field "foo" in message session.Session`, err.Error())
|
||||
|
||||
err = protoutil.OverwriteMasked(&s1, &s2,
|
||||
&fieldmaskpb.FieldMask{Paths: []string{"device_credentials.type_id"}})
|
||||
assert.Equal(t, `cannot overwrite sub-fields of field "device_credentials" in message `+
|
||||
"session.Session", err.Error())
|
||||
|
||||
m, _ := fieldmaskpb.New(&s1, "expires_at")
|
||||
err = protoutil.OverwriteMasked(&s1, &o, m)
|
||||
assert.Equal(t, "descriptor mismatch: session.Session, session.OAuthToken", err.Error())
|
||||
}
|
55
pkg/protoutil/fieldmasktree_test.go
Normal file
55
pkg/protoutil/fieldmasktree_test.go
Normal file
|
@ -0,0 +1,55 @@
|
|||
package protoutil
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"google.golang.org/protobuf/types/known/fieldmaskpb"
|
||||
)
|
||||
|
||||
func TestFieldMaskTree(t *testing.T) {
|
||||
t.Run("empty", func(t *testing.T) {
|
||||
tr := newFieldMaskTree(&fieldmaskpb.FieldMask{})
|
||||
assert.Equal(t, fieldMaskTree(nil), tr)
|
||||
})
|
||||
t.Run("basic", func(t *testing.T) {
|
||||
tr := newFieldMaskTree(&fieldmaskpb.FieldMask{
|
||||
Paths: []string{"foo", "bar", "baz"},
|
||||
})
|
||||
assert.Equal(t, fieldMaskTree{
|
||||
"foo": {},
|
||||
"bar": {},
|
||||
"baz": {},
|
||||
}, tr)
|
||||
})
|
||||
t.Run("nested", func(t *testing.T) {
|
||||
tr := newFieldMaskTree(&fieldmaskpb.FieldMask{
|
||||
Paths: []string{"foo.bar.baz", "foo.bar.xyz", "foo.quux"},
|
||||
})
|
||||
assert.Equal(t, fieldMaskTree{
|
||||
"foo": {
|
||||
"bar": {
|
||||
"baz": {},
|
||||
"xyz": {},
|
||||
},
|
||||
"quux": {},
|
||||
},
|
||||
}, tr)
|
||||
})
|
||||
t.Run("overlapping fields 1", func(t *testing.T) {
|
||||
tr := newFieldMaskTree(&fieldmaskpb.FieldMask{
|
||||
Paths: []string{"foo", "foo.bar"},
|
||||
})
|
||||
assert.Equal(t, fieldMaskTree{
|
||||
"foo": {},
|
||||
}, tr)
|
||||
})
|
||||
t.Run("overlapping fields 2", func(t *testing.T) {
|
||||
tr := newFieldMaskTree(&fieldmaskpb.FieldMask{
|
||||
Paths: []string{"foo.bar", "foo"},
|
||||
})
|
||||
assert.Equal(t, fieldMaskTree{
|
||||
"foo": {},
|
||||
}, tr)
|
||||
})
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue