[tracing] refactor to use custom extension for trace id editing (#5420)

refactor to use custom extension for trace id editing
This commit is contained in:
Joe Kralicky 2025-01-08 16:06:33 -05:00
parent de68673819
commit 86bf8a1d5f
No known key found for this signature in database
GPG key ID: 75C4875F34A9FB79
36 changed files with 1144 additions and 2672 deletions

View file

@ -3,161 +3,35 @@ package trace_test
import (
"context"
"errors"
"net"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/gorilla/mux"
"github.com/pomerium/pomerium/internal/telemetry/trace"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
sdktrace "go.opentelemetry.io/otel/sdk/trace"
oteltrace "go.opentelemetry.io/otel/trace"
"go.opentelemetry.io/otel/trace/noop"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/interop/grpc_testing"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/stats"
"google.golang.org/grpc/test/bufconn"
)
var cases = []struct {
name string
setTraceparent string
setPomeriumTraceparent string
check func(t testing.TB, ctx context.Context)
}{
{
name: "x-pomerium-traceparent not present",
setTraceparent: Traceparent(Trace(1), Span(1), true),
check: func(t testing.TB, ctx context.Context) {
span := oteltrace.SpanFromContext(ctx)
assert.Equal(t, Trace(1).ID().Value(), span.SpanContext().TraceID())
assert.Equal(t, Span(1).ID(), span.SpanContext().SpanID())
assert.True(t, span.SpanContext().IsSampled())
},
},
{
name: "x-pomerium-traceparent present",
setTraceparent: Traceparent(Trace(2), Span(2), true),
setPomeriumTraceparent: Traceparent(Trace(1), Span(1), true),
check: func(t testing.TB, ctx context.Context) {
span := oteltrace.SpanFromContext(ctx)
assert.Equal(t, Trace(1).ID().Value(), span.SpanContext().TraceID())
assert.Equal(t, Span(2).ID(), span.SpanContext().SpanID())
assert.True(t, span.SpanContext().IsSampled())
},
},
{
name: "x-pomerium-traceparent present, force sampling off",
setTraceparent: Traceparent(Trace(2), Span(2), true),
setPomeriumTraceparent: Traceparent(Trace(1), Span(1), false),
check: func(t testing.TB, ctx context.Context) {
span := oteltrace.SpanFromContext(ctx)
assert.Equal(t, Trace(1).ID().Value(), span.SpanContext().TraceID())
assert.Equal(t, Span(2).ID(), span.SpanContext().SpanID())
assert.Equal(t, false, span.SpanContext().IsSampled())
},
},
{
name: "x-pomerium-traceparent present, force sampling on",
setTraceparent: Traceparent(Trace(2), Span(2), false),
setPomeriumTraceparent: Traceparent(Trace(1), Span(1), true),
check: func(t testing.TB, ctx context.Context) {
span := oteltrace.SpanFromContext(ctx)
assert.Equal(t, Trace(1).ID().Value(), span.SpanContext().TraceID())
assert.Equal(t, Span(2).ID(), span.SpanContext().SpanID())
assert.Equal(t, true, span.SpanContext().IsSampled())
},
},
{
name: "malformed x-pomerium-traceparent",
setTraceparent: Traceparent(Trace(2), Span(2), false),
setPomeriumTraceparent: "00-xxxxxx-yyyyyy-03",
check: func(t testing.TB, ctx context.Context) {
span := oteltrace.SpanFromContext(ctx)
assert.Equal(t, Trace(2).ID().Value(), span.SpanContext().TraceID())
assert.Equal(t, Span(2).ID(), span.SpanContext().SpanID())
assert.Equal(t, false, span.SpanContext().IsSampled())
},
},
}
func TestHTTPMiddleware(t *testing.T) {
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
r := httptest.NewRequest(http.MethodGet, "/foo", nil)
if tc.setTraceparent != "" {
r.Header.Add("Traceparent", tc.setTraceparent)
}
if tc.setPomeriumTraceparent != "" {
r.Header.Add("X-Pomerium-Traceparent", tc.setPomeriumTraceparent)
}
w := httptest.NewRecorder()
trace.NewHTTPMiddleware(
otelhttp.WithTracerProvider(noop.NewTracerProvider()),
)(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
tc.check(t, r.Context())
})).ServeHTTP(w, r)
})
}
}
func TestGRPCMiddleware(t *testing.T) {
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
srv := grpc.NewServer(
grpc.StatsHandler(trace.NewServerStatsHandler(otelgrpc.NewServerHandler(
otelgrpc.WithTracerProvider(noop.NewTracerProvider())))),
grpc.Creds(insecure.NewCredentials()),
)
lis := bufconn.Listen(4096)
grpc_testing.RegisterTestServiceServer(srv, &testServer{
fn: func(ctx context.Context) {
tc.check(t, ctx)
},
})
go srv.Serve(lis)
t.Cleanup(srv.Stop)
client, err := grpc.NewClient("passthrough://ignore",
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithStatsHandler(otelgrpc.NewClientHandler(
otelgrpc.WithTracerProvider(noop.NewTracerProvider()))),
grpc.WithContextDialer(func(ctx context.Context, _ string) (net.Conn, error) {
return lis.DialContext(ctx)
}),
)
require.NoError(t, err)
ctx := context.Background()
if tc.setTraceparent != "" {
ctx = metadata.AppendToOutgoingContext(ctx,
"traceparent", tc.setTraceparent,
)
}
if tc.setPomeriumTraceparent != "" {
ctx = metadata.AppendToOutgoingContext(ctx,
"x-pomerium-traceparent", tc.setPomeriumTraceparent,
)
}
_, err = grpc_testing.NewTestServiceClient(client).EmptyCall(ctx, &grpc_testing.Empty{})
assert.NoError(t, err)
})
}
}
type testServer struct {
grpc_testing.UnimplementedTestServiceServer
fn func(ctx context.Context)
}
func (ts *testServer) EmptyCall(ctx context.Context, _ *grpc_testing.Empty) (*grpc_testing.Empty, error) {
ts.fn(ctx)
return &grpc_testing.Empty{}, nil
router := mux.NewRouter()
tp := sdktrace.NewTracerProvider()
router.Use(trace.NewHTTPMiddleware(
otelhttp.WithTracerProvider(tp),
))
router.Path("/foo").HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
span := oteltrace.SpanFromContext(r.Context())
assert.Equal(t, "Server: GET /foo", span.(interface{ Name() string }).Name())
}).Methods(http.MethodGet)
w := httptest.NewRecorder()
ctx, span := tp.Tracer("test").Start(context.Background(), "test")
router.ServeHTTP(w, httptest.NewRequestWithContext(ctx, http.MethodGet, "/foo", nil))
span.End()
}
type mockHandler struct {