pomerium/internal/tripper/chain_test.go
cfanbo 84dad4c612
remove deprecated ioutil usages (#2877)
* fix: Fixed return description error

* config/options: Adjust the position of TracingJaegerAgentEndpoint option

* DOCS: Remove duplicate configuration items

Remove duplicate configuration items of route

* remove deprecated ioutil usages
2021-12-30 10:02:12 -08:00

121 lines
2.7 KiB
Go

package tripper
import (
"fmt"
"io"
"net/http"
"net/http/httptest"
"testing"
)
type mockTransport struct {
id string
}
func (t *mockTransport) RoundTrip(r *http.Request) (*http.Response, error) {
w := httptest.NewRecorder()
w.WriteString(t.id)
return w.Result(), nil
}
// mockMiddleware appends the id into the response body as
// the call stack unwinds.
//
// If your chain is c1->c2->t, it should return 't,c2,c1'
func mockMiddleware(id string) func(next http.RoundTripper) http.RoundTripper {
return func(next http.RoundTripper) http.RoundTripper {
return RoundTripperFunc(func(r *http.Request) (*http.Response, error) {
resp, _ := next.RoundTrip(r)
body, _ := io.ReadAll(resp.Body)
mockResp := httptest.NewRecorder()
mockResp.Write(body)
mockResp.WriteString(fmt.Sprintf(",%s", id))
return mockResp.Result(), nil
})
}
}
func TestNew(t *testing.T) {
m1 := mockMiddleware("c1")
m2 := mockMiddleware("c2")
t1 := &mockTransport{id: "t"}
want := "t,c2,c1"
chain := NewChain(m1, m2)
resp, _ := chain.Then(t1).
RoundTrip(httptest.NewRequest("GET", "/", nil))
if len(chain.constructors) != 2 {
t.Errorf("Wrong number of constructors in chain")
}
b, _ := io.ReadAll(resp.Body)
if string(b) != want {
t.Errorf("Wrong constructors. want=%s, got=%s", want, b)
}
}
func TestThenNoMiddleware(t *testing.T) {
chain := NewChain()
t1 := &mockTransport{id: "t"}
want := "t"
resp, _ := chain.Then(t1).
RoundTrip(httptest.NewRequest("GET", "/", nil))
b, _ := io.ReadAll(resp.Body)
if string(b) != want {
t.Errorf("Wrong constructors. want=%s, got=%s", want, b)
}
}
func TestNilThen(t *testing.T) {
if NewChain().Then(nil) != http.DefaultTransport {
t.Error("Then does not treat nil as DefaultTransport")
}
}
func TestAppend(t *testing.T) {
chain := NewChain(mockMiddleware("c1"))
if len(chain.constructors) != 1 {
t.Errorf("Wrong number of constructors in chain")
}
chain = chain.Append(mockMiddleware("c2"))
t1 := &mockTransport{id: "t"}
want := "t,c2,c1"
resp, _ := chain.Then(t1).
RoundTrip(httptest.NewRequest("GET", "/", nil))
if len(chain.constructors) != 2 {
t.Errorf("Wrong number of constructors in chain")
}
b, _ := io.ReadAll(resp.Body)
if string(b) != want {
t.Errorf("Wrong constructors. want=%s, got=%s", want, b)
}
}
func TestAppendImmutability(t *testing.T) {
chain := NewChain(mockMiddleware("c1"))
chain.Append(mockMiddleware("c2"))
t1 := &mockTransport{id: "t"}
want := "t,c1"
if len(chain.constructors) != 1 {
t.Errorf("Append does not respect immutability")
}
resp, _ := chain.Then(t1).
RoundTrip(httptest.NewRequest("GET", "/", nil))
b, _ := io.ReadAll(resp.Body)
if string(b) != want {
t.Errorf("Wrong constructors. want=%s, got=%s", want, b)
}
}