various bugfixes and improvements

This commit is contained in:
Joe Kralicky 2024-12-05 04:37:56 +00:00
parent e221c8af84
commit 51fa483885
No known key found for this signature in database
GPG key ID: 75C4875F34A9FB79
12 changed files with 819 additions and 306 deletions

View file

@ -11,6 +11,7 @@ import (
envoy_extensions_access_loggers_grpc_v3 "github.com/envoyproxy/go-control-plane/envoy/extensions/access_loggers/grpc/v3"
envoy_extensions_filters_http_header_to_metadata "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/http/header_to_metadata/v3"
envoy_extensions_filters_network_http_connection_manager "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/network/http_connection_manager/v3"
envoy_extensions_tracers_otel "github.com/envoyproxy/go-control-plane/envoy/extensions/tracers/opentelemetry/resource_detectors/v3"
metadatav3 "github.com/envoyproxy/go-control-plane/envoy/type/metadata/v3"
envoy_tracing_v3 "github.com/envoyproxy/go-control-plane/envoy/type/tracing/v3"
envoy_type_v3 "github.com/envoyproxy/go-control-plane/envoy/type/v3"
@ -202,7 +203,7 @@ func (b *Builder) buildMainHTTPConnectionManagerFilter(
RandomSampling: &envoy_type_v3.Percent{Value: cfg.Options.TracingSampleRate * 100},
ClientSampling: &envoy_type_v3.Percent{Value: cfg.Options.TracingSampleRate * 100},
Verbose: true,
SpawnUpstreamSpan: wrapperspb.Bool(false),
SpawnUpstreamSpan: wrapperspb.Bool(true),
Provider: &tracev3.Tracing_Http{
Name: "envoy.tracers.opentelemetry",
ConfigType: &tracev3.Tracing_Http_TypedConfig{
@ -215,6 +216,16 @@ func (b *Builder) buildMainHTTPConnectionManagerFilter(
},
},
ServiceName: "Envoy",
ResourceDetectors: []*envoy_config_core_v3.TypedExtensionConfig{
{
Name: "envoy.tracers.opentelemetry.resource_detectors.static_config",
TypedConfig: marshalAny(&envoy_extensions_tracers_otel.StaticConfigResourceDetectorConfig{
Attributes: map[string]string{
"pomerium.envoy": "true",
},
}),
},
},
}),
},
},

View file

@ -0,0 +1,32 @@
package trace
import (
"context"
"fmt"
"runtime"
"go.opentelemetry.io/otel/attribute"
sdktrace "go.opentelemetry.io/otel/sdk/trace"
)
type stackTraceProcessor struct{}
// ForceFlush implements trace.SpanProcessor.
func (s *stackTraceProcessor) ForceFlush(ctx context.Context) error {
return nil
}
// OnEnd implements trace.SpanProcessor.
func (*stackTraceProcessor) OnEnd(s sdktrace.ReadOnlySpan) {
}
// OnStart implements trace.SpanProcessor.
func (*stackTraceProcessor) OnStart(parent context.Context, s sdktrace.ReadWriteSpan) {
_, file, line, _ := runtime.Caller(2)
s.SetAttributes(attribute.String("caller", fmt.Sprintf("%s:%d", file, line)))
}
// Shutdown implements trace.SpanProcessor.
func (s *stackTraceProcessor) Shutdown(ctx context.Context) error {
return nil
}

View file

@ -0,0 +1,32 @@
package trace
import (
"context"
"go.opentelemetry.io/otel/trace"
"go.opentelemetry.io/otel/trace/embedded"
)
const PomeriumCoreTracer = "pomerium.io/core"
type panicTracerProvider struct {
embedded.TracerProvider
}
// Tracer implements trace.TracerProvider.
func (w panicTracerProvider) Tracer(name string, options ...trace.TracerOption) trace.Tracer {
return panicTracer{}
}
type panicTracer struct {
embedded.Tracer
}
// Start implements trace.Tracer.
func (p panicTracer) Start(ctx context.Context, spanName string, opts ...trace.SpanStartOption) (context.Context, trace.Span) {
panic("global tracer used")
}
func Continue(ctx context.Context, name string, o ...trace.SpanStartOption) (context.Context, trace.Span) {
return trace.SpanFromContext(ctx).TracerProvider().Tracer(PomeriumCoreTracer).Start(ctx, name, o...)
}

View file

@ -1,6 +1,7 @@
package trace
import (
"context"
"fmt"
"net/http"
@ -8,6 +9,8 @@ import (
"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/propagation"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/stats"
)
func NewHTTPMiddleware(opts ...otelhttp.Option) func(http.Handler) http.Handler {
@ -41,3 +44,52 @@ func NewHTTPMiddleware(opts ...otelhttp.Option) func(http.Handler) http.Handler
})
}
}
func NewStatsHandler(base stats.Handler) stats.Handler {
return &statsHandlerWrapper{
base: base,
}
}
type statsHandlerWrapper struct {
base stats.Handler
}
func (w *statsHandlerWrapper) wrapContext(ctx context.Context) context.Context {
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
return ctx
}
traceparent := md.Get("traceparent")
xPomeriumTraceparent := md.Get("x-pomerium-traceparent")
if len(traceparent) > 0 && traceparent[0] != "" && len(xPomeriumTraceparent) > 0 && xPomeriumTraceparent[0] != "" {
newTracectx, err := ParseTraceparent(xPomeriumTraceparent[0])
if err != nil {
return ctx
}
md.Set("traceparent", ReplaceTraceID(traceparent[0], newTracectx.TraceID()))
return metadata.NewIncomingContext(ctx, md)
}
return ctx
}
// HandleConn implements stats.Handler.
func (w *statsHandlerWrapper) HandleConn(ctx context.Context, stats stats.ConnStats) {
w.base.HandleConn(w.wrapContext(ctx), stats)
}
// HandleRPC implements stats.Handler.
func (w *statsHandlerWrapper) HandleRPC(ctx context.Context, stats stats.RPCStats) {
w.base.HandleRPC(w.wrapContext(ctx), stats)
}
// TagConn implements stats.Handler.
func (w *statsHandlerWrapper) TagConn(ctx context.Context, info *stats.ConnTagInfo) context.Context {
return w.base.TagConn(w.wrapContext(ctx), info)
}
// TagRPC implements stats.Handler.
func (w *statsHandlerWrapper) TagRPC(ctx context.Context, info *stats.RPCTagInfo) context.Context {
return w.base.TagRPC(w.wrapContext(ctx), info)
}

View file

@ -3,10 +3,16 @@ package trace
import (
"context"
"encoding/base64"
"encoding/hex"
"errors"
"fmt"
"net"
"net/url"
"os"
"strings"
"sync"
"time"
"unique"
coltracepb "go.opentelemetry.io/proto/otlp/collector/trace/v1"
commonv1 "go.opentelemetry.io/proto/otlp/common/v1"
@ -21,6 +27,8 @@ import (
"github.com/pomerium/pomerium/internal/log"
"go.opentelemetry.io/otel/exporters/otlp/otlptrace"
"go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc"
sdktrace "go.opentelemetry.io/otel/sdk/trace"
"go.opentelemetry.io/otel/trace"
oteltrace "go.opentelemetry.io/otel/trace"
)
@ -57,16 +65,12 @@ func (ptr *PendingScopes) Insert(scope *commonv1.InstrumentationScope, scopeSche
spans.Insert(span)
}
func (ptr *PendingScopes) Delete(scope *commonv1.InstrumentationScope) (cascade bool) {
delete(ptr.spansByScope, scope.GetName())
return len(ptr.spansByScope) == 0
}
func (ptr *PendingScopes) AsScopeSpansList(rewriteTraceId oteltrace.TraceID) []*tracev1.ScopeSpans {
func (ptr *PendingScopes) AsScopeSpansList(rewriteTraceId unique.Handle[oteltrace.TraceID]) []*tracev1.ScopeSpans {
out := make([]*tracev1.ScopeSpans, 0, len(ptr.spansByScope))
for _, spans := range ptr.spansByScope {
for _, span := range spans.spans {
span.TraceId = rewriteTraceId[:]
id := rewriteTraceId.Value()
copy(span.TraceId, id[:])
}
scopeSpans := &tracev1.ScopeSpans{
Scope: spans.scope,
@ -101,15 +105,7 @@ func (ptr *PendingResources) Insert(resource *ResourceInfo, scope *commonv1.Inst
scopes.Insert(scope, scopeSchema, span)
}
func (ptr *PendingResources) Delete(resource *ResourceInfo, scope *commonv1.InstrumentationScope) (cascade bool) {
resourceEq := resource.ID()
if ptr.scopesByResourceID[resourceEq].Delete(scope) {
delete(ptr.scopesByResourceID, resourceEq)
}
return len(ptr.scopesByResourceID) == 0
}
func (ptr *PendingResources) AsResourceSpans(rewriteTraceId oteltrace.TraceID) []*tracev1.ResourceSpans {
func (ptr *PendingResources) AsResourceSpans(rewriteTraceId unique.Handle[oteltrace.TraceID]) []*tracev1.ResourceSpans {
out := make([]*tracev1.ResourceSpans, 0, len(ptr.scopesByResourceID))
for _, scopes := range ptr.scopesByResourceID {
resourceSpans := &tracev1.ResourceSpans{
@ -152,28 +148,67 @@ func (r *ResourceInfo) computeID() string {
return base64.StdEncoding.EncodeToString(hash.Sum(nil))
}
type spanObserver struct {
mu sync.Mutex
referencedIDs map[unique.Handle[oteltrace.SpanID]]bool
unobservedIDs sync.WaitGroup
}
func (obs *spanObserver) ObserveReference(id unique.Handle[oteltrace.SpanID]) {
obs.mu.Lock()
defer obs.mu.Unlock()
if _, referenced := obs.referencedIDs[id]; !referenced {
obs.referencedIDs[id] = false // referenced, but not observed
obs.unobservedIDs.Add(1)
}
}
func (obs *spanObserver) Observe(id unique.Handle[oteltrace.SpanID]) {
obs.mu.Lock()
defer obs.mu.Unlock()
if observed, referenced := obs.referencedIDs[id]; !observed { // NB: subtle condition
obs.referencedIDs[id] = true
if referenced {
obs.unobservedIDs.Done()
}
}
}
func (obs *spanObserver) Wait() {
obs.unobservedIDs.Wait()
}
type SpanExportQueue struct {
mu sync.Mutex
pendingResourcesByTraceId map[string]*PendingResources
knownTraceIdMappings map[string]oteltrace.TraceID
pendingResourcesByTraceId map[unique.Handle[oteltrace.TraceID]]*PendingResources
knownTraceIdMappings map[unique.Handle[oteltrace.TraceID]]unique.Handle[oteltrace.TraceID]
uploadC chan []*tracev1.ResourceSpans
closing bool
closed chan struct{}
debugLevel int
debugAllObservedSpans map[unique.Handle[oteltrace.SpanID]]*tracev1.Span
tracker *spanTracker
observer *spanObserver
}
func NewSpanExportQueue(ctx context.Context, client otlptrace.Client) *SpanExportQueue {
observer := &spanObserver{referencedIDs: make(map[unique.Handle[oteltrace.SpanID]]bool)}
debugLevel := systemContextFromContext(ctx).DebugLevel
q := &SpanExportQueue{
pendingResourcesByTraceId: make(map[string]*PendingResources),
knownTraceIdMappings: make(map[string]oteltrace.TraceID),
pendingResourcesByTraceId: make(map[unique.Handle[oteltrace.TraceID]]*PendingResources),
knownTraceIdMappings: make(map[unique.Handle[oteltrace.TraceID]]unique.Handle[oteltrace.TraceID]),
uploadC: make(chan []*tracev1.ResourceSpans, 8),
closed: make(chan struct{}),
debugLevel: debugLevel,
debugAllObservedSpans: make(map[unique.Handle[oteltrace.SpanID]]*tracev1.Span),
tracker: &spanTracker{observer: observer, debugLevel: debugLevel},
observer: observer,
}
go func() {
for {
select {
case <-ctx.Done():
return
case resourceSpans := <-q.uploadC:
if err := client.UploadTraces(ctx, resourceSpans); err != nil {
log.Ctx(ctx).Err(err).Msg("error uploading traces")
}
defer close(q.closed)
for resourceSpans := range q.uploadC {
if err := client.UploadTraces(context.Background(), resourceSpans); err != nil {
log.Ctx(ctx).Err(err).Msg("error uploading traces")
}
}
}()
@ -186,48 +221,65 @@ type WithSchema[T any] struct {
}
func (q *SpanExportQueue) insertPendingSpanLocked(resource *ResourceInfo, scope *commonv1.InstrumentationScope, scopeSchema string, span *tracev1.Span) {
spanTraceIdHex := oteltrace.TraceID(span.TraceId).String()
spanTraceId := unique.Make(oteltrace.TraceID(span.TraceId))
var pendingTraceResources *PendingResources
if ptr, ok := q.pendingResourcesByTraceId[spanTraceIdHex]; ok {
if ptr, ok := q.pendingResourcesByTraceId[spanTraceId]; ok {
pendingTraceResources = ptr
} else {
pendingTraceResources = NewPendingResources()
q.pendingResourcesByTraceId[spanTraceIdHex] = pendingTraceResources
q.pendingResourcesByTraceId[spanTraceId] = pendingTraceResources
}
pendingTraceResources.Insert(resource, scope, scopeSchema, span)
}
func (q *SpanExportQueue) resolveTraceIdMappingLocked(resource *ResourceInfo, scope *commonv1.InstrumentationScope, scopeSchema string, span *tracev1.Span, mapping oteltrace.TraceID) {
originalTraceIdHex := oteltrace.TraceID(span.TraceId).String()
q.insertPendingSpanLocked(resource, scope, scopeSchema, span)
q.knownTraceIdMappings[originalTraceIdHex] = mapping
toUpload := q.pendingResourcesByTraceId[originalTraceIdHex].AsResourceSpans(mapping)
if q.pendingResourcesByTraceId[originalTraceIdHex].Delete(resource, scope) {
delete(q.pendingResourcesByTraceId, originalTraceIdHex)
func (q *SpanExportQueue) resolveTraceIdMappingLocked(original, mapping unique.Handle[oteltrace.TraceID]) [][]*tracev1.ResourceSpans {
q.knownTraceIdMappings[original] = mapping
toUpload := [][]*tracev1.ResourceSpans{}
if originalPending, ok := q.pendingResourcesByTraceId[original]; ok {
resourceSpans := originalPending.AsResourceSpans(mapping)
delete(q.pendingResourcesByTraceId, original)
toUpload = append(toUpload, resourceSpans)
}
q.uploadC <- toUpload
if original != mapping {
q.knownTraceIdMappings[mapping] = mapping
if targetPending, ok := q.pendingResourcesByTraceId[mapping]; ok {
resourceSpans := targetPending.AsResourceSpans(mapping)
delete(q.pendingResourcesByTraceId, mapping)
toUpload = append(toUpload, resourceSpans)
}
}
return toUpload
}
func (q *SpanExportQueue) Enqueue(ctx context.Context, req *coltracepb.ExportTraceServiceRequest) {
var ErrShuttingDown = errors.New("exporter is shutting down")
func (q *SpanExportQueue) Enqueue(ctx context.Context, req *coltracepb.ExportTraceServiceRequest) error {
q.mu.Lock()
defer q.mu.Unlock()
if q.closing {
return ErrShuttingDown
}
var immediateUpload []*tracev1.ResourceSpans
var toUpload [][]*tracev1.ResourceSpans
for _, resource := range req.ResourceSpans {
resourceInfo := newResourceInfo(resource.Resource, resource.SchemaUrl)
knownResources := &tracev1.ResourceSpans{
Resource: resource.Resource,
SchemaUrl: resource.SchemaUrl,
}
for _, scope := range resource.ScopeSpans {
var knownSpans []*tracev1.Span
for _, span := range scope.Spans {
spanTraceId := oteltrace.TraceID(span.TraceId)
spanTraceIdHex := oteltrace.TraceID(span.TraceId).String()
formatSpanName(span)
if len(span.ParentSpanId) == 0 {
// observed a new root span
spanId := unique.Make(oteltrace.SpanID(span.SpanId))
parentSpanId := parentSpanID(span.ParentSpanId)
if q.debugLevel >= 1 {
q.debugAllObservedSpans[spanId] = span
}
if parentSpanId != rootSpanId {
q.observer.ObserveReference(parentSpanId)
continue
}
spanTraceId := unique.Make(oteltrace.TraceID(span.TraceId))
if _, ok := q.knownTraceIdMappings[spanTraceId]; !ok {
// observed a new root span with an unknown trace id
var pomeriumTraceparent string
for _, attr := range span.Attributes {
if attr.Key == "pomerium.traceparent" {
@ -235,11 +287,11 @@ func (q *SpanExportQueue) Enqueue(ctx context.Context, req *coltracepb.ExportTra
break
}
}
var targetTraceID oteltrace.TraceID
var mappedTraceID unique.Handle[oteltrace.TraceID]
if pomeriumTraceparent == "" {
// no replacement id, map the trace to itself and release pending spans
targetTraceID = spanTraceId
mappedTraceID = spanTraceId
} else {
// this root span has an alternate traceparent. permanently rewrite
// all spans of the old trace id to use the new trace id
@ -248,33 +300,204 @@ func (q *SpanExportQueue) Enqueue(ctx context.Context, req *coltracepb.ExportTra
log.Ctx(ctx).Err(err).Msg("error processing trace")
continue
}
targetTraceID = tp.TraceID()
mappedTraceID = unique.Make(tp.TraceID())
}
q.resolveTraceIdMappingLocked(resourceInfo, scope.Scope, scope.SchemaUrl, span, targetTraceID)
toUpload = append(toUpload, q.resolveTraceIdMappingLocked(spanTraceId, mappedTraceID)...)
}
}
}
}
var knownResources []*tracev1.ResourceSpans
for _, resource := range req.ResourceSpans {
resourceInfo := newResourceInfo(resource.Resource, resource.SchemaUrl)
knownResource := &tracev1.ResourceSpans{
Resource: resource.Resource,
SchemaUrl: resource.SchemaUrl,
}
for _, scope := range resource.ScopeSpans {
var knownSpans []*tracev1.Span
for _, span := range scope.Spans {
spanID := unique.Make(oteltrace.SpanID(span.SpanId))
spanTraceId := unique.Make(oteltrace.TraceID(span.TraceId))
q.observer.Observe(spanID)
if mapping, ok := q.knownTraceIdMappings[spanTraceId]; ok {
id := mapping.Value()
copy(span.TraceId, id[:])
knownSpans = append(knownSpans, span)
} else {
if rewrite, ok := q.knownTraceIdMappings[spanTraceIdHex]; ok {
span.TraceId = rewrite[:]
knownSpans = append(knownSpans, span)
} else {
q.insertPendingSpanLocked(resourceInfo, scope.Scope, scope.SchemaUrl, span)
}
q.insertPendingSpanLocked(resourceInfo, scope.Scope, scope.SchemaUrl, span)
}
}
if len(knownSpans) > 0 {
knownResources.ScopeSpans = append(knownResources.ScopeSpans, &tracev1.ScopeSpans{
knownResource.ScopeSpans = append(knownResource.ScopeSpans, &tracev1.ScopeSpans{
Scope: scope.Scope,
SchemaUrl: scope.SchemaUrl,
Spans: knownSpans,
})
}
}
if len(knownResources.ScopeSpans) > 0 {
immediateUpload = append(immediateUpload, knownResources)
if len(knownResource.ScopeSpans) > 0 {
knownResources = append(knownResources, knownResource)
}
}
if len(immediateUpload) > 0 {
q.uploadC <- immediateUpload
if len(knownResources) > 0 {
toUpload = append(toUpload, knownResources)
}
for _, res := range toUpload {
q.uploadC <- res
}
return nil
}
var (
ErrIncompleteTraces = errors.New("exporter shut down with incomplete traces")
ErrIncompleteUploads = errors.New("exporter shut down with pending trace uploads")
ErrMissingParentSpans = errors.New("exporter shut down with missing parent spans")
)
var rootSpanId = unique.Make(oteltrace.SpanID([8]byte{}))
func parentSpanID(value []byte) unique.Handle[oteltrace.SpanID] {
if len(value) == 0 {
return rootSpanId
}
return unique.Make(oteltrace.SpanID(value))
}
func (q *SpanExportQueue) WaitForSpans(maxDuration time.Duration) error {
done := make(chan struct{})
go func() {
defer close(done)
q.observer.Wait()
}()
select {
case <-done:
return nil
case <-time.After(maxDuration):
return ErrMissingParentSpans
}
}
func (q *SpanExportQueue) Close(ctx context.Context) error {
q.mu.Lock()
q.closing = true
close(q.uploadC)
q.mu.Unlock()
select {
case <-ctx.Done():
return context.Cause(ctx)
case <-q.closed:
q.mu.Lock()
defer q.mu.Unlock()
if q.debugLevel >= 1 {
var unknownParentIds []string
for id, known := range q.observer.referencedIDs {
if !known {
unknownParentIds = append(unknownParentIds, id.Value().String())
}
}
if len(unknownParentIds) > 0 {
msg := strings.Builder{}
msg.WriteString("==================================================\n")
msg.WriteString("WARNING: parent spans referenced but never seen:\n")
for _, str := range unknownParentIds {
msg.WriteString(str)
msg.WriteString("\n")
}
msg.WriteString("==================================================\n")
fmt.Fprint(os.Stderr, msg.String())
}
}
incomplete := len(q.pendingResourcesByTraceId) > 0
if incomplete || q.debugLevel >= 3 {
msg := strings.Builder{}
if incomplete && q.debugLevel >= 1 {
msg.WriteString("==================================================\n")
msg.WriteString("WARNING: exporter shut down with incomplete traces\n")
for k, v := range q.pendingResourcesByTraceId {
msg.WriteString(fmt.Sprintf("- Trace: %s\n", k.Value()))
for _, pendingScope := range v.scopesByResourceID {
msg.WriteString(" - Resource:\n")
for _, v := range pendingScope.resource.Resource.Attributes {
msg.WriteString(fmt.Sprintf(" %s=%s\n", v.Key, v.Value.String()))
}
for _, scope := range pendingScope.spansByScope {
if scope.scope != nil {
msg.WriteString(fmt.Sprintf(" Scope: %s\n", scope.scope.Name))
} else {
msg.WriteString(" Scope: (unknown)\n")
}
msg.WriteString(" Spans:\n")
longestName := 0
for _, span := range scope.spans {
longestName = max(longestName, len(span.Name)+2)
}
for _, span := range scope.spans {
parentSpanId := parentSpanID(span.ParentSpanId)
_, seenParent := q.debugAllObservedSpans[parentSpanId]
var missing string
if !seenParent {
missing = " [missing]"
}
msg.WriteString(fmt.Sprintf(" - %-*s (trace: %s | span: %s | parent:%s %s)\n", longestName,
"'"+span.Name+"'", hex.EncodeToString(span.TraceId), hex.EncodeToString(span.SpanId), missing, parentSpanId.Value()))
for _, attr := range span.Attributes {
if attr.Key == "caller" {
msg.WriteString(fmt.Sprintf(" => caller: '%s'\n", attr.Value.GetStringValue()))
}
}
}
}
}
}
msg.WriteString("==================================================\n")
}
if (incomplete && q.debugLevel >= 2) || (!incomplete && q.debugLevel >= 3) {
msg.WriteString("==================================================\n")
msg.WriteString("Known trace ids:\n")
for k, v := range q.knownTraceIdMappings {
if k != v {
msg.WriteString(fmt.Sprintf("%s => %s\n", k.Value(), v.Value()))
} else {
msg.WriteString(fmt.Sprintf("%s (no change)\n", k.Value()))
}
}
msg.WriteString("==================================================\n")
msg.WriteString("All exported spans:\n")
longestName := 0
for _, span := range q.debugAllObservedSpans {
longestName = max(longestName, len(span.Name)+2)
}
for _, span := range q.debugAllObservedSpans {
traceid := span.TraceId
spanid := span.SpanId
msg.WriteString(fmt.Sprintf("%-*s (trace: %s | span: %s | parent: %s)", longestName,
"'"+span.Name+"'", hex.EncodeToString(traceid[:]), hex.EncodeToString(spanid[:]), parentSpanID(span.ParentSpanId).Value()))
var foundCaller bool
for _, attr := range span.Attributes {
if attr.Key == "caller" {
msg.WriteString(fmt.Sprintf(" => %s\n", attr.Value.GetStringValue()))
foundCaller = true
break
}
}
if !foundCaller {
msg.WriteString("\n")
}
}
msg.WriteString("==================================================\n")
}
if msg.Len() > 0 {
fmt.Fprint(os.Stderr, msg.String())
}
if incomplete {
return ErrIncompleteTraces
}
}
log.Ctx(ctx).Debug().Msg("exporter shut down")
return nil
}
}
@ -308,28 +531,35 @@ func formatSpanName(span *tracev1.Span) {
}
// Export implements ptraceotlp.GRPCServer.
func (srv *Server) Export(ctx context.Context, req *coltracepb.ExportTraceServiceRequest) (*coltracepb.ExportTraceServiceResponse, error) {
func (srv *ExporterServer) Export(ctx context.Context, req *coltracepb.ExportTraceServiceRequest) (*coltracepb.ExportTraceServiceResponse, error) {
srv.spanExportQueue.Enqueue(ctx, req)
return &coltracepb.ExportTraceServiceResponse{}, nil
}
type Server struct {
type ExporterServer struct {
coltracepb.UnimplementedTraceServiceServer
spanExportQueue *SpanExportQueue
server *grpc.Server
remoteClient otlptrace.Client
cc *grpc.ClientConn
}
func NewServer(ctx context.Context, client otlptrace.Client) *Server {
client.Start(ctx)
return &Server{
spanExportQueue: NewSpanExportQueue(ctx, client),
func NewServer(ctx context.Context, remoteClient otlptrace.Client) *ExporterServer {
if err := remoteClient.Start(ctx); err != nil {
panic(err)
}
ex := &ExporterServer{
spanExportQueue: NewSpanExportQueue(ctx, remoteClient),
remoteClient: remoteClient,
server: grpc.NewServer(grpc.Creds(insecure.NewCredentials())),
}
coltracepb.RegisterTraceServiceServer(ex.server, ex)
return ex
}
func (srv *Server) Start(ctx context.Context) otlptrace.Client {
func (srv *ExporterServer) Start(ctx context.Context) {
lis := bufconn.Listen(4096)
gs := grpc.NewServer(grpc.Creds(insecure.NewCredentials()))
coltracepb.RegisterTraceServiceServer(gs, srv)
go gs.Serve(lis)
go srv.server.Serve(lis)
cc, err := grpc.NewClient("passthrough://ignore",
grpc.WithContextDialer(func(context.Context, string) (net.Conn, error) {
return lis.Dial()
@ -337,5 +567,125 @@ func (srv *Server) Start(ctx context.Context) otlptrace.Client {
if err != nil {
panic(err)
}
return otlptracegrpc.NewClient(otlptracegrpc.WithGRPCConn(cc))
srv.cc = cc
}
func (srv *ExporterServer) NewClient() otlptrace.Client {
return otlptracegrpc.NewClient(otlptracegrpc.WithGRPCConn(srv.cc))
}
func (srv *ExporterServer) SpanProcessors() []sdktrace.SpanProcessor {
return []sdktrace.SpanProcessor{srv.spanExportQueue.tracker}
}
func (srv *ExporterServer) Shutdown(ctx context.Context) error {
stopped := make(chan struct{})
go func() {
srv.server.GracefulStop()
close(stopped)
}()
select {
case <-stopped:
case <-ctx.Done():
return context.Cause(ctx)
}
var errs []error
if err := srv.spanExportQueue.WaitForSpans(5 * time.Second); err != nil {
errs = append(errs, err)
}
if err := srv.spanExportQueue.Close(ctx); err != nil {
errs = append(errs, err)
}
if err := srv.remoteClient.Stop(ctx); err != nil {
errs = append(errs, err)
}
return errors.Join(errs...)
}
type spanTracker struct {
inflightSpans sync.Map
allSpans sync.Map
debugLevel int
observer *spanObserver
}
type spanInfo struct {
Name string
SpanContext trace.SpanContext
Parent trace.SpanContext
}
// ForceFlush implements trace.SpanProcessor.
func (t *spanTracker) ForceFlush(ctx context.Context) error {
return nil
}
// OnEnd implements trace.SpanProcessor.
func (t *spanTracker) OnEnd(s sdktrace.ReadOnlySpan) {
id := unique.Make(s.SpanContext().SpanID())
t.inflightSpans.Delete(id)
}
// OnStart implements trace.SpanProcessor.
func (t *spanTracker) OnStart(parent context.Context, s sdktrace.ReadWriteSpan) {
id := unique.Make(s.SpanContext().SpanID())
t.inflightSpans.Store(id, struct{}{})
t.observer.Observe(id)
if t.debugLevel >= 3 {
t.allSpans.Store(id, &spanInfo{
Name: s.Name(),
SpanContext: s.SpanContext(),
Parent: s.Parent(),
})
}
}
// Shutdown implements trace.SpanProcessor.
func (t *spanTracker) Shutdown(ctx context.Context) error {
msg := strings.Builder{}
if t.debugLevel >= 1 {
incompleteSpans := []*spanInfo{}
t.inflightSpans.Range(func(key, value any) bool {
if info, ok := t.allSpans.Load(key); ok {
incompleteSpans = append(incompleteSpans, info.(*spanInfo))
}
return true
})
if len(incompleteSpans) > 0 {
msg.WriteString("==================================================\n")
msg.WriteString("WARNING: spans not ended:\n")
longestName := 0
for _, span := range incompleteSpans {
longestName = max(longestName, len(span.Name)+2)
}
for _, span := range incompleteSpans {
msg.WriteString(fmt.Sprintf("%-*s (trace: %s | span: %s | parent: %s)\n", longestName, "'"+span.Name+"'",
span.SpanContext.TraceID(), span.SpanContext.SpanID(), span.Parent.SpanID()))
}
msg.WriteString("==================================================\n")
}
}
if t.debugLevel >= 3 {
allSpans := []*spanInfo{}
t.allSpans.Range(func(key, value any) bool {
allSpans = append(allSpans, value.(*spanInfo))
return true
})
msg.WriteString("==================================================\n")
msg.WriteString("All observed spans:\n")
longestName := 0
for _, span := range allSpans {
longestName = max(longestName, len(span.Name)+2)
}
for _, span := range allSpans {
msg.WriteString(fmt.Sprintf("%-*s (trace: %s | span: %s | parent: %s)\n", longestName, "'"+span.Name+"'",
span.SpanContext.TraceID(), span.SpanContext.SpanID(), span.Parent.SpanID()))
}
msg.WriteString("==================================================\n")
}
if msg.Len() > 0 {
fmt.Fprint(os.Stderr, msg.String())
}
return nil
}

View file

@ -2,13 +2,12 @@ package trace
import (
"context"
"encoding/hex"
"errors"
"fmt"
"os"
"runtime"
"strconv"
"strings"
"sync"
"time"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/attribute"
@ -20,62 +19,88 @@ import (
sdktrace "go.opentelemetry.io/otel/sdk/trace"
semconv "go.opentelemetry.io/otel/semconv/v1.26.0"
"go.opentelemetry.io/otel/trace"
"go.opentelemetry.io/otel/trace/embedded"
coltracepb "go.opentelemetry.io/proto/otlp/collector/trace/v1"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/stats"
)
type (
clientKeyType struct{}
exporterKeyType struct{}
tracerProviderKeyType struct{}
serverKeyType struct{}
)
type systemContextKeyType struct{}
var (
exporterKey exporterKeyType
tracerProviderKey tracerProviderKeyType
serverKey serverKeyType
)
var systemContextKey systemContextKeyType
type shutdownFunc func(options ...trace.SpanEndOption)
type Options struct {
DebugLevel int
}
type systemContext struct {
Options
tpm *tracerProviderManager
exporterServer *ExporterServer
}
func systemContextFromContext(ctx context.Context) *systemContext {
return ctx.Value(systemContextKey).(*systemContext)
}
func init() {
otel.SetTextMapPropagator(propagation.NewCompositeTextMapPropagator(propagation.TraceContext{}, propagation.Baggage{}))
otel.SetTracerProvider(panicTracerProvider{})
}
type panicTracerProvider struct {
embedded.TracerProvider
var _ trace.Tracer = panicTracer{}
type tracerProviderManager struct {
mu sync.Mutex
tracerProviders []*sdktrace.TracerProvider
}
// Tracer implements trace.TracerProvider.
func (w panicTracerProvider) Tracer(name string, options ...trace.TracerOption) trace.Tracer {
panic("global tracer used")
func (tpm *tracerProviderManager) ShutdownAll(ctx context.Context) error {
tpm.mu.Lock()
defer tpm.mu.Unlock()
var errs []error
for _, tp := range tpm.tracerProviders {
errs = append(errs, tp.ForceFlush(ctx))
}
for _, tp := range tpm.tracerProviders {
errs = append(errs, tp.Shutdown(ctx))
}
clear(tpm.tracerProviders)
return errors.Join(errs...)
}
func (tpm *tracerProviderManager) Add(tp *sdktrace.TracerProvider) {
tpm.mu.Lock()
defer tpm.mu.Unlock()
tpm.tracerProviders = append(tpm.tracerProviders, tp)
}
func (op Options) NewContext(ctx context.Context) context.Context {
var remoteClient otlptrace.Client
if os.Getenv("OTEL_EXPORTER_OTLP_PROTOCOL") == "http/protobuf" {
remoteClient = otlptracehttp.NewClient()
} else {
remoteClient = otlptracegrpc.NewClient()
}
sys := &systemContext{
Options: op,
tpm: &tracerProviderManager{},
}
ctx = context.WithValue(ctx, systemContextKey, sys)
sys.exporterServer = NewServer(ctx, remoteClient)
sys.exporterServer.Start(ctx)
return ctx
}
func NewContext(ctx context.Context) context.Context {
var realClient otlptrace.Client
if os.Getenv("OTEL_EXPORTER_OTLP_PROTOCOL") == "http/protobuf" {
realClient = otlptracehttp.NewClient()
} else {
realClient = otlptracegrpc.NewClient()
}
srv := NewServer(ctx, realClient)
localClient := srv.Start(ctx)
exp, err := otlptrace.New(ctx, localClient)
if err != nil {
panic(err)
}
ctx = context.WithValue(ctx, exporterKey, exp)
ctx = context.WithValue(ctx, serverKey, srv)
return ctx
return Options{}.NewContext(ctx)
}
func NewTracerProvider(ctx context.Context, serviceName string) trace.TracerProvider {
_, file, line, _ := runtime.Caller(1)
exp := ctx.Value(exporterKey).(sdktrace.SpanExporter)
sys := systemContextFromContext(ctx)
exp, err := otlptrace.New(ctx, sys.exporterServer.NewClient())
if err != nil {
panic(err)
}
r, err := resource.Merge(
resource.Default(),
resource.NewWithAttributes(
@ -87,146 +112,40 @@ func NewTracerProvider(ctx context.Context, serviceName string) trace.TracerProv
if err != nil {
panic(err)
}
return sdktrace.NewTracerProvider(
sdktrace.WithSpanProcessor(&stackTraceProcessor{}),
options := []sdktrace.TracerProviderOption{
sdktrace.WithBatcher(exp),
sdktrace.WithResource(r),
)
}
type stackTraceProcessor struct{}
// ForceFlush implements trace.SpanProcessor.
func (s *stackTraceProcessor) ForceFlush(ctx context.Context) error {
return nil
}
// OnEnd implements trace.SpanProcessor.
func (*stackTraceProcessor) OnEnd(s sdktrace.ReadOnlySpan) {
}
// OnStart implements trace.SpanProcessor.
func (*stackTraceProcessor) OnStart(parent context.Context, s sdktrace.ReadWriteSpan) {
_, file, line, _ := runtime.Caller(2)
s.SetAttributes(attribute.String("caller", fmt.Sprintf("%s:%d", file, line)))
}
// Shutdown implements trace.SpanProcessor.
func (s *stackTraceProcessor) Shutdown(ctx context.Context) error {
return nil
}
func ForceFlush(ctx context.Context) error {
if tp, ok := trace.SpanFromContext(ctx).TracerProvider().(interface {
ForceFlush(context.Context) error
}); ok {
return tp.ForceFlush(context.Background())
}
return nil
for _, proc := range sys.exporterServer.SpanProcessors() {
options = append(options, sdktrace.WithSpanProcessor(proc))
}
if sys.DebugLevel >= 1 {
options = append(options,
sdktrace.WithSpanProcessor(&stackTraceProcessor{}),
)
}
tp := sdktrace.NewTracerProvider(options...)
sys.tpm.Add(tp)
return tp
}
func Shutdown(ctx context.Context) error {
_ = ForceFlush(ctx)
exporter := ctx.Value(exporterKey).(sdktrace.SpanExporter)
return exporter.Shutdown(context.Background())
func ShutdownContext(ctx context.Context) error {
var errs []error
sys := systemContextFromContext(ctx)
if err := sys.tpm.ShutdownAll(context.Background()); err != nil {
errs = append(errs, fmt.Errorf("(*tracerProviderManager).ShutdownAll: %w", err))
}
if err := sys.exporterServer.Shutdown(context.Background()); err != nil {
errs = append(errs, fmt.Errorf("(*Server).Shutdown: %w", err))
}
return errors.Join(errs...)
}
func ExporterServerFromContext(ctx context.Context) coltracepb.TraceServiceServer {
return ctx.Value(serverKey).(coltracepb.TraceServiceServer)
return systemContextFromContext(ctx).exporterServer
}
const PomeriumCoreTracer = "pomerium.io/core"
// StartSpan starts a new child span of the current span in the context. If
// there is no span in the context, creates a new trace and span.
//
// Returned context contains the newly created span. You can use it to
// propagate the returned span in process.
func Continue(ctx context.Context, name string, o ...trace.SpanStartOption) (context.Context, trace.Span) {
return trace.SpanFromContext(ctx).TracerProvider().Tracer(PomeriumCoreTracer).Start(ctx, name, o...)
}
func ParseTraceparent(traceparent string) (trace.SpanContext, error) {
parts := strings.Split(traceparent, "-")
if len(parts) != 4 {
return trace.SpanContext{}, errors.New("malformed traceparent")
}
traceId, err := trace.TraceIDFromHex(parts[1])
if err != nil {
return trace.SpanContext{}, err
}
spanId, err := trace.SpanIDFromHex(parts[2])
if err != nil {
return trace.SpanContext{}, err
}
traceFlags, err := strconv.ParseUint(parts[3], 6, 32)
if err != nil {
return trace.SpanContext{}, err
}
if len(traceId) != 16 || len(spanId) != 8 {
return trace.SpanContext{}, errors.New("malformed traceparent")
}
return trace.NewSpanContext(trace.SpanContextConfig{
TraceID: traceId,
SpanID: spanId,
TraceFlags: trace.TraceFlags(traceFlags),
}), nil
}
func ReplaceTraceID(traceparent string, newTraceID trace.TraceID) string {
parts := strings.Split(traceparent, "-")
if len(parts) != 4 {
return traceparent
}
parts[1] = hex.EncodeToString(newTraceID[:])
return strings.Join(parts, "-")
}
func NewStatsHandler(base stats.Handler) stats.Handler {
return &wrapperStatsHandler{
base: base,
}
}
type wrapperStatsHandler struct {
base stats.Handler
}
func (w *wrapperStatsHandler) wrapContext(ctx context.Context) context.Context {
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
return ctx
}
traceparent := md.Get("traceparent")
xPomeriumTraceparent := md.Get("x-pomerium-traceparent")
if len(traceparent) > 0 && traceparent[0] != "" && len(xPomeriumTraceparent) > 0 && xPomeriumTraceparent[0] != "" {
newTracectx, err := ParseTraceparent(xPomeriumTraceparent[0])
if err != nil {
return ctx
}
md.Set("traceparent", ReplaceTraceID(traceparent[0], newTracectx.TraceID()))
return metadata.NewIncomingContext(ctx, md)
}
return ctx
}
// HandleConn implements stats.Handler.
func (w *wrapperStatsHandler) HandleConn(ctx context.Context, stats stats.ConnStats) {
w.base.HandleConn(w.wrapContext(ctx), stats)
}
// HandleRPC implements stats.Handler.
func (w *wrapperStatsHandler) HandleRPC(ctx context.Context, stats stats.RPCStats) {
w.base.HandleRPC(w.wrapContext(ctx), stats)
}
// TagConn implements stats.Handler.
func (w *wrapperStatsHandler) TagConn(ctx context.Context, info *stats.ConnTagInfo) context.Context {
return w.base.TagConn(w.wrapContext(ctx), info)
}
// TagRPC implements stats.Handler.
func (w *wrapperStatsHandler) TagRPC(ctx context.Context, info *stats.RPCTagInfo) context.Context {
return w.base.TagRPC(w.wrapContext(ctx), info)
func WaitForSpans(ctx context.Context, maxDuration time.Duration) error {
return systemContextFromContext(ctx).exporterServer.spanExportQueue.WaitForSpans(maxDuration)
}

View file

@ -0,0 +1,46 @@
package trace
import (
"encoding/hex"
"errors"
"strconv"
"strings"
"go.opentelemetry.io/otel/trace"
)
func ParseTraceparent(traceparent string) (trace.SpanContext, error) {
parts := strings.Split(traceparent, "-")
if len(parts) != 4 {
return trace.SpanContext{}, errors.New("malformed traceparent")
}
traceId, err := trace.TraceIDFromHex(parts[1])
if err != nil {
return trace.SpanContext{}, err
}
spanId, err := trace.SpanIDFromHex(parts[2])
if err != nil {
return trace.SpanContext{}, err
}
traceFlags, err := strconv.ParseUint(parts[3], 6, 32)
if err != nil {
return trace.SpanContext{}, err
}
if len(traceId) != 16 || len(spanId) != 8 {
return trace.SpanContext{}, errors.New("malformed traceparent")
}
return trace.NewSpanContext(trace.SpanContextConfig{
TraceID: traceId,
SpanID: spanId,
TraceFlags: trace.TraceFlags(traceFlags),
}), nil
}
func ReplaceTraceID(traceparent string, newTraceID trace.TraceID) string {
parts := strings.Split(traceparent, "-")
if len(parts) != 4 {
return traceparent
}
parts[1] = hex.EncodeToString(newTraceID[:])
return strings.Join(parts, "-")
}

View file

@ -211,6 +211,7 @@ type environment struct {
logWriter *log.MultiWriter
tracerProvider oteltrace.TracerProvider
tracer oteltrace.Tracer
rootSpan oteltrace.Span
mods []WithCaller[Modifier]
tasks []WithCaller[Task]
@ -267,9 +268,10 @@ func Silent(silent ...bool) EnvironmentOption {
var setGrpcLoggerOnce sync.Once
var (
flagDebug = flag.Bool("env.debug", false, "enables test environment debug logging (equivalent to Debug() option)")
flagPauseOnFailure = flag.Bool("env.pause-on-failure", false, "enables pausing the test environment on failure (equivalent to PauseOnFailure() option)")
flagSilent = flag.Bool("env.silent", false, "suppresses all test environment output (equivalent to Silent() option)")
flagDebug = flag.Bool("env.debug", false, "enables test environment debug logging (equivalent to Debug() option)")
flagPauseOnFailure = flag.Bool("env.pause-on-failure", false, "enables pausing the test environment on failure (equivalent to PauseOnFailure() option)")
flagSilent = flag.Bool("env.silent", false, "suppresses all test environment output (equivalent to Silent() option)")
flagTraceDebugLevel = flag.Int("env.trace-debug-level", 0, "trace debug level")
)
func New(t testing.TB, opts ...EnvironmentOption) Environment {
@ -320,14 +322,16 @@ func New(t testing.TB, opts ...EnvironmentOption) Environment {
})
logger := zerolog.New(writer).With().Timestamp().Logger().Level(zerolog.DebugLevel)
ctx, cancel := context.WithCancelCause(logger.WithContext(trace.NewContext(context.Background())))
t.Cleanup(func() {
trace.Shutdown(ctx)
})
ctx := trace.Options{
DebugLevel: *flagTraceDebugLevel,
}.NewContext(context.Background())
ctx = logger.WithContext(ctx)
tracerProvider := trace.NewTracerProvider(ctx, "Test Environment")
tracer := tracerProvider.Tracer(trace.PomeriumCoreTracer)
ctx, span := tracer.Start(ctx, t.Name())
ctx, span := tracer.Start(ctx, t.Name(), oteltrace.WithNewRoot())
require.NoError(t, err)
ctx, cancel := context.WithCancelCause(ctx)
taskErrGroup, ctx := errgroup.WithContext(ctx)
e := &environment{
@ -352,14 +356,13 @@ func New(t testing.TB, opts ...EnvironmentOption) Environment {
ctx: ctx,
cancel: cancel,
tracerProvider: tracerProvider,
tracer: tracerProvider.Tracer(trace.PomeriumCoreTracer),
tracer: tracer,
logWriter: writer,
taskErrGroup: taskErrGroup,
stateChangeListeners: make(map[EnvironmentState][]func()),
rootSpan: span,
}
e.OnStateChanged(Stopped, func() {
span.End()
})
_, err = rand.Read(e.sharedSecret[:])
require.NoError(t, err)
_, err = rand.Read(e.cookieSecret[:])
@ -561,6 +564,7 @@ func (e *environment) Start() {
opts := []pomerium.Option{
pomerium.WithOverrideFileManager(fileMgr),
pomerium.WithEnvoyServerOptions(envoy.WithExitGracePeriod(10 * time.Second)),
}
envoyBinaryPath := filepath.Join(e.workspaceFolder, fmt.Sprintf("pkg/envoy/files/envoy-%s-%s", runtime.GOOS, runtime.GOARCH))
if envutil.EnvoyProfilerAvailable(envoyBinaryPath) {
@ -591,10 +595,7 @@ func (e *environment) Start() {
}
if len(envVars) > 0 {
e.debugf("adding envoy env vars: %v\n", envVars)
opts = append(opts, pomerium.WithEnvoyServerOptions(
envoy.WithExtraEnvVars(envVars...),
envoy.WithExitGracePeriod(10*time.Second), // allow envoy time to flush pprof data to disk
))
opts = append(opts, pomerium.WithEnvoyServerOptions(envoy.WithExtraEnvVars(envVars...)))
}
} else {
e.debugf("envoy profiling not available")
@ -602,7 +603,11 @@ func (e *environment) Start() {
pom := pomerium.New(opts...)
e.OnStateChanged(Stopping, func() {
pom.Shutdown()
if err := pom.Shutdown(ctx); err != nil {
log.Ctx(ctx).Err(err).Msg("error shutting down pomerium server")
} else {
e.debugf("pomerium server shut down without error")
}
})
pom.Start(ctx, e.tracerProvider, e.src)
return pom.Wait()
@ -742,6 +747,8 @@ func (e *environment) Stop() {
err := e.taskErrGroup.Wait()
e.advanceState(Stopped)
e.debugf("stop: done waiting")
e.rootSpan.End()
assert.NoError(e.t, trace.ShutdownContext(e.ctx))
assert.ErrorIs(e.t, err, ErrCauseManualStop)
})
}

View file

@ -0,0 +1,55 @@
package selftests_test
import (
"context"
"io"
"net/http"
"testing"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/testenv"
"github.com/pomerium/pomerium/internal/testenv/scenarios"
"github.com/pomerium/pomerium/internal/testenv/snippets"
"github.com/pomerium/pomerium/internal/testenv/upstreams"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.opentelemetry.io/otel/trace"
)
func TestOTLPTracing(t *testing.T) {
t.Setenv("OTEL_EXPORTER_OTLP_TRACES_ENDPOINT", "http://localhost:4317")
env := testenv.New(t)
defer env.Stop()
env.Add(testenv.ModifierFunc(func(ctx context.Context, cfg *config.Config) {
cfg.Options.ProxyLogLevel = config.LogLevelInfo
}))
up := upstreams.HTTP(nil, upstreams.WithDisplayName("Upstream"))
up.Handle("/foo", func(w http.ResponseWriter, req *http.Request) {
w.Write([]byte("OK"))
})
env.Add(scenarios.NewIDP([]*scenarios.User{
{
Email: "foo@example.com",
FirstName: "Firstname",
LastName: "Lastname",
},
}))
route := up.Route().
From(env.SubdomainURL("foo")).
PPL(`{"allow":{"and":["email":{"is":"foo@example.com"}]}}`)
env.AddUpstream(up)
env.Start()
snippets.WaitStartupComplete(env)
ctx, span := env.Tracer().Start(env.Context(), "Authenticate", trace.WithNewRoot())
resp, err := up.Get(route, upstreams.AuthenticateAs("foo@example.com"), upstreams.Path("/foo"), upstreams.Context(ctx))
span.End()
require.NoError(t, err)
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
resp.Body.Close()
assert.Equal(t, resp.StatusCode, 200)
assert.Equal(t, "OK", string(body))
}

View file

@ -157,7 +157,7 @@ type httpUpstream struct {
clientCache sync.Map // map[testenv.Route]*http.Client
router *mux.Router
tracerProvider oteltrace.TracerProvider
tracerProvider values.MutableValue[oteltrace.TracerProvider]
}
var (
@ -176,6 +176,7 @@ func HTTP(tlsConfig values.Value[*tls.Config], opts ...HTTPUpstreamOption) HTTPU
serverPort: values.Deferred[int](),
router: mux.NewRouter(),
tlsConfig: tlsConfig,
tracerProvider: values.Deferred[oteltrace.TracerProvider](),
}
up.RecordCaller()
return up
@ -213,8 +214,8 @@ func (h *httpUpstream) Run(ctx context.Context) error {
if h.tlsConfig != nil {
tlsConfig = h.tlsConfig.Value()
}
h.router.Use(trace.NewHTTPMiddleware(otelhttp.WithTracerProvider(h.tracerProvider)))
h.tracerProvider = trace.NewTracerProvider(ctx, h.displayName)
h.tracerProvider.Resolve(trace.NewTracerProvider(ctx, h.displayName))
h.router.Use(trace.NewHTTPMiddleware(otelhttp.WithTracerProvider(h.tracerProvider.Value())))
server := &http.Server{
Handler: h.router,
@ -263,34 +264,6 @@ func (h *httpUpstream) Do(method string, r testenv.Route, opts ...RequestOption)
})
}
req, err := http.NewRequestWithContext(options.requestCtx, method, u.String(), nil)
if err != nil {
return nil, err
}
switch body := options.body.(type) {
case string:
req.Body = io.NopCloser(strings.NewReader(body))
case []byte:
req.Body = io.NopCloser(bytes.NewReader(body))
case io.Reader:
req.Body = io.NopCloser(body)
case proto.Message:
buf, err := proto.Marshal(body)
if err != nil {
return nil, err
}
req.Body = io.NopCloser(bytes.NewReader(buf))
req.Header.Set("Content-Type", "application/octet-stream")
default:
buf, err := json.Marshal(body)
if err != nil {
panic(fmt.Sprintf("unsupported body type: %T", body))
}
req.Body = io.NopCloser(bytes.NewReader(buf))
req.Header.Set("Content-Type", "application/json")
case nil:
}
newClient := func() *http.Client {
c := http.Client{
Transport: otelhttp.NewTransport(&http.Transport{
@ -299,7 +272,7 @@ func (h *httpUpstream) Do(method string, r testenv.Route, opts ...RequestOption)
Certificates: options.clientCerts,
},
},
otelhttp.WithTracerProvider(h.tracerProvider),
otelhttp.WithTracerProvider(h.tracerProvider.Value()),
otelhttp.WithSpanNameFormatter(func(operation string, r *http.Request) string {
return fmt.Sprintf("Client: %s %s", r.Method, r.URL.Path)
}),
@ -322,11 +295,38 @@ func (h *httpUpstream) Do(method string, r testenv.Route, opts ...RequestOption)
var resp *http.Response
if err := retry.Retry(options.requestCtx, "http", func(ctx context.Context) error {
var err error
req, err := http.NewRequestWithContext(options.requestCtx, method, u.String(), nil)
if err != nil {
return err
}
switch body := options.body.(type) {
case string:
req.Body = io.NopCloser(strings.NewReader(body))
case []byte:
req.Body = io.NopCloser(bytes.NewReader(body))
case io.Reader:
req.Body = io.NopCloser(body)
case proto.Message:
buf, err := proto.Marshal(body)
if err != nil {
return err
}
req.Body = io.NopCloser(bytes.NewReader(buf))
req.Header.Set("Content-Type", "application/octet-stream")
default:
buf, err := json.Marshal(body)
if err != nil {
panic(fmt.Sprintf("unsupported body type: %T", body))
}
req.Body = io.NopCloser(bytes.NewReader(buf))
req.Header.Set("Content-Type", "application/json")
case nil:
}
if options.authenticateAs != "" {
resp, err = authenticateFlow(ctx, client, req, options.authenticateAs) //nolint:bodyclose
resp, err = authenticateFlow(ctx, client, req, options.authenticateAs)
} else {
resp, err = client.Do(req) //nolint:bodyclose
resp, err = client.Do(req)
}
// retry on connection refused
if err != nil {
@ -338,6 +338,9 @@ func (h *httpUpstream) Do(method string, r testenv.Route, opts ...RequestOption)
return retry.NewTerminalError(err)
}
if resp.StatusCode/100 == 5 {
if err := resp.Body.Close(); err != nil {
panic(err)
}
oteltrace.SpanFromContext(ctx).AddEvent("Retrying on 5xx error", oteltrace.WithAttributes(
attribute.String("status", resp.Status),
))
@ -357,7 +360,6 @@ func authenticateFlow(ctx context.Context, client *http.Client, req *http.Reques
if err != nil {
return nil, err
}
location := res.Request.URL
if location.Hostname() == originalHostname {
// already authenticated

View file

@ -205,10 +205,13 @@ func (p *Pomerium) Start(ctx context.Context, tracerProvider oteltrace.TracerPro
return nil
}
func (p *Pomerium) Shutdown() error {
_ = p.envoyServer.Close() // this only errors if signaling envoy fails
func (p *Pomerium) Shutdown(ctx context.Context) error {
_ = trace.WaitForSpans(ctx, p.envoyServer.ExitGracePeriod())
var errs []error
errs = append(errs, p.envoyServer.Close()) // this only errors if signaling envoy fails
p.cancel(ErrShutdown)
return p.Wait()
errs = append(errs, p.Wait())
return errors.Join(errs...)
}
func (p *Pomerium) Wait() error {

View file

@ -61,6 +61,10 @@ type ServerOptions struct {
exitGracePeriod time.Duration
}
func (opts *ServerOptions) ExitGracePeriod() time.Duration {
return opts.exitGracePeriod
}
type ServerOption func(*ServerOptions)
func (o *ServerOptions) apply(opts ...ServerOption) {