// Package jwtutil contains functions for working with JWTs.
package jwtutil

import (
	"bytes"
	"encoding/json"
	"fmt"
	"reflect"
	"time"
)

// Claims represent claims in a JWT.
type Claims map[string]any

// UnmarshalJSON implements a custom unmarshaller for claims data.
func (claims *Claims) UnmarshalJSON(raw []byte) error {
	dst := map[string]any{}
	dec := json.NewDecoder(bytes.NewReader(raw))
	dec.UseNumber()
	err := dec.Decode(&dst)
	if err != nil {
		return err
	}
	*claims = Claims(dst)
	return nil
}

// registered claims

// GetIssuer gets the iss claim.
func (claims Claims) GetIssuer() (issuer string, ok bool) {
	return claims.GetString("iss")
}

// GetSubject gets the sub claim.
func (claims Claims) GetSubject() (subject string, ok bool) {
	return claims.GetString("sub")
}

// GetAudience gets the aud claim.
func (claims Claims) GetAudience() (audiences []string, ok bool) {
	return claims.GetStringSlice("aud")
}

// GetExpirationTime gets the exp claim.
func (claims Claims) GetExpirationTime() (expirationTime time.Time, ok bool) {
	return claims.GetNumericDate("exp")
}

// GetNotBefore gets the nbf claim.
func (claims Claims) GetNotBefore() (notBefore time.Time, ok bool) {
	return claims.GetNumericDate("nbf")
}

// GetIssuedAt gets the iat claim.
func (claims Claims) GetIssuedAt() (issuedAt time.Time, ok bool) {
	return claims.GetNumericDate("iat")
}

// GetJWTID gets the jti claim.
func (claims Claims) GetJWTID() (jwtID string, ok bool) {
	return claims.GetString("jti")
}

// custom claims

// GetUserID returns the oid or sub claim.
func (claims Claims) GetUserID() (userID string, ok bool) {
	if oid, ok := claims.GetString("oid"); ok {
		return oid, true
	}

	if sub, ok := claims.GetSubject(); ok {
		return sub, true
	}

	return "", false
}

// GetNumericDate returns the claim as a numeric date.
func (claims Claims) GetNumericDate(name string) (tm time.Time, ok bool) {
	if claims == nil {
		return tm, false
	}

	raw, ok := claims[name]
	if !ok {
		return tm, false
	}

	switch v := raw.(type) {
	case float32:
		return time.Unix(int64(v), 0), true
	case float64:
		return time.Unix(int64(v), 0), true
	case int64:
		return time.Unix(v, 0), true
	case int32:
		return time.Unix(int64(v), 0), true
	case int16:
		return time.Unix(int64(v), 0), true
	case int8:
		return time.Unix(int64(v), 0), true
	case int:
		return time.Unix(int64(v), 0), true
	case uint64:
		return time.Unix(int64(v), 0), true
	case uint32:
		return time.Unix(int64(v), 0), true
	case uint16:
		return time.Unix(int64(v), 0), true
	case uint8:
		return time.Unix(int64(v), 0), true
	case uint:
		return time.Unix(int64(v), 0), true
	case json.Number:
		i, err := v.Int64()
		if err != nil {
			if f, err := v.Float64(); err == nil {
				i = int64(f)
			}
		}
		if err != nil {
			return tm, false
		}
		return time.Unix(i, 0), true
	}

	return tm, false
}

// GetString returns the claim as a string.
func (claims Claims) GetString(name string) (value string, ok bool) {
	raw, ok := claims[name]
	if !ok {
		return value, false
	}

	return toString(raw), true
}

// GetStringSlice returns the claim as a slice of strings.
func (claims Claims) GetStringSlice(name string) (values []string, ok bool) {
	raw, ok := claims[name]
	if !ok {
		return nil, false
	}

	return toStringSlice(raw), true
}

func toString(data any) string {
	switch v := data.(type) {
	case string:
		return v
	}
	return fmt.Sprint(data)
}

func toStringSlice(obj any) []string {
	v := reflect.ValueOf(obj)
	switch v.Kind() {
	case reflect.Slice:
		vs := make([]string, v.Len())
		for i := 0; i < v.Len(); i++ {
			vs[i] = toString(v.Index(i).Interface())
		}
		return vs
	}

	return []string{toString(obj)}
}