zero/api: reset token and url cache if 401 is received (#5256)

zero/api: reset token cache if 401 is received
This commit is contained in:
Denis Mishin 2024-09-03 15:40:28 -04:00 committed by GitHub
parent a04d1a450c
commit ce12e51cf5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 91 additions and 32 deletions

View file

@ -17,18 +17,23 @@ const (
var userAgent = version.UserAgent()
type client struct {
tokenProvider TokenProviderFn
tokenProvider TokenCache
httpClient *http.Client
minTokenTTL time.Duration
}
// TokenProviderFn is a function that returns a token that is expected to be valid for at least minTTL
type TokenProviderFn func(ctx context.Context, minTTL time.Duration) (string, error)
// TokenCache interface for fetching and caching tokens
type TokenCache interface {
// GetToken returns a token that is expected to be valid for at least minTTL duration
GetToken(ctx context.Context, minTTL time.Duration) (string, error)
// Reset resets the token cache
Reset()
}
// NewAuthorizedClient creates a new HTTP client that will automatically add an authorization header
func NewAuthorizedClient(
endpoint string,
tokenProvider TokenProviderFn,
tokenProvider TokenCache,
httpClient *http.Client,
) (ClientWithResponsesInterface, error) {
c := &client{
@ -43,12 +48,21 @@ func NewAuthorizedClient(
func (c *client) Do(req *http.Request) (*http.Response, error) {
ctx := req.Context()
token, err := c.tokenProvider(ctx, c.minTokenTTL)
token, err := c.tokenProvider.GetToken(ctx, c.minTokenTTL)
if err != nil {
return nil, fmt.Errorf("error getting token: %w", err)
}
req.Header.Set("Authorization", "Bearer "+token)
req.Header.Set("User-Agent", userAgent)
return c.httpClient.Do(req)
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, err
}
if resp.StatusCode == http.StatusUnauthorized {
c.tokenProvider.Reset()
}
return resp, nil
}

View file

@ -46,7 +46,7 @@ func TestAPIClient(t *testing.T) {
require.NoError(t, err)
tokenCache := token.NewCache(fetcher, "refresh-token")
client, err := api.NewAuthorizedClient(srv.URL, tokenCache.GetToken, http.DefaultClient)
client, err := api.NewAuthorizedClient(srv.URL, tokenCache, http.DefaultClient)
require.NoError(t, err)
resp, err := client.ExchangeClusterIdentityTokenWithResponse(context.Background(),

View file

@ -6,6 +6,7 @@ import (
"strconv"
"time"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/zero/apierror"
"github.com/pomerium/pomerium/internal/zero/token"
)
@ -20,7 +21,7 @@ func NewTokenFetcher(endpoint string, opts ...ClientOption) (token.Fetcher, erro
return func(ctx context.Context, refreshToken string) (*token.Token, error) {
now := time.Now()
resp, err := apierror.CheckResponse[ExchangeTokenResponse](client.ExchangeClusterIdentityTokenWithResponse(ctx, ExchangeTokenRequest{
resp, err := apierror.CheckResponse(client.ExchangeClusterIdentityTokenWithResponse(ctx, ExchangeTokenRequest{
RefreshToken: refreshToken,
}))
if err != nil {
@ -32,9 +33,11 @@ func NewTokenFetcher(endpoint string, opts ...ClientOption) (token.Fetcher, erro
return nil, fmt.Errorf("error parsing expires in: %w", err)
}
expires := now.Add(time.Duration(expiresSeconds) * time.Second)
log.Ctx(ctx).Debug().Time("expires", expires).Msg("fetched new Bearer token")
return &token.Token{
Bearer: resp.IdToken,
Expires: now.Add(time.Duration(expiresSeconds) * time.Second),
Expires: expires,
}, nil
}, nil
}

View file

@ -29,6 +29,13 @@ func NewURLCache() *URLCache {
}
}
func (c *URLCache) Delete(key string) {
c.mx.Lock()
defer c.mx.Unlock()
delete(c.cache, key)
}
// Get gets the cache entry for the given key, if it exists and has not expired.
func (c *URLCache) Get(key string, minTTL time.Duration) (*DownloadCacheEntry, bool) {
c.mx.RLock()