package identity

import (
	"encoding/json"
	"fmt"
	"reflect"

	"google.golang.org/protobuf/types/known/anypb"
	"google.golang.org/protobuf/types/known/structpb"

	"github.com/pomerium/pomerium/pkg/protoutil"
)

// SessionClaims are claims that are attached to a session so we can store the raw id token.
type SessionClaims struct {
	Claims
	RawIDToken string
}

// SetRawIDToken sets the raw id token.
func (claims *SessionClaims) SetRawIDToken(rawIDToken string) {
	claims.RawIDToken = rawIDToken
}

// Claims are JWT claims.
type Claims map[string]interface{}

// NewClaimsFromRaw creates a new Claims map from a map of raw messages.
func NewClaimsFromRaw(raw map[string]json.RawMessage) Claims {
	claims := make(Claims)
	for k, rawv := range raw {
		var v interface{}
		if err := json.Unmarshal(rawv, &v); err == nil {
			claims[k] = v
		}
	}
	return claims
}

// UnmarshalJSON unmarshals the raw json data into the claims object.
func (claims *Claims) UnmarshalJSON(data []byte) error {
	if *claims == nil {
		*claims = make(Claims)
	}

	var m map[string]interface{}
	err := json.Unmarshal(data, &m)
	if err != nil {
		return err
	}
	for k, v := range m {
		(*claims)[k] = v
	}
	return nil
}

// Claims takes the claims data and fills v.
func (claims Claims) Claims(v interface{}) error {
	bs, err := json.Marshal(claims)
	if err != nil {
		return err
	}
	return json.Unmarshal(bs, v)
}

// Flatten flattens the claims to a FlattenedClaims map. For example:
//
//   { "a": { "b": { "c": 12345 } } } => { "a.b.c": [12345] }
//
func (claims Claims) Flatten() FlattenedClaims {
	flattened := make(FlattenedClaims)
	for k, v := range claims {
		rv := reflect.ValueOf(v)
		switch rv.Kind() {
		case reflect.Map:
			subClaims := make(Claims)
			iter := rv.MapRange()
			for iter.Next() {
				subClaims[fmt.Sprint(iter.Key().Interface())] = iter.Value().Interface()
			}
			for sk, sv := range subClaims.Flatten() {
				flattened[k+"."+sk] = sv
			}
		case reflect.Slice:
			slc := make([]interface{}, rv.Len())
			for i := 0; i < rv.Len(); i++ {
				slc[i] = rv.Index(i).Interface()
			}
			flattened[k] = slc
		default:
			flattened[k] = []interface{}{v}
		}
	}
	return flattened
}

// ToAnyMap converts the claims into a map of string => any.
func (claims Claims) ToAnyMap() map[string]*anypb.Any {
	m := map[string]*anypb.Any{}
	for k, v := range claims {
		m[k] = protoutil.ToAny(v)
	}
	return m
}

// FlattenedClaims are a set claims flattened into a single-level map.
type FlattenedClaims map[string][]interface{}

// NewFlattenedClaimsFromPB creates a new FlattenedClaims from the protobuf struct type.
func NewFlattenedClaimsFromPB(m map[string]*structpb.ListValue) FlattenedClaims {
	claims := make(FlattenedClaims)
	if m == nil {
		return claims
	}
	bs, _ := json.Marshal(m)
	_ = json.Unmarshal(bs, &claims)
	return claims
}

// ToPB converts the flattened claims into a protobuf type.
func (claims FlattenedClaims) ToPB() map[string]*structpb.ListValue {
	if claims == nil {
		return nil
	}
	m := make(map[string]*structpb.ListValue)
	for k, vs := range claims {
		svs := make([]*structpb.Value, len(vs))
		for i, v := range vs {
			svs[i] = protoutil.ToStruct(v)
		}
		m[k] = &structpb.ListValue{Values: svs}
	}
	return m
}

// UnmarshalJSON unmarshals JSON into the flattened claims.
func (claims *FlattenedClaims) UnmarshalJSON(data []byte) error {
	var unflattened Claims
	err := json.Unmarshal(data, &unflattened)
	if err != nil {
		return err
	}

	if *claims == nil {
		*claims = make(FlattenedClaims)
	}
	for k, v := range unflattened.Flatten() {
		(*claims)[k] = v
	}
	return nil
}