mcp: add global runtime flag (#5604)

## Summary

Adds global runtime flag to enable/disable MCP support. (off by
default).

```yaml
runtime_flags:
  mcp: true
```

## Related issues

Fix:
https://linear.app/pomerium/issue/ENG-2367/place-mcp-support-behind-a-runtime-flag

## User Explanation

<!-- How would you explain this change to the user? If this
change doesn't create any user-facing changes, you can leave
this blank. If filled out, add the `docs` label -->

## Checklist

- [x] reference any related issues
- [ ] updated unit tests
- [ ] add appropriate label (`enhancement`, `bug`, `breaking`,
`dependencies`, `ci`)
- [ ] ready for review
This commit is contained in:
Denis Mishin 2025-05-02 16:33:42 -04:00 committed by GitHub
parent d1559eaa86
commit 1a19ccabd8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 100 additions and 31 deletions

View file

@ -21,6 +21,7 @@ import (
"github.com/pomerium/pomerium/authorize/checkrequest" "github.com/pomerium/pomerium/authorize/checkrequest"
"github.com/pomerium/pomerium/authorize/evaluator" "github.com/pomerium/pomerium/authorize/evaluator"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/httputil" "github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/urlutil" "github.com/pomerium/pomerium/internal/urlutil"
@ -358,8 +359,10 @@ func (a *Authorize) userInfoEndpointURL(in *envoy_service_auth_v3.CheckRequest)
} }
func (a *Authorize) shouldRedirect(in *envoy_service_auth_v3.CheckRequest, request *evaluator.Request) bool { func (a *Authorize) shouldRedirect(in *envoy_service_auth_v3.CheckRequest, request *evaluator.Request) bool {
if request.Policy.IsMCPServer() { if a.currentConfig.Load().Options.IsRuntimeFlagSet(config.RuntimeFlagMCP) {
return false if request.Policy.IsMCPServer() {
return false
}
} }
requestHeaders := in.GetAttributes().GetRequest().GetHttp().GetHeaders() requestHeaders := in.GetAttributes().GetRequest().GetHttp().GetHeaders()

View file

@ -113,7 +113,8 @@ func TestAuthorize_handleResult(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, 495, int(res.GetDeniedResponse().GetStatus().GetCode())) assert.Equal(t, 495, int(res.GetDeniedResponse().GetStatus().GetCode()))
}) })
t.Run("mcp-route-unauthenticated", func(t *testing.T) { t.Run("mcp-route-unauthenticated, mcp flag is on", func(t *testing.T) {
opt.RuntimeFlags[config.RuntimeFlagMCP] = true
res, err := a.handleResult(context.Background(), res, err := a.handleResult(context.Background(),
&envoy_service_auth_v3.CheckRequest{}, &envoy_service_auth_v3.CheckRequest{},
&evaluator.Request{ &evaluator.Request{
@ -125,6 +126,19 @@ func TestAuthorize_handleResult(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, 401, int(res.GetDeniedResponse().GetStatus().GetCode())) assert.Equal(t, 401, int(res.GetDeniedResponse().GetStatus().GetCode()))
}) })
t.Run("mcp-route-unauthenticated, mcp flag is off", func(t *testing.T) {
opt.RuntimeFlags[config.RuntimeFlagMCP] = false
res, err := a.handleResult(context.Background(),
&envoy_service_auth_v3.CheckRequest{},
&evaluator.Request{
Policy: &config.Policy{MCP: &config.MCP{}},
},
&evaluator.Result{
Allow: evaluator.NewRuleResult(false, criteria.ReasonUserUnauthenticated),
})
assert.NoError(t, err)
assert.Equal(t, 302, int(res.GetDeniedResponse().GetStatus().GetCode()))
})
} }
func TestAuthorize_okResponse(t *testing.T) { func TestAuthorize_okResponse(t *testing.T) {

View file

@ -123,13 +123,15 @@ func (a *Authorize) maybeGetSessionFromRequest(
hreq *http.Request, hreq *http.Request,
policy *config.Policy, policy *config.Policy,
) (*session.Session, error) { ) (*session.Session, error) {
if policy.IsMCPServer() { if a.currentConfig.Load().Options.IsRuntimeFlagSet(config.RuntimeFlagMCP) {
s, err := a.getMCPSession(ctx, hreq) if policy.IsMCPServer() {
if err != nil { s, err := a.getMCPSession(ctx, hreq)
log.Ctx(ctx).Error().Err(err).Msg("error getting mcp session") if err != nil {
return nil, err log.Ctx(ctx).Error().Err(err).Msg("error getting mcp session")
return nil, err
}
return s, nil
} }
return s, nil
} }
// attempt to create a session from an incoming idp token // attempt to create a session from an incoming idp token

View file

@ -59,13 +59,17 @@ func newAuthorizeStateFromConfig(
previousEvaluator = previousState.evaluator previousEvaluator = previousState.evaluator
} }
mcp, err := mcp.New(ctx, mcp.DefaultPrefix, cfg) var evaluatorOptions []evaluator.Option
if err != nil { if cfg.Options.IsRuntimeFlagSet(config.RuntimeFlagMCP) {
return nil, fmt.Errorf("authorize: failed to create mcp handler: %w", err) mcp, err := mcp.New(ctx, mcp.DefaultPrefix, cfg)
if err != nil {
return nil, fmt.Errorf("authorize: failed to create mcp handler: %w", err)
}
state.mcp = mcp
evaluatorOptions = append(evaluatorOptions, evaluator.WithMCPAccessTokenProvider(mcp))
} }
state.mcp = mcp
state.evaluator, err = newPolicyEvaluator(ctx, cfg.Options, store, previousEvaluator, evaluator.WithMCPAccessTokenProvider(mcp)) state.evaluator, err = newPolicyEvaluator(ctx, cfg.Options, store, previousEvaluator, evaluatorOptions...)
if err != nil { if err != nil {
return nil, fmt.Errorf("authorize: failed to update policy with options: %w", err) return nil, fmt.Errorf("authorize: failed to update policy with options: %w", err)
} }

View file

@ -72,7 +72,7 @@ func (b *Builder) buildPomeriumHTTPRoutes(
) )
// Only add oauth-authorization-server route if there's an MCP policy for this host // Only add oauth-authorization-server route if there's an MCP policy for this host
if isMCPHost { if options.IsRuntimeFlagSet(config.RuntimeFlagMCP) && isMCPHost {
routes = append(routes, b.buildControlPlanePathRoute(options, "/.well-known/oauth-authorization-server")) routes = append(routes, b.buildControlPlanePathRoute(options, "/.well-known/oauth-authorization-server"))
} }
} }

View file

@ -2331,7 +2331,9 @@ func Test_buildPomeriumHTTPRoutesWithMCP(t *testing.T) {
MCP: &config.MCP{}, // This marks the policy as an MCP policy MCP: &config.MCP{}, // This marks the policy as an MCP policy
}, },
}, },
RuntimeFlags: config.DefaultRuntimeFlags(),
} }
options.RuntimeFlags[config.RuntimeFlagMCP] = true
routes, err := b.buildPomeriumHTTPRoutes(options, "example.com", true) routes, err := b.buildPomeriumHTTPRoutes(options, "example.com", true)
require.NoError(t, err) require.NoError(t, err)
@ -2347,4 +2349,38 @@ func Test_buildPomeriumHTTPRoutesWithMCP(t *testing.T) {
`+routeString("path", "/.well-known/oauth-authorization-server")+` `+routeString("path", "/.well-known/oauth-authorization-server")+`
]`, routes) ]`, routes)
}) })
t.Run("with MCP policy, runtime flag is off", func(t *testing.T) {
b := &Builder{filemgr: filemgr.NewManager()}
options := &config.Options{
Services: "all",
AuthenticateURLString: "https://authenticate.example.com",
Policies: []config.Policy{
{
From: "https://example.com",
To: mustParseWeightedURLs(t, "https://to.example.com"),
},
{
From: "https://mcp.example.com",
To: mustParseWeightedURLs(t, "https://mcp-backend.example.com"),
MCP: &config.MCP{}, // This marks the policy as an MCP policy
},
},
RuntimeFlags: config.DefaultRuntimeFlags(),
}
options.RuntimeFlags[config.RuntimeFlagMCP] = false
routes, err := b.buildPomeriumHTTPRoutes(options, "example.com", true)
require.NoError(t, err)
// Verify the expected route structures
testutil.AssertProtoJSONEqual(t, `[
`+routeString("path", "/ping")+`,
`+routeString("path", "/healthz")+`,
`+routeString("path", "/.pomerium")+`,
`+routeString("prefix", "/.pomerium/")+`,
`+routeString("path", "/.well-known/pomerium")+`,
`+routeString("prefix", "/.well-known/pomerium/")+`
]`, routes)
})
} }

View file

@ -29,6 +29,9 @@ var (
// RuntimeFlagAuthorizeUseSyncedData enables synced data for querying the databroker for // RuntimeFlagAuthorizeUseSyncedData enables synced data for querying the databroker for
// certain types of data. // certain types of data.
RuntimeFlagAuthorizeUseSyncedData = runtimeFlag("authorize_use_synced_data", true) RuntimeFlagAuthorizeUseSyncedData = runtimeFlag("authorize_use_synced_data", true)
// RuntimeFlagMCP enables the MCP services for the authorize service
RuntimeFlagMCP = runtimeFlag("mcp", false)
) )
// RuntimeFlag is a runtime flag that can flip on/off certain features // RuntimeFlag is a runtime flag that can flip on/off certain features

View file

@ -81,9 +81,11 @@ func (srv *Server) mountCommonEndpoints(root *mux.Router, cfg *config.Config) er
root.Path("/.well-known/pomerium/jwks.json").Methods(http.MethodGet).Handler(traceHandler(handlers.JWKSHandler(signingKey))) root.Path("/.well-known/pomerium/jwks.json").Methods(http.MethodGet).Handler(traceHandler(handlers.JWKSHandler(signingKey)))
root.Path(urlutil.HPKEPublicKeyPath).Methods(http.MethodGet).Handler(traceHandler(hpke_handlers.HPKEPublicKeyHandler(hpkePublicKey))) root.Path(urlutil.HPKEPublicKeyPath).Methods(http.MethodGet).Handler(traceHandler(hpke_handlers.HPKEPublicKeyHandler(hpkePublicKey)))
root.Path("/.well-known/oauth-authorization-server"). if cfg.Options.IsRuntimeFlagSet(config.RuntimeFlagMCP) {
Methods(http.MethodGet, http.MethodOptions). root.Path("/.well-known/oauth-authorization-server").
Handler(mcp.AuthorizationServerMetadataHandler(mcp.DefaultPrefix)) Methods(http.MethodGet, http.MethodOptions).
Handler(mcp.AuthorizationServerMetadataHandler(mcp.DefaultPrefix))
}
return nil return nil
} }

View file

@ -24,8 +24,10 @@ func (p *Proxy) registerDashboardHandlers(r *mux.Router, opts *config.Options) *
h := httputil.DashboardSubrouter(r) h := httputil.DashboardSubrouter(r)
h.Use(middleware.SetHeaders(httputil.HeadersContentSecurityPolicy)) h.Use(middleware.SetHeaders(httputil.HeadersContentSecurityPolicy))
// model context protocol if opts.IsRuntimeFlagSet(config.RuntimeFlagMCP) {
h.PathPrefix("/mcp").Handler(p.mcp.Load().HandlerFunc()) // model context protocol
h.PathPrefix("/mcp").Handler(p.mcp.Load().HandlerFunc())
}
// special pomerium endpoints for users to view their session // special pomerium endpoints for users to view their session
h.Path("/").Handler(httputil.HandlerFunc(p.userInfo)).Methods(http.MethodGet) h.Path("/").Handler(httputil.HandlerFunc(p.userInfo)).Methods(http.MethodGet)

View file

@ -76,18 +76,19 @@ func New(ctx context.Context, cfg *config.Config) (*Proxy, error) {
return nil, err return nil, err
} }
mcp, err := mcp.New(ctx, mcp.DefaultPrefix, cfg)
if err != nil {
return nil, fmt.Errorf("proxy: failed to create mcp handler: %w", err)
}
p := &Proxy{ p := &Proxy{
tracerProvider: tracerProvider, tracerProvider: tracerProvider,
state: atomicutil.NewValue(state), state: atomicutil.NewValue(state),
currentConfig: atomicutil.NewValue(&config.Config{Options: config.NewDefaultOptions()}), currentConfig: atomicutil.NewValue(&config.Config{Options: config.NewDefaultOptions()}),
currentRouter: atomicutil.NewValue(httputil.NewRouter()), currentRouter: atomicutil.NewValue(httputil.NewRouter()),
logoProvider: portal.NewLogoProvider(), logoProvider: portal.NewLogoProvider(),
mcp: atomicutil.NewValue(mcp), }
if cfg.Options.IsRuntimeFlagSet(config.RuntimeFlagMCP) {
mcp, err := mcp.New(ctx, mcp.DefaultPrefix, cfg)
if err != nil {
return nil, fmt.Errorf("proxy: failed to create mcp handler: %w", err)
}
p.mcp = atomicutil.NewValue(mcp)
} }
p.OnConfigChange(ctx, cfg) p.OnConfigChange(ctx, cfg)
p.webauthn = webauthn.New(p.getWebauthnState) p.webauthn = webauthn.New(p.getWebauthnState)
@ -110,11 +111,13 @@ func (p *Proxy) OnConfigChange(ctx context.Context, cfg *config.Config) {
return return
} }
mcp, err := mcp.New(ctx, mcp.DefaultPrefix, cfg) if cfg.Options.IsRuntimeFlagSet(config.RuntimeFlagMCP) {
if err != nil { mcp, err := mcp.New(ctx, mcp.DefaultPrefix, cfg)
log.Ctx(ctx).Error().Err(err).Msg("proxy: failed to update proxy state from configuration settings") if err != nil {
} else { log.Ctx(ctx).Error().Err(err).Msg("proxy: failed to update proxy state from configuration settings")
p.mcp.Store(mcp) } else {
p.mcp.Store(mcp)
}
} }
p.currentConfig.Store(cfg) p.currentConfig.Store(cfg)