This commit is contained in:
Denis Mishin 2025-04-25 15:54:26 -04:00 committed by GitHub
commit e129aa8fa2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 240 additions and 15 deletions

View file

@ -1,9 +1,14 @@
package mcp
import (
"encoding/json"
"net/http"
"time"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/proto"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/oauth21"
oauth21proto "github.com/pomerium/pomerium/internal/oauth21/gen"
@ -50,10 +55,13 @@ func (srv *Handler) handleAuthorizationCodeToken(w http.ResponseWriter, r *http.
}
authReq, err := srv.storage.GetAuthorizationRequest(ctx, code.Id)
if err != nil {
if status.Code(err) == codes.NotFound {
oauth21.ErrorResponse(w, http.StatusBadRequest, oauth21.InvalidGrant)
return
}
if err != nil {
http.Error(w, "internal error", http.StatusInternalServerError)
}
err = AuthorizeTokenRequest(req, authReq)
if err != nil {
@ -61,5 +69,46 @@ func (srv *Handler) handleAuthorizationCodeToken(w http.ResponseWriter, r *http.
return
}
http.Error(w, "Not Implemented", http.StatusNotImplemented)
// The authorization server MUST return an access token only once for a given authorization code.
// https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-12#section-4.1.3
err = srv.storage.DeleteAuthorizationRequest(ctx, code.Id)
if err != nil {
http.Error(w, "internal error", http.StatusInternalServerError)
return
}
session, err := srv.storage.GetSession(ctx, authReq.SessionId)
if status.Code(err) == codes.NotFound {
oauth21.ErrorResponse(w, http.StatusBadRequest, oauth21.InvalidGrant)
return
}
accessToken, err := CreateAccessToken(session, srv.cipher)
if err != nil {
http.Error(w, "internal error", http.StatusInternalServerError)
return
}
expiresIn := time.Until(session.ExpiresAt.AsTime())
if expiresIn < 0 {
oauth21.ErrorResponse(w, http.StatusBadRequest, oauth21.InvalidGrant)
return
}
resp := &oauth21proto.TokenResponse{
AccessToken: accessToken,
TokenType: "Bearer",
ExpiresIn: proto.Int64(int64(expiresIn.Seconds())),
}
data, err := json.Marshal(resp) // not using protojson.Marshal here because it emits numbers as strings, which is valid, but for some reason Node.js / mcp typescript SDK doesn't like it
if err != nil {
http.Error(w, "internal error", http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Cache-Control", "no-store")
w.Header().Set("Pragma", "no-cache")
w.WriteHeader(http.StatusOK)
_, _ = w.Write(data)
}

View file

@ -7,10 +7,12 @@ import (
"github.com/google/uuid"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/anypb"
"google.golang.org/protobuf/types/known/timestamppb"
oauth21proto "github.com/pomerium/pomerium/internal/oauth21/gen"
rfc7591v1 "github.com/pomerium/pomerium/internal/rfc7591"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/grpc/session"
"github.com/pomerium/pomerium/pkg/protoutil"
)
@ -106,3 +108,40 @@ func (storage *Storage) GetAuthorizationRequest(
return v, nil
}
func (storage *Storage) DeleteAuthorizationRequest(
ctx context.Context,
id string,
) error {
data := protoutil.NewAny(&oauth21proto.AuthorizationRequest{})
_, err := storage.client.Put(ctx, &databroker.PutRequest{
Records: []*databroker.Record{{
Id: id,
Data: data,
Type: data.TypeUrl,
DeletedAt: timestamppb.Now(),
}},
})
if err != nil {
return fmt.Errorf("failed to delete authorization request by ID: %w", err)
}
return nil
}
func (storage *Storage) GetSession(ctx context.Context, id string) (*session.Session, error) {
v := new(session.Session)
rec, err := storage.client.Get(ctx, &databroker.GetRequest{
Type: protoutil.GetTypeURL(v),
Id: id,
})
if err != nil {
return nil, fmt.Errorf("failed to get session by ID: %w", err)
}
err = anypb.UnmarshalTo(rec.Record.Data, v, proto.UnmarshalOptions{})
if err != nil {
return nil, fmt.Errorf("failed to unmarshal session: %w", err)
}
return v, nil
}

View file

@ -1,10 +1,12 @@
package mcp
import (
"crypto/cipher"
"fmt"
"github.com/pomerium/pomerium/internal/oauth21"
"github.com/pomerium/pomerium/internal/oauth21/gen"
"github.com/pomerium/pomerium/pkg/grpc/session"
)
func AuthorizeTokenRequest(
@ -35,3 +37,7 @@ func AuthorizeTokenRequest(
return nil
}
func CreateAccessToken(src *session.Session, cipher cipher.AEAD) (string, error) {
return CreateCode(CodeTypeAccess, src.Id, src.ExpiresAt.AsTime(), "", cipher)
}

View file

@ -134,6 +134,93 @@ func (x *TokenRequest) GetClientSecret() string {
return ""
}
// Represents a successful response from the Token Endpoint (Section 3.2.3).
type TokenResponse struct {
state protoimpl.MessageState `protogen:"open.v1"`
// REQUIRED. The access token issued by the authorization server.
AccessToken string `protobuf:"bytes,1,opt,name=access_token,json=accessToken,proto3" json:"access_token,omitempty"`
// REQUIRED. The type of the token issued (e.g., "Bearer"). Value is case-insensitive.
// See Section 1.4 and Section 6.1.
TokenType string `protobuf:"bytes,2,opt,name=token_type,json=tokenType,proto3" json:"token_type,omitempty"`
// RECOMMENDED. The lifetime in seconds of the access token.
// If omitted, the AS should provide expiration via other means or document the default.
ExpiresIn *int64 `protobuf:"varint,3,opt,name=expires_in,json=expiresIn,proto3,oneof" json:"expires_in,omitempty"`
// OPTIONAL. The refresh token, which can be used to obtain new access tokens.
// Issued based on AS policy and the original grant type.
RefreshToken *string `protobuf:"bytes,4,opt,name=refresh_token,json=refreshToken,proto3,oneof" json:"refresh_token,omitempty"`
// RECOMMENDED if the issued scope is identical to the scope requested by the client,
// otherwise REQUIRED. The scope of the access token granted. Space-delimited list.
// See Section 1.4.1.
Scope *string `protobuf:"bytes,5,opt,name=scope,proto3,oneof" json:"scope,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *TokenResponse) Reset() {
*x = TokenResponse{}
mi := &file_token_proto_msgTypes[1]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *TokenResponse) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*TokenResponse) ProtoMessage() {}
func (x *TokenResponse) ProtoReflect() protoreflect.Message {
mi := &file_token_proto_msgTypes[1]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use TokenResponse.ProtoReflect.Descriptor instead.
func (*TokenResponse) Descriptor() ([]byte, []int) {
return file_token_proto_rawDescGZIP(), []int{1}
}
func (x *TokenResponse) GetAccessToken() string {
if x != nil {
return x.AccessToken
}
return ""
}
func (x *TokenResponse) GetTokenType() string {
if x != nil {
return x.TokenType
}
return ""
}
func (x *TokenResponse) GetExpiresIn() int64 {
if x != nil && x.ExpiresIn != nil {
return *x.ExpiresIn
}
return 0
}
func (x *TokenResponse) GetRefreshToken() string {
if x != nil && x.RefreshToken != nil {
return *x.RefreshToken
}
return ""
}
func (x *TokenResponse) GetScope() string {
if x != nil && x.Scope != nil {
return *x.Scope
}
return ""
}
var File_token_proto protoreflect.FileDescriptor
var file_token_proto_rawDesc = string([]byte{
@ -192,16 +279,33 @@ var file_token_proto_rawDesc = string([]byte{
0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x42, 0x10, 0x0a, 0x0e, 0x5f, 0x72, 0x65,
0x66, 0x72, 0x65, 0x73, 0x68, 0x5f, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x42, 0x08, 0x0a, 0x06, 0x5f,
0x73, 0x63, 0x6f, 0x70, 0x65, 0x42, 0x10, 0x0a, 0x0e, 0x5f, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74,
0x5f, 0x73, 0x65, 0x63, 0x72, 0x65, 0x74, 0x42, 0x88, 0x01, 0x0a, 0x0b, 0x63, 0x6f, 0x6d, 0x2e,
0x6f, 0x61, 0x75, 0x74, 0x68, 0x32, 0x31, 0x42, 0x0a, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x50, 0x72,
0x6f, 0x74, 0x6f, 0x50, 0x01, 0x5a, 0x31, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f,
0x6d, 0x2f, 0x70, 0x6f, 0x6d, 0x65, 0x72, 0x69, 0x75, 0x6d, 0x2f, 0x70, 0x6f, 0x6d, 0x65, 0x72,
0x69, 0x75, 0x6d, 0x2f, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x6e, 0x61, 0x6c, 0x2f, 0x6f, 0x61, 0x75,
0x74, 0x68, 0x32, 0x31, 0x2f, 0x67, 0x65, 0x6e, 0xa2, 0x02, 0x03, 0x4f, 0x58, 0x58, 0xaa, 0x02,
0x07, 0x4f, 0x61, 0x75, 0x74, 0x68, 0x32, 0x31, 0xca, 0x02, 0x07, 0x4f, 0x61, 0x75, 0x74, 0x68,
0x32, 0x31, 0xe2, 0x02, 0x13, 0x4f, 0x61, 0x75, 0x74, 0x68, 0x32, 0x31, 0x5c, 0x47, 0x50, 0x42,
0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0xea, 0x02, 0x07, 0x4f, 0x61, 0x75, 0x74, 0x68,
0x32, 0x31, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
0x5f, 0x73, 0x65, 0x63, 0x72, 0x65, 0x74, 0x22, 0x92, 0x02, 0x0a, 0x0d, 0x54, 0x6f, 0x6b, 0x65,
0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x2a, 0x0a, 0x0c, 0x61, 0x63, 0x63,
0x65, 0x73, 0x73, 0x5f, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x42,
0x07, 0xba, 0x48, 0x04, 0x72, 0x02, 0x10, 0x01, 0x52, 0x0b, 0x61, 0x63, 0x63, 0x65, 0x73, 0x73,
0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x12, 0x26, 0x0a, 0x0a, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x5f, 0x74,
0x79, 0x70, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x42, 0x07, 0xba, 0x48, 0x04, 0x72, 0x02,
0x10, 0x01, 0x52, 0x09, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x54, 0x79, 0x70, 0x65, 0x12, 0x2b, 0x0a,
0x0a, 0x65, 0x78, 0x70, 0x69, 0x72, 0x65, 0x73, 0x5f, 0x69, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28,
0x03, 0x42, 0x07, 0xba, 0x48, 0x04, 0x22, 0x02, 0x28, 0x00, 0x48, 0x00, 0x52, 0x09, 0x65, 0x78,
0x70, 0x69, 0x72, 0x65, 0x73, 0x49, 0x6e, 0x88, 0x01, 0x01, 0x12, 0x31, 0x0a, 0x0d, 0x72, 0x65,
0x66, 0x72, 0x65, 0x73, 0x68, 0x5f, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x18, 0x04, 0x20, 0x01, 0x28,
0x09, 0x42, 0x07, 0xba, 0x48, 0x04, 0x72, 0x02, 0x10, 0x01, 0x48, 0x01, 0x52, 0x0c, 0x72, 0x65,
0x66, 0x72, 0x65, 0x73, 0x68, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x88, 0x01, 0x01, 0x12, 0x22, 0x0a,
0x05, 0x73, 0x63, 0x6f, 0x70, 0x65, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x42, 0x07, 0xba, 0x48,
0x04, 0x72, 0x02, 0x10, 0x01, 0x48, 0x02, 0x52, 0x05, 0x73, 0x63, 0x6f, 0x70, 0x65, 0x88, 0x01,
0x01, 0x42, 0x0d, 0x0a, 0x0b, 0x5f, 0x65, 0x78, 0x70, 0x69, 0x72, 0x65, 0x73, 0x5f, 0x69, 0x6e,
0x42, 0x10, 0x0a, 0x0e, 0x5f, 0x72, 0x65, 0x66, 0x72, 0x65, 0x73, 0x68, 0x5f, 0x74, 0x6f, 0x6b,
0x65, 0x6e, 0x42, 0x08, 0x0a, 0x06, 0x5f, 0x73, 0x63, 0x6f, 0x70, 0x65, 0x42, 0x88, 0x01, 0x0a,
0x0b, 0x63, 0x6f, 0x6d, 0x2e, 0x6f, 0x61, 0x75, 0x74, 0x68, 0x32, 0x31, 0x42, 0x0a, 0x54, 0x6f,
0x6b, 0x65, 0x6e, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x50, 0x01, 0x5a, 0x31, 0x67, 0x69, 0x74, 0x68,
0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x70, 0x6f, 0x6d, 0x65, 0x72, 0x69, 0x75, 0x6d, 0x2f,
0x70, 0x6f, 0x6d, 0x65, 0x72, 0x69, 0x75, 0x6d, 0x2f, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x6e, 0x61,
0x6c, 0x2f, 0x6f, 0x61, 0x75, 0x74, 0x68, 0x32, 0x31, 0x2f, 0x67, 0x65, 0x6e, 0xa2, 0x02, 0x03,
0x4f, 0x58, 0x58, 0xaa, 0x02, 0x07, 0x4f, 0x61, 0x75, 0x74, 0x68, 0x32, 0x31, 0xca, 0x02, 0x07,
0x4f, 0x61, 0x75, 0x74, 0x68, 0x32, 0x31, 0xe2, 0x02, 0x13, 0x4f, 0x61, 0x75, 0x74, 0x68, 0x32,
0x31, 0x5c, 0x47, 0x50, 0x42, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0xea, 0x02, 0x07,
0x4f, 0x61, 0x75, 0x74, 0x68, 0x32, 0x31, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
})
var (
@ -216,9 +320,10 @@ func file_token_proto_rawDescGZIP() []byte {
return file_token_proto_rawDescData
}
var file_token_proto_msgTypes = make([]protoimpl.MessageInfo, 1)
var file_token_proto_msgTypes = make([]protoimpl.MessageInfo, 2)
var file_token_proto_goTypes = []any{
(*TokenRequest)(nil), // 0: oauth21.TokenRequest
(*TokenRequest)(nil), // 0: oauth21.TokenRequest
(*TokenResponse)(nil), // 1: oauth21.TokenResponse
}
var file_token_proto_depIdxs = []int32{
0, // [0:0] is the sub-list for method output_type
@ -234,13 +339,14 @@ func file_token_proto_init() {
return
}
file_token_proto_msgTypes[0].OneofWrappers = []any{}
file_token_proto_msgTypes[1].OneofWrappers = []any{}
type x struct{}
out := protoimpl.TypeBuilder{
File: protoimpl.DescBuilder{
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: unsafe.Slice(unsafe.StringData(file_token_proto_rawDesc), len(file_token_proto_rawDesc)),
NumEnums: 0,
NumMessages: 1,
NumMessages: 2,
NumExtensions: 0,
NumServices: 0,
},

View file

@ -81,3 +81,28 @@ message TokenRequest {
// The client secret.
optional string client_secret = 7 [(buf.validate.field).string.min_len = 1];
}
// Represents a successful response from the Token Endpoint (Section 3.2.3).
message TokenResponse {
// REQUIRED. The access token issued by the authorization server.
string access_token = 1 [(buf.validate.field).string.min_len = 1];
// REQUIRED. The type of the token issued (e.g., "Bearer"). Value is case-insensitive.
// See Section 1.4 and Section 6.1.
string token_type = 2 [(buf.validate.field).string.min_len = 1];
// RECOMMENDED. The lifetime in seconds of the access token.
// If omitted, the AS should provide expiration via other means or document the default.
optional int64 expires_in = 3 [(buf.validate.field).int64.gte = 0];
// OPTIONAL. The refresh token, which can be used to obtain new access tokens.
// Issued based on AS policy and the original grant type.
optional string refresh_token = 4 [(buf.validate.field).string.min_len = 1];
// RECOMMENDED if the issued scope is identical to the scope requested by the client,
// otherwise REQUIRED. The scope of the access token granted. Space-delimited list.
// See Section 1.4.1.
optional string scope = 5 [(buf.validate.field).string = {
min_len: 1,
}];
}