config: add support for extended TCP route URLs (#3845)

* config: add support for extended TCP route URLs

* nevermind, add duplicate names
This commit is contained in:
Caleb Doxsey 2022-12-27 12:50:33 -07:00 committed by GitHub
parent 67e12101fa
commit 271b0787a8
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 182 additions and 51 deletions

View file

@ -151,7 +151,7 @@ func getCheckRequestURL(req *envoy_service_auth_v3.CheckRequest) url.URL {
Scheme: h.GetScheme(), Scheme: h.GetScheme(),
Host: h.GetHost(), Host: h.GetHost(),
} }
u.Host = urlutil.GetDomainsForURL(u)[0] u.Host = urlutil.GetDomainsForURL(&u)[0]
// envoy sends the query string as part of the path // envoy sends the query string as part of the path
path := h.GetPath() path := h.GetPath()
if idx := strings.Index(path, "?"); idx != -1 { if idx := strings.Index(path, "?"); idx != -1 {

View file

@ -620,26 +620,28 @@ func getAllServerNames(cfg *config.Config, addr string) ([]string, error) {
serverNames := sets.NewSorted[string]() serverNames := sets.NewSorted[string]()
serverNames.Add("*") serverNames.Add("*")
routeableHosts, err := getAllRouteableHosts(cfg.Options, addr)
if err != nil {
return nil, err
}
for _, hp := range routeableHosts {
if h, _, err := net.SplitHostPort(hp); err == nil {
serverNames.Add(h)
} else {
serverNames.Add(hp)
}
}
certs, err := cfg.AllCertificates() certs, err := cfg.AllCertificates()
if err != nil { if err != nil {
return nil, err return nil, err
} }
for i := range certs { for i := range certs {
for _, domain := range cryptutil.GetCertificateServerNames(&certs[i]) { serverNames.Add(cryptutil.GetCertificateServerNames(&certs[i])...)
serverNames.Add(domain) }
if addr == cfg.Options.Addr {
sns, err := cfg.Options.GetAllRouteableHTTPServerNames()
if err != nil {
return nil, err
} }
serverNames.Add(sns...)
}
if addr == cfg.Options.GetGRPCAddr() {
sns, err := cfg.Options.GetAllRouteableGRPCServerNames()
if err != nil {
return nil, err
}
serverNames.Add(sns...)
} }
return serverNames.ToSlice(), nil return serverNames.ToSlice(), nil
@ -655,30 +657,12 @@ func urlsMatchHost(urls []*url.URL, host string) bool {
} }
func urlMatchesHost(u *url.URL, host string) bool { func urlMatchesHost(u *url.URL, host string) bool {
if u == nil { for _, h := range urlutil.GetDomainsForURL(u) {
return false if h == host {
return true
}
} }
return false
var defaultPort string
if u.Scheme == "http" {
defaultPort = "80"
} else {
defaultPort = "443"
}
h1, p1, err := net.SplitHostPort(u.Host)
if err != nil {
h1 = u.Host
p1 = defaultPort
}
h2, p2, err := net.SplitHostPort(host)
if err != nil {
h2 = host
p2 = defaultPort
}
return h1 == h2 && p1 == p2
} }
func getPoliciesForServerName(options *config.Options, serverName string) []config.Policy { func getPoliciesForServerName(options *config.Options, serverName string) []config.Policy {

View file

@ -984,6 +984,7 @@ func Test_getAllDomains(t *testing.T) {
{Source: &config.StringURL{URL: mustParseURL(t, "http://a.example.com")}}, {Source: &config.StringURL{URL: mustParseURL(t, "http://a.example.com")}},
{Source: &config.StringURL{URL: mustParseURL(t, "https://b.example.com")}}, {Source: &config.StringURL{URL: mustParseURL(t, "https://b.example.com")}},
{Source: &config.StringURL{URL: mustParseURL(t, "https://c.example.com")}}, {Source: &config.StringURL{URL: mustParseURL(t, "https://c.example.com")}},
{Source: &config.StringURL{URL: mustParseURL(t, "https://d.unknown.example.com")}},
}, },
Cert: base64.StdEncoding.EncodeToString(certPEM), Cert: base64.StdEncoding.EncodeToString(certPEM),
Key: base64.StdEncoding.EncodeToString(keyPEM), Key: base64.StdEncoding.EncodeToString(keyPEM),
@ -1001,6 +1002,8 @@ func Test_getAllDomains(t *testing.T) {
"b.example.com:443", "b.example.com:443",
"c.example.com", "c.example.com",
"c.example.com:443", "c.example.com:443",
"d.unknown.example.com",
"d.unknown.example.com:443",
} }
assert.Equal(t, expect, actual) assert.Equal(t, expect, actual)
}) })
@ -1029,6 +1032,8 @@ func Test_getAllDomains(t *testing.T) {
"c.example.com", "c.example.com",
"c.example.com:443", "c.example.com:443",
"cache.example.com:9001", "cache.example.com:9001",
"d.unknown.example.com",
"d.unknown.example.com:443",
} }
assert.Equal(t, expect, actual) assert.Equal(t, expect, actual)
}) })
@ -1044,6 +1049,7 @@ func Test_getAllDomains(t *testing.T) {
"authenticate.example.com", "authenticate.example.com",
"b.example.com", "b.example.com",
"c.example.com", "c.example.com",
"d.unknown.example.com",
} }
assert.Equal(t, expect, actual) assert.Equal(t, expect, actual)
}) })

