package upstreams import ( "context" "crypto/tls" "errors" "fmt" "io" "net" "net/http" "net/http/cookiejar" "net/http/httptrace" "net/url" "sync" "github.com/pomerium/pomerium/internal/testenv" "github.com/pomerium/pomerium/internal/testenv/values" "github.com/pomerium/pomerium/pkg/telemetry/trace" "go.opentelemetry.io/otel/attribute" oteltrace "go.opentelemetry.io/otel/trace" "golang.org/x/net/http2" ) type TCPUpstream interface { testenv.Upstream Handle(fn func(context.Context, net.Conn) error) Dial(r testenv.Route, fn func(context.Context, net.Conn) error, opts ...RequestOption) error } type TCPUpstreamOptions struct { CommonUpstreamOptions } type TCPUpstreamOption interface { applyTCP(*TCPUpstreamOptions) } type tcpUpstream struct { TCPUpstreamOptions testenv.Aggregate serverPort values.MutableValue[int] serverHandler func(context.Context, net.Conn) error serverTracerProvider values.MutableValue[oteltrace.TracerProvider] clientTracerProvider values.MutableValue[oteltrace.TracerProvider] clientTracer values.Value[oteltrace.Tracer] } func TCP(opts ...TCPUpstreamOption) TCPUpstream { options := TCPUpstreamOptions{ CommonUpstreamOptions: CommonUpstreamOptions{ displayName: "TCP Upstream", }, } for _, op := range opts { op.applyTCP(&options) } up := &tcpUpstream{ TCPUpstreamOptions: options, serverPort: values.Deferred[int](), serverTracerProvider: values.Deferred[oteltrace.TracerProvider](), clientTracerProvider: values.Deferred[oteltrace.TracerProvider](), } up.clientTracer = values.Bind(up.clientTracerProvider, func(tp oteltrace.TracerProvider) oteltrace.Tracer { return tp.Tracer(trace.PomeriumCoreTracer) }) up.RecordCaller() return up } // Dial implements TCPUpstream. func (t *tcpUpstream) Dial(r testenv.Route, clientHandler func(context.Context, net.Conn) error, opts ...RequestOption) error { options := RequestOptions{ requestCtx: t.Env().Context(), dialProtocol: DialHTTP1, } options.apply(opts...) u, err := url.Parse(r.URL().Value()) if err != nil { return err } ctx, span := t.clientTracer.Value().Start(options.requestCtx, "tcpUpstream.Do", oteltrace.WithAttributes( attribute.String("protocol", string(options.dialProtocol)), attribute.String("url", u.String()), )) if options.path != "" || options.query != nil { u = u.ResolveReference(&url.URL{ Path: options.path, RawQuery: options.query.Encode(), }) } if options.trace != nil { ctx = httptrace.WithClientTrace(ctx, options.trace) } options.requestCtx = ctx defer span.End() var remoteConn *tls.Conn remoteWriter := make(chan *io.PipeWriter, 1) connectURL := &url.URL{Scheme: "https", Host: u.Host, Path: u.Path} var getClientFn func(context.Context) *http.Client var newRequestFn func(ctx context.Context) (*http.Request, error) switch options.dialProtocol { case DialHTTP1: getClientFn = t.h1Dialer(&options, connectURL, &remoteConn) newRequestFn = func(ctx context.Context) (*http.Request, error) { req := (&http.Request{ Method: http.MethodConnect, URL: connectURL, Host: u.Host, }).WithContext(ctx) return req, nil } case DialHTTP2: getClientFn = t.h2Dialer(&options, connectURL, &remoteConn, remoteWriter) newRequestFn = func(ctx context.Context) (*http.Request, error) { req := (&http.Request{ Method: http.MethodConnect, URL: connectURL, Host: u.Host, Proto: "HTTP/2", }).WithContext(ctx) return req, nil } case DialHTTP3: panic("not implemented") } resp, err := doAuthenticatedRequest(options.requestCtx, newRequestFn, getClientFn, &options) if err != nil { return err } if resp.StatusCode != http.StatusOK { resp.Body.Close() return errors.New(resp.Status) } if resp.Request.URL.Path == "/oidc/auth" { if options.authenticateAs == "" { return errors.New("test bug: unexpected IDP redirect; missing AuthenticateAs option to Dial()") } return errors.New("internal test bug: unexpected IDP redirect") } var w io.WriteCloser = remoteConn if options.dialProtocol == DialHTTP2 { w = <-remoteWriter } conn := NewRWConn(resp.Body, w) defer conn.Close() return clientHandler(resp.Request.Context(), conn) } func (t *tcpUpstream) h1Dialer( options *RequestOptions, connectURL *url.URL, remoteConn **tls.Conn, ) func(context.Context) *http.Client { jar, _ := cookiejar.New(nil) return func(context.Context) *http.Client { tlsConfig := &tls.Config{ RootCAs: t.Env().ServerCAs(), Certificates: options.clientCerts, NextProtos: []string{"http/1.1"}, } client := &http.Client{ Transport: &http.Transport{ DisableKeepAlives: true, DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) { if *remoteConn != nil { (*remoteConn).Close() *remoteConn = nil } dialer := &tls.Dialer{ Config: tlsConfig, } cc, err := dialer.DialContext(ctx, network, addr) if err != nil { return nil, fmt.Errorf("%w: %w", ErrRetry, err) } protocol := cc.(*tls.Conn).ConnectionState().NegotiatedProtocol if protocol != "http/1.1" { cc.Close() return nil, fmt.Errorf("error: unexpected TLS protocol: %s", protocol) } *remoteConn = cc.(*tls.Conn) return cc, nil }, TLSClientConfig: tlsConfig, // important }, CheckRedirect: func(req *http.Request, _ []*http.Request) error { if req.URL.String() == connectURL.String() && req.Method == http.MethodGet { req.Method = http.MethodConnect } return nil }, Jar: jar, } return client } } func (t *tcpUpstream) h2Dialer( options *RequestOptions, connectURL *url.URL, remoteConn **tls.Conn, writer chan<- *io.PipeWriter, ) func(context.Context) *http.Client { jar, _ := cookiejar.New(nil) return func(context.Context) *http.Client { h1 := &http.Transport{ ForceAttemptHTTP2: true, DisableKeepAlives: true, DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) { if *remoteConn != nil { (*remoteConn).Close() *remoteConn = nil } dialer := &tls.Dialer{ Config: &tls.Config{ RootCAs: t.Env().ServerCAs(), Certificates: options.clientCerts, NextProtos: []string{"h2"}, }, } cc, err := dialer.DialContext(ctx, network, addr) if err != nil { return nil, fmt.Errorf("%w: %w", ErrRetry, err) } protocol := cc.(*tls.Conn).ConnectionState().NegotiatedProtocol if protocol != "h2" { cc.Close() return nil, fmt.Errorf("error: unexpected TLS protocol: %s", protocol) } *remoteConn = cc.(*tls.Conn) return cc, nil }, TLSClientConfig: &tls.Config{ RootCAs: t.Env().ServerCAs(), Certificates: options.clientCerts, NextProtos: []string{"h2"}, }, } if err := http2.ConfigureTransport(h1); err != nil { panic(err) } client := &http.Client{ Transport: h1, CheckRedirect: func(req *http.Request, _ []*http.Request) error { if req.URL.String() == connectURL.String() && req.Method == http.MethodGet { pr, pw := io.Pipe() req.Method = http.MethodConnect req.Body = pr req.ContentLength = -1 writer <- pw } return nil }, Jar: jar, } return client } } // Handle implements TCPUpstream. func (t *tcpUpstream) Handle(fn func(context.Context, net.Conn) error) { t.serverHandler = fn } // Port implements TCPUpstream. func (t *tcpUpstream) Addr() values.Value[string] { return values.Bind(t.serverPort, func(port int) string { return fmt.Sprintf("%s:%d", t.Env().Host(), port) }) } // Route implements TCPUpstream. func (t *tcpUpstream) Route() testenv.RouteStub { r := &testenv.TCPRoute{} r.To(values.Bind(t.serverPort, func(port int) string { return fmt.Sprintf("tcp://%s:%d", t.Env().Host(), port) })) t.Add(r) return r } // Run implements TCPUpstream. func (t *tcpUpstream) Run(ctx context.Context) error { ctx, cancel := context.WithCancel(ctx) defer cancel() listener, err := (&net.ListenConfig{}).Listen(ctx, "tcp", fmt.Sprintf("%s:0", t.Env().Host())) if err != nil { return err } context.AfterFunc(ctx, func() { listener.Close() }) t.serverPort.Resolve(listener.Addr().(*net.TCPAddr).Port) if t.serverTracerProviderOverride != nil { t.serverTracerProvider.Resolve(t.serverTracerProviderOverride) } else { t.serverTracerProvider.Resolve(trace.NewTracerProvider(ctx, t.displayName)) } if t.clientTracerProviderOverride != nil { t.clientTracerProvider.Resolve(t.clientTracerProviderOverride) } else { t.clientTracerProvider.Resolve(trace.NewTracerProvider(ctx, "TCP Client")) } var wg sync.WaitGroup defer wg.Wait() for { conn, err := listener.Accept() if err != nil { if errors.Is(err, net.ErrClosed) { cancel() return nil } continue } wg.Add(1) go func() { defer wg.Done() if err := t.serverHandler(ctx, conn); err != nil { if errors.Is(err, io.EOF) { return } panic("server handler error: " + err.Error()) } }() } } var ( _ testenv.Upstream = (*tcpUpstream)(nil) _ TCPUpstream = (*tcpUpstream)(nil) )