mirror of
https://github.com/pomerium/pomerium.git
synced 2025-08-03 00:40:25 +02:00
config: add support for extended TCP route URLs (#3845)
* config: add support for extended TCP route URLs * nevermind, add duplicate names
This commit is contained in:
parent
67e12101fa
commit
271b0787a8
8 changed files with 182 additions and 51 deletions
|
@ -151,7 +151,7 @@ func getCheckRequestURL(req *envoy_service_auth_v3.CheckRequest) url.URL {
|
|||
Scheme: h.GetScheme(),
|
||||
Host: h.GetHost(),
|
||||
}
|
||||
u.Host = urlutil.GetDomainsForURL(u)[0]
|
||||
u.Host = urlutil.GetDomainsForURL(&u)[0]
|
||||
// envoy sends the query string as part of the path
|
||||
path := h.GetPath()
|
||||
if idx := strings.Index(path, "?"); idx != -1 {
|
||||
|
|
|
@ -620,26 +620,28 @@ func getAllServerNames(cfg *config.Config, addr string) ([]string, error) {
|
|||
serverNames := sets.NewSorted[string]()
|
||||
serverNames.Add("*")
|
||||
|
||||
routeableHosts, err := getAllRouteableHosts(cfg.Options, addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, hp := range routeableHosts {
|
||||
if h, _, err := net.SplitHostPort(hp); err == nil {
|
||||
serverNames.Add(h)
|
||||
} else {
|
||||
serverNames.Add(hp)
|
||||
}
|
||||
}
|
||||
|
||||
certs, err := cfg.AllCertificates()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for i := range certs {
|
||||
for _, domain := range cryptutil.GetCertificateServerNames(&certs[i]) {
|
||||
serverNames.Add(domain)
|
||||
serverNames.Add(cryptutil.GetCertificateServerNames(&certs[i])...)
|
||||
}
|
||||
|
||||
if addr == cfg.Options.Addr {
|
||||
sns, err := cfg.Options.GetAllRouteableHTTPServerNames()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
serverNames.Add(sns...)
|
||||
}
|
||||
|
||||
if addr == cfg.Options.GetGRPCAddr() {
|
||||
sns, err := cfg.Options.GetAllRouteableGRPCServerNames()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
serverNames.Add(sns...)
|
||||
}
|
||||
|
||||
return serverNames.ToSlice(), nil
|
||||
|
@ -655,30 +657,12 @@ func urlsMatchHost(urls []*url.URL, host string) bool {
|
|||
}
|
||||
|
||||
func urlMatchesHost(u *url.URL, host string) bool {
|
||||
if u == nil {
|
||||
return false
|
||||
for _, h := range urlutil.GetDomainsForURL(u) {
|
||||
if h == host {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
var defaultPort string
|
||||
if u.Scheme == "http" {
|
||||
defaultPort = "80"
|
||||
} else {
|
||||
defaultPort = "443"
|
||||
}
|
||||
|
||||
h1, p1, err := net.SplitHostPort(u.Host)
|
||||
if err != nil {
|
||||
h1 = u.Host
|
||||
p1 = defaultPort
|
||||
}
|
||||
|
||||
h2, p2, err := net.SplitHostPort(host)
|
||||
if err != nil {
|
||||
h2 = host
|
||||
p2 = defaultPort
|
||||
}
|
||||
|
||||
return h1 == h2 && p1 == p2
|
||||
return false
|
||||
}
|
||||
|
||||
func getPoliciesForServerName(options *config.Options, serverName string) []config.Policy {
|
||||
|
|
|
@ -984,6 +984,7 @@ func Test_getAllDomains(t *testing.T) {
|
|||
{Source: &config.StringURL{URL: mustParseURL(t, "http://a.example.com")}},
|
||||
{Source: &config.StringURL{URL: mustParseURL(t, "https://b.example.com")}},
|
||||
{Source: &config.StringURL{URL: mustParseURL(t, "https://c.example.com")}},
|
||||
{Source: &config.StringURL{URL: mustParseURL(t, "https://d.unknown.example.com")}},
|
||||
},
|
||||
Cert: base64.StdEncoding.EncodeToString(certPEM),
|
||||
Key: base64.StdEncoding.EncodeToString(keyPEM),
|
||||
|
@ -1001,6 +1002,8 @@ func Test_getAllDomains(t *testing.T) {
|
|||
"b.example.com:443",
|
||||
"c.example.com",
|
||||
"c.example.com:443",
|
||||
"d.unknown.example.com",
|
||||
"d.unknown.example.com:443",
|
||||
}
|
||||
assert.Equal(t, expect, actual)
|
||||
})
|
||||
|
@ -1029,6 +1032,8 @@ func Test_getAllDomains(t *testing.T) {
|
|||
"c.example.com",
|
||||
"c.example.com:443",
|
||||
"cache.example.com:9001",
|
||||
"d.unknown.example.com",
|
||||
"d.unknown.example.com:443",
|
||||
}
|
||||
assert.Equal(t, expect, actual)
|
||||
})
|
||||
|
@ -1044,6 +1049,7 @@ func Test_getAllDomains(t *testing.T) {
|
|||
"authenticate.example.com",
|
||||
"b.example.com",
|
||||
"c.example.com",
|
||||
"d.unknown.example.com",
|
||||
}
|
||||
assert.Equal(t, expect, actual)
|
||||
})
|
||||
|
|
|
@ -1026,7 +1026,7 @@ func (o *Options) GetAllRouteableGRPCHosts() ([]string, error) {
|
|||
return nil, err
|
||||
}
|
||||
for _, u := range authorizeURLs {
|
||||
hosts.Add(urlutil.GetDomainsForURL(*u)...)
|
||||
hosts.Add(urlutil.GetDomainsForURL(u)...)
|
||||
}
|
||||
} else if IsAuthorize(o.Services) {
|
||||
authorizeURLs, err := o.GetInternalAuthorizeURLs()
|
||||
|
@ -1034,7 +1034,7 @@ func (o *Options) GetAllRouteableGRPCHosts() ([]string, error) {
|
|||
return nil, err
|
||||
}
|
||||
for _, u := range authorizeURLs {
|
||||
hosts.Add(urlutil.GetDomainsForURL(*u)...)
|
||||
hosts.Add(urlutil.GetDomainsForURL(u)...)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1045,7 +1045,7 @@ func (o *Options) GetAllRouteableGRPCHosts() ([]string, error) {
|
|||
return nil, err
|
||||
}
|
||||
for _, u := range dataBrokerURLs {
|
||||
hosts.Add(urlutil.GetDomainsForURL(*u)...)
|
||||
hosts.Add(urlutil.GetDomainsForURL(u)...)
|
||||
}
|
||||
} else if IsDataBroker(o.Services) {
|
||||
dataBrokerURLs, err := o.GetInternalDataBrokerURLs()
|
||||
|
@ -1053,7 +1053,52 @@ func (o *Options) GetAllRouteableGRPCHosts() ([]string, error) {
|
|||
return nil, err
|
||||
}
|
||||
for _, u := range dataBrokerURLs {
|
||||
hosts.Add(urlutil.GetDomainsForURL(*u)...)
|
||||
hosts.Add(urlutil.GetDomainsForURL(u)...)
|
||||
}
|
||||
}
|
||||
|
||||
return hosts.ToSlice(), nil
|
||||
}
|
||||
|
||||
// GetAllRouteableGRPCServerNames returns all the possible gRPC server names handled by the Pomerium options.
|
||||
func (o *Options) GetAllRouteableGRPCServerNames() ([]string, error) {
|
||||
hosts := sets.NewSorted[string]()
|
||||
|
||||
// authorize urls
|
||||
if IsAll(o.Services) {
|
||||
authorizeURLs, err := o.GetAuthorizeURLs()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, u := range authorizeURLs {
|
||||
hosts.Add(urlutil.GetServerNamesForURL(u)...)
|
||||
}
|
||||
} else if IsAuthorize(o.Services) {
|
||||
authorizeURLs, err := o.GetInternalAuthorizeURLs()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, u := range authorizeURLs {
|
||||
hosts.Add(urlutil.GetServerNamesForURL(u)...)
|
||||
}
|
||||
}
|
||||
|
||||
// databroker urls
|
||||
if IsAll(o.Services) {
|
||||
dataBrokerURLs, err := o.GetDataBrokerURLs()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, u := range dataBrokerURLs {
|
||||
hosts.Add(urlutil.GetServerNamesForURL(u)...)
|
||||
}
|
||||
} else if IsDataBroker(o.Services) {
|
||||
dataBrokerURLs, err := o.GetInternalDataBrokerURLs()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, u := range dataBrokerURLs {
|
||||
hosts.Add(urlutil.GetServerNamesForURL(u)...)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1068,22 +1113,22 @@ func (o *Options) GetAllRouteableHTTPHosts() ([]string, error) {
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
hosts.Add(urlutil.GetDomainsForURL(*authenticateURL)...)
|
||||
hosts.Add(urlutil.GetDomainsForURL(authenticateURL)...)
|
||||
|
||||
authenticateURL, err = o.GetAuthenticateURL()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
hosts.Add(urlutil.GetDomainsForURL(*authenticateURL)...)
|
||||
hosts.Add(urlutil.GetDomainsForURL(authenticateURL)...)
|
||||
}
|
||||
|
||||
// policy urls
|
||||
if IsProxy(o.Services) {
|
||||
for _, policy := range o.GetAllPolicies() {
|
||||
hosts.Add(urlutil.GetDomainsForURL(*policy.Source.URL)...)
|
||||
hosts.Add(urlutil.GetDomainsForURL(policy.Source.URL)...)
|
||||
if policy.TLSDownstreamServerName != "" {
|
||||
tlsURL := policy.Source.URL.ResolveReference(&url.URL{Host: policy.TLSDownstreamServerName})
|
||||
hosts.Add(urlutil.GetDomainsForURL(*tlsURL)...)
|
||||
hosts.Add(urlutil.GetDomainsForURL(tlsURL)...)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1091,6 +1136,37 @@ func (o *Options) GetAllRouteableHTTPHosts() ([]string, error) {
|
|||
return hosts.ToSlice(), nil
|
||||
}
|
||||
|
||||
// GetAllRouteableHTTPServerNames returns all the possible HTTP server names handled by the Pomerium options.
|
||||
func (o *Options) GetAllRouteableHTTPServerNames() ([]string, error) {
|
||||
serverNames := sets.NewSorted[string]()
|
||||
if IsAuthenticate(o.Services) {
|
||||
authenticateURL, err := o.GetInternalAuthenticateURL()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
serverNames.Add(urlutil.GetServerNamesForURL(authenticateURL)...)
|
||||
|
||||
authenticateURL, err = o.GetAuthenticateURL()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
serverNames.Add(urlutil.GetServerNamesForURL(authenticateURL)...)
|
||||
}
|
||||
|
||||
// policy urls
|
||||
if IsProxy(o.Services) {
|
||||
for _, policy := range o.GetAllPolicies() {
|
||||
serverNames.Add(urlutil.GetServerNamesForURL(policy.Source.URL)...)
|
||||
if policy.TLSDownstreamServerName != "" {
|
||||
tlsURL := policy.Source.URL.ResolveReference(&url.URL{Host: policy.TLSDownstreamServerName})
|
||||
serverNames.Add(urlutil.GetServerNamesForURL(tlsURL)...)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return serverNames.ToSlice(), nil
|
||||
}
|
||||
|
||||
// GetClientSecret gets the client secret.
|
||||
func (o *Options) GetClientSecret() (string, error) {
|
||||
if o == nil {
|
||||
|
|
|
@ -599,7 +599,12 @@ func (p *Policy) Matches(requestURL url.URL) bool {
|
|||
return false
|
||||
}
|
||||
|
||||
if p.Source.Host != requestURL.Host {
|
||||
// make sure one of the host domains matches the incoming url
|
||||
found := false
|
||||
for _, host := range urlutil.GetDomainsForURL(p.Source.URL) {
|
||||
found = found || host == requestURL.Host
|
||||
}
|
||||
if !found {
|
||||
return false
|
||||
}
|
||||
|
||||
|
|
|
@ -269,4 +269,13 @@ func TestPolicy_Matches(t *testing.T) {
|
|||
assert.True(t, p.Matches(urlutil.MustParseAndValidateURL(`https://www.example.com/admin/foo`)))
|
||||
assert.True(t, p.Matches(urlutil.MustParseAndValidateURL(`https://www.example.com/admin/bar`)))
|
||||
})
|
||||
t.Run("tcp", func(t *testing.T) {
|
||||
p := &Policy{
|
||||
From: "tcp+https://proxy.example.com/redis.example.com:6379",
|
||||
To: mustParseWeightedURLs(t, "tcp://localhost:6379"),
|
||||
}
|
||||
assert.NoError(t, p.Validate())
|
||||
|
||||
assert.True(t, p.Matches(urlutil.MustParseAndValidateURL(`https://redis.example.com:6379`)))
|
||||
})
|
||||
}
|
||||
|
|
|
@ -94,13 +94,37 @@ func GetAbsoluteURL(r *http.Request) *url.URL {
|
|||
return u
|
||||
}
|
||||
|
||||
// GetServerNamesForURL returns the TLS server names for the given URL. The server name is the
|
||||
// URL hostname.
|
||||
func GetServerNamesForURL(u *url.URL) []string {
|
||||
if u == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return []string{u.Hostname()}
|
||||
}
|
||||
|
||||
// GetDomainsForURL returns the available domains for given url.
|
||||
//
|
||||
// For standard HTTP (80)/HTTPS (443) ports, it returns `example.com` and `example.com:<port>`.
|
||||
// Otherwise, return the URL.Host value.
|
||||
func GetDomainsForURL(u url.URL) []string {
|
||||
if IsTCP(&u) {
|
||||
return []string{u.Host}
|
||||
func GetDomainsForURL(u *url.URL) []string {
|
||||
if u == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// tcp+https://ssh.example.com:22
|
||||
// => ssh.example.com:22
|
||||
// tcp+https://proxy.example.com/ssh.example.com:22
|
||||
// => ssh.example.com:22
|
||||
if strings.HasPrefix(u.Scheme, "tcp+") {
|
||||
hosts := strings.Split(u.Path, "/")[1:]
|
||||
// if there are no domains in the path part of the URL, use the host
|
||||
if len(hosts) == 0 {
|
||||
return []string{u.Host}
|
||||
}
|
||||
// otherwise use the path parts of the URL as the hosts
|
||||
return hosts
|
||||
}
|
||||
|
||||
var defaultPort string
|
||||
|
|
|
@ -136,6 +136,31 @@ func TestGetAbsoluteURL(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestGetServerNamesForURL(t *testing.T) {
|
||||
t.Parallel()
|
||||
for _, tc := range []struct {
|
||||
name string
|
||||
u *url.URL
|
||||
want []string
|
||||
}{
|
||||
{"http", &url.URL{Scheme: "http", Host: "example.com"}, []string{"example.com"}},
|
||||
{"http scheme with host contain 443", &url.URL{Scheme: "http", Host: "example.com:443"}, []string{"example.com"}},
|
||||
{"https", &url.URL{Scheme: "https", Host: "example.com"}, []string{"example.com"}},
|
||||
{"Host contains other port", &url.URL{Scheme: "https", Host: "example.com:1234"}, []string{"example.com"}},
|
||||
{"tcp", &url.URL{Scheme: "tcp+https", Host: "example.com:1234"}, []string{"example.com"}},
|
||||
{"tcp with path", &url.URL{Scheme: "tcp+https", Host: "proxy.example.com", Path: "/ssh.example.com:1234"}, []string{"proxy.example.com"}},
|
||||
} {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := GetServerNamesForURL(tc.u)
|
||||
if diff := cmp.Diff(got, tc.want); diff != "" {
|
||||
t.Errorf("GetServerNamesForURL() = %v", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetDomainsForURL(t *testing.T) {
|
||||
t.Parallel()
|
||||
tests := []struct {
|
||||
|
@ -147,12 +172,14 @@ func TestGetDomainsForURL(t *testing.T) {
|
|||
{"http scheme with host contain 443", &url.URL{Scheme: "http", Host: "example.com:443"}, []string{"example.com:443"}},
|
||||
{"https", &url.URL{Scheme: "https", Host: "example.com"}, []string{"example.com", "example.com:443"}},
|
||||
{"Host contains other port", &url.URL{Scheme: "https", Host: "example.com:1234"}, []string{"example.com:1234"}},
|
||||
{"tcp", &url.URL{Scheme: "tcp+https", Host: "example.com:1234"}, []string{"example.com:1234"}},
|
||||
{"tcp with path", &url.URL{Scheme: "tcp+https", Host: "proxy.example.com", Path: "/ssh.example.com:1234"}, []string{"ssh.example.com:1234"}},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := GetDomainsForURL(*tc.u)
|
||||
got := GetDomainsForURL(tc.u)
|
||||
if diff := cmp.Diff(got, tc.want); diff != "" {
|
||||
t.Errorf("GetDomainsForURL() = %v", diff)
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue