mirror of
https://github.com/pomerium/pomerium.git
synced 2025-04-29 10:26:29 +02:00
zero: simplify control loop lease retry code (#4979)
zero: simplify lease control loop
This commit is contained in:
parent
a2bf995642
commit
d405a53b90
3 changed files with 8 additions and 213 deletions
|
@ -5,6 +5,7 @@ import (
|
|||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
|
@ -23,7 +24,7 @@ func (c *service) GetDataBrokerServiceClient() databroker.DataBrokerServiceClien
|
|||
// RunLeased implements the databroker.LeaseHandler interface.
|
||||
func (c *service) RunLeased(ctx context.Context) error {
|
||||
eg, ctx := errgroup.WithContext(ctx)
|
||||
for _, fn := range c.funcs {
|
||||
for _, fn := range append(c.funcs, c.databrokerChangeMonitor) {
|
||||
fn := fn
|
||||
eg.Go(func() error {
|
||||
return fn(ctx)
|
||||
|
@ -42,8 +43,11 @@ func Run(
|
|||
client: client,
|
||||
funcs: funcs,
|
||||
}
|
||||
b := backoff.NewExponentialBackOff()
|
||||
b.MaxElapsedTime = 0
|
||||
leaser := databroker.NewLeaser("zero-ctrl", time.Second*30, srv)
|
||||
return RunWithRestart(ctx, func(ctx context.Context) error {
|
||||
return leaser.Run(ctx)
|
||||
}, srv.databrokerChangeMonitor)
|
||||
return backoff.Retry(
|
||||
func() error { return leaser.Run(ctx) },
|
||||
backoff.WithContext(b, ctx),
|
||||
)
|
||||
}
|
||||
|
|
|
@ -1,99 +0,0 @@
|
|||
package leaser
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
)
|
||||
|
||||
// RunWithRestart executes execFn.
|
||||
// The execution would be restarted, by means of canceling the context provided to execFn, each time restartFn returns.
|
||||
// the error returned by restartFn is purely informational and does not affect the execution; may be nil.
|
||||
// the loop is stopped when the context provided to RunWithRestart is canceled or execFn returns an error unrelated to its context cancellation.
|
||||
func RunWithRestart(
|
||||
ctx context.Context,
|
||||
execFn func(context.Context) error,
|
||||
restartFn func(context.Context) error,
|
||||
) error {
|
||||
contexts := make(chan context.Context)
|
||||
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
|
||||
var err error
|
||||
go func() {
|
||||
err = restartWithContext(contexts, execFn)
|
||||
cancel()
|
||||
wg.Done()
|
||||
}()
|
||||
go func() {
|
||||
restartContexts(ctx, contexts, restartFn)
|
||||
wg.Done()
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
return err
|
||||
}
|
||||
|
||||
func restartContexts(
|
||||
base context.Context,
|
||||
contexts chan<- context.Context,
|
||||
restartFn func(context.Context) error,
|
||||
) {
|
||||
bo := backoff.NewExponentialBackOff()
|
||||
bo.MaxElapsedTime = 0 // never stop
|
||||
|
||||
ticker := time.NewTicker(bo.InitialInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
defer close(contexts)
|
||||
for base.Err() == nil {
|
||||
start := time.Now()
|
||||
ctx, cancel := context.WithCancelCause(base)
|
||||
select {
|
||||
case contexts <- ctx:
|
||||
err := restartFn(ctx)
|
||||
cancel(fmt.Errorf("requesting restart: %w", err))
|
||||
case <-base.Done():
|
||||
cancel(fmt.Errorf("parent context canceled: %w", base.Err()))
|
||||
return
|
||||
}
|
||||
|
||||
if time.Since(start) > bo.MaxInterval {
|
||||
bo.Reset()
|
||||
}
|
||||
next := bo.NextBackOff()
|
||||
ticker.Reset(next)
|
||||
|
||||
log.Ctx(ctx).Info().Msgf("restarting zero control loop in %s", next.String())
|
||||
|
||||
select {
|
||||
case <-base.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func restartWithContext(
|
||||
contexts <-chan context.Context,
|
||||
execFn func(context.Context) error,
|
||||
) error {
|
||||
var err error
|
||||
for ctx := range contexts {
|
||||
err = execFn(ctx)
|
||||
if ctx.Err() == nil || !errors.Is(err, ctx.Err()) {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
|
@ -1,110 +0,0 @@
|
|||
package leaser_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/zero/leaser"
|
||||
)
|
||||
|
||||
func TestRestart(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
for i := 0; i < 20; i++ {
|
||||
t.Run(fmt.Sprintf("quit on error %d", i), func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
errExpected := errors.New("execFn error")
|
||||
count := 0
|
||||
err := leaser.RunWithRestart(context.Background(),
|
||||
func(context.Context) error {
|
||||
count++
|
||||
if count == 1 {
|
||||
return errExpected
|
||||
}
|
||||
return errors.New("execFn should not be called more than once")
|
||||
},
|
||||
func(ctx context.Context) error {
|
||||
<-ctx.Done()
|
||||
return ctx.Err()
|
||||
},
|
||||
)
|
||||
assert.ErrorIs(t, err, errExpected)
|
||||
})
|
||||
|
||||
t.Run(fmt.Sprintf("quit on no error %d", i), func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
count := 0
|
||||
err := leaser.RunWithRestart(context.Background(),
|
||||
func(context.Context) error {
|
||||
count++
|
||||
if count == 1 {
|
||||
return nil
|
||||
}
|
||||
return errors.New("execFn should not be called more than once")
|
||||
},
|
||||
func(ctx context.Context) error {
|
||||
<-ctx.Done()
|
||||
return ctx.Err()
|
||||
},
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run(fmt.Sprintf("parent context canceled %d", i), func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
t.Cleanup(cancel)
|
||||
|
||||
ready := make(chan struct{})
|
||||
err := leaser.RunWithRestart(ctx,
|
||||
func(context.Context) error {
|
||||
<-ready
|
||||
cancel()
|
||||
return ctx.Err()
|
||||
},
|
||||
func(context.Context) error {
|
||||
close(ready)
|
||||
<-ctx.Done()
|
||||
return ctx.Err()
|
||||
},
|
||||
)
|
||||
assert.ErrorIs(t, err, context.Canceled)
|
||||
})
|
||||
|
||||
t.Run(fmt.Sprintf("triggers restart %d", i), func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
t.Cleanup(cancel)
|
||||
|
||||
errExpected := errors.New("execFn error")
|
||||
count := 0
|
||||
ready := make(chan struct{})
|
||||
err := leaser.RunWithRestart(ctx,
|
||||
func(ctx context.Context) error {
|
||||
count++
|
||||
if count == 1 { // wait for us to be restarted
|
||||
close(ready)
|
||||
<-ctx.Done()
|
||||
return ctx.Err()
|
||||
} else if count == 2 { // just quit
|
||||
return errExpected
|
||||
}
|
||||
return errors.New("execFn should not be called more than twice")
|
||||
},
|
||||
func(ctx context.Context) error {
|
||||
<-ready
|
||||
return errors.New("restart required")
|
||||
},
|
||||
)
|
||||
assert.ErrorIs(t, err, errExpected)
|
||||
})
|
||||
}
|
||||
}
|
Loading…
Add table
Reference in a new issue