mirror of
https://github.com/pomerium/pomerium.git
synced 2025-06-03 11:22:45 +02:00
testenv: avoid dns lookups for *.localhost.pomerium.io (#5372)
* testenv: avoid dns lookups for localhost.pomerium.io * linter pass
This commit is contained in:
parent
55e13f9608
commit
39e789529e
4 changed files with 142 additions and 8 deletions
64
internal/testenv/dns.go
Normal file
64
internal/testenv/dns.go
Normal file
|
@ -0,0 +1,64 @@
|
|||
package testenv
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const localDomainName = "localhost.pomerium.io"
|
||||
|
||||
type DialContextFunc = func(ctx context.Context, network string, addr string) (net.Conn, error)
|
||||
|
||||
var defaultDialer = &net.Dialer{
|
||||
Timeout: 30 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
Resolver: &net.Resolver{
|
||||
PreferGo: true,
|
||||
},
|
||||
}
|
||||
|
||||
func init() {
|
||||
http.DefaultTransport.(*http.Transport).DialContext = OverrideDialContext(defaultDialer.DialContext)
|
||||
}
|
||||
|
||||
func OverrideDialContext(defaultDialContext DialContextFunc) DialContextFunc {
|
||||
return func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
host, port, err := maybeSplitHostPort(addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if strings.HasSuffix(host, localDomainName) {
|
||||
switch network {
|
||||
case "tcp", "tcp4", "udp", "udp4":
|
||||
host = "127.0.0.1"
|
||||
case "tcp6", "udp6":
|
||||
host = "::1"
|
||||
}
|
||||
}
|
||||
return defaultDialContext(ctx, network, net.JoinHostPort(host, port))
|
||||
}
|
||||
}
|
||||
|
||||
func maybeSplitHostPort(s string) (string, string, error) {
|
||||
if strings.Contains(s, ":") {
|
||||
return net.SplitHostPort(s)
|
||||
}
|
||||
return s, "", nil
|
||||
}
|
||||
|
||||
func GRPCContextDialer(ctx context.Context, target string) (net.Conn, error) {
|
||||
if strings.HasPrefix(target, "unix") {
|
||||
return defaultDialer.DialContext(ctx, "tcp", target)
|
||||
}
|
||||
host, port, err := net.SplitHostPort(target)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if strings.HasSuffix(host, localDomainName) {
|
||||
return defaultDialer.DialContext(ctx, "tcp", "127.0.0.1:"+port)
|
||||
}
|
||||
return defaultDialer.DialContext(ctx, "tcp", target)
|
||||
}
|
55
internal/testenv/selftests/dns_test.go
Normal file
55
internal/testenv/selftests/dns_test.go
Normal file
|
@ -0,0 +1,55 @@
|
|||
package selftests_test
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptrace"
|
||||
"testing"
|
||||
|
||||
"github.com/pomerium/pomerium/config"
|
||||
"github.com/pomerium/pomerium/internal/testenv"
|
||||
"github.com/pomerium/pomerium/internal/testenv/snippets"
|
||||
"github.com/pomerium/pomerium/internal/testenv/upstreams"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestDNSOverrides(t *testing.T) {
|
||||
env := testenv.New(t)
|
||||
h := upstreams.HTTP(nil)
|
||||
h.Handle("/", func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.Write([]byte("OK"))
|
||||
})
|
||||
route := h.Route().From(env.SubdomainURL("foo")).Policy(func(p *config.Policy) {
|
||||
p.AllowPublicUnauthenticatedAccess = true
|
||||
})
|
||||
env.AddUpstream(h)
|
||||
|
||||
env.Start()
|
||||
snippets.WaitStartupComplete(env)
|
||||
|
||||
var traceHostPort, traceRemoteAddr string
|
||||
var dnsStartCalled, dnsEndCalled bool
|
||||
trace := httptrace.ClientTrace{
|
||||
DNSStart: func(_ httptrace.DNSStartInfo) {
|
||||
dnsStartCalled = true
|
||||
},
|
||||
DNSDone: func(_ httptrace.DNSDoneInfo) {
|
||||
dnsEndCalled = true
|
||||
},
|
||||
GetConn: func(hostPort string) {
|
||||
traceHostPort = hostPort
|
||||
},
|
||||
GotConn: func(gci httptrace.GotConnInfo) {
|
||||
traceRemoteAddr = gci.Conn.RemoteAddr().String()
|
||||
},
|
||||
}
|
||||
resp, err := h.Get(route, upstreams.WithClientTrace(&trace))
|
||||
require.NoError(t, err)
|
||||
require.False(t, dnsStartCalled)
|
||||
require.False(t, dnsEndCalled)
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
require.Equal(t, route.URL().Value(), "https://"+traceHostPort)
|
||||
host, _, err := net.SplitHostPort(traceRemoteAddr)
|
||||
require.NoError(t, err)
|
||||
require.True(t, net.ParseIP(host).IsLoopback())
|
||||
}
|
|
@ -128,6 +128,7 @@ func (g *grpcUpstream) Run(ctx context.Context) error {
|
|||
|
||||
func (g *grpcUpstream) Dial(r testenv.Route, dialOpts ...grpc.DialOption) *grpc.ClientConn {
|
||||
dialOpts = append(dialOpts,
|
||||
grpc.WithContextDialer(testenv.GRPCContextDialer),
|
||||
grpc.WithTransportCredentials(credentials.NewClientTLSFromCert(g.Env().ServerCAs(), "")),
|
||||
grpc.WithDefaultCallOptions(grpc.WaitForReady(true)),
|
||||
)
|
||||
|
|
|
@ -11,6 +11,7 @@ import (
|
|||
"net"
|
||||
"net/http"
|
||||
"net/http/cookiejar"
|
||||
"net/http/httptrace"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
@ -22,6 +23,7 @@ import (
|
|||
"github.com/pomerium/pomerium/internal/retry"
|
||||
"github.com/pomerium/pomerium/internal/testenv"
|
||||
"github.com/pomerium/pomerium/internal/testenv/values"
|
||||
"github.com/pomerium/pomerium/pkg/telemetry/requestid"
|
||||
"google.golang.org/protobuf/proto"
|
||||
)
|
||||
|
||||
|
@ -33,6 +35,7 @@ type RequestOptions struct {
|
|||
body any
|
||||
clientCerts []tls.Certificate
|
||||
client *http.Client
|
||||
trace *httptrace.ClientTrace
|
||||
}
|
||||
|
||||
type RequestOption func(*RequestOptions)
|
||||
|
@ -77,6 +80,12 @@ func Client(c *http.Client) RequestOption {
|
|||
}
|
||||
}
|
||||
|
||||
func WithClientTrace(ct *httptrace.ClientTrace) RequestOption {
|
||||
return func(o *RequestOptions) {
|
||||
o.trace = ct
|
||||
}
|
||||
}
|
||||
|
||||
// Body sets the body of the request.
|
||||
// The argument can be one of the following types:
|
||||
// - string
|
||||
|
@ -220,7 +229,11 @@ func (h *httpUpstream) Do(method string, r testenv.Route, opts ...RequestOption)
|
|||
RawQuery: options.query.Encode(),
|
||||
})
|
||||
}
|
||||
req, err := http.NewRequest(method, u.String(), nil)
|
||||
ctx := h.Env().Context()
|
||||
if options.trace != nil {
|
||||
ctx = httptrace.WithClientTrace(ctx, options.trace)
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, method, u.String(), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -249,13 +262,14 @@ func (h *httpUpstream) Do(method string, r testenv.Route, opts ...RequestOption)
|
|||
}
|
||||
|
||||
newClient := func() *http.Client {
|
||||
transport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
transport.TLSClientConfig = &tls.Config{
|
||||
RootCAs: h.Env().ServerCAs(),
|
||||
Certificates: options.clientCerts,
|
||||
}
|
||||
transport.DialTLSContext = nil
|
||||
c := http.Client{
|
||||
Transport: &http.Transport{
|
||||
TLSClientConfig: &tls.Config{
|
||||
RootCAs: h.Env().ServerCAs(),
|
||||
Certificates: options.clientCerts,
|
||||
},
|
||||
},
|
||||
Transport: requestid.NewRoundTripper(transport),
|
||||
}
|
||||
c.Jar, _ = cookiejar.New(&cookiejar.Options{})
|
||||
return &c
|
||||
|
@ -273,7 +287,7 @@ func (h *httpUpstream) Do(method string, r testenv.Route, opts ...RequestOption)
|
|||
}
|
||||
|
||||
var resp *http.Response
|
||||
if err := retry.Retry(h.Env().Context(), "http", func(ctx context.Context) error {
|
||||
if err := retry.Retry(ctx, "http", func(ctx context.Context) error {
|
||||
var err error
|
||||
if options.authenticateAs != "" {
|
||||
resp, err = authenticateFlow(ctx, client, req, options.authenticateAs) //nolint:bodyclose
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue