mirror of
https://github.com/pomerium/pomerium.git
synced 2025-04-30 19:06:33 +02:00
111 lines
2.3 KiB
Go
111 lines
2.3 KiB
Go
// Package token provides a thread-safe cache of a authorization token that may be used across http and grpc clients
|
|
package token
|
|
|
|
import (
|
|
"context"
|
|
"sync/atomic"
|
|
"time"
|
|
)
|
|
|
|
const (
|
|
maxLockWait = 30 * time.Second
|
|
)
|
|
|
|
// Cache is a thread-safe cache of a authorization token
|
|
// that may be used across http and grpc clients
|
|
type Cache struct {
|
|
TimeNow func() time.Time
|
|
|
|
refreshToken string
|
|
fetcher Fetcher
|
|
|
|
lock chan struct{}
|
|
token atomic.Pointer[Token]
|
|
}
|
|
|
|
// Fetcher is a function that fetches a new token
|
|
type Fetcher func(ctx context.Context, refreshToken string) (*Token, error)
|
|
|
|
// Token is a bearer token
|
|
type Token struct {
|
|
// Bearer is the bearer token
|
|
Bearer string
|
|
// Expires is the time the token expires
|
|
Expires time.Time
|
|
}
|
|
|
|
// ExpiresAfter returns true if the token expires after the given time
|
|
func (t *Token) ExpiresAfter(tm time.Time) bool {
|
|
return t != nil && t.Expires.After(tm)
|
|
}
|
|
|
|
// NewCache creates a new token cache
|
|
func NewCache(fetcher Fetcher, refreshToken string) *Cache {
|
|
c := &Cache{
|
|
lock: make(chan struct{}, 1),
|
|
fetcher: fetcher,
|
|
refreshToken: refreshToken,
|
|
}
|
|
c.token.Store(nil)
|
|
return c
|
|
}
|
|
|
|
func (c *Cache) timeNow() time.Time {
|
|
if c.TimeNow != nil {
|
|
return c.TimeNow()
|
|
}
|
|
return time.Now()
|
|
}
|
|
|
|
func (c *Cache) Reset() {
|
|
c.token.Store(nil)
|
|
}
|
|
|
|
// GetToken returns the current token if its at least `minTTL` from expiration, or fetches a new one.
|
|
func (c *Cache) GetToken(ctx context.Context, minTTL time.Duration) (string, error) {
|
|
minExpiration := c.timeNow().Add(minTTL)
|
|
|
|
token := c.token.Load()
|
|
if token.ExpiresAfter(minExpiration) {
|
|
return token.Bearer, nil
|
|
}
|
|
|
|
return c.ForceRefreshToken(ctx)
|
|
}
|
|
|
|
func (c *Cache) ForceRefreshToken(ctx context.Context) (string, error) {
|
|
var current string
|
|
token := c.token.Load()
|
|
if token != nil {
|
|
current = token.Bearer
|
|
}
|
|
|
|
return c.forceRefreshToken(ctx, current)
|
|
}
|
|
|
|
func (c *Cache) forceRefreshToken(ctx context.Context, current string) (string, error) {
|
|
select {
|
|
case c.lock <- struct{}{}:
|
|
case <-ctx.Done():
|
|
return "", ctx.Err()
|
|
}
|
|
defer func() {
|
|
<-c.lock
|
|
}()
|
|
|
|
ctx, cancel := context.WithTimeout(ctx, maxLockWait)
|
|
defer cancel()
|
|
|
|
token := c.token.Load()
|
|
if token != nil && token.Bearer != current {
|
|
return token.Bearer, nil
|
|
}
|
|
|
|
token, err := c.fetcher(ctx, c.refreshToken)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
c.token.Store(token)
|
|
|
|
return token.Bearer, nil
|
|
}
|