mirror of
https://github.com/pomerium/pomerium.git
synced 2025-04-29 02:16:28 +02:00
64 lines
1.5 KiB
Go
64 lines
1.5 KiB
Go
package testenv
|
|
|
|
import (
|
|
"context"
|
|
"net"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
const localDomainName = "localhost.pomerium.io"
|
|
|
|
type DialContextFunc = func(ctx context.Context, network string, addr string) (net.Conn, error)
|
|
|
|
var defaultDialer = &net.Dialer{
|
|
Timeout: 30 * time.Second,
|
|
KeepAlive: 30 * time.Second,
|
|
Resolver: &net.Resolver{
|
|
PreferGo: true,
|
|
},
|
|
}
|
|
|
|
func init() {
|
|
http.DefaultTransport.(*http.Transport).DialContext = OverrideDialContext(defaultDialer.DialContext)
|
|
}
|
|
|
|
func OverrideDialContext(defaultDialContext DialContextFunc) DialContextFunc {
|
|
return func(ctx context.Context, network, addr string) (net.Conn, error) {
|
|
host, port, err := maybeSplitHostPort(addr)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if strings.HasSuffix(host, localDomainName) {
|
|
switch network {
|
|
case "tcp", "tcp4", "udp", "udp4":
|
|
host = "127.0.0.1"
|
|
case "tcp6", "udp6":
|
|
host = "::1"
|
|
}
|
|
}
|
|
return defaultDialContext(ctx, network, net.JoinHostPort(host, port))
|
|
}
|
|
}
|
|
|
|
func maybeSplitHostPort(s string) (string, string, error) {
|
|
if strings.Contains(s, ":") {
|
|
return net.SplitHostPort(s)
|
|
}
|
|
return s, "", nil
|
|
}
|
|
|
|
func GRPCContextDialer(ctx context.Context, target string) (net.Conn, error) {
|
|
if strings.HasPrefix(target, "unix") {
|
|
return defaultDialer.DialContext(ctx, "tcp", target)
|
|
}
|
|
host, port, err := net.SplitHostPort(target)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if strings.HasSuffix(host, localDomainName) {
|
|
return defaultDialer.DialContext(ctx, "tcp", "127.0.0.1:"+port)
|
|
}
|
|
return defaultDialer.DialContext(ctx, "tcp", target)
|
|
}
|