From 4b95eda51e6590151282c320dd637e64ccb01643 Mon Sep 17 00:00:00 2001 From: Caleb Doxsey Date: Wed, 19 Feb 2025 09:45:21 -0700 Subject: [PATCH] netutil: improve port allocation (#5485) --- pkg/netutil/netutil.go | 35 ++++++++++++++++++++++++++++++----- pkg/netutil/netutil_test.go | 24 ++++++++++++++++++++++++ 2 files changed, 54 insertions(+), 5 deletions(-) create mode 100644 pkg/netutil/netutil_test.go diff --git a/pkg/netutil/netutil.go b/pkg/netutil/netutil.go index 5197f12fb..dac866f7d 100644 --- a/pkg/netutil/netutil.go +++ b/pkg/netutil/netutil.go @@ -1,21 +1,46 @@ // Package netutil contains various functions that help with networking. package netutil -import "net" +import ( + "net" + "sync" + "time" +) + +var ( + allocatedPortsMu sync.Mutex + allocatedPorts = map[string]time.Time{} +) // AllocatePorts allocates random ports suitable for listening. func AllocatePorts(count int) ([]string, error) { + allocatedPortsMu.Lock() + defer allocatedPortsMu.Unlock() + + now := time.Now() + cooloff := 10 * time.Minute + // clear any expired ports + for port, tm := range allocatedPorts { + if tm.Add(cooloff).Before(now) { + delete(allocatedPorts, port) + } + } + var ports []string - for i := 0; i < count; i++ { + for len(ports) < count { li, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { return nil, err } _, port, _ := net.SplitHostPort(li.Addr().String()) - err = li.Close() - if err != nil { - return nil, err + defer li.Close() + + // if this port has been allocated recently, skip it + if _, ok := allocatedPorts[port]; ok { + continue } + + allocatedPorts[port] = now ports = append(ports, port) } return ports, nil diff --git a/pkg/netutil/netutil_test.go b/pkg/netutil/netutil_test.go new file mode 100644 index 000000000..ca6716bda --- /dev/null +++ b/pkg/netutil/netutil_test.go @@ -0,0 +1,24 @@ +package netutil_test + +import ( + "testing" + + "github.com/minio/minio-go/v7/pkg/set" + "github.com/stretchr/testify/assert" + + "github.com/pomerium/pomerium/pkg/netutil" +) + +func TestAllocatePorts(t *testing.T) { + t.Parallel() + + seen := set.NewStringSet() + for i := 0; i < 100; i++ { + ports, err := netutil.AllocatePorts(3) + assert.NoError(t, err) + for _, p := range ports { + assert.False(t, seen.Contains(p), "should not re-use ports") + seen.Add(p) + } + } +}