mirror of
https://github.com/pomerium/pomerium.git
synced 2025-06-27 06:58:13 +02:00
Add client certificate utilities
This commit is contained in:
parent
6ed5752fa5
commit
490a301aa4
10 changed files with 355 additions and 20 deletions
|
@ -2,10 +2,20 @@ package testenv
|
|||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/sha256"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/asn1"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net"
|
||||
"net/url"
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
|
@ -13,13 +23,16 @@ import (
|
|||
"strconv"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pomerium/pomerium/config"
|
||||
"github.com/pomerium/pomerium/config/envoyconfig/filemgr"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/internal/testenv/values"
|
||||
"github.com/pomerium/pomerium/pkg/cmd/pomerium"
|
||||
"github.com/pomerium/pomerium/pkg/health"
|
||||
"github.com/pomerium/pomerium/pkg/netutil"
|
||||
"github.com/pomerium/pomerium/pkg/slices"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
@ -33,14 +46,33 @@ type Environment interface {
|
|||
// top-level logger scoped to this environment. It will be canceled when
|
||||
// Stop() is called, or during test cleanup.
|
||||
Context() context.Context
|
||||
|
||||
Assert() *assert.Assertions
|
||||
Require() *require.Assertions
|
||||
|
||||
// TempDir returns a unique temp directory for this context. Calling this
|
||||
// function multiple times returns the same path.
|
||||
TempDir() string
|
||||
// CACert returns the test environment's root CA certificate and private key.
|
||||
CACert() *tls.Certificate
|
||||
// ServerCAs returns a new [*x509.CertPool] containing the root CA certificate
|
||||
// used to sign the server cert and other test certificates.
|
||||
ServerCAs() *x509.CertPool
|
||||
// ServerCert returns the Pomerium server's certificate.
|
||||
// ServerCert returns the Pomerium server's certificate and private key.
|
||||
ServerCert() *tls.Certificate
|
||||
// NewClientCert generates a new client certificate signed by the root CA
|
||||
// certificate. One or more optional templates can be given, which can be
|
||||
// used to set or override certain parameters when creating a certificate,
|
||||
// including subject, SANs, or extensions. If more than one template is
|
||||
// provided, they will be applied in order from left to right.
|
||||
//
|
||||
// By default (unless overridden in a template), the certificate will have
|
||||
// its Common Name set to the file:line string of the call site. Calls to
|
||||
// NewClientCert() on different lines will have different subjects. If
|
||||
// multiple certs with the same subject are needed, wrap the call to this
|
||||
// function in another helper function, or separate calls with commas on the
|
||||
// same line.
|
||||
NewClientCert(templateOverrides ...*x509.Certificate) *Certificate
|
||||
|
||||
// Add adds the given [Modifier] to the environment. All modifiers will be
|
||||
// invoked upon calling Start() to apply individual modifications to the
|
||||
|
@ -80,8 +112,22 @@ type Environment interface {
|
|||
NewLogRecorder(opts ...LogRecorderOption) *LogRecorder
|
||||
}
|
||||
|
||||
type Certificate tls.Certificate
|
||||
|
||||
func (c *Certificate) Fingerprint() string {
|
||||
sum := sha256.Sum256(c.Leaf.Raw)
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
|
||||
func (c *Certificate) SPKIHash() string {
|
||||
sum := sha256.Sum256(c.Leaf.RawSubjectPublicKeyInfo)
|
||||
return base64.StdEncoding.EncodeToString(sum[:])
|
||||
}
|
||||
|
||||
type environment struct {
|
||||
t testing.TB
|
||||
assert *assert.Assertions
|
||||
require *require.Assertions
|
||||
tempDir string
|
||||
domain string
|
||||
ports Ports
|
||||
|
@ -125,6 +171,8 @@ func New(t testing.TB) Environment {
|
|||
|
||||
e := &environment{
|
||||
t: t,
|
||||
assert: assert.New(t),
|
||||
require: require.New(t),
|
||||
tempDir: t.TempDir(),
|
||||
ports: Ports{
|
||||
http: values.Deferred[int](),
|
||||
|
@ -141,13 +189,14 @@ func New(t testing.TB) Environment {
|
|||
copyFile := func(src, dstRel string) {
|
||||
data, err := os.ReadFile(src)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, os.WriteFile(filepath.Join(e.tempDir, dstRel), data, 0o666))
|
||||
require.NoError(t, os.WriteFile(filepath.Join(e.tempDir, dstRel), data, 0o600))
|
||||
}
|
||||
|
||||
certsToCopy := []string{
|
||||
"trusted.pem",
|
||||
"trusted-key.pem",
|
||||
"ca.pem",
|
||||
"ca-key.pem",
|
||||
}
|
||||
for _, crt := range certsToCopy {
|
||||
copyFile(filepath.Join(workspaceFolder, "integration/tpl/files", crt), filepath.Join("certs/", filepath.Base(crt)))
|
||||
|
@ -174,12 +223,29 @@ func (e *environment) Context() context.Context {
|
|||
return ContextWithEnv(e.ctx, e)
|
||||
}
|
||||
|
||||
func (e *environment) Assert() *assert.Assertions {
|
||||
return e.assert
|
||||
}
|
||||
|
||||
func (e *environment) Require() *require.Assertions {
|
||||
return e.require
|
||||
}
|
||||
|
||||
func (e *environment) SubdomainURL(subdomain string) values.Value[string] {
|
||||
return values.Bind(e.ports.http, func(port int) string {
|
||||
return fmt.Sprintf("https://%s.%s:%d", subdomain, e.domain, port)
|
||||
})
|
||||
}
|
||||
|
||||
func (e *environment) CACert() *tls.Certificate {
|
||||
caCert, err := tls.LoadX509KeyPair(
|
||||
filepath.Join(e.tempDir, "certs", "ca.pem"),
|
||||
filepath.Join(e.tempDir, "certs", "ca-key.pem"),
|
||||
)
|
||||
require.NoError(e.t, err)
|
||||
return &caCert
|
||||
}
|
||||
|
||||
func (e *environment) ServerCAs() *x509.CertPool {
|
||||
pool := x509.NewCertPool()
|
||||
caCert, err := os.ReadFile(filepath.Join(e.tempDir, "certs", "ca.pem"))
|
||||
|
@ -213,20 +279,39 @@ func (e *environment) Start() {
|
|||
require.NoError(e.t, err)
|
||||
port0, _ := strconv.Atoi(ports[0])
|
||||
e.ports.http.Resolve(port0)
|
||||
cfg.Options.AutocertOptions = config.AutocertOptions{Enable: false}
|
||||
cfg.Options.LogLevel = config.LogLevelInfo
|
||||
cfg.Options.ProxyLogLevel = config.LogLevelInfo
|
||||
cfg.Options.Addr = fmt.Sprintf("127.0.0.1:%d", port0)
|
||||
cfg.Options.CertFile = filepath.Join(e.tempDir, "certs", "trusted.pem")
|
||||
cfg.Options.KeyFile = filepath.Join(e.tempDir, "certs", "trusted-key.pem")
|
||||
cfg.Options.AccessLogFields = []log.AccessLogField{
|
||||
log.AccessLogFieldAuthority,
|
||||
log.AccessLogFieldDuration,
|
||||
log.AccessLogFieldForwardedFor,
|
||||
log.AccessLogFieldIP,
|
||||
log.AccessLogFieldMethod,
|
||||
log.AccessLogFieldPath,
|
||||
log.AccessLogFieldQuery,
|
||||
log.AccessLogFieldReferer,
|
||||
log.AccessLogFieldRequestID,
|
||||
log.AccessLogFieldResponseCode,
|
||||
log.AccessLogFieldResponseCodeDetails,
|
||||
log.AccessLogFieldSize,
|
||||
log.AccessLogFieldUpstreamCluster,
|
||||
log.AccessLogFieldUserAgent,
|
||||
log.AccessLogFieldClientCertificate,
|
||||
}
|
||||
cfg.AllocatePorts(*(*[6]string)(ports[1:]))
|
||||
|
||||
e.AddTask(TaskFunc(func(ctx context.Context) error {
|
||||
fileMgr := filemgr.NewManager(filemgr.WithCacheDir(filepath.Join(e.TempDir(), "cache")))
|
||||
src := config.NewStaticSource(cfg)
|
||||
for _, mod := range e.mods {
|
||||
mod.Value.Modify(cfg)
|
||||
require.NoError(e.t, cfg.Options.Validate(), "invoking modifier resulted in an invalid configuration:\nadded by: "+mod.Caller)
|
||||
}
|
||||
return pomerium.Run(e.ctx, src)
|
||||
return pomerium.Run(e.ctx, src, pomerium.WithOverrideFileManager(fileMgr))
|
||||
}))
|
||||
|
||||
for i, task := range e.tasks {
|
||||
|
@ -238,6 +323,64 @@ func (e *environment) Start() {
|
|||
}
|
||||
}
|
||||
|
||||
func (e *environment) NewClientCert(templateOverrides ...*x509.Certificate) *Certificate {
|
||||
caCert := e.CACert()
|
||||
|
||||
priv, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
require.NoError(e.t, err)
|
||||
|
||||
sn, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128))
|
||||
require.NoError(e.t, err)
|
||||
now := time.Now()
|
||||
tmpl := &x509.Certificate{
|
||||
SerialNumber: sn,
|
||||
Subject: pkix.Name{
|
||||
CommonName: getCaller(),
|
||||
},
|
||||
NotBefore: now,
|
||||
NotAfter: now.Add(12 * time.Hour),
|
||||
KeyUsage: x509.KeyUsageDigitalSignature,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{
|
||||
x509.ExtKeyUsageClientAuth,
|
||||
},
|
||||
BasicConstraintsValid: true,
|
||||
}
|
||||
for _, override := range templateOverrides {
|
||||
tmpl.CRLDistributionPoints = slices.Unique(append(tmpl.CRLDistributionPoints, override.CRLDistributionPoints...))
|
||||
tmpl.DNSNames = slices.Unique(append(tmpl.DNSNames, override.DNSNames...))
|
||||
tmpl.EmailAddresses = slices.Unique(append(tmpl.EmailAddresses, override.EmailAddresses...))
|
||||
tmpl.ExtraExtensions = append(tmpl.ExtraExtensions, override.ExtraExtensions...)
|
||||
tmpl.IPAddresses = slices.UniqueBy(append(tmpl.IPAddresses, override.IPAddresses...), net.IP.String)
|
||||
tmpl.URIs = slices.UniqueBy(append(tmpl.URIs, override.URIs...), (*url.URL).String)
|
||||
tmpl.UnknownExtKeyUsage = slices.UniqueBy(append(tmpl.UnknownExtKeyUsage, override.UnknownExtKeyUsage...), asn1.ObjectIdentifier.String)
|
||||
seq := override.Subject.ToRDNSequence()
|
||||
tmpl.Subject.FillFromRDNSequence(&seq)
|
||||
tmpl.KeyUsage |= override.KeyUsage
|
||||
tmpl.ExtKeyUsage = slices.Unique(append(tmpl.ExtKeyUsage, override.ExtKeyUsage...))
|
||||
}
|
||||
|
||||
clientCertDER, err := x509.CreateCertificate(rand.Reader, tmpl, caCert.Leaf, priv.Public(), caCert.PrivateKey)
|
||||
require.NoError(e.t, err)
|
||||
|
||||
cert, err := x509.ParseCertificate(clientCertDER)
|
||||
require.NoError(e.t, err)
|
||||
|
||||
clientCert := &tls.Certificate{
|
||||
Certificate: [][]byte{cert.Raw, caCert.Leaf.Raw},
|
||||
PrivateKey: priv,
|
||||
Leaf: cert,
|
||||
}
|
||||
|
||||
_, err = clientCert.Leaf.Verify(x509.VerifyOptions{
|
||||
KeyUsages: []x509.ExtKeyUsage{
|
||||
x509.ExtKeyUsageClientAuth,
|
||||
},
|
||||
Roots: e.ServerCAs(),
|
||||
})
|
||||
require.NoError(e.t, err, "bug: generated client cert is not valid")
|
||||
return (*Certificate)(clientCert)
|
||||
}
|
||||
|
||||
func (e *environment) Stop() {
|
||||
e.cleanupOnce.Do(func() {
|
||||
e.cancel(ErrCauseManualStop)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue