mirror of
https://github.com/pomerium/pomerium.git
synced 2025-04-28 09:56:31 +02:00
112 lines
2.1 KiB
Go
112 lines
2.1 KiB
Go
package tcptunnel
|
|
|
|
import (
|
|
"bufio"
|
|
"bytes"
|
|
"context"
|
|
"crypto/tls"
|
|
"io"
|
|
"net"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
)
|
|
|
|
func TestTunnel(t *testing.T) {
|
|
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10)
|
|
defer clearTimeout()
|
|
|
|
backend, err := net.Listen("tcp", "127.0.0.1:0")
|
|
if !assert.NoError(t, err) {
|
|
return
|
|
}
|
|
defer func() { _ = backend.Close() }()
|
|
|
|
go func() {
|
|
for {
|
|
conn, err := backend.Accept()
|
|
if err != nil {
|
|
return
|
|
}
|
|
go func() {
|
|
defer func() { _ = conn.Close() }()
|
|
|
|
ln, _, _ := bufio.NewReader(conn).ReadLine()
|
|
assert.Equal(t, "HELLO WORLD", string(ln))
|
|
}()
|
|
}
|
|
}()
|
|
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if !assert.Equal(t, "CONNECT", r.Method) {
|
|
return
|
|
}
|
|
if !assert.Equal(t, "example.com:9999", r.RequestURI) {
|
|
return
|
|
}
|
|
|
|
w.WriteHeader(200)
|
|
|
|
in, brw, err := w.(http.Hijacker).Hijack()
|
|
if !assert.NoError(t, err) {
|
|
return
|
|
}
|
|
defer func() { _ = in.Close() }()
|
|
|
|
out, err := net.Dial("tcp", backend.Addr().String())
|
|
if !assert.NoError(t, err) {
|
|
return
|
|
}
|
|
defer func() { _ = out.Close() }()
|
|
|
|
errc := make(chan error, 2)
|
|
go func() {
|
|
_, err := io.Copy(in, out)
|
|
errc <- err
|
|
}()
|
|
go func() {
|
|
_, err := io.Copy(out, deBuffer(brw.Reader, in))
|
|
errc <- err
|
|
}()
|
|
<-errc
|
|
}))
|
|
defer srv.Close()
|
|
|
|
var buf bytes.Buffer
|
|
tun := New(
|
|
WithDestinationHost("example.com:9999"),
|
|
WithProxyHost(srv.Listener.Addr().String()))
|
|
err = tun.Run(ctx, readWriter{strings.NewReader("HELLO WORLD\n"), &buf})
|
|
if !assert.NoError(t, err) {
|
|
return
|
|
}
|
|
}
|
|
|
|
type readWriter struct {
|
|
io.Reader
|
|
io.Writer
|
|
}
|
|
|
|
func TestForceHTTP1(t *testing.T) {
|
|
tunnel := New(WithTLSConfig(&tls.Config{
|
|
InsecureSkipVerify: true,
|
|
}))
|
|
|
|
var protocol string
|
|
srv := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
protocol = r.Proto
|
|
}))
|
|
|
|
client := &http.Client{
|
|
Transport: &http.Transport{
|
|
TLSClientConfig: tunnel.cfg.tlsConfig,
|
|
},
|
|
}
|
|
_, _ = client.Get(srv.URL)
|
|
|
|
assert.Equal(t, "HTTP/1.1", protocol)
|
|
}
|