package upstreams

import (
	"bytes"
	"context"
	"encoding/json"
	"errors"
	"fmt"
	"io"
	"net"
	"net/http"
	"strconv"
	"strings"
	"sync"
	"time"

	"go.opentelemetry.io/otel/attribute"
	"go.opentelemetry.io/otel/codes"
	semconv "go.opentelemetry.io/otel/semconv/v1.26.0"
	oteltrace "go.opentelemetry.io/otel/trace"
	"google.golang.org/protobuf/proto"

	"github.com/pomerium/pomerium/integration/forms"
	"github.com/pomerium/pomerium/internal/retry"
)

var ErrRetry = errors.New("error")

func doAuthenticatedRequest(
	ctx context.Context,
	newRequest func(context.Context) (*http.Request, error),
	getClient func(context.Context) *http.Client,
	options *RequestOptions,
) (*http.Response, error) {
	var resp *http.Response
	resendCount := 0
	client := getClient(ctx)

	if err := retry.Retry(ctx, "http", func(ctx context.Context) error {
		req, err := newRequest(ctx)
		if err != nil {
			return retry.NewTerminalError(err)
		}

		switch body := options.body.(type) {
		case string:
			req.Body = io.NopCloser(strings.NewReader(body))
		case []byte:
			req.Body = io.NopCloser(bytes.NewReader(body))
		case io.Reader:
			req.Body = io.NopCloser(body)
		case proto.Message:
			buf, err := proto.Marshal(body)
			if err != nil {
				return retry.NewTerminalError(err)
			}
			req.Body = io.NopCloser(bytes.NewReader(buf))
			req.Header.Set("Content-Type", "application/octet-stream")
		default:
			buf, err := json.Marshal(body)
			if err != nil {
				panic(fmt.Sprintf("unsupported body type: %T", body))
			}
			req.Body = io.NopCloser(bytes.NewReader(buf))
			req.Header.Set("Content-Type", "application/json")
		case nil:
		}

		if options.headers != nil && req.Header == nil {
			req.Header = http.Header{}
		}
		for k, v := range options.headers {
			req.Header.Add(k, v)
		}

		if options.authenticateAs != "" {
			resp, err = authenticateFlow(ctx, client, req, options.authenticateAs, true) //nolint:bodyclose
		} else {
			resp, err = client.Do(req) //nolint:bodyclose
		}
		// retry on connection refused
		span := oteltrace.SpanFromContext(ctx)
		if err != nil {
			span.RecordError(err)
			var opErr *net.OpError
			if errors.As(err, &opErr) && opErr.Op == "dial" && opErr.Err.Error() == "connect: connection refused" {
				span.AddEvent("Retrying on dial error")
				return err
			}
			return retry.NewTerminalError(err)
		}
		if resp.StatusCode/100 == 5 {
			resendCount++
			_, _ = io.ReadAll(resp.Body)
			_ = resp.Body.Close()
			span.SetAttributes(semconv.HTTPRequestResendCount(resendCount))
			span.AddEvent("Retrying on 5xx error", oteltrace.WithAttributes(
				attribute.String("status", resp.Status),
			))
			return errors.New(http.StatusText(resp.StatusCode))
		}
		span.SetStatus(codes.Ok, "request completed successfully")
		return nil
	},
		retry.WithInitialInterval(1*time.Millisecond),
		retry.WithMaxInterval(100*time.Millisecond),
	); err != nil {
		return nil, err
	}
	return resp, nil
}

func authenticateFlow(ctx context.Context, client *http.Client, req *http.Request, email string, checkLocation bool) (*http.Response, error) {
	span := oteltrace.SpanFromContext(ctx)
	var res *http.Response
	originalHostname := req.URL.Hostname()
	res, err := client.Do(req)
	if err != nil {
		span.RecordError(err)
		return nil, err
	}
	location := res.Request.URL
	if checkLocation && location.Hostname() == originalHostname {
		// already authenticated
		span.SetStatus(codes.Ok, "already authenticated")
		return res, nil
	}
	fs := forms.Parse(res.Body)
	_, _ = io.ReadAll(res.Body)
	_ = res.Body.Close()
	if len(fs) > 0 {
		f := fs[0]
		f.Inputs["email"] = email
		f.Inputs["token_expiration"] = strconv.Itoa(int((time.Hour * 24).Seconds()))
		span.AddEvent("submitting form", oteltrace.WithAttributes(attribute.String("location", location.String())))
		formReq, err := f.NewRequestWithContext(ctx, location)
		if err != nil {
			span.RecordError(err)
			return nil, err
		}
		resp, err := client.Do(formReq)
		if err != nil {
			span.RecordError(err)
			return nil, err
		}
		span.SetStatus(codes.Ok, "form submitted successfully")
		return resp, nil
	}
	return nil, fmt.Errorf("test bug: expected IDP login form")
}

type rwConn struct {
	serverReader io.ReadCloser
	serverWriter io.WriteCloser

	net.Conn
	remote net.Conn

	closeOnce sync.Once
	wg        *sync.WaitGroup
}

func NewRWConn(reader io.ReadCloser, writer io.WriteCloser) net.Conn {
	rwc := &rwConn{
		serverReader: reader,
		serverWriter: writer,
		wg:           &sync.WaitGroup{},
	}

	rwc.Conn, rwc.remote = net.Pipe()
	rwc.wg.Add(2)
	go func() {
		defer rwc.wg.Done()
		_, _ = io.Copy(rwc.remote, rwc.serverReader)
		rwc.remote.Close()
	}()
	go func() {
		defer rwc.wg.Done()
		_, _ = io.Copy(rwc.serverWriter, rwc.remote)
		rwc.serverWriter.Close()
	}()
	return rwc
}

func (rwc *rwConn) Close() error {
	var err error
	rwc.closeOnce.Do(func() {
		readerErr := rwc.serverReader.Close()
		localErr := rwc.Conn.Close()
		rwc.wg.Wait()
		err = errors.Join(localErr, readerErr)
	})
	return err
}

var _ net.Conn = (*rwConn)(nil)