testenv: avoid dns lookups for *.localhost.pomerium.io (#5372)

* testenv: avoid dns lookups for localhost.pomerium.io

* linter pass
This commit is contained in:
Joe Kralicky 2024-12-02 12:29:15 -05:00 committed by GitHub
parent 55e13f9608
commit 39e789529e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 142 additions and 8 deletions

64
internal/testenv/dns.go Normal file
View file

@ -0,0 +1,64 @@
package testenv
import (
"context"
"net"
"net/http"
"strings"
"time"
)
const localDomainName = "localhost.pomerium.io"
type DialContextFunc = func(ctx context.Context, network string, addr string) (net.Conn, error)
var defaultDialer = &net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
Resolver: &net.Resolver{
PreferGo: true,
},
}
func init() {
http.DefaultTransport.(*http.Transport).DialContext = OverrideDialContext(defaultDialer.DialContext)
}
func OverrideDialContext(defaultDialContext DialContextFunc) DialContextFunc {
return func(ctx context.Context, network, addr string) (net.Conn, error) {
host, port, err := maybeSplitHostPort(addr)
if err != nil {
return nil, err
}
if strings.HasSuffix(host, localDomainName) {
switch network {
case "tcp", "tcp4", "udp", "udp4":
host = "127.0.0.1"
case "tcp6", "udp6":
host = "::1"
}
}
return defaultDialContext(ctx, network, net.JoinHostPort(host, port))
}
}
func maybeSplitHostPort(s string) (string, string, error) {
if strings.Contains(s, ":") {
return net.SplitHostPort(s)
}
return s, "", nil
}
func GRPCContextDialer(ctx context.Context, target string) (net.Conn, error) {
if strings.HasPrefix(target, "unix") {
return defaultDialer.DialContext(ctx, "tcp", target)
}
host, port, err := net.SplitHostPort(target)
if err != nil {
return nil, err
}
if strings.HasSuffix(host, localDomainName) {
return defaultDialer.DialContext(ctx, "tcp", "127.0.0.1:"+port)
}
return defaultDialer.DialContext(ctx, "tcp", target)
}

View file

@ -0,0 +1,55 @@
package selftests_test
import (
"net"
"net/http"
"net/http/httptrace"
"testing"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/testenv"
"github.com/pomerium/pomerium/internal/testenv/snippets"
"github.com/pomerium/pomerium/internal/testenv/upstreams"
"github.com/stretchr/testify/require"
)
func TestDNSOverrides(t *testing.T) {
env := testenv.New(t)
h := upstreams.HTTP(nil)
h.Handle("/", func(w http.ResponseWriter, _ *http.Request) {
w.Write([]byte("OK"))
})
route := h.Route().From(env.SubdomainURL("foo")).Policy(func(p *config.Policy) {
p.AllowPublicUnauthenticatedAccess = true
})
env.AddUpstream(h)
env.Start()
snippets.WaitStartupComplete(env)
var traceHostPort, traceRemoteAddr string
var dnsStartCalled, dnsEndCalled bool
trace := httptrace.ClientTrace{
DNSStart: func(_ httptrace.DNSStartInfo) {
dnsStartCalled = true
},
DNSDone: func(_ httptrace.DNSDoneInfo) {
dnsEndCalled = true
},
GetConn: func(hostPort string) {
traceHostPort = hostPort
},
GotConn: func(gci httptrace.GotConnInfo) {
traceRemoteAddr = gci.Conn.RemoteAddr().String()
},
}
resp, err := h.Get(route, upstreams.WithClientTrace(&trace))
require.NoError(t, err)
require.False(t, dnsStartCalled)
require.False(t, dnsEndCalled)
require.Equal(t, http.StatusOK, resp.StatusCode)
require.Equal(t, route.URL().Value(), "https://"+traceHostPort)
host, _, err := net.SplitHostPort(traceRemoteAddr)
require.NoError(t, err)
require.True(t, net.ParseIP(host).IsLoopback())
}

View file

@ -128,6 +128,7 @@ func (g *grpcUpstream) Run(ctx context.Context) error {
func (g *grpcUpstream) Dial(r testenv.Route, dialOpts ...grpc.DialOption) *grpc.ClientConn {
dialOpts = append(dialOpts,
grpc.WithContextDialer(testenv.GRPCContextDialer),
grpc.WithTransportCredentials(credentials.NewClientTLSFromCert(g.Env().ServerCAs(), "")),
grpc.WithDefaultCallOptions(grpc.WaitForReady(true)),
)

View file

@ -11,6 +11,7 @@ import (
"net"
"net/http"
"net/http/cookiejar"
"net/http/httptrace"
"net/url"
"strconv"
"strings"
@ -22,6 +23,7 @@ import (
"github.com/pomerium/pomerium/internal/retry"
"github.com/pomerium/pomerium/internal/testenv"
"github.com/pomerium/pomerium/internal/testenv/values"
"github.com/pomerium/pomerium/pkg/telemetry/requestid"
"google.golang.org/protobuf/proto"
)
@ -33,6 +35,7 @@ type RequestOptions struct {
body any
clientCerts []tls.Certificate
client *http.Client
trace *httptrace.ClientTrace
}
type RequestOption func(*RequestOptions)
@ -77,6 +80,12 @@ func Client(c *http.Client) RequestOption {
}
}
func WithClientTrace(ct *httptrace.ClientTrace) RequestOption {
return func(o *RequestOptions) {
o.trace = ct
}
}
// Body sets the body of the request.
// The argument can be one of the following types:
// - string
@ -220,7 +229,11 @@ func (h *httpUpstream) Do(method string, r testenv.Route, opts ...RequestOption)
RawQuery: options.query.Encode(),
})
}
req, err := http.NewRequest(method, u.String(), nil)
ctx := h.Env().Context()
if options.trace != nil {
ctx = httptrace.WithClientTrace(ctx, options.trace)
}
req, err := http.NewRequestWithContext(ctx, method, u.String(), nil)
if err != nil {
return nil, err
}
@ -249,13 +262,14 @@ func (h *httpUpstream) Do(method string, r testenv.Route, opts ...RequestOption)
}
newClient := func() *http.Client {
transport := http.DefaultTransport.(*http.Transport).Clone()
transport.TLSClientConfig = &tls.Config{
RootCAs: h.Env().ServerCAs(),
Certificates: options.clientCerts,
}
transport.DialTLSContext = nil
c := http.Client{
Transport: &http.Transport{
TLSClientConfig: &tls.Config{
RootCAs: h.Env().ServerCAs(),
Certificates: options.clientCerts,
},
},
Transport: requestid.NewRoundTripper(transport),
}
c.Jar, _ = cookiejar.New(&cookiejar.Options{})
return &c
@ -273,7 +287,7 @@ func (h *httpUpstream) Do(method string, r testenv.Route, opts ...RequestOption)
}
var resp *http.Response
if err := retry.Retry(h.Env().Context(), "http", func(ctx context.Context) error {
if err := retry.Retry(ctx, "http", func(ctx context.Context) error {
var err error
if options.authenticateAs != "" {
resp, err = authenticateFlow(ctx, client, req, options.authenticateAs) //nolint:bodyclose