mirror of
https://github.com/pomerium/pomerium.git
synced 2025-07-23 19:49:13 +02:00
Merge remote-tracking branch 'origin/main' into feature/zero
This commit is contained in:
commit
58d8f406a9
8 changed files with 350 additions and 57 deletions
|
@ -54,6 +54,9 @@ func (u *User) UnmarshalJSON(data []byte) error {
|
|||
// A Session is a session managed by the Manager.
|
||||
type Session struct {
|
||||
*session.Session
|
||||
// lastRefresh is the time of the last refresh attempt (which may or may
|
||||
// not have succeeded), or else the time the Manager first became aware of
|
||||
// the session (if it has not yet attempted to refresh this session).
|
||||
lastRefresh time.Time
|
||||
// gracePeriod is the amount of time before expiration to attempt a refresh.
|
||||
gracePeriod time.Duration
|
||||
|
|
|
@ -107,6 +107,10 @@ func (mgr *Manager) GetDataBrokerServiceClient() databroker.DataBrokerServiceCli
|
|||
return mgr.cfg.Load().dataBrokerClient
|
||||
}
|
||||
|
||||
func (mgr *Manager) now() time.Time {
|
||||
return mgr.cfg.Load().now()
|
||||
}
|
||||
|
||||
func (mgr *Manager) refreshLoop(ctx context.Context, update <-chan updateRecordsMessage, clear <-chan struct{}) error {
|
||||
// wait for initial sync
|
||||
select {
|
||||
|
@ -145,7 +149,7 @@ func (mgr *Manager) refreshLoop(ctx context.Context, update <-chan updateRecords
|
|||
case <-timer.C:
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
now := mgr.now()
|
||||
nextTime = now.Add(maxWait)
|
||||
|
||||
// refresh sessions
|
||||
|
@ -182,22 +186,21 @@ func (mgr *Manager) refreshLoop(ctx context.Context, update <-chan updateRecords
|
|||
}
|
||||
}
|
||||
|
||||
// refreshSession handles two distinct session lifecycle events:
|
||||
//
|
||||
// 1. If the session itself has expired, delete the session.
|
||||
// 2. If the session's underlying OAuth2 access token is nearing expiration
|
||||
// (but the session itself is still valid), refresh the access token.
|
||||
//
|
||||
// After a successful access token refresh, this method will also trigger a
|
||||
// user info refresh. If an access token refresh or a user info refresh fails
|
||||
// with a permanent error, the session will be deleted.
|
||||
func (mgr *Manager) refreshSession(ctx context.Context, userID, sessionID string) {
|
||||
log.Info(ctx).
|
||||
Str("user_id", userID).
|
||||
Str("session_id", sessionID).
|
||||
Msg("refreshing session")
|
||||
|
||||
authenticator := mgr.cfg.Load().authenticator
|
||||
if authenticator == nil {
|
||||
log.Info(ctx).
|
||||
Str("user_id", userID).
|
||||
Str("session_id", sessionID).
|
||||
Msg("no authenticator defined, deleting session")
|
||||
mgr.deleteSession(ctx, userID, sessionID)
|
||||
return
|
||||
}
|
||||
|
||||
s, ok := mgr.sessions.Get(userID, sessionID)
|
||||
if !ok {
|
||||
log.Warn(ctx).
|
||||
|
@ -207,14 +210,37 @@ func (mgr *Manager) refreshSession(ctx context.Context, userID, sessionID string
|
|||
return
|
||||
}
|
||||
|
||||
s.lastRefresh = mgr.now()
|
||||
|
||||
if mgr.refreshSessionInternal(ctx, userID, sessionID, &s) {
|
||||
mgr.sessions.ReplaceOrInsert(s)
|
||||
mgr.sessionScheduler.Add(s.NextRefresh(), toSessionSchedulerKey(userID, sessionID))
|
||||
}
|
||||
}
|
||||
|
||||
// refreshSessionInternal performs the core refresh logic and returns true if
|
||||
// the session should be scheduled for refresh again, or false if not.
|
||||
func (mgr *Manager) refreshSessionInternal(
|
||||
ctx context.Context, userID, sessionID string, s *Session,
|
||||
) bool {
|
||||
authenticator := mgr.cfg.Load().authenticator
|
||||
if authenticator == nil {
|
||||
log.Info(ctx).
|
||||
Str("user_id", userID).
|
||||
Str("session_id", sessionID).
|
||||
Msg("no authenticator defined, deleting session")
|
||||
mgr.deleteSession(ctx, userID, sessionID)
|
||||
return false
|
||||
}
|
||||
|
||||
expiry := s.GetExpiresAt().AsTime()
|
||||
if !expiry.After(time.Now()) {
|
||||
if !expiry.After(mgr.now()) {
|
||||
log.Info(ctx).
|
||||
Str("user_id", userID).
|
||||
Str("session_id", sessionID).
|
||||
Msg("deleting expired session")
|
||||
mgr.deleteSession(ctx, userID, sessionID)
|
||||
return
|
||||
return false
|
||||
}
|
||||
|
||||
if s.Session == nil || s.Session.OauthToken == nil {
|
||||
|
@ -222,10 +248,10 @@ func (mgr *Manager) refreshSession(ctx context.Context, userID, sessionID string
|
|||
Str("user_id", userID).
|
||||
Str("session_id", sessionID).
|
||||
Msg("no session oauth2 token found for refresh")
|
||||
return
|
||||
return false
|
||||
}
|
||||
|
||||
newToken, err := authenticator.Refresh(ctx, FromOAuthToken(s.OauthToken), &s)
|
||||
newToken, err := authenticator.Refresh(ctx, FromOAuthToken(s.OauthToken), s)
|
||||
metrics.RecordIdentityManagerSessionRefresh(ctx, err)
|
||||
mgr.recordLastError(metrics_ids.IdentityManagerLastSessionRefreshError, err)
|
||||
if isTemporaryError(err) {
|
||||
|
@ -233,18 +259,18 @@ func (mgr *Manager) refreshSession(ctx context.Context, userID, sessionID string
|
|||
Str("user_id", s.GetUserId()).
|
||||
Str("session_id", s.GetId()).
|
||||
Msg("failed to refresh oauth2 token")
|
||||
return
|
||||
return true
|
||||
} else if err != nil {
|
||||
log.Error(ctx).Err(err).
|
||||
Str("user_id", s.GetUserId()).
|
||||
Str("session_id", s.GetId()).
|
||||
Msg("failed to refresh oauth2 token, deleting session")
|
||||
mgr.deleteSession(ctx, userID, sessionID)
|
||||
return
|
||||
return false
|
||||
}
|
||||
s.OauthToken = ToOAuthToken(newToken)
|
||||
|
||||
err = authenticator.UpdateUserInfo(ctx, FromOAuthToken(s.OauthToken), &s)
|
||||
err = authenticator.UpdateUserInfo(ctx, FromOAuthToken(s.OauthToken), s)
|
||||
metrics.RecordIdentityManagerUserRefresh(ctx, err)
|
||||
mgr.recordLastError(metrics_ids.IdentityManagerLastUserRefreshError, err)
|
||||
if isTemporaryError(err) {
|
||||
|
@ -252,26 +278,23 @@ func (mgr *Manager) refreshSession(ctx context.Context, userID, sessionID string
|
|||
Str("user_id", s.GetUserId()).
|
||||
Str("session_id", s.GetId()).
|
||||
Msg("failed to update user info")
|
||||
return
|
||||
return true
|
||||
} else if err != nil {
|
||||
log.Error(ctx).Err(err).
|
||||
Str("user_id", s.GetUserId()).
|
||||
Str("session_id", s.GetId()).
|
||||
Msg("failed to update user info, deleting session")
|
||||
mgr.deleteSession(ctx, userID, sessionID)
|
||||
return
|
||||
return false
|
||||
}
|
||||
|
||||
res, err := session.Put(ctx, mgr.cfg.Load().dataBrokerClient, s.Session)
|
||||
if err != nil {
|
||||
if _, err := session.Put(ctx, mgr.cfg.Load().dataBrokerClient, s.Session); err != nil {
|
||||
log.Error(ctx).Err(err).
|
||||
Str("user_id", s.GetUserId()).
|
||||
Str("session_id", s.GetId()).
|
||||
Msg("failed to update session")
|
||||
return
|
||||
}
|
||||
|
||||
mgr.onUpdateSession(ctx, res.GetRecord(), s.Session)
|
||||
return true
|
||||
}
|
||||
|
||||
func (mgr *Manager) refreshUser(ctx context.Context, userID string) {
|
||||
|
@ -291,7 +314,7 @@ func (mgr *Manager) refreshUser(ctx context.Context, userID string) {
|
|||
Msg("no user found for refresh")
|
||||
return
|
||||
}
|
||||
u.lastRefresh = time.Now()
|
||||
u.lastRefresh = mgr.now()
|
||||
mgr.userScheduler.Add(u.NextRefresh(), u.GetId())
|
||||
|
||||
for _, s := range mgr.sessions.GetSessionsForUser(userID) {
|
||||
|
@ -343,7 +366,7 @@ func (mgr *Manager) onUpdateRecords(ctx context.Context, msg updateRecordsMessag
|
|||
log.Warn(ctx).Msgf("error unmarshaling session: %s", err)
|
||||
continue
|
||||
}
|
||||
mgr.onUpdateSession(ctx, record, &pbSession)
|
||||
mgr.onUpdateSession(record, &pbSession)
|
||||
case grpcutil.GetTypeURL(new(user.User)):
|
||||
var pbUser user.User
|
||||
err := record.GetData().UnmarshalTo(&pbUser)
|
||||
|
@ -356,7 +379,7 @@ func (mgr *Manager) onUpdateRecords(ctx context.Context, msg updateRecordsMessag
|
|||
}
|
||||
}
|
||||
|
||||
func (mgr *Manager) onUpdateSession(_ context.Context, record *databroker.Record, session *session.Session) {
|
||||
func (mgr *Manager) onUpdateSession(record *databroker.Record, session *session.Session) {
|
||||
mgr.sessionScheduler.Remove(toSessionSchedulerKey(session.GetUserId(), session.GetId()))
|
||||
|
||||
if record.GetDeletedAt() != nil {
|
||||
|
@ -366,7 +389,9 @@ func (mgr *Manager) onUpdateSession(_ context.Context, record *databroker.Record
|
|||
|
||||
// update session
|
||||
s, _ := mgr.sessions.Get(session.GetUserId(), session.GetId())
|
||||
s.lastRefresh = time.Now()
|
||||
if s.lastRefresh.IsZero() {
|
||||
s.lastRefresh = mgr.now()
|
||||
}
|
||||
s.gracePeriod = mgr.cfg.Load().sessionRefreshGracePeriod
|
||||
s.coolOffDuration = mgr.cfg.Load().sessionRefreshCoolOffDuration
|
||||
s.Session = session
|
||||
|
|
|
@ -3,6 +3,7 @@ package manager
|
|||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
@ -24,18 +25,23 @@ import (
|
|||
"github.com/pomerium/pomerium/pkg/protoutil"
|
||||
)
|
||||
|
||||
type mockAuthenticator struct{}
|
||||
|
||||
func (mock mockAuthenticator) Refresh(_ context.Context, _ *oauth2.Token, _ identity.State) (*oauth2.Token, error) {
|
||||
return nil, errors.New("update session")
|
||||
type mockAuthenticator struct {
|
||||
refreshResult *oauth2.Token
|
||||
refreshError error
|
||||
revokeError error
|
||||
updateUserInfoError error
|
||||
}
|
||||
|
||||
func (mock mockAuthenticator) Revoke(_ context.Context, _ *oauth2.Token) error {
|
||||
return errors.New("not implemented")
|
||||
func (mock *mockAuthenticator) Refresh(_ context.Context, _ *oauth2.Token, _ identity.State) (*oauth2.Token, error) {
|
||||
return mock.refreshResult, mock.refreshError
|
||||
}
|
||||
|
||||
func (mock mockAuthenticator) UpdateUserInfo(_ context.Context, _ *oauth2.Token, _ any) error {
|
||||
return errors.New("update user info")
|
||||
func (mock *mockAuthenticator) Revoke(_ context.Context, _ *oauth2.Token) error {
|
||||
return mock.revokeError
|
||||
}
|
||||
|
||||
func (mock *mockAuthenticator) UpdateUserInfo(_ context.Context, _ *oauth2.Token, _ any) error {
|
||||
return mock.updateUserInfoError
|
||||
}
|
||||
|
||||
func TestManager_refresh(t *testing.T) {
|
||||
|
@ -86,6 +92,9 @@ func TestManager_onUpdateRecords(t *testing.T) {
|
|||
})
|
||||
|
||||
if _, ok := mgr.sessions.Get("user1", "session1"); assert.True(t, ok) {
|
||||
tm, id := mgr.sessionScheduler.Next()
|
||||
assert.Equal(t, now.Add(10*time.Second), tm)
|
||||
assert.Equal(t, "user1\037session1", id)
|
||||
}
|
||||
if _, ok := mgr.users.Get("user1"); assert.True(t, ok) {
|
||||
tm, id := mgr.userScheduler.Next()
|
||||
|
@ -94,6 +103,158 @@ func TestManager_onUpdateRecords(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestManager_onUpdateSession(t *testing.T) {
|
||||
startTime := time.Date(2023, 10, 19, 12, 0, 0, 0, time.UTC)
|
||||
|
||||
s := &session.Session{
|
||||
Id: "session-id",
|
||||
UserId: "user-id",
|
||||
OauthToken: &session.OAuthToken{
|
||||
AccessToken: "access-token",
|
||||
ExpiresAt: timestamppb.New(startTime.Add(5 * time.Minute)),
|
||||
},
|
||||
IssuedAt: timestamppb.New(startTime),
|
||||
ExpiresAt: timestamppb.New(startTime.Add(24 * time.Hour)),
|
||||
}
|
||||
|
||||
assertNextScheduled := func(t *testing.T, mgr *Manager, expectedTime time.Time) {
|
||||
t.Helper()
|
||||
tm, key := mgr.sessionScheduler.Next()
|
||||
assert.Equal(t, expectedTime, tm)
|
||||
assert.Equal(t, "user-id\037session-id", key)
|
||||
}
|
||||
|
||||
t.Run("initial refresh event when not expiring soon", func(t *testing.T) {
|
||||
now := startTime
|
||||
mgr := New(WithNow(func() time.Time { return now }))
|
||||
|
||||
// When the Manager first becomes aware of a session it should schedule
|
||||
// a refresh event for one minute before access token expiration.
|
||||
mgr.onUpdateSession(mkRecord(s), s)
|
||||
assertNextScheduled(t, mgr, startTime.Add(4*time.Minute))
|
||||
})
|
||||
t.Run("initial refresh event when expiring soon", func(t *testing.T) {
|
||||
now := startTime
|
||||
mgr := New(WithNow(func() time.Time { return now }))
|
||||
|
||||
// When the Manager first becomes aware of a session, if that session
|
||||
// is expiring within the gracePeriod (1 minute), it should schedule a
|
||||
// refresh event for as soon as possible, subject to the
|
||||
// coolOffDuration (10 seconds).
|
||||
now = now.Add(4*time.Minute + 30*time.Second) // 30 s before expiration
|
||||
mgr.onUpdateSession(mkRecord(s), s)
|
||||
assertNextScheduled(t, mgr, now.Add(10*time.Second))
|
||||
})
|
||||
t.Run("update near scheduled refresh", func(t *testing.T) {
|
||||
now := startTime
|
||||
mgr := New(WithNow(func() time.Time { return now }))
|
||||
|
||||
mgr.onUpdateSession(mkRecord(s), s)
|
||||
assertNextScheduled(t, mgr, startTime.Add(4*time.Minute))
|
||||
|
||||
// If a session is updated close to the time when it is scheduled to be
|
||||
// refreshed, the scheduled refresh event should not be pushed back.
|
||||
now = now.Add(3*time.Minute + 55*time.Second) // 5 s before refresh
|
||||
mgr.onUpdateSession(mkRecord(s), s)
|
||||
assertNextScheduled(t, mgr, now.Add(5*time.Second))
|
||||
|
||||
// However, if an update changes the access token validity, the refresh
|
||||
// event should be rescheduled accordingly. (This should be uncommon,
|
||||
// as only the refresh loop itself should modify the access token.)
|
||||
s2 := proto.Clone(s).(*session.Session)
|
||||
s2.OauthToken.ExpiresAt = timestamppb.New(now.Add(5 * time.Minute))
|
||||
mgr.onUpdateSession(mkRecord(s2), s2)
|
||||
assertNextScheduled(t, mgr, now.Add(4*time.Minute))
|
||||
})
|
||||
t.Run("session record deleted", func(t *testing.T) {
|
||||
now := startTime
|
||||
mgr := New(WithNow(func() time.Time { return now }))
|
||||
|
||||
mgr.onUpdateSession(mkRecord(s), s)
|
||||
assertNextScheduled(t, mgr, startTime.Add(4*time.Minute))
|
||||
|
||||
// If a session is deleted, any scheduled refresh event should be canceled.
|
||||
record := mkRecord(s)
|
||||
record.DeletedAt = timestamppb.New(now)
|
||||
mgr.onUpdateSession(record, s)
|
||||
_, key := mgr.sessionScheduler.Next()
|
||||
assert.Empty(t, key)
|
||||
})
|
||||
}
|
||||
|
||||
func TestManager_refreshSession(t *testing.T) {
|
||||
startTime := time.Date(2023, 10, 19, 12, 0, 0, 0, time.UTC)
|
||||
|
||||
var auth mockAuthenticator
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
client := mock_databroker.NewMockDataBrokerServiceClient(ctrl)
|
||||
|
||||
now := startTime
|
||||
mgr := New(
|
||||
WithDataBrokerClient(client),
|
||||
WithNow(func() time.Time { return now }),
|
||||
WithAuthenticator(&auth),
|
||||
)
|
||||
|
||||
// Initialize the Manager with a new session.
|
||||
s := &session.Session{
|
||||
Id: "session-id",
|
||||
UserId: "user-id",
|
||||
OauthToken: &session.OAuthToken{
|
||||
AccessToken: "access-token",
|
||||
ExpiresAt: timestamppb.New(startTime.Add(5 * time.Minute)),
|
||||
RefreshToken: "refresh-token",
|
||||
},
|
||||
IssuedAt: timestamppb.New(startTime),
|
||||
ExpiresAt: timestamppb.New(startTime.Add(24 * time.Hour)),
|
||||
}
|
||||
mgr.sessions.ReplaceOrInsert(Session{
|
||||
Session: s,
|
||||
lastRefresh: startTime,
|
||||
gracePeriod: time.Minute,
|
||||
coolOffDuration: 10 * time.Second,
|
||||
})
|
||||
|
||||
// If OAuth2 token refresh fails with a temporary error, the manager should
|
||||
// still reschedule another refresh attempt.
|
||||
now = now.Add(4 * time.Minute)
|
||||
auth.refreshError = context.DeadlineExceeded
|
||||
mgr.refreshSession(context.Background(), "user-id", "session-id")
|
||||
|
||||
tm, key := mgr.sessionScheduler.Next()
|
||||
assert.Equal(t, now.Add(10*time.Second), tm)
|
||||
assert.Equal(t, "user-id\037session-id", key)
|
||||
|
||||
// Simulate a successful token refresh on the second attempt. The manager
|
||||
// should store the updated session in the databroker and schedule another
|
||||
// refresh event.
|
||||
now = now.Add(10 * time.Second)
|
||||
auth.refreshResult, auth.refreshError = &oauth2.Token{
|
||||
AccessToken: "new-access-token",
|
||||
RefreshToken: "new-refresh-token",
|
||||
Expiry: now.Add(5 * time.Minute),
|
||||
}, nil
|
||||
expectedSession := proto.Clone(s).(*session.Session)
|
||||
expectedSession.OauthToken = &session.OAuthToken{
|
||||
AccessToken: "new-access-token",
|
||||
ExpiresAt: timestamppb.New(now.Add(5 * time.Minute)),
|
||||
RefreshToken: "new-refresh-token",
|
||||
}
|
||||
client.EXPECT().Put(gomock.Any(),
|
||||
objectsAreEqualMatcher{&databroker.PutRequest{Records: []*databroker.Record{{
|
||||
Type: "type.googleapis.com/session.Session",
|
||||
Id: "session-id",
|
||||
Data: protoutil.NewAny(expectedSession),
|
||||
}}}}).
|
||||
Return(nil /* this result is currently unused */, nil)
|
||||
mgr.refreshSession(context.Background(), "user-id", "session-id")
|
||||
|
||||
tm, key = mgr.sessionScheduler.Next()
|
||||
assert.Equal(t, now.Add(4*time.Minute), tm)
|
||||
assert.Equal(t, "user-id\037session-id", key)
|
||||
}
|
||||
|
||||
func TestManager_reportErrors(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
|
||||
|
@ -135,7 +296,10 @@ func TestManager_reportErrors(t *testing.T) {
|
|||
mgr := New(
|
||||
WithEventManager(evtMgr),
|
||||
WithDataBrokerClient(client),
|
||||
WithAuthenticator(mockAuthenticator{}),
|
||||
WithAuthenticator(&mockAuthenticator{
|
||||
refreshError: errors.New("update session"),
|
||||
updateUserInfoError: errors.New("update user info"),
|
||||
}),
|
||||
)
|
||||
|
||||
mgr.onUpdateRecords(ctx, updateRecordsMessage{
|
||||
|
@ -172,3 +336,18 @@ type recordable interface {
|
|||
proto.Message
|
||||
GetId() string
|
||||
}
|
||||
|
||||
// objectsAreEqualMatcher implements gomock.Matcher using ObjectsAreEqual. This
|
||||
// is especially helpful when working with pointers, as it will compare the
|
||||
// underlying values rather than the pointers themselves.
|
||||
type objectsAreEqualMatcher struct {
|
||||
expected interface{}
|
||||
}
|
||||
|
||||
func (m objectsAreEqualMatcher) Matches(x interface{}) bool {
|
||||
return assert.ObjectsAreEqual(m.expected, x)
|
||||
}
|
||||
|
||||
func (m objectsAreEqualMatcher) String() string {
|
||||
return fmt.Sprintf("is equal to %v (%T)", m.expected, m.expected)
|
||||
}
|
||||
|
|
|
@ -130,16 +130,9 @@ func (p *Provider) Authenticate(ctx context.Context, code string, v identity.Sta
|
|||
|
||||
// Refresh renews a user's session.
|
||||
func (p *Provider) Refresh(ctx context.Context, t *oauth2.Token, v identity.State) (*oauth2.Token, error) {
|
||||
if t == nil {
|
||||
return nil, oidc.ErrMissingAccessToken
|
||||
}
|
||||
if t.RefreshToken == "" {
|
||||
return nil, oidc.ErrMissingRefreshToken
|
||||
}
|
||||
|
||||
newToken, err := p.oauth.TokenSource(ctx, t).Token()
|
||||
newToken, err := oidc.Refresh(ctx, p.oauth, t)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("identity/apple: refresh failed: %w", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if rawIDToken, ok := newToken.Extra("id_token").(string); ok {
|
||||
|
|
|
@ -213,16 +213,9 @@ func (p *Provider) Refresh(ctx context.Context, t *oauth2.Token, v identity.Stat
|
|||
return nil, err
|
||||
}
|
||||
|
||||
if t == nil {
|
||||
return nil, ErrMissingAccessToken
|
||||
}
|
||||
if t.RefreshToken == "" {
|
||||
return nil, ErrMissingRefreshToken
|
||||
}
|
||||
|
||||
newToken, err := oa.TokenSource(ctx, t).Token()
|
||||
newToken, err := Refresh(ctx, oa, t)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("identity/oidc: refresh failed: %w", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Many identity providers _will not_ return `id_token` on refresh
|
||||
|
|
29
internal/identity/oidc/refresh.go
Normal file
29
internal/identity/oidc/refresh.go
Normal file
|
@ -0,0 +1,29 @@
|
|||
package oidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
// Refresh requests a new oauth2.Token based on an existing Token and the
|
||||
// provided Config. The existing Token must contain a refresh token.
|
||||
func Refresh(ctx context.Context, cfg *oauth2.Config, t *oauth2.Token) (*oauth2.Token, error) {
|
||||
if t == nil || t.RefreshToken == "" {
|
||||
return nil, ErrMissingRefreshToken
|
||||
}
|
||||
|
||||
// Note: the TokenSource returned by oauth2.Config has its own threshold
|
||||
// for determining when to attempt a refresh. In order to force a refresh
|
||||
// we can remove the current AccessToken.
|
||||
t = &oauth2.Token{
|
||||
TokenType: t.TokenType,
|
||||
RefreshToken: t.RefreshToken,
|
||||
}
|
||||
newToken, err := cfg.TokenSource(ctx, t).Token()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("identity/oidc: refresh failed: %w", err)
|
||||
}
|
||||
return newToken, nil
|
||||
}
|
71
internal/identity/oidc/refresh_test.go
Normal file
71
internal/identity/oidc/refresh_test.go
Normal file
|
@ -0,0 +1,71 @@
|
|||
package oidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
func TestRefresh(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, clearTimeout := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
t.Cleanup(clearTimeout)
|
||||
|
||||
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Write([]byte(`{
|
||||
"access_token": "NEW_TOKEN",
|
||||
"refresh_token": "NEW_REFRESH_TOKEN",
|
||||
"expires_in": 3600
|
||||
}`))
|
||||
}))
|
||||
t.Cleanup(s.Close)
|
||||
|
||||
cfg := &oauth2.Config{Endpoint: oauth2.Endpoint{TokenURL: s.URL}}
|
||||
|
||||
token := &oauth2.Token{
|
||||
AccessToken: "OLD_TOKEN",
|
||||
RefreshToken: "OLD_REFRESH_TOKEN",
|
||||
|
||||
// Even if a token is not expiring soon, Refresh() should still perform
|
||||
// the refresh.
|
||||
Expiry: time.Now().Add(time.Hour),
|
||||
}
|
||||
require.True(t, token.Valid())
|
||||
|
||||
newToken, err := Refresh(ctx, cfg, token)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "NEW_TOKEN", newToken.AccessToken)
|
||||
assert.Equal(t, "NEW_REFRESH_TOKEN", newToken.RefreshToken)
|
||||
}
|
||||
|
||||
func TestRefresh_errors(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, clearTimeout := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
t.Cleanup(clearTimeout)
|
||||
|
||||
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte("{}"))
|
||||
}))
|
||||
t.Cleanup(s.Close)
|
||||
|
||||
cfg := &oauth2.Config{Endpoint: oauth2.Endpoint{TokenURL: s.URL}}
|
||||
|
||||
_, err := Refresh(ctx, cfg, nil)
|
||||
assert.Equal(t, ErrMissingRefreshToken, err)
|
||||
|
||||
_, err = Refresh(ctx, cfg, &oauth2.Token{})
|
||||
assert.Equal(t, ErrMissingRefreshToken, err)
|
||||
|
||||
_, err = Refresh(ctx, cfg, &oauth2.Token{RefreshToken: "REFRESH_TOKEN"})
|
||||
assert.Equal(t, "identity/oidc: refresh failed: oauth2: server response missing access_token",
|
||||
err.Error())
|
||||
}
|
|
@ -5,7 +5,7 @@ PATH="$PATH:$(go env GOPATH)/bin"
|
|||
export PATH
|
||||
|
||||
_project_root="$(cd "$(dirname "${BASH_SOURCE[0]}")" >/dev/null 2>&1 && pwd)/.."
|
||||
_envoy_version=1.27.1
|
||||
_envoy_version=1.28.0
|
||||
_dir="$_project_root/pkg/envoy/files"
|
||||
|
||||
for _target in darwin-amd64 darwin-arm64 linux-amd64 linux-arm64; do
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue