add tests/benchmarks for http1/http2 tcp tunnels and http1 websockets (#5471)

* add tests/benchmarks for http1/http2 tcp tunnels and http1 websockets

testenv:
- add new TCP upstream
- add websocket functions to HTTP upstream
- add https support to mock idp (default on)
- add new debug flags -env.bind-address and -env.use-trace-environ to
  allow changing the default bind address, and enabling otel environment
  based trace config, respectively

* linter pass

---------

Co-authored-by: Denis Mishin <dmishin@pomerium.com>
This commit is contained in:
Joe Kralicky 2025-03-19 18:42:19 -04:00 committed by GitHub
parent d6b02441b3
commit 08623ef346
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 1104 additions and 182 deletions

View file

@ -97,8 +97,10 @@ type service struct {
impl any
}
func (g *grpcUpstream) Port() values.Value[int] {
return g.serverPort
func (g *grpcUpstream) Addr() values.Value[string] {
return values.Bind(g.serverPort, func(port int) string {
return fmt.Sprintf("%s:%d", g.Env().Host(), port)
})
}
// RegisterService implements grpc.ServiceRegistrar.
@ -117,7 +119,7 @@ func (g *grpcUpstream) Route() testenv.RouteStub {
protocol = "https"
}
r.To(values.Bind(g.serverPort, func(port int) string {
return fmt.Sprintf("%s://127.0.0.1:%d", protocol, port)
return fmt.Sprintf("%s://%s:%d", protocol, g.Env().Host(), port)
}))
g.Add(r)
return r
@ -125,7 +127,7 @@ func (g *grpcUpstream) Route() testenv.RouteStub {
// Start implements testenv.Upstream.
func (g *grpcUpstream) Run(ctx context.Context) error {
listener, err := net.Listen("tcp", "127.0.0.1:0")
listener, err := net.Listen("tcp", fmt.Sprintf("%s:0", g.Env().Host()))
if err != nil {
return err
}
@ -187,7 +189,7 @@ func (g *grpcUpstream) Dial(r testenv.Route, dialOpts ...grpc.DialOption) *grpc.
}
func (g *grpcUpstream) DirectConnect(dialOpts ...grpc.DialOption) *grpc.ClientConn {
cc, err := grpc.NewClient(fmt.Sprintf("127.0.0.1:%d", g.Port().Value()),
cc, err := grpc.NewClient(g.Addr().Value(),
append(g.withDefaultDialOpts(dialOpts), grpc.WithTransportCredentials(insecure.NewCredentials()))...)
if err != nil {
panic(err)

View file

@ -1,10 +1,8 @@
package upstreams
import (
"bytes"
"context"
"crypto/tls"
"encoding/json"
"errors"
"fmt"
"io"
@ -13,14 +11,12 @@ import (
"net/http/cookiejar"
"net/http/httptrace"
"net/url"
"strconv"
"strings"
"os"
"sync"
"time"
"github.com/gorilla/mux"
"github.com/pomerium/pomerium/integration/forms"
"github.com/pomerium/pomerium/internal/retry"
"github.com/gorilla/websocket"
"github.com/pomerium/pomerium/internal/telemetry/trace"
"github.com/pomerium/pomerium/internal/testenv"
"github.com/pomerium/pomerium/internal/testenv/snippets"
@ -29,9 +25,15 @@ import (
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/codes"
semconv "go.opentelemetry.io/otel/semconv/v1.26.0"
oteltrace "go.opentelemetry.io/otel/trace"
"google.golang.org/protobuf/proto"
)
type Protocol string
const (
DialHTTP1 Protocol = "http/1.1"
DialHTTP2 Protocol = "h2"
DialHTTP3 Protocol = "h3"
)
type RequestOptions struct {
@ -42,12 +44,18 @@ type RequestOptions struct {
authenticateAs string
body any
clientCerts []tls.Certificate
client *http.Client
clientHook func(*http.Client) *http.Client
dialerHook func(*websocket.Dialer, *url.URL) (*websocket.Dialer, *url.URL)
dialProtocol Protocol
trace *httptrace.ClientTrace
}
type RequestOption func(*RequestOptions)
func (ro RequestOption) Format(fmt.State, rune) {
panic("test bug: request option mistakenly passed to assert function")
}
func (o *RequestOptions) apply(opts ...RequestOption) {
for _, op := range opts {
op(o)
@ -82,9 +90,38 @@ func AuthenticateAs(email string) RequestOption {
}
}
func Client(c *http.Client) RequestOption {
// ClientHook allows editing or replacing the http client before it is used.
// When any request is about to start, this function will be called with the
// client that would be used to make the request. The returned client will
// be the actual client used for that request. It can be the same as the input
// (with or without modification), or replaced entirely.
//
// Note: the Transport of the client passed to the hook will always be a
// [*Transport]. That transport's underlying transport will always be
// a [*otelhttp.Transport].
func ClientHook(f func(*http.Client) *http.Client) RequestOption {
return func(o *RequestOptions) {
o.client = c
o.clientHook = f
}
}
// DialerHook allows editing or replacing the websocket dialer before it is
// used. When a websocket request is about to start (using the DialWS method),
// this function will be called with the dialer that would be used, and the
// destination URL (including wss:// scheme, and path if one is present). The
// returned dialer+URL will be the actual dialer+URL used for that request.
//
// If ClientHook is also set, both will be called. The dialer passed to this
// hook will have its TLSClientConfig and Jar fields set from the client.
func DialerHook(f func(*websocket.Dialer, *url.URL) (*websocket.Dialer, *url.URL)) RequestOption {
return func(o *RequestOptions) {
o.dialerHook = f
}
}
func DialProtocol(protocol Protocol) RequestOption {
return func(o *RequestOptions) {
o.dialProtocol = protocol
}
}
@ -143,10 +180,12 @@ type HTTPUpstream interface {
testenv.Upstream
Handle(path string, f func(http.ResponseWriter, *http.Request)) *mux.Route
HandleWS(path string, upgrader websocket.Upgrader, f func(conn *websocket.Conn) error) *mux.Route
Get(r testenv.Route, opts ...RequestOption) (*http.Response, error)
Post(r testenv.Route, opts ...RequestOption) (*http.Response, error)
Do(method string, r testenv.Route, opts ...RequestOption) (*http.Response, error)
DialWS(r testenv.Route, f func(conn *websocket.Conn) error, opts ...RequestOption) error
}
type httpUpstream struct {
@ -194,8 +233,10 @@ func HTTP(tlsConfig values.Value[*tls.Config], opts ...HTTPUpstreamOption) HTTPU
}
// Port implements HTTPUpstream.
func (h *httpUpstream) Port() values.Value[int] {
return h.serverPort
func (h *httpUpstream) Addr() values.Value[string] {
return values.Bind(h.serverPort, func(port int) string {
return fmt.Sprintf("%s:%d", h.Env().Host(), port)
})
}
// Router implements HTTPUpstream.
@ -203,12 +244,37 @@ func (h *httpUpstream) Handle(path string, f func(http.ResponseWriter, *http.Req
return h.router.HandleFunc(path, f)
}
// Router implements HTTPUpstream.
func (h *httpUpstream) HandleWS(path string, upgrader websocket.Upgrader, f func(*websocket.Conn) error) *mux.Route {
return h.router.HandleFunc(path, func(w http.ResponseWriter, r *http.Request) {
ctx, span := trace.Continue(r.Context(), "HandleWS")
defer span.End()
c, err := upgrader.Upgrade(w, r.WithContext(ctx), nil)
if err != nil {
span.SetStatus(codes.Error, err.Error())
w.WriteHeader(http.StatusBadRequest)
_, _ = w.Write([]byte(err.Error()))
return
}
defer c.Close()
err = f(c)
if err != nil {
if errors.Is(err, io.EOF) {
return
}
span.SetStatus(codes.Error, err.Error())
fmt.Fprintf(os.Stderr, "websocket error: %s\n", err.Error())
}
})
}
// Route implements HTTPUpstream.
func (h *httpUpstream) Route() testenv.RouteStub {
r := &testenv.PolicyRoute{}
protocol := "http"
r.To(values.Bind(h.serverPort, func(port int) string {
return fmt.Sprintf("%s://127.0.0.1:%d", protocol, port)
return fmt.Sprintf("%s://%s:%d", protocol, h.Env().Host(), port)
}))
h.Add(r)
return r
@ -216,15 +282,21 @@ func (h *httpUpstream) Route() testenv.RouteStub {
// Run implements HTTPUpstream.
func (h *httpUpstream) Run(ctx context.Context) error {
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
return err
var listener net.Listener
if h.tlsConfig != nil {
var err error
listener, err = tls.Listen("tcp", fmt.Sprintf("%s:0", h.Env().Host()), h.tlsConfig.Value())
if err != nil {
return err
}
} else {
var err error
listener, err = net.Listen("tcp", fmt.Sprintf("%s:0", h.Env().Host()))
if err != nil {
return err
}
}
h.serverPort.Resolve(listener.Addr().(*net.TCPAddr).Port)
var tlsConfig *tls.Config
if h.tlsConfig != nil {
tlsConfig = h.tlsConfig.Value()
}
if h.serverTracerProviderOverride != nil {
h.serverTracerProvider.Resolve(h.serverTracerProviderOverride)
} else {
@ -238,8 +310,7 @@ func (h *httpUpstream) Run(ctx context.Context) error {
h.router.Use(trace.NewHTTPMiddleware(otelhttp.WithTracerProvider(h.serverTracerProvider.Value())))
server := &http.Server{
Handler: h.router,
TLSConfig: tlsConfig,
Handler: h.router,
// BaseContext: func(net.Listener) context.Context {
// return ctx
// },
@ -277,6 +348,53 @@ func (h *httpUpstream) Post(r testenv.Route, opts ...RequestOption) (*http.Respo
return h.Do(http.MethodPost, r, opts...)
}
type Transport struct {
*otelhttp.Transport
// The underlying http.Transport instance wrapped by the otelhttp.Transport.
Base *http.Transport
}
var _ http.RoundTripper = Transport{}
func (h *httpUpstream) newClient(options *RequestOptions) *http.Client {
transport := http.DefaultTransport.(*http.Transport).Clone()
transport.TLSClientConfig = &tls.Config{
RootCAs: h.Env().ServerCAs(),
Certificates: options.clientCerts,
}
transport.DialTLSContext = nil
c := http.Client{
Transport: &Transport{
Transport: otelhttp.NewTransport(transport,
otelhttp.WithTracerProvider(h.clientTracerProvider.Value()),
otelhttp.WithSpanNameFormatter(func(_ string, r *http.Request) string {
return fmt.Sprintf("Client: %s %s", r.Method, r.URL.Path)
}),
),
Base: transport,
},
}
c.Jar, _ = cookiejar.New(&cookiejar.Options{})
return &c
}
func (h *httpUpstream) getRouteClient(r testenv.Route, options *RequestOptions) *http.Client {
span := oteltrace.SpanFromContext(options.requestCtx)
var cachedClient any
var ok bool
if cachedClient, ok = h.clientCache.Load(r); !ok {
span.AddEvent("creating new http client")
cachedClient, _ = h.clientCache.LoadOrStore(r, h.newClient(options))
} else {
span.AddEvent("using cached http client")
}
client := cachedClient.(*http.Client)
if options.clientHook != nil {
client = options.clientHook(client)
}
return client
}
// Do implements HTTPUpstream.
func (h *httpUpstream) Do(method string, r testenv.Route, opts ...RequestOption) (*http.Response, error) {
options := RequestOptions{
@ -303,141 +421,54 @@ func (h *httpUpstream) Do(method string, r testenv.Route, opts ...RequestOption)
options.requestCtx = ctx
defer span.End()
newClient := func() *http.Client {
transport := http.DefaultTransport.(*http.Transport).Clone()
transport.TLSClientConfig = &tls.Config{
RootCAs: h.Env().ServerCAs(),
Certificates: options.clientCerts,
}
transport.DialTLSContext = nil
c := http.Client{
Transport: otelhttp.NewTransport(transport,
otelhttp.WithTracerProvider(h.clientTracerProvider.Value()),
otelhttp.WithSpanNameFormatter(func(_ string, r *http.Request) string {
return fmt.Sprintf("Client: %s %s", r.Method, r.URL.Path)
}),
),
}
c.Jar, _ = cookiejar.New(&cookiejar.Options{})
return &c
}
var client *http.Client
if options.client != nil {
client = options.client
} else {
var cachedClient any
var ok bool
if cachedClient, ok = h.clientCache.Load(r); !ok {
span.AddEvent("creating new http client")
cachedClient, _ = h.clientCache.LoadOrStore(r, newClient())
} else {
span.AddEvent("using cached http client")
}
client = cachedClient.(*http.Client)
}
var resp *http.Response
resendCount := 0
if err := retry.Retry(ctx, "http", func(ctx context.Context) error {
req, err := http.NewRequestWithContext(ctx, method, u.String(), nil)
if err != nil {
return retry.NewTerminalError(err)
}
switch body := options.body.(type) {
case string:
req.Body = io.NopCloser(strings.NewReader(body))
case []byte:
req.Body = io.NopCloser(bytes.NewReader(body))
case io.Reader:
req.Body = io.NopCloser(body)
case proto.Message:
buf, err := proto.Marshal(body)
if err != nil {
return retry.NewTerminalError(err)
}
req.Body = io.NopCloser(bytes.NewReader(buf))
req.Header.Set("Content-Type", "application/octet-stream")
default:
buf, err := json.Marshal(body)
if err != nil {
panic(fmt.Sprintf("unsupported body type: %T", body))
}
req.Body = io.NopCloser(bytes.NewReader(buf))
req.Header.Set("Content-Type", "application/json")
case nil:
}
if options.authenticateAs != "" {
resp, err = authenticateFlow(ctx, client, req, options.authenticateAs) //nolint:bodyclose
} else {
resp, err = client.Do(req) //nolint:bodyclose
}
// retry on connection refused
if err != nil {
span.RecordError(err)
var opErr *net.OpError
if errors.As(err, &opErr) && opErr.Op == "dial" && opErr.Err.Error() == "connect: connection refused" {
span.AddEvent("Retrying on dial error")
return err
}
return retry.NewTerminalError(err)
}
if resp.StatusCode/100 == 5 {
resendCount++
_, _ = io.ReadAll(resp.Body)
_ = resp.Body.Close()
span.SetAttributes(semconv.HTTPRequestResendCount(resendCount))
span.AddEvent("Retrying on 5xx error", oteltrace.WithAttributes(
attribute.String("status", resp.Status),
))
return errors.New(http.StatusText(resp.StatusCode))
}
span.SetStatus(codes.Ok, "request completed successfully")
return nil
},
retry.WithInitialInterval(1*time.Millisecond),
retry.WithMaxInterval(100*time.Millisecond),
); err != nil {
return nil, err
}
return resp, nil
return doAuthenticatedRequest(options.requestCtx,
func(ctx context.Context) (*http.Request, error) {
return http.NewRequestWithContext(ctx, method, u.String(), nil)
},
func(context.Context) *http.Client {
return h.getRouteClient(r, &options)
},
&options,
)
}
func authenticateFlow(ctx context.Context, client *http.Client, req *http.Request, email string) (*http.Response, error) {
span := oteltrace.SpanFromContext(ctx)
var res *http.Response
originalHostname := req.URL.Hostname()
res, err := client.Do(req)
func (h *httpUpstream) DialWS(r testenv.Route, f func(conn *websocket.Conn) error, opts ...RequestOption) error {
options := RequestOptions{
requestCtx: h.Env().Context(),
}
options.apply(opts...)
u, err := url.Parse(r.URL().Value())
if err != nil {
span.RecordError(err)
return nil, err
return err
}
location := res.Request.URL
if location.Hostname() == originalHostname {
// already authenticated
span.SetStatus(codes.Ok, "already authenticated")
return res, nil
u.Scheme = "wss"
if options.path != "" || options.query != nil {
u = u.ResolveReference(&url.URL{
Path: options.path,
RawQuery: options.query.Encode(),
})
}
fs := forms.Parse(res.Body)
_, _ = io.ReadAll(res.Body)
_ = res.Body.Close()
if len(fs) > 0 {
f := fs[0]
f.Inputs["email"] = email
f.Inputs["token_expiration"] = strconv.Itoa(int((time.Hour * 24).Seconds()))
span.AddEvent("submitting form", oteltrace.WithAttributes(attribute.String("location", location.String())))
formReq, err := f.NewRequestWithContext(ctx, location)
if err != nil {
span.RecordError(err)
return nil, err
}
resp, err := client.Do(formReq)
if err != nil {
span.RecordError(err)
return nil, err
}
span.SetStatus(codes.Ok, "form submitted successfully")
return resp, nil
ctx, span := h.clientTracer.Value().Start(options.requestCtx, "httpUpstream.Dial", oteltrace.WithAttributes(
attribute.String("url", u.String()),
))
options.requestCtx = ctx
defer span.End()
client := h.getRouteClient(r, &options)
d := &websocket.Dialer{
HandshakeTimeout: 10 * time.Second,
TLSClientConfig: client.Transport.(*Transport).Base.TLSClientConfig,
Jar: client.Jar,
}
return nil, fmt.Errorf("test bug: expected IDP login form")
if options.dialerHook != nil {
d, u = options.dialerHook(d, u)
}
conn, resp, err := d.DialContext(options.requestCtx, u.String(), nil)
if err != nil {
resp.Body.Close()
return fmt.Errorf("DialContext: %w", err)
}
defer conn.Close()
return f(conn)
}

View file

@ -15,6 +15,7 @@ type CommonUpstreamOptions struct {
type CommonUpstreamOption interface {
GRPCUpstreamOption
HTTPUpstreamOption
TCPUpstreamOption
}
type commonUpstreamOption func(o *CommonUpstreamOptions)
@ -25,6 +26,9 @@ func (c commonUpstreamOption) applyGRPC(o *GRPCUpstreamOptions) { c(&o.CommonUps
// applyHTTP implements CommonUpstreamOption.
func (c commonUpstreamOption) applyHTTP(o *HTTPUpstreamOptions) { c(&o.CommonUpstreamOptions) }
// applyTCP implements CommonUpstreamOption.
func (c commonUpstreamOption) applyTCP(o *TCPUpstreamOptions) { c(&o.CommonUpstreamOptions) }
func WithDisplayName(displayName string) CommonUpstreamOption {
return commonUpstreamOption(func(o *CommonUpstreamOptions) {
o.displayName = displayName

View file

@ -0,0 +1,343 @@
package upstreams
import (
"context"
"crypto/tls"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/http/cookiejar"
"net/http/httptrace"
"net/url"
"sync"
"github.com/pomerium/pomerium/internal/telemetry/trace"
"github.com/pomerium/pomerium/internal/testenv"
"github.com/pomerium/pomerium/internal/testenv/values"
"go.opentelemetry.io/otel/attribute"
oteltrace "go.opentelemetry.io/otel/trace"
"golang.org/x/net/http2"
)
type TCPUpstream interface {
testenv.Upstream
Handle(fn func(context.Context, net.Conn) error)
Dial(r testenv.Route, fn func(context.Context, net.Conn) error, opts ...RequestOption) error
}
type TCPUpstreamOptions struct {
CommonUpstreamOptions
}
type TCPUpstreamOption interface {
applyTCP(*TCPUpstreamOptions)
}
type tcpUpstream struct {
TCPUpstreamOptions
testenv.Aggregate
serverPort values.MutableValue[int]
serverHandler func(context.Context, net.Conn) error
serverTracerProvider values.MutableValue[oteltrace.TracerProvider]
clientTracerProvider values.MutableValue[oteltrace.TracerProvider]
clientTracer values.Value[oteltrace.Tracer]
}
func TCP(opts ...TCPUpstreamOption) TCPUpstream {
options := TCPUpstreamOptions{
CommonUpstreamOptions: CommonUpstreamOptions{
displayName: "TCP Upstream",
},
}
for _, op := range opts {
op.applyTCP(&options)
}
up := &tcpUpstream{
TCPUpstreamOptions: options,
serverPort: values.Deferred[int](),
serverTracerProvider: values.Deferred[oteltrace.TracerProvider](),
clientTracerProvider: values.Deferred[oteltrace.TracerProvider](),
}
up.clientTracer = values.Bind(up.clientTracerProvider, func(tp oteltrace.TracerProvider) oteltrace.Tracer {
return tp.Tracer(trace.PomeriumCoreTracer)
})
up.RecordCaller()
return up
}
// Dial implements TCPUpstream.
func (t *tcpUpstream) Dial(r testenv.Route, clientHandler func(context.Context, net.Conn) error, opts ...RequestOption) error {
options := RequestOptions{
requestCtx: t.Env().Context(),
dialProtocol: DialHTTP1,
}
options.apply(opts...)
u, err := url.Parse(r.URL().Value())
if err != nil {
return err
}
ctx, span := t.clientTracer.Value().Start(options.requestCtx, "tcpUpstream.Do", oteltrace.WithAttributes(
attribute.String("protocol", string(options.dialProtocol)),
attribute.String("url", u.String()),
))
if options.path != "" || options.query != nil {
u = u.ResolveReference(&url.URL{
Path: options.path,
RawQuery: options.query.Encode(),
})
}
if options.trace != nil {
ctx = httptrace.WithClientTrace(ctx, options.trace)
}
options.requestCtx = ctx
defer span.End()
var remoteConn *tls.Conn
remoteWriter := make(chan *io.PipeWriter, 1)
connectURL := &url.URL{Scheme: "https", Host: u.Host, Path: u.Path}
var getClientFn func(context.Context) *http.Client
var newRequestFn func(ctx context.Context) (*http.Request, error)
switch options.dialProtocol {
case DialHTTP1:
getClientFn = t.h1Dialer(&options, connectURL, &remoteConn)
newRequestFn = func(ctx context.Context) (*http.Request, error) {
req := (&http.Request{
Method: http.MethodConnect,
URL: connectURL,
Host: u.Host,
}).WithContext(ctx)
return req, nil
}
case DialHTTP2:
getClientFn = t.h2Dialer(&options, connectURL, &remoteConn, remoteWriter)
newRequestFn = func(ctx context.Context) (*http.Request, error) {
req := (&http.Request{
Method: http.MethodConnect,
URL: connectURL,
Host: u.Host,
Proto: "HTTP/2",
}).WithContext(ctx)
return req, nil
}
case DialHTTP3:
panic("not implemented")
}
resp, err := doAuthenticatedRequest(options.requestCtx, newRequestFn, getClientFn, &options)
if err != nil {
return err
}
if resp.StatusCode != http.StatusOK {
resp.Body.Close()
return errors.New(resp.Status)
}
if resp.Request.URL.Path == "/oidc/auth" {
if options.authenticateAs == "" {
return errors.New("test bug: unexpected IDP redirect; missing AuthenticateAs option to Dial()")
}
return errors.New("internal test bug: unexpected IDP redirect")
}
var w io.WriteCloser = remoteConn
if options.dialProtocol == DialHTTP2 {
w = <-remoteWriter
}
conn := NewRWConn(resp.Body, w)
defer conn.Close()
return clientHandler(resp.Request.Context(), conn)
}
func (t *tcpUpstream) h1Dialer(
options *RequestOptions,
connectURL *url.URL,
remoteConn **tls.Conn,
) func(context.Context) *http.Client {
jar, _ := cookiejar.New(nil)
return func(context.Context) *http.Client {
tlsConfig := &tls.Config{
RootCAs: t.Env().ServerCAs(),
Certificates: options.clientCerts,
NextProtos: []string{"http/1.1"},
}
client := &http.Client{
Transport: &http.Transport{
DisableKeepAlives: true,
DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
if *remoteConn != nil {
(*remoteConn).Close()
*remoteConn = nil
}
dialer := &tls.Dialer{
Config: tlsConfig,
}
cc, err := dialer.DialContext(ctx, network, addr)
if err != nil {
return nil, fmt.Errorf("%w: %w", ErrRetry, err)
}
protocol := cc.(*tls.Conn).ConnectionState().NegotiatedProtocol
if protocol != "http/1.1" {
cc.Close()
return nil, fmt.Errorf("error: unexpected TLS protocol: %s", protocol)
}
*remoteConn = cc.(*tls.Conn)
return cc, nil
},
TLSClientConfig: tlsConfig, // important
},
CheckRedirect: func(req *http.Request, _ []*http.Request) error {
if req.URL.String() == connectURL.String() && req.Method == http.MethodGet {
req.Method = http.MethodConnect
}
return nil
},
Jar: jar,
}
return client
}
}
func (t *tcpUpstream) h2Dialer(
options *RequestOptions,
connectURL *url.URL,
remoteConn **tls.Conn,
writer chan<- *io.PipeWriter,
) func(context.Context) *http.Client {
jar, _ := cookiejar.New(nil)
return func(context.Context) *http.Client {
h1 := &http.Transport{
ForceAttemptHTTP2: true,
DisableKeepAlives: true,
DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
if *remoteConn != nil {
(*remoteConn).Close()
*remoteConn = nil
}
dialer := &tls.Dialer{
Config: &tls.Config{
RootCAs: t.Env().ServerCAs(),
Certificates: options.clientCerts,
NextProtos: []string{"h2"},
},
}
cc, err := dialer.DialContext(ctx, network, addr)
if err != nil {
return nil, fmt.Errorf("%w: %w", ErrRetry, err)
}
protocol := cc.(*tls.Conn).ConnectionState().NegotiatedProtocol
if protocol != "h2" {
cc.Close()
return nil, fmt.Errorf("error: unexpected TLS protocol: %s", protocol)
}
*remoteConn = cc.(*tls.Conn)
return cc, nil
},
TLSClientConfig: &tls.Config{
RootCAs: t.Env().ServerCAs(),
Certificates: options.clientCerts,
NextProtos: []string{"h2"},
},
}
if err := http2.ConfigureTransport(h1); err != nil {
panic(err)
}
client := &http.Client{
Transport: h1,
CheckRedirect: func(req *http.Request, _ []*http.Request) error {
if req.URL.String() == connectURL.String() && req.Method == http.MethodGet {
pr, pw := io.Pipe()
req.Method = http.MethodConnect
req.Body = pr
req.ContentLength = -1
writer <- pw
}
return nil
},
Jar: jar,
}
return client
}
}
// Handle implements TCPUpstream.
func (t *tcpUpstream) Handle(fn func(context.Context, net.Conn) error) {
t.serverHandler = fn
}
// Port implements TCPUpstream.
func (t *tcpUpstream) Addr() values.Value[string] {
return values.Bind(t.serverPort, func(port int) string {
return fmt.Sprintf("%s:%d", t.Env().Host(), port)
})
}
// Route implements TCPUpstream.
func (t *tcpUpstream) Route() testenv.RouteStub {
r := &testenv.TCPRoute{}
r.To(values.Bind(t.serverPort, func(port int) string {
return fmt.Sprintf("tcp://%s:%d", t.Env().Host(), port)
}))
t.Add(r)
return r
}
// Run implements TCPUpstream.
func (t *tcpUpstream) Run(ctx context.Context) error {
ctx, cancel := context.WithCancel(ctx)
defer cancel()
listener, err := (&net.ListenConfig{}).Listen(ctx, "tcp", fmt.Sprintf("%s:0", t.Env().Host()))
if err != nil {
return err
}
context.AfterFunc(ctx, func() {
listener.Close()
})
t.serverPort.Resolve(listener.Addr().(*net.TCPAddr).Port)
if t.serverTracerProviderOverride != nil {
t.serverTracerProvider.Resolve(t.serverTracerProviderOverride)
} else {
t.serverTracerProvider.Resolve(trace.NewTracerProvider(ctx, t.displayName))
}
if t.clientTracerProviderOverride != nil {
t.clientTracerProvider.Resolve(t.clientTracerProviderOverride)
} else {
t.clientTracerProvider.Resolve(trace.NewTracerProvider(ctx, "TCP Client"))
}
var wg sync.WaitGroup
defer wg.Wait()
for {
conn, err := listener.Accept()
if err != nil {
if errors.Is(err, net.ErrClosed) {
cancel()
return nil
}
continue
}
wg.Add(1)
go func() {
defer wg.Done()
if err := t.serverHandler(ctx, conn); err != nil {
if errors.Is(err, io.EOF) {
return
}
panic("server handler error: " + err.Error())
}
}()
}
}
var (
_ testenv.Upstream = (*tcpUpstream)(nil)
_ TCPUpstream = (*tcpUpstream)(nil)
)

View file

@ -0,0 +1,195 @@
package upstreams
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"net/http"
"strconv"
"strings"
"sync"
"time"
"github.com/pomerium/pomerium/integration/forms"
"github.com/pomerium/pomerium/internal/retry"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/codes"
semconv "go.opentelemetry.io/otel/semconv/v1.26.0"
oteltrace "go.opentelemetry.io/otel/trace"
"google.golang.org/protobuf/proto"
)
var ErrRetry = errors.New("error")
func doAuthenticatedRequest(
ctx context.Context,
newRequest func(context.Context) (*http.Request, error),
getClient func(context.Context) *http.Client,
options *RequestOptions,
) (*http.Response, error) {
var resp *http.Response
resendCount := 0
client := getClient(ctx)
if err := retry.Retry(ctx, "http", func(ctx context.Context) error {
req, err := newRequest(ctx)
if err != nil {
return retry.NewTerminalError(err)
}
switch body := options.body.(type) {
case string:
req.Body = io.NopCloser(strings.NewReader(body))
case []byte:
req.Body = io.NopCloser(bytes.NewReader(body))
case io.Reader:
req.Body = io.NopCloser(body)
case proto.Message:
buf, err := proto.Marshal(body)
if err != nil {
return retry.NewTerminalError(err)
}
req.Body = io.NopCloser(bytes.NewReader(buf))
req.Header.Set("Content-Type", "application/octet-stream")
default:
buf, err := json.Marshal(body)
if err != nil {
panic(fmt.Sprintf("unsupported body type: %T", body))
}
req.Body = io.NopCloser(bytes.NewReader(buf))
req.Header.Set("Content-Type", "application/json")
case nil:
}
if options.headers != nil && req.Header == nil {
req.Header = http.Header{}
}
for k, v := range options.headers {
req.Header.Add(k, v)
}
if options.authenticateAs != "" {
resp, err = authenticateFlow(ctx, client, req, options.authenticateAs, true) //nolint:bodyclose
} else {
resp, err = client.Do(req) //nolint:bodyclose
}
// retry on connection refused
span := oteltrace.SpanFromContext(ctx)
if err != nil {
span.RecordError(err)
var opErr *net.OpError
if errors.As(err, &opErr) && opErr.Op == "dial" && opErr.Err.Error() == "connect: connection refused" {
span.AddEvent("Retrying on dial error")
return err
}
return retry.NewTerminalError(err)
}
if resp.StatusCode/100 == 5 {
resendCount++
_, _ = io.ReadAll(resp.Body)
_ = resp.Body.Close()
span.SetAttributes(semconv.HTTPRequestResendCount(resendCount))
span.AddEvent("Retrying on 5xx error", oteltrace.WithAttributes(
attribute.String("status", resp.Status),
))
return errors.New(http.StatusText(resp.StatusCode))
}
span.SetStatus(codes.Ok, "request completed successfully")
return nil
},
retry.WithInitialInterval(1*time.Millisecond),
retry.WithMaxInterval(100*time.Millisecond),
); err != nil {
return nil, err
}
return resp, nil
}
func authenticateFlow(ctx context.Context, client *http.Client, req *http.Request, email string, checkLocation bool) (*http.Response, error) {
span := oteltrace.SpanFromContext(ctx)
var res *http.Response
originalHostname := req.URL.Hostname()
res, err := client.Do(req)
if err != nil {
span.RecordError(err)
return nil, err
}
location := res.Request.URL
if checkLocation && location.Hostname() == originalHostname {
// already authenticated
span.SetStatus(codes.Ok, "already authenticated")
return res, nil
}
fs := forms.Parse(res.Body)
_, _ = io.ReadAll(res.Body)
_ = res.Body.Close()
if len(fs) > 0 {
f := fs[0]
f.Inputs["email"] = email
f.Inputs["token_expiration"] = strconv.Itoa(int((time.Hour * 24).Seconds()))
span.AddEvent("submitting form", oteltrace.WithAttributes(attribute.String("location", location.String())))
formReq, err := f.NewRequestWithContext(ctx, location)
if err != nil {
span.RecordError(err)
return nil, err
}
resp, err := client.Do(formReq)
if err != nil {
span.RecordError(err)
return nil, err
}
span.SetStatus(codes.Ok, "form submitted successfully")
return resp, nil
}
return nil, fmt.Errorf("test bug: expected IDP login form")
}
type rwConn struct {
serverReader io.ReadCloser
serverWriter io.WriteCloser
net.Conn
remote net.Conn
closeOnce sync.Once
wg *sync.WaitGroup
}
func NewRWConn(reader io.ReadCloser, writer io.WriteCloser) net.Conn {
rwc := &rwConn{
serverReader: reader,
serverWriter: writer,
wg: &sync.WaitGroup{},
}
rwc.Conn, rwc.remote = net.Pipe()
rwc.wg.Add(2)
go func() {
defer rwc.wg.Done()
_, _ = io.Copy(rwc.remote, rwc.serverReader)
rwc.remote.Close()
}()
go func() {
defer rwc.wg.Done()
_, _ = io.Copy(rwc.serverWriter, rwc.remote)
rwc.serverWriter.Close()
}()
return rwc
}
func (rwc *rwConn) Close() error {
var err error
rwc.closeOnce.Do(func() {
readerErr := rwc.serverReader.Close()
localErr := rwc.Conn.Close()
rwc.wg.Wait()
err = errors.Join(localErr, readerErr)
})
return err
}
var _ net.Conn = (*rwConn)(nil)