mirror of
https://github.com/pomerium/pomerium.git
synced 2025-04-28 18:06:34 +02:00
210 lines
6.3 KiB
Go
210 lines
6.3 KiB
Go
package trace_test
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/gorilla/mux"
|
|
"github.com/pomerium/pomerium/pkg/telemetry/trace"
|
|
"github.com/stretchr/testify/assert"
|
|
"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
|
|
sdktrace "go.opentelemetry.io/otel/sdk/trace"
|
|
oteltrace "go.opentelemetry.io/otel/trace"
|
|
"google.golang.org/grpc/metadata"
|
|
"google.golang.org/grpc/stats"
|
|
)
|
|
|
|
func TestHTTPMiddleware(t *testing.T) {
|
|
router := mux.NewRouter()
|
|
tp := sdktrace.NewTracerProvider()
|
|
router.Use(trace.NewHTTPMiddleware(
|
|
otelhttp.WithTracerProvider(tp),
|
|
))
|
|
router.Path("/foo").HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
|
|
span := oteltrace.SpanFromContext(r.Context())
|
|
assert.Equal(t, "Server: GET /foo", span.(interface{ Name() string }).Name())
|
|
}).Methods(http.MethodGet)
|
|
w := httptest.NewRecorder()
|
|
ctx, span := tp.Tracer("test").Start(context.Background(), "test")
|
|
router.ServeHTTP(w, httptest.NewRequestWithContext(ctx, http.MethodGet, "/foo", nil))
|
|
span.End()
|
|
}
|
|
|
|
type mockHandler struct {
|
|
handleConn func(ctx context.Context, stats stats.ConnStats)
|
|
handleRPC func(ctx context.Context, stats stats.RPCStats)
|
|
tagConn func(ctx context.Context, info *stats.ConnTagInfo) context.Context
|
|
tagRPC func(ctx context.Context, info *stats.RPCTagInfo) context.Context
|
|
}
|
|
|
|
// HandleConn implements stats.Handler.
|
|
func (m *mockHandler) HandleConn(ctx context.Context, stats stats.ConnStats) {
|
|
m.handleConn(ctx, stats)
|
|
}
|
|
|
|
// HandleRPC implements stats.Handler.
|
|
func (m *mockHandler) HandleRPC(ctx context.Context, stats stats.RPCStats) {
|
|
m.handleRPC(ctx, stats)
|
|
}
|
|
|
|
// TagConn implements stats.Handler.
|
|
func (m *mockHandler) TagConn(ctx context.Context, info *stats.ConnTagInfo) context.Context {
|
|
return m.tagConn(ctx, info)
|
|
}
|
|
|
|
// TagRPC implements stats.Handler.
|
|
func (m *mockHandler) TagRPC(ctx context.Context, info *stats.RPCTagInfo) context.Context {
|
|
return m.tagRPC(ctx, info)
|
|
}
|
|
|
|
var _ stats.Handler = (*mockHandler)(nil)
|
|
|
|
func TestStatsInterceptor(t *testing.T) {
|
|
var outBegin *stats.Begin
|
|
var outEnd *stats.End
|
|
base := &mockHandler{
|
|
handleRPC: func(_ context.Context, rs stats.RPCStats) {
|
|
switch rs := rs.(type) {
|
|
case *stats.Begin:
|
|
outBegin = rs
|
|
case *stats.End:
|
|
outEnd = rs
|
|
}
|
|
},
|
|
}
|
|
interceptor := func(_ context.Context, rs stats.RPCStats) stats.RPCStats {
|
|
switch rs := rs.(type) {
|
|
case *stats.Begin:
|
|
return &stats.Begin{
|
|
Client: rs.Client,
|
|
BeginTime: rs.BeginTime.Add(-1 * time.Minute),
|
|
FailFast: rs.FailFast,
|
|
IsClientStream: rs.IsClientStream,
|
|
IsServerStream: rs.IsServerStream,
|
|
IsTransparentRetryAttempt: rs.IsTransparentRetryAttempt,
|
|
}
|
|
case *stats.End:
|
|
return &stats.End{
|
|
Client: rs.Client,
|
|
BeginTime: rs.BeginTime,
|
|
EndTime: rs.EndTime,
|
|
Trailer: rs.Trailer,
|
|
Error: errors.New("modified"),
|
|
}
|
|
}
|
|
return rs
|
|
}
|
|
handler := trace.NewClientStatsHandler(
|
|
base,
|
|
trace.WithStatsInterceptor(interceptor),
|
|
)
|
|
inBegin := &stats.Begin{
|
|
Client: true,
|
|
BeginTime: time.Now(),
|
|
FailFast: true,
|
|
IsClientStream: true,
|
|
IsServerStream: false,
|
|
IsTransparentRetryAttempt: false,
|
|
}
|
|
handler.HandleRPC(context.Background(), inBegin)
|
|
assert.NotNil(t, outBegin)
|
|
assert.NotSame(t, inBegin, outBegin)
|
|
assert.Equal(t, inBegin.BeginTime.Add(-1*time.Minute), outBegin.BeginTime)
|
|
assert.Equal(t, inBegin.Client, outBegin.Client)
|
|
assert.Equal(t, inBegin.FailFast, outBegin.FailFast)
|
|
assert.Equal(t, inBegin.IsClientStream, outBegin.IsClientStream)
|
|
assert.Equal(t, inBegin.IsServerStream, outBegin.IsServerStream)
|
|
assert.Equal(t, inBegin.IsTransparentRetryAttempt, outBegin.IsTransparentRetryAttempt)
|
|
|
|
inEnd := &stats.End{
|
|
Client: true,
|
|
BeginTime: time.Now(),
|
|
EndTime: time.Now().Add(1 * time.Minute),
|
|
Trailer: metadata.Pairs("a", "b", "c", "d"),
|
|
Error: errors.New("input"),
|
|
}
|
|
handler.HandleRPC(context.Background(), inEnd)
|
|
assert.NotNil(t, outEnd)
|
|
assert.NotSame(t, inEnd, outEnd)
|
|
assert.Equal(t, inEnd.Client, outEnd.Client)
|
|
assert.Equal(t, inEnd.BeginTime, outEnd.BeginTime)
|
|
assert.Equal(t, inEnd.EndTime, outEnd.EndTime)
|
|
assert.Equal(t, inEnd.Trailer, outEnd.Trailer)
|
|
assert.Equal(t, "input", inEnd.Error.Error())
|
|
assert.Equal(t, "modified", outEnd.Error.Error())
|
|
}
|
|
|
|
func TestStatsInterceptor_Nil(t *testing.T) {
|
|
var outCtx context.Context
|
|
var outConnStats stats.ConnStats
|
|
var outRPCStats stats.RPCStats
|
|
var outConnTagInfo *stats.ConnTagInfo
|
|
var outRPCTagInfo *stats.RPCTagInfo
|
|
base := &mockHandler{
|
|
handleConn: func(ctx context.Context, stats stats.ConnStats) {
|
|
outCtx = ctx
|
|
outConnStats = stats
|
|
},
|
|
handleRPC: func(ctx context.Context, stats stats.RPCStats) {
|
|
outCtx = ctx
|
|
outRPCStats = stats
|
|
},
|
|
tagConn: func(ctx context.Context, info *stats.ConnTagInfo) context.Context {
|
|
outCtx = ctx
|
|
outConnTagInfo = info
|
|
return ctx
|
|
},
|
|
tagRPC: func(ctx context.Context, info *stats.RPCTagInfo) context.Context {
|
|
outCtx = ctx
|
|
outRPCTagInfo = info
|
|
return ctx
|
|
},
|
|
}
|
|
handler := trace.NewClientStatsHandler(
|
|
base,
|
|
trace.WithStatsInterceptor(nil),
|
|
)
|
|
|
|
inCtx := context.Background()
|
|
inConnStats := &stats.ConnBegin{}
|
|
inRPCStats := &stats.Begin{}
|
|
inConnTagInfo := &stats.ConnTagInfo{}
|
|
inRPCTagInfo := &stats.RPCTagInfo{}
|
|
|
|
handler.HandleConn(inCtx, inConnStats)
|
|
assert.Equal(t, inCtx, outCtx)
|
|
assert.Same(t, inConnStats, outConnStats)
|
|
|
|
handler.HandleRPC(inCtx, inRPCStats)
|
|
assert.Equal(t, inCtx, outCtx)
|
|
assert.Same(t, inRPCStats, outRPCStats)
|
|
|
|
handler.TagConn(inCtx, inConnTagInfo)
|
|
assert.Equal(t, inCtx, outCtx)
|
|
assert.Same(t, inConnTagInfo, outConnTagInfo)
|
|
|
|
handler.TagRPC(inCtx, inRPCTagInfo)
|
|
assert.Equal(t, inCtx, outCtx)
|
|
assert.Same(t, inRPCTagInfo, outRPCTagInfo)
|
|
}
|
|
|
|
func TestStatsInterceptor_Bug(t *testing.T) {
|
|
handler := trace.NewClientStatsHandler(
|
|
&mockHandler{
|
|
handleRPC: func(_ context.Context, _ stats.RPCStats) {
|
|
t.Error("should not be reached")
|
|
},
|
|
},
|
|
trace.WithStatsInterceptor(func(_ context.Context, rs stats.RPCStats) stats.RPCStats {
|
|
_ = rs.(*stats.Begin)
|
|
return &stats.End{}
|
|
}),
|
|
)
|
|
assert.PanicsWithValue(t, "bug: stats interceptor returned a message of a different type", func() {
|
|
handler.HandleRPC(context.Background(), &stats.Begin{})
|
|
})
|
|
}
|