mirror of
https://github.com/pomerium/pomerium.git
synced 2025-08-03 00:40:25 +02:00
improve certificate matching performance (#4186)
This commit is contained in:
parent
d6e1f3af14
commit
e3b2b3994c
6 changed files with 24 additions and 133 deletions
|
@ -12,7 +12,6 @@ import (
|
|||
"github.com/pomerium/pomerium/internal/fileutil"
|
||||
"github.com/pomerium/pomerium/internal/hashutil"
|
||||
"github.com/pomerium/pomerium/internal/httputil"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/internal/telemetry/metrics"
|
||||
"github.com/pomerium/pomerium/internal/urlutil"
|
||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||
|
@ -149,24 +148,9 @@ func (cfg *Config) GetTLSClientConfig() (*tls.Config, error) {
|
|||
}, nil
|
||||
}
|
||||
|
||||
// GetCertificateForServerName gets the certificate for the server name. If no certificate is found and there
|
||||
// is a derived CA one will be generated using that CA. If no derived CA is defined a self-signed certificate
|
||||
// will be generated.
|
||||
func (cfg *Config) GetCertificateForServerName(serverName string) (*tls.Certificate, error) {
|
||||
certificates, err := cfg.AllCertificates()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// first try a direct name match
|
||||
for i := range certificates {
|
||||
if cryptutil.MatchesServerName(&certificates[i], serverName) {
|
||||
return &certificates[i], nil
|
||||
}
|
||||
}
|
||||
|
||||
log.WarnNoTLSCertificate(serverName)
|
||||
|
||||
// GenerateCatchAllCertificate generates a catch-all certificate. If no derived CA is defined a
|
||||
// self-signed certificate will be generated.
|
||||
func (cfg *Config) GenerateCatchAllCertificate() (*tls.Certificate, error) {
|
||||
if cfg.Options.DeriveInternalDomainCert != nil {
|
||||
sharedKey, err := cfg.Options.GetSharedKey()
|
||||
if err != nil {
|
||||
|
@ -178,7 +162,7 @@ func (cfg *Config) GetCertificateForServerName(serverName string) (*tls.Certific
|
|||
return nil, fmt.Errorf("failed to generate cert, invalid derived CA: %w", err)
|
||||
}
|
||||
|
||||
pem, err := ca.NewServerCert([]string{serverName})
|
||||
pem, err := ca.NewServerCert([]string{"*"})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate cert, error creating server certificate: %w", err)
|
||||
}
|
||||
|
@ -196,7 +180,7 @@ func (cfg *Config) GetCertificateForServerName(serverName string) (*tls.Certific
|
|||
}
|
||||
|
||||
// finally fall back to a generated, self-signed certificate
|
||||
return cryptutil.GenerateCertificate(sharedKey, serverName)
|
||||
return cryptutil.GenerateCertificate(sharedKey, "*")
|
||||
}
|
||||
|
||||
// WillHaveCertificateForServerName returns true if there will be a certificate for the given server name.
|
||||
|
|
|
@ -1,94 +0,0 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||
)
|
||||
|
||||
func TestConfig_GetCertificateForServerName(t *testing.T) {
|
||||
gen := func(t *testing.T, serverName string) *tls.Certificate {
|
||||
cert, err := cryptutil.GenerateCertificate(nil, serverName)
|
||||
if !assert.NoError(t, err, "error generating certificate for: %s", serverName) {
|
||||
t.FailNow()
|
||||
}
|
||||
return cert
|
||||
}
|
||||
|
||||
t.Run("exact match", func(t *testing.T) {
|
||||
cfg := &Config{Options: NewDefaultOptions(), AutoCertificates: []tls.Certificate{
|
||||
*gen(t, "a.example.com"),
|
||||
*gen(t, "b.example.com"),
|
||||
}}
|
||||
|
||||
found, err := cfg.GetCertificateForServerName("b.example.com")
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
assert.Equal(t, &cfg.AutoCertificates[1], found)
|
||||
})
|
||||
t.Run("wildcard match", func(t *testing.T) {
|
||||
cfg := &Config{Options: NewDefaultOptions(), AutoCertificates: []tls.Certificate{
|
||||
*gen(t, "a.example.com"),
|
||||
*gen(t, "*.example.com"),
|
||||
}}
|
||||
|
||||
found, err := cfg.GetCertificateForServerName("b.example.com")
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
assert.Equal(t, &cfg.AutoCertificates[1], found)
|
||||
})
|
||||
t.Run("no name match", func(t *testing.T) {
|
||||
cfg := &Config{Options: NewDefaultOptions(), AutoCertificates: []tls.Certificate{
|
||||
*gen(t, "a.example.com"),
|
||||
}}
|
||||
|
||||
found, err := cfg.GetCertificateForServerName("b.example.com")
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
assert.NotNil(t, found)
|
||||
assert.NotEqual(t, &cfg.AutoCertificates[0], found)
|
||||
})
|
||||
t.Run("generate", func(t *testing.T) {
|
||||
cfg := &Config{Options: NewDefaultOptions()}
|
||||
|
||||
found, err := cfg.GetCertificateForServerName("b.example.com")
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
assert.NotNil(t, found)
|
||||
})
|
||||
t.Run("generate for specific name", func(t *testing.T) {
|
||||
cfg := &Config{Options: NewDefaultOptions()}
|
||||
cfg.Options.DeriveInternalDomainCert = proto.String("databroker.int.example.com")
|
||||
|
||||
ok, err := cfg.WillHaveCertificateForServerName("databroker.int.example.com")
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ok)
|
||||
|
||||
found, err := cfg.GetCertificateForServerName("databroker.int.example.com")
|
||||
require.NoError(t, err)
|
||||
assert.True(t, cryptutil.MatchesServerName(found, "databroker.int.example.com"))
|
||||
|
||||
certPool, err := cfg.GetCertificatePool()
|
||||
require.NoError(t, err)
|
||||
|
||||
xc, err := x509.ParseCertificate(found.Certificate[0])
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = xc.Verify(x509.VerifyOptions{
|
||||
DNSName: "databroker.int.example.com",
|
||||
KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageAny},
|
||||
Roots: certPool,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
|
@ -109,7 +109,8 @@ func getAllCertificates(cfg *config.Config) ([]tls.Certificate, error) {
|
|||
if err != nil {
|
||||
return nil, fmt.Errorf("error collecting all certificates: %w", err)
|
||||
}
|
||||
wc, err := cfg.GetCertificateForServerName("*")
|
||||
|
||||
wc, err := cfg.GenerateCatchAllCertificate()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting wildcard certificate: %w", err)
|
||||
}
|
||||
|
|
|
@ -78,7 +78,7 @@ func (b *Builder) buildMainRouteConfiguration(
|
|||
|
||||
// if we're the proxy, add all the policy routes
|
||||
if config.IsProxy(cfg.Options.Services) {
|
||||
rs, err := b.buildRoutesForPoliciesWithHost(cfg, host)
|
||||
rs, err := b.buildRoutesForPoliciesWithHost(cfg, certs, host)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -95,7 +95,7 @@ func (b *Builder) buildMainRouteConfiguration(
|
|||
return nil, err
|
||||
}
|
||||
if config.IsProxy(cfg.Options.Services) {
|
||||
rs, err := b.buildRoutesForPoliciesWithCatchAll(cfg)
|
||||
rs, err := b.buildRoutesForPoliciesWithCatchAll(cfg, certs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package envoyconfig
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/url"
|
||||
|
@ -193,6 +194,7 @@ func getClusterStatsName(policy *config.Policy) string {
|
|||
|
||||
func (b *Builder) buildRoutesForPoliciesWithHost(
|
||||
cfg *config.Config,
|
||||
certs []tls.Certificate,
|
||||
host string,
|
||||
) ([]*envoy_config_route_v3.Route, error) {
|
||||
var routes []*envoy_config_route_v3.Route
|
||||
|
@ -207,7 +209,7 @@ func (b *Builder) buildRoutesForPoliciesWithHost(
|
|||
continue
|
||||
}
|
||||
|
||||
policyRoutes, err := b.buildRoutesForPolicy(cfg, &policy, fmt.Sprintf("policy-%d", i))
|
||||
policyRoutes, err := b.buildRoutesForPolicy(cfg, certs, &policy, fmt.Sprintf("policy-%d", i))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -219,6 +221,7 @@ func (b *Builder) buildRoutesForPoliciesWithHost(
|
|||
|
||||
func (b *Builder) buildRoutesForPoliciesWithCatchAll(
|
||||
cfg *config.Config,
|
||||
certs []tls.Certificate,
|
||||
) ([]*envoy_config_route_v3.Route, error) {
|
||||
var routes []*envoy_config_route_v3.Route
|
||||
for i, p := range cfg.Options.GetAllPolicies() {
|
||||
|
@ -232,7 +235,7 @@ func (b *Builder) buildRoutesForPoliciesWithCatchAll(
|
|||
continue
|
||||
}
|
||||
|
||||
policyRoutes, err := b.buildRoutesForPolicy(cfg, &policy, fmt.Sprintf("policy-%d", i))
|
||||
policyRoutes, err := b.buildRoutesForPolicy(cfg, certs, &policy, fmt.Sprintf("policy-%d", i))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -244,6 +247,7 @@ func (b *Builder) buildRoutesForPoliciesWithCatchAll(
|
|||
|
||||
func (b *Builder) buildRoutesForPolicy(
|
||||
cfg *config.Config,
|
||||
certs []tls.Certificate,
|
||||
policy *config.Policy,
|
||||
name string,
|
||||
) ([]*envoy_config_route_v3.Route, error) {
|
||||
|
@ -256,14 +260,14 @@ func (b *Builder) buildRoutesForPolicy(
|
|||
if strings.Contains(fromURL.Host, "*") {
|
||||
// we have to match '*.example.com' and '*.example.com:443', so there are two routes
|
||||
for _, host := range urlutil.GetDomainsForURL(fromURL) {
|
||||
route, err := b.buildRouteForPolicyAndMatch(cfg, policy, name, mkRouteMatchForHost(policy, host))
|
||||
route, err := b.buildRouteForPolicyAndMatch(cfg, certs, policy, name, mkRouteMatchForHost(policy, host))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
routes = append(routes, route)
|
||||
}
|
||||
} else {
|
||||
route, err := b.buildRouteForPolicyAndMatch(cfg, policy, name, mkRouteMatch(policy))
|
||||
route, err := b.buildRouteForPolicyAndMatch(cfg, certs, policy, name, mkRouteMatch(policy))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -274,6 +278,7 @@ func (b *Builder) buildRoutesForPolicy(
|
|||
|
||||
func (b *Builder) buildRouteForPolicyAndMatch(
|
||||
cfg *config.Config,
|
||||
certs []tls.Certificate,
|
||||
policy *config.Policy,
|
||||
name string,
|
||||
match *envoy_config_route_v3.RouteMatch,
|
||||
|
@ -283,11 +288,6 @@ func (b *Builder) buildRouteForPolicyAndMatch(
|
|||
return nil, err
|
||||
}
|
||||
|
||||
certs, err := getAllCertificates(cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
requireStrictTransportSecurity := cryptutil.HasCertificateForServerName(certs, fromURL.Hostname())
|
||||
|
||||
route := &envoy_config_route_v3.Route{
|
||||
|
|
|
@ -307,7 +307,7 @@ func TestTimeouts(t *testing.T) {
|
|||
AllowWebsockets: tc.allowWebsockets,
|
||||
},
|
||||
},
|
||||
}}, "example.com")
|
||||
}}, nil, "example.com")
|
||||
if !assert.NoError(t, err, "%v", tc) || !assert.Len(t, routes, 1, tc) || !assert.NotNil(t, routes[0].GetRoute(), "%v", tc) {
|
||||
continue
|
||||
}
|
||||
|
@ -412,7 +412,7 @@ func Test_buildPolicyRoutes(t *testing.T) {
|
|||
UpstreamTimeout: &ten,
|
||||
},
|
||||
},
|
||||
}}, "example.com")
|
||||
}}, nil, "example.com")
|
||||
require.NoError(t, err)
|
||||
|
||||
testutil.AssertProtoJSONEqual(t, `
|
||||
|
@ -918,7 +918,7 @@ func Test_buildPolicyRoutes(t *testing.T) {
|
|||
PassIdentityHeaders: true,
|
||||
},
|
||||
},
|
||||
}}, "authenticate.example.com")
|
||||
}}, nil, "authenticate.example.com")
|
||||
require.NoError(t, err)
|
||||
|
||||
testutil.AssertProtoJSONEqual(t, `
|
||||
|
@ -1005,7 +1005,7 @@ func Test_buildPolicyRoutes(t *testing.T) {
|
|||
UpstreamTimeout: &ten,
|
||||
},
|
||||
},
|
||||
}}, "example.com:22")
|
||||
}}, nil, "example.com:22")
|
||||
require.NoError(t, err)
|
||||
|
||||
testutil.AssertProtoJSONEqual(t, `
|
||||
|
@ -1151,7 +1151,7 @@ func Test_buildPolicyRoutes(t *testing.T) {
|
|||
From: "https://from.example.com",
|
||||
},
|
||||
},
|
||||
}}, "from.example.com")
|
||||
}}, nil, "from.example.com")
|
||||
require.NoError(t, err)
|
||||
|
||||
testutil.AssertProtoJSONEqual(t, `
|
||||
|
@ -1272,7 +1272,7 @@ func Test_buildPolicyRoutesRewrite(t *testing.T) {
|
|||
HostPathRegexRewriteSubstitution: "\\1",
|
||||
},
|
||||
},
|
||||
}}, "example.com")
|
||||
}}, nil, "example.com")
|
||||
require.NoError(t, err)
|
||||
|
||||
testutil.AssertProtoJSONEqual(t, `
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue