diff --git a/config/options.go b/config/options.go index fc707e629..e65e3d035 100644 --- a/config/options.go +++ b/config/options.go @@ -1237,7 +1237,7 @@ func (o *Options) Checksum() uint64 { return hashutil.MustHash(o) } -func (o Options) indexCerts(ctx context.Context) certsIndex { +func (o *Options) indexCerts(ctx context.Context) certsIndex { idx := make(certsIndex) if o.CertFile != "" { @@ -1248,9 +1248,7 @@ func (o Options) indexCerts(ctx context.Context) certsIndex { idx.addCert(cert) } } else if o.Cert != "" { - if data, err := base64.StdEncoding.DecodeString(o.Cert); err != nil { - log.Error(ctx).Err(err).Msg("bad base64 for local cert: skipped") - } else if cert, err := cryptutil.ParsePEMCertificate(data); err != nil { + if cert, err := cryptutil.ParsePEMCertificateFromBase64(o.Cert); err != nil { log.Error(ctx).Err(err).Msg("parsing local cert: skipped") } else { idx.addCert(cert) @@ -1258,9 +1256,12 @@ func (o Options) indexCerts(ctx context.Context) certsIndex { } for _, c := range o.CertificateFiles { - cert, err := cryptutil.ParsePEMCertificateFromFile(c.CertFile) + cert, err := cryptutil.ParsePEMCertificateFromBase64(c.CertFile) if err != nil { - log.Error(ctx).Err(err).Str("file", c.CertFile).Msg("parsing local cert: skipped") + cert, err = cryptutil.ParsePEMCertificateFromFile(c.CertFile) + } + if err != nil { + log.Error(ctx).Err(err).Msg("parsing local cert: skipped") } else { idx.addCert(cert) } @@ -1271,15 +1272,6 @@ func (o Options) indexCerts(ctx context.Context) certsIndex { func (o *Options) applyExternalCerts(ctx context.Context, certs []*config.Settings_Certificate) { idx := o.indexCerts(ctx) for _, c := range certs { - cert, err := cryptutil.ParsePEMCertificate(c.CertBytes) - if err != nil { - log.Error(ctx).Err(err).Msg("parsing cert from databroker: skipped") - continue - } - if overlaps, name := idx.matchCert(cert); overlaps { - log.Error(ctx).Err(err).Str("domain", name).Msg("overlaps with local certs: skipped") - continue - } cfp := certificateFilePair{ CertFile: c.CertFile, KeyFile: c.KeyFile, @@ -1290,6 +1282,20 @@ func (o *Options) applyExternalCerts(ctx context.Context, certs []*config.Settin if cfp.KeyFile == "" { cfp.KeyFile = base64.StdEncoding.EncodeToString(c.KeyBytes) } + + cert, err := cryptutil.ParsePEMCertificateFromBase64(cfp.CertFile) + if err != nil { + cert, err = cryptutil.ParsePEMCertificateFromFile(cfp.CertFile) + } + if err != nil { + log.Error(ctx).Err(err).Msg("parsing cert from databroker: skipped") + continue + } + if overlaps, name := idx.matchCert(cert); overlaps { + log.Error(ctx).Err(err).Str("domain", name).Msg("overlaps with local certs: skipped") + continue + } + o.CertificateFiles = append(o.CertificateFiles, cfp) } } diff --git a/config/options_test.go b/config/options_test.go index 6091856e3..4cdd3be3c 100644 --- a/config/options_test.go +++ b/config/options_test.go @@ -2,7 +2,9 @@ package config import ( "context" + "crypto/tls" "encoding/base64" + "encoding/pem" "fmt" "net/url" "os" @@ -15,6 +17,9 @@ import ( "github.com/spf13/viper" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "github.com/pomerium/pomerium/pkg/cryptutil" + "github.com/pomerium/pomerium/pkg/grpc/config" ) var cmpOptIgnoreUnexported = cmpopts.IgnoreUnexported(Options{}, Policy{}) @@ -734,6 +739,37 @@ func TestOptions_GetAllRouteableHTTPDomains(t *testing.T) { }, domains) } +func TestOptions_ApplySettings(t *testing.T) { + ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second) + defer clearTimeout() + + t.Run("certificates", func(t *testing.T) { + options := NewDefaultOptions() + cert1, err := cryptutil.GenerateSelfSignedCertificate("example.com") + require.NoError(t, err) + options.CertificateFiles = append(options.CertificateFiles, certificateFilePair{ + CertFile: base64.StdEncoding.EncodeToString(encodeCert(cert1)), + }) + cert2, err := cryptutil.GenerateSelfSignedCertificate("example.com") + require.NoError(t, err) + cert3, err := cryptutil.GenerateSelfSignedCertificate("not.example.com") + require.NoError(t, err) + + settings := &config.Settings{ + Certificates: []*config.Settings_Certificate{ + {CertBytes: encodeCert(cert2)}, + {CertBytes: encodeCert(cert3)}, + }, + } + options.ApplySettings(ctx, settings) + assert.Len(t, options.CertificateFiles, 2, "should prevent adding duplicate certificates") + }) +} + +func encodeCert(cert *tls.Certificate) []byte { + return pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: cert.Certificate[0]}) +} + func mustParseWeightedURLs(t *testing.T, urls ...string) []WeightedURL { wu, err := ParseWeightedUrls(urls...) require.NoError(t, err) diff --git a/pkg/cryptutil/certificates.go b/pkg/cryptutil/certificates.go index 841de58b2..e383a8f25 100644 --- a/pkg/cryptutil/certificates.go +++ b/pkg/cryptutil/certificates.go @@ -219,7 +219,7 @@ func GenerateSelfSignedCertificate(domain string, configure ...func(*x509.Certif return &cert, nil } -// ParsePEMCertificate parses PEM encoded certificate block +// ParsePEMCertificate parses a PEM encoded certificate block. func ParsePEMCertificate(raw []byte) (*x509.Certificate, error) { data := raw for { @@ -242,7 +242,16 @@ func ParsePEMCertificate(raw []byte) (*x509.Certificate, error) { return nil, fmt.Errorf("no certificate block found") } -// ParsePEMCertificateFromFile decodes PEM certificate from file +// ParsePEMCertificateFromBase64 parses a PEM encoded certificate block from a base64 encoded string. +func ParsePEMCertificateFromBase64(encoded string) (*x509.Certificate, error) { + raw, err := base64.StdEncoding.DecodeString(encoded) + if err != nil { + return nil, err + } + return ParsePEMCertificate(raw) +} + +// ParsePEMCertificateFromFile decodes a PEM certificate from a file. func ParsePEMCertificateFromFile(file string) (*x509.Certificate, error) { fd, err := os.Open(file) if err != nil {