From 8924b1a5fc58bfc6aaa46bd89e87e732fc549dc5 Mon Sep 17 00:00:00 2001 From: Caleb Doxsey Date: Fri, 9 Apr 2021 12:26:46 -0600 Subject: [PATCH] config: use tls_custom_ca from policy if available (#2077) --- config/http.go | 13 ++++++++++++- config/http_test.go | 36 ++++++++++++++++++++++++++++++++++++ 2 files changed, 48 insertions(+), 1 deletion(-) diff --git a/config/http.go b/config/http.go index e5684ca32..2c18d3130 100644 --- a/config/http.go +++ b/config/http.go @@ -78,7 +78,18 @@ func NewPolicyHTTPTransport(options *Options, policy *Policy) http.RoundTripper tlsClientConfig.MinVersion = tls.VersionTLS12 isCustomClientConfig = true } else { - log.Error().Err(err).Msg("config: error getting cert pool") + log.Error().Err(err).Msg("config: error getting ca cert pool") + } + } + + if policy.TLSCustomCA != "" || policy.TLSCustomCAFile != "" { + rootCAs, err := cryptutil.GetCertPool(policy.TLSCustomCA, policy.TLSCustomCAFile) + if err == nil { + tlsClientConfig.RootCAs = rootCAs + tlsClientConfig.MinVersion = tls.VersionTLS12 + isCustomClientConfig = true + } else { + log.Error().Err(err).Msg("config: error getting custom ca cert pool") } } diff --git a/config/http_test.go b/config/http_test.go index cb22d943a..3fbfb3d57 100644 --- a/config/http_test.go +++ b/config/http_test.go @@ -45,3 +45,39 @@ func TestHTTPTransport(t *testing.T) { _, err := client.Get(s.URL) assert.NoError(t, err) } + +func TestPolicyHTTPTransport(t *testing.T) { + s := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer s.Close() + + get := func(options *Options, policy *Policy) (*http.Response, error) { + transport := NewPolicyHTTPTransport(options, policy) + client := &http.Client{ + Transport: transport, + } + return client.Get(s.URL) + } + + t.Run("default", func(t *testing.T) { + _, err := get(&Options{}, &Policy{}) + assert.Error(t, err) + }) + t.Run("skip verify", func(t *testing.T) { + _, err := get(&Options{}, &Policy{TLSSkipVerify: true}) + assert.NoError(t, err) + }) + t.Run("ca", func(t *testing.T) { + _, err := get(&Options{ + CA: base64.StdEncoding.EncodeToString([]byte(localCert)), + }, &Policy{}) + assert.NoError(t, err) + }) + t.Run("custom ca", func(t *testing.T) { + _, err := get(&Options{}, &Policy{ + TLSCustomCA: base64.StdEncoding.EncodeToString([]byte(localCert)), + }) + assert.NoError(t, err) + }) +}