This commit is contained in:
Caleb Doxsey 2024-04-19 16:07:02 -06:00
parent 2eaa4291ae
commit 0e3e5ff494
3 changed files with 122 additions and 38 deletions

View file

@ -148,7 +148,14 @@ func (mgr *Manager) onUpdateSession(ctx context.Context, s *session.Session) {
mgr.dataStore.putSession(s)
rss, ok := mgr.refreshSessionSchedulers[s.GetId()]
if !ok {
rss = newRefreshSessionScheduler(ctx, mgr, s.GetId())
rss = newRefreshSessionScheduler(
ctx,
mgr.cfg.Load().now,
mgr.cfg.Load().sessionRefreshGracePeriod,
mgr.cfg.Load().sessionRefreshCoolOffDuration,
mgr.refreshSession,
s.GetId(),
)
mgr.refreshSessionSchedulers[s.GetId()] = rss
}
rss.Update(s)
@ -162,7 +169,12 @@ func (mgr *Manager) onUpdateUser(ctx context.Context, u *user.User) {
mgr.dataStore.putUser(u)
_, ok := mgr.updateUserInfoSchedulers[u.GetId()]
if !ok {
uuis := newUpdateUserInfoScheduler(ctx, mgr.cfg.Load().updateUserInfoInterval, mgr.updateUserInfo, u.GetId())
uuis := newUpdateUserInfoScheduler(
ctx,
mgr.cfg.Load().updateUserInfoInterval,
mgr.updateUserInfo,
u.GetId(),
)
mgr.updateUserInfoSchedulers[u.GetId()] = uuis
}
mgr.mu.Unlock()

View file

@ -63,8 +63,12 @@ func (uuis *updateUserInfoScheduler) run(ctx context.Context) {
}
type refreshSessionScheduler struct {
mgr *Manager
sessionID string
now func() time.Time
sessionRefreshGracePeriod time.Duration
sessionRefreshCoolOffDuration time.Duration
refreshSession func(ctx context.Context, sesionID string)
sessionID string
lastRefresh atomic.Pointer[time.Time]
next chan time.Time
cancel context.CancelFunc
@ -72,16 +76,22 @@ type refreshSessionScheduler struct {
func newRefreshSessionScheduler(
ctx context.Context,
mgr *Manager,
now func() time.Time,
sessionRefreshGracePeriod time.Duration,
sessionRefreshCoolOffDuration time.Duration,
refreshSession func(ctx context.Context, sesionID string),
sessionID string,
) *refreshSessionScheduler {
rss := &refreshSessionScheduler{
mgr: mgr,
sessionID: sessionID,
next: make(chan time.Time, 1),
now: now,
sessionRefreshGracePeriod: sessionRefreshGracePeriod,
sessionRefreshCoolOffDuration: sessionRefreshCoolOffDuration,
refreshSession: refreshSession,
sessionID: sessionID,
next: make(chan time.Time, 1),
}
now := rss.mgr.cfg.Load().now()
rss.lastRefresh.Store(&now)
tm := now()
rss.lastRefresh.Store(&tm)
ctx = context.WithoutCancel(ctx)
ctx, rss.cancel = context.WithCancel(ctx)
go rss.run(ctx)
@ -92,8 +102,8 @@ func (rss *refreshSessionScheduler) Update(s *session.Session) {
due := nextSessionRefresh(
s,
*rss.lastRefresh.Load(),
rss.mgr.cfg.Load().sessionRefreshGracePeriod,
rss.mgr.cfg.Load().sessionRefreshCoolOffDuration,
rss.sessionRefreshGracePeriod,
rss.sessionRefreshCoolOffDuration,
)
for {
select {
@ -114,6 +124,12 @@ func (rss *refreshSessionScheduler) Stop() {
func (rss *refreshSessionScheduler) run(ctx context.Context) {
var timer *time.Timer
// ensure we clean up any orphaned timers
defer func() {
if timer != nil {
timer.Stop()
}
}()
// wait for the first update
select {
@ -122,7 +138,6 @@ func (rss *refreshSessionScheduler) run(ctx context.Context) {
case due := <-rss.next:
delay := max(time.Until(due), 0)
timer = time.NewTimer(delay)
defer timer.Stop()
}
// wait for updates or for the timer to trigger
@ -132,15 +147,13 @@ func (rss *refreshSessionScheduler) run(ctx context.Context) {
return
case due := <-rss.next:
delay := max(time.Until(due), 0)
// stop the current timer and reset it
if !timer.Stop() {
<-timer.C
}
timer.Reset(delay)
// stop the existing timer and start a new one
timer.Stop()
timer = time.NewTimer(delay)
case <-timer.C:
now := rss.mgr.cfg.Load().now()
rss.lastRefresh.Store(&now)
rss.mgr.refreshSession(ctx, rss.sessionID)
tm := rss.now()
rss.lastRefresh.Store(&tm)
rss.refreshSession(ctx, rss.sessionID)
}
}
}

View file

@ -7,29 +7,58 @@ import (
"time"
"github.com/stretchr/testify/assert"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/pomerium/pomerium/pkg/grpc/session"
)
func TestRefreshSessionScheduler(t *testing.T) {
t.Parallel()
var calls safeSlice[time.Time]
ctx := context.Background()
sessionRefreshGracePeriod := time.Millisecond
sessionRefreshCoolOffDuration := time.Millisecond
rss := newRefreshSessionScheduler(
ctx,
time.Now,
sessionRefreshGracePeriod,
sessionRefreshCoolOffDuration,
func(ctx context.Context, sesionID string) {
calls.Append(time.Now())
},
"S1",
)
t.Cleanup(rss.Stop)
rss.Update(&session.Session{ExpiresAt: timestamppb.Now()})
assert.Eventually(t, func() bool {
return calls.Len() == 1
}, 100*time.Millisecond, 10*time.Millisecond, "should trigger once")
rss.Update(&session.Session{ExpiresAt: timestamppb.Now()})
assert.Eventually(t, func() bool {
return calls.Len() == 2
}, 100*time.Millisecond, 10*time.Millisecond, "should trigger again")
}
func TestUpdateUserInfoScheduler(t *testing.T) {
t.Parallel()
var mu sync.Mutex
var calls []time.Time
var calls safeSlice[time.Time]
ctx := context.Background()
userUpdateInfoInterval := 100 * time.Millisecond
uuis := newUpdateUserInfoScheduler(ctx, userUpdateInfoInterval, func(ctx context.Context, userID string) {
mu.Lock()
calls = append(calls, time.Now())
mu.Unlock()
calls.Append(time.Now())
}, "U1")
t.Cleanup(uuis.Stop)
// should eventually trigger
assert.Eventually(t, func() bool {
mu.Lock()
n := len(calls)
mu.Unlock()
return n == 1
return calls.Len() == 1
}, 3*userUpdateInfoInterval, userUpdateInfoInterval/10, "should trigger once")
uuis.Reset()
@ -37,18 +66,48 @@ func TestUpdateUserInfoScheduler(t *testing.T) {
uuis.Reset()
assert.Eventually(t, func() bool {
mu.Lock()
n := len(calls)
mu.Unlock()
return n == 2
return calls.Len() == 2
}, 3*userUpdateInfoInterval, userUpdateInfoInterval/10, "should trigger once after multiple resets")
mu.Lock()
var diff time.Duration
if len(calls) >= 2 {
diff = calls[len(calls)-1].Sub(calls[len(calls)-2])
if calls.Len() >= 2 {
diff = calls.At(calls.Len() - 1).Sub(calls.At(calls.Len() - 2))
}
mu.Unlock()
assert.GreaterOrEqual(t, diff, userUpdateInfoInterval, "delay should exceed interval")
uuis.Reset()
uuis.Stop()
time.Sleep(3 * userUpdateInfoInterval)
assert.Equal(t, 2, calls.Len(), "should not trigger again after stopping")
}
type safeSlice[T any] struct {
mu sync.Mutex
elements []T
}
func (s *safeSlice[T]) Append(elements ...T) {
s.mu.Lock()
s.elements = append(s.elements, elements...)
s.mu.Unlock()
}
func (s *safeSlice[T]) At(idx int) T {
var el T
s.mu.Lock()
if idx >= 0 && idx < len(s.elements) {
el = s.elements[idx]
}
s.mu.Unlock()
return el
}
func (s *safeSlice[T]) Len() int {
s.mu.Lock()
n := len(s.elements)
s.mu.Unlock()
return n
}