diff --git a/internal/fileutil/watcher_test.go b/internal/fileutil/watcher_test.go index a012ab514..e3001b9ed 100644 --- a/internal/fileutil/watcher_test.go +++ b/internal/fileutil/watcher_test.go @@ -25,18 +25,14 @@ func TestWatcher(t *testing.T) { w.Watch(context.Background(), []string{filepath.Join(tmpdir, "test1.txt")}) ch := w.Bind() - defer w.Unbind(ch) + t.Cleanup(func() { w.Unbind(ch) }) err = os.WriteFile(filepath.Join(tmpdir, "test1.txt"), []byte{5, 6, 7, 8}, 0o666) if !assert.NoError(t, err) { return } - select { - case <-ch: - case <-time.After(time.Second): - t.Error("expected change signal when file is modified") - } + expectChange(t, ch) } func TestWatcherSymlink(t *testing.T) { @@ -64,20 +60,12 @@ func TestWatcherSymlink(t *testing.T) { assert.NoError(t, os.WriteFile(filepath.Join(tmpdir, "test1.txt"), []byte{9, 10, 11}, 0o666)) - select { - case <-ch: - case <-time.After(time.Second): - t.Error("expected change signal when underlying file is modified") - } + expectChange(t, ch) assert.NoError(t, os.Symlink(filepath.Join(tmpdir, "test2.txt"), filepath.Join(tmpdir, "symlink2.txt"))) assert.NoError(t, os.Rename(filepath.Join(tmpdir, "symlink2.txt"), filepath.Join(tmpdir, "symlink1.txt"))) - select { - case <-ch: - case <-time.After(10 * time.Second): - t.Error("expected change signal when symlink is changed") - } + expectChange(t, ch) } func TestWatcher_FileRemoval(t *testing.T) { @@ -92,31 +80,33 @@ func TestWatcher_FileRemoval(t *testing.T) { w.Watch(context.Background(), []string{filepath.Join(tmpdir, "test1.txt")}) ch := w.Bind() - defer w.Unbind(ch) + t.Cleanup(func() { w.Unbind(ch) }) err = os.Remove(filepath.Join(tmpdir, "test1.txt")) require.NoError(t, err) - expectChange := func() { - cnt := 0 - loop: - for { - select { - case <-ch: - cnt++ - case <-time.After(time.Second): - break loop - } - } - if cnt == 0 { - t.Error("expected change signal") - } - } - - expectChange() + expectChange(t, ch) err = os.WriteFile(filepath.Join(tmpdir, "test1.txt"), []byte{5, 6, 7, 8}, 0o666) require.NoError(t, err) - expectChange() + expectChange(t, ch) +} + +func expectChange(t *testing.T, ch chan context.Context) { + t.Helper() + + cnt := 0 +loop: + for { + select { + case <-ch: + cnt++ + case <-time.After(time.Second): + break loop + } + } + if cnt == 0 { + t.Error("expected change signal") + } }