mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-15 18:17:49 +02:00
core/config: refactor change dispatcher (#4657)
* core/config: refactor change dispatcher * update test * close listener go routine when context is canceled * use cancel cause * use context * add more time * more time
This commit is contained in:
parent
53573dc046
commit
e0693e54f0
4 changed files with 241 additions and 14 deletions
|
@ -10,6 +10,7 @@ import (
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/internal/events"
|
||||||
"github.com/pomerium/pomerium/internal/fileutil"
|
"github.com/pomerium/pomerium/internal/fileutil"
|
||||||
"github.com/pomerium/pomerium/internal/log"
|
"github.com/pomerium/pomerium/internal/log"
|
||||||
"github.com/pomerium/pomerium/internal/telemetry/metrics"
|
"github.com/pomerium/pomerium/internal/telemetry/metrics"
|
||||||
|
@ -19,27 +20,27 @@ import (
|
||||||
// A ChangeListener is called when configuration changes.
|
// A ChangeListener is called when configuration changes.
|
||||||
type ChangeListener = func(context.Context, *Config)
|
type ChangeListener = func(context.Context, *Config)
|
||||||
|
|
||||||
|
type changeDispatcherEvent struct {
|
||||||
|
cfg *Config
|
||||||
|
}
|
||||||
|
|
||||||
// A ChangeDispatcher manages listeners on config changes.
|
// A ChangeDispatcher manages listeners on config changes.
|
||||||
type ChangeDispatcher struct {
|
type ChangeDispatcher struct {
|
||||||
sync.Mutex
|
target events.Target[changeDispatcherEvent]
|
||||||
onConfigChangeListeners []ChangeListener
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Trigger triggers a change.
|
// Trigger triggers a change.
|
||||||
func (dispatcher *ChangeDispatcher) Trigger(ctx context.Context, cfg *Config) {
|
func (dispatcher *ChangeDispatcher) Trigger(ctx context.Context, cfg *Config) {
|
||||||
dispatcher.Lock()
|
dispatcher.target.Dispatch(ctx, changeDispatcherEvent{
|
||||||
defer dispatcher.Unlock()
|
cfg: cfg,
|
||||||
|
})
|
||||||
for _, li := range dispatcher.onConfigChangeListeners {
|
|
||||||
li(ctx, cfg)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// OnConfigChange adds a listener.
|
// OnConfigChange adds a listener.
|
||||||
func (dispatcher *ChangeDispatcher) OnConfigChange(_ context.Context, li ChangeListener) {
|
func (dispatcher *ChangeDispatcher) OnConfigChange(_ context.Context, li ChangeListener) {
|
||||||
dispatcher.Lock()
|
dispatcher.target.AddListener(func(ctx context.Context, evt changeDispatcherEvent) {
|
||||||
defer dispatcher.Unlock()
|
li(ctx, evt.cfg)
|
||||||
dispatcher.onConfigChangeListeners = append(dispatcher.onConfigChangeListeners, li)
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// A Source gets configuration.
|
// A Source gets configuration.
|
||||||
|
|
|
@ -3,7 +3,9 @@ package config_test
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
|
"sync/atomic"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
@ -13,6 +15,8 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestLayeredConfig(t *testing.T) {
|
func TestLayeredConfig(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
t.Run("error on initial build", func(t *testing.T) {
|
t.Run("error on initial build", func(t *testing.T) {
|
||||||
|
@ -33,12 +37,15 @@ func TestLayeredConfig(t *testing.T) {
|
||||||
})
|
})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
var dst *config.Config
|
var dst atomic.Pointer[config.Config]
|
||||||
|
dst.Store(layered.GetConfig())
|
||||||
layered.OnConfigChange(ctx, func(ctx context.Context, c *config.Config) {
|
layered.OnConfigChange(ctx, func(ctx context.Context, c *config.Config) {
|
||||||
dst = c
|
dst.Store(c)
|
||||||
})
|
})
|
||||||
|
|
||||||
underlying.SetConfig(ctx, &config.Config{Options: &config.Options{DeriveInternalDomainCert: proto.String("b.com")}})
|
underlying.SetConfig(ctx, &config.Config{Options: &config.Options{DeriveInternalDomainCert: proto.String("b.com")}})
|
||||||
assert.Equal(t, "b.com", dst.Options.GetDeriveInternalDomain())
|
assert.Eventually(t, func() bool {
|
||||||
|
return dst.Load().Options.GetDeriveInternalDomain() == "b.com"
|
||||||
|
}, 10*time.Second, time.Millisecond)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
166
internal/events/target.go
Normal file
166
internal/events/target.go
Normal file
|
@ -0,0 +1,166 @@
|
||||||
|
package events
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
)
|
||||||
|
|
||||||
|
type (
|
||||||
|
// A Listener is a function that listens for events of type T.
|
||||||
|
Listener[T any] func(ctx context.Context, event T)
|
||||||
|
// A Handle represents a listener.
|
||||||
|
Handle string
|
||||||
|
|
||||||
|
addListenerEvent[T any] struct {
|
||||||
|
listener Listener[T]
|
||||||
|
handle Handle
|
||||||
|
}
|
||||||
|
removeListenerEvent[T any] struct {
|
||||||
|
handle Handle
|
||||||
|
}
|
||||||
|
dispatchEvent[T any] struct {
|
||||||
|
ctx context.Context
|
||||||
|
event T
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
// A Target is a target for events.
|
||||||
|
//
|
||||||
|
// Listeners are added with AddListener with a function to be called when the event occurs.
|
||||||
|
// AddListener returns a Handle which can be used to remove a listener with RemoveListener.
|
||||||
|
//
|
||||||
|
// Dispatch dispatches events to all the registered listeners.
|
||||||
|
//
|
||||||
|
// Target is safe to use in its zero state.
|
||||||
|
//
|
||||||
|
// The first time any method of Target is called a background goroutine is started that handles
|
||||||
|
// any requests and maintains the state of the listeners. Each listener also starts a
|
||||||
|
// separate goroutine so that all listeners can be invoked concurrently.
|
||||||
|
//
|
||||||
|
// The channels to the main goroutine and to the listener goroutines have a size of 1 so typically
|
||||||
|
// methods and dispatches will return immediately. However a slow listener will cause the next event
|
||||||
|
// dispatch to block. This is the opposite behavior from Manager.
|
||||||
|
//
|
||||||
|
// Close will cancel all the goroutines. Subsequent calls to AddListener, RemoveListener, Close and
|
||||||
|
// Dispatch are no-ops.
|
||||||
|
type Target[T any] struct {
|
||||||
|
initOnce sync.Once
|
||||||
|
ctx context.Context
|
||||||
|
cancel context.CancelCauseFunc
|
||||||
|
addListenerCh chan addListenerEvent[T]
|
||||||
|
removeListenerCh chan removeListenerEvent[T]
|
||||||
|
dispatchCh chan dispatchEvent[T]
|
||||||
|
listeners map[Handle]chan dispatchEvent[T]
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddListener adds a listener to the target.
|
||||||
|
func (t *Target[T]) AddListener(listener Listener[T]) Handle {
|
||||||
|
t.init()
|
||||||
|
|
||||||
|
// using a handle is necessary because you can't use a function as a map key.
|
||||||
|
handle := Handle(uuid.NewString())
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-t.ctx.Done():
|
||||||
|
case t.addListenerCh <- addListenerEvent[T]{listener, handle}:
|
||||||
|
}
|
||||||
|
|
||||||
|
return handle
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close closes the event target. This can be called multiple times safely.
|
||||||
|
// Once closed the target cannot be used.
|
||||||
|
func (t *Target[T]) Close() {
|
||||||
|
t.init()
|
||||||
|
|
||||||
|
t.cancel(errors.New("target closed"))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Dispatch dispatches an event to all listeners.
|
||||||
|
func (t *Target[T]) Dispatch(ctx context.Context, evt T) {
|
||||||
|
t.init()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-t.ctx.Done():
|
||||||
|
case t.dispatchCh <- dispatchEvent[T]{ctx: ctx, event: evt}:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveListener removes a listener from the target.
|
||||||
|
func (t *Target[T]) RemoveListener(handle Handle) {
|
||||||
|
t.init()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-t.ctx.Done():
|
||||||
|
case t.removeListenerCh <- removeListenerEvent[T]{handle}:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *Target[T]) init() {
|
||||||
|
t.initOnce.Do(func() {
|
||||||
|
t.ctx, t.cancel = context.WithCancelCause(context.Background())
|
||||||
|
t.addListenerCh = make(chan addListenerEvent[T], 1)
|
||||||
|
t.removeListenerCh = make(chan removeListenerEvent[T], 1)
|
||||||
|
t.dispatchCh = make(chan dispatchEvent[T], 1)
|
||||||
|
t.listeners = map[Handle]chan dispatchEvent[T]{}
|
||||||
|
go t.run()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *Target[T]) run() {
|
||||||
|
// listen for add/remove/dispatch events and call functions
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-t.ctx.Done():
|
||||||
|
return
|
||||||
|
case evt := <-t.addListenerCh:
|
||||||
|
t.addListener(evt.listener, evt.handle)
|
||||||
|
case evt := <-t.removeListenerCh:
|
||||||
|
t.removeListener(evt.handle)
|
||||||
|
case evt := <-t.dispatchCh:
|
||||||
|
t.dispatch(evt.ctx, evt.event)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// these functions are not thread-safe. They are intended to be called only by "run".
|
||||||
|
|
||||||
|
func (t *Target[T]) addListener(listener Listener[T], handle Handle) {
|
||||||
|
ch := make(chan dispatchEvent[T], 1)
|
||||||
|
t.listeners[handle] = ch
|
||||||
|
// start a goroutine to send events to the listener
|
||||||
|
go func() {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-t.ctx.Done():
|
||||||
|
case evt := <-ch:
|
||||||
|
listener(evt.ctx, evt.event)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *Target[T]) removeListener(handle Handle) {
|
||||||
|
ch, ok := t.listeners[handle]
|
||||||
|
if !ok {
|
||||||
|
// nothing to do since the listener doesn't exist
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// close the channel to kill the goroutine
|
||||||
|
close(ch)
|
||||||
|
delete(t.listeners, handle)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *Target[T]) dispatch(ctx context.Context, evt T) {
|
||||||
|
// loop over all the listeners and send the event to them
|
||||||
|
for _, ch := range t.listeners {
|
||||||
|
select {
|
||||||
|
case <-t.ctx.Done():
|
||||||
|
return
|
||||||
|
case ch <- dispatchEvent[T]{ctx: ctx, event: evt}:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
53
internal/events/target_test.go
Normal file
53
internal/events/target_test.go
Normal file
|
@ -0,0 +1,53 @@
|
||||||
|
package events_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/internal/events"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestTarget(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
var target events.Target[int64]
|
||||||
|
t.Cleanup(target.Close)
|
||||||
|
|
||||||
|
var calls1, calls2, calls3 atomic.Int64
|
||||||
|
h1 := target.AddListener(func(_ context.Context, i int64) {
|
||||||
|
calls1.Add(i)
|
||||||
|
})
|
||||||
|
h2 := target.AddListener(func(_ context.Context, i int64) {
|
||||||
|
calls2.Add(i)
|
||||||
|
})
|
||||||
|
h3 := target.AddListener(func(_ context.Context, i int64) {
|
||||||
|
calls3.Add(i)
|
||||||
|
})
|
||||||
|
|
||||||
|
shouldBe := func(i1, i2, i3 int64) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
assert.Eventually(t, func() bool { return calls1.Load() == i1 }, time.Second, time.Millisecond)
|
||||||
|
assert.Eventually(t, func() bool { return calls2.Load() == i2 }, time.Second, time.Millisecond)
|
||||||
|
assert.Eventually(t, func() bool { return calls3.Load() == i3 }, time.Second, time.Millisecond)
|
||||||
|
}
|
||||||
|
|
||||||
|
target.Dispatch(context.Background(), 1)
|
||||||
|
shouldBe(1, 1, 1)
|
||||||
|
|
||||||
|
target.RemoveListener(h2)
|
||||||
|
target.Dispatch(context.Background(), 2)
|
||||||
|
shouldBe(3, 1, 3)
|
||||||
|
|
||||||
|
target.RemoveListener(h1)
|
||||||
|
target.Dispatch(context.Background(), 3)
|
||||||
|
shouldBe(3, 1, 6)
|
||||||
|
|
||||||
|
target.RemoveListener(h3)
|
||||||
|
target.Dispatch(context.Background(), 4)
|
||||||
|
shouldBe(3, 1, 6)
|
||||||
|
}
|
Loading…
Add table
Add a link
Reference in a new issue