config: detect underlying file changes (#1775)

* wip

* cleanup

* add test

* use uuid for temp dir, derive root CA path from filemgr for tests

* fix comment

* fix double close

* use latest notify
This commit is contained in:
Caleb Doxsey 2021-01-14 18:06:02 -07:00 committed by GitHub
parent c99994bed8
commit 10912add67
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
16 changed files with 500 additions and 99 deletions

View file

@ -1,9 +1,14 @@
package config
import (
"crypto/sha256"
"encoding/hex"
"io/ioutil"
"sync"
"github.com/fsnotify/fsnotify"
"github.com/pomerium/pomerium/internal/fileutil"
)
// Config holds pomerium configuration options.
@ -125,3 +130,83 @@ func (src *FileOrEnvironmentSource) GetConfig() *Config {
return src.config
}
// FileWatcherSource is a config source which triggers a change any time a file in the options changes.
type FileWatcherSource struct {
underlying Source
watcher *fileutil.Watcher
mu sync.RWMutex
computedConfig *Config
version string
ChangeDispatcher
}
// NewFileWatcherSource creates a new FileWatcherSource.
func NewFileWatcherSource(underlying Source) *FileWatcherSource {
src := &FileWatcherSource{
underlying: underlying,
watcher: fileutil.NewWatcher(),
}
ch := src.watcher.Bind()
go func() {
for range ch {
src.check(underlying.GetConfig())
}
}()
underlying.OnConfigChange(func(cfg *Config) {
src.check(cfg)
})
src.check(underlying.GetConfig())
return src
}
// GetConfig gets the underlying config.
func (src *FileWatcherSource) GetConfig() *Config {
src.mu.RLock()
defer src.mu.RUnlock()
return src.computedConfig
}
func (src *FileWatcherSource) check(cfg *Config) {
src.mu.Lock()
defer src.mu.Unlock()
src.watcher.Clear()
h := sha256.New()
fs := []string{
cfg.Options.CAFile,
cfg.Options.CertFile,
cfg.Options.ClientCAFile,
cfg.Options.DataBrokerStorageCAFile,
cfg.Options.DataBrokerStorageCertFile,
cfg.Options.DataBrokerStorageCertKeyFile,
cfg.Options.KeyFile,
cfg.Options.PolicyFile,
}
for _, f := range fs {
_, _ = h.Write([]byte{0})
bs, err := ioutil.ReadFile(f)
if err == nil {
src.watcher.Add(f)
_, _ = h.Write(bs)
}
}
version := hex.EncodeToString(h.Sum(nil))
if src.version != version {
src.version = version
// update the computed config
src.computedConfig = cfg.Clone()
src.computedConfig.Options.Certificates = nil
_ = src.computedConfig.Options.Validate()
// trigger a change
src.Trigger(src.computedConfig)
}
}

View file

@ -0,0 +1,50 @@
package config
import (
"io/ioutil"
"os"
"path/filepath"
"sync"
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
)
func TestFileWatcherSource(t *testing.T) {
tmpdir := filepath.Join(os.TempDir(), uuid.New().String())
err := os.MkdirAll(tmpdir, 0o755)
if !assert.NoError(t, err) {
return
}
err = ioutil.WriteFile(filepath.Join(tmpdir, "example.txt"), []byte{1, 2, 3, 4}, 0o600)
if !assert.NoError(t, err) {
return
}
src := NewFileWatcherSource(NewStaticSource(&Config{
Options: &Options{
CAFile: filepath.Join(tmpdir, "example.txt"),
},
}))
var closeOnce sync.Once
ch := make(chan struct{})
src.OnConfigChange(func(cfg *Config) {
closeOnce.Do(func() {
close(ch)
})
})
err = ioutil.WriteFile(filepath.Join(tmpdir, "example.txt"), []byte{5, 6, 7, 8}, 0o600)
if !assert.NoError(t, err) {
return
}
select {
case <-ch:
case <-time.After(time.Second):
t.Error("expected OnConfigChange to be fired after modifying a file")
}
}

1
go.mod
View file

@ -46,6 +46,7 @@ require (
github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35 // indirect
github.com/prometheus/client_golang v1.9.0
github.com/rakyll/statik v0.1.7
github.com/rjeczalik/notify v0.9.3-0.20201210012515-e2a77dcc14cf
github.com/rs/cors v1.7.0
github.com/rs/zerolog v1.20.0
github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966

3
go.sum
View file

@ -544,6 +544,8 @@ github.com/rakyll/statik v0.1.7/go.mod h1:AlZONWzMtEnMs7W4e/1LURLiI49pIMmp6V9Ung
github.com/rcrowley/go-metrics v0.0.0-20181016184325-3113b8401b8a/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4=
github.com/rcrowley/go-metrics v0.0.0-20200313005456-10cdbea86bc0 h1:MkV+77GLUNo5oJ0jf870itWm3D0Sjh7+Za9gazKc5LQ=
github.com/rcrowley/go-metrics v0.0.0-20200313005456-10cdbea86bc0/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4=
github.com/rjeczalik/notify v0.9.3-0.20201210012515-e2a77dcc14cf h1:MY2fqXPSLfjld10N04fNcSFdR9K/Y57iXxZRFAivHzI=
github.com/rjeczalik/notify v0.9.3-0.20201210012515-e2a77dcc14cf/go.mod h1:aErll2f0sUX9PXZnVNyeiObbmTlk5jnMoCa4QEjJeqM=
github.com/rogpeppe/fastuuid v0.0.0-20150106093220-6724a57986af/go.mod h1:XWv6SoW27p1b0cqNHllgS5HIMJraePCO15w5zCzIWYg=
github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
github.com/rs/cors v1.7.0 h1:+88SsELBHx5r+hZ8TCkggzSstaWNbDvThkVK8H6f9ik=
@ -767,6 +769,7 @@ golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5h
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20180926160741-c2ed4eda69e7/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20181026203630-95b1ffbd15a5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20181107165924-66b7b1311ac8/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=

View file

@ -48,6 +48,9 @@ func Run(ctx context.Context, configFile string) error {
return err
}
// trigger changes when underlying files are changed
src = config.NewFileWatcherSource(src)
// override the default http transport so we can use the custom CA in the TLS client config (#1570)
http.DefaultTransport = config.NewHTTPTransport(src)

View file

@ -0,0 +1,35 @@
package filemgr
import (
"os"
"path/filepath"
"github.com/google/uuid"
)
type config struct {
cacheDir string
}
// An Option updates the config.
type Option = func(*config)
// WithCacheDir returns an Option that sets the cache dir.
func WithCacheDir(cacheDir string) Option {
return func(cfg *config) {
cfg.cacheDir = cacheDir
}
}
func newConfig(options ...Option) *config {
cfg := new(config)
cacheDir, err := os.UserCacheDir()
if err != nil {
cacheDir = filepath.Join(os.TempDir(), uuid.New().String())
}
WithCacheDir(filepath.Join(cacheDir, "pomerium", "envoy", "files"))(cfg)
for _, o := range options {
o(cfg)
}
return cfg
}

View file

@ -0,0 +1,91 @@
package filemgr
import (
"fmt"
"io/ioutil"
"os"
"path/filepath"
envoy_config_core_v3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
"github.com/martinlindhe/base36"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/pkg/cryptutil"
)
// A Manager manages files for envoy.
type Manager struct {
cfg *config
}
// NewManager creates a new Manager.
func NewManager(options ...Option) *Manager {
cfg := newConfig(options...)
return &Manager{
cfg: cfg,
}
}
// BytesDataSource returns an envoy config data source based on bytes.
func (mgr *Manager) BytesDataSource(fileName string, data []byte) *envoy_config_core_v3.DataSource {
h := base36.EncodeBytes(cryptutil.Hash("filemgr", data))
ext := filepath.Ext(fileName)
fileName = fmt.Sprintf("%s-%x%s", fileName[:len(fileName)-len(ext)], h, ext)
if err := os.MkdirAll(mgr.cfg.cacheDir, 0o700); err != nil {
log.Error().Err(err).Msg("filemgr: error creating cache directory, falling back to inline bytes")
return inlineBytes(data)
}
filePath := filepath.Join(mgr.cfg.cacheDir, fileName)
if _, err := os.Stat(filePath); os.IsNotExist(err) {
err = ioutil.WriteFile(filePath, data, 0o600)
if err != nil {
log.Error().Err(err).Msg("filemgr: error writing cache file, falling back to inline bytes")
return inlineBytes(data)
}
} else if err != nil {
log.Error().Err(err).Msg("filemgr: error reading cache file, falling back to inline bytes")
return inlineBytes(data)
}
return inlineFilename(filePath)
}
// ClearCache clears the file cache.
func (mgr *Manager) ClearCache() {
err := filepath.Walk(mgr.cfg.cacheDir, func(p string, fi os.FileInfo, err error) error {
if err != nil {
return err
}
return os.Remove(p)
})
if err != nil {
log.Error().Err(err).Msg("failed to clear envoy file cache")
}
}
// FileDataSource returns an envoy config data source based on a file.
func (mgr *Manager) FileDataSource(filePath string) *envoy_config_core_v3.DataSource {
data, err := ioutil.ReadFile(filePath)
if err != nil {
return inlineFilename(filePath)
}
return mgr.BytesDataSource(filepath.Base(filePath), data)
}
func inlineBytes(data []byte) *envoy_config_core_v3.DataSource {
return &envoy_config_core_v3.DataSource{
Specifier: &envoy_config_core_v3.DataSource_InlineBytes{
InlineBytes: data,
},
}
}
func inlineFilename(name string) *envoy_config_core_v3.DataSource {
return &envoy_config_core_v3.DataSource{
Specifier: &envoy_config_core_v3.DataSource_Filename{
Filename: name,
},
}
}

View file

@ -0,0 +1,59 @@
package filemgr
import (
"io/ioutil"
"os"
"path/filepath"
"testing"
envoy_config_core_v3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
)
func Test(t *testing.T) {
dir := filepath.Join(os.TempDir(), uuid.New().String())
err := os.MkdirAll(dir, 0o755)
if !assert.NoError(t, err) {
return
}
defer func() {
_ = os.RemoveAll(dir)
}()
t.Run("bytes", func(t *testing.T) {
mgr := NewManager(WithCacheDir(dir))
ds := mgr.BytesDataSource("test.txt", []byte{1, 2, 3, 4, 5})
assert.Equal(t, &envoy_config_core_v3.DataSource{
Specifier: &envoy_config_core_v3.DataSource_Filename{
Filename: filepath.Join(dir, "test-353354494b53534a5538435652584d594a5759394d43484f38514b34594b4b524b34515339593249344e4238474a5436414b.txt"),
},
}, ds)
mgr.ClearCache()
})
t.Run("file", func(t *testing.T) {
tmpFilePath := filepath.Join(dir, "test.txt")
_ = ioutil.WriteFile(tmpFilePath, []byte("TEST1"), 0o777)
mgr := NewManager(WithCacheDir(dir))
ds := mgr.FileDataSource(tmpFilePath)
assert.Equal(t, &envoy_config_core_v3.DataSource{
Specifier: &envoy_config_core_v3.DataSource_Filename{
Filename: filepath.Join(dir, "test-34514f59593332445a5649504230484142544c515057383944383730554833564d32574836354654585954304e424f464336.txt"),
},
}, ds)
_ = ioutil.WriteFile(tmpFilePath, []byte("TEST2"), 0o777)
ds = mgr.FileDataSource(tmpFilePath)
assert.Equal(t, &envoy_config_core_v3.DataSource{
Specifier: &envoy_config_core_v3.DataSource_Filename{
Filename: filepath.Join(dir, "test-32564e4457304430393559364b5747373138584f484f5a51334d365758584b47364b555a4c444849513241513457323259.txt"),
},
}, ds)
mgr.ClearCache()
})
}

View file

@ -15,6 +15,7 @@ import (
"google.golang.org/grpc/reflection"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/controlplane/filemgr"
"github.com/pomerium/pomerium/internal/controlplane/xdsmgr"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/telemetry"
@ -50,6 +51,7 @@ type Server struct {
currentConfig atomicVersionedOptions
name string
xdsmgr *xdsmgr.Manager
filemgr *filemgr.Manager
}
// NewServer creates a new Server. Listener ports are chosen by the OS.
@ -87,6 +89,9 @@ func NewServer(name string) (*Server, error) {
srv.xdsmgr = xdsmgr.NewManager(srv.buildDiscoveryResources())
envoy_service_discovery_v3.RegisterAggregatedDiscoveryServiceServer(srv.GRPCServer, srv.xdsmgr)
srv.filemgr = filemgr.NewManager()
srv.filemgr.ClearCache()
return srv, nil
}

View file

@ -7,14 +7,11 @@ import (
"encoding/hex"
"encoding/pem"
"fmt"
"io/ioutil"
"net"
"os"
"path/filepath"
"strconv"
"sync"
xxhash "github.com/cespare/xxhash/v2"
envoy_config_accesslog_v3 "github.com/envoyproxy/go-control-plane/envoy/config/accesslog/v3"
envoy_config_core_v3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
envoy_extensions_access_loggers_grpc_v3 "github.com/envoyproxy/go-control-plane/envoy/extensions/access_loggers/grpc/v3"
@ -45,7 +42,7 @@ func (srv *Server) buildDiscoveryResources() map[string][]*envoy_service_discove
Resource: any,
})
}
for _, listener := range buildListeners(&cfg.Options) {
for _, listener := range srv.buildListeners(&cfg.Options) {
any, _ := anypb.New(listener)
resources[listenerTypeURL] = append(resources[listenerTypeURL], &envoy_service_discovery_v3.Resource{
Name: listener.Name,
@ -116,52 +113,7 @@ func buildAddress(hostport string, defaultPort int) *envoy_config_core_v3.Addres
}
}
func inlineBytes(bs []byte) *envoy_config_core_v3.DataSource {
return &envoy_config_core_v3.DataSource{
Specifier: &envoy_config_core_v3.DataSource_InlineBytes{
InlineBytes: bs,
},
}
}
func inlineBytesAsFilename(name string, bs []byte) *envoy_config_core_v3.DataSource {
ext := filepath.Ext(name)
name = fmt.Sprintf("%s-%x%s", name[:len(name)-len(ext)], xxhash.Sum64(bs), ext)
cacheDir, err := os.UserCacheDir()
if err != nil {
cacheDir = filepath.Join(os.TempDir())
}
cacheDir = filepath.Join(cacheDir, "pomerium", "envoy", "files")
if err = os.MkdirAll(cacheDir, 0o755); err != nil {
log.Error().Err(err).Msg("error creating cache directory, falling back to inline bytes")
return inlineBytes(bs)
}
fp := filepath.Join(cacheDir, name)
if _, err = os.Stat(fp); os.IsNotExist(err) {
err = ioutil.WriteFile(fp, bs, 0o600)
if err != nil {
log.Error().Err(err).Msg("error writing cache file, falling back to inline bytes")
return inlineBytes(bs)
}
} else if err != nil {
log.Error().Err(err).Msg("error reading cache file, falling back to inline bytes")
return inlineBytes(bs)
}
return inlineFilename(fp)
}
func inlineFilename(name string) *envoy_config_core_v3.DataSource {
return &envoy_config_core_v3.DataSource{
Specifier: &envoy_config_core_v3.DataSource_Filename{
Filename: name,
},
}
}
func envoyTLSCertificateFromGoTLSCertificate(cert *tls.Certificate) *envoy_extensions_transport_sockets_tls_v3.TlsCertificate {
func (srv *Server) envoyTLSCertificateFromGoTLSCertificate(cert *tls.Certificate) *envoy_extensions_transport_sockets_tls_v3.TlsCertificate {
envoyCert := &envoy_extensions_transport_sockets_tls_v3.TlsCertificate{}
var chain bytes.Buffer
for _, cbs := range cert.Certificate {
@ -170,12 +122,12 @@ func envoyTLSCertificateFromGoTLSCertificate(cert *tls.Certificate) *envoy_exten
Bytes: cbs,
})
}
envoyCert.CertificateChain = inlineBytesAsFilename("tls-crt.pem", chain.Bytes())
envoyCert.CertificateChain = srv.filemgr.BytesDataSource("tls-crt.pem", chain.Bytes())
if cert.OCSPStaple != nil {
envoyCert.OcspStaple = inlineBytes(cert.OCSPStaple)
envoyCert.OcspStaple = srv.filemgr.BytesDataSource("ocsp-staple", cert.OCSPStaple)
}
if bs, err := x509.MarshalPKCS8PrivateKey(cert.PrivateKey); err == nil {
envoyCert.PrivateKey = inlineBytesAsFilename("tls-key.pem", pem.EncodeToMemory(
envoyCert.PrivateKey = srv.filemgr.BytesDataSource("tls-key.pem", pem.EncodeToMemory(
&pem.Block{
Type: "PRIVATE KEY",
Bytes: bs,
@ -186,7 +138,7 @@ func envoyTLSCertificateFromGoTLSCertificate(cert *tls.Certificate) *envoy_exten
}
for _, scts := range cert.SignedCertificateTimestamps {
envoyCert.SignedCertificateTimestamp = append(envoyCert.SignedCertificateTimestamp,
inlineBytes(scts))
srv.filemgr.BytesDataSource("signed-certificate-timestamp", scts))
}
return envoyCert
}

View file

@ -14,10 +14,15 @@ import (
)
func Test_buildPolicyTransportSocket(t *testing.T) {
rootCA, _ := getRootCertificateAuthority()
cacheDir, _ := os.UserCacheDir()
customCA := filepath.Join(cacheDir, "pomerium", "envoy", "files", "custom-ca-32484c314b584447463735303142374c31414145374650305a525539554938594d524855353757313942494d473847535231.pem")
srv, _ := NewServer("TEST")
rootCAPath, _ := getRootCertificateAuthority()
rootCA := srv.filemgr.FileDataSource(rootCAPath).GetFilename()
t.Run("insecure", func(t *testing.T) {
assert.Nil(t, buildPolicyTransportSocket(&config.Policy{
assert.Nil(t, srv.buildPolicyTransportSocket(&config.Policy{
Destination: mustParseURL("http://example.com"),
}))
})
@ -49,7 +54,7 @@ func Test_buildPolicyTransportSocket(t *testing.T) {
"sni": "example.com"
}
}
`, buildPolicyTransportSocket(&config.Policy{
`, srv.buildPolicyTransportSocket(&config.Policy{
Destination: mustParseURL("https://example.com"),
}))
})
@ -81,7 +86,7 @@ func Test_buildPolicyTransportSocket(t *testing.T) {
"sni": "use-this-name.example.com"
}
}
`, buildPolicyTransportSocket(&config.Policy{
`, srv.buildPolicyTransportSocket(&config.Policy{
Destination: mustParseURL("https://example.com"),
TLSServerName: "use-this-name.example.com",
}))
@ -115,7 +120,7 @@ func Test_buildPolicyTransportSocket(t *testing.T) {
"sni": "example.com"
}
}
`, buildPolicyTransportSocket(&config.Policy{
`, srv.buildPolicyTransportSocket(&config.Policy{
Destination: mustParseURL("https://example.com"),
TLSSkipVerify: true,
}))
@ -141,14 +146,14 @@ func Test_buildPolicyTransportSocket(t *testing.T) {
"exact": "example.com"
}],
"trustedCa": {
"filename": "`+filepath.Join(cacheDir, "pomerium", "envoy", "files", "custom-ca-3aefa6fd5cf2deb4.pem")+`"
"filename": "`+customCA+`"
}
}
},
"sni": "example.com"
}
}
`, buildPolicyTransportSocket(&config.Policy{
`, srv.buildPolicyTransportSocket(&config.Policy{
Destination: mustParseURL("https://example.com"),
TLSCustomCA: base64.StdEncoding.EncodeToString([]byte{0, 0, 0, 0}),
}))
@ -172,10 +177,10 @@ func Test_buildPolicyTransportSocket(t *testing.T) {
},
"tlsCertificates": [{
"certificateChain":{
"filename": "`+filepath.Join(cacheDir, "pomerium", "envoy", "files", "tls-crt-921a8294d2e2ec54.pem")+`"
"filename": "`+filepath.Join(cacheDir, "pomerium", "envoy", "files", "tls-crt-354e49305a5a39414a545530374e58454e48334148524c4e324258463837364355564c4e4532464b54355139495547514a38.pem")+`"
},
"privateKey": {
"filename": "`+filepath.Join(cacheDir, "pomerium", "envoy", "files", "tls-key-d5cf35b1e8533e4a.pem")+`"
"filename": "`+filepath.Join(cacheDir, "pomerium", "envoy", "files", "tls-key-3350415a38414e4e4a4655424e55393430474147324651433949384e485341334b5157364f424b4c5856365a545937383735.pem")+`"
}
}],
"validationContext": {
@ -190,7 +195,7 @@ func Test_buildPolicyTransportSocket(t *testing.T) {
"sni": "example.com"
}
}
`, buildPolicyTransportSocket(&config.Policy{
`, srv.buildPolicyTransportSocket(&config.Policy{
Destination: mustParseURL("https://example.com"),
ClientCertificate: clientCert,
}))
@ -198,7 +203,9 @@ func Test_buildPolicyTransportSocket(t *testing.T) {
}
func Test_buildCluster(t *testing.T) {
rootCA, _ := getRootCertificateAuthority()
srv, _ := NewServer("TEST")
rootCAPath, _ := getRootCertificateAuthority()
rootCA := srv.filemgr.FileDataSource(rootCAPath).GetFilename()
t.Run("insecure", func(t *testing.T) {
cluster := buildCluster("example", mustParseURL("http://example.com"), nil, true, config.GetEnvoyDNSLookupFamily(config.DNSLookupFamilyV4Only))
testutil.AssertProtoJSONEqual(t, `
@ -232,7 +239,7 @@ func Test_buildCluster(t *testing.T) {
})
t.Run("secure", func(t *testing.T) {
u := mustParseURL("https://example.com")
transportSocket := buildPolicyTransportSocket(&config.Policy{
transportSocket := srv.buildPolicyTransportSocket(&config.Policy{
Destination: u,
})
cluster := buildCluster("example", u, transportSocket, true, config.GetEnvoyDNSLookupFamily(config.DNSLookupFamilyAuto))

View file

@ -34,17 +34,17 @@ func (srv *Server) buildClusters(options *config.Options) []*envoy_config_cluste
}
clusters := []*envoy_config_cluster_v3.Cluster{
buildInternalCluster(options, "pomerium-control-plane-grpc", grpcURL, true),
buildInternalCluster(options, "pomerium-control-plane-http", httpURL, false),
srv.buildInternalCluster(options, "pomerium-control-plane-grpc", grpcURL, true),
srv.buildInternalCluster(options, "pomerium-control-plane-http", httpURL, false),
}
clusters = append(clusters, buildInternalCluster(options, authzURL.Host, authzURL, true))
clusters = append(clusters, srv.buildInternalCluster(options, authzURL.Host, authzURL, true))
if config.IsProxy(options.Services) {
for i := range options.Policies {
policy := options.Policies[i]
if policy.Destination != nil {
clusters = append(clusters, buildPolicyCluster(options, &policy))
clusters = append(clusters, srv.buildPolicyCluster(options, &policy))
}
}
}
@ -52,21 +52,21 @@ func (srv *Server) buildClusters(options *config.Options) []*envoy_config_cluste
return clusters
}
func buildInternalCluster(options *config.Options, name string, endpoint *url.URL, forceHTTP2 bool) *envoy_config_cluster_v3.Cluster {
func (srv *Server) buildInternalCluster(options *config.Options, name string, endpoint *url.URL, forceHTTP2 bool) *envoy_config_cluster_v3.Cluster {
dnsLookupFamily := config.GetEnvoyDNSLookupFamily(options.DNSLookupFamily)
return buildCluster(name, endpoint, buildInternalTransportSocket(options, endpoint), forceHTTP2, dnsLookupFamily)
return buildCluster(name, endpoint, srv.buildInternalTransportSocket(options, endpoint), forceHTTP2, dnsLookupFamily)
}
func buildPolicyCluster(options *config.Options, policy *config.Policy) *envoy_config_cluster_v3.Cluster {
func (srv *Server) buildPolicyCluster(options *config.Options, policy *config.Policy) *envoy_config_cluster_v3.Cluster {
name := getPolicyName(policy)
dnsLookupFamily := config.GetEnvoyDNSLookupFamily(options.DNSLookupFamily)
if policy.EnableGoogleCloudServerlessAuthentication {
dnsLookupFamily = envoy_config_cluster_v3.Cluster_V4_ONLY
}
return buildCluster(name, policy.Destination, buildPolicyTransportSocket(policy), false, dnsLookupFamily)
return buildCluster(name, policy.Destination, srv.buildPolicyTransportSocket(policy), false, dnsLookupFamily)
}
func buildInternalTransportSocket(options *config.Options, endpoint *url.URL) *envoy_config_core_v3.TransportSocket {
func (srv *Server) buildInternalTransportSocket(options *config.Options, endpoint *url.URL) *envoy_config_core_v3.TransportSocket {
if endpoint.Scheme != "https" {
return nil
}
@ -82,19 +82,19 @@ func buildInternalTransportSocket(options *config.Options, endpoint *url.URL) *e
}},
}
if options.CAFile != "" {
validationContext.TrustedCa = inlineFilename(options.CAFile)
validationContext.TrustedCa = srv.filemgr.FileDataSource(options.CAFile)
} else if options.CA != "" {
bs, err := base64.StdEncoding.DecodeString(options.CA)
if err != nil {
log.Error().Err(err).Msg("invalid custom CA certificate")
}
validationContext.TrustedCa = inlineBytesAsFilename("custom-ca.pem", bs)
validationContext.TrustedCa = srv.filemgr.BytesDataSource("custom-ca.pem", bs)
} else {
rootCA, err := getRootCertificateAuthority()
if err != nil {
log.Error().Err(err).Msg("unable to enable certificate verification because no root CAs were found")
} else {
validationContext.TrustedCa = inlineFilename(rootCA)
validationContext.TrustedCa = srv.filemgr.FileDataSource(rootCA)
}
}
tlsContext := &envoy_extensions_transport_sockets_tls_v3.UpstreamTlsContext{
@ -115,7 +115,7 @@ func buildInternalTransportSocket(options *config.Options, endpoint *url.URL) *e
}
}
func buildPolicyTransportSocket(policy *config.Policy) *envoy_config_core_v3.TransportSocket {
func (srv *Server) buildPolicyTransportSocket(policy *config.Policy) *envoy_config_core_v3.TransportSocket {
if policy.Destination == nil || policy.Destination.Scheme != "https" {
return nil
}
@ -136,14 +136,14 @@ func buildPolicyTransportSocket(policy *config.Policy) *envoy_config_core_v3.Tra
},
AlpnProtocols: []string{"http/1.1"},
ValidationContextType: &envoy_extensions_transport_sockets_tls_v3.CommonTlsContext_ValidationContext{
ValidationContext: buildPolicyValidationContext(policy),
ValidationContext: srv.buildPolicyValidationContext(policy),
},
},
Sni: sni,
}
if policy.ClientCertificate != nil {
tlsContext.CommonTlsContext.TlsCertificates = append(tlsContext.CommonTlsContext.TlsCertificates,
envoyTLSCertificateFromGoTLSCertificate(policy.ClientCertificate))
srv.envoyTLSCertificateFromGoTLSCertificate(policy.ClientCertificate))
}
tlsConfig := marshalAny(tlsContext)
@ -155,7 +155,7 @@ func buildPolicyTransportSocket(policy *config.Policy) *envoy_config_core_v3.Tra
}
}
func buildPolicyValidationContext(policy *config.Policy) *envoy_extensions_transport_sockets_tls_v3.CertificateValidationContext {
func (srv *Server) buildPolicyValidationContext(policy *config.Policy) *envoy_extensions_transport_sockets_tls_v3.CertificateValidationContext {
if policy.Destination == nil {
return nil
}
@ -172,19 +172,19 @@ func buildPolicyValidationContext(policy *config.Policy) *envoy_extensions_trans
}},
}
if policy.TLSCustomCAFile != "" {
validationContext.TrustedCa = inlineFilename(policy.TLSCustomCAFile)
validationContext.TrustedCa = srv.filemgr.FileDataSource(policy.TLSCustomCAFile)
} else if policy.TLSCustomCA != "" {
bs, err := base64.StdEncoding.DecodeString(policy.TLSCustomCA)
if err != nil {
log.Error().Err(err).Msg("invalid custom CA certificate")
}
validationContext.TrustedCa = inlineBytesAsFilename("custom-ca.pem", bs)
validationContext.TrustedCa = srv.filemgr.BytesDataSource("custom-ca.pem", bs)
} else {
rootCA, err := getRootCertificateAuthority()
if err != nil {
log.Error().Err(err).Msg("unable to enable certificate verification because no root CAs were found")
} else {
validationContext.TrustedCa = inlineFilename(rootCA)
validationContext.TrustedCa = srv.filemgr.FileDataSource(rootCA)
}
}

View file

@ -37,21 +37,21 @@ func init() {
})
}
func buildListeners(options *config.Options) []*envoy_config_listener_v3.Listener {
func (srv *Server) buildListeners(options *config.Options) []*envoy_config_listener_v3.Listener {
var listeners []*envoy_config_listener_v3.Listener
if config.IsAuthenticate(options.Services) || config.IsProxy(options.Services) {
listeners = append(listeners, buildMainListener(options))
listeners = append(listeners, srv.buildMainListener(options))
}
if config.IsAuthorize(options.Services) || config.IsCache(options.Services) {
listeners = append(listeners, buildGRPCListener(options))
listeners = append(listeners, srv.buildGRPCListener(options))
}
return listeners
}
func buildMainListener(options *config.Options) *envoy_config_listener_v3.Listener {
func (srv *Server) buildMainListener(options *config.Options) *envoy_config_listener_v3.Listener {
if options.InsecureServer {
filter := buildMainHTTPConnectionManagerFilter(options,
getAllRouteableDomains(options, options.Addr))
@ -88,7 +88,7 @@ func buildMainListener(options *config.Options) *envoy_config_listener_v3.Listen
ServerNames: []string{tlsDomain},
}
}
tlsContext := buildDownstreamTLSContext(options, tlsDomain)
tlsContext := srv.buildDownstreamTLSContext(options, tlsDomain)
if tlsContext != nil {
tlsConfig := marshalAny(tlsContext)
filterChain.TransportSocket = &envoy_config_core_v3.TransportSocket{
@ -250,7 +250,7 @@ func buildMainHTTPConnectionManagerFilter(options *config.Options, domains []str
}
}
func buildGRPCListener(options *config.Options) *envoy_config_listener_v3.Listener {
func (srv *Server) buildGRPCListener(options *config.Options) *envoy_config_listener_v3.Listener {
filter := buildGRPCHTTPConnectionManagerFilter()
if options.GRPCInsecure {
@ -285,7 +285,7 @@ func buildGRPCListener(options *config.Options) *envoy_config_listener_v3.Listen
ServerNames: []string{tlsDomain},
}
}
tlsContext := buildDownstreamTLSContext(options, tlsDomain)
tlsContext := srv.buildDownstreamTLSContext(options, tlsDomain)
if tlsContext != nil {
tlsConfig := marshalAny(tlsContext)
filterChain.TransportSocket = &envoy_config_core_v3.TransportSocket{
@ -357,7 +357,7 @@ func buildRouteConfiguration(name string, virtualHosts []*envoy_config_route_v3.
}
}
func buildDownstreamTLSContext(options *config.Options, domain string) *envoy_extensions_transport_sockets_tls_v3.DownstreamTlsContext {
func (srv *Server) buildDownstreamTLSContext(options *config.Options, domain string) *envoy_extensions_transport_sockets_tls_v3.DownstreamTlsContext {
cert, err := cryptutil.GetCertificateForDomain(options.Certificates, domain)
if err != nil {
log.Warn().Str("domain", domain).Err(err).Msg("failed to get certificate for domain")
@ -370,9 +370,9 @@ func buildDownstreamTLSContext(options *config.Options, domain string) *envoy_ex
if err != nil {
log.Warn().Msg("client_ca does not appear to be a base64 encoded string")
}
trustedCA = inlineBytesAsFilename("client-ca", bs)
trustedCA = srv.filemgr.BytesDataSource("client-ca", bs)
} else if options.ClientCAFile != "" {
trustedCA = inlineFilename(options.ClientCAFile)
trustedCA = srv.filemgr.FileDataSource(options.ClientCAFile)
}
var validationContext *envoy_extensions_transport_sockets_tls_v3.CommonTlsContext_ValidationContext
@ -385,7 +385,7 @@ func buildDownstreamTLSContext(options *config.Options, domain string) *envoy_ex
}
}
envoyCert := envoyTLSCertificateFromGoTLSCertificate(cert)
envoyCert := srv.envoyTLSCertificateFromGoTLSCertificate(cert)
return &envoy_extensions_transport_sockets_tls_v3.DownstreamTlsContext{
CommonTlsContext: &envoy_extensions_transport_sockets_tls_v3.CommonTlsContext{
TlsParams: &envoy_extensions_transport_sockets_tls_v3.TlsParameters{

View file

@ -385,13 +385,15 @@ func Test_buildDownstreamTLSContext(t *testing.T) {
return
}
downstreamTLSContext := buildDownstreamTLSContext(&config.Options{
srv, _ := NewServer("TEST")
downstreamTLSContext := srv.buildDownstreamTLSContext(&config.Options{
Certificates: []tls.Certificate{*certA},
}, "a.example.com")
cacheDir, _ := os.UserCacheDir()
certFileName := filepath.Join(cacheDir, "pomerium", "envoy", "files", "tls-crt-921a8294d2e2ec54.pem")
keyFileName := filepath.Join(cacheDir, "pomerium", "envoy", "files", "tls-key-d5cf35b1e8533e4a.pem")
certFileName := filepath.Join(cacheDir, "pomerium", "envoy", "files", "tls-crt-354e49305a5a39414a545530374e58454e48334148524c4e324258463837364355564c4e4532464b54355139495547514a38.pem")
keyFileName := filepath.Join(cacheDir, "pomerium", "envoy", "files", "tls-key-3350415a38414e4e4a4655424e55393430474147324651433949384e485341334b5157364f424b4c5856365a545937383735.pem")
testutil.AssertProtoJSONEqual(t, `{
"commonTlsContext": {

View file

@ -0,0 +1,66 @@
package fileutil
import (
"sync"
"github.com/rjeczalik/notify"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/signal"
)
// A Watcher watches files for changes.
type Watcher struct {
*signal.Signal
mu sync.Mutex
filePaths map[string]chan notify.EventInfo
}
// NewWatcher creates a new Watcher.
func NewWatcher() *Watcher {
return &Watcher{
Signal: signal.New(),
filePaths: map[string]chan notify.EventInfo{},
}
}
// Add adds a new watch.
func (watcher *Watcher) Add(filePath string) {
watcher.mu.Lock()
defer watcher.mu.Unlock()
// already watching
if _, ok := watcher.filePaths[filePath]; ok {
return
}
ch := make(chan notify.EventInfo, 1)
go func() {
for evt := range ch {
log.Info().Str("path", evt.Path()).Str("event", evt.Event().String()).Msg("filemgr: detected file change")
watcher.Signal.Broadcast()
}
}()
err := notify.Watch(filePath, ch, notify.All)
if err != nil {
log.Error().Err(err).Str("path", filePath).Msg("filemgr: error watching file path")
notify.Stop(ch)
close(ch)
return
}
log.Debug().Str("path", filePath).Msg("filemgr: watching file for changes")
watcher.filePaths[filePath] = ch
}
// Clear removes all watches.
func (watcher *Watcher) Clear() {
watcher.mu.Lock()
defer watcher.mu.Unlock()
for filePath, ch := range watcher.filePaths {
notify.Stop(ch)
close(ch)
delete(watcher.filePaths, filePath)
}
}

View file

@ -0,0 +1,42 @@
package fileutil
import (
"io/ioutil"
"os"
"path/filepath"
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
)
func TestWatcher(t *testing.T) {
tmpdir := filepath.Join(os.TempDir(), uuid.New().String())
err := os.MkdirAll(tmpdir, 0o755)
if !assert.NoError(t, err) {
return
}
err = ioutil.WriteFile(filepath.Join(tmpdir, "test1.txt"), []byte{1, 2, 3, 4}, 0o666)
if !assert.NoError(t, err) {
return
}
w := NewWatcher()
w.Add(filepath.Join(tmpdir, "test1.txt"))
ch := w.Bind()
defer w.Unbind(ch)
err = ioutil.WriteFile(filepath.Join(tmpdir, "test1.txt"), []byte{5, 6, 7, 8}, 0o666)
if !assert.NoError(t, err) {
return
}
select {
case <-ch:
case <-time.After(time.Second):
t.Error("expected change signal when file is modified")
}
}