pomerium/internal/zero/token/cache.go
Denis Mishin ce12e51cf5
zero/api: reset token and url cache if 401 is received (#5256)
zero/api: reset token cache if 401 is received
2024-09-03 15:40:28 -04:00

111 lines
2.3 KiB
Go

// Package token provides a thread-safe cache of a authorization token that may be used across http and grpc clients
package token
import (
"context"
"sync/atomic"
"time"
)
const (
maxLockWait = 30 * time.Second
)
// Cache is a thread-safe cache of a authorization token
// that may be used across http and grpc clients
type Cache struct {
TimeNow func() time.Time
refreshToken string
fetcher Fetcher
lock chan struct{}
token atomic.Pointer[Token]
}
// Fetcher is a function that fetches a new token
type Fetcher func(ctx context.Context, refreshToken string) (*Token, error)
// Token is a bearer token
type Token struct {
// Bearer is the bearer token
Bearer string
// Expires is the time the token expires
Expires time.Time
}
// ExpiresAfter returns true if the token expires after the given time
func (t *Token) ExpiresAfter(tm time.Time) bool {
return t != nil && t.Expires.After(tm)
}
// NewCache creates a new token cache
func NewCache(fetcher Fetcher, refreshToken string) *Cache {
c := &Cache{
lock: make(chan struct{}, 1),
fetcher: fetcher,
refreshToken: refreshToken,
}
c.token.Store(nil)
return c
}
func (c *Cache) timeNow() time.Time {
if c.TimeNow != nil {
return c.TimeNow()
}
return time.Now()
}
func (c *Cache) Reset() {
c.token.Store(nil)
}
// GetToken returns the current token if its at least `minTTL` from expiration, or fetches a new one.
func (c *Cache) GetToken(ctx context.Context, minTTL time.Duration) (string, error) {
minExpiration := c.timeNow().Add(minTTL)
token := c.token.Load()
if token.ExpiresAfter(minExpiration) {
return token.Bearer, nil
}
return c.ForceRefreshToken(ctx)
}
func (c *Cache) ForceRefreshToken(ctx context.Context) (string, error) {
var current string
token := c.token.Load()
if token != nil {
current = token.Bearer
}
return c.forceRefreshToken(ctx, current)
}
func (c *Cache) forceRefreshToken(ctx context.Context, current string) (string, error) {
select {
case c.lock <- struct{}{}:
case <-ctx.Done():
return "", ctx.Err()
}
defer func() {
<-c.lock
}()
ctx, cancel := context.WithTimeout(ctx, maxLockWait)
defer cancel()
token := c.token.Load()
if token != nil && token.Bearer != current {
return token.Bearer, nil
}
token, err := c.fetcher(ctx, c.refreshToken)
if err != nil {
return "", err
}
c.token.Store(token)
return token.Bearer, nil
}