pomerium/internal/telemetry/trace/queue.go
2024-12-06 21:25:48 +00:00

746 lines
21 KiB
Go

package trace
import (
"context"
"encoding/base64"
"encoding/hex"
"errors"
"fmt"
"net/url"
"strings"
"sync"
"time"
"unique"
"github.com/pomerium/pomerium/internal/hashutil"
"github.com/pomerium/pomerium/internal/log"
"go.opentelemetry.io/otel/exporters/otlp/otlptrace"
sdktrace "go.opentelemetry.io/otel/sdk/trace"
oteltrace "go.opentelemetry.io/otel/trace"
coltracepb "go.opentelemetry.io/proto/otlp/collector/trace/v1"
commonv1 "go.opentelemetry.io/proto/otlp/common/v1"
resourcev1 "go.opentelemetry.io/proto/otlp/resource/v1"
tracev1 "go.opentelemetry.io/proto/otlp/trace/v1"
"google.golang.org/protobuf/proto"
)
type SpanExportQueue struct {
mu sync.Mutex
pendingResourcesByTraceID map[unique.Handle[oteltrace.TraceID]]*PendingResources
knownTraceIDMappings map[unique.Handle[oteltrace.TraceID]]unique.Handle[oteltrace.TraceID]
uploadC chan []*tracev1.ResourceSpans
closing bool
closed chan struct{}
debugFlags DebugFlags
debugAllObservedSpans map[oteltrace.SpanID]*tracev1.Span
tracker *spanTracker
observer SpanObserver
}
func NewSpanExportQueue(ctx context.Context, client otlptrace.Client) *SpanExportQueue {
debug := systemContextFromContext(ctx).DebugFlags
var observer SpanObserver
if debug.Check(TrackSpanReferences) {
observer = &spanObserver{referencedIDs: make(map[oteltrace.SpanID]oteltrace.SpanID)}
} else {
observer = noopSpanObserver{}
}
q := &SpanExportQueue{
pendingResourcesByTraceID: make(map[unique.Handle[oteltrace.TraceID]]*PendingResources),
knownTraceIDMappings: make(map[unique.Handle[oteltrace.TraceID]]unique.Handle[oteltrace.TraceID]),
uploadC: make(chan []*tracev1.ResourceSpans, 8),
closed: make(chan struct{}),
debugFlags: debug,
debugAllObservedSpans: make(map[oteltrace.SpanID]*tracev1.Span),
tracker: &spanTracker{observer: observer, debugFlags: debug},
observer: observer,
}
go func() {
defer close(q.closed)
for resourceSpans := range q.uploadC {
if err := client.UploadTraces(context.Background(), resourceSpans); err != nil {
log.Ctx(ctx).Err(err).Msg("error uploading traces")
}
}
}()
return q
}
func (q *SpanExportQueue) insertPendingSpanLocked(
resource *ResourceInfo,
scope *commonv1.InstrumentationScope,
scopeSchema string,
traceID unique.Handle[oteltrace.TraceID],
span *tracev1.Span,
) {
var pendingTraceResources *PendingResources
if ptr, ok := q.pendingResourcesByTraceID[traceID]; ok {
pendingTraceResources = ptr
} else {
pendingTraceResources = NewPendingResources()
q.pendingResourcesByTraceID[traceID] = pendingTraceResources
}
pendingTraceResources.Insert(resource, scope, scopeSchema, span)
}
func (q *SpanExportQueue) resolveTraceIDMappingLocked(original, mapping unique.Handle[oteltrace.TraceID]) [][]*tracev1.ResourceSpans {
q.knownTraceIDMappings[original] = mapping
toUpload := [][]*tracev1.ResourceSpans{}
if originalPending, ok := q.pendingResourcesByTraceID[original]; ok {
resourceSpans := originalPending.AsResourceSpans(mapping)
delete(q.pendingResourcesByTraceID, original)
toUpload = append(toUpload, resourceSpans)
}
if original != mapping {
q.knownTraceIDMappings[mapping] = mapping
if targetPending, ok := q.pendingResourcesByTraceID[mapping]; ok {
resourceSpans := targetPending.AsResourceSpans(mapping)
delete(q.pendingResourcesByTraceID, mapping)
toUpload = append(toUpload, resourceSpans)
}
}
return toUpload
}
var ErrShuttingDown = errors.New("exporter is shutting down")
func (q *SpanExportQueue) Enqueue(ctx context.Context, req *coltracepb.ExportTraceServiceRequest) error {
q.mu.Lock()
defer q.mu.Unlock()
if q.closing {
return ErrShuttingDown
}
var toUpload [][]*tracev1.ResourceSpans
for _, resource := range req.ResourceSpans {
for _, scope := range resource.ScopeSpans {
for _, span := range scope.Spans {
formatSpanName(span)
spanID, ok := toSpanID(span.SpanId)
if !ok {
continue
}
if q.debugFlags.Check(TrackAllSpans) {
q.debugAllObservedSpans[spanID] = span
}
parentSpanID, ok := toSpanID(span.ParentSpanId)
if !ok {
continue
}
if parentSpanID.IsValid() { // if parent is not a root span
q.observer.ObserveReference(parentSpanID, spanID)
continue
}
traceID, ok := toTraceID(span.TraceId)
if !ok {
continue
}
if _, ok := q.knownTraceIDMappings[traceID]; !ok {
// observed a new root span with an unknown trace id
var pomeriumTraceparent string
for _, attr := range span.Attributes {
if attr.Key == "pomerium.traceparent" {
pomeriumTraceparent = attr.GetValue().GetStringValue()
break
}
}
var mappedTraceID unique.Handle[oteltrace.TraceID]
if pomeriumTraceparent == "" {
// no replacement id, map the trace to itself and release pending spans
mappedTraceID = traceID
} else {
// this root span has an alternate traceparent. permanently rewrite
// all spans of the old trace id to use the new trace id
tp, err := ParseTraceparent(pomeriumTraceparent)
if err != nil {
log.Ctx(ctx).Err(err).Msg("error processing trace")
continue
}
mappedTraceID = unique.Make(tp.TraceID())
}
toUpload = append(toUpload, q.resolveTraceIDMappingLocked(traceID, mappedTraceID)...)
}
}
}
}
var knownResources []*tracev1.ResourceSpans
for _, resource := range req.ResourceSpans {
resourceInfo := newResourceInfo(resource.Resource, resource.SchemaUrl)
knownResource := &tracev1.ResourceSpans{
Resource: resource.Resource,
SchemaUrl: resource.SchemaUrl,
}
for _, scope := range resource.ScopeSpans {
var knownSpans []*tracev1.Span
for _, span := range scope.Spans {
spanID, ok := toSpanID(span.SpanId)
if !ok {
continue
}
traceID, ok := toTraceID(span.TraceId)
if !ok {
continue
}
parentSpanId, ok := toSpanID(span.ParentSpanId)
if !ok {
continue
}
q.observer.Observe(spanID)
if mapping, ok := q.knownTraceIDMappings[traceID]; ok {
id := mapping.Value()
copy(span.TraceId, id[:])
knownSpans = append(knownSpans, span)
} else {
var isInternalRoot bool
if q.debugFlags.Check(TrackSpanReferences) {
if parentSpanId.IsValid() {
for _, attr := range span.Attributes {
if attr.Key == "pomerium.external-parent-span" {
isInternalRoot = true
if bytes, err := hex.DecodeString(attr.Value.GetStringValue()); err == nil {
if spanId, _ := toSpanID(bytes); spanId.IsValid() {
q.observer.Observe(spanId)
}
}
break
}
}
}
}
if isInternalRoot {
toUpload = append(toUpload, q.resolveTraceIDMappingLocked(traceID, traceID)...)
} else {
q.insertPendingSpanLocked(resourceInfo, scope.Scope, scope.SchemaUrl, traceID, span)
}
}
}
if len(knownSpans) > 0 {
knownResource.ScopeSpans = append(knownResource.ScopeSpans, &tracev1.ScopeSpans{
Scope: scope.Scope,
SchemaUrl: scope.SchemaUrl,
Spans: knownSpans,
})
}
}
if len(knownResource.ScopeSpans) > 0 {
knownResources = append(knownResources, knownResource)
}
}
if len(knownResources) > 0 {
toUpload = append(toUpload, knownResources)
}
for _, res := range toUpload {
q.uploadC <- res
}
return nil
}
var (
ErrIncompleteTraces = errors.New("exporter shut down with incomplete traces")
ErrIncompleteUploads = errors.New("exporter shut down with pending trace uploads")
ErrMissingParentSpans = errors.New("exporter shut down with missing parent spans")
)
func (q *SpanExportQueue) WaitForSpans(maxDuration time.Duration) error {
done := make(chan struct{})
go func() {
defer close(done)
q.observer.Wait()
}()
select {
case <-done:
return nil
case <-time.After(maxDuration):
return ErrMissingParentSpans
}
}
func (q *SpanExportQueue) Close(ctx context.Context) error {
q.mu.Lock()
q.closing = true
close(q.uploadC)
q.mu.Unlock()
select {
case <-ctx.Done():
return context.Cause(ctx)
case <-q.closed:
q.mu.Lock()
defer q.mu.Unlock()
didWarn := false
if q.debugFlags.Check(TrackSpanReferences) {
var unknownParentIDs []string
for id, via := range q.observer.(*spanObserver).referencedIDs {
if via.IsValid() {
if q.debugFlags.Check(TrackAllSpans) {
if viaSpan, ok := q.debugAllObservedSpans[via]; ok {
unknownParentIDs = append(unknownParentIDs, fmt.Sprintf("%s via %s (%s)", id, via, viaSpan.Name))
} else {
unknownParentIDs = append(unknownParentIDs, fmt.Sprintf("%s via %s", id, via))
}
}
}
}
if len(unknownParentIDs) > 0 {
didWarn = true
msg := startMsg("WARNING: parent spans referenced but never seen:\n")
for _, str := range unknownParentIDs {
msg.WriteString(str)
msg.WriteString("\n")
}
endMsg(msg)
}
}
incomplete := len(q.pendingResourcesByTraceID) > 0
if incomplete && q.debugFlags.Check(WarnOnIncompleteTraces) {
didWarn = true
msg := startMsg("WARNING: exporter shut down with incomplete traces\n")
for k, v := range q.pendingResourcesByTraceID {
fmt.Fprintf(msg, "- Trace: %s\n", k.Value())
for _, pendingScope := range v.scopesByResourceID {
msg.WriteString(" - Resource:\n")
for _, v := range pendingScope.resource.Resource.Attributes {
fmt.Fprintf(msg, " %s=%s\n", v.Key, v.Value.String())
}
for _, scope := range pendingScope.spansByScope {
if scope.scope != nil {
fmt.Fprintf(msg, " Scope: %s\n", scope.scope.Name)
} else {
msg.WriteString(" Scope: (unknown)\n")
}
msg.WriteString(" Spans:\n")
longestName := 0
for _, span := range scope.spans {
longestName = max(longestName, len(span.Name)+2)
}
for _, span := range scope.spans {
spanID, ok := toSpanID(span.SpanId)
if !ok {
continue
}
traceID, ok := toTraceID(span.TraceId)
if !ok {
continue
}
parentSpanID, ok := toSpanID(span.ParentSpanId)
if !ok {
continue
}
_, seenParent := q.debugAllObservedSpans[parentSpanID]
var missing string
if !seenParent {
missing = " [missing]"
}
fmt.Fprintf(msg, " - %-*s (trace: %s | span: %s | parent:%s %s)\n", longestName,
"'"+span.Name+"'", traceID.Value(), spanID, missing, parentSpanID)
for _, attr := range span.Attributes {
if attr.Key == "caller" {
fmt.Fprintf(msg, " => caller: '%s'\n", attr.Value.GetStringValue())
break
}
}
}
}
}
}
endMsg(msg)
}
if q.debugFlags.Check(LogTraceIDMappings) || (didWarn && q.debugFlags.Check(LogTraceIDMappingsOnWarn)) {
msg := startMsg("Known trace ids:\n")
for k, v := range q.knownTraceIDMappings {
if k != v {
fmt.Fprintf(msg, "%s => %s\n", k.Value(), v.Value())
} else {
fmt.Fprintf(msg, "%s (no change)\n", k.Value())
}
}
endMsg(msg)
}
if q.debugFlags.Check(LogAllSpans) || (didWarn && q.debugFlags.Check(LogAllSpansOnWarn)) {
msg := startMsg("All exported spans:\n")
longestName := 0
for _, span := range q.debugAllObservedSpans {
longestName = max(longestName, len(span.Name)+2)
}
for _, span := range q.debugAllObservedSpans {
spanID, ok := toSpanID(span.SpanId)
if !ok {
continue
}
traceID, ok := toTraceID(span.TraceId)
if !ok {
continue
}
parentSpanID, ok := toSpanID(span.ParentSpanId)
if !ok {
continue
}
fmt.Fprintf(msg, "%-*s (trace: %s | span: %s | parent: %s)", longestName,
"'"+span.Name+"'", traceID.Value(), spanID, parentSpanID)
var foundCaller bool
for _, attr := range span.Attributes {
if attr.Key == "caller" {
fmt.Fprintf(msg, " => %s\n", attr.Value.GetStringValue())
foundCaller = true
break
}
}
if !foundCaller {
msg.WriteString("\n")
}
}
endMsg(msg)
}
if incomplete {
return ErrIncompleteTraces
}
log.Ctx(ctx).Debug().Msg("exporter shut down")
return nil
}
}
type spanTracker struct {
inflightSpans sync.Map
allSpans sync.Map
debugFlags DebugFlags
observer SpanObserver
shutdownOnce sync.Once
}
type spanInfo struct {
Name string
SpanContext oteltrace.SpanContext
Parent oteltrace.SpanContext
}
// ForceFlush implements trace.SpanProcessor.
func (t *spanTracker) ForceFlush(context.Context) error {
return nil
}
// OnEnd implements trace.SpanProcessor.
func (t *spanTracker) OnEnd(s sdktrace.ReadOnlySpan) {
id := s.SpanContext().SpanID()
t.inflightSpans.Delete(id)
}
// OnStart implements trace.SpanProcessor.
func (t *spanTracker) OnStart(_ context.Context, s sdktrace.ReadWriteSpan) {
id := s.SpanContext().SpanID()
t.inflightSpans.Store(id, struct{}{})
t.observer.Observe(id)
if t.debugFlags.Check(TrackAllSpans) {
t.allSpans.Store(id, &spanInfo{
Name: s.Name(),
SpanContext: s.SpanContext(),
Parent: s.Parent(),
})
}
}
// Shutdown implements trace.SpanProcessor.
func (t *spanTracker) Shutdown(_ context.Context) error {
if t.debugFlags == 0 {
return nil
}
t.shutdownOnce.Do(func() {
didWarn := false
if t.debugFlags.Check(WarnOnIncompleteSpans) {
if t.debugFlags.Check(TrackAllSpans) {
incompleteSpans := []*spanInfo{}
t.inflightSpans.Range(func(key, _ any) bool {
if info, ok := t.allSpans.Load(key); ok {
incompleteSpans = append(incompleteSpans, info.(*spanInfo))
}
return true
})
if len(incompleteSpans) > 0 {
didWarn = true
msg := startMsg("WARNING: spans not ended:\n")
longestName := 0
for _, span := range incompleteSpans {
longestName = max(longestName, len(span.Name)+2)
}
for _, span := range incompleteSpans {
fmt.Fprintf(msg, "%-*s (trace: %s | span: %s | parent: %s)\n", longestName, "'"+span.Name+"'",
span.SpanContext.TraceID(), span.SpanContext.SpanID(), span.Parent.SpanID())
}
endMsg(msg)
}
} else {
incompleteSpans := []string{}
t.inflightSpans.Range(func(key, _ any) bool {
incompleteSpans = append(incompleteSpans, key.(string))
return true
})
if len(incompleteSpans) > 0 {
didWarn = true
msg := startMsg("WARNING: spans not ended:\n")
for _, span := range incompleteSpans {
fmt.Fprintf(msg, "%s\n", span)
}
msg.WriteString("Note: set TrackAllSpans flag for more info\n")
endMsg(msg)
}
}
}
if t.debugFlags.Check(LogAllSpans) || (t.debugFlags.Check(LogAllSpansOnWarn) && didWarn) {
allSpans := []*spanInfo{}
t.allSpans.Range(func(_, value any) bool {
allSpans = append(allSpans, value.(*spanInfo))
return true
})
msg := startMsg("All observed spans:\n")
longestName := 0
for _, span := range allSpans {
longestName = max(longestName, len(span.Name)+2)
}
for _, span := range allSpans {
fmt.Fprintf(msg, "%-*s (trace: %s | span: %s | parent: %s)\n", longestName, "'"+span.Name+"'",
span.SpanContext.TraceID(), span.SpanContext.SpanID(), span.Parent.SpanID())
}
endMsg(msg)
}
})
return nil
}
type PendingSpans struct {
scope *commonv1.InstrumentationScope
scopeSchema string
spans []*tracev1.Span
}
func (ps *PendingSpans) Insert(span *tracev1.Span) {
ps.spans = append(ps.spans, span)
}
func NewPendingSpans(scope *commonv1.InstrumentationScope, scopeSchema string) *PendingSpans {
return &PendingSpans{
scope: scope,
scopeSchema: scopeSchema,
}
}
type PendingScopes struct {
resource *ResourceInfo
spansByScope map[string]*PendingSpans
}
func (ps *PendingScopes) Insert(scope *commonv1.InstrumentationScope, scopeSchema string, span *tracev1.Span) {
var spans *PendingSpans
if sp, ok := ps.spansByScope[scope.GetName()]; ok {
spans = sp
} else {
spans = NewPendingSpans(scope, scopeSchema)
ps.spansByScope[scope.GetName()] = spans
}
spans.Insert(span)
}
func (ps *PendingScopes) AsScopeSpansList(rewriteTraceID unique.Handle[oteltrace.TraceID]) []*tracev1.ScopeSpans {
out := make([]*tracev1.ScopeSpans, 0, len(ps.spansByScope))
for _, spans := range ps.spansByScope {
for _, span := range spans.spans {
id := rewriteTraceID.Value()
copy(span.TraceId, id[:])
}
scopeSpans := &tracev1.ScopeSpans{
Scope: spans.scope,
SchemaUrl: spans.scopeSchema,
Spans: spans.spans,
}
out = append(out, scopeSpans)
}
return out
}
func NewPendingScopes(resource *ResourceInfo) *PendingScopes {
return &PendingScopes{
resource: resource,
spansByScope: make(map[string]*PendingSpans),
}
}
type PendingResources struct {
scopesByResourceID map[string]*PendingScopes
}
func (pr *PendingResources) Insert(resource *ResourceInfo, scope *commonv1.InstrumentationScope, scopeSchema string, span *tracev1.Span) {
resourceEq := resource.ID()
var scopes *PendingScopes
if sc, ok := pr.scopesByResourceID[resourceEq]; ok {
scopes = sc
} else {
scopes = NewPendingScopes(resource)
pr.scopesByResourceID[resourceEq] = scopes
}
scopes.Insert(scope, scopeSchema, span)
}
func (pr *PendingResources) AsResourceSpans(rewriteTraceID unique.Handle[oteltrace.TraceID]) []*tracev1.ResourceSpans {
out := make([]*tracev1.ResourceSpans, 0, len(pr.scopesByResourceID))
for _, scopes := range pr.scopesByResourceID {
resourceSpans := &tracev1.ResourceSpans{
Resource: scopes.resource.Resource,
ScopeSpans: scopes.AsScopeSpansList(rewriteTraceID),
SchemaUrl: scopes.resource.Schema,
}
out = append(out, resourceSpans)
}
return out
}
func NewPendingResources() *PendingResources {
return &PendingResources{scopesByResourceID: make(map[string]*PendingScopes)}
}
type ResourceInfo struct {
Resource *resourcev1.Resource
Schema string
ID func() string
}
func newResourceInfo(resource *resourcev1.Resource, resourceSchema string) *ResourceInfo {
r := &ResourceInfo{
Resource: resource,
Schema: resourceSchema,
}
r.ID = sync.OnceValue(r.computeID)
return r
}
func (r *ResourceInfo) computeID() string {
hash := hashutil.NewDigest()
tmp := resourcev1.Resource{
Attributes: r.Resource.Attributes,
}
bytes, _ := proto.Marshal(&tmp)
hash.WriteStringWithLen(r.Schema)
hash.WriteWithLen(bytes)
return base64.StdEncoding.EncodeToString(hash.Sum(nil))
}
type SpanObserver interface {
ObserveReference(id oteltrace.SpanID, via oteltrace.SpanID)
Observe(id oteltrace.SpanID)
Wait()
}
type spanObserver struct {
mu sync.Mutex
referencedIDs map[oteltrace.SpanID]oteltrace.SpanID
unobservedIDs sync.WaitGroup
}
func (obs *spanObserver) ObserveReference(id oteltrace.SpanID, via oteltrace.SpanID) {
obs.mu.Lock()
defer obs.mu.Unlock()
if _, referenced := obs.referencedIDs[id]; !referenced {
obs.referencedIDs[id] = via // referenced, but not observed
obs.unobservedIDs.Add(1)
}
}
func (obs *spanObserver) Observe(id oteltrace.SpanID) {
obs.mu.Lock()
defer obs.mu.Unlock()
if observed, referenced := obs.referencedIDs[id]; !referenced || observed.IsValid() { // NB: subtle condition
obs.referencedIDs[id] = zeroSpanID
if referenced {
obs.unobservedIDs.Done()
}
}
}
func (obs *spanObserver) Wait() {
done := make(chan struct{})
defer close(done)
go func() {
select {
case <-done:
return
case <-time.After(10 * time.Second):
obs.mu.Lock()
msg := startMsg("Waiting on unobserved spans:\n")
for id, via := range obs.referencedIDs {
if via.IsValid() {
fmt.Fprintf(msg, "%s via %s\n", id, via)
}
}
endMsg(msg)
obs.mu.Unlock()
}
}()
obs.unobservedIDs.Wait()
}
type noopSpanObserver struct{}
func (noopSpanObserver) ObserveReference(oteltrace.SpanID, oteltrace.SpanID) {}
func (noopSpanObserver) Observe(oteltrace.SpanID) {}
func (noopSpanObserver) Wait() {}
func formatSpanName(span *tracev1.Span) {
hasPath := strings.Contains(span.GetName(), "${path}")
hasHost := strings.Contains(span.GetName(), "${host}")
hasMethod := strings.Contains(span.GetName(), "${method}")
if hasPath || hasHost || hasMethod {
var u *url.URL
var method string
for _, attr := range span.Attributes {
if attr.Key == "http.url" {
u, _ = url.Parse(attr.Value.GetStringValue())
}
if attr.Key == "http.method" {
method = attr.Value.GetStringValue()
}
}
if u != nil {
if hasPath {
span.Name = strings.ReplaceAll(span.Name, "${path}", u.Path)
}
if hasHost {
span.Name = strings.ReplaceAll(span.Name, "${host}", u.Host)
}
if hasMethod {
span.Name = strings.ReplaceAll(span.Name, "${method}", method)
}
}
}
}
var (
zeroSpanID oteltrace.SpanID
zeroTraceID = unique.Make(oteltrace.TraceID([16]byte{}))
)
func toSpanID(bytes []byte) (oteltrace.SpanID, bool) {
switch len(bytes) {
case 0:
return zeroSpanID, true
case 8:
return oteltrace.SpanID(bytes), true
}
return zeroSpanID, false
}
func toTraceID(bytes []byte) (unique.Handle[oteltrace.TraceID], bool) {
switch len(bytes) {
case 0:
return zeroTraceID, true
case 16:
return unique.Make(oteltrace.TraceID(bytes)), true
}
return zeroTraceID, false
}