autocert: store certificates separately from config certificates (#1794)

This commit is contained in:
Caleb Doxsey 2021-01-21 13:13:55 -07:00 committed by GitHub
parent 70b4497595
commit c90eda5622
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 87 additions and 72 deletions

27
config/config.go Normal file
View file

@ -0,0 +1,27 @@
package config
import "crypto/tls"
// Config holds pomerium configuration options.
type Config struct {
Options *Options
AutoCertificates []tls.Certificate
}
// Clone creates a clone of the config.
func (cfg *Config) Clone() *Config {
newOptions := new(Options)
*newOptions = *cfg.Options
return &Config{
Options: newOptions,
AutoCertificates: cfg.AutoCertificates,
}
}
// AllCertificates returns all the certificates in the config.
func (cfg *Config) AllCertificates() []tls.Certificate {
var certs []tls.Certificate
certs = append(certs, cfg.Options.Certificates...)
certs = append(certs, cfg.AutoCertificates...)
return certs
}

View file

@ -11,20 +11,6 @@ import (
"github.com/pomerium/pomerium/internal/fileutil" "github.com/pomerium/pomerium/internal/fileutil"
) )
// Config holds pomerium configuration options.
type Config struct {
Options *Options
}
// Clone creates a clone of the config.
func (cfg *Config) Clone() *Config {
newOptions := new(Options)
*newOptions = *cfg.Options
return &Config{
Options: newOptions,
}
}
// A ChangeListener is called when configuration changes. // A ChangeListener is called when configuration changes.
type ChangeListener = func(*Config) type ChangeListener = func(*Config)

View file

@ -98,19 +98,19 @@ func newManager(ctx context.Context,
return mgr, nil return mgr, nil
} }
func (mgr *Manager) getCertMagicConfig(options *config.Options) (*certmagic.Config, error) { func (mgr *Manager) getCertMagicConfig(cfg *config.Config) (*certmagic.Config, error) {
mgr.certmagic.MustStaple = options.AutocertOptions.MustStaple mgr.certmagic.MustStaple = cfg.Options.AutocertOptions.MustStaple
mgr.certmagic.OnDemand = nil // disable on-demand mgr.certmagic.OnDemand = nil // disable on-demand
mgr.certmagic.Storage = &certmagic.FileStorage{Path: options.AutocertOptions.Folder} mgr.certmagic.Storage = &certmagic.FileStorage{Path: cfg.Options.AutocertOptions.Folder}
// add existing certs to the cache, and staple OCSP // add existing certs to the cache, and staple OCSP
for _, cert := range options.Certificates { for _, cert := range cfg.AllCertificates() {
if err := mgr.certmagic.CacheUnmanagedTLSCertificate(cert, nil); err != nil { if err := mgr.certmagic.CacheUnmanagedTLSCertificate(cert, nil); err != nil {
return nil, fmt.Errorf("config: failed caching cert: %w", err) return nil, fmt.Errorf("config: failed caching cert: %w", err)
} }
} }
acmeMgr := certmagic.NewACMEManager(mgr.certmagic, mgr.acmeTemplate) acmeMgr := certmagic.NewACMEManager(mgr.certmagic, mgr.acmeTemplate)
acmeMgr.Agreed = true acmeMgr.Agreed = true
if options.AutocertOptions.UseStaging { if cfg.Options.AutocertOptions.UseStaging {
acmeMgr.CA = acmeMgr.TestCA acmeMgr.CA = acmeMgr.TestCA
} }
acmeMgr.DisableTLSALPNChallenge = true acmeMgr.DisableTLSALPNChallenge = true
@ -125,7 +125,7 @@ func (mgr *Manager) renewConfigCerts() error {
defer mgr.mu.Unlock() defer mgr.mu.Unlock()
cfg := mgr.config cfg := mgr.config
cm, err := mgr.getCertMagicConfig(cfg.Options) cm, err := mgr.getCertMagicConfig(cfg)
if err != nil { if err != nil {
return err return err
} }
@ -197,7 +197,7 @@ func (mgr *Manager) updateAutocert(cfg *config.Config) error {
return nil return nil
} }
cm, err := mgr.getCertMagicConfig(cfg.Options) cm, err := mgr.getCertMagicConfig(cfg)
if err != nil { if err != nil {
return err return err
} }
@ -219,7 +219,7 @@ func (mgr *Manager) updateAutocert(cfg *config.Config) error {
} }
log.Info().Strs("names", cert.Names).Msg("autocert: added certificate") log.Info().Strs("names", cert.Names).Msg("autocert: added certificate")
cfg.Options.Certificates = append(cfg.Options.Certificates, cert.Certificate) cfg.AutoCertificates = append(cfg.AutoCertificates, cert.Certificate)
} }
return nil return nil

View file

@ -171,12 +171,12 @@ func TestConfig(t *testing.T) {
var certs []tls.Certificate var certs []tls.Certificate
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
cfg := mgr.GetConfig() cfg := mgr.GetConfig()
assert.LessOrEqual(t, len(cfg.Options.Certificates), 1) assert.LessOrEqual(t, len(cfg.AutoCertificates), 1)
if len(cfg.Options.Certificates) == 1 && certs == nil { if len(cfg.AutoCertificates) == 1 && certs == nil {
certs = cfg.Options.Certificates certs = cfg.AutoCertificates
} }
if !cmp.Equal(certs, cfg.Options.Certificates) { if !cmp.Equal(certs, cfg.AutoCertificates) {
return return
} }

View file

@ -24,21 +24,21 @@ import (
"github.com/pomerium/pomerium/pkg/grpcutil" "github.com/pomerium/pomerium/pkg/grpcutil"
) )
type versionedOptions struct { type versionedConfig struct {
config.Options *config.Config
version int64 version int64
} }
type atomicVersionedOptions struct { type atomicVersionedConfig struct {
value atomic.Value value atomic.Value
} }
func (avo *atomicVersionedOptions) Load() versionedOptions { func (avo *atomicVersionedConfig) Load() versionedConfig {
return avo.value.Load().(versionedOptions) return avo.value.Load().(versionedConfig)
} }
func (avo *atomicVersionedOptions) Store(options versionedOptions) { func (avo *atomicVersionedConfig) Store(cfg versionedConfig) {
avo.value.Store(options) avo.value.Store(cfg)
} }
// A Server is the control-plane gRPC and HTTP servers. // A Server is the control-plane gRPC and HTTP servers.
@ -48,7 +48,7 @@ type Server struct {
HTTPListener net.Listener HTTPListener net.Listener
HTTPRouter *mux.Router HTTPRouter *mux.Router
currentConfig atomicVersionedOptions currentConfig atomicVersionedConfig
name string name string
xdsmgr *xdsmgr.Manager xdsmgr *xdsmgr.Manager
filemgr *filemgr.Manager filemgr *filemgr.Manager
@ -57,7 +57,9 @@ type Server struct {
// NewServer creates a new Server. Listener ports are chosen by the OS. // NewServer creates a new Server. Listener ports are chosen by the OS.
func NewServer(name string) (*Server, error) { func NewServer(name string) (*Server, error) {
srv := &Server{} srv := &Server{}
srv.currentConfig.Store(versionedOptions{}) srv.currentConfig.Store(versionedConfig{
Config: &config.Config{Options: &config.Options{}},
})
var err error var err error
@ -158,8 +160,8 @@ func (srv *Server) Run(ctx context.Context) error {
// OnConfigChange updates the pomerium config options. // OnConfigChange updates the pomerium config options.
func (srv *Server) OnConfigChange(cfg *config.Config) { func (srv *Server) OnConfigChange(cfg *config.Config) {
prev := srv.currentConfig.Load() prev := srv.currentConfig.Load()
srv.currentConfig.Store(versionedOptions{ srv.currentConfig.Store(versionedConfig{
Options: *cfg.Options, Config: cfg,
version: prev.version + 1, version: prev.version + 1,
}) })
srv.xdsmgr.Update(srv.buildDiscoveryResources()) srv.xdsmgr.Update(srv.buildDiscoveryResources())

View file

@ -34,7 +34,7 @@ const (
func (srv *Server) buildDiscoveryResources() map[string][]*envoy_service_discovery_v3.Resource { func (srv *Server) buildDiscoveryResources() map[string][]*envoy_service_discovery_v3.Resource {
resources := map[string][]*envoy_service_discovery_v3.Resource{} resources := map[string][]*envoy_service_discovery_v3.Resource{}
cfg := srv.currentConfig.Load() cfg := srv.currentConfig.Load()
for _, cluster := range srv.buildClusters(&cfg.Options) { for _, cluster := range srv.buildClusters(cfg.Options) {
any, _ := anypb.New(cluster) any, _ := anypb.New(cluster)
resources[clusterTypeURL] = append(resources[clusterTypeURL], &envoy_service_discovery_v3.Resource{ resources[clusterTypeURL] = append(resources[clusterTypeURL], &envoy_service_discovery_v3.Resource{
Name: cluster.Name, Name: cluster.Name,
@ -42,7 +42,7 @@ func (srv *Server) buildDiscoveryResources() map[string][]*envoy_service_discove
Resource: any, Resource: any,
}) })
} }
for _, listener := range srv.buildListeners(&cfg.Options) { for _, listener := range srv.buildListeners(cfg.Config) {
any, _ := anypb.New(listener) any, _ := anypb.New(listener)
resources[listenerTypeURL] = append(resources[listenerTypeURL], &envoy_service_discovery_v3.Resource{ resources[listenerTypeURL] = append(resources[listenerTypeURL], &envoy_service_discovery_v3.Resource{
Name: listener.Name, Name: listener.Name,

View file

@ -38,23 +38,23 @@ func init() {
}) })
} }
func (srv *Server) buildListeners(options *config.Options) []*envoy_config_listener_v3.Listener { func (srv *Server) buildListeners(cfg *config.Config) []*envoy_config_listener_v3.Listener {
var listeners []*envoy_config_listener_v3.Listener var listeners []*envoy_config_listener_v3.Listener
if config.IsAuthenticate(options.Services) || config.IsProxy(options.Services) { if config.IsAuthenticate(cfg.Options.Services) || config.IsProxy(cfg.Options.Services) {
listeners = append(listeners, srv.buildMainListener(options)) listeners = append(listeners, srv.buildMainListener(cfg))
} }
if config.IsAuthorize(options.Services) || config.IsDataBroker(options.Services) { if config.IsAuthorize(cfg.Options.Services) || config.IsDataBroker(cfg.Options.Services) {
listeners = append(listeners, srv.buildGRPCListener(options)) listeners = append(listeners, srv.buildGRPCListener(cfg))
} }
return listeners return listeners
} }
func (srv *Server) buildMainListener(options *config.Options) *envoy_config_listener_v3.Listener { func (srv *Server) buildMainListener(cfg *config.Config) *envoy_config_listener_v3.Listener {
listenerFilters := []*envoy_config_listener_v3.ListenerFilter{} listenerFilters := []*envoy_config_listener_v3.ListenerFilter{}
if options.UseProxyProtocol { if cfg.Options.UseProxyProtocol {
proxyCfg := marshalAny(&envoy_extensions_filters_listener_proxy_protocol_v3.ProxyProtocol{}) proxyCfg := marshalAny(&envoy_extensions_filters_listener_proxy_protocol_v3.ProxyProtocol{})
listenerFilters = append(listenerFilters, &envoy_config_listener_v3.ListenerFilter{ listenerFilters = append(listenerFilters, &envoy_config_listener_v3.ListenerFilter{
Name: "envoy.filters.listener.proxy_protocol", Name: "envoy.filters.listener.proxy_protocol",
@ -64,13 +64,13 @@ func (srv *Server) buildMainListener(options *config.Options) *envoy_config_list
}) })
} }
if options.InsecureServer { if cfg.Options.InsecureServer {
filter := buildMainHTTPConnectionManagerFilter(options, filter := buildMainHTTPConnectionManagerFilter(cfg.Options,
getAllRouteableDomains(options, options.Addr)) getAllRouteableDomains(cfg.Options, cfg.Options.Addr))
return &envoy_config_listener_v3.Listener{ return &envoy_config_listener_v3.Listener{
Name: "http-ingress", Name: "http-ingress",
Address: buildAddress(options.Addr, 80), Address: buildAddress(cfg.Options.Addr, 80),
ListenerFilters: listenerFilters, ListenerFilters: listenerFilters,
FilterChains: []*envoy_config_listener_v3.FilterChain{{ FilterChains: []*envoy_config_listener_v3.FilterChain{{
Filters: []*envoy_config_listener_v3.Filter{ Filters: []*envoy_config_listener_v3.Filter{
@ -90,11 +90,11 @@ func (srv *Server) buildMainListener(options *config.Options) *envoy_config_list
li := &envoy_config_listener_v3.Listener{ li := &envoy_config_listener_v3.Listener{
Name: "https-ingress", Name: "https-ingress",
Address: buildAddress(options.Addr, 443), Address: buildAddress(cfg.Options.Addr, 443),
ListenerFilters: listenerFilters, ListenerFilters: listenerFilters,
FilterChains: buildFilterChains(options, options.Addr, FilterChains: buildFilterChains(cfg.Options, cfg.Options.Addr,
func(tlsDomain string, httpDomains []string) *envoy_config_listener_v3.FilterChain { func(tlsDomain string, httpDomains []string) *envoy_config_listener_v3.FilterChain {
filter := buildMainHTTPConnectionManagerFilter(options, httpDomains) filter := buildMainHTTPConnectionManagerFilter(cfg.Options, httpDomains)
filterChain := &envoy_config_listener_v3.FilterChain{ filterChain := &envoy_config_listener_v3.FilterChain{
Filters: []*envoy_config_listener_v3.Filter{filter}, Filters: []*envoy_config_listener_v3.Filter{filter},
} }
@ -103,7 +103,7 @@ func (srv *Server) buildMainListener(options *config.Options) *envoy_config_list
ServerNames: []string{tlsDomain}, ServerNames: []string{tlsDomain},
} }
} }
tlsContext := srv.buildDownstreamTLSContext(options, tlsDomain) tlsContext := srv.buildDownstreamTLSContext(cfg, tlsDomain)
if tlsContext != nil { if tlsContext != nil {
tlsConfig := marshalAny(tlsContext) tlsConfig := marshalAny(tlsContext)
filterChain.TransportSocket = &envoy_config_core_v3.TransportSocket{ filterChain.TransportSocket = &envoy_config_core_v3.TransportSocket{
@ -265,13 +265,13 @@ func buildMainHTTPConnectionManagerFilter(options *config.Options, domains []str
} }
} }
func (srv *Server) buildGRPCListener(options *config.Options) *envoy_config_listener_v3.Listener { func (srv *Server) buildGRPCListener(cfg *config.Config) *envoy_config_listener_v3.Listener {
filter := buildGRPCHTTPConnectionManagerFilter() filter := buildGRPCHTTPConnectionManagerFilter()
if options.GRPCInsecure { if cfg.Options.GRPCInsecure {
return &envoy_config_listener_v3.Listener{ return &envoy_config_listener_v3.Listener{
Name: "grpc-ingress", Name: "grpc-ingress",
Address: buildAddress(options.GRPCAddr, 80), Address: buildAddress(cfg.Options.GRPCAddr, 80),
FilterChains: []*envoy_config_listener_v3.FilterChain{{ FilterChains: []*envoy_config_listener_v3.FilterChain{{
Filters: []*envoy_config_listener_v3.Filter{ Filters: []*envoy_config_listener_v3.Filter{
filter, filter,
@ -283,14 +283,14 @@ func (srv *Server) buildGRPCListener(options *config.Options) *envoy_config_list
tlsInspectorCfg := marshalAny(new(emptypb.Empty)) tlsInspectorCfg := marshalAny(new(emptypb.Empty))
li := &envoy_config_listener_v3.Listener{ li := &envoy_config_listener_v3.Listener{
Name: "grpc-ingress", Name: "grpc-ingress",
Address: buildAddress(options.GRPCAddr, 443), Address: buildAddress(cfg.Options.GRPCAddr, 443),
ListenerFilters: []*envoy_config_listener_v3.ListenerFilter{{ ListenerFilters: []*envoy_config_listener_v3.ListenerFilter{{
Name: "envoy.filters.listener.tls_inspector", Name: "envoy.filters.listener.tls_inspector",
ConfigType: &envoy_config_listener_v3.ListenerFilter_TypedConfig{ ConfigType: &envoy_config_listener_v3.ListenerFilter_TypedConfig{
TypedConfig: tlsInspectorCfg, TypedConfig: tlsInspectorCfg,
}, },
}}, }},
FilterChains: buildFilterChains(options, options.Addr, FilterChains: buildFilterChains(cfg.Options, cfg.Options.Addr,
func(tlsDomain string, httpDomains []string) *envoy_config_listener_v3.FilterChain { func(tlsDomain string, httpDomains []string) *envoy_config_listener_v3.FilterChain {
filterChain := &envoy_config_listener_v3.FilterChain{ filterChain := &envoy_config_listener_v3.FilterChain{
Filters: []*envoy_config_listener_v3.Filter{filter}, Filters: []*envoy_config_listener_v3.Filter{filter},
@ -300,7 +300,7 @@ func (srv *Server) buildGRPCListener(options *config.Options) *envoy_config_list
ServerNames: []string{tlsDomain}, ServerNames: []string{tlsDomain},
} }
} }
tlsContext := srv.buildDownstreamTLSContext(options, tlsDomain) tlsContext := srv.buildDownstreamTLSContext(cfg, tlsDomain)
if tlsContext != nil { if tlsContext != nil {
tlsConfig := marshalAny(tlsContext) tlsConfig := marshalAny(tlsContext)
filterChain.TransportSocket = &envoy_config_core_v3.TransportSocket{ filterChain.TransportSocket = &envoy_config_core_v3.TransportSocket{
@ -372,22 +372,22 @@ func buildRouteConfiguration(name string, virtualHosts []*envoy_config_route_v3.
} }
} }
func (srv *Server) buildDownstreamTLSContext(options *config.Options, domain string) *envoy_extensions_transport_sockets_tls_v3.DownstreamTlsContext { func (srv *Server) buildDownstreamTLSContext(cfg *config.Config, domain string) *envoy_extensions_transport_sockets_tls_v3.DownstreamTlsContext {
cert, err := cryptutil.GetCertificateForDomain(options.Certificates, domain) cert, err := cryptutil.GetCertificateForDomain(cfg.AllCertificates(), domain)
if err != nil { if err != nil {
log.Warn().Str("domain", domain).Err(err).Msg("failed to get certificate for domain") log.Warn().Str("domain", domain).Err(err).Msg("failed to get certificate for domain")
return nil return nil
} }
var trustedCA *envoy_config_core_v3.DataSource var trustedCA *envoy_config_core_v3.DataSource
if options.ClientCA != "" { if cfg.Options.ClientCA != "" {
bs, err := base64.StdEncoding.DecodeString(options.ClientCA) bs, err := base64.StdEncoding.DecodeString(cfg.Options.ClientCA)
if err != nil { if err != nil {
log.Warn().Msg("client_ca does not appear to be a base64 encoded string") log.Warn().Msg("client_ca does not appear to be a base64 encoded string")
} }
trustedCA = srv.filemgr.BytesDataSource("client-ca", bs) trustedCA = srv.filemgr.BytesDataSource("client-ca", bs)
} else if options.ClientCAFile != "" { } else if cfg.Options.ClientCAFile != "" {
trustedCA = srv.filemgr.FileDataSource(options.ClientCAFile) trustedCA = srv.filemgr.FileDataSource(cfg.Options.ClientCAFile)
} }
var validationContext *envoy_extensions_transport_sockets_tls_v3.CommonTlsContext_ValidationContext var validationContext *envoy_extensions_transport_sockets_tls_v3.CommonTlsContext_ValidationContext

View file

@ -390,9 +390,9 @@ func Test_buildDownstreamTLSContext(t *testing.T) {
srv, _ := NewServer("TEST") srv, _ := NewServer("TEST")
downstreamTLSContext := srv.buildDownstreamTLSContext(&config.Options{ downstreamTLSContext := srv.buildDownstreamTLSContext(&config.Config{Options: &config.Options{
Certificates: []tls.Certificate{*certA}, Certificates: []tls.Certificate{*certA},
}, "a.example.com") }}, "a.example.com")
cacheDir, _ := os.UserCacheDir() cacheDir, _ := os.UserCacheDir()
certFileName := filepath.Join(cacheDir, "pomerium", "envoy", "files", "tls-crt-354e49305a5a39414a545530374e58454e48334148524c4e324258463837364355564c4e4532464b54355139495547514a38.pem") certFileName := filepath.Join(cacheDir, "pomerium", "envoy", "files", "tls-crt-354e49305a5a39414a545530374e58454e48334148524c4e324258463837364355564c4e4532464b54355139495547514a38.pem")
@ -509,10 +509,10 @@ func Test_requireProxyProtocol(t *testing.T) {
filemgr: filemgr.NewManager(), filemgr: filemgr.NewManager(),
} }
t.Run("required", func(t *testing.T) { t.Run("required", func(t *testing.T) {
li := srv.buildMainListener(&config.Options{ li := srv.buildMainListener(&config.Config{Options: &config.Options{
UseProxyProtocol: true, UseProxyProtocol: true,
InsecureServer: true, InsecureServer: true,
}) }})
testutil.AssertProtoJSONEqual(t, `[ testutil.AssertProtoJSONEqual(t, `[
{ {
"name": "envoy.filters.listener.proxy_protocol", "name": "envoy.filters.listener.proxy_protocol",
@ -523,10 +523,10 @@ func Test_requireProxyProtocol(t *testing.T) {
]`, li.GetListenerFilters()) ]`, li.GetListenerFilters())
}) })
t.Run("not required", func(t *testing.T) { t.Run("not required", func(t *testing.T) {
li := srv.buildMainListener(&config.Options{ li := srv.buildMainListener(&config.Config{Options: &config.Options{
UseProxyProtocol: false, UseProxyProtocol: false,
InsecureServer: true, InsecureServer: true,
}) }})
assert.Len(t, li.GetListenerFilters(), 0) assert.Len(t, li.GetListenerFilters(), 0)
}) })
} }