diff --git a/config/envoyconfig/protocols_int_test.go b/config/envoyconfig/protocols_int_test.go index 3ad88180c..f0fe85256 100644 --- a/config/envoyconfig/protocols_int_test.go +++ b/config/envoyconfig/protocols_int_test.go @@ -21,6 +21,7 @@ import ( "github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/internal/testenv" + "github.com/pomerium/pomerium/internal/testenv/scenarios" "github.com/pomerium/pomerium/internal/testenv/upstreams" "github.com/pomerium/pomerium/internal/testenv/values" "github.com/pomerium/pomerium/pkg/cmd/pomerium" @@ -124,6 +125,46 @@ func TestHTTP(t *testing.T) { }) } +func TestClientCert(t *testing.T) { + env := testenv.New(t) + env.Add(scenarios.DownstreamMTLS(config.MTLSEnforcementRejectConnection)) + + up := upstreams.HTTP(nil) + up.Handle("/foo", func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, "hello world") + }) + + clientCert := env.NewClientCert() + + route := up.Route(). + From(env.SubdomainURL("http")). + PPL(fmt.Sprintf(`{"allow":{"and":["client_certificate":{"fingerprint":%q}]}}`, clientCert.Fingerprint())) + + env.AddUpstream(up) + env.Start() + + recorder := env.NewLogRecorder() + + resp, err := up.Get(route, upstreams.Path("/foo"), upstreams.ClientCert(clientCert)) + require.NoError(t, err) + + defer resp.Body.Close() + data, err := io.ReadAll(resp.Body) + require.NoError(t, err) + assert.Equal(t, "hello world\n", string(data)) + + recorder.Match([]map[string]any{ + { + "service": "envoy", + "path": "/foo", + "method": "GET", + "message": "http-request", + "response-code-details": "via_upstream", + "client-certificate": clientCert, + }, + }) +} + func TestH2C(t *testing.T) { if testing.Short() { t.SkipNow() diff --git a/internal/controlplane/server.go b/internal/controlplane/server.go index 622585427..dc3305365 100644 --- a/internal/controlplane/server.go +++ b/internal/controlplane/server.go @@ -69,10 +69,17 @@ type Server struct { } // NewServer creates a new Server. Listener ports are chosen by the OS. -func NewServer(ctx context.Context, cfg *config.Config, metricsMgr *config.MetricsManager, eventsMgr *events.Manager) (*Server, error) { +func NewServer( + ctx context.Context, + cfg *config.Config, + metricsMgr *config.MetricsManager, + eventsMgr *events.Manager, + fileMgr *filemgr.Manager, +) (*Server, error) { srv := &Server{ metricsMgr: metricsMgr, EventsMgr: eventsMgr, + filemgr: fileMgr, reproxy: reproxy.New(), haveSetCapacity: map[string]bool{}, updateConfig: make(chan *config.Config, 1), @@ -149,7 +156,6 @@ func NewServer(ctx context.Context, cfg *config.Config, metricsMgr *config.Metri // metrics srv.MetricsRouter.Handle("/metrics", srv.metricsMgr) - srv.filemgr = filemgr.NewManager() srv.filemgr.ClearCache() srv.Builder = envoyconfig.New( diff --git a/internal/controlplane/server_test.go b/internal/controlplane/server_test.go index 9221d2e1b..d863eac8b 100644 --- a/internal/controlplane/server_test.go +++ b/internal/controlplane/server_test.go @@ -12,6 +12,7 @@ import ( "github.com/stretchr/testify/require" "github.com/pomerium/pomerium/config" + "github.com/pomerium/pomerium/config/envoyconfig/filemgr" "github.com/pomerium/pomerium/internal/events" "github.com/pomerium/pomerium/pkg/netutil" ) @@ -38,7 +39,7 @@ func TestServerHTTP(t *testing.T) { cfg.Options.SharedKey = "JDNjY2ITDlARvNaQXjc2Djk+GA6xeCy4KiozmZfdbTs=" src := config.NewStaticSource(cfg) - srv, err := NewServer(ctx, cfg, config.NewMetricsManager(ctx, src), events.New()) + srv, err := NewServer(ctx, cfg, config.NewMetricsManager(ctx, src), events.New(), filemgr.NewManager(filemgr.WithCacheDir(t.TempDir()))) require.NoError(t, err) go srv.Run(ctx) diff --git a/internal/testenv/environment.go b/internal/testenv/environment.go index 286620de8..363723d6e 100644 --- a/internal/testenv/environment.go +++ b/internal/testenv/environment.go @@ -2,10 +2,20 @@ package testenv import ( "context" + "crypto/rand" + "crypto/rsa" + "crypto/sha256" "crypto/tls" "crypto/x509" + "crypto/x509/pkix" + "encoding/asn1" + "encoding/base64" + "encoding/hex" "errors" "fmt" + "math/big" + "net" + "net/url" "os" "path" "path/filepath" @@ -13,13 +23,16 @@ import ( "strconv" "sync" "testing" + "time" "github.com/pomerium/pomerium/config" + "github.com/pomerium/pomerium/config/envoyconfig/filemgr" "github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/testenv/values" "github.com/pomerium/pomerium/pkg/cmd/pomerium" "github.com/pomerium/pomerium/pkg/health" "github.com/pomerium/pomerium/pkg/netutil" + "github.com/pomerium/pomerium/pkg/slices" "github.com/rs/zerolog" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -33,14 +46,33 @@ type Environment interface { // top-level logger scoped to this environment. It will be canceled when // Stop() is called, or during test cleanup. Context() context.Context + + Assert() *assert.Assertions + Require() *require.Assertions + // TempDir returns a unique temp directory for this context. Calling this // function multiple times returns the same path. TempDir() string + // CACert returns the test environment's root CA certificate and private key. + CACert() *tls.Certificate // ServerCAs returns a new [*x509.CertPool] containing the root CA certificate // used to sign the server cert and other test certificates. ServerCAs() *x509.CertPool - // ServerCert returns the Pomerium server's certificate. + // ServerCert returns the Pomerium server's certificate and private key. ServerCert() *tls.Certificate + // NewClientCert generates a new client certificate signed by the root CA + // certificate. One or more optional templates can be given, which can be + // used to set or override certain parameters when creating a certificate, + // including subject, SANs, or extensions. If more than one template is + // provided, they will be applied in order from left to right. + // + // By default (unless overridden in a template), the certificate will have + // its Common Name set to the file:line string of the call site. Calls to + // NewClientCert() on different lines will have different subjects. If + // multiple certs with the same subject are needed, wrap the call to this + // function in another helper function, or separate calls with commas on the + // same line. + NewClientCert(templateOverrides ...*x509.Certificate) *Certificate // Add adds the given [Modifier] to the environment. All modifiers will be // invoked upon calling Start() to apply individual modifications to the @@ -80,8 +112,22 @@ type Environment interface { NewLogRecorder(opts ...LogRecorderOption) *LogRecorder } +type Certificate tls.Certificate + +func (c *Certificate) Fingerprint() string { + sum := sha256.Sum256(c.Leaf.Raw) + return hex.EncodeToString(sum[:]) +} + +func (c *Certificate) SPKIHash() string { + sum := sha256.Sum256(c.Leaf.RawSubjectPublicKeyInfo) + return base64.StdEncoding.EncodeToString(sum[:]) +} + type environment struct { t testing.TB + assert *assert.Assertions + require *require.Assertions tempDir string domain string ports Ports @@ -125,6 +171,8 @@ func New(t testing.TB) Environment { e := &environment{ t: t, + assert: assert.New(t), + require: require.New(t), tempDir: t.TempDir(), ports: Ports{ http: values.Deferred[int](), @@ -141,13 +189,14 @@ func New(t testing.TB) Environment { copyFile := func(src, dstRel string) { data, err := os.ReadFile(src) require.NoError(t, err) - require.NoError(t, os.WriteFile(filepath.Join(e.tempDir, dstRel), data, 0o666)) + require.NoError(t, os.WriteFile(filepath.Join(e.tempDir, dstRel), data, 0o600)) } certsToCopy := []string{ "trusted.pem", "trusted-key.pem", "ca.pem", + "ca-key.pem", } for _, crt := range certsToCopy { copyFile(filepath.Join(workspaceFolder, "integration/tpl/files", crt), filepath.Join("certs/", filepath.Base(crt))) @@ -174,12 +223,29 @@ func (e *environment) Context() context.Context { return ContextWithEnv(e.ctx, e) } +func (e *environment) Assert() *assert.Assertions { + return e.assert +} + +func (e *environment) Require() *require.Assertions { + return e.require +} + func (e *environment) SubdomainURL(subdomain string) values.Value[string] { return values.Bind(e.ports.http, func(port int) string { return fmt.Sprintf("https://%s.%s:%d", subdomain, e.domain, port) }) } +func (e *environment) CACert() *tls.Certificate { + caCert, err := tls.LoadX509KeyPair( + filepath.Join(e.tempDir, "certs", "ca.pem"), + filepath.Join(e.tempDir, "certs", "ca-key.pem"), + ) + require.NoError(e.t, err) + return &caCert +} + func (e *environment) ServerCAs() *x509.CertPool { pool := x509.NewCertPool() caCert, err := os.ReadFile(filepath.Join(e.tempDir, "certs", "ca.pem")) @@ -213,20 +279,39 @@ func (e *environment) Start() { require.NoError(e.t, err) port0, _ := strconv.Atoi(ports[0]) e.ports.http.Resolve(port0) + cfg.Options.AutocertOptions = config.AutocertOptions{Enable: false} cfg.Options.LogLevel = config.LogLevelInfo cfg.Options.ProxyLogLevel = config.LogLevelInfo cfg.Options.Addr = fmt.Sprintf("127.0.0.1:%d", port0) cfg.Options.CertFile = filepath.Join(e.tempDir, "certs", "trusted.pem") cfg.Options.KeyFile = filepath.Join(e.tempDir, "certs", "trusted-key.pem") + cfg.Options.AccessLogFields = []log.AccessLogField{ + log.AccessLogFieldAuthority, + log.AccessLogFieldDuration, + log.AccessLogFieldForwardedFor, + log.AccessLogFieldIP, + log.AccessLogFieldMethod, + log.AccessLogFieldPath, + log.AccessLogFieldQuery, + log.AccessLogFieldReferer, + log.AccessLogFieldRequestID, + log.AccessLogFieldResponseCode, + log.AccessLogFieldResponseCodeDetails, + log.AccessLogFieldSize, + log.AccessLogFieldUpstreamCluster, + log.AccessLogFieldUserAgent, + log.AccessLogFieldClientCertificate, + } cfg.AllocatePorts(*(*[6]string)(ports[1:])) e.AddTask(TaskFunc(func(ctx context.Context) error { + fileMgr := filemgr.NewManager(filemgr.WithCacheDir(filepath.Join(e.TempDir(), "cache"))) src := config.NewStaticSource(cfg) for _, mod := range e.mods { mod.Value.Modify(cfg) require.NoError(e.t, cfg.Options.Validate(), "invoking modifier resulted in an invalid configuration:\nadded by: "+mod.Caller) } - return pomerium.Run(e.ctx, src) + return pomerium.Run(e.ctx, src, pomerium.WithOverrideFileManager(fileMgr)) })) for i, task := range e.tasks { @@ -238,6 +323,64 @@ func (e *environment) Start() { } } +func (e *environment) NewClientCert(templateOverrides ...*x509.Certificate) *Certificate { + caCert := e.CACert() + + priv, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(e.t, err) + + sn, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128)) + require.NoError(e.t, err) + now := time.Now() + tmpl := &x509.Certificate{ + SerialNumber: sn, + Subject: pkix.Name{ + CommonName: getCaller(), + }, + NotBefore: now, + NotAfter: now.Add(12 * time.Hour), + KeyUsage: x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{ + x509.ExtKeyUsageClientAuth, + }, + BasicConstraintsValid: true, + } + for _, override := range templateOverrides { + tmpl.CRLDistributionPoints = slices.Unique(append(tmpl.CRLDistributionPoints, override.CRLDistributionPoints...)) + tmpl.DNSNames = slices.Unique(append(tmpl.DNSNames, override.DNSNames...)) + tmpl.EmailAddresses = slices.Unique(append(tmpl.EmailAddresses, override.EmailAddresses...)) + tmpl.ExtraExtensions = append(tmpl.ExtraExtensions, override.ExtraExtensions...) + tmpl.IPAddresses = slices.UniqueBy(append(tmpl.IPAddresses, override.IPAddresses...), net.IP.String) + tmpl.URIs = slices.UniqueBy(append(tmpl.URIs, override.URIs...), (*url.URL).String) + tmpl.UnknownExtKeyUsage = slices.UniqueBy(append(tmpl.UnknownExtKeyUsage, override.UnknownExtKeyUsage...), asn1.ObjectIdentifier.String) + seq := override.Subject.ToRDNSequence() + tmpl.Subject.FillFromRDNSequence(&seq) + tmpl.KeyUsage |= override.KeyUsage + tmpl.ExtKeyUsage = slices.Unique(append(tmpl.ExtKeyUsage, override.ExtKeyUsage...)) + } + + clientCertDER, err := x509.CreateCertificate(rand.Reader, tmpl, caCert.Leaf, priv.Public(), caCert.PrivateKey) + require.NoError(e.t, err) + + cert, err := x509.ParseCertificate(clientCertDER) + require.NoError(e.t, err) + + clientCert := &tls.Certificate{ + Certificate: [][]byte{cert.Raw, caCert.Leaf.Raw}, + PrivateKey: priv, + Leaf: cert, + } + + _, err = clientCert.Leaf.Verify(x509.VerifyOptions{ + KeyUsages: []x509.ExtKeyUsage{ + x509.ExtKeyUsageClientAuth, + }, + Roots: e.ServerCAs(), + }) + require.NoError(e.t, err, "bug: generated client cert is not valid") + return (*Certificate)(clientCert) +} + func (e *environment) Stop() { e.cleanupOnce.Do(func() { e.cancel(ErrCauseManualStop) diff --git a/internal/testenv/logs.go b/internal/testenv/logs.go index 93f410ddc..71e4456ef 100644 --- a/internal/testenv/logs.go +++ b/internal/testenv/logs.go @@ -4,8 +4,11 @@ import ( "bufio" "bytes" "context" + "crypto/tls" + "crypto/x509" "encoding/json" "fmt" + "reflect" "sync" "testing" "time" @@ -115,9 +118,14 @@ func (lr *LogRecorder) Logs() []map[string]any { // Match stops the log recorder (if it is not already stopped), then asserts // that the given expected logs were captured. The expected logs may contain // partial or complete log entries. By default, logs must only match the fields -// given, and may contain additional fields that will be ignored. For details, -// see [OpenMap] and [ClosedMap]. As a special case, using [json.Number] as the -// expected value will convert the actual value to a string before comparison. +// given, and may contain additional fields that will be ignored. +// +// There are several special-case value types that can be used to customize the +// matching behavior, and/or simplify some common use cases, as follows: +// - [OpenMap] and [ClosedMap] can be used to control matching logic +// - [json.Number] will convert the actual value to a string before comparison +// - [*tls.Certificate] or [*x509.Certificate] will expand to the fields that +// would be logged for this certificate func (lr *LogRecorder) Match(expectedLogs []map[string]any) { lr.collectLogs() var match func(expected, actual map[string]any, open bool) (bool, int) @@ -132,15 +140,50 @@ func (lr *LogRecorder) Match(expectedLogs []map[string]any) { switch actualValue := actualValue.(type) { case map[string]any: - switch value := value.(type) { + switch expectedValue := value.(type) { case ClosedMap: - ok, s := match(value, actualValue, false) + ok, s := match(expectedValue, actualValue, false) score += s * 2 if !ok { return false, score } case OpenMap: - ok, s := match(value, actualValue, true) + ok, s := match(expectedValue, actualValue, true) + score += s + if !ok { + return false, score + } + case *tls.Certificate, *Certificate, *x509.Certificate: + var leaf *x509.Certificate + switch expectedValue := expectedValue.(type) { + case *tls.Certificate: + leaf = expectedValue.Leaf + case *Certificate: + leaf = expectedValue.Leaf + case *x509.Certificate: + leaf = expectedValue + } + + // keep logic consistent with controlplane.populateCertEventDict() + expected := map[string]any{} + if iss := leaf.Issuer.String(); iss != "" { + expected["issuer"] = iss + } + if sub := leaf.Subject.String(); sub != "" { + expected["subject"] = sub + } + sans := []string{} + for _, dnsSAN := range leaf.DNSNames { + sans = append(sans, "DNS:"+dnsSAN) + } + for _, uriSAN := range leaf.URIs { + sans = append(sans, "URI:"+uriSAN.String()) + } + if len(sans) > 0 { + expected["subjectAltName"] = sans + } + + ok, s := match(expected, actualValue, false) score += s if !ok { return false, score @@ -164,7 +207,23 @@ func (lr *LogRecorder) Match(expectedLogs []map[string]any) { } score++ default: - panic(fmt.Sprintf("test bug: add check for type %T in assertMatchingLogs", actualValue)) + // handle slices + if reflect.TypeOf(actualValue).Kind() == reflect.Slice { + if reflect.TypeOf(value) != reflect.TypeOf(actualValue) { + return false, score + } + actualSlice := reflect.ValueOf(actualValue) + expectedSlice := reflect.ValueOf(value) + totalScore := 0 + for i := range min(actualSlice.Len(), expectedSlice.Len()) { + if actualSlice.Index(i).Equal(expectedSlice.Index(i)) { + totalScore++ + } + } + score += totalScore + } else { + panic(fmt.Sprintf("test bug: add check for type %T in assertMatchingLogs", actualValue)) + } } } if !open && len(expected) != len(actual) { diff --git a/internal/testenv/route.go b/internal/testenv/route.go index 26d4fb475..ce1d7e58d 100644 --- a/internal/testenv/route.go +++ b/internal/testenv/route.go @@ -2,9 +2,11 @@ package testenv import ( "net/url" + "strings" "github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/internal/testenv/values" + "github.com/pomerium/pomerium/pkg/policy/parser" ) // PolicyRoute is a [Route] implementation suitable for most common use cases @@ -54,6 +56,20 @@ func (b *PolicyRoute) Policy(edit func(*config.Policy)) Route { return b } +// PPL implements Route. +func (b *PolicyRoute) PPL(ppl string) Route { + pplPolicy, err := parser.ParseYAML(strings.NewReader(ppl)) + if err != nil { + panic(err) + } + b.edits = append(b.edits, func(p *config.Policy) { + p.Policy = &config.PPLPolicy{ + Policy: pplPolicy, + } + }) + return b +} + // To implements Route. func (b *PolicyRoute) URL() values.Value[string] { return b.from diff --git a/internal/testenv/scenarios/mtls.go b/internal/testenv/scenarios/mtls.go new file mode 100644 index 000000000..b7bb2eb5c --- /dev/null +++ b/internal/testenv/scenarios/mtls.go @@ -0,0 +1,24 @@ +package scenarios + +import ( + "context" + "encoding/base64" + "encoding/pem" + + "github.com/pomerium/pomerium/config" + "github.com/pomerium/pomerium/internal/testenv" +) + +func DownstreamMTLS(mode config.MTLSEnforcement) testenv.Modifier { + return testenv.ModifierFunc(func(ctx context.Context, cfg *config.Config) { + env := testenv.EnvFromContext(ctx) + block := pem.Block{ + Type: "CERTIFICATE", + Bytes: env.CACert().Leaf.Raw, + } + cfg.Options.DownstreamMTLS = config.DownstreamMTLSSettings{ + CA: base64.StdEncoding.EncodeToString(pem.EncodeToMemory(&block)), + Enforcement: mode, + } + }) +} diff --git a/internal/testenv/types.go b/internal/testenv/types.go index 56ccd896b..44c9e2c61 100644 --- a/internal/testenv/types.go +++ b/internal/testenv/types.go @@ -77,10 +77,24 @@ func (m Modifiers) Modify(cfg *config.Config) { } } -type ModifierFunc func(cfg *config.Config) +type modifierFunc struct { + fn func(ctx context.Context, cfg *config.Config) + ctx context.Context +} -func (f ModifierFunc) Modify(cfg *config.Config) { - f(cfg) +// Attach implements Modifier. +func (f *modifierFunc) Attach(ctx context.Context) { + f.ctx = ctx +} + +func (f *modifierFunc) Modify(cfg *config.Config) { + f.fn(f.ctx, cfg) +} + +var _ Modifier = (*modifierFunc)(nil) + +func ModifierFunc(fn func(ctx context.Context, cfg *config.Config)) Modifier { + return &modifierFunc{fn: fn} } // Task represents a background task that can be added to an [Environment] to @@ -116,6 +130,7 @@ type Route interface { URL() values.Value[string] To(toUrl values.Value[string]) Route Policy(edit func(*config.Policy)) Route + PPL(ppl string) Route // add more methods here as they become needed } diff --git a/internal/testenv/upstreams/http.go b/internal/testenv/upstreams/http.go index e5687fede..51fed7919 100644 --- a/internal/testenv/upstreams/http.go +++ b/internal/testenv/upstreams/http.go @@ -77,9 +77,11 @@ func Body(body any) RequestOption { } // ClientCert adds a client certificate to the request. -func ClientCert(cert tls.Certificate) RequestOption { +func ClientCert[T interface { + *testenv.Certificate | *tls.Certificate +}](cert T) RequestOption { return func(o *RequestOptions) { - o.clientCerts = append(o.clientCerts, cert) + o.clientCerts = append(o.clientCerts, *(*tls.Certificate)(cert)) } } @@ -250,6 +252,7 @@ func (h *httpUpstream) Do(method string, r testenv.Route, opts ...RequestOption) } return retry.NewTerminalError(err) } + resp.Body.Close() return nil }, retry.WithMaxInterval(1*time.Second)); err != nil { return nil, err diff --git a/pkg/cmd/pomerium/pomerium.go b/pkg/cmd/pomerium/pomerium.go index d556adbaf..dd525a2b7 100644 --- a/pkg/cmd/pomerium/pomerium.go +++ b/pkg/cmd/pomerium/pomerium.go @@ -16,6 +16,7 @@ import ( "github.com/pomerium/pomerium/authenticate" "github.com/pomerium/pomerium/authorize" "github.com/pomerium/pomerium/config" + "github.com/pomerium/pomerium/config/envoyconfig/filemgr" databroker_service "github.com/pomerium/pomerium/databroker" "github.com/pomerium/pomerium/internal/autocert" "github.com/pomerium/pomerium/internal/controlplane" @@ -30,8 +31,29 @@ import ( "github.com/pomerium/pomerium/proxy" ) +type RunOptions struct { + fileMgr *filemgr.Manager +} + +type RunOption func(*RunOptions) + +func (o *RunOptions) apply(opts ...RunOption) { + for _, op := range opts { + op(o) + } +} + +func WithOverrideFileManager(fileMgr *filemgr.Manager) RunOption { + return func(o *RunOptions) { + o.fileMgr = fileMgr + } +} + // Run runs the main pomerium application. -func Run(ctx context.Context, src config.Source) error { +func Run(ctx context.Context, src config.Source, opts ...RunOption) error { + options := RunOptions{} + options.apply(opts...) + _, _ = maxprocs.Set(maxprocs.Logger(func(s string, i ...any) { log.Debug(context.Background()).Msgf(s, i...) })) evt := log.Info(ctx). @@ -68,10 +90,15 @@ func Run(ctx context.Context, src config.Source) error { eventsMgr := events.New() + fileMgr := options.fileMgr + if fileMgr == nil { + fileMgr = filemgr.NewManager() + } + cfg := src.GetConfig() // setup the control plane - controlPlane, err := controlplane.NewServer(ctx, cfg, metricsMgr, eventsMgr) + controlPlane, err := controlplane.NewServer(ctx, cfg, metricsMgr, eventsMgr, fileMgr) if err != nil { return fmt.Errorf("error creating control plane: %w", err) }