mirror of
https://github.com/pomerium/pomerium.git
synced 2025-06-01 10:22:43 +02:00
protoutil: add generic transformer (#2023)
This commit is contained in:
parent
dda6a9af60
commit
5a33012950
2 changed files with 275 additions and 0 deletions
141
pkg/protoutil/transform.go
Normal file
141
pkg/protoutil/transform.go
Normal file
|
@ -0,0 +1,141 @@
|
||||||
|
package protoutil
|
||||||
|
|
||||||
|
import (
|
||||||
|
"google.golang.org/protobuf/proto"
|
||||||
|
"google.golang.org/protobuf/reflect/protoreflect"
|
||||||
|
"google.golang.org/protobuf/types/known/anypb"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TransformFunc is a function that transforms a protobuf value into a new protobuf value.
|
||||||
|
type TransformFunc func(protoreflect.FieldDescriptor, protoreflect.Value) (protoreflect.Value, error)
|
||||||
|
|
||||||
|
// Transform takes in a protobuf message and transforms any basic values with the given function.
|
||||||
|
func Transform(msg proto.Message, f TransformFunc) (proto.Message, error) {
|
||||||
|
t := transformer{callback: f}
|
||||||
|
src := msg.ProtoReflect()
|
||||||
|
dst := src.New()
|
||||||
|
err := t.transformMessage(dst, src)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return dst.Interface(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type transformer struct {
|
||||||
|
callback TransformFunc
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t transformer) transformAny(dst, src *anypb.Any) error {
|
||||||
|
msg, err := src.UnmarshalNew()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
srcMsg := msg.ProtoReflect()
|
||||||
|
dstMsg := srcMsg.New()
|
||||||
|
|
||||||
|
err = t.transformMessage(dstMsg, srcMsg)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
a, err := anypb.New(dstMsg.Interface())
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
dst.TypeUrl = a.TypeUrl
|
||||||
|
dst.Value = a.Value
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t transformer) transformList(fd protoreflect.FieldDescriptor, dst, src protoreflect.List) error {
|
||||||
|
for i, n := 0, src.Len(); i < n; i++ {
|
||||||
|
v := src.Get(i)
|
||||||
|
switch vv := v.Interface().(type) {
|
||||||
|
case protoreflect.Message:
|
||||||
|
nv := dst.NewElement()
|
||||||
|
err := t.transformMessage(nv.Message(), vv)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
dst.Append(nv)
|
||||||
|
default:
|
||||||
|
nv, err := t.callback(fd, v)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
dst.Append(nv)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t transformer) transformMap(fd protoreflect.FieldDescriptor, dst, src protoreflect.Map) error {
|
||||||
|
var err error
|
||||||
|
src.Range(func(k protoreflect.MapKey, v protoreflect.Value) bool {
|
||||||
|
switch vv := v.Interface().(type) {
|
||||||
|
case protoreflect.Message:
|
||||||
|
nv := dst.NewValue()
|
||||||
|
err := t.transformMessage(nv.Message(), vv)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
dst.Set(k, nv)
|
||||||
|
default:
|
||||||
|
nv, err := t.callback(fd, v)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
dst.Set(k, nv)
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t transformer) transformMessage(dst, src protoreflect.Message) error {
|
||||||
|
// most of this code is based on
|
||||||
|
// https://github.com/protocolbuffers/protobuf-go/blob/v1.25.0/proto/merge.go
|
||||||
|
if srcAny, ok := src.Interface().(*anypb.Any); ok {
|
||||||
|
if dstAny, ok := dst.Interface().(*anypb.Any); ok {
|
||||||
|
return t.transformAny(dstAny, srcAny)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var err error
|
||||||
|
src.Range(func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool {
|
||||||
|
switch {
|
||||||
|
case fd.IsList():
|
||||||
|
err = t.transformList(fd, dst.Mutable(fd).List(), v.List())
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
case fd.IsMap():
|
||||||
|
err = t.transformMap(fd, dst.Mutable(fd).Map(), v.Map())
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
case fd.Message() != nil:
|
||||||
|
err = t.transformMessage(dst.Mutable(fd).Message(), v.Message())
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
var nv protoreflect.Value
|
||||||
|
nv, err = t.callback(fd, v)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
dst.Set(fd, nv)
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(src.GetUnknown()) > 0 {
|
||||||
|
dst.SetUnknown(append(dst.GetUnknown(), src.GetUnknown()...))
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
134
pkg/protoutil/transform_test.go
Normal file
134
pkg/protoutil/transform_test.go
Normal file
|
@ -0,0 +1,134 @@
|
||||||
|
package protoutil
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
envoy_config_core_v3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
|
||||||
|
envoy_service_auth_v3 "github.com/envoyproxy/go-control-plane/envoy/service/auth/v3"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"google.golang.org/protobuf/reflect/protoreflect"
|
||||||
|
"google.golang.org/protobuf/types/known/timestamppb"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestTransform(t *testing.T) {
|
||||||
|
t1 := time.Now()
|
||||||
|
original := &envoy_service_auth_v3.CheckRequest{
|
||||||
|
Attributes: &envoy_service_auth_v3.AttributeContext{
|
||||||
|
Source: &envoy_service_auth_v3.AttributeContext_Peer{
|
||||||
|
Address: &envoy_config_core_v3.Address{
|
||||||
|
Address: &envoy_config_core_v3.Address_SocketAddress{
|
||||||
|
SocketAddress: &envoy_config_core_v3.SocketAddress{
|
||||||
|
Protocol: envoy_config_core_v3.SocketAddress_TCP,
|
||||||
|
Address: "SOURCE",
|
||||||
|
PortSpecifier: &envoy_config_core_v3.SocketAddress_PortValue{
|
||||||
|
PortValue: 1234,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Service: "SERVICE",
|
||||||
|
Labels: map[string]string{
|
||||||
|
"LABEL_KEY": "LABEL_VALUE",
|
||||||
|
},
|
||||||
|
Principal: "PRINCIPAL",
|
||||||
|
Certificate: "CERTIFICATE",
|
||||||
|
},
|
||||||
|
Destination: &envoy_service_auth_v3.AttributeContext_Peer{
|
||||||
|
Address: &envoy_config_core_v3.Address{
|
||||||
|
Address: &envoy_config_core_v3.Address_SocketAddress{
|
||||||
|
SocketAddress: &envoy_config_core_v3.SocketAddress{
|
||||||
|
Protocol: envoy_config_core_v3.SocketAddress_TCP,
|
||||||
|
Address: "DESTINATION",
|
||||||
|
PortSpecifier: &envoy_config_core_v3.SocketAddress_PortValue{
|
||||||
|
PortValue: 5678,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Service: "SERVICE",
|
||||||
|
Labels: map[string]string{
|
||||||
|
"LABEL_KEY": "LABEL_VALUE",
|
||||||
|
},
|
||||||
|
Principal: "PRINCIPAL",
|
||||||
|
Certificate: "CERTIFICATE",
|
||||||
|
},
|
||||||
|
Request: &envoy_service_auth_v3.AttributeContext_Request{
|
||||||
|
Time: timestamppb.New(t1),
|
||||||
|
Http: &envoy_service_auth_v3.AttributeContext_HttpRequest{
|
||||||
|
Id: "REQUEST_ID",
|
||||||
|
Method: "METHOD",
|
||||||
|
Headers: map[string]string{
|
||||||
|
"HEADER_KEY": "HEADER_VALUE",
|
||||||
|
},
|
||||||
|
Path: "PATH",
|
||||||
|
Host: "HOST",
|
||||||
|
Scheme: "SCHEME",
|
||||||
|
Query: "QUERY",
|
||||||
|
Fragment: "FRAGMENT",
|
||||||
|
Size: 23,
|
||||||
|
Protocol: "PROTOCOL",
|
||||||
|
Body: "BODY",
|
||||||
|
RawBody: []byte("RAW_BODY"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
transformed, err := Transform(original, func(fd protoreflect.FieldDescriptor, v protoreflect.Value) (protoreflect.Value, error) {
|
||||||
|
switch vv := v.Interface().(type) {
|
||||||
|
case []byte:
|
||||||
|
return protoreflect.ValueOfBytes(append([]byte("TRANSFORM_"), vv...)), nil
|
||||||
|
case string:
|
||||||
|
return protoreflect.ValueOfString("TRANSFORM_" + vv), nil
|
||||||
|
}
|
||||||
|
return v, nil
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
if msg, ok := transformed.(*envoy_service_auth_v3.CheckRequest); assert.True(t, ok) {
|
||||||
|
assert.Equal(t, "TRANSFORM_SOURCE",
|
||||||
|
msg.GetAttributes().GetSource().GetAddress().GetSocketAddress().GetAddress())
|
||||||
|
assert.Equal(t, "TRANSFORM_SERVICE",
|
||||||
|
msg.GetAttributes().GetSource().GetService())
|
||||||
|
assert.Equal(t, map[string]string{"LABEL_KEY": "TRANSFORM_LABEL_VALUE"},
|
||||||
|
msg.GetAttributes().GetSource().GetLabels())
|
||||||
|
assert.Equal(t, "TRANSFORM_PRINCIPAL",
|
||||||
|
msg.GetAttributes().GetSource().GetPrincipal())
|
||||||
|
assert.Equal(t, "TRANSFORM_CERTIFICATE",
|
||||||
|
msg.GetAttributes().GetSource().GetCertificate())
|
||||||
|
|
||||||
|
assert.Equal(t, "TRANSFORM_DESTINATION",
|
||||||
|
msg.GetAttributes().GetDestination().GetAddress().GetSocketAddress().GetAddress())
|
||||||
|
assert.Equal(t, "TRANSFORM_SERVICE",
|
||||||
|
msg.GetAttributes().GetDestination().GetService())
|
||||||
|
assert.Equal(t, map[string]string{"LABEL_KEY": "TRANSFORM_LABEL_VALUE"},
|
||||||
|
msg.GetAttributes().GetDestination().GetLabels())
|
||||||
|
assert.Equal(t, "TRANSFORM_PRINCIPAL",
|
||||||
|
msg.GetAttributes().GetDestination().GetPrincipal())
|
||||||
|
assert.Equal(t, "TRANSFORM_CERTIFICATE",
|
||||||
|
msg.GetAttributes().GetDestination().GetCertificate())
|
||||||
|
|
||||||
|
assert.Equal(t, "TRANSFORM_REQUEST_ID",
|
||||||
|
msg.GetAttributes().GetRequest().GetHttp().GetId())
|
||||||
|
assert.Equal(t, "TRANSFORM_METHOD",
|
||||||
|
msg.GetAttributes().GetRequest().GetHttp().GetMethod())
|
||||||
|
assert.Equal(t, map[string]string{"HEADER_KEY": "TRANSFORM_HEADER_VALUE"},
|
||||||
|
msg.GetAttributes().GetRequest().GetHttp().GetHeaders())
|
||||||
|
assert.Equal(t, "TRANSFORM_PATH",
|
||||||
|
msg.GetAttributes().GetRequest().GetHttp().GetPath())
|
||||||
|
assert.Equal(t, "TRANSFORM_HOST",
|
||||||
|
msg.GetAttributes().GetRequest().GetHttp().GetHost())
|
||||||
|
assert.Equal(t, "TRANSFORM_SCHEME",
|
||||||
|
msg.GetAttributes().GetRequest().GetHttp().GetScheme())
|
||||||
|
assert.Equal(t, "TRANSFORM_QUERY",
|
||||||
|
msg.GetAttributes().GetRequest().GetHttp().GetQuery())
|
||||||
|
assert.Equal(t, "TRANSFORM_FRAGMENT",
|
||||||
|
msg.GetAttributes().GetRequest().GetHttp().GetFragment())
|
||||||
|
assert.Equal(t, "TRANSFORM_PROTOCOL",
|
||||||
|
msg.GetAttributes().GetRequest().GetHttp().GetProtocol())
|
||||||
|
assert.Equal(t, "TRANSFORM_BODY",
|
||||||
|
msg.GetAttributes().GetRequest().GetHttp().GetBody())
|
||||||
|
assert.Equal(t, []byte("TRANSFORM_RAW_BODY"),
|
||||||
|
msg.GetAttributes().GetRequest().GetHttp().GetRawBody())
|
||||||
|
}
|
||||||
|
}
|
Loading…
Add table
Add a link
Reference in a new issue