mirror of
https://github.com/pomerium/pomerium.git
synced 2025-06-01 18:33:19 +02:00
testenv: avoid dns lookups for *.localhost.pomerium.io (#5372)
* testenv: avoid dns lookups for localhost.pomerium.io * linter pass
This commit is contained in:
parent
55e13f9608
commit
39e789529e
4 changed files with 142 additions and 8 deletions
64
internal/testenv/dns.go
Normal file
64
internal/testenv/dns.go
Normal file
|
@ -0,0 +1,64 @@
|
|||
package testenv
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const localDomainName = "localhost.pomerium.io"
|
||||
|
||||
type DialContextFunc = func(ctx context.Context, network string, addr string) (net.Conn, error)
|
||||
|
||||
var defaultDialer = &net.Dialer{
|
||||
Timeout: 30 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
Resolver: &net.Resolver{
|
||||
PreferGo: true,
|
||||
},
|
||||
}
|
||||
|
||||
func init() {
|
||||
http.DefaultTransport.(*http.Transport).DialContext = OverrideDialContext(defaultDialer.DialContext)
|
||||
}
|
||||
|
||||
func OverrideDialContext(defaultDialContext DialContextFunc) DialContextFunc {
|
||||
return func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
host, port, err := maybeSplitHostPort(addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if strings.HasSuffix(host, localDomainName) {
|
||||
switch network {
|
||||
case "tcp", "tcp4", "udp", "udp4":
|
||||
host = "127.0.0.1"
|
||||
case "tcp6", "udp6":
|
||||
host = "::1"
|
||||
}
|
||||
}
|
||||
return defaultDialContext(ctx, network, net.JoinHostPort(host, port))
|
||||
}
|
||||
}
|
||||
|
||||
func maybeSplitHostPort(s string) (string, string, error) {
|
||||
if strings.Contains(s, ":") {
|
||||
return net.SplitHostPort(s)
|
||||
}
|
||||
return s, "", nil
|
||||
}
|
||||
|
||||
func GRPCContextDialer(ctx context.Context, target string) (net.Conn, error) {
|
||||
if strings.HasPrefix(target, "unix") {
|
||||
return defaultDialer.DialContext(ctx, "tcp", target)
|
||||
}
|
||||
host, port, err := net.SplitHostPort(target)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if strings.HasSuffix(host, localDomainName) {
|
||||
return defaultDialer.DialContext(ctx, "tcp", "127.0.0.1:"+port)
|
||||
}
|
||||
return defaultDialer.DialContext(ctx, "tcp", target)
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue