move internal/telemetry/trace => pkg/telemetry/trace (#5541)

This commit is contained in:
Joe Kralicky 2025-03-25 10:43:04 -04:00 committed by GitHub
parent ab5f3ac7f3
commit a96ab2fe93
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
49 changed files with 40 additions and 40 deletions

View file

@ -23,11 +23,11 @@ import (
"github.com/pomerium/pomerium/internal/events"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/registry"
"github.com/pomerium/pomerium/internal/telemetry/trace"
"github.com/pomerium/pomerium/internal/version"
derivecert_config "github.com/pomerium/pomerium/pkg/derivecert/config"
"github.com/pomerium/pomerium/pkg/envoy"
"github.com/pomerium/pomerium/pkg/envoy/files"
"github.com/pomerium/pomerium/pkg/telemetry/trace"
"github.com/pomerium/pomerium/proxy"
oteltrace "go.opentelemetry.io/otel/trace"
)

View file

@ -16,11 +16,11 @@ import (
"golang.org/x/oauth2"
"github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/telemetry/trace"
"github.com/pomerium/pomerium/internal/urlutil"
"github.com/pomerium/pomerium/internal/version"
"github.com/pomerium/pomerium/pkg/identity/identity"
"github.com/pomerium/pomerium/pkg/identity/oauth"
"github.com/pomerium/pomerium/pkg/telemetry/trace"
)
// Name identifies the generic OpenID Connect provider.

View file

@ -0,0 +1,28 @@
package trace
import (
"net/url"
"go.opentelemetry.io/otel/propagation"
)
type PomeriumURLQueryCarrier url.Values
// Get implements propagation.TextMapCarrier.
func (q PomeriumURLQueryCarrier) Get(key string) string {
return url.Values(q).Get("pomerium_" + key)
}
// Set implements propagation.TextMapCarrier.
func (q PomeriumURLQueryCarrier) Set(key string, value string) {
url.Values(q).Set("pomerium_"+key, value)
}
// Keys implements propagation.TextMapCarrier.
func (q PomeriumURLQueryCarrier) Keys() []string {
// this function is never called in otel, and the way it would be
// implemented in this instance is unclear.
panic("unimplemented")
}
var _ propagation.TextMapCarrier = PomeriumURLQueryCarrier{}

View file

@ -0,0 +1,30 @@
package trace_test
import (
"net/url"
"testing"
"github.com/pomerium/pomerium/pkg/telemetry/trace"
"github.com/stretchr/testify/assert"
)
func TestPomeriumURLQueryCarrier(t *testing.T) {
t.Parallel()
values := url.Values{}
carrier := trace.PomeriumURLQueryCarrier(values)
assert.Empty(t, carrier.Get("foo"))
carrier.Set("foo", "bar")
assert.Equal(t, url.Values{
"pomerium_foo": []string{"bar"},
}, values)
assert.Equal(t, "bar", carrier.Get("foo"))
carrier.Set("foo", "bar2")
assert.Equal(t, url.Values{
"pomerium_foo": []string{"bar2"},
}, values)
assert.Equal(t, "bar2", carrier.Get("foo"))
assert.Panics(t, func() {
carrier.Keys()
})
}

View file

@ -0,0 +1,301 @@
package trace
import (
"context"
"errors"
"fmt"
"net/url"
"os"
"strings"
"sync"
"time"
"github.com/pomerium/pomerium/config/otelconfig"
"go.opentelemetry.io/otel/exporters/otlp/otlptrace"
"go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc"
"go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp"
oteltrace "go.opentelemetry.io/otel/trace"
"go.opentelemetry.io/otel/trace/noop"
v1 "go.opentelemetry.io/proto/otlp/trace/v1"
)
var (
ErrNoClient = errors.New("no client")
ErrClientStopped = errors.New("client is stopped")
)
// SyncClient wraps an underlying [otlptrace.Client] which can be swapped out
// for a different client (e.g. in response to a config update) safely and in
// a way that does not lose spans.
type SyncClient interface {
otlptrace.Client
// Update safely replaces the current trace client with the one provided.
// The new client must be unstarted. The old client (if any) will be stopped.
//
// This function is NOT reentrant; callers must use appropriate locking.
Update(ctx context.Context, newClient otlptrace.Client) error
}
// NewSyncClient creates a new [SyncClient] with an initial underlying client.
//
// The client can be nil; if so, calling any method on the SyncClient will
// return ErrNoClient.
func NewSyncClient(client otlptrace.Client) SyncClient {
return &syncClient{
client: client,
}
}
type syncClient struct {
mu sync.Mutex
client otlptrace.Client
waitForNewClient chan struct{}
}
var _ SyncClient = (*syncClient)(nil)
// Start implements otlptrace.Client.
func (ac *syncClient) Start(ctx context.Context) error {
ac.mu.Lock()
defer ac.mu.Unlock()
if ac.waitForNewClient != nil {
panic("bug: Start called during Stop or Update")
}
if ac.client == nil {
return ErrNoClient
}
return ac.client.Start(ctx)
}
// Stop implements otlptrace.Client.
func (ac *syncClient) Stop(ctx context.Context) error {
ac.mu.Lock()
defer ac.mu.Unlock()
if ac.waitForNewClient != nil {
panic("bug: Stop called concurrently")
}
if ac.client == nil {
return ErrNoClient
}
return ac.resetLocked(ctx, nil)
}
func (ac *syncClient) resetLocked(ctx context.Context, newClient otlptrace.Client) error {
var stop func(context.Context) error
if ac.client != nil {
stop = ac.client.Stop
}
ac.waitForNewClient = make(chan struct{})
ac.mu.Unlock()
var err error
if stop != nil {
err = stop(ctx)
}
ac.mu.Lock()
close(ac.waitForNewClient)
ac.waitForNewClient = nil
ac.client = newClient
return err
}
// UploadTraces implements otlptrace.Client.
func (ac *syncClient) UploadTraces(ctx context.Context, protoSpans []*v1.ResourceSpans) error {
ac.mu.Lock()
if ac.waitForNewClient != nil {
wait := ac.waitForNewClient
ac.mu.Unlock()
select {
case <-wait:
ac.mu.Lock()
case <-ctx.Done():
return context.Cause(ctx)
}
} else if ac.client == nil {
ac.mu.Unlock()
return ErrNoClient
}
client := ac.client
ac.mu.Unlock()
if client == nil {
return ErrClientStopped
}
return client.UploadTraces(ctx, protoSpans)
}
func (ac *syncClient) Update(ctx context.Context, newClient otlptrace.Client) error {
if newClient != nil {
if err := newClient.Start(ctx); err != nil {
return fmt.Errorf("error starting new client: %w", err)
}
}
ac.mu.Lock()
defer ac.mu.Unlock()
if ac.waitForNewClient != nil {
panic("bug: Update called during Stop")
}
if newClient == ac.client {
return nil
}
return ac.resetLocked(ctx, newClient)
}
func NewTraceClientFromConfig(opts otelconfig.Config) (otlptrace.Client, error) {
if IsOtelSDKDisabled() {
return NoopClient{}, nil
}
if opts.OtelTracesExporter == nil {
return NoopClient{}, nil
}
switch *opts.OtelTracesExporter {
case "otlp":
var endpoint, protocol string
var signalSpecificEndpoint bool
if opts.OtelExporterOtlpTracesEndpoint != nil {
endpoint = *opts.OtelExporterOtlpTracesEndpoint
signalSpecificEndpoint = true
} else if opts.OtelExporterOtlpEndpoint != nil {
endpoint = *opts.OtelExporterOtlpEndpoint
signalSpecificEndpoint = false
}
if opts.OtelExporterOtlpTracesProtocol != nil {
protocol = *opts.OtelExporterOtlpTracesProtocol
} else if opts.OtelExporterOtlpProtocol != nil {
protocol = *opts.OtelExporterOtlpProtocol
}
if protocol == "" {
protocol = BestEffortProtocolFromOTLPEndpoint(endpoint, signalSpecificEndpoint)
}
var headersList []string
if len(opts.OtelExporterOtlpTracesHeaders) > 0 {
headersList = opts.OtelExporterOtlpTracesHeaders
} else if len(opts.OtelExporterOtlpHeaders) > 0 {
headersList = opts.OtelExporterOtlpHeaders
}
headers := map[string]string{}
for _, kv := range headersList {
k, v, ok := strings.Cut(kv, "=")
if ok {
headers[k] = v
}
}
defaultTimeout := 10 * time.Second // otel default (not exported)
if opts.OtelExporterOtlpTimeout != nil {
defaultTimeout = max(0, time.Duration(*opts.OtelExporterOtlpTimeout)*time.Millisecond)
}
switch strings.ToLower(strings.TrimSpace(protocol)) {
case "grpc":
return otlptracegrpc.NewClient(
otlptracegrpc.WithEndpointURL(endpoint),
otlptracegrpc.WithHeaders(headers),
otlptracegrpc.WithTimeout(defaultTimeout),
), nil
case "http/protobuf", "":
return otlptracehttp.NewClient(
otlptracehttp.WithEndpointURL(endpoint),
otlptracehttp.WithHeaders(headers),
otlptracehttp.WithTimeout(defaultTimeout),
), nil
default:
return nil, fmt.Errorf(`unknown otlp trace exporter protocol %q, expected one of ["grpc", "http/protobuf"]`, protocol)
}
case "none", "noop", "":
return NoopClient{}, nil
default:
return nil, fmt.Errorf(`unknown otlp trace exporter %q, expected one of ["otlp", "none"]`, *opts.OtelTracesExporter)
}
}
func BestEffortProtocolFromOTLPEndpoint(endpoint string, specificEnv bool) string {
if endpoint == "" {
return ""
}
u, err := url.Parse(endpoint)
if err != nil {
return ""
}
switch u.Port() {
case "4318":
return "http/protobuf"
case "4317":
return "grpc"
default:
// For http, if the signal-specific form of the endpoint env variable
// (e.g. $OTEL_EXPORTER_OTLP_TRACES_ENDPOINT) is used, the /v1/<signal>
// ^^^^^^
// path must be present. Otherwise, the path must _not_ be present,
// because the sdk will add it.
// This doesn't apply to grpc endpoints, so assume grpc if there is a
// conflict here.
hasPath := len(strings.Trim(u.Path, "/")) > 0
switch {
case hasPath && specificEnv:
return "http/protobuf"
case !hasPath && specificEnv:
return "grpc"
case hasPath && !specificEnv:
// would be invalid for http, so assume it's grpc on a subpath
return "grpc"
case !hasPath && !specificEnv:
// could be either, but default to http
return "http/protobuf"
}
panic("unreachable")
}
}
type NoopClient struct{}
// Start implements otlptrace.Client.
func (n NoopClient) Start(context.Context) error {
return nil
}
// Stop implements otlptrace.Client.
func (n NoopClient) Stop(context.Context) error {
return nil
}
// UploadTraces implements otlptrace.Client.
func (n NoopClient) UploadTraces(context.Context, []*v1.ResourceSpans) error {
return nil
}
// ValidNoopSpan is the same as noop.Span, except with a "valid" span context
// (has a non-zero trace and span ID).
//
// Adding this into a context as follows:
//
// ctx = oteltrace.ContextWithSpan(ctx, trace.ValidNoopSpan{})
//
// will prevent some usages of the global tracer provider by libraries such
// as otelhttp, which only uses the global provider if the context's span
// is "invalid".
type ValidNoopSpan struct {
noop.Span
}
var noopTraceID = oteltrace.TraceID{
0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
}
var noopSpanID = oteltrace.SpanID{
0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
}
// SpanContext implements trace.Span.
func (n ValidNoopSpan) SpanContext() oteltrace.SpanContext {
return n.Span.SpanContext().WithTraceID(noopTraceID).WithSpanID(noopSpanID)
}
var _ oteltrace.Span = ValidNoopSpan{}
func IsOtelSDKDisabled() bool {
return os.Getenv("OTEL_SDK_DISABLED") == "true"
}

View file

@ -0,0 +1,580 @@
package trace_test
import (
"context"
"fmt"
"os"
"path/filepath"
"runtime"
"strings"
"sync/atomic"
"testing"
"time"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/testenv"
"github.com/pomerium/pomerium/internal/testenv/scenarios"
"github.com/pomerium/pomerium/internal/testenv/snippets"
. "github.com/pomerium/pomerium/internal/testutil/tracetest" //nolint:revive
"github.com/pomerium/pomerium/internal/testutil/tracetest/mock_otlptrace"
"github.com/pomerium/pomerium/internal/version"
"github.com/pomerium/pomerium/pkg/telemetry/trace"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.opentelemetry.io/otel"
tracev1 "go.opentelemetry.io/proto/otlp/trace/v1"
"go.uber.org/mock/gomock"
)
func TestSyncClient(t *testing.T) {
t.Run("No client", func(t *testing.T) {
sc := trace.NewSyncClient(nil)
assert.ErrorIs(t, sc.Start(context.Background()), trace.ErrNoClient)
assert.ErrorIs(t, sc.UploadTraces(context.Background(), nil), trace.ErrNoClient)
assert.ErrorIs(t, sc.Stop(context.Background()), trace.ErrNoClient)
})
t.Run("Valid client", func(t *testing.T) {
ctrl := gomock.NewController(t)
mockClient := mock_otlptrace.NewMockClient(ctrl)
start := mockClient.EXPECT().
Start(gomock.Any()).
Return(nil)
upload := mockClient.EXPECT().
UploadTraces(gomock.Any(), gomock.Any()).
Return(nil).
After(start)
mockClient.EXPECT().
Stop(gomock.Any()).
Return(nil).
After(upload)
sc := trace.NewSyncClient(mockClient)
assert.NoError(t, sc.Start(context.Background()))
assert.NoError(t, sc.UploadTraces(context.Background(), []*tracev1.ResourceSpans{}))
assert.NoError(t, sc.Stop(context.Background()))
})
t.Run("Update", func(t *testing.T) {
ctrl := gomock.NewController(t)
mockClient1 := mock_otlptrace.NewMockClient(ctrl)
mockClient2 := mock_otlptrace.NewMockClient(ctrl)
start1 := mockClient1.EXPECT().
Start(gomock.Any()).
Return(nil)
upload1 := mockClient1.EXPECT().
UploadTraces(gomock.Any(), gomock.Any()).
Return(nil).
After(start1)
start2 := mockClient2.EXPECT().
Start(gomock.Any()).
Return(nil).
After(upload1)
stop1 := mockClient1.EXPECT().
Stop(gomock.Any()).
Return(nil).
After(start2)
upload2 := mockClient2.EXPECT().
UploadTraces(gomock.Any(), gomock.Any()).
Return(nil).
After(stop1)
mockClient2.EXPECT().
Stop(gomock.Any()).
Return(nil).
After(upload2)
sc := trace.NewSyncClient(mockClient1)
assert.NoError(t, sc.Start(context.Background()))
assert.NoError(t, sc.UploadTraces(context.Background(), []*tracev1.ResourceSpans{}))
assert.NoError(t, sc.Update(context.Background(), mockClient2))
assert.NoError(t, sc.UploadTraces(context.Background(), []*tracev1.ResourceSpans{}))
assert.NoError(t, sc.Stop(context.Background()))
})
t.Run("Update from nil client to non-nil client", func(t *testing.T) {
ctrl := gomock.NewController(t)
sc := trace.NewSyncClient(nil)
mockClient := mock_otlptrace.NewMockClient(ctrl)
start := mockClient.EXPECT().
Start(gomock.Any()).
Return(nil)
upload := mockClient.EXPECT().
UploadTraces(gomock.Any(), gomock.Any()).
Return(nil).
After(start)
mockClient.EXPECT().
Stop(gomock.Any()).
Return(nil).
After(upload)
assert.NoError(t, sc.Update(context.Background(), mockClient))
assert.NoError(t, sc.UploadTraces(context.Background(), []*tracev1.ResourceSpans{}))
assert.NoError(t, sc.Stop(context.Background()))
})
t.Run("Update from non-nil client to nil client", func(t *testing.T) {
ctrl := gomock.NewController(t)
sc := trace.NewSyncClient(nil)
{
mockClient := mock_otlptrace.NewMockClient(ctrl)
start := mockClient.EXPECT().
Start(gomock.Any()).
Return(nil)
mockClient.EXPECT().
Stop(gomock.Any()).
Return(nil).
After(start)
assert.NoError(t, sc.Update(context.Background(), mockClient))
}
sc.Update(context.Background(), nil)
assert.ErrorIs(t, sc.UploadTraces(context.Background(), []*tracev1.ResourceSpans{}), trace.ErrNoClient)
})
spinWait := func(counter *atomic.Int32, until int32) error {
startTime := time.Now()
for counter.Load() != until {
if time.Since(startTime) > 1*time.Second {
return fmt.Errorf("timed out waiting for counter to equal %d", until)
}
}
return nil
}
t.Run("Concurrent UploadTraces", func(t *testing.T) {
ctrl := gomock.NewController(t)
mockClient1 := mock_otlptrace.NewMockClient(ctrl)
count := atomic.Int32{}
unlock := make(chan struct{})
concurrency := min(runtime.NumCPU(), 4)
mockClient1.EXPECT().
UploadTraces(gomock.Any(), gomock.Any()).
DoAndReturn(func(context.Context, []*tracev1.ResourceSpans) error {
count.Add(1)
defer count.Add(-1)
<-unlock
return nil
}).
Times(concurrency)
sc := trace.NewSyncClient(mockClient1)
start := make(chan struct{})
for range concurrency {
go func() {
runtime.LockOSThread()
defer runtime.UnlockOSThread()
<-start
require.NoError(t, sc.UploadTraces(context.Background(), []*tracev1.ResourceSpans{}))
}()
}
runtime.LockOSThread()
defer runtime.UnlockOSThread()
close(start)
assert.NoError(t, spinWait(&count, int32(concurrency)))
})
t.Run("Concurrent Update/UploadTraces", func(t *testing.T) {
runtime.LockOSThread()
defer runtime.UnlockOSThread()
ctrl := gomock.NewController(t)
mockClient1 := mock_otlptrace.NewMockClient(ctrl)
mockClient2 := mock_otlptrace.NewMockClient(ctrl)
uploadTracesCount1 := atomic.Int32{}
uploadTracesCount2 := atomic.Int32{}
unlock1 := make(chan struct{})
unlock2 := make(chan struct{})
waitForStop := make(chan struct{})
concurrency := min(runtime.NumCPU(), 4)
// start 1 -> upload 1 -> start 2 -> stop 1 -> upload 2 -> stop 2
fStart1 := mockClient1.EXPECT().
Start(gomock.Any()).
Return(nil)
fUpload1 := mockClient1.EXPECT().
UploadTraces(gomock.Any(), gomock.Any()).
DoAndReturn(func(context.Context, []*tracev1.ResourceSpans) error {
// called from non-test threads
uploadTracesCount1.Add(1)
defer uploadTracesCount1.Add(-1)
<-unlock1
return nil
}).
Times(concurrency).
After(fStart1)
fStart2 := mockClient2.EXPECT().
Start(gomock.Any()).
Return(nil).
After(fUpload1)
fStop1 := mockClient1.EXPECT().
Stop(gomock.Any()).
DoAndReturn(func(context.Context) error {
// called from test thread
close(unlock1)
assert.NoError(t, spinWait(&uploadTracesCount1, 0))
return nil
}).
After(fStart2)
fUpload2 := mockClient2.EXPECT().
UploadTraces(gomock.Any(), gomock.Any()).
DoAndReturn(func(context.Context, []*tracev1.ResourceSpans) error {
// called from non-test threads
uploadTracesCount2.Add(1)
defer uploadTracesCount2.Add(-1)
<-unlock2
return nil
}).
Times(concurrency).
After(fStop1)
mockClient2.EXPECT().
Stop(gomock.Any()).
DoAndReturn(func(context.Context) error {
// called from test thread
close(unlock2)
assert.NoError(t, spinWait(&uploadTracesCount2, 0))
close(waitForStop)
// no way around sleeping here - we have to give the other threads time
// to call UploadTraces and block waiting on waitForNewClient to be
// closed, which happens after this function returns
time.Sleep(10 * time.Millisecond)
return nil
}).
After(fUpload2)
sc := trace.NewSyncClient(mockClient1)
require.NoError(t, sc.Start(context.Background()))
for range concurrency {
go func() {
require.NoError(t, sc.UploadTraces(context.Background(), []*tracev1.ResourceSpans{}))
}()
}
require.NoError(t, spinWait(&uploadTracesCount1, int32(concurrency)))
// at this point, all calls to UploadTraces for client1 are blocked
for range concurrency {
go func() {
<-unlock1 // wait for client1.Stop
// after this, calls to UploadTraces will block waiting for the
// new client, instead of using the old one we're about to close
require.NoError(t, sc.UploadTraces(context.Background(), []*tracev1.ResourceSpans{}))
}()
}
require.NoError(t, sc.Update(context.Background(), mockClient2))
require.NoError(t, spinWait(&uploadTracesCount2, int32(concurrency)))
// at this point, all calls to UploadTraces for client2 are blocked.
// while SyncClient is waiting for the underlying client to stop during
// sc.Stop(), *new* calls to sc.UploadTraces will wait for it to stop, then
// error with trace.ErrClientStopped, but the previous calls blocked in
// client2 will complete without error.
for range concurrency {
go func() {
<-waitForStop
assert.ErrorIs(t, sc.UploadTraces(context.Background(), []*tracev1.ResourceSpans{}), trace.ErrClientStopped)
}()
}
assert.NoError(t, sc.Stop(context.Background()))
// sanity checks
assert.ErrorIs(t, sc.UploadTraces(context.Background(), []*tracev1.ResourceSpans{}), trace.ErrNoClient)
assert.ErrorIs(t, sc.Start(context.Background()), trace.ErrNoClient)
assert.ErrorIs(t, sc.Stop(context.Background()), trace.ErrNoClient)
assert.NoError(t, sc.Update(context.Background(), nil))
})
}
type errHandler struct {
err error
}
var _ otel.ErrorHandler = (*errHandler)(nil)
func (h *errHandler) Handle(err error) {
h.err = err
}
func TestNewTraceClientFromConfig(t *testing.T) {
env := testenv.New(t, testenv.WithTraceDebugFlags(testenv.StandardTraceDebugFlags))
receiver := scenarios.NewOTLPTraceReceiver()
env.Add(receiver)
grpcEndpoint := receiver.GRPCEndpointURL()
httpEndpoint := receiver.HTTPEndpointURL()
emptyConfigFilePath := filepath.Join(env.TempDir(), "empty_config.yaml")
require.NoError(t, os.WriteFile(emptyConfigFilePath, []byte("{}"), 0o644))
env.Start()
snippets.WaitStartupComplete(env)
for _, tc := range []struct {
name string
env map[string]string
newClientErr string
uploadErr bool
expectNoSpans bool
expectHeaders map[string][]string
}{
{
name: "GRPC endpoint, unset protocol",
env: map[string]string{
"OTEL_TRACES_EXPORTER": "otlp",
"OTEL_EXPORTER_OTLP_TRACES_ENDPOINT": grpcEndpoint.Value(),
},
},
{
name: "GRPC endpoint, empty protocol",
env: map[string]string{
"OTEL_TRACES_EXPORTER": "otlp",
"OTEL_EXPORTER_OTLP_TRACES_ENDPOINT": grpcEndpoint.Value(),
"OTEL_EXPORTER_OTLP_TRACES_PROTOCOL": "",
},
},
{
name: "GRPC endpoint, alternate env, unset protocol",
env: map[string]string{
"OTEL_TRACES_EXPORTER": "otlp",
"OTEL_EXPORTER_OTLP_ENDPOINT": grpcEndpoint.Value(),
},
uploadErr: true,
},
{
name: "GRPC endpoint, alternate env, empty protocol",
env: map[string]string{
"OTEL_TRACES_EXPORTER": "otlp",
"OTEL_EXPORTER_OTLP_ENDPOINT": grpcEndpoint.Value(),
"OTEL_EXPORTER_OTLP_PROTOCOL": "",
},
uploadErr: true,
},
{
name: "HTTP endpoint, unset protocol",
env: map[string]string{
"OTEL_TRACES_EXPORTER": "otlp",
"OTEL_EXPORTER_OTLP_TRACES_ENDPOINT": httpEndpoint.Value(),
},
},
{
name: "HTTP endpoint, empty protocol",
env: map[string]string{
"OTEL_TRACES_EXPORTER": "otlp",
"OTEL_EXPORTER_OTLP_TRACES_ENDPOINT": httpEndpoint.Value(),
"OTEL_EXPORTER_OTLP_TRACES_PROTOCOL": "",
},
},
{
name: "HTTP endpoint, alternate env, unset protocol",
env: map[string]string{
"OTEL_TRACES_EXPORTER": "otlp",
"OTEL_EXPORTER_OTLP_ENDPOINT": strings.TrimSuffix(httpEndpoint.Value(), "/v1/traces"), // path is added automatically by the sdk here
},
},
{
name: "HTTP endpoint, alternate env, empty protocol",
env: map[string]string{
"OTEL_TRACES_EXPORTER": "otlp",
"OTEL_EXPORTER_OTLP_ENDPOINT": strings.TrimSuffix(httpEndpoint.Value(), "/v1/traces"),
"OTEL_EXPORTER_OTLP_PROTOCOL": "",
},
},
{
name: "GRPC endpoint, explicit protocol",
env: map[string]string{
"OTEL_TRACES_EXPORTER": "otlp",
"OTEL_EXPORTER_OTLP_TRACES_ENDPOINT": grpcEndpoint.Value(),
"OTEL_EXPORTER_OTLP_TRACES_PROTOCOL": "grpc",
},
},
{
name: "HTTP endpoint, explicit protocol",
env: map[string]string{
"OTEL_TRACES_EXPORTER": "otlp",
"OTEL_EXPORTER_OTLP_TRACES_ENDPOINT": httpEndpoint.Value(),
"OTEL_EXPORTER_OTLP_TRACES_PROTOCOL": "http/protobuf",
},
},
{
name: "exporter unset",
env: map[string]string{
"OTEL_TRACES_EXPORTER": "",
"OTEL_EXPORTER_OTLP_TRACES_ENDPOINT": httpEndpoint.Value(),
"OTEL_EXPORTER_OTLP_TRACES_PROTOCOL": "http/protobuf",
},
expectNoSpans: true,
},
{
name: "exporter noop",
env: map[string]string{
"OTEL_TRACES_EXPORTER": "noop",
"OTEL_EXPORTER_OTLP_TRACES_ENDPOINT": httpEndpoint.Value(),
"OTEL_EXPORTER_OTLP_TRACES_PROTOCOL": "http/protobuf",
},
expectNoSpans: true,
},
{
name: "exporter none",
env: map[string]string{
"OTEL_TRACES_EXPORTER": "none",
"OTEL_EXPORTER_OTLP_TRACES_ENDPOINT": httpEndpoint.Value(),
"OTEL_EXPORTER_OTLP_TRACES_PROTOCOL": "http/protobuf",
},
expectNoSpans: true,
},
{
name: "invalid exporter",
env: map[string]string{
"OTEL_TRACES_EXPORTER": "invalid",
},
newClientErr: `unknown otlp trace exporter "invalid", expected one of ["otlp", "none"]`,
},
{
name: "invalid protocol",
env: map[string]string{
"OTEL_TRACES_EXPORTER": "otlp",
"OTEL_EXPORTER_OTLP_TRACES_ENDPOINT": grpcEndpoint.Value(),
"OTEL_EXPORTER_OTLP_TRACES_PROTOCOL": "invalid",
},
newClientErr: `unknown otlp trace exporter protocol "invalid", expected one of ["grpc", "http/protobuf"]`,
},
{
name: "valid configuration, but sdk disabled",
env: map[string]string{
"OTEL_TRACES_EXPORTER": "otlp",
"OTEL_EXPORTER_OTLP_TRACES_ENDPOINT": grpcEndpoint.Value(),
"OTEL_EXPORTER_OTLP_TRACES_PROTOCOL": "grpc",
"OTEL_SDK_DISABLED": "true",
},
expectNoSpans: true,
},
{
name: "valid configuration, wrong value for sdk disabled env",
env: map[string]string{
"OTEL_TRACES_EXPORTER": "otlp",
"OTEL_EXPORTER_OTLP_TRACES_ENDPOINT": grpcEndpoint.Value(),
"OTEL_EXPORTER_OTLP_TRACES_PROTOCOL": "grpc",
"OTEL_SDK_DISABLED": "1", // only "true" works according to the spec
},
},
{
name: "endpoint variable precedence",
env: map[string]string{
"OTEL_TRACES_EXPORTER": "otlp",
"OTEL_EXPORTER_OTLP_ENDPOINT": "invalid",
"OTEL_EXPORTER_OTLP_TRACES_ENDPOINT": grpcEndpoint.Value(), // should take precedence
"OTEL_EXPORTER_OTLP_PROTOCOL": "grpc",
},
},
{
name: "protocol variable precedence",
env: map[string]string{
"OTEL_TRACES_EXPORTER": "otlp",
"OTEL_EXPORTER_OTLP_PROTOCOL": "invalid",
"OTEL_EXPORTER_OTLP_TRACES_PROTOCOL": "grpc", // should take precedence
"OTEL_EXPORTER_OTLP_ENDPOINT": grpcEndpoint.Value(),
},
},
{
name: "valid exporter, trace headers",
env: map[string]string{
"OTEL_TRACES_EXPORTER": "otlp",
"OTEL_EXPORTER_OTLP_TRACES_ENDPOINT": httpEndpoint.Value(),
"OTEL_EXPORTER_OTLP_TRACES_PROTOCOL": "http/protobuf",
"OTEL_EXPORTER_OTLP_TRACES_HEADERS": "foo=bar,bar=baz",
},
expectHeaders: map[string][]string{
"foo": {"bar"},
"bar": {"baz"},
},
},
{
name: "valid exporter, alt headers",
env: map[string]string{
"OTEL_TRACES_EXPORTER": "otlp",
"OTEL_EXPORTER_OTLP_TRACES_ENDPOINT": httpEndpoint.Value(),
"OTEL_EXPORTER_OTLP_TRACES_PROTOCOL": "http/protobuf",
"OTEL_EXPORTER_OTLP_HEADERS": "foo=bar,bar=baz",
},
expectHeaders: map[string][]string{
"foo": {"bar"},
"bar": {"baz"},
},
},
{
name: "headers variable precedence",
env: map[string]string{
"OTEL_TRACES_EXPORTER": "otlp",
"OTEL_EXPORTER_OTLP_TRACES_ENDPOINT": httpEndpoint.Value(),
"OTEL_EXPORTER_OTLP_TRACES_PROTOCOL": "http/protobuf",
"OTEL_EXPORTER_OTLP_HEADERS": "a=1,b=2,c=3",
"OTEL_EXPORTER_OTLP_TRACES_HEADERS": "a=2,d=4",
},
expectHeaders: map[string][]string{
"a": {"2"},
"d": {"4"},
},
},
} {
t.Run(tc.name, func(t *testing.T) {
for k, v := range tc.env {
t.Setenv(k, v)
}
cfg, err := config.NewFileOrEnvironmentSource(context.Background(), emptyConfigFilePath, version.FullVersion())
require.NoError(t, err)
remoteClient, err := trace.NewTraceClientFromConfig(cfg.GetConfig().Options.Tracing)
if tc.newClientErr != "" {
assert.ErrorContains(t, err, tc.newClientErr)
return
}
require.NoError(t, err)
ctx := trace.NewContext(log.Ctx(env.Context()).WithContext(context.Background()), remoteClient)
tp := trace.NewTracerProvider(ctx, t.Name())
_, span := tp.Tracer(trace.PomeriumCoreTracer).Start(ctx, "test span")
span.End()
if tc.uploadErr {
assert.Error(t, trace.ForceFlush(ctx))
assert.NoError(t, trace.ShutdownContext(ctx))
return
}
assert.NoError(t, trace.ShutdownContext(ctx))
if tc.expectHeaders != nil {
for _, req := range receiver.ReceivedRequests() {
assert.Subset(t, req.Metadata, tc.expectHeaders, "missing expected headers")
}
}
results := NewTraceResults(receiver.FlushResourceSpans())
if tc.expectNoSpans {
results.MatchTraces(t, MatchOptions{Exact: true})
} else {
results.MatchTraces(t, MatchOptions{
Exact: true,
}, Match{Name: t.Name() + ": test span", TraceCount: 1, Services: []string{t.Name()}})
}
})
}
}
func TestBestEffortProtocolFromOTLPEndpoint(t *testing.T) {
t.Run("Well-known port numbers", func(t *testing.T) {
assert.Equal(t, "grpc", trace.BestEffortProtocolFromOTLPEndpoint("http://127.0.0.1:4317", true))
assert.Equal(t, "http/protobuf", trace.BestEffortProtocolFromOTLPEndpoint("http://127.0.0.1:4318", true))
})
t.Run("path presence", func(t *testing.T) {
assert.Equal(t, "http/protobuf", trace.BestEffortProtocolFromOTLPEndpoint("http://127.0.0.1:12345", false))
assert.Equal(t, "grpc", trace.BestEffortProtocolFromOTLPEndpoint("http://127.0.0.1:12345", true))
assert.Equal(t, "grpc", trace.BestEffortProtocolFromOTLPEndpoint("http://127.0.0.1:12345/v1/traces", false))
assert.Equal(t, "http/protobuf", trace.BestEffortProtocolFromOTLPEndpoint("http://127.0.0.1:12345/v1/traces", true))
})
t.Run("invalid inputs", func(t *testing.T) {
assert.Equal(t, "", trace.BestEffortProtocolFromOTLPEndpoint("", false))
assert.Equal(t, "", trace.BestEffortProtocolFromOTLPEndpoint("http://\x7f", false))
})
}

View file

@ -0,0 +1,524 @@
package trace
import (
"context"
"encoding/binary"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"os"
"runtime"
"slices"
"strings"
"sync"
"time"
"go.opentelemetry.io/otel/attribute"
sdktrace "go.opentelemetry.io/otel/sdk/trace"
oteltrace "go.opentelemetry.io/otel/trace"
coltracepb "go.opentelemetry.io/proto/otlp/collector/trace/v1"
"google.golang.org/grpc/metadata"
"google.golang.org/protobuf/encoding/protojson"
)
type DebugFlags uint32
const (
// If set, adds the "caller" attribute to each trace with the source location
// where the trace was started.
TrackSpanCallers = (1 << iota)
// If set, keeps track of all span references and will attempt to wait for
// all traces to complete when shutting down a trace context.
// Use with caution, this will cause increasing memory usage over time.
TrackSpanReferences = (1 << iota)
// If set, keeps track of all observed spans, including span context and
// all attributes.
// Use with caution, this will cause significantly increasing memory usage
// over time.
TrackAllSpans = (1 << iota) | TrackSpanCallers
// If set, will log all trace IDs and their span counts on close.
//
// Enables [TrackAllSpans]
LogTraceIDs = (1 << iota) | TrackAllSpans
// If set, will log all spans observed by the exporter on close. These spans
// may belong to incomplete traces.
//
// Enables [TrackAllSpans]
LogAllSpans = (1 << iota) | TrackAllSpans
// If set, will log all exported spans when a warning is issued on close
// (requires warning flags to also be set)
//
// Enables [TrackAllSpans]
LogAllSpansOnWarn = (1 << iota) | TrackAllSpans
// If set, will log all trace ID mappings when a warning is issued on close.
// (requires warning flags to also be set)
LogTraceIDsOnWarn = (1 << iota)
// If set, will print a warning to stderr on close if there are any incomplete
// traces (traces with no observed root spans)
WarnOnIncompleteTraces = (1 << iota)
// If set, will print a warning to stderr on close if there are any incomplete
// spans (spans started, but not ended)
WarnOnIncompleteSpans = (1 << iota)
// If set, will print a warning to stderr on close if there are any spans
// which reference unknown parent spans.
//
// Enables [TrackSpanReferences]
WarnOnUnresolvedReferences = (1 << iota) | TrackSpanReferences
// If set, configures Envoy to flush every span individually, disabling its
// internal buffer.
EnvoyFlushEverySpan = (1 << iota)
)
func (df DebugFlags) Check(flags DebugFlags) bool {
return (df & flags) == flags
}
var (
ErrIncompleteSpans = errors.New("exporter shut down with incomplete spans")
ErrMissingParentSpans = errors.New("exporter shut down with missing parent spans")
)
// WaitForSpans will block up to the given max duration and wait for all
// in-flight spans from tracers created with the given context to end. This
// function can be called more than once, and is safe to call from multiple
// goroutines in parallel.
//
// This requires the [TrackSpanReferences] debug flag to have been set with
// [Options.NewContext]. Otherwise, this function is a no-op and will return
// immediately.
//
// If this function blocks for more than 10 seconds, it will print a warning
// to stderr containing a list of span IDs it is waiting for, and the IDs of
// their parents (if known). Additionally, if the [TrackAllSpans] debug flag
// is set, details about parent spans will be displayed, including call site
// and trace ID.
func WaitForSpans(ctx context.Context, maxDuration time.Duration) error {
if sys := systemContextFromContext(ctx); sys != nil && sys.observer != nil {
done := make(chan struct{})
go func() {
defer close(done)
sys.observer.wait(10 * time.Second)
}()
select {
case <-done:
return nil
case <-time.After(maxDuration):
return ErrMissingParentSpans
}
}
return nil
}
func DebugFlagsFromContext(ctx context.Context) DebugFlags {
if sys := systemContextFromContext(ctx); sys != nil {
return sys.options.DebugFlags
}
return 0
}
type stackTraceProcessor struct{}
// ForceFlush implements trace.SpanProcessor.
func (s *stackTraceProcessor) ForceFlush(context.Context) error {
return nil
}
// OnEnd implements trace.SpanProcessor.
func (*stackTraceProcessor) OnEnd(sdktrace.ReadOnlySpan) {
}
// OnStart implements trace.SpanProcessor.
func (*stackTraceProcessor) OnStart(_ context.Context, s sdktrace.ReadWriteSpan) {
_, file, line, _ := runtime.Caller(2)
s.SetAttributes(attribute.String("caller", fmt.Sprintf("%s:%d", file, line)))
}
// Shutdown implements trace.SpanProcessor.
func (s *stackTraceProcessor) Shutdown(context.Context) error {
return nil
}
var debugMessageWriter io.Writer
func startMsg(title string) *strings.Builder {
msg := &strings.Builder{}
msg.WriteString("\n==================================================\n")
msg.WriteString(title)
return msg
}
func endMsg(msg *strings.Builder) {
msg.WriteString("==================================================\n")
w := debugMessageWriter
if w == nil {
w = os.Stderr
}
fmt.Fprint(w, msg.String())
}
type DebugEvent struct {
Timestamp time.Time `json:"timestamp"`
Request *coltracepb.ExportTraceServiceRequest `json:"request"`
}
func (e DebugEvent) MarshalJSON() ([]byte, error) {
type debugEvent struct {
Timestamp time.Time `json:"timestamp"`
Request json.RawMessage `json:"request"`
}
reqData, _ := protojson.Marshal(e.Request)
return json.Marshal(debugEvent{
Timestamp: e.Timestamp,
Request: reqData,
})
}
func (e *DebugEvent) UnmarshalJSON(b []byte) error {
type debugEvent struct {
Timestamp time.Time `json:"timestamp"`
Request json.RawMessage `json:"request"`
}
var ev debugEvent
if err := json.Unmarshal(b, &ev); err != nil {
return err
}
e.Timestamp = ev.Timestamp
var msg coltracepb.ExportTraceServiceRequest
if err := protojson.Unmarshal(ev.Request, &msg); err != nil {
return err
}
e.Request = &msg
return nil
}
const shardCount = 64
type (
shardedSet [shardCount]map[oteltrace.SpanID]struct{}
shardedLocks [shardCount]sync.Mutex
)
func (s *shardedSet) Range(f func(key oteltrace.SpanID)) {
for i := range shardCount {
for k := range s[i] {
f(k)
}
}
}
func (s *shardedLocks) LockAll() {
for i := range shardCount {
s[i].Lock()
}
}
func (s *shardedLocks) UnlockAll() {
for i := range shardCount {
s[i].Unlock()
}
}
type spanTracker struct {
inflightSpansMu shardedLocks
inflightSpans shardedSet
allSpans sync.Map
debugFlags DebugFlags
observer *spanObserver
shutdownOnce sync.Once
}
func newSpanTracker(observer *spanObserver, debugFlags DebugFlags) *spanTracker {
st := &spanTracker{
observer: observer,
debugFlags: debugFlags,
}
for i := range len(st.inflightSpans) {
st.inflightSpans[i] = make(map[oteltrace.SpanID]struct{})
}
return st
}
type spanInfo struct {
Name string
SpanContext oteltrace.SpanContext
Parent oteltrace.SpanContext
caller string
startTime time.Time
}
// ForceFlush implements trace.SpanProcessor.
func (t *spanTracker) ForceFlush(context.Context) error {
return nil
}
// OnEnd implements trace.SpanProcessor.
func (t *spanTracker) OnEnd(s sdktrace.ReadOnlySpan) {
id := s.SpanContext().SpanID()
bucket := binary.BigEndian.Uint64(id[:]) % shardCount
t.inflightSpansMu[bucket].Lock()
defer t.inflightSpansMu[bucket].Unlock()
delete(t.inflightSpans[bucket], id)
}
// OnStart implements trace.SpanProcessor.
func (t *spanTracker) OnStart(_ context.Context, s sdktrace.ReadWriteSpan) {
id := s.SpanContext().SpanID()
bucket := binary.BigEndian.Uint64(id[:]) % shardCount
t.inflightSpansMu[bucket].Lock()
defer t.inflightSpansMu[bucket].Unlock()
t.inflightSpans[bucket][id] = struct{}{}
if t.debugFlags.Check(TrackSpanReferences) {
if s.Parent().IsValid() {
t.observer.ObserveReference(s.Parent().SpanID(), id)
}
t.observer.Observe(id)
}
if t.debugFlags.Check(TrackAllSpans) {
var caller string
for _, attr := range s.Attributes() {
if attr.Key == "caller" {
caller = attr.Value.AsString()
break
}
}
t.allSpans.Store(id, &spanInfo{
Name: s.Name(),
SpanContext: s.SpanContext(),
Parent: s.Parent(),
caller: caller,
startTime: s.StartTime(),
})
}
}
// Shutdown implements trace.SpanProcessor.
func (t *spanTracker) Shutdown(_ context.Context) error {
if t.debugFlags == 0 {
return nil
}
didWarn := false
t.shutdownOnce.Do(func() {
if t.debugFlags.Check(WarnOnUnresolvedReferences) {
var unknownParentIDs []string
for id, via := range t.observer.referencedIDs {
if via.IsValid() {
if t.debugFlags.Check(TrackAllSpans) {
if viaSpan, ok := t.allSpans.Load(via); ok {
unknownParentIDs = append(unknownParentIDs, fmt.Sprintf("%s via %s (%s)", id, via, viaSpan.(*spanInfo).Name))
} else {
unknownParentIDs = append(unknownParentIDs, fmt.Sprintf("%s via %s", id, via))
}
}
}
}
if len(unknownParentIDs) > 0 {
didWarn = true
msg := startMsg("WARNING: parent spans referenced but never seen:\n")
for _, str := range unknownParentIDs {
msg.WriteString(str)
msg.WriteString("\n")
}
endMsg(msg)
}
}
if t.debugFlags.Check(WarnOnIncompleteSpans) {
if t.debugFlags.Check(TrackAllSpans) {
incompleteSpans := []*spanInfo{}
t.inflightSpansMu.LockAll()
t.inflightSpans.Range(func(key oteltrace.SpanID) {
if info, ok := t.allSpans.Load(key); ok {
incompleteSpans = append(incompleteSpans, info.(*spanInfo))
}
})
t.inflightSpansMu.UnlockAll()
if len(incompleteSpans) > 0 {
didWarn = true
msg := startMsg("WARNING: spans not ended:\n")
longestName := 0
for _, span := range incompleteSpans {
longestName = max(longestName, len(span.Name)+2)
}
for _, span := range incompleteSpans {
var startedAt string
if span.caller != "" {
startedAt = " | started at: " + span.caller
}
fmt.Fprintf(msg, "%-*s (trace: %s | span: %s | parent: %s%s)\n", longestName, "'"+span.Name+"'",
span.SpanContext.TraceID(), span.SpanContext.SpanID(), span.Parent.SpanID(), startedAt)
}
endMsg(msg)
}
} else {
incompleteSpans := []oteltrace.SpanID{}
t.inflightSpansMu.LockAll()
t.inflightSpans.Range(func(key oteltrace.SpanID) {
incompleteSpans = append(incompleteSpans, key)
})
t.inflightSpansMu.UnlockAll()
if len(incompleteSpans) > 0 {
didWarn = true
msg := startMsg("WARNING: spans not ended:\n")
for _, span := range incompleteSpans {
fmt.Fprintf(msg, "%s\n", span)
}
msg.WriteString("Note: set TrackAllSpans flag for more info\n")
endMsg(msg)
}
}
}
if t.debugFlags.Check(LogAllSpans) || (t.debugFlags.Check(LogAllSpansOnWarn) && didWarn) {
allSpans := []*spanInfo{}
t.allSpans.Range(func(_, value any) bool {
allSpans = append(allSpans, value.(*spanInfo))
return true
})
slices.SortFunc(allSpans, func(a, b *spanInfo) int {
return a.startTime.Compare(b.startTime)
})
msg := startMsg("All observed spans:\n")
longestName := 0
for _, span := range allSpans {
longestName = max(longestName, len(span.Name)+2)
}
for _, span := range allSpans {
var startedAt string
if span.caller != "" {
startedAt = " | started at: " + span.caller
}
fmt.Fprintf(msg, "%-*s (trace: %s | span: %s | parent: %s%s)\n", longestName, "'"+span.Name+"'",
span.SpanContext.TraceID(), span.SpanContext.SpanID(), span.Parent.SpanID(), startedAt)
}
endMsg(msg)
}
if t.debugFlags.Check(LogTraceIDs) || (didWarn && t.debugFlags.Check(LogTraceIDsOnWarn)) {
msg := startMsg("Known trace ids:\n")
traceIDs := map[oteltrace.TraceID]int{}
t.allSpans.Range(func(_, value any) bool {
v := value.(*spanInfo)
traceIDs[v.SpanContext.TraceID()]++
return true
})
for id, n := range traceIDs {
fmt.Fprintf(msg, "%s (%d spans)\n", id.String(), n)
}
endMsg(msg)
}
})
if didWarn {
return ErrIncompleteSpans
}
return nil
}
func newSpanObserver() *spanObserver {
return &spanObserver{
referencedIDs: map[oteltrace.SpanID]oteltrace.SpanID{},
cond: sync.NewCond(&sync.Mutex{}),
}
}
type spanObserver struct {
cond *sync.Cond
referencedIDs map[oteltrace.SpanID]oteltrace.SpanID
unobservedIDs int
}
func (obs *spanObserver) ObserveReference(id oteltrace.SpanID, via oteltrace.SpanID) {
obs.cond.L.Lock()
defer obs.cond.L.Unlock()
if _, referenced := obs.referencedIDs[id]; !referenced {
obs.referencedIDs[id] = via // referenced, but not observed
// It is possible for new unobserved references to come in while waiting,
// but incrementing the counter wouldn't satisfy the condition so we don't
// need to signal the waiters
obs.unobservedIDs++
}
}
func (obs *spanObserver) Observe(id oteltrace.SpanID) {
obs.cond.L.Lock()
defer obs.cond.L.Unlock()
if observed, referenced := obs.referencedIDs[id]; !referenced || observed.IsValid() { // NB: subtle condition
obs.referencedIDs[id] = zeroSpanID
if referenced {
obs.unobservedIDs--
obs.cond.Broadcast()
}
}
}
func (obs *spanObserver) wait(warnAfter time.Duration) {
done := make(chan struct{})
defer close(done)
go func() {
select {
case <-done:
return
case <-time.After(warnAfter):
obs.debugWarnWaiting()
}
}()
obs.cond.L.Lock()
for obs.unobservedIDs > 0 {
obs.cond.Wait()
}
obs.cond.L.Unlock()
}
func (obs *spanObserver) debugWarnWaiting() {
obs.cond.L.Lock()
msg := startMsg(fmt.Sprintf("Waiting on %d unobserved spans:\n", obs.unobservedIDs))
for id, via := range obs.referencedIDs {
if via.IsValid() {
fmt.Fprintf(msg, "%s via %s\n", id, via)
}
}
endMsg(msg)
obs.cond.L.Unlock()
}
func (srv *ExporterServer) observeExport(ctx context.Context, req *coltracepb.ExportTraceServiceRequest) {
isLocal := len(metadata.ValueFromIncomingContext(ctx, localExporterMetadataKey)) != 0
if isLocal {
return
}
for _, res := range req.ResourceSpans {
for _, scope := range res.ScopeSpans {
for _, span := range scope.Spans {
id, ok := ToSpanID(span.SpanId)
if !ok {
continue
}
srv.observer.Observe(id)
for _, attr := range span.Attributes {
if attr.Key != "pomerium.external-parent-span" {
continue
}
if bytes, err := hex.DecodeString(attr.Value.GetStringValue()); err == nil {
if id, ok := ToSpanID(bytes); ok {
srv.observer.Observe(id)
}
}
break
}
}
}
}
}

View file

@ -0,0 +1,284 @@
package trace_test
import (
"bytes"
"context"
"fmt"
"runtime"
"sync/atomic"
"testing"
"time"
. "github.com/pomerium/pomerium/internal/testutil/tracetest" //nolint:revive
"github.com/pomerium/pomerium/pkg/telemetry/trace"
"github.com/stretchr/testify/assert"
sdktrace "go.opentelemetry.io/otel/sdk/trace"
oteltrace "go.opentelemetry.io/otel/trace"
)
func TestSpanObserver(t *testing.T) {
t.Run("observe single reference", func(t *testing.T) {
obs := trace.NewSpanObserver()
assert.Equal(t, []oteltrace.SpanID{}, obs.XUnobservedIDs())
obs.ObserveReference(Span(1).ID(), Span(2).ID())
assert.Equal(t, []oteltrace.SpanID{Span(1).ID()}, obs.XUnobservedIDs())
obs.Observe(Span(1).ID())
assert.Equal(t, []oteltrace.SpanID{}, obs.XUnobservedIDs())
})
t.Run("observe multiple references", func(t *testing.T) {
obs := trace.NewSpanObserver()
obs.ObserveReference(Span(1).ID(), Span(2).ID())
obs.ObserveReference(Span(1).ID(), Span(3).ID())
obs.ObserveReference(Span(1).ID(), Span(4).ID())
assert.Equal(t, []oteltrace.SpanID{Span(1).ID()}, obs.XUnobservedIDs())
obs.Observe(Span(1).ID())
assert.Equal(t, []oteltrace.SpanID{}, obs.XUnobservedIDs())
})
t.Run("observe before reference", func(t *testing.T) {
obs := trace.NewSpanObserver()
obs.Observe(Span(1).ID())
assert.Equal(t, []oteltrace.SpanID{}, obs.XUnobservedIDs())
obs.ObserveReference(Span(1).ID(), Span(2).ID())
assert.Equal(t, []oteltrace.SpanID{}, obs.XUnobservedIDs())
})
t.Run("wait", func(t *testing.T) {
obs := trace.NewSpanObserver()
obs.ObserveReference(Span(1).ID(), Span(2).ID())
obs.Observe(Span(2).ID())
obs.ObserveReference(Span(3).ID(), Span(4).ID())
obs.Observe(Span(4).ID())
obs.ObserveReference(Span(5).ID(), Span(6).ID())
obs.Observe(Span(6).ID())
waitOkToExit := atomic.Bool{}
waitExited := atomic.Bool{}
go func() {
defer waitExited.Store(true)
obs.XWait()
assert.True(t, waitOkToExit.Load(), "wait exited early")
}()
time.Sleep(10 * time.Millisecond)
assert.False(t, waitExited.Load())
obs.Observe(Span(1).ID())
time.Sleep(10 * time.Millisecond)
assert.False(t, waitExited.Load())
obs.Observe(Span(3).ID())
time.Sleep(10 * time.Millisecond)
assert.False(t, waitExited.Load())
waitOkToExit.Store(true)
obs.Observe(Span(5).ID())
assert.Eventually(t, waitExited.Load, 100*time.Millisecond, 10*time.Millisecond)
})
t.Run("new references observed during wait", func(t *testing.T) {
obs := trace.NewSpanObserver()
obs.ObserveReference(Span(1).ID(), Span(2).ID())
obs.Observe(Span(2).ID())
obs.ObserveReference(Span(3).ID(), Span(4).ID())
obs.Observe(Span(4).ID())
obs.ObserveReference(Span(5).ID(), Span(6).ID())
obs.Observe(Span(6).ID())
waitOkToExit := atomic.Bool{}
waitExited := atomic.Bool{}
go func() {
defer waitExited.Store(true)
obs.XWait()
assert.True(t, waitOkToExit.Load(), "wait exited early")
}()
assert.Equal(t, []oteltrace.SpanID{Span(1).ID(), Span(3).ID(), Span(5).ID()}, obs.XUnobservedIDs())
time.Sleep(10 * time.Millisecond)
assert.False(t, waitExited.Load())
obs.Observe(Span(1).ID())
assert.Equal(t, []oteltrace.SpanID{Span(3).ID(), Span(5).ID()}, obs.XUnobservedIDs())
time.Sleep(10 * time.Millisecond)
assert.False(t, waitExited.Load())
obs.Observe(Span(3).ID())
assert.Equal(t, []oteltrace.SpanID{Span(5).ID()}, obs.XUnobservedIDs())
time.Sleep(10 * time.Millisecond)
assert.False(t, waitExited.Load())
// observe a new reference
obs.ObserveReference(Span(7).ID(), Span(8).ID())
obs.Observe(Span(8).ID())
assert.Equal(t, []oteltrace.SpanID{Span(5).ID(), Span(7).ID()}, obs.XUnobservedIDs())
time.Sleep(10 * time.Millisecond)
assert.False(t, waitExited.Load())
obs.Observe(Span(5).ID())
assert.Equal(t, []oteltrace.SpanID{Span(7).ID()}, obs.XUnobservedIDs())
time.Sleep(10 * time.Millisecond)
assert.False(t, waitExited.Load())
waitOkToExit.Store(true)
obs.Observe(Span(7).ID())
assert.Equal(t, []oteltrace.SpanID{}, obs.XUnobservedIDs())
assert.Eventually(t, waitExited.Load, 100*time.Millisecond, 10*time.Millisecond)
})
t.Run("multiple waiters", func(t *testing.T) {
t.Parallel()
obs := trace.NewSpanObserver()
obs.ObserveReference(Span(1).ID(), Span(2).ID())
obs.Observe(Span(2).ID())
waitersExited := atomic.Int32{}
for range 10 {
go func() {
defer waitersExited.Add(1)
obs.XWait()
}()
}
assert.Equal(t, []oteltrace.SpanID{Span(1).ID()}, obs.XUnobservedIDs())
time.Sleep(10 * time.Millisecond)
assert.Equal(t, int32(0), waitersExited.Load())
obs.Observe(Span(1).ID())
assert.Eventually(t, func() bool {
return waitersExited.Load() == 10
}, 100*time.Millisecond, 10*time.Millisecond)
})
}
func TestSpanTracker(t *testing.T) {
t.Run("no debug flags", func(t *testing.T) {
t.Parallel()
obs := trace.NewSpanObserver()
tracker := trace.NewSpanTracker(obs, 0)
tp := sdktrace.NewTracerProvider(sdktrace.WithSpanProcessor(tracker))
tracer := tp.Tracer("test")
assert.Equal(t, []oteltrace.SpanID{}, tracker.XInflightSpans())
_, span1 := tracer.Start(context.Background(), "span 1")
assert.Equal(t, []oteltrace.SpanID{span1.SpanContext().SpanID()}, tracker.XInflightSpans())
assert.Equal(t, []oteltrace.SpanID{}, obs.XObservedIDs())
span1.End()
assert.Equal(t, []oteltrace.SpanID{}, tracker.XInflightSpans())
assert.Equal(t, []oteltrace.SpanID{}, obs.XObservedIDs())
})
t.Run("with TrackSpanReferences debug flag", func(t *testing.T) {
t.Parallel()
obs := trace.NewSpanObserver()
tracker := trace.NewSpanTracker(obs, trace.TrackSpanReferences)
tp := sdktrace.NewTracerProvider(sdktrace.WithSpanProcessor(tracker))
tracer := tp.Tracer("test")
assert.Equal(t, []oteltrace.SpanID{}, tracker.XInflightSpans())
_, span1 := tracer.Start(context.Background(), "span 1")
assert.Equal(t, []oteltrace.SpanID{span1.SpanContext().SpanID()}, tracker.XInflightSpans())
assert.Equal(t, []oteltrace.SpanID{span1.SpanContext().SpanID()}, obs.XObservedIDs())
span1.End()
assert.Equal(t, []oteltrace.SpanID{}, tracker.XInflightSpans())
assert.Equal(t, []oteltrace.SpanID{span1.SpanContext().SpanID()}, obs.XObservedIDs())
})
}
func TestSpanTrackerWarnings(t *testing.T) {
t.Run("WarnOnIncompleteSpans", func(t *testing.T) {
var buf bytes.Buffer
trace.SetDebugMessageWriterForTest(t, &buf)
obs := trace.NewSpanObserver()
tracker := trace.NewSpanTracker(obs, trace.WarnOnIncompleteSpans)
tp := sdktrace.NewTracerProvider(sdktrace.WithSpanProcessor(tracker))
tracer := tp.Tracer("test")
_, span1 := tracer.Start(context.Background(), "span 1")
assert.ErrorIs(t, tp.Shutdown(context.Background()), trace.ErrIncompleteSpans)
assert.Equal(t, fmt.Sprintf(`
==================================================
WARNING: spans not ended:
%s
Note: set TrackAllSpans flag for more info
==================================================
`, span1.SpanContext().SpanID()), buf.String())
})
t.Run("WarnOnIncompleteSpans with TrackAllSpans", func(t *testing.T) {
var buf bytes.Buffer
trace.SetDebugMessageWriterForTest(t, &buf)
obs := trace.NewSpanObserver()
tracker := trace.NewSpanTracker(obs, trace.WarnOnIncompleteSpans|trace.TrackAllSpans)
tp := sdktrace.NewTracerProvider(sdktrace.WithSpanProcessor(tracker))
tracer := tp.Tracer("test")
_, span1 := tracer.Start(context.Background(), "span 1")
assert.ErrorIs(t, tp.Shutdown(context.Background()), trace.ErrIncompleteSpans)
assert.Equal(t, fmt.Sprintf(`
==================================================
WARNING: spans not ended:
'span 1' (trace: %s | span: %s | parent: 0000000000000000)
==================================================
`, span1.SpanContext().TraceID(), span1.SpanContext().SpanID()), buf.String())
})
t.Run("WarnOnIncompleteSpans with TrackAllSpans and stackTraceProcessor", func(t *testing.T) {
var buf bytes.Buffer
trace.SetDebugMessageWriterForTest(t, &buf)
obs := trace.NewSpanObserver()
tracker := trace.NewSpanTracker(obs, trace.WarnOnIncompleteSpans|trace.TrackAllSpans)
tp := sdktrace.NewTracerProvider(sdktrace.WithSpanProcessor(&trace.XStackTraceProcessor{}), sdktrace.WithSpanProcessor(tracker))
tracer := tp.Tracer("test")
_, span1 := tracer.Start(context.Background(), "span 1")
_, file, line, _ := runtime.Caller(0)
line--
assert.ErrorIs(t, tp.Shutdown(context.Background()), trace.ErrIncompleteSpans)
assert.Equal(t, fmt.Sprintf(`
==================================================
WARNING: spans not ended:
'span 1' (trace: %s | span: %s | parent: 0000000000000000 | started at: %s:%d)
==================================================
`, span1.SpanContext().TraceID(), span1.SpanContext().SpanID(), file, line), buf.String())
})
t.Run("LogAllSpansOnWarn", func(t *testing.T) {
var buf bytes.Buffer
trace.SetDebugMessageWriterForTest(t, &buf)
obs := trace.NewSpanObserver()
tracker := trace.NewSpanTracker(obs, trace.WarnOnIncompleteSpans|trace.TrackAllSpans|trace.LogAllSpansOnWarn)
tp := sdktrace.NewTracerProvider(sdktrace.WithSpanProcessor(&trace.XStackTraceProcessor{}), sdktrace.WithSpanProcessor(tracker))
tracer := tp.Tracer("test")
_, span1 := tracer.Start(context.Background(), "span 1")
time.Sleep(10 * time.Millisecond)
span1.End()
time.Sleep(10 * time.Millisecond)
_, span2 := tracer.Start(context.Background(), "span 2")
_, file, line, _ := runtime.Caller(0)
line--
tp.Shutdown(context.Background())
assert.Equal(t,
fmt.Sprintf(`
==================================================
WARNING: spans not ended:
'span 2' (trace: %[1]s | span: %[2]s | parent: 0000000000000000 | started at: %[3]s:%[4]d)
==================================================
==================================================
All observed spans:
'span 1' (trace: %[5]s | span: %[6]s | parent: 0000000000000000 | started at: %[3]s:%[7]d)
'span 2' (trace: %[1]s | span: %[2]s | parent: 0000000000000000 | started at: %[3]s:%[4]d)
==================================================
`,
span2.SpanContext().TraceID(), span2.SpanContext().SpanID(), file, line,
span1.SpanContext().TraceID(), span1.SpanContext().SpanID(), line-4,
), buf.String())
})
}

View file

@ -0,0 +1,71 @@
package trace
import (
"context"
"os"
"strconv"
"go.opentelemetry.io/contrib/propagators/autoprop"
"go.opentelemetry.io/otel"
sdktrace "go.opentelemetry.io/otel/sdk/trace"
"go.opentelemetry.io/otel/trace"
"go.opentelemetry.io/otel/trace/embedded"
)
// PomeriumCoreTracer should be used for all tracers created in pomerium core.
const PomeriumCoreTracer = "pomerium.io/core"
func init() {
otel.SetTextMapPropagator(autoprop.NewTextMapPropagator())
}
// UseGlobalPanicTracer sets the global tracer provider to one whose tracers
// panic when starting spans. This can be used to locate errant usages of the
// global tracer, and is enabled automatically in some tests. It is otherwise
// not used by default, since pomerium is used as a library in some places that
// might use the global tracer provider.
func UseGlobalPanicTracer() {
otel.SetTracerProvider(panicTracerProvider{})
}
type panicTracerProvider struct {
embedded.TracerProvider
}
// Tracer implements trace.TracerProvider.
func (w panicTracerProvider) Tracer(string, ...trace.TracerOption) trace.Tracer {
return panicTracer{}
}
type panicTracer struct {
embedded.Tracer
}
var _ trace.Tracer = panicTracer{}
// Start implements trace.Tracer.
func (p panicTracer) Start(context.Context, string, ...trace.SpanStartOption) (context.Context, trace.Span) {
panic("global tracer used")
}
// functions below mimic those with the same name in otel/sdk/internal/env/env.go
func BatchSpanProcessorScheduleDelay() int {
const defaultValue = sdktrace.DefaultScheduleDelay
if v, ok := os.LookupEnv("OTEL_BSP_SCHEDULE_DELAY"); ok {
if n, err := strconv.Atoi(v); err == nil {
return n
}
}
return defaultValue
}
func BatchSpanProcessorMaxExportBatchSize() int {
const defaultValue = sdktrace.DefaultMaxExportBatchSize
if v, ok := os.LookupEnv("OTEL_BSP_MAX_EXPORT_BATCH_SIZE"); ok {
if n, err := strconv.Atoi(v); err == nil {
return n
}
}
return defaultValue
}

View file

@ -0,0 +1,22 @@
package trace_test
import (
"context"
"testing"
"github.com/pomerium/pomerium/pkg/telemetry/trace"
"github.com/stretchr/testify/assert"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/trace/noop"
)
func TestUseGlobalPanicTracer(t *testing.T) {
t.Cleanup(func() {
otel.SetTracerProvider(noop.NewTracerProvider())
})
trace.UseGlobalPanicTracer()
tracer := otel.GetTracerProvider().Tracer("test")
assert.Panics(t, func() {
tracer.Start(context.Background(), "span")
})
}

View file

@ -0,0 +1,13 @@
package trace_test
import (
"os"
"testing"
"github.com/pomerium/pomerium/pkg/telemetry/trace"
)
func TestMain(m *testing.M) {
trace.UseGlobalPanicTracer()
os.Exit(m.Run())
}

View file

@ -0,0 +1,96 @@
package trace
import (
"context"
"fmt"
"net/http"
"reflect"
"github.com/gorilla/mux"
"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
"google.golang.org/grpc/stats"
)
func NewHTTPMiddleware(opts ...otelhttp.Option) mux.MiddlewareFunc {
return otelhttp.NewMiddleware("Server: %s %s", append(opts, otelhttp.WithSpanNameFormatter(func(operation string, r *http.Request) string {
routeStr := ""
route := mux.CurrentRoute(r)
if route != nil {
var err error
routeStr, err = route.GetPathTemplate()
if err != nil {
routeStr, err = route.GetPathRegexp()
if err != nil {
routeStr = ""
}
}
}
return fmt.Sprintf(operation, r.Method, routeStr)
}))...)
}
type clientStatsHandlerWrapper struct {
ClientStatsHandlerOptions
base stats.Handler
}
type ClientStatsHandlerOptions struct {
statsInterceptor func(ctx context.Context, rs stats.RPCStats) stats.RPCStats
}
type ClientStatsHandlerOption func(*ClientStatsHandlerOptions)
func (o *ClientStatsHandlerOptions) apply(opts ...ClientStatsHandlerOption) {
for _, op := range opts {
op(o)
}
}
// WithStatsInterceptor calls the given function to modify the rpc stats before
// passing it to the stats handler during HandleRPC events.
//
// The interceptor MUST NOT modify the RPCStats it is given. It should instead
// return a copy of the underlying object with the same type, with any
// modifications made to the copy.
func WithStatsInterceptor(statsInterceptor func(ctx context.Context, rs stats.RPCStats) stats.RPCStats) ClientStatsHandlerOption {
return func(o *ClientStatsHandlerOptions) {
o.statsInterceptor = statsInterceptor
}
}
func NewClientStatsHandler(base stats.Handler, opts ...ClientStatsHandlerOption) stats.Handler {
options := ClientStatsHandlerOptions{}
options.apply(opts...)
return &clientStatsHandlerWrapper{
ClientStatsHandlerOptions: options,
base: base,
}
}
// HandleConn implements stats.Handler.
func (w *clientStatsHandlerWrapper) HandleConn(ctx context.Context, stats stats.ConnStats) {
w.base.HandleConn(ctx, stats)
}
// HandleRPC implements stats.Handler.
func (w *clientStatsHandlerWrapper) HandleRPC(ctx context.Context, stats stats.RPCStats) {
if w.statsInterceptor != nil {
modified := w.statsInterceptor(ctx, stats)
if reflect.TypeOf(stats) != reflect.TypeOf(modified) {
panic("bug: stats interceptor returned a message of a different type")
}
w.base.HandleRPC(ctx, modified)
} else {
w.base.HandleRPC(ctx, stats)
}
}
// TagConn implements stats.Handler.
func (w *clientStatsHandlerWrapper) TagConn(ctx context.Context, info *stats.ConnTagInfo) context.Context {
return w.base.TagConn(ctx, info)
}
// TagRPC implements stats.Handler.
func (w *clientStatsHandlerWrapper) TagRPC(ctx context.Context, info *stats.RPCTagInfo) context.Context {
return w.base.TagRPC(ctx, info)
}

View file

@ -0,0 +1,210 @@
package trace_test
import (
"context"
"errors"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/gorilla/mux"
"github.com/pomerium/pomerium/pkg/telemetry/trace"
"github.com/stretchr/testify/assert"
"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
sdktrace "go.opentelemetry.io/otel/sdk/trace"
oteltrace "go.opentelemetry.io/otel/trace"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/stats"
)
func TestHTTPMiddleware(t *testing.T) {
router := mux.NewRouter()
tp := sdktrace.NewTracerProvider()
router.Use(trace.NewHTTPMiddleware(
otelhttp.WithTracerProvider(tp),
))
router.Path("/foo").HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
span := oteltrace.SpanFromContext(r.Context())
assert.Equal(t, "Server: GET /foo", span.(interface{ Name() string }).Name())
}).Methods(http.MethodGet)
w := httptest.NewRecorder()
ctx, span := tp.Tracer("test").Start(context.Background(), "test")
router.ServeHTTP(w, httptest.NewRequestWithContext(ctx, http.MethodGet, "/foo", nil))
span.End()
}
type mockHandler struct {
handleConn func(ctx context.Context, stats stats.ConnStats)
handleRPC func(ctx context.Context, stats stats.RPCStats)
tagConn func(ctx context.Context, info *stats.ConnTagInfo) context.Context
tagRPC func(ctx context.Context, info *stats.RPCTagInfo) context.Context
}
// HandleConn implements stats.Handler.
func (m *mockHandler) HandleConn(ctx context.Context, stats stats.ConnStats) {
m.handleConn(ctx, stats)
}
// HandleRPC implements stats.Handler.
func (m *mockHandler) HandleRPC(ctx context.Context, stats stats.RPCStats) {
m.handleRPC(ctx, stats)
}
// TagConn implements stats.Handler.
func (m *mockHandler) TagConn(ctx context.Context, info *stats.ConnTagInfo) context.Context {
return m.tagConn(ctx, info)
}
// TagRPC implements stats.Handler.
func (m *mockHandler) TagRPC(ctx context.Context, info *stats.RPCTagInfo) context.Context {
return m.tagRPC(ctx, info)
}
var _ stats.Handler = (*mockHandler)(nil)
func TestStatsInterceptor(t *testing.T) {
var outBegin *stats.Begin
var outEnd *stats.End
base := &mockHandler{
handleRPC: func(_ context.Context, rs stats.RPCStats) {
switch rs := rs.(type) {
case *stats.Begin:
outBegin = rs
case *stats.End:
outEnd = rs
}
},
}
interceptor := func(_ context.Context, rs stats.RPCStats) stats.RPCStats {
switch rs := rs.(type) {
case *stats.Begin:
return &stats.Begin{
Client: rs.Client,
BeginTime: rs.BeginTime.Add(-1 * time.Minute),
FailFast: rs.FailFast,
IsClientStream: rs.IsClientStream,
IsServerStream: rs.IsServerStream,
IsTransparentRetryAttempt: rs.IsTransparentRetryAttempt,
}
case *stats.End:
return &stats.End{
Client: rs.Client,
BeginTime: rs.BeginTime,
EndTime: rs.EndTime,
Trailer: rs.Trailer,
Error: errors.New("modified"),
}
}
return rs
}
handler := trace.NewClientStatsHandler(
base,
trace.WithStatsInterceptor(interceptor),
)
inBegin := &stats.Begin{
Client: true,
BeginTime: time.Now(),
FailFast: true,
IsClientStream: true,
IsServerStream: false,
IsTransparentRetryAttempt: false,
}
handler.HandleRPC(context.Background(), inBegin)
assert.NotNil(t, outBegin)
assert.NotSame(t, inBegin, outBegin)
assert.Equal(t, inBegin.BeginTime.Add(-1*time.Minute), outBegin.BeginTime)
assert.Equal(t, inBegin.Client, outBegin.Client)
assert.Equal(t, inBegin.FailFast, outBegin.FailFast)
assert.Equal(t, inBegin.IsClientStream, outBegin.IsClientStream)
assert.Equal(t, inBegin.IsServerStream, outBegin.IsServerStream)
assert.Equal(t, inBegin.IsTransparentRetryAttempt, outBegin.IsTransparentRetryAttempt)
inEnd := &stats.End{
Client: true,
BeginTime: time.Now(),
EndTime: time.Now().Add(1 * time.Minute),
Trailer: metadata.Pairs("a", "b", "c", "d"),
Error: errors.New("input"),
}
handler.HandleRPC(context.Background(), inEnd)
assert.NotNil(t, outEnd)
assert.NotSame(t, inEnd, outEnd)
assert.Equal(t, inEnd.Client, outEnd.Client)
assert.Equal(t, inEnd.BeginTime, outEnd.BeginTime)
assert.Equal(t, inEnd.EndTime, outEnd.EndTime)
assert.Equal(t, inEnd.Trailer, outEnd.Trailer)
assert.Equal(t, "input", inEnd.Error.Error())
assert.Equal(t, "modified", outEnd.Error.Error())
}
func TestStatsInterceptor_Nil(t *testing.T) {
var outCtx context.Context
var outConnStats stats.ConnStats
var outRPCStats stats.RPCStats
var outConnTagInfo *stats.ConnTagInfo
var outRPCTagInfo *stats.RPCTagInfo
base := &mockHandler{
handleConn: func(ctx context.Context, stats stats.ConnStats) {
outCtx = ctx
outConnStats = stats
},
handleRPC: func(ctx context.Context, stats stats.RPCStats) {
outCtx = ctx
outRPCStats = stats
},
tagConn: func(ctx context.Context, info *stats.ConnTagInfo) context.Context {
outCtx = ctx
outConnTagInfo = info
return ctx
},
tagRPC: func(ctx context.Context, info *stats.RPCTagInfo) context.Context {
outCtx = ctx
outRPCTagInfo = info
return ctx
},
}
handler := trace.NewClientStatsHandler(
base,
trace.WithStatsInterceptor(nil),
)
inCtx := context.Background()
inConnStats := &stats.ConnBegin{}
inRPCStats := &stats.Begin{}
inConnTagInfo := &stats.ConnTagInfo{}
inRPCTagInfo := &stats.RPCTagInfo{}
handler.HandleConn(inCtx, inConnStats)
assert.Equal(t, inCtx, outCtx)
assert.Same(t, inConnStats, outConnStats)
handler.HandleRPC(inCtx, inRPCStats)
assert.Equal(t, inCtx, outCtx)
assert.Same(t, inRPCStats, outRPCStats)
handler.TagConn(inCtx, inConnTagInfo)
assert.Equal(t, inCtx, outCtx)
assert.Same(t, inConnTagInfo, outConnTagInfo)
handler.TagRPC(inCtx, inRPCTagInfo)
assert.Equal(t, inCtx, outCtx)
assert.Same(t, inRPCTagInfo, outRPCTagInfo)
}
func TestStatsInterceptor_Bug(t *testing.T) {
handler := trace.NewClientStatsHandler(
&mockHandler{
handleRPC: func(_ context.Context, _ stats.RPCStats) {
t.Error("should not be reached")
},
},
trace.WithStatsInterceptor(func(_ context.Context, rs stats.RPCStats) stats.RPCStats {
_ = rs.(*stats.Begin)
return &stats.End{}
}),
)
assert.PanicsWithValue(t, "bug: stats interceptor returned a message of a different type", func() {
handler.HandleRPC(context.Background(), &stats.Begin{})
})
}

