core/config: add support for stripping the port for matching routes

This commit is contained in:
Caleb Doxsey 2024-04-24 16:36:48 -06:00
parent 05e077fe04
commit 31e8383250
13 changed files with 57 additions and 33 deletions

View file

@ -163,7 +163,7 @@ func getCheckRequestURL(req *envoy_service_auth_v3.CheckRequest) url.URL {
Scheme: h.GetScheme(), Scheme: h.GetScheme(),
Host: h.GetHost(), Host: h.GetHost(),
} }
u.Host = urlutil.GetDomainsForURL(&u)[0] u.Host = urlutil.GetDomainsForURL(&u, false)[0]
// envoy sends the query string as part of the path // envoy sends the query string as part of the path
path := h.GetPath() path := h.GetPath()
if idx := strings.Index(path, "?"); idx != -1 { if idx := strings.Index(path, "?"); idx != -1 {

View file

@ -2,6 +2,7 @@ package envoyconfig
import ( import (
"fmt" "fmt"
"strings"
envoy_config_accesslog_v3 "github.com/envoyproxy/go-control-plane/envoy/config/accesslog/v3" envoy_config_accesslog_v3 "github.com/envoyproxy/go-control-plane/envoy/config/accesslog/v3"
envoy_config_core_v3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" envoy_config_core_v3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
@ -23,6 +24,15 @@ func (b *Builder) buildVirtualHost(
Domains: []string{host}, Domains: []string{host},
} }
// if we're stripping the port from incoming requests
// and this host doesn't have a port or wildcard in it
// then we will add :* to match on any port
if options.IsRuntimeFlagSet(config.RuntimeFlagStripFromPort) &&
!strings.Contains(host, "*") &&
!strings.Contains(host, ":") {
vh.Domains = append(vh.Domains, host+":*")
}
// these routes match /.pomerium/... and similar paths // these routes match /.pomerium/... and similar paths
rs, err := b.buildPomeriumHTTPRoutes(options, host) rs, err := b.buildPomeriumHTTPRoutes(options, host)
if err != nil { if err != nil {

View file

@ -669,7 +669,7 @@ func urlsMatchHost(urls []*url.URL, host string) bool {
} }
func urlMatchesHost(u *url.URL, host string) bool { func urlMatchesHost(u *url.URL, host string) bool {
for _, h := range urlutil.GetDomainsForURL(u) { for _, h := range urlutil.GetDomainsForURL(u, true) {
if h == host { if h == host {
return true return true
} }

View file

@ -236,8 +236,8 @@ func (b *Builder) buildRoutesForPolicy(
var routes []*envoy_config_route_v3.Route var routes []*envoy_config_route_v3.Route
if strings.Contains(fromURL.Host, "*") { if strings.Contains(fromURL.Host, "*") {
// we have to match '*.example.com' and '*.example.com:443', so there are two routes // we have to match '*.example.com' and '*.example.com:443', so there are two routes
for _, host := range urlutil.GetDomainsForURL(fromURL) { for _, host := range urlutil.GetDomainsForURL(fromURL, !cfg.Options.IsRuntimeFlagSet(config.RuntimeFlagStripFromPort)) {
route, err := b.buildRouteForPolicyAndMatch(cfg, policy, name, mkRouteMatchForHost(policy, host)) route, err := b.buildRouteForPolicyAndMatch(cfg, policy, name, mkRouteMatchForHost(cfg.Options, policy, host))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -507,6 +507,7 @@ func mkRouteMatch(policy *config.Policy) *envoy_config_route_v3.RouteMatch {
} }
func mkRouteMatchForHost( func mkRouteMatchForHost(
options *config.Options,
policy *config.Policy, policy *config.Policy,
host string, host string,
) *envoy_config_route_v3.RouteMatch { ) *envoy_config_route_v3.RouteMatch {
@ -517,7 +518,7 @@ func mkRouteMatchForHost(
StringMatch: &envoy_type_matcher_v3.StringMatcher{ StringMatch: &envoy_type_matcher_v3.StringMatcher{
MatchPattern: &envoy_type_matcher_v3.StringMatcher_SafeRegex{ MatchPattern: &envoy_type_matcher_v3.StringMatcher_SafeRegex{
SafeRegex: &envoy_type_matcher_v3.RegexMatcher{ SafeRegex: &envoy_type_matcher_v3.RegexMatcher{
Regex: config.WildcardToRegex(host), Regex: config.WildcardToRegex(host, options.IsRuntimeFlagSet(config.RuntimeFlagStripFromPort)),
}, },
}, },
}, },

View file

@ -9,8 +9,8 @@ import (
) )
// FromURLMatchesRequestURL returns true if the from URL matches the request URL. // FromURLMatchesRequestURL returns true if the from URL matches the request URL.
func FromURLMatchesRequestURL(fromURL, requestURL *url.URL) bool { func FromURLMatchesRequestURL(fromURL, requestURL *url.URL, stripPort bool) bool {
for _, domain := range urlutil.GetDomainsForURL(fromURL) { for _, domain := range urlutil.GetDomainsForURL(fromURL, true) {
if domain == requestURL.Host { if domain == requestURL.Host {
return true return true
} }
@ -19,7 +19,7 @@ func FromURLMatchesRequestURL(fromURL, requestURL *url.URL) bool {
continue continue
} }
reStr := WildcardToRegex(domain) reStr := WildcardToRegex(domain, stripPort)
re := regexp.MustCompile(reStr) re := regexp.MustCompile(reStr)
if re.MatchString(requestURL.Host) { if re.MatchString(requestURL.Host) {
return true return true
@ -29,7 +29,7 @@ func FromURLMatchesRequestURL(fromURL, requestURL *url.URL) bool {
} }
// WildcardToRegex converts a wildcard string to a regular expression. // WildcardToRegex converts a wildcard string to a regular expression.
func WildcardToRegex(wildcard string) string { func WildcardToRegex(wildcard string, stripPort bool) string {
var b strings.Builder var b strings.Builder
b.WriteByte('^') b.WriteByte('^')
for { for {
@ -42,6 +42,9 @@ func WildcardToRegex(wildcard string) string {
wildcard = wildcard[idx+1:] wildcard = wildcard[idx+1:]
} }
b.WriteString(regexp.QuoteMeta(wildcard)) b.WriteString(regexp.QuoteMeta(wildcard))
if stripPort && !strings.Contains(wildcard, ":") {
b.WriteString("(:(.+))?")
}
b.WriteByte('$') b.WriteByte('$')
return b.String() return b.String()
} }

View file

@ -21,10 +21,11 @@ func TestFromURLMatchesRequestURL(t *testing.T) {
{"https://from.example.com", "https://to.example.com/some/path", false}, {"https://from.example.com", "https://to.example.com/some/path", false},
{"https://*.example.com", "https://from.example.com/some/path", true}, {"https://*.example.com", "https://from.example.com/some/path", true},
{"https://*.example.com", "https://example.com/some/path", false}, {"https://*.example.com", "https://example.com/some/path", false},
{"https://*.example.com", "https://from.example.com:8443/some/path", true},
} { } {
fromURL := urlutil.MustParseAndValidateURL(tc.pattern) fromURL := urlutil.MustParseAndValidateURL(tc.pattern)
requestURL := urlutil.MustParseAndValidateURL(tc.input) requestURL := urlutil.MustParseAndValidateURL(tc.input)
assert.Equal(t, tc.matches, FromURLMatchesRequestURL(&fromURL, &requestURL), assert.Equal(t, tc.matches, FromURLMatchesRequestURL(&fromURL, &requestURL, true),
"from-url: %s\nrequest-url: %s", tc.pattern, tc.input) "from-url: %s\nrequest-url: %s", tc.pattern, tc.input)
} }
} }
@ -32,7 +33,7 @@ func TestFromURLMatchesRequestURL(t *testing.T) {
func TestWildcardToRegex(t *testing.T) { func TestWildcardToRegex(t *testing.T) {
t.Parallel() t.Parallel()
re, err := regexp.Compile(WildcardToRegex("*.internal.*.example.com")) re, err := regexp.Compile(WildcardToRegex("*.internal.*.example.com", true))
assert.NoError(t, err) assert.NoError(t, err)
assert.True(t, re.MatchString("a.internal.b.example.com")) assert.True(t, re.MatchString("a.internal.b.example.com"))
} }

View file

@ -59,7 +59,7 @@ func (o *Options) GetIdentityProviderForRequestURL(requestURL string) (*identity
for _, p := range o.GetAllPolicies() { for _, p := range o.GetAllPolicies() {
p := p p := p
if p.Matches(*u) { if p.Matches(*u, o.IsRuntimeFlagSet(RuntimeFlagStripFromPort)) {
return o.GetIdentityProviderForPolicy(&p) return o.GetIdentityProviderForPolicy(&p)
} }
} }

View file

@ -1216,7 +1216,7 @@ func (o *Options) GetAllRouteableGRPCHosts() ([]string, error) {
return nil, err return nil, err
} }
for _, u := range authorizeURLs { for _, u := range authorizeURLs {
hosts.Add(urlutil.GetDomainsForURL(u)...) hosts.Add(urlutil.GetDomainsForURL(u, true)...)
} }
} else if IsAuthorize(o.Services) { } else if IsAuthorize(o.Services) {
authorizeURLs, err := o.GetInternalAuthorizeURLs() authorizeURLs, err := o.GetInternalAuthorizeURLs()
@ -1224,7 +1224,7 @@ func (o *Options) GetAllRouteableGRPCHosts() ([]string, error) {
return nil, err return nil, err
} }
for _, u := range authorizeURLs { for _, u := range authorizeURLs {
hosts.Add(urlutil.GetDomainsForURL(u)...) hosts.Add(urlutil.GetDomainsForURL(u, true)...)
} }
} }
@ -1235,7 +1235,7 @@ func (o *Options) GetAllRouteableGRPCHosts() ([]string, error) {
return nil, err return nil, err
} }
for _, u := range dataBrokerURLs { for _, u := range dataBrokerURLs {
hosts.Add(urlutil.GetDomainsForURL(u)...) hosts.Add(urlutil.GetDomainsForURL(u, true)...)
} }
} else if IsDataBroker(o.Services) { } else if IsDataBroker(o.Services) {
dataBrokerURLs, err := o.GetInternalDataBrokerURLs() dataBrokerURLs, err := o.GetInternalDataBrokerURLs()
@ -1243,7 +1243,7 @@ func (o *Options) GetAllRouteableGRPCHosts() ([]string, error) {
return nil, err return nil, err
} }
for _, u := range dataBrokerURLs { for _, u := range dataBrokerURLs {
hosts.Add(urlutil.GetDomainsForURL(u)...) hosts.Add(urlutil.GetDomainsForURL(u, true)...)
} }
} }
@ -1259,7 +1259,7 @@ func (o *Options) GetAllRouteableHTTPHosts() ([]string, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
hosts.Add(urlutil.GetDomainsForURL(authenticateURL)...) hosts.Add(urlutil.GetDomainsForURL(authenticateURL, !o.IsRuntimeFlagSet(RuntimeFlagStripFromPort))...)
} }
if o.AuthenticateURLString != "" { if o.AuthenticateURLString != "" {
@ -1267,7 +1267,7 @@ func (o *Options) GetAllRouteableHTTPHosts() ([]string, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
hosts.Add(urlutil.GetDomainsForURL(authenticateURL)...) hosts.Add(urlutil.GetDomainsForURL(authenticateURL, !o.IsRuntimeFlagSet(RuntimeFlagStripFromPort))...)
} }
} }
@ -1279,10 +1279,10 @@ func (o *Options) GetAllRouteableHTTPHosts() ([]string, error) {
return nil, err return nil, err
} }
hosts.Add(urlutil.GetDomainsForURL(fromURL)...) hosts.Add(urlutil.GetDomainsForURL(fromURL, !o.IsRuntimeFlagSet(RuntimeFlagStripFromPort))...)
if policy.TLSDownstreamServerName != "" { if policy.TLSDownstreamServerName != "" {
tlsURL := fromURL.ResolveReference(&url.URL{Host: policy.TLSDownstreamServerName}) tlsURL := fromURL.ResolveReference(&url.URL{Host: policy.TLSDownstreamServerName})
hosts.Add(urlutil.GetDomainsForURL(tlsURL)...) hosts.Add(urlutil.GetDomainsForURL(tlsURL, !o.IsRuntimeFlagSet(RuntimeFlagStripFromPort))...)
} }
} }
} }

View file

@ -606,14 +606,14 @@ func (p *Policy) String() string {
} }
// Matches returns true if the policy would match the given URL. // Matches returns true if the policy would match the given URL.
func (p *Policy) Matches(requestURL url.URL) bool { func (p *Policy) Matches(requestURL url.URL, stripPort bool) bool {
// an invalid from URL should not match anything // an invalid from URL should not match anything
fromURL, err := urlutil.ParseAndValidateURL(p.From) fromURL, err := urlutil.ParseAndValidateURL(p.From)
if err != nil { if err != nil {
return false return false
} }
if !FromURLMatchesRequestURL(fromURL, &requestURL) { if !FromURLMatchesRequestURL(fromURL, &requestURL, stripPort) {
return false return false
} }

View file

@ -245,7 +245,7 @@ func TestPolicy_Matches(t *testing.T) {
} }
assert.NoError(t, p.Validate()) assert.NoError(t, p.Validate())
assert.False(t, p.Matches(urlutil.MustParseAndValidateURL(`https://www.example.com/foo/bar`)), assert.False(t, p.Matches(urlutil.MustParseAndValidateURL(`https://www.example.com/foo/bar`), true),
"regex should only match full string") "regex should only match full string")
}) })
t.Run("issue2952", func(t *testing.T) { t.Run("issue2952", func(t *testing.T) {
@ -256,7 +256,7 @@ func TestPolicy_Matches(t *testing.T) {
} }
assert.NoError(t, p.Validate()) assert.NoError(t, p.Validate())
assert.True(t, p.Matches(urlutil.MustParseAndValidateURL(`https://www.example.com/foo/bar/0`))) assert.True(t, p.Matches(urlutil.MustParseAndValidateURL(`https://www.example.com/foo/bar/0`), true))
}) })
t.Run("issue2592-test2", func(t *testing.T) { t.Run("issue2592-test2", func(t *testing.T) {
p := &Policy{ p := &Policy{
@ -266,8 +266,8 @@ func TestPolicy_Matches(t *testing.T) {
} }
assert.NoError(t, p.Validate()) assert.NoError(t, p.Validate())
assert.True(t, p.Matches(urlutil.MustParseAndValidateURL(`https://www.example.com/admin/foo`))) assert.True(t, p.Matches(urlutil.MustParseAndValidateURL(`https://www.example.com/admin/foo`), true))
assert.True(t, p.Matches(urlutil.MustParseAndValidateURL(`https://www.example.com/admin/bar`))) assert.True(t, p.Matches(urlutil.MustParseAndValidateURL(`https://www.example.com/admin/bar`), true))
}) })
t.Run("tcp", func(t *testing.T) { t.Run("tcp", func(t *testing.T) {
p := &Policy{ p := &Policy{
@ -276,7 +276,7 @@ func TestPolicy_Matches(t *testing.T) {
} }
assert.NoError(t, p.Validate()) assert.NoError(t, p.Validate())
assert.True(t, p.Matches(urlutil.MustParseAndValidateURL(`https://tcp.example.com:6379`))) assert.True(t, p.Matches(urlutil.MustParseAndValidateURL(`https://tcp.example.com:6379`), true))
}) })
} }

