update tests

This commit is contained in:
Denis Mishin 2025-04-25 14:32:25 -04:00
parent 0135b8b9aa
commit 428ef62b90
2 changed files with 15 additions and 15 deletions

View file

@ -47,7 +47,7 @@ func CreateCode(
return "", err return "", err
} }
ciphertext := cryptutil.Encrypt(cipher, b, getAD(ad, typ)) ciphertext := cryptutil.Encrypt(cipher, b, []byte(ad))
return base64.StdEncoding.EncodeToString(ciphertext), nil return base64.StdEncoding.EncodeToString(ciphertext), nil
} }
@ -62,7 +62,7 @@ func DecryptCode(
if err != nil { if err != nil {
return nil, fmt.Errorf("base64 decode: %w", err) return nil, fmt.Errorf("base64 decode: %w", err)
} }
plaintext, err := cryptutil.Decrypt(cipher, b, getAD(ad, typ)) plaintext, err := cryptutil.Decrypt(cipher, b, []byte(ad))
if err != nil { if err != nil {
return nil, fmt.Errorf("decrypt: %w", err) return nil, fmt.Errorf("decrypt: %w", err)
} }
@ -71,15 +71,15 @@ func DecryptCode(
if err != nil { if err != nil {
return nil, fmt.Errorf("unmarshal: %w", err) return nil, fmt.Errorf("unmarshal: %w", err)
} }
if v.ExpiresAt == nil { err = protovalidate.Validate(&v)
return nil, fmt.Errorf("expiration is nil") if err != nil {
return nil, fmt.Errorf("validate: %w", err)
}
if v.GrantType != typ {
return nil, fmt.Errorf("code type mismatch: expected %v, got %v", typ, v.GrantType)
} }
if v.ExpiresAt.AsTime().Before(now) { if v.ExpiresAt.AsTime().Before(now) {
return nil, fmt.Errorf("code expired") return nil, fmt.Errorf("code expired")
} }
return &v, nil return &v, nil
} }
func getAD(ad string, typ oauth21proto.CodeType) []byte {
return []byte(fmt.Sprintf("%s:%s", ad, typ.String()))
}

View file

@ -149,7 +149,7 @@ func TestDecryptCode(t *testing.T) {
} }
codeBytes, err := proto.Marshal(codeNoExpiry) codeBytes, err := proto.Marshal(codeNoExpiry)
require.NoError(t, err) require.NoError(t, err)
ciphertext := cryptutil.Encrypt(testCipher, codeBytes, getAD("test-ad", CodeTypeAuthorization)) ciphertext := cryptutil.Encrypt(testCipher, codeBytes, []byte("test-ad"))
codeNoExpiryStr := base64.StdEncoding.EncodeToString(ciphertext) codeNoExpiryStr := base64.StdEncoding.EncodeToString(ciphertext)
tests := []struct { tests := []struct {
@ -191,7 +191,7 @@ func TestDecryptCode(t *testing.T) {
ad: "test-ad", ad: "test-ad",
now: now, now: now,
wantErr: true, wantErr: true,
errMessage: "decrypt", errMessage: "code type mismatch",
}, },
{ {
name: "expired code", name: "expired code",
@ -211,7 +211,7 @@ func TestDecryptCode(t *testing.T) {
ad: "test-ad", ad: "test-ad",
now: now, now: now,
wantErr: true, wantErr: true,
errMessage: "expiration is nil", errMessage: "expires_at: value is required",
}, },
{ {
name: "invalid base64", name: "invalid base64",
@ -231,7 +231,7 @@ func TestDecryptCode(t *testing.T) {
ad: "wrong-ad", ad: "wrong-ad",
now: now, now: now,
wantErr: true, wantErr: true,
errMessage: "decrypt", errMessage: "code type mismatch",
}, },
{ {
name: "unspecified code type", name: "unspecified code type",
@ -241,7 +241,7 @@ func TestDecryptCode(t *testing.T) {
ad: "test-ad", ad: "test-ad",
now: now, now: now,
wantErr: true, wantErr: true,
errMessage: "decrypt", errMessage: "code type mismatch",
}, },
{ {
name: "undefined code type", name: "undefined code type",
@ -260,13 +260,13 @@ func TestDecryptCode(t *testing.T) {
got, err := DecryptCode(tc.typ, tc.code, tc.cipher, tc.ad, tc.now) got, err := DecryptCode(tc.typ, tc.code, tc.cipher, tc.ad, tc.now)
if tc.wantErr { if tc.wantErr {
assert.Error(t, err) require.Error(t, err)
if tc.errMessage != "" { if tc.errMessage != "" {
assert.Contains(t, err.Error(), tc.errMessage) assert.Contains(t, err.Error(), tc.errMessage)
} }
assert.Nil(t, got) assert.Nil(t, got)
} else { } else {
assert.NoError(t, err) require.NoError(t, err)
require.NotNil(t, got) require.NotNil(t, got)
diff := cmp.Diff(tc.want, got, protocmp.Transform()) diff := cmp.Diff(tc.want, got, protocmp.Transform())