protoutil: add OverwriteMasked method

Add a method to copy selected fields from one proto message to another
(of the same type), using a FieldMask. This is intended for use in a new
databroker Patch method.
This commit is contained in:
Kenneth Jenkins 2023-10-30 09:03:05 -07:00
parent 2472490075
commit 6b434b48f4
3 changed files with 216 additions and 0 deletions

View file

@ -0,0 +1,95 @@
package protoutil
import (
"errors"
"fmt"
"strings"
"golang.org/x/exp/maps"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/types/known/fieldmaskpb"
)
// OverwriteMasked copies values from src to dst subject to a field mask. It
// will will return an error if dst and src are not the same type of message,
// or if some path in the field mask is not valid for that message type.
func OverwriteMasked(dst, src proto.Message, m *fieldmaskpb.FieldMask) error {
return newFieldMaskTree(m).overwrite(dst.ProtoReflect(), src.ProtoReflect())
}
// fieldMaskTree represents a FieldMask as a tree, making it simpler to operate
// on messages recursively.
type fieldMaskTree map[string]fieldMaskTree
func newFieldMaskTree(m *fieldmaskpb.FieldMask) fieldMaskTree {
var t fieldMaskTree
for _, p := range m.GetPaths() {
t.addFieldPath(p)
}
return t
}
// This is inspired by FieldMaskTree.java from the Java protobuf library:
// https://github.com/protocolbuffers/protobuf/blob/3667102d9/java/util/src/main/java/com/google/protobuf/util/FieldMaskTree.java#L76
func (t *fieldMaskTree) addFieldPath(path string) {
if *t == nil {
*t = make(map[string]fieldMaskTree)
}
parts := strings.Split(path, ".")
node := *t
for _, p := range parts {
m := node[p]
if m == nil {
m = make(fieldMaskTree)
node[p] = m
} else if len(m) == 0 {
return
}
node = m
}
maps.Clear(node)
}
// ErrDescriptorMismatch indicates an operation could not be performed because
// two proto messages did not have identical descriptors.
var ErrDescriptorMismatch = errors.New("descriptor mismatch")
func (t fieldMaskTree) overwrite(dst, src protoreflect.Message) error {
dd, sd := dst.Descriptor(), src.Descriptor()
if dd != sd {
return fmt.Errorf("%w: %v, %v", ErrDescriptorMismatch, dd.FullName(), sd.FullName())
}
fields := dd.Fields()
for p, subTree := range t {
f := fields.ByName(protoreflect.Name(p))
if f == nil {
return fmt.Errorf("cannot overwrite unknown field %q in message %v", p, dd.FullName())
}
if len(subTree) > 0 {
if f.Cardinality() == protoreflect.Repeated || f.Kind() != protoreflect.MessageKind {
return fmt.Errorf("cannot overwrite sub-fields of field %q in message %v",
f.TextName(), dd.FullName())
}
if !src.Has(f) && !src.Has(f) {
// no need to copy fields that don't exist
continue
}
subTree.overwrite(dst.Mutable(f).Message(), src.Get(f).Message())
continue
}
if src.Has(f) {
dst.Set(f, src.Get(f))
} else {
dst.Clear(f)
}
}
return nil
}

View file

@ -0,0 +1,66 @@
package protoutil_test
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/types/known/fieldmaskpb"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/pomerium/pomerium/internal/testutil"
"github.com/pomerium/pomerium/pkg/grpc/session"
"github.com/pomerium/pomerium/pkg/protoutil"
)
func TestOverwriteMasked(t *testing.T) {
t.Parallel()
s1 := &session.Session{
Id: "session-id",
IssuedAt: timestamppb.New(time.Date(2023, 10, 25, 10, 0, 0, 0, time.UTC)),
}
s2 := &session.Session{
Id: "new-session-id",
AccessedAt: timestamppb.New(time.Date(2023, 10, 25, 12, 0, 0, 0, time.UTC)),
OauthToken: &session.OAuthToken{
AccessToken: "new-access-token",
TokenType: "bearer",
},
}
m, err := fieldmaskpb.New(s2,
"issued_at", "accessed_at", "oauth_token.access_token", "id_token.raw")
require.NoError(t, err)
err = protoutil.OverwriteMasked(s1, s2, m)
require.NoError(t, err)
testutil.AssertProtoJSONEqual(t, `{
"id": "session-id",
"accessedAt": "2023-10-25T12:00:00Z",
"oauthToken": {
"accessToken": "new-access-token"
}
}`, s1)
}
func TestOverwriteMaskedErrors(t *testing.T) {
t.Parallel()
var s1, s2 session.Session
var o session.OAuthToken
err := protoutil.OverwriteMasked(&s1, &s2, &fieldmaskpb.FieldMask{Paths: []string{"foo"}})
assert.Equal(t, `cannot overwrite unknown field "foo" in message session.Session`, err.Error())
err = protoutil.OverwriteMasked(&s1, &s2,
&fieldmaskpb.FieldMask{Paths: []string{"device_credentials.type_id"}})
assert.Equal(t, `cannot overwrite sub-fields of field "device_credentials" in message `+
"session.Session", err.Error())
m, _ := fieldmaskpb.New(&s1, "expires_at")
err = protoutil.OverwriteMasked(&s1, &o, m)
assert.Equal(t, "descriptor mismatch: session.Session, session.OAuthToken", err.Error())
}

View file

@ -0,0 +1,55 @@
package protoutil
import (
"testing"
"github.com/stretchr/testify/assert"
"google.golang.org/protobuf/types/known/fieldmaskpb"
)
func TestFieldMaskTree(t *testing.T) {
t.Run("empty", func(t *testing.T) {
tr := newFieldMaskTree(&fieldmaskpb.FieldMask{})
assert.Equal(t, fieldMaskTree(nil), tr)
})
t.Run("basic", func(t *testing.T) {
tr := newFieldMaskTree(&fieldmaskpb.FieldMask{
Paths: []string{"foo", "bar", "baz"},
})
assert.Equal(t, fieldMaskTree{
"foo": {},
"bar": {},
"baz": {},
}, tr)
})
t.Run("nested", func(t *testing.T) {
tr := newFieldMaskTree(&fieldmaskpb.FieldMask{
Paths: []string{"foo.bar.baz", "foo.bar.xyz", "foo.quux"},
})
assert.Equal(t, fieldMaskTree{
"foo": {
"bar": {
"baz": {},
"xyz": {},
},
"quux": {},
},
}, tr)
})
t.Run("overlapping fields 1", func(t *testing.T) {
tr := newFieldMaskTree(&fieldmaskpb.FieldMask{
Paths: []string{"foo", "foo.bar"},
})
assert.Equal(t, fieldMaskTree{
"foo": {},
}, tr)
})
t.Run("overlapping fields 2", func(t *testing.T) {
tr := newFieldMaskTree(&fieldmaskpb.FieldMask{
Paths: []string{"foo.bar", "foo"},
})
assert.Equal(t, fieldMaskTree{
"foo": {},
}, tr)
})
}