pomerium/internal/testenv/upstreams/tcp.go

343 lines
9 KiB
Go

package upstreams
import (
"context"
"crypto/tls"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/http/cookiejar"
"net/http/httptrace"
"net/url"
"sync"
"github.com/pomerium/pomerium/internal/testenv"
"github.com/pomerium/pomerium/internal/testenv/values"
"github.com/pomerium/pomerium/pkg/telemetry/trace"
"go.opentelemetry.io/otel/attribute"
oteltrace "go.opentelemetry.io/otel/trace"
"golang.org/x/net/http2"
)
type TCPUpstream interface {
testenv.Upstream
Handle(fn func(context.Context, net.Conn) error)
Dial(r testenv.Route, fn func(context.Context, net.Conn) error, opts ...RequestOption) error
}
type TCPUpstreamOptions struct {
CommonUpstreamOptions
}
type TCPUpstreamOption interface {
applyTCP(*TCPUpstreamOptions)
}
type tcpUpstream struct {
TCPUpstreamOptions
testenv.Aggregate
serverPort values.MutableValue[int]
serverHandler func(context.Context, net.Conn) error
serverTracerProvider values.MutableValue[oteltrace.TracerProvider]
clientTracerProvider values.MutableValue[oteltrace.TracerProvider]
clientTracer values.Value[oteltrace.Tracer]
}
func TCP(opts ...TCPUpstreamOption) TCPUpstream {
options := TCPUpstreamOptions{
CommonUpstreamOptions: CommonUpstreamOptions{
displayName: "TCP Upstream",
},
}
for _, op := range opts {
op.applyTCP(&options)
}
up := &tcpUpstream{
TCPUpstreamOptions: options,
serverPort: values.Deferred[int](),
serverTracerProvider: values.Deferred[oteltrace.TracerProvider](),
clientTracerProvider: values.Deferred[oteltrace.TracerProvider](),
}
up.clientTracer = values.Bind(up.clientTracerProvider, func(tp oteltrace.TracerProvider) oteltrace.Tracer {
return tp.Tracer(trace.PomeriumCoreTracer)
})
up.RecordCaller()
return up
}
// Dial implements TCPUpstream.
func (t *tcpUpstream) Dial(r testenv.Route, clientHandler func(context.Context, net.Conn) error, opts ...RequestOption) error {
options := RequestOptions{
requestCtx: t.Env().Context(),
dialProtocol: DialHTTP1,
}
options.apply(opts...)
u, err := url.Parse(r.URL().Value())
if err != nil {
return err
}
ctx, span := t.clientTracer.Value().Start(options.requestCtx, "tcpUpstream.Do", oteltrace.WithAttributes(
attribute.String("protocol", string(options.dialProtocol)),
attribute.String("url", u.String()),
))
if options.path != "" || options.query != nil {
u = u.ResolveReference(&url.URL{
Path: options.path,
RawQuery: options.query.Encode(),
})
}
if options.trace != nil {
ctx = httptrace.WithClientTrace(ctx, options.trace)
}
options.requestCtx = ctx
defer span.End()
var remoteConn *tls.Conn
remoteWriter := make(chan *io.PipeWriter, 1)
connectURL := &url.URL{Scheme: "https", Host: u.Host, Path: u.Path}
var getClientFn func(context.Context) *http.Client
var newRequestFn func(ctx context.Context) (*http.Request, error)
switch options.dialProtocol {
case DialHTTP1:
getClientFn = t.h1Dialer(&options, connectURL, &remoteConn)
newRequestFn = func(ctx context.Context) (*http.Request, error) {
req := (&http.Request{
Method: http.MethodConnect,
URL: connectURL,
Host: u.Host,
}).WithContext(ctx)
return req, nil
}
case DialHTTP2:
getClientFn = t.h2Dialer(&options, connectURL, &remoteConn, remoteWriter)
newRequestFn = func(ctx context.Context) (*http.Request, error) {
req := (&http.Request{
Method: http.MethodConnect,
URL: connectURL,
Host: u.Host,
Proto: "HTTP/2",
}).WithContext(ctx)
return req, nil
}
case DialHTTP3:
panic("not implemented")
}
resp, err := doAuthenticatedRequest(options.requestCtx, newRequestFn, getClientFn, &options)
if err != nil {
return err
}
if resp.StatusCode != http.StatusOK {
resp.Body.Close()
return errors.New(resp.Status)
}
if resp.Request.URL.Path == "/oidc/auth" {
if options.authenticateAs == "" {
return errors.New("test bug: unexpected IDP redirect; missing AuthenticateAs option to Dial()")
}
return errors.New("internal test bug: unexpected IDP redirect")
}
var w io.WriteCloser = remoteConn
if options.dialProtocol == DialHTTP2 {
w = <-remoteWriter
}
conn := NewRWConn(resp.Body, w)
defer conn.Close()
return clientHandler(resp.Request.Context(), conn)
}
func (t *tcpUpstream) h1Dialer(
options *RequestOptions,
connectURL *url.URL,
remoteConn **tls.Conn,
) func(context.Context) *http.Client {
jar, _ := cookiejar.New(nil)
return func(context.Context) *http.Client {
tlsConfig := &tls.Config{
RootCAs: t.Env().ServerCAs(),
Certificates: options.clientCerts,
NextProtos: []string{"http/1.1"},
}
client := &http.Client{
Transport: &http.Transport{
DisableKeepAlives: true,
DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
if *remoteConn != nil {
(*remoteConn).Close()
*remoteConn = nil
}
dialer := &tls.Dialer{
Config: tlsConfig,
}
cc, err := dialer.DialContext(ctx, network, addr)
if err != nil {
return nil, fmt.Errorf("%w: %w", ErrRetry, err)
}
protocol := cc.(*tls.Conn).ConnectionState().NegotiatedProtocol
if protocol != "http/1.1" {
cc.Close()
return nil, fmt.Errorf("error: unexpected TLS protocol: %s", protocol)
}
*remoteConn = cc.(*tls.Conn)
return cc, nil
},
TLSClientConfig: tlsConfig, // important
},
CheckRedirect: func(req *http.Request, _ []*http.Request) error {
if req.URL.String() == connectURL.String() && req.Method == http.MethodGet {
req.Method = http.MethodConnect
}
return nil
},
Jar: jar,
}
return client
}
}
func (t *tcpUpstream) h2Dialer(
options *RequestOptions,
connectURL *url.URL,
remoteConn **tls.Conn,
writer chan<- *io.PipeWriter,
) func(context.Context) *http.Client {
jar, _ := cookiejar.New(nil)
return func(context.Context) *http.Client {
h1 := &http.Transport{
ForceAttemptHTTP2: true,
DisableKeepAlives: true,
DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
if *remoteConn != nil {
(*remoteConn).Close()
*remoteConn = nil
}
dialer := &tls.Dialer{
Config: &tls.Config{
RootCAs: t.Env().ServerCAs(),
Certificates: options.clientCerts,
NextProtos: []string{"h2"},
},
}
cc, err := dialer.DialContext(ctx, network, addr)
if err != nil {
return nil, fmt.Errorf("%w: %w", ErrRetry, err)
}
protocol := cc.(*tls.Conn).ConnectionState().NegotiatedProtocol
if protocol != "h2" {
cc.Close()
return nil, fmt.Errorf("error: unexpected TLS protocol: %s", protocol)
}
*remoteConn = cc.(*tls.Conn)
return cc, nil
},
TLSClientConfig: &tls.Config{
RootCAs: t.Env().ServerCAs(),
Certificates: options.clientCerts,
NextProtos: []string{"h2"},
},
}
if err := http2.ConfigureTransport(h1); err != nil {
panic(err)
}
client := &http.Client{
Transport: h1,
CheckRedirect: func(req *http.Request, _ []*http.Request) error {
if req.URL.String() == connectURL.String() && req.Method == http.MethodGet {
pr, pw := io.Pipe()
req.Method = http.MethodConnect
req.Body = pr
req.ContentLength = -1
writer <- pw
}
return nil
},
Jar: jar,
}
return client
}
}
// Handle implements TCPUpstream.
func (t *tcpUpstream) Handle(fn func(context.Context, net.Conn) error) {
t.serverHandler = fn
}
// Port implements TCPUpstream.
func (t *tcpUpstream) Addr() values.Value[string] {
return values.Bind(t.serverPort, func(port int) string {
return fmt.Sprintf("%s:%d", t.Env().Host(), port)
})
}
// Route implements TCPUpstream.
func (t *tcpUpstream) Route() testenv.RouteStub {
r := &testenv.TCPRoute{}
r.To(values.Bind(t.serverPort, func(port int) string {
return fmt.Sprintf("tcp://%s:%d", t.Env().Host(), port)
}))
t.Add(r)
return r
}
// Run implements TCPUpstream.
func (t *tcpUpstream) Run(ctx context.Context) error {
ctx, cancel := context.WithCancel(ctx)
defer cancel()
listener, err := (&net.ListenConfig{}).Listen(ctx, "tcp", fmt.Sprintf("%s:0", t.Env().Host()))
if err != nil {
return err
}
context.AfterFunc(ctx, func() {
listener.Close()
})
t.serverPort.Resolve(listener.Addr().(*net.TCPAddr).Port)
if t.serverTracerProviderOverride != nil {
t.serverTracerProvider.Resolve(t.serverTracerProviderOverride)
} else {
t.serverTracerProvider.Resolve(trace.NewTracerProvider(ctx, t.displayName))
}
if t.clientTracerProviderOverride != nil {
t.clientTracerProvider.Resolve(t.clientTracerProviderOverride)
} else {
t.clientTracerProvider.Resolve(trace.NewTracerProvider(ctx, "TCP Client"))
}
var wg sync.WaitGroup
defer wg.Wait()
for {
conn, err := listener.Accept()
if err != nil {
if errors.Is(err, net.ErrClosed) {
cancel()
return nil
}
continue
}
wg.Add(1)
go func() {
defer wg.Done()
if err := t.serverHandler(ctx, conn); err != nil {
if errors.Is(err, io.EOF) {
return
}
panic("server handler error: " + err.Error())
}
}()
}
}
var (
_ testenv.Upstream = (*tcpUpstream)(nil)
_ TCPUpstream = (*tcpUpstream)(nil)
)