pomerium/internal/testutil/tracetest/tracing.go
Caleb Doxsey c47055bece
upgrade to go v1.24 (#5562)
* upgrade to go v1.24

* add a macOS-specific //nolint comment too

---------

Co-authored-by: Kenneth Jenkins <51246568+kenjenkins@users.noreply.github.com>
2025-04-02 15:53:09 -06:00

630 lines
18 KiB
Go

package tracetest
import (
"cmp"
"context"
"encoding/binary"
"encoding/json"
"fmt"
"maps"
"runtime"
"slices"
"strings"
"sync"
"testing"
"time"
"unique"
gocmp "github.com/google/go-cmp/cmp"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
oteltrace "go.opentelemetry.io/otel/trace"
coltracepb "go.opentelemetry.io/proto/otlp/collector/trace/v1"
commonv1 "go.opentelemetry.io/proto/otlp/common/v1"
resourcev1 "go.opentelemetry.io/proto/otlp/resource/v1"
tracev1 "go.opentelemetry.io/proto/otlp/trace/v1"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/testing/protocmp"
"github.com/pomerium/pomerium/pkg/telemetry/trace"
)
type (
Trace uint32
Span uint32
Scope uint32
Schema uint32
Resource uint32
)
func (n Trace) String() string { return fmt.Sprintf("Trace %d", n) }
func (n Span) String() string { return fmt.Sprintf("Span %d", n) }
func (n Scope) String() string { return fmt.Sprintf("Scope %d", n) }
func (n Schema) String() string { return fmt.Sprintf("Schema %d", n) }
func (n Resource) String() string { return fmt.Sprintf("Resource %d", n) }
func (n Trace) ID() unique.Handle[oteltrace.TraceID] {
id, _ := trace.ToTraceID(n.B())
return id
}
func (n Trace) B() []byte {
var id oteltrace.TraceID
binary.BigEndian.PutUint32(id[12:], uint32(n))
return id[:]
}
func (n Span) ID() oteltrace.SpanID {
id, _ := trace.ToSpanID(n.B())
return id
}
func (n Span) B() []byte {
var id oteltrace.SpanID
binary.BigEndian.PutUint32(id[4:], uint32(n))
return id[:]
}
func (n Scope) Make(s ...Schema) *ScopeInfo {
if len(s) == 0 {
s = append(s, Schema(0))
}
return NewScopeInfo(&commonv1.InstrumentationScope{
Name: n.String(),
Version: "v1",
Attributes: []*commonv1.KeyValue{
{
Key: "id",
Value: &commonv1.AnyValue{
Value: &commonv1.AnyValue_IntValue{
IntValue: int64(n),
},
},
},
},
}, s[0].String())
}
func (n Resource) Make(s ...Schema) *ResourceInfo {
if len(s) == 0 {
s = append(s, Schema(0))
}
return NewResourceInfo(&resourcev1.Resource{
Attributes: []*commonv1.KeyValue{
{
Key: "name",
Value: &commonv1.AnyValue{
Value: &commonv1.AnyValue_StringValue{
StringValue: n.String(),
},
},
},
{
Key: "id",
Value: &commonv1.AnyValue{
Value: &commonv1.AnyValue_IntValue{
IntValue: int64(n),
},
},
},
},
}, s[0].String())
}
func Traceparent(trace Trace, span Span, sampled bool) string {
sampledStr := "00"
if sampled {
sampledStr = "01"
}
return fmt.Sprintf("00-%s-%s-%s", trace.ID().Value(), span.ID(), sampledStr)
}
type TraceResults struct {
resourceSpans []*tracev1.ResourceSpans
GetResources func() []*resourcev1.Resource
GetTraces func() *Traces
}
type Traces struct {
ByID map[unique.Handle[oteltrace.TraceID]]*TraceDetails
ByName map[string]TraceDetailsList
ByParticipant map[string]TraceDetailsList
}
func (t *Traces) WithoutErrors() *Traces {
byID := make(map[unique.Handle[oteltrace.TraceID]]*TraceDetails, len(t.ByID))
for k, v := range t.ByID {
if len(v.Errors) > 0 {
continue
}
byID[k] = v
}
byName := make(map[string]TraceDetailsList)
for k, v := range t.ByName {
filtered := v.WithoutErrors()
if len(filtered) == 0 {
continue
}
byName[k] = filtered
}
byParticipant := make(map[string]TraceDetailsList)
for k, v := range t.ByParticipant {
filtered := v.WithoutErrors()
if len(filtered) == 0 {
continue
}
byParticipant[k] = filtered
}
return &Traces{
ByID: byID,
ByName: byName,
ByParticipant: byParticipant,
}
}
type TraceDetails struct {
ID unique.Handle[oteltrace.TraceID]
Name string
Spans []*SpanDetails
Services []string
StartTime time.Time
EndTime time.Time
Duration time.Duration
Errors []int // indexes into Spans
}
func (td *TraceDetails) Equal(other *TraceDetails) (bool, string) {
diffSpans := func(a, b []*SpanDetails) (bool, string) {
for i := range len(a) {
diff := gocmp.Diff(a[i], b[i], protocmp.Transform())
if diff != "" {
return false, diff
}
}
return true, ""
}
if td.ID != other.ID {
return false, fmt.Sprintf("traces are trivially not equal: ID %s (actual) != %s (expected)", td.ID.Value(), other.ID.Value())
}
if len(td.Spans) != len(other.Spans) {
return false, fmt.Sprintf("traces are trivially not equal: len(spans) %d (actual) != %d (expected)", len(td.Spans), len(other.Spans))
}
if !td.StartTime.Equal(other.StartTime) {
return false, fmt.Sprintf("traces are trivially not equal: start time %s (actual) != %s (expected)", td.StartTime, other.StartTime)
}
if !td.EndTime.Equal(other.EndTime) {
return false, fmt.Sprintf("traces are trivially not equal: end time %s (actual) != %s (expected)", td.EndTime, other.EndTime)
}
return diffSpans(td.Spans, other.Spans)
}
type TraceDetailsList []*TraceDetails
func (list TraceDetailsList) WithoutExportRPCs() TraceDetailsList {
out := make(TraceDetailsList, 0, len(list))
for _, td := range list {
if strings.Contains(td.Name, "opentelemetry.proto.collector.trace.v1.TraceService/Export") {
continue
}
out = append(out, td)
}
return out
}
func (list TraceDetailsList) WithoutErrors() TraceDetailsList {
out := make(TraceDetailsList, 0, len(list))
for _, td := range list {
if len(td.Errors) > 0 {
continue
}
out = append(out, td)
}
return out
}
func (td *TraceDetails) SpanTree() *SpanTree {
nodesByID := map[oteltrace.SpanID]*SpanTreeNode{}
nodesByID[oteltrace.SpanID([8]byte{})] = &SpanTreeNode{} // root node
for _, span := range td.Spans {
spanID, _ := trace.ToSpanID(span.Raw.SpanId)
nodesByID[spanID] = &SpanTreeNode{
Span: span,
}
}
detachedNodesByID := map[oteltrace.SpanID]*SpanTreeNode{}
for _, span := range td.Spans {
spanID, _ := trace.ToSpanID(span.Raw.SpanId)
parentSpanID, _ := trace.ToSpanID(span.Raw.ParentSpanId)
if _, ok := nodesByID[parentSpanID]; !ok {
detachedNodesByID[parentSpanID] = &SpanTreeNode{}
nodesByID[parentSpanID] = detachedNodesByID[parentSpanID]
}
nodesByID[spanID].Parent = nodesByID[parentSpanID]
nodesByID[parentSpanID].Children = append(nodesByID[parentSpanID].Children, nodesByID[spanID])
}
for _, node := range nodesByID {
slices.SortFunc(node.Children, func(a, b *SpanTreeNode) int {
return cmp.Compare(a.Span.Raw.StartTimeUnixNano, b.Span.Raw.StartTimeUnixNano)
})
}
return &SpanTree{
Root: nodesByID[oteltrace.SpanID([8]byte{})],
DetachedParents: detachedNodesByID,
}
}
type SpanDetails struct {
Raw *tracev1.Span
Resource *resourcev1.Resource
Scope *commonv1.InstrumentationScope
StartTime time.Time
EndTime time.Time
Duration time.Duration
Service string
}
func NewTraceResults(resourceSpans []*tracev1.ResourceSpans) *TraceResults {
tr := &TraceResults{
resourceSpans: resourceSpans,
}
tr.GetResources = sync.OnceValue(tr.computeResources)
tr.GetTraces = sync.OnceValue(tr.computeTraces)
return tr
}
func (tr *TraceResults) computeResources() []*resourcev1.Resource {
resources := []*resourcev1.Resource{}
for _, res := range tr.resourceSpans {
resources = append(resources, res.Resource)
}
return resources
}
func (tr *TraceResults) computeTraces() *Traces {
tracesByID := map[unique.Handle[oteltrace.TraceID]]*TraceDetails{}
for _, resSpan := range tr.resourceSpans {
resource := resSpan.Resource
for _, scopeSpans := range resSpan.ScopeSpans {
scope := scopeSpans.Scope
for _, span := range scopeSpans.Spans {
traceID, _ := trace.ToTraceID(span.TraceId)
var details *TraceDetails
if d, ok := tracesByID[traceID]; ok {
details = d
} else {
details = &TraceDetails{
ID: traceID,
}
tracesByID[traceID] = details
}
svc := ""
for _, attr := range resource.Attributes {
if attr.Key == "service.name" {
svc = attr.Value.GetStringValue()
break
}
}
details.Spans = append(details.Spans, &SpanDetails{
Raw: span,
Resource: resource,
Scope: scope,
StartTime: time.Unix(0, int64(span.StartTimeUnixNano)),
EndTime: time.Unix(0, int64(span.EndTimeUnixNano)),
Duration: time.Duration(span.EndTimeUnixNano - span.StartTimeUnixNano),
Service: svc,
})
if span.Status != nil {
if span.Status.Code == tracev1.Status_STATUS_CODE_ERROR {
details.Errors = append(details.Errors, len(details.Spans)-1)
}
}
}
}
}
tracesByName := map[string]TraceDetailsList{}
tracesByParticipant := map[string]TraceDetailsList{}
// sort spans by start time and compute durations
for _, td := range tracesByID {
slices.SortFunc(td.Spans, func(a, b *SpanDetails) int {
return cmp.Compare(a.Raw.StartTimeUnixNano, b.Raw.StartTimeUnixNano)
})
startTime := td.Spans[0].Raw.StartTimeUnixNano
endTime := td.Spans[0].Raw.EndTimeUnixNano
serviceNames := map[string]struct{}{}
for _, span := range td.Spans {
startTime = min(startTime, span.Raw.StartTimeUnixNano)
endTime = max(endTime, span.Raw.EndTimeUnixNano)
if span.Service != "" {
serviceNames[span.Service] = struct{}{}
}
}
td.StartTime = time.Unix(0, int64(startTime))
td.EndTime = time.Unix(0, int64(endTime))
td.Duration = td.EndTime.Sub(td.StartTime)
td.Services = slices.Sorted(maps.Keys(serviceNames))
td.Name = fmt.Sprintf("%s: %s", td.Spans[0].Service, td.Spans[0].Raw.Name)
tracesByName[td.Name] = append(tracesByName[td.Name], td)
for svc := range serviceNames {
tracesByParticipant[svc] = append(tracesByParticipant[svc], td)
}
}
return &Traces{
ByID: tracesByID,
ByName: tracesByName,
ByParticipant: tracesByParticipant,
}
}
type SpanTree struct {
Root *SpanTreeNode
DetachedParents map[oteltrace.SpanID]*SpanTreeNode
}
type SpanTreeNode struct {
Span *SpanDetails
Parent *SpanTreeNode
Children []*SpanTreeNode
}
type Match struct {
Name string
TraceCount any
Services []string
}
type (
GreaterOrEqual int
Greater int
// Any makes no assertions on the trace count. If the trace is not found, it
// doesn't count against the Exact match option.
Any struct{}
// EqualToMatch asserts that the value is the same as the value of another
// match (by name)
EqualToMatch string
// GreaterThanMatch asserts that the value is greater than the value of
// another match (by name)
GreaterThanMatch string
)
type MatchOptions struct {
// If true, asserts that there is exactly one [Match] entry per result
Exact bool
// If true, asserts that no traces contain detached spans
CheckDetachedSpans bool
}
func (tr *TraceResults) MatchTraces(t testing.TB, opts MatchOptions, matches ...Match) {
t.Helper()
traces := tr.GetTraces()
matchArgsByName := map[string]Match{}
for i, m := range matches {
if m.Name != "" {
require.NotContains(t, matchArgsByName, m.Name, "duplicate name")
matchArgsByName[m.Name] = m
if traceDetails, ok := traces.ByName[m.Name]; ok {
switch tc := m.TraceCount.(type) {
case GreaterOrEqual:
assert.GreaterOrEqualf(t, len(traceDetails), int(tc),
"[match %d]: expected %q to have >=%d traces, but found %d",
i+1, m.Name, int(tc), len(traceDetails))
case Greater:
assert.Greaterf(t, len(traceDetails), int(tc),
"[match %d]: expected %q to have >%d traces, but found %d",
i+1, m.Name, int(tc), len(traceDetails))
case GreaterThanMatch:
assert.Greaterf(t, len(traceDetails), len(traces.ByName[string(tc)]),
"[match %d]: expected %q to have >%d traces (value of %s), but found %d",
i+1, m.Name, len(traces.ByName[string(tc)]), string(tc), len(traceDetails))
case EqualToMatch:
assert.Equalf(t, len(traceDetails), len(traces.ByName[string(tc)]),
"[match %d]: expected %q to have %d traces (value of %s), but found %d",
i+1, m.Name, len(traces.ByName[string(tc)]), string(tc), len(traceDetails))
case Any:
case int:
s := "s"
if tc == 1 {
s = ""
}
assert.Lenf(t, traceDetails, tc,
"[match %d]: expected %q to have %d trace%s, but found %d",
i+1, m.Name, tc, s, len(traceDetails))
}
if m.Services != nil {
for _, trace := range traceDetails {
assert.ElementsMatch(t, m.Services, trace.Services)
}
}
} else if _, ok := m.TraceCount.(Any); !ok {
t.Errorf("no traces with name %q found", m.Name)
}
}
}
if opts.CheckDetachedSpans {
for _, trace := range traces.ByID {
tree := trace.SpanTree()
if !assert.Empty(t, tree.DetachedParents) {
for spanID, node := range tree.DetachedParents {
t.Log("------------------------------------")
t.Logf("span id: %s", spanID)
if len(node.Children) != 0 {
t.Log("children:")
}
for _, c := range node.Children {
t.Log(protojson.Format(c.Span.Raw))
}
t.Log("------------------------------------")
}
}
}
}
if opts.Exact {
expected := slices.Sorted(maps.Keys(matchArgsByName))
actual := slices.Sorted(maps.Keys(traces.ByName))
for name, match := range matchArgsByName {
if _, ok := traces.ByName[name]; !ok {
if _, ok := match.TraceCount.(Any); ok {
expected = slices.DeleteFunc(expected, func(s string) bool { return s == name })
}
}
}
assert.Equal(t, expected, actual)
}
}
func (tr *TraceResults) AssertEqual(t testing.TB, expectedResults *TraceResults, msgFmtAndArgs ...any) {
t.Helper()
actualTraces := tr.GetTraces()
expectedTraces := expectedResults.GetTraces()
for traceID, expected := range expectedTraces.ByID {
if actual, ok := actualTraces.ByID[traceID]; !ok {
if len(msgFmtAndArgs) > 0 {
t.Errorf("expected trace id %s not found (%s)", traceID.Value().String(),
fmt.Sprintf(msgFmtAndArgs[0].(string), msgFmtAndArgs[1:]...))
} else {
t.Errorf("expected trace id %s not found", traceID.Value().String())
}
} else {
if equal, diff := actual.Equal(expected); !equal {
if len(msgFmtAndArgs) > 0 {
t.Errorf("trace %s is not equal (%s):\n%s", traceID.Value().String(),
fmt.Sprintf(msgFmtAndArgs[0].(string), msgFmtAndArgs[1:]...), diff)
} else {
t.Errorf("trace %s is not equal:\n%s", traceID.Value().String(), diff)
}
}
}
}
for traceID := range actualTraces.ByID {
if _, ok := expectedTraces.ByID[traceID]; !ok {
if len(msgFmtAndArgs) > 0 {
t.Errorf("unexpected trace id %s found (%s)", traceID.Value().String(),
fmt.Sprintf(msgFmtAndArgs[0].(string), msgFmtAndArgs[1:]...))
} else {
t.Errorf("unexpected trace id %s found", traceID.Value().String())
}
}
}
}
func FlattenResourceSpans(lists [][]*tracev1.ResourceSpans) []*tracev1.ResourceSpans {
res := NewBuffer()
for _, list := range lists {
for _, resource := range list {
resInfo := NewResourceInfo(resource.Resource, resource.SchemaUrl)
for _, scope := range resource.ScopeSpans {
scopeInfo := NewScopeInfo(scope.Scope, scope.SchemaUrl)
for _, span := range scope.Spans {
res.Insert(resInfo, scopeInfo, span)
}
}
}
}
return res.Flush()
}
func FlattenExportRequests(reqs []*coltracepb.ExportTraceServiceRequest) []*tracev1.ResourceSpans {
lists := make([][]*tracev1.ResourceSpans, len(reqs))
for i, req := range reqs {
lists[i] = req.ResourceSpans
}
return FlattenResourceSpans(lists)
}
type EventRecording struct {
events []trace.DebugEvent
normalizedTo time.Time
}
func LoadEventRecording(raw []byte) (*EventRecording, error) {
events := []trace.DebugEvent{}
if err := json.Unmarshal(raw, &events); err != nil {
return nil, err
}
for i := 1; i < len(events); i++ {
if events[i].Timestamp.Before(events[i-1].Timestamp) {
return nil, fmt.Errorf("invalid timestamps: event %d occurred before event %d", i, i-1)
}
}
return &EventRecording{
events: events,
}, nil
}
func (er *EventRecording) Normalize(startTime time.Time) {
if len(er.events) == 0 {
return
}
er.normalizedTo = startTime
offset := startTime.Sub(er.events[0].Timestamp)
for i, ev := range er.events {
er.events[i].Timestamp = ev.Timestamp.Add(offset)
for _, resSpan := range ev.Request.ResourceSpans {
for _, scopeSpans := range resSpan.ScopeSpans {
for _, span := range scopeSpans.Spans {
span.StartTimeUnixNano += uint64(offset)
span.EndTimeUnixNano += uint64(offset)
for _, event := range span.Events {
event.TimeUnixNano += uint64(offset)
}
}
}
}
}
}
func (er *EventRecording) NormalizedTo() time.Time {
return er.normalizedTo
}
type EventCallbackFunc = func(ctx context.Context, req *coltracepb.ExportTraceServiceRequest) (*coltracepb.ExportTraceServiceResponse, error)
func (er *EventRecording) Events() []trace.DebugEvent {
return er.events
}
func (er *EventRecording) Clone() *EventRecording {
clonedEvents := make([]trace.DebugEvent, 0, len(er.events))
for _, ev := range er.events {
clonedEvents = append(clonedEvents, trace.DebugEvent{
Timestamp: ev.Timestamp,
Request: proto.Clone(ev.Request).(*coltracepb.ExportTraceServiceRequest),
})
}
c := &EventRecording{
events: clonedEvents,
normalizedTo: er.normalizedTo,
}
return c
}
func (er *EventRecording) Replay(callback EventCallbackFunc) error {
runtime.LockOSThread()
defer runtime.UnlockOSThread()
durations := make([]time.Duration, 0, len(er.events)-1)
for i := 1; i < len(er.events); i++ {
durations = append(durations, er.events[i].Timestamp.Sub(er.events[i-1].Timestamp))
}
var wg sync.WaitGroup
wg.Add(len(er.events))
er.Normalize(time.Now())
for i, ev := range er.events {
go func() {
callback(context.Background(), ev.Request)
wg.Done()
}()
if i < len(er.events)-1 {
time.Sleep(durations[i])
}
}
wg.Wait()
return nil
}