package cliutil

import (
	"encoding/json"
	"errors"
	"fmt"
	"io/ioutil"
	"os"
	"path/filepath"
	"sync"
	"time"

	"github.com/martinlindhe/base36"
	"gopkg.in/square/go-jose.v2"

	"github.com/pomerium/pomerium/pkg/cryptutil"
)

// predefined cache errors
var (
	ErrExpired  = errors.New("expired")
	ErrInvalid  = errors.New("invalid")
	ErrNotFound = errors.New("not found")
)

// A JWTCache loads and stores JWTs.
type JWTCache interface {
	DeleteJWT(key string) error
	LoadJWT(key string) (rawJWT string, err error)
	StoreJWT(key string, rawJWT string) error
}

// A LocalJWTCache stores files in the user's cache directory.
type LocalJWTCache struct {
	dir string
}

// NewLocalJWTCache creates a new LocalJWTCache.
func NewLocalJWTCache() (*LocalJWTCache, error) {
	root, err := os.UserCacheDir()
	if err != nil {
		return nil, err
	}

	dir := filepath.Join(root, "pomerium-cli", "jwts")

	err = os.MkdirAll(dir, 0o755)
	if err != nil {
		return nil, fmt.Errorf("error creating user cache directory: %w", err)
	}

	return &LocalJWTCache{
		dir: dir,
	}, nil
}

// DeleteJWT deletes a raw JWT from the local cache.
func (cache *LocalJWTCache) DeleteJWT(key string) error {
	path := filepath.Join(cache.dir, cache.fileName(key))
	err := os.Remove(path)
	if os.IsNotExist(err) {
		err = nil
	}
	return err
}

// LoadJWT loads a raw JWT from the local cache.
func (cache *LocalJWTCache) LoadJWT(key string) (rawJWT string, err error) {
	path := filepath.Join(cache.dir, cache.fileName(key))
	rawBS, err := ioutil.ReadFile(path)
	if os.IsNotExist(err) {
		return "", ErrNotFound
	} else if err != nil {
		return "", err
	}
	rawJWT = string(rawBS)

	return rawJWT, checkExpiry(rawJWT)
}

// StoreJWT stores a raw JWT in the local cache.
func (cache *LocalJWTCache) StoreJWT(key string, rawJWT string) error {
	path := filepath.Join(cache.dir, cache.fileName(key))
	err := ioutil.WriteFile(path, []byte(rawJWT), 0o600)
	if err != nil {
		return err
	}

	return nil
}

func (cache *LocalJWTCache) hash(str string) string {
	h := cryptutil.Hash("LocalJWTCache", []byte(str))
	return base36.EncodeBytes(h)
}

func (cache *LocalJWTCache) fileName(key string) string {
	return cache.hash(key) + ".jwt"
}

// A MemoryJWTCache stores JWTs in an in-memory map.
type MemoryJWTCache struct {
	mu      sync.Mutex
	entries map[string]string
}

// NewMemoryJWTCache creates a new in-memory JWT cache.
func NewMemoryJWTCache() *MemoryJWTCache {
	return &MemoryJWTCache{entries: make(map[string]string)}
}

// DeleteJWT deletes a JWT from the in-memory map.
func (cache *MemoryJWTCache) DeleteJWT(key string) error {
	cache.mu.Lock()
	defer cache.mu.Unlock()

	delete(cache.entries, key)
	return nil
}

// LoadJWT loads a JWT from the in-memory map.
func (cache *MemoryJWTCache) LoadJWT(key string) (rawJWT string, err error) {
	cache.mu.Lock()
	defer cache.mu.Unlock()

	rawJWT, ok := cache.entries[key]
	if !ok {
		return "", ErrNotFound
	}

	return rawJWT, checkExpiry(rawJWT)
}

// StoreJWT stores a JWT in the in-memory map.
func (cache *MemoryJWTCache) StoreJWT(key string, rawJWT string) error {
	cache.mu.Lock()
	defer cache.mu.Unlock()

	cache.entries[key] = rawJWT

	return nil
}

func checkExpiry(rawJWT string) error {
	tok, err := jose.ParseSigned(rawJWT)
	if err != nil {
		return ErrInvalid
	}

	var claims struct {
		Expiry int64 `json:"exp"`
	}
	err = json.Unmarshal(tok.UnsafePayloadWithoutVerification(), &claims)
	if err != nil {
		return ErrInvalid
	}

	expiresAt := time.Unix(claims.Expiry, 0)
	if expiresAt.Before(time.Now()) {
		return ErrExpired
	}

	return nil
}