envoy: use envoy request id for logging across systems with http and gRPC (#691)

This commit is contained in:
Caleb Doxsey 2020-05-12 06:55:55 -06:00 committed by Travis Groth
parent 593c47f8ac
commit 41855e5419
16 changed files with 228 additions and 253 deletions

View file

@ -6,15 +6,18 @@ import (
"time"
"github.com/gorilla/handlers"
"github.com/pomerium/pomerium/internal/frontend"
"github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/middleware"
"github.com/pomerium/pomerium/internal/telemetry/requestid"
"github.com/pomerium/pomerium/internal/version"
)
func (srv *Server) addHTTPMiddleware() {
root := srv.HTTPRouter
root.Use(requestid.HTTPMiddleware())
root.Use(log.NewHandler(log.Logger))
root.Use(log.AccessHandler(func(r *http.Request, status, size int, duration time.Duration) {
log.FromRequest(r).Debug().
@ -31,7 +34,7 @@ func (srv *Server) addHTTPMiddleware() {
root.Use(log.RemoteAddrHandler("ip"))
root.Use(log.UserAgentHandler("user_agent"))
root.Use(log.RefererHandler("referer"))
root.Use(log.RequestIDHandler("req_id", "Request-Id"))
root.Use(log.RequestIDHandler("request-id"))
root.Use(middleware.Healthcheck("/ping", version.UserAgent()))
root.HandleFunc("/healthz", httputil.HealthCheck)
root.HandleFunc("/ping", httputil.HealthCheck)

View file

@ -14,6 +14,7 @@ import (
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/telemetry/requestid"
)
type versionedOptions struct {
@ -58,7 +59,10 @@ func NewServer() (*Server, error) {
if err != nil {
return nil, err
}
srv.GRPCServer = grpc.NewServer()
srv.GRPCServer = grpc.NewServer(
grpc.UnaryInterceptor(requestid.UnaryServerInterceptor()),
grpc.StreamInterceptor(requestid.StreamServerInterceptor()),
)
reflection.Register(srv.GRPCServer)
srv.registerXDSHandlers()
srv.registerAccessLogHandlers()

View file

@ -14,6 +14,7 @@ import (
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/telemetry/metrics"
"github.com/pomerium/pomerium/internal/telemetry/requestid"
"go.opencensus.io/plugin/ocgrpc"
"google.golang.org/grpc"
@ -60,7 +61,12 @@ func NewGRPCClientConn(opts *Options) (*grpc.ClientConn, error) {
connAddr = fmt.Sprintf("%s:%d", connAddr, defaultGRPCPort)
}
dialOptions := []grpc.DialOption{
grpc.WithChainUnaryInterceptor(metrics.GRPCClientInterceptor(opts.ServiceName), grpcTimeoutInterceptor(opts.RequestTimeout)),
grpc.WithChainUnaryInterceptor(
requestid.UnaryClientInterceptor(),
metrics.GRPCClientInterceptor(opts.ServiceName),
grpcTimeoutInterceptor(opts.RequestTimeout),
),
grpc.WithStreamInterceptor(requestid.StreamClientInterceptor()),
grpc.WithStatsHandler(&ocgrpc.ClientHandler{}),
grpc.WithDefaultCallOptions([]grpc.CallOption{grpc.WaitForReady(true)}...),
}

View file

@ -1,109 +0,0 @@
package grpc
import (
"crypto/tls"
"errors"
"net"
"os"
"os/signal"
"sync"
"syscall"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/telemetry/metrics"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/keepalive"
)
// NewServer creates a new gRPC serve.
// It is the callers responsibility to close the resturned server.
func NewServer(opt *ServerOptions, registrationFn func(s *grpc.Server), wg *sync.WaitGroup) (*grpc.Server, error) {
if opt == nil {
opt = defaultServerOptions
} else {
opt.applyServerDefaults()
}
ln, err := net.Listen("tcp", opt.Addr)
if err != nil {
return nil, err
}
grpcOpts := []grpc.ServerOption{
grpc.StatsHandler(metrics.NewGRPCServerStatsHandler(opt.ServiceName)),
grpc.KeepaliveParams(opt.KeepaliveParams),
}
if len(opt.TLSCertificate) == 1 {
cert := credentials.NewServerTLSFromCert(&opt.TLSCertificate[0])
grpcOpts = append(grpcOpts, grpc.Creds(cert))
} else if !opt.InsecureServer {
return nil, errors.New("internal/grpc: unexpected number of certificates")
}
srv := grpc.NewServer(grpcOpts...)
registrationFn(srv)
log.Info().
Str("addr", opt.Addr).
Bool("insecure", opt.InsecureServer).
Str("service", opt.ServiceName).
Interface("grpc-service-info", srv.GetServiceInfo()).
Msg("internal/grpc: registered")
wg.Add(1)
go func() {
defer wg.Done()
if err := srv.Serve(ln); err != grpc.ErrServerStopped {
log.Error().Str("addr", opt.Addr).Err(err).Msg("internal/grpc: unexpected shutdown")
}
}()
return srv, nil
}
// ServerOptions contains the configurations settings for a gRPC server.
type ServerOptions struct {
// Addr specifies the host and port on which the server should serve
// gRPC requests. If empty, ":443" is used.
Addr string
// TLS certificates to use, if any.
TLSCertificate []tls.Certificate
// InsecureServer when enabled disables all transport security.
// In this mode, Pomerium is susceptible to man-in-the-middle attacks.
// This should be used only for testing.
InsecureServer bool
// KeepaliveParams sets GRPC keepalive.ServerParameters
KeepaliveParams keepalive.ServerParameters
// ServiceName specifies the service name for telemetry exposition
ServiceName string
}
var defaultServerOptions = &ServerOptions{
Addr: ":443",
}
func (o *ServerOptions) applyServerDefaults() {
if o.Addr == "" {
o.Addr = defaultServerOptions.Addr
}
}
// Shutdown attempts to shut down the server when a os interrupt or sigterm
// signal are received without interrupting any
// active connections. Shutdown stops the server from
// accepting new connections and RPCs and blocks until all the pending RPCs are
// finished.
func Shutdown(srv *grpc.Server) {
sigint := make(chan os.Signal, 1)
signal.Notify(sigint, os.Interrupt)
signal.Notify(sigint, syscall.SIGTERM)
rec := <-sigint
log.Info().Str("signal", rec.String()).Msg("internal/grpc: shutting down servers")
srv.GracefulStop()
log.Info().Str("signal", rec.String()).Msg("internal/grpc: shut down servers")
}

View file

@ -1,92 +0,0 @@
package grpc
import (
"crypto/tls"
"encoding/base64"
"os"
"os/signal"
"sync"
"syscall"
"testing"
"time"
"github.com/pomerium/pomerium/internal/cryptutil"
"google.golang.org/grpc"
"google.golang.org/grpc/keepalive"
)
const privKey = `-----BEGIN EC PRIVATE KEY-----
MHcCAQEEIMQiDy26/R4ca/OdnjIf8OEDeHcw8yB5SDV9FD500CW5oAoGCCqGSM49
AwEHoUQDQgAEFumdSrEe9dnPEUU3LuyC8l6MM6PefNgpSsRL4GrD22XITMjqDKFr
jqJTf0Fo1ZWm4v+Eds6s88rsLzEC+cKLRQ==
-----END EC PRIVATE KEY-----`
const pubKey = `-----BEGIN CERTIFICATE-----
MIIBeDCCAR+gAwIBAgIUUGE8w2S7XzpkVLbNq5QUxyVOwqEwCgYIKoZIzj0EAwIw
ETEPMA0GA1UEAwwGdW51c2VkMCAXDTE5MDcxNTIzNDQyOVoYDzQ3NTcwNjExMjM0
NDI5WjARMQ8wDQYDVQQDDAZ1bnVzZWQwWTATBgcqhkjOPQIBBggqhkjOPQMBBwNC
AAQW6Z1KsR712c8RRTcu7ILyXowzo9582ClKxEvgasPbZchMyOoMoWuOolN/QWjV
labi/4R2zqzzyuwvMQL5wotFo1MwUTAdBgNVHQ4EFgQURYdcaniRqBHXeaM79LtV
pyJ4EwAwHwYDVR0jBBgwFoAURYdcaniRqBHXeaM79LtVpyJ4EwAwDwYDVR0TAQH/
BAUwAwEB/zAKBggqhkjOPQQDAgNHADBEAiBHbhVnGbwXqaMZ1dB8eBAK56jyeWDZ
2PWXmFMTu7+RywIgaZ7UwVNB2k7KjEEBiLm0PIRcpJmczI2cP9+ZMIkPHHw=
-----END CERTIFICATE-----`
func TestNewServer(t *testing.T) {
// to make friendly to testing environments where 443 requires root
defaultServerOptions.Addr = ":0"
certb64, err := cryptutil.CertifcateFromBase64(
base64.StdEncoding.EncodeToString([]byte(pubKey)),
base64.StdEncoding.EncodeToString([]byte(privKey)))
if err != nil {
t.Fatal(err)
}
tests := []struct {
name string
opt *ServerOptions
registrationFn func(s *grpc.Server)
wg *sync.WaitGroup
wantNil bool
wantErr bool
}{
{"simple", &ServerOptions{Addr: ":0", InsecureServer: true}, func(s *grpc.Server) {}, &sync.WaitGroup{}, false, false},
{"simple keepalive options", &ServerOptions{Addr: ":0", InsecureServer: true, KeepaliveParams: keepalive.ServerParameters{MaxConnectionAge: 5 * time.Minute}}, func(s *grpc.Server) {}, &sync.WaitGroup{}, false, false},
{"bad tcp port", &ServerOptions{Addr: ":9999999"}, func(s *grpc.Server) {}, &sync.WaitGroup{}, true, true},
{"with cert", &ServerOptions{Addr: ":0", TLSCertificate: []tls.Certificate{*certb64}}, func(s *grpc.Server) {}, &sync.WaitGroup{}, false, false},
{"with multiple certs", &ServerOptions{Addr: ":0", TLSCertificate: []tls.Certificate{*certb64, *certb64}}, func(s *grpc.Server) {}, &sync.WaitGroup{}, true, true},
{"with no certs or insecure", &ServerOptions{Addr: ":0", TLSCertificate: []tls.Certificate{}}, func(s *grpc.Server) {}, &sync.WaitGroup{}, true, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := NewServer(tt.opt, tt.registrationFn, tt.wg)
if (err != nil) != tt.wantErr {
t.Errorf("NewServer() error = %v, wantErr %v", err, tt.wantErr)
return
}
if (got == nil) != tt.wantNil {
t.Errorf("NewServer() = %v, want %v", got, tt.wantNil)
}
if got != nil {
// simulate a sigterm and cleanup the server
c := make(chan os.Signal, 1)
signal.Notify(c, syscall.SIGINT)
defer signal.Stop(c)
go Shutdown(got)
syscall.Kill(syscall.Getpid(), syscall.SIGINT)
waitSig(t, c, syscall.SIGINT)
}
})
}
}
func waitSig(t *testing.T, c <-chan os.Signal, sig os.Signal) {
select {
case s := <-c:
if s != sig {
t.Fatalf("signal was %v, want %v", s, sig)
}
case <-time.After(1 * time.Second):
t.Fatalf("timeout waiting for %v", sig)
}
}

View file

@ -13,6 +13,8 @@ import (
"time"
"go.opencensus.io/plugin/ochttp"
"github.com/pomerium/pomerium/internal/telemetry/requestid"
)
// ErrTokenRevoked signifies a token revokation or expiration error
@ -22,7 +24,7 @@ var ErrTokenRevoked = errors.New("token expired or revoked")
var DefaultClient = &http.Client{
Timeout: 1 * time.Minute,
//todo(bdd): incorporate metrics.HTTPMetricsRoundTripper
Transport: &ochttp.Transport{},
Transport: requestid.NewRoundTripper(&ochttp.Transport{}),
}
// Client provides a simple helper interface to make HTTP requests

View file

@ -7,6 +7,7 @@ import (
"github.com/pomerium/pomerium/internal/frontend"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/telemetry/requestid"
"github.com/pomerium/pomerium/internal/urlutil"
"github.com/pomerium/pomerium/internal/version"
)
@ -65,10 +66,7 @@ func (e *HTTPError) ErrorResponse(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(e.Status)
log.FromRequest(r).Info().Err(e).Msg("httputil: ErrorResponse")
var requestID string
if id, ok := log.IDFromRequest(r); ok {
requestID = id
}
requestID := requestid.FromContext(r.Context())
response := errResponse{
Status: e.Status,
StatusText: http.StatusText(e.Status),

View file

@ -1,15 +1,16 @@
package log
import (
"context"
"crypto/rand"
"fmt"
"net"
"net/http"
"time"
"github.com/pomerium/pomerium/internal/middleware/responsewriter"
"github.com/rs/zerolog"
"github.com/pomerium/pomerium/internal/middleware/responsewriter"
"github.com/pomerium/pomerium/internal/telemetry/requestid"
)
// NewHandler injects log into requests context.
@ -115,45 +116,18 @@ func RefererHandler(fieldKey string) func(next http.Handler) http.Handler {
}
}
type idKey struct{}
// IDFromRequest returns the unique id associated to the request if any.
func IDFromRequest(r *http.Request) (id string, ok bool) {
if r == nil {
return
}
return IDFromCtx(r.Context())
}
// IDFromCtx returns the unique id associated to the context if any.
func IDFromCtx(ctx context.Context) (id string, ok bool) {
id, ok = ctx.Value(idKey{}).(string)
return
}
// RequestIDHandler returns a handler setting a unique id to the request which can
// be gathered using IDFromRequest(req). This generated id is added as a field to the
// logger using the passed fieldKey as field name. The id is also added as a response
// header if the headerName is not empty.
func RequestIDHandler(fieldKey, headerName string) func(next http.Handler) http.Handler {
// RequestIDHandler adds the request's id as a field to the context's logger
// using fieldKey as field key.
func RequestIDHandler(fieldKey string) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
id, ok := IDFromRequest(r)
if !ok {
id = uuid()
ctx = context.WithValue(ctx, idKey{}, id)
r = r.WithContext(ctx)
}
if fieldKey != "" {
log := zerolog.Ctx(ctx)
requestID := requestid.FromContext(r.Context())
if requestID != "" {
log := zerolog.Ctx(r.Context())
log.UpdateContext(func(c zerolog.Context) zerolog.Context {
return c.Str(fieldKey, id)
return c.Str(fieldKey, requestID)
})
}
if headerName != "" {
w.Header().Set(headerName, id)
}
next.ServeHTTP(w, r)
})
}

View file

@ -14,6 +14,8 @@ import (
"github.com/google/go-cmp/cmp"
"github.com/rs/zerolog"
"github.com/pomerium/pomerium/internal/telemetry/requestid"
)
func TestGenerateUUID(t *testing.T) {
@ -172,21 +174,19 @@ func TestRequestIDHandler(t *testing.T) {
out := &bytes.Buffer{}
r := &http.Request{
Header: http.Header{
"Referer": []string{"http://foo.com/bar"},
"X-Request-Id": []string{"1234"},
},
}
h := RequestIDHandler("id", "Request-Id")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
id, ok := IDFromRequest(r)
if !ok {
t.Fatal("Missing id in request")
}
h := RequestIDHandler("request-id")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
requestID := requestid.FromContext(r.Context())
l := FromRequest(r)
l.Log().Msg("")
if want, got := fmt.Sprintf(`{"id":"%s"}`+"\n", id), decodeIfBinary(out); want != got {
if want, got := fmt.Sprintf(`{"request-id":"%s"}`+"\n", requestID), decodeIfBinary(out); want != got {
t.Errorf("Invalid log output, got: %s, want: %s", got, want)
}
}))
h = NewHandler(zerolog.New(out))(h)
h = requestid.HTTPMiddleware()(h)
h.ServeHTTP(httptest.NewRecorder(), r)
}

View file

@ -0,0 +1,46 @@
package requestid
import (
"context"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
)
// StreamClientInterceptor returns a new gRPC StreamClientInterceptor which puts the request ID in the outgoing
// metadata.
func StreamClientInterceptor() grpc.StreamClientInterceptor {
return func(ctx context.Context,
desc *grpc.StreamDesc, cc *grpc.ClientConn,
method string, streamer grpc.Streamer, opts ...grpc.CallOption,
) (grpc.ClientStream, error) {
toMetadata(ctx)
return streamer(ctx, desc, cc, method, opts...)
}
}
// UnaryClientInterceptor returns a new gRPC UnaryClientInterceptor which puts the request ID in the outgoing
// metadata.
func UnaryClientInterceptor() grpc.UnaryClientInterceptor {
return func(ctx context.Context,
method string, req, reply interface{},
cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption,
) error {
toMetadata(ctx)
return invoker(ctx, method, req, reply, cc, opts...)
}
}
func toMetadata(ctx context.Context) {
requestID := FromContext(ctx)
if requestID == "" {
requestID = New()
}
md, ok := metadata.FromOutgoingContext(ctx)
if !ok {
return
}
md.Set(headerName, requestID)
}

View file

@ -0,0 +1,56 @@
package requestid
import (
"context"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
)
type grpcStream struct {
grpc.ServerStream
ctx context.Context
}
func (ss grpcStream) Context() context.Context {
return ss.ctx
}
// StreamServerInterceptor returns a new gRPC StreamServerInterceptor which populates the request id
// from the incoming metadata.
func StreamServerInterceptor() grpc.StreamServerInterceptor {
return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
ctx := ss.Context()
requestID := fromMetadata(ctx)
ctx = WithValue(ctx, requestID)
ss = grpcStream{
ServerStream: ss,
ctx: ctx,
}
return handler(srv, ss)
}
}
// UnaryServerInterceptor returns a new gRPC UnaryServerInterceptor which populates the request id
// from the incoming metadata.
func UnaryServerInterceptor() grpc.UnaryServerInterceptor {
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
requestID := fromMetadata(ctx)
ctx = WithValue(ctx, requestID)
return handler(ctx, req)
}
}
func fromMetadata(ctx context.Context) string {
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
return New()
}
headers := md.Get(headerName)
if len(headers) == 0 || headers[0] == "" {
return New()
}
return headers[0]
}

View file

@ -0,0 +1,48 @@
package requestid
import "net/http"
type transport struct {
base http.RoundTripper
}
// NewRoundTripper creates a new RoundTripper which adds the request id to the outgoing headers.
func NewRoundTripper(base http.RoundTripper) http.RoundTripper {
return &transport{base: base}
}
func (t *transport) RoundTrip(req *http.Request) (res *http.Response, err error) {
requestID := FromContext(req.Context())
if requestID != "" && req.Header.Get(headerName) == "" {
req.Header.Set(headerName, requestID)
}
return t.base.RoundTrip(req)
}
type httpMiddleware struct {
next http.Handler
}
// HTTPMiddleware creates a new http middleware that populates the request id.
func HTTPMiddleware() func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return httpMiddleware{next: next}
}
}
func (h httpMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request) {
requestID := FromHTTPHeader(r.Header)
if requestID == "" {
requestID = New()
}
ctx := WithValue(r.Context(), requestID)
r = r.WithContext(ctx)
h.next.ServeHTTP(w, r)
}
// FromHTTPHeader returns the request id in the HTTP header. If no request id exists,
// an empty string is returned.
func FromHTTPHeader(hdr http.Header) string {
return hdr.Get(headerName)
}

View file

@ -0,0 +1,30 @@
// Package requestid has functions for working with x-request-id in http/gRPC requests.
package requestid
import (
"context"
shortuuid "github.com/lithammer/shortuuid/v3"
)
const headerName = "x-request-id"
var contextKey struct{}
// WithValue returns a new context from the parent context with a request id value set.
func WithValue(parent context.Context, requestID string) context.Context {
return context.WithValue(parent, contextKey, requestID)
}
// FromContext gets the request id from a context.
func FromContext(ctx context.Context) string {
if id, ok := ctx.Value(contextKey).(string); ok {
return id
}
return ""
}
// New creates a new request id.
func New() string {
return shortuuid.New()
}