mirror of
https://github.com/pomerium/pomerium.git
synced 2025-08-06 10:21:05 +02:00
wip
This commit is contained in:
parent
2eaa4291ae
commit
0e3e5ff494
3 changed files with 122 additions and 38 deletions
|
@ -148,7 +148,14 @@ func (mgr *Manager) onUpdateSession(ctx context.Context, s *session.Session) {
|
||||||
mgr.dataStore.putSession(s)
|
mgr.dataStore.putSession(s)
|
||||||
rss, ok := mgr.refreshSessionSchedulers[s.GetId()]
|
rss, ok := mgr.refreshSessionSchedulers[s.GetId()]
|
||||||
if !ok {
|
if !ok {
|
||||||
rss = newRefreshSessionScheduler(ctx, mgr, s.GetId())
|
rss = newRefreshSessionScheduler(
|
||||||
|
ctx,
|
||||||
|
mgr.cfg.Load().now,
|
||||||
|
mgr.cfg.Load().sessionRefreshGracePeriod,
|
||||||
|
mgr.cfg.Load().sessionRefreshCoolOffDuration,
|
||||||
|
mgr.refreshSession,
|
||||||
|
s.GetId(),
|
||||||
|
)
|
||||||
mgr.refreshSessionSchedulers[s.GetId()] = rss
|
mgr.refreshSessionSchedulers[s.GetId()] = rss
|
||||||
}
|
}
|
||||||
rss.Update(s)
|
rss.Update(s)
|
||||||
|
@ -162,7 +169,12 @@ func (mgr *Manager) onUpdateUser(ctx context.Context, u *user.User) {
|
||||||
mgr.dataStore.putUser(u)
|
mgr.dataStore.putUser(u)
|
||||||
_, ok := mgr.updateUserInfoSchedulers[u.GetId()]
|
_, ok := mgr.updateUserInfoSchedulers[u.GetId()]
|
||||||
if !ok {
|
if !ok {
|
||||||
uuis := newUpdateUserInfoScheduler(ctx, mgr.cfg.Load().updateUserInfoInterval, mgr.updateUserInfo, u.GetId())
|
uuis := newUpdateUserInfoScheduler(
|
||||||
|
ctx,
|
||||||
|
mgr.cfg.Load().updateUserInfoInterval,
|
||||||
|
mgr.updateUserInfo,
|
||||||
|
u.GetId(),
|
||||||
|
)
|
||||||
mgr.updateUserInfoSchedulers[u.GetId()] = uuis
|
mgr.updateUserInfoSchedulers[u.GetId()] = uuis
|
||||||
}
|
}
|
||||||
mgr.mu.Unlock()
|
mgr.mu.Unlock()
|
||||||
|
|
|
@ -63,8 +63,12 @@ func (uuis *updateUserInfoScheduler) run(ctx context.Context) {
|
||||||
}
|
}
|
||||||
|
|
||||||
type refreshSessionScheduler struct {
|
type refreshSessionScheduler struct {
|
||||||
mgr *Manager
|
now func() time.Time
|
||||||
|
sessionRefreshGracePeriod time.Duration
|
||||||
|
sessionRefreshCoolOffDuration time.Duration
|
||||||
|
refreshSession func(ctx context.Context, sesionID string)
|
||||||
sessionID string
|
sessionID string
|
||||||
|
|
||||||
lastRefresh atomic.Pointer[time.Time]
|
lastRefresh atomic.Pointer[time.Time]
|
||||||
next chan time.Time
|
next chan time.Time
|
||||||
cancel context.CancelFunc
|
cancel context.CancelFunc
|
||||||
|
@ -72,16 +76,22 @@ type refreshSessionScheduler struct {
|
||||||
|
|
||||||
func newRefreshSessionScheduler(
|
func newRefreshSessionScheduler(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
mgr *Manager,
|
now func() time.Time,
|
||||||
|
sessionRefreshGracePeriod time.Duration,
|
||||||
|
sessionRefreshCoolOffDuration time.Duration,
|
||||||
|
refreshSession func(ctx context.Context, sesionID string),
|
||||||
sessionID string,
|
sessionID string,
|
||||||
) *refreshSessionScheduler {
|
) *refreshSessionScheduler {
|
||||||
rss := &refreshSessionScheduler{
|
rss := &refreshSessionScheduler{
|
||||||
mgr: mgr,
|
now: now,
|
||||||
|
sessionRefreshGracePeriod: sessionRefreshGracePeriod,
|
||||||
|
sessionRefreshCoolOffDuration: sessionRefreshCoolOffDuration,
|
||||||
|
refreshSession: refreshSession,
|
||||||
sessionID: sessionID,
|
sessionID: sessionID,
|
||||||
next: make(chan time.Time, 1),
|
next: make(chan time.Time, 1),
|
||||||
}
|
}
|
||||||
now := rss.mgr.cfg.Load().now()
|
tm := now()
|
||||||
rss.lastRefresh.Store(&now)
|
rss.lastRefresh.Store(&tm)
|
||||||
ctx = context.WithoutCancel(ctx)
|
ctx = context.WithoutCancel(ctx)
|
||||||
ctx, rss.cancel = context.WithCancel(ctx)
|
ctx, rss.cancel = context.WithCancel(ctx)
|
||||||
go rss.run(ctx)
|
go rss.run(ctx)
|
||||||
|
@ -92,8 +102,8 @@ func (rss *refreshSessionScheduler) Update(s *session.Session) {
|
||||||
due := nextSessionRefresh(
|
due := nextSessionRefresh(
|
||||||
s,
|
s,
|
||||||
*rss.lastRefresh.Load(),
|
*rss.lastRefresh.Load(),
|
||||||
rss.mgr.cfg.Load().sessionRefreshGracePeriod,
|
rss.sessionRefreshGracePeriod,
|
||||||
rss.mgr.cfg.Load().sessionRefreshCoolOffDuration,
|
rss.sessionRefreshCoolOffDuration,
|
||||||
)
|
)
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
|
@ -114,6 +124,12 @@ func (rss *refreshSessionScheduler) Stop() {
|
||||||
|
|
||||||
func (rss *refreshSessionScheduler) run(ctx context.Context) {
|
func (rss *refreshSessionScheduler) run(ctx context.Context) {
|
||||||
var timer *time.Timer
|
var timer *time.Timer
|
||||||
|
// ensure we clean up any orphaned timers
|
||||||
|
defer func() {
|
||||||
|
if timer != nil {
|
||||||
|
timer.Stop()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
// wait for the first update
|
// wait for the first update
|
||||||
select {
|
select {
|
||||||
|
@ -122,7 +138,6 @@ func (rss *refreshSessionScheduler) run(ctx context.Context) {
|
||||||
case due := <-rss.next:
|
case due := <-rss.next:
|
||||||
delay := max(time.Until(due), 0)
|
delay := max(time.Until(due), 0)
|
||||||
timer = time.NewTimer(delay)
|
timer = time.NewTimer(delay)
|
||||||
defer timer.Stop()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// wait for updates or for the timer to trigger
|
// wait for updates or for the timer to trigger
|
||||||
|
@ -132,15 +147,13 @@ func (rss *refreshSessionScheduler) run(ctx context.Context) {
|
||||||
return
|
return
|
||||||
case due := <-rss.next:
|
case due := <-rss.next:
|
||||||
delay := max(time.Until(due), 0)
|
delay := max(time.Until(due), 0)
|
||||||
// stop the current timer and reset it
|
// stop the existing timer and start a new one
|
||||||
if !timer.Stop() {
|
timer.Stop()
|
||||||
<-timer.C
|
timer = time.NewTimer(delay)
|
||||||
}
|
|
||||||
timer.Reset(delay)
|
|
||||||
case <-timer.C:
|
case <-timer.C:
|
||||||
now := rss.mgr.cfg.Load().now()
|
tm := rss.now()
|
||||||
rss.lastRefresh.Store(&now)
|
rss.lastRefresh.Store(&tm)
|
||||||
rss.mgr.refreshSession(ctx, rss.sessionID)
|
rss.refreshSession(ctx, rss.sessionID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,29 +7,58 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"google.golang.org/protobuf/types/known/timestamppb"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/pkg/grpc/session"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func TestRefreshSessionScheduler(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
var calls safeSlice[time.Time]
|
||||||
|
ctx := context.Background()
|
||||||
|
sessionRefreshGracePeriod := time.Millisecond
|
||||||
|
sessionRefreshCoolOffDuration := time.Millisecond
|
||||||
|
rss := newRefreshSessionScheduler(
|
||||||
|
ctx,
|
||||||
|
time.Now,
|
||||||
|
sessionRefreshGracePeriod,
|
||||||
|
sessionRefreshCoolOffDuration,
|
||||||
|
func(ctx context.Context, sesionID string) {
|
||||||
|
calls.Append(time.Now())
|
||||||
|
},
|
||||||
|
"S1",
|
||||||
|
)
|
||||||
|
t.Cleanup(rss.Stop)
|
||||||
|
|
||||||
|
rss.Update(&session.Session{ExpiresAt: timestamppb.Now()})
|
||||||
|
|
||||||
|
assert.Eventually(t, func() bool {
|
||||||
|
return calls.Len() == 1
|
||||||
|
}, 100*time.Millisecond, 10*time.Millisecond, "should trigger once")
|
||||||
|
|
||||||
|
rss.Update(&session.Session{ExpiresAt: timestamppb.Now()})
|
||||||
|
|
||||||
|
assert.Eventually(t, func() bool {
|
||||||
|
return calls.Len() == 2
|
||||||
|
}, 100*time.Millisecond, 10*time.Millisecond, "should trigger again")
|
||||||
|
}
|
||||||
|
|
||||||
func TestUpdateUserInfoScheduler(t *testing.T) {
|
func TestUpdateUserInfoScheduler(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
var mu sync.Mutex
|
var calls safeSlice[time.Time]
|
||||||
var calls []time.Time
|
|
||||||
|
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
userUpdateInfoInterval := 100 * time.Millisecond
|
userUpdateInfoInterval := 100 * time.Millisecond
|
||||||
uuis := newUpdateUserInfoScheduler(ctx, userUpdateInfoInterval, func(ctx context.Context, userID string) {
|
uuis := newUpdateUserInfoScheduler(ctx, userUpdateInfoInterval, func(ctx context.Context, userID string) {
|
||||||
mu.Lock()
|
calls.Append(time.Now())
|
||||||
calls = append(calls, time.Now())
|
|
||||||
mu.Unlock()
|
|
||||||
}, "U1")
|
}, "U1")
|
||||||
t.Cleanup(uuis.Stop)
|
t.Cleanup(uuis.Stop)
|
||||||
|
|
||||||
// should eventually trigger
|
// should eventually trigger
|
||||||
assert.Eventually(t, func() bool {
|
assert.Eventually(t, func() bool {
|
||||||
mu.Lock()
|
return calls.Len() == 1
|
||||||
n := len(calls)
|
|
||||||
mu.Unlock()
|
|
||||||
return n == 1
|
|
||||||
}, 3*userUpdateInfoInterval, userUpdateInfoInterval/10, "should trigger once")
|
}, 3*userUpdateInfoInterval, userUpdateInfoInterval/10, "should trigger once")
|
||||||
|
|
||||||
uuis.Reset()
|
uuis.Reset()
|
||||||
|
@ -37,18 +66,48 @@ func TestUpdateUserInfoScheduler(t *testing.T) {
|
||||||
uuis.Reset()
|
uuis.Reset()
|
||||||
|
|
||||||
assert.Eventually(t, func() bool {
|
assert.Eventually(t, func() bool {
|
||||||
mu.Lock()
|
return calls.Len() == 2
|
||||||
n := len(calls)
|
|
||||||
mu.Unlock()
|
|
||||||
return n == 2
|
|
||||||
}, 3*userUpdateInfoInterval, userUpdateInfoInterval/10, "should trigger once after multiple resets")
|
}, 3*userUpdateInfoInterval, userUpdateInfoInterval/10, "should trigger once after multiple resets")
|
||||||
|
|
||||||
mu.Lock()
|
|
||||||
var diff time.Duration
|
var diff time.Duration
|
||||||
if len(calls) >= 2 {
|
if calls.Len() >= 2 {
|
||||||
diff = calls[len(calls)-1].Sub(calls[len(calls)-2])
|
diff = calls.At(calls.Len() - 1).Sub(calls.At(calls.Len() - 2))
|
||||||
}
|
}
|
||||||
mu.Unlock()
|
|
||||||
|
|
||||||
assert.GreaterOrEqual(t, diff, userUpdateInfoInterval, "delay should exceed interval")
|
assert.GreaterOrEqual(t, diff, userUpdateInfoInterval, "delay should exceed interval")
|
||||||
|
|
||||||
|
uuis.Reset()
|
||||||
|
uuis.Stop()
|
||||||
|
|
||||||
|
time.Sleep(3 * userUpdateInfoInterval)
|
||||||
|
|
||||||
|
assert.Equal(t, 2, calls.Len(), "should not trigger again after stopping")
|
||||||
|
}
|
||||||
|
|
||||||
|
type safeSlice[T any] struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
elements []T
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *safeSlice[T]) Append(elements ...T) {
|
||||||
|
s.mu.Lock()
|
||||||
|
s.elements = append(s.elements, elements...)
|
||||||
|
s.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *safeSlice[T]) At(idx int) T {
|
||||||
|
var el T
|
||||||
|
s.mu.Lock()
|
||||||
|
if idx >= 0 && idx < len(s.elements) {
|
||||||
|
el = s.elements[idx]
|
||||||
|
}
|
||||||
|
s.mu.Unlock()
|
||||||
|
return el
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *safeSlice[T]) Len() int {
|
||||||
|
s.mu.Lock()
|
||||||
|
n := len(s.elements)
|
||||||
|
s.mu.Unlock()
|
||||||
|
return n
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue