diff --git a/internal/telemetry/trace/server.go b/internal/telemetry/trace/server.go index 8f732df88..dd56402e3 100644 --- a/internal/telemetry/trace/server.go +++ b/internal/telemetry/trace/server.go @@ -3,7 +3,6 @@ package trace import ( "context" "encoding/base64" - "encoding/hex" "errors" "fmt" "net" @@ -28,7 +27,6 @@ import ( "go.opentelemetry.io/otel/exporters/otlp/otlptrace" "go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc" sdktrace "go.opentelemetry.io/otel/sdk/trace" - "go.opentelemetry.io/otel/trace" oteltrace "go.opentelemetry.io/otel/trace" ) @@ -148,13 +146,19 @@ func (r *ResourceInfo) computeID() string { return base64.StdEncoding.EncodeToString(hash.Sum(nil)) } +type SpanObserver interface { + ObserveReference(id oteltrace.SpanID) + Observe(id oteltrace.SpanID) + Wait() +} + type spanObserver struct { mu sync.Mutex - referencedIDs map[unique.Handle[oteltrace.SpanID]]bool + referencedIDs map[oteltrace.SpanID]bool unobservedIDs sync.WaitGroup } -func (obs *spanObserver) ObserveReference(id unique.Handle[oteltrace.SpanID]) { +func (obs *spanObserver) ObserveReference(id oteltrace.SpanID) { obs.mu.Lock() defer obs.mu.Unlock() if _, referenced := obs.referencedIDs[id]; !referenced { @@ -163,7 +167,7 @@ func (obs *spanObserver) ObserveReference(id unique.Handle[oteltrace.SpanID]) { } } -func (obs *spanObserver) Observe(id unique.Handle[oteltrace.SpanID]) { +func (obs *spanObserver) Observe(id oteltrace.SpanID) { obs.mu.Lock() defer obs.mu.Unlock() if observed, referenced := obs.referencedIDs[id]; !observed { // NB: subtle condition @@ -178,6 +182,12 @@ func (obs *spanObserver) Wait() { obs.unobservedIDs.Wait() } +type noopSpanObserver struct{} + +func (noopSpanObserver) ObserveReference(oteltrace.SpanID) {} +func (noopSpanObserver) Observe(oteltrace.SpanID) {} +func (noopSpanObserver) Wait() {} + type SpanExportQueue struct { mu sync.Mutex pendingResourcesByTraceId map[unique.Handle[oteltrace.TraceID]]*PendingResources @@ -185,23 +195,28 @@ type SpanExportQueue struct { uploadC chan []*tracev1.ResourceSpans closing bool closed chan struct{} - debugLevel int - debugAllObservedSpans map[unique.Handle[oteltrace.SpanID]]*tracev1.Span + debugFlags DebugFlags + debugAllObservedSpans map[oteltrace.SpanID]*tracev1.Span tracker *spanTracker - observer *spanObserver + observer SpanObserver } func NewSpanExportQueue(ctx context.Context, client otlptrace.Client) *SpanExportQueue { - observer := &spanObserver{referencedIDs: make(map[unique.Handle[oteltrace.SpanID]]bool)} - debugLevel := systemContextFromContext(ctx).DebugLevel + debug := systemContextFromContext(ctx).DebugFlags + var observer SpanObserver + if debug.Check(TrackSpanReferences) { + observer = &spanObserver{referencedIDs: make(map[oteltrace.SpanID]bool)} + } else { + observer = noopSpanObserver{} + } q := &SpanExportQueue{ pendingResourcesByTraceId: make(map[unique.Handle[oteltrace.TraceID]]*PendingResources), knownTraceIdMappings: make(map[unique.Handle[oteltrace.TraceID]]unique.Handle[oteltrace.TraceID]), uploadC: make(chan []*tracev1.ResourceSpans, 8), closed: make(chan struct{}), - debugLevel: debugLevel, - debugAllObservedSpans: make(map[unique.Handle[oteltrace.SpanID]]*tracev1.Span), - tracker: &spanTracker{observer: observer, debugLevel: debugLevel}, + debugFlags: debug, + debugAllObservedSpans: make(map[oteltrace.SpanID]*tracev1.Span), + tracker: &spanTracker{observer: observer, debugFlags: debug}, observer: observer, } go func() { @@ -267,13 +282,12 @@ func (q *SpanExportQueue) Enqueue(ctx context.Context, req *coltracepb.ExportTra for _, scope := range resource.ScopeSpans { for _, span := range scope.Spans { formatSpanName(span) - spanId := unique.Make(oteltrace.SpanID(span.SpanId)) - parentSpanId := parentSpanID(span.ParentSpanId) - if q.debugLevel >= 1 { + spanId := oteltrace.SpanID(span.SpanId) + if q.debugFlags.Check(TrackAllSpans) { q.debugAllObservedSpans[spanId] = span } - if parentSpanId != rootSpanId { - q.observer.ObserveReference(parentSpanId) + if len(span.ParentSpanId) != 0 { // if parent is not a root span + q.observer.ObserveReference(oteltrace.SpanID(span.ParentSpanId)) continue } spanTraceId := unique.Make(oteltrace.TraceID(span.TraceId)) @@ -319,10 +333,10 @@ func (q *SpanExportQueue) Enqueue(ctx context.Context, req *coltracepb.ExportTra for _, scope := range resource.ScopeSpans { var knownSpans []*tracev1.Span for _, span := range scope.Spans { - spanID := unique.Make(oteltrace.SpanID(span.SpanId)) - spanTraceId := unique.Make(oteltrace.TraceID(span.TraceId)) + spanID := oteltrace.SpanID(span.SpanId) + traceId := unique.Make(oteltrace.TraceID(span.TraceId)) q.observer.Observe(spanID) - if mapping, ok := q.knownTraceIdMappings[spanTraceId]; ok { + if mapping, ok := q.knownTraceIdMappings[traceId]; ok { id := mapping.Value() copy(span.TraceId, id[:]) knownSpans = append(knownSpans, span) @@ -357,15 +371,6 @@ var ( ErrMissingParentSpans = errors.New("exporter shut down with missing parent spans") ) -var rootSpanId = unique.Make(oteltrace.SpanID([8]byte{})) - -func parentSpanID(value []byte) unique.Handle[oteltrace.SpanID] { - if len(value) == 0 { - return rootSpanId - } - return unique.Make(oteltrace.SpanID(value)) -} - func (q *SpanExportQueue) WaitForSpans(maxDuration time.Duration) error { done := make(chan struct{}) go func() { @@ -391,116 +396,122 @@ func (q *SpanExportQueue) Close(ctx context.Context) error { case <-q.closed: q.mu.Lock() defer q.mu.Unlock() - if q.debugLevel >= 1 { + if q.debugFlags.Check(TrackSpanReferences) { var unknownParentIds []string - for id, known := range q.observer.referencedIDs { + for id, known := range q.observer.(*spanObserver).referencedIDs { if !known { - unknownParentIds = append(unknownParentIds, id.Value().String()) + unknownParentIds = append(unknownParentIds, id.String()) } } if len(unknownParentIds) > 0 { - msg := strings.Builder{} - msg.WriteString("==================================================\n") - msg.WriteString("WARNING: parent spans referenced but never seen:\n") + msg := startMsg("WARNING: parent spans referenced but never seen:\n") for _, str := range unknownParentIds { msg.WriteString(str) msg.WriteString("\n") } - msg.WriteString("==================================================\n") - fmt.Fprint(os.Stderr, msg.String()) + endMsg(msg) } } + didWarn := false incomplete := len(q.pendingResourcesByTraceId) > 0 - if incomplete || q.debugLevel >= 3 { - msg := strings.Builder{} - if incomplete && q.debugLevel >= 1 { - msg.WriteString("==================================================\n") - msg.WriteString("WARNING: exporter shut down with incomplete traces\n") - for k, v := range q.pendingResourcesByTraceId { - msg.WriteString(fmt.Sprintf("- Trace: %s\n", k.Value())) - for _, pendingScope := range v.scopesByResourceID { - msg.WriteString(" - Resource:\n") - for _, v := range pendingScope.resource.Resource.Attributes { - msg.WriteString(fmt.Sprintf(" %s=%s\n", v.Key, v.Value.String())) + if incomplete && q.debugFlags.Check(WarnOnIncompleteTraces) { + didWarn = true + msg := startMsg("WARNING: exporter shut down with incomplete traces\n") + for k, v := range q.pendingResourcesByTraceId { + fmt.Fprintf(msg, "- Trace: %s\n", k.Value()) + for _, pendingScope := range v.scopesByResourceID { + msg.WriteString(" - Resource:\n") + for _, v := range pendingScope.resource.Resource.Attributes { + fmt.Fprintf(msg, " %s=%s\n", v.Key, v.Value.String()) + } + for _, scope := range pendingScope.spansByScope { + if scope.scope != nil { + fmt.Fprintf(msg, " Scope: %s\n", scope.scope.Name) + } else { + msg.WriteString(" Scope: (unknown)\n") } - for _, scope := range pendingScope.spansByScope { - if scope.scope != nil { - msg.WriteString(fmt.Sprintf(" Scope: %s\n", scope.scope.Name)) - } else { - msg.WriteString(" Scope: (unknown)\n") + msg.WriteString(" Spans:\n") + longestName := 0 + for _, span := range scope.spans { + longestName = max(longestName, len(span.Name)+2) + } + for _, span := range scope.spans { + parentSpanId := oteltrace.SpanID(span.ParentSpanId) + _, seenParent := q.debugAllObservedSpans[parentSpanId] + var missing string + if !seenParent { + missing = " [missing]" } - msg.WriteString(" Spans:\n") - longestName := 0 - for _, span := range scope.spans { - longestName = max(longestName, len(span.Name)+2) - } - for _, span := range scope.spans { - parentSpanId := parentSpanID(span.ParentSpanId) - _, seenParent := q.debugAllObservedSpans[parentSpanId] - var missing string - if !seenParent { - missing = " [missing]" - } - msg.WriteString(fmt.Sprintf(" - %-*s (trace: %s | span: %s | parent:%s %s)\n", longestName, - "'"+span.Name+"'", hex.EncodeToString(span.TraceId), hex.EncodeToString(span.SpanId), missing, parentSpanId.Value())) - for _, attr := range span.Attributes { - if attr.Key == "caller" { - msg.WriteString(fmt.Sprintf(" => caller: '%s'\n", attr.Value.GetStringValue())) - } + fmt.Fprintf(msg, " - %-*s (trace: %s | span: %s | parent:%s %s)\n", longestName, + "'"+span.Name+"'", oteltrace.SpanID(span.TraceId), oteltrace.SpanID(span.SpanId), missing, parentSpanId) + for _, attr := range span.Attributes { + if attr.Key == "caller" { + fmt.Fprintf(msg, " => caller: '%s'\n", attr.Value.GetStringValue()) + break } } } } } - msg.WriteString("==================================================\n") - } - if (incomplete && q.debugLevel >= 2) || (!incomplete && q.debugLevel >= 3) { - msg.WriteString("==================================================\n") - msg.WriteString("Known trace ids:\n") - for k, v := range q.knownTraceIdMappings { - if k != v { - msg.WriteString(fmt.Sprintf("%s => %s\n", k.Value(), v.Value())) - } else { - msg.WriteString(fmt.Sprintf("%s (no change)\n", k.Value())) - } - } - msg.WriteString("==================================================\n") - msg.WriteString("All exported spans:\n") - longestName := 0 - for _, span := range q.debugAllObservedSpans { - longestName = max(longestName, len(span.Name)+2) - } - for _, span := range q.debugAllObservedSpans { - traceid := span.TraceId - spanid := span.SpanId - msg.WriteString(fmt.Sprintf("%-*s (trace: %s | span: %s | parent: %s)", longestName, - "'"+span.Name+"'", hex.EncodeToString(traceid[:]), hex.EncodeToString(spanid[:]), parentSpanID(span.ParentSpanId).Value())) - var foundCaller bool - for _, attr := range span.Attributes { - if attr.Key == "caller" { - msg.WriteString(fmt.Sprintf(" => %s\n", attr.Value.GetStringValue())) - foundCaller = true - break - } - } - if !foundCaller { - msg.WriteString("\n") - } - } - msg.WriteString("==================================================\n") - } - if msg.Len() > 0 { - fmt.Fprint(os.Stderr, msg.String()) - } - if incomplete { - return ErrIncompleteTraces } + endMsg(msg) } + + if q.debugFlags.Check(LogTraceIDMappings) || (didWarn && q.debugFlags.Check(LogTraceIDMappingsOnWarn)) { + msg := startMsg("Known trace ids:\n") + for k, v := range q.knownTraceIdMappings { + if k != v { + fmt.Fprintf(msg, "%s => %s\n", k.Value(), v.Value()) + } else { + fmt.Fprintf(msg, "%s (no change)\n", k.Value()) + } + } + endMsg(msg) + } + if q.debugFlags.Check(LogAllSpans) || (didWarn && q.debugFlags.Check(LogAllSpansOnWarn)) { + msg := startMsg("All exported spans:\n") + longestName := 0 + for _, span := range q.debugAllObservedSpans { + longestName = max(longestName, len(span.Name)+2) + } + for _, span := range q.debugAllObservedSpans { + fmt.Fprintf(msg, "%-*s (trace: %s | span: %s | parent: %s)", longestName, + "'"+span.Name+"'", oteltrace.TraceID(span.TraceId), oteltrace.SpanID(span.SpanId), oteltrace.SpanID(span.ParentSpanId)) + var foundCaller bool + for _, attr := range span.Attributes { + if attr.Key == "caller" { + fmt.Fprintf(msg, " => %s\n", attr.Value.GetStringValue()) + foundCaller = true + break + } + } + if !foundCaller { + msg.WriteString("\n") + } + } + endMsg(msg) + } + if incomplete { + return ErrIncompleteTraces + } + log.Ctx(ctx).Debug().Msg("exporter shut down") return nil } } +func startMsg(title string) *strings.Builder { + msg := &strings.Builder{} + msg.WriteString("\n==================================================\n") + msg.WriteString(title) + return msg +} + +func endMsg(msg *strings.Builder) { + msg.WriteString("==================================================\n") + fmt.Fprint(os.Stderr, msg.String()) +} + func formatSpanName(span *tracev1.Span) { hasPath := strings.Contains(span.GetName(), "${path}") hasHost := strings.Contains(span.GetName(), "${host}") @@ -605,15 +616,15 @@ func (srv *ExporterServer) Shutdown(ctx context.Context) error { type spanTracker struct { inflightSpans sync.Map allSpans sync.Map - debugLevel int - observer *spanObserver + debugFlags DebugFlags + observer SpanObserver shutdownOnce sync.Once } type spanInfo struct { Name string - SpanContext trace.SpanContext - Parent trace.SpanContext + SpanContext oteltrace.SpanContext + Parent oteltrace.SpanContext } // ForceFlush implements trace.SpanProcessor. @@ -623,16 +634,16 @@ func (t *spanTracker) ForceFlush(ctx context.Context) error { // OnEnd implements trace.SpanProcessor. func (t *spanTracker) OnEnd(s sdktrace.ReadOnlySpan) { - id := unique.Make(s.SpanContext().SpanID()) + id := s.SpanContext().SpanID() t.inflightSpans.Delete(id) } // OnStart implements trace.SpanProcessor. func (t *spanTracker) OnStart(parent context.Context, s sdktrace.ReadWriteSpan) { - id := unique.Make(s.SpanContext().SpanID()) + id := s.SpanContext().SpanID() t.inflightSpans.Store(id, struct{}{}) t.observer.Observe(id) - if t.debugLevel >= 3 { + if t.debugFlags.Check(TrackAllSpans) { t.allSpans.Store(id, &spanInfo{ Name: s.Name(), SpanContext: s.SpanContext(), @@ -643,50 +654,67 @@ func (t *spanTracker) OnStart(parent context.Context, s sdktrace.ReadWriteSpan) // Shutdown implements trace.SpanProcessor. func (t *spanTracker) Shutdown(ctx context.Context) error { + if t.debugFlags == 0 { + return nil + } t.shutdownOnce.Do(func() { - msg := strings.Builder{} - if t.debugLevel >= 1 { - incompleteSpans := []*spanInfo{} - t.inflightSpans.Range(func(key, value any) bool { - if info, ok := t.allSpans.Load(key); ok { - incompleteSpans = append(incompleteSpans, info.(*spanInfo)) + didWarn := false + if t.debugFlags.Check(WarnOnIncompleteSpans) { + if t.debugFlags.Check(TrackAllSpans) { + incompleteSpans := []*spanInfo{} + t.inflightSpans.Range(func(key, value any) bool { + if info, ok := t.allSpans.Load(key); ok { + incompleteSpans = append(incompleteSpans, info.(*spanInfo)) + } + return true + }) + if len(incompleteSpans) > 0 { + didWarn = true + msg := startMsg("WARNING: spans not ended:\n") + longestName := 0 + for _, span := range incompleteSpans { + longestName = max(longestName, len(span.Name)+2) + } + for _, span := range incompleteSpans { + fmt.Fprintf(msg, "%-*s (trace: %s | span: %s | parent: %s)\n", longestName, "'"+span.Name+"'", + span.SpanContext.TraceID(), span.SpanContext.SpanID(), span.Parent.SpanID()) + } + endMsg(msg) } - return true - }) - if len(incompleteSpans) > 0 { - msg.WriteString("==================================================\n") - msg.WriteString("WARNING: spans not ended:\n") - longestName := 0 - for _, span := range incompleteSpans { - longestName = max(longestName, len(span.Name)+2) + } else { + incompleteSpans := []string{} + t.inflightSpans.Range(func(key, value any) bool { + incompleteSpans = append(incompleteSpans, key.(string)) + return true + }) + if len(incompleteSpans) > 0 { + didWarn = true + msg := startMsg("WARNING: spans not ended:\n") + for _, span := range incompleteSpans { + fmt.Fprintf(msg, "%s\n", span) + } + msg.WriteString("Note: set TrackAllObservedSpans flag for more info\n") + endMsg(msg) } - for _, span := range incompleteSpans { - msg.WriteString(fmt.Sprintf("%-*s (trace: %s | span: %s | parent: %s)\n", longestName, "'"+span.Name+"'", - span.SpanContext.TraceID(), span.SpanContext.SpanID(), span.Parent.SpanID())) - } - msg.WriteString("==================================================\n") } } - if t.debugLevel >= 3 { + + if t.debugFlags.Check(LogAllSpans) || (t.debugFlags.Check(LogAllSpansOnWarn) && didWarn) { allSpans := []*spanInfo{} t.allSpans.Range(func(key, value any) bool { allSpans = append(allSpans, value.(*spanInfo)) return true }) - msg.WriteString("==================================================\n") - msg.WriteString("All observed spans:\n") + msg := startMsg("All observed spans:\n") longestName := 0 for _, span := range allSpans { longestName = max(longestName, len(span.Name)+2) } for _, span := range allSpans { - msg.WriteString(fmt.Sprintf("%-*s (trace: %s | span: %s | parent: %s)\n", longestName, "'"+span.Name+"'", - span.SpanContext.TraceID(), span.SpanContext.SpanID(), span.Parent.SpanID())) + fmt.Fprintf(msg, "%-*s (trace: %s | span: %s | parent: %s)\n", longestName, "'"+span.Name+"'", + span.SpanContext.TraceID(), span.SpanContext.SpanID(), span.Parent.SpanID()) } - msg.WriteString("==================================================\n") - } - if msg.Len() > 0 { - fmt.Fprint(os.Stderr, msg.String()) + endMsg(msg) } }) diff --git a/internal/telemetry/trace/trace.go b/internal/telemetry/trace/trace.go index 89aa10154..7287d7edb 100644 --- a/internal/telemetry/trace/trace.go +++ b/internal/telemetry/trace/trace.go @@ -26,8 +26,64 @@ type systemContextKeyType struct{} var systemContextKey systemContextKeyType +type DebugFlags uint32 + +const ( + // If set, adds the "caller" attribute to each trace with the source location + // where the trace was started. + TrackSpanCallers = (1 << iota) + + // If set, keeps track of all span references and will attempt to wait for + // all traces to complete when shutting down a trace context. + // Use with caution, this will cause increasing memory usage over time. + TrackSpanReferences = (1 << iota) + + // If set, keeps track of all observed spans, including span context and + // all attributes. + // Use with caution, this will cause significantly increasing memory usage + // over time. + TrackAllSpans = (1 << iota) + + // If set, will log all trace ID mappings on close. + LogTraceIDMappings = (1 << iota) + + // If set, will log all spans observed by the exporter on close. These spans + // may belong to incomplete traces. + // + // Enables [TrackAllSpans] + LogAllSpans = (1 << iota) | TrackAllSpans + + // If set, will log all exported spans when a warning is issued on close + // (requires warning flags to also be set) + // + // Enables [TrackAllSpans] + LogAllSpansOnWarn = (1 << iota) | TrackAllSpans + + // If set, will log all trace ID mappings when a warning is issued on close. + // (requires warning flags to also be set) + LogTraceIDMappingsOnWarn = (1 << iota) + + // If set, will print a warning to stderr on close if there are any incomplete + // traces (traces with no observed root spans) + WarnOnIncompleteTraces = (1 << iota) + + // If set, will print a warning to stderr on close if there are any incomplete + // spans (spans started, but not ended) + WarnOnIncompleteSpans = (1 << iota) + + // If set, will print a warning to stderr on close if there are any spans + // which reference unknown parent spans. + // + // Enables [TrackSpanReferences] + WarnOnUnresolvedReferences = (1 << iota) | TrackSpanReferences +) + +func (df DebugFlags) Check(flags DebugFlags) bool { + return (df & flags) == flags +} + type Options struct { - DebugLevel int + DebugFlags DebugFlags } type systemContext struct { @@ -119,7 +175,7 @@ func NewTracerProvider(ctx context.Context, serviceName string) trace.TracerProv for _, proc := range sys.exporterServer.SpanProcessors() { options = append(options, sdktrace.WithSpanProcessor(proc)) } - if sys.DebugLevel >= 1 { + if sys.DebugFlags.Check(TrackSpanCallers) { options = append(options, sdktrace.WithSpanProcessor(&stackTraceProcessor{}), ) diff --git a/internal/testenv/environment.go b/internal/testenv/environment.go index 461e529de..2cc6d50d7 100644 --- a/internal/testenv/environment.go +++ b/internal/testenv/environment.go @@ -33,6 +33,7 @@ import ( "github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/config/envoyconfig/filemgr" + databroker_service "github.com/pomerium/pomerium/databroker" "github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/telemetry/trace" "github.com/pomerium/pomerium/internal/testenv/envutil" @@ -41,6 +42,8 @@ import ( "github.com/pomerium/pomerium/pkg/envoy" "github.com/pomerium/pomerium/pkg/grpc/databroker" "github.com/pomerium/pomerium/pkg/health" + "github.com/pomerium/pomerium/pkg/identity/legacymanager" + "github.com/pomerium/pomerium/pkg/identity/manager" "github.com/pomerium/pomerium/pkg/netutil" "github.com/pomerium/pomerium/pkg/slices" "github.com/rs/zerolog" @@ -225,9 +228,10 @@ type environment struct { } type EnvironmentOptions struct { - debug bool - pauseOnFailure bool - forceSilent bool + debug bool + pauseOnFailure bool + forceSilent bool + traceDebugFlags trace.DebugFlags } type EnvironmentOption func(*EnvironmentOptions) @@ -265,13 +269,27 @@ func Silent(silent ...bool) EnvironmentOption { } } +func TraceDebugFlags(flags trace.DebugFlags) EnvironmentOption { + return func(o *EnvironmentOptions) { + o.traceDebugFlags = flags + } +} + +func AddTraceDebugFlags(flags trace.DebugFlags) EnvironmentOption { + return func(o *EnvironmentOptions) { + o.traceDebugFlags |= flags + } +} + var setGrpcLoggerOnce sync.Once +const defaultTraceDebugFlags = trace.TrackSpanCallers + var ( flagDebug = flag.Bool("env.debug", false, "enables test environment debug logging (equivalent to Debug() option)") flagPauseOnFailure = flag.Bool("env.pause-on-failure", false, "enables pausing the test environment on failure (equivalent to PauseOnFailure() option)") flagSilent = flag.Bool("env.silent", false, "suppresses all test environment output (equivalent to Silent() option)") - flagTraceDebugLevel = flag.Int("env.trace-debug-level", 0, "trace debug level") + flagTraceDebugFlags = flag.Uint("env.trace-debug-flags", defaultTraceDebugFlags, "trace debug flags (equivalent to TraceDebugFlags() option)") ) func New(t testing.TB, opts ...EnvironmentOption) Environment { @@ -279,9 +297,10 @@ func New(t testing.TB, opts ...EnvironmentOption) Environment { t.Skip("test environment only supported on linux") } options := EnvironmentOptions{ - debug: *flagDebug, - pauseOnFailure: *flagPauseOnFailure, - forceSilent: *flagSilent, + debug: *flagDebug, + pauseOnFailure: *flagPauseOnFailure, + forceSilent: *flagSilent, + traceDebugFlags: trace.DebugFlags(*flagTraceDebugFlags), } options.apply(opts...) if testing.Short() { @@ -323,7 +342,7 @@ func New(t testing.TB, opts ...EnvironmentOption) Environment { logger := zerolog.New(writer).With().Timestamp().Logger().Level(zerolog.DebugLevel) ctx := trace.Options{ - DebugLevel: *flagTraceDebugLevel, + DebugFlags: options.traceDebugFlags, }.NewContext(context.Background()) ctx = logger.WithContext(ctx) tracerProvider := trace.NewTracerProvider(ctx, "Test Environment") @@ -565,6 +584,10 @@ func (e *environment) Start() { opts := []pomerium.Option{ pomerium.WithOverrideFileManager(fileMgr), pomerium.WithEnvoyServerOptions(envoy.WithExitGracePeriod(10 * time.Second)), + pomerium.WithDataBrokerServerOptions( + databroker_service.WithManagerOptions(manager.WithLeaseTTL(1*time.Second)), + databroker_service.WithLegacyManagerOptions(legacymanager.WithLeaseTTL(1*time.Second)), + ), } envoyBinaryPath := filepath.Join(e.workspaceFolder, fmt.Sprintf("pkg/envoy/files/envoy-%s-%s", runtime.GOOS, runtime.GOARCH)) if envutil.EnvoyProfilerAvailable(envoyBinaryPath) { diff --git a/internal/testenv/selftests/tracing_test.go b/internal/testenv/selftests/tracing_test.go index 47912de94..ffc5005db 100644 --- a/internal/testenv/selftests/tracing_test.go +++ b/internal/testenv/selftests/tracing_test.go @@ -4,21 +4,47 @@ import ( "context" "io" "net/http" + "os" "testing" + "time" "github.com/pomerium/pomerium/config" + "github.com/pomerium/pomerium/internal/telemetry/trace" "github.com/pomerium/pomerium/internal/testenv" "github.com/pomerium/pomerium/internal/testenv/scenarios" "github.com/pomerium/pomerium/internal/testenv/snippets" "github.com/pomerium/pomerium/internal/testenv/upstreams" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "go.opentelemetry.io/otel/trace" + oteltrace "go.opentelemetry.io/otel/trace" + "google.golang.org/grpc" + "google.golang.org/grpc/connectivity" + "google.golang.org/grpc/credentials/insecure" ) func TestOTLPTracing(t *testing.T) { - t.Setenv("OTEL_EXPORTER_OTLP_TRACES_ENDPOINT", "http://localhost:4317") - env := testenv.New(t) + tracesEndpoint := os.Getenv("OTEL_EXPORTER_OTLP_TRACES_ENDPOINT") + if tracesEndpoint == "" { + tracesEndpoint = "http://localhost:4317" + os.Setenv("OTEL_EXPORTER_OTLP_TRACES_ENDPOINT", tracesEndpoint) + } + client, err := grpc.NewClient(tracesEndpoint, grpc.WithTransportCredentials(insecure.NewCredentials())) + require.NoError(t, err) + client.Connect() + ctx, ca := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer ca() + if !client.WaitForStateChange(ctx, connectivity.Ready) { + t.Skip("OTLP server offline: " + tracesEndpoint) + } + client.Close() + + env := testenv.New(t, testenv.AddTraceDebugFlags( + trace.WarnOnIncompleteSpans| + trace.WarnOnIncompleteTraces| + trace.WarnOnUnresolvedReferences| + trace.LogTraceIDMappingsOnWarn| + trace.LogAllSpansOnWarn, + )) defer env.Stop() env.Add(testenv.ModifierFunc(func(ctx context.Context, cfg *config.Config) { cfg.Options.ProxyLogLevel = config.LogLevelInfo @@ -43,13 +69,13 @@ func TestOTLPTracing(t *testing.T) { env.Start() snippets.WaitStartupComplete(env) - ctx, span := env.Tracer().Start(env.Context(), "Authenticate", trace.WithNewRoot()) + ctx, span := env.Tracer().Start(env.Context(), "Authenticate", oteltrace.WithNewRoot()) + defer span.End() resp, err := up.Get(route, upstreams.AuthenticateAs("foo@example.com"), upstreams.Path("/foo"), upstreams.Context(ctx)) - span.End() require.NoError(t, err) body, err := io.ReadAll(resp.Body) - require.NoError(t, err) - resp.Body.Close() + assert.NoError(t, err) + assert.NoError(t, resp.Body.Close()) assert.Equal(t, resp.StatusCode, 200) assert.Equal(t, "OK", string(body)) } diff --git a/internal/testenv/upstreams/http.go b/internal/testenv/upstreams/http.go index f2bbbb217..f04256bae 100644 --- a/internal/testenv/upstreams/http.go +++ b/internal/testenv/upstreams/http.go @@ -26,6 +26,8 @@ import ( "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp" "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/codes" + semconv "go.opentelemetry.io/otel/semconv/v1.26.0" oteltrace "go.opentelemetry.io/otel/trace" "google.golang.org/protobuf/proto" ) @@ -156,8 +158,9 @@ type httpUpstream struct { clientCache sync.Map // map[testenv.Route]*http.Client - router *mux.Router - tracerProvider values.MutableValue[oteltrace.TracerProvider] + router *mux.Router + serverTracerProvider values.MutableValue[oteltrace.TracerProvider] + clientTracerProvider values.MutableValue[oteltrace.TracerProvider] } var ( @@ -172,11 +175,12 @@ func HTTP(tlsConfig values.Value[*tls.Config], opts ...HTTPUpstreamOption) HTTPU } options.apply(opts...) up := &httpUpstream{ - HTTPUpstreamOptions: options, - serverPort: values.Deferred[int](), - router: mux.NewRouter(), - tlsConfig: tlsConfig, - tracerProvider: values.Deferred[oteltrace.TracerProvider](), + HTTPUpstreamOptions: options, + serverPort: values.Deferred[int](), + router: mux.NewRouter(), + tlsConfig: tlsConfig, + serverTracerProvider: values.Deferred[oteltrace.TracerProvider](), + clientTracerProvider: values.Deferred[oteltrace.TracerProvider](), } up.RecordCaller() return up @@ -214,15 +218,16 @@ func (h *httpUpstream) Run(ctx context.Context) error { if h.tlsConfig != nil { tlsConfig = h.tlsConfig.Value() } - h.tracerProvider.Resolve(trace.NewTracerProvider(ctx, h.displayName)) - h.router.Use(trace.NewHTTPMiddleware(otelhttp.WithTracerProvider(h.tracerProvider.Value()))) + h.serverTracerProvider.Resolve(trace.NewTracerProvider(ctx, h.displayName)) + h.clientTracerProvider.Resolve(trace.NewTracerProvider(ctx, "HTTP Client")) + h.router.Use(trace.NewHTTPMiddleware(otelhttp.WithTracerProvider(h.serverTracerProvider.Value()))) server := &http.Server{ Handler: h.router, TLSConfig: tlsConfig, - BaseContext: func(net.Listener) context.Context { - return ctx - }, + // BaseContext: func(net.Listener) context.Context { + // return ctx + // }, } errC := make(chan error, 1) go func() { @@ -263,6 +268,12 @@ func (h *httpUpstream) Do(method string, r testenv.Route, opts ...RequestOption) RawQuery: options.query.Encode(), }) } + ctx, span := trace.Continue(options.requestCtx, "httpUpstream.Do", oteltrace.WithAttributes( + attribute.String("method", method), + attribute.String("url", u.String()), + )) + options.requestCtx = ctx + defer span.End() newClient := func() *http.Client { c := http.Client{ @@ -272,7 +283,7 @@ func (h *httpUpstream) Do(method string, r testenv.Route, opts ...RequestOption) Certificates: options.clientCerts, }, }, - otelhttp.WithTracerProvider(h.tracerProvider.Value()), + otelhttp.WithTracerProvider(h.clientTracerProvider.Value()), otelhttp.WithSpanNameFormatter(func(operation string, r *http.Request) string { return fmt.Sprintf("Client: %s %s", r.Method, r.URL.Path) }), @@ -288,16 +299,20 @@ func (h *httpUpstream) Do(method string, r testenv.Route, opts ...RequestOption) var cachedClient any var ok bool if cachedClient, ok = h.clientCache.Load(r); !ok { + span.AddEvent("creating new http client") cachedClient, _ = h.clientCache.LoadOrStore(r, newClient()) + } else { + span.AddEvent("using cached http client") } client = cachedClient.(*http.Client) } var resp *http.Response - if err := retry.Retry(options.requestCtx, "http", func(ctx context.Context) error { - req, err := http.NewRequestWithContext(options.requestCtx, method, u.String(), nil) + resendCount := 0 + if err := retry.Retry(ctx, "http", func(ctx context.Context) error { + req, err := http.NewRequestWithContext(ctx, method, u.String(), nil) if err != nil { - return err + return retry.NewTerminalError(err) } switch body := options.body.(type) { case string: @@ -309,7 +324,7 @@ func (h *httpUpstream) Do(method string, r testenv.Route, opts ...RequestOption) case proto.Message: buf, err := proto.Marshal(body) if err != nil { - return err + return retry.NewTerminalError(err) } req.Body = io.NopCloser(bytes.NewReader(buf)) req.Header.Set("Content-Type", "application/octet-stream") @@ -330,52 +345,70 @@ func (h *httpUpstream) Do(method string, r testenv.Route, opts ...RequestOption) } // retry on connection refused if err != nil { + span.RecordError(err) var opErr *net.OpError if errors.As(err, &opErr) && opErr.Op == "dial" && opErr.Err.Error() == "connect: connection refused" { - oteltrace.SpanFromContext(ctx).AddEvent("Retrying on dial error") + span.AddEvent("Retrying on dial error") return err } return retry.NewTerminalError(err) } if resp.StatusCode/100 == 5 { - if err := resp.Body.Close(); err != nil { - panic(err) - } - oteltrace.SpanFromContext(ctx).AddEvent("Retrying on 5xx error", oteltrace.WithAttributes( + resendCount++ + io.ReadAll(resp.Body) + resp.Body.Close() + span.SetAttributes(semconv.HTTPRequestResendCount(resendCount)) + span.AddEvent("Retrying on 5xx error", oteltrace.WithAttributes( attribute.String("status", resp.Status), )) return errors.New(http.StatusText(resp.StatusCode)) } + span.SetStatus(codes.Ok, "request completed successfully") return nil - }, retry.WithMaxInterval(100*time.Millisecond)); err != nil { + }, + retry.WithInitialInterval(1*time.Millisecond), + retry.WithMaxInterval(100*time.Millisecond), + ); err != nil { return nil, err } return resp, nil } func authenticateFlow(ctx context.Context, client *http.Client, req *http.Request, email string) (*http.Response, error) { + span := oteltrace.SpanFromContext(ctx) var res *http.Response originalHostname := req.URL.Hostname() res, err := client.Do(req) if err != nil { + span.RecordError(err) return nil, err } location := res.Request.URL if location.Hostname() == originalHostname { // already authenticated - return res, err + span.SetStatus(codes.Ok, "already authenticated") + return res, nil } - defer res.Body.Close() fs := forms.Parse(res.Body) + io.ReadAll(res.Body) + res.Body.Close() if len(fs) > 0 { f := fs[0] f.Inputs["email"] = email f.Inputs["token_expiration"] = strconv.Itoa(int((time.Hour * 24).Seconds())) + span.AddEvent("submitting form", oteltrace.WithAttributes(attribute.String("location", location.String()))) formReq, err := f.NewRequestWithContext(ctx, location) if err != nil { + span.RecordError(err) return nil, err } - return client.Do(formReq) + resp, err := client.Do(formReq) + if err != nil { + span.RecordError(err) + return nil, err + } + span.SetStatus(codes.Ok, "form submitted successfully") + return resp, nil } return nil, fmt.Errorf("test bug: expected IDP login form") }