diff --git a/config/config.go b/config/config.go index 38d220253..b2dad2b5f 100644 --- a/config/config.go +++ b/config/config.go @@ -189,8 +189,13 @@ func (cfg *Config) GetCertificateForServerName(serverName string) (*tls.Certific return &cert, nil } + sharedKey, err := cfg.Options.GetSharedKey() + if err != nil { + return nil, fmt.Errorf("failed to generate cert, invalid shared key: %w", err) + } + // finally fall back to a generated, self-signed certificate - return cryptutil.GenerateSelfSignedCertificate(serverName) + return cryptutil.GenerateCertificate(sharedKey, serverName) } // WillHaveCertificateForServerName returns true if there will be a certificate for the given server name. diff --git a/config/config_test.go b/config/config_test.go index f4fdad40e..7073a13dc 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -11,7 +11,7 @@ import ( func TestConfig_GetCertificateForServerName(t *testing.T) { gen := func(t *testing.T, serverName string) *tls.Certificate { - cert, err := cryptutil.GenerateSelfSignedCertificate(serverName) + cert, err := cryptutil.GenerateCertificate(nil, serverName) if !assert.NoError(t, err, "error generating certificate for: %s", serverName) { t.FailNow() } diff --git a/config/envoyconfig/listeners_test.go b/config/envoyconfig/listeners_test.go index a8b67e909..3f67f0818 100644 --- a/config/envoyconfig/listeners_test.go +++ b/config/envoyconfig/listeners_test.go @@ -211,7 +211,7 @@ func Test_buildDownstreamTLSContext(t *testing.T) { } func Test_getAllDomains(t *testing.T) { - cert, err := cryptutil.GenerateSelfSignedCertificate("*.unknown.example.com") + cert, err := cryptutil.GenerateCertificate(nil, "*.unknown.example.com") require.NoError(t, err) certPEM, keyPEM, err := cryptutil.EncodeCertificate(cert) require.NoError(t, err) diff --git a/config/envoyconfig/tls_test.go b/config/envoyconfig/tls_test.go index a266e5e9b..c41f58696 100644 --- a/config/envoyconfig/tls_test.go +++ b/config/envoyconfig/tls_test.go @@ -46,7 +46,7 @@ func TestBuildSubjectNameIndication(t *testing.T) { } func TestValidateCertificate(t *testing.T) { - cert, err := cryptutil.GenerateSelfSignedCertificate("example.com", func(tpl *x509.Certificate) { + cert, err := cryptutil.GenerateCertificate(nil, "example.com", func(tpl *x509.Certificate) { // set the must staple flag on the cert tpl.ExtraExtensions = append(tpl.ExtraExtensions, pkix.Extension{ Id: oidMustStaple, diff --git a/config/options_test.go b/config/options_test.go index 1c2c350bf..70b627d60 100644 --- a/config/options_test.go +++ b/config/options_test.go @@ -722,14 +722,14 @@ func TestOptions_ApplySettings(t *testing.T) { t.Run("certificates", func(t *testing.T) { options := NewDefaultOptions() - cert1, err := cryptutil.GenerateSelfSignedCertificate("example.com") + cert1, err := cryptutil.GenerateCertificate(nil, "example.com") require.NoError(t, err) options.CertificateFiles = append(options.CertificateFiles, certificateFilePair{ CertFile: base64.StdEncoding.EncodeToString(encodeCert(cert1)), }) - cert2, err := cryptutil.GenerateSelfSignedCertificate("example.com") + cert2, err := cryptutil.GenerateCertificate(nil, "example.com") require.NoError(t, err) - cert3, err := cryptutil.GenerateSelfSignedCertificate("not.example.com") + cert3, err := cryptutil.GenerateCertificate(nil, "not.example.com") require.NoError(t, err) settings := &config.Settings{ diff --git a/pkg/cryptutil/certificates.go b/pkg/cryptutil/certificates.go index d35fb07f6..af0c53407 100644 --- a/pkg/cryptutil/certificates.go +++ b/pkg/cryptutil/certificates.go @@ -2,8 +2,6 @@ package cryptutil import ( "crypto/ecdsa" - "crypto/rand" - "crypto/rsa" "crypto/tls" "crypto/x509" "crypto/x509/pkix" @@ -12,10 +10,9 @@ import ( "errors" "fmt" "io" - "math/big" - "net" "os" - "time" + + "github.com/pomerium/pomerium/pkg/derivecert" ) const ( @@ -160,63 +157,24 @@ func EncodePrivateKey(key *ecdsa.PrivateKey) ([]byte, error) { return pem.EncodeToMemory(keyBlock), nil } -// GenerateSelfSignedCertificate generates a self-signed TLS certificate. -// -// mostly copied from https://golang.org/src/crypto/tls/generate_cert.go -func GenerateSelfSignedCertificate(domain string, configure ...func(*x509.Certificate)) (*tls.Certificate, error) { - privateKey, err := rsa.GenerateKey(rand.Reader, 2048) +// GenerateCertificate generates a TLS certificate derived from a shared key. +func GenerateCertificate(sharedKey []byte, domain string, configure ...func(*x509.Certificate)) (*tls.Certificate, error) { + ca, err := derivecert.NewCA(sharedKey) if err != nil { - return nil, fmt.Errorf("failed to generate private key: %w", err) + return nil, fmt.Errorf("cryptutil: failed to generate certificate, error deriving CA: %w", err) } - serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) - serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) + pem, err := ca.NewServerCert([]string{domain}, configure...) if err != nil { - return nil, fmt.Errorf("failed to generate serial number: %w", err) + return nil, fmt.Errorf("cryptutil: failed to generate certificate, error creating server certificate: %w", err) } - template := &x509.Certificate{ - SerialNumber: serialNumber, - Subject: pkix.Name{ - Organization: []string{"Pomerium"}, - }, - NotBefore: time.Now().Add(-time.Minute * 10), - NotAfter: time.Now().Add(time.Hour * 24 * 365), - KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, - ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, - BasicConstraintsValid: true, - } - if ip := net.ParseIP(domain); ip != nil { - template.IPAddresses = append(template.IPAddresses, ip) - } else { - template.DNSNames = append(template.DNSNames, domain) - } - for _, f := range configure { - f(template) - } - - publicKeyBytes, err := x509.CreateCertificate(rand.Reader, - template, template, - privateKey.Public(), privateKey, - ) + tlsCert, err := pem.TLS() if err != nil { - return nil, fmt.Errorf("failed to create certificate: %w", err) + return nil, fmt.Errorf("cryptutil: failed to generate certificate, error converting server certificate to TLS certificate: %w", err) } - privateKeyBytes, err := x509.MarshalPKCS8PrivateKey(privateKey) - if err != nil { - return nil, fmt.Errorf("failed to marshal private key: %w", err) - } - - cert, err := tls.X509KeyPair( - pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: publicKeyBytes}), - pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: privateKeyBytes}), - ) - if err != nil { - return nil, fmt.Errorf("failed to convert x509 bytes into tls certificate: %w", err) - } - - return &cert, nil + return &tlsCert, nil } // EncodeCertificate encodes a TLS certificate into PEM compatible byte slices. diff --git a/pkg/cryptutil/tls_test.go b/pkg/cryptutil/tls_test.go index 36a86e0c3..0ca167f1a 100644 --- a/pkg/cryptutil/tls_test.go +++ b/pkg/cryptutil/tls_test.go @@ -8,7 +8,7 @@ import ( ) func TestGetCertificateServerNames(t *testing.T) { - cert, err := GenerateSelfSignedCertificate("www.example.com") + cert, err := GenerateCertificate(nil, "www.example.com") require.NoError(t, err) assert.Equal(t, []string{"www.example.com"}, GetCertificateServerNames(cert)) } diff --git a/pkg/derivecert/ca.go b/pkg/derivecert/ca.go index 9f1128a2e..257098eab 100644 --- a/pkg/derivecert/ca.go +++ b/pkg/derivecert/ca.go @@ -83,7 +83,7 @@ func CAFromPEM(p PEM) (*CA, string, error) { } // NewServerCert generates certificate for the given domain name(s) -func (ca *CA) NewServerCert(domains []string) (*PEM, error) { +func (ca *CA) NewServerCert(domains []string, configure ...func(*x509.Certificate)) (*PEM, error) { key, err := deriveKey(newReader(readerTypeServerPrivateKey, ca.psk, domains...)) if err != nil { return nil, fmt.Errorf("derive key: %w", err) @@ -93,6 +93,9 @@ func (ca *CA) NewServerCert(domains []string) (*PEM, error) { if err != nil { return nil, fmt.Errorf("cert template: %w", err) } + for _, f := range configure { + f(tmpl) + } cert, err := x509.CreateCertificate( newReader(readerTypeServerCertificate, ca.psk, domains...),