mcp: add mcp_tool to ppl (#5662)

## Summary

Adds `mcp_tool` PPL criterion, that matches MCP tool names like 

```yaml
  - from: https://db.localhost.pomerium.io
    to: http://localhost:3000/mcp
    policy:
      allow:
        and:
          - email: 
              in: ["user@pomerium.com"]
          - mcp_tool:
              in: ["list_tables", "read_table", "search_records"]
    mcp: {}
```

## Related issues

Fix
https://linear.app/pomerium/issue/ENG-2393/mcp-authorize-each-incoming-request-to-an-mcp-route

## User Explanation

<!-- How would you explain this change to the user? If this
change doesn't create any user-facing changes, you can leave
this blank. If filled out, add the `docs` label -->

## Checklist

- [x] reference any related issues
- [x] updated unit tests
- [x] add appropriate label (`enhancement`, `bug`, `breaking`,
`dependencies`, `ci`)
- [x] ready for review
This commit is contained in:
Denis Mishin 2025-06-23 09:43:43 -07:00 committed by GitHub
parent 55dd6ba7d0
commit f9e7308f12
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 265 additions and 0 deletions

View file

@ -4,6 +4,7 @@ package evaluator
import (
"context"
"encoding/base64"
"encoding/json"
"encoding/pem"
"fmt"
"net/http"
@ -38,6 +39,7 @@ type Request struct {
Policy *config.Policy
HTTP RequestHTTP
Session RequestSession
MCP RequestMCP
EnvoyRouteChecksum uint64
EnvoyRouteID string
}
@ -133,6 +135,12 @@ type RequestSession struct {
ID string `json:"id"`
}
// RequestMCP is the MCP field in the request.
type RequestMCP struct {
Tool string `json:"tool,omitempty"`
Method string `json:"method,omitempty"`
}
// Result is the result of evaluation.
type Result struct {
Allow RuleResult
@ -373,6 +381,7 @@ func (e *Evaluator) evaluatePolicy(ctx context.Context, req *Request) (*PolicyRe
return policyEvaluator.Evaluate(ctx, &PolicyRequest{
HTTP: req.HTTP,
MCP: req.MCP,
Session: req.Session,
IsValidClientCertificate: isValidClientCertificate,
})
@ -468,3 +477,36 @@ func carryOverJWTAssertion(dst http.Header, src map[string]string) {
dst.Add(jwtForKey, jwtFor)
}
}
// RequestMCPFromCheckRequest populates a RequestMCP from an Envoy CheckRequest proto for MCP routes.
func RequestMCPFromCheckRequest(
in *envoy_service_auth_v3.CheckRequest,
) (RequestMCP, bool) {
var mcpReq RequestMCP
body := in.GetAttributes().GetRequest().GetHttp().GetBody()
if body == "" {
return mcpReq, false
}
var jsonRPCReq struct {
Method string `json:"method"`
Params map[string]any `json:"params,omitempty"`
}
if err := json.Unmarshal([]byte(body), &jsonRPCReq); err != nil {
return mcpReq, false
}
mcpReq.Method = jsonRPCReq.Method
if jsonRPCReq.Method == "tools/call" {
if name, exists := jsonRPCReq.Params["name"]; exists {
if toolName, ok := name.(string); ok {
mcpReq.Tool = toolName
}
}
}
return mcpReq, true
}

View file

@ -235,6 +235,21 @@ func TestEvaluator(t *testing.T) {
},
},
},
{
To: config.WeightedURLs{{URL: *mustParseURL("https://to13.example.com")}},
MCP: &config.MCP{},
Policy: &config.PPLPolicy{
Policy: &parser.Policy{
Rules: []parser.Rule{{
Action: parser.ActionAllow,
And: []parser.Criterion{
{Name: "mcp_tool", Data: parser.Object{"is": parser.String("tool_name")}},
{Name: "email", Data: parser.Object{"is": parser.String("a@example.com")}},
},
}},
},
},
},
}
options := []Option{
WithAuthenticateURL("https://authn.example.com"),
@ -653,6 +668,36 @@ func TestEvaluator(t *testing.T) {
require.NoError(t, err)
assert.True(t, res.Allow.Value)
})
t.Run("mcp", func(t *testing.T) {
t.Run("allowed tool name", func(t *testing.T) {
res, err := eval(t, options, []proto.Message{
&session.Session{
Id: "session1",
UserId: "user1",
},
&user.User{
Id: "user1",
Email: "a@example.com",
},
}, &Request{
Policy: policies[12],
Session: RequestSession{
ID: "session1",
},
HTTP: RequestHTTP{
Method: http.MethodGet,
URL: "https://from.example.com",
},
MCP: RequestMCP{
Tool: "tool_name",
Method: "tools/call",
},
})
require.NoError(t, err)
assert.True(t, res.Allow.Value)
assert.False(t, res.Deny.Value)
})
})
}
func TestEvaluator_EvaluateInternal(t *testing.T) {

View file

@ -21,6 +21,7 @@ import (
// PolicyRequest is the input to policy evaluation.
type PolicyRequest struct {
HTTP RequestHTTP `json:"http"`
MCP RequestMCP `json:"mcp"`
Session RequestSession `json:"session"`
IsValidClientCertificate bool `json:"is_valid_client_certificate"`
}

View file

@ -24,6 +24,7 @@ import (
"github.com/pomerium/pomerium/pkg/grpc/session"
"github.com/pomerium/pomerium/pkg/grpc/user"
"github.com/pomerium/pomerium/pkg/grpcutil"
"github.com/pomerium/pomerium/pkg/policy/criteria"
"github.com/pomerium/pomerium/pkg/storage"
"github.com/pomerium/pomerium/pkg/telemetry/requestid"
)
@ -64,6 +65,21 @@ func (a *Authorize) Check(ctx context.Context, in *envoy_service_auth_v3.CheckRe
u, _ = a.getDataBrokerUser(ctx, s.GetUserId()) // ignore any missing user error
}
// For MCP routes that only require authentication (not full authorization),
// if we have a valid session, allow the request without running policy evaluation
// as policy for MCP may contain check for i.e. tool calls that are not relevant at this stage.
if a.currentConfig.Load().Options.IsRuntimeFlagSet(config.RuntimeFlagMCP) {
if req.Policy.IsMCPServer() && strings.HasPrefix(hreq.URL.Path, mcp.DefaultPrefix) {
if s != nil {
return a.requireLoginResponse(ctx, in, req)
}
a.logAuthorizeCheck(ctx, req, &evaluator.Result{
Allow: evaluator.NewRuleResult(true, criteria.ReasonMCPHandshake),
}, s, u)
return a.okResponse(make(http.Header)), nil
}
}
res, err := state.evaluator.Evaluate(ctx, req)
if err != nil {
log.Ctx(ctx).Error().Err(err).Str("request-id", requestID).Msg("error during OPA evaluation")
@ -189,6 +205,25 @@ func (a *Authorize) getEvaluatorRequestFromCheckRequest(
EnvoyRouteID: envoyconfig.ExtAuthzContextExtensionsRouteID(attrs.GetContextExtensions()),
}
req.Policy = a.getMatchingPolicy(req.EnvoyRouteID)
if req.Policy.IsMCPServer() {
var ok bool
req.MCP, ok = evaluator.RequestMCPFromCheckRequest(in)
if !ok {
log.Ctx(ctx).Error().
Str("request-id", requestid.FromContext(ctx)).
Str("route_id", req.EnvoyRouteID).
Msg("failed to parse MCP request from check request")
} else {
log.Ctx(ctx).Debug().
Str("request-id", requestid.FromContext(ctx)).
Str("route_id", req.EnvoyRouteID).
Str("mcp_tool", req.MCP.Tool).
Str("mcp_method", req.MCP.Method).
Msg("authorize request from check request")
}
}
return req, nil
}

View file

@ -32,6 +32,7 @@ type (
Input struct {
HTTP InputHTTP `json:"http"`
Session InputSession `json:"session"`
MCP InputMCP `json:"mcp"`
IsValidClientCertificate bool `json:"is_valid_client_certificate"`
}
InputHTTP struct {
@ -43,6 +44,10 @@ type (
InputSession struct {
ID string `json:"id"`
}
InputMCP struct {
Tool string `json:"tool"`
Method string `json:"method"`
}
ClientCertificateInfo struct {
Presented bool `json:"presented"`
Leaf string `json:"leaf"`

View file

@ -0,0 +1,60 @@
package criteria
import (
"github.com/open-policy-agent/opa/ast"
"github.com/pomerium/pomerium/pkg/policy/generator"
"github.com/pomerium/pomerium/pkg/policy/parser"
)
type mcpToolCriterion struct {
g *Generator
}
func (mcpToolCriterion) DataType() CriterionDataType {
return CriterionDataTypeStringMatcher
}
func (mcpToolCriterion) Name() string {
return "mcp_tool"
}
func (c mcpToolCriterion) GenerateRule(_ string, data parser.Value) (*ast.Rule, []*ast.Rule, error) {
r1 := c.g.NewRule(c.Name())
r1.Head.Value = NewCriterionTerm(true, ReasonMCPNotAToolCall)
r1.Body = ast.Body{
ast.MustParseExpr(`input.mcp.method != "tools/call"`),
}
r2 := &ast.Rule{
Head: generator.NewHead("", NewCriterionTerm(true, ReasonMCPToolOK)),
Body: ast.Body{
ast.MustParseExpr(`input.mcp.method == "tools/call"`),
},
}
toolRef := ast.RefTerm(ast.VarTerm("input"), ast.VarTerm("mcp"), ast.VarTerm("tool"))
err := matchString(&r2.Body, toolRef, data)
if err != nil {
return nil, nil, err
}
r1.Else = r2
r3 := &ast.Rule{
Head: generator.NewHead("", NewCriterionTerm(false, ReasonMCPToolUnauthorized)),
Body: ast.Body{
ast.NewExpr(ast.BooleanTerm(true)),
},
}
r2.Else = r3
return r1, nil, nil
}
// MCPTool returns a Criterion which matches an MCP tool name.
func MCPTool(generator *Generator) Criterion {
return mcpToolCriterion{g: generator}
}
func init() {
Register(MCPTool)
}

View file

@ -0,0 +1,71 @@
package criteria
import (
"testing"
"github.com/stretchr/testify/require"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
)
func TestMCPTool(t *testing.T) {
t.Run("ok", func(t *testing.T) {
res, err := evaluate(t, `
allow:
and:
- mcp_tool:
is: list_tables
`, []*databroker.Record{}, Input{MCP: InputMCP{Tool: "list_tables", Method: "tools/call"}})
require.NoError(t, err)
require.Equal(t, A{true, A{ReasonMCPToolOK}, M{}}, res["allow"])
require.Equal(t, A{false, A{}}, res["deny"])
})
t.Run("unauthorized", func(t *testing.T) {
res, err := evaluate(t, `
allow:
and:
- mcp_tool:
is: list_tables
`, []*databroker.Record{}, Input{MCP: InputMCP{Tool: "read_table", Method: "tools/call"}})
require.NoError(t, err)
require.Equal(t, A{false, A{ReasonMCPToolUnauthorized}, M{}}, res["allow"])
require.Equal(t, A{false, A{}}, res["deny"])
})
t.Run("in list", func(t *testing.T) {
res, err := evaluate(t, `
allow:
and:
- mcp_tool:
in: ["list_tables", "read_table"]
`, []*databroker.Record{}, Input{MCP: InputMCP{Tool: "list_tables", Method: "tools/call"}})
require.NoError(t, err)
require.Equal(t, A{true, A{ReasonMCPToolOK}, M{}}, res["allow"])
require.Equal(t, A{false, A{}}, res["deny"])
})
t.Run("not in list", func(t *testing.T) {
res, err := evaluate(t, `
allow:
and:
- mcp_tool:
in: ["list_tables", "read_table"]
`, []*databroker.Record{}, Input{MCP: InputMCP{Tool: "delete_table", Method: "tools/call"}})
require.NoError(t, err)
require.Equal(t, A{false, A{ReasonMCPToolUnauthorized}, M{}}, res["allow"])
require.Equal(t, A{false, A{}}, res["deny"])
})
t.Run("non-tools/call method should pass", func(t *testing.T) {
res, err := evaluate(t, `
allow:
and:
- mcp_tool:
is: list_tables
`, []*databroker.Record{}, Input{MCP: InputMCP{Method: "some/other_method"}})
require.NoError(t, err)
require.Equal(t, A{true, A{ReasonMCPNotAToolCall}, M{}}, res["allow"])
require.Equal(t, A{false, A{}}, res["deny"])
})
}

View file

@ -28,6 +28,12 @@ const (
ReasonHTTPPathOK = "http-path-ok"
ReasonHTTPPathUnauthorized = "http-path-unauthorized"
ReasonInvalidClientCertificate = "invalid-client-certificate"
ReasonMCPHandshake = "mcp-handshake" // part of MCP protocol handshake
ReasonMCPMethodOK = "mcp-method-ok"
ReasonMCPMethodUnauthorized = "mcp-method-unauthorized"
ReasonMCPToolOK = "mcp-tool-ok"
ReasonMCPNotAToolCall = "mcp-not-a-tool-call" // MCP method is not a tool call
ReasonMCPToolUnauthorized = "mcp-tool-unauthorized"
ReasonNonCORSRequest = "non-cors-request"
ReasonNonPomeriumRoute = "non-pomerium-route"
ReasonPomeriumRoute = "pomerium-route"