mirror of
https://github.com/pomerium/pomerium.git
synced 2025-06-06 12:52:53 +02:00
testenv: avoid dns lookups for *.localhost.pomerium.io (#5372)
* testenv: avoid dns lookups for localhost.pomerium.io * linter pass
This commit is contained in:
parent
55e13f9608
commit
39e789529e
4 changed files with 142 additions and 8 deletions
64
internal/testenv/dns.go
Normal file
64
internal/testenv/dns.go
Normal file
|
@ -0,0 +1,64 @@
|
||||||
|
package testenv
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const localDomainName = "localhost.pomerium.io"
|
||||||
|
|
||||||
|
type DialContextFunc = func(ctx context.Context, network string, addr string) (net.Conn, error)
|
||||||
|
|
||||||
|
var defaultDialer = &net.Dialer{
|
||||||
|
Timeout: 30 * time.Second,
|
||||||
|
KeepAlive: 30 * time.Second,
|
||||||
|
Resolver: &net.Resolver{
|
||||||
|
PreferGo: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
http.DefaultTransport.(*http.Transport).DialContext = OverrideDialContext(defaultDialer.DialContext)
|
||||||
|
}
|
||||||
|
|
||||||
|
func OverrideDialContext(defaultDialContext DialContextFunc) DialContextFunc {
|
||||||
|
return func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||||
|
host, port, err := maybeSplitHostPort(addr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if strings.HasSuffix(host, localDomainName) {
|
||||||
|
switch network {
|
||||||
|
case "tcp", "tcp4", "udp", "udp4":
|
||||||
|
host = "127.0.0.1"
|
||||||
|
case "tcp6", "udp6":
|
||||||
|
host = "::1"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return defaultDialContext(ctx, network, net.JoinHostPort(host, port))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func maybeSplitHostPort(s string) (string, string, error) {
|
||||||
|
if strings.Contains(s, ":") {
|
||||||
|
return net.SplitHostPort(s)
|
||||||
|
}
|
||||||
|
return s, "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func GRPCContextDialer(ctx context.Context, target string) (net.Conn, error) {
|
||||||
|
if strings.HasPrefix(target, "unix") {
|
||||||
|
return defaultDialer.DialContext(ctx, "tcp", target)
|
||||||
|
}
|
||||||
|
host, port, err := net.SplitHostPort(target)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if strings.HasSuffix(host, localDomainName) {
|
||||||
|
return defaultDialer.DialContext(ctx, "tcp", "127.0.0.1:"+port)
|
||||||
|
}
|
||||||
|
return defaultDialer.DialContext(ctx, "tcp", target)
|
||||||
|
}
|
55
internal/testenv/selftests/dns_test.go
Normal file
55
internal/testenv/selftests/dns_test.go
Normal file
|
@ -0,0 +1,55 @@
|
||||||
|
package selftests_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptrace"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/config"
|
||||||
|
"github.com/pomerium/pomerium/internal/testenv"
|
||||||
|
"github.com/pomerium/pomerium/internal/testenv/snippets"
|
||||||
|
"github.com/pomerium/pomerium/internal/testenv/upstreams"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDNSOverrides(t *testing.T) {
|
||||||
|
env := testenv.New(t)
|
||||||
|
h := upstreams.HTTP(nil)
|
||||||
|
h.Handle("/", func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
w.Write([]byte("OK"))
|
||||||
|
})
|
||||||
|
route := h.Route().From(env.SubdomainURL("foo")).Policy(func(p *config.Policy) {
|
||||||
|
p.AllowPublicUnauthenticatedAccess = true
|
||||||
|
})
|
||||||
|
env.AddUpstream(h)
|
||||||
|
|
||||||
|
env.Start()
|
||||||
|
snippets.WaitStartupComplete(env)
|
||||||
|
|
||||||
|
var traceHostPort, traceRemoteAddr string
|
||||||
|
var dnsStartCalled, dnsEndCalled bool
|
||||||
|
trace := httptrace.ClientTrace{
|
||||||
|
DNSStart: func(_ httptrace.DNSStartInfo) {
|
||||||
|
dnsStartCalled = true
|
||||||
|
},
|
||||||
|
DNSDone: func(_ httptrace.DNSDoneInfo) {
|
||||||
|
dnsEndCalled = true
|
||||||
|
},
|
||||||
|
GetConn: func(hostPort string) {
|
||||||
|
traceHostPort = hostPort
|
||||||
|
},
|
||||||
|
GotConn: func(gci httptrace.GotConnInfo) {
|
||||||
|
traceRemoteAddr = gci.Conn.RemoteAddr().String()
|
||||||
|
},
|
||||||
|
}
|
||||||
|
resp, err := h.Get(route, upstreams.WithClientTrace(&trace))
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.False(t, dnsStartCalled)
|
||||||
|
require.False(t, dnsEndCalled)
|
||||||
|
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||||
|
require.Equal(t, route.URL().Value(), "https://"+traceHostPort)
|
||||||
|
host, _, err := net.SplitHostPort(traceRemoteAddr)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.True(t, net.ParseIP(host).IsLoopback())
|
||||||
|
}
|
|
@ -128,6 +128,7 @@ func (g *grpcUpstream) Run(ctx context.Context) error {
|
||||||
|
|
||||||
func (g *grpcUpstream) Dial(r testenv.Route, dialOpts ...grpc.DialOption) *grpc.ClientConn {
|
func (g *grpcUpstream) Dial(r testenv.Route, dialOpts ...grpc.DialOption) *grpc.ClientConn {
|
||||||
dialOpts = append(dialOpts,
|
dialOpts = append(dialOpts,
|
||||||
|
grpc.WithContextDialer(testenv.GRPCContextDialer),
|
||||||
grpc.WithTransportCredentials(credentials.NewClientTLSFromCert(g.Env().ServerCAs(), "")),
|
grpc.WithTransportCredentials(credentials.NewClientTLSFromCert(g.Env().ServerCAs(), "")),
|
||||||
grpc.WithDefaultCallOptions(grpc.WaitForReady(true)),
|
grpc.WithDefaultCallOptions(grpc.WaitForReady(true)),
|
||||||
)
|
)
|
||||||
|
|
|
@ -11,6 +11,7 @@ import (
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/cookiejar"
|
"net/http/cookiejar"
|
||||||
|
"net/http/httptrace"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
@ -22,6 +23,7 @@ import (
|
||||||
"github.com/pomerium/pomerium/internal/retry"
|
"github.com/pomerium/pomerium/internal/retry"
|
||||||
"github.com/pomerium/pomerium/internal/testenv"
|
"github.com/pomerium/pomerium/internal/testenv"
|
||||||
"github.com/pomerium/pomerium/internal/testenv/values"
|
"github.com/pomerium/pomerium/internal/testenv/values"
|
||||||
|
"github.com/pomerium/pomerium/pkg/telemetry/requestid"
|
||||||
"google.golang.org/protobuf/proto"
|
"google.golang.org/protobuf/proto"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -33,6 +35,7 @@ type RequestOptions struct {
|
||||||
body any
|
body any
|
||||||
clientCerts []tls.Certificate
|
clientCerts []tls.Certificate
|
||||||
client *http.Client
|
client *http.Client
|
||||||
|
trace *httptrace.ClientTrace
|
||||||
}
|
}
|
||||||
|
|
||||||
type RequestOption func(*RequestOptions)
|
type RequestOption func(*RequestOptions)
|
||||||
|
@ -77,6 +80,12 @@ func Client(c *http.Client) RequestOption {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func WithClientTrace(ct *httptrace.ClientTrace) RequestOption {
|
||||||
|
return func(o *RequestOptions) {
|
||||||
|
o.trace = ct
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Body sets the body of the request.
|
// Body sets the body of the request.
|
||||||
// The argument can be one of the following types:
|
// The argument can be one of the following types:
|
||||||
// - string
|
// - string
|
||||||
|
@ -220,7 +229,11 @@ func (h *httpUpstream) Do(method string, r testenv.Route, opts ...RequestOption)
|
||||||
RawQuery: options.query.Encode(),
|
RawQuery: options.query.Encode(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
req, err := http.NewRequest(method, u.String(), nil)
|
ctx := h.Env().Context()
|
||||||
|
if options.trace != nil {
|
||||||
|
ctx = httptrace.WithClientTrace(ctx, options.trace)
|
||||||
|
}
|
||||||
|
req, err := http.NewRequestWithContext(ctx, method, u.String(), nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -249,13 +262,14 @@ func (h *httpUpstream) Do(method string, r testenv.Route, opts ...RequestOption)
|
||||||
}
|
}
|
||||||
|
|
||||||
newClient := func() *http.Client {
|
newClient := func() *http.Client {
|
||||||
|
transport := http.DefaultTransport.(*http.Transport).Clone()
|
||||||
|
transport.TLSClientConfig = &tls.Config{
|
||||||
|
RootCAs: h.Env().ServerCAs(),
|
||||||
|
Certificates: options.clientCerts,
|
||||||
|
}
|
||||||
|
transport.DialTLSContext = nil
|
||||||
c := http.Client{
|
c := http.Client{
|
||||||
Transport: &http.Transport{
|
Transport: requestid.NewRoundTripper(transport),
|
||||||
TLSClientConfig: &tls.Config{
|
|
||||||
RootCAs: h.Env().ServerCAs(),
|
|
||||||
Certificates: options.clientCerts,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
c.Jar, _ = cookiejar.New(&cookiejar.Options{})
|
c.Jar, _ = cookiejar.New(&cookiejar.Options{})
|
||||||
return &c
|
return &c
|
||||||
|
@ -273,7 +287,7 @@ func (h *httpUpstream) Do(method string, r testenv.Route, opts ...RequestOption)
|
||||||
}
|
}
|
||||||
|
|
||||||
var resp *http.Response
|
var resp *http.Response
|
||||||
if err := retry.Retry(h.Env().Context(), "http", func(ctx context.Context) error {
|
if err := retry.Retry(ctx, "http", func(ctx context.Context) error {
|
||||||
var err error
|
var err error
|
||||||
if options.authenticateAs != "" {
|
if options.authenticateAs != "" {
|
||||||
resp, err = authenticateFlow(ctx, client, req, options.authenticateAs) //nolint:bodyclose
|
resp, err = authenticateFlow(ctx, client, req, options.authenticateAs) //nolint:bodyclose
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue