From dccec1e646b32a725cb8eae26df06c56a08aa2d6 Mon Sep 17 00:00:00 2001 From: Caleb Doxsey Date: Wed, 13 May 2020 13:07:04 -0600 Subject: [PATCH] envoy: support autocert (#695) * envoy: support autocert * envoy: fallback to http host routing if sni fails to match * update comment * envoy: renew certs when necessary * fix tests --- cmd/pomerium/main.go | 78 ++----- cmd/pomerium/main_test.go | 68 +----- config/autocert.go | 102 +++++++++ config/options.go | 128 ++++++++--- config/options_test.go | 67 +++++- config/policy.go | 2 +- config/redirect.go | 60 ++++++ go.mod | 1 - internal/controlplane/grpc_xds.go | 2 +- internal/controlplane/xds.go | 16 +- internal/controlplane/xds_clusters.go | 2 +- internal/controlplane/xds_listeners.go | 273 +++++++++++++++--------- internal/controlplane/xds_routes.go | 4 +- internal/cryptutil/certificates.go | 66 +++++- internal/cryptutil/certificates_test.go | 4 +- internal/cryptutil/tls.go | 112 +++------- internal/cryptutil/tls_test.go | 92 ++++---- internal/envoy/envoy.go | 3 +- 18 files changed, 689 insertions(+), 391 deletions(-) create mode 100644 config/autocert.go create mode 100644 config/redirect.go diff --git a/cmd/pomerium/main.go b/cmd/pomerium/main.go index 62584949c..6b7c49b10 100644 --- a/cmd/pomerium/main.go +++ b/cmd/pomerium/main.go @@ -5,8 +5,10 @@ import ( "flag" "fmt" "net" + "os" + "os/signal" "sync" - "time" + "syscall" "github.com/pomerium/pomerium/authenticate" "github.com/pomerium/pomerium/authorize" @@ -24,7 +26,6 @@ import ( "github.com/pomerium/pomerium/proxy" envoy_service_auth_v2 "github.com/envoyproxy/go-control-plane/envoy/service/auth/v2" - "github.com/fsnotify/fsnotify" "golang.org/x/sync/errgroup" ) @@ -32,12 +33,12 @@ var versionFlag = flag.Bool("version", false, "prints the version") var configFile = flag.String("config", "", "Specify configuration file location") func main() { - if err := run(); err != nil { + if err := run(context.Background()); err != nil { log.Fatal().Err(err).Msg("cmd/pomerium") } } -func run() error { +func run(ctx context.Context) error { flag.Parse() if *versionFlag { fmt.Println(version.FullVersion()) @@ -51,18 +52,12 @@ func run() error { log.Info().Str("version", version.FullVersion()).Msg("cmd/pomerium") - var wg sync.WaitGroup - if err := setupMetrics(opt, &wg); err != nil { + if err := setupMetrics(opt); err != nil { return err } if err := setupTracing(opt); err != nil { return err } - if err := setupHTTPRedirectServer(opt, &wg); err != nil { - return err - } - - ctx := context.Background() // setup the control plane controlPlane, err := controlplane.NewServer() @@ -99,17 +94,19 @@ func run() error { } // start the config change listener - opt.OnConfigChange(func(e fsnotify.Event) { - log.Info().Str("file", e.Name).Msg("cmd/pomerium: config file changed") - opt = config.HandleConfigUpdate(*configFile, opt, optionsUpdaters) - }) + go config.WatchChanges(*configFile, opt, optionsUpdaters) + + ctx, cancel := context.WithCancel(ctx) + go func() { + ch := make(chan os.Signal, 2) + signal.Notify(ch, os.Interrupt) + signal.Notify(ch, syscall.SIGTERM) + <-ch + cancel() + }() // run everything eg, ctx := errgroup.WithContext(ctx) - eg.Go(func() error { - wg.Wait() - return nil - }) eg.Go(func() error { return controlPlane.Run(ctx) }) @@ -172,7 +169,7 @@ func setupCache(opt *config.Options, controlPlane *controlplane.Server) error { return nil } -func setupMetrics(opt *config.Options, wg *sync.WaitGroup) error { +func setupMetrics(opt *config.Options) error { if opt.MetricsAddr != "" { handler, err := metrics.PrometheusHandler() if err != nil { @@ -185,11 +182,11 @@ func setupMetrics(opt *config.Options, wg *sync.WaitGroup) error { Insecure: true, Service: "metrics", } - srv, err := httputil.NewServer(serverOpts, handler, wg) + var wg sync.WaitGroup + _, err = httputil.NewServer(serverOpts, handler, &wg) if err != nil { return err } - go httputil.Shutdown(srv) } return nil } @@ -222,40 +219,3 @@ func setupTracing(opt *config.Options) error { } return nil } - -func setupHTTPRedirectServer(opt *config.Options, wg *sync.WaitGroup) error { - if opt.HTTPRedirectAddr != "" { - serverOpts := httputil.ServerOptions{ - Addr: opt.HTTPRedirectAddr, - Insecure: true, - Service: "HTTP->HTTPS Redirect", - ReadHeaderTimeout: 5 * time.Second, - ReadTimeout: 5 * time.Second, - WriteTimeout: 5 * time.Second, - IdleTimeout: 5 * time.Second, - } - h := httputil.RedirectHandler() - if opt.AutoCert { - h = opt.AutoCertHandler(h) - } - srv, err := httputil.NewServer(&serverOpts, h, wg) - if err != nil { - return err - } - go httputil.Shutdown(srv) - } - return nil -} - -func httpServerOptions(opt *config.Options) *httputil.ServerOptions { - return &httputil.ServerOptions{ - Addr: opt.Addr, - TLSConfig: opt.TLSConfig, - Insecure: opt.InsecureServer, - ReadTimeout: opt.ReadTimeout, - WriteTimeout: opt.WriteTimeout, - ReadHeaderTimeout: opt.ReadHeaderTimeout, - IdleTimeout: opt.IdleTimeout, - Service: opt.Services, - } -} diff --git a/cmd/pomerium/main_test.go b/cmd/pomerium/main_test.go index c60117353..21c317dfa 100644 --- a/cmd/pomerium/main_test.go +++ b/cmd/pomerium/main_test.go @@ -1,36 +1,17 @@ package main import ( + "context" "io/ioutil" "os" "os/signal" - "sync" "syscall" "testing" "time" - "github.com/google/go-cmp/cmp" "github.com/pomerium/pomerium/config" - "github.com/pomerium/pomerium/internal/httputil" ) -func Test_httpServerOptions(t *testing.T) { - tests := []struct { - name string - opt *config.Options - want *httputil.ServerOptions - }{ - {"simple convert", &config.Options{Addr: ":80"}, &httputil.ServerOptions{Addr: ":80"}}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if diff := cmp.Diff(httpServerOptions(tt.opt), tt.want); diff != "" { - t.Errorf("httpServerOptions() = \n %s", diff) - } - }) - } -} - func Test_setupTracing(t *testing.T) { tests := []struct { name string @@ -60,41 +41,9 @@ func Test_setupMetrics(t *testing.T) { c := make(chan os.Signal, 1) signal.Notify(c, syscall.SIGINT) defer signal.Stop(c) - var wg sync.WaitGroup - - setupMetrics(tt.opt, &wg) + setupMetrics(tt.opt) syscall.Kill(syscall.Getpid(), syscall.SIGINT) waitSig(t, c, syscall.SIGINT) - - }) - } -} - -func Test_setupHTTPRedirectServer(t *testing.T) { - tests := []struct { - name string - opt *config.Options - wantErr bool - }{ - {"dont register anything", &config.Options{}, false}, - {"good redirect server", &config.Options{HTTPRedirectAddr: "localhost:0"}, false}, - {"bad redirect server port", &config.Options{HTTPRedirectAddr: "localhost:-1"}, true}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - c := make(chan os.Signal, 1) - var wg sync.WaitGroup - - signal.Notify(c, syscall.SIGINT) - defer signal.Stop(c) - err := setupHTTPRedirectServer(tt.opt, &wg) - if (err != nil) != tt.wantErr { - t.Errorf("run() error = %v, wantErr %v", err, tt.wantErr) - } - - syscall.Kill(syscall.Getpid(), syscall.SIGINT) - waitSig(t, c, syscall.SIGINT) - }) } } @@ -274,16 +223,11 @@ func Test_run(t *testing.T) { t.Fatal(err) } configFile = &fn - proc, err := os.FindProcess(os.Getpid()) - if err != nil { - t.Fatal(err) - } - go func() { - time.Sleep(time.Millisecond * 500) - proc.Signal(os.Interrupt) - }() - err = run() + ctx, clearTimeout := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer clearTimeout() + + err = run(ctx) if (err != nil) != tt.wantErr { t.Errorf("run() error = %v, wantErr %v", err, tt.wantErr) } diff --git a/config/autocert.go b/config/autocert.go new file mode 100644 index 000000000..ec18225ec --- /dev/null +++ b/config/autocert.go @@ -0,0 +1,102 @@ +package config + +import ( + "context" + "fmt" + "net/http" + "sync" + + "github.com/caddyserver/certmagic" + + "github.com/pomerium/pomerium/internal/log" +) + +// AutocertManager manages Let's Encrypt certificates based on configuration options. +var AutocertManager = newAutocertManager() + +type autocertManager struct { + mu sync.RWMutex + certmagic *certmagic.Config + acmeMgr *certmagic.ACMEManager +} + +func newAutocertManager() *autocertManager { + mgr := &autocertManager{} + return mgr +} + +func (mgr *autocertManager) getConfig(options *Options) (*certmagic.Config, error) { + mgr.mu.Lock() + defer mgr.mu.Unlock() + + cm := mgr.certmagic + if cm == nil { + cm = certmagic.NewDefault() + } + + cm.OnDemand = nil // disable on-demand + cm.Storage = &certmagic.FileStorage{Path: options.AutoCertFolder} + // add existing certs to the cache, and staple OCSP + for _, cert := range options.Certificates { + if err := cm.CacheUnmanagedTLSCertificate(cert, nil); err != nil { + return nil, fmt.Errorf("config: failed caching cert: %w", err) + } + } + acmeMgr := certmagic.NewACMEManager(cm, certmagic.DefaultACME) + acmeMgr.Agreed = true + if options.AutoCertUseStaging { + acmeMgr.CA = certmagic.LetsEncryptStagingCA + } + acmeMgr.DisableTLSALPNChallenge = true + cm.Issuer = acmeMgr + mgr.acmeMgr = acmeMgr + + return cm, nil +} + +func (mgr *autocertManager) update(options *Options) error { + if !options.AutoCert { + return nil + } + + cm, err := mgr.getConfig(options) + if err != nil { + return err + } + + for _, domain := range options.sourceHostnames() { + cert, err := cm.CacheManagedCertificate(domain) + if err != nil { + log.Info().Str("domain", domain).Msg("obtaining certificate") + err = cm.ObtainCert(context.Background(), domain, false) + if err != nil { + return fmt.Errorf("config: failed to obtain client certificate: %w", err) + } + cert, err = cm.CacheManagedCertificate(domain) + } + if err == nil && cert.NeedsRenewal(cm) { + log.Info().Str("domain", domain).Msg("renewing certificate") + err = cm.RenewCert(context.Background(), domain, false) + if err != nil { + return fmt.Errorf("config: failed to renew client certificate: %w", err) + } + cert, err = cm.CacheManagedCertificate(domain) + } + if err == nil { + options.Certificates = append(options.Certificates, cert.Certificate) + } else { + log.Error().Err(err).Msg("config: failed to obtain client certificate") + } + } + return nil +} + +func (mgr *autocertManager) HandleHTTPChallenge(w http.ResponseWriter, r *http.Request) bool { + mgr.mu.RLock() + acmeMgr := mgr.acmeMgr + mgr.mu.RUnlock() + if acmeMgr == nil { + return false + } + return acmeMgr.HandleHTTPChallenge(w, r) +} diff --git a/config/options.go b/config/options.go index 609e60e47..88628a3d1 100644 --- a/config/options.go +++ b/config/options.go @@ -1,15 +1,16 @@ package config import ( + "bytes" "crypto/tls" "encoding/base64" "errors" "fmt" - "net/http" "net/url" "os" "path/filepath" "reflect" + "sort" "strings" "time" @@ -68,10 +69,6 @@ type Options struct { // and renewal from LetsEncrypt. Must be used in conjunction with AutoCertFolder. AutoCert bool `mapstructure:"autocert" yaml:"autocert,omitempty"` - // AutoCertHandler is the HTTP challenge handler used in a http-01 acme - // https://letsencrypt.org/docs/challenge-types/#http-01-challenge - AutoCertHandler func(h http.Handler) http.Handler `hash:"ignore"` - // AutoCertFolder specifies the location to store, and load autocert managed // TLS certificates. // defaults to $XDG_DATA_HOME/pomerium @@ -83,7 +80,7 @@ type Options struct { // https://letsencrypt.org/docs/staging-environment/ AutoCertUseStaging bool `mapstructure:"autocert_use_staging" yaml:"autocert_use_staging,omitempty"` - Certificates []certificateFilePair `mapstructure:"certificates" yaml:"certificates,omitempty"` + CertificateFiles []certificateFilePair `mapstructure:"certificates" yaml:"certificates,omitempty"` // Cert and Key is the x509 certificate used to create the HTTPS server. Cert string `mapstructure:"certificate" yaml:"certificate,omitempty"` @@ -93,7 +90,7 @@ type Options struct { CertFile string `mapstructure:"certificate_file" yaml:"certificate_file,omitempty"` KeyFile string `mapstructure:"certificate_key_file" yaml:"certificate_key_file,omitempty"` - TLSConfig *tls.Config `hash:"ignore"` + Certificates []tls.Certificate `yaml:"-"` // HttpRedirectAddr, if set, specifies the host and port to run the HTTP // to HTTPS redirect server on. If empty, no redirect server is started. @@ -534,38 +531,42 @@ func (o *Options) Validate() error { } if o.Cert != "" || o.Key != "" { - o.TLSConfig, err = cryptutil.TLSConfigFromBase64(o.TLSConfig, o.Cert, o.Key) + cert, err := cryptutil.CertificateFromBase64(o.Cert, o.Key) if err != nil { return fmt.Errorf("config: bad cert base64 %w", err) } + o.Certificates = append(o.Certificates, *cert) } - if len(o.Certificates) != 0 { - for _, c := range o.Certificates { - o.TLSConfig, err = cryptutil.TLSConfigFromFile(o.TLSConfig, c.CertFile, c.KeyFile) - if err != nil { - return fmt.Errorf("config: bad cert file %w", err) - } - } - } - - if o.CertFile != "" || o.KeyFile != "" { - o.TLSConfig, err = cryptutil.TLSConfigFromFile(o.TLSConfig, o.CertFile, o.KeyFile) + for _, c := range o.CertificateFiles { + cert, err := cryptutil.CertificateFromFile(c.CertFile, c.KeyFile) if err != nil { return fmt.Errorf("config: bad cert file %w", err) } + o.Certificates = append(o.Certificates, *cert) } - if o.AutoCert { - o.TLSConfig, o.AutoCertHandler, err = cryptutil.NewAutocert( - o.TLSConfig, - o.sourceHostnames(), - o.AutoCertUseStaging, - o.AutoCertFolder) + + if o.CertFile != "" || o.KeyFile != "" { + cert, err := cryptutil.CertificateFromFile(o.CertFile, o.KeyFile) if err != nil { - return fmt.Errorf("config: autocert failed %w", err) + return fmt.Errorf("config: bad cert file %w", err) } + o.Certificates = append(o.Certificates, *cert) } - if !o.InsecureServer && o.TLSConfig == nil { + + RedirectAndAutocertServer.update(o) + + err = AutocertManager.update(o) + if err != nil { + return fmt.Errorf("config: failed to setup autocert: %w", err) + } + + // sort the certificates so we get a consistent hash + sort.Slice(o.Certificates, func(i, j int) bool { + return compareByteSliceSlice(o.Certificates[i].Certificate, o.Certificates[j].Certificate) < 0 + }) + + if !o.InsecureServer && len(o.Certificates) == 0 { return fmt.Errorf("config: server must be run with `autocert`, " + "`insecure_server` or manually provided certificates to start") } @@ -576,13 +577,21 @@ func (o *Options) sourceHostnames() []string { if len(o.Policies) == 0 { return nil } - var h []string + + dedupe := map[string]struct{}{} for _, p := range o.Policies { - h = append(h, p.Source.Hostname()) + dedupe[p.Source.Hostname()] = struct{}{} } if o.AuthenticateURL != nil { - h = append(h, o.AuthenticateURL.Hostname()) + dedupe[o.AuthenticateURL.Hostname()] = struct{}{} } + + var h []string + for k := range dedupe { + h = append(h, k) + } + sort.Strings(h) + return h } @@ -601,10 +610,37 @@ func (o *Options) Checksum() uint64 { return hash } -// HandleConfigUpdate takes configuration file, an existing options struct, and +// WatchChanges takes a configuration file, an existing options struct, and +// updates each service in the services slice OptionsUpdater with a new set +// of options if any change is detected. It also periodically rechecks if +// any computed properties have changed. +func WatchChanges(configFile string, opt *Options, services []OptionsUpdater) { + onchange := make(chan struct{}, 1) + ticker := time.NewTicker(10 * time.Minute) // force check every 10 minutes + defer ticker.Stop() + + opt.OnConfigChange(func(fs fsnotify.Event) { + log.Info().Str("file", fs.Name).Msg("config: file changed") + select { + case onchange <- struct{}{}: + default: + } + }) + + for { + select { + case <-onchange: + case <-ticker.C: + } + + opt = handleConfigUpdate(configFile, opt, services) + } +} + +// handleConfigUpdate takes configuration file, an existing options struct, and // updates each service in the services slice OptionsUpdater with a new set of // options if any change is detected. -func HandleConfigUpdate(configFile string, opt *Options, services []OptionsUpdater) *Options { +func handleConfigUpdate(configFile string, opt *Options, services []OptionsUpdater) *Options { newOpt, err := NewOptionsFromConfig(configFile) if err != nil { log.Error().Err(err).Msg("config: could not reload configuration") @@ -648,3 +684,31 @@ func dataDir() string { } return filepath.Join(baseDir, "pomerium") } + +func compareByteSliceSlice(a, b [][]byte) int { + sz := min(len(a), len(b)) + for i := 0; i < sz; i++ { + switch bytes.Compare(a[i], b[i]) { + case -1: + return -1 + case 1: + return 1 + } + } + + switch { + case len(a) < len(b): + return -1 + case len(b) < len(a): + return 1 + default: + return 0 + } +} + +func min(x, y int) int { + if x < y { + return x + } + return y +} diff --git a/config/options_test.go b/config/options_test.go index 0b6d6daac..a6ae9d147 100644 --- a/config/options_test.go +++ b/config/options_test.go @@ -416,7 +416,7 @@ func Test_HandleConfigUpdate(t *testing.T) { os.Setenv(k, v) defer os.Unsetenv(k) } - HandleConfigUpdate("", oldOpts, []OptionsUpdater{tt.service}) + handleConfigUpdate("", oldOpts, []OptionsUpdater{tt.service}) if tt.service.Updated != tt.wantUpdate { t.Errorf("Failed to update config on service") } @@ -441,7 +441,7 @@ func TestOptions_sourceHostnames(t *testing.T) { }{ {"empty", []Policy{}, "", nil}, {"good no authN", []Policy{{From: "https://from.example", To: "https://to.example"}}, "", []string{"from.example"}}, - {"good with authN", []Policy{{From: "https://from.example", To: "https://to.example"}}, "https://authn.example.com", []string{"from.example", "authn.example.com"}}, + {"good with authN", []Policy{{From: "https://from.example", To: "https://to.example"}}, "https://authn.example.com", []string{"authn.example.com", "from.example"}}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -459,3 +459,66 @@ func TestOptions_sourceHostnames(t *testing.T) { }) } } + +func TestCompareByteSliceSlice(t *testing.T) { + type Bytes = [][]byte + + tests := []struct { + expect int + a Bytes + b Bytes + }{ + { + 0, + Bytes{ + {0, 1, 2, 3}, + }, + Bytes{ + {0, 1, 2, 3}, + }, + }, + { + -1, + Bytes{ + {0, 1, 2, 3}, + }, + Bytes{ + {0, 1, 2, 4}, + }, + }, + { + 1, + Bytes{ + {0, 1, 2, 4}, + }, + Bytes{ + {0, 1, 2, 3}, + }, + }, + {-1, + Bytes{ + {0, 1, 2, 3}, + }, + Bytes{ + {0, 1, 2, 3}, + {4, 5, 6, 7}, + }, + }, + {1, + Bytes{ + {0, 1, 2, 3}, + {4, 5, 6, 7}, + }, + Bytes{ + {0, 1, 2, 3}, + }, + }, + } + for _, tt := range tests { + actual := compareByteSliceSlice(tt.a, tt.b) + if tt.expect != actual { + t.Errorf("expected compare(%v, %v) to be %v but got %v", + tt.a, tt.b, tt.expect, actual) + } + } +} diff --git a/config/policy.go b/config/policy.go index f8573fce4..6f23cd3e8 100644 --- a/config/policy.go +++ b/config/policy.go @@ -115,7 +115,7 @@ func (p *Policy) Validate() error { } if p.TLSClientCert != "" && p.TLSClientKey != "" { - p.ClientCertificate, err = cryptutil.CertifcateFromBase64(p.TLSClientCert, p.TLSClientKey) + p.ClientCertificate, err = cryptutil.CertificateFromBase64(p.TLSClientCert, p.TLSClientKey) if err != nil { return fmt.Errorf("config: couldn't decode client cert %w", err) } diff --git a/config/redirect.go b/config/redirect.go new file mode 100644 index 000000000..34afe4399 --- /dev/null +++ b/config/redirect.go @@ -0,0 +1,60 @@ +package config + +import ( + "net/http" + "sync" + + "github.com/pomerium/pomerium/internal/httputil" + "github.com/pomerium/pomerium/internal/log" +) + +// RedirectAndAutocertServer is an HTTP server which handles redirecting to HTTPS and autocerts. +var RedirectAndAutocertServer = newRedirectAndAutoCertServer() + +type redirectAndAutoCertServer struct { + mu sync.Mutex + srv *http.Server +} + +func newRedirectAndAutoCertServer() *redirectAndAutoCertServer { + return &redirectAndAutoCertServer{} +} + +func (srv *redirectAndAutoCertServer) update(options *Options) { + srv.mu.Lock() + defer srv.mu.Unlock() + + if srv.srv != nil { + // nothing to do if the address hasn't changed + if srv.srv.Addr == options.HTTPRedirectAddr { + return + } + // close immediately, don't care about the error + _ = srv.srv.Close() + srv.srv = nil + } + + if options.HTTPRedirectAddr == "" { + return + } + + redirect := httputil.RedirectHandler() + + hsrv := &http.Server{ + Addr: options.HTTPRedirectAddr, + Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if AutocertManager.HandleHTTPChallenge(w, r) { + return + } + redirect.ServeHTTP(w, r) + }), + } + go func() { + log.Info().Str("addr", hsrv.Addr).Msg("starting http redirect server") + err := hsrv.ListenAndServe() + if err != nil { + log.Error().Err(err).Msg("failed to run http redirect server") + } + }() + srv.srv = hsrv +} diff --git a/go.mod b/go.mod index 2ebf0415b..1797a56e6 100644 --- a/go.mod +++ b/go.mod @@ -10,7 +10,6 @@ require ( github.com/coreos/go-oidc v2.2.1+incompatible github.com/envoyproxy/go-control-plane v0.9.5 github.com/fsnotify/fsnotify v1.4.9 - github.com/go-acme/lego/v3 v3.4.0 github.com/go-redis/redis/v7 v7.2.0 github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e github.com/golang/mock v1.4.3 diff --git a/internal/controlplane/grpc_xds.go b/internal/controlplane/grpc_xds.go index edcb3bbbc..1c2907b5a 100644 --- a/internal/controlplane/grpc_xds.go +++ b/internal/controlplane/grpc_xds.go @@ -96,7 +96,7 @@ func (srv *Server) streamAggregatedResourcesProcessStep( for typeURL, version := range versions { // the versions are different, so the envoy config needs to be updated if version != fmt.Sprint(current.version) { - res, err := srv.buildDiscoveryResponse(fmt.Sprint(current.version), typeURL, current.Options) + res, err := srv.buildDiscoveryResponse(fmt.Sprint(current.version), typeURL, ¤t.Options) if err != nil { return err } diff --git a/internal/controlplane/xds.go b/internal/controlplane/xds.go index 8f9a9fb34..98956e1be 100644 --- a/internal/controlplane/xds.go +++ b/internal/controlplane/xds.go @@ -3,8 +3,6 @@ package controlplane import ( "fmt" "net" - "os" - "path/filepath" "strconv" "github.com/pomerium/pomerium/config" @@ -19,7 +17,7 @@ import ( "google.golang.org/grpc/status" ) -func (srv *Server) buildDiscoveryResponse(version string, typeURL string, options config.Options) (*envoy_service_discovery_v3.DiscoveryResponse, error) { +func (srv *Server) buildDiscoveryResponse(version string, typeURL string, options *config.Options) (*envoy_service_discovery_v3.DiscoveryResponse, error) { switch typeURL { case "type.googleapis.com/envoy.config.listener.v3.Listener": listeners := srv.buildListeners(options) @@ -56,7 +54,7 @@ func (srv *Server) buildDiscoveryResponse(version string, typeURL string, option } } -func (srv *Server) buildAccessLogs(options config.Options) []*envoy_config_accesslog_v3.AccessLog { +func (srv *Server) buildAccessLogs(options *config.Options) []*envoy_config_accesslog_v3.AccessLog { lvl := options.ProxyLogLevel if lvl == "" { lvl = options.LogLevel @@ -112,10 +110,10 @@ func buildAddress(hostport string, defaultPort int) *envoy_config_core_v3.Addres } } -func getAbsoluteFilePath(filename string) string { - if filepath.IsAbs(filename) { - return filename +func inlineBytes(bs []byte) *envoy_config_core_v3.DataSource { + return &envoy_config_core_v3.DataSource{ + Specifier: &envoy_config_core_v3.DataSource_InlineBytes{ + InlineBytes: bs, + }, } - wd, _ := os.Getwd() - return filepath.Join(wd, filename) } diff --git a/internal/controlplane/xds_clusters.go b/internal/controlplane/xds_clusters.go index 392f7f10a..94fbc8a66 100644 --- a/internal/controlplane/xds_clusters.go +++ b/internal/controlplane/xds_clusters.go @@ -15,7 +15,7 @@ import ( "github.com/pomerium/pomerium/internal/urlutil" ) -func (srv *Server) buildClusters(options config.Options) []*envoy_config_cluster_v3.Cluster { +func (srv *Server) buildClusters(options *config.Options) []*envoy_config_cluster_v3.Cluster { grpcURL := &url.URL{ Scheme: "grpc", Host: srv.GRPCListener.Addr().String(), diff --git a/internal/controlplane/xds_listeners.go b/internal/controlplane/xds_listeners.go index 641773174..e7ab2b409 100644 --- a/internal/controlplane/xds_listeners.go +++ b/internal/controlplane/xds_listeners.go @@ -1,7 +1,9 @@ package controlplane import ( - "encoding/base64" + "bytes" + "crypto/x509" + "encoding/pem" "sort" envoy_config_core_v3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" @@ -14,8 +16,11 @@ import ( envoy_type_v3 "github.com/envoyproxy/go-control-plane/envoy/type/v3" "github.com/golang/protobuf/ptypes" "github.com/golang/protobuf/ptypes/any" + "google.golang.org/protobuf/types/known/emptypb" "github.com/pomerium/pomerium/config" + "github.com/pomerium/pomerium/internal/cryptutil" + "github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/urlutil" ) @@ -29,11 +34,11 @@ func init() { }) } -func (srv *Server) buildListeners(options config.Options) []*envoy_config_listener_v3.Listener { +func (srv *Server) buildListeners(options *config.Options) []*envoy_config_listener_v3.Listener { var listeners []*envoy_config_listener_v3.Listener if config.IsAuthenticate(options.Services) || config.IsProxy(options.Services) { - listeners = append(listeners, srv.buildHTTPListener(options)) + listeners = append(listeners, srv.buildMainListener(options)) } if config.IsAuthorize(options.Services) || config.IsCache(options.Services) { @@ -43,22 +48,77 @@ func (srv *Server) buildListeners(options config.Options) []*envoy_config_listen return listeners } -func (srv *Server) buildHTTPListener(options config.Options) *envoy_config_listener_v3.Listener { - defaultPort := 80 - var transportSocket *envoy_config_core_v3.TransportSocket - if !options.InsecureServer { - defaultPort = 443 - tlsConfig, _ := ptypes.MarshalAny(srv.buildDownstreamTLSContext(options)) - transportSocket = &envoy_config_core_v3.TransportSocket{ - Name: "tls", - ConfigType: &envoy_config_core_v3.TransportSocket_TypedConfig{ - TypedConfig: tlsConfig, - }, +func (srv *Server) buildMainListener(options *config.Options) *envoy_config_listener_v3.Listener { + if options.InsecureServer { + filter := srv.buildMainHTTPConnectionManagerFilter(options, + srv.getAllRouteableDomains(options, options.Addr)) + + return &envoy_config_listener_v3.Listener{ + Name: "http-ingress", + Address: buildAddress(options.Addr, 80), + FilterChains: []*envoy_config_listener_v3.FilterChain{{ + Filters: []*envoy_config_listener_v3.Filter{ + filter, + }, + }}, } } + tlsInspectorCfg, _ := ptypes.MarshalAny(new(emptypb.Empty)) + li := &envoy_config_listener_v3.Listener{ + Name: "https-ingress", + Address: buildAddress(options.Addr, 443), + ListenerFilters: []*envoy_config_listener_v3.ListenerFilter{{ + Name: "envoy.filters.listener.tls_inspector", + ConfigType: &envoy_config_listener_v3.ListenerFilter_TypedConfig{ + TypedConfig: tlsInspectorCfg, + }, + }}, + FilterChains: srv.buildFilterChains(options, options.Addr, + func(tlsDomain string, httpDomains []string) *envoy_config_listener_v3.FilterChain { + filter := srv.buildMainHTTPConnectionManagerFilter(options, httpDomains) + filterChain := &envoy_config_listener_v3.FilterChain{ + Filters: []*envoy_config_listener_v3.Filter{filter}, + } + if tlsDomain != "*" { + filterChain.FilterChainMatch = &envoy_config_listener_v3.FilterChainMatch{ + ServerNames: []string{tlsDomain}, + } + } + tlsContext := srv.buildDownstreamTLSContext(options, tlsDomain) + if tlsContext != nil { + tlsConfig, _ := ptypes.MarshalAny(tlsContext) + filterChain.TransportSocket = &envoy_config_core_v3.TransportSocket{ + Name: "tls", + ConfigType: &envoy_config_core_v3.TransportSocket_TypedConfig{ + TypedConfig: tlsConfig, + }, + } + } + return filterChain + }), + } + return li +} + +func (srv *Server) buildFilterChains( + options *config.Options, addr string, + callback func(tlsDomain string, httpDomains []string) *envoy_config_listener_v3.FilterChain, +) []*envoy_config_listener_v3.FilterChain { + allDomains := srv.getAllRouteableDomains(options, addr) + var chains []*envoy_config_listener_v3.FilterChain + for _, domain := range allDomains { + // first we match on SNI + chains = append(chains, callback(domain, []string{domain})) + } + // if there are no SNI matches we match on HTTP host + chains = append(chains, callback("*", allDomains)) + return chains +} + +func (srv *Server) buildMainHTTPConnectionManagerFilter(options *config.Options, domains []string) *envoy_config_listener_v3.Filter { var virtualHosts []*envoy_config_route_v3.VirtualHost - for _, domain := range srv.getAllRouteableDomains(options, options.Addr) { + for _, domain := range domains { vh := &envoy_config_route_v3.VirtualHost{ Name: domain, Domains: []string{domain}, @@ -142,38 +202,66 @@ func (srv *Server) buildHTTPListener(options config.Options) *envoy_config_liste AccessLog: srv.buildAccessLogs(options), }) - li := &envoy_config_listener_v3.Listener{ - Name: "http-ingress", - Address: buildAddress(options.Addr, defaultPort), - FilterChains: []*envoy_config_listener_v3.FilterChain{{ - Filters: []*envoy_config_listener_v3.Filter{ - { - Name: "envoy.filters.network.http_connection_manager", - ConfigType: &envoy_config_listener_v3.Filter_TypedConfig{ - TypedConfig: tc, - }, + return &envoy_config_listener_v3.Filter{ + Name: "envoy.filters.network.http_connection_manager", + ConfigType: &envoy_config_listener_v3.Filter_TypedConfig{ + TypedConfig: tc, + }, + } +} + +func (srv *Server) buildGRPCListener(options *config.Options) *envoy_config_listener_v3.Listener { + filter := srv.buildGRPCHTTPConnectionManagerFilter() + + if options.GRPCInsecure { + return &envoy_config_listener_v3.Listener{ + Name: "grpc-ingress", + Address: buildAddress(options.GRPCAddr, 80), + FilterChains: []*envoy_config_listener_v3.FilterChain{{ + Filters: []*envoy_config_listener_v3.Filter{ + filter, }, + }}, + } + } + + tlsInspectorCfg, _ := ptypes.MarshalAny(new(emptypb.Empty)) + li := &envoy_config_listener_v3.Listener{ + Name: "grpc-ingress", + Address: buildAddress(options.GRPCAddr, 443), + ListenerFilters: []*envoy_config_listener_v3.ListenerFilter{{ + Name: "envoy.filters.listener.tls_inspector", + ConfigType: &envoy_config_listener_v3.ListenerFilter_TypedConfig{ + TypedConfig: tlsInspectorCfg, }, - TransportSocket: transportSocket, }}, + FilterChains: srv.buildFilterChains(options, options.Addr, + func(tlsDomain string, httpDomains []string) *envoy_config_listener_v3.FilterChain { + filterChain := &envoy_config_listener_v3.FilterChain{ + Filters: []*envoy_config_listener_v3.Filter{filter}, + } + if tlsDomain != "*" { + filterChain.FilterChainMatch = &envoy_config_listener_v3.FilterChainMatch{ + ServerNames: []string{tlsDomain}, + } + } + tlsContext := srv.buildDownstreamTLSContext(options, tlsDomain) + if tlsContext != nil { + tlsConfig, _ := ptypes.MarshalAny(tlsContext) + filterChain.TransportSocket = &envoy_config_core_v3.TransportSocket{ + Name: "tls", + ConfigType: &envoy_config_core_v3.TransportSocket_TypedConfig{ + TypedConfig: tlsConfig, + }, + } + } + return filterChain + }), } return li } -func (srv *Server) buildGRPCListener(options config.Options) *envoy_config_listener_v3.Listener { - defaultPort := 80 - var transportSocket *envoy_config_core_v3.TransportSocket - if !options.GRPCInsecure { - defaultPort = 443 - tlsConfig, _ := ptypes.MarshalAny(srv.buildDownstreamTLSContext(options)) - transportSocket = &envoy_config_core_v3.TransportSocket{ - Name: "tls", - ConfigType: &envoy_config_core_v3.TransportSocket_TypedConfig{ - TypedConfig: tlsConfig, - }, - } - } - +func (srv *Server) buildGRPCHTTPConnectionManagerFilter() *envoy_config_listener_v3.Filter { tc, _ := ptypes.MarshalAny(&envoy_http_connection_manager.HttpConnectionManager{ CodecType: envoy_http_connection_manager.HttpConnectionManager_AUTO, StatPrefix: "grpc_ingress", @@ -191,7 +279,9 @@ func (srv *Server) buildGRPCListener(options config.Options) *envoy_config_liste }, Action: &envoy_config_route_v3.Route_Route{ Route: &envoy_config_route_v3.RouteAction{ - ClusterSpecifier: &envoy_config_route_v3.RouteAction_Cluster{Cluster: "pomerium-control-plane-grpc"}, + ClusterSpecifier: &envoy_config_route_v3.RouteAction_Cluster{ + Cluster: "pomerium-control-plane-grpc", + }, }, }, }}, @@ -202,64 +292,57 @@ func (srv *Server) buildGRPCListener(options config.Options) *envoy_config_liste Name: "envoy.filters.http.router", }}, }) - - return &envoy_config_listener_v3.Listener{ - Name: "grpc-ingress", - Address: buildAddress(options.GRPCAddr, defaultPort), - FilterChains: []*envoy_config_listener_v3.FilterChain{{ - Filters: []*envoy_config_listener_v3.Filter{{ - Name: "envoy.filters.network.http_connection_manager", - ConfigType: &envoy_config_listener_v3.Filter_TypedConfig{ - TypedConfig: tc, - }, - }}, - TransportSocket: transportSocket, - }}, - } -} - -func (srv *Server) buildDownstreamTLSContext(options config.Options) *envoy_extensions_transport_sockets_tls_v3.DownstreamTlsContext { - var cert envoy_extensions_transport_sockets_tls_v3.TlsCertificate - if options.Cert != "" { - bs, _ := base64.StdEncoding.DecodeString(options.Cert) - cert.CertificateChain = &envoy_config_core_v3.DataSource{ - Specifier: &envoy_config_core_v3.DataSource_InlineBytes{ - InlineBytes: bs, - }, - } - } else { - cert.CertificateChain = &envoy_config_core_v3.DataSource{ - Specifier: &envoy_config_core_v3.DataSource_Filename{ - Filename: getAbsoluteFilePath(options.CertFile), - }, - } - } - if options.Key != "" { - bs, _ := base64.StdEncoding.DecodeString(options.Key) - cert.PrivateKey = &envoy_config_core_v3.DataSource{ - Specifier: &envoy_config_core_v3.DataSource_InlineBytes{ - InlineBytes: bs, - }, - } - } else { - cert.PrivateKey = &envoy_config_core_v3.DataSource{ - Specifier: &envoy_config_core_v3.DataSource_Filename{ - Filename: getAbsoluteFilePath(options.KeyFile), - }, - } - } - - return &envoy_extensions_transport_sockets_tls_v3.DownstreamTlsContext{ - CommonTlsContext: &envoy_extensions_transport_sockets_tls_v3.CommonTlsContext{ - TlsCertificates: []*envoy_extensions_transport_sockets_tls_v3.TlsCertificate{ - &cert, - }, - AlpnProtocols: []string{"h2", "http/1.1"}, + return &envoy_config_listener_v3.Filter{ + Name: "envoy.filters.network.http_connection_manager", + ConfigType: &envoy_config_listener_v3.Filter_TypedConfig{ + TypedConfig: tc, }, } } -func (srv *Server) getAllRouteableDomains(options config.Options, addr string) []string { +func (srv *Server) buildDownstreamTLSContext(options *config.Options, domain string) *envoy_extensions_transport_sockets_tls_v3.DownstreamTlsContext { + cert, err := cryptutil.GetCertificateForDomain(options.Certificates, domain) + if err != nil { + log.Warn().Str("domain", domain).Err(err).Msg("failed to get certificate for domain") + return nil + } + + envoyCert := &envoy_extensions_transport_sockets_tls_v3.TlsCertificate{} + var chain bytes.Buffer + for _, cbs := range cert.Certificate { + _ = pem.Encode(&chain, &pem.Block{ + Type: "CERTIFICATE", + Bytes: cbs, + }) + } + envoyCert.CertificateChain = inlineBytes(chain.Bytes()) + if cert.OCSPStaple != nil { + envoyCert.OcspStaple = inlineBytes(cert.OCSPStaple) + } + if bs, err := x509.MarshalPKCS8PrivateKey(cert.PrivateKey); err == nil { + envoyCert.PrivateKey = inlineBytes(pem.EncodeToMemory( + &pem.Block{ + Type: "PRIVATE KEY", + Bytes: bs, + }, + )) + } else { + log.Warn().Err(err).Msg("failed to marshal private key for tls config") + } + for _, scts := range cert.SignedCertificateTimestamps { + envoyCert.SignedCertificateTimestamp = append(envoyCert.SignedCertificateTimestamp, + inlineBytes(scts)) + } + + return &envoy_extensions_transport_sockets_tls_v3.DownstreamTlsContext{ + CommonTlsContext: &envoy_extensions_transport_sockets_tls_v3.CommonTlsContext{ + TlsCertificates: []*envoy_extensions_transport_sockets_tls_v3.TlsCertificate{envoyCert}, + AlpnProtocols: []string{"h2", "http/1.1"}, + }, + } +} + +func (srv *Server) getAllRouteableDomains(options *config.Options, addr string) []string { lookup := map[string]struct{}{} if config.IsAuthenticate(options.Services) && addr == options.Addr { lookup[urlutil.StripPort(options.AuthenticateURL.Host)] = struct{}{} diff --git a/internal/controlplane/xds_routes.go b/internal/controlplane/xds_routes.go index 986bc0bc8..166daf5fe 100644 --- a/internal/controlplane/xds_routes.go +++ b/internal/controlplane/xds_routes.go @@ -37,7 +37,7 @@ func (srv *Server) buildGRPCRoutes() []*envoy_config_route_v3.Route { }} } -func (srv *Server) buildPomeriumHTTPRoutes(options config.Options, domain string) []*envoy_config_route_v3.Route { +func (srv *Server) buildPomeriumHTTPRoutes(options *config.Options, domain string) []*envoy_config_route_v3.Route { routes := []*envoy_config_route_v3.Route{ srv.buildControlPlanePathRoute("/ping"), srv.buildControlPlanePathRoute("/healthz"), @@ -95,7 +95,7 @@ func (srv *Server) buildControlPlanePrefixRoute(prefix string) *envoy_config_rou } } -func (srv *Server) buildPolicyRoutes(options config.Options, domain string) []*envoy_config_route_v3.Route { +func (srv *Server) buildPolicyRoutes(options *config.Options, domain string) []*envoy_config_route_v3.Route { var routes []*envoy_config_route_v3.Route for i, policy := range options.Policies { if policy.Source.Hostname() != domain { diff --git a/internal/cryptutil/certificates.go b/internal/cryptutil/certificates.go index 6db4d245d..b682c5857 100644 --- a/internal/cryptutil/certificates.go +++ b/internal/cryptutil/certificates.go @@ -2,17 +2,23 @@ package cryptutil import ( "crypto/ecdsa" + "crypto/rand" + "crypto/rsa" "crypto/tls" "crypto/x509" + "crypto/x509/pkix" "encoding/base64" "encoding/pem" "errors" "fmt" "io/ioutil" + "math/big" + "net" + "time" ) -// CertifcateFromBase64 returns an X509 pair from a base64 encoded blob. -func CertifcateFromBase64(cert, key string) (*tls.Certificate, error) { +// CertificateFromBase64 returns an X509 pair from a base64 encoded blob. +func CertificateFromBase64(cert, key string) (*tls.Certificate, error) { decodedCert, err := base64.StdEncoding.DecodeString(cert) if err != nil { return nil, fmt.Errorf("failed to decode certificate cert %v: %w", decodedCert, err) @@ -135,3 +141,59 @@ func EncodePrivateKey(key *ecdsa.PrivateKey) ([]byte, error) { return pem.EncodeToMemory(keyBlock), nil } + +// GenerateSelfSignedCertificate generates a self-signed TLS certificate. +// +// mostly copied from https://golang.org/src/crypto/tls/generate_cert.go +func GenerateSelfSignedCertificate(domain string) (*tls.Certificate, error) { + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return nil, fmt.Errorf("failed to geneate private key: %w", err) + } + + serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) + serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) + if err != nil { + return nil, fmt.Errorf("failed to generate serial number: %w", err) + } + + template := x509.Certificate{ + SerialNumber: serialNumber, + Subject: pkix.Name{ + Organization: []string{"Pomerium"}, + }, + NotBefore: time.Now().Add(-time.Minute * 10), + NotAfter: time.Now().Add(time.Hour * 24 * 365), + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + } + if ip := net.ParseIP(domain); ip != nil { + template.IPAddresses = append(template.IPAddresses, ip) + } else { + template.DNSNames = append(template.DNSNames, domain) + } + + publicKeyBytes, err := x509.CreateCertificate(rand.Reader, + &template, &template, + privateKey.Public(), privateKey, + ) + if err != nil { + return nil, fmt.Errorf("failed to create certificate: %w", err) + } + + privateKeyBytes, err := x509.MarshalPKCS8PrivateKey(privateKey) + if err != nil { + return nil, fmt.Errorf("failed to marshal private key: %w", err) + } + + cert, err := tls.X509KeyPair( + pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: publicKeyBytes}), + pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: privateKeyBytes}), + ) + if err != nil { + return nil, fmt.Errorf("failed to convert x509 bytes into tls certificate: %w", err) + } + + return &cert, nil +} diff --git a/internal/cryptutil/certificates_test.go b/internal/cryptutil/certificates_test.go index 6d764df00..efcb38ec4 100644 --- a/internal/cryptutil/certificates_test.go +++ b/internal/cryptutil/certificates_test.go @@ -55,9 +55,9 @@ func TestCertifcateFromBase64(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - _, err := CertifcateFromBase64(tt.cert, tt.key) + _, err := CertificateFromBase64(tt.cert, tt.key) if (err != nil) != tt.wantErr { - t.Errorf("CertifcateFromBase64() error = %v, wantErr %v", err, tt.wantErr) + t.Errorf("CertificateFromBase64() error = %v, wantErr %v", err, tt.wantErr) return } }) diff --git a/internal/cryptutil/tls.go b/internal/cryptutil/tls.go index a93dbf13b..4b705df99 100644 --- a/internal/cryptutil/tls.go +++ b/internal/cryptutil/tls.go @@ -1,103 +1,51 @@ package cryptutil import ( - "context" "crypto/tls" - "fmt" - "net/http" + "crypto/x509" "github.com/caddyserver/certmagic" - "github.com/go-acme/lego/v3/challenge/tlsalpn01" ) -// NewAutocert automatically retrieves public certificates from the free -// certificate authority Let's Encrypt using HTTP-01 and TLS-ALPN-01 challenges. -// To complete the challenges, the server must be accessible from the internet -// by port 80 or 443 . -// -// https://letsencrypt.org/docs/challenge-types/#http-01-challenge -// https://letsencrypt.org/docs/challenge-types/#tls-alpn-01 -func NewAutocert(tlsConfig *tls.Config, hostnames []string, useStaging bool, path string) (*tls.Config, func(h http.Handler) http.Handler, error) { - certmagic.DefaultACME.Agreed = true - if useStaging { - certmagic.DefaultACME.CA = certmagic.LetsEncryptStagingCA - } - cm := certmagic.NewDefault() - - tlsConfig = newTLSConfigIfEmpty(tlsConfig) - // add existing certs to the cache, and staple OCSP - for _, cert := range tlsConfig.Certificates { - if err := cm.CacheUnmanagedTLSCertificate(cert, nil); err != nil { - return nil, nil, fmt.Errorf("cryptutil: failed caching cert: %w", err) +// GetCertificateForDomain returns the tls Certificate which matches the given domain name. +// It should handle both exact matches and wildcard matches. If none of those match, the first certificate will be used. +// Finally if there are no matching certificates one will be generated. +func GetCertificateForDomain(certificates []tls.Certificate, domain string) (*tls.Certificate, error) { + // first try a direct name match + for _, cert := range certificates { + if matchesDomain(&cert, domain) { + return &cert, nil } } - cm.Storage = &certmagic.FileStorage{Path: path} - acmeConfig := certmagic.NewACMEManager(cm, certmagic.DefaultACME) - cm.Issuer = acmeConfig - // todo(bdd) : add cancellation context? - if err := cm.ManageAsync(context.TODO(), hostnames); err != nil { - return nil, nil, fmt.Errorf("cryptutil: sync failed: %w", err) + + // next use the first cert + if len(certificates) > 0 { + return &certificates[0], nil } - tlsConfig.GetCertificate = cm.GetCertificate - tlsConfig.NextProtos = append(tlsConfig.NextProtos, tlsalpn01.ACMETLS1Protocol) - tlsConfig.BuildNameToCertificate() - return tlsConfig, acmeConfig.HTTPChallengeHandler, nil + // finally fall back to a generated, self-signed certificate + return GenerateSelfSignedCertificate(domain) } -// TLSConfigFromBase64 returns an tls configuration from a base64 encoded blob. -func TLSConfigFromBase64(tlsConfig *tls.Config, cert, key string) (*tls.Config, error) { - tlsConfig = newTLSConfigIfEmpty(tlsConfig) - c, err := CertifcateFromBase64(cert, key) +func matchesDomain(cert *tls.Certificate, domain string) bool { + if cert == nil || len(cert.Certificate) == 0 { + return false + } + + xcert, err := x509.ParseCertificate(cert.Certificate[0]) if err != nil { - return nil, err + return false } - tlsConfig.Certificates = append(tlsConfig.Certificates, *c) - tlsConfig.BuildNameToCertificate() - return tlsConfig, nil -} -// TLSConfigFromFile returns an tls configuration from a certificate and -// key file . -func TLSConfigFromFile(tlsConfig *tls.Config, cert, key string) (*tls.Config, error) { - tlsConfig = newTLSConfigIfEmpty(tlsConfig) - c, err := CertificateFromFile(cert, key) - if err != nil { - return nil, err + if certmagic.MatchWildcard(domain, xcert.Subject.CommonName) { + return true } - tlsConfig.Certificates = append(tlsConfig.Certificates, *c) - tlsConfig.BuildNameToCertificate() - return tlsConfig, nil -} -// newTLSConfigIfEmpty returns an opinionated TLS configuration if config is nil. -// See : -// https://wiki.mozilla.org/Security/Server_Side_TLS#Recommended_configurations -// https://blog.cloudflare.com/exposing-go-on-the-internet/ -// https://github.com/ssllabs/research/wiki/SSL-and-TLS-Deployment-Best-Practices -// https://github.com/golang/go/blob/df91b8044dbe790c69c16058330f545be069cc1f/src/crypto/tls/common.go#L919 -func newTLSConfigIfEmpty(tlsConfig *tls.Config) *tls.Config { - if tlsConfig != nil { - return tlsConfig - } - return &tls.Config{ - MinVersion: tls.VersionTLS12, - // Prioritize cipher suites sped up by AES-NI (AES-GCM) - CipherSuites: []uint16{ - tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, - tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, - tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, - tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, - tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, - tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, - }, - PreferServerCipherSuites: true, - // Use curves which have assembly implementations - CurvePreferences: []tls.CurveID{ - tls.X25519, - tls.CurveP256, - }, - // HTTP/2 must be enabled manually when using http.Serve - NextProtos: []string{"h2", "http/1.1"}, + for _, san := range xcert.DNSNames { + if certmagic.MatchWildcard(domain, san) { + return true + } } + + return false } diff --git a/internal/cryptutil/tls_test.go b/internal/cryptutil/tls_test.go index 83c0887a9..8e92be0f1 100644 --- a/internal/cryptutil/tls_test.go +++ b/internal/cryptutil/tls_test.go @@ -3,47 +3,61 @@ package cryptutil import ( "crypto/tls" "testing" + + "github.com/stretchr/testify/assert" ) -func TestTLSConfigFromBase64(t *testing.T) { - tests := []struct { - name string - cert string - key string - wantErr bool - }{ - {"good", - "LS0tLS1CRUdJTiBDRVJUSUZJQ0FURS0tLS0tCk1JSUVJVENDQWdtZ0F3SUJBZ0lSQVBqTEJxS1lwcWU0ekhQc0dWdFR6T0F3RFFZSktvWklodmNOQVFFTEJRQXcKRWpFUU1BNEdBMVVFQXhNSFoyOXZaQzFqWVRBZUZ3MHhPVEE0TVRBeE9EUTVOREJhRncweU1UQXlNVEF4TnpRdwpNREZhTUJNeEVUQVBCZ05WQkFNVENIQnZiV1Z5YVhWdE1JSUJJakFOQmdrcWhraUc5dzBCQVFFRkFBT0NBUThBCk1JSUJDZ0tDQVFFQTY3S2pxbVFZR3EwTVZ0QUNWcGVDbVhtaW5sUWJEUEdMbXNaQVVFd3VlSFFucnQzV3R2cEQKT202QWxhSk1VblcrSHU1NWpqb2thbEtlVmpUS21nWUdicVV6VkRvTWJQRGFIZWtsdGRCVE1HbE9VRnNQNFVKUwpEck80emROK3pvNDI4VFgyUG5HMkZDZFZLR3k0UEU4aWxIYldMY3I4NzFZalY1MWZ3OENMRFg5UFpKTnU4NjFDCkY3VjlpRUptNnNTZlFsbW5oTjhqMytXelZiUFFOeTFXc1I3aTllOWo2M0VxS3QyMlE5T1hMK1dBY0tza29JU20KQ05WUlVBalU4WVJWY2dRSkIrelEzNEFRUGx6ME9wNU8vUU4vTWVkamFGOHdMUytpdi96dmlTOGNxUGJ4bzZzTApxNkZOVGx0ay9Ra3hlQ2VLS1RRZS8za1BZdlFBZG5sNjVRSURBUUFCbzNFd2J6QU9CZ05WSFE4QkFmOEVCQU1DCkE3Z3dIUVlEVlIwbEJCWXdGQVlJS3dZQkJRVUhBd0VHQ0NzR0FRVUZCd01DTUIwR0ExVWREZ1FXQkJRQ1FYbWIKc0hpcS9UQlZUZVhoQ0dpNjhrVy9DakFmQmdOVkhTTUVHREFXZ0JSNTRKQ3pMRlg0T0RTQ1J0dWNBUGZOdVhWegpuREFOQmdrcWhraUc5dzBCQVFzRkFBT0NBZ0VBcm9XL2trMllleFN5NEhaQXFLNDVZaGQ5ay9QVTFiaDlFK1BRCk5jZFgzTUdEY2NDRUFkc1k4dll3NVE1cnhuMGFzcSt3VGFCcGxoYS9rMi9VVW9IQ1RqUVp1Mk94dEF3UTdPaWIKVE1tMEorU3NWT3d4YnFQTW9rK1RqVE16NFdXaFFUTzVwRmNoZDZXZXNCVHlJNzJ0aG1jcDd1c2NLU2h3YktIegpQY2h1QTQ4SzhPdi96WkxmZnduQVNZb3VCczJjd1ZiRDI3ZXZOMzdoMGFzR1BrR1VXdm1PSDduTHNVeTh3TTdqCkNGL3NwMmJmTC9OYVdNclJnTHZBMGZMS2pwWTQrVEpPbkVxQmxPcCsrbHlJTEZMcC9qMHNybjRNUnlKK0t6UTEKR1RPakVtQ1QvVEFtOS9XSThSL0FlYjcwTjEzTytYNEtaOUJHaDAxTzN3T1Vqd3BZZ3lxSnNoRnNRUG50VmMrSQpKQmF4M2VQU3NicUcwTFkzcHdHUkpRNmMrd1lxdGk2Y0tNTjliYlRkMDhCNUk1N1RRTHhNcUoycTFnWmw1R1VUCmVFZGNWRXltMnZmd0NPd0lrbGNBbThxTm5kZGZKV1FabE5VaHNOVWFBMkVINnlDeXdaZm9aak9hSDEwTXowV20KeTNpZ2NSZFQ3Mi9NR2VkZk93MlV0MVVvRFZmdEcxcysrditUQ1lpNmpUQU05dkZPckJ4UGlOeGFkUENHR2NZZAowakZIc2FWOGFPV1dQQjZBQ1JteHdDVDdRTnRTczM2MlpIOUlFWWR4Q00yMDUrZmluVHhkOUcwSmVRRTd2Kyt6CldoeWo2ZmJBWUIxM2wvN1hkRnpNSW5BOGxpekdrVHB2RHMxeTBCUzlwV3ppYmhqbVFoZGZIejdCZGpGTHVvc2wKZzlNZE5sND0KLS0tLS1FTkQgQ0VSVElGSUNBVEUtLS0tLQo=", - "LS0tLS1CRUdJTiBSU0EgUFJJVkFURSBLRVktLS0tLQpNSUlFcGdJQkFBS0NBUUVBNjdLanFtUVlHcTBNVnRBQ1ZwZUNtWG1pbmxRYkRQR0xtc1pBVUV3dWVIUW5ydDNXCnR2cERPbTZBbGFKTVVuVytIdTU1ampva2FsS2VWalRLbWdZR2JxVXpWRG9NYlBEYUhla2x0ZEJUTUdsT1VGc1AKNFVKU0RyTzR6ZE4rem80MjhUWDJQbkcyRkNkVktHeTRQRThpbEhiV0xjcjg3MVlqVjUxZnc4Q0xEWDlQWkpOdQo4NjFDRjdWOWlFSm02c1NmUWxtbmhOOGozK1d6VmJQUU55MVdzUjdpOWU5ajYzRXFLdDIyUTlPWEwrV0FjS3NrCm9JU21DTlZSVUFqVThZUlZjZ1FKQit6UTM0QVFQbHowT3A1Ty9RTi9NZWRqYUY4d0xTK2l2L3p2aVM4Y3FQYngKbzZzTHE2Rk5UbHRrL1FreGVDZUtLVFFlLzNrUFl2UUFkbmw2NVFJREFRQUJBb0lCQVFEQVQ0eXN2V2pSY3pxcgpKcU9SeGFPQTJEY3dXazJML1JXOFhtQWhaRmRTWHV2MkNQbGxhTU1yelBmTG41WUlmaHQzSDNzODZnSEdZc3pnClo4aWJiYWtYNUdFQ0t5N3lRSDZuZ3hFS3pRVGpiampBNWR3S0h0UFhQUnJmamQ1Y2FMczVpcDcxaWxCWEYxU3IKWERIaXUycnFtaC9kVTArWGRMLzNmK2VnVDl6bFQ5YzRyUm84dnZueWNYejFyMnVhRVZ2VExsWHVsb2NpeEVrcgoySjlTMmxveWFUb2tFTnNlMDNpSVdaWnpNNElZcVowOGJOeG9IWCszQXVlWExIUStzRkRKMlhaVVdLSkZHMHUyClp3R2w3YlZpRTFQNXdiQUdtZzJDeDVCN1MrdGQyUEpSV3Frb2VxY3F2RVdCc3RFL1FEcDFpVThCOHpiQXd0Y3IKZHc5TXZ6Q2hBb0dCQVBObzRWMjF6MGp6MWdEb2tlTVN5d3JnL2E4RkJSM2R2Y0xZbWV5VXkybmd3eHVucnFsdwo2U2IrOWdrOGovcXEvc3VQSDhVdzNqSHNKYXdGSnNvTkVqNCt2b1ZSM3UrbE5sTEw5b21rMXBoU0dNdVp0b3huCm5nbUxVbkJUMGI1M3BURkJ5WGsveE5CbElreWdBNlg5T2MreW5na3RqNlRyVnMxUERTdnVJY0s1QW9HQkFQZmoKcEUzR2F6cVFSemx6TjRvTHZmQWJBdktCZ1lPaFNnemxsK0ZLZkhzYWJGNkdudFd1dWVhY1FIWFpYZTA1c2tLcApXN2xYQ3dqQU1iUXI3QmdlazcrOSszZElwL1RnYmZCYnN3Syt6Vng3Z2doeWMrdytXRWExaHByWTZ6YXdxdkFaCkhRU2lMUEd1UGp5WXBQa1E2ZFdEczNmWHJGZ1dlTmd4SkhTZkdaT05Bb0dCQUt5WTF3MUM2U3Y2c3VuTC8vNTcKQ2Z5NTAwaXlqNUZBOWRqZkRDNWt4K1JZMnlDV0ExVGsybjZyVmJ6dzg4czBTeDMrYS9IQW1CM2dMRXBSRU5NKwo5NHVwcENFWEQ3VHdlcGUxUnlrTStKbmp4TzlDSE41c2J2U25sUnBQWlMvZzJRTVhlZ3grK2trbkhXNG1ITkFyCndqMlRrMXBBczFXbkJ0TG9WaGVyY01jSkFvR0JBSTYwSGdJb0Y5SysvRUcyY21LbUg5SDV1dGlnZFU2eHEwK0IKWE0zMWMzUHE0amdJaDZlN3pvbFRxa2d0dWtTMjBraE45dC9ibkI2TmhnK1N1WGVwSXFWZldVUnlMejVwZE9ESgo2V1BMTTYzcDdCR3cwY3RPbU1NYi9VRm5Yd0U4OHlzRlNnOUF6VjdVVUQvU0lDYkI5ZHRVMWh4SHJJK0pZRWdWCkFrZWd6N2lCQW9HQkFJRncrQVFJZUIwM01UL0lCbGswNENQTDJEak0rNDhoVGRRdjgwMDBIQU9mUWJrMEVZUDEKQ2FLR3RDbTg2MXpBZjBzcS81REtZQ0l6OS9HUzNYRk00Qm1rRk9nY1NXVENPNmZmTGdLM3FmQzN4WDJudlpIOQpYZGNKTDQrZndhY0x4c2JJKzhhUWNOVHRtb3pkUjEzQnNmUmIrSGpUL2o3dkdrYlFnSkhCT0syegotLS0tLUVORCBSU0EgUFJJVkFURSBLRVktLS0tLQo=", - false}, - {"bad cert", - "!=", - "LS0tLS1CRUdJTiBSU0EgUFJJVkFURSBLRVktLS0tLQpNSUlFcGdJQkFBS0NBUUVBNjdLanFtUVlHcTBNVnRBQ1ZwZUNtWG1pbmxRYkRQR0xtc1pBVUV3dWVIUW5ydDNXCnR2cERPbTZBbGFKTVVuVytIdTU1ampva2FsS2VWalRLbWdZR2JxVXpWRG9NYlBEYUhla2x0ZEJUTUdsT1VGc1AKNFVKU0RyTzR6ZE4rem80MjhUWDJQbkcyRkNkVktHeTRQRThpbEhiV0xjcjg3MVlqVjUxZnc4Q0xEWDlQWkpOdQo4NjFDRjdWOWlFSm02c1NmUWxtbmhOOGozK1d6VmJQUU55MVdzUjdpOWU5ajYzRXFLdDIyUTlPWEwrV0FjS3NrCm9JU21DTlZSVUFqVThZUlZjZ1FKQit6UTM0QVFQbHowT3A1Ty9RTi9NZWRqYUY4d0xTK2l2L3p2aVM4Y3FQYngKbzZzTHE2Rk5UbHRrL1FreGVDZUtLVFFlLzNrUFl2UUFkbmw2NVFJREFRQUJBb0lCQVFEQVQ0eXN2V2pSY3pxcgpKcU9SeGFPQTJEY3dXazJML1JXOFhtQWhaRmRTWHV2MkNQbGxhTU1yelBmTG41WUlmaHQzSDNzODZnSEdZc3pnClo4aWJiYWtYNUdFQ0t5N3lRSDZuZ3hFS3pRVGpiampBNWR3S0h0UFhQUnJmamQ1Y2FMczVpcDcxaWxCWEYxU3IKWERIaXUycnFtaC9kVTArWGRMLzNmK2VnVDl6bFQ5YzRyUm84dnZueWNYejFyMnVhRVZ2VExsWHVsb2NpeEVrcgoySjlTMmxveWFUb2tFTnNlMDNpSVdaWnpNNElZcVowOGJOeG9IWCszQXVlWExIUStzRkRKMlhaVVdLSkZHMHUyClp3R2w3YlZpRTFQNXdiQUdtZzJDeDVCN1MrdGQyUEpSV3Frb2VxY3F2RVdCc3RFL1FEcDFpVThCOHpiQXd0Y3IKZHc5TXZ6Q2hBb0dCQVBObzRWMjF6MGp6MWdEb2tlTVN5d3JnL2E4RkJSM2R2Y0xZbWV5VXkybmd3eHVucnFsdwo2U2IrOWdrOGovcXEvc3VQSDhVdzNqSHNKYXdGSnNvTkVqNCt2b1ZSM3UrbE5sTEw5b21rMXBoU0dNdVp0b3huCm5nbUxVbkJUMGI1M3BURkJ5WGsveE5CbElreWdBNlg5T2MreW5na3RqNlRyVnMxUERTdnVJY0s1QW9HQkFQZmoKcEUzR2F6cVFSemx6TjRvTHZmQWJBdktCZ1lPaFNnemxsK0ZLZkhzYWJGNkdudFd1dWVhY1FIWFpYZTA1c2tLcApXN2xYQ3dqQU1iUXI3QmdlazcrOSszZElwL1RnYmZCYnN3Syt6Vng3Z2doeWMrdytXRWExaHByWTZ6YXdxdkFaCkhRU2lMUEd1UGp5WXBQa1E2ZFdEczNmWHJGZ1dlTmd4SkhTZkdaT05Bb0dCQUt5WTF3MUM2U3Y2c3VuTC8vNTcKQ2Z5NTAwaXlqNUZBOWRqZkRDNWt4K1JZMnlDV0ExVGsybjZyVmJ6dzg4czBTeDMrYS9IQW1CM2dMRXBSRU5NKwo5NHVwcENFWEQ3VHdlcGUxUnlrTStKbmp4TzlDSE41c2J2U25sUnBQWlMvZzJRTVhlZ3grK2trbkhXNG1ITkFyCndqMlRrMXBBczFXbkJ0TG9WaGVyY01jSkFvR0JBSTYwSGdJb0Y5SysvRUcyY21LbUg5SDV1dGlnZFU2eHEwK0IKWE0zMWMzUHE0amdJaDZlN3pvbFRxa2d0dWtTMjBraE45dC9ibkI2TmhnK1N1WGVwSXFWZldVUnlMejVwZE9ESgo2V1BMTTYzcDdCR3cwY3RPbU1NYi9VRm5Yd0U4OHlzRlNnOUF6VjdVVUQvU0lDYkI5ZHRVMWh4SHJJK0pZRWdWCkFrZWd6N2lCQW9HQkFJRncrQVFJZUIwM01UL0lCbGswNENQTDJEak0rNDhoVGRRdjgwMDBIQU9mUWJrMEVZUDEKQ2FLR3RDbTg2MXpBZjBzcS81REtZQ0l6OS9HUzNYRk00Qm1rRk9nY1NXVENPNmZmTGdLM3FmQzN4WDJudlpIOQpYZGNKTDQrZndhY0x4c2JJKzhhUWNOVHRtb3pkUjEzQnNmUmIrSGpUL2o3dkdrYlFnSkhCT0syegotLS0tLUVORCBSU0EgUFJJVkFURSBLRVktLS0tLQo=", - true}, - {"bad key", - "LS0tLS1CRUdJTiBDRVJUSUZJQ0FURS0tLS0tCk1JSUVJVENDQWdtZ0F3SUJBZ0lSQVBqTEJxS1lwcWU0ekhQc0dWdFR6T0F3RFFZSktvWklodmNOQVFFTEJRQXcKRWpFUU1BNEdBMVVFQXhNSFoyOXZaQzFqWVRBZUZ3MHhPVEE0TVRBeE9EUTVOREJhRncweU1UQXlNVEF4TnpRdwpNREZhTUJNeEVUQVBCZ05WQkFNVENIQnZiV1Z5YVhWdE1JSUJJakFOQmdrcWhraUc5dzBCQVFFRkFBT0NBUThBCk1JSUJDZ0tDQVFFQTY3S2pxbVFZR3EwTVZ0QUNWcGVDbVhtaW5sUWJEUEdMbXNaQVVFd3VlSFFucnQzV3R2cEQKT202QWxhSk1VblcrSHU1NWpqb2thbEtlVmpUS21nWUdicVV6VkRvTWJQRGFIZWtsdGRCVE1HbE9VRnNQNFVKUwpEck80emROK3pvNDI4VFgyUG5HMkZDZFZLR3k0UEU4aWxIYldMY3I4NzFZalY1MWZ3OENMRFg5UFpKTnU4NjFDCkY3VjlpRUptNnNTZlFsbW5oTjhqMytXelZiUFFOeTFXc1I3aTllOWo2M0VxS3QyMlE5T1hMK1dBY0tza29JU20KQ05WUlVBalU4WVJWY2dRSkIrelEzNEFRUGx6ME9wNU8vUU4vTWVkamFGOHdMUytpdi96dmlTOGNxUGJ4bzZzTApxNkZOVGx0ay9Ra3hlQ2VLS1RRZS8za1BZdlFBZG5sNjVRSURBUUFCbzNFd2J6QU9CZ05WSFE4QkFmOEVCQU1DCkE3Z3dIUVlEVlIwbEJCWXdGQVlJS3dZQkJRVUhBd0VHQ0NzR0FRVUZCd01DTUIwR0ExVWREZ1FXQkJRQ1FYbWIKc0hpcS9UQlZUZVhoQ0dpNjhrVy9DakFmQmdOVkhTTUVHREFXZ0JSNTRKQ3pMRlg0T0RTQ1J0dWNBUGZOdVhWegpuREFOQmdrcWhraUc5dzBCQVFzRkFBT0NBZ0VBcm9XL2trMllleFN5NEhaQXFLNDVZaGQ5ay9QVTFiaDlFK1BRCk5jZFgzTUdEY2NDRUFkc1k4dll3NVE1cnhuMGFzcSt3VGFCcGxoYS9rMi9VVW9IQ1RqUVp1Mk94dEF3UTdPaWIKVE1tMEorU3NWT3d4YnFQTW9rK1RqVE16NFdXaFFUTzVwRmNoZDZXZXNCVHlJNzJ0aG1jcDd1c2NLU2h3YktIegpQY2h1QTQ4SzhPdi96WkxmZnduQVNZb3VCczJjd1ZiRDI3ZXZOMzdoMGFzR1BrR1VXdm1PSDduTHNVeTh3TTdqCkNGL3NwMmJmTC9OYVdNclJnTHZBMGZMS2pwWTQrVEpPbkVxQmxPcCsrbHlJTEZMcC9qMHNybjRNUnlKK0t6UTEKR1RPakVtQ1QvVEFtOS9XSThSL0FlYjcwTjEzTytYNEtaOUJHaDAxTzN3T1Vqd3BZZ3lxSnNoRnNRUG50VmMrSQpKQmF4M2VQU3NicUcwTFkzcHdHUkpRNmMrd1lxdGk2Y0tNTjliYlRkMDhCNUk1N1RRTHhNcUoycTFnWmw1R1VUCmVFZGNWRXltMnZmd0NPd0lrbGNBbThxTm5kZGZKV1FabE5VaHNOVWFBMkVINnlDeXdaZm9aak9hSDEwTXowV20KeTNpZ2NSZFQ3Mi9NR2VkZk93MlV0MVVvRFZmdEcxcysrditUQ1lpNmpUQU05dkZPckJ4UGlOeGFkUENHR2NZZAowakZIc2FWOGFPV1dQQjZBQ1JteHdDVDdRTnRTczM2MlpIOUlFWWR4Q00yMDUrZmluVHhkOUcwSmVRRTd2Kyt6CldoeWo2ZmJBWUIxM2wvN1hkRnpNSW5BOGxpekdrVHB2RHMxeTBCUzlwV3ppYmhqbVFoZGZIejdCZGpGTHVvc2wKZzlNZE5sND0KLS0tLS1FTkQgQ0VSVElGSUNBVEUtLS0tLQo=", - "!=", - true}, +func TestGetCertificateForDomain(t *testing.T) { + gen := func(t *testing.T, domain string) *tls.Certificate { + cert, err := GenerateSelfSignedCertificate(domain) + if !assert.NoError(t, err, "error generating certificate for: %s", domain) { + t.FailNow() + } + return cert } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - _, err := TLSConfigFromBase64(nil, tt.cert, tt.key) - if (err != nil) != tt.wantErr { - t.Errorf("TLSConfigFromBase64() error = %v, wantErr %v", err, tt.wantErr) - return - } - }) - } -} -func TestTLSConfigFromFile(t *testing.T) { - cfg, err := TLSConfigFromFile(nil, "testdata/example-cert.pem", "testdata/example-key.pem") - if err != nil { - t.Fatal(err) - } - listener, err := tls.Listen("tcp", ":0", cfg) - if err != nil { - t.Fatal(err) - } - _ = listener + t.Run("exact match", func(t *testing.T) { + certs := []tls.Certificate{ + *gen(t, "a.example.com"), + *gen(t, "b.example.com"), + } + + found, err := GetCertificateForDomain(certs, "b.example.com") + if !assert.NoError(t, err) { + return + } + assert.Equal(t, &certs[1], found) + }) + t.Run("wildcard match", func(t *testing.T) { + certs := []tls.Certificate{ + *gen(t, "a.example.com"), + *gen(t, "*.example.com"), + } + + found, err := GetCertificateForDomain(certs, "b.example.com") + if !assert.NoError(t, err) { + return + } + assert.Equal(t, &certs[1], found) + }) + t.Run("no name match", func(t *testing.T) { + certs := []tls.Certificate{ + *gen(t, "a.example.com"), + } + + found, err := GetCertificateForDomain(certs, "b.example.com") + if !assert.NoError(t, err) { + return + } + assert.Equal(t, &certs[0], found) + }) + t.Run("generate", func(t *testing.T) { + certs := []tls.Certificate{} + + found, err := GetCertificateForDomain(certs, "b.example.com") + if !assert.NoError(t, err) { + return + } + assert.NotNil(t, found) + }) } diff --git a/internal/envoy/envoy.go b/internal/envoy/envoy.go index 9794b2c45..041d3d98a 100644 --- a/internal/envoy/envoy.go +++ b/internal/envoy/envoy.go @@ -15,8 +15,9 @@ import ( "strings" "github.com/natefinch/atomic" - "github.com/pomerium/pomerium/internal/log" "github.com/rs/zerolog" + + "github.com/pomerium/pomerium/internal/log" ) const (