mcp: add global runtime flag (#5604)

## Summary

Adds global runtime flag to enable/disable MCP support. (off by
default).

```yaml
runtime_flags:
  mcp: true
```

## Related issues

Fix:
https://linear.app/pomerium/issue/ENG-2367/place-mcp-support-behind-a-runtime-flag

## User Explanation

<!-- How would you explain this change to the user? If this
change doesn't create any user-facing changes, you can leave
this blank. If filled out, add the `docs` label -->

## Checklist

- [x] reference any related issues
- [ ] updated unit tests
- [ ] add appropriate label (`enhancement`, `bug`, `breaking`,
`dependencies`, `ci`)
- [ ] ready for review
This commit is contained in:
Denis Mishin 2025-05-02 16:33:42 -04:00 committed by GitHub
parent d1559eaa86
commit 1a19ccabd8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 100 additions and 31 deletions

View file

@ -21,6 +21,7 @@ import (
"github.com/pomerium/pomerium/authorize/checkrequest"
"github.com/pomerium/pomerium/authorize/evaluator"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/urlutil"
@ -358,9 +359,11 @@ func (a *Authorize) userInfoEndpointURL(in *envoy_service_auth_v3.CheckRequest)
}
func (a *Authorize) shouldRedirect(in *envoy_service_auth_v3.CheckRequest, request *evaluator.Request) bool {
if a.currentConfig.Load().Options.IsRuntimeFlagSet(config.RuntimeFlagMCP) {
if request.Policy.IsMCPServer() {
return false
}
}
requestHeaders := in.GetAttributes().GetRequest().GetHttp().GetHeaders()
if requestHeaders == nil {

View file

@ -113,7 +113,8 @@ func TestAuthorize_handleResult(t *testing.T) {
assert.NoError(t, err)
assert.Equal(t, 495, int(res.GetDeniedResponse().GetStatus().GetCode()))
})
t.Run("mcp-route-unauthenticated", func(t *testing.T) {
t.Run("mcp-route-unauthenticated, mcp flag is on", func(t *testing.T) {
opt.RuntimeFlags[config.RuntimeFlagMCP] = true
res, err := a.handleResult(context.Background(),
&envoy_service_auth_v3.CheckRequest{},
&evaluator.Request{
@ -125,6 +126,19 @@ func TestAuthorize_handleResult(t *testing.T) {
assert.NoError(t, err)
assert.Equal(t, 401, int(res.GetDeniedResponse().GetStatus().GetCode()))
})
t.Run("mcp-route-unauthenticated, mcp flag is off", func(t *testing.T) {
opt.RuntimeFlags[config.RuntimeFlagMCP] = false
res, err := a.handleResult(context.Background(),
&envoy_service_auth_v3.CheckRequest{},
&evaluator.Request{
Policy: &config.Policy{MCP: &config.MCP{}},
},
&evaluator.Result{
Allow: evaluator.NewRuleResult(false, criteria.ReasonUserUnauthenticated),
})
assert.NoError(t, err)
assert.Equal(t, 302, int(res.GetDeniedResponse().GetStatus().GetCode()))
})
}
func TestAuthorize_okResponse(t *testing.T) {

View file

@ -123,6 +123,7 @@ func (a *Authorize) maybeGetSessionFromRequest(
hreq *http.Request,
policy *config.Policy,
) (*session.Session, error) {
if a.currentConfig.Load().Options.IsRuntimeFlagSet(config.RuntimeFlagMCP) {
if policy.IsMCPServer() {
s, err := a.getMCPSession(ctx, hreq)
if err != nil {
@ -131,6 +132,7 @@ func (a *Authorize) maybeGetSessionFromRequest(
}
return s, nil
}
}
// attempt to create a session from an incoming idp token
return config.NewIncomingIDPTokenSessionCreator(

View file

@ -59,13 +59,17 @@ func newAuthorizeStateFromConfig(
previousEvaluator = previousState.evaluator
}
var evaluatorOptions []evaluator.Option
if cfg.Options.IsRuntimeFlagSet(config.RuntimeFlagMCP) {
mcp, err := mcp.New(ctx, mcp.DefaultPrefix, cfg)
if err != nil {
return nil, fmt.Errorf("authorize: failed to create mcp handler: %w", err)
}
state.mcp = mcp
evaluatorOptions = append(evaluatorOptions, evaluator.WithMCPAccessTokenProvider(mcp))
}
state.evaluator, err = newPolicyEvaluator(ctx, cfg.Options, store, previousEvaluator, evaluator.WithMCPAccessTokenProvider(mcp))
state.evaluator, err = newPolicyEvaluator(ctx, cfg.Options, store, previousEvaluator, evaluatorOptions...)
if err != nil {
return nil, fmt.Errorf("authorize: failed to update policy with options: %w", err)
}

View file

@ -72,7 +72,7 @@ func (b *Builder) buildPomeriumHTTPRoutes(
)
// Only add oauth-authorization-server route if there's an MCP policy for this host
if isMCPHost {
if options.IsRuntimeFlagSet(config.RuntimeFlagMCP) && isMCPHost {
routes = append(routes, b.buildControlPlanePathRoute(options, "/.well-known/oauth-authorization-server"))
}
}

View file

@ -2331,7 +2331,9 @@ func Test_buildPomeriumHTTPRoutesWithMCP(t *testing.T) {
MCP: &config.MCP{}, // This marks the policy as an MCP policy
},
},
RuntimeFlags: config.DefaultRuntimeFlags(),
}
options.RuntimeFlags[config.RuntimeFlagMCP] = true
routes, err := b.buildPomeriumHTTPRoutes(options, "example.com", true)
require.NoError(t, err)
@ -2347,4 +2349,38 @@ func Test_buildPomeriumHTTPRoutesWithMCP(t *testing.T) {
`+routeString("path", "/.well-known/oauth-authorization-server")+`
]`, routes)
})
t.Run("with MCP policy, runtime flag is off", func(t *testing.T) {
b := &Builder{filemgr: filemgr.NewManager()}
options := &config.Options{
Services: "all",
AuthenticateURLString: "https://authenticate.example.com",
Policies: []config.Policy{
{
From: "https://example.com",
To: mustParseWeightedURLs(t, "https://to.example.com"),
},
{
From: "https://mcp.example.com",
To: mustParseWeightedURLs(t, "https://mcp-backend.example.com"),
MCP: &config.MCP{}, // This marks the policy as an MCP policy
},
},
RuntimeFlags: config.DefaultRuntimeFlags(),
}
options.RuntimeFlags[config.RuntimeFlagMCP] = false
routes, err := b.buildPomeriumHTTPRoutes(options, "example.com", true)
require.NoError(t, err)
// Verify the expected route structures
testutil.AssertProtoJSONEqual(t, `[
`+routeString("path", "/ping")+`,
`+routeString("path", "/healthz")+`,
`+routeString("path", "/.pomerium")+`,
`+routeString("prefix", "/.pomerium/")+`,
`+routeString("path", "/.well-known/pomerium")+`,
`+routeString("prefix", "/.well-known/pomerium/")+`
]`, routes)
})
}

View file

@ -29,6 +29,9 @@ var (
// RuntimeFlagAuthorizeUseSyncedData enables synced data for querying the databroker for
// certain types of data.
RuntimeFlagAuthorizeUseSyncedData = runtimeFlag("authorize_use_synced_data", true)
// RuntimeFlagMCP enables the MCP services for the authorize service
RuntimeFlagMCP = runtimeFlag("mcp", false)
)
// RuntimeFlag is a runtime flag that can flip on/off certain features

View file

@ -81,9 +81,11 @@ func (srv *Server) mountCommonEndpoints(root *mux.Router, cfg *config.Config) er
root.Path("/.well-known/pomerium/jwks.json").Methods(http.MethodGet).Handler(traceHandler(handlers.JWKSHandler(signingKey)))
root.Path(urlutil.HPKEPublicKeyPath).Methods(http.MethodGet).Handler(traceHandler(hpke_handlers.HPKEPublicKeyHandler(hpkePublicKey)))
if cfg.Options.IsRuntimeFlagSet(config.RuntimeFlagMCP) {
root.Path("/.well-known/oauth-authorization-server").
Methods(http.MethodGet, http.MethodOptions).
Handler(mcp.AuthorizationServerMetadataHandler(mcp.DefaultPrefix))
}
return nil
}

View file

@ -24,8 +24,10 @@ func (p *Proxy) registerDashboardHandlers(r *mux.Router, opts *config.Options) *
h := httputil.DashboardSubrouter(r)
h.Use(middleware.SetHeaders(httputil.HeadersContentSecurityPolicy))
if opts.IsRuntimeFlagSet(config.RuntimeFlagMCP) {
// model context protocol
h.PathPrefix("/mcp").Handler(p.mcp.Load().HandlerFunc())
}
// special pomerium endpoints for users to view their session
h.Path("/").Handler(httputil.HandlerFunc(p.userInfo)).Methods(http.MethodGet)

View file

@ -76,18 +76,19 @@ func New(ctx context.Context, cfg *config.Config) (*Proxy, error) {
return nil, err
}
mcp, err := mcp.New(ctx, mcp.DefaultPrefix, cfg)
if err != nil {
return nil, fmt.Errorf("proxy: failed to create mcp handler: %w", err)
}
p := &Proxy{
tracerProvider: tracerProvider,
state: atomicutil.NewValue(state),
currentConfig: atomicutil.NewValue(&config.Config{Options: config.NewDefaultOptions()}),
currentRouter: atomicutil.NewValue(httputil.NewRouter()),
logoProvider: portal.NewLogoProvider(),
mcp: atomicutil.NewValue(mcp),
}
if cfg.Options.IsRuntimeFlagSet(config.RuntimeFlagMCP) {
mcp, err := mcp.New(ctx, mcp.DefaultPrefix, cfg)
if err != nil {
return nil, fmt.Errorf("proxy: failed to create mcp handler: %w", err)
}
p.mcp = atomicutil.NewValue(mcp)
}
p.OnConfigChange(ctx, cfg)
p.webauthn = webauthn.New(p.getWebauthnState)
@ -110,12 +111,14 @@ func (p *Proxy) OnConfigChange(ctx context.Context, cfg *config.Config) {
return
}
if cfg.Options.IsRuntimeFlagSet(config.RuntimeFlagMCP) {
mcp, err := mcp.New(ctx, mcp.DefaultPrefix, cfg)
if err != nil {
log.Ctx(ctx).Error().Err(err).Msg("proxy: failed to update proxy state from configuration settings")
} else {
p.mcp.Store(mcp)
}
}
p.currentConfig.Store(cfg)
if err := p.setHandlers(ctx, cfg.Options); err != nil {