mirror of
https://github.com/pomerium/pomerium.git
synced 2025-08-03 08:50:42 +02:00
Code cleanup; simplify SharedURL
This commit is contained in:
parent
bb225eb845
commit
459c73d461
8 changed files with 104 additions and 47 deletions
|
@ -631,7 +631,7 @@ func addCAToBundle(bundle *bytes.Buffer, ca []byte) {
|
|||
}
|
||||
}
|
||||
|
||||
func getAllRouteableHosts(options *config.Options, addr string) ([]string, map[string][]config.IndexedPolicy, error) {
|
||||
func getAllRouteableHosts(options *config.Options, addr string) (*sets.Sorted[string], map[string][]config.IndexedPolicy, error) {
|
||||
allHosts := sets.NewSorted[string]()
|
||||
var policiesByHost map[string][]config.IndexedPolicy
|
||||
|
||||
|
@ -658,7 +658,7 @@ func getAllRouteableHosts(options *config.Options, addr string) ([]string, map[s
|
|||
allHosts.Add(hosts...)
|
||||
}
|
||||
|
||||
return allHosts.ToSlice(), policiesByHost, nil
|
||||
return allHosts, policiesByHost, nil
|
||||
}
|
||||
|
||||
func (b *Builder) urlsMatchHost(urls []*url.URL, host string) bool {
|
||||
|
|
|
@ -8,6 +8,7 @@ import (
|
|||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"slices"
|
||||
"testing"
|
||||
"text/template"
|
||||
|
||||
|
@ -491,7 +492,7 @@ func Test_getAllDomains(t *testing.T) {
|
|||
"d.unknown.example.com",
|
||||
"d.unknown.example.com:443",
|
||||
}
|
||||
assert.Equal(t, expect, actual)
|
||||
assert.Equal(t, expect, slices.Collect(actual.All()))
|
||||
})
|
||||
t.Run("grpc", func(t *testing.T) {
|
||||
actual, _, err := getAllRouteableHosts(options, "127.0.0.1:9001")
|
||||
|
@ -500,7 +501,7 @@ func Test_getAllDomains(t *testing.T) {
|
|||
"authorize.example.com:9001",
|
||||
"cache.example.com:9001",
|
||||
}
|
||||
assert.Equal(t, expect, actual)
|
||||
assert.Equal(t, expect, slices.Collect(actual.All()))
|
||||
})
|
||||
t.Run("both", func(t *testing.T) {
|
||||
newOptions := *options
|
||||
|
@ -523,7 +524,7 @@ func Test_getAllDomains(t *testing.T) {
|
|||
"d.unknown.example.com",
|
||||
"d.unknown.example.com:443",
|
||||
}
|
||||
assert.Equal(t, expect, actual)
|
||||
assert.Equal(t, expect, slices.Collect(actual.All()))
|
||||
})
|
||||
})
|
||||
|
||||
|
@ -534,7 +535,7 @@ func Test_getAllDomains(t *testing.T) {
|
|||
}
|
||||
actual, _, err := getAllRouteableHosts(options, ":443")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []string{"a.example.com"}, actual)
|
||||
assert.Equal(t, []string{"a.example.com"}, slices.Collect(actual.All()))
|
||||
})
|
||||
}
|
||||
|
||||
|
|
|
@ -52,20 +52,26 @@ func (b *Builder) buildMainRouteConfiguration() (*envoy_config_route_v3.RouteCon
|
|||
return nil, err
|
||||
}
|
||||
|
||||
allHosts, policiesByHost, err := getAllRouteableHosts(b.cfg.Options, b.cfg.Options.Addr)
|
||||
hosts, policiesByHost, err := getAllRouteableHosts(b.cfg.Options, b.cfg.Options.Addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var virtualHosts []*envoy_config_route_v3.VirtualHost
|
||||
virtualHosts := make([]*envoy_config_route_v3.VirtualHost, 0, hosts.Size())
|
||||
catchallVirtualHost, err := b.buildVirtualHost("catch-all", "*")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
seenCatchallPolicies := map[int]struct{}{}
|
||||
|
||||
isProxy := config.IsProxy(b.cfg.Options.Services)
|
||||
for _, host := range allHosts {
|
||||
if isProxy && strings.Contains(host, "*") {
|
||||
isAuthorize := config.IsAuthorize(b.cfg.Options.Services)
|
||||
isDatabroker := config.IsDataBroker(b.cfg.Options.Services)
|
||||
isGRPCServiceDomain := b.cfg.Options.Addr == b.cfg.Options.GetGRPCAddr()
|
||||
|
||||
for host := range hosts.All() {
|
||||
if isProxy && strings.ContainsRune(host, '*') {
|
||||
// Group policies containing wildcards into a separate virtual host
|
||||
for _, policy := range policiesByHost[host] {
|
||||
if _, ok := seenCatchallPolicies[policy.Index]; ok {
|
||||
continue
|
||||
|
@ -85,10 +91,10 @@ func (b *Builder) buildMainRouteConfiguration() (*envoy_config_route_v3.RouteCon
|
|||
return nil, err
|
||||
}
|
||||
|
||||
if b.cfg.Options.Addr == b.cfg.Options.GetGRPCAddr() {
|
||||
if isGRPCServiceDomain {
|
||||
// if this is a gRPC service domain and we're supposed to handle that, add those routes
|
||||
if (config.IsAuthorize(b.cfg.Options.Services) && b.urlsMatchHost(authorizeURLs, host)) ||
|
||||
(config.IsDataBroker(b.cfg.Options.Services) && b.urlsMatchHost(dataBrokerURLs, host)) {
|
||||
if (isAuthorize && b.urlsMatchHost(authorizeURLs, host)) ||
|
||||
(isDatabroker && b.urlsMatchHost(dataBrokerURLs, host)) {
|
||||
rs, err := b.buildGRPCRoutes()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
|
@ -185,10 +185,11 @@ func (b *Builder) buildRoutesForPolicy(
|
|||
return nil, err
|
||||
}
|
||||
|
||||
matchAnyIncomingPort := b.cfg.Options.IsRuntimeFlagSet(config.RuntimeFlagMatchAnyIncomingPort)
|
||||
var routes []*envoy_config_route_v3.Route
|
||||
if strings.Contains(fromURL.Host, "*") {
|
||||
// we have to match '*.example.com' and '*.example.com:443', so there are two routes
|
||||
for host := range urlutil.AllDomainsForURL(fromURL.URL, !b.cfg.Options.IsRuntimeFlagSet(config.RuntimeFlagMatchAnyIncomingPort)) {
|
||||
for host := range urlutil.AllDomainsForURL(fromURL.URL, !matchAnyIncomingPort) {
|
||||
route, err := b.buildRouteForPolicyAndMatch(policy, name, mkRouteMatchForHost(b.cfg.Options, policy, host))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
|
@ -1343,32 +1343,29 @@ func (o *Options) GetAllRouteablePolicyHTTPHosts() (map[string][]IndexedPolicy,
|
|||
return nil, nil
|
||||
}
|
||||
|
||||
matchAnyIncomingPort := o.IsRuntimeFlagSet(RuntimeFlagMatchAnyIncomingPort)
|
||||
mult := 1
|
||||
if o.IsRuntimeFlagSet(RuntimeFlagMatchAnyIncomingPort) {
|
||||
if matchAnyIncomingPort {
|
||||
mult = 2
|
||||
}
|
||||
hosts := make(map[string][]IndexedPolicy, o.NumPolicies()*mult)
|
||||
|
||||
var retErr error
|
||||
for i, policy := range o.GetAllPoliciesIndexed() {
|
||||
fromURL, err := urlutil.ParseAndValidateURL(policy.From)
|
||||
fromURL, err := urlutil.ParseAndValidateSharedURL(policy.From)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, domain := range urlutil.GetDomainsForURL(fromURL, !o.IsRuntimeFlagSet(RuntimeFlagMatchAnyIncomingPort)) {
|
||||
for _, domain := range urlutil.GetDomainsForURL(fromURL.URL, !matchAnyIncomingPort) {
|
||||
hosts[domain] = append(hosts[domain], IndexedPolicy{Policy: policy, Index: i})
|
||||
}
|
||||
if policy.TLSDownstreamServerName != "" {
|
||||
tlsURL := fromURL.ResolveReference(&url.URL{Host: policy.TLSDownstreamServerName})
|
||||
for _, domain := range urlutil.GetDomainsForURL(tlsURL, !o.IsRuntimeFlagSet(RuntimeFlagMatchAnyIncomingPort)) {
|
||||
for _, domain := range urlutil.GetDomainsForURL(tlsURL, !matchAnyIncomingPort) {
|
||||
hosts[domain] = append(hosts[domain], IndexedPolicy{Policy: policy, Index: i})
|
||||
}
|
||||
}
|
||||
}
|
||||
if retErr != nil {
|
||||
return nil, retErr
|
||||
}
|
||||
return hosts, nil
|
||||
}
|
||||
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
package sets
|
||||
|
||||
import (
|
||||
"iter"
|
||||
|
||||
"github.com/google/btree"
|
||||
"golang.org/x/exp/constraints"
|
||||
)
|
||||
|
@ -65,3 +67,9 @@ func (s *Sorted[T]) ToSlice() []T {
|
|||
})
|
||||
return arr
|
||||
}
|
||||
|
||||
func (s *Sorted[T]) All() iter.Seq[T] {
|
||||
return func(yield func(T) bool) {
|
||||
s.b.Ascend(yield)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -10,11 +10,9 @@ import (
|
|||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/caddyserver/certmagic"
|
||||
"golang.org/x/sync/singleflight"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -59,18 +57,10 @@ func ParseAndValidateURL(rawurl string) (*url.URL, error) {
|
|||
|
||||
type SharedURL struct {
|
||||
*url.URL
|
||||
initDone uint32
|
||||
hostname func() string
|
||||
}
|
||||
|
||||
func (s *SharedURL) lazyInit() {
|
||||
if atomic.CompareAndSwapUint32(&s.initDone, 0, 1) {
|
||||
s.hostname = sync.OnceValue(s.URL.Hostname)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SharedURL) Hostname() string {
|
||||
s.lazyInit()
|
||||
return s.hostname()
|
||||
}
|
||||
|
||||
|
@ -83,27 +73,21 @@ func (s *SharedURL) Mutable() *url.URL {
|
|||
return &u
|
||||
}
|
||||
|
||||
var (
|
||||
urlCache sync.Map // map[string]*url.URL
|
||||
sf singleflight.Group
|
||||
)
|
||||
var urlCache sync.Map // map[string]*SharedURL
|
||||
|
||||
func ParseAndValidateSharedURL(rawurl string) (*SharedURL, error) {
|
||||
if v, ok := urlCache.Load(rawurl); ok {
|
||||
return &SharedURL{URL: v.(*url.URL)}, nil
|
||||
}
|
||||
v, err, _ := sf.Do(rawurl, func() (any, error) {
|
||||
url, err := ParseAndValidateURL(rawurl)
|
||||
shared, ok := urlCache.Load(rawurl)
|
||||
if !ok {
|
||||
u, err := ParseAndValidateURL(rawurl)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
urlCache.Store(rawurl, url)
|
||||
return url, nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
shared, _ = urlCache.LoadOrStore(rawurl, &SharedURL{
|
||||
URL: u,
|
||||
hostname: sync.OnceValue(u.Hostname),
|
||||
})
|
||||
}
|
||||
return &SharedURL{URL: v.(*url.URL)}, nil
|
||||
return shared.(*SharedURL), nil
|
||||
}
|
||||
|
||||
// MustParseAndValidateURL parses the URL via ParseAndValidateURL but panics if there is an error.
|
||||
|
|
|
@ -1,9 +1,14 @@
|
|||
package urlutil
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
mathrand "math/rand/v2"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
|
@ -199,3 +204,58 @@ func TestMatchesServerName(t *testing.T) {
|
|||
assert.True(t, MatchesServerName(MustParseAndValidateURL("https://domain.example.com"), "*.example.com"))
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkSharedURL(b *testing.B) {
|
||||
randomURL := func() string {
|
||||
randBytes := make([]byte, 32)
|
||||
rand.Read(randBytes)
|
||||
randStr := base64.RawURLEncoding.EncodeToString(randBytes)
|
||||
return fmt.Sprintf("https://%s.example.com/foo/bar?baz=1#fragment", randStr)
|
||||
}
|
||||
|
||||
for _, tc := range []struct {
|
||||
name string
|
||||
fn func(string)
|
||||
}{
|
||||
{"Normal", func(rawurl string) { ParseAndValidateURL(rawurl) }},
|
||||
{"Shared", func(rawurl string) { ParseAndValidateSharedURL(rawurl) }},
|
||||
} {
|
||||
b.Run(tc.name, func(b *testing.B) {
|
||||
b.Run("Same URL", func(b *testing.B) {
|
||||
u := randomURL()
|
||||
b.ResetTimer()
|
||||
b.RunParallel(func(p *testing.PB) {
|
||||
for p.Next() {
|
||||
tc.fn(u)
|
||||
}
|
||||
})
|
||||
})
|
||||
b.Run("Unique URLs", func(b *testing.B) {
|
||||
urls := make([]string, b.N)
|
||||
for i := range urls {
|
||||
urls[i] = randomURL()
|
||||
}
|
||||
var atomicIndex atomic.Int32
|
||||
atomicIndex.Store(-1)
|
||||
b.ResetTimer()
|
||||
b.RunParallel(func(p *testing.PB) {
|
||||
for p.Next() {
|
||||
tc.fn(randomURL())
|
||||
}
|
||||
})
|
||||
})
|
||||
b.Run("Random URLs from set", func(b *testing.B) {
|
||||
urls := make([]string, 10000)
|
||||
for i := range urls {
|
||||
urls[i] = randomURL()
|
||||
}
|
||||
b.ResetTimer()
|
||||
b.RunParallel(func(p *testing.PB) {
|
||||
for p.Next() {
|
||||
tc.fn(urls[mathrand.IntN(len(urls))])
|
||||
}
|
||||
})
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue