pomerium/internal/testenv/environment.go
2025-03-17 11:59:02 -07:00

1045 lines
32 KiB
Go

package testenv
import (
"bytes"
"context"
"crypto/rand"
"crypto/rsa"
"crypto/sha256"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/asn1"
"encoding/base64"
"encoding/hex"
"errors"
"flag"
"fmt"
"io"
"math/big"
"net"
"net/url"
"os"
"os/signal"
"path"
"path/filepath"
"runtime"
"strconv"
"strings"
"sync"
"sync/atomic"
"syscall"
"testing"
"time"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/config/envoyconfig/filemgr"
databroker_service "github.com/pomerium/pomerium/databroker"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/telemetry/trace"
"github.com/pomerium/pomerium/internal/testenv/envutil"
"github.com/pomerium/pomerium/internal/testenv/values"
"github.com/pomerium/pomerium/pkg/cmd/pomerium"
"github.com/pomerium/pomerium/pkg/envoy"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/health"
"github.com/pomerium/pomerium/pkg/identity/manager"
"github.com/pomerium/pomerium/pkg/netutil"
"github.com/pomerium/pomerium/pkg/slices"
"github.com/rs/zerolog"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/exporters/otlp/otlptrace"
oteltrace "go.opentelemetry.io/otel/trace"
"golang.org/x/sync/errgroup"
"google.golang.org/grpc/grpclog"
)
// Environment is a lightweight integration test fixture that runs Pomerium
// in-process.
type Environment interface {
// Context returns the environment's root context. This context holds a
// top-level logger scoped to this environment. It will be canceled when
// Stop() is called, or during test cleanup.
Context() context.Context
Tracer() oteltrace.Tracer
Assert() *assert.Assertions
Require() *require.Assertions
// TempDir returns a unique temp directory for this context. Calling this
// function multiple times returns the same path.
TempDir() string
// CACert returns the test environment's root CA certificate and private key.
CACert() *tls.Certificate
// ServerCAs returns a new [*x509.CertPool] containing the root CA certificate
// used to sign the server cert and other test certificates.
ServerCAs() *x509.CertPool
// ServerCert returns the Pomerium server's certificate and private key.
ServerCert() *tls.Certificate
// NewClientCert generates a new client certificate signed by the root CA
// certificate. One or more optional templates can be given, which can be
// used to set or override certain parameters when creating a certificate,
// including subject, SANs, or extensions. If more than one template is
// provided, they will be applied in order from left to right.
//
// By default (unless overridden in a template), the certificate will have
// its Common Name set to the file:line string of the call site. Calls to
// NewClientCert() on different lines will have different subjects. If
// multiple certs with the same subject are needed, wrap the call to this
// function in another helper function, or separate calls with commas on the
// same line.
NewClientCert(templateOverrides ...*x509.Certificate) *Certificate
NewServerCert(templateOverrides ...*x509.Certificate) *Certificate
AuthenticateURL() values.Value[string]
DatabrokerURL() values.Value[string]
Ports() Ports
SharedSecret() []byte
CookieSecret() []byte
// Add adds the given [Modifier] to the environment. All modifiers will be
// invoked upon calling Start() to apply individual modifications to the
// configuration before starting the Pomerium server.
Add(m Modifier)
// AddTask adds the given [Task] to the environment. All tasks will be
// started in separate goroutines upon calling Start(). If any tasks exit
// with an error, the environment will be stopped and the test will fail.
AddTask(r Task)
// AddUpstream adds the given [Upstream] to the environment. This function is
// equivalent to calling both Add() and AddTask() with the upstream, but
// improves readability.
AddUpstream(u Upstream)
// Start starts the test environment, and adds a call to Stop() as a cleanup
// hook to the environment's [testing.T]. All previously added [Modifier]
// instances are invoked in order to build the configuration, and all
// previously added [Task] instances are started in the background.
//
// Calling Start() more than once, Calling Start() after Stop(), or calling
// any of the Add* functions after Start() will panic.
Start()
// Stop stops the test environment. Calling this function more than once has
// no effect. It is usually not necessary to call Stop() directly unless you
// need to stop the test environment before the test is completed.
Stop()
// Pause will block and wait until SIGINT is received, then continue. This
// has the same effect as if the test failed and the PauseOnFailure option was
// given, but can be called at any time.
Pause()
// SubdomainURL returns a string [values.Value] which will contain a complete
// URL for the given subdomain of the server's domain (given by its serving
// certificate), including the 'https://' scheme and random http server port.
// This value will only be resolved some time after Start() is called, and
// can be used as the 'from' value for routes.
SubdomainURL(subdomain string) values.Value[string]
// NewLogRecorder returns a new [*LogRecorder] and starts capturing logs for
// the Pomerium server and Envoy.
NewLogRecorder(opts ...LogRecorderOption) *LogRecorder
// GetState returns the current state of the test environment.
GetState() EnvironmentState
// OnStateChanged registers a callback to be invoked when the environment's
// state changes to the given state. Each callback is invoked in a separate
// goroutine, but the test environment will wait for all callbacks to return
// before continuing, after triggering the state change.
//
// Calling the returned stop function will prevent the callback from being
// run. Returns true if it stopped the callback from being run, or false if
// it already ran or is currently running.
//
// If the environment is already in the given state, the callback will be run
// in a separate goroutine immediately and the returned stop function will
// have no effect. A callback run in this way will prevent the state from
// advancing until the callback returns.
//
// State changes are triggered in the following places:
// - NotRunning->Starting: in Start(), as the first operation
// - Starting->Running: in Start(), just before returning
// - Running->Stopping: in Stop(), just before the env context is canceled
// - Stopping->Stopped: in Stop(), after all tasks have completed
OnStateChanged(state EnvironmentState, callback func()) (stop func() bool)
}
type Certificate tls.Certificate
func (c *Certificate) Fingerprint() string {
sum := sha256.Sum256(c.Leaf.Raw)
return hex.EncodeToString(sum[:])
}
func (c *Certificate) SPKIHash() string {
sum := sha256.Sum256(c.Leaf.RawSubjectPublicKeyInfo)
return base64.StdEncoding.EncodeToString(sum[:])
}
type EnvironmentState uint32
const (
NotRunning EnvironmentState = iota
Starting
Running
Stopping
Stopped
)
func (e EnvironmentState) String() string {
switch e {
case NotRunning:
return "NotRunning"
case Starting:
return "Starting"
case Running:
return "Running"
case Stopping:
return "Stopping"
case Stopped:
return "Stopped"
default:
return fmt.Sprintf("EnvironmentState(%d)", e)
}
}
type environment struct {
EnvironmentOptions
t testing.TB
assert *assert.Assertions
require *require.Assertions
tempDir string
domain string
ports Ports
sharedSecret [32]byte
cookieSecret [32]byte
workspaceFolder string
silent bool
ctx context.Context
cancel context.CancelCauseFunc
cleanupOnce sync.Once
logWriter *log.MultiWriter
tracerProvider oteltrace.TracerProvider
tracer oteltrace.Tracer
rootSpan oteltrace.Span
mods []WithCaller[Modifier]
tasks []WithCaller[Task]
taskErrGroup *errgroup.Group
stateMu sync.Mutex
state EnvironmentState
stateChangeListeners map[EnvironmentState][]func()
stateChangeBlockers sync.WaitGroup
src *configSource
}
type EnvironmentOptions struct {
debug bool
pauseOnFailure bool
forceSilent bool
traceDebugFlags trace.DebugFlags
traceClient otlptrace.Client
}
type EnvironmentOption func(*EnvironmentOptions)
func (o *EnvironmentOptions) apply(opts ...EnvironmentOption) {
for _, op := range opts {
op(o)
}
}
func Debug(enable ...bool) EnvironmentOption {
if len(enable) == 0 {
enable = append(enable, true)
}
return func(o *EnvironmentOptions) {
o.debug = enable[0]
}
}
func PauseOnFailure(enable ...bool) EnvironmentOption {
if len(enable) == 0 {
enable = append(enable, true)
}
return func(o *EnvironmentOptions) {
o.pauseOnFailure = enable[0]
}
}
func Silent(silent ...bool) EnvironmentOption {
if len(silent) == 0 {
silent = append(silent, true)
}
return func(o *EnvironmentOptions) {
o.forceSilent = silent[0]
}
}
const StandardTraceDebugFlags = trace.TrackSpanCallers |
trace.WarnOnIncompleteSpans |
trace.WarnOnIncompleteTraces |
trace.WarnOnUnresolvedReferences |
trace.LogTraceIDsOnWarn |
trace.LogAllSpansOnWarn
func WithTraceDebugFlags(flags trace.DebugFlags) EnvironmentOption {
return func(o *EnvironmentOptions) {
o.traceDebugFlags = flags
}
}
func WithTraceClient(traceClient otlptrace.Client) EnvironmentOption {
return func(o *EnvironmentOptions) {
o.traceClient = traceClient
}
}
var setGrpcLoggerOnce sync.Once
const defaultTraceDebugFlags = trace.TrackSpanCallers | trace.TrackSpanReferences
var (
flagDebug = flag.Bool("env.debug", false, "enables test environment debug logging (equivalent to Debug() option)")
flagPauseOnFailure = flag.Bool("env.pause-on-failure", false, "enables pausing the test environment on failure (equivalent to PauseOnFailure() option)")
flagSilent = flag.Bool("env.silent", false, "suppresses all test environment output (equivalent to Silent() option)")
flagTraceDebugFlags = flag.String("env.trace-debug-flags", strconv.Itoa(defaultTraceDebugFlags), "trace debug flags (equivalent to TraceDebugFlags() option)")
)
func New(t testing.TB, opts ...EnvironmentOption) Environment {
addTraceDebugFlags := strings.HasPrefix(*flagTraceDebugFlags, "+")
defaultTraceDebugFlags, err := strconv.Atoi(strings.TrimPrefix(*flagTraceDebugFlags, "+"))
if err != nil {
panic("malformed value for --env.trace-debug-flags: " + err.Error())
}
options := EnvironmentOptions{
debug: *flagDebug,
pauseOnFailure: *flagPauseOnFailure,
forceSilent: *flagSilent,
traceDebugFlags: trace.DebugFlags(defaultTraceDebugFlags),
traceClient: trace.NoopClient{},
}
options.apply(opts...)
if testing.Short() {
t.Helper()
t.Skip("test environment disabled in short mode")
}
if addTraceDebugFlags {
options.traceDebugFlags |= trace.DebugFlags(defaultTraceDebugFlags)
}
trace.UseGlobalPanicTracer()
databroker.DebugUseFasterBackoff.Store(true)
workspaceFolder, err := os.Getwd()
require.NoError(t, err)
for {
if _, err := os.Stat(filepath.Join(workspaceFolder, ".git")); err == nil {
break
}
workspaceFolder = filepath.Dir(workspaceFolder)
if workspaceFolder == "/" {
panic("could not find workspace root")
}
}
workspaceFolder, err = filepath.Abs(workspaceFolder)
require.NoError(t, err)
writer := log.NewMultiWriter()
silent := options.forceSilent || isSilent(t)
if silent {
// this sets the global zap level to fatal, then resets the global zerolog
// level to debug
log.SetLevel(zerolog.FatalLevel)
zerolog.SetGlobalLevel(zerolog.DebugLevel)
} else {
log.SetLevel(zerolog.InfoLevel)
writer.Add(os.Stdout)
}
log.DebugDisableGlobalWarnings.Store(silent)
log.DebugDisableGlobalMessages.Store(silent)
log.DebugDisableZapLogger.Store(silent)
setGrpcLoggerOnce.Do(func() {
grpclog.SetLoggerV2(grpclog.NewLoggerV2WithVerbosity(io.Discard, io.Discard, io.Discard, 0))
})
logger := zerolog.New(writer).With().Timestamp().Logger().Level(zerolog.DebugLevel)
ctx := trace.Options{
DebugFlags: options.traceDebugFlags,
}.NewContext(logger.WithContext(context.Background()), options.traceClient)
tracerProvider := trace.NewTracerProvider(ctx, "Test Environment")
tracer := tracerProvider.Tracer(trace.PomeriumCoreTracer)
ctx, span := tracer.Start(ctx, t.Name(), oteltrace.WithNewRoot())
require.NoError(t, err)
ctx, cancel := context.WithCancelCause(ctx)
taskErrGroup, ctx := errgroup.WithContext(ctx)
e := &environment{
EnvironmentOptions: options,
t: t,
assert: assert.New(t),
require: require.New(t),
tempDir: tempDir(t),
ports: Ports{
ProxyHTTP: values.Deferred[int](),
ProxyGRPC: values.Deferred[int](),
ProxyMetrics: values.Deferred[int](),
GRPC: values.Deferred[int](),
HTTP: values.Deferred[int](),
Outbound: values.Deferred[int](),
Metrics: values.Deferred[int](),
Debug: values.Deferred[int](),
ALPN: values.Deferred[int](),
},
workspaceFolder: workspaceFolder,
silent: silent,
ctx: ctx,
cancel: cancel,
tracerProvider: tracerProvider,
tracer: tracer,
logWriter: writer,
taskErrGroup: taskErrGroup,
stateChangeListeners: make(map[EnvironmentState][]func()),
rootSpan: span,
}
_, err = rand.Read(e.sharedSecret[:])
require.NoError(t, err)
_, err = rand.Read(e.cookieSecret[:])
require.NoError(t, err)
health.SetProvider(e)
require.NoError(t, os.Mkdir(filepath.Join(e.tempDir, "certs"), 0o777))
copyFile := func(src, dstRel string) {
data, err := os.ReadFile(src)
require.NoError(t, err)
require.NoError(t, os.WriteFile(filepath.Join(e.tempDir, dstRel), data, 0o600))
}
certsToCopy := []string{
"trusted.pem",
"trusted-key.pem",
"ca.pem",
"ca-key.pem",
}
for _, crt := range certsToCopy {
copyFile(filepath.Join(workspaceFolder, "integration/tpl/files", crt), filepath.Join("certs/", filepath.Base(crt)))
}
e.domain = wildcardDomain(e.ServerCert().Leaf.DNSNames)
return e
}
func (e *environment) debugf(format string, args ...any) {
e.t.Helper()
if e.rootSpan.IsRecording() {
e.rootSpan.AddEvent(fmt.Sprintf(format, args...))
}
if !e.debug {
return
}
e.t.Logf("\x1b[34mDEBUG ["+e.t.Name()+"] "+format+"\x1b[0m", args...)
}
type WithCaller[T any] struct {
Caller string
Value T
}
type Ports struct {
ProxyHTTP values.MutableValue[int]
ProxyGRPC values.MutableValue[int]
ProxyMetrics values.MutableValue[int]
GRPC values.MutableValue[int]
HTTP values.MutableValue[int]
Outbound values.MutableValue[int]
Metrics values.MutableValue[int]
Debug values.MutableValue[int]
ALPN values.MutableValue[int]
}
func (e *environment) TempDir() string {
return e.tempDir
}
func (e *environment) Context() context.Context {
return ContextWithEnv(e.ctx, e)
}
func (e *environment) Tracer() oteltrace.Tracer {
return e.tracer
}
func (e *environment) Assert() *assert.Assertions {
return e.assert
}
func (e *environment) Require() *require.Assertions {
return e.require
}
func (e *environment) SubdomainURL(subdomain string) values.Value[string] {
return values.Bind(e.ports.ProxyHTTP, func(port int) string {
return fmt.Sprintf("https://%s.%s:%d", subdomain, e.domain, port)
})
}
func (e *environment) AuthenticateURL() values.Value[string] {
return e.SubdomainURL("authenticate")
}
func (e *environment) DatabrokerURL() values.Value[string] {
return values.Bind(e.ports.Outbound, func(port int) string {
return fmt.Sprintf("127.0.0.1:%d", port)
})
}
func (e *environment) Ports() Ports {
return e.ports
}
func (e *environment) CACert() *tls.Certificate {
caCert, err := tls.LoadX509KeyPair(
filepath.Join(e.tempDir, "certs", "ca.pem"),
filepath.Join(e.tempDir, "certs", "ca-key.pem"),
)
require.NoError(e.t, err)
return &caCert
}
func (e *environment) ServerCAs() *x509.CertPool {
pool := x509.NewCertPool()
caCert, err := os.ReadFile(filepath.Join(e.tempDir, "certs", "ca.pem"))
require.NoError(e.t, err)
pool.AppendCertsFromPEM(caCert)
return pool
}
func (e *environment) ServerCert() *tls.Certificate {
serverCert, err := tls.LoadX509KeyPair(
filepath.Join(e.tempDir, "certs", "trusted.pem"),
filepath.Join(e.tempDir, "certs", "trusted-key.pem"),
)
require.NoError(e.t, err)
return &serverCert
}
// Used as the context's cancel cause during normal cleanup
var ErrCauseTestCleanup = errors.New("test cleanup")
// Used as the context's cancel cause when Stop() is called
var ErrCauseManualStop = errors.New("Stop() called")
func (e *environment) Start() {
_, span := e.tracer.Start(e.Context(), "Start")
defer span.End()
e.debugf("Start()")
e.advanceState(Starting)
e.t.Cleanup(e.onTestCleanup)
e.t.Setenv("TMPDIR", e.TempDir())
e.debugf("temp dir: %s", e.TempDir())
cfg := &config.Config{
Options: config.NewDefaultOptions(),
}
ports, err := netutil.AllocatePorts(9)
require.NoError(e.t, err)
atoi := func(str string) int {
p, err := strconv.Atoi(str)
if err != nil {
panic(err)
}
return p
}
e.ports.ProxyHTTP.Resolve(atoi(ports[0]))
e.ports.ProxyGRPC.Resolve(atoi(ports[1]))
e.ports.ProxyMetrics.Resolve(atoi(ports[2]))
e.ports.GRPC.Resolve(atoi(ports[3]))
e.ports.HTTP.Resolve(atoi(ports[4]))
e.ports.Outbound.Resolve(atoi(ports[5]))
e.ports.Metrics.Resolve(atoi(ports[6]))
e.ports.Debug.Resolve(atoi(ports[7]))
e.ports.ALPN.Resolve(atoi(ports[8]))
cfg.AllocatePorts(*(*[6]string)(ports[3:]))
cfg.Options.AutocertOptions = config.AutocertOptions{Enable: false}
cfg.Options.Services = "all"
cfg.Options.LogLevel = config.LogLevelDebug
cfg.Options.ProxyLogLevel = config.LogLevelInfo
cfg.Options.Addr = fmt.Sprintf("127.0.0.1:%d", e.ports.ProxyHTTP.Value())
cfg.Options.GRPCAddr = fmt.Sprintf("127.0.0.1:%d", e.ports.ProxyGRPC.Value())
cfg.Options.MetricsAddr = fmt.Sprintf("127.0.0.1:%d", e.ports.ProxyMetrics.Value())
cfg.Options.CAFile = filepath.Join(e.tempDir, "certs", "ca.pem")
cfg.Options.CertFile = filepath.Join(e.tempDir, "certs", "trusted.pem")
cfg.Options.KeyFile = filepath.Join(e.tempDir, "certs", "trusted-key.pem")
cfg.Options.AuthenticateURLString = e.AuthenticateURL().Value()
cfg.Options.DataBrokerStorageType = "memory"
cfg.Options.SharedKey = base64.StdEncoding.EncodeToString(e.sharedSecret[:])
cfg.Options.CookieSecret = base64.StdEncoding.EncodeToString(e.cookieSecret[:])
cfg.Options.AccessLogFields = []log.AccessLogField{
log.AccessLogFieldAuthority,
log.AccessLogFieldDuration,
log.AccessLogFieldForwardedFor,
log.AccessLogFieldIP,
log.AccessLogFieldMethod,
log.AccessLogFieldPath,
log.AccessLogFieldQuery,
log.AccessLogFieldReferer,
log.AccessLogFieldRequestID,
log.AccessLogFieldResponseCode,
log.AccessLogFieldResponseCodeDetails,
log.AccessLogFieldSize,
log.AccessLogFieldUpstreamCluster,
log.AccessLogFieldUserAgent,
log.AccessLogFieldClientCertificate,
}
e.src = &configSource{cfg: cfg}
e.AddTask(TaskFunc(func(ctx context.Context) error {
fileMgr := filemgr.NewManager(filemgr.WithCacheDir(filepath.Join(e.TempDir(), "cache")))
for _, mod := range e.mods {
mod.Value.Modify(cfg)
require.NoError(e.t, cfg.Options.Validate(), "invoking modifier resulted in an invalid configuration:\nadded by: "+mod.Caller)
}
opts := []pomerium.Option{
pomerium.WithOverrideFileManager(fileMgr),
pomerium.WithEnvoyServerOptions(envoy.WithExitGracePeriod(30 * time.Second)),
pomerium.WithDataBrokerServerOptions(
databroker_service.WithManagerOptions(manager.WithLeaseTTL(1 * time.Second)),
),
}
envoyBinaryPath := filepath.Join(e.workspaceFolder, fmt.Sprintf("pkg/envoy/files/envoy-%s-%s", runtime.GOOS, runtime.GOARCH))
if envutil.EnvoyProfilerAvailable(envoyBinaryPath) {
e.debugf("envoy profiling available")
envVars := []string{}
pprofCmdLog := "=> go run github.com/google/pprof@latest -symbolize=local -ignore='TCMalloc|^tcmalloc::|^msync$|stacktrace_generic_fp' -http=: %s %s"
if path := envutil.ProfileOutputPath("cpuprofile"); path != "" {
dir, base := filepath.Split(path)
path = filepath.Join(dir, "envoy_"+base)
envVars = append(envVars, fmt.Sprintf("CPUPROFILE=%s", path))
e.t.Cleanup(func() {
e.debugf("View envoy cpu profile:")
e.debugf(pprofCmdLog, envoyBinaryPath, path)
})
}
if path := envutil.ProfileOutputPath("memprofile"); path != "" {
dir, base := filepath.Split(path)
path = filepath.Join(dir, "envoy_"+base)
envVars = append(envVars, fmt.Sprintf("HEAPPROFILE=%s", path))
e.t.Cleanup(func() {
if err := envutil.CollectEnvoyHeapProfiles(path); err != nil {
e.t.Logf("error collecting envoy heap profiles: %s", err)
}
e.debugf("View envoy heap profile:")
envoyBinaryPath := filepath.Join(e.workspaceFolder, fmt.Sprintf("pkg/envoy/files/envoy-%s-%s", runtime.GOOS, runtime.GOARCH))
e.debugf(pprofCmdLog, envoyBinaryPath, path)
})
}
if len(envVars) > 0 {
e.debugf("adding envoy env vars: %v\n", envVars)
opts = append(opts, pomerium.WithEnvoyServerOptions(envoy.WithExtraEnvVars(envVars...)))
}
} else {
e.debugf("envoy profiling not available")
}
pom := pomerium.New(opts...)
e.OnStateChanged(Stopping, func() {
if err := pom.Shutdown(ctx); err != nil {
log.Ctx(ctx).Err(err).Msg("error shutting down pomerium server")
} else {
e.debugf("pomerium server shut down without error")
}
})
require.NoError(e.t, pom.Start(ctx, e.tracerProvider, e.src))
return pom.Wait()
}))
for i, task := range e.tasks {
log.Ctx(e.Context()).Debug().Str("caller", task.Caller).Msgf("starting task %d", i)
e.taskErrGroup.Go(func() error {
defer log.Ctx(e.Context()).Debug().Str("caller", task.Caller).Msgf("task %d exited", i)
return task.Value.Run(e.Context())
})
}
runtime.Gosched()
e.advanceState(Running)
}
func (e *environment) NewClientCert(templateOverrides ...*x509.Certificate) *Certificate {
caCert := e.CACert()
priv, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(e.t, err)
sn, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128))
require.NoError(e.t, err)
now := time.Now()
tmpl := &x509.Certificate{
SerialNumber: sn,
Subject: pkix.Name{
CommonName: getCaller(),
},
NotBefore: now,
NotAfter: now.Add(12 * time.Hour),
KeyUsage: x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{
x509.ExtKeyUsageClientAuth,
},
BasicConstraintsValid: true,
}
for _, override := range templateOverrides {
tmpl.CRLDistributionPoints = slices.Unique(append(tmpl.CRLDistributionPoints, override.CRLDistributionPoints...))
tmpl.DNSNames = slices.Unique(append(tmpl.DNSNames, override.DNSNames...))
tmpl.EmailAddresses = slices.Unique(append(tmpl.EmailAddresses, override.EmailAddresses...))
tmpl.ExtraExtensions = append(tmpl.ExtraExtensions, override.ExtraExtensions...)
tmpl.IPAddresses = slices.UniqueBy(append(tmpl.IPAddresses, override.IPAddresses...), net.IP.String)
tmpl.URIs = slices.UniqueBy(append(tmpl.URIs, override.URIs...), (*url.URL).String)
tmpl.UnknownExtKeyUsage = slices.UniqueBy(append(tmpl.UnknownExtKeyUsage, override.UnknownExtKeyUsage...), asn1.ObjectIdentifier.String)
seq := override.Subject.ToRDNSequence()
tmpl.Subject.FillFromRDNSequence(&seq)
tmpl.KeyUsage |= override.KeyUsage
tmpl.ExtKeyUsage = slices.Unique(append(tmpl.ExtKeyUsage, override.ExtKeyUsage...))
}
clientCertDER, err := x509.CreateCertificate(rand.Reader, tmpl, caCert.Leaf, priv.Public(), caCert.PrivateKey)
require.NoError(e.t, err)
cert, err := x509.ParseCertificate(clientCertDER)
require.NoError(e.t, err)
e.debugf("provisioned client certificate for %s", cert.Subject.String())
clientCert := &tls.Certificate{
Certificate: [][]byte{cert.Raw, caCert.Leaf.Raw},
PrivateKey: priv,
Leaf: cert,
}
_, err = clientCert.Leaf.Verify(x509.VerifyOptions{
KeyUsages: []x509.ExtKeyUsage{
x509.ExtKeyUsageClientAuth,
},
Roots: e.ServerCAs(),
})
require.NoError(e.t, err, "bug: generated client cert is not valid")
return (*Certificate)(clientCert)
}
func (e *environment) NewServerCert(templateOverrides ...*x509.Certificate) *Certificate {
caCert := e.CACert()
priv, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(e.t, err)
sn, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128))
require.NoError(e.t, err)
now := time.Now()
tmpl := &x509.Certificate{
SerialNumber: sn,
NotBefore: now,
NotAfter: now.Add(12 * time.Hour),
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment,
ExtKeyUsage: []x509.ExtKeyUsage{
x509.ExtKeyUsageServerAuth,
},
BasicConstraintsValid: true,
}
for _, override := range templateOverrides {
tmpl.DNSNames = slices.Unique(append(tmpl.DNSNames, override.DNSNames...))
tmpl.IPAddresses = slices.UniqueBy(append(tmpl.IPAddresses, override.IPAddresses...), net.IP.String)
}
certDER, err := x509.CreateCertificate(rand.Reader, tmpl, caCert.Leaf, priv.Public(), caCert.PrivateKey)
require.NoError(e.t, err)
cert, err := x509.ParseCertificate(certDER)
require.NoError(e.t, err)
e.debugf("provisioned server certificate for %v", cert.DNSNames)
tlsCert := &tls.Certificate{
Certificate: [][]byte{cert.Raw, caCert.Leaf.Raw},
PrivateKey: priv,
Leaf: cert,
}
_, err = tlsCert.Leaf.Verify(x509.VerifyOptions{Roots: e.ServerCAs()})
require.NoError(e.t, err, "bug: generated client cert is not valid")
return (*Certificate)(tlsCert)
}
func (e *environment) SharedSecret() []byte {
return bytes.Clone(e.sharedSecret[:])
}
func (e *environment) CookieSecret() []byte {
return bytes.Clone(e.cookieSecret[:])
}
func (e *environment) Stop() {
if b, ok := e.t.(*testing.B); ok {
// when calling Stop() manually, ensure we aren't timing this
b.StopTimer()
defer b.StartTimer()
}
_, file, line, _ := runtime.Caller(1)
e.cleanupOnce.Do(func() {
e.cleanup(fmt.Errorf("%w (caller: %s:%d)", ErrCauseManualStop, file, line))
})
}
func (e *environment) Pause() {
e.t.Log("\x1b[31m*** test manually paused; continue with ctrl+c ***\x1b[0m")
c := make(chan os.Signal, 1)
signal.Notify(c, syscall.SIGINT)
<-c
e.t.Log("\x1b[31mctrl+c received, continuing\x1b[0m")
}
func (e *environment) onTestCleanup() {
e.cleanupOnce.Do(func() {
e.cleanup(ErrCauseTestCleanup)
})
}
func (e *environment) cleanup(cancelCause error) {
e.debugf("stop: %s", cancelCause.Error())
if e.t.Failed() {
if e.pauseOnFailure {
e.t.Log("\x1b[31m*** pausing on test failure; continue with ctrl+c ***\x1b[0m")
c := make(chan os.Signal, 1)
signal.Notify(c, syscall.SIGINT)
<-c
e.t.Log("\x1b[31mctrl+c received, continuing\x1b[0m")
signal.Stop(c)
}
}
e.advanceState(Stopping)
e.cancel(cancelCause)
errs := []error{}
if err := e.taskErrGroup.Wait(); err != nil {
errs = append(errs, fmt.Errorf("error waiting for tasks: %w", err))
}
e.rootSpan.End()
if err := trace.ShutdownContext(e.Context()); err != nil {
errs = append(errs, fmt.Errorf("error shutting down trace context: %w", err))
}
e.advanceState(Stopped)
// Wait for any additional callbacks created during stopped callbacks
// (for consistency, we consider the stopped state to "end" here)
e.stateChangeBlockers.Wait()
e.debugf("stop: done")
// err can be nil if e.g. the only task is the internal pomerium task, which
// returns a nil error if it exits cleanly
if err := errors.Join(errs...); err != nil {
assert.ErrorIs(e.t, err, cancelCause)
}
}
func (e *environment) Add(m Modifier) {
e.t.Helper()
caller := getCaller()
e.debugf("Add: %T from %s", m, caller)
switch e.GetState() {
case NotRunning:
for _, mod := range e.mods {
if mod.Value == m {
e.t.Fatalf("test bug: duplicate modifier added\nfirst added by: %s", mod.Caller)
}
}
e.mods = append(e.mods, WithCaller[Modifier]{
Caller: caller,
Value: m,
})
e.debugf("Add: state=NotRunning; calling Attach")
m.Attach(e.Context())
case Starting:
panic("test bug: cannot call Add() before Start() has returned")
case Running:
e.debugf("Add: state=Running; calling ModifyConfig")
e.src.ModifyConfig(e.Context(), m)
case Stopped, Stopping:
panic("test bug: cannot call Add() after Stop()")
default:
panic(fmt.Sprintf("unexpected environment state: %s", e.GetState()))
}
}
func (e *environment) AddTask(t Task) {
e.t.Helper()
caller := getCaller()
e.debugf("AddTask: %T from %s", t, caller)
for _, task := range e.tasks {
if task.Value == t {
e.t.Fatalf("test bug: duplicate task added\nfirst added by: %s", task.Caller)
}
}
e.tasks = append(e.tasks, WithCaller[Task]{
Caller: getCaller(),
Value: t,
})
}
func (e *environment) AddUpstream(up Upstream) {
e.t.Helper()
caller := getCaller()
e.debugf("AddUpstream: %T from %s", up, caller)
e.Add(up)
e.AddTask(up)
}
// ReportError implements health.Provider.
func (e *environment) ReportError(check health.Check, err error, attributes ...health.Attr) {
// note: don't use e.t.Fatal here, it will deadlock
panic(fmt.Sprintf("%s: %v %v", check, err, attributes))
}
// ReportOK implements health.Provider.
func (e *environment) ReportOK(_ health.Check, _ ...health.Attr) {}
func (e *environment) advanceState(newState EnvironmentState) {
e.stateMu.Lock()
defer e.stateMu.Unlock()
if newState <= e.state {
panic(fmt.Sprintf("internal test environment bug: changed state to <= current: newState=%s, current=%s", newState, e.state))
}
e.stateChangeBlockers.Wait()
e.debugf("state %s -> %s", e.state.String(), newState.String())
e.state = newState
if len(e.stateChangeListeners[newState]) > 0 {
e.debugf("notifying %d listeners of state change", len(e.stateChangeListeners[newState]))
var wg sync.WaitGroup
for _, listener := range e.stateChangeListeners[newState] {
wg.Add(1)
go func() {
_, span := e.tracer.Start(e.Context(), "State Change Callback")
span.SetAttributes(attribute.String("state", newState.String()))
defer span.End()
defer wg.Done()
listener()
}()
}
wg.Wait()
e.debugf("done notifying state change listeners")
}
}
func (e *environment) GetState() EnvironmentState {
e.stateMu.Lock()
defer e.stateMu.Unlock()
return e.state
}
func (e *environment) OnStateChanged(state EnvironmentState, callback func()) (cancel func() bool) {
e.stateMu.Lock()
defer e.stateMu.Unlock()
_, file, line, _ := runtime.Caller(1)
switch {
case state < e.state:
panic(fmt.Sprintf("test bug: OnStateChanged called with state %s which is < current state (%s)", state, e.sharedSecret))
case state == e.state:
e.stateChangeBlockers.Add(1)
e.debugf("invoking callback for current state (state: %s, caller: %s:%d)", state.String(), file, line)
go func() {
defer func() {
e.stateChangeBlockers.Done()
}()
callback()
}()
return func() bool { return false }
default:
canceled := &atomic.Bool{}
e.stateChangeListeners[state] = append(e.stateChangeListeners[state], func() {
if canceled.CompareAndSwap(false, true) {
e.debugf("invoking state change callback (caller: %s:%d)", file, line)
callback()
}
})
return func() bool {
e.debugf("stopped state change callback (state: %s, caller: %s:%d)", state.String(), file, line)
return canceled.CompareAndSwap(false, true)
}
}
}
func getCaller(skip ...int) string {
if len(skip) == 0 {
skip = append(skip, 3)
}
callers := make([]uintptr, 8)
runtime.Callers(skip[0], callers)
frames := runtime.CallersFrames(callers)
var caller string
for {
next, ok := frames.Next()
if !ok {
break
}
if path.Base(next.Function) == "testenv.(*environment).AddUpstream" {
continue
}
caller = fmt.Sprintf("%s:%d", next.File, next.Line)
break
}
return caller
}
func wildcardDomain(names []string) string {
for _, name := range names {
if name[0] == '*' {
return name[2:]
}
}
panic("test bug: no wildcard domain in certificate")
}
func isSilent(t testing.TB) bool {
switch t.(type) {
case *testing.B:
return !slices.Contains(os.Args, "-test.v=true")
default:
return false
}
}
type configSource struct {
mu sync.Mutex
cfg *config.Config
lis []config.ChangeListener
}
var _ config.Source = (*configSource)(nil)
// GetConfig implements config.Source.
func (src *configSource) GetConfig() *config.Config {
src.mu.Lock()
defer src.mu.Unlock()
return src.cfg
}
// OnConfigChange implements config.Source.
func (src *configSource) OnConfigChange(_ context.Context, li config.ChangeListener) {
src.mu.Lock()
defer src.mu.Unlock()
src.lis = append(src.lis, li)
}
// ModifyConfig updates the current configuration by applying a [Modifier].
func (src *configSource) ModifyConfig(ctx context.Context, m Modifier) {
src.mu.Lock()
defer src.mu.Unlock()
m.Modify(src.cfg)
for _, li := range src.lis {
li(ctx, src.cfg)
}
}