config: fix DefaultTransport so it is still a *http.Transport (#3257)

* config: fix DefaultTransport so it is still a *http.Transport

* remove printlns

* Update config/http.go

Co-authored-by: Denis Mishin <dmishin@pomerium.com>

* remove unnecessary check

Co-authored-by: Denis Mishin <dmishin@pomerium.com>
This commit is contained in:
Caleb Doxsey 2022-04-08 11:07:37 -06:00 committed by GitHub
parent 424a7e4de8
commit c5550d28ed
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -3,63 +3,49 @@ package config
import ( import (
"context" "context"
"crypto/tls" "crypto/tls"
"net"
"net/http" "net/http"
"sync/atomic" "sync"
"github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/tripper" "github.com/pomerium/pomerium/internal/tripper"
"github.com/pomerium/pomerium/pkg/cryptutil" "github.com/pomerium/pomerium/pkg/cryptutil"
"github.com/rs/zerolog"
) )
type httpTransport struct {
underlying *http.Transport
transport atomic.Value
}
// NewHTTPTransport creates a new http transport. If CA or CAFile is set, the transport will // NewHTTPTransport creates a new http transport. If CA or CAFile is set, the transport will
// add the CA to system cert pool. // add the CA to system cert pool.
func NewHTTPTransport(src Source) http.RoundTripper { func NewHTTPTransport(src Source) *http.Transport {
ctx := log.WithContext(context.TODO(), func(c zerolog.Context) zerolog.Context { var (
return c.Caller() lock sync.Mutex
}) tlsConfig *tls.Config
t := new(httpTransport) )
t.underlying, _ = http.DefaultTransport.(*http.Transport) update := func(ctx context.Context, cfg *Config) {
src.OnConfigChange(ctx, func(ctx context.Context, cfg *Config) { rootCAs, err := cryptutil.GetCertPool(cfg.Options.CA, cfg.Options.CAFile)
t.update(ctx, cfg.Options)
})
t.update(ctx, src.GetConfig().Options)
return t
}
func (t *httpTransport) update(ctx context.Context, options *Options) {
nt := new(http.Transport)
if t.underlying != nil {
nt = t.underlying.Clone()
}
if options.CA != "" || options.CAFile != "" {
rootCAs, err := cryptutil.GetCertPool(options.CA, options.CAFile)
if err == nil { if err == nil {
nt.TLSClientConfig = &tls.Config{ lock.Lock()
tlsConfig = &tls.Config{
RootCAs: rootCAs, RootCAs: rootCAs,
MinVersion: tls.VersionTLS12, MinVersion: tls.VersionTLS12,
} }
lock.Unlock()
} else { } else {
log.Error(ctx).Err(err).Msg("config: error getting cert pool") log.Error(ctx).Err(err).Msg("config: error getting cert pool")
} }
} }
t.transport.Store(nt) src.OnConfigChange(context.Background(), update)
} update(context.Background(), src.GetConfig())
// RoundTrip executes an HTTP request. transport := http.DefaultTransport.(*http.Transport).Clone()
func (t *httpTransport) RoundTrip(req *http.Request) (*http.Response, error) { transport.DialTLSContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
return t.transport.Load().(http.RoundTripper).RoundTrip(req) lock.Lock()
} d := &tls.Dialer{
Config: tlsConfig,
// Clone returns a clone of the transport. }
func (t *httpTransport) Clone() *http.Transport { lock.Unlock()
return t.transport.Load().(*http.Transport).Clone() return d.DialContext(ctx, network, addr)
}
transport.ForceAttemptHTTP2 = true
return transport
} }
// NewPolicyHTTPTransport creates a new http RoundTripper for a policy. // NewPolicyHTTPTransport creates a new http RoundTripper for a policy.