mirror of
https://github.com/pomerium/pomerium.git
synced 2025-04-28 18:06:34 +02:00
add tests/benchmarks for http1/http2 tcp tunnels and http1 websockets (#5471)
* add tests/benchmarks for http1/http2 tcp tunnels and http1 websockets testenv: - add new TCP upstream - add websocket functions to HTTP upstream - add https support to mock idp (default on) - add new debug flags -env.bind-address and -env.use-trace-environ to allow changing the default bind address, and enabling otel environment based trace config, respectively * linter pass --------- Co-authored-by: Denis Mishin <dmishin@pomerium.com>
This commit is contained in:
parent
d6b02441b3
commit
08623ef346
12 changed files with 1104 additions and 182 deletions
|
@ -1,11 +1,17 @@
|
|||
package envoyconfig_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
|
@ -28,9 +34,9 @@ func TestH2C(t *testing.T) {
|
|||
|
||||
http := up.Route().
|
||||
From(env.SubdomainURL("grpc-http")).
|
||||
To(values.Bind(up.Port(), func(port int) string {
|
||||
To(values.Bind(up.Addr(), func(addr string) string {
|
||||
// override the target protocol to use http://
|
||||
return fmt.Sprintf("http://127.0.0.1:%d", port)
|
||||
return fmt.Sprintf("http://%s", addr)
|
||||
})).
|
||||
Policy(func(p *config.Policy) { p.AllowPublicUnauthenticatedAccess = true })
|
||||
|
||||
|
@ -118,6 +124,234 @@ func TestHTTP(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func TestTCPTunnel(t *testing.T) {
|
||||
env := testenv.New(t, testenv.Debug())
|
||||
|
||||
env.Add(scenarios.NewIDP([]*scenarios.User{{Email: "test@example.com"}}))
|
||||
up := upstreams.TCP()
|
||||
routeH1 := up.Route().
|
||||
From(env.SubdomainURL("h1")).
|
||||
PPL(`{"allow":{"and":["email":{"is":"test@example.com"}]}}`)
|
||||
routeH2 := up.Route().
|
||||
From(env.SubdomainURL("h2")).
|
||||
Policy(func(p *config.Policy) {
|
||||
p.AllowWebsockets = true
|
||||
}).
|
||||
PPL(`{"allow":{"and":["email":{"is":"test@example.com"}]}}`)
|
||||
|
||||
up.Handle(func(_ context.Context, c net.Conn) error {
|
||||
c.SetReadDeadline(time.Now().Add(1 * time.Second))
|
||||
buf := make([]byte, 8)
|
||||
n, err := c.Read(buf)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, string(buf[:n]), "hello")
|
||||
c.SetWriteDeadline(time.Now().Add(1 * time.Second))
|
||||
_, err = c.Write([]byte("world"))
|
||||
require.NoError(t, err)
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
env.AddUpstream(up)
|
||||
env.Start()
|
||||
snippets.WaitStartupComplete(env)
|
||||
|
||||
t.Run("http1", func(t *testing.T) {
|
||||
assert.NoError(t, up.Dial(routeH1, func(_ context.Context, c net.Conn) error {
|
||||
c.SetWriteDeadline(time.Now().Add(1 * time.Second))
|
||||
_, err := c.Write([]byte("hello"))
|
||||
require.NoError(t, err)
|
||||
|
||||
buf := make([]byte, 8)
|
||||
c.SetReadDeadline(time.Now().Add(1 * time.Second))
|
||||
n, err := c.Read(buf)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, string(buf[:n]), "world")
|
||||
return nil
|
||||
}, upstreams.AuthenticateAs("test@example.com"), upstreams.DialProtocol(upstreams.DialHTTP1)))
|
||||
})
|
||||
|
||||
t.Run("http2", func(t *testing.T) {
|
||||
assert.NoError(t, up.Dial(routeH2, func(_ context.Context, c net.Conn) error {
|
||||
c.SetWriteDeadline(time.Now().Add(1 * time.Second))
|
||||
_, err := c.Write([]byte("hello"))
|
||||
require.NoError(t, err)
|
||||
|
||||
buf := make([]byte, 8)
|
||||
c.SetReadDeadline(time.Now().Add(1 * time.Second))
|
||||
n, err := c.Read(buf)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, string(buf[:n]), "world")
|
||||
return nil
|
||||
}, upstreams.AuthenticateAs("test@example.com"), upstreams.DialProtocol(upstreams.DialHTTP2)))
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkHTTP1TCPTunnel(b *testing.B) {
|
||||
env := testenv.New(b, testenv.Silent())
|
||||
env.Add(scenarios.NewIDP([]*scenarios.User{{Email: "test@example.com"}}))
|
||||
up := upstreams.TCP()
|
||||
h1 := up.Route().
|
||||
From(env.SubdomainURL("bench-h1")).
|
||||
PPL(`{"allow":{"and":["email":{"is":"test@example.com"}]}}`)
|
||||
|
||||
env.AddUpstream(up)
|
||||
env.Start()
|
||||
snippets.WaitStartupComplete(env)
|
||||
|
||||
b.Run("http1", func(b *testing.B) {
|
||||
benchmarkTCP(b, up, h1, tcpBenchmarkParams{
|
||||
msgLen: 512,
|
||||
protocol: upstreams.DialHTTP1,
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkHTTP2TCPTunnel(b *testing.B) {
|
||||
env := testenv.New(b, testenv.Silent())
|
||||
env.Add(scenarios.NewIDP([]*scenarios.User{{Email: "test@example.com"}}))
|
||||
up := upstreams.TCP()
|
||||
|
||||
h2 := up.Route().
|
||||
From(env.SubdomainURL("bench-h2")).
|
||||
Policy(func(p *config.Policy) {
|
||||
p.AllowWebsockets = true
|
||||
}).
|
||||
PPL(`{"allow":{"and":["email":{"is":"test@example.com"}]}}`)
|
||||
|
||||
env.AddUpstream(up)
|
||||
env.Start()
|
||||
snippets.WaitStartupComplete(env)
|
||||
|
||||
b.Run("http2", func(b *testing.B) {
|
||||
benchmarkTCP(b, up, h2, tcpBenchmarkParams{
|
||||
msgLen: 512,
|
||||
protocol: upstreams.DialHTTP2,
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
type tcpBenchmarkParams struct {
|
||||
msgLen int
|
||||
protocol upstreams.Protocol
|
||||
}
|
||||
|
||||
func benchmarkTCP(b *testing.B, up upstreams.TCPUpstream, route testenv.Route, params tcpBenchmarkParams) {
|
||||
sendMsg := func(c net.Conn, buf []byte) error {
|
||||
c.SetWriteDeadline(time.Now().Add(1 * time.Second))
|
||||
_, err := c.Write(buf)
|
||||
if err != nil {
|
||||
if errors.Is(err, net.ErrClosed) {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
recvMsg := func(c net.Conn, buf []byte) error {
|
||||
c.SetReadDeadline(time.Now().Add(1 * time.Second))
|
||||
for read := 0; read != len(buf); {
|
||||
n, err := c.Read(buf)
|
||||
read += n
|
||||
if err != nil {
|
||||
if errors.Is(err, net.ErrClosed) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
up.Handle(func(_ context.Context, c net.Conn) error {
|
||||
for {
|
||||
buf := make([]byte, params.msgLen)
|
||||
if err := recvMsg(c, buf[:]); err != nil {
|
||||
if errors.Is(err, net.ErrClosed) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
if err := sendMsg(c, buf[:]); err != nil {
|
||||
if errors.Is(err, net.ErrClosed) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
}
|
||||
})
|
||||
var threads atomic.Int32
|
||||
var requests atomic.Int32
|
||||
var bytes atomic.Int64
|
||||
start := time.Now()
|
||||
b.RunParallel(func(p *testing.PB) {
|
||||
threads.Add(1)
|
||||
require.NoError(b, up.Dial(route, func(_ context.Context, c net.Conn) error {
|
||||
buf := make([]byte, params.msgLen)
|
||||
for p.Next() {
|
||||
requests.Add(1)
|
||||
bytes.Add(int64(params.msgLen))
|
||||
require.NoError(b, sendMsg(c, buf[:]))
|
||||
require.NoError(b, recvMsg(c, buf[:]))
|
||||
}
|
||||
return nil
|
||||
}, upstreams.AuthenticateAs("test@example.com"), upstreams.DialProtocol(params.protocol)))
|
||||
})
|
||||
duration := time.Since(start)
|
||||
b.Logf("sent %d requests over %d parallel connections in %s", requests.Load(), threads.Load(), duration)
|
||||
b.Logf("throughput: %f bytes/s", float64(bytes.Load())/duration.Seconds())
|
||||
}
|
||||
|
||||
func TestHttp1Websocket(t *testing.T) {
|
||||
env := testenv.New(t)
|
||||
|
||||
up := upstreams.HTTP(nil)
|
||||
up.HandleWS("/ws", websocket.Upgrader{}, func(conn *websocket.Conn) error {
|
||||
for {
|
||||
mt, message, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// echo the message back
|
||||
err = conn.WriteMessage(mt, message)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
route := up.Route().
|
||||
From(env.SubdomainURL("ws-test")).
|
||||
Policy(func(p *config.Policy) {
|
||||
p.AllowPublicUnauthenticatedAccess = true
|
||||
p.AllowWebsockets = true
|
||||
})
|
||||
|
||||
env.AddUpstream(up)
|
||||
env.Start()
|
||||
snippets.WaitStartupComplete(env)
|
||||
|
||||
assert.NoError(t, up.DialWS(route, func(conn *websocket.Conn) error {
|
||||
if err := conn.SetWriteDeadline(time.Now().Add(1 * time.Second)); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := conn.WriteMessage(websocket.TextMessage, []byte("hello world")); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := conn.SetReadDeadline(time.Now().Add(1 * time.Second)); err != nil {
|
||||
return err
|
||||
}
|
||||
mt, bytes, err := conn.ReadMessage()
|
||||
if err := err; err != nil {
|
||||
return err
|
||||
}
|
||||
assert.Equal(t, websocket.TextMessage, mt)
|
||||
assert.Equal(t, "hello world", string(bytes))
|
||||
return nil
|
||||
}, upstreams.Path("/ws")))
|
||||
}
|
||||
|
||||
func TestClientCert(t *testing.T) {
|
||||
env := testenv.New(t)
|
||||
env.Add(scenarios.DownstreamMTLS(config.MTLSEnforcementRejectConnection))
|
||||
|
|
|
@ -23,6 +23,7 @@ import (
|
|||
"os/signal"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
@ -34,11 +35,13 @@ import (
|
|||
|
||||
"github.com/pomerium/pomerium/config"
|
||||
"github.com/pomerium/pomerium/config/envoyconfig/filemgr"
|
||||
"github.com/pomerium/pomerium/config/otelconfig"
|
||||
databroker_service "github.com/pomerium/pomerium/databroker"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
||||
"github.com/pomerium/pomerium/internal/testenv/envutil"
|
||||
"github.com/pomerium/pomerium/internal/testenv/values"
|
||||
"github.com/pomerium/pomerium/internal/version"
|
||||
"github.com/pomerium/pomerium/pkg/cmd/pomerium"
|
||||
"github.com/pomerium/pomerium/pkg/envoy"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
|
@ -97,6 +100,7 @@ type Environment interface {
|
|||
AuthenticateURL() values.Value[string]
|
||||
DatabrokerURL() values.Value[string]
|
||||
Ports() Ports
|
||||
Host() string
|
||||
SharedSecret() []byte
|
||||
CookieSecret() []byte
|
||||
|
||||
|
@ -244,6 +248,8 @@ type EnvironmentOptions struct {
|
|||
forceSilent bool
|
||||
traceDebugFlags trace.DebugFlags
|
||||
traceClient otlptrace.Client
|
||||
traceConfig *otelconfig.Config
|
||||
host string
|
||||
}
|
||||
|
||||
type EnvironmentOption func(*EnvironmentOptions)
|
||||
|
@ -300,15 +306,23 @@ func WithTraceClient(traceClient otlptrace.Client) EnvironmentOption {
|
|||
}
|
||||
}
|
||||
|
||||
func WithTraceConfig(traceConfig *otelconfig.Config) EnvironmentOption {
|
||||
return func(o *EnvironmentOptions) {
|
||||
o.traceConfig = traceConfig
|
||||
}
|
||||
}
|
||||
|
||||
var setGrpcLoggerOnce sync.Once
|
||||
|
||||
const defaultTraceDebugFlags = trace.TrackSpanCallers | trace.TrackSpanReferences
|
||||
|
||||
var (
|
||||
flagDebug = flag.Bool("env.debug", false, "enables test environment debug logging (equivalent to Debug() option)")
|
||||
flagPauseOnFailure = flag.Bool("env.pause-on-failure", false, "enables pausing the test environment on failure (equivalent to PauseOnFailure() option)")
|
||||
flagSilent = flag.Bool("env.silent", false, "suppresses all test environment output (equivalent to Silent() option)")
|
||||
flagTraceDebugFlags = flag.String("env.trace-debug-flags", strconv.Itoa(defaultTraceDebugFlags), "trace debug flags (equivalent to TraceDebugFlags() option)")
|
||||
flagDebug = flag.Bool("env.debug", false, "enables test environment debug logging (equivalent to Debug() option)")
|
||||
flagPauseOnFailure = flag.Bool("env.pause-on-failure", false, "enables pausing the test environment on failure (equivalent to PauseOnFailure() option)")
|
||||
flagSilent = flag.Bool("env.silent", false, "suppresses all test environment output (equivalent to Silent() option)")
|
||||
flagTraceDebugFlags = flag.String("env.trace-debug-flags", strconv.Itoa(defaultTraceDebugFlags), "trace debug flags (equivalent to TraceDebugFlags() option)")
|
||||
flagBindAddress = flag.String("env.bind-address", "127.0.0.1", "bind address for local services")
|
||||
flagTraceEnvironConfig = flag.Bool("env.use-trace-environ", false, "if true, will configure a trace client from environment variables if no trace client has been set")
|
||||
)
|
||||
|
||||
func New(t testing.TB, opts ...EnvironmentOption) Environment {
|
||||
|
@ -323,6 +337,7 @@ func New(t testing.TB, opts ...EnvironmentOption) Environment {
|
|||
forceSilent: *flagSilent,
|
||||
traceDebugFlags: trace.DebugFlags(defaultTraceDebugFlags),
|
||||
traceClient: trace.NoopClient{},
|
||||
host: *flagBindAddress,
|
||||
}
|
||||
options.apply(opts...)
|
||||
if testing.Short() {
|
||||
|
@ -332,6 +347,17 @@ func New(t testing.TB, opts ...EnvironmentOption) Environment {
|
|||
if addTraceDebugFlags {
|
||||
options.traceDebugFlags |= trace.DebugFlags(defaultTraceDebugFlags)
|
||||
}
|
||||
if *flagTraceEnvironConfig && options.traceConfig == nil &&
|
||||
(reflect.TypeOf(options.traceClient) == reflect.TypeFor[trace.NoopClient]()) {
|
||||
cfg := newOtelConfigFromEnv(t)
|
||||
options.traceConfig = &cfg
|
||||
client, err := trace.NewTraceClientFromConfig(cfg)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log("tracing configured from environment")
|
||||
options.traceClient = client
|
||||
}
|
||||
trace.UseGlobalPanicTracer()
|
||||
databroker.DebugUseFasterBackoff.Store(true)
|
||||
workspaceFolder, err := os.Getwd()
|
||||
|
@ -495,7 +521,7 @@ func (e *environment) AuthenticateURL() values.Value[string] {
|
|||
|
||||
func (e *environment) DatabrokerURL() values.Value[string] {
|
||||
return values.Bind(e.ports.Outbound, func(port int) string {
|
||||
return fmt.Sprintf("127.0.0.1:%d", port)
|
||||
return fmt.Sprintf("%s:%d", e.host, port)
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -503,6 +529,13 @@ func (e *environment) Ports() Ports {
|
|||
return e.ports
|
||||
}
|
||||
|
||||
func (e *environment) Host() string {
|
||||
if e.host == "" {
|
||||
return "127.0.0.1"
|
||||
}
|
||||
return e.host
|
||||
}
|
||||
|
||||
func (e *environment) CACert() *tls.Certificate {
|
||||
caCert, err := tls.LoadX509KeyPair(
|
||||
filepath.Join(e.tempDir, "certs", "ca.pem"),
|
||||
|
@ -571,9 +604,9 @@ func (e *environment) Start() {
|
|||
cfg.Options.Services = "all"
|
||||
cfg.Options.LogLevel = config.LogLevelDebug
|
||||
cfg.Options.ProxyLogLevel = config.LogLevelInfo
|
||||
cfg.Options.Addr = fmt.Sprintf("127.0.0.1:%d", e.ports.ProxyHTTP.Value())
|
||||
cfg.Options.GRPCAddr = fmt.Sprintf("127.0.0.1:%d", e.ports.ProxyGRPC.Value())
|
||||
cfg.Options.MetricsAddr = fmt.Sprintf("127.0.0.1:%d", e.ports.ProxyMetrics.Value())
|
||||
cfg.Options.Addr = fmt.Sprintf("%s:%d", e.host, e.ports.ProxyHTTP.Value())
|
||||
cfg.Options.GRPCAddr = fmt.Sprintf("%s:%d", e.host, e.ports.ProxyGRPC.Value())
|
||||
cfg.Options.MetricsAddr = fmt.Sprintf("%s:%d", e.host, e.ports.ProxyMetrics.Value())
|
||||
cfg.Options.CAFile = filepath.Join(e.tempDir, "certs", "ca.pem")
|
||||
cfg.Options.CertFile = filepath.Join(e.tempDir, "certs", "trusted.pem")
|
||||
cfg.Options.KeyFile = filepath.Join(e.tempDir, "certs", "trusted-key.pem")
|
||||
|
@ -598,6 +631,9 @@ func (e *environment) Start() {
|
|||
log.AccessLogFieldUserAgent,
|
||||
log.AccessLogFieldClientCertificate,
|
||||
}
|
||||
if e.traceConfig != nil {
|
||||
cfg.Options.Tracing = *e.traceConfig
|
||||
}
|
||||
|
||||
e.src = &configSource{cfg: cfg}
|
||||
e.AddTask(TaskFunc(func(ctx context.Context) error {
|
||||
|
@ -799,6 +835,7 @@ func (e *environment) Pause() {
|
|||
c := make(chan os.Signal, 1)
|
||||
signal.Notify(c, syscall.SIGINT)
|
||||
<-c
|
||||
signal.Stop(c)
|
||||
e.t.Log("\x1b[31mctrl+c received, continuing\x1b[0m")
|
||||
}
|
||||
|
||||
|
@ -816,6 +853,7 @@ func (e *environment) cleanup(cancelCause error) {
|
|||
c := make(chan os.Signal, 1)
|
||||
signal.Notify(c, syscall.SIGINT)
|
||||
<-c
|
||||
signal.Stop(c)
|
||||
e.t.Log("\x1b[31mctrl+c received, continuing\x1b[0m")
|
||||
signal.Stop(c)
|
||||
}
|
||||
|
@ -1043,3 +1081,13 @@ func (src *configSource) ModifyConfig(ctx context.Context, m Modifier) {
|
|||
li(ctx, src.cfg)
|
||||
}
|
||||
}
|
||||
|
||||
func newOtelConfigFromEnv(t testing.TB) otelconfig.Config {
|
||||
f, err := os.CreateTemp("", "tmp-config-*.yaml")
|
||||
require.NoError(t, err)
|
||||
defer os.Remove(f.Name())
|
||||
f.Close()
|
||||
cfg, err := config.NewFileOrEnvironmentSource(context.Background(), f.Name(), version.FullVersion())
|
||||
require.NoError(t, err)
|
||||
return cfg.GetConfig().Options.Tracing
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package testenv
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
|
@ -74,3 +75,19 @@ func (b *PolicyRoute) PPL(ppl string) Route {
|
|||
func (b *PolicyRoute) URL() values.Value[string] {
|
||||
return b.from
|
||||
}
|
||||
|
||||
type TCPRoute struct {
|
||||
PolicyRoute
|
||||
}
|
||||
|
||||
func (b *TCPRoute) From(fromURL values.Value[string]) Route {
|
||||
b.from = values.Bind(fromURL, func(urlStr string) string {
|
||||
from, _ := url.Parse(urlStr)
|
||||
from.Scheme = "tcp+https"
|
||||
from.Host = fmt.Sprintf("%s:%s", from.Hostname(), from.Port())
|
||||
return from.String()
|
||||
})
|
||||
return b
|
||||
}
|
||||
|
||||
var _ Route = (*TCPRoute)(nil)
|
||||
|
|
|
@ -6,6 +6,8 @@ import (
|
|||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
|
@ -32,6 +34,7 @@ import (
|
|||
)
|
||||
|
||||
type IDP struct {
|
||||
IDPOptions
|
||||
id values.Value[string]
|
||||
url values.Value[string]
|
||||
publicJWK jose.JSONWebKey
|
||||
|
@ -41,18 +44,56 @@ type IDP struct {
|
|||
userLookup map[string]*User
|
||||
}
|
||||
|
||||
type IDPOptions struct {
|
||||
enableTLS bool
|
||||
}
|
||||
|
||||
type IDPOption func(*IDPOptions)
|
||||
|
||||
func (o *IDPOptions) apply(opts ...IDPOption) {
|
||||
for _, op := range opts {
|
||||
op(o)
|
||||
}
|
||||
}
|
||||
|
||||
func WithEnableTLS(enableTLS bool) IDPOption {
|
||||
return func(o *IDPOptions) {
|
||||
o.enableTLS = enableTLS
|
||||
}
|
||||
}
|
||||
|
||||
// Attach implements testenv.Modifier.
|
||||
func (idp *IDP) Attach(ctx context.Context) {
|
||||
env := testenv.EnvFromContext(ctx)
|
||||
|
||||
router := upstreams.HTTP(nil, upstreams.WithDisplayName("IDP"))
|
||||
idpURL := env.SubdomainURL("mock-idp")
|
||||
|
||||
idp.url = values.Bind2(env.SubdomainURL("mock-idp"), router.Port(), func(urlStr string, port int) string {
|
||||
var tlsConfig values.Value[*tls.Config]
|
||||
if idp.enableTLS {
|
||||
tlsConfig = values.Bind(idpURL, func(urlStr string) *tls.Config {
|
||||
u, _ := url.Parse(urlStr)
|
||||
cert := env.NewServerCert(&x509.Certificate{
|
||||
DNSNames: []string{u.Hostname()},
|
||||
})
|
||||
return &tls.Config{
|
||||
RootCAs: env.ServerCAs(),
|
||||
Certificates: []tls.Certificate{tls.Certificate(*cert)},
|
||||
NextProtos: []string{"http/1.1", "h2"},
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
router := upstreams.HTTP(tlsConfig, upstreams.WithDisplayName("IDP"))
|
||||
|
||||
idp.url = values.Bind2(idpURL, router.Addr(), func(urlStr string, addr string) string {
|
||||
u, _ := url.Parse(urlStr)
|
||||
host, _, _ := net.SplitHostPort(u.Host)
|
||||
_, port, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
panic("bug: " + err.Error())
|
||||
}
|
||||
return u.ResolveReference(&url.URL{
|
||||
Scheme: "http",
|
||||
Host: fmt.Sprintf("%s:%d", host, port),
|
||||
Host: fmt.Sprintf("%s:%s", host, port),
|
||||
}).String()
|
||||
})
|
||||
var err error
|
||||
|
@ -108,7 +149,12 @@ func (idp *IDP) Modify(cfg *config.Config) {
|
|||
|
||||
var _ testenv.Modifier = (*IDP)(nil)
|
||||
|
||||
func NewIDP(users []*User) *IDP {
|
||||
func NewIDP(users []*User, opts ...IDPOption) *IDP {
|
||||
options := IDPOptions{
|
||||
enableTLS: true,
|
||||
}
|
||||
options.apply(opts...)
|
||||
|
||||
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
|
@ -136,6 +182,7 @@ func NewIDP(users []*User) *IDP {
|
|||
userLookup[user.ID] = user
|
||||
}
|
||||
return &IDP{
|
||||
IDPOptions: options,
|
||||
publicJWK: publicJWK,
|
||||
signingKey: signingKey,
|
||||
userLookup: userLookup,
|
||||
|
|
|
@ -174,16 +174,16 @@ func (rec *OTLPTraceReceiver) FlushResourceSpans() []*tracev1.ResourceSpans {
|
|||
// GRPCEndpointURL returns a url suitable for use with the environment variable
|
||||
// $OTEL_EXPORTER_OTLP_TRACES_ENDPOINT or with [otlptracegrpc.WithEndpointURL].
|
||||
func (rec *OTLPTraceReceiver) GRPCEndpointURL() values.Value[string] {
|
||||
return values.Chain(rec.grpcUpstream, upstreams.GRPCUpstream.Port, func(port int) string {
|
||||
return fmt.Sprintf("http://127.0.0.1:%d", port)
|
||||
return values.Chain(rec.grpcUpstream, upstreams.GRPCUpstream.Addr, func(addr string) string {
|
||||
return fmt.Sprintf("http://%s", addr)
|
||||
})
|
||||
}
|
||||
|
||||
// GRPCEndpointURL returns a url suitable for use with the environment variable
|
||||
// $OTEL_EXPORTER_OTLP_TRACES_ENDPOINT or with [otlptracehttp.WithEndpointURL].
|
||||
func (rec *OTLPTraceReceiver) HTTPEndpointURL() values.Value[string] {
|
||||
return values.Chain(rec.httpUpstream, upstreams.HTTPUpstream.Port, func(port int) string {
|
||||
return fmt.Sprintf("http://127.0.0.1:%d/v1/traces", port)
|
||||
return values.Chain(rec.httpUpstream, upstreams.HTTPUpstream.Addr, func(addr string) string {
|
||||
return fmt.Sprintf("http://%s/v1/traces", addr)
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -200,9 +200,9 @@ func (rec *OTLPTraceReceiver) NewGRPCClient(opts ...otlptracegrpc.Option) otlptr
|
|||
|
||||
func (rec *OTLPTraceReceiver) NewHTTPClient(opts ...otlptracehttp.Option) otlptrace.Client {
|
||||
return &deferredClient{
|
||||
client: values.Chain(rec.httpUpstream, upstreams.HTTPUpstream.Port, func(port int) otlptrace.Client {
|
||||
client: values.Chain(rec.httpUpstream, upstreams.HTTPUpstream.Addr, func(addr string) otlptrace.Client {
|
||||
return otlptracehttp.NewClient(append(opts,
|
||||
otlptracehttp.WithEndpointURL(fmt.Sprintf("http://127.0.0.1:%d/v1/traces", port)),
|
||||
otlptracehttp.WithEndpointURL(fmt.Sprintf("http://%s/v1/traces", addr)),
|
||||
otlptracehttp.WithTimeout(1*time.Minute),
|
||||
)...)
|
||||
}),
|
||||
|
|
|
@ -204,7 +204,8 @@ func (f *taskFunc) Run(ctx context.Context) error {
|
|||
type Upstream interface {
|
||||
Modifier
|
||||
Task
|
||||
Port() values.Value[int]
|
||||
|
||||
Addr() values.Value[string]
|
||||
Route() RouteStub
|
||||
}
|
||||
|
||||
|
|
|
@ -97,8 +97,10 @@ type service struct {
|
|||
impl any
|
||||
}
|
||||
|
||||
func (g *grpcUpstream) Port() values.Value[int] {
|
||||
return g.serverPort
|
||||
func (g *grpcUpstream) Addr() values.Value[string] {
|
||||
return values.Bind(g.serverPort, func(port int) string {
|
||||
return fmt.Sprintf("%s:%d", g.Env().Host(), port)
|
||||
})
|
||||
}
|
||||
|
||||
// RegisterService implements grpc.ServiceRegistrar.
|
||||
|
@ -117,7 +119,7 @@ func (g *grpcUpstream) Route() testenv.RouteStub {
|
|||
protocol = "https"
|
||||
}
|
||||
r.To(values.Bind(g.serverPort, func(port int) string {
|
||||
return fmt.Sprintf("%s://127.0.0.1:%d", protocol, port)
|
||||
return fmt.Sprintf("%s://%s:%d", protocol, g.Env().Host(), port)
|
||||
}))
|
||||
g.Add(r)
|
||||
return r
|
||||
|
@ -125,7 +127,7 @@ func (g *grpcUpstream) Route() testenv.RouteStub {
|
|||
|
||||
// Start implements testenv.Upstream.
|
||||
func (g *grpcUpstream) Run(ctx context.Context) error {
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
listener, err := net.Listen("tcp", fmt.Sprintf("%s:0", g.Env().Host()))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -187,7 +189,7 @@ func (g *grpcUpstream) Dial(r testenv.Route, dialOpts ...grpc.DialOption) *grpc.
|
|||
}
|
||||
|
||||
func (g *grpcUpstream) DirectConnect(dialOpts ...grpc.DialOption) *grpc.ClientConn {
|
||||
cc, err := grpc.NewClient(fmt.Sprintf("127.0.0.1:%d", g.Port().Value()),
|
||||
cc, err := grpc.NewClient(g.Addr().Value(),
|
||||
append(g.withDefaultDialOpts(dialOpts), grpc.WithTransportCredentials(insecure.NewCredentials()))...)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
|
|
|
@ -1,10 +1,8 @@
|
|||
package upstreams
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
|
@ -13,14 +11,12 @@ import (
|
|||
"net/http/cookiejar"
|
||||
"net/http/httptrace"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/pomerium/pomerium/integration/forms"
|
||||
"github.com/pomerium/pomerium/internal/retry"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
||||
"github.com/pomerium/pomerium/internal/testenv"
|
||||
"github.com/pomerium/pomerium/internal/testenv/snippets"
|
||||
|
@ -29,9 +25,15 @@ import (
|
|||
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"go.opentelemetry.io/otel/codes"
|
||||
semconv "go.opentelemetry.io/otel/semconv/v1.26.0"
|
||||
oteltrace "go.opentelemetry.io/otel/trace"
|
||||
"google.golang.org/protobuf/proto"
|
||||
)
|
||||
|
||||
type Protocol string
|
||||
|
||||
const (
|
||||
DialHTTP1 Protocol = "http/1.1"
|
||||
DialHTTP2 Protocol = "h2"
|
||||
DialHTTP3 Protocol = "h3"
|
||||
)
|
||||
|
||||
type RequestOptions struct {
|
||||
|
@ -42,12 +44,18 @@ type RequestOptions struct {
|
|||
authenticateAs string
|
||||
body any
|
||||
clientCerts []tls.Certificate
|
||||
client *http.Client
|
||||
clientHook func(*http.Client) *http.Client
|
||||
dialerHook func(*websocket.Dialer, *url.URL) (*websocket.Dialer, *url.URL)
|
||||
dialProtocol Protocol
|
||||
trace *httptrace.ClientTrace
|
||||
}
|
||||
|
||||
type RequestOption func(*RequestOptions)
|
||||
|
||||
func (ro RequestOption) Format(fmt.State, rune) {
|
||||
panic("test bug: request option mistakenly passed to assert function")
|
||||
}
|
||||
|
||||
func (o *RequestOptions) apply(opts ...RequestOption) {
|
||||
for _, op := range opts {
|
||||
op(o)
|
||||
|
@ -82,9 +90,38 @@ func AuthenticateAs(email string) RequestOption {
|
|||
}
|
||||
}
|
||||
|
||||
func Client(c *http.Client) RequestOption {
|
||||
// ClientHook allows editing or replacing the http client before it is used.
|
||||
// When any request is about to start, this function will be called with the
|
||||
// client that would be used to make the request. The returned client will
|
||||
// be the actual client used for that request. It can be the same as the input
|
||||
// (with or without modification), or replaced entirely.
|
||||
//
|
||||
// Note: the Transport of the client passed to the hook will always be a
|
||||
// [*Transport]. That transport's underlying transport will always be
|
||||
// a [*otelhttp.Transport].
|
||||
func ClientHook(f func(*http.Client) *http.Client) RequestOption {
|
||||
return func(o *RequestOptions) {
|
||||
o.client = c
|
||||
o.clientHook = f
|
||||
}
|
||||
}
|
||||
|
||||
// DialerHook allows editing or replacing the websocket dialer before it is
|
||||
// used. When a websocket request is about to start (using the DialWS method),
|
||||
// this function will be called with the dialer that would be used, and the
|
||||
// destination URL (including wss:// scheme, and path if one is present). The
|
||||
// returned dialer+URL will be the actual dialer+URL used for that request.
|
||||
//
|
||||
// If ClientHook is also set, both will be called. The dialer passed to this
|
||||
// hook will have its TLSClientConfig and Jar fields set from the client.
|
||||
func DialerHook(f func(*websocket.Dialer, *url.URL) (*websocket.Dialer, *url.URL)) RequestOption {
|
||||
return func(o *RequestOptions) {
|
||||
o.dialerHook = f
|
||||
}
|
||||
}
|
||||
|
||||
func DialProtocol(protocol Protocol) RequestOption {
|
||||
return func(o *RequestOptions) {
|
||||
o.dialProtocol = protocol
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -143,10 +180,12 @@ type HTTPUpstream interface {
|
|||
testenv.Upstream
|
||||
|
||||
Handle(path string, f func(http.ResponseWriter, *http.Request)) *mux.Route
|
||||
HandleWS(path string, upgrader websocket.Upgrader, f func(conn *websocket.Conn) error) *mux.Route
|
||||
|
||||
Get(r testenv.Route, opts ...RequestOption) (*http.Response, error)
|
||||
Post(r testenv.Route, opts ...RequestOption) (*http.Response, error)
|
||||
Do(method string, r testenv.Route, opts ...RequestOption) (*http.Response, error)
|
||||
DialWS(r testenv.Route, f func(conn *websocket.Conn) error, opts ...RequestOption) error
|
||||
}
|
||||
|
||||
type httpUpstream struct {
|
||||
|
@ -194,8 +233,10 @@ func HTTP(tlsConfig values.Value[*tls.Config], opts ...HTTPUpstreamOption) HTTPU
|
|||
}
|
||||
|
||||
// Port implements HTTPUpstream.
|
||||
func (h *httpUpstream) Port() values.Value[int] {
|
||||
return h.serverPort
|
||||
func (h *httpUpstream) Addr() values.Value[string] {
|
||||
return values.Bind(h.serverPort, func(port int) string {
|
||||
return fmt.Sprintf("%s:%d", h.Env().Host(), port)
|
||||
})
|
||||
}
|
||||
|
||||
// Router implements HTTPUpstream.
|
||||
|
@ -203,12 +244,37 @@ func (h *httpUpstream) Handle(path string, f func(http.ResponseWriter, *http.Req
|
|||
return h.router.HandleFunc(path, f)
|
||||
}
|
||||
|
||||
// Router implements HTTPUpstream.
|
||||
func (h *httpUpstream) HandleWS(path string, upgrader websocket.Upgrader, f func(*websocket.Conn) error) *mux.Route {
|
||||
return h.router.HandleFunc(path, func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, span := trace.Continue(r.Context(), "HandleWS")
|
||||
defer span.End()
|
||||
c, err := upgrader.Upgrade(w, r.WithContext(ctx), nil)
|
||||
if err != nil {
|
||||
span.SetStatus(codes.Error, err.Error())
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
_, _ = w.Write([]byte(err.Error()))
|
||||
return
|
||||
}
|
||||
defer c.Close()
|
||||
|
||||
err = f(c)
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
return
|
||||
}
|
||||
span.SetStatus(codes.Error, err.Error())
|
||||
fmt.Fprintf(os.Stderr, "websocket error: %s\n", err.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Route implements HTTPUpstream.
|
||||
func (h *httpUpstream) Route() testenv.RouteStub {
|
||||
r := &testenv.PolicyRoute{}
|
||||
protocol := "http"
|
||||
r.To(values.Bind(h.serverPort, func(port int) string {
|
||||
return fmt.Sprintf("%s://127.0.0.1:%d", protocol, port)
|
||||
return fmt.Sprintf("%s://%s:%d", protocol, h.Env().Host(), port)
|
||||
}))
|
||||
h.Add(r)
|
||||
return r
|
||||
|
@ -216,15 +282,21 @@ func (h *httpUpstream) Route() testenv.RouteStub {
|
|||
|
||||
// Run implements HTTPUpstream.
|
||||
func (h *httpUpstream) Run(ctx context.Context) error {
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
return err
|
||||
var listener net.Listener
|
||||
if h.tlsConfig != nil {
|
||||
var err error
|
||||
listener, err = tls.Listen("tcp", fmt.Sprintf("%s:0", h.Env().Host()), h.tlsConfig.Value())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
var err error
|
||||
listener, err = net.Listen("tcp", fmt.Sprintf("%s:0", h.Env().Host()))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
h.serverPort.Resolve(listener.Addr().(*net.TCPAddr).Port)
|
||||
var tlsConfig *tls.Config
|
||||
if h.tlsConfig != nil {
|
||||
tlsConfig = h.tlsConfig.Value()
|
||||
}
|
||||
if h.serverTracerProviderOverride != nil {
|
||||
h.serverTracerProvider.Resolve(h.serverTracerProviderOverride)
|
||||
} else {
|
||||
|
@ -238,8 +310,7 @@ func (h *httpUpstream) Run(ctx context.Context) error {
|
|||
h.router.Use(trace.NewHTTPMiddleware(otelhttp.WithTracerProvider(h.serverTracerProvider.Value())))
|
||||
|
||||
server := &http.Server{
|
||||
Handler: h.router,
|
||||
TLSConfig: tlsConfig,
|
||||
Handler: h.router,
|
||||
// BaseContext: func(net.Listener) context.Context {
|
||||
// return ctx
|
||||
// },
|
||||
|
@ -277,6 +348,53 @@ func (h *httpUpstream) Post(r testenv.Route, opts ...RequestOption) (*http.Respo
|
|||
return h.Do(http.MethodPost, r, opts...)
|
||||
}
|
||||
|
||||
type Transport struct {
|
||||
*otelhttp.Transport
|
||||
// The underlying http.Transport instance wrapped by the otelhttp.Transport.
|
||||
Base *http.Transport
|
||||
}
|
||||
|
||||
var _ http.RoundTripper = Transport{}
|
||||
|
||||
func (h *httpUpstream) newClient(options *RequestOptions) *http.Client {
|
||||
transport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
transport.TLSClientConfig = &tls.Config{
|
||||
RootCAs: h.Env().ServerCAs(),
|
||||
Certificates: options.clientCerts,
|
||||
}
|
||||
transport.DialTLSContext = nil
|
||||
c := http.Client{
|
||||
Transport: &Transport{
|
||||
Transport: otelhttp.NewTransport(transport,
|
||||
otelhttp.WithTracerProvider(h.clientTracerProvider.Value()),
|
||||
otelhttp.WithSpanNameFormatter(func(_ string, r *http.Request) string {
|
||||
return fmt.Sprintf("Client: %s %s", r.Method, r.URL.Path)
|
||||
}),
|
||||
),
|
||||
Base: transport,
|
||||
},
|
||||
}
|
||||
c.Jar, _ = cookiejar.New(&cookiejar.Options{})
|
||||
return &c
|
||||
}
|
||||
|
||||
func (h *httpUpstream) getRouteClient(r testenv.Route, options *RequestOptions) *http.Client {
|
||||
span := oteltrace.SpanFromContext(options.requestCtx)
|
||||
var cachedClient any
|
||||
var ok bool
|
||||
if cachedClient, ok = h.clientCache.Load(r); !ok {
|
||||
span.AddEvent("creating new http client")
|
||||
cachedClient, _ = h.clientCache.LoadOrStore(r, h.newClient(options))
|
||||
} else {
|
||||
span.AddEvent("using cached http client")
|
||||
}
|
||||
client := cachedClient.(*http.Client)
|
||||
if options.clientHook != nil {
|
||||
client = options.clientHook(client)
|
||||
}
|
||||
return client
|
||||
}
|
||||
|
||||
// Do implements HTTPUpstream.
|
||||
func (h *httpUpstream) Do(method string, r testenv.Route, opts ...RequestOption) (*http.Response, error) {
|
||||
options := RequestOptions{
|
||||
|
@ -303,141 +421,54 @@ func (h *httpUpstream) Do(method string, r testenv.Route, opts ...RequestOption)
|
|||
options.requestCtx = ctx
|
||||
defer span.End()
|
||||
|
||||
newClient := func() *http.Client {
|
||||
transport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
transport.TLSClientConfig = &tls.Config{
|
||||
RootCAs: h.Env().ServerCAs(),
|
||||
Certificates: options.clientCerts,
|
||||
}
|
||||
transport.DialTLSContext = nil
|
||||
c := http.Client{
|
||||
Transport: otelhttp.NewTransport(transport,
|
||||
otelhttp.WithTracerProvider(h.clientTracerProvider.Value()),
|
||||
otelhttp.WithSpanNameFormatter(func(_ string, r *http.Request) string {
|
||||
return fmt.Sprintf("Client: %s %s", r.Method, r.URL.Path)
|
||||
}),
|
||||
),
|
||||
}
|
||||
c.Jar, _ = cookiejar.New(&cookiejar.Options{})
|
||||
return &c
|
||||
}
|
||||
var client *http.Client
|
||||
if options.client != nil {
|
||||
client = options.client
|
||||
} else {
|
||||
var cachedClient any
|
||||
var ok bool
|
||||
if cachedClient, ok = h.clientCache.Load(r); !ok {
|
||||
span.AddEvent("creating new http client")
|
||||
cachedClient, _ = h.clientCache.LoadOrStore(r, newClient())
|
||||
} else {
|
||||
span.AddEvent("using cached http client")
|
||||
}
|
||||
client = cachedClient.(*http.Client)
|
||||
}
|
||||
|
||||
var resp *http.Response
|
||||
resendCount := 0
|
||||
if err := retry.Retry(ctx, "http", func(ctx context.Context) error {
|
||||
req, err := http.NewRequestWithContext(ctx, method, u.String(), nil)
|
||||
if err != nil {
|
||||
return retry.NewTerminalError(err)
|
||||
}
|
||||
switch body := options.body.(type) {
|
||||
case string:
|
||||
req.Body = io.NopCloser(strings.NewReader(body))
|
||||
case []byte:
|
||||
req.Body = io.NopCloser(bytes.NewReader(body))
|
||||
case io.Reader:
|
||||
req.Body = io.NopCloser(body)
|
||||
case proto.Message:
|
||||
buf, err := proto.Marshal(body)
|
||||
if err != nil {
|
||||
return retry.NewTerminalError(err)
|
||||
}
|
||||
req.Body = io.NopCloser(bytes.NewReader(buf))
|
||||
req.Header.Set("Content-Type", "application/octet-stream")
|
||||
default:
|
||||
buf, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("unsupported body type: %T", body))
|
||||
}
|
||||
req.Body = io.NopCloser(bytes.NewReader(buf))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
case nil:
|
||||
}
|
||||
|
||||
if options.authenticateAs != "" {
|
||||
resp, err = authenticateFlow(ctx, client, req, options.authenticateAs) //nolint:bodyclose
|
||||
} else {
|
||||
resp, err = client.Do(req) //nolint:bodyclose
|
||||
}
|
||||
// retry on connection refused
|
||||
if err != nil {
|
||||
span.RecordError(err)
|
||||
var opErr *net.OpError
|
||||
if errors.As(err, &opErr) && opErr.Op == "dial" && opErr.Err.Error() == "connect: connection refused" {
|
||||
span.AddEvent("Retrying on dial error")
|
||||
return err
|
||||
}
|
||||
return retry.NewTerminalError(err)
|
||||
}
|
||||
if resp.StatusCode/100 == 5 {
|
||||
resendCount++
|
||||
_, _ = io.ReadAll(resp.Body)
|
||||
_ = resp.Body.Close()
|
||||
span.SetAttributes(semconv.HTTPRequestResendCount(resendCount))
|
||||
span.AddEvent("Retrying on 5xx error", oteltrace.WithAttributes(
|
||||
attribute.String("status", resp.Status),
|
||||
))
|
||||
return errors.New(http.StatusText(resp.StatusCode))
|
||||
}
|
||||
span.SetStatus(codes.Ok, "request completed successfully")
|
||||
return nil
|
||||
},
|
||||
retry.WithInitialInterval(1*time.Millisecond),
|
||||
retry.WithMaxInterval(100*time.Millisecond),
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return resp, nil
|
||||
return doAuthenticatedRequest(options.requestCtx,
|
||||
func(ctx context.Context) (*http.Request, error) {
|
||||
return http.NewRequestWithContext(ctx, method, u.String(), nil)
|
||||
},
|
||||
func(context.Context) *http.Client {
|
||||
return h.getRouteClient(r, &options)
|
||||
},
|
||||
&options,
|
||||
)
|
||||
}
|
||||
|
||||
func authenticateFlow(ctx context.Context, client *http.Client, req *http.Request, email string) (*http.Response, error) {
|
||||
span := oteltrace.SpanFromContext(ctx)
|
||||
var res *http.Response
|
||||
originalHostname := req.URL.Hostname()
|
||||
res, err := client.Do(req)
|
||||
func (h *httpUpstream) DialWS(r testenv.Route, f func(conn *websocket.Conn) error, opts ...RequestOption) error {
|
||||
options := RequestOptions{
|
||||
requestCtx: h.Env().Context(),
|
||||
}
|
||||
options.apply(opts...)
|
||||
u, err := url.Parse(r.URL().Value())
|
||||
if err != nil {
|
||||
span.RecordError(err)
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
location := res.Request.URL
|
||||
if location.Hostname() == originalHostname {
|
||||
// already authenticated
|
||||
span.SetStatus(codes.Ok, "already authenticated")
|
||||
return res, nil
|
||||
u.Scheme = "wss"
|
||||
if options.path != "" || options.query != nil {
|
||||
u = u.ResolveReference(&url.URL{
|
||||
Path: options.path,
|
||||
RawQuery: options.query.Encode(),
|
||||
})
|
||||
}
|
||||
fs := forms.Parse(res.Body)
|
||||
_, _ = io.ReadAll(res.Body)
|
||||
_ = res.Body.Close()
|
||||
if len(fs) > 0 {
|
||||
f := fs[0]
|
||||
f.Inputs["email"] = email
|
||||
f.Inputs["token_expiration"] = strconv.Itoa(int((time.Hour * 24).Seconds()))
|
||||
span.AddEvent("submitting form", oteltrace.WithAttributes(attribute.String("location", location.String())))
|
||||
formReq, err := f.NewRequestWithContext(ctx, location)
|
||||
if err != nil {
|
||||
span.RecordError(err)
|
||||
return nil, err
|
||||
}
|
||||
resp, err := client.Do(formReq)
|
||||
if err != nil {
|
||||
span.RecordError(err)
|
||||
return nil, err
|
||||
}
|
||||
span.SetStatus(codes.Ok, "form submitted successfully")
|
||||
return resp, nil
|
||||
ctx, span := h.clientTracer.Value().Start(options.requestCtx, "httpUpstream.Dial", oteltrace.WithAttributes(
|
||||
attribute.String("url", u.String()),
|
||||
))
|
||||
options.requestCtx = ctx
|
||||
defer span.End()
|
||||
|
||||
client := h.getRouteClient(r, &options)
|
||||
d := &websocket.Dialer{
|
||||
HandshakeTimeout: 10 * time.Second,
|
||||
TLSClientConfig: client.Transport.(*Transport).Base.TLSClientConfig,
|
||||
Jar: client.Jar,
|
||||
}
|
||||
return nil, fmt.Errorf("test bug: expected IDP login form")
|
||||
if options.dialerHook != nil {
|
||||
d, u = options.dialerHook(d, u)
|
||||
}
|
||||
conn, resp, err := d.DialContext(options.requestCtx, u.String(), nil)
|
||||
if err != nil {
|
||||
resp.Body.Close()
|
||||
return fmt.Errorf("DialContext: %w", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
return f(conn)
|
||||
}
|
||||
|
|
|
@ -15,6 +15,7 @@ type CommonUpstreamOptions struct {
|
|||
type CommonUpstreamOption interface {
|
||||
GRPCUpstreamOption
|
||||
HTTPUpstreamOption
|
||||
TCPUpstreamOption
|
||||
}
|
||||
|
||||
type commonUpstreamOption func(o *CommonUpstreamOptions)
|
||||
|
@ -25,6 +26,9 @@ func (c commonUpstreamOption) applyGRPC(o *GRPCUpstreamOptions) { c(&o.CommonUps
|
|||
// applyHTTP implements CommonUpstreamOption.
|
||||
func (c commonUpstreamOption) applyHTTP(o *HTTPUpstreamOptions) { c(&o.CommonUpstreamOptions) }
|
||||
|
||||
// applyTCP implements CommonUpstreamOption.
|
||||
func (c commonUpstreamOption) applyTCP(o *TCPUpstreamOptions) { c(&o.CommonUpstreamOptions) }
|
||||
|
||||
func WithDisplayName(displayName string) CommonUpstreamOption {
|
||||
return commonUpstreamOption(func(o *CommonUpstreamOptions) {
|
||||
o.displayName = displayName
|
||||
|
|
343
internal/testenv/upstreams/tcp.go
Normal file
343
internal/testenv/upstreams/tcp.go
Normal file
|
@ -0,0 +1,343 @@
|
|||
package upstreams
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/cookiejar"
|
||||
"net/http/httptrace"
|
||||
"net/url"
|
||||
"sync"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
||||
"github.com/pomerium/pomerium/internal/testenv"
|
||||
"github.com/pomerium/pomerium/internal/testenv/values"
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
oteltrace "go.opentelemetry.io/otel/trace"
|
||||
"golang.org/x/net/http2"
|
||||
)
|
||||
|
||||
type TCPUpstream interface {
|
||||
testenv.Upstream
|
||||
|
||||
Handle(fn func(context.Context, net.Conn) error)
|
||||
|
||||
Dial(r testenv.Route, fn func(context.Context, net.Conn) error, opts ...RequestOption) error
|
||||
}
|
||||
|
||||
type TCPUpstreamOptions struct {
|
||||
CommonUpstreamOptions
|
||||
}
|
||||
|
||||
type TCPUpstreamOption interface {
|
||||
applyTCP(*TCPUpstreamOptions)
|
||||
}
|
||||
|
||||
type tcpUpstream struct {
|
||||
TCPUpstreamOptions
|
||||
testenv.Aggregate
|
||||
serverPort values.MutableValue[int]
|
||||
serverHandler func(context.Context, net.Conn) error
|
||||
|
||||
serverTracerProvider values.MutableValue[oteltrace.TracerProvider]
|
||||
clientTracerProvider values.MutableValue[oteltrace.TracerProvider]
|
||||
clientTracer values.Value[oteltrace.Tracer]
|
||||
}
|
||||
|
||||
func TCP(opts ...TCPUpstreamOption) TCPUpstream {
|
||||
options := TCPUpstreamOptions{
|
||||
CommonUpstreamOptions: CommonUpstreamOptions{
|
||||
displayName: "TCP Upstream",
|
||||
},
|
||||
}
|
||||
for _, op := range opts {
|
||||
op.applyTCP(&options)
|
||||
}
|
||||
up := &tcpUpstream{
|
||||
TCPUpstreamOptions: options,
|
||||
serverPort: values.Deferred[int](),
|
||||
|
||||
serverTracerProvider: values.Deferred[oteltrace.TracerProvider](),
|
||||
clientTracerProvider: values.Deferred[oteltrace.TracerProvider](),
|
||||
}
|
||||
up.clientTracer = values.Bind(up.clientTracerProvider, func(tp oteltrace.TracerProvider) oteltrace.Tracer {
|
||||
return tp.Tracer(trace.PomeriumCoreTracer)
|
||||
})
|
||||
up.RecordCaller()
|
||||
return up
|
||||
}
|
||||
|
||||
// Dial implements TCPUpstream.
|
||||
func (t *tcpUpstream) Dial(r testenv.Route, clientHandler func(context.Context, net.Conn) error, opts ...RequestOption) error {
|
||||
options := RequestOptions{
|
||||
requestCtx: t.Env().Context(),
|
||||
dialProtocol: DialHTTP1,
|
||||
}
|
||||
options.apply(opts...)
|
||||
u, err := url.Parse(r.URL().Value())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ctx, span := t.clientTracer.Value().Start(options.requestCtx, "tcpUpstream.Do", oteltrace.WithAttributes(
|
||||
attribute.String("protocol", string(options.dialProtocol)),
|
||||
attribute.String("url", u.String()),
|
||||
))
|
||||
if options.path != "" || options.query != nil {
|
||||
u = u.ResolveReference(&url.URL{
|
||||
Path: options.path,
|
||||
RawQuery: options.query.Encode(),
|
||||
})
|
||||
}
|
||||
if options.trace != nil {
|
||||
ctx = httptrace.WithClientTrace(ctx, options.trace)
|
||||
}
|
||||
options.requestCtx = ctx
|
||||
defer span.End()
|
||||
|
||||
var remoteConn *tls.Conn
|
||||
remoteWriter := make(chan *io.PipeWriter, 1)
|
||||
|
||||
connectURL := &url.URL{Scheme: "https", Host: u.Host, Path: u.Path}
|
||||
|
||||
var getClientFn func(context.Context) *http.Client
|
||||
var newRequestFn func(ctx context.Context) (*http.Request, error)
|
||||
switch options.dialProtocol {
|
||||
case DialHTTP1:
|
||||
getClientFn = t.h1Dialer(&options, connectURL, &remoteConn)
|
||||
newRequestFn = func(ctx context.Context) (*http.Request, error) {
|
||||
req := (&http.Request{
|
||||
Method: http.MethodConnect,
|
||||
URL: connectURL,
|
||||
Host: u.Host,
|
||||
}).WithContext(ctx)
|
||||
return req, nil
|
||||
}
|
||||
case DialHTTP2:
|
||||
getClientFn = t.h2Dialer(&options, connectURL, &remoteConn, remoteWriter)
|
||||
newRequestFn = func(ctx context.Context) (*http.Request, error) {
|
||||
req := (&http.Request{
|
||||
Method: http.MethodConnect,
|
||||
URL: connectURL,
|
||||
Host: u.Host,
|
||||
Proto: "HTTP/2",
|
||||
}).WithContext(ctx)
|
||||
return req, nil
|
||||
}
|
||||
case DialHTTP3:
|
||||
panic("not implemented")
|
||||
}
|
||||
resp, err := doAuthenticatedRequest(options.requestCtx, newRequestFn, getClientFn, &options)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
resp.Body.Close()
|
||||
return errors.New(resp.Status)
|
||||
}
|
||||
if resp.Request.URL.Path == "/oidc/auth" {
|
||||
if options.authenticateAs == "" {
|
||||
return errors.New("test bug: unexpected IDP redirect; missing AuthenticateAs option to Dial()")
|
||||
}
|
||||
return errors.New("internal test bug: unexpected IDP redirect")
|
||||
}
|
||||
|
||||
var w io.WriteCloser = remoteConn
|
||||
if options.dialProtocol == DialHTTP2 {
|
||||
w = <-remoteWriter
|
||||
}
|
||||
|
||||
conn := NewRWConn(resp.Body, w)
|
||||
defer conn.Close()
|
||||
return clientHandler(resp.Request.Context(), conn)
|
||||
}
|
||||
|
||||
func (t *tcpUpstream) h1Dialer(
|
||||
options *RequestOptions,
|
||||
connectURL *url.URL,
|
||||
remoteConn **tls.Conn,
|
||||
) func(context.Context) *http.Client {
|
||||
jar, _ := cookiejar.New(nil)
|
||||
return func(context.Context) *http.Client {
|
||||
tlsConfig := &tls.Config{
|
||||
RootCAs: t.Env().ServerCAs(),
|
||||
Certificates: options.clientCerts,
|
||||
NextProtos: []string{"http/1.1"},
|
||||
}
|
||||
client := &http.Client{
|
||||
Transport: &http.Transport{
|
||||
DisableKeepAlives: true,
|
||||
DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
if *remoteConn != nil {
|
||||
(*remoteConn).Close()
|
||||
*remoteConn = nil
|
||||
}
|
||||
dialer := &tls.Dialer{
|
||||
Config: tlsConfig,
|
||||
}
|
||||
cc, err := dialer.DialContext(ctx, network, addr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%w: %w", ErrRetry, err)
|
||||
}
|
||||
protocol := cc.(*tls.Conn).ConnectionState().NegotiatedProtocol
|
||||
if protocol != "http/1.1" {
|
||||
cc.Close()
|
||||
return nil, fmt.Errorf("error: unexpected TLS protocol: %s", protocol)
|
||||
}
|
||||
*remoteConn = cc.(*tls.Conn)
|
||||
return cc, nil
|
||||
},
|
||||
TLSClientConfig: tlsConfig, // important
|
||||
},
|
||||
CheckRedirect: func(req *http.Request, _ []*http.Request) error {
|
||||
if req.URL.String() == connectURL.String() && req.Method == http.MethodGet {
|
||||
req.Method = http.MethodConnect
|
||||
}
|
||||
return nil
|
||||
},
|
||||
Jar: jar,
|
||||
}
|
||||
return client
|
||||
}
|
||||
}
|
||||
|
||||
func (t *tcpUpstream) h2Dialer(
|
||||
options *RequestOptions,
|
||||
connectURL *url.URL,
|
||||
remoteConn **tls.Conn,
|
||||
writer chan<- *io.PipeWriter,
|
||||
) func(context.Context) *http.Client {
|
||||
jar, _ := cookiejar.New(nil)
|
||||
return func(context.Context) *http.Client {
|
||||
h1 := &http.Transport{
|
||||
ForceAttemptHTTP2: true,
|
||||
DisableKeepAlives: true,
|
||||
DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
if *remoteConn != nil {
|
||||
(*remoteConn).Close()
|
||||
*remoteConn = nil
|
||||
}
|
||||
dialer := &tls.Dialer{
|
||||
Config: &tls.Config{
|
||||
RootCAs: t.Env().ServerCAs(),
|
||||
Certificates: options.clientCerts,
|
||||
NextProtos: []string{"h2"},
|
||||
},
|
||||
}
|
||||
cc, err := dialer.DialContext(ctx, network, addr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%w: %w", ErrRetry, err)
|
||||
}
|
||||
protocol := cc.(*tls.Conn).ConnectionState().NegotiatedProtocol
|
||||
if protocol != "h2" {
|
||||
cc.Close()
|
||||
return nil, fmt.Errorf("error: unexpected TLS protocol: %s", protocol)
|
||||
}
|
||||
*remoteConn = cc.(*tls.Conn)
|
||||
|
||||
return cc, nil
|
||||
},
|
||||
TLSClientConfig: &tls.Config{
|
||||
RootCAs: t.Env().ServerCAs(),
|
||||
Certificates: options.clientCerts,
|
||||
NextProtos: []string{"h2"},
|
||||
},
|
||||
}
|
||||
if err := http2.ConfigureTransport(h1); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
client := &http.Client{
|
||||
Transport: h1,
|
||||
CheckRedirect: func(req *http.Request, _ []*http.Request) error {
|
||||
if req.URL.String() == connectURL.String() && req.Method == http.MethodGet {
|
||||
pr, pw := io.Pipe()
|
||||
req.Method = http.MethodConnect
|
||||
req.Body = pr
|
||||
req.ContentLength = -1
|
||||
writer <- pw
|
||||
}
|
||||
return nil
|
||||
},
|
||||
Jar: jar,
|
||||
}
|
||||
return client
|
||||
}
|
||||
}
|
||||
|
||||
// Handle implements TCPUpstream.
|
||||
func (t *tcpUpstream) Handle(fn func(context.Context, net.Conn) error) {
|
||||
t.serverHandler = fn
|
||||
}
|
||||
|
||||
// Port implements TCPUpstream.
|
||||
func (t *tcpUpstream) Addr() values.Value[string] {
|
||||
return values.Bind(t.serverPort, func(port int) string {
|
||||
return fmt.Sprintf("%s:%d", t.Env().Host(), port)
|
||||
})
|
||||
}
|
||||
|
||||
// Route implements TCPUpstream.
|
||||
func (t *tcpUpstream) Route() testenv.RouteStub {
|
||||
r := &testenv.TCPRoute{}
|
||||
r.To(values.Bind(t.serverPort, func(port int) string {
|
||||
return fmt.Sprintf("tcp://%s:%d", t.Env().Host(), port)
|
||||
}))
|
||||
t.Add(r)
|
||||
return r
|
||||
}
|
||||
|
||||
// Run implements TCPUpstream.
|
||||
func (t *tcpUpstream) Run(ctx context.Context) error {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
listener, err := (&net.ListenConfig{}).Listen(ctx, "tcp", fmt.Sprintf("%s:0", t.Env().Host()))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
context.AfterFunc(ctx, func() {
|
||||
listener.Close()
|
||||
})
|
||||
t.serverPort.Resolve(listener.Addr().(*net.TCPAddr).Port)
|
||||
if t.serverTracerProviderOverride != nil {
|
||||
t.serverTracerProvider.Resolve(t.serverTracerProviderOverride)
|
||||
} else {
|
||||
t.serverTracerProvider.Resolve(trace.NewTracerProvider(ctx, t.displayName))
|
||||
}
|
||||
if t.clientTracerProviderOverride != nil {
|
||||
t.clientTracerProvider.Resolve(t.clientTracerProviderOverride)
|
||||
} else {
|
||||
t.clientTracerProvider.Resolve(trace.NewTracerProvider(ctx, "TCP Client"))
|
||||
}
|
||||
var wg sync.WaitGroup
|
||||
defer wg.Wait()
|
||||
|
||||
for {
|
||||
conn, err := listener.Accept()
|
||||
if err != nil {
|
||||
if errors.Is(err, net.ErrClosed) {
|
||||
cancel()
|
||||
return nil
|
||||
}
|
||||
continue
|
||||
}
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
if err := t.serverHandler(ctx, conn); err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
return
|
||||
}
|
||||
panic("server handler error: " + err.Error())
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
_ testenv.Upstream = (*tcpUpstream)(nil)
|
||||
_ TCPUpstream = (*tcpUpstream)(nil)
|
||||
)
|
195
internal/testenv/upstreams/util.go
Normal file
195
internal/testenv/upstreams/util.go
Normal file
|
@ -0,0 +1,195 @@
|
|||
package upstreams
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/pomerium/pomerium/integration/forms"
|
||||
"github.com/pomerium/pomerium/internal/retry"
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"go.opentelemetry.io/otel/codes"
|
||||
semconv "go.opentelemetry.io/otel/semconv/v1.26.0"
|
||||
oteltrace "go.opentelemetry.io/otel/trace"
|
||||
"google.golang.org/protobuf/proto"
|
||||
)
|
||||
|
||||
var ErrRetry = errors.New("error")
|
||||
|
||||
func doAuthenticatedRequest(
|
||||
ctx context.Context,
|
||||
newRequest func(context.Context) (*http.Request, error),
|
||||
getClient func(context.Context) *http.Client,
|
||||
options *RequestOptions,
|
||||
) (*http.Response, error) {
|
||||
var resp *http.Response
|
||||
resendCount := 0
|
||||
client := getClient(ctx)
|
||||
|
||||
if err := retry.Retry(ctx, "http", func(ctx context.Context) error {
|
||||
req, err := newRequest(ctx)
|
||||
if err != nil {
|
||||
return retry.NewTerminalError(err)
|
||||
}
|
||||
|
||||
switch body := options.body.(type) {
|
||||
case string:
|
||||
req.Body = io.NopCloser(strings.NewReader(body))
|
||||
case []byte:
|
||||
req.Body = io.NopCloser(bytes.NewReader(body))
|
||||
case io.Reader:
|
||||
req.Body = io.NopCloser(body)
|
||||
case proto.Message:
|
||||
buf, err := proto.Marshal(body)
|
||||
if err != nil {
|
||||
return retry.NewTerminalError(err)
|
||||
}
|
||||
req.Body = io.NopCloser(bytes.NewReader(buf))
|
||||
req.Header.Set("Content-Type", "application/octet-stream")
|
||||
default:
|
||||
buf, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("unsupported body type: %T", body))
|
||||
}
|
||||
req.Body = io.NopCloser(bytes.NewReader(buf))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
case nil:
|
||||
}
|
||||
|
||||
if options.headers != nil && req.Header == nil {
|
||||
req.Header = http.Header{}
|
||||
}
|
||||
for k, v := range options.headers {
|
||||
req.Header.Add(k, v)
|
||||
}
|
||||
|
||||
if options.authenticateAs != "" {
|
||||
resp, err = authenticateFlow(ctx, client, req, options.authenticateAs, true) //nolint:bodyclose
|
||||
} else {
|
||||
resp, err = client.Do(req) //nolint:bodyclose
|
||||
}
|
||||
// retry on connection refused
|
||||
span := oteltrace.SpanFromContext(ctx)
|
||||
if err != nil {
|
||||
span.RecordError(err)
|
||||
var opErr *net.OpError
|
||||
if errors.As(err, &opErr) && opErr.Op == "dial" && opErr.Err.Error() == "connect: connection refused" {
|
||||
span.AddEvent("Retrying on dial error")
|
||||
return err
|
||||
}
|
||||
return retry.NewTerminalError(err)
|
||||
}
|
||||
if resp.StatusCode/100 == 5 {
|
||||
resendCount++
|
||||
_, _ = io.ReadAll(resp.Body)
|
||||
_ = resp.Body.Close()
|
||||
span.SetAttributes(semconv.HTTPRequestResendCount(resendCount))
|
||||
span.AddEvent("Retrying on 5xx error", oteltrace.WithAttributes(
|
||||
attribute.String("status", resp.Status),
|
||||
))
|
||||
return errors.New(http.StatusText(resp.StatusCode))
|
||||
}
|
||||
span.SetStatus(codes.Ok, "request completed successfully")
|
||||
return nil
|
||||
},
|
||||
retry.WithInitialInterval(1*time.Millisecond),
|
||||
retry.WithMaxInterval(100*time.Millisecond),
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func authenticateFlow(ctx context.Context, client *http.Client, req *http.Request, email string, checkLocation bool) (*http.Response, error) {
|
||||
span := oteltrace.SpanFromContext(ctx)
|
||||
var res *http.Response
|
||||
originalHostname := req.URL.Hostname()
|
||||
res, err := client.Do(req)
|
||||
if err != nil {
|
||||
span.RecordError(err)
|
||||
return nil, err
|
||||
}
|
||||
location := res.Request.URL
|
||||
if checkLocation && location.Hostname() == originalHostname {
|
||||
// already authenticated
|
||||
span.SetStatus(codes.Ok, "already authenticated")
|
||||
return res, nil
|
||||
}
|
||||
fs := forms.Parse(res.Body)
|
||||
_, _ = io.ReadAll(res.Body)
|
||||
_ = res.Body.Close()
|
||||
if len(fs) > 0 {
|
||||
f := fs[0]
|
||||
f.Inputs["email"] = email
|
||||
f.Inputs["token_expiration"] = strconv.Itoa(int((time.Hour * 24).Seconds()))
|
||||
span.AddEvent("submitting form", oteltrace.WithAttributes(attribute.String("location", location.String())))
|
||||
formReq, err := f.NewRequestWithContext(ctx, location)
|
||||
if err != nil {
|
||||
span.RecordError(err)
|
||||
return nil, err
|
||||
}
|
||||
resp, err := client.Do(formReq)
|
||||
if err != nil {
|
||||
span.RecordError(err)
|
||||
return nil, err
|
||||
}
|
||||
span.SetStatus(codes.Ok, "form submitted successfully")
|
||||
return resp, nil
|
||||
}
|
||||
return nil, fmt.Errorf("test bug: expected IDP login form")
|
||||
}
|
||||
|
||||
type rwConn struct {
|
||||
serverReader io.ReadCloser
|
||||
serverWriter io.WriteCloser
|
||||
|
||||
net.Conn
|
||||
remote net.Conn
|
||||
|
||||
closeOnce sync.Once
|
||||
wg *sync.WaitGroup
|
||||
}
|
||||
|
||||
func NewRWConn(reader io.ReadCloser, writer io.WriteCloser) net.Conn {
|
||||
rwc := &rwConn{
|
||||
serverReader: reader,
|
||||
serverWriter: writer,
|
||||
wg: &sync.WaitGroup{},
|
||||
}
|
||||
|
||||
rwc.Conn, rwc.remote = net.Pipe()
|
||||
rwc.wg.Add(2)
|
||||
go func() {
|
||||
defer rwc.wg.Done()
|
||||
_, _ = io.Copy(rwc.remote, rwc.serverReader)
|
||||
rwc.remote.Close()
|
||||
}()
|
||||
go func() {
|
||||
defer rwc.wg.Done()
|
||||
_, _ = io.Copy(rwc.serverWriter, rwc.remote)
|
||||
rwc.serverWriter.Close()
|
||||
}()
|
||||
return rwc
|
||||
}
|
||||
|
||||
func (rwc *rwConn) Close() error {
|
||||
var err error
|
||||
rwc.closeOnce.Do(func() {
|
||||
readerErr := rwc.serverReader.Close()
|
||||
localErr := rwc.Conn.Close()
|
||||
rwc.wg.Wait()
|
||||
err = errors.Join(localErr, readerErr)
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
var _ net.Conn = (*rwConn)(nil)
|
|
@ -48,7 +48,7 @@ func TestQueryTracing(t *testing.T) {
|
|||
snippets.WaitStartupComplete(env)
|
||||
|
||||
resp, err := up.Get(route, upstreams.AuthenticateAs("user@example.com"), upstreams.Path("/foo"))
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
io.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue