mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-13 00:58:06 +02:00
zero/api: reset token and url cache if 401 is received (#5256)
zero/api: reset token cache if 401 is received
This commit is contained in:
parent
a04d1a450c
commit
ce12e51cf5
8 changed files with 91 additions and 32 deletions
|
@ -3,7 +3,6 @@ package token
|
|||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
@ -21,7 +20,7 @@ type Cache struct {
|
|||
fetcher Fetcher
|
||||
|
||||
lock chan struct{}
|
||||
token atomic.Value
|
||||
token atomic.Pointer[Token]
|
||||
}
|
||||
|
||||
// Fetcher is a function that fetches a new token
|
||||
|
@ -42,11 +41,13 @@ func (t *Token) ExpiresAfter(tm time.Time) bool {
|
|||
|
||||
// NewCache creates a new token cache
|
||||
func NewCache(fetcher Fetcher, refreshToken string) *Cache {
|
||||
return &Cache{
|
||||
c := &Cache{
|
||||
lock: make(chan struct{}, 1),
|
||||
fetcher: fetcher,
|
||||
refreshToken: refreshToken,
|
||||
}
|
||||
c.token.Store(nil)
|
||||
return c
|
||||
}
|
||||
|
||||
func (c *Cache) timeNow() time.Time {
|
||||
|
@ -56,19 +57,33 @@ func (c *Cache) timeNow() time.Time {
|
|||
return time.Now()
|
||||
}
|
||||
|
||||
func (c *Cache) Reset() {
|
||||
c.token.Store(nil)
|
||||
}
|
||||
|
||||
// GetToken returns the current token if its at least `minTTL` from expiration, or fetches a new one.
|
||||
func (c *Cache) GetToken(ctx context.Context, minTTL time.Duration) (string, error) {
|
||||
minExpiration := c.timeNow().Add(minTTL)
|
||||
|
||||
token, ok := c.token.Load().(*Token)
|
||||
if ok && token.ExpiresAfter(minExpiration) {
|
||||
token := c.token.Load()
|
||||
if token.ExpiresAfter(minExpiration) {
|
||||
return token.Bearer, nil
|
||||
}
|
||||
|
||||
return c.forceRefreshToken(ctx, minExpiration)
|
||||
return c.ForceRefreshToken(ctx)
|
||||
}
|
||||
|
||||
func (c *Cache) forceRefreshToken(ctx context.Context, minExpiration time.Time) (string, error) {
|
||||
func (c *Cache) ForceRefreshToken(ctx context.Context) (string, error) {
|
||||
var current string
|
||||
token := c.token.Load()
|
||||
if token != nil {
|
||||
current = token.Bearer
|
||||
}
|
||||
|
||||
return c.forceRefreshToken(ctx, current)
|
||||
}
|
||||
|
||||
func (c *Cache) forceRefreshToken(ctx context.Context, current string) (string, error) {
|
||||
select {
|
||||
case c.lock <- struct{}{}:
|
||||
case <-ctx.Done():
|
||||
|
@ -81,8 +96,8 @@ func (c *Cache) forceRefreshToken(ctx context.Context, minExpiration time.Time)
|
|||
ctx, cancel := context.WithTimeout(ctx, maxLockWait)
|
||||
defer cancel()
|
||||
|
||||
token, ok := c.token.Load().(*Token)
|
||||
if ok && token.ExpiresAfter(minExpiration) {
|
||||
token := c.token.Load()
|
||||
if token != nil && token.Bearer != current {
|
||||
return token.Bearer, nil
|
||||
}
|
||||
|
||||
|
@ -92,9 +107,5 @@ func (c *Cache) forceRefreshToken(ctx context.Context, minExpiration time.Time)
|
|||
}
|
||||
c.token.Store(token)
|
||||
|
||||
if token.Expires.Before(minExpiration) {
|
||||
return "", fmt.Errorf("new token cannot satisfy TTL: %v", minExpiration.Sub(token.Expires))
|
||||
}
|
||||
|
||||
return token.Bearer, nil
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue