From ea5badda77861ed285072899b66904c6e28722c9 Mon Sep 17 00:00:00 2001 From: Denis Mishin Date: Fri, 25 Apr 2025 15:43:42 -0400 Subject: [PATCH 1/3] mcp: token: handle authorization_code request (pt2) --- internal/mcp/handler_token.go | 41 ++++++++- internal/mcp/storage.go | 19 +++++ internal/mcp/token.go | 6 ++ internal/oauth21/gen/token.pb.go | 132 ++++++++++++++++++++++++++--- internal/oauth21/proto/token.proto | 25 ++++++ 5 files changed, 209 insertions(+), 14 deletions(-) diff --git a/internal/mcp/handler_token.go b/internal/mcp/handler_token.go index a601a746a..1efa38e07 100644 --- a/internal/mcp/handler_token.go +++ b/internal/mcp/handler_token.go @@ -1,9 +1,14 @@ package mcp import ( + "encoding/json" "net/http" "time" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/proto" + "github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/oauth21" oauth21proto "github.com/pomerium/pomerium/internal/oauth21/gen" @@ -61,5 +66,39 @@ func (srv *Handler) handleAuthorizationCodeToken(w http.ResponseWriter, r *http. return } - http.Error(w, "Not Implemented", http.StatusNotImplemented) + session, err := srv.storage.GetSession(ctx, authReq.SessionId) + if status.Code(err) == codes.NotFound { + oauth21.ErrorResponse(w, http.StatusBadRequest, oauth21.InvalidGrant) + return + } + + accessToken, err := CreateAccessToken(session, srv.cipher) + if err != nil { + http.Error(w, "internal error", http.StatusInternalServerError) + return + } + + expiresIn := time.Until(session.ExpiresAt.AsTime()) + if expiresIn < 0 { + oauth21.ErrorResponse(w, http.StatusBadRequest, oauth21.InvalidGrant) + return + } + + resp := &oauth21proto.TokenResponse{ + AccessToken: accessToken, + TokenType: "Bearer", + ExpiresIn: proto.Int64(int64(expiresIn.Seconds())), + } + + data, err := json.Marshal(resp) // not using protojson.Marshal here because it emits numbers as strings, which is valid, but for some reason Node.js / mcp typescript SDK doesn't like it + if err != nil { + log.Ctx(ctx).Error().Err(err).Msg("failed to marshal token response") + http.Error(w, "internal error", http.StatusInternalServerError) + return + } + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Cache-Control", "no-store") + w.Header().Set("Pragma", "no-cache") + w.WriteHeader(http.StatusOK) + _, _ = w.Write(data) } diff --git a/internal/mcp/storage.go b/internal/mcp/storage.go index 3fd3d0e4f..1bc12f569 100644 --- a/internal/mcp/storage.go +++ b/internal/mcp/storage.go @@ -11,6 +11,7 @@ import ( oauth21proto "github.com/pomerium/pomerium/internal/oauth21/gen" rfc7591v1 "github.com/pomerium/pomerium/internal/rfc7591" "github.com/pomerium/pomerium/pkg/grpc/databroker" + "github.com/pomerium/pomerium/pkg/grpc/session" "github.com/pomerium/pomerium/pkg/protoutil" ) @@ -106,3 +107,21 @@ func (storage *Storage) GetAuthorizationRequest( return v, nil } + +func (storage *Storage) GetSession(ctx context.Context, id string) (*session.Session, error) { + v := new(session.Session) + rec, err := storage.client.Get(ctx, &databroker.GetRequest{ + Type: protoutil.GetTypeURL(v), + Id: id, + }) + if err != nil { + return nil, fmt.Errorf("failed to get session by ID: %w", err) + } + + err = anypb.UnmarshalTo(rec.Record.Data, v, proto.UnmarshalOptions{}) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal session: %w", err) + } + + return v, nil +} diff --git a/internal/mcp/token.go b/internal/mcp/token.go index 238cecc38..1d942ace9 100644 --- a/internal/mcp/token.go +++ b/internal/mcp/token.go @@ -1,10 +1,12 @@ package mcp import ( + "crypto/cipher" "fmt" "github.com/pomerium/pomerium/internal/oauth21" "github.com/pomerium/pomerium/internal/oauth21/gen" + "github.com/pomerium/pomerium/pkg/grpc/session" ) func AuthorizeTokenRequest( @@ -35,3 +37,7 @@ func AuthorizeTokenRequest( return nil } + +func CreateAccessToken(src *session.Session, cipher cipher.AEAD) (string, error) { + return CreateCode(CodeTypeAccess, src.Id, src.ExpiresAt.AsTime(), "", cipher) +} diff --git a/internal/oauth21/gen/token.pb.go b/internal/oauth21/gen/token.pb.go index f3d04c0ef..6b26c2fa1 100644 --- a/internal/oauth21/gen/token.pb.go +++ b/internal/oauth21/gen/token.pb.go @@ -134,6 +134,93 @@ func (x *TokenRequest) GetClientSecret() string { return "" } +// Represents a successful response from the Token Endpoint (Section 3.2.3). +type TokenResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + // REQUIRED. The access token issued by the authorization server. + AccessToken string `protobuf:"bytes,1,opt,name=access_token,json=accessToken,proto3" json:"access_token,omitempty"` + // REQUIRED. The type of the token issued (e.g., "Bearer"). Value is case-insensitive. + // See Section 1.4 and Section 6.1. + TokenType string `protobuf:"bytes,2,opt,name=token_type,json=tokenType,proto3" json:"token_type,omitempty"` + // RECOMMENDED. The lifetime in seconds of the access token. + // If omitted, the AS should provide expiration via other means or document the default. + ExpiresIn *int64 `protobuf:"varint,3,opt,name=expires_in,json=expiresIn,proto3,oneof" json:"expires_in,omitempty"` + // OPTIONAL. The refresh token, which can be used to obtain new access tokens. + // Issued based on AS policy and the original grant type. + RefreshToken *string `protobuf:"bytes,4,opt,name=refresh_token,json=refreshToken,proto3,oneof" json:"refresh_token,omitempty"` + // RECOMMENDED if the issued scope is identical to the scope requested by the client, + // otherwise REQUIRED. The scope of the access token granted. Space-delimited list. + // See Section 1.4.1. + Scope *string `protobuf:"bytes,5,opt,name=scope,proto3,oneof" json:"scope,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *TokenResponse) Reset() { + *x = TokenResponse{} + mi := &file_token_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *TokenResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*TokenResponse) ProtoMessage() {} + +func (x *TokenResponse) ProtoReflect() protoreflect.Message { + mi := &file_token_proto_msgTypes[1] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use TokenResponse.ProtoReflect.Descriptor instead. +func (*TokenResponse) Descriptor() ([]byte, []int) { + return file_token_proto_rawDescGZIP(), []int{1} +} + +func (x *TokenResponse) GetAccessToken() string { + if x != nil { + return x.AccessToken + } + return "" +} + +func (x *TokenResponse) GetTokenType() string { + if x != nil { + return x.TokenType + } + return "" +} + +func (x *TokenResponse) GetExpiresIn() int64 { + if x != nil && x.ExpiresIn != nil { + return *x.ExpiresIn + } + return 0 +} + +func (x *TokenResponse) GetRefreshToken() string { + if x != nil && x.RefreshToken != nil { + return *x.RefreshToken + } + return "" +} + +func (x *TokenResponse) GetScope() string { + if x != nil && x.Scope != nil { + return *x.Scope + } + return "" +} + var File_token_proto protoreflect.FileDescriptor var file_token_proto_rawDesc = string([]byte{ @@ -192,16 +279,33 @@ var file_token_proto_rawDesc = string([]byte{ 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x42, 0x10, 0x0a, 0x0e, 0x5f, 0x72, 0x65, 0x66, 0x72, 0x65, 0x73, 0x68, 0x5f, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x42, 0x08, 0x0a, 0x06, 0x5f, 0x73, 0x63, 0x6f, 0x70, 0x65, 0x42, 0x10, 0x0a, 0x0e, 0x5f, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, - 0x5f, 0x73, 0x65, 0x63, 0x72, 0x65, 0x74, 0x42, 0x88, 0x01, 0x0a, 0x0b, 0x63, 0x6f, 0x6d, 0x2e, - 0x6f, 0x61, 0x75, 0x74, 0x68, 0x32, 0x31, 0x42, 0x0a, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x50, 0x72, - 0x6f, 0x74, 0x6f, 0x50, 0x01, 0x5a, 0x31, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, - 0x6d, 0x2f, 0x70, 0x6f, 0x6d, 0x65, 0x72, 0x69, 0x75, 0x6d, 0x2f, 0x70, 0x6f, 0x6d, 0x65, 0x72, - 0x69, 0x75, 0x6d, 0x2f, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x6e, 0x61, 0x6c, 0x2f, 0x6f, 0x61, 0x75, - 0x74, 0x68, 0x32, 0x31, 0x2f, 0x67, 0x65, 0x6e, 0xa2, 0x02, 0x03, 0x4f, 0x58, 0x58, 0xaa, 0x02, - 0x07, 0x4f, 0x61, 0x75, 0x74, 0x68, 0x32, 0x31, 0xca, 0x02, 0x07, 0x4f, 0x61, 0x75, 0x74, 0x68, - 0x32, 0x31, 0xe2, 0x02, 0x13, 0x4f, 0x61, 0x75, 0x74, 0x68, 0x32, 0x31, 0x5c, 0x47, 0x50, 0x42, - 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0xea, 0x02, 0x07, 0x4f, 0x61, 0x75, 0x74, 0x68, - 0x32, 0x31, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x5f, 0x73, 0x65, 0x63, 0x72, 0x65, 0x74, 0x22, 0x92, 0x02, 0x0a, 0x0d, 0x54, 0x6f, 0x6b, 0x65, + 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x2a, 0x0a, 0x0c, 0x61, 0x63, 0x63, + 0x65, 0x73, 0x73, 0x5f, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x42, + 0x07, 0xba, 0x48, 0x04, 0x72, 0x02, 0x10, 0x01, 0x52, 0x0b, 0x61, 0x63, 0x63, 0x65, 0x73, 0x73, + 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x12, 0x26, 0x0a, 0x0a, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x5f, 0x74, + 0x79, 0x70, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x42, 0x07, 0xba, 0x48, 0x04, 0x72, 0x02, + 0x10, 0x01, 0x52, 0x09, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x54, 0x79, 0x70, 0x65, 0x12, 0x2b, 0x0a, + 0x0a, 0x65, 0x78, 0x70, 0x69, 0x72, 0x65, 0x73, 0x5f, 0x69, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, + 0x03, 0x42, 0x07, 0xba, 0x48, 0x04, 0x22, 0x02, 0x28, 0x00, 0x48, 0x00, 0x52, 0x09, 0x65, 0x78, + 0x70, 0x69, 0x72, 0x65, 0x73, 0x49, 0x6e, 0x88, 0x01, 0x01, 0x12, 0x31, 0x0a, 0x0d, 0x72, 0x65, + 0x66, 0x72, 0x65, 0x73, 0x68, 0x5f, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x18, 0x04, 0x20, 0x01, 0x28, + 0x09, 0x42, 0x07, 0xba, 0x48, 0x04, 0x72, 0x02, 0x10, 0x01, 0x48, 0x01, 0x52, 0x0c, 0x72, 0x65, + 0x66, 0x72, 0x65, 0x73, 0x68, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x88, 0x01, 0x01, 0x12, 0x22, 0x0a, + 0x05, 0x73, 0x63, 0x6f, 0x70, 0x65, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x42, 0x07, 0xba, 0x48, + 0x04, 0x72, 0x02, 0x10, 0x01, 0x48, 0x02, 0x52, 0x05, 0x73, 0x63, 0x6f, 0x70, 0x65, 0x88, 0x01, + 0x01, 0x42, 0x0d, 0x0a, 0x0b, 0x5f, 0x65, 0x78, 0x70, 0x69, 0x72, 0x65, 0x73, 0x5f, 0x69, 0x6e, + 0x42, 0x10, 0x0a, 0x0e, 0x5f, 0x72, 0x65, 0x66, 0x72, 0x65, 0x73, 0x68, 0x5f, 0x74, 0x6f, 0x6b, + 0x65, 0x6e, 0x42, 0x08, 0x0a, 0x06, 0x5f, 0x73, 0x63, 0x6f, 0x70, 0x65, 0x42, 0x88, 0x01, 0x0a, + 0x0b, 0x63, 0x6f, 0x6d, 0x2e, 0x6f, 0x61, 0x75, 0x74, 0x68, 0x32, 0x31, 0x42, 0x0a, 0x54, 0x6f, + 0x6b, 0x65, 0x6e, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x50, 0x01, 0x5a, 0x31, 0x67, 0x69, 0x74, 0x68, + 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x70, 0x6f, 0x6d, 0x65, 0x72, 0x69, 0x75, 0x6d, 0x2f, + 0x70, 0x6f, 0x6d, 0x65, 0x72, 0x69, 0x75, 0x6d, 0x2f, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x6e, 0x61, + 0x6c, 0x2f, 0x6f, 0x61, 0x75, 0x74, 0x68, 0x32, 0x31, 0x2f, 0x67, 0x65, 0x6e, 0xa2, 0x02, 0x03, + 0x4f, 0x58, 0x58, 0xaa, 0x02, 0x07, 0x4f, 0x61, 0x75, 0x74, 0x68, 0x32, 0x31, 0xca, 0x02, 0x07, + 0x4f, 0x61, 0x75, 0x74, 0x68, 0x32, 0x31, 0xe2, 0x02, 0x13, 0x4f, 0x61, 0x75, 0x74, 0x68, 0x32, + 0x31, 0x5c, 0x47, 0x50, 0x42, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0xea, 0x02, 0x07, + 0x4f, 0x61, 0x75, 0x74, 0x68, 0x32, 0x31, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, }) var ( @@ -216,9 +320,10 @@ func file_token_proto_rawDescGZIP() []byte { return file_token_proto_rawDescData } -var file_token_proto_msgTypes = make([]protoimpl.MessageInfo, 1) +var file_token_proto_msgTypes = make([]protoimpl.MessageInfo, 2) var file_token_proto_goTypes = []any{ - (*TokenRequest)(nil), // 0: oauth21.TokenRequest + (*TokenRequest)(nil), // 0: oauth21.TokenRequest + (*TokenResponse)(nil), // 1: oauth21.TokenResponse } var file_token_proto_depIdxs = []int32{ 0, // [0:0] is the sub-list for method output_type @@ -234,13 +339,14 @@ func file_token_proto_init() { return } file_token_proto_msgTypes[0].OneofWrappers = []any{} + file_token_proto_msgTypes[1].OneofWrappers = []any{} type x struct{} out := protoimpl.TypeBuilder{ File: protoimpl.DescBuilder{ GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: unsafe.Slice(unsafe.StringData(file_token_proto_rawDesc), len(file_token_proto_rawDesc)), NumEnums: 0, - NumMessages: 1, + NumMessages: 2, NumExtensions: 0, NumServices: 0, }, diff --git a/internal/oauth21/proto/token.proto b/internal/oauth21/proto/token.proto index 430c7b29e..d3e0eec52 100644 --- a/internal/oauth21/proto/token.proto +++ b/internal/oauth21/proto/token.proto @@ -81,3 +81,28 @@ message TokenRequest { // The client secret. optional string client_secret = 7 [(buf.validate.field).string.min_len = 1]; } + +// Represents a successful response from the Token Endpoint (Section 3.2.3). +message TokenResponse { + // REQUIRED. The access token issued by the authorization server. + string access_token = 1 [(buf.validate.field).string.min_len = 1]; + + // REQUIRED. The type of the token issued (e.g., "Bearer"). Value is case-insensitive. + // See Section 1.4 and Section 6.1. + string token_type = 2 [(buf.validate.field).string.min_len = 1]; + + // RECOMMENDED. The lifetime in seconds of the access token. + // If omitted, the AS should provide expiration via other means or document the default. + optional int64 expires_in = 3 [(buf.validate.field).int64.gte = 0]; + + // OPTIONAL. The refresh token, which can be used to obtain new access tokens. + // Issued based on AS policy and the original grant type. + optional string refresh_token = 4 [(buf.validate.field).string.min_len = 1]; + + // RECOMMENDED if the issued scope is identical to the scope requested by the client, + // otherwise REQUIRED. The scope of the access token granted. Space-delimited list. + // See Section 1.4.1. + optional string scope = 5 [(buf.validate.field).string = { + min_len: 1, + }]; +} From f89babc585a5ee29961deb52323f0485568201dc Mon Sep 17 00:00:00 2001 From: Denis Mishin Date: Fri, 25 Apr 2025 15:51:16 -0400 Subject: [PATCH 2/3] rm authorization request --- internal/mcp/handler_token.go | 9 ++++++++- internal/mcp/storage.go | 20 ++++++++++++++++++++ 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/internal/mcp/handler_token.go b/internal/mcp/handler_token.go index 1efa38e07..53d418576 100644 --- a/internal/mcp/handler_token.go +++ b/internal/mcp/handler_token.go @@ -66,6 +66,14 @@ func (srv *Handler) handleAuthorizationCodeToken(w http.ResponseWriter, r *http. return } + // The authorization server MUST return an access token only once for a given authorization code. + // https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-12#section-4.1.3 + err = srv.storage.DeleteAuthorizationRequest(ctx, code.Id) + if err != nil { + http.Error(w, "internal error", http.StatusInternalServerError) + return + } + session, err := srv.storage.GetSession(ctx, authReq.SessionId) if status.Code(err) == codes.NotFound { oauth21.ErrorResponse(w, http.StatusBadRequest, oauth21.InvalidGrant) @@ -92,7 +100,6 @@ func (srv *Handler) handleAuthorizationCodeToken(w http.ResponseWriter, r *http. data, err := json.Marshal(resp) // not using protojson.Marshal here because it emits numbers as strings, which is valid, but for some reason Node.js / mcp typescript SDK doesn't like it if err != nil { - log.Ctx(ctx).Error().Err(err).Msg("failed to marshal token response") http.Error(w, "internal error", http.StatusInternalServerError) return } diff --git a/internal/mcp/storage.go b/internal/mcp/storage.go index 1bc12f569..69da0824d 100644 --- a/internal/mcp/storage.go +++ b/internal/mcp/storage.go @@ -7,6 +7,7 @@ import ( "github.com/google/uuid" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/anypb" + "google.golang.org/protobuf/types/known/timestamppb" oauth21proto "github.com/pomerium/pomerium/internal/oauth21/gen" rfc7591v1 "github.com/pomerium/pomerium/internal/rfc7591" @@ -108,6 +109,25 @@ func (storage *Storage) GetAuthorizationRequest( return v, nil } +func (storage *Storage) DeleteAuthorizationRequest( + ctx context.Context, + id string, +) error { + data := protoutil.NewAny(&oauth21proto.AuthorizationRequest{}) + _, err := storage.client.Put(ctx, &databroker.PutRequest{ + Records: []*databroker.Record{{ + Id: id, + Data: data, + Type: data.TypeUrl, + DeletedAt: timestamppb.Now(), + }}, + }) + if err != nil { + return fmt.Errorf("failed to delete authorization request by ID: %w", err) + } + return nil +} + func (storage *Storage) GetSession(ctx context.Context, id string) (*session.Session, error) { v := new(session.Session) rec, err := storage.client.Get(ctx, &databroker.GetRequest{ From 0478d46568c7ec9ffd8b8a1882a4e14b928ea558 Mon Sep 17 00:00:00 2001 From: Denis Mishin Date: Fri, 25 Apr 2025 15:52:23 -0400 Subject: [PATCH 3/3] return invalid_grant when authorization request not found --- internal/mcp/handler_token.go | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/internal/mcp/handler_token.go b/internal/mcp/handler_token.go index 53d418576..1e03d1861 100644 --- a/internal/mcp/handler_token.go +++ b/internal/mcp/handler_token.go @@ -55,10 +55,13 @@ func (srv *Handler) handleAuthorizationCodeToken(w http.ResponseWriter, r *http. } authReq, err := srv.storage.GetAuthorizationRequest(ctx, code.Id) - if err != nil { + if status.Code(err) == codes.NotFound { oauth21.ErrorResponse(w, http.StatusBadRequest, oauth21.InvalidGrant) return } + if err != nil { + http.Error(w, "internal error", http.StatusInternalServerError) + } err = AuthorizeTokenRequest(req, authReq) if err != nil {