mirror of
https://github.com/pomerium/pomerium.git
synced 2025-07-31 15:29:48 +02:00
mcp: pass access token to the upstream (#5593)
This commit is contained in:
parent
b9e3a5d301
commit
5b024a8ada
15 changed files with 774 additions and 719 deletions
|
@ -16,7 +16,6 @@ import (
|
||||||
"github.com/pomerium/pomerium/config"
|
"github.com/pomerium/pomerium/config"
|
||||||
"github.com/pomerium/pomerium/internal/atomicutil"
|
"github.com/pomerium/pomerium/internal/atomicutil"
|
||||||
"github.com/pomerium/pomerium/internal/log"
|
"github.com/pomerium/pomerium/internal/log"
|
||||||
"github.com/pomerium/pomerium/internal/mcp"
|
|
||||||
"github.com/pomerium/pomerium/internal/telemetry/metrics"
|
"github.com/pomerium/pomerium/internal/telemetry/metrics"
|
||||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||||
|
@ -29,7 +28,6 @@ type Authorize struct {
|
||||||
store *store.Store
|
store *store.Store
|
||||||
currentConfig *atomicutil.Value[*config.Config]
|
currentConfig *atomicutil.Value[*config.Config]
|
||||||
accessTracker *AccessTracker
|
accessTracker *AccessTracker
|
||||||
mcp *atomicutil.Value[*mcp.Handler]
|
|
||||||
|
|
||||||
tracerProvider oteltrace.TracerProvider
|
tracerProvider oteltrace.TracerProvider
|
||||||
tracer oteltrace.Tracer
|
tracer oteltrace.Tracer
|
||||||
|
@ -40,17 +38,11 @@ func New(ctx context.Context, cfg *config.Config) (*Authorize, error) {
|
||||||
tracerProvider := trace.NewTracerProvider(ctx, "Authorize")
|
tracerProvider := trace.NewTracerProvider(ctx, "Authorize")
|
||||||
tracer := tracerProvider.Tracer(trace.PomeriumCoreTracer)
|
tracer := tracerProvider.Tracer(trace.PomeriumCoreTracer)
|
||||||
|
|
||||||
mcp, err := mcp.New(ctx, mcp.DefaultPrefix, cfg)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("authorize: failed to create mcp handler: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
a := &Authorize{
|
a := &Authorize{
|
||||||
currentConfig: atomicutil.NewValue(cfg),
|
currentConfig: atomicutil.NewValue(cfg),
|
||||||
store: store.New(),
|
store: store.New(),
|
||||||
tracerProvider: tracerProvider,
|
tracerProvider: tracerProvider,
|
||||||
tracer: tracer,
|
tracer: tracer,
|
||||||
mcp: atomicutil.NewValue(mcp),
|
|
||||||
}
|
}
|
||||||
a.accessTracker = NewAccessTracker(a, accessTrackerMaxSize, accessTrackerDebouncePeriod)
|
a.accessTracker = NewAccessTracker(a, accessTrackerMaxSize, accessTrackerDebouncePeriod)
|
||||||
|
|
||||||
|
@ -93,6 +85,7 @@ func validateOptions(o *config.Options) error {
|
||||||
func newPolicyEvaluator(
|
func newPolicyEvaluator(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
opts *config.Options, store *store.Store, previous *evaluator.Evaluator,
|
opts *config.Options, store *store.Store, previous *evaluator.Evaluator,
|
||||||
|
evaluatorOpts ...evaluator.Option,
|
||||||
) (*evaluator.Evaluator, error) {
|
) (*evaluator.Evaluator, error) {
|
||||||
metrics.AddPolicyCountCallback("pomerium-authorize", func() int64 {
|
metrics.AddPolicyCountCallback("pomerium-authorize", func() int64 {
|
||||||
return int64(opts.NumPolicies())
|
return int64(opts.NumPolicies())
|
||||||
|
@ -136,7 +129,7 @@ func newPolicyEvaluator(
|
||||||
}
|
}
|
||||||
|
|
||||||
allPolicies := slices.Collect(opts.GetAllPolicies())
|
allPolicies := slices.Collect(opts.GetAllPolicies())
|
||||||
return evaluator.New(ctx, store, previous,
|
evaluatorOpts = append([]evaluator.Option{
|
||||||
evaluator.WithPolicies(allPolicies),
|
evaluator.WithPolicies(allPolicies),
|
||||||
evaluator.WithClientCA(clientCA),
|
evaluator.WithClientCA(clientCA),
|
||||||
evaluator.WithAddDefaultClientCertificateRule(addDefaultClientCertificateRule),
|
evaluator.WithAddDefaultClientCertificateRule(addDefaultClientCertificateRule),
|
||||||
|
@ -148,7 +141,8 @@ func newPolicyEvaluator(
|
||||||
evaluator.WithJWTClaimsHeaders(opts.JWTClaimsHeaders),
|
evaluator.WithJWTClaimsHeaders(opts.JWTClaimsHeaders),
|
||||||
evaluator.WithJWTGroupsFilter(opts.JWTGroupsFilter),
|
evaluator.WithJWTGroupsFilter(opts.JWTGroupsFilter),
|
||||||
evaluator.WithDefaultJWTIssuerFormat(opts.JWTIssuerFormat),
|
evaluator.WithDefaultJWTIssuerFormat(opts.JWTIssuerFormat),
|
||||||
)
|
}, evaluatorOpts...)
|
||||||
|
return evaluator.New(ctx, store, previous, evaluatorOpts...)
|
||||||
}
|
}
|
||||||
|
|
||||||
// OnConfigChange updates internal structures based on config.Options
|
// OnConfigChange updates internal structures based on config.Options
|
||||||
|
@ -160,11 +154,4 @@ func (a *Authorize) OnConfigChange(ctx context.Context, cfg *config.Config) {
|
||||||
} else {
|
} else {
|
||||||
a.state.Store(newState)
|
a.state.Store(newState)
|
||||||
}
|
}
|
||||||
|
|
||||||
mcp, err := mcp.New(ctx, mcp.DefaultPrefix, cfg)
|
|
||||||
if err != nil {
|
|
||||||
log.Ctx(ctx).Error().Err(err).Msg("proxy: failed to update authorize state from configuration settings")
|
|
||||||
} else {
|
|
||||||
a.mcp.Store(mcp)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -99,7 +99,7 @@ func (a *Authorize) handleResultDenied(
|
||||||
case invalidClientCertReason(reasons):
|
case invalidClientCertReason(reasons):
|
||||||
denyStatusCode = httputil.StatusInvalidClientCertificate
|
denyStatusCode = httputil.StatusInvalidClientCertificate
|
||||||
denyStatusText = httputil.DetailsText(httputil.StatusInvalidClientCertificate)
|
denyStatusText = httputil.DetailsText(httputil.StatusInvalidClientCertificate)
|
||||||
case request.Policy.IsMCP():
|
case request.Policy.IsMCPServer():
|
||||||
denyStatusCode = http.StatusUnauthorized
|
denyStatusCode = http.StatusUnauthorized
|
||||||
denyStatusText = httputil.DetailsText(http.StatusUnauthorized)
|
denyStatusText = httputil.DetailsText(http.StatusUnauthorized)
|
||||||
}
|
}
|
||||||
|
@ -358,7 +358,7 @@ func (a *Authorize) userInfoEndpointURL(in *envoy_service_auth_v3.CheckRequest)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Authorize) shouldRedirect(in *envoy_service_auth_v3.CheckRequest, request *evaluator.Request) bool {
|
func (a *Authorize) shouldRedirect(in *envoy_service_auth_v3.CheckRequest, request *evaluator.Request) bool {
|
||||||
if request.Policy.IsMCP() {
|
if request.Policy.IsMCPServer() {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
package evaluator
|
package evaluator
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/config"
|
"github.com/pomerium/pomerium/config"
|
||||||
"github.com/pomerium/pomerium/internal/hashutil"
|
"github.com/pomerium/pomerium/internal/hashutil"
|
||||||
)
|
)
|
||||||
|
@ -17,6 +19,7 @@ type evaluatorConfig struct {
|
||||||
JWTClaimsHeaders config.JWTClaimHeaders
|
JWTClaimsHeaders config.JWTClaimHeaders
|
||||||
JWTGroupsFilter config.JWTGroupsFilter
|
JWTGroupsFilter config.JWTGroupsFilter
|
||||||
DefaultJWTIssuerFormat config.JWTIssuerFormat
|
DefaultJWTIssuerFormat config.JWTIssuerFormat
|
||||||
|
MCPAccessTokenProvider func(string, time.Time) (string, error) `hash:"-"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// cacheKey() returns a hash over the configuration, except for the policies.
|
// cacheKey() returns a hash over the configuration, except for the policies.
|
||||||
|
@ -113,3 +116,10 @@ func WithDefaultJWTIssuerFormat(format config.JWTIssuerFormat) Option {
|
||||||
cfg.DefaultJWTIssuerFormat = format
|
cfg.DefaultJWTIssuerFormat = format
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WithMCPAccessTokenProvider sets the MCP access token in the config.
|
||||||
|
func WithMCPAccessTokenProvider(fn func(sessionID string, expires time.Time) (string, error)) Option {
|
||||||
|
return func(cfg *evaluatorConfig) {
|
||||||
|
cfg.MCPAccessTokenProvider = fn
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -383,6 +383,7 @@ func updateStore(store *store.Store, cfg *evaluatorConfig) error {
|
||||||
store.UpdateDefaultJWTIssuerFormat(cfg.DefaultJWTIssuerFormat)
|
store.UpdateDefaultJWTIssuerFormat(cfg.DefaultJWTIssuerFormat)
|
||||||
store.UpdateRoutePolicies(cfg.Policies)
|
store.UpdateRoutePolicies(cfg.Policies)
|
||||||
store.UpdateSigningKey(jwk)
|
store.UpdateSigningKey(jwk)
|
||||||
|
store.UpdateMCPAccessTokenProvider(cfg.MCPAccessTokenProvider)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -93,6 +93,27 @@ func (e *headersEvaluatorEvaluation) fillJWTClaimHeaders(ctx context.Context) er
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (e *headersEvaluatorEvaluation) fillMCPHeaders() error {
|
||||||
|
if e.request == nil ||
|
||||||
|
e.request.Policy == nil ||
|
||||||
|
e.request.Policy.MCP == nil ||
|
||||||
|
!e.request.Policy.MCP.PassUpstreamAccessToken {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if e.request.Session.ID == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
accessToken, err := e.evaluator.store.GetMCPAccessTokenProvider()(e.request.Session.ID, e.now.Add(time.Hour))
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("authorize/header-evaluator: error getting MCP access token: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
e.response.Headers.Set("Authorization", "Bearer "+accessToken)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (e *headersEvaluatorEvaluation) fillKubernetesHeaders(ctx context.Context) {
|
func (e *headersEvaluatorEvaluation) fillKubernetesHeaders(ctx context.Context) {
|
||||||
if e.request.Policy == nil {
|
if e.request.Policy == nil {
|
||||||
return
|
return
|
||||||
|
@ -177,6 +198,10 @@ func (e *headersEvaluatorEvaluation) fillHeaders(ctx context.Context) error {
|
||||||
if err := e.fillJWTClaimHeaders(ctx); err != nil {
|
if err := e.fillJWTClaimHeaders(ctx); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
err := e.fillMCPHeaders()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
e.fillKubernetesHeaders(ctx)
|
e.fillKubernetesHeaders(ctx)
|
||||||
e.fillGoogleCloudServerlessHeaders(ctx)
|
e.fillGoogleCloudServerlessHeaders(ctx)
|
||||||
e.fillRoutingKeyHeaders()
|
e.fillRoutingKeyHeaders()
|
||||||
|
|
|
@ -123,7 +123,7 @@ func (a *Authorize) maybeGetSessionFromRequest(
|
||||||
hreq *http.Request,
|
hreq *http.Request,
|
||||||
policy *config.Policy,
|
policy *config.Policy,
|
||||||
) (*session.Session, error) {
|
) (*session.Session, error) {
|
||||||
if policy.IsMCP() {
|
if policy.IsMCPServer() {
|
||||||
s, err := a.getMCPSession(ctx, hreq)
|
s, err := a.getMCPSession(ctx, hreq)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Ctx(ctx).Error().Err(err).Msg("error getting mcp session")
|
log.Ctx(ctx).Error().Err(err).Msg("error getting mcp session")
|
||||||
|
@ -165,8 +165,8 @@ func (a *Authorize) getMCPSession(
|
||||||
}
|
}
|
||||||
|
|
||||||
accessToken := auth[len(prefix):]
|
accessToken := auth[len(prefix):]
|
||||||
sessionID, ok := a.mcp.Load().GetSessionIDFromAccessToken(ctx, accessToken)
|
sessionID, err := a.state.Load().mcp.GetSessionIDFromAccessToken(accessToken)
|
||||||
if !ok {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("no session found for access token: %w", sessions.ErrNoSessionFound)
|
return nil, fmt.Errorf("no session found for access token: %w", sessions.ErrNoSessionFound)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -20,6 +20,7 @@ import (
|
||||||
"google.golang.org/protobuf/types/known/timestamppb"
|
"google.golang.org/protobuf/types/known/timestamppb"
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/config"
|
"github.com/pomerium/pomerium/config"
|
||||||
|
"github.com/pomerium/pomerium/internal/atomicutil"
|
||||||
"github.com/pomerium/pomerium/internal/log"
|
"github.com/pomerium/pomerium/internal/log"
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||||
"github.com/pomerium/pomerium/pkg/storage"
|
"github.com/pomerium/pomerium/pkg/storage"
|
||||||
|
@ -35,6 +36,8 @@ type Store struct {
|
||||||
jwtGroupsFilter atomic.Pointer[config.JWTGroupsFilter]
|
jwtGroupsFilter atomic.Pointer[config.JWTGroupsFilter]
|
||||||
defaultJWTIssuerFormat atomic.Pointer[config.JWTIssuerFormat]
|
defaultJWTIssuerFormat atomic.Pointer[config.JWTIssuerFormat]
|
||||||
signingKey atomic.Pointer[jose.JSONWebKey]
|
signingKey atomic.Pointer[jose.JSONWebKey]
|
||||||
|
|
||||||
|
mcpAccessTokenProvider atomicutil.Value[func(string, time.Time) (string, error)]
|
||||||
}
|
}
|
||||||
|
|
||||||
// New creates a new Store.
|
// New creates a new Store.
|
||||||
|
@ -78,6 +81,15 @@ func (s *Store) GetSigningKey() *jose.JSONWebKey {
|
||||||
return s.signingKey.Load()
|
return s.signingKey.Load()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Store) GetMCPAccessTokenProvider() func(string, time.Time) (string, error) {
|
||||||
|
if f := s.mcpAccessTokenProvider.Load(); f != nil {
|
||||||
|
return f
|
||||||
|
}
|
||||||
|
return func(string, time.Time) (string, error) {
|
||||||
|
return "", fmt.Errorf("no mcp access token provider")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// UpdateGoogleCloudServerlessAuthenticationServiceAccount updates the google cloud serverless authentication
|
// UpdateGoogleCloudServerlessAuthenticationServiceAccount updates the google cloud serverless authentication
|
||||||
// service account in the store.
|
// service account in the store.
|
||||||
func (s *Store) UpdateGoogleCloudServerlessAuthenticationServiceAccount(serviceAccount string) {
|
func (s *Store) UpdateGoogleCloudServerlessAuthenticationServiceAccount(serviceAccount string) {
|
||||||
|
@ -115,6 +127,11 @@ func (s *Store) UpdateSigningKey(signingKey *jose.JSONWebKey) {
|
||||||
s.signingKey.Store(signingKey)
|
s.signingKey.Store(signingKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Store) UpdateMCPAccessTokenProvider(mcpAccessTokenProvider func(string, time.Time) (string, error)) {
|
||||||
|
// This isn't used by the Rego code, so we don't need to write it to the opastorage.Store instance.
|
||||||
|
s.mcpAccessTokenProvider.Store(mcpAccessTokenProvider)
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Store) write(rawPath string, value any) {
|
func (s *Store) write(rawPath string, value any) {
|
||||||
ctx := context.TODO()
|
ctx := context.TODO()
|
||||||
err := opastorage.Txn(ctx, s.Store, opastorage.WriteParams, func(txn opastorage.Transaction) error {
|
err := opastorage.Txn(ctx, s.Store, opastorage.WriteParams, func(txn opastorage.Transaction) error {
|
||||||
|
|
|
@ -14,6 +14,7 @@ import (
|
||||||
"github.com/pomerium/pomerium/authorize/internal/store"
|
"github.com/pomerium/pomerium/authorize/internal/store"
|
||||||
"github.com/pomerium/pomerium/config"
|
"github.com/pomerium/pomerium/config"
|
||||||
"github.com/pomerium/pomerium/internal/authenticateflow"
|
"github.com/pomerium/pomerium/internal/authenticateflow"
|
||||||
|
"github.com/pomerium/pomerium/internal/mcp"
|
||||||
"github.com/pomerium/pomerium/pkg/grpc"
|
"github.com/pomerium/pomerium/pkg/grpc"
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/session"
|
"github.com/pomerium/pomerium/pkg/grpc/session"
|
||||||
|
@ -36,6 +37,7 @@ type authorizeState struct {
|
||||||
sessionStore *config.SessionStore
|
sessionStore *config.SessionStore
|
||||||
authenticateFlow authenticateFlow
|
authenticateFlow authenticateFlow
|
||||||
syncQueriers map[string]storage.Querier
|
syncQueriers map[string]storage.Querier
|
||||||
|
mcp *mcp.Handler
|
||||||
}
|
}
|
||||||
|
|
||||||
func newAuthorizeStateFromConfig(
|
func newAuthorizeStateFromConfig(
|
||||||
|
@ -57,7 +59,13 @@ func newAuthorizeStateFromConfig(
|
||||||
previousEvaluator = previousState.evaluator
|
previousEvaluator = previousState.evaluator
|
||||||
}
|
}
|
||||||
|
|
||||||
state.evaluator, err = newPolicyEvaluator(ctx, cfg.Options, store, previousEvaluator)
|
mcp, err := mcp.New(ctx, mcp.DefaultPrefix, cfg)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("authorize: failed to create mcp handler: %w", err)
|
||||||
|
}
|
||||||
|
state.mcp = mcp
|
||||||
|
|
||||||
|
state.evaluator, err = newPolicyEvaluator(ctx, cfg.Options, store, previousEvaluator, evaluator.WithMCPAccessTokenProvider(mcp.CreateAccessTokenForSession))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("authorize: failed to update policy with options: %w", err)
|
return nil, fmt.Errorf("authorize: failed to update policy with options: %w", err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -671,6 +671,7 @@ func MCPFromPB(src *configpb.MCP) *MCP {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
var v MCP
|
var v MCP
|
||||||
|
v.PassUpstreamAccessToken = src.GetPassUpstreamAccessToken()
|
||||||
if uo := src.GetUpstreamOauth2(); uo != nil {
|
if uo := src.GetUpstreamOauth2(); uo != nil {
|
||||||
v.UpstreamOAuth2 = &UpstreamOAuth2{
|
v.UpstreamOAuth2 = &UpstreamOAuth2{
|
||||||
ClientID: uo.GetClientId(),
|
ClientID: uo.GetClientId(),
|
||||||
|
@ -708,6 +709,7 @@ func MCPToPB(src *MCP) *configpb.MCP {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
v := new(configpb.MCP)
|
v := new(configpb.MCP)
|
||||||
|
v.PassUpstreamAccessToken = proto.Bool(src.PassUpstreamAccessToken)
|
||||||
if src.UpstreamOAuth2 != nil {
|
if src.UpstreamOAuth2 != nil {
|
||||||
var authStyle *configpb.OAuth2AuthStyle
|
var authStyle *configpb.OAuth2AuthStyle
|
||||||
switch src.UpstreamOAuth2.Endpoint.AuthStyle {
|
switch src.UpstreamOAuth2.Endpoint.AuthStyle {
|
||||||
|
|
|
@ -1309,7 +1309,7 @@ func (o *Options) GetAllRouteableHTTPHosts() ([]string, map[string]bool, error)
|
||||||
hosts.InsertSlice(domains)
|
hosts.InsertSlice(domains)
|
||||||
|
|
||||||
// Track if the domains are associated with an MCP policy
|
// Track if the domains are associated with an MCP policy
|
||||||
if policy.IsMCP() {
|
if policy.IsMCPServer() {
|
||||||
for _, domain := range domains {
|
for _, domain := range domains {
|
||||||
mcpHosts[domain] = true
|
mcpHosts[domain] = true
|
||||||
}
|
}
|
||||||
|
@ -1321,7 +1321,7 @@ func (o *Options) GetAllRouteableHTTPHosts() ([]string, map[string]bool, error)
|
||||||
hosts.InsertSlice(tlsDomains)
|
hosts.InsertSlice(tlsDomains)
|
||||||
|
|
||||||
// Track if the TLS domains are associated with an MCP policy
|
// Track if the TLS domains are associated with an MCP policy
|
||||||
if policy.IsMCP() {
|
if policy.IsMCPServer() {
|
||||||
for _, domain := range tlsDomains {
|
for _, domain := range tlsDomains {
|
||||||
mcpHosts[domain] = true
|
mcpHosts[domain] = true
|
||||||
}
|
}
|
||||||
|
|
|
@ -211,6 +211,8 @@ type Policy struct {
|
||||||
type MCP struct {
|
type MCP struct {
|
||||||
// UpstreamOAuth2 specifies that before the request reaches the MCP upstream server, it should acquire an OAuth2 token
|
// UpstreamOAuth2 specifies that before the request reaches the MCP upstream server, it should acquire an OAuth2 token
|
||||||
UpstreamOAuth2 *UpstreamOAuth2 `mapstructure:"upstream_oauth2" yaml:"upstream_oauth2,omitempty" json:"upstream_oauth2,omitempty"`
|
UpstreamOAuth2 *UpstreamOAuth2 `mapstructure:"upstream_oauth2" yaml:"upstream_oauth2,omitempty" json:"upstream_oauth2,omitempty"`
|
||||||
|
// PassUpstreamAccessToken indicates whether to pass the upstream access token in the `Authorization: Bearer` header that is suitable for calling the MCP routes
|
||||||
|
PassUpstreamAccessToken bool `mapstructure:"pass_upstream_access_token" yaml:"pass_upstream_access_token,omitempty" json:"pass_upstream_access_token,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type UpstreamOAuth2 struct {
|
type UpstreamOAuth2 struct {
|
||||||
|
@ -859,9 +861,9 @@ func (p *Policy) IsForKubernetes() bool {
|
||||||
return p.KubernetesServiceAccountTokenFile != "" || p.KubernetesServiceAccountToken != ""
|
return p.KubernetesServiceAccountTokenFile != "" || p.KubernetesServiceAccountToken != ""
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsMCP returns true if the route is for the Model Context Protocol upstream server.
|
// IsMCPServer returns true if the route is for the Model Context Protocol upstream server.
|
||||||
func (p *Policy) IsMCP() bool {
|
func (p *Policy) IsMCPServer() bool {
|
||||||
return p != nil && p.MCP != nil
|
return p != nil && p.MCP != nil && !p.MCP.PassUpstreamAccessToken
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsTCP returns true if the route is for TCP.
|
// IsTCP returns true if the route is for TCP.
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
package mcp
|
package mcp
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
|
@ -88,7 +87,7 @@ func (srv *Handler) handleAuthorizationCodeToken(w http.ResponseWriter, r *http.
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
accessToken, err := CreateAccessToken(session, srv.cipher)
|
accessToken, err := srv.CreateAccessTokenForSession(session.Id, session.ExpiresAt.AsTime())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(w, "internal error", http.StatusInternalServerError)
|
http.Error(w, "internal error", http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
|
@ -117,12 +116,3 @@ func (srv *Handler) handleAuthorizationCodeToken(w http.ResponseWriter, r *http.
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
_, _ = w.Write(data)
|
_, _ = w.Write(data)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (srv *Handler) GetSessionIDFromAccessToken(ctx context.Context, accessToken string) (string, bool) {
|
|
||||||
sessionID, err := DecryptAccessToken(accessToken, srv.cipher)
|
|
||||||
if err != nil {
|
|
||||||
log.Ctx(ctx).Error().Err(err).Msg("failed to decrypt access token")
|
|
||||||
return "", false
|
|
||||||
}
|
|
||||||
return sessionID, true
|
|
||||||
}
|
|
||||||
|
|
|
@ -1,12 +1,10 @@
|
||||||
package mcp
|
package mcp
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/cipher"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/internal/oauth21"
|
"github.com/pomerium/pomerium/internal/oauth21"
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/session"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func CheckPKCE(
|
func CheckPKCE(
|
||||||
|
@ -30,13 +28,13 @@ func CheckPKCE(
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateAuthorizationCode creates an access token based on the session
|
// CreateAuthorizationCode creates an access token based on the session
|
||||||
func CreateAccessToken(src *session.Session, cipher cipher.AEAD) (string, error) {
|
func (srv *Handler) CreateAccessTokenForSession(id string, expiresAt time.Time) (string, error) {
|
||||||
return CreateCode(CodeTypeAccess, src.Id, src.ExpiresAt.AsTime(), "", cipher)
|
return CreateCode(CodeTypeAccess, id, expiresAt, "", srv.cipher)
|
||||||
}
|
}
|
||||||
|
|
||||||
// DecryptAuthorizationCode decrypts the authorization code and returns the underlying session ID
|
// DecryptAuthorizationCode decrypts the authorization code and returns the underlying session ID
|
||||||
func DecryptAccessToken(accessToken string, cipher cipher.AEAD) (string, error) {
|
func (srv *Handler) GetSessionIDFromAccessToken(accessToken string) (string, error) {
|
||||||
code, err := DecryptCode(CodeTypeAccess, accessToken, cipher, "", time.Now())
|
code, err := DecryptCode(CodeTypeAccess, accessToken, srv.cipher, "", time.Now())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -152,7 +152,8 @@ message Route {
|
||||||
}
|
}
|
||||||
|
|
||||||
message MCP {
|
message MCP {
|
||||||
optional UpstreamOAuth2 upstream_oauth2 = 1;
|
optional UpstreamOAuth2 upstream_oauth2 = 1;
|
||||||
|
optional bool pass_upstream_access_token = 2;
|
||||||
}
|
}
|
||||||
|
|
||||||
message UpstreamOAuth2 {
|
message UpstreamOAuth2 {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue