mcp: authorize request (pt2) (#5586)

This commit is contained in:
Denis Mishin 2025-04-24 15:11:19 -04:00 committed by GitHub
parent 63ccf6ab93
commit 9e4947c62f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 567 additions and 6 deletions

50
internal/mcp/cipher.go Normal file
View file

@ -0,0 +1,50 @@
package mcp
import (
"crypto/cipher"
"crypto/sha256"
"fmt"
"io"
"golang.org/x/crypto/hkdf"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/pkg/cryptutil"
)
func getCipher(
cfg *config.Config,
) (cipher.AEAD, error) {
secret, err := cfg.Options.GetSharedKey()
if err != nil {
return nil, fmt.Errorf("shared key: %w", err)
}
rnd := hkdf.New(sha256.New, secret, nil, []byte("model-context-protocol"))
cipher, err := initCipher(rnd)
if err != nil {
return nil, fmt.Errorf("new aead cipher: %w", err)
}
return cipher, nil
}
func readKey(r io.Reader) ([]byte, error) {
b := make([]byte, cryptutil.DefaultKeySize)
_, err := io.ReadFull(r, b)
if err != nil {
return nil, fmt.Errorf("read from hkdf: %w", err)
}
return b, nil
}
func initCipher(r io.Reader) (cipher.AEAD, error) {
cipherKey, err := readKey(r)
if err != nil {
return nil, fmt.Errorf("read key: %w", err)
}
cipher, err := cryptutil.NewAEADCipher(cipherKey)
if err != nil {
return nil, fmt.Errorf("new aead cipher: %w", err)
}
return cipher, nil
}

72
internal/mcp/code.go Normal file
View file

@ -0,0 +1,72 @@
package mcp
import (
"crypto/cipher"
"encoding/base64"
"fmt"
"time"
"github.com/bufbuild/protovalidate-go"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/timestamppb"
oauth21proto "github.com/pomerium/pomerium/internal/oauth21/gen"
"github.com/pomerium/pomerium/pkg/cryptutil"
)
func CreateCode(
id string,
expires time.Time,
ad string,
cipher cipher.AEAD,
) (string, error) {
if expires.IsZero() {
return "", fmt.Errorf("validate: zero expiration")
}
v := oauth21proto.Code{
Id: id,
ExpiresAt: timestamppb.New(expires),
}
err := protovalidate.Validate(&v)
if err != nil {
return "", fmt.Errorf("validate: %w", err)
}
b, err := proto.Marshal(&v)
if err != nil {
return "", err
}
ciphertext := cryptutil.Encrypt(cipher, b, []byte(ad))
return base64.StdEncoding.EncodeToString(ciphertext), nil
}
func DecryptCode(
code string,
cipher cipher.AEAD,
ad string,
now time.Time,
) (*oauth21proto.Code, error) {
b, err := base64.StdEncoding.DecodeString(code)
if err != nil {
return nil, fmt.Errorf("base64 decode: %w", err)
}
plaintext, err := cryptutil.Decrypt(cipher, b, []byte(ad))
if err != nil {
return nil, fmt.Errorf("decrypt: %w", err)
}
var v oauth21proto.Code
err = proto.Unmarshal(plaintext, &v)
if err != nil {
return nil, fmt.Errorf("unmarshal: %w", err)
}
if v.ExpiresAt == nil {
return nil, fmt.Errorf("expiration is nil")
}
if v.ExpiresAt.AsTime().Before(now) {
return nil, fmt.Errorf("code expired")
}
return &v, nil
}

184
internal/mcp/code_test.go Normal file
View file

@ -0,0 +1,184 @@
package mcp
import (
"crypto/cipher"
"encoding/base64"
"testing"
"time"
"github.com/google/go-cmp/cmp"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/testing/protocmp"
"google.golang.org/protobuf/types/known/timestamppb"
oauth21proto "github.com/pomerium/pomerium/internal/oauth21/gen"
"github.com/pomerium/pomerium/pkg/cryptutil"
)
func TestCreateCode(t *testing.T) {
key := cryptutil.NewKey()
testCipher, err := cryptutil.NewAEADCipher(key)
require.NoError(t, err)
tests := []struct {
name string
id string
expires time.Time
ad string
cipher cipher.AEAD
wantErr bool
errMessage string
}{
{
name: "valid code",
id: "test-id",
expires: time.Now().Add(time.Hour),
ad: "test-ad",
cipher: testCipher,
wantErr: false,
},
{
name: "empty id",
id: "",
expires: time.Now().Add(time.Hour),
ad: "test-ad",
cipher: testCipher,
wantErr: true,
errMessage: "validate",
},
{
name: "empty expires",
id: "test-id",
expires: time.Time{},
ad: "test-ad",
cipher: testCipher,
wantErr: true,
errMessage: "validate",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
code, err := CreateCode(tc.id, tc.expires, tc.ad, tc.cipher)
if tc.wantErr {
assert.Error(t, err)
if tc.errMessage != "" {
assert.Contains(t, err.Error(), tc.errMessage)
}
assert.Empty(t, code)
} else {
assert.NoError(t, err)
assert.NotEmpty(t, code)
decodedCode, err := DecryptCode(code, tc.cipher, tc.ad, time.Now())
require.NoError(t, err)
assert.Equal(t, tc.id, decodedCode.Id)
assert.True(t, proto.Equal(timestamppb.New(tc.expires), decodedCode.ExpiresAt))
}
})
}
}
func TestDecryptCode(t *testing.T) {
key := cryptutil.NewKey()
testCipher, err := cryptutil.NewAEADCipher(key)
require.NoError(t, err)
now := time.Now()
future := now.Add(time.Hour)
past := now.Add(-time.Hour)
validCode, err := CreateCode("test-id", future, "test-ad", testCipher)
require.NoError(t, err)
expiredCode, err := CreateCode("expired-id", past, "test-ad", testCipher)
require.NoError(t, err)
codeNoExpiry := &oauth21proto.Code{
Id: "no-expiry",
}
codeBytes, err := proto.Marshal(codeNoExpiry)
require.NoError(t, err)
ciphertext := cryptutil.Encrypt(testCipher, codeBytes, []byte("test-ad"))
codeNoExpiryStr := base64.StdEncoding.EncodeToString(ciphertext)
tests := []struct {
name string
code string
cipher cipher.AEAD
ad string
now time.Time
want *oauth21proto.Code
wantErr bool
errMessage string
}{
{
name: "valid code",
code: validCode,
cipher: testCipher,
ad: "test-ad",
now: now,
want: &oauth21proto.Code{Id: "test-id", ExpiresAt: timestamppb.New(future)},
wantErr: false,
},
{
name: "expired code",
code: expiredCode,
cipher: testCipher,
ad: "test-ad",
now: now,
wantErr: true,
errMessage: "code expired",
},
{
name: "nil expiration",
code: codeNoExpiryStr,
cipher: testCipher,
ad: "test-ad",
now: now,
wantErr: true,
errMessage: "expiration is nil",
},
{
name: "invalid base64",
code: "not-base64",
cipher: testCipher,
ad: "test-ad",
now: now,
wantErr: true,
errMessage: "base64 decode",
},
{
name: "wrong authentication data",
code: validCode,
cipher: testCipher,
ad: "wrong-ad",
now: now,
wantErr: true,
errMessage: "decrypt",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got, err := DecryptCode(tc.code, tc.cipher, tc.ad, tc.now)
if tc.wantErr {
assert.Error(t, err)
if tc.errMessage != "" {
assert.Contains(t, err.Error(), tc.errMessage)
}
assert.Nil(t, got)
} else {
assert.NoError(t, err)
require.NotNil(t, got)
diff := cmp.Diff(tc.want, got, protocmp.Transform())
assert.Empty(t, diff)
}
})
}
}

