pomerium/internal/tripper/chain_test.go
bobby f837c92741
dev: update linter (#1728)
- gofumpt everything
- fix TLS MinVersion to be at least 1.2
- add octal syntax
- remove newlines
- fix potential decompression bomb in ecjson
- remove implicit memory aliasing in for loops.

Signed-off-by: Bobby DeSimone <bobbydesimone@gmail.com>
2020-12-30 09:02:57 -08:00

121 lines
2.7 KiB
Go

package tripper
import (
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
"testing"
)
type mockTransport struct {
id string
}
func (t *mockTransport) RoundTrip(r *http.Request) (*http.Response, error) {
w := httptest.NewRecorder()
w.WriteString(t.id)
return w.Result(), nil
}
// mockMiddleware appends the id into the response body as
// the call stack unwinds.
//
// If your chain is c1->c2->t, it should return 't,c2,c1'
func mockMiddleware(id string) func(next http.RoundTripper) http.RoundTripper {
return func(next http.RoundTripper) http.RoundTripper {
return RoundTripperFunc(func(r *http.Request) (*http.Response, error) {
resp, _ := next.RoundTrip(r)
body, _ := ioutil.ReadAll(resp.Body)
mockResp := httptest.NewRecorder()
mockResp.Write(body)
mockResp.WriteString(fmt.Sprintf(",%s", id))
return mockResp.Result(), nil
})
}
}
func TestNew(t *testing.T) {
m1 := mockMiddleware("c1")
m2 := mockMiddleware("c2")
t1 := &mockTransport{id: "t"}
want := "t,c2,c1"
chain := NewChain(m1, m2)
resp, _ := chain.Then(t1).
RoundTrip(httptest.NewRequest("GET", "/", nil))
if len(chain.constructors) != 2 {
t.Errorf("Wrong number of constructors in chain")
}
b, _ := ioutil.ReadAll(resp.Body)
if string(b) != want {
t.Errorf("Wrong constructors. want=%s, got=%s", want, b)
}
}
func TestThenNoMiddleware(t *testing.T) {
chain := NewChain()
t1 := &mockTransport{id: "t"}
want := "t"
resp, _ := chain.Then(t1).
RoundTrip(httptest.NewRequest("GET", "/", nil))
b, _ := ioutil.ReadAll(resp.Body)
if string(b) != want {
t.Errorf("Wrong constructors. want=%s, got=%s", want, b)
}
}
func TestNilThen(t *testing.T) {
if NewChain().Then(nil) != http.DefaultTransport {
t.Error("Then does not treat nil as DefaultTransport")
}
}
func TestAppend(t *testing.T) {
chain := NewChain(mockMiddleware("c1"))
if len(chain.constructors) != 1 {
t.Errorf("Wrong number of constructors in chain")
}
chain = chain.Append(mockMiddleware("c2"))
t1 := &mockTransport{id: "t"}
want := "t,c2,c1"
resp, _ := chain.Then(t1).
RoundTrip(httptest.NewRequest("GET", "/", nil))
if len(chain.constructors) != 2 {
t.Errorf("Wrong number of constructors in chain")
}
b, _ := ioutil.ReadAll(resp.Body)
if string(b) != want {
t.Errorf("Wrong constructors. want=%s, got=%s", want, b)
}
}
func TestAppendImmutability(t *testing.T) {
chain := NewChain(mockMiddleware("c1"))
chain.Append(mockMiddleware("c2"))
t1 := &mockTransport{id: "t"}
want := "t,c1"
if len(chain.constructors) != 1 {
t.Errorf("Append does not respect immutability")
}
resp, _ := chain.Then(t1).
RoundTrip(httptest.NewRequest("GET", "/", nil))
b, _ := ioutil.ReadAll(resp.Body)
if string(b) != want {
t.Errorf("Wrong constructors. want=%s, got=%s", want, b)
}
}