diff --git a/go.mod b/go.mod index 17ef1a35f..9f082cd63 100644 --- a/go.mod +++ b/go.mod @@ -19,6 +19,7 @@ require ( github.com/envoyproxy/go-control-plane/envoy v1.32.3 github.com/envoyproxy/protoc-gen-validate v1.1.0 github.com/exaring/otelpgx v0.8.0 + github.com/fsnotify/fsnotify v1.8.0 github.com/go-chi/chi/v5 v5.2.0 github.com/go-jose/go-jose/v3 v3.0.3 github.com/google/btree v1.1.3 @@ -96,7 +97,6 @@ require ( google.golang.org/grpc v1.69.2 google.golang.org/protobuf v1.36.2 gopkg.in/yaml.v3 v3.0.1 - namespacelabs.dev/go-filenotify v0.0.0-20220511192020-53ea11be7eaa sigs.k8s.io/yaml v1.4.0 ) @@ -147,7 +147,6 @@ require ( github.com/dustin/go-humanize v1.0.1 // indirect github.com/envoyproxy/go-control-plane/ratelimit v0.1.0 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect - github.com/fsnotify/fsnotify v1.8.0 // indirect github.com/fxamacker/cbor/v2 v2.6.0 // indirect github.com/go-ini/ini v1.67.0 // indirect github.com/go-jose/go-jose/v4 v4.0.2 // indirect diff --git a/go.sum b/go.sum index a976b2553..608aba0cd 100644 --- a/go.sum +++ b/go.sum @@ -220,7 +220,6 @@ github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHk github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= -github.com/fsnotify/fsnotify v1.5.4/go.mod h1:OVB6XrOHzAwXMpEM7uPOzcehqUV2UqJxmVXmkdnm1bU= github.com/fsnotify/fsnotify v1.8.0 h1:dAwr6QBTBZIkG8roQaJjGof0pp0EeF+tNV7YBP3F/8M= github.com/fsnotify/fsnotify v1.8.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0= github.com/fxamacker/cbor/v2 v2.6.0 h1:sU6J2usfADwWlYDAFhZBQ6TnLFBHxgesMrQfQgk1tWA= @@ -902,7 +901,6 @@ golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220114195835-da31bd327af9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220412211240-33da011f77ad/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220708085239-5a0f0661e09d/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -1113,8 +1111,6 @@ honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWh honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg= honnef.co/go/tools v0.0.1-2020.1.3/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9vFzvIQ3k= honnef.co/go/tools v0.0.1-2020.1.4/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9vFzvIQ3k= -namespacelabs.dev/go-filenotify v0.0.0-20220511192020-53ea11be7eaa h1:jj2kjs0Hvufj40wuhMAzoZUOwrwMDFg1gHZ49RiIv9w= -namespacelabs.dev/go-filenotify v0.0.0-20220511192020-53ea11be7eaa/go.mod h1:e8NJRaInXRRm1+KPA6EkGEzdLJAgEvVSIKiLzpP97nI= rsc.io/binaryregexp v0.2.0/go.mod h1:qTv7/COck+e2FymRvadv62gMdZztPaShugOCi3I+8D8= rsc.io/quote/v3 v3.1.0/go.mod h1:yEA65RcK8LyAZtP9Kv3t0HmxON59tX3rD+tICJqUlj0= rsc.io/sampler v1.3.0/go.mod h1:T1hPZKmBbMNahiBKFy5HrXp6adAjACjK9JXDnKaTXpA= diff --git a/internal/fileutil/watcher.go b/internal/fileutil/watcher.go index d5601417c..38a8c60b0 100644 --- a/internal/fileutil/watcher.go +++ b/internal/fileutil/watcher.go @@ -1,15 +1,53 @@ package fileutil import ( + "cmp" "context" + "io" + "io/fs" + "os" + "path/filepath" "sync" + "time" - "namespacelabs.dev/go-filenotify" + "github.com/fsnotify/fsnotify" + "github.com/hashicorp/go-set/v3" + "github.com/zeebo/xxh3" "github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/signal" ) +const ( + pollingInterval = time.Millisecond * 200 +) + +type watchedFile struct { + path string + size int64 + modTime int64 + hash uint64 + force bool // indicates that the next check should compute the hash of the file as well +} + +func newWatchedFile(path string) *watchedFile { + return &watchedFile{path: path, force: true} +} + +func (wf *watchedFile) check() (changed bool) { + fi, _ := os.Stat(wf.path) + changed = swap(&wf.size, getFileSize(fi)) || changed + changed = swap(&wf.modTime, getFileModTime(fi)) || changed + + // if the file size or mod time has changed, re-compute the file contents hash + if changed || wf.force { + changed = swap(&wf.hash, hashFile(wf.path)) + wf.force = false + } + + return changed +} + // A Watcher watches files for changes. type Watcher struct { *signal.Signal @@ -18,17 +56,31 @@ type Watcher struct { cancel context.CancelFunc mu sync.Mutex - watching map[string]struct{} - pollingWatcher filenotify.FileWatcher + notifyWatcher *fsnotify.Watcher + filePaths []string + files map[string]*watchedFile + directoryPaths []string + directories map[string]struct{} } // NewWatcher creates a new Watcher. func NewWatcher() *Watcher { w := &Watcher{ - Signal: signal.New(), - watching: make(map[string]struct{}), + Signal: signal.New(), + files: map[string]*watchedFile{}, + directories: map[string]struct{}{}, } w.cancelCtx, w.cancel = context.WithCancel(context.Background()) + + var err error + w.notifyWatcher, err = fsnotify.NewWatcher() + if err != nil { + log.Error().Err(err).Msg("fileutil/watcher: file system notifications disabled") + } + + go w.handlePolling() + go w.handleNotifications() + return w } @@ -40,9 +92,9 @@ func (w *Watcher) Close() error { defer w.mu.Unlock() var err error - if w.pollingWatcher != nil { - err = w.pollingWatcher.Close() - w.pollingWatcher = nil + if w.notifyWatcher != nil { + err = w.notifyWatcher.Close() + w.notifyWatcher = nil } return err @@ -53,71 +105,180 @@ func (w *Watcher) Watch(filePaths []string) { w.mu.Lock() defer w.mu.Unlock() - w.initLocked() - - var add []string - seen := map[string]struct{}{} - for _, filePath := range filePaths { - if _, ok := w.watching[filePath]; !ok { - add = append(add, filePath) - } - seen[filePath] = struct{}{} + fps := set.NewTreeSet(cmp.Compare[string]) + for _, fp := range filePaths { + fps.Insert(fp) } + w.filePaths = fps.Slice() - var remove []string - for filePath := range w.watching { - if _, ok := seen[filePath]; !ok { - remove = append(remove, filePath) - } + dps := set.NewTreeSet(cmp.Compare[string]) + for _, fp := range filePaths { + dps.Insert(filepath.Dir(fp)) } + w.directoryPaths = dps.Slice() - for _, filePath := range add { - w.watching[filePath] = struct{}{} - - if w.pollingWatcher != nil { - err := w.pollingWatcher.Add(filePath) - if err != nil { - log.Error().Err(err).Str("file", filePath).Msg("fileutil/watcher: failed to add file to polling-based file watcher") - } - } - } - - for _, filePath := range remove { - delete(w.watching, filePath) - - if w.pollingWatcher != nil { - err := w.pollingWatcher.Remove(filePath) - if err != nil { - log.Error().Err(err).Str("file", filePath).Msg("fileutil/watcher: failed to remove file from polling-based file watcher") - } - } - } + w.checkLocked() } -func (w *Watcher) initLocked() { - if w.pollingWatcher != nil { +func (w *Watcher) handleNotifications() { + if w.notifyWatcher == nil { return } - if w.pollingWatcher == nil { - w.pollingWatcher = filenotify.NewPollingWatcher(nil) + for { + select { + case <-w.cancelCtx.Done(): + return + case err := <-w.notifyWatcher.Errors: + log.Debug().Err(err).Msg("fileutil/watcher: filesystem notification error") + case evt := <-w.notifyWatcher.Events: + if evt.Has(fsnotify.Create) || evt.Has(fsnotify.Remove) || evt.Has(fsnotify.Write) { + w.mu.Lock() + if wf, ok := w.files[evt.Name]; ok { + wf.force = true + } + w.mu.Unlock() + } + } + } +} + +func (w *Watcher) handlePolling() { + ticker := time.NewTicker(pollingInterval) + defer ticker.Stop() + + for { + w.mu.Lock() + w.checkLocked() + w.mu.Unlock() + + select { + case <-w.cancelCtx.Done(): + return + case <-ticker.C: + } + } +} + +func (w *Watcher) checkLocked() { + w.checkDirectoriesLocked() + if changedPaths := w.checkFilesLocked(); len(changedPaths) > 0 { + log.Ctx(w.cancelCtx).Info().Strs("paths", changedPaths).Msg("fileutil/watcher: file change event") + w.Signal.Broadcast(w.cancelCtx) + } +} + +func (w *Watcher) checkDirectoriesLocked() { + // only watch directories that exist + dirs := make([]string, 0, len(w.directoryPaths)) + for _, dp := range w.directoryPaths { + fi, _ := os.Stat(dp) + if fi != nil && fi.IsDir() { + dirs = append(dirs, dp) + } } - errors := w.pollingWatcher.Errors() - events := w.pollingWatcher.Events() - - // log errors - go func() { - for err := range errors { - log.Error().Err(err).Msg("fileutil/watcher: file notification error") - } - }() - - // handle events - go func() { - for evt := range events { - log.Info().Str("name", evt.Name).Str("op", evt.Op.String()).Msg("fileutil/watcher: file notification event") - w.Broadcast(w.cancelCtx) - } - }() + updateMap(w.directories, dirs, + func(dp string) struct{} { + log.Ctx(w.cancelCtx).Debug().Str("path", dp).Msg("fileutil/watcher: watching directory") + if w.notifyWatcher != nil { + _ = w.notifyWatcher.Add(dp) + } + return struct{}{} + }, + func(dp string, _ struct{}) { + log.Ctx(w.cancelCtx).Debug().Str("path", dp).Msg("fileutil/watcher: stopped watching directory") + if w.notifyWatcher != nil { + _ = w.notifyWatcher.Remove(dp) + } + }) +} + +func (w *Watcher) checkFilesLocked() (changedPaths []string) { + updateMap(w.files, w.filePaths, + func(fp string) *watchedFile { + log.Ctx(w.cancelCtx).Debug().Str("path", fp).Msg("fileutil/watcher: watching file") + wf := newWatchedFile(fp) + wf.check() + return wf + }, + func(fp string, _ *watchedFile) { + log.Ctx(w.cancelCtx).Debug().Str("path", fp).Msg("fileutil/watcher: stopped watching file") + }) + + for fp, wf := range w.files { + if wf.check() { + changedPaths = append(changedPaths, fp) + } + } + + return changedPaths +} + +func getFileSize(fi fs.FileInfo) int64 { + if fi == nil { + return 0 + } + return fi.Size() +} + +func getFileModTime(fi fs.FileInfo) int64 { + if fi == nil { + return 0 + } + tm := fi.ModTime() + // UnixNano on a zero time is undefined, so just always return 0 for that + if tm.IsZero() { + return 0 + } + return tm.UnixNano() +} + +func hashFile(path string) uint64 { + f, err := os.Open(path) + if err != nil { + return 0 + } + + h := xxh3.New() + _, err = io.Copy(h, f) + if err != nil { + _ = f.Close() + return 0 + } + + err = f.Close() + if err != nil { + return 0 + } + + return h.Sum64() +} + +func swap[T comparable](dst *T, src T) (changed bool) { + if *dst == src { + return false + } + *dst = src + return true +} + +func updateMap[TKey comparable, T any]( + dst map[TKey]T, + keys []TKey, + create func(k TKey) T, + remove func(k TKey, v T), +) { + for _, k := range keys { + if _, ok := dst[k]; !ok { + dst[k] = create(k) + } + } + s := set.From(keys) + for k, v := range dst { + if !s.Contains(k) { + remove(k, v) + delete(dst, k) + } + } } diff --git a/internal/fileutil/watcher_test.go b/internal/fileutil/watcher_test.go index 4d6bb06b1..08dde6520 100644 --- a/internal/fileutil/watcher_test.go +++ b/internal/fileutil/watcher_test.go @@ -88,6 +88,56 @@ func TestWatcher_FileRemoval(t *testing.T) { expectChange(t, ch) } +func TestWatcher_FileModification(t *testing.T) { + t.Parallel() + + tmpdir := t.TempDir() + nm := filepath.Join(tmpdir, "test1.txt") + now := time.Now() + + require.NoError(t, os.WriteFile(nm, []byte{1, 2, 3, 4}, 0o666)) + require.NoError(t, os.Chtimes(nm, now, now)) + + w := NewWatcher() + defer w.Close() + w.Watch([]string{nm}) + + ch := w.Bind() + t.Cleanup(func() { w.Unbind(ch) }) + + require.NoError(t, os.WriteFile(nm, []byte{5, 6, 7, 8}, 0o666)) + require.NoError(t, os.Chtimes(nm, now, now)) + + expectChange(t, ch) +} + +func TestWatcher_UnWatch(t *testing.T) { + t.Parallel() + + tmpdir := t.TempDir() + nm := filepath.Join(tmpdir, "test1.txt") + now := time.Now() + + require.NoError(t, os.WriteFile(nm, []byte{1, 2, 3}, 0o666)) + require.NoError(t, os.Chtimes(nm, now, now)) + + w := NewWatcher() + defer w.Close() + + ch := w.Bind() + t.Cleanup(func() { w.Unbind(ch) }) + + w.Watch([]string{nm}) + require.NoError(t, os.WriteFile(nm, []byte{4, 5, 6}, 0o666)) + require.NoError(t, os.Chtimes(nm, now, now)) + expectChange(t, ch) + + w.Watch(nil) + require.NoError(t, os.WriteFile(nm, []byte{7, 8, 9}, 0o666)) + require.NoError(t, os.Chtimes(nm, now, now)) + expectNoChange(t, ch) +} + func expectChange(t *testing.T, ch chan context.Context) { t.Helper() @@ -95,9 +145,19 @@ func expectChange(t *testing.T, ch chan context.Context) { select { case <-ch: cnt++ - case <-time.After(10 * time.Second): - } - if cnt == 0 { - t.Error("expected change signal") + case <-time.After(2 * pollingInterval): } + assert.Greater(t, cnt, 0, "should signal a change") +} + +func expectNoChange(t *testing.T, ch chan context.Context) { + t.Helper() + + cnt := 0 + select { + case <-ch: + cnt++ + case <-time.After(2 * pollingInterval): + } + assert.Equal(t, 0, cnt, "should not signal a change") }