diff --git a/pkg/cryptutil/encrypt.go b/pkg/cryptutil/encrypt.go index 0e0b3f31f..452c54463 100644 --- a/pkg/cryptutil/encrypt.go +++ b/pkg/cryptutil/encrypt.go @@ -44,7 +44,7 @@ func Decrypt(a cipher.AEAD, data, ad []byte) ([]byte, error) { nonce := data[size:] plaintext, err := a.Open(nil, nonce, ciphertext, ad) if err != nil { - return nil, err + return nil, fmt.Errorf("cryptutil: decryption failed (mismatched keys?): %w", err) } return plaintext, nil } diff --git a/pkg/cryptutil/encrypt_test.go b/pkg/cryptutil/encrypt_test.go index 98d364f64..94a41419e 100644 --- a/pkg/cryptutil/encrypt_test.go +++ b/pkg/cryptutil/encrypt_test.go @@ -23,22 +23,22 @@ func TestEncodeAndDecodeAccessToken(t *testing.T) { t.Fatalf("plaintext is not encrypted plaintext:%v ciphertext:%x", plaintext, ciphertext) } - got, err := Decrypt(c, ciphertext, nil) + diffKey, err := NewAEADCipher(NewKey()) if err != nil { - t.Fatalf("unexpected err decrypting: %v", err) + t.Fatalf("unexpected err: %v", err) } + // key mismatch + _, err = Decrypt(diffKey, ciphertext, nil) + assert.Error(t, err) - // if less than 32 bytes, fail + // bad data size _, err = Decrypt(c, []byte("oh"), nil) - if err == nil { - t.Fatalf("should fail if <32 bytes output: %v", err) - } + assert.Error(t, err) - if !reflect.DeepEqual(got, plaintext) { - t.Logf(" got: %v", got) - t.Logf("want: %v", plaintext) - t.Fatal("got unexpected decrypted value") - } + // good + got, err := Decrypt(c, ciphertext, nil) + assert.NoError(t, err) + assert.Equal(t, got, plaintext) } func TestNewAEADCipher(t *testing.T) {