mirror of
https://github.com/pomerium/pomerium.git
synced 2025-04-28 18:06:34 +02:00
474 lines
13 KiB
Go
474 lines
13 KiB
Go
package upstreams
|
|
|
|
import (
|
|
"context"
|
|
"crypto/tls"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"net/http"
|
|
"net/http/cookiejar"
|
|
"net/http/httptrace"
|
|
"net/url"
|
|
"os"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/gorilla/mux"
|
|
"github.com/gorilla/websocket"
|
|
"github.com/pomerium/pomerium/internal/testenv"
|
|
"github.com/pomerium/pomerium/internal/testenv/snippets"
|
|
"github.com/pomerium/pomerium/internal/testenv/values"
|
|
"github.com/pomerium/pomerium/pkg/telemetry/trace"
|
|
"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
|
|
|
|
"go.opentelemetry.io/otel/attribute"
|
|
"go.opentelemetry.io/otel/codes"
|
|
oteltrace "go.opentelemetry.io/otel/trace"
|
|
)
|
|
|
|
type Protocol string
|
|
|
|
const (
|
|
DialHTTP1 Protocol = "http/1.1"
|
|
DialHTTP2 Protocol = "h2"
|
|
DialHTTP3 Protocol = "h3"
|
|
)
|
|
|
|
type RequestOptions struct {
|
|
requestCtx context.Context
|
|
path string
|
|
query url.Values
|
|
headers map[string]string
|
|
authenticateAs string
|
|
body any
|
|
clientCerts []tls.Certificate
|
|
clientHook func(*http.Client) *http.Client
|
|
dialerHook func(*websocket.Dialer, *url.URL) (*websocket.Dialer, *url.URL)
|
|
dialProtocol Protocol
|
|
trace *httptrace.ClientTrace
|
|
}
|
|
|
|
type RequestOption func(*RequestOptions)
|
|
|
|
func (ro RequestOption) Format(fmt.State, rune) {
|
|
panic("test bug: request option mistakenly passed to assert function")
|
|
}
|
|
|
|
func (o *RequestOptions) apply(opts ...RequestOption) {
|
|
for _, op := range opts {
|
|
op(o)
|
|
}
|
|
}
|
|
|
|
// Path sets the path of the request. If omitted, the request URL will match
|
|
// the route URL exactly.
|
|
func Path(path string) RequestOption {
|
|
return func(o *RequestOptions) {
|
|
o.path = path
|
|
}
|
|
}
|
|
|
|
// Query sets optional query parameters of the request.
|
|
func Query(query url.Values) RequestOption {
|
|
return func(o *RequestOptions) {
|
|
o.query = query
|
|
}
|
|
}
|
|
|
|
// Headers adds optional headers to the request.
|
|
func Headers(headers map[string]string) RequestOption {
|
|
return func(o *RequestOptions) {
|
|
o.headers = headers
|
|
}
|
|
}
|
|
|
|
func AuthenticateAs(email string) RequestOption {
|
|
return func(o *RequestOptions) {
|
|
o.authenticateAs = email
|
|
}
|
|
}
|
|
|
|
// ClientHook allows editing or replacing the http client before it is used.
|
|
// When any request is about to start, this function will be called with the
|
|
// client that would be used to make the request. The returned client will
|
|
// be the actual client used for that request. It can be the same as the input
|
|
// (with or without modification), or replaced entirely.
|
|
//
|
|
// Note: the Transport of the client passed to the hook will always be a
|
|
// [*Transport]. That transport's underlying transport will always be
|
|
// a [*otelhttp.Transport].
|
|
func ClientHook(f func(*http.Client) *http.Client) RequestOption {
|
|
return func(o *RequestOptions) {
|
|
o.clientHook = f
|
|
}
|
|
}
|
|
|
|
// DialerHook allows editing or replacing the websocket dialer before it is
|
|
// used. When a websocket request is about to start (using the DialWS method),
|
|
// this function will be called with the dialer that would be used, and the
|
|
// destination URL (including wss:// scheme, and path if one is present). The
|
|
// returned dialer+URL will be the actual dialer+URL used for that request.
|
|
//
|
|
// If ClientHook is also set, both will be called. The dialer passed to this
|
|
// hook will have its TLSClientConfig and Jar fields set from the client.
|
|
func DialerHook(f func(*websocket.Dialer, *url.URL) (*websocket.Dialer, *url.URL)) RequestOption {
|
|
return func(o *RequestOptions) {
|
|
o.dialerHook = f
|
|
}
|
|
}
|
|
|
|
func DialProtocol(protocol Protocol) RequestOption {
|
|
return func(o *RequestOptions) {
|
|
o.dialProtocol = protocol
|
|
}
|
|
}
|
|
|
|
func Context(ctx context.Context) RequestOption {
|
|
return func(o *RequestOptions) {
|
|
o.requestCtx = ctx
|
|
}
|
|
}
|
|
|
|
func WithClientTrace(ct *httptrace.ClientTrace) RequestOption {
|
|
return func(o *RequestOptions) {
|
|
o.trace = ct
|
|
}
|
|
}
|
|
|
|
// Body sets the body of the request.
|
|
// The argument can be one of the following types:
|
|
// - string
|
|
// - []byte
|
|
// - io.Reader
|
|
// - proto.Message
|
|
// - any json-encodable type
|
|
// If the argument is encoded as json, the Content-Type header will be set to
|
|
// "application/json". If the argument is a proto.Message, the Content-Type
|
|
// header will be set to "application/octet-stream".
|
|
func Body(body any) RequestOption {
|
|
return func(o *RequestOptions) {
|
|
o.body = body
|
|
}
|
|
}
|
|
|
|
// ClientCert adds a client certificate to the request.
|
|
func ClientCert[T interface {
|
|
*testenv.Certificate | *tls.Certificate
|
|
}](cert T) RequestOption {
|
|
return func(o *RequestOptions) {
|
|
o.clientCerts = append(o.clientCerts, *(*tls.Certificate)(cert))
|
|
}
|
|
}
|
|
|
|
type HTTPUpstreamOptions struct {
|
|
CommonUpstreamOptions
|
|
}
|
|
|
|
type HTTPUpstreamOption interface {
|
|
applyHTTP(*HTTPUpstreamOptions)
|
|
}
|
|
|
|
// HTTPUpstream represents a HTTP server which can be used as the target for
|
|
// one or more Pomerium routes in a test environment.
|
|
//
|
|
// The Handle() method can be used to add handlers the server-side HTTP router,
|
|
// while the Get(), Post(), and (generic) Do() methods can be used to make
|
|
// client-side requests.
|
|
type HTTPUpstream interface {
|
|
testenv.Upstream
|
|
|
|
Handle(path string, f func(http.ResponseWriter, *http.Request)) *mux.Route
|
|
HandleWS(path string, upgrader websocket.Upgrader, f func(conn *websocket.Conn) error) *mux.Route
|
|
|
|
Get(r testenv.Route, opts ...RequestOption) (*http.Response, error)
|
|
Post(r testenv.Route, opts ...RequestOption) (*http.Response, error)
|
|
Do(method string, r testenv.Route, opts ...RequestOption) (*http.Response, error)
|
|
DialWS(r testenv.Route, f func(conn *websocket.Conn) error, opts ...RequestOption) error
|
|
}
|
|
|
|
type httpUpstream struct {
|
|
HTTPUpstreamOptions
|
|
testenv.Aggregate
|
|
serverPort values.MutableValue[int]
|
|
tlsConfig values.Value[*tls.Config]
|
|
|
|
clientCache sync.Map // map[testenv.Route]*http.Client
|
|
|
|
router *mux.Router
|
|
serverTracerProvider values.MutableValue[oteltrace.TracerProvider]
|
|
clientTracerProvider values.MutableValue[oteltrace.TracerProvider]
|
|
clientTracer values.Value[oteltrace.Tracer]
|
|
}
|
|
|
|
var (
|
|
_ testenv.Upstream = (*httpUpstream)(nil)
|
|
_ HTTPUpstream = (*httpUpstream)(nil)
|
|
)
|
|
|
|
// HTTP creates a new HTTP upstream server.
|
|
func HTTP(tlsConfig values.Value[*tls.Config], opts ...HTTPUpstreamOption) HTTPUpstream {
|
|
options := HTTPUpstreamOptions{
|
|
CommonUpstreamOptions: CommonUpstreamOptions{
|
|
displayName: "HTTP Upstream",
|
|
},
|
|
}
|
|
for _, op := range opts {
|
|
op.applyHTTP(&options)
|
|
}
|
|
up := &httpUpstream{
|
|
HTTPUpstreamOptions: options,
|
|
serverPort: values.Deferred[int](),
|
|
router: mux.NewRouter(),
|
|
tlsConfig: tlsConfig,
|
|
serverTracerProvider: values.Deferred[oteltrace.TracerProvider](),
|
|
clientTracerProvider: values.Deferred[oteltrace.TracerProvider](),
|
|
}
|
|
up.clientTracer = values.Bind(up.clientTracerProvider, func(tp oteltrace.TracerProvider) oteltrace.Tracer {
|
|
return tp.Tracer(trace.PomeriumCoreTracer)
|
|
})
|
|
up.RecordCaller()
|
|
return up
|
|
}
|
|
|
|
// Port implements HTTPUpstream.
|
|
func (h *httpUpstream) Addr() values.Value[string] {
|
|
return values.Bind(h.serverPort, func(port int) string {
|
|
return fmt.Sprintf("%s:%d", h.Env().Host(), port)
|
|
})
|
|
}
|
|
|
|
// Router implements HTTPUpstream.
|
|
func (h *httpUpstream) Handle(path string, f func(http.ResponseWriter, *http.Request)) *mux.Route {
|
|
return h.router.HandleFunc(path, f)
|
|
}
|
|
|
|
// Router implements HTTPUpstream.
|
|
func (h *httpUpstream) HandleWS(path string, upgrader websocket.Upgrader, f func(*websocket.Conn) error) *mux.Route {
|
|
return h.router.HandleFunc(path, func(w http.ResponseWriter, r *http.Request) {
|
|
ctx, span := trace.Continue(r.Context(), "HandleWS")
|
|
defer span.End()
|
|
c, err := upgrader.Upgrade(w, r.WithContext(ctx), nil)
|
|
if err != nil {
|
|
span.SetStatus(codes.Error, err.Error())
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
_, _ = w.Write([]byte(err.Error()))
|
|
return
|
|
}
|
|
defer c.Close()
|
|
|
|
err = f(c)
|
|
if err != nil {
|
|
if errors.Is(err, io.EOF) {
|
|
return
|
|
}
|
|
span.SetStatus(codes.Error, err.Error())
|
|
fmt.Fprintf(os.Stderr, "websocket error: %s\n", err.Error())
|
|
}
|
|
})
|
|
}
|
|
|
|
// Route implements HTTPUpstream.
|
|
func (h *httpUpstream) Route() testenv.RouteStub {
|
|
r := &testenv.PolicyRoute{}
|
|
protocol := "http"
|
|
r.To(values.Bind(h.serverPort, func(port int) string {
|
|
return fmt.Sprintf("%s://%s:%d", protocol, h.Env().Host(), port)
|
|
}))
|
|
h.Add(r)
|
|
return r
|
|
}
|
|
|
|
// Run implements HTTPUpstream.
|
|
func (h *httpUpstream) Run(ctx context.Context) error {
|
|
var listener net.Listener
|
|
if h.tlsConfig != nil {
|
|
var err error
|
|
listener, err = tls.Listen("tcp", fmt.Sprintf("%s:0", h.Env().Host()), h.tlsConfig.Value())
|
|
if err != nil {
|
|
return err
|
|
}
|
|
} else {
|
|
var err error
|
|
listener, err = net.Listen("tcp", fmt.Sprintf("%s:0", h.Env().Host()))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
h.serverPort.Resolve(listener.Addr().(*net.TCPAddr).Port)
|
|
if h.serverTracerProviderOverride != nil {
|
|
h.serverTracerProvider.Resolve(h.serverTracerProviderOverride)
|
|
} else {
|
|
h.serverTracerProvider.Resolve(trace.NewTracerProvider(ctx, h.displayName))
|
|
}
|
|
if h.clientTracerProviderOverride != nil {
|
|
h.clientTracerProvider.Resolve(h.clientTracerProviderOverride)
|
|
} else {
|
|
h.clientTracerProvider.Resolve(trace.NewTracerProvider(ctx, "HTTP Client"))
|
|
}
|
|
h.router.Use(trace.NewHTTPMiddleware(otelhttp.WithTracerProvider(h.serverTracerProvider.Value())))
|
|
|
|
server := &http.Server{
|
|
Handler: h.router,
|
|
// BaseContext: func(net.Listener) context.Context {
|
|
// return ctx
|
|
// },
|
|
}
|
|
if h.delayShutdown {
|
|
return snippets.RunWithDelayedShutdown(ctx,
|
|
func() error {
|
|
return server.Serve(listener)
|
|
},
|
|
func() {
|
|
_ = server.Shutdown(context.Background())
|
|
},
|
|
)()
|
|
}
|
|
errC := make(chan error, 1)
|
|
go func() {
|
|
errC <- server.Serve(listener)
|
|
}()
|
|
select {
|
|
case <-ctx.Done():
|
|
_ = server.Shutdown(context.Background())
|
|
return context.Cause(ctx)
|
|
case err := <-errC:
|
|
return err
|
|
}
|
|
}
|
|
|
|
// Get implements HTTPUpstream.
|
|
func (h *httpUpstream) Get(r testenv.Route, opts ...RequestOption) (*http.Response, error) {
|
|
return h.Do(http.MethodGet, r, opts...)
|
|
}
|
|
|
|
// Post implements HTTPUpstream.
|
|
func (h *httpUpstream) Post(r testenv.Route, opts ...RequestOption) (*http.Response, error) {
|
|
return h.Do(http.MethodPost, r, opts...)
|
|
}
|
|
|
|
type Transport struct {
|
|
*otelhttp.Transport
|
|
// The underlying http.Transport instance wrapped by the otelhttp.Transport.
|
|
Base *http.Transport
|
|
}
|
|
|
|
var _ http.RoundTripper = Transport{}
|
|
|
|
func (h *httpUpstream) newClient(options *RequestOptions) *http.Client {
|
|
transport := http.DefaultTransport.(*http.Transport).Clone()
|
|
transport.TLSClientConfig = &tls.Config{
|
|
RootCAs: h.Env().ServerCAs(),
|
|
Certificates: options.clientCerts,
|
|
}
|
|
transport.DialTLSContext = nil
|
|
c := http.Client{
|
|
Transport: &Transport{
|
|
Transport: otelhttp.NewTransport(transport,
|
|
otelhttp.WithTracerProvider(h.clientTracerProvider.Value()),
|
|
otelhttp.WithSpanNameFormatter(func(_ string, r *http.Request) string {
|
|
return fmt.Sprintf("Client: %s %s", r.Method, r.URL.Path)
|
|
}),
|
|
),
|
|
Base: transport,
|
|
},
|
|
}
|
|
c.Jar, _ = cookiejar.New(&cookiejar.Options{})
|
|
return &c
|
|
}
|
|
|
|
func (h *httpUpstream) getRouteClient(r testenv.Route, options *RequestOptions) *http.Client {
|
|
span := oteltrace.SpanFromContext(options.requestCtx)
|
|
var cachedClient any
|
|
var ok bool
|
|
if cachedClient, ok = h.clientCache.Load(r); !ok {
|
|
span.AddEvent("creating new http client")
|
|
cachedClient, _ = h.clientCache.LoadOrStore(r, h.newClient(options))
|
|
} else {
|
|
span.AddEvent("using cached http client")
|
|
}
|
|
client := cachedClient.(*http.Client)
|
|
if options.clientHook != nil {
|
|
client = options.clientHook(client)
|
|
}
|
|
return client
|
|
}
|
|
|
|
// Do implements HTTPUpstream.
|
|
func (h *httpUpstream) Do(method string, r testenv.Route, opts ...RequestOption) (*http.Response, error) {
|
|
options := RequestOptions{
|
|
requestCtx: h.Env().Context(),
|
|
}
|
|
options.apply(opts...)
|
|
u, err := url.Parse(r.URL().Value())
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if options.path != "" || options.query != nil {
|
|
u = u.ResolveReference(&url.URL{
|
|
Path: options.path,
|
|
RawQuery: options.query.Encode(),
|
|
})
|
|
}
|
|
ctx, span := h.clientTracer.Value().Start(options.requestCtx, "httpUpstream.Do", oteltrace.WithAttributes(
|
|
attribute.String("method", method),
|
|
attribute.String("url", u.String()),
|
|
))
|
|
if options.trace != nil {
|
|
ctx = httptrace.WithClientTrace(ctx, options.trace)
|
|
}
|
|
options.requestCtx = ctx
|
|
defer span.End()
|
|
|
|
return doAuthenticatedRequest(options.requestCtx,
|
|
func(ctx context.Context) (*http.Request, error) {
|
|
return http.NewRequestWithContext(ctx, method, u.String(), nil)
|
|
},
|
|
func(context.Context) *http.Client {
|
|
return h.getRouteClient(r, &options)
|
|
},
|
|
&options,
|
|
)
|
|
}
|
|
|
|
func (h *httpUpstream) DialWS(r testenv.Route, f func(conn *websocket.Conn) error, opts ...RequestOption) error {
|
|
options := RequestOptions{
|
|
requestCtx: h.Env().Context(),
|
|
}
|
|
options.apply(opts...)
|
|
u, err := url.Parse(r.URL().Value())
|
|
if err != nil {
|
|
return err
|
|
}
|
|
u.Scheme = "wss"
|
|
if options.path != "" || options.query != nil {
|
|
u = u.ResolveReference(&url.URL{
|
|
Path: options.path,
|
|
RawQuery: options.query.Encode(),
|
|
})
|
|
}
|
|
ctx, span := h.clientTracer.Value().Start(options.requestCtx, "httpUpstream.Dial", oteltrace.WithAttributes(
|
|
attribute.String("url", u.String()),
|
|
))
|
|
options.requestCtx = ctx
|
|
defer span.End()
|
|
|
|
client := h.getRouteClient(r, &options)
|
|
d := &websocket.Dialer{
|
|
HandshakeTimeout: 10 * time.Second,
|
|
TLSClientConfig: client.Transport.(*Transport).Base.TLSClientConfig,
|
|
Jar: client.Jar,
|
|
}
|
|
if options.dialerHook != nil {
|
|
d, u = options.dialerHook(d, u)
|
|
}
|
|
conn, resp, err := d.DialContext(options.requestCtx, u.String(), nil)
|
|
if err != nil {
|
|
resp.Body.Close()
|
|
return fmt.Errorf("DialContext: %w", err)
|
|
}
|
|
defer conn.Close()
|
|
|
|
return f(conn)
|
|
}
|