core/config: add support for stripping the port for matching routes

This commit is contained in:
Caleb Doxsey 2024-04-24 16:36:48 -06:00
parent 05e077fe04
commit 31e8383250
13 changed files with 57 additions and 33 deletions

View file

@ -163,7 +163,7 @@ func getCheckRequestURL(req *envoy_service_auth_v3.CheckRequest) url.URL {
Scheme: h.GetScheme(),
Host: h.GetHost(),
}
u.Host = urlutil.GetDomainsForURL(&u)[0]
u.Host = urlutil.GetDomainsForURL(&u, false)[0]
// envoy sends the query string as part of the path
path := h.GetPath()
if idx := strings.Index(path, "?"); idx != -1 {

View file

@ -2,6 +2,7 @@ package envoyconfig
import (
"fmt"
"strings"
envoy_config_accesslog_v3 "github.com/envoyproxy/go-control-plane/envoy/config/accesslog/v3"
envoy_config_core_v3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
@ -23,6 +24,15 @@ func (b *Builder) buildVirtualHost(
Domains: []string{host},
}
// if we're stripping the port from incoming requests
// and this host doesn't have a port or wildcard in it
// then we will add :* to match on any port
if options.IsRuntimeFlagSet(config.RuntimeFlagStripFromPort) &&
!strings.Contains(host, "*") &&
!strings.Contains(host, ":") {
vh.Domains = append(vh.Domains, host+":*")
}
// these routes match /.pomerium/... and similar paths
rs, err := b.buildPomeriumHTTPRoutes(options, host)
if err != nil {

View file

@ -669,7 +669,7 @@ func urlsMatchHost(urls []*url.URL, host string) bool {
}
func urlMatchesHost(u *url.URL, host string) bool {
for _, h := range urlutil.GetDomainsForURL(u) {
for _, h := range urlutil.GetDomainsForURL(u, true) {
if h == host {
return true
}

View file

@ -236,8 +236,8 @@ func (b *Builder) buildRoutesForPolicy(
var routes []*envoy_config_route_v3.Route
if strings.Contains(fromURL.Host, "*") {
// we have to match '*.example.com' and '*.example.com:443', so there are two routes
for _, host := range urlutil.GetDomainsForURL(fromURL) {
route, err := b.buildRouteForPolicyAndMatch(cfg, policy, name, mkRouteMatchForHost(policy, host))
for _, host := range urlutil.GetDomainsForURL(fromURL, !cfg.Options.IsRuntimeFlagSet(config.RuntimeFlagStripFromPort)) {
route, err := b.buildRouteForPolicyAndMatch(cfg, policy, name, mkRouteMatchForHost(cfg.Options, policy, host))
if err != nil {
return nil, err
}
@ -507,6 +507,7 @@ func mkRouteMatch(policy *config.Policy) *envoy_config_route_v3.RouteMatch {
}
func mkRouteMatchForHost(
options *config.Options,
policy *config.Policy,
host string,
) *envoy_config_route_v3.RouteMatch {
@ -517,7 +518,7 @@ func mkRouteMatchForHost(
StringMatch: &envoy_type_matcher_v3.StringMatcher{
MatchPattern: &envoy_type_matcher_v3.StringMatcher_SafeRegex{
SafeRegex: &envoy_type_matcher_v3.RegexMatcher{
Regex: config.WildcardToRegex(host),
Regex: config.WildcardToRegex(host, options.IsRuntimeFlagSet(config.RuntimeFlagStripFromPort)),
},
},
},

View file

@ -9,8 +9,8 @@ import (
)
// FromURLMatchesRequestURL returns true if the from URL matches the request URL.
func FromURLMatchesRequestURL(fromURL, requestURL *url.URL) bool {
for _, domain := range urlutil.GetDomainsForURL(fromURL) {
func FromURLMatchesRequestURL(fromURL, requestURL *url.URL, stripPort bool) bool {
for _, domain := range urlutil.GetDomainsForURL(fromURL, true) {
if domain == requestURL.Host {
return true
}
@ -19,7 +19,7 @@ func FromURLMatchesRequestURL(fromURL, requestURL *url.URL) bool {
continue
}
reStr := WildcardToRegex(domain)
reStr := WildcardToRegex(domain, stripPort)
re := regexp.MustCompile(reStr)
if re.MatchString(requestURL.Host) {
return true
@ -29,7 +29,7 @@ func FromURLMatchesRequestURL(fromURL, requestURL *url.URL) bool {
}
// WildcardToRegex converts a wildcard string to a regular expression.
func WildcardToRegex(wildcard string) string {
func WildcardToRegex(wildcard string, stripPort bool) string {
var b strings.Builder
b.WriteByte('^')
for {
@ -42,6 +42,9 @@ func WildcardToRegex(wildcard string) string {
wildcard = wildcard[idx+1:]
}
b.WriteString(regexp.QuoteMeta(wildcard))
if stripPort && !strings.Contains(wildcard, ":") {
b.WriteString("(:(.+))?")
}
b.WriteByte('$')
return b.String()
}

View file

@ -21,10 +21,11 @@ func TestFromURLMatchesRequestURL(t *testing.T) {
{"https://from.example.com", "https://to.example.com/some/path", false},
{"https://*.example.com", "https://from.example.com/some/path", true},
{"https://*.example.com", "https://example.com/some/path", false},
{"https://*.example.com", "https://from.example.com:8443/some/path", true},
} {
fromURL := urlutil.MustParseAndValidateURL(tc.pattern)
requestURL := urlutil.MustParseAndValidateURL(tc.input)
assert.Equal(t, tc.matches, FromURLMatchesRequestURL(&fromURL, &requestURL),
assert.Equal(t, tc.matches, FromURLMatchesRequestURL(&fromURL, &requestURL, true),
"from-url: %s\nrequest-url: %s", tc.pattern, tc.input)
}
}
@ -32,7 +33,7 @@ func TestFromURLMatchesRequestURL(t *testing.T) {
func TestWildcardToRegex(t *testing.T) {
t.Parallel()
re, err := regexp.Compile(WildcardToRegex("*.internal.*.example.com"))
re, err := regexp.Compile(WildcardToRegex("*.internal.*.example.com", true))
assert.NoError(t, err)
assert.True(t, re.MatchString("a.internal.b.example.com"))
}

View file

@ -59,7 +59,7 @@ func (o *Options) GetIdentityProviderForRequestURL(requestURL string) (*identity
for _, p := range o.GetAllPolicies() {
p := p
if p.Matches(*u) {
if p.Matches(*u, o.IsRuntimeFlagSet(RuntimeFlagStripFromPort)) {
return o.GetIdentityProviderForPolicy(&p)
}
}

View file

@ -1216,7 +1216,7 @@ func (o *Options) GetAllRouteableGRPCHosts() ([]string, error) {
return nil, err
}
for _, u := range authorizeURLs {
hosts.Add(urlutil.GetDomainsForURL(u)...)
hosts.Add(urlutil.GetDomainsForURL(u, true)...)
}
} else if IsAuthorize(o.Services) {
authorizeURLs, err := o.GetInternalAuthorizeURLs()
@ -1224,7 +1224,7 @@ func (o *Options) GetAllRouteableGRPCHosts() ([]string, error) {
return nil, err
}
for _, u := range authorizeURLs {
hosts.Add(urlutil.GetDomainsForURL(u)...)
hosts.Add(urlutil.GetDomainsForURL(u, true)...)
}
}
@ -1235,7 +1235,7 @@ func (o *Options) GetAllRouteableGRPCHosts() ([]string, error) {
return nil, err
}
for _, u := range dataBrokerURLs {
hosts.Add(urlutil.GetDomainsForURL(u)...)
hosts.Add(urlutil.GetDomainsForURL(u, true)...)
}
} else if IsDataBroker(o.Services) {
dataBrokerURLs, err := o.GetInternalDataBrokerURLs()
@ -1243,7 +1243,7 @@ func (o *Options) GetAllRouteableGRPCHosts() ([]string, error) {
return nil, err
}
for _, u := range dataBrokerURLs {
hosts.Add(urlutil.GetDomainsForURL(u)...)
hosts.Add(urlutil.GetDomainsForURL(u, true)...)
}
}
@ -1259,7 +1259,7 @@ func (o *Options) GetAllRouteableHTTPHosts() ([]string, error) {
if err != nil {
return nil, err
}
hosts.Add(urlutil.GetDomainsForURL(authenticateURL)...)
hosts.Add(urlutil.GetDomainsForURL(authenticateURL, !o.IsRuntimeFlagSet(RuntimeFlagStripFromPort))...)
}
if o.AuthenticateURLString != "" {
@ -1267,7 +1267,7 @@ func (o *Options) GetAllRouteableHTTPHosts() ([]string, error) {
if err != nil {
return nil, err
}
hosts.Add(urlutil.GetDomainsForURL(authenticateURL)...)
hosts.Add(urlutil.GetDomainsForURL(authenticateURL, !o.IsRuntimeFlagSet(RuntimeFlagStripFromPort))...)
}
}
@ -1279,10 +1279,10 @@ func (o *Options) GetAllRouteableHTTPHosts() ([]string, error) {
return nil, err
}
hosts.Add(urlutil.GetDomainsForURL(fromURL)...)
hosts.Add(urlutil.GetDomainsForURL(fromURL, !o.IsRuntimeFlagSet(RuntimeFlagStripFromPort))...)
if policy.TLSDownstreamServerName != "" {
tlsURL := fromURL.ResolveReference(&url.URL{Host: policy.TLSDownstreamServerName})
hosts.Add(urlutil.GetDomainsForURL(tlsURL)...)
hosts.Add(urlutil.GetDomainsForURL(tlsURL, !o.IsRuntimeFlagSet(RuntimeFlagStripFromPort))...)
}
}
}

View file

@ -606,14 +606,14 @@ func (p *Policy) String() string {
}
// Matches returns true if the policy would match the given URL.
func (p *Policy) Matches(requestURL url.URL) bool {
func (p *Policy) Matches(requestURL url.URL, stripPort bool) bool {
// an invalid from URL should not match anything
fromURL, err := urlutil.ParseAndValidateURL(p.From)
if err != nil {
return false
}
if !FromURLMatchesRequestURL(fromURL, &requestURL) {
if !FromURLMatchesRequestURL(fromURL, &requestURL, stripPort) {
return false
}

View file

@ -245,7 +245,7 @@ func TestPolicy_Matches(t *testing.T) {
}
assert.NoError(t, p.Validate())
assert.False(t, p.Matches(urlutil.MustParseAndValidateURL(`https://www.example.com/foo/bar`)),
assert.False(t, p.Matches(urlutil.MustParseAndValidateURL(`https://www.example.com/foo/bar`), true),
"regex should only match full string")
})
t.Run("issue2952", func(t *testing.T) {
@ -256,7 +256,7 @@ func TestPolicy_Matches(t *testing.T) {
}
assert.NoError(t, p.Validate())
assert.True(t, p.Matches(urlutil.MustParseAndValidateURL(`https://www.example.com/foo/bar/0`)))
assert.True(t, p.Matches(urlutil.MustParseAndValidateURL(`https://www.example.com/foo/bar/0`), true))
})
t.Run("issue2592-test2", func(t *testing.T) {
p := &Policy{
@ -266,8 +266,8 @@ func TestPolicy_Matches(t *testing.T) {
}
assert.NoError(t, p.Validate())
assert.True(t, p.Matches(urlutil.MustParseAndValidateURL(`https://www.example.com/admin/foo`)))
assert.True(t, p.Matches(urlutil.MustParseAndValidateURL(`https://www.example.com/admin/bar`)))
assert.True(t, p.Matches(urlutil.MustParseAndValidateURL(`https://www.example.com/admin/foo`), true))
assert.True(t, p.Matches(urlutil.MustParseAndValidateURL(`https://www.example.com/admin/bar`), true))
})
t.Run("tcp", func(t *testing.T) {
p := &Policy{
@ -276,7 +276,7 @@ func TestPolicy_Matches(t *testing.T) {
}
assert.NoError(t, p.Validate())
assert.True(t, p.Matches(urlutil.MustParseAndValidateURL(`https://tcp.example.com:6379`)))
assert.True(t, p.Matches(urlutil.MustParseAndValidateURL(`https://tcp.example.com:6379`), true))
})
}

View file

@ -2,8 +2,13 @@ package config
import "golang.org/x/exp/maps"
// RuntimeFlagGRPCDatabrokerKeepalive enables gRPC keepalive to the databroker service
var RuntimeFlagGRPCDatabrokerKeepalive = runtimeFlag("grpc_databroker_keepalive", false)
var (
// RuntimeFlagGRPCDatabrokerKeepalive enables gRPC keepalive to the databroker service
RuntimeFlagGRPCDatabrokerKeepalive = runtimeFlag("grpc_databroker_keepalive", false)
// RuntimeFlagStripFromPort enables stripping the port from incoming requests for matching from addresses
RuntimeFlagStripFromPort = runtimeFlag("strip_from_port", true)
)
// RuntimeFlag is a runtime flag that can flip on/off certain features
type RuntimeFlag string

View file

@ -106,9 +106,9 @@ func GetServerNamesForURL(u *url.URL) []string {
// GetDomainsForURL returns the available domains for given url.
//
// For standard HTTP (80)/HTTPS (443) ports, it returns `example.com` and `example.com:<port>`.
// Otherwise, return the URL.Host value.
func GetDomainsForURL(u *url.URL) []string {
// For standard HTTP (80)/HTTPS (443) ports, it returns `example.com` and `example.com:<port>`,
// if includeDefaultPort is set. Otherwise, return the URL.Host value.
func GetDomainsForURL(u *url.URL, includeDefaultPort bool) []string {
if u == nil {
return nil
}
@ -141,6 +141,10 @@ func GetDomainsForURL(u *url.URL) []string {
}
}
if !includeDefaultPort {
return []string{u.Hostname()}
}
// for everything else we return two routes: 'example.com' and 'example.com:443'
return []string{u.Hostname(), net.JoinHostPort(u.Hostname(), defaultPort)}
}

View file

@ -179,7 +179,7 @@ func TestGetDomainsForURL(t *testing.T) {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
got := GetDomainsForURL(tc.u)
got := GetDomainsForURL(tc.u, true)
if diff := cmp.Diff(got, tc.want); diff != "" {
t.Errorf("GetDomainsForURL() = %v", diff)
}