package upstreams

import (
	"bytes"
	"context"
	"crypto/tls"
	"encoding/json"
	"errors"
	"fmt"
	"io"
	"net"
	"net/http"
	"net/http/cookiejar"
	"net/http/httptrace"
	"net/url"
	"strconv"
	"strings"
	"sync"
	"time"

	"github.com/gorilla/mux"
	"github.com/pomerium/pomerium/integration/forms"
	"github.com/pomerium/pomerium/internal/retry"
	"github.com/pomerium/pomerium/internal/telemetry/trace"
	"github.com/pomerium/pomerium/internal/testenv"
	"github.com/pomerium/pomerium/internal/testenv/snippets"
	"github.com/pomerium/pomerium/internal/testenv/values"
	"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"

	"go.opentelemetry.io/otel/attribute"
	"go.opentelemetry.io/otel/codes"
	semconv "go.opentelemetry.io/otel/semconv/v1.26.0"
	oteltrace "go.opentelemetry.io/otel/trace"
	"google.golang.org/protobuf/proto"
)

type RequestOptions struct {
	requestCtx     context.Context
	path           string
	query          url.Values
	headers        map[string]string
	authenticateAs string
	body           any
	clientCerts    []tls.Certificate
	client         *http.Client
	trace          *httptrace.ClientTrace
}

type RequestOption func(*RequestOptions)

func (o *RequestOptions) apply(opts ...RequestOption) {
	for _, op := range opts {
		op(o)
	}
}

// Path sets the path of the request. If omitted, the request URL will match
// the route URL exactly.
func Path(path string) RequestOption {
	return func(o *RequestOptions) {
		o.path = path
	}
}

// Query sets optional query parameters of the request.
func Query(query url.Values) RequestOption {
	return func(o *RequestOptions) {
		o.query = query
	}
}

// Headers adds optional headers to the request.
func Headers(headers map[string]string) RequestOption {
	return func(o *RequestOptions) {
		o.headers = headers
	}
}

func AuthenticateAs(email string) RequestOption {
	return func(o *RequestOptions) {
		o.authenticateAs = email
	}
}

func Client(c *http.Client) RequestOption {
	return func(o *RequestOptions) {
		o.client = c
	}
}

func Context(ctx context.Context) RequestOption {
	return func(o *RequestOptions) {
		o.requestCtx = ctx
	}
}

func WithClientTrace(ct *httptrace.ClientTrace) RequestOption {
	return func(o *RequestOptions) {
		o.trace = ct
	}
}

// Body sets the body of the request.
// The argument can be one of the following types:
// - string
// - []byte
// - io.Reader
// - proto.Message
// - any json-encodable type
// If the argument is encoded as json, the Content-Type header will be set to
// "application/json". If the argument is a proto.Message, the Content-Type
// header will be set to "application/octet-stream".
func Body(body any) RequestOption {
	return func(o *RequestOptions) {
		o.body = body
	}
}

// ClientCert adds a client certificate to the request.
func ClientCert[T interface {
	*testenv.Certificate | *tls.Certificate
}](cert T) RequestOption {
	return func(o *RequestOptions) {
		o.clientCerts = append(o.clientCerts, *(*tls.Certificate)(cert))
	}
}

type HTTPUpstreamOptions struct {
	CommonUpstreamOptions
}

type HTTPUpstreamOption interface {
	applyHTTP(*HTTPUpstreamOptions)
}

// HTTPUpstream represents a HTTP server which can be used as the target for
// one or more Pomerium routes in a test environment.
//
// The Handle() method can be used to add handlers the server-side HTTP router,
// while the Get(), Post(), and (generic) Do() methods can be used to make
// client-side requests.
type HTTPUpstream interface {
	testenv.Upstream

	Handle(path string, f func(http.ResponseWriter, *http.Request)) *mux.Route

	Get(r testenv.Route, opts ...RequestOption) (*http.Response, error)
	Post(r testenv.Route, opts ...RequestOption) (*http.Response, error)
	Do(method string, r testenv.Route, opts ...RequestOption) (*http.Response, error)
}

type httpUpstream struct {
	HTTPUpstreamOptions
	testenv.Aggregate
	serverPort values.MutableValue[int]
	tlsConfig  values.Value[*tls.Config]

	clientCache sync.Map // map[testenv.Route]*http.Client

	router               *mux.Router
	serverTracerProvider values.MutableValue[oteltrace.TracerProvider]
	clientTracerProvider values.MutableValue[oteltrace.TracerProvider]
	clientTracer         values.Value[oteltrace.Tracer]
}

var (
	_ testenv.Upstream = (*httpUpstream)(nil)
	_ HTTPUpstream     = (*httpUpstream)(nil)
)

// HTTP creates a new HTTP upstream server.
func HTTP(tlsConfig values.Value[*tls.Config], opts ...HTTPUpstreamOption) HTTPUpstream {
	options := HTTPUpstreamOptions{
		CommonUpstreamOptions: CommonUpstreamOptions{
			displayName: "HTTP Upstream",
		},
	}
	for _, op := range opts {
		op.applyHTTP(&options)
	}
	up := &httpUpstream{
		HTTPUpstreamOptions:  options,
		serverPort:           values.Deferred[int](),
		router:               mux.NewRouter(),
		tlsConfig:            tlsConfig,
		serverTracerProvider: values.Deferred[oteltrace.TracerProvider](),
		clientTracerProvider: values.Deferred[oteltrace.TracerProvider](),
	}
	up.clientTracer = values.Bind(up.clientTracerProvider, func(tp oteltrace.TracerProvider) oteltrace.Tracer {
		return tp.Tracer(trace.PomeriumCoreTracer)
	})
	up.RecordCaller()
	return up
}

// Port implements HTTPUpstream.
func (h *httpUpstream) Port() values.Value[int] {
	return h.serverPort
}

// Router implements HTTPUpstream.
func (h *httpUpstream) Handle(path string, f func(http.ResponseWriter, *http.Request)) *mux.Route {
	return h.router.HandleFunc(path, f)
}

// Route implements HTTPUpstream.
func (h *httpUpstream) Route() testenv.RouteStub {
	r := &testenv.PolicyRoute{}
	protocol := "http"
	r.To(values.Bind(h.serverPort, func(port int) string {
		return fmt.Sprintf("%s://127.0.0.1:%d", protocol, port)
	}))
	h.Add(r)
	return r
}

// Run implements HTTPUpstream.
func (h *httpUpstream) Run(ctx context.Context) error {
	listener, err := net.Listen("tcp", "127.0.0.1:0")
	if err != nil {
		return err
	}
	h.serverPort.Resolve(listener.Addr().(*net.TCPAddr).Port)
	var tlsConfig *tls.Config
	if h.tlsConfig != nil {
		tlsConfig = h.tlsConfig.Value()
	}
	if h.serverTracerProviderOverride != nil {
		h.serverTracerProvider.Resolve(h.serverTracerProviderOverride)
	} else {
		h.serverTracerProvider.Resolve(trace.NewTracerProvider(ctx, h.displayName))
	}
	if h.clientTracerProviderOverride != nil {
		h.clientTracerProvider.Resolve(h.clientTracerProviderOverride)
	} else {
		h.clientTracerProvider.Resolve(trace.NewTracerProvider(ctx, "HTTP Client"))
	}
	h.router.Use(trace.NewHTTPMiddleware(otelhttp.WithTracerProvider(h.serverTracerProvider.Value())))

	server := &http.Server{
		Handler:   h.router,
		TLSConfig: tlsConfig,
		// BaseContext: func(net.Listener) context.Context {
		// 	return ctx
		// },
	}
	if h.delayShutdown {
		return snippets.RunWithDelayedShutdown(ctx,
			func() error {
				return server.Serve(listener)
			},
			func() {
				_ = server.Shutdown(context.Background())
			},
		)()
	}
	errC := make(chan error, 1)
	go func() {
		errC <- server.Serve(listener)
	}()
	select {
	case <-ctx.Done():
		_ = server.Shutdown(context.Background())
		return context.Cause(ctx)
	case err := <-errC:
		return err
	}
}

// Get implements HTTPUpstream.
func (h *httpUpstream) Get(r testenv.Route, opts ...RequestOption) (*http.Response, error) {
	return h.Do(http.MethodGet, r, opts...)
}

// Post implements HTTPUpstream.
func (h *httpUpstream) Post(r testenv.Route, opts ...RequestOption) (*http.Response, error) {
	return h.Do(http.MethodPost, r, opts...)
}

// Do implements HTTPUpstream.
func (h *httpUpstream) Do(method string, r testenv.Route, opts ...RequestOption) (*http.Response, error) {
	options := RequestOptions{
		requestCtx: h.Env().Context(),
	}
	options.apply(opts...)
	u, err := url.Parse(r.URL().Value())
	if err != nil {
		return nil, err
	}
	if options.path != "" || options.query != nil {
		u = u.ResolveReference(&url.URL{
			Path:     options.path,
			RawQuery: options.query.Encode(),
		})
	}
	ctx, span := h.clientTracer.Value().Start(options.requestCtx, "httpUpstream.Do", oteltrace.WithAttributes(
		attribute.String("method", method),
		attribute.String("url", u.String()),
	))
	if options.trace != nil {
		ctx = httptrace.WithClientTrace(ctx, options.trace)
	}
	options.requestCtx = ctx
	defer span.End()

	newClient := func() *http.Client {
		transport := http.DefaultTransport.(*http.Transport).Clone()
		transport.TLSClientConfig = &tls.Config{
			RootCAs:      h.Env().ServerCAs(),
			Certificates: options.clientCerts,
		}
		transport.DialTLSContext = nil
		c := http.Client{
			Transport: otelhttp.NewTransport(transport,
				otelhttp.WithTracerProvider(h.clientTracerProvider.Value()),
				otelhttp.WithSpanNameFormatter(func(_ string, r *http.Request) string {
					return fmt.Sprintf("Client: %s %s", r.Method, r.URL.Path)
				}),
			),
		}
		c.Jar, _ = cookiejar.New(&cookiejar.Options{})
		return &c
	}
	var client *http.Client
	if options.client != nil {
		client = options.client
	} else {
		var cachedClient any
		var ok bool
		if cachedClient, ok = h.clientCache.Load(r); !ok {
			span.AddEvent("creating new http client")
			cachedClient, _ = h.clientCache.LoadOrStore(r, newClient())
		} else {
			span.AddEvent("using cached http client")
		}
		client = cachedClient.(*http.Client)
	}

	var resp *http.Response
	resendCount := 0
	if err := retry.Retry(ctx, "http", func(ctx context.Context) error {
		req, err := http.NewRequestWithContext(ctx, method, u.String(), nil)
		if err != nil {
			return retry.NewTerminalError(err)
		}
		switch body := options.body.(type) {
		case string:
			req.Body = io.NopCloser(strings.NewReader(body))
		case []byte:
			req.Body = io.NopCloser(bytes.NewReader(body))
		case io.Reader:
			req.Body = io.NopCloser(body)
		case proto.Message:
			buf, err := proto.Marshal(body)
			if err != nil {
				return retry.NewTerminalError(err)
			}
			req.Body = io.NopCloser(bytes.NewReader(buf))
			req.Header.Set("Content-Type", "application/octet-stream")
		default:
			buf, err := json.Marshal(body)
			if err != nil {
				panic(fmt.Sprintf("unsupported body type: %T", body))
			}
			req.Body = io.NopCloser(bytes.NewReader(buf))
			req.Header.Set("Content-Type", "application/json")
		case nil:
		}

		if options.authenticateAs != "" {
			resp, err = authenticateFlow(ctx, client, req, options.authenticateAs) //nolint:bodyclose
		} else {
			resp, err = client.Do(req) //nolint:bodyclose
		}
		// retry on connection refused
		if err != nil {
			span.RecordError(err)
			var opErr *net.OpError
			if errors.As(err, &opErr) && opErr.Op == "dial" && opErr.Err.Error() == "connect: connection refused" {
				span.AddEvent("Retrying on dial error")
				return err
			}
			return retry.NewTerminalError(err)
		}
		if resp.StatusCode/100 == 5 {
			resendCount++
			_, _ = io.ReadAll(resp.Body)
			_ = resp.Body.Close()
			span.SetAttributes(semconv.HTTPRequestResendCount(resendCount))
			span.AddEvent("Retrying on 5xx error", oteltrace.WithAttributes(
				attribute.String("status", resp.Status),
			))
			return errors.New(http.StatusText(resp.StatusCode))
		}
		span.SetStatus(codes.Ok, "request completed successfully")
		return nil
	},
		retry.WithInitialInterval(1*time.Millisecond),
		retry.WithMaxInterval(100*time.Millisecond),
	); err != nil {
		return nil, err
	}
	return resp, nil
}

func authenticateFlow(ctx context.Context, client *http.Client, req *http.Request, email string) (*http.Response, error) {
	span := oteltrace.SpanFromContext(ctx)
	var res *http.Response
	originalHostname := req.URL.Hostname()
	res, err := client.Do(req)
	if err != nil {
		span.RecordError(err)
		return nil, err
	}
	location := res.Request.URL
	if location.Hostname() == originalHostname {
		// already authenticated
		span.SetStatus(codes.Ok, "already authenticated")
		return res, nil
	}
	fs := forms.Parse(res.Body)
	_, _ = io.ReadAll(res.Body)
	_ = res.Body.Close()
	if len(fs) > 0 {
		f := fs[0]
		f.Inputs["email"] = email
		f.Inputs["token_expiration"] = strconv.Itoa(int((time.Hour * 24).Seconds()))
		span.AddEvent("submitting form", oteltrace.WithAttributes(attribute.String("location", location.String())))
		formReq, err := f.NewRequestWithContext(ctx, location)
		if err != nil {
			span.RecordError(err)
			return nil, err
		}
		resp, err := client.Do(formReq)
		if err != nil {
			span.RecordError(err)
			return nil, err
		}
		span.SetStatus(codes.Ok, "form submitted successfully")
		return resp, nil
	}
	return nil, fmt.Errorf("test bug: expected IDP login form")
}