mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-23 14:07:11 +02:00
envoyconfig: add virtual host domains for certificates in addition to routes (#3593)
* envoyconfig: add virtual host domains for certificates in addition to routes * Update pkg/cryptutil/certificates.go Co-authored-by: bobby <1544881+desimone@users.noreply.github.com> * Update pkg/cryptutil/tls.go Co-authored-by: bobby <1544881+desimone@users.noreply.github.com> * comments Co-authored-by: bobby <1544881+desimone@users.noreply.github.com>
This commit is contained in:
parent
23c42da8ec
commit
33794ff316
6 changed files with 99 additions and 14 deletions
|
@ -6,6 +6,7 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
envoy_config_core_v3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
|
envoy_config_core_v3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
|
||||||
|
@ -118,7 +119,7 @@ func (b *Builder) buildMainListener(ctx context.Context, cfg *config.Config) (*e
|
||||||
}
|
}
|
||||||
listenerFilters = append(listenerFilters, TLSInspectorFilter())
|
listenerFilters = append(listenerFilters, TLSInspectorFilter())
|
||||||
|
|
||||||
chains, err := b.buildFilterChains(cfg.Options, cfg.Options.Addr,
|
chains, err := b.buildFilterChains(cfg, cfg.Options.Addr,
|
||||||
func(tlsDomain string, httpDomains []string) (*envoy_config_listener_v3.FilterChain, error) {
|
func(tlsDomain string, httpDomains []string) (*envoy_config_listener_v3.FilterChain, error) {
|
||||||
filter, err := b.buildMainHTTPConnectionManagerFilter(cfg.Options, httpDomains, tlsDomain)
|
filter, err := b.buildMainHTTPConnectionManagerFilter(cfg.Options, httpDomains, tlsDomain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -235,15 +236,15 @@ func (b *Builder) buildMetricsListener(cfg *config.Config) (*envoy_config_listen
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *Builder) buildFilterChains(
|
func (b *Builder) buildFilterChains(
|
||||||
options *config.Options, addr string,
|
cfg *config.Config, addr string,
|
||||||
callback func(tlsDomain string, httpDomains []string) (*envoy_config_listener_v3.FilterChain, error),
|
callback func(tlsDomain string, httpDomains []string) (*envoy_config_listener_v3.FilterChain, error),
|
||||||
) ([]*envoy_config_listener_v3.FilterChain, error) {
|
) ([]*envoy_config_listener_v3.FilterChain, error) {
|
||||||
allDomains, err := getAllRouteableDomains(options, addr)
|
allDomains, err := getAllRouteableDomains(cfg.Options, addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
tlsDomains, err := getAllTLSDomains(options, addr)
|
tlsDomains, err := getAllTLSDomains(cfg, addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -251,7 +252,7 @@ func (b *Builder) buildFilterChains(
|
||||||
var chains []*envoy_config_listener_v3.FilterChain
|
var chains []*envoy_config_listener_v3.FilterChain
|
||||||
chains = append(chains, b.buildACMETLSALPNFilterChain())
|
chains = append(chains, b.buildACMETLSALPNFilterChain())
|
||||||
for _, domain := range tlsDomains {
|
for _, domain := range tlsDomains {
|
||||||
routeableDomains, err := getRouteableDomainsForTLSServerName(options, addr, domain)
|
routeableDomains, err := getRouteableDomainsForTLSServerName(cfg.Options, addr, domain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -341,7 +342,9 @@ func (b *Builder) buildMainHTTPConnectionManagerFilter(
|
||||||
LuaFilter(luascripts.CleanUpstream),
|
LuaFilter(luascripts.CleanUpstream),
|
||||||
LuaFilter(luascripts.RewriteHeaders),
|
LuaFilter(luascripts.RewriteHeaders),
|
||||||
}
|
}
|
||||||
if tlsDomain != "" && tlsDomain != "*" {
|
// only return 421s for non-wildcard domains because the lua script doesn't understand how to
|
||||||
|
// parse wildcards properly
|
||||||
|
if tlsDomain != "" && !strings.Contains(tlsDomain, "*") {
|
||||||
filters = append(filters, LuaFilter(fmt.Sprintf(luascripts.FixMisdirected, tlsDomain)))
|
filters = append(filters, LuaFilter(fmt.Sprintf(luascripts.FixMisdirected, tlsDomain)))
|
||||||
}
|
}
|
||||||
filters = append(filters, HTTPRouterFilter())
|
filters = append(filters, HTTPRouterFilter())
|
||||||
|
@ -438,7 +441,7 @@ func (b *Builder) buildGRPCListener(ctx context.Context, cfg *config.Config) (*e
|
||||||
return li, nil
|
return li, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
chains, err := b.buildFilterChains(cfg.Options, cfg.Options.GRPCAddr,
|
chains, err := b.buildFilterChains(cfg, cfg.Options.GRPCAddr,
|
||||||
func(tlsDomain string, httpDomains []string) (*envoy_config_listener_v3.FilterChain, error) {
|
func(tlsDomain string, httpDomains []string) (*envoy_config_listener_v3.FilterChain, error) {
|
||||||
filterChain := &envoy_config_listener_v3.FilterChain{
|
filterChain := &envoy_config_listener_v3.FilterChain{
|
||||||
Filters: []*envoy_config_listener_v3.Filter{filter},
|
Filters: []*envoy_config_listener_v3.Filter{filter},
|
||||||
|
@ -658,14 +661,14 @@ func getAllRouteableDomains(options *config.Options, addr string) ([]string, err
|
||||||
return allDomains.ToSlice(), nil
|
return allDomains.ToSlice(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func getAllTLSDomains(options *config.Options, addr string) ([]string, error) {
|
func getAllTLSDomains(cfg *config.Config, addr string) ([]string, error) {
|
||||||
allDomains, err := getAllRouteableDomains(options, addr)
|
domains := sets.NewSorted[string]()
|
||||||
|
|
||||||
|
routeableDomains, err := getAllRouteableDomains(cfg.Options, addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
for _, hp := range routeableDomains {
|
||||||
domains := sets.NewSorted[string]()
|
|
||||||
for _, hp := range allDomains {
|
|
||||||
if d, _, err := net.SplitHostPort(hp); err == nil {
|
if d, _, err := net.SplitHostPort(hp); err == nil {
|
||||||
domains.Add(d)
|
domains.Add(d)
|
||||||
} else {
|
} else {
|
||||||
|
@ -673,6 +676,16 @@ func getAllTLSDomains(options *config.Options, addr string) ([]string, error) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
certs, err := cfg.AllCertificates()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
for i := range certs {
|
||||||
|
for _, domain := range cryptutil.GetCertificateDomains(&certs[i]) {
|
||||||
|
domains.Add(domain)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return domains.ToSlice(), nil
|
return domains.ToSlice(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -2,6 +2,7 @@ package envoyconfig
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"testing"
|
"testing"
|
||||||
|
@ -13,6 +14,7 @@ import (
|
||||||
"github.com/pomerium/pomerium/config"
|
"github.com/pomerium/pomerium/config"
|
||||||
"github.com/pomerium/pomerium/config/envoyconfig/filemgr"
|
"github.com/pomerium/pomerium/config/envoyconfig/filemgr"
|
||||||
"github.com/pomerium/pomerium/internal/testutil"
|
"github.com/pomerium/pomerium/internal/testutil"
|
||||||
|
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -726,6 +728,11 @@ func Test_buildDownstreamTLSContext(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_getAllDomains(t *testing.T) {
|
func Test_getAllDomains(t *testing.T) {
|
||||||
|
cert, err := cryptutil.GenerateSelfSignedCertificate("*.unknown.example.com")
|
||||||
|
require.NoError(t, err)
|
||||||
|
certPEM, keyPEM, err := cryptutil.EncodeCertificate(cert)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
options := &config.Options{
|
options := &config.Options{
|
||||||
Addr: "127.0.0.1:9000",
|
Addr: "127.0.0.1:9000",
|
||||||
GRPCAddr: "127.0.0.1:9001",
|
GRPCAddr: "127.0.0.1:9001",
|
||||||
|
@ -738,6 +745,8 @@ func Test_getAllDomains(t *testing.T) {
|
||||||
{Source: &config.StringURL{URL: mustParseURL(t, "https://b.example.com")}},
|
{Source: &config.StringURL{URL: mustParseURL(t, "https://b.example.com")}},
|
||||||
{Source: &config.StringURL{URL: mustParseURL(t, "https://c.example.com")}},
|
{Source: &config.StringURL{URL: mustParseURL(t, "https://c.example.com")}},
|
||||||
},
|
},
|
||||||
|
Cert: base64.StdEncoding.EncodeToString(certPEM),
|
||||||
|
Key: base64.StdEncoding.EncodeToString(keyPEM),
|
||||||
}
|
}
|
||||||
t.Run("routable", func(t *testing.T) {
|
t.Run("routable", func(t *testing.T) {
|
||||||
t.Run("http", func(t *testing.T) {
|
t.Run("http", func(t *testing.T) {
|
||||||
|
@ -786,9 +795,10 @@ func Test_getAllDomains(t *testing.T) {
|
||||||
})
|
})
|
||||||
t.Run("tls", func(t *testing.T) {
|
t.Run("tls", func(t *testing.T) {
|
||||||
t.Run("http", func(t *testing.T) {
|
t.Run("http", func(t *testing.T) {
|
||||||
actual, err := getAllTLSDomains(options, "127.0.0.1:9000")
|
actual, err := getAllTLSDomains(&config.Config{Options: options}, "127.0.0.1:9000")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
expect := []string{
|
expect := []string{
|
||||||
|
"*.unknown.example.com",
|
||||||
"a.example.com",
|
"a.example.com",
|
||||||
"authenticate.example.com",
|
"authenticate.example.com",
|
||||||
"b.example.com",
|
"b.example.com",
|
||||||
|
@ -797,9 +807,10 @@ func Test_getAllDomains(t *testing.T) {
|
||||||
assert.Equal(t, expect, actual)
|
assert.Equal(t, expect, actual)
|
||||||
})
|
})
|
||||||
t.Run("grpc", func(t *testing.T) {
|
t.Run("grpc", func(t *testing.T) {
|
||||||
actual, err := getAllTLSDomains(options, "127.0.0.1:9001")
|
actual, err := getAllTLSDomains(&config.Config{Options: options}, "127.0.0.1:9001")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
expect := []string{
|
expect := []string{
|
||||||
|
"*.unknown.example.com",
|
||||||
"authorize.example.com",
|
"authorize.example.com",
|
||||||
"cache.example.com",
|
"cache.example.com",
|
||||||
}
|
}
|
||||||
|
|
|
@ -219,6 +219,21 @@ func GenerateSelfSignedCertificate(domain string, configure ...func(*x509.Certif
|
||||||
return &cert, nil
|
return &cert, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// EncodeCertificate encodes a TLS certificate into PEM compatible byte slices.
|
||||||
|
// Returns `nil`, `nil` if there is an error marshaling the PKCS8 private key.
|
||||||
|
func EncodeCertificate(cert *tls.Certificate) (pemCertificateBytes, pemKeyBytes []byte, err error) {
|
||||||
|
if cert == nil || len(cert.Certificate) == 0 {
|
||||||
|
return nil, nil, nil
|
||||||
|
}
|
||||||
|
publicKeyBytes := cert.Certificate[0]
|
||||||
|
privateKeyBytes, err := x509.MarshalPKCS8PrivateKey(cert.PrivateKey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
return pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: publicKeyBytes}),
|
||||||
|
pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: privateKeyBytes}), nil
|
||||||
|
}
|
||||||
|
|
||||||
// ParsePEMCertificate parses a PEM encoded certificate block.
|
// ParsePEMCertificate parses a PEM encoded certificate block.
|
||||||
func ParsePEMCertificate(raw []byte) (*x509.Certificate, error) {
|
func ParsePEMCertificate(raw []byte) (*x509.Certificate, error) {
|
||||||
data := raw
|
data := raw
|
||||||
|
|
|
@ -165,3 +165,18 @@ func TestPrivateKeyMarshaling(t *testing.T) {
|
||||||
t.Fatal("private key encoding did not match")
|
t.Fatal("private key encoding did not match")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestEncodeCertificate(t *testing.T) {
|
||||||
|
t.Run("nil", func(t *testing.T) {
|
||||||
|
cert, key, err := EncodeCertificate(nil)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Nil(t, cert)
|
||||||
|
assert.Nil(t, key)
|
||||||
|
})
|
||||||
|
t.Run("empty certificate", func(t *testing.T) {
|
||||||
|
cert, key, err := EncodeCertificate(&tls.Certificate{})
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Nil(t, cert)
|
||||||
|
assert.Nil(t, key)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
|
@ -63,6 +63,30 @@ func GetCertificateForDomain(certificates []tls.Certificate, domain string) (*tl
|
||||||
return GenerateSelfSignedCertificate(domain)
|
return GenerateSelfSignedCertificate(domain)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetCertificateDomains gets all the certificate's matching domain names.
|
||||||
|
// Will return an empty slice if certificate is nil, empty, or x509 parsing fails.
|
||||||
|
func GetCertificateDomains(cert *tls.Certificate) []string {
|
||||||
|
if cert == nil || len(cert.Certificate) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
xcert, err := x509.ParseCertificate(cert.Certificate[0])
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var domains []string
|
||||||
|
if xcert.Subject.CommonName != "" {
|
||||||
|
domains = append(domains, xcert.Subject.CommonName)
|
||||||
|
}
|
||||||
|
for _, dnsName := range xcert.DNSNames {
|
||||||
|
if dnsName != "" {
|
||||||
|
domains = append(domains, dnsName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return domains
|
||||||
|
}
|
||||||
|
|
||||||
func matchesDomain(cert *tls.Certificate, domain string) bool {
|
func matchesDomain(cert *tls.Certificate, domain string) bool {
|
||||||
if cert == nil || len(cert.Certificate) == 0 {
|
if cert == nil || len(cert.Certificate) == 0 {
|
||||||
return false
|
return false
|
||||||
|
|
|
@ -5,6 +5,7 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestGetCertificateForDomain(t *testing.T) {
|
func TestGetCertificateForDomain(t *testing.T) {
|
||||||
|
@ -62,3 +63,9 @@ func TestGetCertificateForDomain(t *testing.T) {
|
||||||
assert.NotNil(t, found)
|
assert.NotNil(t, found)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGetCertificateDomains(t *testing.T) {
|
||||||
|
cert, err := GenerateSelfSignedCertificate("www.example.com")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, []string{"www.example.com"}, GetCertificateDomains(cert))
|
||||||
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue