mirror of
https://github.com/pomerium/pomerium.git
synced 2025-06-06 12:52:53 +02:00
add retry package (#4458)
This commit is contained in:
parent
0d29401192
commit
5ddfc74645
5 changed files with 416 additions and 0 deletions
76
internal/retry/config.go
Normal file
76
internal/retry/config.go
Normal file
|
@ -0,0 +1,76 @@
|
||||||
|
package retry
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"reflect"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/cenkalti/backoff/v4"
|
||||||
|
)
|
||||||
|
|
||||||
|
type config struct {
|
||||||
|
maxInterval time.Duration
|
||||||
|
watches []watch
|
||||||
|
|
||||||
|
backoff.BackOff
|
||||||
|
}
|
||||||
|
|
||||||
|
// watch is a helper struct to watch multiple channels
|
||||||
|
type watch struct {
|
||||||
|
name string
|
||||||
|
ch reflect.Value
|
||||||
|
fn func(context.Context) error
|
||||||
|
this bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// Option configures the retry handler
|
||||||
|
type Option func(*config)
|
||||||
|
|
||||||
|
// WithWatch adds a watch to the retry handler
|
||||||
|
// that will be triggered when a value is received on the channel
|
||||||
|
// and the function will be called, also within a retry handler
|
||||||
|
func WithWatch[T any](name string, ch <-chan T, fn func(context.Context) error) Option {
|
||||||
|
return func(cfg *config) {
|
||||||
|
cfg.watches = append(cfg.watches, watch{name: name, ch: reflect.ValueOf(ch), fn: fn, this: false})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithMaxInterval sets the upper bound for the retry handler
|
||||||
|
func WithMaxInterval(d time.Duration) Option {
|
||||||
|
return func(cfg *config) {
|
||||||
|
cfg.maxInterval = d
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newConfig(opts ...Option) ([]watch, backoff.BackOff) {
|
||||||
|
cfg := new(config)
|
||||||
|
for _, opt := range []Option{
|
||||||
|
WithMaxInterval(time.Minute * 5),
|
||||||
|
} {
|
||||||
|
opt(cfg)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt(cfg)
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, w := range cfg.watches {
|
||||||
|
cfg.watches[i].fn = withRetry(cfg, w)
|
||||||
|
}
|
||||||
|
|
||||||
|
bo := backoff.NewExponentialBackOff()
|
||||||
|
bo.MaxInterval = cfg.maxInterval
|
||||||
|
bo.MaxElapsedTime = 0
|
||||||
|
|
||||||
|
return cfg.watches, bo
|
||||||
|
}
|
||||||
|
|
||||||
|
func withRetry(cfg *config, w watch) func(context.Context) error {
|
||||||
|
if w.fn == nil {
|
||||||
|
return func(_ context.Context) error { return nil }
|
||||||
|
}
|
||||||
|
|
||||||
|
return func(ctx context.Context) error {
|
||||||
|
return Retry(ctx, w.name, w.fn, WithMaxInterval(cfg.maxInterval))
|
||||||
|
}
|
||||||
|
}
|
48
internal/retry/error.go
Normal file
48
internal/retry/error.go
Normal file
|
@ -0,0 +1,48 @@
|
||||||
|
package retry
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TerminalError is an error that should not be retried
|
||||||
|
type TerminalError interface {
|
||||||
|
error
|
||||||
|
IsTerminal()
|
||||||
|
}
|
||||||
|
|
||||||
|
// terminalError is an error that should not be retried
|
||||||
|
type terminalError struct {
|
||||||
|
Err error
|
||||||
|
}
|
||||||
|
|
||||||
|
// Error implements error for terminalError
|
||||||
|
func (e *terminalError) Error() string {
|
||||||
|
return fmt.Sprintf("terminal error: %v", e.Err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unwrap implements errors.Unwrap for terminalError
|
||||||
|
func (e *terminalError) Unwrap() error {
|
||||||
|
return e.Err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Is implements errors.Is for terminalError
|
||||||
|
func (e *terminalError) Is(err error) bool {
|
||||||
|
//nolint:errorlint
|
||||||
|
_, ok := err.(*terminalError)
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsTerminal implements TerminalError for terminalError
|
||||||
|
func (e *terminalError) IsTerminal() {}
|
||||||
|
|
||||||
|
// NewTerminalError creates a new terminal error that cannot be retried
|
||||||
|
func NewTerminalError(err error) error {
|
||||||
|
return &terminalError{Err: err}
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsTerminalError returns true if the error is a terminal error
|
||||||
|
func IsTerminalError(err error) bool {
|
||||||
|
var te TerminalError
|
||||||
|
return errors.As(err, &te)
|
||||||
|
}
|
29
internal/retry/error_test.go
Normal file
29
internal/retry/error_test.go
Normal file
|
@ -0,0 +1,29 @@
|
||||||
|
package retry_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/internal/retry"
|
||||||
|
)
|
||||||
|
|
||||||
|
type testError string
|
||||||
|
|
||||||
|
func (e testError) Error() string {
|
||||||
|
return string(e)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e testError) IsTerminal() {}
|
||||||
|
|
||||||
|
func TestError(t *testing.T) {
|
||||||
|
t.Run("local terminal error", func(t *testing.T) {
|
||||||
|
err := fmt.Errorf("wrap: %w", retry.NewTerminalError(fmt.Errorf("inner")))
|
||||||
|
require.True(t, retry.IsTerminalError(err))
|
||||||
|
})
|
||||||
|
t.Run("external terminal error", func(t *testing.T) {
|
||||||
|
err := fmt.Errorf("wrap: %w", testError("inner"))
|
||||||
|
require.True(t, retry.IsTerminalError(err))
|
||||||
|
})
|
||||||
|
}
|
139
internal/retry/retry.go
Normal file
139
internal/retry/retry.go
Normal file
|
@ -0,0 +1,139 @@
|
||||||
|
// Package retry provides a retry loop with exponential back-off
|
||||||
|
// while watching arbitrary signal channels for side effects.
|
||||||
|
package retry
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Retry retries a function (with exponential back-off) until it succeeds.
|
||||||
|
// It additionally watches arbitrary channels and calls the handler function when a value is received.
|
||||||
|
// Handler functions are also retried with exponential back-off.
|
||||||
|
// If a terminal error is returned from the handler function, the retry loop is aborted.
|
||||||
|
// If the context is canceled, the retry loop is aborted.
|
||||||
|
func Retry(
|
||||||
|
ctx context.Context,
|
||||||
|
name string,
|
||||||
|
fn func(context.Context) error,
|
||||||
|
opts ...Option,
|
||||||
|
) error {
|
||||||
|
watches, backoff := newConfig(opts...)
|
||||||
|
ticker := time.NewTicker(backoff.NextBackOff())
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
s := makeSelect(ctx, watches, name, ticker.C, fn)
|
||||||
|
|
||||||
|
restart:
|
||||||
|
for {
|
||||||
|
err := fn(ctx)
|
||||||
|
if err == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if IsTerminalError(err) {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
backoff.Reset()
|
||||||
|
backoff:
|
||||||
|
for {
|
||||||
|
ticker.Reset(backoff.NextBackOff())
|
||||||
|
|
||||||
|
next, err := s.Exec(ctx)
|
||||||
|
switch next {
|
||||||
|
case nextRestart:
|
||||||
|
continue restart
|
||||||
|
case nextBackoff:
|
||||||
|
continue backoff
|
||||||
|
case nextExit:
|
||||||
|
return err
|
||||||
|
default:
|
||||||
|
panic("unreachable")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type selectCase struct {
|
||||||
|
watches []watch
|
||||||
|
cases []reflect.SelectCase
|
||||||
|
}
|
||||||
|
|
||||||
|
func makeSelect(
|
||||||
|
ctx context.Context,
|
||||||
|
watches []watch,
|
||||||
|
name string,
|
||||||
|
ch <-chan time.Time,
|
||||||
|
fn func(context.Context) error,
|
||||||
|
) *selectCase {
|
||||||
|
watches = append(watches,
|
||||||
|
watch{
|
||||||
|
name: "context",
|
||||||
|
fn: func(ctx context.Context) error {
|
||||||
|
// unreachable, the context handler will never be called
|
||||||
|
// as its channel can only be closed
|
||||||
|
return ctx.Err()
|
||||||
|
},
|
||||||
|
ch: reflect.ValueOf(ctx.Done()),
|
||||||
|
},
|
||||||
|
watch{
|
||||||
|
name: name,
|
||||||
|
fn: fn,
|
||||||
|
ch: reflect.ValueOf(ch),
|
||||||
|
this: true,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
cases := make([]reflect.SelectCase, 0, len(watches))
|
||||||
|
for _, w := range watches {
|
||||||
|
cases = append(cases, reflect.SelectCase{
|
||||||
|
Dir: reflect.SelectRecv,
|
||||||
|
Chan: w.ch,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return &selectCase{
|
||||||
|
watches: watches,
|
||||||
|
cases: cases,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type next int
|
||||||
|
|
||||||
|
const (
|
||||||
|
nextRestart next = iota // try again from the beginning
|
||||||
|
nextBackoff // backoff and try again
|
||||||
|
nextExit // exit
|
||||||
|
)
|
||||||
|
|
||||||
|
func (s *selectCase) Exec(ctx context.Context) (next, error) {
|
||||||
|
chosen, _, ok := reflect.Select(s.cases)
|
||||||
|
if !ok {
|
||||||
|
return nextExit, fmt.Errorf("watch %s closed", s.watches[chosen].name)
|
||||||
|
}
|
||||||
|
|
||||||
|
w := s.watches[chosen]
|
||||||
|
|
||||||
|
err := w.fn(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return onError(w, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !w.this {
|
||||||
|
return nextRestart, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return nextExit, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func onError(w watch, err error) (next, error) {
|
||||||
|
if IsTerminalError(err) {
|
||||||
|
return nextExit, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if w.this {
|
||||||
|
return nextBackoff, fmt.Errorf("retry %s failed: %w", w.name, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
panic("unreachable, as watches are wrapped in retries and may only return terminal errors")
|
||||||
|
}
|
124
internal/retry/retry_test.go
Normal file
124
internal/retry/retry_test.go
Normal file
|
@ -0,0 +1,124 @@
|
||||||
|
package retry_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/internal/retry"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRetry(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
limit := retry.WithMaxInterval(time.Second * 5)
|
||||||
|
|
||||||
|
t.Run("no error", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
err := retry.Retry(ctx, "test", func(_ context.Context) error {
|
||||||
|
return nil
|
||||||
|
}, limit)
|
||||||
|
require.NoError(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("eventually succeeds", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
i := 0
|
||||||
|
err := retry.Retry(ctx, "test", func(_ context.Context) error {
|
||||||
|
if i++; i > 2 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return fmt.Errorf("transient %d", i)
|
||||||
|
}, limit)
|
||||||
|
require.NoError(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("eventually fails", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
i := 0
|
||||||
|
err := retry.Retry(ctx, "test", func(_ context.Context) error {
|
||||||
|
if i++; i > 2 {
|
||||||
|
return retry.NewTerminalError(errors.New("the end"))
|
||||||
|
}
|
||||||
|
return fmt.Errorf("transient %d", i)
|
||||||
|
})
|
||||||
|
require.Error(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("context canceled", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
ctx, cancel := context.WithCancel(ctx)
|
||||||
|
cancel()
|
||||||
|
err := retry.Retry(ctx, "test", func(_ context.Context) error {
|
||||||
|
return fmt.Errorf("retry")
|
||||||
|
})
|
||||||
|
require.Error(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("context canceled after retry", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
ctx, cancel := context.WithCancel(ctx)
|
||||||
|
t.Cleanup(cancel)
|
||||||
|
err := retry.Retry(ctx, "test", func(_ context.Context) error {
|
||||||
|
cancel()
|
||||||
|
return fmt.Errorf("retry")
|
||||||
|
})
|
||||||
|
require.Error(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("success after watch hook", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
ch := make(chan struct{}, 1)
|
||||||
|
ch <- struct{}{}
|
||||||
|
ok := false
|
||||||
|
err := retry.Retry(ctx, "test", func(_ context.Context) error {
|
||||||
|
if ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return fmt.Errorf("retry")
|
||||||
|
}, retry.WithWatch("watch", ch, func(_ context.Context) error {
|
||||||
|
ok = true
|
||||||
|
return nil
|
||||||
|
}), limit)
|
||||||
|
require.NoError(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("success after watch hook retried", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
ch := make(chan struct{}, 1)
|
||||||
|
ch <- struct{}{}
|
||||||
|
ok := false
|
||||||
|
i := 0
|
||||||
|
err := retry.Retry(ctx, "test", func(_ context.Context) error {
|
||||||
|
if ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return fmt.Errorf("retry test")
|
||||||
|
}, retry.WithWatch("watch", ch, func(_ context.Context) error {
|
||||||
|
if i++; i > 1 {
|
||||||
|
ok = true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return fmt.Errorf("retry watch")
|
||||||
|
}), limit)
|
||||||
|
require.NoError(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("watch hook fails", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
ch := make(chan struct{}, 1)
|
||||||
|
ch <- struct{}{}
|
||||||
|
err := retry.Retry(ctx, "test", func(_ context.Context) error {
|
||||||
|
return fmt.Errorf("retry")
|
||||||
|
}, retry.WithWatch("watch", ch, func(_ context.Context) error {
|
||||||
|
return retry.NewTerminalError(fmt.Errorf("watch"))
|
||||||
|
}), limit)
|
||||||
|
require.Error(t, err)
|
||||||
|
})
|
||||||
|
}
|
Loading…
Add table
Add a link
Reference in a new issue