envoyconfig: add virtual host domains for certificates in addition to routes (#3593)

* envoyconfig: add virtual host domains for certificates in addition to routes

* Update pkg/cryptutil/certificates.go

Co-authored-by: bobby <1544881+desimone@users.noreply.github.com>

* Update pkg/cryptutil/tls.go

Co-authored-by: bobby <1544881+desimone@users.noreply.github.com>

* comments

Co-authored-by: bobby <1544881+desimone@users.noreply.github.com>
This commit is contained in:
Caleb Doxsey 2022-08-31 10:35:45 -06:00 committed by GitHub
parent 23c42da8ec
commit 33794ff316
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 99 additions and 14 deletions

View file

@ -6,6 +6,7 @@ import (
"fmt" "fmt"
"net" "net"
"net/url" "net/url"
"strings"
"time" "time"
envoy_config_core_v3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" envoy_config_core_v3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
@ -118,7 +119,7 @@ func (b *Builder) buildMainListener(ctx context.Context, cfg *config.Config) (*e
} }
listenerFilters = append(listenerFilters, TLSInspectorFilter()) listenerFilters = append(listenerFilters, TLSInspectorFilter())
chains, err := b.buildFilterChains(cfg.Options, cfg.Options.Addr, chains, err := b.buildFilterChains(cfg, cfg.Options.Addr,
func(tlsDomain string, httpDomains []string) (*envoy_config_listener_v3.FilterChain, error) { func(tlsDomain string, httpDomains []string) (*envoy_config_listener_v3.FilterChain, error) {
filter, err := b.buildMainHTTPConnectionManagerFilter(cfg.Options, httpDomains, tlsDomain) filter, err := b.buildMainHTTPConnectionManagerFilter(cfg.Options, httpDomains, tlsDomain)
if err != nil { if err != nil {
@ -235,15 +236,15 @@ func (b *Builder) buildMetricsListener(cfg *config.Config) (*envoy_config_listen
} }
func (b *Builder) buildFilterChains( func (b *Builder) buildFilterChains(
options *config.Options, addr string, cfg *config.Config, addr string,
callback func(tlsDomain string, httpDomains []string) (*envoy_config_listener_v3.FilterChain, error), callback func(tlsDomain string, httpDomains []string) (*envoy_config_listener_v3.FilterChain, error),
) ([]*envoy_config_listener_v3.FilterChain, error) { ) ([]*envoy_config_listener_v3.FilterChain, error) {
allDomains, err := getAllRouteableDomains(options, addr) allDomains, err := getAllRouteableDomains(cfg.Options, addr)
if err != nil { if err != nil {
return nil, err return nil, err
} }
tlsDomains, err := getAllTLSDomains(options, addr) tlsDomains, err := getAllTLSDomains(cfg, addr)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -251,7 +252,7 @@ func (b *Builder) buildFilterChains(
var chains []*envoy_config_listener_v3.FilterChain var chains []*envoy_config_listener_v3.FilterChain
chains = append(chains, b.buildACMETLSALPNFilterChain()) chains = append(chains, b.buildACMETLSALPNFilterChain())
for _, domain := range tlsDomains { for _, domain := range tlsDomains {
routeableDomains, err := getRouteableDomainsForTLSServerName(options, addr, domain) routeableDomains, err := getRouteableDomainsForTLSServerName(cfg.Options, addr, domain)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -341,7 +342,9 @@ func (b *Builder) buildMainHTTPConnectionManagerFilter(
LuaFilter(luascripts.CleanUpstream), LuaFilter(luascripts.CleanUpstream),
LuaFilter(luascripts.RewriteHeaders), LuaFilter(luascripts.RewriteHeaders),
} }
if tlsDomain != "" && tlsDomain != "*" { // only return 421s for non-wildcard domains because the lua script doesn't understand how to
// parse wildcards properly
if tlsDomain != "" && !strings.Contains(tlsDomain, "*") {
filters = append(filters, LuaFilter(fmt.Sprintf(luascripts.FixMisdirected, tlsDomain))) filters = append(filters, LuaFilter(fmt.Sprintf(luascripts.FixMisdirected, tlsDomain)))
} }
filters = append(filters, HTTPRouterFilter()) filters = append(filters, HTTPRouterFilter())
@ -438,7 +441,7 @@ func (b *Builder) buildGRPCListener(ctx context.Context, cfg *config.Config) (*e
return li, nil return li, nil
} }
chains, err := b.buildFilterChains(cfg.Options, cfg.Options.GRPCAddr, chains, err := b.buildFilterChains(cfg, cfg.Options.GRPCAddr,
func(tlsDomain string, httpDomains []string) (*envoy_config_listener_v3.FilterChain, error) { func(tlsDomain string, httpDomains []string) (*envoy_config_listener_v3.FilterChain, error) {
filterChain := &envoy_config_listener_v3.FilterChain{ filterChain := &envoy_config_listener_v3.FilterChain{
Filters: []*envoy_config_listener_v3.Filter{filter}, Filters: []*envoy_config_listener_v3.Filter{filter},
@ -658,14 +661,14 @@ func getAllRouteableDomains(options *config.Options, addr string) ([]string, err
return allDomains.ToSlice(), nil return allDomains.ToSlice(), nil
} }
func getAllTLSDomains(options *config.Options, addr string) ([]string, error) { func getAllTLSDomains(cfg *config.Config, addr string) ([]string, error) {
allDomains, err := getAllRouteableDomains(options, addr) domains := sets.NewSorted[string]()
routeableDomains, err := getAllRouteableDomains(cfg.Options, addr)
if err != nil { if err != nil {
return nil, err return nil, err
} }
for _, hp := range routeableDomains {
domains := sets.NewSorted[string]()
for _, hp := range allDomains {
if d, _, err := net.SplitHostPort(hp); err == nil { if d, _, err := net.SplitHostPort(hp); err == nil {
domains.Add(d) domains.Add(d)
} else { } else {
@ -673,6 +676,16 @@ func getAllTLSDomains(options *config.Options, addr string) ([]string, error) {
} }
} }
certs, err := cfg.AllCertificates()
if err != nil {
return nil, err
}
for i := range certs {
for _, domain := range cryptutil.GetCertificateDomains(&certs[i]) {
domains.Add(domain)
}
}
return domains.ToSlice(), nil return domains.ToSlice(), nil
} }

View file

@ -2,6 +2,7 @@ package envoyconfig
import ( import (
"context" "context"
"encoding/base64"
"os" "os"
"path/filepath" "path/filepath"
"testing" "testing"
@ -13,6 +14,7 @@ import (
"github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/config/envoyconfig/filemgr" "github.com/pomerium/pomerium/config/envoyconfig/filemgr"
"github.com/pomerium/pomerium/internal/testutil" "github.com/pomerium/pomerium/internal/testutil"
"github.com/pomerium/pomerium/pkg/cryptutil"
) )
const ( const (
@ -726,6 +728,11 @@ func Test_buildDownstreamTLSContext(t *testing.T) {
} }
func Test_getAllDomains(t *testing.T) { func Test_getAllDomains(t *testing.T) {
cert, err := cryptutil.GenerateSelfSignedCertificate("*.unknown.example.com")
require.NoError(t, err)
certPEM, keyPEM, err := cryptutil.EncodeCertificate(cert)
require.NoError(t, err)
options := &config.Options{ options := &config.Options{
Addr: "127.0.0.1:9000", Addr: "127.0.0.1:9000",
GRPCAddr: "127.0.0.1:9001", GRPCAddr: "127.0.0.1:9001",
@ -738,6 +745,8 @@ func Test_getAllDomains(t *testing.T) {
{Source: &config.StringURL{URL: mustParseURL(t, "https://b.example.com")}}, {Source: &config.StringURL{URL: mustParseURL(t, "https://b.example.com")}},
{Source: &config.StringURL{URL: mustParseURL(t, "https://c.example.com")}}, {Source: &config.StringURL{URL: mustParseURL(t, "https://c.example.com")}},
}, },
Cert: base64.StdEncoding.EncodeToString(certPEM),
Key: base64.StdEncoding.EncodeToString(keyPEM),
} }
t.Run("routable", func(t *testing.T) { t.Run("routable", func(t *testing.T) {
t.Run("http", func(t *testing.T) { t.Run("http", func(t *testing.T) {
@ -786,9 +795,10 @@ func Test_getAllDomains(t *testing.T) {
}) })
t.Run("tls", func(t *testing.T) { t.Run("tls", func(t *testing.T) {
t.Run("http", func(t *testing.T) { t.Run("http", func(t *testing.T) {
actual, err := getAllTLSDomains(options, "127.0.0.1:9000") actual, err := getAllTLSDomains(&config.Config{Options: options}, "127.0.0.1:9000")
require.NoError(t, err) require.NoError(t, err)
expect := []string{ expect := []string{
"*.unknown.example.com",
"a.example.com", "a.example.com",
"authenticate.example.com", "authenticate.example.com",
"b.example.com", "b.example.com",
@ -797,9 +807,10 @@ func Test_getAllDomains(t *testing.T) {
assert.Equal(t, expect, actual) assert.Equal(t, expect, actual)
}) })
t.Run("grpc", func(t *testing.T) { t.Run("grpc", func(t *testing.T) {
actual, err := getAllTLSDomains(options, "127.0.0.1:9001") actual, err := getAllTLSDomains(&config.Config{Options: options}, "127.0.0.1:9001")
require.NoError(t, err) require.NoError(t, err)
expect := []string{ expect := []string{
"*.unknown.example.com",
"authorize.example.com", "authorize.example.com",
"cache.example.com", "cache.example.com",
} }

View file

@ -219,6 +219,21 @@ func GenerateSelfSignedCertificate(domain string, configure ...func(*x509.Certif
return &cert, nil return &cert, nil
} }
// EncodeCertificate encodes a TLS certificate into PEM compatible byte slices.
// Returns `nil`, `nil` if there is an error marshaling the PKCS8 private key.
func EncodeCertificate(cert *tls.Certificate) (pemCertificateBytes, pemKeyBytes []byte, err error) {
if cert == nil || len(cert.Certificate) == 0 {
return nil, nil, nil
}
publicKeyBytes := cert.Certificate[0]
privateKeyBytes, err := x509.MarshalPKCS8PrivateKey(cert.PrivateKey)
if err != nil {
return nil, nil, err
}
return pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: publicKeyBytes}),
pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: privateKeyBytes}), nil
}
// ParsePEMCertificate parses a PEM encoded certificate block. // ParsePEMCertificate parses a PEM encoded certificate block.
func ParsePEMCertificate(raw []byte) (*x509.Certificate, error) { func ParsePEMCertificate(raw []byte) (*x509.Certificate, error) {
data := raw data := raw

View file

@ -165,3 +165,18 @@ func TestPrivateKeyMarshaling(t *testing.T) {
t.Fatal("private key encoding did not match") t.Fatal("private key encoding did not match")
} }
} }
func TestEncodeCertificate(t *testing.T) {
t.Run("nil", func(t *testing.T) {
cert, key, err := EncodeCertificate(nil)
assert.NoError(t, err)
assert.Nil(t, cert)
assert.Nil(t, key)
})
t.Run("empty certificate", func(t *testing.T) {
cert, key, err := EncodeCertificate(&tls.Certificate{})
assert.NoError(t, err)
assert.Nil(t, cert)
assert.Nil(t, key)
})
}

View file

@ -63,6 +63,30 @@ func GetCertificateForDomain(certificates []tls.Certificate, domain string) (*tl
return GenerateSelfSignedCertificate(domain) return GenerateSelfSignedCertificate(domain)
} }
// GetCertificateDomains gets all the certificate's matching domain names.
// Will return an empty slice if certificate is nil, empty, or x509 parsing fails.
func GetCertificateDomains(cert *tls.Certificate) []string {
if cert == nil || len(cert.Certificate) == 0 {
return nil
}
xcert, err := x509.ParseCertificate(cert.Certificate[0])
if err != nil {
return nil
}
var domains []string
if xcert.Subject.CommonName != "" {
domains = append(domains, xcert.Subject.CommonName)
}
for _, dnsName := range xcert.DNSNames {
if dnsName != "" {
domains = append(domains, dnsName)
}
}
return domains
}
func matchesDomain(cert *tls.Certificate, domain string) bool { func matchesDomain(cert *tls.Certificate, domain string) bool {
if cert == nil || len(cert.Certificate) == 0 { if cert == nil || len(cert.Certificate) == 0 {
return false return false

View file

@ -5,6 +5,7 @@ import (
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func TestGetCertificateForDomain(t *testing.T) { func TestGetCertificateForDomain(t *testing.T) {
@ -62,3 +63,9 @@ func TestGetCertificateForDomain(t *testing.T) {
assert.NotNil(t, found) assert.NotNil(t, found)
}) })
} }
func TestGetCertificateDomains(t *testing.T) {
cert, err := GenerateSelfSignedCertificate("www.example.com")
require.NoError(t, err)
assert.Equal(t, []string{"www.example.com"}, GetCertificateDomains(cert))
}