[tracing] refactor to use custom extension for trace id editing (#5420)

refactor to use custom extension for trace id editing
This commit is contained in:
Joe Kralicky 2025-01-08 16:06:33 -05:00
parent de68673819
commit 86bf8a1d5f
No known key found for this signature in database
GPG key ID: 75C4875F34A9FB79
36 changed files with 1144 additions and 2672 deletions

View file

@ -3,6 +3,7 @@ package authenticate
import ( import (
"context" "context"
"encoding/base64" "encoding/base64"
"encoding/hex"
"errors" "errors"
"fmt" "fmt"
"net/http" "net/http"
@ -21,6 +22,7 @@ import (
"github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/middleware" "github.com/pomerium/pomerium/internal/middleware"
"github.com/pomerium/pomerium/internal/sessions" "github.com/pomerium/pomerium/internal/sessions"
"github.com/pomerium/pomerium/internal/telemetry/trace"
"github.com/pomerium/pomerium/internal/urlutil" "github.com/pomerium/pomerium/internal/urlutil"
"github.com/pomerium/pomerium/pkg/cryptutil" "github.com/pomerium/pomerium/pkg/cryptutil"
"github.com/pomerium/pomerium/pkg/identity" "github.com/pomerium/pomerium/pkg/identity"
@ -282,9 +284,20 @@ func (a *Authenticate) reauthenticateOrFail(w http.ResponseWriter, r *http.Reque
state.sessionStore.ClearSession(w, r) state.sessionStore.ClearSession(w, r)
redirectURL := state.redirectURL.ResolveReference(r.URL) redirectURL := state.redirectURL.ResolveReference(r.URL)
redirectURLValues := redirectURL.Query()
var traceID string
if tp := trace.PomeriumURLQueryCarrier(redirectURLValues).Get("traceparent"); len(tp) == 55 {
if traceIDBytes, err := hex.DecodeString(tp[3:35]); err == nil {
traceFlags, _ := hex.DecodeString(tp[53:55])
if len(traceFlags) != 1 {
traceFlags = []byte{0}
}
traceID = base64.RawURLEncoding.EncodeToString(append(traceIDBytes, traceFlags[0]))
}
}
nonce := csrf.Token(r) nonce := csrf.Token(r)
now := time.Now().Unix() now := time.Now().Unix()
b := []byte(fmt.Sprintf("%s|%d|", nonce, now)) b := []byte(fmt.Sprintf("%s|%d|%s|", nonce, now, traceID))
enc := cryptutil.Encrypt(state.cookieCipher, []byte(redirectURL.String()), b) enc := cryptutil.Encrypt(state.cookieCipher, []byte(redirectURL.String()), b)
b = append(b, enc...) b = append(b, enc...)
encodedState := base64.URLEncoding.EncodeToString(b) encodedState := base64.URLEncoding.EncodeToString(b)
@ -306,10 +319,6 @@ func (a *Authenticate) OAuthCallback(w http.ResponseWriter, r *http.Request) err
if err != nil { if err != nil {
return fmt.Errorf("authenticate.OAuthCallback: %w", err) return fmt.Errorf("authenticate.OAuthCallback: %w", err)
} }
q := redirect.Query()
if traceparent := q.Get(urlutil.QueryTraceparent); traceparent != "" {
w.Header().Set("X-Pomerium-Traceparent", traceparent)
}
httputil.Redirect(w, r, redirect.String(), http.StatusFound) httputil.Redirect(w, r, redirect.String(), http.StatusFound)
return nil return nil
} }
@ -350,21 +359,20 @@ func (a *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request)
} }
// split state into concat'd components // split state into concat'd components
// (nonce|timestamp|redirect_url|encrypted_data(redirect_url)+mac(nonce,ts)) // (nonce|timestamp|trace_id+flags|encrypted_data(redirect_url)+mac(nonce,ts))
statePayload := strings.SplitN(string(bytes), "|", 3) statePayload := strings.SplitN(string(bytes), "|", 4)
if len(statePayload) != 3 { if len(statePayload) != 4 {
return nil, httputil.NewError(http.StatusBadRequest, fmt.Errorf("state malformed, size: %d", len(statePayload))) return nil, httputil.NewError(http.StatusBadRequest, fmt.Errorf("state malformed, size: %d", len(statePayload)))
} }
// Use our AEAD construct to enforce secrecy and authenticity: // Use our AEAD construct to enforce secrecy and authenticity:
// mac: to validate the nonce again, and above timestamp // mac: to validate the nonce again, and above timestamp
// decrypt: to prevent leaking 'redirect_uri' to IdP or logs // decrypt: to prevent leaking 'redirect_uri' to IdP or logs
b := []byte(fmt.Sprint(statePayload[0], "|", statePayload[1], "|")) b := []byte(fmt.Sprint(statePayload[0], "|", statePayload[1], "|", statePayload[2], "|"))
redirectString, err := cryptutil.Decrypt(state.cookieCipher, []byte(statePayload[2]), b) redirectString, err := cryptutil.Decrypt(state.cookieCipher, []byte(statePayload[3]), b)
if err != nil { if err != nil {
return nil, httputil.NewError(http.StatusBadRequest, err) return nil, httputil.NewError(http.StatusBadRequest, err)
} }
redirectURL, err := urlutil.ParseAndValidateURL(string(redirectString)) redirectURL, err := urlutil.ParseAndValidateURL(string(redirectString))
if err != nil { if err != nil {
return nil, httputil.NewError(http.StatusBadRequest, err) return nil, httputil.NewError(http.StatusBadRequest, err)

View file

@ -368,9 +368,8 @@ func TestAuthenticate_OAuthCallback(t *testing.T) {
params.Add("error", tt.paramErr) params.Add("error", tt.paramErr)
params.Add("code", tt.code) params.Add("code", tt.code)
nonce := cryptutil.NewBase64Key() // mock csrf nonce := cryptutil.NewBase64Key() // mock csrf
// (nonce|timestamp|redirect_url|encrypt(redirect_url),mac(nonce,ts)) // (nonce|timestamp|trace_id+flags|encrypt(redirect_url),mac(nonce,ts))
b := []byte(fmt.Sprintf("%s|%d|%s", nonce, tt.ts, tt.extraMac)) b := []byte(fmt.Sprintf("%s|%d||%s", nonce, tt.ts, tt.extraMac))
enc := cryptutil.Encrypt(a.state.Load().cookieCipher, []byte(tt.redirectURI), b) enc := cryptutil.Encrypt(a.state.Load().cookieCipher, []byte(tt.redirectURI), b)
b = append(b, enc...) b = append(b, enc...)
encodedState := base64.URLEncoding.EncodeToString(b) encodedState := base64.URLEncoding.EncodeToString(b)

View file

@ -232,11 +232,8 @@ func (a *Authorize) requireLoginResponse(
headers := http.Header{} headers := http.Header{}
if id := in.GetAttributes().GetRequest().GetHttp().GetHeaders()["traceparent"]; id != "" { if id := in.GetAttributes().GetRequest().GetHttp().GetHeaders()["traceparent"]; id != "" {
headers["X-Pomerium-Traceparent"] = []string{id}
headers["X-Pomerium-Tracestate"] = []string{"pomerium.traceparent=" + id} // TODO: this might not be necessary anymore
signInURLQuery = url.Values{} signInURLQuery = url.Values{}
signInURLQuery.Add("pomerium_traceparent", id) signInURLQuery.Add("pomerium_traceparent", id)
signInURLQuery.Add("pomerium_tracestate", "pomerium.traceparent="+id)
} }
redirectTo, err := state.authenticateFlow.AuthenticateSignInURL( redirectTo, err := state.authenticateFlow.AuthenticateSignInURL(
ctx, signInURLQuery, &checkRequestURL, idp.GetId()) ctx, signInURLQuery, &checkRequestURL, idp.GetId())

View file

@ -0,0 +1,211 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.35.2
// protoc (unknown)
// source: github.com/pomerium/pomerium/config/envoyconfig/extensions/pomerium_otel.proto
package extensions
import (
_ "github.com/cncf/xds/go/udpa/annotations"
v3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
reflect "reflect"
sync "sync"
)
const (
// Verify that this generated code is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
// Verify that runtime/protoimpl is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
)
type OpenTelemetryConfig struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
GrpcService *v3.GrpcService `protobuf:"bytes,1,opt,name=grpc_service,json=grpcService,proto3" json:"grpc_service,omitempty"`
HttpService *v3.HttpService `protobuf:"bytes,3,opt,name=http_service,json=httpService,proto3" json:"http_service,omitempty"`
ServiceName string `protobuf:"bytes,2,opt,name=service_name,json=serviceName,proto3" json:"service_name,omitempty"`
ResourceDetectors []*v3.TypedExtensionConfig `protobuf:"bytes,4,rep,name=resource_detectors,json=resourceDetectors,proto3" json:"resource_detectors,omitempty"`
Sampler *v3.TypedExtensionConfig `protobuf:"bytes,5,opt,name=sampler,proto3" json:"sampler,omitempty"`
}
func (x *OpenTelemetryConfig) Reset() {
*x = OpenTelemetryConfig{}
mi := &file_github_com_pomerium_pomerium_config_envoyconfig_extensions_pomerium_otel_proto_msgTypes[0]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *OpenTelemetryConfig) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*OpenTelemetryConfig) ProtoMessage() {}
func (x *OpenTelemetryConfig) ProtoReflect() protoreflect.Message {
mi := &file_github_com_pomerium_pomerium_config_envoyconfig_extensions_pomerium_otel_proto_msgTypes[0]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use OpenTelemetryConfig.ProtoReflect.Descriptor instead.
func (*OpenTelemetryConfig) Descriptor() ([]byte, []int) {
return file_github_com_pomerium_pomerium_config_envoyconfig_extensions_pomerium_otel_proto_rawDescGZIP(), []int{0}
}
func (x *OpenTelemetryConfig) GetGrpcService() *v3.GrpcService {
if x != nil {
return x.GrpcService
}
return nil
}
func (x *OpenTelemetryConfig) GetHttpService() *v3.HttpService {
if x != nil {
return x.HttpService
}
return nil
}
func (x *OpenTelemetryConfig) GetServiceName() string {
if x != nil {
return x.ServiceName
}
return ""
}
func (x *OpenTelemetryConfig) GetResourceDetectors() []*v3.TypedExtensionConfig {
if x != nil {
return x.ResourceDetectors
}
return nil
}
func (x *OpenTelemetryConfig) GetSampler() *v3.TypedExtensionConfig {
if x != nil {
return x.Sampler
}
return nil
}
var File_github_com_pomerium_pomerium_config_envoyconfig_extensions_pomerium_otel_proto protoreflect.FileDescriptor
var file_github_com_pomerium_pomerium_config_envoyconfig_extensions_pomerium_otel_proto_rawDesc = []byte{
0x0a, 0x4e, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x70, 0x6f, 0x6d,
0x65, 0x72, 0x69, 0x75, 0x6d, 0x2f, 0x70, 0x6f, 0x6d, 0x65, 0x72, 0x69, 0x75, 0x6d, 0x2f, 0x63,
0x6f, 0x6e, 0x66, 0x69, 0x67, 0x2f, 0x65, 0x6e, 0x76, 0x6f, 0x79, 0x63, 0x6f, 0x6e, 0x66, 0x69,
0x67, 0x2f, 0x65, 0x78, 0x74, 0x65, 0x6e, 0x73, 0x69, 0x6f, 0x6e, 0x73, 0x2f, 0x70, 0x6f, 0x6d,
0x65, 0x72, 0x69, 0x75, 0x6d, 0x5f, 0x6f, 0x74, 0x65, 0x6c, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f,
0x12, 0x13, 0x70, 0x6f, 0x6d, 0x65, 0x72, 0x69, 0x75, 0x6d, 0x2e, 0x65, 0x78, 0x74, 0x65, 0x6e,
0x73, 0x69, 0x6f, 0x6e, 0x73, 0x1a, 0x24, 0x65, 0x6e, 0x76, 0x6f, 0x79, 0x2f, 0x63, 0x6f, 0x6e,
0x66, 0x69, 0x67, 0x2f, 0x63, 0x6f, 0x72, 0x65, 0x2f, 0x76, 0x33, 0x2f, 0x65, 0x78, 0x74, 0x65,
0x6e, 0x73, 0x69, 0x6f, 0x6e, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x1a, 0x27, 0x65, 0x6e, 0x76,
0x6f, 0x79, 0x2f, 0x63, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x2f, 0x63, 0x6f, 0x72, 0x65, 0x2f, 0x76,
0x33, 0x2f, 0x67, 0x72, 0x70, 0x63, 0x5f, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x2e, 0x70,
0x72, 0x6f, 0x74, 0x6f, 0x1a, 0x27, 0x65, 0x6e, 0x76, 0x6f, 0x79, 0x2f, 0x63, 0x6f, 0x6e, 0x66,
0x69, 0x67, 0x2f, 0x63, 0x6f, 0x72, 0x65, 0x2f, 0x76, 0x33, 0x2f, 0x68, 0x74, 0x74, 0x70, 0x5f,
0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x1a, 0x1e, 0x75,
0x64, 0x70, 0x61, 0x2f, 0x61, 0x6e, 0x6e, 0x6f, 0x74, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x2f,
0x6d, 0x69, 0x67, 0x72, 0x61, 0x74, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x1a, 0x1d, 0x75,
0x64, 0x70, 0x61, 0x2f, 0x61, 0x6e, 0x6e, 0x6f, 0x74, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x2f,
0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0x93, 0x03, 0x0a,
0x13, 0x4f, 0x70, 0x65, 0x6e, 0x54, 0x65, 0x6c, 0x65, 0x6d, 0x65, 0x74, 0x72, 0x79, 0x43, 0x6f,
0x6e, 0x66, 0x69, 0x67, 0x12, 0x5b, 0x0a, 0x0c, 0x67, 0x72, 0x70, 0x63, 0x5f, 0x73, 0x65, 0x72,
0x76, 0x69, 0x63, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x21, 0x2e, 0x65, 0x6e, 0x76,
0x6f, 0x79, 0x2e, 0x63, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x2e, 0x63, 0x6f, 0x72, 0x65, 0x2e, 0x76,
0x33, 0x2e, 0x47, 0x72, 0x70, 0x63, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x42, 0x15, 0xf2,
0x98, 0xfe, 0x8f, 0x05, 0x0f, 0x12, 0x0d, 0x6f, 0x74, 0x6c, 0x70, 0x5f, 0x65, 0x78, 0x70, 0x6f,
0x72, 0x74, 0x65, 0x72, 0x52, 0x0b, 0x67, 0x72, 0x70, 0x63, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63,
0x65, 0x12, 0x5b, 0x0a, 0x0c, 0x68, 0x74, 0x74, 0x70, 0x5f, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63,
0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x21, 0x2e, 0x65, 0x6e, 0x76, 0x6f, 0x79, 0x2e,
0x63, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x2e, 0x63, 0x6f, 0x72, 0x65, 0x2e, 0x76, 0x33, 0x2e, 0x48,
0x74, 0x74, 0x70, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x42, 0x15, 0xf2, 0x98, 0xfe, 0x8f,
0x05, 0x0f, 0x12, 0x0d, 0x6f, 0x74, 0x6c, 0x70, 0x5f, 0x65, 0x78, 0x70, 0x6f, 0x72, 0x74, 0x65,
0x72, 0x52, 0x0b, 0x68, 0x74, 0x74, 0x70, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x21,
0x0a, 0x0c, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x02,
0x20, 0x01, 0x28, 0x09, 0x52, 0x0b, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x4e, 0x61, 0x6d,
0x65, 0x12, 0x59, 0x0a, 0x12, 0x72, 0x65, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x5f, 0x64, 0x65,
0x74, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x73, 0x18, 0x04, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x2a, 0x2e,
0x65, 0x6e, 0x76, 0x6f, 0x79, 0x2e, 0x63, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x2e, 0x63, 0x6f, 0x72,
0x65, 0x2e, 0x76, 0x33, 0x2e, 0x54, 0x79, 0x70, 0x65, 0x64, 0x45, 0x78, 0x74, 0x65, 0x6e, 0x73,
0x69, 0x6f, 0x6e, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x11, 0x72, 0x65, 0x73, 0x6f, 0x75,
0x72, 0x63, 0x65, 0x44, 0x65, 0x74, 0x65, 0x63, 0x74, 0x6f, 0x72, 0x73, 0x12, 0x44, 0x0a, 0x07,
0x73, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x72, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x2a, 0x2e,
0x65, 0x6e, 0x76, 0x6f, 0x79, 0x2e, 0x63, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x2e, 0x63, 0x6f, 0x72,
0x65, 0x2e, 0x76, 0x33, 0x2e, 0x54, 0x79, 0x70, 0x65, 0x64, 0x45, 0x78, 0x74, 0x65, 0x6e, 0x73,
0x69, 0x6f, 0x6e, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x07, 0x73, 0x61, 0x6d, 0x70, 0x6c,
0x65, 0x72, 0x42, 0x44, 0xba, 0x80, 0xc8, 0xd1, 0x06, 0x02, 0x10, 0x02, 0x5a, 0x3a, 0x67, 0x69,
0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x70, 0x6f, 0x6d, 0x65, 0x72, 0x69, 0x75,
0x6d, 0x2f, 0x70, 0x6f, 0x6d, 0x65, 0x72, 0x69, 0x75, 0x6d, 0x2f, 0x63, 0x6f, 0x6e, 0x66, 0x69,
0x67, 0x2f, 0x65, 0x6e, 0x76, 0x6f, 0x79, 0x63, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x2f, 0x65, 0x78,
0x74, 0x65, 0x6e, 0x73, 0x69, 0x6f, 0x6e, 0x73, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
}
var (
file_github_com_pomerium_pomerium_config_envoyconfig_extensions_pomerium_otel_proto_rawDescOnce sync.Once
file_github_com_pomerium_pomerium_config_envoyconfig_extensions_pomerium_otel_proto_rawDescData = file_github_com_pomerium_pomerium_config_envoyconfig_extensions_pomerium_otel_proto_rawDesc
)
func file_github_com_pomerium_pomerium_config_envoyconfig_extensions_pomerium_otel_proto_rawDescGZIP() []byte {
file_github_com_pomerium_pomerium_config_envoyconfig_extensions_pomerium_otel_proto_rawDescOnce.Do(func() {
file_github_com_pomerium_pomerium_config_envoyconfig_extensions_pomerium_otel_proto_rawDescData = protoimpl.X.CompressGZIP(file_github_com_pomerium_pomerium_config_envoyconfig_extensions_pomerium_otel_proto_rawDescData)
})
return file_github_com_pomerium_pomerium_config_envoyconfig_extensions_pomerium_otel_proto_rawDescData
}
var file_github_com_pomerium_pomerium_config_envoyconfig_extensions_pomerium_otel_proto_msgTypes = make([]protoimpl.MessageInfo, 1)
var file_github_com_pomerium_pomerium_config_envoyconfig_extensions_pomerium_otel_proto_goTypes = []any{
(*OpenTelemetryConfig)(nil), // 0: pomerium.extensions.OpenTelemetryConfig
(*v3.GrpcService)(nil), // 1: envoy.config.core.v3.GrpcService
(*v3.HttpService)(nil), // 2: envoy.config.core.v3.HttpService
(*v3.TypedExtensionConfig)(nil), // 3: envoy.config.core.v3.TypedExtensionConfig
}
var file_github_com_pomerium_pomerium_config_envoyconfig_extensions_pomerium_otel_proto_depIdxs = []int32{
1, // 0: pomerium.extensions.OpenTelemetryConfig.grpc_service:type_name -> envoy.config.core.v3.GrpcService
2, // 1: pomerium.extensions.OpenTelemetryConfig.http_service:type_name -> envoy.config.core.v3.HttpService
3, // 2: pomerium.extensions.OpenTelemetryConfig.resource_detectors:type_name -> envoy.config.core.v3.TypedExtensionConfig
3, // 3: pomerium.extensions.OpenTelemetryConfig.sampler:type_name -> envoy.config.core.v3.TypedExtensionConfig
4, // [4:4] is the sub-list for method output_type
4, // [4:4] is the sub-list for method input_type
4, // [4:4] is the sub-list for extension type_name
4, // [4:4] is the sub-list for extension extendee
0, // [0:4] is the sub-list for field type_name
}
func init() {
file_github_com_pomerium_pomerium_config_envoyconfig_extensions_pomerium_otel_proto_init()
}
func file_github_com_pomerium_pomerium_config_envoyconfig_extensions_pomerium_otel_proto_init() {
if File_github_com_pomerium_pomerium_config_envoyconfig_extensions_pomerium_otel_proto != nil {
return
}
type x struct{}
out := protoimpl.TypeBuilder{
File: protoimpl.DescBuilder{
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: file_github_com_pomerium_pomerium_config_envoyconfig_extensions_pomerium_otel_proto_rawDesc,
NumEnums: 0,
NumMessages: 1,
NumExtensions: 0,
NumServices: 0,
},
GoTypes: file_github_com_pomerium_pomerium_config_envoyconfig_extensions_pomerium_otel_proto_goTypes,
DependencyIndexes: file_github_com_pomerium_pomerium_config_envoyconfig_extensions_pomerium_otel_proto_depIdxs,
MessageInfos: file_github_com_pomerium_pomerium_config_envoyconfig_extensions_pomerium_otel_proto_msgTypes,
}.Build()
File_github_com_pomerium_pomerium_config_envoyconfig_extensions_pomerium_otel_proto = out.File
file_github_com_pomerium_pomerium_config_envoyconfig_extensions_pomerium_otel_proto_rawDesc = nil
file_github_com_pomerium_pomerium_config_envoyconfig_extensions_pomerium_otel_proto_goTypes = nil
file_github_com_pomerium_pomerium_config_envoyconfig_extensions_pomerium_otel_proto_depIdxs = nil
}

View file

@ -0,0 +1,19 @@
syntax = "proto3";
package pomerium.extensions;
import "envoy/config/core/v3/extension.proto";
import "envoy/config/core/v3/grpc_service.proto";
import "envoy/config/core/v3/http_service.proto";
import "udpa/annotations/migrate.proto";
import "udpa/annotations/status.proto";
option (udpa.annotations.file_status).package_version_status = ACTIVE;
message OpenTelemetryConfig {
envoy.config.core.v3.GrpcService grpc_service = 1 [(udpa.annotations.field_migrate).oneof_promotion = "otlp_exporter"];
envoy.config.core.v3.HttpService http_service = 3 [(udpa.annotations.field_migrate).oneof_promotion = "otlp_exporter"];
string service_name = 2;
repeated envoy.config.core.v3.TypedExtensionConfig resource_detectors = 4;
envoy.config.core.v3.TypedExtensionConfig sampler = 5;
}

View file

@ -34,16 +34,6 @@ func ExtAuthzFilter(grpcClientTimeout *durationpb.Duration) *envoy_extensions_fi
ClusterName: "pomerium-authorize", ClusterName: "pomerium-authorize",
}, },
}, },
InitialMetadata: []*envoy_config_core_v3.HeaderValue{
{
Key: "x-pomerium-traceparent",
Value: `%DYNAMIC_METADATA(pomerium.internal:traceparent)%`,
},
{
Key: "x-pomerium-tracestate",
Value: `%DYNAMIC_METADATA(pomerium.internal:tracestate)%`,
},
},
}, },
}, },
MetadataContextNamespaces: []string{"com.pomerium.client-certificate-info"}, MetadataContextNamespaces: []string{"com.pomerium.client-certificate-info"},

View file

@ -39,22 +39,6 @@ func (b *Builder) buildVirtualHost(
return nil, err return nil, err
} }
vh.Routes = append(vh.Routes, rs...) vh.Routes = append(vh.Routes, rs...)
vh.RequestHeadersToAdd = []*envoy_config_core_v3.HeaderValueOption{
{
Header: &envoy_config_core_v3.HeaderValue{
Key: "x-pomerium-traceparent",
Value: `%DYNAMIC_METADATA(pomerium.internal:traceparent)%`,
},
AppendAction: envoy_config_core_v3.HeaderValueOption_OVERWRITE_IF_EXISTS_OR_ADD,
},
{
Header: &envoy_config_core_v3.HeaderValue{
Key: "x-pomerium-tracestate",
Value: `%DYNAMIC_METADATA(pomerium.internal:tracestate)%`,
},
AppendAction: envoy_config_core_v3.HeaderValueOption_APPEND_IF_EXISTS_OR_ADD,
},
}
return vh, nil return vh, nil
} }

View file

@ -103,7 +103,7 @@ func TestBuildListeners(t *testing.T) {
}] }]
} }
} }
}`, httpConfig.Get("httpFilters.7").String(), }`, httpConfig.Get("httpFilters.6").String(),
"should add alt-svc header") "should add alt-svc header")
case "quic-ingress": case "quic-ingress":
hasQUIC = true hasQUIC = true
@ -149,6 +149,7 @@ func Test_buildMainHTTPConnectionManagerFilter(t *testing.T) {
options.SkipXffAppend = true options.SkipXffAppend = true
options.XffNumTrustedHops = 1 options.XffNumTrustedHops = 1
options.AuthenticateURLString = "https://authenticate.example.com" options.AuthenticateURLString = "https://authenticate.example.com"
options.TracingProvider = "otlp"
filter, err := b.buildMainHTTPConnectionManagerFilter(context.Background(), &config.Config{Options: options}, false, false) filter, err := b.buildMainHTTPConnectionManagerFilter(context.Background(), &config.Config{Options: options}, false, false)
require.NoError(t, err) require.NoError(t, err)

View file

@ -42,21 +42,6 @@ func TestBuilder_buildMainRouteConfiguration(t *testing.T) {
{ {
"name": "catch-all", "name": "catch-all",
"domains": ["*"], "domains": ["*"],
"requestHeadersToAdd": [
{
"appendAction": "OVERWRITE_IF_EXISTS_OR_ADD",
"header": {
"key": "x-pomerium-traceparent",
"value": "%DYNAMIC_METADATA(pomerium.internal:traceparent)%"
}
},
{
"header": {
"key": "x-pomerium-tracestate",
"value": "%DYNAMIC_METADATA(pomerium.internal:tracestate)%"
}
}
],
"routes": [ "routes": [
`+protojson.Format(b.buildControlPlanePathRoute(cfg.Options, "/ping"))+`, `+protojson.Format(b.buildControlPlanePathRoute(cfg.Options, "/ping"))+`,
`+protojson.Format(b.buildControlPlanePathRoute(cfg.Options, "/healthz"))+`, `+protojson.Format(b.buildControlPlanePathRoute(cfg.Options, "/healthz"))+`,

View file

@ -32,60 +32,6 @@
} }
], ],
"httpFilters": [ "httpFilters": [
{
"name": "envoy.filters.http.header_to_metadata",
"typedConfig": {
"@type": "type.googleapis.com/envoy.extensions.filters.http.header_to_metadata.v3.Config",
"requestRules": [
{
"header": "x-pomerium-traceparent",
"onHeaderPresent": {
"metadataNamespace": "pomerium.internal",
"key": "traceparent"
}
},
{
"header": "x-pomerium-tracestate",
"onHeaderPresent": {
"metadataNamespace": "pomerium.internal",
"key": "tracestate"
}
},
{
"header": "x-pomerium-external-parent-span",
"onHeaderPresent": {
"key": "external-parent-span",
"metadataNamespace": "pomerium.internal"
},
"remove": true
},
{
"header": "x-pomerium-sampling-decision",
"onHeaderPresent": {
"metadataNamespace": "pomerium.internal",
"key": "sampling-decision"
},
"remove": true
}
],
"responseRules": [
{
"header": "x-pomerium-traceparent",
"onHeaderPresent": {
"metadataNamespace": "pomerium.internal",
"key": "traceparent"
}
},
{
"header": "x-pomerium-tracestate",
"onHeaderPresent": {
"metadataNamespace": "pomerium.internal",
"key": "tracestate"
}
}
]
}
},
{ {
"name": "envoy.filters.http.lua", "name": "envoy.filters.http.lua",
"typedConfig": { "typedConfig": {
@ -112,17 +58,7 @@
"envoyGrpc": { "envoyGrpc": {
"clusterName": "pomerium-authorize" "clusterName": "pomerium-authorize"
}, },
"timeout": "10s", "timeout": "10s"
"initialMetadata": [
{
"key": "x-pomerium-traceparent",
"value": "%DYNAMIC_METADATA(pomerium.internal:traceparent)%"
},
{
"key": "x-pomerium-tracestate",
"value": "%DYNAMIC_METADATA(pomerium.internal:tracestate)%"
}
]
}, },
"transportApiVersion": "V3", "transportApiVersion": "V3",
"statusOnError": { "statusOnError": {
@ -265,60 +201,10 @@
}, },
"verbose": true, "verbose": true,
"maxPathTagLength": 1024, "maxPathTagLength": 1024,
"customTags": [
{
"tag": "pomerium.traceparent",
"metadata": {
"kind": {
"request": {}
},
"metadataKey": {
"key": "pomerium.internal",
"path": [
{
"key": "traceparent"
}
]
}
}
},
{
"tag": "pomerium.tracestate",
"metadata": {
"kind": {
"request": {}
},
"metadataKey": {
"key": "pomerium.internal",
"path": [
{
"key": "tracestate"
}
]
}
}
},
{
"metadata": {
"kind": {
"request": {}
},
"metadataKey": {
"key": "pomerium.internal",
"path": [
{
"key": "external-parent-span"
}
]
}
},
"tag": "pomerium.external-parent-span"
}
],
"provider": { "provider": {
"name": "envoy.tracers.opentelemetry", "name": "envoy.tracers.pomerium_otel",
"typedConfig": { "typedConfig": {
"@type": "type.googleapis.com/envoy.config.trace.v3.OpenTelemetryConfig", "@type": "type.googleapis.com/pomerium.extensions.OpenTelemetryConfig",
"grpcService": { "grpcService": {
"envoyGrpc": { "envoyGrpc": {
"clusterName": "pomerium-control-plane-grpc" "clusterName": "pomerium-control-plane-grpc"

View file

@ -1,17 +1,13 @@
package envoyconfig package envoyconfig
import ( import (
"fmt"
"os" "os"
"strconv" "strconv"
envoy_config_core_v3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" envoy_config_core_v3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
tracev3 "github.com/envoyproxy/go-control-plane/envoy/config/trace/v3" tracev3 "github.com/envoyproxy/go-control-plane/envoy/config/trace/v3"
envoy_extensions_filters_http_header_to_metadata "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/http/header_to_metadata/v3"
envoy_extensions_filters_network_http_connection_manager "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/network/http_connection_manager/v3" envoy_extensions_filters_network_http_connection_manager "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/network/http_connection_manager/v3"
envoy_extensions_tracers_otel "github.com/envoyproxy/go-control-plane/envoy/extensions/tracers/opentelemetry/resource_detectors/v3" envoy_extensions_tracers_otel "github.com/envoyproxy/go-control-plane/envoy/extensions/tracers/opentelemetry/resource_detectors/v3"
metadatav3 "github.com/envoyproxy/go-control-plane/envoy/type/metadata/v3"
envoy_tracing_v3 "github.com/envoyproxy/go-control-plane/envoy/type/tracing/v3"
envoy_type_v3 "github.com/envoyproxy/go-control-plane/envoy/type/v3" envoy_type_v3 "github.com/envoyproxy/go-control-plane/envoy/type/v3"
"github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/config/envoyconfig/extensions" "github.com/pomerium/pomerium/config/envoyconfig/extensions"
@ -40,9 +36,6 @@ func applyTracingConfig(
if !isTracingEnabled(opts) { if !isTracingEnabled(opts) {
return return
} }
mgr.HttpFilters = append([]*envoy_extensions_filters_network_http_connection_manager.HttpFilter{
tracingMetadataFilter(),
}, mgr.HttpFilters...)
mgr.EarlyHeaderMutationExtensions = []*envoy_config_core_v3.TypedExtensionConfig{ mgr.EarlyHeaderMutationExtensions = []*envoy_config_core_v3.TypedExtensionConfig{
{ {
@ -63,38 +56,14 @@ func applyTracingConfig(
maxPathTagLength = max(64, uint32(num)) maxPathTagLength = max(64, uint32(num))
} }
} }
requestTag := func(key string) *envoy_tracing_v3.CustomTag {
return &envoy_tracing_v3.CustomTag{
Tag: fmt.Sprintf("pomerium.%s", key),
Type: &envoy_tracing_v3.CustomTag_Metadata_{
Metadata: &envoy_tracing_v3.CustomTag_Metadata{
Kind: &metadatav3.MetadataKind{
Kind: &metadatav3.MetadataKind_Request_{
Request: &metadatav3.MetadataKind_Request{},
},
},
MetadataKey: &metadatav3.MetadataKey{
Key: "pomerium.internal",
Path: []*metadatav3.MetadataKey_PathSegment{
{
Segment: &metadatav3.MetadataKey_PathSegment_Key{
Key: key,
},
},
},
},
},
},
}
}
mgr.Tracing = &envoy_extensions_filters_network_http_connection_manager.HttpConnectionManager_Tracing{ mgr.Tracing = &envoy_extensions_filters_network_http_connection_manager.HttpConnectionManager_Tracing{
RandomSampling: &envoy_type_v3.Percent{Value: opts.TracingSampleRate * 100}, RandomSampling: &envoy_type_v3.Percent{Value: opts.TracingSampleRate * 100},
Verbose: true, Verbose: true,
SpawnUpstreamSpan: wrapperspb.Bool(true), SpawnUpstreamSpan: wrapperspb.Bool(true),
Provider: &tracev3.Tracing_Http{ Provider: &tracev3.Tracing_Http{
Name: "envoy.tracers.opentelemetry", Name: "envoy.tracers.pomerium_otel",
ConfigType: &tracev3.Tracing_Http_TypedConfig{ ConfigType: &tracev3.Tracing_Http_TypedConfig{
TypedConfig: marshalAny(&tracev3.OpenTelemetryConfig{ TypedConfig: marshalAny(&extensions.OpenTelemetryConfig{
GrpcService: &envoy_config_core_v3.GrpcService{ GrpcService: &envoy_config_core_v3.GrpcService{
TargetSpecifier: &envoy_config_core_v3.GrpcService_EnvoyGrpc_{ TargetSpecifier: &envoy_config_core_v3.GrpcService_EnvoyGrpc_{
EnvoyGrpc: &envoy_config_core_v3.GrpcService_EnvoyGrpc{ EnvoyGrpc: &envoy_config_core_v3.GrpcService_EnvoyGrpc{
@ -118,62 +87,5 @@ func applyTracingConfig(
}, },
// this allows full URLs to be displayed in traces, they are otherwise truncated // this allows full URLs to be displayed in traces, they are otherwise truncated
MaxPathTagLength: wrapperspb.UInt32(maxPathTagLength), MaxPathTagLength: wrapperspb.UInt32(maxPathTagLength),
CustomTags: []*envoy_tracing_v3.CustomTag{
requestTag("traceparent"),
requestTag("tracestate"),
requestTag("external-parent-span"),
},
}
}
func tracingMetadataFilter() *envoy_extensions_filters_network_http_connection_manager.HttpFilter {
traceparentRule := &envoy_extensions_filters_http_header_to_metadata.Config_Rule{
Header: "x-pomerium-traceparent",
OnHeaderPresent: &envoy_extensions_filters_http_header_to_metadata.Config_KeyValuePair{
MetadataNamespace: "pomerium.internal",
Key: "traceparent",
},
Remove: false,
}
tracestateRule := &envoy_extensions_filters_http_header_to_metadata.Config_Rule{
Header: "x-pomerium-tracestate",
OnHeaderPresent: &envoy_extensions_filters_http_header_to_metadata.Config_KeyValuePair{
MetadataNamespace: "pomerium.internal",
Key: "tracestate",
},
Remove: false,
}
externalParentSpanRule := &envoy_extensions_filters_http_header_to_metadata.Config_Rule{
Header: "x-pomerium-external-parent-span",
OnHeaderPresent: &envoy_extensions_filters_http_header_to_metadata.Config_KeyValuePair{
MetadataNamespace: "pomerium.internal",
Key: "external-parent-span",
},
Remove: true,
}
samplingDecisionRule := &envoy_extensions_filters_http_header_to_metadata.Config_Rule{
Header: "x-pomerium-sampling-decision",
OnHeaderPresent: &envoy_extensions_filters_http_header_to_metadata.Config_KeyValuePair{
MetadataNamespace: "pomerium.internal",
Key: "sampling-decision",
},
Remove: true,
}
return &envoy_extensions_filters_network_http_connection_manager.HttpFilter{
Name: "envoy.filters.http.header_to_metadata",
ConfigType: &envoy_extensions_filters_network_http_connection_manager.HttpFilter_TypedConfig{
TypedConfig: marshalAny(&envoy_extensions_filters_http_header_to_metadata.Config{
RequestRules: []*envoy_extensions_filters_http_header_to_metadata.Config_Rule{
traceparentRule,
tracestateRule,
externalParentSpanRule,
samplingDecisionRule,
},
ResponseRules: []*envoy_extensions_filters_http_header_to_metadata.Config_Rule{
traceparentRule,
tracestateRule,
},
}),
},
} }
} }

View file

@ -96,7 +96,7 @@ func New(ctx context.Context, cfg *config.Config, eventsMgr *events.Manager, opt
// No metrics handler because we have one in the control plane. Add one // No metrics handler because we have one in the control plane. Add one
// if we no longer register with that grpc Server // if we no longer register with that grpc Server
localGRPCServer := grpc.NewServer( localGRPCServer := grpc.NewServer(
grpc.StatsHandler(trace.NewServerStatsHandler(otelgrpc.NewServerHandler(otelgrpc.WithTracerProvider(tracerProvider)))), grpc.StatsHandler(otelgrpc.NewServerHandler(otelgrpc.WithTracerProvider(tracerProvider))),
grpc.ChainStreamInterceptor(log.StreamServerInterceptor(log.Ctx(ctx)), si), grpc.ChainStreamInterceptor(log.StreamServerInterceptor(log.Ctx(ctx)), si),
grpc.ChainUnaryInterceptor(log.UnaryServerInterceptor(log.Ctx(ctx)), ui), grpc.ChainUnaryInterceptor(log.UnaryServerInterceptor(log.Ctx(ctx)), ui),
) )

2
go.mod
View file

@ -15,6 +15,7 @@ require (
github.com/cenkalti/backoff/v4 v4.3.0 github.com/cenkalti/backoff/v4 v4.3.0
github.com/cespare/xxhash/v2 v2.3.0 github.com/cespare/xxhash/v2 v2.3.0
github.com/cloudflare/circl v1.5.0 github.com/cloudflare/circl v1.5.0
github.com/cncf/xds/go v0.0.0-20240905190251-b4127c9b8d78
github.com/coreos/go-oidc/v3 v3.11.0 github.com/coreos/go-oidc/v3 v3.11.0
github.com/docker/docker v27.4.1+incompatible github.com/docker/docker v27.4.1+incompatible
github.com/envoyproxy/go-control-plane/envoy v1.32.2 github.com/envoyproxy/go-control-plane/envoy v1.32.2
@ -133,7 +134,6 @@ require (
github.com/aws/smithy-go v1.22.1 // indirect github.com/aws/smithy-go v1.22.1 // indirect
github.com/beorn7/perks v1.0.1 // indirect github.com/beorn7/perks v1.0.1 // indirect
github.com/caddyserver/zerossl v0.1.3 // indirect github.com/caddyserver/zerossl v0.1.3 // indirect
github.com/cncf/xds/go v0.0.0-20240905190251-b4127c9b8d78 // indirect
github.com/containerd/log v0.1.0 // indirect github.com/containerd/log v0.1.0 // indirect
github.com/containerd/platforms v0.2.1 // indirect github.com/containerd/platforms v0.2.1 // indirect
github.com/cpuguy83/dockercfg v0.3.2 // indirect github.com/cpuguy83/dockercfg v0.3.2 // indirect

View file

@ -114,7 +114,7 @@ func NewServer(
), ),
) )
srv.GRPCServer = grpc.NewServer( srv.GRPCServer = grpc.NewServer(
grpc.StatsHandler(trace.NewServerStatsHandler(otelgrpc.NewServerHandler(otelgrpc.WithTracerProvider(tracerProvider)))), grpc.StatsHandler(otelgrpc.NewServerHandler(otelgrpc.WithTracerProvider(tracerProvider))),
grpc.ChainUnaryInterceptor( grpc.ChainUnaryInterceptor(
log.UnaryServerInterceptor(log.Ctx(ctx)), log.UnaryServerInterceptor(log.Ctx(ctx)),
requestid.UnaryServerInterceptor(), requestid.UnaryServerInterceptor(),

View file

@ -3,7 +3,6 @@ package trace_test
import ( import (
"context" "context"
"fmt" "fmt"
"regexp"
"runtime" "runtime"
"strings" "strings"
"sync/atomic" "sync/atomic"
@ -12,11 +11,11 @@ import (
"github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/telemetry/trace" "github.com/pomerium/pomerium/internal/telemetry/trace"
"github.com/pomerium/pomerium/internal/telemetry/trace/mock_otlptrace"
"github.com/pomerium/pomerium/internal/testenv" "github.com/pomerium/pomerium/internal/testenv"
"github.com/pomerium/pomerium/internal/testenv/scenarios" "github.com/pomerium/pomerium/internal/testenv/scenarios"
"github.com/pomerium/pomerium/internal/testenv/snippets" "github.com/pomerium/pomerium/internal/testenv/snippets"
"github.com/pomerium/pomerium/internal/testutil" . "github.com/pomerium/pomerium/internal/testutil/tracetest" //nolint:revive
"github.com/pomerium/pomerium/internal/testutil/tracetest/mock_otlptrace"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.opentelemetry.io/otel" "go.opentelemetry.io/otel"
@ -281,7 +280,7 @@ func TestNewRemoteClientFromEnv(t *testing.T) {
"OTEL_TRACES_EXPORTER": "otlp", "OTEL_TRACES_EXPORTER": "otlp",
"OTEL_EXPORTER_OTLP_ENDPOINT": grpcEndpoint.Value(), "OTEL_EXPORTER_OTLP_ENDPOINT": grpcEndpoint.Value(),
}, },
uploadErr: "net/http: HTTP/1.x transport connection broken", uploadErr: "net/http: HTTP/1.x transport connection broken: malformed HTTP response",
}, },
{ {
name: "HTTP endpoint, auto protocol", name: "HTTP endpoint, auto protocol",
@ -314,7 +313,7 @@ func TestNewRemoteClientFromEnv(t *testing.T) {
}, },
}, },
{ {
name: "no exporter", name: "exporter unset",
env: map[string]string{ env: map[string]string{
"OTEL_TRACES_EXPORTER": "", "OTEL_TRACES_EXPORTER": "",
"OTEL_EXPORTER_OTLP_TRACES_ENDPOINT": httpEndpoint.Value(), "OTEL_EXPORTER_OTLP_TRACES_ENDPOINT": httpEndpoint.Value(),
@ -323,7 +322,7 @@ func TestNewRemoteClientFromEnv(t *testing.T) {
expectNoSpans: true, expectNoSpans: true,
}, },
{ {
name: "no exporter", name: "exporter noop",
env: map[string]string{ env: map[string]string{
"OTEL_TRACES_EXPORTER": "noop", "OTEL_TRACES_EXPORTER": "noop",
"OTEL_EXPORTER_OTLP_TRACES_ENDPOINT": httpEndpoint.Value(), "OTEL_EXPORTER_OTLP_TRACES_ENDPOINT": httpEndpoint.Value(),
@ -332,7 +331,7 @@ func TestNewRemoteClientFromEnv(t *testing.T) {
expectNoSpans: true, expectNoSpans: true,
}, },
{ {
name: "no exporter", name: "exporter none",
env: map[string]string{ env: map[string]string{
"OTEL_TRACES_EXPORTER": "none", "OTEL_TRACES_EXPORTER": "none",
"OTEL_EXPORTER_OTLP_TRACES_ENDPOINT": httpEndpoint.Value(), "OTEL_EXPORTER_OTLP_TRACES_ENDPOINT": httpEndpoint.Value(),
@ -403,19 +402,6 @@ func TestNewRemoteClientFromEnv(t *testing.T) {
otel.SetErrorHandler(handler) otel.SetErrorHandler(handler)
t.Cleanup(func() { otel.SetErrorHandler(oldErrHandler) }) t.Cleanup(func() { otel.SetErrorHandler(oldErrHandler) })
if tc.uploadErr != "" {
recorder := env.NewLogRecorder(testenv.WithSkipCloseDelay())
defer func() {
recorder.Match([]map[string]any{
{
"level": "error",
"error": regexp.MustCompile(`^Post "[^"]+": net/http: HTTP/1.x transport connection broken: malformed HTTP response.*$`),
"message": "error uploading traces",
},
})
}()
}
remoteClient := trace.NewRemoteClientFromEnv() remoteClient := trace.NewRemoteClientFromEnv()
ctx := trace.Options{ ctx := trace.Options{
RemoteClient: remoteClient, RemoteClient: remoteClient,
@ -431,20 +417,20 @@ func TestNewRemoteClientFromEnv(t *testing.T) {
_, span := tp.Tracer(trace.PomeriumCoreTracer).Start(ctx, "test span") _, span := tp.Tracer(trace.PomeriumCoreTracer).Start(ctx, "test span")
span.End() span.End()
assert.NoError(t, trace.ForceFlush(ctx))
assert.NoError(t, trace.ShutdownContext(ctx))
if tc.uploadErr != "" { if tc.uploadErr != "" {
assert.ErrorContains(t, trace.ForceFlush(ctx), tc.uploadErr)
assert.NoError(t, trace.ShutdownContext(ctx))
return return
} }
assert.NoError(t, trace.ShutdownContext(ctx))
results := testutil.NewTraceResults(receiver.FlushResourceSpans()) results := NewTraceResults(receiver.FlushResourceSpans())
if tc.expectNoSpans { if tc.expectNoSpans {
results.MatchTraces(t, testutil.MatchOptions{Exact: true}) results.MatchTraces(t, MatchOptions{Exact: true})
} else { } else {
results.MatchTraces(t, testutil.MatchOptions{ results.MatchTraces(t, MatchOptions{
Exact: true, Exact: true,
}, testutil.Match{Name: t.Name() + ": test span", TraceCount: 1, Services: []string{t.Name()}}) }, Match{Name: t.Name() + ": test span", TraceCount: 1, Services: []string{t.Name()}})
} }
}) })
} }

View file

@ -2,14 +2,23 @@ package trace
import ( import (
"context" "context"
"encoding/binary"
"encoding/json"
"errors"
"fmt" "fmt"
"io" "io"
"os" "os"
"runtime" "runtime"
"slices"
"strings" "strings"
"sync"
"time"
"go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/attribute"
sdktrace "go.opentelemetry.io/otel/sdk/trace" sdktrace "go.opentelemetry.io/otel/sdk/trace"
oteltrace "go.opentelemetry.io/otel/trace"
coltracepb "go.opentelemetry.io/proto/otlp/collector/trace/v1"
"google.golang.org/protobuf/encoding/protojson"
) )
type DebugFlags uint32 type DebugFlags uint32
@ -30,8 +39,10 @@ const (
// over time. // over time.
TrackAllSpans = (1 << iota) | TrackSpanCallers TrackAllSpans = (1 << iota) | TrackSpanCallers
// If set, will log all trace ID mappings on close. // If set, will log all trace IDs and their span counts on close.
LogTraceIDMappings = (1 << iota) //
// Enables [TrackAllSpans]
LogTraceIDs = (1 << iota) | TrackAllSpans
// If set, will log all spans observed by the exporter on close. These spans // If set, will log all spans observed by the exporter on close. These spans
// may belong to incomplete traces. // may belong to incomplete traces.
@ -39,12 +50,6 @@ const (
// Enables [TrackAllSpans] // Enables [TrackAllSpans]
LogAllSpans = (1 << iota) | TrackAllSpans LogAllSpans = (1 << iota) | TrackAllSpans
// If set, will log the raw json payloads and timestamps of export requests
// on close.
// Use with caution, this will cause significantly increasing memory usage
// over time.
LogAllEvents = (1 << iota)
// If set, will log all exported spans when a warning is issued on close // If set, will log all exported spans when a warning is issued on close
// (requires warning flags to also be set) // (requires warning flags to also be set)
// //
@ -53,7 +58,7 @@ const (
// If set, will log all trace ID mappings when a warning is issued on close. // If set, will log all trace ID mappings when a warning is issued on close.
// (requires warning flags to also be set) // (requires warning flags to also be set)
LogTraceIDMappingsOnWarn = (1 << iota) LogTraceIDsOnWarn = (1 << iota)
// If set, will print a warning to stderr on close if there are any incomplete // If set, will print a warning to stderr on close if there are any incomplete
// traces (traces with no observed root spans) // traces (traces with no observed root spans)
@ -78,6 +83,49 @@ func (df DebugFlags) Check(flags DebugFlags) bool {
return (df & flags) == flags return (df & flags) == flags
} }
var (
ErrIncompleteSpans = errors.New("exporter shut down with incomplete spans")
ErrMissingParentSpans = errors.New("exporter shut down with missing parent spans")
)
// WaitForSpans will block up to the given max duration and wait for all
// in-flight spans from tracers created with the given context to end. This
// function can be called more than once, and is safe to call from multiple
// goroutines in parallel.
//
// This requires the [TrackSpanReferences] debug flag to have been set with
// [Options.NewContext]. Otherwise, this function is a no-op and will return
// immediately.
//
// If this function blocks for more than 10 seconds, it will print a warning
// to stderr containing a list of span IDs it is waiting for, and the IDs of
// their parents (if known). Additionally, if the [TrackAllSpans] debug flag
// is set, details about parent spans will be displayed, including call site
// and trace ID.
func WaitForSpans(ctx context.Context, maxDuration time.Duration) error {
if sys := systemContextFromContext(ctx); sys != nil && sys.observer != nil {
done := make(chan struct{})
go func() {
defer close(done)
sys.observer.wait(10 * time.Second)
}()
select {
case <-done:
return nil
case <-time.After(maxDuration):
return ErrMissingParentSpans
}
}
return nil
}
func DebugFlagsFromContext(ctx context.Context) DebugFlags {
if sys := systemContextFromContext(ctx); sys != nil {
return sys.options.DebugFlags
}
return 0
}
type stackTraceProcessor struct{} type stackTraceProcessor struct{}
// ForceFlush implements trace.SpanProcessor. // ForceFlush implements trace.SpanProcessor.
@ -117,3 +165,329 @@ func endMsg(msg *strings.Builder) {
} }
fmt.Fprint(w, msg.String()) fmt.Fprint(w, msg.String())
} }
type DebugEvent struct {
Timestamp time.Time `json:"timestamp"`
Request *coltracepb.ExportTraceServiceRequest `json:"request"`
}
func (e DebugEvent) MarshalJSON() ([]byte, error) {
type debugEvent struct {
Timestamp time.Time `json:"timestamp"`
Request json.RawMessage `json:"request"`
}
reqData, _ := protojson.Marshal(e.Request)
return json.Marshal(debugEvent{
Timestamp: e.Timestamp,
Request: reqData,
})
}
func (e *DebugEvent) UnmarshalJSON(b []byte) error {
type debugEvent struct {
Timestamp time.Time `json:"timestamp"`
Request json.RawMessage `json:"request"`
}
var ev debugEvent
if err := json.Unmarshal(b, &ev); err != nil {
return err
}
e.Timestamp = ev.Timestamp
var msg coltracepb.ExportTraceServiceRequest
if err := protojson.Unmarshal(ev.Request, &msg); err != nil {
return err
}
e.Request = &msg
return nil
}
const shardCount = 64
type (
shardedSet [shardCount]map[oteltrace.SpanID]struct{}
shardedLocks [shardCount]sync.Mutex
)
func (s *shardedSet) Range(f func(key oteltrace.SpanID)) {
for i := range shardCount {
for k := range s[i] {
f(k)
}
}
}
func (s *shardedLocks) LockAll() {
for i := range shardCount {
s[i].Lock()
}
}
func (s *shardedLocks) UnlockAll() {
for i := range shardCount {
s[i].Unlock()
}
}
type spanTracker struct {
inflightSpansMu shardedLocks
inflightSpans shardedSet
allSpans sync.Map
debugFlags DebugFlags
observer *spanObserver
shutdownOnce sync.Once
}
func newSpanTracker(observer *spanObserver, debugFlags DebugFlags) *spanTracker {
st := &spanTracker{
observer: observer,
debugFlags: debugFlags,
}
for i := range len(st.inflightSpans) {
st.inflightSpans[i] = make(map[oteltrace.SpanID]struct{})
}
return st
}
type spanInfo struct {
Name string
SpanContext oteltrace.SpanContext
Parent oteltrace.SpanContext
caller string
startTime time.Time
}
// ForceFlush implements trace.SpanProcessor.
func (t *spanTracker) ForceFlush(context.Context) error {
return nil
}
// OnEnd implements trace.SpanProcessor.
func (t *spanTracker) OnEnd(s sdktrace.ReadOnlySpan) {
id := s.SpanContext().SpanID()
bucket := binary.BigEndian.Uint64(id[:]) % shardCount
t.inflightSpansMu[bucket].Lock()
defer t.inflightSpansMu[bucket].Unlock()
delete(t.inflightSpans[bucket], id)
}
// OnStart implements trace.SpanProcessor.
func (t *spanTracker) OnStart(_ context.Context, s sdktrace.ReadWriteSpan) {
id := s.SpanContext().SpanID()
bucket := binary.BigEndian.Uint64(id[:]) % shardCount
t.inflightSpansMu[bucket].Lock()
defer t.inflightSpansMu[bucket].Unlock()
t.inflightSpans[bucket][id] = struct{}{}
if t.debugFlags.Check(TrackSpanReferences) {
if s.Parent().IsValid() {
t.observer.ObserveReference(s.Parent().SpanID(), id)
}
t.observer.Observe(id)
}
if t.debugFlags.Check(TrackAllSpans) {
var caller string
for _, attr := range s.Attributes() {
if attr.Key == "caller" {
caller = attr.Value.AsString()
break
}
}
t.allSpans.Store(id, &spanInfo{
Name: s.Name(),
SpanContext: s.SpanContext(),
Parent: s.Parent(),
caller: caller,
startTime: s.StartTime(),
})
}
}
// Shutdown implements trace.SpanProcessor.
func (t *spanTracker) Shutdown(_ context.Context) error {
if t.debugFlags == 0 {
return nil
}
didWarn := false
t.shutdownOnce.Do(func() {
if t.debugFlags.Check(WarnOnUnresolvedReferences) {
var unknownParentIDs []string
for id, via := range t.observer.referencedIDs {
if via.IsValid() {
if t.debugFlags.Check(TrackAllSpans) {
if viaSpan, ok := t.allSpans.Load(via); ok {
unknownParentIDs = append(unknownParentIDs, fmt.Sprintf("%s via %s (%s)", id, via, viaSpan.(*spanInfo).Name))
} else {
unknownParentIDs = append(unknownParentIDs, fmt.Sprintf("%s via %s", id, via))
}
}
}
}
if len(unknownParentIDs) > 0 {
didWarn = true
msg := startMsg("WARNING: parent spans referenced but never seen:\n")
for _, str := range unknownParentIDs {
msg.WriteString(str)
msg.WriteString("\n")
}
endMsg(msg)
}
}
if t.debugFlags.Check(WarnOnIncompleteSpans) {
if t.debugFlags.Check(TrackAllSpans) {
incompleteSpans := []*spanInfo{}
t.inflightSpansMu.LockAll()
t.inflightSpans.Range(func(key oteltrace.SpanID) {
if info, ok := t.allSpans.Load(key); ok {
incompleteSpans = append(incompleteSpans, info.(*spanInfo))
}
})
t.inflightSpansMu.UnlockAll()
if len(incompleteSpans) > 0 {
didWarn = true
msg := startMsg("WARNING: spans not ended:\n")
longestName := 0
for _, span := range incompleteSpans {
longestName = max(longestName, len(span.Name)+2)
}
for _, span := range incompleteSpans {
var startedAt string
if span.caller != "" {
startedAt = " | started at: " + span.caller
}
fmt.Fprintf(msg, "%-*s (trace: %s | span: %s | parent: %s%s)\n", longestName, "'"+span.Name+"'",
span.SpanContext.TraceID(), span.SpanContext.SpanID(), span.Parent.SpanID(), startedAt)
}
endMsg(msg)
}
} else {
incompleteSpans := []oteltrace.SpanID{}
t.inflightSpansMu.LockAll()
t.inflightSpans.Range(func(key oteltrace.SpanID) {
incompleteSpans = append(incompleteSpans, key)
})
t.inflightSpansMu.UnlockAll()
if len(incompleteSpans) > 0 {
didWarn = true
msg := startMsg("WARNING: spans not ended:\n")
for _, span := range incompleteSpans {
fmt.Fprintf(msg, "%s\n", span)
}
msg.WriteString("Note: set TrackAllSpans flag for more info\n")
endMsg(msg)
}
}
}
if t.debugFlags.Check(LogAllSpans) || (t.debugFlags.Check(LogAllSpansOnWarn) && didWarn) {
allSpans := []*spanInfo{}
t.allSpans.Range(func(_, value any) bool {
allSpans = append(allSpans, value.(*spanInfo))
return true
})
slices.SortFunc(allSpans, func(a, b *spanInfo) int {
return a.startTime.Compare(b.startTime)
})
msg := startMsg("All observed spans:\n")
longestName := 0
for _, span := range allSpans {
longestName = max(longestName, len(span.Name)+2)
}
for _, span := range allSpans {
var startedAt string
if span.caller != "" {
startedAt = " | started at: " + span.caller
}
fmt.Fprintf(msg, "%-*s (trace: %s | span: %s | parent: %s%s)\n", longestName, "'"+span.Name+"'",
span.SpanContext.TraceID(), span.SpanContext.SpanID(), span.Parent.SpanID(), startedAt)
}
endMsg(msg)
}
if t.debugFlags.Check(LogTraceIDs) || (didWarn && t.debugFlags.Check(LogTraceIDsOnWarn)) {
msg := startMsg("Known trace ids:\n")
traceIDs := map[oteltrace.TraceID]int{}
t.allSpans.Range(func(_, value any) bool {
v := value.(*spanInfo)
traceIDs[v.SpanContext.TraceID()]++
return true
})
for id, n := range traceIDs {
fmt.Fprintf(msg, "%s (%d spans)\n", id.String(), n)
}
endMsg(msg)
}
})
if didWarn {
return ErrIncompleteSpans
}
return nil
}
func newSpanObserver() *spanObserver {
return &spanObserver{
referencedIDs: map[oteltrace.SpanID]oteltrace.SpanID{},
cond: sync.NewCond(&sync.Mutex{}),
}
}
type spanObserver struct {
cond *sync.Cond
referencedIDs map[oteltrace.SpanID]oteltrace.SpanID
unobservedIDs int
}
func (obs *spanObserver) ObserveReference(id oteltrace.SpanID, via oteltrace.SpanID) {
obs.cond.L.Lock()
defer obs.cond.L.Unlock()
if _, referenced := obs.referencedIDs[id]; !referenced {
obs.referencedIDs[id] = via // referenced, but not observed
// It is possible for new unobserved references to come in while waiting,
// but incrementing the counter wouldn't satisfy the condition so we don't
// need to signal the waiters
obs.unobservedIDs++
}
}
func (obs *spanObserver) Observe(id oteltrace.SpanID) {
obs.cond.L.Lock()
defer obs.cond.L.Unlock()
if observed, referenced := obs.referencedIDs[id]; !referenced || observed.IsValid() { // NB: subtle condition
obs.referencedIDs[id] = zeroSpanID
if referenced {
obs.unobservedIDs--
obs.cond.Broadcast()
}
}
}
func (obs *spanObserver) wait(warnAfter time.Duration) {
done := make(chan struct{})
defer close(done)
go func() {
select {
case <-done:
return
case <-time.After(warnAfter):
obs.debugWarnWaiting()
}
}()
obs.cond.L.Lock()
for obs.unobservedIDs > 0 {
obs.cond.Wait()
}
obs.cond.L.Unlock()
}
func (obs *spanObserver) debugWarnWaiting() {
obs.cond.L.Lock()
msg := startMsg(fmt.Sprintf("Waiting on %d unobserved spans:\n", obs.unobservedIDs))
for id, via := range obs.referencedIDs {
if via.IsValid() {
fmt.Fprintf(msg, "%s via %s\n", id, via)
}
}
endMsg(msg)
obs.cond.L.Unlock()
}

View file

@ -0,0 +1,288 @@
package trace_test
import (
"bytes"
"context"
"fmt"
"runtime"
"sync/atomic"
"testing"
"time"
"github.com/pomerium/pomerium/internal/telemetry/trace"
. "github.com/pomerium/pomerium/internal/testutil/tracetest" //nolint:revive
"github.com/stretchr/testify/assert"
sdktrace "go.opentelemetry.io/otel/sdk/trace"
oteltrace "go.opentelemetry.io/otel/trace"
)
func TestSpanObserver(t *testing.T) {
t.Run("observe single reference", func(t *testing.T) {
obs := trace.NewSpanObserver()
assert.Equal(t, []oteltrace.SpanID{}, obs.XUnobservedIDs())
obs.ObserveReference(Span(1).ID(), Span(2).ID())
assert.Equal(t, []oteltrace.SpanID{Span(1).ID()}, obs.XUnobservedIDs())
obs.Observe(Span(1).ID())
assert.Equal(t, []oteltrace.SpanID{}, obs.XUnobservedIDs())
})
t.Run("observe multiple references", func(t *testing.T) {
obs := trace.NewSpanObserver()
obs.ObserveReference(Span(1).ID(), Span(2).ID())
obs.ObserveReference(Span(1).ID(), Span(3).ID())
obs.ObserveReference(Span(1).ID(), Span(4).ID())
assert.Equal(t, []oteltrace.SpanID{Span(1).ID()}, obs.XUnobservedIDs())
obs.Observe(Span(1).ID())
assert.Equal(t, []oteltrace.SpanID{}, obs.XUnobservedIDs())
})
t.Run("observe before reference", func(t *testing.T) {
obs := trace.NewSpanObserver()
obs.Observe(Span(1).ID())
assert.Equal(t, []oteltrace.SpanID{}, obs.XUnobservedIDs())
obs.ObserveReference(Span(1).ID(), Span(2).ID())
assert.Equal(t, []oteltrace.SpanID{}, obs.XUnobservedIDs())
})
t.Run("wait", func(t *testing.T) {
obs := trace.NewSpanObserver()
obs.ObserveReference(Span(1).ID(), Span(2).ID())
obs.Observe(Span(2).ID())
obs.ObserveReference(Span(3).ID(), Span(4).ID())
obs.Observe(Span(4).ID())
obs.ObserveReference(Span(5).ID(), Span(6).ID())
obs.Observe(Span(6).ID())
waitOkToExit := atomic.Bool{}
waitExited := atomic.Bool{}
go func() {
defer waitExited.Store(true)
obs.XWait()
assert.True(t, waitOkToExit.Load(), "wait exited early")
}()
time.Sleep(10 * time.Millisecond)
assert.False(t, waitExited.Load())
obs.Observe(Span(1).ID())
time.Sleep(10 * time.Millisecond)
assert.False(t, waitExited.Load())
obs.Observe(Span(3).ID())
time.Sleep(10 * time.Millisecond)
assert.False(t, waitExited.Load())
waitOkToExit.Store(true)
obs.Observe(Span(5).ID())
assert.Eventually(t, waitExited.Load, 10*time.Millisecond, 1*time.Millisecond)
})
t.Run("new references observed during wait", func(t *testing.T) {
obs := trace.NewSpanObserver()
obs.ObserveReference(Span(1).ID(), Span(2).ID())
obs.Observe(Span(2).ID())
obs.ObserveReference(Span(3).ID(), Span(4).ID())
obs.Observe(Span(4).ID())
obs.ObserveReference(Span(5).ID(), Span(6).ID())
obs.Observe(Span(6).ID())
waitOkToExit := atomic.Bool{}
waitExited := atomic.Bool{}
go func() {
defer waitExited.Store(true)
obs.XWait()
assert.True(t, waitOkToExit.Load(), "wait exited early")
}()
assert.Equal(t, []oteltrace.SpanID{Span(1).ID(), Span(3).ID(), Span(5).ID()}, obs.XUnobservedIDs())
time.Sleep(10 * time.Millisecond)
assert.False(t, waitExited.Load())
obs.Observe(Span(1).ID())
assert.Equal(t, []oteltrace.SpanID{Span(3).ID(), Span(5).ID()}, obs.XUnobservedIDs())
time.Sleep(10 * time.Millisecond)
assert.False(t, waitExited.Load())
obs.Observe(Span(3).ID())
assert.Equal(t, []oteltrace.SpanID{Span(5).ID()}, obs.XUnobservedIDs())
time.Sleep(10 * time.Millisecond)
assert.False(t, waitExited.Load())
// observe a new reference
obs.ObserveReference(Span(7).ID(), Span(8).ID())
obs.Observe(Span(8).ID())
assert.Equal(t, []oteltrace.SpanID{Span(5).ID(), Span(7).ID()}, obs.XUnobservedIDs())
time.Sleep(10 * time.Millisecond)
assert.False(t, waitExited.Load())
obs.Observe(Span(5).ID())
assert.Equal(t, []oteltrace.SpanID{Span(7).ID()}, obs.XUnobservedIDs())
time.Sleep(10 * time.Millisecond)
assert.False(t, waitExited.Load())
waitOkToExit.Store(true)
obs.Observe(Span(7).ID())
assert.Equal(t, []oteltrace.SpanID{}, obs.XUnobservedIDs())
assert.Eventually(t, waitExited.Load, 10*time.Millisecond, 1*time.Millisecond)
})
t.Run("multiple waiters", func(t *testing.T) {
t.Parallel()
obs := trace.NewSpanObserver()
obs.ObserveReference(Span(1).ID(), Span(2).ID())
obs.Observe(Span(2).ID())
waitersExited := atomic.Int32{}
for range 10 {
go func() {
defer waitersExited.Add(1)
obs.XWait()
}()
}
assert.Equal(t, []oteltrace.SpanID{Span(1).ID()}, obs.XUnobservedIDs())
time.Sleep(10 * time.Millisecond)
assert.Equal(t, int32(0), waitersExited.Load())
obs.Observe(Span(1).ID())
startTime := time.Now()
for waitersExited.Load() != 10 {
if time.Since(startTime) > 1*time.Millisecond {
t.Fatal("timed out")
}
runtime.Gosched()
}
})
}
func TestSpanTracker(t *testing.T) {
t.Run("no debug flags", func(t *testing.T) {
t.Parallel()
obs := trace.NewSpanObserver()
tracker := trace.NewSpanTracker(obs, 0)
tp := sdktrace.NewTracerProvider(sdktrace.WithSpanProcessor(tracker))
tracer := tp.Tracer("test")
assert.Equal(t, []oteltrace.SpanID{}, tracker.XInflightSpans())
_, span1 := tracer.Start(context.Background(), "span 1")
assert.Equal(t, []oteltrace.SpanID{span1.SpanContext().SpanID()}, tracker.XInflightSpans())
assert.Equal(t, []oteltrace.SpanID{}, obs.XObservedIDs())
span1.End()
assert.Equal(t, []oteltrace.SpanID{}, tracker.XInflightSpans())
assert.Equal(t, []oteltrace.SpanID{}, obs.XObservedIDs())
})
t.Run("with TrackSpanReferences debug flag", func(t *testing.T) {
t.Parallel()
obs := trace.NewSpanObserver()
tracker := trace.NewSpanTracker(obs, trace.TrackSpanReferences)
tp := sdktrace.NewTracerProvider(sdktrace.WithSpanProcessor(tracker))
tracer := tp.Tracer("test")
assert.Equal(t, []oteltrace.SpanID{}, tracker.XInflightSpans())
_, span1 := tracer.Start(context.Background(), "span 1")
assert.Equal(t, []oteltrace.SpanID{span1.SpanContext().SpanID()}, tracker.XInflightSpans())
assert.Equal(t, []oteltrace.SpanID{span1.SpanContext().SpanID()}, obs.XObservedIDs())
span1.End()
assert.Equal(t, []oteltrace.SpanID{}, tracker.XInflightSpans())
assert.Equal(t, []oteltrace.SpanID{span1.SpanContext().SpanID()}, obs.XObservedIDs())
})
}
func TestSpanTrackerWarnings(t *testing.T) {
t.Run("WarnOnIncompleteSpans", func(t *testing.T) {
var buf bytes.Buffer
trace.SetDebugMessageWriterForTest(t, &buf)
obs := trace.NewSpanObserver()
tracker := trace.NewSpanTracker(obs, trace.WarnOnIncompleteSpans)
tp := sdktrace.NewTracerProvider(sdktrace.WithSpanProcessor(tracker))
tracer := tp.Tracer("test")
_, span1 := tracer.Start(context.Background(), "span 1")
assert.ErrorIs(t, tp.Shutdown(context.Background()), trace.ErrIncompleteSpans)
assert.Equal(t, fmt.Sprintf(`
==================================================
WARNING: spans not ended:
%s
Note: set TrackAllSpans flag for more info
==================================================
`, span1.SpanContext().SpanID()), buf.String())
})
t.Run("WarnOnIncompleteSpans with TrackAllSpans", func(t *testing.T) {
var buf bytes.Buffer
trace.SetDebugMessageWriterForTest(t, &buf)
obs := trace.NewSpanObserver()
tracker := trace.NewSpanTracker(obs, trace.WarnOnIncompleteSpans|trace.TrackAllSpans)
tp := sdktrace.NewTracerProvider(sdktrace.WithSpanProcessor(tracker))
tracer := tp.Tracer("test")
_, span1 := tracer.Start(context.Background(), "span 1")
assert.ErrorIs(t, tp.Shutdown(context.Background()), trace.ErrIncompleteSpans)
assert.Equal(t, fmt.Sprintf(`
==================================================
WARNING: spans not ended:
'span 1' (trace: %s | span: %s | parent: 0000000000000000)
==================================================
`, span1.SpanContext().TraceID(), span1.SpanContext().SpanID()), buf.String())
})
t.Run("WarnOnIncompleteSpans with TrackAllSpans and stackTraceProcessor", func(t *testing.T) {
var buf bytes.Buffer
trace.SetDebugMessageWriterForTest(t, &buf)
obs := trace.NewSpanObserver()
tracker := trace.NewSpanTracker(obs, trace.WarnOnIncompleteSpans|trace.TrackAllSpans)
tp := sdktrace.NewTracerProvider(sdktrace.WithSpanProcessor(&trace.XStackTraceProcessor{}), sdktrace.WithSpanProcessor(tracker))
tracer := tp.Tracer("test")
_, span1 := tracer.Start(context.Background(), "span 1")
_, file, line, _ := runtime.Caller(0)
line--
assert.ErrorIs(t, tp.Shutdown(context.Background()), trace.ErrIncompleteSpans)
assert.Equal(t, fmt.Sprintf(`
==================================================
WARNING: spans not ended:
'span 1' (trace: %s | span: %s | parent: 0000000000000000 | started at: %s:%d)
==================================================
`, span1.SpanContext().TraceID(), span1.SpanContext().SpanID(), file, line), buf.String())
})
t.Run("LogAllSpansOnWarn", func(t *testing.T) {
var buf bytes.Buffer
trace.SetDebugMessageWriterForTest(t, &buf)
obs := trace.NewSpanObserver()
tracker := trace.NewSpanTracker(obs, trace.WarnOnIncompleteSpans|trace.TrackAllSpans|trace.LogAllSpansOnWarn)
tp := sdktrace.NewTracerProvider(sdktrace.WithSpanProcessor(&trace.XStackTraceProcessor{}), sdktrace.WithSpanProcessor(tracker))
tracer := tp.Tracer("test")
_, span1 := tracer.Start(context.Background(), "span 1")
time.Sleep(10 * time.Millisecond)
span1.End()
time.Sleep(10 * time.Millisecond)
_, span2 := tracer.Start(context.Background(), "span 2")
_, file, line, _ := runtime.Caller(0)
line--
tp.Shutdown(context.Background())
assert.Equal(t,
fmt.Sprintf(`
==================================================
WARNING: spans not ended:
'span 2' (trace: %[1]s | span: %[2]s | parent: 0000000000000000 | started at: %[3]s:%[4]d)
==================================================
==================================================
All observed spans:
'span 1' (trace: %[5]s | span: %[6]s | parent: 0000000000000000 | started at: %[3]s:%[7]d)
'span 2' (trace: %[1]s | span: %[2]s | parent: 0000000000000000 | started at: %[3]s:%[4]d)
==================================================
`,
span2.SpanContext().TraceID(), span2.SpanContext().SpanID(), file, line,
span1.SpanContext().TraceID(), span1.SpanContext().SpanID(), line-4,
), buf.String())
})
}

View file

@ -8,15 +8,11 @@ import (
"github.com/gorilla/mux" "github.com/gorilla/mux"
"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp" "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/propagation"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/stats" "google.golang.org/grpc/stats"
) )
func NewHTTPMiddleware(opts ...otelhttp.Option) func(http.Handler) http.Handler { func NewHTTPMiddleware(opts ...otelhttp.Option) mux.MiddlewareFunc {
return func(next http.Handler) http.Handler { return otelhttp.NewMiddleware("Server: %s %s", append(opts, otelhttp.WithSpanNameFormatter(func(operation string, r *http.Request) string {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
routeStr := "" routeStr := ""
route := mux.CurrentRoute(r) route := mux.CurrentRoute(r)
if route != nil { if route != nil {
@ -29,70 +25,8 @@ func NewHTTPMiddleware(opts ...otelhttp.Option) func(http.Handler) http.Handler
} }
} }
} }
traceparent := r.Header.Get("Traceparent") return fmt.Sprintf(operation, r.Method, routeStr)
if traceparent != "" { }))...)
xPomeriumTraceparent := r.Header.Get("X-Pomerium-Traceparent")
if xPomeriumTraceparent != "" {
sc, err := ParseTraceparent(xPomeriumTraceparent)
if err == nil {
r.Header.Set("Traceparent", WithTraceFromSpanContext(traceparent, sc))
ctx := otel.GetTextMapPropagator().Extract(r.Context(), propagation.HeaderCarrier(r.Header))
r = r.WithContext(ctx)
}
}
}
otelhttp.NewHandler(next, fmt.Sprintf("Server: %s %s", r.Method, routeStr), opts...).ServeHTTP(w, r)
})
}
}
func NewServerStatsHandler(base stats.Handler) stats.Handler {
return &serverStatsHandlerWrapper{
base: base,
}
}
type serverStatsHandlerWrapper struct {
base stats.Handler
}
func (w *serverStatsHandlerWrapper) wrapContext(ctx context.Context) context.Context {
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
return ctx
}
traceparent := md.Get("traceparent")
xPomeriumTraceparent := md.Get("x-pomerium-traceparent")
if len(traceparent) > 0 && traceparent[0] != "" && len(xPomeriumTraceparent) > 0 && xPomeriumTraceparent[0] != "" {
newTracectx, err := ParseTraceparent(xPomeriumTraceparent[0])
if err != nil {
return ctx
}
md.Set("traceparent", WithTraceFromSpanContext(traceparent[0], newTracectx))
return metadata.NewIncomingContext(ctx, md)
}
return ctx
}
// HandleConn implements stats.Handler.
func (w *serverStatsHandlerWrapper) HandleConn(ctx context.Context, stats stats.ConnStats) {
w.base.HandleConn(w.wrapContext(ctx), stats)
}
// HandleRPC implements stats.Handler.
func (w *serverStatsHandlerWrapper) HandleRPC(ctx context.Context, stats stats.RPCStats) {
w.base.HandleRPC(w.wrapContext(ctx), stats)
}
// TagConn implements stats.Handler.
func (w *serverStatsHandlerWrapper) TagConn(ctx context.Context, info *stats.ConnTagInfo) context.Context {
return w.base.TagConn(w.wrapContext(ctx), info)
}
// TagRPC implements stats.Handler.
func (w *serverStatsHandlerWrapper) TagRPC(ctx context.Context, info *stats.RPCTagInfo) context.Context {
return w.base.TagRPC(w.wrapContext(ctx), info)
} }
type clientStatsHandlerWrapper struct { type clientStatsHandlerWrapper struct {

View file

@ -3,161 +3,35 @@ package trace_test
import ( import (
"context" "context"
"errors" "errors"
"net"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"testing" "testing"
"time" "time"
"github.com/gorilla/mux"
"github.com/pomerium/pomerium/internal/telemetry/trace" "github.com/pomerium/pomerium/internal/telemetry/trace"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp" "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
sdktrace "go.opentelemetry.io/otel/sdk/trace"
oteltrace "go.opentelemetry.io/otel/trace" oteltrace "go.opentelemetry.io/otel/trace"
"go.opentelemetry.io/otel/trace/noop"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/interop/grpc_testing"
"google.golang.org/grpc/metadata" "google.golang.org/grpc/metadata"
"google.golang.org/grpc/stats" "google.golang.org/grpc/stats"
"google.golang.org/grpc/test/bufconn"
) )
var cases = []struct {
name string
setTraceparent string
setPomeriumTraceparent string
check func(t testing.TB, ctx context.Context)
}{
{
name: "x-pomerium-traceparent not present",
setTraceparent: Traceparent(Trace(1), Span(1), true),
check: func(t testing.TB, ctx context.Context) {
span := oteltrace.SpanFromContext(ctx)
assert.Equal(t, Trace(1).ID().Value(), span.SpanContext().TraceID())
assert.Equal(t, Span(1).ID(), span.SpanContext().SpanID())
assert.True(t, span.SpanContext().IsSampled())
},
},
{
name: "x-pomerium-traceparent present",
setTraceparent: Traceparent(Trace(2), Span(2), true),
setPomeriumTraceparent: Traceparent(Trace(1), Span(1), true),
check: func(t testing.TB, ctx context.Context) {
span := oteltrace.SpanFromContext(ctx)
assert.Equal(t, Trace(1).ID().Value(), span.SpanContext().TraceID())
assert.Equal(t, Span(2).ID(), span.SpanContext().SpanID())
assert.True(t, span.SpanContext().IsSampled())
},
},
{
name: "x-pomerium-traceparent present, force sampling off",
setTraceparent: Traceparent(Trace(2), Span(2), true),
setPomeriumTraceparent: Traceparent(Trace(1), Span(1), false),
check: func(t testing.TB, ctx context.Context) {
span := oteltrace.SpanFromContext(ctx)
assert.Equal(t, Trace(1).ID().Value(), span.SpanContext().TraceID())
assert.Equal(t, Span(2).ID(), span.SpanContext().SpanID())
assert.Equal(t, false, span.SpanContext().IsSampled())
},
},
{
name: "x-pomerium-traceparent present, force sampling on",
setTraceparent: Traceparent(Trace(2), Span(2), false),
setPomeriumTraceparent: Traceparent(Trace(1), Span(1), true),
check: func(t testing.TB, ctx context.Context) {
span := oteltrace.SpanFromContext(ctx)
assert.Equal(t, Trace(1).ID().Value(), span.SpanContext().TraceID())
assert.Equal(t, Span(2).ID(), span.SpanContext().SpanID())
assert.Equal(t, true, span.SpanContext().IsSampled())
},
},
{
name: "malformed x-pomerium-traceparent",
setTraceparent: Traceparent(Trace(2), Span(2), false),
setPomeriumTraceparent: "00-xxxxxx-yyyyyy-03",
check: func(t testing.TB, ctx context.Context) {
span := oteltrace.SpanFromContext(ctx)
assert.Equal(t, Trace(2).ID().Value(), span.SpanContext().TraceID())
assert.Equal(t, Span(2).ID(), span.SpanContext().SpanID())
assert.Equal(t, false, span.SpanContext().IsSampled())
},
},
}
func TestHTTPMiddleware(t *testing.T) { func TestHTTPMiddleware(t *testing.T) {
for _, tc := range cases { router := mux.NewRouter()
t.Run(tc.name, func(t *testing.T) { tp := sdktrace.NewTracerProvider()
r := httptest.NewRequest(http.MethodGet, "/foo", nil) router.Use(trace.NewHTTPMiddleware(
if tc.setTraceparent != "" { otelhttp.WithTracerProvider(tp),
r.Header.Add("Traceparent", tc.setTraceparent) ))
} router.Path("/foo").HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
if tc.setPomeriumTraceparent != "" { span := oteltrace.SpanFromContext(r.Context())
r.Header.Add("X-Pomerium-Traceparent", tc.setPomeriumTraceparent) assert.Equal(t, "Server: GET /foo", span.(interface{ Name() string }).Name())
} }).Methods(http.MethodGet)
w := httptest.NewRecorder() w := httptest.NewRecorder()
trace.NewHTTPMiddleware( ctx, span := tp.Tracer("test").Start(context.Background(), "test")
otelhttp.WithTracerProvider(noop.NewTracerProvider()), router.ServeHTTP(w, httptest.NewRequestWithContext(ctx, http.MethodGet, "/foo", nil))
)(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { span.End()
tc.check(t, r.Context())
})).ServeHTTP(w, r)
})
}
}
func TestGRPCMiddleware(t *testing.T) {
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
srv := grpc.NewServer(
grpc.StatsHandler(trace.NewServerStatsHandler(otelgrpc.NewServerHandler(
otelgrpc.WithTracerProvider(noop.NewTracerProvider())))),
grpc.Creds(insecure.NewCredentials()),
)
lis := bufconn.Listen(4096)
grpc_testing.RegisterTestServiceServer(srv, &testServer{
fn: func(ctx context.Context) {
tc.check(t, ctx)
},
})
go srv.Serve(lis)
t.Cleanup(srv.Stop)
client, err := grpc.NewClient("passthrough://ignore",
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithStatsHandler(otelgrpc.NewClientHandler(
otelgrpc.WithTracerProvider(noop.NewTracerProvider()))),
grpc.WithContextDialer(func(ctx context.Context, _ string) (net.Conn, error) {
return lis.DialContext(ctx)
}),
)
require.NoError(t, err)
ctx := context.Background()
if tc.setTraceparent != "" {
ctx = metadata.AppendToOutgoingContext(ctx,
"traceparent", tc.setTraceparent,
)
}
if tc.setPomeriumTraceparent != "" {
ctx = metadata.AppendToOutgoingContext(ctx,
"x-pomerium-traceparent", tc.setPomeriumTraceparent,
)
}
_, err = grpc_testing.NewTestServiceClient(client).EmptyCall(ctx, &grpc_testing.Empty{})
assert.NoError(t, err)
})
}
}
type testServer struct {
grpc_testing.UnimplementedTestServiceServer
fn func(ctx context.Context)
}
func (ts *testServer) EmptyCall(ctx context.Context, _ *grpc_testing.Empty) (*grpc_testing.Empty, error) {
ts.fn(ctx)
return &grpc_testing.Empty{}, nil
} }
type mockHandler struct { type mockHandler struct {

View file

@ -1,900 +0,0 @@
package trace
import (
"context"
"encoding/binary"
"encoding/json"
"errors"
"fmt"
"os"
"slices"
"strconv"
"sync"
"sync/atomic"
"time"
"unique"
lru "github.com/hashicorp/golang-lru/v2"
"github.com/pomerium/pomerium/internal/log"
"github.com/rs/zerolog"
"go.opentelemetry.io/otel/exporters/otlp/otlptrace"
sdktrace "go.opentelemetry.io/otel/sdk/trace"
oteltrace "go.opentelemetry.io/otel/trace"
coltracepb "go.opentelemetry.io/proto/otlp/collector/trace/v1"
tracev1 "go.opentelemetry.io/proto/otlp/trace/v1"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/proto"
)
var (
maxPendingTraces atomic.Int32
maxCachedTraceIDs atomic.Int32
)
func init() {
envOrDefault := func(envName string, def int32) int32 {
if val, ok := os.LookupEnv(envName); ok {
if num, err := strconv.ParseInt(val, 10, 32); err == nil {
return int32(num)
}
}
return def
}
maxPendingTraces.Store(envOrDefault("POMERIUM_OTEL_MAX_PENDING_TRACES", 8192))
maxCachedTraceIDs.Store(envOrDefault("POMERIUM_OTEL_MAX_CACHED_TRACE_IDS", 16384))
}
func SetMaxPendingTraces(num int32) {
maxPendingTraces.Store(max(num, 0))
}
func SetMaxCachedTraceIDs(num int32) {
maxCachedTraceIDs.Store(max(num, 0))
}
type eviction struct {
traceID unique.Handle[oteltrace.TraceID]
buf *Buffer
}
type SpanExportQueue struct {
closing chan struct{}
uploadC chan []*tracev1.ResourceSpans
requestC chan *coltracepb.ExportTraceServiceRequest
evictC chan eviction
client otlptrace.Client
pendingResourcesByTraceID *lru.Cache[unique.Handle[oteltrace.TraceID], *Buffer]
knownTraceIDMappings *lru.Cache[unique.Handle[oteltrace.TraceID], unique.Handle[oteltrace.TraceID]]
tracker *spanTracker
observer *spanObserver
debugEvents []DebugEvent
logger *zerolog.Logger
debugFlags DebugFlags
debugAllEnqueuedSpans map[oteltrace.SpanID]*tracev1.Span
wg sync.WaitGroup
}
func NewSpanExportQueue(ctx context.Context, client otlptrace.Client) *SpanExportQueue {
debug := DebugFlagsFromContext(ctx)
var observer *spanObserver
if debug.Check(TrackSpanReferences) {
observer = newSpanObserver()
}
q := &SpanExportQueue{
logger: log.Ctx(ctx),
client: client,
closing: make(chan struct{}),
uploadC: make(chan []*tracev1.ResourceSpans, 64),
requestC: make(chan *coltracepb.ExportTraceServiceRequest, 256),
evictC: make(chan eviction, 64),
debugFlags: debug,
debugAllEnqueuedSpans: make(map[oteltrace.SpanID]*tracev1.Span),
tracker: newSpanTracker(observer, debug),
observer: observer,
}
var err error
q.pendingResourcesByTraceID, err = lru.NewWithEvict(int(maxPendingTraces.Load()), q.onEvict)
if err != nil {
panic(err)
}
q.knownTraceIDMappings, err = lru.New[unique.Handle[oteltrace.TraceID], unique.Handle[oteltrace.TraceID]](int(maxCachedTraceIDs.Load()))
if err != nil {
panic(err)
}
q.wg.Add(2)
go q.runUploader()
go q.runProcessor()
return q
}
func (q *SpanExportQueue) runUploader() {
defer q.wg.Done()
for resourceSpans := range q.uploadC {
ctx, ca := context.WithTimeout(context.Background(), 10*time.Second)
if err := q.client.UploadTraces(ctx, resourceSpans); err != nil {
q.logger.Err(err).Msg("error uploading traces")
}
ca()
}
}
func (q *SpanExportQueue) runProcessor() {
defer q.wg.Done()
for {
select {
case req := <-q.requestC:
q.processRequestLocked(req)
case ev := <-q.evictC:
q.processEvictionLocked(ev)
case <-q.closing:
for {
select {
case req := <-q.requestC:
q.processRequestLocked(req)
case ev := <-q.evictC:
q.processEvictionLocked(ev)
default: // all channels empty
close(q.uploadC)
return
}
}
}
}
}
func (q *SpanExportQueue) onEvict(traceID unique.Handle[oteltrace.TraceID], buf *Buffer) {
q.evictC <- eviction{
traceID: traceID,
buf: buf,
}
}
func (q *SpanExportQueue) insertPendingSpanLocked(
resource *ResourceInfo,
scope *ScopeInfo,
traceID unique.Handle[oteltrace.TraceID],
span *tracev1.Span,
) {
var pendingTraceResources *Buffer
if ptr, ok := q.pendingResourcesByTraceID.Get(traceID); ok {
pendingTraceResources = ptr
} else {
pendingTraceResources = NewBuffer()
q.pendingResourcesByTraceID.Add(traceID, pendingTraceResources)
}
pendingTraceResources.Insert(resource, scope, span)
}
func (q *SpanExportQueue) resolveTraceIDMappingLocked(out *Buffer, original, target unique.Handle[oteltrace.TraceID]) {
q.knownTraceIDMappings.Add(original, target)
if target == zeroTraceID && original != zeroTraceID {
// mapping a trace id to zero indicates we should drop the trace
q.pendingResourcesByTraceID.Remove(original)
return
}
if originalPending, ok := q.pendingResourcesByTraceID.Peek(original); ok {
if original == target {
out.Merge(originalPending)
} else {
// check if the target id is also pending
if targetPending, ok := q.pendingResourcesByTraceID.Peek(target); ok {
targetPending.MergeAs(originalPending, target)
} else {
out.MergeAs(originalPending, target)
}
}
q.pendingResourcesByTraceID.Remove(original)
}
}
func (q *SpanExportQueue) getTraceIDMappingLocked(id unique.Handle[oteltrace.TraceID]) (unique.Handle[oteltrace.TraceID], bool) {
v, ok := q.knownTraceIDMappings.Get(id)
return v, ok
}
func (q *SpanExportQueue) isKnownTracePendingLocked(id unique.Handle[oteltrace.TraceID]) bool {
_, ok := q.pendingResourcesByTraceID.Get(id) // will update the key's recent-ness in the lru
return ok
}
var ErrShuttingDown = errors.New("exporter is shutting down")
func (q *SpanExportQueue) Enqueue(_ context.Context, req *coltracepb.ExportTraceServiceRequest) error {
select {
case <-q.closing:
return ErrShuttingDown
default:
q.requestC <- req
return nil
}
}
func (q *SpanExportQueue) processRequestLocked(req *coltracepb.ExportTraceServiceRequest) {
if q.debugFlags.Check(LogAllEvents) {
q.debugEvents = append(q.debugEvents, DebugEvent{
Timestamp: time.Now(),
Request: proto.Clone(req).(*coltracepb.ExportTraceServiceRequest),
})
}
// Spans are processed in two passes:
// 1. Look through each span to check if we have not yet seen its trace ID.
// If we haven't, and the span is a root span (no parent, or marked as such
// by us), mark the trace as observed, and (if indicated) keep track of the
// trace ID we need to rewrite it as, so that other spans we see later in
// this trace can also be rewritten the same way.
// If we find a new trace ID for which there are pending non-root spans,
// collect them and rewrite their trace IDs (if necessary), and prepare
// them to be uploaded.
//
// At this point, all trace IDs for the spans in the request are known.
//
// 2. Look through each span again, this time to filter out any spans in
// the request which belong to "pending" traces (known trace IDs for which
// we have not yet seen a root span), adding them to the list of pending
// spans for their corresponding trace IDs. They will be uploaded in the
// future once we have observed a root span for those traces, or if they
// are evicted by the queue.
// Pass 1
toUpload := NewBuffer()
for _, resource := range req.ResourceSpans {
for _, scope := range resource.ScopeSpans {
SPANS:
for _, span := range scope.Spans {
FormatSpanName(span)
spanID, ok := ToSpanID(span.SpanId)
if !ok {
continue
}
if q.debugFlags.Check(TrackAllSpans) {
q.debugAllEnqueuedSpans[spanID] = span
}
trackSpanReferences := q.debugFlags.Check(TrackSpanReferences)
parentSpanID, ok := ToSpanID(span.ParentSpanId)
if !ok {
continue
}
traceID, ok := ToTraceID(span.TraceId)
if !ok {
continue
}
if trackSpanReferences {
q.observer.Observe(spanID)
}
if mapping, ok := q.getTraceIDMappingLocked(traceID); ok {
if trackSpanReferences && mapping != zeroTraceID && parentSpanID.IsValid() {
q.observer.ObserveReference(parentSpanID, spanID)
}
} else {
// Observed a new trace ID. Check if the span is a root span
isRootSpan := !parentSpanID.IsValid() // no parent == root span
// Assume the trace is sampled, because it was exported. span.Flags
// is an unreliable way to detect whether the span was sampled,
// because neither envoy nor opentelemetry-go encode the sampling
// decision there, assuming unsampled spans would not be exported
// (this was not taking into account tail-based sampling strategies)
// https://github.com/open-telemetry/opentelemetry-proto/issues/166
isSampled := true
mappedTraceID := traceID
for _, attr := range span.Attributes {
switch attr.Key {
case "pomerium.traceparent":
tp, err := ParseTraceparent(attr.GetValue().GetStringValue())
if err != nil {
data, _ := protojson.Marshal(span)
q.logger.
Err(err).
Str("span", string(data)).
Msg("error processing span")
continue SPANS
}
mappedTraceID = unique.Make(tp.TraceID())
// use the sampling decision from pomerium.traceparent instead
isSampled = tp.IsSampled()
case "pomerium.external-parent-span":
// This is a non-root span whose parent we do not expect to see
// here. For example, if a request originated externally from a
// system that is uploading its own spans out-of-band from us,
// we will never observe a root span for this trace and it would
// otherwise get stuck in the queue.
if !isRootSpan && q.debugFlags.Check(TrackSpanReferences) {
value, err := oteltrace.SpanIDFromHex(attr.GetValue().GetStringValue())
if err != nil {
data, _ := protojson.Marshal(span)
q.logger.
Err(err).
Str("span", string(data)).
Msg("error processing span: invalid value for pomerium.external-parent-span")
} else {
q.observer.Observe(value) // mark this id as observed
}
}
isRootSpan = true
}
}
if q.debugFlags.Check(TrackSpanReferences) {
if isSampled && parentSpanID.IsValid() {
q.observer.ObserveReference(parentSpanID, spanID)
}
}
if !isSampled {
// We have observed a new trace that is not sampled (regardless of
// whether or not it is a root span). Resolve it using the zero
// trace ID to indicate that all spans for this trace should be
// dropped.
q.resolveTraceIDMappingLocked(toUpload, traceID, zeroTraceID)
} else if isRootSpan {
// We have observed a new trace that is sampled and is a root span.
// Resolve it using the mapped trace ID (if present), or its own
// trace ID (indicating it does not need to be rewritten).
// If the mapped trace is pending, this does not flush pending
// spans to the output buffer (toUpload), but instead merges them
// into the mapped trace's pending buffer.
q.resolveTraceIDMappingLocked(toUpload, traceID, mappedTraceID)
}
}
}
}
}
// Pass 2
for _, resource := range req.ResourceSpans {
resourceInfo := NewResourceInfo(resource.Resource, resource.SchemaUrl)
for _, scope := range resource.ScopeSpans {
scopeInfo := NewScopeInfo(scope.Scope, scope.SchemaUrl)
for _, span := range scope.Spans {
traceID, ok := ToTraceID(span.TraceId)
if !ok {
continue
}
if mapping, hasMapping := q.getTraceIDMappingLocked(traceID); hasMapping {
if mapping == zeroTraceID {
continue // the trace has been dropped
}
id := mapping.Value()
copy(span.TraceId, id[:])
// traceID = mapping
if q.isKnownTracePendingLocked(mapping) {
q.insertPendingSpanLocked(resourceInfo, scopeInfo, mapping, span)
} else {
toUpload.Insert(resourceInfo, scopeInfo, span)
}
} else {
q.insertPendingSpanLocked(resourceInfo, scopeInfo, traceID, span)
}
}
}
}
if resourceSpans := toUpload.Flush(); len(resourceSpans) > 0 {
q.uploadC <- resourceSpans
}
}
func (q *SpanExportQueue) processEvictionLocked(ev eviction) {
if ev.buf.IsEmpty() {
// if the buffer is not empty, it was evicted automatically
return
} else if mapping, ok := q.knownTraceIDMappings.Get(ev.traceID); ok && mapping == zeroTraceID {
q.logger.Debug().
Str("traceID", ev.traceID.Value().String()).
Msg("dropping unsampled trace")
return
}
select {
case q.uploadC <- ev.buf.Flush():
q.logger.Warn().
Str("traceID", ev.traceID.Value().String()).
Msg("trace export buffer is full, uploading oldest incomplete trace")
default:
q.logger.Warn().
Str("traceID", ev.traceID.Value().String()).
Msg("trace export buffer and upload queues are full, dropping trace")
}
}
var (
ErrIncompleteTraces = errors.New("exporter shut down with incomplete traces")
ErrIncompleteSpans = errors.New("exporter shut down with incomplete spans")
ErrIncompleteUploads = errors.New("exporter shut down with pending trace uploads")
ErrMissingParentSpans = errors.New("exporter shut down with missing parent spans")
)
func (q *SpanExportQueue) WaitForSpans(maxDuration time.Duration) error {
if !q.debugFlags.Check(TrackSpanReferences) {
return nil
}
done := make(chan struct{})
go func() {
defer close(done)
q.observer.wait(q.debugAllEnqueuedSpans, 10*time.Second)
}()
select {
case <-done:
return nil
case <-time.After(maxDuration):
return ErrMissingParentSpans
}
}
func (q *SpanExportQueue) Close(ctx context.Context) error {
closed := make(chan struct{})
go func() {
q.wg.Wait()
close(closed)
}()
close(q.closing)
select {
case <-ctx.Done():
log.Ctx(ctx).Error().Msg("exporter stopped before all traces could be exported")
return context.Cause(ctx)
case <-closed:
err := q.runOnCloseChecksLocked()
log.Ctx(ctx).Debug().Err(err).Msg("exporter stopped")
return err
}
}
func (q *SpanExportQueue) runOnCloseChecksLocked() error {
didWarn := false
if q.debugFlags.Check(TrackSpanReferences) {
var unknownParentIDs []string
for id, via := range q.observer.referencedIDs {
if via.IsValid() {
if q.debugFlags.Check(TrackAllSpans) {
if viaSpan, ok := q.debugAllEnqueuedSpans[via]; ok {
unknownParentIDs = append(unknownParentIDs, fmt.Sprintf("%s via %s (%s)", id, via, viaSpan.Name))
} else {
unknownParentIDs = append(unknownParentIDs, fmt.Sprintf("%s via %s", id, via))
}
}
}
}
if len(unknownParentIDs) > 0 {
didWarn = true
msg := startMsg("WARNING: parent spans referenced but never seen:\n")
for _, str := range unknownParentIDs {
msg.WriteString(str)
msg.WriteString("\n")
}
endMsg(msg)
}
}
incomplete := q.pendingResourcesByTraceID.Len() > 0
if incomplete && q.debugFlags.Check(WarnOnIncompleteTraces) {
didWarn = true
msg := startMsg("WARNING: exporter shut down with incomplete traces\n")
keys := q.pendingResourcesByTraceID.Keys()
values := q.pendingResourcesByTraceID.Values()
for i, k := range keys {
v := values[i]
fmt.Fprintf(msg, "- Trace: %s\n", k.Value())
for _, pendingScope := range v.scopesByResourceID {
msg.WriteString(" - Resource:\n")
for _, v := range pendingScope.resource.Resource.Attributes {
fmt.Fprintf(msg, " %s=%s\n", v.Key, v.Value.String())
}
for _, spanBuffer := range pendingScope.spansByScope {
if spanBuffer.scope != nil {
fmt.Fprintf(msg, " Scope: %s\n", spanBuffer.scope.ID())
} else {
msg.WriteString(" Scope: (unknown)\n")
}
msg.WriteString(" Spans:\n")
longestName := 0
for _, span := range spanBuffer.spans {
longestName = max(longestName, len(span.Name)+2)
}
for _, span := range spanBuffer.spans {
spanID, ok := ToSpanID(span.SpanId)
if !ok {
continue
}
traceID, ok := ToTraceID(span.TraceId)
if !ok {
continue
}
parentSpanID, ok := ToSpanID(span.ParentSpanId)
if !ok {
continue
}
_, seenParent := q.debugAllEnqueuedSpans[parentSpanID]
var missing string
if !seenParent {
missing = " [missing]"
}
fmt.Fprintf(msg, " - %-*s (trace: %s | span: %s | parent:%s %s)\n", longestName,
"'"+span.Name+"'", traceID.Value(), spanID, missing, parentSpanID)
for _, attr := range span.Attributes {
if attr.Key == "caller" {
fmt.Fprintf(msg, " => caller: '%s'\n", attr.Value.GetStringValue())
break
}
}
}
}
}
}
endMsg(msg)
}
if q.debugFlags.Check(LogTraceIDMappings) || (didWarn && q.debugFlags.Check(LogTraceIDMappingsOnWarn)) {
msg := startMsg("Known trace ids:\n")
keys := q.knownTraceIDMappings.Keys()
values := q.knownTraceIDMappings.Values()
for i, k := range keys {
v := values[i]
if k != v {
if v == zeroTraceID {
fmt.Fprintf(msg, "%s (dropped)\n", k.Value())
} else {
fmt.Fprintf(msg, "%s => %s\n", k.Value(), v.Value())
}
} else {
fmt.Fprintf(msg, "%s (no change)\n", k.Value())
}
}
endMsg(msg)
}
if q.debugFlags.Check(LogAllSpans) || (didWarn && q.debugFlags.Check(LogAllSpansOnWarn)) {
msg := startMsg("All exported spans:\n")
longestName := 0
for _, span := range q.debugAllEnqueuedSpans {
longestName = max(longestName, len(span.Name)+2)
}
for _, span := range q.debugAllEnqueuedSpans {
spanID, ok := ToSpanID(span.SpanId)
if !ok {
continue
}
traceID, ok := ToTraceID(span.TraceId)
if !ok {
continue
}
parentSpanID, ok := ToSpanID(span.ParentSpanId)
if !ok {
continue
}
fmt.Fprintf(msg, "%-*s (trace: %s | span: %s | parent: %s)", longestName,
"'"+span.Name+"'", traceID.Value(), spanID, parentSpanID)
var foundCaller bool
for _, attr := range span.Attributes {
if attr.Key == "caller" {
fmt.Fprintf(msg, " => %s\n", attr.Value.GetStringValue())
foundCaller = true
break
}
}
if !foundCaller {
msg.WriteString("\n")
}
}
endMsg(msg)
}
if q.debugFlags.Check(LogAllEvents) {
msg := startMsg("All Events:\n")
msg.WriteByte('[')
for i, event := range q.debugEvents {
msg.WriteString("\n ")
eventData, _ := json.Marshal(event)
msg.Write(eventData)
if i < len(q.debugEvents)-1 {
msg.WriteByte(',')
} else {
msg.WriteString("\n]")
}
}
msg.WriteByte('\n')
endMsg(msg)
}
if incomplete {
return ErrIncompleteTraces
}
return nil
}
type DebugEvent struct {
Timestamp time.Time `json:"timestamp"`
Request *coltracepb.ExportTraceServiceRequest `json:"request"`
}
func (e DebugEvent) MarshalJSON() ([]byte, error) {
type debugEvent struct {
Timestamp time.Time `json:"timestamp"`
Request json.RawMessage `json:"request"`
}
reqData, _ := protojson.Marshal(e.Request)
return json.Marshal(debugEvent{
Timestamp: e.Timestamp,
Request: reqData,
})
}
func (e *DebugEvent) UnmarshalJSON(b []byte) error {
type debugEvent struct {
Timestamp time.Time `json:"timestamp"`
Request json.RawMessage `json:"request"`
}
var ev debugEvent
if err := json.Unmarshal(b, &ev); err != nil {
return err
}
e.Timestamp = ev.Timestamp
var msg coltracepb.ExportTraceServiceRequest
if err := protojson.Unmarshal(ev.Request, &msg); err != nil {
return err
}
e.Request = &msg
return nil
}
const shardCount = 64
type (
shardedSet [shardCount]map[oteltrace.SpanID]struct{}
shardedLocks [shardCount]sync.Mutex
)
func (s *shardedSet) Range(f func(key oteltrace.SpanID)) {
for i := range shardCount {
for k := range s[i] {
f(k)
}
}
}
func (s *shardedLocks) LockAll() {
for i := range shardCount {
s[i].Lock()
}
}
func (s *shardedLocks) UnlockAll() {
for i := range shardCount {
s[i].Unlock()
}
}
type spanTracker struct {
inflightSpansMu shardedLocks
inflightSpans shardedSet
allSpans sync.Map
debugFlags DebugFlags
observer *spanObserver
shutdownOnce sync.Once
}
func newSpanTracker(observer *spanObserver, debugFlags DebugFlags) *spanTracker {
st := &spanTracker{
observer: observer,
debugFlags: debugFlags,
}
for i := range len(st.inflightSpans) {
st.inflightSpans[i] = make(map[oteltrace.SpanID]struct{})
}
return st
}
type spanInfo struct {
Name string
SpanContext oteltrace.SpanContext
Parent oteltrace.SpanContext
caller string
startTime time.Time
}
// ForceFlush implements trace.SpanProcessor.
func (t *spanTracker) ForceFlush(context.Context) error {
return nil
}
// OnEnd implements trace.SpanProcessor.
func (t *spanTracker) OnEnd(s sdktrace.ReadOnlySpan) {
id := s.SpanContext().SpanID()
bucket := binary.BigEndian.Uint64(id[:]) % shardCount
t.inflightSpansMu[bucket].Lock()
defer t.inflightSpansMu[bucket].Unlock()
delete(t.inflightSpans[bucket], id)
}
// OnStart implements trace.SpanProcessor.
func (t *spanTracker) OnStart(_ context.Context, s sdktrace.ReadWriteSpan) {
id := s.SpanContext().SpanID()
bucket := binary.BigEndian.Uint64(id[:]) % shardCount
t.inflightSpansMu[bucket].Lock()
defer t.inflightSpansMu[bucket].Unlock()
t.inflightSpans[bucket][id] = struct{}{}
if t.debugFlags.Check(TrackSpanReferences) {
t.observer.Observe(id)
}
if t.debugFlags.Check(TrackAllSpans) {
var caller string
for _, attr := range s.Attributes() {
if attr.Key == "caller" {
caller = attr.Value.AsString()
break
}
}
t.allSpans.Store(id, &spanInfo{
Name: s.Name(),
SpanContext: s.SpanContext(),
Parent: s.Parent(),
caller: caller,
startTime: s.StartTime(),
})
}
}
// Shutdown implements trace.SpanProcessor.
func (t *spanTracker) Shutdown(_ context.Context) error {
if t.debugFlags == 0 {
return nil
}
didWarn := false
t.shutdownOnce.Do(func() {
if t.debugFlags.Check(WarnOnIncompleteSpans) {
if t.debugFlags.Check(TrackAllSpans) {
incompleteSpans := []*spanInfo{}
t.inflightSpansMu.LockAll()
t.inflightSpans.Range(func(key oteltrace.SpanID) {
if info, ok := t.allSpans.Load(key); ok {
incompleteSpans = append(incompleteSpans, info.(*spanInfo))
}
})
t.inflightSpansMu.UnlockAll()
if len(incompleteSpans) > 0 {
didWarn = true
msg := startMsg("WARNING: spans not ended:\n")
longestName := 0
for _, span := range incompleteSpans {
longestName = max(longestName, len(span.Name)+2)
}
for _, span := range incompleteSpans {
var startedAt string
if span.caller != "" {
startedAt = " | started at: " + span.caller
}
fmt.Fprintf(msg, "%-*s (trace: %s | span: %s | parent: %s%s)\n", longestName, "'"+span.Name+"'",
span.SpanContext.TraceID(), span.SpanContext.SpanID(), span.Parent.SpanID(), startedAt)
}
endMsg(msg)
}
} else {
incompleteSpans := []oteltrace.SpanID{}
t.inflightSpansMu.LockAll()
t.inflightSpans.Range(func(key oteltrace.SpanID) {
incompleteSpans = append(incompleteSpans, key)
})
t.inflightSpansMu.UnlockAll()
if len(incompleteSpans) > 0 {
didWarn = true
msg := startMsg("WARNING: spans not ended:\n")
for _, span := range incompleteSpans {
fmt.Fprintf(msg, "%s\n", span)
}
msg.WriteString("Note: set TrackAllSpans flag for more info\n")
endMsg(msg)
}
}
}
if t.debugFlags.Check(LogAllSpans) || (t.debugFlags.Check(LogAllSpansOnWarn) && didWarn) {
allSpans := []*spanInfo{}
t.allSpans.Range(func(_, value any) bool {
allSpans = append(allSpans, value.(*spanInfo))
return true
})
slices.SortFunc(allSpans, func(a, b *spanInfo) int {
return a.startTime.Compare(b.startTime)
})
msg := startMsg("All observed spans:\n")
longestName := 0
for _, span := range allSpans {
longestName = max(longestName, len(span.Name)+2)
}
for _, span := range allSpans {
var startedAt string
if span.caller != "" {
startedAt = " | started at: " + span.caller
}
fmt.Fprintf(msg, "%-*s (trace: %s | span: %s | parent: %s%s)\n", longestName, "'"+span.Name+"'",
span.SpanContext.TraceID(), span.SpanContext.SpanID(), span.Parent.SpanID(), startedAt)
}
endMsg(msg)
}
})
if didWarn {
return ErrIncompleteSpans
}
return nil
}
func newSpanObserver() *spanObserver {
return &spanObserver{
referencedIDs: map[oteltrace.SpanID]oteltrace.SpanID{},
cond: sync.NewCond(&sync.Mutex{}),
}
}
type spanObserver struct {
cond *sync.Cond
referencedIDs map[oteltrace.SpanID]oteltrace.SpanID
unobservedIDs int
}
func (obs *spanObserver) ObserveReference(id oteltrace.SpanID, via oteltrace.SpanID) {
obs.cond.L.Lock()
defer obs.cond.L.Unlock()
if _, referenced := obs.referencedIDs[id]; !referenced {
obs.referencedIDs[id] = via // referenced, but not observed
// It is possible for new unobserved references to come in while waiting,
// but incrementing the counter wouldn't satisfy the condition so we don't
// need to signal the waiters
obs.unobservedIDs++
}
}
func (obs *spanObserver) Observe(id oteltrace.SpanID) {
obs.cond.L.Lock()
defer obs.cond.L.Unlock()
if observed, referenced := obs.referencedIDs[id]; !referenced || observed.IsValid() { // NB: subtle condition
obs.referencedIDs[id] = zeroSpanID
if referenced {
obs.unobservedIDs--
obs.cond.Broadcast()
}
}
}
func (obs *spanObserver) wait(debugAllEnqueuedSpans map[oteltrace.SpanID]*tracev1.Span, warnAfter time.Duration) {
done := make(chan struct{})
defer close(done)
go func() {
select {
case <-done:
return
case <-time.After(warnAfter):
obs.debugWarnWaiting(debugAllEnqueuedSpans)
}
}()
obs.cond.L.Lock()
for obs.unobservedIDs > 0 {
obs.cond.Wait()
}
obs.cond.L.Unlock()
}
func (obs *spanObserver) debugWarnWaiting(debugAllEnqueuedSpans map[oteltrace.SpanID]*tracev1.Span) {
obs.cond.L.Lock()
msg := startMsg(fmt.Sprintf("Waiting on %d unobserved spans:\n", obs.unobservedIDs))
for id, via := range obs.referencedIDs {
if via.IsValid() {
fmt.Fprintf(msg, "%s via %s", id, via)
if span := debugAllEnqueuedSpans[id]; span != nil {
createdAt := "(unknown)"
for _, attr := range span.Attributes {
if attr.Key == "caller" {
createdAt = attr.Value.GetStringValue()
break
}
}
fmt.Fprintf(msg, "'%s' (trace: %s | created: %s)\n", span.GetName(), span.TraceId, createdAt)
} else {
msg.WriteString("\n")
}
}
}
endMsg(msg)
obs.cond.L.Unlock()
}

View file

@ -1,798 +0,0 @@
package trace_test
import (
"bytes"
"context"
"embed"
"fmt"
"io/fs"
"regexp"
"runtime"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/pomerium/pomerium/internal/telemetry/trace"
"github.com/pomerium/pomerium/internal/telemetry/trace/mock_otlptrace"
"github.com/pomerium/pomerium/internal/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
sdktrace "go.opentelemetry.io/otel/sdk/trace"
oteltrace "go.opentelemetry.io/otel/trace"
coltracepb "go.opentelemetry.io/proto/otlp/collector/trace/v1"
commonv1 "go.opentelemetry.io/proto/otlp/common/v1"
tracev1 "go.opentelemetry.io/proto/otlp/trace/v1"
"go.uber.org/mock/gomock"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/proto"
)
//go:embed testdata
var testdata embed.FS
func TestSpanExportQueue_Replay(t *testing.T) {
for _, tc := range []struct {
name string
file string
check func(t testing.TB, inputs, outputs *testutil.TraceResults)
}{
{
name: "single trace",
file: "testdata/recording_01_single_trace.json",
check: func(t testing.TB, inputs, outputs *testutil.TraceResults) {
inputs.AssertEqual(t, outputs)
},
},
{
name: "rewriting multiple traces",
file: "testdata/recording_02_multi_trace.json",
check: func(t testing.TB, inputs, outputs *testutil.TraceResults) {
inputTraces := inputs.GetTraces().WithoutErrors()
outputTraces := outputs.GetTraces().WithoutErrors()
// find upstream trace
var inputUpstreamTrace, outputUpstreamTrace *testutil.TraceDetails
isUpstreamTrace := func(v *testutil.TraceDetails) bool {
if strings.HasPrefix(v.Name, "Envoy: ingress:") {
for _, attr := range v.Spans[0].Raw.Attributes {
if attr.Key == "http.url" {
if regexp.MustCompile(`https://127\.0\.0\.1:\d+/foo`).MatchString(attr.Value.GetStringValue()) {
return true
}
}
}
}
return false
}
for _, v := range inputTraces.ByID {
if isUpstreamTrace(v) {
inputUpstreamTrace = v
break
}
}
for _, v := range outputTraces.ByID {
if isUpstreamTrace(v) {
outputUpstreamTrace = v
break
}
}
equal, diff := inputUpstreamTrace.Equal(outputUpstreamTrace)
if !equal {
assert.Failf(t, "upstream traces not equal:\n%s", diff)
return
}
// find downstream traces
// should be composed of:
// - 'ingress: GET foo.localhost.pomerium.io:<port>/foo'
// - 'internal: GET authenticate.localhost.pomerium.io:<port>/.pomerium/sign_in' (unauthorized)
// - 'internal: GET authenticate.localhost.pomerium.io:<port>/oauth2/callback'
// - 'internal: GET authenticate.localhost.pomerium.io:<port>/.pomerium/sign_in' (authorized)
// - 'internal: GET foo.localhost.pomerium.io:<port>/.pomerium/callback/'
envoyOutputTraces := outputTraces.ByParticipant["Envoy"]
// there should be two
require.Len(t, envoyOutputTraces, 2)
// find which one is not the upstream trace
var downstreamTrace *testutil.TraceDetails
if envoyOutputTraces[0].ID == outputUpstreamTrace.ID {
downstreamTrace = envoyOutputTraces[1]
} else {
downstreamTrace = envoyOutputTraces[0]
}
tree := downstreamTrace.SpanTree()
require.Empty(t, tree.DetachedParents)
parts := tree.Root.Children
require.Len(t, parts, 5)
assert.True(t, regexp.MustCompile(`ingress: GET foo\.localhost\.pomerium\.io:\d+/foo`).MatchString(parts[0].Span.Raw.Name))
assert.True(t, regexp.MustCompile(`internal: GET authenticate\.localhost\.pomerium\.io:\d+/.pomerium/sign_in`).MatchString(parts[1].Span.Raw.Name))
assert.True(t, regexp.MustCompile(`internal: GET authenticate\.localhost\.pomerium\.io:\d+/oauth2/callback`).MatchString(parts[2].Span.Raw.Name))
assert.True(t, regexp.MustCompile(`internal: GET authenticate\.localhost\.pomerium\.io:\d+/.pomerium/sign_in`).MatchString(parts[3].Span.Raw.Name))
assert.True(t, regexp.MustCompile(`internal: GET foo\.localhost\.pomerium\.io:\d+/.pomerium/callback`).MatchString(parts[4].Span.Raw.Name))
},
},
} {
t.Run(tc.name, func(t *testing.T) {
ctrl := gomock.NewController(t)
mockClient := mock_otlptrace.NewMockClient(ctrl)
var resultsMu sync.Mutex
outputSpans := [][]*tracev1.ResourceSpans{}
mockClient.EXPECT().
UploadTraces(gomock.Any(), gomock.Any()).
DoAndReturn(func(_ context.Context, protoSpans []*tracev1.ResourceSpans) error {
resultsMu.Lock()
defer resultsMu.Unlock()
outputSpans = append(outputSpans, protoSpans)
return nil
}).
AnyTimes()
recording1, err := fs.ReadFile(testdata, tc.file)
require.NoError(t, err)
rec, err := testutil.LoadEventRecording(recording1)
require.NoError(t, err)
ctx := trace.Options{
DebugFlags: trace.TrackAllSpans | trace.WarnOnIncompleteSpans | trace.WarnOnIncompleteTraces | trace.WarnOnUnresolvedReferences,
}.NewContext(context.Background())
queue := trace.NewSpanExportQueue(ctx, mockClient)
recCloned := rec.Clone()
err = rec.Replay(func(ctx context.Context, req *coltracepb.ExportTraceServiceRequest) (*coltracepb.ExportTraceServiceResponse, error) {
return &coltracepb.ExportTraceServiceResponse{}, queue.Enqueue(ctx, req)
})
assert.NoError(t, err)
// wait for all calls to UploadTraces to complete
ctx, ca := context.WithTimeout(context.Background(), 1*time.Second)
defer ca()
assert.NoError(t, queue.Close(ctx))
recCloned.Normalize(rec.NormalizedTo())
inputRequests := []*coltracepb.ExportTraceServiceRequest{}
for _, ev := range recCloned.Events() {
inputRequests = append(inputRequests, ev.Request)
}
inputs := testutil.NewTraceResults(testutil.FlattenExportRequests(inputRequests))
outputs := testutil.NewTraceResults(testutil.FlattenResourceSpans(outputSpans))
tc.check(t, inputs, outputs)
})
}
}
func TestSpanExportQueue_Enqueue(t *testing.T) {
type (
mapped struct {
s Span
t Trace
}
action struct {
exports []Span
uploads []any // int|mapped|*tracev1.Span
}
testCase struct {
name string
spans []*tracev1.Span // note: span ids are set automatically by index
actions []action
// if actionSets is present, repeats the same test case for each entry
actionSets [][]action
}
)
traceparent := func(trace Trace, span Span, sampled ...bool) *commonv1.KeyValue {
if len(sampled) == 0 {
sampled = append(sampled, true)
}
return &commonv1.KeyValue{
Key: "pomerium.traceparent",
Value: &commonv1.AnyValue{Value: &commonv1.AnyValue_StringValue{
StringValue: Traceparent(trace, span, sampled[0]),
}},
}
}
externalParent := func(span Span) *commonv1.KeyValue {
return &commonv1.KeyValue{
Key: "pomerium.external-parent-span",
Value: &commonv1.AnyValue{Value: &commonv1.AnyValue_StringValue{
StringValue: span.ID().String(),
}},
}
}
attrs := func(kvs ...*commonv1.KeyValue) []*commonv1.KeyValue { return kvs }
cases := []testCase{
{
name: "single trace",
spans: []*tracev1.Span{
// |<========>| Span 1
// | <======> | Span 2
// | <====> | Span 3
// T123456789A-
Span(1): {
TraceId: Trace(1).B(),
ParentSpanId: nil,
StartTimeUnixNano: 1,
EndTimeUnixNano: 0xA,
},
Span(2): {
TraceId: Trace(1).B(),
ParentSpanId: Span(1).B(),
StartTimeUnixNano: 2,
EndTimeUnixNano: 9,
},
Span(3): {
TraceId: Trace(1).B(),
ParentSpanId: Span(2).B(),
StartTimeUnixNano: 3,
EndTimeUnixNano: 8,
},
},
actionSets: [][]action{
// root span first
{
{exports: []Span{1}, uploads: []any{1}},
{exports: []Span{2, 3}, uploads: []any{2, 3}},
},
{
{exports: []Span{1, 2}, uploads: []any{1, 2}},
{exports: []Span{3}, uploads: []any{3}},
},
{
{exports: []Span{1, 2, 3}, uploads: []any{1, 2, 3}},
},
{
{exports: []Span{1, 3, 2}, uploads: []any{1, 2, 3}},
},
{
{exports: []Span{1}, uploads: []any{1}},
{exports: []Span{2}, uploads: []any{2}},
{exports: []Span{3}, uploads: []any{3}},
},
{
{exports: []Span{1}, uploads: []any{1}},
{exports: []Span{3}, uploads: []any{3}},
{exports: []Span{2}, uploads: []any{2}},
},
// root span last
{
{exports: []Span{2}, uploads: []any{}},
{exports: []Span{3}, uploads: []any{}},
{exports: []Span{1}, uploads: []any{1, 2, 3}},
},
{
{exports: []Span{3}, uploads: []any{}},
{exports: []Span{2}, uploads: []any{}},
{exports: []Span{1}, uploads: []any{1, 2, 3}},
},
{
{exports: []Span{2, 3}, uploads: []any{}},
{exports: []Span{1}, uploads: []any{1, 2, 3}},
},
{
{exports: []Span{3, 2}, uploads: []any{}},
{exports: []Span{1}, uploads: []any{1, 2, 3}},
},
{
{exports: []Span{3}, uploads: []any{}},
{exports: []Span{2, 1}, uploads: []any{1, 2, 3}},
},
{
{exports: []Span{2, 3, 1}, uploads: []any{1, 2, 3}},
},
// root span in the middle
{
{exports: []Span{2}, uploads: []any{}},
{exports: []Span{1}, uploads: []any{1, 2}},
{exports: []Span{3}, uploads: []any{3}},
},
{
{exports: []Span{3}, uploads: []any{}},
{exports: []Span{1}, uploads: []any{1, 3}},
{exports: []Span{2}, uploads: []any{2}},
},
{
{exports: []Span{3}, uploads: []any{}},
{exports: []Span{1, 2}, uploads: []any{1, 2, 3}},
},
{
{exports: []Span{2}, uploads: []any{}},
{exports: []Span{1, 3}, uploads: []any{1, 2, 3}},
},
},
},
{
name: "two correlated traces",
spans: []*tracev1.Span{
// |<=====> | Span 1 (Trace 1)
// | <===> | Span 2 (Trace 1)
// | <=> | Span 3 (Trace 1)
// | <======>| Span 4 (Trace 2)
// | <====> | Span 5 (Trace 2)
// T123456789ABCDEF-
Span(1): {
TraceId: Trace(1).B(),
ParentSpanId: nil,
StartTimeUnixNano: 1,
EndTimeUnixNano: 7,
},
Span(2): {
TraceId: Trace(1).B(),
ParentSpanId: Span(1).B(),
StartTimeUnixNano: 2,
EndTimeUnixNano: 6,
},
Span(3): {
TraceId: Trace(1).B(),
ParentSpanId: Span(2).B(),
StartTimeUnixNano: 3,
EndTimeUnixNano: 5,
},
Span(4): {
TraceId: Trace(2).B(),
ParentSpanId: nil,
Attributes: attrs(traceparent(Trace(1), Span(1))),
StartTimeUnixNano: 8,
EndTimeUnixNano: 0xF,
},
Span(5): {
TraceId: Trace(2).B(),
ParentSpanId: Span(4).B(),
Attributes: attrs(traceparent(Trace(1), Span(1))),
StartTimeUnixNano: 9,
EndTimeUnixNano: 0xE,
},
},
actionSets: [][]action{
0: {
{
exports: []Span{1, 2, 3, 4, 5},
uploads: []any{1, 2, 3, mapped{4, Trace(1)}, mapped{5, Trace(1)}},
},
},
1: {
{exports: []Span{2, 3, 5}, uploads: []any{}},
{
exports: []Span{1, 4},
uploads: []any{1, 2, 3, mapped{4, Trace(1)}, mapped{5, Trace(1)}},
},
},
2: {
{exports: []Span{2, 3, 5}, uploads: []any{}},
{
exports: []Span{1},
uploads: []any{1, 2, 3},
},
{
exports: []Span{4},
uploads: []any{mapped{4, Trace(1)}, mapped{5, Trace(1)}},
},
},
3: {
{exports: []Span{2, 3, 5}, uploads: []any{}},
{
exports: []Span{4, 1},
uploads: []any{1, 2, 3, mapped{4, Trace(1)}, mapped{5, Trace(1)}},
},
},
4: {
{exports: []Span{2, 3, 5}, uploads: []any{}},
{exports: []Span{4}, uploads: []any{}}, // root span, but mapped to a pending trace
{
exports: []Span{1},
uploads: []any{1, 2, 3, mapped{4, Trace(1)}, mapped{5, Trace(1)}},
},
},
},
},
{
name: "external parent",
spans: []*tracev1.Span{
// |??????????| Span 1 (external)
// | <======> | Span 2 (internal)
// | <====> | Span 3
// T123456789A-
Span(2): {
TraceId: Trace(1).B(),
ParentSpanId: Span(1).B(),
StartTimeUnixNano: 2,
EndTimeUnixNano: 9,
Attributes: attrs(externalParent(Span(1))),
},
Span(3): {
TraceId: Trace(1).B(),
ParentSpanId: Span(2).B(),
StartTimeUnixNano: 3,
EndTimeUnixNano: 8,
},
},
actionSets: [][]action{
{
{exports: []Span{3}, uploads: []any{}},
{exports: []Span{2}, uploads: []any{2, 3}},
},
{
{exports: []Span{2, 3}, uploads: []any{2, 3}},
},
{
{exports: []Span{3, 2}, uploads: []any{3, 2}},
},
},
},
}
generatedCases := []testCase{}
for _, tc := range cases {
for i, s := range tc.spans {
if s == nil {
continue
}
s.SpanId = Span(i).B()
}
if len(tc.actionSets) > 0 {
generated := []testCase{}
for i, actions := range tc.actionSets {
generated = append(generated, testCase{
name: fmt.Sprintf("%s (action set %d of %d)", tc.name, i+1, len(tc.actionSets)),
spans: tc.spans,
actions: actions,
})
}
generatedCases = append(generatedCases, generated...)
} else {
generatedCases = append(generatedCases, tc)
}
}
for _, tc := range generatedCases {
t.Run(tc.name, func(t *testing.T) {
ctrl := gomock.NewController(t)
mockClient := mock_otlptrace.NewMockClient(ctrl)
var resultsMu sync.Mutex
outputSpans := make(chan []*tracev1.ResourceSpans, 64)
mockClient.EXPECT().
UploadTraces(gomock.Any(), gomock.Any()).
DoAndReturn(func(_ context.Context, protoSpans []*tracev1.ResourceSpans) error {
resultsMu.Lock()
defer resultsMu.Unlock()
outputSpans <- protoSpans
return nil
}).
AnyTimes()
ctx := trace.Options{
DebugFlags: trace.TrackAllSpans | trace.WarnOnIncompleteSpans | trace.WarnOnIncompleteTraces | trace.WarnOnUnresolvedReferences,
}.NewContext(context.Background())
queue := trace.NewSpanExportQueue(ctx, mockClient)
for actionIdx, action := range tc.actions {
spans := []*tracev1.Span{}
for _, idx := range action.exports {
spans = append(spans, proto.Clone(tc.spans[idx]).(*tracev1.Span))
}
assert.NoError(t, queue.Enqueue(ctx, &coltracepb.ExportTraceServiceRequest{
ResourceSpans: []*tracev1.ResourceSpans{
{
Resource: Resource(1).Make().Resource,
ScopeSpans: []*tracev1.ScopeSpans{{Scope: Scope(1).Make().Scope, Spans: spans}},
},
},
}))
if len(action.uploads) == 0 {
for range 5 {
runtime.Gosched()
require.Empty(t, outputSpans)
}
continue
}
expectedSpans := &tracev1.ResourceSpans{
Resource: Resource(1).Make().Resource,
ScopeSpans: []*tracev1.ScopeSpans{{Scope: Scope(1).Make().Scope}},
}
for _, expectedUpload := range action.uploads {
switch up := expectedUpload.(type) {
case int:
expectedSpans.ScopeSpans[0].Spans = append(expectedSpans.ScopeSpans[0].Spans, tc.spans[up])
case mapped:
clone := proto.Clone(tc.spans[up.s]).(*tracev1.Span)
clone.TraceId = up.t.B()
expectedSpans.ScopeSpans[0].Spans = append(expectedSpans.ScopeSpans[0].Spans, clone)
case *tracev1.Span:
expectedSpans.ScopeSpans[0].Spans = append(expectedSpans.ScopeSpans[0].Spans, up)
default:
panic(fmt.Sprintf("test bug: unexpected type: %T", up))
}
}
select {
case resourceSpans := <-outputSpans:
expected := testutil.NewTraceResults([]*tracev1.ResourceSpans{expectedSpans})
actual := testutil.NewTraceResults(resourceSpans)
actual.AssertEqual(t, expected, "action %d/%d", actionIdx+1, len(tc.actions))
case <-time.After(1 * time.Second):
t.Fatalf("timed out waiting for upload (action %d/%d)", actionIdx+1, len(tc.actions))
}
}
if !t.Failed() {
close(outputSpans)
// ensure the queue is read fully
if !assert.Empty(t, outputSpans) {
for _, out := range <-outputSpans {
t.Log(protojson.Format(out))
}
}
}
})
}
}
func TestSpanObserver(t *testing.T) {
t.Run("observe single reference", func(t *testing.T) {
obs := trace.NewSpanObserver()
assert.Equal(t, []oteltrace.SpanID{}, obs.XUnobservedIDs())
obs.ObserveReference(Span(1).ID(), Span(2).ID())
assert.Equal(t, []oteltrace.SpanID{Span(1).ID()}, obs.XUnobservedIDs())
obs.Observe(Span(1).ID())
assert.Equal(t, []oteltrace.SpanID{}, obs.XUnobservedIDs())
})
t.Run("observe multiple references", func(t *testing.T) {
obs := trace.NewSpanObserver()
obs.ObserveReference(Span(1).ID(), Span(2).ID())
obs.ObserveReference(Span(1).ID(), Span(3).ID())
obs.ObserveReference(Span(1).ID(), Span(4).ID())
assert.Equal(t, []oteltrace.SpanID{Span(1).ID()}, obs.XUnobservedIDs())
obs.Observe(Span(1).ID())
assert.Equal(t, []oteltrace.SpanID{}, obs.XUnobservedIDs())
})
t.Run("observe before reference", func(t *testing.T) {
obs := trace.NewSpanObserver()
obs.Observe(Span(1).ID())
assert.Equal(t, []oteltrace.SpanID{}, obs.XUnobservedIDs())
obs.ObserveReference(Span(1).ID(), Span(2).ID())
assert.Equal(t, []oteltrace.SpanID{}, obs.XUnobservedIDs())
})
t.Run("wait", func(t *testing.T) {
obs := trace.NewSpanObserver()
obs.ObserveReference(Span(1).ID(), Span(2).ID())
obs.Observe(Span(2).ID())
obs.ObserveReference(Span(3).ID(), Span(4).ID())
obs.Observe(Span(4).ID())
obs.ObserveReference(Span(5).ID(), Span(6).ID())
obs.Observe(Span(6).ID())
waitOkToExit := atomic.Bool{}
waitExited := atomic.Bool{}
go func() {
defer waitExited.Store(true)
obs.XWait()
assert.True(t, waitOkToExit.Load(), "wait exited early")
}()
time.Sleep(10 * time.Millisecond)
assert.False(t, waitExited.Load())
obs.Observe(Span(1).ID())
time.Sleep(10 * time.Millisecond)
assert.False(t, waitExited.Load())
obs.Observe(Span(3).ID())
time.Sleep(10 * time.Millisecond)
assert.False(t, waitExited.Load())
waitOkToExit.Store(true)
obs.Observe(Span(5).ID())
assert.Eventually(t, waitExited.Load, 10*time.Millisecond, 1*time.Millisecond)
})
t.Run("new references observed during wait", func(t *testing.T) {
obs := trace.NewSpanObserver()
obs.ObserveReference(Span(1).ID(), Span(2).ID())
obs.Observe(Span(2).ID())
obs.ObserveReference(Span(3).ID(), Span(4).ID())
obs.Observe(Span(4).ID())
obs.ObserveReference(Span(5).ID(), Span(6).ID())
obs.Observe(Span(6).ID())
waitOkToExit := atomic.Bool{}
waitExited := atomic.Bool{}
go func() {
defer waitExited.Store(true)
obs.XWait()
assert.True(t, waitOkToExit.Load(), "wait exited early")
}()
assert.Equal(t, []oteltrace.SpanID{Span(1).ID(), Span(3).ID(), Span(5).ID()}, obs.XUnobservedIDs())
time.Sleep(10 * time.Millisecond)
assert.False(t, waitExited.Load())
obs.Observe(Span(1).ID())
assert.Equal(t, []oteltrace.SpanID{Span(3).ID(), Span(5).ID()}, obs.XUnobservedIDs())
time.Sleep(10 * time.Millisecond)
assert.False(t, waitExited.Load())
obs.Observe(Span(3).ID())
assert.Equal(t, []oteltrace.SpanID{Span(5).ID()}, obs.XUnobservedIDs())
time.Sleep(10 * time.Millisecond)
assert.False(t, waitExited.Load())
// observe a new reference
obs.ObserveReference(Span(7).ID(), Span(8).ID())
obs.Observe(Span(8).ID())
assert.Equal(t, []oteltrace.SpanID{Span(5).ID(), Span(7).ID()}, obs.XUnobservedIDs())
time.Sleep(10 * time.Millisecond)
assert.False(t, waitExited.Load())
obs.Observe(Span(5).ID())
assert.Equal(t, []oteltrace.SpanID{Span(7).ID()}, obs.XUnobservedIDs())
time.Sleep(10 * time.Millisecond)
assert.False(t, waitExited.Load())
waitOkToExit.Store(true)
obs.Observe(Span(7).ID())
assert.Equal(t, []oteltrace.SpanID{}, obs.XUnobservedIDs())
assert.Eventually(t, waitExited.Load, 10*time.Millisecond, 1*time.Millisecond)
})
t.Run("multiple waiters", func(t *testing.T) {
t.Parallel()
obs := trace.NewSpanObserver()
obs.ObserveReference(Span(1).ID(), Span(2).ID())
obs.Observe(Span(2).ID())
waitersExited := atomic.Int32{}
for range 10 {
go func() {
defer waitersExited.Add(1)
obs.XWait()
}()
}
assert.Equal(t, []oteltrace.SpanID{Span(1).ID()}, obs.XUnobservedIDs())
time.Sleep(10 * time.Millisecond)
assert.Equal(t, int32(0), waitersExited.Load())
obs.Observe(Span(1).ID())
startTime := time.Now()
for waitersExited.Load() != 10 {
if time.Since(startTime) > 1*time.Millisecond {
t.Fatal("timed out")
}
runtime.Gosched()
}
})
}
func TestSpanTracker(t *testing.T) {
t.Run("no debug flags", func(t *testing.T) {
t.Parallel()
obs := trace.NewSpanObserver()
tracker := trace.NewSpanTracker(obs, 0)
tp := sdktrace.NewTracerProvider(sdktrace.WithSpanProcessor(tracker))
tracer := tp.Tracer("test")
assert.Equal(t, []oteltrace.SpanID{}, tracker.XInflightSpans())
_, span1 := tracer.Start(context.Background(), "span 1")
assert.Equal(t, []oteltrace.SpanID{span1.SpanContext().SpanID()}, tracker.XInflightSpans())
assert.Equal(t, []oteltrace.SpanID{}, obs.XObservedIDs())
span1.End()
assert.Equal(t, []oteltrace.SpanID{}, tracker.XInflightSpans())
assert.Equal(t, []oteltrace.SpanID{}, obs.XObservedIDs())
})
t.Run("with TrackSpanReferences debug flag", func(t *testing.T) {
t.Parallel()
obs := trace.NewSpanObserver()
tracker := trace.NewSpanTracker(obs, trace.TrackSpanReferences)
tp := sdktrace.NewTracerProvider(sdktrace.WithSpanProcessor(tracker))
tracer := tp.Tracer("test")
assert.Equal(t, []oteltrace.SpanID{}, tracker.XInflightSpans())
_, span1 := tracer.Start(context.Background(), "span 1")
assert.Equal(t, []oteltrace.SpanID{span1.SpanContext().SpanID()}, tracker.XInflightSpans())
assert.Equal(t, []oteltrace.SpanID{span1.SpanContext().SpanID()}, obs.XObservedIDs())
span1.End()
assert.Equal(t, []oteltrace.SpanID{}, tracker.XInflightSpans())
assert.Equal(t, []oteltrace.SpanID{span1.SpanContext().SpanID()}, obs.XObservedIDs())
})
}
func TestSpanTrackerWarnings(t *testing.T) {
t.Run("WarnOnIncompleteSpans", func(t *testing.T) {
var buf bytes.Buffer
trace.SetDebugMessageWriterForTest(t, &buf)
obs := trace.NewSpanObserver()
tracker := trace.NewSpanTracker(obs, trace.WarnOnIncompleteSpans)
tp := sdktrace.NewTracerProvider(sdktrace.WithSpanProcessor(tracker))
tracer := tp.Tracer("test")
_, span1 := tracer.Start(context.Background(), "span 1")
assert.ErrorIs(t, tp.Shutdown(context.Background()), trace.ErrIncompleteSpans)
assert.Equal(t, fmt.Sprintf(`
==================================================
WARNING: spans not ended:
%s
Note: set TrackAllSpans flag for more info
==================================================
`, span1.SpanContext().SpanID()), buf.String())
})
t.Run("WarnOnIncompleteSpans with TrackAllSpans", func(t *testing.T) {
var buf bytes.Buffer
trace.SetDebugMessageWriterForTest(t, &buf)
obs := trace.NewSpanObserver()
tracker := trace.NewSpanTracker(obs, trace.WarnOnIncompleteSpans|trace.TrackAllSpans)
tp := sdktrace.NewTracerProvider(sdktrace.WithSpanProcessor(tracker))
tracer := tp.Tracer("test")
_, span1 := tracer.Start(context.Background(), "span 1")
assert.ErrorIs(t, tp.Shutdown(context.Background()), trace.ErrIncompleteSpans)
assert.Equal(t, fmt.Sprintf(`
==================================================
WARNING: spans not ended:
'span 1' (trace: %s | span: %s | parent: 0000000000000000)
==================================================
`, span1.SpanContext().TraceID(), span1.SpanContext().SpanID()), buf.String())
})
t.Run("WarnOnIncompleteSpans with TrackAllSpans and stackTraceProcessor", func(t *testing.T) {
var buf bytes.Buffer
trace.SetDebugMessageWriterForTest(t, &buf)
obs := trace.NewSpanObserver()
tracker := trace.NewSpanTracker(obs, trace.WarnOnIncompleteSpans|trace.TrackAllSpans)
tp := sdktrace.NewTracerProvider(sdktrace.WithSpanProcessor(&trace.XStackTraceProcessor{}), sdktrace.WithSpanProcessor(tracker))
tracer := tp.Tracer("test")
_, span1 := tracer.Start(context.Background(), "span 1")
_, file, line, _ := runtime.Caller(0)
line--
assert.ErrorIs(t, tp.Shutdown(context.Background()), trace.ErrIncompleteSpans)
assert.Equal(t, fmt.Sprintf(`
==================================================
WARNING: spans not ended:
'span 1' (trace: %s | span: %s | parent: 0000000000000000 | started at: %s:%d)
==================================================
`, span1.SpanContext().TraceID(), span1.SpanContext().SpanID(), file, line), buf.String())
})
t.Run("LogAllSpansOnWarn", func(t *testing.T) {
var buf bytes.Buffer
trace.SetDebugMessageWriterForTest(t, &buf)
obs := trace.NewSpanObserver()
tracker := trace.NewSpanTracker(obs, trace.WarnOnIncompleteSpans|trace.TrackAllSpans|trace.LogAllSpansOnWarn)
tp := sdktrace.NewTracerProvider(sdktrace.WithSpanProcessor(&trace.XStackTraceProcessor{}), sdktrace.WithSpanProcessor(tracker))
tracer := tp.Tracer("test")
_, span1 := tracer.Start(context.Background(), "span 1")
time.Sleep(10 * time.Millisecond)
span1.End()
time.Sleep(10 * time.Millisecond)
_, span2 := tracer.Start(context.Background(), "span 2")
_, file, line, _ := runtime.Caller(0)
line--
tp.Shutdown(context.Background())
assert.Equal(t,
fmt.Sprintf(`
==================================================
WARNING: spans not ended:
'span 2' (trace: %[1]s | span: %[2]s | parent: 0000000000000000 | started at: %[3]s:%[4]d)
==================================================
==================================================
All observed spans:
'span 1' (trace: %[5]s | span: %[6]s | parent: 0000000000000000 | started at: %[3]s:%[7]d)
'span 2' (trace: %[1]s | span: %[2]s | parent: 0000000000000000 | started at: %[3]s:%[4]d)
==================================================
`,
span2.SpanContext().TraceID(), span2.SpanContext().SpanID(), file, line,
span1.SpanContext().TraceID(), span1.SpanContext().SpanID(), line-4,
), buf.String())
})
}

View file

@ -6,19 +6,37 @@ import (
"net" "net"
"time" "time"
"github.com/pomerium/pomerium/internal/log"
coltracepb "go.opentelemetry.io/proto/otlp/collector/trace/v1" coltracepb "go.opentelemetry.io/proto/otlp/collector/trace/v1"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/test/bufconn" "google.golang.org/grpc/test/bufconn"
"go.opentelemetry.io/otel/exporters/otlp/otlptrace" "go.opentelemetry.io/otel/exporters/otlp/otlptrace"
"go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc" "go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc"
sdktrace "go.opentelemetry.io/otel/sdk/trace"
) )
const localExporterMetadataKey = "x-local-exporter"
// Export implements ptraceotlp.GRPCServer. // Export implements ptraceotlp.GRPCServer.
func (srv *ExporterServer) Export(ctx context.Context, req *coltracepb.ExportTraceServiceRequest) (*coltracepb.ExportTraceServiceResponse, error) { func (srv *ExporterServer) Export(ctx context.Context, req *coltracepb.ExportTraceServiceRequest) (*coltracepb.ExportTraceServiceResponse, error) {
if err := srv.spanExportQueue.Enqueue(ctx, req); err != nil { if srv.observer != nil {
isLocal := len(metadata.ValueFromIncomingContext(ctx, localExporterMetadataKey)) != 0
if !isLocal {
for _, res := range req.ResourceSpans {
for _, scope := range res.ScopeSpans {
for _, span := range scope.Spans {
if id, ok := ToSpanID(span.SpanId); ok {
srv.observer.Observe(id)
}
}
}
}
}
}
if err := srv.remoteClient.UploadTraces(ctx, req.GetResourceSpans()); err != nil {
log.Ctx(ctx).Err(err).Msg("error uploading traces")
return nil, err return nil, err
} }
return &coltracepb.ExportTraceServiceResponse{}, nil return &coltracepb.ExportTraceServiceResponse{}, nil
@ -26,16 +44,17 @@ func (srv *ExporterServer) Export(ctx context.Context, req *coltracepb.ExportTra
type ExporterServer struct { type ExporterServer struct {
coltracepb.UnimplementedTraceServiceServer coltracepb.UnimplementedTraceServiceServer
spanExportQueue *SpanExportQueue
server *grpc.Server server *grpc.Server
observer *spanObserver
remoteClient otlptrace.Client remoteClient otlptrace.Client
cc *grpc.ClientConn cc *grpc.ClientConn
} }
func NewServer(ctx context.Context, remoteClient otlptrace.Client) *ExporterServer { func NewServer(ctx context.Context) *ExporterServer {
sys := systemContextFromContext(ctx)
ex := &ExporterServer{ ex := &ExporterServer{
spanExportQueue: NewSpanExportQueue(ctx, remoteClient), remoteClient: sys.options.RemoteClient,
remoteClient: remoteClient, observer: sys.observer,
server: grpc.NewServer(grpc.Creds(insecure.NewCredentials())), server: grpc.NewServer(grpc.Creds(insecure.NewCredentials())),
} }
coltracepb.RegisterTraceServiceServer(ex.server, ex) coltracepb.RegisterTraceServiceServer(ex.server, ex)
@ -64,13 +83,12 @@ func (srv *ExporterServer) NewClient() otlptrace.Client {
return otlptracegrpc.NewClient( return otlptracegrpc.NewClient(
otlptracegrpc.WithGRPCConn(srv.cc), otlptracegrpc.WithGRPCConn(srv.cc),
otlptracegrpc.WithTimeout(1*time.Minute), otlptracegrpc.WithTimeout(1*time.Minute),
otlptracegrpc.WithHeaders(map[string]string{
localExporterMetadataKey: "1",
}),
) )
} }
func (srv *ExporterServer) SpanProcessors() []sdktrace.SpanProcessor {
return []sdktrace.SpanProcessor{srv.spanExportQueue.tracker}
}
func (srv *ExporterServer) Shutdown(ctx context.Context) error { func (srv *ExporterServer) Shutdown(ctx context.Context) error {
stopped := make(chan struct{}) stopped := make(chan struct{})
go func() { go func() {
@ -83,10 +101,7 @@ func (srv *ExporterServer) Shutdown(ctx context.Context) error {
return context.Cause(ctx) return context.Cause(ctx)
} }
var errs []error var errs []error
if err := srv.spanExportQueue.WaitForSpans(30 * time.Second); err != nil { if err := WaitForSpans(ctx, 30*time.Second); err != nil {
errs = append(errs, err)
}
if err := srv.spanExportQueue.Close(ctx); err != nil {
errs = append(errs, err) errs = append(errs, err)
} }
if err := srv.remoteClient.Stop(ctx); err != nil { if err := srv.remoteClient.Stop(ctx); err != nil {

View file

@ -1,9 +0,0 @@
These trace recordings are generated as follows:
- recording_01_single_trace.json:
`go test -v -run "^TestOTLPTracing$" -env.trace-debug-flags=+32 github.com/pomerium/pomerium/internal/testenv/selftests | grep -ozP "(?s)(?<=All Events:\n).*?(?=\n=====)"`
- recording_02_multi_trace.json:
`go test -v -run "^TestOTLPTracing_TraceCorrelation$" -env.trace-debug-flags=+32 github.com/pomerium/pomerium/internal/testenv/selftests | grep -ozP "(?s)(?<=All Events:\n).*?(?=\n=====)"`

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View file

@ -7,7 +7,6 @@ import (
"runtime" "runtime"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time"
"go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/exporters/otlp/otlptrace" "go.opentelemetry.io/otel/exporters/otlp/otlptrace"
@ -35,8 +34,11 @@ func (op Options) NewContext(parent context.Context) context.Context {
options: op, options: op,
tpm: &tracerProviderManager{}, tpm: &tracerProviderManager{},
} }
if op.DebugFlags.Check(TrackSpanReferences) {
sys.observer = newSpanObserver()
}
ctx := context.WithValue(parent, systemContextKey, sys) ctx := context.WithValue(parent, systemContextKey, sys)
sys.exporterServer = NewServer(ctx, op.RemoteClient) sys.exporterServer = NewServer(ctx)
sys.exporterServer.Start(ctx) sys.exporterServer.Start(ctx)
return ctx return ctx
} }
@ -86,13 +88,14 @@ func NewTracerProvider(ctx context.Context, serviceName string, opts ...sdktrace
if sys.options.DebugFlags.Check(TrackSpanCallers) { if sys.options.DebugFlags.Check(TrackSpanCallers) {
options = append(options, sdktrace.WithSpanProcessor(&stackTraceProcessor{})) options = append(options, sdktrace.WithSpanProcessor(&stackTraceProcessor{}))
} }
if sys.options.DebugFlags.Check(TrackSpanReferences) {
tracker := newSpanTracker(sys.observer, sys.options.DebugFlags)
options = append(options, sdktrace.WithSpanProcessor(tracker))
}
options = append(append(options, options = append(append(options,
sdktrace.WithBatcher(exp), sdktrace.WithBatcher(exp),
sdktrace.WithResource(r), sdktrace.WithResource(r),
), opts...) ), opts...)
for _, proc := range sys.exporterServer.SpanProcessors() {
options = append(options, sdktrace.WithSpanProcessor(proc))
}
tp := sdktrace.NewTracerProvider(options...) tp := sdktrace.NewTracerProvider(options...)
sys.tpm.Add(tp) sys.tpm.Add(tp)
return tp return tp
@ -156,34 +159,6 @@ func RemoteClientFromContext(ctx context.Context) otlptrace.Client {
return nil return nil
} }
func DebugFlagsFromContext(ctx context.Context) DebugFlags {
if sys := systemContextFromContext(ctx); sys != nil {
return sys.options.DebugFlags
}
return 0
}
// WaitForSpans will block up to the given max duration and wait for all
// in-flight spans from tracers created with the given context to end. This
// function can be called more than once, and is safe to call from multiple
// goroutines in parallel.
//
// This requires the [TrackSpanReferences] debug flag to have been set with
// [Options.NewContext]. Otherwise, this function is a no-op and will return
// immediately.
//
// If this function blocks for more than 10 seconds, it will print a warning
// to stderr containing a list of span IDs it is waiting for, and the IDs of
// their parents (if known). Additionally, if the [TrackAllSpans] debug flag
// is set, details about parent spans will be displayed, including call site
// and trace ID.
func WaitForSpans(ctx context.Context, maxDuration time.Duration) error {
if sys := systemContextFromContext(ctx); sys != nil {
return sys.exporterServer.spanExportQueue.WaitForSpans(maxDuration)
}
return nil
}
// ForceFlush immediately exports all spans that have not yet been exported for // ForceFlush immediately exports all spans that have not yet been exported for
// all tracer providers created using the given context. // all tracer providers created using the given context.
func ForceFlush(ctx context.Context) error { func ForceFlush(ctx context.Context) error {
@ -204,6 +179,7 @@ var systemContextKey systemContextKeyType
type systemContext struct { type systemContext struct {
options Options options Options
tpm *tracerProviderManager tpm *tracerProviderManager
observer *spanObserver
exporterServer *ExporterServer exporterServer *ExporterServer
shutdown atomic.Bool shutdown atomic.Bool
} }

View file

@ -18,7 +18,7 @@ var (
type XStackTraceProcessor = stackTraceProcessor type XStackTraceProcessor = stackTraceProcessor
func (obs *spanObserver) XWait() { func (obs *spanObserver) XWait() {
obs.wait(nil, 5*time.Second) obs.wait(5 * time.Second)
} }
func (obs *spanObserver) XUnobservedIDs() []oteltrace.SpanID { func (obs *spanObserver) XUnobservedIDs() []oteltrace.SpanID {

View file

@ -1,85 +1,11 @@
package trace package trace
import ( import (
"encoding/hex"
"fmt"
"net/url"
"strings"
"unique" "unique"
"go.opentelemetry.io/otel/attribute"
oteltrace "go.opentelemetry.io/otel/trace" oteltrace "go.opentelemetry.io/otel/trace"
commonv1 "go.opentelemetry.io/proto/otlp/common/v1"
tracev1 "go.opentelemetry.io/proto/otlp/trace/v1"
) )
func ParseTraceparent(traceparent string) (oteltrace.SpanContext, error) {
parts := strings.Split(traceparent, "-")
if len(parts) != 4 {
return oteltrace.SpanContext{}, fmt.Errorf("malformed traceparent: expected 4 segments, found %d", len(parts))
}
traceID, err := oteltrace.TraceIDFromHex(parts[1])
if err != nil {
return oteltrace.SpanContext{}, fmt.Errorf("malformed traceparent: invalid trace ID: %w", err)
}
spanID, err := oteltrace.SpanIDFromHex(parts[2])
if err != nil {
return oteltrace.SpanContext{}, fmt.Errorf("malformed traceparent: invalid span ID: %w", err)
}
var traceFlags oteltrace.TraceFlags
if flags, err := hex.DecodeString(parts[3]); err != nil {
return oteltrace.SpanContext{}, fmt.Errorf("malformed traceparent: invalid trace flags: %w", err)
} else if len(flags) == 1 {
traceFlags = oteltrace.TraceFlags(flags[0])
} else {
return oteltrace.SpanContext{}, fmt.Errorf("malformed traceparent: invalid trace flags of size %d", len(flags))
}
if len(traceID) != 16 {
return oteltrace.SpanContext{}, fmt.Errorf("malformed traceparent: invalid trace ID of size %d", len(traceID))
}
if len(spanID) != 8 {
return oteltrace.SpanContext{}, fmt.Errorf("malformed traceparent: invalid span ID of size %d", len(spanID))
}
return oteltrace.NewSpanContext(oteltrace.SpanContextConfig{
TraceID: traceID,
SpanID: spanID,
TraceFlags: traceFlags,
}), nil
}
// WithTraceFromSpanContext returns a copy of traceparent with the trace ID
// (2nd segment) and trace flags (4th segment) replaced with the corresponding
// values from spanContext.
func WithTraceFromSpanContext(traceparent string, spanContext oteltrace.SpanContext) string {
parts := strings.Split(traceparent, "-")
if len(parts) != 4 {
return traceparent
}
parts[1] = spanContext.TraceID().String()
parts[3] = spanContext.TraceFlags().String()
return strings.Join(parts, "-")
}
func FormatSpanName(span *tracev1.Span) {
hasVariables := strings.Contains(span.GetName(), "${")
if hasVariables {
replacements := make([]string, 0, 6)
for _, attr := range span.Attributes {
switch attr.Key {
case "http.url":
u, _ := url.Parse(attr.Value.GetStringValue())
replacements = append(replacements,
"${path}", u.Path,
"${host}", u.Host,
)
case "http.method":
replacements = append(replacements, "${method}", attr.Value.GetStringValue())
}
}
span.Name = strings.NewReplacer(replacements...).Replace(span.Name)
}
}
var ( var (
zeroSpanID oteltrace.SpanID zeroSpanID oteltrace.SpanID
zeroTraceID = unique.Make(oteltrace.TraceID([16]byte{})) zeroTraceID = unique.Make(oteltrace.TraceID([16]byte{}))
@ -104,33 +30,3 @@ func ToTraceID(bytes []byte) (unique.Handle[oteltrace.TraceID], bool) {
} }
return zeroTraceID, false return zeroTraceID, false
} }
func NewAttributeSet(kvs ...*commonv1.KeyValue) attribute.Set {
attrs := make([]attribute.KeyValue, len(kvs))
for i, kv := range kvs {
var value attribute.Value
switch v := kv.Value.Value.(type) {
case *commonv1.AnyValue_BoolValue:
value = attribute.BoolValue(v.BoolValue)
case *commonv1.AnyValue_BytesValue:
value = attribute.StringValue(string(v.BytesValue))
case *commonv1.AnyValue_DoubleValue:
value = attribute.Float64Value(v.DoubleValue)
case *commonv1.AnyValue_IntValue:
value = attribute.Int64Value(v.IntValue)
case *commonv1.AnyValue_StringValue:
value = attribute.StringValue(v.StringValue)
case *commonv1.AnyValue_ArrayValue:
panic("unimplemented")
case *commonv1.AnyValue_KvlistValue:
panic("unimplemented")
default:
panic(fmt.Sprintf("unexpected v1.isAnyValue_Value: %#v", v))
}
attrs[i] = attribute.KeyValue{
Key: attribute.Key(kv.Key),
Value: value,
}
}
return attribute.NewSet(attrs...)
}

View file

@ -286,7 +286,7 @@ const StandardTraceDebugFlags = trace.TrackSpanCallers |
trace.WarnOnIncompleteSpans | trace.WarnOnIncompleteSpans |
trace.WarnOnIncompleteTraces | trace.WarnOnIncompleteTraces |
trace.WarnOnUnresolvedReferences | trace.WarnOnUnresolvedReferences |
trace.LogTraceIDMappingsOnWarn | trace.LogTraceIDsOnWarn |
trace.LogAllSpansOnWarn trace.LogAllSpansOnWarn
func WithTraceDebugFlags(flags trace.DebugFlags) EnvironmentOption { func WithTraceDebugFlags(flags trace.DebugFlags) EnvironmentOption {

View file

@ -13,7 +13,7 @@ import (
"github.com/pomerium/pomerium/internal/testenv" "github.com/pomerium/pomerium/internal/testenv"
"github.com/pomerium/pomerium/internal/testenv/upstreams" "github.com/pomerium/pomerium/internal/testenv/upstreams"
"github.com/pomerium/pomerium/internal/testenv/values" "github.com/pomerium/pomerium/internal/testenv/values"
"github.com/pomerium/pomerium/internal/testutil" "github.com/pomerium/pomerium/internal/testutil/tracetest"
"go.opentelemetry.io/otel/exporters/otlp/otlptrace" "go.opentelemetry.io/otel/exporters/otlp/otlptrace"
"go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc" "go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc"
"go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp" "go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp"
@ -136,7 +136,7 @@ func (rec *OTLPTraceReceiver) PeekResourceSpans() []*tracev1.ResourceSpans {
} }
func (rec *OTLPTraceReceiver) peekResourceSpansLocked() []*tracev1.ResourceSpans { func (rec *OTLPTraceReceiver) peekResourceSpansLocked() []*tracev1.ResourceSpans {
return testutil.FlattenExportRequests(rec.receivedRequests) return tracetest.FlattenExportRequests(rec.receivedRequests)
} }
func (rec *OTLPTraceReceiver) FlushResourceSpans() []*tracev1.ResourceSpans { func (rec *OTLPTraceReceiver) FlushResourceSpans() []*tracev1.ResourceSpans {

View file

@ -18,7 +18,7 @@ import (
"github.com/pomerium/pomerium/internal/testenv/scenarios" "github.com/pomerium/pomerium/internal/testenv/scenarios"
"github.com/pomerium/pomerium/internal/testenv/snippets" "github.com/pomerium/pomerium/internal/testenv/snippets"
"github.com/pomerium/pomerium/internal/testenv/upstreams" "github.com/pomerium/pomerium/internal/testenv/upstreams"
"github.com/pomerium/pomerium/internal/testutil" . "github.com/pomerium/pomerium/internal/testutil/tracetest" //nolint:revive
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
@ -29,7 +29,7 @@ import (
oteltrace "go.opentelemetry.io/otel/trace" oteltrace "go.opentelemetry.io/otel/trace"
) )
func otlpTraceReceiverOrFromEnv(t *testing.T) (modifier testenv.Modifier, newRemoteClient func() otlptrace.Client, getResults func() *testutil.TraceResults) { func otlpTraceReceiverOrFromEnv(t *testing.T) (modifier testenv.Modifier, newRemoteClient func() otlptrace.Client, getResults func() *TraceResults) {
t.Setenv("OTEL_TRACES_EXPORTER", "otlp") t.Setenv("OTEL_TRACES_EXPORTER", "otlp")
tracesEndpoint := os.Getenv("OTEL_EXPORTER_OTLP_TRACES_ENDPOINT") tracesEndpoint := os.Getenv("OTEL_EXPORTER_OTLP_TRACES_ENDPOINT")
if tracesEndpoint == "" { if tracesEndpoint == "" {
@ -40,8 +40,8 @@ func otlpTraceReceiverOrFromEnv(t *testing.T) (modifier testenv.Modifier, newRem
func() otlptrace.Client { func() otlptrace.Client {
return srv.NewGRPCClient() return srv.NewGRPCClient()
}, },
func() *testutil.TraceResults { func() *TraceResults {
return testutil.NewTraceResults(srv.FlushResourceSpans()) return NewTraceResults(srv.FlushResourceSpans())
} }
} }
} }
@ -111,18 +111,18 @@ func TestOTLPTracing(t *testing.T) {
) )
results.MatchTraces(t, results.MatchTraces(t,
testutil.MatchOptions{ MatchOptions{
Exact: true, Exact: true,
CheckDetachedSpans: true, CheckDetachedSpans: true,
}, },
testutil.Match{Name: testEnvironmentLocalTest, TraceCount: 1, Services: []string{"Test Environment", "Control Plane", "Data Broker"}}, Match{Name: testEnvironmentLocalTest, TraceCount: 1, Services: []string{"Test Environment", "Control Plane", "Data Broker"}},
testutil.Match{Name: testEnvironmentAuthenticate, TraceCount: 1, Services: allServices}, Match{Name: testEnvironmentAuthenticate, TraceCount: 1, Services: allServices},
testutil.Match{Name: authenticateOAuth2Client, TraceCount: testutil.Greater(0)}, Match{Name: authenticateOAuth2Client, TraceCount: Greater(0)},
testutil.Match{Name: idpServerGetUserinfo, TraceCount: testutil.EqualToMatch(authenticateOAuth2Client)}, Match{Name: idpServerGetUserinfo, TraceCount: EqualToMatch(authenticateOAuth2Client)},
testutil.Match{Name: idpServerPostToken, TraceCount: testutil.EqualToMatch(authenticateOAuth2Client)}, Match{Name: idpServerPostToken, TraceCount: EqualToMatch(authenticateOAuth2Client)},
testutil.Match{Name: controlPlaneEnvoyDiscovery, TraceCount: 1}, Match{Name: controlPlaneEnvoyDiscovery, TraceCount: 1},
testutil.Match{Name: controlPlaneExport, TraceCount: testutil.Greater(0)}, Match{Name: controlPlaneExport, TraceCount: Greater(0)},
testutil.Match{Name: controlPlaneEnvoyAccessLogs, TraceCount: testutil.Any{}}, Match{Name: controlPlaneEnvoyAccessLogs, TraceCount: Any{}},
) )
} }
} }
@ -164,17 +164,15 @@ func TestOTLPTracing_TraceCorrelation(t *testing.T) {
if getResults != nil { if getResults != nil {
results := getResults() results := getResults()
traces := results.GetTraces() traces := results.GetTraces()
downstreamTraces := traces.ByName[fmt.Sprintf("Envoy: ingress: GET foo.localhost.pomerium.io:%d/foo", env.Ports().ProxyHTTP.Value())].WithoutErrors() // one unauthenticated (ends in /.pomerium/callback redirect), one authenticated
upstreamTraces := traces.ByName[fmt.Sprintf("Envoy: ingress: GET 127.0.0.1:%d/foo", up.Port().Value())].WithoutErrors() assert.Len(t, traces.ByName[fmt.Sprintf("Envoy: ingress: GET foo.localhost.pomerium.io:%d/foo", env.Ports().ProxyHTTP.Value())].WithoutErrors(), 2)
assert.Len(t, upstreamTraces, 1)
assert.Len(t, downstreamTraces, 1)
} }
} }
type SamplingTestSuite struct { type SamplingTestSuite struct {
suite.Suite suite.Suite
env testenv.Environment env testenv.Environment
getResults func() *testutil.TraceResults getResults func() *TraceResults
route testenv.Route route testenv.Route
upstream upstreams.HTTPUpstream upstream upstreams.HTTPUpstream
@ -308,6 +306,7 @@ func (s *SamplingTestSuite) TestExternalTraceparentNeverSample() {
if (len(traces.ByParticipant)) != 0 { if (len(traces.ByParticipant)) != 0 {
// whether or not these show up is timing dependent, but not important // whether or not these show up is timing dependent, but not important
possibleTraces := map[string]struct{}{ possibleTraces := map[string]struct{}{
"Test Environment: Start": {},
"IDP: Server: POST /oidc/token": {}, "IDP: Server: POST /oidc/token": {},
"IDP: Server: GET /oidc/userinfo": {}, "IDP: Server: GET /oidc/userinfo": {},
"Authenticate: OAuth2 Client: GET /.well-known/jwks.json": {}, "Authenticate: OAuth2 Client: GET /.well-known/jwks.json": {},
@ -315,7 +314,7 @@ func (s *SamplingTestSuite) TestExternalTraceparentNeverSample() {
actual := slices.Collect(maps.Keys(traces.ByName)) actual := slices.Collect(maps.Keys(traces.ByName))
for _, name := range actual { for _, name := range actual {
if _, ok := possibleTraces[name]; !ok { if _, ok := possibleTraces[name]; !ok {
s.Failf("unexpected trace: %s", name) s.Fail("unexpected trace: " + name)
} }
} }
} }
@ -384,8 +383,8 @@ func TestExternalSpans(t *testing.T) {
if getResults != nil { if getResults != nil {
results := getResults() results := getResults()
results.MatchTraces(t, testutil.MatchOptions{CheckDetachedSpans: true}, results.MatchTraces(t, MatchOptions{CheckDetachedSpans: true},
testutil.Match{Name: "External: External Root", TraceCount: 1, Services: []string{ Match{Name: "External: External Root", TraceCount: 1, Services: []string{
"Authorize", "Authorize",
"Authenticate", "Authenticate",
"Control Plane", "Control Plane",

View file

@ -142,9 +142,9 @@ func (g *grpcUpstream) Run(ctx context.Context) error {
} }
server := grpc.NewServer(append(g.serverOpts, server := grpc.NewServer(append(g.serverOpts,
grpc.Creds(g.creds), grpc.Creds(g.creds),
grpc.StatsHandler(trace.NewServerStatsHandler(otelgrpc.NewServerHandler( grpc.StatsHandler(otelgrpc.NewServerHandler(
otelgrpc.WithTracerProvider(g.serverTracerProvider.Value()), otelgrpc.WithTracerProvider(g.serverTracerProvider.Value()),
))), )),
)...) )...)
for _, s := range g.services { for _, s := range g.services {
server.RegisterService(s.desc, s.impl) server.RegisterService(s.desc, s.impl)

View file

@ -1,4 +1,4 @@
package trace package tracetest
import ( import (
"cmp" "cmp"
@ -6,10 +6,8 @@ import (
"maps" "maps"
"slices" "slices"
"sync" "sync"
"unique"
"github.com/pomerium/pomerium/internal/hashutil" "github.com/pomerium/pomerium/internal/hashutil"
oteltrace "go.opentelemetry.io/otel/trace"
commonv1 "go.opentelemetry.io/proto/otlp/common/v1" commonv1 "go.opentelemetry.io/proto/otlp/common/v1"
resourcev1 "go.opentelemetry.io/proto/otlp/resource/v1" resourcev1 "go.opentelemetry.io/proto/otlp/resource/v1"
tracev1 "go.opentelemetry.io/proto/otlp/trace/v1" tracev1 "go.opentelemetry.io/proto/otlp/trace/v1"
@ -72,30 +70,6 @@ func (rb *ResourceBuffer) Flush() []*tracev1.ScopeSpans {
return out return out
} }
func (rb *ResourceBuffer) FlushAs(rewriteTraceID unique.Handle[oteltrace.TraceID]) []*tracev1.ScopeSpans {
out := make([]*tracev1.ScopeSpans, 0, len(rb.spansByScope))
for _, key := range slices.Sorted(maps.Keys(rb.spansByScope)) {
spans := rb.spansByScope[key]
{
id := rewriteTraceID.Value()
for _, span := range spans.spans {
copy(span.TraceId, id[:])
}
}
slices.SortStableFunc(spans.spans, func(a, b *tracev1.Span) int {
return cmp.Compare(a.StartTimeUnixNano, b.StartTimeUnixNano)
})
scopeSpans := &tracev1.ScopeSpans{
Scope: spans.scope.Scope,
SchemaUrl: spans.scope.Schema,
Spans: spans.spans,
}
out = append(out, scopeSpans)
}
clear(rb.spansByScope)
return out
}
func (rb *ResourceBuffer) Merge(other *ResourceBuffer) { func (rb *ResourceBuffer) Merge(other *ResourceBuffer) {
for scope, otherSpans := range other.spansByScope { for scope, otherSpans := range other.spansByScope {
if ourSpans, ok := rb.spansByScope[scope]; !ok { if ourSpans, ok := rb.spansByScope[scope]; !ok {
@ -107,23 +81,6 @@ func (rb *ResourceBuffer) Merge(other *ResourceBuffer) {
clear(other.spansByScope) clear(other.spansByScope)
} }
func (rb *ResourceBuffer) MergeAs(other *ResourceBuffer, rewriteTraceID unique.Handle[oteltrace.TraceID]) {
for scope, otherSpans := range other.spansByScope {
{
id := rewriteTraceID.Value()
for _, span := range otherSpans.spans {
copy(span.TraceId, id[:])
}
}
if ourSpans, ok := rb.spansByScope[scope]; !ok {
rb.spansByScope[scope] = otherSpans
} else {
ourSpans.Insert(otherSpans.spans...)
}
}
clear(other.spansByScope)
}
type Buffer struct { type Buffer struct {
scopesByResourceID map[string]*ResourceBuffer scopesByResourceID map[string]*ResourceBuffer
} }
@ -161,21 +118,6 @@ func (b *Buffer) Flush() []*tracev1.ResourceSpans {
return out return out
} }
func (b *Buffer) FlushAs(rewriteTraceID unique.Handle[oteltrace.TraceID]) []*tracev1.ResourceSpans {
out := make([]*tracev1.ResourceSpans, 0, len(b.scopesByResourceID))
for _, key := range slices.Sorted(maps.Keys(b.scopesByResourceID)) {
scopes := b.scopesByResourceID[key]
resourceSpans := &tracev1.ResourceSpans{
Resource: scopes.resource.Resource,
ScopeSpans: scopes.FlushAs(rewriteTraceID),
SchemaUrl: scopes.resource.Schema,
}
out = append(out, resourceSpans)
}
clear(b.scopesByResourceID)
return out
}
func (b *Buffer) Merge(other *Buffer) { func (b *Buffer) Merge(other *Buffer) {
if b != nil { if b != nil {
for k, otherV := range other.scopesByResourceID { for k, otherV := range other.scopesByResourceID {
@ -189,21 +131,6 @@ func (b *Buffer) Merge(other *Buffer) {
clear(other.scopesByResourceID) clear(other.scopesByResourceID)
} }
func (b *Buffer) MergeAs(other *Buffer, rewriteTraceID unique.Handle[oteltrace.TraceID]) {
if b != nil {
for k, otherV := range other.scopesByResourceID {
if v, ok := b.scopesByResourceID[k]; !ok {
newRb := NewResourceBuffer(otherV.resource)
newRb.MergeAs(otherV, rewriteTraceID)
b.scopesByResourceID[k] = newRb
} else {
v.MergeAs(otherV, rewriteTraceID)
}
}
}
clear(other.scopesByResourceID)
}
func (b *Buffer) IsEmpty() bool { func (b *Buffer) IsEmpty() bool {
return len(b.scopesByResourceID) == 0 return len(b.scopesByResourceID) == 0
} }

View file

@ -1,110 +1,13 @@
package trace_test package tracetest
import ( import (
"encoding/binary"
"fmt"
"testing" "testing"
"unique"
"github.com/pomerium/pomerium/internal/telemetry/trace"
"github.com/pomerium/pomerium/internal/testutil" "github.com/pomerium/pomerium/internal/testutil"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
oteltrace "go.opentelemetry.io/otel/trace"
commonv1 "go.opentelemetry.io/proto/otlp/common/v1"
resourcev1 "go.opentelemetry.io/proto/otlp/resource/v1"
tracev1 "go.opentelemetry.io/proto/otlp/trace/v1" tracev1 "go.opentelemetry.io/proto/otlp/trace/v1"
) )
type (
Trace uint32
Span uint32
Scope uint32
Schema uint32
Resource uint32
)
func (n Trace) String() string { return fmt.Sprintf("Trace %d", n) }
func (n Span) String() string { return fmt.Sprintf("Span %d", n) }
func (n Scope) String() string { return fmt.Sprintf("Scope %d", n) }
func (n Schema) String() string { return fmt.Sprintf("Schema %d", n) }
func (n Resource) String() string { return fmt.Sprintf("Resource %d", n) }
func (n Trace) ID() unique.Handle[oteltrace.TraceID] {
id, _ := trace.ToTraceID(n.B())
return id
}
func (n Trace) B() []byte {
var id oteltrace.TraceID
binary.BigEndian.PutUint32(id[12:], uint32(n))
return id[:]
}
func (n Span) ID() oteltrace.SpanID {
id, _ := trace.ToSpanID(n.B())
return id
}
func (n Span) B() []byte {
var id oteltrace.SpanID
binary.BigEndian.PutUint32(id[4:], uint32(n))
return id[:]
}
func (n Scope) Make(s ...Schema) *trace.ScopeInfo {
if len(s) == 0 {
s = append(s, Schema(0))
}
return trace.NewScopeInfo(&commonv1.InstrumentationScope{
Name: n.String(),
Version: "v1",
Attributes: []*commonv1.KeyValue{
{
Key: "id",
Value: &commonv1.AnyValue{
Value: &commonv1.AnyValue_IntValue{
IntValue: int64(n),
},
},
},
},
}, s[0].String())
}
func (n Resource) Make(s ...Schema) *trace.ResourceInfo {
if len(s) == 0 {
s = append(s, Schema(0))
}
return trace.NewResourceInfo(&resourcev1.Resource{
Attributes: []*commonv1.KeyValue{
{
Key: "name",
Value: &commonv1.AnyValue{
Value: &commonv1.AnyValue_StringValue{
StringValue: n.String(),
},
},
},
{
Key: "id",
Value: &commonv1.AnyValue{
Value: &commonv1.AnyValue_IntValue{
IntValue: int64(n),
},
},
},
},
}, s[0].String())
}
func Traceparent(trace Trace, span Span, sampled bool) string {
sampledStr := "00"
if sampled {
sampledStr = "01"
}
return fmt.Sprintf("00-%s-%s-%s", trace.ID().Value(), span.ID(), sampledStr)
}
func TestBuffer(t *testing.T) { func TestBuffer(t *testing.T) {
t.Parallel() t.Parallel()
@ -128,8 +31,8 @@ func TestBuffer(t *testing.T) {
{TraceId: Trace(2).B(), SpanId: Span(16).B(), StartTimeUnixNano: 16}, {TraceId: Trace(2).B(), SpanId: Span(16).B(), StartTimeUnixNano: 16},
} }
newTestBuffer := func() *trace.Buffer { newTestBuffer := func() *Buffer {
b := trace.NewBuffer() b := NewBuffer()
b.Insert(Resource(1).Make(), Scope(1).Make(), s[0]) b.Insert(Resource(1).Make(), Scope(1).Make(), s[0])
b.Insert(Resource(1).Make(), Scope(1).Make(), s[1]) b.Insert(Resource(1).Make(), Scope(1).Make(), s[1])
b.Insert(Resource(1).Make(), Scope(1).Make(), s[2]) b.Insert(Resource(1).Make(), Scope(1).Make(), s[2])
@ -191,26 +94,12 @@ func TestBuffer(t *testing.T) {
assert.True(t, b.IsEmpty()) assert.True(t, b.IsEmpty())
testutil.AssertProtoEqual(t, newExpectedSpans(), actual) testutil.AssertProtoEqual(t, newExpectedSpans(), actual)
}) })
t.Run("FlushAs", func(t *testing.T) {
b := newTestBuffer()
actual := b.FlushAs(Trace(100).ID())
assert.True(t, b.IsEmpty())
expected := newExpectedSpans()
for _, resourceSpans := range expected {
for _, scopeSpans := range resourceSpans.ScopeSpans {
for _, span := range scopeSpans.Spans {
span.TraceId = Trace(100).B()
}
}
}
testutil.AssertProtoEqual(t, expected, actual)
})
t.Run("Default scope", func(t *testing.T) { t.Run("Default scope", func(t *testing.T) {
b := trace.NewBuffer() b := NewBuffer()
b.Insert(Resource(1).Make(Schema(2)), trace.NewScopeInfo(nil, ""), s[0]) b.Insert(Resource(1).Make(Schema(2)), NewScopeInfo(nil, ""), s[0])
b.Insert(Resource(1).Make(Schema(2)), trace.NewScopeInfo(nil, ""), s[1]) b.Insert(Resource(1).Make(Schema(2)), NewScopeInfo(nil, ""), s[1])
b.Insert(Resource(1).Make(Schema(2)), trace.NewScopeInfo(nil, ""), s[2]) b.Insert(Resource(1).Make(Schema(2)), NewScopeInfo(nil, ""), s[2])
actual := b.Flush() actual := b.Flush()
testutil.AssertProtoEqual(t, []*tracev1.ResourceSpans{ testutil.AssertProtoEqual(t, []*tracev1.ResourceSpans{
{ {

View file

@ -1,8 +1,9 @@
package testutil package tracetest
import ( import (
"cmp" "cmp"
"context" "context"
"encoding/binary"
"encoding/json" "encoding/json"
"fmt" "fmt"
"maps" "maps"
@ -28,6 +29,96 @@ import (
"google.golang.org/protobuf/testing/protocmp" "google.golang.org/protobuf/testing/protocmp"
) )
type (
Trace uint32
Span uint32
Scope uint32
Schema uint32
Resource uint32
)
func (n Trace) String() string { return fmt.Sprintf("Trace %d", n) }
func (n Span) String() string { return fmt.Sprintf("Span %d", n) }
func (n Scope) String() string { return fmt.Sprintf("Scope %d", n) }
func (n Schema) String() string { return fmt.Sprintf("Schema %d", n) }
func (n Resource) String() string { return fmt.Sprintf("Resource %d", n) }
func (n Trace) ID() unique.Handle[oteltrace.TraceID] {
id, _ := trace.ToTraceID(n.B())
return id
}
func (n Trace) B() []byte {
var id oteltrace.TraceID
binary.BigEndian.PutUint32(id[12:], uint32(n))
return id[:]
}
func (n Span) ID() oteltrace.SpanID {
id, _ := trace.ToSpanID(n.B())
return id
}
func (n Span) B() []byte {
var id oteltrace.SpanID
binary.BigEndian.PutUint32(id[4:], uint32(n))
return id[:]
}
func (n Scope) Make(s ...Schema) *ScopeInfo {
if len(s) == 0 {
s = append(s, Schema(0))
}
return NewScopeInfo(&commonv1.InstrumentationScope{
Name: n.String(),
Version: "v1",
Attributes: []*commonv1.KeyValue{
{
Key: "id",
Value: &commonv1.AnyValue{
Value: &commonv1.AnyValue_IntValue{
IntValue: int64(n),
},
},
},
},
}, s[0].String())
}
func (n Resource) Make(s ...Schema) *ResourceInfo {
if len(s) == 0 {
s = append(s, Schema(0))
}
return NewResourceInfo(&resourcev1.Resource{
Attributes: []*commonv1.KeyValue{
{
Key: "name",
Value: &commonv1.AnyValue{
Value: &commonv1.AnyValue_StringValue{
StringValue: n.String(),
},
},
},
{
Key: "id",
Value: &commonv1.AnyValue{
Value: &commonv1.AnyValue_IntValue{
IntValue: int64(n),
},
},
},
},
}, s[0].String())
}
func Traceparent(trace Trace, span Span, sampled bool) string {
sampledStr := "00"
if sampled {
sampledStr = "01"
}
return fmt.Sprintf("00-%s-%s-%s", trace.ID().Value(), span.ID(), sampledStr)
}
type TraceResults struct { type TraceResults struct {
resourceSpans []*tracev1.ResourceSpans resourceSpans []*tracev1.ResourceSpans
@ -86,11 +177,7 @@ type TraceDetails struct {
func (td *TraceDetails) Equal(other *TraceDetails) (bool, string) { func (td *TraceDetails) Equal(other *TraceDetails) (bool, string) {
diffSpans := func(a, b []*SpanDetails) (bool, string) { diffSpans := func(a, b []*SpanDetails) (bool, string) {
for i := range len(a) { for i := range len(a) {
aRaw := proto.Clone(a[i].Raw).(*tracev1.Span) diff := gocmp.Diff(a[i], b[i], protocmp.Transform())
trace.FormatSpanName(aRaw)
bRaw := proto.Clone(b[i].Raw).(*tracev1.Span)
trace.FormatSpanName(bRaw)
diff := gocmp.Diff(aRaw, bRaw, protocmp.Transform())
if diff != "" { if diff != "" {
return false, diff return false, diff
} }
@ -426,12 +513,12 @@ func (tr *TraceResults) AssertEqual(t testing.TB, expectedResults *TraceResults,
} }
func FlattenResourceSpans(lists [][]*tracev1.ResourceSpans) []*tracev1.ResourceSpans { func FlattenResourceSpans(lists [][]*tracev1.ResourceSpans) []*tracev1.ResourceSpans {
res := trace.NewBuffer() res := NewBuffer()
for _, list := range lists { for _, list := range lists {
for _, resource := range list { for _, resource := range list {
resInfo := trace.NewResourceInfo(resource.Resource, resource.SchemaUrl) resInfo := NewResourceInfo(resource.Resource, resource.SchemaUrl)
for _, scope := range resource.ScopeSpans { for _, scope := range resource.ScopeSpans {
scopeInfo := trace.NewScopeInfo(scope.Scope, scope.SchemaUrl) scopeInfo := NewScopeInfo(scope.Scope, scope.SchemaUrl)
for _, span := range scope.Spans { for _, span := range scope.Spans {
res.Insert(resInfo, scopeInfo, span) res.Insert(resInfo, scopeInfo, span)
} }