package tracetest

import (
	"cmp"
	"context"
	"encoding/binary"
	"encoding/json"
	"fmt"
	"maps"
	"runtime"
	"slices"
	"strings"
	"sync"
	"testing"
	"time"
	"unique"

	gocmp "github.com/google/go-cmp/cmp"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
	oteltrace "go.opentelemetry.io/otel/trace"
	coltracepb "go.opentelemetry.io/proto/otlp/collector/trace/v1"
	commonv1 "go.opentelemetry.io/proto/otlp/common/v1"
	resourcev1 "go.opentelemetry.io/proto/otlp/resource/v1"
	tracev1 "go.opentelemetry.io/proto/otlp/trace/v1"
	"google.golang.org/protobuf/encoding/protojson"
	"google.golang.org/protobuf/proto"
	"google.golang.org/protobuf/testing/protocmp"

	"github.com/pomerium/pomerium/pkg/telemetry/trace"
)

type (
	Trace    uint32
	Span     uint32
	Scope    uint32
	Schema   uint32
	Resource uint32
)

func (n Trace) String() string    { return fmt.Sprintf("Trace %d", n) }
func (n Span) String() string     { return fmt.Sprintf("Span %d", n) }
func (n Scope) String() string    { return fmt.Sprintf("Scope %d", n) }
func (n Schema) String() string   { return fmt.Sprintf("Schema %d", n) }
func (n Resource) String() string { return fmt.Sprintf("Resource %d", n) }

func (n Trace) ID() unique.Handle[oteltrace.TraceID] {
	id, _ := trace.ToTraceID(n.B())
	return id
}

func (n Trace) B() []byte {
	var id oteltrace.TraceID
	binary.BigEndian.PutUint32(id[12:], uint32(n))
	return id[:]
}

func (n Span) ID() oteltrace.SpanID {
	id, _ := trace.ToSpanID(n.B())
	return id
}

func (n Span) B() []byte {
	var id oteltrace.SpanID
	binary.BigEndian.PutUint32(id[4:], uint32(n))
	return id[:]
}

func (n Scope) Make(s ...Schema) *ScopeInfo {
	if len(s) == 0 {
		s = append(s, Schema(0))
	}
	return NewScopeInfo(&commonv1.InstrumentationScope{
		Name:    n.String(),
		Version: "v1",
		Attributes: []*commonv1.KeyValue{
			{
				Key: "id",
				Value: &commonv1.AnyValue{
					Value: &commonv1.AnyValue_IntValue{
						IntValue: int64(n),
					},
				},
			},
		},
	}, s[0].String())
}

func (n Resource) Make(s ...Schema) *ResourceInfo {
	if len(s) == 0 {
		s = append(s, Schema(0))
	}
	return NewResourceInfo(&resourcev1.Resource{
		Attributes: []*commonv1.KeyValue{
			{
				Key: "name",
				Value: &commonv1.AnyValue{
					Value: &commonv1.AnyValue_StringValue{
						StringValue: n.String(),
					},
				},
			},
			{
				Key: "id",
				Value: &commonv1.AnyValue{
					Value: &commonv1.AnyValue_IntValue{
						IntValue: int64(n),
					},
				},
			},
		},
	}, s[0].String())
}

func Traceparent(trace Trace, span Span, sampled bool) string {
	sampledStr := "00"
	if sampled {
		sampledStr = "01"
	}
	return fmt.Sprintf("00-%s-%s-%s", trace.ID().Value(), span.ID(), sampledStr)
}

type TraceResults struct {
	resourceSpans []*tracev1.ResourceSpans

	GetResources func() []*resourcev1.Resource
	GetTraces    func() *Traces
}

type Traces struct {
	ByID          map[unique.Handle[oteltrace.TraceID]]*TraceDetails
	ByName        map[string]TraceDetailsList
	ByParticipant map[string]TraceDetailsList
}

func (t *Traces) WithoutErrors() *Traces {
	byID := make(map[unique.Handle[oteltrace.TraceID]]*TraceDetails, len(t.ByID))
	for k, v := range t.ByID {
		if len(v.Errors) > 0 {
			continue
		}
		byID[k] = v
	}
	byName := make(map[string]TraceDetailsList)
	for k, v := range t.ByName {
		filtered := v.WithoutErrors()
		if len(filtered) == 0 {
			continue
		}
		byName[k] = filtered
	}
	byParticipant := make(map[string]TraceDetailsList)
	for k, v := range t.ByParticipant {
		filtered := v.WithoutErrors()
		if len(filtered) == 0 {
			continue
		}
		byParticipant[k] = filtered
	}
	return &Traces{
		ByID:          byID,
		ByName:        byName,
		ByParticipant: byParticipant,
	}
}

