diff --git a/pkg/protoutil/fieldmask.go b/pkg/protoutil/fieldmask.go new file mode 100644 index 000000000..b7170320d --- /dev/null +++ b/pkg/protoutil/fieldmask.go @@ -0,0 +1,95 @@ +package protoutil + +import ( + "errors" + "fmt" + "strings" + + "golang.org/x/exp/maps" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/types/known/fieldmaskpb" +) + +// OverwriteMasked copies values from src to dst subject to a field mask. It +// will will return an error if dst and src are not the same type of message, +// or if some path in the field mask is not valid for that message type. +func OverwriteMasked(dst, src proto.Message, m *fieldmaskpb.FieldMask) error { + return newFieldMaskTree(m).overwrite(dst.ProtoReflect(), src.ProtoReflect()) +} + +// fieldMaskTree represents a FieldMask as a tree, making it simpler to operate +// on messages recursively. +type fieldMaskTree map[string]fieldMaskTree + +func newFieldMaskTree(m *fieldmaskpb.FieldMask) fieldMaskTree { + var t fieldMaskTree + for _, p := range m.GetPaths() { + t.addFieldPath(p) + } + return t +} + +// This is inspired by FieldMaskTree.java from the Java protobuf library: +// https://github.com/protocolbuffers/protobuf/blob/3667102d9/java/util/src/main/java/com/google/protobuf/util/FieldMaskTree.java#L76 +func (t *fieldMaskTree) addFieldPath(path string) { + if *t == nil { + *t = make(map[string]fieldMaskTree) + } + + parts := strings.Split(path, ".") + + node := *t + for _, p := range parts { + m := node[p] + if m == nil { + m = make(fieldMaskTree) + node[p] = m + } else if len(m) == 0 { + return + } + node = m + } + maps.Clear(node) +} + +// ErrDescriptorMismatch indicates an operation could not be performed because +// two proto messages did not have identical descriptors. +var ErrDescriptorMismatch = errors.New("descriptor mismatch") + +func (t fieldMaskTree) overwrite(dst, src protoreflect.Message) error { + dd, sd := dst.Descriptor(), src.Descriptor() + if dd != sd { + return fmt.Errorf("%w: %v, %v", ErrDescriptorMismatch, dd.FullName(), sd.FullName()) + } + + fields := dd.Fields() + + for p, subTree := range t { + f := fields.ByName(protoreflect.Name(p)) + if f == nil { + return fmt.Errorf("cannot overwrite unknown field %q in message %v", p, dd.FullName()) + } + + if len(subTree) > 0 { + if f.Cardinality() == protoreflect.Repeated || f.Kind() != protoreflect.MessageKind { + return fmt.Errorf("cannot overwrite sub-fields of field %q in message %v", + f.TextName(), dd.FullName()) + } + if !src.Has(f) && !src.Has(f) { + // no need to copy fields that don't exist + continue + } + subTree.overwrite(dst.Mutable(f).Message(), src.Get(f).Message()) + continue + } + + if src.Has(f) { + dst.Set(f, src.Get(f)) + } else { + dst.Clear(f) + } + } + + return nil +} diff --git a/pkg/protoutil/fieldmask_test.go b/pkg/protoutil/fieldmask_test.go new file mode 100644 index 000000000..281b32df6 --- /dev/null +++ b/pkg/protoutil/fieldmask_test.go @@ -0,0 +1,66 @@ +package protoutil_test + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/types/known/fieldmaskpb" + "google.golang.org/protobuf/types/known/timestamppb" + + "github.com/pomerium/pomerium/internal/testutil" + "github.com/pomerium/pomerium/pkg/grpc/session" + "github.com/pomerium/pomerium/pkg/protoutil" +) + +func TestOverwriteMasked(t *testing.T) { + t.Parallel() + + s1 := &session.Session{ + Id: "session-id", + IssuedAt: timestamppb.New(time.Date(2023, 10, 25, 10, 0, 0, 0, time.UTC)), + } + s2 := &session.Session{ + Id: "new-session-id", + AccessedAt: timestamppb.New(time.Date(2023, 10, 25, 12, 0, 0, 0, time.UTC)), + OauthToken: &session.OAuthToken{ + AccessToken: "new-access-token", + TokenType: "bearer", + }, + } + + m, err := fieldmaskpb.New(s2, + "issued_at", "accessed_at", "oauth_token.access_token", "id_token.raw") + require.NoError(t, err) + + err = protoutil.OverwriteMasked(s1, s2, m) + require.NoError(t, err) + + testutil.AssertProtoJSONEqual(t, `{ + "id": "session-id", + "accessedAt": "2023-10-25T12:00:00Z", + "oauthToken": { + "accessToken": "new-access-token" + } + }`, s1) +} + +func TestOverwriteMaskedErrors(t *testing.T) { + t.Parallel() + + var s1, s2 session.Session + var o session.OAuthToken + + err := protoutil.OverwriteMasked(&s1, &s2, &fieldmaskpb.FieldMask{Paths: []string{"foo"}}) + assert.Equal(t, `cannot overwrite unknown field "foo" in message session.Session`, err.Error()) + + err = protoutil.OverwriteMasked(&s1, &s2, + &fieldmaskpb.FieldMask{Paths: []string{"device_credentials.type_id"}}) + assert.Equal(t, `cannot overwrite sub-fields of field "device_credentials" in message `+ + "session.Session", err.Error()) + + m, _ := fieldmaskpb.New(&s1, "expires_at") + err = protoutil.OverwriteMasked(&s1, &o, m) + assert.Equal(t, "descriptor mismatch: session.Session, session.OAuthToken", err.Error()) +} diff --git a/pkg/protoutil/fieldmasktree_test.go b/pkg/protoutil/fieldmasktree_test.go new file mode 100644 index 000000000..020e41b1f --- /dev/null +++ b/pkg/protoutil/fieldmasktree_test.go @@ -0,0 +1,55 @@ +package protoutil + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "google.golang.org/protobuf/types/known/fieldmaskpb" +) + +func TestFieldMaskTree(t *testing.T) { + t.Run("empty", func(t *testing.T) { + tr := newFieldMaskTree(&fieldmaskpb.FieldMask{}) + assert.Equal(t, fieldMaskTree(nil), tr) + }) + t.Run("basic", func(t *testing.T) { + tr := newFieldMaskTree(&fieldmaskpb.FieldMask{ + Paths: []string{"foo", "bar", "baz"}, + }) + assert.Equal(t, fieldMaskTree{ + "foo": {}, + "bar": {}, + "baz": {}, + }, tr) + }) + t.Run("nested", func(t *testing.T) { + tr := newFieldMaskTree(&fieldmaskpb.FieldMask{ + Paths: []string{"foo.bar.baz", "foo.bar.xyz", "foo.quux"}, + }) + assert.Equal(t, fieldMaskTree{ + "foo": { + "bar": { + "baz": {}, + "xyz": {}, + }, + "quux": {}, + }, + }, tr) + }) + t.Run("overlapping fields 1", func(t *testing.T) { + tr := newFieldMaskTree(&fieldmaskpb.FieldMask{ + Paths: []string{"foo", "foo.bar"}, + }) + assert.Equal(t, fieldMaskTree{ + "foo": {}, + }, tr) + }) + t.Run("overlapping fields 2", func(t *testing.T) { + tr := newFieldMaskTree(&fieldmaskpb.FieldMask{ + Paths: []string{"foo.bar", "foo"}, + }) + assert.Equal(t, fieldMaskTree{ + "foo": {}, + }, tr) + }) +}