mirror of
https://github.com/pomerium/pomerium.git
synced 2025-08-01 07:50:26 +02:00
identity: override TokenSource expiry behavior (#4632)
The current session refresh loop attempts to refresh access tokens when they are due to expire in less than one minute. However, the code to perform the refresh relies on a TokenSource from the x/oauth2 package, which has its own internal 'expiryDelta' threshold, with a default of 10 seconds. As a result, the first four or five attempts to refresh a particular access token will not actually refresh the token. The refresh will happen only when the access token is within 10 seconds of expiring. Instead, before we obtain a new TokenSource, first clear any existing access token. This causes the TokenSource to consider the token invalid, triggering a refresh. This should give the refresh loop more control over when refreshes happen. Consolidate this logic in a new Refresh() method in the oidc package. Add unit tests for this new method.
This commit is contained in:
parent
c32005d0fe
commit
ad962009ca
4 changed files with 104 additions and 18 deletions
|
@ -130,16 +130,9 @@ func (p *Provider) Authenticate(ctx context.Context, code string, v identity.Sta
|
|||
|
||||
// Refresh renews a user's session.
|
||||
func (p *Provider) Refresh(ctx context.Context, t *oauth2.Token, v identity.State) (*oauth2.Token, error) {
|
||||
if t == nil {
|
||||
return nil, oidc.ErrMissingAccessToken
|
||||
}
|
||||
if t.RefreshToken == "" {
|
||||
return nil, oidc.ErrMissingRefreshToken
|
||||
}
|
||||
|
||||
newToken, err := p.oauth.TokenSource(ctx, t).Token()
|
||||
newToken, err := oidc.Refresh(ctx, p.oauth, t)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("identity/apple: refresh failed: %w", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if rawIDToken, ok := newToken.Extra("id_token").(string); ok {
|
||||
|
|
|
@ -213,16 +213,9 @@ func (p *Provider) Refresh(ctx context.Context, t *oauth2.Token, v identity.Stat
|
|||
return nil, err
|
||||
}
|
||||
|
||||
if t == nil {
|
||||
return nil, ErrMissingAccessToken
|
||||
}
|
||||
if t.RefreshToken == "" {
|
||||
return nil, ErrMissingRefreshToken
|
||||
}
|
||||
|
||||
newToken, err := oa.TokenSource(ctx, t).Token()
|
||||
newToken, err := Refresh(ctx, oa, t)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("identity/oidc: refresh failed: %w", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Many identity providers _will not_ return `id_token` on refresh
|
||||
|
|
29
internal/identity/oidc/refresh.go
Normal file
29
internal/identity/oidc/refresh.go
Normal file
|
@ -0,0 +1,29 @@
|
|||
package oidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
// Refresh requests a new oauth2.Token based on an existing Token and the
|
||||
// provided Config. The existing Token must contain a refresh token.
|
||||
func Refresh(ctx context.Context, cfg *oauth2.Config, t *oauth2.Token) (*oauth2.Token, error) {
|
||||
if t == nil || t.RefreshToken == "" {
|
||||
return nil, ErrMissingRefreshToken
|
||||
}
|
||||
|
||||
// Note: the TokenSource returned by oauth2.Config has its own threshold
|
||||
// for determining when to attempt a refresh. In order to force a refresh
|
||||
// we can remove the current AccessToken.
|
||||
t = &oauth2.Token{
|
||||
TokenType: t.TokenType,
|
||||
RefreshToken: t.RefreshToken,
|
||||
}
|
||||
newToken, err := cfg.TokenSource(ctx, t).Token()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("identity/oidc: refresh failed: %w", err)
|
||||
}
|
||||
return newToken, nil
|
||||
}
|
71
internal/identity/oidc/refresh_test.go
Normal file
71
internal/identity/oidc/refresh_test.go
Normal file
|
@ -0,0 +1,71 @@
|
|||
package oidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
func TestRefresh(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, clearTimeout := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
t.Cleanup(clearTimeout)
|
||||
|
||||
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Write([]byte(`{
|
||||
"access_token": "NEW_TOKEN",
|
||||
"refresh_token": "NEW_REFRESH_TOKEN",
|
||||
"expires_in": 3600
|
||||
}`))
|
||||
}))
|
||||
t.Cleanup(s.Close)
|
||||
|
||||
cfg := &oauth2.Config{Endpoint: oauth2.Endpoint{TokenURL: s.URL}}
|
||||
|
||||
token := &oauth2.Token{
|
||||
AccessToken: "OLD_TOKEN",
|
||||
RefreshToken: "OLD_REFRESH_TOKEN",
|
||||
|
||||
// Even if a token is not expiring soon, Refresh() should still perform
|
||||
// the refresh.
|
||||
Expiry: time.Now().Add(time.Hour),
|
||||
}
|
||||
require.True(t, token.Valid())
|
||||
|
||||
newToken, err := Refresh(ctx, cfg, token)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "NEW_TOKEN", newToken.AccessToken)
|
||||
assert.Equal(t, "NEW_REFRESH_TOKEN", newToken.RefreshToken)
|
||||
}
|
||||
|
||||
func TestRefresh_errors(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, clearTimeout := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
t.Cleanup(clearTimeout)
|
||||
|
||||
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte("{}"))
|
||||
}))
|
||||
t.Cleanup(s.Close)
|
||||
|
||||
cfg := &oauth2.Config{Endpoint: oauth2.Endpoint{TokenURL: s.URL}}
|
||||
|
||||
_, err := Refresh(ctx, cfg, nil)
|
||||
assert.Equal(t, ErrMissingRefreshToken, err)
|
||||
|
||||
_, err = Refresh(ctx, cfg, &oauth2.Token{})
|
||||
assert.Equal(t, ErrMissingRefreshToken, err)
|
||||
|
||||
_, err = Refresh(ctx, cfg, &oauth2.Token{RefreshToken: "REFRESH_TOKEN"})
|
||||
assert.Equal(t, "identity/oidc: refresh failed: oauth2: server response missing access_token",
|
||||
err.Error())
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue