mirror of
https://github.com/pomerium/pomerium.git
synced 2025-08-06 10:21:05 +02:00
Initial test environment implementation
This commit is contained in:
parent
6591e3f539
commit
79ba9fcf52
14 changed files with 1085 additions and 14 deletions
|
@ -18,10 +18,74 @@ import (
|
|||
"google.golang.org/grpc/status"
|
||||
|
||||
"github.com/pomerium/pomerium/config"
|
||||
"github.com/pomerium/pomerium/internal/testenv"
|
||||
"github.com/pomerium/pomerium/internal/testenv/upstreams"
|
||||
"github.com/pomerium/pomerium/internal/testenv/values"
|
||||
"github.com/pomerium/pomerium/pkg/cmd/pomerium"
|
||||
"github.com/pomerium/pomerium/pkg/netutil"
|
||||
)
|
||||
|
||||
func TestH2C_v2(t *testing.T) {
|
||||
env := testenv.New(t)
|
||||
|
||||
up := upstreams.GRPC(insecure.NewCredentials())
|
||||
grpc_testing.RegisterTestServiceServer(up, interop.NewTestServer())
|
||||
|
||||
http := up.Route().
|
||||
From(env.SubdomainURL("grpc-http")).
|
||||
To(values.Bind(up.Port(), func(port int) string {
|
||||
// override the target protocol to use http://
|
||||
return fmt.Sprintf("http://127.0.0.1:%d", port)
|
||||
})).
|
||||
Policy(func(p *config.Policy) { p.AllowPublicUnauthenticatedAccess = true })
|
||||
|
||||
h2c := up.Route().
|
||||
From(env.SubdomainURL("grpc-h2c")).
|
||||
Policy(func(p *config.Policy) { p.AllowPublicUnauthenticatedAccess = true })
|
||||
|
||||
env.AddUpstream(up)
|
||||
env.Start()
|
||||
|
||||
t.Run("h2c", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
recorder := env.NewLogRecorder()
|
||||
|
||||
cc := up.Dial(h2c)
|
||||
client := grpc_testing.NewTestServiceClient(cc)
|
||||
_, err := client.EmptyCall(env.Context(), &grpc_testing.Empty{})
|
||||
require.NoError(t, err)
|
||||
cc.Close()
|
||||
|
||||
recorder.Match([]map[string]any{
|
||||
{
|
||||
"service": "envoy",
|
||||
"path": "/grpc.testing.TestService/EmptyCall",
|
||||
"message": "http-request",
|
||||
"response-code-details": "via_upstream",
|
||||
},
|
||||
})
|
||||
})
|
||||
t.Run("http", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
recorder := env.NewLogRecorder()
|
||||
|
||||
cc := up.Dial(http)
|
||||
client := grpc_testing.NewTestServiceClient(cc)
|
||||
_, err := client.UnaryCall(env.Context(), &grpc_testing.SimpleRequest{})
|
||||
require.Error(t, err)
|
||||
cc.Close()
|
||||
|
||||
recorder.Match([]map[string]any{
|
||||
{
|
||||
"service": "envoy",
|
||||
"path": "/grpc.testing.TestService/UnaryCall",
|
||||
"message": "http-request",
|
||||
"response-code-details": "upstream_reset_before_response_started{protocol_error}",
|
||||
},
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestH2C(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.SkipNow()
|
||||
|
|
2
go.mod
2
go.mod
|
@ -50,6 +50,7 @@ require (
|
|||
github.com/peterbourgon/ff/v3 v3.4.0
|
||||
github.com/pomerium/csrf v1.7.0
|
||||
github.com/pomerium/datasource v0.18.2-0.20221108160055-c6134b5ed524
|
||||
github.com/pomerium/protoutil v0.0.0-20240813175624-47b7ac43ff46
|
||||
github.com/pomerium/webauthn v0.0.0-20240603205124-0428df511172
|
||||
github.com/prometheus/client_golang v1.19.1
|
||||
github.com/prometheus/client_model v0.6.1
|
||||
|
@ -177,7 +178,6 @@ require (
|
|||
github.com/morikuni/aec v1.0.0 // indirect
|
||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
|
||||
github.com/onsi/ginkgo v1.16.5 // indirect
|
||||
github.com/onsi/gomega v1.30.0 // indirect
|
||||
github.com/opencontainers/go-digest v1.0.0 // indirect
|
||||
github.com/opencontainers/image-spec v1.1.0 // indirect
|
||||
github.com/opencontainers/runc v1.1.12 // indirect
|
||||
|
|
6
go.sum
6
go.sum
|
@ -481,8 +481,8 @@ github.com/onsi/ginkgo v1.16.5/go.mod h1:+E8gABHa3K6zRBolWtd+ROzc/U5bkGt0FwiG042
|
|||
github.com/onsi/gomega v1.4.3/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY=
|
||||
github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7JYyY=
|
||||
github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1ybHNo=
|
||||
github.com/onsi/gomega v1.30.0 h1:hvMK7xYz4D3HapigLTeGdId/NcfQx1VHMJc60ew99+8=
|
||||
github.com/onsi/gomega v1.30.0/go.mod h1:9sxs+SwGrKI0+PWe4Fxa9tFQQBG5xSsSbMXOI8PPpoQ=
|
||||
github.com/onsi/gomega v1.34.1 h1:EUMJIKUjM8sKjYbtxQI9A4z2o+rruxnzNvpknOXie6k=
|
||||
github.com/onsi/gomega v1.34.1/go.mod h1:kU1QgUvBDLXBJq618Xvm2LUX6rSAfRaFRTcdOeDLwwY=
|
||||
github.com/open-policy-agent/opa v0.67.1 h1:rzy26J6g1X+CKknAcx0Vfbt41KqjuSzx4E0A8DAZf3E=
|
||||
github.com/open-policy-agent/opa v0.67.1/go.mod h1:aqKlHc8E2VAAylYE9x09zJYr/fYzGX+JKne89UGqFzk=
|
||||
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
|
||||
|
@ -522,6 +522,8 @@ github.com/pomerium/csrf v1.7.0 h1:Qp4t6oyEod3svQtKfJZs589mdUTWKVf7q0PgCKYCshY=
|
|||
github.com/pomerium/csrf v1.7.0/go.mod h1:hAPZV47mEj2T9xFs+ysbum4l7SF1IdrryYaY6PdoIqw=
|
||||
github.com/pomerium/datasource v0.18.2-0.20221108160055-c6134b5ed524 h1:3YQY1sb54tEEbr0L73rjHkpLB0IB6qh3zl1+XQbMLis=
|
||||
github.com/pomerium/datasource v0.18.2-0.20221108160055-c6134b5ed524/go.mod h1:7fGbUYJnU8RcxZJvUvhukOIBv1G7LWDAHMfDxAf5+Y0=
|
||||
github.com/pomerium/protoutil v0.0.0-20240813175624-47b7ac43ff46 h1:NRTg8JOXCxcIA1lAgD74iYud0rbshbWOB3Ou4+Huil8=
|
||||
github.com/pomerium/protoutil v0.0.0-20240813175624-47b7ac43ff46/go.mod h1:QqZmx6ZgPxz18va7kqoT4t/0yJtP7YFIDiT/W2n2fZ4=
|
||||
github.com/pomerium/webauthn v0.0.0-20240603205124-0428df511172 h1:TqoPqRgXSHpn+tEJq6H72iCS5pv66j3rPprThUEZg0E=
|
||||
github.com/pomerium/webauthn v0.0.0-20240603205124-0428df511172/go.mod h1:kBQ45E9LluzW7FP1Scn3esaiS2WVbvNRLMOTHareZNQ=
|
||||
github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 h1:o4JXh1EVt9k/+g42oCprj/FisM4qX9L3sZB3upGN2ZU=
|
||||
|
|
|
@ -69,7 +69,7 @@ type Server struct {
|
|||
}
|
||||
|
||||
// NewServer creates a new Server. Listener ports are chosen by the OS.
|
||||
func NewServer(cfg *config.Config, metricsMgr *config.MetricsManager, eventsMgr *events.Manager) (*Server, error) {
|
||||
func NewServer(ctx context.Context, cfg *config.Config, metricsMgr *config.MetricsManager, eventsMgr *events.Manager) (*Server, error) {
|
||||
srv := &Server{
|
||||
metricsMgr: metricsMgr,
|
||||
EventsMgr: eventsMgr,
|
||||
|
@ -80,6 +80,10 @@ func NewServer(cfg *config.Config, metricsMgr *config.MetricsManager, eventsMgr
|
|||
httpRouter: atomicutil.NewValue(mux.NewRouter()),
|
||||
}
|
||||
|
||||
ctx = log.WithContext(ctx, func(c zerolog.Context) zerolog.Context {
|
||||
return c.Str("server_name", cfg.Options.Services)
|
||||
})
|
||||
|
||||
var err error
|
||||
|
||||
// setup gRPC
|
||||
|
@ -96,7 +100,11 @@ func NewServer(cfg *config.Config, metricsMgr *config.MetricsManager, eventsMgr
|
|||
srv.GRPCServer = grpc.NewServer(
|
||||
grpc.StatsHandler(telemetry.NewGRPCServerStatsHandler(cfg.Options.Services)),
|
||||
grpc.ChainUnaryInterceptor(requestid.UnaryServerInterceptor(), ui),
|
||||
grpc.ChainStreamInterceptor(requestid.StreamServerInterceptor(), si),
|
||||
grpc.ChainStreamInterceptor(
|
||||
log.StreamServerInterceptor(log.Ctx(ctx)),
|
||||
requestid.StreamServerInterceptor(),
|
||||
si,
|
||||
),
|
||||
)
|
||||
reflection.Register(srv.GRPCServer)
|
||||
srv.registerAccessLogHandlers()
|
||||
|
@ -152,10 +160,6 @@ func NewServer(cfg *config.Config, metricsMgr *config.MetricsManager, eventsMgr
|
|||
srv.reproxy,
|
||||
)
|
||||
|
||||
ctx := log.WithContext(context.Background(), func(c zerolog.Context) zerolog.Context {
|
||||
return c.Str("server_name", cfg.Options.Services)
|
||||
})
|
||||
|
||||
res, err := srv.buildDiscoveryResources(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
|
@ -38,7 +38,7 @@ func TestServerHTTP(t *testing.T) {
|
|||
cfg.Options.SharedKey = "JDNjY2ITDlARvNaQXjc2Djk+GA6xeCy4KiozmZfdbTs="
|
||||
|
||||
src := config.NewStaticSource(cfg)
|
||||
srv, err := NewServer(cfg, config.NewMetricsManager(ctx, src), events.New())
|
||||
srv, err := NewServer(ctx, cfg, config.NewMetricsManager(ctx, src), events.New())
|
||||
require.NoError(t, err)
|
||||
go srv.Run(ctx)
|
||||
|
||||
|
|
|
@ -96,7 +96,7 @@ func (src *ConfigSource) rebuild(ctx context.Context, firstTime firstTime) {
|
|||
cfg := src.underlyingConfig.Clone()
|
||||
|
||||
// start the updater
|
||||
src.runUpdater(cfg)
|
||||
src.runUpdater(ctx, cfg)
|
||||
|
||||
now = time.Now()
|
||||
err := src.buildNewConfigLocked(ctx, cfg)
|
||||
|
@ -234,7 +234,7 @@ func (src *ConfigSource) addPolicies(ctx context.Context, cfg *config.Config, po
|
|||
cfg.Options.AdditionalPolicies = append(cfg.Options.AdditionalPolicies, additionalPolicies...)
|
||||
}
|
||||
|
||||
func (src *ConfigSource) runUpdater(cfg *config.Config) {
|
||||
func (src *ConfigSource) runUpdater(ctx context.Context, cfg *config.Config) {
|
||||
sharedKey, _ := cfg.Options.GetSharedKey()
|
||||
connectionOptions := &grpc.OutboundOptions{
|
||||
OutboundPort: cfg.OutboundPort,
|
||||
|
@ -257,7 +257,6 @@ func (src *ConfigSource) runUpdater(cfg *config.Config) {
|
|||
src.cancel = nil
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
ctx, src.cancel = context.WithCancel(ctx)
|
||||
|
||||
cc, err := src.outboundGRPCConnection.Get(ctx, connectionOptions)
|
||||
|
|
|
@ -5,7 +5,9 @@ import (
|
|||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/pomerium/protoutil/streams"
|
||||
"github.com/rs/zerolog"
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/middleware/responsewriter"
|
||||
"github.com/pomerium/pomerium/pkg/telemetry/requestid"
|
||||
|
@ -121,3 +123,11 @@ func HeadersHandler(headers []string) func(next http.Handler) http.Handler {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func StreamServerInterceptor(lg *zerolog.Logger) grpc.StreamServerInterceptor {
|
||||
return func(srv any, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
|
||||
s := streams.NewServerStreamWithContext(ss)
|
||||
s.SetContext(lg.WithContext(s.Ctx))
|
||||
return handler(srv, s)
|
||||
}
|
||||
}
|
||||
|
|
329
internal/testenv/environment.go
Normal file
329
internal/testenv/environment.go
Normal file
|
@ -0,0 +1,329 @@
|
|||
package testenv
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/pomerium/pomerium/config"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/internal/testenv/values"
|
||||
"github.com/pomerium/pomerium/pkg/cmd/pomerium"
|
||||
"github.com/pomerium/pomerium/pkg/health"
|
||||
"github.com/pomerium/pomerium/pkg/netutil"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
// Environment is a lightweight integration test fixture that runs Pomerium
|
||||
// in-process.
|
||||
type Environment interface {
|
||||
// Context returns the environment's root context. This context holds a
|
||||
// top-level logger scoped to this environment. It will be canceled when
|
||||
// Stop() is called, or during test cleanup.
|
||||
Context() context.Context
|
||||
// TempDir returns a unique temp directory for this context. Calling this
|
||||
// function multiple times returns the same path.
|
||||
TempDir() string
|
||||
// ServerCAs returns a new [*x509.CertPool] containing the root CA certificate
|
||||
// used to sign the server cert and other test certificates.
|
||||
ServerCAs() *x509.CertPool
|
||||
// ServerCert returns the Pomerium server's certificate.
|
||||
ServerCert() *tls.Certificate
|
||||
|
||||
// Add adds the given [Modifier] to the environment. All modifiers will be
|
||||
// invoked upon calling Start() to apply individual modifications to the
|
||||
// configuration before starting the Pomerium server.
|
||||
Add(m Modifier)
|
||||
// AddTask adds the given [Task] to the environment. All tasks will be
|
||||
// started in separate goroutines upon calling Start(). If any tasks exit
|
||||
// with an error, the environment will be stopped and the test will fail.
|
||||
AddTask(r Task)
|
||||
// AddUpstream adds the given [Upstream] to the environment. This function is
|
||||
// equivalent to calling both Add() and AddTask() with the upstream, but
|
||||
// improves readability.
|
||||
AddUpstream(u Upstream)
|
||||
|
||||
// Start starts the test environment, and adds a call to Stop() as a cleanup
|
||||
// hook to the environment's [testing.T]. All previously added [Modifier]
|
||||
// instances are invoked in order to build the configuration, and all
|
||||
// previously added [Task] instances are started in the background.
|
||||
//
|
||||
// Calling Start() more than once, Calling Start() after Stop(), or calling
|
||||
// any of the Add* functions after Start() will panic.
|
||||
Start()
|
||||
// Stop stops the test environment. Calling this function more than once has
|
||||
// no effect. It is usually not necessary to call Stop() directly unless you
|
||||
// need to stop the test environment before the test is completed.
|
||||
Stop()
|
||||
|
||||
// SubdomainURL returns a string [values.Value] which will contain a complete
|
||||
// URL for the given subdomain of the server's domain (given by its serving
|
||||
// certificate), including the 'https://' scheme and random http server port.
|
||||
// This value will only be resolved some time after Start() is called, and
|
||||
// can be used as the 'from' value for routes.
|
||||
SubdomainURL(subdomain string) values.Value[string]
|
||||
|
||||
// NewLogRecorder returns a new [*LogRecorder] and starts capturing logs for
|
||||
// the Pomerium server and Envoy.
|
||||
NewLogRecorder(opts ...LogRecorderOption) *LogRecorder
|
||||
}
|
||||
|
||||
type environment struct {
|
||||
t testing.TB
|
||||
tempDir string
|
||||
domain string
|
||||
ports Ports
|
||||
workspaceFolder string
|
||||
|
||||
ctx context.Context
|
||||
cancel context.CancelCauseFunc
|
||||
cleanupOnce sync.Once
|
||||
logWriter *log.MultiWriter
|
||||
|
||||
mods []WithCaller[Modifier]
|
||||
tasks []WithCaller[Task]
|
||||
taskErrGroup *errgroup.Group
|
||||
}
|
||||
|
||||
func New(t testing.TB) Environment {
|
||||
if testing.Short() {
|
||||
t.Helper()
|
||||
t.Skip("test environment disabled in short mode")
|
||||
}
|
||||
workspaceFolder, err := os.Getwd()
|
||||
require.NoError(t, err)
|
||||
for {
|
||||
if _, err := os.Stat(filepath.Join(workspaceFolder, ".git")); err == nil {
|
||||
break
|
||||
}
|
||||
workspaceFolder = filepath.Dir(workspaceFolder)
|
||||
if workspaceFolder == "/" {
|
||||
panic("could not find workspace root")
|
||||
}
|
||||
}
|
||||
workspaceFolder, err = filepath.Abs(workspaceFolder)
|
||||
require.NoError(t, err)
|
||||
|
||||
writer := log.NewMultiWriter()
|
||||
writer.Add(os.Stdout)
|
||||
logger := zerolog.New(writer).With().Timestamp().Logger().Level(zerolog.DebugLevel)
|
||||
|
||||
ctx, cancel := context.WithCancelCause(logger.WithContext(context.Background()))
|
||||
taskErrGroup, ctx := errgroup.WithContext(ctx)
|
||||
|
||||
e := &environment{
|
||||
t: t,
|
||||
tempDir: t.TempDir(),
|
||||
ports: Ports{
|
||||
http: values.Deferred[int](),
|
||||
},
|
||||
workspaceFolder: workspaceFolder,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
logWriter: writer,
|
||||
taskErrGroup: taskErrGroup,
|
||||
}
|
||||
health.SetProvider(e)
|
||||
|
||||
require.NoError(t, os.Mkdir(filepath.Join(e.tempDir, "certs"), 0o777))
|
||||
copyFile := func(src, dstRel string) {
|
||||
data, err := os.ReadFile(src)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, os.WriteFile(filepath.Join(e.tempDir, dstRel), data, 0o666))
|
||||
}
|
||||
|
||||
certsToCopy := []string{
|
||||
"trusted.pem",
|
||||
"trusted-key.pem",
|
||||
"ca.pem",
|
||||
}
|
||||
for _, crt := range certsToCopy {
|
||||
copyFile(filepath.Join(workspaceFolder, "integration/tpl/files", crt), filepath.Join("certs/", filepath.Base(crt)))
|
||||
}
|
||||
e.domain = wildcardDomain(e.ServerCert().Leaf.DNSNames)
|
||||
|
||||
return e
|
||||
}
|
||||
|
||||
type WithCaller[T any] struct {
|
||||
Caller string
|
||||
Value T
|
||||
}
|
||||
|
||||
type Ports struct {
|
||||
http values.MutableValue[int]
|
||||
}
|
||||
|
||||
func (e *environment) TempDir() string {
|
||||
return e.tempDir
|
||||
}
|
||||
|
||||
func (e *environment) Context() context.Context {
|
||||
return ContextWithEnv(e.ctx, e)
|
||||
}
|
||||
|
||||
func (e *environment) SubdomainURL(subdomain string) values.Value[string] {
|
||||
return values.Bind(e.ports.http, func(port int) string {
|
||||
return fmt.Sprintf("https://%s.%s:%d", subdomain, e.domain, port)
|
||||
})
|
||||
}
|
||||
|
||||
func (e *environment) ServerCAs() *x509.CertPool {
|
||||
pool := x509.NewCertPool()
|
||||
caCert, err := os.ReadFile(filepath.Join(e.tempDir, "certs", "ca.pem"))
|
||||
require.NoError(e.t, err)
|
||||
pool.AppendCertsFromPEM(caCert)
|
||||
return pool
|
||||
}
|
||||
|
||||
func (e *environment) ServerCert() *tls.Certificate {
|
||||
serverCert, err := tls.LoadX509KeyPair(
|
||||
filepath.Join(e.tempDir, "certs", "trusted.pem"),
|
||||
filepath.Join(e.tempDir, "certs", "trusted-key.pem"),
|
||||
)
|
||||
require.NoError(e.t, err)
|
||||
return &serverCert
|
||||
}
|
||||
|
||||
// Used as the context's cancel cause during normal cleanup
|
||||
var ErrCauseTestCleanup = errors.New("test cleanup")
|
||||
|
||||
// Used as the context's cancel cause when Stop() is called
|
||||
var ErrCauseManualStop = errors.New("Stop() called")
|
||||
|
||||
func (e *environment) Start() {
|
||||
e.t.Cleanup(e.cleanup)
|
||||
|
||||
cfg := &config.Config{
|
||||
Options: config.NewDefaultOptions(),
|
||||
}
|
||||
ports, err := netutil.AllocatePorts(7)
|
||||
require.NoError(e.t, err)
|
||||
port0, _ := strconv.Atoi(ports[0])
|
||||
e.ports.http.Resolve(port0)
|
||||
cfg.Options.LogLevel = config.LogLevelInfo
|
||||
cfg.Options.ProxyLogLevel = config.LogLevelInfo
|
||||
cfg.Options.Addr = fmt.Sprintf("127.0.0.1:%d", port0)
|
||||
cfg.Options.CertFile = filepath.Join(e.tempDir, "certs", "trusted.pem")
|
||||
cfg.Options.KeyFile = filepath.Join(e.tempDir, "certs", "trusted-key.pem")
|
||||
cfg.AllocatePorts(*(*[6]string)(ports[1:]))
|
||||
|
||||
e.AddTask(TaskFunc(func(ctx context.Context) error {
|
||||
src := config.NewStaticSource(cfg)
|
||||
for _, mod := range e.mods {
|
||||
mod.Value.Modify(cfg)
|
||||
require.NoError(e.t, cfg.Options.Validate(), "invoking modifier resulted in an invalid configuration:\nadded by: "+mod.Caller)
|
||||
}
|
||||
return pomerium.Run(e.ctx, src)
|
||||
}))
|
||||
|
||||
for i, task := range e.tasks {
|
||||
log.Ctx(e.ctx).Debug().Str("caller", task.Caller).Msgf("starting task %d", i)
|
||||
e.taskErrGroup.Go(func() error {
|
||||
defer log.Ctx(e.ctx).Debug().Str("caller", task.Caller).Msgf("task %d exited", i)
|
||||
return task.Value.Run(e.ctx)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (e *environment) Stop() {
|
||||
e.cleanupOnce.Do(func() {
|
||||
e.cancel(ErrCauseManualStop)
|
||||
err := e.taskErrGroup.Wait()
|
||||
assert.ErrorIs(e.t, err, ErrCauseManualStop)
|
||||
})
|
||||
}
|
||||
|
||||
func (e *environment) cleanup() {
|
||||
e.cleanupOnce.Do(func() {
|
||||
e.cancel(ErrCauseTestCleanup)
|
||||
err := e.taskErrGroup.Wait()
|
||||
assert.ErrorIs(e.t, err, ErrCauseTestCleanup)
|
||||
})
|
||||
}
|
||||
|
||||
func (e *environment) Add(c Modifier) {
|
||||
e.t.Helper()
|
||||
for _, mod := range e.mods {
|
||||
if mod.Value == c {
|
||||
e.t.Fatalf("test bug: duplicate modifier added\nfirst added by: %s", mod.Caller)
|
||||
}
|
||||
}
|
||||
e.mods = append(e.mods, WithCaller[Modifier]{
|
||||
Caller: getCaller(),
|
||||
Value: c,
|
||||
})
|
||||
c.Attach(e.Context())
|
||||
}
|
||||
|
||||
func (e *environment) AddTask(r Task) {
|
||||
e.t.Helper()
|
||||
for _, task := range e.tasks {
|
||||
if task.Value == r {
|
||||
e.t.Fatalf("test bug: duplicate task added\nfirst added by: %s", task.Caller)
|
||||
}
|
||||
}
|
||||
e.tasks = append(e.tasks, WithCaller[Task]{
|
||||
Caller: getCaller(),
|
||||
Value: r,
|
||||
})
|
||||
}
|
||||
|
||||
func (e *environment) AddUpstream(up Upstream) {
|
||||
e.t.Helper()
|
||||
e.Add(up)
|
||||
e.AddTask(up)
|
||||
}
|
||||
|
||||
// ReportError implements health.Provider.
|
||||
func (e *environment) ReportError(check health.Check, err error, attributes ...health.Attr) {
|
||||
// note: don't use e.t.Fatal here, it will deadlock
|
||||
panic(fmt.Sprintf("%s: %v %v", check, err, attributes))
|
||||
}
|
||||
|
||||
// ReportOK implements health.Provider.
|
||||
func (e *environment) ReportOK(check health.Check, attributes ...health.Attr) {
|
||||
}
|
||||
|
||||
func getCaller(skip ...int) string {
|
||||
if len(skip) == 0 {
|
||||
skip = append(skip, 3)
|
||||
}
|
||||
callers := make([]uintptr, 8)
|
||||
runtime.Callers(skip[0], callers)
|
||||
frames := runtime.CallersFrames(callers)
|
||||
var caller string
|
||||
for {
|
||||
next, ok := frames.Next()
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
if path.Base(next.Function) == "testenv.(*environment).AddUpstream" {
|
||||
continue
|
||||
}
|
||||
caller = fmt.Sprintf("%s:%d", next.File, next.Line)
|
||||
break
|
||||
}
|
||||
return caller
|
||||
}
|
||||
|
||||
func wildcardDomain(names []string) string {
|
||||
for _, name := range names {
|
||||
if name[0] == '*' {
|
||||
return name[2:]
|
||||
}
|
||||
}
|
||||
panic("test bug: no wildcard domain in certificate")
|
||||
}
|
210
internal/testenv/logs.go
Normal file
210
internal/testenv/logs.go
Normal file
|
@ -0,0 +1,210 @@
|
|||
package testenv
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// LogRecorder captures logs from the test environment. It can be created at
|
||||
// any time by calling [Environment.NewLogRecorder], and captures logs until
|
||||
// one of Close(), Logs(), or Match() is called, which stops recording. See the
|
||||
// documentation for each method for more details.
|
||||
type LogRecorder struct {
|
||||
LogRecorderOptions
|
||||
t testing.TB
|
||||
buf *bytes.Buffer
|
||||
recordedLogs []map[string]any
|
||||
|
||||
closeOnce func()
|
||||
collectLogsOnce sync.Once
|
||||
}
|
||||
|
||||
type LogRecorderOptions struct {
|
||||
filters []func(map[string]any) bool
|
||||
}
|
||||
|
||||
type LogRecorderOption func(*LogRecorderOptions)
|
||||
|
||||
func (o *LogRecorderOptions) apply(opts ...LogRecorderOption) {
|
||||
for _, op := range opts {
|
||||
op(o)
|
||||
}
|
||||
}
|
||||
|
||||
// WithFilters applies one or more filter predicates to the logger. If there
|
||||
// are filters present, they will be called in order when a log is received,
|
||||
// and if any filter returns false for a given log, it will be discarded.
|
||||
func WithFilters(filters ...func(map[string]any) bool) LogRecorderOption {
|
||||
return func(o *LogRecorderOptions) {
|
||||
o.filters = filters
|
||||
}
|
||||
}
|
||||
|
||||
func (e *environment) NewLogRecorder(opts ...LogRecorderOption) *LogRecorder {
|
||||
options := LogRecorderOptions{}
|
||||
options.apply(opts...)
|
||||
lr := &LogRecorder{
|
||||
LogRecorderOptions: options,
|
||||
t: e.t,
|
||||
buf: &bytes.Buffer{},
|
||||
}
|
||||
e.logWriter.Add(lr.buf)
|
||||
lr.closeOnce = sync.OnceFunc(func() {
|
||||
// wait for envoy access logs, which flush on a 1 second interval
|
||||
time.Sleep(1100 * time.Millisecond)
|
||||
e.logWriter.Remove(lr.buf)
|
||||
})
|
||||
context.AfterFunc(e.ctx, lr.closeOnce)
|
||||
return lr
|
||||
}
|
||||
|
||||
type (
|
||||
// OpenMap is an alias for map[string]any, and can be used to semantically
|
||||
// represent a map that must contain at least the given entries, but may
|
||||
// also contain additional entries.
|
||||
OpenMap = map[string]any
|
||||
// ClosedMap is a map[string]any that can be used to semantically represent
|
||||
// a map that must contain the given entries exactly, and no others.
|
||||
ClosedMap map[string]any
|
||||
)
|
||||
|
||||
// Close stops the log recorder. After calling this method, Logs() or Match()
|
||||
// can be called to inspect the logs that were captured.
|
||||
func (lr *LogRecorder) Close() {
|
||||
lr.closeOnce()
|
||||
}
|
||||
|
||||
func (lr *LogRecorder) collectLogs() {
|
||||
lr.closeOnce()
|
||||
lr.collectLogsOnce.Do(func() {
|
||||
recordedLogs := []map[string]any{}
|
||||
scan := bufio.NewScanner(lr.buf)
|
||||
for scan.Scan() {
|
||||
log := scan.Bytes()
|
||||
m := map[string]any{}
|
||||
decoder := json.NewDecoder(bytes.NewReader(log))
|
||||
decoder.UseNumber()
|
||||
require.NoError(lr.t, decoder.Decode(&m))
|
||||
for _, filter := range lr.filters {
|
||||
if !filter(m) {
|
||||
continue
|
||||
}
|
||||
}
|
||||
recordedLogs = append(recordedLogs, m)
|
||||
}
|
||||
lr.recordedLogs = recordedLogs
|
||||
})
|
||||
}
|
||||
|
||||
// Logs stops the log recorder (if it is not already stopped), then returns
|
||||
// the logs that were captured as structured map[string]any objects.
|
||||
func (lr *LogRecorder) Logs() []map[string]any {
|
||||
lr.collectLogs()
|
||||
return lr.recordedLogs
|
||||
}
|
||||
|
||||
// Match stops the log recorder (if it is not already stopped), then asserts
|
||||
// that the given expected logs were captured. The expected logs may contain
|
||||
// partial or complete log entries. By default, logs must only match the fields
|
||||
// given, and may contain additional fields that will be ignored. For details,
|
||||
// see [OpenMap] and [ClosedMap]. As a special case, using [json.Number] as the
|
||||
// expected value will convert the actual value to a string before comparison.
|
||||
func (lr *LogRecorder) Match(expectedLogs []map[string]any) {
|
||||
lr.collectLogs()
|
||||
var match func(expected, actual map[string]any, open bool) (bool, int)
|
||||
match = func(expected, actual map[string]any, open bool) (bool, int) {
|
||||
score := 0
|
||||
for key, value := range expected {
|
||||
actualValue, ok := actual[key]
|
||||
if !ok {
|
||||
return false, score
|
||||
}
|
||||
score++
|
||||
|
||||
switch actualValue := actualValue.(type) {
|
||||
case map[string]any:
|
||||
switch value := value.(type) {
|
||||
case ClosedMap:
|
||||
ok, s := match(value, actualValue, false)
|
||||
score += s * 2
|
||||
if !ok {
|
||||
return false, score
|
||||
}
|
||||
case OpenMap:
|
||||
ok, s := match(value, actualValue, true)
|
||||
score += s
|
||||
if !ok {
|
||||
return false, score
|
||||
}
|
||||
default:
|
||||
return false, score
|
||||
}
|
||||
case string:
|
||||
switch value := value.(type) {
|
||||
case string:
|
||||
if value != actualValue {
|
||||
return false, score
|
||||
}
|
||||
score++
|
||||
default:
|
||||
return false, score
|
||||
}
|
||||
case json.Number:
|
||||
if fmt.Sprint(value) != actualValue.String() {
|
||||
return false, score
|
||||
}
|
||||
score++
|
||||
default:
|
||||
panic(fmt.Sprintf("test bug: add check for type %T in assertMatchingLogs", actualValue))
|
||||
}
|
||||
}
|
||||
if !open && len(expected) != len(actual) {
|
||||
return false, score
|
||||
}
|
||||
return true, score
|
||||
}
|
||||
|
||||
for _, expectedLog := range expectedLogs {
|
||||
found := false
|
||||
|
||||
highScore, highScoreIdxs := 0, []int{}
|
||||
for i, actualLog := range lr.recordedLogs {
|
||||
if ok, score := match(expectedLog, actualLog, true); ok {
|
||||
found = true
|
||||
break
|
||||
} else if score > highScore {
|
||||
highScore = score
|
||||
highScoreIdxs = []int{i}
|
||||
} else if score == highScore {
|
||||
highScoreIdxs = append(highScoreIdxs, i)
|
||||
}
|
||||
}
|
||||
if len(highScoreIdxs) > 0 {
|
||||
expectedLogBytes, _ := json.MarshalIndent(expectedLog, "", " ")
|
||||
if len(highScoreIdxs) == 1 {
|
||||
actualLogBytes, _ := json.MarshalIndent(lr.recordedLogs[highScoreIdxs[0]], "", " ")
|
||||
assert.True(lr.t, found, "expected log not found: \n%s\n\nclosest match:\n%s\n",
|
||||
string(expectedLogBytes), string(actualLogBytes))
|
||||
} else {
|
||||
closestMatches := []string{}
|
||||
for _, i := range highScoreIdxs {
|
||||
bytes, _ := json.MarshalIndent(lr.recordedLogs[i], "", " ")
|
||||
closestMatches = append(closestMatches, string(bytes))
|
||||
}
|
||||
assert.True(lr.t, found, "expected log not found: \n%s\n\nclosest matches:\n%s\n", string(expectedLogBytes), closestMatches)
|
||||
}
|
||||
} else {
|
||||
expectedLogBytes, _ := json.MarshalIndent(expectedLog, "", " ")
|
||||
assert.True(lr.t, found, "expected log not found: %s", string(expectedLogBytes))
|
||||
}
|
||||
}
|
||||
}
|
60
internal/testenv/route.go
Normal file
60
internal/testenv/route.go
Normal file
|
@ -0,0 +1,60 @@
|
|||
package testenv
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
|
||||
"github.com/pomerium/pomerium/config"
|
||||
"github.com/pomerium/pomerium/internal/testenv/values"
|
||||
)
|
||||
|
||||
// PolicyRoute is a [Route] implementation suitable for most common use cases
|
||||
// that can be used in implementations of [Upstream].
|
||||
type PolicyRoute struct {
|
||||
DefaultAttach
|
||||
from values.Value[string]
|
||||
to values.List[string]
|
||||
edits []func(*config.Policy)
|
||||
}
|
||||
|
||||
// Modify implements Route.
|
||||
func (b *PolicyRoute) Modify(cfg *config.Config) {
|
||||
to := make(config.WeightedURLs, 0, len(b.to))
|
||||
for _, u := range b.to {
|
||||
u, err := url.Parse(u.Value())
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
to = append(to, config.WeightedURL{URL: *u})
|
||||
}
|
||||
p := config.Policy{
|
||||
From: b.from.Value(),
|
||||
To: to,
|
||||
}
|
||||
for _, edit := range b.edits {
|
||||
edit(&p)
|
||||
}
|
||||
cfg.Options.Policies = append(cfg.Options.Policies, p)
|
||||
}
|
||||
|
||||
// From implements Route.
|
||||
func (b *PolicyRoute) From(fromUrl values.Value[string]) Route {
|
||||
b.from = fromUrl
|
||||
return b
|
||||
}
|
||||
|
||||
// To implements Route.
|
||||
func (b *PolicyRoute) To(toUrl values.Value[string]) Route {
|
||||
b.to = append(b.to, toUrl)
|
||||
return b
|
||||
}
|
||||
|
||||
// To implements Route.
|
||||
func (b *PolicyRoute) Policy(edit func(*config.Policy)) Route {
|
||||
b.edits = append(b.edits, edit)
|
||||
return b
|
||||
}
|
||||
|
||||
// To implements Route.
|
||||
func (b *PolicyRoute) URL() values.Value[string] {
|
||||
return b.from
|
||||
}
|
127
internal/testenv/types.go
Normal file
127
internal/testenv/types.go
Normal file
|
@ -0,0 +1,127 @@
|
|||
package testenv
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/pomerium/pomerium/config"
|
||||
"github.com/pomerium/pomerium/internal/testenv/values"
|
||||
)
|
||||
|
||||
type envContextKeyType struct{}
|
||||
|
||||
var envContextKey envContextKeyType
|
||||
|
||||
func EnvFromContext(ctx context.Context) Environment {
|
||||
return ctx.Value(envContextKey).(Environment)
|
||||
}
|
||||
|
||||
func ContextWithEnv(ctx context.Context, env Environment) context.Context {
|
||||
return context.WithValue(ctx, envContextKey, env)
|
||||
}
|
||||
|
||||
// A Modifier is an object whose presence in the test affects the Pomerium
|
||||
// configuration in some way. When the test environment is started, a
|
||||
// [*config.Config] is constructed by calling each added Modifier in order.
|
||||
//
|
||||
// For additional details, see [Environment.Add] and [Environment.Start].
|
||||
type Modifier interface {
|
||||
// Attach is called by an [Environment] (before Modify) to propagate the
|
||||
// environment's context.
|
||||
Attach(ctx context.Context)
|
||||
|
||||
// Modify is called by an [Environment] to mutate its configuration in some
|
||||
// way required by this Modifier.
|
||||
Modify(cfg *config.Config)
|
||||
}
|
||||
|
||||
// DefaultAttach should be embedded in types implementing [Modifier] to
|
||||
// automatically obtain environment context details and caller information.
|
||||
type DefaultAttach struct {
|
||||
env Environment
|
||||
caller string
|
||||
}
|
||||
|
||||
func (d *DefaultAttach) Env() Environment {
|
||||
d.CheckAttached()
|
||||
return d.env
|
||||
}
|
||||
|
||||
func (d *DefaultAttach) Attach(ctx context.Context) {
|
||||
if d.env != nil {
|
||||
panic("internal test environment bug: Attach called twice")
|
||||
}
|
||||
d.env = EnvFromContext(ctx)
|
||||
if d.env == nil {
|
||||
panic("test bug: no environment in context")
|
||||
}
|
||||
}
|
||||
|
||||
func (d *DefaultAttach) CheckAttached() {
|
||||
if d.env == nil {
|
||||
if d.caller != "" {
|
||||
panic("test bug: missing a call to Add for the object created at: " + d.caller)
|
||||
}
|
||||
panic("test bug: not attached (possibly missing a call to Add)")
|
||||
}
|
||||
}
|
||||
|
||||
func (d *DefaultAttach) RecordCaller() {
|
||||
d.caller = getCaller(4)
|
||||
}
|
||||
|
||||
type Modifiers []Modifier
|
||||
|
||||
func (m Modifiers) Modify(cfg *config.Config) {
|
||||
for _, mod := range m {
|
||||
mod.Modify(cfg)
|
||||
}
|
||||
}
|
||||
|
||||
type ModifierFunc func(cfg *config.Config)
|
||||
|
||||
func (f ModifierFunc) Modify(cfg *config.Config) {
|
||||
f(cfg)
|
||||
}
|
||||
|
||||
// Task represents a background task that can be added to an [Environment] to
|
||||
// have it run automatically on startup.
|
||||
//
|
||||
// For additional details, see [Environment.AddTask] and [Environment.Start].
|
||||
type Task interface {
|
||||
Run(ctx context.Context) error
|
||||
}
|
||||
|
||||
type TaskFunc func(ctx context.Context) error
|
||||
|
||||
func (f TaskFunc) Run(ctx context.Context) error {
|
||||
return f(ctx)
|
||||
}
|
||||
|
||||
// Upstream represents an upstream server. It is both a [Task] and a [Modifier]
|
||||
// and can be added to an environment using [Environment.AddUpstream]. From an
|
||||
// Upstream instance, new routes can be created (which automatically adds the
|
||||
// necessary route/policy entries to the config), and used within a test to
|
||||
// easily make requests to the routes with implementation-specific clients.
|
||||
type Upstream interface {
|
||||
Modifier
|
||||
Task
|
||||
Port() values.Value[int]
|
||||
Route() RouteStub
|
||||
}
|
||||
|
||||
// A Route represents a route from a source URL to a destination URL. A route is
|
||||
// typically created by calling [Upstream.Route].
|
||||
type Route interface {
|
||||
Modifier
|
||||
URL() values.Value[string]
|
||||
To(toUrl values.Value[string]) Route
|
||||
Policy(edit func(*config.Policy)) Route
|
||||
// add more methods here as they become needed
|
||||
}
|
||||
|
||||
// RouteStub represents an incomplete [Route]. Providing a URL by calling its
|
||||
// From() method will return a [Route], from which further configuration can
|
||||
// be made.
|
||||
type RouteStub interface {
|
||||
From(fromUrl values.Value[string]) Route
|
||||
}
|
146
internal/testenv/upstreams/grpc.go
Normal file
146
internal/testenv/upstreams/grpc.go
Normal file
|
@ -0,0 +1,146 @@
|
|||
package upstreams
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
|
||||
"github.com/pomerium/pomerium/config"
|
||||
"github.com/pomerium/pomerium/internal/testenv"
|
||||
"github.com/pomerium/pomerium/internal/testenv/values"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials"
|
||||
)
|
||||
|
||||
type Options struct {
|
||||
serverOpts []grpc.ServerOption
|
||||
}
|
||||
|
||||
type Option func(*Options)
|
||||
|
||||
func (o *Options) apply(opts ...Option) {
|
||||
for _, op := range opts {
|
||||
op(o)
|
||||
}
|
||||
}
|
||||
|
||||
func ServerOpts(opt ...grpc.ServerOption) Option {
|
||||
return func(o *Options) {
|
||||
o.serverOpts = append(o.serverOpts, opt...)
|
||||
}
|
||||
}
|
||||
|
||||
// GRPCUpstream represents a GRPC server which can be used as the target for
|
||||
// one or more Pomerium routes in a test environment.
|
||||
//
|
||||
// This upstream implements [grpc.ServiceRegistrar], and can be used similarly
|
||||
// in the same way as [*grpc.Server] to register services before it is started.
|
||||
//
|
||||
// Any [testenv.Route] instances created from this upstream can be referenced
|
||||
// in the Dial() method to establish a connection to that route.
|
||||
type GRPCUpstream interface {
|
||||
testenv.Upstream
|
||||
grpc.ServiceRegistrar
|
||||
Dial(r testenv.Route, dialOpts ...grpc.DialOption) *grpc.ClientConn
|
||||
}
|
||||
|
||||
type grpcUpstream struct {
|
||||
Options
|
||||
testenv.DefaultAttach
|
||||
serverPort values.MutableValue[int]
|
||||
creds credentials.TransportCredentials
|
||||
|
||||
routes testenv.Modifiers
|
||||
services []service
|
||||
}
|
||||
|
||||
var (
|
||||
_ testenv.Upstream = (*grpcUpstream)(nil)
|
||||
_ grpc.ServiceRegistrar = (*grpcUpstream)(nil)
|
||||
)
|
||||
|
||||
// GRPC Creates a new GRPC upstream server.
|
||||
func GRPC(creds credentials.TransportCredentials, opts ...Option) GRPCUpstream {
|
||||
options := Options{}
|
||||
options.apply(opts...)
|
||||
up := &grpcUpstream{
|
||||
Options: options,
|
||||
creds: creds,
|
||||
serverPort: values.Deferred[int](),
|
||||
}
|
||||
up.RecordCaller()
|
||||
return up
|
||||
}
|
||||
|
||||
type service struct {
|
||||
desc *grpc.ServiceDesc
|
||||
impl any
|
||||
}
|
||||
|
||||
func (g *grpcUpstream) Port() values.Value[int] {
|
||||
return g.serverPort
|
||||
}
|
||||
|
||||
// Modify implements testenv.Upstream.
|
||||
func (g *grpcUpstream) Modify(cfg *config.Config) {
|
||||
g.routes.Modify(cfg)
|
||||
}
|
||||
|
||||
// RegisterService implements grpc.ServiceRegistrar.
|
||||
func (g *grpcUpstream) RegisterService(desc *grpc.ServiceDesc, impl any) {
|
||||
g.services = append(g.services, service{desc, impl})
|
||||
}
|
||||
|
||||
// Route implements testenv.Upstream.
|
||||
func (g *grpcUpstream) Route() testenv.RouteStub {
|
||||
r := &testenv.PolicyRoute{}
|
||||
var protocol string
|
||||
switch g.creds.Info().SecurityProtocol {
|
||||
case "insecure":
|
||||
protocol = "h2c"
|
||||
default:
|
||||
protocol = "https"
|
||||
}
|
||||
r.To(values.Bind(g.serverPort, func(port int) string {
|
||||
return fmt.Sprintf("%s://127.0.0.1:%d", protocol, port)
|
||||
}))
|
||||
g.routes = append(g.routes, r)
|
||||
return r
|
||||
}
|
||||
|
||||
// Start implements testenv.Upstream.
|
||||
func (g *grpcUpstream) Run(ctx context.Context) error {
|
||||
listener, err := net.Listen("tcp4", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
g.serverPort.Resolve(listener.Addr().(*net.TCPAddr).Port)
|
||||
server := grpc.NewServer(append(g.serverOpts, grpc.Creds(g.creds))...)
|
||||
for _, s := range g.services {
|
||||
server.RegisterService(s.desc, s.impl)
|
||||
}
|
||||
errC := make(chan error, 1)
|
||||
go func() {
|
||||
errC <- server.Serve(listener)
|
||||
}()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
server.Stop()
|
||||
return context.Cause(ctx)
|
||||
case err := <-errC:
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
func (g *grpcUpstream) Dial(r testenv.Route, dialOpts ...grpc.DialOption) *grpc.ClientConn {
|
||||
dialOpts = append(dialOpts,
|
||||
grpc.WithTransportCredentials(credentials.NewClientTLSFromCert(g.Env().ServerCAs(), "")),
|
||||
grpc.WithDefaultCallOptions(grpc.WaitForReady(true)),
|
||||
)
|
||||
cc, err := grpc.NewClient(strings.TrimPrefix(r.URL().Value(), "https://"), dialOpts...)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return cc
|
||||
}
|
120
internal/testenv/values/value.go
Normal file
120
internal/testenv/values/value.go
Normal file
|
@ -0,0 +1,120 @@
|
|||
package values
|
||||
|
||||
import (
|
||||
"math/rand/v2"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type value[T any] struct {
|
||||
f func() T
|
||||
ready bool
|
||||
cond *sync.Cond
|
||||
}
|
||||
|
||||
// A Value is a container for a single value of type T, whose initialization is
|
||||
// performed the first time Value() is called. Subsequent calls will return the
|
||||
// same value. The Value() function may block until the value is ready on the
|
||||
// first call. Values are safe to use concurrently.
|
||||
type Value[T any] interface {
|
||||
Value() T
|
||||
}
|
||||
|
||||
// MutableValue is the read-write counterpart to [Value], created by calling
|
||||
// [Deferred] for some type T. Calling Resolve() or ResolveFunc() will set
|
||||
// the value and unblock any waiting calls to Value().
|
||||
type MutableValue[T any] interface {
|
||||
Value[T]
|
||||
Resolve(value T)
|
||||
ResolveFunc(fOnce func() T)
|
||||
}
|
||||
|
||||
// Deferred creates a new read-write [MutableValue] for some type T,
|
||||
// representing a value whose initialization may be deferred to a later time.
|
||||
// Once the value is available, call [MutableValue.Resolve] or
|
||||
// [MutableValue.ResolveFunc] to unblock any waiting calls to Value().
|
||||
func Deferred[T any]() MutableValue[T] {
|
||||
return &value[T]{
|
||||
cond: sync.NewCond(&sync.Mutex{}),
|
||||
}
|
||||
}
|
||||
|
||||
// Const creates a read-only [Value] which will become available immediately
|
||||
// upon calling Value() for the first time; it will never block.
|
||||
func Const[T any](t T) Value[T] {
|
||||
return &value[T]{
|
||||
f: func() T { return t },
|
||||
ready: true,
|
||||
cond: sync.NewCond(&sync.Mutex{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (p *value[T]) Value() T {
|
||||
p.cond.L.Lock()
|
||||
defer p.cond.L.Unlock()
|
||||
for !p.ready {
|
||||
p.cond.Wait()
|
||||
}
|
||||
return p.f()
|
||||
}
|
||||
|
||||
func (p *value[T]) ResolveFunc(fOnce func() T) {
|
||||
p.cond.L.Lock()
|
||||
p.f = sync.OnceValue(fOnce)
|
||||
p.ready = true
|
||||
p.cond.L.Unlock()
|
||||
p.cond.Broadcast()
|
||||
}
|
||||
|
||||
func (p *value[T]) Resolve(value T) {
|
||||
p.ResolveFunc(func() T { return value })
|
||||
}
|
||||
|
||||
// Bind creates a new [MutableValue] whose ultimate value depends on the result
|
||||
// of another [Value] that may not yet be available. When Value() is called on
|
||||
// the result, it will cascade and trigger the full chain of initialization
|
||||
// functions necessary to produce the final value.
|
||||
//
|
||||
// Care should be taken when using this function, as improper use can lead to
|
||||
// deadlocks and cause values to never become available.
|
||||
func Bind[T any, U any](dt Value[T], callback func(value T) U) MutableValue[U] {
|
||||
du := Deferred[U]()
|
||||
du.ResolveFunc(func() U {
|
||||
return callback(dt.Value())
|
||||
})
|
||||
return du
|
||||
}
|
||||
|
||||
// Bind2 is like [Bind], but can accept two input values. The result will only
|
||||
// become available once all input values become available.
|
||||
//
|
||||
// This function blocks to wait for each input value in sequence, but in a
|
||||
// random order. Do not rely on the order of evaluation of the input values.
|
||||
func Bind2[T any, U any, V any](dt Value[T], du Value[U], callback func(value1 T, value2 U) V) MutableValue[V] {
|
||||
dv := Deferred[V]()
|
||||
dv.ResolveFunc(func() V {
|
||||
if rand.IntN(2) == 0 {
|
||||
return callback(dt.Value(), du.Value())
|
||||
}
|
||||
u := du.Value()
|
||||
t := dt.Value()
|
||||
return callback(t, u)
|
||||
})
|
||||
return dv
|
||||
}
|
||||
|
||||
// List is a container for a slice of [Value] of type T, and is also a [Value]
|
||||
// itself, for convenience. The Value() function will return a []T containing
|
||||
// all resolved values for each element in the slice.
|
||||
//
|
||||
// A List's Value() function blocks to wait for each element in the slice in
|
||||
// sequence, but in a random order. Do not rely on the order of evaluation of
|
||||
// the slice elements.
|
||||
type List[T any] []Value[T]
|
||||
|
||||
func (s List[T]) Value() []T {
|
||||
values := make([]T, len(s))
|
||||
for _, i := range rand.Perm(len(values)) {
|
||||
values[i] = s[i].Value()
|
||||
}
|
||||
return values
|
||||
}
|
|
@ -71,7 +71,7 @@ func Run(ctx context.Context, src config.Source) error {
|
|||
cfg := src.GetConfig()
|
||||
|
||||
// setup the control plane
|
||||
controlPlane, err := controlplane.NewServer(cfg, metricsMgr, eventsMgr)
|
||||
controlPlane, err := controlplane.NewServer(ctx, cfg, metricsMgr, eventsMgr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error creating control plane: %w", err)
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue