From 9631d9ff1c6a93f6fdc08dd5b76b5f926dc30a48 Mon Sep 17 00:00:00 2001 From: Caleb Doxsey Date: Fri, 6 Jun 2025 12:37:02 -0600 Subject: [PATCH] cryptutil: add a function to normalize PEM files so that leaf certificates appear first (#5642) ## Summary Go requires that the first certificate in a bundle be the one associated with a private key: > LoadX509KeyPair reads and parses a public/private key pair from a pair of files. The files must contain PEM encoded data. The certificate file may contain intermediate certificates following the leaf certificate to form a certificate chain. On successful return, Certificate.Leaf will be populated. I don't think Go is unusual in this regard, but to make the code more tolerant, add a new `NormalizePEM` function which will take raw PEM data and rewrite it so that leaf certificates appear first. This will be used in zero and the enterprise console. ## Related issues - [ENG-2433](https://linear.app/pomerium/issue/ENG-2423/enterprise-console-updatekeypair-check-is-too-restrictive) ## Checklist - [x] reference any related issues - [x] updated unit tests - [x] add appropriate label (`enhancement`, `bug`, `breaking`, `dependencies`, `ci`) - [x] ready for review --- internal/testutil/tls.go | 101 +++++++++++++++++++++++++++++ pkg/cryptutil/pem.go | 129 ++++++++++++++++++++++++++++++++++++++ pkg/cryptutil/pem_test.go | 55 ++++++++++++++++ 3 files changed, 285 insertions(+) create mode 100644 internal/testutil/tls.go create mode 100644 pkg/cryptutil/pem.go create mode 100644 pkg/cryptutil/pem_test.go diff --git a/internal/testutil/tls.go b/internal/testutil/tls.go new file mode 100644 index 000000000..2e8c7c81b --- /dev/null +++ b/internal/testutil/tls.go @@ -0,0 +1,101 @@ +package testutil + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "math/big" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +// A Certificate is the public and private certificate details. +type Certificate struct { + X509 *x509.Certificate + PublicDER []byte + PublicPEM []byte + PrivateKey *ecdsa.PrivateKey + PrivateKeyDER []byte + PrivateKeyPEM []byte +} + +// GenerateCertificateChain generates a root certificate authority, an intermediate certificate authority and a certificate. +func GenerateCertificateChain(tb testing.TB) (rootCA, intermediateCA, cert Certificate) { + tb.Helper() + + var err error + + rootCA.PrivateKey, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(tb, err) + intermediateCA.PrivateKey, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(tb, err) + cert.PrivateKey, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(tb, err) + + notAfter := time.Now().Add(3650 * 24 * time.Hour) + rootCATemplate := &x509.Certificate{ + SerialNumber: big.NewInt(0x1000), + Subject: pkix.Name{ + CommonName: "Root CA", + }, + NotAfter: notAfter, + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, + BasicConstraintsValid: true, + IsCA: true, + } + intermediateCATemplate := &x509.Certificate{ + SerialNumber: big.NewInt(0x1001), + Subject: pkix.Name{ + CommonName: "Intermediate CA", + }, + NotAfter: notAfter, + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + IsCA: true, + } + certTemplate := &x509.Certificate{ + SerialNumber: big.NewInt(0x1002), + Subject: pkix.Name{ + CommonName: "Certificate", + }, + NotAfter: notAfter, + KeyUsage: x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + } + + rootCA.PrivateKeyDER, err = x509.MarshalPKCS8PrivateKey(rootCA.PrivateKey) + require.NoError(tb, err) + rootCA.PublicDER, err = x509.CreateCertificate(rand.Reader, rootCATemplate, rootCATemplate, rootCA.PrivateKey.Public(), rootCA.PrivateKey) + require.NoError(tb, err) + rootCA.X509, err = x509.ParseCertificate(rootCA.PublicDER) + require.NoError(tb, err) + rootCA.PublicPEM = pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: rootCA.PublicDER, Headers: map[string]string{"name": "root certificate"}}) + + rootCA.PrivateKeyPEM = pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: rootCA.PrivateKeyDER}) + intermediateCA.PrivateKeyDER, err = x509.MarshalPKCS8PrivateKey(intermediateCA.PrivateKey) + require.NoError(tb, err) + intermediateCA.PrivateKeyPEM = pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: intermediateCA.PrivateKeyDER, Headers: map[string]string{"name": "intermediate key"}}) + intermediateCA.PublicDER, err = x509.CreateCertificate(rand.Reader, intermediateCATemplate, rootCA.X509, intermediateCA.PrivateKey.Public(), rootCA.PrivateKey) + require.NoError(tb, err) + intermediateCA.X509, err = x509.ParseCertificate(intermediateCA.PublicDER) + require.NoError(tb, err) + intermediateCA.PublicPEM = pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: intermediateCA.PublicDER, Headers: map[string]string{"name": "intermediate certificate"}}) + + cert.PrivateKeyDER, err = x509.MarshalPKCS8PrivateKey(cert.PrivateKey) + require.NoError(tb, err) + cert.PrivateKeyPEM = pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: cert.PrivateKeyDER, Headers: map[string]string{"name": "key"}}) + cert.PublicDER, err = x509.CreateCertificate(rand.Reader, certTemplate, intermediateCA.X509, cert.PrivateKey.Public(), intermediateCA.PrivateKey) + require.NoError(tb, err) + cert.X509, err = x509.ParseCertificate(cert.PublicDER) + require.NoError(tb, err) + cert.PublicPEM = pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: cert.PublicDER, Headers: map[string]string{"name": "certificate"}}) + + return rootCA, intermediateCA, cert +} diff --git a/pkg/cryptutil/pem.go b/pkg/cryptutil/pem.go new file mode 100644 index 000000000..00d94b7b6 --- /dev/null +++ b/pkg/cryptutil/pem.go @@ -0,0 +1,129 @@ +package cryptutil + +import ( + "bytes" + "cmp" + "crypto/x509" + "encoding/pem" + "iter" + "slices" +) + +// NormalizePEM takes PEM-encoded data and normalizes it. +// +// If the PEM data contains multiple certificates, signing certificates +// will be moved after the things they sign. +func NormalizePEM(data []byte) []byte { + type Segment struct { + ID int + Data []byte + } + var segments []Segment + for block := range iteratePEM(data) { + segments = append(segments, Segment{ID: len(segments), Data: block}) + } + + // build a lookup table for subject keys and authority keys + // a certificate with an authority key set to the subject key + // of another certificate should appear before that certificate + idToAuthorityKey := map[int]string{} + subjectKeyToID := map[string]int{} + for _, segment := range segments { + block, _ := pem.Decode(segment.Data) + if block == nil { + continue + } + if block.Type != "CERTIFICATE" { + continue + } + cert, err := x509.ParseCertificate(block.Bytes) + if err != nil { + continue + } + if len(cert.AuthorityKeyId) > 0 { + idToAuthorityKey[segment.ID] = string(cert.AuthorityKeyId) + } + if len(cert.SubjectKeyId) > 0 { + subjectKeyToID[string(cert.SubjectKeyId)] = segment.ID + } + } + + // calculate the depth of each certificate, deeper certificates will appear last + depth := make([]int, len(segments)) + for i := range segments { + id := segments[i].ID + for { + authorityKey, ok := idToAuthorityKey[id] + if !ok { + break + } + + id, ok = subjectKeyToID[authorityKey] + if !ok { + break + } + depth[id]++ + } + } + + // sort the segments + slices.SortStableFunc(segments, func(x, y Segment) int { + return cmp.Compare(depth[x.ID], depth[y.ID]) + }) + + // join the segments back together + var buf bytes.Buffer + for _, segment := range segments { + buf.Write(segment.Data) + } + return buf.Bytes() +} + +var ( + pemBegin = []byte("-----BEGIN ") + pemEnd = []byte("-----END ") +) + +// splitPEM attempts to split a slice of bytes into a single pem block +// followed by the rest of the data. The pem block may contain extra +// text before the BEGIN but won't contain more than one pem block. +func splitPEM(data []byte) (before, after []byte) { + idx1 := bytes.Index(data, pemBegin) + if idx1 < 0 { + return data, nil + } + + idx2 := bytes.IndexByte(data[idx1+len(pemBegin):], '\n') + if idx2 < 0 { + return data, nil + } + idx2 += idx1 + len(pemBegin) + + idx3 := bytes.Index(data[idx2+1:], pemEnd) + if idx3 < 0 { + return data, nil + } + idx3 += idx2 + 1 + + idx4 := bytes.IndexByte(data[idx3+len(pemEnd):], '\n') + if idx4 < 0 { + return data, nil + } + idx4 += idx3 + len(pemEnd) + + return data[:idx4+1], data[idx4+1:] +} + +// iteratePEM iterates over all the raw PEM blocks +func iteratePEM(data []byte) iter.Seq[[]byte] { + return func(yield func([]byte) bool) { + rest := data + for len(rest) > 0 { + before, after := splitPEM(rest) + if !yield(before) { + return + } + rest = after + } + } +} diff --git a/pkg/cryptutil/pem_test.go b/pkg/cryptutil/pem_test.go new file mode 100644 index 000000000..30288503a --- /dev/null +++ b/pkg/cryptutil/pem_test.go @@ -0,0 +1,55 @@ +package cryptutil_test + +import ( + "slices" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/pomerium/pomerium/internal/testutil" + "github.com/pomerium/pomerium/pkg/cryptutil" +) + +func TestNormalizePEM(t *testing.T) { + t.Parallel() + + rootCA, intermediateCA, cert := testutil.GenerateCertificateChain(t) + + for _, tc := range []struct { + input []byte + expect []byte + }{ + { + input: slices.Concat(rootCA.PublicPEM, intermediateCA.PublicPEM, cert.PublicPEM, cert.PrivateKeyPEM), + expect: slices.Concat(cert.PublicPEM, cert.PrivateKeyPEM, intermediateCA.PublicPEM, rootCA.PublicPEM), + }, + { + input: slices.Concat(cert.PublicPEM, cert.PrivateKeyPEM, intermediateCA.PublicPEM, rootCA.PublicPEM), + expect: slices.Concat(cert.PublicPEM, cert.PrivateKeyPEM, intermediateCA.PublicPEM, rootCA.PublicPEM), + }, + { + input: nil, + expect: nil, + }, + { + input: []byte("\n\n\nNON PEM DATA\n\n\n"), + expect: []byte("\n\n\nNON PEM DATA\n\n\n"), + }, + { + input: rootCA.PublicPEM, + expect: rootCA.PublicPEM, + }, + { + input: slices.Concat(rootCA.PublicPEM, intermediateCA.PublicPEM, cert.PublicPEM, cert.PrivateKeyPEM), + expect: slices.Concat(cert.PublicPEM, cert.PrivateKeyPEM, intermediateCA.PublicPEM, rootCA.PublicPEM), + }, + { + // looks a bit weird, but the text before a block gets moved with it + input: slices.Concat([]byte("BEFORE\n"), intermediateCA.PublicPEM, []byte("BETWEEN\n"), cert.PublicPEM, []byte("AFTER\n")), + expect: slices.Concat([]byte("BETWEEN\n"), cert.PublicPEM, []byte("AFTER\n"), []byte("BEFORE\n"), intermediateCA.PublicPEM), + }, + } { + actual := cryptutil.NormalizePEM(tc.input) + assert.Equal(t, string(tc.expect), string(actual)) + } +}