View file

@ -0,0 +1,103 @@
package trace
import (
"context"
"errors"
"fmt"
"net"
"time"
"github.com/pomerium/pomerium/internal/log"
coltracepb "go.opentelemetry.io/proto/otlp/collector/trace/v1"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/test/bufconn"
"go.opentelemetry.io/otel/exporters/otlp/otlptrace"
"go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc"
)
const localExporterMetadataKey = "x-local-exporter"
// Export implements ptraceotlp.GRPCServer.
func (srv *ExporterServer) Export(ctx context.Context, req *coltracepb.ExportTraceServiceRequest) (*coltracepb.ExportTraceServiceResponse, error) {
if srv.observer != nil {
srv.observeExport(ctx, req)
}
if err := srv.remoteClient.UploadTraces(ctx, req.GetResourceSpans()); err != nil {
log.Ctx(ctx).Err(err).Msg("error uploading traces")
return nil, err
}
return &coltracepb.ExportTraceServiceResponse{}, nil
}
type ExporterServer struct {
coltracepb.UnimplementedTraceServiceServer
server *grpc.Server
observer *spanObserver
remoteClient otlptrace.Client
cc *grpc.ClientConn
}
func NewServer(ctx context.Context) *ExporterServer {
sys := systemContextFromContext(ctx)
ex := &ExporterServer{
remoteClient: sys.remoteClient,
observer: sys.observer,
server: grpc.NewServer(grpc.Creds(insecure.NewCredentials())),
}
coltracepb.RegisterTraceServiceServer(ex.server, ex)
return ex
}
func (srv *ExporterServer) Start(ctx context.Context) {
lis := bufconn.Listen(2 * 1024 * 1024)
go func() {
if err := srv.remoteClient.Start(ctx); err != nil {
if !errors.Is(err, ErrNoClient) {
panic(fmt.Errorf("bug: %w", err))
}
}
_ = srv.server.Serve(lis)
}()
cc, err := grpc.NewClient("passthrough://ignore",
grpc.WithContextDialer(func(context.Context, string) (net.Conn, error) {
return lis.Dial()
}), grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
panic(err)
}
srv.cc = cc
}
func (srv *ExporterServer) NewClient() otlptrace.Client {
return otlptracegrpc.NewClient(
otlptracegrpc.WithGRPCConn(srv.cc),
otlptracegrpc.WithTimeout(1*time.Minute),
otlptracegrpc.WithHeaders(map[string]string{
localExporterMetadataKey: "1",
}),
)
}
func (srv *ExporterServer) Shutdown(ctx context.Context) error {
stopped := make(chan struct{})
go func() {
srv.server.GracefulStop()
close(stopped)
}()
select {
case <-stopped:
case <-ctx.Done():
return context.Cause(ctx)
}
var errs []error
if err := WaitForSpans(ctx, 30*time.Second); err != nil {
errs = append(errs, err)
}
if err := srv.remoteClient.Stop(ctx); err != nil {
errs = append(errs, err)
}
srv.cc.Close()
return errors.Join(errs...)
}

View file

@ -0,0 +1,216 @@
package trace
import (
"context"
"errors"
"fmt"
"runtime"
"sync"
"sync/atomic"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/exporters/otlp/otlptrace"
"go.opentelemetry.io/otel/sdk/resource"
sdktrace "go.opentelemetry.io/otel/sdk/trace"
semconv "go.opentelemetry.io/otel/semconv/v1.26.0"
"go.opentelemetry.io/otel/trace"
"go.opentelemetry.io/otel/trace/noop"
coltracepb "go.opentelemetry.io/proto/otlp/collector/trace/v1"
)
type Options struct {
DebugFlags DebugFlags
}
func (op Options) NewContext(parent context.Context, remoteClient otlptrace.Client) context.Context {
if systemContextFromContext(parent) != nil {
panic("parent already contains trace system context")
}
if remoteClient == nil {
panic("remoteClient cannot be nil (use trace.NoopClient instead)")
}
sys := &systemContext{
options: op,
remoteClient: remoteClient,
tpm: &tracerProviderManager{},
}
if op.DebugFlags.Check(TrackSpanReferences) {
sys.observer = newSpanObserver()
}
ctx := context.WithValue(parent, systemContextKey, sys)
sys.exporterServer = NewServer(ctx)
sys.exporterServer.Start(ctx)
return ctx
}
// NewContext creates a new top-level background context with tracing machinery
// and configuration that will be used when creating new tracer providers.
//
// Any context created with NewContext should eventually be shut down by calling
// [ShutdownContext] to ensure all traces are exported.
//
// The parent context should be context.Background(), or a background context
// containing a logger. If any context in the parent's hierarchy was created
// by NewContext, this will panic.
func NewContext(parent context.Context, remoteClient otlptrace.Client) context.Context {
return Options{}.NewContext(parent, remoteClient)
}
// NewTracerProvider creates a new [trace.TracerProvider] with the given service
// name and options.
//
// A context returned by [NewContext] must exist somewhere in the hierarchy of
// ctx, otherwise a no-op TracerProvider is returned. The configuration embedded
// within that context will be used to configure its resource attributes and
// exporter automatically.
func NewTracerProvider(ctx context.Context, serviceName string, opts ...sdktrace.TracerProviderOption) trace.TracerProvider {
sys := systemContextFromContext(ctx)
if sys == nil {
return noop.NewTracerProvider()
}
_, file, line, _ := runtime.Caller(1)
exp, err := otlptrace.New(ctx, sys.exporterServer.NewClient())
if err != nil {
panic(err)
}
r, err := resource.Merge(
resource.Default(),
resource.NewWithAttributes(
semconv.SchemaURL,
semconv.ServiceName(serviceName),
attribute.String("provider.created_at", fmt.Sprintf("%s:%d", file, line)),
),
)
if err != nil {
panic(err)
}
options := []sdktrace.TracerProviderOption{}
if sys.options.DebugFlags.Check(TrackSpanCallers) {
options = append(options, sdktrace.WithSpanProcessor(&stackTraceProcessor{}))
}
if sys.options.DebugFlags.Check(TrackSpanReferences) {
tracker := newSpanTracker(sys.observer, sys.options.DebugFlags)
options = append(options, sdktrace.WithSpanProcessor(tracker))
}
options = append(append(options,
sdktrace.WithBatcher(exp),
sdktrace.WithResource(r),
), opts...)
tp := sdktrace.NewTracerProvider(options...)
sys.tpm.Add(tp)
return tp
}
// Continue starts a new span using the tracer provider of the span in the given
// context.
//
// In most cases, it is better to start spans directly from a specific tracer,
// obtained via dependency injection or some other mechanism. This function is
// useful in shared code where the tracer used to start the span is not
// necessarily the same every time, but can change based on the call site.
func Continue(ctx context.Context, name string, o ...trace.SpanStartOption) (context.Context, trace.Span) {
return trace.SpanFromContext(ctx).
TracerProvider().
Tracer(PomeriumCoreTracer).
Start(ctx, name, o...)
}
// ShutdownContext will gracefully shut down all tracing resources created with
// a context returned by [NewContext], including all tracer providers and the
// underlying exporter and remote client.
//
// This should only be called once before exiting, but subsequent calls are
// a no-op.
//
// The provided context does not necessarily need to be the exact context
// returned by [NewContext]; it can be anywhere in its context hierarchy and
// this function will have the same effect.
func ShutdownContext(ctx context.Context) error {
sys := systemContextFromContext(ctx)
if sys == nil {
panic("context was not created with trace.NewContext")
}
if !sys.shutdown.CompareAndSwap(false, true) {
return nil
}
var errs []error
if err := sys.tpm.ShutdownAll(context.Background()); err != nil {
errs = append(errs, fmt.Errorf("error shutting down tracer providers: %w", err))
}
if err := sys.exporterServer.Shutdown(context.Background()); err != nil && !errors.Is(err, ErrNoClient) {
errs = append(errs, fmt.Errorf("error shutting down trace exporter: %w", err))
}
return errors.Join(errs...)
}
func ExporterServerFromContext(ctx context.Context) coltracepb.TraceServiceServer {
if sys := systemContextFromContext(ctx); sys != nil {
return sys.exporterServer
}
return nil
}
func RemoteClientFromContext(ctx context.Context) otlptrace.Client {
if sys := systemContextFromContext(ctx); sys != nil {
return sys.remoteClient
}
return nil
}
// ForceFlush immediately exports all spans that have not yet been exported for
// all tracer providers created using the given context.
func ForceFlush(ctx context.Context) error {
if sys := systemContextFromContext(ctx); sys != nil {
var errs []error
for _, tp := range sys.tpm.tracerProviders {
errs = append(errs, tp.ForceFlush(ctx))
}
return errors.Join(errs...)
}
return nil
}
type systemContextKeyType struct{}
var systemContextKey systemContextKeyType
type systemContext struct {
options Options
remoteClient otlptrace.Client
tpm *tracerProviderManager
observer *spanObserver
exporterServer *ExporterServer
shutdown atomic.Bool
}
func systemContextFromContext(ctx context.Context) *systemContext {
sys, _ := ctx.Value(systemContextKey).(*systemContext)
return sys
}
type tracerProviderManager struct {
mu sync.Mutex
tracerProviders []*sdktrace.TracerProvider
}
func (tpm *tracerProviderManager) ShutdownAll(ctx context.Context) error {
tpm.mu.Lock()
defer tpm.mu.Unlock()
var errs []error
for _, tp := range tpm.tracerProviders {
errs = append(errs, tp.ForceFlush(ctx))
}
for _, tp := range tpm.tracerProviders {
errs = append(errs, tp.Shutdown(ctx))
}
clear(tpm.tracerProviders)
return errors.Join(errs...)
}
func (tpm *tracerProviderManager) Add(tp *sdktrace.TracerProvider) {
tpm.mu.Lock()
defer tpm.mu.Unlock()
tpm.tracerProviders = append(tpm.tracerProviders, tp)
}

View file

@ -0,0 +1,72 @@
package trace
import (
"cmp"
"io"
"slices"
"testing"
"time"
oteltrace "go.opentelemetry.io/otel/trace"
)
var (
NewSpanObserver = newSpanObserver
NewSpanTracker = newSpanTracker
)
type XStackTraceProcessor = stackTraceProcessor
func (obs *spanObserver) XWait() {
obs.wait(5 * time.Second)
}
func (obs *spanObserver) XUnobservedIDs() []oteltrace.SpanID {
obs.cond.L.Lock()
defer obs.cond.L.Unlock()
ids := []oteltrace.SpanID{}
for k, v := range obs.referencedIDs {
if v.IsValid() {
ids = append(ids, k)
}
}
slices.SortFunc(ids, func(a, b oteltrace.SpanID) int {
return cmp.Compare(a.String(), b.String())
})
return ids
}
func (obs *spanObserver) XObservedIDs() []oteltrace.SpanID {
obs.cond.L.Lock()
defer obs.cond.L.Unlock()
ids := []oteltrace.SpanID{}
for k, v := range obs.referencedIDs {
if !v.IsValid() {
ids = append(ids, k)
}
}
slices.SortFunc(ids, func(a, b oteltrace.SpanID) int {
return cmp.Compare(a.String(), b.String())
})
return ids
}
func (t *spanTracker) XInflightSpans() []oteltrace.SpanID {
ids := []oteltrace.SpanID{}
t.inflightSpansMu.LockAll()
t.inflightSpans.Range(func(key oteltrace.SpanID) {
ids = append(ids, key)
})
t.inflightSpansMu.UnlockAll()
slices.SortFunc(ids, func(a, b oteltrace.SpanID) int {
return cmp.Compare(a.String(), b.String())
})
return ids
}
func SetDebugMessageWriterForTest(t testing.TB, w io.Writer) {
debugMessageWriter = w
t.Cleanup(func() {
debugMessageWriter = nil
})
}

View file

@ -0,0 +1,32 @@
package trace
import (
"unique"
oteltrace "go.opentelemetry.io/otel/trace"
)
var (
zeroSpanID oteltrace.SpanID
zeroTraceID = unique.Make(oteltrace.TraceID([16]byte{}))
)
func ToSpanID(bytes []byte) (oteltrace.SpanID, bool) {
switch len(bytes) {
case 0:
return zeroSpanID, true
case 8:
return oteltrace.SpanID(bytes), true
}
return zeroSpanID, false
}
func ToTraceID(bytes []byte) (unique.Handle[oteltrace.TraceID], bool) {
switch len(bytes) {
case 0:
return zeroTraceID, true
case 16:
return unique.Make(oteltrace.TraceID(bytes)), true
}
return zeroTraceID, false
}