View file

@ -2,8 +2,13 @@ package config
import "golang.org/x/exp/maps" import "golang.org/x/exp/maps"
// RuntimeFlagGRPCDatabrokerKeepalive enables gRPC keepalive to the databroker service var (
var RuntimeFlagGRPCDatabrokerKeepalive = runtimeFlag("grpc_databroker_keepalive", false) // RuntimeFlagGRPCDatabrokerKeepalive enables gRPC keepalive to the databroker service
RuntimeFlagGRPCDatabrokerKeepalive = runtimeFlag("grpc_databroker_keepalive", false)
// RuntimeFlagStripFromPort enables stripping the port from incoming requests for matching from addresses
RuntimeFlagStripFromPort = runtimeFlag("strip_from_port", true)
)
// RuntimeFlag is a runtime flag that can flip on/off certain features // RuntimeFlag is a runtime flag that can flip on/off certain features
type RuntimeFlag string type RuntimeFlag string

View file

@ -106,9 +106,9 @@ func GetServerNamesForURL(u *url.URL) []string {
// GetDomainsForURL returns the available domains for given url. // GetDomainsForURL returns the available domains for given url.
// //
// For standard HTTP (80)/HTTPS (443) ports, it returns `example.com` and `example.com:<port>`. // For standard HTTP (80)/HTTPS (443) ports, it returns `example.com` and `example.com:<port>`,
// Otherwise, return the URL.Host value. // if includeDefaultPort is set. Otherwise, return the URL.Host value.
func GetDomainsForURL(u *url.URL) []string { func GetDomainsForURL(u *url.URL, includeDefaultPort bool) []string {
if u == nil { if u == nil {
return nil return nil
} }
@ -141,6 +141,10 @@ func GetDomainsForURL(u *url.URL) []string {
} }
} }
if !includeDefaultPort {
return []string{u.Hostname()}
}
// for everything else we return two routes: 'example.com' and 'example.com:443' // for everything else we return two routes: 'example.com' and 'example.com:443'
return []string{u.Hostname(), net.JoinHostPort(u.Hostname(), defaultPort)} return []string{u.Hostname(), net.JoinHostPort(u.Hostname(), defaultPort)}
} }

View file

@ -179,7 +179,7 @@ func TestGetDomainsForURL(t *testing.T) {
tc := tc tc := tc
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
t.Parallel() t.Parallel()
got := GetDomainsForURL(tc.u) got := GetDomainsForURL(tc.u, true)
if diff := cmp.Diff(got, tc.want); diff != "" { if diff := cmp.Diff(got, tc.want); diff != "" {
t.Errorf("GetDomainsForURL() = %v", diff) t.Errorf("GetDomainsForURL() = %v", diff)
} }