New integration test fixtures (#5233)

* Initial test environment implementation

* linter pass

* wip: update request latency test

* bugfixes

* Fix logic race in envoy process monitor when canceling context

* skip tests using test environment on non-linux
This commit is contained in:
Joe Kralicky 2024-11-05 14:31:40 -05:00 committed by GitHub
parent 3d958ff9c5
commit 526e2a58d6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
29 changed files with 2972 additions and 101 deletions

View file

@ -76,3 +76,6 @@ issues:
- text: "G112:"
linters:
- gosec
- text: "G402: TLS MinVersion too low."
linters:
- gosec

View file

@ -26,9 +26,9 @@ import (
const maxActiveDownstreamConnections = 50000
var (
envoyAdminAddressPath = filepath.Join(os.TempDir(), "pomerium-envoy-admin.sock")
envoyAdminAddressMode = 0o600
envoyAdminClusterName = "pomerium-envoy-admin"
envoyAdminAddressSockName = "pomerium-envoy-admin.sock"
envoyAdminAddressMode = 0o600
envoyAdminClusterName = "pomerium-envoy-admin"
)
// BuildBootstrap builds the bootstrap config.
@ -95,7 +95,7 @@ func (b *Builder) BuildBootstrapAdmin(cfg *config.Config) (admin *envoy_config_b
admin.Address = &envoy_config_core_v3.Address{
Address: &envoy_config_core_v3.Address_Pipe{
Pipe: &envoy_config_core_v3.Pipe{
Path: envoyAdminAddressPath,
Path: filepath.Join(os.TempDir(), envoyAdminAddressSockName),
Mode: uint32(envoyAdminAddressMode),
},
},

View file

@ -12,6 +12,7 @@ import (
)
func TestBuilder_BuildBootstrapAdmin(t *testing.T) {
t.Setenv("TMPDIR", "/tmp")
b := New("local-grpc", "local-http", "local-metrics", filemgr.NewManager(), nil)
t.Run("valid", func(t *testing.T) {
adminCfg, err := b.BuildBootstrapAdmin(&config.Config{
@ -25,7 +26,7 @@ func TestBuilder_BuildBootstrapAdmin(t *testing.T) {
"address": {
"pipe": {
"mode": 384,
"path": "`+envoyAdminAddressPath+`"
"path": "/tmp/`+envoyAdminAddressSockName+`"
}
}
}

View file

@ -2,6 +2,8 @@ package envoyconfig
import (
"context"
"os"
"path/filepath"
envoy_config_cluster_v3 "github.com/envoyproxy/go-control-plane/envoy/config/cluster/v3"
envoy_config_core_v3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
@ -23,7 +25,8 @@ func (b *Builder) buildEnvoyAdminCluster(_ context.Context, _ *config.Config) (*
Address: &envoy_config_core_v3.Address{
Address: &envoy_config_core_v3.Address_Pipe{
Pipe: &envoy_config_core_v3.Pipe{
Path: envoyAdminAddressPath,
Path: filepath.Join(os.TempDir(), envoyAdminAddressSockName),
Mode: uint32(envoyAdminAddressMode),
},
},
},

View file

@ -23,11 +23,7 @@ import (
func Test_BuildClusters(t *testing.T) {
// The admin address path is based on os.TempDir(), which will vary from
// system to system, so replace this with a stable location.
originalEnvoyAdminAddressPath := envoyAdminAddressPath
envoyAdminAddressPath = "/tmp/pomerium-envoy-admin.sock"
t.Cleanup(func() {
envoyAdminAddressPath = originalEnvoyAdminAddressPath
})
t.Setenv("TMPDIR", "/tmp")
opts := config.NewDefaultOptions()
ctx := context.Background()

View file

@ -1,102 +1,159 @@
package envoyconfig_test
import (
"context"
"fmt"
"net"
"io"
"net/http"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/interop"
"google.golang.org/grpc/interop/grpc_testing"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/pkg/cmd/pomerium"
"github.com/pomerium/pomerium/pkg/netutil"
"github.com/pomerium/pomerium/internal/testenv"
"github.com/pomerium/pomerium/internal/testenv/scenarios"
"github.com/pomerium/pomerium/internal/testenv/snippets"
"github.com/pomerium/pomerium/internal/testenv/upstreams"
"github.com/pomerium/pomerium/internal/testenv/values"
)
func TestH2C(t *testing.T) {
if testing.Short() {
t.SkipNow()
}
env := testenv.New(t)
ctx, ca := context.WithCancel(context.Background())
up := upstreams.GRPC(insecure.NewCredentials())
grpc_testing.RegisterTestServiceServer(up, interop.NewTestServer())
opts := config.NewDefaultOptions()
listener, err := (&net.ListenConfig{}).Listen(ctx, "tcp", "127.0.0.1:0")
require.NoError(t, err)
ports, err := netutil.AllocatePorts(7)
require.NoError(t, err)
urls, err := config.ParseWeightedUrls("http://"+listener.Addr().String(), "h2c://"+listener.Addr().String())
require.NoError(t, err)
opts.Addr = fmt.Sprintf("127.0.0.1:%s", ports[0])
opts.Routes = []config.Policy{
{
From: fmt.Sprintf("https://grpc-http.localhost.pomerium.io:%s", ports[0]),
To: urls[:1],
AllowPublicUnauthenticatedAccess: true,
},
{
From: fmt.Sprintf("https://grpc-h2c.localhost.pomerium.io:%s", ports[0]),
To: urls[1:],
AllowPublicUnauthenticatedAccess: true,
},
}
opts.CertFile = "../../integration/tpl/files/trusted.pem"
opts.KeyFile = "../../integration/tpl/files/trusted-key.pem"
cfg := &config.Config{Options: opts}
cfg.AllocatePorts(*(*[6]string)(ports[1:]))
http := up.Route().
From(env.SubdomainURL("grpc-http")).
To(values.Bind(up.Port(), func(port int) string {
// override the target protocol to use http://
return fmt.Sprintf("http://127.0.0.1:%d", port)
})).
Policy(func(p *config.Policy) { p.AllowPublicUnauthenticatedAccess = true })
server := grpc.NewServer(grpc.Creds(insecure.NewCredentials()))
grpc_testing.RegisterTestServiceServer(server, interop.NewTestServer())
go server.Serve(listener)
h2c := up.Route().
From(env.SubdomainURL("grpc-h2c")).
Policy(func(p *config.Policy) { p.AllowPublicUnauthenticatedAccess = true })
errC := make(chan error, 1)
go func() {
errC <- pomerium.Run(ctx, config.NewStaticSource(cfg))
}()
t.Cleanup(func() {
ca()
assert.ErrorIs(t, context.Canceled, <-errC)
})
tlsConfig, err := credentials.NewClientTLSFromFile("../../integration/tpl/files/ca.pem", "")
require.NoError(t, err)
env.AddUpstream(up)
env.Start()
snippets.WaitStartupComplete(env)
t.Run("h2c", func(t *testing.T) {
t.Parallel()
recorder := env.NewLogRecorder()
cc, err := grpc.Dial(fmt.Sprintf("grpc-h2c.localhost.pomerium.io:%s", ports[0]), grpc.WithTransportCredentials(tlsConfig))
require.NoError(t, err)
cc := up.Dial(h2c)
client := grpc_testing.NewTestServiceClient(cc)
var md metadata.MD
_, err = client.EmptyCall(ctx, &grpc_testing.Empty{}, grpc.WaitForReady(true), grpc.Header(&md))
_, err := client.EmptyCall(env.Context(), &grpc_testing.Empty{})
require.NoError(t, err)
cc.Close()
assert.NoError(t, err)
assert.Contains(t, md, "x-envoy-upstream-service-time")
recorder.Match([]map[string]any{
{
"service": "envoy",
"path": "/grpc.testing.TestService/EmptyCall",
"message": "http-request",
"response-code-details": "via_upstream",
},
})
})
t.Run("http", func(t *testing.T) {
t.Parallel()
recorder := env.NewLogRecorder()
cc, err := grpc.Dial(fmt.Sprintf("grpc-http.localhost.pomerium.io:%s", ports[0]), grpc.WithTransportCredentials(tlsConfig))
require.NoError(t, err)
cc := up.Dial(http)
client := grpc_testing.NewTestServiceClient(cc)
var md metadata.MD
_, err = client.EmptyCall(ctx, &grpc_testing.Empty{}, grpc.WaitForReady(true), grpc.Trailer(&md))
_, err := client.UnaryCall(env.Context(), &grpc_testing.SimpleRequest{})
require.Error(t, err)
cc.Close()
stat := status.Convert(err)
assert.NotNil(t, stat)
assert.Equal(t, stat.Code(), codes.Unavailable)
assert.NotContains(t, md, "x-envoy-upstream-service-time")
assert.Contains(t, stat.Message(), "<!DOCTYPE html>")
assert.Contains(t, stat.Message(), "upstream_reset_before_response_started{protocol_error}")
recorder.Match([]map[string]any{
{
"service": "envoy",
"path": "/grpc.testing.TestService/UnaryCall",
"message": "http-request",
"response-code-details": "upstream_reset_before_response_started{protocol_error}",
},
})
})
}
func TestHTTP(t *testing.T) {
env := testenv.New(t)
up := upstreams.HTTP(nil)
up.Handle("/foo", func(w http.ResponseWriter, _ *http.Request) {
fmt.Fprintln(w, "hello world")
})
route := up.Route().
From(env.SubdomainURL("http")).
Policy(func(p *config.Policy) { p.AllowPublicUnauthenticatedAccess = true })
env.AddUpstream(up)
env.Start()
recorder := env.NewLogRecorder()
resp, err := up.Get(route, upstreams.Path("/foo"))
require.NoError(t, err)
defer resp.Body.Close()
data, err := io.ReadAll(resp.Body)
require.NoError(t, err)
assert.Equal(t, "hello world\n", string(data))
recorder.Match([]map[string]any{
{
"service": "envoy",
"path": "/foo",
"method": "GET",
"message": "http-request",
"response-code-details": "via_upstream",
},
})
}
func TestClientCert(t *testing.T) {
env := testenv.New(t)
env.Add(scenarios.DownstreamMTLS(config.MTLSEnforcementRejectConnection))
up := upstreams.HTTP(nil)
up.Handle("/foo", func(w http.ResponseWriter, _ *http.Request) {
fmt.Fprintln(w, "hello world")
})
clientCert := env.NewClientCert()
route := up.Route().
From(env.SubdomainURL("http")).
PPL(fmt.Sprintf(`{"allow":{"and":["client_certificate":{"fingerprint":%q}]}}`, clientCert.Fingerprint()))
env.AddUpstream(up)
env.Start()
recorder := env.NewLogRecorder()
resp, err := up.Get(route, upstreams.Path("/foo"), upstreams.ClientCert(clientCert))
require.NoError(t, err)
defer resp.Body.Close()
data, err := io.ReadAll(resp.Body)
require.NoError(t, err)
assert.Equal(t, "hello world\n", string(data))
recorder.Match([]map[string]any{
{
"service": "envoy",
"path": "/foo",
"method": "GET",
"message": "http-request",
"response-code-details": "via_upstream",
"client-certificate": clientCert,
},
})
}

View file

@ -280,6 +280,7 @@
"endpoint": {
"address": {
"pipe": {
"mode": 384,
"path": "/tmp/pomerium-envoy-admin.sock"
}
}

View file

@ -129,6 +129,7 @@ func newManager(
for {
select {
case <-ctx.Done():
cache.Stop()
return
case <-ticker.C:
err := mgr.renewConfigCerts(ctx)

View file

@ -0,0 +1,54 @@
package benchmarks_test
import (
"fmt"
"testing"
"time"
"github.com/pomerium/pomerium/internal/testenv"
"github.com/pomerium/pomerium/internal/testenv/snippets"
"github.com/pomerium/pomerium/internal/testenv/upstreams"
)
func BenchmarkStartupLatency(b *testing.B) {
for _, n := range []int{1, 10, 100, 1000, 10000} {
b.Run(fmt.Sprintf("routes=%d", n), func(b *testing.B) {
for range b.N {
env := testenv.New(b)
up := upstreams.HTTP(nil)
for i := range n {
up.Route().
From(env.SubdomainURL(fmt.Sprintf("from-%d", i))).
PPL(`{"allow":{"and":[{"accept":"true"}]}}`)
}
env.AddUpstream(up)
env.Start()
snippets.WaitStartupComplete(env, 60*time.Minute)
env.Stop()
}
})
}
}
func BenchmarkAppendRoutes(b *testing.B) {
for _, n := range []int{1, 10, 100, 1000, 10000} {
b.Run(fmt.Sprintf("routes=%d", n), func(b *testing.B) {
for range b.N {
env := testenv.New(b)
up := upstreams.HTTP(nil)
env.AddUpstream(up)
env.Start()
snippets.WaitStartupComplete(env)
for i := range n {
env.Add(up.Route().
From(env.SubdomainURL(fmt.Sprintf("from-%d", i))).
PPL(fmt.Sprintf(`{"allow":{"and":["email":{"is":"user-%d@example.com"}]}}`, i)))
}
env.Stop()
}
})
}
}

View file

@ -0,0 +1,87 @@
package benchmarks_test
import (
"flag"
"fmt"
"io"
"math/rand/v2"
"net/http"
"testing"
"github.com/pomerium/pomerium/internal/testenv"
"github.com/pomerium/pomerium/internal/testenv/scenarios"
"github.com/pomerium/pomerium/internal/testenv/snippets"
"github.com/pomerium/pomerium/internal/testenv/upstreams"
"github.com/stretchr/testify/assert"
)
var (
numRoutes int
dumpErrLogs bool
)
func init() {
flag.IntVar(&numRoutes, "routes", 100, "number of routes")
flag.BoolVar(&dumpErrLogs, "dump-err-logs", false, "if the test fails, write all captured logs to a file (testdata/<test-name>)")
}
func TestRequestLatency(t *testing.T) {
env := testenv.New(t, testenv.Silent())
users := []*scenarios.User{}
for i := range numRoutes {
users = append(users, &scenarios.User{
Email: fmt.Sprintf("user%d@example.com", i),
FirstName: fmt.Sprintf("Firstname%d", i),
LastName: fmt.Sprintf("Lastname%d", i),
})
}
env.Add(scenarios.NewIDP(users))
up := upstreams.HTTP(nil)
up.Handle("/", func(w http.ResponseWriter, _ *http.Request) {
w.Write([]byte("OK"))
})
routes := make([]testenv.Route, numRoutes)
for i := range numRoutes {
routes[i] = up.Route().
From(env.SubdomainURL(fmt.Sprintf("from-%d", i))).
PPL(fmt.Sprintf(`{"allow":{"and":["email":{"is":"user%d@example.com"}]}}`, i))
}
env.AddUpstream(up)
env.Start()
snippets.WaitStartupComplete(env)
out := testing.Benchmark(func(b *testing.B) {
b.ReportAllocs()
b.RunParallel(func(pb *testing.PB) {
var rec *testenv.LogRecorder
if dumpErrLogs {
rec = env.NewLogRecorder(testenv.WithSkipCloseDelay())
}
for pb.Next() {
idx := rand.IntN(numRoutes)
resp, err := up.Get(routes[idx], upstreams.AuthenticateAs(fmt.Sprintf("user%d@example.com", idx)))
if !assert.NoError(b, err) {
filename := "TestRequestLatency_err.log"
if dumpErrLogs {
rec.DumpToFile(filename)
b.Logf("test logs written to %s", filename)
}
return
}
assert.Equal(b, resp.StatusCode, 200)
body, err := io.ReadAll(resp.Body)
resp.Body.Close()
assert.NoError(b, err)
assert.Equal(b, "OK", string(body))
}
})
})
t.Log(out)
t.Logf("req/s: %f", float64(out.N)/out.T.Seconds())
env.Stop()
}

View file

@ -74,11 +74,12 @@ func NewServer(
cfg *config.Config,
metricsMgr *config.MetricsManager,
eventsMgr *events.Manager,
fileMgr *filemgr.Manager,
) (*Server, error) {
srv := &Server{
metricsMgr: metricsMgr,
EventsMgr: eventsMgr,
filemgr: filemgr.NewManager(),
filemgr: fileMgr,
reproxy: reproxy.New(),
haveSetCapacity: map[string]bool{},
updateConfig: make(chan *config.Config, 1),

View file

@ -12,6 +12,7 @@ import (
"github.com/stretchr/testify/require"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/config/envoyconfig/filemgr"
"github.com/pomerium/pomerium/internal/events"
"github.com/pomerium/pomerium/pkg/netutil"
)
@ -38,7 +39,7 @@ func TestServerHTTP(t *testing.T) {
cfg.Options.SharedKey = "JDNjY2ITDlARvNaQXjc2Djk+GA6xeCy4KiozmZfdbTs="
src := config.NewStaticSource(cfg)
srv, err := NewServer(ctx, cfg, config.NewMetricsManager(ctx, src), events.New())
srv, err := NewServer(ctx, cfg, config.NewMetricsManager(ctx, src), events.New(), filemgr.NewManager(filemgr.WithCacheDir(t.TempDir())))
require.NoError(t, err)
go srv.Run(ctx)

View file

@ -7,4 +7,6 @@ var (
DebugDisableZapLogger atomic.Bool
// Debug option to suppress global warnings
DebugDisableGlobalWarnings atomic.Bool
// Debug option to suppress global (non-warning) messages
DebugDisableGlobalMessages atomic.Bool
)

View file

@ -0,0 +1,838 @@
package testenv
import (
"bytes"
"context"
"crypto/rand"
"crypto/rsa"
"crypto/sha256"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/asn1"
"encoding/base64"
"encoding/hex"
"errors"
"fmt"
"io"
"math/big"
"math/bits"
"net"
"net/url"
"os"
"os/signal"
"path"
"path/filepath"
"runtime"
"strconv"
"sync"
"syscall"
"testing"
"time"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/config/envoyconfig/filemgr"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/testenv/values"
"github.com/pomerium/pomerium/pkg/cmd/pomerium"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/health"
"github.com/pomerium/pomerium/pkg/netutil"
"github.com/pomerium/pomerium/pkg/slices"
"github.com/rs/zerolog"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/sync/errgroup"
"google.golang.org/grpc/grpclog"
)
// Environment is a lightweight integration test fixture that runs Pomerium
// in-process.
type Environment interface {
// Context returns the environment's root context. This context holds a
// top-level logger scoped to this environment. It will be canceled when
// Stop() is called, or during test cleanup.
Context() context.Context
Assert() *assert.Assertions
Require() *require.Assertions
// TempDir returns a unique temp directory for this context. Calling this
// function multiple times returns the same path.
TempDir() string
// CACert returns the test environment's root CA certificate and private key.
CACert() *tls.Certificate
// ServerCAs returns a new [*x509.CertPool] containing the root CA certificate
// used to sign the server cert and other test certificates.
ServerCAs() *x509.CertPool
// ServerCert returns the Pomerium server's certificate and private key.
ServerCert() *tls.Certificate
// NewClientCert generates a new client certificate signed by the root CA
// certificate. One or more optional templates can be given, which can be
// used to set or override certain parameters when creating a certificate,
// including subject, SANs, or extensions. If more than one template is
// provided, they will be applied in order from left to right.
//
// By default (unless overridden in a template), the certificate will have
// its Common Name set to the file:line string of the call site. Calls to
// NewClientCert() on different lines will have different subjects. If
// multiple certs with the same subject are needed, wrap the call to this
// function in another helper function, or separate calls with commas on the
// same line.
NewClientCert(templateOverrides ...*x509.Certificate) *Certificate
NewServerCert(templateOverrides ...*x509.Certificate) *Certificate
AuthenticateURL() values.Value[string]
DatabrokerURL() values.Value[string]
Ports() Ports
SharedSecret() []byte
CookieSecret() []byte
// Add adds the given [Modifier] to the environment. All modifiers will be
// invoked upon calling Start() to apply individual modifications to the
// configuration before starting the Pomerium server.
Add(m Modifier)
// AddTask adds the given [Task] to the environment. All tasks will be
// started in separate goroutines upon calling Start(). If any tasks exit
// with an error, the environment will be stopped and the test will fail.
AddTask(r Task)
// AddUpstream adds the given [Upstream] to the environment. This function is
// equivalent to calling both Add() and AddTask() with the upstream, but
// improves readability.
AddUpstream(u Upstream)
// Start starts the test environment, and adds a call to Stop() as a cleanup
// hook to the environment's [testing.T]. All previously added [Modifier]
// instances are invoked in order to build the configuration, and all
// previously added [Task] instances are started in the background.
//
// Calling Start() more than once, Calling Start() after Stop(), or calling
// any of the Add* functions after Start() will panic.
Start()
// Stop stops the test environment. Calling this function more than once has
// no effect. It is usually not necessary to call Stop() directly unless you
// need to stop the test environment before the test is completed.
Stop()
// SubdomainURL returns a string [values.Value] which will contain a complete
// URL for the given subdomain of the server's domain (given by its serving
// certificate), including the 'https://' scheme and random http server port.
// This value will only be resolved some time after Start() is called, and
// can be used as the 'from' value for routes.
SubdomainURL(subdomain string) values.Value[string]
// NewLogRecorder returns a new [*LogRecorder] and starts capturing logs for
// the Pomerium server and Envoy.
NewLogRecorder(opts ...LogRecorderOption) *LogRecorder
// OnStateChanged registers a callback to be invoked when the environment's
// state changes to the given state. The callback is invoked in a separate
// goroutine.
OnStateChanged(state EnvironmentState, callback func())
}
type Certificate tls.Certificate
func (c *Certificate) Fingerprint() string {
sum := sha256.Sum256(c.Leaf.Raw)
return hex.EncodeToString(sum[:])
}
func (c *Certificate) SPKIHash() string {
sum := sha256.Sum256(c.Leaf.RawSubjectPublicKeyInfo)
return base64.StdEncoding.EncodeToString(sum[:])
}
type EnvironmentState uint32
const NotRunning EnvironmentState = 0
const (
Starting EnvironmentState = 1 << iota
Running
Stopping
Stopped
)
func (e EnvironmentState) String() string {
switch e {
case NotRunning:
return "NotRunning"
case Starting:
return "Starting"
case Running:
return "Running"
case Stopping:
return "Stopping"
case Stopped:
return "Stopped"
default:
return fmt.Sprintf("EnvironmentState(%d)", e)
}
}
type environment struct {
EnvironmentOptions
t testing.TB
assert *assert.Assertions
require *require.Assertions
tempDir string
domain string
ports Ports
sharedSecret [32]byte
cookieSecret [32]byte
workspaceFolder string
silent bool
ctx context.Context
cancel context.CancelCauseFunc
cleanupOnce sync.Once
logWriter *log.MultiWriter
mods []WithCaller[Modifier]
tasks []WithCaller[Task]
taskErrGroup *errgroup.Group
stateMu sync.Mutex
state EnvironmentState
stateChangeListeners map[EnvironmentState][]func()
src *configSource
}
type EnvironmentOptions struct {
debug bool
pauseOnFailure bool
forceSilent bool
}
type EnvironmentOption func(*EnvironmentOptions)
func (o *EnvironmentOptions) apply(opts ...EnvironmentOption) {
for _, op := range opts {
op(o)
}
}
func Debug(enable ...bool) EnvironmentOption {
if len(enable) == 0 {
enable = append(enable, true)
}
return func(o *EnvironmentOptions) {
o.debug = enable[0]
}
}
func PauseOnFailure(enable ...bool) EnvironmentOption {
if len(enable) == 0 {
enable = append(enable, true)
}
return func(o *EnvironmentOptions) {
o.pauseOnFailure = enable[0]
}
}
func Silent(silent ...bool) EnvironmentOption {
if len(silent) == 0 {
silent = append(silent, true)
}
return func(o *EnvironmentOptions) {
o.forceSilent = silent[0]
}
}
var setGrpcLoggerOnce sync.Once
func New(t testing.TB, opts ...EnvironmentOption) Environment {
if runtime.GOOS != "linux" {
t.Skip("test environment only supported on linux")
}
options := EnvironmentOptions{}
options.apply(opts...)
if testing.Short() {
t.Helper()
t.Skip("test environment disabled in short mode")
}
databroker.DebugUseFasterBackoff.Store(true)
workspaceFolder, err := os.Getwd()
require.NoError(t, err)
for {
if _, err := os.Stat(filepath.Join(workspaceFolder, ".git")); err == nil {
break
}
workspaceFolder = filepath.Dir(workspaceFolder)
if workspaceFolder == "/" {
panic("could not find workspace root")
}
}
workspaceFolder, err = filepath.Abs(workspaceFolder)
require.NoError(t, err)
writer := log.NewMultiWriter()
silent := options.forceSilent || isSilent(t)
if silent {
// this sets the global zap level to fatal, then resets the global zerolog
// level to debug
log.SetLevel(zerolog.FatalLevel)
zerolog.SetGlobalLevel(zerolog.DebugLevel)
} else {
log.SetLevel(zerolog.InfoLevel)
writer.Add(os.Stdout)
}
log.DebugDisableGlobalWarnings.Store(silent)
log.DebugDisableGlobalMessages.Store(silent)
log.DebugDisableZapLogger.Store(silent)
setGrpcLoggerOnce.Do(func() {
grpclog.SetLoggerV2(grpclog.NewLoggerV2WithVerbosity(io.Discard, io.Discard, io.Discard, 0))
})
logger := zerolog.New(writer).With().Timestamp().Logger().Level(zerolog.DebugLevel)
ctx, cancel := context.WithCancelCause(logger.WithContext(context.Background()))
taskErrGroup, ctx := errgroup.WithContext(ctx)
e := &environment{
EnvironmentOptions: options,
t: t,
assert: assert.New(t),
require: require.New(t),
tempDir: t.TempDir(),
ports: Ports{
ProxyHTTP: values.Deferred[int](),
ProxyGRPC: values.Deferred[int](),
GRPC: values.Deferred[int](),
HTTP: values.Deferred[int](),
Outbound: values.Deferred[int](),
Metrics: values.Deferred[int](),
Debug: values.Deferred[int](),
ALPN: values.Deferred[int](),
},
workspaceFolder: workspaceFolder,
silent: silent,
ctx: ctx,
cancel: cancel,
logWriter: writer,
taskErrGroup: taskErrGroup,
}
_, err = rand.Read(e.sharedSecret[:])
require.NoError(t, err)
_, err = rand.Read(e.cookieSecret[:])
require.NoError(t, err)
health.SetProvider(e)
require.NoError(t, os.Mkdir(filepath.Join(e.tempDir, "certs"), 0o777))
copyFile := func(src, dstRel string) {
data, err := os.ReadFile(src)
require.NoError(t, err)
require.NoError(t, os.WriteFile(filepath.Join(e.tempDir, dstRel), data, 0o600))
}
certsToCopy := []string{
"trusted.pem",
"trusted-key.pem",
"ca.pem",
"ca-key.pem",
}
for _, crt := range certsToCopy {
copyFile(filepath.Join(workspaceFolder, "integration/tpl/files", crt), filepath.Join("certs/", filepath.Base(crt)))
}
e.domain = wildcardDomain(e.ServerCert().Leaf.DNSNames)
return e
}
func (e *environment) debugf(format string, args ...any) {
if !e.debug {
return
}
e.t.Logf("\x1b[34m[debug] "+format+"\x1b[0m", args...)
}
type WithCaller[T any] struct {
Caller string
Value T
}
type Ports struct {
ProxyHTTP values.MutableValue[int]
ProxyGRPC values.MutableValue[int]
GRPC values.MutableValue[int]
HTTP values.MutableValue[int]
Outbound values.MutableValue[int]
Metrics values.MutableValue[int]
Debug values.MutableValue[int]
ALPN values.MutableValue[int]
}
func (e *environment) TempDir() string {
return e.tempDir
}
func (e *environment) Context() context.Context {
return ContextWithEnv(e.ctx, e)
}
func (e *environment) Assert() *assert.Assertions {
return e.assert
}
func (e *environment) Require() *require.Assertions {
return e.require
}
func (e *environment) SubdomainURL(subdomain string) values.Value[string] {
return values.Bind(e.ports.ProxyHTTP, func(port int) string {
return fmt.Sprintf("https://%s.%s:%d", subdomain, e.domain, port)
})
}
func (e *environment) AuthenticateURL() values.Value[string] {
return e.SubdomainURL("authenticate")
}
func (e *environment) DatabrokerURL() values.Value[string] {
return values.Bind(e.ports.Outbound, func(port int) string {
return fmt.Sprintf("127.0.0.1:%d", port)
})
}
func (e *environment) Ports() Ports {
return e.ports
}
func (e *environment) CACert() *tls.Certificate {
caCert, err := tls.LoadX509KeyPair(
filepath.Join(e.tempDir, "certs", "ca.pem"),
filepath.Join(e.tempDir, "certs", "ca-key.pem"),
)
require.NoError(e.t, err)
return &caCert
}
func (e *environment) ServerCAs() *x509.CertPool {
pool := x509.NewCertPool()
caCert, err := os.ReadFile(filepath.Join(e.tempDir, "certs", "ca.pem"))
require.NoError(e.t, err)
pool.AppendCertsFromPEM(caCert)
return pool
}
func (e *environment) ServerCert() *tls.Certificate {
serverCert, err := tls.LoadX509KeyPair(
filepath.Join(e.tempDir, "certs", "trusted.pem"),
filepath.Join(e.tempDir, "certs", "trusted-key.pem"),
)
require.NoError(e.t, err)
return &serverCert
}
// Used as the context's cancel cause during normal cleanup
var ErrCauseTestCleanup = errors.New("test cleanup")
// Used as the context's cancel cause when Stop() is called
var ErrCauseManualStop = errors.New("Stop() called")
func (e *environment) Start() {
e.debugf("Start()")
e.advanceState(Starting)
e.t.Cleanup(e.cleanup)
e.t.Setenv("TMPDIR", e.TempDir())
e.debugf("temp dir: %s", e.TempDir())
cfg := &config.Config{
Options: config.NewDefaultOptions(),
}
ports, err := netutil.AllocatePorts(8)
require.NoError(e.t, err)
atoi := func(str string) int {
p, err := strconv.Atoi(str)
if err != nil {
panic(err)
}
return p
}
e.ports.ProxyHTTP.Resolve(atoi(ports[0]))
e.ports.ProxyGRPC.Resolve(atoi(ports[1]))
e.ports.GRPC.Resolve(atoi(ports[2]))
e.ports.HTTP.Resolve(atoi(ports[3]))
e.ports.Outbound.Resolve(atoi(ports[4]))
e.ports.Metrics.Resolve(atoi(ports[5]))
e.ports.Debug.Resolve(atoi(ports[6]))
e.ports.ALPN.Resolve(atoi(ports[7]))
cfg.AllocatePorts(*(*[6]string)(ports[2:]))
cfg.Options.AutocertOptions = config.AutocertOptions{Enable: false}
cfg.Options.Services = "all"
cfg.Options.LogLevel = config.LogLevelDebug
cfg.Options.ProxyLogLevel = config.LogLevelInfo
cfg.Options.Addr = fmt.Sprintf("127.0.0.1:%d", e.ports.ProxyHTTP.Value())
cfg.Options.GRPCAddr = fmt.Sprintf("127.0.0.1:%d", e.ports.ProxyGRPC.Value())
cfg.Options.CAFile = filepath.Join(e.tempDir, "certs", "ca.pem")
cfg.Options.CertFile = filepath.Join(e.tempDir, "certs", "trusted.pem")
cfg.Options.KeyFile = filepath.Join(e.tempDir, "certs", "trusted-key.pem")
cfg.Options.AuthenticateURLString = e.AuthenticateURL().Value()
cfg.Options.DataBrokerStorageType = "memory"
cfg.Options.SharedKey = base64.StdEncoding.EncodeToString(e.sharedSecret[:])
cfg.Options.CookieSecret = base64.StdEncoding.EncodeToString(e.cookieSecret[:])
cfg.Options.AccessLogFields = []log.AccessLogField{
log.AccessLogFieldAuthority,
log.AccessLogFieldDuration,
log.AccessLogFieldForwardedFor,
log.AccessLogFieldIP,
log.AccessLogFieldMethod,
log.AccessLogFieldPath,
log.AccessLogFieldQuery,
log.AccessLogFieldReferer,
log.AccessLogFieldRequestID,
log.AccessLogFieldResponseCode,
log.AccessLogFieldResponseCodeDetails,
log.AccessLogFieldSize,
log.AccessLogFieldUpstreamCluster,
log.AccessLogFieldUserAgent,
log.AccessLogFieldClientCertificate,
}
e.src = &configSource{cfg: cfg}
e.AddTask(TaskFunc(func(ctx context.Context) error {
fileMgr := filemgr.NewManager(filemgr.WithCacheDir(filepath.Join(e.TempDir(), "cache")))
for _, mod := range e.mods {
mod.Value.Modify(cfg)
require.NoError(e.t, cfg.Options.Validate(), "invoking modifier resulted in an invalid configuration:\nadded by: "+mod.Caller)
}
return pomerium.Run(ctx, e.src, pomerium.WithOverrideFileManager(fileMgr))
}))
for i, task := range e.tasks {
log.Ctx(e.ctx).Debug().Str("caller", task.Caller).Msgf("starting task %d", i)
e.taskErrGroup.Go(func() error {
defer log.Ctx(e.ctx).Debug().Str("caller", task.Caller).Msgf("task %d exited", i)
return task.Value.Run(e.ctx)
})
}
runtime.Gosched()
e.advanceState(Running)
}
func (e *environment) NewClientCert(templateOverrides ...*x509.Certificate) *Certificate {
caCert := e.CACert()
priv, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(e.t, err)
sn, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128))
require.NoError(e.t, err)
now := time.Now()
tmpl := &x509.Certificate{
SerialNumber: sn,
Subject: pkix.Name{
CommonName: getCaller(),
},
NotBefore: now,
NotAfter: now.Add(12 * time.Hour),
KeyUsage: x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{
x509.ExtKeyUsageClientAuth,
},
BasicConstraintsValid: true,
}
for _, override := range templateOverrides {
tmpl.CRLDistributionPoints = slices.Unique(append(tmpl.CRLDistributionPoints, override.CRLDistributionPoints...))
tmpl.DNSNames = slices.Unique(append(tmpl.DNSNames, override.DNSNames...))
tmpl.EmailAddresses = slices.Unique(append(tmpl.EmailAddresses, override.EmailAddresses...))
tmpl.ExtraExtensions = append(tmpl.ExtraExtensions, override.ExtraExtensions...)
tmpl.IPAddresses = slices.UniqueBy(append(tmpl.IPAddresses, override.IPAddresses...), net.IP.String)
tmpl.URIs = slices.UniqueBy(append(tmpl.URIs, override.URIs...), (*url.URL).String)
tmpl.UnknownExtKeyUsage = slices.UniqueBy(append(tmpl.UnknownExtKeyUsage, override.UnknownExtKeyUsage...), asn1.ObjectIdentifier.String)
seq := override.Subject.ToRDNSequence()
tmpl.Subject.FillFromRDNSequence(&seq)
tmpl.KeyUsage |= override.KeyUsage
tmpl.ExtKeyUsage = slices.Unique(append(tmpl.ExtKeyUsage, override.ExtKeyUsage...))
}
clientCertDER, err := x509.CreateCertificate(rand.Reader, tmpl, caCert.Leaf, priv.Public(), caCert.PrivateKey)
require.NoError(e.t, err)
cert, err := x509.ParseCertificate(clientCertDER)
require.NoError(e.t, err)
e.debugf("provisioned client certificate for %s", cert.Subject.String())
clientCert := &tls.Certificate{
Certificate: [][]byte{cert.Raw, caCert.Leaf.Raw},
PrivateKey: priv,
Leaf: cert,
}
_, err = clientCert.Leaf.Verify(x509.VerifyOptions{
KeyUsages: []x509.ExtKeyUsage{
x509.ExtKeyUsageClientAuth,
},
Roots: e.ServerCAs(),
})
require.NoError(e.t, err, "bug: generated client cert is not valid")
return (*Certificate)(clientCert)
}
func (e *environment) NewServerCert(templateOverrides ...*x509.Certificate) *Certificate {
caCert := e.CACert()
priv, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(e.t, err)
sn, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128))
require.NoError(e.t, err)
now := time.Now()
tmpl := &x509.Certificate{
SerialNumber: sn,
NotBefore: now,
NotAfter: now.Add(12 * time.Hour),
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment,
ExtKeyUsage: []x509.ExtKeyUsage{
x509.ExtKeyUsageServerAuth,
},
BasicConstraintsValid: true,
}
for _, override := range templateOverrides {
tmpl.DNSNames = slices.Unique(append(tmpl.DNSNames, override.DNSNames...))
tmpl.IPAddresses = slices.UniqueBy(append(tmpl.IPAddresses, override.IPAddresses...), net.IP.String)
}
certDER, err := x509.CreateCertificate(rand.Reader, tmpl, caCert.Leaf, priv.Public(), caCert.PrivateKey)
require.NoError(e.t, err)
cert, err := x509.ParseCertificate(certDER)
require.NoError(e.t, err)
e.debugf("provisioned server certificate for %v", cert.DNSNames)
tlsCert := &tls.Certificate{
Certificate: [][]byte{cert.Raw, caCert.Leaf.Raw},
PrivateKey: priv,
Leaf: cert,
}
_, err = tlsCert.Leaf.Verify(x509.VerifyOptions{Roots: e.ServerCAs()})
require.NoError(e.t, err, "bug: generated client cert is not valid")
return (*Certificate)(tlsCert)
}
func (e *environment) SharedSecret() []byte {
return bytes.Clone(e.sharedSecret[:])
}
func (e *environment) CookieSecret() []byte {
return bytes.Clone(e.cookieSecret[:])
}
func (e *environment) Stop() {
if b, ok := e.t.(*testing.B); ok {
// when calling Stop() manually, ensure we aren't timing this
b.StopTimer()
defer b.StartTimer()
}
e.cleanupOnce.Do(func() {
e.debugf("stop: Stop() called manually")
e.advanceState(Stopping)
e.cancel(ErrCauseManualStop)
err := e.taskErrGroup.Wait()
e.advanceState(Stopped)
e.debugf("stop: done waiting")
assert.ErrorIs(e.t, err, ErrCauseManualStop)
})
}
func (e *environment) cleanup() {
e.cleanupOnce.Do(func() {
e.debugf("stop: test cleanup")
if e.t.Failed() {
if e.pauseOnFailure {
e.t.Log("\x1b[31m*** pausing on test failure; continue with ctrl+c ***\x1b[0m")
c := make(chan os.Signal, 1)
signal.Notify(c, syscall.SIGINT)
<-c
e.t.Log("\x1b[31mctrl+c received, continuing\x1b[0m")
signal.Stop(c)
}
}
e.advanceState(Stopping)
e.cancel(ErrCauseTestCleanup)
err := e.taskErrGroup.Wait()
e.advanceState(Stopped)
e.debugf("stop: done waiting")
assert.ErrorIs(e.t, err, ErrCauseTestCleanup)
})
}
func (e *environment) Add(m Modifier) {
e.t.Helper()
caller := getCaller()
e.debugf("Add: %T from %s", m, caller)
switch e.getState() {
case NotRunning:
for _, mod := range e.mods {
if mod.Value == m {
e.t.Fatalf("test bug: duplicate modifier added\nfirst added by: %s", mod.Caller)
}
}
e.mods = append(e.mods, WithCaller[Modifier]{
Caller: caller,
Value: m,
})
e.debugf("Add: state=NotRunning; calling Attach")
m.Attach(e.Context())
case Starting:
panic("test bug: cannot call Add() before Start() has returned")
case Running:
e.debugf("Add: state=Running; calling ModifyConfig")
e.src.ModifyConfig(e.ctx, m)
case Stopped, Stopping:
panic("test bug: cannot call Add() after Stop()")
default:
panic(fmt.Sprintf("unexpected environment state: %s", e.getState()))
}
}
func (e *environment) AddTask(t Task) {
e.t.Helper()
caller := getCaller()
e.debugf("AddTask: %T from %s", t, caller)
for _, task := range e.tasks {
if task.Value == t {
e.t.Fatalf("test bug: duplicate task added\nfirst added by: %s", task.Caller)
}
}
e.tasks = append(e.tasks, WithCaller[Task]{
Caller: getCaller(),
Value: t,
})
}
func (e *environment) AddUpstream(up Upstream) {
e.t.Helper()
caller := getCaller()
e.debugf("AddUpstream: %T from %s", up, caller)
e.Add(up)
e.AddTask(up)
}
// ReportError implements health.Provider.
func (e *environment) ReportError(check health.Check, err error, attributes ...health.Attr) {
// note: don't use e.t.Fatal here, it will deadlock
panic(fmt.Sprintf("%s: %v %v", check, err, attributes))
}
// ReportOK implements health.Provider.
func (e *environment) ReportOK(_ health.Check, _ ...health.Attr) {}
func (e *environment) advanceState(newState EnvironmentState) {
e.stateMu.Lock()
defer e.stateMu.Unlock()
if newState <= e.state {
panic(fmt.Sprintf("internal test environment bug: changed state to <= current: newState=%s, current=%s", newState, e.state))
}
e.debugf("state %s -> %s", e.state.String(), newState.String())
e.state = newState
e.debugf("notifying %d listeners of state change", len(e.stateChangeListeners[newState]))
for _, listener := range e.stateChangeListeners[newState] {
go listener()
}
}
func (e *environment) getState() EnvironmentState {
e.stateMu.Lock()
defer e.stateMu.Unlock()
return e.state
}
func (e *environment) OnStateChanged(state EnvironmentState, callback func()) {
e.stateMu.Lock()
defer e.stateMu.Unlock()
if e.state&state != 0 {
go callback()
return
}
// add change listeners for all states, if there are multiple bits set
for state > 0 {
stateBit := EnvironmentState(bits.TrailingZeros32(uint32(state)))
state &= (state - 1)
e.stateChangeListeners[stateBit] = append(e.stateChangeListeners[stateBit], callback)
}
}
func getCaller(skip ...int) string {
if len(skip) == 0 {
skip = append(skip, 3)
}
callers := make([]uintptr, 8)
runtime.Callers(skip[0], callers)
frames := runtime.CallersFrames(callers)
var caller string
for {
next, ok := frames.Next()
if !ok {
break
}
if path.Base(next.Function) == "testenv.(*environment).AddUpstream" {
continue
}
caller = fmt.Sprintf("%s:%d", next.File, next.Line)
break
}
return caller
}
func wildcardDomain(names []string) string {
for _, name := range names {
if name[0] == '*' {
return name[2:]
}
}
panic("test bug: no wildcard domain in certificate")
}
func isSilent(t testing.TB) bool {
switch t.(type) {
case *testing.B:
return !slices.Contains(os.Args, "-test.v=true")
default:
return false
}
}
type configSource struct {
mu sync.Mutex
cfg *config.Config
lis []config.ChangeListener
}
var _ config.Source = (*configSource)(nil)
// GetConfig implements config.Source.
func (src *configSource) GetConfig() *config.Config {
src.mu.Lock()
defer src.mu.Unlock()
return src.cfg
}
// OnConfigChange implements config.Source.
func (src *configSource) OnConfigChange(_ context.Context, li config.ChangeListener) {
src.mu.Lock()
defer src.mu.Unlock()
src.lis = append(src.lis, li)
}
// ModifyConfig updates the current configuration by applying a [Modifier].
func (src *configSource) ModifyConfig(ctx context.Context, m Modifier) {
src.mu.Lock()
defer src.mu.Unlock()
m.Modify(src.cfg)
for _, li := range src.lis {
li(ctx, src.cfg)
}
}

391
internal/testenv/logs.go Normal file
View file

@ -0,0 +1,391 @@
package testenv
import (
"bufio"
"bytes"
"context"
"crypto/tls"
"crypto/x509"
"encoding/json"
"errors"
"fmt"
"io"
"os"
"reflect"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// LogRecorder captures logs from the test environment. It can be created at
// any time by calling [Environment.NewLogRecorder], and captures logs until
// one of Close(), Logs(), or Match() is called, which stops recording. See the
// documentation for each method for more details.
type LogRecorder struct {
LogRecorderOptions
t testing.TB
canceled <-chan struct{}
buf *buffer
recordedLogs []map[string]any
removeGlobalWriterOnce func()
collectLogsOnce sync.Once
}
type LogRecorderOptions struct {
filters []func(map[string]any) bool
skipCloseDelay bool
}
type LogRecorderOption func(*LogRecorderOptions)
func (o *LogRecorderOptions) apply(opts ...LogRecorderOption) {
for _, op := range opts {
op(o)
}
}
// WithFilters applies one or more filter predicates to the logger. If there
// are filters present, they will be called in order when a log is received,
// and if any filter returns false for a given log, it will be discarded.
func WithFilters(filters ...func(map[string]any) bool) LogRecorderOption {
return func(o *LogRecorderOptions) {
o.filters = filters
}
}
// WithSkipCloseDelay skips the 1.1 second delay before closing the recorder.
// This delay is normally required to ensure Envoy access logs are flushed,
// but can be skipped if not required.
func WithSkipCloseDelay() LogRecorderOption {
return func(o *LogRecorderOptions) {
o.skipCloseDelay = true
}
}
type buffer struct {
mu *sync.Mutex
underlying bytes.Buffer
cond *sync.Cond
waiting bool
closed bool
}
func newBuffer() *buffer {
mu := &sync.Mutex{}
return &buffer{
mu: mu,
cond: sync.NewCond(mu),
}
}
// Read implements io.ReadWriteCloser.
func (b *buffer) Read(p []byte) (int, error) {
b.mu.Lock()
defer b.mu.Unlock()
for {
n, err := b.underlying.Read(p)
if errors.Is(err, io.EOF) && !b.closed {
b.waiting = true
b.cond.Wait()
continue
}
return n, err
}
}
// Write implements io.ReadWriteCloser.
func (b *buffer) Write(p []byte) (int, error) {
b.mu.Lock()
defer b.mu.Unlock()
if b.closed {
return 0, io.ErrClosedPipe
}
if b.waiting {
b.waiting = false
defer b.cond.Signal()
}
return b.underlying.Write(p)
}
// Close implements io.ReadWriteCloser.
func (b *buffer) Close() error {
b.mu.Lock()
defer b.mu.Unlock()
b.closed = true
b.cond.Signal()
return nil
}
var _ io.ReadWriteCloser = (*buffer)(nil)
func (e *environment) NewLogRecorder(opts ...LogRecorderOption) *LogRecorder {
options := LogRecorderOptions{}
options.apply(opts...)
lr := &LogRecorder{
LogRecorderOptions: options,
t: e.t,
canceled: e.ctx.Done(),
buf: newBuffer(),
}
e.logWriter.Add(lr.buf)
lr.removeGlobalWriterOnce = sync.OnceFunc(func() {
// wait for envoy access logs, which flush on a 1 second interval
if !lr.skipCloseDelay {
time.Sleep(1100 * time.Millisecond)
}
e.logWriter.Remove(lr.buf)
})
context.AfterFunc(e.ctx, lr.removeGlobalWriterOnce)
return lr
}
type (
// OpenMap is an alias for map[string]any, and can be used to semantically
// represent a map that must contain at least the given entries, but may
// also contain additional entries.
OpenMap = map[string]any
// ClosedMap is a map[string]any that can be used to semantically represent
// a map that must contain the given entries exactly, and no others.
ClosedMap map[string]any
)
// Close stops the log recorder. After calling this method, Logs() or Match()
// can be called to inspect the logs that were captured.
func (lr *LogRecorder) Close() {
lr.removeGlobalWriterOnce()
}
func (lr *LogRecorder) collectLogs(shouldClose bool) {
if shouldClose {
lr.removeGlobalWriterOnce()
lr.buf.Close()
}
lr.collectLogsOnce.Do(func() {
recordedLogs := []map[string]any{}
scan := bufio.NewScanner(lr.buf)
for scan.Scan() {
log := scan.Bytes()
m := map[string]any{}
decoder := json.NewDecoder(bytes.NewReader(log))
decoder.UseNumber()
require.NoError(lr.t, decoder.Decode(&m))
for _, filter := range lr.filters {
if !filter(m) {
continue
}
}
recordedLogs = append(recordedLogs, m)
}
lr.recordedLogs = recordedLogs
})
}
func (lr *LogRecorder) WaitForMatch(expectedLog map[string]any, timeout ...time.Duration) {
lr.skipCloseDelay = true
found := make(chan struct{})
done := make(chan struct{})
lr.filters = append(lr.filters, func(entry map[string]any) bool {
select {
case <-found:
default:
if matched, _ := match(expectedLog, entry, true); matched {
close(found)
}
}
return true
})
go func() {
defer close(done)
lr.collectLogs(false)
lr.removeGlobalWriterOnce()
}()
if len(timeout) != 0 {
select {
case <-found:
case <-time.After(timeout[0]):
lr.t.Error("timed out waiting for log")
case <-lr.canceled:
lr.t.Error("canceled")
}
} else {
select {
case <-found:
case <-lr.canceled:
lr.t.Error("canceled")
}
}
lr.buf.Close()
<-done
}
// Logs stops the log recorder (if it is not already stopped), then returns
// the logs that were captured as structured map[string]any objects.
func (lr *LogRecorder) Logs() []map[string]any {
lr.collectLogs(true)
return lr.recordedLogs
}
func (lr *LogRecorder) DumpToFile(file string) {
lr.collectLogs(true)
f, err := os.Create(file)
require.NoError(lr.t, err)
enc := json.NewEncoder(f)
for _, log := range lr.recordedLogs {
_ = enc.Encode(log)
}
f.Close()
}
// Match stops the log recorder (if it is not already stopped), then asserts
// that the given expected logs were captured. The expected logs may contain
// partial or complete log entries. By default, logs must only match the fields
// given, and may contain additional fields that will be ignored.
//
// There are several special-case value types that can be used to customize the
// matching behavior, and/or simplify some common use cases, as follows:
// - [OpenMap] and [ClosedMap] can be used to control matching logic
// - [json.Number] will convert the actual value to a string before comparison
// - [*tls.Certificate] or [*x509.Certificate] will expand to the fields that
// would be logged for this certificate
func (lr *LogRecorder) Match(expectedLogs []map[string]any) {
lr.collectLogs(true)
for _, expectedLog := range expectedLogs {
found := false
highScore, highScoreIdxs := 0, []int{}
for i, actualLog := range lr.recordedLogs {
if ok, score := match(expectedLog, actualLog, true); ok {
found = true
break
} else if score > highScore {
highScore = score
highScoreIdxs = []int{i}
} else if score == highScore {
highScoreIdxs = append(highScoreIdxs, i)
}
}
if len(highScoreIdxs) > 0 {
expectedLogBytes, _ := json.MarshalIndent(expectedLog, "", " ")
if len(highScoreIdxs) == 1 {
actualLogBytes, _ := json.MarshalIndent(lr.recordedLogs[highScoreIdxs[0]], "", " ")
assert.True(lr.t, found, "expected log not found: \n%s\n\nclosest match:\n%s\n",
string(expectedLogBytes), string(actualLogBytes))
} else {
closestMatches := []string{}
for _, i := range highScoreIdxs {
bytes, _ := json.MarshalIndent(lr.recordedLogs[i], "", " ")
closestMatches = append(closestMatches, string(bytes))
}
assert.True(lr.t, found, "expected log not found: \n%s\n\nclosest matches:\n%s\n", string(expectedLogBytes), closestMatches)
}
} else {
expectedLogBytes, _ := json.MarshalIndent(expectedLog, "", " ")
assert.True(lr.t, found, "expected log not found: %s", string(expectedLogBytes))
}
}
}
func match(expected, actual map[string]any, open bool) (matched bool, score int) {
for key, value := range expected {
actualValue, ok := actual[key]
if !ok {
return false, score
}
score++
switch actualValue := actualValue.(type) {
case map[string]any:
switch expectedValue := value.(type) {
case ClosedMap:
ok, s := match(expectedValue, actualValue, false)
score += s * 2
if !ok {
return false, score
}
case OpenMap:
ok, s := match(expectedValue, actualValue, true)
score += s
if !ok {
return false, score
}
case *tls.Certificate, *Certificate, *x509.Certificate:
var leaf *x509.Certificate
switch expectedValue := expectedValue.(type) {
case *tls.Certificate:
leaf = expectedValue.Leaf
case *Certificate:
leaf = expectedValue.Leaf
case *x509.Certificate:
leaf = expectedValue
}
// keep logic consistent with controlplane.populateCertEventDict()
expected := map[string]any{}
if iss := leaf.Issuer.String(); iss != "" {
expected["issuer"] = iss
}
if sub := leaf.Subject.String(); sub != "" {
expected["subject"] = sub
}
sans := []string{}
for _, dnsSAN := range leaf.DNSNames {
sans = append(sans, "DNS:"+dnsSAN)
}
for _, uriSAN := range leaf.URIs {
sans = append(sans, "URI:"+uriSAN.String())
}
if len(sans) > 0 {
expected["subjectAltName"] = sans
}
ok, s := match(expected, actualValue, false)
score += s
if !ok {
return false, score
}
default:
return false, score
}
case string:
switch value := value.(type) {
case string:
if value != actualValue {
return false, score
}
score++
default:
return false, score
}
case json.Number:
if fmt.Sprint(value) != actualValue.String() {
return false, score
}
score++
default:
// handle slices
if reflect.TypeOf(actualValue).Kind() == reflect.Slice {
if reflect.TypeOf(value) != reflect.TypeOf(actualValue) {
return false, score
}
actualSlice := reflect.ValueOf(actualValue)
expectedSlice := reflect.ValueOf(value)
totalScore := 0
for i := range min(actualSlice.Len(), expectedSlice.Len()) {
if actualSlice.Index(i).Equal(expectedSlice.Index(i)) {
totalScore++
}
}
score += totalScore
} else {
panic(fmt.Sprintf("test bug: add check for type %T in assertMatchingLogs", actualValue))
}
}
}
if !open && len(expected) != len(actual) {
return false, score
}
return true, score
}

76
internal/testenv/route.go Normal file
View file

@ -0,0 +1,76 @@
package testenv
import (
"net/url"
"strings"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/testenv/values"
"github.com/pomerium/pomerium/pkg/policy/parser"
)
// PolicyRoute is a [Route] implementation suitable for most common use cases
// that can be used in implementations of [Upstream].
type PolicyRoute struct {
DefaultAttach
from values.Value[string]
to values.List[string]
edits []func(*config.Policy)
}
// Modify implements Route.
func (b *PolicyRoute) Modify(cfg *config.Config) {
to := make(config.WeightedURLs, 0, len(b.to))
for _, u := range b.to {
u, err := url.Parse(u.Value())
if err != nil {
panic(err)
}
to = append(to, config.WeightedURL{URL: *u})
}
p := config.Policy{
From: b.from.Value(),
To: to,
}
for _, edit := range b.edits {
edit(&p)
}
cfg.Options.Policies = append(cfg.Options.Policies, p)
}
// From implements Route.
func (b *PolicyRoute) From(fromURL values.Value[string]) Route {
b.from = fromURL
return b
}
// To implements Route.
func (b *PolicyRoute) To(toURL values.Value[string]) Route {
b.to = append(b.to, toURL)
return b
}
// To implements Route.
func (b *PolicyRoute) Policy(edit func(*config.Policy)) Route {
b.edits = append(b.edits, edit)
return b
}
// PPL implements Route.
func (b *PolicyRoute) PPL(ppl string) Route {
pplPolicy, err := parser.ParseYAML(strings.NewReader(ppl))
if err != nil {
panic(err)
}
b.edits = append(b.edits, func(p *config.Policy) {
p.Policy = &config.PPLPolicy{
Policy: pplPolicy,
}
})
return b
}
// To implements Route.
func (b *PolicyRoute) URL() values.Value[string] {
return b.from
}

View file

@ -0,0 +1,389 @@
package scenarios
import (
"context"
"crypto"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"encoding/base64"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"net/url"
"strconv"
"strings"
"time"
"github.com/go-jose/go-jose/v3"
"github.com/go-jose/go-jose/v3/jwt"
"github.com/google/uuid"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/encoding"
"github.com/pomerium/pomerium/internal/encoding/jws"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/testenv"
"github.com/pomerium/pomerium/internal/testenv/upstreams"
"github.com/pomerium/pomerium/internal/testenv/values"
"github.com/pomerium/pomerium/pkg/grpc/identity"
)
type IDP struct {
id values.Value[string]
url values.Value[string]
publicJWK jose.JSONWebKey
signingKey jose.SigningKey
stateEncoder encoding.MarshalUnmarshaler
userLookup map[string]*User
}
// Attach implements testenv.Modifier.
func (idp *IDP) Attach(ctx context.Context) {
env := testenv.EnvFromContext(ctx)
router := upstreams.HTTP(nil)
idp.url = values.Bind2(env.SubdomainURL("mock-idp"), router.Port(), func(urlStr string, port int) string {
u, _ := url.Parse(urlStr)
host, _, _ := net.SplitHostPort(u.Host)
return u.ResolveReference(&url.URL{
Scheme: "http",
Host: fmt.Sprintf("%s:%d", host, port),
}).String()
})
var err error
idp.stateEncoder, err = jws.NewHS256Signer(env.SharedSecret())
env.Require().NoError(err)
idp.id = values.Bind2(idp.url, env.AuthenticateURL(), func(idpUrl, authUrl string) string {
provider := identity.Provider{
AuthenticateServiceUrl: authUrl,
ClientId: "CLIENT_ID",
ClientSecret: "CLIENT_SECRET",
Type: "oidc",
Scopes: []string{"openid", "email", "profile"},
Url: idpUrl,
}
return provider.Hash()
})
router.Handle("/.well-known/jwks.json", func(w http.ResponseWriter, _ *http.Request) {
_ = json.NewEncoder(w).Encode(&jose.JSONWebKeySet{
Keys: []jose.JSONWebKey{idp.publicJWK},
})
})
router.Handle("/.well-known/openid-configuration", func(w http.ResponseWriter, r *http.Request) {
log.Ctx(ctx).Debug().Str("method", r.Method).Str("uri", r.RequestURI).Send()
rootURL, _ := url.Parse(idp.url.Value())
_ = json.NewEncoder(w).Encode(map[string]interface{}{
"issuer": rootURL.String(),
"authorization_endpoint": rootURL.ResolveReference(&url.URL{Path: "/oidc/auth"}).String(),
"token_endpoint": rootURL.ResolveReference(&url.URL{Path: "/oidc/token"}).String(),
"jwks_uri": rootURL.ResolveReference(&url.URL{Path: "/.well-known/jwks.json"}).String(),
"userinfo_endpoint": rootURL.ResolveReference(&url.URL{Path: "/oidc/userinfo"}).String(),
"id_token_signing_alg_values_supported": []string{
"ES256",
},
})
})
router.Handle("/oidc/auth", idp.HandleAuth)
router.Handle("/oidc/token", idp.HandleToken)
router.Handle("/oidc/userinfo", idp.HandleUserInfo)
env.AddUpstream(router)
}
// Modify implements testenv.Modifier.
func (idp *IDP) Modify(cfg *config.Config) {
cfg.Options.Provider = "oidc"
cfg.Options.ProviderURL = idp.url.Value()
cfg.Options.ClientID = "CLIENT_ID"
cfg.Options.ClientSecret = "CLIENT_SECRET"
cfg.Options.Scopes = []string{"openid", "email", "profile"}
}
var _ testenv.Modifier = (*IDP)(nil)
func NewIDP(users []*User) *IDP {
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
panic(err)
}
publicKey := &privateKey.PublicKey
signingKey := jose.SigningKey{
Algorithm: jose.ES256,
Key: privateKey,
}
publicJWK := jose.JSONWebKey{
Key: publicKey,
Algorithm: string(jose.ES256),
Use: "sig",
}
thumbprint, err := publicJWK.Thumbprint(crypto.SHA256)
if err != nil {
panic(err)
}
publicJWK.KeyID = hex.EncodeToString(thumbprint)
userLookup := map[string]*User{}
for _, user := range users {
user.ID = uuid.NewString()
userLookup[user.ID] = user
}
return &IDP{
publicJWK: publicJWK,
signingKey: signingKey,
userLookup: userLookup,
}
}
// HandleAuth handles the auth flow for OIDC.
func (idp *IDP) HandleAuth(w http.ResponseWriter, r *http.Request) {
rawRedirectURI := r.FormValue("redirect_uri")
if rawRedirectURI == "" {
http.Error(w, "missing redirect_uri", http.StatusBadRequest)
return
}
redirectURI, err := url.Parse(rawRedirectURI)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
rawClientID := r.FormValue("client_id")
if rawClientID == "" {
http.Error(w, "missing client_id", http.StatusBadRequest)
return
}
rawEmail := r.FormValue("email")
if rawEmail != "" {
http.Redirect(w, r, redirectURI.ResolveReference(&url.URL{
RawQuery: (url.Values{
"state": {r.FormValue("state")},
"code": {State{
Email: rawEmail,
ClientID: rawClientID,
}.Encode()},
}).Encode(),
}).String(), http.StatusFound)
return
}
serveHTML(w, `<!doctype html>
<html>
<head>
<title>Login</title>
</head>
<body>
<form method="POST" style="max-width: 200px">
<fieldset>
<legend>Login</legend>
<table>
<tbody>
<tr>
<th><label for="email">Email</label></th>
<td>
<input type="email" name="email" placeholder="email" />
</td>
</tr>
<tr>
<td colspan="2">
<input type="submit" />
</td>
</tr>
</tbody>
</table>
</fieldset>
</form>
</body>
</html>
`)
}
// HandleToken handles the token flow for OIDC.
func (idp *IDP) HandleToken(w http.ResponseWriter, r *http.Request) {
rawCode := r.FormValue("code")
state, err := DecodeState(rawCode)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
serveJSON(w, map[string]interface{}{
"access_token": state.Encode(),
"refresh_token": state.Encode(),
"token_type": "Bearer",
"id_token": state.GetIDToken(r, idp.userLookup).Encode(idp.signingKey),
})
}
// HandleUserInfo handles retrieving the user info.
func (idp *IDP) HandleUserInfo(w http.ResponseWriter, r *http.Request) {
authz := r.Header.Get("Authorization")
if authz == "" {
http.Error(w, "missing authorization header", http.StatusUnauthorized)
return
}
if strings.HasPrefix(authz, "Bearer ") {
authz = authz[len("Bearer "):]
} else if strings.HasPrefix(authz, "token ") {
authz = authz[len("token "):]
} else {
http.Error(w, "missing bearer token", http.StatusUnauthorized)
return
}
state, err := DecodeState(authz)
if err != nil {
http.Error(w, err.Error(), http.StatusForbidden)
return
}
serveJSON(w, state.GetUserInfo(idp.userLookup))
}
type RootURLKey struct{}
var rootURLKey RootURLKey
// WithRootURL sets the Root URL in a context.
func WithRootURL(ctx context.Context, rootURL *url.URL) context.Context {
return context.WithValue(ctx, rootURLKey, rootURL)
}
func getRootURL(r *http.Request) *url.URL {
if u, ok := r.Context().Value(rootURLKey).(*url.URL); ok {
return u
}
u := *r.URL
if r.Host != "" {
u.Host = r.Host
}
if u.Scheme == "" {
if r.TLS != nil {
u.Scheme = "https"
} else {
u.Scheme = "http"
}
}
u.Path = ""
return &u
}
func serveHTML(w http.ResponseWriter, html string) {
w.Header().Set("Content-Type", "text/html")
w.Header().Set("Content-Length", strconv.Itoa(len(html)))
w.WriteHeader(http.StatusOK)
_, _ = io.WriteString(w, html)
}
func serveJSON(w http.ResponseWriter, obj interface{}) {
bs, err := json.Marshal(obj)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_, _ = w.Write(bs)
}
type State struct {
Email string `json:"email"`
ClientID string `json:"client_id"`
}
func DecodeState(rawCode string) (*State, error) {
var state State
bs, _ := base64.URLEncoding.DecodeString(rawCode)
err := json.Unmarshal(bs, &state)
if err != nil {
return nil, err
}
return &state, nil
}
func (state State) Encode() string {
bs, _ := json.Marshal(state)
return base64.URLEncoding.EncodeToString(bs)
}
func (state State) GetIDToken(r *http.Request, users map[string]*User) *IDToken {
token := &IDToken{
UserInfo: state.GetUserInfo(users),
Issuer: getRootURL(r).String(),
Audience: state.ClientID,
Expiry: jwt.NewNumericDate(time.Now().Add(time.Hour * 24 * 365)),
IssuedAt: jwt.NewNumericDate(time.Now()),
}
return token
}
func (state State) GetUserInfo(users map[string]*User) *UserInfo {
userInfo := &UserInfo{
Subject: state.Email,
Email: state.Email,
}
for _, u := range users {
if u.Email == state.Email {
userInfo.Subject = u.ID
userInfo.Name = u.FirstName + " " + u.LastName
userInfo.FamilyName = u.LastName
userInfo.GivenName = u.FirstName
}
}
return userInfo
}
type UserInfo struct {
Subject string `json:"sub"`
Name string `json:"name"`
Email string `json:"email"`
FamilyName string `json:"family_name"`
GivenName string `json:"given_name"`
}
type IDToken struct {
*UserInfo
Issuer string `json:"iss"`
Audience string `json:"aud"`
Expiry *jwt.NumericDate `json:"exp"`
IssuedAt *jwt.NumericDate `json:"iat"`
}
func (token *IDToken) Encode(signingKey jose.SigningKey) string {
sig, err := jose.NewSigner(signingKey, (&jose.SignerOptions{}).WithType("JWT"))
if err != nil {
panic(err)
}
str, err := jwt.Signed(sig).Claims(token).CompactSerialize()
if err != nil {
panic(err)
}
return str
}
type User struct {
ID string
Email string
FirstName string
LastName string
}

View file

@ -0,0 +1,24 @@
package scenarios
import (
"context"
"encoding/base64"
"encoding/pem"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/testenv"
)
func DownstreamMTLS(mode config.MTLSEnforcement) testenv.Modifier {
return testenv.ModifierFunc(func(ctx context.Context, cfg *config.Config) {
env := testenv.EnvFromContext(ctx)
block := pem.Block{
Type: "CERTIFICATE",
Bytes: env.CACert().Leaf.Raw,
}
cfg.Options.DownstreamMTLS = config.DownstreamMTLSSettings{
CA: base64.StdEncoding.EncodeToString(pem.EncodeToMemory(&block)),
Enforcement: mode,
}
})
}

View file

@ -0,0 +1,64 @@
package snippets
import (
"bytes"
"context"
"strings"
"text/template"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/testenv"
"github.com/pomerium/pomerium/pkg/policy/parser"
)
var SimplePolicyTemplate = PolicyTemplate{
From: "https://from-{{.Idx}}.localhost",
To: "https://to-{{.Idx}}.localhost",
PPL: `{"allow":{"and":["email":{"is":"user-{{.Idx}}@example.com"}]}}`,
}
type PolicyTemplate struct {
From string
To string
PPL string
// Add more fields as needed (be sure to update newPolicyFromTemplate)
}
func TemplateRoutes(n int, tmpl PolicyTemplate) testenv.Modifier {
return testenv.ModifierFunc(func(_ context.Context, cfg *config.Config) {
for i := range n {
cfg.Options.Policies = append(cfg.Options.Policies, newPolicyFromTemplate(i, tmpl))
}
})
}
func newPolicyFromTemplate(i int, pt PolicyTemplate) config.Policy {
eval := func(in string) string {
t := template.New("policy")
tmpl, err := t.Parse(in)
if err != nil {
panic(err)
}
var out bytes.Buffer
if err := tmpl.Execute(&out, struct{ Idx int }{i}); err != nil {
panic(err)
}
return out.String()
}
pplPolicy, err := parser.ParseYAML(strings.NewReader(eval(pt.PPL)))
if err != nil {
panic(err)
}
to, err := config.ParseWeightedUrls(eval(pt.To))
if err != nil {
panic(err)
}
return config.Policy{
From: eval(pt.From),
To: to,
Policy: &config.PPLPolicy{Policy: pplPolicy},
}
}

View file

@ -0,0 +1,35 @@
package snippets
import (
"context"
"time"
"github.com/pomerium/pomerium/internal/testenv"
"github.com/pomerium/pomerium/pkg/grpcutil"
"google.golang.org/grpc"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/credentials/insecure"
)
func WaitStartupComplete(env testenv.Environment, timeout ...time.Duration) time.Duration {
start := time.Now()
recorder := env.NewLogRecorder()
if len(timeout) == 0 {
timeout = append(timeout, 1*time.Minute)
}
ctx, ca := context.WithTimeout(env.Context(), timeout[0])
defer ca()
recorder.WaitForMatch(map[string]any{
"syncer_id": "databroker",
"syncer_type": "type.googleapis.com/pomerium.config.Config",
"message": "listening for updates",
}, timeout...)
cc, err := grpc.Dial(env.DatabrokerURL().Value(),
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithChainUnaryInterceptor(grpcutil.WithUnarySignedJWT(env.SharedSecret)),
grpc.WithChainStreamInterceptor(grpcutil.WithStreamSignedJWT(env.SharedSecret)),
)
env.Require().NoError(err)
env.Require().True(cc.WaitForStateChange(ctx, connectivity.Ready))
return time.Since(start)
}

206
internal/testenv/types.go Normal file
View file

@ -0,0 +1,206 @@
package testenv
import (
"context"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/testenv/values"
)
type envContextKeyType struct{}
var envContextKey envContextKeyType
func EnvFromContext(ctx context.Context) Environment {
return ctx.Value(envContextKey).(Environment)
}
func ContextWithEnv(ctx context.Context, env Environment) context.Context {
return context.WithValue(ctx, envContextKey, env)
}
// A Modifier is an object whose presence in the test affects the Pomerium
// configuration in some way. When the test environment is started, a
// [*config.Config] is constructed by calling each added Modifier in order.
//
// For additional details, see [Environment.Add] and [Environment.Start].
type Modifier interface {
// Attach is called by an [Environment] (before Modify) to propagate the
// environment's context.
Attach(ctx context.Context)
// Modify is called by an [Environment] to mutate its configuration in some
// way required by this Modifier.
Modify(cfg *config.Config)
}
// DefaultAttach should be embedded in types implementing [Modifier] to
// automatically obtain environment context details and caller information.
type DefaultAttach struct {
env Environment
caller string
}
func (d *DefaultAttach) Env() Environment {
d.CheckAttached()
return d.env
}
func (d *DefaultAttach) Attach(ctx context.Context) {
if d.env != nil {
panic("internal test environment bug: Attach called twice")
}
d.env = EnvFromContext(ctx)
if d.env == nil {
panic("test bug: no environment in context")
}
}
func (d *DefaultAttach) CheckAttached() {
if d.env == nil {
if d.caller != "" {
panic("test bug: missing a call to Add for the object created at: " + d.caller)
}
panic("test bug: not attached (possibly missing a call to Add)")
}
}
func (d *DefaultAttach) RecordCaller() {
d.caller = getCaller(4)
}
// Aggregate should be embedded in types implementing [Modifier] when the type
// contains other modifiers. Used as an alternative to [DefaultAttach].
// Embedding this struct will properly keep track of when constituent modifiers
// are added, for validation and caller detection.
//
// Aggregate implements a no-op Modify() by default, but this can be overridden
// to make additional modifications. The aggregate's Modify() is called first.
type Aggregate struct {
env Environment
caller string
modifiers []Modifier
}
func (d *Aggregate) Add(mod Modifier) {
if d.env != nil {
if d.env.(*environment).getState() == NotRunning {
// If the test environment is running, adding to an aggregate is a no-op.
// If the test environment has not been started yet, the aggregate is
// being used like in the following example, which is incorrect:
//
// aggregate.Add(foo)
// env.Add(aggregate)
// aggregate.Add(bar)
// env.Start()
//
// It should instead be used like this:
//
// aggregate.Add(foo)
// aggregate.Add(bar)
// env.Add(aggregate)
// env.Start()
panic("test bug: cannot modify an aggregate that has already been added")
}
return
}
d.modifiers = append(d.modifiers, mod)
}
func (d *Aggregate) Env() Environment {
d.CheckAttached()
return d.env
}
func (d *Aggregate) Attach(ctx context.Context) {
if d.env != nil {
panic("internal test environment bug: Attach called twice")
}
d.env = EnvFromContext(ctx)
if d.env == nil {
panic("test bug: no environment in context")
}
d.env.(*environment).t.Helper()
for _, mod := range d.modifiers {
d.env.Add(mod)
}
}
func (d *Aggregate) Modify(*config.Config) {}
func (d *Aggregate) CheckAttached() {
if d.env == nil {
if d.caller != "" {
panic("test bug: missing a call to Add for the object created at: " + d.caller)
}
panic("test bug: not attached (possibly missing a call to Add)")
}
}
func (d *Aggregate) RecordCaller() {
d.caller = getCaller(4)
}
type modifierFunc struct {
fn func(ctx context.Context, cfg *config.Config)
ctx context.Context
}
// Attach implements Modifier.
func (f *modifierFunc) Attach(ctx context.Context) {
f.ctx = ctx
}
func (f *modifierFunc) Modify(cfg *config.Config) {
f.fn(f.ctx, cfg)
}
var _ Modifier = (*modifierFunc)(nil)
func ModifierFunc(fn func(ctx context.Context, cfg *config.Config)) Modifier {
return &modifierFunc{fn: fn}
}
// Task represents a background task that can be added to an [Environment] to
// have it run automatically on startup.
//
// For additional details, see [Environment.AddTask] and [Environment.Start].
type Task interface {
Run(ctx context.Context) error
}
type TaskFunc func(ctx context.Context) error
func (f TaskFunc) Run(ctx context.Context) error {
return f(ctx)
}
// Upstream represents an upstream server. It is both a [Task] and a [Modifier]
// and can be added to an environment using [Environment.AddUpstream]. From an
// Upstream instance, new routes can be created (which automatically adds the
// necessary route/policy entries to the config), and used within a test to
// easily make requests to the routes with implementation-specific clients.
type Upstream interface {
Modifier
Task
Port() values.Value[int]
Route() RouteStub
}
// A Route represents a route from a source URL to a destination URL. A route is
// typically created by calling [Upstream.Route].
type Route interface {
Modifier
URL() values.Value[string]
To(toURL values.Value[string]) Route
Policy(edit func(*config.Policy)) Route
PPL(ppl string) Route
// add more methods here as they become needed
}
// RouteStub represents an incomplete [Route]. Providing a URL by calling its
// From() method will return a [Route], from which further configuration can
// be made.
type RouteStub interface {
From(fromURL values.Value[string]) Route
}

View file

@ -0,0 +1,139 @@
package upstreams
import (
"context"
"fmt"
"net"
"strings"
"github.com/pomerium/pomerium/internal/testenv"
"github.com/pomerium/pomerium/internal/testenv/values"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
)
type Options struct {
serverOpts []grpc.ServerOption
}
type Option func(*Options)
func (o *Options) apply(opts ...Option) {
for _, op := range opts {
op(o)
}
}
func ServerOpts(opt ...grpc.ServerOption) Option {
return func(o *Options) {
o.serverOpts = append(o.serverOpts, opt...)
}
}
// GRPCUpstream represents a GRPC server which can be used as the target for
// one or more Pomerium routes in a test environment.
//
// This upstream implements [grpc.ServiceRegistrar], and can be used similarly
// in the same way as [*grpc.Server] to register services before it is started.
//
// Any [testenv.Route] instances created from this upstream can be referenced
// in the Dial() method to establish a connection to that route.
type GRPCUpstream interface {
testenv.Upstream
grpc.ServiceRegistrar
Dial(r testenv.Route, dialOpts ...grpc.DialOption) *grpc.ClientConn
}
type grpcUpstream struct {
Options
testenv.Aggregate
serverPort values.MutableValue[int]
creds credentials.TransportCredentials
services []service
}
var (
_ testenv.Upstream = (*grpcUpstream)(nil)
_ grpc.ServiceRegistrar = (*grpcUpstream)(nil)
)
// GRPC creates a new GRPC upstream server.
func GRPC(creds credentials.TransportCredentials, opts ...Option) GRPCUpstream {
options := Options{}
options.apply(opts...)
up := &grpcUpstream{
Options: options,
creds: creds,
serverPort: values.Deferred[int](),
}
up.RecordCaller()
return up
}
type service struct {
desc *grpc.ServiceDesc
impl any
}
func (g *grpcUpstream) Port() values.Value[int] {
return g.serverPort
}
// RegisterService implements grpc.ServiceRegistrar.
func (g *grpcUpstream) RegisterService(desc *grpc.ServiceDesc, impl any) {
g.services = append(g.services, service{desc, impl})
}
// Route implements testenv.Upstream.
func (g *grpcUpstream) Route() testenv.RouteStub {
r := &testenv.PolicyRoute{}
var protocol string
switch g.creds.Info().SecurityProtocol {
case "insecure":
protocol = "h2c"
default:
protocol = "https"
}
r.To(values.Bind(g.serverPort, func(port int) string {
return fmt.Sprintf("%s://127.0.0.1:%d", protocol, port)
}))
g.Add(r)
return r
}
// Start implements testenv.Upstream.
func (g *grpcUpstream) Run(ctx context.Context) error {
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
return err
}
g.serverPort.Resolve(listener.Addr().(*net.TCPAddr).Port)
server := grpc.NewServer(append(g.serverOpts, grpc.Creds(g.creds))...)
for _, s := range g.services {
server.RegisterService(s.desc, s.impl)
}
errC := make(chan error, 1)
go func() {
errC <- server.Serve(listener)
}()
select {
case <-ctx.Done():
server.Stop()
return context.Cause(ctx)
case err := <-errC:
return err
}
}
func (g *grpcUpstream) Dial(r testenv.Route, dialOpts ...grpc.DialOption) *grpc.ClientConn {
dialOpts = append(dialOpts,
grpc.WithTransportCredentials(credentials.NewClientTLSFromCert(g.Env().ServerCAs(), "")),
grpc.WithDefaultCallOptions(grpc.WaitForReady(true)),
)
cc, err := grpc.NewClient(strings.TrimPrefix(r.URL().Value(), "https://"), dialOpts...)
if err != nil {
panic(err)
}
return cc
}

View file

@ -0,0 +1,327 @@
package upstreams
import (
"bytes"
"context"
"crypto/tls"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/http/cookiejar"
"net/url"
"strconv"
"strings"
"sync"
"time"
"github.com/gorilla/mux"
"github.com/pomerium/pomerium/integration/forms"
"github.com/pomerium/pomerium/internal/retry"
"github.com/pomerium/pomerium/internal/testenv"
"github.com/pomerium/pomerium/internal/testenv/values"
"google.golang.org/protobuf/proto"
)
type RequestOptions struct {
path string
query url.Values
headers map[string]string
authenticateAs string
body any
clientCerts []tls.Certificate
client *http.Client
}
type RequestOption func(*RequestOptions)
func (o *RequestOptions) apply(opts ...RequestOption) {
for _, op := range opts {
op(o)
}
}
// Path sets the path of the request. If omitted, the request URL will match
// the route URL exactly.
func Path(path string) RequestOption {
return func(o *RequestOptions) {
o.path = path
}
}
// Query sets optional query parameters of the request.
func Query(query url.Values) RequestOption {
return func(o *RequestOptions) {
o.query = query
}
}
// Headers adds optional headers to the request.
func Headers(headers map[string]string) RequestOption {
return func(o *RequestOptions) {
o.headers = headers
}
}
func AuthenticateAs(email string) RequestOption {
return func(o *RequestOptions) {
o.authenticateAs = email
}
}
func Client(c *http.Client) RequestOption {
return func(o *RequestOptions) {
o.client = c
}
}
// Body sets the body of the request.
// The argument can be one of the following types:
// - string
// - []byte
// - io.Reader
// - proto.Message
// - any json-encodable type
// If the argument is encoded as json, the Content-Type header will be set to
// "application/json". If the argument is a proto.Message, the Content-Type
// header will be set to "application/octet-stream".
func Body(body any) RequestOption {
return func(o *RequestOptions) {
o.body = body
}
}
// ClientCert adds a client certificate to the request.
func ClientCert[T interface {
*testenv.Certificate | *tls.Certificate
}](cert T) RequestOption {
return func(o *RequestOptions) {
o.clientCerts = append(o.clientCerts, *(*tls.Certificate)(cert))
}
}
// HTTPUpstream represents a HTTP server which can be used as the target for
// one or more Pomerium routes in a test environment.
//
// The Handle() method can be used to add handlers the server-side HTTP router,
// while the Get(), Post(), and (generic) Do() methods can be used to make
// client-side requests.
type HTTPUpstream interface {
testenv.Upstream
Handle(path string, f func(http.ResponseWriter, *http.Request)) *mux.Route
Get(r testenv.Route, opts ...RequestOption) (*http.Response, error)
Post(r testenv.Route, opts ...RequestOption) (*http.Response, error)
Do(method string, r testenv.Route, opts ...RequestOption) (*http.Response, error)
}
type httpUpstream struct {
testenv.Aggregate
serverPort values.MutableValue[int]
tlsConfig values.Value[*tls.Config]
clientCache sync.Map // map[testenv.Route]*http.Client
router *mux.Router
}
var (
_ testenv.Upstream = (*httpUpstream)(nil)
_ HTTPUpstream = (*httpUpstream)(nil)
)
// HTTP creates a new HTTP upstream server.
func HTTP(tlsConfig values.Value[*tls.Config]) HTTPUpstream {
up := &httpUpstream{
serverPort: values.Deferred[int](),
router: mux.NewRouter(),
tlsConfig: tlsConfig,
}
up.RecordCaller()
return up
}
// Port implements HTTPUpstream.
func (h *httpUpstream) Port() values.Value[int] {
return h.serverPort
}
// Router implements HTTPUpstream.
func (h *httpUpstream) Handle(path string, f func(http.ResponseWriter, *http.Request)) *mux.Route {
return h.router.HandleFunc(path, f)
}
// Route implements HTTPUpstream.
func (h *httpUpstream) Route() testenv.RouteStub {
r := &testenv.PolicyRoute{}
protocol := "http"
r.To(values.Bind(h.serverPort, func(port int) string {
return fmt.Sprintf("%s://127.0.0.1:%d", protocol, port)
}))
h.Add(r)
return r
}
// Run implements HTTPUpstream.
func (h *httpUpstream) Run(ctx context.Context) error {
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
return err
}
h.serverPort.Resolve(listener.Addr().(*net.TCPAddr).Port)
var tlsConfig *tls.Config
if h.tlsConfig != nil {
tlsConfig = h.tlsConfig.Value()
}
server := &http.Server{
Handler: h.router,
TLSConfig: tlsConfig,
BaseContext: func(net.Listener) context.Context {
return ctx
},
}
errC := make(chan error, 1)
go func() {
errC <- server.Serve(listener)
}()
select {
case <-ctx.Done():
server.Close()
return context.Cause(ctx)
case err := <-errC:
return err
}
}
// Get implements HTTPUpstream.
func (h *httpUpstream) Get(r testenv.Route, opts ...RequestOption) (*http.Response, error) {
return h.Do(http.MethodGet, r, opts...)
}
// Post implements HTTPUpstream.
func (h *httpUpstream) Post(r testenv.Route, opts ...RequestOption) (*http.Response, error) {
return h.Do(http.MethodPost, r, opts...)
}
// Do implements HTTPUpstream.
func (h *httpUpstream) Do(method string, r testenv.Route, opts ...RequestOption) (*http.Response, error) {
options := RequestOptions{}
options.apply(opts...)
u, err := url.Parse(r.URL().Value())
if err != nil {
return nil, err
}
if options.path != "" || options.query != nil {
u = u.ResolveReference(&url.URL{
Path: options.path,
RawQuery: options.query.Encode(),
})
}
req, err := http.NewRequest(method, u.String(), nil)
if err != nil {
return nil, err
}
switch body := options.body.(type) {
case string:
req.Body = io.NopCloser(strings.NewReader(body))
case []byte:
req.Body = io.NopCloser(bytes.NewReader(body))
case io.Reader:
req.Body = io.NopCloser(body)
case proto.Message:
buf, err := proto.Marshal(body)
if err != nil {
return nil, err
}
req.Body = io.NopCloser(bytes.NewReader(buf))
req.Header.Set("Content-Type", "application/octet-stream")
default:
buf, err := json.Marshal(body)
if err != nil {
panic(fmt.Sprintf("unsupported body type: %T", body))
}
req.Body = io.NopCloser(bytes.NewReader(buf))
req.Header.Set("Content-Type", "application/json")
case nil:
}
newClient := func() *http.Client {
c := http.Client{
Transport: &http.Transport{
TLSClientConfig: &tls.Config{
RootCAs: h.Env().ServerCAs(),
Certificates: options.clientCerts,
},
},
}
c.Jar, _ = cookiejar.New(&cookiejar.Options{})
return &c
}
var client *http.Client
if options.client != nil {
client = options.client
} else {
var cachedClient any
var ok bool
if cachedClient, ok = h.clientCache.Load(r); !ok {
cachedClient, _ = h.clientCache.LoadOrStore(r, newClient())
}
client = cachedClient.(*http.Client)
}
var resp *http.Response
if err := retry.Retry(h.Env().Context(), "http", func(ctx context.Context) error {
var err error
if options.authenticateAs != "" {
resp, err = authenticateFlow(ctx, client, req, options.authenticateAs) //nolint:bodyclose
} else {
resp, err = client.Do(req) //nolint:bodyclose
}
// retry on connection refused
if err != nil {
var opErr *net.OpError
if errors.As(err, &opErr) && opErr.Op == "dial" && opErr.Err.Error() == "connect: connection refused" {
return err
}
return retry.NewTerminalError(err)
}
if resp.StatusCode == http.StatusInternalServerError {
return errors.New(http.StatusText(resp.StatusCode))
}
return nil
}, retry.WithMaxInterval(100*time.Millisecond)); err != nil {
return nil, err
}
return resp, nil
}
func authenticateFlow(ctx context.Context, client *http.Client, req *http.Request, email string) (*http.Response, error) {
var res *http.Response
originalHostname := req.URL.Hostname()
res, err := client.Do(req)
if err != nil {
return nil, err
}
location := res.Request.URL
if location.Hostname() == originalHostname {
// already authenticated
return res, err
}
defer res.Body.Close()
fs := forms.Parse(res.Body)
if len(fs) > 0 {
f := fs[0]
f.Inputs["email"] = email
f.Inputs["token_expiration"] = strconv.Itoa(int((time.Hour * 24).Seconds()))
formReq, err := f.NewRequestWithContext(ctx, location)
if err != nil {
return nil, err
}
return client.Do(formReq)
}
return nil, fmt.Errorf("test bug: expected IDP login form")
}

View file

@ -0,0 +1,120 @@
package values
import (
"math/rand/v2"
"sync"
)
type value[T any] struct {
f func() T
ready bool
cond *sync.Cond
}
// A Value is a container for a single value of type T, whose initialization is
// performed the first time Value() is called. Subsequent calls will return the
// same value. The Value() function may block until the value is ready on the
// first call. Values are safe to use concurrently.
type Value[T any] interface {
Value() T
}
// MutableValue is the read-write counterpart to [Value], created by calling
// [Deferred] for some type T. Calling Resolve() or ResolveFunc() will set
// the value and unblock any waiting calls to Value().
type MutableValue[T any] interface {
Value[T]
Resolve(value T)
ResolveFunc(fOnce func() T)
}
// Deferred creates a new read-write [MutableValue] for some type T,
// representing a value whose initialization may be deferred to a later time.
// Once the value is available, call [MutableValue.Resolve] or
// [MutableValue.ResolveFunc] to unblock any waiting calls to Value().
func Deferred[T any]() MutableValue[T] {
return &value[T]{
cond: sync.NewCond(&sync.Mutex{}),
}
}
// Const creates a read-only [Value] which will become available immediately
// upon calling Value() for the first time; it will never block.
func Const[T any](t T) Value[T] {
return &value[T]{
f: func() T { return t },
ready: true,
cond: sync.NewCond(&sync.Mutex{}),
}
}
func (p *value[T]) Value() T {
p.cond.L.Lock()
defer p.cond.L.Unlock()
for !p.ready {
p.cond.Wait()
}
return p.f()
}
func (p *value[T]) ResolveFunc(fOnce func() T) {
p.cond.L.Lock()
p.f = sync.OnceValue(fOnce)
p.ready = true
p.cond.L.Unlock()
p.cond.Broadcast()
}
func (p *value[T]) Resolve(value T) {
p.ResolveFunc(func() T { return value })
}
// Bind creates a new [MutableValue] whose ultimate value depends on the result
// of another [Value] that may not yet be available. When Value() is called on
// the result, it will cascade and trigger the full chain of initialization
// functions necessary to produce the final value.
//
// Care should be taken when using this function, as improper use can lead to
// deadlocks and cause values to never become available.
func Bind[T any, U any](dt Value[T], callback func(value T) U) Value[U] {
du := Deferred[U]()
du.ResolveFunc(func() U {
return callback(dt.Value())
})
return du
}
// Bind2 is like [Bind], but can accept two input values. The result will only
// become available once all input values become available.
//
// This function blocks to wait for each input value in sequence, but in a
// random order. Do not rely on the order of evaluation of the input values.
func Bind2[T any, U any, V any](dt Value[T], du Value[U], callback func(value1 T, value2 U) V) Value[V] {
dv := Deferred[V]()
dv.ResolveFunc(func() V {
if rand.IntN(2) == 0 { //nolint:gosec
return callback(dt.Value(), du.Value())
}
u := du.Value()
t := dt.Value()
return callback(t, u)
})
return dv
}
// List is a container for a slice of [Value] of type T, and is also a [Value]
// itself, for convenience. The Value() function will return a []T containing
// all resolved values for each element in the slice.
//
// A List's Value() function blocks to wait for each element in the slice in
// sequence, but in a random order. Do not rely on the order of evaluation of
// the slice elements.
type List[T any] []Value[T]
func (s List[T]) Value() []T {
values := make([]T, len(s))
for _, i := range rand.Perm(len(values)) {
values[i] = s[i].Value()
}
return values
}

View file

@ -16,6 +16,7 @@ import (
"github.com/pomerium/pomerium/authenticate"
"github.com/pomerium/pomerium/authorize"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/config/envoyconfig/filemgr"
databroker_service "github.com/pomerium/pomerium/databroker"
"github.com/pomerium/pomerium/internal/autocert"
"github.com/pomerium/pomerium/internal/controlplane"
@ -30,8 +31,29 @@ import (
"github.com/pomerium/pomerium/proxy"
)
type RunOptions struct {
fileMgr *filemgr.Manager
}
type RunOption func(*RunOptions)
func (o *RunOptions) apply(opts ...RunOption) {
for _, op := range opts {
op(o)
}
}
func WithOverrideFileManager(fileMgr *filemgr.Manager) RunOption {
return func(o *RunOptions) {
o.fileMgr = fileMgr
}
}
// Run runs the main pomerium application.
func Run(ctx context.Context, src config.Source) error {
func Run(ctx context.Context, src config.Source, opts ...RunOption) error {
options := RunOptions{}
options.apply(opts...)
_, _ = maxprocs.Set(maxprocs.Logger(func(s string, i ...any) { log.Ctx(ctx).Debug().Msgf(s, i...) }))
evt := log.Ctx(ctx).Info().
@ -68,10 +90,15 @@ func Run(ctx context.Context, src config.Source) error {
eventsMgr := events.New()
fileMgr := options.fileMgr
if fileMgr == nil {
fileMgr = filemgr.NewManager()
}
cfg := src.GetConfig()
// setup the control plane
controlPlane, err := controlplane.NewServer(ctx, cfg, metricsMgr, eventsMgr)
controlPlane, err := controlplane.NewServer(ctx, cfg, metricsMgr, eventsMgr, fileMgr)
if err != nil {
return fmt.Errorf("error creating control plane: %w", err)
}

View file

@ -1,7 +1,6 @@
package cryptutil
import (
"context"
"crypto/tls"
"crypto/x509"
"encoding/base64"
@ -15,10 +14,9 @@ import (
// GetCertPool gets a cert pool for the given CA or CAFile.
func GetCertPool(ca, caFile string) (*x509.CertPool, error) {
ctx := context.TODO()
rootCAs, err := x509.SystemCertPool()
if err != nil {
log.Ctx(ctx).Error().Err(err).Msg("pkg/cryptutil: failed getting system cert pool making new one")
log.Error().Err(err).Msg("pkg/cryptutil: failed getting system cert pool making new one")
rootCAs = x509.NewCertPool()
}
if ca == "" && caFile == "" {
@ -40,7 +38,9 @@ func GetCertPool(ca, caFile string) (*x509.CertPool, error) {
if ok := rootCAs.AppendCertsFromPEM(data); !ok {
return nil, fmt.Errorf("failed to append any PEM-encoded certificates")
}
log.Ctx(ctx).Debug().Msg("pkg/cryptutil: added custom certificate authority")
if !log.DebugDisableGlobalMessages.Load() {
log.Debug().Msg("pkg/cryptutil: added custom certificate authority")
}
return rootCAs, nil
}

View file

@ -186,7 +186,25 @@ func (srv *Server) run(ctx context.Context, cfg *config.Config) error {
// monitor the process so we exit if it prematurely exits
var monitorProcessCtx context.Context
monitorProcessCtx, srv.monitorProcessCancel = context.WithCancel(context.WithoutCancel(ctx))
go srv.monitorProcess(monitorProcessCtx, int32(cmd.Process.Pid))
go func() {
pid := cmd.Process.Pid
err := srv.monitorProcess(monitorProcessCtx, int32(pid))
if err != nil && ctx.Err() == nil {
// If the envoy subprocess exits and ctx is not done, issue a fatal error.
// If ctx is done, the server is already exiting, and envoy is expected
// to be stopped along with it.
log.Ctx(ctx).
Fatal().
Int("pid", pid).
Err(err).
Send()
}
log.Ctx(ctx).
Debug().
Int("pid", pid).
Err(ctx.Err()).
Msg("envoy process monitor stopped")
}()
if srv.resourceMonitor != nil {
log.Ctx(ctx).Debug().Str("service", "envoy").Msg("starting resource monitor")
@ -300,7 +318,7 @@ func (srv *Server) handleLogs(ctx context.Context, rc io.ReadCloser) {
}
}
func (srv *Server) monitorProcess(ctx context.Context, pid int32) {
func (srv *Server) monitorProcess(ctx context.Context, pid int32) error {
log.Ctx(ctx).Debug().
Int32("pid", pid).
Msg("envoy: start monitoring subprocess")
@ -311,19 +329,15 @@ func (srv *Server) monitorProcess(ctx context.Context, pid int32) {
for {
exists, err := process.PidExistsWithContext(ctx, pid)
if err != nil {
log.Fatal().Err(err).
Int32("pid", pid).
Msg("envoy: error retrieving subprocess information")
return fmt.Errorf("envoy: error retrieving subprocess information: %w", err)
} else if !exists {
log.Fatal().Err(err).
Int32("pid", pid).
Msg("envoy: subprocess exited")
return errors.New("envoy: subprocess exited")
}
// wait for the next tick
select {
case <-ctx.Done():
return
return nil
case <-ticker.C:
}
}

View file

@ -6,6 +6,7 @@ package envoy
import (
"context"
"os"
"path/filepath"
"strconv"
"sync"
"syscall"
@ -17,7 +18,7 @@ import (
"github.com/pomerium/pomerium/internal/telemetry/metrics"
)
const baseIDPath = "/tmp/pomerium-envoy-base-id"
const baseIDName = "pomerium-envoy-base-id"
var restartEpoch struct {
sync.Mutex
@ -89,7 +90,7 @@ func (srv *Server) prepareRunEnvoyCommand(ctx context.Context, sharedArgs []stri
} else {
args = append(args,
"--use-dynamic-base-id",
"--base-id-path", baseIDPath,
"--base-id-path", filepath.Join(os.TempDir(), baseIDName),
)
restartEpoch.value = 1
}
@ -99,7 +100,7 @@ func (srv *Server) prepareRunEnvoyCommand(ctx context.Context, sharedArgs []stri
}
func readBaseID() (int, bool) {
bs, err := os.ReadFile(baseIDPath)
bs, err := os.ReadFile(filepath.Join(os.TempDir(), baseIDName))
if err != nil {
return 0, false
}

View file

@ -3,6 +3,7 @@ package databroker
import (
"context"
"fmt"
"sync/atomic"
"time"
backoff "github.com/cenkalti/backoff/v4"
@ -71,12 +72,24 @@ type Syncer struct {
id string
}
var DebugUseFasterBackoff atomic.Bool
// NewSyncer creates a new Syncer.
func NewSyncer(ctx context.Context, id string, handler SyncerHandler, options ...SyncerOption) *Syncer {
closeCtx, closeCtxCancel := context.WithCancel(context.WithoutCancel(ctx))
bo := backoff.NewExponentialBackOff()
bo.MaxElapsedTime = 0
var bo *backoff.ExponentialBackOff
if DebugUseFasterBackoff.Load() {
bo = backoff.NewExponentialBackOff(
backoff.WithInitialInterval(10*time.Millisecond),
backoff.WithMultiplier(1.0),
backoff.WithMaxElapsedTime(100*time.Millisecond),
)
bo.MaxElapsedTime = 0
} else {
bo = backoff.NewExponentialBackOff()
bo.MaxElapsedTime = 0
}
s := &Syncer{
cfg: getSyncerConfig(options...),
handler: handler,