package config

import (
	"bytes"
	"encoding/base64"
	"encoding/json"
	"fmt"
	"net/url"
	"reflect"
	"strconv"
	"strings"
	"unicode"

	envoy_config_cluster_v3 "github.com/envoyproxy/go-control-plane/envoy/config/cluster/v3"
	"github.com/mitchellh/mapstructure"
	"github.com/volatiletech/null/v9"
	"google.golang.org/protobuf/encoding/protojson"
	"google.golang.org/protobuf/proto"
	"gopkg.in/yaml.v3"

	"github.com/pomerium/pomerium/internal/httputil"
	"github.com/pomerium/pomerium/internal/urlutil"
	"github.com/pomerium/pomerium/pkg/policy/parser"
)

func decodeNullBoolHookFunc() mapstructure.DecodeHookFunc {
	return func(f, t reflect.Type, data interface{}) (interface{}, error) {
		if t != reflect.TypeOf(null.Bool{}) {
			return data, nil
		}

		bs, err := json.Marshal(data)
		if err != nil {
			return nil, err
		}
		var value null.Bool
		err = json.Unmarshal(bs, &value)
		if err != nil {
			return nil, err
		}
		return value, nil
	}
}

// JWTClaimHeaders are headers to add to a request based on IDP claims.
type JWTClaimHeaders map[string]string

// NewJWTClaimHeaders creates a JWTClaimHeaders map from a slice of claims.
func NewJWTClaimHeaders(claims ...string) JWTClaimHeaders {
	hdrs := make(JWTClaimHeaders)
	for _, claim := range claims {
		k := httputil.PomeriumJWTHeaderName(claim)
		hdrs[k] = claim
	}
	return hdrs
}

// UnmarshalJSON unmarshals JSON data into the JWTClaimHeaders.
func (hdrs *JWTClaimHeaders) UnmarshalJSON(data []byte) error {
	var m map[string]interface{}
	if json.Unmarshal(data, &m) == nil {
		*hdrs = make(map[string]string)
		for k, v := range m {
			str := fmt.Sprint(v)
			(*hdrs)[k] = str
		}
		return nil
	}

	var a []interface{}
	if json.Unmarshal(data, &a) == nil {
		var vs []string
		for _, v := range a {
			vs = append(vs, fmt.Sprint(v))
		}
		*hdrs = NewJWTClaimHeaders(vs...)
		return nil
	}

	var s string
	if json.Unmarshal(data, &s) == nil {
		*hdrs = NewJWTClaimHeaders(strings.FieldsFunc(s, func(r rune) bool {
			return r == ',' || unicode.IsSpace(r)
		})...)
		return nil
	}

	return fmt.Errorf("JWTClaimHeaders must be an object or an array of values, got: %s", data)
}

// UnmarshalYAML uses UnmarshalJSON to unmarshal YAML data into the JWTClaimHeaders.
func (hdrs *JWTClaimHeaders) UnmarshalYAML(unmarshal func(interface{}) error) error {
	var i interface{}
	err := unmarshal(&i)
	if err != nil {
		return err
	}

	m, err := serializable(i)
	if err != nil {
		return err
	}

	bs, err := json.Marshal(m)
	if err != nil {
		return err
	}

	return hdrs.UnmarshalJSON(bs)
}

func decodeJWTClaimHeadersHookFunc() mapstructure.DecodeHookFunc {
	return func(f, t reflect.Type, data interface{}) (interface{}, error) {
		if t != reflect.TypeOf(JWTClaimHeaders{}) {
			return data, nil
		}

		bs, err := json.Marshal(data)
		if err != nil {
			return nil, err
		}
		var hdrs JWTClaimHeaders
		err = json.Unmarshal(bs, &hdrs)
		if err != nil {
			return nil, err
		}
		return hdrs, nil
	}
}

// A StringSlice is a slice of strings.
type StringSlice []string

// NewStringSlice creates a new StringSlice.
func NewStringSlice(values ...string) StringSlice {
	return values
}

const (
	array = iota
	arrayValue
	object
	objectKey
	objectValue
)

// UnmarshalJSON unmarshals a JSON document into the string slice.
func (slc *StringSlice) UnmarshalJSON(data []byte) error {
	typeStack := []int{array}
	stateStack := []int{arrayValue}

	var vals []string
	dec := json.NewDecoder(bytes.NewReader(data))
	for {
		token, err := dec.Token()
		if err != nil {
			break
		}

		if delim, ok := token.(json.Delim); ok {
			switch delim {
			case '[':
				typeStack = append(typeStack, array)
				stateStack = append(stateStack, arrayValue)
			case '{':
				typeStack = append(typeStack, object)
				stateStack = append(stateStack, objectKey)
			case ']', '}':
				typeStack = typeStack[:len(typeStack)-1]
				stateStack = stateStack[:len(stateStack)-1]
			}
			continue
		}

		switch stateStack[len(stateStack)-1] {
		case objectKey:
			stateStack[len(stateStack)-1] = objectValue
		case objectValue:
			stateStack[len(stateStack)-1] = objectKey
			fallthrough
		default:
			switch t := token.(type) {
			case bool:
				vals = append(vals, fmt.Sprint(t))
			case float64:
				vals = append(vals, fmt.Sprint(t))
			case json.Number:
				vals = append(vals, fmt.Sprint(t))
			case string:
				vals = append(vals, t)
			default:
			}
		}
	}
	*slc = StringSlice(vals)
	return nil
}

// UnmarshalYAML unmarshals a YAML document into the string slice. UnmarshalJSON is
// reused as the actual implementation.
func (slc *StringSlice) UnmarshalYAML(unmarshal func(interface{}) error) error {
	var i interface{}
	err := unmarshal(&i)
	if err != nil {
		return err
	}
	bs, err := json.Marshal(i)
	if err != nil {
		return err
	}
	return slc.UnmarshalJSON(bs)
}

// WeightedURL is a way to specify an upstream with load balancing weight attached to it
type WeightedURL struct {
	URL url.URL
	// LbWeight is a relative load balancer weight for this upstream URL
	// zero means not assigned
	LbWeight uint32
}

// Validate validates that the WeightedURL is valid.
func (u *WeightedURL) Validate() error {
	if u.URL.Hostname() == "" {
		return errHostnameMustBeSpecified
	}
	if u.URL.Scheme == "" {
		return errSchemeMustBeSpecified
	}
	return nil
}

// ParseWeightedURL parses url that has an optional weight appended to it
func ParseWeightedURL(dst string) (*WeightedURL, error) {
	to, w, err := weightedString(dst)
	if err != nil {
		return nil, err
	}

	u, err := urlutil.ParseAndValidateURL(to)
	if err != nil {
		return nil, fmt.Errorf("%s: %w", to, err)
	}

	if u.Hostname() == "" {
		return nil, errHostnameMustBeSpecified
	}

	return &WeightedURL{*u, w}, nil
}

// String returns the WeightedURL as a string.
func (u *WeightedURL) String() string {
	str := u.URL.String()
	if u.LbWeight == 0 {
		return str
	}
	return fmt.Sprintf("{url=%s, weight=%d}", str, u.LbWeight)
}

// WeightedURLs is a slice of WeightedURLs.
type WeightedURLs []WeightedURL

// ParseWeightedUrls parses
func ParseWeightedUrls(urls ...string) (WeightedURLs, error) {
	out := make([]WeightedURL, 0, len(urls))

	for _, dst := range urls {
		u, err := ParseWeightedURL(dst)
		if err != nil {
			return nil, err
		}
		out = append(out, *u)
	}

	if _, err := WeightedURLs(out).Validate(); err != nil {
		return nil, err
	}

	return out, nil
}

// HasWeight indicates if url group has weights assigned
type HasWeight bool

// Validate checks that URLs are valid, and either all or none have weights assigned
func (urls WeightedURLs) Validate() (HasWeight, error) {
	if len(urls) == 0 {
		return false, errEmptyUrls
	}

	noWeight := false
	hasWeight := false

	for i := range urls {
		if err := urls[i].Validate(); err != nil {
			return false, fmt.Errorf("%s: %w", urls[i].String(), err)
		}
		if urls[i].LbWeight == 0 {
			noWeight = true
		} else {
			hasWeight = true
		}
	}

	if noWeight == hasWeight {
		return false, errEndpointWeightsSpec
	}

	if noWeight {
		return false, nil
	}
	return true, nil
}

// Flatten converts weighted url array into indidual arrays of urls and weights
func (urls WeightedURLs) Flatten() ([]string, []uint32, error) {
	hasWeight, err := urls.Validate()
	if err != nil {
		return nil, nil, err
	}

	str := make([]string, 0, len(urls))
	wghts := make([]uint32, 0, len(urls))

	for i := range urls {
		str = append(str, urls[i].URL.String())
		wghts = append(wghts, urls[i].LbWeight)
	}

	if !hasWeight {
		return str, nil, nil
	}
	return str, wghts, nil
}

// PPLPolicy is a policy defined using PPL.
type PPLPolicy struct {
	*parser.Policy
}

// UnmarshalJSON parses JSON into a PPL policy.
func (ppl *PPLPolicy) UnmarshalJSON(data []byte) error {
	var err error
	ppl.Policy, err = parser.ParseJSON(bytes.NewReader(data))
	if err != nil {
		return err
	}
	return nil
}

// UnmarshalYAML parses YAML into a PPL policy.
func (ppl *PPLPolicy) UnmarshalYAML(unmarshal func(interface{}) error) error {
	var i interface{}
	err := unmarshal(&i)
	if err != nil {
		return err
	}
	bs, err := json.Marshal(i)
	if err != nil {
		return err
	}
	return ppl.UnmarshalJSON(bs)
}

func decodePPLPolicyHookFunc() mapstructure.DecodeHookFunc {
	return func(f, t reflect.Type, data interface{}) (interface{}, error) {
		if t != reflect.TypeOf(&PPLPolicy{}) {
			return data, nil
		}
		bs, err := json.Marshal(data)
		if err != nil {
			return nil, err
		}
		var ppl PPLPolicy
		err = json.Unmarshal(bs, &ppl)
		if err != nil {
			return nil, err
		}
		return &ppl, nil
	}
}

// DecodePolicyBase64Hook returns a mapstructure decode hook for base64 data.
func DecodePolicyBase64Hook() mapstructure.DecodeHookFunc {
	return func(f, t reflect.Type, data interface{}) (interface{}, error) {
		if t != reflect.TypeOf([]Policy{}) {
			return data, nil
		}

		str, ok := data.([]string)
		if !ok {
			return data, nil
		}

		if len(str) != 1 {
			return nil, fmt.Errorf("base64 policy data: expecting 1, got %d", len(str))
		}

		bytes, err := base64.StdEncoding.DecodeString(str[0])
		if err != nil {
			return nil, fmt.Errorf("base64 decoding policy data: %w", err)
		}

		var out []map[interface{}]interface{}
		if err = yaml.Unmarshal(bytes, &out); err != nil {
			return nil, fmt.Errorf("parsing base64-encoded policy data as yaml: %w", err)
		}

		return out, nil
	}
}

// DecodePolicyHookFunc returns a Decode Hook for mapstructure.
func DecodePolicyHookFunc() mapstructure.DecodeHookFunc {
	return func(f, t reflect.Type, data interface{}) (interface{}, error) {
		if t != reflect.TypeOf(Policy{}) {
			return data, nil
		}

		// convert all keys to strings so that it can be serialized back to JSON
		// and read by jsonproto package into Envoy's cluster structure
		mp, err := serializable(data)
		if err != nil {
			return nil, err
		}
		ms, ok := mp.(map[string]interface{})
		if !ok {
			return nil, errKeysMustBeStrings
		}

		return parsePolicy(ms)
	}
}

func parsePolicy(src map[string]interface{}) (out map[string]interface{}, err error) {
	out = make(map[string]interface{}, len(src))
	for k, v := range src {
		if k == toKey {
			if v, err = parseTo(v); err != nil {
				return nil, err
			}
		}
		out[k] = v
	}

	// also, interpret the entire policy as Envoy's Cluster document to derive its options
	out[envoyOptsKey], err = parseEnvoyClusterOpts(src)
	if err != nil {
		return nil, err
	}

	return out, nil
}

func parseTo(raw interface{}) ([]WeightedURL, error) {
	rawBS, err := json.Marshal(raw)
	if err != nil {
		return nil, err
	}
	var slc StringSlice
	err = json.Unmarshal(rawBS, &slc)
	if err != nil {
		return nil, err
	}

	return ParseWeightedUrls(slc...)
}

// parses URL followed by weighted
func weightedString(str string) (string, uint32, error) {
	i := strings.IndexRune(str, ',')
	if i < 0 {
		return str, 0, nil
	}

	w, err := strconv.ParseUint(str[i+1:], 10, 32)
	if err != nil {
		return "", 0, err
	}

	if w == 0 {
		return "", 0, errZeroWeight
	}

	return str[:i], uint32(w), nil
}

// parseEnvoyClusterOpts parses src as envoy cluster spec https://www.envoyproxy.io/docs/envoy/latest/api-v3/config/cluster/v3/cluster.proto
// on top of some pre-filled default values
func parseEnvoyClusterOpts(src map[string]interface{}) (*envoy_config_cluster_v3.Cluster, error) {
	c := new(envoy_config_cluster_v3.Cluster)
	if err := parseJSONPB(src, c, protoPartial); err != nil {
		return nil, err
	}

	return c, nil
}

// parseJSONPB takes an intermediate representation and parses it using protobuf parser
// that correctly handles oneof and other data types
func parseJSONPB(src map[string]interface{}, dst proto.Message, opts protojson.UnmarshalOptions) error {
	data, err := json.Marshal(src)
	if err != nil {
		return err
	}

	return opts.Unmarshal(data, dst)
}

// serializable converts mapstructure nested map into map[string]interface{} that is serializable to JSON
func serializable(in interface{}) (interface{}, error) {
	switch typed := in.(type) {
	case map[interface{}]interface{}:
		m := make(map[string]interface{})
		for k, v := range typed {
			kstr, ok := k.(string)
			if !ok {
				return nil, errKeysMustBeStrings
			}
			val, err := serializable(v)
			if err != nil {
				return nil, err
			}
			m[kstr] = val
		}
		return m, nil
	case []interface{}:
		out := make([]interface{}, 0, len(typed))
		for _, elem := range typed {
			val, err := serializable(elem)
			if err != nil {
				return nil, err
			}
			out = append(out, val)
		}
		return out, nil
	default:
		return in, nil
	}
}