core/telemetry: move requestid to pkg directory (#4911)

This commit is contained in:
Caleb Doxsey 2024-01-19 13:18:16 -07:00 committed by GitHub
parent 803baeb9e1
commit 4301da3648
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
18 changed files with 13 additions and 13 deletions

View file

@ -12,8 +12,8 @@ import (
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/telemetry"
"github.com/pomerium/pomerium/internal/telemetry/requestid"
"github.com/pomerium/pomerium/pkg/grpcutil"
"github.com/pomerium/pomerium/pkg/telemetry/requestid"
)
// Options contains options for connecting to a pomerium rpc service.

View file

@ -0,0 +1,40 @@
package requestid
import (
"context"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
)
// StreamClientInterceptor returns a new gRPC StreamClientInterceptor which puts the request ID in the outgoing
// metadata.
func StreamClientInterceptor() grpc.StreamClientInterceptor {
return func(ctx context.Context,
desc *grpc.StreamDesc, cc *grpc.ClientConn,
method string, streamer grpc.Streamer, opts ...grpc.CallOption,
) (grpc.ClientStream, error) {
ctx = toMetadata(ctx)
return streamer(ctx, desc, cc, method, opts...)
}
}
// UnaryClientInterceptor returns a new gRPC UnaryClientInterceptor which puts the request ID in the outgoing
// metadata.
func UnaryClientInterceptor() grpc.UnaryClientInterceptor {
return func(ctx context.Context,
method string, req, reply interface{},
cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption,
) error {
ctx = toMetadata(ctx)
return invoker(ctx, method, req, reply, cc, opts...)
}
}
func toMetadata(ctx context.Context) context.Context {
requestID := FromContext(ctx)
if requestID == "" {
requestID = New()
}
return metadata.AppendToOutgoingContext(ctx, headerName, requestID)
}

View file

@ -0,0 +1,56 @@
package requestid
import (
"context"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
)
type grpcStream struct {
grpc.ServerStream
ctx context.Context
}
func (ss grpcStream) Context() context.Context {
return ss.ctx
}
// StreamServerInterceptor returns a new gRPC StreamServerInterceptor which populates the request id
// from the incoming metadata.
func StreamServerInterceptor() grpc.StreamServerInterceptor {
return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
ctx := ss.Context()
requestID := fromMetadata(ctx)
ctx = WithValue(ctx, requestID)
ss = grpcStream{
ServerStream: ss,
ctx: ctx,
}
return handler(srv, ss)
}
}
// UnaryServerInterceptor returns a new gRPC UnaryServerInterceptor which populates the request id
// from the incoming metadata.
func UnaryServerInterceptor() grpc.UnaryServerInterceptor {
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
requestID := fromMetadata(ctx)
ctx = WithValue(ctx, requestID)
return handler(ctx, req)
}
}
func fromMetadata(ctx context.Context) string {
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
return New()
}
headers := md.Get(headerName)
if len(headers) == 0 || headers[0] == "" {
return New()
}
return headers[0]
}

View file

@ -0,0 +1,48 @@
package requestid
import "net/http"
type transport struct {
base http.RoundTripper
}
// NewRoundTripper creates a new RoundTripper which adds the request id to the outgoing headers.
func NewRoundTripper(base http.RoundTripper) http.RoundTripper {
return &transport{base: base}
}
func (t *transport) RoundTrip(req *http.Request) (res *http.Response, err error) {
requestID := FromContext(req.Context())
if requestID != "" && req.Header.Get(headerName) == "" {
req.Header.Set(headerName, requestID)
}
return t.base.RoundTrip(req)
}
type httpMiddleware struct {
next http.Handler
}
// HTTPMiddleware creates a new http middleware that populates the request id.
func HTTPMiddleware() func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return httpMiddleware{next: next}
}
}
func (h httpMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request) {
requestID := FromHTTPHeader(r.Header)
if requestID == "" {
requestID = New()
}
ctx := WithValue(r.Context(), requestID)
r = r.WithContext(ctx)
h.next.ServeHTTP(w, r)
}
// FromHTTPHeader returns the request id in the HTTP header. If no request id exists,
// an empty string is returned.
func FromHTTPHeader(hdr http.Header) string {
return hdr.Get(headerName)
}

View file

@ -0,0 +1,33 @@
// Package requestid has functions for working with x-request-id in http/gRPC requests.
package requestid
import (
"context"
"github.com/google/uuid"
"github.com/pomerium/pomerium/pkg/encoding/base58"
)
const headerName = "x-request-id"
type contextKey struct{}
// WithValue returns a new context from the parent context with a request id value set.
func WithValue(parent context.Context, requestID string) context.Context {
return context.WithValue(parent, contextKey{}, requestID)
}
// FromContext gets the request id from a context.
func FromContext(ctx context.Context) string {
if id, ok := ctx.Value(contextKey{}).(string); ok {
return id
}
return ""
}
// New creates a new request id.
func New() string {
id := uuid.New()
return base58.Encode(id[:])
}

View file

@ -0,0 +1,15 @@
package requestid
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
)
func TestFromContext(t *testing.T) {
id := New()
ctx := WithValue(context.Background(), id)
ctxID := FromContext(ctx)
assert.Equal(t, ctxID, id)
}