package trace

import (
	"context"
	"encoding/binary"
	"encoding/hex"
	"encoding/json"
	"errors"
	"fmt"
	"io"
	"os"
	"runtime"
	"slices"
	"strings"
	"sync"
	"time"

	"go.opentelemetry.io/otel/attribute"
	sdktrace "go.opentelemetry.io/otel/sdk/trace"
	oteltrace "go.opentelemetry.io/otel/trace"
	coltracepb "go.opentelemetry.io/proto/otlp/collector/trace/v1"
	"google.golang.org/grpc/metadata"
	"google.golang.org/protobuf/encoding/protojson"
)

type DebugFlags uint32

const (
	// If set, adds the "caller" attribute to each trace with the source location
	// where the trace was started.
	TrackSpanCallers = (1 << iota)

	// If set, keeps track of all span references and will attempt to wait for
	// all traces to complete when shutting down a trace context.
	// Use with caution, this will cause increasing memory usage over time.
	TrackSpanReferences = (1 << iota)

	// If set, keeps track of all observed spans, including span context and
	// all attributes.
	// Use with caution, this will cause significantly increasing memory usage
	// over time.
	TrackAllSpans = (1 << iota) | TrackSpanCallers

	// If set, will log all trace IDs and their span counts on close.
	//
	// Enables [TrackAllSpans]
	LogTraceIDs = (1 << iota) | TrackAllSpans

	// If set, will log all spans observed by the exporter on close. These spans
	// may belong to incomplete traces.
	//
	// Enables [TrackAllSpans]
	LogAllSpans = (1 << iota) | TrackAllSpans

	// If set, will log all exported spans when a warning is issued on close
	// (requires warning flags to also be set)
	//
	// Enables [TrackAllSpans]
	LogAllSpansOnWarn = (1 << iota) | TrackAllSpans

	// If set, will log all trace ID mappings when a warning is issued on close.
	// (requires warning flags to also be set)
	LogTraceIDsOnWarn = (1 << iota)

	// If set, will print a warning to stderr on close if there are any incomplete
	// traces (traces with no observed root spans)
	WarnOnIncompleteTraces = (1 << iota)

	// If set, will print a warning to stderr on close if there are any incomplete
	// spans (spans started, but not ended)
	WarnOnIncompleteSpans = (1 << iota)

	// If set, will print a warning to stderr on close if there are any spans
	// which reference unknown parent spans.
	//
	// Enables [TrackSpanReferences]
	WarnOnUnresolvedReferences = (1 << iota) | TrackSpanReferences

	// If set, configures Envoy to flush every span individually, disabling its
	// internal buffer.
	EnvoyFlushEverySpan = (1 << iota)
)

func (df DebugFlags) Check(flags DebugFlags) bool {
	return (df & flags) == flags
}

var (
	ErrIncompleteSpans    = errors.New("exporter shut down with incomplete spans")
	ErrMissingParentSpans = errors.New("exporter shut down with missing parent spans")
)

// WaitForSpans will block up to the given max duration and wait for all
// in-flight spans from tracers created with the given context to end. This
// function can be called more than once, and is safe to call from multiple
// goroutines in parallel.
//
// This requires the [TrackSpanReferences] debug flag to have been set with
// [Options.NewContext]. Otherwise, this function is a no-op and will return
// immediately.
//
// If this function blocks for more than 10 seconds, it will print a warning
// to stderr containing a list of span IDs it is waiting for, and the IDs of
// their parents (if known). Additionally, if the [TrackAllSpans] debug flag
// is set, details about parent spans will be displayed, including call site
// and trace ID.
func WaitForSpans(ctx context.Context, maxDuration time.Duration) error {
	if sys := systemContextFromContext(ctx); sys != nil && sys.observer != nil {
		done := make(chan struct{})
		go func() {
			defer close(done)
			sys.observer.wait(10 * time.Second)
		}()
		select {
		case <-done:
			return nil
		case <-time.After(maxDuration):
			return ErrMissingParentSpans
		}
	}
	return nil
}

func DebugFlagsFromContext(ctx context.Context) DebugFlags {
	if sys := systemContextFromContext(ctx); sys != nil {
		return sys.options.DebugFlags
	}
	return 0
}

type stackTraceProcessor struct{}

// ForceFlush implements trace.SpanProcessor.
func (s *stackTraceProcessor) ForceFlush(context.Context) error {
	return nil
}

