cryptutil: fix normalize pem with certificate cycles (#5646)

## Summary
If a certificate was its own authority it would result in `NormalizePEM`
going into an infinite loop. This PR updates the code to avoid cycles
using a set.

## Related issues
-
[ENG-2423](https://linear.app/pomerium/issue/ENG-2423/enterprise-console-updatekeypair-check-is-too-restrictive)


## Checklist

- [x] reference any related issues
- [x] updated unit tests
- [ ] add appropriate label (`enhancement`, `bug`, `breaking`,
`dependencies`, `ci`)
- [x] ready for review
This commit is contained in:
Caleb Doxsey 2025-06-09 11:30:05 -06:00 committed by GitHub
parent 4988aea751
commit 5a8597b57b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 82 additions and 35 deletions

View file

@ -7,8 +7,72 @@ import (
"encoding/pem"
"iter"
"slices"
"github.com/hashicorp/go-set/v3"
)
type signedCertificateIndex struct {
idToAuthorityKey map[int]string
subjectKeyToID map[string]int
}
func newSignedCertificateIndex() *signedCertificateIndex {
return &signedCertificateIndex{
idToAuthorityKey: make(map[int]string),
subjectKeyToID: make(map[string]int),
}
}
func (idx *signedCertificateIndex) addPEM(id int, data []byte) {
block, _ := pem.Decode(data)
if block == nil {
return
}
if block.Type != "CERTIFICATE" {
return
}
cert, err := x509.ParseCertificate(block.Bytes)
if err != nil {
return
}
if len(cert.AuthorityKeyId) > 0 {
idx.idToAuthorityKey[id] = string(cert.AuthorityKeyId)
}
if len(cert.SubjectKeyId) > 0 {
idx.subjectKeyToID[string(cert.SubjectKeyId)] = id
}
}
func (idx *signedCertificateIndex) depthMap() map[int]int {
depth := make(map[int]int)
for _, id := range idx.subjectKeyToID {
// use a set to avoid cycles
seen := set.From([]int{id})
for {
depth[id]++
authorityKey, ok := idx.idToAuthorityKey[id]
if !ok {
break
}
id, ok = idx.subjectKeyToID[authorityKey]
if !ok {
break
}
if seen.Contains(id) {
break
}
seen.Insert(id)
}
}
return depth
}
// NormalizePEM takes PEM-encoded data and normalizes it.
//
// If the PEM data contains multiple certificates, signing certificates
@ -31,45 +95,13 @@ func NormalizePEM(data []byte) []byte {
// build a lookup table for subject keys and authority keys
// a certificate with an authority key set to the subject key
// of another certificate should appear before that certificate
idToAuthorityKey := map[int]string{}
subjectKeyToID := map[string]int{}
idx := newSignedCertificateIndex()
for _, segment := range segments {
block, _ := pem.Decode(segment.Data)
if block == nil {
continue
}
if block.Type != "CERTIFICATE" {
continue
}
cert, err := x509.ParseCertificate(block.Bytes)
if err != nil {
continue
}
if len(cert.AuthorityKeyId) > 0 {
idToAuthorityKey[segment.ID] = string(cert.AuthorityKeyId)
}
if len(cert.SubjectKeyId) > 0 {
subjectKeyToID[string(cert.SubjectKeyId)] = segment.ID
}
idx.addPEM(segment.ID, segment.Data)
}
// calculate the depth of each certificate, deeper certificates will appear last
depth := make([]int, len(segments))
for i := range segments {
id := segments[i].ID
for {
authorityKey, ok := idToAuthorityKey[id]
if !ok {
break
}
id, ok = subjectKeyToID[authorityKey]
if !ok {
break
}
depth[id]++
}
}
depth := idx.depthMap()
// sort the segments
slices.SortStableFunc(segments, func(x, y Segment) int {

View file

@ -14,6 +14,17 @@ import (
func TestNormalizePEM(t *testing.T) {
t.Parallel()
cycleCert := []byte(`-----BEGIN CERTIFICATE-----
MIIBqTCCAU6gAwIBAgIUX5ybxP/LMyet/jBir4cx1ZkhGV0wCgYIKoZIzj0EAwIw
GTEXMBUGA1UEAwwOZXhhbXBsZS1jZXJ0LTIwHhcNMjQwNTE2MjEzMjI5WhcNMjUw
NTE2MjEzMjI5WjAZMRcwFQYDVQQDDA5leGFtcGxlLWNlcnQtMjBZMBMGByqGSM49
AgEGCCqGSM49AwEHA0IABLSs3wwhUyip81aiP6aEW0JY44tZqYDqYpJxxIPjC0ce
2QOYaXEMw6YlgJR3jt/oP+bFP9cCGojcD+p0hJW2DzOjdDByMB0GA1UdDgQWBBRE
31UkR4OdgMmxoj1V1D5+MjbeRTAfBgNVHSMEGDAWgBRE31UkR4OdgMmxoj1V1D5+
MjbeRTAPBgNVHRMBAf8EBTADAQH/MB8GA1UdEQQYMBaCDmxvY2FsaG9zdDo1NDMy
hwR/AAABMAoGCCqGSM49BAMCA0kAMEYCIQDHwY1oj3TBZdDtTk+E7RqczOkv3SoO
XKxuqSKG0OIoNAIhANRdc+x57QSUmul0S+MxNh6g17qw1ncfnp/62pA4nRWC
-----END CERTIFICATE-----`)
rootCA, intermediateCA, cert := testutil.GenerateCertificateChain(t)
for _, tc := range []struct {
@ -54,6 +65,10 @@ func TestNormalizePEM(t *testing.T) {
input: slices.Concat([]byte("BEFORE\n"), intermediateCA.PublicPEM, []byte("BETWEEN\n"), cert.PublicPEM, []byte("AFTER\n")),
expect: slices.Concat([]byte("BETWEEN\n"), cert.PublicPEM, []byte("AFTER\n"), []byte("BEFORE\n"), intermediateCA.PublicPEM),
},
{
input: cycleCert,
expect: append(cycleCert, '\n'),
},
} {
actual := cryptutil.NormalizePEM(tc.input)
assert.Equal(t, string(tc.expect), string(actual))