mirror of
https://github.com/pomerium/pomerium.git
synced 2025-04-30 19:06:33 +02:00
187 lines
4.5 KiB
Go
187 lines
4.5 KiB
Go
package circuit // import "github.com/pomerium/pomerium/internal/circuit"
|
|
|
|
import (
|
|
"errors"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/benbjohnson/clock"
|
|
)
|
|
|
|
var errFailed = errors.New("failed")
|
|
|
|
func fail() (interface{}, error) {
|
|
return nil, errFailed
|
|
}
|
|
|
|
func succeed() (interface{}, error) {
|
|
return nil, nil
|
|
}
|
|
|
|
func TestCircuitBreaker(t *testing.T) {
|
|
mock := clock.NewMock()
|
|
threshold := 3
|
|
timeout := time.Duration(2) * time.Second
|
|
trip := func(c Counts) bool { return c.ConsecutiveFailures > threshold }
|
|
reset := func(c Counts) bool { return c.ConsecutiveSuccesses > threshold }
|
|
backoff := func(c Counts) time.Duration { return timeout }
|
|
stateChange := func(p, c State) { t.Logf("state change from %s to %s\n", p, c) }
|
|
cb := NewBreaker(&Options{
|
|
TestClock: mock,
|
|
ShouldTripFunc: trip,
|
|
ShouldResetFunc: reset,
|
|
BackoffDurationFunc: backoff,
|
|
OnStateChange: stateChange,
|
|
})
|
|
state, _ := cb.currentState()
|
|
if state != StateClosed {
|
|
t.Fatalf("expected state to start %s, got %s", StateClosed, state)
|
|
}
|
|
|
|
for i := 0; i <= threshold; i++ {
|
|
_, err := cb.Call(fail)
|
|
if err == nil {
|
|
t.Fatalf("expected to error, got nil")
|
|
}
|
|
state, _ := cb.currentState()
|
|
t.Logf("iteration %#v", i)
|
|
if i == threshold {
|
|
// we expect this to be the case to trip the circuit
|
|
if state != StateOpen {
|
|
t.Fatalf("expected state to be %s, got %s", StateOpen, state)
|
|
}
|
|
} else if state != StateClosed {
|
|
// this is a normal failure case
|
|
t.Fatalf("expected state to be %s, got %s", StateClosed, state)
|
|
}
|
|
}
|
|
|
|
_, err := cb.Call(fail)
|
|
switch err.(type) {
|
|
case *ErrOpenState:
|
|
// this is the expected case
|
|
break
|
|
default:
|
|
t.Errorf("%#v", cb.counts)
|
|
t.Fatalf("expected to get open state failure, got %s", err)
|
|
}
|
|
|
|
// we advance time by the timeout and a hair
|
|
mock.Add(timeout + time.Duration(1)*time.Millisecond)
|
|
state, _ = cb.currentState()
|
|
if state != StateHalfOpen {
|
|
t.Fatalf("expected state to be %s, got %s", StateHalfOpen, state)
|
|
}
|
|
|
|
for i := 0; i <= threshold; i++ {
|
|
_, err := cb.Call(succeed)
|
|
if err != nil {
|
|
t.Fatalf("expected to get no error, got %s", err)
|
|
}
|
|
state, _ := cb.currentState()
|
|
t.Logf("iteration %#v", i)
|
|
if i == threshold {
|
|
// we expect this to be the case that ressets the circuit
|
|
if state != StateClosed {
|
|
t.Fatalf("expected state to be %s, got %s", StateClosed, state)
|
|
}
|
|
} else if state != StateHalfOpen {
|
|
t.Fatalf("expected state to be %s, got %s", StateHalfOpen, state)
|
|
}
|
|
}
|
|
|
|
state, _ = cb.currentState()
|
|
if state != StateClosed {
|
|
t.Fatalf("expected state to be %s, got %s", StateClosed, state)
|
|
}
|
|
}
|
|
|
|
func TestExponentialBackOffFunc(t *testing.T) {
|
|
baseTimeout := time.Duration(1) * time.Millisecond
|
|
// Note Expected is an upper range case
|
|
cases := []struct {
|
|
FailureCount int
|
|
Expected time.Duration
|
|
}{
|
|
{
|
|
FailureCount: 0,
|
|
Expected: time.Duration(1) * time.Millisecond,
|
|
},
|
|
{
|
|
FailureCount: 1,
|
|
Expected: time.Duration(2) * time.Millisecond,
|
|
},
|
|
{
|
|
FailureCount: 2,
|
|
Expected: time.Duration(4) * time.Millisecond,
|
|
},
|
|
{
|
|
FailureCount: 3,
|
|
Expected: time.Duration(8) * time.Millisecond,
|
|
},
|
|
{
|
|
FailureCount: 4,
|
|
Expected: time.Duration(16) * time.Millisecond,
|
|
},
|
|
{
|
|
FailureCount: 5,
|
|
Expected: time.Duration(32) * time.Millisecond,
|
|
},
|
|
{
|
|
FailureCount: 6,
|
|
Expected: time.Duration(64) * time.Millisecond,
|
|
},
|
|
{
|
|
FailureCount: 7,
|
|
Expected: time.Duration(128) * time.Millisecond,
|
|
},
|
|
{
|
|
FailureCount: 8,
|
|
Expected: time.Duration(256) * time.Millisecond,
|
|
},
|
|
{
|
|
FailureCount: 9,
|
|
Expected: time.Duration(512) * time.Millisecond,
|
|
},
|
|
{
|
|
FailureCount: 10,
|
|
Expected: time.Duration(1024) * time.Millisecond,
|
|
},
|
|
}
|
|
|
|
f := ExponentialBackoffDuration(time.Duration(1)*time.Hour, baseTimeout)
|
|
for _, tc := range cases {
|
|
got := f(Counts{ConsecutiveFailures: tc.FailureCount})
|
|
t.Logf("got backoff %#v", got)
|
|
if got > tc.Expected {
|
|
t.Errorf("got %#v but expected less than %#v", got, tc.Expected)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestCircuitBreakerClosedParallel(t *testing.T) {
|
|
cb := NewBreaker(nil)
|
|
numReqs := 10000
|
|
wg := &sync.WaitGroup{}
|
|
routine := func(wg *sync.WaitGroup) {
|
|
for i := 0; i < numReqs; i++ {
|
|
cb.Call(succeed)
|
|
}
|
|
wg.Done()
|
|
}
|
|
|
|
numRoutines := 10
|
|
for i := 0; i < numRoutines; i++ {
|
|
wg.Add(1)
|
|
go routine(wg)
|
|
}
|
|
|
|
total := numReqs * numRoutines
|
|
|
|
wg.Wait()
|
|
|
|
if cb.counts.ConsecutiveSuccesses != total {
|
|
t.Fatalf("expected to get total requests %d, got %d", total, cb.counts.ConsecutiveSuccesses)
|
|
}
|
|
}
|