package upstreams

import (
	"context"
	"crypto/tls"
	"errors"
	"fmt"
	"io"
	"net"
	"net/http"
	"net/http/cookiejar"
	"net/http/httptrace"
	"net/url"
	"os"
	"sync"
	"time"

	"github.com/gorilla/mux"
	"github.com/gorilla/websocket"
	"github.com/pomerium/pomerium/internal/testenv"
	"github.com/pomerium/pomerium/internal/testenv/snippets"
	"github.com/pomerium/pomerium/internal/testenv/values"
	"github.com/pomerium/pomerium/pkg/telemetry/trace"
	"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"

	"go.opentelemetry.io/otel/attribute"
	"go.opentelemetry.io/otel/codes"
	oteltrace "go.opentelemetry.io/otel/trace"
)

type Protocol string

const (
	DialHTTP1 Protocol = "http/1.1"
	DialHTTP2 Protocol = "h2"
	DialHTTP3 Protocol = "h3"
)

type RequestOptions struct {
	requestCtx     context.Context
	path           string
	query          url.Values
	headers        map[string]string
	authenticateAs string
	body           any
	clientCerts    []tls.Certificate
	clientHook     func(*http.Client) *http.Client
	dialerHook     func(*websocket.Dialer, *url.URL) (*websocket.Dialer, *url.URL)
	dialProtocol   Protocol
	trace          *httptrace.ClientTrace
}

type RequestOption func(*RequestOptions)

func (ro RequestOption) Format(fmt.State, rune) {
	panic("test bug: request option mistakenly passed to assert function")
}

func (o *RequestOptions) apply(opts ...RequestOption) {
	for _, op := range opts {
		op(o)
	}
}

// Path sets the path of the request. If omitted, the request URL will match
// the route URL exactly.
func Path(path string) RequestOption {
	return func(o *RequestOptions) {
		o.path = path
	}
}

// Query sets optional query parameters of the request.
func Query(query url.Values) RequestOption {
	return func(o *RequestOptions) {
		o.query = query
	}
}

// Headers adds optional headers to the request.
func Headers(headers map[string]string) RequestOption {
	return func(o *RequestOptions) {
		o.headers = headers
	}
}

func AuthenticateAs(email string) RequestOption {
	return func(o *RequestOptions) {
		o.authenticateAs = email
	}
}

// ClientHook allows editing or replacing the http client before it is used.
// When any request is about to start, this function will be called with the
// client that would be used to make the request. The returned client will
// be the actual client used for that request. It can be the same as the input
// (with or without modification), or replaced entirely.
//
// Note: the Transport of the client passed to the hook will always be a
// [*Transport]. That transport's underlying transport will always be
// a [*otelhttp.Transport].
func ClientHook(f func(*http.Client) *http.Client) RequestOption {
	return func(o *RequestOptions) {
		o.clientHook = f
	}
}

// DialerHook allows editing or replacing the websocket dialer before it is
// used. When a websocket request is about to start (using the DialWS method),
// this function will be called with the dialer that would be used, and the
// destination URL (including wss:// scheme, and path if one is present). The
// returned dialer+URL will be the actual dialer+URL used for that request.
//
// If ClientHook is also set, both will be called. The dialer passed to this
// hook will have its TLSClientConfig and Jar fields set from the client.
func DialerHook(f func(*websocket.Dialer, *url.URL) (*websocket.Dialer, *url.URL)) RequestOption {
	return func(o *RequestOptions) {
		o.dialerHook = f
	}
}

func DialProtocol(protocol Protocol) RequestOption {
	return func(o *RequestOptions) {
		o.dialProtocol = protocol
	}
}

func Context(ctx context.Context) RequestOption {
	return func(o *RequestOptions) {
		o.requestCtx = ctx
	}
}

func WithClientTrace(ct *httptrace.ClientTrace) RequestOption {
	return func(o *RequestOptions) {
		o.trace = ct
	}
}

