envoyconfig: clean up filter chain construction (#3844)

* cleanup filter chain construction

* rename domains to server names

* rename to hosts

* fix tests

* update function name

* improved domaain matching
This commit is contained in:
Caleb Doxsey 2022-12-27 10:07:26 -07:00 committed by GitHub
parent a49f86d023
commit 67e12101fa
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 405 additions and 246 deletions

View file

@ -12,16 +12,16 @@ import (
func (b *Builder) buildVirtualHost(
options *config.Options,
name string,
domain string,
host string,
requireStrictTransportSecurity bool,
) (*envoy_config_route_v3.VirtualHost, error) {
vh := &envoy_config_route_v3.VirtualHost{
Name: name,
Domains: []string{domain},
Domains: []string{host},
}
// these routes match /.pomerium/... and similar paths
rs, err := b.buildPomeriumHTTPRoutes(options, domain)
rs, err := b.buildPomeriumHTTPRoutes(options, host)
if err != nil {
return nil, err
}

View file

@ -25,6 +25,7 @@ import (
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/sets"
"github.com/pomerium/pomerium/internal/telemetry/metrics"
"github.com/pomerium/pomerium/internal/urlutil"
"github.com/pomerium/pomerium/pkg/cryptutil"
)
@ -99,51 +100,50 @@ func (b *Builder) BuildListeners(ctx context.Context, cfg *config.Config) ([]*en
}
func (b *Builder) buildMainListener(ctx context.Context, cfg *config.Config) (*envoy_config_listener_v3.Listener, error) {
listenerFilters := []*envoy_config_listener_v3.ListenerFilter{}
li := newEnvoyListener("http-ingress")
if cfg.Options.UseProxyProtocol {
listenerFilters = append(listenerFilters, ProxyProtocolFilter())
li.ListenerFilters = append(li.ListenerFilters, ProxyProtocolFilter())
}
if cfg.Options.InsecureServer {
allDomains, err := getAllRouteableDomains(cfg.Options, cfg.Options.Addr)
if err != nil {
return nil, err
}
filter, err := b.buildMainHTTPConnectionManagerFilter(cfg.Options, allDomains, false)
if err != nil {
return nil, err
}
li := newEnvoyListener("http-ingress")
li.Address = buildAddress(cfg.Options.Addr, 80)
li.ListenerFilters = listenerFilters
filter, err := b.buildMainHTTPConnectionManagerFilter(cfg.Options, false)
if err != nil {
return nil, err
}
li.FilterChains = []*envoy_config_listener_v3.FilterChain{{
Filters: []*envoy_config_listener_v3.Filter{
filter,
},
}}
return li, nil
}
listenerFilters = append(listenerFilters, TLSInspectorFilter())
} else {
li.Address = buildAddress(cfg.Options.Addr, 443)
li.ListenerFilters = append(li.ListenerFilters, TLSInspectorFilter())
chains, err := b.buildFilterChains(cfg, cfg.Options.Addr,
func(tlsDomain string, httpDomains []string) (*envoy_config_listener_v3.FilterChain, error) {
allCertificates, _ := cfg.AllCertificates()
requireStrictTransportSecurity := cryptutil.HasCertificateForDomain(allCertificates, tlsDomain)
filter, err := b.buildMainHTTPConnectionManagerFilter(cfg.Options, httpDomains, requireStrictTransportSecurity)
serverNames, err := getAllServerNames(cfg, cfg.Options.Addr)
if err != nil {
return nil, err
}
for _, serverName := range serverNames {
requireStrictTransportSecurity := cryptutil.HasCertificateForServerName(allCertificates, serverName)
filter, err := b.buildMainHTTPConnectionManagerFilter(cfg.Options, requireStrictTransportSecurity)
if err != nil {
return nil, err
}
filterChain := &envoy_config_listener_v3.FilterChain{
Filters: []*envoy_config_listener_v3.Filter{filter},
}
if tlsDomain != "*" {
if serverName != "*" {
filterChain.FilterChainMatch = &envoy_config_listener_v3.FilterChainMatch{
ServerNames: []string{tlsDomain},
ServerNames: []string{serverName},
}
}
tlsContext := b.buildDownstreamTLSContext(ctx, cfg, tlsDomain)
tlsContext := b.buildDownstreamTLSContext(ctx, cfg, serverName)
if tlsContext != nil {
tlsConfig := marshalAny(tlsContext)
filterChain.TransportSocket = &envoy_config_core_v3.TransportSocket{
@ -153,16 +153,9 @@ func (b *Builder) buildMainListener(ctx context.Context, cfg *config.Config) (*e
},
}
}
return filterChain, nil
})
if err != nil {
return nil, err
li.FilterChains = append(li.FilterChains, filterChain)
}
}
li := newEnvoyListener("https-ingress")
li.Address = buildAddress(cfg.Options.Addr, 443)
li.ListenerFilters = listenerFilters
li.FilterChains = chains
return li, nil
}
@ -245,42 +238,8 @@ func (b *Builder) buildMetricsListener(cfg *config.Config) (*envoy_config_listen
return li, nil
}
func (b *Builder) buildFilterChains(
cfg *config.Config, addr string,
callback func(tlsDomain string, httpDomains []string) (*envoy_config_listener_v3.FilterChain, error),
) ([]*envoy_config_listener_v3.FilterChain, error) {
allDomains, err := getAllRouteableDomains(cfg.Options, addr)
if err != nil {
return nil, err
}
tlsDomains, err := getAllTLSDomains(cfg, addr)
if err != nil {
return nil, err
}
var chains []*envoy_config_listener_v3.FilterChain
chains = append(chains, b.buildACMETLSALPNFilterChain())
for _, domain := range tlsDomains {
chain, err := callback(domain, allDomains)
if err != nil {
return nil, err
}
chains = append(chains, chain)
}
// if there are no SNI matches we match on HTTP host
chain, err := callback("*", allDomains)
if err != nil {
return nil, err
}
chains = append(chains, chain)
return chains, nil
}
func (b *Builder) buildMainHTTPConnectionManagerFilter(
options *config.Options,
domains []string,
requireStrictTransportSecurity bool,
) (*envoy_config_listener_v3.Filter, error) {
authorizeURLs, err := options.GetInternalAuthorizeURLs()
@ -293,17 +252,22 @@ func (b *Builder) buildMainHTTPConnectionManagerFilter(
return nil, err
}
allHosts, err := getAllRouteableHosts(options, options.Addr)
if err != nil {
return nil, err
}
var virtualHosts []*envoy_config_route_v3.VirtualHost
for _, domain := range domains {
vh, err := b.buildVirtualHost(options, domain, domain, requireStrictTransportSecurity)
for _, host := range allHosts {
vh, err := b.buildVirtualHost(options, host, host, requireStrictTransportSecurity)
if err != nil {
return nil, err
}
if options.Addr == options.GetGRPCAddr() {
// if this is a gRPC service domain and we're supposed to handle that, add those routes
if (config.IsAuthorize(options.Services) && hostsMatchDomain(authorizeURLs, domain)) ||
(config.IsDataBroker(options.Services) && hostsMatchDomain(dataBrokerURLs, domain)) {
if (config.IsAuthorize(options.Services) && urlsMatchHost(authorizeURLs, host)) ||
(config.IsDataBroker(options.Services) && urlsMatchHost(dataBrokerURLs, host)) {
rs, err := b.buildGRPCRoutes()
if err != nil {
return nil, err
@ -314,7 +278,7 @@ func (b *Builder) buildMainHTTPConnectionManagerFilter(
// if we're the proxy, add all the policy routes
if config.IsProxy(options.Services) {
rs, err := b.buildPolicyRoutes(options, domain)
rs, err := b.buildPolicyRoutes(options, host)
if err != nil {
return nil, err
}
@ -445,28 +409,35 @@ func (b *Builder) buildGRPCListener(ctx context.Context, cfg *config.Config) (*e
return nil, err
}
if cfg.Options.GetGRPCInsecure() {
li := newEnvoyListener("grpc-ingress")
if cfg.Options.GetGRPCInsecure() {
li.Address = buildAddress(cfg.Options.GetGRPCAddr(), 80)
li.FilterChains = []*envoy_config_listener_v3.FilterChain{{
Filters: []*envoy_config_listener_v3.Filter{
filter,
},
}}
return li, nil
} else {
li.Address = buildAddress(cfg.Options.GetGRPCAddr(), 443)
li.ListenerFilters = []*envoy_config_listener_v3.ListenerFilter{
TLSInspectorFilter(),
}
chains, err := b.buildFilterChains(cfg, cfg.Options.GRPCAddr,
func(tlsDomain string, httpDomains []string) (*envoy_config_listener_v3.FilterChain, error) {
serverNames, err := getAllServerNames(cfg, cfg.Options.GRPCAddr)
if err != nil {
return nil, err
}
for _, serverName := range serverNames {
filterChain := &envoy_config_listener_v3.FilterChain{
Filters: []*envoy_config_listener_v3.Filter{filter},
}
if tlsDomain != "*" {
if serverName != "*" {
filterChain.FilterChainMatch = &envoy_config_listener_v3.FilterChainMatch{
ServerNames: []string{tlsDomain},
ServerNames: []string{serverName},
}
}
tlsContext := b.buildDownstreamTLSContext(ctx, cfg, tlsDomain)
tlsContext := b.buildDownstreamTLSContext(ctx, cfg, serverName)
if tlsContext != nil {
tlsConfig := marshalAny(tlsContext)
filterChain.TransportSocket = &envoy_config_core_v3.TransportSocket{
@ -476,18 +447,9 @@ func (b *Builder) buildGRPCListener(ctx context.Context, cfg *config.Config) (*e
},
}
}
return filterChain, nil
})
if err != nil {
return nil, err
li.FilterChains = append(li.FilterChains, filterChain)
}
li := newEnvoyListener("grpc-ingress")
li.Address = buildAddress(cfg.Options.GetGRPCAddr(), 443)
li.ListenerFilters = []*envoy_config_listener_v3.ListenerFilter{
TLSInspectorFilter(),
}
li.FilterChains = chains
return li, nil
}
@ -548,23 +510,23 @@ func (b *Builder) buildRouteConfiguration(name string, virtualHosts []*envoy_con
func (b *Builder) buildDownstreamTLSContext(ctx context.Context,
cfg *config.Config,
domain string,
serverName string,
) *envoy_extensions_transport_sockets_tls_v3.DownstreamTlsContext {
certs, err := cfg.AllCertificates()
if err != nil {
log.Warn(ctx).Str("domain", domain).Err(err).Msg("failed to get all certificates from config")
log.Warn(ctx).Str("domain", serverName).Err(err).Msg("failed to get all certificates from config")
return nil
}
cert, err := cryptutil.GetCertificateForDomain(certs, domain)
cert, err := cryptutil.GetCertificateForServerName(certs, serverName)
if err != nil {
log.Warn(ctx).Str("domain", domain).Err(err).Msg("failed to get certificate for domain")
log.Warn(ctx).Str("domain", serverName).Err(err).Msg("failed to get certificate for domain")
return nil
}
err = validateCertificate(cert)
if err != nil {
log.Warn(ctx).Str("domain", domain).Err(err).Msg("invalid certificate for domain")
log.Warn(ctx).Str("domain", serverName).Err(err).Msg("invalid certificate for domain")
return nil
}
@ -584,14 +546,14 @@ func (b *Builder) buildDownstreamTLSContext(ctx context.Context,
TlsParams: tlsParams,
TlsCertificates: []*envoy_extensions_transport_sockets_tls_v3.TlsCertificate{envoyCert},
AlpnProtocols: alpnProtocols,
ValidationContextType: b.buildDownstreamValidationContext(ctx, cfg, domain),
ValidationContextType: b.buildDownstreamValidationContext(ctx, cfg, serverName),
},
}
}
func (b *Builder) buildDownstreamValidationContext(ctx context.Context,
cfg *config.Config,
domain string,
serverName string,
) *envoy_extensions_transport_sockets_tls_v3.CommonTlsContext_ValidationContext {
needsClientCert := false
@ -599,7 +561,7 @@ func (b *Builder) buildDownstreamValidationContext(ctx context.Context,
needsClientCert = true
}
if !needsClientCert {
for _, p := range getPoliciesForDomain(cfg.Options, domain) {
for _, p := range getPoliciesForServerName(cfg.Options, serverName) {
if p.TLSDownstreamClientCA != "" {
needsClientCert = true
break
@ -632,40 +594,41 @@ func (b *Builder) buildDownstreamValidationContext(ctx context.Context,
return vc
}
func getAllRouteableDomains(options *config.Options, addr string) ([]string, error) {
allDomains := sets.NewSorted[string]()
func getAllRouteableHosts(options *config.Options, addr string) ([]string, error) {
allHosts := sets.NewSorted[string]()
if addr == options.Addr {
domains, err := options.GetAllRouteableHTTPDomains()
hosts, err := options.GetAllRouteableHTTPHosts()
if err != nil {
return nil, err
}
allDomains.Add(domains...)
allHosts.Add(hosts...)
}
if addr == options.GetGRPCAddr() {
domains, err := options.GetAllRouteableGRPCDomains()
hosts, err := options.GetAllRouteableGRPCHosts()
if err != nil {
return nil, err
}
allDomains.Add(domains...)
allHosts.Add(hosts...)
}
return allDomains.ToSlice(), nil
return allHosts.ToSlice(), nil
}
func getAllTLSDomains(cfg *config.Config, addr string) ([]string, error) {
domains := sets.NewSorted[string]()
func getAllServerNames(cfg *config.Config, addr string) ([]string, error) {
serverNames := sets.NewSorted[string]()
serverNames.Add("*")
routeableDomains, err := getAllRouteableDomains(cfg.Options, addr)
routeableHosts, err := getAllRouteableHosts(cfg.Options, addr)
if err != nil {
return nil, err
}
for _, hp := range routeableDomains {
if d, _, err := net.SplitHostPort(hp); err == nil {
domains.Add(d)
for _, hp := range routeableHosts {
if h, _, err := net.SplitHostPort(hp); err == nil {
serverNames.Add(h)
} else {
domains.Add(hp)
serverNames.Add(hp)
}
}
@ -674,24 +637,24 @@ func getAllTLSDomains(cfg *config.Config, addr string) ([]string, error) {
return nil, err
}
for i := range certs {
for _, domain := range cryptutil.GetCertificateDomains(&certs[i]) {
domains.Add(domain)
for _, domain := range cryptutil.GetCertificateServerNames(&certs[i]) {
serverNames.Add(domain)
}
}
return domains.ToSlice(), nil
return serverNames.ToSlice(), nil
}
func hostsMatchDomain(urls []*url.URL, host string) bool {
func urlsMatchHost(urls []*url.URL, host string) bool {
for _, u := range urls {
if hostMatchesDomain(u, host) {
if urlMatchesHost(u, host) {
return true
}
}
return false
}
func hostMatchesDomain(u *url.URL, host string) bool {
func urlMatchesHost(u *url.URL, host string) bool {
if u == nil {
return false
}
@ -718,10 +681,10 @@ func hostMatchesDomain(u *url.URL, host string) bool {
return h1 == h2 && p1 == p2
}
func getPoliciesForDomain(options *config.Options, domain string) []config.Policy {
func getPoliciesForServerName(options *config.Options, serverName string) []config.Policy {
var policies []config.Policy
for _, p := range options.GetAllPolicies() {
if p.Source != nil && p.Source.URL.Hostname() == domain {
if p.Source != nil && urlutil.MatchesServerName(*p.Source.URL, serverName) {
policies = append(policies, p)
}
}

View file

@ -129,7 +129,8 @@ func Test_buildMainHTTPConnectionManagerFilter(t *testing.T) {
options := config.NewDefaultOptions()
options.SkipXffAppend = true
options.XffNumTrustedHops = 1
filter, err := b.buildMainHTTPConnectionManagerFilter(options, []string{"example.com"}, true)
options.AuthenticateURLString = "https://authenticate.example.com"
filter, err := b.buildMainHTTPConnectionManagerFilter(options, true)
require.NoError(t, err)
testutil.AssertProtoJSONEqual(t, `{
"name": "envoy.filters.network.http_connection_manager",
@ -220,8 +221,8 @@ func Test_buildMainHTTPConnectionManagerFilter(t *testing.T) {
"name": "main",
"virtualHosts": [
{
"name": "example.com",
"domains": ["example.com"],
"name": "authenticate.example.com",
"domains": ["authenticate.example.com"],
"responseHeadersToAdd": [{
"appendAction": "OVERWRITE_IF_EXISTS_OR_ADD",
"header": {
@ -366,6 +367,216 @@ func Test_buildMainHTTPConnectionManagerFilter(t *testing.T) {
"disabled": true
}
}
},
{
"name": "pomerium-path-/oauth2/callback",
"match": {
"path": "/oauth2/callback"
},
"route": {
"cluster": "pomerium-control-plane-http"
},
"typedPerFilterConfig": {
"envoy.filters.http.ext_authz": {
"@type": "type.googleapis.com/envoy.extensions.filters.http.ext_authz.v3.ExtAuthzPerRoute",
"disabled": true
}
}
},
{
"name": "pomerium-path-/",
"match": {
"path": "/"
},
"route": {
"cluster": "pomerium-control-plane-http"
},
"typedPerFilterConfig": {
"envoy.filters.http.ext_authz": {
"@type": "type.googleapis.com/envoy.extensions.filters.http.ext_authz.v3.ExtAuthzPerRoute",
"disabled": true
}
}
}
]
},
{
"name": "authenticate.example.com:443",
"domains": ["authenticate.example.com:443"],
"responseHeadersToAdd": [{
"appendAction": "OVERWRITE_IF_EXISTS_OR_ADD",
"header": {
"key": "Strict-Transport-Security",
"value": "max-age=31536000; includeSubDomains; preload"
}
},
{
"appendAction": "OVERWRITE_IF_EXISTS_OR_ADD",
"header": {
"key": "X-Frame-Options",
"value": "SAMEORIGIN"
}
},
{
"appendAction": "OVERWRITE_IF_EXISTS_OR_ADD",
"header": {
"key": "X-XSS-Protection",
"value": "1; mode=block"
}
}],
"routes": [
{
"name": "pomerium-path-/.pomerium/jwt",
"match": {
"path": "/.pomerium/jwt"
},
"route": {
"cluster": "pomerium-control-plane-http"
}
},
{
"name": "pomerium-path-/.pomerium/webauthn",
"match": {
"path": "/.pomerium/webauthn"
},
"route": {
"cluster": "pomerium-control-plane-http"
}
},
{
"name": "pomerium-path-/ping",
"match": {
"path": "/ping"
},
"route": {
"cluster": "pomerium-control-plane-http"
},
"typedPerFilterConfig": {
"envoy.filters.http.ext_authz": {
"@type": "type.googleapis.com/envoy.extensions.filters.http.ext_authz.v3.ExtAuthzPerRoute",
"disabled": true
}
}
},
{
"name": "pomerium-path-/healthz",
"match": {
"path": "/healthz"
},
"route": {
"cluster": "pomerium-control-plane-http"
},
"typedPerFilterConfig": {
"envoy.filters.http.ext_authz": {
"@type": "type.googleapis.com/envoy.extensions.filters.http.ext_authz.v3.ExtAuthzPerRoute",
"disabled": true
}
}
},
{
"name": "pomerium-path-/.pomerium",
"match": {
"path": "/.pomerium"
},
"route": {
"cluster": "pomerium-control-plane-http"
},
"typedPerFilterConfig": {
"envoy.filters.http.ext_authz": {
"@type": "type.googleapis.com/envoy.extensions.filters.http.ext_authz.v3.ExtAuthzPerRoute",
"disabled": true
}
}
},
{
"name": "pomerium-prefix-/.pomerium/",
"match": {
"prefix": "/.pomerium/"
},
"route": {
"cluster": "pomerium-control-plane-http"
},
"typedPerFilterConfig": {
"envoy.filters.http.ext_authz": {
"@type": "type.googleapis.com/envoy.extensions.filters.http.ext_authz.v3.ExtAuthzPerRoute",
"disabled": true
}
}
},
{
"name": "pomerium-path-/.well-known/pomerium",
"match": {
"path": "/.well-known/pomerium"
},
"route": {
"cluster": "pomerium-control-plane-http"
},
"typedPerFilterConfig": {
"envoy.filters.http.ext_authz": {
"@type": "type.googleapis.com/envoy.extensions.filters.http.ext_authz.v3.ExtAuthzPerRoute",
"disabled": true
}
}
},
{
"name": "pomerium-prefix-/.well-known/pomerium/",
"match": {
"prefix": "/.well-known/pomerium/"
},
"route": {
"cluster": "pomerium-control-plane-http"
},
"typedPerFilterConfig": {
"envoy.filters.http.ext_authz": {
"@type": "type.googleapis.com/envoy.extensions.filters.http.ext_authz.v3.ExtAuthzPerRoute",
"disabled": true
}
}
},
{
"name": "pomerium-path-/robots.txt",
"match": {
"path": "/robots.txt"
},
"route": {
"cluster": "pomerium-control-plane-http"
},
"typedPerFilterConfig": {
"envoy.filters.http.ext_authz": {
"@type": "type.googleapis.com/envoy.extensions.filters.http.ext_authz.v3.ExtAuthzPerRoute",
"disabled": true
}
}
},
{
"name": "pomerium-path-/oauth2/callback",
"match": {
"path": "/oauth2/callback"
},
"route": {
"cluster": "pomerium-control-plane-http"
},
"typedPerFilterConfig": {
"envoy.filters.http.ext_authz": {
"@type": "type.googleapis.com/envoy.extensions.filters.http.ext_authz.v3.ExtAuthzPerRoute",
"disabled": true
}
}
},
{
"name": "pomerium-path-/",
"match": {
"path": "/"
},
"route": {
"cluster": "pomerium-control-plane-http"
},
"typedPerFilterConfig": {
"envoy.filters.http.ext_authz": {
"@type": "type.googleapis.com/envoy.extensions.filters.http.ext_authz.v3.ExtAuthzPerRoute",
"disabled": true
}
}
}
]
},
@ -779,7 +990,7 @@ func Test_getAllDomains(t *testing.T) {
}
t.Run("routable", func(t *testing.T) {
t.Run("http", func(t *testing.T) {
actual, err := getAllRouteableDomains(options, "127.0.0.1:9000")
actual, err := getAllRouteableHosts(options, "127.0.0.1:9000")
require.NoError(t, err)
expect := []string{
"a.example.com",
@ -794,7 +1005,7 @@ func Test_getAllDomains(t *testing.T) {
assert.Equal(t, expect, actual)
})
t.Run("grpc", func(t *testing.T) {
actual, err := getAllRouteableDomains(options, "127.0.0.1:9001")
actual, err := getAllRouteableHosts(options, "127.0.0.1:9001")
require.NoError(t, err)
expect := []string{
"authorize.example.com:9001",
@ -805,7 +1016,7 @@ func Test_getAllDomains(t *testing.T) {
t.Run("both", func(t *testing.T) {
newOptions := *options
newOptions.GRPCAddr = newOptions.Addr
actual, err := getAllRouteableDomains(&newOptions, "127.0.0.1:9000")
actual, err := getAllRouteableHosts(&newOptions, "127.0.0.1:9000")
require.NoError(t, err)
expect := []string{
"a.example.com",
@ -824,9 +1035,10 @@ func Test_getAllDomains(t *testing.T) {
})
t.Run("tls", func(t *testing.T) {
t.Run("http", func(t *testing.T) {
actual, err := getAllTLSDomains(&config.Config{Options: options}, "127.0.0.1:9000")
actual, err := getAllServerNames(&config.Config{Options: options}, "127.0.0.1:9000")
require.NoError(t, err)
expect := []string{
"*",
"*.unknown.example.com",
"a.example.com",
"authenticate.example.com",
@ -836,9 +1048,10 @@ func Test_getAllDomains(t *testing.T) {
assert.Equal(t, expect, actual)
})
t.Run("grpc", func(t *testing.T) {
actual, err := getAllTLSDomains(&config.Config{Options: options}, "127.0.0.1:9001")
actual, err := getAllServerNames(&config.Config{Options: options}, "127.0.0.1:9001")
require.NoError(t, err)
expect := []string{
"*",
"*.unknown.example.com",
"authorize.example.com",
"cache.example.com",
@ -848,14 +1061,31 @@ func Test_getAllDomains(t *testing.T) {
})
}
func Test_hostMatchesDomain(t *testing.T) {
assert.True(t, hostMatchesDomain(mustParseURL(t, "http://example.com"), "example.com"))
assert.True(t, hostMatchesDomain(mustParseURL(t, "http://example.com"), "example.com:80"))
assert.True(t, hostMatchesDomain(mustParseURL(t, "https://example.com"), "example.com:443"))
assert.True(t, hostMatchesDomain(mustParseURL(t, "https://example.com:443"), "example.com:443"))
assert.True(t, hostMatchesDomain(mustParseURL(t, "https://example.com:443"), "example.com"))
assert.False(t, hostMatchesDomain(mustParseURL(t, "http://example.com:81"), "example.com"))
assert.False(t, hostMatchesDomain(mustParseURL(t, "http://example.com:81"), "example.com:80"))
func Test_urlMatchesHost(t *testing.T) {
t.Parallel()
for _, tc := range []struct {
name string
sourceURL string
host string
matches bool
}{
{"no port", "http://example.com", "example.com", true},
{"host http port", "http://example.com", "example.com:80", true},
{"host https port", "https://example.com", "example.com:443", true},
{"with port", "https://example.com:443", "example.com:443", true},
{"url port", "https://example.com:443", "example.com", true},
{"non standard port", "http://example.com:81", "example.com", false},
{"non standard host port", "http://example.com:81", "example.com:80", false},
} {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
assert.Equal(t, tc.matches, urlMatchesHost(mustParseURL(t, tc.sourceURL), tc.host),
"urlMatchesHost(%s,%s)", tc.sourceURL, tc.host)
})
}
}
func Test_buildRouteConfiguration(t *testing.T) {

View file

@ -47,12 +47,12 @@ func (b *Builder) buildGRPCRoutes() ([]*envoy_config_route_v3.Route, error) {
}}, nil
}
func (b *Builder) buildPomeriumHTTPRoutes(options *config.Options, domain string) ([]*envoy_config_route_v3.Route, error) {
func (b *Builder) buildPomeriumHTTPRoutes(options *config.Options, host string) ([]*envoy_config_route_v3.Route, error) {
var routes []*envoy_config_route_v3.Route
// if this is the pomerium proxy in front of the the authenticate service, don't add
// these routes since they will be handled by authenticate
isFrontingAuthenticate, err := isProxyFrontingAuthenticate(options, domain)
isFrontingAuthenticate, err := isProxyFrontingAuthenticate(options, host)
if err != nil {
return nil, err
}
@ -70,7 +70,7 @@ func (b *Builder) buildPomeriumHTTPRoutes(options *config.Options, domain string
b.buildControlPlanePrefixRoute("/.well-known/pomerium/", false),
)
// per #837, only add robots.txt if there are no unauthenticated routes
if !hasPublicPolicyMatchingURL(options, url.URL{Scheme: "https", Host: domain, Path: "/robots.txt"}) {
if !hasPublicPolicyMatchingURL(options, url.URL{Scheme: "https", Host: host, Path: "/robots.txt"}) {
routes = append(routes, b.buildControlPlanePathRoute("/robots.txt", false))
}
}
@ -79,7 +79,7 @@ func (b *Builder) buildPomeriumHTTPRoutes(options *config.Options, domain string
if err != nil {
return nil, err
}
if config.IsAuthenticate(options.Services) && hostMatchesDomain(authenticateURL, domain) {
if config.IsAuthenticate(options.Services) && urlMatchesHost(authenticateURL, host) {
routes = append(routes,
b.buildControlPlanePathRoute(options.AuthenticateCallbackPath, false),
b.buildControlPlanePathRoute("/", false),
@ -151,12 +151,12 @@ func getClusterStatsName(policy *config.Policy) string {
return ""
}
func (b *Builder) buildPolicyRoutes(options *config.Options, domain string) ([]*envoy_config_route_v3.Route, error) {
func (b *Builder) buildPolicyRoutes(options *config.Options, host string) ([]*envoy_config_route_v3.Route, error) {
var routes []*envoy_config_route_v3.Route
for i, p := range options.GetAllPolicies() {
policy := p
if !hostMatchesDomain(policy.Source.URL, domain) {
if !urlMatchesHost(policy.Source.URL, host) {
continue
}
@ -188,7 +188,7 @@ func (b *Builder) buildPolicyRoutes(options *config.Options, domain string) ([]*
}
// disable authentication entirely when the proxy is fronting authenticate
isFrontingAuthenticate, err := isProxyFrontingAuthenticate(options, domain)
isFrontingAuthenticate, err := isProxyFrontingAuthenticate(options, host)
if err != nil {
return nil, err
}
@ -497,13 +497,13 @@ func hasPublicPolicyMatchingURL(options *config.Options, requestURL url.URL) boo
return false
}
func isProxyFrontingAuthenticate(options *config.Options, domain string) (bool, error) {
func isProxyFrontingAuthenticate(options *config.Options, host string) (bool, error) {
authenticateURL, err := options.GetAuthenticateURL()
if err != nil {
return false, err
}
if !config.IsAuthenticate(options.Services) && hostMatchesDomain(authenticateURL, domain) {
if !config.IsAuthenticate(options.Services) && urlMatchesHost(authenticateURL, host) {
return true, nil
}

View file

@ -1015,15 +1015,9 @@ func (o *Options) GetCodecType() CodecType {
return o.CodecType
}
// GetAllRouteableGRPCDomains returns all the possible gRPC domains handled by the Pomerium options.
func (o *Options) GetAllRouteableGRPCDomains() ([]string, error) {
return o.GetAllRouteableGRPCDomainsForTLSServerName("")
}
// GetAllRouteableGRPCDomainsForTLSServerName returns all the possible gRPC domains handled by the Pomerium options
// for the given TLS server name.
func (o *Options) GetAllRouteableGRPCDomainsForTLSServerName(tlsServerName string) ([]string, error) {
domains := sets.NewSorted[string]()
// GetAllRouteableGRPCHosts returns all the possible gRPC hosts handled by the Pomerium options.
func (o *Options) GetAllRouteableGRPCHosts() ([]string, error) {
hosts := sets.NewSorted[string]()
// authorize urls
if IsAll(o.Services) {
@ -1032,11 +1026,7 @@ func (o *Options) GetAllRouteableGRPCDomainsForTLSServerName(tlsServerName strin
return nil, err
}
for _, u := range authorizeURLs {
for _, h := range urlutil.GetDomainsForURL(*u) {
if tlsServerName == "" || urlutil.StripPort(h) == tlsServerName {
domains.Add(h)
}
}
hosts.Add(urlutil.GetDomainsForURL(*u)...)
}
} else if IsAuthorize(o.Services) {
authorizeURLs, err := o.GetInternalAuthorizeURLs()
@ -1044,11 +1034,7 @@ func (o *Options) GetAllRouteableGRPCDomainsForTLSServerName(tlsServerName strin
return nil, err
}
for _, u := range authorizeURLs {
for _, h := range urlutil.GetDomainsForURL(*u) {
if tlsServerName == "" || urlutil.StripPort(h) == tlsServerName {
domains.Add(h)
}
}
hosts.Add(urlutil.GetDomainsForURL(*u)...)
}
}
@ -1059,11 +1045,7 @@ func (o *Options) GetAllRouteableGRPCDomainsForTLSServerName(tlsServerName strin
return nil, err
}
for _, u := range dataBrokerURLs {
for _, h := range urlutil.GetDomainsForURL(*u) {
if tlsServerName == "" || urlutil.StripPort(h) == tlsServerName {
domains.Add(h)
}
}
hosts.Add(urlutil.GetDomainsForURL(*u)...)
}
} else if IsDataBroker(o.Services) {
dataBrokerURLs, err := o.GetInternalDataBrokerURLs()
@ -1071,71 +1053,42 @@ func (o *Options) GetAllRouteableGRPCDomainsForTLSServerName(tlsServerName strin
return nil, err
}
for _, u := range dataBrokerURLs {
for _, h := range urlutil.GetDomainsForURL(*u) {
if tlsServerName == "" || urlutil.StripPort(h) == tlsServerName {
domains.Add(h)
}
}
hosts.Add(urlutil.GetDomainsForURL(*u)...)
}
}
return domains.ToSlice(), nil
return hosts.ToSlice(), nil
}
// GetAllRouteableHTTPDomains returns all the possible HTTP domains handled by the Pomerium options.
func (o *Options) GetAllRouteableHTTPDomains() ([]string, error) {
return o.GetAllRouteableHTTPDomainsForTLSServerName("")
}
// GetAllRouteableHTTPDomainsForTLSServerName returns all the possible HTTP domains handled by the Pomerium options
// for the given TLS server name.
func (o *Options) GetAllRouteableHTTPDomainsForTLSServerName(tlsServerName string) ([]string, error) {
domains := sets.NewSorted[string]()
// GetAllRouteableHTTPHosts returns all the possible HTTP hosts handled by the Pomerium options.
func (o *Options) GetAllRouteableHTTPHosts() ([]string, error) {
hosts := sets.NewSorted[string]()
if IsAuthenticate(o.Services) {
authenticateURL, err := o.GetInternalAuthenticateURL()
if err != nil {
return nil, err
}
for _, h := range urlutil.GetDomainsForURL(*authenticateURL) {
if tlsServerName == "" || urlutil.StripPort(h) == tlsServerName {
domains.Add(h)
}
}
hosts.Add(urlutil.GetDomainsForURL(*authenticateURL)...)
authenticateURL, err = o.GetAuthenticateURL()
if err != nil {
return nil, err
}
for _, h := range urlutil.GetDomainsForURL(*authenticateURL) {
if tlsServerName == "" || urlutil.StripPort(h) == tlsServerName {
domains.Add(h)
}
}
hosts.Add(urlutil.GetDomainsForURL(*authenticateURL)...)
}
// policy urls
if IsProxy(o.Services) {
for _, policy := range o.GetAllPolicies() {
for _, h := range urlutil.GetDomainsForURL(*policy.Source.URL) {
if tlsServerName == "" ||
policy.TLSDownstreamServerName == tlsServerName ||
urlutil.StripPort(h) == tlsServerName {
domains.Add(h)
}
}
hosts.Add(urlutil.GetDomainsForURL(*policy.Source.URL)...)
if policy.TLSDownstreamServerName != "" {
tlsURL := policy.Source.URL.ResolveReference(&url.URL{Host: policy.TLSDownstreamServerName})
for _, h := range urlutil.GetDomainsForURL(*tlsURL) {
if tlsServerName == "" ||
urlutil.StripPort(h) == tlsServerName {
domains.Add(h)
}
}
hosts.Add(urlutil.GetDomainsForURL(*tlsURL)...)
}
}
}
return domains.ToSlice(), nil
return hosts.ToSlice(), nil
}
// GetClientSecret gets the client secret.

View file

@ -666,14 +666,14 @@ func TestOptions_GetOauthOptions(t *testing.T) {
assert.Equal(t, u.Hostname(), oauthOptions.RedirectURL.Hostname())
}
func TestOptions_GetAllRouteableGRPCDomains(t *testing.T) {
func TestOptions_GetAllRouteableGRPCHosts(t *testing.T) {
opts := &Options{
AuthenticateURLString: "https://authenticate.example.com",
AuthorizeURLString: "https://authorize.example.com",
DataBrokerURLString: "https://databroker.example.com",
Services: "all",
}
domains, err := opts.GetAllRouteableGRPCDomains()
hosts, err := opts.GetAllRouteableGRPCHosts()
assert.NoError(t, err)
assert.Equal(t, []string{
@ -681,10 +681,10 @@ func TestOptions_GetAllRouteableGRPCDomains(t *testing.T) {
"authorize.example.com:443",
"databroker.example.com",
"databroker.example.com:443",
}, domains)
}, hosts)
}
func TestOptions_GetAllRouteableHTTPDomains(t *testing.T) {
func TestOptions_GetAllRouteableHTTPHosts(t *testing.T) {
p1 := Policy{From: "https://from1.example.com"}
p1.Validate()
p2 := Policy{From: "https://from2.example.com"}
@ -699,7 +699,7 @@ func TestOptions_GetAllRouteableHTTPDomains(t *testing.T) {
Policies: []Policy{p1, p2, p3},
Services: "all",
}
domains, err := opts.GetAllRouteableHTTPDomains()
hosts, err := opts.GetAllRouteableHTTPHosts()
assert.NoError(t, err)
assert.Equal(t, []string{
@ -713,7 +713,7 @@ func TestOptions_GetAllRouteableHTTPDomains(t *testing.T) {
"from2.example.com:443",
"from3.example.com",
"from3.example.com:443",
}, domains)
}, hosts)
}
func TestOptions_ApplySettings(t *testing.T) {

View file

@ -8,6 +8,8 @@ import (
"net/url"
"strings"
"time"
"github.com/caddyserver/certmagic"
)
const (
@ -160,3 +162,8 @@ func GetExternalRequest(internalURL, externalURL *url.URL, r *http.Request) *htt
}
return er
}
// MatchesServerName returnes true if the url's host matches the given server name.
func MatchesServerName(u url.URL, serverName string) bool {
return certmagic.MatchWildcard(u.Hostname(), serverName)
}

View file

@ -166,3 +166,9 @@ func TestJoin(t *testing.T) {
assert.Equal(t, "/x/y/z/", Join("/x", "/y/z/"))
assert.Equal(t, "/x/y/z/", Join("/x/", "/y/z/"))
}
func TestMatchesServerName(t *testing.T) {
t.Run("wildcard", func(t *testing.T) {
assert.True(t, MatchesServerName(MustParseAndValidateURL("https://domain.example.com"), "*.example.com"))
})
}

View file

@ -44,36 +44,36 @@ func GetCertPool(ca, caFile string) (*x509.CertPool, error) {
return rootCAs, nil
}
// GetCertificateForDomain returns the tls Certificate which matches the given domain name.
// GetCertificateForServerName returns the tls Certificate which matches the given server name.
// It should handle both exact matches and wildcard matches. If none of those match, the first certificate will be used.
// Finally if there are no matching certificates one will be generated.
func GetCertificateForDomain(certificates []tls.Certificate, domain string) (*tls.Certificate, error) {
func GetCertificateForServerName(certificates []tls.Certificate, serverName string) (*tls.Certificate, error) {
// first try a direct name match
for i := range certificates {
if matchesDomain(&certificates[i], domain) {
if matchesServerName(&certificates[i], serverName) {
return &certificates[i], nil
}
}
log.WarnNoTLSCertificate(domain)
log.WarnNoTLSCertificate(serverName)
// finally fall back to a generated, self-signed certificate
return GenerateSelfSignedCertificate(domain)
return GenerateSelfSignedCertificate(serverName)
}
// HasCertificateForDomain returns true if a TLS certificate matches the given domain.
func HasCertificateForDomain(certificates []tls.Certificate, domain string) bool {
// HasCertificateForServerName returns true if a TLS certificate matches the given server name.
func HasCertificateForServerName(certificates []tls.Certificate, serverName string) bool {
for i := range certificates {
if matchesDomain(&certificates[i], domain) {
if matchesServerName(&certificates[i], serverName) {
return true
}
}
return false
}
// GetCertificateDomains gets all the certificate's matching domain names.
// GetCertificateServerNames gets all the certificate's server names.
// Will return an empty slice if certificate is nil, empty, or x509 parsing fails.
func GetCertificateDomains(cert *tls.Certificate) []string {
func GetCertificateServerNames(cert *tls.Certificate) []string {
if cert == nil || len(cert.Certificate) == 0 {
return nil
}
@ -83,19 +83,19 @@ func GetCertificateDomains(cert *tls.Certificate) []string {
return nil
}
var domains []string
var serverNames []string
if xcert.Subject.CommonName != "" {
domains = append(domains, xcert.Subject.CommonName)
serverNames = append(serverNames, xcert.Subject.CommonName)
}
for _, dnsName := range xcert.DNSNames {
if dnsName != "" {
domains = append(domains, dnsName)
serverNames = append(serverNames, dnsName)
}
}
return domains
return serverNames
}
func matchesDomain(cert *tls.Certificate, domain string) bool {
func matchesServerName(cert *tls.Certificate, serverName string) bool {
if cert == nil || len(cert.Certificate) == 0 {
return false
}
@ -105,12 +105,12 @@ func matchesDomain(cert *tls.Certificate, domain string) bool {
return false
}
if certmagic.MatchWildcard(domain, xcert.Subject.CommonName) {
if certmagic.MatchWildcard(serverName, xcert.Subject.CommonName) {
return true
}
for _, san := range xcert.DNSNames {
if certmagic.MatchWildcard(domain, san) {
if certmagic.MatchWildcard(serverName, san) {
return true
}
}

View file

@ -8,10 +8,10 @@ import (
"github.com/stretchr/testify/require"
)
func TestGetCertificateForDomain(t *testing.T) {
gen := func(t *testing.T, domain string) *tls.Certificate {
cert, err := GenerateSelfSignedCertificate(domain)
if !assert.NoError(t, err, "error generating certificate for: %s", domain) {
func TestGetCertificateForServerName(t *testing.T) {
gen := func(t *testing.T, serverName string) *tls.Certificate {
cert, err := GenerateSelfSignedCertificate(serverName)
if !assert.NoError(t, err, "error generating certificate for: %s", serverName) {
t.FailNow()
}
return cert
@ -23,7 +23,7 @@ func TestGetCertificateForDomain(t *testing.T) {
*gen(t, "b.example.com"),
}
found, err := GetCertificateForDomain(certs, "b.example.com")
found, err := GetCertificateForServerName(certs, "b.example.com")
if !assert.NoError(t, err) {
return
}
@ -35,7 +35,7 @@ func TestGetCertificateForDomain(t *testing.T) {
*gen(t, "*.example.com"),
}
found, err := GetCertificateForDomain(certs, "b.example.com")
found, err := GetCertificateForServerName(certs, "b.example.com")
if !assert.NoError(t, err) {
return
}
@ -46,7 +46,7 @@ func TestGetCertificateForDomain(t *testing.T) {
*gen(t, "a.example.com"),
}
found, err := GetCertificateForDomain(certs, "b.example.com")
found, err := GetCertificateForServerName(certs, "b.example.com")
if !assert.NoError(t, err) {
return
}
@ -56,7 +56,7 @@ func TestGetCertificateForDomain(t *testing.T) {
t.Run("generate", func(t *testing.T) {
certs := []tls.Certificate{}
found, err := GetCertificateForDomain(certs, "b.example.com")
found, err := GetCertificateForServerName(certs, "b.example.com")
if !assert.NoError(t, err) {
return
}
@ -64,8 +64,8 @@ func TestGetCertificateForDomain(t *testing.T) {
})
}
func TestGetCertificateDomains(t *testing.T) {
func TestGetCertificateServerNames(t *testing.T) {
cert, err := GenerateSelfSignedCertificate("www.example.com")
require.NoError(t, err)
assert.Equal(t, []string{"www.example.com"}, GetCertificateDomains(cert))
assert.Equal(t, []string{"www.example.com"}, GetCertificateServerNames(cert))
}