mirror of
https://github.com/pomerium/pomerium.git
synced 2025-08-03 16:59:22 +02:00
mcp: extend code usage (#5588)
This commit is contained in:
parent
9e4947c62f
commit
4dd5357fe3
6 changed files with 244 additions and 52 deletions
|
@ -14,7 +14,14 @@ import (
|
|||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||
)
|
||||
|
||||
const (
|
||||
CodeTypeAuthorization = oauth21proto.CodeType_CODE_TYPE_AUTHORIZATION
|
||||
CodeTypeRefresh = oauth21proto.CodeType_CODE_TYPE_REFRESH
|
||||
CodeTypeAccess = oauth21proto.CodeType_CODE_TYPE_ACCESS
|
||||
)
|
||||
|
||||
func CreateCode(
|
||||
typ oauth21proto.CodeType,
|
||||
id string,
|
||||
expires time.Time,
|
||||
ad string,
|
||||
|
@ -27,6 +34,7 @@ func CreateCode(
|
|||
v := oauth21proto.Code{
|
||||
Id: id,
|
||||
ExpiresAt: timestamppb.New(expires),
|
||||
GrantType: typ,
|
||||
}
|
||||
|
||||
err := protovalidate.Validate(&v)
|
||||
|
@ -44,6 +52,7 @@ func CreateCode(
|
|||
}
|
||||
|
||||
func DecryptCode(
|
||||
typ oauth21proto.CodeType,
|
||||
code string,
|
||||
cipher cipher.AEAD,
|
||||
ad string,
|
||||
|
@ -62,8 +71,12 @@ func DecryptCode(
|
|||
if err != nil {
|
||||
return nil, fmt.Errorf("unmarshal: %w", err)
|
||||
}
|
||||
if v.ExpiresAt == nil {
|
||||
return nil, fmt.Errorf("expiration is nil")
|
||||
err = protovalidate.Validate(&v)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("validate: %w", err)
|
||||
}
|
||||
if v.GrantType != typ {
|
||||
return nil, fmt.Errorf("code type mismatch: expected %v, got %v", typ, v.GrantType)
|
||||
}
|
||||
if v.ExpiresAt.AsTime().Before(now) {
|
||||
return nil, fmt.Errorf("code expired")
|
||||
|
|
|
@ -24,6 +24,7 @@ func TestCreateCode(t *testing.T) {
|
|||
|
||||
tests := []struct {
|
||||
name string
|
||||
typ oauth21proto.CodeType
|
||||
id string
|
||||
expires time.Time
|
||||
ad string
|
||||
|
@ -32,7 +33,26 @@ func TestCreateCode(t *testing.T) {
|
|||
errMessage string
|
||||
}{
|
||||
{
|
||||
name: "valid code",
|
||||
name: "valid authorization code",
|
||||
typ: CodeTypeAuthorization,
|
||||
id: "test-id",
|
||||
expires: time.Now().Add(time.Hour),
|
||||
ad: "test-ad",
|
||||
cipher: testCipher,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid refresh code",
|
||||
typ: CodeTypeRefresh,
|
||||
id: "test-id",
|
||||
expires: time.Now().Add(time.Hour),
|
||||
ad: "test-ad",
|
||||
cipher: testCipher,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid access code",
|
||||
typ: CodeTypeAccess,
|
||||
id: "test-id",
|
||||
expires: time.Now().Add(time.Hour),
|
||||
ad: "test-ad",
|
||||
|
@ -41,6 +61,7 @@ func TestCreateCode(t *testing.T) {
|
|||
},
|
||||
{
|
||||
name: "empty id",
|
||||
typ: CodeTypeAuthorization,
|
||||
id: "",
|
||||
expires: time.Now().Add(time.Hour),
|
||||
ad: "test-ad",
|
||||
|
@ -50,6 +71,7 @@ func TestCreateCode(t *testing.T) {
|
|||
},
|
||||
{
|
||||
name: "empty expires",
|
||||
typ: CodeTypeAuthorization,
|
||||
id: "test-id",
|
||||
expires: time.Time{},
|
||||
ad: "test-ad",
|
||||
|
@ -57,11 +79,31 @@ func TestCreateCode(t *testing.T) {
|
|||
wantErr: true,
|
||||
errMessage: "validate",
|
||||
},
|
||||
{
|
||||
name: "invalid code type",
|
||||
typ: 0, // Unspecified type
|
||||
id: "test-id",
|
||||
expires: time.Now().Add(time.Hour),
|
||||
ad: "test-ad",
|
||||
cipher: testCipher,
|
||||
wantErr: true,
|
||||
errMessage: "validate",
|
||||
},
|
||||
{
|
||||
name: "undefined code type",
|
||||
typ: 99, // Undefined type
|
||||
id: "test-id",
|
||||
expires: time.Now().Add(time.Hour),
|
||||
ad: "test-ad",
|
||||
cipher: testCipher,
|
||||
wantErr: true,
|
||||
errMessage: "validate",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
code, err := CreateCode(tc.id, tc.expires, tc.ad, tc.cipher)
|
||||
code, err := CreateCode(tc.typ, tc.id, tc.expires, tc.ad, tc.cipher)
|
||||
|
||||
if tc.wantErr {
|
||||
assert.Error(t, err)
|
||||
|
@ -73,9 +115,10 @@ func TestCreateCode(t *testing.T) {
|
|||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, code)
|
||||
|
||||
decodedCode, err := DecryptCode(code, tc.cipher, tc.ad, time.Now())
|
||||
decodedCode, err := DecryptCode(tc.typ, code, tc.cipher, tc.ad, time.Now())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tc.id, decodedCode.Id)
|
||||
assert.Equal(t, tc.typ, decodedCode.GrantType)
|
||||
assert.True(t, proto.Equal(timestamppb.New(tc.expires), decodedCode.ExpiresAt))
|
||||
}
|
||||
})
|
||||
|
@ -91,14 +134,18 @@ func TestDecryptCode(t *testing.T) {
|
|||
future := now.Add(time.Hour)
|
||||
past := now.Add(-time.Hour)
|
||||
|
||||
validCode, err := CreateCode("test-id", future, "test-ad", testCipher)
|
||||
validCode, err := CreateCode(CodeTypeAuthorization, "test-id", future, "test-ad", testCipher)
|
||||
require.NoError(t, err)
|
||||
|
||||
expiredCode, err := CreateCode("expired-id", past, "test-ad", testCipher)
|
||||
validRefreshCode, err := CreateCode(CodeTypeRefresh, "refresh-id", future, "test-ad", testCipher)
|
||||
require.NoError(t, err)
|
||||
|
||||
expiredCode, err := CreateCode(CodeTypeAuthorization, "expired-id", past, "test-ad", testCipher)
|
||||
require.NoError(t, err)
|
||||
|
||||
codeNoExpiry := &oauth21proto.Code{
|
||||
Id: "no-expiry",
|
||||
Id: "no-expiry",
|
||||
GrantType: CodeTypeAuthorization,
|
||||
}
|
||||
codeBytes, err := proto.Marshal(codeNoExpiry)
|
||||
require.NoError(t, err)
|
||||
|
@ -107,6 +154,7 @@ func TestDecryptCode(t *testing.T) {
|
|||
|
||||
tests := []struct {
|
||||
name string
|
||||
typ oauth21proto.CodeType
|
||||
code string
|
||||
cipher cipher.AEAD
|
||||
ad string
|
||||
|
@ -117,15 +165,37 @@ func TestDecryptCode(t *testing.T) {
|
|||
}{
|
||||
{
|
||||
name: "valid code",
|
||||
typ: CodeTypeAuthorization,
|
||||
code: validCode,
|
||||
cipher: testCipher,
|
||||
ad: "test-ad",
|
||||
now: now,
|
||||
want: &oauth21proto.Code{Id: "test-id", ExpiresAt: timestamppb.New(future)},
|
||||
want: &oauth21proto.Code{Id: "test-id", ExpiresAt: timestamppb.New(future), GrantType: CodeTypeAuthorization},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid refresh code",
|
||||
typ: CodeTypeRefresh,
|
||||
code: validRefreshCode,
|
||||
cipher: testCipher,
|
||||
ad: "test-ad",
|
||||
now: now,
|
||||
want: &oauth21proto.Code{Id: "refresh-id", ExpiresAt: timestamppb.New(future), GrantType: CodeTypeRefresh},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "wrong code type",
|
||||
typ: CodeTypeAccess, // Using wrong type
|
||||
code: validCode, // This was created with Authorization type
|
||||
cipher: testCipher,
|
||||
ad: "test-ad",
|
||||
now: now,
|
||||
wantErr: true,
|
||||
errMessage: "code type mismatch",
|
||||
},
|
||||
{
|
||||
name: "expired code",
|
||||
typ: CodeTypeAuthorization,
|
||||
code: expiredCode,
|
||||
cipher: testCipher,
|
||||
ad: "test-ad",
|
||||
|
@ -135,15 +205,17 @@ func TestDecryptCode(t *testing.T) {
|
|||
},
|
||||
{
|
||||
name: "nil expiration",
|
||||
typ: CodeTypeAuthorization,
|
||||
code: codeNoExpiryStr,
|
||||
cipher: testCipher,
|
||||
ad: "test-ad",
|
||||
now: now,
|
||||
wantErr: true,
|
||||
errMessage: "expiration is nil",
|
||||
errMessage: "expires_at: value is required",
|
||||
},
|
||||
{
|
||||
name: "invalid base64",
|
||||
typ: CodeTypeAuthorization,
|
||||
code: "not-base64",
|
||||
cipher: testCipher,
|
||||
ad: "test-ad",
|
||||
|
@ -153,27 +225,48 @@ func TestDecryptCode(t *testing.T) {
|
|||
},
|
||||
{
|
||||
name: "wrong authentication data",
|
||||
typ: CodeTypeAuthorization,
|
||||
code: validCode,
|
||||
cipher: testCipher,
|
||||
ad: "wrong-ad",
|
||||
now: now,
|
||||
wantErr: true,
|
||||
errMessage: "decrypt",
|
||||
errMessage: "message authentication failed",
|
||||
},
|
||||
{
|
||||
name: "unspecified code type",
|
||||
typ: 0, // Unspecified type
|
||||
code: validCode,
|
||||
cipher: testCipher,
|
||||
ad: "test-ad",
|
||||
now: now,
|
||||
wantErr: true,
|
||||
errMessage: "code type mismatch",
|
||||
},
|
||||
{
|
||||
name: "undefined code type",
|
||||
typ: 99, // undefined type
|
||||
code: validCode,
|
||||
cipher: testCipher,
|
||||
ad: "test-ad",
|
||||
now: now,
|
||||
wantErr: true,
|
||||
errMessage: "code type mismatch",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got, err := DecryptCode(tc.code, tc.cipher, tc.ad, tc.now)
|
||||
got, err := DecryptCode(tc.typ, tc.code, tc.cipher, tc.ad, tc.now)
|
||||
|
||||
if tc.wantErr {
|
||||
assert.Error(t, err)
|
||||
require.Error(t, err)
|
||||
if tc.errMessage != "" {
|
||||
assert.Contains(t, err.Error(), tc.errMessage)
|
||||
}
|
||||
assert.Nil(t, got)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, got)
|
||||
|
||||
diff := cmp.Diff(tc.want, got, protocmp.Transform())
|
||||
|
|
|
@ -81,6 +81,7 @@ func (srv *Handler) AuthorizationResponse(
|
|||
req *oauth21proto.AuthorizationRequest,
|
||||
) {
|
||||
code, err := CreateCode(
|
||||
CodeTypeAuthorization,
|
||||
id,
|
||||
time.Now().Add(time.Minute*10),
|
||||
req.ClientId,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue