diff --git a/pkg/envoy/resource_monitor_linux.go b/pkg/envoy/resource_monitor_linux.go index 743df077c..3b126459b 100644 --- a/pkg/envoy/resource_monitor_linux.go +++ b/pkg/envoy/resource_monitor_linux.go @@ -125,23 +125,48 @@ const ( metricCgroupMemorySaturation = "cgroup_memory_saturation" ) +type ResourceMonitorOptions struct { + driver CgroupDriver +} + +type ResourceMonitorOption func(*ResourceMonitorOptions) + +func (o *ResourceMonitorOptions) apply(opts ...ResourceMonitorOption) { + for _, op := range opts { + op(o) + } +} + +// WithCgroupDriver overrides the cgroup driver used for the resource monitor. +// If unset, it will be chosen automatically. +func WithCgroupDriver(driver CgroupDriver) ResourceMonitorOption { + return func(o *ResourceMonitorOptions) { + o.driver = driver + } +} + // NewSharedResourceMonitor creates a new ResourceMonitor suitable for running // envoy in the same cgroup as the parent process. It reports the cgroup's // memory saturation to envoy as an injected resource. This allows envoy to // react to actual memory pressure in the cgroup, taking into account memory // usage from pomerium itself. -func NewSharedResourceMonitor(tempDir string) (ResourceMonitor, error) { - driver, err := SystemCgroupDriver() - if err != nil { - return nil, err +func NewSharedResourceMonitor(tempDir string, opts ...ResourceMonitorOption) (ResourceMonitor, error) { + options := ResourceMonitorOptions{} + options.apply(opts...) + if options.driver == nil { + var err error + options.driver, err = DetectCgroupDriver() + if err != nil { + return nil, err + } } recordActionThresholds() - selfCgroup, err := driver.CgroupForPid(os.Getpid()) + selfCgroup, err := options.driver.CgroupForPid(os.Getpid()) if err != nil { return nil, fmt.Errorf("failed to look up cgroup: %w", err) } - if err := driver.Validate(selfCgroup); err != nil { + if err := options.driver.Validate(selfCgroup); err != nil { return nil, fmt.Errorf("cgroup not valid for resource monitoring: %w", err) } @@ -150,9 +175,9 @@ func NewSharedResourceMonitor(tempDir string) (ResourceMonitor, error) { } s := &sharedResourceMonitor{ - driver: driver, - cgroup: selfCgroup, - tempDir: filepath.Join(tempDir, "resource_monitor"), + ResourceMonitorOptions: options, + cgroup: selfCgroup, + tempDir: filepath.Join(tempDir, "resource_monitor"), } if err := s.writeMetricFile(groupMemory, metricCgroupMemorySaturation, "0", 0o644); err != nil { @@ -162,7 +187,7 @@ func NewSharedResourceMonitor(tempDir string) (ResourceMonitor, error) { } type sharedResourceMonitor struct { - driver CgroupDriver + ResourceMonitorOptions cgroup string tempDir string } @@ -226,6 +251,12 @@ func (s *sharedResourceMonitor) ApplyBootstrapConfig(bootstrap *envoy_config_boo } } +var ( + monitorInitialTickDuration = 1 * time.Second + monitorMaxTickDuration = 5 * time.Second + monitorMinTickDuration = 250 * time.Millisecond +) + func (s *sharedResourceMonitor) Run(ctx context.Context, envoyPid int) error { envoyCgroup, err := s.driver.CgroupForPid(envoyPid) if err != nil { @@ -237,9 +268,14 @@ func (s *sharedResourceMonitor) Run(ctx context.Context, envoyPid int) error { log.Info(ctx).Str("service", "envoy").Str("cgroup", s.cgroup).Msg("starting resource monitor") limitWatcher := &memoryLimitWatcher{ - limitFilePath: s.driver.Path(s.cgroup, MemoryLimitPath), + limitFilePath: "/" + s.driver.Path(s.cgroup, MemoryLimitPath), } - if err := limitWatcher.Watch(ctx); err != nil { + lwCtx, lwCancel := context.WithCancel(ctx) + defer func() { + lwCancel() + limitWatcher.Wait() + }() + if err := limitWatcher.Watch(lwCtx); err != nil { return fmt.Errorf("failed to start watch on cgroup memory limit: %w", err) } @@ -256,10 +292,8 @@ func (s *sharedResourceMonitor) Run(ctx context.Context, envoyPid int) error { // taking disruptive actions for as long as possible. // the envoy default interval for the builtin heap monitor is 1s - initialTickDuration := 1 * time.Second - maxTickDuration := 5 * time.Second - minTickDuration := 250 * time.Millisecond - tick := time.NewTimer(initialTickDuration) + + tick := time.NewTimer(monitorInitialTickDuration) var lastValue string for { select { @@ -278,7 +312,7 @@ func (s *sharedResourceMonitor) Run(ctx context.Context, envoyPid int) error { } saturationStr := fmt.Sprintf("%.6f", saturation) - nextInterval := (maxTickDuration - (time.Duration(float64(maxTickDuration-minTickDuration) * saturation))). + nextInterval := (monitorMaxTickDuration - (time.Duration(float64(monitorMaxTickDuration-monitorMinTickDuration) * saturation))). Round(time.Millisecond) if saturationStr != lastValue { @@ -348,6 +382,7 @@ func (s *sharedResourceMonitor) writeMetricFile(group, name, data string, mode f } type cgroupV2Driver struct { + fs fs.FS root string } @@ -363,8 +398,8 @@ func (d *cgroupV2Driver) Path(cgroup string, kind CgroupFilePath) string { return "" } -func (*cgroupV2Driver) CgroupForPid(pid int) (string, error) { - data, err := os.ReadFile(fmt.Sprintf("/proc/%d/cgroup", pid)) +func (d *cgroupV2Driver) CgroupForPid(pid int) (string, error) { + data, err := fs.ReadFile(d.fs, fmt.Sprintf("proc/%d/cgroup", pid)) if err != nil { return "", err } @@ -373,7 +408,7 @@ func (*cgroupV2Driver) CgroupForPid(pid int) (string, error) { // MemoryUsage implements CgroupDriver. func (d *cgroupV2Driver) MemoryUsage(cgroup string) (uint64, error) { - current, err := os.ReadFile(d.Path(cgroup, MemoryUsagePath)) + current, err := fs.ReadFile(d.fs, d.Path(cgroup, MemoryUsagePath)) if err != nil { return 0, err } @@ -382,19 +417,20 @@ func (d *cgroupV2Driver) MemoryUsage(cgroup string) (uint64, error) { // MemoryLimit implements CgroupDriver. func (d *cgroupV2Driver) MemoryLimit(cgroup string) (uint64, error) { - max, err := os.ReadFile(d.Path(cgroup, MemoryLimitPath)) + data, err := fs.ReadFile(d.fs, d.Path(cgroup, MemoryLimitPath)) if err != nil { return 0, err } - if string(max) == "max" { + max := strings.TrimSpace(string(data)) + if max == "max" { return 0, nil } - return strconv.ParseUint(strings.TrimSpace(string(max)), 10, 64) + return strconv.ParseUint(max, 10, 64) } // Validate implements CgroupDriver. func (d *cgroupV2Driver) Validate(cgroup string) error { - if typ, err := os.ReadFile(filepath.Join(d.root, cgroup, "cgroup.type")); err != nil { + if typ, err := fs.ReadFile(d.fs, filepath.Join(d.root, cgroup, "cgroup.type")); err != nil { return err } else if strings.TrimSpace(string(typ)) != "domain" { return errors.New("not a domain cgroup") @@ -416,7 +452,7 @@ func (d *cgroupV2Driver) Validate(cgroup string) error { } func (d *cgroupV2Driver) enabledControllers(cgroup string) ([]string, error) { - data, err := os.ReadFile(filepath.Join(d.root, cgroup, "cgroup.controllers")) + data, err := fs.ReadFile(d.fs, filepath.Join(d.root, cgroup, "cgroup.controllers")) if err != nil { return nil, err } @@ -424,7 +460,7 @@ func (d *cgroupV2Driver) enabledControllers(cgroup string) ([]string, error) { } func (d *cgroupV2Driver) enabledSubtreeControllers(cgroup string) ([]string, error) { - data, err := os.ReadFile(filepath.Join(d.root, cgroup, "cgroup.subtree_control")) + data, err := fs.ReadFile(d.fs, filepath.Join(d.root, cgroup, "cgroup.subtree_control")) if err != nil { return nil, err } @@ -434,6 +470,7 @@ func (d *cgroupV2Driver) enabledSubtreeControllers(cgroup string) ([]string, err var _ CgroupDriver = (*cgroupV2Driver)(nil) type cgroupV1Driver struct { + fs fs.FS root string } @@ -450,7 +487,7 @@ func (d *cgroupV1Driver) Path(cgroup string, kind CgroupFilePath) string { } func (d *cgroupV1Driver) CgroupForPid(pid int) (string, error) { - data, err := os.ReadFile(fmt.Sprintf("/proc/%d/cgroup", pid)) + data, err := fs.ReadFile(d.fs, fmt.Sprintf("proc/%d/cgroup", pid)) if err != nil { return "", err } @@ -459,7 +496,7 @@ func (d *cgroupV1Driver) CgroupForPid(pid int) (string, error) { return "", err } - mountinfo, err := os.ReadFile(fmt.Sprintf("/proc/%d/mountinfo", pid)) + mountinfo, err := fs.ReadFile(d.fs, fmt.Sprintf("proc/%d/mountinfo", pid)) if err != nil { return "", err } @@ -469,11 +506,12 @@ func (d *cgroupV1Driver) CgroupForPid(pid int) (string, error) { if len(line) < 5 { continue } - // entries 3 and 4 contain the cgroup path and the mountpoint, respectively. + + // Entries 3 and 4 contain the root path and the mountpoint, respectively. // each resource will contain a separate mountpoint for the same path, so - // we can just pick the first one. - if line[3] == name { - mountpoint, err := filepath.Rel(d.root, filepath.Dir(line[4])) + // we can just pick one. + if line[4] == fmt.Sprintf("/%s/memory", d.root) { + mountpoint, err := filepath.Rel(line[3], name) if err != nil { return "", err } @@ -485,7 +523,7 @@ func (d *cgroupV1Driver) CgroupForPid(pid int) (string, error) { // MemoryUsage implements CgroupDriver. func (d *cgroupV1Driver) MemoryUsage(cgroup string) (uint64, error) { - current, err := os.ReadFile(d.Path(cgroup, MemoryUsagePath)) + current, err := fs.ReadFile(d.fs, d.Path(cgroup, MemoryUsagePath)) if err != nil { return 0, err } @@ -494,22 +532,23 @@ func (d *cgroupV1Driver) MemoryUsage(cgroup string) (uint64, error) { // MemoryLimit implements CgroupDriver. func (d *cgroupV1Driver) MemoryLimit(cgroup string) (uint64, error) { - max, err := os.ReadFile(d.Path(cgroup, MemoryLimitPath)) + data, err := fs.ReadFile(d.fs, d.Path(cgroup, MemoryLimitPath)) if err != nil { return 0, err } - if string(max) == "max" { + max := strings.TrimSpace(string(data)) + if max == "max" { return 0, nil } - return strconv.ParseUint(strings.TrimSpace(string(max)), 10, 64) + return strconv.ParseUint(max, 10, 64) } // Validate implements CgroupDriver. func (d *cgroupV1Driver) Validate(cgroup string) error { memoryPath := filepath.Join(d.root, "memory", cgroup) - info, err := os.Stat(memoryPath) + info, err := fs.Stat(d.fs, memoryPath) if err != nil { - if errors.Is(err, os.ErrNotExist) { + if errors.Is(err, fs.ErrNotExist) { return errors.New("memory controller not enabled") } return fmt.Errorf("failed to stat cgroup: %w", err) @@ -522,7 +561,7 @@ func (d *cgroupV1Driver) Validate(cgroup string) error { var _ CgroupDriver = (*cgroupV1Driver)(nil) -func SystemCgroupDriver() (CgroupDriver, error) { +func DetectCgroupDriver() (CgroupDriver, error) { const cgv2Magic = 0x63677270 fsType := func(path string) (int64, error) { @@ -538,6 +577,7 @@ func SystemCgroupDriver() (CgroupDriver, error) { return stat.Type, nil } } + osFs := os.DirFS("/") // fast path: cgroup2 only t, err := fsType("/sys/fs/cgroup") @@ -545,36 +585,18 @@ func SystemCgroupDriver() (CgroupDriver, error) { return nil, err } if t == cgv2Magic { - return &cgroupV2Driver{root: "/sys/fs/cgroup"}, nil + return &cgroupV2Driver{root: "sys/fs/cgroup", fs: osFs}, nil } - // find the unified mountpoint, or fall back to v1 - mounts, err := os.ReadFile("/proc/self/mounts") + // find the hybrid mountpoint, or fall back to v1 + mountpoint, isV2, err := findMountpoint(osFs) if err != nil { return nil, err } - scanner := bufio.NewScanner(bytes.NewReader(mounts)) - var cgv1Root string - for scanner.Scan() { - line := strings.Fields(scanner.Text()) - if len(line) < 3 { - continue - } - switch line[2] { - case "cgroup2": - return &cgroupV2Driver{root: line[1]}, nil - case "cgroup": - if cgv1Root == "" { - cgv1Root = filepath.Dir(line[1]) - } - } + if isV2 { + return &cgroupV2Driver{root: mountpoint, fs: osFs}, nil } - - if cgv1Root != "" { - return &cgroupV1Driver{root: cgv1Root}, nil - } - - return nil, errors.New("no cgroup mount found") + return &cgroupV1Driver{root: mountpoint, fs: osFs}, nil } func parseCgroupName(contents []byte) (string, error) { @@ -588,6 +610,31 @@ func parseCgroupName(contents []byte) (string, error) { return "", errors.New("cgroup not found") } +func findMountpoint(fsys fs.FS) (mountpoint string, isV2 bool, err error) { + mounts, err := fs.ReadFile(fsys, fmt.Sprintf("proc/%d/mountinfo", os.Getpid())) + if err != nil { + return "", false, err + } + scanner := bufio.NewScanner(bytes.NewReader(mounts)) + var cgv1Root string + for scanner.Scan() { + line := strings.Fields(scanner.Text()) + fsType := line[slices.Index(line, "-")+1] + switch fsType { + case "cgroup2": + return line[4][1:], true, nil + case "cgroup": + if cgv1Root == "" { + cgv1Root = filepath.Dir(line[4][1:]) + } + } + } + if cgv1Root == "" { + return "", false, errors.New("no cgroup mount found") + } + return cgv1Root, false, nil +} + func marshalAny(msg proto.Message) *anypb.Any { data := new(anypb.Any) _ = anypb.MarshalFrom(data, msg, proto.MarshalOptions{ @@ -600,23 +647,25 @@ func marshalAny(msg proto.Message) *anypb.Any { type memoryLimitWatcher struct { limitFilePath string - value atomic.Int64 + value atomic.Uint64 + + watches sync.WaitGroup } -func (w *memoryLimitWatcher) Value() int64 { +func (w *memoryLimitWatcher) Value() uint64 { return w.value.Load() } -func (w *memoryLimitWatcher) readValue() (int64, error) { +func (w *memoryLimitWatcher) readValue() (uint64, error) { data, err := os.ReadFile(w.limitFilePath) if err != nil { return 0, err } - if string(data) == "max" { - // no limit set + max := strings.TrimSpace(string(data)) + if max == "max" { return 0, nil } - return strconv.ParseInt(strings.TrimSpace(string(data)), 10, 64) + return strconv.ParseUint(max, 10, 64) } func (w *memoryLimitWatcher) Watch(ctx context.Context) error { @@ -624,7 +673,7 @@ func (w *memoryLimitWatcher) Watch(ctx context.Context) error { if err != nil { return err } - closeWatch := sync.OnceFunc(func() { + closeInotify := sync.OnceFunc(func() { log.Debug(ctx).Msg("stopping memory limit watcher") for { if err := syscall.Close(fd); !errors.Is(err, syscall.EINTR) { @@ -632,10 +681,19 @@ func (w *memoryLimitWatcher) Watch(ctx context.Context) error { } } }) - if _, err := syscall.InotifyAddWatch(fd, w.limitFilePath, syscall.IN_MODIFY); err != nil { - closeWatch() - return err + log.Debug(ctx).Str("file", w.limitFilePath).Msg("starting watch") + wd, err := syscall.InotifyAddWatch(fd, w.limitFilePath, syscall.IN_MODIFY) + if err != nil { + closeInotify() + return fmt.Errorf("failed to watch %s: %w", w.limitFilePath, err) } + w.watches.Add(1) + closeWatch := sync.OnceFunc(func() { + log.Debug(ctx).Str("file", w.limitFilePath).Msg("stopping watch") + syscall.InotifyRmWatch(fd, uint32(wd)) + closeInotify() + w.watches.Done() + }) // perform the initial read synchronously and only after setting up the watch v, err := w.readValue() @@ -644,21 +702,20 @@ func (w *memoryLimitWatcher) Watch(ctx context.Context) error { return err } w.value.Store(v) - log.Debug(ctx).Int64("bytes", v).Msg("current memory limit") + log.Debug(ctx).Uint64("bytes", v).Msg("current memory limit") context.AfterFunc(ctx, closeWatch) // to unblock syscall.Read below go func() { defer closeWatch() var buf [syscall.SizeofInotifyEvent]byte - for { + for ctx.Err() == nil { v, err := w.readValue() if err != nil { - return - } - if prev := w.value.Swap(v); prev != v { + log.Error(ctx).Err(err).Msg("error reading memory limit") + } else if prev := w.value.Swap(v); prev != v { log.Debug(ctx). - Int64("prev", prev). - Int64("current", v). + Uint64("prev", prev). + Uint64("current", v). Msg("memory limit updated") } _, err = syscall.Read(fd, buf[:]) @@ -673,3 +730,17 @@ func (w *memoryLimitWatcher) Watch(ctx context.Context) error { return nil } + +// Wait blocks until all watches have been closed. +// +// Example use: +// +// ctx, ca := context.WithCancel(context.Background()) +// w := &memoryLimitWatcher{...} +// w.Watch(ctx) +// ... +// ca() +// w.Wait() // blocks until the previous watch is closed +func (w *memoryLimitWatcher) Wait() { + w.watches.Wait() +} diff --git a/pkg/envoy/resource_monitor_test.go b/pkg/envoy/resource_monitor_test.go new file mode 100644 index 000000000..090d0df55 --- /dev/null +++ b/pkg/envoy/resource_monitor_test.go @@ -0,0 +1,787 @@ +//go:build linux + +package envoy + +import ( + "context" + "fmt" + "io/fs" + "maps" + "os" + "path" + "path/filepath" + "testing" + "testing/fstest" + "time" + + "github.com/pomerium/pomerium/config" + "github.com/pomerium/pomerium/config/envoyconfig" + "github.com/pomerium/pomerium/config/envoyconfig/filemgr" + "github.com/pomerium/pomerium/internal/testutil" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +var ( + file = func(data string, mode fs.FileMode) *fstest.MapFile { + // ensure the data always ends with a \n + if data != "" && data[len(data)-1] != '\n' { + data += "\n" + } + return &fstest.MapFile{Data: []byte(data), Mode: mode} + } + + v2Fs = fstest.MapFS{ + "sys/fs/cgroup/test/cgroup.type": file("domain", 0o644), + "sys/fs/cgroup/test/cgroup.controllers": file("memory", 0o444), + "sys/fs/cgroup/test/cgroup.subtree_control": file("", 0o644), + "sys/fs/cgroup/test/memory.current": file("100", 0o644), + "sys/fs/cgroup/test/memory.max": file("200", 0o644), + + "proc/1/cgroup": file("0::/test\n", 0o444), + "proc/2/cgroup": file("0::/test2 (deleted)\n", 0o444), + + "proc/1/mountinfo": file(` +24 30 0:22 / /proc rw,nosuid,nodev,noexec,relatime shared:5 - proc proc rw +25 30 0:23 / /sys rw,nosuid,nodev,noexec,relatime shared:6 - sysfs sys rw +33 25 0:28 / /sys/fs/cgroup rw,nosuid,nodev,noexec,relatime shared:9 - cgroup2 cgroup2 rw,nsdelegate,memory_recursiveprot +`[1:], 0o444), + } + + v1Fs = fstest.MapFS{ + "sys/fs/cgroup/memory/test/memory.usage_in_bytes": file("100", 0o644), + "sys/fs/cgroup/memory/test/memory.limit_in_bytes": file("200", 0o644), + + "proc/1/cgroup": file(` +1:memory:/test +0::/test +`[1:], 0o444), + "proc/1/mountinfo": file(` +26 31 0:24 / /sys rw,nosuid,nodev,noexec,relatime shared:7 - sysfs sysfs rw +27 31 0:5 / /proc rw,nosuid,nodev,noexec,relatime shared:14 - proc proc rw +31 1 252:1 / / rw,relatime shared:1 - ext4 /dev/vda1 rw,errors=remount-ro +35 26 0:29 / /sys/fs/cgroup ro,nosuid,nodev,noexec shared:9 - tmpfs tmpfs ro,mode=755 +40 35 0:34 / /sys/fs/cgroup/memory rw,nosuid,nodev,noexec,relatime shared:15 - cgroup cgroup rw,memory +`[1:], 0o444), + } + + v1ContainerFs = fstest.MapFS{ + "sys/fs/cgroup/memory/test/memory.usage_in_bytes": file("100", 0o644), + "sys/fs/cgroup/memory/test/memory.limit_in_bytes": file("200", 0o644), + + "proc/1/cgroup": file(` +1:memory:/test +0::/test +`[1:], 0o444), + "proc/1/mountinfo": file(` +1574 1573 0:138 / /proc rw,nosuid,nodev,noexec,relatime - proc proc rw +1578 1573 0:133 / /sys ro,nosuid,nodev,noexec,relatime - sysfs sysfs ro +1579 1578 0:141 / /sys/fs/cgroup rw,nosuid,nodev,noexec,relatime - tmpfs tmpfs rw,mode=755 +1586 1579 0:39 /test /sys/fs/cgroup/memory ro,nosuid,nodev,noexec,relatime master:20 - cgroup cgroup rw,memory +1311 1574 0:138 /sys /proc/sys ro,nosuid,nodev,noexec,relatime - proc proc rw +`[1:], 0o444), + } + + hybridFs = fstest.MapFS{ + "sys/fs/cgroup/memory/test/memory.usage_in_bytes": file("100", 0o644), + "sys/fs/cgroup/memory/test/memory.limit_in_bytes": file("200", 0o644), + "sys/fs/cgroup/unified/test/cgroup.type": file("domain", 0o644), + "sys/fs/cgroup/unified/test/cgroup.controllers": file("memory", 0o444), + "sys/fs/cgroup/unified/test/cgroup.subtree_control": file("", 0o644), + "sys/fs/cgroup/unified/test/memory.current": file("100", 0o644), + "sys/fs/cgroup/unified/test/memory.max": file("200", 0o644), + + "proc/1/cgroup": file(` +1:memory:/test +0::/test +`[1:], 0o444), + "proc/1/mountinfo": file(` +26 31 0:24 / /sys rw,nosuid,nodev,noexec,relatime shared:7 - sysfs sysfs rw +27 31 0:5 / /proc rw,nosuid,nodev,noexec,relatime shared:14 - proc proc rw +35 26 0:29 / /sys/fs/cgroup ro,nosuid,nodev,noexec shared:9 - tmpfs tmpfs ro,mode=755 +36 35 0:30 / /sys/fs/cgroup/unified rw,nosuid,nodev,noexec,relatime shared:10 - cgroup2 cgroup2 rw,nsdelegate +46 35 0:40 / /sys/fs/cgroup/memory rw,nosuid,nodev,noexec,relatime shared:21 - cgroup cgroup rw,memory +`[1:], 0o444), + } + + with = func(dest, src fstest.MapFS) fstest.MapFS { + dest = maps.Clone(dest) + for k, v := range src { + dest[k] = v + } + return dest + } + + without = func(fs fstest.MapFS, keys ...string) fstest.MapFS { + fs = maps.Clone(fs) + for _, k := range keys { + delete(fs, k) + } + return fs + } +) + +func TestCgroupV2Driver(t *testing.T) { + d := cgroupV2Driver{ + root: "sys/fs/cgroup", + fs: v2Fs, + } + t.Run("Path", func(t *testing.T) { + assert.Equal(t, "sys/fs/cgroup", d.Path("test", RootPath)) + assert.Equal(t, "sys/fs/cgroup/test/memory.current", d.Path("test", MemoryUsagePath)) + assert.Equal(t, "sys/fs/cgroup/test/memory.max", d.Path("test", MemoryLimitPath)) + assert.Equal(t, "", d.Path("test", CgroupFilePath(0xF00))) + }) + + t.Run("CgroupForPid", func(t *testing.T) { + cgroup, err := d.CgroupForPid(1) + assert.NoError(t, err) + assert.Equal(t, "/test", cgroup) + + cgroup, err = d.CgroupForPid(2) + assert.NoError(t, err) + assert.Equal(t, "/test2", cgroup) + + _, err = d.CgroupForPid(12345) + assert.Error(t, err) + }) + + t.Run("MemoryUsage", func(t *testing.T) { + cases := []struct { + fs fstest.MapFS + err string + usage uint64 + }{ + 0: { + fs: v2Fs, + usage: 100, + }, + 1: { + fs: with(v2Fs, fstest.MapFS{ + "sys/fs/cgroup/test/memory.current": file("invalid", 0o644), + }), + err: "strconv.ParseUint: parsing \"invalid\": invalid syntax", + }, + 2: { + fs: without(v2Fs, "sys/fs/cgroup/test/memory.current"), + err: "open sys/fs/cgroup/test/memory.current: file does not exist", + }, + } + + for i, c := range cases { + t.Run(fmt.Sprintf("case %d", i), func(t *testing.T) { + driver := cgroupV2Driver{ + root: "sys/fs/cgroup", + fs: c.fs, + } + usage, err := driver.MemoryUsage("test") + if c.err == "" { + assert.NoError(t, err) + assert.Equal(t, c.usage, usage) + } else { + assert.EqualError(t, err, c.err) + } + }) + } + }) + + t.Run("MemoryLimit", func(t *testing.T) { + cases := []struct { + fs fstest.MapFS + err string + limit uint64 + }{ + 0: { + fs: v2Fs, + limit: 200, + }, + 1: { + fs: with(v2Fs, fstest.MapFS{ + "sys/fs/cgroup/test/memory.max": file("max", 0o644), + }), + limit: 0, + }, + 2: { + fs: without(v2Fs, "sys/fs/cgroup/test/memory.max"), + err: "open sys/fs/cgroup/test/memory.max: file does not exist", + }, + 3: { + fs: with(v2Fs, fstest.MapFS{ + "sys/fs/cgroup/test/memory.max": file("invalid", 0o644), + }), + err: "strconv.ParseUint: parsing \"invalid\": invalid syntax", + }, + } + + for i, c := range cases { + t.Run(fmt.Sprintf("case %d", i), func(t *testing.T) { + driver := cgroupV2Driver{ + root: "sys/fs/cgroup", + fs: c.fs, + } + limit, err := driver.MemoryLimit("test") + if c.err == "" { + assert.NoError(t, err) + assert.Equal(t, c.limit, limit) + } else { + assert.EqualError(t, err, c.err) + } + }) + } + }) + + t.Run("Validate", func(t *testing.T) { + cases := []struct { + fs fstest.MapFS + root string // optional + err string + }{ + 0: {fs: v2Fs}, + 1: {fs: hybridFs, root: "sys/fs/cgroup/unified"}, + 2: { + fs: with(v2Fs, fstest.MapFS{ + "sys/fs/cgroup/test/cgroup.type": file("threaded", 0o644), + }), + err: "not a domain cgroup", + }, + 3: { + fs: with(v2Fs, fstest.MapFS{ + "sys/fs/cgroup/test/cgroup.subtree_control": file("cpu", 0o644), + }), + err: "not a leaf cgroup", + }, + 4: { + fs: with(v2Fs, fstest.MapFS{ + "sys/fs/cgroup/test/cgroup.controllers": file("cpu io", 0o444), + }), + err: "memory controller not enabled", + }, + 5: { + fs: without(v2Fs, "sys/fs/cgroup/test/cgroup.controllers"), + err: "open sys/fs/cgroup/test/cgroup.controllers: file does not exist", + }, + 6: { + fs: without(v2Fs, "sys/fs/cgroup/test/cgroup.type"), + err: "open sys/fs/cgroup/test/cgroup.type: file does not exist", + }, + 7: { + fs: without(v2Fs, "sys/fs/cgroup/test/cgroup.subtree_control"), + err: "open sys/fs/cgroup/test/cgroup.subtree_control: file does not exist", + }, + } + + for i, c := range cases { + t.Run(fmt.Sprintf("case %d", i), func(t *testing.T) { + driver := cgroupV2Driver{ + root: "sys/fs/cgroup", + fs: c.fs, + } + if c.root != "" { + driver.root = c.root + } + err := driver.Validate("test") + if c.err == "" { + assert.NoError(t, err) + } else { + assert.EqualError(t, err, c.err) + } + }) + } + }) +} + +func TestCgroupV1Driver(t *testing.T) { + d := cgroupV1Driver{ + root: "sys/fs/cgroup", + fs: v1Fs, + } + t.Run("Path", func(t *testing.T) { + assert.Equal(t, "sys/fs/cgroup", d.Path("test", RootPath)) + assert.Equal(t, "sys/fs/cgroup/memory/test/memory.usage_in_bytes", d.Path("test", MemoryUsagePath)) + assert.Equal(t, "sys/fs/cgroup/memory/test/memory.limit_in_bytes", d.Path("test", MemoryLimitPath)) + assert.Equal(t, "", d.Path("test", CgroupFilePath(0xF00))) + }) + + t.Run("CgroupForPid", func(t *testing.T) { + cgroup, err := d.CgroupForPid(1) + assert.NoError(t, err) + assert.Equal(t, "/test", cgroup) + + _, err = d.CgroupForPid(12345) + assert.Error(t, err) + }) + + t.Run("MemoryUsage", func(t *testing.T) { + cases := []struct { + fs fstest.MapFS + err string + usage uint64 + }{ + 0: { + fs: v1Fs, + usage: 100, + }, + 1: { + fs: with(v1Fs, fstest.MapFS{ + "sys/fs/cgroup/memory/test/memory.usage_in_bytes": file("invalid", 0o644), + }), + err: "strconv.ParseUint: parsing \"invalid\": invalid syntax", + }, + 2: { + fs: without(v1Fs, "sys/fs/cgroup/memory/test/memory.usage_in_bytes"), + err: "open sys/fs/cgroup/memory/test/memory.usage_in_bytes: file does not exist", + }, + } + + for i, c := range cases { + t.Run(fmt.Sprintf("case %d", i), func(t *testing.T) { + driver := cgroupV1Driver{ + root: "sys/fs/cgroup", + fs: c.fs, + } + usage, err := driver.MemoryUsage("test") + if c.err == "" { + assert.NoError(t, err) + assert.Equal(t, c.usage, usage) + } else { + assert.EqualError(t, err, c.err) + } + }) + } + }) + + t.Run("MemoryLimit", func(t *testing.T) { + cases := []struct { + fs fstest.MapFS + err string + limit uint64 + }{ + 0: { + fs: v1Fs, + limit: 200, + }, + 1: { + fs: with(v1Fs, fstest.MapFS{ + "sys/fs/cgroup/memory/test/memory.limit_in_bytes": file("max", 0o644), + }), + limit: 0, + }, + 2: { + fs: with(v1Fs, fstest.MapFS{ + "sys/fs/cgroup/memory/test/memory.limit_in_bytes": file("invalid", 0o644), + }), + err: "strconv.ParseUint: parsing \"invalid\": invalid syntax", + }, + 3: { + fs: without(v1Fs, "sys/fs/cgroup/memory/test/memory.limit_in_bytes"), + err: "open sys/fs/cgroup/memory/test/memory.limit_in_bytes: file does not exist", + }, + } + + for i, c := range cases { + t.Run(fmt.Sprintf("case %d", i), func(t *testing.T) { + driver := cgroupV1Driver{ + root: "sys/fs/cgroup", + fs: c.fs, + } + limit, err := driver.MemoryLimit("test") + if c.err == "" { + assert.NoError(t, err) + assert.Equal(t, c.limit, limit) + } else { + assert.EqualError(t, err, c.err) + } + }) + } + }) + + t.Run("Validate", func(t *testing.T) { + cases := []struct { + fs fstest.MapFS + err string + }{ + 0: {fs: v1Fs}, + 1: {fs: v1ContainerFs}, + 2: {fs: hybridFs}, + 3: { + fs: without(v1Fs, + "sys/fs/cgroup/memory/test/memory.usage_in_bytes", + "sys/fs/cgroup/memory/test/memory.limit_in_bytes", + ), + err: "memory controller not enabled", + }, + } + + for i, c := range cases { + t.Run(fmt.Sprintf("case %d", i), func(t *testing.T) { + driver := cgroupV1Driver{ + root: "sys/fs/cgroup", + fs: c.fs, + } + err := driver.Validate("test") + if c.err == "" { + assert.NoError(t, err) + } else { + assert.EqualError(t, err, c.err) + } + }) + } + }) + + t.Run("Container FS", func(t *testing.T) { + driver := cgroupV1Driver{ + root: "sys/fs/cgroup", + fs: v1ContainerFs, + } + cgroup, err := driver.CgroupForPid(1) + assert.NoError(t, err) + assert.Equal(t, "/", cgroup) + }) + + t.Run("Hybrid FS", func(t *testing.T) { + driver := cgroupV1Driver{ + root: "sys/fs/cgroup", + fs: hybridFs, + } + cgroup, err := driver.CgroupForPid(1) + assert.NoError(t, err) + assert.Equal(t, "/test", cgroup) + + driver2 := cgroupV2Driver{ + root: "sys/fs/cgroup/unified", + fs: hybridFs, + } + cgroup, err = driver2.CgroupForPid(1) + assert.NoError(t, err) + assert.Equal(t, "/test", cgroup) + }) +} + +func TestFindMountpoint(t *testing.T) { + withActualPid := func(fs fstest.MapFS) fstest.MapFS { + fs = maps.Clone(fs) + fs[fmt.Sprintf("proc/%d/cgroup", os.Getpid())] = fs["proc/1/cgroup"] + fs[fmt.Sprintf("proc/%d/mountinfo", os.Getpid())] = fs["proc/1/mountinfo"] + return fs + } + cases := []struct { + fsys fs.FS + + mountpoint string + isV2 bool + err string + }{ + 0: { + fsys: withActualPid(v2Fs), + mountpoint: "sys/fs/cgroup", + isV2: true, + }, + 1: { + fsys: withActualPid(v1Fs), + mountpoint: "sys/fs/cgroup", + isV2: false, + }, + 2: { + fsys: withActualPid(hybridFs), + mountpoint: "sys/fs/cgroup/unified", + isV2: true, + }, + } + + for i, c := range cases { + t.Run(fmt.Sprintf("case %d", i), func(t *testing.T) { + mountpoint, isV2, err := findMountpoint(c.fsys) + if c.err == "" { + assert.NoError(t, err) + assert.Equal(t, c.mountpoint, mountpoint) + assert.Equal(t, c.isV2, isV2) + } else { + assert.EqualError(t, err, c.err) + } + }) + } +} + +type hybridTestFS struct { + base fstest.MapFS + tempDir string +} + +var _ fs.FS = (*hybridTestFS)(nil) + +func (fs *hybridTestFS) Open(name string) (fs.File, error) { + switch base := path.Base(name); base { + case "memory.current", "memory.max": + return os.Open(filepath.Join(fs.tempDir, ".fs", base)) + } + return fs.base.Open(name) +} + +func (fs *hybridTestFS) ReadFile(name string) ([]byte, error) { + switch base := path.Base(name); base { + case "memory.current", "memory.max": + return os.ReadFile(filepath.Join(fs.tempDir, ".fs", base)) + } + return fs.base.ReadFile(name) +} + +func (fs *hybridTestFS) Stat(name string) (fs.FileInfo, error) { + switch base := path.Base(name); base { + case "memory.current", "memory.max": + return os.Stat(filepath.Join(fs.tempDir, ".fs", base)) + } + return fs.base.Stat(name) +} + +type pathOverrideDriver struct { + CgroupDriver + overrides map[CgroupFilePath]string +} + +var _ CgroupDriver = (*pathOverrideDriver)(nil) + +func (d *pathOverrideDriver) Path(name string, path CgroupFilePath) string { + if override, ok := d.overrides[path]; ok { + return override + } + return d.CgroupDriver.Path(name, path) +} + +func TestSharedResourceMonitor(t *testing.T) { + // set shorter intervals for testing + var prevInitial, prevMin, prevMax time.Duration + monitorInitialTickDuration, prevInitial = 0, monitorInitialTickDuration + monitorMaxTickDuration, prevMax = 100*time.Millisecond, monitorMaxTickDuration + monitorMinTickDuration, prevMin = 10*time.Millisecond, monitorMinTickDuration + t.Cleanup(func() { + monitorInitialTickDuration = prevInitial + monitorMaxTickDuration = prevMax + monitorMinTickDuration = prevMin + }) + + testEnvoyPid := 99 + tempDir := t.TempDir() + require.NoError(t, os.Mkdir(filepath.Join(tempDir, ".fs"), 0o777)) + + testMemoryCurrentPath := filepath.Join(tempDir, ".fs/memory.current") + testMemoryMaxPath := filepath.Join(tempDir, ".fs/memory.max") + + updateMemoryCurrent := func(value string) { + t.Log("updating memory.current to", value) + f, err := os.OpenFile(testMemoryCurrentPath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0o644) + require.NoError(t, err) + f.WriteString(value) + require.NoError(t, f.Close()) + } + + updateMemoryMax := func(value string) { + t.Log("updating memory.max to", value) + f, err := os.OpenFile(testMemoryMaxPath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0o644) + require.NoError(t, err) + f.WriteString(value) + require.NoError(t, f.Close()) + } + + updateMemoryCurrent("100") + updateMemoryMax("200") + + driver := &pathOverrideDriver{ + CgroupDriver: &cgroupV2Driver{ + root: "sys/fs/cgroup", + fs: &hybridTestFS{ + base: with(v2Fs, fstest.MapFS{ + fmt.Sprintf("proc/%d/cgroup", os.Getpid()): v2Fs["proc/1/cgroup"], + fmt.Sprintf("proc/%d/mountinfo", os.Getpid()): v2Fs["proc/1/mountinfo"], + fmt.Sprintf("proc/%d/cgroup", testEnvoyPid): v2Fs["proc/1/cgroup"], + fmt.Sprintf("proc/%d/mountinfo", testEnvoyPid): v2Fs["proc/1/mountinfo"], + }), + tempDir: tempDir, + }, + }, + overrides: map[CgroupFilePath]string{ + MemoryUsagePath: testMemoryCurrentPath, + MemoryLimitPath: testMemoryMaxPath, + }, + } + + monitor, err := NewSharedResourceMonitor(tempDir, WithCgroupDriver(driver)) + require.NoError(t, err) + + readMemorySaturation := func(t assert.TestingT) string { + f, err := os.ReadFile(filepath.Join(tempDir, "resource_monitor/memory/cgroup_memory_saturation")) + assert.NoError(t, err) + return string(f) + } + + assert.Equal(t, "0", readMemorySaturation(t)) + + ctx, ca := context.WithCancel(context.Background()) + + errC := make(chan error) + go func() { + defer close(errC) + errC <- monitor.Run(ctx, testEnvoyPid) + }() + + timeout := 1 * time.Second + interval := 10 * time.Millisecond + assert.EventuallyWithT(t, func(c *assert.CollectT) { + assert.Equal(c, "0.500000", readMemorySaturation(c)) + }, timeout, interval) + + updateMemoryCurrent("150") + assert.EventuallyWithT(t, func(c *assert.CollectT) { + assert.Equal(c, "0.750000", readMemorySaturation(c)) + }, timeout, interval) + + updateMemoryMax("300") + assert.EventuallyWithT(t, func(c *assert.CollectT) { + assert.Equal(c, "0.500000", readMemorySaturation(c)) + }, timeout, interval) + + updateMemoryMax("max") + assert.EventuallyWithT(t, func(c *assert.CollectT) { + assert.Equal(c, "0.000000", readMemorySaturation(c)) + }, timeout, interval) + + updateMemoryMax("150") + assert.EventuallyWithT(t, func(c *assert.CollectT) { + assert.Equal(c, "1.000000", readMemorySaturation(c)) + }, timeout, interval) + + ca() + assert.ErrorIs(t, <-errC, context.Canceled) +} + +func TestBootstrapConfig(t *testing.T) { + b := envoyconfig.New("localhost:1111", "localhost:2222", "localhost:3333", filemgr.NewManager(), nil) + testEnvoyPid := 99 + tempDir := t.TempDir() + monitor, err := NewSharedResourceMonitor(tempDir, WithCgroupDriver(&cgroupV2Driver{ + root: "sys/fs/cgroup", + fs: &hybridTestFS{ + base: with(v2Fs, fstest.MapFS{ + fmt.Sprintf("proc/%d/cgroup", os.Getpid()): v2Fs["proc/1/cgroup"], + fmt.Sprintf("proc/%d/mountinfo", os.Getpid()): v2Fs["proc/1/mountinfo"], + fmt.Sprintf("proc/%d/cgroup", testEnvoyPid): v2Fs["proc/1/cgroup"], + fmt.Sprintf("proc/%d/mountinfo", testEnvoyPid): v2Fs["proc/1/mountinfo"], + }), + tempDir: tempDir, + }, + })) + require.NoError(t, err) + + bootstrap, err := b.BuildBootstrap(context.Background(), &config.Config{ + Options: &config.Options{ + EnvoyAdminAddress: "localhost:9901", + }, + }, false) + assert.NoError(t, err) + + monitor.ApplyBootstrapConfig(bootstrap) + + testutil.AssertProtoJSONEqual(t, fmt.Sprintf(` + { + "actions": [ + { + "name": "envoy.overload_actions.shrink_heap", + "triggers": [ + { + "name": "envoy.resource_monitors.injected_resource", + "threshold": { + "value": 0.9 + } + } + ] + }, + { + "name": "envoy.overload_actions.reduce_timeouts", + "triggers": [ + { + "name": "envoy.resource_monitors.injected_resource", + "scaled": { + "saturationThreshold": 0.95, + "scalingThreshold": 0.85 + } + } + ], + "typedConfig": { + "@type": "type.googleapis.com/envoy.config.overload.v3.ScaleTimersOverloadActionConfig", + "timerScaleFactors": [ + { + "minScale": { + "value": 50 + }, + "timer": "HTTP_DOWNSTREAM_CONNECTION_IDLE" + } + ] + } + }, + { + "name": "envoy.overload_actions.reset_high_memory_stream", + "triggers": [ + { + "name": "envoy.resource_monitors.injected_resource", + "scaled": { + "saturationThreshold": 0.98, + "scalingThreshold": 0.9 + } + } + ] + }, + { + "name": "envoy.overload_actions.stop_accepting_connections", + "triggers": [ + { + "name": "envoy.resource_monitors.injected_resource", + "threshold": { + "value": 0.95 + } + } + ] + }, + { + "name": "envoy.overload_actions.disable_http_keepalive", + "triggers": [ + { + "name": "envoy.resource_monitors.injected_resource", + "threshold": { + "value": 0.98 + } + } + ] + }, + { + "name": "envoy.overload_actions.stop_accepting_requests", + "triggers": [ + { + "name": "envoy.resource_monitors.injected_resource", + "threshold": { + "value": 0.99 + } + } + ] + } + ], + "bufferFactoryConfig": { + "minimumAccountToTrackPowerOfTwo": 20 + }, + "resourceMonitors": [ + { + "name": "envoy.resource_monitors.global_downstream_max_connections", + "typedConfig": { + "@type": "type.googleapis.com/envoy.extensions.resource_monitors.downstream_connections.v3.DownstreamConnectionsConfig", + "maxActiveDownstreamConnections": "50000" + } + }, + { + "name": "envoy.resource_monitors.injected_resource", + "typedConfig": { + "@type": "type.googleapis.com/envoy.extensions.resource_monitors.injected_resource.v3.InjectedResourceConfig", + "filename": "%s/resource_monitor/memory/cgroup_memory_saturation" + } + } + ] + } + `, tempDir), bootstrap.OverloadManager) +}