mirror of
https://github.com/pomerium/pomerium.git
synced 2025-04-29 18:36:30 +02:00
zero/api: reset token and url cache if 401 is received (#5256)
zero/api: reset token cache if 401 is received
This commit is contained in:
parent
a04d1a450c
commit
ce12e51cf5
8 changed files with 91 additions and 32 deletions
|
@ -63,7 +63,7 @@ func NewAPI(ctx context.Context, opts ...Option) (*API, error) {
|
||||||
|
|
||||||
tokenCache := token_api.NewCache(fetcher, cfg.apiToken)
|
tokenCache := token_api.NewCache(fetcher, cfg.apiToken)
|
||||||
|
|
||||||
clusterClient, err := cluster_api.NewAuthorizedClient(cfg.clusterAPIEndpoint, tokenCache.GetToken, cfg.httpClient)
|
clusterClient, err := cluster_api.NewAuthorizedClient(cfg.clusterAPIEndpoint, tokenCache, cfg.httpClient)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error creating cluster client: %w", err)
|
return nil, fmt.Errorf("error creating cluster client: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -104,14 +104,14 @@ func (api *API) Watch(ctx context.Context, opts ...WatchOption) error {
|
||||||
|
|
||||||
// GetClusterBootstrapConfig fetches the bootstrap configuration from the cluster API
|
// GetClusterBootstrapConfig fetches the bootstrap configuration from the cluster API
|
||||||
func (api *API) GetClusterBootstrapConfig(ctx context.Context) (*cluster_api.BootstrapConfig, error) {
|
func (api *API) GetClusterBootstrapConfig(ctx context.Context) (*cluster_api.BootstrapConfig, error) {
|
||||||
return apierror.CheckResponse[cluster_api.BootstrapConfig](
|
return apierror.CheckResponse(
|
||||||
api.cluster.GetClusterBootstrapConfigWithResponse(ctx),
|
api.cluster.GetClusterBootstrapConfigWithResponse(ctx),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetClusterResourceBundles fetches the resource bundles from the cluster API
|
// GetClusterResourceBundles fetches the resource bundles from the cluster API
|
||||||
func (api *API) GetClusterResourceBundles(ctx context.Context) (*cluster_api.GetBundlesResponse, error) {
|
func (api *API) GetClusterResourceBundles(ctx context.Context) (*cluster_api.GetBundlesResponse, error) {
|
||||||
return apierror.CheckResponse[cluster_api.GetBundlesResponse](
|
return apierror.CheckResponse(
|
||||||
api.cluster.GetClusterResourceBundlesWithResponse(ctx),
|
api.cluster.GetClusterResourceBundlesWithResponse(ctx),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
|
@ -56,6 +56,10 @@ func (api *API) DownloadClusterResourceBundle(
|
||||||
return newContentNotModifiedDownloadResult(resp.Header.Get("Last-Modified") != current.LastModified), nil
|
return newContentNotModifiedDownloadResult(resp.Header.Get("Last-Modified") != current.LastModified), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode == http.StatusUnauthorized {
|
||||||
|
api.downloadURLCache.Delete(id)
|
||||||
|
}
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
return nil, httpDownloadError(ctx, resp)
|
return nil, httpDownloadError(ctx, resp)
|
||||||
}
|
}
|
||||||
|
@ -107,6 +111,10 @@ func (api *API) HeadClusterResourceBundle(
|
||||||
Str("status", resp.Status).
|
Str("status", resp.Status).
|
||||||
Msg("bundle metadata request")
|
Msg("bundle metadata request")
|
||||||
|
|
||||||
|
if resp.StatusCode == http.StatusUnauthorized {
|
||||||
|
api.downloadURLCache.Delete(id)
|
||||||
|
}
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
return nil, httpDownloadError(ctx, resp)
|
return nil, httpDownloadError(ctx, resp)
|
||||||
}
|
}
|
||||||
|
@ -180,7 +188,7 @@ func (api *API) getDownloadParams(ctx context.Context, id string) (*cluster_api.
|
||||||
func (api *API) updateBundleDownloadParams(ctx context.Context, id string) (*cluster_api.DownloadCacheEntry, error) {
|
func (api *API) updateBundleDownloadParams(ctx context.Context, id string) (*cluster_api.DownloadCacheEntry, error) {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
|
|
||||||
resp, err := apierror.CheckResponse[cluster_api.DownloadBundleResponse](
|
resp, err := apierror.CheckResponse(
|
||||||
api.cluster.DownloadClusterResourceBundleWithResponse(ctx, id),
|
api.cluster.DownloadClusterResourceBundleWithResponse(ctx, id),
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -197,11 +205,13 @@ func (api *API) updateBundleDownloadParams(ctx context.Context, id string) (*clu
|
||||||
return nil, fmt.Errorf("parse url: %w", err)
|
return nil, fmt.Errorf("parse url: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
expires := now.Add(time.Duration(expiresSeconds) * time.Second)
|
||||||
param := cluster_api.DownloadCacheEntry{
|
param := cluster_api.DownloadCacheEntry{
|
||||||
URL: *u,
|
URL: *u,
|
||||||
ExpiresAt: now.Add(time.Duration(expiresSeconds) * time.Second),
|
ExpiresAt: expires,
|
||||||
CaptureHeaders: resp.CaptureMetadataHeaders,
|
CaptureHeaders: resp.CaptureMetadataHeaders,
|
||||||
}
|
}
|
||||||
|
log.Ctx(ctx).Debug().Time("expires", expires).Msg("bundle download URL updated")
|
||||||
api.downloadURLCache.Set(id, param)
|
api.downloadURLCache.Set(id, param)
|
||||||
return ¶m, nil
|
return ¶m, nil
|
||||||
}
|
}
|
||||||
|
@ -323,7 +333,7 @@ func isXML(ct string) bool {
|
||||||
}
|
}
|
||||||
|
|
||||||
func extractMetadata(header http.Header, keys []string) map[string]string {
|
func extractMetadata(header http.Header, keys []string) map[string]string {
|
||||||
log.Info().Interface("header", header).Msg("extract metadata")
|
log.Debug().Interface("header", header).Msg("extract metadata")
|
||||||
m := make(map[string]string)
|
m := make(map[string]string)
|
||||||
for _, k := range keys {
|
for _, k := range keys {
|
||||||
v := header.Get(k)
|
v := header.Get(k)
|
||||||
|
|
|
@ -3,7 +3,6 @@ package token
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
@ -21,7 +20,7 @@ type Cache struct {
|
||||||
fetcher Fetcher
|
fetcher Fetcher
|
||||||
|
|
||||||
lock chan struct{}
|
lock chan struct{}
|
||||||
token atomic.Value
|
token atomic.Pointer[Token]
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fetcher is a function that fetches a new token
|
// Fetcher is a function that fetches a new token
|
||||||
|
@ -42,11 +41,13 @@ func (t *Token) ExpiresAfter(tm time.Time) bool {
|
||||||
|
|
||||||
// NewCache creates a new token cache
|
// NewCache creates a new token cache
|
||||||
func NewCache(fetcher Fetcher, refreshToken string) *Cache {
|
func NewCache(fetcher Fetcher, refreshToken string) *Cache {
|
||||||
return &Cache{
|
c := &Cache{
|
||||||
lock: make(chan struct{}, 1),
|
lock: make(chan struct{}, 1),
|
||||||
fetcher: fetcher,
|
fetcher: fetcher,
|
||||||
refreshToken: refreshToken,
|
refreshToken: refreshToken,
|
||||||
}
|
}
|
||||||
|
c.token.Store(nil)
|
||||||
|
return c
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Cache) timeNow() time.Time {
|
func (c *Cache) timeNow() time.Time {
|
||||||
|
@ -56,19 +57,33 @@ func (c *Cache) timeNow() time.Time {
|
||||||
return time.Now()
|
return time.Now()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Cache) Reset() {
|
||||||
|
c.token.Store(nil)
|
||||||
|
}
|
||||||
|
|
||||||
// GetToken returns the current token if its at least `minTTL` from expiration, or fetches a new one.
|
// GetToken returns the current token if its at least `minTTL` from expiration, or fetches a new one.
|
||||||
func (c *Cache) GetToken(ctx context.Context, minTTL time.Duration) (string, error) {
|
func (c *Cache) GetToken(ctx context.Context, minTTL time.Duration) (string, error) {
|
||||||
minExpiration := c.timeNow().Add(minTTL)
|
minExpiration := c.timeNow().Add(minTTL)
|
||||||
|
|
||||||
token, ok := c.token.Load().(*Token)
|
token := c.token.Load()
|
||||||
if ok && token.ExpiresAfter(minExpiration) {
|
if token.ExpiresAfter(minExpiration) {
|
||||||
return token.Bearer, nil
|
return token.Bearer, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return c.forceRefreshToken(ctx, minExpiration)
|
return c.ForceRefreshToken(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Cache) forceRefreshToken(ctx context.Context, minExpiration time.Time) (string, error) {
|
func (c *Cache) ForceRefreshToken(ctx context.Context) (string, error) {
|
||||||
|
var current string
|
||||||
|
token := c.token.Load()
|
||||||
|
if token != nil {
|
||||||
|
current = token.Bearer
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.forceRefreshToken(ctx, current)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Cache) forceRefreshToken(ctx context.Context, current string) (string, error) {
|
||||||
select {
|
select {
|
||||||
case c.lock <- struct{}{}:
|
case c.lock <- struct{}{}:
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
|
@ -81,8 +96,8 @@ func (c *Cache) forceRefreshToken(ctx context.Context, minExpiration time.Time)
|
||||||
ctx, cancel := context.WithTimeout(ctx, maxLockWait)
|
ctx, cancel := context.WithTimeout(ctx, maxLockWait)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
token, ok := c.token.Load().(*Token)
|
token := c.token.Load()
|
||||||
if ok && token.ExpiresAfter(minExpiration) {
|
if token != nil && token.Bearer != current {
|
||||||
return token.Bearer, nil
|
return token.Bearer, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -92,9 +107,5 @@ func (c *Cache) forceRefreshToken(ctx context.Context, minExpiration time.Time)
|
||||||
}
|
}
|
||||||
c.token.Store(token)
|
c.token.Store(token)
|
||||||
|
|
||||||
if token.Expires.Before(minExpiration) {
|
|
||||||
return "", fmt.Errorf("new token cannot satisfy TTL: %v", minExpiration.Sub(token.Expires))
|
|
||||||
}
|
|
||||||
|
|
||||||
return token.Bearer, nil
|
return token.Bearer, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,6 +2,7 @@ package token_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -51,15 +52,28 @@ func TestCache(t *testing.T) {
|
||||||
assert.Equal(t, "bearer-3", bearer)
|
assert.Equal(t, "bearer-3", bearer)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("token cannot fit minTTL", func(t *testing.T) {
|
t.Run("reset", func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
|
var calls int
|
||||||
fetcher := func(_ context.Context, _ string) (*token.Token, error) {
|
fetcher := func(_ context.Context, _ string) (*token.Token, error) {
|
||||||
return &token.Token{"ok-bearer", time.Now().Add(time.Minute)}, nil
|
calls++
|
||||||
|
return &token.Token{fmt.Sprintf("bearer-%d", calls), time.Now().Add(time.Hour)}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
c := token.NewCache(fetcher, "test-refresh-token")
|
c := token.NewCache(fetcher, "test-refresh-token")
|
||||||
_, err := c.GetToken(context.Background(), time.Minute*2)
|
got, err := c.GetToken(ctx, time.Minute*2)
|
||||||
assert.Error(t, err)
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, "bearer-1", got)
|
||||||
|
|
||||||
|
got, err = c.GetToken(ctx, time.Minute*2)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, "bearer-1", got)
|
||||||
|
|
||||||
|
c.Reset()
|
||||||
|
got, err = c.GetToken(ctx, time.Minute*2)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, "bearer-2", got)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,18 +17,23 @@ const (
|
||||||
var userAgent = version.UserAgent()
|
var userAgent = version.UserAgent()
|
||||||
|
|
||||||
type client struct {
|
type client struct {
|
||||||
tokenProvider TokenProviderFn
|
tokenProvider TokenCache
|
||||||
httpClient *http.Client
|
httpClient *http.Client
|
||||||
minTokenTTL time.Duration
|
minTokenTTL time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
// TokenProviderFn is a function that returns a token that is expected to be valid for at least minTTL
|
// TokenCache interface for fetching and caching tokens
|
||||||
type TokenProviderFn func(ctx context.Context, minTTL time.Duration) (string, error)
|
type TokenCache interface {
|
||||||
|
// GetToken returns a token that is expected to be valid for at least minTTL duration
|
||||||
|
GetToken(ctx context.Context, minTTL time.Duration) (string, error)
|
||||||
|
// Reset resets the token cache
|
||||||
|
Reset()
|
||||||
|
}
|
||||||
|
|
||||||
// NewAuthorizedClient creates a new HTTP client that will automatically add an authorization header
|
// NewAuthorizedClient creates a new HTTP client that will automatically add an authorization header
|
||||||
func NewAuthorizedClient(
|
func NewAuthorizedClient(
|
||||||
endpoint string,
|
endpoint string,
|
||||||
tokenProvider TokenProviderFn,
|
tokenProvider TokenCache,
|
||||||
httpClient *http.Client,
|
httpClient *http.Client,
|
||||||
) (ClientWithResponsesInterface, error) {
|
) (ClientWithResponsesInterface, error) {
|
||||||
c := &client{
|
c := &client{
|
||||||
|
@ -43,12 +48,21 @@ func NewAuthorizedClient(
|
||||||
|
|
||||||
func (c *client) Do(req *http.Request) (*http.Response, error) {
|
func (c *client) Do(req *http.Request) (*http.Response, error) {
|
||||||
ctx := req.Context()
|
ctx := req.Context()
|
||||||
token, err := c.tokenProvider(ctx, c.minTokenTTL)
|
token, err := c.tokenProvider.GetToken(ctx, c.minTokenTTL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error getting token: %w", err)
|
return nil, fmt.Errorf("error getting token: %w", err)
|
||||||
}
|
}
|
||||||
req.Header.Set("Authorization", "Bearer "+token)
|
req.Header.Set("Authorization", "Bearer "+token)
|
||||||
req.Header.Set("User-Agent", userAgent)
|
req.Header.Set("User-Agent", userAgent)
|
||||||
|
|
||||||
return c.httpClient.Do(req)
|
resp, err := c.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode == http.StatusUnauthorized {
|
||||||
|
c.tokenProvider.Reset()
|
||||||
|
}
|
||||||
|
|
||||||
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -46,7 +46,7 @@ func TestAPIClient(t *testing.T) {
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
tokenCache := token.NewCache(fetcher, "refresh-token")
|
tokenCache := token.NewCache(fetcher, "refresh-token")
|
||||||
client, err := api.NewAuthorizedClient(srv.URL, tokenCache.GetToken, http.DefaultClient)
|
client, err := api.NewAuthorizedClient(srv.URL, tokenCache, http.DefaultClient)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
resp, err := client.ExchangeClusterIdentityTokenWithResponse(context.Background(),
|
resp, err := client.ExchangeClusterIdentityTokenWithResponse(context.Background(),
|
||||||
|
|
|
@ -6,6 +6,7 @@ import (
|
||||||
"strconv"
|
"strconv"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/internal/log"
|
||||||
"github.com/pomerium/pomerium/internal/zero/apierror"
|
"github.com/pomerium/pomerium/internal/zero/apierror"
|
||||||
"github.com/pomerium/pomerium/internal/zero/token"
|
"github.com/pomerium/pomerium/internal/zero/token"
|
||||||
)
|
)
|
||||||
|
@ -20,7 +21,7 @@ func NewTokenFetcher(endpoint string, opts ...ClientOption) (token.Fetcher, erro
|
||||||
return func(ctx context.Context, refreshToken string) (*token.Token, error) {
|
return func(ctx context.Context, refreshToken string) (*token.Token, error) {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
|
|
||||||
resp, err := apierror.CheckResponse[ExchangeTokenResponse](client.ExchangeClusterIdentityTokenWithResponse(ctx, ExchangeTokenRequest{
|
resp, err := apierror.CheckResponse(client.ExchangeClusterIdentityTokenWithResponse(ctx, ExchangeTokenRequest{
|
||||||
RefreshToken: refreshToken,
|
RefreshToken: refreshToken,
|
||||||
}))
|
}))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -32,9 +33,11 @@ func NewTokenFetcher(endpoint string, opts ...ClientOption) (token.Fetcher, erro
|
||||||
return nil, fmt.Errorf("error parsing expires in: %w", err)
|
return nil, fmt.Errorf("error parsing expires in: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
expires := now.Add(time.Duration(expiresSeconds) * time.Second)
|
||||||
|
log.Ctx(ctx).Debug().Time("expires", expires).Msg("fetched new Bearer token")
|
||||||
return &token.Token{
|
return &token.Token{
|
||||||
Bearer: resp.IdToken,
|
Bearer: resp.IdToken,
|
||||||
Expires: now.Add(time.Duration(expiresSeconds) * time.Second),
|
Expires: expires,
|
||||||
}, nil
|
}, nil
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -29,6 +29,13 @@ func NewURLCache() *URLCache {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *URLCache) Delete(key string) {
|
||||||
|
c.mx.Lock()
|
||||||
|
defer c.mx.Unlock()
|
||||||
|
|
||||||
|
delete(c.cache, key)
|
||||||
|
}
|
||||||
|
|
||||||
// Get gets the cache entry for the given key, if it exists and has not expired.
|
// Get gets the cache entry for the given key, if it exists and has not expired.
|
||||||
func (c *URLCache) Get(key string, minTTL time.Duration) (*DownloadCacheEntry, bool) {
|
func (c *URLCache) Get(key string, minTTL time.Duration) (*DownloadCacheEntry, bool) {
|
||||||
c.mx.RLock()
|
c.mx.RLock()
|
||||||
|
|
Loading…
Add table
Reference in a new issue