mcp: split mcp into server and client for better option grouping (#5666)

This commit is contained in:
Denis Mishin 2025-06-24 10:21:32 -07:00 committed by GitHub
parent d36c48a2bc
commit db6449ecca
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 1156 additions and 934 deletions

View file

@ -117,7 +117,7 @@ func TestAuthorize_handleResult(t *testing.T) {
res, err := a.handleResult(t.Context(),
&envoy_service_auth_v3.CheckRequest{},
&evaluator.Request{
Policy: &config.Policy{MCP: &config.MCP{}},
Policy: &config.Policy{MCP: &config.MCP{Server: &config.MCPServer{}}},
},
&evaluator.Result{
Allow: evaluator.NewRuleResult(false, criteria.ReasonUserUnauthenticated),
@ -130,7 +130,7 @@ func TestAuthorize_handleResult(t *testing.T) {
res, err := a.handleResult(t.Context(),
&envoy_service_auth_v3.CheckRequest{},
&evaluator.Request{
Policy: &config.Policy{MCP: &config.MCP{}},
Policy: &config.Policy{MCP: &config.MCP{Server: &config.MCPServer{}}},
},
&evaluator.Result{
Allow: evaluator.NewRuleResult(false, criteria.ReasonUserUnauthenticated),

View file

@ -104,7 +104,7 @@ func (e *headersEvaluatorEvaluation) fillMCPHeaders(ctx context.Context) (err er
}
var accessToken string
if e.request.Policy.MCP.IsUpstreamClientNeedsAccessToken() {
if e.request.Policy.IsMCPClient() {
accessToken, err = p.GetAccessTokenForSession(e.request.Session.ID, time.Now().Add(5*time.Minute))
if err != nil {
return fmt.Errorf("authorize/header-evaluator: error getting MCP access token: %w", err)
@ -113,7 +113,7 @@ func (e *headersEvaluatorEvaluation) fillMCPHeaders(ctx context.Context) (err er
return nil
}
if e.request.Policy.MCP.HasUpstreamOAuth2() {
if e.request.Policy.MCP.GetServerUpstreamOAuth2() != nil {
user := e.getUser(ctx)
accessToken, err = p.GetUpstreamOAuth2Token(ctx, e.request.HTTP.Host, user.Id)
if err != nil {

View file

@ -686,11 +686,50 @@ func (f JWTIssuerFormat) Valid() bool {
}
func MCPFromPB(src *configpb.MCP) *MCP {
if srv := src.GetServer(); srv != nil {
return &MCP{Server: mcpServerFromPB(srv)}
} else if cli := src.GetClient(); cli != nil {
return &MCP{Client: mcpClientFromPB(cli)}
}
return nil
}
func MCPToPB(src *MCP) *configpb.MCP {
if src == nil {
return nil
}
var v MCP
v.PassUpstreamAccessToken = src.GetPassUpstreamAccessToken()
if src.Client != nil {
return &configpb.MCP{
Mode: &configpb.MCP_Client{
Client: mcpClientToPB(src.Client),
},
}
}
if src.Server != nil {
return &configpb.MCP{
Mode: &configpb.MCP_Server{
Server: mcpServerToPB(src.Server),
},
}
}
return nil
}
func mcpClientFromPB(*configpb.MCPClient) *MCPClient {
return &MCPClient{}
}
func mcpClientToPB(src *MCPClient) *configpb.MCPClient {
if src == nil {
return nil
}
return &configpb.MCPClient{}
}
func mcpServerFromPB(src *configpb.MCPServer) *MCPServer {
v := MCPServer{
MaxRequestBytes: src.MaxRequestBytes,
}
if uo := src.GetUpstreamOauth2(); uo != nil {
v.UpstreamOAuth2 = &UpstreamOAuth2{
ClientID: uo.GetClientId(),
@ -699,10 +738,30 @@ func MCPFromPB(src *configpb.MCP) *MCP {
Scopes: uo.GetScopes(),
}
}
return &v
}
func mcpServerToPB(src *MCPServer) *configpb.MCPServer {
if src == nil {
return nil
}
srv := &configpb.MCPServer{
MaxRequestBytes: src.MaxRequestBytes,
}
if src.UpstreamOAuth2 != nil {
srv.UpstreamOauth2 = &configpb.UpstreamOAuth2{
ClientId: src.UpstreamOAuth2.ClientID,
ClientSecret: src.UpstreamOAuth2.ClientSecret,
Oauth2Endpoint: OAuth2EndpointToPB(src.UpstreamOAuth2.Endpoint),
Scopes: src.UpstreamOAuth2.Scopes,
}
}
return srv
}
func OAuth2EndpointFromPB(src *configpb.OAuth2Endpoint) OAuth2Endpoint {
if src == nil {
return OAuth2Endpoint{}
@ -723,34 +782,20 @@ func OAuth2EndpointFromPB(src *configpb.OAuth2Endpoint) OAuth2Endpoint {
}
}
func MCPToPB(src *MCP) *configpb.MCP {
if src == nil {
return nil
}
v := new(configpb.MCP)
v.PassUpstreamAccessToken = proto.Bool(src.PassUpstreamAccessToken)
if src.UpstreamOAuth2 != nil {
var authStyle *configpb.OAuth2AuthStyle
switch src.UpstreamOAuth2.Endpoint.AuthStyle {
case OAuth2EndpointAuthStyleInHeader:
authStyle = configpb.OAuth2AuthStyle_OAUTH2_AUTH_STYLE_IN_HEADER.Enum()
case OAuth2EndpointAuthStyleInParams:
authStyle = configpb.OAuth2AuthStyle_OAUTH2_AUTH_STYLE_IN_PARAMS.Enum()
default:
authStyle = nil
}
v.UpstreamOauth2 = &configpb.UpstreamOAuth2{
ClientId: src.UpstreamOAuth2.ClientID,
ClientSecret: src.UpstreamOAuth2.ClientSecret,
Oauth2Endpoint: &configpb.OAuth2Endpoint{
AuthUrl: src.UpstreamOAuth2.Endpoint.AuthURL,
TokenUrl: src.UpstreamOAuth2.Endpoint.TokenURL,
AuthStyle: authStyle,
},
Scopes: src.UpstreamOAuth2.Scopes,
}
func OAuth2EndpointToPB(src OAuth2Endpoint) *configpb.OAuth2Endpoint {
var authStyle *configpb.OAuth2AuthStyle
switch src.AuthStyle {
case OAuth2EndpointAuthStyleInHeader:
authStyle = configpb.OAuth2AuthStyle_OAUTH2_AUTH_STYLE_IN_HEADER.Enum()
case OAuth2EndpointAuthStyleInParams:
authStyle = configpb.OAuth2AuthStyle_OAUTH2_AUTH_STYLE_IN_PARAMS.Enum()
default:
authStyle = configpb.OAuth2AuthStyle_OAUTH2_AUTH_STYLE_UNSPECIFIED.Enum()
}
return v
return &configpb.OAuth2Endpoint{
AuthUrl: src.AuthURL,
TokenUrl: src.TokenURL,
AuthStyle: authStyle,
}
}

View file

@ -328,7 +328,7 @@ func (b *Builder) buildRouteForPolicyAndMatch(
extAuthzOpts := MakeExtAuthzContextExtensions(false, routeID, routeChecksum)
extAuthzCfg := PerFilterConfigExtAuthzContextExtensions(extAuthzOpts)
if policy.IsMCPServer() {
extAuthzCfg = PerFilterConfigExtAuthzContextExtensionsWithBody(policy.MCP.GetMaxRequestBytes(), extAuthzOpts)
extAuthzCfg = PerFilterConfigExtAuthzContextExtensionsWithBody(policy.MCP.Server.GetMaxRequestBytes(), extAuthzOpts)
}
route.TypedPerFilterConfig = map[string]*anypb.Any{
PerFilterConfigExtAuthzName: extAuthzCfg,

View file

@ -2361,7 +2361,7 @@ func Test_buildPomeriumHTTPRoutesWithMCP(t *testing.T) {
{
From: "https://mcp.example.com",
To: mustParseWeightedURLs(t, "https://mcp-backend.example.com"),
MCP: &config.MCP{}, // This marks the policy as an MCP policy
MCP: &config.MCP{Server: &config.MCPServer{}}, // This marks the policy as an MCP policy
},
},
RuntimeFlags: config.DefaultRuntimeFlags(),
@ -2396,7 +2396,7 @@ func Test_buildPomeriumHTTPRoutesWithMCP(t *testing.T) {
{
From: "https://mcp.example.com",
To: mustParseWeightedURLs(t, "https://mcp-backend.example.com"),
MCP: &config.MCP{}, // This marks the policy as an MCP policy
MCP: &config.MCP{Server: &config.MCPServer{}}, // This marks the policy as an MCP policy
},
},
RuntimeFlags: config.DefaultRuntimeFlags(),

View file

@ -912,7 +912,7 @@ func TestOptions_GetAllRouteableHTTPHosts(t *testing.T) {
assert.NoError(t, p2.Validate())
p3 := Policy{From: "https://from3.example.com", TLSDownstreamServerName: "from.example.com", To: to}
assert.NoError(t, p3.Validate())
p4 := Policy{From: "https://from4.example.com", MCP: &MCP{}, To: to}
p4 := Policy{From: "https://from4.example.com", MCP: &MCP{Server: &MCPServer{}}, To: to}
assert.NoError(t, p4.Validate())
opts := &Options{
@ -1587,15 +1587,22 @@ func TestRoute_FromToProto(t *testing.T) {
for i := range pb.LoadBalancingWeights {
pb.LoadBalancingWeights[i] = mathrand.Uint32N(10000) + 1
}
pb.Mcp.UpstreamOauth2.Oauth2Endpoint.AuthStyle = nil
case 1:
pb.Redirect, err = redirectGen.Gen()
require.NoError(t, err)
pb.Mcp.UpstreamOauth2.Oauth2Endpoint.AuthStyle = configpb.OAuth2AuthStyle_OAUTH2_AUTH_STYLE_IN_PARAMS.Enum()
pb.Mcp = &configpb.MCP{
Mode: &configpb.MCP_Client{
Client: &configpb.MCPClient{},
},
}
case 2:
pb.Response, err = responseGen.Gen()
require.NoError(t, err)
pb.Mcp.UpstreamOauth2.Oauth2Endpoint.AuthStyle = configpb.OAuth2AuthStyle_OAUTH2_AUTH_STYLE_IN_HEADER.Enum()
pb.Mcp = &configpb.MCP{
Mode: &configpb.MCP_Server{
Server: &configpb.MCPServer{},
},
}
}
return pb
}

View file

@ -212,15 +212,23 @@ type Policy struct {
// MCP is an experimental support for Model Context Protocol upstreams configuration
type MCP struct {
// exactly one of server or client should be specified
Server *MCPServer `mapstructure:"server" yaml:"server,omitempty" json:"server,omitempty"`
Client *MCPClient `mapstructure:"client" yaml:"client,omitempty" json:"client,omitempty"`
}
// MCPServer holds configuration for an MCP server route
type MCPServer struct {
// UpstreamOAuth2 specifies that before the request reaches the MCP upstream server, it should acquire an OAuth2 token
UpstreamOAuth2 *UpstreamOAuth2 `mapstructure:"upstream_oauth2" yaml:"upstream_oauth2,omitempty" json:"upstream_oauth2,omitempty"`
// PassUpstreamAccessToken indicates whether to pass the upstream access token in the `Authorization: Bearer` header that is suitable for calling the MCP routes
PassUpstreamAccessToken bool `mapstructure:"pass_upstream_access_token" yaml:"pass_upstream_access_token,omitempty" json:"pass_upstream_access_token,omitempty"`
// MaxRequestBytes is the maximum request body size in bytes that can be sent to the MCP server
MaxRequestBytes *uint32 `mapstructure:"max_request_bytes" yaml:"max_request_bytes,omitempty" json:"max_request_bytes,omitempty"`
}
func (p *MCP) GetMaxRequestBytes() uint32 {
// MCPClient holds configuration for an MCP client route
type MCPClient struct{}
func (p *MCPServer) GetMaxRequestBytes() uint32 {
if p == nil || p.MaxRequestBytes == nil {
return 4 * 1024
}
@ -228,13 +236,11 @@ func (p *MCP) GetMaxRequestBytes() uint32 {
}
// HasUpstreamOAuth2 checks if the route is for the MCP Server and if it has an upstream OAuth2 configuration
func (p *MCP) HasUpstreamOAuth2() bool {
return p != nil && p.UpstreamOAuth2 != nil
}
// IsUpstreamClientNeedsAccessToken checks if the route is for the MCP Client and if it needs to pass the upstream access token
func (p *MCP) IsUpstreamClientNeedsAccessToken() bool {
return p != nil && p.PassUpstreamAccessToken
func (p *MCP) GetServerUpstreamOAuth2() *UpstreamOAuth2 {
if p != nil && p.Server != nil {
return p.Server.UpstreamOAuth2
}
return nil
}
type UpstreamOAuth2 struct {
@ -756,6 +762,9 @@ func (p *Policy) Validate() error {
return fmt.Errorf("config: depends_on is limited to 5 additional redirect hosts, got %v", p.DependsOn)
}
if p.MCP != nil && p.MCP.Server == nil && p.MCP.Client == nil {
return fmt.Errorf("config: mcp must have either server or client set")
}
return nil
}
@ -893,12 +902,12 @@ func (p *Policy) IsForKubernetes() bool {
// IsMCPServer returns true if the route is for the Model Context Protocol upstream server.
func (p *Policy) IsMCPServer() bool {
return p != nil && p.MCP != nil && !p.MCP.PassUpstreamAccessToken
return p != nil && p.MCP != nil && p.MCP.Server != nil
}
// IsMCPClient returns true if the route is for the Model Context Protocol client application upstream.
func (p *Policy) IsMCPClient() bool {
return p != nil && p.MCP != nil && p.MCP.PassUpstreamAccessToken
return p != nil && p.MCP != nil && p.MCP.Client != nil
}
// IsTCP returns true if the route is for TCP.

View file

@ -138,21 +138,21 @@ func BuildHostInfo(cfg *config.Config, prefix string) (map[string]ServerHostInfo
Host: host,
URL: policy.GetFrom(),
}
if policy.MCP.UpstreamOAuth2 != nil {
if oa := policy.MCP.GetServerUpstreamOAuth2(); oa != nil {
v.Config = &oauth2.Config{
ClientID: policy.MCP.UpstreamOAuth2.ClientID,
ClientSecret: policy.MCP.UpstreamOAuth2.ClientSecret,
ClientID: oa.ClientID,
ClientSecret: oa.ClientSecret,
Endpoint: oauth2.Endpoint{
AuthURL: policy.MCP.UpstreamOAuth2.Endpoint.AuthURL,
TokenURL: policy.MCP.UpstreamOAuth2.Endpoint.TokenURL,
AuthStyle: authStyleEnum(policy.MCP.UpstreamOAuth2.Endpoint.AuthStyle),
AuthURL: oa.Endpoint.AuthURL,
TokenURL: oa.Endpoint.TokenURL,
AuthStyle: authStyleEnum(oa.Endpoint.AuthStyle),
},
RedirectURL: (&url.URL{
Scheme: "https",
Host: host,
Path: path.Join(prefix, oauthCallbackEndpoint),
}).String(),
Scopes: policy.MCP.UpstreamOAuth2.Scopes,
Scopes: oa.Scopes,
}
}
servers[host] = v

View file

@ -25,13 +25,13 @@ func TestBuildOAuthConfig(t *testing.T) {
Description: "description-1",
LogoURL: "https://logo.example.com",
From: "https://mcp1.example.com",
MCP: &config.MCP{},
MCP: &config.MCP{Server: &config.MCPServer{}},
},
{
Name: "mcp-2",
From: "https://mcp2.example.com",
MCP: &config.MCP{
UpstreamOAuth2: &config.UpstreamOAuth2{
Server: &config.MCPServer{UpstreamOAuth2: &config.UpstreamOAuth2{
ClientID: "client_id",
ClientSecret: "client_secret",
Endpoint: config.OAuth2Endpoint{
@ -39,22 +39,18 @@ func TestBuildOAuthConfig(t *testing.T) {
TokenURL: "https://auth.example.com/token",
AuthStyle: config.OAuth2EndpointAuthStyleInParams,
},
},
}},
},
},
{
Name: "mcp-client-1",
From: "https://client1.example.com",
MCP: &config.MCP{
PassUpstreamAccessToken: true,
},
MCP: &config.MCP{Client: &config.MCPClient{}},
},
{
Name: "mcp-client-2",
From: "https://client2.example.com",
MCP: &config.MCP{
PassUpstreamAccessToken: true,
},
MCP: &config.MCP{Client: &config.MCPClient{}},
},
},
},
@ -105,14 +101,12 @@ func TestHostInfo_IsMCPClientForHost(t *testing.T) {
{
Name: "mcp-server",
From: "https://server.example.com",
MCP: &config.MCP{},
MCP: &config.MCP{Server: &config.MCPServer{}},
},
{
Name: "mcp-client",
From: "https://client.example.com",
MCP: &config.MCP{
PassUpstreamAccessToken: true,
},
MCP: &config.MCP{Client: &config.MCPClient{}},
},
},
},

File diff suppressed because it is too large Load diff

View file

@ -17,7 +17,7 @@ message Config {
message RouteRewriteHeader {
string header = 1;
oneof matcher {
oneof matcher {
string prefix = 3;
}
string value = 2;
@ -101,7 +101,7 @@ message Route {
repeated string allowed_users = 4 [deprecated = true];
// repeated string allowed_groups = 5 [ deprecated = true ];
repeated string allowed_domains = 6 [deprecated = true];
repeated string allowed_domains = 6 [deprecated = true];
map<string, google.protobuf.ListValue> allowed_idp_claims = 32 [deprecated = true];
string prefix = 7;
@ -170,15 +170,24 @@ message Route {
optional StringList idp_access_token_allowed_audiences = 69;
bool show_error_details = 59;
optional MCP mcp = 72;
optional MCP mcp = 72;
optional CircuitBreakerThresholds circuit_breaker_thresholds = 73;
}
message MCP {
optional UpstreamOAuth2 upstream_oauth2 = 1;
optional bool pass_upstream_access_token = 2;
oneof mode {
MCPServer server = 1;
MCPClient client = 2;
}
}
message MCPServer {
optional UpstreamOAuth2 upstream_oauth2 = 1;
optional uint32 max_request_bytes = 2;
}
message MCPClient {}
message UpstreamOAuth2 {
string client_id = 1;
string client_secret = 2;
@ -236,18 +245,18 @@ message Settings {
repeated string values = 1;
}
optional string installation_id = 71;
optional string log_level = 3;
optional StringList access_log_fields = 114;
optional StringList authorize_log_fields = 115;
optional string proxy_log_level = 4;
optional string shared_secret = 5;
optional string services = 6;
optional string address = 7;
optional bool insecure_server = 8;
optional string dns_lookup_family = 60;
repeated Certificate certificates = 9;
optional string http_redirect_addr = 10;
optional string installation_id = 71;
optional string log_level = 3;
optional StringList access_log_fields = 114;
optional StringList authorize_log_fields = 115;
optional string proxy_log_level = 4;
optional string shared_secret = 5;
optional string services = 6;
optional string address = 7;
optional bool insecure_server = 8;
optional string dns_lookup_family = 60;
repeated Certificate certificates = 9;
optional string http_redirect_addr = 10;
optional google.protobuf.Duration timeout_read = 11;
optional google.protobuf.Duration timeout_write = 12;
optional google.protobuf.Duration timeout_idle = 13;
@ -259,7 +268,7 @@ message Settings {
optional string cookie_secret = 17;
optional string cookie_domain = 18;
// optional bool cookie_secure = 19;
optional bool cookie_http_only = 20;
optional bool cookie_http_only = 20;
optional google.protobuf.Duration cookie_expire = 21;
optional string cookie_same_site = 113;
optional string idp_client_id = 22;
@ -280,10 +289,10 @@ message Settings {
optional string signing_key = 36;
map<string, string> set_response_headers = 69;
// repeated string jwt_claims_headers = 37;
map<string, string> jwt_claims_headers = 63;
optional IssuerFormat jwt_issuer_format = 139;
repeated string jwt_groups_filter = 119;
optional BearerTokenFormat bearer_token_format = 138;
map<string, string> jwt_claims_headers = 63;
optional IssuerFormat jwt_issuer_format = 139;
repeated string jwt_groups_filter = 119;
optional BearerTokenFormat bearer_token_format = 138;
optional google.protobuf.Duration default_upstream_timeout = 39;
optional string metrics_address = 40;
optional string metrics_basic_auth = 64;
@ -304,17 +313,17 @@ message Settings {
optional google.protobuf.Duration otel_exporter_otlp_traces_timeout = 133;
optional google.protobuf.Duration otel_bsp_schedule_delay = 134;
optional int32 otel_bsp_max_export_batch_size = 135;
reserved 41 to 45, 98; // legacy tracing fields
optional string grpc_address = 46;
optional bool grpc_insecure = 47;
reserved 41 to 45, 98; // legacy tracing fields
optional string grpc_address = 46;
optional bool grpc_insecure = 47;
optional google.protobuf.Duration grpc_client_timeout = 99;
reserved 100; // grpc_client_dns_roundrobin
reserved 100; // grpc_client_dns_roundrobin
// optional string forward_auth_url = 50;
repeated string databroker_service_urls = 52;
optional string databroker_internal_service_url = 84;
optional string databroker_storage_type = 101;
optional string databroker_storage_connection_string = 102;
reserved 106; // databroker_storage_tls_skip_verify
reserved 106; // databroker_storage_tls_skip_verify
optional DownstreamMtlsSettings downstream_mtls = 116;
// optional string client_ca = 53;
// optional string client_crl = 74;