diff --git a/config/config_source.go b/config/config_source.go index b94e6e028..9c20534de 100644 --- a/config/config_source.go +++ b/config/config_source.go @@ -21,7 +21,6 @@ import ( type ChangeListener = func(context.Context, *Config) type changeDispatcherEvent struct { - ctx context.Context cfg *Config } @@ -32,16 +31,15 @@ type ChangeDispatcher struct { // Trigger triggers a change. func (dispatcher *ChangeDispatcher) Trigger(ctx context.Context, cfg *Config) { - dispatcher.target.Dispatch(changeDispatcherEvent{ - ctx: ctx, + dispatcher.target.Dispatch(ctx, changeDispatcherEvent{ cfg: cfg, }) } // OnConfigChange adds a listener. func (dispatcher *ChangeDispatcher) OnConfigChange(_ context.Context, li ChangeListener) { - dispatcher.target.AddListener(func(evt changeDispatcherEvent) { - li(evt.ctx, evt.cfg) + dispatcher.target.AddListener(func(ctx context.Context, evt changeDispatcherEvent) { + li(ctx, evt.cfg) }) } diff --git a/internal/events/target.go b/internal/events/target.go index 4c41c346d..7bc0806a2 100644 --- a/internal/events/target.go +++ b/internal/events/target.go @@ -10,7 +10,7 @@ import ( type ( // A Listener is a function that listens for events of type T. - Listener[T any] func(T) + Listener[T any] func(ctx context.Context, event T) // A Handle represents a listener. Handle string @@ -22,6 +22,7 @@ type ( handle Handle } dispatchEvent[T any] struct { + ctx context.Context event T } ) @@ -52,7 +53,7 @@ type Target[T any] struct { addListenerCh chan addListenerEvent[T] removeListenerCh chan removeListenerEvent[T] dispatchCh chan dispatchEvent[T] - listeners map[Handle]chan T + listeners map[Handle]chan dispatchEvent[T] } // AddListener adds a listener to the target. @@ -78,13 +79,13 @@ func (t *Target[T]) Close() { t.cancel(errors.New("target closed")) } -// Dispatch dispatches an event to any listeners. -func (t *Target[T]) Dispatch(evt T) { +// Dispatch dispatches an event to all listeners. +func (t *Target[T]) Dispatch(ctx context.Context, evt T) { t.init() select { case <-t.ctx.Done(): - case t.dispatchCh <- dispatchEvent[T]{evt}: + case t.dispatchCh <- dispatchEvent[T]{ctx: ctx, event: evt}: } } @@ -104,7 +105,7 @@ func (t *Target[T]) init() { t.addListenerCh = make(chan addListenerEvent[T], 1) t.removeListenerCh = make(chan removeListenerEvent[T], 1) t.dispatchCh = make(chan dispatchEvent[T], 1) - t.listeners = map[Handle]chan T{} + t.listeners = map[Handle]chan dispatchEvent[T]{} go t.run() }) } @@ -120,7 +121,7 @@ func (t *Target[T]) run() { case evt := <-t.removeListenerCh: t.removeListener(evt.handle) case evt := <-t.dispatchCh: - t.dispatch(evt.event) + t.dispatch(evt.ctx, evt.event) } } } @@ -128,7 +129,7 @@ func (t *Target[T]) run() { // these functions are not thread-safe. They are intended to be called only by "run". func (t *Target[T]) addListener(listener Listener[T], handle Handle) { - ch := make(chan T, 1) + ch := make(chan dispatchEvent[T], 1) t.listeners[handle] = ch // start a goroutine to send events to the listener go func() { @@ -136,7 +137,7 @@ func (t *Target[T]) addListener(listener Listener[T], handle Handle) { select { case <-t.ctx.Done(): case evt := <-ch: - listener(evt) + listener(evt.ctx, evt.event) } } }() @@ -153,13 +154,13 @@ func (t *Target[T]) removeListener(handle Handle) { delete(t.listeners, handle) } -func (t *Target[T]) dispatch(evt T) { +func (t *Target[T]) dispatch(ctx context.Context, evt T) { // loop over all the listeners and send the event to them for _, ch := range t.listeners { select { case <-t.ctx.Done(): return - case ch <- evt: + case ch <- dispatchEvent[T]{ctx: ctx, event: evt}: } } } diff --git a/internal/events/target_test.go b/internal/events/target_test.go index 8fa38cc97..f31655485 100644 --- a/internal/events/target_test.go +++ b/internal/events/target_test.go @@ -1,6 +1,7 @@ package events_test import ( + "context" "sync/atomic" "testing" "time" @@ -14,16 +15,16 @@ func TestTarget(t *testing.T) { t.Parallel() var target events.Target[int64] - defer target.Close() + t.Cleanup(target.Close) var calls1, calls2, calls3 atomic.Int64 - h1 := target.AddListener(func(i int64) { + h1 := target.AddListener(func(_ context.Context, i int64) { calls1.Add(i) }) - h2 := target.AddListener(func(i int64) { + h2 := target.AddListener(func(_ context.Context, i int64) { calls2.Add(i) }) - h3 := target.AddListener(func(i int64) { + h3 := target.AddListener(func(_ context.Context, i int64) { calls3.Add(i) }) @@ -35,18 +36,18 @@ func TestTarget(t *testing.T) { assert.Eventually(t, func() bool { return calls3.Load() == i3 }, time.Millisecond*10, time.Microsecond*100) } - target.Dispatch(1) + target.Dispatch(context.Background(), 1) shouldBe(1, 1, 1) target.RemoveListener(h2) - target.Dispatch(2) + target.Dispatch(context.Background(), 2) shouldBe(3, 1, 3) target.RemoveListener(h1) - target.Dispatch(3) + target.Dispatch(context.Background(), 3) shouldBe(3, 1, 6) target.RemoveListener(h3) - target.Dispatch(4) + target.Dispatch(context.Background(), 4) shouldBe(3, 1, 6) }