mcp: authorize: response with code

This commit is contained in:
Denis Mishin 2025-04-24 10:23:30 -04:00
parent 52af622cc4
commit 565f08db49
7 changed files with 346 additions and 4 deletions

50
internal/mcp/cipher.go Normal file
View file

@ -0,0 +1,50 @@
package mcp
import (
"crypto/cipher"
"crypto/sha256"
"fmt"
"io"
"golang.org/x/crypto/hkdf"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/pkg/cryptutil"
)
func getCipher(
cfg *config.Config,
) (cipher.AEAD, error) {
secret, err := cfg.Options.GetSharedKey()
if err != nil {
return nil, fmt.Errorf("shared key: %w", err)
}
rnd := hkdf.New(sha256.New, secret, nil, nil)
cipher, err := initCipher(rnd)
if err != nil {
return nil, fmt.Errorf("new aead cipher: %w", err)
}
return cipher, nil
}
func readKey(r io.Reader) ([]byte, error) {
b := make([]byte, cryptutil.DefaultKeySize)
_, err := io.ReadFull(r, b)
if err != nil {
return nil, fmt.Errorf("read from hkdf: %w", err)
}
return b, nil
}
func initCipher(r io.Reader) (cipher.AEAD, error) {
cipherKey, err := readKey(r)
if err != nil {
return nil, fmt.Errorf("read key: %w", err)
}
cipher, err := cryptutil.NewAEADCipher(cipherKey)
if err != nil {
return nil, fmt.Errorf("new aead cipher: %w", err)
}
return cipher, nil
}

68
internal/mcp/code.go Normal file
View file

@ -0,0 +1,68 @@
package mcp
import (
"crypto/cipher"
"encoding/base64"
"fmt"
"time"
"github.com/bufbuild/protovalidate-go"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/timestamppb"
oauth21proto "github.com/pomerium/pomerium/internal/oauth21/gen"
"github.com/pomerium/pomerium/pkg/cryptutil"
)
func CreateCode(
id string,
expires time.Time,
ad string,
cipher cipher.AEAD,
) (string, error) {
v := oauth21proto.Code{
Id: id,
ExpiresAt: timestamppb.New(expires),
}
err := protovalidate.Validate(&v)
if err != nil {
return "", fmt.Errorf("validate: %w", err)
}
b, err := proto.Marshal(&v)
if err != nil {
return "", err
}
ciphertext := cryptutil.Encrypt(cipher, b, []byte(ad))
return base64.StdEncoding.EncodeToString(ciphertext), nil
}
func DecryptCode(
code string,
cipher cipher.AEAD,
ad string,
now time.Time,
) (*oauth21proto.Code, error) {
b, err := base64.StdEncoding.DecodeString(code)
if err != nil {
return nil, fmt.Errorf("base64 decode: %w", err)
}
plaintext, err := cryptutil.Decrypt(cipher, b, []byte(ad))
if err != nil {
return nil, fmt.Errorf("decrypt: %w", err)
}
var v oauth21proto.Code
err = proto.Unmarshal(plaintext, &v)
if err != nil {
return nil, fmt.Errorf("unmarshal: %w", err)
}
if v.ExpiresAt == nil {
return nil, fmt.Errorf("expiration is nil")
}
if v.ExpiresAt.AsTime().Before(now) {
return nil, fmt.Errorf("code expired")
}
return &v, nil
}

View file

@ -2,6 +2,7 @@ package mcp
import (
"context"
"crypto/cipher"
"fmt"
"net/http"
"path"
@ -32,6 +33,7 @@ type Handler struct {
prefix string
trace oteltrace.TracerProvider
storage *Storage
cipher cipher.AEAD
}
func New(
@ -46,10 +48,16 @@ func New(
return nil, fmt.Errorf("databroker client: %w", err)
}
cipher, err := getCipher(cfg)
if err != nil {
return nil, fmt.Errorf("get cipher: %w", err)
}
return &Handler{
prefix: prefix,
trace: tracerProvider,
storage: NewStorage(client),
cipher: cipher,
}, nil
}

View file

@ -1,14 +1,18 @@
package mcp
import (
"context"
"errors"
"net/http"
"net/url"
"time"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/oauth21"
oauth21proto "github.com/pomerium/pomerium/internal/oauth21/gen"
)
// Authorize handles the /authorize endpoint.
@ -29,6 +33,7 @@ func (srv *Handler) Authorize(w http.ResponseWriter, r *http.Request) {
client, err := srv.storage.GetClientByID(ctx, v.ClientId)
if err != nil && status.Code(err) == codes.NotFound {
log.Ctx(ctx).Error().Err(err).Str("id", v.ClientId).Msg("client id not found")
oauth21.ErrorResponse(w, http.StatusUnauthorized, oauth21.InvalidClient)
return
}
@ -46,12 +51,46 @@ func (srv *Handler) Authorize(w http.ResponseWriter, r *http.Request) {
return
}
_, err = srv.storage.CreateAuthorizationRequest(ctx, v)
id, err := srv.storage.CreateAuthorizationRequest(ctx, v)
if err != nil {
log.Ctx(ctx).Error().Err(err).Msg("failed to store authorization request")
http.Error(w, "cannot create authorization request", http.StatusInternalServerError)
return
}
http.Error(w, "not implemented", http.StatusNotImplemented)
srv.AuthorizationResponse(ctx, w, r, id, v)
}
// AuthorizationResponse generates the successful authorization response
// see https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-12#section-4.1.2
func (srv *Handler) AuthorizationResponse(
ctx context.Context,
w http.ResponseWriter,
r *http.Request,
id string,
req *oauth21proto.AuthorizationRequest,
) {
code, err := CreateCode(
id,
time.Now().Add(time.Minute*10),
req.ClientId,
srv.cipher,
)
if err != nil {
log.Ctx(ctx).Error().Err(err).Msg("failed to create code")
http.Error(w, "internal server error", http.StatusInternalServerError)
return
}
to, err := url.Parse(req.GetRedirectUri())
if err != nil {
log.Ctx(ctx).Error().Err(err).Msg("failed to parse redirect uri")
http.Error(w, "invalid redirect uri", http.StatusBadRequest)
return
}
q := to.Query()
q.Set("code", code)
q.Set("state", req.GetState())
to.RawQuery = q.Encode()
http.Redirect(w, r, to.String(), http.StatusFound)
}

View file

@ -50,14 +50,16 @@ func (storage *Storage) GetClientByID(
ctx context.Context,
id string,
) (*rfc7591v1.ClientMetadata, error) {
v := new(rfc7591v1.ClientMetadata)
rec, err := storage.client.Get(ctx, &databroker.GetRequest{
Type: protoutil.GetTypeURL(v),
Id: id,
})
if err != nil {
return nil, fmt.Errorf("failed to get client by ID: %w", err)
}
v := new(rfc7591v1.ClientMetadata)
err = anypb.UnmarshalTo(rec.Record.Data, v, proto.UnmarshalOptions{})
if err != nil {
return nil, fmt.Errorf("failed to unmarshal client registration request: %w", err)

View file

@ -0,0 +1,155 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.36.5
// protoc (unknown)
// source: code.proto
package gen
import (
_ "buf.build/gen/go/bufbuild/protovalidate/protocolbuffers/go/buf/validate"
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
timestamppb "google.golang.org/protobuf/types/known/timestamppb"
reflect "reflect"
sync "sync"
unsafe "unsafe"
)
const (
// Verify that this generated code is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
// Verify that runtime/protoimpl is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
)
// Code is a code used in the authorization code flow.
type Code struct {
state protoimpl.MessageState `protogen:"open.v1"`
Id string `protobuf:"bytes,1,opt,name=id,proto3" json:"id,omitempty"`
ExpiresAt *timestamppb.Timestamp `protobuf:"bytes,2,opt,name=expires_at,json=expiresAt,proto3" json:"expires_at,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *Code) Reset() {
*x = Code{}
mi := &file_code_proto_msgTypes[0]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *Code) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*Code) ProtoMessage() {}
func (x *Code) ProtoReflect() protoreflect.Message {
mi := &file_code_proto_msgTypes[0]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use Code.ProtoReflect.Descriptor instead.
func (*Code) Descriptor() ([]byte, []int) {
return file_code_proto_rawDescGZIP(), []int{0}
}
func (x *Code) GetId() string {
if x != nil {
return x.Id
}
return ""
}
func (x *Code) GetExpiresAt() *timestamppb.Timestamp {
if x != nil {
return x.ExpiresAt
}
return nil
}
var File_code_proto protoreflect.FileDescriptor
var file_code_proto_rawDesc = string([]byte{
0x0a, 0x0a, 0x63, 0x6f, 0x64, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x07, 0x6f, 0x61,
0x75, 0x74, 0x68, 0x32, 0x31, 0x1a, 0x1f, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x70, 0x72,
0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f, 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70,
0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x1a, 0x1b, 0x62, 0x75, 0x66, 0x2f, 0x76, 0x61, 0x6c, 0x69,
0x64, 0x61, 0x74, 0x65, 0x2f, 0x76, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65, 0x2e, 0x70, 0x72,
0x6f, 0x74, 0x6f, 0x22, 0x65, 0x0a, 0x04, 0x43, 0x6f, 0x64, 0x65, 0x12, 0x1a, 0x0a, 0x02, 0x69,
0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x42, 0x0a, 0xba, 0x48, 0x07, 0xc8, 0x01, 0x01, 0x72,
0x02, 0x10, 0x01, 0x52, 0x02, 0x69, 0x64, 0x12, 0x41, 0x0a, 0x0a, 0x65, 0x78, 0x70, 0x69, 0x72,
0x65, 0x73, 0x5f, 0x61, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f,
0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69,
0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x42, 0x06, 0xba, 0x48, 0x03, 0xc8, 0x01, 0x01, 0x52,
0x09, 0x65, 0x78, 0x70, 0x69, 0x72, 0x65, 0x73, 0x41, 0x74, 0x42, 0x9c, 0x01, 0x0a, 0x0b, 0x63,
0x6f, 0x6d, 0x2e, 0x6f, 0x61, 0x75, 0x74, 0x68, 0x32, 0x31, 0x42, 0x09, 0x43, 0x6f, 0x64, 0x65,
0x50, 0x72, 0x6f, 0x74, 0x6f, 0x50, 0x01, 0x5a, 0x46, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e,
0x63, 0x6f, 0x6d, 0x2f, 0x62, 0x75, 0x66, 0x62, 0x75, 0x69, 0x6c, 0x64, 0x2f, 0x62, 0x75, 0x66,
0x2d, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x73, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x76,
0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65, 0x2f, 0x71, 0x75, 0x69, 0x63, 0x6b, 0x73, 0x74, 0x61,
0x72, 0x74, 0x2d, 0x67, 0x6f, 0x2f, 0x73, 0x74, 0x61, 0x72, 0x74, 0x2f, 0x67, 0x65, 0x6e, 0xa2,
0x02, 0x03, 0x4f, 0x58, 0x58, 0xaa, 0x02, 0x07, 0x4f, 0x61, 0x75, 0x74, 0x68, 0x32, 0x31, 0xca,
0x02, 0x07, 0x4f, 0x61, 0x75, 0x74, 0x68, 0x32, 0x31, 0xe2, 0x02, 0x13, 0x4f, 0x61, 0x75, 0x74,
0x68, 0x32, 0x31, 0x5c, 0x47, 0x50, 0x42, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0xea,
0x02, 0x07, 0x4f, 0x61, 0x75, 0x74, 0x68, 0x32, 0x31, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f,
0x33,
})
var (
file_code_proto_rawDescOnce sync.Once
file_code_proto_rawDescData []byte
)
func file_code_proto_rawDescGZIP() []byte {
file_code_proto_rawDescOnce.Do(func() {
file_code_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_code_proto_rawDesc), len(file_code_proto_rawDesc)))
})
return file_code_proto_rawDescData
}
var file_code_proto_msgTypes = make([]protoimpl.MessageInfo, 1)
var file_code_proto_goTypes = []any{
(*Code)(nil), // 0: oauth21.Code
(*timestamppb.Timestamp)(nil), // 1: google.protobuf.Timestamp
}
var file_code_proto_depIdxs = []int32{
1, // 0: oauth21.Code.expires_at:type_name -> google.protobuf.Timestamp
1, // [1:1] is the sub-list for method output_type
1, // [1:1] is the sub-list for method input_type
1, // [1:1] is the sub-list for extension type_name
1, // [1:1] is the sub-list for extension extendee
0, // [0:1] is the sub-list for field type_name
}
func init() { file_code_proto_init() }
func file_code_proto_init() {
if File_code_proto != nil {
return
}
type x struct{}
out := protoimpl.TypeBuilder{
File: protoimpl.DescBuilder{
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: unsafe.Slice(unsafe.StringData(file_code_proto_rawDesc), len(file_code_proto_rawDesc)),
NumEnums: 0,
NumMessages: 1,
NumExtensions: 0,
NumServices: 0,
},
GoTypes: file_code_proto_goTypes,
DependencyIndexes: file_code_proto_depIdxs,
MessageInfos: file_code_proto_msgTypes,
}.Build()
File_code_proto = out.File
file_code_proto_goTypes = nil
file_code_proto_depIdxs = nil
}

View file

@ -0,0 +1,20 @@
syntax = "proto3";
package oauth21;
import "google/protobuf/timestamp.proto";
import "buf/validate/validate.proto";
option go_package = "github.com/pomerium/pomerium/internal/oauth21/gen";
// Code is a code used in the authorization code flow.
message Code {
string id = 1 [
(buf.validate.field).required = true,
(buf.validate.field).string = {
min_len : 1,
}
];
google.protobuf.Timestamp expires_at = 2
[ (buf.validate.field).required = true ];
}