mirror of
https://github.com/pomerium/pomerium.git
synced 2025-08-03 08:50:42 +02:00
add tests/benchmarks for http1/http2 tcp tunnels and http1 websockets (#5471)
* add tests/benchmarks for http1/http2 tcp tunnels and http1 websockets testenv: - add new TCP upstream - add websocket functions to HTTP upstream - add https support to mock idp (default on) - add new debug flags -env.bind-address and -env.use-trace-environ to allow changing the default bind address, and enabling otel environment based trace config, respectively * linter pass --------- Co-authored-by: Denis Mishin <dmishin@pomerium.com>
This commit is contained in:
parent
d6b02441b3
commit
08623ef346
12 changed files with 1104 additions and 182 deletions
|
@ -1,10 +1,8 @@
|
|||
package upstreams
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
|
@ -13,14 +11,12 @@ import (
|
|||
"net/http/cookiejar"
|
||||
"net/http/httptrace"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/pomerium/pomerium/integration/forms"
|
||||
"github.com/pomerium/pomerium/internal/retry"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
||||
"github.com/pomerium/pomerium/internal/testenv"
|
||||
"github.com/pomerium/pomerium/internal/testenv/snippets"
|
||||
|
@ -29,9 +25,15 @@ import (
|
|||
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"go.opentelemetry.io/otel/codes"
|
||||
semconv "go.opentelemetry.io/otel/semconv/v1.26.0"
|
||||
oteltrace "go.opentelemetry.io/otel/trace"
|
||||
"google.golang.org/protobuf/proto"
|
||||
)
|
||||
|
||||
type Protocol string
|
||||
|
||||
const (
|
||||
DialHTTP1 Protocol = "http/1.1"
|
||||
DialHTTP2 Protocol = "h2"
|
||||
DialHTTP3 Protocol = "h3"
|
||||
)
|
||||
|
||||
type RequestOptions struct {
|
||||
|
@ -42,12 +44,18 @@ type RequestOptions struct {
|
|||
authenticateAs string
|
||||
body any
|
||||
clientCerts []tls.Certificate
|
||||
client *http.Client
|
||||
clientHook func(*http.Client) *http.Client
|
||||
dialerHook func(*websocket.Dialer, *url.URL) (*websocket.Dialer, *url.URL)
|
||||
dialProtocol Protocol
|
||||
trace *httptrace.ClientTrace
|
||||
}
|
||||
|
||||
type RequestOption func(*RequestOptions)
|
||||
|
||||
func (ro RequestOption) Format(fmt.State, rune) {
|
||||
panic("test bug: request option mistakenly passed to assert function")
|
||||
}
|
||||
|
||||
func (o *RequestOptions) apply(opts ...RequestOption) {
|
||||
for _, op := range opts {
|
||||
op(o)
|
||||
|
@ -82,9 +90,38 @@ func AuthenticateAs(email string) RequestOption {
|
|||
}
|
||||
}
|
||||
|
||||
func Client(c *http.Client) RequestOption {
|
||||
// ClientHook allows editing or replacing the http client before it is used.
|
||||
// When any request is about to start, this function will be called with the
|
||||
// client that would be used to make the request. The returned client will
|
||||
// be the actual client used for that request. It can be the same as the input
|
||||
// (with or without modification), or replaced entirely.
|
||||
//
|
||||
// Note: the Transport of the client passed to the hook will always be a
|
||||
// [*Transport]. That transport's underlying transport will always be
|
||||
// a [*otelhttp.Transport].
|
||||
func ClientHook(f func(*http.Client) *http.Client) RequestOption {
|
||||
return func(o *RequestOptions) {
|
||||
o.client = c
|
||||
o.clientHook = f
|
||||
}
|
||||
}
|
||||
|
||||
// DialerHook allows editing or replacing the websocket dialer before it is
|
||||
// used. When a websocket request is about to start (using the DialWS method),
|
||||
// this function will be called with the dialer that would be used, and the
|
||||
// destination URL (including wss:// scheme, and path if one is present). The
|
||||
// returned dialer+URL will be the actual dialer+URL used for that request.
|
||||
//
|
||||
// If ClientHook is also set, both will be called. The dialer passed to this
|
||||
// hook will have its TLSClientConfig and Jar fields set from the client.
|
||||
func DialerHook(f func(*websocket.Dialer, *url.URL) (*websocket.Dialer, *url.URL)) RequestOption {
|
||||
return func(o *RequestOptions) {
|
||||
o.dialerHook = f
|
||||
}
|
||||
}
|
||||
|
||||
func DialProtocol(protocol Protocol) RequestOption {
|
||||
return func(o *RequestOptions) {
|
||||
o.dialProtocol = protocol
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -143,10 +180,12 @@ type HTTPUpstream interface {
|
|||
testenv.Upstream
|
||||
|
||||
Handle(path string, f func(http.ResponseWriter, *http.Request)) *mux.Route
|
||||
HandleWS(path string, upgrader websocket.Upgrader, f func(conn *websocket.Conn) error) *mux.Route
|
||||
|
||||
Get(r testenv.Route, opts ...RequestOption) (*http.Response, error)
|
||||
Post(r testenv.Route, opts ...RequestOption) (*http.Response, error)
|
||||
Do(method string, r testenv.Route, opts ...RequestOption) (*http.Response, error)
|
||||
DialWS(r testenv.Route, f func(conn *websocket.Conn) error, opts ...RequestOption) error
|
||||
}
|
||||
|
||||
type httpUpstream struct {
|
||||
|
@ -194,8 +233,10 @@ func HTTP(tlsConfig values.Value[*tls.Config], opts ...HTTPUpstreamOption) HTTPU
|
|||
}
|
||||
|
||||
// Port implements HTTPUpstream.
|
||||
func (h *httpUpstream) Port() values.Value[int] {
|
||||
return h.serverPort
|
||||
func (h *httpUpstream) Addr() values.Value[string] {
|
||||
return values.Bind(h.serverPort, func(port int) string {
|
||||
return fmt.Sprintf("%s:%d", h.Env().Host(), port)
|
||||
})
|
||||
}
|
||||
|
||||
// Router implements HTTPUpstream.
|
||||
|
@ -203,12 +244,37 @@ func (h *httpUpstream) Handle(path string, f func(http.ResponseWriter, *http.Req
|
|||
return h.router.HandleFunc(path, f)
|
||||
}
|
||||
|
||||
// Router implements HTTPUpstream.
|
||||
func (h *httpUpstream) HandleWS(path string, upgrader websocket.Upgrader, f func(*websocket.Conn) error) *mux.Route {
|
||||
return h.router.HandleFunc(path, func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, span := trace.Continue(r.Context(), "HandleWS")
|
||||
defer span.End()
|
||||
c, err := upgrader.Upgrade(w, r.WithContext(ctx), nil)
|
||||
if err != nil {
|
||||
span.SetStatus(codes.Error, err.Error())
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
_, _ = w.Write([]byte(err.Error()))
|
||||
return
|
||||
}
|
||||
defer c.Close()
|
||||
|
||||
err = f(c)
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
return
|
||||
}
|
||||
span.SetStatus(codes.Error, err.Error())
|
||||
fmt.Fprintf(os.Stderr, "websocket error: %s\n", err.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Route implements HTTPUpstream.
|
||||
func (h *httpUpstream) Route() testenv.RouteStub {
|
||||
r := &testenv.PolicyRoute{}
|
||||
protocol := "http"
|
||||
r.To(values.Bind(h.serverPort, func(port int) string {
|
||||
return fmt.Sprintf("%s://127.0.0.1:%d", protocol, port)
|
||||
return fmt.Sprintf("%s://%s:%d", protocol, h.Env().Host(), port)
|
||||
}))
|
||||
h.Add(r)
|
||||
return r
|
||||
|
@ -216,15 +282,21 @@ func (h *httpUpstream) Route() testenv.RouteStub {
|
|||
|
||||
// Run implements HTTPUpstream.
|
||||
func (h *httpUpstream) Run(ctx context.Context) error {
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
return err
|
||||
var listener net.Listener
|
||||
if h.tlsConfig != nil {
|
||||
var err error
|
||||
listener, err = tls.Listen("tcp", fmt.Sprintf("%s:0", h.Env().Host()), h.tlsConfig.Value())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
var err error
|
||||
listener, err = net.Listen("tcp", fmt.Sprintf("%s:0", h.Env().Host()))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
h.serverPort.Resolve(listener.Addr().(*net.TCPAddr).Port)
|
||||
var tlsConfig *tls.Config
|
||||
if h.tlsConfig != nil {
|
||||
tlsConfig = h.tlsConfig.Value()
|
||||
}
|
||||
if h.serverTracerProviderOverride != nil {
|
||||
h.serverTracerProvider.Resolve(h.serverTracerProviderOverride)
|
||||
} else {
|
||||
|
@ -238,8 +310,7 @@ func (h *httpUpstream) Run(ctx context.Context) error {
|
|||
h.router.Use(trace.NewHTTPMiddleware(otelhttp.WithTracerProvider(h.serverTracerProvider.Value())))
|
||||
|
||||
server := &http.Server{
|
||||
Handler: h.router,
|
||||
TLSConfig: tlsConfig,
|
||||
Handler: h.router,
|
||||
// BaseContext: func(net.Listener) context.Context {
|
||||
// return ctx
|
||||
// },
|
||||
|
@ -277,6 +348,53 @@ func (h *httpUpstream) Post(r testenv.Route, opts ...RequestOption) (*http.Respo
|
|||
return h.Do(http.MethodPost, r, opts...)
|
||||
}
|
||||
|
||||
type Transport struct {
|
||||
*otelhttp.Transport
|
||||
// The underlying http.Transport instance wrapped by the otelhttp.Transport.
|
||||
Base *http.Transport
|
||||
}
|
||||
|
||||
var _ http.RoundTripper = Transport{}
|
||||
|
||||
func (h *httpUpstream) newClient(options *RequestOptions) *http.Client {
|
||||
transport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
transport.TLSClientConfig = &tls.Config{
|
||||
RootCAs: h.Env().ServerCAs(),
|
||||
Certificates: options.clientCerts,
|
||||
}
|
||||
transport.DialTLSContext = nil
|
||||
c := http.Client{
|
||||
Transport: &Transport{
|
||||
Transport: otelhttp.NewTransport(transport,
|
||||
otelhttp.WithTracerProvider(h.clientTracerProvider.Value()),
|
||||
otelhttp.WithSpanNameFormatter(func(_ string, r *http.Request) string {
|
||||
return fmt.Sprintf("Client: %s %s", r.Method, r.URL.Path)
|
||||
}),
|
||||
),
|
||||
Base: transport,
|
||||
},
|
||||
}
|
||||
c.Jar, _ = cookiejar.New(&cookiejar.Options{})
|
||||
return &c
|
||||
}
|
||||
|
||||
func (h *httpUpstream) getRouteClient(r testenv.Route, options *RequestOptions) *http.Client {
|
||||
span := oteltrace.SpanFromContext(options.requestCtx)
|
||||
var cachedClient any
|
||||
var ok bool
|
||||
if cachedClient, ok = h.clientCache.Load(r); !ok {
|
||||
span.AddEvent("creating new http client")
|
||||
cachedClient, _ = h.clientCache.LoadOrStore(r, h.newClient(options))
|
||||
} else {
|
||||
span.AddEvent("using cached http client")
|
||||
}
|
||||
client := cachedClient.(*http.Client)
|
||||
if options.clientHook != nil {
|
||||
client = options.clientHook(client)
|
||||
}
|
||||
return client
|
||||
}
|
||||
|
||||
// Do implements HTTPUpstream.
|
||||
func (h *httpUpstream) Do(method string, r testenv.Route, opts ...RequestOption) (*http.Response, error) {
|
||||
options := RequestOptions{
|
||||
|
@ -303,141 +421,54 @@ func (h *httpUpstream) Do(method string, r testenv.Route, opts ...RequestOption)
|
|||
options.requestCtx = ctx
|
||||
defer span.End()
|
||||
|
||||
newClient := func() *http.Client {
|
||||
transport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
transport.TLSClientConfig = &tls.Config{
|
||||
RootCAs: h.Env().ServerCAs(),
|
||||
Certificates: options.clientCerts,
|
||||
}
|
||||
transport.DialTLSContext = nil
|
||||
c := http.Client{
|
||||
Transport: otelhttp.NewTransport(transport,
|
||||
otelhttp.WithTracerProvider(h.clientTracerProvider.Value()),
|
||||
otelhttp.WithSpanNameFormatter(func(_ string, r *http.Request) string {
|
||||
return fmt.Sprintf("Client: %s %s", r.Method, r.URL.Path)
|
||||
}),
|
||||
),
|
||||
}
|
||||
c.Jar, _ = cookiejar.New(&cookiejar.Options{})
|
||||
return &c
|
||||
}
|
||||
var client *http.Client
|
||||
if options.client != nil {
|
||||
client = options.client
|
||||
} else {
|
||||
var cachedClient any
|
||||
var ok bool
|
||||
if cachedClient, ok = h.clientCache.Load(r); !ok {
|
||||
span.AddEvent("creating new http client")
|
||||
cachedClient, _ = h.clientCache.LoadOrStore(r, newClient())
|
||||
} else {
|
||||
span.AddEvent("using cached http client")
|
||||
}
|
||||
client = cachedClient.(*http.Client)
|
||||
}
|
||||
|
||||
var resp *http.Response
|
||||
resendCount := 0
|
||||
if err := retry.Retry(ctx, "http", func(ctx context.Context) error {
|
||||
req, err := http.NewRequestWithContext(ctx, method, u.String(), nil)
|
||||
if err != nil {
|
||||
return retry.NewTerminalError(err)
|
||||
}
|
||||
switch body := options.body.(type) {
|
||||
case string:
|
||||
req.Body = io.NopCloser(strings.NewReader(body))
|
||||
case []byte:
|
||||
req.Body = io.NopCloser(bytes.NewReader(body))
|
||||
case io.Reader:
|
||||
req.Body = io.NopCloser(body)
|
||||
case proto.Message:
|
||||
buf, err := proto.Marshal(body)
|
||||
if err != nil {
|
||||
return retry.NewTerminalError(err)
|
||||
}
|
||||
req.Body = io.NopCloser(bytes.NewReader(buf))
|
||||
req.Header.Set("Content-Type", "application/octet-stream")
|
||||
default:
|
||||
buf, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("unsupported body type: %T", body))
|
||||
}
|
||||
req.Body = io.NopCloser(bytes.NewReader(buf))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
case nil:
|
||||
}
|
||||
|
||||
if options.authenticateAs != "" {
|
||||
resp, err = authenticateFlow(ctx, client, req, options.authenticateAs) //nolint:bodyclose
|
||||
} else {
|
||||
resp, err = client.Do(req) //nolint:bodyclose
|
||||
}
|
||||
// retry on connection refused
|
||||
if err != nil {
|
||||
span.RecordError(err)
|
||||
var opErr *net.OpError
|
||||
if errors.As(err, &opErr) && opErr.Op == "dial" && opErr.Err.Error() == "connect: connection refused" {
|
||||
span.AddEvent("Retrying on dial error")
|
||||
return err
|
||||
}
|
||||
return retry.NewTerminalError(err)
|
||||
}
|
||||
if resp.StatusCode/100 == 5 {
|
||||
resendCount++
|
||||
_, _ = io.ReadAll(resp.Body)
|
||||
_ = resp.Body.Close()
|
||||
span.SetAttributes(semconv.HTTPRequestResendCount(resendCount))
|
||||
span.AddEvent("Retrying on 5xx error", oteltrace.WithAttributes(
|
||||
attribute.String("status", resp.Status),
|
||||
))
|
||||
return errors.New(http.StatusText(resp.StatusCode))
|
||||
}
|
||||
span.SetStatus(codes.Ok, "request completed successfully")
|
||||
return nil
|
||||
},
|
||||
retry.WithInitialInterval(1*time.Millisecond),
|
||||
retry.WithMaxInterval(100*time.Millisecond),
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return resp, nil
|
||||
return doAuthenticatedRequest(options.requestCtx,
|
||||
func(ctx context.Context) (*http.Request, error) {
|
||||
return http.NewRequestWithContext(ctx, method, u.String(), nil)
|
||||
},
|
||||
func(context.Context) *http.Client {
|
||||
return h.getRouteClient(r, &options)
|
||||
},
|
||||
&options,
|
||||
)
|
||||
}
|
||||
|
||||
func authenticateFlow(ctx context.Context, client *http.Client, req *http.Request, email string) (*http.Response, error) {
|
||||
span := oteltrace.SpanFromContext(ctx)
|
||||
var res *http.Response
|
||||
originalHostname := req.URL.Hostname()
|
||||
res, err := client.Do(req)
|
||||
func (h *httpUpstream) DialWS(r testenv.Route, f func(conn *websocket.Conn) error, opts ...RequestOption) error {
|
||||
options := RequestOptions{
|
||||
requestCtx: h.Env().Context(),
|
||||
}
|
||||
options.apply(opts...)
|
||||
u, err := url.Parse(r.URL().Value())
|
||||
if err != nil {
|
||||
span.RecordError(err)
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
location := res.Request.URL
|
||||
if location.Hostname() == originalHostname {
|
||||
// already authenticated
|
||||
span.SetStatus(codes.Ok, "already authenticated")
|
||||
return res, nil
|
||||
u.Scheme = "wss"
|
||||
if options.path != "" || options.query != nil {
|
||||
u = u.ResolveReference(&url.URL{
|
||||
Path: options.path,
|
||||
RawQuery: options.query.Encode(),
|
||||
})
|
||||
}
|
||||
fs := forms.Parse(res.Body)
|
||||
_, _ = io.ReadAll(res.Body)
|
||||
_ = res.Body.Close()
|
||||
if len(fs) > 0 {
|
||||
f := fs[0]
|
||||
f.Inputs["email"] = email
|
||||
f.Inputs["token_expiration"] = strconv.Itoa(int((time.Hour * 24).Seconds()))
|
||||
span.AddEvent("submitting form", oteltrace.WithAttributes(attribute.String("location", location.String())))
|
||||
formReq, err := f.NewRequestWithContext(ctx, location)
|
||||
if err != nil {
|
||||
span.RecordError(err)
|
||||
return nil, err
|
||||
}
|
||||
resp, err := client.Do(formReq)
|
||||
if err != nil {
|
||||
span.RecordError(err)
|
||||
return nil, err
|
||||
}
|
||||
span.SetStatus(codes.Ok, "form submitted successfully")
|
||||
return resp, nil
|
||||
ctx, span := h.clientTracer.Value().Start(options.requestCtx, "httpUpstream.Dial", oteltrace.WithAttributes(
|
||||
attribute.String("url", u.String()),
|
||||
))
|
||||
options.requestCtx = ctx
|
||||
defer span.End()
|
||||
|
||||
client := h.getRouteClient(r, &options)
|
||||
d := &websocket.Dialer{
|
||||
HandshakeTimeout: 10 * time.Second,
|
||||
TLSClientConfig: client.Transport.(*Transport).Base.TLSClientConfig,
|
||||
Jar: client.Jar,
|
||||
}
|
||||
return nil, fmt.Errorf("test bug: expected IDP login form")
|
||||
if options.dialerHook != nil {
|
||||
d, u = options.dialerHook(d, u)
|
||||
}
|
||||
conn, resp, err := d.DialContext(options.requestCtx, u.String(), nil)
|
||||
if err != nil {
|
||||
resp.Body.Close()
|
||||
return fmt.Errorf("DialContext: %w", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
return f(conn)
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue