mirror of
https://github.com/pomerium/pomerium.git
synced 2025-04-28 18:06:34 +02:00
343 lines
9 KiB
Go
343 lines
9 KiB
Go
package upstreams
|
|
|
|
import (
|
|
"context"
|
|
"crypto/tls"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"net/http"
|
|
"net/http/cookiejar"
|
|
"net/http/httptrace"
|
|
"net/url"
|
|
"sync"
|
|
|
|
"github.com/pomerium/pomerium/internal/testenv"
|
|
"github.com/pomerium/pomerium/internal/testenv/values"
|
|
"github.com/pomerium/pomerium/pkg/telemetry/trace"
|
|
"go.opentelemetry.io/otel/attribute"
|
|
oteltrace "go.opentelemetry.io/otel/trace"
|
|
"golang.org/x/net/http2"
|
|
)
|
|
|
|
type TCPUpstream interface {
|
|
testenv.Upstream
|
|
|
|
Handle(fn func(context.Context, net.Conn) error)
|
|
|
|
Dial(r testenv.Route, fn func(context.Context, net.Conn) error, opts ...RequestOption) error
|
|
}
|
|
|
|
type TCPUpstreamOptions struct {
|
|
CommonUpstreamOptions
|
|
}
|
|
|
|
type TCPUpstreamOption interface {
|
|
applyTCP(*TCPUpstreamOptions)
|
|
}
|
|
|
|
type tcpUpstream struct {
|
|
TCPUpstreamOptions
|
|
testenv.Aggregate
|
|
serverPort values.MutableValue[int]
|
|
serverHandler func(context.Context, net.Conn) error
|
|
|
|
serverTracerProvider values.MutableValue[oteltrace.TracerProvider]
|
|
clientTracerProvider values.MutableValue[oteltrace.TracerProvider]
|
|
clientTracer values.Value[oteltrace.Tracer]
|
|
}
|
|
|
|
func TCP(opts ...TCPUpstreamOption) TCPUpstream {
|
|
options := TCPUpstreamOptions{
|
|
CommonUpstreamOptions: CommonUpstreamOptions{
|
|
displayName: "TCP Upstream",
|
|
},
|
|
}
|
|
for _, op := range opts {
|
|
op.applyTCP(&options)
|
|
}
|
|
up := &tcpUpstream{
|
|
TCPUpstreamOptions: options,
|
|
serverPort: values.Deferred[int](),
|
|
|
|
serverTracerProvider: values.Deferred[oteltrace.TracerProvider](),
|
|
clientTracerProvider: values.Deferred[oteltrace.TracerProvider](),
|
|
}
|
|
up.clientTracer = values.Bind(up.clientTracerProvider, func(tp oteltrace.TracerProvider) oteltrace.Tracer {
|
|
return tp.Tracer(trace.PomeriumCoreTracer)
|
|
})
|
|
up.RecordCaller()
|
|
return up
|
|
}
|
|
|
|
// Dial implements TCPUpstream.
|
|
func (t *tcpUpstream) Dial(r testenv.Route, clientHandler func(context.Context, net.Conn) error, opts ...RequestOption) error {
|
|
options := RequestOptions{
|
|
requestCtx: t.Env().Context(),
|
|
dialProtocol: DialHTTP1,
|
|
}
|
|
options.apply(opts...)
|
|
u, err := url.Parse(r.URL().Value())
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
ctx, span := t.clientTracer.Value().Start(options.requestCtx, "tcpUpstream.Do", oteltrace.WithAttributes(
|
|
attribute.String("protocol", string(options.dialProtocol)),
|
|
attribute.String("url", u.String()),
|
|
))
|
|
if options.path != "" || options.query != nil {
|
|
u = u.ResolveReference(&url.URL{
|
|
Path: options.path,
|
|
RawQuery: options.query.Encode(),
|
|
})
|
|
}
|
|
if options.trace != nil {
|
|
ctx = httptrace.WithClientTrace(ctx, options.trace)
|
|
}
|
|
options.requestCtx = ctx
|
|
defer span.End()
|
|
|
|
var remoteConn *tls.Conn
|
|
remoteWriter := make(chan *io.PipeWriter, 1)
|
|
|
|
connectURL := &url.URL{Scheme: "https", Host: u.Host, Path: u.Path}
|
|
|
|
var getClientFn func(context.Context) *http.Client
|
|
var newRequestFn func(ctx context.Context) (*http.Request, error)
|
|
switch options.dialProtocol {
|
|
case DialHTTP1:
|
|
getClientFn = t.h1Dialer(&options, connectURL, &remoteConn)
|
|
newRequestFn = func(ctx context.Context) (*http.Request, error) {
|
|
req := (&http.Request{
|
|
Method: http.MethodConnect,
|
|
URL: connectURL,
|
|
Host: u.Host,
|
|
}).WithContext(ctx)
|
|
return req, nil
|
|
}
|
|
case DialHTTP2:
|
|
getClientFn = t.h2Dialer(&options, connectURL, &remoteConn, remoteWriter)
|
|
newRequestFn = func(ctx context.Context) (*http.Request, error) {
|
|
req := (&http.Request{
|
|
Method: http.MethodConnect,
|
|
URL: connectURL,
|
|
Host: u.Host,
|
|
Proto: "HTTP/2",
|
|
}).WithContext(ctx)
|
|
return req, nil
|
|
}
|
|
case DialHTTP3:
|
|
panic("not implemented")
|
|
}
|
|
resp, err := doAuthenticatedRequest(options.requestCtx, newRequestFn, getClientFn, &options)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if resp.StatusCode != http.StatusOK {
|
|
resp.Body.Close()
|
|
return errors.New(resp.Status)
|
|
}
|
|
if resp.Request.URL.Path == "/oidc/auth" {
|
|
if options.authenticateAs == "" {
|
|
return errors.New("test bug: unexpected IDP redirect; missing AuthenticateAs option to Dial()")
|
|
}
|
|
return errors.New("internal test bug: unexpected IDP redirect")
|
|
}
|
|
|
|
var w io.WriteCloser = remoteConn
|
|
if options.dialProtocol == DialHTTP2 {
|
|
w = <-remoteWriter
|
|
}
|
|
|
|
conn := NewRWConn(resp.Body, w)
|
|
defer conn.Close()
|
|
return clientHandler(resp.Request.Context(), conn)
|
|
}
|
|
|
|
func (t *tcpUpstream) h1Dialer(
|
|
options *RequestOptions,
|
|
connectURL *url.URL,
|
|
remoteConn **tls.Conn,
|
|
) func(context.Context) *http.Client {
|
|
jar, _ := cookiejar.New(nil)
|
|
return func(context.Context) *http.Client {
|
|
tlsConfig := &tls.Config{
|
|
RootCAs: t.Env().ServerCAs(),
|
|
Certificates: options.clientCerts,
|
|
NextProtos: []string{"http/1.1"},
|
|
}
|
|
client := &http.Client{
|
|
Transport: &http.Transport{
|
|
DisableKeepAlives: true,
|
|
DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
|
if *remoteConn != nil {
|
|
(*remoteConn).Close()
|
|
*remoteConn = nil
|
|
}
|
|
dialer := &tls.Dialer{
|
|
Config: tlsConfig,
|
|
}
|
|
cc, err := dialer.DialContext(ctx, network, addr)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("%w: %w", ErrRetry, err)
|
|
}
|
|
protocol := cc.(*tls.Conn).ConnectionState().NegotiatedProtocol
|
|
if protocol != "http/1.1" {
|
|
cc.Close()
|
|
return nil, fmt.Errorf("error: unexpected TLS protocol: %s", protocol)
|
|
}
|
|
*remoteConn = cc.(*tls.Conn)
|
|
return cc, nil
|
|
},
|
|
TLSClientConfig: tlsConfig, // important
|
|
},
|
|
CheckRedirect: func(req *http.Request, _ []*http.Request) error {
|
|
if req.URL.String() == connectURL.String() && req.Method == http.MethodGet {
|
|
req.Method = http.MethodConnect
|
|
}
|
|
return nil
|
|
},
|
|
Jar: jar,
|
|
}
|
|
return client
|
|
}
|
|
}
|
|
|
|
func (t *tcpUpstream) h2Dialer(
|
|
options *RequestOptions,
|
|
connectURL *url.URL,
|
|
remoteConn **tls.Conn,
|
|
writer chan<- *io.PipeWriter,
|
|
) func(context.Context) *http.Client {
|
|
jar, _ := cookiejar.New(nil)
|
|
return func(context.Context) *http.Client {
|
|
h1 := &http.Transport{
|
|
ForceAttemptHTTP2: true,
|
|
DisableKeepAlives: true,
|
|
DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
|
if *remoteConn != nil {
|
|
(*remoteConn).Close()
|
|
*remoteConn = nil
|
|
}
|
|
dialer := &tls.Dialer{
|
|
Config: &tls.Config{
|
|
RootCAs: t.Env().ServerCAs(),
|
|
Certificates: options.clientCerts,
|
|
NextProtos: []string{"h2"},
|
|
},
|
|
}
|
|
cc, err := dialer.DialContext(ctx, network, addr)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("%w: %w", ErrRetry, err)
|
|
}
|
|
protocol := cc.(*tls.Conn).ConnectionState().NegotiatedProtocol
|
|
if protocol != "h2" {
|
|
cc.Close()
|
|
return nil, fmt.Errorf("error: unexpected TLS protocol: %s", protocol)
|
|
}
|
|
*remoteConn = cc.(*tls.Conn)
|
|
|
|
return cc, nil
|
|
},
|
|
TLSClientConfig: &tls.Config{
|
|
RootCAs: t.Env().ServerCAs(),
|
|
Certificates: options.clientCerts,
|
|
NextProtos: []string{"h2"},
|
|
},
|
|
}
|
|
if err := http2.ConfigureTransport(h1); err != nil {
|
|
panic(err)
|
|
}
|
|
client := &http.Client{
|
|
Transport: h1,
|
|
CheckRedirect: func(req *http.Request, _ []*http.Request) error {
|
|
if req.URL.String() == connectURL.String() && req.Method == http.MethodGet {
|
|
pr, pw := io.Pipe()
|
|
req.Method = http.MethodConnect
|
|
req.Body = pr
|
|
req.ContentLength = -1
|
|
writer <- pw
|
|
}
|
|
return nil
|
|
},
|
|
Jar: jar,
|
|
}
|
|
return client
|
|
}
|
|
}
|
|
|
|
// Handle implements TCPUpstream.
|
|
func (t *tcpUpstream) Handle(fn func(context.Context, net.Conn) error) {
|
|
t.serverHandler = fn
|
|
}
|
|
|
|
// Port implements TCPUpstream.
|
|
func (t *tcpUpstream) Addr() values.Value[string] {
|
|
return values.Bind(t.serverPort, func(port int) string {
|
|
return fmt.Sprintf("%s:%d", t.Env().Host(), port)
|
|
})
|
|
}
|
|
|
|
// Route implements TCPUpstream.
|
|
func (t *tcpUpstream) Route() testenv.RouteStub {
|
|
r := &testenv.TCPRoute{}
|
|
r.To(values.Bind(t.serverPort, func(port int) string {
|
|
return fmt.Sprintf("tcp://%s:%d", t.Env().Host(), port)
|
|
}))
|
|
t.Add(r)
|
|
return r
|
|
}
|
|
|
|
// Run implements TCPUpstream.
|
|
func (t *tcpUpstream) Run(ctx context.Context) error {
|
|
ctx, cancel := context.WithCancel(ctx)
|
|
defer cancel()
|
|
listener, err := (&net.ListenConfig{}).Listen(ctx, "tcp", fmt.Sprintf("%s:0", t.Env().Host()))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
context.AfterFunc(ctx, func() {
|
|
listener.Close()
|
|
})
|
|
t.serverPort.Resolve(listener.Addr().(*net.TCPAddr).Port)
|
|
if t.serverTracerProviderOverride != nil {
|
|
t.serverTracerProvider.Resolve(t.serverTracerProviderOverride)
|
|
} else {
|
|
t.serverTracerProvider.Resolve(trace.NewTracerProvider(ctx, t.displayName))
|
|
}
|
|
if t.clientTracerProviderOverride != nil {
|
|
t.clientTracerProvider.Resolve(t.clientTracerProviderOverride)
|
|
} else {
|
|
t.clientTracerProvider.Resolve(trace.NewTracerProvider(ctx, "TCP Client"))
|
|
}
|
|
var wg sync.WaitGroup
|
|
defer wg.Wait()
|
|
|
|
for {
|
|
conn, err := listener.Accept()
|
|
if err != nil {
|
|
if errors.Is(err, net.ErrClosed) {
|
|
cancel()
|
|
return nil
|
|
}
|
|
continue
|
|
}
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
if err := t.serverHandler(ctx, conn); err != nil {
|
|
if errors.Is(err, io.EOF) {
|
|
return
|
|
}
|
|
panic("server handler error: " + err.Error())
|
|
}
|
|
}()
|
|
}
|
|
}
|
|
|
|
var (
|
|
_ testenv.Upstream = (*tcpUpstream)(nil)
|
|
_ TCPUpstream = (*tcpUpstream)(nil)
|
|
)
|