mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-21 21:17:13 +02:00
zero/api: reset token and url cache if 401 is received (#5256)
zero/api: reset token cache if 401 is received
This commit is contained in:
parent
a04d1a450c
commit
ce12e51cf5
8 changed files with 91 additions and 32 deletions
|
@ -3,7 +3,6 @@ package token
|
|||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
@ -21,7 +20,7 @@ type Cache struct {
|
|||
fetcher Fetcher
|
||||
|
||||
lock chan struct{}
|
||||
token atomic.Value
|
||||
token atomic.Pointer[Token]
|
||||
}
|
||||
|
||||
// Fetcher is a function that fetches a new token
|
||||
|
@ -42,11 +41,13 @@ func (t *Token) ExpiresAfter(tm time.Time) bool {
|
|||
|
||||
// NewCache creates a new token cache
|
||||
func NewCache(fetcher Fetcher, refreshToken string) *Cache {
|
||||
return &Cache{
|
||||
c := &Cache{
|
||||
lock: make(chan struct{}, 1),
|
||||
fetcher: fetcher,
|
||||
refreshToken: refreshToken,
|
||||
}
|
||||
c.token.Store(nil)
|
||||
return c
|
||||
}
|
||||
|
||||
func (c *Cache) timeNow() time.Time {
|
||||
|
@ -56,19 +57,33 @@ func (c *Cache) timeNow() time.Time {
|
|||
return time.Now()
|
||||
}
|
||||
|
||||
func (c *Cache) Reset() {
|
||||
c.token.Store(nil)
|
||||
}
|
||||
|
||||
// GetToken returns the current token if its at least `minTTL` from expiration, or fetches a new one.
|
||||
func (c *Cache) GetToken(ctx context.Context, minTTL time.Duration) (string, error) {
|
||||
minExpiration := c.timeNow().Add(minTTL)
|
||||
|
||||
token, ok := c.token.Load().(*Token)
|
||||
if ok && token.ExpiresAfter(minExpiration) {
|
||||
token := c.token.Load()
|
||||
if token.ExpiresAfter(minExpiration) {
|
||||
return token.Bearer, nil
|
||||
}
|
||||
|
||||
return c.forceRefreshToken(ctx, minExpiration)
|
||||
return c.ForceRefreshToken(ctx)
|
||||
}
|
||||
|
||||
func (c *Cache) forceRefreshToken(ctx context.Context, minExpiration time.Time) (string, error) {
|
||||
func (c *Cache) ForceRefreshToken(ctx context.Context) (string, error) {
|
||||
var current string
|
||||
token := c.token.Load()
|
||||
if token != nil {
|
||||
current = token.Bearer
|
||||
}
|
||||
|
||||
return c.forceRefreshToken(ctx, current)
|
||||
}
|
||||
|
||||
func (c *Cache) forceRefreshToken(ctx context.Context, current string) (string, error) {
|
||||
select {
|
||||
case c.lock <- struct{}{}:
|
||||
case <-ctx.Done():
|
||||
|
@ -81,8 +96,8 @@ func (c *Cache) forceRefreshToken(ctx context.Context, minExpiration time.Time)
|
|||
ctx, cancel := context.WithTimeout(ctx, maxLockWait)
|
||||
defer cancel()
|
||||
|
||||
token, ok := c.token.Load().(*Token)
|
||||
if ok && token.ExpiresAfter(minExpiration) {
|
||||
token := c.token.Load()
|
||||
if token != nil && token.Bearer != current {
|
||||
return token.Bearer, nil
|
||||
}
|
||||
|
||||
|
@ -92,9 +107,5 @@ func (c *Cache) forceRefreshToken(ctx context.Context, minExpiration time.Time)
|
|||
}
|
||||
c.token.Store(token)
|
||||
|
||||
if token.Expires.Before(minExpiration) {
|
||||
return "", fmt.Errorf("new token cannot satisfy TTL: %v", minExpiration.Sub(token.Expires))
|
||||
}
|
||||
|
||||
return token.Bearer, nil
|
||||
}
|
||||
|
|
|
@ -2,6 +2,7 @@ package token_test
|
|||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
@ -51,15 +52,28 @@ func TestCache(t *testing.T) {
|
|||
assert.Equal(t, "bearer-3", bearer)
|
||||
})
|
||||
|
||||
t.Run("token cannot fit minTTL", func(t *testing.T) {
|
||||
t.Run("reset", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var calls int
|
||||
fetcher := func(_ context.Context, _ string) (*token.Token, error) {
|
||||
return &token.Token{"ok-bearer", time.Now().Add(time.Minute)}, nil
|
||||
calls++
|
||||
return &token.Token{fmt.Sprintf("bearer-%d", calls), time.Now().Add(time.Hour)}, nil
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
c := token.NewCache(fetcher, "test-refresh-token")
|
||||
_, err := c.GetToken(context.Background(), time.Minute*2)
|
||||
assert.Error(t, err)
|
||||
got, err := c.GetToken(ctx, time.Minute*2)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "bearer-1", got)
|
||||
|
||||
got, err = c.GetToken(ctx, time.Minute*2)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "bearer-1", got)
|
||||
|
||||
c.Reset()
|
||||
got, err = c.GetToken(ctx, time.Minute*2)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "bearer-2", got)
|
||||
})
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue