mirror of
https://github.com/pomerium/pomerium.git
synced 2025-06-25 14:08:09 +02:00
Add client certificate utilities
This commit is contained in:
parent
6ed5752fa5
commit
490a301aa4
10 changed files with 355 additions and 20 deletions
|
@ -21,6 +21,7 @@ import (
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/config"
|
"github.com/pomerium/pomerium/config"
|
||||||
"github.com/pomerium/pomerium/internal/testenv"
|
"github.com/pomerium/pomerium/internal/testenv"
|
||||||
|
"github.com/pomerium/pomerium/internal/testenv/scenarios"
|
||||||
"github.com/pomerium/pomerium/internal/testenv/upstreams"
|
"github.com/pomerium/pomerium/internal/testenv/upstreams"
|
||||||
"github.com/pomerium/pomerium/internal/testenv/values"
|
"github.com/pomerium/pomerium/internal/testenv/values"
|
||||||
"github.com/pomerium/pomerium/pkg/cmd/pomerium"
|
"github.com/pomerium/pomerium/pkg/cmd/pomerium"
|
||||||
|
@ -124,6 +125,46 @@ func TestHTTP(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestClientCert(t *testing.T) {
|
||||||
|
env := testenv.New(t)
|
||||||
|
env.Add(scenarios.DownstreamMTLS(config.MTLSEnforcementRejectConnection))
|
||||||
|
|
||||||
|
up := upstreams.HTTP(nil)
|
||||||
|
up.Handle("/foo", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
fmt.Fprintln(w, "hello world")
|
||||||
|
})
|
||||||
|
|
||||||
|
clientCert := env.NewClientCert()
|
||||||
|
|
||||||
|
route := up.Route().
|
||||||
|
From(env.SubdomainURL("http")).
|
||||||
|
PPL(fmt.Sprintf(`{"allow":{"and":["client_certificate":{"fingerprint":%q}]}}`, clientCert.Fingerprint()))
|
||||||
|
|
||||||
|
env.AddUpstream(up)
|
||||||
|
env.Start()
|
||||||
|
|
||||||
|
recorder := env.NewLogRecorder()
|
||||||
|
|
||||||
|
resp, err := up.Get(route, upstreams.Path("/foo"), upstreams.ClientCert(clientCert))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
defer resp.Body.Close()
|
||||||
|
data, err := io.ReadAll(resp.Body)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "hello world\n", string(data))
|
||||||
|
|
||||||
|
recorder.Match([]map[string]any{
|
||||||
|
{
|
||||||
|
"service": "envoy",
|
||||||
|
"path": "/foo",
|
||||||
|
"method": "GET",
|
||||||
|
"message": "http-request",
|
||||||
|
"response-code-details": "via_upstream",
|
||||||
|
"client-certificate": clientCert,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func TestH2C(t *testing.T) {
|
func TestH2C(t *testing.T) {
|
||||||
if testing.Short() {
|
if testing.Short() {
|
||||||
t.SkipNow()
|
t.SkipNow()
|
||||||
|
|
|
@ -69,10 +69,17 @@ type Server struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewServer creates a new Server. Listener ports are chosen by the OS.
|
// NewServer creates a new Server. Listener ports are chosen by the OS.
|
||||||
func NewServer(ctx context.Context, cfg *config.Config, metricsMgr *config.MetricsManager, eventsMgr *events.Manager) (*Server, error) {
|
func NewServer(
|
||||||
|
ctx context.Context,
|
||||||
|
cfg *config.Config,
|
||||||
|
metricsMgr *config.MetricsManager,
|
||||||
|
eventsMgr *events.Manager,
|
||||||
|
fileMgr *filemgr.Manager,
|
||||||
|
) (*Server, error) {
|
||||||
srv := &Server{
|
srv := &Server{
|
||||||
metricsMgr: metricsMgr,
|
metricsMgr: metricsMgr,
|
||||||
EventsMgr: eventsMgr,
|
EventsMgr: eventsMgr,
|
||||||
|
filemgr: fileMgr,
|
||||||
reproxy: reproxy.New(),
|
reproxy: reproxy.New(),
|
||||||
haveSetCapacity: map[string]bool{},
|
haveSetCapacity: map[string]bool{},
|
||||||
updateConfig: make(chan *config.Config, 1),
|
updateConfig: make(chan *config.Config, 1),
|
||||||
|
@ -149,7 +156,6 @@ func NewServer(ctx context.Context, cfg *config.Config, metricsMgr *config.Metri
|
||||||
// metrics
|
// metrics
|
||||||
srv.MetricsRouter.Handle("/metrics", srv.metricsMgr)
|
srv.MetricsRouter.Handle("/metrics", srv.metricsMgr)
|
||||||
|
|
||||||
srv.filemgr = filemgr.NewManager()
|
|
||||||
srv.filemgr.ClearCache()
|
srv.filemgr.ClearCache()
|
||||||
|
|
||||||
srv.Builder = envoyconfig.New(
|
srv.Builder = envoyconfig.New(
|
||||||
|
|
|
@ -12,6 +12,7 @@ import (
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/config"
|
"github.com/pomerium/pomerium/config"
|
||||||
|
"github.com/pomerium/pomerium/config/envoyconfig/filemgr"
|
||||||
"github.com/pomerium/pomerium/internal/events"
|
"github.com/pomerium/pomerium/internal/events"
|
||||||
"github.com/pomerium/pomerium/pkg/netutil"
|
"github.com/pomerium/pomerium/pkg/netutil"
|
||||||
)
|
)
|
||||||
|
@ -38,7 +39,7 @@ func TestServerHTTP(t *testing.T) {
|
||||||
cfg.Options.SharedKey = "JDNjY2ITDlARvNaQXjc2Djk+GA6xeCy4KiozmZfdbTs="
|
cfg.Options.SharedKey = "JDNjY2ITDlARvNaQXjc2Djk+GA6xeCy4KiozmZfdbTs="
|
||||||
|
|
||||||
src := config.NewStaticSource(cfg)
|
src := config.NewStaticSource(cfg)
|
||||||
srv, err := NewServer(ctx, cfg, config.NewMetricsManager(ctx, src), events.New())
|
srv, err := NewServer(ctx, cfg, config.NewMetricsManager(ctx, src), events.New(), filemgr.NewManager(filemgr.WithCacheDir(t.TempDir())))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
go srv.Run(ctx)
|
go srv.Run(ctx)
|
||||||
|
|
||||||
|
|
|
@ -2,10 +2,20 @@ package testenv
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/rsa"
|
||||||
|
"crypto/sha256"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
|
"crypto/x509/pkix"
|
||||||
|
"encoding/asn1"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/hex"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"math/big"
|
||||||
|
"net"
|
||||||
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
"path"
|
"path"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
@ -13,13 +23,16 @@ import (
|
||||||
"strconv"
|
"strconv"
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/config"
|
"github.com/pomerium/pomerium/config"
|
||||||
|
"github.com/pomerium/pomerium/config/envoyconfig/filemgr"
|
||||||
"github.com/pomerium/pomerium/internal/log"
|
"github.com/pomerium/pomerium/internal/log"
|
||||||
"github.com/pomerium/pomerium/internal/testenv/values"
|
"github.com/pomerium/pomerium/internal/testenv/values"
|
||||||
"github.com/pomerium/pomerium/pkg/cmd/pomerium"
|
"github.com/pomerium/pomerium/pkg/cmd/pomerium"
|
||||||
"github.com/pomerium/pomerium/pkg/health"
|
"github.com/pomerium/pomerium/pkg/health"
|
||||||
"github.com/pomerium/pomerium/pkg/netutil"
|
"github.com/pomerium/pomerium/pkg/netutil"
|
||||||
|
"github.com/pomerium/pomerium/pkg/slices"
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
@ -33,14 +46,33 @@ type Environment interface {
|
||||||
// top-level logger scoped to this environment. It will be canceled when
|
// top-level logger scoped to this environment. It will be canceled when
|
||||||
// Stop() is called, or during test cleanup.
|
// Stop() is called, or during test cleanup.
|
||||||
Context() context.Context
|
Context() context.Context
|
||||||
|
|
||||||
|
Assert() *assert.Assertions
|
||||||
|
Require() *require.Assertions
|
||||||
|
|
||||||
// TempDir returns a unique temp directory for this context. Calling this
|
// TempDir returns a unique temp directory for this context. Calling this
|
||||||
// function multiple times returns the same path.
|
// function multiple times returns the same path.
|
||||||
TempDir() string
|
TempDir() string
|
||||||
|
// CACert returns the test environment's root CA certificate and private key.
|
||||||
|
CACert() *tls.Certificate
|
||||||
// ServerCAs returns a new [*x509.CertPool] containing the root CA certificate
|
// ServerCAs returns a new [*x509.CertPool] containing the root CA certificate
|
||||||
// used to sign the server cert and other test certificates.
|
// used to sign the server cert and other test certificates.
|
||||||
ServerCAs() *x509.CertPool
|
ServerCAs() *x509.CertPool
|
||||||
// ServerCert returns the Pomerium server's certificate.
|
// ServerCert returns the Pomerium server's certificate and private key.
|
||||||
ServerCert() *tls.Certificate
|
ServerCert() *tls.Certificate
|
||||||
|
// NewClientCert generates a new client certificate signed by the root CA
|
||||||
|
// certificate. One or more optional templates can be given, which can be
|
||||||
|
// used to set or override certain parameters when creating a certificate,
|
||||||
|
// including subject, SANs, or extensions. If more than one template is
|
||||||
|
// provided, they will be applied in order from left to right.
|
||||||
|
//
|
||||||
|
// By default (unless overridden in a template), the certificate will have
|
||||||
|
// its Common Name set to the file:line string of the call site. Calls to
|
||||||
|
// NewClientCert() on different lines will have different subjects. If
|
||||||
|
// multiple certs with the same subject are needed, wrap the call to this
|
||||||
|
// function in another helper function, or separate calls with commas on the
|
||||||
|
// same line.
|
||||||
|
NewClientCert(templateOverrides ...*x509.Certificate) *Certificate
|
||||||
|
|
||||||
// Add adds the given [Modifier] to the environment. All modifiers will be
|
// Add adds the given [Modifier] to the environment. All modifiers will be
|
||||||
// invoked upon calling Start() to apply individual modifications to the
|
// invoked upon calling Start() to apply individual modifications to the
|
||||||
|
@ -80,8 +112,22 @@ type Environment interface {
|
||||||
NewLogRecorder(opts ...LogRecorderOption) *LogRecorder
|
NewLogRecorder(opts ...LogRecorderOption) *LogRecorder
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type Certificate tls.Certificate
|
||||||
|
|
||||||
|
func (c *Certificate) Fingerprint() string {
|
||||||
|
sum := sha256.Sum256(c.Leaf.Raw)
|
||||||
|
return hex.EncodeToString(sum[:])
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Certificate) SPKIHash() string {
|
||||||
|
sum := sha256.Sum256(c.Leaf.RawSubjectPublicKeyInfo)
|
||||||
|
return base64.StdEncoding.EncodeToString(sum[:])
|
||||||
|
}
|
||||||
|
|
||||||
type environment struct {
|
type environment struct {
|
||||||
t testing.TB
|
t testing.TB
|
||||||
|
assert *assert.Assertions
|
||||||
|
require *require.Assertions
|
||||||
tempDir string
|
tempDir string
|
||||||
domain string
|
domain string
|
||||||
ports Ports
|
ports Ports
|
||||||
|
@ -125,6 +171,8 @@ func New(t testing.TB) Environment {
|
||||||
|
|
||||||
e := &environment{
|
e := &environment{
|
||||||
t: t,
|
t: t,
|
||||||
|
assert: assert.New(t),
|
||||||
|
require: require.New(t),
|
||||||
tempDir: t.TempDir(),
|
tempDir: t.TempDir(),
|
||||||
ports: Ports{
|
ports: Ports{
|
||||||
http: values.Deferred[int](),
|
http: values.Deferred[int](),
|
||||||
|
@ -141,13 +189,14 @@ func New(t testing.TB) Environment {
|
||||||
copyFile := func(src, dstRel string) {
|
copyFile := func(src, dstRel string) {
|
||||||
data, err := os.ReadFile(src)
|
data, err := os.ReadFile(src)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NoError(t, os.WriteFile(filepath.Join(e.tempDir, dstRel), data, 0o666))
|
require.NoError(t, os.WriteFile(filepath.Join(e.tempDir, dstRel), data, 0o600))
|
||||||
}
|
}
|
||||||
|
|
||||||
certsToCopy := []string{
|
certsToCopy := []string{
|
||||||
"trusted.pem",
|
"trusted.pem",
|
||||||
"trusted-key.pem",
|
"trusted-key.pem",
|
||||||
"ca.pem",
|
"ca.pem",
|
||||||
|
"ca-key.pem",
|
||||||
}
|
}
|
||||||
for _, crt := range certsToCopy {
|
for _, crt := range certsToCopy {
|
||||||
copyFile(filepath.Join(workspaceFolder, "integration/tpl/files", crt), filepath.Join("certs/", filepath.Base(crt)))
|
copyFile(filepath.Join(workspaceFolder, "integration/tpl/files", crt), filepath.Join("certs/", filepath.Base(crt)))
|
||||||
|
@ -174,12 +223,29 @@ func (e *environment) Context() context.Context {
|
||||||
return ContextWithEnv(e.ctx, e)
|
return ContextWithEnv(e.ctx, e)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (e *environment) Assert() *assert.Assertions {
|
||||||
|
return e.assert
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *environment) Require() *require.Assertions {
|
||||||
|
return e.require
|
||||||
|
}
|
||||||
|
|
||||||
func (e *environment) SubdomainURL(subdomain string) values.Value[string] {
|
func (e *environment) SubdomainURL(subdomain string) values.Value[string] {
|
||||||
return values.Bind(e.ports.http, func(port int) string {
|
return values.Bind(e.ports.http, func(port int) string {
|
||||||
return fmt.Sprintf("https://%s.%s:%d", subdomain, e.domain, port)
|
return fmt.Sprintf("https://%s.%s:%d", subdomain, e.domain, port)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (e *environment) CACert() *tls.Certificate {
|
||||||
|
caCert, err := tls.LoadX509KeyPair(
|
||||||
|
filepath.Join(e.tempDir, "certs", "ca.pem"),
|
||||||
|
filepath.Join(e.tempDir, "certs", "ca-key.pem"),
|
||||||
|
)
|
||||||
|
require.NoError(e.t, err)
|
||||||
|
return &caCert
|
||||||
|
}
|
||||||
|
|
||||||
func (e *environment) ServerCAs() *x509.CertPool {
|
func (e *environment) ServerCAs() *x509.CertPool {
|
||||||
pool := x509.NewCertPool()
|
pool := x509.NewCertPool()
|
||||||
caCert, err := os.ReadFile(filepath.Join(e.tempDir, "certs", "ca.pem"))
|
caCert, err := os.ReadFile(filepath.Join(e.tempDir, "certs", "ca.pem"))
|
||||||
|
@ -213,20 +279,39 @@ func (e *environment) Start() {
|
||||||
require.NoError(e.t, err)
|
require.NoError(e.t, err)
|
||||||
port0, _ := strconv.Atoi(ports[0])
|
port0, _ := strconv.Atoi(ports[0])
|
||||||
e.ports.http.Resolve(port0)
|
e.ports.http.Resolve(port0)
|
||||||
|
cfg.Options.AutocertOptions = config.AutocertOptions{Enable: false}
|
||||||
cfg.Options.LogLevel = config.LogLevelInfo
|
cfg.Options.LogLevel = config.LogLevelInfo
|
||||||
cfg.Options.ProxyLogLevel = config.LogLevelInfo
|
cfg.Options.ProxyLogLevel = config.LogLevelInfo
|
||||||
cfg.Options.Addr = fmt.Sprintf("127.0.0.1:%d", port0)
|
cfg.Options.Addr = fmt.Sprintf("127.0.0.1:%d", port0)
|
||||||
cfg.Options.CertFile = filepath.Join(e.tempDir, "certs", "trusted.pem")
|
cfg.Options.CertFile = filepath.Join(e.tempDir, "certs", "trusted.pem")
|
||||||
cfg.Options.KeyFile = filepath.Join(e.tempDir, "certs", "trusted-key.pem")
|
cfg.Options.KeyFile = filepath.Join(e.tempDir, "certs", "trusted-key.pem")
|
||||||
|
cfg.Options.AccessLogFields = []log.AccessLogField{
|
||||||
|
log.AccessLogFieldAuthority,
|
||||||
|
log.AccessLogFieldDuration,
|
||||||
|
log.AccessLogFieldForwardedFor,
|
||||||
|
log.AccessLogFieldIP,
|
||||||
|
log.AccessLogFieldMethod,
|
||||||
|
log.AccessLogFieldPath,
|
||||||
|
log.AccessLogFieldQuery,
|
||||||
|
log.AccessLogFieldReferer,
|
||||||
|
log.AccessLogFieldRequestID,
|
||||||
|
log.AccessLogFieldResponseCode,
|
||||||
|
log.AccessLogFieldResponseCodeDetails,
|
||||||
|
log.AccessLogFieldSize,
|
||||||
|
log.AccessLogFieldUpstreamCluster,
|
||||||
|
log.AccessLogFieldUserAgent,
|
||||||
|
log.AccessLogFieldClientCertificate,
|
||||||
|
}
|
||||||
cfg.AllocatePorts(*(*[6]string)(ports[1:]))
|
cfg.AllocatePorts(*(*[6]string)(ports[1:]))
|
||||||
|
|
||||||
e.AddTask(TaskFunc(func(ctx context.Context) error {
|
e.AddTask(TaskFunc(func(ctx context.Context) error {
|
||||||
|
fileMgr := filemgr.NewManager(filemgr.WithCacheDir(filepath.Join(e.TempDir(), "cache")))
|
||||||
src := config.NewStaticSource(cfg)
|
src := config.NewStaticSource(cfg)
|
||||||
for _, mod := range e.mods {
|
for _, mod := range e.mods {
|
||||||
mod.Value.Modify(cfg)
|
mod.Value.Modify(cfg)
|
||||||
require.NoError(e.t, cfg.Options.Validate(), "invoking modifier resulted in an invalid configuration:\nadded by: "+mod.Caller)
|
require.NoError(e.t, cfg.Options.Validate(), "invoking modifier resulted in an invalid configuration:\nadded by: "+mod.Caller)
|
||||||
}
|
}
|
||||||
return pomerium.Run(e.ctx, src)
|
return pomerium.Run(e.ctx, src, pomerium.WithOverrideFileManager(fileMgr))
|
||||||
}))
|
}))
|
||||||
|
|
||||||
for i, task := range e.tasks {
|
for i, task := range e.tasks {
|
||||||
|
@ -238,6 +323,64 @@ func (e *environment) Start() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (e *environment) NewClientCert(templateOverrides ...*x509.Certificate) *Certificate {
|
||||||
|
caCert := e.CACert()
|
||||||
|
|
||||||
|
priv, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||||
|
require.NoError(e.t, err)
|
||||||
|
|
||||||
|
sn, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128))
|
||||||
|
require.NoError(e.t, err)
|
||||||
|
now := time.Now()
|
||||||
|
tmpl := &x509.Certificate{
|
||||||
|
SerialNumber: sn,
|
||||||
|
Subject: pkix.Name{
|
||||||
|
CommonName: getCaller(),
|
||||||
|
},
|
||||||
|
NotBefore: now,
|
||||||
|
NotAfter: now.Add(12 * time.Hour),
|
||||||
|
KeyUsage: x509.KeyUsageDigitalSignature,
|
||||||
|
ExtKeyUsage: []x509.ExtKeyUsage{
|
||||||
|
x509.ExtKeyUsageClientAuth,
|
||||||
|
},
|
||||||
|
BasicConstraintsValid: true,
|
||||||
|
}
|
||||||
|
for _, override := range templateOverrides {
|
||||||
|
tmpl.CRLDistributionPoints = slices.Unique(append(tmpl.CRLDistributionPoints, override.CRLDistributionPoints...))
|
||||||
|
tmpl.DNSNames = slices.Unique(append(tmpl.DNSNames, override.DNSNames...))
|
||||||
|
tmpl.EmailAddresses = slices.Unique(append(tmpl.EmailAddresses, override.EmailAddresses...))
|
||||||
|
tmpl.ExtraExtensions = append(tmpl.ExtraExtensions, override.ExtraExtensions...)
|
||||||
|
tmpl.IPAddresses = slices.UniqueBy(append(tmpl.IPAddresses, override.IPAddresses...), net.IP.String)
|
||||||
|
tmpl.URIs = slices.UniqueBy(append(tmpl.URIs, override.URIs...), (*url.URL).String)
|
||||||
|
tmpl.UnknownExtKeyUsage = slices.UniqueBy(append(tmpl.UnknownExtKeyUsage, override.UnknownExtKeyUsage...), asn1.ObjectIdentifier.String)
|
||||||
|
seq := override.Subject.ToRDNSequence()
|
||||||
|
tmpl.Subject.FillFromRDNSequence(&seq)
|
||||||
|
tmpl.KeyUsage |= override.KeyUsage
|
||||||
|
tmpl.ExtKeyUsage = slices.Unique(append(tmpl.ExtKeyUsage, override.ExtKeyUsage...))
|
||||||
|
}
|
||||||
|
|
||||||
|
clientCertDER, err := x509.CreateCertificate(rand.Reader, tmpl, caCert.Leaf, priv.Public(), caCert.PrivateKey)
|
||||||
|
require.NoError(e.t, err)
|
||||||
|
|
||||||
|
cert, err := x509.ParseCertificate(clientCertDER)
|
||||||
|
require.NoError(e.t, err)
|
||||||
|
|
||||||
|
clientCert := &tls.Certificate{
|
||||||
|
Certificate: [][]byte{cert.Raw, caCert.Leaf.Raw},
|
||||||
|
PrivateKey: priv,
|
||||||
|
Leaf: cert,
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = clientCert.Leaf.Verify(x509.VerifyOptions{
|
||||||
|
KeyUsages: []x509.ExtKeyUsage{
|
||||||
|
x509.ExtKeyUsageClientAuth,
|
||||||
|
},
|
||||||
|
Roots: e.ServerCAs(),
|
||||||
|
})
|
||||||
|
require.NoError(e.t, err, "bug: generated client cert is not valid")
|
||||||
|
return (*Certificate)(clientCert)
|
||||||
|
}
|
||||||
|
|
||||||
func (e *environment) Stop() {
|
func (e *environment) Stop() {
|
||||||
e.cleanupOnce.Do(func() {
|
e.cleanupOnce.Do(func() {
|
||||||
e.cancel(ErrCauseManualStop)
|
e.cancel(ErrCauseManualStop)
|
||||||
|
|
|
@ -4,8 +4,11 @@ import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"reflect"
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
@ -115,9 +118,14 @@ func (lr *LogRecorder) Logs() []map[string]any {
|
||||||
// Match stops the log recorder (if it is not already stopped), then asserts
|
// Match stops the log recorder (if it is not already stopped), then asserts
|
||||||
// that the given expected logs were captured. The expected logs may contain
|
// that the given expected logs were captured. The expected logs may contain
|
||||||
// partial or complete log entries. By default, logs must only match the fields
|
// partial or complete log entries. By default, logs must only match the fields
|
||||||
// given, and may contain additional fields that will be ignored. For details,
|
// given, and may contain additional fields that will be ignored.
|
||||||
// see [OpenMap] and [ClosedMap]. As a special case, using [json.Number] as the
|
//
|
||||||
// expected value will convert the actual value to a string before comparison.
|
// There are several special-case value types that can be used to customize the
|
||||||
|
// matching behavior, and/or simplify some common use cases, as follows:
|
||||||
|
// - [OpenMap] and [ClosedMap] can be used to control matching logic
|
||||||
|
// - [json.Number] will convert the actual value to a string before comparison
|
||||||
|
// - [*tls.Certificate] or [*x509.Certificate] will expand to the fields that
|
||||||
|
// would be logged for this certificate
|
||||||
func (lr *LogRecorder) Match(expectedLogs []map[string]any) {
|
func (lr *LogRecorder) Match(expectedLogs []map[string]any) {
|
||||||
lr.collectLogs()
|
lr.collectLogs()
|
||||||
var match func(expected, actual map[string]any, open bool) (bool, int)
|
var match func(expected, actual map[string]any, open bool) (bool, int)
|
||||||
|
@ -132,15 +140,50 @@ func (lr *LogRecorder) Match(expectedLogs []map[string]any) {
|
||||||
|
|
||||||
switch actualValue := actualValue.(type) {
|
switch actualValue := actualValue.(type) {
|
||||||
case map[string]any:
|
case map[string]any:
|
||||||
switch value := value.(type) {
|
switch expectedValue := value.(type) {
|
||||||
case ClosedMap:
|
case ClosedMap:
|
||||||
ok, s := match(value, actualValue, false)
|
ok, s := match(expectedValue, actualValue, false)
|
||||||
score += s * 2
|
score += s * 2
|
||||||
if !ok {
|
if !ok {
|
||||||
return false, score
|
return false, score
|
||||||
}
|
}
|
||||||
case OpenMap:
|
case OpenMap:
|
||||||
ok, s := match(value, actualValue, true)
|
ok, s := match(expectedValue, actualValue, true)
|
||||||
|
score += s
|
||||||
|
if !ok {
|
||||||
|
return false, score
|
||||||
|
}
|
||||||
|
case *tls.Certificate, *Certificate, *x509.Certificate:
|
||||||
|
var leaf *x509.Certificate
|
||||||
|
switch expectedValue := expectedValue.(type) {
|
||||||
|
case *tls.Certificate:
|
||||||
|
leaf = expectedValue.Leaf
|
||||||
|
case *Certificate:
|
||||||
|
leaf = expectedValue.Leaf
|
||||||
|
case *x509.Certificate:
|
||||||
|
leaf = expectedValue
|
||||||
|
}
|
||||||
|
|
||||||
|
// keep logic consistent with controlplane.populateCertEventDict()
|
||||||
|
expected := map[string]any{}
|
||||||
|
if iss := leaf.Issuer.String(); iss != "" {
|
||||||
|
expected["issuer"] = iss
|
||||||
|
}
|
||||||
|
if sub := leaf.Subject.String(); sub != "" {
|
||||||
|
expected["subject"] = sub
|
||||||
|
}
|
||||||
|
sans := []string{}
|
||||||
|
for _, dnsSAN := range leaf.DNSNames {
|
||||||
|
sans = append(sans, "DNS:"+dnsSAN)
|
||||||
|
}
|
||||||
|
for _, uriSAN := range leaf.URIs {
|
||||||
|
sans = append(sans, "URI:"+uriSAN.String())
|
||||||
|
}
|
||||||
|
if len(sans) > 0 {
|
||||||
|
expected["subjectAltName"] = sans
|
||||||
|
}
|
||||||
|
|
||||||
|
ok, s := match(expected, actualValue, false)
|
||||||
score += s
|
score += s
|
||||||
if !ok {
|
if !ok {
|
||||||
return false, score
|
return false, score
|
||||||
|
@ -164,9 +207,25 @@ func (lr *LogRecorder) Match(expectedLogs []map[string]any) {
|
||||||
}
|
}
|
||||||
score++
|
score++
|
||||||
default:
|
default:
|
||||||
|
// handle slices
|
||||||
|
if reflect.TypeOf(actualValue).Kind() == reflect.Slice {
|
||||||
|
if reflect.TypeOf(value) != reflect.TypeOf(actualValue) {
|
||||||
|
return false, score
|
||||||
|
}
|
||||||
|
actualSlice := reflect.ValueOf(actualValue)
|
||||||
|
expectedSlice := reflect.ValueOf(value)
|
||||||
|
totalScore := 0
|
||||||
|
for i := range min(actualSlice.Len(), expectedSlice.Len()) {
|
||||||
|
if actualSlice.Index(i).Equal(expectedSlice.Index(i)) {
|
||||||
|
totalScore++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
score += totalScore
|
||||||
|
} else {
|
||||||
panic(fmt.Sprintf("test bug: add check for type %T in assertMatchingLogs", actualValue))
|
panic(fmt.Sprintf("test bug: add check for type %T in assertMatchingLogs", actualValue))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
if !open && len(expected) != len(actual) {
|
if !open && len(expected) != len(actual) {
|
||||||
return false, score
|
return false, score
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,9 +2,11 @@ package testenv
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/config"
|
"github.com/pomerium/pomerium/config"
|
||||||
"github.com/pomerium/pomerium/internal/testenv/values"
|
"github.com/pomerium/pomerium/internal/testenv/values"
|
||||||
|
"github.com/pomerium/pomerium/pkg/policy/parser"
|
||||||
)
|
)
|
||||||
|
|
||||||
// PolicyRoute is a [Route] implementation suitable for most common use cases
|
// PolicyRoute is a [Route] implementation suitable for most common use cases
|
||||||
|
@ -54,6 +56,20 @@ func (b *PolicyRoute) Policy(edit func(*config.Policy)) Route {
|
||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// PPL implements Route.
|
||||||
|
func (b *PolicyRoute) PPL(ppl string) Route {
|
||||||
|
pplPolicy, err := parser.ParseYAML(strings.NewReader(ppl))
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
b.edits = append(b.edits, func(p *config.Policy) {
|
||||||
|
p.Policy = &config.PPLPolicy{
|
||||||
|
Policy: pplPolicy,
|
||||||
|
}
|
||||||
|
})
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
// To implements Route.
|
// To implements Route.
|
||||||
func (b *PolicyRoute) URL() values.Value[string] {
|
func (b *PolicyRoute) URL() values.Value[string] {
|
||||||
return b.from
|
return b.from
|
||||||
|
|
24
internal/testenv/scenarios/mtls.go
Normal file
24
internal/testenv/scenarios/mtls.go
Normal file
|
@ -0,0 +1,24 @@
|
||||||
|
package scenarios
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/pem"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/config"
|
||||||
|
"github.com/pomerium/pomerium/internal/testenv"
|
||||||
|
)
|
||||||
|
|
||||||
|
func DownstreamMTLS(mode config.MTLSEnforcement) testenv.Modifier {
|
||||||
|
return testenv.ModifierFunc(func(ctx context.Context, cfg *config.Config) {
|
||||||
|
env := testenv.EnvFromContext(ctx)
|
||||||
|
block := pem.Block{
|
||||||
|
Type: "CERTIFICATE",
|
||||||
|
Bytes: env.CACert().Leaf.Raw,
|
||||||
|
}
|
||||||
|
cfg.Options.DownstreamMTLS = config.DownstreamMTLSSettings{
|
||||||
|
CA: base64.StdEncoding.EncodeToString(pem.EncodeToMemory(&block)),
|
||||||
|
Enforcement: mode,
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
|
@ -77,10 +77,24 @@ func (m Modifiers) Modify(cfg *config.Config) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type ModifierFunc func(cfg *config.Config)
|
type modifierFunc struct {
|
||||||
|
fn func(ctx context.Context, cfg *config.Config)
|
||||||
|
ctx context.Context
|
||||||
|
}
|
||||||
|
|
||||||
func (f ModifierFunc) Modify(cfg *config.Config) {
|
// Attach implements Modifier.
|
||||||
f(cfg)
|
func (f *modifierFunc) Attach(ctx context.Context) {
|
||||||
|
f.ctx = ctx
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *modifierFunc) Modify(cfg *config.Config) {
|
||||||
|
f.fn(f.ctx, cfg)
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ Modifier = (*modifierFunc)(nil)
|
||||||
|
|
||||||
|
func ModifierFunc(fn func(ctx context.Context, cfg *config.Config)) Modifier {
|
||||||
|
return &modifierFunc{fn: fn}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Task represents a background task that can be added to an [Environment] to
|
// Task represents a background task that can be added to an [Environment] to
|
||||||
|
@ -116,6 +130,7 @@ type Route interface {
|
||||||
URL() values.Value[string]
|
URL() values.Value[string]
|
||||||
To(toUrl values.Value[string]) Route
|
To(toUrl values.Value[string]) Route
|
||||||
Policy(edit func(*config.Policy)) Route
|
Policy(edit func(*config.Policy)) Route
|
||||||
|
PPL(ppl string) Route
|
||||||
// add more methods here as they become needed
|
// add more methods here as they become needed
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -77,9 +77,11 @@ func Body(body any) RequestOption {
|
||||||
}
|
}
|
||||||
|
|
||||||
// ClientCert adds a client certificate to the request.
|
// ClientCert adds a client certificate to the request.
|
||||||
func ClientCert(cert tls.Certificate) RequestOption {
|
func ClientCert[T interface {
|
||||||
|
*testenv.Certificate | *tls.Certificate
|
||||||
|
}](cert T) RequestOption {
|
||||||
return func(o *RequestOptions) {
|
return func(o *RequestOptions) {
|
||||||
o.clientCerts = append(o.clientCerts, cert)
|
o.clientCerts = append(o.clientCerts, *(*tls.Certificate)(cert))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -250,6 +252,7 @@ func (h *httpUpstream) Do(method string, r testenv.Route, opts ...RequestOption)
|
||||||
}
|
}
|
||||||
return retry.NewTerminalError(err)
|
return retry.NewTerminalError(err)
|
||||||
}
|
}
|
||||||
|
resp.Body.Close()
|
||||||
return nil
|
return nil
|
||||||
}, retry.WithMaxInterval(1*time.Second)); err != nil {
|
}, retry.WithMaxInterval(1*time.Second)); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|
|
@ -16,6 +16,7 @@ import (
|
||||||
"github.com/pomerium/pomerium/authenticate"
|
"github.com/pomerium/pomerium/authenticate"
|
||||||
"github.com/pomerium/pomerium/authorize"
|
"github.com/pomerium/pomerium/authorize"
|
||||||
"github.com/pomerium/pomerium/config"
|
"github.com/pomerium/pomerium/config"
|
||||||
|
"github.com/pomerium/pomerium/config/envoyconfig/filemgr"
|
||||||
databroker_service "github.com/pomerium/pomerium/databroker"
|
databroker_service "github.com/pomerium/pomerium/databroker"
|
||||||
"github.com/pomerium/pomerium/internal/autocert"
|
"github.com/pomerium/pomerium/internal/autocert"
|
||||||
"github.com/pomerium/pomerium/internal/controlplane"
|
"github.com/pomerium/pomerium/internal/controlplane"
|
||||||
|
@ -30,8 +31,29 @@ import (
|
||||||
"github.com/pomerium/pomerium/proxy"
|
"github.com/pomerium/pomerium/proxy"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type RunOptions struct {
|
||||||
|
fileMgr *filemgr.Manager
|
||||||
|
}
|
||||||
|
|
||||||
|
type RunOption func(*RunOptions)
|
||||||
|
|
||||||
|
func (o *RunOptions) apply(opts ...RunOption) {
|
||||||
|
for _, op := range opts {
|
||||||
|
op(o)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithOverrideFileManager(fileMgr *filemgr.Manager) RunOption {
|
||||||
|
return func(o *RunOptions) {
|
||||||
|
o.fileMgr = fileMgr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Run runs the main pomerium application.
|
// Run runs the main pomerium application.
|
||||||
func Run(ctx context.Context, src config.Source) error {
|
func Run(ctx context.Context, src config.Source, opts ...RunOption) error {
|
||||||
|
options := RunOptions{}
|
||||||
|
options.apply(opts...)
|
||||||
|
|
||||||
_, _ = maxprocs.Set(maxprocs.Logger(func(s string, i ...any) { log.Debug(context.Background()).Msgf(s, i...) }))
|
_, _ = maxprocs.Set(maxprocs.Logger(func(s string, i ...any) { log.Debug(context.Background()).Msgf(s, i...) }))
|
||||||
|
|
||||||
evt := log.Info(ctx).
|
evt := log.Info(ctx).
|
||||||
|
@ -68,10 +90,15 @@ func Run(ctx context.Context, src config.Source) error {
|
||||||
|
|
||||||
eventsMgr := events.New()
|
eventsMgr := events.New()
|
||||||
|
|
||||||
|
fileMgr := options.fileMgr
|
||||||
|
if fileMgr == nil {
|
||||||
|
fileMgr = filemgr.NewManager()
|
||||||
|
}
|
||||||
|
|
||||||
cfg := src.GetConfig()
|
cfg := src.GetConfig()
|
||||||
|
|
||||||
// setup the control plane
|
// setup the control plane
|
||||||
controlPlane, err := controlplane.NewServer(ctx, cfg, metricsMgr, eventsMgr)
|
controlPlane, err := controlplane.NewServer(ctx, cfg, metricsMgr, eventsMgr, fileMgr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("error creating control plane: %w", err)
|
return fmt.Errorf("error creating control plane: %w", err)
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue