envoy: support autocert (#695)

* envoy: support autocert

* envoy: fallback to http host routing if sni fails to match

* update comment

* envoy: renew certs when necessary

* fix tests
This commit is contained in:
Caleb Doxsey 2020-05-13 13:07:04 -06:00 committed by Travis Groth
parent 0c1ac5a575
commit dccec1e646
18 changed files with 689 additions and 391 deletions

View file

@ -5,8 +5,10 @@ import (
"flag" "flag"
"fmt" "fmt"
"net" "net"
"os"
"os/signal"
"sync" "sync"
"time" "syscall"
"github.com/pomerium/pomerium/authenticate" "github.com/pomerium/pomerium/authenticate"
"github.com/pomerium/pomerium/authorize" "github.com/pomerium/pomerium/authorize"
@ -24,7 +26,6 @@ import (
"github.com/pomerium/pomerium/proxy" "github.com/pomerium/pomerium/proxy"
envoy_service_auth_v2 "github.com/envoyproxy/go-control-plane/envoy/service/auth/v2" envoy_service_auth_v2 "github.com/envoyproxy/go-control-plane/envoy/service/auth/v2"
"github.com/fsnotify/fsnotify"
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
) )
@ -32,12 +33,12 @@ var versionFlag = flag.Bool("version", false, "prints the version")
var configFile = flag.String("config", "", "Specify configuration file location") var configFile = flag.String("config", "", "Specify configuration file location")
func main() { func main() {
if err := run(); err != nil { if err := run(context.Background()); err != nil {
log.Fatal().Err(err).Msg("cmd/pomerium") log.Fatal().Err(err).Msg("cmd/pomerium")
} }
} }
func run() error { func run(ctx context.Context) error {
flag.Parse() flag.Parse()
if *versionFlag { if *versionFlag {
fmt.Println(version.FullVersion()) fmt.Println(version.FullVersion())
@ -51,18 +52,12 @@ func run() error {
log.Info().Str("version", version.FullVersion()).Msg("cmd/pomerium") log.Info().Str("version", version.FullVersion()).Msg("cmd/pomerium")
var wg sync.WaitGroup if err := setupMetrics(opt); err != nil {
if err := setupMetrics(opt, &wg); err != nil {
return err return err
} }
if err := setupTracing(opt); err != nil { if err := setupTracing(opt); err != nil {
return err return err
} }
if err := setupHTTPRedirectServer(opt, &wg); err != nil {
return err
}
ctx := context.Background()
// setup the control plane // setup the control plane
controlPlane, err := controlplane.NewServer() controlPlane, err := controlplane.NewServer()
@ -99,17 +94,19 @@ func run() error {
} }
// start the config change listener // start the config change listener
opt.OnConfigChange(func(e fsnotify.Event) { go config.WatchChanges(*configFile, opt, optionsUpdaters)
log.Info().Str("file", e.Name).Msg("cmd/pomerium: config file changed")
opt = config.HandleConfigUpdate(*configFile, opt, optionsUpdaters) ctx, cancel := context.WithCancel(ctx)
}) go func() {
ch := make(chan os.Signal, 2)
signal.Notify(ch, os.Interrupt)
signal.Notify(ch, syscall.SIGTERM)
<-ch
cancel()
}()
// run everything // run everything
eg, ctx := errgroup.WithContext(ctx) eg, ctx := errgroup.WithContext(ctx)
eg.Go(func() error {
wg.Wait()
return nil
})
eg.Go(func() error { eg.Go(func() error {
return controlPlane.Run(ctx) return controlPlane.Run(ctx)
}) })
@ -172,7 +169,7 @@ func setupCache(opt *config.Options, controlPlane *controlplane.Server) error {
return nil return nil
} }
func setupMetrics(opt *config.Options, wg *sync.WaitGroup) error { func setupMetrics(opt *config.Options) error {
if opt.MetricsAddr != "" { if opt.MetricsAddr != "" {
handler, err := metrics.PrometheusHandler() handler, err := metrics.PrometheusHandler()
if err != nil { if err != nil {
@ -185,11 +182,11 @@ func setupMetrics(opt *config.Options, wg *sync.WaitGroup) error {
Insecure: true, Insecure: true,
Service: "metrics", Service: "metrics",
} }
srv, err := httputil.NewServer(serverOpts, handler, wg) var wg sync.WaitGroup
_, err = httputil.NewServer(serverOpts, handler, &wg)
if err != nil { if err != nil {
return err return err
} }
go httputil.Shutdown(srv)
} }
return nil return nil
} }
@ -222,40 +219,3 @@ func setupTracing(opt *config.Options) error {
} }
return nil return nil
} }
func setupHTTPRedirectServer(opt *config.Options, wg *sync.WaitGroup) error {
if opt.HTTPRedirectAddr != "" {
serverOpts := httputil.ServerOptions{
Addr: opt.HTTPRedirectAddr,
Insecure: true,
Service: "HTTP->HTTPS Redirect",
ReadHeaderTimeout: 5 * time.Second,
ReadTimeout: 5 * time.Second,
WriteTimeout: 5 * time.Second,
IdleTimeout: 5 * time.Second,
}
h := httputil.RedirectHandler()
if opt.AutoCert {
h = opt.AutoCertHandler(h)
}
srv, err := httputil.NewServer(&serverOpts, h, wg)
if err != nil {
return err
}
go httputil.Shutdown(srv)
}
return nil
}
func httpServerOptions(opt *config.Options) *httputil.ServerOptions {
return &httputil.ServerOptions{
Addr: opt.Addr,
TLSConfig: opt.TLSConfig,
Insecure: opt.InsecureServer,
ReadTimeout: opt.ReadTimeout,
WriteTimeout: opt.WriteTimeout,
ReadHeaderTimeout: opt.ReadHeaderTimeout,
IdleTimeout: opt.IdleTimeout,
Service: opt.Services,
}
}

View file

@ -1,36 +1,17 @@
package main package main
import ( import (
"context"
"io/ioutil" "io/ioutil"
"os" "os"
"os/signal" "os/signal"
"sync"
"syscall" "syscall"
"testing" "testing"
"time" "time"
"github.com/google/go-cmp/cmp"
"github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/httputil"
) )
func Test_httpServerOptions(t *testing.T) {
tests := []struct {
name string
opt *config.Options
want *httputil.ServerOptions
}{
{"simple convert", &config.Options{Addr: ":80"}, &httputil.ServerOptions{Addr: ":80"}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if diff := cmp.Diff(httpServerOptions(tt.opt), tt.want); diff != "" {
t.Errorf("httpServerOptions() = \n %s", diff)
}
})
}
}
func Test_setupTracing(t *testing.T) { func Test_setupTracing(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
@ -60,41 +41,9 @@ func Test_setupMetrics(t *testing.T) {
c := make(chan os.Signal, 1) c := make(chan os.Signal, 1)
signal.Notify(c, syscall.SIGINT) signal.Notify(c, syscall.SIGINT)
defer signal.Stop(c) defer signal.Stop(c)
var wg sync.WaitGroup setupMetrics(tt.opt)
setupMetrics(tt.opt, &wg)
syscall.Kill(syscall.Getpid(), syscall.SIGINT) syscall.Kill(syscall.Getpid(), syscall.SIGINT)
waitSig(t, c, syscall.SIGINT) waitSig(t, c, syscall.SIGINT)
})
}
}
func Test_setupHTTPRedirectServer(t *testing.T) {
tests := []struct {
name string
opt *config.Options
wantErr bool
}{
{"dont register anything", &config.Options{}, false},
{"good redirect server", &config.Options{HTTPRedirectAddr: "localhost:0"}, false},
{"bad redirect server port", &config.Options{HTTPRedirectAddr: "localhost:-1"}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := make(chan os.Signal, 1)
var wg sync.WaitGroup
signal.Notify(c, syscall.SIGINT)
defer signal.Stop(c)
err := setupHTTPRedirectServer(tt.opt, &wg)
if (err != nil) != tt.wantErr {
t.Errorf("run() error = %v, wantErr %v", err, tt.wantErr)
}
syscall.Kill(syscall.Getpid(), syscall.SIGINT)
waitSig(t, c, syscall.SIGINT)
}) })
} }
} }
@ -274,16 +223,11 @@ func Test_run(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
configFile = &fn configFile = &fn
proc, err := os.FindProcess(os.Getpid())
if err != nil {
t.Fatal(err)
}
go func() {
time.Sleep(time.Millisecond * 500)
proc.Signal(os.Interrupt)
}()
err = run() ctx, clearTimeout := context.WithTimeout(context.Background(), 500*time.Millisecond)
defer clearTimeout()
err = run(ctx)
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("run() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("run() error = %v, wantErr %v", err, tt.wantErr)
} }

102
config/autocert.go Normal file
View file

