package hpke

import (
	"fmt"
	"net/url"

	"github.com/klauspost/compress/zstd"
)

// URL Parameters
const (
	paramSenderPublicKey = "pomerium_hpke_sender_pub"
	paramQuery           = "pomerium_hpke_query"

	paramSenderPublicKeyV2 = "k"
	paramQueryV2           = "q"
)

// IsEncryptedURL returns true if the url.Values contain an HPKE encrypted query.
func IsEncryptedURL(values url.Values) bool {
	return IsEncryptedURLV1(values) || IsEncryptedURLV2(values)
}

// IsEncryptedURLV1 returns true if the url.Values contain a V1 HPKE encrypted query.
func IsEncryptedURLV1(values url.Values) bool {
	return values.Has(paramSenderPublicKey) && values.Has(paramQuery)
}

// IsEncryptedURLV2 returns true if the url.Values contains a V2 HPKE encrypted query.
func IsEncryptedURLV2(values url.Values) bool {
	return values.Has(paramSenderPublicKeyV2) && values.Has(paramQueryV2)
}

// An EncryptURLValuesFunc is a function that encrypts url values.
type EncryptURLValuesFunc func(senderPrivateKey *PrivateKey, receiverPublicKey *PublicKey, values url.Values) (encrypted url.Values, err error)

// EncryptURLValuesV1 encrypts URL values using the Seal method.
func EncryptURLValuesV1(
	senderPrivateKey *PrivateKey,
	receiverPublicKey *PublicKey,
	values url.Values,
) (encrypted url.Values, err error) {
	values = withoutHPKEParams(values)

	encoded := encodeQueryStringV1(values)

	sealed, err := Seal(senderPrivateKey, receiverPublicKey, encoded)
	if err != nil {
		return nil, fmt.Errorf("hpke: failed to seal URL values %w", err)
	}

	return url.Values{
		paramSenderPublicKey: {senderPrivateKey.PublicKey().String()},
		paramQuery:           {encode(sealed)},
	}, nil
}

// EncryptURLValuesV2 encrypts URL values using the Seal method and compresses the query string.
func EncryptURLValuesV2(
	senderPrivateKey *PrivateKey,
	receiverPublicKey *PublicKey,
	values url.Values,
) (encrypted url.Values, err error) {
	values = withoutHPKEParams(values)

	encoded := encodeQueryStringV2(values)

	sealed, err := Seal(senderPrivateKey, receiverPublicKey, encoded)
	if err != nil {
		return nil, fmt.Errorf("hpke: failed to seal URL values %w", err)
	}

	return url.Values{
		paramSenderPublicKeyV2: {senderPrivateKey.PublicKey().String()},
		paramQueryV2:           {encode(sealed)},
	}, nil
}

// DecryptURLValues decrypts URL values using the Open method.
func DecryptURLValues(
	receiverPrivateKey *PrivateKey,
	encrypted url.Values,
) (senderPublicKey *PublicKey, values url.Values, err error) {
	var decrypted url.Values
	switch {
	case IsEncryptedURLV1(encrypted):
		senderPublicKey, err = PublicKeyFromString(encrypted.Get(paramSenderPublicKey))
		if err != nil {
			return nil, nil, fmt.Errorf("hpke: invalid sender public key parameter: %w", err)
		}

		sealed, err := decode(encrypted.Get(paramQuery))
		if err != nil {
			return nil, nil, fmt.Errorf("hpke: failed decoding query parameter: %w", err)
		}

		message, err := Open(receiverPrivateKey, senderPublicKey, sealed)
		if err != nil {
			return nil, nil, fmt.Errorf("hpke: failed to open sealed message: %w", err)
		}

		decrypted, err = decodeQueryStringV1(message)
		if err != nil {
			return nil, nil, fmt.Errorf("hpke: invalid query parameter: %w", err)
		}
	case IsEncryptedURLV2(encrypted):
		senderPublicKey, err = PublicKeyFromString(encrypted.Get(paramSenderPublicKeyV2))
		if err != nil {
			return nil, nil, fmt.Errorf("hpke: invalid sender public key parameter: %w", err)
		}

		sealed, err := decode(encrypted.Get(paramQueryV2))
		if err != nil {
			return nil, nil, fmt.Errorf("hpke: failed decoding query parameter: %w", err)
		}

		message, err := Open(receiverPrivateKey, senderPublicKey, sealed)
		if err != nil {
			return nil, nil, fmt.Errorf("hpke: failed to open sealed message: %w", err)
		}

		decrypted, err = decodeQueryStringV2(message)
		if err != nil {
			return nil, nil, fmt.Errorf("hpke: invalid query parameter: %w", err)
		}
	default:
		return nil, nil, fmt.Errorf("hpke: missing query parameters")
	}

	values = withoutHPKEParams(encrypted)
	for k, vs := range decrypted {
		values[k] = vs
	}

	return senderPublicKey, values, err
}

func withoutHPKEParams(values url.Values) url.Values {
	filtered := make(url.Values)
	for k, vs := range values {
		if k != paramSenderPublicKey && k != paramQuery && k != paramSenderPublicKeyV2 && k != paramQueryV2 {
			filtered[k] = vs
		}
	}
	return filtered
}

func encodeQueryStringV1(values url.Values) []byte {
	return []byte(values.Encode())
}

const zstdWindowSize = 8 << 10 // 8kiB

var zstdEncoder, _ = zstd.NewWriter(nil,
	zstd.WithEncoderLevel(zstd.SpeedDefault),
	zstd.WithWindowSize(zstdWindowSize),
)

func encodeQueryStringV2(values url.Values) []byte {
	return zstdEncoder.EncodeAll([]byte(values.Encode()), nil)
}

func decodeQueryStringV1(raw []byte) (url.Values, error) {
	return url.ParseQuery(string(raw))
}

var zstdDecoder, _ = zstd.NewReader(nil,
	zstd.WithDecoderLowmem(true),
)

func decodeQueryStringV2(raw []byte) (url.Values, error) {
	bs, err := zstdDecoder.DecodeAll(raw, nil)
	if err != nil {
		return nil, err
	}
	return url.ParseQuery(string(bs))
}