config: generate derived certificates instead of self-signed certificates (#3860)

This commit is contained in:
Caleb Doxsey 2023-01-06 12:50:40 -07:00 committed by GitHub
parent 488bcd6f72
commit 3f1a87727f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 151 additions and 84 deletions

View file

@ -3,12 +3,16 @@ package config
import ( import (
"bytes" "bytes"
"crypto/tls" "crypto/tls"
"crypto/x509"
"encoding/base64" "encoding/base64"
"fmt"
"github.com/pomerium/pomerium/internal/fileutil" "github.com/pomerium/pomerium/internal/fileutil"
"github.com/pomerium/pomerium/internal/hashutil" "github.com/pomerium/pomerium/internal/hashutil"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/telemetry/metrics" "github.com/pomerium/pomerium/internal/telemetry/metrics"
"github.com/pomerium/pomerium/pkg/cryptutil" "github.com/pomerium/pomerium/pkg/cryptutil"
"github.com/pomerium/pomerium/pkg/derivecert"
) )
// MetricsScrapeEndpoint defines additional metrics endpoints that would be scraped and exposed by pomerium // MetricsScrapeEndpoint defines additional metrics endpoints that would be scraped and exposed by pomerium
@ -130,7 +134,7 @@ func (cfg *Config) AllocatePorts(ports [6]string) {
// GetTLSClientConfig returns TLS configuration that accounts for additional CA entries // GetTLSClientConfig returns TLS configuration that accounts for additional CA entries
func (cfg *Config) GetTLSClientConfig() (*tls.Config, error) { func (cfg *Config) GetTLSClientConfig() (*tls.Config, error) {
roots, err := cryptutil.GetCertPool(cfg.Options.CA, cfg.Options.CAFile) roots, err := cfg.GetCertificatePool()
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -139,3 +143,79 @@ func (cfg *Config) GetTLSClientConfig() (*tls.Config, error) {
MinVersion: tls.VersionTLS12, MinVersion: tls.VersionTLS12,
}, nil }, nil
} }
// GetCertificateForServerName gets the certificate for the server name. If no certificate is found and there
// is a derived CA one will be generated using that CA. If no derived CA is defined a self-signed certificate
// will be generated.
func (cfg *Config) GetCertificateForServerName(serverName string) (*tls.Certificate, error) {
certificates, err := cfg.AllCertificates()
if err != nil {
return nil, err
}
// first try a direct name match
for i := range certificates {
if cryptutil.MatchesServerName(&certificates[i], serverName) {
return &certificates[i], nil
}
}
log.WarnNoTLSCertificate(serverName)
if cfg.Options.DeriveInternalDomainCert != nil {
sharedKey, err := cfg.Options.GetSharedKey()
if err != nil {
return nil, fmt.Errorf("failed to generate cert, invalid shared key: %w", err)
}
ca, err := derivecert.NewCA(sharedKey)
if err != nil {
return nil, fmt.Errorf("failed to generate cert, invalid derived CA: %w", err)
}
pem, err := ca.NewServerCert([]string{serverName})
if err != nil {
return nil, fmt.Errorf("failed to generate cert, error creating server certificate: %w", err)
}
cert, err := pem.TLS()
if err != nil {
return nil, fmt.Errorf("failed to generate cert, error converting generated certificate into TLS certificate: %w", err)
}
return &cert, nil
}
// finally fall back to a generated, self-signed certificate
return cryptutil.GenerateSelfSignedCertificate(serverName)
}
// GetCertificatePool gets the certificate pool for the config.
func (cfg *Config) GetCertificatePool() (*x509.CertPool, error) {
pool, err := cryptutil.GetCertPool(cfg.Options.CA, cfg.Options.CAFile)
if err != nil {
return nil, err
}
if cfg.Options.DeriveInternalDomainCert != nil {
sharedKey, err := cfg.Options.GetSharedKey()
if err != nil {
return nil, fmt.Errorf("failed to derive CA, invalid shared key: %w", err)
}
ca, err := derivecert.NewCA(sharedKey)
if err != nil {
return nil, fmt.Errorf("failed to derive CA: %w", err)
}
pem, err := ca.PEM()
if err != nil {
return nil, fmt.Errorf("failed to derive CA PEM: %w", err)
}
if !pool.AppendCertsFromPEM(pem.Cert) {
return nil, fmt.Errorf("failed to derive CA PEM, error appending to pool")
}
}
return pool, nil
}

66
config/config_test.go Normal file
View file

@ -0,0 +1,66 @@
package config
import (
"crypto/tls"
"testing"
"github.com/stretchr/testify/assert"
"github.com/pomerium/pomerium/pkg/cryptutil"
)
func TestConfig_GetCertificateForServerName(t *testing.T) {
gen := func(t *testing.T, serverName string) *tls.Certificate {
cert, err := cryptutil.GenerateSelfSignedCertificate(serverName)
if !assert.NoError(t, err, "error generating certificate for: %s", serverName) {
t.FailNow()
}
return cert
}
t.Run("exact match", func(t *testing.T) {
cfg := &Config{Options: NewDefaultOptions(), AutoCertificates: []tls.Certificate{
*gen(t, "a.example.com"),
*gen(t, "b.example.com"),
}}
found, err := cfg.GetCertificateForServerName("b.example.com")
if !assert.NoError(t, err) {
return
}
assert.Equal(t, &cfg.AutoCertificates[1], found)
})
t.Run("wildcard match", func(t *testing.T) {
cfg := &Config{Options: NewDefaultOptions(), AutoCertificates: []tls.Certificate{
*gen(t, "a.example.com"),
*gen(t, "*.example.com"),
}}
found, err := cfg.GetCertificateForServerName("b.example.com")
if !assert.NoError(t, err) {
return
}
assert.Equal(t, &cfg.AutoCertificates[1], found)
})
t.Run("no name match", func(t *testing.T) {
cfg := &Config{Options: NewDefaultOptions(), AutoCertificates: []tls.Certificate{
*gen(t, "a.example.com"),
}}
found, err := cfg.GetCertificateForServerName("b.example.com")
if !assert.NoError(t, err) {
return
}
assert.NotNil(t, found)
assert.NotEqual(t, &cfg.AutoCertificates[0], found)
})
t.Run("generate", func(t *testing.T) {
cfg := &Config{Options: NewDefaultOptions()}
found, err := cfg.GetCertificateForServerName("b.example.com")
if !assert.NoError(t, err) {
return
}
assert.NotNil(t, found)
})
}

View file

@ -512,13 +512,7 @@ func (b *Builder) buildDownstreamTLSContext(ctx context.Context,
cfg *config.Config, cfg *config.Config,
serverName string, serverName string,
) *envoy_extensions_transport_sockets_tls_v3.DownstreamTlsContext { ) *envoy_extensions_transport_sockets_tls_v3.DownstreamTlsContext {
certs, err := cfg.AllCertificates() cert, err := cfg.GetCertificateForServerName(serverName)
if err != nil {
log.Warn(ctx).Str("domain", serverName).Err(err).Msg("failed to get all certificates from config")
return nil
}
cert, err := cryptutil.GetCertificateForServerName(certs, serverName)
if err != nil { if err != nil {
log.Warn(ctx).Str("domain", serverName).Err(err).Msg("failed to get certificate for domain") log.Warn(ctx).Str("domain", serverName).Err(err).Msg("failed to get certificate for domain")
return nil return nil

View file

@ -44,27 +44,10 @@ func GetCertPool(ca, caFile string) (*x509.CertPool, error) {
return rootCAs, nil return rootCAs, nil
} }
// GetCertificateForServerName returns the tls Certificate which matches the given server name.
// It should handle both exact matches and wildcard matches. If none of those match, the first certificate will be used.
// Finally if there are no matching certificates one will be generated.
func GetCertificateForServerName(certificates []tls.Certificate, serverName string) (*tls.Certificate, error) {
// first try a direct name match
for i := range certificates {
if matchesServerName(&certificates[i], serverName) {
return &certificates[i], nil
}
}
log.WarnNoTLSCertificate(serverName)
// finally fall back to a generated, self-signed certificate
return GenerateSelfSignedCertificate(serverName)
}
// HasCertificateForServerName returns true if a TLS certificate matches the given server name. // HasCertificateForServerName returns true if a TLS certificate matches the given server name.
func HasCertificateForServerName(certificates []tls.Certificate, serverName string) bool { func HasCertificateForServerName(certificates []tls.Certificate, serverName string) bool {
for i := range certificates { for i := range certificates {
if matchesServerName(&certificates[i], serverName) { if MatchesServerName(&certificates[i], serverName) {
return true return true
} }
} }
@ -95,7 +78,8 @@ func GetCertificateServerNames(cert *tls.Certificate) []string {
return serverNames return serverNames
} }
func matchesServerName(cert *tls.Certificate, serverName string) bool { // MatchesServerName returns true if the certificate matches the server name.
func MatchesServerName(cert *tls.Certificate, serverName string) bool {
if cert == nil || len(cert.Certificate) == 0 { if cert == nil || len(cert.Certificate) == 0 {
return false return false
} }

View file

@ -1,69 +1,12 @@
package cryptutil package cryptutil
import ( import (
"crypto/tls"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func TestGetCertificateForServerName(t *testing.T) {
gen := func(t *testing.T, serverName string) *tls.Certificate {
cert, err := GenerateSelfSignedCertificate(serverName)
if !assert.NoError(t, err, "error generating certificate for: %s", serverName) {
t.FailNow()
}
return cert
}
t.Run("exact match", func(t *testing.T) {
certs := []tls.Certificate{
*gen(t, "a.example.com"),
*gen(t, "b.example.com"),
}
found, err := GetCertificateForServerName(certs, "b.example.com")
if !assert.NoError(t, err) {
return
}
assert.Equal(t, &certs[1], found)
})
t.Run("wildcard match", func(t *testing.T) {
certs := []tls.Certificate{
*gen(t, "a.example.com"),
*gen(t, "*.example.com"),
}
found, err := GetCertificateForServerName(certs, "b.example.com")
if !assert.NoError(t, err) {
return
}
assert.Equal(t, &certs[1], found)
})
t.Run("no name match", func(t *testing.T) {
certs := []tls.Certificate{
*gen(t, "a.example.com"),
}
found, err := GetCertificateForServerName(certs, "b.example.com")
if !assert.NoError(t, err) {
return
}
assert.NotNil(t, found)
assert.NotEqual(t, &certs[0], found)
})
t.Run("generate", func(t *testing.T) {
certs := []tls.Certificate{}
found, err := GetCertificateForServerName(certs, "b.example.com")
if !assert.NoError(t, err) {
return
}
assert.NotNil(t, found)
})
}
func TestGetCertificateServerNames(t *testing.T) { func TestGetCertificateServerNames(t *testing.T) {
cert, err := GenerateSelfSignedCertificate("www.example.com") cert, err := GenerateSelfSignedCertificate("www.example.com")
require.NoError(t, err) require.NoError(t, err)