From 1a19ccabd88cf50c13c182e23188acf697413499 Mon Sep 17 00:00:00 2001 From: Denis Mishin Date: Fri, 2 May 2025 16:33:42 -0400 Subject: [PATCH] mcp: add global runtime flag (#5604) ## Summary Adds global runtime flag to enable/disable MCP support. (off by default). ```yaml runtime_flags: mcp: true ``` ## Related issues Fix: https://linear.app/pomerium/issue/ENG-2367/place-mcp-support-behind-a-runtime-flag ## User Explanation ## Checklist - [x] reference any related issues - [ ] updated unit tests - [ ] add appropriate label (`enhancement`, `bug`, `breaking`, `dependencies`, `ci`) - [ ] ready for review --- authorize/check_response.go | 7 ++++-- authorize/check_response_test.go | 16 +++++++++++++- authorize/grpc.go | 14 ++++++------ authorize/state.go | 14 +++++++----- config/envoyconfig/routes.go | 2 +- config/envoyconfig/routes_test.go | 36 +++++++++++++++++++++++++++++++ config/runtime_flags.go | 3 +++ internal/controlplane/http.go | 8 ++++--- proxy/handlers.go | 6 ++++-- proxy/proxy.go | 25 +++++++++++---------- 10 files changed, 100 insertions(+), 31 deletions(-) diff --git a/authorize/check_response.go b/authorize/check_response.go index 751c0b925..4665eecae 100644 --- a/authorize/check_response.go +++ b/authorize/check_response.go @@ -21,6 +21,7 @@ import ( "github.com/pomerium/pomerium/authorize/checkrequest" "github.com/pomerium/pomerium/authorize/evaluator" + "github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/internal/httputil" "github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/urlutil" @@ -358,8 +359,10 @@ func (a *Authorize) userInfoEndpointURL(in *envoy_service_auth_v3.CheckRequest) } func (a *Authorize) shouldRedirect(in *envoy_service_auth_v3.CheckRequest, request *evaluator.Request) bool { - if request.Policy.IsMCPServer() { - return false + if a.currentConfig.Load().Options.IsRuntimeFlagSet(config.RuntimeFlagMCP) { + if request.Policy.IsMCPServer() { + return false + } } requestHeaders := in.GetAttributes().GetRequest().GetHttp().GetHeaders() diff --git a/authorize/check_response_test.go b/authorize/check_response_test.go index e98c13548..cb2b9987a 100644 --- a/authorize/check_response_test.go +++ b/authorize/check_response_test.go @@ -113,7 +113,8 @@ func TestAuthorize_handleResult(t *testing.T) { assert.NoError(t, err) assert.Equal(t, 495, int(res.GetDeniedResponse().GetStatus().GetCode())) }) - t.Run("mcp-route-unauthenticated", func(t *testing.T) { + t.Run("mcp-route-unauthenticated, mcp flag is on", func(t *testing.T) { + opt.RuntimeFlags[config.RuntimeFlagMCP] = true res, err := a.handleResult(context.Background(), &envoy_service_auth_v3.CheckRequest{}, &evaluator.Request{ @@ -125,6 +126,19 @@ func TestAuthorize_handleResult(t *testing.T) { assert.NoError(t, err) assert.Equal(t, 401, int(res.GetDeniedResponse().GetStatus().GetCode())) }) + t.Run("mcp-route-unauthenticated, mcp flag is off", func(t *testing.T) { + opt.RuntimeFlags[config.RuntimeFlagMCP] = false + res, err := a.handleResult(context.Background(), + &envoy_service_auth_v3.CheckRequest{}, + &evaluator.Request{ + Policy: &config.Policy{MCP: &config.MCP{}}, + }, + &evaluator.Result{ + Allow: evaluator.NewRuleResult(false, criteria.ReasonUserUnauthenticated), + }) + assert.NoError(t, err) + assert.Equal(t, 302, int(res.GetDeniedResponse().GetStatus().GetCode())) + }) } func TestAuthorize_okResponse(t *testing.T) { diff --git a/authorize/grpc.go b/authorize/grpc.go index f716ceec7..792370fa9 100644 --- a/authorize/grpc.go +++ b/authorize/grpc.go @@ -123,13 +123,15 @@ func (a *Authorize) maybeGetSessionFromRequest( hreq *http.Request, policy *config.Policy, ) (*session.Session, error) { - if policy.IsMCPServer() { - s, err := a.getMCPSession(ctx, hreq) - if err != nil { - log.Ctx(ctx).Error().Err(err).Msg("error getting mcp session") - return nil, err + if a.currentConfig.Load().Options.IsRuntimeFlagSet(config.RuntimeFlagMCP) { + if policy.IsMCPServer() { + s, err := a.getMCPSession(ctx, hreq) + if err != nil { + log.Ctx(ctx).Error().Err(err).Msg("error getting mcp session") + return nil, err + } + return s, nil } - return s, nil } // attempt to create a session from an incoming idp token diff --git a/authorize/state.go b/authorize/state.go index e6545a9ce..a67f9da39 100644 --- a/authorize/state.go +++ b/authorize/state.go @@ -59,13 +59,17 @@ func newAuthorizeStateFromConfig( previousEvaluator = previousState.evaluator } - mcp, err := mcp.New(ctx, mcp.DefaultPrefix, cfg) - if err != nil { - return nil, fmt.Errorf("authorize: failed to create mcp handler: %w", err) + var evaluatorOptions []evaluator.Option + if cfg.Options.IsRuntimeFlagSet(config.RuntimeFlagMCP) { + mcp, err := mcp.New(ctx, mcp.DefaultPrefix, cfg) + if err != nil { + return nil, fmt.Errorf("authorize: failed to create mcp handler: %w", err) + } + state.mcp = mcp + evaluatorOptions = append(evaluatorOptions, evaluator.WithMCPAccessTokenProvider(mcp)) } - state.mcp = mcp - state.evaluator, err = newPolicyEvaluator(ctx, cfg.Options, store, previousEvaluator, evaluator.WithMCPAccessTokenProvider(mcp)) + state.evaluator, err = newPolicyEvaluator(ctx, cfg.Options, store, previousEvaluator, evaluatorOptions...) if err != nil { return nil, fmt.Errorf("authorize: failed to update policy with options: %w", err) } diff --git a/config/envoyconfig/routes.go b/config/envoyconfig/routes.go index f96783930..394394708 100644 --- a/config/envoyconfig/routes.go +++ b/config/envoyconfig/routes.go @@ -72,7 +72,7 @@ func (b *Builder) buildPomeriumHTTPRoutes( ) // Only add oauth-authorization-server route if there's an MCP policy for this host - if isMCPHost { + if options.IsRuntimeFlagSet(config.RuntimeFlagMCP) && isMCPHost { routes = append(routes, b.buildControlPlanePathRoute(options, "/.well-known/oauth-authorization-server")) } } diff --git a/config/envoyconfig/routes_test.go b/config/envoyconfig/routes_test.go index aedaf6638..75a6f968d 100644 --- a/config/envoyconfig/routes_test.go +++ b/config/envoyconfig/routes_test.go @@ -2331,7 +2331,9 @@ func Test_buildPomeriumHTTPRoutesWithMCP(t *testing.T) { MCP: &config.MCP{}, // This marks the policy as an MCP policy }, }, + RuntimeFlags: config.DefaultRuntimeFlags(), } + options.RuntimeFlags[config.RuntimeFlagMCP] = true routes, err := b.buildPomeriumHTTPRoutes(options, "example.com", true) require.NoError(t, err) @@ -2347,4 +2349,38 @@ func Test_buildPomeriumHTTPRoutesWithMCP(t *testing.T) { `+routeString("path", "/.well-known/oauth-authorization-server")+` ]`, routes) }) + + t.Run("with MCP policy, runtime flag is off", func(t *testing.T) { + b := &Builder{filemgr: filemgr.NewManager()} + options := &config.Options{ + Services: "all", + AuthenticateURLString: "https://authenticate.example.com", + Policies: []config.Policy{ + { + From: "https://example.com", + To: mustParseWeightedURLs(t, "https://to.example.com"), + }, + { + From: "https://mcp.example.com", + To: mustParseWeightedURLs(t, "https://mcp-backend.example.com"), + MCP: &config.MCP{}, // This marks the policy as an MCP policy + }, + }, + RuntimeFlags: config.DefaultRuntimeFlags(), + } + options.RuntimeFlags[config.RuntimeFlagMCP] = false + + routes, err := b.buildPomeriumHTTPRoutes(options, "example.com", true) + require.NoError(t, err) + + // Verify the expected route structures + testutil.AssertProtoJSONEqual(t, `[ + `+routeString("path", "/ping")+`, + `+routeString("path", "/healthz")+`, + `+routeString("path", "/.pomerium")+`, + `+routeString("prefix", "/.pomerium/")+`, + `+routeString("path", "/.well-known/pomerium")+`, + `+routeString("prefix", "/.well-known/pomerium/")+` + ]`, routes) + }) } diff --git a/config/runtime_flags.go b/config/runtime_flags.go index b4331d79a..930b44a4d 100644 --- a/config/runtime_flags.go +++ b/config/runtime_flags.go @@ -29,6 +29,9 @@ var ( // RuntimeFlagAuthorizeUseSyncedData enables synced data for querying the databroker for // certain types of data. RuntimeFlagAuthorizeUseSyncedData = runtimeFlag("authorize_use_synced_data", true) + + // RuntimeFlagMCP enables the MCP services for the authorize service + RuntimeFlagMCP = runtimeFlag("mcp", false) ) // RuntimeFlag is a runtime flag that can flip on/off certain features diff --git a/internal/controlplane/http.go b/internal/controlplane/http.go index ab8c9e77d..4ccd6261d 100644 --- a/internal/controlplane/http.go +++ b/internal/controlplane/http.go @@ -81,9 +81,11 @@ func (srv *Server) mountCommonEndpoints(root *mux.Router, cfg *config.Config) er root.Path("/.well-known/pomerium/jwks.json").Methods(http.MethodGet).Handler(traceHandler(handlers.JWKSHandler(signingKey))) root.Path(urlutil.HPKEPublicKeyPath).Methods(http.MethodGet).Handler(traceHandler(hpke_handlers.HPKEPublicKeyHandler(hpkePublicKey))) - root.Path("/.well-known/oauth-authorization-server"). - Methods(http.MethodGet, http.MethodOptions). - Handler(mcp.AuthorizationServerMetadataHandler(mcp.DefaultPrefix)) + if cfg.Options.IsRuntimeFlagSet(config.RuntimeFlagMCP) { + root.Path("/.well-known/oauth-authorization-server"). + Methods(http.MethodGet, http.MethodOptions). + Handler(mcp.AuthorizationServerMetadataHandler(mcp.DefaultPrefix)) + } return nil } diff --git a/proxy/handlers.go b/proxy/handlers.go index 5856ba1b7..6292845a0 100644 --- a/proxy/handlers.go +++ b/proxy/handlers.go @@ -24,8 +24,10 @@ func (p *Proxy) registerDashboardHandlers(r *mux.Router, opts *config.Options) * h := httputil.DashboardSubrouter(r) h.Use(middleware.SetHeaders(httputil.HeadersContentSecurityPolicy)) - // model context protocol - h.PathPrefix("/mcp").Handler(p.mcp.Load().HandlerFunc()) + if opts.IsRuntimeFlagSet(config.RuntimeFlagMCP) { + // model context protocol + h.PathPrefix("/mcp").Handler(p.mcp.Load().HandlerFunc()) + } // special pomerium endpoints for users to view their session h.Path("/").Handler(httputil.HandlerFunc(p.userInfo)).Methods(http.MethodGet) diff --git a/proxy/proxy.go b/proxy/proxy.go index 935ec79d8..faa88ff56 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -76,18 +76,19 @@ func New(ctx context.Context, cfg *config.Config) (*Proxy, error) { return nil, err } - mcp, err := mcp.New(ctx, mcp.DefaultPrefix, cfg) - if err != nil { - return nil, fmt.Errorf("proxy: failed to create mcp handler: %w", err) - } - p := &Proxy{ tracerProvider: tracerProvider, state: atomicutil.NewValue(state), currentConfig: atomicutil.NewValue(&config.Config{Options: config.NewDefaultOptions()}), currentRouter: atomicutil.NewValue(httputil.NewRouter()), logoProvider: portal.NewLogoProvider(), - mcp: atomicutil.NewValue(mcp), + } + if cfg.Options.IsRuntimeFlagSet(config.RuntimeFlagMCP) { + mcp, err := mcp.New(ctx, mcp.DefaultPrefix, cfg) + if err != nil { + return nil, fmt.Errorf("proxy: failed to create mcp handler: %w", err) + } + p.mcp = atomicutil.NewValue(mcp) } p.OnConfigChange(ctx, cfg) p.webauthn = webauthn.New(p.getWebauthnState) @@ -110,11 +111,13 @@ func (p *Proxy) OnConfigChange(ctx context.Context, cfg *config.Config) { return } - mcp, err := mcp.New(ctx, mcp.DefaultPrefix, cfg) - if err != nil { - log.Ctx(ctx).Error().Err(err).Msg("proxy: failed to update proxy state from configuration settings") - } else { - p.mcp.Store(mcp) + if cfg.Options.IsRuntimeFlagSet(config.RuntimeFlagMCP) { + mcp, err := mcp.New(ctx, mcp.DefaultPrefix, cfg) + if err != nil { + log.Ctx(ctx).Error().Err(err).Msg("proxy: failed to update proxy state from configuration settings") + } else { + p.mcp.Store(mcp) + } } p.currentConfig.Store(cfg)