core/config: refactor change dispatcher (#4657)

* core/config: refactor change dispatcher

* update test

* close listener go routine when context is canceled

* use cancel cause

* use context

* add more time

* more time
This commit is contained in:
Caleb Doxsey 2023-11-01 13:52:23 -06:00 committed by GitHub
parent 53573dc046
commit e0693e54f0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 241 additions and 14 deletions

View file

@ -10,6 +10,7 @@ import (
"github.com/google/uuid"
"github.com/rs/zerolog"
"github.com/pomerium/pomerium/internal/events"
"github.com/pomerium/pomerium/internal/fileutil"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/telemetry/metrics"
@ -19,27 +20,27 @@ import (
// A ChangeListener is called when configuration changes.
type ChangeListener = func(context.Context, *Config)
type changeDispatcherEvent struct {
cfg *Config
}
// A ChangeDispatcher manages listeners on config changes.
type ChangeDispatcher struct {
sync.Mutex
onConfigChangeListeners []ChangeListener
target events.Target[changeDispatcherEvent]
}
// Trigger triggers a change.
func (dispatcher *ChangeDispatcher) Trigger(ctx context.Context, cfg *Config) {
dispatcher.Lock()
defer dispatcher.Unlock()
for _, li := range dispatcher.onConfigChangeListeners {
li(ctx, cfg)
}
dispatcher.target.Dispatch(ctx, changeDispatcherEvent{
cfg: cfg,
})
}
// OnConfigChange adds a listener.
func (dispatcher *ChangeDispatcher) OnConfigChange(_ context.Context, li ChangeListener) {
dispatcher.Lock()
defer dispatcher.Unlock()
dispatcher.onConfigChangeListeners = append(dispatcher.onConfigChangeListeners, li)
dispatcher.target.AddListener(func(ctx context.Context, evt changeDispatcherEvent) {
li(ctx, evt.cfg)
})
}
// A Source gets configuration.

View file

@ -3,7 +3,9 @@ package config_test
import (
"context"
"errors"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@ -13,6 +15,8 @@ import (
)
func TestLayeredConfig(t *testing.T) {
t.Parallel()
ctx := context.Background()
t.Run("error on initial build", func(t *testing.T) {
@ -33,12 +37,15 @@ func TestLayeredConfig(t *testing.T) {
})
require.NoError(t, err)
var dst *config.Config
var dst atomic.Pointer[config.Config]
dst.Store(layered.GetConfig())
layered.OnConfigChange(ctx, func(ctx context.Context, c *config.Config) {
dst = c
dst.Store(c)
})
underlying.SetConfig(ctx, &config.Config{Options: &config.Options{DeriveInternalDomainCert: proto.String("b.com")}})
assert.Equal(t, "b.com", dst.Options.GetDeriveInternalDomain())
assert.Eventually(t, func() bool {
return dst.Load().Options.GetDeriveInternalDomain() == "b.com"
}, 10*time.Second, time.Millisecond)
})
}

166
internal/events/target.go Normal file
View file

@ -0,0 +1,166 @@
package events
import (
"context"
"errors"
"sync"
"github.com/google/uuid"
)
type (
// A Listener is a function that listens for events of type T.
Listener[T any] func(ctx context.Context, event T)
// A Handle represents a listener.
Handle string
addListenerEvent[T any] struct {
listener Listener[T]
handle Handle
}
removeListenerEvent[T any] struct {
handle Handle
}
dispatchEvent[T any] struct {
ctx context.Context
event T
}
)
// A Target is a target for events.
//
// Listeners are added with AddListener with a function to be called when the event occurs.
// AddListener returns a Handle which can be used to remove a listener with RemoveListener.
//
// Dispatch dispatches events to all the registered listeners.
//
// Target is safe to use in its zero state.
//
// The first time any method of Target is called a background goroutine is started that handles
// any requests and maintains the state of the listeners. Each listener also starts a
// separate goroutine so that all listeners can be invoked concurrently.
//
// The channels to the main goroutine and to the listener goroutines have a size of 1 so typically
// methods and dispatches will return immediately. However a slow listener will cause the next event
// dispatch to block. This is the opposite behavior from Manager.
//
// Close will cancel all the goroutines. Subsequent calls to AddListener, RemoveListener, Close and
// Dispatch are no-ops.
type Target[T any] struct {
initOnce sync.Once
ctx context.Context
cancel context.CancelCauseFunc
addListenerCh chan addListenerEvent[T]
removeListenerCh chan removeListenerEvent[T]
dispatchCh chan dispatchEvent[T]
listeners map[Handle]chan dispatchEvent[T]
}
// AddListener adds a listener to the target.
func (t *Target[T]) AddListener(listener Listener[T]) Handle {
t.init()
// using a handle is necessary because you can't use a function as a map key.
handle := Handle(uuid.NewString())
select {
case <-t.ctx.Done():
case t.addListenerCh <- addListenerEvent[T]{listener, handle}:
}
return handle
}
// Close closes the event target. This can be called multiple times safely.
// Once closed the target cannot be used.
func (t *Target[T]) Close() {
t.init()
t.cancel(errors.New("target closed"))
}
// Dispatch dispatches an event to all listeners.
func (t *Target[T]) Dispatch(ctx context.Context, evt T) {
t.init()
select {
case <-t.ctx.Done():
case t.dispatchCh <- dispatchEvent[T]{ctx: ctx, event: evt}:
}
}
// RemoveListener removes a listener from the target.
func (t *Target[T]) RemoveListener(handle Handle) {
t.init()
select {
case <-t.ctx.Done():
case t.removeListenerCh <- removeListenerEvent[T]{handle}:
}
}
func (t *Target[T]) init() {
t.initOnce.Do(func() {
t.ctx, t.cancel = context.WithCancelCause(context.Background())
t.addListenerCh = make(chan addListenerEvent[T], 1)
t.removeListenerCh = make(chan removeListenerEvent[T], 1)
t.dispatchCh = make(chan dispatchEvent[T], 1)
t.listeners = map[Handle]chan dispatchEvent[T]{}
go t.run()
})
}
func (t *Target[T]) run() {
// listen for add/remove/dispatch events and call functions
for {
select {
case <-t.ctx.Done():
return
case evt := <-t.addListenerCh:
t.addListener(evt.listener, evt.handle)
case evt := <-t.removeListenerCh:
t.removeListener(evt.handle)
case evt := <-t.dispatchCh:
t.dispatch(evt.ctx, evt.event)
}
}
}
// these functions are not thread-safe. They are intended to be called only by "run".
func (t *Target[T]) addListener(listener Listener[T], handle Handle) {
ch := make(chan dispatchEvent[T], 1)
t.listeners[handle] = ch
// start a goroutine to send events to the listener
go func() {
for {
select {
case <-t.ctx.Done():
case evt := <-ch:
listener(evt.ctx, evt.event)
}
}
}()
}
func (t *Target[T]) removeListener(handle Handle) {
ch, ok := t.listeners[handle]
if !ok {
// nothing to do since the listener doesn't exist
return
}
// close the channel to kill the goroutine
close(ch)
delete(t.listeners, handle)
}
func (t *Target[T]) dispatch(ctx context.Context, evt T) {
// loop over all the listeners and send the event to them
for _, ch := range t.listeners {
select {
case <-t.ctx.Done():
return
case ch <- dispatchEvent[T]{ctx: ctx, event: evt}:
}
}
}

View file

@ -0,0 +1,53 @@
package events_test
import (
"context"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/pomerium/pomerium/internal/events"
)
func TestTarget(t *testing.T) {
t.Parallel()
var target events.Target[int64]
t.Cleanup(target.Close)
var calls1, calls2, calls3 atomic.Int64
h1 := target.AddListener(func(_ context.Context, i int64) {
calls1.Add(i)
})
h2 := target.AddListener(func(_ context.Context, i int64) {
calls2.Add(i)
})
h3 := target.AddListener(func(_ context.Context, i int64) {
calls3.Add(i)
})
shouldBe := func(i1, i2, i3 int64) {
t.Helper()
assert.Eventually(t, func() bool { return calls1.Load() == i1 }, time.Second, time.Millisecond)
assert.Eventually(t, func() bool { return calls2.Load() == i2 }, time.Second, time.Millisecond)
assert.Eventually(t, func() bool { return calls3.Load() == i3 }, time.Second, time.Millisecond)
}
target.Dispatch(context.Background(), 1)
shouldBe(1, 1, 1)
target.RemoveListener(h2)
target.Dispatch(context.Background(), 2)
shouldBe(3, 1, 3)
target.RemoveListener(h1)
target.Dispatch(context.Background(), 3)
shouldBe(3, 1, 6)
target.RemoveListener(h3)
target.Dispatch(context.Background(), 4)
shouldBe(3, 1, 6)
}