mirror of
https://github.com/pomerium/pomerium.git
synced 2025-04-29 10:26:29 +02:00
options: fix overlapping certificate test (#3492)
This commit is contained in:
parent
250eff3e0a
commit
1afbc6e9c4
3 changed files with 68 additions and 17 deletions
|
@ -1237,7 +1237,7 @@ func (o *Options) Checksum() uint64 {
|
|||
return hashutil.MustHash(o)
|
||||
}
|
||||
|
||||
func (o Options) indexCerts(ctx context.Context) certsIndex {
|
||||
func (o *Options) indexCerts(ctx context.Context) certsIndex {
|
||||
idx := make(certsIndex)
|
||||
|
||||
if o.CertFile != "" {
|
||||
|
@ -1248,9 +1248,7 @@ func (o Options) indexCerts(ctx context.Context) certsIndex {
|
|||
idx.addCert(cert)
|
||||
}
|
||||
} else if o.Cert != "" {
|
||||
if data, err := base64.StdEncoding.DecodeString(o.Cert); err != nil {
|
||||
log.Error(ctx).Err(err).Msg("bad base64 for local cert: skipped")
|
||||
} else if cert, err := cryptutil.ParsePEMCertificate(data); err != nil {
|
||||
if cert, err := cryptutil.ParsePEMCertificateFromBase64(o.Cert); err != nil {
|
||||
log.Error(ctx).Err(err).Msg("parsing local cert: skipped")
|
||||
} else {
|
||||
idx.addCert(cert)
|
||||
|
@ -1258,9 +1256,12 @@ func (o Options) indexCerts(ctx context.Context) certsIndex {
|
|||
}
|
||||
|
||||
for _, c := range o.CertificateFiles {
|
||||
cert, err := cryptutil.ParsePEMCertificateFromFile(c.CertFile)
|
||||
cert, err := cryptutil.ParsePEMCertificateFromBase64(c.CertFile)
|
||||
if err != nil {
|
||||
log.Error(ctx).Err(err).Str("file", c.CertFile).Msg("parsing local cert: skipped")
|
||||
cert, err = cryptutil.ParsePEMCertificateFromFile(c.CertFile)
|
||||
}
|
||||
if err != nil {
|
||||
log.Error(ctx).Err(err).Msg("parsing local cert: skipped")
|
||||
} else {
|
||||
idx.addCert(cert)
|
||||
}
|
||||
|
@ -1271,15 +1272,6 @@ func (o Options) indexCerts(ctx context.Context) certsIndex {
|
|||
func (o *Options) applyExternalCerts(ctx context.Context, certs []*config.Settings_Certificate) {
|
||||
idx := o.indexCerts(ctx)
|
||||
for _, c := range certs {
|
||||
cert, err := cryptutil.ParsePEMCertificate(c.CertBytes)
|
||||
if err != nil {
|
||||
log.Error(ctx).Err(err).Msg("parsing cert from databroker: skipped")
|
||||
continue
|
||||
}
|
||||
if overlaps, name := idx.matchCert(cert); overlaps {
|
||||
log.Error(ctx).Err(err).Str("domain", name).Msg("overlaps with local certs: skipped")
|
||||
continue
|
||||
}
|
||||
cfp := certificateFilePair{
|
||||
CertFile: c.CertFile,
|
||||
KeyFile: c.KeyFile,
|
||||
|
@ -1290,6 +1282,20 @@ func (o *Options) applyExternalCerts(ctx context.Context, certs []*config.Settin
|
|||
if cfp.KeyFile == "" {
|
||||
cfp.KeyFile = base64.StdEncoding.EncodeToString(c.KeyBytes)
|
||||
}
|
||||
|
||||
cert, err := cryptutil.ParsePEMCertificateFromBase64(cfp.CertFile)
|
||||
if err != nil {
|
||||
cert, err = cryptutil.ParsePEMCertificateFromFile(cfp.CertFile)
|
||||
}
|
||||
if err != nil {
|
||||
log.Error(ctx).Err(err).Msg("parsing cert from databroker: skipped")
|
||||
continue
|
||||
}
|
||||
if overlaps, name := idx.matchCert(cert); overlaps {
|
||||
log.Error(ctx).Err(err).Str("domain", name).Msg("overlaps with local certs: skipped")
|
||||
continue
|
||||
}
|
||||
|
||||
o.CertificateFiles = append(o.CertificateFiles, cfp)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2,7 +2,9 @@ package config
|
|||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/base64"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"os"
|
||||
|
@ -15,6 +17,9 @@ import (
|
|||
"github.com/spf13/viper"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/config"
|
||||
)
|
||||
|
||||
var cmpOptIgnoreUnexported = cmpopts.IgnoreUnexported(Options{}, Policy{})
|
||||
|
@ -734,6 +739,37 @@ func TestOptions_GetAllRouteableHTTPDomains(t *testing.T) {
|
|||
}, domains)
|
||||
}
|
||||
|
||||
func TestOptions_ApplySettings(t *testing.T) {
|
||||
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second)
|
||||
defer clearTimeout()
|
||||
|
||||
t.Run("certificates", func(t *testing.T) {
|
||||
options := NewDefaultOptions()
|
||||
cert1, err := cryptutil.GenerateSelfSignedCertificate("example.com")
|
||||
require.NoError(t, err)
|
||||
options.CertificateFiles = append(options.CertificateFiles, certificateFilePair{
|
||||
CertFile: base64.StdEncoding.EncodeToString(encodeCert(cert1)),
|
||||
})
|
||||
cert2, err := cryptutil.GenerateSelfSignedCertificate("example.com")
|
||||
require.NoError(t, err)
|
||||
cert3, err := cryptutil.GenerateSelfSignedCertificate("not.example.com")
|
||||
require.NoError(t, err)
|
||||
|
||||
settings := &config.Settings{
|
||||
Certificates: []*config.Settings_Certificate{
|
||||
{CertBytes: encodeCert(cert2)},
|
||||
{CertBytes: encodeCert(cert3)},
|
||||
},
|
||||
}
|
||||
options.ApplySettings(ctx, settings)
|
||||
assert.Len(t, options.CertificateFiles, 2, "should prevent adding duplicate certificates")
|
||||
})
|
||||
}
|
||||
|
||||
func encodeCert(cert *tls.Certificate) []byte {
|
||||
return pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: cert.Certificate[0]})
|
||||
}
|
||||
|
||||
func mustParseWeightedURLs(t *testing.T, urls ...string) []WeightedURL {
|
||||
wu, err := ParseWeightedUrls(urls...)
|
||||
require.NoError(t, err)
|
||||
|
|
|
@ -219,7 +219,7 @@ func GenerateSelfSignedCertificate(domain string, configure ...func(*x509.Certif
|
|||
return &cert, nil
|
||||
}
|
||||
|
||||
// ParsePEMCertificate parses PEM encoded certificate block
|
||||
// ParsePEMCertificate parses a PEM encoded certificate block.
|
||||
func ParsePEMCertificate(raw []byte) (*x509.Certificate, error) {
|
||||
data := raw
|
||||
for {
|
||||
|
@ -242,7 +242,16 @@ func ParsePEMCertificate(raw []byte) (*x509.Certificate, error) {
|
|||
return nil, fmt.Errorf("no certificate block found")
|
||||
}
|
||||
|
||||
// ParsePEMCertificateFromFile decodes PEM certificate from file
|
||||
// ParsePEMCertificateFromBase64 parses a PEM encoded certificate block from a base64 encoded string.
|
||||
func ParsePEMCertificateFromBase64(encoded string) (*x509.Certificate, error) {
|
||||
raw, err := base64.StdEncoding.DecodeString(encoded)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ParsePEMCertificate(raw)
|
||||
}
|
||||
|
||||
// ParsePEMCertificateFromFile decodes a PEM certificate from a file.
|
||||
func ParsePEMCertificateFromFile(file string) (*x509.Certificate, error) {
|
||||
fd, err := os.Open(file)
|
||||
if err != nil {
|
||||
|
|
Loading…
Add table
Reference in a new issue