package clusterping

import (
	"context"
	"crypto/tls"
	"encoding/json"
	"errors"
	"fmt"
	"net"
	"net/http"
	"net/url"

	"github.com/go-jose/go-jose/v3"

	"github.com/pomerium/pomerium/internal/version"
)

const (
	JWKSPath = "/.well-known/pomerium/jwks.json"
)

type CheckErrorCode int

const (
	ErrInvalidCert CheckErrorCode = iota
	ErrDNSError
	ErrConnectionError
	ErrKeyNotFound
	ErrUnexpectedResponse
)

type CheckError struct {
	Code CheckErrorCode
	Err  error
}

func NewCheckError(code CheckErrorCode, err error) *CheckError {
	return &CheckError{
		Code: code,
		Err:  err,
	}
}

var errorCodeToString = map[CheckErrorCode]string{
	ErrInvalidCert:        "invalid certificate",
	ErrDNSError:           "DNS error",
	ErrConnectionError:    "connection error",
	ErrKeyNotFound:        "key not found",
	ErrUnexpectedResponse: "unexpected response",
}

func (e *CheckError) Error() string {
	return fmt.Sprintf("%s: %v", errorCodeToString[e.Code], e.Err)
}

func (e *CheckError) Unwrap() error {
	return e.Err
}

func GetJWKSURL(host string) string {
	return (&url.URL{
		Scheme: "https",
		Host:   host,
		Path:   JWKSPath,
	}).String()
}

func CheckKey(
	ctx context.Context,
	jwksURL string,
	key jose.JSONWebKey,
	client *http.Client,
) error {
	keys, err := fetchKeys(ctx, client, jwksURL)
	if err != nil {
		return err
	}

	if !containsKey(keys, key) {
		return NewCheckError(ErrKeyNotFound, fmt.Errorf("key %s not found in JWKS", key.KeyID))
	}

	return nil
}

func containsKey(keys []jose.JSONWebKey, key jose.JSONWebKey) bool {
	for _, k := range keys {
		if k.KeyID == key.KeyID {
			return true
		}
	}
	return false
}

func fetchKeys(ctx context.Context, client *http.Client, jwksURL string) ([]jose.JSONWebKey, error) {
	req, err := http.NewRequestWithContext(ctx, http.MethodGet, jwksURL, nil)
	if err != nil {
		return nil, fmt.Errorf("error creating request: %w", err)
	}
	req.Header.Set("Accept", "application/json")
	req.Header.Set("User-Agent", version.UserAgent())
	resp, err := client.Do(req)
	if err != nil {
		return nil, convertRequestError(err)
	}
	defer resp.Body.Close()

	if resp.StatusCode != http.StatusOK {
		return nil, NewCheckError(ErrUnexpectedResponse, fmt.Errorf("unexpected status code %d", resp.StatusCode))
	}

	if resp.Header.Get("Content-Type") != "application/json" {
		return nil, NewCheckError(ErrUnexpectedResponse, fmt.Errorf("unexpected content type %s", resp.Header.Get("Content-Type")))
	}

	var jwks struct {
		Keys []jose.JSONWebKey `json:"keys"`
	}
	if err := json.NewDecoder(resp.Body).Decode(&jwks); err != nil {
		return nil, NewCheckError(ErrUnexpectedResponse, fmt.Errorf("error decoding response: %w", err))
	}

	return jwks.Keys, nil
}

func convertRequestError(err error) error {
	if tlsErr := new(tls.CertificateVerificationError); errors.As(err, &tlsErr) {
		return NewCheckError(ErrInvalidCert, err)
	}
	if dnsErr := new(net.DNSError); errors.As(err, &dnsErr) {
		return NewCheckError(ErrDNSError, err)
	}
	if netErr := new(net.Error); errors.As(err, netErr) {
		return NewCheckError(ErrConnectionError, err)
	}

	return fmt.Errorf("error making request: %w", err)
}