mirror of
https://github.com/pomerium/pomerium.git
synced 2025-04-29 18:36:30 +02:00
core/storage: hijack connections for notification listeners (#4806)
This commit is contained in:
parent
4559320463
commit
1780fefa72
2 changed files with 46 additions and 36 deletions
|
@ -74,42 +74,9 @@ func New(dsn string, options ...Option) *Backend {
|
|||
return nil
|
||||
}, backend.cfg.registryTTL/2)
|
||||
|
||||
// listen for changes and broadcast them via signals
|
||||
for _, row := range []struct {
|
||||
signal *signal.Signal
|
||||
channel string
|
||||
}{
|
||||
{backend.onRecordChange, recordChangeNotifyName},
|
||||
{backend.onServiceChange, serviceChangeNotifyName},
|
||||
} {
|
||||
sig, ch := row.signal, row.channel
|
||||
go backend.doPeriodically(func(ctx context.Context) error {
|
||||
_, pool, err := backend.init(backend.closeCtx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
conn, err := pool.Acquire(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer conn.Release()
|
||||
|
||||
_, err = conn.Exec(ctx, `LISTEN `+ch)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = conn.Conn().WaitForNotification(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
sig.Broadcast(ctx)
|
||||
|
||||
return nil
|
||||
return backend.listenForNotifications(ctx)
|
||||
}, time.Millisecond*100)
|
||||
}
|
||||
|
||||
return backend
|
||||
}
|
||||
|
@ -433,6 +400,46 @@ func (backend *Backend) doPeriodically(f func(ctx context.Context) error, dur ti
|
|||
}
|
||||
}
|
||||
|
||||
func (backend *Backend) listenForNotifications(ctx context.Context) error {
|
||||
_, pool, err := backend.init(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error initializing pool for notifications: %w", err)
|
||||
}
|
||||
|
||||
poolConn, err := pool.Acquire(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error acquiring connection from pool for notifications: %w", err)
|
||||
}
|
||||
|
||||
// hijack the connection so the pool can be left for short-lived queries
|
||||
// and so that LISTENs don't leak to other queries
|
||||
conn := poolConn.Hijack()
|
||||
defer conn.Close(ctx)
|
||||
|
||||
for _, ch := range []string{recordChangeNotifyName, serviceChangeNotifyName} {
|
||||
_, err = conn.Exec(ctx, `LISTEN `+ch)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error listening on channel %s for notifications: %w", ch, err)
|
||||
}
|
||||
}
|
||||
|
||||
// for each notification broadcast the signal
|
||||
for {
|
||||
n, err := conn.WaitForNotification(ctx)
|
||||
if err != nil {
|
||||
// on error we'll close the connection to stop listening
|
||||
return fmt.Errorf("error receiving notification: %w", err)
|
||||
}
|
||||
|
||||
switch n.Channel {
|
||||
case recordChangeNotifyName:
|
||||
backend.onRecordChange.Broadcast(ctx)
|
||||
case serviceChangeNotifyName:
|
||||
backend.onServiceChange.Broadcast(ctx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ParseConfig parses a DSN into a pgxpool.Config.
|
||||
func ParseConfig(dsn string) (*pgxpool.Config, error) {
|
||||
config, err := pgxpool.ParseConfig(dsn)
|
||||
|
|
|
@ -195,6 +195,9 @@ func TestBackend(t *testing.T) {
|
|||
storagetest.TestBackendPatch(t, ctx, backend)
|
||||
})
|
||||
|
||||
assert.Equal(t, int32(0), backend.pool.Stat().AcquiredConns(),
|
||||
"acquired connections should be released")
|
||||
|
||||
return nil
|
||||
}))
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue