diff --git a/pkg/protoutil/transform.go b/pkg/protoutil/transform.go new file mode 100644 index 000000000..597eea119 --- /dev/null +++ b/pkg/protoutil/transform.go @@ -0,0 +1,141 @@ +package protoutil + +import ( + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/types/known/anypb" +) + +// TransformFunc is a function that transforms a protobuf value into a new protobuf value. +type TransformFunc func(protoreflect.FieldDescriptor, protoreflect.Value) (protoreflect.Value, error) + +// Transform takes in a protobuf message and transforms any basic values with the given function. +func Transform(msg proto.Message, f TransformFunc) (proto.Message, error) { + t := transformer{callback: f} + src := msg.ProtoReflect() + dst := src.New() + err := t.transformMessage(dst, src) + if err != nil { + return nil, err + } + return dst.Interface(), nil +} + +type transformer struct { + callback TransformFunc +} + +func (t transformer) transformAny(dst, src *anypb.Any) error { + msg, err := src.UnmarshalNew() + if err != nil { + return err + } + + srcMsg := msg.ProtoReflect() + dstMsg := srcMsg.New() + + err = t.transformMessage(dstMsg, srcMsg) + if err != nil { + return err + } + + a, err := anypb.New(dstMsg.Interface()) + if err != nil { + return err + } + dst.TypeUrl = a.TypeUrl + dst.Value = a.Value + return nil +} + +func (t transformer) transformList(fd protoreflect.FieldDescriptor, dst, src protoreflect.List) error { + for i, n := 0, src.Len(); i < n; i++ { + v := src.Get(i) + switch vv := v.Interface().(type) { + case protoreflect.Message: + nv := dst.NewElement() + err := t.transformMessage(nv.Message(), vv) + if err != nil { + return err + } + dst.Append(nv) + default: + nv, err := t.callback(fd, v) + if err != nil { + return err + } + dst.Append(nv) + } + } + return nil +} + +func (t transformer) transformMap(fd protoreflect.FieldDescriptor, dst, src protoreflect.Map) error { + var err error + src.Range(func(k protoreflect.MapKey, v protoreflect.Value) bool { + switch vv := v.Interface().(type) { + case protoreflect.Message: + nv := dst.NewValue() + err := t.transformMessage(nv.Message(), vv) + if err != nil { + return false + } + dst.Set(k, nv) + default: + nv, err := t.callback(fd, v) + if err != nil { + return false + } + dst.Set(k, nv) + } + return true + }) + return err +} + +func (t transformer) transformMessage(dst, src protoreflect.Message) error { + // most of this code is based on + // https://github.com/protocolbuffers/protobuf-go/blob/v1.25.0/proto/merge.go + if srcAny, ok := src.Interface().(*anypb.Any); ok { + if dstAny, ok := dst.Interface().(*anypb.Any); ok { + return t.transformAny(dstAny, srcAny) + } + } + + var err error + src.Range(func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool { + switch { + case fd.IsList(): + err = t.transformList(fd, dst.Mutable(fd).List(), v.List()) + if err != nil { + return false + } + case fd.IsMap(): + err = t.transformMap(fd, dst.Mutable(fd).Map(), v.Map()) + if err != nil { + return false + } + case fd.Message() != nil: + err = t.transformMessage(dst.Mutable(fd).Message(), v.Message()) + if err != nil { + return false + } + default: + var nv protoreflect.Value + nv, err = t.callback(fd, v) + if err != nil { + return false + } + dst.Set(fd, nv) + } + return true + }) + if err != nil { + return err + } + + if len(src.GetUnknown()) > 0 { + dst.SetUnknown(append(dst.GetUnknown(), src.GetUnknown()...)) + } + return nil +} diff --git a/pkg/protoutil/transform_test.go b/pkg/protoutil/transform_test.go new file mode 100644 index 000000000..e8270e2ab --- /dev/null +++ b/pkg/protoutil/transform_test.go @@ -0,0 +1,134 @@ +package protoutil + +import ( + "testing" + "time" + + envoy_config_core_v3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" + envoy_service_auth_v3 "github.com/envoyproxy/go-control-plane/envoy/service/auth/v3" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/types/known/timestamppb" +) + +func TestTransform(t *testing.T) { + t1 := time.Now() + original := &envoy_service_auth_v3.CheckRequest{ + Attributes: &envoy_service_auth_v3.AttributeContext{ + Source: &envoy_service_auth_v3.AttributeContext_Peer{ + Address: &envoy_config_core_v3.Address{ + Address: &envoy_config_core_v3.Address_SocketAddress{ + SocketAddress: &envoy_config_core_v3.SocketAddress{ + Protocol: envoy_config_core_v3.SocketAddress_TCP, + Address: "SOURCE", + PortSpecifier: &envoy_config_core_v3.SocketAddress_PortValue{ + PortValue: 1234, + }, + }, + }, + }, + Service: "SERVICE", + Labels: map[string]string{ + "LABEL_KEY": "LABEL_VALUE", + }, + Principal: "PRINCIPAL", + Certificate: "CERTIFICATE", + }, + Destination: &envoy_service_auth_v3.AttributeContext_Peer{ + Address: &envoy_config_core_v3.Address{ + Address: &envoy_config_core_v3.Address_SocketAddress{ + SocketAddress: &envoy_config_core_v3.SocketAddress{ + Protocol: envoy_config_core_v3.SocketAddress_TCP, + Address: "DESTINATION", + PortSpecifier: &envoy_config_core_v3.SocketAddress_PortValue{ + PortValue: 5678, + }, + }, + }, + }, + Service: "SERVICE", + Labels: map[string]string{ + "LABEL_KEY": "LABEL_VALUE", + }, + Principal: "PRINCIPAL", + Certificate: "CERTIFICATE", + }, + Request: &envoy_service_auth_v3.AttributeContext_Request{ + Time: timestamppb.New(t1), + Http: &envoy_service_auth_v3.AttributeContext_HttpRequest{ + Id: "REQUEST_ID", + Method: "METHOD", + Headers: map[string]string{ + "HEADER_KEY": "HEADER_VALUE", + }, + Path: "PATH", + Host: "HOST", + Scheme: "SCHEME", + Query: "QUERY", + Fragment: "FRAGMENT", + Size: 23, + Protocol: "PROTOCOL", + Body: "BODY", + RawBody: []byte("RAW_BODY"), + }, + }, + }, + } + transformed, err := Transform(original, func(fd protoreflect.FieldDescriptor, v protoreflect.Value) (protoreflect.Value, error) { + switch vv := v.Interface().(type) { + case []byte: + return protoreflect.ValueOfBytes(append([]byte("TRANSFORM_"), vv...)), nil + case string: + return protoreflect.ValueOfString("TRANSFORM_" + vv), nil + } + return v, nil + }) + require.NoError(t, err) + if msg, ok := transformed.(*envoy_service_auth_v3.CheckRequest); assert.True(t, ok) { + assert.Equal(t, "TRANSFORM_SOURCE", + msg.GetAttributes().GetSource().GetAddress().GetSocketAddress().GetAddress()) + assert.Equal(t, "TRANSFORM_SERVICE", + msg.GetAttributes().GetSource().GetService()) + assert.Equal(t, map[string]string{"LABEL_KEY": "TRANSFORM_LABEL_VALUE"}, + msg.GetAttributes().GetSource().GetLabels()) + assert.Equal(t, "TRANSFORM_PRINCIPAL", + msg.GetAttributes().GetSource().GetPrincipal()) + assert.Equal(t, "TRANSFORM_CERTIFICATE", + msg.GetAttributes().GetSource().GetCertificate()) + + assert.Equal(t, "TRANSFORM_DESTINATION", + msg.GetAttributes().GetDestination().GetAddress().GetSocketAddress().GetAddress()) + assert.Equal(t, "TRANSFORM_SERVICE", + msg.GetAttributes().GetDestination().GetService()) + assert.Equal(t, map[string]string{"LABEL_KEY": "TRANSFORM_LABEL_VALUE"}, + msg.GetAttributes().GetDestination().GetLabels()) + assert.Equal(t, "TRANSFORM_PRINCIPAL", + msg.GetAttributes().GetDestination().GetPrincipal()) + assert.Equal(t, "TRANSFORM_CERTIFICATE", + msg.GetAttributes().GetDestination().GetCertificate()) + + assert.Equal(t, "TRANSFORM_REQUEST_ID", + msg.GetAttributes().GetRequest().GetHttp().GetId()) + assert.Equal(t, "TRANSFORM_METHOD", + msg.GetAttributes().GetRequest().GetHttp().GetMethod()) + assert.Equal(t, map[string]string{"HEADER_KEY": "TRANSFORM_HEADER_VALUE"}, + msg.GetAttributes().GetRequest().GetHttp().GetHeaders()) + assert.Equal(t, "TRANSFORM_PATH", + msg.GetAttributes().GetRequest().GetHttp().GetPath()) + assert.Equal(t, "TRANSFORM_HOST", + msg.GetAttributes().GetRequest().GetHttp().GetHost()) + assert.Equal(t, "TRANSFORM_SCHEME", + msg.GetAttributes().GetRequest().GetHttp().GetScheme()) + assert.Equal(t, "TRANSFORM_QUERY", + msg.GetAttributes().GetRequest().GetHttp().GetQuery()) + assert.Equal(t, "TRANSFORM_FRAGMENT", + msg.GetAttributes().GetRequest().GetHttp().GetFragment()) + assert.Equal(t, "TRANSFORM_PROTOCOL", + msg.GetAttributes().GetRequest().GetHttp().GetProtocol()) + assert.Equal(t, "TRANSFORM_BODY", + msg.GetAttributes().GetRequest().GetHttp().GetBody()) + assert.Equal(t, []byte("TRANSFORM_RAW_BODY"), + msg.GetAttributes().GetRequest().GetHttp().GetRawBody()) + } +}