mirror of
https://github.com/pomerium/pomerium.git
synced 2025-04-29 18:36:30 +02:00
138 lines
3.2 KiB
Go
138 lines
3.2 KiB
Go
package protoutil
|
|
|
|
import (
|
|
"google.golang.org/protobuf/proto"
|
|
"google.golang.org/protobuf/reflect/protoreflect"
|
|
"google.golang.org/protobuf/types/known/anypb"
|
|
)
|
|
|
|
// TransformFunc is a function that transforms a protobuf value into a new protobuf value.
|
|
type TransformFunc func(protoreflect.FieldDescriptor, protoreflect.Value) (protoreflect.Value, error)
|
|
|
|
// Transform takes in a protobuf message and transforms any basic values with the given function.
|
|
func Transform(msg proto.Message, f TransformFunc) (proto.Message, error) {
|
|
t := transformer{callback: f}
|
|
src := msg.ProtoReflect()
|
|
dst := src.New()
|
|
err := t.transformMessage(dst, src)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return dst.Interface(), nil
|
|
}
|
|
|
|
type transformer struct {
|
|
callback TransformFunc
|
|
}
|
|
|
|
func (t transformer) transformAny(dst, src *anypb.Any) error {
|
|
msg, err := src.UnmarshalNew()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
srcMsg := msg.ProtoReflect()
|
|
dstMsg := srcMsg.New()
|
|
|
|
err = t.transformMessage(dstMsg, srcMsg)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
a := NewAny(dstMsg.Interface())
|
|
dst.TypeUrl = a.TypeUrl
|
|
dst.Value = a.Value
|
|
return nil
|
|
}
|
|
|
|
func (t transformer) transformList(fd protoreflect.FieldDescriptor, dst, src protoreflect.List) error {
|
|
for i, n := 0, src.Len(); i < n; i++ {
|
|
v := src.Get(i)
|
|
switch vv := v.Interface().(type) {
|
|
case protoreflect.Message:
|
|
nv := dst.NewElement()
|
|
err := t.transformMessage(nv.Message(), vv)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
dst.Append(nv)
|
|
default:
|
|
nv, err := t.callback(fd, v)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
dst.Append(nv)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (t transformer) transformMap(fd protoreflect.FieldDescriptor, dst, src protoreflect.Map) error {
|
|
var err error
|
|
src.Range(func(k protoreflect.MapKey, v protoreflect.Value) bool {
|
|
switch vv := v.Interface().(type) {
|
|
case protoreflect.Message:
|
|
nv := dst.NewValue()
|
|
err := t.transformMessage(nv.Message(), vv)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
dst.Set(k, nv)
|
|
default:
|
|
nv, err := t.callback(fd, v)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
dst.Set(k, nv)
|
|
}
|
|
return true
|
|
})
|
|
return err
|
|
}
|
|
|
|
func (t transformer) transformMessage(dst, src protoreflect.Message) error {
|
|
// most of this code is based on
|
|
// https://github.com/protocolbuffers/protobuf-go/blob/v1.25.0/proto/merge.go
|
|
if srcAny, ok := src.Interface().(*anypb.Any); ok {
|
|
if dstAny, ok := dst.Interface().(*anypb.Any); ok {
|
|
return t.transformAny(dstAny, srcAny)
|
|
}
|
|
}
|
|
|
|
var err error
|
|
src.Range(func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool {
|
|
switch {
|
|
case fd.IsList():
|
|
err = t.transformList(fd, dst.Mutable(fd).List(), v.List())
|
|
if err != nil {
|
|
return false
|
|
}
|
|
case fd.IsMap():
|
|
err = t.transformMap(fd, dst.Mutable(fd).Map(), v.Map())
|
|
if err != nil {
|
|
return false
|
|
}
|
|
case fd.Message() != nil:
|
|
err = t.transformMessage(dst.Mutable(fd).Message(), v.Message())
|
|
if err != nil {
|
|
return false
|
|
}
|
|
default:
|
|
var nv protoreflect.Value
|
|
nv, err = t.callback(fd, v)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
dst.Set(fd, nv)
|
|
}
|
|
return true
|
|
})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if len(src.GetUnknown()) > 0 {
|
|
dst.SetUnknown(append(dst.GetUnknown(), src.GetUnknown()...))
|
|
}
|
|
return nil
|
|
}
|