mirror of
https://github.com/pomerium/pomerium.git
synced 2025-04-28 18:06:34 +02:00
update tests
This commit is contained in:
parent
0135b8b9aa
commit
428ef62b90
2 changed files with 15 additions and 15 deletions
|
@ -47,7 +47,7 @@ func CreateCode(
|
|||
return "", err
|
||||
}
|
||||
|
||||
ciphertext := cryptutil.Encrypt(cipher, b, getAD(ad, typ))
|
||||
ciphertext := cryptutil.Encrypt(cipher, b, []byte(ad))
|
||||
return base64.StdEncoding.EncodeToString(ciphertext), nil
|
||||
}
|
||||
|
||||
|
@ -62,7 +62,7 @@ func DecryptCode(
|
|||
if err != nil {
|
||||
return nil, fmt.Errorf("base64 decode: %w", err)
|
||||
}
|
||||
plaintext, err := cryptutil.Decrypt(cipher, b, getAD(ad, typ))
|
||||
plaintext, err := cryptutil.Decrypt(cipher, b, []byte(ad))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decrypt: %w", err)
|
||||
}
|
||||
|
@ -71,15 +71,15 @@ func DecryptCode(
|
|||
if err != nil {
|
||||
return nil, fmt.Errorf("unmarshal: %w", err)
|
||||
}
|
||||
if v.ExpiresAt == nil {
|
||||
return nil, fmt.Errorf("expiration is nil")
|
||||
err = protovalidate.Validate(&v)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("validate: %w", err)
|
||||
}
|
||||
if v.GrantType != typ {
|
||||
return nil, fmt.Errorf("code type mismatch: expected %v, got %v", typ, v.GrantType)
|
||||
}
|
||||
if v.ExpiresAt.AsTime().Before(now) {
|
||||
return nil, fmt.Errorf("code expired")
|
||||
}
|
||||
return &v, nil
|
||||
}
|
||||
|
||||
func getAD(ad string, typ oauth21proto.CodeType) []byte {
|
||||
return []byte(fmt.Sprintf("%s:%s", ad, typ.String()))
|
||||
}
|
||||
|
|
|
@ -149,7 +149,7 @@ func TestDecryptCode(t *testing.T) {
|
|||
}
|
||||
codeBytes, err := proto.Marshal(codeNoExpiry)
|
||||
require.NoError(t, err)
|
||||
ciphertext := cryptutil.Encrypt(testCipher, codeBytes, getAD("test-ad", CodeTypeAuthorization))
|
||||
ciphertext := cryptutil.Encrypt(testCipher, codeBytes, []byte("test-ad"))
|
||||
codeNoExpiryStr := base64.StdEncoding.EncodeToString(ciphertext)
|
||||
|
||||
tests := []struct {
|
||||
|
@ -191,7 +191,7 @@ func TestDecryptCode(t *testing.T) {
|
|||
ad: "test-ad",
|
||||
now: now,
|
||||
wantErr: true,
|
||||
errMessage: "decrypt",
|
||||
errMessage: "code type mismatch",
|
||||
},
|
||||
{
|
||||
name: "expired code",
|
||||
|
@ -211,7 +211,7 @@ func TestDecryptCode(t *testing.T) {
|
|||
ad: "test-ad",
|
||||
now: now,
|
||||
wantErr: true,
|
||||
errMessage: "expiration is nil",
|
||||
errMessage: "expires_at: value is required",
|
||||
},
|
||||
{
|
||||
name: "invalid base64",
|
||||
|
@ -231,7 +231,7 @@ func TestDecryptCode(t *testing.T) {
|
|||
ad: "wrong-ad",
|
||||
now: now,
|
||||
wantErr: true,
|
||||
errMessage: "decrypt",
|
||||
errMessage: "code type mismatch",
|
||||
},
|
||||
{
|
||||
name: "unspecified code type",
|
||||
|
@ -241,7 +241,7 @@ func TestDecryptCode(t *testing.T) {
|
|||
ad: "test-ad",
|
||||
now: now,
|
||||
wantErr: true,
|
||||
errMessage: "decrypt",
|
||||
errMessage: "code type mismatch",
|
||||
},
|
||||
{
|
||||
name: "undefined code type",
|
||||
|
@ -260,13 +260,13 @@ func TestDecryptCode(t *testing.T) {
|
|||
got, err := DecryptCode(tc.typ, tc.code, tc.cipher, tc.ad, tc.now)
|
||||
|
||||
if tc.wantErr {
|
||||
assert.Error(t, err)
|
||||
require.Error(t, err)
|
||||
if tc.errMessage != "" {
|
||||
assert.Contains(t, err.Error(), tc.errMessage)
|
||||
}
|
||||
assert.Nil(t, got)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, got)
|
||||
|
||||
diff := cmp.Diff(tc.want, got, protocmp.Transform())
|
||||
|
|
Loading…
Add table
Reference in a new issue