package protoutil import ( "google.golang.org/protobuf/proto" "google.golang.org/protobuf/reflect/protoreflect" "google.golang.org/protobuf/types/known/anypb" ) // TransformFunc is a function that transforms a protobuf value into a new protobuf value. type TransformFunc func(protoreflect.FieldDescriptor, protoreflect.Value) (protoreflect.Value, error) // Transform takes in a protobuf message and transforms any basic values with the given function. func Transform(msg proto.Message, f TransformFunc) (proto.Message, error) { t := transformer{callback: f} src := msg.ProtoReflect() dst := src.New() err := t.transformMessage(dst, src) if err != nil { return nil, err } return dst.Interface(), nil } type transformer struct { callback TransformFunc } func (t transformer) transformAny(dst, src *anypb.Any) error { msg, err := src.UnmarshalNew() if err != nil { return err } srcMsg := msg.ProtoReflect() dstMsg := srcMsg.New() err = t.transformMessage(dstMsg, srcMsg) if err != nil { return err } a := NewAny(dstMsg.Interface()) dst.TypeUrl = a.TypeUrl dst.Value = a.Value return nil } func (t transformer) transformList(fd protoreflect.FieldDescriptor, dst, src protoreflect.List) error { for i, n := 0, src.Len(); i < n; i++ { v := src.Get(i) switch vv := v.Interface().(type) { case protoreflect.Message: nv := dst.NewElement() err := t.transformMessage(nv.Message(), vv) if err != nil { return err } dst.Append(nv) default: nv, err := t.callback(fd, v) if err != nil { return err } dst.Append(nv) } } return nil } func (t transformer) transformMap(fd protoreflect.FieldDescriptor, dst, src protoreflect.Map) error { var err error src.Range(func(k protoreflect.MapKey, v protoreflect.Value) bool { switch vv := v.Interface().(type) { case protoreflect.Message: nv := dst.NewValue() err := t.transformMessage(nv.Message(), vv) if err != nil { return false } dst.Set(k, nv) default: nv, err := t.callback(fd, v) if err != nil { return false } dst.Set(k, nv) } return true }) return err } func (t transformer) transformMessage(dst, src protoreflect.Message) error { // most of this code is based on // https://github.com/protocolbuffers/protobuf-go/blob/v1.25.0/proto/merge.go if srcAny, ok := src.Interface().(*anypb.Any); ok { if dstAny, ok := dst.Interface().(*anypb.Any); ok { return t.transformAny(dstAny, srcAny) } } var err error src.Range(func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool { switch { case fd.IsList(): err = t.transformList(fd, dst.Mutable(fd).List(), v.List()) if err != nil { return false } case fd.IsMap(): err = t.transformMap(fd, dst.Mutable(fd).Map(), v.Map()) if err != nil { return false } case fd.Message() != nil: err = t.transformMessage(dst.Mutable(fd).Message(), v.Message()) if err != nil { return false } default: var nv protoreflect.Value nv, err = t.callback(fd, v) if err != nil { return false } dst.Set(fd, nv) } return true }) if err != nil { return err } if len(src.GetUnknown()) > 0 { dst.SetUnknown(append(dst.GetUnknown(), src.GetUnknown()...)) } return nil }