package fileutil import ( "cmp" "context" "io" "io/fs" "os" "path/filepath" "sync" "time" "github.com/fsnotify/fsnotify" "github.com/hashicorp/go-set/v3" "github.com/zeebo/xxh3" "github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/signal" ) const ( pollingInterval = time.Millisecond * 200 ) type watchedFile struct { path string size int64 modTime int64 hash uint64 force bool // indicates that the next check should compute the hash of the file as well } func newWatchedFile(path string) *watchedFile { return &watchedFile{path: path, force: true} } func (wf *watchedFile) check() (changed bool) { fi, _ := os.Stat(wf.path) changed = swap(&wf.size, getFileSize(fi)) || changed changed = swap(&wf.modTime, getFileModTime(fi)) || changed // if the file size or mod time has changed, re-compute the file contents hash if changed || wf.force { changed = swap(&wf.hash, hashFile(wf.path)) wf.force = false } return changed } // A Watcher watches files for changes. type Watcher struct { *signal.Signal cancelCtx context.Context cancel context.CancelFunc mu sync.Mutex notifyWatcher *fsnotify.Watcher filePaths []string files map[string]*watchedFile directoryPaths []string directories map[string]struct{} } // NewWatcher creates a new Watcher. func NewWatcher() *Watcher { w := &Watcher{ Signal: signal.New(), files: map[string]*watchedFile{}, directories: map[string]struct{}{}, } w.cancelCtx, w.cancel = context.WithCancel(context.Background()) var err error w.notifyWatcher, err = fsnotify.NewWatcher() if err != nil { log.Error().Err(err).Msg("fileutil/watcher: file system notifications disabled") } go w.handlePolling() go w.handleNotifications() return w } // Close closes the watcher. func (w *Watcher) Close() error { w.cancel() w.mu.Lock() defer w.mu.Unlock() var err error if w.notifyWatcher != nil { err = w.notifyWatcher.Close() w.notifyWatcher = nil } return err } // Watch updates the watched file paths. func (w *Watcher) Watch(filePaths []string) { w.mu.Lock() defer w.mu.Unlock() w.filePaths = set.TreeSetFrom(filePaths, cmp.Compare[string]).Slice() var dps []string for _, fp := range filePaths { dps = append(dps, filepath.Dir(fp)) } w.directoryPaths = set.TreeSetFrom(dps, cmp.Compare[string]).Slice() w.checkLocked() } func (w *Watcher) handleNotifications() { w.mu.Lock() nw := w.notifyWatcher w.mu.Unlock() if nw == nil { return } for { select { case <-w.cancelCtx.Done(): return case err := <-nw.Errors: log.Debug().Err(err).Msg("fileutil/watcher: filesystem notification error") case evt := <-nw.Events: if evt.Has(fsnotify.Create) || evt.Has(fsnotify.Remove) || evt.Has(fsnotify.Write) { w.mu.Lock() if wf, ok := w.files[evt.Name]; ok { wf.force = true } w.mu.Unlock() // the actual check will be done via the polling interval to debounce } } } } func (w *Watcher) handlePolling() { ticker := time.NewTicker(pollingInterval) defer ticker.Stop() for { w.mu.Lock() w.checkLocked() w.mu.Unlock() select { case <-w.cancelCtx.Done(): return case <-ticker.C: } } } func (w *Watcher) checkLocked() { w.checkDirectoriesLocked() if changedPaths := w.checkFilesLocked(); len(changedPaths) > 0 { log.Ctx(w.cancelCtx).Info().Strs("paths", changedPaths).Msg("fileutil/watcher: file change event") w.Signal.Broadcast(w.cancelCtx) } } func (w *Watcher) checkDirectoriesLocked() { // only watch directories that exist dirs := make([]string, 0, len(w.directoryPaths)) for _, dp := range w.directoryPaths { fi, _ := os.Stat(dp) if fi != nil && fi.IsDir() { dirs = append(dirs, dp) } } updateMap(w.directories, dirs, func(dp string) struct{} { log.Ctx(w.cancelCtx).Debug().Str("path", dp).Msg("fileutil/watcher: watching directory") if w.notifyWatcher != nil { _ = w.notifyWatcher.Add(dp) } return struct{}{} }, func(dp string, _ struct{}) { log.Ctx(w.cancelCtx).Debug().Str("path", dp).Msg("fileutil/watcher: stopped watching directory") if w.notifyWatcher != nil { _ = w.notifyWatcher.Remove(dp) } }) } func (w *Watcher) checkFilesLocked() (changedPaths []string) { updateMap(w.files, w.filePaths, func(fp string) *watchedFile { log.Ctx(w.cancelCtx).Debug().Str("path", fp).Msg("fileutil/watcher: watching file") wf := newWatchedFile(fp) wf.check() return wf }, func(fp string, _ *watchedFile) { log.Ctx(w.cancelCtx).Debug().Str("path", fp).Msg("fileutil/watcher: stopped watching file") }) for fp, wf := range w.files { if wf.check() { changedPaths = append(changedPaths, fp) } } return changedPaths } func getFileSize(fi fs.FileInfo) int64 { if fi == nil { return 0 } return fi.Size() } func getFileModTime(fi fs.FileInfo) int64 { if fi == nil { return 0 } tm := fi.ModTime() // UnixNano on a zero time is undefined, so just always return 0 for that if tm.IsZero() { return 0 } return tm.UnixNano() } func hashFile(path string) uint64 { f, err := os.Open(path) if err != nil { return 0 } h := xxh3.New() _, err = io.Copy(h, f) if err != nil { _ = f.Close() return 0 } err = f.Close() if err != nil { return 0 } return h.Sum64() } func swap[T comparable](dst *T, src T) (changed bool) { if *dst == src { return false } *dst = src return true } func updateMap[TKey comparable, T any]( dst map[TKey]T, keys []TKey, create func(k TKey) T, remove func(k TKey, v T), ) { for _, k := range keys { if _, ok := dst[k]; !ok { dst[k] = create(k) } } s := set.From(keys) for k, v := range dst { if !s.Contains(k) { remove(k, v) delete(dst, k) } } }