mirror of
https://github.com/pomerium/pomerium.git
synced 2025-08-03 08:50:42 +02:00
core/telemetry: move requestid to pkg directory (#4911)
This commit is contained in:
parent
803baeb9e1
commit
4301da3648
18 changed files with 13 additions and 13 deletions
|
@ -12,8 +12,8 @@ import (
|
|||
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/internal/telemetry"
|
||||
"github.com/pomerium/pomerium/internal/telemetry/requestid"
|
||||
"github.com/pomerium/pomerium/pkg/grpcutil"
|
||||
"github.com/pomerium/pomerium/pkg/telemetry/requestid"
|
||||
)
|
||||
|
||||
// Options contains options for connecting to a pomerium rpc service.
|
||||
|
|
40
pkg/telemetry/requestid/grpc_client.go
Normal file
40
pkg/telemetry/requestid/grpc_client.go
Normal file
|
@ -0,0 +1,40 @@
|
|||
package requestid
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/metadata"
|
||||
)
|
||||
|
||||
// StreamClientInterceptor returns a new gRPC StreamClientInterceptor which puts the request ID in the outgoing
|
||||
// metadata.
|
||||
func StreamClientInterceptor() grpc.StreamClientInterceptor {
|
||||
return func(ctx context.Context,
|
||||
desc *grpc.StreamDesc, cc *grpc.ClientConn,
|
||||
method string, streamer grpc.Streamer, opts ...grpc.CallOption,
|
||||
) (grpc.ClientStream, error) {
|
||||
ctx = toMetadata(ctx)
|
||||
return streamer(ctx, desc, cc, method, opts...)
|
||||
}
|
||||
}
|
||||
|
||||
// UnaryClientInterceptor returns a new gRPC UnaryClientInterceptor which puts the request ID in the outgoing
|
||||
// metadata.
|
||||
func UnaryClientInterceptor() grpc.UnaryClientInterceptor {
|
||||
return func(ctx context.Context,
|
||||
method string, req, reply interface{},
|
||||
cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption,
|
||||
) error {
|
||||
ctx = toMetadata(ctx)
|
||||
return invoker(ctx, method, req, reply, cc, opts...)
|
||||
}
|
||||
}
|
||||
|
||||
func toMetadata(ctx context.Context) context.Context {
|
||||
requestID := FromContext(ctx)
|
||||
if requestID == "" {
|
||||
requestID = New()
|
||||
}
|
||||
return metadata.AppendToOutgoingContext(ctx, headerName, requestID)
|
||||
}
|
56
pkg/telemetry/requestid/grpc_server.go
Normal file
56
pkg/telemetry/requestid/grpc_server.go
Normal file
|
@ -0,0 +1,56 @@
|
|||
package requestid
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/metadata"
|
||||
)
|
||||
|
||||
type grpcStream struct {
|
||||
grpc.ServerStream
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
func (ss grpcStream) Context() context.Context {
|
||||
return ss.ctx
|
||||
}
|
||||
|
||||
// StreamServerInterceptor returns a new gRPC StreamServerInterceptor which populates the request id
|
||||
// from the incoming metadata.
|
||||
func StreamServerInterceptor() grpc.StreamServerInterceptor {
|
||||
return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
|
||||
ctx := ss.Context()
|
||||
requestID := fromMetadata(ctx)
|
||||
ctx = WithValue(ctx, requestID)
|
||||
ss = grpcStream{
|
||||
ServerStream: ss,
|
||||
ctx: ctx,
|
||||
}
|
||||
return handler(srv, ss)
|
||||
}
|
||||
}
|
||||
|
||||
// UnaryServerInterceptor returns a new gRPC UnaryServerInterceptor which populates the request id
|
||||
// from the incoming metadata.
|
||||
func UnaryServerInterceptor() grpc.UnaryServerInterceptor {
|
||||
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
|
||||
requestID := fromMetadata(ctx)
|
||||
ctx = WithValue(ctx, requestID)
|
||||
return handler(ctx, req)
|
||||
}
|
||||
}
|
||||
|
||||
func fromMetadata(ctx context.Context) string {
|
||||
md, ok := metadata.FromIncomingContext(ctx)
|
||||
if !ok {
|
||||
return New()
|
||||
}
|
||||
|
||||
headers := md.Get(headerName)
|
||||
if len(headers) == 0 || headers[0] == "" {
|
||||
return New()
|
||||
}
|
||||
|
||||
return headers[0]
|
||||
}
|
48
pkg/telemetry/requestid/http.go
Normal file
48
pkg/telemetry/requestid/http.go
Normal file
|
@ -0,0 +1,48 @@
|
|||
package requestid
|
||||
|
||||
import "net/http"
|
||||
|
||||
type transport struct {
|
||||
base http.RoundTripper
|
||||
}
|
||||
|
||||
// NewRoundTripper creates a new RoundTripper which adds the request id to the outgoing headers.
|
||||
func NewRoundTripper(base http.RoundTripper) http.RoundTripper {
|
||||
return &transport{base: base}
|
||||
}
|
||||
|
||||
func (t *transport) RoundTrip(req *http.Request) (res *http.Response, err error) {
|
||||
requestID := FromContext(req.Context())
|
||||
if requestID != "" && req.Header.Get(headerName) == "" {
|
||||
req.Header.Set(headerName, requestID)
|
||||
}
|
||||
|
||||
return t.base.RoundTrip(req)
|
||||
}
|
||||
|
||||
type httpMiddleware struct {
|
||||
next http.Handler
|
||||
}
|
||||
|
||||
// HTTPMiddleware creates a new http middleware that populates the request id.
|
||||
func HTTPMiddleware() func(next http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return httpMiddleware{next: next}
|
||||
}
|
||||
}
|
||||
|
||||
func (h httpMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
requestID := FromHTTPHeader(r.Header)
|
||||
if requestID == "" {
|
||||
requestID = New()
|
||||
}
|
||||
ctx := WithValue(r.Context(), requestID)
|
||||
r = r.WithContext(ctx)
|
||||
h.next.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
// FromHTTPHeader returns the request id in the HTTP header. If no request id exists,
|
||||
// an empty string is returned.
|
||||
func FromHTTPHeader(hdr http.Header) string {
|
||||
return hdr.Get(headerName)
|
||||
}
|
33
pkg/telemetry/requestid/requestid.go
Normal file
33
pkg/telemetry/requestid/requestid.go
Normal file
|
@ -0,0 +1,33 @@
|
|||
// Package requestid has functions for working with x-request-id in http/gRPC requests.
|
||||
package requestid
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/pomerium/pomerium/pkg/encoding/base58"
|
||||
)
|
||||
|
||||
const headerName = "x-request-id"
|
||||
|
||||
type contextKey struct{}
|
||||
|
||||
// WithValue returns a new context from the parent context with a request id value set.
|
||||
func WithValue(parent context.Context, requestID string) context.Context {
|
||||
return context.WithValue(parent, contextKey{}, requestID)
|
||||
}
|
||||
|
||||
// FromContext gets the request id from a context.
|
||||
func FromContext(ctx context.Context) string {
|
||||
if id, ok := ctx.Value(contextKey{}).(string); ok {
|
||||
return id
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// New creates a new request id.
|
||||
func New() string {
|
||||
id := uuid.New()
|
||||
return base58.Encode(id[:])
|
||||
}
|
15
pkg/telemetry/requestid/requestid_test.go
Normal file
15
pkg/telemetry/requestid/requestid_test.go
Normal file
|
@ -0,0 +1,15 @@
|
|||
package requestid
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestFromContext(t *testing.T) {
|
||||
id := New()
|
||||
ctx := WithValue(context.Background(), id)
|
||||
ctxID := FromContext(ctx)
|
||||
assert.Equal(t, ctxID, id)
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue