From ddb2b416652322266363dd7907170a8099e99de1 Mon Sep 17 00:00:00 2001 From: Caleb Doxsey Date: Wed, 1 Nov 2023 09:50:33 -0600 Subject: [PATCH] core/config: refactor change dispatcher --- config/config_source.go | 25 +++--- internal/events/target.go | 160 +++++++++++++++++++++++++++++++++ internal/events/target_test.go | 52 +++++++++++ 3 files changed, 226 insertions(+), 11 deletions(-) create mode 100644 internal/events/target.go create mode 100644 internal/events/target_test.go diff --git a/config/config_source.go b/config/config_source.go index 1900f3906..6d4558ca0 100644 --- a/config/config_source.go +++ b/config/config_source.go @@ -10,6 +10,7 @@ import ( "github.com/google/uuid" "github.com/rs/zerolog" + "github.com/pomerium/pomerium/internal/events" "github.com/pomerium/pomerium/internal/fileutil" "github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/telemetry/metrics" @@ -19,27 +20,29 @@ import ( // A ChangeListener is called when configuration changes. type ChangeListener = func(context.Context, *Config) +type changeDispatcherEvent struct { + ctx context.Context + cfg *Config +} + // A ChangeDispatcher manages listeners on config changes. type ChangeDispatcher struct { - sync.Mutex - onConfigChangeListeners []ChangeListener + target events.Target[changeDispatcherEvent] } // Trigger triggers a change. func (dispatcher *ChangeDispatcher) Trigger(ctx context.Context, cfg *Config) { - dispatcher.Lock() - defer dispatcher.Unlock() - - for _, li := range dispatcher.onConfigChangeListeners { - li(ctx, cfg) - } + dispatcher.target.Dispatch(changeDispatcherEvent{ + ctx: ctx, + cfg: cfg, + }) } // OnConfigChange adds a listener. func (dispatcher *ChangeDispatcher) OnConfigChange(_ context.Context, li ChangeListener) { - dispatcher.Lock() - defer dispatcher.Unlock() - dispatcher.onConfigChangeListeners = append(dispatcher.onConfigChangeListeners, li) + dispatcher.target.AddListener(func(evt changeDispatcherEvent) { + li(evt.ctx, evt.cfg) + }) } // A Source gets configuration. diff --git a/internal/events/target.go b/internal/events/target.go new file mode 100644 index 000000000..3c09a6f67 --- /dev/null +++ b/internal/events/target.go @@ -0,0 +1,160 @@ +package events + +import ( + "context" + "sync" + + "github.com/google/uuid" +) + +type ( + // A Listener is a function that listens for events of type T. + Listener[T any] func(T) + // A Handle represents a listener. + Handle string + + addListenerEvent[T any] struct { + listener Listener[T] + handle Handle + } + removeListenerEvent[T any] struct { + handle Handle + } + dispatchEvent[T any] struct { + event T + } +) + +// A Target is a target for events. +// +// Listeners are added with AddListener with a function to be called when the event occurs. +// AddListener returns a Handle which can be used to remove a listener with RemoveListener. +// +// Dispatch dispatches events to all the registered listeners. +// +// Target is safe to use in its zero state. +// +// The first time any method of Target is called a background goroutine is started that handles +// any requests and maintains the state of the listeners. Each listener also starts a +// separate goroutine so that all listeners can be invoked concurrently. +// +// The channels to the main goroutine and to the listener goroutines have a size of 1 so typically +// methods and dispatches will return immediately. However a slow listener will cause the next event +// dispatch to block. This is the opposite behavior from Manager. +// +// Close will cancel all the goroutines. Subsequent calls to AddListener, RemoveListener, Close and +// Dispatch are no-ops. +type Target[T any] struct { + initOnce sync.Once + ctx context.Context + cancel context.CancelFunc + addListenerCh chan addListenerEvent[T] + removeListenerCh chan removeListenerEvent[T] + dispatchCh chan dispatchEvent[T] + listeners map[Handle]chan T +} + +// AddListener adds a listener to the target. +func (t *Target[T]) AddListener(listener Listener[T]) Handle { + t.init() + + // using a handle is necessary because you can't use a function as a map key. + handle := Handle(uuid.NewString()) + + select { + case <-t.ctx.Done(): + case t.addListenerCh <- addListenerEvent[T]{listener, handle}: + } + + return handle +} + +// Close closes the event target. This can be called multiple times safely. +// Once closed the target cannot be used. +func (t *Target[T]) Close() { + t.init() + + t.cancel() +} + +// Dispatch dispatches an event to any listeners. +func (t *Target[T]) Dispatch(evt T) { + t.init() + + select { + case <-t.ctx.Done(): + case t.dispatchCh <- dispatchEvent[T]{evt}: + } +} + +// RemoveListener removes a listener from the target. +func (t *Target[T]) RemoveListener(handle Handle) { + t.init() + + select { + case <-t.ctx.Done(): + case t.removeListenerCh <- removeListenerEvent[T]{handle}: + } +} + +func (t *Target[T]) init() { + t.initOnce.Do(func() { + t.ctx, t.cancel = context.WithCancel(context.Background()) + t.addListenerCh = make(chan addListenerEvent[T], 1) + t.removeListenerCh = make(chan removeListenerEvent[T], 1) + t.dispatchCh = make(chan dispatchEvent[T], 1) + t.listeners = map[Handle]chan T{} + go t.run() + }) +} + +func (t *Target[T]) run() { + // listen for add/remove/dispatch events and call functions + for { + select { + case <-t.ctx.Done(): + return + case evt := <-t.addListenerCh: + t.addListener(evt.listener, evt.handle) + case evt := <-t.removeListenerCh: + t.removeListener(evt.handle) + case evt := <-t.dispatchCh: + t.dispatch(evt.event) + } + } +} + +// these functions are not thread-safe. They are intended to be called only by "run". + +func (t *Target[T]) addListener(listener Listener[T], handle Handle) { + ch := make(chan T, 1) + t.listeners[handle] = ch + // start a goroutine to send events to the listener + go func() { + for evt := range ch { + listener(evt) + } + }() +} + +func (t *Target[T]) removeListener(handle Handle) { + ch, ok := t.listeners[handle] + if !ok { + // nothing to do since the listener doesn't exist + return + } + // close the channel to kill the goroutine + close(ch) + delete(t.listeners, handle) +} + +func (t *Target[T]) dispatch(evt T) { + // loop over all the listeners and send the event to them + for _, ch := range t.listeners { + select { + case <-t.ctx.Done(): + return + case ch <- evt: + } + } +} diff --git a/internal/events/target_test.go b/internal/events/target_test.go new file mode 100644 index 000000000..8fa38cc97 --- /dev/null +++ b/internal/events/target_test.go @@ -0,0 +1,52 @@ +package events_test + +import ( + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "github.com/pomerium/pomerium/internal/events" +) + +func TestTarget(t *testing.T) { + t.Parallel() + + var target events.Target[int64] + defer target.Close() + + var calls1, calls2, calls3 atomic.Int64 + h1 := target.AddListener(func(i int64) { + calls1.Add(i) + }) + h2 := target.AddListener(func(i int64) { + calls2.Add(i) + }) + h3 := target.AddListener(func(i int64) { + calls3.Add(i) + }) + + shouldBe := func(i1, i2, i3 int64) { + t.Helper() + + assert.Eventually(t, func() bool { return calls1.Load() == i1 }, time.Millisecond*10, time.Microsecond*100) + assert.Eventually(t, func() bool { return calls2.Load() == i2 }, time.Millisecond*10, time.Microsecond*100) + assert.Eventually(t, func() bool { return calls3.Load() == i3 }, time.Millisecond*10, time.Microsecond*100) + } + + target.Dispatch(1) + shouldBe(1, 1, 1) + + target.RemoveListener(h2) + target.Dispatch(2) + shouldBe(3, 1, 3) + + target.RemoveListener(h1) + target.Dispatch(3) + shouldBe(3, 1, 6) + + target.RemoveListener(h3) + target.Dispatch(4) + shouldBe(3, 1, 6) +}