From 51fa4838856a2dd74430a19d1634f95559045037 Mon Sep 17 00:00:00 2001 From: Joe Kralicky Date: Thu, 5 Dec 2024 04:37:56 +0000 Subject: [PATCH] various bugfixes and improvements --- config/envoyconfig/listeners_main.go | 13 +- internal/telemetry/trace/debug.go | 32 ++ internal/telemetry/trace/global.go | 32 ++ internal/telemetry/trace/middleware.go | 52 +++ internal/telemetry/trace/server.go | 506 +++++++++++++++++---- internal/telemetry/trace/trace.go | 261 ++++------- internal/telemetry/trace/util.go | 46 ++ internal/testenv/environment.go | 41 +- internal/testenv/selftests/tracing_test.go | 55 +++ internal/testenv/upstreams/http.go | 74 +-- pkg/cmd/pomerium/pomerium.go | 9 +- pkg/envoy/envoy.go | 4 + 12 files changed, 819 insertions(+), 306 deletions(-) create mode 100644 internal/telemetry/trace/debug.go create mode 100644 internal/telemetry/trace/global.go create mode 100644 internal/telemetry/trace/util.go create mode 100644 internal/testenv/selftests/tracing_test.go diff --git a/config/envoyconfig/listeners_main.go b/config/envoyconfig/listeners_main.go index fc3c17247..563e0d558 100644 --- a/config/envoyconfig/listeners_main.go +++ b/config/envoyconfig/listeners_main.go @@ -11,6 +11,7 @@ import ( envoy_extensions_access_loggers_grpc_v3 "github.com/envoyproxy/go-control-plane/envoy/extensions/access_loggers/grpc/v3" envoy_extensions_filters_http_header_to_metadata "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/http/header_to_metadata/v3" envoy_extensions_filters_network_http_connection_manager "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/network/http_connection_manager/v3" + envoy_extensions_tracers_otel "github.com/envoyproxy/go-control-plane/envoy/extensions/tracers/opentelemetry/resource_detectors/v3" metadatav3 "github.com/envoyproxy/go-control-plane/envoy/type/metadata/v3" envoy_tracing_v3 "github.com/envoyproxy/go-control-plane/envoy/type/tracing/v3" envoy_type_v3 "github.com/envoyproxy/go-control-plane/envoy/type/v3" @@ -202,7 +203,7 @@ func (b *Builder) buildMainHTTPConnectionManagerFilter( RandomSampling: &envoy_type_v3.Percent{Value: cfg.Options.TracingSampleRate * 100}, ClientSampling: &envoy_type_v3.Percent{Value: cfg.Options.TracingSampleRate * 100}, Verbose: true, - SpawnUpstreamSpan: wrapperspb.Bool(false), + SpawnUpstreamSpan: wrapperspb.Bool(true), Provider: &tracev3.Tracing_Http{ Name: "envoy.tracers.opentelemetry", ConfigType: &tracev3.Tracing_Http_TypedConfig{ @@ -215,6 +216,16 @@ func (b *Builder) buildMainHTTPConnectionManagerFilter( }, }, ServiceName: "Envoy", + ResourceDetectors: []*envoy_config_core_v3.TypedExtensionConfig{ + { + Name: "envoy.tracers.opentelemetry.resource_detectors.static_config", + TypedConfig: marshalAny(&envoy_extensions_tracers_otel.StaticConfigResourceDetectorConfig{ + Attributes: map[string]string{ + "pomerium.envoy": "true", + }, + }), + }, + }, }), }, }, diff --git a/internal/telemetry/trace/debug.go b/internal/telemetry/trace/debug.go new file mode 100644 index 000000000..b541a5e06 --- /dev/null +++ b/internal/telemetry/trace/debug.go @@ -0,0 +1,32 @@ +package trace + +import ( + "context" + "fmt" + "runtime" + + "go.opentelemetry.io/otel/attribute" + sdktrace "go.opentelemetry.io/otel/sdk/trace" +) + +type stackTraceProcessor struct{} + +// ForceFlush implements trace.SpanProcessor. +func (s *stackTraceProcessor) ForceFlush(ctx context.Context) error { + return nil +} + +// OnEnd implements trace.SpanProcessor. +func (*stackTraceProcessor) OnEnd(s sdktrace.ReadOnlySpan) { +} + +// OnStart implements trace.SpanProcessor. +func (*stackTraceProcessor) OnStart(parent context.Context, s sdktrace.ReadWriteSpan) { + _, file, line, _ := runtime.Caller(2) + s.SetAttributes(attribute.String("caller", fmt.Sprintf("%s:%d", file, line))) +} + +// Shutdown implements trace.SpanProcessor. +func (s *stackTraceProcessor) Shutdown(ctx context.Context) error { + return nil +} diff --git a/internal/telemetry/trace/global.go b/internal/telemetry/trace/global.go new file mode 100644 index 000000000..6042ba62a --- /dev/null +++ b/internal/telemetry/trace/global.go @@ -0,0 +1,32 @@ +package trace + +import ( + "context" + + "go.opentelemetry.io/otel/trace" + "go.opentelemetry.io/otel/trace/embedded" +) + +const PomeriumCoreTracer = "pomerium.io/core" + +type panicTracerProvider struct { + embedded.TracerProvider +} + +// Tracer implements trace.TracerProvider. +func (w panicTracerProvider) Tracer(name string, options ...trace.TracerOption) trace.Tracer { + return panicTracer{} +} + +type panicTracer struct { + embedded.Tracer +} + +// Start implements trace.Tracer. +func (p panicTracer) Start(ctx context.Context, spanName string, opts ...trace.SpanStartOption) (context.Context, trace.Span) { + panic("global tracer used") +} + +func Continue(ctx context.Context, name string, o ...trace.SpanStartOption) (context.Context, trace.Span) { + return trace.SpanFromContext(ctx).TracerProvider().Tracer(PomeriumCoreTracer).Start(ctx, name, o...) +} diff --git a/internal/telemetry/trace/middleware.go b/internal/telemetry/trace/middleware.go index f163e8eb2..5639af3a2 100644 --- a/internal/telemetry/trace/middleware.go +++ b/internal/telemetry/trace/middleware.go @@ -1,6 +1,7 @@ package trace import ( + "context" "fmt" "net/http" @@ -8,6 +9,8 @@ import ( "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp" "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/propagation" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/stats" ) func NewHTTPMiddleware(opts ...otelhttp.Option) func(http.Handler) http.Handler { @@ -41,3 +44,52 @@ func NewHTTPMiddleware(opts ...otelhttp.Option) func(http.Handler) http.Handler }) } } + +func NewStatsHandler(base stats.Handler) stats.Handler { + return &statsHandlerWrapper{ + base: base, + } +} + +type statsHandlerWrapper struct { + base stats.Handler +} + +func (w *statsHandlerWrapper) wrapContext(ctx context.Context) context.Context { + md, ok := metadata.FromIncomingContext(ctx) + if !ok { + return ctx + } + traceparent := md.Get("traceparent") + xPomeriumTraceparent := md.Get("x-pomerium-traceparent") + if len(traceparent) > 0 && traceparent[0] != "" && len(xPomeriumTraceparent) > 0 && xPomeriumTraceparent[0] != "" { + newTracectx, err := ParseTraceparent(xPomeriumTraceparent[0]) + if err != nil { + return ctx + } + + md.Set("traceparent", ReplaceTraceID(traceparent[0], newTracectx.TraceID())) + return metadata.NewIncomingContext(ctx, md) + } + return ctx +} + +// HandleConn implements stats.Handler. +func (w *statsHandlerWrapper) HandleConn(ctx context.Context, stats stats.ConnStats) { + w.base.HandleConn(w.wrapContext(ctx), stats) +} + +// HandleRPC implements stats.Handler. +func (w *statsHandlerWrapper) HandleRPC(ctx context.Context, stats stats.RPCStats) { + w.base.HandleRPC(w.wrapContext(ctx), stats) +} + +// TagConn implements stats.Handler. +func (w *statsHandlerWrapper) TagConn(ctx context.Context, info *stats.ConnTagInfo) context.Context { + return w.base.TagConn(w.wrapContext(ctx), info) +} + +// TagRPC implements stats.Handler. +func (w *statsHandlerWrapper) TagRPC(ctx context.Context, info *stats.RPCTagInfo) context.Context { + return w.base.TagRPC(w.wrapContext(ctx), info) +} diff --git a/internal/telemetry/trace/server.go b/internal/telemetry/trace/server.go index 12c5ff9d0..91b1dd5f5 100644 --- a/internal/telemetry/trace/server.go +++ b/internal/telemetry/trace/server.go @@ -3,10 +3,16 @@ package trace import ( "context" "encoding/base64" + "encoding/hex" + "errors" + "fmt" "net" "net/url" + "os" "strings" "sync" + "time" + "unique" coltracepb "go.opentelemetry.io/proto/otlp/collector/trace/v1" commonv1 "go.opentelemetry.io/proto/otlp/common/v1" @@ -21,6 +27,8 @@ import ( "github.com/pomerium/pomerium/internal/log" "go.opentelemetry.io/otel/exporters/otlp/otlptrace" "go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc" + sdktrace "go.opentelemetry.io/otel/sdk/trace" + "go.opentelemetry.io/otel/trace" oteltrace "go.opentelemetry.io/otel/trace" ) @@ -57,16 +65,12 @@ func (ptr *PendingScopes) Insert(scope *commonv1.InstrumentationScope, scopeSche spans.Insert(span) } -func (ptr *PendingScopes) Delete(scope *commonv1.InstrumentationScope) (cascade bool) { - delete(ptr.spansByScope, scope.GetName()) - return len(ptr.spansByScope) == 0 -} - -func (ptr *PendingScopes) AsScopeSpansList(rewriteTraceId oteltrace.TraceID) []*tracev1.ScopeSpans { +func (ptr *PendingScopes) AsScopeSpansList(rewriteTraceId unique.Handle[oteltrace.TraceID]) []*tracev1.ScopeSpans { out := make([]*tracev1.ScopeSpans, 0, len(ptr.spansByScope)) for _, spans := range ptr.spansByScope { for _, span := range spans.spans { - span.TraceId = rewriteTraceId[:] + id := rewriteTraceId.Value() + copy(span.TraceId, id[:]) } scopeSpans := &tracev1.ScopeSpans{ Scope: spans.scope, @@ -101,15 +105,7 @@ func (ptr *PendingResources) Insert(resource *ResourceInfo, scope *commonv1.Inst scopes.Insert(scope, scopeSchema, span) } -func (ptr *PendingResources) Delete(resource *ResourceInfo, scope *commonv1.InstrumentationScope) (cascade bool) { - resourceEq := resource.ID() - if ptr.scopesByResourceID[resourceEq].Delete(scope) { - delete(ptr.scopesByResourceID, resourceEq) - } - return len(ptr.scopesByResourceID) == 0 -} - -func (ptr *PendingResources) AsResourceSpans(rewriteTraceId oteltrace.TraceID) []*tracev1.ResourceSpans { +func (ptr *PendingResources) AsResourceSpans(rewriteTraceId unique.Handle[oteltrace.TraceID]) []*tracev1.ResourceSpans { out := make([]*tracev1.ResourceSpans, 0, len(ptr.scopesByResourceID)) for _, scopes := range ptr.scopesByResourceID { resourceSpans := &tracev1.ResourceSpans{ @@ -152,28 +148,67 @@ func (r *ResourceInfo) computeID() string { return base64.StdEncoding.EncodeToString(hash.Sum(nil)) } +type spanObserver struct { + mu sync.Mutex + referencedIDs map[unique.Handle[oteltrace.SpanID]]bool + unobservedIDs sync.WaitGroup +} + +func (obs *spanObserver) ObserveReference(id unique.Handle[oteltrace.SpanID]) { + obs.mu.Lock() + defer obs.mu.Unlock() + if _, referenced := obs.referencedIDs[id]; !referenced { + obs.referencedIDs[id] = false // referenced, but not observed + obs.unobservedIDs.Add(1) + } +} + +func (obs *spanObserver) Observe(id unique.Handle[oteltrace.SpanID]) { + obs.mu.Lock() + defer obs.mu.Unlock() + if observed, referenced := obs.referencedIDs[id]; !observed { // NB: subtle condition + obs.referencedIDs[id] = true + if referenced { + obs.unobservedIDs.Done() + } + } +} + +func (obs *spanObserver) Wait() { + obs.unobservedIDs.Wait() +} + type SpanExportQueue struct { mu sync.Mutex - pendingResourcesByTraceId map[string]*PendingResources - knownTraceIdMappings map[string]oteltrace.TraceID + pendingResourcesByTraceId map[unique.Handle[oteltrace.TraceID]]*PendingResources + knownTraceIdMappings map[unique.Handle[oteltrace.TraceID]]unique.Handle[oteltrace.TraceID] uploadC chan []*tracev1.ResourceSpans + closing bool + closed chan struct{} + debugLevel int + debugAllObservedSpans map[unique.Handle[oteltrace.SpanID]]*tracev1.Span + tracker *spanTracker + observer *spanObserver } func NewSpanExportQueue(ctx context.Context, client otlptrace.Client) *SpanExportQueue { + observer := &spanObserver{referencedIDs: make(map[unique.Handle[oteltrace.SpanID]]bool)} + debugLevel := systemContextFromContext(ctx).DebugLevel q := &SpanExportQueue{ - pendingResourcesByTraceId: make(map[string]*PendingResources), - knownTraceIdMappings: make(map[string]oteltrace.TraceID), + pendingResourcesByTraceId: make(map[unique.Handle[oteltrace.TraceID]]*PendingResources), + knownTraceIdMappings: make(map[unique.Handle[oteltrace.TraceID]]unique.Handle[oteltrace.TraceID]), uploadC: make(chan []*tracev1.ResourceSpans, 8), + closed: make(chan struct{}), + debugLevel: debugLevel, + debugAllObservedSpans: make(map[unique.Handle[oteltrace.SpanID]]*tracev1.Span), + tracker: &spanTracker{observer: observer, debugLevel: debugLevel}, + observer: observer, } go func() { - for { - select { - case <-ctx.Done(): - return - case resourceSpans := <-q.uploadC: - if err := client.UploadTraces(ctx, resourceSpans); err != nil { - log.Ctx(ctx).Err(err).Msg("error uploading traces") - } + defer close(q.closed) + for resourceSpans := range q.uploadC { + if err := client.UploadTraces(context.Background(), resourceSpans); err != nil { + log.Ctx(ctx).Err(err).Msg("error uploading traces") } } }() @@ -186,48 +221,65 @@ type WithSchema[T any] struct { } func (q *SpanExportQueue) insertPendingSpanLocked(resource *ResourceInfo, scope *commonv1.InstrumentationScope, scopeSchema string, span *tracev1.Span) { - spanTraceIdHex := oteltrace.TraceID(span.TraceId).String() + spanTraceId := unique.Make(oteltrace.TraceID(span.TraceId)) var pendingTraceResources *PendingResources - if ptr, ok := q.pendingResourcesByTraceId[spanTraceIdHex]; ok { + if ptr, ok := q.pendingResourcesByTraceId[spanTraceId]; ok { pendingTraceResources = ptr } else { pendingTraceResources = NewPendingResources() - q.pendingResourcesByTraceId[spanTraceIdHex] = pendingTraceResources + q.pendingResourcesByTraceId[spanTraceId] = pendingTraceResources } pendingTraceResources.Insert(resource, scope, scopeSchema, span) } -func (q *SpanExportQueue) resolveTraceIdMappingLocked(resource *ResourceInfo, scope *commonv1.InstrumentationScope, scopeSchema string, span *tracev1.Span, mapping oteltrace.TraceID) { - originalTraceIdHex := oteltrace.TraceID(span.TraceId).String() - q.insertPendingSpanLocked(resource, scope, scopeSchema, span) - q.knownTraceIdMappings[originalTraceIdHex] = mapping - toUpload := q.pendingResourcesByTraceId[originalTraceIdHex].AsResourceSpans(mapping) - if q.pendingResourcesByTraceId[originalTraceIdHex].Delete(resource, scope) { - delete(q.pendingResourcesByTraceId, originalTraceIdHex) +func (q *SpanExportQueue) resolveTraceIdMappingLocked(original, mapping unique.Handle[oteltrace.TraceID]) [][]*tracev1.ResourceSpans { + q.knownTraceIdMappings[original] = mapping + + toUpload := [][]*tracev1.ResourceSpans{} + if originalPending, ok := q.pendingResourcesByTraceId[original]; ok { + resourceSpans := originalPending.AsResourceSpans(mapping) + delete(q.pendingResourcesByTraceId, original) + toUpload = append(toUpload, resourceSpans) } - q.uploadC <- toUpload + + if original != mapping { + q.knownTraceIdMappings[mapping] = mapping + if targetPending, ok := q.pendingResourcesByTraceId[mapping]; ok { + resourceSpans := targetPending.AsResourceSpans(mapping) + delete(q.pendingResourcesByTraceId, mapping) + toUpload = append(toUpload, resourceSpans) + } + } + return toUpload } -func (q *SpanExportQueue) Enqueue(ctx context.Context, req *coltracepb.ExportTraceServiceRequest) { +var ErrShuttingDown = errors.New("exporter is shutting down") + +func (q *SpanExportQueue) Enqueue(ctx context.Context, req *coltracepb.ExportTraceServiceRequest) error { q.mu.Lock() defer q.mu.Unlock() + if q.closing { + return ErrShuttingDown + } - var immediateUpload []*tracev1.ResourceSpans + var toUpload [][]*tracev1.ResourceSpans for _, resource := range req.ResourceSpans { - resourceInfo := newResourceInfo(resource.Resource, resource.SchemaUrl) - knownResources := &tracev1.ResourceSpans{ - Resource: resource.Resource, - SchemaUrl: resource.SchemaUrl, - } for _, scope := range resource.ScopeSpans { - var knownSpans []*tracev1.Span for _, span := range scope.Spans { - spanTraceId := oteltrace.TraceID(span.TraceId) - spanTraceIdHex := oteltrace.TraceID(span.TraceId).String() - formatSpanName(span) - if len(span.ParentSpanId) == 0 { - // observed a new root span + spanId := unique.Make(oteltrace.SpanID(span.SpanId)) + parentSpanId := parentSpanID(span.ParentSpanId) + if q.debugLevel >= 1 { + q.debugAllObservedSpans[spanId] = span + } + if parentSpanId != rootSpanId { + q.observer.ObserveReference(parentSpanId) + continue + } + spanTraceId := unique.Make(oteltrace.TraceID(span.TraceId)) + + if _, ok := q.knownTraceIdMappings[spanTraceId]; !ok { + // observed a new root span with an unknown trace id var pomeriumTraceparent string for _, attr := range span.Attributes { if attr.Key == "pomerium.traceparent" { @@ -235,11 +287,11 @@ func (q *SpanExportQueue) Enqueue(ctx context.Context, req *coltracepb.ExportTra break } } - var targetTraceID oteltrace.TraceID + var mappedTraceID unique.Handle[oteltrace.TraceID] if pomeriumTraceparent == "" { // no replacement id, map the trace to itself and release pending spans - targetTraceID = spanTraceId + mappedTraceID = spanTraceId } else { // this root span has an alternate traceparent. permanently rewrite // all spans of the old trace id to use the new trace id @@ -248,33 +300,204 @@ func (q *SpanExportQueue) Enqueue(ctx context.Context, req *coltracepb.ExportTra log.Ctx(ctx).Err(err).Msg("error processing trace") continue } - targetTraceID = tp.TraceID() + mappedTraceID = unique.Make(tp.TraceID()) } - q.resolveTraceIdMappingLocked(resourceInfo, scope.Scope, scope.SchemaUrl, span, targetTraceID) + toUpload = append(toUpload, q.resolveTraceIdMappingLocked(spanTraceId, mappedTraceID)...) + } + } + } + } + + var knownResources []*tracev1.ResourceSpans + for _, resource := range req.ResourceSpans { + resourceInfo := newResourceInfo(resource.Resource, resource.SchemaUrl) + knownResource := &tracev1.ResourceSpans{ + Resource: resource.Resource, + SchemaUrl: resource.SchemaUrl, + } + for _, scope := range resource.ScopeSpans { + var knownSpans []*tracev1.Span + for _, span := range scope.Spans { + spanID := unique.Make(oteltrace.SpanID(span.SpanId)) + spanTraceId := unique.Make(oteltrace.TraceID(span.TraceId)) + q.observer.Observe(spanID) + if mapping, ok := q.knownTraceIdMappings[spanTraceId]; ok { + id := mapping.Value() + copy(span.TraceId, id[:]) + knownSpans = append(knownSpans, span) } else { - if rewrite, ok := q.knownTraceIdMappings[spanTraceIdHex]; ok { - span.TraceId = rewrite[:] - knownSpans = append(knownSpans, span) - } else { - q.insertPendingSpanLocked(resourceInfo, scope.Scope, scope.SchemaUrl, span) - } + q.insertPendingSpanLocked(resourceInfo, scope.Scope, scope.SchemaUrl, span) } } if len(knownSpans) > 0 { - knownResources.ScopeSpans = append(knownResources.ScopeSpans, &tracev1.ScopeSpans{ + knownResource.ScopeSpans = append(knownResource.ScopeSpans, &tracev1.ScopeSpans{ Scope: scope.Scope, SchemaUrl: scope.SchemaUrl, Spans: knownSpans, }) } } - if len(knownResources.ScopeSpans) > 0 { - immediateUpload = append(immediateUpload, knownResources) + if len(knownResource.ScopeSpans) > 0 { + knownResources = append(knownResources, knownResource) } } - if len(immediateUpload) > 0 { - q.uploadC <- immediateUpload + if len(knownResources) > 0 { + toUpload = append(toUpload, knownResources) + } + for _, res := range toUpload { + q.uploadC <- res + } + return nil +} + +var ( + ErrIncompleteTraces = errors.New("exporter shut down with incomplete traces") + ErrIncompleteUploads = errors.New("exporter shut down with pending trace uploads") + ErrMissingParentSpans = errors.New("exporter shut down with missing parent spans") +) + +var rootSpanId = unique.Make(oteltrace.SpanID([8]byte{})) + +func parentSpanID(value []byte) unique.Handle[oteltrace.SpanID] { + if len(value) == 0 { + return rootSpanId + } + return unique.Make(oteltrace.SpanID(value)) +} + +func (q *SpanExportQueue) WaitForSpans(maxDuration time.Duration) error { + done := make(chan struct{}) + go func() { + defer close(done) + q.observer.Wait() + }() + select { + case <-done: + return nil + case <-time.After(maxDuration): + return ErrMissingParentSpans + } +} + +func (q *SpanExportQueue) Close(ctx context.Context) error { + q.mu.Lock() + q.closing = true + close(q.uploadC) + q.mu.Unlock() + select { + case <-ctx.Done(): + return context.Cause(ctx) + case <-q.closed: + q.mu.Lock() + defer q.mu.Unlock() + if q.debugLevel >= 1 { + var unknownParentIds []string + for id, known := range q.observer.referencedIDs { + if !known { + unknownParentIds = append(unknownParentIds, id.Value().String()) + } + } + if len(unknownParentIds) > 0 { + msg := strings.Builder{} + msg.WriteString("==================================================\n") + msg.WriteString("WARNING: parent spans referenced but never seen:\n") + for _, str := range unknownParentIds { + msg.WriteString(str) + msg.WriteString("\n") + } + msg.WriteString("==================================================\n") + fmt.Fprint(os.Stderr, msg.String()) + } + } + incomplete := len(q.pendingResourcesByTraceId) > 0 + if incomplete || q.debugLevel >= 3 { + msg := strings.Builder{} + if incomplete && q.debugLevel >= 1 { + msg.WriteString("==================================================\n") + msg.WriteString("WARNING: exporter shut down with incomplete traces\n") + for k, v := range q.pendingResourcesByTraceId { + msg.WriteString(fmt.Sprintf("- Trace: %s\n", k.Value())) + for _, pendingScope := range v.scopesByResourceID { + msg.WriteString(" - Resource:\n") + for _, v := range pendingScope.resource.Resource.Attributes { + msg.WriteString(fmt.Sprintf(" %s=%s\n", v.Key, v.Value.String())) + } + for _, scope := range pendingScope.spansByScope { + if scope.scope != nil { + msg.WriteString(fmt.Sprintf(" Scope: %s\n", scope.scope.Name)) + } else { + msg.WriteString(" Scope: (unknown)\n") + } + msg.WriteString(" Spans:\n") + longestName := 0 + for _, span := range scope.spans { + longestName = max(longestName, len(span.Name)+2) + } + for _, span := range scope.spans { + parentSpanId := parentSpanID(span.ParentSpanId) + _, seenParent := q.debugAllObservedSpans[parentSpanId] + var missing string + if !seenParent { + missing = " [missing]" + } + msg.WriteString(fmt.Sprintf(" - %-*s (trace: %s | span: %s | parent:%s %s)\n", longestName, + "'"+span.Name+"'", hex.EncodeToString(span.TraceId), hex.EncodeToString(span.SpanId), missing, parentSpanId.Value())) + for _, attr := range span.Attributes { + if attr.Key == "caller" { + msg.WriteString(fmt.Sprintf(" => caller: '%s'\n", attr.Value.GetStringValue())) + } + } + } + } + } + } + msg.WriteString("==================================================\n") + } + if (incomplete && q.debugLevel >= 2) || (!incomplete && q.debugLevel >= 3) { + msg.WriteString("==================================================\n") + msg.WriteString("Known trace ids:\n") + for k, v := range q.knownTraceIdMappings { + if k != v { + msg.WriteString(fmt.Sprintf("%s => %s\n", k.Value(), v.Value())) + } else { + msg.WriteString(fmt.Sprintf("%s (no change)\n", k.Value())) + } + } + msg.WriteString("==================================================\n") + msg.WriteString("All exported spans:\n") + longestName := 0 + for _, span := range q.debugAllObservedSpans { + longestName = max(longestName, len(span.Name)+2) + } + for _, span := range q.debugAllObservedSpans { + traceid := span.TraceId + spanid := span.SpanId + msg.WriteString(fmt.Sprintf("%-*s (trace: %s | span: %s | parent: %s)", longestName, + "'"+span.Name+"'", hex.EncodeToString(traceid[:]), hex.EncodeToString(spanid[:]), parentSpanID(span.ParentSpanId).Value())) + var foundCaller bool + for _, attr := range span.Attributes { + if attr.Key == "caller" { + msg.WriteString(fmt.Sprintf(" => %s\n", attr.Value.GetStringValue())) + foundCaller = true + break + } + } + if !foundCaller { + msg.WriteString("\n") + } + } + msg.WriteString("==================================================\n") + } + if msg.Len() > 0 { + fmt.Fprint(os.Stderr, msg.String()) + } + if incomplete { + return ErrIncompleteTraces + } + } + log.Ctx(ctx).Debug().Msg("exporter shut down") + return nil } } @@ -308,28 +531,35 @@ func formatSpanName(span *tracev1.Span) { } // Export implements ptraceotlp.GRPCServer. -func (srv *Server) Export(ctx context.Context, req *coltracepb.ExportTraceServiceRequest) (*coltracepb.ExportTraceServiceResponse, error) { +func (srv *ExporterServer) Export(ctx context.Context, req *coltracepb.ExportTraceServiceRequest) (*coltracepb.ExportTraceServiceResponse, error) { srv.spanExportQueue.Enqueue(ctx, req) return &coltracepb.ExportTraceServiceResponse{}, nil } -type Server struct { +type ExporterServer struct { coltracepb.UnimplementedTraceServiceServer spanExportQueue *SpanExportQueue + server *grpc.Server + remoteClient otlptrace.Client + cc *grpc.ClientConn } -func NewServer(ctx context.Context, client otlptrace.Client) *Server { - client.Start(ctx) - return &Server{ - spanExportQueue: NewSpanExportQueue(ctx, client), +func NewServer(ctx context.Context, remoteClient otlptrace.Client) *ExporterServer { + if err := remoteClient.Start(ctx); err != nil { + panic(err) } + ex := &ExporterServer{ + spanExportQueue: NewSpanExportQueue(ctx, remoteClient), + remoteClient: remoteClient, + server: grpc.NewServer(grpc.Creds(insecure.NewCredentials())), + } + coltracepb.RegisterTraceServiceServer(ex.server, ex) + return ex } -func (srv *Server) Start(ctx context.Context) otlptrace.Client { +func (srv *ExporterServer) Start(ctx context.Context) { lis := bufconn.Listen(4096) - gs := grpc.NewServer(grpc.Creds(insecure.NewCredentials())) - coltracepb.RegisterTraceServiceServer(gs, srv) - go gs.Serve(lis) + go srv.server.Serve(lis) cc, err := grpc.NewClient("passthrough://ignore", grpc.WithContextDialer(func(context.Context, string) (net.Conn, error) { return lis.Dial() @@ -337,5 +567,125 @@ func (srv *Server) Start(ctx context.Context) otlptrace.Client { if err != nil { panic(err) } - return otlptracegrpc.NewClient(otlptracegrpc.WithGRPCConn(cc)) + srv.cc = cc +} + +func (srv *ExporterServer) NewClient() otlptrace.Client { + return otlptracegrpc.NewClient(otlptracegrpc.WithGRPCConn(srv.cc)) +} + +func (srv *ExporterServer) SpanProcessors() []sdktrace.SpanProcessor { + return []sdktrace.SpanProcessor{srv.spanExportQueue.tracker} +} + +func (srv *ExporterServer) Shutdown(ctx context.Context) error { + stopped := make(chan struct{}) + go func() { + srv.server.GracefulStop() + close(stopped) + }() + select { + case <-stopped: + case <-ctx.Done(): + return context.Cause(ctx) + } + var errs []error + if err := srv.spanExportQueue.WaitForSpans(5 * time.Second); err != nil { + errs = append(errs, err) + } + if err := srv.spanExportQueue.Close(ctx); err != nil { + errs = append(errs, err) + } + if err := srv.remoteClient.Stop(ctx); err != nil { + errs = append(errs, err) + } + return errors.Join(errs...) +} + +type spanTracker struct { + inflightSpans sync.Map + allSpans sync.Map + debugLevel int + observer *spanObserver +} + +type spanInfo struct { + Name string + SpanContext trace.SpanContext + Parent trace.SpanContext +} + +// ForceFlush implements trace.SpanProcessor. +func (t *spanTracker) ForceFlush(ctx context.Context) error { + return nil +} + +// OnEnd implements trace.SpanProcessor. +func (t *spanTracker) OnEnd(s sdktrace.ReadOnlySpan) { + id := unique.Make(s.SpanContext().SpanID()) + t.inflightSpans.Delete(id) +} + +// OnStart implements trace.SpanProcessor. +func (t *spanTracker) OnStart(parent context.Context, s sdktrace.ReadWriteSpan) { + id := unique.Make(s.SpanContext().SpanID()) + t.inflightSpans.Store(id, struct{}{}) + t.observer.Observe(id) + if t.debugLevel >= 3 { + t.allSpans.Store(id, &spanInfo{ + Name: s.Name(), + SpanContext: s.SpanContext(), + Parent: s.Parent(), + }) + } +} + +// Shutdown implements trace.SpanProcessor. +func (t *spanTracker) Shutdown(ctx context.Context) error { + msg := strings.Builder{} + if t.debugLevel >= 1 { + incompleteSpans := []*spanInfo{} + t.inflightSpans.Range(func(key, value any) bool { + if info, ok := t.allSpans.Load(key); ok { + incompleteSpans = append(incompleteSpans, info.(*spanInfo)) + } + return true + }) + if len(incompleteSpans) > 0 { + msg.WriteString("==================================================\n") + msg.WriteString("WARNING: spans not ended:\n") + longestName := 0 + for _, span := range incompleteSpans { + longestName = max(longestName, len(span.Name)+2) + } + for _, span := range incompleteSpans { + msg.WriteString(fmt.Sprintf("%-*s (trace: %s | span: %s | parent: %s)\n", longestName, "'"+span.Name+"'", + span.SpanContext.TraceID(), span.SpanContext.SpanID(), span.Parent.SpanID())) + } + msg.WriteString("==================================================\n") + } + } + if t.debugLevel >= 3 { + allSpans := []*spanInfo{} + t.allSpans.Range(func(key, value any) bool { + allSpans = append(allSpans, value.(*spanInfo)) + return true + }) + msg.WriteString("==================================================\n") + msg.WriteString("All observed spans:\n") + longestName := 0 + for _, span := range allSpans { + longestName = max(longestName, len(span.Name)+2) + } + for _, span := range allSpans { + msg.WriteString(fmt.Sprintf("%-*s (trace: %s | span: %s | parent: %s)\n", longestName, "'"+span.Name+"'", + span.SpanContext.TraceID(), span.SpanContext.SpanID(), span.Parent.SpanID())) + } + msg.WriteString("==================================================\n") + } + if msg.Len() > 0 { + fmt.Fprint(os.Stderr, msg.String()) + } + + return nil } diff --git a/internal/telemetry/trace/trace.go b/internal/telemetry/trace/trace.go index 9b811a9c0..89aa10154 100644 --- a/internal/telemetry/trace/trace.go +++ b/internal/telemetry/trace/trace.go @@ -2,13 +2,12 @@ package trace import ( "context" - "encoding/hex" "errors" "fmt" "os" "runtime" - "strconv" - "strings" + "sync" + "time" "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/attribute" @@ -20,62 +19,88 @@ import ( sdktrace "go.opentelemetry.io/otel/sdk/trace" semconv "go.opentelemetry.io/otel/semconv/v1.26.0" "go.opentelemetry.io/otel/trace" - "go.opentelemetry.io/otel/trace/embedded" coltracepb "go.opentelemetry.io/proto/otlp/collector/trace/v1" - "google.golang.org/grpc/metadata" - "google.golang.org/grpc/stats" ) -type ( - clientKeyType struct{} - exporterKeyType struct{} - tracerProviderKeyType struct{} - serverKeyType struct{} -) +type systemContextKeyType struct{} -var ( - exporterKey exporterKeyType - tracerProviderKey tracerProviderKeyType - serverKey serverKeyType -) +var systemContextKey systemContextKeyType -type shutdownFunc func(options ...trace.SpanEndOption) +type Options struct { + DebugLevel int +} + +type systemContext struct { + Options + tpm *tracerProviderManager + exporterServer *ExporterServer +} + +func systemContextFromContext(ctx context.Context) *systemContext { + return ctx.Value(systemContextKey).(*systemContext) +} func init() { otel.SetTextMapPropagator(propagation.NewCompositeTextMapPropagator(propagation.TraceContext{}, propagation.Baggage{})) otel.SetTracerProvider(panicTracerProvider{}) } -type panicTracerProvider struct { - embedded.TracerProvider +var _ trace.Tracer = panicTracer{} + +type tracerProviderManager struct { + mu sync.Mutex + tracerProviders []*sdktrace.TracerProvider } -// Tracer implements trace.TracerProvider. -func (w panicTracerProvider) Tracer(name string, options ...trace.TracerOption) trace.Tracer { - panic("global tracer used") +func (tpm *tracerProviderManager) ShutdownAll(ctx context.Context) error { + tpm.mu.Lock() + defer tpm.mu.Unlock() + var errs []error + for _, tp := range tpm.tracerProviders { + errs = append(errs, tp.ForceFlush(ctx)) + } + for _, tp := range tpm.tracerProviders { + errs = append(errs, tp.Shutdown(ctx)) + } + clear(tpm.tracerProviders) + return errors.Join(errs...) +} + +func (tpm *tracerProviderManager) Add(tp *sdktrace.TracerProvider) { + tpm.mu.Lock() + defer tpm.mu.Unlock() + tpm.tracerProviders = append(tpm.tracerProviders, tp) +} + +func (op Options) NewContext(ctx context.Context) context.Context { + var remoteClient otlptrace.Client + if os.Getenv("OTEL_EXPORTER_OTLP_PROTOCOL") == "http/protobuf" { + remoteClient = otlptracehttp.NewClient() + } else { + remoteClient = otlptracegrpc.NewClient() + } + sys := &systemContext{ + Options: op, + tpm: &tracerProviderManager{}, + } + ctx = context.WithValue(ctx, systemContextKey, sys) + sys.exporterServer = NewServer(ctx, remoteClient) + sys.exporterServer.Start(ctx) + + return ctx } func NewContext(ctx context.Context) context.Context { - var realClient otlptrace.Client - if os.Getenv("OTEL_EXPORTER_OTLP_PROTOCOL") == "http/protobuf" { - realClient = otlptracehttp.NewClient() - } else { - realClient = otlptracegrpc.NewClient() - } - srv := NewServer(ctx, realClient) - localClient := srv.Start(ctx) - exp, err := otlptrace.New(ctx, localClient) - if err != nil { - panic(err) - } - ctx = context.WithValue(ctx, exporterKey, exp) - ctx = context.WithValue(ctx, serverKey, srv) - return ctx + return Options{}.NewContext(ctx) } func NewTracerProvider(ctx context.Context, serviceName string) trace.TracerProvider { _, file, line, _ := runtime.Caller(1) - exp := ctx.Value(exporterKey).(sdktrace.SpanExporter) + sys := systemContextFromContext(ctx) + exp, err := otlptrace.New(ctx, sys.exporterServer.NewClient()) + if err != nil { + panic(err) + } r, err := resource.Merge( resource.Default(), resource.NewWithAttributes( @@ -87,146 +112,40 @@ func NewTracerProvider(ctx context.Context, serviceName string) trace.TracerProv if err != nil { panic(err) } - return sdktrace.NewTracerProvider( - sdktrace.WithSpanProcessor(&stackTraceProcessor{}), + options := []sdktrace.TracerProviderOption{ sdktrace.WithBatcher(exp), sdktrace.WithResource(r), - ) -} - -type stackTraceProcessor struct{} - -// ForceFlush implements trace.SpanProcessor. -func (s *stackTraceProcessor) ForceFlush(ctx context.Context) error { - return nil -} - -// OnEnd implements trace.SpanProcessor. -func (*stackTraceProcessor) OnEnd(s sdktrace.ReadOnlySpan) { -} - -// OnStart implements trace.SpanProcessor. -func (*stackTraceProcessor) OnStart(parent context.Context, s sdktrace.ReadWriteSpan) { - _, file, line, _ := runtime.Caller(2) - s.SetAttributes(attribute.String("caller", fmt.Sprintf("%s:%d", file, line))) -} - -// Shutdown implements trace.SpanProcessor. -func (s *stackTraceProcessor) Shutdown(ctx context.Context) error { - return nil -} - -func ForceFlush(ctx context.Context) error { - if tp, ok := trace.SpanFromContext(ctx).TracerProvider().(interface { - ForceFlush(context.Context) error - }); ok { - return tp.ForceFlush(context.Background()) } - return nil + for _, proc := range sys.exporterServer.SpanProcessors() { + options = append(options, sdktrace.WithSpanProcessor(proc)) + } + if sys.DebugLevel >= 1 { + options = append(options, + sdktrace.WithSpanProcessor(&stackTraceProcessor{}), + ) + } + tp := sdktrace.NewTracerProvider(options...) + sys.tpm.Add(tp) + return tp } -func Shutdown(ctx context.Context) error { - _ = ForceFlush(ctx) - exporter := ctx.Value(exporterKey).(sdktrace.SpanExporter) - return exporter.Shutdown(context.Background()) +func ShutdownContext(ctx context.Context) error { + var errs []error + sys := systemContextFromContext(ctx) + + if err := sys.tpm.ShutdownAll(context.Background()); err != nil { + errs = append(errs, fmt.Errorf("(*tracerProviderManager).ShutdownAll: %w", err)) + } + if err := sys.exporterServer.Shutdown(context.Background()); err != nil { + errs = append(errs, fmt.Errorf("(*Server).Shutdown: %w", err)) + } + return errors.Join(errs...) } func ExporterServerFromContext(ctx context.Context) coltracepb.TraceServiceServer { - return ctx.Value(serverKey).(coltracepb.TraceServiceServer) + return systemContextFromContext(ctx).exporterServer } -const PomeriumCoreTracer = "pomerium.io/core" - -// StartSpan starts a new child span of the current span in the context. If -// there is no span in the context, creates a new trace and span. -// -// Returned context contains the newly created span. You can use it to -// propagate the returned span in process. -func Continue(ctx context.Context, name string, o ...trace.SpanStartOption) (context.Context, trace.Span) { - return trace.SpanFromContext(ctx).TracerProvider().Tracer(PomeriumCoreTracer).Start(ctx, name, o...) -} - -func ParseTraceparent(traceparent string) (trace.SpanContext, error) { - parts := strings.Split(traceparent, "-") - if len(parts) != 4 { - return trace.SpanContext{}, errors.New("malformed traceparent") - } - traceId, err := trace.TraceIDFromHex(parts[1]) - if err != nil { - return trace.SpanContext{}, err - } - spanId, err := trace.SpanIDFromHex(parts[2]) - if err != nil { - return trace.SpanContext{}, err - } - traceFlags, err := strconv.ParseUint(parts[3], 6, 32) - if err != nil { - return trace.SpanContext{}, err - } - if len(traceId) != 16 || len(spanId) != 8 { - return trace.SpanContext{}, errors.New("malformed traceparent") - } - return trace.NewSpanContext(trace.SpanContextConfig{ - TraceID: traceId, - SpanID: spanId, - TraceFlags: trace.TraceFlags(traceFlags), - }), nil -} - -func ReplaceTraceID(traceparent string, newTraceID trace.TraceID) string { - parts := strings.Split(traceparent, "-") - if len(parts) != 4 { - return traceparent - } - parts[1] = hex.EncodeToString(newTraceID[:]) - return strings.Join(parts, "-") -} - -func NewStatsHandler(base stats.Handler) stats.Handler { - return &wrapperStatsHandler{ - base: base, - } -} - -type wrapperStatsHandler struct { - base stats.Handler -} - -func (w *wrapperStatsHandler) wrapContext(ctx context.Context) context.Context { - md, ok := metadata.FromIncomingContext(ctx) - if !ok { - return ctx - } - traceparent := md.Get("traceparent") - xPomeriumTraceparent := md.Get("x-pomerium-traceparent") - if len(traceparent) > 0 && traceparent[0] != "" && len(xPomeriumTraceparent) > 0 && xPomeriumTraceparent[0] != "" { - newTracectx, err := ParseTraceparent(xPomeriumTraceparent[0]) - if err != nil { - return ctx - } - - md.Set("traceparent", ReplaceTraceID(traceparent[0], newTracectx.TraceID())) - return metadata.NewIncomingContext(ctx, md) - } - return ctx -} - -// HandleConn implements stats.Handler. -func (w *wrapperStatsHandler) HandleConn(ctx context.Context, stats stats.ConnStats) { - w.base.HandleConn(w.wrapContext(ctx), stats) -} - -// HandleRPC implements stats.Handler. -func (w *wrapperStatsHandler) HandleRPC(ctx context.Context, stats stats.RPCStats) { - w.base.HandleRPC(w.wrapContext(ctx), stats) -} - -// TagConn implements stats.Handler. -func (w *wrapperStatsHandler) TagConn(ctx context.Context, info *stats.ConnTagInfo) context.Context { - return w.base.TagConn(w.wrapContext(ctx), info) -} - -// TagRPC implements stats.Handler. -func (w *wrapperStatsHandler) TagRPC(ctx context.Context, info *stats.RPCTagInfo) context.Context { - return w.base.TagRPC(w.wrapContext(ctx), info) +func WaitForSpans(ctx context.Context, maxDuration time.Duration) error { + return systemContextFromContext(ctx).exporterServer.spanExportQueue.WaitForSpans(maxDuration) } diff --git a/internal/telemetry/trace/util.go b/internal/telemetry/trace/util.go new file mode 100644 index 000000000..6d3163ecb --- /dev/null +++ b/internal/telemetry/trace/util.go @@ -0,0 +1,46 @@ +package trace + +import ( + "encoding/hex" + "errors" + "strconv" + "strings" + + "go.opentelemetry.io/otel/trace" +) + +func ParseTraceparent(traceparent string) (trace.SpanContext, error) { + parts := strings.Split(traceparent, "-") + if len(parts) != 4 { + return trace.SpanContext{}, errors.New("malformed traceparent") + } + traceId, err := trace.TraceIDFromHex(parts[1]) + if err != nil { + return trace.SpanContext{}, err + } + spanId, err := trace.SpanIDFromHex(parts[2]) + if err != nil { + return trace.SpanContext{}, err + } + traceFlags, err := strconv.ParseUint(parts[3], 6, 32) + if err != nil { + return trace.SpanContext{}, err + } + if len(traceId) != 16 || len(spanId) != 8 { + return trace.SpanContext{}, errors.New("malformed traceparent") + } + return trace.NewSpanContext(trace.SpanContextConfig{ + TraceID: traceId, + SpanID: spanId, + TraceFlags: trace.TraceFlags(traceFlags), + }), nil +} + +func ReplaceTraceID(traceparent string, newTraceID trace.TraceID) string { + parts := strings.Split(traceparent, "-") + if len(parts) != 4 { + return traceparent + } + parts[1] = hex.EncodeToString(newTraceID[:]) + return strings.Join(parts, "-") +} diff --git a/internal/testenv/environment.go b/internal/testenv/environment.go index 2cb78c50c..461e529de 100644 --- a/internal/testenv/environment.go +++ b/internal/testenv/environment.go @@ -211,6 +211,7 @@ type environment struct { logWriter *log.MultiWriter tracerProvider oteltrace.TracerProvider tracer oteltrace.Tracer + rootSpan oteltrace.Span mods []WithCaller[Modifier] tasks []WithCaller[Task] @@ -267,9 +268,10 @@ func Silent(silent ...bool) EnvironmentOption { var setGrpcLoggerOnce sync.Once var ( - flagDebug = flag.Bool("env.debug", false, "enables test environment debug logging (equivalent to Debug() option)") - flagPauseOnFailure = flag.Bool("env.pause-on-failure", false, "enables pausing the test environment on failure (equivalent to PauseOnFailure() option)") - flagSilent = flag.Bool("env.silent", false, "suppresses all test environment output (equivalent to Silent() option)") + flagDebug = flag.Bool("env.debug", false, "enables test environment debug logging (equivalent to Debug() option)") + flagPauseOnFailure = flag.Bool("env.pause-on-failure", false, "enables pausing the test environment on failure (equivalent to PauseOnFailure() option)") + flagSilent = flag.Bool("env.silent", false, "suppresses all test environment output (equivalent to Silent() option)") + flagTraceDebugLevel = flag.Int("env.trace-debug-level", 0, "trace debug level") ) func New(t testing.TB, opts ...EnvironmentOption) Environment { @@ -320,14 +322,16 @@ func New(t testing.TB, opts ...EnvironmentOption) Environment { }) logger := zerolog.New(writer).With().Timestamp().Logger().Level(zerolog.DebugLevel) - ctx, cancel := context.WithCancelCause(logger.WithContext(trace.NewContext(context.Background()))) - t.Cleanup(func() { - trace.Shutdown(ctx) - }) + ctx := trace.Options{ + DebugLevel: *flagTraceDebugLevel, + }.NewContext(context.Background()) + ctx = logger.WithContext(ctx) tracerProvider := trace.NewTracerProvider(ctx, "Test Environment") tracer := tracerProvider.Tracer(trace.PomeriumCoreTracer) - ctx, span := tracer.Start(ctx, t.Name()) + ctx, span := tracer.Start(ctx, t.Name(), oteltrace.WithNewRoot()) require.NoError(t, err) + + ctx, cancel := context.WithCancelCause(ctx) taskErrGroup, ctx := errgroup.WithContext(ctx) e := &environment{ @@ -352,14 +356,13 @@ func New(t testing.TB, opts ...EnvironmentOption) Environment { ctx: ctx, cancel: cancel, tracerProvider: tracerProvider, - tracer: tracerProvider.Tracer(trace.PomeriumCoreTracer), + tracer: tracer, logWriter: writer, taskErrGroup: taskErrGroup, stateChangeListeners: make(map[EnvironmentState][]func()), + rootSpan: span, } - e.OnStateChanged(Stopped, func() { - span.End() - }) + _, err = rand.Read(e.sharedSecret[:]) require.NoError(t, err) _, err = rand.Read(e.cookieSecret[:]) @@ -561,6 +564,7 @@ func (e *environment) Start() { opts := []pomerium.Option{ pomerium.WithOverrideFileManager(fileMgr), + pomerium.WithEnvoyServerOptions(envoy.WithExitGracePeriod(10 * time.Second)), } envoyBinaryPath := filepath.Join(e.workspaceFolder, fmt.Sprintf("pkg/envoy/files/envoy-%s-%s", runtime.GOOS, runtime.GOARCH)) if envutil.EnvoyProfilerAvailable(envoyBinaryPath) { @@ -591,10 +595,7 @@ func (e *environment) Start() { } if len(envVars) > 0 { e.debugf("adding envoy env vars: %v\n", envVars) - opts = append(opts, pomerium.WithEnvoyServerOptions( - envoy.WithExtraEnvVars(envVars...), - envoy.WithExitGracePeriod(10*time.Second), // allow envoy time to flush pprof data to disk - )) + opts = append(opts, pomerium.WithEnvoyServerOptions(envoy.WithExtraEnvVars(envVars...))) } } else { e.debugf("envoy profiling not available") @@ -602,7 +603,11 @@ func (e *environment) Start() { pom := pomerium.New(opts...) e.OnStateChanged(Stopping, func() { - pom.Shutdown() + if err := pom.Shutdown(ctx); err != nil { + log.Ctx(ctx).Err(err).Msg("error shutting down pomerium server") + } else { + e.debugf("pomerium server shut down without error") + } }) pom.Start(ctx, e.tracerProvider, e.src) return pom.Wait() @@ -742,6 +747,8 @@ func (e *environment) Stop() { err := e.taskErrGroup.Wait() e.advanceState(Stopped) e.debugf("stop: done waiting") + e.rootSpan.End() + assert.NoError(e.t, trace.ShutdownContext(e.ctx)) assert.ErrorIs(e.t, err, ErrCauseManualStop) }) } diff --git a/internal/testenv/selftests/tracing_test.go b/internal/testenv/selftests/tracing_test.go new file mode 100644 index 000000000..47912de94 --- /dev/null +++ b/internal/testenv/selftests/tracing_test.go @@ -0,0 +1,55 @@ +package selftests_test + +import ( + "context" + "io" + "net/http" + "testing" + + "github.com/pomerium/pomerium/config" + "github.com/pomerium/pomerium/internal/testenv" + "github.com/pomerium/pomerium/internal/testenv/scenarios" + "github.com/pomerium/pomerium/internal/testenv/snippets" + "github.com/pomerium/pomerium/internal/testenv/upstreams" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.opentelemetry.io/otel/trace" +) + +func TestOTLPTracing(t *testing.T) { + t.Setenv("OTEL_EXPORTER_OTLP_TRACES_ENDPOINT", "http://localhost:4317") + env := testenv.New(t) + defer env.Stop() + env.Add(testenv.ModifierFunc(func(ctx context.Context, cfg *config.Config) { + cfg.Options.ProxyLogLevel = config.LogLevelInfo + })) + up := upstreams.HTTP(nil, upstreams.WithDisplayName("Upstream")) + up.Handle("/foo", func(w http.ResponseWriter, req *http.Request) { + w.Write([]byte("OK")) + }) + env.Add(scenarios.NewIDP([]*scenarios.User{ + { + Email: "foo@example.com", + FirstName: "Firstname", + LastName: "Lastname", + }, + })) + + route := up.Route(). + From(env.SubdomainURL("foo")). + PPL(`{"allow":{"and":["email":{"is":"foo@example.com"}]}}`) + + env.AddUpstream(up) + env.Start() + snippets.WaitStartupComplete(env) + + ctx, span := env.Tracer().Start(env.Context(), "Authenticate", trace.WithNewRoot()) + resp, err := up.Get(route, upstreams.AuthenticateAs("foo@example.com"), upstreams.Path("/foo"), upstreams.Context(ctx)) + span.End() + require.NoError(t, err) + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + resp.Body.Close() + assert.Equal(t, resp.StatusCode, 200) + assert.Equal(t, "OK", string(body)) +} diff --git a/internal/testenv/upstreams/http.go b/internal/testenv/upstreams/http.go index ac74bc108..f2bbbb217 100644 --- a/internal/testenv/upstreams/http.go +++ b/internal/testenv/upstreams/http.go @@ -157,7 +157,7 @@ type httpUpstream struct { clientCache sync.Map // map[testenv.Route]*http.Client router *mux.Router - tracerProvider oteltrace.TracerProvider + tracerProvider values.MutableValue[oteltrace.TracerProvider] } var ( @@ -176,6 +176,7 @@ func HTTP(tlsConfig values.Value[*tls.Config], opts ...HTTPUpstreamOption) HTTPU serverPort: values.Deferred[int](), router: mux.NewRouter(), tlsConfig: tlsConfig, + tracerProvider: values.Deferred[oteltrace.TracerProvider](), } up.RecordCaller() return up @@ -213,8 +214,8 @@ func (h *httpUpstream) Run(ctx context.Context) error { if h.tlsConfig != nil { tlsConfig = h.tlsConfig.Value() } - h.router.Use(trace.NewHTTPMiddleware(otelhttp.WithTracerProvider(h.tracerProvider))) - h.tracerProvider = trace.NewTracerProvider(ctx, h.displayName) + h.tracerProvider.Resolve(trace.NewTracerProvider(ctx, h.displayName)) + h.router.Use(trace.NewHTTPMiddleware(otelhttp.WithTracerProvider(h.tracerProvider.Value()))) server := &http.Server{ Handler: h.router, @@ -263,34 +264,6 @@ func (h *httpUpstream) Do(method string, r testenv.Route, opts ...RequestOption) }) } - req, err := http.NewRequestWithContext(options.requestCtx, method, u.String(), nil) - if err != nil { - return nil, err - } - switch body := options.body.(type) { - case string: - req.Body = io.NopCloser(strings.NewReader(body)) - case []byte: - req.Body = io.NopCloser(bytes.NewReader(body)) - case io.Reader: - req.Body = io.NopCloser(body) - case proto.Message: - buf, err := proto.Marshal(body) - if err != nil { - return nil, err - } - req.Body = io.NopCloser(bytes.NewReader(buf)) - req.Header.Set("Content-Type", "application/octet-stream") - default: - buf, err := json.Marshal(body) - if err != nil { - panic(fmt.Sprintf("unsupported body type: %T", body)) - } - req.Body = io.NopCloser(bytes.NewReader(buf)) - req.Header.Set("Content-Type", "application/json") - case nil: - } - newClient := func() *http.Client { c := http.Client{ Transport: otelhttp.NewTransport(&http.Transport{ @@ -299,7 +272,7 @@ func (h *httpUpstream) Do(method string, r testenv.Route, opts ...RequestOption) Certificates: options.clientCerts, }, }, - otelhttp.WithTracerProvider(h.tracerProvider), + otelhttp.WithTracerProvider(h.tracerProvider.Value()), otelhttp.WithSpanNameFormatter(func(operation string, r *http.Request) string { return fmt.Sprintf("Client: %s %s", r.Method, r.URL.Path) }), @@ -322,11 +295,38 @@ func (h *httpUpstream) Do(method string, r testenv.Route, opts ...RequestOption) var resp *http.Response if err := retry.Retry(options.requestCtx, "http", func(ctx context.Context) error { - var err error + req, err := http.NewRequestWithContext(options.requestCtx, method, u.String(), nil) + if err != nil { + return err + } + switch body := options.body.(type) { + case string: + req.Body = io.NopCloser(strings.NewReader(body)) + case []byte: + req.Body = io.NopCloser(bytes.NewReader(body)) + case io.Reader: + req.Body = io.NopCloser(body) + case proto.Message: + buf, err := proto.Marshal(body) + if err != nil { + return err + } + req.Body = io.NopCloser(bytes.NewReader(buf)) + req.Header.Set("Content-Type", "application/octet-stream") + default: + buf, err := json.Marshal(body) + if err != nil { + panic(fmt.Sprintf("unsupported body type: %T", body)) + } + req.Body = io.NopCloser(bytes.NewReader(buf)) + req.Header.Set("Content-Type", "application/json") + case nil: + } + if options.authenticateAs != "" { - resp, err = authenticateFlow(ctx, client, req, options.authenticateAs) //nolint:bodyclose + resp, err = authenticateFlow(ctx, client, req, options.authenticateAs) } else { - resp, err = client.Do(req) //nolint:bodyclose + resp, err = client.Do(req) } // retry on connection refused if err != nil { @@ -338,6 +338,9 @@ func (h *httpUpstream) Do(method string, r testenv.Route, opts ...RequestOption) return retry.NewTerminalError(err) } if resp.StatusCode/100 == 5 { + if err := resp.Body.Close(); err != nil { + panic(err) + } oteltrace.SpanFromContext(ctx).AddEvent("Retrying on 5xx error", oteltrace.WithAttributes( attribute.String("status", resp.Status), )) @@ -357,7 +360,6 @@ func authenticateFlow(ctx context.Context, client *http.Client, req *http.Reques if err != nil { return nil, err } - location := res.Request.URL if location.Hostname() == originalHostname { // already authenticated diff --git a/pkg/cmd/pomerium/pomerium.go b/pkg/cmd/pomerium/pomerium.go index d048bedca..457c5f7d7 100644 --- a/pkg/cmd/pomerium/pomerium.go +++ b/pkg/cmd/pomerium/pomerium.go @@ -205,10 +205,13 @@ func (p *Pomerium) Start(ctx context.Context, tracerProvider oteltrace.TracerPro return nil } -func (p *Pomerium) Shutdown() error { - _ = p.envoyServer.Close() // this only errors if signaling envoy fails +func (p *Pomerium) Shutdown(ctx context.Context) error { + _ = trace.WaitForSpans(ctx, p.envoyServer.ExitGracePeriod()) + var errs []error + errs = append(errs, p.envoyServer.Close()) // this only errors if signaling envoy fails p.cancel(ErrShutdown) - return p.Wait() + errs = append(errs, p.Wait()) + return errors.Join(errs...) } func (p *Pomerium) Wait() error { diff --git a/pkg/envoy/envoy.go b/pkg/envoy/envoy.go index f5c1a7dac..7824b8749 100644 --- a/pkg/envoy/envoy.go +++ b/pkg/envoy/envoy.go @@ -61,6 +61,10 @@ type ServerOptions struct { exitGracePeriod time.Duration } +func (opts *ServerOptions) ExitGracePeriod() time.Duration { + return opts.exitGracePeriod +} + type ServerOption func(*ServerOptions) func (o *ServerOptions) apply(opts ...ServerOption) {