mcp: pass access token to the upstream (#5593)

This commit is contained in:
Denis Mishin 2025-04-29 12:13:18 -04:00 committed by GitHub
parent b9e3a5d301
commit 5b024a8ada
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
15 changed files with 774 additions and 719 deletions

View file

@ -16,7 +16,6 @@ import (
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/atomicutil"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/mcp"
"github.com/pomerium/pomerium/internal/telemetry/metrics"
"github.com/pomerium/pomerium/pkg/cryptutil"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
@ -29,7 +28,6 @@ type Authorize struct {
store *store.Store
currentConfig *atomicutil.Value[*config.Config]
accessTracker *AccessTracker
mcp *atomicutil.Value[*mcp.Handler]
tracerProvider oteltrace.TracerProvider
tracer oteltrace.Tracer
@ -40,17 +38,11 @@ func New(ctx context.Context, cfg *config.Config) (*Authorize, error) {
tracerProvider := trace.NewTracerProvider(ctx, "Authorize")
tracer := tracerProvider.Tracer(trace.PomeriumCoreTracer)
mcp, err := mcp.New(ctx, mcp.DefaultPrefix, cfg)
if err != nil {
return nil, fmt.Errorf("authorize: failed to create mcp handler: %w", err)
}
a := &Authorize{
currentConfig: atomicutil.NewValue(cfg),
store: store.New(),
tracerProvider: tracerProvider,
tracer: tracer,
mcp: atomicutil.NewValue(mcp),
}
a.accessTracker = NewAccessTracker(a, accessTrackerMaxSize, accessTrackerDebouncePeriod)
@ -93,6 +85,7 @@ func validateOptions(o *config.Options) error {
func newPolicyEvaluator(
ctx context.Context,
opts *config.Options, store *store.Store, previous *evaluator.Evaluator,
evaluatorOpts ...evaluator.Option,
) (*evaluator.Evaluator, error) {
metrics.AddPolicyCountCallback("pomerium-authorize", func() int64 {
return int64(opts.NumPolicies())
@ -136,7 +129,7 @@ func newPolicyEvaluator(
}
allPolicies := slices.Collect(opts.GetAllPolicies())
return evaluator.New(ctx, store, previous,
evaluatorOpts = append([]evaluator.Option{
evaluator.WithPolicies(allPolicies),
evaluator.WithClientCA(clientCA),
evaluator.WithAddDefaultClientCertificateRule(addDefaultClientCertificateRule),
@ -148,7 +141,8 @@ func newPolicyEvaluator(
evaluator.WithJWTClaimsHeaders(opts.JWTClaimsHeaders),
evaluator.WithJWTGroupsFilter(opts.JWTGroupsFilter),
evaluator.WithDefaultJWTIssuerFormat(opts.JWTIssuerFormat),
)
}, evaluatorOpts...)
return evaluator.New(ctx, store, previous, evaluatorOpts...)
}
// OnConfigChange updates internal structures based on config.Options
@ -160,11 +154,4 @@ func (a *Authorize) OnConfigChange(ctx context.Context, cfg *config.Config) {
} else {
a.state.Store(newState)
}
mcp, err := mcp.New(ctx, mcp.DefaultPrefix, cfg)
if err != nil {
log.Ctx(ctx).Error().Err(err).Msg("proxy: failed to update authorize state from configuration settings")
} else {
a.mcp.Store(mcp)
}
}

View file

@ -99,7 +99,7 @@ func (a *Authorize) handleResultDenied(
case invalidClientCertReason(reasons):
denyStatusCode = httputil.StatusInvalidClientCertificate
denyStatusText = httputil.DetailsText(httputil.StatusInvalidClientCertificate)
case request.Policy.IsMCP():
case request.Policy.IsMCPServer():
denyStatusCode = http.StatusUnauthorized
denyStatusText = httputil.DetailsText(http.StatusUnauthorized)
}
@ -358,7 +358,7 @@ func (a *Authorize) userInfoEndpointURL(in *envoy_service_auth_v3.CheckRequest)
}
func (a *Authorize) shouldRedirect(in *envoy_service_auth_v3.CheckRequest, request *evaluator.Request) bool {
if request.Policy.IsMCP() {
if request.Policy.IsMCPServer() {
return false
}

View file

@ -1,6 +1,8 @@
package evaluator
import (
"time"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/hashutil"
)
@ -17,6 +19,7 @@ type evaluatorConfig struct {
JWTClaimsHeaders config.JWTClaimHeaders
JWTGroupsFilter config.JWTGroupsFilter
DefaultJWTIssuerFormat config.JWTIssuerFormat
MCPAccessTokenProvider func(string, time.Time) (string, error) `hash:"-"`
}
// cacheKey() returns a hash over the configuration, except for the policies.
@ -113,3 +116,10 @@ func WithDefaultJWTIssuerFormat(format config.JWTIssuerFormat) Option {
cfg.DefaultJWTIssuerFormat = format
}
}
// WithMCPAccessTokenProvider sets the MCP access token in the config.
func WithMCPAccessTokenProvider(fn func(sessionID string, expires time.Time) (string, error)) Option {
return func(cfg *evaluatorConfig) {
cfg.MCPAccessTokenProvider = fn
}
}

View file

@ -383,6 +383,7 @@ func updateStore(store *store.Store, cfg *evaluatorConfig) error {
store.UpdateDefaultJWTIssuerFormat(cfg.DefaultJWTIssuerFormat)
store.UpdateRoutePolicies(cfg.Policies)
store.UpdateSigningKey(jwk)
store.UpdateMCPAccessTokenProvider(cfg.MCPAccessTokenProvider)
return nil
}

View file

@ -93,6 +93,27 @@ func (e *headersEvaluatorEvaluation) fillJWTClaimHeaders(ctx context.Context) er
return nil
}
func (e *headersEvaluatorEvaluation) fillMCPHeaders() error {
if e.request == nil ||
e.request.Policy == nil ||
e.request.Policy.MCP == nil ||
!e.request.Policy.MCP.PassUpstreamAccessToken {
return nil
}
if e.request.Session.ID == "" {
return nil
}
accessToken, err := e.evaluator.store.GetMCPAccessTokenProvider()(e.request.Session.ID, e.now.Add(time.Hour))
if err != nil {
return fmt.Errorf("authorize/header-evaluator: error getting MCP access token: %w", err)
}
e.response.Headers.Set("Authorization", "Bearer "+accessToken)
return nil
}
func (e *headersEvaluatorEvaluation) fillKubernetesHeaders(ctx context.Context) {
if e.request.Policy == nil {
return
@ -177,6 +198,10 @@ func (e *headersEvaluatorEvaluation) fillHeaders(ctx context.Context) error {
if err := e.fillJWTClaimHeaders(ctx); err != nil {
return err
}
err := e.fillMCPHeaders()
if err != nil {
return err
}
e.fillKubernetesHeaders(ctx)
e.fillGoogleCloudServerlessHeaders(ctx)
e.fillRoutingKeyHeaders()

View file

@ -123,7 +123,7 @@ func (a *Authorize) maybeGetSessionFromRequest(
hreq *http.Request,
policy *config.Policy,
) (*session.Session, error) {
if policy.IsMCP() {
if policy.IsMCPServer() {
s, err := a.getMCPSession(ctx, hreq)
if err != nil {
log.Ctx(ctx).Error().Err(err).Msg("error getting mcp session")
@ -165,8 +165,8 @@ func (a *Authorize) getMCPSession(
}
accessToken := auth[len(prefix):]
sessionID, ok := a.mcp.Load().GetSessionIDFromAccessToken(ctx, accessToken)
if !ok {
sessionID, err := a.state.Load().mcp.GetSessionIDFromAccessToken(accessToken)
if err != nil {
return nil, fmt.Errorf("no session found for access token: %w", sessions.ErrNoSessionFound)
}

View file

@ -20,6 +20,7 @@ import (
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/atomicutil"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/storage"
@ -35,6 +36,8 @@ type Store struct {
jwtGroupsFilter atomic.Pointer[config.JWTGroupsFilter]
defaultJWTIssuerFormat atomic.Pointer[config.JWTIssuerFormat]
signingKey atomic.Pointer[jose.JSONWebKey]
mcpAccessTokenProvider atomicutil.Value[func(string, time.Time) (string, error)]
}
// New creates a new Store.
@ -78,6 +81,15 @@ func (s *Store) GetSigningKey() *jose.JSONWebKey {
return s.signingKey.Load()
}
func (s *Store) GetMCPAccessTokenProvider() func(string, time.Time) (string, error) {
if f := s.mcpAccessTokenProvider.Load(); f != nil {
return f
}
return func(string, time.Time) (string, error) {
return "", fmt.Errorf("no mcp access token provider")
}
}
// UpdateGoogleCloudServerlessAuthenticationServiceAccount updates the google cloud serverless authentication
// service account in the store.
func (s *Store) UpdateGoogleCloudServerlessAuthenticationServiceAccount(serviceAccount string) {
@ -115,6 +127,11 @@ func (s *Store) UpdateSigningKey(signingKey *jose.JSONWebKey) {
s.signingKey.Store(signingKey)
}
func (s *Store) UpdateMCPAccessTokenProvider(mcpAccessTokenProvider func(string, time.Time) (string, error)) {
// This isn't used by the Rego code, so we don't need to write it to the opastorage.Store instance.
s.mcpAccessTokenProvider.Store(mcpAccessTokenProvider)
}
func (s *Store) write(rawPath string, value any) {
ctx := context.TODO()
err := opastorage.Txn(ctx, s.Store, opastorage.WriteParams, func(txn opastorage.Transaction) error {

View file

@ -14,6 +14,7 @@ import (
"github.com/pomerium/pomerium/authorize/internal/store"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/authenticateflow"
"github.com/pomerium/pomerium/internal/mcp"
"github.com/pomerium/pomerium/pkg/grpc"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/grpc/session"
@ -36,6 +37,7 @@ type authorizeState struct {
sessionStore *config.SessionStore
authenticateFlow authenticateFlow
syncQueriers map[string]storage.Querier
mcp *mcp.Handler
}
func newAuthorizeStateFromConfig(
@ -57,7 +59,13 @@ func newAuthorizeStateFromConfig(
previousEvaluator = previousState.evaluator
}
state.evaluator, err = newPolicyEvaluator(ctx, cfg.Options, store, previousEvaluator)
mcp, err := mcp.New(ctx, mcp.DefaultPrefix, cfg)
if err != nil {
return nil, fmt.Errorf("authorize: failed to create mcp handler: %w", err)
}
state.mcp = mcp
state.evaluator, err = newPolicyEvaluator(ctx, cfg.Options, store, previousEvaluator, evaluator.WithMCPAccessTokenProvider(mcp.CreateAccessTokenForSession))
if err != nil {
return nil, fmt.Errorf("authorize: failed to update policy with options: %w", err)
}

View file

@ -671,6 +671,7 @@ func MCPFromPB(src *configpb.MCP) *MCP {
return nil
}
var v MCP
v.PassUpstreamAccessToken = src.GetPassUpstreamAccessToken()
if uo := src.GetUpstreamOauth2(); uo != nil {
v.UpstreamOAuth2 = &UpstreamOAuth2{
ClientID: uo.GetClientId(),
@ -708,6 +709,7 @@ func MCPToPB(src *MCP) *configpb.MCP {
return nil
}
v := new(configpb.MCP)
v.PassUpstreamAccessToken = proto.Bool(src.PassUpstreamAccessToken)
if src.UpstreamOAuth2 != nil {
var authStyle *configpb.OAuth2AuthStyle
switch src.UpstreamOAuth2.Endpoint.AuthStyle {

View file

@ -1309,7 +1309,7 @@ func (o *Options) GetAllRouteableHTTPHosts() ([]string, map[string]bool, error)
hosts.InsertSlice(domains)
// Track if the domains are associated with an MCP policy
if policy.IsMCP() {
if policy.IsMCPServer() {
for _, domain := range domains {
mcpHosts[domain] = true
}
@ -1321,7 +1321,7 @@ func (o *Options) GetAllRouteableHTTPHosts() ([]string, map[string]bool, error)
hosts.InsertSlice(tlsDomains)
// Track if the TLS domains are associated with an MCP policy
if policy.IsMCP() {
if policy.IsMCPServer() {
for _, domain := range tlsDomains {
mcpHosts[domain] = true
}

View file

@ -211,6 +211,8 @@ type Policy struct {
type MCP struct {
// UpstreamOAuth2 specifies that before the request reaches the MCP upstream server, it should acquire an OAuth2 token
UpstreamOAuth2 *UpstreamOAuth2 `mapstructure:"upstream_oauth2" yaml:"upstream_oauth2,omitempty" json:"upstream_oauth2,omitempty"`
// PassUpstreamAccessToken indicates whether to pass the upstream access token in the `Authorization: Bearer` header that is suitable for calling the MCP routes
PassUpstreamAccessToken bool `mapstructure:"pass_upstream_access_token" yaml:"pass_upstream_access_token,omitempty" json:"pass_upstream_access_token,omitempty"`
}
type UpstreamOAuth2 struct {
@ -859,9 +861,9 @@ func (p *Policy) IsForKubernetes() bool {
return p.KubernetesServiceAccountTokenFile != "" || p.KubernetesServiceAccountToken != ""
}
// IsMCP returns true if the route is for the Model Context Protocol upstream server.
func (p *Policy) IsMCP() bool {
return p != nil && p.MCP != nil
// IsMCPServer returns true if the route is for the Model Context Protocol upstream server.
func (p *Policy) IsMCPServer() bool {
return p != nil && p.MCP != nil && !p.MCP.PassUpstreamAccessToken
}
// IsTCP returns true if the route is for TCP.

View file

@ -1,7 +1,6 @@
package mcp
import (
"context"
"encoding/json"
"net/http"
"time"
@ -88,7 +87,7 @@ func (srv *Handler) handleAuthorizationCodeToken(w http.ResponseWriter, r *http.
return
}
accessToken, err := CreateAccessToken(session, srv.cipher)
accessToken, err := srv.CreateAccessTokenForSession(session.Id, session.ExpiresAt.AsTime())
if err != nil {
http.Error(w, "internal error", http.StatusInternalServerError)
return
@ -117,12 +116,3 @@ func (srv *Handler) handleAuthorizationCodeToken(w http.ResponseWriter, r *http.
w.WriteHeader(http.StatusOK)
_, _ = w.Write(data)
}
func (srv *Handler) GetSessionIDFromAccessToken(ctx context.Context, accessToken string) (string, bool) {
sessionID, err := DecryptAccessToken(accessToken, srv.cipher)
if err != nil {
log.Ctx(ctx).Error().Err(err).Msg("failed to decrypt access token")
return "", false
}
return sessionID, true
}

View file

@ -1,12 +1,10 @@
package mcp
import (
"crypto/cipher"
"fmt"
"time"
"github.com/pomerium/pomerium/internal/oauth21"
"github.com/pomerium/pomerium/pkg/grpc/session"
)
func CheckPKCE(
@ -30,13 +28,13 @@ func CheckPKCE(
}
// CreateAuthorizationCode creates an access token based on the session
func CreateAccessToken(src *session.Session, cipher cipher.AEAD) (string, error) {
return CreateCode(CodeTypeAccess, src.Id, src.ExpiresAt.AsTime(), "", cipher)
func (srv *Handler) CreateAccessTokenForSession(id string, expiresAt time.Time) (string, error) {
return CreateCode(CodeTypeAccess, id, expiresAt, "", srv.cipher)
}
// DecryptAuthorizationCode decrypts the authorization code and returns the underlying session ID
func DecryptAccessToken(accessToken string, cipher cipher.AEAD) (string, error) {
code, err := DecryptCode(CodeTypeAccess, accessToken, cipher, "", time.Now())
func (srv *Handler) GetSessionIDFromAccessToken(accessToken string) (string, error) {
code, err := DecryptCode(CodeTypeAccess, accessToken, srv.cipher, "", time.Now())
if err != nil {
return "", err
}

File diff suppressed because it is too large Load diff

View file

@ -153,6 +153,7 @@ message Route {
message MCP {
optional UpstreamOAuth2 upstream_oauth2 = 1;
optional bool pass_upstream_access_token = 2;
}
message UpstreamOAuth2 {