mirror of
https://github.com/pomerium/pomerium.git
synced 2025-04-29 18:36:30 +02:00
157 lines
4.2 KiB
Go
157 lines
4.2 KiB
Go
package logutil
|
|
|
|
import (
|
|
"crypto/hmac"
|
|
"crypto/sha256"
|
|
"encoding/base64"
|
|
|
|
"google.golang.org/protobuf/proto"
|
|
"google.golang.org/protobuf/reflect/protoreflect"
|
|
"google.golang.org/protobuf/types/known/anypb"
|
|
"google.golang.org/protobuf/types/known/wrapperspb"
|
|
|
|
"github.com/pomerium/pomerium/pkg/protoutil"
|
|
)
|
|
|
|
// A Scrubber scrubs potentially sensitive strings from protobuf messages.
|
|
type Scrubber struct {
|
|
key string
|
|
whitelist map[string]struct{}
|
|
}
|
|
|
|
// NewScrubber creates a new Scrubber.
|
|
func NewScrubber(key string) *Scrubber {
|
|
return &Scrubber{
|
|
key: key,
|
|
whitelist: map[string]struct{}{},
|
|
}
|
|
}
|
|
|
|
// Whitelist whitelists fields for a given type. The type name should be the full
|
|
// protobuf typename (ie google.protobuf.Any).
|
|
func (s *Scrubber) Whitelist(typeName string, fieldNames ...string) *Scrubber {
|
|
for _, fieldName := range fieldNames {
|
|
s.whitelist[typeName+"."+fieldName] = struct{}{}
|
|
}
|
|
return s
|
|
}
|
|
|
|
// ScrubProto takes in a protobuf message, clones it and scrubs any non-whitelisted strings.
|
|
func (s *Scrubber) ScrubProto(msg proto.Message) proto.Message {
|
|
src := msg.ProtoReflect()
|
|
dst := src.New()
|
|
|
|
s.scrubProtoMessage(dst, src)
|
|
|
|
return dst.Interface()
|
|
}
|
|
|
|
func (s *Scrubber) scrubProtoMessage(dst, src protoreflect.Message) {
|
|
// most of this code is based on
|
|
// https://github.com/protocolbuffers/protobuf-go/blob/v1.25.0/proto/merge.go
|
|
if srcany, ok := src.Interface().(*anypb.Any); ok {
|
|
if dstany, ok := dst.Interface().(*anypb.Any); ok {
|
|
s.scrubProtoAny(dstany, srcany)
|
|
return
|
|
}
|
|
}
|
|
|
|
src.Range(func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool {
|
|
// skip whitelisted fields
|
|
if _, ok := s.whitelist[string(fd.FullName())]; ok {
|
|
dst.Set(fd, v)
|
|
return true
|
|
}
|
|
|
|
switch {
|
|
case fd.IsList():
|
|
s.scrubProtoList(dst.Mutable(fd).List(), v.List(), fd)
|
|
case fd.IsMap():
|
|
s.scrubProtoMap(dst.Mutable(fd).Map(), v.Map(), fd.MapValue())
|
|
case fd.Message() != nil:
|
|
s.scrubProtoMessage(dst.Mutable(fd).Message(), v.Message())
|
|
case fd.Kind() == protoreflect.BytesKind:
|
|
nv := s.hmacBytes(v.Bytes())
|
|
dst.Set(fd, protoreflect.ValueOfBytes(nv))
|
|
case fd.Kind() == protoreflect.StringKind:
|
|
nv := s.hmacString(v.String())
|
|
dst.Set(fd, protoreflect.ValueOfString(nv))
|
|
default:
|
|
dst.Set(fd, v)
|
|
}
|
|
return true
|
|
})
|
|
|
|
if len(src.GetUnknown()) > 0 {
|
|
dst.SetUnknown(append(dst.GetUnknown(), src.GetUnknown()...))
|
|
}
|
|
}
|
|
|
|
func (s *Scrubber) scrubProtoAny(dst, src *anypb.Any) {
|
|
msg, err := src.UnmarshalNew()
|
|
if err != nil {
|
|
// this will happen if a type isn't registered.
|
|
// So we will just hash the raw data.
|
|
a := protoutil.NewAny(wrapperspb.Bytes(s.hmacBytes(src.Value)))
|
|
dst.TypeUrl = a.TypeUrl
|
|
dst.Value = a.Value
|
|
return
|
|
}
|
|
|
|
srcmsg := msg.ProtoReflect()
|
|
dstmsg := srcmsg.New()
|
|
|
|
s.scrubProtoMessage(dstmsg, srcmsg)
|
|
|
|
a := protoutil.NewAny(dstmsg.Interface())
|
|
dst.TypeUrl = a.TypeUrl
|
|
dst.Value = a.Value
|
|
}
|
|
|
|
func (s *Scrubber) scrubProtoList(dst, src protoreflect.List, fd protoreflect.FieldDescriptor) {
|
|
for i, n := 0, src.Len(); i < n; i++ {
|
|
switch v := src.Get(i); {
|
|
case fd.Message() != nil:
|
|
dstv := dst.NewElement()
|
|
s.scrubProtoMessage(dstv.Message(), v.Message())
|
|
dst.Append(dstv)
|
|
case fd.Kind() == protoreflect.BytesKind:
|
|
nv := s.hmacBytes(v.Bytes())
|
|
dst.Append(protoreflect.ValueOfBytes(nv))
|
|
case fd.Kind() == protoreflect.StringKind:
|
|
nv := s.hmacString(v.String())
|
|
dst.Append(protoreflect.ValueOfString(nv))
|
|
default:
|
|
dst.Append(v)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (s *Scrubber) scrubProtoMap(dst, src protoreflect.Map, fd protoreflect.FieldDescriptor) {
|
|
src.Range(func(k protoreflect.MapKey, v protoreflect.Value) bool {
|
|
switch {
|
|
case fd.Message() != nil:
|
|
dstv := dst.NewValue()
|
|
s.scrubProtoMessage(dstv.Message(), v.Message())
|
|
dst.Set(k, dstv)
|
|
case fd.Kind() == protoreflect.BytesKind:
|
|
nv := s.hmacBytes(v.Bytes())
|
|
dst.Set(k, protoreflect.ValueOfBytes(nv))
|
|
case fd.Kind() == protoreflect.StringKind:
|
|
nv := s.hmacString(v.String())
|
|
dst.Set(k, protoreflect.ValueOfString(nv))
|
|
default:
|
|
dst.Set(k, v)
|
|
}
|
|
return true
|
|
})
|
|
}
|
|
|
|
func (s *Scrubber) hmacBytes(v []byte) []byte {
|
|
h := hmac.New(sha256.New, []byte(s.key))
|
|
return h.Sum(v)
|
|
}
|
|
|
|
func (s *Scrubber) hmacString(v string) string {
|
|
return base64.StdEncoding.EncodeToString(s.hmacBytes([]byte(v)))
|
|
}
|