From d405a53b90bac04085bd65bc497399660fa689f8 Mon Sep 17 00:00:00 2001 From: Denis Mishin Date: Fri, 1 Mar 2024 11:36:08 -0500 Subject: [PATCH] zero: simplify control loop lease retry code (#4979) zero: simplify lease control loop --- internal/zero/leaser/leaser.go | 12 ++- internal/zero/leaser/restart.go | 99 ------------------------ internal/zero/leaser/restart_test.go | 110 --------------------------- 3 files changed, 8 insertions(+), 213 deletions(-) delete mode 100644 internal/zero/leaser/restart.go delete mode 100644 internal/zero/leaser/restart_test.go diff --git a/internal/zero/leaser/leaser.go b/internal/zero/leaser/leaser.go index 9018fd239..a9976946c 100644 --- a/internal/zero/leaser/leaser.go +++ b/internal/zero/leaser/leaser.go @@ -5,6 +5,7 @@ import ( "context" "time" + "github.com/cenkalti/backoff/v4" "golang.org/x/sync/errgroup" "github.com/pomerium/pomerium/pkg/grpc/databroker" @@ -23,7 +24,7 @@ func (c *service) GetDataBrokerServiceClient() databroker.DataBrokerServiceClien // RunLeased implements the databroker.LeaseHandler interface. func (c *service) RunLeased(ctx context.Context) error { eg, ctx := errgroup.WithContext(ctx) - for _, fn := range c.funcs { + for _, fn := range append(c.funcs, c.databrokerChangeMonitor) { fn := fn eg.Go(func() error { return fn(ctx) @@ -42,8 +43,11 @@ func Run( client: client, funcs: funcs, } + b := backoff.NewExponentialBackOff() + b.MaxElapsedTime = 0 leaser := databroker.NewLeaser("zero-ctrl", time.Second*30, srv) - return RunWithRestart(ctx, func(ctx context.Context) error { - return leaser.Run(ctx) - }, srv.databrokerChangeMonitor) + return backoff.Retry( + func() error { return leaser.Run(ctx) }, + backoff.WithContext(b, ctx), + ) } diff --git a/internal/zero/leaser/restart.go b/internal/zero/leaser/restart.go deleted file mode 100644 index 6a852c97d..000000000 --- a/internal/zero/leaser/restart.go +++ /dev/null @@ -1,99 +0,0 @@ -package leaser - -import ( - "context" - "errors" - "fmt" - "sync" - "time" - - "github.com/cenkalti/backoff/v4" - - "github.com/pomerium/pomerium/internal/log" -) - -// RunWithRestart executes execFn. -// The execution would be restarted, by means of canceling the context provided to execFn, each time restartFn returns. -// the error returned by restartFn is purely informational and does not affect the execution; may be nil. -// the loop is stopped when the context provided to RunWithRestart is canceled or execFn returns an error unrelated to its context cancellation. -func RunWithRestart( - ctx context.Context, - execFn func(context.Context) error, - restartFn func(context.Context) error, -) error { - contexts := make(chan context.Context) - - ctx, cancel := context.WithCancel(ctx) - defer cancel() - - var wg sync.WaitGroup - wg.Add(2) - - var err error - go func() { - err = restartWithContext(contexts, execFn) - cancel() - wg.Done() - }() - go func() { - restartContexts(ctx, contexts, restartFn) - wg.Done() - }() - - wg.Wait() - return err -} - -func restartContexts( - base context.Context, - contexts chan<- context.Context, - restartFn func(context.Context) error, -) { - bo := backoff.NewExponentialBackOff() - bo.MaxElapsedTime = 0 // never stop - - ticker := time.NewTicker(bo.InitialInterval) - defer ticker.Stop() - - defer close(contexts) - for base.Err() == nil { - start := time.Now() - ctx, cancel := context.WithCancelCause(base) - select { - case contexts <- ctx: - err := restartFn(ctx) - cancel(fmt.Errorf("requesting restart: %w", err)) - case <-base.Done(): - cancel(fmt.Errorf("parent context canceled: %w", base.Err())) - return - } - - if time.Since(start) > bo.MaxInterval { - bo.Reset() - } - next := bo.NextBackOff() - ticker.Reset(next) - - log.Ctx(ctx).Info().Msgf("restarting zero control loop in %s", next.String()) - - select { - case <-base.Done(): - return - case <-ticker.C: - } - } -} - -func restartWithContext( - contexts <-chan context.Context, - execFn func(context.Context) error, -) error { - var err error - for ctx := range contexts { - err = execFn(ctx) - if ctx.Err() == nil || !errors.Is(err, ctx.Err()) { - return err - } - } - return err -} diff --git a/internal/zero/leaser/restart_test.go b/internal/zero/leaser/restart_test.go deleted file mode 100644 index 4e17a52cd..000000000 --- a/internal/zero/leaser/restart_test.go +++ /dev/null @@ -1,110 +0,0 @@ -package leaser_test - -import ( - "context" - "errors" - "fmt" - "testing" - - "github.com/stretchr/testify/assert" - - "github.com/pomerium/pomerium/internal/zero/leaser" -) - -func TestRestart(t *testing.T) { - t.Parallel() - - for i := 0; i < 20; i++ { - t.Run(fmt.Sprintf("quit on error %d", i), func(t *testing.T) { - t.Parallel() - - errExpected := errors.New("execFn error") - count := 0 - err := leaser.RunWithRestart(context.Background(), - func(context.Context) error { - count++ - if count == 1 { - return errExpected - } - return errors.New("execFn should not be called more than once") - }, - func(ctx context.Context) error { - <-ctx.Done() - return ctx.Err() - }, - ) - assert.ErrorIs(t, err, errExpected) - }) - - t.Run(fmt.Sprintf("quit on no error %d", i), func(t *testing.T) { - t.Parallel() - - count := 0 - err := leaser.RunWithRestart(context.Background(), - func(context.Context) error { - count++ - if count == 1 { - return nil - } - return errors.New("execFn should not be called more than once") - }, - func(ctx context.Context) error { - <-ctx.Done() - return ctx.Err() - }, - ) - assert.NoError(t, err) - }) - - t.Run(fmt.Sprintf("parent context canceled %d", i), func(t *testing.T) { - t.Parallel() - - ctx, cancel := context.WithCancel(context.Background()) - t.Cleanup(cancel) - - ready := make(chan struct{}) - err := leaser.RunWithRestart(ctx, - func(context.Context) error { - <-ready - cancel() - return ctx.Err() - }, - func(context.Context) error { - close(ready) - <-ctx.Done() - return ctx.Err() - }, - ) - assert.ErrorIs(t, err, context.Canceled) - }) - - t.Run(fmt.Sprintf("triggers restart %d", i), func(t *testing.T) { - t.Parallel() - - ctx, cancel := context.WithCancel(context.Background()) - t.Cleanup(cancel) - - errExpected := errors.New("execFn error") - count := 0 - ready := make(chan struct{}) - err := leaser.RunWithRestart(ctx, - func(ctx context.Context) error { - count++ - if count == 1 { // wait for us to be restarted - close(ready) - <-ctx.Done() - return ctx.Err() - } else if count == 2 { // just quit - return errExpected - } - return errors.New("execFn should not be called more than twice") - }, - func(ctx context.Context) error { - <-ready - return errors.New("restart required") - }, - ) - assert.ErrorIs(t, err, errExpected) - }) - } -}