type TraceDetails struct {
	ID        unique.Handle[oteltrace.TraceID]
	Name      string
	Spans     []*SpanDetails
	Services  []string
	StartTime time.Time
	EndTime   time.Time
	Duration  time.Duration
	Errors    []int // indexes into Spans
}

func (td *TraceDetails) Equal(other *TraceDetails) (bool, string) {
	diffSpans := func(a, b []*SpanDetails) (bool, string) {
		for i := range len(a) {
			diff := gocmp.Diff(a[i], b[i], protocmp.Transform())
			if diff != "" {
				return false, diff
			}
		}
		return true, ""
	}
	if td.ID != other.ID {
		return false, fmt.Sprintf("traces are trivially not equal: ID %s (actual) != %s (expected)", td.ID.Value(), other.ID.Value())
	}
	if len(td.Spans) != len(other.Spans) {
		return false, fmt.Sprintf("traces are trivially not equal: len(spans) %d (actual) != %d (expected)", len(td.Spans), len(other.Spans))
	}
	if !td.StartTime.Equal(other.StartTime) {
		return false, fmt.Sprintf("traces are trivially not equal: start time %s (actual) != %s (expected)", td.StartTime, other.StartTime)
	}
	if !td.EndTime.Equal(other.EndTime) {
		return false, fmt.Sprintf("traces are trivially not equal: end time %s (actual) != %s (expected)", td.EndTime, other.EndTime)
	}
	return diffSpans(td.Spans, other.Spans)
}

type TraceDetailsList []*TraceDetails

func (list TraceDetailsList) WithoutExportRPCs() TraceDetailsList {
	out := make(TraceDetailsList, 0, len(list))
	for _, td := range list {
		if strings.Contains(td.Name, "opentelemetry.proto.collector.trace.v1.TraceService/Export") {
			continue
		}
		out = append(out, td)
	}
	return out
}

func (list TraceDetailsList) WithoutErrors() TraceDetailsList {
	out := make(TraceDetailsList, 0, len(list))
	for _, td := range list {
		if len(td.Errors) > 0 {
			continue
		}
		out = append(out, td)
	}
	return out
}

func (td *TraceDetails) SpanTree() *SpanTree {
	nodesByID := map[oteltrace.SpanID]*SpanTreeNode{}
	nodesByID[oteltrace.SpanID([8]byte{})] = &SpanTreeNode{} // root node
	for _, span := range td.Spans {
		spanID, _ := trace.ToSpanID(span.Raw.SpanId)
		nodesByID[spanID] = &SpanTreeNode{
			Span: span,
		}
	}
	detachedNodesByID := map[oteltrace.SpanID]*SpanTreeNode{}
	for _, span := range td.Spans {
		spanID, _ := trace.ToSpanID(span.Raw.SpanId)
		parentSpanID, _ := trace.ToSpanID(span.Raw.ParentSpanId)
		if _, ok := nodesByID[parentSpanID]; !ok {
			detachedNodesByID[parentSpanID] = &SpanTreeNode{}
			nodesByID[parentSpanID] = detachedNodesByID[parentSpanID]
		}
		nodesByID[spanID].Parent = nodesByID[parentSpanID]
		nodesByID[parentSpanID].Children = append(nodesByID[parentSpanID].Children, nodesByID[spanID])
	}
	for _, node := range nodesByID {
		slices.SortFunc(node.Children, func(a, b *SpanTreeNode) int {
			return cmp.Compare(a.Span.Raw.StartTimeUnixNano, b.Span.Raw.StartTimeUnixNano)
		})
	}
	return &SpanTree{
		Root:            nodesByID[oteltrace.SpanID([8]byte{})],
		DetachedParents: detachedNodesByID,
	}
}

type SpanDetails struct {
	Raw       *tracev1.Span
	Resource  *resourcev1.Resource
	Scope     *commonv1.InstrumentationScope
	StartTime time.Time
	EndTime   time.Time
	Duration  time.Duration
	Service   string
}

func NewTraceResults(resourceSpans []*tracev1.ResourceSpans) *TraceResults {
	tr := &TraceResults{
		resourceSpans: resourceSpans,
	}
	tr.GetResources = sync.OnceValue(tr.computeResources)
	tr.GetTraces = sync.OnceValue(tr.computeTraces)
	return tr
}

func (tr *TraceResults) computeResources() []*resourcev1.Resource {
	resources := []*resourcev1.Resource{}
	for _, res := range tr.resourceSpans {
		resources = append(resources, res.Resource)
	}
	return resources
}

func (tr *TraceResults) computeTraces() *Traces {
	tracesByID := map[unique.Handle[oteltrace.TraceID]]*TraceDetails{}
	for _, resSpan := range tr.resourceSpans {
		resource := resSpan.Resource
		for _, scopeSpans := range resSpan.ScopeSpans {
			scope := scopeSpans.Scope
			for _, span := range scopeSpans.Spans {
				traceID, _ := trace.ToTraceID(span.TraceId)
				var details *TraceDetails
				if d, ok := tracesByID[traceID]; ok {
					details = d
				} else {
					details = &TraceDetails{
						ID: traceID,
					}
					tracesByID[traceID] = details
				}
				svc := ""
				for _, attr := range resource.Attributes {
					if attr.Key == "service.name" {
						svc = attr.Value.GetStringValue()
						break
					}
				}
				details.Spans = append(details.Spans, &SpanDetails{
					Raw:       span,
					Resource:  resource,
					Scope:     scope,
					StartTime: time.Unix(0, int64(span.StartTimeUnixNano)),
					EndTime:   time.Unix(0, int64(span.EndTimeUnixNano)),
					Duration:  time.Duration(span.EndTimeUnixNano - span.StartTimeUnixNano),
					Service:   svc,
				})
				if span.Status != nil {
					if span.Status.Code == tracev1.Status_STATUS_CODE_ERROR {
						details.Errors = append(details.Errors, len(details.Spans)-1)
					}
				}
			}
		}
	}

	tracesByName := map[string]TraceDetailsList{}
	tracesByParticipant := map[string]TraceDetailsList{}
	// sort spans by start time and compute durations
	for _, td := range tracesByID {
		slices.SortFunc(td.Spans, func(a, b *SpanDetails) int {
			return cmp.Compare(a.Raw.StartTimeUnixNano, b.Raw.StartTimeUnixNano)
		})
		startTime := td.Spans[0].Raw.StartTimeUnixNano
		endTime := td.Spans[0].Raw.EndTimeUnixNano
		serviceNames := map[string]struct{}{}
		for _, span := range td.Spans {
			startTime = min(startTime, span.Raw.StartTimeUnixNano)
			endTime = max(endTime, span.Raw.EndTimeUnixNano)
			if span.Service != "" {
				serviceNames[span.Service] = struct{}{}
			}
		}
		td.StartTime = time.Unix(0, int64(startTime))
		td.EndTime = time.Unix(0, int64(endTime))
		td.Duration = td.EndTime.Sub(td.StartTime)
		td.Services = slices.Sorted(maps.Keys(serviceNames))
		td.Name = fmt.Sprintf("%s: %s", td.Spans[0].Service, td.Spans[0].Raw.Name)
		tracesByName[td.Name] = append(tracesByName[td.Name], td)
		for svc := range serviceNames {
			tracesByParticipant[svc] = append(tracesByParticipant[svc], td)
		}
	}

	return &Traces{
		ByID:          tracesByID,
		ByName:        tracesByName,
		ByParticipant: tracesByParticipant,
	}
}

type SpanTree struct {
	Root            *SpanTreeNode
	DetachedParents map[oteltrace.SpanID]*SpanTreeNode
}

type SpanTreeNode struct {
	Span     *SpanDetails
	Parent   *SpanTreeNode
	Children []*SpanTreeNode
}

type Match struct {
	Name       string
	TraceCount any
	Services   []string
}

type (
	GreaterOrEqual int
	Greater        int

	// Any makes no assertions on the trace count. If the trace is not found, it
	// doesn't count against the Exact match option.
	Any struct{}

	// EqualToMatch asserts that the value is the same as the value of another
	// match (by name)
	EqualToMatch string
	// GreaterThanMatch asserts that the value is greater than the value of
	// another match (by name)
	GreaterThanMatch string
)

