Add client certificate utilities

This commit is contained in:
Joe Kralicky 2024-08-26 20:00:18 -04:00
parent 6ed5752fa5
commit 490a301aa4
No known key found for this signature in database
GPG key ID: 75C4875F34A9FB79
10 changed files with 355 additions and 20 deletions

View file

@ -21,6 +21,7 @@ import (
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/testenv"
"github.com/pomerium/pomerium/internal/testenv/scenarios"
"github.com/pomerium/pomerium/internal/testenv/upstreams"
"github.com/pomerium/pomerium/internal/testenv/values"
"github.com/pomerium/pomerium/pkg/cmd/pomerium"
@ -124,6 +125,46 @@ func TestHTTP(t *testing.T) {
})
}
func TestClientCert(t *testing.T) {
env := testenv.New(t)
env.Add(scenarios.DownstreamMTLS(config.MTLSEnforcementRejectConnection))
up := upstreams.HTTP(nil)
up.Handle("/foo", func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "hello world")
})
clientCert := env.NewClientCert()
route := up.Route().
From(env.SubdomainURL("http")).
PPL(fmt.Sprintf(`{"allow":{"and":["client_certificate":{"fingerprint":%q}]}}`, clientCert.Fingerprint()))
env.AddUpstream(up)
env.Start()
recorder := env.NewLogRecorder()
resp, err := up.Get(route, upstreams.Path("/foo"), upstreams.ClientCert(clientCert))
require.NoError(t, err)
defer resp.Body.Close()
data, err := io.ReadAll(resp.Body)
require.NoError(t, err)
assert.Equal(t, "hello world\n", string(data))
recorder.Match([]map[string]any{
{
"service": "envoy",
"path": "/foo",
"method": "GET",
"message": "http-request",
"response-code-details": "via_upstream",
"client-certificate": clientCert,
},
})
}
func TestH2C(t *testing.T) {
if testing.Short() {
t.SkipNow()

View file

@ -69,10 +69,17 @@ type Server struct {
}
// NewServer creates a new Server. Listener ports are chosen by the OS.
func NewServer(ctx context.Context, cfg *config.Config, metricsMgr *config.MetricsManager, eventsMgr *events.Manager) (*Server, error) {
func NewServer(
ctx context.Context,
cfg *config.Config,
metricsMgr *config.MetricsManager,
eventsMgr *events.Manager,
fileMgr *filemgr.Manager,
) (*Server, error) {
srv := &Server{
metricsMgr: metricsMgr,
EventsMgr: eventsMgr,
filemgr: fileMgr,
reproxy: reproxy.New(),
haveSetCapacity: map[string]bool{},
updateConfig: make(chan *config.Config, 1),
@ -149,7 +156,6 @@ func NewServer(ctx context.Context, cfg *config.Config, metricsMgr *config.Metri
// metrics
srv.MetricsRouter.Handle("/metrics", srv.metricsMgr)
srv.filemgr = filemgr.NewManager()
srv.filemgr.ClearCache()
srv.Builder = envoyconfig.New(

View file

@ -12,6 +12,7 @@ import (
"github.com/stretchr/testify/require"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/config/envoyconfig/filemgr"
"github.com/pomerium/pomerium/internal/events"
"github.com/pomerium/pomerium/pkg/netutil"
)
@ -38,7 +39,7 @@ func TestServerHTTP(t *testing.T) {
cfg.Options.SharedKey = "JDNjY2ITDlARvNaQXjc2Djk+GA6xeCy4KiozmZfdbTs="
src := config.NewStaticSource(cfg)
srv, err := NewServer(ctx, cfg, config.NewMetricsManager(ctx, src), events.New())
srv, err := NewServer(ctx, cfg, config.NewMetricsManager(ctx, src), events.New(), filemgr.NewManager(filemgr.WithCacheDir(t.TempDir())))
require.NoError(t, err)
go srv.Run(ctx)

View file

@ -2,10 +2,20 @@ package testenv
import (
"context"
"crypto/rand"
"crypto/rsa"
"crypto/sha256"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/asn1"
"encoding/base64"
"encoding/hex"
"errors"
"fmt"
"math/big"
"net"
"net/url"
"os"
"path"
"path/filepath"
@ -13,13 +23,16 @@ import (
"strconv"
"sync"
"testing"
"time"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/config/envoyconfig/filemgr"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/testenv/values"
"github.com/pomerium/pomerium/pkg/cmd/pomerium"
"github.com/pomerium/pomerium/pkg/health"
"github.com/pomerium/pomerium/pkg/netutil"
"github.com/pomerium/pomerium/pkg/slices"
"github.com/rs/zerolog"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@ -33,14 +46,33 @@ type Environment interface {
// top-level logger scoped to this environment. It will be canceled when
// Stop() is called, or during test cleanup.
Context() context.Context
Assert() *assert.Assertions
Require() *require.Assertions
// TempDir returns a unique temp directory for this context. Calling this
// function multiple times returns the same path.
TempDir() string
// CACert returns the test environment's root CA certificate and private key.
CACert() *tls.Certificate
// ServerCAs returns a new [*x509.CertPool] containing the root CA certificate
// used to sign the server cert and other test certificates.
ServerCAs() *x509.CertPool
// ServerCert returns the Pomerium server's certificate.
// ServerCert returns the Pomerium server's certificate and private key.
ServerCert() *tls.Certificate
// NewClientCert generates a new client certificate signed by the root CA
// certificate. One or more optional templates can be given, which can be
// used to set or override certain parameters when creating a certificate,
// including subject, SANs, or extensions. If more than one template is
// provided, they will be applied in order from left to right.
//
// By default (unless overridden in a template), the certificate will have
// its Common Name set to the file:line string of the call site. Calls to
// NewClientCert() on different lines will have different subjects. If
// multiple certs with the same subject are needed, wrap the call to this
// function in another helper function, or separate calls with commas on the
// same line.
NewClientCert(templateOverrides ...*x509.Certificate) *Certificate
// Add adds the given [Modifier] to the environment. All modifiers will be
// invoked upon calling Start() to apply individual modifications to the
@ -80,8 +112,22 @@ type Environment interface {
NewLogRecorder(opts ...LogRecorderOption) *LogRecorder
}
type Certificate tls.Certificate
func (c *Certificate) Fingerprint() string {
sum := sha256.Sum256(c.Leaf.Raw)
return hex.EncodeToString(sum[:])
}
func (c *Certificate) SPKIHash() string {
sum := sha256.Sum256(c.Leaf.RawSubjectPublicKeyInfo)
return base64.StdEncoding.EncodeToString(sum[:])
}
type environment struct {
t testing.TB
assert *assert.Assertions
require *require.Assertions
tempDir string
domain string
ports Ports
@ -125,6 +171,8 @@ func New(t testing.TB) Environment {
e := &environment{
t: t,
assert: assert.New(t),
require: require.New(t),
tempDir: t.TempDir(),
ports: Ports{
http: values.Deferred[int](),
@ -141,13 +189,14 @@ func New(t testing.TB) Environment {
copyFile := func(src, dstRel string) {
data, err := os.ReadFile(src)
require.NoError(t, err)
require.NoError(t, os.WriteFile(filepath.Join(e.tempDir, dstRel), data, 0o666))
require.NoError(t, os.WriteFile(filepath.Join(e.tempDir, dstRel), data, 0o600))
}
certsToCopy := []string{
"trusted.pem",
"trusted-key.pem",
"ca.pem",
"ca-key.pem",
}
for _, crt := range certsToCopy {
copyFile(filepath.Join(workspaceFolder, "integration/tpl/files", crt), filepath.Join("certs/", filepath.Base(crt)))
@ -174,12 +223,29 @@ func (e *environment) Context() context.Context {
return ContextWithEnv(e.ctx, e)
}
func (e *environment) Assert() *assert.Assertions {
return e.assert
}
func (e *environment) Require() *require.Assertions {
return e.require
}
func (e *environment) SubdomainURL(subdomain string) values.Value[string] {
return values.Bind(e.ports.http, func(port int) string {
return fmt.Sprintf("https://%s.%s:%d", subdomain, e.domain, port)
})
}
func (e *environment) CACert() *tls.Certificate {
caCert, err := tls.LoadX509KeyPair(
filepath.Join(e.tempDir, "certs", "ca.pem"),
filepath.Join(e.tempDir, "certs", "ca-key.pem"),
)
require.NoError(e.t, err)
return &caCert
}
func (e *environment) ServerCAs() *x509.CertPool {
pool := x509.NewCertPool()
caCert, err := os.ReadFile(filepath.Join(e.tempDir, "certs", "ca.pem"))
@ -213,20 +279,39 @@ func (e *environment) Start() {
require.NoError(e.t, err)
port0, _ := strconv.Atoi(ports[0])
e.ports.http.Resolve(port0)
cfg.Options.AutocertOptions = config.AutocertOptions{Enable: false}
cfg.Options.LogLevel = config.LogLevelInfo
cfg.Options.ProxyLogLevel = config.LogLevelInfo
cfg.Options.Addr = fmt.Sprintf("127.0.0.1:%d", port0)
cfg.Options.CertFile = filepath.Join(e.tempDir, "certs", "trusted.pem")
cfg.Options.KeyFile = filepath.Join(e.tempDir, "certs", "trusted-key.pem")
cfg.Options.AccessLogFields = []log.AccessLogField{
log.AccessLogFieldAuthority,
log.AccessLogFieldDuration,
log.AccessLogFieldForwardedFor,
log.AccessLogFieldIP,
log.AccessLogFieldMethod,
log.AccessLogFieldPath,
log.AccessLogFieldQuery,
log.AccessLogFieldReferer,
log.AccessLogFieldRequestID,
log.AccessLogFieldResponseCode,
log.AccessLogFieldResponseCodeDetails,
log.AccessLogFieldSize,
log.AccessLogFieldUpstreamCluster,
log.AccessLogFieldUserAgent,
log.AccessLogFieldClientCertificate,
}
cfg.AllocatePorts(*(*[6]string)(ports[1:]))
e.AddTask(TaskFunc(func(ctx context.Context) error {
fileMgr := filemgr.NewManager(filemgr.WithCacheDir(filepath.Join(e.TempDir(), "cache")))
src := config.NewStaticSource(cfg)
for _, mod := range e.mods {
mod.Value.Modify(cfg)
require.NoError(e.t, cfg.Options.Validate(), "invoking modifier resulted in an invalid configuration:\nadded by: "+mod.Caller)
}
return pomerium.Run(e.ctx, src)
return pomerium.Run(e.ctx, src, pomerium.WithOverrideFileManager(fileMgr))
}))
for i, task := range e.tasks {
@ -238,6 +323,64 @@ func (e *environment) Start() {
}
}
func (e *environment) NewClientCert(templateOverrides ...*x509.Certificate) *Certificate {
caCert := e.CACert()
priv, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(e.t, err)
sn, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128))
require.NoError(e.t, err)
now := time.Now()
tmpl := &x509.Certificate{
SerialNumber: sn,
Subject: pkix.Name{
CommonName: getCaller(),
},
NotBefore: now,
NotAfter: now.Add(12 * time.Hour),
KeyUsage: x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{
x509.ExtKeyUsageClientAuth,
},
BasicConstraintsValid: true,
}
for _, override := range templateOverrides {
tmpl.CRLDistributionPoints = slices.Unique(append(tmpl.CRLDistributionPoints, override.CRLDistributionPoints...))
tmpl.DNSNames = slices.Unique(append(tmpl.DNSNames, override.DNSNames...))
tmpl.EmailAddresses = slices.Unique(append(tmpl.EmailAddresses, override.EmailAddresses...))
tmpl.ExtraExtensions = append(tmpl.ExtraExtensions, override.ExtraExtensions...)
tmpl.IPAddresses = slices.UniqueBy(append(tmpl.IPAddresses, override.IPAddresses...), net.IP.String)
tmpl.URIs = slices.UniqueBy(append(tmpl.URIs, override.URIs...), (*url.URL).String)
tmpl.UnknownExtKeyUsage = slices.UniqueBy(append(tmpl.UnknownExtKeyUsage, override.UnknownExtKeyUsage...), asn1.ObjectIdentifier.String)
seq := override.Subject.ToRDNSequence()
tmpl.Subject.FillFromRDNSequence(&seq)
tmpl.KeyUsage |= override.KeyUsage
tmpl.ExtKeyUsage = slices.Unique(append(tmpl.ExtKeyUsage, override.ExtKeyUsage...))
}
clientCertDER, err := x509.CreateCertificate(rand.Reader, tmpl, caCert.Leaf, priv.Public(), caCert.PrivateKey)
require.NoError(e.t, err)
cert, err := x509.ParseCertificate(clientCertDER)
require.NoError(e.t, err)
clientCert := &tls.Certificate{
Certificate: [][]byte{cert.Raw, caCert.Leaf.Raw},
PrivateKey: priv,
Leaf: cert,
}
_, err = clientCert.Leaf.Verify(x509.VerifyOptions{
KeyUsages: []x509.ExtKeyUsage{
x509.ExtKeyUsageClientAuth,
},
Roots: e.ServerCAs(),
})
require.NoError(e.t, err, "bug: generated client cert is not valid")
return (*Certificate)(clientCert)
}
func (e *environment) Stop() {
e.cleanupOnce.Do(func() {
e.cancel(ErrCauseManualStop)

View file

@ -4,8 +4,11 @@ import (
"bufio"
"bytes"
"context"
"crypto/tls"
"crypto/x509"
"encoding/json"
"fmt"
"reflect"
"sync"
"testing"
"time"
@ -115,9 +118,14 @@ func (lr *LogRecorder) Logs() []map[string]any {
// Match stops the log recorder (if it is not already stopped), then asserts
// that the given expected logs were captured. The expected logs may contain
// partial or complete log entries. By default, logs must only match the fields
// given, and may contain additional fields that will be ignored. For details,
// see [OpenMap] and [ClosedMap]. As a special case, using [json.Number] as the
// expected value will convert the actual value to a string before comparison.
// given, and may contain additional fields that will be ignored.
//
// There are several special-case value types that can be used to customize the
// matching behavior, and/or simplify some common use cases, as follows:
// - [OpenMap] and [ClosedMap] can be used to control matching logic
// - [json.Number] will convert the actual value to a string before comparison
// - [*tls.Certificate] or [*x509.Certificate] will expand to the fields that
// would be logged for this certificate
func (lr *LogRecorder) Match(expectedLogs []map[string]any) {
lr.collectLogs()
var match func(expected, actual map[string]any, open bool) (bool, int)
@ -132,15 +140,50 @@ func (lr *LogRecorder) Match(expectedLogs []map[string]any) {
switch actualValue := actualValue.(type) {
case map[string]any:
switch value := value.(type) {
switch expectedValue := value.(type) {
case ClosedMap:
ok, s := match(value, actualValue, false)
ok, s := match(expectedValue, actualValue, false)
score += s * 2
if !ok {
return false, score
}
case OpenMap:
ok, s := match(value, actualValue, true)
ok, s := match(expectedValue, actualValue, true)
score += s
if !ok {
return false, score
}
case *tls.Certificate, *Certificate, *x509.Certificate:
var leaf *x509.Certificate
switch expectedValue := expectedValue.(type) {
case *tls.Certificate:
leaf = expectedValue.Leaf
case *Certificate:
leaf = expectedValue.Leaf
case *x509.Certificate:
leaf = expectedValue
}
// keep logic consistent with controlplane.populateCertEventDict()
expected := map[string]any{}
if iss := leaf.Issuer.String(); iss != "" {
expected["issuer"] = iss
}
if sub := leaf.Subject.String(); sub != "" {
expected["subject"] = sub
}
sans := []string{}
for _, dnsSAN := range leaf.DNSNames {
sans = append(sans, "DNS:"+dnsSAN)
}
for _, uriSAN := range leaf.URIs {
sans = append(sans, "URI:"+uriSAN.String())
}
if len(sans) > 0 {
expected["subjectAltName"] = sans
}
ok, s := match(expected, actualValue, false)
score += s
if !ok {
return false, score
@ -164,7 +207,23 @@ func (lr *LogRecorder) Match(expectedLogs []map[string]any) {
}
score++
default:
panic(fmt.Sprintf("test bug: add check for type %T in assertMatchingLogs", actualValue))
// handle slices
if reflect.TypeOf(actualValue).Kind() == reflect.Slice {
if reflect.TypeOf(value) != reflect.TypeOf(actualValue) {
return false, score
}
actualSlice := reflect.ValueOf(actualValue)
expectedSlice := reflect.ValueOf(value)
totalScore := 0
for i := range min(actualSlice.Len(), expectedSlice.Len()) {
if actualSlice.Index(i).Equal(expectedSlice.Index(i)) {
totalScore++
}
}
score += totalScore
} else {
panic(fmt.Sprintf("test bug: add check for type %T in assertMatchingLogs", actualValue))
}
}
}
if !open && len(expected) != len(actual) {

View file

@ -2,9 +2,11 @@ package testenv
import (
"net/url"
"strings"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/testenv/values"
"github.com/pomerium/pomerium/pkg/policy/parser"
)
// PolicyRoute is a [Route] implementation suitable for most common use cases
@ -54,6 +56,20 @@ func (b *PolicyRoute) Policy(edit func(*config.Policy)) Route {
return b
}
// PPL implements Route.
func (b *PolicyRoute) PPL(ppl string) Route {
pplPolicy, err := parser.ParseYAML(strings.NewReader(ppl))
if err != nil {
panic(err)
}
b.edits = append(b.edits, func(p *config.Policy) {
p.Policy = &config.PPLPolicy{
Policy: pplPolicy,
}
})
return b
}
// To implements Route.
func (b *PolicyRoute) URL() values.Value[string] {
return b.from

View file

@ -0,0 +1,24 @@
package scenarios
import (
"context"
"encoding/base64"
"encoding/pem"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/testenv"
)
func DownstreamMTLS(mode config.MTLSEnforcement) testenv.Modifier {
return testenv.ModifierFunc(func(ctx context.Context, cfg *config.Config) {
env := testenv.EnvFromContext(ctx)
block := pem.Block{
Type: "CERTIFICATE",
Bytes: env.CACert().Leaf.Raw,
}
cfg.Options.DownstreamMTLS = config.DownstreamMTLSSettings{
CA: base64.StdEncoding.EncodeToString(pem.EncodeToMemory(&block)),
Enforcement: mode,
}
})
}

View file

@ -77,10 +77,24 @@ func (m Modifiers) Modify(cfg *config.Config) {
}
}
type ModifierFunc func(cfg *config.Config)
type modifierFunc struct {
fn func(ctx context.Context, cfg *config.Config)
ctx context.Context
}
func (f ModifierFunc) Modify(cfg *config.Config) {
f(cfg)
// Attach implements Modifier.
func (f *modifierFunc) Attach(ctx context.Context) {
f.ctx = ctx
}
func (f *modifierFunc) Modify(cfg *config.Config) {
f.fn(f.ctx, cfg)
}
var _ Modifier = (*modifierFunc)(nil)
func ModifierFunc(fn func(ctx context.Context, cfg *config.Config)) Modifier {
return &modifierFunc{fn: fn}
}
// Task represents a background task that can be added to an [Environment] to
@ -116,6 +130,7 @@ type Route interface {
URL() values.Value[string]
To(toUrl values.Value[string]) Route
Policy(edit func(*config.Policy)) Route
PPL(ppl string) Route
// add more methods here as they become needed
}

View file

@ -77,9 +77,11 @@ func Body(body any) RequestOption {
}
// ClientCert adds a client certificate to the request.
func ClientCert(cert tls.Certificate) RequestOption {
func ClientCert[T interface {
*testenv.Certificate | *tls.Certificate
}](cert T) RequestOption {
return func(o *RequestOptions) {
o.clientCerts = append(o.clientCerts, cert)
o.clientCerts = append(o.clientCerts, *(*tls.Certificate)(cert))
}
}
@ -250,6 +252,7 @@ func (h *httpUpstream) Do(method string, r testenv.Route, opts ...RequestOption)
}
return retry.NewTerminalError(err)
}
resp.Body.Close()
return nil
}, retry.WithMaxInterval(1*time.Second)); err != nil {
return nil, err

View file

@ -16,6 +16,7 @@ import (
"github.com/pomerium/pomerium/authenticate"
"github.com/pomerium/pomerium/authorize"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/config/envoyconfig/filemgr"
databroker_service "github.com/pomerium/pomerium/databroker"
"github.com/pomerium/pomerium/internal/autocert"
"github.com/pomerium/pomerium/internal/controlplane"
@ -30,8 +31,29 @@ import (
"github.com/pomerium/pomerium/proxy"
)
type RunOptions struct {
fileMgr *filemgr.Manager
}
type RunOption func(*RunOptions)
func (o *RunOptions) apply(opts ...RunOption) {
for _, op := range opts {
op(o)
}
}
func WithOverrideFileManager(fileMgr *filemgr.Manager) RunOption {
return func(o *RunOptions) {
o.fileMgr = fileMgr
}
}
// Run runs the main pomerium application.
func Run(ctx context.Context, src config.Source) error {
func Run(ctx context.Context, src config.Source, opts ...RunOption) error {
options := RunOptions{}
options.apply(opts...)
_, _ = maxprocs.Set(maxprocs.Logger(func(s string, i ...any) { log.Debug(context.Background()).Msgf(s, i...) }))
evt := log.Info(ctx).
@ -68,10 +90,15 @@ func Run(ctx context.Context, src config.Source) error {
eventsMgr := events.New()
fileMgr := options.fileMgr
if fileMgr == nil {
fileMgr = filemgr.NewManager()
}
cfg := src.GetConfig()
// setup the control plane
controlPlane, err := controlplane.NewServer(ctx, cfg, metricsMgr, eventsMgr)
controlPlane, err := controlplane.NewServer(ctx, cfg, metricsMgr, eventsMgr, fileMgr)
if err != nil {
return fmt.Errorf("error creating control plane: %w", err)
}