package logutil import ( "crypto/hmac" "crypto/sha256" "encoding/base64" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/reflect/protoreflect" "google.golang.org/protobuf/types/known/anypb" "google.golang.org/protobuf/types/known/wrapperspb" "github.com/pomerium/pomerium/pkg/protoutil" ) // A Scrubber scrubs potentially sensitive strings from protobuf messages. type Scrubber struct { key string whitelist map[string]struct{} } // NewScrubber creates a new Scrubber. func NewScrubber(key string) *Scrubber { return &Scrubber{ key: key, whitelist: map[string]struct{}{}, } } // Whitelist whitelists fields for a given type. The type name should be the full // protobuf typename (ie google.protobuf.Any). func (s *Scrubber) Whitelist(typeName string, fieldNames ...string) *Scrubber { for _, fieldName := range fieldNames { s.whitelist[typeName+"."+fieldName] = struct{}{} } return s } // ScrubProto takes in a protobuf message, clones it and scrubs any non-whitelisted strings. func (s *Scrubber) ScrubProto(msg proto.Message) proto.Message { src := msg.ProtoReflect() dst := src.New() s.scrubProtoMessage(dst, src) return dst.Interface() } func (s *Scrubber) scrubProtoMessage(dst, src protoreflect.Message) { // most of this code is based on // https://github.com/protocolbuffers/protobuf-go/blob/v1.25.0/proto/merge.go if srcany, ok := src.Interface().(*anypb.Any); ok { if dstany, ok := dst.Interface().(*anypb.Any); ok { s.scrubProtoAny(dstany, srcany) return } } src.Range(func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool { // skip whitelisted fields if _, ok := s.whitelist[string(fd.FullName())]; ok { dst.Set(fd, v) return true } switch { case fd.IsList(): s.scrubProtoList(dst.Mutable(fd).List(), v.List(), fd) case fd.IsMap(): s.scrubProtoMap(dst.Mutable(fd).Map(), v.Map(), fd.MapValue()) case fd.Message() != nil: s.scrubProtoMessage(dst.Mutable(fd).Message(), v.Message()) case fd.Kind() == protoreflect.BytesKind: nv := s.hmacBytes(v.Bytes()) dst.Set(fd, protoreflect.ValueOfBytes(nv)) case fd.Kind() == protoreflect.StringKind: nv := s.hmacString(v.String()) dst.Set(fd, protoreflect.ValueOfString(nv)) default: dst.Set(fd, v) } return true }) if len(src.GetUnknown()) > 0 { dst.SetUnknown(append(dst.GetUnknown(), src.GetUnknown()...)) } } func (s *Scrubber) scrubProtoAny(dst, src *anypb.Any) { msg, err := src.UnmarshalNew() if err != nil { // this will happen if a type isn't registered. // So we will just hash the raw data. a := protoutil.NewAny(wrapperspb.Bytes(s.hmacBytes(src.Value))) dst.TypeUrl = a.TypeUrl dst.Value = a.Value return } srcmsg := msg.ProtoReflect() dstmsg := srcmsg.New() s.scrubProtoMessage(dstmsg, srcmsg) a := protoutil.NewAny(dstmsg.Interface()) dst.TypeUrl = a.TypeUrl dst.Value = a.Value } func (s *Scrubber) scrubProtoList(dst, src protoreflect.List, fd protoreflect.FieldDescriptor) { for i, n := 0, src.Len(); i < n; i++ { switch v := src.Get(i); { case fd.Message() != nil: dstv := dst.NewElement() s.scrubProtoMessage(dstv.Message(), v.Message()) dst.Append(dstv) case fd.Kind() == protoreflect.BytesKind: nv := s.hmacBytes(v.Bytes()) dst.Append(protoreflect.ValueOfBytes(nv)) case fd.Kind() == protoreflect.StringKind: nv := s.hmacString(v.String()) dst.Append(protoreflect.ValueOfString(nv)) default: dst.Append(v) } } } func (s *Scrubber) scrubProtoMap(dst, src protoreflect.Map, fd protoreflect.FieldDescriptor) { src.Range(func(k protoreflect.MapKey, v protoreflect.Value) bool { switch { case fd.Message() != nil: dstv := dst.NewValue() s.scrubProtoMessage(dstv.Message(), v.Message()) dst.Set(k, dstv) case fd.Kind() == protoreflect.BytesKind: nv := s.hmacBytes(v.Bytes()) dst.Set(k, protoreflect.ValueOfBytes(nv)) case fd.Kind() == protoreflect.StringKind: nv := s.hmacString(v.String()) dst.Set(k, protoreflect.ValueOfString(nv)) default: dst.Set(k, v) } return true }) } func (s *Scrubber) hmacBytes(v []byte) []byte { h := hmac.New(sha256.New, []byte(s.key)) return h.Sum(v) } func (s *Scrubber) hmacString(v string) string { return base64.StdEncoding.EncodeToString(s.hmacBytes([]byte(v))) }