package tcptunnel

import (
	"bufio"
	"bytes"
	"context"
	"io"
	"net"
	"net/http"
	"net/http/httptest"
	"strings"
	"testing"
	"time"

	"github.com/stretchr/testify/assert"
)

func TestTunnel(t *testing.T) {
	ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10)
	defer clearTimeout()

	backend, err := net.Listen("tcp", "127.0.0.1:0")
	if !assert.NoError(t, err) {
		return
	}
	defer func() { _ = backend.Close() }()

	go func() {
		for {
			conn, err := backend.Accept()
			if err != nil {
				return
			}
			go func() {
				defer func() { _ = conn.Close() }()

				ln, _, _ := bufio.NewReader(conn).ReadLine()
				assert.Equal(t, "HELLO WORLD", string(ln))
			}()
		}
	}()

	srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		if !assert.Equal(t, "CONNECT", r.Method) {
			return
		}
		if !assert.Equal(t, "example.com:9999", r.RequestURI) {
			return
		}

		w.WriteHeader(200)

		in, brw, err := w.(http.Hijacker).Hijack()
		if !assert.NoError(t, err) {
			return
		}
		defer func() { _ = in.Close() }()

		out, err := net.Dial("tcp", backend.Addr().String())
		if !assert.NoError(t, err) {
			return
		}
		defer func() { _ = out.Close() }()

		errc := make(chan error, 2)
		go func() {
			_, err := io.Copy(in, out)
			errc <- err
		}()
		go func() {
			_, err := io.Copy(out, deBuffer(brw.Reader, in))
			errc <- err
		}()
		<-errc
	}))
	defer srv.Close()

	var buf bytes.Buffer
	tun := New(
		WithDestinationHost("example.com:9999"),
		WithProxyHost(srv.Listener.Addr().String()))
	err = tun.Run(ctx, readWriter{strings.NewReader("HELLO WORLD\n"), &buf})
	if !assert.NoError(t, err) {
		return
	}
}

type readWriter struct {
	io.Reader
	io.Writer
}