This commit is contained in:
Caleb Doxsey 2024-04-19 16:07:02 -06:00
parent 2eaa4291ae
commit 0e3e5ff494
3 changed files with 122 additions and 38 deletions

View file

@ -148,7 +148,14 @@ func (mgr *Manager) onUpdateSession(ctx context.Context, s *session.Session) {
mgr.dataStore.putSession(s) mgr.dataStore.putSession(s)
rss, ok := mgr.refreshSessionSchedulers[s.GetId()] rss, ok := mgr.refreshSessionSchedulers[s.GetId()]
if !ok { if !ok {
rss = newRefreshSessionScheduler(ctx, mgr, s.GetId()) rss = newRefreshSessionScheduler(
ctx,
mgr.cfg.Load().now,
mgr.cfg.Load().sessionRefreshGracePeriod,
mgr.cfg.Load().sessionRefreshCoolOffDuration,
mgr.refreshSession,
s.GetId(),
)
mgr.refreshSessionSchedulers[s.GetId()] = rss mgr.refreshSessionSchedulers[s.GetId()] = rss
} }
rss.Update(s) rss.Update(s)
@ -162,7 +169,12 @@ func (mgr *Manager) onUpdateUser(ctx context.Context, u *user.User) {
mgr.dataStore.putUser(u) mgr.dataStore.putUser(u)
_, ok := mgr.updateUserInfoSchedulers[u.GetId()] _, ok := mgr.updateUserInfoSchedulers[u.GetId()]
if !ok { if !ok {
uuis := newUpdateUserInfoScheduler(ctx, mgr.cfg.Load().updateUserInfoInterval, mgr.updateUserInfo, u.GetId()) uuis := newUpdateUserInfoScheduler(
ctx,
mgr.cfg.Load().updateUserInfoInterval,
mgr.updateUserInfo,
u.GetId(),
)
mgr.updateUserInfoSchedulers[u.GetId()] = uuis mgr.updateUserInfoSchedulers[u.GetId()] = uuis
} }
mgr.mu.Unlock() mgr.mu.Unlock()

View file

@ -63,8 +63,12 @@ func (uuis *updateUserInfoScheduler) run(ctx context.Context) {
} }
type refreshSessionScheduler struct { type refreshSessionScheduler struct {
mgr *Manager now func() time.Time
sessionRefreshGracePeriod time.Duration
sessionRefreshCoolOffDuration time.Duration
refreshSession func(ctx context.Context, sesionID string)
sessionID string sessionID string
lastRefresh atomic.Pointer[time.Time] lastRefresh atomic.Pointer[time.Time]
next chan time.Time next chan time.Time
cancel context.CancelFunc cancel context.CancelFunc
@ -72,16 +76,22 @@ type refreshSessionScheduler struct {
func newRefreshSessionScheduler( func newRefreshSessionScheduler(
ctx context.Context, ctx context.Context,
mgr *Manager, now func() time.Time,
sessionRefreshGracePeriod time.Duration,
sessionRefreshCoolOffDuration time.Duration,
refreshSession func(ctx context.Context, sesionID string),
sessionID string, sessionID string,
) *refreshSessionScheduler { ) *refreshSessionScheduler {
rss := &refreshSessionScheduler{ rss := &refreshSessionScheduler{
mgr: mgr, now: now,
sessionRefreshGracePeriod: sessionRefreshGracePeriod,
sessionRefreshCoolOffDuration: sessionRefreshCoolOffDuration,
refreshSession: refreshSession,
sessionID: sessionID, sessionID: sessionID,
next: make(chan time.Time, 1), next: make(chan time.Time, 1),
} }
now := rss.mgr.cfg.Load().now() tm := now()
rss.lastRefresh.Store(&now) rss.lastRefresh.Store(&tm)
ctx = context.WithoutCancel(ctx) ctx = context.WithoutCancel(ctx)
ctx, rss.cancel = context.WithCancel(ctx) ctx, rss.cancel = context.WithCancel(ctx)
go rss.run(ctx) go rss.run(ctx)
@ -92,8 +102,8 @@ func (rss *refreshSessionScheduler) Update(s *session.Session) {
due := nextSessionRefresh( due := nextSessionRefresh(
s, s,
*rss.lastRefresh.Load(), *rss.lastRefresh.Load(),
rss.mgr.cfg.Load().sessionRefreshGracePeriod, rss.sessionRefreshGracePeriod,
rss.mgr.cfg.Load().sessionRefreshCoolOffDuration, rss.sessionRefreshCoolOffDuration,
) )
for { for {
select { select {
@ -114,6 +124,12 @@ func (rss *refreshSessionScheduler) Stop() {
func (rss *refreshSessionScheduler) run(ctx context.Context) { func (rss *refreshSessionScheduler) run(ctx context.Context) {
var timer *time.Timer var timer *time.Timer
// ensure we clean up any orphaned timers
defer func() {
if timer != nil {
timer.Stop()
}
}()
// wait for the first update // wait for the first update
select { select {
@ -122,7 +138,6 @@ func (rss *refreshSessionScheduler) run(ctx context.Context) {
case due := <-rss.next: case due := <-rss.next:
delay := max(time.Until(due), 0) delay := max(time.Until(due), 0)
timer = time.NewTimer(delay) timer = time.NewTimer(delay)
defer timer.Stop()
} }
// wait for updates or for the timer to trigger // wait for updates or for the timer to trigger
@ -132,15 +147,13 @@ func (rss *refreshSessionScheduler) run(ctx context.Context) {
return return
case due := <-rss.next: case due := <-rss.next:
delay := max(time.Until(due), 0) delay := max(time.Until(due), 0)
// stop the current timer and reset it // stop the existing timer and start a new one
if !timer.Stop() { timer.Stop()
<-timer.C timer = time.NewTimer(delay)
}
timer.Reset(delay)
case <-timer.C: case <-timer.C:
now := rss.mgr.cfg.Load().now() tm := rss.now()
rss.lastRefresh.Store(&now) rss.lastRefresh.Store(&tm)
rss.mgr.refreshSession(ctx, rss.sessionID) rss.refreshSession(ctx, rss.sessionID)
} }
} }
} }

View file

@ -7,29 +7,58 @@ import (
"time" "time"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/pomerium/pomerium/pkg/grpc/session"
) )
func TestRefreshSessionScheduler(t *testing.T) {
t.Parallel()
var calls safeSlice[time.Time]
ctx := context.Background()
sessionRefreshGracePeriod := time.Millisecond
sessionRefreshCoolOffDuration := time.Millisecond
rss := newRefreshSessionScheduler(
ctx,
time.Now,
sessionRefreshGracePeriod,
sessionRefreshCoolOffDuration,
func(ctx context.Context, sesionID string) {
calls.Append(time.Now())
},
"S1",
)
t.Cleanup(rss.Stop)
rss.Update(&session.Session{ExpiresAt: timestamppb.Now()})
assert.Eventually(t, func() bool {
return calls.Len() == 1
}, 100*time.Millisecond, 10*time.Millisecond, "should trigger once")
rss.Update(&session.Session{ExpiresAt: timestamppb.Now()})
assert.Eventually(t, func() bool {
return calls.Len() == 2
}, 100*time.Millisecond, 10*time.Millisecond, "should trigger again")
}
func TestUpdateUserInfoScheduler(t *testing.T) { func TestUpdateUserInfoScheduler(t *testing.T) {
t.Parallel() t.Parallel()
var mu sync.Mutex var calls safeSlice[time.Time]
var calls []time.Time
ctx := context.Background() ctx := context.Background()
userUpdateInfoInterval := 100 * time.Millisecond userUpdateInfoInterval := 100 * time.Millisecond
uuis := newUpdateUserInfoScheduler(ctx, userUpdateInfoInterval, func(ctx context.Context, userID string) { uuis := newUpdateUserInfoScheduler(ctx, userUpdateInfoInterval, func(ctx context.Context, userID string) {
mu.Lock() calls.Append(time.Now())
calls = append(calls, time.Now())
mu.Unlock()
}, "U1") }, "U1")
t.Cleanup(uuis.Stop) t.Cleanup(uuis.Stop)
// should eventually trigger // should eventually trigger
assert.Eventually(t, func() bool { assert.Eventually(t, func() bool {
mu.Lock() return calls.Len() == 1
n := len(calls)
mu.Unlock()
return n == 1
}, 3*userUpdateInfoInterval, userUpdateInfoInterval/10, "should trigger once") }, 3*userUpdateInfoInterval, userUpdateInfoInterval/10, "should trigger once")
uuis.Reset() uuis.Reset()
@ -37,18 +66,48 @@ func TestUpdateUserInfoScheduler(t *testing.T) {
uuis.Reset() uuis.Reset()
assert.Eventually(t, func() bool { assert.Eventually(t, func() bool {
mu.Lock() return calls.Len() == 2
n := len(calls)
mu.Unlock()
return n == 2
}, 3*userUpdateInfoInterval, userUpdateInfoInterval/10, "should trigger once after multiple resets") }, 3*userUpdateInfoInterval, userUpdateInfoInterval/10, "should trigger once after multiple resets")
mu.Lock()
var diff time.Duration var diff time.Duration
if len(calls) >= 2 { if calls.Len() >= 2 {
diff = calls[len(calls)-1].Sub(calls[len(calls)-2]) diff = calls.At(calls.Len() - 1).Sub(calls.At(calls.Len() - 2))
} }
mu.Unlock()
assert.GreaterOrEqual(t, diff, userUpdateInfoInterval, "delay should exceed interval") assert.GreaterOrEqual(t, diff, userUpdateInfoInterval, "delay should exceed interval")
uuis.Reset()
uuis.Stop()
time.Sleep(3 * userUpdateInfoInterval)
assert.Equal(t, 2, calls.Len(), "should not trigger again after stopping")
}
type safeSlice[T any] struct {
mu sync.Mutex
elements []T
}
func (s *safeSlice[T]) Append(elements ...T) {
s.mu.Lock()
s.elements = append(s.elements, elements...)
s.mu.Unlock()
}
func (s *safeSlice[T]) At(idx int) T {
var el T
s.mu.Lock()
if idx >= 0 && idx < len(s.elements) {
el = s.elements[idx]
}
s.mu.Unlock()
return el
}
func (s *safeSlice[T]) Len() int {
s.mu.Lock()
n := len(s.elements)
s.mu.Unlock()
return n
} }