diff --git a/go.sum b/go.sum index 64d0292f2..0e2016176 100644 --- a/go.sum +++ b/go.sum @@ -787,8 +787,6 @@ google.golang.org/genproto v0.0.0-20200224152610-e50cd9704f63/go.mod h1:55QSHmfG google.golang.org/genproto v0.0.0-20200305110556-506484158171/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= google.golang.org/genproto v0.0.0-20200331122359-1ee6d9798940/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo= -google.golang.org/genproto v0.0.0-20200715011427-11fb19a81f2c h1:6DWnZZ6EY/59QRRQttZKiktVL23UuQYs7uy75MhhLRM= -google.golang.org/genproto v0.0.0-20200715011427-11fb19a81f2c/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= google.golang.org/genproto v0.0.0-20200726014623-da3ae01ef02d h1:HJaAqDnKreMkv+AQyf1Mcw0jEmL9kKBNL07RDJu1N/k= google.golang.org/genproto v0.0.0-20200726014623-da3ae01ef02d/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= google.golang.org/grpc v1.17.0/go.mod h1:6QZJwpn2B+Zp71q/5VxRsJ6NXXVCE5NRUHRo+f3cWCs= diff --git a/internal/controlplane/xds_listeners.go b/internal/controlplane/xds_listeners.go index 707123d29..45bf2eae5 100644 --- a/internal/controlplane/xds_listeners.go +++ b/internal/controlplane/xds_listeners.go @@ -2,6 +2,8 @@ package controlplane import ( "encoding/base64" + "net" + "net/url" "sort" "time" @@ -126,8 +128,8 @@ func buildMainHTTPConnectionManagerFilter(options *config.Options, domains []str if options.Addr == options.GRPCAddr { // if this is a gRPC service domain and we're supposed to handle that, add those routes - if (config.IsAuthorize(options.Services) && domain == options.GetAuthorizeURL().Host) || - (config.IsCache(options.Services) && domain == options.GetDataBrokerURL().Host) { + if (config.IsAuthorize(options.Services) && hostMatchesDomain(options.GetAuthorizeURL(), domain)) || + (config.IsCache(options.Services) && hostMatchesDomain(options.GetDataBrokerURL(), domain)) { vh.Routes = append(vh.Routes, buildGRPCRoutes()...) } } @@ -225,9 +227,8 @@ func buildMainHTTPConnectionManagerFilter(options *config.Options, domains []str RandomSampling: &envoy_type_v3.Percent{Value: options.TracingSampleRate * 100}, }, // See https://www.envoyproxy.io/docs/envoy/latest/configuration/http/http_conn_man/headers#x-forwarded-for - UseRemoteAddress: &wrappers.BoolValue{Value: true}, - SkipXffAppend: false, - StripMatchingHostPort: true, + UseRemoteAddress: &wrappers.BoolValue{Value: true}, + SkipXffAppend: false, }) return &envoy_config_listener_v3.Filter{ @@ -397,20 +398,30 @@ func buildDownstreamTLSContext(options *config.Options, domain string) *envoy_ex func getAllRouteableDomains(options *config.Options, addr string) []string { lookup := map[string]struct{}{} if config.IsAuthenticate(options.Services) && addr == options.Addr { - lookup[options.GetAuthenticateURL().Host] = struct{}{} + for _, h := range getDomainsForURL(options.GetAuthenticateURL()) { + lookup[h] = struct{}{} + } } if config.IsAuthorize(options.Services) && addr == options.GRPCAddr { - lookup[options.GetAuthorizeURL().Host] = struct{}{} + for _, h := range getDomainsForURL(options.GetAuthorizeURL()) { + lookup[h] = struct{}{} + } } if config.IsCache(options.Services) && addr == options.GRPCAddr { - lookup[options.GetDataBrokerURL().Host] = struct{}{} + for _, h := range getDomainsForURL(options.GetDataBrokerURL()) { + lookup[h] = struct{}{} + } } if config.IsProxy(options.Services) && addr == options.Addr { for _, policy := range options.Policies { - lookup[policy.Source.Host] = struct{}{} + for _, h := range getDomainsForURL(policy.Source.URL) { + lookup[h] = struct{}{} + } } if options.ForwardAuthURL != nil { - lookup[options.ForwardAuthURL.Host] = struct{}{} + for _, h := range getDomainsForURL(options.GetForwardAuthURL()) { + lookup[h] = struct{}{} + } } } @@ -422,3 +433,45 @@ func getAllRouteableDomains(options *config.Options, addr string) []string { return domains } + +func getDomainsForURL(u *url.URL) []string { + var defaultPort string + if u.Scheme == "http" { + defaultPort = "80" + } else { + defaultPort = "443" + } + + // for hosts like 'example.com:1234' we only return one route + if _, p, err := net.SplitHostPort(u.Host); err == nil { + if p != defaultPort { + return []string{u.Host} + } + } + + // for everything else we return two routes: 'example.com' and 'example.com:443' + return []string{u.Hostname(), net.JoinHostPort(u.Hostname(), defaultPort)} +} + +func hostMatchesDomain(u *url.URL, host string) bool { + var defaultPort string + if u.Scheme == "http" { + defaultPort = "80" + } else { + defaultPort = "443" + } + + h1, p1, err := net.SplitHostPort(u.Host) + if err != nil { + h1 = u.Host + p1 = defaultPort + } + + h2, p2, err := net.SplitHostPort(host) + if err != nil { + h2 = host + p2 = defaultPort + } + + return h1 == h2 && p1 == p2 +} diff --git a/internal/controlplane/xds_listeners_test.go b/internal/controlplane/xds_listeners_test.go index 7d72ae964..8c0e332d7 100644 --- a/internal/controlplane/xds_listeners_test.go +++ b/internal/controlplane/xds_listeners_test.go @@ -308,7 +308,6 @@ func Test_buildMainHTTPConnectionManagerFilter(t *testing.T) { "validateClusters": false }, "statPrefix": "ingress", - "stripMatchingHostPort": true, "tracing": { "randomSampling": { "value": 0.01 @@ -370,7 +369,7 @@ func Test_getAllRouteableDomains(t *testing.T) { AuthorizeURL: mustParseURL("https://authorize.example.com:9001"), DataBrokerURL: mustParseURL("https://cache.example.com:9001"), Policies: []config.Policy{ - {Source: &config.StringURL{URL: mustParseURL("https://a.example.com")}}, + {Source: &config.StringURL{URL: mustParseURL("http://a.example.com")}}, {Source: &config.StringURL{URL: mustParseURL("https://b.example.com")}}, {Source: &config.StringURL{URL: mustParseURL("https://c.example.com")}}, }, @@ -379,9 +378,13 @@ func Test_getAllRouteableDomains(t *testing.T) { actual := getAllRouteableDomains(options, "127.0.0.1:9000") expect := []string{ "a.example.com", + "a.example.com:80", "authenticate.example.com", + "authenticate.example.com:443", "b.example.com", + "b.example.com:443", "c.example.com", + "c.example.com:443", } assert.Equal(t, expect, actual) }) @@ -395,6 +398,16 @@ func Test_getAllRouteableDomains(t *testing.T) { }) } +func Test_hostMatchesDomain(t *testing.T) { + assert.True(t, hostMatchesDomain(mustParseURL("http://example.com"), "example.com")) + assert.True(t, hostMatchesDomain(mustParseURL("http://example.com"), "example.com:80")) + assert.True(t, hostMatchesDomain(mustParseURL("https://example.com"), "example.com:443")) + assert.True(t, hostMatchesDomain(mustParseURL("https://example.com:443"), "example.com:443")) + assert.True(t, hostMatchesDomain(mustParseURL("https://example.com:443"), "example.com")) + assert.False(t, hostMatchesDomain(mustParseURL("http://example.com:81"), "example.com")) + assert.False(t, hostMatchesDomain(mustParseURL("http://example.com:81"), "example.com:80")) +} + func Test_buildRouteConfiguration(t *testing.T) { virtualHosts := make([]*envoy_config_route_v3.VirtualHost, 10) routeConfig := buildRouteConfiguration("test-route-configuration", virtualHosts) diff --git a/internal/controlplane/xds_routes.go b/internal/controlplane/xds_routes.go index 8d06c823c..d3d096226 100644 --- a/internal/controlplane/xds_routes.go +++ b/internal/controlplane/xds_routes.go @@ -50,11 +50,11 @@ func buildPomeriumHTTPRoutes(options *config.Options, domain string) []*envoy_co buildControlPlanePrefixRoute("/.well-known/pomerium/"), } // if we're handling authentication, add the oauth2 callback url - if config.IsAuthenticate(options.Services) && domain == options.GetAuthenticateURL().Host { + if config.IsAuthenticate(options.Services) && hostMatchesDomain(options.GetAuthenticateURL(), domain) { routes = append(routes, buildControlPlanePathRoute(options.AuthenticateCallbackPath)) } // if we're the proxy and this is the forward-auth url - if config.IsProxy(options.Services) && options.ForwardAuthURL != nil && domain == options.ForwardAuthURL.Host { + if config.IsProxy(options.Services) && options.ForwardAuthURL != nil && hostMatchesDomain(options.GetForwardAuthURL(), domain) { routes = append(routes, buildControlPlanePrefixRoute("/")) } return routes @@ -103,7 +103,7 @@ func buildPolicyRoutes(options *config.Options, domain string) []*envoy_config_r responseHeadersToAdd := toEnvoyHeaders(options.Headers) for i, policy := range options.Policies { - if policy.Source.Host != domain { + if !hostMatchesDomain(policy.Source.URL, domain) { continue }