mirror of
https://github.com/pomerium/pomerium.git
synced 2025-08-06 10:21:05 +02:00
add test
This commit is contained in:
parent
a80ef11763
commit
2eaa4291ae
3 changed files with 72 additions and 10 deletions
|
@ -162,7 +162,7 @@ func (mgr *Manager) onUpdateUser(ctx context.Context, u *user.User) {
|
||||||
mgr.dataStore.putUser(u)
|
mgr.dataStore.putUser(u)
|
||||||
_, ok := mgr.updateUserInfoSchedulers[u.GetId()]
|
_, ok := mgr.updateUserInfoSchedulers[u.GetId()]
|
||||||
if !ok {
|
if !ok {
|
||||||
uuis := newUpdateUserInfoScheduler(ctx, mgr, u.GetId())
|
uuis := newUpdateUserInfoScheduler(ctx, mgr.cfg.Load().updateUserInfoInterval, mgr.updateUserInfo, u.GetId())
|
||||||
mgr.updateUserInfoSchedulers[u.GetId()] = uuis
|
mgr.updateUserInfoSchedulers[u.GetId()] = uuis
|
||||||
}
|
}
|
||||||
mgr.mu.Unlock()
|
mgr.mu.Unlock()
|
||||||
|
|
|
@ -9,17 +9,25 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
type updateUserInfoScheduler struct {
|
type updateUserInfoScheduler struct {
|
||||||
mgr *Manager
|
updateUserInfoInterval time.Duration
|
||||||
userID string
|
updateUserInfo func(ctx context.Context, userID string)
|
||||||
|
userID string
|
||||||
|
|
||||||
reset chan struct{}
|
reset chan struct{}
|
||||||
cancel context.CancelFunc
|
cancel context.CancelFunc
|
||||||
}
|
}
|
||||||
|
|
||||||
func newUpdateUserInfoScheduler(ctx context.Context, mgr *Manager, userID string) *updateUserInfoScheduler {
|
func newUpdateUserInfoScheduler(
|
||||||
|
ctx context.Context,
|
||||||
|
updateUserInfoInterval time.Duration,
|
||||||
|
updateUserInfo func(ctx context.Context, userID string),
|
||||||
|
userID string,
|
||||||
|
) *updateUserInfoScheduler {
|
||||||
uuis := &updateUserInfoScheduler{
|
uuis := &updateUserInfoScheduler{
|
||||||
mgr: mgr,
|
updateUserInfoInterval: updateUserInfoInterval,
|
||||||
userID: userID,
|
updateUserInfo: updateUserInfo,
|
||||||
reset: make(chan struct{}, 1),
|
userID: userID,
|
||||||
|
reset: make(chan struct{}, 1),
|
||||||
}
|
}
|
||||||
ctx = context.WithoutCancel(ctx)
|
ctx = context.WithoutCancel(ctx)
|
||||||
ctx, uuis.cancel = context.WithCancel(ctx)
|
ctx, uuis.cancel = context.WithCancel(ctx)
|
||||||
|
@ -39,7 +47,7 @@ func (uuis *updateUserInfoScheduler) Stop() {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (uuis *updateUserInfoScheduler) run(ctx context.Context) {
|
func (uuis *updateUserInfoScheduler) run(ctx context.Context) {
|
||||||
ticker := time.NewTicker(uuis.mgr.cfg.Load().updateUserInfoInterval)
|
ticker := time.NewTicker(uuis.updateUserInfoInterval)
|
||||||
defer ticker.Stop()
|
defer ticker.Stop()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
|
@ -47,9 +55,9 @@ func (uuis *updateUserInfoScheduler) run(ctx context.Context) {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return
|
return
|
||||||
case <-uuis.reset:
|
case <-uuis.reset:
|
||||||
ticker.Reset(uuis.mgr.cfg.Load().updateUserInfoInterval)
|
ticker.Reset(uuis.updateUserInfoInterval)
|
||||||
case <-ticker.C:
|
case <-ticker.C:
|
||||||
uuis.mgr.updateUserInfo(ctx, uuis.userID)
|
uuis.updateUserInfo(ctx, uuis.userID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
54
internal/identity/manager/schedulers_test.go
Normal file
54
internal/identity/manager/schedulers_test.go
Normal file
|
@ -0,0 +1,54 @@
|
||||||
|
package manager
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestUpdateUserInfoScheduler(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
var mu sync.Mutex
|
||||||
|
var calls []time.Time
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
userUpdateInfoInterval := 100 * time.Millisecond
|
||||||
|
uuis := newUpdateUserInfoScheduler(ctx, userUpdateInfoInterval, func(ctx context.Context, userID string) {
|
||||||
|
mu.Lock()
|
||||||
|
calls = append(calls, time.Now())
|
||||||
|
mu.Unlock()
|
||||||
|
}, "U1")
|
||||||
|
t.Cleanup(uuis.Stop)
|
||||||
|
|
||||||
|
// should eventually trigger
|
||||||
|
assert.Eventually(t, func() bool {
|
||||||
|
mu.Lock()
|
||||||
|
n := len(calls)
|
||||||
|
mu.Unlock()
|
||||||
|
return n == 1
|
||||||
|
}, 3*userUpdateInfoInterval, userUpdateInfoInterval/10, "should trigger once")
|
||||||
|
|
||||||
|
uuis.Reset()
|
||||||
|
uuis.Reset()
|
||||||
|
uuis.Reset()
|
||||||
|
|
||||||
|
assert.Eventually(t, func() bool {
|
||||||
|
mu.Lock()
|
||||||
|
n := len(calls)
|
||||||
|
mu.Unlock()
|
||||||
|
return n == 2
|
||||||
|
}, 3*userUpdateInfoInterval, userUpdateInfoInterval/10, "should trigger once after multiple resets")
|
||||||
|
|
||||||
|
mu.Lock()
|
||||||
|
var diff time.Duration
|
||||||
|
if len(calls) >= 2 {
|
||||||
|
diff = calls[len(calls)-1].Sub(calls[len(calls)-2])
|
||||||
|
}
|
||||||
|
mu.Unlock()
|
||||||
|
|
||||||
|
assert.GreaterOrEqual(t, diff, userUpdateInfoInterval, "delay should exceed interval")
|
||||||
|
}
|
Loading…
Add table
Add a link
Reference in a new issue