package trace_test import ( "bytes" "context" "fmt" "runtime" "sync/atomic" "testing" "time" "github.com/pomerium/pomerium/internal/telemetry/trace" . "github.com/pomerium/pomerium/internal/testutil/tracetest" //nolint:revive "github.com/stretchr/testify/assert" sdktrace "go.opentelemetry.io/otel/sdk/trace" oteltrace "go.opentelemetry.io/otel/trace" ) func TestSpanObserver(t *testing.T) { t.Run("observe single reference", func(t *testing.T) { obs := trace.NewSpanObserver() assert.Equal(t, []oteltrace.SpanID{}, obs.XUnobservedIDs()) obs.ObserveReference(Span(1).ID(), Span(2).ID()) assert.Equal(t, []oteltrace.SpanID{Span(1).ID()}, obs.XUnobservedIDs()) obs.Observe(Span(1).ID()) assert.Equal(t, []oteltrace.SpanID{}, obs.XUnobservedIDs()) }) t.Run("observe multiple references", func(t *testing.T) { obs := trace.NewSpanObserver() obs.ObserveReference(Span(1).ID(), Span(2).ID()) obs.ObserveReference(Span(1).ID(), Span(3).ID()) obs.ObserveReference(Span(1).ID(), Span(4).ID()) assert.Equal(t, []oteltrace.SpanID{Span(1).ID()}, obs.XUnobservedIDs()) obs.Observe(Span(1).ID()) assert.Equal(t, []oteltrace.SpanID{}, obs.XUnobservedIDs()) }) t.Run("observe before reference", func(t *testing.T) { obs := trace.NewSpanObserver() obs.Observe(Span(1).ID()) assert.Equal(t, []oteltrace.SpanID{}, obs.XUnobservedIDs()) obs.ObserveReference(Span(1).ID(), Span(2).ID()) assert.Equal(t, []oteltrace.SpanID{}, obs.XUnobservedIDs()) }) t.Run("wait", func(t *testing.T) { obs := trace.NewSpanObserver() obs.ObserveReference(Span(1).ID(), Span(2).ID()) obs.Observe(Span(2).ID()) obs.ObserveReference(Span(3).ID(), Span(4).ID()) obs.Observe(Span(4).ID()) obs.ObserveReference(Span(5).ID(), Span(6).ID()) obs.Observe(Span(6).ID()) waitOkToExit := atomic.Bool{} waitExited := atomic.Bool{} go func() { defer waitExited.Store(true) obs.XWait() assert.True(t, waitOkToExit.Load(), "wait exited early") }() time.Sleep(10 * time.Millisecond) assert.False(t, waitExited.Load()) obs.Observe(Span(1).ID()) time.Sleep(10 * time.Millisecond) assert.False(t, waitExited.Load()) obs.Observe(Span(3).ID()) time.Sleep(10 * time.Millisecond) assert.False(t, waitExited.Load()) waitOkToExit.Store(true) obs.Observe(Span(5).ID()) assert.Eventually(t, waitExited.Load, 10*time.Millisecond, 1*time.Millisecond) }) t.Run("new references observed during wait", func(t *testing.T) { obs := trace.NewSpanObserver() obs.ObserveReference(Span(1).ID(), Span(2).ID()) obs.Observe(Span(2).ID()) obs.ObserveReference(Span(3).ID(), Span(4).ID()) obs.Observe(Span(4).ID()) obs.ObserveReference(Span(5).ID(), Span(6).ID()) obs.Observe(Span(6).ID()) waitOkToExit := atomic.Bool{} waitExited := atomic.Bool{} go func() { defer waitExited.Store(true) obs.XWait() assert.True(t, waitOkToExit.Load(), "wait exited early") }() assert.Equal(t, []oteltrace.SpanID{Span(1).ID(), Span(3).ID(), Span(5).ID()}, obs.XUnobservedIDs()) time.Sleep(10 * time.Millisecond) assert.False(t, waitExited.Load()) obs.Observe(Span(1).ID()) assert.Equal(t, []oteltrace.SpanID{Span(3).ID(), Span(5).ID()}, obs.XUnobservedIDs()) time.Sleep(10 * time.Millisecond) assert.False(t, waitExited.Load()) obs.Observe(Span(3).ID()) assert.Equal(t, []oteltrace.SpanID{Span(5).ID()}, obs.XUnobservedIDs()) time.Sleep(10 * time.Millisecond) assert.False(t, waitExited.Load()) // observe a new reference obs.ObserveReference(Span(7).ID(), Span(8).ID()) obs.Observe(Span(8).ID()) assert.Equal(t, []oteltrace.SpanID{Span(5).ID(), Span(7).ID()}, obs.XUnobservedIDs()) time.Sleep(10 * time.Millisecond) assert.False(t, waitExited.Load()) obs.Observe(Span(5).ID()) assert.Equal(t, []oteltrace.SpanID{Span(7).ID()}, obs.XUnobservedIDs()) time.Sleep(10 * time.Millisecond) assert.False(t, waitExited.Load()) waitOkToExit.Store(true) obs.Observe(Span(7).ID()) assert.Equal(t, []oteltrace.SpanID{}, obs.XUnobservedIDs()) assert.Eventually(t, waitExited.Load, 10*time.Millisecond, 1*time.Millisecond) }) t.Run("multiple waiters", func(t *testing.T) { t.Parallel() obs := trace.NewSpanObserver() obs.ObserveReference(Span(1).ID(), Span(2).ID()) obs.Observe(Span(2).ID()) waitersExited := atomic.Int32{} for range 10 { go func() { defer waitersExited.Add(1) obs.XWait() }() } assert.Equal(t, []oteltrace.SpanID{Span(1).ID()}, obs.XUnobservedIDs()) time.Sleep(10 * time.Millisecond) assert.Equal(t, int32(0), waitersExited.Load()) obs.Observe(Span(1).ID()) assert.Eventually(t, func() bool { return waitersExited.Load() == 10 }, 10*time.Millisecond, 1*time.Millisecond) }) } func TestSpanTracker(t *testing.T) { t.Run("no debug flags", func(t *testing.T) { t.Parallel() obs := trace.NewSpanObserver() tracker := trace.NewSpanTracker(obs, 0) tp := sdktrace.NewTracerProvider(sdktrace.WithSpanProcessor(tracker)) tracer := tp.Tracer("test") assert.Equal(t, []oteltrace.SpanID{}, tracker.XInflightSpans()) _, span1 := tracer.Start(context.Background(), "span 1") assert.Equal(t, []oteltrace.SpanID{span1.SpanContext().SpanID()}, tracker.XInflightSpans()) assert.Equal(t, []oteltrace.SpanID{}, obs.XObservedIDs()) span1.End() assert.Equal(t, []oteltrace.SpanID{}, tracker.XInflightSpans()) assert.Equal(t, []oteltrace.SpanID{}, obs.XObservedIDs()) }) t.Run("with TrackSpanReferences debug flag", func(t *testing.T) { t.Parallel() obs := trace.NewSpanObserver() tracker := trace.NewSpanTracker(obs, trace.TrackSpanReferences) tp := sdktrace.NewTracerProvider(sdktrace.WithSpanProcessor(tracker)) tracer := tp.Tracer("test") assert.Equal(t, []oteltrace.SpanID{}, tracker.XInflightSpans()) _, span1 := tracer.Start(context.Background(), "span 1") assert.Equal(t, []oteltrace.SpanID{span1.SpanContext().SpanID()}, tracker.XInflightSpans()) assert.Equal(t, []oteltrace.SpanID{span1.SpanContext().SpanID()}, obs.XObservedIDs()) span1.End() assert.Equal(t, []oteltrace.SpanID{}, tracker.XInflightSpans()) assert.Equal(t, []oteltrace.SpanID{span1.SpanContext().SpanID()}, obs.XObservedIDs()) }) } func TestSpanTrackerWarnings(t *testing.T) { t.Run("WarnOnIncompleteSpans", func(t *testing.T) { var buf bytes.Buffer trace.SetDebugMessageWriterForTest(t, &buf) obs := trace.NewSpanObserver() tracker := trace.NewSpanTracker(obs, trace.WarnOnIncompleteSpans) tp := sdktrace.NewTracerProvider(sdktrace.WithSpanProcessor(tracker)) tracer := tp.Tracer("test") _, span1 := tracer.Start(context.Background(), "span 1") assert.ErrorIs(t, tp.Shutdown(context.Background()), trace.ErrIncompleteSpans) assert.Equal(t, fmt.Sprintf(` ================================================== WARNING: spans not ended: %s Note: set TrackAllSpans flag for more info ================================================== `, span1.SpanContext().SpanID()), buf.String()) }) t.Run("WarnOnIncompleteSpans with TrackAllSpans", func(t *testing.T) { var buf bytes.Buffer trace.SetDebugMessageWriterForTest(t, &buf) obs := trace.NewSpanObserver() tracker := trace.NewSpanTracker(obs, trace.WarnOnIncompleteSpans|trace.TrackAllSpans) tp := sdktrace.NewTracerProvider(sdktrace.WithSpanProcessor(tracker)) tracer := tp.Tracer("test") _, span1 := tracer.Start(context.Background(), "span 1") assert.ErrorIs(t, tp.Shutdown(context.Background()), trace.ErrIncompleteSpans) assert.Equal(t, fmt.Sprintf(` ================================================== WARNING: spans not ended: 'span 1' (trace: %s | span: %s | parent: 0000000000000000) ================================================== `, span1.SpanContext().TraceID(), span1.SpanContext().SpanID()), buf.String()) }) t.Run("WarnOnIncompleteSpans with TrackAllSpans and stackTraceProcessor", func(t *testing.T) { var buf bytes.Buffer trace.SetDebugMessageWriterForTest(t, &buf) obs := trace.NewSpanObserver() tracker := trace.NewSpanTracker(obs, trace.WarnOnIncompleteSpans|trace.TrackAllSpans) tp := sdktrace.NewTracerProvider(sdktrace.WithSpanProcessor(&trace.XStackTraceProcessor{}), sdktrace.WithSpanProcessor(tracker)) tracer := tp.Tracer("test") _, span1 := tracer.Start(context.Background(), "span 1") _, file, line, _ := runtime.Caller(0) line-- assert.ErrorIs(t, tp.Shutdown(context.Background()), trace.ErrIncompleteSpans) assert.Equal(t, fmt.Sprintf(` ================================================== WARNING: spans not ended: 'span 1' (trace: %s | span: %s | parent: 0000000000000000 | started at: %s:%d) ================================================== `, span1.SpanContext().TraceID(), span1.SpanContext().SpanID(), file, line), buf.String()) }) t.Run("LogAllSpansOnWarn", func(t *testing.T) { var buf bytes.Buffer trace.SetDebugMessageWriterForTest(t, &buf) obs := trace.NewSpanObserver() tracker := trace.NewSpanTracker(obs, trace.WarnOnIncompleteSpans|trace.TrackAllSpans|trace.LogAllSpansOnWarn) tp := sdktrace.NewTracerProvider(sdktrace.WithSpanProcessor(&trace.XStackTraceProcessor{}), sdktrace.WithSpanProcessor(tracker)) tracer := tp.Tracer("test") _, span1 := tracer.Start(context.Background(), "span 1") time.Sleep(10 * time.Millisecond) span1.End() time.Sleep(10 * time.Millisecond) _, span2 := tracer.Start(context.Background(), "span 2") _, file, line, _ := runtime.Caller(0) line-- tp.Shutdown(context.Background()) assert.Equal(t, fmt.Sprintf(` ================================================== WARNING: spans not ended: 'span 2' (trace: %[1]s | span: %[2]s | parent: 0000000000000000 | started at: %[3]s:%[4]d) ================================================== ================================================== All observed spans: 'span 1' (trace: %[5]s | span: %[6]s | parent: 0000000000000000 | started at: %[3]s:%[7]d) 'span 2' (trace: %[1]s | span: %[2]s | parent: 0000000000000000 | started at: %[3]s:%[4]d) ================================================== `, span2.SpanContext().TraceID(), span2.SpanContext().SpanID(), file, line, span1.SpanContext().TraceID(), span1.SpanContext().SpanID(), line-4, ), buf.String()) }) }