From bfb218a79a489d8c2291b9ad34827f9853535ffc Mon Sep 17 00:00:00 2001 From: Caleb Doxsey Date: Tue, 30 Aug 2022 11:53:30 -0600 Subject: [PATCH] envoyconfig: add virtual host domains for certificates in addition to routes --- config/envoyconfig/listeners.go | 35 ++++++++++++++++++---------- config/envoyconfig/listeners_test.go | 14 +++++++++-- pkg/cryptutil/certificates.go | 11 +++++++++ pkg/cryptutil/tls.go | 23 ++++++++++++++++++ 4 files changed, 69 insertions(+), 14 deletions(-) diff --git a/config/envoyconfig/listeners.go b/config/envoyconfig/listeners.go index 4ce3d2fac..00ca00909 100644 --- a/config/envoyconfig/listeners.go +++ b/config/envoyconfig/listeners.go @@ -6,6 +6,7 @@ import ( "fmt" "net" "net/url" + "strings" "time" envoy_config_core_v3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" @@ -118,7 +119,7 @@ func (b *Builder) buildMainListener(ctx context.Context, cfg *config.Config) (*e } listenerFilters = append(listenerFilters, TLSInspectorFilter()) - chains, err := b.buildFilterChains(cfg.Options, cfg.Options.Addr, + chains, err := b.buildFilterChains(cfg, cfg.Options.Addr, func(tlsDomain string, httpDomains []string) (*envoy_config_listener_v3.FilterChain, error) { filter, err := b.buildMainHTTPConnectionManagerFilter(cfg.Options, httpDomains, tlsDomain) if err != nil { @@ -235,15 +236,15 @@ func (b *Builder) buildMetricsListener(cfg *config.Config) (*envoy_config_listen } func (b *Builder) buildFilterChains( - options *config.Options, addr string, + cfg *config.Config, addr string, callback func(tlsDomain string, httpDomains []string) (*envoy_config_listener_v3.FilterChain, error), ) ([]*envoy_config_listener_v3.FilterChain, error) { - allDomains, err := getAllRouteableDomains(options, addr) + allDomains, err := getAllRouteableDomains(cfg.Options, addr) if err != nil { return nil, err } - tlsDomains, err := getAllTLSDomains(options, addr) + tlsDomains, err := getAllTLSDomains(cfg, addr) if err != nil { return nil, err } @@ -251,7 +252,7 @@ func (b *Builder) buildFilterChains( var chains []*envoy_config_listener_v3.FilterChain chains = append(chains, b.buildACMETLSALPNFilterChain()) for _, domain := range tlsDomains { - routeableDomains, err := getRouteableDomainsForTLSServerName(options, addr, domain) + routeableDomains, err := getRouteableDomainsForTLSServerName(cfg.Options, addr, domain) if err != nil { return nil, err } @@ -341,7 +342,7 @@ func (b *Builder) buildMainHTTPConnectionManagerFilter( LuaFilter(luascripts.CleanUpstream), LuaFilter(luascripts.RewriteHeaders), } - if tlsDomain != "" && tlsDomain != "*" { + if tlsDomain != "" && !strings.Contains(tlsDomain, "*") { filters = append(filters, LuaFilter(fmt.Sprintf(luascripts.FixMisdirected, tlsDomain))) } filters = append(filters, HTTPRouterFilter()) @@ -438,7 +439,7 @@ func (b *Builder) buildGRPCListener(ctx context.Context, cfg *config.Config) (*e return li, nil } - chains, err := b.buildFilterChains(cfg.Options, cfg.Options.GRPCAddr, + chains, err := b.buildFilterChains(cfg, cfg.Options.GRPCAddr, func(tlsDomain string, httpDomains []string) (*envoy_config_listener_v3.FilterChain, error) { filterChain := &envoy_config_listener_v3.FilterChain{ Filters: []*envoy_config_listener_v3.Filter{filter}, @@ -658,14 +659,14 @@ func getAllRouteableDomains(options *config.Options, addr string) ([]string, err return allDomains.ToSlice(), nil } -func getAllTLSDomains(options *config.Options, addr string) ([]string, error) { - allDomains, err := getAllRouteableDomains(options, addr) +func getAllTLSDomains(cfg *config.Config, addr string) ([]string, error) { + domains := sets.NewSorted[string]() + + routeableDomains, err := getAllRouteableDomains(cfg.Options, addr) if err != nil { return nil, err } - - domains := sets.NewSorted[string]() - for _, hp := range allDomains { + for _, hp := range routeableDomains { if d, _, err := net.SplitHostPort(hp); err == nil { domains.Add(d) } else { @@ -673,6 +674,16 @@ func getAllTLSDomains(options *config.Options, addr string) ([]string, error) { } } + certs, err := cfg.AllCertificates() + if err != nil { + return nil, err + } + for i := range certs { + for _, domain := range cryptutil.GetCertificateDomains(&certs[i]) { + domains.Add(domain) + } + } + return domains.ToSlice(), nil } diff --git a/config/envoyconfig/listeners_test.go b/config/envoyconfig/listeners_test.go index 861dcac08..5ae10db13 100644 --- a/config/envoyconfig/listeners_test.go +++ b/config/envoyconfig/listeners_test.go @@ -2,6 +2,7 @@ package envoyconfig import ( "context" + "encoding/base64" "os" "path/filepath" "testing" @@ -13,6 +14,7 @@ import ( "github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/config/envoyconfig/filemgr" "github.com/pomerium/pomerium/internal/testutil" + "github.com/pomerium/pomerium/pkg/cryptutil" ) const ( @@ -726,6 +728,10 @@ func Test_buildDownstreamTLSContext(t *testing.T) { } func Test_getAllDomains(t *testing.T) { + cert, err := cryptutil.GenerateSelfSignedCertificate("*.unknown.example.com") + require.NoError(t, err) + certPEM, keyPEM := cryptutil.EncodeCertificate(cert) + options := &config.Options{ Addr: "127.0.0.1:9000", GRPCAddr: "127.0.0.1:9001", @@ -738,6 +744,8 @@ func Test_getAllDomains(t *testing.T) { {Source: &config.StringURL{URL: mustParseURL(t, "https://b.example.com")}}, {Source: &config.StringURL{URL: mustParseURL(t, "https://c.example.com")}}, }, + Cert: base64.StdEncoding.EncodeToString(certPEM), + Key: base64.StdEncoding.EncodeToString(keyPEM), } t.Run("routable", func(t *testing.T) { t.Run("http", func(t *testing.T) { @@ -786,9 +794,10 @@ func Test_getAllDomains(t *testing.T) { }) t.Run("tls", func(t *testing.T) { t.Run("http", func(t *testing.T) { - actual, err := getAllTLSDomains(options, "127.0.0.1:9000") + actual, err := getAllTLSDomains(&config.Config{Options: options}, "127.0.0.1:9000") require.NoError(t, err) expect := []string{ + "*.unknown.example.com", "a.example.com", "authenticate.example.com", "b.example.com", @@ -797,9 +806,10 @@ func Test_getAllDomains(t *testing.T) { assert.Equal(t, expect, actual) }) t.Run("grpc", func(t *testing.T) { - actual, err := getAllTLSDomains(options, "127.0.0.1:9001") + actual, err := getAllTLSDomains(&config.Config{Options: options}, "127.0.0.1:9001") require.NoError(t, err) expect := []string{ + "*.unknown.example.com", "authorize.example.com", "cache.example.com", } diff --git a/pkg/cryptutil/certificates.go b/pkg/cryptutil/certificates.go index e383a8f25..259a73734 100644 --- a/pkg/cryptutil/certificates.go +++ b/pkg/cryptutil/certificates.go @@ -219,6 +219,17 @@ func GenerateSelfSignedCertificate(domain string, configure ...func(*x509.Certif return &cert, nil } +// EncodeCertificate encodes a TLS certificate into PEM compatible byte slices. +func EncodeCertificate(cert *tls.Certificate) (pemCertificateBytes, pemKeyBytes []byte) { + publicKeyBytes := cert.Certificate[0] + privateKeyBytes, err := x509.MarshalPKCS8PrivateKey(cert.PrivateKey) + if err != nil { + return + } + return pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: publicKeyBytes}), + pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: privateKeyBytes}) +} + // ParsePEMCertificate parses a PEM encoded certificate block. func ParsePEMCertificate(raw []byte) (*x509.Certificate, error) { data := raw diff --git a/pkg/cryptutil/tls.go b/pkg/cryptutil/tls.go index 05140f228..7b25fcc23 100644 --- a/pkg/cryptutil/tls.go +++ b/pkg/cryptutil/tls.go @@ -63,6 +63,29 @@ func GetCertificateForDomain(certificates []tls.Certificate, domain string) (*tl return GenerateSelfSignedCertificate(domain) } +// GetCertificateDomains gets all the certificate's matching domain names. +func GetCertificateDomains(cert *tls.Certificate) []string { + if cert == nil || len(cert.Certificate) == 0 { + return nil + } + + xcert, err := x509.ParseCertificate(cert.Certificate[0]) + if err != nil { + return nil + } + + var domains []string + if xcert.Subject.CommonName != "" { + domains = append(domains, xcert.Subject.CommonName) + } + for _, dnsName := range xcert.DNSNames { + if dnsName != "" { + domains = append(domains, dnsName) + } + } + return domains +} + func matchesDomain(cert *tls.Certificate, domain string) bool { if cert == nil || len(cert.Certificate) == 0 { return false