mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-24 22:47:14 +02:00
netutil: improve port allocation (#5485)
This commit is contained in:
parent
fbd1f34110
commit
4b95eda51e
2 changed files with 54 additions and 5 deletions
|
@ -1,21 +1,46 @@
|
|||
// Package netutil contains various functions that help with networking.
|
||||
package netutil
|
||||
|
||||
import "net"
|
||||
import (
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
allocatedPortsMu sync.Mutex
|
||||
allocatedPorts = map[string]time.Time{}
|
||||
)
|
||||
|
||||
// AllocatePorts allocates random ports suitable for listening.
|
||||
func AllocatePorts(count int) ([]string, error) {
|
||||
allocatedPortsMu.Lock()
|
||||
defer allocatedPortsMu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
cooloff := 10 * time.Minute
|
||||
// clear any expired ports
|
||||
for port, tm := range allocatedPorts {
|
||||
if tm.Add(cooloff).Before(now) {
|
||||
delete(allocatedPorts, port)
|
||||
}
|
||||
}
|
||||
|
||||
var ports []string
|
||||
for i := 0; i < count; i++ {
|
||||
for len(ports) < count {
|
||||
li, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
_, port, _ := net.SplitHostPort(li.Addr().String())
|
||||
err = li.Close()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
defer li.Close()
|
||||
|
||||
// if this port has been allocated recently, skip it
|
||||
if _, ok := allocatedPorts[port]; ok {
|
||||
continue
|
||||
}
|
||||
|
||||
allocatedPorts[port] = now
|
||||
ports = append(ports, port)
|
||||
}
|
||||
return ports, nil
|
||||
|
|
24
pkg/netutil/netutil_test.go
Normal file
24
pkg/netutil/netutil_test.go
Normal file
|
@ -0,0 +1,24 @@
|
|||
package netutil_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/minio/minio-go/v7/pkg/set"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/pomerium/pomerium/pkg/netutil"
|
||||
)
|
||||
|
||||
func TestAllocatePorts(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
seen := set.NewStringSet()
|
||||
for i := 0; i < 100; i++ {
|
||||
ports, err := netutil.AllocatePorts(3)
|
||||
assert.NoError(t, err)
|
||||
for _, p := range ports {
|
||||
assert.False(t, seen.Contains(p), "should not re-use ports")
|
||||
seen.Add(p)
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue