Add client certificate utilities

This commit is contained in:
Joe Kralicky 2024-08-26 20:00:18 -04:00
parent 6ed5752fa5
commit 490a301aa4
No known key found for this signature in database
GPG key ID: 75C4875F34A9FB79
10 changed files with 355 additions and 20 deletions

View file

@ -2,10 +2,20 @@ package testenv
import (
"context"
"crypto/rand"
"crypto/rsa"
"crypto/sha256"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/asn1"
"encoding/base64"
"encoding/hex"
"errors"
"fmt"
"math/big"
"net"
"net/url"
"os"
"path"
"path/filepath"
@ -13,13 +23,16 @@ import (
"strconv"
"sync"
"testing"
"time"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/config/envoyconfig/filemgr"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/testenv/values"
"github.com/pomerium/pomerium/pkg/cmd/pomerium"
"github.com/pomerium/pomerium/pkg/health"
"github.com/pomerium/pomerium/pkg/netutil"
"github.com/pomerium/pomerium/pkg/slices"
"github.com/rs/zerolog"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@ -33,14 +46,33 @@ type Environment interface {
// top-level logger scoped to this environment. It will be canceled when
// Stop() is called, or during test cleanup.
Context() context.Context
Assert() *assert.Assertions
Require() *require.Assertions
// TempDir returns a unique temp directory for this context. Calling this
// function multiple times returns the same path.
TempDir() string
// CACert returns the test environment's root CA certificate and private key.
CACert() *tls.Certificate
// ServerCAs returns a new [*x509.CertPool] containing the root CA certificate
// used to sign the server cert and other test certificates.
ServerCAs() *x509.CertPool
// ServerCert returns the Pomerium server's certificate.
// ServerCert returns the Pomerium server's certificate and private key.
ServerCert() *tls.Certificate
// NewClientCert generates a new client certificate signed by the root CA
// certificate. One or more optional templates can be given, which can be
// used to set or override certain parameters when creating a certificate,
// including subject, SANs, or extensions. If more than one template is
// provided, they will be applied in order from left to right.
//
// By default (unless overridden in a template), the certificate will have
// its Common Name set to the file:line string of the call site. Calls to
// NewClientCert() on different lines will have different subjects. If
// multiple certs with the same subject are needed, wrap the call to this
// function in another helper function, or separate calls with commas on the
// same line.
NewClientCert(templateOverrides ...*x509.Certificate) *Certificate
// Add adds the given [Modifier] to the environment. All modifiers will be
// invoked upon calling Start() to apply individual modifications to the
@ -80,8 +112,22 @@ type Environment interface {
NewLogRecorder(opts ...LogRecorderOption) *LogRecorder
}
type Certificate tls.Certificate
func (c *Certificate) Fingerprint() string {
sum := sha256.Sum256(c.Leaf.Raw)
return hex.EncodeToString(sum[:])
}
func (c *Certificate) SPKIHash() string {
sum := sha256.Sum256(c.Leaf.RawSubjectPublicKeyInfo)
return base64.StdEncoding.EncodeToString(sum[:])
}
type environment struct {
t testing.TB
assert *assert.Assertions
require *require.Assertions
tempDir string
domain string
ports Ports
@ -125,6 +171,8 @@ func New(t testing.TB) Environment {
e := &environment{
t: t,
assert: assert.New(t),
require: require.New(t),
tempDir: t.TempDir(),
ports: Ports{
http: values.Deferred[int](),
@ -141,13 +189,14 @@ func New(t testing.TB) Environment {
copyFile := func(src, dstRel string) {
data, err := os.ReadFile(src)
require.NoError(t, err)
require.NoError(t, os.WriteFile(filepath.Join(e.tempDir, dstRel), data, 0o666))
require.NoError(t, os.WriteFile(filepath.Join(e.tempDir, dstRel), data, 0o600))
}
certsToCopy := []string{
"trusted.pem",
"trusted-key.pem",
"ca.pem",
"ca-key.pem",
}
for _, crt := range certsToCopy {
copyFile(filepath.Join(workspaceFolder, "integration/tpl/files", crt), filepath.Join("certs/", filepath.Base(crt)))
@ -174,12 +223,29 @@ func (e *environment) Context() context.Context {
return ContextWithEnv(e.ctx, e)
}
func (e *environment) Assert() *assert.Assertions {
return e.assert
}
func (e *environment) Require() *require.Assertions {
return e.require
}
func (e *environment) SubdomainURL(subdomain string) values.Value[string] {
return values.Bind(e.ports.http, func(port int) string {
return fmt.Sprintf("https://%s.%s:%d", subdomain, e.domain, port)
})
}
func (e *environment) CACert() *tls.Certificate {
caCert, err := tls.LoadX509KeyPair(
filepath.Join(e.tempDir, "certs", "ca.pem"),
filepath.Join(e.tempDir, "certs", "ca-key.pem"),
)
require.NoError(e.t, err)
return &caCert
}
func (e *environment) ServerCAs() *x509.CertPool {
pool := x509.NewCertPool()
caCert, err := os.ReadFile(filepath.Join(e.tempDir, "certs", "ca.pem"))
@ -213,20 +279,39 @@ func (e *environment) Start() {
require.NoError(e.t, err)
port0, _ := strconv.Atoi(ports[0])
e.ports.http.Resolve(port0)
cfg.Options.AutocertOptions = config.AutocertOptions{Enable: false}
cfg.Options.LogLevel = config.LogLevelInfo
cfg.Options.ProxyLogLevel = config.LogLevelInfo
cfg.Options.Addr = fmt.Sprintf("127.0.0.1:%d", port0)
cfg.Options.CertFile = filepath.Join(e.tempDir, "certs", "trusted.pem")
cfg.Options.KeyFile = filepath.Join(e.tempDir, "certs", "trusted-key.pem")
cfg.Options.AccessLogFields = []log.AccessLogField{
log.AccessLogFieldAuthority,
log.AccessLogFieldDuration,
log.AccessLogFieldForwardedFor,
log.AccessLogFieldIP,
log.AccessLogFieldMethod,
log.AccessLogFieldPath,
log.AccessLogFieldQuery,
log.AccessLogFieldReferer,
log.AccessLogFieldRequestID,
log.AccessLogFieldResponseCode,
log.AccessLogFieldResponseCodeDetails,
log.AccessLogFieldSize,
log.AccessLogFieldUpstreamCluster,
log.AccessLogFieldUserAgent,
log.AccessLogFieldClientCertificate,
}
cfg.AllocatePorts(*(*[6]string)(ports[1:]))
e.AddTask(TaskFunc(func(ctx context.Context) error {
fileMgr := filemgr.NewManager(filemgr.WithCacheDir(filepath.Join(e.TempDir(), "cache")))
src := config.NewStaticSource(cfg)
for _, mod := range e.mods {
mod.Value.Modify(cfg)
require.NoError(e.t, cfg.Options.Validate(), "invoking modifier resulted in an invalid configuration:\nadded by: "+mod.Caller)
}
return pomerium.Run(e.ctx, src)
return pomerium.Run(e.ctx, src, pomerium.WithOverrideFileManager(fileMgr))
}))
for i, task := range e.tasks {
@ -238,6 +323,64 @@ func (e *environment) Start() {
}
}
func (e *environment) NewClientCert(templateOverrides ...*x509.Certificate) *Certificate {
caCert := e.CACert()
priv, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(e.t, err)
sn, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128))
require.NoError(e.t, err)
now := time.Now()
tmpl := &x509.Certificate{
SerialNumber: sn,
Subject: pkix.Name{
CommonName: getCaller(),
},
NotBefore: now,
NotAfter: now.Add(12 * time.Hour),
KeyUsage: x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{
x509.ExtKeyUsageClientAuth,
},
BasicConstraintsValid: true,
}
for _, override := range templateOverrides {
tmpl.CRLDistributionPoints = slices.Unique(append(tmpl.CRLDistributionPoints, override.CRLDistributionPoints...))
tmpl.DNSNames = slices.Unique(append(tmpl.DNSNames, override.DNSNames...))
tmpl.EmailAddresses = slices.Unique(append(tmpl.EmailAddresses, override.EmailAddresses...))
tmpl.ExtraExtensions = append(tmpl.ExtraExtensions, override.ExtraExtensions...)
tmpl.IPAddresses = slices.UniqueBy(append(tmpl.IPAddresses, override.IPAddresses...), net.IP.String)
tmpl.URIs = slices.UniqueBy(append(tmpl.URIs, override.URIs...), (*url.URL).String)
tmpl.UnknownExtKeyUsage = slices.UniqueBy(append(tmpl.UnknownExtKeyUsage, override.UnknownExtKeyUsage...), asn1.ObjectIdentifier.String)
seq := override.Subject.ToRDNSequence()
tmpl.Subject.FillFromRDNSequence(&seq)
tmpl.KeyUsage |= override.KeyUsage
tmpl.ExtKeyUsage = slices.Unique(append(tmpl.ExtKeyUsage, override.ExtKeyUsage...))
}
clientCertDER, err := x509.CreateCertificate(rand.Reader, tmpl, caCert.Leaf, priv.Public(), caCert.PrivateKey)
require.NoError(e.t, err)
cert, err := x509.ParseCertificate(clientCertDER)
require.NoError(e.t, err)
clientCert := &tls.Certificate{
Certificate: [][]byte{cert.Raw, caCert.Leaf.Raw},
PrivateKey: priv,
Leaf: cert,
}
_, err = clientCert.Leaf.Verify(x509.VerifyOptions{
KeyUsages: []x509.ExtKeyUsage{
x509.ExtKeyUsageClientAuth,
},
Roots: e.ServerCAs(),
})
require.NoError(e.t, err, "bug: generated client cert is not valid")
return (*Certificate)(clientCert)
}
func (e *environment) Stop() {
e.cleanupOnce.Do(func() {
e.cancel(ErrCauseManualStop)