mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-31 18:07:17 +02:00
New integration test fixtures (#5233)
* Initial test environment implementation * linter pass * wip: update request latency test * bugfixes * Fix logic race in envoy process monitor when canceling context * skip tests using test environment on non-linux
This commit is contained in:
parent
3d958ff9c5
commit
526e2a58d6
29 changed files with 2972 additions and 101 deletions
|
@ -76,3 +76,6 @@ issues:
|
||||||
- text: "G112:"
|
- text: "G112:"
|
||||||
linters:
|
linters:
|
||||||
- gosec
|
- gosec
|
||||||
|
- text: "G402: TLS MinVersion too low."
|
||||||
|
linters:
|
||||||
|
- gosec
|
||||||
|
|
|
@ -26,7 +26,7 @@ import (
|
||||||
const maxActiveDownstreamConnections = 50000
|
const maxActiveDownstreamConnections = 50000
|
||||||
|
|
||||||
var (
|
var (
|
||||||
envoyAdminAddressPath = filepath.Join(os.TempDir(), "pomerium-envoy-admin.sock")
|
envoyAdminAddressSockName = "pomerium-envoy-admin.sock"
|
||||||
envoyAdminAddressMode = 0o600
|
envoyAdminAddressMode = 0o600
|
||||||
envoyAdminClusterName = "pomerium-envoy-admin"
|
envoyAdminClusterName = "pomerium-envoy-admin"
|
||||||
)
|
)
|
||||||
|
@ -95,7 +95,7 @@ func (b *Builder) BuildBootstrapAdmin(cfg *config.Config) (admin *envoy_config_b
|
||||||
admin.Address = &envoy_config_core_v3.Address{
|
admin.Address = &envoy_config_core_v3.Address{
|
||||||
Address: &envoy_config_core_v3.Address_Pipe{
|
Address: &envoy_config_core_v3.Address_Pipe{
|
||||||
Pipe: &envoy_config_core_v3.Pipe{
|
Pipe: &envoy_config_core_v3.Pipe{
|
||||||
Path: envoyAdminAddressPath,
|
Path: filepath.Join(os.TempDir(), envoyAdminAddressSockName),
|
||||||
Mode: uint32(envoyAdminAddressMode),
|
Mode: uint32(envoyAdminAddressMode),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
|
@ -12,6 +12,7 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestBuilder_BuildBootstrapAdmin(t *testing.T) {
|
func TestBuilder_BuildBootstrapAdmin(t *testing.T) {
|
||||||
|
t.Setenv("TMPDIR", "/tmp")
|
||||||
b := New("local-grpc", "local-http", "local-metrics", filemgr.NewManager(), nil)
|
b := New("local-grpc", "local-http", "local-metrics", filemgr.NewManager(), nil)
|
||||||
t.Run("valid", func(t *testing.T) {
|
t.Run("valid", func(t *testing.T) {
|
||||||
adminCfg, err := b.BuildBootstrapAdmin(&config.Config{
|
adminCfg, err := b.BuildBootstrapAdmin(&config.Config{
|
||||||
|
@ -25,7 +26,7 @@ func TestBuilder_BuildBootstrapAdmin(t *testing.T) {
|
||||||
"address": {
|
"address": {
|
||||||
"pipe": {
|
"pipe": {
|
||||||
"mode": 384,
|
"mode": 384,
|
||||||
"path": "`+envoyAdminAddressPath+`"
|
"path": "/tmp/`+envoyAdminAddressSockName+`"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,6 +2,8 @@ package envoyconfig
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
|
||||||
envoy_config_cluster_v3 "github.com/envoyproxy/go-control-plane/envoy/config/cluster/v3"
|
envoy_config_cluster_v3 "github.com/envoyproxy/go-control-plane/envoy/config/cluster/v3"
|
||||||
envoy_config_core_v3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
|
envoy_config_core_v3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
|
||||||
|
@ -23,7 +25,8 @@ func (b *Builder) buildEnvoyAdminCluster(_ context.Context, _ *config.Config) (*
|
||||||
Address: &envoy_config_core_v3.Address{
|
Address: &envoy_config_core_v3.Address{
|
||||||
Address: &envoy_config_core_v3.Address_Pipe{
|
Address: &envoy_config_core_v3.Address_Pipe{
|
||||||
Pipe: &envoy_config_core_v3.Pipe{
|
Pipe: &envoy_config_core_v3.Pipe{
|
||||||
Path: envoyAdminAddressPath,
|
Path: filepath.Join(os.TempDir(), envoyAdminAddressSockName),
|
||||||
|
Mode: uint32(envoyAdminAddressMode),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
|
@ -23,11 +23,7 @@ import (
|
||||||
func Test_BuildClusters(t *testing.T) {
|
func Test_BuildClusters(t *testing.T) {
|
||||||
// The admin address path is based on os.TempDir(), which will vary from
|
// The admin address path is based on os.TempDir(), which will vary from
|
||||||
// system to system, so replace this with a stable location.
|
// system to system, so replace this with a stable location.
|
||||||
originalEnvoyAdminAddressPath := envoyAdminAddressPath
|
t.Setenv("TMPDIR", "/tmp")
|
||||||
envoyAdminAddressPath = "/tmp/pomerium-envoy-admin.sock"
|
|
||||||
t.Cleanup(func() {
|
|
||||||
envoyAdminAddressPath = originalEnvoyAdminAddressPath
|
|
||||||
})
|
|
||||||
|
|
||||||
opts := config.NewDefaultOptions()
|
opts := config.NewDefaultOptions()
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
|
@ -1,102 +1,159 @@
|
||||||
package envoyconfig_test
|
package envoyconfig_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"io"
|
||||||
|
"net/http"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"google.golang.org/grpc"
|
|
||||||
"google.golang.org/grpc/codes"
|
|
||||||
"google.golang.org/grpc/credentials"
|
|
||||||
"google.golang.org/grpc/credentials/insecure"
|
"google.golang.org/grpc/credentials/insecure"
|
||||||
"google.golang.org/grpc/interop"
|
"google.golang.org/grpc/interop"
|
||||||
"google.golang.org/grpc/interop/grpc_testing"
|
"google.golang.org/grpc/interop/grpc_testing"
|
||||||
"google.golang.org/grpc/metadata"
|
|
||||||
"google.golang.org/grpc/status"
|
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/config"
|
"github.com/pomerium/pomerium/config"
|
||||||
"github.com/pomerium/pomerium/pkg/cmd/pomerium"
|
"github.com/pomerium/pomerium/internal/testenv"
|
||||||
"github.com/pomerium/pomerium/pkg/netutil"
|
"github.com/pomerium/pomerium/internal/testenv/scenarios"
|
||||||
|
"github.com/pomerium/pomerium/internal/testenv/snippets"
|
||||||
|
"github.com/pomerium/pomerium/internal/testenv/upstreams"
|
||||||
|
"github.com/pomerium/pomerium/internal/testenv/values"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestH2C(t *testing.T) {
|
func TestH2C(t *testing.T) {
|
||||||
if testing.Short() {
|
env := testenv.New(t)
|
||||||
t.SkipNow()
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx, ca := context.WithCancel(context.Background())
|
up := upstreams.GRPC(insecure.NewCredentials())
|
||||||
|
grpc_testing.RegisterTestServiceServer(up, interop.NewTestServer())
|
||||||
|
|
||||||
opts := config.NewDefaultOptions()
|
http := up.Route().
|
||||||
listener, err := (&net.ListenConfig{}).Listen(ctx, "tcp", "127.0.0.1:0")
|
From(env.SubdomainURL("grpc-http")).
|
||||||
require.NoError(t, err)
|
To(values.Bind(up.Port(), func(port int) string {
|
||||||
ports, err := netutil.AllocatePorts(7)
|
// override the target protocol to use http://
|
||||||
require.NoError(t, err)
|
return fmt.Sprintf("http://127.0.0.1:%d", port)
|
||||||
urls, err := config.ParseWeightedUrls("http://"+listener.Addr().String(), "h2c://"+listener.Addr().String())
|
})).
|
||||||
require.NoError(t, err)
|
Policy(func(p *config.Policy) { p.AllowPublicUnauthenticatedAccess = true })
|
||||||
opts.Addr = fmt.Sprintf("127.0.0.1:%s", ports[0])
|
|
||||||
opts.Routes = []config.Policy{
|
|
||||||
{
|
|
||||||
From: fmt.Sprintf("https://grpc-http.localhost.pomerium.io:%s", ports[0]),
|
|
||||||
To: urls[:1],
|
|
||||||
AllowPublicUnauthenticatedAccess: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
From: fmt.Sprintf("https://grpc-h2c.localhost.pomerium.io:%s", ports[0]),
|
|
||||||
To: urls[1:],
|
|
||||||
AllowPublicUnauthenticatedAccess: true,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
opts.CertFile = "../../integration/tpl/files/trusted.pem"
|
|
||||||
opts.KeyFile = "../../integration/tpl/files/trusted-key.pem"
|
|
||||||
cfg := &config.Config{Options: opts}
|
|
||||||
cfg.AllocatePorts(*(*[6]string)(ports[1:]))
|
|
||||||
|
|
||||||
server := grpc.NewServer(grpc.Creds(insecure.NewCredentials()))
|
h2c := up.Route().
|
||||||
grpc_testing.RegisterTestServiceServer(server, interop.NewTestServer())
|
From(env.SubdomainURL("grpc-h2c")).
|
||||||
go server.Serve(listener)
|
Policy(func(p *config.Policy) { p.AllowPublicUnauthenticatedAccess = true })
|
||||||
|
|
||||||
errC := make(chan error, 1)
|
env.AddUpstream(up)
|
||||||
go func() {
|
env.Start()
|
||||||
errC <- pomerium.Run(ctx, config.NewStaticSource(cfg))
|
snippets.WaitStartupComplete(env)
|
||||||
}()
|
|
||||||
|
|
||||||
t.Cleanup(func() {
|
|
||||||
ca()
|
|
||||||
assert.ErrorIs(t, context.Canceled, <-errC)
|
|
||||||
})
|
|
||||||
|
|
||||||
tlsConfig, err := credentials.NewClientTLSFromFile("../../integration/tpl/files/ca.pem", "")
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
t.Run("h2c", func(t *testing.T) {
|
t.Run("h2c", func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
recorder := env.NewLogRecorder()
|
||||||
|
|
||||||
cc, err := grpc.Dial(fmt.Sprintf("grpc-h2c.localhost.pomerium.io:%s", ports[0]), grpc.WithTransportCredentials(tlsConfig))
|
cc := up.Dial(h2c)
|
||||||
require.NoError(t, err)
|
|
||||||
client := grpc_testing.NewTestServiceClient(cc)
|
client := grpc_testing.NewTestServiceClient(cc)
|
||||||
var md metadata.MD
|
_, err := client.EmptyCall(env.Context(), &grpc_testing.Empty{})
|
||||||
_, err = client.EmptyCall(ctx, &grpc_testing.Empty{}, grpc.WaitForReady(true), grpc.Header(&md))
|
require.NoError(t, err)
|
||||||
cc.Close()
|
cc.Close()
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Contains(t, md, "x-envoy-upstream-service-time")
|
recorder.Match([]map[string]any{
|
||||||
|
{
|
||||||
|
"service": "envoy",
|
||||||
|
"path": "/grpc.testing.TestService/EmptyCall",
|
||||||
|
"message": "http-request",
|
||||||
|
"response-code-details": "via_upstream",
|
||||||
|
},
|
||||||
|
})
|
||||||
})
|
})
|
||||||
t.Run("http", func(t *testing.T) {
|
t.Run("http", func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
recorder := env.NewLogRecorder()
|
||||||
|
|
||||||
cc, err := grpc.Dial(fmt.Sprintf("grpc-http.localhost.pomerium.io:%s", ports[0]), grpc.WithTransportCredentials(tlsConfig))
|
cc := up.Dial(http)
|
||||||
require.NoError(t, err)
|
|
||||||
client := grpc_testing.NewTestServiceClient(cc)
|
client := grpc_testing.NewTestServiceClient(cc)
|
||||||
var md metadata.MD
|
_, err := client.UnaryCall(env.Context(), &grpc_testing.SimpleRequest{})
|
||||||
_, err = client.EmptyCall(ctx, &grpc_testing.Empty{}, grpc.WaitForReady(true), grpc.Trailer(&md))
|
require.Error(t, err)
|
||||||
cc.Close()
|
cc.Close()
|
||||||
stat := status.Convert(err)
|
|
||||||
assert.NotNil(t, stat)
|
recorder.Match([]map[string]any{
|
||||||
assert.Equal(t, stat.Code(), codes.Unavailable)
|
{
|
||||||
assert.NotContains(t, md, "x-envoy-upstream-service-time")
|
"service": "envoy",
|
||||||
assert.Contains(t, stat.Message(), "<!DOCTYPE html>")
|
"path": "/grpc.testing.TestService/UnaryCall",
|
||||||
assert.Contains(t, stat.Message(), "upstream_reset_before_response_started{protocol_error}")
|
"message": "http-request",
|
||||||
|
"response-code-details": "upstream_reset_before_response_started{protocol_error}",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHTTP(t *testing.T) {
|
||||||
|
env := testenv.New(t)
|
||||||
|
|
||||||
|
up := upstreams.HTTP(nil)
|
||||||
|
up.Handle("/foo", func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
fmt.Fprintln(w, "hello world")
|
||||||
|
})
|
||||||
|
|
||||||
|
route := up.Route().
|
||||||
|
From(env.SubdomainURL("http")).
|
||||||
|
Policy(func(p *config.Policy) { p.AllowPublicUnauthenticatedAccess = true })
|
||||||
|
|
||||||
|
env.AddUpstream(up)
|
||||||
|
env.Start()
|
||||||
|
|
||||||
|
recorder := env.NewLogRecorder()
|
||||||
|
|
||||||
|
resp, err := up.Get(route, upstreams.Path("/foo"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
defer resp.Body.Close()
|
||||||
|
data, err := io.ReadAll(resp.Body)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "hello world\n", string(data))
|
||||||
|
|
||||||
|
recorder.Match([]map[string]any{
|
||||||
|
{
|
||||||
|
"service": "envoy",
|
||||||
|
"path": "/foo",
|
||||||
|
"method": "GET",
|
||||||
|
"message": "http-request",
|
||||||
|
"response-code-details": "via_upstream",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClientCert(t *testing.T) {
|
||||||
|
env := testenv.New(t)
|
||||||
|
env.Add(scenarios.DownstreamMTLS(config.MTLSEnforcementRejectConnection))
|
||||||
|
|
||||||
|
up := upstreams.HTTP(nil)
|
||||||
|
up.Handle("/foo", func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
fmt.Fprintln(w, "hello world")
|
||||||
|
})
|
||||||
|
|
||||||
|
clientCert := env.NewClientCert()
|
||||||
|
|
||||||
|
route := up.Route().
|
||||||
|
From(env.SubdomainURL("http")).
|
||||||
|
PPL(fmt.Sprintf(`{"allow":{"and":["client_certificate":{"fingerprint":%q}]}}`, clientCert.Fingerprint()))
|
||||||
|
|
||||||
|
env.AddUpstream(up)
|
||||||
|
env.Start()
|
||||||
|
|
||||||
|
recorder := env.NewLogRecorder()
|
||||||
|
|
||||||
|
resp, err := up.Get(route, upstreams.Path("/foo"), upstreams.ClientCert(clientCert))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
defer resp.Body.Close()
|
||||||
|
data, err := io.ReadAll(resp.Body)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "hello world\n", string(data))
|
||||||
|
|
||||||
|
recorder.Match([]map[string]any{
|
||||||
|
{
|
||||||
|
"service": "envoy",
|
||||||
|
"path": "/foo",
|
||||||
|
"method": "GET",
|
||||||
|
"message": "http-request",
|
||||||
|
"response-code-details": "via_upstream",
|
||||||
|
"client-certificate": clientCert,
|
||||||
|
},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
1
config/envoyconfig/testdata/clusters.json
vendored
1
config/envoyconfig/testdata/clusters.json
vendored
|
@ -280,6 +280,7 @@
|
||||||
"endpoint": {
|
"endpoint": {
|
||||||
"address": {
|
"address": {
|
||||||
"pipe": {
|
"pipe": {
|
||||||
|
"mode": 384,
|
||||||
"path": "/tmp/pomerium-envoy-admin.sock"
|
"path": "/tmp/pomerium-envoy-admin.sock"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -129,6 +129,7 @@ func newManager(
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
|
cache.Stop()
|
||||||
return
|
return
|
||||||
case <-ticker.C:
|
case <-ticker.C:
|
||||||
err := mgr.renewConfigCerts(ctx)
|
err := mgr.renewConfigCerts(ctx)
|
||||||
|
|
54
internal/benchmarks/config_bench_test.go
Normal file
54
internal/benchmarks/config_bench_test.go
Normal file
|
@ -0,0 +1,54 @@
|
||||||
|
package benchmarks_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/internal/testenv"
|
||||||
|
"github.com/pomerium/pomerium/internal/testenv/snippets"
|
||||||
|
"github.com/pomerium/pomerium/internal/testenv/upstreams"
|
||||||
|
)
|
||||||
|
|
||||||
|
func BenchmarkStartupLatency(b *testing.B) {
|
||||||
|
for _, n := range []int{1, 10, 100, 1000, 10000} {
|
||||||
|
b.Run(fmt.Sprintf("routes=%d", n), func(b *testing.B) {
|
||||||
|
for range b.N {
|
||||||
|
env := testenv.New(b)
|
||||||
|
up := upstreams.HTTP(nil)
|
||||||
|
for i := range n {
|
||||||
|
up.Route().
|
||||||
|
From(env.SubdomainURL(fmt.Sprintf("from-%d", i))).
|
||||||
|
PPL(`{"allow":{"and":[{"accept":"true"}]}}`)
|
||||||
|
}
|
||||||
|
env.AddUpstream(up)
|
||||||
|
|
||||||
|
env.Start()
|
||||||
|
snippets.WaitStartupComplete(env, 60*time.Minute)
|
||||||
|
|
||||||
|
env.Stop()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkAppendRoutes(b *testing.B) {
|
||||||
|
for _, n := range []int{1, 10, 100, 1000, 10000} {
|
||||||
|
b.Run(fmt.Sprintf("routes=%d", n), func(b *testing.B) {
|
||||||
|
for range b.N {
|
||||||
|
env := testenv.New(b)
|
||||||
|
up := upstreams.HTTP(nil)
|
||||||
|
env.AddUpstream(up)
|
||||||
|
|
||||||
|
env.Start()
|
||||||
|
snippets.WaitStartupComplete(env)
|
||||||
|
for i := range n {
|
||||||
|
env.Add(up.Route().
|
||||||
|
From(env.SubdomainURL(fmt.Sprintf("from-%d", i))).
|
||||||
|
PPL(fmt.Sprintf(`{"allow":{"and":["email":{"is":"user-%d@example.com"}]}}`, i)))
|
||||||
|
}
|
||||||
|
env.Stop()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
87
internal/benchmarks/latency_bench_test.go
Normal file
87
internal/benchmarks/latency_bench_test.go
Normal file
|
@ -0,0 +1,87 @@
|
||||||
|
package benchmarks_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"flag"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"math/rand/v2"
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/internal/testenv"
|
||||||
|
"github.com/pomerium/pomerium/internal/testenv/scenarios"
|
||||||
|
"github.com/pomerium/pomerium/internal/testenv/snippets"
|
||||||
|
"github.com/pomerium/pomerium/internal/testenv/upstreams"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
numRoutes int
|
||||||
|
dumpErrLogs bool
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
flag.IntVar(&numRoutes, "routes", 100, "number of routes")
|
||||||
|
flag.BoolVar(&dumpErrLogs, "dump-err-logs", false, "if the test fails, write all captured logs to a file (testdata/<test-name>)")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequestLatency(t *testing.T) {
|
||||||
|
env := testenv.New(t, testenv.Silent())
|
||||||
|
users := []*scenarios.User{}
|
||||||
|
for i := range numRoutes {
|
||||||
|
users = append(users, &scenarios.User{
|
||||||
|
Email: fmt.Sprintf("user%d@example.com", i),
|
||||||
|
FirstName: fmt.Sprintf("Firstname%d", i),
|
||||||
|
LastName: fmt.Sprintf("Lastname%d", i),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
env.Add(scenarios.NewIDP(users))
|
||||||
|
|
||||||
|
up := upstreams.HTTP(nil)
|
||||||
|
up.Handle("/", func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
w.Write([]byte("OK"))
|
||||||
|
})
|
||||||
|
routes := make([]testenv.Route, numRoutes)
|
||||||
|
for i := range numRoutes {
|
||||||
|
routes[i] = up.Route().
|
||||||
|
From(env.SubdomainURL(fmt.Sprintf("from-%d", i))).
|
||||||
|
PPL(fmt.Sprintf(`{"allow":{"and":["email":{"is":"user%d@example.com"}]}}`, i))
|
||||||
|
}
|
||||||
|
env.AddUpstream(up)
|
||||||
|
|
||||||
|
env.Start()
|
||||||
|
snippets.WaitStartupComplete(env)
|
||||||
|
|
||||||
|
out := testing.Benchmark(func(b *testing.B) {
|
||||||
|
b.ReportAllocs()
|
||||||
|
b.RunParallel(func(pb *testing.PB) {
|
||||||
|
var rec *testenv.LogRecorder
|
||||||
|
if dumpErrLogs {
|
||||||
|
rec = env.NewLogRecorder(testenv.WithSkipCloseDelay())
|
||||||
|
}
|
||||||
|
for pb.Next() {
|
||||||
|
idx := rand.IntN(numRoutes)
|
||||||
|
resp, err := up.Get(routes[idx], upstreams.AuthenticateAs(fmt.Sprintf("user%d@example.com", idx)))
|
||||||
|
if !assert.NoError(b, err) {
|
||||||
|
filename := "TestRequestLatency_err.log"
|
||||||
|
if dumpErrLogs {
|
||||||
|
rec.DumpToFile(filename)
|
||||||
|
b.Logf("test logs written to %s", filename)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(b, resp.StatusCode, 200)
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
resp.Body.Close()
|
||||||
|
assert.NoError(b, err)
|
||||||
|
assert.Equal(b, "OK", string(body))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Log(out)
|
||||||
|
t.Logf("req/s: %f", float64(out.N)/out.T.Seconds())
|
||||||
|
|
||||||
|
env.Stop()
|
||||||
|
}
|
|
@ -74,11 +74,12 @@ func NewServer(
|
||||||
cfg *config.Config,
|
cfg *config.Config,
|
||||||
metricsMgr *config.MetricsManager,
|
metricsMgr *config.MetricsManager,
|
||||||
eventsMgr *events.Manager,
|
eventsMgr *events.Manager,
|
||||||
|
fileMgr *filemgr.Manager,
|
||||||
) (*Server, error) {
|
) (*Server, error) {
|
||||||
srv := &Server{
|
srv := &Server{
|
||||||
metricsMgr: metricsMgr,
|
metricsMgr: metricsMgr,
|
||||||
EventsMgr: eventsMgr,
|
EventsMgr: eventsMgr,
|
||||||
filemgr: filemgr.NewManager(),
|
filemgr: fileMgr,
|
||||||
reproxy: reproxy.New(),
|
reproxy: reproxy.New(),
|
||||||
haveSetCapacity: map[string]bool{},
|
haveSetCapacity: map[string]bool{},
|
||||||
updateConfig: make(chan *config.Config, 1),
|
updateConfig: make(chan *config.Config, 1),
|
||||||
|
|
|
@ -12,6 +12,7 @@ import (
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/config"
|
"github.com/pomerium/pomerium/config"
|
||||||
|
"github.com/pomerium/pomerium/config/envoyconfig/filemgr"
|
||||||
"github.com/pomerium/pomerium/internal/events"
|
"github.com/pomerium/pomerium/internal/events"
|
||||||
"github.com/pomerium/pomerium/pkg/netutil"
|
"github.com/pomerium/pomerium/pkg/netutil"
|
||||||
)
|
)
|
||||||
|
@ -38,7 +39,7 @@ func TestServerHTTP(t *testing.T) {
|
||||||
cfg.Options.SharedKey = "JDNjY2ITDlARvNaQXjc2Djk+GA6xeCy4KiozmZfdbTs="
|
cfg.Options.SharedKey = "JDNjY2ITDlARvNaQXjc2Djk+GA6xeCy4KiozmZfdbTs="
|
||||||
|
|
||||||
src := config.NewStaticSource(cfg)
|
src := config.NewStaticSource(cfg)
|
||||||
srv, err := NewServer(ctx, cfg, config.NewMetricsManager(ctx, src), events.New())
|
srv, err := NewServer(ctx, cfg, config.NewMetricsManager(ctx, src), events.New(), filemgr.NewManager(filemgr.WithCacheDir(t.TempDir())))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
go srv.Run(ctx)
|
go srv.Run(ctx)
|
||||||
|
|
||||||
|
|
|
@ -7,4 +7,6 @@ var (
|
||||||
DebugDisableZapLogger atomic.Bool
|
DebugDisableZapLogger atomic.Bool
|
||||||
// Debug option to suppress global warnings
|
// Debug option to suppress global warnings
|
||||||
DebugDisableGlobalWarnings atomic.Bool
|
DebugDisableGlobalWarnings atomic.Bool
|
||||||
|
// Debug option to suppress global (non-warning) messages
|
||||||
|
DebugDisableGlobalMessages atomic.Bool
|
||||||
)
|
)
|
||||||
|
|
838
internal/testenv/environment.go
Normal file
838
internal/testenv/environment.go
Normal file
|
@ -0,0 +1,838 @@
|
||||||
|
package testenv
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/rsa"
|
||||||
|
"crypto/sha256"
|
||||||
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
|
"crypto/x509/pkix"
|
||||||
|
"encoding/asn1"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/hex"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"math/big"
|
||||||
|
"math/bits"
|
||||||
|
"net"
|
||||||
|
"net/url"
|
||||||
|
"os"
|
||||||
|
"os/signal"
|
||||||
|
"path"
|
||||||
|
"path/filepath"
|
||||||
|
"runtime"
|
||||||
|
"strconv"
|
||||||
|
"sync"
|
||||||
|
"syscall"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/config"
|
||||||
|
"github.com/pomerium/pomerium/config/envoyconfig/filemgr"
|
||||||
|
"github.com/pomerium/pomerium/internal/log"
|
||||||
|
"github.com/pomerium/pomerium/internal/testenv/values"
|
||||||
|
"github.com/pomerium/pomerium/pkg/cmd/pomerium"
|
||||||
|
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||||
|
"github.com/pomerium/pomerium/pkg/health"
|
||||||
|
"github.com/pomerium/pomerium/pkg/netutil"
|
||||||
|
"github.com/pomerium/pomerium/pkg/slices"
|
||||||
|
"github.com/rs/zerolog"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"golang.org/x/sync/errgroup"
|
||||||
|
"google.golang.org/grpc/grpclog"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Environment is a lightweight integration test fixture that runs Pomerium
|
||||||
|
// in-process.
|
||||||
|
type Environment interface {
|
||||||
|
// Context returns the environment's root context. This context holds a
|
||||||
|
// top-level logger scoped to this environment. It will be canceled when
|
||||||
|
// Stop() is called, or during test cleanup.
|
||||||
|
Context() context.Context
|
||||||
|
|
||||||
|
Assert() *assert.Assertions
|
||||||
|
Require() *require.Assertions
|
||||||
|
|
||||||
|
// TempDir returns a unique temp directory for this context. Calling this
|
||||||
|
// function multiple times returns the same path.
|
||||||
|
TempDir() string
|
||||||
|
// CACert returns the test environment's root CA certificate and private key.
|
||||||
|
CACert() *tls.Certificate
|
||||||
|
// ServerCAs returns a new [*x509.CertPool] containing the root CA certificate
|
||||||
|
// used to sign the server cert and other test certificates.
|
||||||
|
ServerCAs() *x509.CertPool
|
||||||
|
// ServerCert returns the Pomerium server's certificate and private key.
|
||||||
|
ServerCert() *tls.Certificate
|
||||||
|
// NewClientCert generates a new client certificate signed by the root CA
|
||||||
|
// certificate. One or more optional templates can be given, which can be
|
||||||
|
// used to set or override certain parameters when creating a certificate,
|
||||||
|
// including subject, SANs, or extensions. If more than one template is
|
||||||
|
// provided, they will be applied in order from left to right.
|
||||||
|
//
|
||||||
|
// By default (unless overridden in a template), the certificate will have
|
||||||
|
// its Common Name set to the file:line string of the call site. Calls to
|
||||||
|
// NewClientCert() on different lines will have different subjects. If
|
||||||
|
// multiple certs with the same subject are needed, wrap the call to this
|
||||||
|
// function in another helper function, or separate calls with commas on the
|
||||||
|
// same line.
|
||||||
|
NewClientCert(templateOverrides ...*x509.Certificate) *Certificate
|
||||||
|
|
||||||
|
NewServerCert(templateOverrides ...*x509.Certificate) *Certificate
|
||||||
|
|
||||||
|
AuthenticateURL() values.Value[string]
|
||||||
|
DatabrokerURL() values.Value[string]
|
||||||
|
Ports() Ports
|
||||||
|
SharedSecret() []byte
|
||||||
|
CookieSecret() []byte
|
||||||
|
|
||||||
|
// Add adds the given [Modifier] to the environment. All modifiers will be
|
||||||
|
// invoked upon calling Start() to apply individual modifications to the
|
||||||
|
// configuration before starting the Pomerium server.
|
||||||
|
Add(m Modifier)
|
||||||
|
// AddTask adds the given [Task] to the environment. All tasks will be
|
||||||
|
// started in separate goroutines upon calling Start(). If any tasks exit
|
||||||
|
// with an error, the environment will be stopped and the test will fail.
|
||||||
|
AddTask(r Task)
|
||||||
|
// AddUpstream adds the given [Upstream] to the environment. This function is
|
||||||
|
// equivalent to calling both Add() and AddTask() with the upstream, but
|
||||||
|
// improves readability.
|
||||||
|
AddUpstream(u Upstream)
|
||||||
|
|
||||||
|
// Start starts the test environment, and adds a call to Stop() as a cleanup
|
||||||
|
// hook to the environment's [testing.T]. All previously added [Modifier]
|
||||||
|
// instances are invoked in order to build the configuration, and all
|
||||||
|
// previously added [Task] instances are started in the background.
|
||||||
|
//
|
||||||
|
// Calling Start() more than once, Calling Start() after Stop(), or calling
|
||||||
|
// any of the Add* functions after Start() will panic.
|
||||||
|
Start()
|
||||||
|
// Stop stops the test environment. Calling this function more than once has
|
||||||
|
// no effect. It is usually not necessary to call Stop() directly unless you
|
||||||
|
// need to stop the test environment before the test is completed.
|
||||||
|
Stop()
|
||||||
|
|
||||||
|
// SubdomainURL returns a string [values.Value] which will contain a complete
|
||||||
|
// URL for the given subdomain of the server's domain (given by its serving
|
||||||
|
// certificate), including the 'https://' scheme and random http server port.
|
||||||
|
// This value will only be resolved some time after Start() is called, and
|
||||||
|
// can be used as the 'from' value for routes.
|
||||||
|
SubdomainURL(subdomain string) values.Value[string]
|
||||||
|
|
||||||
|
// NewLogRecorder returns a new [*LogRecorder] and starts capturing logs for
|
||||||
|
// the Pomerium server and Envoy.
|
||||||
|
NewLogRecorder(opts ...LogRecorderOption) *LogRecorder
|
||||||
|
|
||||||
|
// OnStateChanged registers a callback to be invoked when the environment's
|
||||||
|
// state changes to the given state. The callback is invoked in a separate
|
||||||
|
// goroutine.
|
||||||
|
OnStateChanged(state EnvironmentState, callback func())
|
||||||
|
}
|
||||||
|
|
||||||
|
type Certificate tls.Certificate
|
||||||
|
|
||||||
|
func (c *Certificate) Fingerprint() string {
|
||||||
|
sum := sha256.Sum256(c.Leaf.Raw)
|
||||||
|
return hex.EncodeToString(sum[:])
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Certificate) SPKIHash() string {
|
||||||
|
sum := sha256.Sum256(c.Leaf.RawSubjectPublicKeyInfo)
|
||||||
|
return base64.StdEncoding.EncodeToString(sum[:])
|
||||||
|
}
|
||||||
|
|
||||||
|
type EnvironmentState uint32
|
||||||
|
|
||||||
|
const NotRunning EnvironmentState = 0
|
||||||
|
|
||||||
|
const (
|
||||||
|
Starting EnvironmentState = 1 << iota
|
||||||
|
Running
|
||||||
|
Stopping
|
||||||
|
Stopped
|
||||||
|
)
|
||||||
|
|
||||||
|
func (e EnvironmentState) String() string {
|
||||||
|
switch e {
|
||||||
|
case NotRunning:
|
||||||
|
return "NotRunning"
|
||||||
|
case Starting:
|
||||||
|
return "Starting"
|
||||||
|
case Running:
|
||||||
|
return "Running"
|
||||||
|
case Stopping:
|
||||||
|
return "Stopping"
|
||||||
|
case Stopped:
|
||||||
|
return "Stopped"
|
||||||
|
default:
|
||||||
|
return fmt.Sprintf("EnvironmentState(%d)", e)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type environment struct {
|
||||||
|
EnvironmentOptions
|
||||||
|
t testing.TB
|
||||||
|
assert *assert.Assertions
|
||||||
|
require *require.Assertions
|
||||||
|
tempDir string
|
||||||
|
domain string
|
||||||
|
ports Ports
|
||||||
|
sharedSecret [32]byte
|
||||||
|
cookieSecret [32]byte
|
||||||
|
workspaceFolder string
|
||||||
|
silent bool
|
||||||
|
|
||||||
|
ctx context.Context
|
||||||
|
cancel context.CancelCauseFunc
|
||||||
|
cleanupOnce sync.Once
|
||||||
|
logWriter *log.MultiWriter
|
||||||
|
|
||||||
|
mods []WithCaller[Modifier]
|
||||||
|
tasks []WithCaller[Task]
|
||||||
|
taskErrGroup *errgroup.Group
|
||||||
|
|
||||||
|
stateMu sync.Mutex
|
||||||
|
state EnvironmentState
|
||||||
|
stateChangeListeners map[EnvironmentState][]func()
|
||||||
|
|
||||||
|
src *configSource
|
||||||
|
}
|
||||||
|
|
||||||
|
type EnvironmentOptions struct {
|
||||||
|
debug bool
|
||||||
|
pauseOnFailure bool
|
||||||
|
forceSilent bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type EnvironmentOption func(*EnvironmentOptions)
|
||||||
|
|
||||||
|
func (o *EnvironmentOptions) apply(opts ...EnvironmentOption) {
|
||||||
|
for _, op := range opts {
|
||||||
|
op(o)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Debug(enable ...bool) EnvironmentOption {
|
||||||
|
if len(enable) == 0 {
|
||||||
|
enable = append(enable, true)
|
||||||
|
}
|
||||||
|
return func(o *EnvironmentOptions) {
|
||||||
|
o.debug = enable[0]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func PauseOnFailure(enable ...bool) EnvironmentOption {
|
||||||
|
if len(enable) == 0 {
|
||||||
|
enable = append(enable, true)
|
||||||
|
}
|
||||||
|
return func(o *EnvironmentOptions) {
|
||||||
|
o.pauseOnFailure = enable[0]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Silent(silent ...bool) EnvironmentOption {
|
||||||
|
if len(silent) == 0 {
|
||||||
|
silent = append(silent, true)
|
||||||
|
}
|
||||||
|
return func(o *EnvironmentOptions) {
|
||||||
|
o.forceSilent = silent[0]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var setGrpcLoggerOnce sync.Once
|
||||||
|
|
||||||
|
func New(t testing.TB, opts ...EnvironmentOption) Environment {
|
||||||
|
if runtime.GOOS != "linux" {
|
||||||
|
t.Skip("test environment only supported on linux")
|
||||||
|
}
|
||||||
|
options := EnvironmentOptions{}
|
||||||
|
options.apply(opts...)
|
||||||
|
if testing.Short() {
|
||||||
|
t.Helper()
|
||||||
|
t.Skip("test environment disabled in short mode")
|
||||||
|
}
|
||||||
|
databroker.DebugUseFasterBackoff.Store(true)
|
||||||
|
workspaceFolder, err := os.Getwd()
|
||||||
|
require.NoError(t, err)
|
||||||
|
for {
|
||||||
|
if _, err := os.Stat(filepath.Join(workspaceFolder, ".git")); err == nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
workspaceFolder = filepath.Dir(workspaceFolder)
|
||||||
|
if workspaceFolder == "/" {
|
||||||
|
panic("could not find workspace root")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
workspaceFolder, err = filepath.Abs(workspaceFolder)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
writer := log.NewMultiWriter()
|
||||||
|
silent := options.forceSilent || isSilent(t)
|
||||||
|
if silent {
|
||||||
|
// this sets the global zap level to fatal, then resets the global zerolog
|
||||||
|
// level to debug
|
||||||
|
log.SetLevel(zerolog.FatalLevel)
|
||||||
|
zerolog.SetGlobalLevel(zerolog.DebugLevel)
|
||||||
|
} else {
|
||||||
|
log.SetLevel(zerolog.InfoLevel)
|
||||||
|
writer.Add(os.Stdout)
|
||||||
|
}
|
||||||
|
log.DebugDisableGlobalWarnings.Store(silent)
|
||||||
|
log.DebugDisableGlobalMessages.Store(silent)
|
||||||
|
log.DebugDisableZapLogger.Store(silent)
|
||||||
|
setGrpcLoggerOnce.Do(func() {
|
||||||
|
grpclog.SetLoggerV2(grpclog.NewLoggerV2WithVerbosity(io.Discard, io.Discard, io.Discard, 0))
|
||||||
|
})
|
||||||
|
logger := zerolog.New(writer).With().Timestamp().Logger().Level(zerolog.DebugLevel)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancelCause(logger.WithContext(context.Background()))
|
||||||
|
taskErrGroup, ctx := errgroup.WithContext(ctx)
|
||||||
|
|
||||||
|
e := &environment{
|
||||||
|
EnvironmentOptions: options,
|
||||||
|
t: t,
|
||||||
|
assert: assert.New(t),
|
||||||
|
require: require.New(t),
|
||||||
|
tempDir: t.TempDir(),
|
||||||
|
ports: Ports{
|
||||||
|
ProxyHTTP: values.Deferred[int](),
|
||||||
|
ProxyGRPC: values.Deferred[int](),
|
||||||
|
GRPC: values.Deferred[int](),
|
||||||
|
HTTP: values.Deferred[int](),
|
||||||
|
Outbound: values.Deferred[int](),
|
||||||
|
Metrics: values.Deferred[int](),
|
||||||
|
Debug: values.Deferred[int](),
|
||||||
|
ALPN: values.Deferred[int](),
|
||||||
|
},
|
||||||
|
workspaceFolder: workspaceFolder,
|
||||||
|
silent: silent,
|
||||||
|
ctx: ctx,
|
||||||
|
cancel: cancel,
|
||||||
|
logWriter: writer,
|
||||||
|
taskErrGroup: taskErrGroup,
|
||||||
|
}
|
||||||
|
_, err = rand.Read(e.sharedSecret[:])
|
||||||
|
require.NoError(t, err)
|
||||||
|
_, err = rand.Read(e.cookieSecret[:])
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
health.SetProvider(e)
|
||||||
|
|
||||||
|
require.NoError(t, os.Mkdir(filepath.Join(e.tempDir, "certs"), 0o777))
|
||||||
|
copyFile := func(src, dstRel string) {
|
||||||
|
data, err := os.ReadFile(src)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NoError(t, os.WriteFile(filepath.Join(e.tempDir, dstRel), data, 0o600))
|
||||||
|
}
|
||||||
|
|
||||||
|
certsToCopy := []string{
|
||||||
|
"trusted.pem",
|
||||||
|
"trusted-key.pem",
|
||||||
|
"ca.pem",
|
||||||
|
"ca-key.pem",
|
||||||
|
}
|
||||||
|
for _, crt := range certsToCopy {
|
||||||
|
copyFile(filepath.Join(workspaceFolder, "integration/tpl/files", crt), filepath.Join("certs/", filepath.Base(crt)))
|
||||||
|
}
|
||||||
|
e.domain = wildcardDomain(e.ServerCert().Leaf.DNSNames)
|
||||||
|
|
||||||
|
return e
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *environment) debugf(format string, args ...any) {
|
||||||
|
if !e.debug {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
e.t.Logf("\x1b[34m[debug] "+format+"\x1b[0m", args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
type WithCaller[T any] struct {
|
||||||
|
Caller string
|
||||||
|
Value T
|
||||||
|
}
|
||||||
|
|
||||||
|
type Ports struct {
|
||||||
|
ProxyHTTP values.MutableValue[int]
|
||||||
|
ProxyGRPC values.MutableValue[int]
|
||||||
|
GRPC values.MutableValue[int]
|
||||||
|
HTTP values.MutableValue[int]
|
||||||
|
Outbound values.MutableValue[int]
|
||||||
|
Metrics values.MutableValue[int]
|
||||||
|
Debug values.MutableValue[int]
|
||||||
|
ALPN values.MutableValue[int]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *environment) TempDir() string {
|
||||||
|
return e.tempDir
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *environment) Context() context.Context {
|
||||||
|
return ContextWithEnv(e.ctx, e)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *environment) Assert() *assert.Assertions {
|
||||||
|
return e.assert
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *environment) Require() *require.Assertions {
|
||||||
|
return e.require
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *environment) SubdomainURL(subdomain string) values.Value[string] {
|
||||||
|
return values.Bind(e.ports.ProxyHTTP, func(port int) string {
|
||||||
|
return fmt.Sprintf("https://%s.%s:%d", subdomain, e.domain, port)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *environment) AuthenticateURL() values.Value[string] {
|
||||||
|
return e.SubdomainURL("authenticate")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *environment) DatabrokerURL() values.Value[string] {
|
||||||
|
return values.Bind(e.ports.Outbound, func(port int) string {
|
||||||
|
return fmt.Sprintf("127.0.0.1:%d", port)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *environment) Ports() Ports {
|
||||||
|
return e.ports
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *environment) CACert() *tls.Certificate {
|
||||||
|
caCert, err := tls.LoadX509KeyPair(
|
||||||
|
filepath.Join(e.tempDir, "certs", "ca.pem"),
|
||||||
|
filepath.Join(e.tempDir, "certs", "ca-key.pem"),
|
||||||
|
)
|
||||||
|
require.NoError(e.t, err)
|
||||||
|
return &caCert
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *environment) ServerCAs() *x509.CertPool {
|
||||||
|
pool := x509.NewCertPool()
|
||||||
|
caCert, err := os.ReadFile(filepath.Join(e.tempDir, "certs", "ca.pem"))
|
||||||
|
require.NoError(e.t, err)
|
||||||
|
pool.AppendCertsFromPEM(caCert)
|
||||||
|
return pool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *environment) ServerCert() *tls.Certificate {
|
||||||
|
serverCert, err := tls.LoadX509KeyPair(
|
||||||
|
filepath.Join(e.tempDir, "certs", "trusted.pem"),
|
||||||
|
filepath.Join(e.tempDir, "certs", "trusted-key.pem"),
|
||||||
|
)
|
||||||
|
require.NoError(e.t, err)
|
||||||
|
return &serverCert
|
||||||
|
}
|
||||||
|
|
||||||
|
// Used as the context's cancel cause during normal cleanup
|
||||||
|
var ErrCauseTestCleanup = errors.New("test cleanup")
|
||||||
|
|
||||||
|
// Used as the context's cancel cause when Stop() is called
|
||||||
|
var ErrCauseManualStop = errors.New("Stop() called")
|
||||||
|
|
||||||
|
func (e *environment) Start() {
|
||||||
|
e.debugf("Start()")
|
||||||
|
e.advanceState(Starting)
|
||||||
|
e.t.Cleanup(e.cleanup)
|
||||||
|
e.t.Setenv("TMPDIR", e.TempDir())
|
||||||
|
e.debugf("temp dir: %s", e.TempDir())
|
||||||
|
|
||||||
|
cfg := &config.Config{
|
||||||
|
Options: config.NewDefaultOptions(),
|
||||||
|
}
|
||||||
|
ports, err := netutil.AllocatePorts(8)
|
||||||
|
require.NoError(e.t, err)
|
||||||
|
atoi := func(str string) int {
|
||||||
|
p, err := strconv.Atoi(str)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
e.ports.ProxyHTTP.Resolve(atoi(ports[0]))
|
||||||
|
e.ports.ProxyGRPC.Resolve(atoi(ports[1]))
|
||||||
|
e.ports.GRPC.Resolve(atoi(ports[2]))
|
||||||
|
e.ports.HTTP.Resolve(atoi(ports[3]))
|
||||||
|
e.ports.Outbound.Resolve(atoi(ports[4]))
|
||||||
|
e.ports.Metrics.Resolve(atoi(ports[5]))
|
||||||
|
e.ports.Debug.Resolve(atoi(ports[6]))
|
||||||
|
e.ports.ALPN.Resolve(atoi(ports[7]))
|
||||||
|
cfg.AllocatePorts(*(*[6]string)(ports[2:]))
|
||||||
|
|
||||||
|
cfg.Options.AutocertOptions = config.AutocertOptions{Enable: false}
|
||||||
|
cfg.Options.Services = "all"
|
||||||
|
cfg.Options.LogLevel = config.LogLevelDebug
|
||||||
|
cfg.Options.ProxyLogLevel = config.LogLevelInfo
|
||||||
|
cfg.Options.Addr = fmt.Sprintf("127.0.0.1:%d", e.ports.ProxyHTTP.Value())
|
||||||
|
cfg.Options.GRPCAddr = fmt.Sprintf("127.0.0.1:%d", e.ports.ProxyGRPC.Value())
|
||||||
|
cfg.Options.CAFile = filepath.Join(e.tempDir, "certs", "ca.pem")
|
||||||
|
cfg.Options.CertFile = filepath.Join(e.tempDir, "certs", "trusted.pem")
|
||||||
|
cfg.Options.KeyFile = filepath.Join(e.tempDir, "certs", "trusted-key.pem")
|
||||||
|
cfg.Options.AuthenticateURLString = e.AuthenticateURL().Value()
|
||||||
|
cfg.Options.DataBrokerStorageType = "memory"
|
||||||
|
cfg.Options.SharedKey = base64.StdEncoding.EncodeToString(e.sharedSecret[:])
|
||||||
|
cfg.Options.CookieSecret = base64.StdEncoding.EncodeToString(e.cookieSecret[:])
|
||||||
|
cfg.Options.AccessLogFields = []log.AccessLogField{
|
||||||
|
log.AccessLogFieldAuthority,
|
||||||
|
log.AccessLogFieldDuration,
|
||||||
|
log.AccessLogFieldForwardedFor,
|
||||||
|
log.AccessLogFieldIP,
|
||||||
|
log.AccessLogFieldMethod,
|
||||||
|
log.AccessLogFieldPath,
|
||||||
|
log.AccessLogFieldQuery,
|
||||||
|
log.AccessLogFieldReferer,
|
||||||
|
log.AccessLogFieldRequestID,
|
||||||
|
log.AccessLogFieldResponseCode,
|
||||||
|
log.AccessLogFieldResponseCodeDetails,
|
||||||
|
log.AccessLogFieldSize,
|
||||||
|
log.AccessLogFieldUpstreamCluster,
|
||||||
|
log.AccessLogFieldUserAgent,
|
||||||
|
log.AccessLogFieldClientCertificate,
|
||||||
|
}
|
||||||
|
|
||||||
|
e.src = &configSource{cfg: cfg}
|
||||||
|
e.AddTask(TaskFunc(func(ctx context.Context) error {
|
||||||
|
fileMgr := filemgr.NewManager(filemgr.WithCacheDir(filepath.Join(e.TempDir(), "cache")))
|
||||||
|
for _, mod := range e.mods {
|
||||||
|
mod.Value.Modify(cfg)
|
||||||
|
require.NoError(e.t, cfg.Options.Validate(), "invoking modifier resulted in an invalid configuration:\nadded by: "+mod.Caller)
|
||||||
|
}
|
||||||
|
return pomerium.Run(ctx, e.src, pomerium.WithOverrideFileManager(fileMgr))
|
||||||
|
}))
|
||||||
|
|
||||||
|
for i, task := range e.tasks {
|
||||||
|
log.Ctx(e.ctx).Debug().Str("caller", task.Caller).Msgf("starting task %d", i)
|
||||||
|
e.taskErrGroup.Go(func() error {
|
||||||
|
defer log.Ctx(e.ctx).Debug().Str("caller", task.Caller).Msgf("task %d exited", i)
|
||||||
|
return task.Value.Run(e.ctx)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
runtime.Gosched()
|
||||||
|
|
||||||
|
e.advanceState(Running)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *environment) NewClientCert(templateOverrides ...*x509.Certificate) *Certificate {
|
||||||
|
caCert := e.CACert()
|
||||||
|
|
||||||
|
priv, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||||
|
require.NoError(e.t, err)
|
||||||
|
|
||||||
|
sn, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128))
|
||||||
|
require.NoError(e.t, err)
|
||||||
|
now := time.Now()
|
||||||
|
tmpl := &x509.Certificate{
|
||||||
|
SerialNumber: sn,
|
||||||
|
Subject: pkix.Name{
|
||||||
|
CommonName: getCaller(),
|
||||||
|
},
|
||||||
|
NotBefore: now,
|
||||||
|
NotAfter: now.Add(12 * time.Hour),
|
||||||
|
KeyUsage: x509.KeyUsageDigitalSignature,
|
||||||
|
ExtKeyUsage: []x509.ExtKeyUsage{
|
||||||
|
x509.ExtKeyUsageClientAuth,
|
||||||
|
},
|
||||||
|
BasicConstraintsValid: true,
|
||||||
|
}
|
||||||
|
for _, override := range templateOverrides {
|
||||||
|
tmpl.CRLDistributionPoints = slices.Unique(append(tmpl.CRLDistributionPoints, override.CRLDistributionPoints...))
|
||||||
|
tmpl.DNSNames = slices.Unique(append(tmpl.DNSNames, override.DNSNames...))
|
||||||
|
tmpl.EmailAddresses = slices.Unique(append(tmpl.EmailAddresses, override.EmailAddresses...))
|
||||||
|
tmpl.ExtraExtensions = append(tmpl.ExtraExtensions, override.ExtraExtensions...)
|
||||||
|
tmpl.IPAddresses = slices.UniqueBy(append(tmpl.IPAddresses, override.IPAddresses...), net.IP.String)
|
||||||
|
tmpl.URIs = slices.UniqueBy(append(tmpl.URIs, override.URIs...), (*url.URL).String)
|
||||||
|
tmpl.UnknownExtKeyUsage = slices.UniqueBy(append(tmpl.UnknownExtKeyUsage, override.UnknownExtKeyUsage...), asn1.ObjectIdentifier.String)
|
||||||
|
seq := override.Subject.ToRDNSequence()
|
||||||
|
tmpl.Subject.FillFromRDNSequence(&seq)
|
||||||
|
tmpl.KeyUsage |= override.KeyUsage
|
||||||
|
tmpl.ExtKeyUsage = slices.Unique(append(tmpl.ExtKeyUsage, override.ExtKeyUsage...))
|
||||||
|
}
|
||||||
|
|
||||||
|
clientCertDER, err := x509.CreateCertificate(rand.Reader, tmpl, caCert.Leaf, priv.Public(), caCert.PrivateKey)
|
||||||
|
require.NoError(e.t, err)
|
||||||
|
|
||||||
|
cert, err := x509.ParseCertificate(clientCertDER)
|
||||||
|
require.NoError(e.t, err)
|
||||||
|
e.debugf("provisioned client certificate for %s", cert.Subject.String())
|
||||||
|
|
||||||
|
clientCert := &tls.Certificate{
|
||||||
|
Certificate: [][]byte{cert.Raw, caCert.Leaf.Raw},
|
||||||
|
PrivateKey: priv,
|
||||||
|
Leaf: cert,
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = clientCert.Leaf.Verify(x509.VerifyOptions{
|
||||||
|
KeyUsages: []x509.ExtKeyUsage{
|
||||||
|
x509.ExtKeyUsageClientAuth,
|
||||||
|
},
|
||||||
|
Roots: e.ServerCAs(),
|
||||||
|
})
|
||||||
|
require.NoError(e.t, err, "bug: generated client cert is not valid")
|
||||||
|
return (*Certificate)(clientCert)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *environment) NewServerCert(templateOverrides ...*x509.Certificate) *Certificate {
|
||||||
|
caCert := e.CACert()
|
||||||
|
|
||||||
|
priv, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||||
|
require.NoError(e.t, err)
|
||||||
|
|
||||||
|
sn, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128))
|
||||||
|
require.NoError(e.t, err)
|
||||||
|
now := time.Now()
|
||||||
|
tmpl := &x509.Certificate{
|
||||||
|
SerialNumber: sn,
|
||||||
|
NotBefore: now,
|
||||||
|
NotAfter: now.Add(12 * time.Hour),
|
||||||
|
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment,
|
||||||
|
ExtKeyUsage: []x509.ExtKeyUsage{
|
||||||
|
x509.ExtKeyUsageServerAuth,
|
||||||
|
},
|
||||||
|
BasicConstraintsValid: true,
|
||||||
|
}
|
||||||
|
for _, override := range templateOverrides {
|
||||||
|
tmpl.DNSNames = slices.Unique(append(tmpl.DNSNames, override.DNSNames...))
|
||||||
|
tmpl.IPAddresses = slices.UniqueBy(append(tmpl.IPAddresses, override.IPAddresses...), net.IP.String)
|
||||||
|
}
|
||||||
|
certDER, err := x509.CreateCertificate(rand.Reader, tmpl, caCert.Leaf, priv.Public(), caCert.PrivateKey)
|
||||||
|
require.NoError(e.t, err)
|
||||||
|
|
||||||
|
cert, err := x509.ParseCertificate(certDER)
|
||||||
|
require.NoError(e.t, err)
|
||||||
|
e.debugf("provisioned server certificate for %v", cert.DNSNames)
|
||||||
|
|
||||||
|
tlsCert := &tls.Certificate{
|
||||||
|
Certificate: [][]byte{cert.Raw, caCert.Leaf.Raw},
|
||||||
|
PrivateKey: priv,
|
||||||
|
Leaf: cert,
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = tlsCert.Leaf.Verify(x509.VerifyOptions{Roots: e.ServerCAs()})
|
||||||
|
require.NoError(e.t, err, "bug: generated client cert is not valid")
|
||||||
|
return (*Certificate)(tlsCert)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *environment) SharedSecret() []byte {
|
||||||
|
return bytes.Clone(e.sharedSecret[:])
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *environment) CookieSecret() []byte {
|
||||||
|
return bytes.Clone(e.cookieSecret[:])
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *environment) Stop() {
|
||||||
|
if b, ok := e.t.(*testing.B); ok {
|
||||||
|
// when calling Stop() manually, ensure we aren't timing this
|
||||||
|
b.StopTimer()
|
||||||
|
defer b.StartTimer()
|
||||||
|
}
|
||||||
|
e.cleanupOnce.Do(func() {
|
||||||
|
e.debugf("stop: Stop() called manually")
|
||||||
|
e.advanceState(Stopping)
|
||||||
|
e.cancel(ErrCauseManualStop)
|
||||||
|
err := e.taskErrGroup.Wait()
|
||||||
|
e.advanceState(Stopped)
|
||||||
|
e.debugf("stop: done waiting")
|
||||||
|
assert.ErrorIs(e.t, err, ErrCauseManualStop)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *environment) cleanup() {
|
||||||
|
e.cleanupOnce.Do(func() {
|
||||||
|
e.debugf("stop: test cleanup")
|
||||||
|
if e.t.Failed() {
|
||||||
|
if e.pauseOnFailure {
|
||||||
|
e.t.Log("\x1b[31m*** pausing on test failure; continue with ctrl+c ***\x1b[0m")
|
||||||
|
c := make(chan os.Signal, 1)
|
||||||
|
signal.Notify(c, syscall.SIGINT)
|
||||||
|
<-c
|
||||||
|
e.t.Log("\x1b[31mctrl+c received, continuing\x1b[0m")
|
||||||
|
signal.Stop(c)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
e.advanceState(Stopping)
|
||||||
|
e.cancel(ErrCauseTestCleanup)
|
||||||
|
err := e.taskErrGroup.Wait()
|
||||||
|
e.advanceState(Stopped)
|
||||||
|
e.debugf("stop: done waiting")
|
||||||
|
assert.ErrorIs(e.t, err, ErrCauseTestCleanup)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *environment) Add(m Modifier) {
|
||||||
|
e.t.Helper()
|
||||||
|
caller := getCaller()
|
||||||
|
e.debugf("Add: %T from %s", m, caller)
|
||||||
|
switch e.getState() {
|
||||||
|
case NotRunning:
|
||||||
|
for _, mod := range e.mods {
|
||||||
|
if mod.Value == m {
|
||||||
|
e.t.Fatalf("test bug: duplicate modifier added\nfirst added by: %s", mod.Caller)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
e.mods = append(e.mods, WithCaller[Modifier]{
|
||||||
|
Caller: caller,
|
||||||
|
Value: m,
|
||||||
|
})
|
||||||
|
e.debugf("Add: state=NotRunning; calling Attach")
|
||||||
|
m.Attach(e.Context())
|
||||||
|
case Starting:
|
||||||
|
panic("test bug: cannot call Add() before Start() has returned")
|
||||||
|
case Running:
|
||||||
|
e.debugf("Add: state=Running; calling ModifyConfig")
|
||||||
|
e.src.ModifyConfig(e.ctx, m)
|
||||||
|
case Stopped, Stopping:
|
||||||
|
panic("test bug: cannot call Add() after Stop()")
|
||||||
|
default:
|
||||||
|
panic(fmt.Sprintf("unexpected environment state: %s", e.getState()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *environment) AddTask(t Task) {
|
||||||
|
e.t.Helper()
|
||||||
|
caller := getCaller()
|
||||||
|
e.debugf("AddTask: %T from %s", t, caller)
|
||||||
|
for _, task := range e.tasks {
|
||||||
|
if task.Value == t {
|
||||||
|
e.t.Fatalf("test bug: duplicate task added\nfirst added by: %s", task.Caller)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
e.tasks = append(e.tasks, WithCaller[Task]{
|
||||||
|
Caller: getCaller(),
|
||||||
|
Value: t,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *environment) AddUpstream(up Upstream) {
|
||||||
|
e.t.Helper()
|
||||||
|
caller := getCaller()
|
||||||
|
e.debugf("AddUpstream: %T from %s", up, caller)
|
||||||
|
e.Add(up)
|
||||||
|
e.AddTask(up)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReportError implements health.Provider.
|
||||||
|
func (e *environment) ReportError(check health.Check, err error, attributes ...health.Attr) {
|
||||||
|
// note: don't use e.t.Fatal here, it will deadlock
|
||||||
|
panic(fmt.Sprintf("%s: %v %v", check, err, attributes))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReportOK implements health.Provider.
|
||||||
|
func (e *environment) ReportOK(_ health.Check, _ ...health.Attr) {}
|
||||||
|
|
||||||
|
func (e *environment) advanceState(newState EnvironmentState) {
|
||||||
|
e.stateMu.Lock()
|
||||||
|
defer e.stateMu.Unlock()
|
||||||
|
if newState <= e.state {
|
||||||
|
panic(fmt.Sprintf("internal test environment bug: changed state to <= current: newState=%s, current=%s", newState, e.state))
|
||||||
|
}
|
||||||
|
e.debugf("state %s -> %s", e.state.String(), newState.String())
|
||||||
|
e.state = newState
|
||||||
|
e.debugf("notifying %d listeners of state change", len(e.stateChangeListeners[newState]))
|
||||||
|
for _, listener := range e.stateChangeListeners[newState] {
|
||||||
|
go listener()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *environment) getState() EnvironmentState {
|
||||||
|
e.stateMu.Lock()
|
||||||
|
defer e.stateMu.Unlock()
|
||||||
|
return e.state
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *environment) OnStateChanged(state EnvironmentState, callback func()) {
|
||||||
|
e.stateMu.Lock()
|
||||||
|
defer e.stateMu.Unlock()
|
||||||
|
|
||||||
|
if e.state&state != 0 {
|
||||||
|
go callback()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// add change listeners for all states, if there are multiple bits set
|
||||||
|
for state > 0 {
|
||||||
|
stateBit := EnvironmentState(bits.TrailingZeros32(uint32(state)))
|
||||||
|
state &= (state - 1)
|
||||||
|
e.stateChangeListeners[stateBit] = append(e.stateChangeListeners[stateBit], callback)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func getCaller(skip ...int) string {
|
||||||
|
if len(skip) == 0 {
|
||||||
|
skip = append(skip, 3)
|
||||||
|
}
|
||||||
|
callers := make([]uintptr, 8)
|
||||||
|
runtime.Callers(skip[0], callers)
|
||||||
|
frames := runtime.CallersFrames(callers)
|
||||||
|
var caller string
|
||||||
|
for {
|
||||||
|
next, ok := frames.Next()
|
||||||
|
if !ok {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if path.Base(next.Function) == "testenv.(*environment).AddUpstream" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
caller = fmt.Sprintf("%s:%d", next.File, next.Line)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
return caller
|
||||||
|
}
|
||||||
|
|
||||||
|
func wildcardDomain(names []string) string {
|
||||||
|
for _, name := range names {
|
||||||
|
if name[0] == '*' {
|
||||||
|
return name[2:]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
panic("test bug: no wildcard domain in certificate")
|
||||||
|
}
|
||||||
|
|
||||||
|
func isSilent(t testing.TB) bool {
|
||||||
|
switch t.(type) {
|
||||||
|
case *testing.B:
|
||||||
|
return !slices.Contains(os.Args, "-test.v=true")
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type configSource struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
cfg *config.Config
|
||||||
|
lis []config.ChangeListener
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ config.Source = (*configSource)(nil)
|
||||||
|
|
||||||
|
// GetConfig implements config.Source.
|
||||||
|
func (src *configSource) GetConfig() *config.Config {
|
||||||
|
src.mu.Lock()
|
||||||
|
defer src.mu.Unlock()
|
||||||
|
|
||||||
|
return src.cfg
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnConfigChange implements config.Source.
|
||||||
|
func (src *configSource) OnConfigChange(_ context.Context, li config.ChangeListener) {
|
||||||
|
src.mu.Lock()
|
||||||
|
defer src.mu.Unlock()
|
||||||
|
|
||||||
|
src.lis = append(src.lis, li)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ModifyConfig updates the current configuration by applying a [Modifier].
|
||||||
|
func (src *configSource) ModifyConfig(ctx context.Context, m Modifier) {
|
||||||
|
src.mu.Lock()
|
||||||
|
defer src.mu.Unlock()
|
||||||
|
|
||||||
|
m.Modify(src.cfg)
|
||||||
|
for _, li := range src.lis {
|
||||||
|
li(ctx, src.cfg)
|
||||||
|
}
|
||||||
|
}
|
391
internal/testenv/logs.go
Normal file
391
internal/testenv/logs.go
Normal file
|
@ -0,0 +1,391 @@
|
||||||
|
package testenv
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"os"
|
||||||
|
"reflect"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// LogRecorder captures logs from the test environment. It can be created at
|
||||||
|
// any time by calling [Environment.NewLogRecorder], and captures logs until
|
||||||
|
// one of Close(), Logs(), or Match() is called, which stops recording. See the
|
||||||
|
// documentation for each method for more details.
|
||||||
|
type LogRecorder struct {
|
||||||
|
LogRecorderOptions
|
||||||
|
t testing.TB
|
||||||
|
canceled <-chan struct{}
|
||||||
|
buf *buffer
|
||||||
|
recordedLogs []map[string]any
|
||||||
|
|
||||||
|
removeGlobalWriterOnce func()
|
||||||
|
collectLogsOnce sync.Once
|
||||||
|
}
|
||||||
|
|
||||||
|
type LogRecorderOptions struct {
|
||||||
|
filters []func(map[string]any) bool
|
||||||
|
skipCloseDelay bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type LogRecorderOption func(*LogRecorderOptions)
|
||||||
|
|
||||||
|
func (o *LogRecorderOptions) apply(opts ...LogRecorderOption) {
|
||||||
|
for _, op := range opts {
|
||||||
|
op(o)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithFilters applies one or more filter predicates to the logger. If there
|
||||||
|
// are filters present, they will be called in order when a log is received,
|
||||||
|
// and if any filter returns false for a given log, it will be discarded.
|
||||||
|
func WithFilters(filters ...func(map[string]any) bool) LogRecorderOption {
|
||||||
|
return func(o *LogRecorderOptions) {
|
||||||
|
o.filters = filters
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithSkipCloseDelay skips the 1.1 second delay before closing the recorder.
|
||||||
|
// This delay is normally required to ensure Envoy access logs are flushed,
|
||||||
|
// but can be skipped if not required.
|
||||||
|
func WithSkipCloseDelay() LogRecorderOption {
|
||||||
|
return func(o *LogRecorderOptions) {
|
||||||
|
o.skipCloseDelay = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type buffer struct {
|
||||||
|
mu *sync.Mutex
|
||||||
|
underlying bytes.Buffer
|
||||||
|
cond *sync.Cond
|
||||||
|
waiting bool
|
||||||
|
closed bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func newBuffer() *buffer {
|
||||||
|
mu := &sync.Mutex{}
|
||||||
|
return &buffer{
|
||||||
|
mu: mu,
|
||||||
|
cond: sync.NewCond(mu),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read implements io.ReadWriteCloser.
|
||||||
|
func (b *buffer) Read(p []byte) (int, error) {
|
||||||
|
b.mu.Lock()
|
||||||
|
defer b.mu.Unlock()
|
||||||
|
for {
|
||||||
|
n, err := b.underlying.Read(p)
|
||||||
|
if errors.Is(err, io.EOF) && !b.closed {
|
||||||
|
b.waiting = true
|
||||||
|
b.cond.Wait()
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write implements io.ReadWriteCloser.
|
||||||
|
func (b *buffer) Write(p []byte) (int, error) {
|
||||||
|
b.mu.Lock()
|
||||||
|
defer b.mu.Unlock()
|
||||||
|
if b.closed {
|
||||||
|
return 0, io.ErrClosedPipe
|
||||||
|
}
|
||||||
|
if b.waiting {
|
||||||
|
b.waiting = false
|
||||||
|
defer b.cond.Signal()
|
||||||
|
}
|
||||||
|
return b.underlying.Write(p)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close implements io.ReadWriteCloser.
|
||||||
|
func (b *buffer) Close() error {
|
||||||
|
b.mu.Lock()
|
||||||
|
defer b.mu.Unlock()
|
||||||
|
b.closed = true
|
||||||
|
b.cond.Signal()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ io.ReadWriteCloser = (*buffer)(nil)
|
||||||
|
|
||||||
|
func (e *environment) NewLogRecorder(opts ...LogRecorderOption) *LogRecorder {
|
||||||
|
options := LogRecorderOptions{}
|
||||||
|
options.apply(opts...)
|
||||||
|
lr := &LogRecorder{
|
||||||
|
LogRecorderOptions: options,
|
||||||
|
t: e.t,
|
||||||
|
canceled: e.ctx.Done(),
|
||||||
|
buf: newBuffer(),
|
||||||
|
}
|
||||||
|
e.logWriter.Add(lr.buf)
|
||||||
|
lr.removeGlobalWriterOnce = sync.OnceFunc(func() {
|
||||||
|
// wait for envoy access logs, which flush on a 1 second interval
|
||||||
|
if !lr.skipCloseDelay {
|
||||||
|
time.Sleep(1100 * time.Millisecond)
|
||||||
|
}
|
||||||
|
e.logWriter.Remove(lr.buf)
|
||||||
|
})
|
||||||
|
context.AfterFunc(e.ctx, lr.removeGlobalWriterOnce)
|
||||||
|
return lr
|
||||||
|
}
|
||||||
|
|
||||||
|
type (
|
||||||
|
// OpenMap is an alias for map[string]any, and can be used to semantically
|
||||||
|
// represent a map that must contain at least the given entries, but may
|
||||||
|
// also contain additional entries.
|
||||||
|
OpenMap = map[string]any
|
||||||
|
// ClosedMap is a map[string]any that can be used to semantically represent
|
||||||
|
// a map that must contain the given entries exactly, and no others.
|
||||||
|
ClosedMap map[string]any
|
||||||
|
)
|
||||||
|
|
||||||
|
// Close stops the log recorder. After calling this method, Logs() or Match()
|
||||||
|
// can be called to inspect the logs that were captured.
|
||||||
|
func (lr *LogRecorder) Close() {
|
||||||
|
lr.removeGlobalWriterOnce()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (lr *LogRecorder) collectLogs(shouldClose bool) {
|
||||||
|
if shouldClose {
|
||||||
|
lr.removeGlobalWriterOnce()
|
||||||
|
lr.buf.Close()
|
||||||
|
}
|
||||||
|
lr.collectLogsOnce.Do(func() {
|
||||||
|
recordedLogs := []map[string]any{}
|
||||||
|
scan := bufio.NewScanner(lr.buf)
|
||||||
|
for scan.Scan() {
|
||||||
|
log := scan.Bytes()
|
||||||
|
m := map[string]any{}
|
||||||
|
decoder := json.NewDecoder(bytes.NewReader(log))
|
||||||
|
decoder.UseNumber()
|
||||||
|
require.NoError(lr.t, decoder.Decode(&m))
|
||||||
|
for _, filter := range lr.filters {
|
||||||
|
if !filter(m) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
recordedLogs = append(recordedLogs, m)
|
||||||
|
}
|
||||||
|
lr.recordedLogs = recordedLogs
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (lr *LogRecorder) WaitForMatch(expectedLog map[string]any, timeout ...time.Duration) {
|
||||||
|
lr.skipCloseDelay = true
|
||||||
|
found := make(chan struct{})
|
||||||
|
done := make(chan struct{})
|
||||||
|
lr.filters = append(lr.filters, func(entry map[string]any) bool {
|
||||||
|
select {
|
||||||
|
case <-found:
|
||||||
|
default:
|
||||||
|
if matched, _ := match(expectedLog, entry, true); matched {
|
||||||
|
close(found)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
go func() {
|
||||||
|
defer close(done)
|
||||||
|
lr.collectLogs(false)
|
||||||
|
lr.removeGlobalWriterOnce()
|
||||||
|
}()
|
||||||
|
if len(timeout) != 0 {
|
||||||
|
select {
|
||||||
|
case <-found:
|
||||||
|
case <-time.After(timeout[0]):
|
||||||
|
lr.t.Error("timed out waiting for log")
|
||||||
|
case <-lr.canceled:
|
||||||
|
lr.t.Error("canceled")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
select {
|
||||||
|
case <-found:
|
||||||
|
case <-lr.canceled:
|
||||||
|
lr.t.Error("canceled")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
lr.buf.Close()
|
||||||
|
<-done
|
||||||
|
}
|
||||||
|
|
||||||
|
// Logs stops the log recorder (if it is not already stopped), then returns
|
||||||
|
// the logs that were captured as structured map[string]any objects.
|
||||||
|
func (lr *LogRecorder) Logs() []map[string]any {
|
||||||
|
lr.collectLogs(true)
|
||||||
|
return lr.recordedLogs
|
||||||
|
}
|
||||||
|
|
||||||
|
func (lr *LogRecorder) DumpToFile(file string) {
|
||||||
|
lr.collectLogs(true)
|
||||||
|
f, err := os.Create(file)
|
||||||
|
require.NoError(lr.t, err)
|
||||||
|
enc := json.NewEncoder(f)
|
||||||
|
for _, log := range lr.recordedLogs {
|
||||||
|
_ = enc.Encode(log)
|
||||||
|
}
|
||||||
|
f.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Match stops the log recorder (if it is not already stopped), then asserts
|
||||||
|
// that the given expected logs were captured. The expected logs may contain
|
||||||
|
// partial or complete log entries. By default, logs must only match the fields
|
||||||
|
// given, and may contain additional fields that will be ignored.
|
||||||
|
//
|
||||||
|
// There are several special-case value types that can be used to customize the
|
||||||
|
// matching behavior, and/or simplify some common use cases, as follows:
|
||||||
|
// - [OpenMap] and [ClosedMap] can be used to control matching logic
|
||||||
|
// - [json.Number] will convert the actual value to a string before comparison
|
||||||
|
// - [*tls.Certificate] or [*x509.Certificate] will expand to the fields that
|
||||||
|
// would be logged for this certificate
|
||||||
|
func (lr *LogRecorder) Match(expectedLogs []map[string]any) {
|
||||||
|
lr.collectLogs(true)
|
||||||
|
for _, expectedLog := range expectedLogs {
|
||||||
|
found := false
|
||||||
|
highScore, highScoreIdxs := 0, []int{}
|
||||||
|
for i, actualLog := range lr.recordedLogs {
|
||||||
|
if ok, score := match(expectedLog, actualLog, true); ok {
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
} else if score > highScore {
|
||||||
|
highScore = score
|
||||||
|
highScoreIdxs = []int{i}
|
||||||
|
} else if score == highScore {
|
||||||
|
highScoreIdxs = append(highScoreIdxs, i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(highScoreIdxs) > 0 {
|
||||||
|
expectedLogBytes, _ := json.MarshalIndent(expectedLog, "", " ")
|
||||||
|
if len(highScoreIdxs) == 1 {
|
||||||
|
actualLogBytes, _ := json.MarshalIndent(lr.recordedLogs[highScoreIdxs[0]], "", " ")
|
||||||
|
assert.True(lr.t, found, "expected log not found: \n%s\n\nclosest match:\n%s\n",
|
||||||
|
string(expectedLogBytes), string(actualLogBytes))
|
||||||
|
} else {
|
||||||
|
closestMatches := []string{}
|
||||||
|
for _, i := range highScoreIdxs {
|
||||||
|
bytes, _ := json.MarshalIndent(lr.recordedLogs[i], "", " ")
|
||||||
|
closestMatches = append(closestMatches, string(bytes))
|
||||||
|
}
|
||||||
|
assert.True(lr.t, found, "expected log not found: \n%s\n\nclosest matches:\n%s\n", string(expectedLogBytes), closestMatches)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
expectedLogBytes, _ := json.MarshalIndent(expectedLog, "", " ")
|
||||||
|
assert.True(lr.t, found, "expected log not found: %s", string(expectedLogBytes))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func match(expected, actual map[string]any, open bool) (matched bool, score int) {
|
||||||
|
for key, value := range expected {
|
||||||
|
actualValue, ok := actual[key]
|
||||||
|
if !ok {
|
||||||
|
return false, score
|
||||||
|
}
|
||||||
|
score++
|
||||||
|
|
||||||
|
switch actualValue := actualValue.(type) {
|
||||||
|
case map[string]any:
|
||||||
|
switch expectedValue := value.(type) {
|
||||||
|
case ClosedMap:
|
||||||
|
ok, s := match(expectedValue, actualValue, false)
|
||||||
|
score += s * 2
|
||||||
|
if !ok {
|
||||||
|
return false, score
|
||||||
|
}
|
||||||
|
case OpenMap:
|
||||||
|
ok, s := match(expectedValue, actualValue, true)
|
||||||
|
score += s
|
||||||
|
if !ok {
|
||||||
|
return false, score
|
||||||
|
}
|
||||||
|
case *tls.Certificate, *Certificate, *x509.Certificate:
|
||||||
|
var leaf *x509.Certificate
|
||||||
|
switch expectedValue := expectedValue.(type) {
|
||||||
|
case *tls.Certificate:
|
||||||
|
leaf = expectedValue.Leaf
|
||||||
|
case *Certificate:
|
||||||
|
leaf = expectedValue.Leaf
|
||||||
|
case *x509.Certificate:
|
||||||
|
leaf = expectedValue
|
||||||
|
}
|
||||||
|
|
||||||
|
// keep logic consistent with controlplane.populateCertEventDict()
|
||||||
|
expected := map[string]any{}
|
||||||
|
if iss := leaf.Issuer.String(); iss != "" {
|
||||||
|
expected["issuer"] = iss
|
||||||
|
}
|
||||||
|
if sub := leaf.Subject.String(); sub != "" {
|
||||||
|
expected["subject"] = sub
|
||||||
|
}
|
||||||
|
sans := []string{}
|
||||||
|
for _, dnsSAN := range leaf.DNSNames {
|
||||||
|
sans = append(sans, "DNS:"+dnsSAN)
|
||||||
|
}
|
||||||
|
for _, uriSAN := range leaf.URIs {
|
||||||
|
sans = append(sans, "URI:"+uriSAN.String())
|
||||||
|
}
|
||||||
|
if len(sans) > 0 {
|
||||||
|
expected["subjectAltName"] = sans
|
||||||
|
}
|
||||||
|
|
||||||
|
ok, s := match(expected, actualValue, false)
|
||||||
|
score += s
|
||||||
|
if !ok {
|
||||||
|
return false, score
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return false, score
|
||||||
|
}
|
||||||
|
case string:
|
||||||
|
switch value := value.(type) {
|
||||||
|
case string:
|
||||||
|
if value != actualValue {
|
||||||
|
return false, score
|
||||||
|
}
|
||||||
|
score++
|
||||||
|
default:
|
||||||
|
return false, score
|
||||||
|
}
|
||||||
|
case json.Number:
|
||||||
|
if fmt.Sprint(value) != actualValue.String() {
|
||||||
|
return false, score
|
||||||
|
}
|
||||||
|
score++
|
||||||
|
default:
|
||||||
|
// handle slices
|
||||||
|
if reflect.TypeOf(actualValue).Kind() == reflect.Slice {
|
||||||
|
if reflect.TypeOf(value) != reflect.TypeOf(actualValue) {
|
||||||
|
return false, score
|
||||||
|
}
|
||||||
|
actualSlice := reflect.ValueOf(actualValue)
|
||||||
|
expectedSlice := reflect.ValueOf(value)
|
||||||
|
totalScore := 0
|
||||||
|
for i := range min(actualSlice.Len(), expectedSlice.Len()) {
|
||||||
|
if actualSlice.Index(i).Equal(expectedSlice.Index(i)) {
|
||||||
|
totalScore++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
score += totalScore
|
||||||
|
} else {
|
||||||
|
panic(fmt.Sprintf("test bug: add check for type %T in assertMatchingLogs", actualValue))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !open && len(expected) != len(actual) {
|
||||||
|
return false, score
|
||||||
|
}
|
||||||
|
return true, score
|
||||||
|
}
|
76
internal/testenv/route.go
Normal file
76
internal/testenv/route.go
Normal file
|
@ -0,0 +1,76 @@
|
||||||
|
package testenv
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/config"
|
||||||
|
"github.com/pomerium/pomerium/internal/testenv/values"
|
||||||
|
"github.com/pomerium/pomerium/pkg/policy/parser"
|
||||||
|
)
|
||||||
|
|
||||||
|
// PolicyRoute is a [Route] implementation suitable for most common use cases
|
||||||
|
// that can be used in implementations of [Upstream].
|
||||||
|
type PolicyRoute struct {
|
||||||
|
DefaultAttach
|
||||||
|
from values.Value[string]
|
||||||
|
to values.List[string]
|
||||||
|
edits []func(*config.Policy)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Modify implements Route.
|
||||||
|
func (b *PolicyRoute) Modify(cfg *config.Config) {
|
||||||
|
to := make(config.WeightedURLs, 0, len(b.to))
|
||||||
|
for _, u := range b.to {
|
||||||
|
u, err := url.Parse(u.Value())
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
to = append(to, config.WeightedURL{URL: *u})
|
||||||
|
}
|
||||||
|
p := config.Policy{
|
||||||
|
From: b.from.Value(),
|
||||||
|
To: to,
|
||||||
|
}
|
||||||
|
for _, edit := range b.edits {
|
||||||
|
edit(&p)
|
||||||
|
}
|
||||||
|
cfg.Options.Policies = append(cfg.Options.Policies, p)
|
||||||
|
}
|
||||||
|
|
||||||
|
// From implements Route.
|
||||||
|
func (b *PolicyRoute) From(fromURL values.Value[string]) Route {
|
||||||
|
b.from = fromURL
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
// To implements Route.
|
||||||
|
func (b *PolicyRoute) To(toURL values.Value[string]) Route {
|
||||||
|
b.to = append(b.to, toURL)
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
// To implements Route.
|
||||||
|
func (b *PolicyRoute) Policy(edit func(*config.Policy)) Route {
|
||||||
|
b.edits = append(b.edits, edit)
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
// PPL implements Route.
|
||||||
|
func (b *PolicyRoute) PPL(ppl string) Route {
|
||||||
|
pplPolicy, err := parser.ParseYAML(strings.NewReader(ppl))
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
b.edits = append(b.edits, func(p *config.Policy) {
|
||||||
|
p.Policy = &config.PPLPolicy{
|
||||||
|
Policy: pplPolicy,
|
||||||
|
}
|
||||||
|
})
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
// To implements Route.
|
||||||
|
func (b *PolicyRoute) URL() values.Value[string] {
|
||||||
|
return b.from
|
||||||
|
}
|
389
internal/testenv/scenarios/mock_idp.go
Normal file
389
internal/testenv/scenarios/mock_idp.go
Normal file
|
@ -0,0 +1,389 @@
|
||||||
|
package scenarios
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto"
|
||||||
|
"crypto/ecdsa"
|
||||||
|
"crypto/elliptic"
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/hex"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/go-jose/go-jose/v3"
|
||||||
|
"github.com/go-jose/go-jose/v3/jwt"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/pomerium/pomerium/config"
|
||||||
|
"github.com/pomerium/pomerium/internal/encoding"
|
||||||
|
"github.com/pomerium/pomerium/internal/encoding/jws"
|
||||||
|
"github.com/pomerium/pomerium/internal/log"
|
||||||
|
"github.com/pomerium/pomerium/internal/testenv"
|
||||||
|
"github.com/pomerium/pomerium/internal/testenv/upstreams"
|
||||||
|
"github.com/pomerium/pomerium/internal/testenv/values"
|
||||||
|
"github.com/pomerium/pomerium/pkg/grpc/identity"
|
||||||
|
)
|
||||||
|
|
||||||
|
type IDP struct {
|
||||||
|
id values.Value[string]
|
||||||
|
url values.Value[string]
|
||||||
|
publicJWK jose.JSONWebKey
|
||||||
|
signingKey jose.SigningKey
|
||||||
|
|
||||||
|
stateEncoder encoding.MarshalUnmarshaler
|
||||||
|
userLookup map[string]*User
|
||||||
|
}
|
||||||
|
|
||||||
|
// Attach implements testenv.Modifier.
|
||||||
|
func (idp *IDP) Attach(ctx context.Context) {
|
||||||
|
env := testenv.EnvFromContext(ctx)
|
||||||
|
|
||||||
|
router := upstreams.HTTP(nil)
|
||||||
|
|
||||||
|
idp.url = values.Bind2(env.SubdomainURL("mock-idp"), router.Port(), func(urlStr string, port int) string {
|
||||||
|
u, _ := url.Parse(urlStr)
|
||||||
|
host, _, _ := net.SplitHostPort(u.Host)
|
||||||
|
return u.ResolveReference(&url.URL{
|
||||||
|
Scheme: "http",
|
||||||
|
Host: fmt.Sprintf("%s:%d", host, port),
|
||||||
|
}).String()
|
||||||
|
})
|
||||||
|
var err error
|
||||||
|
idp.stateEncoder, err = jws.NewHS256Signer(env.SharedSecret())
|
||||||
|
env.Require().NoError(err)
|
||||||
|
|
||||||
|
idp.id = values.Bind2(idp.url, env.AuthenticateURL(), func(idpUrl, authUrl string) string {
|
||||||
|
provider := identity.Provider{
|
||||||
|
AuthenticateServiceUrl: authUrl,
|
||||||
|
ClientId: "CLIENT_ID",
|
||||||
|
ClientSecret: "CLIENT_SECRET",
|
||||||
|
Type: "oidc",
|
||||||
|
Scopes: []string{"openid", "email", "profile"},
|
||||||
|
Url: idpUrl,
|
||||||
|
}
|
||||||
|
return provider.Hash()
|
||||||
|
})
|
||||||
|
|
||||||
|
router.Handle("/.well-known/jwks.json", func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
_ = json.NewEncoder(w).Encode(&jose.JSONWebKeySet{
|
||||||
|
Keys: []jose.JSONWebKey{idp.publicJWK},
|
||||||
|
})
|
||||||
|
})
|
||||||
|
router.Handle("/.well-known/openid-configuration", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
log.Ctx(ctx).Debug().Str("method", r.Method).Str("uri", r.RequestURI).Send()
|
||||||
|
rootURL, _ := url.Parse(idp.url.Value())
|
||||||
|
_ = json.NewEncoder(w).Encode(map[string]interface{}{
|
||||||
|
"issuer": rootURL.String(),
|
||||||
|
"authorization_endpoint": rootURL.ResolveReference(&url.URL{Path: "/oidc/auth"}).String(),
|
||||||
|
"token_endpoint": rootURL.ResolveReference(&url.URL{Path: "/oidc/token"}).String(),
|
||||||
|
"jwks_uri": rootURL.ResolveReference(&url.URL{Path: "/.well-known/jwks.json"}).String(),
|
||||||
|
"userinfo_endpoint": rootURL.ResolveReference(&url.URL{Path: "/oidc/userinfo"}).String(),
|
||||||
|
"id_token_signing_alg_values_supported": []string{
|
||||||
|
"ES256",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
})
|
||||||
|
router.Handle("/oidc/auth", idp.HandleAuth)
|
||||||
|
router.Handle("/oidc/token", idp.HandleToken)
|
||||||
|
router.Handle("/oidc/userinfo", idp.HandleUserInfo)
|
||||||
|
|
||||||
|
env.AddUpstream(router)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Modify implements testenv.Modifier.
|
||||||
|
func (idp *IDP) Modify(cfg *config.Config) {
|
||||||
|
cfg.Options.Provider = "oidc"
|
||||||
|
cfg.Options.ProviderURL = idp.url.Value()
|
||||||
|
cfg.Options.ClientID = "CLIENT_ID"
|
||||||
|
cfg.Options.ClientSecret = "CLIENT_SECRET"
|
||||||
|
cfg.Options.Scopes = []string{"openid", "email", "profile"}
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ testenv.Modifier = (*IDP)(nil)
|
||||||
|
|
||||||
|
func NewIDP(users []*User) *IDP {
|
||||||
|
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
publicKey := &privateKey.PublicKey
|
||||||
|
|
||||||
|
signingKey := jose.SigningKey{
|
||||||
|
Algorithm: jose.ES256,
|
||||||
|
Key: privateKey,
|
||||||
|
}
|
||||||
|
publicJWK := jose.JSONWebKey{
|
||||||
|
Key: publicKey,
|
||||||
|
Algorithm: string(jose.ES256),
|
||||||
|
Use: "sig",
|
||||||
|
}
|
||||||
|
thumbprint, err := publicJWK.Thumbprint(crypto.SHA256)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
publicJWK.KeyID = hex.EncodeToString(thumbprint)
|
||||||
|
|
||||||
|
userLookup := map[string]*User{}
|
||||||
|
for _, user := range users {
|
||||||
|
user.ID = uuid.NewString()
|
||||||
|
userLookup[user.ID] = user
|
||||||
|
}
|
||||||
|
return &IDP{
|
||||||
|
publicJWK: publicJWK,
|
||||||
|
signingKey: signingKey,
|
||||||
|
userLookup: userLookup,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// HandleAuth handles the auth flow for OIDC.
|
||||||
|
func (idp *IDP) HandleAuth(w http.ResponseWriter, r *http.Request) {
|
||||||
|
rawRedirectURI := r.FormValue("redirect_uri")
|
||||||
|
if rawRedirectURI == "" {
|
||||||
|
http.Error(w, "missing redirect_uri", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
redirectURI, err := url.Parse(rawRedirectURI)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
rawClientID := r.FormValue("client_id")
|
||||||
|
if rawClientID == "" {
|
||||||
|
http.Error(w, "missing client_id", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
rawEmail := r.FormValue("email")
|
||||||
|
if rawEmail != "" {
|
||||||
|
http.Redirect(w, r, redirectURI.ResolveReference(&url.URL{
|
||||||
|
RawQuery: (url.Values{
|
||||||
|
"state": {r.FormValue("state")},
|
||||||
|
"code": {State{
|
||||||
|
Email: rawEmail,
|
||||||
|
ClientID: rawClientID,
|
||||||
|
}.Encode()},
|
||||||
|
}).Encode(),
|
||||||
|
}).String(), http.StatusFound)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
serveHTML(w, `<!doctype html>
|
||||||
|
<html>
|
||||||
|
<head>
|
||||||
|
<title>Login</title>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<form method="POST" style="max-width: 200px">
|
||||||
|
<fieldset>
|
||||||
|
<legend>Login</legend>
|
||||||
|
|
||||||
|
<table>
|
||||||
|
<tbody>
|
||||||
|
<tr>
|
||||||
|
<th><label for="email">Email</label></th>
|
||||||
|
<td>
|
||||||
|
<input type="email" name="email" placeholder="email" />
|
||||||
|
</td>
|
||||||
|
</tr>
|
||||||
|
<tr>
|
||||||
|
<td colspan="2">
|
||||||
|
<input type="submit" />
|
||||||
|
</td>
|
||||||
|
</tr>
|
||||||
|
</tbody>
|
||||||
|
</table>
|
||||||
|
|
||||||
|
</fieldset>
|
||||||
|
</form>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
|
`)
|
||||||
|
}
|
||||||
|
|
||||||
|
// HandleToken handles the token flow for OIDC.
|
||||||
|
func (idp *IDP) HandleToken(w http.ResponseWriter, r *http.Request) {
|
||||||
|
rawCode := r.FormValue("code")
|
||||||
|
|
||||||
|
state, err := DecodeState(rawCode)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
serveJSON(w, map[string]interface{}{
|
||||||
|
"access_token": state.Encode(),
|
||||||
|
"refresh_token": state.Encode(),
|
||||||
|
"token_type": "Bearer",
|
||||||
|
"id_token": state.GetIDToken(r, idp.userLookup).Encode(idp.signingKey),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// HandleUserInfo handles retrieving the user info.
|
||||||
|
func (idp *IDP) HandleUserInfo(w http.ResponseWriter, r *http.Request) {
|
||||||
|
authz := r.Header.Get("Authorization")
|
||||||
|
if authz == "" {
|
||||||
|
http.Error(w, "missing authorization header", http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.HasPrefix(authz, "Bearer ") {
|
||||||
|
authz = authz[len("Bearer "):]
|
||||||
|
} else if strings.HasPrefix(authz, "token ") {
|
||||||
|
authz = authz[len("token "):]
|
||||||
|
} else {
|
||||||
|
http.Error(w, "missing bearer token", http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
state, err := DecodeState(authz)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusForbidden)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
serveJSON(w, state.GetUserInfo(idp.userLookup))
|
||||||
|
}
|
||||||
|
|
||||||
|
type RootURLKey struct{}
|
||||||
|
|
||||||
|
var rootURLKey RootURLKey
|
||||||
|
|
||||||
|
// WithRootURL sets the Root URL in a context.
|
||||||
|
func WithRootURL(ctx context.Context, rootURL *url.URL) context.Context {
|
||||||
|
return context.WithValue(ctx, rootURLKey, rootURL)
|
||||||
|
}
|
||||||
|
|
||||||
|
func getRootURL(r *http.Request) *url.URL {
|
||||||
|
if u, ok := r.Context().Value(rootURLKey).(*url.URL); ok {
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
u := *r.URL
|
||||||
|
if r.Host != "" {
|
||||||
|
u.Host = r.Host
|
||||||
|
}
|
||||||
|
if u.Scheme == "" {
|
||||||
|
if r.TLS != nil {
|
||||||
|
u.Scheme = "https"
|
||||||
|
} else {
|
||||||
|
u.Scheme = "http"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
u.Path = ""
|
||||||
|
return &u
|
||||||
|
}
|
||||||
|
|
||||||
|
func serveHTML(w http.ResponseWriter, html string) {
|
||||||
|
w.Header().Set("Content-Type", "text/html")
|
||||||
|
w.Header().Set("Content-Length", strconv.Itoa(len(html)))
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_, _ = io.WriteString(w, html)
|
||||||
|
}
|
||||||
|
|
||||||
|
func serveJSON(w http.ResponseWriter, obj interface{}) {
|
||||||
|
bs, err := json.Marshal(obj)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_, _ = w.Write(bs)
|
||||||
|
}
|
||||||
|
|
||||||
|
type State struct {
|
||||||
|
Email string `json:"email"`
|
||||||
|
ClientID string `json:"client_id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func DecodeState(rawCode string) (*State, error) {
|
||||||
|
var state State
|
||||||
|
bs, _ := base64.URLEncoding.DecodeString(rawCode)
|
||||||
|
err := json.Unmarshal(bs, &state)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &state, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (state State) Encode() string {
|
||||||
|
bs, _ := json.Marshal(state)
|
||||||
|
return base64.URLEncoding.EncodeToString(bs)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (state State) GetIDToken(r *http.Request, users map[string]*User) *IDToken {
|
||||||
|
token := &IDToken{
|
||||||
|
UserInfo: state.GetUserInfo(users),
|
||||||
|
|
||||||
|
Issuer: getRootURL(r).String(),
|
||||||
|
Audience: state.ClientID,
|
||||||
|
Expiry: jwt.NewNumericDate(time.Now().Add(time.Hour * 24 * 365)),
|
||||||
|
IssuedAt: jwt.NewNumericDate(time.Now()),
|
||||||
|
}
|
||||||
|
return token
|
||||||
|
}
|
||||||
|
|
||||||
|
func (state State) GetUserInfo(users map[string]*User) *UserInfo {
|
||||||
|
userInfo := &UserInfo{
|
||||||
|
Subject: state.Email,
|
||||||
|
Email: state.Email,
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, u := range users {
|
||||||
|
if u.Email == state.Email {
|
||||||
|
userInfo.Subject = u.ID
|
||||||
|
userInfo.Name = u.FirstName + " " + u.LastName
|
||||||
|
userInfo.FamilyName = u.LastName
|
||||||
|
userInfo.GivenName = u.FirstName
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return userInfo
|
||||||
|
}
|
||||||
|
|
||||||
|
type UserInfo struct {
|
||||||
|
Subject string `json:"sub"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
Email string `json:"email"`
|
||||||
|
FamilyName string `json:"family_name"`
|
||||||
|
GivenName string `json:"given_name"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type IDToken struct {
|
||||||
|
*UserInfo
|
||||||
|
|
||||||
|
Issuer string `json:"iss"`
|
||||||
|
Audience string `json:"aud"`
|
||||||
|
Expiry *jwt.NumericDate `json:"exp"`
|
||||||
|
IssuedAt *jwt.NumericDate `json:"iat"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (token *IDToken) Encode(signingKey jose.SigningKey) string {
|
||||||
|
sig, err := jose.NewSigner(signingKey, (&jose.SignerOptions{}).WithType("JWT"))
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
str, err := jwt.Signed(sig).Claims(token).CompactSerialize()
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return str
|
||||||
|
}
|
||||||
|
|
||||||
|
type User struct {
|
||||||
|
ID string
|
||||||
|
Email string
|
||||||
|
FirstName string
|
||||||
|
LastName string
|
||||||
|
}
|
24
internal/testenv/scenarios/mtls.go
Normal file
24
internal/testenv/scenarios/mtls.go
Normal file
|
@ -0,0 +1,24 @@
|
||||||
|
package scenarios
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/pem"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/config"
|
||||||
|
"github.com/pomerium/pomerium/internal/testenv"
|
||||||
|
)
|
||||||
|
|
||||||
|
func DownstreamMTLS(mode config.MTLSEnforcement) testenv.Modifier {
|
||||||
|
return testenv.ModifierFunc(func(ctx context.Context, cfg *config.Config) {
|
||||||
|
env := testenv.EnvFromContext(ctx)
|
||||||
|
block := pem.Block{
|
||||||
|
Type: "CERTIFICATE",
|
||||||
|
Bytes: env.CACert().Leaf.Raw,
|
||||||
|
}
|
||||||
|
cfg.Options.DownstreamMTLS = config.DownstreamMTLSSettings{
|
||||||
|
CA: base64.StdEncoding.EncodeToString(pem.EncodeToMemory(&block)),
|
||||||
|
Enforcement: mode,
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
64
internal/testenv/snippets/routes.go
Normal file
64
internal/testenv/snippets/routes.go
Normal file
|
@ -0,0 +1,64 @@
|
||||||
|
package snippets
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"strings"
|
||||||
|
"text/template"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/config"
|
||||||
|
"github.com/pomerium/pomerium/internal/testenv"
|
||||||
|
"github.com/pomerium/pomerium/pkg/policy/parser"
|
||||||
|
)
|
||||||
|
|
||||||
|
var SimplePolicyTemplate = PolicyTemplate{
|
||||||
|
From: "https://from-{{.Idx}}.localhost",
|
||||||
|
To: "https://to-{{.Idx}}.localhost",
|
||||||
|
PPL: `{"allow":{"and":["email":{"is":"user-{{.Idx}}@example.com"}]}}`,
|
||||||
|
}
|
||||||
|
|
||||||
|
type PolicyTemplate struct {
|
||||||
|
From string
|
||||||
|
To string
|
||||||
|
PPL string
|
||||||
|
|
||||||
|
// Add more fields as needed (be sure to update newPolicyFromTemplate)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TemplateRoutes(n int, tmpl PolicyTemplate) testenv.Modifier {
|
||||||
|
return testenv.ModifierFunc(func(_ context.Context, cfg *config.Config) {
|
||||||
|
for i := range n {
|
||||||
|
cfg.Options.Policies = append(cfg.Options.Policies, newPolicyFromTemplate(i, tmpl))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func newPolicyFromTemplate(i int, pt PolicyTemplate) config.Policy {
|
||||||
|
eval := func(in string) string {
|
||||||
|
t := template.New("policy")
|
||||||
|
tmpl, err := t.Parse(in)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
var out bytes.Buffer
|
||||||
|
if err := tmpl.Execute(&out, struct{ Idx int }{i}); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return out.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
pplPolicy, err := parser.ParseYAML(strings.NewReader(eval(pt.PPL)))
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
to, err := config.ParseWeightedUrls(eval(pt.To))
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return config.Policy{
|
||||||
|
From: eval(pt.From),
|
||||||
|
To: to,
|
||||||
|
Policy: &config.PPLPolicy{Policy: pplPolicy},
|
||||||
|
}
|
||||||
|
}
|
35
internal/testenv/snippets/wait.go
Normal file
35
internal/testenv/snippets/wait.go
Normal file
|
@ -0,0 +1,35 @@
|
||||||
|
package snippets
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/internal/testenv"
|
||||||
|
"github.com/pomerium/pomerium/pkg/grpcutil"
|
||||||
|
"google.golang.org/grpc"
|
||||||
|
"google.golang.org/grpc/connectivity"
|
||||||
|
"google.golang.org/grpc/credentials/insecure"
|
||||||
|
)
|
||||||
|
|
||||||
|
func WaitStartupComplete(env testenv.Environment, timeout ...time.Duration) time.Duration {
|
||||||
|
start := time.Now()
|
||||||
|
recorder := env.NewLogRecorder()
|
||||||
|
if len(timeout) == 0 {
|
||||||
|
timeout = append(timeout, 1*time.Minute)
|
||||||
|
}
|
||||||
|
ctx, ca := context.WithTimeout(env.Context(), timeout[0])
|
||||||
|
defer ca()
|
||||||
|
recorder.WaitForMatch(map[string]any{
|
||||||
|
"syncer_id": "databroker",
|
||||||
|
"syncer_type": "type.googleapis.com/pomerium.config.Config",
|
||||||
|
"message": "listening for updates",
|
||||||
|
}, timeout...)
|
||||||
|
cc, err := grpc.Dial(env.DatabrokerURL().Value(),
|
||||||
|
grpc.WithTransportCredentials(insecure.NewCredentials()),
|
||||||
|
grpc.WithChainUnaryInterceptor(grpcutil.WithUnarySignedJWT(env.SharedSecret)),
|
||||||
|
grpc.WithChainStreamInterceptor(grpcutil.WithStreamSignedJWT(env.SharedSecret)),
|
||||||
|
)
|
||||||
|
env.Require().NoError(err)
|
||||||
|
env.Require().True(cc.WaitForStateChange(ctx, connectivity.Ready))
|
||||||
|
return time.Since(start)
|
||||||
|
}
|
206
internal/testenv/types.go
Normal file
206
internal/testenv/types.go
Normal file
|
@ -0,0 +1,206 @@
|
||||||
|
package testenv
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/config"
|
||||||
|
"github.com/pomerium/pomerium/internal/testenv/values"
|
||||||
|
)
|
||||||
|
|
||||||
|
type envContextKeyType struct{}
|
||||||
|
|
||||||
|
var envContextKey envContextKeyType
|
||||||
|
|
||||||
|
func EnvFromContext(ctx context.Context) Environment {
|
||||||
|
return ctx.Value(envContextKey).(Environment)
|
||||||
|
}
|
||||||
|
|
||||||
|
func ContextWithEnv(ctx context.Context, env Environment) context.Context {
|
||||||
|
return context.WithValue(ctx, envContextKey, env)
|
||||||
|
}
|
||||||
|
|
||||||
|
// A Modifier is an object whose presence in the test affects the Pomerium
|
||||||
|
// configuration in some way. When the test environment is started, a
|
||||||
|
// [*config.Config] is constructed by calling each added Modifier in order.
|
||||||
|
//
|
||||||
|
// For additional details, see [Environment.Add] and [Environment.Start].
|
||||||
|
type Modifier interface {
|
||||||
|
// Attach is called by an [Environment] (before Modify) to propagate the
|
||||||
|
// environment's context.
|
||||||
|
Attach(ctx context.Context)
|
||||||
|
|
||||||
|
// Modify is called by an [Environment] to mutate its configuration in some
|
||||||
|
// way required by this Modifier.
|
||||||
|
Modify(cfg *config.Config)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultAttach should be embedded in types implementing [Modifier] to
|
||||||
|
// automatically obtain environment context details and caller information.
|
||||||
|
type DefaultAttach struct {
|
||||||
|
env Environment
|
||||||
|
caller string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DefaultAttach) Env() Environment {
|
||||||
|
d.CheckAttached()
|
||||||
|
return d.env
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DefaultAttach) Attach(ctx context.Context) {
|
||||||
|
if d.env != nil {
|
||||||
|
panic("internal test environment bug: Attach called twice")
|
||||||
|
}
|
||||||
|
d.env = EnvFromContext(ctx)
|
||||||
|
if d.env == nil {
|
||||||
|
panic("test bug: no environment in context")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DefaultAttach) CheckAttached() {
|
||||||
|
if d.env == nil {
|
||||||
|
if d.caller != "" {
|
||||||
|
panic("test bug: missing a call to Add for the object created at: " + d.caller)
|
||||||
|
}
|
||||||
|
panic("test bug: not attached (possibly missing a call to Add)")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DefaultAttach) RecordCaller() {
|
||||||
|
d.caller = getCaller(4)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Aggregate should be embedded in types implementing [Modifier] when the type
|
||||||
|
// contains other modifiers. Used as an alternative to [DefaultAttach].
|
||||||
|
// Embedding this struct will properly keep track of when constituent modifiers
|
||||||
|
// are added, for validation and caller detection.
|
||||||
|
//
|
||||||
|
// Aggregate implements a no-op Modify() by default, but this can be overridden
|
||||||
|
// to make additional modifications. The aggregate's Modify() is called first.
|
||||||
|
type Aggregate struct {
|
||||||
|
env Environment
|
||||||
|
caller string
|
||||||
|
modifiers []Modifier
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Aggregate) Add(mod Modifier) {
|
||||||
|
if d.env != nil {
|
||||||
|
if d.env.(*environment).getState() == NotRunning {
|
||||||
|
// If the test environment is running, adding to an aggregate is a no-op.
|
||||||
|
// If the test environment has not been started yet, the aggregate is
|
||||||
|
// being used like in the following example, which is incorrect:
|
||||||
|
//
|
||||||
|
// aggregate.Add(foo)
|
||||||
|
// env.Add(aggregate)
|
||||||
|
// aggregate.Add(bar)
|
||||||
|
// env.Start()
|
||||||
|
//
|
||||||
|
// It should instead be used like this:
|
||||||
|
//
|
||||||
|
// aggregate.Add(foo)
|
||||||
|
// aggregate.Add(bar)
|
||||||
|
// env.Add(aggregate)
|
||||||
|
// env.Start()
|
||||||
|
panic("test bug: cannot modify an aggregate that has already been added")
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
d.modifiers = append(d.modifiers, mod)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Aggregate) Env() Environment {
|
||||||
|
d.CheckAttached()
|
||||||
|
return d.env
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Aggregate) Attach(ctx context.Context) {
|
||||||
|
if d.env != nil {
|
||||||
|
panic("internal test environment bug: Attach called twice")
|
||||||
|
}
|
||||||
|
d.env = EnvFromContext(ctx)
|
||||||
|
if d.env == nil {
|
||||||
|
panic("test bug: no environment in context")
|
||||||
|
}
|
||||||
|
d.env.(*environment).t.Helper()
|
||||||
|
for _, mod := range d.modifiers {
|
||||||
|
d.env.Add(mod)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Aggregate) Modify(*config.Config) {}
|
||||||
|
|
||||||
|
func (d *Aggregate) CheckAttached() {
|
||||||
|
if d.env == nil {
|
||||||
|
if d.caller != "" {
|
||||||
|
panic("test bug: missing a call to Add for the object created at: " + d.caller)
|
||||||
|
}
|
||||||
|
panic("test bug: not attached (possibly missing a call to Add)")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Aggregate) RecordCaller() {
|
||||||
|
d.caller = getCaller(4)
|
||||||
|
}
|
||||||
|
|
||||||
|
type modifierFunc struct {
|
||||||
|
fn func(ctx context.Context, cfg *config.Config)
|
||||||
|
ctx context.Context
|
||||||
|
}
|
||||||
|
|
||||||
|
// Attach implements Modifier.
|
||||||
|
func (f *modifierFunc) Attach(ctx context.Context) {
|
||||||
|
f.ctx = ctx
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *modifierFunc) Modify(cfg *config.Config) {
|
||||||
|
f.fn(f.ctx, cfg)
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ Modifier = (*modifierFunc)(nil)
|
||||||
|
|
||||||
|
func ModifierFunc(fn func(ctx context.Context, cfg *config.Config)) Modifier {
|
||||||
|
return &modifierFunc{fn: fn}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Task represents a background task that can be added to an [Environment] to
|
||||||
|
// have it run automatically on startup.
|
||||||
|
//
|
||||||
|
// For additional details, see [Environment.AddTask] and [Environment.Start].
|
||||||
|
type Task interface {
|
||||||
|
Run(ctx context.Context) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type TaskFunc func(ctx context.Context) error
|
||||||
|
|
||||||
|
func (f TaskFunc) Run(ctx context.Context) error {
|
||||||
|
return f(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Upstream represents an upstream server. It is both a [Task] and a [Modifier]
|
||||||
|
// and can be added to an environment using [Environment.AddUpstream]. From an
|
||||||
|
// Upstream instance, new routes can be created (which automatically adds the
|
||||||
|
// necessary route/policy entries to the config), and used within a test to
|
||||||
|
// easily make requests to the routes with implementation-specific clients.
|
||||||
|
type Upstream interface {
|
||||||
|
Modifier
|
||||||
|
Task
|
||||||
|
Port() values.Value[int]
|
||||||
|
Route() RouteStub
|
||||||
|
}
|
||||||
|
|
||||||
|
// A Route represents a route from a source URL to a destination URL. A route is
|
||||||
|
// typically created by calling [Upstream.Route].
|
||||||
|
type Route interface {
|
||||||
|
Modifier
|
||||||
|
URL() values.Value[string]
|
||||||
|
To(toURL values.Value[string]) Route
|
||||||
|
Policy(edit func(*config.Policy)) Route
|
||||||
|
PPL(ppl string) Route
|
||||||
|
// add more methods here as they become needed
|
||||||
|
}
|
||||||
|
|
||||||
|
// RouteStub represents an incomplete [Route]. Providing a URL by calling its
|
||||||
|
// From() method will return a [Route], from which further configuration can
|
||||||
|
// be made.
|
||||||
|
type RouteStub interface {
|
||||||
|
From(fromURL values.Value[string]) Route
|
||||||
|
}
|
139
internal/testenv/upstreams/grpc.go
Normal file
139
internal/testenv/upstreams/grpc.go
Normal file
|
@ -0,0 +1,139 @@
|
||||||
|
package upstreams
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/internal/testenv"
|
||||||
|
"github.com/pomerium/pomerium/internal/testenv/values"
|
||||||
|
"google.golang.org/grpc"
|
||||||
|
"google.golang.org/grpc/credentials"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Options struct {
|
||||||
|
serverOpts []grpc.ServerOption
|
||||||
|
}
|
||||||
|
|
||||||
|
type Option func(*Options)
|
||||||
|
|
||||||
|
func (o *Options) apply(opts ...Option) {
|
||||||
|
for _, op := range opts {
|
||||||
|
op(o)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func ServerOpts(opt ...grpc.ServerOption) Option {
|
||||||
|
return func(o *Options) {
|
||||||
|
o.serverOpts = append(o.serverOpts, opt...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GRPCUpstream represents a GRPC server which can be used as the target for
|
||||||
|
// one or more Pomerium routes in a test environment.
|
||||||
|
//
|
||||||
|
// This upstream implements [grpc.ServiceRegistrar], and can be used similarly
|
||||||
|
// in the same way as [*grpc.Server] to register services before it is started.
|
||||||
|
//
|
||||||
|
// Any [testenv.Route] instances created from this upstream can be referenced
|
||||||
|
// in the Dial() method to establish a connection to that route.
|
||||||
|
type GRPCUpstream interface {
|
||||||
|
testenv.Upstream
|
||||||
|
grpc.ServiceRegistrar
|
||||||
|
Dial(r testenv.Route, dialOpts ...grpc.DialOption) *grpc.ClientConn
|
||||||
|
}
|
||||||
|
|
||||||
|
type grpcUpstream struct {
|
||||||
|
Options
|
||||||
|
testenv.Aggregate
|
||||||
|
serverPort values.MutableValue[int]
|
||||||
|
creds credentials.TransportCredentials
|
||||||
|
|
||||||
|
services []service
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
_ testenv.Upstream = (*grpcUpstream)(nil)
|
||||||
|
_ grpc.ServiceRegistrar = (*grpcUpstream)(nil)
|
||||||
|
)
|
||||||
|
|
||||||
|
// GRPC creates a new GRPC upstream server.
|
||||||
|
func GRPC(creds credentials.TransportCredentials, opts ...Option) GRPCUpstream {
|
||||||
|
options := Options{}
|
||||||
|
options.apply(opts...)
|
||||||
|
up := &grpcUpstream{
|
||||||
|
Options: options,
|
||||||
|
creds: creds,
|
||||||
|
serverPort: values.Deferred[int](),
|
||||||
|
}
|
||||||
|
up.RecordCaller()
|
||||||
|
return up
|
||||||
|
}
|
||||||
|
|
||||||
|
type service struct {
|
||||||
|
desc *grpc.ServiceDesc
|
||||||
|
impl any
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *grpcUpstream) Port() values.Value[int] {
|
||||||
|
return g.serverPort
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterService implements grpc.ServiceRegistrar.
|
||||||
|
func (g *grpcUpstream) RegisterService(desc *grpc.ServiceDesc, impl any) {
|
||||||
|
g.services = append(g.services, service{desc, impl})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Route implements testenv.Upstream.
|
||||||
|
func (g *grpcUpstream) Route() testenv.RouteStub {
|
||||||
|
r := &testenv.PolicyRoute{}
|
||||||
|
var protocol string
|
||||||
|
switch g.creds.Info().SecurityProtocol {
|
||||||
|
case "insecure":
|
||||||
|
protocol = "h2c"
|
||||||
|
default:
|
||||||
|
protocol = "https"
|
||||||
|
}
|
||||||
|
r.To(values.Bind(g.serverPort, func(port int) string {
|
||||||
|
return fmt.Sprintf("%s://127.0.0.1:%d", protocol, port)
|
||||||
|
}))
|
||||||
|
g.Add(r)
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start implements testenv.Upstream.
|
||||||
|
func (g *grpcUpstream) Run(ctx context.Context) error {
|
||||||
|
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
g.serverPort.Resolve(listener.Addr().(*net.TCPAddr).Port)
|
||||||
|
server := grpc.NewServer(append(g.serverOpts, grpc.Creds(g.creds))...)
|
||||||
|
for _, s := range g.services {
|
||||||
|
server.RegisterService(s.desc, s.impl)
|
||||||
|
}
|
||||||
|
errC := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
errC <- server.Serve(listener)
|
||||||
|
}()
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
server.Stop()
|
||||||
|
return context.Cause(ctx)
|
||||||
|
case err := <-errC:
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *grpcUpstream) Dial(r testenv.Route, dialOpts ...grpc.DialOption) *grpc.ClientConn {
|
||||||
|
dialOpts = append(dialOpts,
|
||||||
|
grpc.WithTransportCredentials(credentials.NewClientTLSFromCert(g.Env().ServerCAs(), "")),
|
||||||
|
grpc.WithDefaultCallOptions(grpc.WaitForReady(true)),
|
||||||
|
)
|
||||||
|
cc, err := grpc.NewClient(strings.TrimPrefix(r.URL().Value(), "https://"), dialOpts...)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return cc
|
||||||
|
}
|
327
internal/testenv/upstreams/http.go
Normal file
327
internal/testenv/upstreams/http.go
Normal file
|
@ -0,0 +1,327 @@
|
||||||
|
package upstreams
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/http/cookiejar"
|
||||||
|
"net/url"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gorilla/mux"
|
||||||
|
"github.com/pomerium/pomerium/integration/forms"
|
||||||
|
"github.com/pomerium/pomerium/internal/retry"
|
||||||
|
"github.com/pomerium/pomerium/internal/testenv"
|
||||||
|
"github.com/pomerium/pomerium/internal/testenv/values"
|
||||||
|
"google.golang.org/protobuf/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
type RequestOptions struct {
|
||||||
|
path string
|
||||||
|
query url.Values
|
||||||
|
headers map[string]string
|
||||||
|
authenticateAs string
|
||||||
|
body any
|
||||||
|
clientCerts []tls.Certificate
|
||||||
|
client *http.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
type RequestOption func(*RequestOptions)
|
||||||
|
|
||||||
|
func (o *RequestOptions) apply(opts ...RequestOption) {
|
||||||
|
for _, op := range opts {
|
||||||
|
op(o)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Path sets the path of the request. If omitted, the request URL will match
|
||||||
|
// the route URL exactly.
|
||||||
|
func Path(path string) RequestOption {
|
||||||
|
return func(o *RequestOptions) {
|
||||||
|
o.path = path
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Query sets optional query parameters of the request.
|
||||||
|
func Query(query url.Values) RequestOption {
|
||||||
|
return func(o *RequestOptions) {
|
||||||
|
o.query = query
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Headers adds optional headers to the request.
|
||||||
|
func Headers(headers map[string]string) RequestOption {
|
||||||
|
return func(o *RequestOptions) {
|
||||||
|
o.headers = headers
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func AuthenticateAs(email string) RequestOption {
|
||||||
|
return func(o *RequestOptions) {
|
||||||
|
o.authenticateAs = email
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Client(c *http.Client) RequestOption {
|
||||||
|
return func(o *RequestOptions) {
|
||||||
|
o.client = c
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Body sets the body of the request.
|
||||||
|
// The argument can be one of the following types:
|
||||||
|
// - string
|
||||||
|
// - []byte
|
||||||
|
// - io.Reader
|
||||||
|
// - proto.Message
|
||||||
|
// - any json-encodable type
|
||||||
|
// If the argument is encoded as json, the Content-Type header will be set to
|
||||||
|
// "application/json". If the argument is a proto.Message, the Content-Type
|
||||||
|
// header will be set to "application/octet-stream".
|
||||||
|
func Body(body any) RequestOption {
|
||||||
|
return func(o *RequestOptions) {
|
||||||
|
o.body = body
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClientCert adds a client certificate to the request.
|
||||||
|
func ClientCert[T interface {
|
||||||
|
*testenv.Certificate | *tls.Certificate
|
||||||
|
}](cert T) RequestOption {
|
||||||
|
return func(o *RequestOptions) {
|
||||||
|
o.clientCerts = append(o.clientCerts, *(*tls.Certificate)(cert))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// HTTPUpstream represents a HTTP server which can be used as the target for
|
||||||
|
// one or more Pomerium routes in a test environment.
|
||||||
|
//
|
||||||
|
// The Handle() method can be used to add handlers the server-side HTTP router,
|
||||||
|
// while the Get(), Post(), and (generic) Do() methods can be used to make
|
||||||
|
// client-side requests.
|
||||||
|
type HTTPUpstream interface {
|
||||||
|
testenv.Upstream
|
||||||
|
|
||||||
|
Handle(path string, f func(http.ResponseWriter, *http.Request)) *mux.Route
|
||||||
|
|
||||||
|
Get(r testenv.Route, opts ...RequestOption) (*http.Response, error)
|
||||||
|
Post(r testenv.Route, opts ...RequestOption) (*http.Response, error)
|
||||||
|
Do(method string, r testenv.Route, opts ...RequestOption) (*http.Response, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type httpUpstream struct {
|
||||||
|
testenv.Aggregate
|
||||||
|
serverPort values.MutableValue[int]
|
||||||
|
tlsConfig values.Value[*tls.Config]
|
||||||
|
|
||||||
|
clientCache sync.Map // map[testenv.Route]*http.Client
|
||||||
|
|
||||||
|
router *mux.Router
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
_ testenv.Upstream = (*httpUpstream)(nil)
|
||||||
|
_ HTTPUpstream = (*httpUpstream)(nil)
|
||||||
|
)
|
||||||
|
|
||||||
|
// HTTP creates a new HTTP upstream server.
|
||||||
|
func HTTP(tlsConfig values.Value[*tls.Config]) HTTPUpstream {
|
||||||
|
up := &httpUpstream{
|
||||||
|
serverPort: values.Deferred[int](),
|
||||||
|
router: mux.NewRouter(),
|
||||||
|
tlsConfig: tlsConfig,
|
||||||
|
}
|
||||||
|
up.RecordCaller()
|
||||||
|
return up
|
||||||
|
}
|
||||||
|
|
||||||
|
// Port implements HTTPUpstream.
|
||||||
|
func (h *httpUpstream) Port() values.Value[int] {
|
||||||
|
return h.serverPort
|
||||||
|
}
|
||||||
|
|
||||||
|
// Router implements HTTPUpstream.
|
||||||
|
func (h *httpUpstream) Handle(path string, f func(http.ResponseWriter, *http.Request)) *mux.Route {
|
||||||
|
return h.router.HandleFunc(path, f)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Route implements HTTPUpstream.
|
||||||
|
func (h *httpUpstream) Route() testenv.RouteStub {
|
||||||
|
r := &testenv.PolicyRoute{}
|
||||||
|
protocol := "http"
|
||||||
|
r.To(values.Bind(h.serverPort, func(port int) string {
|
||||||
|
return fmt.Sprintf("%s://127.0.0.1:%d", protocol, port)
|
||||||
|
}))
|
||||||
|
h.Add(r)
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run implements HTTPUpstream.
|
||||||
|
func (h *httpUpstream) Run(ctx context.Context) error {
|
||||||
|
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
h.serverPort.Resolve(listener.Addr().(*net.TCPAddr).Port)
|
||||||
|
var tlsConfig *tls.Config
|
||||||
|
if h.tlsConfig != nil {
|
||||||
|
tlsConfig = h.tlsConfig.Value()
|
||||||
|
}
|
||||||
|
server := &http.Server{
|
||||||
|
Handler: h.router,
|
||||||
|
TLSConfig: tlsConfig,
|
||||||
|
BaseContext: func(net.Listener) context.Context {
|
||||||
|
return ctx
|
||||||
|
},
|
||||||
|
}
|
||||||
|
errC := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
errC <- server.Serve(listener)
|
||||||
|
}()
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
server.Close()
|
||||||
|
return context.Cause(ctx)
|
||||||
|
case err := <-errC:
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get implements HTTPUpstream.
|
||||||
|
func (h *httpUpstream) Get(r testenv.Route, opts ...RequestOption) (*http.Response, error) {
|
||||||
|
return h.Do(http.MethodGet, r, opts...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Post implements HTTPUpstream.
|
||||||
|
func (h *httpUpstream) Post(r testenv.Route, opts ...RequestOption) (*http.Response, error) {
|
||||||
|
return h.Do(http.MethodPost, r, opts...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Do implements HTTPUpstream.
|
||||||
|
func (h *httpUpstream) Do(method string, r testenv.Route, opts ...RequestOption) (*http.Response, error) {
|
||||||
|
options := RequestOptions{}
|
||||||
|
options.apply(opts...)
|
||||||
|
u, err := url.Parse(r.URL().Value())
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if options.path != "" || options.query != nil {
|
||||||
|
u = u.ResolveReference(&url.URL{
|
||||||
|
Path: options.path,
|
||||||
|
RawQuery: options.query.Encode(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
req, err := http.NewRequest(method, u.String(), nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
switch body := options.body.(type) {
|
||||||
|
case string:
|
||||||
|
req.Body = io.NopCloser(strings.NewReader(body))
|
||||||
|
case []byte:
|
||||||
|
req.Body = io.NopCloser(bytes.NewReader(body))
|
||||||
|
case io.Reader:
|
||||||
|
req.Body = io.NopCloser(body)
|
||||||
|
case proto.Message:
|
||||||
|
buf, err := proto.Marshal(body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
req.Body = io.NopCloser(bytes.NewReader(buf))
|
||||||
|
req.Header.Set("Content-Type", "application/octet-stream")
|
||||||
|
default:
|
||||||
|
buf, err := json.Marshal(body)
|
||||||
|
if err != nil {
|
||||||
|
panic(fmt.Sprintf("unsupported body type: %T", body))
|
||||||
|
}
|
||||||
|
req.Body = io.NopCloser(bytes.NewReader(buf))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
case nil:
|
||||||
|
}
|
||||||
|
|
||||||
|
newClient := func() *http.Client {
|
||||||
|
c := http.Client{
|
||||||
|
Transport: &http.Transport{
|
||||||
|
TLSClientConfig: &tls.Config{
|
||||||
|
RootCAs: h.Env().ServerCAs(),
|
||||||
|
Certificates: options.clientCerts,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
c.Jar, _ = cookiejar.New(&cookiejar.Options{})
|
||||||
|
return &c
|
||||||
|
}
|
||||||
|
var client *http.Client
|
||||||
|
if options.client != nil {
|
||||||
|
client = options.client
|
||||||
|
} else {
|
||||||
|
var cachedClient any
|
||||||
|
var ok bool
|
||||||
|
if cachedClient, ok = h.clientCache.Load(r); !ok {
|
||||||
|
cachedClient, _ = h.clientCache.LoadOrStore(r, newClient())
|
||||||
|
}
|
||||||
|
client = cachedClient.(*http.Client)
|
||||||
|
}
|
||||||
|
|
||||||
|
var resp *http.Response
|
||||||
|
if err := retry.Retry(h.Env().Context(), "http", func(ctx context.Context) error {
|
||||||
|
var err error
|
||||||
|
if options.authenticateAs != "" {
|
||||||
|
resp, err = authenticateFlow(ctx, client, req, options.authenticateAs) //nolint:bodyclose
|
||||||
|
} else {
|
||||||
|
resp, err = client.Do(req) //nolint:bodyclose
|
||||||
|
}
|
||||||
|
// retry on connection refused
|
||||||
|
if err != nil {
|
||||||
|
var opErr *net.OpError
|
||||||
|
if errors.As(err, &opErr) && opErr.Op == "dial" && opErr.Err.Error() == "connect: connection refused" {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return retry.NewTerminalError(err)
|
||||||
|
}
|
||||||
|
if resp.StatusCode == http.StatusInternalServerError {
|
||||||
|
return errors.New(http.StatusText(resp.StatusCode))
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}, retry.WithMaxInterval(100*time.Millisecond)); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func authenticateFlow(ctx context.Context, client *http.Client, req *http.Request, email string) (*http.Response, error) {
|
||||||
|
var res *http.Response
|
||||||
|
originalHostname := req.URL.Hostname()
|
||||||
|
res, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
location := res.Request.URL
|
||||||
|
if location.Hostname() == originalHostname {
|
||||||
|
// already authenticated
|
||||||
|
return res, err
|
||||||
|
}
|
||||||
|
defer res.Body.Close()
|
||||||
|
fs := forms.Parse(res.Body)
|
||||||
|
if len(fs) > 0 {
|
||||||
|
f := fs[0]
|
||||||
|
f.Inputs["email"] = email
|
||||||
|
f.Inputs["token_expiration"] = strconv.Itoa(int((time.Hour * 24).Seconds()))
|
||||||
|
formReq, err := f.NewRequestWithContext(ctx, location)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return client.Do(formReq)
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("test bug: expected IDP login form")
|
||||||
|
}
|
120
internal/testenv/values/value.go
Normal file
120
internal/testenv/values/value.go
Normal file
|
@ -0,0 +1,120 @@
|
||||||
|
package values
|
||||||
|
|
||||||
|
import (
|
||||||
|
"math/rand/v2"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
type value[T any] struct {
|
||||||
|
f func() T
|
||||||
|
ready bool
|
||||||
|
cond *sync.Cond
|
||||||
|
}
|
||||||
|
|
||||||
|
// A Value is a container for a single value of type T, whose initialization is
|
||||||
|
// performed the first time Value() is called. Subsequent calls will return the
|
||||||
|
// same value. The Value() function may block until the value is ready on the
|
||||||
|
// first call. Values are safe to use concurrently.
|
||||||
|
type Value[T any] interface {
|
||||||
|
Value() T
|
||||||
|
}
|
||||||
|
|
||||||
|
// MutableValue is the read-write counterpart to [Value], created by calling
|
||||||
|
// [Deferred] for some type T. Calling Resolve() or ResolveFunc() will set
|
||||||
|
// the value and unblock any waiting calls to Value().
|
||||||
|
type MutableValue[T any] interface {
|
||||||
|
Value[T]
|
||||||
|
Resolve(value T)
|
||||||
|
ResolveFunc(fOnce func() T)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Deferred creates a new read-write [MutableValue] for some type T,
|
||||||
|
// representing a value whose initialization may be deferred to a later time.
|
||||||
|
// Once the value is available, call [MutableValue.Resolve] or
|
||||||
|
// [MutableValue.ResolveFunc] to unblock any waiting calls to Value().
|
||||||
|
func Deferred[T any]() MutableValue[T] {
|
||||||
|
return &value[T]{
|
||||||
|
cond: sync.NewCond(&sync.Mutex{}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Const creates a read-only [Value] which will become available immediately
|
||||||
|
// upon calling Value() for the first time; it will never block.
|
||||||
|
func Const[T any](t T) Value[T] {
|
||||||
|
return &value[T]{
|
||||||
|
f: func() T { return t },
|
||||||
|
ready: true,
|
||||||
|
cond: sync.NewCond(&sync.Mutex{}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *value[T]) Value() T {
|
||||||
|
p.cond.L.Lock()
|
||||||
|
defer p.cond.L.Unlock()
|
||||||
|
for !p.ready {
|
||||||
|
p.cond.Wait()
|
||||||
|
}
|
||||||
|
return p.f()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *value[T]) ResolveFunc(fOnce func() T) {
|
||||||
|
p.cond.L.Lock()
|
||||||
|
p.f = sync.OnceValue(fOnce)
|
||||||
|
p.ready = true
|
||||||
|
p.cond.L.Unlock()
|
||||||
|
p.cond.Broadcast()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *value[T]) Resolve(value T) {
|
||||||
|
p.ResolveFunc(func() T { return value })
|
||||||
|
}
|
||||||
|
|
||||||
|
// Bind creates a new [MutableValue] whose ultimate value depends on the result
|
||||||
|
// of another [Value] that may not yet be available. When Value() is called on
|
||||||
|
// the result, it will cascade and trigger the full chain of initialization
|
||||||
|
// functions necessary to produce the final value.
|
||||||
|
//
|
||||||
|
// Care should be taken when using this function, as improper use can lead to
|
||||||
|
// deadlocks and cause values to never become available.
|
||||||
|
func Bind[T any, U any](dt Value[T], callback func(value T) U) Value[U] {
|
||||||
|
du := Deferred[U]()
|
||||||
|
du.ResolveFunc(func() U {
|
||||||
|
return callback(dt.Value())
|
||||||
|
})
|
||||||
|
return du
|
||||||
|
}
|
||||||
|
|
||||||
|
// Bind2 is like [Bind], but can accept two input values. The result will only
|
||||||
|
// become available once all input values become available.
|
||||||
|
//
|
||||||
|
// This function blocks to wait for each input value in sequence, but in a
|
||||||
|
// random order. Do not rely on the order of evaluation of the input values.
|
||||||
|
func Bind2[T any, U any, V any](dt Value[T], du Value[U], callback func(value1 T, value2 U) V) Value[V] {
|
||||||
|
dv := Deferred[V]()
|
||||||
|
dv.ResolveFunc(func() V {
|
||||||
|
if rand.IntN(2) == 0 { //nolint:gosec
|
||||||
|
return callback(dt.Value(), du.Value())
|
||||||
|
}
|
||||||
|
u := du.Value()
|
||||||
|
t := dt.Value()
|
||||||
|
return callback(t, u)
|
||||||
|
})
|
||||||
|
return dv
|
||||||
|
}
|
||||||
|
|
||||||
|
// List is a container for a slice of [Value] of type T, and is also a [Value]
|
||||||
|
// itself, for convenience. The Value() function will return a []T containing
|
||||||
|
// all resolved values for each element in the slice.
|
||||||
|
//
|
||||||
|
// A List's Value() function blocks to wait for each element in the slice in
|
||||||
|
// sequence, but in a random order. Do not rely on the order of evaluation of
|
||||||
|
// the slice elements.
|
||||||
|
type List[T any] []Value[T]
|
||||||
|
|
||||||
|
func (s List[T]) Value() []T {
|
||||||
|
values := make([]T, len(s))
|
||||||
|
for _, i := range rand.Perm(len(values)) {
|
||||||
|
values[i] = s[i].Value()
|
||||||
|
}
|
||||||
|
return values
|
||||||
|
}
|
|
@ -16,6 +16,7 @@ import (
|
||||||
"github.com/pomerium/pomerium/authenticate"
|
"github.com/pomerium/pomerium/authenticate"
|
||||||
"github.com/pomerium/pomerium/authorize"
|
"github.com/pomerium/pomerium/authorize"
|
||||||
"github.com/pomerium/pomerium/config"
|
"github.com/pomerium/pomerium/config"
|
||||||
|
"github.com/pomerium/pomerium/config/envoyconfig/filemgr"
|
||||||
databroker_service "github.com/pomerium/pomerium/databroker"
|
databroker_service "github.com/pomerium/pomerium/databroker"
|
||||||
"github.com/pomerium/pomerium/internal/autocert"
|
"github.com/pomerium/pomerium/internal/autocert"
|
||||||
"github.com/pomerium/pomerium/internal/controlplane"
|
"github.com/pomerium/pomerium/internal/controlplane"
|
||||||
|
@ -30,8 +31,29 @@ import (
|
||||||
"github.com/pomerium/pomerium/proxy"
|
"github.com/pomerium/pomerium/proxy"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type RunOptions struct {
|
||||||
|
fileMgr *filemgr.Manager
|
||||||
|
}
|
||||||
|
|
||||||
|
type RunOption func(*RunOptions)
|
||||||
|
|
||||||
|
func (o *RunOptions) apply(opts ...RunOption) {
|
||||||
|
for _, op := range opts {
|
||||||
|
op(o)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithOverrideFileManager(fileMgr *filemgr.Manager) RunOption {
|
||||||
|
return func(o *RunOptions) {
|
||||||
|
o.fileMgr = fileMgr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Run runs the main pomerium application.
|
// Run runs the main pomerium application.
|
||||||
func Run(ctx context.Context, src config.Source) error {
|
func Run(ctx context.Context, src config.Source, opts ...RunOption) error {
|
||||||
|
options := RunOptions{}
|
||||||
|
options.apply(opts...)
|
||||||
|
|
||||||
_, _ = maxprocs.Set(maxprocs.Logger(func(s string, i ...any) { log.Ctx(ctx).Debug().Msgf(s, i...) }))
|
_, _ = maxprocs.Set(maxprocs.Logger(func(s string, i ...any) { log.Ctx(ctx).Debug().Msgf(s, i...) }))
|
||||||
|
|
||||||
evt := log.Ctx(ctx).Info().
|
evt := log.Ctx(ctx).Info().
|
||||||
|
@ -68,10 +90,15 @@ func Run(ctx context.Context, src config.Source) error {
|
||||||
|
|
||||||
eventsMgr := events.New()
|
eventsMgr := events.New()
|
||||||
|
|
||||||
|
fileMgr := options.fileMgr
|
||||||
|
if fileMgr == nil {
|
||||||
|
fileMgr = filemgr.NewManager()
|
||||||
|
}
|
||||||
|
|
||||||
cfg := src.GetConfig()
|
cfg := src.GetConfig()
|
||||||
|
|
||||||
// setup the control plane
|
// setup the control plane
|
||||||
controlPlane, err := controlplane.NewServer(ctx, cfg, metricsMgr, eventsMgr)
|
controlPlane, err := controlplane.NewServer(ctx, cfg, metricsMgr, eventsMgr, fileMgr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("error creating control plane: %w", err)
|
return fmt.Errorf("error creating control plane: %w", err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
package cryptutil
|
package cryptutil
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
|
@ -15,10 +14,9 @@ import (
|
||||||
|
|
||||||
// GetCertPool gets a cert pool for the given CA or CAFile.
|
// GetCertPool gets a cert pool for the given CA or CAFile.
|
||||||
func GetCertPool(ca, caFile string) (*x509.CertPool, error) {
|
func GetCertPool(ca, caFile string) (*x509.CertPool, error) {
|
||||||
ctx := context.TODO()
|
|
||||||
rootCAs, err := x509.SystemCertPool()
|
rootCAs, err := x509.SystemCertPool()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Ctx(ctx).Error().Err(err).Msg("pkg/cryptutil: failed getting system cert pool making new one")
|
log.Error().Err(err).Msg("pkg/cryptutil: failed getting system cert pool making new one")
|
||||||
rootCAs = x509.NewCertPool()
|
rootCAs = x509.NewCertPool()
|
||||||
}
|
}
|
||||||
if ca == "" && caFile == "" {
|
if ca == "" && caFile == "" {
|
||||||
|
@ -40,7 +38,9 @@ func GetCertPool(ca, caFile string) (*x509.CertPool, error) {
|
||||||
if ok := rootCAs.AppendCertsFromPEM(data); !ok {
|
if ok := rootCAs.AppendCertsFromPEM(data); !ok {
|
||||||
return nil, fmt.Errorf("failed to append any PEM-encoded certificates")
|
return nil, fmt.Errorf("failed to append any PEM-encoded certificates")
|
||||||
}
|
}
|
||||||
log.Ctx(ctx).Debug().Msg("pkg/cryptutil: added custom certificate authority")
|
if !log.DebugDisableGlobalMessages.Load() {
|
||||||
|
log.Debug().Msg("pkg/cryptutil: added custom certificate authority")
|
||||||
|
}
|
||||||
return rootCAs, nil
|
return rootCAs, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -186,7 +186,25 @@ func (srv *Server) run(ctx context.Context, cfg *config.Config) error {
|
||||||
// monitor the process so we exit if it prematurely exits
|
// monitor the process so we exit if it prematurely exits
|
||||||
var monitorProcessCtx context.Context
|
var monitorProcessCtx context.Context
|
||||||
monitorProcessCtx, srv.monitorProcessCancel = context.WithCancel(context.WithoutCancel(ctx))
|
monitorProcessCtx, srv.monitorProcessCancel = context.WithCancel(context.WithoutCancel(ctx))
|
||||||
go srv.monitorProcess(monitorProcessCtx, int32(cmd.Process.Pid))
|
go func() {
|
||||||
|
pid := cmd.Process.Pid
|
||||||
|
err := srv.monitorProcess(monitorProcessCtx, int32(pid))
|
||||||
|
if err != nil && ctx.Err() == nil {
|
||||||
|
// If the envoy subprocess exits and ctx is not done, issue a fatal error.
|
||||||
|
// If ctx is done, the server is already exiting, and envoy is expected
|
||||||
|
// to be stopped along with it.
|
||||||
|
log.Ctx(ctx).
|
||||||
|
Fatal().
|
||||||
|
Int("pid", pid).
|
||||||
|
Err(err).
|
||||||
|
Send()
|
||||||
|
}
|
||||||
|
log.Ctx(ctx).
|
||||||
|
Debug().
|
||||||
|
Int("pid", pid).
|
||||||
|
Err(ctx.Err()).
|
||||||
|
Msg("envoy process monitor stopped")
|
||||||
|
}()
|
||||||
|
|
||||||
if srv.resourceMonitor != nil {
|
if srv.resourceMonitor != nil {
|
||||||
log.Ctx(ctx).Debug().Str("service", "envoy").Msg("starting resource monitor")
|
log.Ctx(ctx).Debug().Str("service", "envoy").Msg("starting resource monitor")
|
||||||
|
@ -300,7 +318,7 @@ func (srv *Server) handleLogs(ctx context.Context, rc io.ReadCloser) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (srv *Server) monitorProcess(ctx context.Context, pid int32) {
|
func (srv *Server) monitorProcess(ctx context.Context, pid int32) error {
|
||||||
log.Ctx(ctx).Debug().
|
log.Ctx(ctx).Debug().
|
||||||
Int32("pid", pid).
|
Int32("pid", pid).
|
||||||
Msg("envoy: start monitoring subprocess")
|
Msg("envoy: start monitoring subprocess")
|
||||||
|
@ -311,19 +329,15 @@ func (srv *Server) monitorProcess(ctx context.Context, pid int32) {
|
||||||
for {
|
for {
|
||||||
exists, err := process.PidExistsWithContext(ctx, pid)
|
exists, err := process.PidExistsWithContext(ctx, pid)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal().Err(err).
|
return fmt.Errorf("envoy: error retrieving subprocess information: %w", err)
|
||||||
Int32("pid", pid).
|
|
||||||
Msg("envoy: error retrieving subprocess information")
|
|
||||||
} else if !exists {
|
} else if !exists {
|
||||||
log.Fatal().Err(err).
|
return errors.New("envoy: subprocess exited")
|
||||||
Int32("pid", pid).
|
|
||||||
Msg("envoy: subprocess exited")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// wait for the next tick
|
// wait for the next tick
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return
|
return nil
|
||||||
case <-ticker.C:
|
case <-ticker.C:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -6,6 +6,7 @@ package envoy
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"os"
|
"os"
|
||||||
|
"path/filepath"
|
||||||
"strconv"
|
"strconv"
|
||||||
"sync"
|
"sync"
|
||||||
"syscall"
|
"syscall"
|
||||||
|
@ -17,7 +18,7 @@ import (
|
||||||
"github.com/pomerium/pomerium/internal/telemetry/metrics"
|
"github.com/pomerium/pomerium/internal/telemetry/metrics"
|
||||||
)
|
)
|
||||||
|
|
||||||
const baseIDPath = "/tmp/pomerium-envoy-base-id"
|
const baseIDName = "pomerium-envoy-base-id"
|
||||||
|
|
||||||
var restartEpoch struct {
|
var restartEpoch struct {
|
||||||
sync.Mutex
|
sync.Mutex
|
||||||
|
@ -89,7 +90,7 @@ func (srv *Server) prepareRunEnvoyCommand(ctx context.Context, sharedArgs []stri
|
||||||
} else {
|
} else {
|
||||||
args = append(args,
|
args = append(args,
|
||||||
"--use-dynamic-base-id",
|
"--use-dynamic-base-id",
|
||||||
"--base-id-path", baseIDPath,
|
"--base-id-path", filepath.Join(os.TempDir(), baseIDName),
|
||||||
)
|
)
|
||||||
restartEpoch.value = 1
|
restartEpoch.value = 1
|
||||||
}
|
}
|
||||||
|
@ -99,7 +100,7 @@ func (srv *Server) prepareRunEnvoyCommand(ctx context.Context, sharedArgs []stri
|
||||||
}
|
}
|
||||||
|
|
||||||
func readBaseID() (int, bool) {
|
func readBaseID() (int, bool) {
|
||||||
bs, err := os.ReadFile(baseIDPath)
|
bs, err := os.ReadFile(filepath.Join(os.TempDir(), baseIDName))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, false
|
return 0, false
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,6 +3,7 @@ package databroker
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
backoff "github.com/cenkalti/backoff/v4"
|
backoff "github.com/cenkalti/backoff/v4"
|
||||||
|
@ -71,12 +72,24 @@ type Syncer struct {
|
||||||
id string
|
id string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var DebugUseFasterBackoff atomic.Bool
|
||||||
|
|
||||||
// NewSyncer creates a new Syncer.
|
// NewSyncer creates a new Syncer.
|
||||||
func NewSyncer(ctx context.Context, id string, handler SyncerHandler, options ...SyncerOption) *Syncer {
|
func NewSyncer(ctx context.Context, id string, handler SyncerHandler, options ...SyncerOption) *Syncer {
|
||||||
closeCtx, closeCtxCancel := context.WithCancel(context.WithoutCancel(ctx))
|
closeCtx, closeCtxCancel := context.WithCancel(context.WithoutCancel(ctx))
|
||||||
|
|
||||||
bo := backoff.NewExponentialBackOff()
|
var bo *backoff.ExponentialBackOff
|
||||||
|
if DebugUseFasterBackoff.Load() {
|
||||||
|
bo = backoff.NewExponentialBackOff(
|
||||||
|
backoff.WithInitialInterval(10*time.Millisecond),
|
||||||
|
backoff.WithMultiplier(1.0),
|
||||||
|
backoff.WithMaxElapsedTime(100*time.Millisecond),
|
||||||
|
)
|
||||||
bo.MaxElapsedTime = 0
|
bo.MaxElapsedTime = 0
|
||||||
|
} else {
|
||||||
|
bo = backoff.NewExponentialBackOff()
|
||||||
|
bo.MaxElapsedTime = 0
|
||||||
|
}
|
||||||
s := &Syncer{
|
s := &Syncer{
|
||||||
cfg: getSyncerConfig(options...),
|
cfg: getSyncerConfig(options...),
|
||||||
handler: handler,
|
handler: handler,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue