envoyconfig: clean up filter chain construction (#3844)

* cleanup filter chain construction

* rename domains to server names

* rename to hosts

* fix tests

* update function name

* improved domaain matching
This commit is contained in:
Caleb Doxsey 2022-12-27 10:07:26 -07:00 committed by GitHub
parent a49f86d023
commit 67e12101fa
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 405 additions and 246 deletions

View file

@ -12,16 +12,16 @@ import (
func (b *Builder) buildVirtualHost( func (b *Builder) buildVirtualHost(
options *config.Options, options *config.Options,
name string, name string,
domain string, host string,
requireStrictTransportSecurity bool, requireStrictTransportSecurity bool,
) (*envoy_config_route_v3.VirtualHost, error) { ) (*envoy_config_route_v3.VirtualHost, error) {
vh := &envoy_config_route_v3.VirtualHost{ vh := &envoy_config_route_v3.VirtualHost{
Name: name, Name: name,
Domains: []string{domain}, Domains: []string{host},
} }
// these routes match /.pomerium/... and similar paths // these routes match /.pomerium/... and similar paths
rs, err := b.buildPomeriumHTTPRoutes(options, domain) rs, err := b.buildPomeriumHTTPRoutes(options, host)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -25,6 +25,7 @@ import (
"github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/sets" "github.com/pomerium/pomerium/internal/sets"
"github.com/pomerium/pomerium/internal/telemetry/metrics" "github.com/pomerium/pomerium/internal/telemetry/metrics"
"github.com/pomerium/pomerium/internal/urlutil"
"github.com/pomerium/pomerium/pkg/cryptutil" "github.com/pomerium/pomerium/pkg/cryptutil"
) )
@ -99,51 +100,50 @@ func (b *Builder) BuildListeners(ctx context.Context, cfg *config.Config) ([]*en
} }
func (b *Builder) buildMainListener(ctx context.Context, cfg *config.Config) (*envoy_config_listener_v3.Listener, error) { func (b *Builder) buildMainListener(ctx context.Context, cfg *config.Config) (*envoy_config_listener_v3.Listener, error) {
listenerFilters := []*envoy_config_listener_v3.ListenerFilter{} li := newEnvoyListener("http-ingress")
if cfg.Options.UseProxyProtocol { if cfg.Options.UseProxyProtocol {
listenerFilters = append(listenerFilters, ProxyProtocolFilter()) li.ListenerFilters = append(li.ListenerFilters, ProxyProtocolFilter())
} }
if cfg.Options.InsecureServer { if cfg.Options.InsecureServer {
allDomains, err := getAllRouteableDomains(cfg.Options, cfg.Options.Addr)
if err != nil {
return nil, err
}
filter, err := b.buildMainHTTPConnectionManagerFilter(cfg.Options, allDomains, false)
if err != nil {
return nil, err
}
li := newEnvoyListener("http-ingress")
li.Address = buildAddress(cfg.Options.Addr, 80) li.Address = buildAddress(cfg.Options.Addr, 80)
li.ListenerFilters = listenerFilters
filter, err := b.buildMainHTTPConnectionManagerFilter(cfg.Options, false)
if err != nil {
return nil, err
}
li.FilterChains = []*envoy_config_listener_v3.FilterChain{{ li.FilterChains = []*envoy_config_listener_v3.FilterChain{{
Filters: []*envoy_config_listener_v3.Filter{ Filters: []*envoy_config_listener_v3.Filter{
filter, filter,
}, },
}} }}
return li, nil } else {
} li.Address = buildAddress(cfg.Options.Addr, 443)
listenerFilters = append(listenerFilters, TLSInspectorFilter()) li.ListenerFilters = append(li.ListenerFilters, TLSInspectorFilter())
chains, err := b.buildFilterChains(cfg, cfg.Options.Addr, allCertificates, _ := cfg.AllCertificates()
func(tlsDomain string, httpDomains []string) (*envoy_config_listener_v3.FilterChain, error) {
allCertificates, _ := cfg.AllCertificates() serverNames, err := getAllServerNames(cfg, cfg.Options.Addr)
requireStrictTransportSecurity := cryptutil.HasCertificateForDomain(allCertificates, tlsDomain) if err != nil {
filter, err := b.buildMainHTTPConnectionManagerFilter(cfg.Options, httpDomains, requireStrictTransportSecurity) return nil, err
}
for _, serverName := range serverNames {
requireStrictTransportSecurity := cryptutil.HasCertificateForServerName(allCertificates, serverName)
filter, err := b.buildMainHTTPConnectionManagerFilter(cfg.Options, requireStrictTransportSecurity)
if err != nil { if err != nil {
return nil, err return nil, err
} }
filterChain := &envoy_config_listener_v3.FilterChain{ filterChain := &envoy_config_listener_v3.FilterChain{
Filters: []*envoy_config_listener_v3.Filter{filter}, Filters: []*envoy_config_listener_v3.Filter{filter},
} }
if tlsDomain != "*" { if serverName != "*" {
filterChain.FilterChainMatch = &envoy_config_listener_v3.FilterChainMatch{ filterChain.FilterChainMatch = &envoy_config_listener_v3.FilterChainMatch{
ServerNames: []string{tlsDomain}, ServerNames: []string{serverName},
} }
} }
tlsContext := b.buildDownstreamTLSContext(ctx, cfg, tlsDomain) tlsContext := b.buildDownstreamTLSContext(ctx, cfg, serverName)
if tlsContext != nil { if tlsContext != nil {
tlsConfig := marshalAny(tlsContext) tlsConfig := marshalAny(tlsContext)
filterChain.TransportSocket = &envoy_config_core_v3.TransportSocket{ filterChain.TransportSocket = &envoy_config_core_v3.TransportSocket{
@ -153,16 +153,9 @@ func (b *Builder) buildMainListener(ctx context.Context, cfg *config.Config) (*e
}, },
} }
} }
return filterChain, nil li.FilterChains = append(li.FilterChains, filterChain)
}) }
if err != nil {
return nil, err
} }
li := newEnvoyListener("https-ingress")
li.Address = buildAddress(cfg.Options.Addr, 443)
li.ListenerFilters = listenerFilters
li.FilterChains = chains
return li, nil return li, nil
} }
@ -245,42 +238,8 @@ func (b *Builder) buildMetricsListener(cfg *config.Config) (*envoy_config_listen
return li, nil return li, nil
} }
func (b *Builder) buildFilterChains(
cfg *config.Config, addr string,
callback func(tlsDomain string, httpDomains []string) (*envoy_config_listener_v3.FilterChain, error),
) ([]*envoy_config_listener_v3.FilterChain, error) {
allDomains, err := getAllRouteableDomains(cfg.Options, addr)
if err != nil {
return nil, err
}
tlsDomains, err := getAllTLSDomains(cfg, addr)
if err != nil {
return nil, err
}
var chains []*envoy_config_listener_v3.FilterChain
chains = append(chains, b.buildACMETLSALPNFilterChain())
for _, domain := range tlsDomains {
chain, err := callback(domain, allDomains)
if err != nil {
return nil, err
}
chains = append(chains, chain)
}
// if there are no SNI matches we match on HTTP host
chain, err := callback("*", allDomains)
if err != nil {
return nil, err
}
chains = append(chains, chain)
return chains, nil
}
func (b *Builder) buildMainHTTPConnectionManagerFilter( func (b *Builder) buildMainHTTPConnectionManagerFilter(
options *config.Options, options *config.Options,
domains []string,
requireStrictTransportSecurity bool, requireStrictTransportSecurity bool,
) (*envoy_config_listener_v3.Filter, error) { ) (*envoy_config_listener_v3.Filter, error) {
authorizeURLs, err := options.GetInternalAuthorizeURLs() authorizeURLs, err := options.GetInternalAuthorizeURLs()
@ -293,17 +252,22 @@ func (b *Builder) buildMainHTTPConnectionManagerFilter(
return nil, err return nil, err
} }
allHosts, err := getAllRouteableHosts(options, options.Addr)
if err != nil {
return nil, err
}
var virtualHosts []*envoy_config_route_v3.VirtualHost var virtualHosts []*envoy_config_route_v3.VirtualHost
for _, domain := range domains { for _, host := range allHosts {
vh, err := b.buildVirtualHost(options, domain, domain, requireStrictTransportSecurity) vh, err := b.buildVirtualHost(options, host, host, requireStrictTransportSecurity)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if options.Addr == options.GetGRPCAddr() { if options.Addr == options.GetGRPCAddr() {
// if this is a gRPC service domain and we're supposed to handle that, add those routes // if this is a gRPC service domain and we're supposed to handle that, add those routes
if (config.IsAuthorize(options.Services) && hostsMatchDomain(authorizeURLs, domain)) || if (config.IsAuthorize(options.Services) && urlsMatchHost(authorizeURLs, host)) ||
(config.IsDataBroker(options.Services) && hostsMatchDomain(dataBrokerURLs, domain)) { (config.IsDataBroker(options.Services) && urlsMatchHost(dataBrokerURLs, host)) {
rs, err := b.buildGRPCRoutes() rs, err := b.buildGRPCRoutes()
if err != nil { if err != nil {
return nil, err return nil, err
@ -314,7 +278,7 @@ func (b *Builder) buildMainHTTPConnectionManagerFilter(
// if we're the proxy, add all the policy routes // if we're the proxy, add all the policy routes
if config.IsProxy(options.Services) { if config.IsProxy(options.Services) {
rs, err := b.buildPolicyRoutes(options, domain) rs, err := b.buildPolicyRoutes(options, host)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -445,28 +409,35 @@ func (b *Builder) buildGRPCListener(ctx context.Context, cfg *config.Config) (*e
return nil, err return nil, err
} }
li := newEnvoyListener("grpc-ingress")
if cfg.Options.GetGRPCInsecure() { if cfg.Options.GetGRPCInsecure() {
li := newEnvoyListener("grpc-ingress")
li.Address = buildAddress(cfg.Options.GetGRPCAddr(), 80) li.Address = buildAddress(cfg.Options.GetGRPCAddr(), 80)
li.FilterChains = []*envoy_config_listener_v3.FilterChain{{ li.FilterChains = []*envoy_config_listener_v3.FilterChain{{
Filters: []*envoy_config_listener_v3.Filter{ Filters: []*envoy_config_listener_v3.Filter{
filter, filter,
}, },
}} }}
return li, nil } else {
} li.Address = buildAddress(cfg.Options.GetGRPCAddr(), 443)
li.ListenerFilters = []*envoy_config_listener_v3.ListenerFilter{
TLSInspectorFilter(),
}
chains, err := b.buildFilterChains(cfg, cfg.Options.GRPCAddr, serverNames, err := getAllServerNames(cfg, cfg.Options.GRPCAddr)
func(tlsDomain string, httpDomains []string) (*envoy_config_listener_v3.FilterChain, error) { if err != nil {
return nil, err
}
for _, serverName := range serverNames {
filterChain := &envoy_config_listener_v3.FilterChain{ filterChain := &envoy_config_listener_v3.FilterChain{
Filters: []*envoy_config_listener_v3.Filter{filter}, Filters: []*envoy_config_listener_v3.Filter{filter},
} }
if tlsDomain != "*" { if serverName != "*" {
filterChain.FilterChainMatch = &envoy_config_listener_v3.FilterChainMatch{ filterChain.FilterChainMatch = &envoy_config_listener_v3.FilterChainMatch{
ServerNames: []string{tlsDomain}, ServerNames: []string{serverName},
} }
} }
tlsContext := b.buildDownstreamTLSContext(ctx, cfg, tlsDomain) tlsContext := b.buildDownstreamTLSContext(ctx, cfg, serverName)
if tlsContext != nil { if tlsContext != nil {
tlsConfig := marshalAny(tlsContext) tlsConfig := marshalAny(tlsContext)
filterChain.TransportSocket = &envoy_config_core_v3.TransportSocket{ filterChain.TransportSocket = &envoy_config_core_v3.TransportSocket{
@ -476,18 +447,9 @@ func (b *Builder) buildGRPCListener(ctx context.Context, cfg *config.Config) (*e
}, },
} }
} }
return filterChain, nil li.FilterChains = append(li.FilterChains, filterChain)
}) }
if err != nil {
return nil, err
} }
li := newEnvoyListener("grpc-ingress")
li.Address = buildAddress(cfg.Options.GetGRPCAddr(), 443)
li.ListenerFilters = []*envoy_config_listener_v3.ListenerFilter{
TLSInspectorFilter(),
}
li.FilterChains = chains
return li, nil return li, nil
} }
@ -548,23 +510,23 @@ func (b *Builder) buildRouteConfiguration(name string, virtualHosts []*envoy_con
func (b *Builder) buildDownstreamTLSContext(ctx context.Context, func (b *Builder) buildDownstreamTLSContext(ctx context.Context,
cfg *config.Config, cfg *config.Config,
domain string, serverName string,
) *envoy_extensions_transport_sockets_tls_v3.DownstreamTlsContext { ) *envoy_extensions_transport_sockets_tls_v3.DownstreamTlsContext {
certs, err := cfg.AllCertificates() certs, err := cfg.AllCertificates()
if err != nil { if err != nil {
log.Warn(ctx).Str("domain", domain).Err(err).Msg("failed to get all certificates from config") log.Warn(ctx).Str("domain", serverName).Err(err).Msg("failed to get all certificates from config")
return nil return nil
} }
cert, err := cryptutil.GetCertificateForDomain(certs, domain) cert, err := cryptutil.GetCertificateForServerName(certs, serverName)
if err != nil { if err != nil {
log.Warn(ctx).Str("domain", domain).Err(err).Msg("failed to get certificate for domain") log.Warn(ctx).Str("domain", serverName).Err(err).Msg("failed to get certificate for domain")
return nil return nil
} }
err = validateCertificate(cert) err = validateCertificate(cert)
if err != nil { if err != nil {
log.Warn(ctx).Str("domain", domain).Err(err).Msg("invalid certificate for domain") log.Warn(ctx).Str("domain", serverName).Err(err).Msg("invalid certificate for domain")
return nil return nil
} }
@ -584,14 +546,14 @@ func (b *Builder) buildDownstreamTLSContext(ctx context.Context,
TlsParams: tlsParams, TlsParams: tlsParams,
TlsCertificates: []*envoy_extensions_transport_sockets_tls_v3.TlsCertificate{envoyCert}, TlsCertificates: []*envoy_extensions_transport_sockets_tls_v3.TlsCertificate{envoyCert},
AlpnProtocols: alpnProtocols, AlpnProtocols: alpnProtocols,
ValidationContextType: b.buildDownstreamValidationContext(ctx, cfg, domain), ValidationContextType: b.buildDownstreamValidationContext(ctx, cfg, serverName),
}, },
} }
} }
func (b *Builder) buildDownstreamValidationContext(ctx context.Context, func (b *Builder) buildDownstreamValidationContext(ctx context.Context,
cfg *config.Config, cfg *config.Config,
domain string, serverName string,
) *envoy_extensions_transport_sockets_tls_v3.CommonTlsContext_ValidationContext { ) *envoy_extensions_transport_sockets_tls_v3.CommonTlsContext_ValidationContext {
needsClientCert := false needsClientCert := false
@ -599,7 +561,7 @@ func (b *Builder) buildDownstreamValidationContext(ctx context.Context,
needsClientCert = true needsClientCert = true
} }
if !needsClientCert { if !needsClientCert {
for _, p := range getPoliciesForDomain(cfg.Options, domain) { for _, p := range getPoliciesForServerName(cfg.Options, serverName) {
if p.TLSDownstreamClientCA != "" { if p.TLSDownstreamClientCA != "" {
needsClientCert = true needsClientCert = true
break break
@ -632,40 +594,41 @@ func (b *Builder) buildDownstreamValidationContext(ctx context.Context,
return vc return vc
} }
func getAllRouteableDomains(options *config.Options, addr string) ([]string, error) { func getAllRouteableHosts(options *config.Options, addr string) ([]string, error) {
allDomains := sets.NewSorted[string]() allHosts := sets.NewSorted[string]()
if addr == options.Addr { if addr == options.Addr {
domains, err := options.GetAllRouteableHTTPDomains() hosts, err := options.GetAllRouteableHTTPHosts()
if err != nil { if err != nil {
return nil, err return nil, err
} }
allDomains.Add(domains...) allHosts.Add(hosts...)
} }
if addr == options.GetGRPCAddr() { if addr == options.GetGRPCAddr() {
domains, err := options.GetAllRouteableGRPCDomains() hosts, err := options.GetAllRouteableGRPCHosts()
if err != nil { if err != nil {
return nil, err return nil, err
} }
allDomains.Add(domains...) allHosts.Add(hosts...)
} }
return allDomains.ToSlice(), nil return allHosts.ToSlice(), nil
} }
func getAllTLSDomains(cfg *config.Config, addr string) ([]string, error) { func getAllServerNames(cfg *config.Config, addr string) ([]string, error) {
domains := sets.NewSorted[string]() serverNames := sets.NewSorted[string]()
serverNames.Add("*")
routeableDomains, err := getAllRouteableDomains(cfg.Options, addr) routeableHosts, err := getAllRouteableHosts(cfg.Options, addr)
if err != nil { if err != nil {
return nil, err return nil, err
} }
for _, hp := range routeableDomains { for _, hp := range routeableHosts {
if d, _, err := net.SplitHostPort(hp); err == nil { if h, _, err := net.SplitHostPort(hp); err == nil {
domains.Add(d) serverNames.Add(h)
} else { } else {
domains.Add(hp) serverNames.Add(hp)
} }
} }
@ -674,24 +637,24 @@ func getAllTLSDomains(cfg *config.Config, addr string) ([]string, error) {
return nil, err return nil, err
} }
for i := range certs { for i := range certs {
for _, domain := range cryptutil.GetCertificateDomains(&certs[i]) { for _, domain := range cryptutil.GetCertificateServerNames(&certs[i]) {
domains.Add(domain) serverNames.Add(domain)
} }
} }
return domains.ToSlice(), nil return serverNames.ToSlice(), nil
} }
func hostsMatchDomain(urls []*url.URL, host string) bool { func urlsMatchHost(urls []*url.URL, host string) bool {
for _, u := range urls { for _, u := range urls {
if hostMatchesDomain(u, host) { if urlMatchesHost(u, host) {
return true return true
} }
} }
return false return false
} }
func hostMatchesDomain(u *url.URL, host string) bool { func urlMatchesHost(u *url.URL, host string) bool {
if u == nil { if u == nil {
return false return false
} }
@ -718,10 +681,10 @@ func hostMatchesDomain(u *url.URL, host string) bool {
return h1 == h2 && p1 == p2 return h1 == h2 && p1 == p2
} }
func getPoliciesForDomain(options *config.Options, domain string) []config.Policy { func getPoliciesForServerName(options *config.Options, serverName string) []config.Policy {
var policies []config.Policy var policies []config.Policy
for _, p := range options.GetAllPolicies() { for _, p := range options.GetAllPolicies() {
if p.Source != nil && p.Source.URL.Hostname() == domain { if p.Source != nil && urlutil.MatchesServerName(*p.Source.URL, serverName) {
policies = append(policies, p) policies = append(policies, p)
} }
} }

View file

@ -129,7 +129,8 @@ func Test_buildMainHTTPConnectionManagerFilter(t *testing.T) {
options := config.NewDefaultOptions() options := config.NewDefaultOptions()
options.SkipXffAppend = true options.SkipXffAppend = true
options.XffNumTrustedHops = 1 options.XffNumTrustedHops = 1
filter, err := b.buildMainHTTPConnectionManagerFilter(options, []string{"example.com"}, true) options.AuthenticateURLString = "https://authenticate.example.com"
filter, err := b.buildMainHTTPConnectionManagerFilter(options, true)
require.NoError(t, err) require.NoError(t, err)
testutil.AssertProtoJSONEqual(t, `{ testutil.AssertProtoJSONEqual(t, `{
"name": "envoy.filters.network.http_connection_manager", "name": "envoy.filters.network.http_connection_manager",
@ -220,8 +221,8 @@ func Test_buildMainHTTPConnectionManagerFilter(t *testing.T) {
"name": "main", "name": "main",
"virtualHosts": [ "virtualHosts": [
{ {
"name": "example.com", "name": "authenticate.example.com",
"domains": ["example.com"], "domains": ["authenticate.example.com"],
"responseHeadersToAdd": [{ "responseHeadersToAdd": [{
"appendAction": "OVERWRITE_IF_EXISTS_OR_ADD", "appendAction": "OVERWRITE_IF_EXISTS_OR_ADD",
"header": { "header": {
@ -366,6 +367,216 @@ func Test_buildMainHTTPConnectionManagerFilter(t *testing.T) {
"disabled": true "disabled": true
} }
} }
},
{
"name": "pomerium-path-/oauth2/callback",
"match": {
"path": "/oauth2/callback"
},
"route": {
"cluster": "pomerium-control-plane-http"
},
"typedPerFilterConfig": {
"envoy.filters.http.ext_authz": {
"@type": "type.googleapis.com/envoy.extensions.filters.http.ext_authz.v3.ExtAuthzPerRoute",
"disabled": true
}
}
},
{
"name": "pomerium-path-/",
"match": {
"path": "/"
},
"route": {
"cluster": "pomerium-control-plane-http"
},
"typedPerFilterConfig": {
"envoy.filters.http.ext_authz": {
"@type": "type.googleapis.com/envoy.extensions.filters.http.ext_authz.v3.ExtAuthzPerRoute",
"disabled": true
}
}
}
]
},
{
"name": "authenticate.example.com:443",
"domains": ["authenticate.example.com:443"],
"responseHeadersToAdd": [{
"appendAction": "OVERWRITE_IF_EXISTS_OR_ADD",
"header": {
"key": "Strict-Transport-Security",
"value": "max-age=31536000; includeSubDomains; preload"
}
},
{
"appendAction": "OVERWRITE_IF_EXISTS_OR_ADD",
"header": {
"key": "X-Frame-Options",
"value": "SAMEORIGIN"
}
},
{
"appendAction": "OVERWRITE_IF_EXISTS_OR_ADD",
"header": {
"key": "X-XSS-Protection",
"value": "1; mode=block"
}
}],
"routes": [
{
"name": "pomerium-path-/.pomerium/jwt",
"match": {
"path": "/.pomerium/jwt"
},
"route": {
"cluster": "pomerium-control-plane-http"
}
},
{
"name": "pomerium-path-/.pomerium/webauthn",
"match": {
"path": "/.pomerium/webauthn"
},
"route": {
"cluster": "pomerium-control-plane-http"
}
},
{
"name": "pomerium-path-/ping",
"match": {
"path": "/ping"
},
"route": {
"cluster": "pomerium-control-plane-http"
},
"typedPerFilterConfig": {
"envoy.filters.http.ext_authz": {
"@type": "type.googleapis.com/envoy.extensions.filters.http.ext_authz.v3.ExtAuthzPerRoute",
"disabled": true
}
}
},
{
"name": "pomerium-path-/healthz",
"match": {
"path": "/healthz"
},
"route": {
"cluster": "pomerium-control-plane-http"
},
"typedPerFilterConfig": {
"envoy.filters.http.ext_authz": {
"@type": "type.googleapis.com/envoy.extensions.filters.http.ext_authz.v3.ExtAuthzPerRoute",
"disabled": true
}
}
},
{
"name": "pomerium-path-/.pomerium",
"match": {
"path": "/.pomerium"
},
"route": {
"cluster": "pomerium-control-plane-http"
},
"typedPerFilterConfig": {
"envoy.filters.http.ext_authz": {
"@type": "type.googleapis.com/envoy.extensions.filters.http.ext_authz.v3.ExtAuthzPerRoute",
"disabled": true
}
}
},
{
"name": "pomerium-prefix-/.pomerium/",
"match": {
"prefix": "/.pomerium/"
},
"route": {
"cluster": "pomerium-control-plane-http"
},
"typedPerFilterConfig": {
"envoy.filters.http.ext_authz": {
"@type": "type.googleapis.com/envoy.extensions.filters.http.ext_authz.v3.ExtAuthzPerRoute",
"disabled": true
}
}
},
{
"name": "pomerium-path-/.well-known/pomerium",
"match": {
"path": "/.well-known/pomerium"
},
"route": {
"cluster": "pomerium-control-plane-http"
},
"typedPerFilterConfig": {
"envoy.filters.http.ext_authz": {
"@type": "type.googleapis.com/envoy.extensions.filters.http.ext_authz.v3.ExtAuthzPerRoute",
"disabled": true
}
}
},
{
"name": "pomerium-prefix-/.well-known/pomerium/",
"match": {
"prefix": "/.well-known/pomerium/"
},
"route": {
"cluster": "pomerium-control-plane-http"
},
"typedPerFilterConfig": {
"envoy.filters.http.ext_authz": {
"@type": "type.googleapis.com/envoy.extensions.filters.http.ext_authz.v3.ExtAuthzPerRoute",
"disabled": true
}
}
},
{
"name": "pomerium-path-/robots.txt",
"match": {
"path": "/robots.txt"
},
"route": {
"cluster": "pomerium-control-plane-http"
},
"typedPerFilterConfig": {
"envoy.filters.http.ext_authz": {
"@type": "type.googleapis.com/envoy.extensions.filters.http.ext_authz.v3.ExtAuthzPerRoute",
"disabled": true
}
}
},
{
"name": "pomerium-path-/oauth2/callback",
"match": {
"path": "/oauth2/callback"
},
"route": {
"cluster": "pomerium-control-plane-http"
},
"typedPerFilterConfig": {
"envoy.filters.http.ext_authz": {
"@type": "type.googleapis.com/envoy.extensions.filters.http.ext_authz.v3.ExtAuthzPerRoute",
"disabled": true
}
}
},
{
"name": "pomerium-path-/",
"match": {
"path": "/"
},
"route": {
"cluster": "pomerium-control-plane-http"
},
"typedPerFilterConfig": {
"envoy.filters.http.ext_authz": {
"@type": "type.googleapis.com/envoy.extensions.filters.http.ext_authz.v3.ExtAuthzPerRoute",
"disabled": true
}
}
} }
] ]
}, },
@ -779,7 +990,7 @@ func Test_getAllDomains(t *testing.T) {
} }
t.Run("routable", func(t *testing.T) { t.Run("routable", func(t *testing.T) {
t.Run("http", func(t *testing.T) { t.Run("http", func(t *testing.T) {
actual, err := getAllRouteableDomains(options, "127.0.0.1:9000") actual, err := getAllRouteableHosts(options, "127.0.0.1:9000")
require.NoError(t, err) require.NoError(t, err)
expect := []string{ expect := []string{
"a.example.com", "a.example.com",
@ -794,7 +1005,7 @@ func Test_getAllDomains(t *testing.T) {
assert.Equal(t, expect, actual) assert.Equal(t, expect, actual)
}) })
t.Run("grpc", func(t *testing.T) { t.Run("grpc", func(t *testing.T) {
actual, err := getAllRouteableDomains(options, "127.0.0.1:9001") actual, err := getAllRouteableHosts(options, "127.0.0.1:9001")
require.NoError(t, err) require.NoError(t, err)
expect := []string{ expect := []string{
"authorize.example.com:9001", "authorize.example.com:9001",
@ -805,7 +1016,7 @@ func Test_getAllDomains(t *testing.T) {
t.Run("both", func(t *testing.T) { t.Run("both", func(t *testing.T) {
newOptions := *options newOptions := *options
newOptions.GRPCAddr = newOptions.Addr newOptions.GRPCAddr = newOptions.Addr
actual, err := getAllRouteableDomains(&newOptions, "127.0.0.1:9000") actual, err := getAllRouteableHosts(&newOptions, "127.0.0.1:9000")
require.NoError(t, err) require.NoError(t, err)
expect := []string{ expect := []string{
"a.example.com", "a.example.com",
@ -824,9 +1035,10 @@ func Test_getAllDomains(t *testing.T) {
}) })
t.Run("tls", func(t *testing.T) { t.Run("tls", func(t *testing.T) {
t.Run("http", func(t *testing.T) { t.Run("http", func(t *testing.T) {
actual, err := getAllTLSDomains(&config.Config{Options: options}, "127.0.0.1:9000") actual, err := getAllServerNames(&config.Config{Options: options}, "127.0.0.1:9000")
require.NoError(t, err) require.NoError(t, err)
expect := []string{ expect := []string{
"*",
"*.unknown.example.com", "*.unknown.example.com",
"a.example.com", "a.example.com",
"authenticate.example.com", "authenticate.example.com",
@ -836,9 +1048,10 @@ func Test_getAllDomains(t *testing.T) {
assert.Equal(t, expect, actual) assert.Equal(t, expect, actual)
}) })
t.Run("grpc", func(t *testing.T) { t.Run("grpc", func(t *testing.T) {
actual, err := getAllTLSDomains(&config.Config{Options: options}, "127.0.0.1:9001") actual, err := getAllServerNames(&config.Config{Options: options}, "127.0.0.1:9001")
require.NoError(t, err) require.NoError(t, err)
expect := []string{ expect := []string{
"*",
"*.unknown.example.com", "*.unknown.example.com",
"authorize.example.com", "authorize.example.com",
"cache.example.com", "cache.example.com",
@ -848,14 +1061,31 @@ func Test_getAllDomains(t *testing.T) {
}) })
} }
func Test_hostMatchesDomain(t *testing.T) { func Test_urlMatchesHost(t *testing.T) {
assert.True(t, hostMatchesDomain(mustParseURL(t, "http://example.com"), "example.com")) t.Parallel()
assert.True(t, hostMatchesDomain(mustParseURL(t, "http://example.com"), "example.com:80"))
assert.True(t, hostMatchesDomain(mustParseURL(t, "https://example.com"), "example.com:443")) for _, tc := range []struct {
assert.True(t, hostMatchesDomain(mustParseURL(t, "https://example.com:443"), "example.com:443")) name string
assert.True(t, hostMatchesDomain(mustParseURL(t, "https://example.com:443"), "example.com")) sourceURL string
assert.False(t, hostMatchesDomain(mustParseURL(t, "http://example.com:81"), "example.com")) host string
assert.False(t, hostMatchesDomain(mustParseURL(t, "http://example.com:81"), "example.com:80")) matches bool
}{
{"no port", "http://example.com", "example.com", true},
{"host http port", "http://example.com", "example.com:80", true},
{"host https port", "https://example.com", "example.com:443", true},
{"with port", "https://example.com:443", "example.com:443", true},
{"url port", "https://example.com:443", "example.com", true},
{"non standard port", "http://example.com:81", "example.com", false},
{"non standard host port", "http://example.com:81", "example.com:80", false},
} {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
assert.Equal(t, tc.matches, urlMatchesHost(mustParseURL(t, tc.sourceURL), tc.host),
"urlMatchesHost(%s,%s)", tc.sourceURL, tc.host)
})
}
} }
func Test_buildRouteConfiguration(t *testing.T) { func Test_buildRouteConfiguration(t *testing.T) {

View file

@ -47,12 +47,12 @@ func (b *Builder) buildGRPCRoutes() ([]*envoy_config_route_v3.Route, error) {
}}, nil }}, nil
} }
func (b *Builder) buildPomeriumHTTPRoutes(options *config.Options, domain string) ([]*envoy_config_route_v3.Route, error) { func (b *Builder) buildPomeriumHTTPRoutes(options *config.Options, host string) ([]*envoy_config_route_v3.Route, error) {
var routes []*envoy_config_route_v3.Route var routes []*envoy_config_route_v3.Route
// if this is the pomerium proxy in front of the the authenticate service, don't add // if this is the pomerium proxy in front of the the authenticate service, don't add
// these routes since they will be handled by authenticate // these routes since they will be handled by authenticate
isFrontingAuthenticate, err := isProxyFrontingAuthenticate(options, domain) isFrontingAuthenticate, err := isProxyFrontingAuthenticate(options, host)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -70,7 +70,7 @@ func (b *Builder) buildPomeriumHTTPRoutes(options *config.Options, domain string
b.buildControlPlanePrefixRoute("/.well-known/pomerium/", false), b.buildControlPlanePrefixRoute("/.well-known/pomerium/", false),
) )
// per #837, only add robots.txt if there are no unauthenticated routes // per #837, only add robots.txt if there are no unauthenticated routes
if !hasPublicPolicyMatchingURL(options, url.URL{Scheme: "https", Host: domain, Path: "/robots.txt"}) { if !hasPublicPolicyMatchingURL(options, url.URL{Scheme: "https", Host: host, Path: "/robots.txt"}) {
routes = append(routes, b.buildControlPlanePathRoute("/robots.txt", false)) routes = append(routes, b.buildControlPlanePathRoute("/robots.txt", false))
} }
} }
@ -79,7 +79,7 @@ func (b *Builder) buildPomeriumHTTPRoutes(options *config.Options, domain string
if err != nil { if err != nil {
return nil, err return nil, err
} }
if config.IsAuthenticate(options.Services) && hostMatchesDomain(authenticateURL, domain) { if config.IsAuthenticate(options.Services) && urlMatchesHost(authenticateURL, host) {
routes = append(routes, routes = append(routes,
b.buildControlPlanePathRoute(options.AuthenticateCallbackPath, false), b.buildControlPlanePathRoute(options.AuthenticateCallbackPath, false),
b.buildControlPlanePathRoute("/", false), b.buildControlPlanePathRoute("/", false),
@ -151,12 +151,12 @@ func getClusterStatsName(policy *config.Policy) string {
return "" return ""
} }
func (b *Builder) buildPolicyRoutes(options *config.Options, domain string) ([]*envoy_config_route_v3.Route, error) { func (b *Builder) buildPolicyRoutes(options *config.Options, host string) ([]*envoy_config_route_v3.Route, error) {
var routes []*envoy_config_route_v3.Route var routes []*envoy_config_route_v3.Route
for i, p := range options.GetAllPolicies() { for i, p := range options.GetAllPolicies() {
policy := p policy := p
if !hostMatchesDomain(policy.Source.URL, domain) { if !urlMatchesHost(policy.Source.URL, host) {
continue continue
} }
@ -188,7 +188,7 @@ func (b *Builder) buildPolicyRoutes(options *config.Options, domain string) ([]*
} }
// disable authentication entirely when the proxy is fronting authenticate // disable authentication entirely when the proxy is fronting authenticate
isFrontingAuthenticate, err := isProxyFrontingAuthenticate(options, domain) isFrontingAuthenticate, err := isProxyFrontingAuthenticate(options, host)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -497,13 +497,13 @@ func hasPublicPolicyMatchingURL(options *config.Options, requestURL url.URL) boo
return false return false
} }
func isProxyFrontingAuthenticate(options *config.Options, domain string) (bool, error) { func isProxyFrontingAuthenticate(options *config.Options, host string) (bool, error) {
authenticateURL, err := options.GetAuthenticateURL() authenticateURL, err := options.GetAuthenticateURL()
if err != nil { if err != nil {
return false, err return false, err
} }
if !config.IsAuthenticate(options.Services) && hostMatchesDomain(authenticateURL, domain) { if !config.IsAuthenticate(options.Services) && urlMatchesHost(authenticateURL, host) {
return true, nil return true, nil
} }

View file

@ -1015,15 +1015,9 @@ func (o *Options) GetCodecType() CodecType {
return o.CodecType return o.CodecType
} }
// GetAllRouteableGRPCDomains returns all the possible gRPC domains handled by the Pomerium options. // GetAllRouteableGRPCHosts returns all the possible gRPC hosts handled by the Pomerium options.
func (o *Options) GetAllRouteableGRPCDomains() ([]string, error) { func (o *Options) GetAllRouteableGRPCHosts() ([]string, error) {
return o.GetAllRouteableGRPCDomainsForTLSServerName("") hosts := sets.NewSorted[string]()
}
// GetAllRouteableGRPCDomainsForTLSServerName returns all the possible gRPC domains handled by the Pomerium options
// for the given TLS server name.
func (o *Options) GetAllRouteableGRPCDomainsForTLSServerName(tlsServerName string) ([]string, error) {
domains := sets.NewSorted[string]()
// authorize urls // authorize urls
if IsAll(o.Services) { if IsAll(o.Services) {
@ -1032,11 +1026,7 @@ func (o *Options) GetAllRouteableGRPCDomainsForTLSServerName(tlsServerName strin
return nil, err return nil, err
} }
for _, u := range authorizeURLs { for _, u := range authorizeURLs {
for _, h := range urlutil.GetDomainsForURL(*u) { hosts.Add(urlutil.GetDomainsForURL(*u)...)
if tlsServerName == "" || urlutil.StripPort(h) == tlsServerName {
domains.Add(h)
}
}
} }
} else if IsAuthorize(o.Services) { } else if IsAuthorize(o.Services) {
authorizeURLs, err := o.GetInternalAuthorizeURLs() authorizeURLs, err := o.GetInternalAuthorizeURLs()
@ -1044,11 +1034,7 @@ func (o *Options) GetAllRouteableGRPCDomainsForTLSServerName(tlsServerName strin
return nil, err return nil, err
} }
for _, u := range authorizeURLs { for _, u := range authorizeURLs {
for _, h := range urlutil.GetDomainsForURL(*u) { hosts.Add(urlutil.GetDomainsForURL(*u)...)
if tlsServerName == "" || urlutil.StripPort(h) == tlsServerName {
domains.Add(h)
}
}
} }
} }
@ -1059,11 +1045,7 @@ func (o *Options) GetAllRouteableGRPCDomainsForTLSServerName(tlsServerName strin
return nil, err return nil, err
} }
for _, u := range dataBrokerURLs { for _, u := range dataBrokerURLs {
for _, h := range urlutil.GetDomainsForURL(*u) { hosts.Add(urlutil.GetDomainsForURL(*u)...)
if tlsServerName == "" || urlutil.StripPort(h) == tlsServerName {
domains.Add(h)
}
}
} }
} else if IsDataBroker(o.Services) { } else if IsDataBroker(o.Services) {
dataBrokerURLs, err := o.GetInternalDataBrokerURLs() dataBrokerURLs, err := o.GetInternalDataBrokerURLs()
@ -1071,71 +1053,42 @@ func (o *Options) GetAllRouteableGRPCDomainsForTLSServerName(tlsServerName strin
return nil, err return nil, err
} }
for _, u := range dataBrokerURLs { for _, u := range dataBrokerURLs {
for _, h := range urlutil.GetDomainsForURL(*u) { hosts.Add(urlutil.GetDomainsForURL(*u)...)
if tlsServerName == "" || urlutil.StripPort(h) == tlsServerName {
domains.Add(h)
}
}
} }
} }
return domains.ToSlice(), nil return hosts.ToSlice(), nil
} }
// GetAllRouteableHTTPDomains returns all the possible HTTP domains handled by the Pomerium options. // GetAllRouteableHTTPHosts returns all the possible HTTP hosts handled by the Pomerium options.
func (o *Options) GetAllRouteableHTTPDomains() ([]string, error) { func (o *Options) GetAllRouteableHTTPHosts() ([]string, error) {
return o.GetAllRouteableHTTPDomainsForTLSServerName("") hosts := sets.NewSorted[string]()
}
// GetAllRouteableHTTPDomainsForTLSServerName returns all the possible HTTP domains handled by the Pomerium options
// for the given TLS server name.
func (o *Options) GetAllRouteableHTTPDomainsForTLSServerName(tlsServerName string) ([]string, error) {
domains := sets.NewSorted[string]()
if IsAuthenticate(o.Services) { if IsAuthenticate(o.Services) {
authenticateURL, err := o.GetInternalAuthenticateURL() authenticateURL, err := o.GetInternalAuthenticateURL()
if err != nil { if err != nil {
return nil, err return nil, err
} }
for _, h := range urlutil.GetDomainsForURL(*authenticateURL) { hosts.Add(urlutil.GetDomainsForURL(*authenticateURL)...)
if tlsServerName == "" || urlutil.StripPort(h) == tlsServerName {
domains.Add(h)
}
}
authenticateURL, err = o.GetAuthenticateURL() authenticateURL, err = o.GetAuthenticateURL()
if err != nil { if err != nil {
return nil, err return nil, err
} }
for _, h := range urlutil.GetDomainsForURL(*authenticateURL) { hosts.Add(urlutil.GetDomainsForURL(*authenticateURL)...)
if tlsServerName == "" || urlutil.StripPort(h) == tlsServerName {
domains.Add(h)
}
}
} }
// policy urls // policy urls
if IsProxy(o.Services) { if IsProxy(o.Services) {
for _, policy := range o.GetAllPolicies() { for _, policy := range o.GetAllPolicies() {
for _, h := range urlutil.GetDomainsForURL(*policy.Source.URL) { hosts.Add(urlutil.GetDomainsForURL(*policy.Source.URL)...)
if tlsServerName == "" ||
policy.TLSDownstreamServerName == tlsServerName ||
urlutil.StripPort(h) == tlsServerName {
domains.Add(h)
}
}
if policy.TLSDownstreamServerName != "" { if policy.TLSDownstreamServerName != "" {
tlsURL := policy.Source.URL.ResolveReference(&url.URL{Host: policy.TLSDownstreamServerName}) tlsURL := policy.Source.URL.ResolveReference(&url.URL{Host: policy.TLSDownstreamServerName})
for _, h := range urlutil.GetDomainsForURL(*tlsURL) { hosts.Add(urlutil.GetDomainsForURL(*tlsURL)...)
if tlsServerName == "" ||
urlutil.StripPort(h) == tlsServerName {
domains.Add(h)
}
}
} }
} }
} }
return domains.ToSlice(), nil return hosts.ToSlice(), nil
} }
// GetClientSecret gets the client secret. // GetClientSecret gets the client secret.

View file

@ -666,14 +666,14 @@ func TestOptions_GetOauthOptions(t *testing.T) {
assert.Equal(t, u.Hostname(), oauthOptions.RedirectURL.Hostname()) assert.Equal(t, u.Hostname(), oauthOptions.RedirectURL.Hostname())
} }
func TestOptions_GetAllRouteableGRPCDomains(t *testing.T) { func TestOptions_GetAllRouteableGRPCHosts(t *testing.T) {
opts := &Options{ opts := &Options{
AuthenticateURLString: "https://authenticate.example.com", AuthenticateURLString: "https://authenticate.example.com",
AuthorizeURLString: "https://authorize.example.com", AuthorizeURLString: "https://authorize.example.com",
DataBrokerURLString: "https://databroker.example.com", DataBrokerURLString: "https://databroker.example.com",
Services: "all", Services: "all",
} }
domains, err := opts.GetAllRouteableGRPCDomains() hosts, err := opts.GetAllRouteableGRPCHosts()
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, []string{ assert.Equal(t, []string{
@ -681,10 +681,10 @@ func TestOptions_GetAllRouteableGRPCDomains(t *testing.T) {
"authorize.example.com:443", "authorize.example.com:443",
"databroker.example.com", "databroker.example.com",
"databroker.example.com:443", "databroker.example.com:443",
}, domains) }, hosts)
} }
func TestOptions_GetAllRouteableHTTPDomains(t *testing.T) { func TestOptions_GetAllRouteableHTTPHosts(t *testing.T) {
p1 := Policy{From: "https://from1.example.com"} p1 := Policy{From: "https://from1.example.com"}
p1.Validate() p1.Validate()
p2 := Policy{From: "https://from2.example.com"} p2 := Policy{From: "https://from2.example.com"}
@ -699,7 +699,7 @@ func TestOptions_GetAllRouteableHTTPDomains(t *testing.T) {
Policies: []Policy{p1, p2, p3}, Policies: []Policy{p1, p2, p3},
Services: "all", Services: "all",
} }
domains, err := opts.GetAllRouteableHTTPDomains() hosts, err := opts.GetAllRouteableHTTPHosts()
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, []string{ assert.Equal(t, []string{
@ -713,7 +713,7 @@ func TestOptions_GetAllRouteableHTTPDomains(t *testing.T) {
"from2.example.com:443", "from2.example.com:443",
"from3.example.com", "from3.example.com",
"from3.example.com:443", "from3.example.com:443",
}, domains) }, hosts)
} }
func TestOptions_ApplySettings(t *testing.T) { func TestOptions_ApplySettings(t *testing.T) {

View file

@ -8,6 +8,8 @@ import (
"net/url" "net/url"
"strings" "strings"
"time" "time"
"github.com/caddyserver/certmagic"
) )
const ( const (
@ -160,3 +162,8 @@ func GetExternalRequest(internalURL, externalURL *url.URL, r *http.Request) *htt
} }
return er return er
} }
// MatchesServerName returnes true if the url's host matches the given server name.
func MatchesServerName(u url.URL, serverName string) bool {
return certmagic.MatchWildcard(u.Hostname(), serverName)
}

View file

@ -166,3 +166,9 @@ func TestJoin(t *testing.T) {
assert.Equal(t, "/x/y/z/", Join("/x", "/y/z/")) assert.Equal(t, "/x/y/z/", Join("/x", "/y/z/"))
assert.Equal(t, "/x/y/z/", Join("/x/", "/y/z/")) assert.Equal(t, "/x/y/z/", Join("/x/", "/y/z/"))
} }
func TestMatchesServerName(t *testing.T) {
t.Run("wildcard", func(t *testing.T) {
assert.True(t, MatchesServerName(MustParseAndValidateURL("https://domain.example.com"), "*.example.com"))
})
}

View file

@ -44,36 +44,36 @@ func GetCertPool(ca, caFile string) (*x509.CertPool, error) {
return rootCAs, nil return rootCAs, nil
} }
// GetCertificateForDomain returns the tls Certificate which matches the given domain name. // GetCertificateForServerName returns the tls Certificate which matches the given server name.
// It should handle both exact matches and wildcard matches. If none of those match, the first certificate will be used. // It should handle both exact matches and wildcard matches. If none of those match, the first certificate will be used.
// Finally if there are no matching certificates one will be generated. // Finally if there are no matching certificates one will be generated.
func GetCertificateForDomain(certificates []tls.Certificate, domain string) (*tls.Certificate, error) { func GetCertificateForServerName(certificates []tls.Certificate, serverName string) (*tls.Certificate, error) {
// first try a direct name match // first try a direct name match
for i := range certificates { for i := range certificates {
if matchesDomain(&certificates[i], domain) { if matchesServerName(&certificates[i], serverName) {
return &certificates[i], nil return &certificates[i], nil
} }
} }
log.WarnNoTLSCertificate(domain) log.WarnNoTLSCertificate(serverName)
// finally fall back to a generated, self-signed certificate // finally fall back to a generated, self-signed certificate
return GenerateSelfSignedCertificate(domain) return GenerateSelfSignedCertificate(serverName)
} }
// HasCertificateForDomain returns true if a TLS certificate matches the given domain. // HasCertificateForServerName returns true if a TLS certificate matches the given server name.
func HasCertificateForDomain(certificates []tls.Certificate, domain string) bool { func HasCertificateForServerName(certificates []tls.Certificate, serverName string) bool {
for i := range certificates { for i := range certificates {
if matchesDomain(&certificates[i], domain) { if matchesServerName(&certificates[i], serverName) {
return true return true
} }
} }
return false return false
} }
// GetCertificateDomains gets all the certificate's matching domain names. // GetCertificateServerNames gets all the certificate's server names.
// Will return an empty slice if certificate is nil, empty, or x509 parsing fails. // Will return an empty slice if certificate is nil, empty, or x509 parsing fails.
func GetCertificateDomains(cert *tls.Certificate) []string { func GetCertificateServerNames(cert *tls.Certificate) []string {
if cert == nil || len(cert.Certificate) == 0 { if cert == nil || len(cert.Certificate) == 0 {
return nil return nil
} }
@ -83,19 +83,19 @@ func GetCertificateDomains(cert *tls.Certificate) []string {
return nil return nil
} }
var domains []string var serverNames []string
if xcert.Subject.CommonName != "" { if xcert.Subject.CommonName != "" {
domains = append(domains, xcert.Subject.CommonName) serverNames = append(serverNames, xcert.Subject.CommonName)
} }
for _, dnsName := range xcert.DNSNames { for _, dnsName := range xcert.DNSNames {
if dnsName != "" { if dnsName != "" {
domains = append(domains, dnsName) serverNames = append(serverNames, dnsName)
} }
} }
return domains return serverNames
} }
func matchesDomain(cert *tls.Certificate, domain string) bool { func matchesServerName(cert *tls.Certificate, serverName string) bool {
if cert == nil || len(cert.Certificate) == 0 { if cert == nil || len(cert.Certificate) == 0 {
return false return false
} }
@ -105,12 +105,12 @@ func matchesDomain(cert *tls.Certificate, domain string) bool {
return false return false
} }
if certmagic.MatchWildcard(domain, xcert.Subject.CommonName) { if certmagic.MatchWildcard(serverName, xcert.Subject.CommonName) {
return true return true
} }
for _, san := range xcert.DNSNames { for _, san := range xcert.DNSNames {
if certmagic.MatchWildcard(domain, san) { if certmagic.MatchWildcard(serverName, san) {
return true return true
} }
} }

View file

@ -8,10 +8,10 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func TestGetCertificateForDomain(t *testing.T) { func TestGetCertificateForServerName(t *testing.T) {
gen := func(t *testing.T, domain string) *tls.Certificate { gen := func(t *testing.T, serverName string) *tls.Certificate {
cert, err := GenerateSelfSignedCertificate(domain) cert, err := GenerateSelfSignedCertificate(serverName)
if !assert.NoError(t, err, "error generating certificate for: %s", domain) { if !assert.NoError(t, err, "error generating certificate for: %s", serverName) {
t.FailNow() t.FailNow()
} }
return cert return cert
@ -23,7 +23,7 @@ func TestGetCertificateForDomain(t *testing.T) {
*gen(t, "b.example.com"), *gen(t, "b.example.com"),
} }
found, err := GetCertificateForDomain(certs, "b.example.com") found, err := GetCertificateForServerName(certs, "b.example.com")
if !assert.NoError(t, err) { if !assert.NoError(t, err) {
return return
} }
@ -35,7 +35,7 @@ func TestGetCertificateForDomain(t *testing.T) {
*gen(t, "*.example.com"), *gen(t, "*.example.com"),
} }
found, err := GetCertificateForDomain(certs, "b.example.com") found, err := GetCertificateForServerName(certs, "b.example.com")
if !assert.NoError(t, err) { if !assert.NoError(t, err) {
return return
} }
@ -46,7 +46,7 @@ func TestGetCertificateForDomain(t *testing.T) {
*gen(t, "a.example.com"), *gen(t, "a.example.com"),
} }
found, err := GetCertificateForDomain(certs, "b.example.com") found, err := GetCertificateForServerName(certs, "b.example.com")
if !assert.NoError(t, err) { if !assert.NoError(t, err) {
return return
} }
@ -56,7 +56,7 @@ func TestGetCertificateForDomain(t *testing.T) {
t.Run("generate", func(t *testing.T) { t.Run("generate", func(t *testing.T) {
certs := []tls.Certificate{} certs := []tls.Certificate{}
found, err := GetCertificateForDomain(certs, "b.example.com") found, err := GetCertificateForServerName(certs, "b.example.com")
if !assert.NoError(t, err) { if !assert.NoError(t, err) {
return return
} }
@ -64,8 +64,8 @@ func TestGetCertificateForDomain(t *testing.T) {
}) })
} }
func TestGetCertificateDomains(t *testing.T) { func TestGetCertificateServerNames(t *testing.T) {
cert, err := GenerateSelfSignedCertificate("www.example.com") cert, err := GenerateSelfSignedCertificate("www.example.com")
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, []string{"www.example.com"}, GetCertificateDomains(cert)) assert.Equal(t, []string{"www.example.com"}, GetCertificateServerNames(cert))
} }