config: use getters for certificates (#2001)

* config: use getters for certificates

* update log message
This commit is contained in:
Caleb Doxsey 2021-03-23 08:02:50 -06:00 committed by GitHub
parent 36eeff296a
commit 853d2dd478
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 101 additions and 51 deletions

View file

@ -11,7 +11,6 @@ import (
"os"
"path/filepath"
"reflect"
"sort"
"strings"
"sync/atomic"
"time"
@ -95,8 +94,6 @@ type Options struct {
CertFile string `mapstructure:"certificate_file" yaml:"certificate_file,omitempty"`
KeyFile string `mapstructure:"certificate_key_file" yaml:"certificate_key_file,omitempty"`
Certificates []tls.Certificate `mapstructure:"-" yaml:"-"`
// HttpRedirectAddr, if set, specifies the host and port to run the HTTP
// to HTTPS redirect server on. If empty, no redirect server is started.
HTTPRedirectAddr string `mapstructure:"http_redirect_addr" yaml:"http_redirect_addr,omitempty"`
@ -255,8 +252,6 @@ type Options struct {
DataBrokerStorageCAFile string `mapstructure:"databroker_storage_ca_file" yaml:"databroker_storage_ca_file,omitempty"`
DataBrokerStorageCertSkipVerify bool `mapstructure:"databroker_storage_tls_skip_verify" yaml:"databroker_storage_tls_skip_verify,omitempty"`
DataBrokerCertificate *tls.Certificate `mapstructure:"-" yaml:"-"`
// ClientCA is the base64-encoded certificate authority to validate client mTLS certificates against.
ClientCA string `mapstructure:"client_ca" yaml:"client_ca,omitempty"`
// ClientCAFile points to a file that contains the certificate authority to validate client mTLS certificates against.
@ -594,39 +589,40 @@ func (o *Options) Validate() error {
o.Headers = make(map[string]string)
}
hasCert := false
if o.Cert != "" || o.Key != "" {
cert, err := cryptutil.CertificateFromBase64(o.Cert, o.Key)
_, err := cryptutil.CertificateFromBase64(o.Cert, o.Key)
if err != nil {
return fmt.Errorf("config: bad cert base64 %w", err)
}
o.Certificates = append(o.Certificates, *cert)
hasCert = true
}
for _, c := range o.CertificateFiles {
cert, err := cryptutil.CertificateFromBase64(c.CertFile, c.KeyFile)
_, err := cryptutil.CertificateFromBase64(c.CertFile, c.KeyFile)
if err != nil {
cert, err = cryptutil.CertificateFromFile(c.CertFile, c.KeyFile)
_, err = cryptutil.CertificateFromFile(c.CertFile, c.KeyFile)
}
if err != nil {
return fmt.Errorf("config: bad cert entry, base64 or file reference invalid. %w", err)
}
o.Certificates = append(o.Certificates, *cert)
hasCert = true
}
if o.CertFile != "" || o.KeyFile != "" {
cert, err := cryptutil.CertificateFromFile(o.CertFile, o.KeyFile)
_, err := cryptutil.CertificateFromFile(o.CertFile, o.KeyFile)
if err != nil {
return fmt.Errorf("config: bad cert file %w", err)
}
o.Certificates = append(o.Certificates, *cert)
hasCert = true
}
if o.DataBrokerStorageCertFile != "" || o.DataBrokerStorageCertKeyFile != "" {
cert, err := cryptutil.CertificateFromFile(o.DataBrokerStorageCertFile, o.DataBrokerStorageCertKeyFile)
_, err := cryptutil.CertificateFromFile(o.DataBrokerStorageCertFile, o.DataBrokerStorageCertKeyFile)
if err != nil {
return fmt.Errorf("config: bad databroker cert file %w", err)
}
o.DataBrokerCertificate = cert
}
if o.DataBrokerStorageCAFile != "" {
@ -642,11 +638,10 @@ func (o *Options) Validate() error {
}
if o.ClientCAFile != "" {
bs, err := ioutil.ReadFile(o.ClientCAFile)
_, err := ioutil.ReadFile(o.ClientCAFile)
if err != nil {
return fmt.Errorf("config: bad client ca file: %w", err)
}
o.ClientCA = base64.StdEncoding.EncodeToString(bs)
}
// if no service account was defined, there should not be any policies that
@ -670,12 +665,7 @@ func (o *Options) Validate() error {
// strip quotes from redirect address (#811)
o.HTTPRedirectAddr = strings.Trim(o.HTTPRedirectAddr, `"'`)
// sort the certificates so we get a consistent hash
sort.Slice(o.Certificates, func(i, j int) bool {
return compareByteSliceSlice(o.Certificates[i].Certificate, o.Certificates[j].Certificate) < 0
})
if !o.InsecureServer && len(o.Certificates) == 0 && !o.AutocertOptions.Enable {
if !o.InsecureServer && !hasCert && !o.AutocertOptions.Enable {
return fmt.Errorf("config: server must be run with `autocert`, " +
"`insecure_server` or manually provided certificates to start")
}
@ -851,6 +841,57 @@ func (o *Options) GetMetricsBasicAuth() (username, password string, ok bool) {
return string(bs[:idx]), string(bs[idx+1:]), true
}
// GetClientCA returns the client certificate authority. If neither client_ca nor client_ca_file is specified nil will
// be returned.
func (o *Options) GetClientCA() ([]byte, error) {
if o.ClientCA != "" {
return base64.StdEncoding.DecodeString(o.ClientCA)
}
if o.ClientCAFile != "" {
return ioutil.ReadFile(o.ClientCAFile)
}
return nil, nil
}
// GetDataBrokerCertificate gets the optional databroker certificate. This method will return nil if no certificate is
// specified.
func (o *Options) GetDataBrokerCertificate() (*tls.Certificate, error) {
if o.DataBrokerStorageCertFile == "" || o.DataBrokerStorageCertKeyFile == "" {
return nil, nil
}
return cryptutil.CertificateFromFile(o.DataBrokerStorageCertFile, o.DataBrokerStorageCertKeyFile)
}
// GetCertificates gets all the certificates from the options.
func (o *Options) GetCertificates() ([]tls.Certificate, error) {
var certs []tls.Certificate
if o.Cert != "" && o.Key != "" {
cert, err := cryptutil.CertificateFromBase64(o.Cert, o.Key)
if err != nil {
return nil, fmt.Errorf("config: invalid base64 certificate: %w", err)
}
certs = append(certs, *cert)
}
for _, c := range o.CertificateFiles {
cert, err := cryptutil.CertificateFromBase64(c.CertFile, c.KeyFile)
if err != nil {
cert, err = cryptutil.CertificateFromFile(c.CertFile, c.KeyFile)
}
if err != nil {
return nil, fmt.Errorf("config: invalid certificate entry: %w", err)
}
certs = append(certs, *cert)
}
if o.CertFile != "" && o.KeyFile != "" {
cert, err := cryptutil.CertificateFromFile(o.CertFile, o.KeyFile)
if err != nil {
return nil, fmt.Errorf("config: bad cert file %w", err)
}
certs = append(certs, *cert)
}
return certs, nil
}
// Checksum returns the checksum of the current options struct
func (o *Options) Checksum() uint64 {
return hashutil.MustHash(o)