type MatchOptions struct {
	// If true, asserts that there is exactly one [Match] entry per result
	Exact bool
	// If true, asserts that no traces contain detached spans
	CheckDetachedSpans bool
}

func (tr *TraceResults) MatchTraces(t testing.TB, opts MatchOptions, matches ...Match) {
	t.Helper()
	traces := tr.GetTraces()
	matchArgsByName := map[string]Match{}
	for i, m := range matches {
		if m.Name != "" {
			require.NotContains(t, matchArgsByName, m.Name, "duplicate name")
			matchArgsByName[m.Name] = m
			if traceDetails, ok := traces.ByName[m.Name]; ok {
				switch tc := m.TraceCount.(type) {
				case GreaterOrEqual:
					assert.GreaterOrEqualf(t, len(traceDetails), int(tc),
						"[match %d]: expected %q to have >=%d traces, but found %d",
						i+1, m.Name, int(tc), len(traceDetails))
				case Greater:
					assert.Greaterf(t, len(traceDetails), int(tc),
						"[match %d]: expected %q to have >%d traces, but found %d",
						i+1, m.Name, int(tc), len(traceDetails))
				case GreaterThanMatch:
					assert.Greaterf(t, len(traceDetails), len(traces.ByName[string(tc)]),
						"[match %d]: expected %q to have >%d traces (value of %s), but found %d",
						i+1, m.Name, len(traces.ByName[string(tc)]), string(tc), len(traceDetails))
				case EqualToMatch:
					assert.Equalf(t, len(traceDetails), len(traces.ByName[string(tc)]),
						"[match %d]: expected %q to have %d traces (value of %s), but found %d",
						i+1, m.Name, len(traces.ByName[string(tc)]), string(tc), len(traceDetails))
				case Any:
				case int:
					s := "s"
					if tc == 1 {
						s = ""
					}
					assert.Lenf(t, traceDetails, tc,
						"[match %d]: expected %q to have %d trace%s, but found %d",
						i+1, m.Name, tc, s, len(traceDetails))
				}

				if m.Services != nil {
					for _, trace := range traceDetails {
						assert.ElementsMatch(t, m.Services, trace.Services)
					}
				}
			} else if _, ok := m.TraceCount.(Any); !ok {
				t.Errorf("no traces with name %q found", m.Name)
			}
		}
	}
	if opts.CheckDetachedSpans {
		for _, trace := range traces.ByID {
			tree := trace.SpanTree()
			if !assert.Empty(t, tree.DetachedParents) {
				for spanID, node := range tree.DetachedParents {
					t.Log("------------------------------------")
					t.Logf("span id: %s", spanID)
					if len(node.Children) != 0 {
						t.Log("children:")
					}
					for _, c := range node.Children {
						t.Log(protojson.Format(c.Span.Raw))
					}
					t.Log("------------------------------------")
				}
			}
		}
	}
	if opts.Exact {
		expected := slices.Sorted(maps.Keys(matchArgsByName))
		actual := slices.Sorted(maps.Keys(traces.ByName))
		for name, match := range matchArgsByName {
			if _, ok := traces.ByName[name]; !ok {
				if _, ok := match.TraceCount.(Any); ok {
					expected = slices.DeleteFunc(expected, func(s string) bool { return s == name })
				}
			}
		}
		assert.Equal(t, expected, actual)
	}
}

func (tr *TraceResults) AssertEqual(t testing.TB, expectedResults *TraceResults, msgFmtAndArgs ...any) {
	t.Helper()
	actualTraces := tr.GetTraces()
	expectedTraces := expectedResults.GetTraces()
	for traceID, expected := range expectedTraces.ByID {
		if actual, ok := actualTraces.ByID[traceID]; !ok {
			if len(msgFmtAndArgs) > 0 {
				t.Errorf("expected trace id %s not found (%s)", traceID.Value().String(),
					fmt.Sprintf(msgFmtAndArgs[0].(string), msgFmtAndArgs[1:]...))
			} else {
				t.Errorf("expected trace id %s not found", traceID.Value().String())
			}
		} else {
			if equal, diff := actual.Equal(expected); !equal {
				if len(msgFmtAndArgs) > 0 {
					t.Errorf("trace %s is not equal (%s):\n%s", traceID.Value().String(),
						fmt.Sprintf(msgFmtAndArgs[0].(string), msgFmtAndArgs[1:]...), diff)
				} else {
					t.Errorf("trace %s is not equal:\n%s", traceID.Value().String(), diff)
				}
			}
		}
	}
	for traceID := range actualTraces.ByID {
		if _, ok := expectedTraces.ByID[traceID]; !ok {
			if len(msgFmtAndArgs) > 0 {
				t.Errorf("unexpected trace id %s found (%s)", traceID.Value().String(),
					fmt.Sprintf(msgFmtAndArgs[0].(string), msgFmtAndArgs[1:]...))
			} else {
				t.Errorf("unexpected trace id %s found", traceID.Value().String())
			}
		}
	}
}