@ -0,0 +1,102 @@
package config
import (
"context"
"fmt"
"net/http"
"sync"
"github.com/caddyserver/certmagic"
"github.com/pomerium/pomerium/internal/log"
)
// AutocertManager manages Let's Encrypt certificates based on configuration options.
var AutocertManager = newAutocertManager()
type autocertManager struct {
mu sync.RWMutex
certmagic *certmagic.Config
acmeMgr *certmagic.ACMEManager
}
func newAutocertManager() *autocertManager {
mgr := &autocertManager{}
return mgr
}
func (mgr *autocertManager) getConfig(options *Options) (*certmagic.Config, error) {
mgr.mu.Lock()
defer mgr.mu.Unlock()
cm := mgr.certmagic
if cm == nil {
cm = certmagic.NewDefault()
}
cm.OnDemand = nil // disable on-demand
cm.Storage = &certmagic.FileStorage{Path: options.AutoCertFolder}
// add existing certs to the cache, and staple OCSP
for _, cert := range options.Certificates {
if err := cm.CacheUnmanagedTLSCertificate(cert, nil); err != nil {
return nil, fmt.Errorf("config: failed caching cert: %w", err)
}
}
acmeMgr := certmagic.NewACMEManager(cm, certmagic.DefaultACME)
acmeMgr.Agreed = true
if options.AutoCertUseStaging {
acmeMgr.CA = certmagic.LetsEncryptStagingCA
}
acmeMgr.DisableTLSALPNChallenge = true
cm.Issuer = acmeMgr
mgr.acmeMgr = acmeMgr
return cm, nil
}
func (mgr *autocertManager) update(options *Options) error {
if !options.AutoCert {
return nil
}
cm, err := mgr.getConfig(options)
if err != nil {
return err
}
for _, domain := range options.sourceHostnames() {
cert, err := cm.CacheManagedCertificate(domain)
if err != nil {
log.Info().Str("domain", domain).Msg("obtaining certificate")
err = cm.ObtainCert(context.Background(), domain, false)
if err != nil {
return fmt.Errorf("config: failed to obtain client certificate: %w", err)
}
cert, err = cm.CacheManagedCertificate(domain)
}
if err == nil && cert.NeedsRenewal(cm) {
log.Info().Str("domain", domain).Msg("renewing certificate")
err = cm.RenewCert(context.Background(), domain, false)
if err != nil {
return fmt.Errorf("config: failed to renew client certificate: %w", err)
}
cert, err = cm.CacheManagedCertificate(domain)
}
if err == nil {
options.Certificates = append(options.Certificates, cert.Certificate)
} else {
log.Error().Err(err).Msg("config: failed to obtain client certificate")
}
}
return nil
}
func (mgr *autocertManager) HandleHTTPChallenge(w http.ResponseWriter, r *http.Request) bool {
mgr.mu.RLock()
acmeMgr := mgr.acmeMgr
mgr.mu.RUnlock()
if acmeMgr == nil {
return false
}
return acmeMgr.HandleHTTPChallenge(w, r)
}

View file

@ -1,15 +1,16 @@
package config package config
import ( import (
"bytes"
"crypto/tls" "crypto/tls"
"encoding/base64" "encoding/base64"
"errors" "errors"
"fmt" "fmt"
"net/http"
"net/url" "net/url"
"os" "os"
"path/filepath" "path/filepath"
"reflect" "reflect"
"sort"
"strings" "strings"
"time" "time"
@ -68,10 +69,6 @@ type Options struct {
// and renewal from LetsEncrypt. Must be used in conjunction with AutoCertFolder. // and renewal from LetsEncrypt. Must be used in conjunction with AutoCertFolder.
AutoCert bool `mapstructure:"autocert" yaml:"autocert,omitempty"` AutoCert bool `mapstructure:"autocert" yaml:"autocert,omitempty"`
// AutoCertHandler is the HTTP challenge handler used in a http-01 acme
// https://letsencrypt.org/docs/challenge-types/#http-01-challenge
AutoCertHandler func(h http.Handler) http.Handler `hash:"ignore"`
// AutoCertFolder specifies the location to store, and load autocert managed // AutoCertFolder specifies the location to store, and load autocert managed
// TLS certificates. // TLS certificates.
// defaults to $XDG_DATA_HOME/pomerium // defaults to $XDG_DATA_HOME/pomerium
@ -83,7 +80,7 @@ type Options struct {
// https://letsencrypt.org/docs/staging-environment/ // https://letsencrypt.org/docs/staging-environment/
AutoCertUseStaging bool `mapstructure:"autocert_use_staging" yaml:"autocert_use_staging,omitempty"` AutoCertUseStaging bool `mapstructure:"autocert_use_staging" yaml:"autocert_use_staging,omitempty"`
Certificates []certificateFilePair `mapstructure:"certificates" yaml:"certificates,omitempty"` CertificateFiles []certificateFilePair `mapstructure:"certificates" yaml:"certificates,omitempty"`
// Cert and Key is the x509 certificate used to create the HTTPS server. // Cert and Key is the x509 certificate used to create the HTTPS server.
Cert string `mapstructure:"certificate" yaml:"certificate,omitempty"` Cert string `mapstructure:"certificate" yaml:"certificate,omitempty"`
@ -93,7 +90,7 @@ type Options struct {
CertFile string `mapstructure:"certificate_file" yaml:"certificate_file,omitempty"` CertFile string `mapstructure:"certificate_file" yaml:"certificate_file,omitempty"`
KeyFile string `mapstructure:"certificate_key_file" yaml:"certificate_key_file,omitempty"` KeyFile string `mapstructure:"certificate_key_file" yaml:"certificate_key_file,omitempty"`
TLSConfig *tls.Config `hash:"ignore"` Certificates []tls.Certificate `yaml:"-"`
// HttpRedirectAddr, if set, specifies the host and port to run the HTTP // HttpRedirectAddr, if set, specifies the host and port to run the HTTP
// to HTTPS redirect server on. If empty, no redirect server is started. // to HTTPS redirect server on. If empty, no redirect server is started.
@ -534,38 +531,42 @@ func (o *Options) Validate() error {
} }
if o.Cert != "" || o.Key != "" { if o.Cert != "" || o.Key != "" {
o.TLSConfig, err = cryptutil.TLSConfigFromBase64(o.TLSConfig, o.Cert, o.Key) cert, err := cryptutil.CertificateFromBase64(o.Cert, o.Key)
if err != nil { if err != nil {
return fmt.Errorf("config: bad cert base64 %w", err) return fmt.Errorf("config: bad cert base64 %w", err)
} }
o.Certificates = append(o.Certificates, *cert)
} }
if len(o.Certificates) != 0 { for _, c := range o.CertificateFiles {
for _, c := range o.Certificates { cert, err := cryptutil.CertificateFromFile(c.CertFile, c.KeyFile)
o.TLSConfig, err = cryptutil.TLSConfigFromFile(o.TLSConfig, c.CertFile, c.KeyFile)
if err != nil { if err != nil {
return fmt.Errorf("config: bad cert file %w", err) return fmt.Errorf("config: bad cert file %w", err)
} }
} o.Certificates = append(o.Certificates, *cert)
} }
if o.CertFile != "" || o.KeyFile != "" { if o.CertFile != "" || o.KeyFile != "" {
o.TLSConfig, err = cryptutil.TLSConfigFromFile(o.TLSConfig, o.CertFile, o.KeyFile) cert, err := cryptutil.CertificateFromFile(o.CertFile, o.KeyFile)
if err != nil { if err != nil {
return fmt.Errorf("config: bad cert file %w", err) return fmt.Errorf("config: bad cert file %w", err)
} }
o.Certificates = append(o.Certificates, *cert)
} }
if o.AutoCert {
o.TLSConfig, o.AutoCertHandler, err = cryptutil.NewAutocert( RedirectAndAutocertServer.update(o)
o.TLSConfig,
o.sourceHostnames(), err = AutocertManager.update(o)
o.AutoCertUseStaging,
o.AutoCertFolder)
if err != nil { if err != nil {
return fmt.Errorf("config: autocert failed %w", err) return fmt.Errorf("config: failed to setup autocert: %w", err)
} }
}
if !o.InsecureServer && o.TLSConfig == nil { // sort the certificates so we get a consistent hash
sort.Slice(o.Certificates, func(i, j int) bool {
return compareByteSliceSlice(o.Certificates[i].Certificate, o.Certificates[j].Certificate) < 0
})
if !o.InsecureServer && len(o.Certificates) == 0 {
return fmt.Errorf("config: server must be run with `autocert`, " + return fmt.Errorf("config: server must be run with `autocert`, " +
"`insecure_server` or manually provided certificates to start") "`insecure_server` or manually provided certificates to start")
} }
@ -576,13 +577,21 @@ func (o *Options) sourceHostnames() []string {
if len(o.Policies) == 0 { if len(o.Policies) == 0 {
return nil return nil
} }
var h []string
dedupe := map[string]struct{}{}
for _, p := range o.Policies { for _, p := range o.Policies {
h = append(h, p.Source.Hostname()) dedupe[p.Source.Hostname()] = struct{}{}
} }
if o.AuthenticateURL != nil { if o.AuthenticateURL != nil {
h = append(h, o.AuthenticateURL.Hostname()) dedupe[o.AuthenticateURL.Hostname()] = struct{}{}
} }
var h []string
for k := range dedupe {
h = append(h, k)
}
sort.Strings(h)
return h return h
} }
@ -601,10 +610,37 @@ func (o *Options) Checksum() uint64 {
return hash return hash
} }
// HandleConfigUpdate takes configuration file, an existing options struct, and // WatchChanges takes a configuration file, an existing options struct, and
// updates each service in the services slice OptionsUpdater with a new set
// of options if any change is detected. It also periodically rechecks if
// any computed properties have changed.
func WatchChanges(configFile string, opt *Options, services []OptionsUpdater) {
onchange := make(chan struct{}, 1)
ticker := time.NewTicker(10 * time.Minute) // force check every 10 minutes
defer ticker.Stop()
opt.OnConfigChange(func(fs fsnotify.Event) {
log.Info().Str("file", fs.Name).Msg("config: file changed")
select {
case onchange <- struct{}{}:
default:
}
})
for {
select {
case <-onchange:
case <-ticker.C:
}
opt = handleConfigUpdate(configFile, opt, services)
}
}
// handleConfigUpdate takes configuration file, an existing options struct, and
// updates each service in the services slice OptionsUpdater with a new set of // updates each service in the services slice OptionsUpdater with a new set of
// options if any change is detected. // options if any change is detected.
func HandleConfigUpdate(configFile string, opt *Options, services []OptionsUpdater) *Options { func handleConfigUpdate(configFile string, opt *Options, services []OptionsUpdater) *Options {
newOpt, err := NewOptionsFromConfig(configFile) newOpt, err := NewOptionsFromConfig(configFile)
if err != nil { if err != nil {
log.Error().Err(err).Msg("config: could not reload configuration") log.Error().Err(err).Msg("config: could not reload configuration")
@ -648,3 +684,31 @@ func dataDir() string {
} }
return filepath.Join(baseDir, "pomerium") return filepath.Join(baseDir, "pomerium")
} }
func compareByteSliceSlice(a, b [][]byte) int {
sz := min(len(a), len(b))
for i := 0; i < sz; i++ {
switch bytes.Compare(a[i], b[i]) {
case -1:
return -1
case 1:
return 1
}
}
switch {
case len(a) < len(b):
return -1
case len(b) < len(a):
return 1
default:
return 0
}
}
func min(x, y int) int {
if x < y {
return x
}
return y
}

View file

@ -416,7 +416,7 @@ func Test_HandleConfigUpdate(t *testing.T) {
os.Setenv(k, v) os.Setenv(k, v)
defer os.Unsetenv(k) defer os.Unsetenv(k)
} }
HandleConfigUpdate("", oldOpts, []OptionsUpdater{tt.service}) handleConfigUpdate("", oldOpts, []OptionsUpdater{tt.service})
if tt.service.Updated != tt.wantUpdate { if tt.service.Updated != tt.wantUpdate {
t.Errorf("Failed to update config on service") t.Errorf("Failed to update config on service")
} }
@ -441,7 +441,7 @@ func TestOptions_sourceHostnames(t *testing.T) {
}{ }{
{"empty", []Policy{}, "", nil}, {"empty", []Policy{}, "", nil},
{"good no authN", []Policy{{From: "https://from.example", To: "https://to.example"}}, "", []string{"from.example"}}, {"good no authN", []Policy{{From: "https://from.example", To: "https://to.example"}}, "", []string{"from.example"}},
{"good with authN", []Policy{{From: "https://from.example", To: "https://to.example"}}, "https://authn.example.com", []string{"from.example", "authn.example.com"}}, {"good with authN", []Policy{{From: "https://from.example", To: "https://to.example"}}, "https://authn.example.com", []string{"authn.example.com", "from.example"}},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
@ -459,3 +459,66 @@ func TestOptions_sourceHostnames(t *testing.T) {
}) })
} }
} }
func TestCompareByteSliceSlice(t *testing.T) {
type Bytes = [][]byte
tests := []struct {
expect int
a Bytes
b Bytes
}{
{
0,
Bytes{
{0, 1, 2, 3},
},
Bytes{
{0, 1, 2, 3},
},
},
{
-1,
Bytes{
{0, 1, 2, 3},
},
Bytes{
{0, 1, 2, 4},
},
},
{
1,
Bytes{
{0, 1, 2, 4},
},
Bytes{
{0, 1, 2, 3},
},
},
{-1,
Bytes{
{0, 1, 2, 3},
},
Bytes{
{0, 1, 2, 3},
{4, 5, 6, 7},
},
},
{1,
Bytes{
{0, 1, 2, 3},
{4, 5, 6, 7},
},
Bytes{
{0, 1, 2, 3},
},
},
}
for _, tt := range tests {
actual := compareByteSliceSlice(tt.a, tt.b)
if tt.expect != actual {
t.Errorf("expected compare(%v, %v) to be %v but got %v",
tt.a, tt.b, tt.expect, actual)
}
}
}

