Initial test environment implementation

This commit is contained in:
Joe Kralicky 2024-08-22 17:15:20 -04:00
parent 6591e3f539
commit 79ba9fcf52
No known key found for this signature in database
GPG key ID: 75C4875F34A9FB79
14 changed files with 1085 additions and 14 deletions

View file

@ -18,10 +18,74 @@ import (
"google.golang.org/grpc/status"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/testenv"
"github.com/pomerium/pomerium/internal/testenv/upstreams"
"github.com/pomerium/pomerium/internal/testenv/values"
"github.com/pomerium/pomerium/pkg/cmd/pomerium"
"github.com/pomerium/pomerium/pkg/netutil"
)
func TestH2C_v2(t *testing.T) {
env := testenv.New(t)
up := upstreams.GRPC(insecure.NewCredentials())
grpc_testing.RegisterTestServiceServer(up, interop.NewTestServer())
http := up.Route().
From(env.SubdomainURL("grpc-http")).
To(values.Bind(up.Port(), func(port int) string {
// override the target protocol to use http://
return fmt.Sprintf("http://127.0.0.1:%d", port)
})).
Policy(func(p *config.Policy) { p.AllowPublicUnauthenticatedAccess = true })
h2c := up.Route().
From(env.SubdomainURL("grpc-h2c")).
Policy(func(p *config.Policy) { p.AllowPublicUnauthenticatedAccess = true })
env.AddUpstream(up)
env.Start()
t.Run("h2c", func(t *testing.T) {
t.Parallel()
recorder := env.NewLogRecorder()
cc := up.Dial(h2c)
client := grpc_testing.NewTestServiceClient(cc)
_, err := client.EmptyCall(env.Context(), &grpc_testing.Empty{})
require.NoError(t, err)
cc.Close()
recorder.Match([]map[string]any{
{
"service": "envoy",
"path": "/grpc.testing.TestService/EmptyCall",
"message": "http-request",
"response-code-details": "via_upstream",
},
})
})
t.Run("http", func(t *testing.T) {
t.Parallel()
recorder := env.NewLogRecorder()
cc := up.Dial(http)
client := grpc_testing.NewTestServiceClient(cc)
_, err := client.UnaryCall(env.Context(), &grpc_testing.SimpleRequest{})
require.Error(t, err)
cc.Close()
recorder.Match([]map[string]any{
{
"service": "envoy",
"path": "/grpc.testing.TestService/UnaryCall",
"message": "http-request",
"response-code-details": "upstream_reset_before_response_started{protocol_error}",
},
})
})
}
func TestH2C(t *testing.T) {
if testing.Short() {
t.SkipNow()

2
go.mod
View file

@ -50,6 +50,7 @@ require (
github.com/peterbourgon/ff/v3 v3.4.0
github.com/pomerium/csrf v1.7.0
github.com/pomerium/datasource v0.18.2-0.20221108160055-c6134b5ed524
github.com/pomerium/protoutil v0.0.0-20240813175624-47b7ac43ff46
github.com/pomerium/webauthn v0.0.0-20240603205124-0428df511172
github.com/prometheus/client_golang v1.19.1
github.com/prometheus/client_model v0.6.1
@ -177,7 +178,6 @@ require (
github.com/morikuni/aec v1.0.0 // indirect
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
github.com/onsi/ginkgo v1.16.5 // indirect
github.com/onsi/gomega v1.30.0 // indirect
github.com/opencontainers/go-digest v1.0.0 // indirect
github.com/opencontainers/image-spec v1.1.0 // indirect
github.com/opencontainers/runc v1.1.12 // indirect

6
go.sum
View file

@ -481,8 +481,8 @@ github.com/onsi/ginkgo v1.16.5/go.mod h1:+E8gABHa3K6zRBolWtd+ROzc/U5bkGt0FwiG042
github.com/onsi/gomega v1.4.3/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY=
github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7JYyY=
github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1ybHNo=
github.com/onsi/gomega v1.30.0 h1:hvMK7xYz4D3HapigLTeGdId/NcfQx1VHMJc60ew99+8=
github.com/onsi/gomega v1.30.0/go.mod h1:9sxs+SwGrKI0+PWe4Fxa9tFQQBG5xSsSbMXOI8PPpoQ=
github.com/onsi/gomega v1.34.1 h1:EUMJIKUjM8sKjYbtxQI9A4z2o+rruxnzNvpknOXie6k=
github.com/onsi/gomega v1.34.1/go.mod h1:kU1QgUvBDLXBJq618Xvm2LUX6rSAfRaFRTcdOeDLwwY=
github.com/open-policy-agent/opa v0.67.1 h1:rzy26J6g1X+CKknAcx0Vfbt41KqjuSzx4E0A8DAZf3E=
github.com/open-policy-agent/opa v0.67.1/go.mod h1:aqKlHc8E2VAAylYE9x09zJYr/fYzGX+JKne89UGqFzk=
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
@ -522,6 +522,8 @@ github.com/pomerium/csrf v1.7.0 h1:Qp4t6oyEod3svQtKfJZs589mdUTWKVf7q0PgCKYCshY=
github.com/pomerium/csrf v1.7.0/go.mod h1:hAPZV47mEj2T9xFs+ysbum4l7SF1IdrryYaY6PdoIqw=
github.com/pomerium/datasource v0.18.2-0.20221108160055-c6134b5ed524 h1:3YQY1sb54tEEbr0L73rjHkpLB0IB6qh3zl1+XQbMLis=
github.com/pomerium/datasource v0.18.2-0.20221108160055-c6134b5ed524/go.mod h1:7fGbUYJnU8RcxZJvUvhukOIBv1G7LWDAHMfDxAf5+Y0=
github.com/pomerium/protoutil v0.0.0-20240813175624-47b7ac43ff46 h1:NRTg8JOXCxcIA1lAgD74iYud0rbshbWOB3Ou4+Huil8=
github.com/pomerium/protoutil v0.0.0-20240813175624-47b7ac43ff46/go.mod h1:QqZmx6ZgPxz18va7kqoT4t/0yJtP7YFIDiT/W2n2fZ4=
github.com/pomerium/webauthn v0.0.0-20240603205124-0428df511172 h1:TqoPqRgXSHpn+tEJq6H72iCS5pv66j3rPprThUEZg0E=
github.com/pomerium/webauthn v0.0.0-20240603205124-0428df511172/go.mod h1:kBQ45E9LluzW7FP1Scn3esaiS2WVbvNRLMOTHareZNQ=
github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 h1:o4JXh1EVt9k/+g42oCprj/FisM4qX9L3sZB3upGN2ZU=

View file

@ -69,7 +69,7 @@ type Server struct {
}
// NewServer creates a new Server. Listener ports are chosen by the OS.
func NewServer(cfg *config.Config, metricsMgr *config.MetricsManager, eventsMgr *events.Manager) (*Server, error) {
func NewServer(ctx context.Context, cfg *config.Config, metricsMgr *config.MetricsManager, eventsMgr *events.Manager) (*Server, error) {
srv := &Server{
metricsMgr: metricsMgr,
EventsMgr: eventsMgr,
@ -80,6 +80,10 @@ func NewServer(cfg *config.Config, metricsMgr *config.MetricsManager, eventsMgr
httpRouter: atomicutil.NewValue(mux.NewRouter()),
}
ctx = log.WithContext(ctx, func(c zerolog.Context) zerolog.Context {
return c.Str("server_name", cfg.Options.Services)
})
var err error
// setup gRPC
@ -96,7 +100,11 @@ func NewServer(cfg *config.Config, metricsMgr *config.MetricsManager, eventsMgr
srv.GRPCServer = grpc.NewServer(
grpc.StatsHandler(telemetry.NewGRPCServerStatsHandler(cfg.Options.Services)),
grpc.ChainUnaryInterceptor(requestid.UnaryServerInterceptor(), ui),
grpc.ChainStreamInterceptor(requestid.StreamServerInterceptor(), si),
grpc.ChainStreamInterceptor(
log.StreamServerInterceptor(log.Ctx(ctx)),
requestid.StreamServerInterceptor(),
si,
),
)
reflection.Register(srv.GRPCServer)
srv.registerAccessLogHandlers()
@ -152,10 +160,6 @@ func NewServer(cfg *config.Config, metricsMgr *config.MetricsManager, eventsMgr
srv.reproxy,
)
ctx := log.WithContext(context.Background(), func(c zerolog.Context) zerolog.Context {
return c.Str("server_name", cfg.Options.Services)
})
res, err := srv.buildDiscoveryResources(ctx)
if err != nil {
return nil, err

View file

@ -38,7 +38,7 @@ func TestServerHTTP(t *testing.T) {
cfg.Options.SharedKey = "JDNjY2ITDlARvNaQXjc2Djk+GA6xeCy4KiozmZfdbTs="
src := config.NewStaticSource(cfg)
srv, err := NewServer(cfg, config.NewMetricsManager(ctx, src), events.New())
srv, err := NewServer(ctx, cfg, config.NewMetricsManager(ctx, src), events.New())
require.NoError(t, err)
go srv.Run(ctx)

View file

@ -96,7 +96,7 @@ func (src *ConfigSource) rebuild(ctx context.Context, firstTime firstTime) {
cfg := src.underlyingConfig.Clone()
// start the updater
src.runUpdater(cfg)
src.runUpdater(ctx, cfg)
now = time.Now()
err := src.buildNewConfigLocked(ctx, cfg)
@ -234,7 +234,7 @@ func (src *ConfigSource) addPolicies(ctx context.Context, cfg *config.Config, po
cfg.Options.AdditionalPolicies = append(cfg.Options.AdditionalPolicies, additionalPolicies...)
}
func (src *ConfigSource) runUpdater(cfg *config.Config) {
func (src *ConfigSource) runUpdater(ctx context.Context, cfg *config.Config) {
sharedKey, _ := cfg.Options.GetSharedKey()
connectionOptions := &grpc.OutboundOptions{
OutboundPort: cfg.OutboundPort,
@ -257,7 +257,6 @@ func (src *ConfigSource) runUpdater(cfg *config.Config) {
src.cancel = nil
}
ctx := context.Background()
ctx, src.cancel = context.WithCancel(ctx)
cc, err := src.outboundGRPCConnection.Get(ctx, connectionOptions)

View file

@ -5,7 +5,9 @@ import (
"net/http"
"time"
"github.com/pomerium/protoutil/streams"
"github.com/rs/zerolog"
"google.golang.org/grpc"
"github.com/pomerium/pomerium/internal/middleware/responsewriter"
"github.com/pomerium/pomerium/pkg/telemetry/requestid"
@ -121,3 +123,11 @@ func HeadersHandler(headers []string) func(next http.Handler) http.Handler {
})
}
}
func StreamServerInterceptor(lg *zerolog.Logger) grpc.StreamServerInterceptor {
return func(srv any, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
s := streams.NewServerStreamWithContext(ss)
s.SetContext(lg.WithContext(s.Ctx))
return handler(srv, s)
}
}

View file

@ -0,0 +1,329 @@
package testenv
import (
"context"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"os"
"path"
"path/filepath"
"runtime"
"strconv"
"sync"
"testing"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/testenv/values"
"github.com/pomerium/pomerium/pkg/cmd/pomerium"
"github.com/pomerium/pomerium/pkg/health"
"github.com/pomerium/pomerium/pkg/netutil"
"github.com/rs/zerolog"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/sync/errgroup"
)
// Environment is a lightweight integration test fixture that runs Pomerium
// in-process.
type Environment interface {
// Context returns the environment's root context. This context holds a
// top-level logger scoped to this environment. It will be canceled when
// Stop() is called, or during test cleanup.
Context() context.Context
// TempDir returns a unique temp directory for this context. Calling this
// function multiple times returns the same path.
TempDir() string
// ServerCAs returns a new [*x509.CertPool] containing the root CA certificate
// used to sign the server cert and other test certificates.
ServerCAs() *x509.CertPool
// ServerCert returns the Pomerium server's certificate.
ServerCert() *tls.Certificate
// Add adds the given [Modifier] to the environment. All modifiers will be
// invoked upon calling Start() to apply individual modifications to the
// configuration before starting the Pomerium server.
Add(m Modifier)
// AddTask adds the given [Task] to the environment. All tasks will be
// started in separate goroutines upon calling Start(). If any tasks exit
// with an error, the environment will be stopped and the test will fail.
AddTask(r Task)
// AddUpstream adds the given [Upstream] to the environment. This function is
// equivalent to calling both Add() and AddTask() with the upstream, but
// improves readability.
AddUpstream(u Upstream)
// Start starts the test environment, and adds a call to Stop() as a cleanup
// hook to the environment's [testing.T]. All previously added [Modifier]
// instances are invoked in order to build the configuration, and all
// previously added [Task] instances are started in the background.
//
// Calling Start() more than once, Calling Start() after Stop(), or calling
// any of the Add* functions after Start() will panic.
Start()
// Stop stops the test environment. Calling this function more than once has
// no effect. It is usually not necessary to call Stop() directly unless you
// need to stop the test environment before the test is completed.
Stop()
// SubdomainURL returns a string [values.Value] which will contain a complete
// URL for the given subdomain of the server's domain (given by its serving
// certificate), including the 'https://' scheme and random http server port.
// This value will only be resolved some time after Start() is called, and
// can be used as the 'from' value for routes.
SubdomainURL(subdomain string) values.Value[string]
// NewLogRecorder returns a new [*LogRecorder] and starts capturing logs for
// the Pomerium server and Envoy.
NewLogRecorder(opts ...LogRecorderOption) *LogRecorder
}
type environment struct {
t testing.TB
tempDir string
domain string
ports Ports
workspaceFolder string
ctx context.Context
cancel context.CancelCauseFunc
cleanupOnce sync.Once
logWriter *log.MultiWriter
mods []WithCaller[Modifier]
tasks []WithCaller[Task]
taskErrGroup *errgroup.Group
}
func New(t testing.TB) Environment {
if testing.Short() {
t.Helper()
t.Skip("test environment disabled in short mode")
}
workspaceFolder, err := os.Getwd()
require.NoError(t, err)
for {
if _, err := os.Stat(filepath.Join(workspaceFolder, ".git")); err == nil {
break
}
workspaceFolder = filepath.Dir(workspaceFolder)
if workspaceFolder == "/" {
panic("could not find workspace root")
}
}
workspaceFolder, err = filepath.Abs(workspaceFolder)
require.NoError(t, err)
writer := log.NewMultiWriter()
writer.Add(os.Stdout)
logger := zerolog.New(writer).With().Timestamp().Logger().Level(zerolog.DebugLevel)
ctx, cancel := context.WithCancelCause(logger.WithContext(context.Background()))
taskErrGroup, ctx := errgroup.WithContext(ctx)
e := &environment{
t: t,
tempDir: t.TempDir(),
ports: Ports{
http: values.Deferred[int](),
},
workspaceFolder: workspaceFolder,
ctx: ctx,
cancel: cancel,
logWriter: writer,
taskErrGroup: taskErrGroup,
}
health.SetProvider(e)
require.NoError(t, os.Mkdir(filepath.Join(e.tempDir, "certs"), 0o777))
copyFile := func(src, dstRel string) {
data, err := os.ReadFile(src)
require.NoError(t, err)
require.NoError(t, os.WriteFile(filepath.Join(e.tempDir, dstRel), data, 0o666))
}
certsToCopy := []string{
"trusted.pem",
"trusted-key.pem",
"ca.pem",
}
for _, crt := range certsToCopy {
copyFile(filepath.Join(workspaceFolder, "integration/tpl/files", crt), filepath.Join("certs/", filepath.Base(crt)))
}
e.domain = wildcardDomain(e.ServerCert().Leaf.DNSNames)
return e
}
type WithCaller[T any] struct {
Caller string
Value T
}
type Ports struct {
http values.MutableValue[int]
}
func (e *environment) TempDir() string {
return e.tempDir
}
func (e *environment) Context() context.Context {
return ContextWithEnv(e.ctx, e)
}
func (e *environment) SubdomainURL(subdomain string) values.Value[string] {
return values.Bind(e.ports.http, func(port int) string {
return fmt.Sprintf("https://%s.%s:%d", subdomain, e.domain, port)
})
}
func (e *environment) ServerCAs() *x509.CertPool {
pool := x509.NewCertPool()
caCert, err := os.ReadFile(filepath.Join(e.tempDir, "certs", "ca.pem"))
require.NoError(e.t, err)
pool.AppendCertsFromPEM(caCert)
return pool
}
func (e *environment) ServerCert() *tls.Certificate {
serverCert, err := tls.LoadX509KeyPair(
filepath.Join(e.tempDir, "certs", "trusted.pem"),
filepath.Join(e.tempDir, "certs", "trusted-key.pem"),
)
require.NoError(e.t, err)
return &serverCert
}
// Used as the context's cancel cause during normal cleanup
var ErrCauseTestCleanup = errors.New("test cleanup")
// Used as the context's cancel cause when Stop() is called
var ErrCauseManualStop = errors.New("Stop() called")
func (e *environment) Start() {
e.t.Cleanup(e.cleanup)
cfg := &config.Config{
Options: config.NewDefaultOptions(),
}
ports, err := netutil.AllocatePorts(7)
require.NoError(e.t, err)
port0, _ := strconv.Atoi(ports[0])
e.ports.http.Resolve(port0)
cfg.Options.LogLevel = config.LogLevelInfo
cfg.Options.ProxyLogLevel = config.LogLevelInfo
cfg.Options.Addr = fmt.Sprintf("127.0.0.1:%d", port0)
cfg.Options.CertFile = filepath.Join(e.tempDir, "certs", "trusted.pem")
cfg.Options.KeyFile = filepath.Join(e.tempDir, "certs", "trusted-key.pem")
cfg.AllocatePorts(*(*[6]string)(ports[1:]))
e.AddTask(TaskFunc(func(ctx context.Context) error {
src := config.NewStaticSource(cfg)
for _, mod := range e.mods {
mod.Value.Modify(cfg)
require.NoError(e.t, cfg.Options.Validate(), "invoking modifier resulted in an invalid configuration:\nadded by: "+mod.Caller)
}
return pomerium.Run(e.ctx, src)
}))
for i, task := range e.tasks {
log.Ctx(e.ctx).Debug().Str("caller", task.Caller).Msgf("starting task %d", i)
e.taskErrGroup.Go(func() error {
defer log.Ctx(e.ctx).Debug().Str("caller", task.Caller).Msgf("task %d exited", i)
return task.Value.Run(e.ctx)
})
}
}
func (e *environment) Stop() {
e.cleanupOnce.Do(func() {
e.cancel(ErrCauseManualStop)
err := e.taskErrGroup.Wait()
assert.ErrorIs(e.t, err, ErrCauseManualStop)
})
}
func (e *environment) cleanup() {
e.cleanupOnce.Do(func() {
e.cancel(ErrCauseTestCleanup)
err := e.taskErrGroup.Wait()
assert.ErrorIs(e.t, err, ErrCauseTestCleanup)
})
}
func (e *environment) Add(c Modifier) {
e.t.Helper()
for _, mod := range e.mods {
if mod.Value == c {
e.t.Fatalf("test bug: duplicate modifier added\nfirst added by: %s", mod.Caller)
}
}
e.mods = append(e.mods, WithCaller[Modifier]{
Caller: getCaller(),
Value: c,
})
c.Attach(e.Context())
}
func (e *environment) AddTask(r Task) {
e.t.Helper()
for _, task := range e.tasks {
if task.Value == r {
e.t.Fatalf("test bug: duplicate task added\nfirst added by: %s", task.Caller)
}
}
e.tasks = append(e.tasks, WithCaller[Task]{
Caller: getCaller(),
Value: r,
})
}
func (e *environment) AddUpstream(up Upstream) {
e.t.Helper()
e.Add(up)
e.AddTask(up)
}
// ReportError implements health.Provider.
func (e *environment) ReportError(check health.Check, err error, attributes ...health.Attr) {
// note: don't use e.t.Fatal here, it will deadlock
panic(fmt.Sprintf("%s: %v %v", check, err, attributes))
}
// ReportOK implements health.Provider.
func (e *environment) ReportOK(check health.Check, attributes ...health.Attr) {
}
func getCaller(skip ...int) string {
if len(skip) == 0 {
skip = append(skip, 3)
}
callers := make([]uintptr, 8)
runtime.Callers(skip[0], callers)
frames := runtime.CallersFrames(callers)
var caller string
for {
next, ok := frames.Next()
if !ok {
break
}
if path.Base(next.Function) == "testenv.(*environment).AddUpstream" {
continue
}
caller = fmt.Sprintf("%s:%d", next.File, next.Line)
break
}
return caller
}
func wildcardDomain(names []string) string {
for _, name := range names {
if name[0] == '*' {
return name[2:]
}
}
panic("test bug: no wildcard domain in certificate")
}

210
internal/testenv/logs.go Normal file
View file

@ -0,0 +1,210 @@
package testenv
import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// LogRecorder captures logs from the test environment. It can be created at
// any time by calling [Environment.NewLogRecorder], and captures logs until
// one of Close(), Logs(), or Match() is called, which stops recording. See the
// documentation for each method for more details.
type LogRecorder struct {
LogRecorderOptions
t testing.TB
buf *bytes.Buffer
recordedLogs []map[string]any
closeOnce func()
collectLogsOnce sync.Once
}
type LogRecorderOptions struct {
filters []func(map[string]any) bool
}
type LogRecorderOption func(*LogRecorderOptions)
func (o *LogRecorderOptions) apply(opts ...LogRecorderOption) {
for _, op := range opts {
op(o)
}
}
// WithFilters applies one or more filter predicates to the logger. If there
// are filters present, they will be called in order when a log is received,
// and if any filter returns false for a given log, it will be discarded.
func WithFilters(filters ...func(map[string]any) bool) LogRecorderOption {
return func(o *LogRecorderOptions) {
o.filters = filters
}
}
func (e *environment) NewLogRecorder(opts ...LogRecorderOption) *LogRecorder {
options := LogRecorderOptions{}
options.apply(opts...)
lr := &LogRecorder{
LogRecorderOptions: options,
t: e.t,
buf: &bytes.Buffer{},
}
e.logWriter.Add(lr.buf)
lr.closeOnce = sync.OnceFunc(func() {
// wait for envoy access logs, which flush on a 1 second interval
time.Sleep(1100 * time.Millisecond)
e.logWriter.Remove(lr.buf)
})
context.AfterFunc(e.ctx, lr.closeOnce)
return lr
}
type (
// OpenMap is an alias for map[string]any, and can be used to semantically
// represent a map that must contain at least the given entries, but may
// also contain additional entries.
OpenMap = map[string]any
// ClosedMap is a map[string]any that can be used to semantically represent
// a map that must contain the given entries exactly, and no others.
ClosedMap map[string]any
)
// Close stops the log recorder. After calling this method, Logs() or Match()
// can be called to inspect the logs that were captured.
func (lr *LogRecorder) Close() {
lr.closeOnce()
}
func (lr *LogRecorder) collectLogs() {
lr.closeOnce()
lr.collectLogsOnce.Do(func() {
recordedLogs := []map[string]any{}
scan := bufio.NewScanner(lr.buf)
for scan.Scan() {
log := scan.Bytes()
m := map[string]any{}
decoder := json.NewDecoder(bytes.NewReader(log))
decoder.UseNumber()
require.NoError(lr.t, decoder.Decode(&m))
for _, filter := range lr.filters {
if !filter(m) {
continue
}
}
recordedLogs = append(recordedLogs, m)
}
lr.recordedLogs = recordedLogs
})
}
// Logs stops the log recorder (if it is not already stopped), then returns
// the logs that were captured as structured map[string]any objects.
func (lr *LogRecorder) Logs() []map[string]any {
lr.collectLogs()
return lr.recordedLogs
}
// Match stops the log recorder (if it is not already stopped), then asserts
// that the given expected logs were captured. The expected logs may contain
// partial or complete log entries. By default, logs must only match the fields
// given, and may contain additional fields that will be ignored. For details,
// see [OpenMap] and [ClosedMap]. As a special case, using [json.Number] as the
// expected value will convert the actual value to a string before comparison.
func (lr *LogRecorder) Match(expectedLogs []map[string]any) {
lr.collectLogs()
var match func(expected, actual map[string]any, open bool) (bool, int)
match = func(expected, actual map[string]any, open bool) (bool, int) {
score := 0
for key, value := range expected {
actualValue, ok := actual[key]
if !ok {
return false, score
}
score++
switch actualValue := actualValue.(type) {
case map[string]any:
switch value := value.(type) {
case ClosedMap:
ok, s := match(value, actualValue, false)
score += s * 2
if !ok {
return false, score
}
case OpenMap:
ok, s := match(value, actualValue, true)
score += s
if !ok {
return false, score
}
default:
return false, score
}
case string:
switch value := value.(type) {
case string:
if value != actualValue {
return false, score
}
score++
default:
return false, score
}
case json.Number:
if fmt.Sprint(value) != actualValue.String() {
return false, score
}
score++
default:
panic(fmt.Sprintf("test bug: add check for type %T in assertMatchingLogs", actualValue))
}
}
if !open && len(expected) != len(actual) {
return false, score
}
return true, score
}
for _, expectedLog := range expectedLogs {
found := false
highScore, highScoreIdxs := 0, []int{}
for i, actualLog := range lr.recordedLogs {
if ok, score := match(expectedLog, actualLog, true); ok {
found = true
break
} else if score > highScore {
highScore = score
highScoreIdxs = []int{i}
} else if score == highScore {
highScoreIdxs = append(highScoreIdxs, i)
}
}
if len(highScoreIdxs) > 0 {
expectedLogBytes, _ := json.MarshalIndent(expectedLog, "", " ")
if len(highScoreIdxs) == 1 {
actualLogBytes, _ := json.MarshalIndent(lr.recordedLogs[highScoreIdxs[0]], "", " ")
assert.True(lr.t, found, "expected log not found: \n%s\n\nclosest match:\n%s\n",
string(expectedLogBytes), string(actualLogBytes))
} else {
closestMatches := []string{}
for _, i := range highScoreIdxs {
bytes, _ := json.MarshalIndent(lr.recordedLogs[i], "", " ")
closestMatches = append(closestMatches, string(bytes))
}
assert.True(lr.t, found, "expected log not found: \n%s\n\nclosest matches:\n%s\n", string(expectedLogBytes), closestMatches)
}
} else {
expectedLogBytes, _ := json.MarshalIndent(expectedLog, "", " ")
assert.True(lr.t, found, "expected log not found: %s", string(expectedLogBytes))
}
}
}

60
internal/testenv/route.go Normal file
View file

@ -0,0 +1,60 @@
package testenv
import (
"net/url"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/testenv/values"
)
// PolicyRoute is a [Route] implementation suitable for most common use cases
// that can be used in implementations of [Upstream].
type PolicyRoute struct {
DefaultAttach
from values.Value[string]
to values.List[string]
edits []func(*config.Policy)
}
// Modify implements Route.
func (b *PolicyRoute) Modify(cfg *config.Config) {
to := make(config.WeightedURLs, 0, len(b.to))
for _, u := range b.to {
u, err := url.Parse(u.Value())
if err != nil {
panic(err)
}
to = append(to, config.WeightedURL{URL: *u})
}
p := config.Policy{
From: b.from.Value(),
To: to,
}
for _, edit := range b.edits {
edit(&p)
}
cfg.Options.Policies = append(cfg.Options.Policies, p)
}
// From implements Route.
func (b *PolicyRoute) From(fromUrl values.Value[string]) Route {
b.from = fromUrl
return b
}
// To implements Route.
func (b *PolicyRoute) To(toUrl values.Value[string]) Route {
b.to = append(b.to, toUrl)
return b
}
// To implements Route.
func (b *PolicyRoute) Policy(edit func(*config.Policy)) Route {
b.edits = append(b.edits, edit)
return b
}
// To implements Route.
func (b *PolicyRoute) URL() values.Value[string] {
return b.from
}

127
internal/testenv/types.go Normal file
View file

@ -0,0 +1,127 @@
package testenv
import (
"context"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/testenv/values"
)
type envContextKeyType struct{}
var envContextKey envContextKeyType
func EnvFromContext(ctx context.Context) Environment {
return ctx.Value(envContextKey).(Environment)
}
func ContextWithEnv(ctx context.Context, env Environment) context.Context {
return context.WithValue(ctx, envContextKey, env)
}
// A Modifier is an object whose presence in the test affects the Pomerium
// configuration in some way. When the test environment is started, a
// [*config.Config] is constructed by calling each added Modifier in order.
//
// For additional details, see [Environment.Add] and [Environment.Start].
type Modifier interface {
// Attach is called by an [Environment] (before Modify) to propagate the
// environment's context.
Attach(ctx context.Context)
// Modify is called by an [Environment] to mutate its configuration in some
// way required by this Modifier.
Modify(cfg *config.Config)
}
// DefaultAttach should be embedded in types implementing [Modifier] to
// automatically obtain environment context details and caller information.
type DefaultAttach struct {
env Environment
caller string
}
func (d *DefaultAttach) Env() Environment {
d.CheckAttached()
return d.env
}
func (d *DefaultAttach) Attach(ctx context.Context) {
if d.env != nil {
panic("internal test environment bug: Attach called twice")
}
d.env = EnvFromContext(ctx)
if d.env == nil {
panic("test bug: no environment in context")
}
}
func (d *DefaultAttach) CheckAttached() {
if d.env == nil {
if d.caller != "" {
panic("test bug: missing a call to Add for the object created at: " + d.caller)
}
panic("test bug: not attached (possibly missing a call to Add)")
}
}
func (d *DefaultAttach) RecordCaller() {
d.caller = getCaller(4)
}
type Modifiers []Modifier
func (m Modifiers) Modify(cfg *config.Config) {
for _, mod := range m {
mod.Modify(cfg)
}
}
type ModifierFunc func(cfg *config.Config)
func (f ModifierFunc) Modify(cfg *config.Config) {
f(cfg)
}
// Task represents a background task that can be added to an [Environment] to
// have it run automatically on startup.
//
// For additional details, see [Environment.AddTask] and [Environment.Start].
type Task interface {
Run(ctx context.Context) error
}
type TaskFunc func(ctx context.Context) error
func (f TaskFunc) Run(ctx context.Context) error {
return f(ctx)
}
// Upstream represents an upstream server. It is both a [Task] and a [Modifier]
// and can be added to an environment using [Environment.AddUpstream]. From an
// Upstream instance, new routes can be created (which automatically adds the
// necessary route/policy entries to the config), and used within a test to
// easily make requests to the routes with implementation-specific clients.
type Upstream interface {
Modifier
Task
Port() values.Value[int]
Route() RouteStub
}
// A Route represents a route from a source URL to a destination URL. A route is
// typically created by calling [Upstream.Route].
type Route interface {
Modifier
URL() values.Value[string]
To(toUrl values.Value[string]) Route
Policy(edit func(*config.Policy)) Route
// add more methods here as they become needed
}
// RouteStub represents an incomplete [Route]. Providing a URL by calling its
// From() method will return a [Route], from which further configuration can
// be made.
type RouteStub interface {
From(fromUrl values.Value[string]) Route
}

View file

@ -0,0 +1,146 @@
package upstreams
import (
"context"
"fmt"
"net"
"strings"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/testenv"
"github.com/pomerium/pomerium/internal/testenv/values"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
)
type Options struct {
serverOpts []grpc.ServerOption
}
type Option func(*Options)
func (o *Options) apply(opts ...Option) {
for _, op := range opts {
op(o)
}
}
func ServerOpts(opt ...grpc.ServerOption) Option {
return func(o *Options) {
o.serverOpts = append(o.serverOpts, opt...)
}
}
// GRPCUpstream represents a GRPC server which can be used as the target for
// one or more Pomerium routes in a test environment.
//
// This upstream implements [grpc.ServiceRegistrar], and can be used similarly
// in the same way as [*grpc.Server] to register services before it is started.
//
// Any [testenv.Route] instances created from this upstream can be referenced
// in the Dial() method to establish a connection to that route.
type GRPCUpstream interface {
testenv.Upstream
grpc.ServiceRegistrar
Dial(r testenv.Route, dialOpts ...grpc.DialOption) *grpc.ClientConn
}
type grpcUpstream struct {
Options
testenv.DefaultAttach
serverPort values.MutableValue[int]
creds credentials.TransportCredentials
routes testenv.Modifiers
services []service
}
var (
_ testenv.Upstream = (*grpcUpstream)(nil)
_ grpc.ServiceRegistrar = (*grpcUpstream)(nil)
)
// GRPC Creates a new GRPC upstream server.
func GRPC(creds credentials.TransportCredentials, opts ...Option) GRPCUpstream {
options := Options{}
options.apply(opts...)
up := &grpcUpstream{
Options: options,
creds: creds,
serverPort: values.Deferred[int](),
}
up.RecordCaller()
return up
}
type service struct {
desc *grpc.ServiceDesc
impl any
}
func (g *grpcUpstream) Port() values.Value[int] {
return g.serverPort
}
// Modify implements testenv.Upstream.
func (g *grpcUpstream) Modify(cfg *config.Config) {
g.routes.Modify(cfg)
}
// RegisterService implements grpc.ServiceRegistrar.
func (g *grpcUpstream) RegisterService(desc *grpc.ServiceDesc, impl any) {
g.services = append(g.services, service{desc, impl})
}
// Route implements testenv.Upstream.
func (g *grpcUpstream) Route() testenv.RouteStub {
r := &testenv.PolicyRoute{}
var protocol string
switch g.creds.Info().SecurityProtocol {
case "insecure":
protocol = "h2c"
default:
protocol = "https"
}
r.To(values.Bind(g.serverPort, func(port int) string {
return fmt.Sprintf("%s://127.0.0.1:%d", protocol, port)
}))
g.routes = append(g.routes, r)
return r
}
// Start implements testenv.Upstream.
func (g *grpcUpstream) Run(ctx context.Context) error {
listener, err := net.Listen("tcp4", "127.0.0.1:0")
if err != nil {
return err
}
g.serverPort.Resolve(listener.Addr().(*net.TCPAddr).Port)
server := grpc.NewServer(append(g.serverOpts, grpc.Creds(g.creds))...)
for _, s := range g.services {
server.RegisterService(s.desc, s.impl)
}
errC := make(chan error, 1)
go func() {
errC <- server.Serve(listener)
}()
select {
case <-ctx.Done():
server.Stop()
return context.Cause(ctx)
case err := <-errC:
return err
}
}
func (g *grpcUpstream) Dial(r testenv.Route, dialOpts ...grpc.DialOption) *grpc.ClientConn {
dialOpts = append(dialOpts,
grpc.WithTransportCredentials(credentials.NewClientTLSFromCert(g.Env().ServerCAs(), "")),
grpc.WithDefaultCallOptions(grpc.WaitForReady(true)),
)
cc, err := grpc.NewClient(strings.TrimPrefix(r.URL().Value(), "https://"), dialOpts...)
if err != nil {
panic(err)
}
return cc
}

View file

@ -0,0 +1,120 @@
package values
import (
"math/rand/v2"
"sync"
)
type value[T any] struct {
f func() T
ready bool
cond *sync.Cond
}
// A Value is a container for a single value of type T, whose initialization is
// performed the first time Value() is called. Subsequent calls will return the
// same value. The Value() function may block until the value is ready on the
// first call. Values are safe to use concurrently.
type Value[T any] interface {
Value() T
}
// MutableValue is the read-write counterpart to [Value], created by calling
// [Deferred] for some type T. Calling Resolve() or ResolveFunc() will set
// the value and unblock any waiting calls to Value().
type MutableValue[T any] interface {
Value[T]
Resolve(value T)
ResolveFunc(fOnce func() T)
}
// Deferred creates a new read-write [MutableValue] for some type T,
// representing a value whose initialization may be deferred to a later time.
// Once the value is available, call [MutableValue.Resolve] or
// [MutableValue.ResolveFunc] to unblock any waiting calls to Value().
func Deferred[T any]() MutableValue[T] {
return &value[T]{
cond: sync.NewCond(&sync.Mutex{}),
}
}
// Const creates a read-only [Value] which will become available immediately
// upon calling Value() for the first time; it will never block.
func Const[T any](t T) Value[T] {
return &value[T]{
f: func() T { return t },
ready: true,
cond: sync.NewCond(&sync.Mutex{}),
}
}
func (p *value[T]) Value() T {
p.cond.L.Lock()
defer p.cond.L.Unlock()
for !p.ready {
p.cond.Wait()
}
return p.f()
}
func (p *value[T]) ResolveFunc(fOnce func() T) {
p.cond.L.Lock()
p.f = sync.OnceValue(fOnce)
p.ready = true
p.cond.L.Unlock()
p.cond.Broadcast()
}
func (p *value[T]) Resolve(value T) {
p.ResolveFunc(func() T { return value })
}
// Bind creates a new [MutableValue] whose ultimate value depends on the result
// of another [Value] that may not yet be available. When Value() is called on
// the result, it will cascade and trigger the full chain of initialization
// functions necessary to produce the final value.
//
// Care should be taken when using this function, as improper use can lead to
// deadlocks and cause values to never become available.
func Bind[T any, U any](dt Value[T], callback func(value T) U) MutableValue[U] {
du := Deferred[U]()
du.ResolveFunc(func() U {
return callback(dt.Value())
})
return du
}
// Bind2 is like [Bind], but can accept two input values. The result will only
// become available once all input values become available.
//
// This function blocks to wait for each input value in sequence, but in a
// random order. Do not rely on the order of evaluation of the input values.
func Bind2[T any, U any, V any](dt Value[T], du Value[U], callback func(value1 T, value2 U) V) MutableValue[V] {
dv := Deferred[V]()
dv.ResolveFunc(func() V {
if rand.IntN(2) == 0 {
return callback(dt.Value(), du.Value())
}
u := du.Value()
t := dt.Value()
return callback(t, u)
})
return dv
}
// List is a container for a slice of [Value] of type T, and is also a [Value]
// itself, for convenience. The Value() function will return a []T containing
// all resolved values for each element in the slice.
//
// A List's Value() function blocks to wait for each element in the slice in
// sequence, but in a random order. Do not rely on the order of evaluation of
// the slice elements.
type List[T any] []Value[T]
func (s List[T]) Value() []T {
values := make([]T, len(s))
for _, i := range rand.Perm(len(values)) {
values[i] = s[i].Value()
}
return values
}

View file

@ -71,7 +71,7 @@ func Run(ctx context.Context, src config.Source) error {
cfg := src.GetConfig()
// setup the control plane
controlPlane, err := controlplane.NewServer(cfg, metricsMgr, eventsMgr)
controlPlane, err := controlplane.NewServer(ctx, cfg, metricsMgr, eventsMgr)
if err != nil {
return fmt.Errorf("error creating control plane: %w", err)
}