Merge remote-tracking branch 'origin/main' into feature/zero

This commit is contained in:
Denis Mishin 2023-10-24 21:42:40 -04:00
commit 58d8f406a9
8 changed files with 350 additions and 57 deletions

View file

@ -54,6 +54,9 @@ func (u *User) UnmarshalJSON(data []byte) error {
// A Session is a session managed by the Manager.
type Session struct {
*session.Session
// lastRefresh is the time of the last refresh attempt (which may or may
// not have succeeded), or else the time the Manager first became aware of
// the session (if it has not yet attempted to refresh this session).
lastRefresh time.Time
// gracePeriod is the amount of time before expiration to attempt a refresh.
gracePeriod time.Duration

View file

@ -107,6 +107,10 @@ func (mgr *Manager) GetDataBrokerServiceClient() databroker.DataBrokerServiceCli
return mgr.cfg.Load().dataBrokerClient
}
func (mgr *Manager) now() time.Time {
return mgr.cfg.Load().now()
}
func (mgr *Manager) refreshLoop(ctx context.Context, update <-chan updateRecordsMessage, clear <-chan struct{}) error {
// wait for initial sync
select {
@ -145,7 +149,7 @@ func (mgr *Manager) refreshLoop(ctx context.Context, update <-chan updateRecords
case <-timer.C:
}
now := time.Now()
now := mgr.now()
nextTime = now.Add(maxWait)
// refresh sessions
@ -182,22 +186,21 @@ func (mgr *Manager) refreshLoop(ctx context.Context, update <-chan updateRecords
}
}
// refreshSession handles two distinct session lifecycle events:
//
// 1. If the session itself has expired, delete the session.
// 2. If the session's underlying OAuth2 access token is nearing expiration
// (but the session itself is still valid), refresh the access token.
//
// After a successful access token refresh, this method will also trigger a
// user info refresh. If an access token refresh or a user info refresh fails
// with a permanent error, the session will be deleted.
func (mgr *Manager) refreshSession(ctx context.Context, userID, sessionID string) {
log.Info(ctx).
Str("user_id", userID).
Str("session_id", sessionID).
Msg("refreshing session")
authenticator := mgr.cfg.Load().authenticator
if authenticator == nil {
log.Info(ctx).
Str("user_id", userID).
Str("session_id", sessionID).
Msg("no authenticator defined, deleting session")
mgr.deleteSession(ctx, userID, sessionID)
return
}
s, ok := mgr.sessions.Get(userID, sessionID)
if !ok {
log.Warn(ctx).
@ -207,14 +210,37 @@ func (mgr *Manager) refreshSession(ctx context.Context, userID, sessionID string
return
}
s.lastRefresh = mgr.now()
if mgr.refreshSessionInternal(ctx, userID, sessionID, &s) {
mgr.sessions.ReplaceOrInsert(s)
mgr.sessionScheduler.Add(s.NextRefresh(), toSessionSchedulerKey(userID, sessionID))
}
}
// refreshSessionInternal performs the core refresh logic and returns true if
// the session should be scheduled for refresh again, or false if not.
func (mgr *Manager) refreshSessionInternal(
ctx context.Context, userID, sessionID string, s *Session,
) bool {
authenticator := mgr.cfg.Load().authenticator
if authenticator == nil {
log.Info(ctx).
Str("user_id", userID).
Str("session_id", sessionID).
Msg("no authenticator defined, deleting session")
mgr.deleteSession(ctx, userID, sessionID)
return false
}
expiry := s.GetExpiresAt().AsTime()
if !expiry.After(time.Now()) {
if !expiry.After(mgr.now()) {
log.Info(ctx).
Str("user_id", userID).
Str("session_id", sessionID).
Msg("deleting expired session")
mgr.deleteSession(ctx, userID, sessionID)
return
return false
}
if s.Session == nil || s.Session.OauthToken == nil {
@ -222,10 +248,10 @@ func (mgr *Manager) refreshSession(ctx context.Context, userID, sessionID string
Str("user_id", userID).
Str("session_id", sessionID).
Msg("no session oauth2 token found for refresh")
return
return false
}
newToken, err := authenticator.Refresh(ctx, FromOAuthToken(s.OauthToken), &s)
newToken, err := authenticator.Refresh(ctx, FromOAuthToken(s.OauthToken), s)
metrics.RecordIdentityManagerSessionRefresh(ctx, err)
mgr.recordLastError(metrics_ids.IdentityManagerLastSessionRefreshError, err)
if isTemporaryError(err) {
@ -233,18 +259,18 @@ func (mgr *Manager) refreshSession(ctx context.Context, userID, sessionID string
Str("user_id", s.GetUserId()).
Str("session_id", s.GetId()).
Msg("failed to refresh oauth2 token")
return
return true
} else if err != nil {
log.Error(ctx).Err(err).
Str("user_id", s.GetUserId()).
Str("session_id", s.GetId()).
Msg("failed to refresh oauth2 token, deleting session")
mgr.deleteSession(ctx, userID, sessionID)
return
return false
}
s.OauthToken = ToOAuthToken(newToken)
err = authenticator.UpdateUserInfo(ctx, FromOAuthToken(s.OauthToken), &s)
err = authenticator.UpdateUserInfo(ctx, FromOAuthToken(s.OauthToken), s)
metrics.RecordIdentityManagerUserRefresh(ctx, err)
mgr.recordLastError(metrics_ids.IdentityManagerLastUserRefreshError, err)
if isTemporaryError(err) {
@ -252,26 +278,23 @@ func (mgr *Manager) refreshSession(ctx context.Context, userID, sessionID string
Str("user_id", s.GetUserId()).
Str("session_id", s.GetId()).
Msg("failed to update user info")
return
return true
} else if err != nil {
log.Error(ctx).Err(err).
Str("user_id", s.GetUserId()).
Str("session_id", s.GetId()).
Msg("failed to update user info, deleting session")
mgr.deleteSession(ctx, userID, sessionID)
return
return false
}
res, err := session.Put(ctx, mgr.cfg.Load().dataBrokerClient, s.Session)
if err != nil {
if _, err := session.Put(ctx, mgr.cfg.Load().dataBrokerClient, s.Session); err != nil {
log.Error(ctx).Err(err).
Str("user_id", s.GetUserId()).
Str("session_id", s.GetId()).
Msg("failed to update session")
return
}
mgr.onUpdateSession(ctx, res.GetRecord(), s.Session)
return true
}
func (mgr *Manager) refreshUser(ctx context.Context, userID string) {
@ -291,7 +314,7 @@ func (mgr *Manager) refreshUser(ctx context.Context, userID string) {
Msg("no user found for refresh")
return
}
u.lastRefresh = time.Now()
u.lastRefresh = mgr.now()
mgr.userScheduler.Add(u.NextRefresh(), u.GetId())
for _, s := range mgr.sessions.GetSessionsForUser(userID) {
@ -343,7 +366,7 @@ func (mgr *Manager) onUpdateRecords(ctx context.Context, msg updateRecordsMessag
log.Warn(ctx).Msgf("error unmarshaling session: %s", err)
continue
}
mgr.onUpdateSession(ctx, record, &pbSession)
mgr.onUpdateSession(record, &pbSession)
case grpcutil.GetTypeURL(new(user.User)):
var pbUser user.User
err := record.GetData().UnmarshalTo(&pbUser)
@ -356,7 +379,7 @@ func (mgr *Manager) onUpdateRecords(ctx context.Context, msg updateRecordsMessag
}
}
func (mgr *Manager) onUpdateSession(_ context.Context, record *databroker.Record, session *session.Session) {
func (mgr *Manager) onUpdateSession(record *databroker.Record, session *session.Session) {
mgr.sessionScheduler.Remove(toSessionSchedulerKey(session.GetUserId(), session.GetId()))
if record.GetDeletedAt() != nil {
@ -366,7 +389,9 @@ func (mgr *Manager) onUpdateSession(_ context.Context, record *databroker.Record
// update session
s, _ := mgr.sessions.Get(session.GetUserId(), session.GetId())
s.lastRefresh = time.Now()
if s.lastRefresh.IsZero() {
s.lastRefresh = mgr.now()
}
s.gracePeriod = mgr.cfg.Load().sessionRefreshGracePeriod
s.coolOffDuration = mgr.cfg.Load().sessionRefreshCoolOffDuration
s.Session = session

View file

@ -3,6 +3,7 @@ package manager
import (
"context"
"errors"
"fmt"
"testing"
"time"
@ -24,18 +25,23 @@ import (
"github.com/pomerium/pomerium/pkg/protoutil"
)
type mockAuthenticator struct{}
func (mock mockAuthenticator) Refresh(_ context.Context, _ *oauth2.Token, _ identity.State) (*oauth2.Token, error) {
return nil, errors.New("update session")
type mockAuthenticator struct {
refreshResult *oauth2.Token
refreshError error
revokeError error
updateUserInfoError error
}
func (mock mockAuthenticator) Revoke(_ context.Context, _ *oauth2.Token) error {
return errors.New("not implemented")
func (mock *mockAuthenticator) Refresh(_ context.Context, _ *oauth2.Token, _ identity.State) (*oauth2.Token, error) {
return mock.refreshResult, mock.refreshError
}
func (mock mockAuthenticator) UpdateUserInfo(_ context.Context, _ *oauth2.Token, _ any) error {
return errors.New("update user info")
func (mock *mockAuthenticator) Revoke(_ context.Context, _ *oauth2.Token) error {
return mock.revokeError
}
func (mock *mockAuthenticator) UpdateUserInfo(_ context.Context, _ *oauth2.Token, _ any) error {
return mock.updateUserInfoError
}
func TestManager_refresh(t *testing.T) {
@ -86,6 +92,9 @@ func TestManager_onUpdateRecords(t *testing.T) {
})
if _, ok := mgr.sessions.Get("user1", "session1"); assert.True(t, ok) {
tm, id := mgr.sessionScheduler.Next()
assert.Equal(t, now.Add(10*time.Second), tm)
assert.Equal(t, "user1\037session1", id)
}
if _, ok := mgr.users.Get("user1"); assert.True(t, ok) {
tm, id := mgr.userScheduler.Next()
@ -94,6 +103,158 @@ func TestManager_onUpdateRecords(t *testing.T) {
}
}
func TestManager_onUpdateSession(t *testing.T) {
startTime := time.Date(2023, 10, 19, 12, 0, 0, 0, time.UTC)
s := &session.Session{
Id: "session-id",
UserId: "user-id",
OauthToken: &session.OAuthToken{
AccessToken: "access-token",
ExpiresAt: timestamppb.New(startTime.Add(5 * time.Minute)),
},
IssuedAt: timestamppb.New(startTime),
ExpiresAt: timestamppb.New(startTime.Add(24 * time.Hour)),
}
assertNextScheduled := func(t *testing.T, mgr *Manager, expectedTime time.Time) {
t.Helper()
tm, key := mgr.sessionScheduler.Next()
assert.Equal(t, expectedTime, tm)
assert.Equal(t, "user-id\037session-id", key)
}
t.Run("initial refresh event when not expiring soon", func(t *testing.T) {
now := startTime
mgr := New(WithNow(func() time.Time { return now }))
// When the Manager first becomes aware of a session it should schedule
// a refresh event for one minute before access token expiration.
mgr.onUpdateSession(mkRecord(s), s)
assertNextScheduled(t, mgr, startTime.Add(4*time.Minute))
})
t.Run("initial refresh event when expiring soon", func(t *testing.T) {
now := startTime
mgr := New(WithNow(func() time.Time { return now }))
// When the Manager first becomes aware of a session, if that session
// is expiring within the gracePeriod (1 minute), it should schedule a
// refresh event for as soon as possible, subject to the
// coolOffDuration (10 seconds).
now = now.Add(4*time.Minute + 30*time.Second) // 30 s before expiration
mgr.onUpdateSession(mkRecord(s), s)
assertNextScheduled(t, mgr, now.Add(10*time.Second))
})
t.Run("update near scheduled refresh", func(t *testing.T) {
now := startTime
mgr := New(WithNow(func() time.Time { return now }))
mgr.onUpdateSession(mkRecord(s), s)
assertNextScheduled(t, mgr, startTime.Add(4*time.Minute))
// If a session is updated close to the time when it is scheduled to be
// refreshed, the scheduled refresh event should not be pushed back.
now = now.Add(3*time.Minute + 55*time.Second) // 5 s before refresh
mgr.onUpdateSession(mkRecord(s), s)
assertNextScheduled(t, mgr, now.Add(5*time.Second))
// However, if an update changes the access token validity, the refresh
// event should be rescheduled accordingly. (This should be uncommon,
// as only the refresh loop itself should modify the access token.)
s2 := proto.Clone(s).(*session.Session)
s2.OauthToken.ExpiresAt = timestamppb.New(now.Add(5 * time.Minute))
mgr.onUpdateSession(mkRecord(s2), s2)
assertNextScheduled(t, mgr, now.Add(4*time.Minute))
})
t.Run("session record deleted", func(t *testing.T) {
now := startTime
mgr := New(WithNow(func() time.Time { return now }))
mgr.onUpdateSession(mkRecord(s), s)
assertNextScheduled(t, mgr, startTime.Add(4*time.Minute))
// If a session is deleted, any scheduled refresh event should be canceled.
record := mkRecord(s)
record.DeletedAt = timestamppb.New(now)
mgr.onUpdateSession(record, s)
_, key := mgr.sessionScheduler.Next()
assert.Empty(t, key)
})
}
func TestManager_refreshSession(t *testing.T) {
startTime := time.Date(2023, 10, 19, 12, 0, 0, 0, time.UTC)
var auth mockAuthenticator
ctrl := gomock.NewController(t)
client := mock_databroker.NewMockDataBrokerServiceClient(ctrl)
now := startTime
mgr := New(
WithDataBrokerClient(client),
WithNow(func() time.Time { return now }),
WithAuthenticator(&auth),
)
// Initialize the Manager with a new session.
s := &session.Session{
Id: "session-id",
UserId: "user-id",
OauthToken: &session.OAuthToken{
AccessToken: "access-token",
ExpiresAt: timestamppb.New(startTime.Add(5 * time.Minute)),
RefreshToken: "refresh-token",
},
IssuedAt: timestamppb.New(startTime),
ExpiresAt: timestamppb.New(startTime.Add(24 * time.Hour)),
}
mgr.sessions.ReplaceOrInsert(Session{
Session: s,
lastRefresh: startTime,
gracePeriod: time.Minute,
coolOffDuration: 10 * time.Second,
})
// If OAuth2 token refresh fails with a temporary error, the manager should
// still reschedule another refresh attempt.
now = now.Add(4 * time.Minute)
auth.refreshError = context.DeadlineExceeded
mgr.refreshSession(context.Background(), "user-id", "session-id")
tm, key := mgr.sessionScheduler.Next()
assert.Equal(t, now.Add(10*time.Second), tm)
assert.Equal(t, "user-id\037session-id", key)
// Simulate a successful token refresh on the second attempt. The manager
// should store the updated session in the databroker and schedule another
// refresh event.
now = now.Add(10 * time.Second)
auth.refreshResult, auth.refreshError = &oauth2.Token{
AccessToken: "new-access-token",
RefreshToken: "new-refresh-token",
Expiry: now.Add(5 * time.Minute),
}, nil
expectedSession := proto.Clone(s).(*session.Session)
expectedSession.OauthToken = &session.OAuthToken{
AccessToken: "new-access-token",
ExpiresAt: timestamppb.New(now.Add(5 * time.Minute)),
RefreshToken: "new-refresh-token",
}
client.EXPECT().Put(gomock.Any(),
objectsAreEqualMatcher{&databroker.PutRequest{Records: []*databroker.Record{{
Type: "type.googleapis.com/session.Session",
Id: "session-id",
Data: protoutil.NewAny(expectedSession),
}}}}).
Return(nil /* this result is currently unused */, nil)
mgr.refreshSession(context.Background(), "user-id", "session-id")
tm, key = mgr.sessionScheduler.Next()
assert.Equal(t, now.Add(4*time.Minute), tm)
assert.Equal(t, "user-id\037session-id", key)
}
func TestManager_reportErrors(t *testing.T) {
ctrl := gomock.NewController(t)
@ -135,7 +296,10 @@ func TestManager_reportErrors(t *testing.T) {
mgr := New(
WithEventManager(evtMgr),
WithDataBrokerClient(client),
WithAuthenticator(mockAuthenticator{}),
WithAuthenticator(&mockAuthenticator{
refreshError: errors.New("update session"),
updateUserInfoError: errors.New("update user info"),
}),
)
mgr.onUpdateRecords(ctx, updateRecordsMessage{
@ -172,3 +336,18 @@ type recordable interface {
proto.Message
GetId() string
}
// objectsAreEqualMatcher implements gomock.Matcher using ObjectsAreEqual. This
// is especially helpful when working with pointers, as it will compare the
// underlying values rather than the pointers themselves.
type objectsAreEqualMatcher struct {
expected interface{}
}
func (m objectsAreEqualMatcher) Matches(x interface{}) bool {
return assert.ObjectsAreEqual(m.expected, x)
}
func (m objectsAreEqualMatcher) String() string {
return fmt.Sprintf("is equal to %v (%T)", m.expected, m.expected)
}

View file

@ -130,16 +130,9 @@ func (p *Provider) Authenticate(ctx context.Context, code string, v identity.Sta
// Refresh renews a user's session.
func (p *Provider) Refresh(ctx context.Context, t *oauth2.Token, v identity.State) (*oauth2.Token, error) {
if t == nil {
return nil, oidc.ErrMissingAccessToken
}
if t.RefreshToken == "" {
return nil, oidc.ErrMissingRefreshToken
}
newToken, err := p.oauth.TokenSource(ctx, t).Token()
newToken, err := oidc.Refresh(ctx, p.oauth, t)
if err != nil {
return nil, fmt.Errorf("identity/apple: refresh failed: %w", err)
return nil, err
}
if rawIDToken, ok := newToken.Extra("id_token").(string); ok {

View file

@ -213,16 +213,9 @@ func (p *Provider) Refresh(ctx context.Context, t *oauth2.Token, v identity.Stat
return nil, err
}
if t == nil {
return nil, ErrMissingAccessToken
}
if t.RefreshToken == "" {
return nil, ErrMissingRefreshToken
}
newToken, err := oa.TokenSource(ctx, t).Token()
newToken, err := Refresh(ctx, oa, t)
if err != nil {
return nil, fmt.Errorf("identity/oidc: refresh failed: %w", err)
return nil, err
}
// Many identity providers _will not_ return `id_token` on refresh

View file

@ -0,0 +1,29 @@
package oidc
import (
"context"
"fmt"
"golang.org/x/oauth2"
)
// Refresh requests a new oauth2.Token based on an existing Token and the
// provided Config. The existing Token must contain a refresh token.
func Refresh(ctx context.Context, cfg *oauth2.Config, t *oauth2.Token) (*oauth2.Token, error) {
if t == nil || t.RefreshToken == "" {
return nil, ErrMissingRefreshToken
}
// Note: the TokenSource returned by oauth2.Config has its own threshold
// for determining when to attempt a refresh. In order to force a refresh
// we can remove the current AccessToken.
t = &oauth2.Token{
TokenType: t.TokenType,
RefreshToken: t.RefreshToken,
}
newToken, err := cfg.TokenSource(ctx, t).Token()
if err != nil {
return nil, fmt.Errorf("identity/oidc: refresh failed: %w", err)
}
return newToken, nil
}

View file

@ -0,0 +1,71 @@
package oidc
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/oauth2"
)
func TestRefresh(t *testing.T) {
t.Parallel()
ctx, clearTimeout := context.WithTimeout(context.Background(), 10*time.Second)
t.Cleanup(clearTimeout)
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(`{
"access_token": "NEW_TOKEN",
"refresh_token": "NEW_REFRESH_TOKEN",
"expires_in": 3600
}`))
}))
t.Cleanup(s.Close)
cfg := &oauth2.Config{Endpoint: oauth2.Endpoint{TokenURL: s.URL}}
token := &oauth2.Token{
AccessToken: "OLD_TOKEN",
RefreshToken: "OLD_REFRESH_TOKEN",
// Even if a token is not expiring soon, Refresh() should still perform
// the refresh.
Expiry: time.Now().Add(time.Hour),
}
require.True(t, token.Valid())
newToken, err := Refresh(ctx, cfg, token)
require.NoError(t, err)
assert.Equal(t, "NEW_TOKEN", newToken.AccessToken)
assert.Equal(t, "NEW_REFRESH_TOKEN", newToken.RefreshToken)
}
func TestRefresh_errors(t *testing.T) {
t.Parallel()
ctx, clearTimeout := context.WithTimeout(context.Background(), 10*time.Second)
t.Cleanup(clearTimeout)
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("{}"))
}))
t.Cleanup(s.Close)
cfg := &oauth2.Config{Endpoint: oauth2.Endpoint{TokenURL: s.URL}}
_, err := Refresh(ctx, cfg, nil)
assert.Equal(t, ErrMissingRefreshToken, err)
_, err = Refresh(ctx, cfg, &oauth2.Token{})
assert.Equal(t, ErrMissingRefreshToken, err)
_, err = Refresh(ctx, cfg, &oauth2.Token{RefreshToken: "REFRESH_TOKEN"})
assert.Equal(t, "identity/oidc: refresh failed: oauth2: server response missing access_token",
err.Error())
}

View file

@ -5,7 +5,7 @@ PATH="$PATH:$(go env GOPATH)/bin"
export PATH
_project_root="$(cd "$(dirname "${BASH_SOURCE[0]}")" >/dev/null 2>&1 && pwd)/.."
_envoy_version=1.27.1
_envoy_version=1.28.0
_dir="$_project_root/pkg/envoy/files"
for _target in darwin-amd64 darwin-arm64 linux-amd64 linux-arm64; do