package protoutil

import (
	"fmt"
	"testing"
	"time"

	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
	"google.golang.org/protobuf/encoding/protojson"
	"google.golang.org/protobuf/proto"
	"google.golang.org/protobuf/types/known/structpb"
	"google.golang.org/protobuf/types/known/wrapperspb"

	"github.com/pomerium/pomerium/pkg/cryptutil"
)

func TestEncryptor_Encrypt(t *testing.T) {
	t.Run("simple", func(t *testing.T) {
		kek, err := cryptutil.GenerateKeyEncryptionKey()
		require.NoError(t, err)
		enc := NewEncryptor(kek.Public())
		sealed, err := enc.Encrypt(wrapperspb.String("HELLO WORLD"))
		require.NoError(t, err)
		require.Equal(t, kek.Public().ID(), sealed.GetKeyId())
		require.NotEmpty(t, sealed.GetDataEncryptionKey())
		require.Equal(t, "type.googleapis.com/google.protobuf.StringValue", sealed.GetMessageType())
		require.NotEmpty(t, sealed.GetEncryptedMessage())
	})

	t.Run("reuse dek", func(t *testing.T) {
		kek, err := cryptutil.GenerateKeyEncryptionKey()
		require.NoError(t, err)
		enc := NewEncryptor(kek.Public())
		s1, err := enc.Encrypt(wrapperspb.String("HELLO WORLD"))
		require.NoError(t, err)
		s2, err := enc.Encrypt(wrapperspb.String("HELLO WORLD"))
		require.NoError(t, err)
		assert.Equal(t, s1.GetDataEncryptionKey(), s2.GetDataEncryptionKey())
	})
	t.Run("rotate dek", func(t *testing.T) {
		kek, err := cryptutil.GenerateKeyEncryptionKey()
		require.NoError(t, err)
		enc := NewEncryptor(kek.Public())
		s1, err := enc.Encrypt(wrapperspb.String("HELLO WORLD"))
		require.NoError(t, err)
		enc.nextRotate = time.Now()
		s2, err := enc.Encrypt(wrapperspb.String("HELLO WORLD"))
		require.NoError(t, err)
		assert.NotEqual(t, s1.GetDataEncryptionKey(), s2.GetDataEncryptionKey())
	})
}

func TestDecryptor_Decrypt(t *testing.T) {
	expect := wrapperspb.String("HELLO WORLD")

	kek, err := cryptutil.GenerateKeyEncryptionKey()
	require.NoError(t, err)

	enc := NewEncryptor(kek.Public())
	sealed, err := enc.Encrypt(expect)
	require.NoError(t, err)

	dec := NewDecryptor(cryptutil.KeyEncryptionKeySourceFunc(func(id string) (*cryptutil.PrivateKeyEncryptionKey, error) {
		require.Equal(t, kek.ID(), id)
		return kek, nil
	}))
	opened, err := dec.Decrypt(sealed)
	require.NoError(t, err)
	assertProtoEqual(t, expect, opened)
}

func assertProtoEqual(t *testing.T, x, y proto.Message) {
	xbs, _ := protojson.Marshal(x)
	ybs, _ := protojson.Marshal(y)
	assert.True(t, proto.Equal(x, y), "%s != %s", xbs, ybs)
}

func BenchmarkEncrypt(b *testing.B) {
	m := map[string]interface{}{}
	for i := 0; i < 10; i++ {
		mm := map[string]interface{}{}
		for j := 0; j < 10; j++ {
			mm[fmt.Sprintf("key%d", j)] = fmt.Sprintf("value%d", j)
		}
		m[fmt.Sprintf("key%d", i)] = mm
	}

	obj, err := structpb.NewStruct(m)
	require.NoError(b, err)

	kek, err := cryptutil.GenerateKeyEncryptionKey()
	require.NoError(b, err)
	enc := NewEncryptor(kek.Public())

	b.ResetTimer()
	for i := 0; i < b.N; i++ {
		_, err := enc.Encrypt(obj)
		require.NoError(b, err)
	}
}