mirror of
https://github.com/pomerium/pomerium.git
synced 2025-07-19 01:28:51 +02:00
use context
This commit is contained in:
parent
d1350f0447
commit
328d2d500b
3 changed files with 24 additions and 24 deletions
|
@ -21,7 +21,6 @@ import (
|
|||
type ChangeListener = func(context.Context, *Config)
|
||||
|
||||
type changeDispatcherEvent struct {
|
||||
ctx context.Context
|
||||
cfg *Config
|
||||
}
|
||||
|
||||
|
@ -32,16 +31,15 @@ type ChangeDispatcher struct {
|
|||
|
||||
// Trigger triggers a change.
|
||||
func (dispatcher *ChangeDispatcher) Trigger(ctx context.Context, cfg *Config) {
|
||||
dispatcher.target.Dispatch(changeDispatcherEvent{
|
||||
ctx: ctx,
|
||||
dispatcher.target.Dispatch(ctx, changeDispatcherEvent{
|
||||
cfg: cfg,
|
||||
})
|
||||
}
|
||||
|
||||
// OnConfigChange adds a listener.
|
||||
func (dispatcher *ChangeDispatcher) OnConfigChange(_ context.Context, li ChangeListener) {
|
||||
dispatcher.target.AddListener(func(evt changeDispatcherEvent) {
|
||||
li(evt.ctx, evt.cfg)
|
||||
dispatcher.target.AddListener(func(ctx context.Context, evt changeDispatcherEvent) {
|
||||
li(ctx, evt.cfg)
|
||||
})
|
||||
}
|
||||
|
||||
|
|
|
@ -10,7 +10,7 @@ import (
|
|||
|
||||
type (
|
||||
// A Listener is a function that listens for events of type T.
|
||||
Listener[T any] func(T)
|
||||
Listener[T any] func(ctx context.Context, event T)
|
||||
// A Handle represents a listener.
|
||||
Handle string
|
||||
|
||||
|
@ -22,6 +22,7 @@ type (
|
|||
handle Handle
|
||||
}
|
||||
dispatchEvent[T any] struct {
|
||||
ctx context.Context
|
||||
event T
|
||||
}
|
||||
)
|
||||
|
@ -52,7 +53,7 @@ type Target[T any] struct {
|
|||
addListenerCh chan addListenerEvent[T]
|
||||
removeListenerCh chan removeListenerEvent[T]
|
||||
dispatchCh chan dispatchEvent[T]
|
||||
listeners map[Handle]chan T
|
||||
listeners map[Handle]chan dispatchEvent[T]
|
||||
}
|
||||
|
||||
// AddListener adds a listener to the target.
|
||||
|
@ -78,13 +79,13 @@ func (t *Target[T]) Close() {
|
|||
t.cancel(errors.New("target closed"))
|
||||
}
|
||||
|
||||
// Dispatch dispatches an event to any listeners.
|
||||
func (t *Target[T]) Dispatch(evt T) {
|
||||
// Dispatch dispatches an event to all listeners.
|
||||
func (t *Target[T]) Dispatch(ctx context.Context, evt T) {
|
||||
t.init()
|
||||
|
||||
select {
|
||||
case <-t.ctx.Done():
|
||||
case t.dispatchCh <- dispatchEvent[T]{evt}:
|
||||
case t.dispatchCh <- dispatchEvent[T]{ctx: ctx, event: evt}:
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -104,7 +105,7 @@ func (t *Target[T]) init() {
|
|||
t.addListenerCh = make(chan addListenerEvent[T], 1)
|
||||
t.removeListenerCh = make(chan removeListenerEvent[T], 1)
|
||||
t.dispatchCh = make(chan dispatchEvent[T], 1)
|
||||
t.listeners = map[Handle]chan T{}
|
||||
t.listeners = map[Handle]chan dispatchEvent[T]{}
|
||||
go t.run()
|
||||
})
|
||||
}
|
||||
|
@ -120,7 +121,7 @@ func (t *Target[T]) run() {
|
|||
case evt := <-t.removeListenerCh:
|
||||
t.removeListener(evt.handle)
|
||||
case evt := <-t.dispatchCh:
|
||||
t.dispatch(evt.event)
|
||||
t.dispatch(evt.ctx, evt.event)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -128,7 +129,7 @@ func (t *Target[T]) run() {
|
|||
// these functions are not thread-safe. They are intended to be called only by "run".
|
||||
|
||||
func (t *Target[T]) addListener(listener Listener[T], handle Handle) {
|
||||
ch := make(chan T, 1)
|
||||
ch := make(chan dispatchEvent[T], 1)
|
||||
t.listeners[handle] = ch
|
||||
// start a goroutine to send events to the listener
|
||||
go func() {
|
||||
|
@ -136,7 +137,7 @@ func (t *Target[T]) addListener(listener Listener[T], handle Handle) {
|
|||
select {
|
||||
case <-t.ctx.Done():
|
||||
case evt := <-ch:
|
||||
listener(evt)
|
||||
listener(evt.ctx, evt.event)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
@ -153,13 +154,13 @@ func (t *Target[T]) removeListener(handle Handle) {
|
|||
delete(t.listeners, handle)
|
||||
}
|
||||
|
||||
func (t *Target[T]) dispatch(evt T) {
|
||||
func (t *Target[T]) dispatch(ctx context.Context, evt T) {
|
||||
// loop over all the listeners and send the event to them
|
||||
for _, ch := range t.listeners {
|
||||
select {
|
||||
case <-t.ctx.Done():
|
||||
return
|
||||
case ch <- evt:
|
||||
case ch <- dispatchEvent[T]{ctx: ctx, event: evt}:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package events_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
@ -14,16 +15,16 @@ func TestTarget(t *testing.T) {
|
|||
t.Parallel()
|
||||
|
||||
var target events.Target[int64]
|
||||
defer target.Close()
|
||||
t.Cleanup(target.Close)
|
||||
|
||||
var calls1, calls2, calls3 atomic.Int64
|
||||
h1 := target.AddListener(func(i int64) {
|
||||
h1 := target.AddListener(func(_ context.Context, i int64) {
|
||||
calls1.Add(i)
|
||||
})
|
||||
h2 := target.AddListener(func(i int64) {
|
||||
h2 := target.AddListener(func(_ context.Context, i int64) {
|
||||
calls2.Add(i)
|
||||
})
|
||||
h3 := target.AddListener(func(i int64) {
|
||||
h3 := target.AddListener(func(_ context.Context, i int64) {
|
||||
calls3.Add(i)
|
||||
})
|
||||
|
||||
|
@ -35,18 +36,18 @@ func TestTarget(t *testing.T) {
|
|||
assert.Eventually(t, func() bool { return calls3.Load() == i3 }, time.Millisecond*10, time.Microsecond*100)
|
||||
}
|
||||
|
||||
target.Dispatch(1)
|
||||
target.Dispatch(context.Background(), 1)
|
||||
shouldBe(1, 1, 1)
|
||||
|
||||
target.RemoveListener(h2)
|
||||
target.Dispatch(2)
|
||||
target.Dispatch(context.Background(), 2)
|
||||
shouldBe(3, 1, 3)
|
||||
|
||||
target.RemoveListener(h1)
|
||||
target.Dispatch(3)
|
||||
target.Dispatch(context.Background(), 3)
|
||||
shouldBe(3, 1, 6)
|
||||
|
||||
target.RemoveListener(h3)
|
||||
target.Dispatch(4)
|
||||
target.Dispatch(context.Background(), 4)
|
||||
shouldBe(3, 1, 6)
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue