package testenv import ( "context" "net" "net/http" "strings" "time" ) const localDomainName = "localhost.pomerium.io" type DialContextFunc = func(ctx context.Context, network string, addr string) (net.Conn, error) var defaultDialer = &net.Dialer{ Timeout: 30 * time.Second, KeepAlive: 30 * time.Second, Resolver: &net.Resolver{ PreferGo: true, }, } func init() { http.DefaultTransport.(*http.Transport).DialContext = OverrideDialContext(defaultDialer.DialContext) } func OverrideDialContext(defaultDialContext DialContextFunc) DialContextFunc { return func(ctx context.Context, network, addr string) (net.Conn, error) { host, port, err := maybeSplitHostPort(addr) if err != nil { return nil, err } if strings.HasSuffix(host, localDomainName) { switch network { case "tcp", "tcp4", "udp", "udp4": host = "127.0.0.1" case "tcp6", "udp6": host = "::1" } } return defaultDialContext(ctx, network, net.JoinHostPort(host, port)) } } func maybeSplitHostPort(s string) (string, string, error) { if strings.Contains(s, ":") { return net.SplitHostPort(s) } return s, "", nil } func GRPCContextDialer(ctx context.Context, target string) (net.Conn, error) { if strings.HasPrefix(target, "unix") { return defaultDialer.DialContext(ctx, "tcp", target) } host, port, err := net.SplitHostPort(target) if err != nil { return nil, err } if strings.HasSuffix(host, localDomainName) { return defaultDialer.DialContext(ctx, "tcp", "127.0.0.1:"+port) } return defaultDialer.DialContext(ctx, "tcp", target) }