diff --git a/config/envoyconfig/protocols_int_test.go b/config/envoyconfig/protocols_int_test.go index ef1e82785..8c332b0d9 100644 --- a/config/envoyconfig/protocols_int_test.go +++ b/config/envoyconfig/protocols_int_test.go @@ -1,11 +1,17 @@ package envoyconfig_test import ( + "context" + "errors" "fmt" "io" + "net" "net/http" + "sync/atomic" "testing" + "time" + "github.com/gorilla/websocket" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "google.golang.org/grpc/credentials/insecure" @@ -28,9 +34,9 @@ func TestH2C(t *testing.T) { http := up.Route(). From(env.SubdomainURL("grpc-http")). - To(values.Bind(up.Port(), func(port int) string { + To(values.Bind(up.Addr(), func(addr string) string { // override the target protocol to use http:// - return fmt.Sprintf("http://127.0.0.1:%d", port) + return fmt.Sprintf("http://%s", addr) })). Policy(func(p *config.Policy) { p.AllowPublicUnauthenticatedAccess = true }) @@ -118,6 +124,234 @@ func TestHTTP(t *testing.T) { }) } +func TestTCPTunnel(t *testing.T) { + env := testenv.New(t, testenv.Debug()) + + env.Add(scenarios.NewIDP([]*scenarios.User{{Email: "test@example.com"}})) + up := upstreams.TCP() + routeH1 := up.Route(). + From(env.SubdomainURL("h1")). + PPL(`{"allow":{"and":["email":{"is":"test@example.com"}]}}`) + routeH2 := up.Route(). + From(env.SubdomainURL("h2")). + Policy(func(p *config.Policy) { + p.AllowWebsockets = true + }). + PPL(`{"allow":{"and":["email":{"is":"test@example.com"}]}}`) + + up.Handle(func(_ context.Context, c net.Conn) error { + c.SetReadDeadline(time.Now().Add(1 * time.Second)) + buf := make([]byte, 8) + n, err := c.Read(buf) + require.NoError(t, err) + assert.Equal(t, string(buf[:n]), "hello") + c.SetWriteDeadline(time.Now().Add(1 * time.Second)) + _, err = c.Write([]byte("world")) + require.NoError(t, err) + + return nil + }) + + env.AddUpstream(up) + env.Start() + snippets.WaitStartupComplete(env) + + t.Run("http1", func(t *testing.T) { + assert.NoError(t, up.Dial(routeH1, func(_ context.Context, c net.Conn) error { + c.SetWriteDeadline(time.Now().Add(1 * time.Second)) + _, err := c.Write([]byte("hello")) + require.NoError(t, err) + + buf := make([]byte, 8) + c.SetReadDeadline(time.Now().Add(1 * time.Second)) + n, err := c.Read(buf) + require.NoError(t, err) + + assert.Equal(t, string(buf[:n]), "world") + return nil + }, upstreams.AuthenticateAs("test@example.com"), upstreams.DialProtocol(upstreams.DialHTTP1))) + }) + + t.Run("http2", func(t *testing.T) { + assert.NoError(t, up.Dial(routeH2, func(_ context.Context, c net.Conn) error { + c.SetWriteDeadline(time.Now().Add(1 * time.Second)) + _, err := c.Write([]byte("hello")) + require.NoError(t, err) + + buf := make([]byte, 8) + c.SetReadDeadline(time.Now().Add(1 * time.Second)) + n, err := c.Read(buf) + require.NoError(t, err) + + assert.Equal(t, string(buf[:n]), "world") + return nil + }, upstreams.AuthenticateAs("test@example.com"), upstreams.DialProtocol(upstreams.DialHTTP2))) + }) +} + +func BenchmarkHTTP1TCPTunnel(b *testing.B) { + env := testenv.New(b, testenv.Silent()) + env.Add(scenarios.NewIDP([]*scenarios.User{{Email: "test@example.com"}})) + up := upstreams.TCP() + h1 := up.Route(). + From(env.SubdomainURL("bench-h1")). + PPL(`{"allow":{"and":["email":{"is":"test@example.com"}]}}`) + + env.AddUpstream(up) + env.Start() + snippets.WaitStartupComplete(env) + + b.Run("http1", func(b *testing.B) { + benchmarkTCP(b, up, h1, tcpBenchmarkParams{ + msgLen: 512, + protocol: upstreams.DialHTTP1, + }) + }) +} + +func BenchmarkHTTP2TCPTunnel(b *testing.B) { + env := testenv.New(b, testenv.Silent()) + env.Add(scenarios.NewIDP([]*scenarios.User{{Email: "test@example.com"}})) + up := upstreams.TCP() + + h2 := up.Route(). + From(env.SubdomainURL("bench-h2")). + Policy(func(p *config.Policy) { + p.AllowWebsockets = true + }). + PPL(`{"allow":{"and":["email":{"is":"test@example.com"}]}}`) + + env.AddUpstream(up) + env.Start() + snippets.WaitStartupComplete(env) + + b.Run("http2", func(b *testing.B) { + benchmarkTCP(b, up, h2, tcpBenchmarkParams{ + msgLen: 512, + protocol: upstreams.DialHTTP2, + }) + }) +} + +type tcpBenchmarkParams struct { + msgLen int + protocol upstreams.Protocol +} + +func benchmarkTCP(b *testing.B, up upstreams.TCPUpstream, route testenv.Route, params tcpBenchmarkParams) { + sendMsg := func(c net.Conn, buf []byte) error { + c.SetWriteDeadline(time.Now().Add(1 * time.Second)) + _, err := c.Write(buf) + if err != nil { + if errors.Is(err, net.ErrClosed) { + return nil + } + } + return err + } + recvMsg := func(c net.Conn, buf []byte) error { + c.SetReadDeadline(time.Now().Add(1 * time.Second)) + for read := 0; read != len(buf); { + n, err := c.Read(buf) + read += n + if err != nil { + if errors.Is(err, net.ErrClosed) { + return nil + } + return err + } + } + return nil + } + up.Handle(func(_ context.Context, c net.Conn) error { + for { + buf := make([]byte, params.msgLen) + if err := recvMsg(c, buf[:]); err != nil { + if errors.Is(err, net.ErrClosed) { + return nil + } + return err + } + if err := sendMsg(c, buf[:]); err != nil { + if errors.Is(err, net.ErrClosed) { + return nil + } + return err + } + } + }) + var threads atomic.Int32 + var requests atomic.Int32 + var bytes atomic.Int64 + start := time.Now() + b.RunParallel(func(p *testing.PB) { + threads.Add(1) + require.NoError(b, up.Dial(route, func(_ context.Context, c net.Conn) error { + buf := make([]byte, params.msgLen) + for p.Next() { + requests.Add(1) + bytes.Add(int64(params.msgLen)) + require.NoError(b, sendMsg(c, buf[:])) + require.NoError(b, recvMsg(c, buf[:])) + } + return nil + }, upstreams.AuthenticateAs("test@example.com"), upstreams.DialProtocol(params.protocol))) + }) + duration := time.Since(start) + b.Logf("sent %d requests over %d parallel connections in %s", requests.Load(), threads.Load(), duration) + b.Logf("throughput: %f bytes/s", float64(bytes.Load())/duration.Seconds()) +} + +func TestHttp1Websocket(t *testing.T) { + env := testenv.New(t) + + up := upstreams.HTTP(nil) + up.HandleWS("/ws", websocket.Upgrader{}, func(conn *websocket.Conn) error { + for { + mt, message, err := conn.ReadMessage() + if err != nil { + return err + } + + // echo the message back + err = conn.WriteMessage(mt, message) + if err != nil { + return err + } + } + }) + + route := up.Route(). + From(env.SubdomainURL("ws-test")). + Policy(func(p *config.Policy) { + p.AllowPublicUnauthenticatedAccess = true + p.AllowWebsockets = true + }) + + env.AddUpstream(up) + env.Start() + snippets.WaitStartupComplete(env) + + assert.NoError(t, up.DialWS(route, func(conn *websocket.Conn) error { + if err := conn.SetWriteDeadline(time.Now().Add(1 * time.Second)); err != nil { + return err + } + if err := conn.WriteMessage(websocket.TextMessage, []byte("hello world")); err != nil { + return err + } + if err := conn.SetReadDeadline(time.Now().Add(1 * time.Second)); err != nil { + return err + } + mt, bytes, err := conn.ReadMessage() + if err := err; err != nil { + return err + } + assert.Equal(t, websocket.TextMessage, mt) + assert.Equal(t, "hello world", string(bytes)) + return nil + }, upstreams.Path("/ws"))) +} + func TestClientCert(t *testing.T) { env := testenv.New(t) env.Add(scenarios.DownstreamMTLS(config.MTLSEnforcementRejectConnection)) diff --git a/internal/testenv/environment.go b/internal/testenv/environment.go index c42e6c471..8b16f99ea 100644 --- a/internal/testenv/environment.go +++ b/internal/testenv/environment.go @@ -23,6 +23,7 @@ import ( "os/signal" "path" "path/filepath" + "reflect" "runtime" "strconv" "strings" @@ -34,11 +35,13 @@ import ( "github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/config/envoyconfig/filemgr" + "github.com/pomerium/pomerium/config/otelconfig" databroker_service "github.com/pomerium/pomerium/databroker" "github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/telemetry/trace" "github.com/pomerium/pomerium/internal/testenv/envutil" "github.com/pomerium/pomerium/internal/testenv/values" + "github.com/pomerium/pomerium/internal/version" "github.com/pomerium/pomerium/pkg/cmd/pomerium" "github.com/pomerium/pomerium/pkg/envoy" "github.com/pomerium/pomerium/pkg/grpc/databroker" @@ -97,6 +100,7 @@ type Environment interface { AuthenticateURL() values.Value[string] DatabrokerURL() values.Value[string] Ports() Ports + Host() string SharedSecret() []byte CookieSecret() []byte @@ -244,6 +248,8 @@ type EnvironmentOptions struct { forceSilent bool traceDebugFlags trace.DebugFlags traceClient otlptrace.Client + traceConfig *otelconfig.Config + host string } type EnvironmentOption func(*EnvironmentOptions) @@ -300,15 +306,23 @@ func WithTraceClient(traceClient otlptrace.Client) EnvironmentOption { } } +func WithTraceConfig(traceConfig *otelconfig.Config) EnvironmentOption { + return func(o *EnvironmentOptions) { + o.traceConfig = traceConfig + } +} + var setGrpcLoggerOnce sync.Once const defaultTraceDebugFlags = trace.TrackSpanCallers | trace.TrackSpanReferences var ( - flagDebug = flag.Bool("env.debug", false, "enables test environment debug logging (equivalent to Debug() option)") - flagPauseOnFailure = flag.Bool("env.pause-on-failure", false, "enables pausing the test environment on failure (equivalent to PauseOnFailure() option)") - flagSilent = flag.Bool("env.silent", false, "suppresses all test environment output (equivalent to Silent() option)") - flagTraceDebugFlags = flag.String("env.trace-debug-flags", strconv.Itoa(defaultTraceDebugFlags), "trace debug flags (equivalent to TraceDebugFlags() option)") + flagDebug = flag.Bool("env.debug", false, "enables test environment debug logging (equivalent to Debug() option)") + flagPauseOnFailure = flag.Bool("env.pause-on-failure", false, "enables pausing the test environment on failure (equivalent to PauseOnFailure() option)") + flagSilent = flag.Bool("env.silent", false, "suppresses all test environment output (equivalent to Silent() option)") + flagTraceDebugFlags = flag.String("env.trace-debug-flags", strconv.Itoa(defaultTraceDebugFlags), "trace debug flags (equivalent to TraceDebugFlags() option)") + flagBindAddress = flag.String("env.bind-address", "127.0.0.1", "bind address for local services") + flagTraceEnvironConfig = flag.Bool("env.use-trace-environ", false, "if true, will configure a trace client from environment variables if no trace client has been set") ) func New(t testing.TB, opts ...EnvironmentOption) Environment { @@ -323,6 +337,7 @@ func New(t testing.TB, opts ...EnvironmentOption) Environment { forceSilent: *flagSilent, traceDebugFlags: trace.DebugFlags(defaultTraceDebugFlags), traceClient: trace.NoopClient{}, + host: *flagBindAddress, } options.apply(opts...) if testing.Short() { @@ -332,6 +347,17 @@ func New(t testing.TB, opts ...EnvironmentOption) Environment { if addTraceDebugFlags { options.traceDebugFlags |= trace.DebugFlags(defaultTraceDebugFlags) } + if *flagTraceEnvironConfig && options.traceConfig == nil && + (reflect.TypeOf(options.traceClient) == reflect.TypeFor[trace.NoopClient]()) { + cfg := newOtelConfigFromEnv(t) + options.traceConfig = &cfg + client, err := trace.NewTraceClientFromConfig(cfg) + if err != nil { + t.Fatal(err) + } + t.Log("tracing configured from environment") + options.traceClient = client + } trace.UseGlobalPanicTracer() databroker.DebugUseFasterBackoff.Store(true) workspaceFolder, err := os.Getwd() @@ -495,7 +521,7 @@ func (e *environment) AuthenticateURL() values.Value[string] { func (e *environment) DatabrokerURL() values.Value[string] { return values.Bind(e.ports.Outbound, func(port int) string { - return fmt.Sprintf("127.0.0.1:%d", port) + return fmt.Sprintf("%s:%d", e.host, port) }) } @@ -503,6 +529,13 @@ func (e *environment) Ports() Ports { return e.ports } +func (e *environment) Host() string { + if e.host == "" { + return "127.0.0.1" + } + return e.host +} + func (e *environment) CACert() *tls.Certificate { caCert, err := tls.LoadX509KeyPair( filepath.Join(e.tempDir, "certs", "ca.pem"), @@ -571,9 +604,9 @@ func (e *environment) Start() { cfg.Options.Services = "all" cfg.Options.LogLevel = config.LogLevelDebug cfg.Options.ProxyLogLevel = config.LogLevelInfo - cfg.Options.Addr = fmt.Sprintf("127.0.0.1:%d", e.ports.ProxyHTTP.Value()) - cfg.Options.GRPCAddr = fmt.Sprintf("127.0.0.1:%d", e.ports.ProxyGRPC.Value()) - cfg.Options.MetricsAddr = fmt.Sprintf("127.0.0.1:%d", e.ports.ProxyMetrics.Value()) + cfg.Options.Addr = fmt.Sprintf("%s:%d", e.host, e.ports.ProxyHTTP.Value()) + cfg.Options.GRPCAddr = fmt.Sprintf("%s:%d", e.host, e.ports.ProxyGRPC.Value()) + cfg.Options.MetricsAddr = fmt.Sprintf("%s:%d", e.host, e.ports.ProxyMetrics.Value()) cfg.Options.CAFile = filepath.Join(e.tempDir, "certs", "ca.pem") cfg.Options.CertFile = filepath.Join(e.tempDir, "certs", "trusted.pem") cfg.Options.KeyFile = filepath.Join(e.tempDir, "certs", "trusted-key.pem") @@ -598,6 +631,9 @@ func (e *environment) Start() { log.AccessLogFieldUserAgent, log.AccessLogFieldClientCertificate, } + if e.traceConfig != nil { + cfg.Options.Tracing = *e.traceConfig + } e.src = &configSource{cfg: cfg} e.AddTask(TaskFunc(func(ctx context.Context) error { @@ -799,6 +835,7 @@ func (e *environment) Pause() { c := make(chan os.Signal, 1) signal.Notify(c, syscall.SIGINT) <-c + signal.Stop(c) e.t.Log("\x1b[31mctrl+c received, continuing\x1b[0m") } @@ -816,6 +853,7 @@ func (e *environment) cleanup(cancelCause error) { c := make(chan os.Signal, 1) signal.Notify(c, syscall.SIGINT) <-c + signal.Stop(c) e.t.Log("\x1b[31mctrl+c received, continuing\x1b[0m") signal.Stop(c) } @@ -1043,3 +1081,13 @@ func (src *configSource) ModifyConfig(ctx context.Context, m Modifier) { li(ctx, src.cfg) } } + +func newOtelConfigFromEnv(t testing.TB) otelconfig.Config { + f, err := os.CreateTemp("", "tmp-config-*.yaml") + require.NoError(t, err) + defer os.Remove(f.Name()) + f.Close() + cfg, err := config.NewFileOrEnvironmentSource(context.Background(), f.Name(), version.FullVersion()) + require.NoError(t, err) + return cfg.GetConfig().Options.Tracing +} diff --git a/internal/testenv/route.go b/internal/testenv/route.go index 8b002f8bb..0eaea287c 100644 --- a/internal/testenv/route.go +++ b/internal/testenv/route.go @@ -1,6 +1,7 @@ package testenv import ( + "fmt" "net/url" "strings" @@ -74,3 +75,19 @@ func (b *PolicyRoute) PPL(ppl string) Route { func (b *PolicyRoute) URL() values.Value[string] { return b.from } + +type TCPRoute struct { + PolicyRoute +} + +func (b *TCPRoute) From(fromURL values.Value[string]) Route { + b.from = values.Bind(fromURL, func(urlStr string) string { + from, _ := url.Parse(urlStr) + from.Scheme = "tcp+https" + from.Host = fmt.Sprintf("%s:%s", from.Hostname(), from.Port()) + return from.String() + }) + return b +} + +var _ Route = (*TCPRoute)(nil) diff --git a/internal/testenv/scenarios/mock_idp.go b/internal/testenv/scenarios/mock_idp.go index 85babd674..9e07b70e1 100644 --- a/internal/testenv/scenarios/mock_idp.go +++ b/internal/testenv/scenarios/mock_idp.go @@ -6,6 +6,8 @@ import ( "crypto/ecdsa" "crypto/elliptic" "crypto/rand" + "crypto/tls" + "crypto/x509" "encoding/base64" "encoding/hex" "encoding/json" @@ -32,6 +34,7 @@ import ( ) type IDP struct { + IDPOptions id values.Value[string] url values.Value[string] publicJWK jose.JSONWebKey @@ -41,18 +44,56 @@ type IDP struct { userLookup map[string]*User } +type IDPOptions struct { + enableTLS bool +} + +type IDPOption func(*IDPOptions) + +func (o *IDPOptions) apply(opts ...IDPOption) { + for _, op := range opts { + op(o) + } +} + +func WithEnableTLS(enableTLS bool) IDPOption { + return func(o *IDPOptions) { + o.enableTLS = enableTLS + } +} + // Attach implements testenv.Modifier. func (idp *IDP) Attach(ctx context.Context) { env := testenv.EnvFromContext(ctx) - router := upstreams.HTTP(nil, upstreams.WithDisplayName("IDP")) + idpURL := env.SubdomainURL("mock-idp") - idp.url = values.Bind2(env.SubdomainURL("mock-idp"), router.Port(), func(urlStr string, port int) string { + var tlsConfig values.Value[*tls.Config] + if idp.enableTLS { + tlsConfig = values.Bind(idpURL, func(urlStr string) *tls.Config { + u, _ := url.Parse(urlStr) + cert := env.NewServerCert(&x509.Certificate{ + DNSNames: []string{u.Hostname()}, + }) + return &tls.Config{ + RootCAs: env.ServerCAs(), + Certificates: []tls.Certificate{tls.Certificate(*cert)}, + NextProtos: []string{"http/1.1", "h2"}, + } + }) + } + + router := upstreams.HTTP(tlsConfig, upstreams.WithDisplayName("IDP")) + + idp.url = values.Bind2(idpURL, router.Addr(), func(urlStr string, addr string) string { u, _ := url.Parse(urlStr) host, _, _ := net.SplitHostPort(u.Host) + _, port, err := net.SplitHostPort(addr) + if err != nil { + panic("bug: " + err.Error()) + } return u.ResolveReference(&url.URL{ - Scheme: "http", - Host: fmt.Sprintf("%s:%d", host, port), + Host: fmt.Sprintf("%s:%s", host, port), }).String() }) var err error @@ -108,7 +149,12 @@ func (idp *IDP) Modify(cfg *config.Config) { var _ testenv.Modifier = (*IDP)(nil) -func NewIDP(users []*User) *IDP { +func NewIDP(users []*User, opts ...IDPOption) *IDP { + options := IDPOptions{ + enableTLS: true, + } + options.apply(opts...) + privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) if err != nil { panic(err) @@ -136,6 +182,7 @@ func NewIDP(users []*User) *IDP { userLookup[user.ID] = user } return &IDP{ + IDPOptions: options, publicJWK: publicJWK, signingKey: signingKey, userLookup: userLookup, diff --git a/internal/testenv/scenarios/trace_receiver.go b/internal/testenv/scenarios/trace_receiver.go index 96b0c1b8b..f5d761fbb 100644 --- a/internal/testenv/scenarios/trace_receiver.go +++ b/internal/testenv/scenarios/trace_receiver.go @@ -174,16 +174,16 @@ func (rec *OTLPTraceReceiver) FlushResourceSpans() []*tracev1.ResourceSpans { // GRPCEndpointURL returns a url suitable for use with the environment variable // $OTEL_EXPORTER_OTLP_TRACES_ENDPOINT or with [otlptracegrpc.WithEndpointURL]. func (rec *OTLPTraceReceiver) GRPCEndpointURL() values.Value[string] { - return values.Chain(rec.grpcUpstream, upstreams.GRPCUpstream.Port, func(port int) string { - return fmt.Sprintf("http://127.0.0.1:%d", port) + return values.Chain(rec.grpcUpstream, upstreams.GRPCUpstream.Addr, func(addr string) string { + return fmt.Sprintf("http://%s", addr) }) } // GRPCEndpointURL returns a url suitable for use with the environment variable // $OTEL_EXPORTER_OTLP_TRACES_ENDPOINT or with [otlptracehttp.WithEndpointURL]. func (rec *OTLPTraceReceiver) HTTPEndpointURL() values.Value[string] { - return values.Chain(rec.httpUpstream, upstreams.HTTPUpstream.Port, func(port int) string { - return fmt.Sprintf("http://127.0.0.1:%d/v1/traces", port) + return values.Chain(rec.httpUpstream, upstreams.HTTPUpstream.Addr, func(addr string) string { + return fmt.Sprintf("http://%s/v1/traces", addr) }) } @@ -200,9 +200,9 @@ func (rec *OTLPTraceReceiver) NewGRPCClient(opts ...otlptracegrpc.Option) otlptr func (rec *OTLPTraceReceiver) NewHTTPClient(opts ...otlptracehttp.Option) otlptrace.Client { return &deferredClient{ - client: values.Chain(rec.httpUpstream, upstreams.HTTPUpstream.Port, func(port int) otlptrace.Client { + client: values.Chain(rec.httpUpstream, upstreams.HTTPUpstream.Addr, func(addr string) otlptrace.Client { return otlptracehttp.NewClient(append(opts, - otlptracehttp.WithEndpointURL(fmt.Sprintf("http://127.0.0.1:%d/v1/traces", port)), + otlptracehttp.WithEndpointURL(fmt.Sprintf("http://%s/v1/traces", addr)), otlptracehttp.WithTimeout(1*time.Minute), )...) }), diff --git a/internal/testenv/types.go b/internal/testenv/types.go index 8e863d3eb..959fd9a33 100644 --- a/internal/testenv/types.go +++ b/internal/testenv/types.go @@ -204,7 +204,8 @@ func (f *taskFunc) Run(ctx context.Context) error { type Upstream interface { Modifier Task - Port() values.Value[int] + + Addr() values.Value[string] Route() RouteStub } diff --git a/internal/testenv/upstreams/grpc.go b/internal/testenv/upstreams/grpc.go index 90773d8c1..c6c203586 100644 --- a/internal/testenv/upstreams/grpc.go +++ b/internal/testenv/upstreams/grpc.go @@ -97,8 +97,10 @@ type service struct { impl any } -func (g *grpcUpstream) Port() values.Value[int] { - return g.serverPort +func (g *grpcUpstream) Addr() values.Value[string] { + return values.Bind(g.serverPort, func(port int) string { + return fmt.Sprintf("%s:%d", g.Env().Host(), port) + }) } // RegisterService implements grpc.ServiceRegistrar. @@ -117,7 +119,7 @@ func (g *grpcUpstream) Route() testenv.RouteStub { protocol = "https" } r.To(values.Bind(g.serverPort, func(port int) string { - return fmt.Sprintf("%s://127.0.0.1:%d", protocol, port) + return fmt.Sprintf("%s://%s:%d", protocol, g.Env().Host(), port) })) g.Add(r) return r @@ -125,7 +127,7 @@ func (g *grpcUpstream) Route() testenv.RouteStub { // Start implements testenv.Upstream. func (g *grpcUpstream) Run(ctx context.Context) error { - listener, err := net.Listen("tcp", "127.0.0.1:0") + listener, err := net.Listen("tcp", fmt.Sprintf("%s:0", g.Env().Host())) if err != nil { return err } @@ -187,7 +189,7 @@ func (g *grpcUpstream) Dial(r testenv.Route, dialOpts ...grpc.DialOption) *grpc. } func (g *grpcUpstream) DirectConnect(dialOpts ...grpc.DialOption) *grpc.ClientConn { - cc, err := grpc.NewClient(fmt.Sprintf("127.0.0.1:%d", g.Port().Value()), + cc, err := grpc.NewClient(g.Addr().Value(), append(g.withDefaultDialOpts(dialOpts), grpc.WithTransportCredentials(insecure.NewCredentials()))...) if err != nil { panic(err) diff --git a/internal/testenv/upstreams/http.go b/internal/testenv/upstreams/http.go index fd8d53ffd..8d97d61e3 100644 --- a/internal/testenv/upstreams/http.go +++ b/internal/testenv/upstreams/http.go @@ -1,10 +1,8 @@ package upstreams import ( - "bytes" "context" "crypto/tls" - "encoding/json" "errors" "fmt" "io" @@ -13,14 +11,12 @@ import ( "net/http/cookiejar" "net/http/httptrace" "net/url" - "strconv" - "strings" + "os" "sync" "time" "github.com/gorilla/mux" - "github.com/pomerium/pomerium/integration/forms" - "github.com/pomerium/pomerium/internal/retry" + "github.com/gorilla/websocket" "github.com/pomerium/pomerium/internal/telemetry/trace" "github.com/pomerium/pomerium/internal/testenv" "github.com/pomerium/pomerium/internal/testenv/snippets" @@ -29,9 +25,15 @@ import ( "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/codes" - semconv "go.opentelemetry.io/otel/semconv/v1.26.0" oteltrace "go.opentelemetry.io/otel/trace" - "google.golang.org/protobuf/proto" +) + +type Protocol string + +const ( + DialHTTP1 Protocol = "http/1.1" + DialHTTP2 Protocol = "h2" + DialHTTP3 Protocol = "h3" ) type RequestOptions struct { @@ -42,12 +44,18 @@ type RequestOptions struct { authenticateAs string body any clientCerts []tls.Certificate - client *http.Client + clientHook func(*http.Client) *http.Client + dialerHook func(*websocket.Dialer, *url.URL) (*websocket.Dialer, *url.URL) + dialProtocol Protocol trace *httptrace.ClientTrace } type RequestOption func(*RequestOptions) +func (ro RequestOption) Format(fmt.State, rune) { + panic("test bug: request option mistakenly passed to assert function") +} + func (o *RequestOptions) apply(opts ...RequestOption) { for _, op := range opts { op(o) @@ -82,9 +90,38 @@ func AuthenticateAs(email string) RequestOption { } } -func Client(c *http.Client) RequestOption { +// ClientHook allows editing or replacing the http client before it is used. +// When any request is about to start, this function will be called with the +// client that would be used to make the request. The returned client will +// be the actual client used for that request. It can be the same as the input +// (with or without modification), or replaced entirely. +// +// Note: the Transport of the client passed to the hook will always be a +// [*Transport]. That transport's underlying transport will always be +// a [*otelhttp.Transport]. +func ClientHook(f func(*http.Client) *http.Client) RequestOption { return func(o *RequestOptions) { - o.client = c + o.clientHook = f + } +} + +// DialerHook allows editing or replacing the websocket dialer before it is +// used. When a websocket request is about to start (using the DialWS method), +// this function will be called with the dialer that would be used, and the +// destination URL (including wss:// scheme, and path if one is present). The +// returned dialer+URL will be the actual dialer+URL used for that request. +// +// If ClientHook is also set, both will be called. The dialer passed to this +// hook will have its TLSClientConfig and Jar fields set from the client. +func DialerHook(f func(*websocket.Dialer, *url.URL) (*websocket.Dialer, *url.URL)) RequestOption { + return func(o *RequestOptions) { + o.dialerHook = f + } +} + +func DialProtocol(protocol Protocol) RequestOption { + return func(o *RequestOptions) { + o.dialProtocol = protocol } } @@ -143,10 +180,12 @@ type HTTPUpstream interface { testenv.Upstream Handle(path string, f func(http.ResponseWriter, *http.Request)) *mux.Route + HandleWS(path string, upgrader websocket.Upgrader, f func(conn *websocket.Conn) error) *mux.Route Get(r testenv.Route, opts ...RequestOption) (*http.Response, error) Post(r testenv.Route, opts ...RequestOption) (*http.Response, error) Do(method string, r testenv.Route, opts ...RequestOption) (*http.Response, error) + DialWS(r testenv.Route, f func(conn *websocket.Conn) error, opts ...RequestOption) error } type httpUpstream struct { @@ -194,8 +233,10 @@ func HTTP(tlsConfig values.Value[*tls.Config], opts ...HTTPUpstreamOption) HTTPU } // Port implements HTTPUpstream. -func (h *httpUpstream) Port() values.Value[int] { - return h.serverPort +func (h *httpUpstream) Addr() values.Value[string] { + return values.Bind(h.serverPort, func(port int) string { + return fmt.Sprintf("%s:%d", h.Env().Host(), port) + }) } // Router implements HTTPUpstream. @@ -203,12 +244,37 @@ func (h *httpUpstream) Handle(path string, f func(http.ResponseWriter, *http.Req return h.router.HandleFunc(path, f) } +// Router implements HTTPUpstream. +func (h *httpUpstream) HandleWS(path string, upgrader websocket.Upgrader, f func(*websocket.Conn) error) *mux.Route { + return h.router.HandleFunc(path, func(w http.ResponseWriter, r *http.Request) { + ctx, span := trace.Continue(r.Context(), "HandleWS") + defer span.End() + c, err := upgrader.Upgrade(w, r.WithContext(ctx), nil) + if err != nil { + span.SetStatus(codes.Error, err.Error()) + w.WriteHeader(http.StatusBadRequest) + _, _ = w.Write([]byte(err.Error())) + return + } + defer c.Close() + + err = f(c) + if err != nil { + if errors.Is(err, io.EOF) { + return + } + span.SetStatus(codes.Error, err.Error()) + fmt.Fprintf(os.Stderr, "websocket error: %s\n", err.Error()) + } + }) +} + // Route implements HTTPUpstream. func (h *httpUpstream) Route() testenv.RouteStub { r := &testenv.PolicyRoute{} protocol := "http" r.To(values.Bind(h.serverPort, func(port int) string { - return fmt.Sprintf("%s://127.0.0.1:%d", protocol, port) + return fmt.Sprintf("%s://%s:%d", protocol, h.Env().Host(), port) })) h.Add(r) return r @@ -216,15 +282,21 @@ func (h *httpUpstream) Route() testenv.RouteStub { // Run implements HTTPUpstream. func (h *httpUpstream) Run(ctx context.Context) error { - listener, err := net.Listen("tcp", "127.0.0.1:0") - if err != nil { - return err + var listener net.Listener + if h.tlsConfig != nil { + var err error + listener, err = tls.Listen("tcp", fmt.Sprintf("%s:0", h.Env().Host()), h.tlsConfig.Value()) + if err != nil { + return err + } + } else { + var err error + listener, err = net.Listen("tcp", fmt.Sprintf("%s:0", h.Env().Host())) + if err != nil { + return err + } } h.serverPort.Resolve(listener.Addr().(*net.TCPAddr).Port) - var tlsConfig *tls.Config - if h.tlsConfig != nil { - tlsConfig = h.tlsConfig.Value() - } if h.serverTracerProviderOverride != nil { h.serverTracerProvider.Resolve(h.serverTracerProviderOverride) } else { @@ -238,8 +310,7 @@ func (h *httpUpstream) Run(ctx context.Context) error { h.router.Use(trace.NewHTTPMiddleware(otelhttp.WithTracerProvider(h.serverTracerProvider.Value()))) server := &http.Server{ - Handler: h.router, - TLSConfig: tlsConfig, + Handler: h.router, // BaseContext: func(net.Listener) context.Context { // return ctx // }, @@ -277,6 +348,53 @@ func (h *httpUpstream) Post(r testenv.Route, opts ...RequestOption) (*http.Respo return h.Do(http.MethodPost, r, opts...) } +type Transport struct { + *otelhttp.Transport + // The underlying http.Transport instance wrapped by the otelhttp.Transport. + Base *http.Transport +} + +var _ http.RoundTripper = Transport{} + +func (h *httpUpstream) newClient(options *RequestOptions) *http.Client { + transport := http.DefaultTransport.(*http.Transport).Clone() + transport.TLSClientConfig = &tls.Config{ + RootCAs: h.Env().ServerCAs(), + Certificates: options.clientCerts, + } + transport.DialTLSContext = nil + c := http.Client{ + Transport: &Transport{ + Transport: otelhttp.NewTransport(transport, + otelhttp.WithTracerProvider(h.clientTracerProvider.Value()), + otelhttp.WithSpanNameFormatter(func(_ string, r *http.Request) string { + return fmt.Sprintf("Client: %s %s", r.Method, r.URL.Path) + }), + ), + Base: transport, + }, + } + c.Jar, _ = cookiejar.New(&cookiejar.Options{}) + return &c +} + +func (h *httpUpstream) getRouteClient(r testenv.Route, options *RequestOptions) *http.Client { + span := oteltrace.SpanFromContext(options.requestCtx) + var cachedClient any + var ok bool + if cachedClient, ok = h.clientCache.Load(r); !ok { + span.AddEvent("creating new http client") + cachedClient, _ = h.clientCache.LoadOrStore(r, h.newClient(options)) + } else { + span.AddEvent("using cached http client") + } + client := cachedClient.(*http.Client) + if options.clientHook != nil { + client = options.clientHook(client) + } + return client +} + // Do implements HTTPUpstream. func (h *httpUpstream) Do(method string, r testenv.Route, opts ...RequestOption) (*http.Response, error) { options := RequestOptions{ @@ -303,141 +421,54 @@ func (h *httpUpstream) Do(method string, r testenv.Route, opts ...RequestOption) options.requestCtx = ctx defer span.End() - newClient := func() *http.Client { - transport := http.DefaultTransport.(*http.Transport).Clone() - transport.TLSClientConfig = &tls.Config{ - RootCAs: h.Env().ServerCAs(), - Certificates: options.clientCerts, - } - transport.DialTLSContext = nil - c := http.Client{ - Transport: otelhttp.NewTransport(transport, - otelhttp.WithTracerProvider(h.clientTracerProvider.Value()), - otelhttp.WithSpanNameFormatter(func(_ string, r *http.Request) string { - return fmt.Sprintf("Client: %s %s", r.Method, r.URL.Path) - }), - ), - } - c.Jar, _ = cookiejar.New(&cookiejar.Options{}) - return &c - } - var client *http.Client - if options.client != nil { - client = options.client - } else { - var cachedClient any - var ok bool - if cachedClient, ok = h.clientCache.Load(r); !ok { - span.AddEvent("creating new http client") - cachedClient, _ = h.clientCache.LoadOrStore(r, newClient()) - } else { - span.AddEvent("using cached http client") - } - client = cachedClient.(*http.Client) - } - - var resp *http.Response - resendCount := 0 - if err := retry.Retry(ctx, "http", func(ctx context.Context) error { - req, err := http.NewRequestWithContext(ctx, method, u.String(), nil) - if err != nil { - return retry.NewTerminalError(err) - } - switch body := options.body.(type) { - case string: - req.Body = io.NopCloser(strings.NewReader(body)) - case []byte: - req.Body = io.NopCloser(bytes.NewReader(body)) - case io.Reader: - req.Body = io.NopCloser(body) - case proto.Message: - buf, err := proto.Marshal(body) - if err != nil { - return retry.NewTerminalError(err) - } - req.Body = io.NopCloser(bytes.NewReader(buf)) - req.Header.Set("Content-Type", "application/octet-stream") - default: - buf, err := json.Marshal(body) - if err != nil { - panic(fmt.Sprintf("unsupported body type: %T", body)) - } - req.Body = io.NopCloser(bytes.NewReader(buf)) - req.Header.Set("Content-Type", "application/json") - case nil: - } - - if options.authenticateAs != "" { - resp, err = authenticateFlow(ctx, client, req, options.authenticateAs) //nolint:bodyclose - } else { - resp, err = client.Do(req) //nolint:bodyclose - } - // retry on connection refused - if err != nil { - span.RecordError(err) - var opErr *net.OpError - if errors.As(err, &opErr) && opErr.Op == "dial" && opErr.Err.Error() == "connect: connection refused" { - span.AddEvent("Retrying on dial error") - return err - } - return retry.NewTerminalError(err) - } - if resp.StatusCode/100 == 5 { - resendCount++ - _, _ = io.ReadAll(resp.Body) - _ = resp.Body.Close() - span.SetAttributes(semconv.HTTPRequestResendCount(resendCount)) - span.AddEvent("Retrying on 5xx error", oteltrace.WithAttributes( - attribute.String("status", resp.Status), - )) - return errors.New(http.StatusText(resp.StatusCode)) - } - span.SetStatus(codes.Ok, "request completed successfully") - return nil - }, - retry.WithInitialInterval(1*time.Millisecond), - retry.WithMaxInterval(100*time.Millisecond), - ); err != nil { - return nil, err - } - return resp, nil + return doAuthenticatedRequest(options.requestCtx, + func(ctx context.Context) (*http.Request, error) { + return http.NewRequestWithContext(ctx, method, u.String(), nil) + }, + func(context.Context) *http.Client { + return h.getRouteClient(r, &options) + }, + &options, + ) } -func authenticateFlow(ctx context.Context, client *http.Client, req *http.Request, email string) (*http.Response, error) { - span := oteltrace.SpanFromContext(ctx) - var res *http.Response - originalHostname := req.URL.Hostname() - res, err := client.Do(req) +func (h *httpUpstream) DialWS(r testenv.Route, f func(conn *websocket.Conn) error, opts ...RequestOption) error { + options := RequestOptions{ + requestCtx: h.Env().Context(), + } + options.apply(opts...) + u, err := url.Parse(r.URL().Value()) if err != nil { - span.RecordError(err) - return nil, err + return err } - location := res.Request.URL - if location.Hostname() == originalHostname { - // already authenticated - span.SetStatus(codes.Ok, "already authenticated") - return res, nil + u.Scheme = "wss" + if options.path != "" || options.query != nil { + u = u.ResolveReference(&url.URL{ + Path: options.path, + RawQuery: options.query.Encode(), + }) } - fs := forms.Parse(res.Body) - _, _ = io.ReadAll(res.Body) - _ = res.Body.Close() - if len(fs) > 0 { - f := fs[0] - f.Inputs["email"] = email - f.Inputs["token_expiration"] = strconv.Itoa(int((time.Hour * 24).Seconds())) - span.AddEvent("submitting form", oteltrace.WithAttributes(attribute.String("location", location.String()))) - formReq, err := f.NewRequestWithContext(ctx, location) - if err != nil { - span.RecordError(err) - return nil, err - } - resp, err := client.Do(formReq) - if err != nil { - span.RecordError(err) - return nil, err - } - span.SetStatus(codes.Ok, "form submitted successfully") - return resp, nil + ctx, span := h.clientTracer.Value().Start(options.requestCtx, "httpUpstream.Dial", oteltrace.WithAttributes( + attribute.String("url", u.String()), + )) + options.requestCtx = ctx + defer span.End() + + client := h.getRouteClient(r, &options) + d := &websocket.Dialer{ + HandshakeTimeout: 10 * time.Second, + TLSClientConfig: client.Transport.(*Transport).Base.TLSClientConfig, + Jar: client.Jar, } - return nil, fmt.Errorf("test bug: expected IDP login form") + if options.dialerHook != nil { + d, u = options.dialerHook(d, u) + } + conn, resp, err := d.DialContext(options.requestCtx, u.String(), nil) + if err != nil { + resp.Body.Close() + return fmt.Errorf("DialContext: %w", err) + } + defer conn.Close() + + return f(conn) } diff --git a/internal/testenv/upstreams/options.go b/internal/testenv/upstreams/options.go index c0bad32b5..3d58548bf 100644 --- a/internal/testenv/upstreams/options.go +++ b/internal/testenv/upstreams/options.go @@ -15,6 +15,7 @@ type CommonUpstreamOptions struct { type CommonUpstreamOption interface { GRPCUpstreamOption HTTPUpstreamOption + TCPUpstreamOption } type commonUpstreamOption func(o *CommonUpstreamOptions) @@ -25,6 +26,9 @@ func (c commonUpstreamOption) applyGRPC(o *GRPCUpstreamOptions) { c(&o.CommonUps // applyHTTP implements CommonUpstreamOption. func (c commonUpstreamOption) applyHTTP(o *HTTPUpstreamOptions) { c(&o.CommonUpstreamOptions) } +// applyTCP implements CommonUpstreamOption. +func (c commonUpstreamOption) applyTCP(o *TCPUpstreamOptions) { c(&o.CommonUpstreamOptions) } + func WithDisplayName(displayName string) CommonUpstreamOption { return commonUpstreamOption(func(o *CommonUpstreamOptions) { o.displayName = displayName diff --git a/internal/testenv/upstreams/tcp.go b/internal/testenv/upstreams/tcp.go new file mode 100644 index 000000000..6eab5b629 --- /dev/null +++ b/internal/testenv/upstreams/tcp.go @@ -0,0 +1,343 @@ +package upstreams + +import ( + "context" + "crypto/tls" + "errors" + "fmt" + "io" + "net" + "net/http" + "net/http/cookiejar" + "net/http/httptrace" + "net/url" + "sync" + + "github.com/pomerium/pomerium/internal/telemetry/trace" + "github.com/pomerium/pomerium/internal/testenv" + "github.com/pomerium/pomerium/internal/testenv/values" + "go.opentelemetry.io/otel/attribute" + oteltrace "go.opentelemetry.io/otel/trace" + "golang.org/x/net/http2" +) + +type TCPUpstream interface { + testenv.Upstream + + Handle(fn func(context.Context, net.Conn) error) + + Dial(r testenv.Route, fn func(context.Context, net.Conn) error, opts ...RequestOption) error +} + +type TCPUpstreamOptions struct { + CommonUpstreamOptions +} + +type TCPUpstreamOption interface { + applyTCP(*TCPUpstreamOptions) +} + +type tcpUpstream struct { + TCPUpstreamOptions + testenv.Aggregate + serverPort values.MutableValue[int] + serverHandler func(context.Context, net.Conn) error + + serverTracerProvider values.MutableValue[oteltrace.TracerProvider] + clientTracerProvider values.MutableValue[oteltrace.TracerProvider] + clientTracer values.Value[oteltrace.Tracer] +} + +func TCP(opts ...TCPUpstreamOption) TCPUpstream { + options := TCPUpstreamOptions{ + CommonUpstreamOptions: CommonUpstreamOptions{ + displayName: "TCP Upstream", + }, + } + for _, op := range opts { + op.applyTCP(&options) + } + up := &tcpUpstream{ + TCPUpstreamOptions: options, + serverPort: values.Deferred[int](), + + serverTracerProvider: values.Deferred[oteltrace.TracerProvider](), + clientTracerProvider: values.Deferred[oteltrace.TracerProvider](), + } + up.clientTracer = values.Bind(up.clientTracerProvider, func(tp oteltrace.TracerProvider) oteltrace.Tracer { + return tp.Tracer(trace.PomeriumCoreTracer) + }) + up.RecordCaller() + return up +} + +// Dial implements TCPUpstream. +func (t *tcpUpstream) Dial(r testenv.Route, clientHandler func(context.Context, net.Conn) error, opts ...RequestOption) error { + options := RequestOptions{ + requestCtx: t.Env().Context(), + dialProtocol: DialHTTP1, + } + options.apply(opts...) + u, err := url.Parse(r.URL().Value()) + if err != nil { + return err + } + + ctx, span := t.clientTracer.Value().Start(options.requestCtx, "tcpUpstream.Do", oteltrace.WithAttributes( + attribute.String("protocol", string(options.dialProtocol)), + attribute.String("url", u.String()), + )) + if options.path != "" || options.query != nil { + u = u.ResolveReference(&url.URL{ + Path: options.path, + RawQuery: options.query.Encode(), + }) + } + if options.trace != nil { + ctx = httptrace.WithClientTrace(ctx, options.trace) + } + options.requestCtx = ctx + defer span.End() + + var remoteConn *tls.Conn + remoteWriter := make(chan *io.PipeWriter, 1) + + connectURL := &url.URL{Scheme: "https", Host: u.Host, Path: u.Path} + + var getClientFn func(context.Context) *http.Client + var newRequestFn func(ctx context.Context) (*http.Request, error) + switch options.dialProtocol { + case DialHTTP1: + getClientFn = t.h1Dialer(&options, connectURL, &remoteConn) + newRequestFn = func(ctx context.Context) (*http.Request, error) { + req := (&http.Request{ + Method: http.MethodConnect, + URL: connectURL, + Host: u.Host, + }).WithContext(ctx) + return req, nil + } + case DialHTTP2: + getClientFn = t.h2Dialer(&options, connectURL, &remoteConn, remoteWriter) + newRequestFn = func(ctx context.Context) (*http.Request, error) { + req := (&http.Request{ + Method: http.MethodConnect, + URL: connectURL, + Host: u.Host, + Proto: "HTTP/2", + }).WithContext(ctx) + return req, nil + } + case DialHTTP3: + panic("not implemented") + } + resp, err := doAuthenticatedRequest(options.requestCtx, newRequestFn, getClientFn, &options) + if err != nil { + return err + } + if resp.StatusCode != http.StatusOK { + resp.Body.Close() + return errors.New(resp.Status) + } + if resp.Request.URL.Path == "/oidc/auth" { + if options.authenticateAs == "" { + return errors.New("test bug: unexpected IDP redirect; missing AuthenticateAs option to Dial()") + } + return errors.New("internal test bug: unexpected IDP redirect") + } + + var w io.WriteCloser = remoteConn + if options.dialProtocol == DialHTTP2 { + w = <-remoteWriter + } + + conn := NewRWConn(resp.Body, w) + defer conn.Close() + return clientHandler(resp.Request.Context(), conn) +} + +func (t *tcpUpstream) h1Dialer( + options *RequestOptions, + connectURL *url.URL, + remoteConn **tls.Conn, +) func(context.Context) *http.Client { + jar, _ := cookiejar.New(nil) + return func(context.Context) *http.Client { + tlsConfig := &tls.Config{ + RootCAs: t.Env().ServerCAs(), + Certificates: options.clientCerts, + NextProtos: []string{"http/1.1"}, + } + client := &http.Client{ + Transport: &http.Transport{ + DisableKeepAlives: true, + DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + if *remoteConn != nil { + (*remoteConn).Close() + *remoteConn = nil + } + dialer := &tls.Dialer{ + Config: tlsConfig, + } + cc, err := dialer.DialContext(ctx, network, addr) + if err != nil { + return nil, fmt.Errorf("%w: %w", ErrRetry, err) + } + protocol := cc.(*tls.Conn).ConnectionState().NegotiatedProtocol + if protocol != "http/1.1" { + cc.Close() + return nil, fmt.Errorf("error: unexpected TLS protocol: %s", protocol) + } + *remoteConn = cc.(*tls.Conn) + return cc, nil + }, + TLSClientConfig: tlsConfig, // important + }, + CheckRedirect: func(req *http.Request, _ []*http.Request) error { + if req.URL.String() == connectURL.String() && req.Method == http.MethodGet { + req.Method = http.MethodConnect + } + return nil + }, + Jar: jar, + } + return client + } +} + +func (t *tcpUpstream) h2Dialer( + options *RequestOptions, + connectURL *url.URL, + remoteConn **tls.Conn, + writer chan<- *io.PipeWriter, +) func(context.Context) *http.Client { + jar, _ := cookiejar.New(nil) + return func(context.Context) *http.Client { + h1 := &http.Transport{ + ForceAttemptHTTP2: true, + DisableKeepAlives: true, + DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + if *remoteConn != nil { + (*remoteConn).Close() + *remoteConn = nil + } + dialer := &tls.Dialer{ + Config: &tls.Config{ + RootCAs: t.Env().ServerCAs(), + Certificates: options.clientCerts, + NextProtos: []string{"h2"}, + }, + } + cc, err := dialer.DialContext(ctx, network, addr) + if err != nil { + return nil, fmt.Errorf("%w: %w", ErrRetry, err) + } + protocol := cc.(*tls.Conn).ConnectionState().NegotiatedProtocol + if protocol != "h2" { + cc.Close() + return nil, fmt.Errorf("error: unexpected TLS protocol: %s", protocol) + } + *remoteConn = cc.(*tls.Conn) + + return cc, nil + }, + TLSClientConfig: &tls.Config{ + RootCAs: t.Env().ServerCAs(), + Certificates: options.clientCerts, + NextProtos: []string{"h2"}, + }, + } + if err := http2.ConfigureTransport(h1); err != nil { + panic(err) + } + client := &http.Client{ + Transport: h1, + CheckRedirect: func(req *http.Request, _ []*http.Request) error { + if req.URL.String() == connectURL.String() && req.Method == http.MethodGet { + pr, pw := io.Pipe() + req.Method = http.MethodConnect + req.Body = pr + req.ContentLength = -1 + writer <- pw + } + return nil + }, + Jar: jar, + } + return client + } +} + +// Handle implements TCPUpstream. +func (t *tcpUpstream) Handle(fn func(context.Context, net.Conn) error) { + t.serverHandler = fn +} + +// Port implements TCPUpstream. +func (t *tcpUpstream) Addr() values.Value[string] { + return values.Bind(t.serverPort, func(port int) string { + return fmt.Sprintf("%s:%d", t.Env().Host(), port) + }) +} + +// Route implements TCPUpstream. +func (t *tcpUpstream) Route() testenv.RouteStub { + r := &testenv.TCPRoute{} + r.To(values.Bind(t.serverPort, func(port int) string { + return fmt.Sprintf("tcp://%s:%d", t.Env().Host(), port) + })) + t.Add(r) + return r +} + +// Run implements TCPUpstream. +func (t *tcpUpstream) Run(ctx context.Context) error { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + listener, err := (&net.ListenConfig{}).Listen(ctx, "tcp", fmt.Sprintf("%s:0", t.Env().Host())) + if err != nil { + return err + } + context.AfterFunc(ctx, func() { + listener.Close() + }) + t.serverPort.Resolve(listener.Addr().(*net.TCPAddr).Port) + if t.serverTracerProviderOverride != nil { + t.serverTracerProvider.Resolve(t.serverTracerProviderOverride) + } else { + t.serverTracerProvider.Resolve(trace.NewTracerProvider(ctx, t.displayName)) + } + if t.clientTracerProviderOverride != nil { + t.clientTracerProvider.Resolve(t.clientTracerProviderOverride) + } else { + t.clientTracerProvider.Resolve(trace.NewTracerProvider(ctx, "TCP Client")) + } + var wg sync.WaitGroup + defer wg.Wait() + + for { + conn, err := listener.Accept() + if err != nil { + if errors.Is(err, net.ErrClosed) { + cancel() + return nil + } + continue + } + wg.Add(1) + go func() { + defer wg.Done() + if err := t.serverHandler(ctx, conn); err != nil { + if errors.Is(err, io.EOF) { + return + } + panic("server handler error: " + err.Error()) + } + }() + } +} + +var ( + _ testenv.Upstream = (*tcpUpstream)(nil) + _ TCPUpstream = (*tcpUpstream)(nil) +) diff --git a/internal/testenv/upstreams/util.go b/internal/testenv/upstreams/util.go new file mode 100644 index 000000000..cc5b3cc2f --- /dev/null +++ b/internal/testenv/upstreams/util.go @@ -0,0 +1,195 @@ +package upstreams + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net" + "net/http" + "strconv" + "strings" + "sync" + "time" + + "github.com/pomerium/pomerium/integration/forms" + "github.com/pomerium/pomerium/internal/retry" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/codes" + semconv "go.opentelemetry.io/otel/semconv/v1.26.0" + oteltrace "go.opentelemetry.io/otel/trace" + "google.golang.org/protobuf/proto" +) + +var ErrRetry = errors.New("error") + +func doAuthenticatedRequest( + ctx context.Context, + newRequest func(context.Context) (*http.Request, error), + getClient func(context.Context) *http.Client, + options *RequestOptions, +) (*http.Response, error) { + var resp *http.Response + resendCount := 0 + client := getClient(ctx) + + if err := retry.Retry(ctx, "http", func(ctx context.Context) error { + req, err := newRequest(ctx) + if err != nil { + return retry.NewTerminalError(err) + } + + switch body := options.body.(type) { + case string: + req.Body = io.NopCloser(strings.NewReader(body)) + case []byte: + req.Body = io.NopCloser(bytes.NewReader(body)) + case io.Reader: + req.Body = io.NopCloser(body) + case proto.Message: + buf, err := proto.Marshal(body) + if err != nil { + return retry.NewTerminalError(err) + } + req.Body = io.NopCloser(bytes.NewReader(buf)) + req.Header.Set("Content-Type", "application/octet-stream") + default: + buf, err := json.Marshal(body) + if err != nil { + panic(fmt.Sprintf("unsupported body type: %T", body)) + } + req.Body = io.NopCloser(bytes.NewReader(buf)) + req.Header.Set("Content-Type", "application/json") + case nil: + } + + if options.headers != nil && req.Header == nil { + req.Header = http.Header{} + } + for k, v := range options.headers { + req.Header.Add(k, v) + } + + if options.authenticateAs != "" { + resp, err = authenticateFlow(ctx, client, req, options.authenticateAs, true) //nolint:bodyclose + } else { + resp, err = client.Do(req) //nolint:bodyclose + } + // retry on connection refused + span := oteltrace.SpanFromContext(ctx) + if err != nil { + span.RecordError(err) + var opErr *net.OpError + if errors.As(err, &opErr) && opErr.Op == "dial" && opErr.Err.Error() == "connect: connection refused" { + span.AddEvent("Retrying on dial error") + return err + } + return retry.NewTerminalError(err) + } + if resp.StatusCode/100 == 5 { + resendCount++ + _, _ = io.ReadAll(resp.Body) + _ = resp.Body.Close() + span.SetAttributes(semconv.HTTPRequestResendCount(resendCount)) + span.AddEvent("Retrying on 5xx error", oteltrace.WithAttributes( + attribute.String("status", resp.Status), + )) + return errors.New(http.StatusText(resp.StatusCode)) + } + span.SetStatus(codes.Ok, "request completed successfully") + return nil + }, + retry.WithInitialInterval(1*time.Millisecond), + retry.WithMaxInterval(100*time.Millisecond), + ); err != nil { + return nil, err + } + return resp, nil +} + +func authenticateFlow(ctx context.Context, client *http.Client, req *http.Request, email string, checkLocation bool) (*http.Response, error) { + span := oteltrace.SpanFromContext(ctx) + var res *http.Response + originalHostname := req.URL.Hostname() + res, err := client.Do(req) + if err != nil { + span.RecordError(err) + return nil, err + } + location := res.Request.URL + if checkLocation && location.Hostname() == originalHostname { + // already authenticated + span.SetStatus(codes.Ok, "already authenticated") + return res, nil + } + fs := forms.Parse(res.Body) + _, _ = io.ReadAll(res.Body) + _ = res.Body.Close() + if len(fs) > 0 { + f := fs[0] + f.Inputs["email"] = email + f.Inputs["token_expiration"] = strconv.Itoa(int((time.Hour * 24).Seconds())) + span.AddEvent("submitting form", oteltrace.WithAttributes(attribute.String("location", location.String()))) + formReq, err := f.NewRequestWithContext(ctx, location) + if err != nil { + span.RecordError(err) + return nil, err + } + resp, err := client.Do(formReq) + if err != nil { + span.RecordError(err) + return nil, err + } + span.SetStatus(codes.Ok, "form submitted successfully") + return resp, nil + } + return nil, fmt.Errorf("test bug: expected IDP login form") +} + +type rwConn struct { + serverReader io.ReadCloser + serverWriter io.WriteCloser + + net.Conn + remote net.Conn + + closeOnce sync.Once + wg *sync.WaitGroup +} + +func NewRWConn(reader io.ReadCloser, writer io.WriteCloser) net.Conn { + rwc := &rwConn{ + serverReader: reader, + serverWriter: writer, + wg: &sync.WaitGroup{}, + } + + rwc.Conn, rwc.remote = net.Pipe() + rwc.wg.Add(2) + go func() { + defer rwc.wg.Done() + _, _ = io.Copy(rwc.remote, rwc.serverReader) + rwc.remote.Close() + }() + go func() { + defer rwc.wg.Done() + _, _ = io.Copy(rwc.serverWriter, rwc.remote) + rwc.serverWriter.Close() + }() + return rwc +} + +func (rwc *rwConn) Close() error { + var err error + rwc.closeOnce.Do(func() { + readerErr := rwc.serverReader.Close() + localErr := rwc.Conn.Close() + rwc.wg.Wait() + err = errors.Join(localErr, readerErr) + }) + return err +} + +var _ net.Conn = (*rwConn)(nil) diff --git a/pkg/storage/postgres/tracing_test.go b/pkg/storage/postgres/tracing_test.go index 3f00eb9f6..7e966ea30 100644 --- a/pkg/storage/postgres/tracing_test.go +++ b/pkg/storage/postgres/tracing_test.go @@ -48,7 +48,7 @@ func TestQueryTracing(t *testing.T) { snippets.WaitStartupComplete(env) resp, err := up.Get(route, upstreams.AuthenticateAs("user@example.com"), upstreams.Path("/foo")) - assert.NoError(t, err) + require.NoError(t, err) io.ReadAll(resp.Body) resp.Body.Close()