Code cleanup; simplify SharedURL

This commit is contained in:
Joe Kralicky 2024-08-23 13:52:07 -04:00
parent bb225eb845
commit 459c73d461
No known key found for this signature in database
GPG key ID: 75C4875F34A9FB79
8 changed files with 104 additions and 47 deletions

View file

@ -631,7 +631,7 @@ func addCAToBundle(bundle *bytes.Buffer, ca []byte) {
}
}
func getAllRouteableHosts(options *config.Options, addr string) ([]string, map[string][]config.IndexedPolicy, error) {
func getAllRouteableHosts(options *config.Options, addr string) (*sets.Sorted[string], map[string][]config.IndexedPolicy, error) {
allHosts := sets.NewSorted[string]()
var policiesByHost map[string][]config.IndexedPolicy
@ -658,7 +658,7 @@ func getAllRouteableHosts(options *config.Options, addr string) ([]string, map[s
allHosts.Add(hosts...)
}
return allHosts.ToSlice(), policiesByHost, nil
return allHosts, policiesByHost, nil
}
func (b *Builder) urlsMatchHost(urls []*url.URL, host string) bool {

View file

@ -8,6 +8,7 @@ import (
"os"
"path/filepath"
"runtime"
"slices"
"testing"
"text/template"
@ -491,7 +492,7 @@ func Test_getAllDomains(t *testing.T) {
"d.unknown.example.com",
"d.unknown.example.com:443",
}
assert.Equal(t, expect, actual)
assert.Equal(t, expect, slices.Collect(actual.All()))
})
t.Run("grpc", func(t *testing.T) {
actual, _, err := getAllRouteableHosts(options, "127.0.0.1:9001")
@ -500,7 +501,7 @@ func Test_getAllDomains(t *testing.T) {
"authorize.example.com:9001",
"cache.example.com:9001",
}
assert.Equal(t, expect, actual)
assert.Equal(t, expect, slices.Collect(actual.All()))
})
t.Run("both", func(t *testing.T) {
newOptions := *options
@ -523,7 +524,7 @@ func Test_getAllDomains(t *testing.T) {
"d.unknown.example.com",
"d.unknown.example.com:443",
}
assert.Equal(t, expect, actual)
assert.Equal(t, expect, slices.Collect(actual.All()))
})
})
@ -534,7 +535,7 @@ func Test_getAllDomains(t *testing.T) {
}
actual, _, err := getAllRouteableHosts(options, ":443")
require.NoError(t, err)
assert.Equal(t, []string{"a.example.com"}, actual)
assert.Equal(t, []string{"a.example.com"}, slices.Collect(actual.All()))
})
}

View file

@ -52,20 +52,26 @@ func (b *Builder) buildMainRouteConfiguration() (*envoy_config_route_v3.RouteCon
return nil, err
}
allHosts, policiesByHost, err := getAllRouteableHosts(b.cfg.Options, b.cfg.Options.Addr)
hosts, policiesByHost, err := getAllRouteableHosts(b.cfg.Options, b.cfg.Options.Addr)
if err != nil {
return nil, err
}
var virtualHosts []*envoy_config_route_v3.VirtualHost
virtualHosts := make([]*envoy_config_route_v3.VirtualHost, 0, hosts.Size())
catchallVirtualHost, err := b.buildVirtualHost("catch-all", "*")
if err != nil {
return nil, err
}
seenCatchallPolicies := map[int]struct{}{}
isProxy := config.IsProxy(b.cfg.Options.Services)
for _, host := range allHosts {
if isProxy && strings.Contains(host, "*") {
isAuthorize := config.IsAuthorize(b.cfg.Options.Services)
isDatabroker := config.IsDataBroker(b.cfg.Options.Services)
isGRPCServiceDomain := b.cfg.Options.Addr == b.cfg.Options.GetGRPCAddr()
for host := range hosts.All() {
if isProxy && strings.ContainsRune(host, '*') {
// Group policies containing wildcards into a separate virtual host
for _, policy := range policiesByHost[host] {
if _, ok := seenCatchallPolicies[policy.Index]; ok {
continue
@ -85,10 +91,10 @@ func (b *Builder) buildMainRouteConfiguration() (*envoy_config_route_v3.RouteCon
return nil, err
}
if b.cfg.Options.Addr == b.cfg.Options.GetGRPCAddr() {
if isGRPCServiceDomain {
// if this is a gRPC service domain and we're supposed to handle that, add those routes
if (config.IsAuthorize(b.cfg.Options.Services) && b.urlsMatchHost(authorizeURLs, host)) ||
(config.IsDataBroker(b.cfg.Options.Services) && b.urlsMatchHost(dataBrokerURLs, host)) {
if (isAuthorize && b.urlsMatchHost(authorizeURLs, host)) ||
(isDatabroker && b.urlsMatchHost(dataBrokerURLs, host)) {
rs, err := b.buildGRPCRoutes()
if err != nil {
return nil, err

View file

@ -185,10 +185,11 @@ func (b *Builder) buildRoutesForPolicy(
return nil, err
}
matchAnyIncomingPort := b.cfg.Options.IsRuntimeFlagSet(config.RuntimeFlagMatchAnyIncomingPort)
var routes []*envoy_config_route_v3.Route
if strings.Contains(fromURL.Host, "*") {
// we have to match '*.example.com' and '*.example.com:443', so there are two routes
for host := range urlutil.AllDomainsForURL(fromURL.URL, !b.cfg.Options.IsRuntimeFlagSet(config.RuntimeFlagMatchAnyIncomingPort)) {
for host := range urlutil.AllDomainsForURL(fromURL.URL, !matchAnyIncomingPort) {
route, err := b.buildRouteForPolicyAndMatch(policy, name, mkRouteMatchForHost(b.cfg.Options, policy, host))
if err != nil {
return nil, err

View file

@ -1343,32 +1343,29 @@ func (o *Options) GetAllRouteablePolicyHTTPHosts() (map[string][]IndexedPolicy,
return nil, nil
}
matchAnyIncomingPort := o.IsRuntimeFlagSet(RuntimeFlagMatchAnyIncomingPort)
mult := 1
if o.IsRuntimeFlagSet(RuntimeFlagMatchAnyIncomingPort) {
if matchAnyIncomingPort {
mult = 2
}
hosts := make(map[string][]IndexedPolicy, o.NumPolicies()*mult)
var retErr error
for i, policy := range o.GetAllPoliciesIndexed() {
fromURL, err := urlutil.ParseAndValidateURL(policy.From)
fromURL, err := urlutil.ParseAndValidateSharedURL(policy.From)
if err != nil {
return nil, err
}
for _, domain := range urlutil.GetDomainsForURL(fromURL, !o.IsRuntimeFlagSet(RuntimeFlagMatchAnyIncomingPort)) {
for _, domain := range urlutil.GetDomainsForURL(fromURL.URL, !matchAnyIncomingPort) {
hosts[domain] = append(hosts[domain], IndexedPolicy{Policy: policy, Index: i})
}
if policy.TLSDownstreamServerName != "" {
tlsURL := fromURL.ResolveReference(&url.URL{Host: policy.TLSDownstreamServerName})
for _, domain := range urlutil.GetDomainsForURL(tlsURL, !o.IsRuntimeFlagSet(RuntimeFlagMatchAnyIncomingPort)) {
for _, domain := range urlutil.GetDomainsForURL(tlsURL, !matchAnyIncomingPort) {
hosts[domain] = append(hosts[domain], IndexedPolicy{Policy: policy, Index: i})
}
}
}
if retErr != nil {
return nil, retErr
}
return hosts, nil
}

View file

@ -1,6 +1,8 @@
package sets
import (
"iter"
"github.com/google/btree"
"golang.org/x/exp/constraints"
)
@ -65,3 +67,9 @@ func (s *Sorted[T]) ToSlice() []T {
})
return arr
}
func (s *Sorted[T]) All() iter.Seq[T] {
return func(yield func(T) bool) {
s.b.Ascend(yield)
}
}

View file

@ -10,11 +10,9 @@ import (
"net/url"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/caddyserver/certmagic"
"golang.org/x/sync/singleflight"
)
const (
@ -59,18 +57,10 @@ func ParseAndValidateURL(rawurl string) (*url.URL, error) {
type SharedURL struct {
*url.URL
initDone uint32
hostname func() string
}
func (s *SharedURL) lazyInit() {
if atomic.CompareAndSwapUint32(&s.initDone, 0, 1) {
s.hostname = sync.OnceValue(s.URL.Hostname)
}
}
func (s *SharedURL) Hostname() string {
s.lazyInit()
return s.hostname()
}
@ -83,27 +73,21 @@ func (s *SharedURL) Mutable() *url.URL {
return &u
}
var (
urlCache sync.Map // map[string]*url.URL
sf singleflight.Group
)
var urlCache sync.Map // map[string]*SharedURL
func ParseAndValidateSharedURL(rawurl string) (*SharedURL, error) {
if v, ok := urlCache.Load(rawurl); ok {
return &SharedURL{URL: v.(*url.URL)}, nil
}
v, err, _ := sf.Do(rawurl, func() (any, error) {
url, err := ParseAndValidateURL(rawurl)
shared, ok := urlCache.Load(rawurl)
if !ok {
u, err := ParseAndValidateURL(rawurl)
if err != nil {
return nil, err
}
urlCache.Store(rawurl, url)
return url, nil
})
if err != nil {
return nil, err
shared, _ = urlCache.LoadOrStore(rawurl, &SharedURL{
URL: u,
hostname: sync.OnceValue(u.Hostname),
})
}
return &SharedURL{URL: v.(*url.URL)}, nil
return shared.(*SharedURL), nil
}
// MustParseAndValidateURL parses the URL via ParseAndValidateURL but panics if there is an error.

View file

@ -1,9 +1,14 @@
package urlutil
import (
"crypto/rand"
"encoding/base64"
"fmt"
mathrand "math/rand/v2"
"net/http"
"net/url"
"reflect"
"sync/atomic"
"testing"
"github.com/google/go-cmp/cmp"
@ -199,3 +204,58 @@ func TestMatchesServerName(t *testing.T) {
assert.True(t, MatchesServerName(MustParseAndValidateURL("https://domain.example.com"), "*.example.com"))
})
}
func BenchmarkSharedURL(b *testing.B) {
randomURL := func() string {
randBytes := make([]byte, 32)
rand.Read(randBytes)
randStr := base64.RawURLEncoding.EncodeToString(randBytes)
return fmt.Sprintf("https://%s.example.com/foo/bar?baz=1#fragment", randStr)
}
for _, tc := range []struct {
name string
fn func(string)
}{
{"Normal", func(rawurl string) { ParseAndValidateURL(rawurl) }},
{"Shared", func(rawurl string) { ParseAndValidateSharedURL(rawurl) }},
} {
b.Run(tc.name, func(b *testing.B) {
b.Run("Same URL", func(b *testing.B) {
u := randomURL()
b.ResetTimer()
b.RunParallel(func(p *testing.PB) {
for p.Next() {
tc.fn(u)
}
})
})
b.Run("Unique URLs", func(b *testing.B) {
urls := make([]string, b.N)
for i := range urls {
urls[i] = randomURL()
}
var atomicIndex atomic.Int32
atomicIndex.Store(-1)
b.ResetTimer()
b.RunParallel(func(p *testing.PB) {
for p.Next() {
tc.fn(randomURL())
}
})
})
b.Run("Random URLs from set", func(b *testing.B) {
urls := make([]string, 10000)
for i := range urls {
urls[i] = randomURL()
}
b.ResetTimer()
b.RunParallel(func(p *testing.PB) {
for p.Next() {
tc.fn(urls[mathrand.IntN(len(urls))])
}
})
})
})
}
}