config: update logic for checking overlapping certificates (#4216)

* config: update logic for checking overlapping certificates

* add test

* go mod tidy
This commit is contained in:
Caleb Doxsey 2023-06-01 09:30:46 -06:00 committed by GitHub
parent 3a791542d4
commit baf964f44a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 216 additions and 138 deletions

View file

@ -63,7 +63,7 @@ func TestAutocertOptions_Validate(t *testing.T) {
wantErr bool
cleanup func()
}
var tests = map[string]func(t *testing.T) test{
tests := map[string]func(t *testing.T) test{
"ok/custom-ca": func(t *testing.T) test {
return test{
fields: fields{

View file

@ -1,43 +0,0 @@
package config
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestCertOverlap(t *testing.T) {
testCases := []struct {
names []string
test string
match bool
}{
{[]string{"aa.bb.cc", "cc.bb.aa"}, "aa.bb.c", false},
{[]string{"aa.bb.cc"}, "aa.bb.cc", true},
{[]string{"*.bb.cc"}, "aa.bb.cc", true},
{[]string{"a1.bb.cc", "a2.bb.cc"}, "*.bb.cc", true},
{[]string{"*.bb.cc", "a2.bb.cc"}, "*.bb.cc", true},
{[]string{"*.aa.bb.cc"}, "*.bb.cc", false},
{[]string{"*.aa.bb.cc"}, "aa.bb.cc", false},
{[]string{"bb.cc"}, "*.bb.cc", false},
}
t.Run("match mix mode", func(t *testing.T) {
for _, tc := range testCases {
idx := make(certsIndex)
for _, name := range tc.names {
idx.add(name, certUsageServerAuth|certUsageClientAuth)
}
assert.Equalf(t, tc.match, idx.match(tc.test, certUsageServerAuth), "%v", tc)
}
})
t.Run("different cert usages never match", func(t *testing.T) {
for _, tc := range testCases {
idx := make(certsIndex)
for _, name := range tc.names {
idx.add(name, certUsageServerAuth)
}
assert.Equalf(t, false, idx.match(tc.test, certUsageClientAuth), "%v", tc)
}
})
}

View file

@ -4,6 +4,7 @@ import (
"bytes"
"context"
"crypto/tls"
"crypto/x509"
"encoding/base64"
"errors"
"fmt"
@ -983,6 +984,40 @@ func (o *Options) HasCertificates() bool {
return o.Cert != "" || o.Key != "" || len(o.CertificateFiles) > 0 || o.CertFile != "" || o.KeyFile != ""
}
// GetX509Certificates gets all the x509 certificates from the options. Invalid certificates are ignored.
func (o *Options) GetX509Certificates() []*x509.Certificate {
var certs []*x509.Certificate
if o.CertFile != "" {
cert, err := cryptutil.ParsePEMCertificateFromFile(o.CertFile)
if err != nil {
log.Error(context.Background()).Err(err).Str("file", o.CertFile).Msg("invalid cert_file")
} else {
certs = append(certs, cert)
}
} else if o.Cert != "" {
if cert, err := cryptutil.ParsePEMCertificateFromBase64(o.Cert); err != nil {
log.Error(context.Background()).Err(err).Msg("invalid cert")
} else {
certs = append(certs, cert)
}
}
for _, c := range o.CertificateFiles {
cert, err := cryptutil.ParsePEMCertificateFromBase64(c.CertFile)
if err != nil {
cert, err = cryptutil.ParsePEMCertificateFromFile(c.CertFile)
}
if err != nil {
log.Error(context.Background()).Err(err).Msg("invalid certificate_file")
} else {
certs = append(certs, cert)
}
}
return certs
}
// GetSharedKey gets the decoded shared key.
func (o *Options) GetSharedKey() ([]byte, error) {
sharedKey := o.SharedKey
@ -1257,40 +1292,7 @@ func (o *Options) Checksum() uint64 {
return hashutil.MustHash(o)
}
func (o *Options) indexCerts(ctx context.Context) certsIndex {
idx := make(certsIndex)
if o.CertFile != "" {
cert, err := cryptutil.ParsePEMCertificateFromFile(o.CertFile)
if err != nil {
log.Error(ctx).Err(err).Str("file", o.CertFile).Msg("parsing local cert: skipped")
} else {
idx.addCert(cert)
}
} else if o.Cert != "" {
if cert, err := cryptutil.ParsePEMCertificateFromBase64(o.Cert); err != nil {
log.Error(ctx).Err(err).Msg("parsing local cert: skipped")
} else {
idx.addCert(cert)
}
}
for _, c := range o.CertificateFiles {
cert, err := cryptutil.ParsePEMCertificateFromBase64(c.CertFile)
if err != nil {
cert, err = cryptutil.ParsePEMCertificateFromFile(c.CertFile)
}
if err != nil {
log.Error(ctx).Err(err).Msg("parsing local cert: skipped")
} else {
idx.addCert(cert)
}
}
return idx
}
func (o *Options) applyExternalCerts(ctx context.Context, certs []*config.Settings_Certificate) {
idx := o.indexCerts(ctx)
func (o *Options) applyExternalCerts(ctx context.Context, certsIndex *cryptutil.CertificatesIndex, certs []*config.Settings_Certificate) {
for _, c := range certs {
cfp := certificateFilePair{}
cfp.CertFile = base64.StdEncoding.EncodeToString(c.CertBytes)
@ -1301,7 +1303,7 @@ func (o *Options) applyExternalCerts(ctx context.Context, certs []*config.Settin
log.Error(ctx).Err(err).Msg("parsing cert from databroker: skipped")
continue
}
if overlaps, name := idx.matchCert(cert); overlaps {
if overlaps, name := certsIndex.OverlapsWithExistingCertificate(cert); overlaps {
log.Error(ctx).Err(err).Str("domain", name).Msg("overlaps with local certs: skipped")
continue
}
@ -1311,7 +1313,7 @@ func (o *Options) applyExternalCerts(ctx context.Context, certs []*config.Settin
}
// ApplySettings modifies the config options using the given protobuf settings.
func (o *Options) ApplySettings(ctx context.Context, settings *config.Settings) {
func (o *Options) ApplySettings(ctx context.Context, certsIndex *cryptutil.CertificatesIndex, settings *config.Settings) {
if settings == nil {
return
}
@ -1325,7 +1327,7 @@ func (o *Options) ApplySettings(ctx context.Context, settings *config.Settings)
set(&o.Addr, settings.Address)
set(&o.InsecureServer, settings.InsecureServer)
set(&o.DNSLookupFamily, settings.DnsLookupFamily)
o.applyExternalCerts(ctx, settings.GetCertificates())
o.applyExternalCerts(ctx, certsIndex, settings.GetCertificates())
set(&o.HTTPRedirectAddr, settings.HttpRedirectAddr)
setDuration(&o.ReadTimeout, settings.TimeoutRead)
setDuration(&o.WriteTimeout, settings.TimeoutWrite)

View file

@ -3,6 +3,7 @@ package config
import (
"context"
"crypto/tls"
"crypto/x509"
"encoding/base64"
"encoding/pem"
"fmt"
@ -735,13 +736,17 @@ func TestOptions_ApplySettings(t *testing.T) {
cert3, err := cryptutil.GenerateCertificate(nil, "not.example.com")
require.NoError(t, err)
certsIndex := cryptutil.NewCertificatesIndex()
xc1, _ := x509.ParseCertificate(cert1.Certificate[0])
certsIndex.Add(xc1)
settings := &config.Settings{
Certificates: []*config.Settings_Certificate{
{CertBytes: encodeCert(cert2)},
{CertBytes: encodeCert(cert3)},
},
}
options.ApplySettings(ctx, settings)
options.ApplySettings(ctx, certsIndex, settings)
assert.Len(t, options.CertificateFiles, 2, "should prevent adding duplicate certificates")
})
}

2
go.mod
View file

@ -55,7 +55,7 @@ require (
github.com/rs/zerolog v1.29.1
github.com/shirou/gopsutil/v3 v3.23.4
github.com/spf13/viper v1.15.0
github.com/stretchr/testify v1.8.3
github.com/stretchr/testify v1.8.4
github.com/tniswong/go.rfcx v0.0.0-20181019234604-07783c52761f
github.com/volatiletech/null/v9 v9.0.0
github.com/yuin/gopher-lua v1.1.0

4
go.sum
View file

@ -658,8 +658,8 @@ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.8.3 h1:RP3t2pwF7cMEbC1dqtB6poj3niw/9gnV4Cjg5oW5gtY=
github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/stvp/go-udp-testing v0.0.0-20201019212854-469649b16807/go.mod h1:7jxmlfBCDBXRzr0eAQJ48XC1hBu1np4CS5+cHEYfwpc=
github.com/subosito/gotenv v1.4.2 h1:X1TuBLAMDFbaTAChgCBLu3DU3UPyELpnF2jjJ2cz/S8=
github.com/subosito/gotenv v1.4.2/go.mod h1:ayKnFf/c6rvx/2iiLrJUk1e6plDbT3edrFNGqEflhK0=

View file

@ -13,6 +13,7 @@ import (
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/telemetry/metrics"
"github.com/pomerium/pomerium/internal/telemetry/trace"
"github.com/pomerium/pomerium/pkg/cryptutil"
"github.com/pomerium/pomerium/pkg/grpc"
configpb "github.com/pomerium/pomerium/pkg/grpc/config"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
@ -98,11 +99,16 @@ func (src *ConfigSource) rebuild(ctx context.Context, firstTime firstTime) {
ids := maps.Keys(src.dbConfigs)
sort.Strings(ids)
certsIndex := cryptutil.NewCertificatesIndex()
for _, cert := range cfg.Options.GetX509Certificates() {
certsIndex.Add(cert)
}
// add all the config policies to the list
for _, id := range ids {
cfgpb := src.dbConfigs[id]
cfg.Options.ApplySettings(ctx, cfgpb.Settings)
cfg.Options.ApplySettings(ctx, certsIndex, cfgpb.Settings)
var errCount uint64
err := cfg.Options.Validate()

View file

@ -2,22 +2,35 @@ package databroker
import (
"context"
"encoding/base64"
"net"
"net/url"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"google.golang.org/protobuf/proto"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/pkg/cryptutil"
configpb "github.com/pomerium/pomerium/pkg/grpc/config"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/protoutil"
)
func TestConfigSource(t *testing.T) {
t.Parallel()
generateCert := func(name string) ([]byte, []byte) {
cert, err := cryptutil.GenerateCertificate(nil, name)
require.NoError(t, err)
certPEM, keyPEM, err := cryptutil.EncodeCertificate(cert)
require.NoError(t, err)
return certPEM, keyPEM
}
ctx, clearTimeout := context.WithTimeout(context.Background(), 50*time.Second)
defer clearTimeout()
@ -45,6 +58,8 @@ func TestConfigSource(t *testing.T) {
{URL: *u},
}, AllowedUsers: []string{"foo@bar.com"},
})
certPEM, keyPEM := generateCert("*.example.com")
base.Cert, base.Key = base64.StdEncoding.EncodeToString(certPEM), base64.StdEncoding.EncodeToString(keyPEM)
baseSource := config.NewStaticSource(&config.Config{
OutboundPort: outboundPort,
@ -55,13 +70,17 @@ func TestConfigSource(t *testing.T) {
})
cfgs <- src.GetConfig()
route := &configpb.Route{
From: "https://from.example.com",
To: []string{"https://to.example.com"},
}
cert := &configpb.Settings_Certificate{}
cert.CertBytes, cert.KeyBytes = generateCert("*.example.com")
data := protoutil.NewAny(&configpb.Config{
Name: "config",
Routes: []*configpb.Route{
{
From: "https://from.example.com",
To: []string{"https://to.example.com"},
},
Name: "config",
Routes: []*configpb.Route{route},
Settings: &configpb.Settings{
Certificates: []*configpb.Settings_Certificate{cert},
},
})
_, _ = dataBrokerServer.Put(ctx, &databroker.PutRequest{
@ -86,6 +105,7 @@ func TestConfigSource(t *testing.T) {
return
case cfg := <-cfgs:
assert.Len(t, cfg.Options.AdditionalPolicies, 1)
assert.Len(t, cfg.Options.CertificateFiles, 0, "ignores overlapping certificate")
}
baseSource.SetConfig(ctx, &config.Config{

View file

@ -1,4 +1,4 @@
package config
package cryptutil
import (
"crypto/x509"
@ -6,13 +6,70 @@ import (
)
type certUsage byte
type certsIndex map[string]map[string]certUsage
const (
certUsageServerAuth = certUsage(1 << iota)
certUsageClientAuth
)
// A CertificatesIndex indexes certificates to determine if there is overlap between them.
type CertificatesIndex struct {
index map[string]map[string]certUsage
}
// NewCertificatesIndex creates a new CertificatesIndex.
func NewCertificatesIndex() *CertificatesIndex {
return &CertificatesIndex{
index: make(map[string]map[string]certUsage),
}
}
// Add adds a certificate to the index.
func (c *CertificatesIndex) Add(cert *x509.Certificate) {
usage := getCertUsage(cert)
for _, name := range cert.DNSNames {
c.add(name, usage)
}
}
// OverlapsWithExistingCertificate returns true if the certificate overlaps with an existing certificate.
func (c *CertificatesIndex) OverlapsWithExistingCertificate(cert *x509.Certificate) (bool, string) {
usage := getCertUsage(cert)
for _, name := range cert.DNSNames {
if c.match(name, usage) {
return true, name
}
}
return false, ""
}
func (c *CertificatesIndex) add(name string, usage certUsage) {
prefix, suffix := splitDomainName(name)
names := c.index[suffix]
if names == nil {
names = make(map[string]certUsage)
c.index[suffix] = names
}
names[prefix] = usage
}
func (c *CertificatesIndex) match(name string, usage certUsage) bool {
prefix, suffix := splitDomainName(name)
names := c.index[suffix]
if names == nil {
return false
}
if prefix != "*" {
return names["*"]&usage != 0 || names[prefix]&usage != 0
}
for _, u := range names {
if u&usage != 0 {
return true
}
}
return false
}
func splitDomainName(name string) (prefix, suffix string) {
dot := strings.IndexRune(name, '.')
if dot < 0 {
@ -33,47 +90,3 @@ func getCertUsage(cert *x509.Certificate) certUsage {
}
return usage
}
func (c certsIndex) addCert(cert *x509.Certificate) {
usage := getCertUsage(cert)
for _, name := range cert.DNSNames {
c.add(name, usage)
}
}
func (c certsIndex) matchCert(cert *x509.Certificate) (bool, string) {
usage := getCertUsage(cert)
for _, name := range cert.DNSNames {
if c.match(name, usage) {
return true, name
}
}
return false, ""
}
func (c certsIndex) add(name string, usage certUsage) {
prefix, suffix := splitDomainName(name)
names := c[suffix]
if names == nil {
names = make(map[string]certUsage)
c[suffix] = names
}
names[prefix] = usage
}
func (c certsIndex) match(name string, usage certUsage) bool {
prefix, suffix := splitDomainName(name)
names := c[suffix]
if names == nil {
return false
}
if prefix != "*" {
return names["*"]&usage != 0 || names[prefix]&usage != 0
}
for _, u := range names {
if u&usage != 0 {
return true
}
}
return false
}

View file

@ -0,0 +1,75 @@
package cryptutil_test
import (
"crypto/x509"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/pomerium/pomerium/pkg/cryptutil"
"github.com/pomerium/pomerium/pkg/derivecert"
)
func TestCertificatesIndex(t *testing.T) {
t.Parallel()
ca, err := derivecert.NewCA(cryptutil.NewKey())
require.NoError(t, err)
mkClientCert := func(domains []string) *x509.Certificate {
pem, err := ca.NewServerCert(domains, func(c *x509.Certificate) {
c.ExtKeyUsage = []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}
})
require.NoError(t, err)
_, cert, err := pem.KeyCert()
require.NoError(t, err)
return cert
}
mkServerCert := func(domains []string) *x509.Certificate {
pem, err := ca.NewServerCert(domains)
require.NoError(t, err)
_, cert, err := pem.KeyCert()
require.NoError(t, err)
return cert
}
testCases := []struct {
names []string
test string
match bool
}{
{[]string{"aa.bb.cc", "cc.bb.aa"}, "aa.bb.c", false},
{[]string{"aa.bb.cc"}, "aa.bb.cc", true},
{[]string{"*.bb.cc"}, "aa.bb.cc", true},
{[]string{"a1.bb.cc", "a2.bb.cc"}, "*.bb.cc", true},
{[]string{"*.bb.cc", "a2.bb.cc"}, "*.bb.cc", true},
{[]string{"*.aa.bb.cc"}, "*.bb.cc", false},
{[]string{"*.aa.bb.cc"}, "aa.bb.cc", false},
{[]string{"bb.cc"}, "*.bb.cc", false},
}
t.Run("match mix mode", func(t *testing.T) {
for _, tc := range testCases {
idx := cryptutil.NewCertificatesIndex()
idx.Add(mkServerCert(tc.names))
cert := mkServerCert([]string{tc.test})
overlaps, _ := idx.OverlapsWithExistingCertificate(cert)
assert.Equalf(t, tc.match, overlaps, "%v", tc)
}
})
t.Run("different cert usages never match", func(t *testing.T) {
for _, tc := range testCases {
idx := cryptutil.NewCertificatesIndex()
idx.Add(mkServerCert(tc.names))
cert := mkClientCert([]string{tc.test})
overlaps, _ := idx.OverlapsWithExistingCertificate(cert)
assert.Equalf(t, false, overlaps, "%v", tc)
}
})
}