package cryptutil

import (
	"encoding/base64"
	"reflect"
	"testing"

	"github.com/stretchr/testify/assert"
)

func TestEncodeAndDecodeAccessToken(t *testing.T) {
	plaintext := []byte("my plain text value")

	key := NewKey()
	c, err := NewAEADCipher(key)
	if err != nil {
		t.Fatalf("unexpected err: %v", err)
	}

	ciphertext := Encrypt(c, plaintext, nil)

	if reflect.DeepEqual(plaintext, ciphertext) {
		t.Fatalf("plaintext is not encrypted plaintext:%v ciphertext:%x", plaintext, ciphertext)
	}

	diffKey, err := NewAEADCipher(NewKey())
	if err != nil {
		t.Fatalf("unexpected err: %v", err)
	}
	// key mismatch
	_, err = Decrypt(diffKey, ciphertext, nil)
	assert.Error(t, err)

	// bad data size
	_, err = Decrypt(c, []byte("oh"), nil)
	assert.Error(t, err)

	// good
	got, err := Decrypt(c, ciphertext, nil)
	assert.NoError(t, err)
	assert.Equal(t, got, plaintext)
}

func TestNewAEADCipher(t *testing.T) {
	t.Parallel()
	tests := []struct {
		name    string
		secret  []byte
		wantErr bool
	}{
		{"simple 32 byte key", NewKey(), false},
		{"key too short", []byte("what is entropy"), true},
		{"key too long", []byte(NewRandomStringN(33)), true},
	}
	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			_, err := NewAEADCipher(tt.secret)
			if (err != nil) != tt.wantErr {
				t.Errorf("NewAEADCipher() error = %v, wantErr %v", err, tt.wantErr)
				return
			}
		})
	}
}

func BenchmarkAEADCipher(b *testing.B) {
	plaintext := []byte("my plain text value")

	key := NewKey()
	c, err := NewAEADCipher(key)
	if !assert.NoError(b, err) {
		return
	}

	ciphertext := Encrypt(c, plaintext, nil)
	b.ResetTimer()
	for i := 0; i < b.N; i++ {
		Decrypt(c, ciphertext, nil)
	}
}

func TestNewAEADCipherFromBase64(t *testing.T) {
	t.Parallel()
	tests := []struct {
		name    string
		s       string
		wantErr bool
	}{
		{"simple 32 byte key", base64.StdEncoding.EncodeToString(NewKey()), false},
		{"key too short", base64.StdEncoding.EncodeToString([]byte("what is entropy")), true},
		{"key too long", NewRandomStringN(33), true},
		{"bad base 64", string(NewKey()), true},
	}
	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			_, err := NewAEADCipherFromBase64(tt.s)
			if (err != nil) != tt.wantErr {
				t.Errorf("NewAEADCipherFromBase64() error = %v, wantErr %v", err, tt.wantErr)
				return
			}
		})
	}
}