pomerium/authenticate/circuit/breaker.go
Bobby 0766725ff8
proxy : add HTTP endpoint support (#13)
proxy : Add HTTP (insecure) endpoint support, closes #11.

* Fix typos
* Fixed additional typos and an ineffectual assignment
* Update route configuration in docs
2019-01-15 15:06:16 -08:00

329 lines
7.9 KiB
Go

// Package circuit implements the Circuit Breaker pattern.
// https://docs.microsoft.com/en-us/azure/architecture/patterns/circuit-breaker
package circuit // import "github.com/pomerium/pomerium/internal/circuit"
import (
"fmt"
"math"
"math/rand"
"sync"
"time"
"github.com/benbjohnson/clock"
)
// State is a type that represents a state of Breaker.
type State int
// These constants are states of Breaker.
const (
StateClosed State = iota
StateHalfOpen
StateOpen
)
type (
// ShouldTripFunc is a function that takes in a Counts and returns true if the circuit breaker should be tripped.
ShouldTripFunc func(Counts) bool
// ShouldResetFunc is a function that takes in a Counts and returns true if the circuit breaker should be reset.
ShouldResetFunc func(Counts) bool
// BackoffDurationFunc is a function that takes in a Counts and returns the backoff duration
BackoffDurationFunc func(Counts) time.Duration
// StateChangeHook is a function that represents a state change.
StateChangeHook func(prev, to State)
// BackoffHook is a function that represents backoff.
BackoffHook func(duration time.Duration, reset time.Time)
)
var (
// DefaultShouldTripFunc is a default ShouldTripFunc.
DefaultShouldTripFunc = func(counts Counts) bool {
// Trip into Open after three consecutive failures
return counts.ConsecutiveFailures >= 3
}
// DefaultShouldResetFunc is a default ShouldResetFunc.
DefaultShouldResetFunc = func(counts Counts) bool {
// Reset after three consecutive successes
return counts.ConsecutiveSuccesses >= 3
}
// DefaultBackoffDurationFunc is an exponential backoff function
DefaultBackoffDurationFunc = ExponentialBackoffDuration(time.Duration(100)*time.Second, time.Duration(500)*time.Millisecond)
)
// ErrOpenState is returned when the b state is open
type ErrOpenState struct{}
func (e *ErrOpenState) Error() string { return "circuit breaker is open" }
// ExponentialBackoffDuration returns a function that uses exponential backoff and full jitter
func ExponentialBackoffDuration(maxBackoff, baseTimeout time.Duration) func(Counts) time.Duration {
return func(counts Counts) time.Duration {
// Full Jitter from https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/
// sleep = random_between(0, min(cap, base * 2 ** attempt))
backoff := math.Min(float64(maxBackoff), float64(baseTimeout)*math.Exp2(float64(counts.ConsecutiveFailures)))
jittered := rand.Float64() * backoff
return time.Duration(jittered)
}
}
// String implements stringer interface.
func (s State) String() string {
switch s {
case StateClosed:
return "closed"
case StateHalfOpen:
return "half-open"
case StateOpen:
return "open"
default:
return fmt.Sprintf("unknown state: %d", s)
}
}
// Counts holds the numbers of requests and their successes/failures.
type Counts struct {
CurrentRequests int
ConsecutiveSuccesses int
ConsecutiveFailures int
}
func (c *Counts) onRequest() {
c.CurrentRequests++
}
func (c *Counts) afterRequest() {
c.CurrentRequests--
}
func (c *Counts) onSuccess() {
c.ConsecutiveSuccesses++
c.ConsecutiveFailures = 0
}
func (c *Counts) onFailure() {
c.ConsecutiveFailures++
c.ConsecutiveSuccesses = 0
}
func (c *Counts) clear() {
c.ConsecutiveSuccesses = 0
c.ConsecutiveFailures = 0
}
// Options configures Breaker:
//
// HalfOpenConcurrentRequests specifies how many concurrent requests to allow while
// the circuit is in the half-open state
//
// ShouldTripFunc specifies when the circuit should trip from the closed state to
// the open state. It takes a Counts struct and returns a bool.
//
// ShouldResetFunc specifies when the circuit should be reset from the half-open state
// to the closed state and allow all requests. It takes a Counts struct and returns a bool.
//
// BackoffDurationFunc specifies how long to set the backoff duration. It takes a
// counts struct and returns a time.Duration
//
// OnStateChange is called whenever the state of the Breaker changes.
//
// OnBackoff is called whenever a backoff is set with the backoff duration and reset time
//
// TestClock is used to mock the clock during tests
type Options struct {
HalfOpenConcurrentRequests int
ShouldTripFunc ShouldTripFunc
ShouldResetFunc ShouldResetFunc
BackoffDurationFunc BackoffDurationFunc
// hooks
OnStateChange StateChangeHook
OnBackoff BackoffHook
// used in tests
TestClock clock.Clock
}
// Breaker is a state machine to prevent sending requests that are likely to fail.
type Breaker struct {
halfOpenRequests int
shouldTripFunc ShouldTripFunc
shouldResetFunc ShouldResetFunc
backoffDurationFunc BackoffDurationFunc
// hooks
onStateChange StateChangeHook
onBackoff BackoffHook
// used primarily for mocking tests
clock clock.Clock
mutex sync.Mutex
state State
counts Counts
backoffExpires time.Time
generation int
}
// NewBreaker returns a new Breaker configured with the given Settings.
func NewBreaker(opts *Options) *Breaker {
b := new(Breaker)
if opts == nil {
opts = &Options{}
}
// set hooks
b.onStateChange = opts.OnStateChange
b.onBackoff = opts.OnBackoff
b.halfOpenRequests = 1
if opts.HalfOpenConcurrentRequests > 0 {
b.halfOpenRequests = opts.HalfOpenConcurrentRequests
}
b.backoffDurationFunc = DefaultBackoffDurationFunc
if opts.BackoffDurationFunc != nil {
b.backoffDurationFunc = opts.BackoffDurationFunc
}
b.shouldTripFunc = DefaultShouldTripFunc
if opts.ShouldTripFunc != nil {
b.shouldTripFunc = opts.ShouldTripFunc
}
b.shouldResetFunc = DefaultShouldResetFunc
if opts.ShouldResetFunc != nil {
b.shouldResetFunc = opts.ShouldResetFunc
}
b.clock = clock.New()
if opts.TestClock != nil {
b.clock = opts.TestClock
}
b.setState(StateClosed)
return b
}
// Call runs the given function if the Breaker allows the call.
// Call returns an error instantly if the Breaker rejects the request.
// Otherwise, Call returns the result of the request.
func (b *Breaker) Call(f func() (interface{}, error)) (interface{}, error) {
generation, err := b.beforeRequest()
if err != nil {
return nil, err
}
result, err := f()
b.afterRequest(err == nil, generation)
return result, err
}
func (b *Breaker) beforeRequest() (int, error) {
b.mutex.Lock()
defer b.mutex.Unlock()
state, generation := b.currentState()
switch state {
case StateOpen:
return generation, &ErrOpenState{}
case StateHalfOpen:
if b.counts.CurrentRequests >= b.halfOpenRequests {
return generation, &ErrOpenState{}
}
}
b.counts.onRequest()
return generation, nil
}
func (b *Breaker) afterRequest(success bool, prevGeneration int) {
b.mutex.Lock()
defer b.mutex.Unlock()
b.counts.afterRequest()
state, generation := b.currentState()
if prevGeneration != generation {
return
}
if success {
b.onSuccess(state)
return
}
b.onFailure(state)
}
func (b *Breaker) onSuccess(state State) {
b.counts.onSuccess()
switch state {
case StateHalfOpen:
if b.shouldResetFunc(b.counts) {
b.setState(StateClosed)
b.counts.clear()
}
}
}
func (b *Breaker) onFailure(state State) {
b.counts.onFailure()
switch state {
case StateClosed:
if b.shouldTripFunc(b.counts) {
b.setState(StateOpen)
b.counts.clear()
b.setBackoff()
}
case StateOpen:
b.setBackoff()
case StateHalfOpen:
b.setState(StateOpen)
b.setBackoff()
}
}
func (b *Breaker) setBackoff() {
backoffDuration := b.backoffDurationFunc(b.counts)
backoffExpires := b.clock.Now().Add(backoffDuration)
b.backoffExpires = backoffExpires
if b.onBackoff != nil {
b.onBackoff(backoffDuration, backoffExpires)
}
}
func (b *Breaker) currentState() (State, int) {
switch b.state {
case StateOpen:
if b.clock.Now().After(b.backoffExpires) {
b.setState(StateHalfOpen)
}
}
return b.state, b.generation
}
func (b *Breaker) newGeneration() {
b.generation++
}
func (b *Breaker) setState(state State) {
if b.state == state {
return
}
b.newGeneration()
prev := b.state
b.state = state
if b.onStateChange != nil {
b.onStateChange(prev, state)
}
}