mirror of
https://github.com/pomerium/pomerium.git
synced 2025-04-29 18:36:30 +02:00
change Policy.Matches to accept a URL pointer (#5360)
This commit is contained in:
parent
9cd5fe4e25
commit
177f789e63
5 changed files with 9 additions and 9 deletions
|
@ -25,7 +25,7 @@ func TestFromURLMatchesRequestURL(t *testing.T) {
|
||||||
} {
|
} {
|
||||||
fromURL := urlutil.MustParseAndValidateURL(tc.pattern)
|
fromURL := urlutil.MustParseAndValidateURL(tc.pattern)
|
||||||
requestURL := urlutil.MustParseAndValidateURL(tc.input)
|
requestURL := urlutil.MustParseAndValidateURL(tc.input)
|
||||||
assert.Equal(t, tc.matches, FromURLMatchesRequestURL(&fromURL, &requestURL, true),
|
assert.Equal(t, tc.matches, FromURLMatchesRequestURL(fromURL, requestURL, true),
|
||||||
"from-url: %s\nrequest-url: %s", tc.pattern, tc.input)
|
"from-url: %s\nrequest-url: %s", tc.pattern, tc.input)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -63,7 +63,7 @@ func (o *Options) GetIdentityProviderForRequestURL(requestURL string) (*identity
|
||||||
}
|
}
|
||||||
|
|
||||||
for p := range o.GetAllPolicies() {
|
for p := range o.GetAllPolicies() {
|
||||||
if p.Matches(*u, o.IsRuntimeFlagSet(RuntimeFlagMatchAnyIncomingPort)) {
|
if p.Matches(u, o.IsRuntimeFlagSet(RuntimeFlagMatchAnyIncomingPort)) {
|
||||||
return o.GetIdentityProviderForPolicy(p)
|
return o.GetIdentityProviderForPolicy(p)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -735,14 +735,14 @@ func (p *Policy) String() string {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Matches returns true if the policy would match the given URL.
|
// Matches returns true if the policy would match the given URL.
|
||||||
func (p *Policy) Matches(requestURL url.URL, stripPort bool) bool {
|
func (p *Policy) Matches(requestURL *url.URL, stripPort bool) bool {
|
||||||
// an invalid from URL should not match anything
|
// an invalid from URL should not match anything
|
||||||
fromURL, err := urlutil.ParseAndValidateURL(p.From)
|
fromURL, err := urlutil.ParseAndValidateURL(p.From)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
if !FromURLMatchesRequestURL(fromURL, &requestURL, stripPort) {
|
if !FromURLMatchesRequestURL(fromURL, requestURL, stripPort) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -83,7 +83,7 @@ func TestSignInURL(t *testing.T) {
|
||||||
authenticateURL := MustParseAndValidateURL("https://authenticate.example.com")
|
authenticateURL := MustParseAndValidateURL("https://authenticate.example.com")
|
||||||
redirectURL := MustParseAndValidateURL("https://redirect.example.com")
|
redirectURL := MustParseAndValidateURL("https://redirect.example.com")
|
||||||
|
|
||||||
rawSignInURL, err := SignInURL(k1, k2.PublicKey(), &authenticateURL, &redirectURL, "IDP-1")
|
rawSignInURL, err := SignInURL(k1, k2.PublicKey(), authenticateURL, redirectURL, "IDP-1")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
signInURL, err := ParseAndValidateURL(rawSignInURL)
|
signInURL, err := ParseAndValidateURL(rawSignInURL)
|
||||||
|
@ -107,7 +107,7 @@ func TestSignOutURL(t *testing.T) {
|
||||||
}).Encode(), nil)
|
}).Encode(), nil)
|
||||||
authenticateURL := MustParseAndValidateURL("https://authenticate.example.com")
|
authenticateURL := MustParseAndValidateURL("https://authenticate.example.com")
|
||||||
|
|
||||||
rawSignOutURL := SignOutURL(r, &authenticateURL, []byte("TEST"))
|
rawSignOutURL := SignOutURL(r, authenticateURL, []byte("TEST"))
|
||||||
signOutURL, err := ParseAndValidateURL(rawSignOutURL)
|
signOutURL, err := ParseAndValidateURL(rawSignOutURL)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
|
|
@ -54,12 +54,12 @@ func ParseAndValidateURL(rawurl string) (*url.URL, error) {
|
||||||
|
|
||||||
// MustParseAndValidateURL parses the URL via ParseAndValidateURL but panics if there is an error.
|
// MustParseAndValidateURL parses the URL via ParseAndValidateURL but panics if there is an error.
|
||||||
// (useful for testing)
|
// (useful for testing)
|
||||||
func MustParseAndValidateURL(rawURL string) url.URL {
|
func MustParseAndValidateURL(rawURL string) *url.URL {
|
||||||
u, err := ParseAndValidateURL(rawURL)
|
u, err := ParseAndValidateURL(rawURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
return *u
|
return u
|
||||||
}
|
}
|
||||||
|
|
||||||
// ValidateURL wraps standard library's default url.Parse because
|
// ValidateURL wraps standard library's default url.Parse because
|
||||||
|
@ -187,6 +187,6 @@ func GetExternalRequest(internalURL, externalURL *url.URL, r *http.Request) *htt
|
||||||
}
|
}
|
||||||
|
|
||||||
// MatchesServerName returnes true if the url's host matches the given server name.
|
// MatchesServerName returnes true if the url's host matches the given server name.
|
||||||
func MatchesServerName(u url.URL, serverName string) bool {
|
func MatchesServerName(u *url.URL, serverName string) bool {
|
||||||
return certmagic.MatchWildcard(u.Hostname(), serverName)
|
return certmagic.MatchWildcard(u.Hostname(), serverName)
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Reference in a new issue