zero/api: reset token and url cache if 401 is received (#5256)

zero/api: reset token cache if 401 is received
This commit is contained in:
Denis Mishin 2024-09-03 15:40:28 -04:00 committed by GitHub
parent a04d1a450c
commit ce12e51cf5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 91 additions and 32 deletions

View file

@ -63,7 +63,7 @@ func NewAPI(ctx context.Context, opts ...Option) (*API, error) {
tokenCache := token_api.NewCache(fetcher, cfg.apiToken) tokenCache := token_api.NewCache(fetcher, cfg.apiToken)
clusterClient, err := cluster_api.NewAuthorizedClient(cfg.clusterAPIEndpoint, tokenCache.GetToken, cfg.httpClient) clusterClient, err := cluster_api.NewAuthorizedClient(cfg.clusterAPIEndpoint, tokenCache, cfg.httpClient)
if err != nil { if err != nil {
return nil, fmt.Errorf("error creating cluster client: %w", err) return nil, fmt.Errorf("error creating cluster client: %w", err)
} }
@ -104,14 +104,14 @@ func (api *API) Watch(ctx context.Context, opts ...WatchOption) error {
// GetClusterBootstrapConfig fetches the bootstrap configuration from the cluster API // GetClusterBootstrapConfig fetches the bootstrap configuration from the cluster API
func (api *API) GetClusterBootstrapConfig(ctx context.Context) (*cluster_api.BootstrapConfig, error) { func (api *API) GetClusterBootstrapConfig(ctx context.Context) (*cluster_api.BootstrapConfig, error) {
return apierror.CheckResponse[cluster_api.BootstrapConfig]( return apierror.CheckResponse(
api.cluster.GetClusterBootstrapConfigWithResponse(ctx), api.cluster.GetClusterBootstrapConfigWithResponse(ctx),
) )
} }
// GetClusterResourceBundles fetches the resource bundles from the cluster API // GetClusterResourceBundles fetches the resource bundles from the cluster API
func (api *API) GetClusterResourceBundles(ctx context.Context) (*cluster_api.GetBundlesResponse, error) { func (api *API) GetClusterResourceBundles(ctx context.Context) (*cluster_api.GetBundlesResponse, error) {
return apierror.CheckResponse[cluster_api.GetBundlesResponse]( return apierror.CheckResponse(
api.cluster.GetClusterResourceBundlesWithResponse(ctx), api.cluster.GetClusterResourceBundlesWithResponse(ctx),
) )
} }

View file

@ -56,6 +56,10 @@ func (api *API) DownloadClusterResourceBundle(
return newContentNotModifiedDownloadResult(resp.Header.Get("Last-Modified") != current.LastModified), nil return newContentNotModifiedDownloadResult(resp.Header.Get("Last-Modified") != current.LastModified), nil
} }
if resp.StatusCode == http.StatusUnauthorized {
api.downloadURLCache.Delete(id)
}
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
return nil, httpDownloadError(ctx, resp) return nil, httpDownloadError(ctx, resp)
} }
@ -107,6 +111,10 @@ func (api *API) HeadClusterResourceBundle(
Str("status", resp.Status). Str("status", resp.Status).
Msg("bundle metadata request") Msg("bundle metadata request")
if resp.StatusCode == http.StatusUnauthorized {
api.downloadURLCache.Delete(id)
}
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
return nil, httpDownloadError(ctx, resp) return nil, httpDownloadError(ctx, resp)
} }
@ -180,7 +188,7 @@ func (api *API) getDownloadParams(ctx context.Context, id string) (*cluster_api.
func (api *API) updateBundleDownloadParams(ctx context.Context, id string) (*cluster_api.DownloadCacheEntry, error) { func (api *API) updateBundleDownloadParams(ctx context.Context, id string) (*cluster_api.DownloadCacheEntry, error) {
now := time.Now() now := time.Now()
resp, err := apierror.CheckResponse[cluster_api.DownloadBundleResponse]( resp, err := apierror.CheckResponse(
api.cluster.DownloadClusterResourceBundleWithResponse(ctx, id), api.cluster.DownloadClusterResourceBundleWithResponse(ctx, id),
) )
if err != nil { if err != nil {
@ -197,11 +205,13 @@ func (api *API) updateBundleDownloadParams(ctx context.Context, id string) (*clu
return nil, fmt.Errorf("parse url: %w", err) return nil, fmt.Errorf("parse url: %w", err)
} }
expires := now.Add(time.Duration(expiresSeconds) * time.Second)
param := cluster_api.DownloadCacheEntry{ param := cluster_api.DownloadCacheEntry{
URL: *u, URL: *u,
ExpiresAt: now.Add(time.Duration(expiresSeconds) * time.Second), ExpiresAt: expires,
CaptureHeaders: resp.CaptureMetadataHeaders, CaptureHeaders: resp.CaptureMetadataHeaders,
} }
log.Ctx(ctx).Debug().Time("expires", expires).Msg("bundle download URL updated")
api.downloadURLCache.Set(id, param) api.downloadURLCache.Set(id, param)
return &param, nil return &param, nil
} }
@ -323,7 +333,7 @@ func isXML(ct string) bool {
} }
func extractMetadata(header http.Header, keys []string) map[string]string { func extractMetadata(header http.Header, keys []string) map[string]string {
log.Info().Interface("header", header).Msg("extract metadata") log.Debug().Interface("header", header).Msg("extract metadata")
m := make(map[string]string) m := make(map[string]string)
for _, k := range keys { for _, k := range keys {
v := header.Get(k) v := header.Get(k)

View file

@ -3,7 +3,6 @@ package token
import ( import (
"context" "context"
"fmt"
"sync/atomic" "sync/atomic"
"time" "time"
) )
@ -21,7 +20,7 @@ type Cache struct {
fetcher Fetcher fetcher Fetcher
lock chan struct{} lock chan struct{}
token atomic.Value token atomic.Pointer[Token]
} }
// Fetcher is a function that fetches a new token // Fetcher is a function that fetches a new token
@ -42,11 +41,13 @@ func (t *Token) ExpiresAfter(tm time.Time) bool {
// NewCache creates a new token cache // NewCache creates a new token cache
func NewCache(fetcher Fetcher, refreshToken string) *Cache { func NewCache(fetcher Fetcher, refreshToken string) *Cache {
return &Cache{ c := &Cache{
lock: make(chan struct{}, 1), lock: make(chan struct{}, 1),
fetcher: fetcher, fetcher: fetcher,
refreshToken: refreshToken, refreshToken: refreshToken,
} }
c.token.Store(nil)
return c
} }
func (c *Cache) timeNow() time.Time { func (c *Cache) timeNow() time.Time {
@ -56,19 +57,33 @@ func (c *Cache) timeNow() time.Time {
return time.Now() return time.Now()
} }
func (c *Cache) Reset() {
c.token.Store(nil)
}
// GetToken returns the current token if its at least `minTTL` from expiration, or fetches a new one. // GetToken returns the current token if its at least `minTTL` from expiration, or fetches a new one.
func (c *Cache) GetToken(ctx context.Context, minTTL time.Duration) (string, error) { func (c *Cache) GetToken(ctx context.Context, minTTL time.Duration) (string, error) {
minExpiration := c.timeNow().Add(minTTL) minExpiration := c.timeNow().Add(minTTL)
token, ok := c.token.Load().(*Token) token := c.token.Load()
if ok && token.ExpiresAfter(minExpiration) { if token.ExpiresAfter(minExpiration) {
return token.Bearer, nil return token.Bearer, nil
} }
return c.forceRefreshToken(ctx, minExpiration) return c.ForceRefreshToken(ctx)
} }
func (c *Cache) forceRefreshToken(ctx context.Context, minExpiration time.Time) (string, error) { func (c *Cache) ForceRefreshToken(ctx context.Context) (string, error) {
var current string
token := c.token.Load()
if token != nil {
current = token.Bearer
}
return c.forceRefreshToken(ctx, current)
}
func (c *Cache) forceRefreshToken(ctx context.Context, current string) (string, error) {
select { select {
case c.lock <- struct{}{}: case c.lock <- struct{}{}:
case <-ctx.Done(): case <-ctx.Done():
@ -81,8 +96,8 @@ func (c *Cache) forceRefreshToken(ctx context.Context, minExpiration time.Time)
ctx, cancel := context.WithTimeout(ctx, maxLockWait) ctx, cancel := context.WithTimeout(ctx, maxLockWait)
defer cancel() defer cancel()
token, ok := c.token.Load().(*Token) token := c.token.Load()
if ok && token.ExpiresAfter(minExpiration) { if token != nil && token.Bearer != current {
return token.Bearer, nil return token.Bearer, nil
} }
@ -92,9 +107,5 @@ func (c *Cache) forceRefreshToken(ctx context.Context, minExpiration time.Time)
} }
c.token.Store(token) c.token.Store(token)
if token.Expires.Before(minExpiration) {
return "", fmt.Errorf("new token cannot satisfy TTL: %v", minExpiration.Sub(token.Expires))
}
return token.Bearer, nil return token.Bearer, nil
} }

View file

@ -2,6 +2,7 @@ package token_test
import ( import (
"context" "context"
"fmt"
"testing" "testing"
"time" "time"
@ -51,15 +52,28 @@ func TestCache(t *testing.T) {
assert.Equal(t, "bearer-3", bearer) assert.Equal(t, "bearer-3", bearer)
}) })
t.Run("token cannot fit minTTL", func(t *testing.T) { t.Run("reset", func(t *testing.T) {
t.Parallel() t.Parallel()
var calls int
fetcher := func(_ context.Context, _ string) (*token.Token, error) { fetcher := func(_ context.Context, _ string) (*token.Token, error) {
return &token.Token{"ok-bearer", time.Now().Add(time.Minute)}, nil calls++
return &token.Token{fmt.Sprintf("bearer-%d", calls), time.Now().Add(time.Hour)}, nil
} }
ctx := context.Background()
c := token.NewCache(fetcher, "test-refresh-token") c := token.NewCache(fetcher, "test-refresh-token")
_, err := c.GetToken(context.Background(), time.Minute*2) got, err := c.GetToken(ctx, time.Minute*2)
assert.Error(t, err) assert.NoError(t, err)
assert.Equal(t, "bearer-1", got)
got, err = c.GetToken(ctx, time.Minute*2)
assert.NoError(t, err)
assert.Equal(t, "bearer-1", got)
c.Reset()
got, err = c.GetToken(ctx, time.Minute*2)
assert.NoError(t, err)
assert.Equal(t, "bearer-2", got)
}) })
} }

View file

@ -17,18 +17,23 @@ const (
var userAgent = version.UserAgent() var userAgent = version.UserAgent()
type client struct { type client struct {
tokenProvider TokenProviderFn tokenProvider TokenCache
httpClient *http.Client httpClient *http.Client
minTokenTTL time.Duration minTokenTTL time.Duration
} }
// TokenProviderFn is a function that returns a token that is expected to be valid for at least minTTL // TokenCache interface for fetching and caching tokens
type TokenProviderFn func(ctx context.Context, minTTL time.Duration) (string, error) type TokenCache interface {
// GetToken returns a token that is expected to be valid for at least minTTL duration
GetToken(ctx context.Context, minTTL time.Duration) (string, error)
// Reset resets the token cache
Reset()
}
// NewAuthorizedClient creates a new HTTP client that will automatically add an authorization header // NewAuthorizedClient creates a new HTTP client that will automatically add an authorization header
func NewAuthorizedClient( func NewAuthorizedClient(
endpoint string, endpoint string,
tokenProvider TokenProviderFn, tokenProvider TokenCache,
httpClient *http.Client, httpClient *http.Client,
) (ClientWithResponsesInterface, error) { ) (ClientWithResponsesInterface, error) {
c := &client{ c := &client{
@ -43,12 +48,21 @@ func NewAuthorizedClient(
func (c *client) Do(req *http.Request) (*http.Response, error) { func (c *client) Do(req *http.Request) (*http.Response, error) {
ctx := req.Context() ctx := req.Context()
token, err := c.tokenProvider(ctx, c.minTokenTTL) token, err := c.tokenProvider.GetToken(ctx, c.minTokenTTL)
if err != nil { if err != nil {
return nil, fmt.Errorf("error getting token: %w", err) return nil, fmt.Errorf("error getting token: %w", err)
} }
req.Header.Set("Authorization", "Bearer "+token) req.Header.Set("Authorization", "Bearer "+token)
req.Header.Set("User-Agent", userAgent) req.Header.Set("User-Agent", userAgent)
return c.httpClient.Do(req) resp, err := c.httpClient.Do(req)
if err != nil {
return nil, err
}
if resp.StatusCode == http.StatusUnauthorized {
c.tokenProvider.Reset()
}
return resp, nil
} }

View file

@ -46,7 +46,7 @@ func TestAPIClient(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
tokenCache := token.NewCache(fetcher, "refresh-token") tokenCache := token.NewCache(fetcher, "refresh-token")
client, err := api.NewAuthorizedClient(srv.URL, tokenCache.GetToken, http.DefaultClient) client, err := api.NewAuthorizedClient(srv.URL, tokenCache, http.DefaultClient)
require.NoError(t, err) require.NoError(t, err)
resp, err := client.ExchangeClusterIdentityTokenWithResponse(context.Background(), resp, err := client.ExchangeClusterIdentityTokenWithResponse(context.Background(),

View file

@ -6,6 +6,7 @@ import (
"strconv" "strconv"
"time" "time"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/zero/apierror" "github.com/pomerium/pomerium/internal/zero/apierror"
"github.com/pomerium/pomerium/internal/zero/token" "github.com/pomerium/pomerium/internal/zero/token"
) )
@ -20,7 +21,7 @@ func NewTokenFetcher(endpoint string, opts ...ClientOption) (token.Fetcher, erro
return func(ctx context.Context, refreshToken string) (*token.Token, error) { return func(ctx context.Context, refreshToken string) (*token.Token, error) {
now := time.Now() now := time.Now()
resp, err := apierror.CheckResponse[ExchangeTokenResponse](client.ExchangeClusterIdentityTokenWithResponse(ctx, ExchangeTokenRequest{ resp, err := apierror.CheckResponse(client.ExchangeClusterIdentityTokenWithResponse(ctx, ExchangeTokenRequest{
RefreshToken: refreshToken, RefreshToken: refreshToken,
})) }))
if err != nil { if err != nil {
@ -32,9 +33,11 @@ func NewTokenFetcher(endpoint string, opts ...ClientOption) (token.Fetcher, erro
return nil, fmt.Errorf("error parsing expires in: %w", err) return nil, fmt.Errorf("error parsing expires in: %w", err)
} }
expires := now.Add(time.Duration(expiresSeconds) * time.Second)
log.Ctx(ctx).Debug().Time("expires", expires).Msg("fetched new Bearer token")
return &token.Token{ return &token.Token{
Bearer: resp.IdToken, Bearer: resp.IdToken,
Expires: now.Add(time.Duration(expiresSeconds) * time.Second), Expires: expires,
}, nil }, nil
}, nil }, nil
} }

View file

@ -29,6 +29,13 @@ func NewURLCache() *URLCache {
} }
} }
func (c *URLCache) Delete(key string) {
c.mx.Lock()
defer c.mx.Unlock()
delete(c.cache, key)
}
// Get gets the cache entry for the given key, if it exists and has not expired. // Get gets the cache entry for the given key, if it exists and has not expired.
func (c *URLCache) Get(key string, minTTL time.Duration) (*DownloadCacheEntry, bool) { func (c *URLCache) Get(key string, minTTL time.Duration) (*DownloadCacheEntry, bool) {
c.mx.RLock() c.mx.RLock()