diff --git a/config/config_source.go b/config/config_source.go index b40b9d7c1..c636b59d2 100644 --- a/config/config_source.go +++ b/config/config_source.go @@ -1,9 +1,14 @@ package config import ( + "crypto/sha256" + "encoding/hex" + "io/ioutil" "sync" "github.com/fsnotify/fsnotify" + + "github.com/pomerium/pomerium/internal/fileutil" ) // Config holds pomerium configuration options. @@ -125,3 +130,83 @@ func (src *FileOrEnvironmentSource) GetConfig() *Config { return src.config } + +// FileWatcherSource is a config source which triggers a change any time a file in the options changes. +type FileWatcherSource struct { + underlying Source + watcher *fileutil.Watcher + + mu sync.RWMutex + computedConfig *Config + version string + + ChangeDispatcher +} + +// NewFileWatcherSource creates a new FileWatcherSource. +func NewFileWatcherSource(underlying Source) *FileWatcherSource { + src := &FileWatcherSource{ + underlying: underlying, + watcher: fileutil.NewWatcher(), + } + + ch := src.watcher.Bind() + go func() { + for range ch { + src.check(underlying.GetConfig()) + } + }() + underlying.OnConfigChange(func(cfg *Config) { + src.check(cfg) + }) + src.check(underlying.GetConfig()) + + return src +} + +// GetConfig gets the underlying config. +func (src *FileWatcherSource) GetConfig() *Config { + src.mu.RLock() + defer src.mu.RUnlock() + return src.computedConfig +} + +func (src *FileWatcherSource) check(cfg *Config) { + src.mu.Lock() + defer src.mu.Unlock() + + src.watcher.Clear() + + h := sha256.New() + fs := []string{ + cfg.Options.CAFile, + cfg.Options.CertFile, + cfg.Options.ClientCAFile, + cfg.Options.DataBrokerStorageCAFile, + cfg.Options.DataBrokerStorageCertFile, + cfg.Options.DataBrokerStorageCertKeyFile, + cfg.Options.KeyFile, + cfg.Options.PolicyFile, + } + for _, f := range fs { + _, _ = h.Write([]byte{0}) + bs, err := ioutil.ReadFile(f) + if err == nil { + src.watcher.Add(f) + _, _ = h.Write(bs) + } + } + + version := hex.EncodeToString(h.Sum(nil)) + if src.version != version { + src.version = version + + // update the computed config + src.computedConfig = cfg.Clone() + src.computedConfig.Options.Certificates = nil + _ = src.computedConfig.Options.Validate() + + // trigger a change + src.Trigger(src.computedConfig) + } +} diff --git a/config/config_source_test.go b/config/config_source_test.go new file mode 100644 index 000000000..692a62e50 --- /dev/null +++ b/config/config_source_test.go @@ -0,0 +1,50 @@ +package config + +import ( + "io/ioutil" + "os" + "path/filepath" + "sync" + "testing" + "time" + + "github.com/google/uuid" + "github.com/stretchr/testify/assert" +) + +func TestFileWatcherSource(t *testing.T) { + tmpdir := filepath.Join(os.TempDir(), uuid.New().String()) + err := os.MkdirAll(tmpdir, 0o755) + if !assert.NoError(t, err) { + return + } + + err = ioutil.WriteFile(filepath.Join(tmpdir, "example.txt"), []byte{1, 2, 3, 4}, 0o600) + if !assert.NoError(t, err) { + return + } + + src := NewFileWatcherSource(NewStaticSource(&Config{ + Options: &Options{ + CAFile: filepath.Join(tmpdir, "example.txt"), + }, + })) + var closeOnce sync.Once + ch := make(chan struct{}) + src.OnConfigChange(func(cfg *Config) { + closeOnce.Do(func() { + close(ch) + }) + }) + + err = ioutil.WriteFile(filepath.Join(tmpdir, "example.txt"), []byte{5, 6, 7, 8}, 0o600) + if !assert.NoError(t, err) { + return + } + + select { + case <-ch: + case <-time.After(time.Second): + t.Error("expected OnConfigChange to be fired after modifying a file") + } +} diff --git a/go.mod b/go.mod index c84c1e93b..d65da9d7f 100644 --- a/go.mod +++ b/go.mod @@ -46,6 +46,7 @@ require ( github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35 // indirect github.com/prometheus/client_golang v1.9.0 github.com/rakyll/statik v0.1.7 + github.com/rjeczalik/notify v0.9.3-0.20201210012515-e2a77dcc14cf github.com/rs/cors v1.7.0 github.com/rs/zerolog v1.20.0 github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966 diff --git a/go.sum b/go.sum index d6e9bc41b..3a42212ce 100644 --- a/go.sum +++ b/go.sum @@ -544,6 +544,8 @@ github.com/rakyll/statik v0.1.7/go.mod h1:AlZONWzMtEnMs7W4e/1LURLiI49pIMmp6V9Ung github.com/rcrowley/go-metrics v0.0.0-20181016184325-3113b8401b8a/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4= github.com/rcrowley/go-metrics v0.0.0-20200313005456-10cdbea86bc0 h1:MkV+77GLUNo5oJ0jf870itWm3D0Sjh7+Za9gazKc5LQ= github.com/rcrowley/go-metrics v0.0.0-20200313005456-10cdbea86bc0/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4= +github.com/rjeczalik/notify v0.9.3-0.20201210012515-e2a77dcc14cf h1:MY2fqXPSLfjld10N04fNcSFdR9K/Y57iXxZRFAivHzI= +github.com/rjeczalik/notify v0.9.3-0.20201210012515-e2a77dcc14cf/go.mod h1:aErll2f0sUX9PXZnVNyeiObbmTlk5jnMoCa4QEjJeqM= github.com/rogpeppe/fastuuid v0.0.0-20150106093220-6724a57986af/go.mod h1:XWv6SoW27p1b0cqNHllgS5HIMJraePCO15w5zCzIWYg= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= github.com/rs/cors v1.7.0 h1:+88SsELBHx5r+hZ8TCkggzSstaWNbDvThkVK8H6f9ik= @@ -767,6 +769,7 @@ golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5h golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20180926160741-c2ed4eda69e7/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20181026203630-95b1ffbd15a5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20181107165924-66b7b1311ac8/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= diff --git a/internal/cmd/pomerium/pomerium.go b/internal/cmd/pomerium/pomerium.go index b297b30f2..b24faac6d 100644 --- a/internal/cmd/pomerium/pomerium.go +++ b/internal/cmd/pomerium/pomerium.go @@ -48,6 +48,9 @@ func Run(ctx context.Context, configFile string) error { return err } + // trigger changes when underlying files are changed + src = config.NewFileWatcherSource(src) + // override the default http transport so we can use the custom CA in the TLS client config (#1570) http.DefaultTransport = config.NewHTTPTransport(src) diff --git a/internal/controlplane/filemgr/config.go b/internal/controlplane/filemgr/config.go new file mode 100644 index 000000000..802d9046f --- /dev/null +++ b/internal/controlplane/filemgr/config.go @@ -0,0 +1,35 @@ +package filemgr + +import ( + "os" + "path/filepath" + + "github.com/google/uuid" +) + +type config struct { + cacheDir string +} + +// An Option updates the config. +type Option = func(*config) + +// WithCacheDir returns an Option that sets the cache dir. +func WithCacheDir(cacheDir string) Option { + return func(cfg *config) { + cfg.cacheDir = cacheDir + } +} + +func newConfig(options ...Option) *config { + cfg := new(config) + cacheDir, err := os.UserCacheDir() + if err != nil { + cacheDir = filepath.Join(os.TempDir(), uuid.New().String()) + } + WithCacheDir(filepath.Join(cacheDir, "pomerium", "envoy", "files"))(cfg) + for _, o := range options { + o(cfg) + } + return cfg +} diff --git a/internal/controlplane/filemgr/filemgr.go b/internal/controlplane/filemgr/filemgr.go new file mode 100644 index 000000000..542dffa94 --- /dev/null +++ b/internal/controlplane/filemgr/filemgr.go @@ -0,0 +1,91 @@ +package filemgr + +import ( + "fmt" + "io/ioutil" + "os" + "path/filepath" + + envoy_config_core_v3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" + "github.com/martinlindhe/base36" + + "github.com/pomerium/pomerium/internal/log" + "github.com/pomerium/pomerium/pkg/cryptutil" +) + +// A Manager manages files for envoy. +type Manager struct { + cfg *config +} + +// NewManager creates a new Manager. +func NewManager(options ...Option) *Manager { + cfg := newConfig(options...) + return &Manager{ + cfg: cfg, + } +} + +// BytesDataSource returns an envoy config data source based on bytes. +func (mgr *Manager) BytesDataSource(fileName string, data []byte) *envoy_config_core_v3.DataSource { + h := base36.EncodeBytes(cryptutil.Hash("filemgr", data)) + ext := filepath.Ext(fileName) + fileName = fmt.Sprintf("%s-%x%s", fileName[:len(fileName)-len(ext)], h, ext) + + if err := os.MkdirAll(mgr.cfg.cacheDir, 0o700); err != nil { + log.Error().Err(err).Msg("filemgr: error creating cache directory, falling back to inline bytes") + return inlineBytes(data) + } + + filePath := filepath.Join(mgr.cfg.cacheDir, fileName) + if _, err := os.Stat(filePath); os.IsNotExist(err) { + err = ioutil.WriteFile(filePath, data, 0o600) + if err != nil { + log.Error().Err(err).Msg("filemgr: error writing cache file, falling back to inline bytes") + return inlineBytes(data) + } + } else if err != nil { + log.Error().Err(err).Msg("filemgr: error reading cache file, falling back to inline bytes") + return inlineBytes(data) + } + + return inlineFilename(filePath) +} + +// ClearCache clears the file cache. +func (mgr *Manager) ClearCache() { + err := filepath.Walk(mgr.cfg.cacheDir, func(p string, fi os.FileInfo, err error) error { + if err != nil { + return err + } + return os.Remove(p) + }) + if err != nil { + log.Error().Err(err).Msg("failed to clear envoy file cache") + } +} + +// FileDataSource returns an envoy config data source based on a file. +func (mgr *Manager) FileDataSource(filePath string) *envoy_config_core_v3.DataSource { + data, err := ioutil.ReadFile(filePath) + if err != nil { + return inlineFilename(filePath) + } + return mgr.BytesDataSource(filepath.Base(filePath), data) +} + +func inlineBytes(data []byte) *envoy_config_core_v3.DataSource { + return &envoy_config_core_v3.DataSource{ + Specifier: &envoy_config_core_v3.DataSource_InlineBytes{ + InlineBytes: data, + }, + } +} + +func inlineFilename(name string) *envoy_config_core_v3.DataSource { + return &envoy_config_core_v3.DataSource{ + Specifier: &envoy_config_core_v3.DataSource_Filename{ + Filename: name, + }, + } +} diff --git a/internal/controlplane/filemgr/filemgr_test.go b/internal/controlplane/filemgr/filemgr_test.go new file mode 100644 index 000000000..d15a639a5 --- /dev/null +++ b/internal/controlplane/filemgr/filemgr_test.go @@ -0,0 +1,59 @@ +package filemgr + +import ( + "io/ioutil" + "os" + "path/filepath" + "testing" + + envoy_config_core_v3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" + "github.com/google/uuid" + "github.com/stretchr/testify/assert" +) + +func Test(t *testing.T) { + dir := filepath.Join(os.TempDir(), uuid.New().String()) + err := os.MkdirAll(dir, 0o755) + if !assert.NoError(t, err) { + return + } + defer func() { + _ = os.RemoveAll(dir) + }() + + t.Run("bytes", func(t *testing.T) { + mgr := NewManager(WithCacheDir(dir)) + ds := mgr.BytesDataSource("test.txt", []byte{1, 2, 3, 4, 5}) + assert.Equal(t, &envoy_config_core_v3.DataSource{ + Specifier: &envoy_config_core_v3.DataSource_Filename{ + Filename: filepath.Join(dir, "test-353354494b53534a5538435652584d594a5759394d43484f38514b34594b4b524b34515339593249344e4238474a5436414b.txt"), + }, + }, ds) + mgr.ClearCache() + }) + + t.Run("file", func(t *testing.T) { + tmpFilePath := filepath.Join(dir, "test.txt") + _ = ioutil.WriteFile(tmpFilePath, []byte("TEST1"), 0o777) + + mgr := NewManager(WithCacheDir(dir)) + + ds := mgr.FileDataSource(tmpFilePath) + assert.Equal(t, &envoy_config_core_v3.DataSource{ + Specifier: &envoy_config_core_v3.DataSource_Filename{ + Filename: filepath.Join(dir, "test-34514f59593332445a5649504230484142544c515057383944383730554833564d32574836354654585954304e424f464336.txt"), + }, + }, ds) + + _ = ioutil.WriteFile(tmpFilePath, []byte("TEST2"), 0o777) + + ds = mgr.FileDataSource(tmpFilePath) + assert.Equal(t, &envoy_config_core_v3.DataSource{ + Specifier: &envoy_config_core_v3.DataSource_Filename{ + Filename: filepath.Join(dir, "test-32564e4457304430393559364b5747373138584f484f5a51334d365758584b47364b555a4c444849513241513457323259.txt"), + }, + }, ds) + + mgr.ClearCache() + }) +} diff --git a/internal/controlplane/server.go b/internal/controlplane/server.go index a60a16d94..d3cb3c7e5 100644 --- a/internal/controlplane/server.go +++ b/internal/controlplane/server.go @@ -15,6 +15,7 @@ import ( "google.golang.org/grpc/reflection" "github.com/pomerium/pomerium/config" + "github.com/pomerium/pomerium/internal/controlplane/filemgr" "github.com/pomerium/pomerium/internal/controlplane/xdsmgr" "github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/telemetry" @@ -50,6 +51,7 @@ type Server struct { currentConfig atomicVersionedOptions name string xdsmgr *xdsmgr.Manager + filemgr *filemgr.Manager } // NewServer creates a new Server. Listener ports are chosen by the OS. @@ -87,6 +89,9 @@ func NewServer(name string) (*Server, error) { srv.xdsmgr = xdsmgr.NewManager(srv.buildDiscoveryResources()) envoy_service_discovery_v3.RegisterAggregatedDiscoveryServiceServer(srv.GRPCServer, srv.xdsmgr) + srv.filemgr = filemgr.NewManager() + srv.filemgr.ClearCache() + return srv, nil } diff --git a/internal/controlplane/xds.go b/internal/controlplane/xds.go index ea274fa9b..08056601f 100644 --- a/internal/controlplane/xds.go +++ b/internal/controlplane/xds.go @@ -7,14 +7,11 @@ import ( "encoding/hex" "encoding/pem" "fmt" - "io/ioutil" "net" "os" - "path/filepath" "strconv" "sync" - xxhash "github.com/cespare/xxhash/v2" envoy_config_accesslog_v3 "github.com/envoyproxy/go-control-plane/envoy/config/accesslog/v3" envoy_config_core_v3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" envoy_extensions_access_loggers_grpc_v3 "github.com/envoyproxy/go-control-plane/envoy/extensions/access_loggers/grpc/v3" @@ -45,7 +42,7 @@ func (srv *Server) buildDiscoveryResources() map[string][]*envoy_service_discove Resource: any, }) } - for _, listener := range buildListeners(&cfg.Options) { + for _, listener := range srv.buildListeners(&cfg.Options) { any, _ := anypb.New(listener) resources[listenerTypeURL] = append(resources[listenerTypeURL], &envoy_service_discovery_v3.Resource{ Name: listener.Name, @@ -116,52 +113,7 @@ func buildAddress(hostport string, defaultPort int) *envoy_config_core_v3.Addres } } -func inlineBytes(bs []byte) *envoy_config_core_v3.DataSource { - return &envoy_config_core_v3.DataSource{ - Specifier: &envoy_config_core_v3.DataSource_InlineBytes{ - InlineBytes: bs, - }, - } -} - -func inlineBytesAsFilename(name string, bs []byte) *envoy_config_core_v3.DataSource { - ext := filepath.Ext(name) - name = fmt.Sprintf("%s-%x%s", name[:len(name)-len(ext)], xxhash.Sum64(bs), ext) - - cacheDir, err := os.UserCacheDir() - if err != nil { - cacheDir = filepath.Join(os.TempDir()) - } - cacheDir = filepath.Join(cacheDir, "pomerium", "envoy", "files") - if err = os.MkdirAll(cacheDir, 0o755); err != nil { - log.Error().Err(err).Msg("error creating cache directory, falling back to inline bytes") - return inlineBytes(bs) - } - - fp := filepath.Join(cacheDir, name) - if _, err = os.Stat(fp); os.IsNotExist(err) { - err = ioutil.WriteFile(fp, bs, 0o600) - if err != nil { - log.Error().Err(err).Msg("error writing cache file, falling back to inline bytes") - return inlineBytes(bs) - } - } else if err != nil { - log.Error().Err(err).Msg("error reading cache file, falling back to inline bytes") - return inlineBytes(bs) - } - - return inlineFilename(fp) -} - -func inlineFilename(name string) *envoy_config_core_v3.DataSource { - return &envoy_config_core_v3.DataSource{ - Specifier: &envoy_config_core_v3.DataSource_Filename{ - Filename: name, - }, - } -} - -func envoyTLSCertificateFromGoTLSCertificate(cert *tls.Certificate) *envoy_extensions_transport_sockets_tls_v3.TlsCertificate { +func (srv *Server) envoyTLSCertificateFromGoTLSCertificate(cert *tls.Certificate) *envoy_extensions_transport_sockets_tls_v3.TlsCertificate { envoyCert := &envoy_extensions_transport_sockets_tls_v3.TlsCertificate{} var chain bytes.Buffer for _, cbs := range cert.Certificate { @@ -170,12 +122,12 @@ func envoyTLSCertificateFromGoTLSCertificate(cert *tls.Certificate) *envoy_exten Bytes: cbs, }) } - envoyCert.CertificateChain = inlineBytesAsFilename("tls-crt.pem", chain.Bytes()) + envoyCert.CertificateChain = srv.filemgr.BytesDataSource("tls-crt.pem", chain.Bytes()) if cert.OCSPStaple != nil { - envoyCert.OcspStaple = inlineBytes(cert.OCSPStaple) + envoyCert.OcspStaple = srv.filemgr.BytesDataSource("ocsp-staple", cert.OCSPStaple) } if bs, err := x509.MarshalPKCS8PrivateKey(cert.PrivateKey); err == nil { - envoyCert.PrivateKey = inlineBytesAsFilename("tls-key.pem", pem.EncodeToMemory( + envoyCert.PrivateKey = srv.filemgr.BytesDataSource("tls-key.pem", pem.EncodeToMemory( &pem.Block{ Type: "PRIVATE KEY", Bytes: bs, @@ -186,7 +138,7 @@ func envoyTLSCertificateFromGoTLSCertificate(cert *tls.Certificate) *envoy_exten } for _, scts := range cert.SignedCertificateTimestamps { envoyCert.SignedCertificateTimestamp = append(envoyCert.SignedCertificateTimestamp, - inlineBytes(scts)) + srv.filemgr.BytesDataSource("signed-certificate-timestamp", scts)) } return envoyCert } diff --git a/internal/controlplane/xds_cluster_test.go b/internal/controlplane/xds_cluster_test.go index 1553b59aa..620bfc86b 100644 --- a/internal/controlplane/xds_cluster_test.go +++ b/internal/controlplane/xds_cluster_test.go @@ -14,10 +14,15 @@ import ( ) func Test_buildPolicyTransportSocket(t *testing.T) { - rootCA, _ := getRootCertificateAuthority() cacheDir, _ := os.UserCacheDir() + customCA := filepath.Join(cacheDir, "pomerium", "envoy", "files", "custom-ca-32484c314b584447463735303142374c31414145374650305a525539554938594d524855353757313942494d473847535231.pem") + + srv, _ := NewServer("TEST") + rootCAPath, _ := getRootCertificateAuthority() + rootCA := srv.filemgr.FileDataSource(rootCAPath).GetFilename() + t.Run("insecure", func(t *testing.T) { - assert.Nil(t, buildPolicyTransportSocket(&config.Policy{ + assert.Nil(t, srv.buildPolicyTransportSocket(&config.Policy{ Destination: mustParseURL("http://example.com"), })) }) @@ -49,7 +54,7 @@ func Test_buildPolicyTransportSocket(t *testing.T) { "sni": "example.com" } } - `, buildPolicyTransportSocket(&config.Policy{ + `, srv.buildPolicyTransportSocket(&config.Policy{ Destination: mustParseURL("https://example.com"), })) }) @@ -81,7 +86,7 @@ func Test_buildPolicyTransportSocket(t *testing.T) { "sni": "use-this-name.example.com" } } - `, buildPolicyTransportSocket(&config.Policy{ + `, srv.buildPolicyTransportSocket(&config.Policy{ Destination: mustParseURL("https://example.com"), TLSServerName: "use-this-name.example.com", })) @@ -115,7 +120,7 @@ func Test_buildPolicyTransportSocket(t *testing.T) { "sni": "example.com" } } - `, buildPolicyTransportSocket(&config.Policy{ + `, srv.buildPolicyTransportSocket(&config.Policy{ Destination: mustParseURL("https://example.com"), TLSSkipVerify: true, })) @@ -141,14 +146,14 @@ func Test_buildPolicyTransportSocket(t *testing.T) { "exact": "example.com" }], "trustedCa": { - "filename": "`+filepath.Join(cacheDir, "pomerium", "envoy", "files", "custom-ca-3aefa6fd5cf2deb4.pem")+`" + "filename": "`+customCA+`" } } }, "sni": "example.com" } } - `, buildPolicyTransportSocket(&config.Policy{ + `, srv.buildPolicyTransportSocket(&config.Policy{ Destination: mustParseURL("https://example.com"), TLSCustomCA: base64.StdEncoding.EncodeToString([]byte{0, 0, 0, 0}), })) @@ -172,10 +177,10 @@ func Test_buildPolicyTransportSocket(t *testing.T) { }, "tlsCertificates": [{ "certificateChain":{ - "filename": "`+filepath.Join(cacheDir, "pomerium", "envoy", "files", "tls-crt-921a8294d2e2ec54.pem")+`" + "filename": "`+filepath.Join(cacheDir, "pomerium", "envoy", "files", "tls-crt-354e49305a5a39414a545530374e58454e48334148524c4e324258463837364355564c4e4532464b54355139495547514a38.pem")+`" }, "privateKey": { - "filename": "`+filepath.Join(cacheDir, "pomerium", "envoy", "files", "tls-key-d5cf35b1e8533e4a.pem")+`" + "filename": "`+filepath.Join(cacheDir, "pomerium", "envoy", "files", "tls-key-3350415a38414e4e4a4655424e55393430474147324651433949384e485341334b5157364f424b4c5856365a545937383735.pem")+`" } }], "validationContext": { @@ -190,7 +195,7 @@ func Test_buildPolicyTransportSocket(t *testing.T) { "sni": "example.com" } } - `, buildPolicyTransportSocket(&config.Policy{ + `, srv.buildPolicyTransportSocket(&config.Policy{ Destination: mustParseURL("https://example.com"), ClientCertificate: clientCert, })) @@ -198,7 +203,9 @@ func Test_buildPolicyTransportSocket(t *testing.T) { } func Test_buildCluster(t *testing.T) { - rootCA, _ := getRootCertificateAuthority() + srv, _ := NewServer("TEST") + rootCAPath, _ := getRootCertificateAuthority() + rootCA := srv.filemgr.FileDataSource(rootCAPath).GetFilename() t.Run("insecure", func(t *testing.T) { cluster := buildCluster("example", mustParseURL("http://example.com"), nil, true, config.GetEnvoyDNSLookupFamily(config.DNSLookupFamilyV4Only)) testutil.AssertProtoJSONEqual(t, ` @@ -232,7 +239,7 @@ func Test_buildCluster(t *testing.T) { }) t.Run("secure", func(t *testing.T) { u := mustParseURL("https://example.com") - transportSocket := buildPolicyTransportSocket(&config.Policy{ + transportSocket := srv.buildPolicyTransportSocket(&config.Policy{ Destination: u, }) cluster := buildCluster("example", u, transportSocket, true, config.GetEnvoyDNSLookupFamily(config.DNSLookupFamilyAuto)) diff --git a/internal/controlplane/xds_clusters.go b/internal/controlplane/xds_clusters.go index 303b97210..84142233b 100644 --- a/internal/controlplane/xds_clusters.go +++ b/internal/controlplane/xds_clusters.go @@ -34,17 +34,17 @@ func (srv *Server) buildClusters(options *config.Options) []*envoy_config_cluste } clusters := []*envoy_config_cluster_v3.Cluster{ - buildInternalCluster(options, "pomerium-control-plane-grpc", grpcURL, true), - buildInternalCluster(options, "pomerium-control-plane-http", httpURL, false), + srv.buildInternalCluster(options, "pomerium-control-plane-grpc", grpcURL, true), + srv.buildInternalCluster(options, "pomerium-control-plane-http", httpURL, false), } - clusters = append(clusters, buildInternalCluster(options, authzURL.Host, authzURL, true)) + clusters = append(clusters, srv.buildInternalCluster(options, authzURL.Host, authzURL, true)) if config.IsProxy(options.Services) { for i := range options.Policies { policy := options.Policies[i] if policy.Destination != nil { - clusters = append(clusters, buildPolicyCluster(options, &policy)) + clusters = append(clusters, srv.buildPolicyCluster(options, &policy)) } } } @@ -52,21 +52,21 @@ func (srv *Server) buildClusters(options *config.Options) []*envoy_config_cluste return clusters } -func buildInternalCluster(options *config.Options, name string, endpoint *url.URL, forceHTTP2 bool) *envoy_config_cluster_v3.Cluster { +func (srv *Server) buildInternalCluster(options *config.Options, name string, endpoint *url.URL, forceHTTP2 bool) *envoy_config_cluster_v3.Cluster { dnsLookupFamily := config.GetEnvoyDNSLookupFamily(options.DNSLookupFamily) - return buildCluster(name, endpoint, buildInternalTransportSocket(options, endpoint), forceHTTP2, dnsLookupFamily) + return buildCluster(name, endpoint, srv.buildInternalTransportSocket(options, endpoint), forceHTTP2, dnsLookupFamily) } -func buildPolicyCluster(options *config.Options, policy *config.Policy) *envoy_config_cluster_v3.Cluster { +func (srv *Server) buildPolicyCluster(options *config.Options, policy *config.Policy) *envoy_config_cluster_v3.Cluster { name := getPolicyName(policy) dnsLookupFamily := config.GetEnvoyDNSLookupFamily(options.DNSLookupFamily) if policy.EnableGoogleCloudServerlessAuthentication { dnsLookupFamily = envoy_config_cluster_v3.Cluster_V4_ONLY } - return buildCluster(name, policy.Destination, buildPolicyTransportSocket(policy), false, dnsLookupFamily) + return buildCluster(name, policy.Destination, srv.buildPolicyTransportSocket(policy), false, dnsLookupFamily) } -func buildInternalTransportSocket(options *config.Options, endpoint *url.URL) *envoy_config_core_v3.TransportSocket { +func (srv *Server) buildInternalTransportSocket(options *config.Options, endpoint *url.URL) *envoy_config_core_v3.TransportSocket { if endpoint.Scheme != "https" { return nil } @@ -82,19 +82,19 @@ func buildInternalTransportSocket(options *config.Options, endpoint *url.URL) *e }}, } if options.CAFile != "" { - validationContext.TrustedCa = inlineFilename(options.CAFile) + validationContext.TrustedCa = srv.filemgr.FileDataSource(options.CAFile) } else if options.CA != "" { bs, err := base64.StdEncoding.DecodeString(options.CA) if err != nil { log.Error().Err(err).Msg("invalid custom CA certificate") } - validationContext.TrustedCa = inlineBytesAsFilename("custom-ca.pem", bs) + validationContext.TrustedCa = srv.filemgr.BytesDataSource("custom-ca.pem", bs) } else { rootCA, err := getRootCertificateAuthority() if err != nil { log.Error().Err(err).Msg("unable to enable certificate verification because no root CAs were found") } else { - validationContext.TrustedCa = inlineFilename(rootCA) + validationContext.TrustedCa = srv.filemgr.FileDataSource(rootCA) } } tlsContext := &envoy_extensions_transport_sockets_tls_v3.UpstreamTlsContext{ @@ -115,7 +115,7 @@ func buildInternalTransportSocket(options *config.Options, endpoint *url.URL) *e } } -func buildPolicyTransportSocket(policy *config.Policy) *envoy_config_core_v3.TransportSocket { +func (srv *Server) buildPolicyTransportSocket(policy *config.Policy) *envoy_config_core_v3.TransportSocket { if policy.Destination == nil || policy.Destination.Scheme != "https" { return nil } @@ -136,14 +136,14 @@ func buildPolicyTransportSocket(policy *config.Policy) *envoy_config_core_v3.Tra }, AlpnProtocols: []string{"http/1.1"}, ValidationContextType: &envoy_extensions_transport_sockets_tls_v3.CommonTlsContext_ValidationContext{ - ValidationContext: buildPolicyValidationContext(policy), + ValidationContext: srv.buildPolicyValidationContext(policy), }, }, Sni: sni, } if policy.ClientCertificate != nil { tlsContext.CommonTlsContext.TlsCertificates = append(tlsContext.CommonTlsContext.TlsCertificates, - envoyTLSCertificateFromGoTLSCertificate(policy.ClientCertificate)) + srv.envoyTLSCertificateFromGoTLSCertificate(policy.ClientCertificate)) } tlsConfig := marshalAny(tlsContext) @@ -155,7 +155,7 @@ func buildPolicyTransportSocket(policy *config.Policy) *envoy_config_core_v3.Tra } } -func buildPolicyValidationContext(policy *config.Policy) *envoy_extensions_transport_sockets_tls_v3.CertificateValidationContext { +func (srv *Server) buildPolicyValidationContext(policy *config.Policy) *envoy_extensions_transport_sockets_tls_v3.CertificateValidationContext { if policy.Destination == nil { return nil } @@ -172,19 +172,19 @@ func buildPolicyValidationContext(policy *config.Policy) *envoy_extensions_trans }}, } if policy.TLSCustomCAFile != "" { - validationContext.TrustedCa = inlineFilename(policy.TLSCustomCAFile) + validationContext.TrustedCa = srv.filemgr.FileDataSource(policy.TLSCustomCAFile) } else if policy.TLSCustomCA != "" { bs, err := base64.StdEncoding.DecodeString(policy.TLSCustomCA) if err != nil { log.Error().Err(err).Msg("invalid custom CA certificate") } - validationContext.TrustedCa = inlineBytesAsFilename("custom-ca.pem", bs) + validationContext.TrustedCa = srv.filemgr.BytesDataSource("custom-ca.pem", bs) } else { rootCA, err := getRootCertificateAuthority() if err != nil { log.Error().Err(err).Msg("unable to enable certificate verification because no root CAs were found") } else { - validationContext.TrustedCa = inlineFilename(rootCA) + validationContext.TrustedCa = srv.filemgr.FileDataSource(rootCA) } } diff --git a/internal/controlplane/xds_listeners.go b/internal/controlplane/xds_listeners.go index 0e67a106e..de50622f2 100644 --- a/internal/controlplane/xds_listeners.go +++ b/internal/controlplane/xds_listeners.go @@ -37,21 +37,21 @@ func init() { }) } -func buildListeners(options *config.Options) []*envoy_config_listener_v3.Listener { +func (srv *Server) buildListeners(options *config.Options) []*envoy_config_listener_v3.Listener { var listeners []*envoy_config_listener_v3.Listener if config.IsAuthenticate(options.Services) || config.IsProxy(options.Services) { - listeners = append(listeners, buildMainListener(options)) + listeners = append(listeners, srv.buildMainListener(options)) } if config.IsAuthorize(options.Services) || config.IsCache(options.Services) { - listeners = append(listeners, buildGRPCListener(options)) + listeners = append(listeners, srv.buildGRPCListener(options)) } return listeners } -func buildMainListener(options *config.Options) *envoy_config_listener_v3.Listener { +func (srv *Server) buildMainListener(options *config.Options) *envoy_config_listener_v3.Listener { if options.InsecureServer { filter := buildMainHTTPConnectionManagerFilter(options, getAllRouteableDomains(options, options.Addr)) @@ -88,7 +88,7 @@ func buildMainListener(options *config.Options) *envoy_config_listener_v3.Listen ServerNames: []string{tlsDomain}, } } - tlsContext := buildDownstreamTLSContext(options, tlsDomain) + tlsContext := srv.buildDownstreamTLSContext(options, tlsDomain) if tlsContext != nil { tlsConfig := marshalAny(tlsContext) filterChain.TransportSocket = &envoy_config_core_v3.TransportSocket{ @@ -250,7 +250,7 @@ func buildMainHTTPConnectionManagerFilter(options *config.Options, domains []str } } -func buildGRPCListener(options *config.Options) *envoy_config_listener_v3.Listener { +func (srv *Server) buildGRPCListener(options *config.Options) *envoy_config_listener_v3.Listener { filter := buildGRPCHTTPConnectionManagerFilter() if options.GRPCInsecure { @@ -285,7 +285,7 @@ func buildGRPCListener(options *config.Options) *envoy_config_listener_v3.Listen ServerNames: []string{tlsDomain}, } } - tlsContext := buildDownstreamTLSContext(options, tlsDomain) + tlsContext := srv.buildDownstreamTLSContext(options, tlsDomain) if tlsContext != nil { tlsConfig := marshalAny(tlsContext) filterChain.TransportSocket = &envoy_config_core_v3.TransportSocket{ @@ -357,7 +357,7 @@ func buildRouteConfiguration(name string, virtualHosts []*envoy_config_route_v3. } } -func buildDownstreamTLSContext(options *config.Options, domain string) *envoy_extensions_transport_sockets_tls_v3.DownstreamTlsContext { +func (srv *Server) buildDownstreamTLSContext(options *config.Options, domain string) *envoy_extensions_transport_sockets_tls_v3.DownstreamTlsContext { cert, err := cryptutil.GetCertificateForDomain(options.Certificates, domain) if err != nil { log.Warn().Str("domain", domain).Err(err).Msg("failed to get certificate for domain") @@ -370,9 +370,9 @@ func buildDownstreamTLSContext(options *config.Options, domain string) *envoy_ex if err != nil { log.Warn().Msg("client_ca does not appear to be a base64 encoded string") } - trustedCA = inlineBytesAsFilename("client-ca", bs) + trustedCA = srv.filemgr.BytesDataSource("client-ca", bs) } else if options.ClientCAFile != "" { - trustedCA = inlineFilename(options.ClientCAFile) + trustedCA = srv.filemgr.FileDataSource(options.ClientCAFile) } var validationContext *envoy_extensions_transport_sockets_tls_v3.CommonTlsContext_ValidationContext @@ -385,7 +385,7 @@ func buildDownstreamTLSContext(options *config.Options, domain string) *envoy_ex } } - envoyCert := envoyTLSCertificateFromGoTLSCertificate(cert) + envoyCert := srv.envoyTLSCertificateFromGoTLSCertificate(cert) return &envoy_extensions_transport_sockets_tls_v3.DownstreamTlsContext{ CommonTlsContext: &envoy_extensions_transport_sockets_tls_v3.CommonTlsContext{ TlsParams: &envoy_extensions_transport_sockets_tls_v3.TlsParameters{ diff --git a/internal/controlplane/xds_listeners_test.go b/internal/controlplane/xds_listeners_test.go index d070de34e..14f06abbf 100644 --- a/internal/controlplane/xds_listeners_test.go +++ b/internal/controlplane/xds_listeners_test.go @@ -385,13 +385,15 @@ func Test_buildDownstreamTLSContext(t *testing.T) { return } - downstreamTLSContext := buildDownstreamTLSContext(&config.Options{ + srv, _ := NewServer("TEST") + + downstreamTLSContext := srv.buildDownstreamTLSContext(&config.Options{ Certificates: []tls.Certificate{*certA}, }, "a.example.com") cacheDir, _ := os.UserCacheDir() - certFileName := filepath.Join(cacheDir, "pomerium", "envoy", "files", "tls-crt-921a8294d2e2ec54.pem") - keyFileName := filepath.Join(cacheDir, "pomerium", "envoy", "files", "tls-key-d5cf35b1e8533e4a.pem") + certFileName := filepath.Join(cacheDir, "pomerium", "envoy", "files", "tls-crt-354e49305a5a39414a545530374e58454e48334148524c4e324258463837364355564c4e4532464b54355139495547514a38.pem") + keyFileName := filepath.Join(cacheDir, "pomerium", "envoy", "files", "tls-key-3350415a38414e4e4a4655424e55393430474147324651433949384e485341334b5157364f424b4c5856365a545937383735.pem") testutil.AssertProtoJSONEqual(t, `{ "commonTlsContext": { diff --git a/internal/fileutil/watcher.go b/internal/fileutil/watcher.go new file mode 100644 index 000000000..9438c28ab --- /dev/null +++ b/internal/fileutil/watcher.go @@ -0,0 +1,66 @@ +package fileutil + +import ( + "sync" + + "github.com/rjeczalik/notify" + + "github.com/pomerium/pomerium/internal/log" + "github.com/pomerium/pomerium/internal/signal" +) + +// A Watcher watches files for changes. +type Watcher struct { + *signal.Signal + mu sync.Mutex + filePaths map[string]chan notify.EventInfo +} + +// NewWatcher creates a new Watcher. +func NewWatcher() *Watcher { + return &Watcher{ + Signal: signal.New(), + filePaths: map[string]chan notify.EventInfo{}, + } +} + +// Add adds a new watch. +func (watcher *Watcher) Add(filePath string) { + watcher.mu.Lock() + defer watcher.mu.Unlock() + + // already watching + if _, ok := watcher.filePaths[filePath]; ok { + return + } + + ch := make(chan notify.EventInfo, 1) + go func() { + for evt := range ch { + log.Info().Str("path", evt.Path()).Str("event", evt.Event().String()).Msg("filemgr: detected file change") + watcher.Signal.Broadcast() + } + }() + err := notify.Watch(filePath, ch, notify.All) + if err != nil { + log.Error().Err(err).Str("path", filePath).Msg("filemgr: error watching file path") + notify.Stop(ch) + close(ch) + return + } + log.Debug().Str("path", filePath).Msg("filemgr: watching file for changes") + + watcher.filePaths[filePath] = ch +} + +// Clear removes all watches. +func (watcher *Watcher) Clear() { + watcher.mu.Lock() + defer watcher.mu.Unlock() + + for filePath, ch := range watcher.filePaths { + notify.Stop(ch) + close(ch) + delete(watcher.filePaths, filePath) + } +} diff --git a/internal/fileutil/watcher_test.go b/internal/fileutil/watcher_test.go new file mode 100644 index 000000000..cff8cc41c --- /dev/null +++ b/internal/fileutil/watcher_test.go @@ -0,0 +1,42 @@ +package fileutil + +import ( + "io/ioutil" + "os" + "path/filepath" + "testing" + "time" + + "github.com/google/uuid" + "github.com/stretchr/testify/assert" +) + +func TestWatcher(t *testing.T) { + tmpdir := filepath.Join(os.TempDir(), uuid.New().String()) + err := os.MkdirAll(tmpdir, 0o755) + if !assert.NoError(t, err) { + return + } + + err = ioutil.WriteFile(filepath.Join(tmpdir, "test1.txt"), []byte{1, 2, 3, 4}, 0o666) + if !assert.NoError(t, err) { + return + } + + w := NewWatcher() + w.Add(filepath.Join(tmpdir, "test1.txt")) + + ch := w.Bind() + defer w.Unbind(ch) + + err = ioutil.WriteFile(filepath.Join(tmpdir, "test1.txt"), []byte{5, 6, 7, 8}, 0o666) + if !assert.NoError(t, err) { + return + } + + select { + case <-ch: + case <-time.After(time.Second): + t.Error("expected change signal when file is modified") + } +}