pomerium/authenticate/circuit/breaker_test.go
2019-01-02 12:13:36 -08:00

187 lines
4.5 KiB
Go

package circuit // import "github.com/pomerium/pomerium/internal/circuit"
import (
"errors"
"sync"
"testing"
"time"
"github.com/benbjohnson/clock"
)
var errFailed = errors.New("failed")
func fail() (interface{}, error) {
return nil, errFailed
}
func succeed() (interface{}, error) {
return nil, nil
}
func TestCircuitBreaker(t *testing.T) {
mock := clock.NewMock()
threshold := 3
timeout := time.Duration(2) * time.Second
trip := func(c Counts) bool { return c.ConsecutiveFailures > threshold }
reset := func(c Counts) bool { return c.ConsecutiveSuccesses > threshold }
backoff := func(c Counts) time.Duration { return timeout }
stateChange := func(p, c State) { t.Logf("state change from %s to %s\n", p, c) }
cb := NewBreaker(&Options{
TestClock: mock,
ShouldTripFunc: trip,
ShouldResetFunc: reset,
BackoffDurationFunc: backoff,
OnStateChange: stateChange,
})
state, _ := cb.currentState()
if state != StateClosed {
t.Fatalf("expected state to start %s, got %s", StateClosed, state)
}
for i := 0; i <= threshold; i++ {
_, err := cb.Call(fail)
if err == nil {
t.Fatalf("expected to error, got nil")
}
state, _ := cb.currentState()
t.Logf("iteration %#v", i)
if i == threshold {
// we expect this to be the case to trip the circuit
if state != StateOpen {
t.Fatalf("expected state to be %s, got %s", StateOpen, state)
}
} else if state != StateClosed {
// this is a normal failure case
t.Fatalf("expected state to be %s, got %s", StateClosed, state)
}
}
_, err := cb.Call(fail)
switch err.(type) {
case *ErrOpenState:
// this is the expected case
break
default:
t.Errorf("%#v", cb.counts)
t.Fatalf("expected to get open state failure, got %s", err)
}
// we advance time by the timeout and a hair
mock.Add(timeout + time.Duration(1)*time.Millisecond)
state, _ = cb.currentState()
if state != StateHalfOpen {
t.Fatalf("expected state to be %s, got %s", StateHalfOpen, state)
}
for i := 0; i <= threshold; i++ {
_, err := cb.Call(succeed)
if err != nil {
t.Fatalf("expected to get no error, got %s", err)
}
state, _ := cb.currentState()
t.Logf("iteration %#v", i)
if i == threshold {
// we expect this to be the case that ressets the circuit
if state != StateClosed {
t.Fatalf("expected state to be %s, got %s", StateClosed, state)
}
} else if state != StateHalfOpen {
t.Fatalf("expected state to be %s, got %s", StateHalfOpen, state)
}
}
state, _ = cb.currentState()
if state != StateClosed {
t.Fatalf("expected state to be %s, got %s", StateClosed, state)
}
}
func TestExponentialBackOffFunc(t *testing.T) {
baseTimeout := time.Duration(1) * time.Millisecond
// Note Expected is an upper range case
cases := []struct {
FailureCount int
Expected time.Duration
}{
{
FailureCount: 0,
Expected: time.Duration(1) * time.Millisecond,
},
{
FailureCount: 1,
Expected: time.Duration(2) * time.Millisecond,
},
{
FailureCount: 2,
Expected: time.Duration(4) * time.Millisecond,
},
{
FailureCount: 3,
Expected: time.Duration(8) * time.Millisecond,
},
{
FailureCount: 4,
Expected: time.Duration(16) * time.Millisecond,
},
{
FailureCount: 5,
Expected: time.Duration(32) * time.Millisecond,
},
{
FailureCount: 6,
Expected: time.Duration(64) * time.Millisecond,
},
{
FailureCount: 7,
Expected: time.Duration(128) * time.Millisecond,
},
{
FailureCount: 8,
Expected: time.Duration(256) * time.Millisecond,
},
{
FailureCount: 9,
Expected: time.Duration(512) * time.Millisecond,
},
{
FailureCount: 10,
Expected: time.Duration(1024) * time.Millisecond,
},
}
f := ExponentialBackoffDuration(time.Duration(1)*time.Hour, baseTimeout)
for _, tc := range cases {
got := f(Counts{ConsecutiveFailures: tc.FailureCount})
t.Logf("got backoff %#v", got)
if got > tc.Expected {
t.Errorf("got %#v but expected less than %#v", got, tc.Expected)
}
}
}
func TestCircuitBreakerClosedParallel(t *testing.T) {
cb := NewBreaker(nil)
numReqs := 10000
wg := &sync.WaitGroup{}
routine := func(wg *sync.WaitGroup) {
for i := 0; i < numReqs; i++ {
cb.Call(succeed)
}
wg.Done()
}
numRoutines := 10
for i := 0; i < numRoutines; i++ {
wg.Add(1)
go routine(wg)
}
total := numReqs * numRoutines
wg.Wait()
if cb.counts.ConsecutiveSuccesses != total {
t.Fatalf("expected to get total requests %d, got %d", total, cb.counts.ConsecutiveSuccesses)
}
}