mirror of
https://github.com/pomerium/pomerium.git
synced 2025-08-04 09:19:39 +02:00
envoy: use envoy request id for logging across systems with http and gRPC (#691)
This commit is contained in:
parent
593c47f8ac
commit
41855e5419
16 changed files with 228 additions and 253 deletions
|
@ -6,15 +6,18 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/gorilla/handlers"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/frontend"
|
||||
"github.com/pomerium/pomerium/internal/httputil"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/internal/middleware"
|
||||
"github.com/pomerium/pomerium/internal/telemetry/requestid"
|
||||
"github.com/pomerium/pomerium/internal/version"
|
||||
)
|
||||
|
||||
func (srv *Server) addHTTPMiddleware() {
|
||||
root := srv.HTTPRouter
|
||||
root.Use(requestid.HTTPMiddleware())
|
||||
root.Use(log.NewHandler(log.Logger))
|
||||
root.Use(log.AccessHandler(func(r *http.Request, status, size int, duration time.Duration) {
|
||||
log.FromRequest(r).Debug().
|
||||
|
@ -31,7 +34,7 @@ func (srv *Server) addHTTPMiddleware() {
|
|||
root.Use(log.RemoteAddrHandler("ip"))
|
||||
root.Use(log.UserAgentHandler("user_agent"))
|
||||
root.Use(log.RefererHandler("referer"))
|
||||
root.Use(log.RequestIDHandler("req_id", "Request-Id"))
|
||||
root.Use(log.RequestIDHandler("request-id"))
|
||||
root.Use(middleware.Healthcheck("/ping", version.UserAgent()))
|
||||
root.HandleFunc("/healthz", httputil.HealthCheck)
|
||||
root.HandleFunc("/ping", httputil.HealthCheck)
|
||||
|
|
|
@ -14,6 +14,7 @@ import (
|
|||
|
||||
"github.com/pomerium/pomerium/config"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/internal/telemetry/requestid"
|
||||
)
|
||||
|
||||
type versionedOptions struct {
|
||||
|
@ -58,7 +59,10 @@ func NewServer() (*Server, error) {
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
srv.GRPCServer = grpc.NewServer()
|
||||
srv.GRPCServer = grpc.NewServer(
|
||||
grpc.UnaryInterceptor(requestid.UnaryServerInterceptor()),
|
||||
grpc.StreamInterceptor(requestid.StreamServerInterceptor()),
|
||||
)
|
||||
reflection.Register(srv.GRPCServer)
|
||||
srv.registerXDSHandlers()
|
||||
srv.registerAccessLogHandlers()
|
||||
|
|
|
@ -14,6 +14,7 @@ import (
|
|||
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/internal/telemetry/metrics"
|
||||
"github.com/pomerium/pomerium/internal/telemetry/requestid"
|
||||
|
||||
"go.opencensus.io/plugin/ocgrpc"
|
||||
"google.golang.org/grpc"
|
||||
|
@ -60,7 +61,12 @@ func NewGRPCClientConn(opts *Options) (*grpc.ClientConn, error) {
|
|||
connAddr = fmt.Sprintf("%s:%d", connAddr, defaultGRPCPort)
|
||||
}
|
||||
dialOptions := []grpc.DialOption{
|
||||
grpc.WithChainUnaryInterceptor(metrics.GRPCClientInterceptor(opts.ServiceName), grpcTimeoutInterceptor(opts.RequestTimeout)),
|
||||
grpc.WithChainUnaryInterceptor(
|
||||
requestid.UnaryClientInterceptor(),
|
||||
metrics.GRPCClientInterceptor(opts.ServiceName),
|
||||
grpcTimeoutInterceptor(opts.RequestTimeout),
|
||||
),
|
||||
grpc.WithStreamInterceptor(requestid.StreamClientInterceptor()),
|
||||
grpc.WithStatsHandler(&ocgrpc.ClientHandler{}),
|
||||
grpc.WithDefaultCallOptions([]grpc.CallOption{grpc.WaitForReady(true)}...),
|
||||
}
|
||||
|
|
|
@ -1,109 +0,0 @@
|
|||
package grpc
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"net"
|
||||
"os"
|
||||
"os/signal"
|
||||
"sync"
|
||||
"syscall"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/internal/telemetry/metrics"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials"
|
||||
"google.golang.org/grpc/keepalive"
|
||||
)
|
||||
|
||||
// NewServer creates a new gRPC serve.
|
||||
// It is the callers responsibility to close the resturned server.
|
||||
func NewServer(opt *ServerOptions, registrationFn func(s *grpc.Server), wg *sync.WaitGroup) (*grpc.Server, error) {
|
||||
if opt == nil {
|
||||
opt = defaultServerOptions
|
||||
} else {
|
||||
opt.applyServerDefaults()
|
||||
}
|
||||
ln, err := net.Listen("tcp", opt.Addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
grpcOpts := []grpc.ServerOption{
|
||||
grpc.StatsHandler(metrics.NewGRPCServerStatsHandler(opt.ServiceName)),
|
||||
grpc.KeepaliveParams(opt.KeepaliveParams),
|
||||
}
|
||||
|
||||
if len(opt.TLSCertificate) == 1 {
|
||||
cert := credentials.NewServerTLSFromCert(&opt.TLSCertificate[0])
|
||||
grpcOpts = append(grpcOpts, grpc.Creds(cert))
|
||||
} else if !opt.InsecureServer {
|
||||
return nil, errors.New("internal/grpc: unexpected number of certificates")
|
||||
}
|
||||
|
||||
srv := grpc.NewServer(grpcOpts...)
|
||||
registrationFn(srv)
|
||||
log.Info().
|
||||
Str("addr", opt.Addr).
|
||||
Bool("insecure", opt.InsecureServer).
|
||||
Str("service", opt.ServiceName).
|
||||
Interface("grpc-service-info", srv.GetServiceInfo()).
|
||||
Msg("internal/grpc: registered")
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
if err := srv.Serve(ln); err != grpc.ErrServerStopped {
|
||||
log.Error().Str("addr", opt.Addr).Err(err).Msg("internal/grpc: unexpected shutdown")
|
||||
}
|
||||
}()
|
||||
|
||||
return srv, nil
|
||||
}
|
||||
|
||||
// ServerOptions contains the configurations settings for a gRPC server.
|
||||
type ServerOptions struct {
|
||||
// Addr specifies the host and port on which the server should serve
|
||||
// gRPC requests. If empty, ":443" is used.
|
||||
Addr string
|
||||
|
||||
// TLS certificates to use, if any.
|
||||
TLSCertificate []tls.Certificate
|
||||
|
||||
// InsecureServer when enabled disables all transport security.
|
||||
// In this mode, Pomerium is susceptible to man-in-the-middle attacks.
|
||||
// This should be used only for testing.
|
||||
InsecureServer bool
|
||||
|
||||
// KeepaliveParams sets GRPC keepalive.ServerParameters
|
||||
KeepaliveParams keepalive.ServerParameters
|
||||
|
||||
// ServiceName specifies the service name for telemetry exposition
|
||||
ServiceName string
|
||||
}
|
||||
|
||||
var defaultServerOptions = &ServerOptions{
|
||||
Addr: ":443",
|
||||
}
|
||||
|
||||
func (o *ServerOptions) applyServerDefaults() {
|
||||
if o.Addr == "" {
|
||||
o.Addr = defaultServerOptions.Addr
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// Shutdown attempts to shut down the server when a os interrupt or sigterm
|
||||
// signal are received without interrupting any
|
||||
// active connections. Shutdown stops the server from
|
||||
// accepting new connections and RPCs and blocks until all the pending RPCs are
|
||||
// finished.
|
||||
func Shutdown(srv *grpc.Server) {
|
||||
sigint := make(chan os.Signal, 1)
|
||||
signal.Notify(sigint, os.Interrupt)
|
||||
signal.Notify(sigint, syscall.SIGTERM)
|
||||
rec := <-sigint
|
||||
log.Info().Str("signal", rec.String()).Msg("internal/grpc: shutting down servers")
|
||||
srv.GracefulStop()
|
||||
log.Info().Str("signal", rec.String()).Msg("internal/grpc: shut down servers")
|
||||
}
|
|
@ -1,92 +0,0 @@
|
|||
package grpc
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"encoding/base64"
|
||||
"os"
|
||||
"os/signal"
|
||||
"sync"
|
||||
"syscall"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/cryptutil"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/keepalive"
|
||||
)
|
||||
|
||||
const privKey = `-----BEGIN EC PRIVATE KEY-----
|
||||
MHcCAQEEIMQiDy26/R4ca/OdnjIf8OEDeHcw8yB5SDV9FD500CW5oAoGCCqGSM49
|
||||
AwEHoUQDQgAEFumdSrEe9dnPEUU3LuyC8l6MM6PefNgpSsRL4GrD22XITMjqDKFr
|
||||
jqJTf0Fo1ZWm4v+Eds6s88rsLzEC+cKLRQ==
|
||||
-----END EC PRIVATE KEY-----`
|
||||
const pubKey = `-----BEGIN CERTIFICATE-----
|
||||
MIIBeDCCAR+gAwIBAgIUUGE8w2S7XzpkVLbNq5QUxyVOwqEwCgYIKoZIzj0EAwIw
|
||||
ETEPMA0GA1UEAwwGdW51c2VkMCAXDTE5MDcxNTIzNDQyOVoYDzQ3NTcwNjExMjM0
|
||||
NDI5WjARMQ8wDQYDVQQDDAZ1bnVzZWQwWTATBgcqhkjOPQIBBggqhkjOPQMBBwNC
|
||||
AAQW6Z1KsR712c8RRTcu7ILyXowzo9582ClKxEvgasPbZchMyOoMoWuOolN/QWjV
|
||||
labi/4R2zqzzyuwvMQL5wotFo1MwUTAdBgNVHQ4EFgQURYdcaniRqBHXeaM79LtV
|
||||
pyJ4EwAwHwYDVR0jBBgwFoAURYdcaniRqBHXeaM79LtVpyJ4EwAwDwYDVR0TAQH/
|
||||
BAUwAwEB/zAKBggqhkjOPQQDAgNHADBEAiBHbhVnGbwXqaMZ1dB8eBAK56jyeWDZ
|
||||
2PWXmFMTu7+RywIgaZ7UwVNB2k7KjEEBiLm0PIRcpJmczI2cP9+ZMIkPHHw=
|
||||
-----END CERTIFICATE-----`
|
||||
|
||||
func TestNewServer(t *testing.T) {
|
||||
// to make friendly to testing environments where 443 requires root
|
||||
defaultServerOptions.Addr = ":0"
|
||||
certb64, err := cryptutil.CertifcateFromBase64(
|
||||
base64.StdEncoding.EncodeToString([]byte(pubKey)),
|
||||
base64.StdEncoding.EncodeToString([]byte(privKey)))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
opt *ServerOptions
|
||||
registrationFn func(s *grpc.Server)
|
||||
wg *sync.WaitGroup
|
||||
wantNil bool
|
||||
wantErr bool
|
||||
}{
|
||||
{"simple", &ServerOptions{Addr: ":0", InsecureServer: true}, func(s *grpc.Server) {}, &sync.WaitGroup{}, false, false},
|
||||
{"simple keepalive options", &ServerOptions{Addr: ":0", InsecureServer: true, KeepaliveParams: keepalive.ServerParameters{MaxConnectionAge: 5 * time.Minute}}, func(s *grpc.Server) {}, &sync.WaitGroup{}, false, false},
|
||||
{"bad tcp port", &ServerOptions{Addr: ":9999999"}, func(s *grpc.Server) {}, &sync.WaitGroup{}, true, true},
|
||||
{"with cert", &ServerOptions{Addr: ":0", TLSCertificate: []tls.Certificate{*certb64}}, func(s *grpc.Server) {}, &sync.WaitGroup{}, false, false},
|
||||
{"with multiple certs", &ServerOptions{Addr: ":0", TLSCertificate: []tls.Certificate{*certb64, *certb64}}, func(s *grpc.Server) {}, &sync.WaitGroup{}, true, true},
|
||||
{"with no certs or insecure", &ServerOptions{Addr: ":0", TLSCertificate: []tls.Certificate{}}, func(s *grpc.Server) {}, &sync.WaitGroup{}, true, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := NewServer(tt.opt, tt.registrationFn, tt.wg)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("NewServer() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if (got == nil) != tt.wantNil {
|
||||
t.Errorf("NewServer() = %v, want %v", got, tt.wantNil)
|
||||
}
|
||||
if got != nil {
|
||||
// simulate a sigterm and cleanup the server
|
||||
c := make(chan os.Signal, 1)
|
||||
signal.Notify(c, syscall.SIGINT)
|
||||
defer signal.Stop(c)
|
||||
go Shutdown(got)
|
||||
syscall.Kill(syscall.Getpid(), syscall.SIGINT)
|
||||
waitSig(t, c, syscall.SIGINT)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func waitSig(t *testing.T, c <-chan os.Signal, sig os.Signal) {
|
||||
select {
|
||||
case s := <-c:
|
||||
if s != sig {
|
||||
t.Fatalf("signal was %v, want %v", s, sig)
|
||||
}
|
||||
case <-time.After(1 * time.Second):
|
||||
t.Fatalf("timeout waiting for %v", sig)
|
||||
}
|
||||
}
|
|
@ -13,6 +13,8 @@ import (
|
|||
"time"
|
||||
|
||||
"go.opencensus.io/plugin/ochttp"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/telemetry/requestid"
|
||||
)
|
||||
|
||||
// ErrTokenRevoked signifies a token revokation or expiration error
|
||||
|
@ -22,7 +24,7 @@ var ErrTokenRevoked = errors.New("token expired or revoked")
|
|||
var DefaultClient = &http.Client{
|
||||
Timeout: 1 * time.Minute,
|
||||
//todo(bdd): incorporate metrics.HTTPMetricsRoundTripper
|
||||
Transport: &ochttp.Transport{},
|
||||
Transport: requestid.NewRoundTripper(&ochttp.Transport{}),
|
||||
}
|
||||
|
||||
// Client provides a simple helper interface to make HTTP requests
|
||||
|
|
|
@ -7,6 +7,7 @@ import (
|
|||
|
||||
"github.com/pomerium/pomerium/internal/frontend"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/internal/telemetry/requestid"
|
||||
"github.com/pomerium/pomerium/internal/urlutil"
|
||||
"github.com/pomerium/pomerium/internal/version"
|
||||
)
|
||||
|
@ -65,10 +66,7 @@ func (e *HTTPError) ErrorResponse(w http.ResponseWriter, r *http.Request) {
|
|||
w.WriteHeader(e.Status)
|
||||
|
||||
log.FromRequest(r).Info().Err(e).Msg("httputil: ErrorResponse")
|
||||
var requestID string
|
||||
if id, ok := log.IDFromRequest(r); ok {
|
||||
requestID = id
|
||||
}
|
||||
requestID := requestid.FromContext(r.Context())
|
||||
response := errResponse{
|
||||
Status: e.Status,
|
||||
StatusText: http.StatusText(e.Status),
|
||||
|
|
|
@ -1,15 +1,16 @@
|
|||
package log
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/middleware/responsewriter"
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/middleware/responsewriter"
|
||||
"github.com/pomerium/pomerium/internal/telemetry/requestid"
|
||||
)
|
||||
|
||||
// NewHandler injects log into requests context.
|
||||
|
@ -115,45 +116,18 @@ func RefererHandler(fieldKey string) func(next http.Handler) http.Handler {
|
|||
}
|
||||
}
|
||||
|
||||
type idKey struct{}
|
||||
|
||||
// IDFromRequest returns the unique id associated to the request if any.
|
||||
func IDFromRequest(r *http.Request) (id string, ok bool) {
|
||||
if r == nil {
|
||||
return
|
||||
}
|
||||
return IDFromCtx(r.Context())
|
||||
}
|
||||
|
||||
// IDFromCtx returns the unique id associated to the context if any.
|
||||
func IDFromCtx(ctx context.Context) (id string, ok bool) {
|
||||
id, ok = ctx.Value(idKey{}).(string)
|
||||
return
|
||||
}
|
||||
|
||||
// RequestIDHandler returns a handler setting a unique id to the request which can
|
||||
// be gathered using IDFromRequest(req). This generated id is added as a field to the
|
||||
// logger using the passed fieldKey as field name. The id is also added as a response
|
||||
// header if the headerName is not empty.
|
||||
func RequestIDHandler(fieldKey, headerName string) func(next http.Handler) http.Handler {
|
||||
// RequestIDHandler adds the request's id as a field to the context's logger
|
||||
// using fieldKey as field key.
|
||||
func RequestIDHandler(fieldKey string) func(next http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
id, ok := IDFromRequest(r)
|
||||
if !ok {
|
||||
id = uuid()
|
||||
ctx = context.WithValue(ctx, idKey{}, id)
|
||||
r = r.WithContext(ctx)
|
||||
}
|
||||
if fieldKey != "" {
|
||||
log := zerolog.Ctx(ctx)
|
||||
requestID := requestid.FromContext(r.Context())
|
||||
if requestID != "" {
|
||||
log := zerolog.Ctx(r.Context())
|
||||
log.UpdateContext(func(c zerolog.Context) zerolog.Context {
|
||||
return c.Str(fieldKey, id)
|
||||
return c.Str(fieldKey, requestID)
|
||||
})
|
||||
}
|
||||
if headerName != "" {
|
||||
w.Header().Set(headerName, id)
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
|
|
@ -14,6 +14,8 @@ import (
|
|||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/telemetry/requestid"
|
||||
)
|
||||
|
||||
func TestGenerateUUID(t *testing.T) {
|
||||
|
@ -172,21 +174,19 @@ func TestRequestIDHandler(t *testing.T) {
|
|||
out := &bytes.Buffer{}
|
||||
r := &http.Request{
|
||||
Header: http.Header{
|
||||
"Referer": []string{"http://foo.com/bar"},
|
||||
"X-Request-Id": []string{"1234"},
|
||||
},
|
||||
}
|
||||
h := RequestIDHandler("id", "Request-Id")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
id, ok := IDFromRequest(r)
|
||||
if !ok {
|
||||
t.Fatal("Missing id in request")
|
||||
}
|
||||
h := RequestIDHandler("request-id")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
requestID := requestid.FromContext(r.Context())
|
||||
l := FromRequest(r)
|
||||
l.Log().Msg("")
|
||||
if want, got := fmt.Sprintf(`{"id":"%s"}`+"\n", id), decodeIfBinary(out); want != got {
|
||||
if want, got := fmt.Sprintf(`{"request-id":"%s"}`+"\n", requestID), decodeIfBinary(out); want != got {
|
||||
t.Errorf("Invalid log output, got: %s, want: %s", got, want)
|
||||
}
|
||||
}))
|
||||
h = NewHandler(zerolog.New(out))(h)
|
||||
h = requestid.HTTPMiddleware()(h)
|
||||
h.ServeHTTP(httptest.NewRecorder(), r)
|
||||
}
|
||||
|
||||
|
|
46
internal/telemetry/requestid/grpc_client.go
Normal file
46
internal/telemetry/requestid/grpc_client.go
Normal file
|
@ -0,0 +1,46 @@
|
|||
package requestid
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/metadata"
|
||||
)
|
||||
|
||||
// StreamClientInterceptor returns a new gRPC StreamClientInterceptor which puts the request ID in the outgoing
|
||||
// metadata.
|
||||
func StreamClientInterceptor() grpc.StreamClientInterceptor {
|
||||
return func(ctx context.Context,
|
||||
desc *grpc.StreamDesc, cc *grpc.ClientConn,
|
||||
method string, streamer grpc.Streamer, opts ...grpc.CallOption,
|
||||
) (grpc.ClientStream, error) {
|
||||
toMetadata(ctx)
|
||||
return streamer(ctx, desc, cc, method, opts...)
|
||||
}
|
||||
}
|
||||
|
||||
// UnaryClientInterceptor returns a new gRPC UnaryClientInterceptor which puts the request ID in the outgoing
|
||||
// metadata.
|
||||
func UnaryClientInterceptor() grpc.UnaryClientInterceptor {
|
||||
return func(ctx context.Context,
|
||||
method string, req, reply interface{},
|
||||
cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption,
|
||||
) error {
|
||||
toMetadata(ctx)
|
||||
return invoker(ctx, method, req, reply, cc, opts...)
|
||||
}
|
||||
}
|
||||
|
||||
func toMetadata(ctx context.Context) {
|
||||
requestID := FromContext(ctx)
|
||||
if requestID == "" {
|
||||
requestID = New()
|
||||
}
|
||||
|
||||
md, ok := metadata.FromOutgoingContext(ctx)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
md.Set(headerName, requestID)
|
||||
}
|
56
internal/telemetry/requestid/grpc_server.go
Normal file
56
internal/telemetry/requestid/grpc_server.go
Normal file
|
@ -0,0 +1,56 @@
|
|||
package requestid
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/metadata"
|
||||
)
|
||||
|
||||
type grpcStream struct {
|
||||
grpc.ServerStream
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
func (ss grpcStream) Context() context.Context {
|
||||
return ss.ctx
|
||||
}
|
||||
|
||||
// StreamServerInterceptor returns a new gRPC StreamServerInterceptor which populates the request id
|
||||
// from the incoming metadata.
|
||||
func StreamServerInterceptor() grpc.StreamServerInterceptor {
|
||||
return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
|
||||
ctx := ss.Context()
|
||||
requestID := fromMetadata(ctx)
|
||||
ctx = WithValue(ctx, requestID)
|
||||
ss = grpcStream{
|
||||
ServerStream: ss,
|
||||
ctx: ctx,
|
||||
}
|
||||
return handler(srv, ss)
|
||||
}
|
||||
}
|
||||
|
||||
// UnaryServerInterceptor returns a new gRPC UnaryServerInterceptor which populates the request id
|
||||
// from the incoming metadata.
|
||||
func UnaryServerInterceptor() grpc.UnaryServerInterceptor {
|
||||
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
|
||||
requestID := fromMetadata(ctx)
|
||||
ctx = WithValue(ctx, requestID)
|
||||
return handler(ctx, req)
|
||||
}
|
||||
}
|
||||
|
||||
func fromMetadata(ctx context.Context) string {
|
||||
md, ok := metadata.FromIncomingContext(ctx)
|
||||
if !ok {
|
||||
return New()
|
||||
}
|
||||
|
||||
headers := md.Get(headerName)
|
||||
if len(headers) == 0 || headers[0] == "" {
|
||||
return New()
|
||||
}
|
||||
|
||||
return headers[0]
|
||||
}
|
48
internal/telemetry/requestid/http.go
Normal file
48
internal/telemetry/requestid/http.go
Normal file
|
@ -0,0 +1,48 @@
|
|||
package requestid
|
||||
|
||||
import "net/http"
|
||||
|
||||
type transport struct {
|
||||
base http.RoundTripper
|
||||
}
|
||||
|
||||
// NewRoundTripper creates a new RoundTripper which adds the request id to the outgoing headers.
|
||||
func NewRoundTripper(base http.RoundTripper) http.RoundTripper {
|
||||
return &transport{base: base}
|
||||
}
|
||||
|
||||
func (t *transport) RoundTrip(req *http.Request) (res *http.Response, err error) {
|
||||
requestID := FromContext(req.Context())
|
||||
if requestID != "" && req.Header.Get(headerName) == "" {
|
||||
req.Header.Set(headerName, requestID)
|
||||
}
|
||||
|
||||
return t.base.RoundTrip(req)
|
||||
}
|
||||
|
||||
type httpMiddleware struct {
|
||||
next http.Handler
|
||||
}
|
||||
|
||||
// HTTPMiddleware creates a new http middleware that populates the request id.
|
||||
func HTTPMiddleware() func(next http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return httpMiddleware{next: next}
|
||||
}
|
||||
}
|
||||
|
||||
func (h httpMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
requestID := FromHTTPHeader(r.Header)
|
||||
if requestID == "" {
|
||||
requestID = New()
|
||||
}
|
||||
ctx := WithValue(r.Context(), requestID)
|
||||
r = r.WithContext(ctx)
|
||||
h.next.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
// FromHTTPHeader returns the request id in the HTTP header. If no request id exists,
|
||||
// an empty string is returned.
|
||||
func FromHTTPHeader(hdr http.Header) string {
|
||||
return hdr.Get(headerName)
|
||||
}
|
30
internal/telemetry/requestid/requestid.go
Normal file
30
internal/telemetry/requestid/requestid.go
Normal file
|
@ -0,0 +1,30 @@
|
|||
// Package requestid has functions for working with x-request-id in http/gRPC requests.
|
||||
package requestid
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
shortuuid "github.com/lithammer/shortuuid/v3"
|
||||
)
|
||||
|
||||
const headerName = "x-request-id"
|
||||
|
||||
var contextKey struct{}
|
||||
|
||||
// WithValue returns a new context from the parent context with a request id value set.
|
||||
func WithValue(parent context.Context, requestID string) context.Context {
|
||||
return context.WithValue(parent, contextKey, requestID)
|
||||
}
|
||||
|
||||
// FromContext gets the request id from a context.
|
||||
func FromContext(ctx context.Context) string {
|
||||
if id, ok := ctx.Value(contextKey).(string); ok {
|
||||
return id
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// New creates a new request id.
|
||||
func New() string {
|
||||
return shortuuid.New()
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue