diff --git a/go.mod b/go.mod index 27479a3f9..d72f84a26 100644 --- a/go.mod +++ b/go.mod @@ -17,6 +17,7 @@ require ( github.com/docker/docker v20.10.19+incompatible github.com/envoyproxy/go-control-plane v0.10.3-0.20220819153403-8a9be01c9575 github.com/envoyproxy/protoc-gen-validate v0.6.13 + github.com/fsnotify/fsnotify v1.5.4 github.com/go-chi/chi/v5 v5.0.7 github.com/go-jose/go-jose/v3 v3.0.0 github.com/go-redis/redis/v8 v8.11.5 @@ -50,7 +51,6 @@ require ( github.com/prometheus/client_model v0.2.0 github.com/prometheus/common v0.37.0 github.com/prometheus/procfs v0.8.0 - github.com/rjeczalik/notify v0.9.3-0.20201210012515-e2a77dcc14cf github.com/rs/cors v1.8.2 github.com/rs/zerolog v1.28.0 github.com/shirou/gopsutil/v3 v3.22.9 @@ -74,6 +74,7 @@ require ( google.golang.org/protobuf v1.28.1 gopkg.in/auth0.v5 v5.21.1 gopkg.in/yaml.v3 v3.0.1 + namespacelabs.dev/go-filenotify v0.0.0-20220511192020-53ea11be7eaa sigs.k8s.io/yaml v1.3.0 ) @@ -125,7 +126,6 @@ require ( github.com/fatih/structtag v1.2.0 // indirect github.com/felixge/httpsnoop v1.0.2 // indirect github.com/firefart/nonamedreturns v1.0.4 // indirect - github.com/fsnotify/fsnotify v1.5.4 // indirect github.com/fxamacker/cbor/v2 v2.3.0 // indirect github.com/fzipp/gocyclo v0.6.0 // indirect github.com/ghodss/yaml v1.0.0 // indirect diff --git a/go.sum b/go.sum index 8f3a55ea1..aae3169e5 100644 --- a/go.sum +++ b/go.sum @@ -952,8 +952,6 @@ github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475 h1:N/ElC8H3+5X github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4= github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY= github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= -github.com/rjeczalik/notify v0.9.3-0.20201210012515-e2a77dcc14cf h1:MY2fqXPSLfjld10N04fNcSFdR9K/Y57iXxZRFAivHzI= -github.com/rjeczalik/notify v0.9.3-0.20201210012515-e2a77dcc14cf/go.mod h1:aErll2f0sUX9PXZnVNyeiObbmTlk5jnMoCa4QEjJeqM= github.com/rogpeppe/fastuuid v0.0.0-20150106093220-6724a57986af/go.mod h1:XWv6SoW27p1b0cqNHllgS5HIMJraePCO15w5zCzIWYg= github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= @@ -1365,7 +1363,6 @@ golang.org/x/sync v0.0.0-20220929204114-8fcdb60fdcc0/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20180926160741-c2ed4eda69e7/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20181107165924-66b7b1311ac8/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20181205085412-a5c9d58dba9a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -1862,6 +1859,8 @@ mvdan.cc/lint v0.0.0-20170908181259-adc824a0674b h1:DxJ5nJdkhDlLok9K6qO+5290kphD mvdan.cc/lint v0.0.0-20170908181259-adc824a0674b/go.mod h1:2odslEg/xrtNQqCYg2/jCoyKnw3vv5biOc3JnIcYfL4= mvdan.cc/unparam v0.0.0-20220706161116-678bad134442 h1:seuXWbRB1qPrS3NQnHmFKLJLtskWyueeIzmLXghMGgk= mvdan.cc/unparam v0.0.0-20220706161116-678bad134442/go.mod h1:F/Cxw/6mVrNKqrR2YjFf5CaW0Bw4RL8RfbEf4GRggJk= +namespacelabs.dev/go-filenotify v0.0.0-20220511192020-53ea11be7eaa h1:jj2kjs0Hvufj40wuhMAzoZUOwrwMDFg1gHZ49RiIv9w= +namespacelabs.dev/go-filenotify v0.0.0-20220511192020-53ea11be7eaa/go.mod h1:e8NJRaInXRRm1+KPA6EkGEzdLJAgEvVSIKiLzpP97nI= rsc.io/binaryregexp v0.2.0/go.mod h1:qTv7/COck+e2FymRvadv62gMdZztPaShugOCi3I+8D8= rsc.io/quote/v3 v3.1.0/go.mod h1:yEA65RcK8LyAZtP9Kv3t0HmxON59tX3rD+tICJqUlj0= rsc.io/sampler v1.3.0/go.mod h1:T1hPZKmBbMNahiBKFy5HrXp6adAjACjK9JXDnKaTXpA= diff --git a/internal/chanutil/batch.go b/internal/chanutil/batch.go new file mode 100644 index 000000000..688cd1947 --- /dev/null +++ b/internal/chanutil/batch.go @@ -0,0 +1,76 @@ +package chanutil + +import "time" + +const ( + defaultBatchMaxSize = 1024 + defaultBatchMaxWait = time.Millisecond * 300 +) + +type batchConfig struct { + maxSize int + maxWait time.Duration +} + +// A BatchOption customizes a batch operation. +type BatchOption func(cfg *batchConfig) + +// WithBatchMaxSize sets the maximum batch size for a Batch operation. +func WithBatchMaxSize(maxSize int) BatchOption { + return func(cfg *batchConfig) { + cfg.maxSize = maxSize + } +} + +// WithBatchMaxWait sets the maximum wait duration for a Batch operation. +func WithBatchMaxWait(maxWait time.Duration) BatchOption { + return func(cfg *batchConfig) { + cfg.maxWait = maxWait + } +} + +// Batch returns a new channel that consumes all the items from `in` and batches them together. +func Batch[T any](in <-chan T, options ...BatchOption) <-chan []T { + cfg := new(batchConfig) + WithBatchMaxSize(defaultBatchMaxSize)(cfg) + WithBatchMaxWait(defaultBatchMaxWait)(cfg) + for _, option := range options { + option(cfg) + } + + out := make(chan []T) + go func() { + var buf []T + var timer <-chan time.Time + for { + if in == nil && timer == nil { + close(out) + return + } + + select { + case item, ok := <-in: + if !ok { + in = nil + timer = time.After(0) + continue + } + buf = append(buf, item) + if timer == nil { + timer = time.After(cfg.maxWait) + } + case <-timer: + timer = nil + for len(buf) > 0 { + batch := buf + if len(batch) > cfg.maxSize { + batch = batch[:cfg.maxSize] + } + buf = buf[len(batch):] + out <- batch + } + } + } + }() + return out +} diff --git a/internal/chanutil/batch_test.go b/internal/chanutil/batch_test.go new file mode 100644 index 000000000..81200d87b --- /dev/null +++ b/internal/chanutil/batch_test.go @@ -0,0 +1,25 @@ +package chanutil + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestBatch(t *testing.T) { + ch1 := make(chan int) + go func() { + for _, i := range []int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10} { + ch1 <- i + } + close(ch1) + }() + + ch2 := Batch(ch1, WithBatchMaxWait(time.Millisecond*10), WithBatchMaxSize(3)) + assert.Equal(t, []int{1, 2, 3}, <-ch2) + assert.Equal(t, []int{4, 5, 6}, <-ch2) + assert.Equal(t, []int{7, 8, 9}, <-ch2) + assert.Equal(t, []int{10}, <-ch2) + assert.Equal(t, []int(nil), <-ch2) +} diff --git a/internal/chanutil/chanutil.go b/internal/chanutil/chanutil.go new file mode 100644 index 000000000..f5c89a717 --- /dev/null +++ b/internal/chanutil/chanutil.go @@ -0,0 +1,2 @@ +// Package chanutil implements methods for working with channels. +package chanutil diff --git a/internal/chanutil/merge.go b/internal/chanutil/merge.go new file mode 100644 index 000000000..c2e0ba7fd --- /dev/null +++ b/internal/chanutil/merge.go @@ -0,0 +1,46 @@ +package chanutil + +// Merge merges multiple channels together. +func Merge[T any](ins ...<-chan T) <-chan T { + switch len(ins) { + case 0: + return nil + case 1: + return ins[0] + case 2: + default: + return Merge( + Merge(ins[:len(ins)/2]...), + Merge(ins[len(ins)/2:]...), + ) + } + + in1, in2 := ins[0], ins[1] + out := make(chan T) + go func() { + for { + if in1 == nil && in2 == nil { + close(out) + return + } + + select { + case item, ok := <-in1: + if !ok { + in1 = nil + continue + } + + out <- item + case item, ok := <-in2: + if !ok { + in2 = nil + continue + } + + out <- item + } + } + }() + return out +} diff --git a/internal/chanutil/merge_test.go b/internal/chanutil/merge_test.go new file mode 100644 index 000000000..3e1e30407 --- /dev/null +++ b/internal/chanutil/merge_test.go @@ -0,0 +1,37 @@ +package chanutil + +import ( + "sort" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestMerge(t *testing.T) { + ch1, ch2, ch3 := make(chan int), make(chan int), make(chan int) + go func() { + for _, i := range []int{1, 2, 3} { + ch1 <- i + } + close(ch1) + }() + go func() { + for _, i := range []int{4, 5, 6} { + ch2 <- i + } + close(ch2) + }() + go func() { + for _, i := range []int{7, 8, 9} { + ch3 <- i + } + close(ch3) + }() + out := Merge(ch1, ch2, ch3) + var tmp []int + for item := range out { + tmp = append(tmp, item) + } + sort.Ints(tmp) + assert.Equal(t, []int{1, 2, 3, 4, 5, 6, 7, 8, 9}, tmp) +} diff --git a/internal/fileutil/watcher.go b/internal/fileutil/watcher.go index d2acf7c2f..c1f6ee34e 100644 --- a/internal/fileutil/watcher.go +++ b/internal/fileutil/watcher.go @@ -4,9 +4,11 @@ import ( "context" "sync" - "github.com/rjeczalik/notify" + "github.com/fsnotify/fsnotify" "github.com/rs/zerolog" + "namespacelabs.dev/go-filenotify" + "github.com/pomerium/pomerium/internal/chanutil" "github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/signal" ) @@ -14,15 +16,18 @@ import ( // A Watcher watches files for changes. type Watcher struct { *signal.Signal - mu sync.Mutex - filePaths map[string]chan notify.EventInfo + + mu sync.Mutex + watching map[string]struct{} + eventWatcher filenotify.FileWatcher + pollingWatcher filenotify.FileWatcher } // NewWatcher creates a new Watcher. func NewWatcher() *Watcher { return &Watcher{ - Signal: signal.New(), - filePaths: map[string]chan notify.EventInfo{}, + Signal: signal.New(), + watching: make(map[string]struct{}), } } @@ -31,32 +36,27 @@ func (watcher *Watcher) Add(filePath string) { watcher.mu.Lock() defer watcher.mu.Unlock() - ctx := log.WithContext(context.TODO(), func(c zerolog.Context) zerolog.Context { + // already watching + if _, ok := watcher.watching[filePath]; ok { + return + } + + ctx := log.WithContext(context.Background(), func(c zerolog.Context) zerolog.Context { return c.Str("watch_file", filePath) }) + watcher.initLocked(ctx) - // already watching - if _, ok := watcher.filePaths[filePath]; ok { - return - } - - ch := make(chan notify.EventInfo, 1) - go func() { - for evt := range ch { - log.Info(ctx).Str("event", evt.Event().String()).Msg("filemgr: detected file change") - watcher.Signal.Broadcast(ctx) + if watcher.eventWatcher != nil { + if err := watcher.eventWatcher.Add(filePath); err != nil { + log.Error(ctx).Msg("fileutil/watcher: failed to watch file with event-based file watcher") } - }() - err := notify.Watch(filePath, ch, notify.All) - if err != nil { - log.Error(ctx).Err(err).Msg("filemgr: error watching file path") - notify.Stop(ch) - close(ch) - return } - log.Debug(ctx).Msg("filemgr: watching file for changes") - watcher.filePaths[filePath] = ch + if watcher.pollingWatcher != nil { + if err := watcher.pollingWatcher.Add(filePath); err != nil { + log.Error(ctx).Msg("fileutil/watcher: failed to watch file with polling-based file watcher") + } + } } // Clear removes all watches. @@ -64,9 +64,57 @@ func (watcher *Watcher) Clear() { watcher.mu.Lock() defer watcher.mu.Unlock() - for filePath, ch := range watcher.filePaths { - notify.Stop(ch) - close(ch) - delete(watcher.filePaths, filePath) + if w := watcher.eventWatcher; w != nil { + _ = watcher.pollingWatcher.Close() + watcher.eventWatcher = nil } + + if w := watcher.pollingWatcher; w != nil { + _ = watcher.pollingWatcher.Close() + watcher.pollingWatcher = nil + } + + watcher.watching = make(map[string]struct{}) +} + +func (watcher *Watcher) initLocked(ctx context.Context) { + if watcher.eventWatcher != nil || watcher.pollingWatcher != nil { + return + } + + if watcher.eventWatcher == nil { + var err error + watcher.eventWatcher, err = filenotify.NewEventWatcher() + if err != nil { + log.Error(ctx).Msg("fileutil/watcher: failed to create event-based file watcher") + } + } + if watcher.pollingWatcher == nil { + watcher.pollingWatcher = filenotify.NewPollingWatcher(nil) + } + + var errors <-chan error = watcher.pollingWatcher.Errors() //nolint + var events <-chan fsnotify.Event = watcher.pollingWatcher.Events() //nolint + + if watcher.eventWatcher != nil { + errors = chanutil.Merge(errors, watcher.eventWatcher.Errors()) + events = chanutil.Merge(events, watcher.eventWatcher.Events()) + } + + // log errors + go func() { + for err := range errors { + log.Error(ctx).Err(err).Msg("fileutil/watcher: file notification error") + } + }() + + // handle events + go func() { + for evts := range chanutil.Batch(events) { + for _, evt := range evts { + log.Info(ctx).Str("name", evt.Name).Str("op", evt.Op.String()).Msg("fileutil/watcher: file notification event") + } + watcher.Broadcast(ctx) + } + }() } diff --git a/internal/fileutil/watcher_test.go b/internal/fileutil/watcher_test.go index 1319ecc48..ec6017b9f 100644 --- a/internal/fileutil/watcher_test.go +++ b/internal/fileutil/watcher_test.go @@ -23,6 +23,7 @@ func TestWatcher(t *testing.T) { } w := NewWatcher() + defer w.Clear() w.Add(filepath.Join(tmpdir, "test1.txt")) ch := w.Bind() @@ -39,3 +40,50 @@ func TestWatcher(t *testing.T) { t.Error("expected change signal when file is modified") } } + +func TestWatcherSymlink(t *testing.T) { + t.Parallel() + + tmpdir := filepath.Join(os.TempDir(), uuid.New().String()) + err := os.MkdirAll(tmpdir, 0o755) + if !assert.NoError(t, err) { + return + } + t.Cleanup(func() { os.RemoveAll(tmpdir) }) + + err = os.WriteFile(filepath.Join(tmpdir, "test1.txt"), []byte{1, 2, 3, 4}, 0o666) + if !assert.NoError(t, err) { + return + } + + err = os.WriteFile(filepath.Join(tmpdir, "test2.txt"), []byte{5, 6, 7, 8}, 0o666) + if !assert.NoError(t, err) { + return + } + + assert.NoError(t, os.Symlink(filepath.Join(tmpdir, "test1.txt"), filepath.Join(tmpdir, "symlink1.txt"))) + + w := NewWatcher() + defer w.Clear() + w.Add(filepath.Join(tmpdir, "symlink1.txt")) + + ch := w.Bind() + t.Cleanup(func() { w.Unbind(ch) }) + + assert.NoError(t, os.WriteFile(filepath.Join(tmpdir, "test1.txt"), []byte{9, 10, 11}, 0o666)) + + select { + case <-ch: + case <-time.After(time.Second): + t.Error("expected change signal when underlying file is modified") + } + + assert.NoError(t, os.Symlink(filepath.Join(tmpdir, "test2.txt"), filepath.Join(tmpdir, "symlink2.txt"))) + assert.NoError(t, os.Rename(filepath.Join(tmpdir, "symlink2.txt"), filepath.Join(tmpdir, "symlink1.txt"))) + + select { + case <-ch: + case <-time.After(10 * time.Second): + t.Error("expected change signal when symlink is changed") + } +}