pomerium/config/envoyconfig/protocols_int_test.go
Joe Kralicky 08623ef346
add tests/benchmarks for http1/http2 tcp tunnels and http1 websockets (#5471)
* add tests/benchmarks for http1/http2 tcp tunnels and http1 websockets

testenv:
- add new TCP upstream
- add websocket functions to HTTP upstream
- add https support to mock idp (default on)
- add new debug flags -env.bind-address and -env.use-trace-environ to
  allow changing the default bind address, and enabling otel environment
  based trace config, respectively

* linter pass

---------

Co-authored-by: Denis Mishin <dmishin@pomerium.com>
2025-03-19 18:42:19 -04:00

393 lines
10 KiB
Go

package envoyconfig_test
import (
"context"
"errors"
"fmt"
"io"
"net"
"net/http"
"sync/atomic"
"testing"
"time"
"github.com/gorilla/websocket"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/interop"
"google.golang.org/grpc/interop/grpc_testing"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/testenv"
"github.com/pomerium/pomerium/internal/testenv/scenarios"
"github.com/pomerium/pomerium/internal/testenv/snippets"
"github.com/pomerium/pomerium/internal/testenv/upstreams"
"github.com/pomerium/pomerium/internal/testenv/values"
)
func TestH2C(t *testing.T) {
env := testenv.New(t)
up := upstreams.GRPC(insecure.NewCredentials())
grpc_testing.RegisterTestServiceServer(up, interop.NewTestServer())
http := up.Route().
From(env.SubdomainURL("grpc-http")).
To(values.Bind(up.Addr(), func(addr string) string {
// override the target protocol to use http://
return fmt.Sprintf("http://%s", addr)
})).
Policy(func(p *config.Policy) { p.AllowPublicUnauthenticatedAccess = true })
h2c := up.Route().
From(env.SubdomainURL("grpc-h2c")).
Policy(func(p *config.Policy) { p.AllowPublicUnauthenticatedAccess = true })
env.AddUpstream(up)
env.Start()
snippets.WaitStartupComplete(env)
t.Run("h2c", func(t *testing.T) {
t.Parallel()
recorder := env.NewLogRecorder()
cc := up.Dial(h2c)
client := grpc_testing.NewTestServiceClient(cc)
_, err := client.EmptyCall(env.Context(), &grpc_testing.Empty{})
require.NoError(t, err)
cc.Close()
recorder.Match([]map[string]any{
{
"service": "envoy",
"path": "/grpc.testing.TestService/EmptyCall",
"message": "http-request",
"response-code-details": "via_upstream",
},
})
})
t.Run("http", func(t *testing.T) {
t.Parallel()
recorder := env.NewLogRecorder()
cc := up.Dial(http)
client := grpc_testing.NewTestServiceClient(cc)
_, err := client.UnaryCall(env.Context(), &grpc_testing.SimpleRequest{})
require.Error(t, err)
cc.Close()
recorder.Match([]map[string]any{
{
"service": "envoy",
"path": "/grpc.testing.TestService/UnaryCall",
"message": "http-request",
"response-code-details": "upstream_reset_before_response_started{protocol_error}",
},
})
})
}
func TestHTTP(t *testing.T) {
env := testenv.New(t)
up := upstreams.HTTP(nil)
up.Handle("/foo", func(w http.ResponseWriter, _ *http.Request) {
fmt.Fprintln(w, "hello world")
})
route := up.Route().
From(env.SubdomainURL("http")).
Policy(func(p *config.Policy) { p.AllowPublicUnauthenticatedAccess = true })
env.AddUpstream(up)
env.Start()
recorder := env.NewLogRecorder()
resp, err := up.Get(route, upstreams.Path("/foo"))
require.NoError(t, err)
defer resp.Body.Close()
data, err := io.ReadAll(resp.Body)
require.NoError(t, err)
assert.Equal(t, "hello world\n", string(data))
recorder.Match([]map[string]any{
{
"service": "envoy",
"path": "/foo",
"method": "GET",
"message": "http-request",
"response-code-details": "via_upstream",
},
})
}
func TestTCPTunnel(t *testing.T) {
env := testenv.New(t, testenv.Debug())
env.Add(scenarios.NewIDP([]*scenarios.User{{Email: "test@example.com"}}))
up := upstreams.TCP()
routeH1 := up.Route().
From(env.SubdomainURL("h1")).
PPL(`{"allow":{"and":["email":{"is":"test@example.com"}]}}`)
routeH2 := up.Route().
From(env.SubdomainURL("h2")).
Policy(func(p *config.Policy) {
p.AllowWebsockets = true
}).
PPL(`{"allow":{"and":["email":{"is":"test@example.com"}]}}`)
up.Handle(func(_ context.Context, c net.Conn) error {
c.SetReadDeadline(time.Now().Add(1 * time.Second))
buf := make([]byte, 8)
n, err := c.Read(buf)
require.NoError(t, err)
assert.Equal(t, string(buf[:n]), "hello")
c.SetWriteDeadline(time.Now().Add(1 * time.Second))
_, err = c.Write([]byte("world"))
require.NoError(t, err)
return nil
})
env.AddUpstream(up)
env.Start()
snippets.WaitStartupComplete(env)
t.Run("http1", func(t *testing.T) {
assert.NoError(t, up.Dial(routeH1, func(_ context.Context, c net.Conn) error {
c.SetWriteDeadline(time.Now().Add(1 * time.Second))
_, err := c.Write([]byte("hello"))
require.NoError(t, err)
buf := make([]byte, 8)
c.SetReadDeadline(time.Now().Add(1 * time.Second))
n, err := c.Read(buf)
require.NoError(t, err)
assert.Equal(t, string(buf[:n]), "world")
return nil
}, upstreams.AuthenticateAs("test@example.com"), upstreams.DialProtocol(upstreams.DialHTTP1)))
})
t.Run("http2", func(t *testing.T) {
assert.NoError(t, up.Dial(routeH2, func(_ context.Context, c net.Conn) error {
c.SetWriteDeadline(time.Now().Add(1 * time.Second))
_, err := c.Write([]byte("hello"))
require.NoError(t, err)
buf := make([]byte, 8)
c.SetReadDeadline(time.Now().Add(1 * time.Second))
n, err := c.Read(buf)
require.NoError(t, err)
assert.Equal(t, string(buf[:n]), "world")
return nil
}, upstreams.AuthenticateAs("test@example.com"), upstreams.DialProtocol(upstreams.DialHTTP2)))
})
}
func BenchmarkHTTP1TCPTunnel(b *testing.B) {
env := testenv.New(b, testenv.Silent())
env.Add(scenarios.NewIDP([]*scenarios.User{{Email: "test@example.com"}}))
up := upstreams.TCP()
h1 := up.Route().
From(env.SubdomainURL("bench-h1")).
PPL(`{"allow":{"and":["email":{"is":"test@example.com"}]}}`)
env.AddUpstream(up)
env.Start()
snippets.WaitStartupComplete(env)
b.Run("http1", func(b *testing.B) {
benchmarkTCP(b, up, h1, tcpBenchmarkParams{
msgLen: 512,
protocol: upstreams.DialHTTP1,
})
})
}
func BenchmarkHTTP2TCPTunnel(b *testing.B) {
env := testenv.New(b, testenv.Silent())
env.Add(scenarios.NewIDP([]*scenarios.User{{Email: "test@example.com"}}))
up := upstreams.TCP()
h2 := up.Route().
From(env.SubdomainURL("bench-h2")).
Policy(func(p *config.Policy) {
p.AllowWebsockets = true
}).
PPL(`{"allow":{"and":["email":{"is":"test@example.com"}]}}`)
env.AddUpstream(up)
env.Start()
snippets.WaitStartupComplete(env)
b.Run("http2", func(b *testing.B) {
benchmarkTCP(b, up, h2, tcpBenchmarkParams{
msgLen: 512,
protocol: upstreams.DialHTTP2,
})
})
}
type tcpBenchmarkParams struct {
msgLen int
protocol upstreams.Protocol
}
func benchmarkTCP(b *testing.B, up upstreams.TCPUpstream, route testenv.Route, params tcpBenchmarkParams) {
sendMsg := func(c net.Conn, buf []byte) error {
c.SetWriteDeadline(time.Now().Add(1 * time.Second))
_, err := c.Write(buf)
if err != nil {
if errors.Is(err, net.ErrClosed) {
return nil
}
}
return err
}
recvMsg := func(c net.Conn, buf []byte) error {
c.SetReadDeadline(time.Now().Add(1 * time.Second))
for read := 0; read != len(buf); {
n, err := c.Read(buf)
read += n
if err != nil {
if errors.Is(err, net.ErrClosed) {
return nil
}
return err
}
}
return nil
}
up.Handle(func(_ context.Context, c net.Conn) error {
for {
buf := make([]byte, params.msgLen)
if err := recvMsg(c, buf[:]); err != nil {
if errors.Is(err, net.ErrClosed) {
return nil
}
return err
}
if err := sendMsg(c, buf[:]); err != nil {
if errors.Is(err, net.ErrClosed) {
return nil
}
return err
}
}
})
var threads atomic.Int32
var requests atomic.Int32
var bytes atomic.Int64
start := time.Now()
b.RunParallel(func(p *testing.PB) {
threads.Add(1)
require.NoError(b, up.Dial(route, func(_ context.Context, c net.Conn) error {
buf := make([]byte, params.msgLen)
for p.Next() {
requests.Add(1)
bytes.Add(int64(params.msgLen))
require.NoError(b, sendMsg(c, buf[:]))
require.NoError(b, recvMsg(c, buf[:]))
}
return nil
}, upstreams.AuthenticateAs("test@example.com"), upstreams.DialProtocol(params.protocol)))
})
duration := time.Since(start)
b.Logf("sent %d requests over %d parallel connections in %s", requests.Load(), threads.Load(), duration)
b.Logf("throughput: %f bytes/s", float64(bytes.Load())/duration.Seconds())
}
func TestHttp1Websocket(t *testing.T) {
env := testenv.New(t)
up := upstreams.HTTP(nil)
up.HandleWS("/ws", websocket.Upgrader{}, func(conn *websocket.Conn) error {
for {
mt, message, err := conn.ReadMessage()
if err != nil {
return err
}
// echo the message back
err = conn.WriteMessage(mt, message)
if err != nil {
return err
}
}
})
route := up.Route().
From(env.SubdomainURL("ws-test")).
Policy(func(p *config.Policy) {
p.AllowPublicUnauthenticatedAccess = true
p.AllowWebsockets = true
})
env.AddUpstream(up)
env.Start()
snippets.WaitStartupComplete(env)
assert.NoError(t, up.DialWS(route, func(conn *websocket.Conn) error {
if err := conn.SetWriteDeadline(time.Now().Add(1 * time.Second)); err != nil {
return err
}
if err := conn.WriteMessage(websocket.TextMessage, []byte("hello world")); err != nil {
return err
}
if err := conn.SetReadDeadline(time.Now().Add(1 * time.Second)); err != nil {
return err
}
mt, bytes, err := conn.ReadMessage()
if err := err; err != nil {
return err
}
assert.Equal(t, websocket.TextMessage, mt)
assert.Equal(t, "hello world", string(bytes))
return nil
}, upstreams.Path("/ws")))
}
func TestClientCert(t *testing.T) {
env := testenv.New(t)
env.Add(scenarios.DownstreamMTLS(config.MTLSEnforcementRejectConnection))
up := upstreams.HTTP(nil)
up.Handle("/foo", func(w http.ResponseWriter, _ *http.Request) {
fmt.Fprintln(w, "hello world")
})
clientCert := env.NewClientCert()
route := up.Route().
From(env.SubdomainURL("http")).
PPL(fmt.Sprintf(`{"allow":{"and":["client_certificate":{"fingerprint":%q}]}}`, clientCert.Fingerprint()))
env.AddUpstream(up)
env.Start()
recorder := env.NewLogRecorder()
resp, err := up.Get(route, upstreams.Path("/foo"), upstreams.ClientCert(clientCert))
require.NoError(t, err)
defer resp.Body.Close()
data, err := io.ReadAll(resp.Body)
require.NoError(t, err)
assert.Equal(t, "hello world\n", string(data))
recorder.Match([]map[string]any{
{
"service": "envoy",
"path": "/foo",
"method": "GET",
"message": "http-request",
"response-code-details": "via_upstream",
"client-certificate": clientCert,
},
})
}