package mcp

import (
	"crypto/cipher"
	"encoding/base64"
	"testing"
	"time"

	"github.com/google/go-cmp/cmp"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
	"google.golang.org/protobuf/proto"
	"google.golang.org/protobuf/testing/protocmp"
	"google.golang.org/protobuf/types/known/timestamppb"

	oauth21proto "github.com/pomerium/pomerium/internal/oauth21/gen"
	"github.com/pomerium/pomerium/pkg/cryptutil"
)

func TestCreateCode(t *testing.T) {
	key := cryptutil.NewKey()
	testCipher, err := cryptutil.NewAEADCipher(key)
	require.NoError(t, err)

	tests := []struct {
		name       string
		typ        oauth21proto.CodeType
		id         string
		expires    time.Time
		ad         string
		cipher     cipher.AEAD
		wantErr    bool
		errMessage string
	}{
		{
			name:    "valid authorization code",
			typ:     CodeTypeAuthorization,
			id:      "test-id",
			expires: time.Now().Add(time.Hour),
			ad:      "test-ad",
			cipher:  testCipher,
			wantErr: false,
		},
		{
			name:    "valid refresh code",
			typ:     CodeTypeRefresh,
			id:      "test-id",
			expires: time.Now().Add(time.Hour),
			ad:      "test-ad",
			cipher:  testCipher,
			wantErr: false,
		},
		{
			name:    "valid access code",
			typ:     CodeTypeAccess,
			id:      "test-id",
			expires: time.Now().Add(time.Hour),
			ad:      "test-ad",
			cipher:  testCipher,
			wantErr: false,
		},
		{
			name:       "empty id",
			typ:        CodeTypeAuthorization,
			id:         "",
			expires:    time.Now().Add(time.Hour),
			ad:         "test-ad",
			cipher:     testCipher,
			wantErr:    true,
			errMessage: "validate",
		},
		{
			name:       "empty expires",
			typ:        CodeTypeAuthorization,
			id:         "test-id",
			expires:    time.Time{},
			ad:         "test-ad",
			cipher:     testCipher,
			wantErr:    true,
			errMessage: "validate",
		},
		{
			name:       "invalid code type",
			typ:        0, // Unspecified type
			id:         "test-id",
			expires:    time.Now().Add(time.Hour),
			ad:         "test-ad",
			cipher:     testCipher,
			wantErr:    true,
			errMessage: "validate",
		},
		{
			name:       "undefined code type",
			typ:        99, // Undefined type
			id:         "test-id",
			expires:    time.Now().Add(time.Hour),
			ad:         "test-ad",
			cipher:     testCipher,
			wantErr:    true,
			errMessage: "validate",
		},
	}

	for _, tc := range tests {
		t.Run(tc.name, func(t *testing.T) {
			code, err := CreateCode(tc.typ, tc.id, tc.expires, tc.ad, tc.cipher)

			if tc.wantErr {
				assert.Error(t, err)
				if tc.errMessage != "" {
					assert.Contains(t, err.Error(), tc.errMessage)
				}
				assert.Empty(t, code)
			} else {
				assert.NoError(t, err)
				assert.NotEmpty(t, code)

				decodedCode, err := DecryptCode(tc.typ, code, tc.cipher, tc.ad, time.Now())
				require.NoError(t, err)
				assert.Equal(t, tc.id, decodedCode.Id)
				assert.Equal(t, tc.typ, decodedCode.GrantType)
				assert.True(t, proto.Equal(timestamppb.New(tc.expires), decodedCode.ExpiresAt))
			}
		})
	}
}

