update tests

This commit is contained in:
Denis Mishin 2025-04-25 14:32:25 -04:00
parent 0135b8b9aa
commit 428ef62b90
2 changed files with 15 additions and 15 deletions

View file

@ -47,7 +47,7 @@ func CreateCode(
return "", err
}
ciphertext := cryptutil.Encrypt(cipher, b, getAD(ad, typ))
ciphertext := cryptutil.Encrypt(cipher, b, []byte(ad))
return base64.StdEncoding.EncodeToString(ciphertext), nil
}
@ -62,7 +62,7 @@ func DecryptCode(
if err != nil {
return nil, fmt.Errorf("base64 decode: %w", err)
}
plaintext, err := cryptutil.Decrypt(cipher, b, getAD(ad, typ))
plaintext, err := cryptutil.Decrypt(cipher, b, []byte(ad))
if err != nil {
return nil, fmt.Errorf("decrypt: %w", err)
}
@ -71,15 +71,15 @@ func DecryptCode(
if err != nil {
return nil, fmt.Errorf("unmarshal: %w", err)
}
if v.ExpiresAt == nil {
return nil, fmt.Errorf("expiration is nil")
err = protovalidate.Validate(&v)
if err != nil {
return nil, fmt.Errorf("validate: %w", err)
}
if v.GrantType != typ {
return nil, fmt.Errorf("code type mismatch: expected %v, got %v", typ, v.GrantType)
}
if v.ExpiresAt.AsTime().Before(now) {
return nil, fmt.Errorf("code expired")
}
return &v, nil
}
func getAD(ad string, typ oauth21proto.CodeType) []byte {
return []byte(fmt.Sprintf("%s:%s", ad, typ.String()))
}

View file

@ -149,7 +149,7 @@ func TestDecryptCode(t *testing.T) {
}
codeBytes, err := proto.Marshal(codeNoExpiry)
require.NoError(t, err)
ciphertext := cryptutil.Encrypt(testCipher, codeBytes, getAD("test-ad", CodeTypeAuthorization))
ciphertext := cryptutil.Encrypt(testCipher, codeBytes, []byte("test-ad"))
codeNoExpiryStr := base64.StdEncoding.EncodeToString(ciphertext)
tests := []struct {
@ -191,7 +191,7 @@ func TestDecryptCode(t *testing.T) {
ad: "test-ad",
now: now,
wantErr: true,
errMessage: "decrypt",
errMessage: "code type mismatch",
},
{
name: "expired code",
@ -211,7 +211,7 @@ func TestDecryptCode(t *testing.T) {
ad: "test-ad",
now: now,
wantErr: true,
errMessage: "expiration is nil",
errMessage: "expires_at: value is required",
},
{
name: "invalid base64",
@ -231,7 +231,7 @@ func TestDecryptCode(t *testing.T) {
ad: "wrong-ad",
now: now,
wantErr: true,
errMessage: "decrypt",
errMessage: "code type mismatch",
},
{
name: "unspecified code type",
@ -241,7 +241,7 @@ func TestDecryptCode(t *testing.T) {
ad: "test-ad",
now: now,
wantErr: true,
errMessage: "decrypt",
errMessage: "code type mismatch",
},
{
name: "undefined code type",
@ -260,13 +260,13 @@ func TestDecryptCode(t *testing.T) {
got, err := DecryptCode(tc.typ, tc.code, tc.cipher, tc.ad, tc.now)
if tc.wantErr {
assert.Error(t, err)
require.Error(t, err)
if tc.errMessage != "" {
assert.Contains(t, err.Error(), tc.errMessage)
}
assert.Nil(t, got)
} else {
assert.NoError(t, err)
require.NoError(t, err)
require.NotNil(t, got)
diff := cmp.Diff(tc.want, got, protocmp.Transform())