add retry package (#4458)

This commit is contained in:
Denis Mishin 2023-08-16 12:45:07 -04:00 committed by Kenneth Jenkins
parent 0d29401192
commit 5ddfc74645
5 changed files with 416 additions and 0 deletions

76
internal/retry/config.go Normal file
View file

@ -0,0 +1,76 @@
package retry
import (
"context"
"reflect"
"time"
"github.com/cenkalti/backoff/v4"
)
type config struct {
maxInterval time.Duration
watches []watch
backoff.BackOff
}
// watch is a helper struct to watch multiple channels
type watch struct {
name string
ch reflect.Value
fn func(context.Context) error
this bool
}
// Option configures the retry handler
type Option func(*config)
// WithWatch adds a watch to the retry handler
// that will be triggered when a value is received on the channel
// and the function will be called, also within a retry handler
func WithWatch[T any](name string, ch <-chan T, fn func(context.Context) error) Option {
return func(cfg *config) {
cfg.watches = append(cfg.watches, watch{name: name, ch: reflect.ValueOf(ch), fn: fn, this: false})
}
}
// WithMaxInterval sets the upper bound for the retry handler
func WithMaxInterval(d time.Duration) Option {
return func(cfg *config) {
cfg.maxInterval = d
}
}
func newConfig(opts ...Option) ([]watch, backoff.BackOff) {
cfg := new(config)
for _, opt := range []Option{
WithMaxInterval(time.Minute * 5),
} {
opt(cfg)
}
for _, opt := range opts {
opt(cfg)
}
for i, w := range cfg.watches {
cfg.watches[i].fn = withRetry(cfg, w)
}
bo := backoff.NewExponentialBackOff()
bo.MaxInterval = cfg.maxInterval
bo.MaxElapsedTime = 0
return cfg.watches, bo
}
func withRetry(cfg *config, w watch) func(context.Context) error {
if w.fn == nil {
return func(_ context.Context) error { return nil }
}
return func(ctx context.Context) error {
return Retry(ctx, w.name, w.fn, WithMaxInterval(cfg.maxInterval))
}
}

48
internal/retry/error.go Normal file
View file

@ -0,0 +1,48 @@
package retry
import (
"errors"
"fmt"
)
// TerminalError is an error that should not be retried
type TerminalError interface {
error
IsTerminal()
}
// terminalError is an error that should not be retried
type terminalError struct {
Err error
}
// Error implements error for terminalError
func (e *terminalError) Error() string {
return fmt.Sprintf("terminal error: %v", e.Err)
}
// Unwrap implements errors.Unwrap for terminalError
func (e *terminalError) Unwrap() error {
return e.Err
}
// Is implements errors.Is for terminalError
func (e *terminalError) Is(err error) bool {
//nolint:errorlint
_, ok := err.(*terminalError)
return ok
}
// IsTerminal implements TerminalError for terminalError
func (e *terminalError) IsTerminal() {}
// NewTerminalError creates a new terminal error that cannot be retried
func NewTerminalError(err error) error {
return &terminalError{Err: err}
}
// IsTerminalError returns true if the error is a terminal error
func IsTerminalError(err error) bool {
var te TerminalError
return errors.As(err, &te)
}

View file

@ -0,0 +1,29 @@
package retry_test
import (
"fmt"
"testing"
"github.com/stretchr/testify/require"
"github.com/pomerium/pomerium/internal/retry"
)
type testError string
func (e testError) Error() string {
return string(e)
}
func (e testError) IsTerminal() {}
func TestError(t *testing.T) {
t.Run("local terminal error", func(t *testing.T) {
err := fmt.Errorf("wrap: %w", retry.NewTerminalError(fmt.Errorf("inner")))
require.True(t, retry.IsTerminalError(err))
})
t.Run("external terminal error", func(t *testing.T) {
err := fmt.Errorf("wrap: %w", testError("inner"))
require.True(t, retry.IsTerminalError(err))
})
}

139
internal/retry/retry.go Normal file
View file

@ -0,0 +1,139 @@
// Package retry provides a retry loop with exponential back-off
// while watching arbitrary signal channels for side effects.
package retry
import (
"context"
"fmt"
"reflect"
"time"
)
// Retry retries a function (with exponential back-off) until it succeeds.
// It additionally watches arbitrary channels and calls the handler function when a value is received.
// Handler functions are also retried with exponential back-off.
// If a terminal error is returned from the handler function, the retry loop is aborted.
// If the context is canceled, the retry loop is aborted.
func Retry(
ctx context.Context,
name string,
fn func(context.Context) error,
opts ...Option,
) error {
watches, backoff := newConfig(opts...)
ticker := time.NewTicker(backoff.NextBackOff())
defer ticker.Stop()
s := makeSelect(ctx, watches, name, ticker.C, fn)
restart:
for {
err := fn(ctx)
if err == nil {
return nil
}
if IsTerminalError(err) {
return err
}
backoff.Reset()
backoff:
for {
ticker.Reset(backoff.NextBackOff())
next, err := s.Exec(ctx)
switch next {
case nextRestart:
continue restart
case nextBackoff:
continue backoff
case nextExit:
return err
default:
panic("unreachable")
}
}
}
}
type selectCase struct {
watches []watch
cases []reflect.SelectCase
}
func makeSelect(
ctx context.Context,
watches []watch,
name string,
ch <-chan time.Time,
fn func(context.Context) error,
) *selectCase {
watches = append(watches,
watch{
name: "context",
fn: func(ctx context.Context) error {
// unreachable, the context handler will never be called
// as its channel can only be closed
return ctx.Err()
},
ch: reflect.ValueOf(ctx.Done()),
},
watch{
name: name,
fn: fn,
ch: reflect.ValueOf(ch),
this: true,
},
)
cases := make([]reflect.SelectCase, 0, len(watches))
for _, w := range watches {
cases = append(cases, reflect.SelectCase{
Dir: reflect.SelectRecv,
Chan: w.ch,
})
}
return &selectCase{
watches: watches,
cases: cases,
}
}
type next int
const (
nextRestart next = iota // try again from the beginning
nextBackoff // backoff and try again
nextExit // exit
)
func (s *selectCase) Exec(ctx context.Context) (next, error) {
chosen, _, ok := reflect.Select(s.cases)
if !ok {
return nextExit, fmt.Errorf("watch %s closed", s.watches[chosen].name)
}
w := s.watches[chosen]
err := w.fn(ctx)
if err != nil {
return onError(w, err)
}
if !w.this {
return nextRestart, nil
}
return nextExit, nil
}
func onError(w watch, err error) (next, error) {
if IsTerminalError(err) {
return nextExit, err
}
if w.this {
return nextBackoff, fmt.Errorf("retry %s failed: %w", w.name, err)
}
panic("unreachable, as watches are wrapped in retries and may only return terminal errors")
}

View file

@ -0,0 +1,124 @@
package retry_test
import (
"context"
"errors"
"fmt"
"testing"
"time"
"github.com/stretchr/testify/require"
"github.com/pomerium/pomerium/internal/retry"
)
func TestRetry(t *testing.T) {
t.Parallel()
ctx := context.Background()
limit := retry.WithMaxInterval(time.Second * 5)
t.Run("no error", func(t *testing.T) {
t.Parallel()
err := retry.Retry(ctx, "test", func(_ context.Context) error {
return nil
}, limit)
require.NoError(t, err)
})
t.Run("eventually succeeds", func(t *testing.T) {
t.Parallel()
i := 0
err := retry.Retry(ctx, "test", func(_ context.Context) error {
if i++; i > 2 {
return nil
}
return fmt.Errorf("transient %d", i)
}, limit)
require.NoError(t, err)
})
t.Run("eventually fails", func(t *testing.T) {
t.Parallel()
i := 0
err := retry.Retry(ctx, "test", func(_ context.Context) error {
if i++; i > 2 {
return retry.NewTerminalError(errors.New("the end"))
}
return fmt.Errorf("transient %d", i)
})
require.Error(t, err)
})
t.Run("context canceled", func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithCancel(ctx)
cancel()
err := retry.Retry(ctx, "test", func(_ context.Context) error {
return fmt.Errorf("retry")
})
require.Error(t, err)
})
t.Run("context canceled after retry", func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithCancel(ctx)
t.Cleanup(cancel)
err := retry.Retry(ctx, "test", func(_ context.Context) error {
cancel()
return fmt.Errorf("retry")
})
require.Error(t, err)
})
t.Run("success after watch hook", func(t *testing.T) {
t.Parallel()
ch := make(chan struct{}, 1)
ch <- struct{}{}
ok := false
err := retry.Retry(ctx, "test", func(_ context.Context) error {
if ok {
return nil
}
return fmt.Errorf("retry")
}, retry.WithWatch("watch", ch, func(_ context.Context) error {
ok = true
return nil
}), limit)
require.NoError(t, err)
})
t.Run("success after watch hook retried", func(t *testing.T) {
t.Parallel()
ch := make(chan struct{}, 1)
ch <- struct{}{}
ok := false
i := 0
err := retry.Retry(ctx, "test", func(_ context.Context) error {
if ok {
return nil
}
return fmt.Errorf("retry test")
}, retry.WithWatch("watch", ch, func(_ context.Context) error {
if i++; i > 1 {
ok = true
return nil
}
return fmt.Errorf("retry watch")
}), limit)
require.NoError(t, err)
})
t.Run("watch hook fails", func(t *testing.T) {
t.Parallel()
ch := make(chan struct{}, 1)
ch <- struct{}{}
err := retry.Retry(ctx, "test", func(_ context.Context) error {
return fmt.Errorf("retry")
}, retry.WithWatch("watch", ch, func(_ context.Context) error {
return retry.NewTerminalError(fmt.Errorf("watch"))
}), limit)
require.Error(t, err)
})
}