mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-10 23:57:34 +02:00
config: generate derived certificates instead of self-signed certificates (#3860)
This commit is contained in:
parent
488bcd6f72
commit
3f1a87727f
5 changed files with 151 additions and 84 deletions
|
@ -3,12 +3,16 @@ package config
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/internal/fileutil"
|
"github.com/pomerium/pomerium/internal/fileutil"
|
||||||
"github.com/pomerium/pomerium/internal/hashutil"
|
"github.com/pomerium/pomerium/internal/hashutil"
|
||||||
|
"github.com/pomerium/pomerium/internal/log"
|
||||||
"github.com/pomerium/pomerium/internal/telemetry/metrics"
|
"github.com/pomerium/pomerium/internal/telemetry/metrics"
|
||||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||||
|
"github.com/pomerium/pomerium/pkg/derivecert"
|
||||||
)
|
)
|
||||||
|
|
||||||
// MetricsScrapeEndpoint defines additional metrics endpoints that would be scraped and exposed by pomerium
|
// MetricsScrapeEndpoint defines additional metrics endpoints that would be scraped and exposed by pomerium
|
||||||
|
@ -130,7 +134,7 @@ func (cfg *Config) AllocatePorts(ports [6]string) {
|
||||||
|
|
||||||
// GetTLSClientConfig returns TLS configuration that accounts for additional CA entries
|
// GetTLSClientConfig returns TLS configuration that accounts for additional CA entries
|
||||||
func (cfg *Config) GetTLSClientConfig() (*tls.Config, error) {
|
func (cfg *Config) GetTLSClientConfig() (*tls.Config, error) {
|
||||||
roots, err := cryptutil.GetCertPool(cfg.Options.CA, cfg.Options.CAFile)
|
roots, err := cfg.GetCertificatePool()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -139,3 +143,79 @@ func (cfg *Config) GetTLSClientConfig() (*tls.Config, error) {
|
||||||
MinVersion: tls.VersionTLS12,
|
MinVersion: tls.VersionTLS12,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetCertificateForServerName gets the certificate for the server name. If no certificate is found and there
|
||||||
|
// is a derived CA one will be generated using that CA. If no derived CA is defined a self-signed certificate
|
||||||
|
// will be generated.
|
||||||
|
func (cfg *Config) GetCertificateForServerName(serverName string) (*tls.Certificate, error) {
|
||||||
|
certificates, err := cfg.AllCertificates()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// first try a direct name match
|
||||||
|
for i := range certificates {
|
||||||
|
if cryptutil.MatchesServerName(&certificates[i], serverName) {
|
||||||
|
return &certificates[i], nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
log.WarnNoTLSCertificate(serverName)
|
||||||
|
|
||||||
|
if cfg.Options.DeriveInternalDomainCert != nil {
|
||||||
|
sharedKey, err := cfg.Options.GetSharedKey()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to generate cert, invalid shared key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ca, err := derivecert.NewCA(sharedKey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to generate cert, invalid derived CA: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
pem, err := ca.NewServerCert([]string{serverName})
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to generate cert, error creating server certificate: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cert, err := pem.TLS()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to generate cert, error converting generated certificate into TLS certificate: %w", err)
|
||||||
|
}
|
||||||
|
return &cert, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// finally fall back to a generated, self-signed certificate
|
||||||
|
return cryptutil.GenerateSelfSignedCertificate(serverName)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetCertificatePool gets the certificate pool for the config.
|
||||||
|
func (cfg *Config) GetCertificatePool() (*x509.CertPool, error) {
|
||||||
|
pool, err := cryptutil.GetCertPool(cfg.Options.CA, cfg.Options.CAFile)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.Options.DeriveInternalDomainCert != nil {
|
||||||
|
sharedKey, err := cfg.Options.GetSharedKey()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to derive CA, invalid shared key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ca, err := derivecert.NewCA(sharedKey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to derive CA: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
pem, err := ca.PEM()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to derive CA PEM: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !pool.AppendCertsFromPEM(pem.Cert) {
|
||||||
|
return nil, fmt.Errorf("failed to derive CA PEM, error appending to pool")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return pool, nil
|
||||||
|
}
|
||||||
|
|
66
config/config_test.go
Normal file
66
config/config_test.go
Normal file
|
@ -0,0 +1,66 @@
|
||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestConfig_GetCertificateForServerName(t *testing.T) {
|
||||||
|
gen := func(t *testing.T, serverName string) *tls.Certificate {
|
||||||
|
cert, err := cryptutil.GenerateSelfSignedCertificate(serverName)
|
||||||
|
if !assert.NoError(t, err, "error generating certificate for: %s", serverName) {
|
||||||
|
t.FailNow()
|
||||||
|
}
|
||||||
|
return cert
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("exact match", func(t *testing.T) {
|
||||||
|
cfg := &Config{Options: NewDefaultOptions(), AutoCertificates: []tls.Certificate{
|
||||||
|
*gen(t, "a.example.com"),
|
||||||
|
*gen(t, "b.example.com"),
|
||||||
|
}}
|
||||||
|
|
||||||
|
found, err := cfg.GetCertificateForServerName("b.example.com")
|
||||||
|
if !assert.NoError(t, err) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
assert.Equal(t, &cfg.AutoCertificates[1], found)
|
||||||
|
})
|
||||||
|
t.Run("wildcard match", func(t *testing.T) {
|
||||||
|
cfg := &Config{Options: NewDefaultOptions(), AutoCertificates: []tls.Certificate{
|
||||||
|
*gen(t, "a.example.com"),
|
||||||
|
*gen(t, "*.example.com"),
|
||||||
|
}}
|
||||||
|
|
||||||
|
found, err := cfg.GetCertificateForServerName("b.example.com")
|
||||||
|
if !assert.NoError(t, err) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
assert.Equal(t, &cfg.AutoCertificates[1], found)
|
||||||
|
})
|
||||||
|
t.Run("no name match", func(t *testing.T) {
|
||||||
|
cfg := &Config{Options: NewDefaultOptions(), AutoCertificates: []tls.Certificate{
|
||||||
|
*gen(t, "a.example.com"),
|
||||||
|
}}
|
||||||
|
|
||||||
|
found, err := cfg.GetCertificateForServerName("b.example.com")
|
||||||
|
if !assert.NoError(t, err) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
assert.NotNil(t, found)
|
||||||
|
assert.NotEqual(t, &cfg.AutoCertificates[0], found)
|
||||||
|
})
|
||||||
|
t.Run("generate", func(t *testing.T) {
|
||||||
|
cfg := &Config{Options: NewDefaultOptions()}
|
||||||
|
|
||||||
|
found, err := cfg.GetCertificateForServerName("b.example.com")
|
||||||
|
if !assert.NoError(t, err) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
assert.NotNil(t, found)
|
||||||
|
})
|
||||||
|
}
|
|
@ -512,13 +512,7 @@ func (b *Builder) buildDownstreamTLSContext(ctx context.Context,
|
||||||
cfg *config.Config,
|
cfg *config.Config,
|
||||||
serverName string,
|
serverName string,
|
||||||
) *envoy_extensions_transport_sockets_tls_v3.DownstreamTlsContext {
|
) *envoy_extensions_transport_sockets_tls_v3.DownstreamTlsContext {
|
||||||
certs, err := cfg.AllCertificates()
|
cert, err := cfg.GetCertificateForServerName(serverName)
|
||||||
if err != nil {
|
|
||||||
log.Warn(ctx).Str("domain", serverName).Err(err).Msg("failed to get all certificates from config")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
cert, err := cryptutil.GetCertificateForServerName(certs, serverName)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warn(ctx).Str("domain", serverName).Err(err).Msg("failed to get certificate for domain")
|
log.Warn(ctx).Str("domain", serverName).Err(err).Msg("failed to get certificate for domain")
|
||||||
return nil
|
return nil
|
||||||
|
|
|
@ -44,27 +44,10 @@ func GetCertPool(ca, caFile string) (*x509.CertPool, error) {
|
||||||
return rootCAs, nil
|
return rootCAs, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetCertificateForServerName returns the tls Certificate which matches the given server name.
|
|
||||||
// It should handle both exact matches and wildcard matches. If none of those match, the first certificate will be used.
|
|
||||||
// Finally if there are no matching certificates one will be generated.
|
|
||||||
func GetCertificateForServerName(certificates []tls.Certificate, serverName string) (*tls.Certificate, error) {
|
|
||||||
// first try a direct name match
|
|
||||||
for i := range certificates {
|
|
||||||
if matchesServerName(&certificates[i], serverName) {
|
|
||||||
return &certificates[i], nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
log.WarnNoTLSCertificate(serverName)
|
|
||||||
|
|
||||||
// finally fall back to a generated, self-signed certificate
|
|
||||||
return GenerateSelfSignedCertificate(serverName)
|
|
||||||
}
|
|
||||||
|
|
||||||
// HasCertificateForServerName returns true if a TLS certificate matches the given server name.
|
// HasCertificateForServerName returns true if a TLS certificate matches the given server name.
|
||||||
func HasCertificateForServerName(certificates []tls.Certificate, serverName string) bool {
|
func HasCertificateForServerName(certificates []tls.Certificate, serverName string) bool {
|
||||||
for i := range certificates {
|
for i := range certificates {
|
||||||
if matchesServerName(&certificates[i], serverName) {
|
if MatchesServerName(&certificates[i], serverName) {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -95,7 +78,8 @@ func GetCertificateServerNames(cert *tls.Certificate) []string {
|
||||||
return serverNames
|
return serverNames
|
||||||
}
|
}
|
||||||
|
|
||||||
func matchesServerName(cert *tls.Certificate, serverName string) bool {
|
// MatchesServerName returns true if the certificate matches the server name.
|
||||||
|
func MatchesServerName(cert *tls.Certificate, serverName string) bool {
|
||||||
if cert == nil || len(cert.Certificate) == 0 {
|
if cert == nil || len(cert.Certificate) == 0 {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,69 +1,12 @@
|
||||||
package cryptutil
|
package cryptutil
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/tls"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestGetCertificateForServerName(t *testing.T) {
|
|
||||||
gen := func(t *testing.T, serverName string) *tls.Certificate {
|
|
||||||
cert, err := GenerateSelfSignedCertificate(serverName)
|
|
||||||
if !assert.NoError(t, err, "error generating certificate for: %s", serverName) {
|
|
||||||
t.FailNow()
|
|
||||||
}
|
|
||||||
return cert
|
|
||||||
}
|
|
||||||
|
|
||||||
t.Run("exact match", func(t *testing.T) {
|
|
||||||
certs := []tls.Certificate{
|
|
||||||
*gen(t, "a.example.com"),
|
|
||||||
*gen(t, "b.example.com"),
|
|
||||||
}
|
|
||||||
|
|
||||||
found, err := GetCertificateForServerName(certs, "b.example.com")
|
|
||||||
if !assert.NoError(t, err) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
assert.Equal(t, &certs[1], found)
|
|
||||||
})
|
|
||||||
t.Run("wildcard match", func(t *testing.T) {
|
|
||||||
certs := []tls.Certificate{
|
|
||||||
*gen(t, "a.example.com"),
|
|
||||||
*gen(t, "*.example.com"),
|
|
||||||
}
|
|
||||||
|
|
||||||
found, err := GetCertificateForServerName(certs, "b.example.com")
|
|
||||||
if !assert.NoError(t, err) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
assert.Equal(t, &certs[1], found)
|
|
||||||
})
|
|
||||||
t.Run("no name match", func(t *testing.T) {
|
|
||||||
certs := []tls.Certificate{
|
|
||||||
*gen(t, "a.example.com"),
|
|
||||||
}
|
|
||||||
|
|
||||||
found, err := GetCertificateForServerName(certs, "b.example.com")
|
|
||||||
if !assert.NoError(t, err) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
assert.NotNil(t, found)
|
|
||||||
assert.NotEqual(t, &certs[0], found)
|
|
||||||
})
|
|
||||||
t.Run("generate", func(t *testing.T) {
|
|
||||||
certs := []tls.Certificate{}
|
|
||||||
|
|
||||||
found, err := GetCertificateForServerName(certs, "b.example.com")
|
|
||||||
if !assert.NoError(t, err) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
assert.NotNil(t, found)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGetCertificateServerNames(t *testing.T) {
|
func TestGetCertificateServerNames(t *testing.T) {
|
||||||
cert, err := GenerateSelfSignedCertificate("www.example.com")
|
cert, err := GenerateSelfSignedCertificate("www.example.com")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue