diff --git a/internal/zero/api/api.go b/internal/zero/api/api.go index 9f62e0936..b83d9b1a7 100644 --- a/internal/zero/api/api.go +++ b/internal/zero/api/api.go @@ -63,7 +63,7 @@ func NewAPI(ctx context.Context, opts ...Option) (*API, error) { tokenCache := token_api.NewCache(fetcher, cfg.apiToken) - clusterClient, err := cluster_api.NewAuthorizedClient(cfg.clusterAPIEndpoint, tokenCache.GetToken, cfg.httpClient) + clusterClient, err := cluster_api.NewAuthorizedClient(cfg.clusterAPIEndpoint, tokenCache, cfg.httpClient) if err != nil { return nil, fmt.Errorf("error creating cluster client: %w", err) } @@ -104,14 +104,14 @@ func (api *API) Watch(ctx context.Context, opts ...WatchOption) error { // GetClusterBootstrapConfig fetches the bootstrap configuration from the cluster API func (api *API) GetClusterBootstrapConfig(ctx context.Context) (*cluster_api.BootstrapConfig, error) { - return apierror.CheckResponse[cluster_api.BootstrapConfig]( + return apierror.CheckResponse( api.cluster.GetClusterBootstrapConfigWithResponse(ctx), ) } // GetClusterResourceBundles fetches the resource bundles from the cluster API func (api *API) GetClusterResourceBundles(ctx context.Context) (*cluster_api.GetBundlesResponse, error) { - return apierror.CheckResponse[cluster_api.GetBundlesResponse]( + return apierror.CheckResponse( api.cluster.GetClusterResourceBundlesWithResponse(ctx), ) } diff --git a/internal/zero/api/download.go b/internal/zero/api/download.go index f536badbf..0a7463f5c 100644 --- a/internal/zero/api/download.go +++ b/internal/zero/api/download.go @@ -56,6 +56,10 @@ func (api *API) DownloadClusterResourceBundle( return newContentNotModifiedDownloadResult(resp.Header.Get("Last-Modified") != current.LastModified), nil } + if resp.StatusCode == http.StatusUnauthorized { + api.downloadURLCache.Delete(id) + } + if resp.StatusCode != http.StatusOK { return nil, httpDownloadError(ctx, resp) } @@ -107,6 +111,10 @@ func (api *API) HeadClusterResourceBundle( Str("status", resp.Status). Msg("bundle metadata request") + if resp.StatusCode == http.StatusUnauthorized { + api.downloadURLCache.Delete(id) + } + if resp.StatusCode != http.StatusOK { return nil, httpDownloadError(ctx, resp) } @@ -180,7 +188,7 @@ func (api *API) getDownloadParams(ctx context.Context, id string) (*cluster_api. func (api *API) updateBundleDownloadParams(ctx context.Context, id string) (*cluster_api.DownloadCacheEntry, error) { now := time.Now() - resp, err := apierror.CheckResponse[cluster_api.DownloadBundleResponse]( + resp, err := apierror.CheckResponse( api.cluster.DownloadClusterResourceBundleWithResponse(ctx, id), ) if err != nil { @@ -197,11 +205,13 @@ func (api *API) updateBundleDownloadParams(ctx context.Context, id string) (*clu return nil, fmt.Errorf("parse url: %w", err) } + expires := now.Add(time.Duration(expiresSeconds) * time.Second) param := cluster_api.DownloadCacheEntry{ URL: *u, - ExpiresAt: now.Add(time.Duration(expiresSeconds) * time.Second), + ExpiresAt: expires, CaptureHeaders: resp.CaptureMetadataHeaders, } + log.Ctx(ctx).Debug().Time("expires", expires).Msg("bundle download URL updated") api.downloadURLCache.Set(id, param) return ¶m, nil } @@ -323,7 +333,7 @@ func isXML(ct string) bool { } func extractMetadata(header http.Header, keys []string) map[string]string { - log.Info().Interface("header", header).Msg("extract metadata") + log.Debug().Interface("header", header).Msg("extract metadata") m := make(map[string]string) for _, k := range keys { v := header.Get(k) diff --git a/internal/zero/token/cache.go b/internal/zero/token/cache.go index abd0a1a06..bcd119ea5 100644 --- a/internal/zero/token/cache.go +++ b/internal/zero/token/cache.go @@ -3,7 +3,6 @@ package token import ( "context" - "fmt" "sync/atomic" "time" ) @@ -21,7 +20,7 @@ type Cache struct { fetcher Fetcher lock chan struct{} - token atomic.Value + token atomic.Pointer[Token] } // Fetcher is a function that fetches a new token @@ -42,11 +41,13 @@ func (t *Token) ExpiresAfter(tm time.Time) bool { // NewCache creates a new token cache func NewCache(fetcher Fetcher, refreshToken string) *Cache { - return &Cache{ + c := &Cache{ lock: make(chan struct{}, 1), fetcher: fetcher, refreshToken: refreshToken, } + c.token.Store(nil) + return c } func (c *Cache) timeNow() time.Time { @@ -56,19 +57,33 @@ func (c *Cache) timeNow() time.Time { return time.Now() } +func (c *Cache) Reset() { + c.token.Store(nil) +} + // GetToken returns the current token if its at least `minTTL` from expiration, or fetches a new one. func (c *Cache) GetToken(ctx context.Context, minTTL time.Duration) (string, error) { minExpiration := c.timeNow().Add(minTTL) - token, ok := c.token.Load().(*Token) - if ok && token.ExpiresAfter(minExpiration) { + token := c.token.Load() + if token.ExpiresAfter(minExpiration) { return token.Bearer, nil } - return c.forceRefreshToken(ctx, minExpiration) + return c.ForceRefreshToken(ctx) } -func (c *Cache) forceRefreshToken(ctx context.Context, minExpiration time.Time) (string, error) { +func (c *Cache) ForceRefreshToken(ctx context.Context) (string, error) { + var current string + token := c.token.Load() + if token != nil { + current = token.Bearer + } + + return c.forceRefreshToken(ctx, current) +} + +func (c *Cache) forceRefreshToken(ctx context.Context, current string) (string, error) { select { case c.lock <- struct{}{}: case <-ctx.Done(): @@ -81,8 +96,8 @@ func (c *Cache) forceRefreshToken(ctx context.Context, minExpiration time.Time) ctx, cancel := context.WithTimeout(ctx, maxLockWait) defer cancel() - token, ok := c.token.Load().(*Token) - if ok && token.ExpiresAfter(minExpiration) { + token := c.token.Load() + if token != nil && token.Bearer != current { return token.Bearer, nil } @@ -92,9 +107,5 @@ func (c *Cache) forceRefreshToken(ctx context.Context, minExpiration time.Time) } c.token.Store(token) - if token.Expires.Before(minExpiration) { - return "", fmt.Errorf("new token cannot satisfy TTL: %v", minExpiration.Sub(token.Expires)) - } - return token.Bearer, nil } diff --git a/internal/zero/token/cache_test.go b/internal/zero/token/cache_test.go index f8de4b363..6a55bf9e5 100644 --- a/internal/zero/token/cache_test.go +++ b/internal/zero/token/cache_test.go @@ -2,6 +2,7 @@ package token_test import ( "context" + "fmt" "testing" "time" @@ -51,15 +52,28 @@ func TestCache(t *testing.T) { assert.Equal(t, "bearer-3", bearer) }) - t.Run("token cannot fit minTTL", func(t *testing.T) { + t.Run("reset", func(t *testing.T) { t.Parallel() + var calls int fetcher := func(_ context.Context, _ string) (*token.Token, error) { - return &token.Token{"ok-bearer", time.Now().Add(time.Minute)}, nil + calls++ + return &token.Token{fmt.Sprintf("bearer-%d", calls), time.Now().Add(time.Hour)}, nil } + ctx := context.Background() c := token.NewCache(fetcher, "test-refresh-token") - _, err := c.GetToken(context.Background(), time.Minute*2) - assert.Error(t, err) + got, err := c.GetToken(ctx, time.Minute*2) + assert.NoError(t, err) + assert.Equal(t, "bearer-1", got) + + got, err = c.GetToken(ctx, time.Minute*2) + assert.NoError(t, err) + assert.Equal(t, "bearer-1", got) + + c.Reset() + got, err = c.GetToken(ctx, time.Minute*2) + assert.NoError(t, err) + assert.Equal(t, "bearer-2", got) }) } diff --git a/pkg/zero/cluster/client.go b/pkg/zero/cluster/client.go index 932e88f4a..066c37f81 100644 --- a/pkg/zero/cluster/client.go +++ b/pkg/zero/cluster/client.go @@ -17,18 +17,23 @@ const ( var userAgent = version.UserAgent() type client struct { - tokenProvider TokenProviderFn + tokenProvider TokenCache httpClient *http.Client minTokenTTL time.Duration } -// TokenProviderFn is a function that returns a token that is expected to be valid for at least minTTL -type TokenProviderFn func(ctx context.Context, minTTL time.Duration) (string, error) +// TokenCache interface for fetching and caching tokens +type TokenCache interface { + // GetToken returns a token that is expected to be valid for at least minTTL duration + GetToken(ctx context.Context, minTTL time.Duration) (string, error) + // Reset resets the token cache + Reset() +} // NewAuthorizedClient creates a new HTTP client that will automatically add an authorization header func NewAuthorizedClient( endpoint string, - tokenProvider TokenProviderFn, + tokenProvider TokenCache, httpClient *http.Client, ) (ClientWithResponsesInterface, error) { c := &client{ @@ -43,12 +48,21 @@ func NewAuthorizedClient( func (c *client) Do(req *http.Request) (*http.Response, error) { ctx := req.Context() - token, err := c.tokenProvider(ctx, c.minTokenTTL) + token, err := c.tokenProvider.GetToken(ctx, c.minTokenTTL) if err != nil { return nil, fmt.Errorf("error getting token: %w", err) } req.Header.Set("Authorization", "Bearer "+token) req.Header.Set("User-Agent", userAgent) - return c.httpClient.Do(req) + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, err + } + + if resp.StatusCode == http.StatusUnauthorized { + c.tokenProvider.Reset() + } + + return resp, nil } diff --git a/pkg/zero/cluster/client_test.go b/pkg/zero/cluster/client_test.go index 0de8f2cdb..182ebf9df 100644 --- a/pkg/zero/cluster/client_test.go +++ b/pkg/zero/cluster/client_test.go @@ -46,7 +46,7 @@ func TestAPIClient(t *testing.T) { require.NoError(t, err) tokenCache := token.NewCache(fetcher, "refresh-token") - client, err := api.NewAuthorizedClient(srv.URL, tokenCache.GetToken, http.DefaultClient) + client, err := api.NewAuthorizedClient(srv.URL, tokenCache, http.DefaultClient) require.NoError(t, err) resp, err := client.ExchangeClusterIdentityTokenWithResponse(context.Background(), diff --git a/pkg/zero/cluster/token_fetcher.go b/pkg/zero/cluster/token_fetcher.go index 76f6d87d3..ba1e040de 100644 --- a/pkg/zero/cluster/token_fetcher.go +++ b/pkg/zero/cluster/token_fetcher.go @@ -6,6 +6,7 @@ import ( "strconv" "time" + "github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/zero/apierror" "github.com/pomerium/pomerium/internal/zero/token" ) @@ -20,7 +21,7 @@ func NewTokenFetcher(endpoint string, opts ...ClientOption) (token.Fetcher, erro return func(ctx context.Context, refreshToken string) (*token.Token, error) { now := time.Now() - resp, err := apierror.CheckResponse[ExchangeTokenResponse](client.ExchangeClusterIdentityTokenWithResponse(ctx, ExchangeTokenRequest{ + resp, err := apierror.CheckResponse(client.ExchangeClusterIdentityTokenWithResponse(ctx, ExchangeTokenRequest{ RefreshToken: refreshToken, })) if err != nil { @@ -32,9 +33,11 @@ func NewTokenFetcher(endpoint string, opts ...ClientOption) (token.Fetcher, erro return nil, fmt.Errorf("error parsing expires in: %w", err) } + expires := now.Add(time.Duration(expiresSeconds) * time.Second) + log.Ctx(ctx).Debug().Time("expires", expires).Msg("fetched new Bearer token") return &token.Token{ Bearer: resp.IdToken, - Expires: now.Add(time.Duration(expiresSeconds) * time.Second), + Expires: expires, }, nil }, nil } diff --git a/pkg/zero/cluster/urlcache.go b/pkg/zero/cluster/urlcache.go index 09350c931..d7d8a1189 100644 --- a/pkg/zero/cluster/urlcache.go +++ b/pkg/zero/cluster/urlcache.go @@ -29,6 +29,13 @@ func NewURLCache() *URLCache { } } +func (c *URLCache) Delete(key string) { + c.mx.Lock() + defer c.mx.Unlock() + + delete(c.cache, key) +} + // Get gets the cache entry for the given key, if it exists and has not expired. func (c *URLCache) Get(key string, minTTL time.Duration) (*DownloadCacheEntry, bool) { c.mx.RLock()