package selftests_test

import (
	"net"
	"net/http"
	"net/http/httptrace"
	"testing"

	"github.com/stretchr/testify/require"

	"github.com/pomerium/pomerium/config"
	"github.com/pomerium/pomerium/internal/testenv"
	"github.com/pomerium/pomerium/internal/testenv/snippets"
	"github.com/pomerium/pomerium/internal/testenv/upstreams"
)

func TestDNSOverrides(t *testing.T) {
	env := testenv.New(t)
	h := upstreams.HTTP(nil)
	h.Handle("/", func(w http.ResponseWriter, _ *http.Request) {
		w.Write([]byte("OK"))
	})
	route := h.Route().From(env.SubdomainURL("foo")).Policy(func(p *config.Policy) {
		p.AllowPublicUnauthenticatedAccess = true
	})
	env.AddUpstream(h)

	env.Start()
	snippets.WaitStartupComplete(env)

	var traceHostPort, traceRemoteAddr string
	var dnsStartCalled, dnsEndCalled bool
	trace := httptrace.ClientTrace{
		DNSStart: func(_ httptrace.DNSStartInfo) {
			dnsStartCalled = true
		},
		DNSDone: func(_ httptrace.DNSDoneInfo) {
			dnsEndCalled = true
		},
		GetConn: func(hostPort string) {
			traceHostPort = hostPort
		},
		GotConn: func(gci httptrace.GotConnInfo) {
			traceRemoteAddr = gci.Conn.RemoteAddr().String()
		},
	}
	resp, err := h.Get(route, upstreams.WithClientTrace(&trace))
	require.NoError(t, err)
	require.False(t, dnsStartCalled)
	require.False(t, dnsEndCalled)
	require.Equal(t, http.StatusOK, resp.StatusCode)
	require.Equal(t, route.URL().Value(), "https://"+traceHostPort)
	host, _, err := net.SplitHostPort(traceRemoteAddr)
	require.NoError(t, err)
	require.True(t, net.ParseIP(host).IsLoopback())
}