// Body sets the body of the request.
// The argument can be one of the following types:
// - string
// - []byte
// - io.Reader
// - proto.Message
// - any json-encodable type
// If the argument is encoded as json, the Content-Type header will be set to
// "application/json". If the argument is a proto.Message, the Content-Type
// header will be set to "application/octet-stream".
func Body(body any) RequestOption {
	return func(o *RequestOptions) {
		o.body = body
	}
}

// ClientCert adds a client certificate to the request.
func ClientCert[T interface {
	*testenv.Certificate | *tls.Certificate
}](cert T) RequestOption {
	return func(o *RequestOptions) {
		o.clientCerts = append(o.clientCerts, *(*tls.Certificate)(cert))
	}
}

type HTTPUpstreamOptions struct {
	CommonUpstreamOptions
}

type HTTPUpstreamOption interface {
	applyHTTP(*HTTPUpstreamOptions)
}

// HTTPUpstream represents a HTTP server which can be used as the target for
// one or more Pomerium routes in a test environment.
//
// The Handle() method can be used to add handlers the server-side HTTP router,
// while the Get(), Post(), and (generic) Do() methods can be used to make
// client-side requests.
type HTTPUpstream interface {
	testenv.Upstream

	Handle(path string, f func(http.ResponseWriter, *http.Request)) *mux.Route
	HandleWS(path string, upgrader websocket.Upgrader, f func(conn *websocket.Conn) error) *mux.Route

	Get(r testenv.Route, opts ...RequestOption) (*http.Response, error)
	Post(r testenv.Route, opts ...RequestOption) (*http.Response, error)
	Do(method string, r testenv.Route, opts ...RequestOption) (*http.Response, error)
	DialWS(r testenv.Route, f func(conn *websocket.Conn) error, opts ...RequestOption) error
}

type httpUpstream struct {
	HTTPUpstreamOptions
	testenv.Aggregate
	serverPort values.MutableValue[int]
	tlsConfig  values.Value[*tls.Config]

	clientCache sync.Map // map[testenv.Route]*http.Client

	router               *mux.Router
	serverTracerProvider values.MutableValue[oteltrace.TracerProvider]
	clientTracerProvider values.MutableValue[oteltrace.TracerProvider]
	clientTracer         values.Value[oteltrace.Tracer]
}

var (
	_ testenv.Upstream = (*httpUpstream)(nil)
	_ HTTPUpstream     = (*httpUpstream)(nil)
)

// HTTP creates a new HTTP upstream server.
func HTTP(tlsConfig values.Value[*tls.Config], opts ...HTTPUpstreamOption) HTTPUpstream {
	options := HTTPUpstreamOptions{
		CommonUpstreamOptions: CommonUpstreamOptions{
			displayName: "HTTP Upstream",
		},
	}
	for _, op := range opts {
		op.applyHTTP(&options)
	}
	up := &httpUpstream{
		HTTPUpstreamOptions:  options,
		serverPort:           values.Deferred[int](),
		router:               mux.NewRouter(),
		tlsConfig:            tlsConfig,
		serverTracerProvider: values.Deferred[oteltrace.TracerProvider](),
		clientTracerProvider: values.Deferred[oteltrace.TracerProvider](),
	}
	up.clientTracer = values.Bind(up.clientTracerProvider, func(tp oteltrace.TracerProvider) oteltrace.Tracer {
		return tp.Tracer(trace.PomeriumCoreTracer)
	})
	up.RecordCaller()
	return up
}

// Port implements HTTPUpstream.
func (h *httpUpstream) Addr() values.Value[string] {
	return values.Bind(h.serverPort, func(port int) string {
		return fmt.Sprintf("%s:%d", h.Env().Host(), port)
	})
}

// Router implements HTTPUpstream.
func (h *httpUpstream) Handle(path string, f func(http.ResponseWriter, *http.Request)) *mux.Route {
	return h.router.HandleFunc(path, f)
}

// Router implements HTTPUpstream.
func (h *httpUpstream) HandleWS(path string, upgrader websocket.Upgrader, f func(*websocket.Conn) error) *mux.Route {
	return h.router.HandleFunc(path, func(w http.ResponseWriter, r *http.Request) {
		ctx, span := trace.Continue(r.Context(), "HandleWS")
		defer span.End()
		c, err := upgrader.Upgrade(w, r.WithContext(ctx), nil)
		if err != nil {
			span.SetStatus(codes.Error, err.Error())
			w.WriteHeader(http.StatusBadRequest)
			_, _ = w.Write([]byte(err.Error()))
			return
		}
		defer c.Close()

		err = f(c)
		if err != nil {
			if errors.Is(err, io.EOF) {
				return
			}
			span.SetStatus(codes.Error, err.Error())
			fmt.Fprintf(os.Stderr, "websocket error: %s\n", err.Error())
		}
	})
}

// Route implements HTTPUpstream.
func (h *httpUpstream) Route() testenv.RouteStub {
	r := &testenv.PolicyRoute{}
	protocol := "http"
	r.To(values.Bind(h.serverPort, func(port int) string {
		return fmt.Sprintf("%s://%s:%d", protocol, h.Env().Host(), port)
	}))
	h.Add(r)
	return r
}

// Run implements HTTPUpstream.
func (h *httpUpstream) Run(ctx context.Context) error {
	var listener net.Listener
	if h.tlsConfig != nil {
		var err error
		listener, err = tls.Listen("tcp", fmt.Sprintf("%s:0", h.Env().Host()), h.tlsConfig.Value())
		if err != nil {
			return err
		}
	} else {
		var err error
		listener, err = net.Listen("tcp", fmt.Sprintf("%s:0", h.Env().Host()))
		if err != nil {
			return err
		}
	}
	h.serverPort.Resolve(listener.Addr().(*net.TCPAddr).Port)
	if h.serverTracerProviderOverride != nil {
		h.serverTracerProvider.Resolve(h.serverTracerProviderOverride)
	} else {
		h.serverTracerProvider.Resolve(trace.NewTracerProvider(ctx, h.displayName))
	}
	if h.clientTracerProviderOverride != nil {
		h.clientTracerProvider.Resolve(h.clientTracerProviderOverride)
	} else {
		h.clientTracerProvider.Resolve(trace.NewTracerProvider(ctx, "HTTP Client"))
	}
	h.router.Use(trace.NewHTTPMiddleware(otelhttp.WithTracerProvider(h.serverTracerProvider.Value())))

	server := &http.Server{
		Handler: h.router,
		// BaseContext: func(net.Listener) context.Context {
		// 	return ctx
		// },
	}
	if h.delayShutdown {
		return snippets.RunWithDelayedShutdown(ctx,
			func() error {
				return server.Serve(listener)
			},
			func() {
				_ = server.Shutdown(context.Background())
			},
		)()
	}
	errC := make(chan error, 1)
	go func() {
		errC <- server.Serve(listener)
	}()
	select {
	case <-ctx.Done():
		_ = server.Shutdown(context.Background())
		return context.Cause(ctx)
	case err := <-errC:
		return err
	}
}

// Get implements HTTPUpstream.
func (h *httpUpstream) Get(r testenv.Route, opts ...RequestOption) (*http.Response, error) {
	return h.Do(http.MethodGet, r, opts...)
}

// Post implements HTTPUpstream.
func (h *httpUpstream) Post(r testenv.Route, opts ...RequestOption) (*http.Response, error) {
	return h.Do(http.MethodPost, r, opts...)
}

type Transport struct {
	*otelhttp.Transport
	// The underlying http.Transport instance wrapped by the otelhttp.Transport.
	Base *http.Transport
}

var _ http.RoundTripper = Transport{}

