mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-15 10:07:47 +02:00
core/config: refactor change dispatcher (#4657)
* core/config: refactor change dispatcher * update test * close listener go routine when context is canceled * use cancel cause * use context * add more time * more time
This commit is contained in:
parent
53573dc046
commit
e0693e54f0
4 changed files with 241 additions and 14 deletions
|
@ -10,6 +10,7 @@ import (
|
|||
"github.com/google/uuid"
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/events"
|
||||
"github.com/pomerium/pomerium/internal/fileutil"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/internal/telemetry/metrics"
|
||||
|
@ -19,27 +20,27 @@ import (
|
|||
// A ChangeListener is called when configuration changes.
|
||||
type ChangeListener = func(context.Context, *Config)
|
||||
|
||||
type changeDispatcherEvent struct {
|
||||
cfg *Config
|
||||
}
|
||||
|
||||
// A ChangeDispatcher manages listeners on config changes.
|
||||
type ChangeDispatcher struct {
|
||||
sync.Mutex
|
||||
onConfigChangeListeners []ChangeListener
|
||||
target events.Target[changeDispatcherEvent]
|
||||
}
|
||||
|
||||
// Trigger triggers a change.
|
||||
func (dispatcher *ChangeDispatcher) Trigger(ctx context.Context, cfg *Config) {
|
||||
dispatcher.Lock()
|
||||
defer dispatcher.Unlock()
|
||||
|
||||
for _, li := range dispatcher.onConfigChangeListeners {
|
||||
li(ctx, cfg)
|
||||
}
|
||||
dispatcher.target.Dispatch(ctx, changeDispatcherEvent{
|
||||
cfg: cfg,
|
||||
})
|
||||
}
|
||||
|
||||
// OnConfigChange adds a listener.
|
||||
func (dispatcher *ChangeDispatcher) OnConfigChange(_ context.Context, li ChangeListener) {
|
||||
dispatcher.Lock()
|
||||
defer dispatcher.Unlock()
|
||||
dispatcher.onConfigChangeListeners = append(dispatcher.onConfigChangeListeners, li)
|
||||
dispatcher.target.AddListener(func(ctx context.Context, evt changeDispatcherEvent) {
|
||||
li(ctx, evt.cfg)
|
||||
})
|
||||
}
|
||||
|
||||
// A Source gets configuration.
|
||||
|
|
|
@ -3,7 +3,9 @@ package config_test
|
|||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
@ -13,6 +15,8 @@ import (
|
|||
)
|
||||
|
||||
func TestLayeredConfig(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("error on initial build", func(t *testing.T) {
|
||||
|
@ -33,12 +37,15 @@ func TestLayeredConfig(t *testing.T) {
|
|||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
var dst *config.Config
|
||||
var dst atomic.Pointer[config.Config]
|
||||
dst.Store(layered.GetConfig())
|
||||
layered.OnConfigChange(ctx, func(ctx context.Context, c *config.Config) {
|
||||
dst = c
|
||||
dst.Store(c)
|
||||
})
|
||||
|
||||
underlying.SetConfig(ctx, &config.Config{Options: &config.Options{DeriveInternalDomainCert: proto.String("b.com")}})
|
||||
assert.Equal(t, "b.com", dst.Options.GetDeriveInternalDomain())
|
||||
assert.Eventually(t, func() bool {
|
||||
return dst.Load().Options.GetDeriveInternalDomain() == "b.com"
|
||||
}, 10*time.Second, time.Millisecond)
|
||||
})
|
||||
}
|
||||
|
|
166
internal/events/target.go
Normal file
166
internal/events/target.go
Normal file
|
@ -0,0 +1,166 @@
|
|||
package events
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type (
|
||||
// A Listener is a function that listens for events of type T.
|
||||
Listener[T any] func(ctx context.Context, event T)
|
||||
// A Handle represents a listener.
|
||||
Handle string
|
||||
|
||||
addListenerEvent[T any] struct {
|
||||
listener Listener[T]
|
||||
handle Handle
|
||||
}
|
||||
removeListenerEvent[T any] struct {
|
||||
handle Handle
|
||||
}
|
||||
dispatchEvent[T any] struct {
|
||||
ctx context.Context
|
||||
event T
|
||||
}
|
||||
)
|
||||
|
||||
// A Target is a target for events.
|
||||
//
|
||||
// Listeners are added with AddListener with a function to be called when the event occurs.
|
||||
// AddListener returns a Handle which can be used to remove a listener with RemoveListener.
|
||||
//
|
||||
// Dispatch dispatches events to all the registered listeners.
|
||||
//
|
||||
// Target is safe to use in its zero state.
|
||||
//
|
||||
// The first time any method of Target is called a background goroutine is started that handles
|
||||
// any requests and maintains the state of the listeners. Each listener also starts a
|
||||
// separate goroutine so that all listeners can be invoked concurrently.
|
||||
//
|
||||
// The channels to the main goroutine and to the listener goroutines have a size of 1 so typically
|
||||
// methods and dispatches will return immediately. However a slow listener will cause the next event
|
||||
// dispatch to block. This is the opposite behavior from Manager.
|
||||
//
|
||||
// Close will cancel all the goroutines. Subsequent calls to AddListener, RemoveListener, Close and
|
||||
// Dispatch are no-ops.
|
||||
type Target[T any] struct {
|
||||
initOnce sync.Once
|
||||
ctx context.Context
|
||||
cancel context.CancelCauseFunc
|
||||
addListenerCh chan addListenerEvent[T]
|
||||
removeListenerCh chan removeListenerEvent[T]
|
||||
dispatchCh chan dispatchEvent[T]
|
||||
listeners map[Handle]chan dispatchEvent[T]
|
||||
}
|
||||
|
||||
// AddListener adds a listener to the target.
|
||||
func (t *Target[T]) AddListener(listener Listener[T]) Handle {
|
||||
t.init()
|
||||
|
||||
// using a handle is necessary because you can't use a function as a map key.
|
||||
handle := Handle(uuid.NewString())
|
||||
|
||||
select {
|
||||
case <-t.ctx.Done():
|
||||
case t.addListenerCh <- addListenerEvent[T]{listener, handle}:
|
||||
}
|
||||
|
||||
return handle
|
||||
}
|
||||
|
||||
// Close closes the event target. This can be called multiple times safely.
|
||||
// Once closed the target cannot be used.
|
||||
func (t *Target[T]) Close() {
|
||||
t.init()
|
||||
|
||||
t.cancel(errors.New("target closed"))
|
||||
}
|
||||
|
||||
// Dispatch dispatches an event to all listeners.
|
||||
func (t *Target[T]) Dispatch(ctx context.Context, evt T) {
|
||||
t.init()
|
||||
|
||||
select {
|
||||
case <-t.ctx.Done():
|
||||
case t.dispatchCh <- dispatchEvent[T]{ctx: ctx, event: evt}:
|
||||
}
|
||||
}
|
||||
|
||||
// RemoveListener removes a listener from the target.
|
||||
func (t *Target[T]) RemoveListener(handle Handle) {
|
||||
t.init()
|
||||
|
||||
select {
|
||||
case <-t.ctx.Done():
|
||||
case t.removeListenerCh <- removeListenerEvent[T]{handle}:
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Target[T]) init() {
|
||||
t.initOnce.Do(func() {
|
||||
t.ctx, t.cancel = context.WithCancelCause(context.Background())
|
||||
t.addListenerCh = make(chan addListenerEvent[T], 1)
|
||||
t.removeListenerCh = make(chan removeListenerEvent[T], 1)
|
||||
t.dispatchCh = make(chan dispatchEvent[T], 1)
|
||||
t.listeners = map[Handle]chan dispatchEvent[T]{}
|
||||
go t.run()
|
||||
})
|
||||
}
|
||||
|
||||
func (t *Target[T]) run() {
|
||||
// listen for add/remove/dispatch events and call functions
|
||||
for {
|
||||
select {
|
||||
case <-t.ctx.Done():
|
||||
return
|
||||
case evt := <-t.addListenerCh:
|
||||
t.addListener(evt.listener, evt.handle)
|
||||
case evt := <-t.removeListenerCh:
|
||||
t.removeListener(evt.handle)
|
||||
case evt := <-t.dispatchCh:
|
||||
t.dispatch(evt.ctx, evt.event)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// these functions are not thread-safe. They are intended to be called only by "run".
|
||||
|
||||
func (t *Target[T]) addListener(listener Listener[T], handle Handle) {
|
||||
ch := make(chan dispatchEvent[T], 1)
|
||||
t.listeners[handle] = ch
|
||||
// start a goroutine to send events to the listener
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-t.ctx.Done():
|
||||
case evt := <-ch:
|
||||
listener(evt.ctx, evt.event)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (t *Target[T]) removeListener(handle Handle) {
|
||||
ch, ok := t.listeners[handle]
|
||||
if !ok {
|
||||
// nothing to do since the listener doesn't exist
|
||||
return
|
||||
}
|
||||
// close the channel to kill the goroutine
|
||||
close(ch)
|
||||
delete(t.listeners, handle)
|
||||
}
|
||||
|
||||
func (t *Target[T]) dispatch(ctx context.Context, evt T) {
|
||||
// loop over all the listeners and send the event to them
|
||||
for _, ch := range t.listeners {
|
||||
select {
|
||||
case <-t.ctx.Done():
|
||||
return
|
||||
case ch <- dispatchEvent[T]{ctx: ctx, event: evt}:
|
||||
}
|
||||
}
|
||||
}
|
53
internal/events/target_test.go
Normal file
53
internal/events/target_test.go
Normal file
|
@ -0,0 +1,53 @@
|
|||
package events_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/events"
|
||||
)
|
||||
|
||||
func TestTarget(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var target events.Target[int64]
|
||||
t.Cleanup(target.Close)
|
||||
|
||||
var calls1, calls2, calls3 atomic.Int64
|
||||
h1 := target.AddListener(func(_ context.Context, i int64) {
|
||||
calls1.Add(i)
|
||||
})
|
||||
h2 := target.AddListener(func(_ context.Context, i int64) {
|
||||
calls2.Add(i)
|
||||
})
|
||||
h3 := target.AddListener(func(_ context.Context, i int64) {
|
||||
calls3.Add(i)
|
||||
})
|
||||
|
||||
shouldBe := func(i1, i2, i3 int64) {
|
||||
t.Helper()
|
||||
|
||||
assert.Eventually(t, func() bool { return calls1.Load() == i1 }, time.Second, time.Millisecond)
|
||||
assert.Eventually(t, func() bool { return calls2.Load() == i2 }, time.Second, time.Millisecond)
|
||||
assert.Eventually(t, func() bool { return calls3.Load() == i3 }, time.Second, time.Millisecond)
|
||||
}
|
||||
|
||||
target.Dispatch(context.Background(), 1)
|
||||
shouldBe(1, 1, 1)
|
||||
|
||||
target.RemoveListener(h2)
|
||||
target.Dispatch(context.Background(), 2)
|
||||
shouldBe(3, 1, 3)
|
||||
|
||||
target.RemoveListener(h1)
|
||||
target.Dispatch(context.Background(), 3)
|
||||
shouldBe(3, 1, 6)
|
||||
|
||||
target.RemoveListener(h3)
|
||||
target.Dispatch(context.Background(), 4)
|
||||
shouldBe(3, 1, 6)
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue