mirror of
https://github.com/pomerium/pomerium.git
synced 2025-04-29 18:36:30 +02:00
envoy: support autocert (#695)
* envoy: support autocert * envoy: fallback to http host routing if sni fails to match * update comment * envoy: renew certs when necessary * fix tests
This commit is contained in:
parent
0c1ac5a575
commit
dccec1e646
18 changed files with 689 additions and 391 deletions
|
@ -5,8 +5,10 @@ import (
|
|||
"flag"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"os/signal"
|
||||
"sync"
|
||||
"time"
|
||||
"syscall"
|
||||
|
||||
"github.com/pomerium/pomerium/authenticate"
|
||||
"github.com/pomerium/pomerium/authorize"
|
||||
|
@ -24,7 +26,6 @@ import (
|
|||
"github.com/pomerium/pomerium/proxy"
|
||||
|
||||
envoy_service_auth_v2 "github.com/envoyproxy/go-control-plane/envoy/service/auth/v2"
|
||||
"github.com/fsnotify/fsnotify"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
|
@ -32,12 +33,12 @@ var versionFlag = flag.Bool("version", false, "prints the version")
|
|||
var configFile = flag.String("config", "", "Specify configuration file location")
|
||||
|
||||
func main() {
|
||||
if err := run(); err != nil {
|
||||
if err := run(context.Background()); err != nil {
|
||||
log.Fatal().Err(err).Msg("cmd/pomerium")
|
||||
}
|
||||
}
|
||||
|
||||
func run() error {
|
||||
func run(ctx context.Context) error {
|
||||
flag.Parse()
|
||||
if *versionFlag {
|
||||
fmt.Println(version.FullVersion())
|
||||
|
@ -51,18 +52,12 @@ func run() error {
|
|||
|
||||
log.Info().Str("version", version.FullVersion()).Msg("cmd/pomerium")
|
||||
|
||||
var wg sync.WaitGroup
|
||||
if err := setupMetrics(opt, &wg); err != nil {
|
||||
if err := setupMetrics(opt); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := setupTracing(opt); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := setupHTTPRedirectServer(opt, &wg); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// setup the control plane
|
||||
controlPlane, err := controlplane.NewServer()
|
||||
|
@ -99,17 +94,19 @@ func run() error {
|
|||
}
|
||||
|
||||
// start the config change listener
|
||||
opt.OnConfigChange(func(e fsnotify.Event) {
|
||||
log.Info().Str("file", e.Name).Msg("cmd/pomerium: config file changed")
|
||||
opt = config.HandleConfigUpdate(*configFile, opt, optionsUpdaters)
|
||||
})
|
||||
go config.WatchChanges(*configFile, opt, optionsUpdaters)
|
||||
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
go func() {
|
||||
ch := make(chan os.Signal, 2)
|
||||
signal.Notify(ch, os.Interrupt)
|
||||
signal.Notify(ch, syscall.SIGTERM)
|
||||
<-ch
|
||||
cancel()
|
||||
}()
|
||||
|
||||
// run everything
|
||||
eg, ctx := errgroup.WithContext(ctx)
|
||||
eg.Go(func() error {
|
||||
wg.Wait()
|
||||
return nil
|
||||
})
|
||||
eg.Go(func() error {
|
||||
return controlPlane.Run(ctx)
|
||||
})
|
||||
|
@ -172,7 +169,7 @@ func setupCache(opt *config.Options, controlPlane *controlplane.Server) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func setupMetrics(opt *config.Options, wg *sync.WaitGroup) error {
|
||||
func setupMetrics(opt *config.Options) error {
|
||||
if opt.MetricsAddr != "" {
|
||||
handler, err := metrics.PrometheusHandler()
|
||||
if err != nil {
|
||||
|
@ -185,11 +182,11 @@ func setupMetrics(opt *config.Options, wg *sync.WaitGroup) error {
|
|||
Insecure: true,
|
||||
Service: "metrics",
|
||||
}
|
||||
srv, err := httputil.NewServer(serverOpts, handler, wg)
|
||||
var wg sync.WaitGroup
|
||||
_, err = httputil.NewServer(serverOpts, handler, &wg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
go httputil.Shutdown(srv)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -222,40 +219,3 @@ func setupTracing(opt *config.Options) error {
|
|||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func setupHTTPRedirectServer(opt *config.Options, wg *sync.WaitGroup) error {
|
||||
if opt.HTTPRedirectAddr != "" {
|
||||
serverOpts := httputil.ServerOptions{
|
||||
Addr: opt.HTTPRedirectAddr,
|
||||
Insecure: true,
|
||||
Service: "HTTP->HTTPS Redirect",
|
||||
ReadHeaderTimeout: 5 * time.Second,
|
||||
ReadTimeout: 5 * time.Second,
|
||||
WriteTimeout: 5 * time.Second,
|
||||
IdleTimeout: 5 * time.Second,
|
||||
}
|
||||
h := httputil.RedirectHandler()
|
||||
if opt.AutoCert {
|
||||
h = opt.AutoCertHandler(h)
|
||||
}
|
||||
srv, err := httputil.NewServer(&serverOpts, h, wg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
go httputil.Shutdown(srv)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func httpServerOptions(opt *config.Options) *httputil.ServerOptions {
|
||||
return &httputil.ServerOptions{
|
||||
Addr: opt.Addr,
|
||||
TLSConfig: opt.TLSConfig,
|
||||
Insecure: opt.InsecureServer,
|
||||
ReadTimeout: opt.ReadTimeout,
|
||||
WriteTimeout: opt.WriteTimeout,
|
||||
ReadHeaderTimeout: opt.ReadHeaderTimeout,
|
||||
IdleTimeout: opt.IdleTimeout,
|
||||
Service: opt.Services,
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,36 +1,17 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"os/signal"
|
||||
"sync"
|
||||
"syscall"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/pomerium/pomerium/config"
|
||||
"github.com/pomerium/pomerium/internal/httputil"
|
||||
)
|
||||
|
||||
func Test_httpServerOptions(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
opt *config.Options
|
||||
want *httputil.ServerOptions
|
||||
}{
|
||||
{"simple convert", &config.Options{Addr: ":80"}, &httputil.ServerOptions{Addr: ":80"}},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if diff := cmp.Diff(httpServerOptions(tt.opt), tt.want); diff != "" {
|
||||
t.Errorf("httpServerOptions() = \n %s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_setupTracing(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
@ -60,41 +41,9 @@ func Test_setupMetrics(t *testing.T) {
|
|||
c := make(chan os.Signal, 1)
|
||||
signal.Notify(c, syscall.SIGINT)
|
||||
defer signal.Stop(c)
|
||||
var wg sync.WaitGroup
|
||||
|
||||
setupMetrics(tt.opt, &wg)
|
||||
setupMetrics(tt.opt)
|
||||
syscall.Kill(syscall.Getpid(), syscall.SIGINT)
|
||||
waitSig(t, c, syscall.SIGINT)
|
||||
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_setupHTTPRedirectServer(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
opt *config.Options
|
||||
wantErr bool
|
||||
}{
|
||||
{"dont register anything", &config.Options{}, false},
|
||||
{"good redirect server", &config.Options{HTTPRedirectAddr: "localhost:0"}, false},
|
||||
{"bad redirect server port", &config.Options{HTTPRedirectAddr: "localhost:-1"}, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := make(chan os.Signal, 1)
|
||||
var wg sync.WaitGroup
|
||||
|
||||
signal.Notify(c, syscall.SIGINT)
|
||||
defer signal.Stop(c)
|
||||
err := setupHTTPRedirectServer(tt.opt, &wg)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("run() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
|
||||
syscall.Kill(syscall.Getpid(), syscall.SIGINT)
|
||||
waitSig(t, c, syscall.SIGINT)
|
||||
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -274,16 +223,11 @@ func Test_run(t *testing.T) {
|
|||
t.Fatal(err)
|
||||
}
|
||||
configFile = &fn
|
||||
proc, err := os.FindProcess(os.Getpid())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
go func() {
|
||||
time.Sleep(time.Millisecond * 500)
|
||||
proc.Signal(os.Interrupt)
|
||||
}()
|
||||
|
||||
err = run()
|
||||
ctx, clearTimeout := context.WithTimeout(context.Background(), 500*time.Millisecond)
|
||||
defer clearTimeout()
|
||||
|
||||
err = run(ctx)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("run() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
|
|
102
config/autocert.go
Normal file
102
config/autocert.go
Normal file
|
@ -0,0 +1,102 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sync"
|
||||
|
||||
"github.com/caddyserver/certmagic"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
)
|
||||
|
||||
// AutocertManager manages Let's Encrypt certificates based on configuration options.
|
||||
var AutocertManager = newAutocertManager()
|
||||
|
||||
type autocertManager struct {
|
||||
mu sync.RWMutex
|
||||
certmagic *certmagic.Config
|
||||
acmeMgr *certmagic.ACMEManager
|
||||
}
|
||||
|
||||
func newAutocertManager() *autocertManager {
|
||||
mgr := &autocertManager{}
|
||||
return mgr
|
||||
}
|
||||
|
||||
func (mgr *autocertManager) getConfig(options *Options) (*certmagic.Config, error) {
|
||||
mgr.mu.Lock()
|
||||
defer mgr.mu.Unlock()
|
||||
|
||||
cm := mgr.certmagic
|
||||
if cm == nil {
|
||||
cm = certmagic.NewDefault()
|
||||
}
|
||||
|
||||
cm.OnDemand = nil // disable on-demand
|
||||
cm.Storage = &certmagic.FileStorage{Path: options.AutoCertFolder}
|
||||
// add existing certs to the cache, and staple OCSP
|
||||
for _, cert := range options.Certificates {
|
||||
if err := cm.CacheUnmanagedTLSCertificate(cert, nil); err != nil {
|
||||
return nil, fmt.Errorf("config: failed caching cert: %w", err)
|
||||
}
|
||||
}
|
||||
acmeMgr := certmagic.NewACMEManager(cm, certmagic.DefaultACME)
|
||||
acmeMgr.Agreed = true
|
||||
if options.AutoCertUseStaging {
|
||||
acmeMgr.CA = certmagic.LetsEncryptStagingCA
|
||||
}
|
||||
acmeMgr.DisableTLSALPNChallenge = true
|
||||
cm.Issuer = acmeMgr
|
||||
mgr.acmeMgr = acmeMgr
|
||||
|
||||
return cm, nil
|
||||
}
|
||||
|
||||
func (mgr *autocertManager) update(options *Options) error {
|
||||
if !options.AutoCert {
|
||||
return nil
|
||||
}
|
||||
|
||||
cm, err := mgr.getConfig(options)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, domain := range options.sourceHostnames() {
|
||||
cert, err := cm.CacheManagedCertificate(domain)
|
||||
if err != nil {
|
||||
log.Info().Str("domain", domain).Msg("obtaining certificate")
|
||||
err = cm.ObtainCert(context.Background(), domain, false)
|
||||
if err != nil {
|
||||
return fmt.Errorf("config: failed to obtain client certificate: %w", err)
|
||||
}
|
||||
cert, err = cm.CacheManagedCertificate(domain)
|
||||
}
|
||||
if err == nil && cert.NeedsRenewal(cm) {
|
||||
log.Info().Str("domain", domain).Msg("renewing certificate")
|
||||
err = cm.RenewCert(context.Background(), domain, false)
|
||||
if err != nil {
|
||||
return fmt.Errorf("config: failed to renew client certificate: %w", err)
|
||||
}
|
||||
cert, err = cm.CacheManagedCertificate(domain)
|
||||
}
|
||||
if err == nil {
|
||||
options.Certificates = append(options.Certificates, cert.Certificate)
|
||||
} else {
|
||||
log.Error().Err(err).Msg("config: failed to obtain client certificate")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mgr *autocertManager) HandleHTTPChallenge(w http.ResponseWriter, r *http.Request) bool {
|
||||
mgr.mu.RLock()
|
||||
acmeMgr := mgr.acmeMgr
|
||||
mgr.mu.RUnlock()
|
||||
if acmeMgr == nil {
|
||||
return false
|
||||
}
|
||||
return acmeMgr.HandleHTTPChallenge(w, r)
|
||||
}
|
|
@ -1,15 +1,16 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/tls"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
|
@ -68,10 +69,6 @@ type Options struct {
|
|||
// and renewal from LetsEncrypt. Must be used in conjunction with AutoCertFolder.
|
||||
AutoCert bool `mapstructure:"autocert" yaml:"autocert,omitempty"`
|
||||
|
||||
// AutoCertHandler is the HTTP challenge handler used in a http-01 acme
|
||||
// https://letsencrypt.org/docs/challenge-types/#http-01-challenge
|
||||
AutoCertHandler func(h http.Handler) http.Handler `hash:"ignore"`
|
||||
|
||||
// AutoCertFolder specifies the location to store, and load autocert managed
|
||||
// TLS certificates.
|
||||
// defaults to $XDG_DATA_HOME/pomerium
|
||||
|
@ -83,7 +80,7 @@ type Options struct {
|
|||
// https://letsencrypt.org/docs/staging-environment/
|
||||
AutoCertUseStaging bool `mapstructure:"autocert_use_staging" yaml:"autocert_use_staging,omitempty"`
|
||||
|
||||
Certificates []certificateFilePair `mapstructure:"certificates" yaml:"certificates,omitempty"`
|
||||
CertificateFiles []certificateFilePair `mapstructure:"certificates" yaml:"certificates,omitempty"`
|
||||
|
||||
// Cert and Key is the x509 certificate used to create the HTTPS server.
|
||||
Cert string `mapstructure:"certificate" yaml:"certificate,omitempty"`
|
||||
|
@ -93,7 +90,7 @@ type Options struct {
|
|||
CertFile string `mapstructure:"certificate_file" yaml:"certificate_file,omitempty"`
|
||||
KeyFile string `mapstructure:"certificate_key_file" yaml:"certificate_key_file,omitempty"`
|
||||
|
||||
TLSConfig *tls.Config `hash:"ignore"`
|
||||
Certificates []tls.Certificate `yaml:"-"`
|
||||
|
||||
// HttpRedirectAddr, if set, specifies the host and port to run the HTTP
|
||||
// to HTTPS redirect server on. If empty, no redirect server is started.
|
||||
|
@ -534,38 +531,42 @@ func (o *Options) Validate() error {
|
|||
}
|
||||
|
||||
if o.Cert != "" || o.Key != "" {
|
||||
o.TLSConfig, err = cryptutil.TLSConfigFromBase64(o.TLSConfig, o.Cert, o.Key)
|
||||
cert, err := cryptutil.CertificateFromBase64(o.Cert, o.Key)
|
||||
if err != nil {
|
||||
return fmt.Errorf("config: bad cert base64 %w", err)
|
||||
}
|
||||
o.Certificates = append(o.Certificates, *cert)
|
||||
}
|
||||
|
||||
if len(o.Certificates) != 0 {
|
||||
for _, c := range o.Certificates {
|
||||
o.TLSConfig, err = cryptutil.TLSConfigFromFile(o.TLSConfig, c.CertFile, c.KeyFile)
|
||||
for _, c := range o.CertificateFiles {
|
||||
cert, err := cryptutil.CertificateFromFile(c.CertFile, c.KeyFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("config: bad cert file %w", err)
|
||||
}
|
||||
}
|
||||
o.Certificates = append(o.Certificates, *cert)
|
||||
}
|
||||
|
||||
if o.CertFile != "" || o.KeyFile != "" {
|
||||
o.TLSConfig, err = cryptutil.TLSConfigFromFile(o.TLSConfig, o.CertFile, o.KeyFile)
|
||||
cert, err := cryptutil.CertificateFromFile(o.CertFile, o.KeyFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("config: bad cert file %w", err)
|
||||
}
|
||||
o.Certificates = append(o.Certificates, *cert)
|
||||
}
|
||||
if o.AutoCert {
|
||||
o.TLSConfig, o.AutoCertHandler, err = cryptutil.NewAutocert(
|
||||
o.TLSConfig,
|
||||
o.sourceHostnames(),
|
||||
o.AutoCertUseStaging,
|
||||
o.AutoCertFolder)
|
||||
|
||||
RedirectAndAutocertServer.update(o)
|
||||
|
||||
err = AutocertManager.update(o)
|
||||
if err != nil {
|
||||
return fmt.Errorf("config: autocert failed %w", err)
|
||||
return fmt.Errorf("config: failed to setup autocert: %w", err)
|
||||
}
|
||||
}
|
||||
if !o.InsecureServer && o.TLSConfig == nil {
|
||||
|
||||
// sort the certificates so we get a consistent hash
|
||||
sort.Slice(o.Certificates, func(i, j int) bool {
|
||||
return compareByteSliceSlice(o.Certificates[i].Certificate, o.Certificates[j].Certificate) < 0
|
||||
})
|
||||
|
||||
if !o.InsecureServer && len(o.Certificates) == 0 {
|
||||
return fmt.Errorf("config: server must be run with `autocert`, " +
|
||||
"`insecure_server` or manually provided certificates to start")
|
||||
}
|
||||
|
@ -576,13 +577,21 @@ func (o *Options) sourceHostnames() []string {
|
|||
if len(o.Policies) == 0 {
|
||||
return nil
|
||||
}
|
||||
var h []string
|
||||
|
||||
dedupe := map[string]struct{}{}
|
||||
for _, p := range o.Policies {
|
||||
h = append(h, p.Source.Hostname())
|
||||
dedupe[p.Source.Hostname()] = struct{}{}
|
||||
}
|
||||
if o.AuthenticateURL != nil {
|
||||
h = append(h, o.AuthenticateURL.Hostname())
|
||||
dedupe[o.AuthenticateURL.Hostname()] = struct{}{}
|
||||
}
|
||||
|
||||
var h []string
|
||||
for k := range dedupe {
|
||||
h = append(h, k)
|
||||
}
|
||||
sort.Strings(h)
|
||||
|
||||
return h
|
||||
}
|
||||
|
||||
|
@ -601,10 +610,37 @@ func (o *Options) Checksum() uint64 {
|
|||
return hash
|
||||
}
|
||||
|
||||
// HandleConfigUpdate takes configuration file, an existing options struct, and
|
||||
// WatchChanges takes a configuration file, an existing options struct, and
|
||||
// updates each service in the services slice OptionsUpdater with a new set
|
||||
// of options if any change is detected. It also periodically rechecks if
|
||||
// any computed properties have changed.
|
||||
func WatchChanges(configFile string, opt *Options, services []OptionsUpdater) {
|
||||
onchange := make(chan struct{}, 1)
|
||||
ticker := time.NewTicker(10 * time.Minute) // force check every 10 minutes
|
||||
defer ticker.Stop()
|
||||
|
||||
opt.OnConfigChange(func(fs fsnotify.Event) {
|
||||
log.Info().Str("file", fs.Name).Msg("config: file changed")
|
||||
select {
|
||||
case onchange <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
})
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-onchange:
|
||||
case <-ticker.C:
|
||||
}
|
||||
|
||||
opt = handleConfigUpdate(configFile, opt, services)
|
||||
}
|
||||
}
|
||||
|
||||
// handleConfigUpdate takes configuration file, an existing options struct, and
|
||||
// updates each service in the services slice OptionsUpdater with a new set of
|
||||
// options if any change is detected.
|
||||
func HandleConfigUpdate(configFile string, opt *Options, services []OptionsUpdater) *Options {
|
||||
func handleConfigUpdate(configFile string, opt *Options, services []OptionsUpdater) *Options {
|
||||
newOpt, err := NewOptionsFromConfig(configFile)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("config: could not reload configuration")
|
||||
|
@ -648,3 +684,31 @@ func dataDir() string {
|
|||
}
|
||||
return filepath.Join(baseDir, "pomerium")
|
||||
}
|
||||
|
||||
func compareByteSliceSlice(a, b [][]byte) int {
|
||||
sz := min(len(a), len(b))
|
||||
for i := 0; i < sz; i++ {
|
||||
switch bytes.Compare(a[i], b[i]) {
|
||||
case -1:
|
||||
return -1
|
||||
case 1:
|
||||
return 1
|
||||
}
|
||||
}
|
||||
|
||||
switch {
|
||||
case len(a) < len(b):
|
||||
return -1
|
||||
case len(b) < len(a):
|
||||
return 1
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
func min(x, y int) int {
|
||||
if x < y {
|
||||
return x
|
||||
}
|
||||
return y
|
||||
}
|
||||
|
|
|
@ -416,7 +416,7 @@ func Test_HandleConfigUpdate(t *testing.T) {
|
|||
os.Setenv(k, v)
|
||||
defer os.Unsetenv(k)
|
||||
}
|
||||
HandleConfigUpdate("", oldOpts, []OptionsUpdater{tt.service})
|
||||
handleConfigUpdate("", oldOpts, []OptionsUpdater{tt.service})
|
||||
if tt.service.Updated != tt.wantUpdate {
|
||||
t.Errorf("Failed to update config on service")
|
||||
}
|
||||
|
@ -441,7 +441,7 @@ func TestOptions_sourceHostnames(t *testing.T) {
|
|||
}{
|
||||
{"empty", []Policy{}, "", nil},
|
||||
{"good no authN", []Policy{{From: "https://from.example", To: "https://to.example"}}, "", []string{"from.example"}},
|
||||
{"good with authN", []Policy{{From: "https://from.example", To: "https://to.example"}}, "https://authn.example.com", []string{"from.example", "authn.example.com"}},
|
||||
{"good with authN", []Policy{{From: "https://from.example", To: "https://to.example"}}, "https://authn.example.com", []string{"authn.example.com", "from.example"}},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
@ -459,3 +459,66 @@ func TestOptions_sourceHostnames(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompareByteSliceSlice(t *testing.T) {
|
||||
type Bytes = [][]byte
|
||||
|
||||
tests := []struct {
|
||||
expect int
|
||||
a Bytes
|
||||
b Bytes
|
||||
}{
|
||||
{
|
||||
0,
|
||||
Bytes{
|
||||
{0, 1, 2, 3},
|
||||
},
|
||||
Bytes{
|
||||
{0, 1, 2, 3},
|
||||
},
|
||||
},
|
||||
{
|
||||
-1,
|
||||
Bytes{
|
||||
{0, 1, 2, 3},
|
||||
},
|
||||
Bytes{
|
||||
{0, 1, 2, 4},
|
||||
},
|
||||
},
|
||||
{
|
||||
1,
|
||||
Bytes{
|
||||
{0, 1, 2, 4},
|
||||
},
|
||||
Bytes{
|
||||
{0, 1, 2, 3},
|
||||
},
|
||||
},
|
||||
{-1,
|
||||
Bytes{
|
||||
{0, 1, 2, 3},
|
||||
},
|
||||
Bytes{
|
||||
{0, 1, 2, 3},
|
||||
{4, 5, 6, 7},
|
||||
},
|
||||
},
|
||||
{1,
|
||||
Bytes{
|
||||
{0, 1, 2, 3},
|
||||
{4, 5, 6, 7},
|
||||
},
|
||||
Bytes{
|
||||
{0, 1, 2, 3},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
actual := compareByteSliceSlice(tt.a, tt.b)
|
||||
if tt.expect != actual {
|
||||
t.Errorf("expected compare(%v, %v) to be %v but got %v",
|
||||
tt.a, tt.b, tt.expect, actual)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -115,7 +115,7 @@ func (p *Policy) Validate() error {
|
|||
}
|
||||
|
||||
if p.TLSClientCert != "" && p.TLSClientKey != "" {
|
||||
p.ClientCertificate, err = cryptutil.CertifcateFromBase64(p.TLSClientCert, p.TLSClientKey)
|
||||
p.ClientCertificate, err = cryptutil.CertificateFromBase64(p.TLSClientCert, p.TLSClientKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("config: couldn't decode client cert %w", err)
|
||||
}
|
||||
|
|
60
config/redirect.go
Normal file
60
config/redirect.go
Normal file
|
@ -0,0 +1,60 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"sync"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/httputil"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
)
|
||||
|
||||
// RedirectAndAutocertServer is an HTTP server which handles redirecting to HTTPS and autocerts.
|
||||
var RedirectAndAutocertServer = newRedirectAndAutoCertServer()
|
||||
|
||||
type redirectAndAutoCertServer struct {
|
||||
mu sync.Mutex
|
||||
srv *http.Server
|
||||
}
|
||||
|
||||
func newRedirectAndAutoCertServer() *redirectAndAutoCertServer {
|
||||
return &redirectAndAutoCertServer{}
|
||||
}
|
||||
|
||||
func (srv *redirectAndAutoCertServer) update(options *Options) {
|
||||
srv.mu.Lock()
|
||||
defer srv.mu.Unlock()
|
||||
|
||||
if srv.srv != nil {
|
||||
// nothing to do if the address hasn't changed
|
||||
if srv.srv.Addr == options.HTTPRedirectAddr {
|
||||
return
|
||||
}
|
||||
// close immediately, don't care about the error
|
||||
_ = srv.srv.Close()
|
||||
srv.srv = nil
|
||||
}
|
||||
|
||||
if options.HTTPRedirectAddr == "" {
|
||||
return
|
||||
}
|
||||
|
||||
redirect := httputil.RedirectHandler()
|
||||
|
||||
hsrv := &http.Server{
|
||||
Addr: options.HTTPRedirectAddr,
|
||||
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if AutocertManager.HandleHTTPChallenge(w, r) {
|
||||
return
|
||||
}
|
||||
redirect.ServeHTTP(w, r)
|
||||
}),
|
||||
}
|
||||
go func() {
|
||||
log.Info().Str("addr", hsrv.Addr).Msg("starting http redirect server")
|
||||
err := hsrv.ListenAndServe()
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("failed to run http redirect server")
|
||||
}
|
||||
}()
|
||||
srv.srv = hsrv
|
||||
}
|
1
go.mod
1
go.mod
|
@ -10,7 +10,6 @@ require (
|
|||
github.com/coreos/go-oidc v2.2.1+incompatible
|
||||
github.com/envoyproxy/go-control-plane v0.9.5
|
||||
github.com/fsnotify/fsnotify v1.4.9
|
||||
github.com/go-acme/lego/v3 v3.4.0
|
||||
github.com/go-redis/redis/v7 v7.2.0
|
||||
github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e
|
||||
github.com/golang/mock v1.4.3
|
||||
|
|
|
@ -96,7 +96,7 @@ func (srv *Server) streamAggregatedResourcesProcessStep(
|
|||
for typeURL, version := range versions {
|
||||
// the versions are different, so the envoy config needs to be updated
|
||||
if version != fmt.Sprint(current.version) {
|
||||
res, err := srv.buildDiscoveryResponse(fmt.Sprint(current.version), typeURL, current.Options)
|
||||
res, err := srv.buildDiscoveryResponse(fmt.Sprint(current.version), typeURL, ¤t.Options)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -3,8 +3,6 @@ package controlplane
|
|||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
|
||||
"github.com/pomerium/pomerium/config"
|
||||
|
@ -19,7 +17,7 @@ import (
|
|||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
func (srv *Server) buildDiscoveryResponse(version string, typeURL string, options config.Options) (*envoy_service_discovery_v3.DiscoveryResponse, error) {
|
||||
func (srv *Server) buildDiscoveryResponse(version string, typeURL string, options *config.Options) (*envoy_service_discovery_v3.DiscoveryResponse, error) {
|
||||
switch typeURL {
|
||||
case "type.googleapis.com/envoy.config.listener.v3.Listener":
|
||||
listeners := srv.buildListeners(options)
|
||||
|
@ -56,7 +54,7 @@ func (srv *Server) buildDiscoveryResponse(version string, typeURL string, option
|
|||
}
|
||||
}
|
||||
|
||||
func (srv *Server) buildAccessLogs(options config.Options) []*envoy_config_accesslog_v3.AccessLog {
|
||||
func (srv *Server) buildAccessLogs(options *config.Options) []*envoy_config_accesslog_v3.AccessLog {
|
||||
lvl := options.ProxyLogLevel
|
||||
if lvl == "" {
|
||||
lvl = options.LogLevel
|
||||
|
@ -112,10 +110,10 @@ func buildAddress(hostport string, defaultPort int) *envoy_config_core_v3.Addres
|
|||
}
|
||||
}
|
||||
|
||||
func getAbsoluteFilePath(filename string) string {
|
||||
if filepath.IsAbs(filename) {
|
||||
return filename
|
||||
func inlineBytes(bs []byte) *envoy_config_core_v3.DataSource {
|
||||
return &envoy_config_core_v3.DataSource{
|
||||
Specifier: &envoy_config_core_v3.DataSource_InlineBytes{
|
||||
InlineBytes: bs,
|
||||
},
|
||||
}
|
||||
wd, _ := os.Getwd()
|
||||
return filepath.Join(wd, filename)
|
||||
}
|
||||
|
|
|
@ -15,7 +15,7 @@ import (
|
|||
"github.com/pomerium/pomerium/internal/urlutil"
|
||||
)
|
||||
|
||||
func (srv *Server) buildClusters(options config.Options) []*envoy_config_cluster_v3.Cluster {
|
||||
func (srv *Server) buildClusters(options *config.Options) []*envoy_config_cluster_v3.Cluster {
|
||||
grpcURL := &url.URL{
|
||||
Scheme: "grpc",
|
||||
Host: srv.GRPCListener.Addr().String(),
|
||||
|
|
|
@ -1,7 +1,9 @@
|
|||
package controlplane
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"bytes"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"sort"
|
||||
|
||||
envoy_config_core_v3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
|
||||
|
@ -14,8 +16,11 @@ import (
|
|||
envoy_type_v3 "github.com/envoyproxy/go-control-plane/envoy/type/v3"
|
||||
"github.com/golang/protobuf/ptypes"
|
||||
"github.com/golang/protobuf/ptypes/any"
|
||||
"google.golang.org/protobuf/types/known/emptypb"
|
||||
|
||||
"github.com/pomerium/pomerium/config"
|
||||
"github.com/pomerium/pomerium/internal/cryptutil"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/internal/urlutil"
|
||||
)
|
||||
|
||||
|
@ -29,11 +34,11 @@ func init() {
|
|||
})
|
||||
}
|
||||
|
||||
func (srv *Server) buildListeners(options config.Options) []*envoy_config_listener_v3.Listener {
|
||||
func (srv *Server) buildListeners(options *config.Options) []*envoy_config_listener_v3.Listener {
|
||||
var listeners []*envoy_config_listener_v3.Listener
|
||||
|
||||
if config.IsAuthenticate(options.Services) || config.IsProxy(options.Services) {
|
||||
listeners = append(listeners, srv.buildHTTPListener(options))
|
||||
listeners = append(listeners, srv.buildMainListener(options))
|
||||
}
|
||||
|
||||
if config.IsAuthorize(options.Services) || config.IsCache(options.Services) {
|
||||
|
@ -43,22 +48,77 @@ func (srv *Server) buildListeners(options config.Options) []*envoy_config_listen
|
|||
return listeners
|
||||
}
|
||||
|
||||
func (srv *Server) buildHTTPListener(options config.Options) *envoy_config_listener_v3.Listener {
|
||||
defaultPort := 80
|
||||
var transportSocket *envoy_config_core_v3.TransportSocket
|
||||
if !options.InsecureServer {
|
||||
defaultPort = 443
|
||||
tlsConfig, _ := ptypes.MarshalAny(srv.buildDownstreamTLSContext(options))
|
||||
transportSocket = &envoy_config_core_v3.TransportSocket{
|
||||
func (srv *Server) buildMainListener(options *config.Options) *envoy_config_listener_v3.Listener {
|
||||
if options.InsecureServer {
|
||||
filter := srv.buildMainHTTPConnectionManagerFilter(options,
|
||||
srv.getAllRouteableDomains(options, options.Addr))
|
||||
|
||||
return &envoy_config_listener_v3.Listener{
|
||||
Name: "http-ingress",
|
||||
Address: buildAddress(options.Addr, 80),
|
||||
FilterChains: []*envoy_config_listener_v3.FilterChain{{
|
||||
Filters: []*envoy_config_listener_v3.Filter{
|
||||
filter,
|
||||
},
|
||||
}},
|
||||
}
|
||||
}
|
||||
|
||||
tlsInspectorCfg, _ := ptypes.MarshalAny(new(emptypb.Empty))
|
||||
li := &envoy_config_listener_v3.Listener{
|
||||
Name: "https-ingress",
|
||||
Address: buildAddress(options.Addr, 443),
|
||||
ListenerFilters: []*envoy_config_listener_v3.ListenerFilter{{
|
||||
Name: "envoy.filters.listener.tls_inspector",
|
||||
ConfigType: &envoy_config_listener_v3.ListenerFilter_TypedConfig{
|
||||
TypedConfig: tlsInspectorCfg,
|
||||
},
|
||||
}},
|
||||
FilterChains: srv.buildFilterChains(options, options.Addr,
|
||||
func(tlsDomain string, httpDomains []string) *envoy_config_listener_v3.FilterChain {
|
||||
filter := srv.buildMainHTTPConnectionManagerFilter(options, httpDomains)
|
||||
filterChain := &envoy_config_listener_v3.FilterChain{
|
||||
Filters: []*envoy_config_listener_v3.Filter{filter},
|
||||
}
|
||||
if tlsDomain != "*" {
|
||||
filterChain.FilterChainMatch = &envoy_config_listener_v3.FilterChainMatch{
|
||||
ServerNames: []string{tlsDomain},
|
||||
}
|
||||
}
|
||||
tlsContext := srv.buildDownstreamTLSContext(options, tlsDomain)
|
||||
if tlsContext != nil {
|
||||
tlsConfig, _ := ptypes.MarshalAny(tlsContext)
|
||||
filterChain.TransportSocket = &envoy_config_core_v3.TransportSocket{
|
||||
Name: "tls",
|
||||
ConfigType: &envoy_config_core_v3.TransportSocket_TypedConfig{
|
||||
TypedConfig: tlsConfig,
|
||||
},
|
||||
}
|
||||
}
|
||||
return filterChain
|
||||
}),
|
||||
}
|
||||
return li
|
||||
}
|
||||
|
||||
func (srv *Server) buildFilterChains(
|
||||
options *config.Options, addr string,
|
||||
callback func(tlsDomain string, httpDomains []string) *envoy_config_listener_v3.FilterChain,
|
||||
) []*envoy_config_listener_v3.FilterChain {
|
||||
allDomains := srv.getAllRouteableDomains(options, addr)
|
||||
var chains []*envoy_config_listener_v3.FilterChain
|
||||
for _, domain := range allDomains {
|
||||
// first we match on SNI
|
||||
chains = append(chains, callback(domain, []string{domain}))
|
||||
}
|
||||
// if there are no SNI matches we match on HTTP host
|
||||
chains = append(chains, callback("*", allDomains))
|
||||
return chains
|
||||
}
|
||||
|
||||
func (srv *Server) buildMainHTTPConnectionManagerFilter(options *config.Options, domains []string) *envoy_config_listener_v3.Filter {
|
||||
var virtualHosts []*envoy_config_route_v3.VirtualHost
|
||||
for _, domain := range srv.getAllRouteableDomains(options, options.Addr) {
|
||||
for _, domain := range domains {
|
||||
vh := &envoy_config_route_v3.VirtualHost{
|
||||
Name: domain,
|
||||
Domains: []string{domain},
|
||||
|
@ -142,38 +202,66 @@ func (srv *Server) buildHTTPListener(options config.Options) *envoy_config_liste
|
|||
AccessLog: srv.buildAccessLogs(options),
|
||||
})
|
||||
|
||||
li := &envoy_config_listener_v3.Listener{
|
||||
Name: "http-ingress",
|
||||
Address: buildAddress(options.Addr, defaultPort),
|
||||
FilterChains: []*envoy_config_listener_v3.FilterChain{{
|
||||
Filters: []*envoy_config_listener_v3.Filter{
|
||||
{
|
||||
return &envoy_config_listener_v3.Filter{
|
||||
Name: "envoy.filters.network.http_connection_manager",
|
||||
ConfigType: &envoy_config_listener_v3.Filter_TypedConfig{
|
||||
TypedConfig: tc,
|
||||
},
|
||||
},
|
||||
},
|
||||
TransportSocket: transportSocket,
|
||||
}},
|
||||
}
|
||||
return li
|
||||
}
|
||||
|
||||
func (srv *Server) buildGRPCListener(options config.Options) *envoy_config_listener_v3.Listener {
|
||||
defaultPort := 80
|
||||
var transportSocket *envoy_config_core_v3.TransportSocket
|
||||
if !options.GRPCInsecure {
|
||||
defaultPort = 443
|
||||
tlsConfig, _ := ptypes.MarshalAny(srv.buildDownstreamTLSContext(options))
|
||||
transportSocket = &envoy_config_core_v3.TransportSocket{
|
||||
func (srv *Server) buildGRPCListener(options *config.Options) *envoy_config_listener_v3.Listener {
|
||||
filter := srv.buildGRPCHTTPConnectionManagerFilter()
|
||||
|
||||
if options.GRPCInsecure {
|
||||
return &envoy_config_listener_v3.Listener{
|
||||
Name: "grpc-ingress",
|
||||
Address: buildAddress(options.GRPCAddr, 80),
|
||||
FilterChains: []*envoy_config_listener_v3.FilterChain{{
|
||||
Filters: []*envoy_config_listener_v3.Filter{
|
||||
filter,
|
||||
},
|
||||
}},
|
||||
}
|
||||
}
|
||||
|
||||
tlsInspectorCfg, _ := ptypes.MarshalAny(new(emptypb.Empty))
|
||||
li := &envoy_config_listener_v3.Listener{
|
||||
Name: "grpc-ingress",
|
||||
Address: buildAddress(options.GRPCAddr, 443),
|
||||
ListenerFilters: []*envoy_config_listener_v3.ListenerFilter{{
|
||||
Name: "envoy.filters.listener.tls_inspector",
|
||||
ConfigType: &envoy_config_listener_v3.ListenerFilter_TypedConfig{
|
||||
TypedConfig: tlsInspectorCfg,
|
||||
},
|
||||
}},
|
||||
FilterChains: srv.buildFilterChains(options, options.Addr,
|
||||
func(tlsDomain string, httpDomains []string) *envoy_config_listener_v3.FilterChain {
|
||||
filterChain := &envoy_config_listener_v3.FilterChain{
|
||||
Filters: []*envoy_config_listener_v3.Filter{filter},
|
||||
}
|
||||
if tlsDomain != "*" {
|
||||
filterChain.FilterChainMatch = &envoy_config_listener_v3.FilterChainMatch{
|
||||
ServerNames: []string{tlsDomain},
|
||||
}
|
||||
}
|
||||
tlsContext := srv.buildDownstreamTLSContext(options, tlsDomain)
|
||||
if tlsContext != nil {
|
||||
tlsConfig, _ := ptypes.MarshalAny(tlsContext)
|
||||
filterChain.TransportSocket = &envoy_config_core_v3.TransportSocket{
|
||||
Name: "tls",
|
||||
ConfigType: &envoy_config_core_v3.TransportSocket_TypedConfig{
|
||||
TypedConfig: tlsConfig,
|
||||
},
|
||||
}
|
||||
}
|
||||
return filterChain
|
||||
}),
|
||||
}
|
||||
return li
|
||||
}
|
||||
|
||||
func (srv *Server) buildGRPCHTTPConnectionManagerFilter() *envoy_config_listener_v3.Filter {
|
||||
tc, _ := ptypes.MarshalAny(&envoy_http_connection_manager.HttpConnectionManager{
|
||||
CodecType: envoy_http_connection_manager.HttpConnectionManager_AUTO,
|
||||
StatPrefix: "grpc_ingress",
|
||||
|
@ -191,7 +279,9 @@ func (srv *Server) buildGRPCListener(options config.Options) *envoy_config_liste
|
|||
},
|
||||
Action: &envoy_config_route_v3.Route_Route{
|
||||
Route: &envoy_config_route_v3.RouteAction{
|
||||
ClusterSpecifier: &envoy_config_route_v3.RouteAction_Cluster{Cluster: "pomerium-control-plane-grpc"},
|
||||
ClusterSpecifier: &envoy_config_route_v3.RouteAction_Cluster{
|
||||
Cluster: "pomerium-control-plane-grpc",
|
||||
},
|
||||
},
|
||||
},
|
||||
}},
|
||||
|
@ -202,64 +292,57 @@ func (srv *Server) buildGRPCListener(options config.Options) *envoy_config_liste
|
|||
Name: "envoy.filters.http.router",
|
||||
}},
|
||||
})
|
||||
|
||||
return &envoy_config_listener_v3.Listener{
|
||||
Name: "grpc-ingress",
|
||||
Address: buildAddress(options.GRPCAddr, defaultPort),
|
||||
FilterChains: []*envoy_config_listener_v3.FilterChain{{
|
||||
Filters: []*envoy_config_listener_v3.Filter{{
|
||||
return &envoy_config_listener_v3.Filter{
|
||||
Name: "envoy.filters.network.http_connection_manager",
|
||||
ConfigType: &envoy_config_listener_v3.Filter_TypedConfig{
|
||||
TypedConfig: tc,
|
||||
},
|
||||
}},
|
||||
TransportSocket: transportSocket,
|
||||
}},
|
||||
}
|
||||
}
|
||||
|
||||
func (srv *Server) buildDownstreamTLSContext(options config.Options) *envoy_extensions_transport_sockets_tls_v3.DownstreamTlsContext {
|
||||
var cert envoy_extensions_transport_sockets_tls_v3.TlsCertificate
|
||||
if options.Cert != "" {
|
||||
bs, _ := base64.StdEncoding.DecodeString(options.Cert)
|
||||
cert.CertificateChain = &envoy_config_core_v3.DataSource{
|
||||
Specifier: &envoy_config_core_v3.DataSource_InlineBytes{
|
||||
InlineBytes: bs,
|
||||
},
|
||||
func (srv *Server) buildDownstreamTLSContext(options *config.Options, domain string) *envoy_extensions_transport_sockets_tls_v3.DownstreamTlsContext {
|
||||
cert, err := cryptutil.GetCertificateForDomain(options.Certificates, domain)
|
||||
if err != nil {
|
||||
log.Warn().Str("domain", domain).Err(err).Msg("failed to get certificate for domain")
|
||||
return nil
|
||||
}
|
||||
|
||||
envoyCert := &envoy_extensions_transport_sockets_tls_v3.TlsCertificate{}
|
||||
var chain bytes.Buffer
|
||||
for _, cbs := range cert.Certificate {
|
||||
_ = pem.Encode(&chain, &pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: cbs,
|
||||
})
|
||||
}
|
||||
envoyCert.CertificateChain = inlineBytes(chain.Bytes())
|
||||
if cert.OCSPStaple != nil {
|
||||
envoyCert.OcspStaple = inlineBytes(cert.OCSPStaple)
|
||||
}
|
||||
if bs, err := x509.MarshalPKCS8PrivateKey(cert.PrivateKey); err == nil {
|
||||
envoyCert.PrivateKey = inlineBytes(pem.EncodeToMemory(
|
||||
&pem.Block{
|
||||
Type: "PRIVATE KEY",
|
||||
Bytes: bs,
|
||||
},
|
||||
))
|
||||
} else {
|
||||
cert.CertificateChain = &envoy_config_core_v3.DataSource{
|
||||
Specifier: &envoy_config_core_v3.DataSource_Filename{
|
||||
Filename: getAbsoluteFilePath(options.CertFile),
|
||||
},
|
||||
}
|
||||
}
|
||||
if options.Key != "" {
|
||||
bs, _ := base64.StdEncoding.DecodeString(options.Key)
|
||||
cert.PrivateKey = &envoy_config_core_v3.DataSource{
|
||||
Specifier: &envoy_config_core_v3.DataSource_InlineBytes{
|
||||
InlineBytes: bs,
|
||||
},
|
||||
}
|
||||
} else {
|
||||
cert.PrivateKey = &envoy_config_core_v3.DataSource{
|
||||
Specifier: &envoy_config_core_v3.DataSource_Filename{
|
||||
Filename: getAbsoluteFilePath(options.KeyFile),
|
||||
},
|
||||
log.Warn().Err(err).Msg("failed to marshal private key for tls config")
|
||||
}
|
||||
for _, scts := range cert.SignedCertificateTimestamps {
|
||||
envoyCert.SignedCertificateTimestamp = append(envoyCert.SignedCertificateTimestamp,
|
||||
inlineBytes(scts))
|
||||
}
|
||||
|
||||
return &envoy_extensions_transport_sockets_tls_v3.DownstreamTlsContext{
|
||||
CommonTlsContext: &envoy_extensions_transport_sockets_tls_v3.CommonTlsContext{
|
||||
TlsCertificates: []*envoy_extensions_transport_sockets_tls_v3.TlsCertificate{
|
||||
&cert,
|
||||
},
|
||||
TlsCertificates: []*envoy_extensions_transport_sockets_tls_v3.TlsCertificate{envoyCert},
|
||||
AlpnProtocols: []string{"h2", "http/1.1"},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (srv *Server) getAllRouteableDomains(options config.Options, addr string) []string {
|
||||
func (srv *Server) getAllRouteableDomains(options *config.Options, addr string) []string {
|
||||
lookup := map[string]struct{}{}
|
||||
if config.IsAuthenticate(options.Services) && addr == options.Addr {
|
||||
lookup[urlutil.StripPort(options.AuthenticateURL.Host)] = struct{}{}
|
||||
|
|
|
@ -37,7 +37,7 @@ func (srv *Server) buildGRPCRoutes() []*envoy_config_route_v3.Route {
|
|||
}}
|
||||
}
|
||||
|
||||
func (srv *Server) buildPomeriumHTTPRoutes(options config.Options, domain string) []*envoy_config_route_v3.Route {
|
||||
func (srv *Server) buildPomeriumHTTPRoutes(options *config.Options, domain string) []*envoy_config_route_v3.Route {
|
||||
routes := []*envoy_config_route_v3.Route{
|
||||
srv.buildControlPlanePathRoute("/ping"),
|
||||
srv.buildControlPlanePathRoute("/healthz"),
|
||||
|
@ -95,7 +95,7 @@ func (srv *Server) buildControlPlanePrefixRoute(prefix string) *envoy_config_rou
|
|||
}
|
||||
}
|
||||
|
||||
func (srv *Server) buildPolicyRoutes(options config.Options, domain string) []*envoy_config_route_v3.Route {
|
||||
func (srv *Server) buildPolicyRoutes(options *config.Options, domain string) []*envoy_config_route_v3.Route {
|
||||
var routes []*envoy_config_route_v3.Route
|
||||
for i, policy := range options.Policies {
|
||||
if policy.Source.Hostname() != domain {
|
||||
|
|
|
@ -2,17 +2,23 @@ package cryptutil
|
|||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/base64"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"math/big"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
// CertifcateFromBase64 returns an X509 pair from a base64 encoded blob.
|
||||
func CertifcateFromBase64(cert, key string) (*tls.Certificate, error) {
|
||||
// CertificateFromBase64 returns an X509 pair from a base64 encoded blob.
|
||||
func CertificateFromBase64(cert, key string) (*tls.Certificate, error) {
|
||||
decodedCert, err := base64.StdEncoding.DecodeString(cert)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode certificate cert %v: %w", decodedCert, err)
|
||||
|
@ -135,3 +141,59 @@ func EncodePrivateKey(key *ecdsa.PrivateKey) ([]byte, error) {
|
|||
|
||||
return pem.EncodeToMemory(keyBlock), nil
|
||||
}
|
||||
|
||||
// GenerateSelfSignedCertificate generates a self-signed TLS certificate.
|
||||
//
|
||||
// mostly copied from https://golang.org/src/crypto/tls/generate_cert.go
|
||||
func GenerateSelfSignedCertificate(domain string) (*tls.Certificate, error) {
|
||||
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to geneate private key: %w", err)
|
||||
}
|
||||
|
||||
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
|
||||
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate serial number: %w", err)
|
||||
}
|
||||
|
||||
template := x509.Certificate{
|
||||
SerialNumber: serialNumber,
|
||||
Subject: pkix.Name{
|
||||
Organization: []string{"Pomerium"},
|
||||
},
|
||||
NotBefore: time.Now().Add(-time.Minute * 10),
|
||||
NotAfter: time.Now().Add(time.Hour * 24 * 365),
|
||||
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||
BasicConstraintsValid: true,
|
||||
}
|
||||
if ip := net.ParseIP(domain); ip != nil {
|
||||
template.IPAddresses = append(template.IPAddresses, ip)
|
||||
} else {
|
||||
template.DNSNames = append(template.DNSNames, domain)
|
||||
}
|
||||
|
||||
publicKeyBytes, err := x509.CreateCertificate(rand.Reader,
|
||||
&template, &template,
|
||||
privateKey.Public(), privateKey,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create certificate: %w", err)
|
||||
}
|
||||
|
||||
privateKeyBytes, err := x509.MarshalPKCS8PrivateKey(privateKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal private key: %w", err)
|
||||
}
|
||||
|
||||
cert, err := tls.X509KeyPair(
|
||||
pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: publicKeyBytes}),
|
||||
pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: privateKeyBytes}),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to convert x509 bytes into tls certificate: %w", err)
|
||||
}
|
||||
|
||||
return &cert, nil
|
||||
}
|
||||
|
|
|
@ -55,9 +55,9 @@ func TestCertifcateFromBase64(t *testing.T) {
|
|||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, err := CertifcateFromBase64(tt.cert, tt.key)
|
||||
_, err := CertificateFromBase64(tt.cert, tt.key)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("CertifcateFromBase64() error = %v, wantErr %v", err, tt.wantErr)
|
||||
t.Errorf("CertificateFromBase64() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
})
|
||||
|
|
|
@ -1,103 +1,51 @@
|
|||
package cryptutil
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"crypto/x509"
|
||||
|
||||
"github.com/caddyserver/certmagic"
|
||||
"github.com/go-acme/lego/v3/challenge/tlsalpn01"
|
||||
)
|
||||
|
||||
// NewAutocert automatically retrieves public certificates from the free
|
||||
// certificate authority Let's Encrypt using HTTP-01 and TLS-ALPN-01 challenges.
|
||||
// To complete the challenges, the server must be accessible from the internet
|
||||
// by port 80 or 443 .
|
||||
//
|
||||
// https://letsencrypt.org/docs/challenge-types/#http-01-challenge
|
||||
// https://letsencrypt.org/docs/challenge-types/#tls-alpn-01
|
||||
func NewAutocert(tlsConfig *tls.Config, hostnames []string, useStaging bool, path string) (*tls.Config, func(h http.Handler) http.Handler, error) {
|
||||
certmagic.DefaultACME.Agreed = true
|
||||
if useStaging {
|
||||
certmagic.DefaultACME.CA = certmagic.LetsEncryptStagingCA
|
||||
// GetCertificateForDomain returns the tls Certificate which matches the given domain name.
|
||||
// It should handle both exact matches and wildcard matches. If none of those match, the first certificate will be used.
|
||||
// Finally if there are no matching certificates one will be generated.
|
||||
func GetCertificateForDomain(certificates []tls.Certificate, domain string) (*tls.Certificate, error) {
|
||||
// first try a direct name match
|
||||
for _, cert := range certificates {
|
||||
if matchesDomain(&cert, domain) {
|
||||
return &cert, nil
|
||||
}
|
||||
cm := certmagic.NewDefault()
|
||||
|
||||
tlsConfig = newTLSConfigIfEmpty(tlsConfig)
|
||||
// add existing certs to the cache, and staple OCSP
|
||||
for _, cert := range tlsConfig.Certificates {
|
||||
if err := cm.CacheUnmanagedTLSCertificate(cert, nil); err != nil {
|
||||
return nil, nil, fmt.Errorf("cryptutil: failed caching cert: %w", err)
|
||||
}
|
||||
}
|
||||
cm.Storage = &certmagic.FileStorage{Path: path}
|
||||
acmeConfig := certmagic.NewACMEManager(cm, certmagic.DefaultACME)
|
||||
cm.Issuer = acmeConfig
|
||||
// todo(bdd) : add cancellation context?
|
||||
if err := cm.ManageAsync(context.TODO(), hostnames); err != nil {
|
||||
return nil, nil, fmt.Errorf("cryptutil: sync failed: %w", err)
|
||||
}
|
||||
|
||||
tlsConfig.GetCertificate = cm.GetCertificate
|
||||
tlsConfig.NextProtos = append(tlsConfig.NextProtos, tlsalpn01.ACMETLS1Protocol)
|
||||
tlsConfig.BuildNameToCertificate()
|
||||
return tlsConfig, acmeConfig.HTTPChallengeHandler, nil
|
||||
// next use the first cert
|
||||
if len(certificates) > 0 {
|
||||
return &certificates[0], nil
|
||||
}
|
||||
|
||||
// finally fall back to a generated, self-signed certificate
|
||||
return GenerateSelfSignedCertificate(domain)
|
||||
}
|
||||
|
||||
// TLSConfigFromBase64 returns an tls configuration from a base64 encoded blob.
|
||||
func TLSConfigFromBase64(tlsConfig *tls.Config, cert, key string) (*tls.Config, error) {
|
||||
tlsConfig = newTLSConfigIfEmpty(tlsConfig)
|
||||
c, err := CertifcateFromBase64(cert, key)
|
||||
func matchesDomain(cert *tls.Certificate, domain string) bool {
|
||||
if cert == nil || len(cert.Certificate) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
xcert, err := x509.ParseCertificate(cert.Certificate[0])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return false
|
||||
}
|
||||
tlsConfig.Certificates = append(tlsConfig.Certificates, *c)
|
||||
tlsConfig.BuildNameToCertificate()
|
||||
return tlsConfig, nil
|
||||
}
|
||||
|
||||
// TLSConfigFromFile returns an tls configuration from a certificate and
|
||||
// key file .
|
||||
func TLSConfigFromFile(tlsConfig *tls.Config, cert, key string) (*tls.Config, error) {
|
||||
tlsConfig = newTLSConfigIfEmpty(tlsConfig)
|
||||
c, err := CertificateFromFile(cert, key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
if certmagic.MatchWildcard(domain, xcert.Subject.CommonName) {
|
||||
return true
|
||||
}
|
||||
tlsConfig.Certificates = append(tlsConfig.Certificates, *c)
|
||||
tlsConfig.BuildNameToCertificate()
|
||||
return tlsConfig, nil
|
||||
}
|
||||
|
||||
// newTLSConfigIfEmpty returns an opinionated TLS configuration if config is nil.
|
||||
// See :
|
||||
// https://wiki.mozilla.org/Security/Server_Side_TLS#Recommended_configurations
|
||||
// https://blog.cloudflare.com/exposing-go-on-the-internet/
|
||||
// https://github.com/ssllabs/research/wiki/SSL-and-TLS-Deployment-Best-Practices
|
||||
// https://github.com/golang/go/blob/df91b8044dbe790c69c16058330f545be069cc1f/src/crypto/tls/common.go#L919
|
||||
func newTLSConfigIfEmpty(tlsConfig *tls.Config) *tls.Config {
|
||||
if tlsConfig != nil {
|
||||
return tlsConfig
|
||||
for _, san := range xcert.DNSNames {
|
||||
if certmagic.MatchWildcard(domain, san) {
|
||||
return true
|
||||
}
|
||||
return &tls.Config{
|
||||
MinVersion: tls.VersionTLS12,
|
||||
// Prioritize cipher suites sped up by AES-NI (AES-GCM)
|
||||
CipherSuites: []uint16{
|
||||
tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
|
||||
tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
|
||||
tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305,
|
||||
tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
|
||||
tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
|
||||
tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305,
|
||||
},
|
||||
PreferServerCipherSuites: true,
|
||||
// Use curves which have assembly implementations
|
||||
CurvePreferences: []tls.CurveID{
|
||||
tls.X25519,
|
||||
tls.CurveP256,
|
||||
},
|
||||
// HTTP/2 must be enabled manually when using http.Serve
|
||||
NextProtos: []string{"h2", "http/1.1"},
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
|
|
@ -3,47 +3,61 @@ package cryptutil
|
|||
import (
|
||||
"crypto/tls"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestTLSConfigFromBase64(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
cert string
|
||||
key string
|
||||
wantErr bool
|
||||
}{
|
||||
{"good",
|
||||
"LS0tLS1CRUdJTiBDRVJUSUZJQ0FURS0tLS0tCk1JSUVJVENDQWdtZ0F3SUJBZ0lSQVBqTEJxS1lwcWU0ekhQc0dWdFR6T0F3RFFZSktvWklodmNOQVFFTEJRQXcKRWpFUU1BNEdBMVVFQXhNSFoyOXZaQzFqWVRBZUZ3MHhPVEE0TVRBeE9EUTVOREJhRncweU1UQXlNVEF4TnpRdwpNREZhTUJNeEVUQVBCZ05WQkFNVENIQnZiV1Z5YVhWdE1JSUJJakFOQmdrcWhraUc5dzBCQVFFRkFBT0NBUThBCk1JSUJDZ0tDQVFFQTY3S2pxbVFZR3EwTVZ0QUNWcGVDbVhtaW5sUWJEUEdMbXNaQVVFd3VlSFFucnQzV3R2cEQKT202QWxhSk1VblcrSHU1NWpqb2thbEtlVmpUS21nWUdicVV6VkRvTWJQRGFIZWtsdGRCVE1HbE9VRnNQNFVKUwpEck80emROK3pvNDI4VFgyUG5HMkZDZFZLR3k0UEU4aWxIYldMY3I4NzFZalY1MWZ3OENMRFg5UFpKTnU4NjFDCkY3VjlpRUptNnNTZlFsbW5oTjhqMytXelZiUFFOeTFXc1I3aTllOWo2M0VxS3QyMlE5T1hMK1dBY0tza29JU20KQ05WUlVBalU4WVJWY2dRSkIrelEzNEFRUGx6ME9wNU8vUU4vTWVkamFGOHdMUytpdi96dmlTOGNxUGJ4bzZzTApxNkZOVGx0ay9Ra3hlQ2VLS1RRZS8za1BZdlFBZG5sNjVRSURBUUFCbzNFd2J6QU9CZ05WSFE4QkFmOEVCQU1DCkE3Z3dIUVlEVlIwbEJCWXdGQVlJS3dZQkJRVUhBd0VHQ0NzR0FRVUZCd01DTUIwR0ExVWREZ1FXQkJRQ1FYbWIKc0hpcS9UQlZUZVhoQ0dpNjhrVy9DakFmQmdOVkhTTUVHREFXZ0JSNTRKQ3pMRlg0T0RTQ1J0dWNBUGZOdVhWegpuREFOQmdrcWhraUc5dzBCQVFzRkFBT0NBZ0VBcm9XL2trMllleFN5NEhaQXFLNDVZaGQ5ay9QVTFiaDlFK1BRCk5jZFgzTUdEY2NDRUFkc1k4dll3NVE1cnhuMGFzcSt3VGFCcGxoYS9rMi9VVW9IQ1RqUVp1Mk94dEF3UTdPaWIKVE1tMEorU3NWT3d4YnFQTW9rK1RqVE16NFdXaFFUTzVwRmNoZDZXZXNCVHlJNzJ0aG1jcDd1c2NLU2h3YktIegpQY2h1QTQ4SzhPdi96WkxmZnduQVNZb3VCczJjd1ZiRDI3ZXZOMzdoMGFzR1BrR1VXdm1PSDduTHNVeTh3TTdqCkNGL3NwMmJmTC9OYVdNclJnTHZBMGZMS2pwWTQrVEpPbkVxQmxPcCsrbHlJTEZMcC9qMHNybjRNUnlKK0t6UTEKR1RPakVtQ1QvVEFtOS9XSThSL0FlYjcwTjEzTytYNEtaOUJHaDAxTzN3T1Vqd3BZZ3lxSnNoRnNRUG50VmMrSQpKQmF4M2VQU3NicUcwTFkzcHdHUkpRNmMrd1lxdGk2Y0tNTjliYlRkMDhCNUk1N1RRTHhNcUoycTFnWmw1R1VUCmVFZGNWRXltMnZmd0NPd0lrbGNBbThxTm5kZGZKV1FabE5VaHNOVWFBMkVINnlDeXdaZm9aak9hSDEwTXowV20KeTNpZ2NSZFQ3Mi9NR2VkZk93MlV0MVVvRFZmdEcxcysrditUQ1lpNmpUQU05dkZPckJ4UGlOeGFkUENHR2NZZAowakZIc2FWOGFPV1dQQjZBQ1JteHdDVDdRTnRTczM2MlpIOUlFWWR4Q00yMDUrZmluVHhkOUcwSmVRRTd2Kyt6CldoeWo2ZmJBWUIxM2wvN1hkRnpNSW5BOGxpekdrVHB2RHMxeTBCUzlwV3ppYmhqbVFoZGZIejdCZGpGTHVvc2wKZzlNZE5sND0KLS0tLS1FTkQgQ0VSVElGSUNBVEUtLS0tLQo=",
|
||||
"LS0tLS1CRUdJTiBSU0EgUFJJVkFURSBLRVktLS0tLQpNSUlFcGdJQkFBS0NBUUVBNjdLanFtUVlHcTBNVnRBQ1ZwZUNtWG1pbmxRYkRQR0xtc1pBVUV3dWVIUW5ydDNXCnR2cERPbTZBbGFKTVVuVytIdTU1ampva2FsS2VWalRLbWdZR2JxVXpWRG9NYlBEYUhla2x0ZEJUTUdsT1VGc1AKNFVKU0RyTzR6ZE4rem80MjhUWDJQbkcyRkNkVktHeTRQRThpbEhiV0xjcjg3MVlqVjUxZnc4Q0xEWDlQWkpOdQo4NjFDRjdWOWlFSm02c1NmUWxtbmhOOGozK1d6VmJQUU55MVdzUjdpOWU5ajYzRXFLdDIyUTlPWEwrV0FjS3NrCm9JU21DTlZSVUFqVThZUlZjZ1FKQit6UTM0QVFQbHowT3A1Ty9RTi9NZWRqYUY4d0xTK2l2L3p2aVM4Y3FQYngKbzZzTHE2Rk5UbHRrL1FreGVDZUtLVFFlLzNrUFl2UUFkbmw2NVFJREFRQUJBb0lCQVFEQVQ0eXN2V2pSY3pxcgpKcU9SeGFPQTJEY3dXazJML1JXOFhtQWhaRmRTWHV2MkNQbGxhTU1yelBmTG41WUlmaHQzSDNzODZnSEdZc3pnClo4aWJiYWtYNUdFQ0t5N3lRSDZuZ3hFS3pRVGpiampBNWR3S0h0UFhQUnJmamQ1Y2FMczVpcDcxaWxCWEYxU3IKWERIaXUycnFtaC9kVTArWGRMLzNmK2VnVDl6bFQ5YzRyUm84dnZueWNYejFyMnVhRVZ2VExsWHVsb2NpeEVrcgoySjlTMmxveWFUb2tFTnNlMDNpSVdaWnpNNElZcVowOGJOeG9IWCszQXVlWExIUStzRkRKMlhaVVdLSkZHMHUyClp3R2w3YlZpRTFQNXdiQUdtZzJDeDVCN1MrdGQyUEpSV3Frb2VxY3F2RVdCc3RFL1FEcDFpVThCOHpiQXd0Y3IKZHc5TXZ6Q2hBb0dCQVBObzRWMjF6MGp6MWdEb2tlTVN5d3JnL2E4RkJSM2R2Y0xZbWV5VXkybmd3eHVucnFsdwo2U2IrOWdrOGovcXEvc3VQSDhVdzNqSHNKYXdGSnNvTkVqNCt2b1ZSM3UrbE5sTEw5b21rMXBoU0dNdVp0b3huCm5nbUxVbkJUMGI1M3BURkJ5WGsveE5CbElreWdBNlg5T2MreW5na3RqNlRyVnMxUERTdnVJY0s1QW9HQkFQZmoKcEUzR2F6cVFSemx6TjRvTHZmQWJBdktCZ1lPaFNnemxsK0ZLZkhzYWJGNkdudFd1dWVhY1FIWFpYZTA1c2tLcApXN2xYQ3dqQU1iUXI3QmdlazcrOSszZElwL1RnYmZCYnN3Syt6Vng3Z2doeWMrdytXRWExaHByWTZ6YXdxdkFaCkhRU2lMUEd1UGp5WXBQa1E2ZFdEczNmWHJGZ1dlTmd4SkhTZkdaT05Bb0dCQUt5WTF3MUM2U3Y2c3VuTC8vNTcKQ2Z5NTAwaXlqNUZBOWRqZkRDNWt4K1JZMnlDV0ExVGsybjZyVmJ6dzg4czBTeDMrYS9IQW1CM2dMRXBSRU5NKwo5NHVwcENFWEQ3VHdlcGUxUnlrTStKbmp4TzlDSE41c2J2U25sUnBQWlMvZzJRTVhlZ3grK2trbkhXNG1ITkFyCndqMlRrMXBBczFXbkJ0TG9WaGVyY01jSkFvR0JBSTYwSGdJb0Y5SysvRUcyY21LbUg5SDV1dGlnZFU2eHEwK0IKWE0zMWMzUHE0amdJaDZlN3pvbFRxa2d0dWtTMjBraE45dC9ibkI2TmhnK1N1WGVwSXFWZldVUnlMejVwZE9ESgo2V1BMTTYzcDdCR3cwY3RPbU1NYi9VRm5Yd0U4OHlzRlNnOUF6VjdVVUQvU0lDYkI5ZHRVMWh4SHJJK0pZRWdWCkFrZWd6N2lCQW9HQkFJRncrQVFJZUIwM01UL0lCbGswNENQTDJEak0rNDhoVGRRdjgwMDBIQU9mUWJrMEVZUDEKQ2FLR3RDbTg2MXpBZjBzcS81REtZQ0l6OS9HUzNYRk00Qm1rRk9nY1NXVENPNmZmTGdLM3FmQzN4WDJudlpIOQpYZGNKTDQrZndhY0x4c2JJKzhhUWNOVHRtb3pkUjEzQnNmUmIrSGpUL2o3dkdrYlFnSkhCT0syegotLS0tLUVORCBSU0EgUFJJVkFURSBLRVktLS0tLQo=",
|
||||
false},
|
||||
{"bad cert",
|
||||
"!=",
|
||||
"LS0tLS1CRUdJTiBSU0EgUFJJVkFURSBLRVktLS0tLQpNSUlFcGdJQkFBS0NBUUVBNjdLanFtUVlHcTBNVnRBQ1ZwZUNtWG1pbmxRYkRQR0xtc1pBVUV3dWVIUW5ydDNXCnR2cERPbTZBbGFKTVVuVytIdTU1ampva2FsS2VWalRLbWdZR2JxVXpWRG9NYlBEYUhla2x0ZEJUTUdsT1VGc1AKNFVKU0RyTzR6ZE4rem80MjhUWDJQbkcyRkNkVktHeTRQRThpbEhiV0xjcjg3MVlqVjUxZnc4Q0xEWDlQWkpOdQo4NjFDRjdWOWlFSm02c1NmUWxtbmhOOGozK1d6VmJQUU55MVdzUjdpOWU5ajYzRXFLdDIyUTlPWEwrV0FjS3NrCm9JU21DTlZSVUFqVThZUlZjZ1FKQit6UTM0QVFQbHowT3A1Ty9RTi9NZWRqYUY4d0xTK2l2L3p2aVM4Y3FQYngKbzZzTHE2Rk5UbHRrL1FreGVDZUtLVFFlLzNrUFl2UUFkbmw2NVFJREFRQUJBb0lCQVFEQVQ0eXN2V2pSY3pxcgpKcU9SeGFPQTJEY3dXazJML1JXOFhtQWhaRmRTWHV2MkNQbGxhTU1yelBmTG41WUlmaHQzSDNzODZnSEdZc3pnClo4aWJiYWtYNUdFQ0t5N3lRSDZuZ3hFS3pRVGpiampBNWR3S0h0UFhQUnJmamQ1Y2FMczVpcDcxaWxCWEYxU3IKWERIaXUycnFtaC9kVTArWGRMLzNmK2VnVDl6bFQ5YzRyUm84dnZueWNYejFyMnVhRVZ2VExsWHVsb2NpeEVrcgoySjlTMmxveWFUb2tFTnNlMDNpSVdaWnpNNElZcVowOGJOeG9IWCszQXVlWExIUStzRkRKMlhaVVdLSkZHMHUyClp3R2w3YlZpRTFQNXdiQUdtZzJDeDVCN1MrdGQyUEpSV3Frb2VxY3F2RVdCc3RFL1FEcDFpVThCOHpiQXd0Y3IKZHc5TXZ6Q2hBb0dCQVBObzRWMjF6MGp6MWdEb2tlTVN5d3JnL2E4RkJSM2R2Y0xZbWV5VXkybmd3eHVucnFsdwo2U2IrOWdrOGovcXEvc3VQSDhVdzNqSHNKYXdGSnNvTkVqNCt2b1ZSM3UrbE5sTEw5b21rMXBoU0dNdVp0b3huCm5nbUxVbkJUMGI1M3BURkJ5WGsveE5CbElreWdBNlg5T2MreW5na3RqNlRyVnMxUERTdnVJY0s1QW9HQkFQZmoKcEUzR2F6cVFSemx6TjRvTHZmQWJBdktCZ1lPaFNnemxsK0ZLZkhzYWJGNkdudFd1dWVhY1FIWFpYZTA1c2tLcApXN2xYQ3dqQU1iUXI3QmdlazcrOSszZElwL1RnYmZCYnN3Syt6Vng3Z2doeWMrdytXRWExaHByWTZ6YXdxdkFaCkhRU2lMUEd1UGp5WXBQa1E2ZFdEczNmWHJGZ1dlTmd4SkhTZkdaT05Bb0dCQUt5WTF3MUM2U3Y2c3VuTC8vNTcKQ2Z5NTAwaXlqNUZBOWRqZkRDNWt4K1JZMnlDV0ExVGsybjZyVmJ6dzg4czBTeDMrYS9IQW1CM2dMRXBSRU5NKwo5NHVwcENFWEQ3VHdlcGUxUnlrTStKbmp4TzlDSE41c2J2U25sUnBQWlMvZzJRTVhlZ3grK2trbkhXNG1ITkFyCndqMlRrMXBBczFXbkJ0TG9WaGVyY01jSkFvR0JBSTYwSGdJb0Y5SysvRUcyY21LbUg5SDV1dGlnZFU2eHEwK0IKWE0zMWMzUHE0amdJaDZlN3pvbFRxa2d0dWtTMjBraE45dC9ibkI2TmhnK1N1WGVwSXFWZldVUnlMejVwZE9ESgo2V1BMTTYzcDdCR3cwY3RPbU1NYi9VRm5Yd0U4OHlzRlNnOUF6VjdVVUQvU0lDYkI5ZHRVMWh4SHJJK0pZRWdWCkFrZWd6N2lCQW9HQkFJRncrQVFJZUIwM01UL0lCbGswNENQTDJEak0rNDhoVGRRdjgwMDBIQU9mUWJrMEVZUDEKQ2FLR3RDbTg2MXpBZjBzcS81REtZQ0l6OS9HUzNYRk00Qm1rRk9nY1NXVENPNmZmTGdLM3FmQzN4WDJudlpIOQpYZGNKTDQrZndhY0x4c2JJKzhhUWNOVHRtb3pkUjEzQnNmUmIrSGpUL2o3dkdrYlFnSkhCT0syegotLS0tLUVORCBSU0EgUFJJVkFURSBLRVktLS0tLQo=",
|
||||
true},
|
||||
{"bad key",
|
||||
"LS0tLS1CRUdJTiBDRVJUSUZJQ0FURS0tLS0tCk1JSUVJVENDQWdtZ0F3SUJBZ0lSQVBqTEJxS1lwcWU0ekhQc0dWdFR6T0F3RFFZSktvWklodmNOQVFFTEJRQXcKRWpFUU1BNEdBMVVFQXhNSFoyOXZaQzFqWVRBZUZ3MHhPVEE0TVRBeE9EUTVOREJhRncweU1UQXlNVEF4TnpRdwpNREZhTUJNeEVUQVBCZ05WQkFNVENIQnZiV1Z5YVhWdE1JSUJJakFOQmdrcWhraUc5dzBCQVFFRkFBT0NBUThBCk1JSUJDZ0tDQVFFQTY3S2pxbVFZR3EwTVZ0QUNWcGVDbVhtaW5sUWJEUEdMbXNaQVVFd3VlSFFucnQzV3R2cEQKT202QWxhSk1VblcrSHU1NWpqb2thbEtlVmpUS21nWUdicVV6VkRvTWJQRGFIZWtsdGRCVE1HbE9VRnNQNFVKUwpEck80emROK3pvNDI4VFgyUG5HMkZDZFZLR3k0UEU4aWxIYldMY3I4NzFZalY1MWZ3OENMRFg5UFpKTnU4NjFDCkY3VjlpRUptNnNTZlFsbW5oTjhqMytXelZiUFFOeTFXc1I3aTllOWo2M0VxS3QyMlE5T1hMK1dBY0tza29JU20KQ05WUlVBalU4WVJWY2dRSkIrelEzNEFRUGx6ME9wNU8vUU4vTWVkamFGOHdMUytpdi96dmlTOGNxUGJ4bzZzTApxNkZOVGx0ay9Ra3hlQ2VLS1RRZS8za1BZdlFBZG5sNjVRSURBUUFCbzNFd2J6QU9CZ05WSFE4QkFmOEVCQU1DCkE3Z3dIUVlEVlIwbEJCWXdGQVlJS3dZQkJRVUhBd0VHQ0NzR0FRVUZCd01DTUIwR0ExVWREZ1FXQkJRQ1FYbWIKc0hpcS9UQlZUZVhoQ0dpNjhrVy9DakFmQmdOVkhTTUVHREFXZ0JSNTRKQ3pMRlg0T0RTQ1J0dWNBUGZOdVhWegpuREFOQmdrcWhraUc5dzBCQVFzRkFBT0NBZ0VBcm9XL2trMllleFN5NEhaQXFLNDVZaGQ5ay9QVTFiaDlFK1BRCk5jZFgzTUdEY2NDRUFkc1k4dll3NVE1cnhuMGFzcSt3VGFCcGxoYS9rMi9VVW9IQ1RqUVp1Mk94dEF3UTdPaWIKVE1tMEorU3NWT3d4YnFQTW9rK1RqVE16NFdXaFFUTzVwRmNoZDZXZXNCVHlJNzJ0aG1jcDd1c2NLU2h3YktIegpQY2h1QTQ4SzhPdi96WkxmZnduQVNZb3VCczJjd1ZiRDI3ZXZOMzdoMGFzR1BrR1VXdm1PSDduTHNVeTh3TTdqCkNGL3NwMmJmTC9OYVdNclJnTHZBMGZMS2pwWTQrVEpPbkVxQmxPcCsrbHlJTEZMcC9qMHNybjRNUnlKK0t6UTEKR1RPakVtQ1QvVEFtOS9XSThSL0FlYjcwTjEzTytYNEtaOUJHaDAxTzN3T1Vqd3BZZ3lxSnNoRnNRUG50VmMrSQpKQmF4M2VQU3NicUcwTFkzcHdHUkpRNmMrd1lxdGk2Y0tNTjliYlRkMDhCNUk1N1RRTHhNcUoycTFnWmw1R1VUCmVFZGNWRXltMnZmd0NPd0lrbGNBbThxTm5kZGZKV1FabE5VaHNOVWFBMkVINnlDeXdaZm9aak9hSDEwTXowV20KeTNpZ2NSZFQ3Mi9NR2VkZk93MlV0MVVvRFZmdEcxcysrditUQ1lpNmpUQU05dkZPckJ4UGlOeGFkUENHR2NZZAowakZIc2FWOGFPV1dQQjZBQ1JteHdDVDdRTnRTczM2MlpIOUlFWWR4Q00yMDUrZmluVHhkOUcwSmVRRTd2Kyt6CldoeWo2ZmJBWUIxM2wvN1hkRnpNSW5BOGxpekdrVHB2RHMxeTBCUzlwV3ppYmhqbVFoZGZIejdCZGpGTHVvc2wKZzlNZE5sND0KLS0tLS1FTkQgQ0VSVElGSUNBVEUtLS0tLQo=",
|
||||
"!=",
|
||||
true},
|
||||
func TestGetCertificateForDomain(t *testing.T) {
|
||||
gen := func(t *testing.T, domain string) *tls.Certificate {
|
||||
cert, err := GenerateSelfSignedCertificate(domain)
|
||||
if !assert.NoError(t, err, "error generating certificate for: %s", domain) {
|
||||
t.FailNow()
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, err := TLSConfigFromBase64(nil, tt.cert, tt.key)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("TLSConfigFromBase64() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return cert
|
||||
}
|
||||
|
||||
t.Run("exact match", func(t *testing.T) {
|
||||
certs := []tls.Certificate{
|
||||
*gen(t, "a.example.com"),
|
||||
*gen(t, "b.example.com"),
|
||||
}
|
||||
|
||||
found, err := GetCertificateForDomain(certs, "b.example.com")
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
assert.Equal(t, &certs[1], found)
|
||||
})
|
||||
t.Run("wildcard match", func(t *testing.T) {
|
||||
certs := []tls.Certificate{
|
||||
*gen(t, "a.example.com"),
|
||||
*gen(t, "*.example.com"),
|
||||
}
|
||||
}
|
||||
|
||||
func TestTLSConfigFromFile(t *testing.T) {
|
||||
cfg, err := TLSConfigFromFile(nil, "testdata/example-cert.pem", "testdata/example-key.pem")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
found, err := GetCertificateForDomain(certs, "b.example.com")
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
listener, err := tls.Listen("tcp", ":0", cfg)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
assert.Equal(t, &certs[1], found)
|
||||
})
|
||||
t.Run("no name match", func(t *testing.T) {
|
||||
certs := []tls.Certificate{
|
||||
*gen(t, "a.example.com"),
|
||||
}
|
||||
_ = listener
|
||||
|
||||
found, err := GetCertificateForDomain(certs, "b.example.com")
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
assert.Equal(t, &certs[0], found)
|
||||
})
|
||||
t.Run("generate", func(t *testing.T) {
|
||||
certs := []tls.Certificate{}
|
||||
|
||||
found, err := GetCertificateForDomain(certs, "b.example.com")
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
assert.NotNil(t, found)
|
||||
})
|
||||
}
|
||||
|
|
|
@ -15,8 +15,9 @@ import (
|
|||
"strings"
|
||||
|
||||
"github.com/natefinch/atomic"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
)
|
||||
|
||||
const (
|
||||
|
|
Loading…
Add table
Reference in a new issue