cache mcp policy condition

This commit is contained in:
Denis Mishin 2025-04-22 20:58:06 -04:00
parent ba03bb732b
commit 3c8e440b58
7 changed files with 72 additions and 66 deletions

View file

@ -20,6 +20,7 @@ func (b *Builder) buildVirtualHost(
options *config.Options, options *config.Options,
name string, name string,
host string, host string,
hasMCPPolicy bool,
) (*envoy_config_route_v3.VirtualHost, error) { ) (*envoy_config_route_v3.VirtualHost, error) {
vh := &envoy_config_route_v3.VirtualHost{ vh := &envoy_config_route_v3.VirtualHost{
Name: name, Name: name,
@ -36,7 +37,7 @@ func (b *Builder) buildVirtualHost(
} }
// these routes match /.pomerium/... and similar paths // these routes match /.pomerium/... and similar paths
rs, err := b.buildPomeriumHTTPRoutes(options, host) rs, err := b.buildPomeriumHTTPRoutes(options, host, hasMCPPolicy)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -50,14 +50,14 @@ func (b *Builder) buildMainRouteConfiguration(
return nil, err return nil, err
} }
allHosts, err := getAllRouteableHosts(cfg.Options, cfg.Options.Addr) allHosts, mcpHosts, err := getAllRouteableHosts(cfg.Options, cfg.Options.Addr)
if err != nil { if err != nil {
return nil, err return nil, err
} }
var virtualHosts []*envoy_config_route_v3.VirtualHost var virtualHosts []*envoy_config_route_v3.VirtualHost
for _, host := range allHosts { for _, host := range allHosts {
vh, err := b.buildVirtualHost(cfg.Options, host, host) vh, err := b.buildVirtualHost(cfg.Options, host, host, mcpHosts[host])
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -88,7 +88,7 @@ func (b *Builder) buildMainRouteConfiguration(
} }
} }
vh, err := b.buildVirtualHost(cfg.Options, "catch-all", "*") vh, err := b.buildVirtualHost(cfg.Options, "catch-all", "*", false)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -106,21 +106,28 @@ func (b *Builder) buildMainRouteConfiguration(
return rc, nil return rc, nil
} }
func getAllRouteableHosts(options *config.Options, addr string) ([]string, error) { func getAllRouteableHosts(options *config.Options, addr string) ([]string, map[string]bool, error) {
allHosts := set.NewTreeSet(cmp.Compare[string]) allHosts := set.NewTreeSet(cmp.Compare[string])
mcpHosts := make(map[string]bool)
if addr == options.Addr { if addr == options.Addr {
hosts, err := options.GetAllRouteableHTTPHosts() hosts, hostsMCP, err := options.GetAllRouteableHTTPHosts()
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
allHosts.InsertSlice(hosts) allHosts.InsertSlice(hosts)
// Merge any MCP hosts
for host, isMCP := range hostsMCP {
if isMCP {
mcpHosts[host] = true
}
}
} }
if addr == options.GetGRPCAddr() { if addr == options.GetGRPCAddr() {
hosts, err := options.GetAllRouteableGRPCHosts() hosts, err := options.GetAllRouteableGRPCHosts()
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
allHosts.InsertSlice(hosts) allHosts.InsertSlice(hosts)
} }
@ -131,7 +138,7 @@ func getAllRouteableHosts(options *config.Options, addr string) ([]string, error
filtered = append(filtered, host) filtered = append(filtered, host)
} }
} }
return filtered, nil return filtered, mcpHosts, nil
} }
func newRouteConfiguration(name string, virtualHosts []*envoy_config_route_v3.VirtualHost) *envoy_config_route_v3.RouteConfiguration { func newRouteConfiguration(name string, virtualHosts []*envoy_config_route_v3.VirtualHost) *envoy_config_route_v3.RouteConfiguration {

View file

@ -195,7 +195,7 @@ func Test_getAllDomains(t *testing.T) {
} }
t.Run("routable", func(t *testing.T) { t.Run("routable", func(t *testing.T) {
t.Run("http", func(t *testing.T) { t.Run("http", func(t *testing.T) {
actual, err := getAllRouteableHosts(options, "127.0.0.1:9000") actual, _, err := getAllRouteableHosts(options, "127.0.0.1:9000")
require.NoError(t, err) require.NoError(t, err)
expect := []string{ expect := []string{
"a.example.com", "a.example.com",
@ -214,7 +214,7 @@ func Test_getAllDomains(t *testing.T) {
assert.Equal(t, expect, actual) assert.Equal(t, expect, actual)
}) })
t.Run("grpc", func(t *testing.T) { t.Run("grpc", func(t *testing.T) {
actual, err := getAllRouteableHosts(options, "127.0.0.1:9001") actual, _, err := getAllRouteableHosts(options, "127.0.0.1:9001")
require.NoError(t, err) require.NoError(t, err)
expect := []string{ expect := []string{
"authorize.example.com:9001", "authorize.example.com:9001",
@ -225,7 +225,7 @@ func Test_getAllDomains(t *testing.T) {
t.Run("both", func(t *testing.T) { t.Run("both", func(t *testing.T) {
newOptions := *options newOptions := *options
newOptions.GRPCAddr = newOptions.Addr newOptions.GRPCAddr = newOptions.Addr
actual, err := getAllRouteableHosts(&newOptions, "127.0.0.1:9000") actual, _, err := getAllRouteableHosts(&newOptions, "127.0.0.1:9000")
require.NoError(t, err) require.NoError(t, err)
expect := []string{ expect := []string{
"a.example.com", "a.example.com",
@ -252,7 +252,7 @@ func Test_getAllDomains(t *testing.T) {
options.Policies = []config.Policy{ options.Policies = []config.Policy{
{From: "https://a.example.com"}, {From: "https://a.example.com"},
} }
actual, err := getAllRouteableHosts(options, ":443") actual, _, err := getAllRouteableHosts(options, ":443")
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, []string{"a.example.com"}, actual) assert.Equal(t, []string{"a.example.com"}, actual)
}) })

View file

@ -50,6 +50,7 @@ func (b *Builder) buildGRPCRoutes() ([]*envoy_config_route_v3.Route, error) {
func (b *Builder) buildPomeriumHTTPRoutes( func (b *Builder) buildPomeriumHTTPRoutes(
options *config.Options, options *config.Options,
host string, host string,
isMCPHost bool,
) ([]*envoy_config_route_v3.Route, error) { ) ([]*envoy_config_route_v3.Route, error) {
var routes []*envoy_config_route_v3.Route var routes []*envoy_config_route_v3.Route
@ -70,15 +71,8 @@ func (b *Builder) buildPomeriumHTTPRoutes(
b.buildControlPlanePrefixRoute(options, "/.well-known/pomerium/"), b.buildControlPlanePrefixRoute(options, "/.well-known/pomerium/"),
) )
// Only add oauth-authorization-server route if there's an MCP policy // Only add oauth-authorization-server route if there's an MCP policy for this host
hasMCPPolicy := false if isMCPHost {
for _, policy := range options.GetAllPoliciesIndexed() {
if policy.IsMCP() {
hasMCPPolicy = true
break
}
}
if hasMCPPolicy {
routes = append(routes, b.buildControlPlanePathRoute(options, "/.well-known/oauth-authorization-server")) routes = append(routes, b.buildControlPlanePathRoute(options, "/.well-known/oauth-authorization-server"))
} }
} }

View file

@ -104,7 +104,7 @@ func Test_buildPomeriumHTTPRoutes(t *testing.T) {
AuthenticateURLString: "https://authenticate.example.com", AuthenticateURLString: "https://authenticate.example.com",
AuthenticateCallbackPath: "/oauth2/callback", AuthenticateCallbackPath: "/oauth2/callback",
} }
routes, err := b.buildPomeriumHTTPRoutes(options, "authenticate.example.com") routes, err := b.buildPomeriumHTTPRoutes(options, "authenticate.example.com", false)
require.NoError(t, err) require.NoError(t, err)
testutil.AssertProtoJSONEqual(t, `[ testutil.AssertProtoJSONEqual(t, `[
@ -125,7 +125,7 @@ func Test_buildPomeriumHTTPRoutes(t *testing.T) {
AuthenticateURLString: "https://authenticate.example.com", AuthenticateURLString: "https://authenticate.example.com",
AuthenticateCallbackPath: "/oauth2/callback", AuthenticateCallbackPath: "/oauth2/callback",
} }
routes, err := b.buildPomeriumHTTPRoutes(options, "authenticate.example.com") routes, err := b.buildPomeriumHTTPRoutes(options, "authenticate.example.com", false)
require.NoError(t, err) require.NoError(t, err)
testutil.AssertProtoJSONEqual(t, "null", routes) testutil.AssertProtoJSONEqual(t, "null", routes)
}) })
@ -2302,23 +2302,16 @@ func Test_buildPomeriumHTTPRoutesWithMCP(t *testing.T) {
}, },
} }
routes, err := b.buildPomeriumHTTPRoutes(options, "example.com") routes, err := b.buildPomeriumHTTPRoutes(options, "example.com", false)
require.NoError(t, err) require.NoError(t, err)
// Check routes for well-known endpoints
hasOAuthServer := false hasOAuthServer := false
hasPomerium := false
for _, route := range routes { for _, route := range routes {
if route.GetMatch().GetPath() == "/.well-known/oauth-authorization-server" { if route.GetMatch().GetPath() == "/.well-known/oauth-authorization-server" {
hasOAuthServer = true hasOAuthServer = true
} }
if route.GetMatch().GetPath() == "/.well-known/pomerium" {
hasPomerium = true
}
} }
// Verify oauth-authorization-server route is NOT present
assert.True(t, hasPomerium, "/.well-known/pomerium route should be present")
assert.False(t, hasOAuthServer, "/.well-known/oauth-authorization-server route should NOT be present") assert.False(t, hasOAuthServer, "/.well-known/oauth-authorization-server route should NOT be present")
}) })
@ -2340,25 +2333,9 @@ func Test_buildPomeriumHTTPRoutesWithMCP(t *testing.T) {
}, },
} }
routes, err := b.buildPomeriumHTTPRoutes(options, "example.com") routes, err := b.buildPomeriumHTTPRoutes(options, "example.com", true)
require.NoError(t, err) require.NoError(t, err)
// Check routes for well-known endpoints
hasOAuthServer := false
hasPomerium := false
for _, route := range routes {
if route.GetMatch().GetPath() == "/.well-known/oauth-authorization-server" {
hasOAuthServer = true
}
if route.GetMatch().GetPath() == "/.well-known/pomerium" {
hasPomerium = true
}
}
// Verify oauth-authorization-server route IS present
assert.True(t, hasPomerium, "/.well-known/pomerium route should be present")
assert.True(t, hasOAuthServer, "/.well-known/oauth-authorization-server route should be present")
// Verify the expected route structures // Verify the expected route structures
testutil.AssertProtoJSONEqual(t, `[ testutil.AssertProtoJSONEqual(t, `[
`+routeString("path", "/ping")+`, `+routeString("path", "/ping")+`,

View file

@ -1273,23 +1273,27 @@ func (o *Options) GetAllRouteableGRPCHosts() ([]string, error) {
} }
// GetAllRouteableHTTPHosts returns all the possible HTTP hosts handled by the Pomerium options. // GetAllRouteableHTTPHosts returns all the possible HTTP hosts handled by the Pomerium options.
func (o *Options) GetAllRouteableHTTPHosts() ([]string, error) { func (o *Options) GetAllRouteableHTTPHosts() ([]string, map[string]bool, error) {
hosts := goset.NewTreeSet(cmp.Compare[string]) hosts := goset.NewTreeSet(cmp.Compare[string])
mcpHosts := make(map[string]bool)
if IsAuthenticate(o.Services) { if IsAuthenticate(o.Services) {
if o.AuthenticateInternalURLString != "" { if o.AuthenticateInternalURLString != "" {
authenticateURL, err := o.GetInternalAuthenticateURL() authenticateURL, err := o.GetInternalAuthenticateURL()
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
hosts.InsertSlice(urlutil.GetDomainsForURL(authenticateURL, !o.IsRuntimeFlagSet(RuntimeFlagMatchAnyIncomingPort))) domains := urlutil.GetDomainsForURL(authenticateURL, !o.IsRuntimeFlagSet(RuntimeFlagMatchAnyIncomingPort))
hosts.InsertSlice(domains)
} }
if o.AuthenticateURLString != "" { if o.AuthenticateURLString != "" {
authenticateURL, err := o.GetAuthenticateURL() authenticateURL, err := o.GetAuthenticateURL()
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
hosts.InsertSlice(urlutil.GetDomainsForURL(authenticateURL, !o.IsRuntimeFlagSet(RuntimeFlagMatchAnyIncomingPort))) domains := urlutil.GetDomainsForURL(authenticateURL, !o.IsRuntimeFlagSet(RuntimeFlagMatchAnyIncomingPort))
hosts.InsertSlice(domains)
} }
} }
@ -1298,18 +1302,35 @@ func (o *Options) GetAllRouteableHTTPHosts() ([]string, error) {
for policy := range o.GetAllPolicies() { for policy := range o.GetAllPolicies() {
fromURL, err := urlutil.ParseAndValidateURL(policy.From) fromURL, err := urlutil.ParseAndValidateURL(policy.From)
if err != nil { if err != nil {
return nil, err return nil, nil, err
}
domains := urlutil.GetDomainsForURL(fromURL, !o.IsRuntimeFlagSet(RuntimeFlagMatchAnyIncomingPort))
hosts.InsertSlice(domains)
// Track if the domains are associated with an MCP policy
if policy.IsMCP() {
for _, domain := range domains {
mcpHosts[domain] = true
}
} }
hosts.InsertSlice(urlutil.GetDomainsForURL(fromURL, !o.IsRuntimeFlagSet(RuntimeFlagMatchAnyIncomingPort)))
if policy.TLSDownstreamServerName != "" { if policy.TLSDownstreamServerName != "" {
tlsURL := fromURL.ResolveReference(&url.URL{Host: policy.TLSDownstreamServerName}) tlsURL := fromURL.ResolveReference(&url.URL{Host: policy.TLSDownstreamServerName})
hosts.InsertSlice(urlutil.GetDomainsForURL(tlsURL, !o.IsRuntimeFlagSet(RuntimeFlagMatchAnyIncomingPort))) tlsDomains := urlutil.GetDomainsForURL(tlsURL, !o.IsRuntimeFlagSet(RuntimeFlagMatchAnyIncomingPort))
hosts.InsertSlice(tlsDomains)
// Track if the TLS domains are associated with an MCP policy
if policy.IsMCP() {
for _, domain := range tlsDomains {
mcpHosts[domain] = true
}
}
} }
} }
} }
return hosts.Slice(), nil return hosts.Slice(), mcpHosts, nil
} }
// GetClientSecret gets the client secret. // GetClientSecret gets the client secret.

View file

@ -888,22 +888,26 @@ func TestOptions_GetAllRouteableGRPCHosts(t *testing.T) {
} }
func TestOptions_GetAllRouteableHTTPHosts(t *testing.T) { func TestOptions_GetAllRouteableHTTPHosts(t *testing.T) {
p1 := Policy{From: "https://from1.example.com"} to := WeightedURLs{{URL: url.URL{Scheme: "https", Host: "to.example.com"}}}
p1.Validate() p1 := Policy{From: "https://from1.example.com", To: to}
p2 := Policy{From: "https://from2.example.com"} assert.NoError(t, p1.Validate())
p2.Validate() p2 := Policy{From: "https://from2.example.com", To: to}
p3 := Policy{From: "https://from3.example.com", TLSDownstreamServerName: "from.example.com"} assert.NoError(t, p2.Validate())
p3.Validate() p3 := Policy{From: "https://from3.example.com", TLSDownstreamServerName: "from.example.com", To: to}
assert.NoError(t, p3.Validate())
p4 := Policy{From: "https://from4.example.com", MCP: &MCP{}, To: to}
assert.NoError(t, p4.Validate())
opts := &Options{ opts := &Options{
AuthenticateURLString: "https://authenticate.example.com", AuthenticateURLString: "https://authenticate.example.com",
AuthorizeURLString: "https://authorize.example.com", AuthorizeURLString: "https://authorize.example.com",
DataBrokerURLString: "https://databroker.example.com", DataBrokerURLString: "https://databroker.example.com",
Policies: []Policy{p1, p2, p3}, Policies: []Policy{p1, p2, p3, p4},
Services: "all", Services: "all",
} }
hosts, err := opts.GetAllRouteableHTTPHosts() hosts, mcpHosts, err := opts.GetAllRouteableHTTPHosts()
assert.NoError(t, err) assert.NoError(t, err)
assert.Empty(t, cmp.Diff(mcpHosts, map[string]bool{"from4.example.com:443": true, "from4.example.com": true}))
assert.Equal(t, []string{ assert.Equal(t, []string{
"authenticate.example.com", "authenticate.example.com",
@ -916,6 +920,8 @@ func TestOptions_GetAllRouteableHTTPHosts(t *testing.T) {
"from2.example.com:443", "from2.example.com:443",
"from3.example.com", "from3.example.com",
"from3.example.com:443", "from3.example.com:443",
"from4.example.com",
"from4.example.com:443",
}, hosts) }, hosts)
} }