diff --git a/config/config.go b/config/config.go index 5be5cfe89..8f6ba6a7e 100644 --- a/config/config.go +++ b/config/config.go @@ -3,12 +3,16 @@ package config import ( "bytes" "crypto/tls" + "crypto/x509" "encoding/base64" + "fmt" "github.com/pomerium/pomerium/internal/fileutil" "github.com/pomerium/pomerium/internal/hashutil" + "github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/telemetry/metrics" "github.com/pomerium/pomerium/pkg/cryptutil" + "github.com/pomerium/pomerium/pkg/derivecert" ) // MetricsScrapeEndpoint defines additional metrics endpoints that would be scraped and exposed by pomerium @@ -130,7 +134,7 @@ func (cfg *Config) AllocatePorts(ports [6]string) { // GetTLSClientConfig returns TLS configuration that accounts for additional CA entries func (cfg *Config) GetTLSClientConfig() (*tls.Config, error) { - roots, err := cryptutil.GetCertPool(cfg.Options.CA, cfg.Options.CAFile) + roots, err := cfg.GetCertificatePool() if err != nil { return nil, err } @@ -139,3 +143,79 @@ func (cfg *Config) GetTLSClientConfig() (*tls.Config, error) { MinVersion: tls.VersionTLS12, }, nil } + +// GetCertificateForServerName gets the certificate for the server name. If no certificate is found and there +// is a derived CA one will be generated using that CA. If no derived CA is defined a self-signed certificate +// will be generated. +func (cfg *Config) GetCertificateForServerName(serverName string) (*tls.Certificate, error) { + certificates, err := cfg.AllCertificates() + if err != nil { + return nil, err + } + + // first try a direct name match + for i := range certificates { + if cryptutil.MatchesServerName(&certificates[i], serverName) { + return &certificates[i], nil + } + } + + log.WarnNoTLSCertificate(serverName) + + if cfg.Options.DeriveInternalDomainCert != nil { + sharedKey, err := cfg.Options.GetSharedKey() + if err != nil { + return nil, fmt.Errorf("failed to generate cert, invalid shared key: %w", err) + } + + ca, err := derivecert.NewCA(sharedKey) + if err != nil { + return nil, fmt.Errorf("failed to generate cert, invalid derived CA: %w", err) + } + + pem, err := ca.NewServerCert([]string{serverName}) + if err != nil { + return nil, fmt.Errorf("failed to generate cert, error creating server certificate: %w", err) + } + + cert, err := pem.TLS() + if err != nil { + return nil, fmt.Errorf("failed to generate cert, error converting generated certificate into TLS certificate: %w", err) + } + return &cert, nil + } + + // finally fall back to a generated, self-signed certificate + return cryptutil.GenerateSelfSignedCertificate(serverName) +} + +// GetCertificatePool gets the certificate pool for the config. +func (cfg *Config) GetCertificatePool() (*x509.CertPool, error) { + pool, err := cryptutil.GetCertPool(cfg.Options.CA, cfg.Options.CAFile) + if err != nil { + return nil, err + } + + if cfg.Options.DeriveInternalDomainCert != nil { + sharedKey, err := cfg.Options.GetSharedKey() + if err != nil { + return nil, fmt.Errorf("failed to derive CA, invalid shared key: %w", err) + } + + ca, err := derivecert.NewCA(sharedKey) + if err != nil { + return nil, fmt.Errorf("failed to derive CA: %w", err) + } + + pem, err := ca.PEM() + if err != nil { + return nil, fmt.Errorf("failed to derive CA PEM: %w", err) + } + + if !pool.AppendCertsFromPEM(pem.Cert) { + return nil, fmt.Errorf("failed to derive CA PEM, error appending to pool") + } + } + + return pool, nil +} diff --git a/config/config_test.go b/config/config_test.go new file mode 100644 index 000000000..f4fdad40e --- /dev/null +++ b/config/config_test.go @@ -0,0 +1,66 @@ +package config + +import ( + "crypto/tls" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/pomerium/pomerium/pkg/cryptutil" +) + +func TestConfig_GetCertificateForServerName(t *testing.T) { + gen := func(t *testing.T, serverName string) *tls.Certificate { + cert, err := cryptutil.GenerateSelfSignedCertificate(serverName) + if !assert.NoError(t, err, "error generating certificate for: %s", serverName) { + t.FailNow() + } + return cert + } + + t.Run("exact match", func(t *testing.T) { + cfg := &Config{Options: NewDefaultOptions(), AutoCertificates: []tls.Certificate{ + *gen(t, "a.example.com"), + *gen(t, "b.example.com"), + }} + + found, err := cfg.GetCertificateForServerName("b.example.com") + if !assert.NoError(t, err) { + return + } + assert.Equal(t, &cfg.AutoCertificates[1], found) + }) + t.Run("wildcard match", func(t *testing.T) { + cfg := &Config{Options: NewDefaultOptions(), AutoCertificates: []tls.Certificate{ + *gen(t, "a.example.com"), + *gen(t, "*.example.com"), + }} + + found, err := cfg.GetCertificateForServerName("b.example.com") + if !assert.NoError(t, err) { + return + } + assert.Equal(t, &cfg.AutoCertificates[1], found) + }) + t.Run("no name match", func(t *testing.T) { + cfg := &Config{Options: NewDefaultOptions(), AutoCertificates: []tls.Certificate{ + *gen(t, "a.example.com"), + }} + + found, err := cfg.GetCertificateForServerName("b.example.com") + if !assert.NoError(t, err) { + return + } + assert.NotNil(t, found) + assert.NotEqual(t, &cfg.AutoCertificates[0], found) + }) + t.Run("generate", func(t *testing.T) { + cfg := &Config{Options: NewDefaultOptions()} + + found, err := cfg.GetCertificateForServerName("b.example.com") + if !assert.NoError(t, err) { + return + } + assert.NotNil(t, found) + }) +} diff --git a/config/envoyconfig/listeners.go b/config/envoyconfig/listeners.go index a390762c0..0afdd4f3b 100644 --- a/config/envoyconfig/listeners.go +++ b/config/envoyconfig/listeners.go @@ -512,13 +512,7 @@ func (b *Builder) buildDownstreamTLSContext(ctx context.Context, cfg *config.Config, serverName string, ) *envoy_extensions_transport_sockets_tls_v3.DownstreamTlsContext { - certs, err := cfg.AllCertificates() - if err != nil { - log.Warn(ctx).Str("domain", serverName).Err(err).Msg("failed to get all certificates from config") - return nil - } - - cert, err := cryptutil.GetCertificateForServerName(certs, serverName) + cert, err := cfg.GetCertificateForServerName(serverName) if err != nil { log.Warn(ctx).Str("domain", serverName).Err(err).Msg("failed to get certificate for domain") return nil diff --git a/pkg/cryptutil/tls.go b/pkg/cryptutil/tls.go index 395f9d1f5..d8308b649 100644 --- a/pkg/cryptutil/tls.go +++ b/pkg/cryptutil/tls.go @@ -44,27 +44,10 @@ func GetCertPool(ca, caFile string) (*x509.CertPool, error) { return rootCAs, nil } -// GetCertificateForServerName returns the tls Certificate which matches the given server name. -// It should handle both exact matches and wildcard matches. If none of those match, the first certificate will be used. -// Finally if there are no matching certificates one will be generated. -func GetCertificateForServerName(certificates []tls.Certificate, serverName string) (*tls.Certificate, error) { - // first try a direct name match - for i := range certificates { - if matchesServerName(&certificates[i], serverName) { - return &certificates[i], nil - } - } - - log.WarnNoTLSCertificate(serverName) - - // finally fall back to a generated, self-signed certificate - return GenerateSelfSignedCertificate(serverName) -} - // HasCertificateForServerName returns true if a TLS certificate matches the given server name. func HasCertificateForServerName(certificates []tls.Certificate, serverName string) bool { for i := range certificates { - if matchesServerName(&certificates[i], serverName) { + if MatchesServerName(&certificates[i], serverName) { return true } } @@ -95,7 +78,8 @@ func GetCertificateServerNames(cert *tls.Certificate) []string { return serverNames } -func matchesServerName(cert *tls.Certificate, serverName string) bool { +// MatchesServerName returns true if the certificate matches the server name. +func MatchesServerName(cert *tls.Certificate, serverName string) bool { if cert == nil || len(cert.Certificate) == 0 { return false } diff --git a/pkg/cryptutil/tls_test.go b/pkg/cryptutil/tls_test.go index a6bbdebec..36a86e0c3 100644 --- a/pkg/cryptutil/tls_test.go +++ b/pkg/cryptutil/tls_test.go @@ -1,69 +1,12 @@ package cryptutil import ( - "crypto/tls" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -func TestGetCertificateForServerName(t *testing.T) { - gen := func(t *testing.T, serverName string) *tls.Certificate { - cert, err := GenerateSelfSignedCertificate(serverName) - if !assert.NoError(t, err, "error generating certificate for: %s", serverName) { - t.FailNow() - } - return cert - } - - t.Run("exact match", func(t *testing.T) { - certs := []tls.Certificate{ - *gen(t, "a.example.com"), - *gen(t, "b.example.com"), - } - - found, err := GetCertificateForServerName(certs, "b.example.com") - if !assert.NoError(t, err) { - return - } - assert.Equal(t, &certs[1], found) - }) - t.Run("wildcard match", func(t *testing.T) { - certs := []tls.Certificate{ - *gen(t, "a.example.com"), - *gen(t, "*.example.com"), - } - - found, err := GetCertificateForServerName(certs, "b.example.com") - if !assert.NoError(t, err) { - return - } - assert.Equal(t, &certs[1], found) - }) - t.Run("no name match", func(t *testing.T) { - certs := []tls.Certificate{ - *gen(t, "a.example.com"), - } - - found, err := GetCertificateForServerName(certs, "b.example.com") - if !assert.NoError(t, err) { - return - } - assert.NotNil(t, found) - assert.NotEqual(t, &certs[0], found) - }) - t.Run("generate", func(t *testing.T) { - certs := []tls.Certificate{} - - found, err := GetCertificateForServerName(certs, "b.example.com") - if !assert.NoError(t, err) { - return - } - assert.NotNil(t, found) - }) -} - func TestGetCertificateServerNames(t *testing.T) { cert, err := GenerateSelfSignedCertificate("www.example.com") require.NoError(t, err)