diff --git a/config/autocert_test.go b/config/autocert_test.go index a63ea6128..9dcfa508c 100644 --- a/config/autocert_test.go +++ b/config/autocert_test.go @@ -63,7 +63,7 @@ func TestAutocertOptions_Validate(t *testing.T) { wantErr bool cleanup func() } - var tests = map[string]func(t *testing.T) test{ + tests := map[string]func(t *testing.T) test{ "ok/custom-ca": func(t *testing.T) test { return test{ fields: fields{ diff --git a/config/certs_test.go b/config/certs_test.go deleted file mode 100644 index c9eaf4818..000000000 --- a/config/certs_test.go +++ /dev/null @@ -1,43 +0,0 @@ -package config - -import ( - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestCertOverlap(t *testing.T) { - testCases := []struct { - names []string - test string - match bool - }{ - {[]string{"aa.bb.cc", "cc.bb.aa"}, "aa.bb.c", false}, - {[]string{"aa.bb.cc"}, "aa.bb.cc", true}, - {[]string{"*.bb.cc"}, "aa.bb.cc", true}, - {[]string{"a1.bb.cc", "a2.bb.cc"}, "*.bb.cc", true}, - {[]string{"*.bb.cc", "a2.bb.cc"}, "*.bb.cc", true}, - {[]string{"*.aa.bb.cc"}, "*.bb.cc", false}, - {[]string{"*.aa.bb.cc"}, "aa.bb.cc", false}, - {[]string{"bb.cc"}, "*.bb.cc", false}, - } - t.Run("match mix mode", func(t *testing.T) { - for _, tc := range testCases { - idx := make(certsIndex) - for _, name := range tc.names { - idx.add(name, certUsageServerAuth|certUsageClientAuth) - } - assert.Equalf(t, tc.match, idx.match(tc.test, certUsageServerAuth), "%v", tc) - } - }) - t.Run("different cert usages never match", func(t *testing.T) { - for _, tc := range testCases { - idx := make(certsIndex) - for _, name := range tc.names { - idx.add(name, certUsageServerAuth) - } - assert.Equalf(t, false, idx.match(tc.test, certUsageClientAuth), "%v", tc) - } - }) - -} diff --git a/config/options.go b/config/options.go index 388c4aff4..6abf17af6 100644 --- a/config/options.go +++ b/config/options.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "crypto/tls" + "crypto/x509" "encoding/base64" "errors" "fmt" @@ -983,6 +984,40 @@ func (o *Options) HasCertificates() bool { return o.Cert != "" || o.Key != "" || len(o.CertificateFiles) > 0 || o.CertFile != "" || o.KeyFile != "" } +// GetX509Certificates gets all the x509 certificates from the options. Invalid certificates are ignored. +func (o *Options) GetX509Certificates() []*x509.Certificate { + var certs []*x509.Certificate + + if o.CertFile != "" { + cert, err := cryptutil.ParsePEMCertificateFromFile(o.CertFile) + if err != nil { + log.Error(context.Background()).Err(err).Str("file", o.CertFile).Msg("invalid cert_file") + } else { + certs = append(certs, cert) + } + } else if o.Cert != "" { + if cert, err := cryptutil.ParsePEMCertificateFromBase64(o.Cert); err != nil { + log.Error(context.Background()).Err(err).Msg("invalid cert") + } else { + certs = append(certs, cert) + } + } + + for _, c := range o.CertificateFiles { + cert, err := cryptutil.ParsePEMCertificateFromBase64(c.CertFile) + if err != nil { + cert, err = cryptutil.ParsePEMCertificateFromFile(c.CertFile) + } + if err != nil { + log.Error(context.Background()).Err(err).Msg("invalid certificate_file") + } else { + certs = append(certs, cert) + } + } + + return certs +} + // GetSharedKey gets the decoded shared key. func (o *Options) GetSharedKey() ([]byte, error) { sharedKey := o.SharedKey @@ -1257,40 +1292,7 @@ func (o *Options) Checksum() uint64 { return hashutil.MustHash(o) } -func (o *Options) indexCerts(ctx context.Context) certsIndex { - idx := make(certsIndex) - - if o.CertFile != "" { - cert, err := cryptutil.ParsePEMCertificateFromFile(o.CertFile) - if err != nil { - log.Error(ctx).Err(err).Str("file", o.CertFile).Msg("parsing local cert: skipped") - } else { - idx.addCert(cert) - } - } else if o.Cert != "" { - if cert, err := cryptutil.ParsePEMCertificateFromBase64(o.Cert); err != nil { - log.Error(ctx).Err(err).Msg("parsing local cert: skipped") - } else { - idx.addCert(cert) - } - } - - for _, c := range o.CertificateFiles { - cert, err := cryptutil.ParsePEMCertificateFromBase64(c.CertFile) - if err != nil { - cert, err = cryptutil.ParsePEMCertificateFromFile(c.CertFile) - } - if err != nil { - log.Error(ctx).Err(err).Msg("parsing local cert: skipped") - } else { - idx.addCert(cert) - } - } - return idx -} - -func (o *Options) applyExternalCerts(ctx context.Context, certs []*config.Settings_Certificate) { - idx := o.indexCerts(ctx) +func (o *Options) applyExternalCerts(ctx context.Context, certsIndex *cryptutil.CertificatesIndex, certs []*config.Settings_Certificate) { for _, c := range certs { cfp := certificateFilePair{} cfp.CertFile = base64.StdEncoding.EncodeToString(c.CertBytes) @@ -1301,7 +1303,7 @@ func (o *Options) applyExternalCerts(ctx context.Context, certs []*config.Settin log.Error(ctx).Err(err).Msg("parsing cert from databroker: skipped") continue } - if overlaps, name := idx.matchCert(cert); overlaps { + if overlaps, name := certsIndex.OverlapsWithExistingCertificate(cert); overlaps { log.Error(ctx).Err(err).Str("domain", name).Msg("overlaps with local certs: skipped") continue } @@ -1311,7 +1313,7 @@ func (o *Options) applyExternalCerts(ctx context.Context, certs []*config.Settin } // ApplySettings modifies the config options using the given protobuf settings. -func (o *Options) ApplySettings(ctx context.Context, settings *config.Settings) { +func (o *Options) ApplySettings(ctx context.Context, certsIndex *cryptutil.CertificatesIndex, settings *config.Settings) { if settings == nil { return } @@ -1325,7 +1327,7 @@ func (o *Options) ApplySettings(ctx context.Context, settings *config.Settings) set(&o.Addr, settings.Address) set(&o.InsecureServer, settings.InsecureServer) set(&o.DNSLookupFamily, settings.DnsLookupFamily) - o.applyExternalCerts(ctx, settings.GetCertificates()) + o.applyExternalCerts(ctx, certsIndex, settings.GetCertificates()) set(&o.HTTPRedirectAddr, settings.HttpRedirectAddr) setDuration(&o.ReadTimeout, settings.TimeoutRead) setDuration(&o.WriteTimeout, settings.TimeoutWrite) diff --git a/config/options_test.go b/config/options_test.go index 42ab36d59..b52ac0954 100644 --- a/config/options_test.go +++ b/config/options_test.go @@ -3,6 +3,7 @@ package config import ( "context" "crypto/tls" + "crypto/x509" "encoding/base64" "encoding/pem" "fmt" @@ -735,13 +736,17 @@ func TestOptions_ApplySettings(t *testing.T) { cert3, err := cryptutil.GenerateCertificate(nil, "not.example.com") require.NoError(t, err) + certsIndex := cryptutil.NewCertificatesIndex() + xc1, _ := x509.ParseCertificate(cert1.Certificate[0]) + certsIndex.Add(xc1) + settings := &config.Settings{ Certificates: []*config.Settings_Certificate{ {CertBytes: encodeCert(cert2)}, {CertBytes: encodeCert(cert3)}, }, } - options.ApplySettings(ctx, settings) + options.ApplySettings(ctx, certsIndex, settings) assert.Len(t, options.CertificateFiles, 2, "should prevent adding duplicate certificates") }) } diff --git a/go.mod b/go.mod index c5382211c..e1f5ed1a2 100644 --- a/go.mod +++ b/go.mod @@ -55,7 +55,7 @@ require ( github.com/rs/zerolog v1.29.1 github.com/shirou/gopsutil/v3 v3.23.4 github.com/spf13/viper v1.15.0 - github.com/stretchr/testify v1.8.3 + github.com/stretchr/testify v1.8.4 github.com/tniswong/go.rfcx v0.0.0-20181019234604-07783c52761f github.com/volatiletech/null/v9 v9.0.0 github.com/yuin/gopher-lua v1.1.0 diff --git a/go.sum b/go.sum index 6dc7fe406..4b85a85ce 100644 --- a/go.sum +++ b/go.sum @@ -658,8 +658,8 @@ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= -github.com/stretchr/testify v1.8.3 h1:RP3t2pwF7cMEbC1dqtB6poj3niw/9gnV4Cjg5oW5gtY= -github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/stvp/go-udp-testing v0.0.0-20201019212854-469649b16807/go.mod h1:7jxmlfBCDBXRzr0eAQJ48XC1hBu1np4CS5+cHEYfwpc= github.com/subosito/gotenv v1.4.2 h1:X1TuBLAMDFbaTAChgCBLu3DU3UPyELpnF2jjJ2cz/S8= github.com/subosito/gotenv v1.4.2/go.mod h1:ayKnFf/c6rvx/2iiLrJUk1e6plDbT3edrFNGqEflhK0= diff --git a/internal/databroker/config_source.go b/internal/databroker/config_source.go index 68185e282..711a5e5bd 100644 --- a/internal/databroker/config_source.go +++ b/internal/databroker/config_source.go @@ -13,6 +13,7 @@ import ( "github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/telemetry/metrics" "github.com/pomerium/pomerium/internal/telemetry/trace" + "github.com/pomerium/pomerium/pkg/cryptutil" "github.com/pomerium/pomerium/pkg/grpc" configpb "github.com/pomerium/pomerium/pkg/grpc/config" "github.com/pomerium/pomerium/pkg/grpc/databroker" @@ -98,11 +99,16 @@ func (src *ConfigSource) rebuild(ctx context.Context, firstTime firstTime) { ids := maps.Keys(src.dbConfigs) sort.Strings(ids) + certsIndex := cryptutil.NewCertificatesIndex() + for _, cert := range cfg.Options.GetX509Certificates() { + certsIndex.Add(cert) + } + // add all the config policies to the list for _, id := range ids { cfgpb := src.dbConfigs[id] - cfg.Options.ApplySettings(ctx, cfgpb.Settings) + cfg.Options.ApplySettings(ctx, certsIndex, cfgpb.Settings) var errCount uint64 err := cfg.Options.Validate() diff --git a/internal/databroker/config_source_test.go b/internal/databroker/config_source_test.go index e9b22fcca..335662f59 100644 --- a/internal/databroker/config_source_test.go +++ b/internal/databroker/config_source_test.go @@ -2,22 +2,35 @@ package databroker import ( "context" + "encoding/base64" "net" "net/url" "testing" "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "google.golang.org/grpc" "google.golang.org/protobuf/proto" "github.com/pomerium/pomerium/config" + "github.com/pomerium/pomerium/pkg/cryptutil" configpb "github.com/pomerium/pomerium/pkg/grpc/config" "github.com/pomerium/pomerium/pkg/grpc/databroker" "github.com/pomerium/pomerium/pkg/protoutil" ) func TestConfigSource(t *testing.T) { + t.Parallel() + + generateCert := func(name string) ([]byte, []byte) { + cert, err := cryptutil.GenerateCertificate(nil, name) + require.NoError(t, err) + certPEM, keyPEM, err := cryptutil.EncodeCertificate(cert) + require.NoError(t, err) + return certPEM, keyPEM + } + ctx, clearTimeout := context.WithTimeout(context.Background(), 50*time.Second) defer clearTimeout() @@ -45,6 +58,8 @@ func TestConfigSource(t *testing.T) { {URL: *u}, }, AllowedUsers: []string{"foo@bar.com"}, }) + certPEM, keyPEM := generateCert("*.example.com") + base.Cert, base.Key = base64.StdEncoding.EncodeToString(certPEM), base64.StdEncoding.EncodeToString(keyPEM) baseSource := config.NewStaticSource(&config.Config{ OutboundPort: outboundPort, @@ -55,13 +70,17 @@ func TestConfigSource(t *testing.T) { }) cfgs <- src.GetConfig() + route := &configpb.Route{ + From: "https://from.example.com", + To: []string{"https://to.example.com"}, + } + cert := &configpb.Settings_Certificate{} + cert.CertBytes, cert.KeyBytes = generateCert("*.example.com") data := protoutil.NewAny(&configpb.Config{ - Name: "config", - Routes: []*configpb.Route{ - { - From: "https://from.example.com", - To: []string{"https://to.example.com"}, - }, + Name: "config", + Routes: []*configpb.Route{route}, + Settings: &configpb.Settings{ + Certificates: []*configpb.Settings_Certificate{cert}, }, }) _, _ = dataBrokerServer.Put(ctx, &databroker.PutRequest{ @@ -86,6 +105,7 @@ func TestConfigSource(t *testing.T) { return case cfg := <-cfgs: assert.Len(t, cfg.Options.AdditionalPolicies, 1) + assert.Len(t, cfg.Options.CertificateFiles, 0, "ignores overlapping certificate") } baseSource.SetConfig(ctx, &config.Config{ diff --git a/config/certs.go b/pkg/cryptutil/certificates_index.go similarity index 57% rename from config/certs.go rename to pkg/cryptutil/certificates_index.go index 99d57e0b2..bb16d748f 100644 --- a/config/certs.go +++ b/pkg/cryptutil/certificates_index.go @@ -1,4 +1,4 @@ -package config +package cryptutil import ( "crypto/x509" @@ -6,13 +6,70 @@ import ( ) type certUsage byte -type certsIndex map[string]map[string]certUsage const ( certUsageServerAuth = certUsage(1 << iota) certUsageClientAuth ) +// A CertificatesIndex indexes certificates to determine if there is overlap between them. +type CertificatesIndex struct { + index map[string]map[string]certUsage +} + +// NewCertificatesIndex creates a new CertificatesIndex. +func NewCertificatesIndex() *CertificatesIndex { + return &CertificatesIndex{ + index: make(map[string]map[string]certUsage), + } +} + +// Add adds a certificate to the index. +func (c *CertificatesIndex) Add(cert *x509.Certificate) { + usage := getCertUsage(cert) + for _, name := range cert.DNSNames { + c.add(name, usage) + } +} + +// OverlapsWithExistingCertificate returns true if the certificate overlaps with an existing certificate. +func (c *CertificatesIndex) OverlapsWithExistingCertificate(cert *x509.Certificate) (bool, string) { + usage := getCertUsage(cert) + for _, name := range cert.DNSNames { + if c.match(name, usage) { + return true, name + } + } + return false, "" +} + +func (c *CertificatesIndex) add(name string, usage certUsage) { + prefix, suffix := splitDomainName(name) + names := c.index[suffix] + if names == nil { + names = make(map[string]certUsage) + c.index[suffix] = names + } + names[prefix] = usage +} + +func (c *CertificatesIndex) match(name string, usage certUsage) bool { + prefix, suffix := splitDomainName(name) + names := c.index[suffix] + if names == nil { + return false + } + if prefix != "*" { + return names["*"]&usage != 0 || names[prefix]&usage != 0 + } + for _, u := range names { + if u&usage != 0 { + return true + } + } + return false +} + func splitDomainName(name string) (prefix, suffix string) { dot := strings.IndexRune(name, '.') if dot < 0 { @@ -33,47 +90,3 @@ func getCertUsage(cert *x509.Certificate) certUsage { } return usage } - -func (c certsIndex) addCert(cert *x509.Certificate) { - usage := getCertUsage(cert) - for _, name := range cert.DNSNames { - c.add(name, usage) - } -} - -func (c certsIndex) matchCert(cert *x509.Certificate) (bool, string) { - usage := getCertUsage(cert) - for _, name := range cert.DNSNames { - if c.match(name, usage) { - return true, name - } - } - return false, "" -} - -func (c certsIndex) add(name string, usage certUsage) { - prefix, suffix := splitDomainName(name) - names := c[suffix] - if names == nil { - names = make(map[string]certUsage) - c[suffix] = names - } - names[prefix] = usage -} - -func (c certsIndex) match(name string, usage certUsage) bool { - prefix, suffix := splitDomainName(name) - names := c[suffix] - if names == nil { - return false - } - if prefix != "*" { - return names["*"]&usage != 0 || names[prefix]&usage != 0 - } - for _, u := range names { - if u&usage != 0 { - return true - } - } - return false -} diff --git a/pkg/cryptutil/certificates_index_test.go b/pkg/cryptutil/certificates_index_test.go new file mode 100644 index 000000000..26abdeee2 --- /dev/null +++ b/pkg/cryptutil/certificates_index_test.go @@ -0,0 +1,75 @@ +package cryptutil_test + +import ( + "crypto/x509" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/pomerium/pomerium/pkg/cryptutil" + "github.com/pomerium/pomerium/pkg/derivecert" +) + +func TestCertificatesIndex(t *testing.T) { + t.Parallel() + + ca, err := derivecert.NewCA(cryptutil.NewKey()) + require.NoError(t, err) + + mkClientCert := func(domains []string) *x509.Certificate { + pem, err := ca.NewServerCert(domains, func(c *x509.Certificate) { + c.ExtKeyUsage = []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth} + }) + require.NoError(t, err) + + _, cert, err := pem.KeyCert() + require.NoError(t, err) + + return cert + } + mkServerCert := func(domains []string) *x509.Certificate { + pem, err := ca.NewServerCert(domains) + require.NoError(t, err) + + _, cert, err := pem.KeyCert() + require.NoError(t, err) + + return cert + } + + testCases := []struct { + names []string + test string + match bool + }{ + {[]string{"aa.bb.cc", "cc.bb.aa"}, "aa.bb.c", false}, + {[]string{"aa.bb.cc"}, "aa.bb.cc", true}, + {[]string{"*.bb.cc"}, "aa.bb.cc", true}, + {[]string{"a1.bb.cc", "a2.bb.cc"}, "*.bb.cc", true}, + {[]string{"*.bb.cc", "a2.bb.cc"}, "*.bb.cc", true}, + {[]string{"*.aa.bb.cc"}, "*.bb.cc", false}, + {[]string{"*.aa.bb.cc"}, "aa.bb.cc", false}, + {[]string{"bb.cc"}, "*.bb.cc", false}, + } + t.Run("match mix mode", func(t *testing.T) { + for _, tc := range testCases { + idx := cryptutil.NewCertificatesIndex() + idx.Add(mkServerCert(tc.names)) + + cert := mkServerCert([]string{tc.test}) + overlaps, _ := idx.OverlapsWithExistingCertificate(cert) + assert.Equalf(t, tc.match, overlaps, "%v", tc) + } + }) + t.Run("different cert usages never match", func(t *testing.T) { + for _, tc := range testCases { + idx := cryptutil.NewCertificatesIndex() + idx.Add(mkServerCert(tc.names)) + + cert := mkClientCert([]string{tc.test}) + overlaps, _ := idx.OverlapsWithExistingCertificate(cert) + assert.Equalf(t, false, overlaps, "%v", tc) + } + }) +}