func FlattenResourceSpans(lists [][]*tracev1.ResourceSpans) []*tracev1.ResourceSpans {
	res := NewBuffer()
	for _, list := range lists {
		for _, resource := range list {
			resInfo := NewResourceInfo(resource.Resource, resource.SchemaUrl)
			for _, scope := range resource.ScopeSpans {
				scopeInfo := NewScopeInfo(scope.Scope, scope.SchemaUrl)
				for _, span := range scope.Spans {
					res.Insert(resInfo, scopeInfo, span)
				}
			}
		}
	}
	return res.Flush()
}

func FlattenExportRequests(reqs []*coltracepb.ExportTraceServiceRequest) []*tracev1.ResourceSpans {
	lists := make([][]*tracev1.ResourceSpans, len(reqs))
	for i, req := range reqs {
		lists[i] = req.ResourceSpans
	}
	return FlattenResourceSpans(lists)
}

type EventRecording struct {
	events       []trace.DebugEvent
	normalizedTo time.Time
}

func LoadEventRecording(raw []byte) (*EventRecording, error) {
	events := []trace.DebugEvent{}
	if err := json.Unmarshal(raw, &events); err != nil {
		return nil, err
	}
	for i := 1; i < len(events); i++ {
		if events[i].Timestamp.Before(events[i-1].Timestamp) {
			return nil, fmt.Errorf("invalid timestamps: event %d occurred before event %d", i, i-1)
		}
	}
	return &EventRecording{
		events: events,
	}, nil
}

func (er *EventRecording) Normalize(startTime time.Time) {
	if len(er.events) == 0 {
		return
	}
	er.normalizedTo = startTime
	offset := startTime.Sub(er.events[0].Timestamp)
	for i, ev := range er.events {
		er.events[i].Timestamp = ev.Timestamp.Add(offset)
		for _, resSpan := range ev.Request.ResourceSpans {
			for _, scopeSpans := range resSpan.ScopeSpans {
				for _, span := range scopeSpans.Spans {
					span.StartTimeUnixNano += uint64(offset)
					span.EndTimeUnixNano += uint64(offset)
					for _, event := range span.Events {
						event.TimeUnixNano += uint64(offset)
					}
				}
			}
		}
	}
}

func (er *EventRecording) NormalizedTo() time.Time {
	return er.normalizedTo
}

type EventCallbackFunc = func(ctx context.Context, req *coltracepb.ExportTraceServiceRequest) (*coltracepb.ExportTraceServiceResponse, error)

func (er *EventRecording) Events() []trace.DebugEvent {
	return er.events
}

func (er *EventRecording) Clone() *EventRecording {
	clonedEvents := make([]trace.DebugEvent, 0, len(er.events))
	for _, ev := range er.events {
		clonedEvents = append(clonedEvents, trace.DebugEvent{
			Timestamp: ev.Timestamp,
			Request:   proto.Clone(ev.Request).(*coltracepb.ExportTraceServiceRequest),
		})
	}
	c := &EventRecording{
		events:       clonedEvents,
		normalizedTo: er.normalizedTo,
	}
	return c
}

func (er *EventRecording) Replay(callback EventCallbackFunc) error {
	runtime.LockOSThread()
	defer runtime.UnlockOSThread()

	durations := make([]time.Duration, 0, len(er.events)-1)
	for i := 1; i < len(er.events); i++ {
		durations = append(durations, er.events[i].Timestamp.Sub(er.events[i-1].Timestamp))
	}

	var wg sync.WaitGroup
	wg.Add(len(er.events))
	er.Normalize(time.Now())
	for i, ev := range er.events {
		go func() {
			callback(context.Background(), ev.Request)
			wg.Done()
		}()
		if i < len(er.events)-1 {
			time.Sleep(durations[i])
		}
	}
	wg.Wait()
	return nil
}