mirror of
https://github.com/pomerium/pomerium.git
synced 2025-04-28 09:56:31 +02:00
131 lines
3.7 KiB
Go
131 lines
3.7 KiB
Go
package config
|
|
|
|
import (
|
|
"context"
|
|
"crypto/tls"
|
|
"net"
|
|
"net/http"
|
|
"sync"
|
|
|
|
"github.com/pomerium/pomerium/internal/log"
|
|
"github.com/pomerium/pomerium/internal/tripper"
|
|
"github.com/pomerium/pomerium/pkg/cryptutil"
|
|
)
|
|
|
|
// NewHTTPTransport creates a new http transport. If CA or CAFile is set, the transport will
|
|
// add the CA to system cert pool.
|
|
func NewHTTPTransport(src Source) *http.Transport {
|
|
var (
|
|
lock sync.Mutex
|
|
tlsConfig *tls.Config
|
|
)
|
|
update := func(ctx context.Context, cfg *Config) {
|
|
rootCAs, err := cryptutil.GetCertPool(cfg.Options.CA, cfg.Options.CAFile)
|
|
if err == nil {
|
|
lock.Lock()
|
|
tlsConfig = &tls.Config{
|
|
RootCAs: rootCAs,
|
|
MinVersion: tls.VersionTLS12,
|
|
}
|
|
lock.Unlock()
|
|
} else {
|
|
log.Error(ctx).Err(err).Msg("config: error getting cert pool")
|
|
}
|
|
}
|
|
src.OnConfigChange(context.Background(), update)
|
|
update(context.Background(), src.GetConfig())
|
|
|
|
transport := http.DefaultTransport.(*http.Transport).Clone()
|
|
transport.DialTLSContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
|
|
lock.Lock()
|
|
d := &tls.Dialer{
|
|
Config: tlsConfig,
|
|
}
|
|
lock.Unlock()
|
|
return d.DialContext(ctx, network, addr)
|
|
}
|
|
transport.ForceAttemptHTTP2 = true
|
|
return transport
|
|
}
|
|
|
|
// NewPolicyHTTPTransport creates a new http RoundTripper for a policy.
|
|
func NewPolicyHTTPTransport(options *Options, policy *Policy, disableHTTP2 bool) http.RoundTripper {
|
|
transport := http.DefaultTransport.(interface {
|
|
Clone() *http.Transport
|
|
}).Clone()
|
|
c := tripper.NewChain()
|
|
|
|
// according to the docs:
|
|
//
|
|
// Programs that must disable HTTP/2 can do so by setting Transport.TLSNextProto (for clients) or
|
|
// Server.TLSNextProto (for servers) to a non-nil, empty map.
|
|
//
|
|
if disableHTTP2 {
|
|
transport.TLSNextProto = map[string]func(authority string, c *tls.Conn) http.RoundTripper{}
|
|
transport.ForceAttemptHTTP2 = false
|
|
}
|
|
|
|
var tlsClientConfig tls.Config
|
|
var isCustomClientConfig bool
|
|
|
|
if policy.TLSSkipVerify {
|
|
tlsClientConfig.InsecureSkipVerify = true
|
|
isCustomClientConfig = true
|
|
}
|
|
|
|
if options.CA != "" || options.CAFile != "" {
|
|
rootCAs, err := cryptutil.GetCertPool(options.CA, options.CAFile)
|
|
if err == nil {
|
|
tlsClientConfig.RootCAs = rootCAs
|
|
tlsClientConfig.MinVersion = tls.VersionTLS12
|
|
isCustomClientConfig = true
|
|
} else {
|
|
log.Error(context.TODO()).Err(err).Msg("config: error getting ca cert pool")
|
|
}
|
|
}
|
|
|
|
if policy.TLSCustomCA != "" || policy.TLSCustomCAFile != "" {
|
|
rootCAs, err := cryptutil.GetCertPool(policy.TLSCustomCA, policy.TLSCustomCAFile)
|
|
if err == nil {
|
|
tlsClientConfig.RootCAs = rootCAs
|
|
tlsClientConfig.MinVersion = tls.VersionTLS12
|
|
isCustomClientConfig = true
|
|
} else {
|
|
log.Error(context.TODO()).Err(err).Msg("config: error getting custom ca cert pool")
|
|
}
|
|
}
|
|
|
|
if policy.ClientCertificate != nil {
|
|
tlsClientConfig.Certificates = []tls.Certificate{*policy.ClientCertificate}
|
|
isCustomClientConfig = true
|
|
}
|
|
|
|
if policy.TLSServerName != "" {
|
|
tlsClientConfig.ServerName = policy.TLSServerName
|
|
isCustomClientConfig = true
|
|
}
|
|
if policy.TLSUpstreamServerName != "" {
|
|
tlsClientConfig.ServerName = policy.TLSUpstreamServerName
|
|
isCustomClientConfig = true
|
|
}
|
|
|
|
// We avoid setting a custom client config unless we have to as
|
|
// if TLSClientConfig is nil, the default configuration is used.
|
|
if isCustomClientConfig {
|
|
transport.DialTLSContext = nil
|
|
transport.TLSClientConfig = &tlsClientConfig
|
|
}
|
|
return c.Then(transport)
|
|
}
|
|
|
|
// GetTLSClientTransport returns http transport accounting for custom CAs from config
|
|
func GetTLSClientTransport(cfg *Config) (*http.Transport, error) {
|
|
tlsConfig, err := cfg.GetTLSClientConfig()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &http.Transport{
|
|
TLSClientConfig: tlsConfig,
|
|
ForceAttemptHTTP2: true,
|
|
}, nil
|
|
}
|