// Package token provides a thread-safe cache of a authorization token that may be used across http and grpc clients
package token

import (
	"context"
	"sync/atomic"
	"time"
)

const (
	maxLockWait = 30 * time.Second
)

// Cache is a thread-safe cache of a authorization token
// that may be used across http and grpc clients
type Cache struct {
	TimeNow func() time.Time

	refreshToken string
	fetcher      Fetcher

	lock  chan struct{}
	token atomic.Pointer[Token]
}

// Fetcher is a function that fetches a new token
type Fetcher func(ctx context.Context, refreshToken string) (*Token, error)

// Token is a bearer token
type Token struct {
	// Bearer is the bearer token
	Bearer string
	// Expires is the time the token expires
	Expires time.Time
}

// ExpiresAfter returns true if the token expires after the given time
func (t *Token) ExpiresAfter(tm time.Time) bool {
	return t != nil && t.Expires.After(tm)
}

// NewCache creates a new token cache
func NewCache(fetcher Fetcher, refreshToken string) *Cache {
	c := &Cache{
		lock:         make(chan struct{}, 1),
		fetcher:      fetcher,
		refreshToken: refreshToken,
	}
	c.token.Store(nil)
	return c
}

func (c *Cache) timeNow() time.Time {
	if c.TimeNow != nil {
		return c.TimeNow()
	}
	return time.Now()
}

func (c *Cache) Reset() {
	c.token.Store(nil)
}

// GetToken returns the current token if its at least `minTTL` from expiration, or fetches a new one.
func (c *Cache) GetToken(ctx context.Context, minTTL time.Duration) (string, error) {
	minExpiration := c.timeNow().Add(minTTL)

	token := c.token.Load()
	if token.ExpiresAfter(minExpiration) {
		return token.Bearer, nil
	}

	return c.ForceRefreshToken(ctx)
}

func (c *Cache) ForceRefreshToken(ctx context.Context) (string, error) {
	var current string
	token := c.token.Load()
	if token != nil {
		current = token.Bearer
	}

	return c.forceRefreshToken(ctx, current)
}

func (c *Cache) forceRefreshToken(ctx context.Context, current string) (string, error) {
	select {
	case c.lock <- struct{}{}:
	case <-ctx.Done():
		return "", ctx.Err()
	}
	defer func() {
		<-c.lock
	}()

	ctx, cancel := context.WithTimeout(ctx, maxLockWait)
	defer cancel()

	token := c.token.Load()
	if token != nil && token.Bearer != current {
		return token.Bearer, nil
	}

	token, err := c.fetcher(ctx, c.refreshToken)
	if err != nil {
		return "", err
	}
	c.token.Store(token)

	return token.Bearer, nil
}