mcp: authorize: load session from the access token (#5591)

This commit is contained in:
Denis Mishin 2025-04-28 16:32:06 -04:00 committed by GitHub
parent 0602f5e00d
commit daaf5b8e30
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 112 additions and 17 deletions

View file

@ -16,6 +16,7 @@ import (
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/atomicutil"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/mcp"
"github.com/pomerium/pomerium/internal/telemetry/metrics"
"github.com/pomerium/pomerium/pkg/cryptutil"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
@ -28,6 +29,7 @@ type Authorize struct {
store *store.Store
currentConfig *atomicutil.Value[*config.Config]
accessTracker *AccessTracker
mcp *atomicutil.Value[*mcp.Handler]
tracerProvider oteltrace.TracerProvider
tracer oteltrace.Tracer
@ -37,11 +39,18 @@ type Authorize struct {
func New(ctx context.Context, cfg *config.Config) (*Authorize, error) {
tracerProvider := trace.NewTracerProvider(ctx, "Authorize")
tracer := tracerProvider.Tracer(trace.PomeriumCoreTracer)
mcp, err := mcp.New(ctx, mcp.DefaultPrefix, cfg)
if err != nil {
return nil, fmt.Errorf("authorize: failed to create mcp handler: %w", err)
}
a := &Authorize{
currentConfig: atomicutil.NewValue(cfg),
store: store.New(),
tracerProvider: tracerProvider,
tracer: tracer,
mcp: atomicutil.NewValue(mcp),
}
a.accessTracker = NewAccessTracker(a, accessTrackerMaxSize, accessTrackerDebouncePeriod)
@ -151,4 +160,11 @@ func (a *Authorize) OnConfigChange(ctx context.Context, cfg *config.Config) {
} else {
a.state.Store(newState)
}
mcp, err := mcp.New(ctx, mcp.DefaultPrefix, cfg)
if err != nil {
log.Ctx(ctx).Error().Err(err).Msg("proxy: failed to update authorize state from configuration settings")
} else {
a.mcp.Store(mcp)
}
}

View file

@ -21,7 +21,9 @@ import (
"github.com/pomerium/pomerium/internal/sessions"
"github.com/pomerium/pomerium/pkg/contextutil"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/grpc/session"
"github.com/pomerium/pomerium/pkg/grpc/user"
"github.com/pomerium/pomerium/pkg/grpcutil"
"github.com/pomerium/pomerium/pkg/storage"
"github.com/pomerium/pomerium/pkg/telemetry/requestid"
)
@ -88,29 +90,14 @@ func (a *Authorize) loadSession(
) (s sessionOrServiceAccount, err error) {
requestID := requestid.FromHTTPHeader(hreq.Header)
// attempt to create a session from an incoming idp token
s, err = config.NewIncomingIDPTokenSessionCreator(
func(ctx context.Context, recordType, recordID string) (*databroker.Record, error) {
return storage.GetDataBrokerRecord(ctx, recordType, recordID, 0)
},
func(ctx context.Context, records []*databroker.Record) error {
_, err := a.state.Load().dataBrokerClient.Put(ctx, &databroker.PutRequest{
Records: records,
})
if err != nil {
return err
}
storage.InvalidateCacheForDataBrokerRecords(ctx, records...)
return nil
},
).CreateSession(ctx, a.currentConfig.Load(), req.Policy, hreq)
s, err = a.maybeGetSessionFromRequest(ctx, hreq, req.Policy)
if err == nil {
return s, nil
} else if !errors.Is(err, sessions.ErrNoSessionFound) {
log.Ctx(ctx).Info().
Str("request-id", requestID).
Err(err).
Msg("error creating session for incoming idp token")
Msg("error creating session from incoming request")
return nil, err
}
@ -131,6 +118,76 @@ func (a *Authorize) loadSession(
return s, nil
}
func (a *Authorize) maybeGetSessionFromRequest(
ctx context.Context,
hreq *http.Request,
policy *config.Policy,
) (*session.Session, error) {
if policy.IsMCP() {
s, err := a.getMCPSession(ctx, hreq)
if err != nil {
log.Ctx(ctx).Error().Err(err).Msg("error getting mcp session")
return nil, err
}
return s, nil
}
// attempt to create a session from an incoming idp token
return config.NewIncomingIDPTokenSessionCreator(
func(ctx context.Context, recordType, recordID string) (*databroker.Record, error) {
return storage.GetDataBrokerRecord(ctx, recordType, recordID, 0)
},
func(ctx context.Context, records []*databroker.Record) error {
_, err := a.state.Load().dataBrokerClient.Put(ctx, &databroker.PutRequest{
Records: records,
})
if err != nil {
return err
}
storage.InvalidateCacheForDataBrokerRecords(ctx, records...)
return nil
},
).CreateSession(ctx, a.currentConfig.Load(), policy, hreq)
}
func (a *Authorize) getMCPSession(
ctx context.Context,
hreq *http.Request,
) (*session.Session, error) {
auth := hreq.Header.Get(httputil.HeaderAuthorization)
if auth == "" {
return nil, fmt.Errorf("no authorization header was provided: %w", sessions.ErrNoSessionFound)
}
prefix := "Bearer "
if !strings.HasPrefix(strings.ToLower(auth), strings.ToLower(prefix)) {
return nil, fmt.Errorf("authorization header does not start with %q: %w", prefix, sessions.ErrNoSessionFound)
}
accessToken := auth[len(prefix):]
sessionID, ok := a.mcp.Load().GetSessionIDFromAccessToken(ctx, accessToken)
if !ok {
return nil, fmt.Errorf("no session found for access token: %w", sessions.ErrNoSessionFound)
}
record, err := storage.GetDataBrokerRecord(ctx, grpcutil.GetTypeURL(new(session.Session)), sessionID, 0)
if storage.IsNotFound(err) {
return nil, fmt.Errorf("session databroker record not found: %w", sessions.ErrNoSessionFound)
}
msg, err := record.GetData().UnmarshalNew()
if err != nil {
return nil, fmt.Errorf("error unmarshalling session: %w: %w", err, sessions.ErrNoSessionFound)
}
s, ok := msg.(*session.Session)
if !ok {
return nil, fmt.Errorf("unexpected session type: %T: %w", msg, sessions.ErrNoSessionFound)
}
return s, nil
}
func (a *Authorize) getEvaluatorRequestFromCheckRequest(
ctx context.Context,
in *envoy_service_auth_v3.CheckRequest,

View file

@ -1,6 +1,7 @@
package mcp
import (
"context"
"encoding/json"
"net/http"
"time"
@ -116,3 +117,12 @@ func (srv *Handler) handleAuthorizationCodeToken(w http.ResponseWriter, r *http.
w.WriteHeader(http.StatusOK)
_, _ = w.Write(data)
}
func (srv *Handler) GetSessionIDFromAccessToken(ctx context.Context, accessToken string) (string, bool) {
sessionID, err := DecryptAccessToken(accessToken, srv.cipher)
if err != nil {
log.Ctx(ctx).Error().Err(err).Msg("failed to decrypt access token")
return "", false
}
return sessionID, true
}

View file

@ -3,6 +3,7 @@ package mcp
import (
"crypto/cipher"
"fmt"
"time"
"github.com/pomerium/pomerium/internal/oauth21"
"github.com/pomerium/pomerium/pkg/grpc/session"
@ -28,6 +29,17 @@ func CheckPKCE(
return nil
}
// CreateAuthorizationCode creates an access token based on the session
func CreateAccessToken(src *session.Session, cipher cipher.AEAD) (string, error) {
return CreateCode(CodeTypeAccess, src.Id, src.ExpiresAt.AsTime(), "", cipher)
}
// DecryptAuthorizationCode decrypts the authorization code and returns the underlying session ID
func DecryptAccessToken(accessToken string, cipher cipher.AEAD) (string, error) {
code, err := DecryptCode(CodeTypeAccess, accessToken, cipher, "", time.Now())
if err != nil {
return "", err
}
return code.Id, nil
}