pomerium/pkg/logutil/scrub.go

157 lines
4.2 KiB
Go

package logutil
import (
"crypto/hmac"
"crypto/sha256"
"encoding/base64"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/types/known/anypb"
"google.golang.org/protobuf/types/known/wrapperspb"
"github.com/pomerium/pomerium/pkg/protoutil"
)
// A Scrubber scrubs potentially sensitive strings from protobuf messages.
type Scrubber struct {
key string
whitelist map[string]struct{}
}
// NewScrubber creates a new Scrubber.
func NewScrubber(key string) *Scrubber {
return &Scrubber{
key: key,
whitelist: map[string]struct{}{},
}
}
// Whitelist whitelists fields for a given type. The type name should be the full
// protobuf typename (ie google.protobuf.Any).
func (s *Scrubber) Whitelist(typeName string, fieldNames ...string) *Scrubber {
for _, fieldName := range fieldNames {
s.whitelist[typeName+"."+fieldName] = struct{}{}
}
return s
}
// ScrubProto takes in a protobuf message, clones it and scrubs any non-whitelisted strings.
func (s *Scrubber) ScrubProto(msg proto.Message) proto.Message {
src := msg.ProtoReflect()
dst := src.New()
s.scrubProtoMessage(dst, src)
return dst.Interface()
}
func (s *Scrubber) scrubProtoMessage(dst, src protoreflect.Message) {
// most of this code is based on
// https://github.com/protocolbuffers/protobuf-go/blob/v1.25.0/proto/merge.go
if srcany, ok := src.Interface().(*anypb.Any); ok {
if dstany, ok := dst.Interface().(*anypb.Any); ok {
s.scrubProtoAny(dstany, srcany)
return
}
}
src.Range(func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool {
// skip whitelisted fields
if _, ok := s.whitelist[string(fd.FullName())]; ok {
dst.Set(fd, v)
return true
}
switch {
case fd.IsList():
s.scrubProtoList(dst.Mutable(fd).List(), v.List(), fd)
case fd.IsMap():
s.scrubProtoMap(dst.Mutable(fd).Map(), v.Map(), fd.MapValue())
case fd.Message() != nil:
s.scrubProtoMessage(dst.Mutable(fd).Message(), v.Message())
case fd.Kind() == protoreflect.BytesKind:
nv := s.hmacBytes(v.Bytes())
dst.Set(fd, protoreflect.ValueOfBytes(nv))
case fd.Kind() == protoreflect.StringKind:
nv := s.hmacString(v.String())
dst.Set(fd, protoreflect.ValueOfString(nv))
default:
dst.Set(fd, v)
}
return true
})
if len(src.GetUnknown()) > 0 {
dst.SetUnknown(append(dst.GetUnknown(), src.GetUnknown()...))
}
}
func (s *Scrubber) scrubProtoAny(dst, src *anypb.Any) {
msg, err := src.UnmarshalNew()
if err != nil {
// this will happen if a type isn't registered.
// So we will just hash the raw data.
a := protoutil.NewAny(wrapperspb.Bytes(s.hmacBytes(src.Value)))
dst.TypeUrl = a.TypeUrl
dst.Value = a.Value
return
}
srcmsg := msg.ProtoReflect()
dstmsg := srcmsg.New()
s.scrubProtoMessage(dstmsg, srcmsg)
a := protoutil.NewAny(dstmsg.Interface())
dst.TypeUrl = a.TypeUrl
dst.Value = a.Value
}
func (s *Scrubber) scrubProtoList(dst, src protoreflect.List, fd protoreflect.FieldDescriptor) {
for i, n := 0, src.Len(); i < n; i++ {
switch v := src.Get(i); {
case fd.Message() != nil:
dstv := dst.NewElement()
s.scrubProtoMessage(dstv.Message(), v.Message())
dst.Append(dstv)
case fd.Kind() == protoreflect.BytesKind:
nv := s.hmacBytes(v.Bytes())
dst.Append(protoreflect.ValueOfBytes(nv))
case fd.Kind() == protoreflect.StringKind:
nv := s.hmacString(v.String())
dst.Append(protoreflect.ValueOfString(nv))
default:
dst.Append(v)
}
}
}
func (s *Scrubber) scrubProtoMap(dst, src protoreflect.Map, fd protoreflect.FieldDescriptor) {
src.Range(func(k protoreflect.MapKey, v protoreflect.Value) bool {
switch {
case fd.Message() != nil:
dstv := dst.NewValue()
s.scrubProtoMessage(dstv.Message(), v.Message())
dst.Set(k, dstv)
case fd.Kind() == protoreflect.BytesKind:
nv := s.hmacBytes(v.Bytes())
dst.Set(k, protoreflect.ValueOfBytes(nv))
case fd.Kind() == protoreflect.StringKind:
nv := s.hmacString(v.String())
dst.Set(k, protoreflect.ValueOfString(nv))
default:
dst.Set(k, v)
}
return true
})
}
func (s *Scrubber) hmacBytes(v []byte) []byte {
h := hmac.New(sha256.New, []byte(s.key))
return h.Sum(v)
}
func (s *Scrubber) hmacString(v string) string {
return base64.StdEncoding.EncodeToString(s.hmacBytes([]byte(v)))
}