Add unit tests

This commit is contained in:
Joe Kralicky 2024-05-10 15:02:50 -04:00
parent 560d2b7ebc
commit 8444bd26b8
No known key found for this signature in database
GPG key ID: 75C4875F34A9FB79
2 changed files with 938 additions and 80 deletions

View file

@ -125,23 +125,48 @@ const (
metricCgroupMemorySaturation = "cgroup_memory_saturation"
)
type ResourceMonitorOptions struct {
driver CgroupDriver
}
type ResourceMonitorOption func(*ResourceMonitorOptions)
func (o *ResourceMonitorOptions) apply(opts ...ResourceMonitorOption) {
for _, op := range opts {
op(o)
}
}
// WithCgroupDriver overrides the cgroup driver used for the resource monitor.
// If unset, it will be chosen automatically.
func WithCgroupDriver(driver CgroupDriver) ResourceMonitorOption {
return func(o *ResourceMonitorOptions) {
o.driver = driver
}
}
// NewSharedResourceMonitor creates a new ResourceMonitor suitable for running
// envoy in the same cgroup as the parent process. It reports the cgroup's
// memory saturation to envoy as an injected resource. This allows envoy to
// react to actual memory pressure in the cgroup, taking into account memory
// usage from pomerium itself.
func NewSharedResourceMonitor(tempDir string) (ResourceMonitor, error) {
driver, err := SystemCgroupDriver()
if err != nil {
return nil, err
func NewSharedResourceMonitor(tempDir string, opts ...ResourceMonitorOption) (ResourceMonitor, error) {
options := ResourceMonitorOptions{}
options.apply(opts...)
if options.driver == nil {
var err error
options.driver, err = DetectCgroupDriver()
if err != nil {
return nil, err
}
}
recordActionThresholds()
selfCgroup, err := driver.CgroupForPid(os.Getpid())
selfCgroup, err := options.driver.CgroupForPid(os.Getpid())
if err != nil {
return nil, fmt.Errorf("failed to look up cgroup: %w", err)
}
if err := driver.Validate(selfCgroup); err != nil {
if err := options.driver.Validate(selfCgroup); err != nil {
return nil, fmt.Errorf("cgroup not valid for resource monitoring: %w", err)
}
@ -150,9 +175,9 @@ func NewSharedResourceMonitor(tempDir string) (ResourceMonitor, error) {
}
s := &sharedResourceMonitor{
driver: driver,
cgroup: selfCgroup,
tempDir: filepath.Join(tempDir, "resource_monitor"),
ResourceMonitorOptions: options,
cgroup: selfCgroup,
tempDir: filepath.Join(tempDir, "resource_monitor"),
}
if err := s.writeMetricFile(groupMemory, metricCgroupMemorySaturation, "0", 0o644); err != nil {
@ -162,7 +187,7 @@ func NewSharedResourceMonitor(tempDir string) (ResourceMonitor, error) {
}
type sharedResourceMonitor struct {
driver CgroupDriver
ResourceMonitorOptions
cgroup string
tempDir string
}
@ -226,6 +251,12 @@ func (s *sharedResourceMonitor) ApplyBootstrapConfig(bootstrap *envoy_config_boo
}
}
var (
monitorInitialTickDuration = 1 * time.Second
monitorMaxTickDuration = 5 * time.Second
monitorMinTickDuration = 250 * time.Millisecond
)
func (s *sharedResourceMonitor) Run(ctx context.Context, envoyPid int) error {
envoyCgroup, err := s.driver.CgroupForPid(envoyPid)
if err != nil {
@ -237,9 +268,14 @@ func (s *sharedResourceMonitor) Run(ctx context.Context, envoyPid int) error {
log.Info(ctx).Str("service", "envoy").Str("cgroup", s.cgroup).Msg("starting resource monitor")
limitWatcher := &memoryLimitWatcher{
limitFilePath: s.driver.Path(s.cgroup, MemoryLimitPath),
limitFilePath: "/" + s.driver.Path(s.cgroup, MemoryLimitPath),
}
if err := limitWatcher.Watch(ctx); err != nil {
lwCtx, lwCancel := context.WithCancel(ctx)
defer func() {
lwCancel()
limitWatcher.Wait()
}()
if err := limitWatcher.Watch(lwCtx); err != nil {
return fmt.Errorf("failed to start watch on cgroup memory limit: %w", err)
}
@ -256,10 +292,8 @@ func (s *sharedResourceMonitor) Run(ctx context.Context, envoyPid int) error {
// taking disruptive actions for as long as possible.
// the envoy default interval for the builtin heap monitor is 1s
initialTickDuration := 1 * time.Second
maxTickDuration := 5 * time.Second
minTickDuration := 250 * time.Millisecond
tick := time.NewTimer(initialTickDuration)
tick := time.NewTimer(monitorInitialTickDuration)
var lastValue string
for {
select {
@ -278,7 +312,7 @@ func (s *sharedResourceMonitor) Run(ctx context.Context, envoyPid int) error {
}
saturationStr := fmt.Sprintf("%.6f", saturation)
nextInterval := (maxTickDuration - (time.Duration(float64(maxTickDuration-minTickDuration) * saturation))).
nextInterval := (monitorMaxTickDuration - (time.Duration(float64(monitorMaxTickDuration-monitorMinTickDuration) * saturation))).
Round(time.Millisecond)
if saturationStr != lastValue {
@ -348,6 +382,7 @@ func (s *sharedResourceMonitor) writeMetricFile(group, name, data string, mode f
}
type cgroupV2Driver struct {
fs fs.FS
root string
}
@ -363,8 +398,8 @@ func (d *cgroupV2Driver) Path(cgroup string, kind CgroupFilePath) string {
return ""
}
func (*cgroupV2Driver) CgroupForPid(pid int) (string, error) {
data, err := os.ReadFile(fmt.Sprintf("/proc/%d/cgroup", pid))
func (d *cgroupV2Driver) CgroupForPid(pid int) (string, error) {
data, err := fs.ReadFile(d.fs, fmt.Sprintf("proc/%d/cgroup", pid))
if err != nil {
return "", err
}
@ -373,7 +408,7 @@ func (*cgroupV2Driver) CgroupForPid(pid int) (string, error) {
// MemoryUsage implements CgroupDriver.
func (d *cgroupV2Driver) MemoryUsage(cgroup string) (uint64, error) {
current, err := os.ReadFile(d.Path(cgroup, MemoryUsagePath))
current, err := fs.ReadFile(d.fs, d.Path(cgroup, MemoryUsagePath))
if err != nil {
return 0, err
}
@ -382,19 +417,20 @@ func (d *cgroupV2Driver) MemoryUsage(cgroup string) (uint64, error) {
// MemoryLimit implements CgroupDriver.
func (d *cgroupV2Driver) MemoryLimit(cgroup string) (uint64, error) {
max, err := os.ReadFile(d.Path(cgroup, MemoryLimitPath))
data, err := fs.ReadFile(d.fs, d.Path(cgroup, MemoryLimitPath))
if err != nil {
return 0, err
}
if string(max) == "max" {
max := strings.TrimSpace(string(data))
if max == "max" {
return 0, nil
}
return strconv.ParseUint(strings.TrimSpace(string(max)), 10, 64)
return strconv.ParseUint(max, 10, 64)
}
// Validate implements CgroupDriver.
func (d *cgroupV2Driver) Validate(cgroup string) error {
if typ, err := os.ReadFile(filepath.Join(d.root, cgroup, "cgroup.type")); err != nil {
if typ, err := fs.ReadFile(d.fs, filepath.Join(d.root, cgroup, "cgroup.type")); err != nil {
return err
} else if strings.TrimSpace(string(typ)) != "domain" {
return errors.New("not a domain cgroup")
@ -416,7 +452,7 @@ func (d *cgroupV2Driver) Validate(cgroup string) error {
}
func (d *cgroupV2Driver) enabledControllers(cgroup string) ([]string, error) {
data, err := os.ReadFile(filepath.Join(d.root, cgroup, "cgroup.controllers"))
data, err := fs.ReadFile(d.fs, filepath.Join(d.root, cgroup, "cgroup.controllers"))
if err != nil {
return nil, err
}
@ -424,7 +460,7 @@ func (d *cgroupV2Driver) enabledControllers(cgroup string) ([]string, error) {
}
func (d *cgroupV2Driver) enabledSubtreeControllers(cgroup string) ([]string, error) {
data, err := os.ReadFile(filepath.Join(d.root, cgroup, "cgroup.subtree_control"))
data, err := fs.ReadFile(d.fs, filepath.Join(d.root, cgroup, "cgroup.subtree_control"))
if err != nil {
return nil, err
}
@ -434,6 +470,7 @@ func (d *cgroupV2Driver) enabledSubtreeControllers(cgroup string) ([]string, err
var _ CgroupDriver = (*cgroupV2Driver)(nil)
type cgroupV1Driver struct {
fs fs.FS
root string
}
@ -450,7 +487,7 @@ func (d *cgroupV1Driver) Path(cgroup string, kind CgroupFilePath) string {
}
func (d *cgroupV1Driver) CgroupForPid(pid int) (string, error) {
data, err := os.ReadFile(fmt.Sprintf("/proc/%d/cgroup", pid))
data, err := fs.ReadFile(d.fs, fmt.Sprintf("proc/%d/cgroup", pid))
if err != nil {
return "", err
}
@ -459,7 +496,7 @@ func (d *cgroupV1Driver) CgroupForPid(pid int) (string, error) {
return "", err
}
mountinfo, err := os.ReadFile(fmt.Sprintf("/proc/%d/mountinfo", pid))
mountinfo, err := fs.ReadFile(d.fs, fmt.Sprintf("proc/%d/mountinfo", pid))
if err != nil {
return "", err
}
@ -469,11 +506,12 @@ func (d *cgroupV1Driver) CgroupForPid(pid int) (string, error) {
if len(line) < 5 {
continue
}
// entries 3 and 4 contain the cgroup path and the mountpoint, respectively.
// Entries 3 and 4 contain the root path and the mountpoint, respectively.
// each resource will contain a separate mountpoint for the same path, so
// we can just pick the first one.
if line[3] == name {
mountpoint, err := filepath.Rel(d.root, filepath.Dir(line[4]))
// we can just pick one.
if line[4] == fmt.Sprintf("/%s/memory", d.root) {
mountpoint, err := filepath.Rel(line[3], name)
if err != nil {
return "", err
}
@ -485,7 +523,7 @@ func (d *cgroupV1Driver) CgroupForPid(pid int) (string, error) {
// MemoryUsage implements CgroupDriver.
func (d *cgroupV1Driver) MemoryUsage(cgroup string) (uint64, error) {
current, err := os.ReadFile(d.Path(cgroup, MemoryUsagePath))
current, err := fs.ReadFile(d.fs, d.Path(cgroup, MemoryUsagePath))
if err != nil {
return 0, err
}
@ -494,22 +532,23 @@ func (d *cgroupV1Driver) MemoryUsage(cgroup string) (uint64, error) {
// MemoryLimit implements CgroupDriver.
func (d *cgroupV1Driver) MemoryLimit(cgroup string) (uint64, error) {
max, err := os.ReadFile(d.Path(cgroup, MemoryLimitPath))
data, err := fs.ReadFile(d.fs, d.Path(cgroup, MemoryLimitPath))
if err != nil {
return 0, err
}
if string(max) == "max" {
max := strings.TrimSpace(string(data))
if max == "max" {
return 0, nil
}
return strconv.ParseUint(strings.TrimSpace(string(max)), 10, 64)
return strconv.ParseUint(max, 10, 64)
}
// Validate implements CgroupDriver.
func (d *cgroupV1Driver) Validate(cgroup string) error {
memoryPath := filepath.Join(d.root, "memory", cgroup)
info, err := os.Stat(memoryPath)
info, err := fs.Stat(d.fs, memoryPath)
if err != nil {
if errors.Is(err, os.ErrNotExist) {
if errors.Is(err, fs.ErrNotExist) {
return errors.New("memory controller not enabled")
}
return fmt.Errorf("failed to stat cgroup: %w", err)
@ -522,7 +561,7 @@ func (d *cgroupV1Driver) Validate(cgroup string) error {
var _ CgroupDriver = (*cgroupV1Driver)(nil)
func SystemCgroupDriver() (CgroupDriver, error) {
func DetectCgroupDriver() (CgroupDriver, error) {
const cgv2Magic = 0x63677270
fsType := func(path string) (int64, error) {
@ -538,6 +577,7 @@ func SystemCgroupDriver() (CgroupDriver, error) {
return stat.Type, nil
}
}
osFs := os.DirFS("/")
// fast path: cgroup2 only
t, err := fsType("/sys/fs/cgroup")
@ -545,36 +585,18 @@ func SystemCgroupDriver() (CgroupDriver, error) {
return nil, err
}
if t == cgv2Magic {
return &cgroupV2Driver{root: "/sys/fs/cgroup"}, nil
return &cgroupV2Driver{root: "sys/fs/cgroup", fs: osFs}, nil
}
// find the unified mountpoint, or fall back to v1
mounts, err := os.ReadFile("/proc/self/mounts")
// find the hybrid mountpoint, or fall back to v1
mountpoint, isV2, err := findMountpoint(osFs)
if err != nil {
return nil, err
}
scanner := bufio.NewScanner(bytes.NewReader(mounts))
var cgv1Root string
for scanner.Scan() {
line := strings.Fields(scanner.Text())
if len(line) < 3 {
continue
}
switch line[2] {
case "cgroup2":
return &cgroupV2Driver{root: line[1]}, nil
case "cgroup":
if cgv1Root == "" {
cgv1Root = filepath.Dir(line[1])
}
}
if isV2 {
return &cgroupV2Driver{root: mountpoint, fs: osFs}, nil
}
if cgv1Root != "" {
return &cgroupV1Driver{root: cgv1Root}, nil
}
return nil, errors.New("no cgroup mount found")
return &cgroupV1Driver{root: mountpoint, fs: osFs}, nil
}
func parseCgroupName(contents []byte) (string, error) {
@ -588,6 +610,31 @@ func parseCgroupName(contents []byte) (string, error) {
return "", errors.New("cgroup not found")
}
func findMountpoint(fsys fs.FS) (mountpoint string, isV2 bool, err error) {
mounts, err := fs.ReadFile(fsys, fmt.Sprintf("proc/%d/mountinfo", os.Getpid()))
if err != nil {
return "", false, err
}
scanner := bufio.NewScanner(bytes.NewReader(mounts))
var cgv1Root string
for scanner.Scan() {
line := strings.Fields(scanner.Text())
fsType := line[slices.Index(line, "-")+1]
switch fsType {
case "cgroup2":
return line[4][1:], true, nil
case "cgroup":
if cgv1Root == "" {
cgv1Root = filepath.Dir(line[4][1:])
}
}
}
if cgv1Root == "" {
return "", false, errors.New("no cgroup mount found")
}
return cgv1Root, false, nil
}
func marshalAny(msg proto.Message) *anypb.Any {
data := new(anypb.Any)
_ = anypb.MarshalFrom(data, msg, proto.MarshalOptions{
@ -600,23 +647,25 @@ func marshalAny(msg proto.Message) *anypb.Any {
type memoryLimitWatcher struct {
limitFilePath string
value atomic.Int64
value atomic.Uint64
watches sync.WaitGroup
}
func (w *memoryLimitWatcher) Value() int64 {
func (w *memoryLimitWatcher) Value() uint64 {
return w.value.Load()
}
func (w *memoryLimitWatcher) readValue() (int64, error) {
func (w *memoryLimitWatcher) readValue() (uint64, error) {
data, err := os.ReadFile(w.limitFilePath)
if err != nil {
return 0, err
}
if string(data) == "max" {
// no limit set
max := strings.TrimSpace(string(data))
if max == "max" {
return 0, nil
}
return strconv.ParseInt(strings.TrimSpace(string(data)), 10, 64)
return strconv.ParseUint(max, 10, 64)
}
func (w *memoryLimitWatcher) Watch(ctx context.Context) error {
@ -624,7 +673,7 @@ func (w *memoryLimitWatcher) Watch(ctx context.Context) error {
if err != nil {
return err
}
closeWatch := sync.OnceFunc(func() {
closeInotify := sync.OnceFunc(func() {
log.Debug(ctx).Msg("stopping memory limit watcher")
for {
if err := syscall.Close(fd); !errors.Is(err, syscall.EINTR) {
@ -632,10 +681,19 @@ func (w *memoryLimitWatcher) Watch(ctx context.Context) error {
}
}
})
if _, err := syscall.InotifyAddWatch(fd, w.limitFilePath, syscall.IN_MODIFY); err != nil {
closeWatch()
return err
log.Debug(ctx).Str("file", w.limitFilePath).Msg("starting watch")
wd, err := syscall.InotifyAddWatch(fd, w.limitFilePath, syscall.IN_MODIFY)
if err != nil {
closeInotify()
return fmt.Errorf("failed to watch %s: %w", w.limitFilePath, err)
}
w.watches.Add(1)
closeWatch := sync.OnceFunc(func() {
log.Debug(ctx).Str("file", w.limitFilePath).Msg("stopping watch")
syscall.InotifyRmWatch(fd, uint32(wd))
closeInotify()
w.watches.Done()
})
// perform the initial read synchronously and only after setting up the watch
v, err := w.readValue()
@ -644,21 +702,20 @@ func (w *memoryLimitWatcher) Watch(ctx context.Context) error {
return err
}
w.value.Store(v)
log.Debug(ctx).Int64("bytes", v).Msg("current memory limit")
log.Debug(ctx).Uint64("bytes", v).Msg("current memory limit")
context.AfterFunc(ctx, closeWatch) // to unblock syscall.Read below
go func() {
defer closeWatch()
var buf [syscall.SizeofInotifyEvent]byte
for {
for ctx.Err() == nil {
v, err := w.readValue()
if err != nil {
return
}
if prev := w.value.Swap(v); prev != v {
log.Error(ctx).Err(err).Msg("error reading memory limit")
} else if prev := w.value.Swap(v); prev != v {
log.Debug(ctx).
Int64("prev", prev).
Int64("current", v).
Uint64("prev", prev).
Uint64("current", v).
Msg("memory limit updated")
}
_, err = syscall.Read(fd, buf[:])
@ -673,3 +730,17 @@ func (w *memoryLimitWatcher) Watch(ctx context.Context) error {
return nil
}
// Wait blocks until all watches have been closed.
//
// Example use:
//
// ctx, ca := context.WithCancel(context.Background())
// w := &memoryLimitWatcher{...}
// w.Watch(ctx)
// ...
// ca()
// w.Wait() // blocks until the previous watch is closed
func (w *memoryLimitWatcher) Wait() {
w.watches.Wait()
}