View file

@ -1026,7 +1026,7 @@ func (o *Options) GetAllRouteableGRPCHosts() ([]string, error) {
return nil, err return nil, err
} }
for _, u := range authorizeURLs { for _, u := range authorizeURLs {
hosts.Add(urlutil.GetDomainsForURL(*u)...) hosts.Add(urlutil.GetDomainsForURL(u)...)
} }
} else if IsAuthorize(o.Services) { } else if IsAuthorize(o.Services) {
authorizeURLs, err := o.GetInternalAuthorizeURLs() authorizeURLs, err := o.GetInternalAuthorizeURLs()
@ -1034,7 +1034,7 @@ func (o *Options) GetAllRouteableGRPCHosts() ([]string, error) {
return nil, err return nil, err
} }
for _, u := range authorizeURLs { for _, u := range authorizeURLs {
hosts.Add(urlutil.GetDomainsForURL(*u)...) hosts.Add(urlutil.GetDomainsForURL(u)...)
} }
} }
@ -1045,7 +1045,7 @@ func (o *Options) GetAllRouteableGRPCHosts() ([]string, error) {
return nil, err return nil, err
} }
for _, u := range dataBrokerURLs { for _, u := range dataBrokerURLs {
hosts.Add(urlutil.GetDomainsForURL(*u)...) hosts.Add(urlutil.GetDomainsForURL(u)...)
} }
} else if IsDataBroker(o.Services) { } else if IsDataBroker(o.Services) {
dataBrokerURLs, err := o.GetInternalDataBrokerURLs() dataBrokerURLs, err := o.GetInternalDataBrokerURLs()
@ -1053,7 +1053,52 @@ func (o *Options) GetAllRouteableGRPCHosts() ([]string, error) {
return nil, err return nil, err
} }
for _, u := range dataBrokerURLs { for _, u := range dataBrokerURLs {
hosts.Add(urlutil.GetDomainsForURL(*u)...) hosts.Add(urlutil.GetDomainsForURL(u)...)
}
}
return hosts.ToSlice(), nil
}
// GetAllRouteableGRPCServerNames returns all the possible gRPC server names handled by the Pomerium options.
func (o *Options) GetAllRouteableGRPCServerNames() ([]string, error) {
hosts := sets.NewSorted[string]()
// authorize urls
if IsAll(o.Services) {
authorizeURLs, err := o.GetAuthorizeURLs()
if err != nil {
return nil, err
}
for _, u := range authorizeURLs {
hosts.Add(urlutil.GetServerNamesForURL(u)...)
}
} else if IsAuthorize(o.Services) {
authorizeURLs, err := o.GetInternalAuthorizeURLs()
if err != nil {
return nil, err
}
for _, u := range authorizeURLs {
hosts.Add(urlutil.GetServerNamesForURL(u)...)
}
}
// databroker urls
if IsAll(o.Services) {
dataBrokerURLs, err := o.GetDataBrokerURLs()
if err != nil {
return nil, err
}
for _, u := range dataBrokerURLs {
hosts.Add(urlutil.GetServerNamesForURL(u)...)
}
} else if IsDataBroker(o.Services) {
dataBrokerURLs, err := o.GetInternalDataBrokerURLs()
if err != nil {
return nil, err
}
for _, u := range dataBrokerURLs {
hosts.Add(urlutil.GetServerNamesForURL(u)...)
} }
} }
@ -1068,22 +1113,22 @@ func (o *Options) GetAllRouteableHTTPHosts() ([]string, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
hosts.Add(urlutil.GetDomainsForURL(*authenticateURL)...) hosts.Add(urlutil.GetDomainsForURL(authenticateURL)...)
authenticateURL, err = o.GetAuthenticateURL() authenticateURL, err = o.GetAuthenticateURL()
if err != nil { if err != nil {
return nil, err return nil, err
} }
hosts.Add(urlutil.GetDomainsForURL(*authenticateURL)...) hosts.Add(urlutil.GetDomainsForURL(authenticateURL)...)
} }
// policy urls // policy urls
if IsProxy(o.Services) { if IsProxy(o.Services) {
for _, policy := range o.GetAllPolicies() { for _, policy := range o.GetAllPolicies() {
hosts.Add(urlutil.GetDomainsForURL(*policy.Source.URL)...) hosts.Add(urlutil.GetDomainsForURL(policy.Source.URL)...)
if policy.TLSDownstreamServerName != "" { if policy.TLSDownstreamServerName != "" {
tlsURL := policy.Source.URL.ResolveReference(&url.URL{Host: policy.TLSDownstreamServerName}) tlsURL := policy.Source.URL.ResolveReference(&url.URL{Host: policy.TLSDownstreamServerName})
hosts.Add(urlutil.GetDomainsForURL(*tlsURL)...) hosts.Add(urlutil.GetDomainsForURL(tlsURL)...)
} }
} }
} }
@ -1091,6 +1136,37 @@ func (o *Options) GetAllRouteableHTTPHosts() ([]string, error) {
return hosts.ToSlice(), nil return hosts.ToSlice(), nil
} }
// GetAllRouteableHTTPServerNames returns all the possible HTTP server names handled by the Pomerium options.
func (o *Options) GetAllRouteableHTTPServerNames() ([]string, error) {
serverNames := sets.NewSorted[string]()
if IsAuthenticate(o.Services) {
authenticateURL, err := o.GetInternalAuthenticateURL()
if err != nil {
return nil, err
}
serverNames.Add(urlutil.GetServerNamesForURL(authenticateURL)...)
authenticateURL, err = o.GetAuthenticateURL()
if err != nil {
return nil, err
}
serverNames.Add(urlutil.GetServerNamesForURL(authenticateURL)...)
}
// policy urls
if IsProxy(o.Services) {
for _, policy := range o.GetAllPolicies() {
serverNames.Add(urlutil.GetServerNamesForURL(policy.Source.URL)...)
if policy.TLSDownstreamServerName != "" {
tlsURL := policy.Source.URL.ResolveReference(&url.URL{Host: policy.TLSDownstreamServerName})
serverNames.Add(urlutil.GetServerNamesForURL(tlsURL)...)
}
}
}
return serverNames.ToSlice(), nil
}
// GetClientSecret gets the client secret. // GetClientSecret gets the client secret.
func (o *Options) GetClientSecret() (string, error) { func (o *Options) GetClientSecret() (string, error) {
if o == nil { if o == nil {

View file

@ -599,7 +599,12 @@ func (p *Policy) Matches(requestURL url.URL) bool {
return false return false
} }
if p.Source.Host != requestURL.Host { // make sure one of the host domains matches the incoming url
found := false
for _, host := range urlutil.GetDomainsForURL(p.Source.URL) {
found = found || host == requestURL.Host
}
if !found {
return false return false
} }

View file

@ -269,4 +269,13 @@ func TestPolicy_Matches(t *testing.T) {
assert.True(t, p.Matches(urlutil.MustParseAndValidateURL(`https://www.example.com/admin/foo`))) assert.True(t, p.Matches(urlutil.MustParseAndValidateURL(`https://www.example.com/admin/foo`)))
assert.True(t, p.Matches(urlutil.MustParseAndValidateURL(`https://www.example.com/admin/bar`))) assert.True(t, p.Matches(urlutil.MustParseAndValidateURL(`https://www.example.com/admin/bar`)))
}) })
t.Run("tcp", func(t *testing.T) {
p := &Policy{
From: "tcp+https://proxy.example.com/redis.example.com:6379",
To: mustParseWeightedURLs(t, "tcp://localhost:6379"),
}
assert.NoError(t, p.Validate())
assert.True(t, p.Matches(urlutil.MustParseAndValidateURL(`https://redis.example.com:6379`)))
})
} }

View file

@ -94,13 +94,37 @@ func GetAbsoluteURL(r *http.Request) *url.URL {
return u return u
} }
// GetServerNamesForURL returns the TLS server names for the given URL. The server name is the
// URL hostname.
func GetServerNamesForURL(u *url.URL) []string {
if u == nil {
return nil
}
return []string{u.Hostname()}
}
// GetDomainsForURL returns the available domains for given url. // GetDomainsForURL returns the available domains for given url.
// //
// For standard HTTP (80)/HTTPS (443) ports, it returns `example.com` and `example.com:<port>`. // For standard HTTP (80)/HTTPS (443) ports, it returns `example.com` and `example.com:<port>`.
// Otherwise, return the URL.Host value. // Otherwise, return the URL.Host value.
func GetDomainsForURL(u url.URL) []string { func GetDomainsForURL(u *url.URL) []string {
if IsTCP(&u) { if u == nil {
return []string{u.Host} return nil
}
// tcp+https://ssh.example.com:22
// => ssh.example.com:22
// tcp+https://proxy.example.com/ssh.example.com:22
// => ssh.example.com:22
if strings.HasPrefix(u.Scheme, "tcp+") {
hosts := strings.Split(u.Path, "/")[1:]
// if there are no domains in the path part of the URL, use the host
if len(hosts) == 0 {
return []string{u.Host}
}
// otherwise use the path parts of the URL as the hosts
return hosts
} }
var defaultPort string var defaultPort string

View file

@ -136,6 +136,31 @@ func TestGetAbsoluteURL(t *testing.T) {
} }
} }
func TestGetServerNamesForURL(t *testing.T) {
t.Parallel()
for _, tc := range []struct {
name string
u *url.URL
want []string
}{
{"http", &url.URL{Scheme: "http", Host: "example.com"}, []string{"example.com"}},
{"http scheme with host contain 443", &url.URL{Scheme: "http", Host: "example.com:443"}, []string{"example.com"}},
{"https", &url.URL{Scheme: "https", Host: "example.com"}, []string{"example.com"}},
{"Host contains other port", &url.URL{Scheme: "https", Host: "example.com:1234"}, []string{"example.com"}},
{"tcp", &url.URL{Scheme: "tcp+https", Host: "example.com:1234"}, []string{"example.com"}},
{"tcp with path", &url.URL{Scheme: "tcp+https", Host: "proxy.example.com", Path: "/ssh.example.com:1234"}, []string{"proxy.example.com"}},
} {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
got := GetServerNamesForURL(tc.u)
if diff := cmp.Diff(got, tc.want); diff != "" {
t.Errorf("GetServerNamesForURL() = %v", diff)
}
})
}
}
func TestGetDomainsForURL(t *testing.T) { func TestGetDomainsForURL(t *testing.T) {
t.Parallel() t.Parallel()
tests := []struct { tests := []struct {
@ -147,12 +172,14 @@ func TestGetDomainsForURL(t *testing.T) {
{"http scheme with host contain 443", &url.URL{Scheme: "http", Host: "example.com:443"}, []string{"example.com:443"}}, {"http scheme with host contain 443", &url.URL{Scheme: "http", Host: "example.com:443"}, []string{"example.com:443"}},
{"https", &url.URL{Scheme: "https", Host: "example.com"}, []string{"example.com", "example.com:443"}}, {"https", &url.URL{Scheme: "https", Host: "example.com"}, []string{"example.com", "example.com:443"}},
{"Host contains other port", &url.URL{Scheme: "https", Host: "example.com:1234"}, []string{"example.com:1234"}}, {"Host contains other port", &url.URL{Scheme: "https", Host: "example.com:1234"}, []string{"example.com:1234"}},
{"tcp", &url.URL{Scheme: "tcp+https", Host: "example.com:1234"}, []string{"example.com:1234"}},
{"tcp with path", &url.URL{Scheme: "tcp+https", Host: "proxy.example.com", Path: "/ssh.example.com:1234"}, []string{"ssh.example.com:1234"}},
} }
for _, tc := range tests { for _, tc := range tests {
tc := tc tc := tc
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
t.Parallel() t.Parallel()
got := GetDomainsForURL(*tc.u) got := GetDomainsForURL(tc.u)
if diff := cmp.Diff(got, tc.want); diff != "" { if diff := cmp.Diff(got, tc.want); diff != "" {
t.Errorf("GetDomainsForURL() = %v", diff) t.Errorf("GetDomainsForURL() = %v", diff)
} }