mirror of
https://github.com/pomerium/pomerium.git
synced 2025-07-20 18:18:08 +02:00
## Summary If a certificate was its own authority it would result in `NormalizePEM` going into an infinite loop. This PR updates the code to avoid cycles using a set. ## Related issues - [ENG-2423](https://linear.app/pomerium/issue/ENG-2423/enterprise-console-updatekeypair-check-is-too-restrictive) ## Checklist - [x] reference any related issues - [x] updated unit tests - [ ] add appropriate label (`enhancement`, `bug`, `breaking`, `dependencies`, `ci`) - [x] ready for review
166 lines
3.5 KiB
Go
166 lines
3.5 KiB
Go
package cryptutil
|
|
|
|
import (
|
|
"bytes"
|
|
"cmp"
|
|
"crypto/x509"
|
|
"encoding/pem"
|
|
"iter"
|
|
"slices"
|
|
|
|
"github.com/hashicorp/go-set/v3"
|
|
)
|
|
|
|
type signedCertificateIndex struct {
|
|
idToAuthorityKey map[int]string
|
|
subjectKeyToID map[string]int
|
|
}
|
|
|
|
func newSignedCertificateIndex() *signedCertificateIndex {
|
|
return &signedCertificateIndex{
|
|
idToAuthorityKey: make(map[int]string),
|
|
subjectKeyToID: make(map[string]int),
|
|
}
|
|
}
|
|
|
|
func (idx *signedCertificateIndex) addPEM(id int, data []byte) {
|
|
block, _ := pem.Decode(data)
|
|
if block == nil {
|
|
return
|
|
}
|
|
|
|
if block.Type != "CERTIFICATE" {
|
|
return
|
|
}
|
|
|
|
cert, err := x509.ParseCertificate(block.Bytes)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
if len(cert.AuthorityKeyId) > 0 {
|
|
idx.idToAuthorityKey[id] = string(cert.AuthorityKeyId)
|
|
}
|
|
if len(cert.SubjectKeyId) > 0 {
|
|
idx.subjectKeyToID[string(cert.SubjectKeyId)] = id
|
|
}
|
|
}
|
|
|
|
func (idx *signedCertificateIndex) depthMap() map[int]int {
|
|
depth := make(map[int]int)
|
|
for _, id := range idx.subjectKeyToID {
|
|
// use a set to avoid cycles
|
|
seen := set.From([]int{id})
|
|
for {
|
|
depth[id]++
|
|
|
|
authorityKey, ok := idx.idToAuthorityKey[id]
|
|
if !ok {
|
|
break
|
|
}
|
|
|
|
id, ok = idx.subjectKeyToID[authorityKey]
|
|
if !ok {
|
|
break
|
|
}
|
|
|
|
if seen.Contains(id) {
|
|
break
|
|
}
|
|
seen.Insert(id)
|
|
}
|
|
}
|
|
return depth
|
|
}
|
|
|
|
// NormalizePEM takes PEM-encoded data and normalizes it.
|
|
//
|
|
// If the PEM data contains multiple certificates, signing certificates
|
|
// will be moved after the things they sign.
|
|
func NormalizePEM(data []byte) []byte {
|
|
// make sure the file has a trailing newline
|
|
if len(data) > 0 && !bytes.HasSuffix(data, []byte{'\n'}) {
|
|
data = append(data, '\n')
|
|
}
|
|
|
|
type Segment struct {
|
|
ID int
|
|
Data []byte
|
|
}
|
|
var segments []Segment
|
|
for block := range iteratePEM(data) {
|
|
segments = append(segments, Segment{ID: len(segments), Data: block})
|
|
}
|
|
|
|
// build a lookup table for subject keys and authority keys
|
|
// a certificate with an authority key set to the subject key
|
|
// of another certificate should appear before that certificate
|
|
idx := newSignedCertificateIndex()
|
|
for _, segment := range segments {
|
|
idx.addPEM(segment.ID, segment.Data)
|
|
}
|
|
|
|
// calculate the depth of each certificate, deeper certificates will appear last
|
|
depth := idx.depthMap()
|
|
|
|
// sort the segments
|
|
slices.SortStableFunc(segments, func(x, y Segment) int {
|
|
return cmp.Compare(depth[x.ID], depth[y.ID])
|
|
})
|
|
|
|
// join the segments back together
|
|
var buf bytes.Buffer
|
|
for _, segment := range segments {
|
|
buf.Write(segment.Data)
|
|
}
|
|
return buf.Bytes()
|
|
}
|
|
|
|
var (
|
|
pemBegin = []byte("-----BEGIN ")
|
|
pemEnd = []byte("-----END ")
|
|
)
|
|
|
|
// splitPEM attempts to split a slice of bytes into a single pem block
|
|
// followed by the rest of the data. The pem block may contain extra
|
|
// text before the BEGIN but won't contain more than one pem block.
|
|
func splitPEM(data []byte) (before, after []byte) {
|
|
idx1 := bytes.Index(data, pemBegin)
|
|
if idx1 < 0 {
|
|
return data, nil
|
|
}
|
|
|
|
idx2 := bytes.IndexByte(data[idx1+len(pemBegin):], '\n')
|
|
if idx2 < 0 {
|
|
return data, nil
|
|
}
|
|
idx2 += idx1 + len(pemBegin)
|
|
|
|
idx3 := bytes.Index(data[idx2+1:], pemEnd)
|
|
if idx3 < 0 {
|
|
return data, nil
|
|
}
|
|
idx3 += idx2 + 1
|
|
|
|
idx4 := bytes.IndexByte(data[idx3+len(pemEnd):], '\n')
|
|
if idx4 < 0 {
|
|
return data, nil
|
|
}
|
|
idx4 += idx3 + len(pemEnd)
|
|
|
|
return data[:idx4+1], data[idx4+1:]
|
|
}
|
|
|
|
// iteratePEM iterates over all the raw PEM blocks
|
|
func iteratePEM(data []byte) iter.Seq[[]byte] {
|
|
return func(yield func([]byte) bool) {
|
|
rest := data
|
|
for len(rest) > 0 {
|
|
before, after := splitPEM(rest)
|
|
if !yield(before) {
|
|
return
|
|
}
|
|
rest = after
|
|
}
|
|
}
|
|
}
|