func (h *httpUpstream) newClient(options *RequestOptions) *http.Client {
	transport := http.DefaultTransport.(*http.Transport).Clone()
	transport.TLSClientConfig = &tls.Config{
		RootCAs:      h.Env().ServerCAs(),
		Certificates: options.clientCerts,
	}
	transport.DialTLSContext = nil
	c := http.Client{
		Transport: &Transport{
			Transport: otelhttp.NewTransport(transport,
				otelhttp.WithTracerProvider(h.clientTracerProvider.Value()),
				otelhttp.WithSpanNameFormatter(func(_ string, r *http.Request) string {
					return fmt.Sprintf("Client: %s %s", r.Method, r.URL.Path)
				}),
			),
			Base: transport,
		},
	}
	c.Jar, _ = cookiejar.New(&cookiejar.Options{})
	return &c
}

func (h *httpUpstream) getRouteClient(r testenv.Route, options *RequestOptions) *http.Client {
	span := oteltrace.SpanFromContext(options.requestCtx)
	var cachedClient any
	var ok bool
	if cachedClient, ok = h.clientCache.Load(r); !ok {
		span.AddEvent("creating new http client")
		cachedClient, _ = h.clientCache.LoadOrStore(r, h.newClient(options))
	} else {
		span.AddEvent("using cached http client")
	}
	client := cachedClient.(*http.Client)
	if options.clientHook != nil {
		client = options.clientHook(client)
	}
	return client
}

// Do implements HTTPUpstream.
func (h *httpUpstream) Do(method string, r testenv.Route, opts ...RequestOption) (*http.Response, error) {
	options := RequestOptions{
		requestCtx: h.Env().Context(),
	}
	options.apply(opts...)
	u, err := url.Parse(r.URL().Value())
	if err != nil {
		return nil, err
	}
	if options.path != "" || options.query != nil {
		u = u.ResolveReference(&url.URL{
			Path:     options.path,
			RawQuery: options.query.Encode(),
		})
	}
	ctx, span := h.clientTracer.Value().Start(options.requestCtx, "httpUpstream.Do", oteltrace.WithAttributes(
		attribute.String("method", method),
		attribute.String("url", u.String()),
	))
	if options.trace != nil {
		ctx = httptrace.WithClientTrace(ctx, options.trace)
	}
	options.requestCtx = ctx
	defer span.End()

	return doAuthenticatedRequest(options.requestCtx,
		func(ctx context.Context) (*http.Request, error) {
			return http.NewRequestWithContext(ctx, method, u.String(), nil)
		},
		func(context.Context) *http.Client {
			return h.getRouteClient(r, &options)
		},
		&options,
	)
}

func (h *httpUpstream) DialWS(r testenv.Route, f func(conn *websocket.Conn) error, opts ...RequestOption) error {
	options := RequestOptions{
		requestCtx: h.Env().Context(),
	}
	options.apply(opts...)
	u, err := url.Parse(r.URL().Value())
	if err != nil {
		return err
	}
	u.Scheme = "wss"
	if options.path != "" || options.query != nil {
		u = u.ResolveReference(&url.URL{
			Path:     options.path,
			RawQuery: options.query.Encode(),
		})
	}
	ctx, span := h.clientTracer.Value().Start(options.requestCtx, "httpUpstream.Dial", oteltrace.WithAttributes(
		attribute.String("url", u.String()),
	))
	options.requestCtx = ctx
	defer span.End()

	client := h.getRouteClient(r, &options)
	d := &websocket.Dialer{
		HandshakeTimeout: 10 * time.Second,
		TLSClientConfig:  client.Transport.(*Transport).Base.TLSClientConfig,
		Jar:              client.Jar,
	}
	if options.dialerHook != nil {
		d, u = options.dialerHook(d, u)
	}
	conn, resp, err := d.DialContext(options.requestCtx, u.String(), nil)
	if err != nil {
		resp.Body.Close()
		return fmt.Errorf("DialContext: %w", err)
	}
	defer conn.Close()

	return f(conn)
}