mirror of
https://github.com/pomerium/pomerium.git
synced 2025-08-03 00:40:25 +02:00
wip
This commit is contained in:
parent
2eaa4291ae
commit
0e3e5ff494
3 changed files with 122 additions and 38 deletions
|
@ -148,7 +148,14 @@ func (mgr *Manager) onUpdateSession(ctx context.Context, s *session.Session) {
|
|||
mgr.dataStore.putSession(s)
|
||||
rss, ok := mgr.refreshSessionSchedulers[s.GetId()]
|
||||
if !ok {
|
||||
rss = newRefreshSessionScheduler(ctx, mgr, s.GetId())
|
||||
rss = newRefreshSessionScheduler(
|
||||
ctx,
|
||||
mgr.cfg.Load().now,
|
||||
mgr.cfg.Load().sessionRefreshGracePeriod,
|
||||
mgr.cfg.Load().sessionRefreshCoolOffDuration,
|
||||
mgr.refreshSession,
|
||||
s.GetId(),
|
||||
)
|
||||
mgr.refreshSessionSchedulers[s.GetId()] = rss
|
||||
}
|
||||
rss.Update(s)
|
||||
|
@ -162,7 +169,12 @@ func (mgr *Manager) onUpdateUser(ctx context.Context, u *user.User) {
|
|||
mgr.dataStore.putUser(u)
|
||||
_, ok := mgr.updateUserInfoSchedulers[u.GetId()]
|
||||
if !ok {
|
||||
uuis := newUpdateUserInfoScheduler(ctx, mgr.cfg.Load().updateUserInfoInterval, mgr.updateUserInfo, u.GetId())
|
||||
uuis := newUpdateUserInfoScheduler(
|
||||
ctx,
|
||||
mgr.cfg.Load().updateUserInfoInterval,
|
||||
mgr.updateUserInfo,
|
||||
u.GetId(),
|
||||
)
|
||||
mgr.updateUserInfoSchedulers[u.GetId()] = uuis
|
||||
}
|
||||
mgr.mu.Unlock()
|
||||
|
|
|
@ -63,8 +63,12 @@ func (uuis *updateUserInfoScheduler) run(ctx context.Context) {
|
|||
}
|
||||
|
||||
type refreshSessionScheduler struct {
|
||||
mgr *Manager
|
||||
sessionID string
|
||||
now func() time.Time
|
||||
sessionRefreshGracePeriod time.Duration
|
||||
sessionRefreshCoolOffDuration time.Duration
|
||||
refreshSession func(ctx context.Context, sesionID string)
|
||||
sessionID string
|
||||
|
||||
lastRefresh atomic.Pointer[time.Time]
|
||||
next chan time.Time
|
||||
cancel context.CancelFunc
|
||||
|
@ -72,16 +76,22 @@ type refreshSessionScheduler struct {
|
|||
|
||||
func newRefreshSessionScheduler(
|
||||
ctx context.Context,
|
||||
mgr *Manager,
|
||||
now func() time.Time,
|
||||
sessionRefreshGracePeriod time.Duration,
|
||||
sessionRefreshCoolOffDuration time.Duration,
|
||||
refreshSession func(ctx context.Context, sesionID string),
|
||||
sessionID string,
|
||||
) *refreshSessionScheduler {
|
||||
rss := &refreshSessionScheduler{
|
||||
mgr: mgr,
|
||||
sessionID: sessionID,
|
||||
next: make(chan time.Time, 1),
|
||||
now: now,
|
||||
sessionRefreshGracePeriod: sessionRefreshGracePeriod,
|
||||
sessionRefreshCoolOffDuration: sessionRefreshCoolOffDuration,
|
||||
refreshSession: refreshSession,
|
||||
sessionID: sessionID,
|
||||
next: make(chan time.Time, 1),
|
||||
}
|
||||
now := rss.mgr.cfg.Load().now()
|
||||
rss.lastRefresh.Store(&now)
|
||||
tm := now()
|
||||
rss.lastRefresh.Store(&tm)
|
||||
ctx = context.WithoutCancel(ctx)
|
||||
ctx, rss.cancel = context.WithCancel(ctx)
|
||||
go rss.run(ctx)
|
||||
|
@ -92,8 +102,8 @@ func (rss *refreshSessionScheduler) Update(s *session.Session) {
|
|||
due := nextSessionRefresh(
|
||||
s,
|
||||
*rss.lastRefresh.Load(),
|
||||
rss.mgr.cfg.Load().sessionRefreshGracePeriod,
|
||||
rss.mgr.cfg.Load().sessionRefreshCoolOffDuration,
|
||||
rss.sessionRefreshGracePeriod,
|
||||
rss.sessionRefreshCoolOffDuration,
|
||||
)
|
||||
for {
|
||||
select {
|
||||
|
@ -114,6 +124,12 @@ func (rss *refreshSessionScheduler) Stop() {
|
|||
|
||||
func (rss *refreshSessionScheduler) run(ctx context.Context) {
|
||||
var timer *time.Timer
|
||||
// ensure we clean up any orphaned timers
|
||||
defer func() {
|
||||
if timer != nil {
|
||||
timer.Stop()
|
||||
}
|
||||
}()
|
||||
|
||||
// wait for the first update
|
||||
select {
|
||||
|
@ -122,7 +138,6 @@ func (rss *refreshSessionScheduler) run(ctx context.Context) {
|
|||
case due := <-rss.next:
|
||||
delay := max(time.Until(due), 0)
|
||||
timer = time.NewTimer(delay)
|
||||
defer timer.Stop()
|
||||
}
|
||||
|
||||
// wait for updates or for the timer to trigger
|
||||
|
@ -132,15 +147,13 @@ func (rss *refreshSessionScheduler) run(ctx context.Context) {
|
|||
return
|
||||
case due := <-rss.next:
|
||||
delay := max(time.Until(due), 0)
|
||||
// stop the current timer and reset it
|
||||
if !timer.Stop() {
|
||||
<-timer.C
|
||||
}
|
||||
timer.Reset(delay)
|
||||
// stop the existing timer and start a new one
|
||||
timer.Stop()
|
||||
timer = time.NewTimer(delay)
|
||||
case <-timer.C:
|
||||
now := rss.mgr.cfg.Load().now()
|
||||
rss.lastRefresh.Store(&now)
|
||||
rss.mgr.refreshSession(ctx, rss.sessionID)
|
||||
tm := rss.now()
|
||||
rss.lastRefresh.Store(&tm)
|
||||
rss.refreshSession(ctx, rss.sessionID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -7,29 +7,58 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/pomerium/pomerium/pkg/grpc/session"
|
||||
)
|
||||
|
||||
func TestRefreshSessionScheduler(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var calls safeSlice[time.Time]
|
||||
ctx := context.Background()
|
||||
sessionRefreshGracePeriod := time.Millisecond
|
||||
sessionRefreshCoolOffDuration := time.Millisecond
|
||||
rss := newRefreshSessionScheduler(
|
||||
ctx,
|
||||
time.Now,
|
||||
sessionRefreshGracePeriod,
|
||||
sessionRefreshCoolOffDuration,
|
||||
func(ctx context.Context, sesionID string) {
|
||||
calls.Append(time.Now())
|
||||
},
|
||||
"S1",
|
||||
)
|
||||
t.Cleanup(rss.Stop)
|
||||
|
||||
rss.Update(&session.Session{ExpiresAt: timestamppb.Now()})
|
||||
|
||||
assert.Eventually(t, func() bool {
|
||||
return calls.Len() == 1
|
||||
}, 100*time.Millisecond, 10*time.Millisecond, "should trigger once")
|
||||
|
||||
rss.Update(&session.Session{ExpiresAt: timestamppb.Now()})
|
||||
|
||||
assert.Eventually(t, func() bool {
|
||||
return calls.Len() == 2
|
||||
}, 100*time.Millisecond, 10*time.Millisecond, "should trigger again")
|
||||
}
|
||||
|
||||
func TestUpdateUserInfoScheduler(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var mu sync.Mutex
|
||||
var calls []time.Time
|
||||
var calls safeSlice[time.Time]
|
||||
|
||||
ctx := context.Background()
|
||||
userUpdateInfoInterval := 100 * time.Millisecond
|
||||
uuis := newUpdateUserInfoScheduler(ctx, userUpdateInfoInterval, func(ctx context.Context, userID string) {
|
||||
mu.Lock()
|
||||
calls = append(calls, time.Now())
|
||||
mu.Unlock()
|
||||
calls.Append(time.Now())
|
||||
}, "U1")
|
||||
t.Cleanup(uuis.Stop)
|
||||
|
||||
// should eventually trigger
|
||||
assert.Eventually(t, func() bool {
|
||||
mu.Lock()
|
||||
n := len(calls)
|
||||
mu.Unlock()
|
||||
return n == 1
|
||||
return calls.Len() == 1
|
||||
}, 3*userUpdateInfoInterval, userUpdateInfoInterval/10, "should trigger once")
|
||||
|
||||
uuis.Reset()
|
||||
|
@ -37,18 +66,48 @@ func TestUpdateUserInfoScheduler(t *testing.T) {
|
|||
uuis.Reset()
|
||||
|
||||
assert.Eventually(t, func() bool {
|
||||
mu.Lock()
|
||||
n := len(calls)
|
||||
mu.Unlock()
|
||||
return n == 2
|
||||
return calls.Len() == 2
|
||||
}, 3*userUpdateInfoInterval, userUpdateInfoInterval/10, "should trigger once after multiple resets")
|
||||
|
||||
mu.Lock()
|
||||
var diff time.Duration
|
||||
if len(calls) >= 2 {
|
||||
diff = calls[len(calls)-1].Sub(calls[len(calls)-2])
|
||||
if calls.Len() >= 2 {
|
||||
diff = calls.At(calls.Len() - 1).Sub(calls.At(calls.Len() - 2))
|
||||
}
|
||||
mu.Unlock()
|
||||
|
||||
assert.GreaterOrEqual(t, diff, userUpdateInfoInterval, "delay should exceed interval")
|
||||
|
||||
uuis.Reset()
|
||||
uuis.Stop()
|
||||
|
||||
time.Sleep(3 * userUpdateInfoInterval)
|
||||
|
||||
assert.Equal(t, 2, calls.Len(), "should not trigger again after stopping")
|
||||
}
|
||||
|
||||
type safeSlice[T any] struct {
|
||||
mu sync.Mutex
|
||||
elements []T
|
||||
}
|
||||
|
||||
func (s *safeSlice[T]) Append(elements ...T) {
|
||||
s.mu.Lock()
|
||||
s.elements = append(s.elements, elements...)
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
func (s *safeSlice[T]) At(idx int) T {
|
||||
var el T
|
||||
s.mu.Lock()
|
||||
if idx >= 0 && idx < len(s.elements) {
|
||||
el = s.elements[idx]
|
||||
}
|
||||
s.mu.Unlock()
|
||||
return el
|
||||
}
|
||||
|
||||
func (s *safeSlice[T]) Len() int {
|
||||
s.mu.Lock()
|
||||
n := len(s.elements)
|
||||
s.mu.Unlock()
|
||||
return n
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue