diff --git a/config/config_source.go b/config/config_source.go index dd5d92b5c..9c20534de 100644 --- a/config/config_source.go +++ b/config/config_source.go @@ -10,6 +10,7 @@ import ( "github.com/google/uuid" "github.com/rs/zerolog" + "github.com/pomerium/pomerium/internal/events" "github.com/pomerium/pomerium/internal/fileutil" "github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/telemetry/metrics" @@ -19,27 +20,27 @@ import ( // A ChangeListener is called when configuration changes. type ChangeListener = func(context.Context, *Config) +type changeDispatcherEvent struct { + cfg *Config +} + // A ChangeDispatcher manages listeners on config changes. type ChangeDispatcher struct { - sync.Mutex - onConfigChangeListeners []ChangeListener + target events.Target[changeDispatcherEvent] } // Trigger triggers a change. func (dispatcher *ChangeDispatcher) Trigger(ctx context.Context, cfg *Config) { - dispatcher.Lock() - defer dispatcher.Unlock() - - for _, li := range dispatcher.onConfigChangeListeners { - li(ctx, cfg) - } + dispatcher.target.Dispatch(ctx, changeDispatcherEvent{ + cfg: cfg, + }) } // OnConfigChange adds a listener. func (dispatcher *ChangeDispatcher) OnConfigChange(_ context.Context, li ChangeListener) { - dispatcher.Lock() - defer dispatcher.Unlock() - dispatcher.onConfigChangeListeners = append(dispatcher.onConfigChangeListeners, li) + dispatcher.target.AddListener(func(ctx context.Context, evt changeDispatcherEvent) { + li(ctx, evt.cfg) + }) } // A Source gets configuration. diff --git a/config/layered_test.go b/config/layered_test.go index a8da007c8..97aebbeca 100644 --- a/config/layered_test.go +++ b/config/layered_test.go @@ -3,7 +3,9 @@ package config_test import ( "context" "errors" + "sync/atomic" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -13,6 +15,8 @@ import ( ) func TestLayeredConfig(t *testing.T) { + t.Parallel() + ctx := context.Background() t.Run("error on initial build", func(t *testing.T) { @@ -33,12 +37,15 @@ func TestLayeredConfig(t *testing.T) { }) require.NoError(t, err) - var dst *config.Config + var dst atomic.Pointer[config.Config] + dst.Store(layered.GetConfig()) layered.OnConfigChange(ctx, func(ctx context.Context, c *config.Config) { - dst = c + dst.Store(c) }) underlying.SetConfig(ctx, &config.Config{Options: &config.Options{DeriveInternalDomainCert: proto.String("b.com")}}) - assert.Equal(t, "b.com", dst.Options.GetDeriveInternalDomain()) + assert.Eventually(t, func() bool { + return dst.Load().Options.GetDeriveInternalDomain() == "b.com" + }, 10*time.Second, time.Millisecond) }) } diff --git a/internal/events/target.go b/internal/events/target.go new file mode 100644 index 000000000..7bc0806a2 --- /dev/null +++ b/internal/events/target.go @@ -0,0 +1,166 @@ +package events + +import ( + "context" + "errors" + "sync" + + "github.com/google/uuid" +) + +type ( + // A Listener is a function that listens for events of type T. + Listener[T any] func(ctx context.Context, event T) + // A Handle represents a listener. + Handle string + + addListenerEvent[T any] struct { + listener Listener[T] + handle Handle + } + removeListenerEvent[T any] struct { + handle Handle + } + dispatchEvent[T any] struct { + ctx context.Context + event T + } +) + +// A Target is a target for events. +// +// Listeners are added with AddListener with a function to be called when the event occurs. +// AddListener returns a Handle which can be used to remove a listener with RemoveListener. +// +// Dispatch dispatches events to all the registered listeners. +// +// Target is safe to use in its zero state. +// +// The first time any method of Target is called a background goroutine is started that handles +// any requests and maintains the state of the listeners. Each listener also starts a +// separate goroutine so that all listeners can be invoked concurrently. +// +// The channels to the main goroutine and to the listener goroutines have a size of 1 so typically +// methods and dispatches will return immediately. However a slow listener will cause the next event +// dispatch to block. This is the opposite behavior from Manager. +// +// Close will cancel all the goroutines. Subsequent calls to AddListener, RemoveListener, Close and +// Dispatch are no-ops. +type Target[T any] struct { + initOnce sync.Once + ctx context.Context + cancel context.CancelCauseFunc + addListenerCh chan addListenerEvent[T] + removeListenerCh chan removeListenerEvent[T] + dispatchCh chan dispatchEvent[T] + listeners map[Handle]chan dispatchEvent[T] +} + +// AddListener adds a listener to the target. +func (t *Target[T]) AddListener(listener Listener[T]) Handle { + t.init() + + // using a handle is necessary because you can't use a function as a map key. + handle := Handle(uuid.NewString()) + + select { + case <-t.ctx.Done(): + case t.addListenerCh <- addListenerEvent[T]{listener, handle}: + } + + return handle +} + +// Close closes the event target. This can be called multiple times safely. +// Once closed the target cannot be used. +func (t *Target[T]) Close() { + t.init() + + t.cancel(errors.New("target closed")) +} + +// Dispatch dispatches an event to all listeners. +func (t *Target[T]) Dispatch(ctx context.Context, evt T) { + t.init() + + select { + case <-t.ctx.Done(): + case t.dispatchCh <- dispatchEvent[T]{ctx: ctx, event: evt}: + } +} + +// RemoveListener removes a listener from the target. +func (t *Target[T]) RemoveListener(handle Handle) { + t.init() + + select { + case <-t.ctx.Done(): + case t.removeListenerCh <- removeListenerEvent[T]{handle}: + } +} + +func (t *Target[T]) init() { + t.initOnce.Do(func() { + t.ctx, t.cancel = context.WithCancelCause(context.Background()) + t.addListenerCh = make(chan addListenerEvent[T], 1) + t.removeListenerCh = make(chan removeListenerEvent[T], 1) + t.dispatchCh = make(chan dispatchEvent[T], 1) + t.listeners = map[Handle]chan dispatchEvent[T]{} + go t.run() + }) +} + +func (t *Target[T]) run() { + // listen for add/remove/dispatch events and call functions + for { + select { + case <-t.ctx.Done(): + return + case evt := <-t.addListenerCh: + t.addListener(evt.listener, evt.handle) + case evt := <-t.removeListenerCh: + t.removeListener(evt.handle) + case evt := <-t.dispatchCh: + t.dispatch(evt.ctx, evt.event) + } + } +} + +// these functions are not thread-safe. They are intended to be called only by "run". + +func (t *Target[T]) addListener(listener Listener[T], handle Handle) { + ch := make(chan dispatchEvent[T], 1) + t.listeners[handle] = ch + // start a goroutine to send events to the listener + go func() { + for { + select { + case <-t.ctx.Done(): + case evt := <-ch: + listener(evt.ctx, evt.event) + } + } + }() +} + +func (t *Target[T]) removeListener(handle Handle) { + ch, ok := t.listeners[handle] + if !ok { + // nothing to do since the listener doesn't exist + return + } + // close the channel to kill the goroutine + close(ch) + delete(t.listeners, handle) +} + +func (t *Target[T]) dispatch(ctx context.Context, evt T) { + // loop over all the listeners and send the event to them + for _, ch := range t.listeners { + select { + case <-t.ctx.Done(): + return + case ch <- dispatchEvent[T]{ctx: ctx, event: evt}: + } + } +} diff --git a/internal/events/target_test.go b/internal/events/target_test.go new file mode 100644 index 000000000..f91a5c6d6 --- /dev/null +++ b/internal/events/target_test.go @@ -0,0 +1,53 @@ +package events_test + +import ( + "context" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "github.com/pomerium/pomerium/internal/events" +) + +func TestTarget(t *testing.T) { + t.Parallel() + + var target events.Target[int64] + t.Cleanup(target.Close) + + var calls1, calls2, calls3 atomic.Int64 + h1 := target.AddListener(func(_ context.Context, i int64) { + calls1.Add(i) + }) + h2 := target.AddListener(func(_ context.Context, i int64) { + calls2.Add(i) + }) + h3 := target.AddListener(func(_ context.Context, i int64) { + calls3.Add(i) + }) + + shouldBe := func(i1, i2, i3 int64) { + t.Helper() + + assert.Eventually(t, func() bool { return calls1.Load() == i1 }, time.Second, time.Millisecond) + assert.Eventually(t, func() bool { return calls2.Load() == i2 }, time.Second, time.Millisecond) + assert.Eventually(t, func() bool { return calls3.Load() == i3 }, time.Second, time.Millisecond) + } + + target.Dispatch(context.Background(), 1) + shouldBe(1, 1, 1) + + target.RemoveListener(h2) + target.Dispatch(context.Background(), 2) + shouldBe(3, 1, 3) + + target.RemoveListener(h1) + target.Dispatch(context.Background(), 3) + shouldBe(3, 1, 6) + + target.RemoveListener(h3) + target.Dispatch(context.Background(), 4) + shouldBe(3, 1, 6) +}