mirror of
https://github.com/pomerium/pomerium.git
synced 2025-04-29 18:36:30 +02:00
config: detect underlying file changes (#1775)
* wip * cleanup * add test * use uuid for temp dir, derive root CA path from filemgr for tests * fix comment * fix double close * use latest notify
This commit is contained in:
parent
c99994bed8
commit
10912add67
16 changed files with 500 additions and 99 deletions
|
@ -1,9 +1,14 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"io/ioutil"
|
||||
"sync"
|
||||
|
||||
"github.com/fsnotify/fsnotify"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/fileutil"
|
||||
)
|
||||
|
||||
// Config holds pomerium configuration options.
|
||||
|
@ -125,3 +130,83 @@ func (src *FileOrEnvironmentSource) GetConfig() *Config {
|
|||
|
||||
return src.config
|
||||
}
|
||||
|
||||
// FileWatcherSource is a config source which triggers a change any time a file in the options changes.
|
||||
type FileWatcherSource struct {
|
||||
underlying Source
|
||||
watcher *fileutil.Watcher
|
||||
|
||||
mu sync.RWMutex
|
||||
computedConfig *Config
|
||||
version string
|
||||
|
||||
ChangeDispatcher
|
||||
}
|
||||
|
||||
// NewFileWatcherSource creates a new FileWatcherSource.
|
||||
func NewFileWatcherSource(underlying Source) *FileWatcherSource {
|
||||
src := &FileWatcherSource{
|
||||
underlying: underlying,
|
||||
watcher: fileutil.NewWatcher(),
|
||||
}
|
||||
|
||||
ch := src.watcher.Bind()
|
||||
go func() {
|
||||
for range ch {
|
||||
src.check(underlying.GetConfig())
|
||||
}
|
||||
}()
|
||||
underlying.OnConfigChange(func(cfg *Config) {
|
||||
src.check(cfg)
|
||||
})
|
||||
src.check(underlying.GetConfig())
|
||||
|
||||
return src
|
||||
}
|
||||
|
||||
// GetConfig gets the underlying config.
|
||||
func (src *FileWatcherSource) GetConfig() *Config {
|
||||
src.mu.RLock()
|
||||
defer src.mu.RUnlock()
|
||||
return src.computedConfig
|
||||
}
|
||||
|
||||
func (src *FileWatcherSource) check(cfg *Config) {
|
||||
src.mu.Lock()
|
||||
defer src.mu.Unlock()
|
||||
|
||||
src.watcher.Clear()
|
||||
|
||||
h := sha256.New()
|
||||
fs := []string{
|
||||
cfg.Options.CAFile,
|
||||
cfg.Options.CertFile,
|
||||
cfg.Options.ClientCAFile,
|
||||
cfg.Options.DataBrokerStorageCAFile,
|
||||
cfg.Options.DataBrokerStorageCertFile,
|
||||
cfg.Options.DataBrokerStorageCertKeyFile,
|
||||
cfg.Options.KeyFile,
|
||||
cfg.Options.PolicyFile,
|
||||
}
|
||||
for _, f := range fs {
|
||||
_, _ = h.Write([]byte{0})
|
||||
bs, err := ioutil.ReadFile(f)
|
||||
if err == nil {
|
||||
src.watcher.Add(f)
|
||||
_, _ = h.Write(bs)
|
||||
}
|
||||
}
|
||||
|
||||
version := hex.EncodeToString(h.Sum(nil))
|
||||
if src.version != version {
|
||||
src.version = version
|
||||
|
||||
// update the computed config
|
||||
src.computedConfig = cfg.Clone()
|
||||
src.computedConfig.Options.Certificates = nil
|
||||
_ = src.computedConfig.Options.Validate()
|
||||
|
||||
// trigger a change
|
||||
src.Trigger(src.computedConfig)
|
||||
}
|
||||
}
|
||||
|
|
50
config/config_source_test.go
Normal file
50
config/config_source_test.go
Normal file
|
@ -0,0 +1,50 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestFileWatcherSource(t *testing.T) {
|
||||
tmpdir := filepath.Join(os.TempDir(), uuid.New().String())
|
||||
err := os.MkdirAll(tmpdir, 0o755)
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
|
||||
err = ioutil.WriteFile(filepath.Join(tmpdir, "example.txt"), []byte{1, 2, 3, 4}, 0o600)
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
|
||||
src := NewFileWatcherSource(NewStaticSource(&Config{
|
||||
Options: &Options{
|
||||
CAFile: filepath.Join(tmpdir, "example.txt"),
|
||||
},
|
||||
}))
|
||||
var closeOnce sync.Once
|
||||
ch := make(chan struct{})
|
||||
src.OnConfigChange(func(cfg *Config) {
|
||||
closeOnce.Do(func() {
|
||||
close(ch)
|
||||
})
|
||||
})
|
||||
|
||||
err = ioutil.WriteFile(filepath.Join(tmpdir, "example.txt"), []byte{5, 6, 7, 8}, 0o600)
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ch:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("expected OnConfigChange to be fired after modifying a file")
|
||||
}
|
||||
}
|
1
go.mod
1
go.mod
|
@ -46,6 +46,7 @@ require (
|
|||
github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35 // indirect
|
||||
github.com/prometheus/client_golang v1.9.0
|
||||
github.com/rakyll/statik v0.1.7
|
||||
github.com/rjeczalik/notify v0.9.3-0.20201210012515-e2a77dcc14cf
|
||||
github.com/rs/cors v1.7.0
|
||||
github.com/rs/zerolog v1.20.0
|
||||
github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966
|
||||
|
|
3
go.sum
3
go.sum
|
@ -544,6 +544,8 @@ github.com/rakyll/statik v0.1.7/go.mod h1:AlZONWzMtEnMs7W4e/1LURLiI49pIMmp6V9Ung
|
|||
github.com/rcrowley/go-metrics v0.0.0-20181016184325-3113b8401b8a/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4=
|
||||
github.com/rcrowley/go-metrics v0.0.0-20200313005456-10cdbea86bc0 h1:MkV+77GLUNo5oJ0jf870itWm3D0Sjh7+Za9gazKc5LQ=
|
||||
github.com/rcrowley/go-metrics v0.0.0-20200313005456-10cdbea86bc0/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4=
|
||||
github.com/rjeczalik/notify v0.9.3-0.20201210012515-e2a77dcc14cf h1:MY2fqXPSLfjld10N04fNcSFdR9K/Y57iXxZRFAivHzI=
|
||||
github.com/rjeczalik/notify v0.9.3-0.20201210012515-e2a77dcc14cf/go.mod h1:aErll2f0sUX9PXZnVNyeiObbmTlk5jnMoCa4QEjJeqM=
|
||||
github.com/rogpeppe/fastuuid v0.0.0-20150106093220-6724a57986af/go.mod h1:XWv6SoW27p1b0cqNHllgS5HIMJraePCO15w5zCzIWYg=
|
||||
github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
|
||||
github.com/rs/cors v1.7.0 h1:+88SsELBHx5r+hZ8TCkggzSstaWNbDvThkVK8H6f9ik=
|
||||
|
@ -767,6 +769,7 @@ golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5h
|
|||
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20180926160741-c2ed4eda69e7/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20181026203630-95b1ffbd15a5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20181107165924-66b7b1311ac8/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
|
|
|
@ -48,6 +48,9 @@ func Run(ctx context.Context, configFile string) error {
|
|||
return err
|
||||
}
|
||||
|
||||
// trigger changes when underlying files are changed
|
||||
src = config.NewFileWatcherSource(src)
|
||||
|
||||
// override the default http transport so we can use the custom CA in the TLS client config (#1570)
|
||||
http.DefaultTransport = config.NewHTTPTransport(src)
|
||||
|
||||
|
|
35
internal/controlplane/filemgr/config.go
Normal file
35
internal/controlplane/filemgr/config.go
Normal file
|
@ -0,0 +1,35 @@
|
|||
package filemgr
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type config struct {
|
||||
cacheDir string
|
||||
}
|
||||
|
||||
// An Option updates the config.
|
||||
type Option = func(*config)
|
||||
|
||||
// WithCacheDir returns an Option that sets the cache dir.
|
||||
func WithCacheDir(cacheDir string) Option {
|
||||
return func(cfg *config) {
|
||||
cfg.cacheDir = cacheDir
|
||||
}
|
||||
}
|
||||
|
||||
func newConfig(options ...Option) *config {
|
||||
cfg := new(config)
|
||||
cacheDir, err := os.UserCacheDir()
|
||||
if err != nil {
|
||||
cacheDir = filepath.Join(os.TempDir(), uuid.New().String())
|
||||
}
|
||||
WithCacheDir(filepath.Join(cacheDir, "pomerium", "envoy", "files"))(cfg)
|
||||
for _, o := range options {
|
||||
o(cfg)
|
||||
}
|
||||
return cfg
|
||||
}
|
91
internal/controlplane/filemgr/filemgr.go
Normal file
91
internal/controlplane/filemgr/filemgr.go
Normal file
|
@ -0,0 +1,91 @@
|
|||
package filemgr
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
envoy_config_core_v3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
|
||||
"github.com/martinlindhe/base36"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||
)
|
||||
|
||||
// A Manager manages files for envoy.
|
||||
type Manager struct {
|
||||
cfg *config
|
||||
}
|
||||
|
||||
// NewManager creates a new Manager.
|
||||
func NewManager(options ...Option) *Manager {
|
||||
cfg := newConfig(options...)
|
||||
return &Manager{
|
||||
cfg: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
// BytesDataSource returns an envoy config data source based on bytes.
|
||||
func (mgr *Manager) BytesDataSource(fileName string, data []byte) *envoy_config_core_v3.DataSource {
|
||||
h := base36.EncodeBytes(cryptutil.Hash("filemgr", data))
|
||||
ext := filepath.Ext(fileName)
|
||||
fileName = fmt.Sprintf("%s-%x%s", fileName[:len(fileName)-len(ext)], h, ext)
|
||||
|
||||
if err := os.MkdirAll(mgr.cfg.cacheDir, 0o700); err != nil {
|
||||
log.Error().Err(err).Msg("filemgr: error creating cache directory, falling back to inline bytes")
|
||||
return inlineBytes(data)
|
||||
}
|
||||
|
||||
filePath := filepath.Join(mgr.cfg.cacheDir, fileName)
|
||||
if _, err := os.Stat(filePath); os.IsNotExist(err) {
|
||||
err = ioutil.WriteFile(filePath, data, 0o600)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("filemgr: error writing cache file, falling back to inline bytes")
|
||||
return inlineBytes(data)
|
||||
}
|
||||
} else if err != nil {
|
||||
log.Error().Err(err).Msg("filemgr: error reading cache file, falling back to inline bytes")
|
||||
return inlineBytes(data)
|
||||
}
|
||||
|
||||
return inlineFilename(filePath)
|
||||
}
|
||||
|
||||
// ClearCache clears the file cache.
|
||||
func (mgr *Manager) ClearCache() {
|
||||
err := filepath.Walk(mgr.cfg.cacheDir, func(p string, fi os.FileInfo, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return os.Remove(p)
|
||||
})
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("failed to clear envoy file cache")
|
||||
}
|
||||
}
|
||||
|
||||
// FileDataSource returns an envoy config data source based on a file.
|
||||
func (mgr *Manager) FileDataSource(filePath string) *envoy_config_core_v3.DataSource {
|
||||
data, err := ioutil.ReadFile(filePath)
|
||||
if err != nil {
|
||||
return inlineFilename(filePath)
|
||||
}
|
||||
return mgr.BytesDataSource(filepath.Base(filePath), data)
|
||||
}
|
||||
|
||||
func inlineBytes(data []byte) *envoy_config_core_v3.DataSource {
|
||||
return &envoy_config_core_v3.DataSource{
|
||||
Specifier: &envoy_config_core_v3.DataSource_InlineBytes{
|
||||
InlineBytes: data,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func inlineFilename(name string) *envoy_config_core_v3.DataSource {
|
||||
return &envoy_config_core_v3.DataSource{
|
||||
Specifier: &envoy_config_core_v3.DataSource_Filename{
|
||||
Filename: name,
|
||||
},
|
||||
}
|
||||
}
|
59
internal/controlplane/filemgr/filemgr_test.go
Normal file
59
internal/controlplane/filemgr/filemgr_test.go
Normal file
|
@ -0,0 +1,59 @@
|
|||
package filemgr
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
envoy_config_core_v3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test(t *testing.T) {
|
||||
dir := filepath.Join(os.TempDir(), uuid.New().String())
|
||||
err := os.MkdirAll(dir, 0o755)
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
_ = os.RemoveAll(dir)
|
||||
}()
|
||||
|
||||
t.Run("bytes", func(t *testing.T) {
|
||||
mgr := NewManager(WithCacheDir(dir))
|
||||
ds := mgr.BytesDataSource("test.txt", []byte{1, 2, 3, 4, 5})
|
||||
assert.Equal(t, &envoy_config_core_v3.DataSource{
|
||||
Specifier: &envoy_config_core_v3.DataSource_Filename{
|
||||
Filename: filepath.Join(dir, "test-353354494b53534a5538435652584d594a5759394d43484f38514b34594b4b524b34515339593249344e4238474a5436414b.txt"),
|
||||
},
|
||||
}, ds)
|
||||
mgr.ClearCache()
|
||||
})
|
||||
|
||||
t.Run("file", func(t *testing.T) {
|
||||
tmpFilePath := filepath.Join(dir, "test.txt")
|
||||
_ = ioutil.WriteFile(tmpFilePath, []byte("TEST1"), 0o777)
|
||||
|
||||
mgr := NewManager(WithCacheDir(dir))
|
||||
|
||||
ds := mgr.FileDataSource(tmpFilePath)
|
||||
assert.Equal(t, &envoy_config_core_v3.DataSource{
|
||||
Specifier: &envoy_config_core_v3.DataSource_Filename{
|
||||
Filename: filepath.Join(dir, "test-34514f59593332445a5649504230484142544c515057383944383730554833564d32574836354654585954304e424f464336.txt"),
|
||||
},
|
||||
}, ds)
|
||||
|
||||
_ = ioutil.WriteFile(tmpFilePath, []byte("TEST2"), 0o777)
|
||||
|
||||
ds = mgr.FileDataSource(tmpFilePath)
|
||||
assert.Equal(t, &envoy_config_core_v3.DataSource{
|
||||
Specifier: &envoy_config_core_v3.DataSource_Filename{
|
||||
Filename: filepath.Join(dir, "test-32564e4457304430393559364b5747373138584f484f5a51334d365758584b47364b555a4c444849513241513457323259.txt"),
|
||||
},
|
||||
}, ds)
|
||||
|
||||
mgr.ClearCache()
|
||||
})
|
||||
}
|
|
@ -15,6 +15,7 @@ import (
|
|||
"google.golang.org/grpc/reflection"
|
||||
|
||||
"github.com/pomerium/pomerium/config"
|
||||
"github.com/pomerium/pomerium/internal/controlplane/filemgr"
|
||||
"github.com/pomerium/pomerium/internal/controlplane/xdsmgr"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/internal/telemetry"
|
||||
|
@ -50,6 +51,7 @@ type Server struct {
|
|||
currentConfig atomicVersionedOptions
|
||||
name string
|
||||
xdsmgr *xdsmgr.Manager
|
||||
filemgr *filemgr.Manager
|
||||
}
|
||||
|
||||
// NewServer creates a new Server. Listener ports are chosen by the OS.
|
||||
|
@ -87,6 +89,9 @@ func NewServer(name string) (*Server, error) {
|
|||
srv.xdsmgr = xdsmgr.NewManager(srv.buildDiscoveryResources())
|
||||
envoy_service_discovery_v3.RegisterAggregatedDiscoveryServiceServer(srv.GRPCServer, srv.xdsmgr)
|
||||
|
||||
srv.filemgr = filemgr.NewManager()
|
||||
srv.filemgr.ClearCache()
|
||||
|
||||
return srv, nil
|
||||
}
|
||||
|
||||
|
|
|
@ -7,14 +7,11 @@ import (
|
|||
"encoding/hex"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"sync"
|
||||
|
||||
xxhash "github.com/cespare/xxhash/v2"
|
||||
envoy_config_accesslog_v3 "github.com/envoyproxy/go-control-plane/envoy/config/accesslog/v3"
|
||||
envoy_config_core_v3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
|
||||
envoy_extensions_access_loggers_grpc_v3 "github.com/envoyproxy/go-control-plane/envoy/extensions/access_loggers/grpc/v3"
|
||||
|
@ -45,7 +42,7 @@ func (srv *Server) buildDiscoveryResources() map[string][]*envoy_service_discove
|
|||
Resource: any,
|
||||
})
|
||||
}
|
||||
for _, listener := range buildListeners(&cfg.Options) {
|
||||
for _, listener := range srv.buildListeners(&cfg.Options) {
|
||||
any, _ := anypb.New(listener)
|
||||
resources[listenerTypeURL] = append(resources[listenerTypeURL], &envoy_service_discovery_v3.Resource{
|
||||
Name: listener.Name,
|
||||
|
@ -116,52 +113,7 @@ func buildAddress(hostport string, defaultPort int) *envoy_config_core_v3.Addres
|
|||
}
|
||||
}
|
||||
|
||||
func inlineBytes(bs []byte) *envoy_config_core_v3.DataSource {
|
||||
return &envoy_config_core_v3.DataSource{
|
||||
Specifier: &envoy_config_core_v3.DataSource_InlineBytes{
|
||||
InlineBytes: bs,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func inlineBytesAsFilename(name string, bs []byte) *envoy_config_core_v3.DataSource {
|
||||
ext := filepath.Ext(name)
|
||||
name = fmt.Sprintf("%s-%x%s", name[:len(name)-len(ext)], xxhash.Sum64(bs), ext)
|
||||
|
||||
cacheDir, err := os.UserCacheDir()
|
||||
if err != nil {
|
||||
cacheDir = filepath.Join(os.TempDir())
|
||||
}
|
||||
cacheDir = filepath.Join(cacheDir, "pomerium", "envoy", "files")
|
||||
if err = os.MkdirAll(cacheDir, 0o755); err != nil {
|
||||
log.Error().Err(err).Msg("error creating cache directory, falling back to inline bytes")
|
||||
return inlineBytes(bs)
|
||||
}
|
||||
|
||||
fp := filepath.Join(cacheDir, name)
|
||||
if _, err = os.Stat(fp); os.IsNotExist(err) {
|
||||
err = ioutil.WriteFile(fp, bs, 0o600)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("error writing cache file, falling back to inline bytes")
|
||||
return inlineBytes(bs)
|
||||
}
|
||||
} else if err != nil {
|
||||
log.Error().Err(err).Msg("error reading cache file, falling back to inline bytes")
|
||||
return inlineBytes(bs)
|
||||
}
|
||||
|
||||
return inlineFilename(fp)
|
||||
}
|
||||
|
||||
func inlineFilename(name string) *envoy_config_core_v3.DataSource {
|
||||
return &envoy_config_core_v3.DataSource{
|
||||
Specifier: &envoy_config_core_v3.DataSource_Filename{
|
||||
Filename: name,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func envoyTLSCertificateFromGoTLSCertificate(cert *tls.Certificate) *envoy_extensions_transport_sockets_tls_v3.TlsCertificate {
|
||||
func (srv *Server) envoyTLSCertificateFromGoTLSCertificate(cert *tls.Certificate) *envoy_extensions_transport_sockets_tls_v3.TlsCertificate {
|
||||
envoyCert := &envoy_extensions_transport_sockets_tls_v3.TlsCertificate{}
|
||||
var chain bytes.Buffer
|
||||
for _, cbs := range cert.Certificate {
|
||||
|
@ -170,12 +122,12 @@ func envoyTLSCertificateFromGoTLSCertificate(cert *tls.Certificate) *envoy_exten
|
|||
Bytes: cbs,
|
||||
})
|
||||
}
|
||||
envoyCert.CertificateChain = inlineBytesAsFilename("tls-crt.pem", chain.Bytes())
|
||||
envoyCert.CertificateChain = srv.filemgr.BytesDataSource("tls-crt.pem", chain.Bytes())
|
||||
if cert.OCSPStaple != nil {
|
||||
envoyCert.OcspStaple = inlineBytes(cert.OCSPStaple)
|
||||
envoyCert.OcspStaple = srv.filemgr.BytesDataSource("ocsp-staple", cert.OCSPStaple)
|
||||
}
|
||||
if bs, err := x509.MarshalPKCS8PrivateKey(cert.PrivateKey); err == nil {
|
||||
envoyCert.PrivateKey = inlineBytesAsFilename("tls-key.pem", pem.EncodeToMemory(
|
||||
envoyCert.PrivateKey = srv.filemgr.BytesDataSource("tls-key.pem", pem.EncodeToMemory(
|
||||
&pem.Block{
|
||||
Type: "PRIVATE KEY",
|
||||
Bytes: bs,
|
||||
|
@ -186,7 +138,7 @@ func envoyTLSCertificateFromGoTLSCertificate(cert *tls.Certificate) *envoy_exten
|
|||
}
|
||||
for _, scts := range cert.SignedCertificateTimestamps {
|
||||
envoyCert.SignedCertificateTimestamp = append(envoyCert.SignedCertificateTimestamp,
|
||||
inlineBytes(scts))
|
||||
srv.filemgr.BytesDataSource("signed-certificate-timestamp", scts))
|
||||
}
|
||||
return envoyCert
|
||||
}
|
||||
|
|
|
@ -14,10 +14,15 @@ import (
|
|||
)
|
||||
|
||||
func Test_buildPolicyTransportSocket(t *testing.T) {
|
||||
rootCA, _ := getRootCertificateAuthority()
|
||||
cacheDir, _ := os.UserCacheDir()
|
||||
customCA := filepath.Join(cacheDir, "pomerium", "envoy", "files", "custom-ca-32484c314b584447463735303142374c31414145374650305a525539554938594d524855353757313942494d473847535231.pem")
|
||||
|
||||
srv, _ := NewServer("TEST")
|
||||
rootCAPath, _ := getRootCertificateAuthority()
|
||||
rootCA := srv.filemgr.FileDataSource(rootCAPath).GetFilename()
|
||||
|
||||
t.Run("insecure", func(t *testing.T) {
|
||||
assert.Nil(t, buildPolicyTransportSocket(&config.Policy{
|
||||
assert.Nil(t, srv.buildPolicyTransportSocket(&config.Policy{
|
||||
Destination: mustParseURL("http://example.com"),
|
||||
}))
|
||||
})
|
||||
|
@ -49,7 +54,7 @@ func Test_buildPolicyTransportSocket(t *testing.T) {
|
|||
"sni": "example.com"
|
||||
}
|
||||
}
|
||||
`, buildPolicyTransportSocket(&config.Policy{
|
||||
`, srv.buildPolicyTransportSocket(&config.Policy{
|
||||
Destination: mustParseURL("https://example.com"),
|
||||
}))
|
||||
})
|
||||
|
@ -81,7 +86,7 @@ func Test_buildPolicyTransportSocket(t *testing.T) {
|
|||
"sni": "use-this-name.example.com"
|
||||
}
|
||||
}
|
||||
`, buildPolicyTransportSocket(&config.Policy{
|
||||
`, srv.buildPolicyTransportSocket(&config.Policy{
|
||||
Destination: mustParseURL("https://example.com"),
|
||||
TLSServerName: "use-this-name.example.com",
|
||||
}))
|
||||
|
@ -115,7 +120,7 @@ func Test_buildPolicyTransportSocket(t *testing.T) {
|
|||
"sni": "example.com"
|
||||
}
|
||||
}
|
||||
`, buildPolicyTransportSocket(&config.Policy{
|
||||
`, srv.buildPolicyTransportSocket(&config.Policy{
|
||||
Destination: mustParseURL("https://example.com"),
|
||||
TLSSkipVerify: true,
|
||||
}))
|
||||
|
@ -141,14 +146,14 @@ func Test_buildPolicyTransportSocket(t *testing.T) {
|
|||
"exact": "example.com"
|
||||
}],
|
||||
"trustedCa": {
|
||||
"filename": "`+filepath.Join(cacheDir, "pomerium", "envoy", "files", "custom-ca-3aefa6fd5cf2deb4.pem")+`"
|
||||
"filename": "`+customCA+`"
|
||||
}
|
||||
}
|
||||
},
|
||||
"sni": "example.com"
|
||||
}
|
||||
}
|
||||
`, buildPolicyTransportSocket(&config.Policy{
|
||||
`, srv.buildPolicyTransportSocket(&config.Policy{
|
||||
Destination: mustParseURL("https://example.com"),
|
||||
TLSCustomCA: base64.StdEncoding.EncodeToString([]byte{0, 0, 0, 0}),
|
||||
}))
|
||||
|
@ -172,10 +177,10 @@ func Test_buildPolicyTransportSocket(t *testing.T) {
|
|||
},
|
||||
"tlsCertificates": [{
|
||||
"certificateChain":{
|
||||
"filename": "`+filepath.Join(cacheDir, "pomerium", "envoy", "files", "tls-crt-921a8294d2e2ec54.pem")+`"
|
||||
"filename": "`+filepath.Join(cacheDir, "pomerium", "envoy", "files", "tls-crt-354e49305a5a39414a545530374e58454e48334148524c4e324258463837364355564c4e4532464b54355139495547514a38.pem")+`"
|
||||
},
|
||||
"privateKey": {
|
||||
"filename": "`+filepath.Join(cacheDir, "pomerium", "envoy", "files", "tls-key-d5cf35b1e8533e4a.pem")+`"
|
||||
"filename": "`+filepath.Join(cacheDir, "pomerium", "envoy", "files", "tls-key-3350415a38414e4e4a4655424e55393430474147324651433949384e485341334b5157364f424b4c5856365a545937383735.pem")+`"
|
||||
}
|
||||
}],
|
||||
"validationContext": {
|
||||
|
@ -190,7 +195,7 @@ func Test_buildPolicyTransportSocket(t *testing.T) {
|
|||
"sni": "example.com"
|
||||
}
|
||||
}
|
||||
`, buildPolicyTransportSocket(&config.Policy{
|
||||
`, srv.buildPolicyTransportSocket(&config.Policy{
|
||||
Destination: mustParseURL("https://example.com"),
|
||||
ClientCertificate: clientCert,
|
||||
}))
|
||||
|
@ -198,7 +203,9 @@ func Test_buildPolicyTransportSocket(t *testing.T) {
|
|||
}
|
||||
|
||||
func Test_buildCluster(t *testing.T) {
|
||||
rootCA, _ := getRootCertificateAuthority()
|
||||
srv, _ := NewServer("TEST")
|
||||
rootCAPath, _ := getRootCertificateAuthority()
|
||||
rootCA := srv.filemgr.FileDataSource(rootCAPath).GetFilename()
|
||||
t.Run("insecure", func(t *testing.T) {
|
||||
cluster := buildCluster("example", mustParseURL("http://example.com"), nil, true, config.GetEnvoyDNSLookupFamily(config.DNSLookupFamilyV4Only))
|
||||
testutil.AssertProtoJSONEqual(t, `
|
||||
|
@ -232,7 +239,7 @@ func Test_buildCluster(t *testing.T) {
|
|||
})
|
||||
t.Run("secure", func(t *testing.T) {
|
||||
u := mustParseURL("https://example.com")
|
||||
transportSocket := buildPolicyTransportSocket(&config.Policy{
|
||||
transportSocket := srv.buildPolicyTransportSocket(&config.Policy{
|
||||
Destination: u,
|
||||
})
|
||||
cluster := buildCluster("example", u, transportSocket, true, config.GetEnvoyDNSLookupFamily(config.DNSLookupFamilyAuto))
|
||||
|
|
|
@ -34,17 +34,17 @@ func (srv *Server) buildClusters(options *config.Options) []*envoy_config_cluste
|
|||
}
|
||||
|
||||
clusters := []*envoy_config_cluster_v3.Cluster{
|
||||
buildInternalCluster(options, "pomerium-control-plane-grpc", grpcURL, true),
|
||||
buildInternalCluster(options, "pomerium-control-plane-http", httpURL, false),
|
||||
srv.buildInternalCluster(options, "pomerium-control-plane-grpc", grpcURL, true),
|
||||
srv.buildInternalCluster(options, "pomerium-control-plane-http", httpURL, false),
|
||||
}
|
||||
|
||||
clusters = append(clusters, buildInternalCluster(options, authzURL.Host, authzURL, true))
|
||||
clusters = append(clusters, srv.buildInternalCluster(options, authzURL.Host, authzURL, true))
|
||||
|
||||
if config.IsProxy(options.Services) {
|
||||
for i := range options.Policies {
|
||||
policy := options.Policies[i]
|
||||
if policy.Destination != nil {
|
||||
clusters = append(clusters, buildPolicyCluster(options, &policy))
|
||||
clusters = append(clusters, srv.buildPolicyCluster(options, &policy))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -52,21 +52,21 @@ func (srv *Server) buildClusters(options *config.Options) []*envoy_config_cluste
|
|||
return clusters
|
||||
}
|
||||
|
||||
func buildInternalCluster(options *config.Options, name string, endpoint *url.URL, forceHTTP2 bool) *envoy_config_cluster_v3.Cluster {
|
||||
func (srv *Server) buildInternalCluster(options *config.Options, name string, endpoint *url.URL, forceHTTP2 bool) *envoy_config_cluster_v3.Cluster {
|
||||
dnsLookupFamily := config.GetEnvoyDNSLookupFamily(options.DNSLookupFamily)
|
||||
return buildCluster(name, endpoint, buildInternalTransportSocket(options, endpoint), forceHTTP2, dnsLookupFamily)
|
||||
return buildCluster(name, endpoint, srv.buildInternalTransportSocket(options, endpoint), forceHTTP2, dnsLookupFamily)
|
||||
}
|
||||
|
||||
func buildPolicyCluster(options *config.Options, policy *config.Policy) *envoy_config_cluster_v3.Cluster {
|
||||
func (srv *Server) buildPolicyCluster(options *config.Options, policy *config.Policy) *envoy_config_cluster_v3.Cluster {
|
||||
name := getPolicyName(policy)
|
||||
dnsLookupFamily := config.GetEnvoyDNSLookupFamily(options.DNSLookupFamily)
|
||||
if policy.EnableGoogleCloudServerlessAuthentication {
|
||||
dnsLookupFamily = envoy_config_cluster_v3.Cluster_V4_ONLY
|
||||
}
|
||||
return buildCluster(name, policy.Destination, buildPolicyTransportSocket(policy), false, dnsLookupFamily)
|
||||
return buildCluster(name, policy.Destination, srv.buildPolicyTransportSocket(policy), false, dnsLookupFamily)
|
||||
}
|
||||
|
||||
func buildInternalTransportSocket(options *config.Options, endpoint *url.URL) *envoy_config_core_v3.TransportSocket {
|
||||
func (srv *Server) buildInternalTransportSocket(options *config.Options, endpoint *url.URL) *envoy_config_core_v3.TransportSocket {
|
||||
if endpoint.Scheme != "https" {
|
||||
return nil
|
||||
}
|
||||
|
@ -82,19 +82,19 @@ func buildInternalTransportSocket(options *config.Options, endpoint *url.URL) *e
|
|||
}},
|
||||
}
|
||||
if options.CAFile != "" {
|
||||
validationContext.TrustedCa = inlineFilename(options.CAFile)
|
||||
validationContext.TrustedCa = srv.filemgr.FileDataSource(options.CAFile)
|
||||
} else if options.CA != "" {
|
||||
bs, err := base64.StdEncoding.DecodeString(options.CA)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("invalid custom CA certificate")
|
||||
}
|
||||
validationContext.TrustedCa = inlineBytesAsFilename("custom-ca.pem", bs)
|
||||
validationContext.TrustedCa = srv.filemgr.BytesDataSource("custom-ca.pem", bs)
|
||||
} else {
|
||||
rootCA, err := getRootCertificateAuthority()
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("unable to enable certificate verification because no root CAs were found")
|
||||
} else {
|
||||
validationContext.TrustedCa = inlineFilename(rootCA)
|
||||
validationContext.TrustedCa = srv.filemgr.FileDataSource(rootCA)
|
||||
}
|
||||
}
|
||||
tlsContext := &envoy_extensions_transport_sockets_tls_v3.UpstreamTlsContext{
|
||||
|
@ -115,7 +115,7 @@ func buildInternalTransportSocket(options *config.Options, endpoint *url.URL) *e
|
|||
}
|
||||
}
|
||||
|
||||
func buildPolicyTransportSocket(policy *config.Policy) *envoy_config_core_v3.TransportSocket {
|
||||
func (srv *Server) buildPolicyTransportSocket(policy *config.Policy) *envoy_config_core_v3.TransportSocket {
|
||||
if policy.Destination == nil || policy.Destination.Scheme != "https" {
|
||||
return nil
|
||||
}
|
||||
|
@ -136,14 +136,14 @@ func buildPolicyTransportSocket(policy *config.Policy) *envoy_config_core_v3.Tra
|
|||
},
|
||||
AlpnProtocols: []string{"http/1.1"},
|
||||
ValidationContextType: &envoy_extensions_transport_sockets_tls_v3.CommonTlsContext_ValidationContext{
|
||||
ValidationContext: buildPolicyValidationContext(policy),
|
||||
ValidationContext: srv.buildPolicyValidationContext(policy),
|
||||
},
|
||||
},
|
||||
Sni: sni,
|
||||
}
|
||||
if policy.ClientCertificate != nil {
|
||||
tlsContext.CommonTlsContext.TlsCertificates = append(tlsContext.CommonTlsContext.TlsCertificates,
|
||||
envoyTLSCertificateFromGoTLSCertificate(policy.ClientCertificate))
|
||||
srv.envoyTLSCertificateFromGoTLSCertificate(policy.ClientCertificate))
|
||||
}
|
||||
|
||||
tlsConfig := marshalAny(tlsContext)
|
||||
|
@ -155,7 +155,7 @@ func buildPolicyTransportSocket(policy *config.Policy) *envoy_config_core_v3.Tra
|
|||
}
|
||||
}
|
||||
|
||||
func buildPolicyValidationContext(policy *config.Policy) *envoy_extensions_transport_sockets_tls_v3.CertificateValidationContext {
|
||||
func (srv *Server) buildPolicyValidationContext(policy *config.Policy) *envoy_extensions_transport_sockets_tls_v3.CertificateValidationContext {
|
||||
if policy.Destination == nil {
|
||||
return nil
|
||||
}
|
||||
|
@ -172,19 +172,19 @@ func buildPolicyValidationContext(policy *config.Policy) *envoy_extensions_trans
|
|||
}},
|
||||
}
|
||||
if policy.TLSCustomCAFile != "" {
|
||||
validationContext.TrustedCa = inlineFilename(policy.TLSCustomCAFile)
|
||||
validationContext.TrustedCa = srv.filemgr.FileDataSource(policy.TLSCustomCAFile)
|
||||
} else if policy.TLSCustomCA != "" {
|
||||
bs, err := base64.StdEncoding.DecodeString(policy.TLSCustomCA)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("invalid custom CA certificate")
|
||||
}
|
||||
validationContext.TrustedCa = inlineBytesAsFilename("custom-ca.pem", bs)
|
||||
validationContext.TrustedCa = srv.filemgr.BytesDataSource("custom-ca.pem", bs)
|
||||
} else {
|
||||
rootCA, err := getRootCertificateAuthority()
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("unable to enable certificate verification because no root CAs were found")
|
||||
} else {
|
||||
validationContext.TrustedCa = inlineFilename(rootCA)
|
||||
validationContext.TrustedCa = srv.filemgr.FileDataSource(rootCA)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -37,21 +37,21 @@ func init() {
|
|||
})
|
||||
}
|
||||
|
||||
func buildListeners(options *config.Options) []*envoy_config_listener_v3.Listener {
|
||||
func (srv *Server) buildListeners(options *config.Options) []*envoy_config_listener_v3.Listener {
|
||||
var listeners []*envoy_config_listener_v3.Listener
|
||||
|
||||
if config.IsAuthenticate(options.Services) || config.IsProxy(options.Services) {
|
||||
listeners = append(listeners, buildMainListener(options))
|
||||
listeners = append(listeners, srv.buildMainListener(options))
|
||||
}
|
||||
|
||||
if config.IsAuthorize(options.Services) || config.IsCache(options.Services) {
|
||||
listeners = append(listeners, buildGRPCListener(options))
|
||||
listeners = append(listeners, srv.buildGRPCListener(options))
|
||||
}
|
||||
|
||||
return listeners
|
||||
}
|
||||
|
||||
func buildMainListener(options *config.Options) *envoy_config_listener_v3.Listener {
|
||||
func (srv *Server) buildMainListener(options *config.Options) *envoy_config_listener_v3.Listener {
|
||||
if options.InsecureServer {
|
||||
filter := buildMainHTTPConnectionManagerFilter(options,
|
||||
getAllRouteableDomains(options, options.Addr))
|
||||
|
@ -88,7 +88,7 @@ func buildMainListener(options *config.Options) *envoy_config_listener_v3.Listen
|
|||
ServerNames: []string{tlsDomain},
|
||||
}
|
||||
}
|
||||
tlsContext := buildDownstreamTLSContext(options, tlsDomain)
|
||||
tlsContext := srv.buildDownstreamTLSContext(options, tlsDomain)
|
||||
if tlsContext != nil {
|
||||
tlsConfig := marshalAny(tlsContext)
|
||||
filterChain.TransportSocket = &envoy_config_core_v3.TransportSocket{
|
||||
|
@ -250,7 +250,7 @@ func buildMainHTTPConnectionManagerFilter(options *config.Options, domains []str
|
|||
}
|
||||
}
|
||||
|
||||
func buildGRPCListener(options *config.Options) *envoy_config_listener_v3.Listener {
|
||||
func (srv *Server) buildGRPCListener(options *config.Options) *envoy_config_listener_v3.Listener {
|
||||
filter := buildGRPCHTTPConnectionManagerFilter()
|
||||
|
||||
if options.GRPCInsecure {
|
||||
|
@ -285,7 +285,7 @@ func buildGRPCListener(options *config.Options) *envoy_config_listener_v3.Listen
|
|||
ServerNames: []string{tlsDomain},
|
||||
}
|
||||
}
|
||||
tlsContext := buildDownstreamTLSContext(options, tlsDomain)
|
||||
tlsContext := srv.buildDownstreamTLSContext(options, tlsDomain)
|
||||
if tlsContext != nil {
|
||||
tlsConfig := marshalAny(tlsContext)
|
||||
filterChain.TransportSocket = &envoy_config_core_v3.TransportSocket{
|
||||
|
@ -357,7 +357,7 @@ func buildRouteConfiguration(name string, virtualHosts []*envoy_config_route_v3.
|
|||
}
|
||||
}
|
||||
|
||||
func buildDownstreamTLSContext(options *config.Options, domain string) *envoy_extensions_transport_sockets_tls_v3.DownstreamTlsContext {
|
||||
func (srv *Server) buildDownstreamTLSContext(options *config.Options, domain string) *envoy_extensions_transport_sockets_tls_v3.DownstreamTlsContext {
|
||||
cert, err := cryptutil.GetCertificateForDomain(options.Certificates, domain)
|
||||
if err != nil {
|
||||
log.Warn().Str("domain", domain).Err(err).Msg("failed to get certificate for domain")
|
||||
|
@ -370,9 +370,9 @@ func buildDownstreamTLSContext(options *config.Options, domain string) *envoy_ex
|
|||
if err != nil {
|
||||
log.Warn().Msg("client_ca does not appear to be a base64 encoded string")
|
||||
}
|
||||
trustedCA = inlineBytesAsFilename("client-ca", bs)
|
||||
trustedCA = srv.filemgr.BytesDataSource("client-ca", bs)
|
||||
} else if options.ClientCAFile != "" {
|
||||
trustedCA = inlineFilename(options.ClientCAFile)
|
||||
trustedCA = srv.filemgr.FileDataSource(options.ClientCAFile)
|
||||
}
|
||||
|
||||
var validationContext *envoy_extensions_transport_sockets_tls_v3.CommonTlsContext_ValidationContext
|
||||
|
@ -385,7 +385,7 @@ func buildDownstreamTLSContext(options *config.Options, domain string) *envoy_ex
|
|||
}
|
||||
}
|
||||
|
||||
envoyCert := envoyTLSCertificateFromGoTLSCertificate(cert)
|
||||
envoyCert := srv.envoyTLSCertificateFromGoTLSCertificate(cert)
|
||||
return &envoy_extensions_transport_sockets_tls_v3.DownstreamTlsContext{
|
||||
CommonTlsContext: &envoy_extensions_transport_sockets_tls_v3.CommonTlsContext{
|
||||
TlsParams: &envoy_extensions_transport_sockets_tls_v3.TlsParameters{
|
||||
|
|
|
@ -385,13 +385,15 @@ func Test_buildDownstreamTLSContext(t *testing.T) {
|
|||
return
|
||||
}
|
||||
|
||||
downstreamTLSContext := buildDownstreamTLSContext(&config.Options{
|
||||
srv, _ := NewServer("TEST")
|
||||
|
||||
downstreamTLSContext := srv.buildDownstreamTLSContext(&config.Options{
|
||||
Certificates: []tls.Certificate{*certA},
|
||||
}, "a.example.com")
|
||||
|
||||
cacheDir, _ := os.UserCacheDir()
|
||||
certFileName := filepath.Join(cacheDir, "pomerium", "envoy", "files", "tls-crt-921a8294d2e2ec54.pem")
|
||||
keyFileName := filepath.Join(cacheDir, "pomerium", "envoy", "files", "tls-key-d5cf35b1e8533e4a.pem")
|
||||
certFileName := filepath.Join(cacheDir, "pomerium", "envoy", "files", "tls-crt-354e49305a5a39414a545530374e58454e48334148524c4e324258463837364355564c4e4532464b54355139495547514a38.pem")
|
||||
keyFileName := filepath.Join(cacheDir, "pomerium", "envoy", "files", "tls-key-3350415a38414e4e4a4655424e55393430474147324651433949384e485341334b5157364f424b4c5856365a545937383735.pem")
|
||||
|
||||
testutil.AssertProtoJSONEqual(t, `{
|
||||
"commonTlsContext": {
|
||||
|
|
66
internal/fileutil/watcher.go
Normal file
66
internal/fileutil/watcher.go
Normal file
|
@ -0,0 +1,66 @@
|
|||
package fileutil
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/rjeczalik/notify"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/internal/signal"
|
||||
)
|
||||
|
||||
// A Watcher watches files for changes.
|
||||
type Watcher struct {
|
||||
*signal.Signal
|
||||
mu sync.Mutex
|
||||
filePaths map[string]chan notify.EventInfo
|
||||
}
|
||||
|
||||
// NewWatcher creates a new Watcher.
|
||||
func NewWatcher() *Watcher {
|
||||
return &Watcher{
|
||||
Signal: signal.New(),
|
||||
filePaths: map[string]chan notify.EventInfo{},
|
||||
}
|
||||
}
|
||||
|
||||
// Add adds a new watch.
|
||||
func (watcher *Watcher) Add(filePath string) {
|
||||
watcher.mu.Lock()
|
||||
defer watcher.mu.Unlock()
|
||||
|
||||
// already watching
|
||||
if _, ok := watcher.filePaths[filePath]; ok {
|
||||
return
|
||||
}
|
||||
|
||||
ch := make(chan notify.EventInfo, 1)
|
||||
go func() {
|
||||
for evt := range ch {
|
||||
log.Info().Str("path", evt.Path()).Str("event", evt.Event().String()).Msg("filemgr: detected file change")
|
||||
watcher.Signal.Broadcast()
|
||||
}
|
||||
}()
|
||||
err := notify.Watch(filePath, ch, notify.All)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Str("path", filePath).Msg("filemgr: error watching file path")
|
||||
notify.Stop(ch)
|
||||
close(ch)
|
||||
return
|
||||
}
|
||||
log.Debug().Str("path", filePath).Msg("filemgr: watching file for changes")
|
||||
|
||||
watcher.filePaths[filePath] = ch
|
||||
}
|
||||
|
||||
// Clear removes all watches.
|
||||
func (watcher *Watcher) Clear() {
|
||||
watcher.mu.Lock()
|
||||
defer watcher.mu.Unlock()
|
||||
|
||||
for filePath, ch := range watcher.filePaths {
|
||||
notify.Stop(ch)
|
||||
close(ch)
|
||||
delete(watcher.filePaths, filePath)
|
||||
}
|
||||
}
|
42
internal/fileutil/watcher_test.go
Normal file
42
internal/fileutil/watcher_test.go
Normal file
|
@ -0,0 +1,42 @@
|
|||
package fileutil
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestWatcher(t *testing.T) {
|
||||
tmpdir := filepath.Join(os.TempDir(), uuid.New().String())
|
||||
err := os.MkdirAll(tmpdir, 0o755)
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
|
||||
err = ioutil.WriteFile(filepath.Join(tmpdir, "test1.txt"), []byte{1, 2, 3, 4}, 0o666)
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
|
||||
w := NewWatcher()
|
||||
w.Add(filepath.Join(tmpdir, "test1.txt"))
|
||||
|
||||
ch := w.Bind()
|
||||
defer w.Unbind(ch)
|
||||
|
||||
err = ioutil.WriteFile(filepath.Join(tmpdir, "test1.txt"), []byte{5, 6, 7, 8}, 0o666)
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ch:
|
||||
case <-time.After(time.Second):
|
||||
t.Error("expected change signal when file is modified")
|
||||
}
|
||||
}
|
Loading…
Add table
Reference in a new issue