diff --git a/internal/benchmarks/latency_bench_test.go b/internal/benchmarks/latency_bench_test.go index 17aa19a96..d25185942 100644 --- a/internal/benchmarks/latency_bench_test.go +++ b/internal/benchmarks/latency_bench_test.go @@ -9,7 +9,6 @@ import ( "testing" "time" - "github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/internal/testenv" "github.com/pomerium/pomerium/internal/testenv/envutil" "github.com/pomerium/pomerium/internal/testenv/scenarios" @@ -49,8 +48,8 @@ func TestRequestLatency(t *testing.T) { for i := range numRoutes { routes[i] = up.Route(). From(env.SubdomainURL(fmt.Sprintf("from-%d", i))). - Policy(func(p *config.Policy) { p.AllowPublicUnauthenticatedAccess = true }) - // PPL(fmt.Sprintf(`{"allow":{"and":["email":{"is":"user%d@example.com"}]}}`, i)) + // Policy(func(p *config.Policy) { p.AllowPublicUnauthenticatedAccess = true }) + PPL(fmt.Sprintf(`{"allow":{"and":["email":{"is":"user%d@example.com"}]}}`, i)) } env.AddUpstream(up) diff --git a/internal/telemetry/trace/queue.go b/internal/telemetry/trace/queue.go index 9b3b47c93..e7ed06d31 100644 --- a/internal/telemetry/trace/queue.go +++ b/internal/telemetry/trace/queue.go @@ -2,6 +2,7 @@ package trace import ( "context" + "encoding/binary" "encoding/json" "errors" "fmt" @@ -51,20 +52,26 @@ func SetMaxCachedTraceIDs(num int32) { maxCachedTraceIDs.Store(max(num, 0)) } +type eviction struct { + traceID unique.Handle[oteltrace.TraceID] + buf *Buffer +} + type SpanExportQueue struct { - mu sync.Mutex - logger *zerolog.Logger + closing chan struct{} + uploadC chan []*tracev1.ResourceSpans + requestC chan *coltracepb.ExportTraceServiceRequest + evictC chan eviction client otlptrace.Client pendingResourcesByTraceID *lru.Cache[unique.Handle[oteltrace.TraceID], *Buffer] knownTraceIDMappings *lru.Cache[unique.Handle[oteltrace.TraceID], unique.Handle[oteltrace.TraceID]] - uploadC chan []*tracev1.ResourceSpans - closing bool - closed chan struct{} - debugFlags DebugFlags - debugAllEnqueuedSpans map[oteltrace.SpanID]*tracev1.Span tracker *spanTracker observer *spanObserver debugEvents []DebugEvent + logger *zerolog.Logger + debugFlags DebugFlags + debugAllEnqueuedSpans map[oteltrace.SpanID]*tracev1.Span + wg sync.WaitGroup } func NewSpanExportQueue(ctx context.Context, client otlptrace.Client) *SpanExportQueue { @@ -76,8 +83,10 @@ func NewSpanExportQueue(ctx context.Context, client otlptrace.Client) *SpanExpor q := &SpanExportQueue{ logger: log.Ctx(ctx), client: client, + closing: make(chan struct{}), uploadC: make(chan []*tracev1.ResourceSpans, 64), - closed: make(chan struct{}), + requestC: make(chan *coltracepb.ExportTraceServiceRequest, 256), + evictC: make(chan eviction, 64), debugFlags: debug, debugAllEnqueuedSpans: make(map[oteltrace.SpanID]*tracev1.Span), tracker: newSpanTracker(observer, debug), @@ -92,12 +101,14 @@ func NewSpanExportQueue(ctx context.Context, client otlptrace.Client) *SpanExpor if err != nil { panic(err) } + q.wg.Add(2) go q.runUploader() + go q.runProcessor() return q } func (q *SpanExportQueue) runUploader() { - defer close(q.closed) + defer q.wg.Done() for resourceSpans := range q.uploadC { ctx, ca := context.WithTimeout(context.Background(), 10*time.Second) if err := q.client.UploadTraces(ctx, resourceSpans); err != nil { @@ -107,26 +118,34 @@ func (q *SpanExportQueue) runUploader() { } } -func (q *SpanExportQueue) onEvict(traceID unique.Handle[oteltrace.TraceID], buf *Buffer) { - if buf.IsEmpty() { - // if the buffer is not empty, it was evicted automatically - return - } else if mapping, ok := q.knownTraceIDMappings.Get(traceID); ok && mapping == zeroTraceID { - q.logger.Debug(). - Str("traceID", traceID.Value().String()). - Msg("dropping unsampled trace") - return +func (q *SpanExportQueue) runProcessor() { + defer q.wg.Done() + for { + select { + case req := <-q.requestC: + q.processRequestLocked(req) + case ev := <-q.evictC: + q.processEvictionLocked(ev) + case <-q.closing: + for { + select { + case req := <-q.requestC: + q.processRequestLocked(req) + case ev := <-q.evictC: + q.processEvictionLocked(ev) + default: // all channels empty + close(q.uploadC) + return + } + } + } } +} - select { - case q.uploadC <- buf.Flush(): - q.logger.Warn(). - Str("traceID", traceID.Value().String()). - Msg("trace export buffer is full, uploading oldest incomplete trace") - default: - q.logger.Warn(). - Str("traceID", traceID.Value().String()). - Msg("trace export buffer and upload queues are full, dropping trace") +func (q *SpanExportQueue) onEvict(traceID unique.Handle[oteltrace.TraceID], buf *Buffer) { + q.evictC <- eviction{ + traceID: traceID, + buf: buf, } } @@ -183,13 +202,17 @@ func (q *SpanExportQueue) isKnownTracePendingLocked(id unique.Handle[oteltrace.T var ErrShuttingDown = errors.New("exporter is shutting down") -func (q *SpanExportQueue) Enqueue(ctx context.Context, req *coltracepb.ExportTraceServiceRequest) error { - q.mu.Lock() - defer q.mu.Unlock() - if q.closing { +func (q *SpanExportQueue) Enqueue(_ context.Context, req *coltracepb.ExportTraceServiceRequest) error { + select { + case <-q.closing: return ErrShuttingDown + default: + q.requestC <- req + return nil } +} +func (q *SpanExportQueue) processRequestLocked(req *coltracepb.ExportTraceServiceRequest) { if q.debugFlags.Check(LogAllEvents) { q.debugEvents = append(q.debugEvents, DebugEvent{ Timestamp: time.Now(), @@ -265,7 +288,7 @@ func (q *SpanExportQueue) Enqueue(ctx context.Context, req *coltracepb.ExportTra tp, err := ParseTraceparent(attr.GetValue().GetStringValue()) if err != nil { data, _ := protojson.Marshal(span) - log.Ctx(ctx). + q.logger. Err(err). Str("span", string(data)). Msg("error processing span") @@ -284,7 +307,7 @@ func (q *SpanExportQueue) Enqueue(ctx context.Context, req *coltracepb.ExportTra value, err := oteltrace.SpanIDFromHex(attr.GetValue().GetStringValue()) if err != nil { data, _ := protojson.Marshal(span) - log.Ctx(ctx). + q.logger. Err(err). Str("span", string(data)). Msg("error processing span: invalid value for pomerium.external-parent-span") @@ -354,7 +377,29 @@ func (q *SpanExportQueue) Enqueue(ctx context.Context, req *coltracepb.ExportTra if resourceSpans := toUpload.Flush(); len(resourceSpans) > 0 { q.uploadC <- resourceSpans } - return nil +} + +func (q *SpanExportQueue) processEvictionLocked(ev eviction) { + if ev.buf.IsEmpty() { + // if the buffer is not empty, it was evicted automatically + return + } else if mapping, ok := q.knownTraceIDMappings.Get(ev.traceID); ok && mapping == zeroTraceID { + q.logger.Debug(). + Str("traceID", ev.traceID.Value().String()). + Msg("dropping unsampled trace") + return + } + + select { + case q.uploadC <- ev.buf.Flush(): + q.logger.Warn(). + Str("traceID", ev.traceID.Value().String()). + Msg("trace export buffer is full, uploading oldest incomplete trace") + default: + q.logger.Warn(). + Str("traceID", ev.traceID.Value().String()). + Msg("trace export buffer and upload queues are full, dropping trace") + } } var ( @@ -382,20 +427,17 @@ func (q *SpanExportQueue) WaitForSpans(maxDuration time.Duration) error { } func (q *SpanExportQueue) Close(ctx context.Context) error { - q.mu.Lock() - q.closing = true - close(q.uploadC) - q.mu.Unlock() + closed := make(chan struct{}) + go func() { + q.wg.Wait() + close(closed) + }() + close(q.closing) select { case <-ctx.Done(): log.Ctx(ctx).Error().Msg("exporter stopped before all traces could be exported") - // drain uploadC - for range q.uploadC { - } return context.Cause(ctx) - case <-q.closed: - q.mu.Lock() - defer q.mu.Unlock() + case <-closed: err := q.runOnCloseChecksLocked() log.Ctx(ctx).Debug().Err(err).Msg("exporter stopped") return err @@ -595,19 +637,51 @@ func (e *DebugEvent) UnmarshalJSON(b []byte) error { return nil } +const shardCount = 64 + +type ( + shardedSet [shardCount]map[oteltrace.SpanID]struct{} + shardedLocks [shardCount]sync.Mutex +) + +func (s *shardedSet) Range(f func(key oteltrace.SpanID)) { + for i := range shardCount { + for k := range s[i] { + f(k) + } + } +} + +func (s *shardedLocks) LockAll() { + for i := range shardCount { + s[i].Lock() + } +} + +func (s *shardedLocks) UnlockAll() { + for i := range shardCount { + s[i].Unlock() + } +} + type spanTracker struct { - inflightSpans sync.Map - allSpans sync.Map - debugFlags DebugFlags - observer *spanObserver - shutdownOnce sync.Once + inflightSpansMu shardedLocks + inflightSpans shardedSet + allSpans sync.Map + debugFlags DebugFlags + observer *spanObserver + shutdownOnce sync.Once } func newSpanTracker(observer *spanObserver, debugFlags DebugFlags) *spanTracker { - return &spanTracker{ + st := &spanTracker{ observer: observer, debugFlags: debugFlags, } + for i := range len(st.inflightSpans) { + st.inflightSpans[i] = make(map[oteltrace.SpanID]struct{}) + } + return st } type spanInfo struct { @@ -626,13 +700,20 @@ func (t *spanTracker) ForceFlush(context.Context) error { // OnEnd implements trace.SpanProcessor. func (t *spanTracker) OnEnd(s sdktrace.ReadOnlySpan) { id := s.SpanContext().SpanID() - t.inflightSpans.Delete(id) + bucket := binary.BigEndian.Uint64(id[:]) % shardCount + t.inflightSpansMu[bucket].Lock() + defer t.inflightSpansMu[bucket].Unlock() + delete(t.inflightSpans[bucket], id) } // OnStart implements trace.SpanProcessor. func (t *spanTracker) OnStart(_ context.Context, s sdktrace.ReadWriteSpan) { id := s.SpanContext().SpanID() - t.inflightSpans.Store(id, struct{}{}) + bucket := binary.BigEndian.Uint64(id[:]) % shardCount + t.inflightSpansMu[bucket].Lock() + defer t.inflightSpansMu[bucket].Unlock() + t.inflightSpans[bucket][id] = struct{}{} + if t.debugFlags.Check(TrackSpanReferences) { t.observer.Observe(id) } @@ -664,12 +745,13 @@ func (t *spanTracker) Shutdown(_ context.Context) error { if t.debugFlags.Check(WarnOnIncompleteSpans) { if t.debugFlags.Check(TrackAllSpans) { incompleteSpans := []*spanInfo{} - t.inflightSpans.Range(func(key, _ any) bool { + t.inflightSpansMu.LockAll() + t.inflightSpans.Range(func(key oteltrace.SpanID) { if info, ok := t.allSpans.Load(key); ok { incompleteSpans = append(incompleteSpans, info.(*spanInfo)) } - return true }) + t.inflightSpansMu.UnlockAll() if len(incompleteSpans) > 0 { didWarn = true msg := startMsg("WARNING: spans not ended:\n") @@ -689,10 +771,11 @@ func (t *spanTracker) Shutdown(_ context.Context) error { } } else { incompleteSpans := []oteltrace.SpanID{} - t.inflightSpans.Range(func(key, _ any) bool { - incompleteSpans = append(incompleteSpans, key.(oteltrace.SpanID)) - return true + t.inflightSpansMu.LockAll() + t.inflightSpans.Range(func(key oteltrace.SpanID) { + incompleteSpans = append(incompleteSpans, key) }) + t.inflightSpansMu.UnlockAll() if len(incompleteSpans) > 0 { didWarn = true msg := startMsg("WARNING: spans not ended:\n") diff --git a/internal/telemetry/trace/trace_export_test.go b/internal/telemetry/trace/trace_export_test.go index 9917d6c61..35360bce0 100644 --- a/internal/telemetry/trace/trace_export_test.go +++ b/internal/telemetry/trace/trace_export_test.go @@ -53,10 +53,11 @@ func (obs *spanObserver) XObservedIDs() []oteltrace.SpanID { func (t *spanTracker) XInflightSpans() []oteltrace.SpanID { ids := []oteltrace.SpanID{} - t.inflightSpans.Range(func(key, _ any) bool { - ids = append(ids, key.(oteltrace.SpanID)) - return true + t.inflightSpansMu.LockAll() + t.inflightSpans.Range(func(key oteltrace.SpanID) { + ids = append(ids, key) }) + t.inflightSpansMu.UnlockAll() slices.SortFunc(ids, func(a, b oteltrace.SpanID) int { return cmp.Compare(a.String(), b.String()) }) diff --git a/internal/testenv/environment.go b/internal/testenv/environment.go index 0c78ea94a..00744b9e8 100644 --- a/internal/testenv/environment.go +++ b/internal/testenv/environment.go @@ -303,7 +303,7 @@ func WithTraceClient(traceClient otlptrace.Client) EnvironmentOption { var setGrpcLoggerOnce sync.Once -const defaultTraceDebugFlags = trace.TrackSpanCallers +const defaultTraceDebugFlags = trace.TrackSpanCallers | trace.TrackSpanReferences var ( flagDebug = flag.Bool("env.debug", false, "enables test environment debug logging (equivalent to Debug() option)")