pomerium/internal/mcp/code_test.go
2025-04-25 14:47:11 -04:00

277 lines
6.8 KiB
Go

package mcp
import (
"crypto/cipher"
"encoding/base64"
"testing"
"time"
"github.com/google/go-cmp/cmp"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/testing/protocmp"
"google.golang.org/protobuf/types/known/timestamppb"
oauth21proto "github.com/pomerium/pomerium/internal/oauth21/gen"
"github.com/pomerium/pomerium/pkg/cryptutil"
)
func TestCreateCode(t *testing.T) {
key := cryptutil.NewKey()
testCipher, err := cryptutil.NewAEADCipher(key)
require.NoError(t, err)
tests := []struct {
name string
typ oauth21proto.CodeType
id string
expires time.Time
ad string
cipher cipher.AEAD
wantErr bool
errMessage string
}{
{
name: "valid authorization code",
typ: CodeTypeAuthorization,
id: "test-id",
expires: time.Now().Add(time.Hour),
ad: "test-ad",
cipher: testCipher,
wantErr: false,
},
{
name: "valid refresh code",
typ: CodeTypeRefresh,
id: "test-id",
expires: time.Now().Add(time.Hour),
ad: "test-ad",
cipher: testCipher,
wantErr: false,
},
{
name: "valid access code",
typ: CodeTypeAccess,
id: "test-id",
expires: time.Now().Add(time.Hour),
ad: "test-ad",
cipher: testCipher,
wantErr: false,
},
{
name: "empty id",
typ: CodeTypeAuthorization,
id: "",
expires: time.Now().Add(time.Hour),
ad: "test-ad",
cipher: testCipher,
wantErr: true,
errMessage: "validate",
},
{
name: "empty expires",
typ: CodeTypeAuthorization,
id: "test-id",
expires: time.Time{},
ad: "test-ad",
cipher: testCipher,
wantErr: true,
errMessage: "validate",
},
{
name: "invalid code type",
typ: 0, // Unspecified type
id: "test-id",
expires: time.Now().Add(time.Hour),
ad: "test-ad",
cipher: testCipher,
wantErr: true,
errMessage: "validate",
},
{
name: "undefined code type",
typ: 99, // Undefined type
id: "test-id",
expires: time.Now().Add(time.Hour),
ad: "test-ad",
cipher: testCipher,
wantErr: true,
errMessage: "validate",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
code, err := CreateCode(tc.typ, tc.id, tc.expires, tc.ad, tc.cipher)
if tc.wantErr {
assert.Error(t, err)
if tc.errMessage != "" {
assert.Contains(t, err.Error(), tc.errMessage)
}
assert.Empty(t, code)
} else {
assert.NoError(t, err)
assert.NotEmpty(t, code)
decodedCode, err := DecryptCode(tc.typ, code, tc.cipher, tc.ad, time.Now())
require.NoError(t, err)
assert.Equal(t, tc.id, decodedCode.Id)
assert.Equal(t, tc.typ, decodedCode.GrantType)
assert.True(t, proto.Equal(timestamppb.New(tc.expires), decodedCode.ExpiresAt))
}
})
}
}
func TestDecryptCode(t *testing.T) {
key := cryptutil.NewKey()
testCipher, err := cryptutil.NewAEADCipher(key)
require.NoError(t, err)
now := time.Now()
future := now.Add(time.Hour)
past := now.Add(-time.Hour)
validCode, err := CreateCode(CodeTypeAuthorization, "test-id", future, "test-ad", testCipher)
require.NoError(t, err)
validRefreshCode, err := CreateCode(CodeTypeRefresh, "refresh-id", future, "test-ad", testCipher)
require.NoError(t, err)
expiredCode, err := CreateCode(CodeTypeAuthorization, "expired-id", past, "test-ad", testCipher)
require.NoError(t, err)
codeNoExpiry := &oauth21proto.Code{
Id: "no-expiry",
GrantType: CodeTypeAuthorization,
}
codeBytes, err := proto.Marshal(codeNoExpiry)
require.NoError(t, err)
ciphertext := cryptutil.Encrypt(testCipher, codeBytes, []byte("test-ad"))
codeNoExpiryStr := base64.StdEncoding.EncodeToString(ciphertext)
tests := []struct {
name string
typ oauth21proto.CodeType
code string
cipher cipher.AEAD
ad string
now time.Time
want *oauth21proto.Code
wantErr bool
errMessage string
}{
{
name: "valid code",
typ: CodeTypeAuthorization,
code: validCode,
cipher: testCipher,
ad: "test-ad",
now: now,
want: &oauth21proto.Code{Id: "test-id", ExpiresAt: timestamppb.New(future), GrantType: CodeTypeAuthorization},
wantErr: false,
},
{
name: "valid refresh code",
typ: CodeTypeRefresh,
code: validRefreshCode,
cipher: testCipher,
ad: "test-ad",
now: now,
want: &oauth21proto.Code{Id: "refresh-id", ExpiresAt: timestamppb.New(future), GrantType: CodeTypeRefresh},
wantErr: false,
},
{
name: "wrong code type",
typ: CodeTypeAccess, // Using wrong type
code: validCode, // This was created with Authorization type
cipher: testCipher,
ad: "test-ad",
now: now,
wantErr: true,
errMessage: "code type mismatch",
},
{
name: "expired code",
typ: CodeTypeAuthorization,
code: expiredCode,
cipher: testCipher,
ad: "test-ad",
now: now,
wantErr: true,
errMessage: "code expired",
},
{
name: "nil expiration",
typ: CodeTypeAuthorization,
code: codeNoExpiryStr,
cipher: testCipher,
ad: "test-ad",
now: now,
wantErr: true,
errMessage: "expires_at: value is required",
},
{
name: "invalid base64",
typ: CodeTypeAuthorization,
code: "not-base64",
cipher: testCipher,
ad: "test-ad",
now: now,
wantErr: true,
errMessage: "base64 decode",
},
{
name: "wrong authentication data",
typ: CodeTypeAuthorization,
code: validCode,
cipher: testCipher,
ad: "wrong-ad",
now: now,
wantErr: true,
errMessage: "message authentication failed",
},
{
name: "unspecified code type",
typ: 0, // Unspecified type
code: validCode,
cipher: testCipher,
ad: "test-ad",
now: now,
wantErr: true,
errMessage: "code type mismatch",
},
{
name: "undefined code type",
typ: 99, // undefined type
code: validCode,
cipher: testCipher,
ad: "test-ad",
now: now,
wantErr: true,
errMessage: "code type mismatch",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got, err := DecryptCode(tc.typ, tc.code, tc.cipher, tc.ad, tc.now)
if tc.wantErr {
require.Error(t, err)
if tc.errMessage != "" {
assert.Contains(t, err.Error(), tc.errMessage)
}
assert.Nil(t, got)
} else {
require.NoError(t, err)
require.NotNil(t, got)
diff := cmp.Diff(tc.want, got, protocmp.Transform())
assert.Empty(t, diff)
}
})
}
}