pomerium/internal/mcp/handler_token.go
Denis Mishin 777b3b12d2
mcp: client registration/token fixes (#5649)
## Summary

Fixes to MCP code registration and token requests. 

1. ease some requirements on fields that are RECOMMENDED 
2. fill in defaults
3. store both request and response in the client registration
4. check client secret in the /token request

## Related issues

- Fixes
https://linear.app/pomerium/issue/ENG-2462/mcp-ignore-unknown-grant-types-in-the-client-registration
- Fixes
https://linear.app/pomerium/issue/ENG-2461/mcp-support-client-secret-in-dynamic-client-registration
 
## User Explanation

<!-- How would you explain this change to the user? If this
change doesn't create any user-facing changes, you can leave
this blank. If filled out, add the `docs` label -->

## Checklist

- [x] reference any related issues
- [x] updated unit tests
- [x] add appropriate label (`enhancement`, `bug`, `breaking`,
`dependencies`, `ci`)
- [ ] ready for review
2025-06-11 11:28:24 -04:00

179 lines
5.9 KiB
Go

package mcp
import (
"encoding/json"
"fmt"
"net/http"
"time"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/proto"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/oauth21"
oauth21proto "github.com/pomerium/pomerium/internal/oauth21/gen"
rfc7591v1 "github.com/pomerium/pomerium/internal/rfc7591"
)
// Token handles the /token endpoint.
func (srv *Handler) Token(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed)
return
}
ctx := r.Context()
req, err := srv.getTokenRequest(r)
if err != nil {
log.Ctx(ctx).Error().Err(err).Msg("get token request failed")
oauth21.ErrorResponse(w, http.StatusBadRequest, oauth21.InvalidRequest)
return
}
switch req.GrantType {
case "authorization_code":
log.Ctx(ctx).Debug().Msg("handling authorization_code token request")
srv.handleAuthorizationCodeToken(w, r, req)
default:
log.Ctx(ctx).Error().Msgf("unsupported grant type: %s", req.GrantType)
oauth21.ErrorResponse(w, http.StatusBadRequest, oauth21.UnsupportedGrantType)
return
}
}
func (srv *Handler) handleAuthorizationCodeToken(w http.ResponseWriter, r *http.Request, tokenReq *oauth21proto.TokenRequest) {
ctx := r.Context()
if tokenReq.ClientId == nil {
log.Ctx(ctx).Error().Msg("missing client_id in token request")
oauth21.ErrorResponse(w, http.StatusBadRequest, oauth21.InvalidClient)
return
}
if tokenReq.Code == nil {
log.Ctx(ctx).Error().Msg("missing code in token request")
oauth21.ErrorResponse(w, http.StatusBadRequest, oauth21.InvalidGrant)
return
}
code, err := DecryptCode(CodeTypeAuthorization, *tokenReq.Code, srv.cipher, *tokenReq.ClientId, time.Now())
if err != nil {
log.Ctx(ctx).Error().Err(err).Msg("failed to decrypt authorization code")
oauth21.ErrorResponse(w, http.StatusBadRequest, oauth21.InvalidGrant)
return
}
authReq, err := srv.storage.GetAuthorizationRequest(ctx, code.Id)
if status.Code(err) == codes.NotFound {
log.Ctx(ctx).Error().Msg("authorization request not found")
oauth21.ErrorResponse(w, http.StatusBadRequest, oauth21.InvalidGrant)
return
}
if err != nil {
log.Ctx(ctx).Error().Err(err).Msg("failed to get authorization request and client")
http.Error(w, "internal error", http.StatusInternalServerError)
}
if *tokenReq.ClientId != authReq.ClientId {
log.Ctx(ctx).Error().Msgf("client ID mismatch: %s != %s", *tokenReq.ClientId, authReq.ClientId)
oauth21.ErrorResponse(w, http.StatusBadRequest, oauth21.InvalidGrant)
return
}
err = CheckPKCE(authReq.GetCodeChallengeMethod(), authReq.GetCodeChallenge(), tokenReq.GetCodeVerifier())
if err != nil {
log.Ctx(ctx).Error().Err(err).Msg("failed to check PKCE")
oauth21.ErrorResponse(w, http.StatusBadRequest, oauth21.InvalidGrant)
return
}
// The authorization server MUST return an access token only once for a given authorization code.
// https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-12#section-4.1.3
err = srv.storage.DeleteAuthorizationRequest(ctx, code.Id)
if err != nil {
log.Ctx(ctx).Error().Err(err).Msg("failed to delete authorization request")
http.Error(w, "internal error", http.StatusInternalServerError)
return
}
session, err := srv.storage.GetSession(ctx, authReq.SessionId)
if status.Code(err) == codes.NotFound {
log.Ctx(ctx).Error().Msg("session not found")
oauth21.ErrorResponse(w, http.StatusBadRequest, oauth21.InvalidGrant)
return
}
accessToken, err := srv.GetAccessTokenForSession(session.Id, session.ExpiresAt.AsTime())
if err != nil {
log.Ctx(ctx).Error().Err(err).Msg("failed to get access token for session")
http.Error(w, "internal error", http.StatusInternalServerError)
return
}
expiresIn := time.Until(session.ExpiresAt.AsTime())
if expiresIn < 0 {
log.Ctx(ctx).Error().Msg("session has already expired")
oauth21.ErrorResponse(w, http.StatusBadRequest, oauth21.InvalidGrant)
return
}
resp := &oauth21proto.TokenResponse{
AccessToken: accessToken,
TokenType: "Bearer",
ExpiresIn: proto.Int64(int64(expiresIn.Seconds())),
}
data, err := json.Marshal(resp) // not using protojson.Marshal here because it emits numbers as strings, which is valid, but for some reason Node.js / mcp typescript SDK doesn't like it
if err != nil {
log.Ctx(ctx).Error().Err(err).Msg("failed to marshal token response")
http.Error(w, "internal error", http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Cache-Control", "no-store")
w.Header().Set("Pragma", "no-cache")
w.WriteHeader(http.StatusOK)
_, _ = w.Write(data)
}
func (srv *Handler) getTokenRequest(
r *http.Request,
) (*oauth21proto.TokenRequest, error) {
tokenReq, err := oauth21.ParseTokenRequest(r)
if err != nil {
return nil, fmt.Errorf("failed to parse token request: %w", err)
}
ctx := r.Context()
clientReg, err := srv.storage.GetClient(ctx, tokenReq.GetClientId())
if err != nil {
return nil, fmt.Errorf("failed to get client registration: %w", err)
}
m := clientReg.ResponseMetadata.GetTokenEndpointAuthMethod()
if m == rfc7591v1.TokenEndpointAuthMethodNone {
return tokenReq, nil
}
secret := clientReg.ClientSecret
if secret == nil {
return nil, fmt.Errorf("client registration does not have a client secret")
}
if expires := secret.ExpiresAt; expires != nil && expires.AsTime().Before(time.Now()) {
return nil, fmt.Errorf("client registration client secret has expired")
}
switch m {
case rfc7591v1.TokenEndpointAuthMethodClientSecretBasic:
case rfc7591v1.TokenEndpointAuthMethodClientSecretPost:
if tokenReq.ClientSecret == nil {
return nil, fmt.Errorf("client_secret was not provided")
}
if tokenReq.GetClientSecret() != secret.Value {
return nil, fmt.Errorf("client secret mismatch")
}
default:
return nil, fmt.Errorf("unsupported token endpoint authentication method: %s", m)
}
return tokenReq, nil
}