mcp: token: handle authorization_code (pt2) (#5589)

This commit is contained in:
Denis Mishin 2025-04-28 14:37:19 -04:00 committed by GitHub
parent 7b9c392531
commit 0602f5e00d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 240 additions and 16 deletions

View file

@ -1,9 +1,14 @@
package mcp
import (
"encoding/json"
"net/http"
"time"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/proto"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/oauth21"
oauth21proto "github.com/pomerium/pomerium/internal/oauth21/gen"
@ -39,7 +44,6 @@ func (srv *Handler) handleAuthorizationCodeToken(w http.ResponseWriter, r *http.
oauth21.ErrorResponse(w, http.StatusBadRequest, oauth21.InvalidClient)
return
}
if req.Code == nil {
oauth21.ErrorResponse(w, http.StatusBadRequest, oauth21.InvalidGrant)
return
@ -51,10 +55,13 @@ func (srv *Handler) handleAuthorizationCodeToken(w http.ResponseWriter, r *http.
}
authReq, err := srv.storage.GetAuthorizationRequest(ctx, code.Id)
if err != nil {
if status.Code(err) == codes.NotFound {
oauth21.ErrorResponse(w, http.StatusBadRequest, oauth21.InvalidGrant)
return
}
if err != nil {
http.Error(w, "internal error", http.StatusInternalServerError)
}
if *req.ClientId != authReq.ClientId {
oauth21.ErrorResponse(w, http.StatusBadRequest, oauth21.InvalidGrant)
@ -66,5 +73,46 @@ func (srv *Handler) handleAuthorizationCodeToken(w http.ResponseWriter, r *http.
return
}
http.Error(w, "Not Implemented", http.StatusNotImplemented)
// The authorization server MUST return an access token only once for a given authorization code.
// https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-12#section-4.1.3
err = srv.storage.DeleteAuthorizationRequest(ctx, code.Id)
if err != nil {
http.Error(w, "internal error", http.StatusInternalServerError)
return
}
session, err := srv.storage.GetSession(ctx, authReq.SessionId)
if status.Code(err) == codes.NotFound {
oauth21.ErrorResponse(w, http.StatusBadRequest, oauth21.InvalidGrant)
return
}
accessToken, err := CreateAccessToken(session, srv.cipher)
if err != nil {
http.Error(w, "internal error", http.StatusInternalServerError)
return
}
expiresIn := time.Until(session.ExpiresAt.AsTime())
if expiresIn < 0 {
oauth21.ErrorResponse(w, http.StatusBadRequest, oauth21.InvalidGrant)
return
}
resp := &oauth21proto.TokenResponse{
AccessToken: accessToken,
TokenType: "Bearer",
ExpiresIn: proto.Int64(int64(expiresIn.Seconds())),
}
data, err := json.Marshal(resp) // not using protojson.Marshal here because it emits numbers as strings, which is valid, but for some reason Node.js / mcp typescript SDK doesn't like it
if err != nil {
http.Error(w, "internal error", http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Cache-Control", "no-store")
w.Header().Set("Pragma", "no-cache")
w.WriteHeader(http.StatusOK)
_, _ = w.Write(data)
}