testenv: avoid dns lookups for *.localhost.pomerium.io (#5372)

* testenv: avoid dns lookups for localhost.pomerium.io

* linter pass
This commit is contained in:
Joe Kralicky 2024-12-02 12:29:15 -05:00 committed by GitHub
parent 55e13f9608
commit 39e789529e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 142 additions and 8 deletions

64
internal/testenv/dns.go Normal file
View file

@ -0,0 +1,64 @@
package testenv
import (
"context"
"net"
"net/http"
"strings"
"time"
)
const localDomainName = "localhost.pomerium.io"
type DialContextFunc = func(ctx context.Context, network string, addr string) (net.Conn, error)
var defaultDialer = &net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
Resolver: &net.Resolver{
PreferGo: true,
},
}
func init() {
http.DefaultTransport.(*http.Transport).DialContext = OverrideDialContext(defaultDialer.DialContext)
}
func OverrideDialContext(defaultDialContext DialContextFunc) DialContextFunc {
return func(ctx context.Context, network, addr string) (net.Conn, error) {
host, port, err := maybeSplitHostPort(addr)
if err != nil {
return nil, err
}
if strings.HasSuffix(host, localDomainName) {
switch network {
case "tcp", "tcp4", "udp", "udp4":
host = "127.0.0.1"
case "tcp6", "udp6":
host = "::1"
}
}
return defaultDialContext(ctx, network, net.JoinHostPort(host, port))
}
}
func maybeSplitHostPort(s string) (string, string, error) {
if strings.Contains(s, ":") {
return net.SplitHostPort(s)
}
return s, "", nil
}
func GRPCContextDialer(ctx context.Context, target string) (net.Conn, error) {
if strings.HasPrefix(target, "unix") {
return defaultDialer.DialContext(ctx, "tcp", target)
}
host, port, err := net.SplitHostPort(target)
if err != nil {
return nil, err
}
if strings.HasSuffix(host, localDomainName) {
return defaultDialer.DialContext(ctx, "tcp", "127.0.0.1:"+port)
}
return defaultDialer.DialContext(ctx, "tcp", target)
}