View file

@ -2,6 +2,7 @@ package mcp
import (
"context"
"crypto/cipher"
"fmt"
"net/http"
"path"
@ -32,6 +33,7 @@ type Handler struct {
prefix string
trace oteltrace.TracerProvider
storage *Storage
cipher cipher.AEAD
}
func New(
@ -46,10 +48,16 @@ func New(
return nil, fmt.Errorf("databroker client: %w", err)
}
cipher, err := getCipher(cfg)
if err != nil {
return nil, fmt.Errorf("get cipher: %w", err)
}
return &Handler{
prefix: prefix,
trace: tracerProvider,
storage: NewStorage(client),
cipher: cipher,
}, nil
}

View file

@ -1,9 +1,12 @@
package mcp
import (
"context"
"errors"
"fmt"
"net/http"
"net/url"
"time"
"github.com/go-jose/go-jose/v3/jwt"
"google.golang.org/grpc/codes"
@ -12,6 +15,7 @@ import (
"github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/oauth21"
oauth21proto "github.com/pomerium/pomerium/internal/oauth21/gen"
)
// Authorize handles the /authorize endpoint.
@ -37,8 +41,9 @@ func (srv *Handler) Authorize(w http.ResponseWriter, r *http.Request) {
return
}
client, err := srv.storage.GetClientByID(ctx, v.ClientId)
client, err := srv.storage.GetClient(ctx, v.ClientId)
if err != nil && status.Code(err) == codes.NotFound {
log.Ctx(ctx).Error().Err(err).Str("id", v.ClientId).Msg("client id not found")
oauth21.ErrorResponse(w, http.StatusUnauthorized, oauth21.InvalidClient)
return
}
@ -56,14 +61,48 @@ func (srv *Handler) Authorize(w http.ResponseWriter, r *http.Request) {
return
}
_, err = srv.storage.CreateAuthorizationRequest(ctx, v)
id, err := srv.storage.CreateAuthorizationRequest(ctx, v)
if err != nil {
log.Ctx(ctx).Error().Err(err).Msg("failed to store authorization request")
http.Error(w, "cannot create authorization request", http.StatusInternalServerError)
return
}
http.Error(w, "not implemented", http.StatusNotImplemented)
srv.AuthorizationResponse(ctx, w, r, id, v)
}
// AuthorizationResponse generates the successful authorization response
// see https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-12#section-4.1.2
func (srv *Handler) AuthorizationResponse(
ctx context.Context,
w http.ResponseWriter,
r *http.Request,
id string,
req *oauth21proto.AuthorizationRequest,
) {
code, err := CreateCode(
id,
time.Now().Add(time.Minute*10),
req.ClientId,
srv.cipher,
)
if err != nil {
log.Ctx(ctx).Error().Err(err).Msg("failed to create code")
http.Error(w, "internal server error", http.StatusInternalServerError)
return
}
to, err := url.Parse(req.GetRedirectUri())
if err != nil {
log.Ctx(ctx).Error().Err(err).Msg("failed to parse redirect uri")
http.Error(w, "invalid redirect uri", http.StatusBadRequest)
return
}
q := to.Query()
q.Set("code", code)
q.Set("state", req.GetState())
to.RawQuery = q.Encode()
http.Redirect(w, r, to.String(), http.StatusFound)
}
func getSessionFromRequest(r *http.Request) (string, error) {

View file

@ -46,7 +46,7 @@ func (storage *Storage) RegisterClient(
return id, nil
}
func (storage *Storage) GetClientByID(
func (storage *Storage) GetClient(
ctx context.Context,
id string,
) (*rfc7591v1.ClientMetadata, error) {
@ -85,3 +85,24 @@ func (storage *Storage) CreateAuthorizationRequest(
}
return id, nil
}
func (storage *Storage) GetAuthorizationRequest(
ctx context.Context,
id string,
) (*oauth21proto.AuthorizationRequest, error) {
v := new(oauth21proto.AuthorizationRequest)
rec, err := storage.client.Get(ctx, &databroker.GetRequest{
Type: protoutil.GetTypeURL(v),
Id: id,
})
if err != nil {
return nil, fmt.Errorf("failed to get authorization request by ID: %w", err)
}
err = anypb.UnmarshalTo(rec.Record.Data, v, proto.UnmarshalOptions{})
if err != nil {
return nil, fmt.Errorf("failed to unmarshal authorization request: %w", err)
}
return v, nil
}

View file

@ -14,6 +14,7 @@ import (
"github.com/pomerium/pomerium/internal/databroker"
"github.com/pomerium/pomerium/internal/mcp"
oauth21proto "github.com/pomerium/pomerium/internal/oauth21/gen"
rfc7591v1 "github.com/pomerium/pomerium/internal/rfc7591"
"github.com/pomerium/pomerium/internal/testutil"
databroker_grpc "github.com/pomerium/pomerium/pkg/grpc/databroker"
@ -50,15 +51,26 @@ func TestStorage(t *testing.T) {
require.NoError(t, err)
client := databroker_grpc.NewDataBrokerServiceClient(conn)
storage := mcp.NewStorage(client)
t.Run("client registration", func(t *testing.T) {
storage := mcp.NewStorage(client)
t.Parallel()
id, err := storage.RegisterClient(ctx, &rfc7591v1.ClientMetadata{})
require.NoError(t, err)
require.NotEmpty(t, id)
_, err = storage.GetClientByID(ctx, id)
_, err = storage.GetClient(ctx, id)
require.NoError(t, err)
})
t.Run("authorization request", func(t *testing.T) {
t.Parallel()
id, err := storage.CreateAuthorizationRequest(ctx, &oauth21proto.AuthorizationRequest{})
require.NoError(t, err)
_, err = storage.GetAuthorizationRequest(ctx, id)
require.NoError(t, err)
})
}

View file

@ -0,0 +1,155 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.36.5
// protoc (unknown)
// source: code.proto
package gen
import (
_ "buf.build/gen/go/bufbuild/protovalidate/protocolbuffers/go/buf/validate"
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
timestamppb "google.golang.org/protobuf/types/known/timestamppb"
reflect "reflect"
sync "sync"
unsafe "unsafe"
)
const (
// Verify that this generated code is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
// Verify that runtime/protoimpl is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
)
// Code is a code used in the authorization code flow.
type Code struct {
state protoimpl.MessageState `protogen:"open.v1"`
Id string `protobuf:"bytes,1,opt,name=id,proto3" json:"id,omitempty"`
ExpiresAt *timestamppb.Timestamp `protobuf:"bytes,2,opt,name=expires_at,json=expiresAt,proto3" json:"expires_at,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *Code) Reset() {
*x = Code{}
mi := &file_code_proto_msgTypes[0]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *Code) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*Code) ProtoMessage() {}
func (x *Code) ProtoReflect() protoreflect.Message {
mi := &file_code_proto_msgTypes[0]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use Code.ProtoReflect.Descriptor instead.
func (*Code) Descriptor() ([]byte, []int) {
return file_code_proto_rawDescGZIP(), []int{0}
}
func (x *Code) GetId() string {
if x != nil {
return x.Id
}
return ""
}
func (x *Code) GetExpiresAt() *timestamppb.Timestamp {
if x != nil {
return x.ExpiresAt
}
return nil
}
var File_code_proto protoreflect.FileDescriptor
var file_code_proto_rawDesc = string([]byte{
0x0a, 0x0a, 0x63, 0x6f, 0x64, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x07, 0x6f, 0x61,
0x75, 0x74, 0x68, 0x32, 0x31, 0x1a, 0x1f, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x70, 0x72,
0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f, 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70,
0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x1a, 0x1b, 0x62, 0x75, 0x66, 0x2f, 0x76, 0x61, 0x6c, 0x69,
0x64, 0x61, 0x74, 0x65, 0x2f, 0x76, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65, 0x2e, 0x70, 0x72,
0x6f, 0x74, 0x6f, 0x22, 0x65, 0x0a, 0x04, 0x43, 0x6f, 0x64, 0x65, 0x12, 0x1a, 0x0a, 0x02, 0x69,
0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x42, 0x0a, 0xba, 0x48, 0x07, 0xc8, 0x01, 0x01, 0x72,
0x02, 0x10, 0x01, 0x52, 0x02, 0x69, 0x64, 0x12, 0x41, 0x0a, 0x0a, 0x65, 0x78, 0x70, 0x69, 0x72,
0x65, 0x73, 0x5f, 0x61, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f,
0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69,
0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x42, 0x06, 0xba, 0x48, 0x03, 0xc8, 0x01, 0x01, 0x52,
0x09, 0x65, 0x78, 0x70, 0x69, 0x72, 0x65, 0x73, 0x41, 0x74, 0x42, 0x9c, 0x01, 0x0a, 0x0b, 0x63,
0x6f, 0x6d, 0x2e, 0x6f, 0x61, 0x75, 0x74, 0x68, 0x32, 0x31, 0x42, 0x09, 0x43, 0x6f, 0x64, 0x65,
0x50, 0x72, 0x6f, 0x74, 0x6f, 0x50, 0x01, 0x5a, 0x46, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e,
0x63, 0x6f, 0x6d, 0x2f, 0x62, 0x75, 0x66, 0x62, 0x75, 0x69, 0x6c, 0x64, 0x2f, 0x62, 0x75, 0x66,
0x2d, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x73, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x76,
0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65, 0x2f, 0x71, 0x75, 0x69, 0x63, 0x6b, 0x73, 0x74, 0x61,
0x72, 0x74, 0x2d, 0x67, 0x6f, 0x2f, 0x73, 0x74, 0x61, 0x72, 0x74, 0x2f, 0x67, 0x65, 0x6e, 0xa2,
0x02, 0x03, 0x4f, 0x58, 0x58, 0xaa, 0x02, 0x07, 0x4f, 0x61, 0x75, 0x74, 0x68, 0x32, 0x31, 0xca,
0x02, 0x07, 0x4f, 0x61, 0x75, 0x74, 0x68, 0x32, 0x31, 0xe2, 0x02, 0x13, 0x4f, 0x61, 0x75, 0x74,
0x68, 0x32, 0x31, 0x5c, 0x47, 0x50, 0x42, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0xea,
0x02, 0x07, 0x4f, 0x61, 0x75, 0x74, 0x68, 0x32, 0x31, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f,
0x33,
})
var (
file_code_proto_rawDescOnce sync.Once
file_code_proto_rawDescData []byte
)
func file_code_proto_rawDescGZIP() []byte {
file_code_proto_rawDescOnce.Do(func() {
file_code_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_code_proto_rawDesc), len(file_code_proto_rawDesc)))
})
return file_code_proto_rawDescData
}
var file_code_proto_msgTypes = make([]protoimpl.MessageInfo, 1)
var file_code_proto_goTypes = []any{
(*Code)(nil), // 0: oauth21.Code
(*timestamppb.Timestamp)(nil), // 1: google.protobuf.Timestamp
}
var file_code_proto_depIdxs = []int32{
1, // 0: oauth21.Code.expires_at:type_name -> google.protobuf.Timestamp
1, // [1:1] is the sub-list for method output_type
1, // [1:1] is the sub-list for method input_type
1, // [1:1] is the sub-list for extension type_name
1, // [1:1] is the sub-list for extension extendee
0, // [0:1] is the sub-list for field type_name
}
func init() { file_code_proto_init() }
func file_code_proto_init() {
if File_code_proto != nil {
return
}
type x struct{}
out := protoimpl.TypeBuilder{
File: protoimpl.DescBuilder{
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: unsafe.Slice(unsafe.StringData(file_code_proto_rawDesc), len(file_code_proto_rawDesc)),
NumEnums: 0,
NumMessages: 1,
NumExtensions: 0,
NumServices: 0,
},
GoTypes: file_code_proto_goTypes,
DependencyIndexes: file_code_proto_depIdxs,
MessageInfos: file_code_proto_msgTypes,
}.Build()
File_code_proto = out.File
file_code_proto_goTypes = nil
file_code_proto_depIdxs = nil
}

View file

@ -0,0 +1,20 @@
syntax = "proto3";
package oauth21;
import "google/protobuf/timestamp.proto";
import "buf/validate/validate.proto";
option go_package = "github.com/pomerium/pomerium/internal/oauth21/gen";
// Code is a code used in the authorization code flow.
message Code {
string id = 1 [
(buf.validate.field).required = true,
(buf.validate.field).string = {
min_len : 1,
}
];
google.protobuf.Timestamp expires_at = 2
[ (buf.validate.field).required = true ];
}