func TestDecryptCode(t *testing.T) {
	key := cryptutil.NewKey()
	testCipher, err := cryptutil.NewAEADCipher(key)
	require.NoError(t, err)

	now := time.Now()
	future := now.Add(time.Hour)
	past := now.Add(-time.Hour)

	validCode, err := CreateCode(CodeTypeAuthorization, "test-id", future, "test-ad", testCipher)
	require.NoError(t, err)

	validRefreshCode, err := CreateCode(CodeTypeRefresh, "refresh-id", future, "test-ad", testCipher)
	require.NoError(t, err)

	expiredCode, err := CreateCode(CodeTypeAuthorization, "expired-id", past, "test-ad", testCipher)
	require.NoError(t, err)

	codeNoExpiry := &oauth21proto.Code{
		Id:        "no-expiry",
		GrantType: CodeTypeAuthorization,
	}
	codeBytes, err := proto.Marshal(codeNoExpiry)
	require.NoError(t, err)
	ciphertext := cryptutil.Encrypt(testCipher, codeBytes, []byte("test-ad"))
	codeNoExpiryStr := base64.StdEncoding.EncodeToString(ciphertext)

	tests := []struct {
		name       string
		typ        oauth21proto.CodeType
		code       string
		cipher     cipher.AEAD
		ad         string
		now        time.Time
		want       *oauth21proto.Code
		wantErr    bool
		errMessage string
	}{
		{
			name:    "valid code",
			typ:     CodeTypeAuthorization,
			code:    validCode,
			cipher:  testCipher,
			ad:      "test-ad",
			now:     now,
			want:    &oauth21proto.Code{Id: "test-id", ExpiresAt: timestamppb.New(future), GrantType: CodeTypeAuthorization},
			wantErr: false,
		},
		{
			name:    "valid refresh code",
			typ:     CodeTypeRefresh,
			code:    validRefreshCode,
			cipher:  testCipher,
			ad:      "test-ad",
			now:     now,
			want:    &oauth21proto.Code{Id: "refresh-id", ExpiresAt: timestamppb.New(future), GrantType: CodeTypeRefresh},
			wantErr: false,
		},
		{
			name:       "wrong code type",
			typ:        CodeTypeAccess, // Using wrong type
			code:       validCode,      // This was created with Authorization type
			cipher:     testCipher,
			ad:         "test-ad",
			now:        now,
			wantErr:    true,
			errMessage: "code type mismatch",
		},
		{
			name:       "expired code",
			typ:        CodeTypeAuthorization,
			code:       expiredCode,
			cipher:     testCipher,
			ad:         "test-ad",
			now:        now,
			wantErr:    true,
			errMessage: "code expired",
		},
		{
			name:       "nil expiration",
			typ:        CodeTypeAuthorization,
			code:       codeNoExpiryStr,
			cipher:     testCipher,
			ad:         "test-ad",
			now:        now,
			wantErr:    true,
			errMessage: "expires_at: value is required",
		},
		{
			name:       "invalid base64",
			typ:        CodeTypeAuthorization,
			code:       "not-base64",
			cipher:     testCipher,
			ad:         "test-ad",
			now:        now,
			wantErr:    true,
			errMessage: "base64 decode",
		},
		{
			name:       "wrong authentication data",
			typ:        CodeTypeAuthorization,
			code:       validCode,
			cipher:     testCipher,
			ad:         "wrong-ad",
			now:        now,
			wantErr:    true,
			errMessage: "message authentication failed",
		},
		{
			name:       "unspecified code type",
			typ:        0, // Unspecified type
			code:       validCode,
			cipher:     testCipher,
			ad:         "test-ad",
			now:        now,
			wantErr:    true,
			errMessage: "code type mismatch",
		},
		{
			name:       "undefined code type",
			typ:        99, // undefined type
			code:       validCode,
			cipher:     testCipher,
			ad:         "test-ad",
			now:        now,
			wantErr:    true,
			errMessage: "code type mismatch",
		},
	}

	for _, tc := range tests {
		t.Run(tc.name, func(t *testing.T) {
			got, err := DecryptCode(tc.typ, tc.code, tc.cipher, tc.ad, tc.now)

			if tc.wantErr {
				require.Error(t, err)
				if tc.errMessage != "" {
					assert.Contains(t, err.Error(), tc.errMessage)
				}
				assert.Nil(t, got)
			} else {
				require.NoError(t, err)
				require.NotNil(t, got)

				diff := cmp.Diff(tc.want, got, protocmp.Transform())
				assert.Empty(t, diff)
			}
		})
	}
}