mirror of
https://github.com/pomerium/pomerium.git
synced 2025-04-29 10:26:29 +02:00
config: use getters for certificates (#2001)
* config: use getters for certificates * update log message
This commit is contained in:
parent
36eeff296a
commit
853d2dd478
8 changed files with 101 additions and 51 deletions
|
@ -228,21 +228,20 @@ func (a *Authorize) getEvaluatorRequestFromCheckRequest(
|
|||
|
||||
func (a *Authorize) getDownstreamClientCA(policy *config.Policy) (string, error) {
|
||||
options := a.currentOptions.Load()
|
||||
switch {
|
||||
case policy != nil && policy.TLSDownstreamClientCA != "":
|
||||
|
||||
if policy != nil && policy.TLSDownstreamClientCA != "" {
|
||||
bs, err := base64.StdEncoding.DecodeString(policy.TLSDownstreamClientCA)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(bs), nil
|
||||
case options.ClientCA != "":
|
||||
bs, err := base64.StdEncoding.DecodeString(options.ClientCA)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(bs), nil
|
||||
}
|
||||
return "", nil
|
||||
|
||||
ca, err := options.GetClientCA()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(ca), nil
|
||||
}
|
||||
|
||||
func (a *Authorize) getMatchingPolicy(requestURL url.URL) *config.Policy {
|
||||
|
|
|
@ -23,11 +23,16 @@ func (cfg *Config) Clone() *Config {
|
|||
}
|
||||
|
||||
// AllCertificates returns all the certificates in the config.
|
||||
func (cfg *Config) AllCertificates() []tls.Certificate {
|
||||
func (cfg *Config) AllCertificates() ([]tls.Certificate, error) {
|
||||
optionCertificates, err := cfg.Options.GetCertificates()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var certs []tls.Certificate
|
||||
certs = append(certs, cfg.Options.Certificates...)
|
||||
certs = append(certs, optionCertificates...)
|
||||
certs = append(certs, cfg.AutoCertificates...)
|
||||
return certs
|
||||
return certs, nil
|
||||
}
|
||||
|
||||
// Checksum returns the config checksum.
|
||||
|
|
|
@ -216,8 +216,6 @@ func (src *FileWatcherSource) check(cfg *Config) {
|
|||
|
||||
// update the computed config
|
||||
src.computedConfig = cfg.Clone()
|
||||
src.computedConfig.Options.Certificates = nil
|
||||
_ = src.computedConfig.Options.Validate()
|
||||
|
||||
// trigger a change
|
||||
src.Trigger(src.computedConfig)
|
||||
|
|
|
@ -11,7 +11,6 @@ import (
|
|||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
@ -95,8 +94,6 @@ type Options struct {
|
|||
CertFile string `mapstructure:"certificate_file" yaml:"certificate_file,omitempty"`
|
||||
KeyFile string `mapstructure:"certificate_key_file" yaml:"certificate_key_file,omitempty"`
|
||||
|
||||
Certificates []tls.Certificate `mapstructure:"-" yaml:"-"`
|
||||
|
||||
// HttpRedirectAddr, if set, specifies the host and port to run the HTTP
|
||||
// to HTTPS redirect server on. If empty, no redirect server is started.
|
||||
HTTPRedirectAddr string `mapstructure:"http_redirect_addr" yaml:"http_redirect_addr,omitempty"`
|
||||
|
@ -255,8 +252,6 @@ type Options struct {
|
|||
DataBrokerStorageCAFile string `mapstructure:"databroker_storage_ca_file" yaml:"databroker_storage_ca_file,omitempty"`
|
||||
DataBrokerStorageCertSkipVerify bool `mapstructure:"databroker_storage_tls_skip_verify" yaml:"databroker_storage_tls_skip_verify,omitempty"`
|
||||
|
||||
DataBrokerCertificate *tls.Certificate `mapstructure:"-" yaml:"-"`
|
||||
|
||||
// ClientCA is the base64-encoded certificate authority to validate client mTLS certificates against.
|
||||
ClientCA string `mapstructure:"client_ca" yaml:"client_ca,omitempty"`
|
||||
// ClientCAFile points to a file that contains the certificate authority to validate client mTLS certificates against.
|
||||
|
@ -594,39 +589,40 @@ func (o *Options) Validate() error {
|
|||
o.Headers = make(map[string]string)
|
||||
}
|
||||
|
||||
hasCert := false
|
||||
|
||||
if o.Cert != "" || o.Key != "" {
|
||||
cert, err := cryptutil.CertificateFromBase64(o.Cert, o.Key)
|
||||
_, err := cryptutil.CertificateFromBase64(o.Cert, o.Key)
|
||||
if err != nil {
|
||||
return fmt.Errorf("config: bad cert base64 %w", err)
|
||||
}
|
||||
o.Certificates = append(o.Certificates, *cert)
|
||||
hasCert = true
|
||||
}
|
||||
|
||||
for _, c := range o.CertificateFiles {
|
||||
cert, err := cryptutil.CertificateFromBase64(c.CertFile, c.KeyFile)
|
||||
_, err := cryptutil.CertificateFromBase64(c.CertFile, c.KeyFile)
|
||||
if err != nil {
|
||||
cert, err = cryptutil.CertificateFromFile(c.CertFile, c.KeyFile)
|
||||
_, err = cryptutil.CertificateFromFile(c.CertFile, c.KeyFile)
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("config: bad cert entry, base64 or file reference invalid. %w", err)
|
||||
}
|
||||
o.Certificates = append(o.Certificates, *cert)
|
||||
hasCert = true
|
||||
}
|
||||
|
||||
if o.CertFile != "" || o.KeyFile != "" {
|
||||
cert, err := cryptutil.CertificateFromFile(o.CertFile, o.KeyFile)
|
||||
_, err := cryptutil.CertificateFromFile(o.CertFile, o.KeyFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("config: bad cert file %w", err)
|
||||
}
|
||||
o.Certificates = append(o.Certificates, *cert)
|
||||
hasCert = true
|
||||
}
|
||||
|
||||
if o.DataBrokerStorageCertFile != "" || o.DataBrokerStorageCertKeyFile != "" {
|
||||
cert, err := cryptutil.CertificateFromFile(o.DataBrokerStorageCertFile, o.DataBrokerStorageCertKeyFile)
|
||||
_, err := cryptutil.CertificateFromFile(o.DataBrokerStorageCertFile, o.DataBrokerStorageCertKeyFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("config: bad databroker cert file %w", err)
|
||||
}
|
||||
o.DataBrokerCertificate = cert
|
||||
}
|
||||
|
||||
if o.DataBrokerStorageCAFile != "" {
|
||||
|
@ -642,11 +638,10 @@ func (o *Options) Validate() error {
|
|||
}
|
||||
|
||||
if o.ClientCAFile != "" {
|
||||
bs, err := ioutil.ReadFile(o.ClientCAFile)
|
||||
_, err := ioutil.ReadFile(o.ClientCAFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("config: bad client ca file: %w", err)
|
||||
}
|
||||
o.ClientCA = base64.StdEncoding.EncodeToString(bs)
|
||||
}
|
||||
|
||||
// if no service account was defined, there should not be any policies that
|
||||
|
@ -670,12 +665,7 @@ func (o *Options) Validate() error {
|
|||
// strip quotes from redirect address (#811)
|
||||
o.HTTPRedirectAddr = strings.Trim(o.HTTPRedirectAddr, `"'`)
|
||||
|
||||
// sort the certificates so we get a consistent hash
|
||||
sort.Slice(o.Certificates, func(i, j int) bool {
|
||||
return compareByteSliceSlice(o.Certificates[i].Certificate, o.Certificates[j].Certificate) < 0
|
||||
})
|
||||
|
||||
if !o.InsecureServer && len(o.Certificates) == 0 && !o.AutocertOptions.Enable {
|
||||
if !o.InsecureServer && !hasCert && !o.AutocertOptions.Enable {
|
||||
return fmt.Errorf("config: server must be run with `autocert`, " +
|
||||
"`insecure_server` or manually provided certificates to start")
|
||||
}
|
||||
|
@ -851,6 +841,57 @@ func (o *Options) GetMetricsBasicAuth() (username, password string, ok bool) {
|
|||
return string(bs[:idx]), string(bs[idx+1:]), true
|
||||
}
|
||||
|
||||
// GetClientCA returns the client certificate authority. If neither client_ca nor client_ca_file is specified nil will
|
||||
// be returned.
|
||||
func (o *Options) GetClientCA() ([]byte, error) {
|
||||
if o.ClientCA != "" {
|
||||
return base64.StdEncoding.DecodeString(o.ClientCA)
|
||||
}
|
||||
if o.ClientCAFile != "" {
|
||||
return ioutil.ReadFile(o.ClientCAFile)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// GetDataBrokerCertificate gets the optional databroker certificate. This method will return nil if no certificate is
|
||||
// specified.
|
||||
func (o *Options) GetDataBrokerCertificate() (*tls.Certificate, error) {
|
||||
if o.DataBrokerStorageCertFile == "" || o.DataBrokerStorageCertKeyFile == "" {
|
||||
return nil, nil
|
||||
}
|
||||
return cryptutil.CertificateFromFile(o.DataBrokerStorageCertFile, o.DataBrokerStorageCertKeyFile)
|
||||
}
|
||||
|
||||
// GetCertificates gets all the certificates from the options.
|
||||
func (o *Options) GetCertificates() ([]tls.Certificate, error) {
|
||||
var certs []tls.Certificate
|
||||
if o.Cert != "" && o.Key != "" {
|
||||
cert, err := cryptutil.CertificateFromBase64(o.Cert, o.Key)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("config: invalid base64 certificate: %w", err)
|
||||
}
|
||||
certs = append(certs, *cert)
|
||||
}
|
||||
for _, c := range o.CertificateFiles {
|
||||
cert, err := cryptutil.CertificateFromBase64(c.CertFile, c.KeyFile)
|
||||
if err != nil {
|
||||
cert, err = cryptutil.CertificateFromFile(c.CertFile, c.KeyFile)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("config: invalid certificate entry: %w", err)
|
||||
}
|
||||
certs = append(certs, *cert)
|
||||
}
|
||||
if o.CertFile != "" && o.KeyFile != "" {
|
||||
cert, err := cryptutil.CertificateFromFile(o.CertFile, o.KeyFile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("config: bad cert file %w", err)
|
||||
}
|
||||
certs = append(certs, *cert)
|
||||
}
|
||||
return certs, nil
|
||||
}
|
||||
|
||||
// Checksum returns the checksum of the current options struct
|
||||
func (o *Options) Checksum() uint64 {
|
||||
return hashutil.MustHash(o)
|
||||
|
|
|
@ -33,12 +33,13 @@ func (srv *dataBrokerServer) OnConfigChange(cfg *config.Config) {
|
|||
}
|
||||
|
||||
func (srv *dataBrokerServer) getOptions(cfg *config.Config) []databroker.ServerOption {
|
||||
cert, _ := cfg.Options.GetDataBrokerCertificate()
|
||||
return []databroker.ServerOption{
|
||||
databroker.WithSharedKey(cfg.Options.SharedKey),
|
||||
databroker.WithStorageType(cfg.Options.DataBrokerStorageType),
|
||||
databroker.WithStorageConnectionString(cfg.Options.DataBrokerStorageConnectionString),
|
||||
databroker.WithStorageCAFile(cfg.Options.DataBrokerStorageCAFile),
|
||||
databroker.WithStorageCertificate(cfg.Options.DataBrokerCertificate),
|
||||
databroker.WithStorageCertificate(cert),
|
||||
databroker.WithStorageCertSkipVerify(cfg.Options.DataBrokerStorageCertSkipVerify),
|
||||
}
|
||||
}
|
||||
|
|
|
@ -102,8 +102,12 @@ func (mgr *Manager) getCertMagicConfig(cfg *config.Config) (*certmagic.Config, e
|
|||
mgr.certmagic.MustStaple = cfg.Options.AutocertOptions.MustStaple
|
||||
mgr.certmagic.OnDemand = nil // disable on-demand
|
||||
mgr.certmagic.Storage = &certmagic.FileStorage{Path: cfg.Options.AutocertOptions.Folder}
|
||||
certs, err := cfg.AllCertificates()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// add existing certs to the cache, and staple OCSP
|
||||
for _, cert := range cfg.AllCertificates() {
|
||||
for _, cert := range certs {
|
||||
if err := mgr.certmagic.CacheUnmanagedTLSCertificate(cert, nil); err != nil {
|
||||
return nil, fmt.Errorf("config: failed caching cert: %w", err)
|
||||
}
|
||||
|
|
|
@ -631,7 +631,13 @@ func (srv *Server) buildRouteConfiguration(name string, virtualHosts []*envoy_co
|
|||
}
|
||||
|
||||
func (srv *Server) buildDownstreamTLSContext(cfg *config.Config, domain string) *envoy_extensions_transport_sockets_tls_v3.DownstreamTlsContext {
|
||||
cert, err := cryptutil.GetCertificateForDomain(cfg.AllCertificates(), domain)
|
||||
certs, err := cfg.AllCertificates()
|
||||
if err != nil {
|
||||
log.Warn().Str("domain", domain).Err(err).Msg("failed to get all certificates from config")
|
||||
return nil
|
||||
}
|
||||
|
||||
cert, err := cryptutil.GetCertificateForDomain(certs, domain)
|
||||
if err != nil {
|
||||
log.Warn().Str("domain", domain).Err(err).Msg("failed to get certificate for domain")
|
||||
return nil
|
||||
|
@ -792,7 +798,7 @@ func getDownstreamValidationContext(
|
|||
) *envoy_extensions_transport_sockets_tls_v3.CommonTlsContext_ValidationContext {
|
||||
needsClientCert := false
|
||||
|
||||
if cfg.Options.ClientCA != "" {
|
||||
if ca, _ := cfg.Options.GetClientCA(); len(ca) > 0 {
|
||||
needsClientCert = true
|
||||
}
|
||||
if !needsClientCert {
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
package controlplane
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
@ -13,7 +12,6 @@ import (
|
|||
"github.com/pomerium/pomerium/config"
|
||||
"github.com/pomerium/pomerium/internal/controlplane/filemgr"
|
||||
"github.com/pomerium/pomerium/internal/testutil"
|
||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -469,11 +467,6 @@ func Test_buildMainHTTPConnectionManagerFilter(t *testing.T) {
|
|||
}
|
||||
|
||||
func Test_buildDownstreamTLSContext(t *testing.T) {
|
||||
certA, err := cryptutil.CertificateFromBase64(aExampleComCert, aExampleComKey)
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
|
||||
srv, _ := NewServer("TEST", nil)
|
||||
|
||||
cacheDir, _ := os.UserCacheDir()
|
||||
|
@ -482,7 +475,8 @@ func Test_buildDownstreamTLSContext(t *testing.T) {
|
|||
|
||||
t.Run("no-validation", func(t *testing.T) {
|
||||
downstreamTLSContext := srv.buildDownstreamTLSContext(&config.Config{Options: &config.Options{
|
||||
Certificates: []tls.Certificate{*certA},
|
||||
Cert: aExampleComCert,
|
||||
Key: aExampleComKey,
|
||||
}}, "a.example.com")
|
||||
|
||||
testutil.AssertProtoJSONEqual(t, `{
|
||||
|
@ -514,8 +508,9 @@ func Test_buildDownstreamTLSContext(t *testing.T) {
|
|||
})
|
||||
t.Run("client-ca", func(t *testing.T) {
|
||||
downstreamTLSContext := srv.buildDownstreamTLSContext(&config.Config{Options: &config.Options{
|
||||
Certificates: []tls.Certificate{*certA},
|
||||
ClientCA: "TEST",
|
||||
Cert: aExampleComCert,
|
||||
Key: aExampleComKey,
|
||||
ClientCA: "TEST",
|
||||
}}, "a.example.com")
|
||||
|
||||
testutil.AssertProtoJSONEqual(t, `{
|
||||
|
@ -550,7 +545,8 @@ func Test_buildDownstreamTLSContext(t *testing.T) {
|
|||
})
|
||||
t.Run("policy-client-ca", func(t *testing.T) {
|
||||
downstreamTLSContext := srv.buildDownstreamTLSContext(&config.Config{Options: &config.Options{
|
||||
Certificates: []tls.Certificate{*certA},
|
||||
Cert: aExampleComCert,
|
||||
Key: aExampleComKey,
|
||||
Policies: []config.Policy{
|
||||
{
|
||||
Source: &config.StringURL{URL: mustParseURL(t, "https://a.example.com")},
|
||||
|
|
Loading…
Add table
Reference in a new issue