pomerium/internal/identity/oidc/refresh_test.go
Kenneth Jenkins 39a477c510
identity: override TokenSource expiry behavior (#4632)
The current session refresh loop attempts to refresh access tokens when
they are due to expire in less than one minute. However, the code to
perform the refresh relies on a TokenSource from the x/oauth2 package,
which has its own internal 'expiryDelta' threshold, with a default of
10 seconds. As a result, the first four or five attempts to refresh a
particular access token will not actually refresh the token. The refresh
will happen only when the access token is within 10 seconds of expiring.

Instead, before we obtain a new TokenSource, first clear any existing
access token. This causes the TokenSource to consider the token invalid,
triggering a refresh. This should give the refresh loop more control
over when refreshes happen.

Consolidate this logic in a new Refresh() method in the oidc package.
Add unit tests for this new method.
2023-10-23 08:20:04 -07:00

71 lines
1.8 KiB
Go

package oidc
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/oauth2"
)
func TestRefresh(t *testing.T) {
t.Parallel()
ctx, clearTimeout := context.WithTimeout(context.Background(), 10*time.Second)
t.Cleanup(clearTimeout)
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(`{
"access_token": "NEW_TOKEN",
"refresh_token": "NEW_REFRESH_TOKEN",
"expires_in": 3600
}`))
}))
t.Cleanup(s.Close)
cfg := &oauth2.Config{Endpoint: oauth2.Endpoint{TokenURL: s.URL}}
token := &oauth2.Token{
AccessToken: "OLD_TOKEN",
RefreshToken: "OLD_REFRESH_TOKEN",
// Even if a token is not expiring soon, Refresh() should still perform
// the refresh.
Expiry: time.Now().Add(time.Hour),
}
require.True(t, token.Valid())
newToken, err := Refresh(ctx, cfg, token)
require.NoError(t, err)
assert.Equal(t, "NEW_TOKEN", newToken.AccessToken)
assert.Equal(t, "NEW_REFRESH_TOKEN", newToken.RefreshToken)
}
func TestRefresh_errors(t *testing.T) {
t.Parallel()
ctx, clearTimeout := context.WithTimeout(context.Background(), 10*time.Second)
t.Cleanup(clearTimeout)
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("{}"))
}))
t.Cleanup(s.Close)
cfg := &oauth2.Config{Endpoint: oauth2.Endpoint{TokenURL: s.URL}}
_, err := Refresh(ctx, cfg, nil)
assert.Equal(t, ErrMissingRefreshToken, err)
_, err = Refresh(ctx, cfg, &oauth2.Token{})
assert.Equal(t, ErrMissingRefreshToken, err)
_, err = Refresh(ctx, cfg, &oauth2.Token{RefreshToken: "REFRESH_TOKEN"})
assert.Equal(t, "identity/oidc: refresh failed: oauth2: server response missing access_token",
err.Error())
}