mcp: extend code usage (#5588)

This commit is contained in:
Denis Mishin 2025-04-25 14:47:11 -04:00 committed by GitHub
parent 9e4947c62f
commit 4dd5357fe3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 244 additions and 52 deletions

View file

@ -24,6 +24,7 @@ func TestCreateCode(t *testing.T) {
tests := []struct {
name string
typ oauth21proto.CodeType
id string
expires time.Time
ad string
@ -32,7 +33,26 @@ func TestCreateCode(t *testing.T) {
errMessage string
}{
{
name: "valid code",
name: "valid authorization code",
typ: CodeTypeAuthorization,
id: "test-id",
expires: time.Now().Add(time.Hour),
ad: "test-ad",
cipher: testCipher,
wantErr: false,
},
{
name: "valid refresh code",
typ: CodeTypeRefresh,
id: "test-id",
expires: time.Now().Add(time.Hour),
ad: "test-ad",
cipher: testCipher,
wantErr: false,
},
{
name: "valid access code",
typ: CodeTypeAccess,
id: "test-id",
expires: time.Now().Add(time.Hour),
ad: "test-ad",
@ -41,6 +61,7 @@ func TestCreateCode(t *testing.T) {
},
{
name: "empty id",
typ: CodeTypeAuthorization,
id: "",
expires: time.Now().Add(time.Hour),
ad: "test-ad",
@ -50,6 +71,7 @@ func TestCreateCode(t *testing.T) {
},
{
name: "empty expires",
typ: CodeTypeAuthorization,
id: "test-id",
expires: time.Time{},
ad: "test-ad",
@ -57,11 +79,31 @@ func TestCreateCode(t *testing.T) {
wantErr: true,
errMessage: "validate",
},
{
name: "invalid code type",
typ: 0, // Unspecified type
id: "test-id",
expires: time.Now().Add(time.Hour),
ad: "test-ad",
cipher: testCipher,
wantErr: true,
errMessage: "validate",
},
{
name: "undefined code type",
typ: 99, // Undefined type
id: "test-id",
expires: time.Now().Add(time.Hour),
ad: "test-ad",
cipher: testCipher,
wantErr: true,
errMessage: "validate",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
code, err := CreateCode(tc.id, tc.expires, tc.ad, tc.cipher)
code, err := CreateCode(tc.typ, tc.id, tc.expires, tc.ad, tc.cipher)
if tc.wantErr {
assert.Error(t, err)
@ -73,9 +115,10 @@ func TestCreateCode(t *testing.T) {
assert.NoError(t, err)
assert.NotEmpty(t, code)
decodedCode, err := DecryptCode(code, tc.cipher, tc.ad, time.Now())
decodedCode, err := DecryptCode(tc.typ, code, tc.cipher, tc.ad, time.Now())
require.NoError(t, err)
assert.Equal(t, tc.id, decodedCode.Id)
assert.Equal(t, tc.typ, decodedCode.GrantType)
assert.True(t, proto.Equal(timestamppb.New(tc.expires), decodedCode.ExpiresAt))
}
})
@ -91,14 +134,18 @@ func TestDecryptCode(t *testing.T) {
future := now.Add(time.Hour)
past := now.Add(-time.Hour)
validCode, err := CreateCode("test-id", future, "test-ad", testCipher)
validCode, err := CreateCode(CodeTypeAuthorization, "test-id", future, "test-ad", testCipher)
require.NoError(t, err)
expiredCode, err := CreateCode("expired-id", past, "test-ad", testCipher)
validRefreshCode, err := CreateCode(CodeTypeRefresh, "refresh-id", future, "test-ad", testCipher)
require.NoError(t, err)
expiredCode, err := CreateCode(CodeTypeAuthorization, "expired-id", past, "test-ad", testCipher)
require.NoError(t, err)
codeNoExpiry := &oauth21proto.Code{
Id: "no-expiry",
Id: "no-expiry",
GrantType: CodeTypeAuthorization,
}
codeBytes, err := proto.Marshal(codeNoExpiry)
require.NoError(t, err)
@ -107,6 +154,7 @@ func TestDecryptCode(t *testing.T) {
tests := []struct {
name string
typ oauth21proto.CodeType
code string
cipher cipher.AEAD
ad string
@ -117,15 +165,37 @@ func TestDecryptCode(t *testing.T) {
}{
{
name: "valid code",
typ: CodeTypeAuthorization,
code: validCode,
cipher: testCipher,
ad: "test-ad",
now: now,
want: &oauth21proto.Code{Id: "test-id", ExpiresAt: timestamppb.New(future)},
want: &oauth21proto.Code{Id: "test-id", ExpiresAt: timestamppb.New(future), GrantType: CodeTypeAuthorization},
wantErr: false,
},
{
name: "valid refresh code",
typ: CodeTypeRefresh,
code: validRefreshCode,
cipher: testCipher,
ad: "test-ad",
now: now,
want: &oauth21proto.Code{Id: "refresh-id", ExpiresAt: timestamppb.New(future), GrantType: CodeTypeRefresh},
wantErr: false,
},
{
name: "wrong code type",
typ: CodeTypeAccess, // Using wrong type
code: validCode, // This was created with Authorization type
cipher: testCipher,
ad: "test-ad",
now: now,
wantErr: true,
errMessage: "code type mismatch",
},
{
name: "expired code",
typ: CodeTypeAuthorization,
code: expiredCode,
cipher: testCipher,
ad: "test-ad",
@ -135,15 +205,17 @@ func TestDecryptCode(t *testing.T) {
},
{
name: "nil expiration",
typ: CodeTypeAuthorization,
code: codeNoExpiryStr,
cipher: testCipher,
ad: "test-ad",
now: now,
wantErr: true,
errMessage: "expiration is nil",
errMessage: "expires_at: value is required",
},
{
name: "invalid base64",
typ: CodeTypeAuthorization,
code: "not-base64",
cipher: testCipher,
ad: "test-ad",
@ -153,27 +225,48 @@ func TestDecryptCode(t *testing.T) {
},
{
name: "wrong authentication data",
typ: CodeTypeAuthorization,
code: validCode,
cipher: testCipher,
ad: "wrong-ad",
now: now,
wantErr: true,
errMessage: "decrypt",
errMessage: "message authentication failed",
},
{
name: "unspecified code type",
typ: 0, // Unspecified type
code: validCode,
cipher: testCipher,
ad: "test-ad",
now: now,
wantErr: true,
errMessage: "code type mismatch",
},
{
name: "undefined code type",
typ: 99, // undefined type
code: validCode,
cipher: testCipher,
ad: "test-ad",
now: now,
wantErr: true,
errMessage: "code type mismatch",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got, err := DecryptCode(tc.code, tc.cipher, tc.ad, tc.now)
got, err := DecryptCode(tc.typ, tc.code, tc.cipher, tc.ad, tc.now)
if tc.wantErr {
assert.Error(t, err)
require.Error(t, err)
if tc.errMessage != "" {
assert.Contains(t, err.Error(), tc.errMessage)
}
assert.Nil(t, got)
} else {
assert.NoError(t, err)
require.NoError(t, err)
require.NotNil(t, got)
diff := cmp.Diff(tc.want, got, protocmp.Transform())