mirror of
https://github.com/pomerium/pomerium.git
synced 2025-04-30 10:56:28 +02:00
zero: simplify control loop lease retry code (#4979)
zero: simplify lease control loop
This commit is contained in:
parent
a2bf995642
commit
d405a53b90
3 changed files with 8 additions and 213 deletions
|
@ -5,6 +5,7 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/cenkalti/backoff/v4"
|
||||||
"golang.org/x/sync/errgroup"
|
"golang.org/x/sync/errgroup"
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||||
|
@ -23,7 +24,7 @@ func (c *service) GetDataBrokerServiceClient() databroker.DataBrokerServiceClien
|
||||||
// RunLeased implements the databroker.LeaseHandler interface.
|
// RunLeased implements the databroker.LeaseHandler interface.
|
||||||
func (c *service) RunLeased(ctx context.Context) error {
|
func (c *service) RunLeased(ctx context.Context) error {
|
||||||
eg, ctx := errgroup.WithContext(ctx)
|
eg, ctx := errgroup.WithContext(ctx)
|
||||||
for _, fn := range c.funcs {
|
for _, fn := range append(c.funcs, c.databrokerChangeMonitor) {
|
||||||
fn := fn
|
fn := fn
|
||||||
eg.Go(func() error {
|
eg.Go(func() error {
|
||||||
return fn(ctx)
|
return fn(ctx)
|
||||||
|
@ -42,8 +43,11 @@ func Run(
|
||||||
client: client,
|
client: client,
|
||||||
funcs: funcs,
|
funcs: funcs,
|
||||||
}
|
}
|
||||||
|
b := backoff.NewExponentialBackOff()
|
||||||
|
b.MaxElapsedTime = 0
|
||||||
leaser := databroker.NewLeaser("zero-ctrl", time.Second*30, srv)
|
leaser := databroker.NewLeaser("zero-ctrl", time.Second*30, srv)
|
||||||
return RunWithRestart(ctx, func(ctx context.Context) error {
|
return backoff.Retry(
|
||||||
return leaser.Run(ctx)
|
func() error { return leaser.Run(ctx) },
|
||||||
}, srv.databrokerChangeMonitor)
|
backoff.WithContext(b, ctx),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,99 +0,0 @@
|
||||||
package leaser
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/cenkalti/backoff/v4"
|
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/internal/log"
|
|
||||||
)
|
|
||||||
|
|
||||||
// RunWithRestart executes execFn.
|
|
||||||
// The execution would be restarted, by means of canceling the context provided to execFn, each time restartFn returns.
|
|
||||||
// the error returned by restartFn is purely informational and does not affect the execution; may be nil.
|
|
||||||
// the loop is stopped when the context provided to RunWithRestart is canceled or execFn returns an error unrelated to its context cancellation.
|
|
||||||
func RunWithRestart(
|
|
||||||
ctx context.Context,
|
|
||||||
execFn func(context.Context) error,
|
|
||||||
restartFn func(context.Context) error,
|
|
||||||
) error {
|
|
||||||
contexts := make(chan context.Context)
|
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(ctx)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
wg.Add(2)
|
|
||||||
|
|
||||||
var err error
|
|
||||||
go func() {
|
|
||||||
err = restartWithContext(contexts, execFn)
|
|
||||||
cancel()
|
|
||||||
wg.Done()
|
|
||||||
}()
|
|
||||||
go func() {
|
|
||||||
restartContexts(ctx, contexts, restartFn)
|
|
||||||
wg.Done()
|
|
||||||
}()
|
|
||||||
|
|
||||||
wg.Wait()
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func restartContexts(
|
|
||||||
base context.Context,
|
|
||||||
contexts chan<- context.Context,
|
|
||||||
restartFn func(context.Context) error,
|
|
||||||
) {
|
|
||||||
bo := backoff.NewExponentialBackOff()
|
|
||||||
bo.MaxElapsedTime = 0 // never stop
|
|
||||||
|
|
||||||
ticker := time.NewTicker(bo.InitialInterval)
|
|
||||||
defer ticker.Stop()
|
|
||||||
|
|
||||||
defer close(contexts)
|
|
||||||
for base.Err() == nil {
|
|
||||||
start := time.Now()
|
|
||||||
ctx, cancel := context.WithCancelCause(base)
|
|
||||||
select {
|
|
||||||
case contexts <- ctx:
|
|
||||||
err := restartFn(ctx)
|
|
||||||
cancel(fmt.Errorf("requesting restart: %w", err))
|
|
||||||
case <-base.Done():
|
|
||||||
cancel(fmt.Errorf("parent context canceled: %w", base.Err()))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if time.Since(start) > bo.MaxInterval {
|
|
||||||
bo.Reset()
|
|
||||||
}
|
|
||||||
next := bo.NextBackOff()
|
|
||||||
ticker.Reset(next)
|
|
||||||
|
|
||||||
log.Ctx(ctx).Info().Msgf("restarting zero control loop in %s", next.String())
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-base.Done():
|
|
||||||
return
|
|
||||||
case <-ticker.C:
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func restartWithContext(
|
|
||||||
contexts <-chan context.Context,
|
|
||||||
execFn func(context.Context) error,
|
|
||||||
) error {
|
|
||||||
var err error
|
|
||||||
for ctx := range contexts {
|
|
||||||
err = execFn(ctx)
|
|
||||||
if ctx.Err() == nil || !errors.Is(err, ctx.Err()) {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
}
|
|
|
@ -1,110 +0,0 @@
|
||||||
package leaser_test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/internal/zero/leaser"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestRestart(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
for i := 0; i < 20; i++ {
|
|
||||||
t.Run(fmt.Sprintf("quit on error %d", i), func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
errExpected := errors.New("execFn error")
|
|
||||||
count := 0
|
|
||||||
err := leaser.RunWithRestart(context.Background(),
|
|
||||||
func(context.Context) error {
|
|
||||||
count++
|
|
||||||
if count == 1 {
|
|
||||||
return errExpected
|
|
||||||
}
|
|
||||||
return errors.New("execFn should not be called more than once")
|
|
||||||
},
|
|
||||||
func(ctx context.Context) error {
|
|
||||||
<-ctx.Done()
|
|
||||||
return ctx.Err()
|
|
||||||
},
|
|
||||||
)
|
|
||||||
assert.ErrorIs(t, err, errExpected)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run(fmt.Sprintf("quit on no error %d", i), func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
count := 0
|
|
||||||
err := leaser.RunWithRestart(context.Background(),
|
|
||||||
func(context.Context) error {
|
|
||||||
count++
|
|
||||||
if count == 1 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return errors.New("execFn should not be called more than once")
|
|
||||||
},
|
|
||||||
func(ctx context.Context) error {
|
|
||||||
<-ctx.Done()
|
|
||||||
return ctx.Err()
|
|
||||||
},
|
|
||||||
)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run(fmt.Sprintf("parent context canceled %d", i), func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
|
||||||
t.Cleanup(cancel)
|
|
||||||
|
|
||||||
ready := make(chan struct{})
|
|
||||||
err := leaser.RunWithRestart(ctx,
|
|
||||||
func(context.Context) error {
|
|
||||||
<-ready
|
|
||||||
cancel()
|
|
||||||
return ctx.Err()
|
|
||||||
},
|
|
||||||
func(context.Context) error {
|
|
||||||
close(ready)
|
|
||||||
<-ctx.Done()
|
|
||||||
return ctx.Err()
|
|
||||||
},
|
|
||||||
)
|
|
||||||
assert.ErrorIs(t, err, context.Canceled)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run(fmt.Sprintf("triggers restart %d", i), func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
|
||||||
t.Cleanup(cancel)
|
|
||||||
|
|
||||||
errExpected := errors.New("execFn error")
|
|
||||||
count := 0
|
|
||||||
ready := make(chan struct{})
|
|
||||||
err := leaser.RunWithRestart(ctx,
|
|
||||||
func(ctx context.Context) error {
|
|
||||||
count++
|
|
||||||
if count == 1 { // wait for us to be restarted
|
|
||||||
close(ready)
|
|
||||||
<-ctx.Done()
|
|
||||||
return ctx.Err()
|
|
||||||
} else if count == 2 { // just quit
|
|
||||||
return errExpected
|
|
||||||
}
|
|
||||||
return errors.New("execFn should not be called more than twice")
|
|
||||||
},
|
|
||||||
func(ctx context.Context) error {
|
|
||||||
<-ready
|
|
||||||
return errors.New("restart required")
|
|
||||||
},
|
|
||||||
)
|
|
||||||
assert.ErrorIs(t, err, errExpected)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
Loading…
Add table
Reference in a new issue