mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-22 21:47:16 +02:00
mcp: extend code usage (#5588)
This commit is contained in:
parent
9e4947c62f
commit
4dd5357fe3
6 changed files with 244 additions and 52 deletions
|
@ -24,6 +24,7 @@ func TestCreateCode(t *testing.T) {
|
|||
|
||||
tests := []struct {
|
||||
name string
|
||||
typ oauth21proto.CodeType
|
||||
id string
|
||||
expires time.Time
|
||||
ad string
|
||||
|
@ -32,7 +33,26 @@ func TestCreateCode(t *testing.T) {
|
|||
errMessage string
|
||||
}{
|
||||
{
|
||||
name: "valid code",
|
||||
name: "valid authorization code",
|
||||
typ: CodeTypeAuthorization,
|
||||
id: "test-id",
|
||||
expires: time.Now().Add(time.Hour),
|
||||
ad: "test-ad",
|
||||
cipher: testCipher,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid refresh code",
|
||||
typ: CodeTypeRefresh,
|
||||
id: "test-id",
|
||||
expires: time.Now().Add(time.Hour),
|
||||
ad: "test-ad",
|
||||
cipher: testCipher,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid access code",
|
||||
typ: CodeTypeAccess,
|
||||
id: "test-id",
|
||||
expires: time.Now().Add(time.Hour),
|
||||
ad: "test-ad",
|
||||
|
@ -41,6 +61,7 @@ func TestCreateCode(t *testing.T) {
|
|||
},
|
||||
{
|
||||
name: "empty id",
|
||||
typ: CodeTypeAuthorization,
|
||||
id: "",
|
||||
expires: time.Now().Add(time.Hour),
|
||||
ad: "test-ad",
|
||||
|
@ -50,6 +71,7 @@ func TestCreateCode(t *testing.T) {
|
|||
},
|
||||
{
|
||||
name: "empty expires",
|
||||
typ: CodeTypeAuthorization,
|
||||
id: "test-id",
|
||||
expires: time.Time{},
|
||||
ad: "test-ad",
|
||||
|
@ -57,11 +79,31 @@ func TestCreateCode(t *testing.T) {
|
|||
wantErr: true,
|
||||
errMessage: "validate",
|
||||
},
|
||||
{
|
||||
name: "invalid code type",
|
||||
typ: 0, // Unspecified type
|
||||
id: "test-id",
|
||||
expires: time.Now().Add(time.Hour),
|
||||
ad: "test-ad",
|
||||
cipher: testCipher,
|
||||
wantErr: true,
|
||||
errMessage: "validate",
|
||||
},
|
||||
{
|
||||
name: "undefined code type",
|
||||
typ: 99, // Undefined type
|
||||
id: "test-id",
|
||||
expires: time.Now().Add(time.Hour),
|
||||
ad: "test-ad",
|
||||
cipher: testCipher,
|
||||
wantErr: true,
|
||||
errMessage: "validate",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
code, err := CreateCode(tc.id, tc.expires, tc.ad, tc.cipher)
|
||||
code, err := CreateCode(tc.typ, tc.id, tc.expires, tc.ad, tc.cipher)
|
||||
|
||||
if tc.wantErr {
|
||||
assert.Error(t, err)
|
||||
|
@ -73,9 +115,10 @@ func TestCreateCode(t *testing.T) {
|
|||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, code)
|
||||
|
||||
decodedCode, err := DecryptCode(code, tc.cipher, tc.ad, time.Now())
|
||||
decodedCode, err := DecryptCode(tc.typ, code, tc.cipher, tc.ad, time.Now())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tc.id, decodedCode.Id)
|
||||
assert.Equal(t, tc.typ, decodedCode.GrantType)
|
||||
assert.True(t, proto.Equal(timestamppb.New(tc.expires), decodedCode.ExpiresAt))
|
||||
}
|
||||
})
|
||||
|
@ -91,14 +134,18 @@ func TestDecryptCode(t *testing.T) {
|
|||
future := now.Add(time.Hour)
|
||||
past := now.Add(-time.Hour)
|
||||
|
||||
validCode, err := CreateCode("test-id", future, "test-ad", testCipher)
|
||||
validCode, err := CreateCode(CodeTypeAuthorization, "test-id", future, "test-ad", testCipher)
|
||||
require.NoError(t, err)
|
||||
|
||||
expiredCode, err := CreateCode("expired-id", past, "test-ad", testCipher)
|
||||
validRefreshCode, err := CreateCode(CodeTypeRefresh, "refresh-id", future, "test-ad", testCipher)
|
||||
require.NoError(t, err)
|
||||
|
||||
expiredCode, err := CreateCode(CodeTypeAuthorization, "expired-id", past, "test-ad", testCipher)
|
||||
require.NoError(t, err)
|
||||
|
||||
codeNoExpiry := &oauth21proto.Code{
|
||||
Id: "no-expiry",
|
||||
Id: "no-expiry",
|
||||
GrantType: CodeTypeAuthorization,
|
||||
}
|
||||
codeBytes, err := proto.Marshal(codeNoExpiry)
|
||||
require.NoError(t, err)
|
||||
|
@ -107,6 +154,7 @@ func TestDecryptCode(t *testing.T) {
|
|||
|
||||
tests := []struct {
|
||||
name string
|
||||
typ oauth21proto.CodeType
|
||||
code string
|
||||
cipher cipher.AEAD
|
||||
ad string
|
||||
|
@ -117,15 +165,37 @@ func TestDecryptCode(t *testing.T) {
|
|||
}{
|
||||
{
|
||||
name: "valid code",
|
||||
typ: CodeTypeAuthorization,
|
||||
code: validCode,
|
||||
cipher: testCipher,
|
||||
ad: "test-ad",
|
||||
now: now,
|
||||
want: &oauth21proto.Code{Id: "test-id", ExpiresAt: timestamppb.New(future)},
|
||||
want: &oauth21proto.Code{Id: "test-id", ExpiresAt: timestamppb.New(future), GrantType: CodeTypeAuthorization},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid refresh code",
|
||||
typ: CodeTypeRefresh,
|
||||
code: validRefreshCode,
|
||||
cipher: testCipher,
|
||||
ad: "test-ad",
|
||||
now: now,
|
||||
want: &oauth21proto.Code{Id: "refresh-id", ExpiresAt: timestamppb.New(future), GrantType: CodeTypeRefresh},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "wrong code type",
|
||||
typ: CodeTypeAccess, // Using wrong type
|
||||
code: validCode, // This was created with Authorization type
|
||||
cipher: testCipher,
|
||||
ad: "test-ad",
|
||||
now: now,
|
||||
wantErr: true,
|
||||
errMessage: "code type mismatch",
|
||||
},
|
||||
{
|
||||
name: "expired code",
|
||||
typ: CodeTypeAuthorization,
|
||||
code: expiredCode,
|
||||
cipher: testCipher,
|
||||
ad: "test-ad",
|
||||
|
@ -135,15 +205,17 @@ func TestDecryptCode(t *testing.T) {
|
|||
},
|
||||
{
|
||||
name: "nil expiration",
|
||||
typ: CodeTypeAuthorization,
|
||||
code: codeNoExpiryStr,
|
||||
cipher: testCipher,
|
||||
ad: "test-ad",
|
||||
now: now,
|
||||
wantErr: true,
|
||||
errMessage: "expiration is nil",
|
||||
errMessage: "expires_at: value is required",
|
||||
},
|
||||
{
|
||||
name: "invalid base64",
|
||||
typ: CodeTypeAuthorization,
|
||||
code: "not-base64",
|
||||
cipher: testCipher,
|
||||
ad: "test-ad",
|
||||
|
@ -153,27 +225,48 @@ func TestDecryptCode(t *testing.T) {
|
|||
},
|
||||
{
|
||||
name: "wrong authentication data",
|
||||
typ: CodeTypeAuthorization,
|
||||
code: validCode,
|
||||
cipher: testCipher,
|
||||
ad: "wrong-ad",
|
||||
now: now,
|
||||
wantErr: true,
|
||||
errMessage: "decrypt",
|
||||
errMessage: "message authentication failed",
|
||||
},
|
||||
{
|
||||
name: "unspecified code type",
|
||||
typ: 0, // Unspecified type
|
||||
code: validCode,
|
||||
cipher: testCipher,
|
||||
ad: "test-ad",
|
||||
now: now,
|
||||
wantErr: true,
|
||||
errMessage: "code type mismatch",
|
||||
},
|
||||
{
|
||||
name: "undefined code type",
|
||||
typ: 99, // undefined type
|
||||
code: validCode,
|
||||
cipher: testCipher,
|
||||
ad: "test-ad",
|
||||
now: now,
|
||||
wantErr: true,
|
||||
errMessage: "code type mismatch",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got, err := DecryptCode(tc.code, tc.cipher, tc.ad, tc.now)
|
||||
got, err := DecryptCode(tc.typ, tc.code, tc.cipher, tc.ad, tc.now)
|
||||
|
||||
if tc.wantErr {
|
||||
assert.Error(t, err)
|
||||
require.Error(t, err)
|
||||
if tc.errMessage != "" {
|
||||
assert.Contains(t, err.Error(), tc.errMessage)
|
||||
}
|
||||
assert.Nil(t, got)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, got)
|
||||
|
||||
diff := cmp.Diff(tc.want, got, protocmp.Transform())
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue