diff --git a/config/envoyconfig/http_connection_manager.go b/config/envoyconfig/http_connection_manager.go index a9ed55b92..884d8392a 100644 --- a/config/envoyconfig/http_connection_manager.go +++ b/config/envoyconfig/http_connection_manager.go @@ -20,6 +20,7 @@ func (b *Builder) buildVirtualHost( options *config.Options, name string, host string, + hasMCPPolicy bool, ) (*envoy_config_route_v3.VirtualHost, error) { vh := &envoy_config_route_v3.VirtualHost{ Name: name, @@ -36,7 +37,7 @@ func (b *Builder) buildVirtualHost( } // these routes match /.pomerium/... and similar paths - rs, err := b.buildPomeriumHTTPRoutes(options, host) + rs, err := b.buildPomeriumHTTPRoutes(options, host, hasMCPPolicy) if err != nil { return nil, err } diff --git a/config/envoyconfig/route_configurations.go b/config/envoyconfig/route_configurations.go index dbf67a8fa..a88d9e588 100644 --- a/config/envoyconfig/route_configurations.go +++ b/config/envoyconfig/route_configurations.go @@ -50,14 +50,14 @@ func (b *Builder) buildMainRouteConfiguration( return nil, err } - allHosts, err := getAllRouteableHosts(cfg.Options, cfg.Options.Addr) + allHosts, mcpHosts, err := getAllRouteableHosts(cfg.Options, cfg.Options.Addr) if err != nil { return nil, err } var virtualHosts []*envoy_config_route_v3.VirtualHost for _, host := range allHosts { - vh, err := b.buildVirtualHost(cfg.Options, host, host) + vh, err := b.buildVirtualHost(cfg.Options, host, host, mcpHosts[host]) if err != nil { return nil, err } @@ -88,7 +88,7 @@ func (b *Builder) buildMainRouteConfiguration( } } - vh, err := b.buildVirtualHost(cfg.Options, "catch-all", "*") + vh, err := b.buildVirtualHost(cfg.Options, "catch-all", "*", false) if err != nil { return nil, err } @@ -106,21 +106,28 @@ func (b *Builder) buildMainRouteConfiguration( return rc, nil } -func getAllRouteableHosts(options *config.Options, addr string) ([]string, error) { +func getAllRouteableHosts(options *config.Options, addr string) ([]string, map[string]bool, error) { allHosts := set.NewTreeSet(cmp.Compare[string]) + mcpHosts := make(map[string]bool) if addr == options.Addr { - hosts, err := options.GetAllRouteableHTTPHosts() + hosts, hostsMCP, err := options.GetAllRouteableHTTPHosts() if err != nil { - return nil, err + return nil, nil, err } allHosts.InsertSlice(hosts) + // Merge any MCP hosts + for host, isMCP := range hostsMCP { + if isMCP { + mcpHosts[host] = true + } + } } if addr == options.GetGRPCAddr() { hosts, err := options.GetAllRouteableGRPCHosts() if err != nil { - return nil, err + return nil, nil, err } allHosts.InsertSlice(hosts) } @@ -131,7 +138,7 @@ func getAllRouteableHosts(options *config.Options, addr string) ([]string, error filtered = append(filtered, host) } } - return filtered, nil + return filtered, mcpHosts, nil } func newRouteConfiguration(name string, virtualHosts []*envoy_config_route_v3.VirtualHost) *envoy_config_route_v3.RouteConfiguration { diff --git a/config/envoyconfig/route_configurations_test.go b/config/envoyconfig/route_configurations_test.go index 7d55daacd..04afa06ca 100644 --- a/config/envoyconfig/route_configurations_test.go +++ b/config/envoyconfig/route_configurations_test.go @@ -195,7 +195,7 @@ func Test_getAllDomains(t *testing.T) { } t.Run("routable", func(t *testing.T) { t.Run("http", func(t *testing.T) { - actual, err := getAllRouteableHosts(options, "127.0.0.1:9000") + actual, _, err := getAllRouteableHosts(options, "127.0.0.1:9000") require.NoError(t, err) expect := []string{ "a.example.com", @@ -214,7 +214,7 @@ func Test_getAllDomains(t *testing.T) { assert.Equal(t, expect, actual) }) t.Run("grpc", func(t *testing.T) { - actual, err := getAllRouteableHosts(options, "127.0.0.1:9001") + actual, _, err := getAllRouteableHosts(options, "127.0.0.1:9001") require.NoError(t, err) expect := []string{ "authorize.example.com:9001", @@ -225,7 +225,7 @@ func Test_getAllDomains(t *testing.T) { t.Run("both", func(t *testing.T) { newOptions := *options newOptions.GRPCAddr = newOptions.Addr - actual, err := getAllRouteableHosts(&newOptions, "127.0.0.1:9000") + actual, _, err := getAllRouteableHosts(&newOptions, "127.0.0.1:9000") require.NoError(t, err) expect := []string{ "a.example.com", @@ -252,7 +252,7 @@ func Test_getAllDomains(t *testing.T) { options.Policies = []config.Policy{ {From: "https://a.example.com"}, } - actual, err := getAllRouteableHosts(options, ":443") + actual, _, err := getAllRouteableHosts(options, ":443") require.NoError(t, err) assert.Equal(t, []string{"a.example.com"}, actual) }) diff --git a/config/envoyconfig/routes.go b/config/envoyconfig/routes.go index 7f86f6499..f96783930 100644 --- a/config/envoyconfig/routes.go +++ b/config/envoyconfig/routes.go @@ -50,6 +50,7 @@ func (b *Builder) buildGRPCRoutes() ([]*envoy_config_route_v3.Route, error) { func (b *Builder) buildPomeriumHTTPRoutes( options *config.Options, host string, + isMCPHost bool, ) ([]*envoy_config_route_v3.Route, error) { var routes []*envoy_config_route_v3.Route @@ -70,15 +71,8 @@ func (b *Builder) buildPomeriumHTTPRoutes( b.buildControlPlanePrefixRoute(options, "/.well-known/pomerium/"), ) - // Only add oauth-authorization-server route if there's an MCP policy - hasMCPPolicy := false - for _, policy := range options.GetAllPoliciesIndexed() { - if policy.IsMCP() { - hasMCPPolicy = true - break - } - } - if hasMCPPolicy { + // Only add oauth-authorization-server route if there's an MCP policy for this host + if isMCPHost { routes = append(routes, b.buildControlPlanePathRoute(options, "/.well-known/oauth-authorization-server")) } } diff --git a/config/envoyconfig/routes_test.go b/config/envoyconfig/routes_test.go index b97a3d304..aedaf6638 100644 --- a/config/envoyconfig/routes_test.go +++ b/config/envoyconfig/routes_test.go @@ -104,7 +104,7 @@ func Test_buildPomeriumHTTPRoutes(t *testing.T) { AuthenticateURLString: "https://authenticate.example.com", AuthenticateCallbackPath: "/oauth2/callback", } - routes, err := b.buildPomeriumHTTPRoutes(options, "authenticate.example.com") + routes, err := b.buildPomeriumHTTPRoutes(options, "authenticate.example.com", false) require.NoError(t, err) testutil.AssertProtoJSONEqual(t, `[ @@ -125,7 +125,7 @@ func Test_buildPomeriumHTTPRoutes(t *testing.T) { AuthenticateURLString: "https://authenticate.example.com", AuthenticateCallbackPath: "/oauth2/callback", } - routes, err := b.buildPomeriumHTTPRoutes(options, "authenticate.example.com") + routes, err := b.buildPomeriumHTTPRoutes(options, "authenticate.example.com", false) require.NoError(t, err) testutil.AssertProtoJSONEqual(t, "null", routes) }) @@ -2302,23 +2302,16 @@ func Test_buildPomeriumHTTPRoutesWithMCP(t *testing.T) { }, } - routes, err := b.buildPomeriumHTTPRoutes(options, "example.com") + routes, err := b.buildPomeriumHTTPRoutes(options, "example.com", false) require.NoError(t, err) - // Check routes for well-known endpoints hasOAuthServer := false - hasPomerium := false for _, route := range routes { if route.GetMatch().GetPath() == "/.well-known/oauth-authorization-server" { hasOAuthServer = true } - if route.GetMatch().GetPath() == "/.well-known/pomerium" { - hasPomerium = true - } } - // Verify oauth-authorization-server route is NOT present - assert.True(t, hasPomerium, "/.well-known/pomerium route should be present") assert.False(t, hasOAuthServer, "/.well-known/oauth-authorization-server route should NOT be present") }) @@ -2340,25 +2333,9 @@ func Test_buildPomeriumHTTPRoutesWithMCP(t *testing.T) { }, } - routes, err := b.buildPomeriumHTTPRoutes(options, "example.com") + routes, err := b.buildPomeriumHTTPRoutes(options, "example.com", true) require.NoError(t, err) - // Check routes for well-known endpoints - hasOAuthServer := false - hasPomerium := false - for _, route := range routes { - if route.GetMatch().GetPath() == "/.well-known/oauth-authorization-server" { - hasOAuthServer = true - } - if route.GetMatch().GetPath() == "/.well-known/pomerium" { - hasPomerium = true - } - } - - // Verify oauth-authorization-server route IS present - assert.True(t, hasPomerium, "/.well-known/pomerium route should be present") - assert.True(t, hasOAuthServer, "/.well-known/oauth-authorization-server route should be present") - // Verify the expected route structures testutil.AssertProtoJSONEqual(t, `[ `+routeString("path", "/ping")+`, diff --git a/config/options.go b/config/options.go index fe49d03b0..87015ed19 100644 --- a/config/options.go +++ b/config/options.go @@ -1273,23 +1273,27 @@ func (o *Options) GetAllRouteableGRPCHosts() ([]string, error) { } // GetAllRouteableHTTPHosts returns all the possible HTTP hosts handled by the Pomerium options. -func (o *Options) GetAllRouteableHTTPHosts() ([]string, error) { +func (o *Options) GetAllRouteableHTTPHosts() ([]string, map[string]bool, error) { hosts := goset.NewTreeSet(cmp.Compare[string]) + mcpHosts := make(map[string]bool) + if IsAuthenticate(o.Services) { if o.AuthenticateInternalURLString != "" { authenticateURL, err := o.GetInternalAuthenticateURL() if err != nil { - return nil, err + return nil, nil, err } - hosts.InsertSlice(urlutil.GetDomainsForURL(authenticateURL, !o.IsRuntimeFlagSet(RuntimeFlagMatchAnyIncomingPort))) + domains := urlutil.GetDomainsForURL(authenticateURL, !o.IsRuntimeFlagSet(RuntimeFlagMatchAnyIncomingPort)) + hosts.InsertSlice(domains) } if o.AuthenticateURLString != "" { authenticateURL, err := o.GetAuthenticateURL() if err != nil { - return nil, err + return nil, nil, err } - hosts.InsertSlice(urlutil.GetDomainsForURL(authenticateURL, !o.IsRuntimeFlagSet(RuntimeFlagMatchAnyIncomingPort))) + domains := urlutil.GetDomainsForURL(authenticateURL, !o.IsRuntimeFlagSet(RuntimeFlagMatchAnyIncomingPort)) + hosts.InsertSlice(domains) } } @@ -1298,18 +1302,35 @@ func (o *Options) GetAllRouteableHTTPHosts() ([]string, error) { for policy := range o.GetAllPolicies() { fromURL, err := urlutil.ParseAndValidateURL(policy.From) if err != nil { - return nil, err + return nil, nil, err + } + + domains := urlutil.GetDomainsForURL(fromURL, !o.IsRuntimeFlagSet(RuntimeFlagMatchAnyIncomingPort)) + hosts.InsertSlice(domains) + + // Track if the domains are associated with an MCP policy + if policy.IsMCP() { + for _, domain := range domains { + mcpHosts[domain] = true + } } - hosts.InsertSlice(urlutil.GetDomainsForURL(fromURL, !o.IsRuntimeFlagSet(RuntimeFlagMatchAnyIncomingPort))) if policy.TLSDownstreamServerName != "" { tlsURL := fromURL.ResolveReference(&url.URL{Host: policy.TLSDownstreamServerName}) - hosts.InsertSlice(urlutil.GetDomainsForURL(tlsURL, !o.IsRuntimeFlagSet(RuntimeFlagMatchAnyIncomingPort))) + tlsDomains := urlutil.GetDomainsForURL(tlsURL, !o.IsRuntimeFlagSet(RuntimeFlagMatchAnyIncomingPort)) + hosts.InsertSlice(tlsDomains) + + // Track if the TLS domains are associated with an MCP policy + if policy.IsMCP() { + for _, domain := range tlsDomains { + mcpHosts[domain] = true + } + } } } } - return hosts.Slice(), nil + return hosts.Slice(), mcpHosts, nil } // GetClientSecret gets the client secret. diff --git a/config/options_test.go b/config/options_test.go index a91f69711..aff398afb 100644 --- a/config/options_test.go +++ b/config/options_test.go @@ -888,22 +888,26 @@ func TestOptions_GetAllRouteableGRPCHosts(t *testing.T) { } func TestOptions_GetAllRouteableHTTPHosts(t *testing.T) { - p1 := Policy{From: "https://from1.example.com"} - p1.Validate() - p2 := Policy{From: "https://from2.example.com"} - p2.Validate() - p3 := Policy{From: "https://from3.example.com", TLSDownstreamServerName: "from.example.com"} - p3.Validate() + to := WeightedURLs{{URL: url.URL{Scheme: "https", Host: "to.example.com"}}} + p1 := Policy{From: "https://from1.example.com", To: to} + assert.NoError(t, p1.Validate()) + p2 := Policy{From: "https://from2.example.com", To: to} + assert.NoError(t, p2.Validate()) + p3 := Policy{From: "https://from3.example.com", TLSDownstreamServerName: "from.example.com", To: to} + assert.NoError(t, p3.Validate()) + p4 := Policy{From: "https://from4.example.com", MCP: &MCP{}, To: to} + assert.NoError(t, p4.Validate()) opts := &Options{ AuthenticateURLString: "https://authenticate.example.com", AuthorizeURLString: "https://authorize.example.com", DataBrokerURLString: "https://databroker.example.com", - Policies: []Policy{p1, p2, p3}, + Policies: []Policy{p1, p2, p3, p4}, Services: "all", } - hosts, err := opts.GetAllRouteableHTTPHosts() + hosts, mcpHosts, err := opts.GetAllRouteableHTTPHosts() assert.NoError(t, err) + assert.Empty(t, cmp.Diff(mcpHosts, map[string]bool{"from4.example.com:443": true, "from4.example.com": true})) assert.Equal(t, []string{ "authenticate.example.com", @@ -916,6 +920,8 @@ func TestOptions_GetAllRouteableHTTPHosts(t *testing.T) { "from2.example.com:443", "from3.example.com", "from3.example.com:443", + "from4.example.com", + "from4.example.com:443", }, hosts) }