diff --git a/.golangci.yml b/.golangci.yml index c4df769ad..7e6ef272d 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -76,3 +76,6 @@ issues: - text: "G112:" linters: - gosec + - text: "G402: TLS MinVersion too low." + linters: + - gosec diff --git a/config/envoyconfig/bootstrap.go b/config/envoyconfig/bootstrap.go index 370c54a25..d1dbbd8f9 100644 --- a/config/envoyconfig/bootstrap.go +++ b/config/envoyconfig/bootstrap.go @@ -26,9 +26,9 @@ import ( const maxActiveDownstreamConnections = 50000 var ( - envoyAdminAddressPath = filepath.Join(os.TempDir(), "pomerium-envoy-admin.sock") - envoyAdminAddressMode = 0o600 - envoyAdminClusterName = "pomerium-envoy-admin" + envoyAdminAddressSockName = "pomerium-envoy-admin.sock" + envoyAdminAddressMode = 0o600 + envoyAdminClusterName = "pomerium-envoy-admin" ) // BuildBootstrap builds the bootstrap config. @@ -95,7 +95,7 @@ func (b *Builder) BuildBootstrapAdmin(cfg *config.Config) (admin *envoy_config_b admin.Address = &envoy_config_core_v3.Address{ Address: &envoy_config_core_v3.Address_Pipe{ Pipe: &envoy_config_core_v3.Pipe{ - Path: envoyAdminAddressPath, + Path: filepath.Join(os.TempDir(), envoyAdminAddressSockName), Mode: uint32(envoyAdminAddressMode), }, }, diff --git a/config/envoyconfig/bootstrap_test.go b/config/envoyconfig/bootstrap_test.go index ef8cfcc4d..bdceb7945 100644 --- a/config/envoyconfig/bootstrap_test.go +++ b/config/envoyconfig/bootstrap_test.go @@ -12,6 +12,7 @@ import ( ) func TestBuilder_BuildBootstrapAdmin(t *testing.T) { + t.Setenv("TMPDIR", "/tmp") b := New("local-grpc", "local-http", "local-metrics", filemgr.NewManager(), nil) t.Run("valid", func(t *testing.T) { adminCfg, err := b.BuildBootstrapAdmin(&config.Config{ @@ -25,7 +26,7 @@ func TestBuilder_BuildBootstrapAdmin(t *testing.T) { "address": { "pipe": { "mode": 384, - "path": "`+envoyAdminAddressPath+`" + "path": "/tmp/`+envoyAdminAddressSockName+`" } } } diff --git a/config/envoyconfig/clusters_envoy_admin.go b/config/envoyconfig/clusters_envoy_admin.go index 4cde20f92..e0144c1e1 100644 --- a/config/envoyconfig/clusters_envoy_admin.go +++ b/config/envoyconfig/clusters_envoy_admin.go @@ -2,6 +2,8 @@ package envoyconfig import ( "context" + "os" + "path/filepath" envoy_config_cluster_v3 "github.com/envoyproxy/go-control-plane/envoy/config/cluster/v3" envoy_config_core_v3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" @@ -23,7 +25,8 @@ func (b *Builder) buildEnvoyAdminCluster(_ context.Context, _ *config.Config) (* Address: &envoy_config_core_v3.Address{ Address: &envoy_config_core_v3.Address_Pipe{ Pipe: &envoy_config_core_v3.Pipe{ - Path: envoyAdminAddressPath, + Path: filepath.Join(os.TempDir(), envoyAdminAddressSockName), + Mode: uint32(envoyAdminAddressMode), }, }, }, diff --git a/config/envoyconfig/clusters_test.go b/config/envoyconfig/clusters_test.go index 1dcbbf7df..a4d9e309e 100644 --- a/config/envoyconfig/clusters_test.go +++ b/config/envoyconfig/clusters_test.go @@ -23,11 +23,7 @@ import ( func Test_BuildClusters(t *testing.T) { // The admin address path is based on os.TempDir(), which will vary from // system to system, so replace this with a stable location. - originalEnvoyAdminAddressPath := envoyAdminAddressPath - envoyAdminAddressPath = "/tmp/pomerium-envoy-admin.sock" - t.Cleanup(func() { - envoyAdminAddressPath = originalEnvoyAdminAddressPath - }) + t.Setenv("TMPDIR", "/tmp") opts := config.NewDefaultOptions() ctx := context.Background() diff --git a/config/envoyconfig/protocols_int_test.go b/config/envoyconfig/protocols_int_test.go index 19f2833a6..ef1e82785 100644 --- a/config/envoyconfig/protocols_int_test.go +++ b/config/envoyconfig/protocols_int_test.go @@ -1,102 +1,159 @@ package envoyconfig_test import ( - "context" "fmt" - "net" + "io" + "net/http" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "google.golang.org/grpc" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/interop" "google.golang.org/grpc/interop/grpc_testing" - "google.golang.org/grpc/metadata" - "google.golang.org/grpc/status" "github.com/pomerium/pomerium/config" - "github.com/pomerium/pomerium/pkg/cmd/pomerium" - "github.com/pomerium/pomerium/pkg/netutil" + "github.com/pomerium/pomerium/internal/testenv" + "github.com/pomerium/pomerium/internal/testenv/scenarios" + "github.com/pomerium/pomerium/internal/testenv/snippets" + "github.com/pomerium/pomerium/internal/testenv/upstreams" + "github.com/pomerium/pomerium/internal/testenv/values" ) func TestH2C(t *testing.T) { - if testing.Short() { - t.SkipNow() - } + env := testenv.New(t) - ctx, ca := context.WithCancel(context.Background()) + up := upstreams.GRPC(insecure.NewCredentials()) + grpc_testing.RegisterTestServiceServer(up, interop.NewTestServer()) - opts := config.NewDefaultOptions() - listener, err := (&net.ListenConfig{}).Listen(ctx, "tcp", "127.0.0.1:0") - require.NoError(t, err) - ports, err := netutil.AllocatePorts(7) - require.NoError(t, err) - urls, err := config.ParseWeightedUrls("http://"+listener.Addr().String(), "h2c://"+listener.Addr().String()) - require.NoError(t, err) - opts.Addr = fmt.Sprintf("127.0.0.1:%s", ports[0]) - opts.Routes = []config.Policy{ - { - From: fmt.Sprintf("https://grpc-http.localhost.pomerium.io:%s", ports[0]), - To: urls[:1], - AllowPublicUnauthenticatedAccess: true, - }, - { - From: fmt.Sprintf("https://grpc-h2c.localhost.pomerium.io:%s", ports[0]), - To: urls[1:], - AllowPublicUnauthenticatedAccess: true, - }, - } - opts.CertFile = "../../integration/tpl/files/trusted.pem" - opts.KeyFile = "../../integration/tpl/files/trusted-key.pem" - cfg := &config.Config{Options: opts} - cfg.AllocatePorts(*(*[6]string)(ports[1:])) + http := up.Route(). + From(env.SubdomainURL("grpc-http")). + To(values.Bind(up.Port(), func(port int) string { + // override the target protocol to use http:// + return fmt.Sprintf("http://127.0.0.1:%d", port) + })). + Policy(func(p *config.Policy) { p.AllowPublicUnauthenticatedAccess = true }) - server := grpc.NewServer(grpc.Creds(insecure.NewCredentials())) - grpc_testing.RegisterTestServiceServer(server, interop.NewTestServer()) - go server.Serve(listener) + h2c := up.Route(). + From(env.SubdomainURL("grpc-h2c")). + Policy(func(p *config.Policy) { p.AllowPublicUnauthenticatedAccess = true }) - errC := make(chan error, 1) - go func() { - errC <- pomerium.Run(ctx, config.NewStaticSource(cfg)) - }() - - t.Cleanup(func() { - ca() - assert.ErrorIs(t, context.Canceled, <-errC) - }) - - tlsConfig, err := credentials.NewClientTLSFromFile("../../integration/tpl/files/ca.pem", "") - require.NoError(t, err) + env.AddUpstream(up) + env.Start() + snippets.WaitStartupComplete(env) t.Run("h2c", func(t *testing.T) { t.Parallel() + recorder := env.NewLogRecorder() - cc, err := grpc.Dial(fmt.Sprintf("grpc-h2c.localhost.pomerium.io:%s", ports[0]), grpc.WithTransportCredentials(tlsConfig)) - require.NoError(t, err) + cc := up.Dial(h2c) client := grpc_testing.NewTestServiceClient(cc) - var md metadata.MD - _, err = client.EmptyCall(ctx, &grpc_testing.Empty{}, grpc.WaitForReady(true), grpc.Header(&md)) + _, err := client.EmptyCall(env.Context(), &grpc_testing.Empty{}) + require.NoError(t, err) cc.Close() - assert.NoError(t, err) - assert.Contains(t, md, "x-envoy-upstream-service-time") + + recorder.Match([]map[string]any{ + { + "service": "envoy", + "path": "/grpc.testing.TestService/EmptyCall", + "message": "http-request", + "response-code-details": "via_upstream", + }, + }) }) t.Run("http", func(t *testing.T) { t.Parallel() + recorder := env.NewLogRecorder() - cc, err := grpc.Dial(fmt.Sprintf("grpc-http.localhost.pomerium.io:%s", ports[0]), grpc.WithTransportCredentials(tlsConfig)) - require.NoError(t, err) + cc := up.Dial(http) client := grpc_testing.NewTestServiceClient(cc) - var md metadata.MD - _, err = client.EmptyCall(ctx, &grpc_testing.Empty{}, grpc.WaitForReady(true), grpc.Trailer(&md)) + _, err := client.UnaryCall(env.Context(), &grpc_testing.SimpleRequest{}) + require.Error(t, err) cc.Close() - stat := status.Convert(err) - assert.NotNil(t, stat) - assert.Equal(t, stat.Code(), codes.Unavailable) - assert.NotContains(t, md, "x-envoy-upstream-service-time") - assert.Contains(t, stat.Message(), "") - assert.Contains(t, stat.Message(), "upstream_reset_before_response_started{protocol_error}") + + recorder.Match([]map[string]any{ + { + "service": "envoy", + "path": "/grpc.testing.TestService/UnaryCall", + "message": "http-request", + "response-code-details": "upstream_reset_before_response_started{protocol_error}", + }, + }) + }) +} + +func TestHTTP(t *testing.T) { + env := testenv.New(t) + + up := upstreams.HTTP(nil) + up.Handle("/foo", func(w http.ResponseWriter, _ *http.Request) { + fmt.Fprintln(w, "hello world") + }) + + route := up.Route(). + From(env.SubdomainURL("http")). + Policy(func(p *config.Policy) { p.AllowPublicUnauthenticatedAccess = true }) + + env.AddUpstream(up) + env.Start() + + recorder := env.NewLogRecorder() + + resp, err := up.Get(route, upstreams.Path("/foo")) + require.NoError(t, err) + + defer resp.Body.Close() + data, err := io.ReadAll(resp.Body) + require.NoError(t, err) + assert.Equal(t, "hello world\n", string(data)) + + recorder.Match([]map[string]any{ + { + "service": "envoy", + "path": "/foo", + "method": "GET", + "message": "http-request", + "response-code-details": "via_upstream", + }, + }) +} + +func TestClientCert(t *testing.T) { + env := testenv.New(t) + env.Add(scenarios.DownstreamMTLS(config.MTLSEnforcementRejectConnection)) + + up := upstreams.HTTP(nil) + up.Handle("/foo", func(w http.ResponseWriter, _ *http.Request) { + fmt.Fprintln(w, "hello world") + }) + + clientCert := env.NewClientCert() + + route := up.Route(). + From(env.SubdomainURL("http")). + PPL(fmt.Sprintf(`{"allow":{"and":["client_certificate":{"fingerprint":%q}]}}`, clientCert.Fingerprint())) + + env.AddUpstream(up) + env.Start() + + recorder := env.NewLogRecorder() + + resp, err := up.Get(route, upstreams.Path("/foo"), upstreams.ClientCert(clientCert)) + require.NoError(t, err) + + defer resp.Body.Close() + data, err := io.ReadAll(resp.Body) + require.NoError(t, err) + assert.Equal(t, "hello world\n", string(data)) + + recorder.Match([]map[string]any{ + { + "service": "envoy", + "path": "/foo", + "method": "GET", + "message": "http-request", + "response-code-details": "via_upstream", + "client-certificate": clientCert, + }, }) } diff --git a/config/envoyconfig/testdata/clusters.json b/config/envoyconfig/testdata/clusters.json index 4e8a159f6..e42f5b5ee 100644 --- a/config/envoyconfig/testdata/clusters.json +++ b/config/envoyconfig/testdata/clusters.json @@ -280,6 +280,7 @@ "endpoint": { "address": { "pipe": { + "mode": 384, "path": "/tmp/pomerium-envoy-admin.sock" } } diff --git a/internal/autocert/manager.go b/internal/autocert/manager.go index 429f2173e..8cd21e7da 100644 --- a/internal/autocert/manager.go +++ b/internal/autocert/manager.go @@ -129,6 +129,7 @@ func newManager( for { select { case <-ctx.Done(): + cache.Stop() return case <-ticker.C: err := mgr.renewConfigCerts(ctx) diff --git a/internal/benchmarks/config_bench_test.go b/internal/benchmarks/config_bench_test.go new file mode 100644 index 000000000..b554c2e42 --- /dev/null +++ b/internal/benchmarks/config_bench_test.go @@ -0,0 +1,54 @@ +package benchmarks_test + +import ( + "fmt" + "testing" + "time" + + "github.com/pomerium/pomerium/internal/testenv" + "github.com/pomerium/pomerium/internal/testenv/snippets" + "github.com/pomerium/pomerium/internal/testenv/upstreams" +) + +func BenchmarkStartupLatency(b *testing.B) { + for _, n := range []int{1, 10, 100, 1000, 10000} { + b.Run(fmt.Sprintf("routes=%d", n), func(b *testing.B) { + for range b.N { + env := testenv.New(b) + up := upstreams.HTTP(nil) + for i := range n { + up.Route(). + From(env.SubdomainURL(fmt.Sprintf("from-%d", i))). + PPL(`{"allow":{"and":[{"accept":"true"}]}}`) + } + env.AddUpstream(up) + + env.Start() + snippets.WaitStartupComplete(env, 60*time.Minute) + + env.Stop() + } + }) + } +} + +func BenchmarkAppendRoutes(b *testing.B) { + for _, n := range []int{1, 10, 100, 1000, 10000} { + b.Run(fmt.Sprintf("routes=%d", n), func(b *testing.B) { + for range b.N { + env := testenv.New(b) + up := upstreams.HTTP(nil) + env.AddUpstream(up) + + env.Start() + snippets.WaitStartupComplete(env) + for i := range n { + env.Add(up.Route(). + From(env.SubdomainURL(fmt.Sprintf("from-%d", i))). + PPL(fmt.Sprintf(`{"allow":{"and":["email":{"is":"user-%d@example.com"}]}}`, i))) + } + env.Stop() + } + }) + } +} diff --git a/internal/benchmarks/latency_bench_test.go b/internal/benchmarks/latency_bench_test.go new file mode 100644 index 000000000..f936d7d8d --- /dev/null +++ b/internal/benchmarks/latency_bench_test.go @@ -0,0 +1,87 @@ +package benchmarks_test + +import ( + "flag" + "fmt" + "io" + "math/rand/v2" + "net/http" + "testing" + + "github.com/pomerium/pomerium/internal/testenv" + "github.com/pomerium/pomerium/internal/testenv/scenarios" + "github.com/pomerium/pomerium/internal/testenv/snippets" + "github.com/pomerium/pomerium/internal/testenv/upstreams" + "github.com/stretchr/testify/assert" +) + +var ( + numRoutes int + dumpErrLogs bool +) + +func init() { + flag.IntVar(&numRoutes, "routes", 100, "number of routes") + flag.BoolVar(&dumpErrLogs, "dump-err-logs", false, "if the test fails, write all captured logs to a file (testdata/)") +} + +func TestRequestLatency(t *testing.T) { + env := testenv.New(t, testenv.Silent()) + users := []*scenarios.User{} + for i := range numRoutes { + users = append(users, &scenarios.User{ + Email: fmt.Sprintf("user%d@example.com", i), + FirstName: fmt.Sprintf("Firstname%d", i), + LastName: fmt.Sprintf("Lastname%d", i), + }) + } + env.Add(scenarios.NewIDP(users)) + + up := upstreams.HTTP(nil) + up.Handle("/", func(w http.ResponseWriter, _ *http.Request) { + w.Write([]byte("OK")) + }) + routes := make([]testenv.Route, numRoutes) + for i := range numRoutes { + routes[i] = up.Route(). + From(env.SubdomainURL(fmt.Sprintf("from-%d", i))). + PPL(fmt.Sprintf(`{"allow":{"and":["email":{"is":"user%d@example.com"}]}}`, i)) + } + env.AddUpstream(up) + + env.Start() + snippets.WaitStartupComplete(env) + + out := testing.Benchmark(func(b *testing.B) { + b.ReportAllocs() + b.RunParallel(func(pb *testing.PB) { + var rec *testenv.LogRecorder + if dumpErrLogs { + rec = env.NewLogRecorder(testenv.WithSkipCloseDelay()) + } + for pb.Next() { + idx := rand.IntN(numRoutes) + resp, err := up.Get(routes[idx], upstreams.AuthenticateAs(fmt.Sprintf("user%d@example.com", idx))) + if !assert.NoError(b, err) { + filename := "TestRequestLatency_err.log" + if dumpErrLogs { + rec.DumpToFile(filename) + b.Logf("test logs written to %s", filename) + } + return + } + + assert.Equal(b, resp.StatusCode, 200) + body, err := io.ReadAll(resp.Body) + resp.Body.Close() + assert.NoError(b, err) + assert.Equal(b, "OK", string(body)) + } + }) + }) + + t.Log(out) + t.Logf("req/s: %f", float64(out.N)/out.T.Seconds()) + + env.Stop() +} diff --git a/internal/controlplane/server.go b/internal/controlplane/server.go index cfd74effc..587d19318 100644 --- a/internal/controlplane/server.go +++ b/internal/controlplane/server.go @@ -74,11 +74,12 @@ func NewServer( cfg *config.Config, metricsMgr *config.MetricsManager, eventsMgr *events.Manager, + fileMgr *filemgr.Manager, ) (*Server, error) { srv := &Server{ metricsMgr: metricsMgr, EventsMgr: eventsMgr, - filemgr: filemgr.NewManager(), + filemgr: fileMgr, reproxy: reproxy.New(), haveSetCapacity: map[string]bool{}, updateConfig: make(chan *config.Config, 1), diff --git a/internal/controlplane/server_test.go b/internal/controlplane/server_test.go index f7ae90155..8f85b4cd4 100644 --- a/internal/controlplane/server_test.go +++ b/internal/controlplane/server_test.go @@ -12,6 +12,7 @@ import ( "github.com/stretchr/testify/require" "github.com/pomerium/pomerium/config" + "github.com/pomerium/pomerium/config/envoyconfig/filemgr" "github.com/pomerium/pomerium/internal/events" "github.com/pomerium/pomerium/pkg/netutil" ) @@ -38,7 +39,7 @@ func TestServerHTTP(t *testing.T) { cfg.Options.SharedKey = "JDNjY2ITDlARvNaQXjc2Djk+GA6xeCy4KiozmZfdbTs=" src := config.NewStaticSource(cfg) - srv, err := NewServer(ctx, cfg, config.NewMetricsManager(ctx, src), events.New()) + srv, err := NewServer(ctx, cfg, config.NewMetricsManager(ctx, src), events.New(), filemgr.NewManager(filemgr.WithCacheDir(t.TempDir()))) require.NoError(t, err) go srv.Run(ctx) diff --git a/internal/log/debug.go b/internal/log/debug.go index 1d4f38bb7..0afa5bc56 100644 --- a/internal/log/debug.go +++ b/internal/log/debug.go @@ -7,4 +7,6 @@ var ( DebugDisableZapLogger atomic.Bool // Debug option to suppress global warnings DebugDisableGlobalWarnings atomic.Bool + // Debug option to suppress global (non-warning) messages + DebugDisableGlobalMessages atomic.Bool ) diff --git a/internal/testenv/environment.go b/internal/testenv/environment.go new file mode 100644 index 000000000..2147d1912 --- /dev/null +++ b/internal/testenv/environment.go @@ -0,0 +1,838 @@ +package testenv + +import ( + "bytes" + "context" + "crypto/rand" + "crypto/rsa" + "crypto/sha256" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/asn1" + "encoding/base64" + "encoding/hex" + "errors" + "fmt" + "io" + "math/big" + "math/bits" + "net" + "net/url" + "os" + "os/signal" + "path" + "path/filepath" + "runtime" + "strconv" + "sync" + "syscall" + "testing" + "time" + + "github.com/pomerium/pomerium/config" + "github.com/pomerium/pomerium/config/envoyconfig/filemgr" + "github.com/pomerium/pomerium/internal/log" + "github.com/pomerium/pomerium/internal/testenv/values" + "github.com/pomerium/pomerium/pkg/cmd/pomerium" + "github.com/pomerium/pomerium/pkg/grpc/databroker" + "github.com/pomerium/pomerium/pkg/health" + "github.com/pomerium/pomerium/pkg/netutil" + "github.com/pomerium/pomerium/pkg/slices" + "github.com/rs/zerolog" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/sync/errgroup" + "google.golang.org/grpc/grpclog" +) + +// Environment is a lightweight integration test fixture that runs Pomerium +// in-process. +type Environment interface { + // Context returns the environment's root context. This context holds a + // top-level logger scoped to this environment. It will be canceled when + // Stop() is called, or during test cleanup. + Context() context.Context + + Assert() *assert.Assertions + Require() *require.Assertions + + // TempDir returns a unique temp directory for this context. Calling this + // function multiple times returns the same path. + TempDir() string + // CACert returns the test environment's root CA certificate and private key. + CACert() *tls.Certificate + // ServerCAs returns a new [*x509.CertPool] containing the root CA certificate + // used to sign the server cert and other test certificates. + ServerCAs() *x509.CertPool + // ServerCert returns the Pomerium server's certificate and private key. + ServerCert() *tls.Certificate + // NewClientCert generates a new client certificate signed by the root CA + // certificate. One or more optional templates can be given, which can be + // used to set or override certain parameters when creating a certificate, + // including subject, SANs, or extensions. If more than one template is + // provided, they will be applied in order from left to right. + // + // By default (unless overridden in a template), the certificate will have + // its Common Name set to the file:line string of the call site. Calls to + // NewClientCert() on different lines will have different subjects. If + // multiple certs with the same subject are needed, wrap the call to this + // function in another helper function, or separate calls with commas on the + // same line. + NewClientCert(templateOverrides ...*x509.Certificate) *Certificate + + NewServerCert(templateOverrides ...*x509.Certificate) *Certificate + + AuthenticateURL() values.Value[string] + DatabrokerURL() values.Value[string] + Ports() Ports + SharedSecret() []byte + CookieSecret() []byte + + // Add adds the given [Modifier] to the environment. All modifiers will be + // invoked upon calling Start() to apply individual modifications to the + // configuration before starting the Pomerium server. + Add(m Modifier) + // AddTask adds the given [Task] to the environment. All tasks will be + // started in separate goroutines upon calling Start(). If any tasks exit + // with an error, the environment will be stopped and the test will fail. + AddTask(r Task) + // AddUpstream adds the given [Upstream] to the environment. This function is + // equivalent to calling both Add() and AddTask() with the upstream, but + // improves readability. + AddUpstream(u Upstream) + + // Start starts the test environment, and adds a call to Stop() as a cleanup + // hook to the environment's [testing.T]. All previously added [Modifier] + // instances are invoked in order to build the configuration, and all + // previously added [Task] instances are started in the background. + // + // Calling Start() more than once, Calling Start() after Stop(), or calling + // any of the Add* functions after Start() will panic. + Start() + // Stop stops the test environment. Calling this function more than once has + // no effect. It is usually not necessary to call Stop() directly unless you + // need to stop the test environment before the test is completed. + Stop() + + // SubdomainURL returns a string [values.Value] which will contain a complete + // URL for the given subdomain of the server's domain (given by its serving + // certificate), including the 'https://' scheme and random http server port. + // This value will only be resolved some time after Start() is called, and + // can be used as the 'from' value for routes. + SubdomainURL(subdomain string) values.Value[string] + + // NewLogRecorder returns a new [*LogRecorder] and starts capturing logs for + // the Pomerium server and Envoy. + NewLogRecorder(opts ...LogRecorderOption) *LogRecorder + + // OnStateChanged registers a callback to be invoked when the environment's + // state changes to the given state. The callback is invoked in a separate + // goroutine. + OnStateChanged(state EnvironmentState, callback func()) +} + +type Certificate tls.Certificate + +func (c *Certificate) Fingerprint() string { + sum := sha256.Sum256(c.Leaf.Raw) + return hex.EncodeToString(sum[:]) +} + +func (c *Certificate) SPKIHash() string { + sum := sha256.Sum256(c.Leaf.RawSubjectPublicKeyInfo) + return base64.StdEncoding.EncodeToString(sum[:]) +} + +type EnvironmentState uint32 + +const NotRunning EnvironmentState = 0 + +const ( + Starting EnvironmentState = 1 << iota + Running + Stopping + Stopped +) + +func (e EnvironmentState) String() string { + switch e { + case NotRunning: + return "NotRunning" + case Starting: + return "Starting" + case Running: + return "Running" + case Stopping: + return "Stopping" + case Stopped: + return "Stopped" + default: + return fmt.Sprintf("EnvironmentState(%d)", e) + } +} + +type environment struct { + EnvironmentOptions + t testing.TB + assert *assert.Assertions + require *require.Assertions + tempDir string + domain string + ports Ports + sharedSecret [32]byte + cookieSecret [32]byte + workspaceFolder string + silent bool + + ctx context.Context + cancel context.CancelCauseFunc + cleanupOnce sync.Once + logWriter *log.MultiWriter + + mods []WithCaller[Modifier] + tasks []WithCaller[Task] + taskErrGroup *errgroup.Group + + stateMu sync.Mutex + state EnvironmentState + stateChangeListeners map[EnvironmentState][]func() + + src *configSource +} + +type EnvironmentOptions struct { + debug bool + pauseOnFailure bool + forceSilent bool +} + +type EnvironmentOption func(*EnvironmentOptions) + +func (o *EnvironmentOptions) apply(opts ...EnvironmentOption) { + for _, op := range opts { + op(o) + } +} + +func Debug(enable ...bool) EnvironmentOption { + if len(enable) == 0 { + enable = append(enable, true) + } + return func(o *EnvironmentOptions) { + o.debug = enable[0] + } +} + +func PauseOnFailure(enable ...bool) EnvironmentOption { + if len(enable) == 0 { + enable = append(enable, true) + } + return func(o *EnvironmentOptions) { + o.pauseOnFailure = enable[0] + } +} + +func Silent(silent ...bool) EnvironmentOption { + if len(silent) == 0 { + silent = append(silent, true) + } + return func(o *EnvironmentOptions) { + o.forceSilent = silent[0] + } +} + +var setGrpcLoggerOnce sync.Once + +func New(t testing.TB, opts ...EnvironmentOption) Environment { + if runtime.GOOS != "linux" { + t.Skip("test environment only supported on linux") + } + options := EnvironmentOptions{} + options.apply(opts...) + if testing.Short() { + t.Helper() + t.Skip("test environment disabled in short mode") + } + databroker.DebugUseFasterBackoff.Store(true) + workspaceFolder, err := os.Getwd() + require.NoError(t, err) + for { + if _, err := os.Stat(filepath.Join(workspaceFolder, ".git")); err == nil { + break + } + workspaceFolder = filepath.Dir(workspaceFolder) + if workspaceFolder == "/" { + panic("could not find workspace root") + } + } + workspaceFolder, err = filepath.Abs(workspaceFolder) + require.NoError(t, err) + + writer := log.NewMultiWriter() + silent := options.forceSilent || isSilent(t) + if silent { + // this sets the global zap level to fatal, then resets the global zerolog + // level to debug + log.SetLevel(zerolog.FatalLevel) + zerolog.SetGlobalLevel(zerolog.DebugLevel) + } else { + log.SetLevel(zerolog.InfoLevel) + writer.Add(os.Stdout) + } + log.DebugDisableGlobalWarnings.Store(silent) + log.DebugDisableGlobalMessages.Store(silent) + log.DebugDisableZapLogger.Store(silent) + setGrpcLoggerOnce.Do(func() { + grpclog.SetLoggerV2(grpclog.NewLoggerV2WithVerbosity(io.Discard, io.Discard, io.Discard, 0)) + }) + logger := zerolog.New(writer).With().Timestamp().Logger().Level(zerolog.DebugLevel) + + ctx, cancel := context.WithCancelCause(logger.WithContext(context.Background())) + taskErrGroup, ctx := errgroup.WithContext(ctx) + + e := &environment{ + EnvironmentOptions: options, + t: t, + assert: assert.New(t), + require: require.New(t), + tempDir: t.TempDir(), + ports: Ports{ + ProxyHTTP: values.Deferred[int](), + ProxyGRPC: values.Deferred[int](), + GRPC: values.Deferred[int](), + HTTP: values.Deferred[int](), + Outbound: values.Deferred[int](), + Metrics: values.Deferred[int](), + Debug: values.Deferred[int](), + ALPN: values.Deferred[int](), + }, + workspaceFolder: workspaceFolder, + silent: silent, + ctx: ctx, + cancel: cancel, + logWriter: writer, + taskErrGroup: taskErrGroup, + } + _, err = rand.Read(e.sharedSecret[:]) + require.NoError(t, err) + _, err = rand.Read(e.cookieSecret[:]) + require.NoError(t, err) + + health.SetProvider(e) + + require.NoError(t, os.Mkdir(filepath.Join(e.tempDir, "certs"), 0o777)) + copyFile := func(src, dstRel string) { + data, err := os.ReadFile(src) + require.NoError(t, err) + require.NoError(t, os.WriteFile(filepath.Join(e.tempDir, dstRel), data, 0o600)) + } + + certsToCopy := []string{ + "trusted.pem", + "trusted-key.pem", + "ca.pem", + "ca-key.pem", + } + for _, crt := range certsToCopy { + copyFile(filepath.Join(workspaceFolder, "integration/tpl/files", crt), filepath.Join("certs/", filepath.Base(crt))) + } + e.domain = wildcardDomain(e.ServerCert().Leaf.DNSNames) + + return e +} + +func (e *environment) debugf(format string, args ...any) { + if !e.debug { + return + } + + e.t.Logf("\x1b[34m[debug] "+format+"\x1b[0m", args...) +} + +type WithCaller[T any] struct { + Caller string + Value T +} + +type Ports struct { + ProxyHTTP values.MutableValue[int] + ProxyGRPC values.MutableValue[int] + GRPC values.MutableValue[int] + HTTP values.MutableValue[int] + Outbound values.MutableValue[int] + Metrics values.MutableValue[int] + Debug values.MutableValue[int] + ALPN values.MutableValue[int] +} + +func (e *environment) TempDir() string { + return e.tempDir +} + +func (e *environment) Context() context.Context { + return ContextWithEnv(e.ctx, e) +} + +func (e *environment) Assert() *assert.Assertions { + return e.assert +} + +func (e *environment) Require() *require.Assertions { + return e.require +} + +func (e *environment) SubdomainURL(subdomain string) values.Value[string] { + return values.Bind(e.ports.ProxyHTTP, func(port int) string { + return fmt.Sprintf("https://%s.%s:%d", subdomain, e.domain, port) + }) +} + +func (e *environment) AuthenticateURL() values.Value[string] { + return e.SubdomainURL("authenticate") +} + +func (e *environment) DatabrokerURL() values.Value[string] { + return values.Bind(e.ports.Outbound, func(port int) string { + return fmt.Sprintf("127.0.0.1:%d", port) + }) +} + +func (e *environment) Ports() Ports { + return e.ports +} + +func (e *environment) CACert() *tls.Certificate { + caCert, err := tls.LoadX509KeyPair( + filepath.Join(e.tempDir, "certs", "ca.pem"), + filepath.Join(e.tempDir, "certs", "ca-key.pem"), + ) + require.NoError(e.t, err) + return &caCert +} + +func (e *environment) ServerCAs() *x509.CertPool { + pool := x509.NewCertPool() + caCert, err := os.ReadFile(filepath.Join(e.tempDir, "certs", "ca.pem")) + require.NoError(e.t, err) + pool.AppendCertsFromPEM(caCert) + return pool +} + +func (e *environment) ServerCert() *tls.Certificate { + serverCert, err := tls.LoadX509KeyPair( + filepath.Join(e.tempDir, "certs", "trusted.pem"), + filepath.Join(e.tempDir, "certs", "trusted-key.pem"), + ) + require.NoError(e.t, err) + return &serverCert +} + +// Used as the context's cancel cause during normal cleanup +var ErrCauseTestCleanup = errors.New("test cleanup") + +// Used as the context's cancel cause when Stop() is called +var ErrCauseManualStop = errors.New("Stop() called") + +func (e *environment) Start() { + e.debugf("Start()") + e.advanceState(Starting) + e.t.Cleanup(e.cleanup) + e.t.Setenv("TMPDIR", e.TempDir()) + e.debugf("temp dir: %s", e.TempDir()) + + cfg := &config.Config{ + Options: config.NewDefaultOptions(), + } + ports, err := netutil.AllocatePorts(8) + require.NoError(e.t, err) + atoi := func(str string) int { + p, err := strconv.Atoi(str) + if err != nil { + panic(err) + } + return p + } + e.ports.ProxyHTTP.Resolve(atoi(ports[0])) + e.ports.ProxyGRPC.Resolve(atoi(ports[1])) + e.ports.GRPC.Resolve(atoi(ports[2])) + e.ports.HTTP.Resolve(atoi(ports[3])) + e.ports.Outbound.Resolve(atoi(ports[4])) + e.ports.Metrics.Resolve(atoi(ports[5])) + e.ports.Debug.Resolve(atoi(ports[6])) + e.ports.ALPN.Resolve(atoi(ports[7])) + cfg.AllocatePorts(*(*[6]string)(ports[2:])) + + cfg.Options.AutocertOptions = config.AutocertOptions{Enable: false} + cfg.Options.Services = "all" + cfg.Options.LogLevel = config.LogLevelDebug + cfg.Options.ProxyLogLevel = config.LogLevelInfo + cfg.Options.Addr = fmt.Sprintf("127.0.0.1:%d", e.ports.ProxyHTTP.Value()) + cfg.Options.GRPCAddr = fmt.Sprintf("127.0.0.1:%d", e.ports.ProxyGRPC.Value()) + cfg.Options.CAFile = filepath.Join(e.tempDir, "certs", "ca.pem") + cfg.Options.CertFile = filepath.Join(e.tempDir, "certs", "trusted.pem") + cfg.Options.KeyFile = filepath.Join(e.tempDir, "certs", "trusted-key.pem") + cfg.Options.AuthenticateURLString = e.AuthenticateURL().Value() + cfg.Options.DataBrokerStorageType = "memory" + cfg.Options.SharedKey = base64.StdEncoding.EncodeToString(e.sharedSecret[:]) + cfg.Options.CookieSecret = base64.StdEncoding.EncodeToString(e.cookieSecret[:]) + cfg.Options.AccessLogFields = []log.AccessLogField{ + log.AccessLogFieldAuthority, + log.AccessLogFieldDuration, + log.AccessLogFieldForwardedFor, + log.AccessLogFieldIP, + log.AccessLogFieldMethod, + log.AccessLogFieldPath, + log.AccessLogFieldQuery, + log.AccessLogFieldReferer, + log.AccessLogFieldRequestID, + log.AccessLogFieldResponseCode, + log.AccessLogFieldResponseCodeDetails, + log.AccessLogFieldSize, + log.AccessLogFieldUpstreamCluster, + log.AccessLogFieldUserAgent, + log.AccessLogFieldClientCertificate, + } + + e.src = &configSource{cfg: cfg} + e.AddTask(TaskFunc(func(ctx context.Context) error { + fileMgr := filemgr.NewManager(filemgr.WithCacheDir(filepath.Join(e.TempDir(), "cache"))) + for _, mod := range e.mods { + mod.Value.Modify(cfg) + require.NoError(e.t, cfg.Options.Validate(), "invoking modifier resulted in an invalid configuration:\nadded by: "+mod.Caller) + } + return pomerium.Run(ctx, e.src, pomerium.WithOverrideFileManager(fileMgr)) + })) + + for i, task := range e.tasks { + log.Ctx(e.ctx).Debug().Str("caller", task.Caller).Msgf("starting task %d", i) + e.taskErrGroup.Go(func() error { + defer log.Ctx(e.ctx).Debug().Str("caller", task.Caller).Msgf("task %d exited", i) + return task.Value.Run(e.ctx) + }) + } + + runtime.Gosched() + + e.advanceState(Running) +} + +func (e *environment) NewClientCert(templateOverrides ...*x509.Certificate) *Certificate { + caCert := e.CACert() + + priv, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(e.t, err) + + sn, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128)) + require.NoError(e.t, err) + now := time.Now() + tmpl := &x509.Certificate{ + SerialNumber: sn, + Subject: pkix.Name{ + CommonName: getCaller(), + }, + NotBefore: now, + NotAfter: now.Add(12 * time.Hour), + KeyUsage: x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{ + x509.ExtKeyUsageClientAuth, + }, + BasicConstraintsValid: true, + } + for _, override := range templateOverrides { + tmpl.CRLDistributionPoints = slices.Unique(append(tmpl.CRLDistributionPoints, override.CRLDistributionPoints...)) + tmpl.DNSNames = slices.Unique(append(tmpl.DNSNames, override.DNSNames...)) + tmpl.EmailAddresses = slices.Unique(append(tmpl.EmailAddresses, override.EmailAddresses...)) + tmpl.ExtraExtensions = append(tmpl.ExtraExtensions, override.ExtraExtensions...) + tmpl.IPAddresses = slices.UniqueBy(append(tmpl.IPAddresses, override.IPAddresses...), net.IP.String) + tmpl.URIs = slices.UniqueBy(append(tmpl.URIs, override.URIs...), (*url.URL).String) + tmpl.UnknownExtKeyUsage = slices.UniqueBy(append(tmpl.UnknownExtKeyUsage, override.UnknownExtKeyUsage...), asn1.ObjectIdentifier.String) + seq := override.Subject.ToRDNSequence() + tmpl.Subject.FillFromRDNSequence(&seq) + tmpl.KeyUsage |= override.KeyUsage + tmpl.ExtKeyUsage = slices.Unique(append(tmpl.ExtKeyUsage, override.ExtKeyUsage...)) + } + + clientCertDER, err := x509.CreateCertificate(rand.Reader, tmpl, caCert.Leaf, priv.Public(), caCert.PrivateKey) + require.NoError(e.t, err) + + cert, err := x509.ParseCertificate(clientCertDER) + require.NoError(e.t, err) + e.debugf("provisioned client certificate for %s", cert.Subject.String()) + + clientCert := &tls.Certificate{ + Certificate: [][]byte{cert.Raw, caCert.Leaf.Raw}, + PrivateKey: priv, + Leaf: cert, + } + + _, err = clientCert.Leaf.Verify(x509.VerifyOptions{ + KeyUsages: []x509.ExtKeyUsage{ + x509.ExtKeyUsageClientAuth, + }, + Roots: e.ServerCAs(), + }) + require.NoError(e.t, err, "bug: generated client cert is not valid") + return (*Certificate)(clientCert) +} + +func (e *environment) NewServerCert(templateOverrides ...*x509.Certificate) *Certificate { + caCert := e.CACert() + + priv, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(e.t, err) + + sn, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128)) + require.NoError(e.t, err) + now := time.Now() + tmpl := &x509.Certificate{ + SerialNumber: sn, + NotBefore: now, + NotAfter: now.Add(12 * time.Hour), + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment, + ExtKeyUsage: []x509.ExtKeyUsage{ + x509.ExtKeyUsageServerAuth, + }, + BasicConstraintsValid: true, + } + for _, override := range templateOverrides { + tmpl.DNSNames = slices.Unique(append(tmpl.DNSNames, override.DNSNames...)) + tmpl.IPAddresses = slices.UniqueBy(append(tmpl.IPAddresses, override.IPAddresses...), net.IP.String) + } + certDER, err := x509.CreateCertificate(rand.Reader, tmpl, caCert.Leaf, priv.Public(), caCert.PrivateKey) + require.NoError(e.t, err) + + cert, err := x509.ParseCertificate(certDER) + require.NoError(e.t, err) + e.debugf("provisioned server certificate for %v", cert.DNSNames) + + tlsCert := &tls.Certificate{ + Certificate: [][]byte{cert.Raw, caCert.Leaf.Raw}, + PrivateKey: priv, + Leaf: cert, + } + + _, err = tlsCert.Leaf.Verify(x509.VerifyOptions{Roots: e.ServerCAs()}) + require.NoError(e.t, err, "bug: generated client cert is not valid") + return (*Certificate)(tlsCert) +} + +func (e *environment) SharedSecret() []byte { + return bytes.Clone(e.sharedSecret[:]) +} + +func (e *environment) CookieSecret() []byte { + return bytes.Clone(e.cookieSecret[:]) +} + +func (e *environment) Stop() { + if b, ok := e.t.(*testing.B); ok { + // when calling Stop() manually, ensure we aren't timing this + b.StopTimer() + defer b.StartTimer() + } + e.cleanupOnce.Do(func() { + e.debugf("stop: Stop() called manually") + e.advanceState(Stopping) + e.cancel(ErrCauseManualStop) + err := e.taskErrGroup.Wait() + e.advanceState(Stopped) + e.debugf("stop: done waiting") + assert.ErrorIs(e.t, err, ErrCauseManualStop) + }) +} + +func (e *environment) cleanup() { + e.cleanupOnce.Do(func() { + e.debugf("stop: test cleanup") + if e.t.Failed() { + if e.pauseOnFailure { + e.t.Log("\x1b[31m*** pausing on test failure; continue with ctrl+c ***\x1b[0m") + c := make(chan os.Signal, 1) + signal.Notify(c, syscall.SIGINT) + <-c + e.t.Log("\x1b[31mctrl+c received, continuing\x1b[0m") + signal.Stop(c) + } + } + e.advanceState(Stopping) + e.cancel(ErrCauseTestCleanup) + err := e.taskErrGroup.Wait() + e.advanceState(Stopped) + e.debugf("stop: done waiting") + assert.ErrorIs(e.t, err, ErrCauseTestCleanup) + }) +} + +func (e *environment) Add(m Modifier) { + e.t.Helper() + caller := getCaller() + e.debugf("Add: %T from %s", m, caller) + switch e.getState() { + case NotRunning: + for _, mod := range e.mods { + if mod.Value == m { + e.t.Fatalf("test bug: duplicate modifier added\nfirst added by: %s", mod.Caller) + } + } + e.mods = append(e.mods, WithCaller[Modifier]{ + Caller: caller, + Value: m, + }) + e.debugf("Add: state=NotRunning; calling Attach") + m.Attach(e.Context()) + case Starting: + panic("test bug: cannot call Add() before Start() has returned") + case Running: + e.debugf("Add: state=Running; calling ModifyConfig") + e.src.ModifyConfig(e.ctx, m) + case Stopped, Stopping: + panic("test bug: cannot call Add() after Stop()") + default: + panic(fmt.Sprintf("unexpected environment state: %s", e.getState())) + } +} + +func (e *environment) AddTask(t Task) { + e.t.Helper() + caller := getCaller() + e.debugf("AddTask: %T from %s", t, caller) + for _, task := range e.tasks { + if task.Value == t { + e.t.Fatalf("test bug: duplicate task added\nfirst added by: %s", task.Caller) + } + } + e.tasks = append(e.tasks, WithCaller[Task]{ + Caller: getCaller(), + Value: t, + }) +} + +func (e *environment) AddUpstream(up Upstream) { + e.t.Helper() + caller := getCaller() + e.debugf("AddUpstream: %T from %s", up, caller) + e.Add(up) + e.AddTask(up) +} + +// ReportError implements health.Provider. +func (e *environment) ReportError(check health.Check, err error, attributes ...health.Attr) { + // note: don't use e.t.Fatal here, it will deadlock + panic(fmt.Sprintf("%s: %v %v", check, err, attributes)) +} + +// ReportOK implements health.Provider. +func (e *environment) ReportOK(_ health.Check, _ ...health.Attr) {} + +func (e *environment) advanceState(newState EnvironmentState) { + e.stateMu.Lock() + defer e.stateMu.Unlock() + if newState <= e.state { + panic(fmt.Sprintf("internal test environment bug: changed state to <= current: newState=%s, current=%s", newState, e.state)) + } + e.debugf("state %s -> %s", e.state.String(), newState.String()) + e.state = newState + e.debugf("notifying %d listeners of state change", len(e.stateChangeListeners[newState])) + for _, listener := range e.stateChangeListeners[newState] { + go listener() + } +} + +func (e *environment) getState() EnvironmentState { + e.stateMu.Lock() + defer e.stateMu.Unlock() + return e.state +} + +func (e *environment) OnStateChanged(state EnvironmentState, callback func()) { + e.stateMu.Lock() + defer e.stateMu.Unlock() + + if e.state&state != 0 { + go callback() + return + } + + // add change listeners for all states, if there are multiple bits set + for state > 0 { + stateBit := EnvironmentState(bits.TrailingZeros32(uint32(state))) + state &= (state - 1) + e.stateChangeListeners[stateBit] = append(e.stateChangeListeners[stateBit], callback) + } +} + +func getCaller(skip ...int) string { + if len(skip) == 0 { + skip = append(skip, 3) + } + callers := make([]uintptr, 8) + runtime.Callers(skip[0], callers) + frames := runtime.CallersFrames(callers) + var caller string + for { + next, ok := frames.Next() + if !ok { + break + } + if path.Base(next.Function) == "testenv.(*environment).AddUpstream" { + continue + } + caller = fmt.Sprintf("%s:%d", next.File, next.Line) + break + } + return caller +} + +func wildcardDomain(names []string) string { + for _, name := range names { + if name[0] == '*' { + return name[2:] + } + } + panic("test bug: no wildcard domain in certificate") +} + +func isSilent(t testing.TB) bool { + switch t.(type) { + case *testing.B: + return !slices.Contains(os.Args, "-test.v=true") + default: + return false + } +} + +type configSource struct { + mu sync.Mutex + cfg *config.Config + lis []config.ChangeListener +} + +var _ config.Source = (*configSource)(nil) + +// GetConfig implements config.Source. +func (src *configSource) GetConfig() *config.Config { + src.mu.Lock() + defer src.mu.Unlock() + + return src.cfg +} + +// OnConfigChange implements config.Source. +func (src *configSource) OnConfigChange(_ context.Context, li config.ChangeListener) { + src.mu.Lock() + defer src.mu.Unlock() + + src.lis = append(src.lis, li) +} + +// ModifyConfig updates the current configuration by applying a [Modifier]. +func (src *configSource) ModifyConfig(ctx context.Context, m Modifier) { + src.mu.Lock() + defer src.mu.Unlock() + + m.Modify(src.cfg) + for _, li := range src.lis { + li(ctx, src.cfg) + } +} diff --git a/internal/testenv/logs.go b/internal/testenv/logs.go new file mode 100644 index 000000000..73789f10d --- /dev/null +++ b/internal/testenv/logs.go @@ -0,0 +1,391 @@ +package testenv + +import ( + "bufio" + "bytes" + "context" + "crypto/tls" + "crypto/x509" + "encoding/json" + "errors" + "fmt" + "io" + "os" + "reflect" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// LogRecorder captures logs from the test environment. It can be created at +// any time by calling [Environment.NewLogRecorder], and captures logs until +// one of Close(), Logs(), or Match() is called, which stops recording. See the +// documentation for each method for more details. +type LogRecorder struct { + LogRecorderOptions + t testing.TB + canceled <-chan struct{} + buf *buffer + recordedLogs []map[string]any + + removeGlobalWriterOnce func() + collectLogsOnce sync.Once +} + +type LogRecorderOptions struct { + filters []func(map[string]any) bool + skipCloseDelay bool +} + +type LogRecorderOption func(*LogRecorderOptions) + +func (o *LogRecorderOptions) apply(opts ...LogRecorderOption) { + for _, op := range opts { + op(o) + } +} + +// WithFilters applies one or more filter predicates to the logger. If there +// are filters present, they will be called in order when a log is received, +// and if any filter returns false for a given log, it will be discarded. +func WithFilters(filters ...func(map[string]any) bool) LogRecorderOption { + return func(o *LogRecorderOptions) { + o.filters = filters + } +} + +// WithSkipCloseDelay skips the 1.1 second delay before closing the recorder. +// This delay is normally required to ensure Envoy access logs are flushed, +// but can be skipped if not required. +func WithSkipCloseDelay() LogRecorderOption { + return func(o *LogRecorderOptions) { + o.skipCloseDelay = true + } +} + +type buffer struct { + mu *sync.Mutex + underlying bytes.Buffer + cond *sync.Cond + waiting bool + closed bool +} + +func newBuffer() *buffer { + mu := &sync.Mutex{} + return &buffer{ + mu: mu, + cond: sync.NewCond(mu), + } +} + +// Read implements io.ReadWriteCloser. +func (b *buffer) Read(p []byte) (int, error) { + b.mu.Lock() + defer b.mu.Unlock() + for { + n, err := b.underlying.Read(p) + if errors.Is(err, io.EOF) && !b.closed { + b.waiting = true + b.cond.Wait() + continue + } + return n, err + } +} + +// Write implements io.ReadWriteCloser. +func (b *buffer) Write(p []byte) (int, error) { + b.mu.Lock() + defer b.mu.Unlock() + if b.closed { + return 0, io.ErrClosedPipe + } + if b.waiting { + b.waiting = false + defer b.cond.Signal() + } + return b.underlying.Write(p) +} + +// Close implements io.ReadWriteCloser. +func (b *buffer) Close() error { + b.mu.Lock() + defer b.mu.Unlock() + b.closed = true + b.cond.Signal() + return nil +} + +var _ io.ReadWriteCloser = (*buffer)(nil) + +func (e *environment) NewLogRecorder(opts ...LogRecorderOption) *LogRecorder { + options := LogRecorderOptions{} + options.apply(opts...) + lr := &LogRecorder{ + LogRecorderOptions: options, + t: e.t, + canceled: e.ctx.Done(), + buf: newBuffer(), + } + e.logWriter.Add(lr.buf) + lr.removeGlobalWriterOnce = sync.OnceFunc(func() { + // wait for envoy access logs, which flush on a 1 second interval + if !lr.skipCloseDelay { + time.Sleep(1100 * time.Millisecond) + } + e.logWriter.Remove(lr.buf) + }) + context.AfterFunc(e.ctx, lr.removeGlobalWriterOnce) + return lr +} + +type ( + // OpenMap is an alias for map[string]any, and can be used to semantically + // represent a map that must contain at least the given entries, but may + // also contain additional entries. + OpenMap = map[string]any + // ClosedMap is a map[string]any that can be used to semantically represent + // a map that must contain the given entries exactly, and no others. + ClosedMap map[string]any +) + +// Close stops the log recorder. After calling this method, Logs() or Match() +// can be called to inspect the logs that were captured. +func (lr *LogRecorder) Close() { + lr.removeGlobalWriterOnce() +} + +func (lr *LogRecorder) collectLogs(shouldClose bool) { + if shouldClose { + lr.removeGlobalWriterOnce() + lr.buf.Close() + } + lr.collectLogsOnce.Do(func() { + recordedLogs := []map[string]any{} + scan := bufio.NewScanner(lr.buf) + for scan.Scan() { + log := scan.Bytes() + m := map[string]any{} + decoder := json.NewDecoder(bytes.NewReader(log)) + decoder.UseNumber() + require.NoError(lr.t, decoder.Decode(&m)) + for _, filter := range lr.filters { + if !filter(m) { + continue + } + } + recordedLogs = append(recordedLogs, m) + } + lr.recordedLogs = recordedLogs + }) +} + +func (lr *LogRecorder) WaitForMatch(expectedLog map[string]any, timeout ...time.Duration) { + lr.skipCloseDelay = true + found := make(chan struct{}) + done := make(chan struct{}) + lr.filters = append(lr.filters, func(entry map[string]any) bool { + select { + case <-found: + default: + if matched, _ := match(expectedLog, entry, true); matched { + close(found) + } + } + return true + }) + go func() { + defer close(done) + lr.collectLogs(false) + lr.removeGlobalWriterOnce() + }() + if len(timeout) != 0 { + select { + case <-found: + case <-time.After(timeout[0]): + lr.t.Error("timed out waiting for log") + case <-lr.canceled: + lr.t.Error("canceled") + } + } else { + select { + case <-found: + case <-lr.canceled: + lr.t.Error("canceled") + } + } + lr.buf.Close() + <-done +} + +// Logs stops the log recorder (if it is not already stopped), then returns +// the logs that were captured as structured map[string]any objects. +func (lr *LogRecorder) Logs() []map[string]any { + lr.collectLogs(true) + return lr.recordedLogs +} + +func (lr *LogRecorder) DumpToFile(file string) { + lr.collectLogs(true) + f, err := os.Create(file) + require.NoError(lr.t, err) + enc := json.NewEncoder(f) + for _, log := range lr.recordedLogs { + _ = enc.Encode(log) + } + f.Close() +} + +// Match stops the log recorder (if it is not already stopped), then asserts +// that the given expected logs were captured. The expected logs may contain +// partial or complete log entries. By default, logs must only match the fields +// given, and may contain additional fields that will be ignored. +// +// There are several special-case value types that can be used to customize the +// matching behavior, and/or simplify some common use cases, as follows: +// - [OpenMap] and [ClosedMap] can be used to control matching logic +// - [json.Number] will convert the actual value to a string before comparison +// - [*tls.Certificate] or [*x509.Certificate] will expand to the fields that +// would be logged for this certificate +func (lr *LogRecorder) Match(expectedLogs []map[string]any) { + lr.collectLogs(true) + for _, expectedLog := range expectedLogs { + found := false + highScore, highScoreIdxs := 0, []int{} + for i, actualLog := range lr.recordedLogs { + if ok, score := match(expectedLog, actualLog, true); ok { + found = true + break + } else if score > highScore { + highScore = score + highScoreIdxs = []int{i} + } else if score == highScore { + highScoreIdxs = append(highScoreIdxs, i) + } + } + if len(highScoreIdxs) > 0 { + expectedLogBytes, _ := json.MarshalIndent(expectedLog, "", " ") + if len(highScoreIdxs) == 1 { + actualLogBytes, _ := json.MarshalIndent(lr.recordedLogs[highScoreIdxs[0]], "", " ") + assert.True(lr.t, found, "expected log not found: \n%s\n\nclosest match:\n%s\n", + string(expectedLogBytes), string(actualLogBytes)) + } else { + closestMatches := []string{} + for _, i := range highScoreIdxs { + bytes, _ := json.MarshalIndent(lr.recordedLogs[i], "", " ") + closestMatches = append(closestMatches, string(bytes)) + } + assert.True(lr.t, found, "expected log not found: \n%s\n\nclosest matches:\n%s\n", string(expectedLogBytes), closestMatches) + } + } else { + expectedLogBytes, _ := json.MarshalIndent(expectedLog, "", " ") + assert.True(lr.t, found, "expected log not found: %s", string(expectedLogBytes)) + } + } +} + +func match(expected, actual map[string]any, open bool) (matched bool, score int) { + for key, value := range expected { + actualValue, ok := actual[key] + if !ok { + return false, score + } + score++ + + switch actualValue := actualValue.(type) { + case map[string]any: + switch expectedValue := value.(type) { + case ClosedMap: + ok, s := match(expectedValue, actualValue, false) + score += s * 2 + if !ok { + return false, score + } + case OpenMap: + ok, s := match(expectedValue, actualValue, true) + score += s + if !ok { + return false, score + } + case *tls.Certificate, *Certificate, *x509.Certificate: + var leaf *x509.Certificate + switch expectedValue := expectedValue.(type) { + case *tls.Certificate: + leaf = expectedValue.Leaf + case *Certificate: + leaf = expectedValue.Leaf + case *x509.Certificate: + leaf = expectedValue + } + + // keep logic consistent with controlplane.populateCertEventDict() + expected := map[string]any{} + if iss := leaf.Issuer.String(); iss != "" { + expected["issuer"] = iss + } + if sub := leaf.Subject.String(); sub != "" { + expected["subject"] = sub + } + sans := []string{} + for _, dnsSAN := range leaf.DNSNames { + sans = append(sans, "DNS:"+dnsSAN) + } + for _, uriSAN := range leaf.URIs { + sans = append(sans, "URI:"+uriSAN.String()) + } + if len(sans) > 0 { + expected["subjectAltName"] = sans + } + + ok, s := match(expected, actualValue, false) + score += s + if !ok { + return false, score + } + default: + return false, score + } + case string: + switch value := value.(type) { + case string: + if value != actualValue { + return false, score + } + score++ + default: + return false, score + } + case json.Number: + if fmt.Sprint(value) != actualValue.String() { + return false, score + } + score++ + default: + // handle slices + if reflect.TypeOf(actualValue).Kind() == reflect.Slice { + if reflect.TypeOf(value) != reflect.TypeOf(actualValue) { + return false, score + } + actualSlice := reflect.ValueOf(actualValue) + expectedSlice := reflect.ValueOf(value) + totalScore := 0 + for i := range min(actualSlice.Len(), expectedSlice.Len()) { + if actualSlice.Index(i).Equal(expectedSlice.Index(i)) { + totalScore++ + } + } + score += totalScore + } else { + panic(fmt.Sprintf("test bug: add check for type %T in assertMatchingLogs", actualValue)) + } + } + } + if !open && len(expected) != len(actual) { + return false, score + } + return true, score +} diff --git a/internal/testenv/route.go b/internal/testenv/route.go new file mode 100644 index 000000000..8b002f8bb --- /dev/null +++ b/internal/testenv/route.go @@ -0,0 +1,76 @@ +package testenv + +import ( + "net/url" + "strings" + + "github.com/pomerium/pomerium/config" + "github.com/pomerium/pomerium/internal/testenv/values" + "github.com/pomerium/pomerium/pkg/policy/parser" +) + +// PolicyRoute is a [Route] implementation suitable for most common use cases +// that can be used in implementations of [Upstream]. +type PolicyRoute struct { + DefaultAttach + from values.Value[string] + to values.List[string] + edits []func(*config.Policy) +} + +// Modify implements Route. +func (b *PolicyRoute) Modify(cfg *config.Config) { + to := make(config.WeightedURLs, 0, len(b.to)) + for _, u := range b.to { + u, err := url.Parse(u.Value()) + if err != nil { + panic(err) + } + to = append(to, config.WeightedURL{URL: *u}) + } + p := config.Policy{ + From: b.from.Value(), + To: to, + } + for _, edit := range b.edits { + edit(&p) + } + cfg.Options.Policies = append(cfg.Options.Policies, p) +} + +// From implements Route. +func (b *PolicyRoute) From(fromURL values.Value[string]) Route { + b.from = fromURL + return b +} + +// To implements Route. +func (b *PolicyRoute) To(toURL values.Value[string]) Route { + b.to = append(b.to, toURL) + return b +} + +// To implements Route. +func (b *PolicyRoute) Policy(edit func(*config.Policy)) Route { + b.edits = append(b.edits, edit) + return b +} + +// PPL implements Route. +func (b *PolicyRoute) PPL(ppl string) Route { + pplPolicy, err := parser.ParseYAML(strings.NewReader(ppl)) + if err != nil { + panic(err) + } + b.edits = append(b.edits, func(p *config.Policy) { + p.Policy = &config.PPLPolicy{ + Policy: pplPolicy, + } + }) + return b +} + +// To implements Route. +func (b *PolicyRoute) URL() values.Value[string] { + return b.from +} diff --git a/internal/testenv/scenarios/mock_idp.go b/internal/testenv/scenarios/mock_idp.go new file mode 100644 index 000000000..e6d6e539e --- /dev/null +++ b/internal/testenv/scenarios/mock_idp.go @@ -0,0 +1,389 @@ +package scenarios + +import ( + "context" + "crypto" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "encoding/base64" + "encoding/hex" + "encoding/json" + "fmt" + "io" + "net" + "net/http" + "net/url" + "strconv" + "strings" + "time" + + "github.com/go-jose/go-jose/v3" + "github.com/go-jose/go-jose/v3/jwt" + "github.com/google/uuid" + "github.com/pomerium/pomerium/config" + "github.com/pomerium/pomerium/internal/encoding" + "github.com/pomerium/pomerium/internal/encoding/jws" + "github.com/pomerium/pomerium/internal/log" + "github.com/pomerium/pomerium/internal/testenv" + "github.com/pomerium/pomerium/internal/testenv/upstreams" + "github.com/pomerium/pomerium/internal/testenv/values" + "github.com/pomerium/pomerium/pkg/grpc/identity" +) + +type IDP struct { + id values.Value[string] + url values.Value[string] + publicJWK jose.JSONWebKey + signingKey jose.SigningKey + + stateEncoder encoding.MarshalUnmarshaler + userLookup map[string]*User +} + +// Attach implements testenv.Modifier. +func (idp *IDP) Attach(ctx context.Context) { + env := testenv.EnvFromContext(ctx) + + router := upstreams.HTTP(nil) + + idp.url = values.Bind2(env.SubdomainURL("mock-idp"), router.Port(), func(urlStr string, port int) string { + u, _ := url.Parse(urlStr) + host, _, _ := net.SplitHostPort(u.Host) + return u.ResolveReference(&url.URL{ + Scheme: "http", + Host: fmt.Sprintf("%s:%d", host, port), + }).String() + }) + var err error + idp.stateEncoder, err = jws.NewHS256Signer(env.SharedSecret()) + env.Require().NoError(err) + + idp.id = values.Bind2(idp.url, env.AuthenticateURL(), func(idpUrl, authUrl string) string { + provider := identity.Provider{ + AuthenticateServiceUrl: authUrl, + ClientId: "CLIENT_ID", + ClientSecret: "CLIENT_SECRET", + Type: "oidc", + Scopes: []string{"openid", "email", "profile"}, + Url: idpUrl, + } + return provider.Hash() + }) + + router.Handle("/.well-known/jwks.json", func(w http.ResponseWriter, _ *http.Request) { + _ = json.NewEncoder(w).Encode(&jose.JSONWebKeySet{ + Keys: []jose.JSONWebKey{idp.publicJWK}, + }) + }) + router.Handle("/.well-known/openid-configuration", func(w http.ResponseWriter, r *http.Request) { + log.Ctx(ctx).Debug().Str("method", r.Method).Str("uri", r.RequestURI).Send() + rootURL, _ := url.Parse(idp.url.Value()) + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "issuer": rootURL.String(), + "authorization_endpoint": rootURL.ResolveReference(&url.URL{Path: "/oidc/auth"}).String(), + "token_endpoint": rootURL.ResolveReference(&url.URL{Path: "/oidc/token"}).String(), + "jwks_uri": rootURL.ResolveReference(&url.URL{Path: "/.well-known/jwks.json"}).String(), + "userinfo_endpoint": rootURL.ResolveReference(&url.URL{Path: "/oidc/userinfo"}).String(), + "id_token_signing_alg_values_supported": []string{ + "ES256", + }, + }) + }) + router.Handle("/oidc/auth", idp.HandleAuth) + router.Handle("/oidc/token", idp.HandleToken) + router.Handle("/oidc/userinfo", idp.HandleUserInfo) + + env.AddUpstream(router) +} + +// Modify implements testenv.Modifier. +func (idp *IDP) Modify(cfg *config.Config) { + cfg.Options.Provider = "oidc" + cfg.Options.ProviderURL = idp.url.Value() + cfg.Options.ClientID = "CLIENT_ID" + cfg.Options.ClientSecret = "CLIENT_SECRET" + cfg.Options.Scopes = []string{"openid", "email", "profile"} +} + +var _ testenv.Modifier = (*IDP)(nil) + +func NewIDP(users []*User) *IDP { + privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + panic(err) + } + publicKey := &privateKey.PublicKey + + signingKey := jose.SigningKey{ + Algorithm: jose.ES256, + Key: privateKey, + } + publicJWK := jose.JSONWebKey{ + Key: publicKey, + Algorithm: string(jose.ES256), + Use: "sig", + } + thumbprint, err := publicJWK.Thumbprint(crypto.SHA256) + if err != nil { + panic(err) + } + publicJWK.KeyID = hex.EncodeToString(thumbprint) + + userLookup := map[string]*User{} + for _, user := range users { + user.ID = uuid.NewString() + userLookup[user.ID] = user + } + return &IDP{ + publicJWK: publicJWK, + signingKey: signingKey, + userLookup: userLookup, + } +} + +// HandleAuth handles the auth flow for OIDC. +func (idp *IDP) HandleAuth(w http.ResponseWriter, r *http.Request) { + rawRedirectURI := r.FormValue("redirect_uri") + if rawRedirectURI == "" { + http.Error(w, "missing redirect_uri", http.StatusBadRequest) + return + } + + redirectURI, err := url.Parse(rawRedirectURI) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + rawClientID := r.FormValue("client_id") + if rawClientID == "" { + http.Error(w, "missing client_id", http.StatusBadRequest) + return + } + + rawEmail := r.FormValue("email") + if rawEmail != "" { + http.Redirect(w, r, redirectURI.ResolveReference(&url.URL{ + RawQuery: (url.Values{ + "state": {r.FormValue("state")}, + "code": {State{ + Email: rawEmail, + ClientID: rawClientID, + }.Encode()}, + }).Encode(), + }).String(), http.StatusFound) + return + } + + serveHTML(w, ` + + + Login + + +
+
+ Login + + + + + + + + + + + +
+ +
+ +
+ +
+
+ + + `) +} + +// HandleToken handles the token flow for OIDC. +func (idp *IDP) HandleToken(w http.ResponseWriter, r *http.Request) { + rawCode := r.FormValue("code") + + state, err := DecodeState(rawCode) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + serveJSON(w, map[string]interface{}{ + "access_token": state.Encode(), + "refresh_token": state.Encode(), + "token_type": "Bearer", + "id_token": state.GetIDToken(r, idp.userLookup).Encode(idp.signingKey), + }) +} + +// HandleUserInfo handles retrieving the user info. +func (idp *IDP) HandleUserInfo(w http.ResponseWriter, r *http.Request) { + authz := r.Header.Get("Authorization") + if authz == "" { + http.Error(w, "missing authorization header", http.StatusUnauthorized) + return + } + + if strings.HasPrefix(authz, "Bearer ") { + authz = authz[len("Bearer "):] + } else if strings.HasPrefix(authz, "token ") { + authz = authz[len("token "):] + } else { + http.Error(w, "missing bearer token", http.StatusUnauthorized) + return + } + + state, err := DecodeState(authz) + if err != nil { + http.Error(w, err.Error(), http.StatusForbidden) + return + } + + serveJSON(w, state.GetUserInfo(idp.userLookup)) +} + +type RootURLKey struct{} + +var rootURLKey RootURLKey + +// WithRootURL sets the Root URL in a context. +func WithRootURL(ctx context.Context, rootURL *url.URL) context.Context { + return context.WithValue(ctx, rootURLKey, rootURL) +} + +func getRootURL(r *http.Request) *url.URL { + if u, ok := r.Context().Value(rootURLKey).(*url.URL); ok { + return u + } + + u := *r.URL + if r.Host != "" { + u.Host = r.Host + } + if u.Scheme == "" { + if r.TLS != nil { + u.Scheme = "https" + } else { + u.Scheme = "http" + } + } + u.Path = "" + return &u +} + +func serveHTML(w http.ResponseWriter, html string) { + w.Header().Set("Content-Type", "text/html") + w.Header().Set("Content-Length", strconv.Itoa(len(html))) + w.WriteHeader(http.StatusOK) + _, _ = io.WriteString(w, html) +} + +func serveJSON(w http.ResponseWriter, obj interface{}) { + bs, err := json.Marshal(obj) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write(bs) +} + +type State struct { + Email string `json:"email"` + ClientID string `json:"client_id"` +} + +func DecodeState(rawCode string) (*State, error) { + var state State + bs, _ := base64.URLEncoding.DecodeString(rawCode) + err := json.Unmarshal(bs, &state) + if err != nil { + return nil, err + } + return &state, nil +} + +func (state State) Encode() string { + bs, _ := json.Marshal(state) + return base64.URLEncoding.EncodeToString(bs) +} + +func (state State) GetIDToken(r *http.Request, users map[string]*User) *IDToken { + token := &IDToken{ + UserInfo: state.GetUserInfo(users), + + Issuer: getRootURL(r).String(), + Audience: state.ClientID, + Expiry: jwt.NewNumericDate(time.Now().Add(time.Hour * 24 * 365)), + IssuedAt: jwt.NewNumericDate(time.Now()), + } + return token +} + +func (state State) GetUserInfo(users map[string]*User) *UserInfo { + userInfo := &UserInfo{ + Subject: state.Email, + Email: state.Email, + } + + for _, u := range users { + if u.Email == state.Email { + userInfo.Subject = u.ID + userInfo.Name = u.FirstName + " " + u.LastName + userInfo.FamilyName = u.LastName + userInfo.GivenName = u.FirstName + } + } + + return userInfo +} + +type UserInfo struct { + Subject string `json:"sub"` + Name string `json:"name"` + Email string `json:"email"` + FamilyName string `json:"family_name"` + GivenName string `json:"given_name"` +} + +type IDToken struct { + *UserInfo + + Issuer string `json:"iss"` + Audience string `json:"aud"` + Expiry *jwt.NumericDate `json:"exp"` + IssuedAt *jwt.NumericDate `json:"iat"` +} + +func (token *IDToken) Encode(signingKey jose.SigningKey) string { + sig, err := jose.NewSigner(signingKey, (&jose.SignerOptions{}).WithType("JWT")) + if err != nil { + panic(err) + } + + str, err := jwt.Signed(sig).Claims(token).CompactSerialize() + if err != nil { + panic(err) + } + return str +} + +type User struct { + ID string + Email string + FirstName string + LastName string +} diff --git a/internal/testenv/scenarios/mtls.go b/internal/testenv/scenarios/mtls.go new file mode 100644 index 000000000..b7bb2eb5c --- /dev/null +++ b/internal/testenv/scenarios/mtls.go @@ -0,0 +1,24 @@ +package scenarios + +import ( + "context" + "encoding/base64" + "encoding/pem" + + "github.com/pomerium/pomerium/config" + "github.com/pomerium/pomerium/internal/testenv" +) + +func DownstreamMTLS(mode config.MTLSEnforcement) testenv.Modifier { + return testenv.ModifierFunc(func(ctx context.Context, cfg *config.Config) { + env := testenv.EnvFromContext(ctx) + block := pem.Block{ + Type: "CERTIFICATE", + Bytes: env.CACert().Leaf.Raw, + } + cfg.Options.DownstreamMTLS = config.DownstreamMTLSSettings{ + CA: base64.StdEncoding.EncodeToString(pem.EncodeToMemory(&block)), + Enforcement: mode, + } + }) +} diff --git a/internal/testenv/snippets/routes.go b/internal/testenv/snippets/routes.go new file mode 100644 index 000000000..664d21bd6 --- /dev/null +++ b/internal/testenv/snippets/routes.go @@ -0,0 +1,64 @@ +package snippets + +import ( + "bytes" + "context" + "strings" + "text/template" + + "github.com/pomerium/pomerium/config" + "github.com/pomerium/pomerium/internal/testenv" + "github.com/pomerium/pomerium/pkg/policy/parser" +) + +var SimplePolicyTemplate = PolicyTemplate{ + From: "https://from-{{.Idx}}.localhost", + To: "https://to-{{.Idx}}.localhost", + PPL: `{"allow":{"and":["email":{"is":"user-{{.Idx}}@example.com"}]}}`, +} + +type PolicyTemplate struct { + From string + To string + PPL string + + // Add more fields as needed (be sure to update newPolicyFromTemplate) +} + +func TemplateRoutes(n int, tmpl PolicyTemplate) testenv.Modifier { + return testenv.ModifierFunc(func(_ context.Context, cfg *config.Config) { + for i := range n { + cfg.Options.Policies = append(cfg.Options.Policies, newPolicyFromTemplate(i, tmpl)) + } + }) +} + +func newPolicyFromTemplate(i int, pt PolicyTemplate) config.Policy { + eval := func(in string) string { + t := template.New("policy") + tmpl, err := t.Parse(in) + if err != nil { + panic(err) + } + var out bytes.Buffer + if err := tmpl.Execute(&out, struct{ Idx int }{i}); err != nil { + panic(err) + } + return out.String() + } + + pplPolicy, err := parser.ParseYAML(strings.NewReader(eval(pt.PPL))) + if err != nil { + panic(err) + } + + to, err := config.ParseWeightedUrls(eval(pt.To)) + if err != nil { + panic(err) + } + return config.Policy{ + From: eval(pt.From), + To: to, + Policy: &config.PPLPolicy{Policy: pplPolicy}, + } +} diff --git a/internal/testenv/snippets/wait.go b/internal/testenv/snippets/wait.go new file mode 100644 index 000000000..4342f9c42 --- /dev/null +++ b/internal/testenv/snippets/wait.go @@ -0,0 +1,35 @@ +package snippets + +import ( + "context" + "time" + + "github.com/pomerium/pomerium/internal/testenv" + "github.com/pomerium/pomerium/pkg/grpcutil" + "google.golang.org/grpc" + "google.golang.org/grpc/connectivity" + "google.golang.org/grpc/credentials/insecure" +) + +func WaitStartupComplete(env testenv.Environment, timeout ...time.Duration) time.Duration { + start := time.Now() + recorder := env.NewLogRecorder() + if len(timeout) == 0 { + timeout = append(timeout, 1*time.Minute) + } + ctx, ca := context.WithTimeout(env.Context(), timeout[0]) + defer ca() + recorder.WaitForMatch(map[string]any{ + "syncer_id": "databroker", + "syncer_type": "type.googleapis.com/pomerium.config.Config", + "message": "listening for updates", + }, timeout...) + cc, err := grpc.Dial(env.DatabrokerURL().Value(), + grpc.WithTransportCredentials(insecure.NewCredentials()), + grpc.WithChainUnaryInterceptor(grpcutil.WithUnarySignedJWT(env.SharedSecret)), + grpc.WithChainStreamInterceptor(grpcutil.WithStreamSignedJWT(env.SharedSecret)), + ) + env.Require().NoError(err) + env.Require().True(cc.WaitForStateChange(ctx, connectivity.Ready)) + return time.Since(start) +} diff --git a/internal/testenv/types.go b/internal/testenv/types.go new file mode 100644 index 000000000..014114f22 --- /dev/null +++ b/internal/testenv/types.go @@ -0,0 +1,206 @@ +package testenv + +import ( + "context" + + "github.com/pomerium/pomerium/config" + "github.com/pomerium/pomerium/internal/testenv/values" +) + +type envContextKeyType struct{} + +var envContextKey envContextKeyType + +func EnvFromContext(ctx context.Context) Environment { + return ctx.Value(envContextKey).(Environment) +} + +func ContextWithEnv(ctx context.Context, env Environment) context.Context { + return context.WithValue(ctx, envContextKey, env) +} + +// A Modifier is an object whose presence in the test affects the Pomerium +// configuration in some way. When the test environment is started, a +// [*config.Config] is constructed by calling each added Modifier in order. +// +// For additional details, see [Environment.Add] and [Environment.Start]. +type Modifier interface { + // Attach is called by an [Environment] (before Modify) to propagate the + // environment's context. + Attach(ctx context.Context) + + // Modify is called by an [Environment] to mutate its configuration in some + // way required by this Modifier. + Modify(cfg *config.Config) +} + +// DefaultAttach should be embedded in types implementing [Modifier] to +// automatically obtain environment context details and caller information. +type DefaultAttach struct { + env Environment + caller string +} + +func (d *DefaultAttach) Env() Environment { + d.CheckAttached() + return d.env +} + +func (d *DefaultAttach) Attach(ctx context.Context) { + if d.env != nil { + panic("internal test environment bug: Attach called twice") + } + d.env = EnvFromContext(ctx) + if d.env == nil { + panic("test bug: no environment in context") + } +} + +func (d *DefaultAttach) CheckAttached() { + if d.env == nil { + if d.caller != "" { + panic("test bug: missing a call to Add for the object created at: " + d.caller) + } + panic("test bug: not attached (possibly missing a call to Add)") + } +} + +func (d *DefaultAttach) RecordCaller() { + d.caller = getCaller(4) +} + +// Aggregate should be embedded in types implementing [Modifier] when the type +// contains other modifiers. Used as an alternative to [DefaultAttach]. +// Embedding this struct will properly keep track of when constituent modifiers +// are added, for validation and caller detection. +// +// Aggregate implements a no-op Modify() by default, but this can be overridden +// to make additional modifications. The aggregate's Modify() is called first. +type Aggregate struct { + env Environment + caller string + modifiers []Modifier +} + +func (d *Aggregate) Add(mod Modifier) { + if d.env != nil { + if d.env.(*environment).getState() == NotRunning { + // If the test environment is running, adding to an aggregate is a no-op. + // If the test environment has not been started yet, the aggregate is + // being used like in the following example, which is incorrect: + // + // aggregate.Add(foo) + // env.Add(aggregate) + // aggregate.Add(bar) + // env.Start() + // + // It should instead be used like this: + // + // aggregate.Add(foo) + // aggregate.Add(bar) + // env.Add(aggregate) + // env.Start() + panic("test bug: cannot modify an aggregate that has already been added") + } + return + } + d.modifiers = append(d.modifiers, mod) +} + +func (d *Aggregate) Env() Environment { + d.CheckAttached() + return d.env +} + +func (d *Aggregate) Attach(ctx context.Context) { + if d.env != nil { + panic("internal test environment bug: Attach called twice") + } + d.env = EnvFromContext(ctx) + if d.env == nil { + panic("test bug: no environment in context") + } + d.env.(*environment).t.Helper() + for _, mod := range d.modifiers { + d.env.Add(mod) + } +} + +func (d *Aggregate) Modify(*config.Config) {} + +func (d *Aggregate) CheckAttached() { + if d.env == nil { + if d.caller != "" { + panic("test bug: missing a call to Add for the object created at: " + d.caller) + } + panic("test bug: not attached (possibly missing a call to Add)") + } +} + +func (d *Aggregate) RecordCaller() { + d.caller = getCaller(4) +} + +type modifierFunc struct { + fn func(ctx context.Context, cfg *config.Config) + ctx context.Context +} + +// Attach implements Modifier. +func (f *modifierFunc) Attach(ctx context.Context) { + f.ctx = ctx +} + +func (f *modifierFunc) Modify(cfg *config.Config) { + f.fn(f.ctx, cfg) +} + +var _ Modifier = (*modifierFunc)(nil) + +func ModifierFunc(fn func(ctx context.Context, cfg *config.Config)) Modifier { + return &modifierFunc{fn: fn} +} + +// Task represents a background task that can be added to an [Environment] to +// have it run automatically on startup. +// +// For additional details, see [Environment.AddTask] and [Environment.Start]. +type Task interface { + Run(ctx context.Context) error +} + +type TaskFunc func(ctx context.Context) error + +func (f TaskFunc) Run(ctx context.Context) error { + return f(ctx) +} + +// Upstream represents an upstream server. It is both a [Task] and a [Modifier] +// and can be added to an environment using [Environment.AddUpstream]. From an +// Upstream instance, new routes can be created (which automatically adds the +// necessary route/policy entries to the config), and used within a test to +// easily make requests to the routes with implementation-specific clients. +type Upstream interface { + Modifier + Task + Port() values.Value[int] + Route() RouteStub +} + +// A Route represents a route from a source URL to a destination URL. A route is +// typically created by calling [Upstream.Route]. +type Route interface { + Modifier + URL() values.Value[string] + To(toURL values.Value[string]) Route + Policy(edit func(*config.Policy)) Route + PPL(ppl string) Route + // add more methods here as they become needed +} + +// RouteStub represents an incomplete [Route]. Providing a URL by calling its +// From() method will return a [Route], from which further configuration can +// be made. +type RouteStub interface { + From(fromURL values.Value[string]) Route +} diff --git a/internal/testenv/upstreams/grpc.go b/internal/testenv/upstreams/grpc.go new file mode 100644 index 000000000..f46801392 --- /dev/null +++ b/internal/testenv/upstreams/grpc.go @@ -0,0 +1,139 @@ +package upstreams + +import ( + "context" + "fmt" + "net" + "strings" + + "github.com/pomerium/pomerium/internal/testenv" + "github.com/pomerium/pomerium/internal/testenv/values" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials" +) + +type Options struct { + serverOpts []grpc.ServerOption +} + +type Option func(*Options) + +func (o *Options) apply(opts ...Option) { + for _, op := range opts { + op(o) + } +} + +func ServerOpts(opt ...grpc.ServerOption) Option { + return func(o *Options) { + o.serverOpts = append(o.serverOpts, opt...) + } +} + +// GRPCUpstream represents a GRPC server which can be used as the target for +// one or more Pomerium routes in a test environment. +// +// This upstream implements [grpc.ServiceRegistrar], and can be used similarly +// in the same way as [*grpc.Server] to register services before it is started. +// +// Any [testenv.Route] instances created from this upstream can be referenced +// in the Dial() method to establish a connection to that route. +type GRPCUpstream interface { + testenv.Upstream + grpc.ServiceRegistrar + Dial(r testenv.Route, dialOpts ...grpc.DialOption) *grpc.ClientConn +} + +type grpcUpstream struct { + Options + testenv.Aggregate + serverPort values.MutableValue[int] + creds credentials.TransportCredentials + + services []service +} + +var ( + _ testenv.Upstream = (*grpcUpstream)(nil) + _ grpc.ServiceRegistrar = (*grpcUpstream)(nil) +) + +// GRPC creates a new GRPC upstream server. +func GRPC(creds credentials.TransportCredentials, opts ...Option) GRPCUpstream { + options := Options{} + options.apply(opts...) + up := &grpcUpstream{ + Options: options, + creds: creds, + serverPort: values.Deferred[int](), + } + up.RecordCaller() + return up +} + +type service struct { + desc *grpc.ServiceDesc + impl any +} + +func (g *grpcUpstream) Port() values.Value[int] { + return g.serverPort +} + +// RegisterService implements grpc.ServiceRegistrar. +func (g *grpcUpstream) RegisterService(desc *grpc.ServiceDesc, impl any) { + g.services = append(g.services, service{desc, impl}) +} + +// Route implements testenv.Upstream. +func (g *grpcUpstream) Route() testenv.RouteStub { + r := &testenv.PolicyRoute{} + var protocol string + switch g.creds.Info().SecurityProtocol { + case "insecure": + protocol = "h2c" + default: + protocol = "https" + } + r.To(values.Bind(g.serverPort, func(port int) string { + return fmt.Sprintf("%s://127.0.0.1:%d", protocol, port) + })) + g.Add(r) + return r +} + +// Start implements testenv.Upstream. +func (g *grpcUpstream) Run(ctx context.Context) error { + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + return err + } + g.serverPort.Resolve(listener.Addr().(*net.TCPAddr).Port) + server := grpc.NewServer(append(g.serverOpts, grpc.Creds(g.creds))...) + for _, s := range g.services { + server.RegisterService(s.desc, s.impl) + } + errC := make(chan error, 1) + go func() { + errC <- server.Serve(listener) + }() + select { + case <-ctx.Done(): + server.Stop() + return context.Cause(ctx) + case err := <-errC: + return err + } +} + +func (g *grpcUpstream) Dial(r testenv.Route, dialOpts ...grpc.DialOption) *grpc.ClientConn { + dialOpts = append(dialOpts, + grpc.WithTransportCredentials(credentials.NewClientTLSFromCert(g.Env().ServerCAs(), "")), + grpc.WithDefaultCallOptions(grpc.WaitForReady(true)), + ) + cc, err := grpc.NewClient(strings.TrimPrefix(r.URL().Value(), "https://"), dialOpts...) + if err != nil { + panic(err) + } + return cc +} diff --git a/internal/testenv/upstreams/http.go b/internal/testenv/upstreams/http.go new file mode 100644 index 000000000..0138c18f3 --- /dev/null +++ b/internal/testenv/upstreams/http.go @@ -0,0 +1,327 @@ +package upstreams + +import ( + "bytes" + "context" + "crypto/tls" + "encoding/json" + "errors" + "fmt" + "io" + "net" + "net/http" + "net/http/cookiejar" + "net/url" + "strconv" + "strings" + "sync" + "time" + + "github.com/gorilla/mux" + "github.com/pomerium/pomerium/integration/forms" + "github.com/pomerium/pomerium/internal/retry" + "github.com/pomerium/pomerium/internal/testenv" + "github.com/pomerium/pomerium/internal/testenv/values" + "google.golang.org/protobuf/proto" +) + +type RequestOptions struct { + path string + query url.Values + headers map[string]string + authenticateAs string + body any + clientCerts []tls.Certificate + client *http.Client +} + +type RequestOption func(*RequestOptions) + +func (o *RequestOptions) apply(opts ...RequestOption) { + for _, op := range opts { + op(o) + } +} + +// Path sets the path of the request. If omitted, the request URL will match +// the route URL exactly. +func Path(path string) RequestOption { + return func(o *RequestOptions) { + o.path = path + } +} + +// Query sets optional query parameters of the request. +func Query(query url.Values) RequestOption { + return func(o *RequestOptions) { + o.query = query + } +} + +// Headers adds optional headers to the request. +func Headers(headers map[string]string) RequestOption { + return func(o *RequestOptions) { + o.headers = headers + } +} + +func AuthenticateAs(email string) RequestOption { + return func(o *RequestOptions) { + o.authenticateAs = email + } +} + +func Client(c *http.Client) RequestOption { + return func(o *RequestOptions) { + o.client = c + } +} + +// Body sets the body of the request. +// The argument can be one of the following types: +// - string +// - []byte +// - io.Reader +// - proto.Message +// - any json-encodable type +// If the argument is encoded as json, the Content-Type header will be set to +// "application/json". If the argument is a proto.Message, the Content-Type +// header will be set to "application/octet-stream". +func Body(body any) RequestOption { + return func(o *RequestOptions) { + o.body = body + } +} + +// ClientCert adds a client certificate to the request. +func ClientCert[T interface { + *testenv.Certificate | *tls.Certificate +}](cert T) RequestOption { + return func(o *RequestOptions) { + o.clientCerts = append(o.clientCerts, *(*tls.Certificate)(cert)) + } +} + +// HTTPUpstream represents a HTTP server which can be used as the target for +// one or more Pomerium routes in a test environment. +// +// The Handle() method can be used to add handlers the server-side HTTP router, +// while the Get(), Post(), and (generic) Do() methods can be used to make +// client-side requests. +type HTTPUpstream interface { + testenv.Upstream + + Handle(path string, f func(http.ResponseWriter, *http.Request)) *mux.Route + + Get(r testenv.Route, opts ...RequestOption) (*http.Response, error) + Post(r testenv.Route, opts ...RequestOption) (*http.Response, error) + Do(method string, r testenv.Route, opts ...RequestOption) (*http.Response, error) +} + +type httpUpstream struct { + testenv.Aggregate + serverPort values.MutableValue[int] + tlsConfig values.Value[*tls.Config] + + clientCache sync.Map // map[testenv.Route]*http.Client + + router *mux.Router +} + +var ( + _ testenv.Upstream = (*httpUpstream)(nil) + _ HTTPUpstream = (*httpUpstream)(nil) +) + +// HTTP creates a new HTTP upstream server. +func HTTP(tlsConfig values.Value[*tls.Config]) HTTPUpstream { + up := &httpUpstream{ + serverPort: values.Deferred[int](), + router: mux.NewRouter(), + tlsConfig: tlsConfig, + } + up.RecordCaller() + return up +} + +// Port implements HTTPUpstream. +func (h *httpUpstream) Port() values.Value[int] { + return h.serverPort +} + +// Router implements HTTPUpstream. +func (h *httpUpstream) Handle(path string, f func(http.ResponseWriter, *http.Request)) *mux.Route { + return h.router.HandleFunc(path, f) +} + +// Route implements HTTPUpstream. +func (h *httpUpstream) Route() testenv.RouteStub { + r := &testenv.PolicyRoute{} + protocol := "http" + r.To(values.Bind(h.serverPort, func(port int) string { + return fmt.Sprintf("%s://127.0.0.1:%d", protocol, port) + })) + h.Add(r) + return r +} + +// Run implements HTTPUpstream. +func (h *httpUpstream) Run(ctx context.Context) error { + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + return err + } + h.serverPort.Resolve(listener.Addr().(*net.TCPAddr).Port) + var tlsConfig *tls.Config + if h.tlsConfig != nil { + tlsConfig = h.tlsConfig.Value() + } + server := &http.Server{ + Handler: h.router, + TLSConfig: tlsConfig, + BaseContext: func(net.Listener) context.Context { + return ctx + }, + } + errC := make(chan error, 1) + go func() { + errC <- server.Serve(listener) + }() + select { + case <-ctx.Done(): + server.Close() + return context.Cause(ctx) + case err := <-errC: + return err + } +} + +// Get implements HTTPUpstream. +func (h *httpUpstream) Get(r testenv.Route, opts ...RequestOption) (*http.Response, error) { + return h.Do(http.MethodGet, r, opts...) +} + +// Post implements HTTPUpstream. +func (h *httpUpstream) Post(r testenv.Route, opts ...RequestOption) (*http.Response, error) { + return h.Do(http.MethodPost, r, opts...) +} + +// Do implements HTTPUpstream. +func (h *httpUpstream) Do(method string, r testenv.Route, opts ...RequestOption) (*http.Response, error) { + options := RequestOptions{} + options.apply(opts...) + u, err := url.Parse(r.URL().Value()) + if err != nil { + return nil, err + } + if options.path != "" || options.query != nil { + u = u.ResolveReference(&url.URL{ + Path: options.path, + RawQuery: options.query.Encode(), + }) + } + req, err := http.NewRequest(method, u.String(), nil) + if err != nil { + return nil, err + } + switch body := options.body.(type) { + case string: + req.Body = io.NopCloser(strings.NewReader(body)) + case []byte: + req.Body = io.NopCloser(bytes.NewReader(body)) + case io.Reader: + req.Body = io.NopCloser(body) + case proto.Message: + buf, err := proto.Marshal(body) + if err != nil { + return nil, err + } + req.Body = io.NopCloser(bytes.NewReader(buf)) + req.Header.Set("Content-Type", "application/octet-stream") + default: + buf, err := json.Marshal(body) + if err != nil { + panic(fmt.Sprintf("unsupported body type: %T", body)) + } + req.Body = io.NopCloser(bytes.NewReader(buf)) + req.Header.Set("Content-Type", "application/json") + case nil: + } + + newClient := func() *http.Client { + c := http.Client{ + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{ + RootCAs: h.Env().ServerCAs(), + Certificates: options.clientCerts, + }, + }, + } + c.Jar, _ = cookiejar.New(&cookiejar.Options{}) + return &c + } + var client *http.Client + if options.client != nil { + client = options.client + } else { + var cachedClient any + var ok bool + if cachedClient, ok = h.clientCache.Load(r); !ok { + cachedClient, _ = h.clientCache.LoadOrStore(r, newClient()) + } + client = cachedClient.(*http.Client) + } + + var resp *http.Response + if err := retry.Retry(h.Env().Context(), "http", func(ctx context.Context) error { + var err error + if options.authenticateAs != "" { + resp, err = authenticateFlow(ctx, client, req, options.authenticateAs) //nolint:bodyclose + } else { + resp, err = client.Do(req) //nolint:bodyclose + } + // retry on connection refused + if err != nil { + var opErr *net.OpError + if errors.As(err, &opErr) && opErr.Op == "dial" && opErr.Err.Error() == "connect: connection refused" { + return err + } + return retry.NewTerminalError(err) + } + if resp.StatusCode == http.StatusInternalServerError { + return errors.New(http.StatusText(resp.StatusCode)) + } + return nil + }, retry.WithMaxInterval(100*time.Millisecond)); err != nil { + return nil, err + } + return resp, nil +} + +func authenticateFlow(ctx context.Context, client *http.Client, req *http.Request, email string) (*http.Response, error) { + var res *http.Response + originalHostname := req.URL.Hostname() + res, err := client.Do(req) + if err != nil { + return nil, err + } + + location := res.Request.URL + if location.Hostname() == originalHostname { + // already authenticated + return res, err + } + defer res.Body.Close() + fs := forms.Parse(res.Body) + if len(fs) > 0 { + f := fs[0] + f.Inputs["email"] = email + f.Inputs["token_expiration"] = strconv.Itoa(int((time.Hour * 24).Seconds())) + formReq, err := f.NewRequestWithContext(ctx, location) + if err != nil { + return nil, err + } + return client.Do(formReq) + } + return nil, fmt.Errorf("test bug: expected IDP login form") +} diff --git a/internal/testenv/values/value.go b/internal/testenv/values/value.go new file mode 100644 index 000000000..03d9ac8f2 --- /dev/null +++ b/internal/testenv/values/value.go @@ -0,0 +1,120 @@ +package values + +import ( + "math/rand/v2" + "sync" +) + +type value[T any] struct { + f func() T + ready bool + cond *sync.Cond +} + +// A Value is a container for a single value of type T, whose initialization is +// performed the first time Value() is called. Subsequent calls will return the +// same value. The Value() function may block until the value is ready on the +// first call. Values are safe to use concurrently. +type Value[T any] interface { + Value() T +} + +// MutableValue is the read-write counterpart to [Value], created by calling +// [Deferred] for some type T. Calling Resolve() or ResolveFunc() will set +// the value and unblock any waiting calls to Value(). +type MutableValue[T any] interface { + Value[T] + Resolve(value T) + ResolveFunc(fOnce func() T) +} + +// Deferred creates a new read-write [MutableValue] for some type T, +// representing a value whose initialization may be deferred to a later time. +// Once the value is available, call [MutableValue.Resolve] or +// [MutableValue.ResolveFunc] to unblock any waiting calls to Value(). +func Deferred[T any]() MutableValue[T] { + return &value[T]{ + cond: sync.NewCond(&sync.Mutex{}), + } +} + +// Const creates a read-only [Value] which will become available immediately +// upon calling Value() for the first time; it will never block. +func Const[T any](t T) Value[T] { + return &value[T]{ + f: func() T { return t }, + ready: true, + cond: sync.NewCond(&sync.Mutex{}), + } +} + +func (p *value[T]) Value() T { + p.cond.L.Lock() + defer p.cond.L.Unlock() + for !p.ready { + p.cond.Wait() + } + return p.f() +} + +func (p *value[T]) ResolveFunc(fOnce func() T) { + p.cond.L.Lock() + p.f = sync.OnceValue(fOnce) + p.ready = true + p.cond.L.Unlock() + p.cond.Broadcast() +} + +func (p *value[T]) Resolve(value T) { + p.ResolveFunc(func() T { return value }) +} + +// Bind creates a new [MutableValue] whose ultimate value depends on the result +// of another [Value] that may not yet be available. When Value() is called on +// the result, it will cascade and trigger the full chain of initialization +// functions necessary to produce the final value. +// +// Care should be taken when using this function, as improper use can lead to +// deadlocks and cause values to never become available. +func Bind[T any, U any](dt Value[T], callback func(value T) U) Value[U] { + du := Deferred[U]() + du.ResolveFunc(func() U { + return callback(dt.Value()) + }) + return du +} + +// Bind2 is like [Bind], but can accept two input values. The result will only +// become available once all input values become available. +// +// This function blocks to wait for each input value in sequence, but in a +// random order. Do not rely on the order of evaluation of the input values. +func Bind2[T any, U any, V any](dt Value[T], du Value[U], callback func(value1 T, value2 U) V) Value[V] { + dv := Deferred[V]() + dv.ResolveFunc(func() V { + if rand.IntN(2) == 0 { //nolint:gosec + return callback(dt.Value(), du.Value()) + } + u := du.Value() + t := dt.Value() + return callback(t, u) + }) + return dv +} + +// List is a container for a slice of [Value] of type T, and is also a [Value] +// itself, for convenience. The Value() function will return a []T containing +// all resolved values for each element in the slice. +// +// A List's Value() function blocks to wait for each element in the slice in +// sequence, but in a random order. Do not rely on the order of evaluation of +// the slice elements. +type List[T any] []Value[T] + +func (s List[T]) Value() []T { + values := make([]T, len(s)) + for _, i := range rand.Perm(len(values)) { + values[i] = s[i].Value() + } + return values +} diff --git a/pkg/cmd/pomerium/pomerium.go b/pkg/cmd/pomerium/pomerium.go index 489c4ed55..d70b3959d 100644 --- a/pkg/cmd/pomerium/pomerium.go +++ b/pkg/cmd/pomerium/pomerium.go @@ -16,6 +16,7 @@ import ( "github.com/pomerium/pomerium/authenticate" "github.com/pomerium/pomerium/authorize" "github.com/pomerium/pomerium/config" + "github.com/pomerium/pomerium/config/envoyconfig/filemgr" databroker_service "github.com/pomerium/pomerium/databroker" "github.com/pomerium/pomerium/internal/autocert" "github.com/pomerium/pomerium/internal/controlplane" @@ -30,8 +31,29 @@ import ( "github.com/pomerium/pomerium/proxy" ) +type RunOptions struct { + fileMgr *filemgr.Manager +} + +type RunOption func(*RunOptions) + +func (o *RunOptions) apply(opts ...RunOption) { + for _, op := range opts { + op(o) + } +} + +func WithOverrideFileManager(fileMgr *filemgr.Manager) RunOption { + return func(o *RunOptions) { + o.fileMgr = fileMgr + } +} + // Run runs the main pomerium application. -func Run(ctx context.Context, src config.Source) error { +func Run(ctx context.Context, src config.Source, opts ...RunOption) error { + options := RunOptions{} + options.apply(opts...) + _, _ = maxprocs.Set(maxprocs.Logger(func(s string, i ...any) { log.Ctx(ctx).Debug().Msgf(s, i...) })) evt := log.Ctx(ctx).Info(). @@ -68,10 +90,15 @@ func Run(ctx context.Context, src config.Source) error { eventsMgr := events.New() + fileMgr := options.fileMgr + if fileMgr == nil { + fileMgr = filemgr.NewManager() + } + cfg := src.GetConfig() // setup the control plane - controlPlane, err := controlplane.NewServer(ctx, cfg, metricsMgr, eventsMgr) + controlPlane, err := controlplane.NewServer(ctx, cfg, metricsMgr, eventsMgr, fileMgr) if err != nil { return fmt.Errorf("error creating control plane: %w", err) } diff --git a/pkg/cryptutil/tls.go b/pkg/cryptutil/tls.go index 7900e9b69..0d8075686 100644 --- a/pkg/cryptutil/tls.go +++ b/pkg/cryptutil/tls.go @@ -1,7 +1,6 @@ package cryptutil import ( - "context" "crypto/tls" "crypto/x509" "encoding/base64" @@ -15,10 +14,9 @@ import ( // GetCertPool gets a cert pool for the given CA or CAFile. func GetCertPool(ca, caFile string) (*x509.CertPool, error) { - ctx := context.TODO() rootCAs, err := x509.SystemCertPool() if err != nil { - log.Ctx(ctx).Error().Err(err).Msg("pkg/cryptutil: failed getting system cert pool making new one") + log.Error().Err(err).Msg("pkg/cryptutil: failed getting system cert pool making new one") rootCAs = x509.NewCertPool() } if ca == "" && caFile == "" { @@ -40,7 +38,9 @@ func GetCertPool(ca, caFile string) (*x509.CertPool, error) { if ok := rootCAs.AppendCertsFromPEM(data); !ok { return nil, fmt.Errorf("failed to append any PEM-encoded certificates") } - log.Ctx(ctx).Debug().Msg("pkg/cryptutil: added custom certificate authority") + if !log.DebugDisableGlobalMessages.Load() { + log.Debug().Msg("pkg/cryptutil: added custom certificate authority") + } return rootCAs, nil } diff --git a/pkg/envoy/envoy.go b/pkg/envoy/envoy.go index 0980b0c82..66cf71ae5 100644 --- a/pkg/envoy/envoy.go +++ b/pkg/envoy/envoy.go @@ -186,7 +186,25 @@ func (srv *Server) run(ctx context.Context, cfg *config.Config) error { // monitor the process so we exit if it prematurely exits var monitorProcessCtx context.Context monitorProcessCtx, srv.monitorProcessCancel = context.WithCancel(context.WithoutCancel(ctx)) - go srv.monitorProcess(monitorProcessCtx, int32(cmd.Process.Pid)) + go func() { + pid := cmd.Process.Pid + err := srv.monitorProcess(monitorProcessCtx, int32(pid)) + if err != nil && ctx.Err() == nil { + // If the envoy subprocess exits and ctx is not done, issue a fatal error. + // If ctx is done, the server is already exiting, and envoy is expected + // to be stopped along with it. + log.Ctx(ctx). + Fatal(). + Int("pid", pid). + Err(err). + Send() + } + log.Ctx(ctx). + Debug(). + Int("pid", pid). + Err(ctx.Err()). + Msg("envoy process monitor stopped") + }() if srv.resourceMonitor != nil { log.Ctx(ctx).Debug().Str("service", "envoy").Msg("starting resource monitor") @@ -300,7 +318,7 @@ func (srv *Server) handleLogs(ctx context.Context, rc io.ReadCloser) { } } -func (srv *Server) monitorProcess(ctx context.Context, pid int32) { +func (srv *Server) monitorProcess(ctx context.Context, pid int32) error { log.Ctx(ctx).Debug(). Int32("pid", pid). Msg("envoy: start monitoring subprocess") @@ -311,19 +329,15 @@ func (srv *Server) monitorProcess(ctx context.Context, pid int32) { for { exists, err := process.PidExistsWithContext(ctx, pid) if err != nil { - log.Fatal().Err(err). - Int32("pid", pid). - Msg("envoy: error retrieving subprocess information") + return fmt.Errorf("envoy: error retrieving subprocess information: %w", err) } else if !exists { - log.Fatal().Err(err). - Int32("pid", pid). - Msg("envoy: subprocess exited") + return errors.New("envoy: subprocess exited") } // wait for the next tick select { case <-ctx.Done(): - return + return nil case <-ticker.C: } } diff --git a/pkg/envoy/envoy_linux.go b/pkg/envoy/envoy_linux.go index 4f2fe3644..88c3a13e5 100644 --- a/pkg/envoy/envoy_linux.go +++ b/pkg/envoy/envoy_linux.go @@ -6,6 +6,7 @@ package envoy import ( "context" "os" + "path/filepath" "strconv" "sync" "syscall" @@ -17,7 +18,7 @@ import ( "github.com/pomerium/pomerium/internal/telemetry/metrics" ) -const baseIDPath = "/tmp/pomerium-envoy-base-id" +const baseIDName = "pomerium-envoy-base-id" var restartEpoch struct { sync.Mutex @@ -89,7 +90,7 @@ func (srv *Server) prepareRunEnvoyCommand(ctx context.Context, sharedArgs []stri } else { args = append(args, "--use-dynamic-base-id", - "--base-id-path", baseIDPath, + "--base-id-path", filepath.Join(os.TempDir(), baseIDName), ) restartEpoch.value = 1 } @@ -99,7 +100,7 @@ func (srv *Server) prepareRunEnvoyCommand(ctx context.Context, sharedArgs []stri } func readBaseID() (int, bool) { - bs, err := os.ReadFile(baseIDPath) + bs, err := os.ReadFile(filepath.Join(os.TempDir(), baseIDName)) if err != nil { return 0, false } diff --git a/pkg/grpc/databroker/syncer.go b/pkg/grpc/databroker/syncer.go index 769fd3425..4391024fd 100644 --- a/pkg/grpc/databroker/syncer.go +++ b/pkg/grpc/databroker/syncer.go @@ -3,6 +3,7 @@ package databroker import ( "context" "fmt" + "sync/atomic" "time" backoff "github.com/cenkalti/backoff/v4" @@ -71,12 +72,24 @@ type Syncer struct { id string } +var DebugUseFasterBackoff atomic.Bool + // NewSyncer creates a new Syncer. func NewSyncer(ctx context.Context, id string, handler SyncerHandler, options ...SyncerOption) *Syncer { closeCtx, closeCtxCancel := context.WithCancel(context.WithoutCancel(ctx)) - bo := backoff.NewExponentialBackOff() - bo.MaxElapsedTime = 0 + var bo *backoff.ExponentialBackOff + if DebugUseFasterBackoff.Load() { + bo = backoff.NewExponentialBackOff( + backoff.WithInitialInterval(10*time.Millisecond), + backoff.WithMultiplier(1.0), + backoff.WithMaxElapsedTime(100*time.Millisecond), + ) + bo.MaxElapsedTime = 0 + } else { + bo = backoff.NewExponentialBackOff() + bo.MaxElapsedTime = 0 + } s := &Syncer{ cfg: getSyncerConfig(options...), handler: handler,