mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-10 07:37:33 +02:00
config: update logic for checking overlapping certificates (#4216)
* config: update logic for checking overlapping certificates * add test * go mod tidy
This commit is contained in:
parent
3a791542d4
commit
baf964f44a
10 changed files with 216 additions and 138 deletions
|
@ -63,7 +63,7 @@ func TestAutocertOptions_Validate(t *testing.T) {
|
||||||
wantErr bool
|
wantErr bool
|
||||||
cleanup func()
|
cleanup func()
|
||||||
}
|
}
|
||||||
var tests = map[string]func(t *testing.T) test{
|
tests := map[string]func(t *testing.T) test{
|
||||||
"ok/custom-ca": func(t *testing.T) test {
|
"ok/custom-ca": func(t *testing.T) test {
|
||||||
return test{
|
return test{
|
||||||
fields: fields{
|
fields: fields{
|
||||||
|
|
|
@ -1,43 +0,0 @@
|
||||||
package config
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestCertOverlap(t *testing.T) {
|
|
||||||
testCases := []struct {
|
|
||||||
names []string
|
|
||||||
test string
|
|
||||||
match bool
|
|
||||||
}{
|
|
||||||
{[]string{"aa.bb.cc", "cc.bb.aa"}, "aa.bb.c", false},
|
|
||||||
{[]string{"aa.bb.cc"}, "aa.bb.cc", true},
|
|
||||||
{[]string{"*.bb.cc"}, "aa.bb.cc", true},
|
|
||||||
{[]string{"a1.bb.cc", "a2.bb.cc"}, "*.bb.cc", true},
|
|
||||||
{[]string{"*.bb.cc", "a2.bb.cc"}, "*.bb.cc", true},
|
|
||||||
{[]string{"*.aa.bb.cc"}, "*.bb.cc", false},
|
|
||||||
{[]string{"*.aa.bb.cc"}, "aa.bb.cc", false},
|
|
||||||
{[]string{"bb.cc"}, "*.bb.cc", false},
|
|
||||||
}
|
|
||||||
t.Run("match mix mode", func(t *testing.T) {
|
|
||||||
for _, tc := range testCases {
|
|
||||||
idx := make(certsIndex)
|
|
||||||
for _, name := range tc.names {
|
|
||||||
idx.add(name, certUsageServerAuth|certUsageClientAuth)
|
|
||||||
}
|
|
||||||
assert.Equalf(t, tc.match, idx.match(tc.test, certUsageServerAuth), "%v", tc)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
t.Run("different cert usages never match", func(t *testing.T) {
|
|
||||||
for _, tc := range testCases {
|
|
||||||
idx := make(certsIndex)
|
|
||||||
for _, name := range tc.names {
|
|
||||||
idx.add(name, certUsageServerAuth)
|
|
||||||
}
|
|
||||||
assert.Equalf(t, false, idx.match(tc.test, certUsageClientAuth), "%v", tc)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
}
|
|
|
@ -4,6 +4,7 @@ import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
@ -983,6 +984,40 @@ func (o *Options) HasCertificates() bool {
|
||||||
return o.Cert != "" || o.Key != "" || len(o.CertificateFiles) > 0 || o.CertFile != "" || o.KeyFile != ""
|
return o.Cert != "" || o.Key != "" || len(o.CertificateFiles) > 0 || o.CertFile != "" || o.KeyFile != ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetX509Certificates gets all the x509 certificates from the options. Invalid certificates are ignored.
|
||||||
|
func (o *Options) GetX509Certificates() []*x509.Certificate {
|
||||||
|
var certs []*x509.Certificate
|
||||||
|
|
||||||
|
if o.CertFile != "" {
|
||||||
|
cert, err := cryptutil.ParsePEMCertificateFromFile(o.CertFile)
|
||||||
|
if err != nil {
|
||||||
|
log.Error(context.Background()).Err(err).Str("file", o.CertFile).Msg("invalid cert_file")
|
||||||
|
} else {
|
||||||
|
certs = append(certs, cert)
|
||||||
|
}
|
||||||
|
} else if o.Cert != "" {
|
||||||
|
if cert, err := cryptutil.ParsePEMCertificateFromBase64(o.Cert); err != nil {
|
||||||
|
log.Error(context.Background()).Err(err).Msg("invalid cert")
|
||||||
|
} else {
|
||||||
|
certs = append(certs, cert)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, c := range o.CertificateFiles {
|
||||||
|
cert, err := cryptutil.ParsePEMCertificateFromBase64(c.CertFile)
|
||||||
|
if err != nil {
|
||||||
|
cert, err = cryptutil.ParsePEMCertificateFromFile(c.CertFile)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
log.Error(context.Background()).Err(err).Msg("invalid certificate_file")
|
||||||
|
} else {
|
||||||
|
certs = append(certs, cert)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return certs
|
||||||
|
}
|
||||||
|
|
||||||
// GetSharedKey gets the decoded shared key.
|
// GetSharedKey gets the decoded shared key.
|
||||||
func (o *Options) GetSharedKey() ([]byte, error) {
|
func (o *Options) GetSharedKey() ([]byte, error) {
|
||||||
sharedKey := o.SharedKey
|
sharedKey := o.SharedKey
|
||||||
|
@ -1257,40 +1292,7 @@ func (o *Options) Checksum() uint64 {
|
||||||
return hashutil.MustHash(o)
|
return hashutil.MustHash(o)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *Options) indexCerts(ctx context.Context) certsIndex {
|
func (o *Options) applyExternalCerts(ctx context.Context, certsIndex *cryptutil.CertificatesIndex, certs []*config.Settings_Certificate) {
|
||||||
idx := make(certsIndex)
|
|
||||||
|
|
||||||
if o.CertFile != "" {
|
|
||||||
cert, err := cryptutil.ParsePEMCertificateFromFile(o.CertFile)
|
|
||||||
if err != nil {
|
|
||||||
log.Error(ctx).Err(err).Str("file", o.CertFile).Msg("parsing local cert: skipped")
|
|
||||||
} else {
|
|
||||||
idx.addCert(cert)
|
|
||||||
}
|
|
||||||
} else if o.Cert != "" {
|
|
||||||
if cert, err := cryptutil.ParsePEMCertificateFromBase64(o.Cert); err != nil {
|
|
||||||
log.Error(ctx).Err(err).Msg("parsing local cert: skipped")
|
|
||||||
} else {
|
|
||||||
idx.addCert(cert)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, c := range o.CertificateFiles {
|
|
||||||
cert, err := cryptutil.ParsePEMCertificateFromBase64(c.CertFile)
|
|
||||||
if err != nil {
|
|
||||||
cert, err = cryptutil.ParsePEMCertificateFromFile(c.CertFile)
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
log.Error(ctx).Err(err).Msg("parsing local cert: skipped")
|
|
||||||
} else {
|
|
||||||
idx.addCert(cert)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return idx
|
|
||||||
}
|
|
||||||
|
|
||||||
func (o *Options) applyExternalCerts(ctx context.Context, certs []*config.Settings_Certificate) {
|
|
||||||
idx := o.indexCerts(ctx)
|
|
||||||
for _, c := range certs {
|
for _, c := range certs {
|
||||||
cfp := certificateFilePair{}
|
cfp := certificateFilePair{}
|
||||||
cfp.CertFile = base64.StdEncoding.EncodeToString(c.CertBytes)
|
cfp.CertFile = base64.StdEncoding.EncodeToString(c.CertBytes)
|
||||||
|
@ -1301,7 +1303,7 @@ func (o *Options) applyExternalCerts(ctx context.Context, certs []*config.Settin
|
||||||
log.Error(ctx).Err(err).Msg("parsing cert from databroker: skipped")
|
log.Error(ctx).Err(err).Msg("parsing cert from databroker: skipped")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if overlaps, name := idx.matchCert(cert); overlaps {
|
if overlaps, name := certsIndex.OverlapsWithExistingCertificate(cert); overlaps {
|
||||||
log.Error(ctx).Err(err).Str("domain", name).Msg("overlaps with local certs: skipped")
|
log.Error(ctx).Err(err).Str("domain", name).Msg("overlaps with local certs: skipped")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
@ -1311,7 +1313,7 @@ func (o *Options) applyExternalCerts(ctx context.Context, certs []*config.Settin
|
||||||
}
|
}
|
||||||
|
|
||||||
// ApplySettings modifies the config options using the given protobuf settings.
|
// ApplySettings modifies the config options using the given protobuf settings.
|
||||||
func (o *Options) ApplySettings(ctx context.Context, settings *config.Settings) {
|
func (o *Options) ApplySettings(ctx context.Context, certsIndex *cryptutil.CertificatesIndex, settings *config.Settings) {
|
||||||
if settings == nil {
|
if settings == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -1325,7 +1327,7 @@ func (o *Options) ApplySettings(ctx context.Context, settings *config.Settings)
|
||||||
set(&o.Addr, settings.Address)
|
set(&o.Addr, settings.Address)
|
||||||
set(&o.InsecureServer, settings.InsecureServer)
|
set(&o.InsecureServer, settings.InsecureServer)
|
||||||
set(&o.DNSLookupFamily, settings.DnsLookupFamily)
|
set(&o.DNSLookupFamily, settings.DnsLookupFamily)
|
||||||
o.applyExternalCerts(ctx, settings.GetCertificates())
|
o.applyExternalCerts(ctx, certsIndex, settings.GetCertificates())
|
||||||
set(&o.HTTPRedirectAddr, settings.HttpRedirectAddr)
|
set(&o.HTTPRedirectAddr, settings.HttpRedirectAddr)
|
||||||
setDuration(&o.ReadTimeout, settings.TimeoutRead)
|
setDuration(&o.ReadTimeout, settings.TimeoutRead)
|
||||||
setDuration(&o.WriteTimeout, settings.TimeoutWrite)
|
setDuration(&o.WriteTimeout, settings.TimeoutWrite)
|
||||||
|
|
|
@ -3,6 +3,7 @@ package config
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/pem"
|
"encoding/pem"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
@ -735,13 +736,17 @@ func TestOptions_ApplySettings(t *testing.T) {
|
||||||
cert3, err := cryptutil.GenerateCertificate(nil, "not.example.com")
|
cert3, err := cryptutil.GenerateCertificate(nil, "not.example.com")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
certsIndex := cryptutil.NewCertificatesIndex()
|
||||||
|
xc1, _ := x509.ParseCertificate(cert1.Certificate[0])
|
||||||
|
certsIndex.Add(xc1)
|
||||||
|
|
||||||
settings := &config.Settings{
|
settings := &config.Settings{
|
||||||
Certificates: []*config.Settings_Certificate{
|
Certificates: []*config.Settings_Certificate{
|
||||||
{CertBytes: encodeCert(cert2)},
|
{CertBytes: encodeCert(cert2)},
|
||||||
{CertBytes: encodeCert(cert3)},
|
{CertBytes: encodeCert(cert3)},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
options.ApplySettings(ctx, settings)
|
options.ApplySettings(ctx, certsIndex, settings)
|
||||||
assert.Len(t, options.CertificateFiles, 2, "should prevent adding duplicate certificates")
|
assert.Len(t, options.CertificateFiles, 2, "should prevent adding duplicate certificates")
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
2
go.mod
2
go.mod
|
@ -55,7 +55,7 @@ require (
|
||||||
github.com/rs/zerolog v1.29.1
|
github.com/rs/zerolog v1.29.1
|
||||||
github.com/shirou/gopsutil/v3 v3.23.4
|
github.com/shirou/gopsutil/v3 v3.23.4
|
||||||
github.com/spf13/viper v1.15.0
|
github.com/spf13/viper v1.15.0
|
||||||
github.com/stretchr/testify v1.8.3
|
github.com/stretchr/testify v1.8.4
|
||||||
github.com/tniswong/go.rfcx v0.0.0-20181019234604-07783c52761f
|
github.com/tniswong/go.rfcx v0.0.0-20181019234604-07783c52761f
|
||||||
github.com/volatiletech/null/v9 v9.0.0
|
github.com/volatiletech/null/v9 v9.0.0
|
||||||
github.com/yuin/gopher-lua v1.1.0
|
github.com/yuin/gopher-lua v1.1.0
|
||||||
|
|
4
go.sum
4
go.sum
|
@ -658,8 +658,8 @@ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/
|
||||||
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
|
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
|
||||||
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||||
github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||||
github.com/stretchr/testify v1.8.3 h1:RP3t2pwF7cMEbC1dqtB6poj3niw/9gnV4Cjg5oW5gtY=
|
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
|
||||||
github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
||||||
github.com/stvp/go-udp-testing v0.0.0-20201019212854-469649b16807/go.mod h1:7jxmlfBCDBXRzr0eAQJ48XC1hBu1np4CS5+cHEYfwpc=
|
github.com/stvp/go-udp-testing v0.0.0-20201019212854-469649b16807/go.mod h1:7jxmlfBCDBXRzr0eAQJ48XC1hBu1np4CS5+cHEYfwpc=
|
||||||
github.com/subosito/gotenv v1.4.2 h1:X1TuBLAMDFbaTAChgCBLu3DU3UPyELpnF2jjJ2cz/S8=
|
github.com/subosito/gotenv v1.4.2 h1:X1TuBLAMDFbaTAChgCBLu3DU3UPyELpnF2jjJ2cz/S8=
|
||||||
github.com/subosito/gotenv v1.4.2/go.mod h1:ayKnFf/c6rvx/2iiLrJUk1e6plDbT3edrFNGqEflhK0=
|
github.com/subosito/gotenv v1.4.2/go.mod h1:ayKnFf/c6rvx/2iiLrJUk1e6plDbT3edrFNGqEflhK0=
|
||||||
|
|
|
@ -13,6 +13,7 @@ import (
|
||||||
"github.com/pomerium/pomerium/internal/log"
|
"github.com/pomerium/pomerium/internal/log"
|
||||||
"github.com/pomerium/pomerium/internal/telemetry/metrics"
|
"github.com/pomerium/pomerium/internal/telemetry/metrics"
|
||||||
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
||||||
|
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||||
"github.com/pomerium/pomerium/pkg/grpc"
|
"github.com/pomerium/pomerium/pkg/grpc"
|
||||||
configpb "github.com/pomerium/pomerium/pkg/grpc/config"
|
configpb "github.com/pomerium/pomerium/pkg/grpc/config"
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||||
|
@ -98,11 +99,16 @@ func (src *ConfigSource) rebuild(ctx context.Context, firstTime firstTime) {
|
||||||
ids := maps.Keys(src.dbConfigs)
|
ids := maps.Keys(src.dbConfigs)
|
||||||
sort.Strings(ids)
|
sort.Strings(ids)
|
||||||
|
|
||||||
|
certsIndex := cryptutil.NewCertificatesIndex()
|
||||||
|
for _, cert := range cfg.Options.GetX509Certificates() {
|
||||||
|
certsIndex.Add(cert)
|
||||||
|
}
|
||||||
|
|
||||||
// add all the config policies to the list
|
// add all the config policies to the list
|
||||||
for _, id := range ids {
|
for _, id := range ids {
|
||||||
cfgpb := src.dbConfigs[id]
|
cfgpb := src.dbConfigs[id]
|
||||||
|
|
||||||
cfg.Options.ApplySettings(ctx, cfgpb.Settings)
|
cfg.Options.ApplySettings(ctx, certsIndex, cfgpb.Settings)
|
||||||
var errCount uint64
|
var errCount uint64
|
||||||
|
|
||||||
err := cfg.Options.Validate()
|
err := cfg.Options.Validate()
|
||||||
|
|
|
@ -2,22 +2,35 @@ package databroker
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
"net"
|
"net"
|
||||||
"net/url"
|
"net/url"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
"google.golang.org/protobuf/proto"
|
"google.golang.org/protobuf/proto"
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/config"
|
"github.com/pomerium/pomerium/config"
|
||||||
|
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||||
configpb "github.com/pomerium/pomerium/pkg/grpc/config"
|
configpb "github.com/pomerium/pomerium/pkg/grpc/config"
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||||
"github.com/pomerium/pomerium/pkg/protoutil"
|
"github.com/pomerium/pomerium/pkg/protoutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestConfigSource(t *testing.T) {
|
func TestConfigSource(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
generateCert := func(name string) ([]byte, []byte) {
|
||||||
|
cert, err := cryptutil.GenerateCertificate(nil, name)
|
||||||
|
require.NoError(t, err)
|
||||||
|
certPEM, keyPEM, err := cryptutil.EncodeCertificate(cert)
|
||||||
|
require.NoError(t, err)
|
||||||
|
return certPEM, keyPEM
|
||||||
|
}
|
||||||
|
|
||||||
ctx, clearTimeout := context.WithTimeout(context.Background(), 50*time.Second)
|
ctx, clearTimeout := context.WithTimeout(context.Background(), 50*time.Second)
|
||||||
defer clearTimeout()
|
defer clearTimeout()
|
||||||
|
|
||||||
|
@ -45,6 +58,8 @@ func TestConfigSource(t *testing.T) {
|
||||||
{URL: *u},
|
{URL: *u},
|
||||||
}, AllowedUsers: []string{"foo@bar.com"},
|
}, AllowedUsers: []string{"foo@bar.com"},
|
||||||
})
|
})
|
||||||
|
certPEM, keyPEM := generateCert("*.example.com")
|
||||||
|
base.Cert, base.Key = base64.StdEncoding.EncodeToString(certPEM), base64.StdEncoding.EncodeToString(keyPEM)
|
||||||
|
|
||||||
baseSource := config.NewStaticSource(&config.Config{
|
baseSource := config.NewStaticSource(&config.Config{
|
||||||
OutboundPort: outboundPort,
|
OutboundPort: outboundPort,
|
||||||
|
@ -55,13 +70,17 @@ func TestConfigSource(t *testing.T) {
|
||||||
})
|
})
|
||||||
cfgs <- src.GetConfig()
|
cfgs <- src.GetConfig()
|
||||||
|
|
||||||
data := protoutil.NewAny(&configpb.Config{
|
route := &configpb.Route{
|
||||||
Name: "config",
|
|
||||||
Routes: []*configpb.Route{
|
|
||||||
{
|
|
||||||
From: "https://from.example.com",
|
From: "https://from.example.com",
|
||||||
To: []string{"https://to.example.com"},
|
To: []string{"https://to.example.com"},
|
||||||
},
|
}
|
||||||
|
cert := &configpb.Settings_Certificate{}
|
||||||
|
cert.CertBytes, cert.KeyBytes = generateCert("*.example.com")
|
||||||
|
data := protoutil.NewAny(&configpb.Config{
|
||||||
|
Name: "config",
|
||||||
|
Routes: []*configpb.Route{route},
|
||||||
|
Settings: &configpb.Settings{
|
||||||
|
Certificates: []*configpb.Settings_Certificate{cert},
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
_, _ = dataBrokerServer.Put(ctx, &databroker.PutRequest{
|
_, _ = dataBrokerServer.Put(ctx, &databroker.PutRequest{
|
||||||
|
@ -86,6 +105,7 @@ func TestConfigSource(t *testing.T) {
|
||||||
return
|
return
|
||||||
case cfg := <-cfgs:
|
case cfg := <-cfgs:
|
||||||
assert.Len(t, cfg.Options.AdditionalPolicies, 1)
|
assert.Len(t, cfg.Options.AdditionalPolicies, 1)
|
||||||
|
assert.Len(t, cfg.Options.CertificateFiles, 0, "ignores overlapping certificate")
|
||||||
}
|
}
|
||||||
|
|
||||||
baseSource.SetConfig(ctx, &config.Config{
|
baseSource.SetConfig(ctx, &config.Config{
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
package config
|
package cryptutil
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
|
@ -6,13 +6,70 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
type certUsage byte
|
type certUsage byte
|
||||||
type certsIndex map[string]map[string]certUsage
|
|
||||||
|
|
||||||
const (
|
const (
|
||||||
certUsageServerAuth = certUsage(1 << iota)
|
certUsageServerAuth = certUsage(1 << iota)
|
||||||
certUsageClientAuth
|
certUsageClientAuth
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// A CertificatesIndex indexes certificates to determine if there is overlap between them.
|
||||||
|
type CertificatesIndex struct {
|
||||||
|
index map[string]map[string]certUsage
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewCertificatesIndex creates a new CertificatesIndex.
|
||||||
|
func NewCertificatesIndex() *CertificatesIndex {
|
||||||
|
return &CertificatesIndex{
|
||||||
|
index: make(map[string]map[string]certUsage),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add adds a certificate to the index.
|
||||||
|
func (c *CertificatesIndex) Add(cert *x509.Certificate) {
|
||||||
|
usage := getCertUsage(cert)
|
||||||
|
for _, name := range cert.DNSNames {
|
||||||
|
c.add(name, usage)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// OverlapsWithExistingCertificate returns true if the certificate overlaps with an existing certificate.
|
||||||
|
func (c *CertificatesIndex) OverlapsWithExistingCertificate(cert *x509.Certificate) (bool, string) {
|
||||||
|
usage := getCertUsage(cert)
|
||||||
|
for _, name := range cert.DNSNames {
|
||||||
|
if c.match(name, usage) {
|
||||||
|
return true, name
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false, ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *CertificatesIndex) add(name string, usage certUsage) {
|
||||||
|
prefix, suffix := splitDomainName(name)
|
||||||
|
names := c.index[suffix]
|
||||||
|
if names == nil {
|
||||||
|
names = make(map[string]certUsage)
|
||||||
|
c.index[suffix] = names
|
||||||
|
}
|
||||||
|
names[prefix] = usage
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *CertificatesIndex) match(name string, usage certUsage) bool {
|
||||||
|
prefix, suffix := splitDomainName(name)
|
||||||
|
names := c.index[suffix]
|
||||||
|
if names == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if prefix != "*" {
|
||||||
|
return names["*"]&usage != 0 || names[prefix]&usage != 0
|
||||||
|
}
|
||||||
|
for _, u := range names {
|
||||||
|
if u&usage != 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
func splitDomainName(name string) (prefix, suffix string) {
|
func splitDomainName(name string) (prefix, suffix string) {
|
||||||
dot := strings.IndexRune(name, '.')
|
dot := strings.IndexRune(name, '.')
|
||||||
if dot < 0 {
|
if dot < 0 {
|
||||||
|
@ -33,47 +90,3 @@ func getCertUsage(cert *x509.Certificate) certUsage {
|
||||||
}
|
}
|
||||||
return usage
|
return usage
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c certsIndex) addCert(cert *x509.Certificate) {
|
|
||||||
usage := getCertUsage(cert)
|
|
||||||
for _, name := range cert.DNSNames {
|
|
||||||
c.add(name, usage)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c certsIndex) matchCert(cert *x509.Certificate) (bool, string) {
|
|
||||||
usage := getCertUsage(cert)
|
|
||||||
for _, name := range cert.DNSNames {
|
|
||||||
if c.match(name, usage) {
|
|
||||||
return true, name
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false, ""
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c certsIndex) add(name string, usage certUsage) {
|
|
||||||
prefix, suffix := splitDomainName(name)
|
|
||||||
names := c[suffix]
|
|
||||||
if names == nil {
|
|
||||||
names = make(map[string]certUsage)
|
|
||||||
c[suffix] = names
|
|
||||||
}
|
|
||||||
names[prefix] = usage
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c certsIndex) match(name string, usage certUsage) bool {
|
|
||||||
prefix, suffix := splitDomainName(name)
|
|
||||||
names := c[suffix]
|
|
||||||
if names == nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if prefix != "*" {
|
|
||||||
return names["*"]&usage != 0 || names[prefix]&usage != 0
|
|
||||||
}
|
|
||||||
for _, u := range names {
|
|
||||||
if u&usage != 0 {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
75
pkg/cryptutil/certificates_index_test.go
Normal file
75
pkg/cryptutil/certificates_index_test.go
Normal file
|
@ -0,0 +1,75 @@
|
||||||
|
package cryptutil_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/x509"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||||
|
"github.com/pomerium/pomerium/pkg/derivecert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCertificatesIndex(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
ca, err := derivecert.NewCA(cryptutil.NewKey())
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
mkClientCert := func(domains []string) *x509.Certificate {
|
||||||
|
pem, err := ca.NewServerCert(domains, func(c *x509.Certificate) {
|
||||||
|
c.ExtKeyUsage = []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
_, cert, err := pem.KeyCert()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
return cert
|
||||||
|
}
|
||||||
|
mkServerCert := func(domains []string) *x509.Certificate {
|
||||||
|
pem, err := ca.NewServerCert(domains)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
_, cert, err := pem.KeyCert()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
return cert
|
||||||
|
}
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
names []string
|
||||||
|
test string
|
||||||
|
match bool
|
||||||
|
}{
|
||||||
|
{[]string{"aa.bb.cc", "cc.bb.aa"}, "aa.bb.c", false},
|
||||||
|
{[]string{"aa.bb.cc"}, "aa.bb.cc", true},
|
||||||
|
{[]string{"*.bb.cc"}, "aa.bb.cc", true},
|
||||||
|
{[]string{"a1.bb.cc", "a2.bb.cc"}, "*.bb.cc", true},
|
||||||
|
{[]string{"*.bb.cc", "a2.bb.cc"}, "*.bb.cc", true},
|
||||||
|
{[]string{"*.aa.bb.cc"}, "*.bb.cc", false},
|
||||||
|
{[]string{"*.aa.bb.cc"}, "aa.bb.cc", false},
|
||||||
|
{[]string{"bb.cc"}, "*.bb.cc", false},
|
||||||
|
}
|
||||||
|
t.Run("match mix mode", func(t *testing.T) {
|
||||||
|
for _, tc := range testCases {
|
||||||
|
idx := cryptutil.NewCertificatesIndex()
|
||||||
|
idx.Add(mkServerCert(tc.names))
|
||||||
|
|
||||||
|
cert := mkServerCert([]string{tc.test})
|
||||||
|
overlaps, _ := idx.OverlapsWithExistingCertificate(cert)
|
||||||
|
assert.Equalf(t, tc.match, overlaps, "%v", tc)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
t.Run("different cert usages never match", func(t *testing.T) {
|
||||||
|
for _, tc := range testCases {
|
||||||
|
idx := cryptutil.NewCertificatesIndex()
|
||||||
|
idx.Add(mkServerCert(tc.names))
|
||||||
|
|
||||||
|
cert := mkClientCert([]string{tc.test})
|
||||||
|
overlaps, _ := idx.OverlapsWithExistingCertificate(cert)
|
||||||
|
assert.Equalf(t, false, overlaps, "%v", tc)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
Loading…
Add table
Add a link
Reference in a new issue