mirror of
https://github.com/pomerium/pomerium.git
synced 2025-08-06 10:21:05 +02:00
envoy: use envoy request id for logging across systems with http and gRPC (#691)
This commit is contained in:
parent
593c47f8ac
commit
41855e5419
16 changed files with 228 additions and 253 deletions
|
@ -14,6 +14,7 @@ import (
|
||||||
"github.com/pomerium/pomerium/internal/httputil"
|
"github.com/pomerium/pomerium/internal/httputil"
|
||||||
"github.com/pomerium/pomerium/internal/log"
|
"github.com/pomerium/pomerium/internal/log"
|
||||||
"github.com/pomerium/pomerium/internal/sessions"
|
"github.com/pomerium/pomerium/internal/sessions"
|
||||||
|
"github.com/pomerium/pomerium/internal/telemetry/requestid"
|
||||||
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
||||||
"github.com/pomerium/pomerium/internal/urlutil"
|
"github.com/pomerium/pomerium/internal/urlutil"
|
||||||
|
|
||||||
|
@ -74,7 +75,8 @@ func (a *Authorize) Check(ctx context.Context, in *envoy_service_auth_v2.CheckRe
|
||||||
|
|
||||||
evt := log.Info().Str("service", "authorize")
|
evt := log.Info().Str("service", "authorize")
|
||||||
// request
|
// request
|
||||||
evt = evt.Str("request-id", hattrs.GetId())
|
evt = evt.Str("request-id", requestid.FromContext(ctx))
|
||||||
|
evt = evt.Strs("check-request-id", hdrs["X-Request-Id"])
|
||||||
evt = evt.Str("method", hattrs.GetMethod())
|
evt = evt.Str("method", hattrs.GetMethod())
|
||||||
evt = evt.Interface("headers", hdrs)
|
evt = evt.Interface("headers", hdrs)
|
||||||
evt = evt.Str("path", hattrs.GetPath())
|
evt = evt.Str("path", hattrs.GetPath())
|
||||||
|
@ -100,6 +102,10 @@ func (a *Authorize) Check(ctx context.Context, in *envoy_service_auth_v2.CheckRe
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if reply.SessionExpired {
|
||||||
|
sesserr = sessions.ErrExpired
|
||||||
|
}
|
||||||
|
|
||||||
switch sesserr {
|
switch sesserr {
|
||||||
case sessions.ErrExpired, sessions.ErrIssuedInTheFuture, sessions.ErrMalformed, sessions.ErrNoSessionFound, sessions.ErrNotValidYet:
|
case sessions.ErrExpired, sessions.ErrIssuedInTheFuture, sessions.ErrMalformed, sessions.ErrNoSessionFound, sessions.ErrNotValidYet:
|
||||||
// redirect to login
|
// redirect to login
|
||||||
|
|
1
go.mod
1
go.mod
|
@ -22,6 +22,7 @@ require (
|
||||||
github.com/gorilla/mux v1.7.4
|
github.com/gorilla/mux v1.7.4
|
||||||
github.com/gorilla/websocket v1.4.0
|
github.com/gorilla/websocket v1.4.0
|
||||||
github.com/kardianos/osext v0.0.0-20190222173326-2bc1f35cddc0 // indirect
|
github.com/kardianos/osext v0.0.0-20190222173326-2bc1f35cddc0 // indirect
|
||||||
|
github.com/lithammer/shortuuid/v3 v3.0.4
|
||||||
github.com/mitchellh/hashstructure v1.0.0
|
github.com/mitchellh/hashstructure v1.0.0
|
||||||
github.com/natefinch/atomic v0.0.0-20150920032501-a62ce929ffcc
|
github.com/natefinch/atomic v0.0.0-20150920032501-a62ce929ffcc
|
||||||
github.com/onsi/ginkgo v1.11.0 // indirect
|
github.com/onsi/ginkgo v1.11.0 // indirect
|
||||||
|
|
2
go.sum
2
go.sum
|
@ -261,6 +261,8 @@ github.com/labbsr0x/bindman-dns-webhook v1.0.2/go.mod h1:p6b+VCXIR8NYKpDr8/dg1HK
|
||||||
github.com/labbsr0x/goh v1.0.1/go.mod h1:8K2UhVoaWXcCU7Lxoa2omWnC8gyW8px7/lmO61c027w=
|
github.com/labbsr0x/goh v1.0.1/go.mod h1:8K2UhVoaWXcCU7Lxoa2omWnC8gyW8px7/lmO61c027w=
|
||||||
github.com/linode/linodego v0.10.0/go.mod h1:cziNP7pbvE3mXIPneHj0oRY8L1WtGEIKlZ8LANE4eXA=
|
github.com/linode/linodego v0.10.0/go.mod h1:cziNP7pbvE3mXIPneHj0oRY8L1WtGEIKlZ8LANE4eXA=
|
||||||
github.com/liquidweb/liquidweb-go v1.6.0/go.mod h1:UDcVnAMDkZxpw4Y7NOHkqoeiGacVLEIG/i5J9cyixzQ=
|
github.com/liquidweb/liquidweb-go v1.6.0/go.mod h1:UDcVnAMDkZxpw4Y7NOHkqoeiGacVLEIG/i5J9cyixzQ=
|
||||||
|
github.com/lithammer/shortuuid/v3 v3.0.4 h1:uj4xhotfY92Y1Oa6n6HUiFn87CdoEHYUlTy0+IgbLrs=
|
||||||
|
github.com/lithammer/shortuuid/v3 v3.0.4/go.mod h1:RviRjexKqIzx/7r1peoAITm6m7gnif/h+0zmolKJjzw=
|
||||||
github.com/magiconair/properties v1.8.1 h1:ZC2Vc7/ZFkGmsVC9KvOjumD+G5lXy2RtTKyzRKO2BQ4=
|
github.com/magiconair/properties v1.8.1 h1:ZC2Vc7/ZFkGmsVC9KvOjumD+G5lXy2RtTKyzRKO2BQ4=
|
||||||
github.com/magiconair/properties v1.8.1/go.mod h1:PppfXfuXeibc/6YijjN8zIbojt8czPbwD3XqdrwzmxQ=
|
github.com/magiconair/properties v1.8.1/go.mod h1:PppfXfuXeibc/6YijjN8zIbojt8czPbwD3XqdrwzmxQ=
|
||||||
github.com/mattn/go-colorable v0.0.9/go.mod h1:9vuHe8Xs5qXnSaW/c/ABM9alt+Vo+STaOChaDxuIBZU=
|
github.com/mattn/go-colorable v0.0.9/go.mod h1:9vuHe8Xs5qXnSaW/c/ABM9alt+Vo+STaOChaDxuIBZU=
|
||||||
|
|
|
@ -6,15 +6,18 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gorilla/handlers"
|
"github.com/gorilla/handlers"
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/internal/frontend"
|
"github.com/pomerium/pomerium/internal/frontend"
|
||||||
"github.com/pomerium/pomerium/internal/httputil"
|
"github.com/pomerium/pomerium/internal/httputil"
|
||||||
"github.com/pomerium/pomerium/internal/log"
|
"github.com/pomerium/pomerium/internal/log"
|
||||||
"github.com/pomerium/pomerium/internal/middleware"
|
"github.com/pomerium/pomerium/internal/middleware"
|
||||||
|
"github.com/pomerium/pomerium/internal/telemetry/requestid"
|
||||||
"github.com/pomerium/pomerium/internal/version"
|
"github.com/pomerium/pomerium/internal/version"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (srv *Server) addHTTPMiddleware() {
|
func (srv *Server) addHTTPMiddleware() {
|
||||||
root := srv.HTTPRouter
|
root := srv.HTTPRouter
|
||||||
|
root.Use(requestid.HTTPMiddleware())
|
||||||
root.Use(log.NewHandler(log.Logger))
|
root.Use(log.NewHandler(log.Logger))
|
||||||
root.Use(log.AccessHandler(func(r *http.Request, status, size int, duration time.Duration) {
|
root.Use(log.AccessHandler(func(r *http.Request, status, size int, duration time.Duration) {
|
||||||
log.FromRequest(r).Debug().
|
log.FromRequest(r).Debug().
|
||||||
|
@ -31,7 +34,7 @@ func (srv *Server) addHTTPMiddleware() {
|
||||||
root.Use(log.RemoteAddrHandler("ip"))
|
root.Use(log.RemoteAddrHandler("ip"))
|
||||||
root.Use(log.UserAgentHandler("user_agent"))
|
root.Use(log.UserAgentHandler("user_agent"))
|
||||||
root.Use(log.RefererHandler("referer"))
|
root.Use(log.RefererHandler("referer"))
|
||||||
root.Use(log.RequestIDHandler("req_id", "Request-Id"))
|
root.Use(log.RequestIDHandler("request-id"))
|
||||||
root.Use(middleware.Healthcheck("/ping", version.UserAgent()))
|
root.Use(middleware.Healthcheck("/ping", version.UserAgent()))
|
||||||
root.HandleFunc("/healthz", httputil.HealthCheck)
|
root.HandleFunc("/healthz", httputil.HealthCheck)
|
||||||
root.HandleFunc("/ping", httputil.HealthCheck)
|
root.HandleFunc("/ping", httputil.HealthCheck)
|
||||||
|
|
|
@ -14,6 +14,7 @@ import (
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/config"
|
"github.com/pomerium/pomerium/config"
|
||||||
"github.com/pomerium/pomerium/internal/log"
|
"github.com/pomerium/pomerium/internal/log"
|
||||||
|
"github.com/pomerium/pomerium/internal/telemetry/requestid"
|
||||||
)
|
)
|
||||||
|
|
||||||
type versionedOptions struct {
|
type versionedOptions struct {
|
||||||
|
@ -58,7 +59,10 @@ func NewServer() (*Server, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
srv.GRPCServer = grpc.NewServer()
|
srv.GRPCServer = grpc.NewServer(
|
||||||
|
grpc.UnaryInterceptor(requestid.UnaryServerInterceptor()),
|
||||||
|
grpc.StreamInterceptor(requestid.StreamServerInterceptor()),
|
||||||
|
)
|
||||||
reflection.Register(srv.GRPCServer)
|
reflection.Register(srv.GRPCServer)
|
||||||
srv.registerXDSHandlers()
|
srv.registerXDSHandlers()
|
||||||
srv.registerAccessLogHandlers()
|
srv.registerAccessLogHandlers()
|
||||||
|
|
|
@ -14,6 +14,7 @@ import (
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/internal/log"
|
"github.com/pomerium/pomerium/internal/log"
|
||||||
"github.com/pomerium/pomerium/internal/telemetry/metrics"
|
"github.com/pomerium/pomerium/internal/telemetry/metrics"
|
||||||
|
"github.com/pomerium/pomerium/internal/telemetry/requestid"
|
||||||
|
|
||||||
"go.opencensus.io/plugin/ocgrpc"
|
"go.opencensus.io/plugin/ocgrpc"
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
|
@ -60,7 +61,12 @@ func NewGRPCClientConn(opts *Options) (*grpc.ClientConn, error) {
|
||||||
connAddr = fmt.Sprintf("%s:%d", connAddr, defaultGRPCPort)
|
connAddr = fmt.Sprintf("%s:%d", connAddr, defaultGRPCPort)
|
||||||
}
|
}
|
||||||
dialOptions := []grpc.DialOption{
|
dialOptions := []grpc.DialOption{
|
||||||
grpc.WithChainUnaryInterceptor(metrics.GRPCClientInterceptor(opts.ServiceName), grpcTimeoutInterceptor(opts.RequestTimeout)),
|
grpc.WithChainUnaryInterceptor(
|
||||||
|
requestid.UnaryClientInterceptor(),
|
||||||
|
metrics.GRPCClientInterceptor(opts.ServiceName),
|
||||||
|
grpcTimeoutInterceptor(opts.RequestTimeout),
|
||||||
|
),
|
||||||
|
grpc.WithStreamInterceptor(requestid.StreamClientInterceptor()),
|
||||||
grpc.WithStatsHandler(&ocgrpc.ClientHandler{}),
|
grpc.WithStatsHandler(&ocgrpc.ClientHandler{}),
|
||||||
grpc.WithDefaultCallOptions([]grpc.CallOption{grpc.WaitForReady(true)}...),
|
grpc.WithDefaultCallOptions([]grpc.CallOption{grpc.WaitForReady(true)}...),
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,109 +0,0 @@
|
||||||
package grpc
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/tls"
|
|
||||||
"errors"
|
|
||||||
"net"
|
|
||||||
"os"
|
|
||||||
"os/signal"
|
|
||||||
"sync"
|
|
||||||
"syscall"
|
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/internal/log"
|
|
||||||
"github.com/pomerium/pomerium/internal/telemetry/metrics"
|
|
||||||
|
|
||||||
"google.golang.org/grpc"
|
|
||||||
"google.golang.org/grpc/credentials"
|
|
||||||
"google.golang.org/grpc/keepalive"
|
|
||||||
)
|
|
||||||
|
|
||||||
// NewServer creates a new gRPC serve.
|
|
||||||
// It is the callers responsibility to close the resturned server.
|
|
||||||
func NewServer(opt *ServerOptions, registrationFn func(s *grpc.Server), wg *sync.WaitGroup) (*grpc.Server, error) {
|
|
||||||
if opt == nil {
|
|
||||||
opt = defaultServerOptions
|
|
||||||
} else {
|
|
||||||
opt.applyServerDefaults()
|
|
||||||
}
|
|
||||||
ln, err := net.Listen("tcp", opt.Addr)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
grpcOpts := []grpc.ServerOption{
|
|
||||||
grpc.StatsHandler(metrics.NewGRPCServerStatsHandler(opt.ServiceName)),
|
|
||||||
grpc.KeepaliveParams(opt.KeepaliveParams),
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(opt.TLSCertificate) == 1 {
|
|
||||||
cert := credentials.NewServerTLSFromCert(&opt.TLSCertificate[0])
|
|
||||||
grpcOpts = append(grpcOpts, grpc.Creds(cert))
|
|
||||||
} else if !opt.InsecureServer {
|
|
||||||
return nil, errors.New("internal/grpc: unexpected number of certificates")
|
|
||||||
}
|
|
||||||
|
|
||||||
srv := grpc.NewServer(grpcOpts...)
|
|
||||||
registrationFn(srv)
|
|
||||||
log.Info().
|
|
||||||
Str("addr", opt.Addr).
|
|
||||||
Bool("insecure", opt.InsecureServer).
|
|
||||||
Str("service", opt.ServiceName).
|
|
||||||
Interface("grpc-service-info", srv.GetServiceInfo()).
|
|
||||||
Msg("internal/grpc: registered")
|
|
||||||
|
|
||||||
wg.Add(1)
|
|
||||||
go func() {
|
|
||||||
defer wg.Done()
|
|
||||||
if err := srv.Serve(ln); err != grpc.ErrServerStopped {
|
|
||||||
log.Error().Str("addr", opt.Addr).Err(err).Msg("internal/grpc: unexpected shutdown")
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
return srv, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ServerOptions contains the configurations settings for a gRPC server.
|
|
||||||
type ServerOptions struct {
|
|
||||||
// Addr specifies the host and port on which the server should serve
|
|
||||||
// gRPC requests. If empty, ":443" is used.
|
|
||||||
Addr string
|
|
||||||
|
|
||||||
// TLS certificates to use, if any.
|
|
||||||
TLSCertificate []tls.Certificate
|
|
||||||
|
|
||||||
// InsecureServer when enabled disables all transport security.
|
|
||||||
// In this mode, Pomerium is susceptible to man-in-the-middle attacks.
|
|
||||||
// This should be used only for testing.
|
|
||||||
InsecureServer bool
|
|
||||||
|
|
||||||
// KeepaliveParams sets GRPC keepalive.ServerParameters
|
|
||||||
KeepaliveParams keepalive.ServerParameters
|
|
||||||
|
|
||||||
// ServiceName specifies the service name for telemetry exposition
|
|
||||||
ServiceName string
|
|
||||||
}
|
|
||||||
|
|
||||||
var defaultServerOptions = &ServerOptions{
|
|
||||||
Addr: ":443",
|
|
||||||
}
|
|
||||||
|
|
||||||
func (o *ServerOptions) applyServerDefaults() {
|
|
||||||
if o.Addr == "" {
|
|
||||||
o.Addr = defaultServerOptions.Addr
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
// Shutdown attempts to shut down the server when a os interrupt or sigterm
|
|
||||||
// signal are received without interrupting any
|
|
||||||
// active connections. Shutdown stops the server from
|
|
||||||
// accepting new connections and RPCs and blocks until all the pending RPCs are
|
|
||||||
// finished.
|
|
||||||
func Shutdown(srv *grpc.Server) {
|
|
||||||
sigint := make(chan os.Signal, 1)
|
|
||||||
signal.Notify(sigint, os.Interrupt)
|
|
||||||
signal.Notify(sigint, syscall.SIGTERM)
|
|
||||||
rec := <-sigint
|
|
||||||
log.Info().Str("signal", rec.String()).Msg("internal/grpc: shutting down servers")
|
|
||||||
srv.GracefulStop()
|
|
||||||
log.Info().Str("signal", rec.String()).Msg("internal/grpc: shut down servers")
|
|
||||||
}
|
|
|
@ -1,92 +0,0 @@
|
||||||
package grpc
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/tls"
|
|
||||||
"encoding/base64"
|
|
||||||
"os"
|
|
||||||
"os/signal"
|
|
||||||
"sync"
|
|
||||||
"syscall"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/internal/cryptutil"
|
|
||||||
|
|
||||||
"google.golang.org/grpc"
|
|
||||||
"google.golang.org/grpc/keepalive"
|
|
||||||
)
|
|
||||||
|
|
||||||
const privKey = `-----BEGIN EC PRIVATE KEY-----
|
|
||||||
MHcCAQEEIMQiDy26/R4ca/OdnjIf8OEDeHcw8yB5SDV9FD500CW5oAoGCCqGSM49
|
|
||||||
AwEHoUQDQgAEFumdSrEe9dnPEUU3LuyC8l6MM6PefNgpSsRL4GrD22XITMjqDKFr
|
|
||||||
jqJTf0Fo1ZWm4v+Eds6s88rsLzEC+cKLRQ==
|
|
||||||
-----END EC PRIVATE KEY-----`
|
|
||||||
const pubKey = `-----BEGIN CERTIFICATE-----
|
|
||||||
MIIBeDCCAR+gAwIBAgIUUGE8w2S7XzpkVLbNq5QUxyVOwqEwCgYIKoZIzj0EAwIw
|
|
||||||
ETEPMA0GA1UEAwwGdW51c2VkMCAXDTE5MDcxNTIzNDQyOVoYDzQ3NTcwNjExMjM0
|
|
||||||
NDI5WjARMQ8wDQYDVQQDDAZ1bnVzZWQwWTATBgcqhkjOPQIBBggqhkjOPQMBBwNC
|
|
||||||
AAQW6Z1KsR712c8RRTcu7ILyXowzo9582ClKxEvgasPbZchMyOoMoWuOolN/QWjV
|
|
||||||
labi/4R2zqzzyuwvMQL5wotFo1MwUTAdBgNVHQ4EFgQURYdcaniRqBHXeaM79LtV
|
|
||||||
pyJ4EwAwHwYDVR0jBBgwFoAURYdcaniRqBHXeaM79LtVpyJ4EwAwDwYDVR0TAQH/
|
|
||||||
BAUwAwEB/zAKBggqhkjOPQQDAgNHADBEAiBHbhVnGbwXqaMZ1dB8eBAK56jyeWDZ
|
|
||||||
2PWXmFMTu7+RywIgaZ7UwVNB2k7KjEEBiLm0PIRcpJmczI2cP9+ZMIkPHHw=
|
|
||||||
-----END CERTIFICATE-----`
|
|
||||||
|
|
||||||
func TestNewServer(t *testing.T) {
|
|
||||||
// to make friendly to testing environments where 443 requires root
|
|
||||||
defaultServerOptions.Addr = ":0"
|
|
||||||
certb64, err := cryptutil.CertifcateFromBase64(
|
|
||||||
base64.StdEncoding.EncodeToString([]byte(pubKey)),
|
|
||||||
base64.StdEncoding.EncodeToString([]byte(privKey)))
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
opt *ServerOptions
|
|
||||||
registrationFn func(s *grpc.Server)
|
|
||||||
wg *sync.WaitGroup
|
|
||||||
wantNil bool
|
|
||||||
wantErr bool
|
|
||||||
}{
|
|
||||||
{"simple", &ServerOptions{Addr: ":0", InsecureServer: true}, func(s *grpc.Server) {}, &sync.WaitGroup{}, false, false},
|
|
||||||
{"simple keepalive options", &ServerOptions{Addr: ":0", InsecureServer: true, KeepaliveParams: keepalive.ServerParameters{MaxConnectionAge: 5 * time.Minute}}, func(s *grpc.Server) {}, &sync.WaitGroup{}, false, false},
|
|
||||||
{"bad tcp port", &ServerOptions{Addr: ":9999999"}, func(s *grpc.Server) {}, &sync.WaitGroup{}, true, true},
|
|
||||||
{"with cert", &ServerOptions{Addr: ":0", TLSCertificate: []tls.Certificate{*certb64}}, func(s *grpc.Server) {}, &sync.WaitGroup{}, false, false},
|
|
||||||
{"with multiple certs", &ServerOptions{Addr: ":0", TLSCertificate: []tls.Certificate{*certb64, *certb64}}, func(s *grpc.Server) {}, &sync.WaitGroup{}, true, true},
|
|
||||||
{"with no certs or insecure", &ServerOptions{Addr: ":0", TLSCertificate: []tls.Certificate{}}, func(s *grpc.Server) {}, &sync.WaitGroup{}, true, true},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
got, err := NewServer(tt.opt, tt.registrationFn, tt.wg)
|
|
||||||
if (err != nil) != tt.wantErr {
|
|
||||||
t.Errorf("NewServer() error = %v, wantErr %v", err, tt.wantErr)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if (got == nil) != tt.wantNil {
|
|
||||||
t.Errorf("NewServer() = %v, want %v", got, tt.wantNil)
|
|
||||||
}
|
|
||||||
if got != nil {
|
|
||||||
// simulate a sigterm and cleanup the server
|
|
||||||
c := make(chan os.Signal, 1)
|
|
||||||
signal.Notify(c, syscall.SIGINT)
|
|
||||||
defer signal.Stop(c)
|
|
||||||
go Shutdown(got)
|
|
||||||
syscall.Kill(syscall.Getpid(), syscall.SIGINT)
|
|
||||||
waitSig(t, c, syscall.SIGINT)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func waitSig(t *testing.T, c <-chan os.Signal, sig os.Signal) {
|
|
||||||
select {
|
|
||||||
case s := <-c:
|
|
||||||
if s != sig {
|
|
||||||
t.Fatalf("signal was %v, want %v", s, sig)
|
|
||||||
}
|
|
||||||
case <-time.After(1 * time.Second):
|
|
||||||
t.Fatalf("timeout waiting for %v", sig)
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -13,6 +13,8 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"go.opencensus.io/plugin/ochttp"
|
"go.opencensus.io/plugin/ochttp"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/internal/telemetry/requestid"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ErrTokenRevoked signifies a token revokation or expiration error
|
// ErrTokenRevoked signifies a token revokation or expiration error
|
||||||
|
@ -22,7 +24,7 @@ var ErrTokenRevoked = errors.New("token expired or revoked")
|
||||||
var DefaultClient = &http.Client{
|
var DefaultClient = &http.Client{
|
||||||
Timeout: 1 * time.Minute,
|
Timeout: 1 * time.Minute,
|
||||||
//todo(bdd): incorporate metrics.HTTPMetricsRoundTripper
|
//todo(bdd): incorporate metrics.HTTPMetricsRoundTripper
|
||||||
Transport: &ochttp.Transport{},
|
Transport: requestid.NewRoundTripper(&ochttp.Transport{}),
|
||||||
}
|
}
|
||||||
|
|
||||||
// Client provides a simple helper interface to make HTTP requests
|
// Client provides a simple helper interface to make HTTP requests
|
||||||
|
|
|
@ -7,6 +7,7 @@ import (
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/internal/frontend"
|
"github.com/pomerium/pomerium/internal/frontend"
|
||||||
"github.com/pomerium/pomerium/internal/log"
|
"github.com/pomerium/pomerium/internal/log"
|
||||||
|
"github.com/pomerium/pomerium/internal/telemetry/requestid"
|
||||||
"github.com/pomerium/pomerium/internal/urlutil"
|
"github.com/pomerium/pomerium/internal/urlutil"
|
||||||
"github.com/pomerium/pomerium/internal/version"
|
"github.com/pomerium/pomerium/internal/version"
|
||||||
)
|
)
|
||||||
|
@ -65,10 +66,7 @@ func (e *HTTPError) ErrorResponse(w http.ResponseWriter, r *http.Request) {
|
||||||
w.WriteHeader(e.Status)
|
w.WriteHeader(e.Status)
|
||||||
|
|
||||||
log.FromRequest(r).Info().Err(e).Msg("httputil: ErrorResponse")
|
log.FromRequest(r).Info().Err(e).Msg("httputil: ErrorResponse")
|
||||||
var requestID string
|
requestID := requestid.FromContext(r.Context())
|
||||||
if id, ok := log.IDFromRequest(r); ok {
|
|
||||||
requestID = id
|
|
||||||
}
|
|
||||||
response := errResponse{
|
response := errResponse{
|
||||||
Status: e.Status,
|
Status: e.Status,
|
||||||
StatusText: http.StatusText(e.Status),
|
StatusText: http.StatusText(e.Status),
|
||||||
|
|
|
@ -1,15 +1,16 @@
|
||||||
package log
|
package log
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/internal/middleware/responsewriter"
|
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/internal/middleware/responsewriter"
|
||||||
|
"github.com/pomerium/pomerium/internal/telemetry/requestid"
|
||||||
)
|
)
|
||||||
|
|
||||||
// NewHandler injects log into requests context.
|
// NewHandler injects log into requests context.
|
||||||
|
@ -115,45 +116,18 @@ func RefererHandler(fieldKey string) func(next http.Handler) http.Handler {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type idKey struct{}
|
// RequestIDHandler adds the request's id as a field to the context's logger
|
||||||
|
// using fieldKey as field key.
|
||||||
// IDFromRequest returns the unique id associated to the request if any.
|
func RequestIDHandler(fieldKey string) func(next http.Handler) http.Handler {
|
||||||
func IDFromRequest(r *http.Request) (id string, ok bool) {
|
|
||||||
if r == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
return IDFromCtx(r.Context())
|
|
||||||
}
|
|
||||||
|
|
||||||
// IDFromCtx returns the unique id associated to the context if any.
|
|
||||||
func IDFromCtx(ctx context.Context) (id string, ok bool) {
|
|
||||||
id, ok = ctx.Value(idKey{}).(string)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// RequestIDHandler returns a handler setting a unique id to the request which can
|
|
||||||
// be gathered using IDFromRequest(req). This generated id is added as a field to the
|
|
||||||
// logger using the passed fieldKey as field name. The id is also added as a response
|
|
||||||
// header if the headerName is not empty.
|
|
||||||
func RequestIDHandler(fieldKey, headerName string) func(next http.Handler) http.Handler {
|
|
||||||
return func(next http.Handler) http.Handler {
|
return func(next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
ctx := r.Context()
|
requestID := requestid.FromContext(r.Context())
|
||||||
id, ok := IDFromRequest(r)
|
if requestID != "" {
|
||||||
if !ok {
|
log := zerolog.Ctx(r.Context())
|
||||||
id = uuid()
|
|
||||||
ctx = context.WithValue(ctx, idKey{}, id)
|
|
||||||
r = r.WithContext(ctx)
|
|
||||||
}
|
|
||||||
if fieldKey != "" {
|
|
||||||
log := zerolog.Ctx(ctx)
|
|
||||||
log.UpdateContext(func(c zerolog.Context) zerolog.Context {
|
log.UpdateContext(func(c zerolog.Context) zerolog.Context {
|
||||||
return c.Str(fieldKey, id)
|
return c.Str(fieldKey, requestID)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
if headerName != "" {
|
|
||||||
w.Header().Set(headerName, id)
|
|
||||||
}
|
|
||||||
next.ServeHTTP(w, r)
|
next.ServeHTTP(w, r)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
|
@ -14,6 +14,8 @@ import (
|
||||||
|
|
||||||
"github.com/google/go-cmp/cmp"
|
"github.com/google/go-cmp/cmp"
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/internal/telemetry/requestid"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestGenerateUUID(t *testing.T) {
|
func TestGenerateUUID(t *testing.T) {
|
||||||
|
@ -172,21 +174,19 @@ func TestRequestIDHandler(t *testing.T) {
|
||||||
out := &bytes.Buffer{}
|
out := &bytes.Buffer{}
|
||||||
r := &http.Request{
|
r := &http.Request{
|
||||||
Header: http.Header{
|
Header: http.Header{
|
||||||
"Referer": []string{"http://foo.com/bar"},
|
"X-Request-Id": []string{"1234"},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
h := RequestIDHandler("id", "Request-Id")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
h := RequestIDHandler("request-id")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
id, ok := IDFromRequest(r)
|
requestID := requestid.FromContext(r.Context())
|
||||||
if !ok {
|
|
||||||
t.Fatal("Missing id in request")
|
|
||||||
}
|
|
||||||
l := FromRequest(r)
|
l := FromRequest(r)
|
||||||
l.Log().Msg("")
|
l.Log().Msg("")
|
||||||
if want, got := fmt.Sprintf(`{"id":"%s"}`+"\n", id), decodeIfBinary(out); want != got {
|
if want, got := fmt.Sprintf(`{"request-id":"%s"}`+"\n", requestID), decodeIfBinary(out); want != got {
|
||||||
t.Errorf("Invalid log output, got: %s, want: %s", got, want)
|
t.Errorf("Invalid log output, got: %s, want: %s", got, want)
|
||||||
}
|
}
|
||||||
}))
|
}))
|
||||||
h = NewHandler(zerolog.New(out))(h)
|
h = NewHandler(zerolog.New(out))(h)
|
||||||
|
h = requestid.HTTPMiddleware()(h)
|
||||||
h.ServeHTTP(httptest.NewRecorder(), r)
|
h.ServeHTTP(httptest.NewRecorder(), r)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
46
internal/telemetry/requestid/grpc_client.go
Normal file
46
internal/telemetry/requestid/grpc_client.go
Normal file
|
@ -0,0 +1,46 @@
|
||||||
|
package requestid
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"google.golang.org/grpc"
|
||||||
|
"google.golang.org/grpc/metadata"
|
||||||
|
)
|
||||||
|
|
||||||
|
// StreamClientInterceptor returns a new gRPC StreamClientInterceptor which puts the request ID in the outgoing
|
||||||
|
// metadata.
|
||||||
|
func StreamClientInterceptor() grpc.StreamClientInterceptor {
|
||||||
|
return func(ctx context.Context,
|
||||||
|
desc *grpc.StreamDesc, cc *grpc.ClientConn,
|
||||||
|
method string, streamer grpc.Streamer, opts ...grpc.CallOption,
|
||||||
|
) (grpc.ClientStream, error) {
|
||||||
|
toMetadata(ctx)
|
||||||
|
return streamer(ctx, desc, cc, method, opts...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnaryClientInterceptor returns a new gRPC UnaryClientInterceptor which puts the request ID in the outgoing
|
||||||
|
// metadata.
|
||||||
|
func UnaryClientInterceptor() grpc.UnaryClientInterceptor {
|
||||||
|
return func(ctx context.Context,
|
||||||
|
method string, req, reply interface{},
|
||||||
|
cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption,
|
||||||
|
) error {
|
||||||
|
toMetadata(ctx)
|
||||||
|
return invoker(ctx, method, req, reply, cc, opts...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func toMetadata(ctx context.Context) {
|
||||||
|
requestID := FromContext(ctx)
|
||||||
|
if requestID == "" {
|
||||||
|
requestID = New()
|
||||||
|
}
|
||||||
|
|
||||||
|
md, ok := metadata.FromOutgoingContext(ctx)
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
md.Set(headerName, requestID)
|
||||||
|
}
|
56
internal/telemetry/requestid/grpc_server.go
Normal file
56
internal/telemetry/requestid/grpc_server.go
Normal file
|
@ -0,0 +1,56 @@
|
||||||
|
package requestid
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"google.golang.org/grpc"
|
||||||
|
"google.golang.org/grpc/metadata"
|
||||||
|
)
|
||||||
|
|
||||||
|
type grpcStream struct {
|
||||||
|
grpc.ServerStream
|
||||||
|
ctx context.Context
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ss grpcStream) Context() context.Context {
|
||||||
|
return ss.ctx
|
||||||
|
}
|
||||||
|
|
||||||
|
// StreamServerInterceptor returns a new gRPC StreamServerInterceptor which populates the request id
|
||||||
|
// from the incoming metadata.
|
||||||
|
func StreamServerInterceptor() grpc.StreamServerInterceptor {
|
||||||
|
return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
|
||||||
|
ctx := ss.Context()
|
||||||
|
requestID := fromMetadata(ctx)
|
||||||
|
ctx = WithValue(ctx, requestID)
|
||||||
|
ss = grpcStream{
|
||||||
|
ServerStream: ss,
|
||||||
|
ctx: ctx,
|
||||||
|
}
|
||||||
|
return handler(srv, ss)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnaryServerInterceptor returns a new gRPC UnaryServerInterceptor which populates the request id
|
||||||
|
// from the incoming metadata.
|
||||||
|
func UnaryServerInterceptor() grpc.UnaryServerInterceptor {
|
||||||
|
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
|
||||||
|
requestID := fromMetadata(ctx)
|
||||||
|
ctx = WithValue(ctx, requestID)
|
||||||
|
return handler(ctx, req)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func fromMetadata(ctx context.Context) string {
|
||||||
|
md, ok := metadata.FromIncomingContext(ctx)
|
||||||
|
if !ok {
|
||||||
|
return New()
|
||||||
|
}
|
||||||
|
|
||||||
|
headers := md.Get(headerName)
|
||||||
|
if len(headers) == 0 || headers[0] == "" {
|
||||||
|
return New()
|
||||||
|
}
|
||||||
|
|
||||||
|
return headers[0]
|
||||||
|
}
|
48
internal/telemetry/requestid/http.go
Normal file
48
internal/telemetry/requestid/http.go
Normal file
|
@ -0,0 +1,48 @@
|
||||||
|
package requestid
|
||||||
|
|
||||||
|
import "net/http"
|
||||||
|
|
||||||
|
type transport struct {
|
||||||
|
base http.RoundTripper
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewRoundTripper creates a new RoundTripper which adds the request id to the outgoing headers.
|
||||||
|
func NewRoundTripper(base http.RoundTripper) http.RoundTripper {
|
||||||
|
return &transport{base: base}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *transport) RoundTrip(req *http.Request) (res *http.Response, err error) {
|
||||||
|
requestID := FromContext(req.Context())
|
||||||
|
if requestID != "" && req.Header.Get(headerName) == "" {
|
||||||
|
req.Header.Set(headerName, requestID)
|
||||||
|
}
|
||||||
|
|
||||||
|
return t.base.RoundTrip(req)
|
||||||
|
}
|
||||||
|
|
||||||
|
type httpMiddleware struct {
|
||||||
|
next http.Handler
|
||||||
|
}
|
||||||
|
|
||||||
|
// HTTPMiddleware creates a new http middleware that populates the request id.
|
||||||
|
func HTTPMiddleware() func(next http.Handler) http.Handler {
|
||||||
|
return func(next http.Handler) http.Handler {
|
||||||
|
return httpMiddleware{next: next}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h httpMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
|
requestID := FromHTTPHeader(r.Header)
|
||||||
|
if requestID == "" {
|
||||||
|
requestID = New()
|
||||||
|
}
|
||||||
|
ctx := WithValue(r.Context(), requestID)
|
||||||
|
r = r.WithContext(ctx)
|
||||||
|
h.next.ServeHTTP(w, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
// FromHTTPHeader returns the request id in the HTTP header. If no request id exists,
|
||||||
|
// an empty string is returned.
|
||||||
|
func FromHTTPHeader(hdr http.Header) string {
|
||||||
|
return hdr.Get(headerName)
|
||||||
|
}
|
30
internal/telemetry/requestid/requestid.go
Normal file
30
internal/telemetry/requestid/requestid.go
Normal file
|
@ -0,0 +1,30 @@
|
||||||
|
// Package requestid has functions for working with x-request-id in http/gRPC requests.
|
||||||
|
package requestid
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
shortuuid "github.com/lithammer/shortuuid/v3"
|
||||||
|
)
|
||||||
|
|
||||||
|
const headerName = "x-request-id"
|
||||||
|
|
||||||
|
var contextKey struct{}
|
||||||
|
|
||||||
|
// WithValue returns a new context from the parent context with a request id value set.
|
||||||
|
func WithValue(parent context.Context, requestID string) context.Context {
|
||||||
|
return context.WithValue(parent, contextKey, requestID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// FromContext gets the request id from a context.
|
||||||
|
func FromContext(ctx context.Context) string {
|
||||||
|
if id, ok := ctx.Value(contextKey).(string); ok {
|
||||||
|
return id
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// New creates a new request id.
|
||||||
|
func New() string {
|
||||||
|
return shortuuid.New()
|
||||||
|
}
|
Loading…
Add table
Add a link
Reference in a new issue