zero: simplify control loop lease retry code (#4979)

zero: simplify lease control loop
This commit is contained in:
Denis Mishin 2024-03-01 11:36:08 -05:00 committed by GitHub
parent a2bf995642
commit d405a53b90
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 8 additions and 213 deletions

View file

@ -5,6 +5,7 @@ import (
"context"
"time"
"github.com/cenkalti/backoff/v4"
"golang.org/x/sync/errgroup"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
@ -23,7 +24,7 @@ func (c *service) GetDataBrokerServiceClient() databroker.DataBrokerServiceClien
// RunLeased implements the databroker.LeaseHandler interface.
func (c *service) RunLeased(ctx context.Context) error {
eg, ctx := errgroup.WithContext(ctx)
for _, fn := range c.funcs {
for _, fn := range append(c.funcs, c.databrokerChangeMonitor) {
fn := fn
eg.Go(func() error {
return fn(ctx)
@ -42,8 +43,11 @@ func Run(
client: client,
funcs: funcs,
}
b := backoff.NewExponentialBackOff()
b.MaxElapsedTime = 0
leaser := databroker.NewLeaser("zero-ctrl", time.Second*30, srv)
return RunWithRestart(ctx, func(ctx context.Context) error {
return leaser.Run(ctx)
}, srv.databrokerChangeMonitor)
return backoff.Retry(
func() error { return leaser.Run(ctx) },
backoff.WithContext(b, ctx),
)
}

View file

@ -1,99 +0,0 @@
package leaser
import (
"context"
"errors"
"fmt"
"sync"
"time"
"github.com/cenkalti/backoff/v4"
"github.com/pomerium/pomerium/internal/log"
)
// RunWithRestart executes execFn.
// The execution would be restarted, by means of canceling the context provided to execFn, each time restartFn returns.
// the error returned by restartFn is purely informational and does not affect the execution; may be nil.
// the loop is stopped when the context provided to RunWithRestart is canceled or execFn returns an error unrelated to its context cancellation.
func RunWithRestart(
ctx context.Context,
execFn func(context.Context) error,
restartFn func(context.Context) error,
) error {
contexts := make(chan context.Context)
ctx, cancel := context.WithCancel(ctx)
defer cancel()
var wg sync.WaitGroup
wg.Add(2)
var err error
go func() {
err = restartWithContext(contexts, execFn)
cancel()
wg.Done()
}()
go func() {
restartContexts(ctx, contexts, restartFn)
wg.Done()
}()
wg.Wait()
return err
}
func restartContexts(
base context.Context,
contexts chan<- context.Context,
restartFn func(context.Context) error,
) {
bo := backoff.NewExponentialBackOff()
bo.MaxElapsedTime = 0 // never stop
ticker := time.NewTicker(bo.InitialInterval)
defer ticker.Stop()
defer close(contexts)
for base.Err() == nil {
start := time.Now()
ctx, cancel := context.WithCancelCause(base)
select {
case contexts <- ctx:
err := restartFn(ctx)
cancel(fmt.Errorf("requesting restart: %w", err))
case <-base.Done():
cancel(fmt.Errorf("parent context canceled: %w", base.Err()))
return
}
if time.Since(start) > bo.MaxInterval {
bo.Reset()
}
next := bo.NextBackOff()
ticker.Reset(next)
log.Ctx(ctx).Info().Msgf("restarting zero control loop in %s", next.String())
select {
case <-base.Done():
return
case <-ticker.C:
}
}
}
func restartWithContext(
contexts <-chan context.Context,
execFn func(context.Context) error,
) error {
var err error
for ctx := range contexts {
err = execFn(ctx)
if ctx.Err() == nil || !errors.Is(err, ctx.Err()) {
return err
}
}
return err
}

View file

@ -1,110 +0,0 @@
package leaser_test
import (
"context"
"errors"
"fmt"
"testing"
"github.com/stretchr/testify/assert"
"github.com/pomerium/pomerium/internal/zero/leaser"
)
func TestRestart(t *testing.T) {
t.Parallel()
for i := 0; i < 20; i++ {
t.Run(fmt.Sprintf("quit on error %d", i), func(t *testing.T) {
t.Parallel()
errExpected := errors.New("execFn error")
count := 0
err := leaser.RunWithRestart(context.Background(),
func(context.Context) error {
count++
if count == 1 {
return errExpected
}
return errors.New("execFn should not be called more than once")
},
func(ctx context.Context) error {
<-ctx.Done()
return ctx.Err()
},
)
assert.ErrorIs(t, err, errExpected)
})
t.Run(fmt.Sprintf("quit on no error %d", i), func(t *testing.T) {
t.Parallel()
count := 0
err := leaser.RunWithRestart(context.Background(),
func(context.Context) error {
count++
if count == 1 {
return nil
}
return errors.New("execFn should not be called more than once")
},
func(ctx context.Context) error {
<-ctx.Done()
return ctx.Err()
},
)
assert.NoError(t, err)
})
t.Run(fmt.Sprintf("parent context canceled %d", i), func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
ready := make(chan struct{})
err := leaser.RunWithRestart(ctx,
func(context.Context) error {
<-ready
cancel()
return ctx.Err()
},
func(context.Context) error {
close(ready)
<-ctx.Done()
return ctx.Err()
},
)
assert.ErrorIs(t, err, context.Canceled)
})
t.Run(fmt.Sprintf("triggers restart %d", i), func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
errExpected := errors.New("execFn error")
count := 0
ready := make(chan struct{})
err := leaser.RunWithRestart(ctx,
func(ctx context.Context) error {
count++
if count == 1 { // wait for us to be restarted
close(ready)
<-ctx.Done()
return ctx.Err()
} else if count == 2 { // just quit
return errExpected
}
return errors.New("execFn should not be called more than twice")
},
func(ctx context.Context) error {
<-ready
return errors.New("restart required")
},
)
assert.ErrorIs(t, err, errExpected)
})
}
}