diff --git a/config/envoyconfig/http_connection_manager.go b/config/envoyconfig/http_connection_manager.go index bbb4bffac..0fac5d3e5 100644 --- a/config/envoyconfig/http_connection_manager.go +++ b/config/envoyconfig/http_connection_manager.go @@ -12,16 +12,16 @@ import ( func (b *Builder) buildVirtualHost( options *config.Options, name string, - domain string, + host string, requireStrictTransportSecurity bool, ) (*envoy_config_route_v3.VirtualHost, error) { vh := &envoy_config_route_v3.VirtualHost{ Name: name, - Domains: []string{domain}, + Domains: []string{host}, } // these routes match /.pomerium/... and similar paths - rs, err := b.buildPomeriumHTTPRoutes(options, domain) + rs, err := b.buildPomeriumHTTPRoutes(options, host) if err != nil { return nil, err } diff --git a/config/envoyconfig/listeners.go b/config/envoyconfig/listeners.go index 567f45583..88d0d76b8 100644 --- a/config/envoyconfig/listeners.go +++ b/config/envoyconfig/listeners.go @@ -25,6 +25,7 @@ import ( "github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/sets" "github.com/pomerium/pomerium/internal/telemetry/metrics" + "github.com/pomerium/pomerium/internal/urlutil" "github.com/pomerium/pomerium/pkg/cryptutil" ) @@ -99,51 +100,50 @@ func (b *Builder) BuildListeners(ctx context.Context, cfg *config.Config) ([]*en } func (b *Builder) buildMainListener(ctx context.Context, cfg *config.Config) (*envoy_config_listener_v3.Listener, error) { - listenerFilters := []*envoy_config_listener_v3.ListenerFilter{} + li := newEnvoyListener("http-ingress") if cfg.Options.UseProxyProtocol { - listenerFilters = append(listenerFilters, ProxyProtocolFilter()) + li.ListenerFilters = append(li.ListenerFilters, ProxyProtocolFilter()) } if cfg.Options.InsecureServer { - allDomains, err := getAllRouteableDomains(cfg.Options, cfg.Options.Addr) - if err != nil { - return nil, err - } - - filter, err := b.buildMainHTTPConnectionManagerFilter(cfg.Options, allDomains, false) - if err != nil { - return nil, err - } - - li := newEnvoyListener("http-ingress") li.Address = buildAddress(cfg.Options.Addr, 80) - li.ListenerFilters = listenerFilters + + filter, err := b.buildMainHTTPConnectionManagerFilter(cfg.Options, false) + if err != nil { + return nil, err + } + li.FilterChains = []*envoy_config_listener_v3.FilterChain{{ Filters: []*envoy_config_listener_v3.Filter{ filter, }, }} - return li, nil - } - listenerFilters = append(listenerFilters, TLSInspectorFilter()) + } else { + li.Address = buildAddress(cfg.Options.Addr, 443) + li.ListenerFilters = append(li.ListenerFilters, TLSInspectorFilter()) - chains, err := b.buildFilterChains(cfg, cfg.Options.Addr, - func(tlsDomain string, httpDomains []string) (*envoy_config_listener_v3.FilterChain, error) { - allCertificates, _ := cfg.AllCertificates() - requireStrictTransportSecurity := cryptutil.HasCertificateForDomain(allCertificates, tlsDomain) - filter, err := b.buildMainHTTPConnectionManagerFilter(cfg.Options, httpDomains, requireStrictTransportSecurity) + allCertificates, _ := cfg.AllCertificates() + + serverNames, err := getAllServerNames(cfg, cfg.Options.Addr) + if err != nil { + return nil, err + } + + for _, serverName := range serverNames { + requireStrictTransportSecurity := cryptutil.HasCertificateForServerName(allCertificates, serverName) + filter, err := b.buildMainHTTPConnectionManagerFilter(cfg.Options, requireStrictTransportSecurity) if err != nil { return nil, err } filterChain := &envoy_config_listener_v3.FilterChain{ Filters: []*envoy_config_listener_v3.Filter{filter}, } - if tlsDomain != "*" { + if serverName != "*" { filterChain.FilterChainMatch = &envoy_config_listener_v3.FilterChainMatch{ - ServerNames: []string{tlsDomain}, + ServerNames: []string{serverName}, } } - tlsContext := b.buildDownstreamTLSContext(ctx, cfg, tlsDomain) + tlsContext := b.buildDownstreamTLSContext(ctx, cfg, serverName) if tlsContext != nil { tlsConfig := marshalAny(tlsContext) filterChain.TransportSocket = &envoy_config_core_v3.TransportSocket{ @@ -153,16 +153,9 @@ func (b *Builder) buildMainListener(ctx context.Context, cfg *config.Config) (*e }, } } - return filterChain, nil - }) - if err != nil { - return nil, err + li.FilterChains = append(li.FilterChains, filterChain) + } } - - li := newEnvoyListener("https-ingress") - li.Address = buildAddress(cfg.Options.Addr, 443) - li.ListenerFilters = listenerFilters - li.FilterChains = chains return li, nil } @@ -245,42 +238,8 @@ func (b *Builder) buildMetricsListener(cfg *config.Config) (*envoy_config_listen return li, nil } -func (b *Builder) buildFilterChains( - cfg *config.Config, addr string, - callback func(tlsDomain string, httpDomains []string) (*envoy_config_listener_v3.FilterChain, error), -) ([]*envoy_config_listener_v3.FilterChain, error) { - allDomains, err := getAllRouteableDomains(cfg.Options, addr) - if err != nil { - return nil, err - } - - tlsDomains, err := getAllTLSDomains(cfg, addr) - if err != nil { - return nil, err - } - - var chains []*envoy_config_listener_v3.FilterChain - chains = append(chains, b.buildACMETLSALPNFilterChain()) - for _, domain := range tlsDomains { - chain, err := callback(domain, allDomains) - if err != nil { - return nil, err - } - chains = append(chains, chain) - } - - // if there are no SNI matches we match on HTTP host - chain, err := callback("*", allDomains) - if err != nil { - return nil, err - } - chains = append(chains, chain) - return chains, nil -} - func (b *Builder) buildMainHTTPConnectionManagerFilter( options *config.Options, - domains []string, requireStrictTransportSecurity bool, ) (*envoy_config_listener_v3.Filter, error) { authorizeURLs, err := options.GetInternalAuthorizeURLs() @@ -293,17 +252,22 @@ func (b *Builder) buildMainHTTPConnectionManagerFilter( return nil, err } + allHosts, err := getAllRouteableHosts(options, options.Addr) + if err != nil { + return nil, err + } + var virtualHosts []*envoy_config_route_v3.VirtualHost - for _, domain := range domains { - vh, err := b.buildVirtualHost(options, domain, domain, requireStrictTransportSecurity) + for _, host := range allHosts { + vh, err := b.buildVirtualHost(options, host, host, requireStrictTransportSecurity) if err != nil { return nil, err } if options.Addr == options.GetGRPCAddr() { // if this is a gRPC service domain and we're supposed to handle that, add those routes - if (config.IsAuthorize(options.Services) && hostsMatchDomain(authorizeURLs, domain)) || - (config.IsDataBroker(options.Services) && hostsMatchDomain(dataBrokerURLs, domain)) { + if (config.IsAuthorize(options.Services) && urlsMatchHost(authorizeURLs, host)) || + (config.IsDataBroker(options.Services) && urlsMatchHost(dataBrokerURLs, host)) { rs, err := b.buildGRPCRoutes() if err != nil { return nil, err @@ -314,7 +278,7 @@ func (b *Builder) buildMainHTTPConnectionManagerFilter( // if we're the proxy, add all the policy routes if config.IsProxy(options.Services) { - rs, err := b.buildPolicyRoutes(options, domain) + rs, err := b.buildPolicyRoutes(options, host) if err != nil { return nil, err } @@ -445,28 +409,35 @@ func (b *Builder) buildGRPCListener(ctx context.Context, cfg *config.Config) (*e return nil, err } + li := newEnvoyListener("grpc-ingress") if cfg.Options.GetGRPCInsecure() { - li := newEnvoyListener("grpc-ingress") li.Address = buildAddress(cfg.Options.GetGRPCAddr(), 80) li.FilterChains = []*envoy_config_listener_v3.FilterChain{{ Filters: []*envoy_config_listener_v3.Filter{ filter, }, }} - return li, nil - } + } else { + li.Address = buildAddress(cfg.Options.GetGRPCAddr(), 443) + li.ListenerFilters = []*envoy_config_listener_v3.ListenerFilter{ + TLSInspectorFilter(), + } - chains, err := b.buildFilterChains(cfg, cfg.Options.GRPCAddr, - func(tlsDomain string, httpDomains []string) (*envoy_config_listener_v3.FilterChain, error) { + serverNames, err := getAllServerNames(cfg, cfg.Options.GRPCAddr) + if err != nil { + return nil, err + } + + for _, serverName := range serverNames { filterChain := &envoy_config_listener_v3.FilterChain{ Filters: []*envoy_config_listener_v3.Filter{filter}, } - if tlsDomain != "*" { + if serverName != "*" { filterChain.FilterChainMatch = &envoy_config_listener_v3.FilterChainMatch{ - ServerNames: []string{tlsDomain}, + ServerNames: []string{serverName}, } } - tlsContext := b.buildDownstreamTLSContext(ctx, cfg, tlsDomain) + tlsContext := b.buildDownstreamTLSContext(ctx, cfg, serverName) if tlsContext != nil { tlsConfig := marshalAny(tlsContext) filterChain.TransportSocket = &envoy_config_core_v3.TransportSocket{ @@ -476,18 +447,9 @@ func (b *Builder) buildGRPCListener(ctx context.Context, cfg *config.Config) (*e }, } } - return filterChain, nil - }) - if err != nil { - return nil, err + li.FilterChains = append(li.FilterChains, filterChain) + } } - - li := newEnvoyListener("grpc-ingress") - li.Address = buildAddress(cfg.Options.GetGRPCAddr(), 443) - li.ListenerFilters = []*envoy_config_listener_v3.ListenerFilter{ - TLSInspectorFilter(), - } - li.FilterChains = chains return li, nil } @@ -548,23 +510,23 @@ func (b *Builder) buildRouteConfiguration(name string, virtualHosts []*envoy_con func (b *Builder) buildDownstreamTLSContext(ctx context.Context, cfg *config.Config, - domain string, + serverName string, ) *envoy_extensions_transport_sockets_tls_v3.DownstreamTlsContext { certs, err := cfg.AllCertificates() if err != nil { - log.Warn(ctx).Str("domain", domain).Err(err).Msg("failed to get all certificates from config") + log.Warn(ctx).Str("domain", serverName).Err(err).Msg("failed to get all certificates from config") return nil } - cert, err := cryptutil.GetCertificateForDomain(certs, domain) + cert, err := cryptutil.GetCertificateForServerName(certs, serverName) if err != nil { - log.Warn(ctx).Str("domain", domain).Err(err).Msg("failed to get certificate for domain") + log.Warn(ctx).Str("domain", serverName).Err(err).Msg("failed to get certificate for domain") return nil } err = validateCertificate(cert) if err != nil { - log.Warn(ctx).Str("domain", domain).Err(err).Msg("invalid certificate for domain") + log.Warn(ctx).Str("domain", serverName).Err(err).Msg("invalid certificate for domain") return nil } @@ -584,14 +546,14 @@ func (b *Builder) buildDownstreamTLSContext(ctx context.Context, TlsParams: tlsParams, TlsCertificates: []*envoy_extensions_transport_sockets_tls_v3.TlsCertificate{envoyCert}, AlpnProtocols: alpnProtocols, - ValidationContextType: b.buildDownstreamValidationContext(ctx, cfg, domain), + ValidationContextType: b.buildDownstreamValidationContext(ctx, cfg, serverName), }, } } func (b *Builder) buildDownstreamValidationContext(ctx context.Context, cfg *config.Config, - domain string, + serverName string, ) *envoy_extensions_transport_sockets_tls_v3.CommonTlsContext_ValidationContext { needsClientCert := false @@ -599,7 +561,7 @@ func (b *Builder) buildDownstreamValidationContext(ctx context.Context, needsClientCert = true } if !needsClientCert { - for _, p := range getPoliciesForDomain(cfg.Options, domain) { + for _, p := range getPoliciesForServerName(cfg.Options, serverName) { if p.TLSDownstreamClientCA != "" { needsClientCert = true break @@ -632,40 +594,41 @@ func (b *Builder) buildDownstreamValidationContext(ctx context.Context, return vc } -func getAllRouteableDomains(options *config.Options, addr string) ([]string, error) { - allDomains := sets.NewSorted[string]() +func getAllRouteableHosts(options *config.Options, addr string) ([]string, error) { + allHosts := sets.NewSorted[string]() if addr == options.Addr { - domains, err := options.GetAllRouteableHTTPDomains() + hosts, err := options.GetAllRouteableHTTPHosts() if err != nil { return nil, err } - allDomains.Add(domains...) + allHosts.Add(hosts...) } if addr == options.GetGRPCAddr() { - domains, err := options.GetAllRouteableGRPCDomains() + hosts, err := options.GetAllRouteableGRPCHosts() if err != nil { return nil, err } - allDomains.Add(domains...) + allHosts.Add(hosts...) } - return allDomains.ToSlice(), nil + return allHosts.ToSlice(), nil } -func getAllTLSDomains(cfg *config.Config, addr string) ([]string, error) { - domains := sets.NewSorted[string]() +func getAllServerNames(cfg *config.Config, addr string) ([]string, error) { + serverNames := sets.NewSorted[string]() + serverNames.Add("*") - routeableDomains, err := getAllRouteableDomains(cfg.Options, addr) + routeableHosts, err := getAllRouteableHosts(cfg.Options, addr) if err != nil { return nil, err } - for _, hp := range routeableDomains { - if d, _, err := net.SplitHostPort(hp); err == nil { - domains.Add(d) + for _, hp := range routeableHosts { + if h, _, err := net.SplitHostPort(hp); err == nil { + serverNames.Add(h) } else { - domains.Add(hp) + serverNames.Add(hp) } } @@ -674,24 +637,24 @@ func getAllTLSDomains(cfg *config.Config, addr string) ([]string, error) { return nil, err } for i := range certs { - for _, domain := range cryptutil.GetCertificateDomains(&certs[i]) { - domains.Add(domain) + for _, domain := range cryptutil.GetCertificateServerNames(&certs[i]) { + serverNames.Add(domain) } } - return domains.ToSlice(), nil + return serverNames.ToSlice(), nil } -func hostsMatchDomain(urls []*url.URL, host string) bool { +func urlsMatchHost(urls []*url.URL, host string) bool { for _, u := range urls { - if hostMatchesDomain(u, host) { + if urlMatchesHost(u, host) { return true } } return false } -func hostMatchesDomain(u *url.URL, host string) bool { +func urlMatchesHost(u *url.URL, host string) bool { if u == nil { return false } @@ -718,10 +681,10 @@ func hostMatchesDomain(u *url.URL, host string) bool { return h1 == h2 && p1 == p2 } -func getPoliciesForDomain(options *config.Options, domain string) []config.Policy { +func getPoliciesForServerName(options *config.Options, serverName string) []config.Policy { var policies []config.Policy for _, p := range options.GetAllPolicies() { - if p.Source != nil && p.Source.URL.Hostname() == domain { + if p.Source != nil && urlutil.MatchesServerName(*p.Source.URL, serverName) { policies = append(policies, p) } } diff --git a/config/envoyconfig/listeners_test.go b/config/envoyconfig/listeners_test.go index b49753c21..ab44e74e0 100644 --- a/config/envoyconfig/listeners_test.go +++ b/config/envoyconfig/listeners_test.go @@ -129,7 +129,8 @@ func Test_buildMainHTTPConnectionManagerFilter(t *testing.T) { options := config.NewDefaultOptions() options.SkipXffAppend = true options.XffNumTrustedHops = 1 - filter, err := b.buildMainHTTPConnectionManagerFilter(options, []string{"example.com"}, true) + options.AuthenticateURLString = "https://authenticate.example.com" + filter, err := b.buildMainHTTPConnectionManagerFilter(options, true) require.NoError(t, err) testutil.AssertProtoJSONEqual(t, `{ "name": "envoy.filters.network.http_connection_manager", @@ -220,8 +221,8 @@ func Test_buildMainHTTPConnectionManagerFilter(t *testing.T) { "name": "main", "virtualHosts": [ { - "name": "example.com", - "domains": ["example.com"], + "name": "authenticate.example.com", + "domains": ["authenticate.example.com"], "responseHeadersToAdd": [{ "appendAction": "OVERWRITE_IF_EXISTS_OR_ADD", "header": { @@ -366,6 +367,216 @@ func Test_buildMainHTTPConnectionManagerFilter(t *testing.T) { "disabled": true } } + }, + { + "name": "pomerium-path-/oauth2/callback", + "match": { + "path": "/oauth2/callback" + }, + "route": { + "cluster": "pomerium-control-plane-http" + }, + "typedPerFilterConfig": { + "envoy.filters.http.ext_authz": { + "@type": "type.googleapis.com/envoy.extensions.filters.http.ext_authz.v3.ExtAuthzPerRoute", + "disabled": true + } + } + }, + { + "name": "pomerium-path-/", + "match": { + "path": "/" + }, + "route": { + "cluster": "pomerium-control-plane-http" + }, + "typedPerFilterConfig": { + "envoy.filters.http.ext_authz": { + "@type": "type.googleapis.com/envoy.extensions.filters.http.ext_authz.v3.ExtAuthzPerRoute", + "disabled": true + } + } + } + ] + }, + { + "name": "authenticate.example.com:443", + "domains": ["authenticate.example.com:443"], + "responseHeadersToAdd": [{ + "appendAction": "OVERWRITE_IF_EXISTS_OR_ADD", + "header": { + "key": "Strict-Transport-Security", + "value": "max-age=31536000; includeSubDomains; preload" + } + }, + { + "appendAction": "OVERWRITE_IF_EXISTS_OR_ADD", + "header": { + "key": "X-Frame-Options", + "value": "SAMEORIGIN" + } + }, + { + "appendAction": "OVERWRITE_IF_EXISTS_OR_ADD", + "header": { + "key": "X-XSS-Protection", + "value": "1; mode=block" + } + }], + "routes": [ + { + "name": "pomerium-path-/.pomerium/jwt", + "match": { + "path": "/.pomerium/jwt" + }, + "route": { + "cluster": "pomerium-control-plane-http" + } + }, + { + "name": "pomerium-path-/.pomerium/webauthn", + "match": { + "path": "/.pomerium/webauthn" + }, + "route": { + "cluster": "pomerium-control-plane-http" + } + }, + { + "name": "pomerium-path-/ping", + "match": { + "path": "/ping" + }, + "route": { + "cluster": "pomerium-control-plane-http" + }, + "typedPerFilterConfig": { + "envoy.filters.http.ext_authz": { + "@type": "type.googleapis.com/envoy.extensions.filters.http.ext_authz.v3.ExtAuthzPerRoute", + "disabled": true + } + } + }, + { + "name": "pomerium-path-/healthz", + "match": { + "path": "/healthz" + }, + "route": { + "cluster": "pomerium-control-plane-http" + }, + "typedPerFilterConfig": { + "envoy.filters.http.ext_authz": { + "@type": "type.googleapis.com/envoy.extensions.filters.http.ext_authz.v3.ExtAuthzPerRoute", + "disabled": true + } + } + }, + { + "name": "pomerium-path-/.pomerium", + "match": { + "path": "/.pomerium" + }, + "route": { + "cluster": "pomerium-control-plane-http" + }, + "typedPerFilterConfig": { + "envoy.filters.http.ext_authz": { + "@type": "type.googleapis.com/envoy.extensions.filters.http.ext_authz.v3.ExtAuthzPerRoute", + "disabled": true + } + } + }, + { + "name": "pomerium-prefix-/.pomerium/", + "match": { + "prefix": "/.pomerium/" + }, + "route": { + "cluster": "pomerium-control-plane-http" + }, + "typedPerFilterConfig": { + "envoy.filters.http.ext_authz": { + "@type": "type.googleapis.com/envoy.extensions.filters.http.ext_authz.v3.ExtAuthzPerRoute", + "disabled": true + } + } + }, + { + "name": "pomerium-path-/.well-known/pomerium", + "match": { + "path": "/.well-known/pomerium" + }, + "route": { + "cluster": "pomerium-control-plane-http" + }, + "typedPerFilterConfig": { + "envoy.filters.http.ext_authz": { + "@type": "type.googleapis.com/envoy.extensions.filters.http.ext_authz.v3.ExtAuthzPerRoute", + "disabled": true + } + } + }, + { + "name": "pomerium-prefix-/.well-known/pomerium/", + "match": { + "prefix": "/.well-known/pomerium/" + }, + "route": { + "cluster": "pomerium-control-plane-http" + }, + "typedPerFilterConfig": { + "envoy.filters.http.ext_authz": { + "@type": "type.googleapis.com/envoy.extensions.filters.http.ext_authz.v3.ExtAuthzPerRoute", + "disabled": true + } + } + }, + { + "name": "pomerium-path-/robots.txt", + "match": { + "path": "/robots.txt" + }, + "route": { + "cluster": "pomerium-control-plane-http" + }, + "typedPerFilterConfig": { + "envoy.filters.http.ext_authz": { + "@type": "type.googleapis.com/envoy.extensions.filters.http.ext_authz.v3.ExtAuthzPerRoute", + "disabled": true + } + } + }, + { + "name": "pomerium-path-/oauth2/callback", + "match": { + "path": "/oauth2/callback" + }, + "route": { + "cluster": "pomerium-control-plane-http" + }, + "typedPerFilterConfig": { + "envoy.filters.http.ext_authz": { + "@type": "type.googleapis.com/envoy.extensions.filters.http.ext_authz.v3.ExtAuthzPerRoute", + "disabled": true + } + } + }, + { + "name": "pomerium-path-/", + "match": { + "path": "/" + }, + "route": { + "cluster": "pomerium-control-plane-http" + }, + "typedPerFilterConfig": { + "envoy.filters.http.ext_authz": { + "@type": "type.googleapis.com/envoy.extensions.filters.http.ext_authz.v3.ExtAuthzPerRoute", + "disabled": true + } + } } ] }, @@ -779,7 +990,7 @@ func Test_getAllDomains(t *testing.T) { } t.Run("routable", func(t *testing.T) { t.Run("http", func(t *testing.T) { - actual, err := getAllRouteableDomains(options, "127.0.0.1:9000") + actual, err := getAllRouteableHosts(options, "127.0.0.1:9000") require.NoError(t, err) expect := []string{ "a.example.com", @@ -794,7 +1005,7 @@ func Test_getAllDomains(t *testing.T) { assert.Equal(t, expect, actual) }) t.Run("grpc", func(t *testing.T) { - actual, err := getAllRouteableDomains(options, "127.0.0.1:9001") + actual, err := getAllRouteableHosts(options, "127.0.0.1:9001") require.NoError(t, err) expect := []string{ "authorize.example.com:9001", @@ -805,7 +1016,7 @@ func Test_getAllDomains(t *testing.T) { t.Run("both", func(t *testing.T) { newOptions := *options newOptions.GRPCAddr = newOptions.Addr - actual, err := getAllRouteableDomains(&newOptions, "127.0.0.1:9000") + actual, err := getAllRouteableHosts(&newOptions, "127.0.0.1:9000") require.NoError(t, err) expect := []string{ "a.example.com", @@ -824,9 +1035,10 @@ func Test_getAllDomains(t *testing.T) { }) t.Run("tls", func(t *testing.T) { t.Run("http", func(t *testing.T) { - actual, err := getAllTLSDomains(&config.Config{Options: options}, "127.0.0.1:9000") + actual, err := getAllServerNames(&config.Config{Options: options}, "127.0.0.1:9000") require.NoError(t, err) expect := []string{ + "*", "*.unknown.example.com", "a.example.com", "authenticate.example.com", @@ -836,9 +1048,10 @@ func Test_getAllDomains(t *testing.T) { assert.Equal(t, expect, actual) }) t.Run("grpc", func(t *testing.T) { - actual, err := getAllTLSDomains(&config.Config{Options: options}, "127.0.0.1:9001") + actual, err := getAllServerNames(&config.Config{Options: options}, "127.0.0.1:9001") require.NoError(t, err) expect := []string{ + "*", "*.unknown.example.com", "authorize.example.com", "cache.example.com", @@ -848,14 +1061,31 @@ func Test_getAllDomains(t *testing.T) { }) } -func Test_hostMatchesDomain(t *testing.T) { - assert.True(t, hostMatchesDomain(mustParseURL(t, "http://example.com"), "example.com")) - assert.True(t, hostMatchesDomain(mustParseURL(t, "http://example.com"), "example.com:80")) - assert.True(t, hostMatchesDomain(mustParseURL(t, "https://example.com"), "example.com:443")) - assert.True(t, hostMatchesDomain(mustParseURL(t, "https://example.com:443"), "example.com:443")) - assert.True(t, hostMatchesDomain(mustParseURL(t, "https://example.com:443"), "example.com")) - assert.False(t, hostMatchesDomain(mustParseURL(t, "http://example.com:81"), "example.com")) - assert.False(t, hostMatchesDomain(mustParseURL(t, "http://example.com:81"), "example.com:80")) +func Test_urlMatchesHost(t *testing.T) { + t.Parallel() + + for _, tc := range []struct { + name string + sourceURL string + host string + matches bool + }{ + {"no port", "http://example.com", "example.com", true}, + {"host http port", "http://example.com", "example.com:80", true}, + {"host https port", "https://example.com", "example.com:443", true}, + {"with port", "https://example.com:443", "example.com:443", true}, + {"url port", "https://example.com:443", "example.com", true}, + {"non standard port", "http://example.com:81", "example.com", false}, + {"non standard host port", "http://example.com:81", "example.com:80", false}, + } { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + assert.Equal(t, tc.matches, urlMatchesHost(mustParseURL(t, tc.sourceURL), tc.host), + "urlMatchesHost(%s,%s)", tc.sourceURL, tc.host) + }) + } } func Test_buildRouteConfiguration(t *testing.T) { diff --git a/config/envoyconfig/routes.go b/config/envoyconfig/routes.go index e837bdc74..96758cf2c 100644 --- a/config/envoyconfig/routes.go +++ b/config/envoyconfig/routes.go @@ -47,12 +47,12 @@ func (b *Builder) buildGRPCRoutes() ([]*envoy_config_route_v3.Route, error) { }}, nil } -func (b *Builder) buildPomeriumHTTPRoutes(options *config.Options, domain string) ([]*envoy_config_route_v3.Route, error) { +func (b *Builder) buildPomeriumHTTPRoutes(options *config.Options, host string) ([]*envoy_config_route_v3.Route, error) { var routes []*envoy_config_route_v3.Route // if this is the pomerium proxy in front of the the authenticate service, don't add // these routes since they will be handled by authenticate - isFrontingAuthenticate, err := isProxyFrontingAuthenticate(options, domain) + isFrontingAuthenticate, err := isProxyFrontingAuthenticate(options, host) if err != nil { return nil, err } @@ -70,7 +70,7 @@ func (b *Builder) buildPomeriumHTTPRoutes(options *config.Options, domain string b.buildControlPlanePrefixRoute("/.well-known/pomerium/", false), ) // per #837, only add robots.txt if there are no unauthenticated routes - if !hasPublicPolicyMatchingURL(options, url.URL{Scheme: "https", Host: domain, Path: "/robots.txt"}) { + if !hasPublicPolicyMatchingURL(options, url.URL{Scheme: "https", Host: host, Path: "/robots.txt"}) { routes = append(routes, b.buildControlPlanePathRoute("/robots.txt", false)) } } @@ -79,7 +79,7 @@ func (b *Builder) buildPomeriumHTTPRoutes(options *config.Options, domain string if err != nil { return nil, err } - if config.IsAuthenticate(options.Services) && hostMatchesDomain(authenticateURL, domain) { + if config.IsAuthenticate(options.Services) && urlMatchesHost(authenticateURL, host) { routes = append(routes, b.buildControlPlanePathRoute(options.AuthenticateCallbackPath, false), b.buildControlPlanePathRoute("/", false), @@ -151,12 +151,12 @@ func getClusterStatsName(policy *config.Policy) string { return "" } -func (b *Builder) buildPolicyRoutes(options *config.Options, domain string) ([]*envoy_config_route_v3.Route, error) { +func (b *Builder) buildPolicyRoutes(options *config.Options, host string) ([]*envoy_config_route_v3.Route, error) { var routes []*envoy_config_route_v3.Route for i, p := range options.GetAllPolicies() { policy := p - if !hostMatchesDomain(policy.Source.URL, domain) { + if !urlMatchesHost(policy.Source.URL, host) { continue } @@ -188,7 +188,7 @@ func (b *Builder) buildPolicyRoutes(options *config.Options, domain string) ([]* } // disable authentication entirely when the proxy is fronting authenticate - isFrontingAuthenticate, err := isProxyFrontingAuthenticate(options, domain) + isFrontingAuthenticate, err := isProxyFrontingAuthenticate(options, host) if err != nil { return nil, err } @@ -497,13 +497,13 @@ func hasPublicPolicyMatchingURL(options *config.Options, requestURL url.URL) boo return false } -func isProxyFrontingAuthenticate(options *config.Options, domain string) (bool, error) { +func isProxyFrontingAuthenticate(options *config.Options, host string) (bool, error) { authenticateURL, err := options.GetAuthenticateURL() if err != nil { return false, err } - if !config.IsAuthenticate(options.Services) && hostMatchesDomain(authenticateURL, domain) { + if !config.IsAuthenticate(options.Services) && urlMatchesHost(authenticateURL, host) { return true, nil } diff --git a/config/options.go b/config/options.go index d912e18b2..662dd3a91 100644 --- a/config/options.go +++ b/config/options.go @@ -1015,15 +1015,9 @@ func (o *Options) GetCodecType() CodecType { return o.CodecType } -// GetAllRouteableGRPCDomains returns all the possible gRPC domains handled by the Pomerium options. -func (o *Options) GetAllRouteableGRPCDomains() ([]string, error) { - return o.GetAllRouteableGRPCDomainsForTLSServerName("") -} - -// GetAllRouteableGRPCDomainsForTLSServerName returns all the possible gRPC domains handled by the Pomerium options -// for the given TLS server name. -func (o *Options) GetAllRouteableGRPCDomainsForTLSServerName(tlsServerName string) ([]string, error) { - domains := sets.NewSorted[string]() +// GetAllRouteableGRPCHosts returns all the possible gRPC hosts handled by the Pomerium options. +func (o *Options) GetAllRouteableGRPCHosts() ([]string, error) { + hosts := sets.NewSorted[string]() // authorize urls if IsAll(o.Services) { @@ -1032,11 +1026,7 @@ func (o *Options) GetAllRouteableGRPCDomainsForTLSServerName(tlsServerName strin return nil, err } for _, u := range authorizeURLs { - for _, h := range urlutil.GetDomainsForURL(*u) { - if tlsServerName == "" || urlutil.StripPort(h) == tlsServerName { - domains.Add(h) - } - } + hosts.Add(urlutil.GetDomainsForURL(*u)...) } } else if IsAuthorize(o.Services) { authorizeURLs, err := o.GetInternalAuthorizeURLs() @@ -1044,11 +1034,7 @@ func (o *Options) GetAllRouteableGRPCDomainsForTLSServerName(tlsServerName strin return nil, err } for _, u := range authorizeURLs { - for _, h := range urlutil.GetDomainsForURL(*u) { - if tlsServerName == "" || urlutil.StripPort(h) == tlsServerName { - domains.Add(h) - } - } + hosts.Add(urlutil.GetDomainsForURL(*u)...) } } @@ -1059,11 +1045,7 @@ func (o *Options) GetAllRouteableGRPCDomainsForTLSServerName(tlsServerName strin return nil, err } for _, u := range dataBrokerURLs { - for _, h := range urlutil.GetDomainsForURL(*u) { - if tlsServerName == "" || urlutil.StripPort(h) == tlsServerName { - domains.Add(h) - } - } + hosts.Add(urlutil.GetDomainsForURL(*u)...) } } else if IsDataBroker(o.Services) { dataBrokerURLs, err := o.GetInternalDataBrokerURLs() @@ -1071,71 +1053,42 @@ func (o *Options) GetAllRouteableGRPCDomainsForTLSServerName(tlsServerName strin return nil, err } for _, u := range dataBrokerURLs { - for _, h := range urlutil.GetDomainsForURL(*u) { - if tlsServerName == "" || urlutil.StripPort(h) == tlsServerName { - domains.Add(h) - } - } + hosts.Add(urlutil.GetDomainsForURL(*u)...) } } - return domains.ToSlice(), nil + return hosts.ToSlice(), nil } -// GetAllRouteableHTTPDomains returns all the possible HTTP domains handled by the Pomerium options. -func (o *Options) GetAllRouteableHTTPDomains() ([]string, error) { - return o.GetAllRouteableHTTPDomainsForTLSServerName("") -} - -// GetAllRouteableHTTPDomainsForTLSServerName returns all the possible HTTP domains handled by the Pomerium options -// for the given TLS server name. -func (o *Options) GetAllRouteableHTTPDomainsForTLSServerName(tlsServerName string) ([]string, error) { - domains := sets.NewSorted[string]() +// GetAllRouteableHTTPHosts returns all the possible HTTP hosts handled by the Pomerium options. +func (o *Options) GetAllRouteableHTTPHosts() ([]string, error) { + hosts := sets.NewSorted[string]() if IsAuthenticate(o.Services) { authenticateURL, err := o.GetInternalAuthenticateURL() if err != nil { return nil, err } - for _, h := range urlutil.GetDomainsForURL(*authenticateURL) { - if tlsServerName == "" || urlutil.StripPort(h) == tlsServerName { - domains.Add(h) - } - } + hosts.Add(urlutil.GetDomainsForURL(*authenticateURL)...) authenticateURL, err = o.GetAuthenticateURL() if err != nil { return nil, err } - for _, h := range urlutil.GetDomainsForURL(*authenticateURL) { - if tlsServerName == "" || urlutil.StripPort(h) == tlsServerName { - domains.Add(h) - } - } + hosts.Add(urlutil.GetDomainsForURL(*authenticateURL)...) } // policy urls if IsProxy(o.Services) { for _, policy := range o.GetAllPolicies() { - for _, h := range urlutil.GetDomainsForURL(*policy.Source.URL) { - if tlsServerName == "" || - policy.TLSDownstreamServerName == tlsServerName || - urlutil.StripPort(h) == tlsServerName { - domains.Add(h) - } - } + hosts.Add(urlutil.GetDomainsForURL(*policy.Source.URL)...) if policy.TLSDownstreamServerName != "" { tlsURL := policy.Source.URL.ResolveReference(&url.URL{Host: policy.TLSDownstreamServerName}) - for _, h := range urlutil.GetDomainsForURL(*tlsURL) { - if tlsServerName == "" || - urlutil.StripPort(h) == tlsServerName { - domains.Add(h) - } - } + hosts.Add(urlutil.GetDomainsForURL(*tlsURL)...) } } } - return domains.ToSlice(), nil + return hosts.ToSlice(), nil } // GetClientSecret gets the client secret. diff --git a/config/options_test.go b/config/options_test.go index 37f6d2e41..1c2c350bf 100644 --- a/config/options_test.go +++ b/config/options_test.go @@ -666,14 +666,14 @@ func TestOptions_GetOauthOptions(t *testing.T) { assert.Equal(t, u.Hostname(), oauthOptions.RedirectURL.Hostname()) } -func TestOptions_GetAllRouteableGRPCDomains(t *testing.T) { +func TestOptions_GetAllRouteableGRPCHosts(t *testing.T) { opts := &Options{ AuthenticateURLString: "https://authenticate.example.com", AuthorizeURLString: "https://authorize.example.com", DataBrokerURLString: "https://databroker.example.com", Services: "all", } - domains, err := opts.GetAllRouteableGRPCDomains() + hosts, err := opts.GetAllRouteableGRPCHosts() assert.NoError(t, err) assert.Equal(t, []string{ @@ -681,10 +681,10 @@ func TestOptions_GetAllRouteableGRPCDomains(t *testing.T) { "authorize.example.com:443", "databroker.example.com", "databroker.example.com:443", - }, domains) + }, hosts) } -func TestOptions_GetAllRouteableHTTPDomains(t *testing.T) { +func TestOptions_GetAllRouteableHTTPHosts(t *testing.T) { p1 := Policy{From: "https://from1.example.com"} p1.Validate() p2 := Policy{From: "https://from2.example.com"} @@ -699,7 +699,7 @@ func TestOptions_GetAllRouteableHTTPDomains(t *testing.T) { Policies: []Policy{p1, p2, p3}, Services: "all", } - domains, err := opts.GetAllRouteableHTTPDomains() + hosts, err := opts.GetAllRouteableHTTPHosts() assert.NoError(t, err) assert.Equal(t, []string{ @@ -713,7 +713,7 @@ func TestOptions_GetAllRouteableHTTPDomains(t *testing.T) { "from2.example.com:443", "from3.example.com", "from3.example.com:443", - }, domains) + }, hosts) } func TestOptions_ApplySettings(t *testing.T) { diff --git a/internal/urlutil/url.go b/internal/urlutil/url.go index 9c8fa443b..405635724 100644 --- a/internal/urlutil/url.go +++ b/internal/urlutil/url.go @@ -8,6 +8,8 @@ import ( "net/url" "strings" "time" + + "github.com/caddyserver/certmagic" ) const ( @@ -160,3 +162,8 @@ func GetExternalRequest(internalURL, externalURL *url.URL, r *http.Request) *htt } return er } + +// MatchesServerName returnes true if the url's host matches the given server name. +func MatchesServerName(u url.URL, serverName string) bool { + return certmagic.MatchWildcard(u.Hostname(), serverName) +} diff --git a/internal/urlutil/url_test.go b/internal/urlutil/url_test.go index 06e8ee4c4..a253a3e7d 100644 --- a/internal/urlutil/url_test.go +++ b/internal/urlutil/url_test.go @@ -166,3 +166,9 @@ func TestJoin(t *testing.T) { assert.Equal(t, "/x/y/z/", Join("/x", "/y/z/")) assert.Equal(t, "/x/y/z/", Join("/x/", "/y/z/")) } + +func TestMatchesServerName(t *testing.T) { + t.Run("wildcard", func(t *testing.T) { + assert.True(t, MatchesServerName(MustParseAndValidateURL("https://domain.example.com"), "*.example.com")) + }) +} diff --git a/pkg/cryptutil/tls.go b/pkg/cryptutil/tls.go index 93f0e073e..395f9d1f5 100644 --- a/pkg/cryptutil/tls.go +++ b/pkg/cryptutil/tls.go @@ -44,36 +44,36 @@ func GetCertPool(ca, caFile string) (*x509.CertPool, error) { return rootCAs, nil } -// GetCertificateForDomain returns the tls Certificate which matches the given domain name. +// GetCertificateForServerName returns the tls Certificate which matches the given server name. // It should handle both exact matches and wildcard matches. If none of those match, the first certificate will be used. // Finally if there are no matching certificates one will be generated. -func GetCertificateForDomain(certificates []tls.Certificate, domain string) (*tls.Certificate, error) { +func GetCertificateForServerName(certificates []tls.Certificate, serverName string) (*tls.Certificate, error) { // first try a direct name match for i := range certificates { - if matchesDomain(&certificates[i], domain) { + if matchesServerName(&certificates[i], serverName) { return &certificates[i], nil } } - log.WarnNoTLSCertificate(domain) + log.WarnNoTLSCertificate(serverName) // finally fall back to a generated, self-signed certificate - return GenerateSelfSignedCertificate(domain) + return GenerateSelfSignedCertificate(serverName) } -// HasCertificateForDomain returns true if a TLS certificate matches the given domain. -func HasCertificateForDomain(certificates []tls.Certificate, domain string) bool { +// HasCertificateForServerName returns true if a TLS certificate matches the given server name. +func HasCertificateForServerName(certificates []tls.Certificate, serverName string) bool { for i := range certificates { - if matchesDomain(&certificates[i], domain) { + if matchesServerName(&certificates[i], serverName) { return true } } return false } -// GetCertificateDomains gets all the certificate's matching domain names. +// GetCertificateServerNames gets all the certificate's server names. // Will return an empty slice if certificate is nil, empty, or x509 parsing fails. -func GetCertificateDomains(cert *tls.Certificate) []string { +func GetCertificateServerNames(cert *tls.Certificate) []string { if cert == nil || len(cert.Certificate) == 0 { return nil } @@ -83,19 +83,19 @@ func GetCertificateDomains(cert *tls.Certificate) []string { return nil } - var domains []string + var serverNames []string if xcert.Subject.CommonName != "" { - domains = append(domains, xcert.Subject.CommonName) + serverNames = append(serverNames, xcert.Subject.CommonName) } for _, dnsName := range xcert.DNSNames { if dnsName != "" { - domains = append(domains, dnsName) + serverNames = append(serverNames, dnsName) } } - return domains + return serverNames } -func matchesDomain(cert *tls.Certificate, domain string) bool { +func matchesServerName(cert *tls.Certificate, serverName string) bool { if cert == nil || len(cert.Certificate) == 0 { return false } @@ -105,12 +105,12 @@ func matchesDomain(cert *tls.Certificate, domain string) bool { return false } - if certmagic.MatchWildcard(domain, xcert.Subject.CommonName) { + if certmagic.MatchWildcard(serverName, xcert.Subject.CommonName) { return true } for _, san := range xcert.DNSNames { - if certmagic.MatchWildcard(domain, san) { + if certmagic.MatchWildcard(serverName, san) { return true } } diff --git a/pkg/cryptutil/tls_test.go b/pkg/cryptutil/tls_test.go index b62155fa8..a6bbdebec 100644 --- a/pkg/cryptutil/tls_test.go +++ b/pkg/cryptutil/tls_test.go @@ -8,10 +8,10 @@ import ( "github.com/stretchr/testify/require" ) -func TestGetCertificateForDomain(t *testing.T) { - gen := func(t *testing.T, domain string) *tls.Certificate { - cert, err := GenerateSelfSignedCertificate(domain) - if !assert.NoError(t, err, "error generating certificate for: %s", domain) { +func TestGetCertificateForServerName(t *testing.T) { + gen := func(t *testing.T, serverName string) *tls.Certificate { + cert, err := GenerateSelfSignedCertificate(serverName) + if !assert.NoError(t, err, "error generating certificate for: %s", serverName) { t.FailNow() } return cert @@ -23,7 +23,7 @@ func TestGetCertificateForDomain(t *testing.T) { *gen(t, "b.example.com"), } - found, err := GetCertificateForDomain(certs, "b.example.com") + found, err := GetCertificateForServerName(certs, "b.example.com") if !assert.NoError(t, err) { return } @@ -35,7 +35,7 @@ func TestGetCertificateForDomain(t *testing.T) { *gen(t, "*.example.com"), } - found, err := GetCertificateForDomain(certs, "b.example.com") + found, err := GetCertificateForServerName(certs, "b.example.com") if !assert.NoError(t, err) { return } @@ -46,7 +46,7 @@ func TestGetCertificateForDomain(t *testing.T) { *gen(t, "a.example.com"), } - found, err := GetCertificateForDomain(certs, "b.example.com") + found, err := GetCertificateForServerName(certs, "b.example.com") if !assert.NoError(t, err) { return } @@ -56,7 +56,7 @@ func TestGetCertificateForDomain(t *testing.T) { t.Run("generate", func(t *testing.T) { certs := []tls.Certificate{} - found, err := GetCertificateForDomain(certs, "b.example.com") + found, err := GetCertificateForServerName(certs, "b.example.com") if !assert.NoError(t, err) { return } @@ -64,8 +64,8 @@ func TestGetCertificateForDomain(t *testing.T) { }) } -func TestGetCertificateDomains(t *testing.T) { +func TestGetCertificateServerNames(t *testing.T) { cert, err := GenerateSelfSignedCertificate("www.example.com") require.NoError(t, err) - assert.Equal(t, []string{"www.example.com"}, GetCertificateDomains(cert)) + assert.Equal(t, []string{"www.example.com"}, GetCertificateServerNames(cert)) }