package config

import (
	"context"
	"crypto/tls"
	"net"
	"net/http"
	"sync"

	"github.com/pomerium/pomerium/internal/log"
	"github.com/pomerium/pomerium/internal/tripper"
	"github.com/pomerium/pomerium/pkg/cryptutil"
)

// NewHTTPTransport creates a new http transport. If CA or CAFile is set, the transport will
// add the CA to system cert pool.
func NewHTTPTransport(src Source) *http.Transport {
	var (
		lock      sync.Mutex
		tlsConfig *tls.Config
	)
	update := func(ctx context.Context, cfg *Config) {
		rootCAs, err := cryptutil.GetCertPool(cfg.Options.CA, cfg.Options.CAFile)
		if err == nil {
			lock.Lock()
			tlsConfig = &tls.Config{
				RootCAs:    rootCAs,
				MinVersion: tls.VersionTLS12,
			}
			lock.Unlock()
		} else {
			log.Error(ctx).Err(err).Msg("config: error getting cert pool")
		}
	}
	src.OnConfigChange(context.Background(), update)
	update(context.Background(), src.GetConfig())

	transport := http.DefaultTransport.(*http.Transport).Clone()
	transport.DialTLSContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
		lock.Lock()
		d := &tls.Dialer{
			Config: tlsConfig,
		}
		lock.Unlock()
		return d.DialContext(ctx, network, addr)
	}
	transport.ForceAttemptHTTP2 = true
	return transport
}

// NewPolicyHTTPTransport creates a new http RoundTripper for a policy.
func NewPolicyHTTPTransport(options *Options, policy *Policy, disableHTTP2 bool) http.RoundTripper {
	transport := http.DefaultTransport.(interface {
		Clone() *http.Transport
	}).Clone()
	c := tripper.NewChain()

	// according to the docs:
	//
	//    Programs that must disable HTTP/2 can do so by setting Transport.TLSNextProto (for clients) or
	//    Server.TLSNextProto (for servers) to a non-nil, empty map.
	//
	if disableHTTP2 {
		transport.TLSNextProto = map[string]func(authority string, c *tls.Conn) http.RoundTripper{}
		transport.ForceAttemptHTTP2 = false
	}

	var tlsClientConfig tls.Config
	var isCustomClientConfig bool

	if policy.TLSSkipVerify {
		tlsClientConfig.InsecureSkipVerify = true
		isCustomClientConfig = true
	}

	if options.CA != "" || options.CAFile != "" {
		rootCAs, err := cryptutil.GetCertPool(options.CA, options.CAFile)
		if err == nil {
			tlsClientConfig.RootCAs = rootCAs
			tlsClientConfig.MinVersion = tls.VersionTLS12
			isCustomClientConfig = true
		} else {
			log.Error(context.TODO()).Err(err).Msg("config: error getting ca cert pool")
		}
	}

	if policy.TLSCustomCA != "" || policy.TLSCustomCAFile != "" {
		rootCAs, err := cryptutil.GetCertPool(policy.TLSCustomCA, policy.TLSCustomCAFile)
		if err == nil {
			tlsClientConfig.RootCAs = rootCAs
			tlsClientConfig.MinVersion = tls.VersionTLS12
			isCustomClientConfig = true
		} else {
			log.Error(context.TODO()).Err(err).Msg("config: error getting custom ca cert pool")
		}
	}

	if policy.ClientCertificate != nil {
		tlsClientConfig.Certificates = []tls.Certificate{*policy.ClientCertificate}
		isCustomClientConfig = true
	}

	if policy.TLSServerName != "" {
		tlsClientConfig.ServerName = policy.TLSServerName
		isCustomClientConfig = true
	}
	if policy.TLSUpstreamServerName != "" {
		tlsClientConfig.ServerName = policy.TLSUpstreamServerName
		isCustomClientConfig = true
	}

	// We avoid setting a custom client config unless we have to as
	// if TLSClientConfig is nil, the default configuration is used.
	if isCustomClientConfig {
		transport.DialTLSContext = nil
		transport.TLSClientConfig = &tlsClientConfig
	}
	return c.Then(transport)
}

// GetTLSClientTransport returns http transport accounting for custom CAs from config
func GetTLSClientTransport(cfg *Config) (*http.Transport, error) {
	tlsConfig, err := cfg.GetTLSClientConfig()
	if err != nil {
		return nil, err
	}
	return &http.Transport{
		TLSClientConfig:   tlsConfig,
		ForceAttemptHTTP2: true,
	}, nil
}