mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-02 20:06:03 +02:00
* Initial test environment implementation * linter pass * wip: update request latency test * bugfixes * Fix logic race in envoy process monitor when canceling context * skip tests using test environment on non-linux
327 lines
8.2 KiB
Go
327 lines
8.2 KiB
Go
package upstreams
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"crypto/tls"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"net/http"
|
|
"net/http/cookiejar"
|
|
"net/url"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/gorilla/mux"
|
|
"github.com/pomerium/pomerium/integration/forms"
|
|
"github.com/pomerium/pomerium/internal/retry"
|
|
"github.com/pomerium/pomerium/internal/testenv"
|
|
"github.com/pomerium/pomerium/internal/testenv/values"
|
|
"google.golang.org/protobuf/proto"
|
|
)
|
|
|
|
type RequestOptions struct {
|
|
path string
|
|
query url.Values
|
|
headers map[string]string
|
|
authenticateAs string
|
|
body any
|
|
clientCerts []tls.Certificate
|
|
client *http.Client
|
|
}
|
|
|
|
type RequestOption func(*RequestOptions)
|
|
|
|
func (o *RequestOptions) apply(opts ...RequestOption) {
|
|
for _, op := range opts {
|
|
op(o)
|
|
}
|
|
}
|
|
|
|
// Path sets the path of the request. If omitted, the request URL will match
|
|
// the route URL exactly.
|
|
func Path(path string) RequestOption {
|
|
return func(o *RequestOptions) {
|
|
o.path = path
|
|
}
|
|
}
|
|
|
|
// Query sets optional query parameters of the request.
|
|
func Query(query url.Values) RequestOption {
|
|
return func(o *RequestOptions) {
|
|
o.query = query
|
|
}
|
|
}
|
|
|
|
// Headers adds optional headers to the request.
|
|
func Headers(headers map[string]string) RequestOption {
|
|
return func(o *RequestOptions) {
|
|
o.headers = headers
|
|
}
|
|
}
|
|
|
|
func AuthenticateAs(email string) RequestOption {
|
|
return func(o *RequestOptions) {
|
|
o.authenticateAs = email
|
|
}
|
|
}
|
|
|
|
func Client(c *http.Client) RequestOption {
|
|
return func(o *RequestOptions) {
|
|
o.client = c
|
|
}
|
|
}
|
|
|
|
// Body sets the body of the request.
|
|
// The argument can be one of the following types:
|
|
// - string
|
|
// - []byte
|
|
// - io.Reader
|
|
// - proto.Message
|
|
// - any json-encodable type
|
|
// If the argument is encoded as json, the Content-Type header will be set to
|
|
// "application/json". If the argument is a proto.Message, the Content-Type
|
|
// header will be set to "application/octet-stream".
|
|
func Body(body any) RequestOption {
|
|
return func(o *RequestOptions) {
|
|
o.body = body
|
|
}
|
|
}
|
|
|
|
// ClientCert adds a client certificate to the request.
|
|
func ClientCert[T interface {
|
|
*testenv.Certificate | *tls.Certificate
|
|
}](cert T) RequestOption {
|
|
return func(o *RequestOptions) {
|
|
o.clientCerts = append(o.clientCerts, *(*tls.Certificate)(cert))
|
|
}
|
|
}
|
|
|
|
// HTTPUpstream represents a HTTP server which can be used as the target for
|
|
// one or more Pomerium routes in a test environment.
|
|
//
|
|
// The Handle() method can be used to add handlers the server-side HTTP router,
|
|
// while the Get(), Post(), and (generic) Do() methods can be used to make
|
|
// client-side requests.
|
|
type HTTPUpstream interface {
|
|
testenv.Upstream
|
|
|
|
Handle(path string, f func(http.ResponseWriter, *http.Request)) *mux.Route
|
|
|
|
Get(r testenv.Route, opts ...RequestOption) (*http.Response, error)
|
|
Post(r testenv.Route, opts ...RequestOption) (*http.Response, error)
|
|
Do(method string, r testenv.Route, opts ...RequestOption) (*http.Response, error)
|
|
}
|
|
|
|
type httpUpstream struct {
|
|
testenv.Aggregate
|
|
serverPort values.MutableValue[int]
|
|
tlsConfig values.Value[*tls.Config]
|
|
|
|
clientCache sync.Map // map[testenv.Route]*http.Client
|
|
|
|
router *mux.Router
|
|
}
|
|
|
|
var (
|
|
_ testenv.Upstream = (*httpUpstream)(nil)
|
|
_ HTTPUpstream = (*httpUpstream)(nil)
|
|
)
|
|
|
|
// HTTP creates a new HTTP upstream server.
|
|
func HTTP(tlsConfig values.Value[*tls.Config]) HTTPUpstream {
|
|
up := &httpUpstream{
|
|
serverPort: values.Deferred[int](),
|
|
router: mux.NewRouter(),
|
|
tlsConfig: tlsConfig,
|
|
}
|
|
up.RecordCaller()
|
|
return up
|
|
}
|
|
|
|
// Port implements HTTPUpstream.
|
|
func (h *httpUpstream) Port() values.Value[int] {
|
|
return h.serverPort
|
|
}
|
|
|
|
// Router implements HTTPUpstream.
|
|
func (h *httpUpstream) Handle(path string, f func(http.ResponseWriter, *http.Request)) *mux.Route {
|
|
return h.router.HandleFunc(path, f)
|
|
}
|
|
|
|
// Route implements HTTPUpstream.
|
|
func (h *httpUpstream) Route() testenv.RouteStub {
|
|
r := &testenv.PolicyRoute{}
|
|
protocol := "http"
|
|
r.To(values.Bind(h.serverPort, func(port int) string {
|
|
return fmt.Sprintf("%s://127.0.0.1:%d", protocol, port)
|
|
}))
|
|
h.Add(r)
|
|
return r
|
|
}
|
|
|
|
// Run implements HTTPUpstream.
|
|
func (h *httpUpstream) Run(ctx context.Context) error {
|
|
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
h.serverPort.Resolve(listener.Addr().(*net.TCPAddr).Port)
|
|
var tlsConfig *tls.Config
|
|
if h.tlsConfig != nil {
|
|
tlsConfig = h.tlsConfig.Value()
|
|
}
|
|
server := &http.Server{
|
|
Handler: h.router,
|
|
TLSConfig: tlsConfig,
|
|
BaseContext: func(net.Listener) context.Context {
|
|
return ctx
|
|
},
|
|
}
|
|
errC := make(chan error, 1)
|
|
go func() {
|
|
errC <- server.Serve(listener)
|
|
}()
|
|
select {
|
|
case <-ctx.Done():
|
|
server.Close()
|
|
return context.Cause(ctx)
|
|
case err := <-errC:
|
|
return err
|
|
}
|
|
}
|
|
|
|
// Get implements HTTPUpstream.
|
|
func (h *httpUpstream) Get(r testenv.Route, opts ...RequestOption) (*http.Response, error) {
|
|
return h.Do(http.MethodGet, r, opts...)
|
|
}
|
|
|
|
// Post implements HTTPUpstream.
|
|
func (h *httpUpstream) Post(r testenv.Route, opts ...RequestOption) (*http.Response, error) {
|
|
return h.Do(http.MethodPost, r, opts...)
|
|
}
|
|
|
|
// Do implements HTTPUpstream.
|
|
func (h *httpUpstream) Do(method string, r testenv.Route, opts ...RequestOption) (*http.Response, error) {
|
|
options := RequestOptions{}
|
|
options.apply(opts...)
|
|
u, err := url.Parse(r.URL().Value())
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if options.path != "" || options.query != nil {
|
|
u = u.ResolveReference(&url.URL{
|
|
Path: options.path,
|
|
RawQuery: options.query.Encode(),
|
|
})
|
|
}
|
|
req, err := http.NewRequest(method, u.String(), nil)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
switch body := options.body.(type) {
|
|
case string:
|
|
req.Body = io.NopCloser(strings.NewReader(body))
|
|
case []byte:
|
|
req.Body = io.NopCloser(bytes.NewReader(body))
|
|
case io.Reader:
|
|
req.Body = io.NopCloser(body)
|
|
case proto.Message:
|
|
buf, err := proto.Marshal(body)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
req.Body = io.NopCloser(bytes.NewReader(buf))
|
|
req.Header.Set("Content-Type", "application/octet-stream")
|
|
default:
|
|
buf, err := json.Marshal(body)
|
|
if err != nil {
|
|
panic(fmt.Sprintf("unsupported body type: %T", body))
|
|
}
|
|
req.Body = io.NopCloser(bytes.NewReader(buf))
|
|
req.Header.Set("Content-Type", "application/json")
|
|
case nil:
|
|
}
|
|
|
|
newClient := func() *http.Client {
|
|
c := http.Client{
|
|
Transport: &http.Transport{
|
|
TLSClientConfig: &tls.Config{
|
|
RootCAs: h.Env().ServerCAs(),
|
|
Certificates: options.clientCerts,
|
|
},
|
|
},
|
|
}
|
|
c.Jar, _ = cookiejar.New(&cookiejar.Options{})
|
|
return &c
|
|
}
|
|
var client *http.Client
|
|
if options.client != nil {
|
|
client = options.client
|
|
} else {
|
|
var cachedClient any
|
|
var ok bool
|
|
if cachedClient, ok = h.clientCache.Load(r); !ok {
|
|
cachedClient, _ = h.clientCache.LoadOrStore(r, newClient())
|
|
}
|
|
client = cachedClient.(*http.Client)
|
|
}
|
|
|
|
var resp *http.Response
|
|
if err := retry.Retry(h.Env().Context(), "http", func(ctx context.Context) error {
|
|
var err error
|
|
if options.authenticateAs != "" {
|
|
resp, err = authenticateFlow(ctx, client, req, options.authenticateAs) //nolint:bodyclose
|
|
} else {
|
|
resp, err = client.Do(req) //nolint:bodyclose
|
|
}
|
|
// retry on connection refused
|
|
if err != nil {
|
|
var opErr *net.OpError
|
|
if errors.As(err, &opErr) && opErr.Op == "dial" && opErr.Err.Error() == "connect: connection refused" {
|
|
return err
|
|
}
|
|
return retry.NewTerminalError(err)
|
|
}
|
|
if resp.StatusCode == http.StatusInternalServerError {
|
|
return errors.New(http.StatusText(resp.StatusCode))
|
|
}
|
|
return nil
|
|
}, retry.WithMaxInterval(100*time.Millisecond)); err != nil {
|
|
return nil, err
|
|
}
|
|
return resp, nil
|
|
}
|
|
|
|
func authenticateFlow(ctx context.Context, client *http.Client, req *http.Request, email string) (*http.Response, error) {
|
|
var res *http.Response
|
|
originalHostname := req.URL.Hostname()
|
|
res, err := client.Do(req)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
location := res.Request.URL
|
|
if location.Hostname() == originalHostname {
|
|
// already authenticated
|
|
return res, err
|
|
}
|
|
defer res.Body.Close()
|
|
fs := forms.Parse(res.Body)
|
|
if len(fs) > 0 {
|
|
f := fs[0]
|
|
f.Inputs["email"] = email
|
|
f.Inputs["token_expiration"] = strconv.Itoa(int((time.Hour * 24).Seconds()))
|
|
formReq, err := f.NewRequestWithContext(ctx, location)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return client.Do(formReq)
|
|
}
|
|
return nil, fmt.Errorf("test bug: expected IDP login form")
|
|
}
|