View file

@ -115,7 +115,7 @@ func (p *Policy) Validate() error {
} }
if p.TLSClientCert != "" && p.TLSClientKey != "" { if p.TLSClientCert != "" && p.TLSClientKey != "" {
p.ClientCertificate, err = cryptutil.CertifcateFromBase64(p.TLSClientCert, p.TLSClientKey) p.ClientCertificate, err = cryptutil.CertificateFromBase64(p.TLSClientCert, p.TLSClientKey)
if err != nil { if err != nil {
return fmt.Errorf("config: couldn't decode client cert %w", err) return fmt.Errorf("config: couldn't decode client cert %w", err)
} }

60
config/redirect.go Normal file
View file

@ -0,0 +1,60 @@
package config
import (
"net/http"
"sync"
"github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/log"
)
// RedirectAndAutocertServer is an HTTP server which handles redirecting to HTTPS and autocerts.
var RedirectAndAutocertServer = newRedirectAndAutoCertServer()
type redirectAndAutoCertServer struct {
mu sync.Mutex
srv *http.Server
}
func newRedirectAndAutoCertServer() *redirectAndAutoCertServer {
return &redirectAndAutoCertServer{}
}
func (srv *redirectAndAutoCertServer) update(options *Options) {
srv.mu.Lock()
defer srv.mu.Unlock()
if srv.srv != nil {
// nothing to do if the address hasn't changed
if srv.srv.Addr == options.HTTPRedirectAddr {
return
}
// close immediately, don't care about the error
_ = srv.srv.Close()
srv.srv = nil
}
if options.HTTPRedirectAddr == "" {
return
}
redirect := httputil.RedirectHandler()
hsrv := &http.Server{
Addr: options.HTTPRedirectAddr,
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if AutocertManager.HandleHTTPChallenge(w, r) {
return
}
redirect.ServeHTTP(w, r)
}),
}
go func() {
log.Info().Str("addr", hsrv.Addr).Msg("starting http redirect server")
err := hsrv.ListenAndServe()
if err != nil {
log.Error().Err(err).Msg("failed to run http redirect server")
}
}()
srv.srv = hsrv
}

1
go.mod
View file

@ -10,7 +10,6 @@ require (
github.com/coreos/go-oidc v2.2.1+incompatible github.com/coreos/go-oidc v2.2.1+incompatible
github.com/envoyproxy/go-control-plane v0.9.5 github.com/envoyproxy/go-control-plane v0.9.5
github.com/fsnotify/fsnotify v1.4.9 github.com/fsnotify/fsnotify v1.4.9
github.com/go-acme/lego/v3 v3.4.0
github.com/go-redis/redis/v7 v7.2.0 github.com/go-redis/redis/v7 v7.2.0
github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e
github.com/golang/mock v1.4.3 github.com/golang/mock v1.4.3

View file

@ -96,7 +96,7 @@ func (srv *Server) streamAggregatedResourcesProcessStep(
for typeURL, version := range versions { for typeURL, version := range versions {
// the versions are different, so the envoy config needs to be updated // the versions are different, so the envoy config needs to be updated
if version != fmt.Sprint(current.version) { if version != fmt.Sprint(current.version) {
res, err := srv.buildDiscoveryResponse(fmt.Sprint(current.version), typeURL, current.Options) res, err := srv.buildDiscoveryResponse(fmt.Sprint(current.version), typeURL, &current.Options)
if err != nil { if err != nil {
return err return err
} }

View file

@ -3,8 +3,6 @@ package controlplane
import ( import (
"fmt" "fmt"
"net" "net"
"os"
"path/filepath"
"strconv" "strconv"
"github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/config"
@ -19,7 +17,7 @@ import (
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
) )
func (srv *Server) buildDiscoveryResponse(version string, typeURL string, options config.Options) (*envoy_service_discovery_v3.DiscoveryResponse, error) { func (srv *Server) buildDiscoveryResponse(version string, typeURL string, options *config.Options) (*envoy_service_discovery_v3.DiscoveryResponse, error) {
switch typeURL { switch typeURL {
case "type.googleapis.com/envoy.config.listener.v3.Listener": case "type.googleapis.com/envoy.config.listener.v3.Listener":
listeners := srv.buildListeners(options) listeners := srv.buildListeners(options)
@ -56,7 +54,7 @@ func (srv *Server) buildDiscoveryResponse(version string, typeURL string, option
} }
} }
func (srv *Server) buildAccessLogs(options config.Options) []*envoy_config_accesslog_v3.AccessLog { func (srv *Server) buildAccessLogs(options *config.Options) []*envoy_config_accesslog_v3.AccessLog {
lvl := options.ProxyLogLevel lvl := options.ProxyLogLevel
if lvl == "" { if lvl == "" {
lvl = options.LogLevel lvl = options.LogLevel
@ -112,10 +110,10 @@ func buildAddress(hostport string, defaultPort int) *envoy_config_core_v3.Addres
} }
} }
func getAbsoluteFilePath(filename string) string { func inlineBytes(bs []byte) *envoy_config_core_v3.DataSource {
if filepath.IsAbs(filename) { return &envoy_config_core_v3.DataSource{
return filename Specifier: &envoy_config_core_v3.DataSource_InlineBytes{
InlineBytes: bs,
},
} }
wd, _ := os.Getwd()
return filepath.Join(wd, filename)
} }

View file

@ -15,7 +15,7 @@ import (
"github.com/pomerium/pomerium/internal/urlutil" "github.com/pomerium/pomerium/internal/urlutil"
) )
func (srv *Server) buildClusters(options config.Options) []*envoy_config_cluster_v3.Cluster { func (srv *Server) buildClusters(options *config.Options) []*envoy_config_cluster_v3.Cluster {
grpcURL := &url.URL{ grpcURL := &url.URL{
Scheme: "grpc", Scheme: "grpc",
Host: srv.GRPCListener.Addr().String(), Host: srv.GRPCListener.Addr().String(),

View file

@ -1,7 +1,9 @@
package controlplane package controlplane
import ( import (
"encoding/base64" "bytes"
"crypto/x509"
"encoding/pem"
"sort" "sort"
envoy_config_core_v3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" envoy_config_core_v3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
@ -14,8 +16,11 @@ import (
envoy_type_v3 "github.com/envoyproxy/go-control-plane/envoy/type/v3" envoy_type_v3 "github.com/envoyproxy/go-control-plane/envoy/type/v3"
"github.com/golang/protobuf/ptypes" "github.com/golang/protobuf/ptypes"
"github.com/golang/protobuf/ptypes/any" "github.com/golang/protobuf/ptypes/any"
"google.golang.org/protobuf/types/known/emptypb"
"github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/cryptutil"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/urlutil" "github.com/pomerium/pomerium/internal/urlutil"
) )
@ -29,11 +34,11 @@ func init() {
}) })
} }
func (srv *Server) buildListeners(options config.Options) []*envoy_config_listener_v3.Listener { func (srv *Server) buildListeners(options *config.Options) []*envoy_config_listener_v3.Listener {
var listeners []*envoy_config_listener_v3.Listener var listeners []*envoy_config_listener_v3.Listener
if config.IsAuthenticate(options.Services) || config.IsProxy(options.Services) { if config.IsAuthenticate(options.Services) || config.IsProxy(options.Services) {
listeners = append(listeners, srv.buildHTTPListener(options)) listeners = append(listeners, srv.buildMainListener(options))
} }
if config.IsAuthorize(options.Services) || config.IsCache(options.Services) { if config.IsAuthorize(options.Services) || config.IsCache(options.Services) {
@ -43,22 +48,77 @@ func (srv *Server) buildListeners(options config.Options) []*envoy_config_listen
return listeners return listeners
} }
func (srv *Server) buildHTTPListener(options config.Options) *envoy_config_listener_v3.Listener { func (srv *Server) buildMainListener(options *config.Options) *envoy_config_listener_v3.Listener {
defaultPort := 80 if options.InsecureServer {
var transportSocket *envoy_config_core_v3.TransportSocket filter := srv.buildMainHTTPConnectionManagerFilter(options,
if !options.InsecureServer { srv.getAllRouteableDomains(options, options.Addr))
defaultPort = 443
tlsConfig, _ := ptypes.MarshalAny(srv.buildDownstreamTLSContext(options)) return &envoy_config_listener_v3.Listener{
transportSocket = &envoy_config_core_v3.TransportSocket{ Name: "http-ingress",
Address: buildAddress(options.Addr, 80),
FilterChains: []*envoy_config_listener_v3.FilterChain{{
Filters: []*envoy_config_listener_v3.Filter{
filter,
},
}},
}
}
tlsInspectorCfg, _ := ptypes.MarshalAny(new(emptypb.Empty))
li := &envoy_config_listener_v3.Listener{
Name: "https-ingress",
Address: buildAddress(options.Addr, 443),
ListenerFilters: []*envoy_config_listener_v3.ListenerFilter{{
Name: "envoy.filters.listener.tls_inspector",
ConfigType: &envoy_config_listener_v3.ListenerFilter_TypedConfig{
TypedConfig: tlsInspectorCfg,
},
}},
FilterChains: srv.buildFilterChains(options, options.Addr,
func(tlsDomain string, httpDomains []string) *envoy_config_listener_v3.FilterChain {
filter := srv.buildMainHTTPConnectionManagerFilter(options, httpDomains)
filterChain := &envoy_config_listener_v3.FilterChain{
Filters: []*envoy_config_listener_v3.Filter{filter},
}
if tlsDomain != "*" {
filterChain.FilterChainMatch = &envoy_config_listener_v3.FilterChainMatch{
ServerNames: []string{tlsDomain},
}
}
tlsContext := srv.buildDownstreamTLSContext(options, tlsDomain)
if tlsContext != nil {
tlsConfig, _ := ptypes.MarshalAny(tlsContext)
filterChain.TransportSocket = &envoy_config_core_v3.TransportSocket{
Name: "tls", Name: "tls",
ConfigType: &envoy_config_core_v3.TransportSocket_TypedConfig{ ConfigType: &envoy_config_core_v3.TransportSocket_TypedConfig{
TypedConfig: tlsConfig, TypedConfig: tlsConfig,
}, },
} }
} }
return filterChain
}),
}
return li
}
func (srv *Server) buildFilterChains(
options *config.Options, addr string,
callback func(tlsDomain string, httpDomains []string) *envoy_config_listener_v3.FilterChain,
) []*envoy_config_listener_v3.FilterChain {
allDomains := srv.getAllRouteableDomains(options, addr)
var chains []*envoy_config_listener_v3.FilterChain
for _, domain := range allDomains {
// first we match on SNI
chains = append(chains, callback(domain, []string{domain}))
}
// if there are no SNI matches we match on HTTP host
chains = append(chains, callback("*", allDomains))
return chains
}
func (srv *Server) buildMainHTTPConnectionManagerFilter(options *config.Options, domains []string) *envoy_config_listener_v3.Filter {
var virtualHosts []*envoy_config_route_v3.VirtualHost var virtualHosts []*envoy_config_route_v3.VirtualHost
for _, domain := range srv.getAllRouteableDomains(options, options.Addr) { for _, domain := range domains {
vh := &envoy_config_route_v3.VirtualHost{ vh := &envoy_config_route_v3.VirtualHost{
Name: domain, Name: domain,
Domains: []string{domain}, Domains: []string{domain},
@ -142,38 +202,66 @@ func (srv *Server) buildHTTPListener(options config.Options) *envoy_config_liste
AccessLog: srv.buildAccessLogs(options), AccessLog: srv.buildAccessLogs(options),
}) })
li := &envoy_config_listener_v3.Listener{ return &envoy_config_listener_v3.Filter{
Name: "http-ingress",
Address: buildAddress(options.Addr, defaultPort),
FilterChains: []*envoy_config_listener_v3.FilterChain{{
Filters: []*envoy_config_listener_v3.Filter{
{
Name: "envoy.filters.network.http_connection_manager", Name: "envoy.filters.network.http_connection_manager",
ConfigType: &envoy_config_listener_v3.Filter_TypedConfig{ ConfigType: &envoy_config_listener_v3.Filter_TypedConfig{
TypedConfig: tc, TypedConfig: tc,
}, },
},
},
TransportSocket: transportSocket,
}},
} }
return li
} }
func (srv *Server) buildGRPCListener(options config.Options) *envoy_config_listener_v3.Listener { func (srv *Server) buildGRPCListener(options *config.Options) *envoy_config_listener_v3.Listener {
defaultPort := 80 filter := srv.buildGRPCHTTPConnectionManagerFilter()
var transportSocket *envoy_config_core_v3.TransportSocket
if !options.GRPCInsecure { if options.GRPCInsecure {
defaultPort = 443 return &envoy_config_listener_v3.Listener{
tlsConfig, _ := ptypes.MarshalAny(srv.buildDownstreamTLSContext(options)) Name: "grpc-ingress",
transportSocket = &envoy_config_core_v3.TransportSocket{ Address: buildAddress(options.GRPCAddr, 80),
FilterChains: []*envoy_config_listener_v3.FilterChain{{
Filters: []*envoy_config_listener_v3.Filter{
filter,
},
}},
}
}
tlsInspectorCfg, _ := ptypes.MarshalAny(new(emptypb.Empty))
li := &envoy_config_listener_v3.Listener{
Name: "grpc-ingress",
Address: buildAddress(options.GRPCAddr, 443),
ListenerFilters: []*envoy_config_listener_v3.ListenerFilter{{
Name: "envoy.filters.listener.tls_inspector",
ConfigType: &envoy_config_listener_v3.ListenerFilter_TypedConfig{
TypedConfig: tlsInspectorCfg,
},
}},
FilterChains: srv.buildFilterChains(options, options.Addr,
func(tlsDomain string, httpDomains []string) *envoy_config_listener_v3.FilterChain {
filterChain := &envoy_config_listener_v3.FilterChain{
Filters: []*envoy_config_listener_v3.Filter{filter},
}
if tlsDomain != "*" {
filterChain.FilterChainMatch = &envoy_config_listener_v3.FilterChainMatch{
ServerNames: []string{tlsDomain},
}
}
tlsContext := srv.buildDownstreamTLSContext(options, tlsDomain)
if tlsContext != nil {
tlsConfig, _ := ptypes.MarshalAny(tlsContext)
filterChain.TransportSocket = &envoy_config_core_v3.TransportSocket{
Name: "tls", Name: "tls",
ConfigType: &envoy_config_core_v3.TransportSocket_TypedConfig{ ConfigType: &envoy_config_core_v3.TransportSocket_TypedConfig{
TypedConfig: tlsConfig, TypedConfig: tlsConfig,
}, },
} }
} }
return filterChain
}),
}
return li
}
func (srv *Server) buildGRPCHTTPConnectionManagerFilter() *envoy_config_listener_v3.Filter {
tc, _ := ptypes.MarshalAny(&envoy_http_connection_manager.HttpConnectionManager{ tc, _ := ptypes.MarshalAny(&envoy_http_connection_manager.HttpConnectionManager{
CodecType: envoy_http_connection_manager.HttpConnectionManager_AUTO, CodecType: envoy_http_connection_manager.HttpConnectionManager_AUTO,
StatPrefix: "grpc_ingress", StatPrefix: "grpc_ingress",
@ -191,7 +279,9 @@ func (srv *Server) buildGRPCListener(options config.Options) *envoy_config_liste
}, },
Action: &envoy_config_route_v3.Route_Route{ Action: &envoy_config_route_v3.Route_Route{
Route: &envoy_config_route_v3.RouteAction{ Route: &envoy_config_route_v3.RouteAction{
ClusterSpecifier: &envoy_config_route_v3.RouteAction_Cluster{Cluster: "pomerium-control-plane-grpc"}, ClusterSpecifier: &envoy_config_route_v3.RouteAction_Cluster{
Cluster: "pomerium-control-plane-grpc",
},
}, },
}, },
}}, }},
@ -202,64 +292,57 @@ func (srv *Server) buildGRPCListener(options config.Options) *envoy_config_liste
Name: "envoy.filters.http.router", Name: "envoy.filters.http.router",
}}, }},
}) })
return &envoy_config_listener_v3.Filter{
return &envoy_config_listener_v3.Listener{
Name: "grpc-ingress",
Address: buildAddress(options.GRPCAddr, defaultPort),
FilterChains: []*envoy_config_listener_v3.FilterChain{{
Filters: []*envoy_config_listener_v3.Filter{{
Name: "envoy.filters.network.http_connection_manager", Name: "envoy.filters.network.http_connection_manager",
ConfigType: &envoy_config_listener_v3.Filter_TypedConfig{ ConfigType: &envoy_config_listener_v3.Filter_TypedConfig{
TypedConfig: tc, TypedConfig: tc,
}, },
}},
TransportSocket: transportSocket,
}},
} }
} }
func (srv *Server) buildDownstreamTLSContext(options config.Options) *envoy_extensions_transport_sockets_tls_v3.DownstreamTlsContext { func (srv *Server) buildDownstreamTLSContext(options *config.Options, domain string) *envoy_extensions_transport_sockets_tls_v3.DownstreamTlsContext {
var cert envoy_extensions_transport_sockets_tls_v3.TlsCertificate cert, err := cryptutil.GetCertificateForDomain(options.Certificates, domain)
if options.Cert != "" { if err != nil {
bs, _ := base64.StdEncoding.DecodeString(options.Cert) log.Warn().Str("domain", domain).Err(err).Msg("failed to get certificate for domain")
cert.CertificateChain = &envoy_config_core_v3.DataSource{ return nil
Specifier: &envoy_config_core_v3.DataSource_InlineBytes{
InlineBytes: bs,
},
} }
envoyCert := &envoy_extensions_transport_sockets_tls_v3.TlsCertificate{}
var chain bytes.Buffer
for _, cbs := range cert.Certificate {
_ = pem.Encode(&chain, &pem.Block{
Type: "CERTIFICATE",
Bytes: cbs,
})
}
envoyCert.CertificateChain = inlineBytes(chain.Bytes())
if cert.OCSPStaple != nil {
envoyCert.OcspStaple = inlineBytes(cert.OCSPStaple)
}
if bs, err := x509.MarshalPKCS8PrivateKey(cert.PrivateKey); err == nil {
envoyCert.PrivateKey = inlineBytes(pem.EncodeToMemory(
&pem.Block{
Type: "PRIVATE KEY",
Bytes: bs,
},
))
} else { } else {
cert.CertificateChain = &envoy_config_core_v3.DataSource{ log.Warn().Err(err).Msg("failed to marshal private key for tls config")
Specifier: &envoy_config_core_v3.DataSource_Filename{
Filename: getAbsoluteFilePath(options.CertFile),
},
}
}
if options.Key != "" {
bs, _ := base64.StdEncoding.DecodeString(options.Key)
cert.PrivateKey = &envoy_config_core_v3.DataSource{
Specifier: &envoy_config_core_v3.DataSource_InlineBytes{
InlineBytes: bs,
},
}
} else {
cert.PrivateKey = &envoy_config_core_v3.DataSource{
Specifier: &envoy_config_core_v3.DataSource_Filename{
Filename: getAbsoluteFilePath(options.KeyFile),
},
} }
for _, scts := range cert.SignedCertificateTimestamps {
envoyCert.SignedCertificateTimestamp = append(envoyCert.SignedCertificateTimestamp,
inlineBytes(scts))
} }
return &envoy_extensions_transport_sockets_tls_v3.DownstreamTlsContext{ return &envoy_extensions_transport_sockets_tls_v3.DownstreamTlsContext{
CommonTlsContext: &envoy_extensions_transport_sockets_tls_v3.CommonTlsContext{ CommonTlsContext: &envoy_extensions_transport_sockets_tls_v3.CommonTlsContext{
TlsCertificates: []*envoy_extensions_transport_sockets_tls_v3.TlsCertificate{ TlsCertificates: []*envoy_extensions_transport_sockets_tls_v3.TlsCertificate{envoyCert},
&cert,
},
AlpnProtocols: []string{"h2", "http/1.1"}, AlpnProtocols: []string{"h2", "http/1.1"},
}, },
} }
} }
func (srv *Server) getAllRouteableDomains(options config.Options, addr string) []string { func (srv *Server) getAllRouteableDomains(options *config.Options, addr string) []string {
lookup := map[string]struct{}{} lookup := map[string]struct{}{}
if config.IsAuthenticate(options.Services) && addr == options.Addr { if config.IsAuthenticate(options.Services) && addr == options.Addr {
lookup[urlutil.StripPort(options.AuthenticateURL.Host)] = struct{}{} lookup[urlutil.StripPort(options.AuthenticateURL.Host)] = struct{}{}

View file

@ -37,7 +37,7 @@ func (srv *Server) buildGRPCRoutes() []*envoy_config_route_v3.Route {
}} }}
} }
func (srv *Server) buildPomeriumHTTPRoutes(options config.Options, domain string) []*envoy_config_route_v3.Route { func (srv *Server) buildPomeriumHTTPRoutes(options *config.Options, domain string) []*envoy_config_route_v3.Route {
routes := []*envoy_config_route_v3.Route{ routes := []*envoy_config_route_v3.Route{
srv.buildControlPlanePathRoute("/ping"), srv.buildControlPlanePathRoute("/ping"),
srv.buildControlPlanePathRoute("/healthz"), srv.buildControlPlanePathRoute("/healthz"),
@ -95,7 +95,7 @@ func (srv *Server) buildControlPlanePrefixRoute(prefix string) *envoy_config_rou
} }
} }
func (srv *Server) buildPolicyRoutes(options config.Options, domain string) []*envoy_config_route_v3.Route { func (srv *Server) buildPolicyRoutes(options *config.Options, domain string) []*envoy_config_route_v3.Route {
var routes []*envoy_config_route_v3.Route var routes []*envoy_config_route_v3.Route
for i, policy := range options.Policies { for i, policy := range options.Policies {
if policy.Source.Hostname() != domain { if policy.Source.Hostname() != domain {

View file

@ -2,17 +2,23 @@ package cryptutil
import ( import (
"crypto/ecdsa" "crypto/ecdsa"
"crypto/rand"
"crypto/rsa"
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"crypto/x509/pkix"
"encoding/base64" "encoding/base64"
"encoding/pem" "encoding/pem"
"errors" "errors"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"math/big"
"net"
"time"
) )
// CertifcateFromBase64 returns an X509 pair from a base64 encoded blob. // CertificateFromBase64 returns an X509 pair from a base64 encoded blob.
func CertifcateFromBase64(cert, key string) (*tls.Certificate, error) { func CertificateFromBase64(cert, key string) (*tls.Certificate, error) {
decodedCert, err := base64.StdEncoding.DecodeString(cert) decodedCert, err := base64.StdEncoding.DecodeString(cert)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to decode certificate cert %v: %w", decodedCert, err) return nil, fmt.Errorf("failed to decode certificate cert %v: %w", decodedCert, err)
@ -135,3 +141,59 @@ func EncodePrivateKey(key *ecdsa.PrivateKey) ([]byte, error) {
return pem.EncodeToMemory(keyBlock), nil return pem.EncodeToMemory(keyBlock), nil
} }
// GenerateSelfSignedCertificate generates a self-signed TLS certificate.
//
// mostly copied from https://golang.org/src/crypto/tls/generate_cert.go
func GenerateSelfSignedCertificate(domain string) (*tls.Certificate, error) {
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return nil, fmt.Errorf("failed to geneate private key: %w", err)
}
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
if err != nil {
return nil, fmt.Errorf("failed to generate serial number: %w", err)
}
template := x509.Certificate{
SerialNumber: serialNumber,
Subject: pkix.Name{
Organization: []string{"Pomerium"},
},
NotBefore: time.Now().Add(-time.Minute * 10),
NotAfter: time.Now().Add(time.Hour * 24 * 365),
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
}
if ip := net.ParseIP(domain); ip != nil {
template.IPAddresses = append(template.IPAddresses, ip)
} else {
template.DNSNames = append(template.DNSNames, domain)
}
publicKeyBytes, err := x509.CreateCertificate(rand.Reader,
&template, &template,
privateKey.Public(), privateKey,
)
if err != nil {
return nil, fmt.Errorf("failed to create certificate: %w", err)
}
privateKeyBytes, err := x509.MarshalPKCS8PrivateKey(privateKey)
if err != nil {
return nil, fmt.Errorf("failed to marshal private key: %w", err)
}
cert, err := tls.X509KeyPair(
pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: publicKeyBytes}),
pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: privateKeyBytes}),
)
if err != nil {
return nil, fmt.Errorf("failed to convert x509 bytes into tls certificate: %w", err)
}
return &cert, nil
}

View file

@ -55,9 +55,9 @@ func TestCertifcateFromBase64(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
_, err := CertifcateFromBase64(tt.cert, tt.key) _, err := CertificateFromBase64(tt.cert, tt.key)
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("CertifcateFromBase64() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("CertificateFromBase64() error = %v, wantErr %v", err, tt.wantErr)
return return
} }
}) })

View file

@ -1,103 +1,51 @@
package cryptutil package cryptutil
import ( import (
"context"
"crypto/tls" "crypto/tls"
"fmt" "crypto/x509"
"net/http"
"github.com/caddyserver/certmagic" "github.com/caddyserver/certmagic"
"github.com/go-acme/lego/v3/challenge/tlsalpn01"
) )
// NewAutocert automatically retrieves public certificates from the free // GetCertificateForDomain returns the tls Certificate which matches the given domain name.
// certificate authority Let's Encrypt using HTTP-01 and TLS-ALPN-01 challenges. // It should handle both exact matches and wildcard matches. If none of those match, the first certificate will be used.
// To complete the challenges, the server must be accessible from the internet // Finally if there are no matching certificates one will be generated.
// by port 80 or 443 . func GetCertificateForDomain(certificates []tls.Certificate, domain string) (*tls.Certificate, error) {
// // first try a direct name match
// https://letsencrypt.org/docs/challenge-types/#http-01-challenge for _, cert := range certificates {
// https://letsencrypt.org/docs/challenge-types/#tls-alpn-01 if matchesDomain(&cert, domain) {
func NewAutocert(tlsConfig *tls.Config, hostnames []string, useStaging bool, path string) (*tls.Config, func(h http.Handler) http.Handler, error) { return &cert, nil
certmagic.DefaultACME.Agreed = true
if useStaging {
certmagic.DefaultACME.CA = certmagic.LetsEncryptStagingCA
} }
cm := certmagic.NewDefault()
tlsConfig = newTLSConfigIfEmpty(tlsConfig)
// add existing certs to the cache, and staple OCSP
for _, cert := range tlsConfig.Certificates {
if err := cm.CacheUnmanagedTLSCertificate(cert, nil); err != nil {
return nil, nil, fmt.Errorf("cryptutil: failed caching cert: %w", err)
}
}
cm.Storage = &certmagic.FileStorage{Path: path}
acmeConfig := certmagic.NewACMEManager(cm, certmagic.DefaultACME)
cm.Issuer = acmeConfig
// todo(bdd) : add cancellation context?
if err := cm.ManageAsync(context.TODO(), hostnames); err != nil {
return nil, nil, fmt.Errorf("cryptutil: sync failed: %w", err)
} }
tlsConfig.GetCertificate = cm.GetCertificate // next use the first cert
tlsConfig.NextProtos = append(tlsConfig.NextProtos, tlsalpn01.ACMETLS1Protocol) if len(certificates) > 0 {
tlsConfig.BuildNameToCertificate() return &certificates[0], nil
return tlsConfig, acmeConfig.HTTPChallengeHandler, nil }
// finally fall back to a generated, self-signed certificate
return GenerateSelfSignedCertificate(domain)
} }
// TLSConfigFromBase64 returns an tls configuration from a base64 encoded blob. func matchesDomain(cert *tls.Certificate, domain string) bool {
func TLSConfigFromBase64(tlsConfig *tls.Config, cert, key string) (*tls.Config, error) { if cert == nil || len(cert.Certificate) == 0 {
tlsConfig = newTLSConfigIfEmpty(tlsConfig) return false
c, err := CertifcateFromBase64(cert, key) }
xcert, err := x509.ParseCertificate(cert.Certificate[0])
if err != nil { if err != nil {
return nil, err return false
} }
tlsConfig.Certificates = append(tlsConfig.Certificates, *c)
tlsConfig.BuildNameToCertificate()
return tlsConfig, nil
}
// TLSConfigFromFile returns an tls configuration from a certificate and if certmagic.MatchWildcard(domain, xcert.Subject.CommonName) {
// key file . return true
func TLSConfigFromFile(tlsConfig *tls.Config, cert, key string) (*tls.Config, error) {
tlsConfig = newTLSConfigIfEmpty(tlsConfig)
c, err := CertificateFromFile(cert, key)
if err != nil {
return nil, err
} }
tlsConfig.Certificates = append(tlsConfig.Certificates, *c)
tlsConfig.BuildNameToCertificate()
return tlsConfig, nil
}
// newTLSConfigIfEmpty returns an opinionated TLS configuration if config is nil. for _, san := range xcert.DNSNames {
// See : if certmagic.MatchWildcard(domain, san) {
// https://wiki.mozilla.org/Security/Server_Side_TLS#Recommended_configurations return true
// https://blog.cloudflare.com/exposing-go-on-the-internet/
// https://github.com/ssllabs/research/wiki/SSL-and-TLS-Deployment-Best-Practices
// https://github.com/golang/go/blob/df91b8044dbe790c69c16058330f545be069cc1f/src/crypto/tls/common.go#L919
func newTLSConfigIfEmpty(tlsConfig *tls.Config) *tls.Config {
if tlsConfig != nil {
return tlsConfig
} }
return &tls.Config{
MinVersion: tls.VersionTLS12,
// Prioritize cipher suites sped up by AES-NI (AES-GCM)
CipherSuites: []uint16{
tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305,
tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305,
},
PreferServerCipherSuites: true,
// Use curves which have assembly implementations
CurvePreferences: []tls.CurveID{
tls.X25519,
tls.CurveP256,
},
// HTTP/2 must be enabled manually when using http.Serve
NextProtos: []string{"h2", "http/1.1"},
} }
return false
} }

View file

@ -3,47 +3,61 @@ package cryptutil
import ( import (
"crypto/tls" "crypto/tls"
"testing" "testing"
"github.com/stretchr/testify/assert"
) )
func TestTLSConfigFromBase64(t *testing.T) { func TestGetCertificateForDomain(t *testing.T) {
tests := []struct { gen := func(t *testing.T, domain string) *tls.Certificate {
name string cert, err := GenerateSelfSignedCertificate(domain)
cert string if !assert.NoError(t, err, "error generating certificate for: %s", domain) {
key string t.FailNow()
wantErr bool
}{
{"good",
"LS0tLS1CRUdJTiBDRVJUSUZJQ0FURS0tLS0tCk1JSUVJVENDQWdtZ0F3SUJBZ0lSQVBqTEJxS1lwcWU0ekhQc0dWdFR6T0F3RFFZSktvWklodmNOQVFFTEJRQXcKRWpFUU1BNEdBMVVFQXhNSFoyOXZaQzFqWVRBZUZ3MHhPVEE0TVRBeE9EUTVOREJhRncweU1UQXlNVEF4TnpRdwpNREZhTUJNeEVUQVBCZ05WQkFNVENIQnZiV1Z5YVhWdE1JSUJJakFOQmdrcWhraUc5dzBCQVFFRkFBT0NBUThBCk1JSUJDZ0tDQVFFQTY3S2pxbVFZR3EwTVZ0QUNWcGVDbVhtaW5sUWJEUEdMbXNaQVVFd3VlSFFucnQzV3R2cEQKT202QWxhSk1VblcrSHU1NWpqb2thbEtlVmpUS21nWUdicVV6VkRvTWJQRGFIZWtsdGRCVE1HbE9VRnNQNFVKUwpEck80emROK3pvNDI4VFgyUG5HMkZDZFZLR3k0UEU4aWxIYldMY3I4NzFZalY1MWZ3OENMRFg5UFpKTnU4NjFDCkY3VjlpRUptNnNTZlFsbW5oTjhqMytXelZiUFFOeTFXc1I3aTllOWo2M0VxS3QyMlE5T1hMK1dBY0tza29JU20KQ05WUlVBalU4WVJWY2dRSkIrelEzNEFRUGx6ME9wNU8vUU4vTWVkamFGOHdMUytpdi96dmlTOGNxUGJ4bzZzTApxNkZOVGx0ay9Ra3hlQ2VLS1RRZS8za1BZdlFBZG5sNjVRSURBUUFCbzNFd2J6QU9CZ05WSFE4QkFmOEVCQU1DCkE3Z3dIUVlEVlIwbEJCWXdGQVlJS3dZQkJRVUhBd0VHQ0NzR0FRVUZCd01DTUIwR0ExVWREZ1FXQkJRQ1FYbWIKc0hpcS9UQlZUZVhoQ0dpNjhrVy9DakFmQmdOVkhTTUVHREFXZ0JSNTRKQ3pMRlg0T0RTQ1J0dWNBUGZOdVhWegpuREFOQmdrcWhraUc5dzBCQVFzRkFBT0NBZ0VBcm9XL2trMllleFN5NEhaQXFLNDVZaGQ5ay9QVTFiaDlFK1BRCk5jZFgzTUdEY2NDRUFkc1k4dll3NVE1cnhuMGFzcSt3VGFCcGxoYS9rMi9VVW9IQ1RqUVp1Mk94dEF3UTdPaWIKVE1tMEorU3NWT3d4YnFQTW9rK1RqVE16NFdXaFFUTzVwRmNoZDZXZXNCVHlJNzJ0aG1jcDd1c2NLU2h3YktIegpQY2h1QTQ4SzhPdi96WkxmZnduQVNZb3VCczJjd1ZiRDI3ZXZOMzdoMGFzR1BrR1VXdm1PSDduTHNVeTh3TTdqCkNGL3NwMmJmTC9OYVdNclJnTHZBMGZMS2pwWTQrVEpPbkVxQmxPcCsrbHlJTEZMcC9qMHNybjRNUnlKK0t6UTEKR1RPakVtQ1QvVEFtOS9XSThSL0FlYjcwTjEzTytYNEtaOUJHaDAxTzN3T1Vqd3BZZ3lxSnNoRnNRUG50VmMrSQpKQmF4M2VQU3NicUcwTFkzcHdHUkpRNmMrd1lxdGk2Y0tNTjliYlRkMDhCNUk1N1RRTHhNcUoycTFnWmw1R1VUCmVFZGNWRXltMnZmd0NPd0lrbGNBbThxTm5kZGZKV1FabE5VaHNOVWFBMkVINnlDeXdaZm9aak9hSDEwTXowV20KeTNpZ2NSZFQ3Mi9NR2VkZk93MlV0MVVvRFZmdEcxcysrditUQ1lpNmpUQU05dkZPckJ4UGlOeGFkUENHR2NZZAowakZIc2FWOGFPV1dQQjZBQ1JteHdDVDdRTnRTczM2MlpIOUlFWWR4Q00yMDUrZmluVHhkOUcwSmVRRTd2Kyt6CldoeWo2ZmJBWUIxM2wvN1hkRnpNSW5BOGxpekdrVHB2RHMxeTBCUzlwV3ppYmhqbVFoZGZIejdCZGpGTHVvc2wKZzlNZE5sND0KLS0tLS1FTkQgQ0VSVElGSUNBVEUtLS0tLQo=",
"LS0tLS1CRUdJTiBSU0EgUFJJVkFURSBLRVktLS0tLQpNSUlFcGdJQkFBS0NBUUVBNjdLanFtUVlHcTBNVnRBQ1ZwZUNtWG1pbmxRYkRQR0xtc1pBVUV3dWVIUW5ydDNXCnR2cERPbTZBbGFKTVVuVytIdTU1ampva2FsS2VWalRLbWdZR2JxVXpWRG9NYlBEYUhla2x0ZEJUTUdsT1VGc1AKNFVKU0RyTzR6ZE4rem80MjhUWDJQbkcyRkNkVktHeTRQRThpbEhiV0xjcjg3MVlqVjUxZnc4Q0xEWDlQWkpOdQo4NjFDRjdWOWlFSm02c1NmUWxtbmhOOGozK1d6VmJQUU55MVdzUjdpOWU5ajYzRXFLdDIyUTlPWEwrV0FjS3NrCm9JU21DTlZSVUFqVThZUlZjZ1FKQit6UTM0QVFQbHowT3A1Ty9RTi9NZWRqYUY4d0xTK2l2L3p2aVM4Y3FQYngKbzZzTHE2Rk5UbHRrL1FreGVDZUtLVFFlLzNrUFl2UUFkbmw2NVFJREFRQUJBb0lCQVFEQVQ0eXN2V2pSY3pxcgpKcU9SeGFPQTJEY3dXazJML1JXOFhtQWhaRmRTWHV2MkNQbGxhTU1yelBmTG41WUlmaHQzSDNzODZnSEdZc3pnClo4aWJiYWtYNUdFQ0t5N3lRSDZuZ3hFS3pRVGpiampBNWR3S0h0UFhQUnJmamQ1Y2FMczVpcDcxaWxCWEYxU3IKWERIaXUycnFtaC9kVTArWGRMLzNmK2VnVDl6bFQ5YzRyUm84dnZueWNYejFyMnVhRVZ2VExsWHVsb2NpeEVrcgoySjlTMmxveWFUb2tFTnNlMDNpSVdaWnpNNElZcVowOGJOeG9IWCszQXVlWExIUStzRkRKMlhaVVdLSkZHMHUyClp3R2w3YlZpRTFQNXdiQUdtZzJDeDVCN1MrdGQyUEpSV3Frb2VxY3F2RVdCc3RFL1FEcDFpVThCOHpiQXd0Y3IKZHc5TXZ6Q2hBb0dCQVBObzRWMjF6MGp6MWdEb2tlTVN5d3JnL2E4RkJSM2R2Y0xZbWV5VXkybmd3eHVucnFsdwo2U2IrOWdrOGovcXEvc3VQSDhVdzNqSHNKYXdGSnNvTkVqNCt2b1ZSM3UrbE5sTEw5b21rMXBoU0dNdVp0b3huCm5nbUxVbkJUMGI1M3BURkJ5WGsveE5CbElreWdBNlg5T2MreW5na3RqNlRyVnMxUERTdnVJY0s1QW9HQkFQZmoKcEUzR2F6cVFSemx6TjRvTHZmQWJBdktCZ1lPaFNnemxsK0ZLZkhzYWJGNkdudFd1dWVhY1FIWFpYZTA1c2tLcApXN2xYQ3dqQU1iUXI3QmdlazcrOSszZElwL1RnYmZCYnN3Syt6Vng3Z2doeWMrdytXRWExaHByWTZ6YXdxdkFaCkhRU2lMUEd1UGp5WXBQa1E2ZFdEczNmWHJGZ1dlTmd4SkhTZkdaT05Bb0dCQUt5WTF3MUM2U3Y2c3VuTC8vNTcKQ2Z5NTAwaXlqNUZBOWRqZkRDNWt4K1JZMnlDV0ExVGsybjZyVmJ6dzg4czBTeDMrYS9IQW1CM2dMRXBSRU5NKwo5NHVwcENFWEQ3VHdlcGUxUnlrTStKbmp4TzlDSE41c2J2U25sUnBQWlMvZzJRTVhlZ3grK2trbkhXNG1ITkFyCndqMlRrMXBBczFXbkJ0TG9WaGVyY01jSkFvR0JBSTYwSGdJb0Y5SysvRUcyY21LbUg5SDV1dGlnZFU2eHEwK0IKWE0zMWMzUHE0amdJaDZlN3pvbFRxa2d0dWtTMjBraE45dC9ibkI2TmhnK1N1WGVwSXFWZldVUnlMejVwZE9ESgo2V1BMTTYzcDdCR3cwY3RPbU1NYi9VRm5Yd0U4OHlzRlNnOUF6VjdVVUQvU0lDYkI5ZHRVMWh4SHJJK0pZRWdWCkFrZWd6N2lCQW9HQkFJRncrQVFJZUIwM01UL0lCbGswNENQTDJEak0rNDhoVGRRdjgwMDBIQU9mUWJrMEVZUDEKQ2FLR3RDbTg2MXpBZjBzcS81REtZQ0l6OS9HUzNYRk00Qm1rRk9nY1NXVENPNmZmTGdLM3FmQzN4WDJudlpIOQpYZGNKTDQrZndhY0x4c2JJKzhhUWNOVHRtb3pkUjEzQnNmUmIrSGpUL2o3dkdrYlFnSkhCT0syegotLS0tLUVORCBSU0EgUFJJVkFURSBLRVktLS0tLQo=",
false},
{"bad cert",
"!=",
"LS0tLS1CRUdJTiBSU0EgUFJJVkFURSBLRVktLS0tLQpNSUlFcGdJQkFBS0NBUUVBNjdLanFtUVlHcTBNVnRBQ1ZwZUNtWG1pbmxRYkRQR0xtc1pBVUV3dWVIUW5ydDNXCnR2cERPbTZBbGFKTVVuVytIdTU1ampva2FsS2VWalRLbWdZR2JxVXpWRG9NYlBEYUhla2x0ZEJUTUdsT1VGc1AKNFVKU0RyTzR6ZE4rem80MjhUWDJQbkcyRkNkVktHeTRQRThpbEhiV0xjcjg3MVlqVjUxZnc4Q0xEWDlQWkpOdQo4NjFDRjdWOWlFSm02c1NmUWxtbmhOOGozK1d6VmJQUU55MVdzUjdpOWU5ajYzRXFLdDIyUTlPWEwrV0FjS3NrCm9JU21DTlZSVUFqVThZUlZjZ1FKQit6UTM0QVFQbHowT3A1Ty9RTi9NZWRqYUY4d0xTK2l2L3p2aVM4Y3FQYngKbzZzTHE2Rk5UbHRrL1FreGVDZUtLVFFlLzNrUFl2UUFkbmw2NVFJREFRQUJBb0lCQVFEQVQ0eXN2V2pSY3pxcgpKcU9SeGFPQTJEY3dXazJML1JXOFhtQWhaRmRTWHV2MkNQbGxhTU1yelBmTG41WUlmaHQzSDNzODZnSEdZc3pnClo4aWJiYWtYNUdFQ0t5N3lRSDZuZ3hFS3pRVGpiampBNWR3S0h0UFhQUnJmamQ1Y2FMczVpcDcxaWxCWEYxU3IKWERIaXUycnFtaC9kVTArWGRMLzNmK2VnVDl6bFQ5YzRyUm84dnZueWNYejFyMnVhRVZ2VExsWHVsb2NpeEVrcgoySjlTMmxveWFUb2tFTnNlMDNpSVdaWnpNNElZcVowOGJOeG9IWCszQXVlWExIUStzRkRKMlhaVVdLSkZHMHUyClp3R2w3YlZpRTFQNXdiQUdtZzJDeDVCN1MrdGQyUEpSV3Frb2VxY3F2RVdCc3RFL1FEcDFpVThCOHpiQXd0Y3IKZHc5TXZ6Q2hBb0dCQVBObzRWMjF6MGp6MWdEb2tlTVN5d3JnL2E4RkJSM2R2Y0xZbWV5VXkybmd3eHVucnFsdwo2U2IrOWdrOGovcXEvc3VQSDhVdzNqSHNKYXdGSnNvTkVqNCt2b1ZSM3UrbE5sTEw5b21rMXBoU0dNdVp0b3huCm5nbUxVbkJUMGI1M3BURkJ5WGsveE5CbElreWdBNlg5T2MreW5na3RqNlRyVnMxUERTdnVJY0s1QW9HQkFQZmoKcEUzR2F6cVFSemx6TjRvTHZmQWJBdktCZ1lPaFNnemxsK0ZLZkhzYWJGNkdudFd1dWVhY1FIWFpYZTA1c2tLcApXN2xYQ3dqQU1iUXI3QmdlazcrOSszZElwL1RnYmZCYnN3Syt6Vng3Z2doeWMrdytXRWExaHByWTZ6YXdxdkFaCkhRU2lMUEd1UGp5WXBQa1E2ZFdEczNmWHJGZ1dlTmd4SkhTZkdaT05Bb0dCQUt5WTF3MUM2U3Y2c3VuTC8vNTcKQ2Z5NTAwaXlqNUZBOWRqZkRDNWt4K1JZMnlDV0ExVGsybjZyVmJ6dzg4czBTeDMrYS9IQW1CM2dMRXBSRU5NKwo5NHVwcENFWEQ3VHdlcGUxUnlrTStKbmp4TzlDSE41c2J2U25sUnBQWlMvZzJRTVhlZ3grK2trbkhXNG1ITkFyCndqMlRrMXBBczFXbkJ0TG9WaGVyY01jSkFvR0JBSTYwSGdJb0Y5SysvRUcyY21LbUg5SDV1dGlnZFU2eHEwK0IKWE0zMWMzUHE0amdJaDZlN3pvbFRxa2d0dWtTMjBraE45dC9ibkI2TmhnK1N1WGVwSXFWZldVUnlMejVwZE9ESgo2V1BMTTYzcDdCR3cwY3RPbU1NYi9VRm5Yd0U4OHlzRlNnOUF6VjdVVUQvU0lDYkI5ZHRVMWh4SHJJK0pZRWdWCkFrZWd6N2lCQW9HQkFJRncrQVFJZUIwM01UL0lCbGswNENQTDJEak0rNDhoVGRRdjgwMDBIQU9mUWJrMEVZUDEKQ2FLR3RDbTg2MXpBZjBzcS81REtZQ0l6OS9HUzNYRk00Qm1rRk9nY1NXVENPNmZmTGdLM3FmQzN4WDJudlpIOQpYZGNKTDQrZndhY0x4c2JJKzhhUWNOVHRtb3pkUjEzQnNmUmIrSGpUL2o3dkdrYlFnSkhCT0syegotLS0tLUVORCBSU0EgUFJJVkFURSBLRVktLS0tLQo=",
true},
{"bad key",
"LS0tLS1CRUdJTiBDRVJUSUZJQ0FURS0tLS0tCk1JSUVJVENDQWdtZ0F3SUJBZ0lSQVBqTEJxS1lwcWU0ekhQc0dWdFR6T0F3RFFZSktvWklodmNOQVFFTEJRQXcKRWpFUU1BNEdBMVVFQXhNSFoyOXZaQzFqWVRBZUZ3MHhPVEE0TVRBeE9EUTVOREJhRncweU1UQXlNVEF4TnpRdwpNREZhTUJNeEVUQVBCZ05WQkFNVENIQnZiV1Z5YVhWdE1JSUJJakFOQmdrcWhraUc5dzBCQVFFRkFBT0NBUThBCk1JSUJDZ0tDQVFFQTY3S2pxbVFZR3EwTVZ0QUNWcGVDbVhtaW5sUWJEUEdMbXNaQVVFd3VlSFFucnQzV3R2cEQKT202QWxhSk1VblcrSHU1NWpqb2thbEtlVmpUS21nWUdicVV6VkRvTWJQRGFIZWtsdGRCVE1HbE9VRnNQNFVKUwpEck80emROK3pvNDI4VFgyUG5HMkZDZFZLR3k0UEU4aWxIYldMY3I4NzFZalY1MWZ3OENMRFg5UFpKTnU4NjFDCkY3VjlpRUptNnNTZlFsbW5oTjhqMytXelZiUFFOeTFXc1I3aTllOWo2M0VxS3QyMlE5T1hMK1dBY0tza29JU20KQ05WUlVBalU4WVJWY2dRSkIrelEzNEFRUGx6ME9wNU8vUU4vTWVkamFGOHdMUytpdi96dmlTOGNxUGJ4bzZzTApxNkZOVGx0ay9Ra3hlQ2VLS1RRZS8za1BZdlFBZG5sNjVRSURBUUFCbzNFd2J6QU9CZ05WSFE4QkFmOEVCQU1DCkE3Z3dIUVlEVlIwbEJCWXdGQVlJS3dZQkJRVUhBd0VHQ0NzR0FRVUZCd01DTUIwR0ExVWREZ1FXQkJRQ1FYbWIKc0hpcS9UQlZUZVhoQ0dpNjhrVy9DakFmQmdOVkhTTUVHREFXZ0JSNTRKQ3pMRlg0T0RTQ1J0dWNBUGZOdVhWegpuREFOQmdrcWhraUc5dzBCQVFzRkFBT0NBZ0VBcm9XL2trMllleFN5NEhaQXFLNDVZaGQ5ay9QVTFiaDlFK1BRCk5jZFgzTUdEY2NDRUFkc1k4dll3NVE1cnhuMGFzcSt3VGFCcGxoYS9rMi9VVW9IQ1RqUVp1Mk94dEF3UTdPaWIKVE1tMEorU3NWT3d4YnFQTW9rK1RqVE16NFdXaFFUTzVwRmNoZDZXZXNCVHlJNzJ0aG1jcDd1c2NLU2h3YktIegpQY2h1QTQ4SzhPdi96WkxmZnduQVNZb3VCczJjd1ZiRDI3ZXZOMzdoMGFzR1BrR1VXdm1PSDduTHNVeTh3TTdqCkNGL3NwMmJmTC9OYVdNclJnTHZBMGZMS2pwWTQrVEpPbkVxQmxPcCsrbHlJTEZMcC9qMHNybjRNUnlKK0t6UTEKR1RPakVtQ1QvVEFtOS9XSThSL0FlYjcwTjEzTytYNEtaOUJHaDAxTzN3T1Vqd3BZZ3lxSnNoRnNRUG50VmMrSQpKQmF4M2VQU3NicUcwTFkzcHdHUkpRNmMrd1lxdGk2Y0tNTjliYlRkMDhCNUk1N1RRTHhNcUoycTFnWmw1R1VUCmVFZGNWRXltMnZmd0NPd0lrbGNBbThxTm5kZGZKV1FabE5VaHNOVWFBMkVINnlDeXdaZm9aak9hSDEwTXowV20KeTNpZ2NSZFQ3Mi9NR2VkZk93MlV0MVVvRFZmdEcxcysrditUQ1lpNmpUQU05dkZPckJ4UGlOeGFkUENHR2NZZAowakZIc2FWOGFPV1dQQjZBQ1JteHdDVDdRTnRTczM2MlpIOUlFWWR4Q00yMDUrZmluVHhkOUcwSmVRRTd2Kyt6CldoeWo2ZmJBWUIxM2wvN1hkRnpNSW5BOGxpekdrVHB2RHMxeTBCUzlwV3ppYmhqbVFoZGZIejdCZGpGTHVvc2wKZzlNZE5sND0KLS0tLS1FTkQgQ0VSVElGSUNBVEUtLS0tLQo=",
"!=",
true},
} }
for _, tt := range tests { return cert
t.Run(tt.name, func(t *testing.T) { }
_, err := TLSConfigFromBase64(nil, tt.cert, tt.key)
if (err != nil) != tt.wantErr { t.Run("exact match", func(t *testing.T) {
t.Errorf("TLSConfigFromBase64() error = %v, wantErr %v", err, tt.wantErr) certs := []tls.Certificate{
*gen(t, "a.example.com"),
*gen(t, "b.example.com"),
}
found, err := GetCertificateForDomain(certs, "b.example.com")
if !assert.NoError(t, err) {
return return
} }
assert.Equal(t, &certs[1], found)
}) })
t.Run("wildcard match", func(t *testing.T) {
certs := []tls.Certificate{
*gen(t, "a.example.com"),
*gen(t, "*.example.com"),
} }
}
func TestTLSConfigFromFile(t *testing.T) { found, err := GetCertificateForDomain(certs, "b.example.com")
cfg, err := TLSConfigFromFile(nil, "testdata/example-cert.pem", "testdata/example-key.pem") if !assert.NoError(t, err) {
if err != nil { return
t.Fatal(err)
} }
listener, err := tls.Listen("tcp", ":0", cfg) assert.Equal(t, &certs[1], found)
if err != nil { })
t.Fatal(err) t.Run("no name match", func(t *testing.T) {
certs := []tls.Certificate{
*gen(t, "a.example.com"),
} }
_ = listener
found, err := GetCertificateForDomain(certs, "b.example.com")
if !assert.NoError(t, err) {
return
}
assert.Equal(t, &certs[0], found)
})
t.Run("generate", func(t *testing.T) {
certs := []tls.Certificate{}
found, err := GetCertificateForDomain(certs, "b.example.com")
if !assert.NoError(t, err) {
return
}
assert.NotNil(t, found)
})
} }

View file

@ -15,8 +15,9 @@ import (
"strings" "strings"
"github.com/natefinch/atomic" "github.com/natefinch/atomic"
"github.com/pomerium/pomerium/internal/log"
"github.com/rs/zerolog" "github.com/rs/zerolog"
"github.com/pomerium/pomerium/internal/log"
) )
const ( const (