use context

This commit is contained in:
Caleb Doxsey 2023-11-01 13:08:37 -06:00
parent d1350f0447
commit 328d2d500b
3 changed files with 24 additions and 24 deletions

View file

@ -21,7 +21,6 @@ import (
type ChangeListener = func(context.Context, *Config)
type changeDispatcherEvent struct {
ctx context.Context
cfg *Config
}
@ -32,16 +31,15 @@ type ChangeDispatcher struct {
// Trigger triggers a change.
func (dispatcher *ChangeDispatcher) Trigger(ctx context.Context, cfg *Config) {
dispatcher.target.Dispatch(changeDispatcherEvent{
ctx: ctx,
dispatcher.target.Dispatch(ctx, changeDispatcherEvent{
cfg: cfg,
})
}
// OnConfigChange adds a listener.
func (dispatcher *ChangeDispatcher) OnConfigChange(_ context.Context, li ChangeListener) {
dispatcher.target.AddListener(func(evt changeDispatcherEvent) {
li(evt.ctx, evt.cfg)
dispatcher.target.AddListener(func(ctx context.Context, evt changeDispatcherEvent) {
li(ctx, evt.cfg)
})
}

View file

@ -10,7 +10,7 @@ import (
type (
// A Listener is a function that listens for events of type T.
Listener[T any] func(T)
Listener[T any] func(ctx context.Context, event T)
// A Handle represents a listener.
Handle string
@ -22,6 +22,7 @@ type (
handle Handle
}
dispatchEvent[T any] struct {
ctx context.Context
event T
}
)
@ -52,7 +53,7 @@ type Target[T any] struct {
addListenerCh chan addListenerEvent[T]
removeListenerCh chan removeListenerEvent[T]
dispatchCh chan dispatchEvent[T]
listeners map[Handle]chan T
listeners map[Handle]chan dispatchEvent[T]
}
// AddListener adds a listener to the target.
@ -78,13 +79,13 @@ func (t *Target[T]) Close() {
t.cancel(errors.New("target closed"))
}
// Dispatch dispatches an event to any listeners.
func (t *Target[T]) Dispatch(evt T) {
// Dispatch dispatches an event to all listeners.
func (t *Target[T]) Dispatch(ctx context.Context, evt T) {
t.init()
select {
case <-t.ctx.Done():
case t.dispatchCh <- dispatchEvent[T]{evt}:
case t.dispatchCh <- dispatchEvent[T]{ctx: ctx, event: evt}:
}
}
@ -104,7 +105,7 @@ func (t *Target[T]) init() {
t.addListenerCh = make(chan addListenerEvent[T], 1)
t.removeListenerCh = make(chan removeListenerEvent[T], 1)
t.dispatchCh = make(chan dispatchEvent[T], 1)
t.listeners = map[Handle]chan T{}
t.listeners = map[Handle]chan dispatchEvent[T]{}
go t.run()
})
}
@ -120,7 +121,7 @@ func (t *Target[T]) run() {
case evt := <-t.removeListenerCh:
t.removeListener(evt.handle)
case evt := <-t.dispatchCh:
t.dispatch(evt.event)
t.dispatch(evt.ctx, evt.event)
}
}
}
@ -128,7 +129,7 @@ func (t *Target[T]) run() {
// these functions are not thread-safe. They are intended to be called only by "run".
func (t *Target[T]) addListener(listener Listener[T], handle Handle) {
ch := make(chan T, 1)
ch := make(chan dispatchEvent[T], 1)
t.listeners[handle] = ch
// start a goroutine to send events to the listener
go func() {
@ -136,7 +137,7 @@ func (t *Target[T]) addListener(listener Listener[T], handle Handle) {
select {
case <-t.ctx.Done():
case evt := <-ch:
listener(evt)
listener(evt.ctx, evt.event)
}
}
}()
@ -153,13 +154,13 @@ func (t *Target[T]) removeListener(handle Handle) {
delete(t.listeners, handle)
}
func (t *Target[T]) dispatch(evt T) {
func (t *Target[T]) dispatch(ctx context.Context, evt T) {
// loop over all the listeners and send the event to them
for _, ch := range t.listeners {
select {
case <-t.ctx.Done():
return
case ch <- evt:
case ch <- dispatchEvent[T]{ctx: ctx, event: evt}:
}
}
}

View file

@ -1,6 +1,7 @@
package events_test
import (
"context"
"sync/atomic"
"testing"
"time"
@ -14,16 +15,16 @@ func TestTarget(t *testing.T) {
t.Parallel()
var target events.Target[int64]
defer target.Close()
t.Cleanup(target.Close)
var calls1, calls2, calls3 atomic.Int64
h1 := target.AddListener(func(i int64) {
h1 := target.AddListener(func(_ context.Context, i int64) {
calls1.Add(i)
})
h2 := target.AddListener(func(i int64) {
h2 := target.AddListener(func(_ context.Context, i int64) {
calls2.Add(i)
})
h3 := target.AddListener(func(i int64) {
h3 := target.AddListener(func(_ context.Context, i int64) {
calls3.Add(i)
})
@ -35,18 +36,18 @@ func TestTarget(t *testing.T) {
assert.Eventually(t, func() bool { return calls3.Load() == i3 }, time.Millisecond*10, time.Microsecond*100)
}
target.Dispatch(1)
target.Dispatch(context.Background(), 1)
shouldBe(1, 1, 1)
target.RemoveListener(h2)
target.Dispatch(2)
target.Dispatch(context.Background(), 2)
shouldBe(3, 1, 3)
target.RemoveListener(h1)
target.Dispatch(3)
target.Dispatch(context.Background(), 3)
shouldBe(3, 1, 6)
target.RemoveListener(h3)
target.Dispatch(4)
target.Dispatch(context.Background(), 4)
shouldBe(3, 1, 6)
}