mirror of
https://github.com/pomerium/pomerium.git
synced 2025-08-02 00:10:45 +02:00
mcp: client registration/token fixes (#5649)
## Summary Fixes to MCP code registration and token requests. 1. ease some requirements on fields that are RECOMMENDED 2. fill in defaults 3. store both request and response in the client registration 4. check client secret in the /token request ## Related issues - Fixes https://linear.app/pomerium/issue/ENG-2462/mcp-ignore-unknown-grant-types-in-the-client-registration - Fixes https://linear.app/pomerium/issue/ENG-2461/mcp-support-client-secret-in-dynamic-client-registration ## User Explanation <!-- How would you explain this change to the user? If this change doesn't create any user-facing changes, you can leave this blank. If filled out, add the `docs` label --> ## Checklist - [x] reference any related issues - [x] updated unit tests - [x] add appropriate label (`enhancement`, `bug`, `breaking`, `dependencies`, `ci`) - [ ] ready for review
This commit is contained in:
parent
200f2e8164
commit
777b3b12d2
20 changed files with 1440 additions and 461 deletions
|
@ -2,6 +2,7 @@ package mcp
|
|||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
|
@ -12,6 +13,7 @@ import (
|
|||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/internal/oauth21"
|
||||
oauth21proto "github.com/pomerium/pomerium/internal/oauth21/gen"
|
||||
rfc7591v1 "github.com/pomerium/pomerium/internal/rfc7591"
|
||||
)
|
||||
|
||||
// Token handles the /token endpoint.
|
||||
|
@ -21,54 +23,65 @@ func (srv *Handler) Token(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
req, err := oauth21.ParseTokenRequest(r)
|
||||
ctx := r.Context()
|
||||
req, err := srv.getTokenRequest(r)
|
||||
if err != nil {
|
||||
log.Ctx(r.Context()).Error().Err(err).Msg("failed to parse token request")
|
||||
log.Ctx(ctx).Error().Err(err).Msg("get token request failed")
|
||||
oauth21.ErrorResponse(w, http.StatusBadRequest, oauth21.InvalidRequest)
|
||||
return
|
||||
}
|
||||
|
||||
switch req.GrantType {
|
||||
case "authorization_code":
|
||||
log.Ctx(ctx).Debug().Msg("handling authorization_code token request")
|
||||
srv.handleAuthorizationCodeToken(w, r, req)
|
||||
default:
|
||||
log.Ctx(ctx).Error().Msgf("unsupported grant type: %s", req.GrantType)
|
||||
oauth21.ErrorResponse(w, http.StatusBadRequest, oauth21.UnsupportedGrantType)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (srv *Handler) handleAuthorizationCodeToken(w http.ResponseWriter, r *http.Request, req *oauth21proto.TokenRequest) {
|
||||
func (srv *Handler) handleAuthorizationCodeToken(w http.ResponseWriter, r *http.Request, tokenReq *oauth21proto.TokenRequest) {
|
||||
ctx := r.Context()
|
||||
|
||||
if req.ClientId == nil {
|
||||
if tokenReq.ClientId == nil {
|
||||
log.Ctx(ctx).Error().Msg("missing client_id in token request")
|
||||
oauth21.ErrorResponse(w, http.StatusBadRequest, oauth21.InvalidClient)
|
||||
return
|
||||
}
|
||||
if req.Code == nil {
|
||||
if tokenReq.Code == nil {
|
||||
log.Ctx(ctx).Error().Msg("missing code in token request")
|
||||
oauth21.ErrorResponse(w, http.StatusBadRequest, oauth21.InvalidGrant)
|
||||
return
|
||||
}
|
||||
code, err := DecryptCode(CodeTypeAuthorization, *req.Code, srv.cipher, *req.ClientId, time.Now())
|
||||
code, err := DecryptCode(CodeTypeAuthorization, *tokenReq.Code, srv.cipher, *tokenReq.ClientId, time.Now())
|
||||
if err != nil {
|
||||
log.Ctx(ctx).Error().Err(err).Msg("failed to decrypt authorization code")
|
||||
oauth21.ErrorResponse(w, http.StatusBadRequest, oauth21.InvalidGrant)
|
||||
return
|
||||
}
|
||||
|
||||
authReq, err := srv.storage.GetAuthorizationRequest(ctx, code.Id)
|
||||
if status.Code(err) == codes.NotFound {
|
||||
log.Ctx(ctx).Error().Msg("authorization request not found")
|
||||
oauth21.ErrorResponse(w, http.StatusBadRequest, oauth21.InvalidGrant)
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
log.Ctx(ctx).Error().Err(err).Msg("failed to get authorization request and client")
|
||||
http.Error(w, "internal error", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
if *req.ClientId != authReq.ClientId {
|
||||
if *tokenReq.ClientId != authReq.ClientId {
|
||||
log.Ctx(ctx).Error().Msgf("client ID mismatch: %s != %s", *tokenReq.ClientId, authReq.ClientId)
|
||||
oauth21.ErrorResponse(w, http.StatusBadRequest, oauth21.InvalidGrant)
|
||||
return
|
||||
}
|
||||
|
||||
err = CheckPKCE(authReq.GetCodeChallengeMethod(), authReq.GetCodeChallenge(), req.GetCodeVerifier())
|
||||
err = CheckPKCE(authReq.GetCodeChallengeMethod(), authReq.GetCodeChallenge(), tokenReq.GetCodeVerifier())
|
||||
if err != nil {
|
||||
log.Ctx(ctx).Error().Err(err).Msg("failed to check PKCE")
|
||||
oauth21.ErrorResponse(w, http.StatusBadRequest, oauth21.InvalidGrant)
|
||||
return
|
||||
}
|
||||
|
@ -77,24 +90,28 @@ func (srv *Handler) handleAuthorizationCodeToken(w http.ResponseWriter, r *http.
|
|||
// https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-12#section-4.1.3
|
||||
err = srv.storage.DeleteAuthorizationRequest(ctx, code.Id)
|
||||
if err != nil {
|
||||
log.Ctx(ctx).Error().Err(err).Msg("failed to delete authorization request")
|
||||
http.Error(w, "internal error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
session, err := srv.storage.GetSession(ctx, authReq.SessionId)
|
||||
if status.Code(err) == codes.NotFound {
|
||||
log.Ctx(ctx).Error().Msg("session not found")
|
||||
oauth21.ErrorResponse(w, http.StatusBadRequest, oauth21.InvalidGrant)
|
||||
return
|
||||
}
|
||||
|
||||
accessToken, err := srv.GetAccessTokenForSession(session.Id, session.ExpiresAt.AsTime())
|
||||
if err != nil {
|
||||
log.Ctx(ctx).Error().Err(err).Msg("failed to get access token for session")
|
||||
http.Error(w, "internal error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
expiresIn := time.Until(session.ExpiresAt.AsTime())
|
||||
if expiresIn < 0 {
|
||||
log.Ctx(ctx).Error().Msg("session has already expired")
|
||||
oauth21.ErrorResponse(w, http.StatusBadRequest, oauth21.InvalidGrant)
|
||||
return
|
||||
}
|
||||
|
@ -107,6 +124,7 @@ func (srv *Handler) handleAuthorizationCodeToken(w http.ResponseWriter, r *http.
|
|||
|
||||
data, err := json.Marshal(resp) // not using protojson.Marshal here because it emits numbers as strings, which is valid, but for some reason Node.js / mcp typescript SDK doesn't like it
|
||||
if err != nil {
|
||||
log.Ctx(ctx).Error().Err(err).Msg("failed to marshal token response")
|
||||
http.Error(w, "internal error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
@ -116,3 +134,46 @@ func (srv *Handler) handleAuthorizationCodeToken(w http.ResponseWriter, r *http.
|
|||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write(data)
|
||||
}
|
||||
|
||||
func (srv *Handler) getTokenRequest(
|
||||
r *http.Request,
|
||||
) (*oauth21proto.TokenRequest, error) {
|
||||
tokenReq, err := oauth21.ParseTokenRequest(r)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse token request: %w", err)
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
clientReg, err := srv.storage.GetClient(ctx, tokenReq.GetClientId())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get client registration: %w", err)
|
||||
}
|
||||
|
||||
m := clientReg.ResponseMetadata.GetTokenEndpointAuthMethod()
|
||||
if m == rfc7591v1.TokenEndpointAuthMethodNone {
|
||||
return tokenReq, nil
|
||||
}
|
||||
|
||||
secret := clientReg.ClientSecret
|
||||
if secret == nil {
|
||||
return nil, fmt.Errorf("client registration does not have a client secret")
|
||||
}
|
||||
if expires := secret.ExpiresAt; expires != nil && expires.AsTime().Before(time.Now()) {
|
||||
return nil, fmt.Errorf("client registration client secret has expired")
|
||||
}
|
||||
|
||||
switch m {
|
||||
case rfc7591v1.TokenEndpointAuthMethodClientSecretBasic:
|
||||
case rfc7591v1.TokenEndpointAuthMethodClientSecretPost:
|
||||
if tokenReq.ClientSecret == nil {
|
||||
return nil, fmt.Errorf("client_secret was not provided")
|
||||
}
|
||||
if tokenReq.GetClientSecret() != secret.Value {
|
||||
return nil, fmt.Errorf("client secret mismatch")
|
||||
}
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported token endpoint authentication method: %s", m)
|
||||
}
|
||||
|
||||
return tokenReq, nil
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue