mirror of
https://github.com/pomerium/pomerium.git
synced 2025-06-02 02:42:57 +02:00
use custom default http transport (#1576)
* use custom default http transport * Update config/http.go Co-authored-by: bobby <1544881+desimone@users.noreply.github.com> * Update config/http.go Co-authored-by: bobby <1544881+desimone@users.noreply.github.com> * return early Co-authored-by: bobby <1544881+desimone@users.noreply.github.com>
This commit is contained in:
parent
1910125e6f
commit
ccdd1e5586
5 changed files with 139 additions and 25 deletions
50
config/http.go
Normal file
50
config/http.go
Normal file
|
@ -0,0 +1,50 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"net/http"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||
)
|
||||
|
||||
type httpTransport struct {
|
||||
underlying *http.Transport
|
||||
transport atomic.Value
|
||||
}
|
||||
|
||||
// NewHTTPTransport creates a new http transport. If CA or CAFile is set, the transport will
|
||||
// add the CA to system cert pool.
|
||||
func NewHTTPTransport(src Source) http.RoundTripper {
|
||||
t := new(httpTransport)
|
||||
t.underlying, _ = http.DefaultTransport.(*http.Transport)
|
||||
src.OnConfigChange(func(cfg *Config) {
|
||||
t.update(cfg.Options)
|
||||
})
|
||||
t.update(src.GetConfig().Options)
|
||||
return t
|
||||
}
|
||||
|
||||
func (t *httpTransport) update(options *Options) {
|
||||
nt := new(http.Transport)
|
||||
if t.underlying != nil {
|
||||
nt = t.underlying.Clone()
|
||||
}
|
||||
if options.CA != "" || options.CAFile != "" {
|
||||
rootCAs, err := cryptutil.GetCertPool(options.CA, options.CAFile)
|
||||
if err == nil {
|
||||
nt.TLSClientConfig = &tls.Config{
|
||||
RootCAs: rootCAs,
|
||||
}
|
||||
} else {
|
||||
log.Error().Err(err).Msg("config: error getting cert pool")
|
||||
}
|
||||
}
|
||||
t.transport.Store(nt)
|
||||
}
|
||||
|
||||
// RoundTrip executes an HTTP request.
|
||||
func (t *httpTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
return t.transport.Load().(http.RoundTripper).RoundTrip(req)
|
||||
}
|
47
config/http_test.go
Normal file
47
config/http_test.go
Normal file
|
@ -0,0 +1,47 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// this cert is the cert used by httptest when creating a TLS server
|
||||
var localCert = `
|
||||
-----BEGIN CERTIFICATE-----
|
||||
MIICEzCCAXygAwIBAgIQMIMChMLGrR+QvmQvpwAU6zANBgkqhkiG9w0BAQsFADAS
|
||||
MRAwDgYDVQQKEwdBY21lIENvMCAXDTcwMDEwMTAwMDAwMFoYDzIwODQwMTI5MTYw
|
||||
MDAwWjASMRAwDgYDVQQKEwdBY21lIENvMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCB
|
||||
iQKBgQDuLnQAI3mDgey3VBzWnB2L39JUU4txjeVE6myuDqkM/uGlfjb9SjY1bIw4
|
||||
iA5sBBZzHi3z0h1YV8QPuxEbi4nW91IJm2gsvvZhIrCHS3l6afab4pZBl2+XsDul
|
||||
rKBxKKtD1rGxlG4LjncdabFn9gvLZad2bSysqz/qTAUStTvqJQIDAQABo2gwZjAO
|
||||
BgNVHQ8BAf8EBAMCAqQwEwYDVR0lBAwwCgYIKwYBBQUHAwEwDwYDVR0TAQH/BAUw
|
||||
AwEB/zAuBgNVHREEJzAlggtleGFtcGxlLmNvbYcEfwAAAYcQAAAAAAAAAAAAAAAA
|
||||
AAAAATANBgkqhkiG9w0BAQsFAAOBgQCEcetwO59EWk7WiJsG4x8SY+UIAA+flUI9
|
||||
tyC4lNhbcF2Idq9greZwbYCqTTTr2XiRNSMLCOjKyI7ukPoPjo16ocHj+P3vZGfs
|
||||
h1fIw3cSS2OolhloGw/XM6RWPWtPAlGykKLciQrBru5NAPvCMsb/I1DAceTiotQM
|
||||
fblo6RBxUQ==
|
||||
-----END CERTIFICATE-----
|
||||
`
|
||||
|
||||
func TestHTTPTransport(t *testing.T) {
|
||||
s := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer s.Close()
|
||||
|
||||
src := NewStaticSource(&Config{
|
||||
Options: &Options{
|
||||
CA: base64.StdEncoding.EncodeToString([]byte(localCert)),
|
||||
},
|
||||
})
|
||||
transport := NewHTTPTransport(src)
|
||||
client := &http.Client{
|
||||
Transport: transport,
|
||||
}
|
||||
_, err := client.Get(s.URL)
|
||||
assert.NoError(t, err)
|
||||
}
|
|
@ -6,6 +6,7 @@ import (
|
|||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
@ -45,6 +46,9 @@ func Run(ctx context.Context, configFile string) error {
|
|||
|
||||
src = databroker.NewConfigSource(src)
|
||||
|
||||
// override the default http transport so we can use the custom CA in the TLS client config (#1570)
|
||||
http.DefaultTransport = config.NewHTTPTransport(src)
|
||||
|
||||
logMgr := config.NewLogManager(src)
|
||||
defer logMgr.Close()
|
||||
metricsMgr := config.NewMetricsManager(src)
|
||||
|
|
|
@ -3,10 +3,45 @@ package cryptutil
|
|||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
|
||||
"github.com/caddyserver/certmagic"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
)
|
||||
|
||||
// GetCertPool gets a cert pool for the given CA or CAFile.
|
||||
func GetCertPool(ca, caFile string) (*x509.CertPool, error) {
|
||||
rootCAs, err := x509.SystemCertPool()
|
||||
if err != nil {
|
||||
log.Error().Msg("pkg/cryptutil: failed getting system cert pool making new one")
|
||||
rootCAs = x509.NewCertPool()
|
||||
}
|
||||
if ca == "" && caFile == "" {
|
||||
return rootCAs, nil
|
||||
}
|
||||
|
||||
var data []byte
|
||||
if ca != "" {
|
||||
data, err = base64.StdEncoding.DecodeString(ca)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode certificate authority: %w", err)
|
||||
}
|
||||
} else {
|
||||
data, err = ioutil.ReadFile(caFile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("certificate authority file %v not readable: %w", caFile, err)
|
||||
}
|
||||
}
|
||||
if ok := rootCAs.AppendCertsFromPEM(data); !ok {
|
||||
return nil, fmt.Errorf("failed to append CA cert to certPool")
|
||||
}
|
||||
log.Debug().Msg("pkg/cryptutil: added custom certificate authority")
|
||||
return rootCAs, nil
|
||||
}
|
||||
|
||||
// GetCertificateForDomain returns the tls Certificate which matches the given domain name.
|
||||
// It should handle both exact matches and wildcard matches. If none of those match, the first certificate will be used.
|
||||
// Finally if there are no matching certificates one will be generated.
|
||||
|
|
|
@ -3,11 +3,8 @@ package grpc
|
|||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/url"
|
||||
"strconv"
|
||||
|
@ -22,6 +19,7 @@ import (
|
|||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/internal/telemetry"
|
||||
"github.com/pomerium/pomerium/internal/telemetry/requestid"
|
||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||
"github.com/pomerium/pomerium/pkg/grpcutil"
|
||||
)
|
||||
|
||||
|
@ -100,29 +98,9 @@ func NewGRPCClientConn(opts *Options) (*grpc.ClientConn, error) {
|
|||
log.Info().Str("addr", connAddr).Msg("internal/grpc: grpc with insecure")
|
||||
dialOptions = append(dialOptions, grpc.WithInsecure())
|
||||
} else {
|
||||
rootCAs, err := x509.SystemCertPool()
|
||||
rootCAs, err := cryptutil.GetCertPool(opts.CA, opts.CAFile)
|
||||
if err != nil {
|
||||
log.Warn().Msg("internal/grpc: failed getting system cert pool making new one")
|
||||
rootCAs = x509.NewCertPool()
|
||||
}
|
||||
if opts.CA != "" || opts.CAFile != "" {
|
||||
var ca []byte
|
||||
var err error
|
||||
if opts.CA != "" {
|
||||
ca, err = base64.StdEncoding.DecodeString(opts.CA)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode certificate authority: %w", err)
|
||||
}
|
||||
} else {
|
||||
ca, err = ioutil.ReadFile(opts.CAFile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("certificate authority file %v not readable: %w", opts.CAFile, err)
|
||||
}
|
||||
}
|
||||
if ok := rootCAs.AppendCertsFromPEM(ca); !ok {
|
||||
return nil, fmt.Errorf("failed to append CA cert to certPool")
|
||||
}
|
||||
log.Debug().Msg("internal/grpc: added custom certificate authority")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cert := credentials.NewTLS(&tls.Config{RootCAs: rootCAs})
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue