mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-29 17:07:24 +02:00
New integration test fixtures (#5233)
* Initial test environment implementation * linter pass * wip: update request latency test * bugfixes * Fix logic race in envoy process monitor when canceling context * skip tests using test environment on non-linux
This commit is contained in:
parent
3d958ff9c5
commit
526e2a58d6
29 changed files with 2972 additions and 101 deletions
|
@ -76,3 +76,6 @@ issues:
|
|||
- text: "G112:"
|
||||
linters:
|
||||
- gosec
|
||||
- text: "G402: TLS MinVersion too low."
|
||||
linters:
|
||||
- gosec
|
||||
|
|
|
@ -26,7 +26,7 @@ import (
|
|||
const maxActiveDownstreamConnections = 50000
|
||||
|
||||
var (
|
||||
envoyAdminAddressPath = filepath.Join(os.TempDir(), "pomerium-envoy-admin.sock")
|
||||
envoyAdminAddressSockName = "pomerium-envoy-admin.sock"
|
||||
envoyAdminAddressMode = 0o600
|
||||
envoyAdminClusterName = "pomerium-envoy-admin"
|
||||
)
|
||||
|
@ -95,7 +95,7 @@ func (b *Builder) BuildBootstrapAdmin(cfg *config.Config) (admin *envoy_config_b
|
|||
admin.Address = &envoy_config_core_v3.Address{
|
||||
Address: &envoy_config_core_v3.Address_Pipe{
|
||||
Pipe: &envoy_config_core_v3.Pipe{
|
||||
Path: envoyAdminAddressPath,
|
||||
Path: filepath.Join(os.TempDir(), envoyAdminAddressSockName),
|
||||
Mode: uint32(envoyAdminAddressMode),
|
||||
},
|
||||
},
|
||||
|
|
|
@ -12,6 +12,7 @@ import (
|
|||
)
|
||||
|
||||
func TestBuilder_BuildBootstrapAdmin(t *testing.T) {
|
||||
t.Setenv("TMPDIR", "/tmp")
|
||||
b := New("local-grpc", "local-http", "local-metrics", filemgr.NewManager(), nil)
|
||||
t.Run("valid", func(t *testing.T) {
|
||||
adminCfg, err := b.BuildBootstrapAdmin(&config.Config{
|
||||
|
@ -25,7 +26,7 @@ func TestBuilder_BuildBootstrapAdmin(t *testing.T) {
|
|||
"address": {
|
||||
"pipe": {
|
||||
"mode": 384,
|
||||
"path": "`+envoyAdminAddressPath+`"
|
||||
"path": "/tmp/`+envoyAdminAddressSockName+`"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2,6 +2,8 @@ package envoyconfig
|
|||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
envoy_config_cluster_v3 "github.com/envoyproxy/go-control-plane/envoy/config/cluster/v3"
|
||||
envoy_config_core_v3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
|
||||
|
@ -23,7 +25,8 @@ func (b *Builder) buildEnvoyAdminCluster(_ context.Context, _ *config.Config) (*
|
|||
Address: &envoy_config_core_v3.Address{
|
||||
Address: &envoy_config_core_v3.Address_Pipe{
|
||||
Pipe: &envoy_config_core_v3.Pipe{
|
||||
Path: envoyAdminAddressPath,
|
||||
Path: filepath.Join(os.TempDir(), envoyAdminAddressSockName),
|
||||
Mode: uint32(envoyAdminAddressMode),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
|
@ -23,11 +23,7 @@ import (
|
|||
func Test_BuildClusters(t *testing.T) {
|
||||
// The admin address path is based on os.TempDir(), which will vary from
|
||||
// system to system, so replace this with a stable location.
|
||||
originalEnvoyAdminAddressPath := envoyAdminAddressPath
|
||||
envoyAdminAddressPath = "/tmp/pomerium-envoy-admin.sock"
|
||||
t.Cleanup(func() {
|
||||
envoyAdminAddressPath = originalEnvoyAdminAddressPath
|
||||
})
|
||||
t.Setenv("TMPDIR", "/tmp")
|
||||
|
||||
opts := config.NewDefaultOptions()
|
||||
ctx := context.Background()
|
||||
|
|
|
@ -1,102 +1,159 @@
|
|||
package envoyconfig_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"io"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/credentials"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
"google.golang.org/grpc/interop"
|
||||
"google.golang.org/grpc/interop/grpc_testing"
|
||||
"google.golang.org/grpc/metadata"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
"github.com/pomerium/pomerium/config"
|
||||
"github.com/pomerium/pomerium/pkg/cmd/pomerium"
|
||||
"github.com/pomerium/pomerium/pkg/netutil"
|
||||
"github.com/pomerium/pomerium/internal/testenv"
|
||||
"github.com/pomerium/pomerium/internal/testenv/scenarios"
|
||||
"github.com/pomerium/pomerium/internal/testenv/snippets"
|
||||
"github.com/pomerium/pomerium/internal/testenv/upstreams"
|
||||
"github.com/pomerium/pomerium/internal/testenv/values"
|
||||
)
|
||||
|
||||
func TestH2C(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.SkipNow()
|
||||
}
|
||||
env := testenv.New(t)
|
||||
|
||||
ctx, ca := context.WithCancel(context.Background())
|
||||
up := upstreams.GRPC(insecure.NewCredentials())
|
||||
grpc_testing.RegisterTestServiceServer(up, interop.NewTestServer())
|
||||
|
||||
opts := config.NewDefaultOptions()
|
||||
listener, err := (&net.ListenConfig{}).Listen(ctx, "tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
ports, err := netutil.AllocatePorts(7)
|
||||
require.NoError(t, err)
|
||||
urls, err := config.ParseWeightedUrls("http://"+listener.Addr().String(), "h2c://"+listener.Addr().String())
|
||||
require.NoError(t, err)
|
||||
opts.Addr = fmt.Sprintf("127.0.0.1:%s", ports[0])
|
||||
opts.Routes = []config.Policy{
|
||||
{
|
||||
From: fmt.Sprintf("https://grpc-http.localhost.pomerium.io:%s", ports[0]),
|
||||
To: urls[:1],
|
||||
AllowPublicUnauthenticatedAccess: true,
|
||||
},
|
||||
{
|
||||
From: fmt.Sprintf("https://grpc-h2c.localhost.pomerium.io:%s", ports[0]),
|
||||
To: urls[1:],
|
||||
AllowPublicUnauthenticatedAccess: true,
|
||||
},
|
||||
}
|
||||
opts.CertFile = "../../integration/tpl/files/trusted.pem"
|
||||
opts.KeyFile = "../../integration/tpl/files/trusted-key.pem"
|
||||
cfg := &config.Config{Options: opts}
|
||||
cfg.AllocatePorts(*(*[6]string)(ports[1:]))
|
||||
http := up.Route().
|
||||
From(env.SubdomainURL("grpc-http")).
|
||||
To(values.Bind(up.Port(), func(port int) string {
|
||||
// override the target protocol to use http://
|
||||
return fmt.Sprintf("http://127.0.0.1:%d", port)
|
||||
})).
|
||||
Policy(func(p *config.Policy) { p.AllowPublicUnauthenticatedAccess = true })
|
||||
|
||||
server := grpc.NewServer(grpc.Creds(insecure.NewCredentials()))
|
||||
grpc_testing.RegisterTestServiceServer(server, interop.NewTestServer())
|
||||
go server.Serve(listener)
|
||||
h2c := up.Route().
|
||||
From(env.SubdomainURL("grpc-h2c")).
|
||||
Policy(func(p *config.Policy) { p.AllowPublicUnauthenticatedAccess = true })
|
||||
|
||||
errC := make(chan error, 1)
|
||||
go func() {
|
||||
errC <- pomerium.Run(ctx, config.NewStaticSource(cfg))
|
||||
}()
|
||||
|
||||
t.Cleanup(func() {
|
||||
ca()
|
||||
assert.ErrorIs(t, context.Canceled, <-errC)
|
||||
})
|
||||
|
||||
tlsConfig, err := credentials.NewClientTLSFromFile("../../integration/tpl/files/ca.pem", "")
|
||||
require.NoError(t, err)
|
||||
env.AddUpstream(up)
|
||||
env.Start()
|
||||
snippets.WaitStartupComplete(env)
|
||||
|
||||
t.Run("h2c", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
recorder := env.NewLogRecorder()
|
||||
|
||||
cc, err := grpc.Dial(fmt.Sprintf("grpc-h2c.localhost.pomerium.io:%s", ports[0]), grpc.WithTransportCredentials(tlsConfig))
|
||||
require.NoError(t, err)
|
||||
cc := up.Dial(h2c)
|
||||
client := grpc_testing.NewTestServiceClient(cc)
|
||||
var md metadata.MD
|
||||
_, err = client.EmptyCall(ctx, &grpc_testing.Empty{}, grpc.WaitForReady(true), grpc.Header(&md))
|
||||
_, err := client.EmptyCall(env.Context(), &grpc_testing.Empty{})
|
||||
require.NoError(t, err)
|
||||
cc.Close()
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, md, "x-envoy-upstream-service-time")
|
||||
|
||||
recorder.Match([]map[string]any{
|
||||
{
|
||||
"service": "envoy",
|
||||
"path": "/grpc.testing.TestService/EmptyCall",
|
||||
"message": "http-request",
|
||||
"response-code-details": "via_upstream",
|
||||
},
|
||||
})
|
||||
})
|
||||
t.Run("http", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
recorder := env.NewLogRecorder()
|
||||
|
||||
cc, err := grpc.Dial(fmt.Sprintf("grpc-http.localhost.pomerium.io:%s", ports[0]), grpc.WithTransportCredentials(tlsConfig))
|
||||
require.NoError(t, err)
|
||||
cc := up.Dial(http)
|
||||
client := grpc_testing.NewTestServiceClient(cc)
|
||||
var md metadata.MD
|
||||
_, err = client.EmptyCall(ctx, &grpc_testing.Empty{}, grpc.WaitForReady(true), grpc.Trailer(&md))
|
||||
_, err := client.UnaryCall(env.Context(), &grpc_testing.SimpleRequest{})
|
||||
require.Error(t, err)
|
||||
cc.Close()
|
||||
stat := status.Convert(err)
|
||||
assert.NotNil(t, stat)
|
||||
assert.Equal(t, stat.Code(), codes.Unavailable)
|
||||
assert.NotContains(t, md, "x-envoy-upstream-service-time")
|
||||
assert.Contains(t, stat.Message(), "<!DOCTYPE html>")
|
||||
assert.Contains(t, stat.Message(), "upstream_reset_before_response_started{protocol_error}")
|
||||
|
||||
recorder.Match([]map[string]any{
|
||||
{
|
||||
"service": "envoy",
|
||||
"path": "/grpc.testing.TestService/UnaryCall",
|
||||
"message": "http-request",
|
||||
"response-code-details": "upstream_reset_before_response_started{protocol_error}",
|
||||
},
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestHTTP(t *testing.T) {
|
||||
env := testenv.New(t)
|
||||
|
||||
up := upstreams.HTTP(nil)
|
||||
up.Handle("/foo", func(w http.ResponseWriter, _ *http.Request) {
|
||||
fmt.Fprintln(w, "hello world")
|
||||
})
|
||||
|
||||
route := up.Route().
|
||||
From(env.SubdomainURL("http")).
|
||||
Policy(func(p *config.Policy) { p.AllowPublicUnauthenticatedAccess = true })
|
||||
|
||||
env.AddUpstream(up)
|
||||
env.Start()
|
||||
|
||||
recorder := env.NewLogRecorder()
|
||||
|
||||
resp, err := up.Get(route, upstreams.Path("/foo"))
|
||||
require.NoError(t, err)
|
||||
|
||||
defer resp.Body.Close()
|
||||
data, err := io.ReadAll(resp.Body)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "hello world\n", string(data))
|
||||
|
||||
recorder.Match([]map[string]any{
|
||||
{
|
||||
"service": "envoy",
|
||||
"path": "/foo",
|
||||
"method": "GET",
|
||||
"message": "http-request",
|
||||
"response-code-details": "via_upstream",
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func TestClientCert(t *testing.T) {
|
||||
env := testenv.New(t)
|
||||
env.Add(scenarios.DownstreamMTLS(config.MTLSEnforcementRejectConnection))
|
||||
|
||||
up := upstreams.HTTP(nil)
|
||||
up.Handle("/foo", func(w http.ResponseWriter, _ *http.Request) {
|
||||
fmt.Fprintln(w, "hello world")
|
||||
})
|
||||
|
||||
clientCert := env.NewClientCert()
|
||||
|
||||
route := up.Route().
|
||||
From(env.SubdomainURL("http")).
|
||||
PPL(fmt.Sprintf(`{"allow":{"and":["client_certificate":{"fingerprint":%q}]}}`, clientCert.Fingerprint()))
|
||||
|
||||
env.AddUpstream(up)
|
||||
env.Start()
|
||||
|
||||
recorder := env.NewLogRecorder()
|
||||
|
||||
resp, err := up.Get(route, upstreams.Path("/foo"), upstreams.ClientCert(clientCert))
|
||||
require.NoError(t, err)
|
||||
|
||||
defer resp.Body.Close()
|
||||
data, err := io.ReadAll(resp.Body)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "hello world\n", string(data))
|
||||
|
||||
recorder.Match([]map[string]any{
|
||||
{
|
||||
"service": "envoy",
|
||||
"path": "/foo",
|
||||
"method": "GET",
|
||||
"message": "http-request",
|
||||
"response-code-details": "via_upstream",
|
||||
"client-certificate": clientCert,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
|
1
config/envoyconfig/testdata/clusters.json
vendored
1
config/envoyconfig/testdata/clusters.json
vendored
|
@ -280,6 +280,7 @@
|
|||
"endpoint": {
|
||||
"address": {
|
||||
"pipe": {
|
||||
"mode": 384,
|
||||
"path": "/tmp/pomerium-envoy-admin.sock"
|
||||
}
|
||||
}
|
||||
|
|
|
@ -129,6 +129,7 @@ func newManager(
|
|||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
cache.Stop()
|
||||
return
|
||||
case <-ticker.C:
|
||||
err := mgr.renewConfigCerts(ctx)
|
||||
|
|
54
internal/benchmarks/config_bench_test.go
Normal file
54
internal/benchmarks/config_bench_test.go
Normal file
|
@ -0,0 +1,54 @@
|
|||
package benchmarks_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/testenv"
|
||||
"github.com/pomerium/pomerium/internal/testenv/snippets"
|
||||
"github.com/pomerium/pomerium/internal/testenv/upstreams"
|
||||
)
|
||||
|
||||
func BenchmarkStartupLatency(b *testing.B) {
|
||||
for _, n := range []int{1, 10, 100, 1000, 10000} {
|
||||
b.Run(fmt.Sprintf("routes=%d", n), func(b *testing.B) {
|
||||
for range b.N {
|
||||
env := testenv.New(b)
|
||||
up := upstreams.HTTP(nil)
|
||||
for i := range n {
|
||||
up.Route().
|
||||
From(env.SubdomainURL(fmt.Sprintf("from-%d", i))).
|
||||
PPL(`{"allow":{"and":[{"accept":"true"}]}}`)
|
||||
}
|
||||
env.AddUpstream(up)
|
||||
|
||||
env.Start()
|
||||
snippets.WaitStartupComplete(env, 60*time.Minute)
|
||||
|
||||
env.Stop()
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkAppendRoutes(b *testing.B) {
|
||||
for _, n := range []int{1, 10, 100, 1000, 10000} {
|
||||
b.Run(fmt.Sprintf("routes=%d", n), func(b *testing.B) {
|
||||
for range b.N {
|
||||
env := testenv.New(b)
|
||||
up := upstreams.HTTP(nil)
|
||||
env.AddUpstream(up)
|
||||
|
||||
env.Start()
|
||||
snippets.WaitStartupComplete(env)
|
||||
for i := range n {
|
||||
env.Add(up.Route().
|
||||
From(env.SubdomainURL(fmt.Sprintf("from-%d", i))).
|
||||
PPL(fmt.Sprintf(`{"allow":{"and":["email":{"is":"user-%d@example.com"}]}}`, i)))
|
||||
}
|
||||
env.Stop()
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
87
internal/benchmarks/latency_bench_test.go
Normal file
87
internal/benchmarks/latency_bench_test.go
Normal file
|
@ -0,0 +1,87 @@
|
|||
package benchmarks_test
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"math/rand/v2"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/testenv"
|
||||
"github.com/pomerium/pomerium/internal/testenv/scenarios"
|
||||
"github.com/pomerium/pomerium/internal/testenv/snippets"
|
||||
"github.com/pomerium/pomerium/internal/testenv/upstreams"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
var (
|
||||
numRoutes int
|
||||
dumpErrLogs bool
|
||||
)
|
||||
|
||||
func init() {
|
||||
flag.IntVar(&numRoutes, "routes", 100, "number of routes")
|
||||
flag.BoolVar(&dumpErrLogs, "dump-err-logs", false, "if the test fails, write all captured logs to a file (testdata/<test-name>)")
|
||||
}
|
||||
|
||||
func TestRequestLatency(t *testing.T) {
|
||||
env := testenv.New(t, testenv.Silent())
|
||||
users := []*scenarios.User{}
|
||||
for i := range numRoutes {
|
||||
users = append(users, &scenarios.User{
|
||||
Email: fmt.Sprintf("user%d@example.com", i),
|
||||
FirstName: fmt.Sprintf("Firstname%d", i),
|
||||
LastName: fmt.Sprintf("Lastname%d", i),
|
||||
})
|
||||
}
|
||||
env.Add(scenarios.NewIDP(users))
|
||||
|
||||
up := upstreams.HTTP(nil)
|
||||
up.Handle("/", func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.Write([]byte("OK"))
|
||||
})
|
||||
routes := make([]testenv.Route, numRoutes)
|
||||
for i := range numRoutes {
|
||||
routes[i] = up.Route().
|
||||
From(env.SubdomainURL(fmt.Sprintf("from-%d", i))).
|
||||
PPL(fmt.Sprintf(`{"allow":{"and":["email":{"is":"user%d@example.com"}]}}`, i))
|
||||
}
|
||||
env.AddUpstream(up)
|
||||
|
||||
env.Start()
|
||||
snippets.WaitStartupComplete(env)
|
||||
|
||||
out := testing.Benchmark(func(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
var rec *testenv.LogRecorder
|
||||
if dumpErrLogs {
|
||||
rec = env.NewLogRecorder(testenv.WithSkipCloseDelay())
|
||||
}
|
||||
for pb.Next() {
|
||||
idx := rand.IntN(numRoutes)
|
||||
resp, err := up.Get(routes[idx], upstreams.AuthenticateAs(fmt.Sprintf("user%d@example.com", idx)))
|
||||
if !assert.NoError(b, err) {
|
||||
filename := "TestRequestLatency_err.log"
|
||||
if dumpErrLogs {
|
||||
rec.DumpToFile(filename)
|
||||
b.Logf("test logs written to %s", filename)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
assert.Equal(b, resp.StatusCode, 200)
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
assert.NoError(b, err)
|
||||
assert.Equal(b, "OK", string(body))
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
t.Log(out)
|
||||
t.Logf("req/s: %f", float64(out.N)/out.T.Seconds())
|
||||
|
||||
env.Stop()
|
||||
}
|
|
@ -74,11 +74,12 @@ func NewServer(
|
|||
cfg *config.Config,
|
||||
metricsMgr *config.MetricsManager,
|
||||
eventsMgr *events.Manager,
|
||||
fileMgr *filemgr.Manager,
|
||||
) (*Server, error) {
|
||||
srv := &Server{
|
||||
metricsMgr: metricsMgr,
|
||||
EventsMgr: eventsMgr,
|
||||
filemgr: filemgr.NewManager(),
|
||||
filemgr: fileMgr,
|
||||
reproxy: reproxy.New(),
|
||||
haveSetCapacity: map[string]bool{},
|
||||
updateConfig: make(chan *config.Config, 1),
|
||||
|
|
|
@ -12,6 +12,7 @@ import (
|
|||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/pomerium/pomerium/config"
|
||||
"github.com/pomerium/pomerium/config/envoyconfig/filemgr"
|
||||
"github.com/pomerium/pomerium/internal/events"
|
||||
"github.com/pomerium/pomerium/pkg/netutil"
|
||||
)
|
||||
|
@ -38,7 +39,7 @@ func TestServerHTTP(t *testing.T) {
|
|||
cfg.Options.SharedKey = "JDNjY2ITDlARvNaQXjc2Djk+GA6xeCy4KiozmZfdbTs="
|
||||
|
||||
src := config.NewStaticSource(cfg)
|
||||
srv, err := NewServer(ctx, cfg, config.NewMetricsManager(ctx, src), events.New())
|
||||
srv, err := NewServer(ctx, cfg, config.NewMetricsManager(ctx, src), events.New(), filemgr.NewManager(filemgr.WithCacheDir(t.TempDir())))
|
||||
require.NoError(t, err)
|
||||
go srv.Run(ctx)
|
||||
|
||||
|
|
|
@ -7,4 +7,6 @@ var (
|
|||
DebugDisableZapLogger atomic.Bool
|
||||
// Debug option to suppress global warnings
|
||||
DebugDisableGlobalWarnings atomic.Bool
|
||||
// Debug option to suppress global (non-warning) messages
|
||||
DebugDisableGlobalMessages atomic.Bool
|
||||
)
|
||||
|
|
838
internal/testenv/environment.go
Normal file
838
internal/testenv/environment.go
Normal file
|
@ -0,0 +1,838 @@
|
|||
package testenv
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/sha256"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/asn1"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"math/big"
|
||||
"math/bits"
|
||||
"net"
|
||||
"net/url"
|
||||
"os"
|
||||
"os/signal"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"sync"
|
||||
"syscall"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pomerium/pomerium/config"
|
||||
"github.com/pomerium/pomerium/config/envoyconfig/filemgr"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/internal/testenv/values"
|
||||
"github.com/pomerium/pomerium/pkg/cmd/pomerium"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/pkg/health"
|
||||
"github.com/pomerium/pomerium/pkg/netutil"
|
||||
"github.com/pomerium/pomerium/pkg/slices"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/sync/errgroup"
|
||||
"google.golang.org/grpc/grpclog"
|
||||
)
|
||||
|
||||
// Environment is a lightweight integration test fixture that runs Pomerium
|
||||
// in-process.
|
||||
type Environment interface {
|
||||
// Context returns the environment's root context. This context holds a
|
||||
// top-level logger scoped to this environment. It will be canceled when
|
||||
// Stop() is called, or during test cleanup.
|
||||
Context() context.Context
|
||||
|
||||
Assert() *assert.Assertions
|
||||
Require() *require.Assertions
|
||||
|
||||
// TempDir returns a unique temp directory for this context. Calling this
|
||||
// function multiple times returns the same path.
|
||||
TempDir() string
|
||||
// CACert returns the test environment's root CA certificate and private key.
|
||||
CACert() *tls.Certificate
|
||||
// ServerCAs returns a new [*x509.CertPool] containing the root CA certificate
|
||||
// used to sign the server cert and other test certificates.
|
||||
ServerCAs() *x509.CertPool
|
||||
// ServerCert returns the Pomerium server's certificate and private key.
|
||||
ServerCert() *tls.Certificate
|
||||
// NewClientCert generates a new client certificate signed by the root CA
|
||||
// certificate. One or more optional templates can be given, which can be
|
||||
// used to set or override certain parameters when creating a certificate,
|
||||
// including subject, SANs, or extensions. If more than one template is
|
||||
// provided, they will be applied in order from left to right.
|
||||
//
|
||||
// By default (unless overridden in a template), the certificate will have
|
||||
// its Common Name set to the file:line string of the call site. Calls to
|
||||
// NewClientCert() on different lines will have different subjects. If
|
||||
// multiple certs with the same subject are needed, wrap the call to this
|
||||
// function in another helper function, or separate calls with commas on the
|
||||
// same line.
|
||||
NewClientCert(templateOverrides ...*x509.Certificate) *Certificate
|
||||
|
||||
NewServerCert(templateOverrides ...*x509.Certificate) *Certificate
|
||||
|
||||
AuthenticateURL() values.Value[string]
|
||||
DatabrokerURL() values.Value[string]
|
||||
Ports() Ports
|
||||
SharedSecret() []byte
|
||||
CookieSecret() []byte
|
||||
|
||||
// Add adds the given [Modifier] to the environment. All modifiers will be
|
||||
// invoked upon calling Start() to apply individual modifications to the
|
||||
// configuration before starting the Pomerium server.
|
||||
Add(m Modifier)
|
||||
// AddTask adds the given [Task] to the environment. All tasks will be
|
||||
// started in separate goroutines upon calling Start(). If any tasks exit
|
||||
// with an error, the environment will be stopped and the test will fail.
|
||||
AddTask(r Task)
|
||||
// AddUpstream adds the given [Upstream] to the environment. This function is
|
||||
// equivalent to calling both Add() and AddTask() with the upstream, but
|
||||
// improves readability.
|
||||
AddUpstream(u Upstream)
|
||||
|
||||
// Start starts the test environment, and adds a call to Stop() as a cleanup
|
||||
// hook to the environment's [testing.T]. All previously added [Modifier]
|
||||
// instances are invoked in order to build the configuration, and all
|
||||
// previously added [Task] instances are started in the background.
|
||||
//
|
||||
// Calling Start() more than once, Calling Start() after Stop(), or calling
|
||||
// any of the Add* functions after Start() will panic.
|
||||
Start()
|
||||
// Stop stops the test environment. Calling this function more than once has
|
||||
// no effect. It is usually not necessary to call Stop() directly unless you
|
||||
// need to stop the test environment before the test is completed.
|
||||
Stop()
|
||||
|
||||
// SubdomainURL returns a string [values.Value] which will contain a complete
|
||||
// URL for the given subdomain of the server's domain (given by its serving
|
||||
// certificate), including the 'https://' scheme and random http server port.
|
||||
// This value will only be resolved some time after Start() is called, and
|
||||
// can be used as the 'from' value for routes.
|
||||
SubdomainURL(subdomain string) values.Value[string]
|
||||
|
||||
// NewLogRecorder returns a new [*LogRecorder] and starts capturing logs for
|
||||
// the Pomerium server and Envoy.
|
||||
NewLogRecorder(opts ...LogRecorderOption) *LogRecorder
|
||||
|
||||
// OnStateChanged registers a callback to be invoked when the environment's
|
||||
// state changes to the given state. The callback is invoked in a separate
|
||||
// goroutine.
|
||||
OnStateChanged(state EnvironmentState, callback func())
|
||||
}
|
||||
|
||||
type Certificate tls.Certificate
|
||||
|
||||
func (c *Certificate) Fingerprint() string {
|
||||
sum := sha256.Sum256(c.Leaf.Raw)
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
|
||||
func (c *Certificate) SPKIHash() string {
|
||||
sum := sha256.Sum256(c.Leaf.RawSubjectPublicKeyInfo)
|
||||
return base64.StdEncoding.EncodeToString(sum[:])
|
||||
}
|
||||
|
||||
type EnvironmentState uint32
|
||||
|
||||
const NotRunning EnvironmentState = 0
|
||||
|
||||
const (
|
||||
Starting EnvironmentState = 1 << iota
|
||||
Running
|
||||
Stopping
|
||||
Stopped
|
||||
)
|
||||
|
||||
func (e EnvironmentState) String() string {
|
||||
switch e {
|
||||
case NotRunning:
|
||||
return "NotRunning"
|
||||
case Starting:
|
||||
return "Starting"
|
||||
case Running:
|
||||
return "Running"
|
||||
case Stopping:
|
||||
return "Stopping"
|
||||
case Stopped:
|
||||
return "Stopped"
|
||||
default:
|
||||
return fmt.Sprintf("EnvironmentState(%d)", e)
|
||||
}
|
||||
}
|
||||
|
||||
type environment struct {
|
||||
EnvironmentOptions
|
||||
t testing.TB
|
||||
assert *assert.Assertions
|
||||
require *require.Assertions
|
||||
tempDir string
|
||||
domain string
|
||||
ports Ports
|
||||
sharedSecret [32]byte
|
||||
cookieSecret [32]byte
|
||||
workspaceFolder string
|
||||
silent bool
|
||||
|
||||
ctx context.Context
|
||||
cancel context.CancelCauseFunc
|
||||
cleanupOnce sync.Once
|
||||
logWriter *log.MultiWriter
|
||||
|
||||
mods []WithCaller[Modifier]
|
||||
tasks []WithCaller[Task]
|
||||
taskErrGroup *errgroup.Group
|
||||
|
||||
stateMu sync.Mutex
|
||||
state EnvironmentState
|
||||
stateChangeListeners map[EnvironmentState][]func()
|
||||
|
||||
src *configSource
|
||||
}
|
||||
|
||||
type EnvironmentOptions struct {
|
||||
debug bool
|
||||
pauseOnFailure bool
|
||||
forceSilent bool
|
||||
}
|
||||
|
||||
type EnvironmentOption func(*EnvironmentOptions)
|
||||
|
||||
func (o *EnvironmentOptions) apply(opts ...EnvironmentOption) {
|
||||
for _, op := range opts {
|
||||
op(o)
|
||||
}
|
||||
}
|
||||
|
||||
func Debug(enable ...bool) EnvironmentOption {
|
||||
if len(enable) == 0 {
|
||||
enable = append(enable, true)
|
||||
}
|
||||
return func(o *EnvironmentOptions) {
|
||||
o.debug = enable[0]
|
||||
}
|
||||
}
|
||||
|
||||
func PauseOnFailure(enable ...bool) EnvironmentOption {
|
||||
if len(enable) == 0 {
|
||||
enable = append(enable, true)
|
||||
}
|
||||
return func(o *EnvironmentOptions) {
|
||||
o.pauseOnFailure = enable[0]
|
||||
}
|
||||
}
|
||||
|
||||
func Silent(silent ...bool) EnvironmentOption {
|
||||
if len(silent) == 0 {
|
||||
silent = append(silent, true)
|
||||
}
|
||||
return func(o *EnvironmentOptions) {
|
||||
o.forceSilent = silent[0]
|
||||
}
|
||||
}
|
||||
|
||||
var setGrpcLoggerOnce sync.Once
|
||||
|
||||
func New(t testing.TB, opts ...EnvironmentOption) Environment {
|
||||
if runtime.GOOS != "linux" {
|
||||
t.Skip("test environment only supported on linux")
|
||||
}
|
||||
options := EnvironmentOptions{}
|
||||
options.apply(opts...)
|
||||
if testing.Short() {
|
||||
t.Helper()
|
||||
t.Skip("test environment disabled in short mode")
|
||||
}
|
||||
databroker.DebugUseFasterBackoff.Store(true)
|
||||
workspaceFolder, err := os.Getwd()
|
||||
require.NoError(t, err)
|
||||
for {
|
||||
if _, err := os.Stat(filepath.Join(workspaceFolder, ".git")); err == nil {
|
||||
break
|
||||
}
|
||||
workspaceFolder = filepath.Dir(workspaceFolder)
|
||||
if workspaceFolder == "/" {
|
||||
panic("could not find workspace root")
|
||||
}
|
||||
}
|
||||
workspaceFolder, err = filepath.Abs(workspaceFolder)
|
||||
require.NoError(t, err)
|
||||
|
||||
writer := log.NewMultiWriter()
|
||||
silent := options.forceSilent || isSilent(t)
|
||||
if silent {
|
||||
// this sets the global zap level to fatal, then resets the global zerolog
|
||||
// level to debug
|
||||
log.SetLevel(zerolog.FatalLevel)
|
||||
zerolog.SetGlobalLevel(zerolog.DebugLevel)
|
||||
} else {
|
||||
log.SetLevel(zerolog.InfoLevel)
|
||||
writer.Add(os.Stdout)
|
||||
}
|
||||
log.DebugDisableGlobalWarnings.Store(silent)
|
||||
log.DebugDisableGlobalMessages.Store(silent)
|
||||
log.DebugDisableZapLogger.Store(silent)
|
||||
setGrpcLoggerOnce.Do(func() {
|
||||
grpclog.SetLoggerV2(grpclog.NewLoggerV2WithVerbosity(io.Discard, io.Discard, io.Discard, 0))
|
||||
})
|
||||
logger := zerolog.New(writer).With().Timestamp().Logger().Level(zerolog.DebugLevel)
|
||||
|
||||
ctx, cancel := context.WithCancelCause(logger.WithContext(context.Background()))
|
||||
taskErrGroup, ctx := errgroup.WithContext(ctx)
|
||||
|
||||
e := &environment{
|
||||
EnvironmentOptions: options,
|
||||
t: t,
|
||||
assert: assert.New(t),
|
||||
require: require.New(t),
|
||||
tempDir: t.TempDir(),
|
||||
ports: Ports{
|
||||
ProxyHTTP: values.Deferred[int](),
|
||||
ProxyGRPC: values.Deferred[int](),
|
||||
GRPC: values.Deferred[int](),
|
||||
HTTP: values.Deferred[int](),
|
||||
Outbound: values.Deferred[int](),
|
||||
Metrics: values.Deferred[int](),
|
||||
Debug: values.Deferred[int](),
|
||||
ALPN: values.Deferred[int](),
|
||||
},
|
||||
workspaceFolder: workspaceFolder,
|
||||
silent: silent,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
logWriter: writer,
|
||||
taskErrGroup: taskErrGroup,
|
||||
}
|
||||
_, err = rand.Read(e.sharedSecret[:])
|
||||
require.NoError(t, err)
|
||||
_, err = rand.Read(e.cookieSecret[:])
|
||||
require.NoError(t, err)
|
||||
|
||||
health.SetProvider(e)
|
||||
|
||||
require.NoError(t, os.Mkdir(filepath.Join(e.tempDir, "certs"), 0o777))
|
||||
copyFile := func(src, dstRel string) {
|
||||
data, err := os.ReadFile(src)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, os.WriteFile(filepath.Join(e.tempDir, dstRel), data, 0o600))
|
||||
}
|
||||
|
||||
certsToCopy := []string{
|
||||
"trusted.pem",
|
||||
"trusted-key.pem",
|
||||
"ca.pem",
|
||||
"ca-key.pem",
|
||||
}
|
||||
for _, crt := range certsToCopy {
|
||||
copyFile(filepath.Join(workspaceFolder, "integration/tpl/files", crt), filepath.Join("certs/", filepath.Base(crt)))
|
||||
}
|
||||
e.domain = wildcardDomain(e.ServerCert().Leaf.DNSNames)
|
||||
|
||||
return e
|
||||
}
|
||||
|
||||
func (e *environment) debugf(format string, args ...any) {
|
||||
if !e.debug {
|
||||
return
|
||||
}
|
||||
|
||||
e.t.Logf("\x1b[34m[debug] "+format+"\x1b[0m", args...)
|
||||
}
|
||||
|
||||
type WithCaller[T any] struct {
|
||||
Caller string
|
||||
Value T
|
||||
}
|
||||
|
||||
type Ports struct {
|
||||
ProxyHTTP values.MutableValue[int]
|
||||
ProxyGRPC values.MutableValue[int]
|
||||
GRPC values.MutableValue[int]
|
||||
HTTP values.MutableValue[int]
|
||||
Outbound values.MutableValue[int]
|
||||
Metrics values.MutableValue[int]
|
||||
Debug values.MutableValue[int]
|
||||
ALPN values.MutableValue[int]
|
||||
}
|
||||
|
||||
func (e *environment) TempDir() string {
|
||||
return e.tempDir
|
||||
}
|
||||
|
||||
func (e *environment) Context() context.Context {
|
||||
return ContextWithEnv(e.ctx, e)
|
||||
}
|
||||
|
||||
func (e *environment) Assert() *assert.Assertions {
|
||||
return e.assert
|
||||
}
|
||||
|
||||
func (e *environment) Require() *require.Assertions {
|
||||
return e.require
|
||||
}
|
||||
|
||||
func (e *environment) SubdomainURL(subdomain string) values.Value[string] {
|
||||
return values.Bind(e.ports.ProxyHTTP, func(port int) string {
|
||||
return fmt.Sprintf("https://%s.%s:%d", subdomain, e.domain, port)
|
||||
})
|
||||
}
|
||||
|
||||
func (e *environment) AuthenticateURL() values.Value[string] {
|
||||
return e.SubdomainURL("authenticate")
|
||||
}
|
||||
|
||||
func (e *environment) DatabrokerURL() values.Value[string] {
|
||||
return values.Bind(e.ports.Outbound, func(port int) string {
|
||||
return fmt.Sprintf("127.0.0.1:%d", port)
|
||||
})
|
||||
}
|
||||
|
||||
func (e *environment) Ports() Ports {
|
||||
return e.ports
|
||||
}
|
||||
|
||||
func (e *environment) CACert() *tls.Certificate {
|
||||
caCert, err := tls.LoadX509KeyPair(
|
||||
filepath.Join(e.tempDir, "certs", "ca.pem"),
|
||||
filepath.Join(e.tempDir, "certs", "ca-key.pem"),
|
||||
)
|
||||
require.NoError(e.t, err)
|
||||
return &caCert
|
||||
}
|
||||
|
||||
func (e *environment) ServerCAs() *x509.CertPool {
|
||||
pool := x509.NewCertPool()
|
||||
caCert, err := os.ReadFile(filepath.Join(e.tempDir, "certs", "ca.pem"))
|
||||
require.NoError(e.t, err)
|
||||
pool.AppendCertsFromPEM(caCert)
|
||||
return pool
|
||||
}
|
||||
|
||||
func (e *environment) ServerCert() *tls.Certificate {
|
||||
serverCert, err := tls.LoadX509KeyPair(
|
||||
filepath.Join(e.tempDir, "certs", "trusted.pem"),
|
||||
filepath.Join(e.tempDir, "certs", "trusted-key.pem"),
|
||||
)
|
||||
require.NoError(e.t, err)
|
||||
return &serverCert
|
||||
}
|
||||
|
||||
// Used as the context's cancel cause during normal cleanup
|
||||
var ErrCauseTestCleanup = errors.New("test cleanup")
|
||||
|
||||
// Used as the context's cancel cause when Stop() is called
|
||||
var ErrCauseManualStop = errors.New("Stop() called")
|
||||
|
||||
func (e *environment) Start() {
|
||||
e.debugf("Start()")
|
||||
e.advanceState(Starting)
|
||||
e.t.Cleanup(e.cleanup)
|
||||
e.t.Setenv("TMPDIR", e.TempDir())
|
||||
e.debugf("temp dir: %s", e.TempDir())
|
||||
|
||||
cfg := &config.Config{
|
||||
Options: config.NewDefaultOptions(),
|
||||
}
|
||||
ports, err := netutil.AllocatePorts(8)
|
||||
require.NoError(e.t, err)
|
||||
atoi := func(str string) int {
|
||||
p, err := strconv.Atoi(str)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return p
|
||||
}
|
||||
e.ports.ProxyHTTP.Resolve(atoi(ports[0]))
|
||||
e.ports.ProxyGRPC.Resolve(atoi(ports[1]))
|
||||
e.ports.GRPC.Resolve(atoi(ports[2]))
|
||||
e.ports.HTTP.Resolve(atoi(ports[3]))
|
||||
e.ports.Outbound.Resolve(atoi(ports[4]))
|
||||
e.ports.Metrics.Resolve(atoi(ports[5]))
|
||||
e.ports.Debug.Resolve(atoi(ports[6]))
|
||||
e.ports.ALPN.Resolve(atoi(ports[7]))
|
||||
cfg.AllocatePorts(*(*[6]string)(ports[2:]))
|
||||
|
||||
cfg.Options.AutocertOptions = config.AutocertOptions{Enable: false}
|
||||
cfg.Options.Services = "all"
|
||||
cfg.Options.LogLevel = config.LogLevelDebug
|
||||
cfg.Options.ProxyLogLevel = config.LogLevelInfo
|
||||
cfg.Options.Addr = fmt.Sprintf("127.0.0.1:%d", e.ports.ProxyHTTP.Value())
|
||||
cfg.Options.GRPCAddr = fmt.Sprintf("127.0.0.1:%d", e.ports.ProxyGRPC.Value())
|
||||
cfg.Options.CAFile = filepath.Join(e.tempDir, "certs", "ca.pem")
|
||||
cfg.Options.CertFile = filepath.Join(e.tempDir, "certs", "trusted.pem")
|
||||
cfg.Options.KeyFile = filepath.Join(e.tempDir, "certs", "trusted-key.pem")
|
||||
cfg.Options.AuthenticateURLString = e.AuthenticateURL().Value()
|
||||
cfg.Options.DataBrokerStorageType = "memory"
|
||||
cfg.Options.SharedKey = base64.StdEncoding.EncodeToString(e.sharedSecret[:])
|
||||
cfg.Options.CookieSecret = base64.StdEncoding.EncodeToString(e.cookieSecret[:])
|
||||
cfg.Options.AccessLogFields = []log.AccessLogField{
|
||||
log.AccessLogFieldAuthority,
|
||||
log.AccessLogFieldDuration,
|
||||
log.AccessLogFieldForwardedFor,
|
||||
log.AccessLogFieldIP,
|
||||
log.AccessLogFieldMethod,
|
||||
log.AccessLogFieldPath,
|
||||
log.AccessLogFieldQuery,
|
||||
log.AccessLogFieldReferer,
|
||||
log.AccessLogFieldRequestID,
|
||||
log.AccessLogFieldResponseCode,
|
||||
log.AccessLogFieldResponseCodeDetails,
|
||||
log.AccessLogFieldSize,
|
||||
log.AccessLogFieldUpstreamCluster,
|
||||
log.AccessLogFieldUserAgent,
|
||||
log.AccessLogFieldClientCertificate,
|
||||
}
|
||||
|
||||
e.src = &configSource{cfg: cfg}
|
||||
e.AddTask(TaskFunc(func(ctx context.Context) error {
|
||||
fileMgr := filemgr.NewManager(filemgr.WithCacheDir(filepath.Join(e.TempDir(), "cache")))
|
||||
for _, mod := range e.mods {
|
||||
mod.Value.Modify(cfg)
|
||||
require.NoError(e.t, cfg.Options.Validate(), "invoking modifier resulted in an invalid configuration:\nadded by: "+mod.Caller)
|
||||
}
|
||||
return pomerium.Run(ctx, e.src, pomerium.WithOverrideFileManager(fileMgr))
|
||||
}))
|
||||
|
||||
for i, task := range e.tasks {
|
||||
log.Ctx(e.ctx).Debug().Str("caller", task.Caller).Msgf("starting task %d", i)
|
||||
e.taskErrGroup.Go(func() error {
|
||||
defer log.Ctx(e.ctx).Debug().Str("caller", task.Caller).Msgf("task %d exited", i)
|
||||
return task.Value.Run(e.ctx)
|
||||
})
|
||||
}
|
||||
|
||||
runtime.Gosched()
|
||||
|
||||
e.advanceState(Running)
|
||||
}
|
||||
|
||||
func (e *environment) NewClientCert(templateOverrides ...*x509.Certificate) *Certificate {
|
||||
caCert := e.CACert()
|
||||
|
||||
priv, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
require.NoError(e.t, err)
|
||||
|
||||
sn, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128))
|
||||
require.NoError(e.t, err)
|
||||
now := time.Now()
|
||||
tmpl := &x509.Certificate{
|
||||
SerialNumber: sn,
|
||||
Subject: pkix.Name{
|
||||
CommonName: getCaller(),
|
||||
},
|
||||
NotBefore: now,
|
||||
NotAfter: now.Add(12 * time.Hour),
|
||||
KeyUsage: x509.KeyUsageDigitalSignature,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{
|
||||
x509.ExtKeyUsageClientAuth,
|
||||
},
|
||||
BasicConstraintsValid: true,
|
||||
}
|
||||
for _, override := range templateOverrides {
|
||||
tmpl.CRLDistributionPoints = slices.Unique(append(tmpl.CRLDistributionPoints, override.CRLDistributionPoints...))
|
||||
tmpl.DNSNames = slices.Unique(append(tmpl.DNSNames, override.DNSNames...))
|
||||
tmpl.EmailAddresses = slices.Unique(append(tmpl.EmailAddresses, override.EmailAddresses...))
|
||||
tmpl.ExtraExtensions = append(tmpl.ExtraExtensions, override.ExtraExtensions...)
|
||||
tmpl.IPAddresses = slices.UniqueBy(append(tmpl.IPAddresses, override.IPAddresses...), net.IP.String)
|
||||
tmpl.URIs = slices.UniqueBy(append(tmpl.URIs, override.URIs...), (*url.URL).String)
|
||||
tmpl.UnknownExtKeyUsage = slices.UniqueBy(append(tmpl.UnknownExtKeyUsage, override.UnknownExtKeyUsage...), asn1.ObjectIdentifier.String)
|
||||
seq := override.Subject.ToRDNSequence()
|
||||
tmpl.Subject.FillFromRDNSequence(&seq)
|
||||
tmpl.KeyUsage |= override.KeyUsage
|
||||
tmpl.ExtKeyUsage = slices.Unique(append(tmpl.ExtKeyUsage, override.ExtKeyUsage...))
|
||||
}
|
||||
|
||||
clientCertDER, err := x509.CreateCertificate(rand.Reader, tmpl, caCert.Leaf, priv.Public(), caCert.PrivateKey)
|
||||
require.NoError(e.t, err)
|
||||
|
||||
cert, err := x509.ParseCertificate(clientCertDER)
|
||||
require.NoError(e.t, err)
|
||||
e.debugf("provisioned client certificate for %s", cert.Subject.String())
|
||||
|
||||
clientCert := &tls.Certificate{
|
||||
Certificate: [][]byte{cert.Raw, caCert.Leaf.Raw},
|
||||
PrivateKey: priv,
|
||||
Leaf: cert,
|
||||
}
|
||||
|
||||
_, err = clientCert.Leaf.Verify(x509.VerifyOptions{
|
||||
KeyUsages: []x509.ExtKeyUsage{
|
||||
x509.ExtKeyUsageClientAuth,
|
||||
},
|
||||
Roots: e.ServerCAs(),
|
||||
})
|
||||
require.NoError(e.t, err, "bug: generated client cert is not valid")
|
||||
return (*Certificate)(clientCert)
|
||||
}
|
||||
|
||||
func (e *environment) NewServerCert(templateOverrides ...*x509.Certificate) *Certificate {
|
||||
caCert := e.CACert()
|
||||
|
||||
priv, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
require.NoError(e.t, err)
|
||||
|
||||
sn, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128))
|
||||
require.NoError(e.t, err)
|
||||
now := time.Now()
|
||||
tmpl := &x509.Certificate{
|
||||
SerialNumber: sn,
|
||||
NotBefore: now,
|
||||
NotAfter: now.Add(12 * time.Hour),
|
||||
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{
|
||||
x509.ExtKeyUsageServerAuth,
|
||||
},
|
||||
BasicConstraintsValid: true,
|
||||
}
|
||||
for _, override := range templateOverrides {
|
||||
tmpl.DNSNames = slices.Unique(append(tmpl.DNSNames, override.DNSNames...))
|
||||
tmpl.IPAddresses = slices.UniqueBy(append(tmpl.IPAddresses, override.IPAddresses...), net.IP.String)
|
||||
}
|
||||
certDER, err := x509.CreateCertificate(rand.Reader, tmpl, caCert.Leaf, priv.Public(), caCert.PrivateKey)
|
||||
require.NoError(e.t, err)
|
||||
|
||||
cert, err := x509.ParseCertificate(certDER)
|
||||
require.NoError(e.t, err)
|
||||
e.debugf("provisioned server certificate for %v", cert.DNSNames)
|
||||
|
||||
tlsCert := &tls.Certificate{
|
||||
Certificate: [][]byte{cert.Raw, caCert.Leaf.Raw},
|
||||
PrivateKey: priv,
|
||||
Leaf: cert,
|
||||
}
|
||||
|
||||
_, err = tlsCert.Leaf.Verify(x509.VerifyOptions{Roots: e.ServerCAs()})
|
||||
require.NoError(e.t, err, "bug: generated client cert is not valid")
|
||||
return (*Certificate)(tlsCert)
|
||||
}
|
||||
|
||||
func (e *environment) SharedSecret() []byte {
|
||||
return bytes.Clone(e.sharedSecret[:])
|
||||
}
|
||||
|
||||
func (e *environment) CookieSecret() []byte {
|
||||
return bytes.Clone(e.cookieSecret[:])
|
||||
}
|
||||
|
||||
func (e *environment) Stop() {
|
||||
if b, ok := e.t.(*testing.B); ok {
|
||||
// when calling Stop() manually, ensure we aren't timing this
|
||||
b.StopTimer()
|
||||
defer b.StartTimer()
|
||||
}
|
||||
e.cleanupOnce.Do(func() {
|
||||
e.debugf("stop: Stop() called manually")
|
||||
e.advanceState(Stopping)
|
||||
e.cancel(ErrCauseManualStop)
|
||||
err := e.taskErrGroup.Wait()
|
||||
e.advanceState(Stopped)
|
||||
e.debugf("stop: done waiting")
|
||||
assert.ErrorIs(e.t, err, ErrCauseManualStop)
|
||||
})
|
||||
}
|
||||
|
||||
func (e *environment) cleanup() {
|
||||
e.cleanupOnce.Do(func() {
|
||||
e.debugf("stop: test cleanup")
|
||||
if e.t.Failed() {
|
||||
if e.pauseOnFailure {
|
||||
e.t.Log("\x1b[31m*** pausing on test failure; continue with ctrl+c ***\x1b[0m")
|
||||
c := make(chan os.Signal, 1)
|
||||
signal.Notify(c, syscall.SIGINT)
|
||||
<-c
|
||||
e.t.Log("\x1b[31mctrl+c received, continuing\x1b[0m")
|
||||
signal.Stop(c)
|
||||
}
|
||||
}
|
||||
e.advanceState(Stopping)
|
||||
e.cancel(ErrCauseTestCleanup)
|
||||
err := e.taskErrGroup.Wait()
|
||||
e.advanceState(Stopped)
|
||||
e.debugf("stop: done waiting")
|
||||
assert.ErrorIs(e.t, err, ErrCauseTestCleanup)
|
||||
})
|
||||
}
|
||||
|
||||
func (e *environment) Add(m Modifier) {
|
||||
e.t.Helper()
|
||||
caller := getCaller()
|
||||
e.debugf("Add: %T from %s", m, caller)
|
||||
switch e.getState() {
|
||||
case NotRunning:
|
||||
for _, mod := range e.mods {
|
||||
if mod.Value == m {
|
||||
e.t.Fatalf("test bug: duplicate modifier added\nfirst added by: %s", mod.Caller)
|
||||
}
|
||||
}
|
||||
e.mods = append(e.mods, WithCaller[Modifier]{
|
||||
Caller: caller,
|
||||
Value: m,
|
||||
})
|
||||
e.debugf("Add: state=NotRunning; calling Attach")
|
||||
m.Attach(e.Context())
|
||||
case Starting:
|
||||
panic("test bug: cannot call Add() before Start() has returned")
|
||||
case Running:
|
||||
e.debugf("Add: state=Running; calling ModifyConfig")
|
||||
e.src.ModifyConfig(e.ctx, m)
|
||||
case Stopped, Stopping:
|
||||
panic("test bug: cannot call Add() after Stop()")
|
||||
default:
|
||||
panic(fmt.Sprintf("unexpected environment state: %s", e.getState()))
|
||||
}
|
||||
}
|
||||
|
||||
func (e *environment) AddTask(t Task) {
|
||||
e.t.Helper()
|
||||
caller := getCaller()
|
||||
e.debugf("AddTask: %T from %s", t, caller)
|
||||
for _, task := range e.tasks {
|
||||
if task.Value == t {
|
||||
e.t.Fatalf("test bug: duplicate task added\nfirst added by: %s", task.Caller)
|
||||
}
|
||||
}
|
||||
e.tasks = append(e.tasks, WithCaller[Task]{
|
||||
Caller: getCaller(),
|
||||
Value: t,
|
||||
})
|
||||
}
|
||||
|
||||
func (e *environment) AddUpstream(up Upstream) {
|
||||
e.t.Helper()
|
||||
caller := getCaller()
|
||||
e.debugf("AddUpstream: %T from %s", up, caller)
|
||||
e.Add(up)
|
||||
e.AddTask(up)
|
||||
}
|
||||
|
||||
// ReportError implements health.Provider.
|
||||
func (e *environment) ReportError(check health.Check, err error, attributes ...health.Attr) {
|
||||
// note: don't use e.t.Fatal here, it will deadlock
|
||||
panic(fmt.Sprintf("%s: %v %v", check, err, attributes))
|
||||
}
|
||||
|
||||
// ReportOK implements health.Provider.
|
||||
func (e *environment) ReportOK(_ health.Check, _ ...health.Attr) {}
|
||||
|
||||
func (e *environment) advanceState(newState EnvironmentState) {
|
||||
e.stateMu.Lock()
|
||||
defer e.stateMu.Unlock()
|
||||
if newState <= e.state {
|
||||
panic(fmt.Sprintf("internal test environment bug: changed state to <= current: newState=%s, current=%s", newState, e.state))
|
||||
}
|
||||
e.debugf("state %s -> %s", e.state.String(), newState.String())
|
||||
e.state = newState
|
||||
e.debugf("notifying %d listeners of state change", len(e.stateChangeListeners[newState]))
|
||||
for _, listener := range e.stateChangeListeners[newState] {
|
||||
go listener()
|
||||
}
|
||||
}
|
||||
|
||||
func (e *environment) getState() EnvironmentState {
|
||||
e.stateMu.Lock()
|
||||
defer e.stateMu.Unlock()
|
||||
return e.state
|
||||
}
|
||||
|
||||
func (e *environment) OnStateChanged(state EnvironmentState, callback func()) {
|
||||
e.stateMu.Lock()
|
||||
defer e.stateMu.Unlock()
|
||||
|
||||
if e.state&state != 0 {
|
||||
go callback()
|
||||
return
|
||||
}
|
||||
|
||||
// add change listeners for all states, if there are multiple bits set
|
||||
for state > 0 {
|
||||
stateBit := EnvironmentState(bits.TrailingZeros32(uint32(state)))
|
||||
state &= (state - 1)
|
||||
e.stateChangeListeners[stateBit] = append(e.stateChangeListeners[stateBit], callback)
|
||||
}
|
||||
}
|
||||
|
||||
func getCaller(skip ...int) string {
|
||||
if len(skip) == 0 {
|
||||
skip = append(skip, 3)
|
||||
}
|
||||
callers := make([]uintptr, 8)
|
||||
runtime.Callers(skip[0], callers)
|
||||
frames := runtime.CallersFrames(callers)
|
||||
var caller string
|
||||
for {
|
||||
next, ok := frames.Next()
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
if path.Base(next.Function) == "testenv.(*environment).AddUpstream" {
|
||||
continue
|
||||
}
|
||||
caller = fmt.Sprintf("%s:%d", next.File, next.Line)
|
||||
break
|
||||
}
|
||||
return caller
|
||||
}
|
||||
|
||||
func wildcardDomain(names []string) string {
|
||||
for _, name := range names {
|
||||
if name[0] == '*' {
|
||||
return name[2:]
|
||||
}
|
||||
}
|
||||
panic("test bug: no wildcard domain in certificate")
|
||||
}
|
||||
|
||||
func isSilent(t testing.TB) bool {
|
||||
switch t.(type) {
|
||||
case *testing.B:
|
||||
return !slices.Contains(os.Args, "-test.v=true")
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
type configSource struct {
|
||||
mu sync.Mutex
|
||||
cfg *config.Config
|
||||
lis []config.ChangeListener
|
||||
}
|
||||
|
||||
var _ config.Source = (*configSource)(nil)
|
||||
|
||||
// GetConfig implements config.Source.
|
||||
func (src *configSource) GetConfig() *config.Config {
|
||||
src.mu.Lock()
|
||||
defer src.mu.Unlock()
|
||||
|
||||
return src.cfg
|
||||
}
|
||||
|
||||
// OnConfigChange implements config.Source.
|
||||
func (src *configSource) OnConfigChange(_ context.Context, li config.ChangeListener) {
|
||||
src.mu.Lock()
|
||||
defer src.mu.Unlock()
|
||||
|
||||
src.lis = append(src.lis, li)
|
||||
}
|
||||
|
||||
// ModifyConfig updates the current configuration by applying a [Modifier].
|
||||
func (src *configSource) ModifyConfig(ctx context.Context, m Modifier) {
|
||||
src.mu.Lock()
|
||||
defer src.mu.Unlock()
|
||||
|
||||
m.Modify(src.cfg)
|
||||
for _, li := range src.lis {
|
||||
li(ctx, src.cfg)
|
||||
}
|
||||
}
|
391
internal/testenv/logs.go
Normal file
391
internal/testenv/logs.go
Normal file
|
@ -0,0 +1,391 @@
|
|||
package testenv
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"reflect"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// LogRecorder captures logs from the test environment. It can be created at
|
||||
// any time by calling [Environment.NewLogRecorder], and captures logs until
|
||||
// one of Close(), Logs(), or Match() is called, which stops recording. See the
|
||||
// documentation for each method for more details.
|
||||
type LogRecorder struct {
|
||||
LogRecorderOptions
|
||||
t testing.TB
|
||||
canceled <-chan struct{}
|
||||
buf *buffer
|
||||
recordedLogs []map[string]any
|
||||
|
||||
removeGlobalWriterOnce func()
|
||||
collectLogsOnce sync.Once
|
||||
}
|
||||
|
||||
type LogRecorderOptions struct {
|
||||
filters []func(map[string]any) bool
|
||||
skipCloseDelay bool
|
||||
}
|
||||
|
||||
type LogRecorderOption func(*LogRecorderOptions)
|
||||
|
||||
func (o *LogRecorderOptions) apply(opts ...LogRecorderOption) {
|
||||
for _, op := range opts {
|
||||
op(o)
|
||||
}
|
||||
}
|
||||
|
||||
// WithFilters applies one or more filter predicates to the logger. If there
|
||||
// are filters present, they will be called in order when a log is received,
|
||||
// and if any filter returns false for a given log, it will be discarded.
|
||||
func WithFilters(filters ...func(map[string]any) bool) LogRecorderOption {
|
||||
return func(o *LogRecorderOptions) {
|
||||
o.filters = filters
|
||||
}
|
||||
}
|
||||
|
||||
// WithSkipCloseDelay skips the 1.1 second delay before closing the recorder.
|
||||
// This delay is normally required to ensure Envoy access logs are flushed,
|
||||
// but can be skipped if not required.
|
||||
func WithSkipCloseDelay() LogRecorderOption {
|
||||
return func(o *LogRecorderOptions) {
|
||||
o.skipCloseDelay = true
|
||||
}
|
||||
}
|
||||
|
||||
type buffer struct {
|
||||
mu *sync.Mutex
|
||||
underlying bytes.Buffer
|
||||
cond *sync.Cond
|
||||
waiting bool
|
||||
closed bool
|
||||
}
|
||||
|
||||
func newBuffer() *buffer {
|
||||
mu := &sync.Mutex{}
|
||||
return &buffer{
|
||||
mu: mu,
|
||||
cond: sync.NewCond(mu),
|
||||
}
|
||||
}
|
||||
|
||||
// Read implements io.ReadWriteCloser.
|
||||
func (b *buffer) Read(p []byte) (int, error) {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
for {
|
||||
n, err := b.underlying.Read(p)
|
||||
if errors.Is(err, io.EOF) && !b.closed {
|
||||
b.waiting = true
|
||||
b.cond.Wait()
|
||||
continue
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
}
|
||||
|
||||
// Write implements io.ReadWriteCloser.
|
||||
func (b *buffer) Write(p []byte) (int, error) {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
if b.closed {
|
||||
return 0, io.ErrClosedPipe
|
||||
}
|
||||
if b.waiting {
|
||||
b.waiting = false
|
||||
defer b.cond.Signal()
|
||||
}
|
||||
return b.underlying.Write(p)
|
||||
}
|
||||
|
||||
// Close implements io.ReadWriteCloser.
|
||||
func (b *buffer) Close() error {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
b.closed = true
|
||||
b.cond.Signal()
|
||||
return nil
|
||||
}
|
||||
|
||||
var _ io.ReadWriteCloser = (*buffer)(nil)
|
||||
|
||||
func (e *environment) NewLogRecorder(opts ...LogRecorderOption) *LogRecorder {
|
||||
options := LogRecorderOptions{}
|
||||
options.apply(opts...)
|
||||
lr := &LogRecorder{
|
||||
LogRecorderOptions: options,
|
||||
t: e.t,
|
||||
canceled: e.ctx.Done(),
|
||||
buf: newBuffer(),
|
||||
}
|
||||
e.logWriter.Add(lr.buf)
|
||||
lr.removeGlobalWriterOnce = sync.OnceFunc(func() {
|
||||
// wait for envoy access logs, which flush on a 1 second interval
|
||||
if !lr.skipCloseDelay {
|
||||
time.Sleep(1100 * time.Millisecond)
|
||||
}
|
||||
e.logWriter.Remove(lr.buf)
|
||||
})
|
||||
context.AfterFunc(e.ctx, lr.removeGlobalWriterOnce)
|
||||
return lr
|
||||
}
|
||||
|
||||
type (
|
||||
// OpenMap is an alias for map[string]any, and can be used to semantically
|
||||
// represent a map that must contain at least the given entries, but may
|
||||
// also contain additional entries.
|
||||
OpenMap = map[string]any
|
||||
// ClosedMap is a map[string]any that can be used to semantically represent
|
||||
// a map that must contain the given entries exactly, and no others.
|
||||
ClosedMap map[string]any
|
||||
)
|
||||
|
||||
// Close stops the log recorder. After calling this method, Logs() or Match()
|
||||
// can be called to inspect the logs that were captured.
|
||||
func (lr *LogRecorder) Close() {
|
||||
lr.removeGlobalWriterOnce()
|
||||
}
|
||||
|
||||
func (lr *LogRecorder) collectLogs(shouldClose bool) {
|
||||
if shouldClose {
|
||||
lr.removeGlobalWriterOnce()
|
||||
lr.buf.Close()
|
||||
}
|
||||
lr.collectLogsOnce.Do(func() {
|
||||
recordedLogs := []map[string]any{}
|
||||
scan := bufio.NewScanner(lr.buf)
|
||||
for scan.Scan() {
|
||||
log := scan.Bytes()
|
||||
m := map[string]any{}
|
||||
decoder := json.NewDecoder(bytes.NewReader(log))
|
||||
decoder.UseNumber()
|
||||
require.NoError(lr.t, decoder.Decode(&m))
|
||||
for _, filter := range lr.filters {
|
||||
if !filter(m) {
|
||||
continue
|
||||
}
|
||||
}
|
||||
recordedLogs = append(recordedLogs, m)
|
||||
}
|
||||
lr.recordedLogs = recordedLogs
|
||||
})
|
||||
}
|
||||
|
||||
func (lr *LogRecorder) WaitForMatch(expectedLog map[string]any, timeout ...time.Duration) {
|
||||
lr.skipCloseDelay = true
|
||||
found := make(chan struct{})
|
||||
done := make(chan struct{})
|
||||
lr.filters = append(lr.filters, func(entry map[string]any) bool {
|
||||
select {
|
||||
case <-found:
|
||||
default:
|
||||
if matched, _ := match(expectedLog, entry, true); matched {
|
||||
close(found)
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
go func() {
|
||||
defer close(done)
|
||||
lr.collectLogs(false)
|
||||
lr.removeGlobalWriterOnce()
|
||||
}()
|
||||
if len(timeout) != 0 {
|
||||
select {
|
||||
case <-found:
|
||||
case <-time.After(timeout[0]):
|
||||
lr.t.Error("timed out waiting for log")
|
||||
case <-lr.canceled:
|
||||
lr.t.Error("canceled")
|
||||
}
|
||||
} else {
|
||||
select {
|
||||
case <-found:
|
||||
case <-lr.canceled:
|
||||
lr.t.Error("canceled")
|
||||
}
|
||||
}
|
||||
lr.buf.Close()
|
||||
<-done
|
||||
}
|
||||
|
||||
// Logs stops the log recorder (if it is not already stopped), then returns
|
||||
// the logs that were captured as structured map[string]any objects.
|
||||
func (lr *LogRecorder) Logs() []map[string]any {
|
||||
lr.collectLogs(true)
|
||||
return lr.recordedLogs
|
||||
}
|
||||
|
||||
func (lr *LogRecorder) DumpToFile(file string) {
|
||||
lr.collectLogs(true)
|
||||
f, err := os.Create(file)
|
||||
require.NoError(lr.t, err)
|
||||
enc := json.NewEncoder(f)
|
||||
for _, log := range lr.recordedLogs {
|
||||
_ = enc.Encode(log)
|
||||
}
|
||||
f.Close()
|
||||
}
|
||||
|
||||
// Match stops the log recorder (if it is not already stopped), then asserts
|
||||
// that the given expected logs were captured. The expected logs may contain
|
||||
// partial or complete log entries. By default, logs must only match the fields
|
||||
// given, and may contain additional fields that will be ignored.
|
||||
//
|
||||
// There are several special-case value types that can be used to customize the
|
||||
// matching behavior, and/or simplify some common use cases, as follows:
|
||||
// - [OpenMap] and [ClosedMap] can be used to control matching logic
|
||||
// - [json.Number] will convert the actual value to a string before comparison
|
||||
// - [*tls.Certificate] or [*x509.Certificate] will expand to the fields that
|
||||
// would be logged for this certificate
|
||||
func (lr *LogRecorder) Match(expectedLogs []map[string]any) {
|
||||
lr.collectLogs(true)
|
||||
for _, expectedLog := range expectedLogs {
|
||||
found := false
|
||||
highScore, highScoreIdxs := 0, []int{}
|
||||
for i, actualLog := range lr.recordedLogs {
|
||||
if ok, score := match(expectedLog, actualLog, true); ok {
|
||||
found = true
|
||||
break
|
||||
} else if score > highScore {
|
||||
highScore = score
|
||||
highScoreIdxs = []int{i}
|
||||
} else if score == highScore {
|
||||
highScoreIdxs = append(highScoreIdxs, i)
|
||||
}
|
||||
}
|
||||
if len(highScoreIdxs) > 0 {
|
||||
expectedLogBytes, _ := json.MarshalIndent(expectedLog, "", " ")
|
||||
if len(highScoreIdxs) == 1 {
|
||||
actualLogBytes, _ := json.MarshalIndent(lr.recordedLogs[highScoreIdxs[0]], "", " ")
|
||||
assert.True(lr.t, found, "expected log not found: \n%s\n\nclosest match:\n%s\n",
|
||||
string(expectedLogBytes), string(actualLogBytes))
|
||||
} else {
|
||||
closestMatches := []string{}
|
||||
for _, i := range highScoreIdxs {
|
||||
bytes, _ := json.MarshalIndent(lr.recordedLogs[i], "", " ")
|
||||
closestMatches = append(closestMatches, string(bytes))
|
||||
}
|
||||
assert.True(lr.t, found, "expected log not found: \n%s\n\nclosest matches:\n%s\n", string(expectedLogBytes), closestMatches)
|
||||
}
|
||||
} else {
|
||||
expectedLogBytes, _ := json.MarshalIndent(expectedLog, "", " ")
|
||||
assert.True(lr.t, found, "expected log not found: %s", string(expectedLogBytes))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func match(expected, actual map[string]any, open bool) (matched bool, score int) {
|
||||
for key, value := range expected {
|
||||
actualValue, ok := actual[key]
|
||||
if !ok {
|
||||
return false, score
|
||||
}
|
||||
score++
|
||||
|
||||
switch actualValue := actualValue.(type) {
|
||||
case map[string]any:
|
||||
switch expectedValue := value.(type) {
|
||||
case ClosedMap:
|
||||
ok, s := match(expectedValue, actualValue, false)
|
||||
score += s * 2
|
||||
if !ok {
|
||||
return false, score
|
||||
}
|
||||
case OpenMap:
|
||||
ok, s := match(expectedValue, actualValue, true)
|
||||
score += s
|
||||
if !ok {
|
||||
return false, score
|
||||
}
|
||||
case *tls.Certificate, *Certificate, *x509.Certificate:
|
||||
var leaf *x509.Certificate
|
||||
switch expectedValue := expectedValue.(type) {
|
||||
case *tls.Certificate:
|
||||
leaf = expectedValue.Leaf
|
||||
case *Certificate:
|
||||
leaf = expectedValue.Leaf
|
||||
case *x509.Certificate:
|
||||
leaf = expectedValue
|
||||
}
|
||||
|
||||
// keep logic consistent with controlplane.populateCertEventDict()
|
||||
expected := map[string]any{}
|
||||
if iss := leaf.Issuer.String(); iss != "" {
|
||||
expected["issuer"] = iss
|
||||
}
|
||||
if sub := leaf.Subject.String(); sub != "" {
|
||||
expected["subject"] = sub
|
||||
}
|
||||
sans := []string{}
|
||||
for _, dnsSAN := range leaf.DNSNames {
|
||||
sans = append(sans, "DNS:"+dnsSAN)
|
||||
}
|
||||
for _, uriSAN := range leaf.URIs {
|
||||
sans = append(sans, "URI:"+uriSAN.String())
|
||||
}
|
||||
if len(sans) > 0 {
|
||||
expected["subjectAltName"] = sans
|
||||
}
|
||||
|
||||
ok, s := match(expected, actualValue, false)
|
||||
score += s
|
||||
if !ok {
|
||||
return false, score
|
||||
}
|
||||
default:
|
||||
return false, score
|
||||
}
|
||||
case string:
|
||||
switch value := value.(type) {
|
||||
case string:
|
||||
if value != actualValue {
|
||||
return false, score
|
||||
}
|
||||
score++
|
||||
default:
|
||||
return false, score
|
||||
}
|
||||
case json.Number:
|
||||
if fmt.Sprint(value) != actualValue.String() {
|
||||
return false, score
|
||||
}
|
||||
score++
|
||||
default:
|
||||
// handle slices
|
||||
if reflect.TypeOf(actualValue).Kind() == reflect.Slice {
|
||||
if reflect.TypeOf(value) != reflect.TypeOf(actualValue) {
|
||||
return false, score
|
||||
}
|
||||
actualSlice := reflect.ValueOf(actualValue)
|
||||
expectedSlice := reflect.ValueOf(value)
|
||||
totalScore := 0
|
||||
for i := range min(actualSlice.Len(), expectedSlice.Len()) {
|
||||
if actualSlice.Index(i).Equal(expectedSlice.Index(i)) {
|
||||
totalScore++
|
||||
}
|
||||
}
|
||||
score += totalScore
|
||||
} else {
|
||||
panic(fmt.Sprintf("test bug: add check for type %T in assertMatchingLogs", actualValue))
|
||||
}
|
||||
}
|
||||
}
|
||||
if !open && len(expected) != len(actual) {
|
||||
return false, score
|
||||
}
|
||||
return true, score
|
||||
}
|
76
internal/testenv/route.go
Normal file
76
internal/testenv/route.go
Normal file
|
@ -0,0 +1,76 @@
|
|||
package testenv
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/pomerium/pomerium/config"
|
||||
"github.com/pomerium/pomerium/internal/testenv/values"
|
||||
"github.com/pomerium/pomerium/pkg/policy/parser"
|
||||
)
|
||||
|
||||
// PolicyRoute is a [Route] implementation suitable for most common use cases
|
||||
// that can be used in implementations of [Upstream].
|
||||
type PolicyRoute struct {
|
||||
DefaultAttach
|
||||
from values.Value[string]
|
||||
to values.List[string]
|
||||
edits []func(*config.Policy)
|
||||
}
|
||||
|
||||
// Modify implements Route.
|
||||
func (b *PolicyRoute) Modify(cfg *config.Config) {
|
||||
to := make(config.WeightedURLs, 0, len(b.to))
|
||||
for _, u := range b.to {
|
||||
u, err := url.Parse(u.Value())
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
to = append(to, config.WeightedURL{URL: *u})
|
||||
}
|
||||
p := config.Policy{
|
||||
From: b.from.Value(),
|
||||
To: to,
|
||||
}
|
||||
for _, edit := range b.edits {
|
||||
edit(&p)
|
||||
}
|
||||
cfg.Options.Policies = append(cfg.Options.Policies, p)
|
||||
}
|
||||
|
||||
// From implements Route.
|
||||
func (b *PolicyRoute) From(fromURL values.Value[string]) Route {
|
||||
b.from = fromURL
|
||||
return b
|
||||
}
|
||||
|
||||
// To implements Route.
|
||||
func (b *PolicyRoute) To(toURL values.Value[string]) Route {
|
||||
b.to = append(b.to, toURL)
|
||||
return b
|
||||
}
|
||||
|
||||
// To implements Route.
|
||||
func (b *PolicyRoute) Policy(edit func(*config.Policy)) Route {
|
||||
b.edits = append(b.edits, edit)
|
||||
return b
|
||||
}
|
||||
|
||||
// PPL implements Route.
|
||||
func (b *PolicyRoute) PPL(ppl string) Route {
|
||||
pplPolicy, err := parser.ParseYAML(strings.NewReader(ppl))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
b.edits = append(b.edits, func(p *config.Policy) {
|
||||
p.Policy = &config.PPLPolicy{
|
||||
Policy: pplPolicy,
|
||||
}
|
||||
})
|
||||
return b
|
||||
}
|
||||
|
||||
// To implements Route.
|
||||
func (b *PolicyRoute) URL() values.Value[string] {
|
||||
return b.from
|
||||
}
|
389
internal/testenv/scenarios/mock_idp.go
Normal file
389
internal/testenv/scenarios/mock_idp.go
Normal file
|
@ -0,0 +1,389 @@
|
|||
package scenarios
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-jose/go-jose/v3"
|
||||
"github.com/go-jose/go-jose/v3/jwt"
|
||||
"github.com/google/uuid"
|
||||
"github.com/pomerium/pomerium/config"
|
||||
"github.com/pomerium/pomerium/internal/encoding"
|
||||
"github.com/pomerium/pomerium/internal/encoding/jws"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/internal/testenv"
|
||||
"github.com/pomerium/pomerium/internal/testenv/upstreams"
|
||||
"github.com/pomerium/pomerium/internal/testenv/values"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/identity"
|
||||
)
|
||||
|
||||
type IDP struct {
|
||||
id values.Value[string]
|
||||
url values.Value[string]
|
||||
publicJWK jose.JSONWebKey
|
||||
signingKey jose.SigningKey
|
||||
|
||||
stateEncoder encoding.MarshalUnmarshaler
|
||||
userLookup map[string]*User
|
||||
}
|
||||
|
||||
// Attach implements testenv.Modifier.
|
||||
func (idp *IDP) Attach(ctx context.Context) {
|
||||
env := testenv.EnvFromContext(ctx)
|
||||
|
||||
router := upstreams.HTTP(nil)
|
||||
|
||||
idp.url = values.Bind2(env.SubdomainURL("mock-idp"), router.Port(), func(urlStr string, port int) string {
|
||||
u, _ := url.Parse(urlStr)
|
||||
host, _, _ := net.SplitHostPort(u.Host)
|
||||
return u.ResolveReference(&url.URL{
|
||||
Scheme: "http",
|
||||
Host: fmt.Sprintf("%s:%d", host, port),
|
||||
}).String()
|
||||
})
|
||||
var err error
|
||||
idp.stateEncoder, err = jws.NewHS256Signer(env.SharedSecret())
|
||||
env.Require().NoError(err)
|
||||
|
||||
idp.id = values.Bind2(idp.url, env.AuthenticateURL(), func(idpUrl, authUrl string) string {
|
||||
provider := identity.Provider{
|
||||
AuthenticateServiceUrl: authUrl,
|
||||
ClientId: "CLIENT_ID",
|
||||
ClientSecret: "CLIENT_SECRET",
|
||||
Type: "oidc",
|
||||
Scopes: []string{"openid", "email", "profile"},
|
||||
Url: idpUrl,
|
||||
}
|
||||
return provider.Hash()
|
||||
})
|
||||
|
||||
router.Handle("/.well-known/jwks.json", func(w http.ResponseWriter, _ *http.Request) {
|
||||
_ = json.NewEncoder(w).Encode(&jose.JSONWebKeySet{
|
||||
Keys: []jose.JSONWebKey{idp.publicJWK},
|
||||
})
|
||||
})
|
||||
router.Handle("/.well-known/openid-configuration", func(w http.ResponseWriter, r *http.Request) {
|
||||
log.Ctx(ctx).Debug().Str("method", r.Method).Str("uri", r.RequestURI).Send()
|
||||
rootURL, _ := url.Parse(idp.url.Value())
|
||||
_ = json.NewEncoder(w).Encode(map[string]interface{}{
|
||||
"issuer": rootURL.String(),
|
||||
"authorization_endpoint": rootURL.ResolveReference(&url.URL{Path: "/oidc/auth"}).String(),
|
||||
"token_endpoint": rootURL.ResolveReference(&url.URL{Path: "/oidc/token"}).String(),
|
||||
"jwks_uri": rootURL.ResolveReference(&url.URL{Path: "/.well-known/jwks.json"}).String(),
|
||||
"userinfo_endpoint": rootURL.ResolveReference(&url.URL{Path: "/oidc/userinfo"}).String(),
|
||||
"id_token_signing_alg_values_supported": []string{
|
||||
"ES256",
|
||||
},
|
||||
})
|
||||
})
|
||||
router.Handle("/oidc/auth", idp.HandleAuth)
|
||||
router.Handle("/oidc/token", idp.HandleToken)
|
||||
router.Handle("/oidc/userinfo", idp.HandleUserInfo)
|
||||
|
||||
env.AddUpstream(router)
|
||||
}
|
||||
|
||||
// Modify implements testenv.Modifier.
|
||||
func (idp *IDP) Modify(cfg *config.Config) {
|
||||
cfg.Options.Provider = "oidc"
|
||||
cfg.Options.ProviderURL = idp.url.Value()
|
||||
cfg.Options.ClientID = "CLIENT_ID"
|
||||
cfg.Options.ClientSecret = "CLIENT_SECRET"
|
||||
cfg.Options.Scopes = []string{"openid", "email", "profile"}
|
||||
}
|
||||
|
||||
var _ testenv.Modifier = (*IDP)(nil)
|
||||
|
||||
func NewIDP(users []*User) *IDP {
|
||||
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
publicKey := &privateKey.PublicKey
|
||||
|
||||
signingKey := jose.SigningKey{
|
||||
Algorithm: jose.ES256,
|
||||
Key: privateKey,
|
||||
}
|
||||
publicJWK := jose.JSONWebKey{
|
||||
Key: publicKey,
|
||||
Algorithm: string(jose.ES256),
|
||||
Use: "sig",
|
||||
}
|
||||
thumbprint, err := publicJWK.Thumbprint(crypto.SHA256)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
publicJWK.KeyID = hex.EncodeToString(thumbprint)
|
||||
|
||||
userLookup := map[string]*User{}
|
||||
for _, user := range users {
|
||||
user.ID = uuid.NewString()
|
||||
userLookup[user.ID] = user
|
||||
}
|
||||
return &IDP{
|
||||
publicJWK: publicJWK,
|
||||
signingKey: signingKey,
|
||||
userLookup: userLookup,
|
||||
}
|
||||
}
|
||||
|
||||
// HandleAuth handles the auth flow for OIDC.
|
||||
func (idp *IDP) HandleAuth(w http.ResponseWriter, r *http.Request) {
|
||||
rawRedirectURI := r.FormValue("redirect_uri")
|
||||
if rawRedirectURI == "" {
|
||||
http.Error(w, "missing redirect_uri", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
redirectURI, err := url.Parse(rawRedirectURI)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
rawClientID := r.FormValue("client_id")
|
||||
if rawClientID == "" {
|
||||
http.Error(w, "missing client_id", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
rawEmail := r.FormValue("email")
|
||||
if rawEmail != "" {
|
||||
http.Redirect(w, r, redirectURI.ResolveReference(&url.URL{
|
||||
RawQuery: (url.Values{
|
||||
"state": {r.FormValue("state")},
|
||||
"code": {State{
|
||||
Email: rawEmail,
|
||||
ClientID: rawClientID,
|
||||
}.Encode()},
|
||||
}).Encode(),
|
||||
}).String(), http.StatusFound)
|
||||
return
|
||||
}
|
||||
|
||||
serveHTML(w, `<!doctype html>
|
||||
<html>
|
||||
<head>
|
||||
<title>Login</title>
|
||||
</head>
|
||||
<body>
|
||||
<form method="POST" style="max-width: 200px">
|
||||
<fieldset>
|
||||
<legend>Login</legend>
|
||||
|
||||
<table>
|
||||
<tbody>
|
||||
<tr>
|
||||
<th><label for="email">Email</label></th>
|
||||
<td>
|
||||
<input type="email" name="email" placeholder="email" />
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td colspan="2">
|
||||
<input type="submit" />
|
||||
</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
|
||||
</fieldset>
|
||||
</form>
|
||||
</body>
|
||||
</html>
|
||||
`)
|
||||
}
|
||||
|
||||
// HandleToken handles the token flow for OIDC.
|
||||
func (idp *IDP) HandleToken(w http.ResponseWriter, r *http.Request) {
|
||||
rawCode := r.FormValue("code")
|
||||
|
||||
state, err := DecodeState(rawCode)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
serveJSON(w, map[string]interface{}{
|
||||
"access_token": state.Encode(),
|
||||
"refresh_token": state.Encode(),
|
||||
"token_type": "Bearer",
|
||||
"id_token": state.GetIDToken(r, idp.userLookup).Encode(idp.signingKey),
|
||||
})
|
||||
}
|
||||
|
||||
// HandleUserInfo handles retrieving the user info.
|
||||
func (idp *IDP) HandleUserInfo(w http.ResponseWriter, r *http.Request) {
|
||||
authz := r.Header.Get("Authorization")
|
||||
if authz == "" {
|
||||
http.Error(w, "missing authorization header", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
if strings.HasPrefix(authz, "Bearer ") {
|
||||
authz = authz[len("Bearer "):]
|
||||
} else if strings.HasPrefix(authz, "token ") {
|
||||
authz = authz[len("token "):]
|
||||
} else {
|
||||
http.Error(w, "missing bearer token", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
state, err := DecodeState(authz)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
serveJSON(w, state.GetUserInfo(idp.userLookup))
|
||||
}
|
||||
|
||||
type RootURLKey struct{}
|
||||
|
||||
var rootURLKey RootURLKey
|
||||
|
||||
// WithRootURL sets the Root URL in a context.
|
||||
func WithRootURL(ctx context.Context, rootURL *url.URL) context.Context {
|
||||
return context.WithValue(ctx, rootURLKey, rootURL)
|
||||
}
|
||||
|
||||
func getRootURL(r *http.Request) *url.URL {
|
||||
if u, ok := r.Context().Value(rootURLKey).(*url.URL); ok {
|
||||
return u
|
||||
}
|
||||
|
||||
u := *r.URL
|
||||
if r.Host != "" {
|
||||
u.Host = r.Host
|
||||
}
|
||||
if u.Scheme == "" {
|
||||
if r.TLS != nil {
|
||||
u.Scheme = "https"
|
||||
} else {
|
||||
u.Scheme = "http"
|
||||
}
|
||||
}
|
||||
u.Path = ""
|
||||
return &u
|
||||
}
|
||||
|
||||
func serveHTML(w http.ResponseWriter, html string) {
|
||||
w.Header().Set("Content-Type", "text/html")
|
||||
w.Header().Set("Content-Length", strconv.Itoa(len(html)))
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = io.WriteString(w, html)
|
||||
}
|
||||
|
||||
func serveJSON(w http.ResponseWriter, obj interface{}) {
|
||||
bs, err := json.Marshal(obj)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write(bs)
|
||||
}
|
||||
|
||||
type State struct {
|
||||
Email string `json:"email"`
|
||||
ClientID string `json:"client_id"`
|
||||
}
|
||||
|
||||
func DecodeState(rawCode string) (*State, error) {
|
||||
var state State
|
||||
bs, _ := base64.URLEncoding.DecodeString(rawCode)
|
||||
err := json.Unmarshal(bs, &state)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &state, nil
|
||||
}
|
||||
|
||||
func (state State) Encode() string {
|
||||
bs, _ := json.Marshal(state)
|
||||
return base64.URLEncoding.EncodeToString(bs)
|
||||
}
|
||||
|
||||
func (state State) GetIDToken(r *http.Request, users map[string]*User) *IDToken {
|
||||
token := &IDToken{
|
||||
UserInfo: state.GetUserInfo(users),
|
||||
|
||||
Issuer: getRootURL(r).String(),
|
||||
Audience: state.ClientID,
|
||||
Expiry: jwt.NewNumericDate(time.Now().Add(time.Hour * 24 * 365)),
|
||||
IssuedAt: jwt.NewNumericDate(time.Now()),
|
||||
}
|
||||
return token
|
||||
}
|
||||
|
||||
func (state State) GetUserInfo(users map[string]*User) *UserInfo {
|
||||
userInfo := &UserInfo{
|
||||
Subject: state.Email,
|
||||
Email: state.Email,
|
||||
}
|
||||
|
||||
for _, u := range users {
|
||||
if u.Email == state.Email {
|
||||
userInfo.Subject = u.ID
|
||||
userInfo.Name = u.FirstName + " " + u.LastName
|
||||
userInfo.FamilyName = u.LastName
|
||||
userInfo.GivenName = u.FirstName
|
||||
}
|
||||
}
|
||||
|
||||
return userInfo
|
||||
}
|
||||
|
||||
type UserInfo struct {
|
||||
Subject string `json:"sub"`
|
||||
Name string `json:"name"`
|
||||
Email string `json:"email"`
|
||||
FamilyName string `json:"family_name"`
|
||||
GivenName string `json:"given_name"`
|
||||
}
|
||||
|
||||
type IDToken struct {
|
||||
*UserInfo
|
||||
|
||||
Issuer string `json:"iss"`
|
||||
Audience string `json:"aud"`
|
||||
Expiry *jwt.NumericDate `json:"exp"`
|
||||
IssuedAt *jwt.NumericDate `json:"iat"`
|
||||
}
|
||||
|
||||
func (token *IDToken) Encode(signingKey jose.SigningKey) string {
|
||||
sig, err := jose.NewSigner(signingKey, (&jose.SignerOptions{}).WithType("JWT"))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
str, err := jwt.Signed(sig).Claims(token).CompactSerialize()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return str
|
||||
}
|
||||
|
||||
type User struct {
|
||||
ID string
|
||||
Email string
|
||||
FirstName string
|
||||
LastName string
|
||||
}
|
24
internal/testenv/scenarios/mtls.go
Normal file
24
internal/testenv/scenarios/mtls.go
Normal file
|
@ -0,0 +1,24 @@
|
|||
package scenarios
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/pem"
|
||||
|
||||
"github.com/pomerium/pomerium/config"
|
||||
"github.com/pomerium/pomerium/internal/testenv"
|
||||
)
|
||||
|
||||
func DownstreamMTLS(mode config.MTLSEnforcement) testenv.Modifier {
|
||||
return testenv.ModifierFunc(func(ctx context.Context, cfg *config.Config) {
|
||||
env := testenv.EnvFromContext(ctx)
|
||||
block := pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: env.CACert().Leaf.Raw,
|
||||
}
|
||||
cfg.Options.DownstreamMTLS = config.DownstreamMTLSSettings{
|
||||
CA: base64.StdEncoding.EncodeToString(pem.EncodeToMemory(&block)),
|
||||
Enforcement: mode,
|
||||
}
|
||||
})
|
||||
}
|
64
internal/testenv/snippets/routes.go
Normal file
64
internal/testenv/snippets/routes.go
Normal file
|
@ -0,0 +1,64 @@
|
|||
package snippets
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"strings"
|
||||
"text/template"
|
||||
|
||||
"github.com/pomerium/pomerium/config"
|
||||
"github.com/pomerium/pomerium/internal/testenv"
|
||||
"github.com/pomerium/pomerium/pkg/policy/parser"
|
||||
)
|
||||
|
||||
var SimplePolicyTemplate = PolicyTemplate{
|
||||
From: "https://from-{{.Idx}}.localhost",
|
||||
To: "https://to-{{.Idx}}.localhost",
|
||||
PPL: `{"allow":{"and":["email":{"is":"user-{{.Idx}}@example.com"}]}}`,
|
||||
}
|
||||
|
||||
type PolicyTemplate struct {
|
||||
From string
|
||||
To string
|
||||
PPL string
|
||||
|
||||
// Add more fields as needed (be sure to update newPolicyFromTemplate)
|
||||
}
|
||||
|
||||
func TemplateRoutes(n int, tmpl PolicyTemplate) testenv.Modifier {
|
||||
return testenv.ModifierFunc(func(_ context.Context, cfg *config.Config) {
|
||||
for i := range n {
|
||||
cfg.Options.Policies = append(cfg.Options.Policies, newPolicyFromTemplate(i, tmpl))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func newPolicyFromTemplate(i int, pt PolicyTemplate) config.Policy {
|
||||
eval := func(in string) string {
|
||||
t := template.New("policy")
|
||||
tmpl, err := t.Parse(in)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
var out bytes.Buffer
|
||||
if err := tmpl.Execute(&out, struct{ Idx int }{i}); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return out.String()
|
||||
}
|
||||
|
||||
pplPolicy, err := parser.ParseYAML(strings.NewReader(eval(pt.PPL)))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
to, err := config.ParseWeightedUrls(eval(pt.To))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return config.Policy{
|
||||
From: eval(pt.From),
|
||||
To: to,
|
||||
Policy: &config.PPLPolicy{Policy: pplPolicy},
|
||||
}
|
||||
}
|
35
internal/testenv/snippets/wait.go
Normal file
35
internal/testenv/snippets/wait.go
Normal file
|
@ -0,0 +1,35 @@
|
|||
package snippets
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/testenv"
|
||||
"github.com/pomerium/pomerium/pkg/grpcutil"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/connectivity"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
)
|
||||
|
||||
func WaitStartupComplete(env testenv.Environment, timeout ...time.Duration) time.Duration {
|
||||
start := time.Now()
|
||||
recorder := env.NewLogRecorder()
|
||||
if len(timeout) == 0 {
|
||||
timeout = append(timeout, 1*time.Minute)
|
||||
}
|
||||
ctx, ca := context.WithTimeout(env.Context(), timeout[0])
|
||||
defer ca()
|
||||
recorder.WaitForMatch(map[string]any{
|
||||
"syncer_id": "databroker",
|
||||
"syncer_type": "type.googleapis.com/pomerium.config.Config",
|
||||
"message": "listening for updates",
|
||||
}, timeout...)
|
||||
cc, err := grpc.Dial(env.DatabrokerURL().Value(),
|
||||
grpc.WithTransportCredentials(insecure.NewCredentials()),
|
||||
grpc.WithChainUnaryInterceptor(grpcutil.WithUnarySignedJWT(env.SharedSecret)),
|
||||
grpc.WithChainStreamInterceptor(grpcutil.WithStreamSignedJWT(env.SharedSecret)),
|
||||
)
|
||||
env.Require().NoError(err)
|
||||
env.Require().True(cc.WaitForStateChange(ctx, connectivity.Ready))
|
||||
return time.Since(start)
|
||||
}
|
206
internal/testenv/types.go
Normal file
206
internal/testenv/types.go
Normal file
|
@ -0,0 +1,206 @@
|
|||
package testenv
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/pomerium/pomerium/config"
|
||||
"github.com/pomerium/pomerium/internal/testenv/values"
|
||||
)
|
||||
|
||||
type envContextKeyType struct{}
|
||||
|
||||
var envContextKey envContextKeyType
|
||||
|
||||
func EnvFromContext(ctx context.Context) Environment {
|
||||
return ctx.Value(envContextKey).(Environment)
|
||||
}
|
||||
|
||||
func ContextWithEnv(ctx context.Context, env Environment) context.Context {
|
||||
return context.WithValue(ctx, envContextKey, env)
|
||||
}
|
||||
|
||||
// A Modifier is an object whose presence in the test affects the Pomerium
|
||||
// configuration in some way. When the test environment is started, a
|
||||
// [*config.Config] is constructed by calling each added Modifier in order.
|
||||
//
|
||||
// For additional details, see [Environment.Add] and [Environment.Start].
|
||||
type Modifier interface {
|
||||
// Attach is called by an [Environment] (before Modify) to propagate the
|
||||
// environment's context.
|
||||
Attach(ctx context.Context)
|
||||
|
||||
// Modify is called by an [Environment] to mutate its configuration in some
|
||||
// way required by this Modifier.
|
||||
Modify(cfg *config.Config)
|
||||
}
|
||||
|
||||
// DefaultAttach should be embedded in types implementing [Modifier] to
|
||||
// automatically obtain environment context details and caller information.
|
||||
type DefaultAttach struct {
|
||||
env Environment
|
||||
caller string
|
||||
}
|
||||
|
||||
func (d *DefaultAttach) Env() Environment {
|
||||
d.CheckAttached()
|
||||
return d.env
|
||||
}
|
||||
|
||||
func (d *DefaultAttach) Attach(ctx context.Context) {
|
||||
if d.env != nil {
|
||||
panic("internal test environment bug: Attach called twice")
|
||||
}
|
||||
d.env = EnvFromContext(ctx)
|
||||
if d.env == nil {
|
||||
panic("test bug: no environment in context")
|
||||
}
|
||||
}
|
||||
|
||||
func (d *DefaultAttach) CheckAttached() {
|
||||
if d.env == nil {
|
||||
if d.caller != "" {
|
||||
panic("test bug: missing a call to Add for the object created at: " + d.caller)
|
||||
}
|
||||
panic("test bug: not attached (possibly missing a call to Add)")
|
||||
}
|
||||
}
|
||||
|
||||
func (d *DefaultAttach) RecordCaller() {
|
||||
d.caller = getCaller(4)
|
||||
}
|
||||
|
||||
// Aggregate should be embedded in types implementing [Modifier] when the type
|
||||
// contains other modifiers. Used as an alternative to [DefaultAttach].
|
||||
// Embedding this struct will properly keep track of when constituent modifiers
|
||||
// are added, for validation and caller detection.
|
||||
//
|
||||
// Aggregate implements a no-op Modify() by default, but this can be overridden
|
||||
// to make additional modifications. The aggregate's Modify() is called first.
|
||||
type Aggregate struct {
|
||||
env Environment
|
||||
caller string
|
||||
modifiers []Modifier
|
||||
}
|
||||
|
||||
func (d *Aggregate) Add(mod Modifier) {
|
||||
if d.env != nil {
|
||||
if d.env.(*environment).getState() == NotRunning {
|
||||
// If the test environment is running, adding to an aggregate is a no-op.
|
||||
// If the test environment has not been started yet, the aggregate is
|
||||
// being used like in the following example, which is incorrect:
|
||||
//
|
||||
// aggregate.Add(foo)
|
||||
// env.Add(aggregate)
|
||||
// aggregate.Add(bar)
|
||||
// env.Start()
|
||||
//
|
||||
// It should instead be used like this:
|
||||
//
|
||||
// aggregate.Add(foo)
|
||||
// aggregate.Add(bar)
|
||||
// env.Add(aggregate)
|
||||
// env.Start()
|
||||
panic("test bug: cannot modify an aggregate that has already been added")
|
||||
}
|
||||
return
|
||||
}
|
||||
d.modifiers = append(d.modifiers, mod)
|
||||
}
|
||||
|
||||
func (d *Aggregate) Env() Environment {
|
||||
d.CheckAttached()
|
||||
return d.env
|
||||
}
|
||||
|
||||
func (d *Aggregate) Attach(ctx context.Context) {
|
||||
if d.env != nil {
|
||||
panic("internal test environment bug: Attach called twice")
|
||||
}
|
||||
d.env = EnvFromContext(ctx)
|
||||
if d.env == nil {
|
||||
panic("test bug: no environment in context")
|
||||
}
|
||||
d.env.(*environment).t.Helper()
|
||||
for _, mod := range d.modifiers {
|
||||
d.env.Add(mod)
|
||||
}
|
||||
}
|
||||
|
||||
func (d *Aggregate) Modify(*config.Config) {}
|
||||
|
||||
func (d *Aggregate) CheckAttached() {
|
||||
if d.env == nil {
|
||||
if d.caller != "" {
|
||||
panic("test bug: missing a call to Add for the object created at: " + d.caller)
|
||||
}
|
||||
panic("test bug: not attached (possibly missing a call to Add)")
|
||||
}
|
||||
}
|
||||
|
||||
func (d *Aggregate) RecordCaller() {
|
||||
d.caller = getCaller(4)
|
||||
}
|
||||
|
||||
type modifierFunc struct {
|
||||
fn func(ctx context.Context, cfg *config.Config)
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
// Attach implements Modifier.
|
||||
func (f *modifierFunc) Attach(ctx context.Context) {
|
||||
f.ctx = ctx
|
||||
}
|
||||
|
||||
func (f *modifierFunc) Modify(cfg *config.Config) {
|
||||
f.fn(f.ctx, cfg)
|
||||
}
|
||||
|
||||
var _ Modifier = (*modifierFunc)(nil)
|
||||
|
||||
func ModifierFunc(fn func(ctx context.Context, cfg *config.Config)) Modifier {
|
||||
return &modifierFunc{fn: fn}
|
||||
}
|
||||
|
||||
// Task represents a background task that can be added to an [Environment] to
|
||||
// have it run automatically on startup.
|
||||
//
|
||||
// For additional details, see [Environment.AddTask] and [Environment.Start].
|
||||
type Task interface {
|
||||
Run(ctx context.Context) error
|
||||
}
|
||||
|
||||
type TaskFunc func(ctx context.Context) error
|
||||
|
||||
func (f TaskFunc) Run(ctx context.Context) error {
|
||||
return f(ctx)
|
||||
}
|
||||
|
||||
// Upstream represents an upstream server. It is both a [Task] and a [Modifier]
|
||||
// and can be added to an environment using [Environment.AddUpstream]. From an
|
||||
// Upstream instance, new routes can be created (which automatically adds the
|
||||
// necessary route/policy entries to the config), and used within a test to
|
||||
// easily make requests to the routes with implementation-specific clients.
|
||||
type Upstream interface {
|
||||
Modifier
|
||||
Task
|
||||
Port() values.Value[int]
|
||||
Route() RouteStub
|
||||
}
|
||||
|
||||
// A Route represents a route from a source URL to a destination URL. A route is
|
||||
// typically created by calling [Upstream.Route].
|
||||
type Route interface {
|
||||
Modifier
|
||||
URL() values.Value[string]
|
||||
To(toURL values.Value[string]) Route
|
||||
Policy(edit func(*config.Policy)) Route
|
||||
PPL(ppl string) Route
|
||||
// add more methods here as they become needed
|
||||
}
|
||||
|
||||
// RouteStub represents an incomplete [Route]. Providing a URL by calling its
|
||||
// From() method will return a [Route], from which further configuration can
|
||||
// be made.
|
||||
type RouteStub interface {
|
||||
From(fromURL values.Value[string]) Route
|
||||
}
|
139
internal/testenv/upstreams/grpc.go
Normal file
139
internal/testenv/upstreams/grpc.go
Normal file
|
@ -0,0 +1,139 @@
|
|||
package upstreams
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/testenv"
|
||||
"github.com/pomerium/pomerium/internal/testenv/values"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials"
|
||||
)
|
||||
|
||||
type Options struct {
|
||||
serverOpts []grpc.ServerOption
|
||||
}
|
||||
|
||||
type Option func(*Options)
|
||||
|
||||
func (o *Options) apply(opts ...Option) {
|
||||
for _, op := range opts {
|
||||
op(o)
|
||||
}
|
||||
}
|
||||
|
||||
func ServerOpts(opt ...grpc.ServerOption) Option {
|
||||
return func(o *Options) {
|
||||
o.serverOpts = append(o.serverOpts, opt...)
|
||||
}
|
||||
}
|
||||
|
||||
// GRPCUpstream represents a GRPC server which can be used as the target for
|
||||
// one or more Pomerium routes in a test environment.
|
||||
//
|
||||
// This upstream implements [grpc.ServiceRegistrar], and can be used similarly
|
||||
// in the same way as [*grpc.Server] to register services before it is started.
|
||||
//
|
||||
// Any [testenv.Route] instances created from this upstream can be referenced
|
||||
// in the Dial() method to establish a connection to that route.
|
||||
type GRPCUpstream interface {
|
||||
testenv.Upstream
|
||||
grpc.ServiceRegistrar
|
||||
Dial(r testenv.Route, dialOpts ...grpc.DialOption) *grpc.ClientConn
|
||||
}
|
||||
|
||||
type grpcUpstream struct {
|
||||
Options
|
||||
testenv.Aggregate
|
||||
serverPort values.MutableValue[int]
|
||||
creds credentials.TransportCredentials
|
||||
|
||||
services []service
|
||||
}
|
||||
|
||||
var (
|
||||
_ testenv.Upstream = (*grpcUpstream)(nil)
|
||||
_ grpc.ServiceRegistrar = (*grpcUpstream)(nil)
|
||||
)
|
||||
|
||||
// GRPC creates a new GRPC upstream server.
|
||||
func GRPC(creds credentials.TransportCredentials, opts ...Option) GRPCUpstream {
|
||||
options := Options{}
|
||||
options.apply(opts...)
|
||||
up := &grpcUpstream{
|
||||
Options: options,
|
||||
creds: creds,
|
||||
serverPort: values.Deferred[int](),
|
||||
}
|
||||
up.RecordCaller()
|
||||
return up
|
||||
}
|
||||
|
||||
type service struct {
|
||||
desc *grpc.ServiceDesc
|
||||
impl any
|
||||
}
|
||||
|
||||
func (g *grpcUpstream) Port() values.Value[int] {
|
||||
return g.serverPort
|
||||
}
|
||||
|
||||
// RegisterService implements grpc.ServiceRegistrar.
|
||||
func (g *grpcUpstream) RegisterService(desc *grpc.ServiceDesc, impl any) {
|
||||
g.services = append(g.services, service{desc, impl})
|
||||
}
|
||||
|
||||
// Route implements testenv.Upstream.
|
||||
func (g *grpcUpstream) Route() testenv.RouteStub {
|
||||
r := &testenv.PolicyRoute{}
|
||||
var protocol string
|
||||
switch g.creds.Info().SecurityProtocol {
|
||||
case "insecure":
|
||||
protocol = "h2c"
|
||||
default:
|
||||
protocol = "https"
|
||||
}
|
||||
r.To(values.Bind(g.serverPort, func(port int) string {
|
||||
return fmt.Sprintf("%s://127.0.0.1:%d", protocol, port)
|
||||
}))
|
||||
g.Add(r)
|
||||
return r
|
||||
}
|
||||
|
||||
// Start implements testenv.Upstream.
|
||||
func (g *grpcUpstream) Run(ctx context.Context) error {
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
g.serverPort.Resolve(listener.Addr().(*net.TCPAddr).Port)
|
||||
server := grpc.NewServer(append(g.serverOpts, grpc.Creds(g.creds))...)
|
||||
for _, s := range g.services {
|
||||
server.RegisterService(s.desc, s.impl)
|
||||
}
|
||||
errC := make(chan error, 1)
|
||||
go func() {
|
||||
errC <- server.Serve(listener)
|
||||
}()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
server.Stop()
|
||||
return context.Cause(ctx)
|
||||
case err := <-errC:
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
func (g *grpcUpstream) Dial(r testenv.Route, dialOpts ...grpc.DialOption) *grpc.ClientConn {
|
||||
dialOpts = append(dialOpts,
|
||||
grpc.WithTransportCredentials(credentials.NewClientTLSFromCert(g.Env().ServerCAs(), "")),
|
||||
grpc.WithDefaultCallOptions(grpc.WaitForReady(true)),
|
||||
)
|
||||
cc, err := grpc.NewClient(strings.TrimPrefix(r.URL().Value(), "https://"), dialOpts...)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return cc
|
||||
}
|
327
internal/testenv/upstreams/http.go
Normal file
327
internal/testenv/upstreams/http.go
Normal file
|
@ -0,0 +1,327 @@
|
|||
package upstreams
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/cookiejar"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/pomerium/pomerium/integration/forms"
|
||||
"github.com/pomerium/pomerium/internal/retry"
|
||||
"github.com/pomerium/pomerium/internal/testenv"
|
||||
"github.com/pomerium/pomerium/internal/testenv/values"
|
||||
"google.golang.org/protobuf/proto"
|
||||
)
|
||||
|
||||
type RequestOptions struct {
|
||||
path string
|
||||
query url.Values
|
||||
headers map[string]string
|
||||
authenticateAs string
|
||||
body any
|
||||
clientCerts []tls.Certificate
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
type RequestOption func(*RequestOptions)
|
||||
|
||||
func (o *RequestOptions) apply(opts ...RequestOption) {
|
||||
for _, op := range opts {
|
||||
op(o)
|
||||
}
|
||||
}
|
||||
|
||||
// Path sets the path of the request. If omitted, the request URL will match
|
||||
// the route URL exactly.
|
||||
func Path(path string) RequestOption {
|
||||
return func(o *RequestOptions) {
|
||||
o.path = path
|
||||
}
|
||||
}
|
||||
|
||||
// Query sets optional query parameters of the request.
|
||||
func Query(query url.Values) RequestOption {
|
||||
return func(o *RequestOptions) {
|
||||
o.query = query
|
||||
}
|
||||
}
|
||||
|
||||
// Headers adds optional headers to the request.
|
||||
func Headers(headers map[string]string) RequestOption {
|
||||
return func(o *RequestOptions) {
|
||||
o.headers = headers
|
||||
}
|
||||
}
|
||||
|
||||
func AuthenticateAs(email string) RequestOption {
|
||||
return func(o *RequestOptions) {
|
||||
o.authenticateAs = email
|
||||
}
|
||||
}
|
||||
|
||||
func Client(c *http.Client) RequestOption {
|
||||
return func(o *RequestOptions) {
|
||||
o.client = c
|
||||
}
|
||||
}
|
||||
|
||||
// Body sets the body of the request.
|
||||
// The argument can be one of the following types:
|
||||
// - string
|
||||
// - []byte
|
||||
// - io.Reader
|
||||
// - proto.Message
|
||||
// - any json-encodable type
|
||||
// If the argument is encoded as json, the Content-Type header will be set to
|
||||
// "application/json". If the argument is a proto.Message, the Content-Type
|
||||
// header will be set to "application/octet-stream".
|
||||
func Body(body any) RequestOption {
|
||||
return func(o *RequestOptions) {
|
||||
o.body = body
|
||||
}
|
||||
}
|
||||
|
||||
// ClientCert adds a client certificate to the request.
|
||||
func ClientCert[T interface {
|
||||
*testenv.Certificate | *tls.Certificate
|
||||
}](cert T) RequestOption {
|
||||
return func(o *RequestOptions) {
|
||||
o.clientCerts = append(o.clientCerts, *(*tls.Certificate)(cert))
|
||||
}
|
||||
}
|
||||
|
||||
// HTTPUpstream represents a HTTP server which can be used as the target for
|
||||
// one or more Pomerium routes in a test environment.
|
||||
//
|
||||
// The Handle() method can be used to add handlers the server-side HTTP router,
|
||||
// while the Get(), Post(), and (generic) Do() methods can be used to make
|
||||
// client-side requests.
|
||||
type HTTPUpstream interface {
|
||||
testenv.Upstream
|
||||
|
||||
Handle(path string, f func(http.ResponseWriter, *http.Request)) *mux.Route
|
||||
|
||||
Get(r testenv.Route, opts ...RequestOption) (*http.Response, error)
|
||||
Post(r testenv.Route, opts ...RequestOption) (*http.Response, error)
|
||||
Do(method string, r testenv.Route, opts ...RequestOption) (*http.Response, error)
|
||||
}
|
||||
|
||||
type httpUpstream struct {
|
||||
testenv.Aggregate
|
||||
serverPort values.MutableValue[int]
|
||||
tlsConfig values.Value[*tls.Config]
|
||||
|
||||
clientCache sync.Map // map[testenv.Route]*http.Client
|
||||
|
||||
router *mux.Router
|
||||
}
|
||||
|
||||
var (
|
||||
_ testenv.Upstream = (*httpUpstream)(nil)
|
||||
_ HTTPUpstream = (*httpUpstream)(nil)
|
||||
)
|
||||
|
||||
// HTTP creates a new HTTP upstream server.
|
||||
func HTTP(tlsConfig values.Value[*tls.Config]) HTTPUpstream {
|
||||
up := &httpUpstream{
|
||||
serverPort: values.Deferred[int](),
|
||||
router: mux.NewRouter(),
|
||||
tlsConfig: tlsConfig,
|
||||
}
|
||||
up.RecordCaller()
|
||||
return up
|
||||
}
|
||||
|
||||
// Port implements HTTPUpstream.
|
||||
func (h *httpUpstream) Port() values.Value[int] {
|
||||
return h.serverPort
|
||||
}
|
||||
|
||||
// Router implements HTTPUpstream.
|
||||
func (h *httpUpstream) Handle(path string, f func(http.ResponseWriter, *http.Request)) *mux.Route {
|
||||
return h.router.HandleFunc(path, f)
|
||||
}
|
||||
|
||||
// Route implements HTTPUpstream.
|
||||
func (h *httpUpstream) Route() testenv.RouteStub {
|
||||
r := &testenv.PolicyRoute{}
|
||||
protocol := "http"
|
||||
r.To(values.Bind(h.serverPort, func(port int) string {
|
||||
return fmt.Sprintf("%s://127.0.0.1:%d", protocol, port)
|
||||
}))
|
||||
h.Add(r)
|
||||
return r
|
||||
}
|
||||
|
||||
// Run implements HTTPUpstream.
|
||||
func (h *httpUpstream) Run(ctx context.Context) error {
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
h.serverPort.Resolve(listener.Addr().(*net.TCPAddr).Port)
|
||||
var tlsConfig *tls.Config
|
||||
if h.tlsConfig != nil {
|
||||
tlsConfig = h.tlsConfig.Value()
|
||||
}
|
||||
server := &http.Server{
|
||||
Handler: h.router,
|
||||
TLSConfig: tlsConfig,
|
||||
BaseContext: func(net.Listener) context.Context {
|
||||
return ctx
|
||||
},
|
||||
}
|
||||
errC := make(chan error, 1)
|
||||
go func() {
|
||||
errC <- server.Serve(listener)
|
||||
}()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
server.Close()
|
||||
return context.Cause(ctx)
|
||||
case err := <-errC:
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Get implements HTTPUpstream.
|
||||
func (h *httpUpstream) Get(r testenv.Route, opts ...RequestOption) (*http.Response, error) {
|
||||
return h.Do(http.MethodGet, r, opts...)
|
||||
}
|
||||
|
||||
// Post implements HTTPUpstream.
|
||||
func (h *httpUpstream) Post(r testenv.Route, opts ...RequestOption) (*http.Response, error) {
|
||||
return h.Do(http.MethodPost, r, opts...)
|
||||
}
|
||||
|
||||
// Do implements HTTPUpstream.
|
||||
func (h *httpUpstream) Do(method string, r testenv.Route, opts ...RequestOption) (*http.Response, error) {
|
||||
options := RequestOptions{}
|
||||
options.apply(opts...)
|
||||
u, err := url.Parse(r.URL().Value())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if options.path != "" || options.query != nil {
|
||||
u = u.ResolveReference(&url.URL{
|
||||
Path: options.path,
|
||||
RawQuery: options.query.Encode(),
|
||||
})
|
||||
}
|
||||
req, err := http.NewRequest(method, u.String(), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
switch body := options.body.(type) {
|
||||
case string:
|
||||
req.Body = io.NopCloser(strings.NewReader(body))
|
||||
case []byte:
|
||||
req.Body = io.NopCloser(bytes.NewReader(body))
|
||||
case io.Reader:
|
||||
req.Body = io.NopCloser(body)
|
||||
case proto.Message:
|
||||
buf, err := proto.Marshal(body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Body = io.NopCloser(bytes.NewReader(buf))
|
||||
req.Header.Set("Content-Type", "application/octet-stream")
|
||||
default:
|
||||
buf, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("unsupported body type: %T", body))
|
||||
}
|
||||
req.Body = io.NopCloser(bytes.NewReader(buf))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
case nil:
|
||||
}
|
||||
|
||||
newClient := func() *http.Client {
|
||||
c := http.Client{
|
||||
Transport: &http.Transport{
|
||||
TLSClientConfig: &tls.Config{
|
||||
RootCAs: h.Env().ServerCAs(),
|
||||
Certificates: options.clientCerts,
|
||||
},
|
||||
},
|
||||
}
|
||||
c.Jar, _ = cookiejar.New(&cookiejar.Options{})
|
||||
return &c
|
||||
}
|
||||
var client *http.Client
|
||||
if options.client != nil {
|
||||
client = options.client
|
||||
} else {
|
||||
var cachedClient any
|
||||
var ok bool
|
||||
if cachedClient, ok = h.clientCache.Load(r); !ok {
|
||||
cachedClient, _ = h.clientCache.LoadOrStore(r, newClient())
|
||||
}
|
||||
client = cachedClient.(*http.Client)
|
||||
}
|
||||
|
||||
var resp *http.Response
|
||||
if err := retry.Retry(h.Env().Context(), "http", func(ctx context.Context) error {
|
||||
var err error
|
||||
if options.authenticateAs != "" {
|
||||
resp, err = authenticateFlow(ctx, client, req, options.authenticateAs) //nolint:bodyclose
|
||||
} else {
|
||||
resp, err = client.Do(req) //nolint:bodyclose
|
||||
}
|
||||
// retry on connection refused
|
||||
if err != nil {
|
||||
var opErr *net.OpError
|
||||
if errors.As(err, &opErr) && opErr.Op == "dial" && opErr.Err.Error() == "connect: connection refused" {
|
||||
return err
|
||||
}
|
||||
return retry.NewTerminalError(err)
|
||||
}
|
||||
if resp.StatusCode == http.StatusInternalServerError {
|
||||
return errors.New(http.StatusText(resp.StatusCode))
|
||||
}
|
||||
return nil
|
||||
}, retry.WithMaxInterval(100*time.Millisecond)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func authenticateFlow(ctx context.Context, client *http.Client, req *http.Request, email string) (*http.Response, error) {
|
||||
var res *http.Response
|
||||
originalHostname := req.URL.Hostname()
|
||||
res, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
location := res.Request.URL
|
||||
if location.Hostname() == originalHostname {
|
||||
// already authenticated
|
||||
return res, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
fs := forms.Parse(res.Body)
|
||||
if len(fs) > 0 {
|
||||
f := fs[0]
|
||||
f.Inputs["email"] = email
|
||||
f.Inputs["token_expiration"] = strconv.Itoa(int((time.Hour * 24).Seconds()))
|
||||
formReq, err := f.NewRequestWithContext(ctx, location)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return client.Do(formReq)
|
||||
}
|
||||
return nil, fmt.Errorf("test bug: expected IDP login form")
|
||||
}
|
120
internal/testenv/values/value.go
Normal file
120
internal/testenv/values/value.go
Normal file
|
@ -0,0 +1,120 @@
|
|||
package values
|
||||
|
||||
import (
|
||||
"math/rand/v2"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type value[T any] struct {
|
||||
f func() T
|
||||
ready bool
|
||||
cond *sync.Cond
|
||||
}
|
||||
|
||||
// A Value is a container for a single value of type T, whose initialization is
|
||||
// performed the first time Value() is called. Subsequent calls will return the
|
||||
// same value. The Value() function may block until the value is ready on the
|
||||
// first call. Values are safe to use concurrently.
|
||||
type Value[T any] interface {
|
||||
Value() T
|
||||
}
|
||||
|
||||
// MutableValue is the read-write counterpart to [Value], created by calling
|
||||
// [Deferred] for some type T. Calling Resolve() or ResolveFunc() will set
|
||||
// the value and unblock any waiting calls to Value().
|
||||
type MutableValue[T any] interface {
|
||||
Value[T]
|
||||
Resolve(value T)
|
||||
ResolveFunc(fOnce func() T)
|
||||
}
|
||||
|
||||
// Deferred creates a new read-write [MutableValue] for some type T,
|
||||
// representing a value whose initialization may be deferred to a later time.
|
||||
// Once the value is available, call [MutableValue.Resolve] or
|
||||
// [MutableValue.ResolveFunc] to unblock any waiting calls to Value().
|
||||
func Deferred[T any]() MutableValue[T] {
|
||||
return &value[T]{
|
||||
cond: sync.NewCond(&sync.Mutex{}),
|
||||
}
|
||||
}
|
||||
|
||||
// Const creates a read-only [Value] which will become available immediately
|
||||
// upon calling Value() for the first time; it will never block.
|
||||
func Const[T any](t T) Value[T] {
|
||||
return &value[T]{
|
||||
f: func() T { return t },
|
||||
ready: true,
|
||||
cond: sync.NewCond(&sync.Mutex{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (p *value[T]) Value() T {
|
||||
p.cond.L.Lock()
|
||||
defer p.cond.L.Unlock()
|
||||
for !p.ready {
|
||||
p.cond.Wait()
|
||||
}
|
||||
return p.f()
|
||||
}
|
||||
|
||||
func (p *value[T]) ResolveFunc(fOnce func() T) {
|
||||
p.cond.L.Lock()
|
||||
p.f = sync.OnceValue(fOnce)
|
||||
p.ready = true
|
||||
p.cond.L.Unlock()
|
||||
p.cond.Broadcast()
|
||||
}
|
||||
|
||||
func (p *value[T]) Resolve(value T) {
|
||||
p.ResolveFunc(func() T { return value })
|
||||
}
|
||||
|
||||
// Bind creates a new [MutableValue] whose ultimate value depends on the result
|
||||
// of another [Value] that may not yet be available. When Value() is called on
|
||||
// the result, it will cascade and trigger the full chain of initialization
|
||||
// functions necessary to produce the final value.
|
||||
//
|
||||
// Care should be taken when using this function, as improper use can lead to
|
||||
// deadlocks and cause values to never become available.
|
||||
func Bind[T any, U any](dt Value[T], callback func(value T) U) Value[U] {
|
||||
du := Deferred[U]()
|
||||
du.ResolveFunc(func() U {
|
||||
return callback(dt.Value())
|
||||
})
|
||||
return du
|
||||
}
|
||||
|
||||
// Bind2 is like [Bind], but can accept two input values. The result will only
|
||||
// become available once all input values become available.
|
||||
//
|
||||
// This function blocks to wait for each input value in sequence, but in a
|
||||
// random order. Do not rely on the order of evaluation of the input values.
|
||||
func Bind2[T any, U any, V any](dt Value[T], du Value[U], callback func(value1 T, value2 U) V) Value[V] {
|
||||
dv := Deferred[V]()
|
||||
dv.ResolveFunc(func() V {
|
||||
if rand.IntN(2) == 0 { //nolint:gosec
|
||||
return callback(dt.Value(), du.Value())
|
||||
}
|
||||
u := du.Value()
|
||||
t := dt.Value()
|
||||
return callback(t, u)
|
||||
})
|
||||
return dv
|
||||
}
|
||||
|
||||
// List is a container for a slice of [Value] of type T, and is also a [Value]
|
||||
// itself, for convenience. The Value() function will return a []T containing
|
||||
// all resolved values for each element in the slice.
|
||||
//
|
||||
// A List's Value() function blocks to wait for each element in the slice in
|
||||
// sequence, but in a random order. Do not rely on the order of evaluation of
|
||||
// the slice elements.
|
||||
type List[T any] []Value[T]
|
||||
|
||||
func (s List[T]) Value() []T {
|
||||
values := make([]T, len(s))
|
||||
for _, i := range rand.Perm(len(values)) {
|
||||
values[i] = s[i].Value()
|
||||
}
|
||||
return values
|
||||
}
|
|
@ -16,6 +16,7 @@ import (
|
|||
"github.com/pomerium/pomerium/authenticate"
|
||||
"github.com/pomerium/pomerium/authorize"
|
||||
"github.com/pomerium/pomerium/config"
|
||||
"github.com/pomerium/pomerium/config/envoyconfig/filemgr"
|
||||
databroker_service "github.com/pomerium/pomerium/databroker"
|
||||
"github.com/pomerium/pomerium/internal/autocert"
|
||||
"github.com/pomerium/pomerium/internal/controlplane"
|
||||
|
@ -30,8 +31,29 @@ import (
|
|||
"github.com/pomerium/pomerium/proxy"
|
||||
)
|
||||
|
||||
type RunOptions struct {
|
||||
fileMgr *filemgr.Manager
|
||||
}
|
||||
|
||||
type RunOption func(*RunOptions)
|
||||
|
||||
func (o *RunOptions) apply(opts ...RunOption) {
|
||||
for _, op := range opts {
|
||||
op(o)
|
||||
}
|
||||
}
|
||||
|
||||
func WithOverrideFileManager(fileMgr *filemgr.Manager) RunOption {
|
||||
return func(o *RunOptions) {
|
||||
o.fileMgr = fileMgr
|
||||
}
|
||||
}
|
||||
|
||||
// Run runs the main pomerium application.
|
||||
func Run(ctx context.Context, src config.Source) error {
|
||||
func Run(ctx context.Context, src config.Source, opts ...RunOption) error {
|
||||
options := RunOptions{}
|
||||
options.apply(opts...)
|
||||
|
||||
_, _ = maxprocs.Set(maxprocs.Logger(func(s string, i ...any) { log.Ctx(ctx).Debug().Msgf(s, i...) }))
|
||||
|
||||
evt := log.Ctx(ctx).Info().
|
||||
|
@ -68,10 +90,15 @@ func Run(ctx context.Context, src config.Source) error {
|
|||
|
||||
eventsMgr := events.New()
|
||||
|
||||
fileMgr := options.fileMgr
|
||||
if fileMgr == nil {
|
||||
fileMgr = filemgr.NewManager()
|
||||
}
|
||||
|
||||
cfg := src.GetConfig()
|
||||
|
||||
// setup the control plane
|
||||
controlPlane, err := controlplane.NewServer(ctx, cfg, metricsMgr, eventsMgr)
|
||||
controlPlane, err := controlplane.NewServer(ctx, cfg, metricsMgr, eventsMgr, fileMgr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error creating control plane: %w", err)
|
||||
}
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
package cryptutil
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
|
@ -15,10 +14,9 @@ import (
|
|||
|
||||
// GetCertPool gets a cert pool for the given CA or CAFile.
|
||||
func GetCertPool(ca, caFile string) (*x509.CertPool, error) {
|
||||
ctx := context.TODO()
|
||||
rootCAs, err := x509.SystemCertPool()
|
||||
if err != nil {
|
||||
log.Ctx(ctx).Error().Err(err).Msg("pkg/cryptutil: failed getting system cert pool making new one")
|
||||
log.Error().Err(err).Msg("pkg/cryptutil: failed getting system cert pool making new one")
|
||||
rootCAs = x509.NewCertPool()
|
||||
}
|
||||
if ca == "" && caFile == "" {
|
||||
|
@ -40,7 +38,9 @@ func GetCertPool(ca, caFile string) (*x509.CertPool, error) {
|
|||
if ok := rootCAs.AppendCertsFromPEM(data); !ok {
|
||||
return nil, fmt.Errorf("failed to append any PEM-encoded certificates")
|
||||
}
|
||||
log.Ctx(ctx).Debug().Msg("pkg/cryptutil: added custom certificate authority")
|
||||
if !log.DebugDisableGlobalMessages.Load() {
|
||||
log.Debug().Msg("pkg/cryptutil: added custom certificate authority")
|
||||
}
|
||||
return rootCAs, nil
|
||||
}
|
||||
|
||||
|
|
|
@ -186,7 +186,25 @@ func (srv *Server) run(ctx context.Context, cfg *config.Config) error {
|
|||
// monitor the process so we exit if it prematurely exits
|
||||
var monitorProcessCtx context.Context
|
||||
monitorProcessCtx, srv.monitorProcessCancel = context.WithCancel(context.WithoutCancel(ctx))
|
||||
go srv.monitorProcess(monitorProcessCtx, int32(cmd.Process.Pid))
|
||||
go func() {
|
||||
pid := cmd.Process.Pid
|
||||
err := srv.monitorProcess(monitorProcessCtx, int32(pid))
|
||||
if err != nil && ctx.Err() == nil {
|
||||
// If the envoy subprocess exits and ctx is not done, issue a fatal error.
|
||||
// If ctx is done, the server is already exiting, and envoy is expected
|
||||
// to be stopped along with it.
|
||||
log.Ctx(ctx).
|
||||
Fatal().
|
||||
Int("pid", pid).
|
||||
Err(err).
|
||||
Send()
|
||||
}
|
||||
log.Ctx(ctx).
|
||||
Debug().
|
||||
Int("pid", pid).
|
||||
Err(ctx.Err()).
|
||||
Msg("envoy process monitor stopped")
|
||||
}()
|
||||
|
||||
if srv.resourceMonitor != nil {
|
||||
log.Ctx(ctx).Debug().Str("service", "envoy").Msg("starting resource monitor")
|
||||
|
@ -300,7 +318,7 @@ func (srv *Server) handleLogs(ctx context.Context, rc io.ReadCloser) {
|
|||
}
|
||||
}
|
||||
|
||||
func (srv *Server) monitorProcess(ctx context.Context, pid int32) {
|
||||
func (srv *Server) monitorProcess(ctx context.Context, pid int32) error {
|
||||
log.Ctx(ctx).Debug().
|
||||
Int32("pid", pid).
|
||||
Msg("envoy: start monitoring subprocess")
|
||||
|
@ -311,19 +329,15 @@ func (srv *Server) monitorProcess(ctx context.Context, pid int32) {
|
|||
for {
|
||||
exists, err := process.PidExistsWithContext(ctx, pid)
|
||||
if err != nil {
|
||||
log.Fatal().Err(err).
|
||||
Int32("pid", pid).
|
||||
Msg("envoy: error retrieving subprocess information")
|
||||
return fmt.Errorf("envoy: error retrieving subprocess information: %w", err)
|
||||
} else if !exists {
|
||||
log.Fatal().Err(err).
|
||||
Int32("pid", pid).
|
||||
Msg("envoy: subprocess exited")
|
||||
return errors.New("envoy: subprocess exited")
|
||||
}
|
||||
|
||||
// wait for the next tick
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
return nil
|
||||
case <-ticker.C:
|
||||
}
|
||||
}
|
||||
|
|
|
@ -6,6 +6,7 @@ package envoy
|
|||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"sync"
|
||||
"syscall"
|
||||
|
@ -17,7 +18,7 @@ import (
|
|||
"github.com/pomerium/pomerium/internal/telemetry/metrics"
|
||||
)
|
||||
|
||||
const baseIDPath = "/tmp/pomerium-envoy-base-id"
|
||||
const baseIDName = "pomerium-envoy-base-id"
|
||||
|
||||
var restartEpoch struct {
|
||||
sync.Mutex
|
||||
|
@ -89,7 +90,7 @@ func (srv *Server) prepareRunEnvoyCommand(ctx context.Context, sharedArgs []stri
|
|||
} else {
|
||||
args = append(args,
|
||||
"--use-dynamic-base-id",
|
||||
"--base-id-path", baseIDPath,
|
||||
"--base-id-path", filepath.Join(os.TempDir(), baseIDName),
|
||||
)
|
||||
restartEpoch.value = 1
|
||||
}
|
||||
|
@ -99,7 +100,7 @@ func (srv *Server) prepareRunEnvoyCommand(ctx context.Context, sharedArgs []stri
|
|||
}
|
||||
|
||||
func readBaseID() (int, bool) {
|
||||
bs, err := os.ReadFile(baseIDPath)
|
||||
bs, err := os.ReadFile(filepath.Join(os.TempDir(), baseIDName))
|
||||
if err != nil {
|
||||
return 0, false
|
||||
}
|
||||
|
|
|
@ -3,6 +3,7 @@ package databroker
|
|||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
backoff "github.com/cenkalti/backoff/v4"
|
||||
|
@ -71,12 +72,24 @@ type Syncer struct {
|
|||
id string
|
||||
}
|
||||
|
||||
var DebugUseFasterBackoff atomic.Bool
|
||||
|
||||
// NewSyncer creates a new Syncer.
|
||||
func NewSyncer(ctx context.Context, id string, handler SyncerHandler, options ...SyncerOption) *Syncer {
|
||||
closeCtx, closeCtxCancel := context.WithCancel(context.WithoutCancel(ctx))
|
||||
|
||||
bo := backoff.NewExponentialBackOff()
|
||||
var bo *backoff.ExponentialBackOff
|
||||
if DebugUseFasterBackoff.Load() {
|
||||
bo = backoff.NewExponentialBackOff(
|
||||
backoff.WithInitialInterval(10*time.Millisecond),
|
||||
backoff.WithMultiplier(1.0),
|
||||
backoff.WithMaxElapsedTime(100*time.Millisecond),
|
||||
)
|
||||
bo.MaxElapsedTime = 0
|
||||
} else {
|
||||
bo = backoff.NewExponentialBackOff()
|
||||
bo.MaxElapsedTime = 0
|
||||
}
|
||||
s := &Syncer{
|
||||
cfg: getSyncerConfig(options...),
|
||||
handler: handler,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue