package trace_test

import (
	"context"
	"errors"
	"net/http"
	"net/http/httptest"
	"testing"
	"time"

	"github.com/gorilla/mux"
	"github.com/pomerium/pomerium/internal/telemetry/trace"
	"github.com/stretchr/testify/assert"
	"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
	sdktrace "go.opentelemetry.io/otel/sdk/trace"
	oteltrace "go.opentelemetry.io/otel/trace"
	"google.golang.org/grpc/metadata"
	"google.golang.org/grpc/stats"
)

func TestHTTPMiddleware(t *testing.T) {
	router := mux.NewRouter()
	tp := sdktrace.NewTracerProvider()
	router.Use(trace.NewHTTPMiddleware(
		otelhttp.WithTracerProvider(tp),
	))
	router.Path("/foo").HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
		span := oteltrace.SpanFromContext(r.Context())
		assert.Equal(t, "Server: GET /foo", span.(interface{ Name() string }).Name())
	}).Methods(http.MethodGet)
	w := httptest.NewRecorder()
	ctx, span := tp.Tracer("test").Start(context.Background(), "test")
	router.ServeHTTP(w, httptest.NewRequestWithContext(ctx, http.MethodGet, "/foo", nil))
	span.End()
}

type mockHandler struct {
	handleConn func(ctx context.Context, stats stats.ConnStats)
	handleRPC  func(ctx context.Context, stats stats.RPCStats)
	tagConn    func(ctx context.Context, info *stats.ConnTagInfo) context.Context
	tagRPC     func(ctx context.Context, info *stats.RPCTagInfo) context.Context
}

// HandleConn implements stats.Handler.
func (m *mockHandler) HandleConn(ctx context.Context, stats stats.ConnStats) {
	m.handleConn(ctx, stats)
}

// HandleRPC implements stats.Handler.
func (m *mockHandler) HandleRPC(ctx context.Context, stats stats.RPCStats) {
	m.handleRPC(ctx, stats)
}

// TagConn implements stats.Handler.
func (m *mockHandler) TagConn(ctx context.Context, info *stats.ConnTagInfo) context.Context {
	return m.tagConn(ctx, info)
}

// TagRPC implements stats.Handler.
func (m *mockHandler) TagRPC(ctx context.Context, info *stats.RPCTagInfo) context.Context {
	return m.tagRPC(ctx, info)
}

var _ stats.Handler = (*mockHandler)(nil)

func TestStatsInterceptor(t *testing.T) {
	var outBegin *stats.Begin
	var outEnd *stats.End
	base := &mockHandler{
		handleRPC: func(_ context.Context, rs stats.RPCStats) {
			switch rs := rs.(type) {
			case *stats.Begin:
				outBegin = rs
			case *stats.End:
				outEnd = rs
			}
		},
	}
	interceptor := func(_ context.Context, rs stats.RPCStats) stats.RPCStats {
		switch rs := rs.(type) {
		case *stats.Begin:
			return &stats.Begin{
				Client:                    rs.Client,
				BeginTime:                 rs.BeginTime.Add(-1 * time.Minute),
				FailFast:                  rs.FailFast,
				IsClientStream:            rs.IsClientStream,
				IsServerStream:            rs.IsServerStream,
				IsTransparentRetryAttempt: rs.IsTransparentRetryAttempt,
			}
		case *stats.End:
			return &stats.End{
				Client:    rs.Client,
				BeginTime: rs.BeginTime,
				EndTime:   rs.EndTime,
				Trailer:   rs.Trailer,
				Error:     errors.New("modified"),
			}
		}
		return rs
	}
	handler := trace.NewClientStatsHandler(
		base,
		trace.WithStatsInterceptor(interceptor),
	)
	inBegin := &stats.Begin{
		Client:                    true,
		BeginTime:                 time.Now(),
		FailFast:                  true,
		IsClientStream:            true,
		IsServerStream:            false,
		IsTransparentRetryAttempt: false,
	}
	handler.HandleRPC(context.Background(), inBegin)
	assert.NotNil(t, outBegin)
	assert.NotSame(t, inBegin, outBegin)
	assert.Equal(t, inBegin.BeginTime.Add(-1*time.Minute), outBegin.BeginTime)
	assert.Equal(t, inBegin.Client, outBegin.Client)
	assert.Equal(t, inBegin.FailFast, outBegin.FailFast)
	assert.Equal(t, inBegin.IsClientStream, outBegin.IsClientStream)
	assert.Equal(t, inBegin.IsServerStream, outBegin.IsServerStream)
	assert.Equal(t, inBegin.IsTransparentRetryAttempt, outBegin.IsTransparentRetryAttempt)

	inEnd := &stats.End{
		Client:    true,
		BeginTime: time.Now(),
		EndTime:   time.Now().Add(1 * time.Minute),
		Trailer:   metadata.Pairs("a", "b", "c", "d"),
		Error:     errors.New("input"),
	}
	handler.HandleRPC(context.Background(), inEnd)
	assert.NotNil(t, outEnd)
	assert.NotSame(t, inEnd, outEnd)
	assert.Equal(t, inEnd.Client, outEnd.Client)
	assert.Equal(t, inEnd.BeginTime, outEnd.BeginTime)
	assert.Equal(t, inEnd.EndTime, outEnd.EndTime)
	assert.Equal(t, inEnd.Trailer, outEnd.Trailer)
	assert.Equal(t, "input", inEnd.Error.Error())
	assert.Equal(t, "modified", outEnd.Error.Error())
}

func TestStatsInterceptor_Nil(t *testing.T) {
	var outCtx context.Context
	var outConnStats stats.ConnStats
	var outRPCStats stats.RPCStats
	var outConnTagInfo *stats.ConnTagInfo
	var outRPCTagInfo *stats.RPCTagInfo
	base := &mockHandler{
		handleConn: func(ctx context.Context, stats stats.ConnStats) {
			outCtx = ctx
			outConnStats = stats
		},
		handleRPC: func(ctx context.Context, stats stats.RPCStats) {
			outCtx = ctx
			outRPCStats = stats
		},
		tagConn: func(ctx context.Context, info *stats.ConnTagInfo) context.Context {
			outCtx = ctx
			outConnTagInfo = info
			return ctx
		},
		tagRPC: func(ctx context.Context, info *stats.RPCTagInfo) context.Context {
			outCtx = ctx
			outRPCTagInfo = info
			return ctx
		},
	}
	handler := trace.NewClientStatsHandler(
		base,
		trace.WithStatsInterceptor(nil),
	)

	inCtx := context.Background()
	inConnStats := &stats.ConnBegin{}
	inRPCStats := &stats.Begin{}
	inConnTagInfo := &stats.ConnTagInfo{}
	inRPCTagInfo := &stats.RPCTagInfo{}

	handler.HandleConn(inCtx, inConnStats)
	assert.Equal(t, inCtx, outCtx)
	assert.Same(t, inConnStats, outConnStats)

	handler.HandleRPC(inCtx, inRPCStats)
	assert.Equal(t, inCtx, outCtx)
	assert.Same(t, inRPCStats, outRPCStats)

	handler.TagConn(inCtx, inConnTagInfo)
	assert.Equal(t, inCtx, outCtx)
	assert.Same(t, inConnTagInfo, outConnTagInfo)

	handler.TagRPC(inCtx, inRPCTagInfo)
	assert.Equal(t, inCtx, outCtx)
	assert.Same(t, inRPCTagInfo, outRPCTagInfo)
}

func TestStatsInterceptor_Bug(t *testing.T) {
	handler := trace.NewClientStatsHandler(
		&mockHandler{
			handleRPC: func(_ context.Context, _ stats.RPCStats) {
				t.Error("should not be reached")
			},
		},
		trace.WithStatsInterceptor(func(_ context.Context, rs stats.RPCStats) stats.RPCStats {
			_ = rs.(*stats.Begin)
			return &stats.End{}
		}),
	)
	assert.PanicsWithValue(t, "bug: stats interceptor returned a message of a different type", func() {
		handler.HandleRPC(context.Background(), &stats.Begin{})
	})
}