From 39e789529ea408103bb354de832945485fc83c61 Mon Sep 17 00:00:00 2001 From: Joe Kralicky Date: Mon, 2 Dec 2024 12:29:15 -0500 Subject: [PATCH] testenv: avoid dns lookups for *.localhost.pomerium.io (#5372) * testenv: avoid dns lookups for localhost.pomerium.io * linter pass --- internal/testenv/dns.go | 64 ++++++++++++++++++++++++++ internal/testenv/selftests/dns_test.go | 55 ++++++++++++++++++++++ internal/testenv/upstreams/grpc.go | 1 + internal/testenv/upstreams/http.go | 30 ++++++++---- 4 files changed, 142 insertions(+), 8 deletions(-) create mode 100644 internal/testenv/dns.go create mode 100644 internal/testenv/selftests/dns_test.go diff --git a/internal/testenv/dns.go b/internal/testenv/dns.go new file mode 100644 index 000000000..d1bf0c46c --- /dev/null +++ b/internal/testenv/dns.go @@ -0,0 +1,64 @@ +package testenv + +import ( + "context" + "net" + "net/http" + "strings" + "time" +) + +const localDomainName = "localhost.pomerium.io" + +type DialContextFunc = func(ctx context.Context, network string, addr string) (net.Conn, error) + +var defaultDialer = &net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + Resolver: &net.Resolver{ + PreferGo: true, + }, +} + +func init() { + http.DefaultTransport.(*http.Transport).DialContext = OverrideDialContext(defaultDialer.DialContext) +} + +func OverrideDialContext(defaultDialContext DialContextFunc) DialContextFunc { + return func(ctx context.Context, network, addr string) (net.Conn, error) { + host, port, err := maybeSplitHostPort(addr) + if err != nil { + return nil, err + } + if strings.HasSuffix(host, localDomainName) { + switch network { + case "tcp", "tcp4", "udp", "udp4": + host = "127.0.0.1" + case "tcp6", "udp6": + host = "::1" + } + } + return defaultDialContext(ctx, network, net.JoinHostPort(host, port)) + } +} + +func maybeSplitHostPort(s string) (string, string, error) { + if strings.Contains(s, ":") { + return net.SplitHostPort(s) + } + return s, "", nil +} + +func GRPCContextDialer(ctx context.Context, target string) (net.Conn, error) { + if strings.HasPrefix(target, "unix") { + return defaultDialer.DialContext(ctx, "tcp", target) + } + host, port, err := net.SplitHostPort(target) + if err != nil { + return nil, err + } + if strings.HasSuffix(host, localDomainName) { + return defaultDialer.DialContext(ctx, "tcp", "127.0.0.1:"+port) + } + return defaultDialer.DialContext(ctx, "tcp", target) +} diff --git a/internal/testenv/selftests/dns_test.go b/internal/testenv/selftests/dns_test.go new file mode 100644 index 000000000..c04fc4efd --- /dev/null +++ b/internal/testenv/selftests/dns_test.go @@ -0,0 +1,55 @@ +package selftests_test + +import ( + "net" + "net/http" + "net/http/httptrace" + "testing" + + "github.com/pomerium/pomerium/config" + "github.com/pomerium/pomerium/internal/testenv" + "github.com/pomerium/pomerium/internal/testenv/snippets" + "github.com/pomerium/pomerium/internal/testenv/upstreams" + "github.com/stretchr/testify/require" +) + +func TestDNSOverrides(t *testing.T) { + env := testenv.New(t) + h := upstreams.HTTP(nil) + h.Handle("/", func(w http.ResponseWriter, _ *http.Request) { + w.Write([]byte("OK")) + }) + route := h.Route().From(env.SubdomainURL("foo")).Policy(func(p *config.Policy) { + p.AllowPublicUnauthenticatedAccess = true + }) + env.AddUpstream(h) + + env.Start() + snippets.WaitStartupComplete(env) + + var traceHostPort, traceRemoteAddr string + var dnsStartCalled, dnsEndCalled bool + trace := httptrace.ClientTrace{ + DNSStart: func(_ httptrace.DNSStartInfo) { + dnsStartCalled = true + }, + DNSDone: func(_ httptrace.DNSDoneInfo) { + dnsEndCalled = true + }, + GetConn: func(hostPort string) { + traceHostPort = hostPort + }, + GotConn: func(gci httptrace.GotConnInfo) { + traceRemoteAddr = gci.Conn.RemoteAddr().String() + }, + } + resp, err := h.Get(route, upstreams.WithClientTrace(&trace)) + require.NoError(t, err) + require.False(t, dnsStartCalled) + require.False(t, dnsEndCalled) + require.Equal(t, http.StatusOK, resp.StatusCode) + require.Equal(t, route.URL().Value(), "https://"+traceHostPort) + host, _, err := net.SplitHostPort(traceRemoteAddr) + require.NoError(t, err) + require.True(t, net.ParseIP(host).IsLoopback()) +} diff --git a/internal/testenv/upstreams/grpc.go b/internal/testenv/upstreams/grpc.go index f46801392..fb22e08a7 100644 --- a/internal/testenv/upstreams/grpc.go +++ b/internal/testenv/upstreams/grpc.go @@ -128,6 +128,7 @@ func (g *grpcUpstream) Run(ctx context.Context) error { func (g *grpcUpstream) Dial(r testenv.Route, dialOpts ...grpc.DialOption) *grpc.ClientConn { dialOpts = append(dialOpts, + grpc.WithContextDialer(testenv.GRPCContextDialer), grpc.WithTransportCredentials(credentials.NewClientTLSFromCert(g.Env().ServerCAs(), "")), grpc.WithDefaultCallOptions(grpc.WaitForReady(true)), ) diff --git a/internal/testenv/upstreams/http.go b/internal/testenv/upstreams/http.go index 0138c18f3..2ec020108 100644 --- a/internal/testenv/upstreams/http.go +++ b/internal/testenv/upstreams/http.go @@ -11,6 +11,7 @@ import ( "net" "net/http" "net/http/cookiejar" + "net/http/httptrace" "net/url" "strconv" "strings" @@ -22,6 +23,7 @@ import ( "github.com/pomerium/pomerium/internal/retry" "github.com/pomerium/pomerium/internal/testenv" "github.com/pomerium/pomerium/internal/testenv/values" + "github.com/pomerium/pomerium/pkg/telemetry/requestid" "google.golang.org/protobuf/proto" ) @@ -33,6 +35,7 @@ type RequestOptions struct { body any clientCerts []tls.Certificate client *http.Client + trace *httptrace.ClientTrace } type RequestOption func(*RequestOptions) @@ -77,6 +80,12 @@ func Client(c *http.Client) RequestOption { } } +func WithClientTrace(ct *httptrace.ClientTrace) RequestOption { + return func(o *RequestOptions) { + o.trace = ct + } +} + // Body sets the body of the request. // The argument can be one of the following types: // - string @@ -220,7 +229,11 @@ func (h *httpUpstream) Do(method string, r testenv.Route, opts ...RequestOption) RawQuery: options.query.Encode(), }) } - req, err := http.NewRequest(method, u.String(), nil) + ctx := h.Env().Context() + if options.trace != nil { + ctx = httptrace.WithClientTrace(ctx, options.trace) + } + req, err := http.NewRequestWithContext(ctx, method, u.String(), nil) if err != nil { return nil, err } @@ -249,13 +262,14 @@ func (h *httpUpstream) Do(method string, r testenv.Route, opts ...RequestOption) } newClient := func() *http.Client { + transport := http.DefaultTransport.(*http.Transport).Clone() + transport.TLSClientConfig = &tls.Config{ + RootCAs: h.Env().ServerCAs(), + Certificates: options.clientCerts, + } + transport.DialTLSContext = nil c := http.Client{ - Transport: &http.Transport{ - TLSClientConfig: &tls.Config{ - RootCAs: h.Env().ServerCAs(), - Certificates: options.clientCerts, - }, - }, + Transport: requestid.NewRoundTripper(transport), } c.Jar, _ = cookiejar.New(&cookiejar.Options{}) return &c @@ -273,7 +287,7 @@ func (h *httpUpstream) Do(method string, r testenv.Route, opts ...RequestOption) } var resp *http.Response - if err := retry.Retry(h.Env().Context(), "http", func(ctx context.Context) error { + if err := retry.Retry(ctx, "http", func(ctx context.Context) error { var err error if options.authenticateAs != "" { resp, err = authenticateFlow(ctx, client, req, options.authenticateAs) //nolint:bodyclose