mirror of
https://github.com/pomerium/pomerium.git
synced 2025-08-03 08:50:42 +02:00
add tests/benchmarks for http1/http2 tcp tunnels and http1 websockets (#5471)
* add tests/benchmarks for http1/http2 tcp tunnels and http1 websockets testenv: - add new TCP upstream - add websocket functions to HTTP upstream - add https support to mock idp (default on) - add new debug flags -env.bind-address and -env.use-trace-environ to allow changing the default bind address, and enabling otel environment based trace config, respectively * linter pass --------- Co-authored-by: Denis Mishin <dmishin@pomerium.com>
This commit is contained in:
parent
d6b02441b3
commit
08623ef346
12 changed files with 1104 additions and 182 deletions
343
internal/testenv/upstreams/tcp.go
Normal file
343
internal/testenv/upstreams/tcp.go
Normal file
|
@ -0,0 +1,343 @@
|
|||
package upstreams
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/cookiejar"
|
||||
"net/http/httptrace"
|
||||
"net/url"
|
||||
"sync"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
||||
"github.com/pomerium/pomerium/internal/testenv"
|
||||
"github.com/pomerium/pomerium/internal/testenv/values"
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
oteltrace "go.opentelemetry.io/otel/trace"
|
||||
"golang.org/x/net/http2"
|
||||
)
|
||||
|
||||
type TCPUpstream interface {
|
||||
testenv.Upstream
|
||||
|
||||
Handle(fn func(context.Context, net.Conn) error)
|
||||
|
||||
Dial(r testenv.Route, fn func(context.Context, net.Conn) error, opts ...RequestOption) error
|
||||
}
|
||||
|
||||
type TCPUpstreamOptions struct {
|
||||
CommonUpstreamOptions
|
||||
}
|
||||
|
||||
type TCPUpstreamOption interface {
|
||||
applyTCP(*TCPUpstreamOptions)
|
||||
}
|
||||
|
||||
type tcpUpstream struct {
|
||||
TCPUpstreamOptions
|
||||
testenv.Aggregate
|
||||
serverPort values.MutableValue[int]
|
||||
serverHandler func(context.Context, net.Conn) error
|
||||
|
||||
serverTracerProvider values.MutableValue[oteltrace.TracerProvider]
|
||||
clientTracerProvider values.MutableValue[oteltrace.TracerProvider]
|
||||
clientTracer values.Value[oteltrace.Tracer]
|
||||
}
|
||||
|
||||
func TCP(opts ...TCPUpstreamOption) TCPUpstream {
|
||||
options := TCPUpstreamOptions{
|
||||
CommonUpstreamOptions: CommonUpstreamOptions{
|
||||
displayName: "TCP Upstream",
|
||||
},
|
||||
}
|
||||
for _, op := range opts {
|
||||
op.applyTCP(&options)
|
||||
}
|
||||
up := &tcpUpstream{
|
||||
TCPUpstreamOptions: options,
|
||||
serverPort: values.Deferred[int](),
|
||||
|
||||
serverTracerProvider: values.Deferred[oteltrace.TracerProvider](),
|
||||
clientTracerProvider: values.Deferred[oteltrace.TracerProvider](),
|
||||
}
|
||||
up.clientTracer = values.Bind(up.clientTracerProvider, func(tp oteltrace.TracerProvider) oteltrace.Tracer {
|
||||
return tp.Tracer(trace.PomeriumCoreTracer)
|
||||
})
|
||||
up.RecordCaller()
|
||||
return up
|
||||
}
|
||||
|
||||
// Dial implements TCPUpstream.
|
||||
func (t *tcpUpstream) Dial(r testenv.Route, clientHandler func(context.Context, net.Conn) error, opts ...RequestOption) error {
|
||||
options := RequestOptions{
|
||||
requestCtx: t.Env().Context(),
|
||||
dialProtocol: DialHTTP1,
|
||||
}
|
||||
options.apply(opts...)
|
||||
u, err := url.Parse(r.URL().Value())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ctx, span := t.clientTracer.Value().Start(options.requestCtx, "tcpUpstream.Do", oteltrace.WithAttributes(
|
||||
attribute.String("protocol", string(options.dialProtocol)),
|
||||
attribute.String("url", u.String()),
|
||||
))
|
||||
if options.path != "" || options.query != nil {
|
||||
u = u.ResolveReference(&url.URL{
|
||||
Path: options.path,
|
||||
RawQuery: options.query.Encode(),
|
||||
})
|
||||
}
|
||||
if options.trace != nil {
|
||||
ctx = httptrace.WithClientTrace(ctx, options.trace)
|
||||
}
|
||||
options.requestCtx = ctx
|
||||
defer span.End()
|
||||
|
||||
var remoteConn *tls.Conn
|
||||
remoteWriter := make(chan *io.PipeWriter, 1)
|
||||
|
||||
connectURL := &url.URL{Scheme: "https", Host: u.Host, Path: u.Path}
|
||||
|
||||
var getClientFn func(context.Context) *http.Client
|
||||
var newRequestFn func(ctx context.Context) (*http.Request, error)
|
||||
switch options.dialProtocol {
|
||||
case DialHTTP1:
|
||||
getClientFn = t.h1Dialer(&options, connectURL, &remoteConn)
|
||||
newRequestFn = func(ctx context.Context) (*http.Request, error) {
|
||||
req := (&http.Request{
|
||||
Method: http.MethodConnect,
|
||||
URL: connectURL,
|
||||
Host: u.Host,
|
||||
}).WithContext(ctx)
|
||||
return req, nil
|
||||
}
|
||||
case DialHTTP2:
|
||||
getClientFn = t.h2Dialer(&options, connectURL, &remoteConn, remoteWriter)
|
||||
newRequestFn = func(ctx context.Context) (*http.Request, error) {
|
||||
req := (&http.Request{
|
||||
Method: http.MethodConnect,
|
||||
URL: connectURL,
|
||||
Host: u.Host,
|
||||
Proto: "HTTP/2",
|
||||
}).WithContext(ctx)
|
||||
return req, nil
|
||||
}
|
||||
case DialHTTP3:
|
||||
panic("not implemented")
|
||||
}
|
||||
resp, err := doAuthenticatedRequest(options.requestCtx, newRequestFn, getClientFn, &options)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
resp.Body.Close()
|
||||
return errors.New(resp.Status)
|
||||
}
|
||||
if resp.Request.URL.Path == "/oidc/auth" {
|
||||
if options.authenticateAs == "" {
|
||||
return errors.New("test bug: unexpected IDP redirect; missing AuthenticateAs option to Dial()")
|
||||
}
|
||||
return errors.New("internal test bug: unexpected IDP redirect")
|
||||
}
|
||||
|
||||
var w io.WriteCloser = remoteConn
|
||||
if options.dialProtocol == DialHTTP2 {
|
||||
w = <-remoteWriter
|
||||
}
|
||||
|
||||
conn := NewRWConn(resp.Body, w)
|
||||
defer conn.Close()
|
||||
return clientHandler(resp.Request.Context(), conn)
|
||||
}
|
||||
|
||||
func (t *tcpUpstream) h1Dialer(
|
||||
options *RequestOptions,
|
||||
connectURL *url.URL,
|
||||
remoteConn **tls.Conn,
|
||||
) func(context.Context) *http.Client {
|
||||
jar, _ := cookiejar.New(nil)
|
||||
return func(context.Context) *http.Client {
|
||||
tlsConfig := &tls.Config{
|
||||
RootCAs: t.Env().ServerCAs(),
|
||||
Certificates: options.clientCerts,
|
||||
NextProtos: []string{"http/1.1"},
|
||||
}
|
||||
client := &http.Client{
|
||||
Transport: &http.Transport{
|
||||
DisableKeepAlives: true,
|
||||
DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
if *remoteConn != nil {
|
||||
(*remoteConn).Close()
|
||||
*remoteConn = nil
|
||||
}
|
||||
dialer := &tls.Dialer{
|
||||
Config: tlsConfig,
|
||||
}
|
||||
cc, err := dialer.DialContext(ctx, network, addr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%w: %w", ErrRetry, err)
|
||||
}
|
||||
protocol := cc.(*tls.Conn).ConnectionState().NegotiatedProtocol
|
||||
if protocol != "http/1.1" {
|
||||
cc.Close()
|
||||
return nil, fmt.Errorf("error: unexpected TLS protocol: %s", protocol)
|
||||
}
|
||||
*remoteConn = cc.(*tls.Conn)
|
||||
return cc, nil
|
||||
},
|
||||
TLSClientConfig: tlsConfig, // important
|
||||
},
|
||||
CheckRedirect: func(req *http.Request, _ []*http.Request) error {
|
||||
if req.URL.String() == connectURL.String() && req.Method == http.MethodGet {
|
||||
req.Method = http.MethodConnect
|
||||
}
|
||||
return nil
|
||||
},
|
||||
Jar: jar,
|
||||
}
|
||||
return client
|
||||
}
|
||||
}
|
||||
|
||||
func (t *tcpUpstream) h2Dialer(
|
||||
options *RequestOptions,
|
||||
connectURL *url.URL,
|
||||
remoteConn **tls.Conn,
|
||||
writer chan<- *io.PipeWriter,
|
||||
) func(context.Context) *http.Client {
|
||||
jar, _ := cookiejar.New(nil)
|
||||
return func(context.Context) *http.Client {
|
||||
h1 := &http.Transport{
|
||||
ForceAttemptHTTP2: true,
|
||||
DisableKeepAlives: true,
|
||||
DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
if *remoteConn != nil {
|
||||
(*remoteConn).Close()
|
||||
*remoteConn = nil
|
||||
}
|
||||
dialer := &tls.Dialer{
|
||||
Config: &tls.Config{
|
||||
RootCAs: t.Env().ServerCAs(),
|
||||
Certificates: options.clientCerts,
|
||||
NextProtos: []string{"h2"},
|
||||
},
|
||||
}
|
||||
cc, err := dialer.DialContext(ctx, network, addr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%w: %w", ErrRetry, err)
|
||||
}
|
||||
protocol := cc.(*tls.Conn).ConnectionState().NegotiatedProtocol
|
||||
if protocol != "h2" {
|
||||
cc.Close()
|
||||
return nil, fmt.Errorf("error: unexpected TLS protocol: %s", protocol)
|
||||
}
|
||||
*remoteConn = cc.(*tls.Conn)
|
||||
|
||||
return cc, nil
|
||||
},
|
||||
TLSClientConfig: &tls.Config{
|
||||
RootCAs: t.Env().ServerCAs(),
|
||||
Certificates: options.clientCerts,
|
||||
NextProtos: []string{"h2"},
|
||||
},
|
||||
}
|
||||
if err := http2.ConfigureTransport(h1); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
client := &http.Client{
|
||||
Transport: h1,
|
||||
CheckRedirect: func(req *http.Request, _ []*http.Request) error {
|
||||
if req.URL.String() == connectURL.String() && req.Method == http.MethodGet {
|
||||
pr, pw := io.Pipe()
|
||||
req.Method = http.MethodConnect
|
||||
req.Body = pr
|
||||
req.ContentLength = -1
|
||||
writer <- pw
|
||||
}
|
||||
return nil
|
||||
},
|
||||
Jar: jar,
|
||||
}
|
||||
return client
|
||||
}
|
||||
}
|
||||
|
||||
// Handle implements TCPUpstream.
|
||||
func (t *tcpUpstream) Handle(fn func(context.Context, net.Conn) error) {
|
||||
t.serverHandler = fn
|
||||
}
|
||||
|
||||
// Port implements TCPUpstream.
|
||||
func (t *tcpUpstream) Addr() values.Value[string] {
|
||||
return values.Bind(t.serverPort, func(port int) string {
|
||||
return fmt.Sprintf("%s:%d", t.Env().Host(), port)
|
||||
})
|
||||
}
|
||||
|
||||
// Route implements TCPUpstream.
|
||||
func (t *tcpUpstream) Route() testenv.RouteStub {
|
||||
r := &testenv.TCPRoute{}
|
||||
r.To(values.Bind(t.serverPort, func(port int) string {
|
||||
return fmt.Sprintf("tcp://%s:%d", t.Env().Host(), port)
|
||||
}))
|
||||
t.Add(r)
|
||||
return r
|
||||
}
|
||||
|
||||
// Run implements TCPUpstream.
|
||||
func (t *tcpUpstream) Run(ctx context.Context) error {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
listener, err := (&net.ListenConfig{}).Listen(ctx, "tcp", fmt.Sprintf("%s:0", t.Env().Host()))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
context.AfterFunc(ctx, func() {
|
||||
listener.Close()
|
||||
})
|
||||
t.serverPort.Resolve(listener.Addr().(*net.TCPAddr).Port)
|
||||
if t.serverTracerProviderOverride != nil {
|
||||
t.serverTracerProvider.Resolve(t.serverTracerProviderOverride)
|
||||
} else {
|
||||
t.serverTracerProvider.Resolve(trace.NewTracerProvider(ctx, t.displayName))
|
||||
}
|
||||
if t.clientTracerProviderOverride != nil {
|
||||
t.clientTracerProvider.Resolve(t.clientTracerProviderOverride)
|
||||
} else {
|
||||
t.clientTracerProvider.Resolve(trace.NewTracerProvider(ctx, "TCP Client"))
|
||||
}
|
||||
var wg sync.WaitGroup
|
||||
defer wg.Wait()
|
||||
|
||||
for {
|
||||
conn, err := listener.Accept()
|
||||
if err != nil {
|
||||
if errors.Is(err, net.ErrClosed) {
|
||||
cancel()
|
||||
return nil
|
||||
}
|
||||
continue
|
||||
}
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
if err := t.serverHandler(ctx, conn); err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
return
|
||||
}
|
||||
panic("server handler error: " + err.Error())
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
_ testenv.Upstream = (*tcpUpstream)(nil)
|
||||
_ TCPUpstream = (*tcpUpstream)(nil)
|
||||
)
|
Loading…
Add table
Add a link
Reference in a new issue