use context

This commit is contained in:
Caleb Doxsey 2023-11-01 13:08:37 -06:00
parent d1350f0447
commit 328d2d500b
3 changed files with 24 additions and 24 deletions

View file

@ -21,7 +21,6 @@ import (
type ChangeListener = func(context.Context, *Config) type ChangeListener = func(context.Context, *Config)
type changeDispatcherEvent struct { type changeDispatcherEvent struct {
ctx context.Context
cfg *Config cfg *Config
} }
@ -32,16 +31,15 @@ type ChangeDispatcher struct {
// Trigger triggers a change. // Trigger triggers a change.
func (dispatcher *ChangeDispatcher) Trigger(ctx context.Context, cfg *Config) { func (dispatcher *ChangeDispatcher) Trigger(ctx context.Context, cfg *Config) {
dispatcher.target.Dispatch(changeDispatcherEvent{ dispatcher.target.Dispatch(ctx, changeDispatcherEvent{
ctx: ctx,
cfg: cfg, cfg: cfg,
}) })
} }
// OnConfigChange adds a listener. // OnConfigChange adds a listener.
func (dispatcher *ChangeDispatcher) OnConfigChange(_ context.Context, li ChangeListener) { func (dispatcher *ChangeDispatcher) OnConfigChange(_ context.Context, li ChangeListener) {
dispatcher.target.AddListener(func(evt changeDispatcherEvent) { dispatcher.target.AddListener(func(ctx context.Context, evt changeDispatcherEvent) {
li(evt.ctx, evt.cfg) li(ctx, evt.cfg)
}) })
} }

View file

@ -10,7 +10,7 @@ import (
type ( type (
// A Listener is a function that listens for events of type T. // A Listener is a function that listens for events of type T.
Listener[T any] func(T) Listener[T any] func(ctx context.Context, event T)
// A Handle represents a listener. // A Handle represents a listener.
Handle string Handle string
@ -22,6 +22,7 @@ type (
handle Handle handle Handle
} }
dispatchEvent[T any] struct { dispatchEvent[T any] struct {
ctx context.Context
event T event T
} }
) )
@ -52,7 +53,7 @@ type Target[T any] struct {
addListenerCh chan addListenerEvent[T] addListenerCh chan addListenerEvent[T]
removeListenerCh chan removeListenerEvent[T] removeListenerCh chan removeListenerEvent[T]
dispatchCh chan dispatchEvent[T] dispatchCh chan dispatchEvent[T]
listeners map[Handle]chan T listeners map[Handle]chan dispatchEvent[T]
} }
// AddListener adds a listener to the target. // AddListener adds a listener to the target.
@ -78,13 +79,13 @@ func (t *Target[T]) Close() {
t.cancel(errors.New("target closed")) t.cancel(errors.New("target closed"))
} }
// Dispatch dispatches an event to any listeners. // Dispatch dispatches an event to all listeners.
func (t *Target[T]) Dispatch(evt T) { func (t *Target[T]) Dispatch(ctx context.Context, evt T) {
t.init() t.init()
select { select {
case <-t.ctx.Done(): case <-t.ctx.Done():
case t.dispatchCh <- dispatchEvent[T]{evt}: case t.dispatchCh <- dispatchEvent[T]{ctx: ctx, event: evt}:
} }
} }
@ -104,7 +105,7 @@ func (t *Target[T]) init() {
t.addListenerCh = make(chan addListenerEvent[T], 1) t.addListenerCh = make(chan addListenerEvent[T], 1)
t.removeListenerCh = make(chan removeListenerEvent[T], 1) t.removeListenerCh = make(chan removeListenerEvent[T], 1)
t.dispatchCh = make(chan dispatchEvent[T], 1) t.dispatchCh = make(chan dispatchEvent[T], 1)
t.listeners = map[Handle]chan T{} t.listeners = map[Handle]chan dispatchEvent[T]{}
go t.run() go t.run()
}) })
} }
@ -120,7 +121,7 @@ func (t *Target[T]) run() {
case evt := <-t.removeListenerCh: case evt := <-t.removeListenerCh:
t.removeListener(evt.handle) t.removeListener(evt.handle)
case evt := <-t.dispatchCh: case evt := <-t.dispatchCh:
t.dispatch(evt.event) t.dispatch(evt.ctx, evt.event)
} }
} }
} }
@ -128,7 +129,7 @@ func (t *Target[T]) run() {
// these functions are not thread-safe. They are intended to be called only by "run". // these functions are not thread-safe. They are intended to be called only by "run".
func (t *Target[T]) addListener(listener Listener[T], handle Handle) { func (t *Target[T]) addListener(listener Listener[T], handle Handle) {
ch := make(chan T, 1) ch := make(chan dispatchEvent[T], 1)
t.listeners[handle] = ch t.listeners[handle] = ch
// start a goroutine to send events to the listener // start a goroutine to send events to the listener
go func() { go func() {
@ -136,7 +137,7 @@ func (t *Target[T]) addListener(listener Listener[T], handle Handle) {
select { select {
case <-t.ctx.Done(): case <-t.ctx.Done():
case evt := <-ch: case evt := <-ch:
listener(evt) listener(evt.ctx, evt.event)
} }
} }
}() }()
@ -153,13 +154,13 @@ func (t *Target[T]) removeListener(handle Handle) {
delete(t.listeners, handle) delete(t.listeners, handle)
} }
func (t *Target[T]) dispatch(evt T) { func (t *Target[T]) dispatch(ctx context.Context, evt T) {
// loop over all the listeners and send the event to them // loop over all the listeners and send the event to them
for _, ch := range t.listeners { for _, ch := range t.listeners {
select { select {
case <-t.ctx.Done(): case <-t.ctx.Done():
return return
case ch <- evt: case ch <- dispatchEvent[T]{ctx: ctx, event: evt}:
} }
} }
} }

View file

@ -1,6 +1,7 @@
package events_test package events_test
import ( import (
"context"
"sync/atomic" "sync/atomic"
"testing" "testing"
"time" "time"
@ -14,16 +15,16 @@ func TestTarget(t *testing.T) {
t.Parallel() t.Parallel()
var target events.Target[int64] var target events.Target[int64]
defer target.Close() t.Cleanup(target.Close)
var calls1, calls2, calls3 atomic.Int64 var calls1, calls2, calls3 atomic.Int64
h1 := target.AddListener(func(i int64) { h1 := target.AddListener(func(_ context.Context, i int64) {
calls1.Add(i) calls1.Add(i)
}) })
h2 := target.AddListener(func(i int64) { h2 := target.AddListener(func(_ context.Context, i int64) {
calls2.Add(i) calls2.Add(i)
}) })
h3 := target.AddListener(func(i int64) { h3 := target.AddListener(func(_ context.Context, i int64) {
calls3.Add(i) calls3.Add(i)
}) })
@ -35,18 +36,18 @@ func TestTarget(t *testing.T) {
assert.Eventually(t, func() bool { return calls3.Load() == i3 }, time.Millisecond*10, time.Microsecond*100) assert.Eventually(t, func() bool { return calls3.Load() == i3 }, time.Millisecond*10, time.Microsecond*100)
} }
target.Dispatch(1) target.Dispatch(context.Background(), 1)
shouldBe(1, 1, 1) shouldBe(1, 1, 1)
target.RemoveListener(h2) target.RemoveListener(h2)
target.Dispatch(2) target.Dispatch(context.Background(), 2)
shouldBe(3, 1, 3) shouldBe(3, 1, 3)
target.RemoveListener(h1) target.RemoveListener(h1)
target.Dispatch(3) target.Dispatch(context.Background(), 3)
shouldBe(3, 1, 6) shouldBe(3, 1, 6)
target.RemoveListener(h3) target.RemoveListener(h3)
target.Dispatch(4) target.Dispatch(context.Background(), 4)
shouldBe(3, 1, 6) shouldBe(3, 1, 6)
} }