// OnEnd implements trace.SpanProcessor.
func (*stackTraceProcessor) OnEnd(sdktrace.ReadOnlySpan) {
}

// OnStart implements trace.SpanProcessor.
func (*stackTraceProcessor) OnStart(_ context.Context, s sdktrace.ReadWriteSpan) {
	_, file, line, _ := runtime.Caller(2)
	s.SetAttributes(attribute.String("caller", fmt.Sprintf("%s:%d", file, line)))
}

// Shutdown implements trace.SpanProcessor.
func (s *stackTraceProcessor) Shutdown(context.Context) error {
	return nil
}

var debugMessageWriter io.Writer

func startMsg(title string) *strings.Builder {
	msg := &strings.Builder{}
	msg.WriteString("\n==================================================\n")
	msg.WriteString(title)
	return msg
}

func endMsg(msg *strings.Builder) {
	msg.WriteString("==================================================\n")
	w := debugMessageWriter
	if w == nil {
		w = os.Stderr
	}
	fmt.Fprint(w, msg.String())
}

type DebugEvent struct {
	Timestamp time.Time                             `json:"timestamp"`
	Request   *coltracepb.ExportTraceServiceRequest `json:"request"`
}

func (e DebugEvent) MarshalJSON() ([]byte, error) {
	type debugEvent struct {
		Timestamp time.Time       `json:"timestamp"`
		Request   json.RawMessage `json:"request"`
	}
	reqData, _ := protojson.Marshal(e.Request)
	return json.Marshal(debugEvent{
		Timestamp: e.Timestamp,
		Request:   reqData,
	})
}

func (e *DebugEvent) UnmarshalJSON(b []byte) error {
	type debugEvent struct {
		Timestamp time.Time       `json:"timestamp"`
		Request   json.RawMessage `json:"request"`
	}
	var ev debugEvent
	if err := json.Unmarshal(b, &ev); err != nil {
		return err
	}
	e.Timestamp = ev.Timestamp
	var msg coltracepb.ExportTraceServiceRequest
	if err := protojson.Unmarshal(ev.Request, &msg); err != nil {
		return err
	}
	e.Request = &msg
	return nil
}

const shardCount = 64

type (
	shardedSet   [shardCount]map[oteltrace.SpanID]struct{}
	shardedLocks [shardCount]sync.Mutex
)

func (s *shardedSet) Range(f func(key oteltrace.SpanID)) {
	for i := range shardCount {
		for k := range s[i] {
			f(k)
		}
	}
}

func (s *shardedLocks) LockAll() {
	for i := range shardCount {
		s[i].Lock()
	}
}

func (s *shardedLocks) UnlockAll() {
	for i := range shardCount {
		s[i].Unlock()
	}
}

type spanTracker struct {
	inflightSpansMu shardedLocks
	inflightSpans   shardedSet
	allSpans        sync.Map
	debugFlags      DebugFlags
	observer        *spanObserver
	shutdownOnce    sync.Once
}

func newSpanTracker(observer *spanObserver, debugFlags DebugFlags) *spanTracker {
	st := &spanTracker{
		observer:   observer,
		debugFlags: debugFlags,
	}
	for i := range len(st.inflightSpans) {
		st.inflightSpans[i] = make(map[oteltrace.SpanID]struct{})
	}
	return st
}

type spanInfo struct {
	Name        string
	SpanContext oteltrace.SpanContext
	Parent      oteltrace.SpanContext
	caller      string
	startTime   time.Time
}

// ForceFlush implements trace.SpanProcessor.
func (t *spanTracker) ForceFlush(context.Context) error {
	return nil
}

// OnEnd implements trace.SpanProcessor.
func (t *spanTracker) OnEnd(s sdktrace.ReadOnlySpan) {
	id := s.SpanContext().SpanID()
	bucket := binary.BigEndian.Uint64(id[:]) % shardCount
	t.inflightSpansMu[bucket].Lock()
	defer t.inflightSpansMu[bucket].Unlock()
	delete(t.inflightSpans[bucket], id)
}

// OnStart implements trace.SpanProcessor.
func (t *spanTracker) OnStart(_ context.Context, s sdktrace.ReadWriteSpan) {
	id := s.SpanContext().SpanID()
	bucket := binary.BigEndian.Uint64(id[:]) % shardCount
	t.inflightSpansMu[bucket].Lock()
	defer t.inflightSpansMu[bucket].Unlock()
	t.inflightSpans[bucket][id] = struct{}{}

	if t.debugFlags.Check(TrackSpanReferences) {
		if s.Parent().IsValid() {
			t.observer.ObserveReference(s.Parent().SpanID(), id)
		}
		t.observer.Observe(id)
	}

	if t.debugFlags.Check(TrackAllSpans) {
		var caller string
		for _, attr := range s.Attributes() {
			if attr.Key == "caller" {
				caller = attr.Value.AsString()
				break
			}
		}
		t.allSpans.Store(id, &spanInfo{
			Name:        s.Name(),
			SpanContext: s.SpanContext(),
			Parent:      s.Parent(),
			caller:      caller,
			startTime:   s.StartTime(),
		})
	}
}

// Shutdown implements trace.SpanProcessor.
func (t *spanTracker) Shutdown(_ context.Context) error {
	if t.debugFlags == 0 {
		return nil
	}
	didWarn := false
	t.shutdownOnce.Do(func() {
		if t.debugFlags.Check(WarnOnUnresolvedReferences) {
			var unknownParentIDs []string
			for id, via := range t.observer.referencedIDs {
				if via.IsValid() {
					if t.debugFlags.Check(TrackAllSpans) {
						if viaSpan, ok := t.allSpans.Load(via); ok {
							unknownParentIDs = append(unknownParentIDs, fmt.Sprintf("%s via %s (%s)", id, via, viaSpan.(*spanInfo).Name))
						} else {
							unknownParentIDs = append(unknownParentIDs, fmt.Sprintf("%s via %s", id, via))
						}
					}
				}
			}
			if len(unknownParentIDs) > 0 {
				didWarn = true
				msg := startMsg("WARNING: parent spans referenced but never seen:\n")
				for _, str := range unknownParentIDs {
					msg.WriteString(str)
					msg.WriteString("\n")
				}
				endMsg(msg)
			}
		}
		if t.debugFlags.Check(WarnOnIncompleteSpans) {
			if t.debugFlags.Check(TrackAllSpans) {
				incompleteSpans := []*spanInfo{}
				t.inflightSpansMu.LockAll()
				t.inflightSpans.Range(func(key oteltrace.SpanID) {
					if info, ok := t.allSpans.Load(key); ok {
						incompleteSpans = append(incompleteSpans, info.(*spanInfo))
					}
				})
				t.inflightSpansMu.UnlockAll()
				if len(incompleteSpans) > 0 {
					didWarn = true
					msg := startMsg("WARNING: spans not ended:\n")
					longestName := 0
					for _, span := range incompleteSpans {
						longestName = max(longestName, len(span.Name)+2)
					}
					for _, span := range incompleteSpans {
						var startedAt string
						if span.caller != "" {
							startedAt = " | started at: " + span.caller
						}
						fmt.Fprintf(msg, "%-*s (trace: %s | span: %s | parent: %s%s)\n", longestName, "'"+span.Name+"'",
							span.SpanContext.TraceID(), span.SpanContext.SpanID(), span.Parent.SpanID(), startedAt)
					}
					endMsg(msg)
				}
			} else {
				incompleteSpans := []oteltrace.SpanID{}
				t.inflightSpansMu.LockAll()
				t.inflightSpans.Range(func(key oteltrace.SpanID) {
					incompleteSpans = append(incompleteSpans, key)
				})
				t.inflightSpansMu.UnlockAll()
				if len(incompleteSpans) > 0 {
					didWarn = true
					msg := startMsg("WARNING: spans not ended:\n")
					for _, span := range incompleteSpans {
						fmt.Fprintf(msg, "%s\n", span)
					}
					msg.WriteString("Note: set TrackAllSpans flag for more info\n")
					endMsg(msg)
				}
			}
		}

		if t.debugFlags.Check(LogAllSpans) || (t.debugFlags.Check(LogAllSpansOnWarn) && didWarn) {
			allSpans := []*spanInfo{}
			t.allSpans.Range(func(_, value any) bool {
				allSpans = append(allSpans, value.(*spanInfo))
				return true
			})
			slices.SortFunc(allSpans, func(a, b *spanInfo) int {
				return a.startTime.Compare(b.startTime)
			})
			msg := startMsg("All observed spans:\n")
			longestName := 0
			for _, span := range allSpans {
				longestName = max(longestName, len(span.Name)+2)
			}
			for _, span := range allSpans {
				var startedAt string
				if span.caller != "" {
					startedAt = " | started at: " + span.caller
				}
				fmt.Fprintf(msg, "%-*s (trace: %s | span: %s | parent: %s%s)\n", longestName, "'"+span.Name+"'",
					span.SpanContext.TraceID(), span.SpanContext.SpanID(), span.Parent.SpanID(), startedAt)
			}
			endMsg(msg)
		}

		if t.debugFlags.Check(LogTraceIDs) || (didWarn && t.debugFlags.Check(LogTraceIDsOnWarn)) {
			msg := startMsg("Known trace ids:\n")
			traceIDs := map[oteltrace.TraceID]int{}
			t.allSpans.Range(func(_, value any) bool {
				v := value.(*spanInfo)
				traceIDs[v.SpanContext.TraceID()]++
				return true
			})
			for id, n := range traceIDs {
				fmt.Fprintf(msg, "%s (%d spans)\n", id.String(), n)
			}
			endMsg(msg)
		}
	})
	if didWarn {
		return ErrIncompleteSpans
	}
	return nil
}

func newSpanObserver() *spanObserver {
	return &spanObserver{
		referencedIDs: map[oteltrace.SpanID]oteltrace.SpanID{},
		cond:          sync.NewCond(&sync.Mutex{}),
	}
}

type spanObserver struct {
	cond          *sync.Cond
	referencedIDs map[oteltrace.SpanID]oteltrace.SpanID
	unobservedIDs int
}

func (obs *spanObserver) ObserveReference(id oteltrace.SpanID, via oteltrace.SpanID) {
	obs.cond.L.Lock()
	defer obs.cond.L.Unlock()
	if _, referenced := obs.referencedIDs[id]; !referenced {
		obs.referencedIDs[id] = via // referenced, but not observed
		// It is possible for new unobserved references to come in while waiting,
		// but incrementing the counter wouldn't satisfy the condition so we don't
		// need to signal the waiters
		obs.unobservedIDs++
	}
}

func (obs *spanObserver) Observe(id oteltrace.SpanID) {
	obs.cond.L.Lock()
	defer obs.cond.L.Unlock()
	if observed, referenced := obs.referencedIDs[id]; !referenced || observed.IsValid() { // NB: subtle condition
		obs.referencedIDs[id] = zeroSpanID
		if referenced {
			obs.unobservedIDs--
			obs.cond.Broadcast()
		}
	}
}

func (obs *spanObserver) wait(warnAfter time.Duration) {
	done := make(chan struct{})
	defer close(done)
	go func() {
		select {
		case <-done:
			return
		case <-time.After(warnAfter):
			obs.debugWarnWaiting()
		}
	}()

	obs.cond.L.Lock()
	for obs.unobservedIDs > 0 {
		obs.cond.Wait()
	}
	obs.cond.L.Unlock()
}

func (obs *spanObserver) debugWarnWaiting() {
	obs.cond.L.Lock()
	msg := startMsg(fmt.Sprintf("Waiting on %d unobserved spans:\n", obs.unobservedIDs))
	for id, via := range obs.referencedIDs {
		if via.IsValid() {
			fmt.Fprintf(msg, "%s via %s\n", id, via)
		}
	}
	endMsg(msg)
	obs.cond.L.Unlock()
}

func (srv *ExporterServer) observeExport(ctx context.Context, req *coltracepb.ExportTraceServiceRequest) {
	isLocal := len(metadata.ValueFromIncomingContext(ctx, localExporterMetadataKey)) != 0
	if isLocal {
		return
	}
	for _, res := range req.ResourceSpans {
		for _, scope := range res.ScopeSpans {
			for _, span := range scope.Spans {
				id, ok := ToSpanID(span.SpanId)
				if !ok {
					continue
				}
				srv.observer.Observe(id)
				for _, attr := range span.Attributes {
					if attr.Key != "pomerium.external-parent-span" {
						continue
					}
					if bytes, err := hex.DecodeString(attr.Value.GetStringValue()); err == nil {
						if id, ok := ToSpanID(bytes); ok {
							srv.observer.Observe(id)
						}
					}
					break
				}
			}
		}
	}
}