diff --git a/config/config_source.go b/config/config_source.go index 9c20534de..8c6c88418 100644 --- a/config/config_source.go +++ b/config/config_source.go @@ -2,11 +2,12 @@ package config import ( "context" - "crypto/sha256" "fmt" + "io" "os" "sync" + "github.com/cespare/xxhash/v2" "github.com/google/uuid" "github.com/rs/zerolog" @@ -15,6 +16,7 @@ import ( "github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/telemetry/metrics" "github.com/pomerium/pomerium/pkg/netutil" + "github.com/pomerium/pomerium/pkg/slices" ) // A ChangeListener is called when configuration changes. @@ -131,7 +133,9 @@ func NewFileOrEnvironmentSource( watcher: fileutil.NewWatcher(), config: cfg, } - src.watcher.Add(configFile) + if configFile != "" { + src.watcher.Watch(ctx, []string{configFile}) + } ch := src.watcher.Bind() go func() { for range ch { @@ -179,29 +183,32 @@ type FileWatcherSource struct { underlying Source watcher *fileutil.Watcher - mu sync.RWMutex - computedConfig *Config + mu sync.RWMutex + hash uint64 + cfg *Config ChangeDispatcher } // NewFileWatcherSource creates a new FileWatcherSource -func NewFileWatcherSource(underlying Source) *FileWatcherSource { +func NewFileWatcherSource(ctx context.Context, underlying Source) *FileWatcherSource { + cfg := underlying.GetConfig() src := &FileWatcherSource{ underlying: underlying, watcher: fileutil.NewWatcher(), + cfg: cfg, } ch := src.watcher.Bind() go func() { for range ch { - src.check(context.TODO(), underlying.GetConfig()) + src.onFileChange(ctx) } }() - underlying.OnConfigChange(context.TODO(), func(ctx context.Context, cfg *Config) { - src.check(ctx, cfg) + underlying.OnConfigChange(ctx, func(ctx context.Context, cfg *Config) { + src.onConfigChange(ctx, cfg) }) - src.check(context.TODO(), underlying.GetConfig()) + src.onConfigChange(ctx, cfg) return src } @@ -210,20 +217,56 @@ func NewFileWatcherSource(underlying Source) *FileWatcherSource { func (src *FileWatcherSource) GetConfig() *Config { src.mu.RLock() defer src.mu.RUnlock() - return src.computedConfig + + return src.cfg } -func (src *FileWatcherSource) check(ctx context.Context, cfg *Config) { - if cfg == nil || cfg.Options == nil { - return - } +func (src *FileWatcherSource) onConfigChange(ctx context.Context, cfg *Config) { + // update the file watcher with paths from the config + src.watcher.Watch(ctx, getAllConfigFilePaths(cfg)) src.mu.Lock() defer src.mu.Unlock() - src.watcher.Clear() + // store the config and trigger an update + src.cfg = cfg.Clone() + src.hash = getAllConfigFilePathsHash(src.cfg) + log.Info(ctx).Uint64("hash", src.hash).Msg("config/filewatchersource: underlying config change, triggering update") + src.Trigger(ctx, src.cfg) +} - h := sha256.New() +func (src *FileWatcherSource) onFileChange(ctx context.Context) { + src.mu.Lock() + defer src.mu.Unlock() + + hash := getAllConfigFilePathsHash(src.cfg) + + if hash == src.hash { + log.Info(ctx).Uint64("hash", src.hash).Msg("config/filewatchersource: no change detected") + } else { + // if the hash changed, trigger an update + // the actual config will be identical + src.hash = hash + log.Info(ctx).Uint64("hash", src.hash).Msg("config/filewatchersource: change detected, triggering update") + src.Trigger(ctx, src.cfg) + } +} + +func getAllConfigFilePathsHash(cfg *Config) uint64 { + // read all the config files and build a hash from their contents + h := xxhash.New() + for _, f := range getAllConfigFilePaths(cfg) { + _, _ = h.Write([]byte{0}) + f, err := os.Open(f) + if err == nil { + _, _ = io.Copy(h, f) + _ = f.Close() + } + } + return h.Sum64() +} + +func getAllConfigFilePaths(cfg *Config) []string { fs := []string{ cfg.Options.CAFile, cfg.Options.CertFile, @@ -258,18 +301,9 @@ func (src *FileWatcherSource) check(ctx context.Context, cfg *Config) { ) } - for _, f := range fs { - _, _ = h.Write([]byte{0}) - bs, err := os.ReadFile(f) - if err == nil { - src.watcher.Add(f) - _, _ = h.Write(bs) - } - } + fs = slices.Filter(fs, func(s string) bool { + return s != "" + }) - // update the computed config - src.computedConfig = cfg.Clone() - - // trigger a change - src.Trigger(ctx, src.computedConfig) + return fs } diff --git a/config/config_source_test.go b/config/config_source_test.go index 47d26a741..0ef9ed20e 100644 --- a/config/config_source_test.go +++ b/config/config_source_test.go @@ -16,12 +16,12 @@ func TestFileWatcherSource(t *testing.T) { tmpdir := t.TempDir() - err := os.WriteFile(filepath.Join(tmpdir, "example.txt"), []byte{1, 2, 3, 4}, 0o600) + err := os.WriteFile(filepath.Join(tmpdir, "example.txt"), []byte{1}, 0o600) if !assert.NoError(t, err) { return } - err = os.WriteFile(filepath.Join(tmpdir, "kubernetes-example.txt"), []byte{1, 2, 3, 4}, 0o600) + err = os.WriteFile(filepath.Join(tmpdir, "kubernetes-example.txt"), []byte{2}, 0o600) if !assert.NoError(t, err) { return } @@ -35,7 +35,7 @@ func TestFileWatcherSource(t *testing.T) { }, }) - src := NewFileWatcherSource(ssrc) + src := NewFileWatcherSource(ctx, ssrc) var closeOnce sync.Once ch := make(chan struct{}) src.OnConfigChange(context.Background(), func(ctx context.Context, cfg *Config) { @@ -44,7 +44,7 @@ func TestFileWatcherSource(t *testing.T) { }) }) - err = os.WriteFile(filepath.Join(tmpdir, "example.txt"), []byte{5, 6, 7, 8}, 0o600) + err = os.WriteFile(filepath.Join(tmpdir, "example.txt"), []byte{1, 2}, 0o600) if !assert.NoError(t, err) { return } @@ -55,7 +55,7 @@ func TestFileWatcherSource(t *testing.T) { t.Error("expected OnConfigChange to be fired after modifying a file") } - err = os.WriteFile(filepath.Join(tmpdir, "kubernetes-example.txt"), []byte{5, 6, 7, 8}, 0o600) + err = os.WriteFile(filepath.Join(tmpdir, "kubernetes-example.txt"), []byte{2, 3}, 0o600) if !assert.NoError(t, err) { return } diff --git a/go.mod b/go.mod index 6b4036ad1..ee28d0910 100644 --- a/go.mod +++ b/go.mod @@ -21,7 +21,6 @@ require ( github.com/docker/docker v24.0.7+incompatible github.com/envoyproxy/go-control-plane v0.11.1 github.com/envoyproxy/protoc-gen-validate v1.0.2 - github.com/fsnotify/fsnotify v1.6.0 github.com/go-chi/chi/v5 v5.0.10 github.com/go-jose/go-jose/v3 v3.0.0 github.com/go-redis/redis/v8 v8.11.5 @@ -116,6 +115,7 @@ require ( github.com/docker/go-connections v0.4.0 // indirect github.com/docker/go-units v0.5.0 // indirect github.com/dustin/go-humanize v1.0.1 // indirect + github.com/fsnotify/fsnotify v1.6.0 // indirect github.com/fxamacker/cbor/v2 v2.4.0 // indirect github.com/go-ini/ini v1.67.0 // indirect github.com/go-kit/log v0.2.1 // indirect diff --git a/internal/fileutil/watcher.go b/internal/fileutil/watcher.go index c1f6ee34e..ea1d3c350 100644 --- a/internal/fileutil/watcher.go +++ b/internal/fileutil/watcher.go @@ -4,11 +4,8 @@ import ( "context" "sync" - "github.com/fsnotify/fsnotify" - "github.com/rs/zerolog" "namespacelabs.dev/go-filenotify" - "github.com/pomerium/pomerium/internal/chanutil" "github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/signal" ) @@ -19,7 +16,6 @@ type Watcher struct { mu sync.Mutex watching map[string]struct{} - eventWatcher filenotify.FileWatcher pollingWatcher filenotify.FileWatcher } @@ -31,75 +27,63 @@ func NewWatcher() *Watcher { } } -// Add adds a new watch. -func (watcher *Watcher) Add(filePath string) { +// Watch updates the watched file paths. +func (watcher *Watcher) Watch(ctx context.Context, filePaths []string) { watcher.mu.Lock() defer watcher.mu.Unlock() - // already watching - if _, ok := watcher.watching[filePath]; ok { - return - } - - ctx := log.WithContext(context.Background(), func(c zerolog.Context) zerolog.Context { - return c.Str("watch_file", filePath) - }) watcher.initLocked(ctx) - if watcher.eventWatcher != nil { - if err := watcher.eventWatcher.Add(filePath); err != nil { - log.Error(ctx).Msg("fileutil/watcher: failed to watch file with event-based file watcher") + var add []string + seen := map[string]struct{}{} + for _, filePath := range filePaths { + if _, ok := watcher.watching[filePath]; !ok { + add = append(add, filePath) + } + seen[filePath] = struct{}{} + } + + var remove []string + for filePath := range watcher.watching { + if _, ok := seen[filePath]; !ok { + remove = append(remove, filePath) } } - if watcher.pollingWatcher != nil { - if err := watcher.pollingWatcher.Add(filePath); err != nil { - log.Error(ctx).Msg("fileutil/watcher: failed to watch file with polling-based file watcher") + for _, filePath := range add { + watcher.watching[filePath] = struct{}{} + + if watcher.pollingWatcher != nil { + err := watcher.pollingWatcher.Add(filePath) + if err != nil { + log.Error(ctx).Err(err).Str("file", filePath).Msg("fileutil/watcher: failed to add file to polling-based file watcher") + } } } -} -// Clear removes all watches. -func (watcher *Watcher) Clear() { - watcher.mu.Lock() - defer watcher.mu.Unlock() + for _, filePath := range remove { + delete(watcher.watching, filePath) - if w := watcher.eventWatcher; w != nil { - _ = watcher.pollingWatcher.Close() - watcher.eventWatcher = nil + if watcher.pollingWatcher != nil { + err := watcher.pollingWatcher.Remove(filePath) + if err != nil { + log.Error(ctx).Err(err).Str("file", filePath).Msg("fileutil/watcher: failed to remove file from polling-based file watcher") + } + } } - - if w := watcher.pollingWatcher; w != nil { - _ = watcher.pollingWatcher.Close() - watcher.pollingWatcher = nil - } - - watcher.watching = make(map[string]struct{}) } func (watcher *Watcher) initLocked(ctx context.Context) { - if watcher.eventWatcher != nil || watcher.pollingWatcher != nil { + if watcher.pollingWatcher != nil { return } - if watcher.eventWatcher == nil { - var err error - watcher.eventWatcher, err = filenotify.NewEventWatcher() - if err != nil { - log.Error(ctx).Msg("fileutil/watcher: failed to create event-based file watcher") - } - } if watcher.pollingWatcher == nil { watcher.pollingWatcher = filenotify.NewPollingWatcher(nil) } - var errors <-chan error = watcher.pollingWatcher.Errors() //nolint - var events <-chan fsnotify.Event = watcher.pollingWatcher.Events() //nolint - - if watcher.eventWatcher != nil { - errors = chanutil.Merge(errors, watcher.eventWatcher.Errors()) - events = chanutil.Merge(events, watcher.eventWatcher.Events()) - } + errors := watcher.pollingWatcher.Errors() + events := watcher.pollingWatcher.Events() // log errors go func() { @@ -110,10 +94,8 @@ func (watcher *Watcher) initLocked(ctx context.Context) { // handle events go func() { - for evts := range chanutil.Batch(events) { - for _, evt := range evts { - log.Info(ctx).Str("name", evt.Name).Str("op", evt.Op.String()).Msg("fileutil/watcher: file notification event") - } + for evt := range events { + log.Info(ctx).Str("name", evt.Name).Str("op", evt.Op.String()).Msg("fileutil/watcher: file notification event") watcher.Broadcast(ctx) } }() diff --git a/internal/fileutil/watcher_test.go b/internal/fileutil/watcher_test.go index fd64f02f2..49f3ee446 100644 --- a/internal/fileutil/watcher_test.go +++ b/internal/fileutil/watcher_test.go @@ -1,79 +1,94 @@ package fileutil import ( + "context" "os" "path/filepath" "testing" "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestWatcher(t *testing.T) { tmpdir := t.TempDir() - err := os.WriteFile(filepath.Join(tmpdir, "test1.txt"), []byte{1, 2, 3, 4}, 0o666) - if !assert.NoError(t, err) { - return - } + err := os.WriteFile(filepath.Join(tmpdir, "test1.txt"), []byte{1}, 0o666) + require.NoError(t, err) w := NewWatcher() - defer w.Clear() - w.Add(filepath.Join(tmpdir, "test1.txt")) - - ch := w.Bind() - defer w.Unbind(ch) - - err = os.WriteFile(filepath.Join(tmpdir, "test1.txt"), []byte{5, 6, 7, 8}, 0o666) - if !assert.NoError(t, err) { - return - } - - select { - case <-ch: - case <-time.After(time.Second): - t.Error("expected change signal when file is modified") - } -} - -func TestWatcherSymlink(t *testing.T) { - t.Parallel() - - tmpdir := t.TempDir() - - err := os.WriteFile(filepath.Join(tmpdir, "test1.txt"), []byte{1, 2, 3, 4}, 0o666) - if !assert.NoError(t, err) { - return - } - - err = os.WriteFile(filepath.Join(tmpdir, "test2.txt"), []byte{5, 6, 7, 8}, 0o666) - if !assert.NoError(t, err) { - return - } - - assert.NoError(t, os.Symlink(filepath.Join(tmpdir, "test1.txt"), filepath.Join(tmpdir, "symlink1.txt"))) - - w := NewWatcher() - defer w.Clear() - w.Add(filepath.Join(tmpdir, "symlink1.txt")) + w.Watch(context.Background(), []string{filepath.Join(tmpdir, "test1.txt")}) ch := w.Bind() t.Cleanup(func() { w.Unbind(ch) }) - assert.NoError(t, os.WriteFile(filepath.Join(tmpdir, "test1.txt"), []byte{9, 10, 11}, 0o666)) + err = os.WriteFile(filepath.Join(tmpdir, "test1.txt"), []byte{1, 2}, 0o666) + require.NoError(t, err) - select { - case <-ch: - case <-time.After(time.Second): - t.Error("expected change signal when underlying file is modified") - } + expectChange(t, ch) +} + +func TestWatcherSymlink(t *testing.T) { + tmpdir := t.TempDir() + + err := os.WriteFile(filepath.Join(tmpdir, "test1.txt"), []byte{1}, 0o666) + require.NoError(t, err) + + err = os.WriteFile(filepath.Join(tmpdir, "test2.txt"), []byte{1, 2}, 0o666) + require.NoError(t, err) + + assert.NoError(t, os.Symlink(filepath.Join(tmpdir, "test1.txt"), filepath.Join(tmpdir, "symlink1.txt"))) + + w := NewWatcher() + w.Watch(context.Background(), []string{filepath.Join(tmpdir, "symlink1.txt")}) + + ch := w.Bind() + t.Cleanup(func() { w.Unbind(ch) }) + + assert.NoError(t, os.WriteFile(filepath.Join(tmpdir, "test1.txt"), []byte{1, 2, 3}, 0o666)) + + expectChange(t, ch) assert.NoError(t, os.Symlink(filepath.Join(tmpdir, "test2.txt"), filepath.Join(tmpdir, "symlink2.txt"))) assert.NoError(t, os.Rename(filepath.Join(tmpdir, "symlink2.txt"), filepath.Join(tmpdir, "symlink1.txt"))) + expectChange(t, ch) +} + +func TestWatcher_FileRemoval(t *testing.T) { + tmpdir := t.TempDir() + + err := os.WriteFile(filepath.Join(tmpdir, "test1.txt"), []byte{1}, 0o666) + require.NoError(t, err) + + w := NewWatcher() + w.Watch(context.Background(), []string{filepath.Join(tmpdir, "test1.txt")}) + + ch := w.Bind() + t.Cleanup(func() { w.Unbind(ch) }) + + err = os.Remove(filepath.Join(tmpdir, "test1.txt")) + require.NoError(t, err) + + expectChange(t, ch) + + err = os.WriteFile(filepath.Join(tmpdir, "test1.txt"), []byte{1, 2}, 0o666) + require.NoError(t, err) + + expectChange(t, ch) +} + +func expectChange(t *testing.T, ch chan context.Context) { + t.Helper() + + cnt := 0 select { case <-ch: + cnt++ case <-time.After(10 * time.Second): - t.Error("expected change signal when symlink is changed") + } + if cnt == 0 { + t.Error("expected change signal") } } diff --git a/pkg/cmd/pomerium/pomerium.go b/pkg/cmd/pomerium/pomerium.go index c732463c3..de54968c5 100644 --- a/pkg/cmd/pomerium/pomerium.go +++ b/pkg/cmd/pomerium/pomerium.go @@ -45,7 +45,7 @@ func Run(ctx context.Context, src config.Source) error { defer logMgr.Close() // trigger changes when underlying files are changed - src = config.NewFileWatcherSource(src) + src = config.NewFileWatcherSource(ctx, src) src, err = autocert.New(src) if err != nil {