From ccdd1e5586c4c6fc935445f8654eef7b20fb0df0 Mon Sep 17 00:00:00 2001 From: Caleb Doxsey Date: Wed, 4 Nov 2020 15:35:10 -0700 Subject: [PATCH] use custom default http transport (#1576) * use custom default http transport * Update config/http.go Co-authored-by: bobby <1544881+desimone@users.noreply.github.com> * Update config/http.go Co-authored-by: bobby <1544881+desimone@users.noreply.github.com> * return early Co-authored-by: bobby <1544881+desimone@users.noreply.github.com> --- config/http.go | 50 +++++++++++++++++++++++++++++++ config/http_test.go | 47 +++++++++++++++++++++++++++++ internal/cmd/pomerium/pomerium.go | 4 +++ pkg/cryptutil/tls.go | 35 ++++++++++++++++++++++ pkg/grpc/client.go | 28 ++--------------- 5 files changed, 139 insertions(+), 25 deletions(-) create mode 100644 config/http.go create mode 100644 config/http_test.go diff --git a/config/http.go b/config/http.go new file mode 100644 index 000000000..83267c238 --- /dev/null +++ b/config/http.go @@ -0,0 +1,50 @@ +package config + +import ( + "crypto/tls" + "net/http" + "sync/atomic" + + "github.com/pomerium/pomerium/internal/log" + "github.com/pomerium/pomerium/pkg/cryptutil" +) + +type httpTransport struct { + underlying *http.Transport + transport atomic.Value +} + +// NewHTTPTransport creates a new http transport. If CA or CAFile is set, the transport will +// add the CA to system cert pool. +func NewHTTPTransport(src Source) http.RoundTripper { + t := new(httpTransport) + t.underlying, _ = http.DefaultTransport.(*http.Transport) + src.OnConfigChange(func(cfg *Config) { + t.update(cfg.Options) + }) + t.update(src.GetConfig().Options) + return t +} + +func (t *httpTransport) update(options *Options) { + nt := new(http.Transport) + if t.underlying != nil { + nt = t.underlying.Clone() + } + if options.CA != "" || options.CAFile != "" { + rootCAs, err := cryptutil.GetCertPool(options.CA, options.CAFile) + if err == nil { + nt.TLSClientConfig = &tls.Config{ + RootCAs: rootCAs, + } + } else { + log.Error().Err(err).Msg("config: error getting cert pool") + } + } + t.transport.Store(nt) +} + +// RoundTrip executes an HTTP request. +func (t *httpTransport) RoundTrip(req *http.Request) (*http.Response, error) { + return t.transport.Load().(http.RoundTripper).RoundTrip(req) +} diff --git a/config/http_test.go b/config/http_test.go new file mode 100644 index 000000000..cb22d943a --- /dev/null +++ b/config/http_test.go @@ -0,0 +1,47 @@ +package config + +import ( + "encoding/base64" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" +) + +// this cert is the cert used by httptest when creating a TLS server +var localCert = ` +-----BEGIN CERTIFICATE----- +MIICEzCCAXygAwIBAgIQMIMChMLGrR+QvmQvpwAU6zANBgkqhkiG9w0BAQsFADAS +MRAwDgYDVQQKEwdBY21lIENvMCAXDTcwMDEwMTAwMDAwMFoYDzIwODQwMTI5MTYw +MDAwWjASMRAwDgYDVQQKEwdBY21lIENvMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCB +iQKBgQDuLnQAI3mDgey3VBzWnB2L39JUU4txjeVE6myuDqkM/uGlfjb9SjY1bIw4 +iA5sBBZzHi3z0h1YV8QPuxEbi4nW91IJm2gsvvZhIrCHS3l6afab4pZBl2+XsDul +rKBxKKtD1rGxlG4LjncdabFn9gvLZad2bSysqz/qTAUStTvqJQIDAQABo2gwZjAO +BgNVHQ8BAf8EBAMCAqQwEwYDVR0lBAwwCgYIKwYBBQUHAwEwDwYDVR0TAQH/BAUw +AwEB/zAuBgNVHREEJzAlggtleGFtcGxlLmNvbYcEfwAAAYcQAAAAAAAAAAAAAAAA +AAAAATANBgkqhkiG9w0BAQsFAAOBgQCEcetwO59EWk7WiJsG4x8SY+UIAA+flUI9 +tyC4lNhbcF2Idq9greZwbYCqTTTr2XiRNSMLCOjKyI7ukPoPjo16ocHj+P3vZGfs +h1fIw3cSS2OolhloGw/XM6RWPWtPAlGykKLciQrBru5NAPvCMsb/I1DAceTiotQM +fblo6RBxUQ== +-----END CERTIFICATE----- +` + +func TestHTTPTransport(t *testing.T) { + s := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer s.Close() + + src := NewStaticSource(&Config{ + Options: &Options{ + CA: base64.StdEncoding.EncodeToString([]byte(localCert)), + }, + }) + transport := NewHTTPTransport(src) + client := &http.Client{ + Transport: transport, + } + _, err := client.Get(s.URL) + assert.NoError(t, err) +} diff --git a/internal/cmd/pomerium/pomerium.go b/internal/cmd/pomerium/pomerium.go index 77472bc2f..a7a679158 100644 --- a/internal/cmd/pomerium/pomerium.go +++ b/internal/cmd/pomerium/pomerium.go @@ -6,6 +6,7 @@ import ( "context" "fmt" "net" + "net/http" "os" "os/signal" "syscall" @@ -45,6 +46,9 @@ func Run(ctx context.Context, configFile string) error { src = databroker.NewConfigSource(src) + // override the default http transport so we can use the custom CA in the TLS client config (#1570) + http.DefaultTransport = config.NewHTTPTransport(src) + logMgr := config.NewLogManager(src) defer logMgr.Close() metricsMgr := config.NewMetricsManager(src) diff --git a/pkg/cryptutil/tls.go b/pkg/cryptutil/tls.go index 4b705df99..185240ed4 100644 --- a/pkg/cryptutil/tls.go +++ b/pkg/cryptutil/tls.go @@ -3,10 +3,45 @@ package cryptutil import ( "crypto/tls" "crypto/x509" + "encoding/base64" + "fmt" + "io/ioutil" "github.com/caddyserver/certmagic" + + "github.com/pomerium/pomerium/internal/log" ) +// GetCertPool gets a cert pool for the given CA or CAFile. +func GetCertPool(ca, caFile string) (*x509.CertPool, error) { + rootCAs, err := x509.SystemCertPool() + if err != nil { + log.Error().Msg("pkg/cryptutil: failed getting system cert pool making new one") + rootCAs = x509.NewCertPool() + } + if ca == "" && caFile == "" { + return rootCAs, nil + } + + var data []byte + if ca != "" { + data, err = base64.StdEncoding.DecodeString(ca) + if err != nil { + return nil, fmt.Errorf("failed to decode certificate authority: %w", err) + } + } else { + data, err = ioutil.ReadFile(caFile) + if err != nil { + return nil, fmt.Errorf("certificate authority file %v not readable: %w", caFile, err) + } + } + if ok := rootCAs.AppendCertsFromPEM(data); !ok { + return nil, fmt.Errorf("failed to append CA cert to certPool") + } + log.Debug().Msg("pkg/cryptutil: added custom certificate authority") + return rootCAs, nil +} + // GetCertificateForDomain returns the tls Certificate which matches the given domain name. // It should handle both exact matches and wildcard matches. If none of those match, the first certificate will be used. // Finally if there are no matching certificates one will be generated. diff --git a/pkg/grpc/client.go b/pkg/grpc/client.go index 63d678666..33b17dac4 100644 --- a/pkg/grpc/client.go +++ b/pkg/grpc/client.go @@ -3,11 +3,8 @@ package grpc import ( "context" "crypto/tls" - "crypto/x509" - "encoding/base64" "errors" "fmt" - "io/ioutil" "net" "net/url" "strconv" @@ -22,6 +19,7 @@ import ( "github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/telemetry" "github.com/pomerium/pomerium/internal/telemetry/requestid" + "github.com/pomerium/pomerium/pkg/cryptutil" "github.com/pomerium/pomerium/pkg/grpcutil" ) @@ -100,29 +98,9 @@ func NewGRPCClientConn(opts *Options) (*grpc.ClientConn, error) { log.Info().Str("addr", connAddr).Msg("internal/grpc: grpc with insecure") dialOptions = append(dialOptions, grpc.WithInsecure()) } else { - rootCAs, err := x509.SystemCertPool() + rootCAs, err := cryptutil.GetCertPool(opts.CA, opts.CAFile) if err != nil { - log.Warn().Msg("internal/grpc: failed getting system cert pool making new one") - rootCAs = x509.NewCertPool() - } - if opts.CA != "" || opts.CAFile != "" { - var ca []byte - var err error - if opts.CA != "" { - ca, err = base64.StdEncoding.DecodeString(opts.CA) - if err != nil { - return nil, fmt.Errorf("failed to decode certificate authority: %w", err) - } - } else { - ca, err = ioutil.ReadFile(opts.CAFile) - if err != nil { - return nil, fmt.Errorf("certificate authority file %v not readable: %w", opts.CAFile, err) - } - } - if ok := rootCAs.AppendCertsFromPEM(ca); !ok { - return nil, fmt.Errorf("failed to append CA cert to certPool") - } - log.Debug().Msg("internal/grpc: added custom certificate authority") + return nil, err } cert := credentials.NewTLS(&tls.Config{RootCAs: rootCAs})