mirror of
https://github.com/pomerium/pomerium.git
synced 2025-06-02 02:42:57 +02:00
zero/api: reset token and url cache if 401 is received (#5256)
zero/api: reset token cache if 401 is received
This commit is contained in:
parent
a04d1a450c
commit
ce12e51cf5
8 changed files with 91 additions and 32 deletions
|
@ -17,18 +17,23 @@ const (
|
|||
var userAgent = version.UserAgent()
|
||||
|
||||
type client struct {
|
||||
tokenProvider TokenProviderFn
|
||||
tokenProvider TokenCache
|
||||
httpClient *http.Client
|
||||
minTokenTTL time.Duration
|
||||
}
|
||||
|
||||
// TokenProviderFn is a function that returns a token that is expected to be valid for at least minTTL
|
||||
type TokenProviderFn func(ctx context.Context, minTTL time.Duration) (string, error)
|
||||
// TokenCache interface for fetching and caching tokens
|
||||
type TokenCache interface {
|
||||
// GetToken returns a token that is expected to be valid for at least minTTL duration
|
||||
GetToken(ctx context.Context, minTTL time.Duration) (string, error)
|
||||
// Reset resets the token cache
|
||||
Reset()
|
||||
}
|
||||
|
||||
// NewAuthorizedClient creates a new HTTP client that will automatically add an authorization header
|
||||
func NewAuthorizedClient(
|
||||
endpoint string,
|
||||
tokenProvider TokenProviderFn,
|
||||
tokenProvider TokenCache,
|
||||
httpClient *http.Client,
|
||||
) (ClientWithResponsesInterface, error) {
|
||||
c := &client{
|
||||
|
@ -43,12 +48,21 @@ func NewAuthorizedClient(
|
|||
|
||||
func (c *client) Do(req *http.Request) (*http.Response, error) {
|
||||
ctx := req.Context()
|
||||
token, err := c.tokenProvider(ctx, c.minTokenTTL)
|
||||
token, err := c.tokenProvider.GetToken(ctx, c.minTokenTTL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting token: %w", err)
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
req.Header.Set("User-Agent", userAgent)
|
||||
|
||||
return c.httpClient.Do(req)
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if resp.StatusCode == http.StatusUnauthorized {
|
||||
c.tokenProvider.Reset()
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
|
|
@ -46,7 +46,7 @@ func TestAPIClient(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
|
||||
tokenCache := token.NewCache(fetcher, "refresh-token")
|
||||
client, err := api.NewAuthorizedClient(srv.URL, tokenCache.GetToken, http.DefaultClient)
|
||||
client, err := api.NewAuthorizedClient(srv.URL, tokenCache, http.DefaultClient)
|
||||
require.NoError(t, err)
|
||||
|
||||
resp, err := client.ExchangeClusterIdentityTokenWithResponse(context.Background(),
|
||||
|
|
|
@ -6,6 +6,7 @@ import (
|
|||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/internal/zero/apierror"
|
||||
"github.com/pomerium/pomerium/internal/zero/token"
|
||||
)
|
||||
|
@ -20,7 +21,7 @@ func NewTokenFetcher(endpoint string, opts ...ClientOption) (token.Fetcher, erro
|
|||
return func(ctx context.Context, refreshToken string) (*token.Token, error) {
|
||||
now := time.Now()
|
||||
|
||||
resp, err := apierror.CheckResponse[ExchangeTokenResponse](client.ExchangeClusterIdentityTokenWithResponse(ctx, ExchangeTokenRequest{
|
||||
resp, err := apierror.CheckResponse(client.ExchangeClusterIdentityTokenWithResponse(ctx, ExchangeTokenRequest{
|
||||
RefreshToken: refreshToken,
|
||||
}))
|
||||
if err != nil {
|
||||
|
@ -32,9 +33,11 @@ func NewTokenFetcher(endpoint string, opts ...ClientOption) (token.Fetcher, erro
|
|||
return nil, fmt.Errorf("error parsing expires in: %w", err)
|
||||
}
|
||||
|
||||
expires := now.Add(time.Duration(expiresSeconds) * time.Second)
|
||||
log.Ctx(ctx).Debug().Time("expires", expires).Msg("fetched new Bearer token")
|
||||
return &token.Token{
|
||||
Bearer: resp.IdToken,
|
||||
Expires: now.Add(time.Duration(expiresSeconds) * time.Second),
|
||||
Expires: expires,
|
||||
}, nil
|
||||
}, nil
|
||||
}
|
||||
|
|
|
@ -29,6 +29,13 @@ func NewURLCache() *URLCache {
|
|||
}
|
||||
}
|
||||
|
||||
func (c *URLCache) Delete(key string) {
|
||||
c.mx.Lock()
|
||||
defer c.mx.Unlock()
|
||||
|
||||
delete(c.cache, key)
|
||||
}
|
||||
|
||||
// Get gets the cache entry for the given key, if it exists and has not expired.
|
||||
func (c *URLCache) Get(key string, minTTL time.Duration) (*DownloadCacheEntry, bool) {
|
||||
c.mx.RLock()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue