From 428ef62b90eb1aeb2332eb7066a0af0cc9ab86d2 Mon Sep 17 00:00:00 2001 From: Denis Mishin Date: Fri, 25 Apr 2025 14:32:25 -0400 Subject: [PATCH] update tests --- internal/mcp/code.go | 16 ++++++++-------- internal/mcp/code_test.go | 14 +++++++------- 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/internal/mcp/code.go b/internal/mcp/code.go index 2f4c5afa1..8dced0ac6 100644 --- a/internal/mcp/code.go +++ b/internal/mcp/code.go @@ -47,7 +47,7 @@ func CreateCode( return "", err } - ciphertext := cryptutil.Encrypt(cipher, b, getAD(ad, typ)) + ciphertext := cryptutil.Encrypt(cipher, b, []byte(ad)) return base64.StdEncoding.EncodeToString(ciphertext), nil } @@ -62,7 +62,7 @@ func DecryptCode( if err != nil { return nil, fmt.Errorf("base64 decode: %w", err) } - plaintext, err := cryptutil.Decrypt(cipher, b, getAD(ad, typ)) + plaintext, err := cryptutil.Decrypt(cipher, b, []byte(ad)) if err != nil { return nil, fmt.Errorf("decrypt: %w", err) } @@ -71,15 +71,15 @@ func DecryptCode( if err != nil { return nil, fmt.Errorf("unmarshal: %w", err) } - if v.ExpiresAt == nil { - return nil, fmt.Errorf("expiration is nil") + err = protovalidate.Validate(&v) + if err != nil { + return nil, fmt.Errorf("validate: %w", err) + } + if v.GrantType != typ { + return nil, fmt.Errorf("code type mismatch: expected %v, got %v", typ, v.GrantType) } if v.ExpiresAt.AsTime().Before(now) { return nil, fmt.Errorf("code expired") } return &v, nil } - -func getAD(ad string, typ oauth21proto.CodeType) []byte { - return []byte(fmt.Sprintf("%s:%s", ad, typ.String())) -} diff --git a/internal/mcp/code_test.go b/internal/mcp/code_test.go index f28bcfff4..3c7b227f8 100644 --- a/internal/mcp/code_test.go +++ b/internal/mcp/code_test.go @@ -149,7 +149,7 @@ func TestDecryptCode(t *testing.T) { } codeBytes, err := proto.Marshal(codeNoExpiry) require.NoError(t, err) - ciphertext := cryptutil.Encrypt(testCipher, codeBytes, getAD("test-ad", CodeTypeAuthorization)) + ciphertext := cryptutil.Encrypt(testCipher, codeBytes, []byte("test-ad")) codeNoExpiryStr := base64.StdEncoding.EncodeToString(ciphertext) tests := []struct { @@ -191,7 +191,7 @@ func TestDecryptCode(t *testing.T) { ad: "test-ad", now: now, wantErr: true, - errMessage: "decrypt", + errMessage: "code type mismatch", }, { name: "expired code", @@ -211,7 +211,7 @@ func TestDecryptCode(t *testing.T) { ad: "test-ad", now: now, wantErr: true, - errMessage: "expiration is nil", + errMessage: "expires_at: value is required", }, { name: "invalid base64", @@ -231,7 +231,7 @@ func TestDecryptCode(t *testing.T) { ad: "wrong-ad", now: now, wantErr: true, - errMessage: "decrypt", + errMessage: "code type mismatch", }, { name: "unspecified code type", @@ -241,7 +241,7 @@ func TestDecryptCode(t *testing.T) { ad: "test-ad", now: now, wantErr: true, - errMessage: "decrypt", + errMessage: "code type mismatch", }, { name: "undefined code type", @@ -260,13 +260,13 @@ func TestDecryptCode(t *testing.T) { got, err := DecryptCode(tc.typ, tc.code, tc.cipher, tc.ad, tc.now) if tc.wantErr { - assert.Error(t, err) + require.Error(t, err) if tc.errMessage != "" { assert.Contains(t, err.Error(), tc.errMessage) } assert.Nil(t, got) } else { - assert.NoError(t, err) + require.NoError(t, err) require.NotNil(t, got) diff := cmp.Diff(tc.want, got, protocmp.Transform())