diff --git a/authorize/grpc.go b/authorize/grpc.go index 36ba20bb3..f9e19d1e3 100644 --- a/authorize/grpc.go +++ b/authorize/grpc.go @@ -228,21 +228,20 @@ func (a *Authorize) getEvaluatorRequestFromCheckRequest( func (a *Authorize) getDownstreamClientCA(policy *config.Policy) (string, error) { options := a.currentOptions.Load() - switch { - case policy != nil && policy.TLSDownstreamClientCA != "": + + if policy != nil && policy.TLSDownstreamClientCA != "" { bs, err := base64.StdEncoding.DecodeString(policy.TLSDownstreamClientCA) if err != nil { return "", err } return string(bs), nil - case options.ClientCA != "": - bs, err := base64.StdEncoding.DecodeString(options.ClientCA) - if err != nil { - return "", err - } - return string(bs), nil } - return "", nil + + ca, err := options.GetClientCA() + if err != nil { + return "", err + } + return string(ca), nil } func (a *Authorize) getMatchingPolicy(requestURL url.URL) *config.Policy { diff --git a/config/config.go b/config/config.go index aaabb592f..2f7609085 100644 --- a/config/config.go +++ b/config/config.go @@ -23,11 +23,16 @@ func (cfg *Config) Clone() *Config { } // AllCertificates returns all the certificates in the config. -func (cfg *Config) AllCertificates() []tls.Certificate { +func (cfg *Config) AllCertificates() ([]tls.Certificate, error) { + optionCertificates, err := cfg.Options.GetCertificates() + if err != nil { + return nil, err + } + var certs []tls.Certificate - certs = append(certs, cfg.Options.Certificates...) + certs = append(certs, optionCertificates...) certs = append(certs, cfg.AutoCertificates...) - return certs + return certs, nil } // Checksum returns the config checksum. diff --git a/config/config_source.go b/config/config_source.go index 3bc1edafa..10e5e1757 100644 --- a/config/config_source.go +++ b/config/config_source.go @@ -216,8 +216,6 @@ func (src *FileWatcherSource) check(cfg *Config) { // update the computed config src.computedConfig = cfg.Clone() - src.computedConfig.Options.Certificates = nil - _ = src.computedConfig.Options.Validate() // trigger a change src.Trigger(src.computedConfig) diff --git a/config/options.go b/config/options.go index eac06d0f2..8ca1549f6 100644 --- a/config/options.go +++ b/config/options.go @@ -11,7 +11,6 @@ import ( "os" "path/filepath" "reflect" - "sort" "strings" "sync/atomic" "time" @@ -95,8 +94,6 @@ type Options struct { CertFile string `mapstructure:"certificate_file" yaml:"certificate_file,omitempty"` KeyFile string `mapstructure:"certificate_key_file" yaml:"certificate_key_file,omitempty"` - Certificates []tls.Certificate `mapstructure:"-" yaml:"-"` - // HttpRedirectAddr, if set, specifies the host and port to run the HTTP // to HTTPS redirect server on. If empty, no redirect server is started. HTTPRedirectAddr string `mapstructure:"http_redirect_addr" yaml:"http_redirect_addr,omitempty"` @@ -255,8 +252,6 @@ type Options struct { DataBrokerStorageCAFile string `mapstructure:"databroker_storage_ca_file" yaml:"databroker_storage_ca_file,omitempty"` DataBrokerStorageCertSkipVerify bool `mapstructure:"databroker_storage_tls_skip_verify" yaml:"databroker_storage_tls_skip_verify,omitempty"` - DataBrokerCertificate *tls.Certificate `mapstructure:"-" yaml:"-"` - // ClientCA is the base64-encoded certificate authority to validate client mTLS certificates against. ClientCA string `mapstructure:"client_ca" yaml:"client_ca,omitempty"` // ClientCAFile points to a file that contains the certificate authority to validate client mTLS certificates against. @@ -594,39 +589,40 @@ func (o *Options) Validate() error { o.Headers = make(map[string]string) } + hasCert := false + if o.Cert != "" || o.Key != "" { - cert, err := cryptutil.CertificateFromBase64(o.Cert, o.Key) + _, err := cryptutil.CertificateFromBase64(o.Cert, o.Key) if err != nil { return fmt.Errorf("config: bad cert base64 %w", err) } - o.Certificates = append(o.Certificates, *cert) + hasCert = true } for _, c := range o.CertificateFiles { - cert, err := cryptutil.CertificateFromBase64(c.CertFile, c.KeyFile) + _, err := cryptutil.CertificateFromBase64(c.CertFile, c.KeyFile) if err != nil { - cert, err = cryptutil.CertificateFromFile(c.CertFile, c.KeyFile) + _, err = cryptutil.CertificateFromFile(c.CertFile, c.KeyFile) } if err != nil { return fmt.Errorf("config: bad cert entry, base64 or file reference invalid. %w", err) } - o.Certificates = append(o.Certificates, *cert) + hasCert = true } if o.CertFile != "" || o.KeyFile != "" { - cert, err := cryptutil.CertificateFromFile(o.CertFile, o.KeyFile) + _, err := cryptutil.CertificateFromFile(o.CertFile, o.KeyFile) if err != nil { return fmt.Errorf("config: bad cert file %w", err) } - o.Certificates = append(o.Certificates, *cert) + hasCert = true } if o.DataBrokerStorageCertFile != "" || o.DataBrokerStorageCertKeyFile != "" { - cert, err := cryptutil.CertificateFromFile(o.DataBrokerStorageCertFile, o.DataBrokerStorageCertKeyFile) + _, err := cryptutil.CertificateFromFile(o.DataBrokerStorageCertFile, o.DataBrokerStorageCertKeyFile) if err != nil { return fmt.Errorf("config: bad databroker cert file %w", err) } - o.DataBrokerCertificate = cert } if o.DataBrokerStorageCAFile != "" { @@ -642,11 +638,10 @@ func (o *Options) Validate() error { } if o.ClientCAFile != "" { - bs, err := ioutil.ReadFile(o.ClientCAFile) + _, err := ioutil.ReadFile(o.ClientCAFile) if err != nil { return fmt.Errorf("config: bad client ca file: %w", err) } - o.ClientCA = base64.StdEncoding.EncodeToString(bs) } // if no service account was defined, there should not be any policies that @@ -670,12 +665,7 @@ func (o *Options) Validate() error { // strip quotes from redirect address (#811) o.HTTPRedirectAddr = strings.Trim(o.HTTPRedirectAddr, `"'`) - // sort the certificates so we get a consistent hash - sort.Slice(o.Certificates, func(i, j int) bool { - return compareByteSliceSlice(o.Certificates[i].Certificate, o.Certificates[j].Certificate) < 0 - }) - - if !o.InsecureServer && len(o.Certificates) == 0 && !o.AutocertOptions.Enable { + if !o.InsecureServer && !hasCert && !o.AutocertOptions.Enable { return fmt.Errorf("config: server must be run with `autocert`, " + "`insecure_server` or manually provided certificates to start") } @@ -851,6 +841,57 @@ func (o *Options) GetMetricsBasicAuth() (username, password string, ok bool) { return string(bs[:idx]), string(bs[idx+1:]), true } +// GetClientCA returns the client certificate authority. If neither client_ca nor client_ca_file is specified nil will +// be returned. +func (o *Options) GetClientCA() ([]byte, error) { + if o.ClientCA != "" { + return base64.StdEncoding.DecodeString(o.ClientCA) + } + if o.ClientCAFile != "" { + return ioutil.ReadFile(o.ClientCAFile) + } + return nil, nil +} + +// GetDataBrokerCertificate gets the optional databroker certificate. This method will return nil if no certificate is +// specified. +func (o *Options) GetDataBrokerCertificate() (*tls.Certificate, error) { + if o.DataBrokerStorageCertFile == "" || o.DataBrokerStorageCertKeyFile == "" { + return nil, nil + } + return cryptutil.CertificateFromFile(o.DataBrokerStorageCertFile, o.DataBrokerStorageCertKeyFile) +} + +// GetCertificates gets all the certificates from the options. +func (o *Options) GetCertificates() ([]tls.Certificate, error) { + var certs []tls.Certificate + if o.Cert != "" && o.Key != "" { + cert, err := cryptutil.CertificateFromBase64(o.Cert, o.Key) + if err != nil { + return nil, fmt.Errorf("config: invalid base64 certificate: %w", err) + } + certs = append(certs, *cert) + } + for _, c := range o.CertificateFiles { + cert, err := cryptutil.CertificateFromBase64(c.CertFile, c.KeyFile) + if err != nil { + cert, err = cryptutil.CertificateFromFile(c.CertFile, c.KeyFile) + } + if err != nil { + return nil, fmt.Errorf("config: invalid certificate entry: %w", err) + } + certs = append(certs, *cert) + } + if o.CertFile != "" && o.KeyFile != "" { + cert, err := cryptutil.CertificateFromFile(o.CertFile, o.KeyFile) + if err != nil { + return nil, fmt.Errorf("config: bad cert file %w", err) + } + certs = append(certs, *cert) + } + return certs, nil +} + // Checksum returns the checksum of the current options struct func (o *Options) Checksum() uint64 { return hashutil.MustHash(o) diff --git a/databroker/databroker.go b/databroker/databroker.go index 329f6c3cd..8ecf763ee 100644 --- a/databroker/databroker.go +++ b/databroker/databroker.go @@ -33,12 +33,13 @@ func (srv *dataBrokerServer) OnConfigChange(cfg *config.Config) { } func (srv *dataBrokerServer) getOptions(cfg *config.Config) []databroker.ServerOption { + cert, _ := cfg.Options.GetDataBrokerCertificate() return []databroker.ServerOption{ databroker.WithSharedKey(cfg.Options.SharedKey), databroker.WithStorageType(cfg.Options.DataBrokerStorageType), databroker.WithStorageConnectionString(cfg.Options.DataBrokerStorageConnectionString), databroker.WithStorageCAFile(cfg.Options.DataBrokerStorageCAFile), - databroker.WithStorageCertificate(cfg.Options.DataBrokerCertificate), + databroker.WithStorageCertificate(cert), databroker.WithStorageCertSkipVerify(cfg.Options.DataBrokerStorageCertSkipVerify), } } diff --git a/internal/autocert/manager.go b/internal/autocert/manager.go index 4b5b2ebca..2b40b131b 100644 --- a/internal/autocert/manager.go +++ b/internal/autocert/manager.go @@ -102,8 +102,12 @@ func (mgr *Manager) getCertMagicConfig(cfg *config.Config) (*certmagic.Config, e mgr.certmagic.MustStaple = cfg.Options.AutocertOptions.MustStaple mgr.certmagic.OnDemand = nil // disable on-demand mgr.certmagic.Storage = &certmagic.FileStorage{Path: cfg.Options.AutocertOptions.Folder} + certs, err := cfg.AllCertificates() + if err != nil { + return nil, err + } // add existing certs to the cache, and staple OCSP - for _, cert := range cfg.AllCertificates() { + for _, cert := range certs { if err := mgr.certmagic.CacheUnmanagedTLSCertificate(cert, nil); err != nil { return nil, fmt.Errorf("config: failed caching cert: %w", err) } diff --git a/internal/controlplane/xds_listeners.go b/internal/controlplane/xds_listeners.go index ac43ed88f..aea84066b 100644 --- a/internal/controlplane/xds_listeners.go +++ b/internal/controlplane/xds_listeners.go @@ -631,7 +631,13 @@ func (srv *Server) buildRouteConfiguration(name string, virtualHosts []*envoy_co } func (srv *Server) buildDownstreamTLSContext(cfg *config.Config, domain string) *envoy_extensions_transport_sockets_tls_v3.DownstreamTlsContext { - cert, err := cryptutil.GetCertificateForDomain(cfg.AllCertificates(), domain) + certs, err := cfg.AllCertificates() + if err != nil { + log.Warn().Str("domain", domain).Err(err).Msg("failed to get all certificates from config") + return nil + } + + cert, err := cryptutil.GetCertificateForDomain(certs, domain) if err != nil { log.Warn().Str("domain", domain).Err(err).Msg("failed to get certificate for domain") return nil @@ -792,7 +798,7 @@ func getDownstreamValidationContext( ) *envoy_extensions_transport_sockets_tls_v3.CommonTlsContext_ValidationContext { needsClientCert := false - if cfg.Options.ClientCA != "" { + if ca, _ := cfg.Options.GetClientCA(); len(ca) > 0 { needsClientCert = true } if !needsClientCert { diff --git a/internal/controlplane/xds_listeners_test.go b/internal/controlplane/xds_listeners_test.go index 7f3066559..88e185de8 100644 --- a/internal/controlplane/xds_listeners_test.go +++ b/internal/controlplane/xds_listeners_test.go @@ -1,7 +1,6 @@ package controlplane import ( - "crypto/tls" "os" "path/filepath" "testing" @@ -13,7 +12,6 @@ import ( "github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/internal/controlplane/filemgr" "github.com/pomerium/pomerium/internal/testutil" - "github.com/pomerium/pomerium/pkg/cryptutil" ) const ( @@ -469,11 +467,6 @@ func Test_buildMainHTTPConnectionManagerFilter(t *testing.T) { } func Test_buildDownstreamTLSContext(t *testing.T) { - certA, err := cryptutil.CertificateFromBase64(aExampleComCert, aExampleComKey) - if !assert.NoError(t, err) { - return - } - srv, _ := NewServer("TEST", nil) cacheDir, _ := os.UserCacheDir() @@ -482,7 +475,8 @@ func Test_buildDownstreamTLSContext(t *testing.T) { t.Run("no-validation", func(t *testing.T) { downstreamTLSContext := srv.buildDownstreamTLSContext(&config.Config{Options: &config.Options{ - Certificates: []tls.Certificate{*certA}, + Cert: aExampleComCert, + Key: aExampleComKey, }}, "a.example.com") testutil.AssertProtoJSONEqual(t, `{ @@ -514,8 +508,9 @@ func Test_buildDownstreamTLSContext(t *testing.T) { }) t.Run("client-ca", func(t *testing.T) { downstreamTLSContext := srv.buildDownstreamTLSContext(&config.Config{Options: &config.Options{ - Certificates: []tls.Certificate{*certA}, - ClientCA: "TEST", + Cert: aExampleComCert, + Key: aExampleComKey, + ClientCA: "TEST", }}, "a.example.com") testutil.AssertProtoJSONEqual(t, `{ @@ -550,7 +545,8 @@ func Test_buildDownstreamTLSContext(t *testing.T) { }) t.Run("policy-client-ca", func(t *testing.T) { downstreamTLSContext := srv.buildDownstreamTLSContext(&config.Config{Options: &config.Options{ - Certificates: []tls.Certificate{*certA}, + Cert: aExampleComCert, + Key: aExampleComKey, Policies: []config.Policy{ { Source: &config.StringURL{URL: mustParseURL(t, "https://a.example.com")},