diff --git a/config/envoyconfig/listeners.go b/config/envoyconfig/listeners.go index 4ce3d2fac..c19cffad4 100644 --- a/config/envoyconfig/listeners.go +++ b/config/envoyconfig/listeners.go @@ -6,6 +6,7 @@ import ( "fmt" "net" "net/url" + "strings" "time" envoy_config_core_v3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" @@ -118,7 +119,7 @@ func (b *Builder) buildMainListener(ctx context.Context, cfg *config.Config) (*e } listenerFilters = append(listenerFilters, TLSInspectorFilter()) - chains, err := b.buildFilterChains(cfg.Options, cfg.Options.Addr, + chains, err := b.buildFilterChains(cfg, cfg.Options.Addr, func(tlsDomain string, httpDomains []string) (*envoy_config_listener_v3.FilterChain, error) { filter, err := b.buildMainHTTPConnectionManagerFilter(cfg.Options, httpDomains, tlsDomain) if err != nil { @@ -235,15 +236,15 @@ func (b *Builder) buildMetricsListener(cfg *config.Config) (*envoy_config_listen } func (b *Builder) buildFilterChains( - options *config.Options, addr string, + cfg *config.Config, addr string, callback func(tlsDomain string, httpDomains []string) (*envoy_config_listener_v3.FilterChain, error), ) ([]*envoy_config_listener_v3.FilterChain, error) { - allDomains, err := getAllRouteableDomains(options, addr) + allDomains, err := getAllRouteableDomains(cfg.Options, addr) if err != nil { return nil, err } - tlsDomains, err := getAllTLSDomains(options, addr) + tlsDomains, err := getAllTLSDomains(cfg, addr) if err != nil { return nil, err } @@ -251,7 +252,7 @@ func (b *Builder) buildFilterChains( var chains []*envoy_config_listener_v3.FilterChain chains = append(chains, b.buildACMETLSALPNFilterChain()) for _, domain := range tlsDomains { - routeableDomains, err := getRouteableDomainsForTLSServerName(options, addr, domain) + routeableDomains, err := getRouteableDomainsForTLSServerName(cfg.Options, addr, domain) if err != nil { return nil, err } @@ -341,7 +342,9 @@ func (b *Builder) buildMainHTTPConnectionManagerFilter( LuaFilter(luascripts.CleanUpstream), LuaFilter(luascripts.RewriteHeaders), } - if tlsDomain != "" && tlsDomain != "*" { + // only return 421s for non-wildcard domains because the lua script doesn't understand how to + // parse wildcards properly + if tlsDomain != "" && !strings.Contains(tlsDomain, "*") { filters = append(filters, LuaFilter(fmt.Sprintf(luascripts.FixMisdirected, tlsDomain))) } filters = append(filters, HTTPRouterFilter()) @@ -438,7 +441,7 @@ func (b *Builder) buildGRPCListener(ctx context.Context, cfg *config.Config) (*e return li, nil } - chains, err := b.buildFilterChains(cfg.Options, cfg.Options.GRPCAddr, + chains, err := b.buildFilterChains(cfg, cfg.Options.GRPCAddr, func(tlsDomain string, httpDomains []string) (*envoy_config_listener_v3.FilterChain, error) { filterChain := &envoy_config_listener_v3.FilterChain{ Filters: []*envoy_config_listener_v3.Filter{filter}, @@ -658,14 +661,14 @@ func getAllRouteableDomains(options *config.Options, addr string) ([]string, err return allDomains.ToSlice(), nil } -func getAllTLSDomains(options *config.Options, addr string) ([]string, error) { - allDomains, err := getAllRouteableDomains(options, addr) +func getAllTLSDomains(cfg *config.Config, addr string) ([]string, error) { + domains := sets.NewSorted[string]() + + routeableDomains, err := getAllRouteableDomains(cfg.Options, addr) if err != nil { return nil, err } - - domains := sets.NewSorted[string]() - for _, hp := range allDomains { + for _, hp := range routeableDomains { if d, _, err := net.SplitHostPort(hp); err == nil { domains.Add(d) } else { @@ -673,6 +676,16 @@ func getAllTLSDomains(options *config.Options, addr string) ([]string, error) { } } + certs, err := cfg.AllCertificates() + if err != nil { + return nil, err + } + for i := range certs { + for _, domain := range cryptutil.GetCertificateDomains(&certs[i]) { + domains.Add(domain) + } + } + return domains.ToSlice(), nil } diff --git a/config/envoyconfig/listeners_test.go b/config/envoyconfig/listeners_test.go index 861dcac08..97cfb516e 100644 --- a/config/envoyconfig/listeners_test.go +++ b/config/envoyconfig/listeners_test.go @@ -2,6 +2,7 @@ package envoyconfig import ( "context" + "encoding/base64" "os" "path/filepath" "testing" @@ -13,6 +14,7 @@ import ( "github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/config/envoyconfig/filemgr" "github.com/pomerium/pomerium/internal/testutil" + "github.com/pomerium/pomerium/pkg/cryptutil" ) const ( @@ -726,6 +728,11 @@ func Test_buildDownstreamTLSContext(t *testing.T) { } func Test_getAllDomains(t *testing.T) { + cert, err := cryptutil.GenerateSelfSignedCertificate("*.unknown.example.com") + require.NoError(t, err) + certPEM, keyPEM, err := cryptutil.EncodeCertificate(cert) + require.NoError(t, err) + options := &config.Options{ Addr: "127.0.0.1:9000", GRPCAddr: "127.0.0.1:9001", @@ -738,6 +745,8 @@ func Test_getAllDomains(t *testing.T) { {Source: &config.StringURL{URL: mustParseURL(t, "https://b.example.com")}}, {Source: &config.StringURL{URL: mustParseURL(t, "https://c.example.com")}}, }, + Cert: base64.StdEncoding.EncodeToString(certPEM), + Key: base64.StdEncoding.EncodeToString(keyPEM), } t.Run("routable", func(t *testing.T) { t.Run("http", func(t *testing.T) { @@ -786,9 +795,10 @@ func Test_getAllDomains(t *testing.T) { }) t.Run("tls", func(t *testing.T) { t.Run("http", func(t *testing.T) { - actual, err := getAllTLSDomains(options, "127.0.0.1:9000") + actual, err := getAllTLSDomains(&config.Config{Options: options}, "127.0.0.1:9000") require.NoError(t, err) expect := []string{ + "*.unknown.example.com", "a.example.com", "authenticate.example.com", "b.example.com", @@ -797,9 +807,10 @@ func Test_getAllDomains(t *testing.T) { assert.Equal(t, expect, actual) }) t.Run("grpc", func(t *testing.T) { - actual, err := getAllTLSDomains(options, "127.0.0.1:9001") + actual, err := getAllTLSDomains(&config.Config{Options: options}, "127.0.0.1:9001") require.NoError(t, err) expect := []string{ + "*.unknown.example.com", "authorize.example.com", "cache.example.com", } diff --git a/pkg/cryptutil/certificates.go b/pkg/cryptutil/certificates.go index 68d5f4686..d35fb07f6 100644 --- a/pkg/cryptutil/certificates.go +++ b/pkg/cryptutil/certificates.go @@ -219,6 +219,21 @@ func GenerateSelfSignedCertificate(domain string, configure ...func(*x509.Certif return &cert, nil } +// EncodeCertificate encodes a TLS certificate into PEM compatible byte slices. +// Returns `nil`, `nil` if there is an error marshaling the PKCS8 private key. +func EncodeCertificate(cert *tls.Certificate) (pemCertificateBytes, pemKeyBytes []byte, err error) { + if cert == nil || len(cert.Certificate) == 0 { + return nil, nil, nil + } + publicKeyBytes := cert.Certificate[0] + privateKeyBytes, err := x509.MarshalPKCS8PrivateKey(cert.PrivateKey) + if err != nil { + return nil, nil, err + } + return pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: publicKeyBytes}), + pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: privateKeyBytes}), nil +} + // ParsePEMCertificate parses a PEM encoded certificate block. func ParsePEMCertificate(raw []byte) (*x509.Certificate, error) { data := raw diff --git a/pkg/cryptutil/certificates_test.go b/pkg/cryptutil/certificates_test.go index ba1e0e927..ae7d9d918 100644 --- a/pkg/cryptutil/certificates_test.go +++ b/pkg/cryptutil/certificates_test.go @@ -165,3 +165,18 @@ func TestPrivateKeyMarshaling(t *testing.T) { t.Fatal("private key encoding did not match") } } + +func TestEncodeCertificate(t *testing.T) { + t.Run("nil", func(t *testing.T) { + cert, key, err := EncodeCertificate(nil) + assert.NoError(t, err) + assert.Nil(t, cert) + assert.Nil(t, key) + }) + t.Run("empty certificate", func(t *testing.T) { + cert, key, err := EncodeCertificate(&tls.Certificate{}) + assert.NoError(t, err) + assert.Nil(t, cert) + assert.Nil(t, key) + }) +} diff --git a/pkg/cryptutil/tls.go b/pkg/cryptutil/tls.go index 05140f228..1d26e2c47 100644 --- a/pkg/cryptutil/tls.go +++ b/pkg/cryptutil/tls.go @@ -63,6 +63,30 @@ func GetCertificateForDomain(certificates []tls.Certificate, domain string) (*tl return GenerateSelfSignedCertificate(domain) } +// GetCertificateDomains gets all the certificate's matching domain names. +// Will return an empty slice if certificate is nil, empty, or x509 parsing fails. +func GetCertificateDomains(cert *tls.Certificate) []string { + if cert == nil || len(cert.Certificate) == 0 { + return nil + } + + xcert, err := x509.ParseCertificate(cert.Certificate[0]) + if err != nil { + return nil + } + + var domains []string + if xcert.Subject.CommonName != "" { + domains = append(domains, xcert.Subject.CommonName) + } + for _, dnsName := range xcert.DNSNames { + if dnsName != "" { + domains = append(domains, dnsName) + } + } + return domains +} + func matchesDomain(cert *tls.Certificate, domain string) bool { if cert == nil || len(cert.Certificate) == 0 { return false diff --git a/pkg/cryptutil/tls_test.go b/pkg/cryptutil/tls_test.go index 2e39e9a63..b62155fa8 100644 --- a/pkg/cryptutil/tls_test.go +++ b/pkg/cryptutil/tls_test.go @@ -5,6 +5,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestGetCertificateForDomain(t *testing.T) { @@ -62,3 +63,9 @@ func TestGetCertificateForDomain(t *testing.T) { assert.NotNil(t, found) }) } + +func TestGetCertificateDomains(t *testing.T) { + cert, err := GenerateSelfSignedCertificate("www.example.com") + require.NoError(t, err) + assert.Equal(t, []string{"www.example.com"}, GetCertificateDomains(cert)) +}