pomerium/internal/tcptunnel/tcptunnel_test.go
Caleb Doxsey 77ae17d23b
tcptunnel: force the use of HTTP/1.1 during ALPN (#2593)
* tcptunnel: force the use of HTTP/1.1 during ALPN

* remove unused code
2021-09-13 13:53:19 -06:00

112 lines
2.1 KiB
Go

package tcptunnel
import (
"bufio"
"bytes"
"context"
"crypto/tls"
"io"
"net"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestTunnel(t *testing.T) {
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10)
defer clearTimeout()
backend, err := net.Listen("tcp", "127.0.0.1:0")
if !assert.NoError(t, err) {
return
}
defer func() { _ = backend.Close() }()
go func() {
for {
conn, err := backend.Accept()
if err != nil {
return
}
go func() {
defer func() { _ = conn.Close() }()
ln, _, _ := bufio.NewReader(conn).ReadLine()
assert.Equal(t, "HELLO WORLD", string(ln))
}()
}
}()
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !assert.Equal(t, "CONNECT", r.Method) {
return
}
if !assert.Equal(t, "example.com:9999", r.RequestURI) {
return
}
w.WriteHeader(200)
in, brw, err := w.(http.Hijacker).Hijack()
if !assert.NoError(t, err) {
return
}
defer func() { _ = in.Close() }()
out, err := net.Dial("tcp", backend.Addr().String())
if !assert.NoError(t, err) {
return
}
defer func() { _ = out.Close() }()
errc := make(chan error, 2)
go func() {
_, err := io.Copy(in, out)
errc <- err
}()
go func() {
_, err := io.Copy(out, deBuffer(brw.Reader, in))
errc <- err
}()
<-errc
}))
defer srv.Close()
var buf bytes.Buffer
tun := New(
WithDestinationHost("example.com:9999"),
WithProxyHost(srv.Listener.Addr().String()))
err = tun.Run(ctx, readWriter{strings.NewReader("HELLO WORLD\n"), &buf})
if !assert.NoError(t, err) {
return
}
}
type readWriter struct {
io.Reader
io.Writer
}
func TestForceHTTP1(t *testing.T) {
tunnel := New(WithTLSConfig(&tls.Config{
InsecureSkipVerify: true,
}))
var protocol string
srv := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
protocol = r.Proto
}))
client := &http.Client{
Transport: &http.Transport{
TLSClientConfig: tunnel.cfg.tlsConfig,
},
}
_, _ = client.Get(srv.URL)
assert.Equal(t, "HTTP/1.1", protocol)
}