mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-10 15:47:36 +02:00
certs: reject certs from databroker if they conflict with local (#2309)
This commit is contained in:
parent
b162307a96
commit
41a2622736
5 changed files with 225 additions and 16 deletions
79
config/certs.go
Normal file
79
config/certs.go
Normal file
|
@ -0,0 +1,79 @@
|
||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/x509"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
type certUsage byte
|
||||||
|
type certsIndex map[string]map[string]certUsage
|
||||||
|
|
||||||
|
const (
|
||||||
|
certUsageServerAuth = certUsage(1 << iota)
|
||||||
|
certUsageClientAuth
|
||||||
|
)
|
||||||
|
|
||||||
|
func splitDomainName(name string) (prefix, suffix string) {
|
||||||
|
dot := strings.IndexRune(name, '.')
|
||||||
|
if dot < 0 {
|
||||||
|
dot = 0 // i.e. `localhost`
|
||||||
|
}
|
||||||
|
return name[0:dot], name[dot:]
|
||||||
|
}
|
||||||
|
|
||||||
|
func getCertUsage(cert *x509.Certificate) certUsage {
|
||||||
|
var usage certUsage
|
||||||
|
for _, ex := range cert.ExtKeyUsage {
|
||||||
|
switch ex {
|
||||||
|
case x509.ExtKeyUsageClientAuth:
|
||||||
|
usage |= certUsageClientAuth
|
||||||
|
case x509.ExtKeyUsageServerAuth:
|
||||||
|
usage |= certUsageServerAuth
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return usage
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c certsIndex) addCert(cert *x509.Certificate) {
|
||||||
|
usage := getCertUsage(cert)
|
||||||
|
for _, name := range cert.DNSNames {
|
||||||
|
c.add(name, usage)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c certsIndex) matchCert(cert *x509.Certificate) (bool, string) {
|
||||||
|
usage := getCertUsage(cert)
|
||||||
|
for _, name := range cert.DNSNames {
|
||||||
|
if c.match(name, usage) {
|
||||||
|
return true, name
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false, ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c certsIndex) add(name string, usage certUsage) {
|
||||||
|
prefix, suffix := splitDomainName(name)
|
||||||
|
names := c[suffix]
|
||||||
|
if names == nil {
|
||||||
|
names = make(map[string]certUsage)
|
||||||
|
c[suffix] = names
|
||||||
|
}
|
||||||
|
names[prefix] = usage
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c certsIndex) match(name string, usage certUsage) bool {
|
||||||
|
prefix, suffix := splitDomainName(name)
|
||||||
|
names := c[suffix]
|
||||||
|
if names == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if prefix != "*" {
|
||||||
|
return names["*"]&usage != 0 || names[prefix]&usage != 0
|
||||||
|
}
|
||||||
|
for _, u := range names {
|
||||||
|
if u&usage != 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
43
config/certs_test.go
Normal file
43
config/certs_test.go
Normal file
|
@ -0,0 +1,43 @@
|
||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCertOverlap(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
names []string
|
||||||
|
test string
|
||||||
|
match bool
|
||||||
|
}{
|
||||||
|
{[]string{"aa.bb.cc", "cc.bb.aa"}, "aa.bb.c", false},
|
||||||
|
{[]string{"aa.bb.cc"}, "aa.bb.cc", true},
|
||||||
|
{[]string{"*.bb.cc"}, "aa.bb.cc", true},
|
||||||
|
{[]string{"a1.bb.cc", "a2.bb.cc"}, "*.bb.cc", true},
|
||||||
|
{[]string{"*.bb.cc", "a2.bb.cc"}, "*.bb.cc", true},
|
||||||
|
{[]string{"*.aa.bb.cc"}, "*.bb.cc", false},
|
||||||
|
{[]string{"*.aa.bb.cc"}, "aa.bb.cc", false},
|
||||||
|
{[]string{"bb.cc"}, "*.bb.cc", false},
|
||||||
|
}
|
||||||
|
t.Run("match mix mode", func(t *testing.T) {
|
||||||
|
for _, tc := range testCases {
|
||||||
|
idx := make(certsIndex)
|
||||||
|
for _, name := range tc.names {
|
||||||
|
idx.add(name, certUsageServerAuth|certUsageClientAuth)
|
||||||
|
}
|
||||||
|
assert.Equalf(t, tc.match, idx.match(tc.test, certUsageServerAuth), "%v", tc)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
t.Run("different cert usages never match", func(t *testing.T) {
|
||||||
|
for _, tc := range testCases {
|
||||||
|
idx := make(certsIndex)
|
||||||
|
for _, name := range tc.names {
|
||||||
|
idx.add(name, certUsageServerAuth)
|
||||||
|
}
|
||||||
|
assert.Equalf(t, false, idx.match(tc.test, certUsageClientAuth), "%v", tc)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
}
|
|
@ -990,8 +990,64 @@ func (o *Options) Checksum() uint64 {
|
||||||
return hashutil.MustHash(o)
|
return hashutil.MustHash(o)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (o Options) indexCerts(ctx context.Context) certsIndex {
|
||||||
|
idx := make(certsIndex)
|
||||||
|
|
||||||
|
if o.CertFile != "" {
|
||||||
|
cert, err := cryptutil.ParsePEMCertificateFromFile(o.CertFile)
|
||||||
|
if err != nil {
|
||||||
|
log.Error(ctx).Err(err).Str("file", o.CertFile).Msg("parsing local cert: skipped")
|
||||||
|
} else {
|
||||||
|
idx.addCert(cert)
|
||||||
|
}
|
||||||
|
} else if o.Cert != "" {
|
||||||
|
if data, err := base64.StdEncoding.DecodeString(o.Cert); err != nil {
|
||||||
|
log.Error(ctx).Err(err).Msg("bad base64 for local cert: skipped")
|
||||||
|
} else if cert, err := cryptutil.ParsePEMCertificate(data); err != nil {
|
||||||
|
log.Error(ctx).Err(err).Msg("parsing local cert: skipped")
|
||||||
|
} else {
|
||||||
|
idx.addCert(cert)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, c := range o.CertificateFiles {
|
||||||
|
cert, err := cryptutil.ParsePEMCertificateFromFile(c.CertFile)
|
||||||
|
if err != nil {
|
||||||
|
log.Error(ctx).Err(err).Str("file", c.CertFile).Msg("parsing local cert: skipped")
|
||||||
|
}
|
||||||
|
idx.addCert(cert)
|
||||||
|
}
|
||||||
|
return idx
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *Options) applyExternalCerts(ctx context.Context, certs []*config.Settings_Certificate) {
|
||||||
|
idx := o.indexCerts(ctx)
|
||||||
|
for _, c := range certs {
|
||||||
|
cert, err := cryptutil.ParsePEMCertificate(c.CertBytes)
|
||||||
|
if err != nil {
|
||||||
|
log.Error(ctx).Err(err).Msg("parsing cert from databroker: skipped")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if overlaps, name := idx.matchCert(cert); overlaps {
|
||||||
|
log.Error(ctx).Err(err).Str("domain", name).Msg("overlaps with local certs: skipped")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
cfp := certificateFilePair{
|
||||||
|
CertFile: c.CertFile,
|
||||||
|
KeyFile: c.KeyFile,
|
||||||
|
}
|
||||||
|
if cfp.CertFile == "" {
|
||||||
|
cfp.CertFile = base64.StdEncoding.EncodeToString(c.CertBytes)
|
||||||
|
}
|
||||||
|
if cfp.KeyFile == "" {
|
||||||
|
cfp.KeyFile = base64.StdEncoding.EncodeToString(c.KeyBytes)
|
||||||
|
}
|
||||||
|
o.CertificateFiles = append(o.CertificateFiles, cfp)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// ApplySettings modifies the config options using the given protobuf settings.
|
// ApplySettings modifies the config options using the given protobuf settings.
|
||||||
func (o *Options) ApplySettings(settings *config.Settings) {
|
func (o *Options) ApplySettings(ctx context.Context, settings *config.Settings) {
|
||||||
if settings == nil {
|
if settings == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -1023,19 +1079,7 @@ func (o *Options) ApplySettings(settings *config.Settings) {
|
||||||
if settings.DnsLookupFamily != nil {
|
if settings.DnsLookupFamily != nil {
|
||||||
o.DNSLookupFamily = settings.GetDnsLookupFamily()
|
o.DNSLookupFamily = settings.GetDnsLookupFamily()
|
||||||
}
|
}
|
||||||
for _, c := range settings.Certificates {
|
o.applyExternalCerts(ctx, settings.GetCertificates())
|
||||||
cfp := certificateFilePair{
|
|
||||||
CertFile: c.CertFile,
|
|
||||||
KeyFile: c.KeyFile,
|
|
||||||
}
|
|
||||||
if cfp.CertFile == "" {
|
|
||||||
cfp.CertFile = base64.StdEncoding.EncodeToString(c.CertBytes)
|
|
||||||
}
|
|
||||||
if cfp.KeyFile == "" {
|
|
||||||
cfp.KeyFile = base64.StdEncoding.EncodeToString(c.KeyBytes)
|
|
||||||
}
|
|
||||||
o.CertificateFiles = append(o.CertificateFiles, cfp)
|
|
||||||
}
|
|
||||||
if settings.HttpRedirectAddr != nil {
|
if settings.HttpRedirectAddr != nil {
|
||||||
o.HTTPRedirectAddr = settings.GetHttpRedirectAddr()
|
o.HTTPRedirectAddr = settings.GetHttpRedirectAddr()
|
||||||
}
|
}
|
||||||
|
|
|
@ -91,7 +91,7 @@ func (src *ConfigSource) rebuild(ctx context.Context, firstTime firstTime) {
|
||||||
|
|
||||||
// add all the config policies to the list
|
// add all the config policies to the list
|
||||||
for id, cfgpb := range src.dbConfigs {
|
for id, cfgpb := range src.dbConfigs {
|
||||||
cfg.Options.ApplySettings(cfgpb.Settings)
|
cfg.Options.ApplySettings(ctx, cfgpb.Settings)
|
||||||
var errCount uint64
|
var errCount uint64
|
||||||
|
|
||||||
err := cfg.Options.Validate()
|
err := cfg.Options.Validate()
|
||||||
|
|
|
@ -11,13 +11,17 @@ import (
|
||||||
"encoding/pem"
|
"encoding/pem"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"math/big"
|
"math/big"
|
||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
const crlPemType = "X509 CRL"
|
const (
|
||||||
|
crlPemType = "X509 CRL"
|
||||||
|
maxCertFileSize = 1 << 16
|
||||||
|
)
|
||||||
|
|
||||||
// CertificateFromBase64 returns an X509 pair from a base64 encoded blob.
|
// CertificateFromBase64 returns an X509 pair from a base64 encoded blob.
|
||||||
func CertificateFromBase64(cert, key string) (*tls.Certificate, error) {
|
func CertificateFromBase64(cert, key string) (*tls.Certificate, error) {
|
||||||
|
@ -211,3 +215,42 @@ func GenerateSelfSignedCertificate(domain string) (*tls.Certificate, error) {
|
||||||
|
|
||||||
return &cert, nil
|
return &cert, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ParsePEMCertificate parses PEM encoded certificate block
|
||||||
|
func ParsePEMCertificate(raw []byte) (*x509.Certificate, error) {
|
||||||
|
data := raw
|
||||||
|
for {
|
||||||
|
var block *pem.Block
|
||||||
|
block, data = pem.Decode(data)
|
||||||
|
if block == nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
if block.Type != "CERTIFICATE" || len(block.Headers) != 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
cert, err := x509.ParseCertificate(block.Bytes)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid certificate: %w", err)
|
||||||
|
}
|
||||||
|
return cert, nil
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("no certificate block found")
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParsePEMCertificateFromFile decodes PEM certificate from file
|
||||||
|
func ParsePEMCertificateFromFile(file string) (*x509.Certificate, error) {
|
||||||
|
fd, err := os.Open(file)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("open file: %w", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
_ = fd.Close()
|
||||||
|
}()
|
||||||
|
raw, err := io.ReadAll(io.LimitReader(fd, maxCertFileSize))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("read file: %w", err)
|
||||||
|
}
|
||||||
|
return ParsePEMCertificate(raw)
|
||||||
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue