mirror of
https://github.com/pomerium/pomerium.git
synced 2025-08-02 08:19:23 +02:00
core/config: add support for stripping the port for matching routes
This commit is contained in:
parent
05e077fe04
commit
31e8383250
13 changed files with 57 additions and 33 deletions
|
@ -163,7 +163,7 @@ func getCheckRequestURL(req *envoy_service_auth_v3.CheckRequest) url.URL {
|
|||
Scheme: h.GetScheme(),
|
||||
Host: h.GetHost(),
|
||||
}
|
||||
u.Host = urlutil.GetDomainsForURL(&u)[0]
|
||||
u.Host = urlutil.GetDomainsForURL(&u, false)[0]
|
||||
// envoy sends the query string as part of the path
|
||||
path := h.GetPath()
|
||||
if idx := strings.Index(path, "?"); idx != -1 {
|
||||
|
|
|
@ -2,6 +2,7 @@ package envoyconfig
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
envoy_config_accesslog_v3 "github.com/envoyproxy/go-control-plane/envoy/config/accesslog/v3"
|
||||
envoy_config_core_v3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
|
||||
|
@ -23,6 +24,15 @@ func (b *Builder) buildVirtualHost(
|
|||
Domains: []string{host},
|
||||
}
|
||||
|
||||
// if we're stripping the port from incoming requests
|
||||
// and this host doesn't have a port or wildcard in it
|
||||
// then we will add :* to match on any port
|
||||
if options.IsRuntimeFlagSet(config.RuntimeFlagStripFromPort) &&
|
||||
!strings.Contains(host, "*") &&
|
||||
!strings.Contains(host, ":") {
|
||||
vh.Domains = append(vh.Domains, host+":*")
|
||||
}
|
||||
|
||||
// these routes match /.pomerium/... and similar paths
|
||||
rs, err := b.buildPomeriumHTTPRoutes(options, host)
|
||||
if err != nil {
|
||||
|
|
|
@ -669,7 +669,7 @@ func urlsMatchHost(urls []*url.URL, host string) bool {
|
|||
}
|
||||
|
||||
func urlMatchesHost(u *url.URL, host string) bool {
|
||||
for _, h := range urlutil.GetDomainsForURL(u) {
|
||||
for _, h := range urlutil.GetDomainsForURL(u, true) {
|
||||
if h == host {
|
||||
return true
|
||||
}
|
||||
|
|
|
@ -236,8 +236,8 @@ func (b *Builder) buildRoutesForPolicy(
|
|||
var routes []*envoy_config_route_v3.Route
|
||||
if strings.Contains(fromURL.Host, "*") {
|
||||
// we have to match '*.example.com' and '*.example.com:443', so there are two routes
|
||||
for _, host := range urlutil.GetDomainsForURL(fromURL) {
|
||||
route, err := b.buildRouteForPolicyAndMatch(cfg, policy, name, mkRouteMatchForHost(policy, host))
|
||||
for _, host := range urlutil.GetDomainsForURL(fromURL, !cfg.Options.IsRuntimeFlagSet(config.RuntimeFlagStripFromPort)) {
|
||||
route, err := b.buildRouteForPolicyAndMatch(cfg, policy, name, mkRouteMatchForHost(cfg.Options, policy, host))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -507,6 +507,7 @@ func mkRouteMatch(policy *config.Policy) *envoy_config_route_v3.RouteMatch {
|
|||
}
|
||||
|
||||
func mkRouteMatchForHost(
|
||||
options *config.Options,
|
||||
policy *config.Policy,
|
||||
host string,
|
||||
) *envoy_config_route_v3.RouteMatch {
|
||||
|
@ -517,7 +518,7 @@ func mkRouteMatchForHost(
|
|||
StringMatch: &envoy_type_matcher_v3.StringMatcher{
|
||||
MatchPattern: &envoy_type_matcher_v3.StringMatcher_SafeRegex{
|
||||
SafeRegex: &envoy_type_matcher_v3.RegexMatcher{
|
||||
Regex: config.WildcardToRegex(host),
|
||||
Regex: config.WildcardToRegex(host, options.IsRuntimeFlagSet(config.RuntimeFlagStripFromPort)),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
|
@ -9,8 +9,8 @@ import (
|
|||
)
|
||||
|
||||
// FromURLMatchesRequestURL returns true if the from URL matches the request URL.
|
||||
func FromURLMatchesRequestURL(fromURL, requestURL *url.URL) bool {
|
||||
for _, domain := range urlutil.GetDomainsForURL(fromURL) {
|
||||
func FromURLMatchesRequestURL(fromURL, requestURL *url.URL, stripPort bool) bool {
|
||||
for _, domain := range urlutil.GetDomainsForURL(fromURL, true) {
|
||||
if domain == requestURL.Host {
|
||||
return true
|
||||
}
|
||||
|
@ -19,7 +19,7 @@ func FromURLMatchesRequestURL(fromURL, requestURL *url.URL) bool {
|
|||
continue
|
||||
}
|
||||
|
||||
reStr := WildcardToRegex(domain)
|
||||
reStr := WildcardToRegex(domain, stripPort)
|
||||
re := regexp.MustCompile(reStr)
|
||||
if re.MatchString(requestURL.Host) {
|
||||
return true
|
||||
|
@ -29,7 +29,7 @@ func FromURLMatchesRequestURL(fromURL, requestURL *url.URL) bool {
|
|||
}
|
||||
|
||||
// WildcardToRegex converts a wildcard string to a regular expression.
|
||||
func WildcardToRegex(wildcard string) string {
|
||||
func WildcardToRegex(wildcard string, stripPort bool) string {
|
||||
var b strings.Builder
|
||||
b.WriteByte('^')
|
||||
for {
|
||||
|
@ -42,6 +42,9 @@ func WildcardToRegex(wildcard string) string {
|
|||
wildcard = wildcard[idx+1:]
|
||||
}
|
||||
b.WriteString(regexp.QuoteMeta(wildcard))
|
||||
if stripPort && !strings.Contains(wildcard, ":") {
|
||||
b.WriteString("(:(.+))?")
|
||||
}
|
||||
b.WriteByte('$')
|
||||
return b.String()
|
||||
}
|
||||
|
|
|
@ -21,10 +21,11 @@ func TestFromURLMatchesRequestURL(t *testing.T) {
|
|||
{"https://from.example.com", "https://to.example.com/some/path", false},
|
||||
{"https://*.example.com", "https://from.example.com/some/path", true},
|
||||
{"https://*.example.com", "https://example.com/some/path", false},
|
||||
{"https://*.example.com", "https://from.example.com:8443/some/path", true},
|
||||
} {
|
||||
fromURL := urlutil.MustParseAndValidateURL(tc.pattern)
|
||||
requestURL := urlutil.MustParseAndValidateURL(tc.input)
|
||||
assert.Equal(t, tc.matches, FromURLMatchesRequestURL(&fromURL, &requestURL),
|
||||
assert.Equal(t, tc.matches, FromURLMatchesRequestURL(&fromURL, &requestURL, true),
|
||||
"from-url: %s\nrequest-url: %s", tc.pattern, tc.input)
|
||||
}
|
||||
}
|
||||
|
@ -32,7 +33,7 @@ func TestFromURLMatchesRequestURL(t *testing.T) {
|
|||
func TestWildcardToRegex(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
re, err := regexp.Compile(WildcardToRegex("*.internal.*.example.com"))
|
||||
re, err := regexp.Compile(WildcardToRegex("*.internal.*.example.com", true))
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, re.MatchString("a.internal.b.example.com"))
|
||||
}
|
||||
|
|
|
@ -59,7 +59,7 @@ func (o *Options) GetIdentityProviderForRequestURL(requestURL string) (*identity
|
|||
|
||||
for _, p := range o.GetAllPolicies() {
|
||||
p := p
|
||||
if p.Matches(*u) {
|
||||
if p.Matches(*u, o.IsRuntimeFlagSet(RuntimeFlagStripFromPort)) {
|
||||
return o.GetIdentityProviderForPolicy(&p)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1216,7 +1216,7 @@ func (o *Options) GetAllRouteableGRPCHosts() ([]string, error) {
|
|||
return nil, err
|
||||
}
|
||||
for _, u := range authorizeURLs {
|
||||
hosts.Add(urlutil.GetDomainsForURL(u)...)
|
||||
hosts.Add(urlutil.GetDomainsForURL(u, true)...)
|
||||
}
|
||||
} else if IsAuthorize(o.Services) {
|
||||
authorizeURLs, err := o.GetInternalAuthorizeURLs()
|
||||
|
@ -1224,7 +1224,7 @@ func (o *Options) GetAllRouteableGRPCHosts() ([]string, error) {
|
|||
return nil, err
|
||||
}
|
||||
for _, u := range authorizeURLs {
|
||||
hosts.Add(urlutil.GetDomainsForURL(u)...)
|
||||
hosts.Add(urlutil.GetDomainsForURL(u, true)...)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1235,7 +1235,7 @@ func (o *Options) GetAllRouteableGRPCHosts() ([]string, error) {
|
|||
return nil, err
|
||||
}
|
||||
for _, u := range dataBrokerURLs {
|
||||
hosts.Add(urlutil.GetDomainsForURL(u)...)
|
||||
hosts.Add(urlutil.GetDomainsForURL(u, true)...)
|
||||
}
|
||||
} else if IsDataBroker(o.Services) {
|
||||
dataBrokerURLs, err := o.GetInternalDataBrokerURLs()
|
||||
|
@ -1243,7 +1243,7 @@ func (o *Options) GetAllRouteableGRPCHosts() ([]string, error) {
|
|||
return nil, err
|
||||
}
|
||||
for _, u := range dataBrokerURLs {
|
||||
hosts.Add(urlutil.GetDomainsForURL(u)...)
|
||||
hosts.Add(urlutil.GetDomainsForURL(u, true)...)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1259,7 +1259,7 @@ func (o *Options) GetAllRouteableHTTPHosts() ([]string, error) {
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
hosts.Add(urlutil.GetDomainsForURL(authenticateURL)...)
|
||||
hosts.Add(urlutil.GetDomainsForURL(authenticateURL, !o.IsRuntimeFlagSet(RuntimeFlagStripFromPort))...)
|
||||
}
|
||||
|
||||
if o.AuthenticateURLString != "" {
|
||||
|
@ -1267,7 +1267,7 @@ func (o *Options) GetAllRouteableHTTPHosts() ([]string, error) {
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
hosts.Add(urlutil.GetDomainsForURL(authenticateURL)...)
|
||||
hosts.Add(urlutil.GetDomainsForURL(authenticateURL, !o.IsRuntimeFlagSet(RuntimeFlagStripFromPort))...)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1279,10 +1279,10 @@ func (o *Options) GetAllRouteableHTTPHosts() ([]string, error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
hosts.Add(urlutil.GetDomainsForURL(fromURL)...)
|
||||
hosts.Add(urlutil.GetDomainsForURL(fromURL, !o.IsRuntimeFlagSet(RuntimeFlagStripFromPort))...)
|
||||
if policy.TLSDownstreamServerName != "" {
|
||||
tlsURL := fromURL.ResolveReference(&url.URL{Host: policy.TLSDownstreamServerName})
|
||||
hosts.Add(urlutil.GetDomainsForURL(tlsURL)...)
|
||||
hosts.Add(urlutil.GetDomainsForURL(tlsURL, !o.IsRuntimeFlagSet(RuntimeFlagStripFromPort))...)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -606,14 +606,14 @@ func (p *Policy) String() string {
|
|||
}
|
||||
|
||||
// Matches returns true if the policy would match the given URL.
|
||||
func (p *Policy) Matches(requestURL url.URL) bool {
|
||||
func (p *Policy) Matches(requestURL url.URL, stripPort bool) bool {
|
||||
// an invalid from URL should not match anything
|
||||
fromURL, err := urlutil.ParseAndValidateURL(p.From)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if !FromURLMatchesRequestURL(fromURL, &requestURL) {
|
||||
if !FromURLMatchesRequestURL(fromURL, &requestURL, stripPort) {
|
||||
return false
|
||||
}
|
||||
|
||||
|
|
|
@ -245,7 +245,7 @@ func TestPolicy_Matches(t *testing.T) {
|
|||
}
|
||||
assert.NoError(t, p.Validate())
|
||||
|
||||
assert.False(t, p.Matches(urlutil.MustParseAndValidateURL(`https://www.example.com/foo/bar`)),
|
||||
assert.False(t, p.Matches(urlutil.MustParseAndValidateURL(`https://www.example.com/foo/bar`), true),
|
||||
"regex should only match full string")
|
||||
})
|
||||
t.Run("issue2952", func(t *testing.T) {
|
||||
|
@ -256,7 +256,7 @@ func TestPolicy_Matches(t *testing.T) {
|
|||
}
|
||||
assert.NoError(t, p.Validate())
|
||||
|
||||
assert.True(t, p.Matches(urlutil.MustParseAndValidateURL(`https://www.example.com/foo/bar/0`)))
|
||||
assert.True(t, p.Matches(urlutil.MustParseAndValidateURL(`https://www.example.com/foo/bar/0`), true))
|
||||
})
|
||||
t.Run("issue2592-test2", func(t *testing.T) {
|
||||
p := &Policy{
|
||||
|
@ -266,8 +266,8 @@ func TestPolicy_Matches(t *testing.T) {
|
|||
}
|
||||
assert.NoError(t, p.Validate())
|
||||
|
||||
assert.True(t, p.Matches(urlutil.MustParseAndValidateURL(`https://www.example.com/admin/foo`)))
|
||||
assert.True(t, p.Matches(urlutil.MustParseAndValidateURL(`https://www.example.com/admin/bar`)))
|
||||
assert.True(t, p.Matches(urlutil.MustParseAndValidateURL(`https://www.example.com/admin/foo`), true))
|
||||
assert.True(t, p.Matches(urlutil.MustParseAndValidateURL(`https://www.example.com/admin/bar`), true))
|
||||
})
|
||||
t.Run("tcp", func(t *testing.T) {
|
||||
p := &Policy{
|
||||
|
@ -276,7 +276,7 @@ func TestPolicy_Matches(t *testing.T) {
|
|||
}
|
||||
assert.NoError(t, p.Validate())
|
||||
|
||||
assert.True(t, p.Matches(urlutil.MustParseAndValidateURL(`https://tcp.example.com:6379`)))
|
||||
assert.True(t, p.Matches(urlutil.MustParseAndValidateURL(`https://tcp.example.com:6379`), true))
|
||||
})
|
||||
}
|
||||
|
||||
|
|
|
@ -2,8 +2,13 @@ package config
|
|||
|
||||
import "golang.org/x/exp/maps"
|
||||
|
||||
// RuntimeFlagGRPCDatabrokerKeepalive enables gRPC keepalive to the databroker service
|
||||
var RuntimeFlagGRPCDatabrokerKeepalive = runtimeFlag("grpc_databroker_keepalive", false)
|
||||
var (
|
||||
// RuntimeFlagGRPCDatabrokerKeepalive enables gRPC keepalive to the databroker service
|
||||
RuntimeFlagGRPCDatabrokerKeepalive = runtimeFlag("grpc_databroker_keepalive", false)
|
||||
|
||||
// RuntimeFlagStripFromPort enables stripping the port from incoming requests for matching from addresses
|
||||
RuntimeFlagStripFromPort = runtimeFlag("strip_from_port", true)
|
||||
)
|
||||
|
||||
// RuntimeFlag is a runtime flag that can flip on/off certain features
|
||||
type RuntimeFlag string
|
||||
|
|
|
@ -106,9 +106,9 @@ func GetServerNamesForURL(u *url.URL) []string {
|
|||
|
||||
// GetDomainsForURL returns the available domains for given url.
|
||||
//
|
||||
// For standard HTTP (80)/HTTPS (443) ports, it returns `example.com` and `example.com:<port>`.
|
||||
// Otherwise, return the URL.Host value.
|
||||
func GetDomainsForURL(u *url.URL) []string {
|
||||
// For standard HTTP (80)/HTTPS (443) ports, it returns `example.com` and `example.com:<port>`,
|
||||
// if includeDefaultPort is set. Otherwise, return the URL.Host value.
|
||||
func GetDomainsForURL(u *url.URL, includeDefaultPort bool) []string {
|
||||
if u == nil {
|
||||
return nil
|
||||
}
|
||||
|
@ -141,6 +141,10 @@ func GetDomainsForURL(u *url.URL) []string {
|
|||
}
|
||||
}
|
||||
|
||||
if !includeDefaultPort {
|
||||
return []string{u.Hostname()}
|
||||
}
|
||||
|
||||
// for everything else we return two routes: 'example.com' and 'example.com:443'
|
||||
return []string{u.Hostname(), net.JoinHostPort(u.Hostname(), defaultPort)}
|
||||
}
|
||||
|
|
|
@ -179,7 +179,7 @@ func TestGetDomainsForURL(t *testing.T) {
|
|||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := GetDomainsForURL(tc.u)
|
||||
got := GetDomainsForURL(tc.u, true)
|
||||
if diff := cmp.Diff(got, tc.want); diff != "" {
|
||||
t.Errorf("GetDomainsForURL() = %v", diff)
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue