From 177f789e63eb1001db3dc56c66ed407a17e4b5c6 Mon Sep 17 00:00:00 2001 From: Joe Kralicky Date: Thu, 7 Nov 2024 14:55:44 -0500 Subject: [PATCH] change Policy.Matches to accept a URL pointer (#5360) --- config/from_test.go | 2 +- config/identity.go | 2 +- config/policy.go | 4 ++-- internal/urlutil/known_test.go | 4 ++-- internal/urlutil/url.go | 6 +++--- 5 files changed, 9 insertions(+), 9 deletions(-) diff --git a/config/from_test.go b/config/from_test.go index 323e58f4c..e59d126c2 100644 --- a/config/from_test.go +++ b/config/from_test.go @@ -25,7 +25,7 @@ func TestFromURLMatchesRequestURL(t *testing.T) { } { fromURL := urlutil.MustParseAndValidateURL(tc.pattern) requestURL := urlutil.MustParseAndValidateURL(tc.input) - assert.Equal(t, tc.matches, FromURLMatchesRequestURL(&fromURL, &requestURL, true), + assert.Equal(t, tc.matches, FromURLMatchesRequestURL(fromURL, requestURL, true), "from-url: %s\nrequest-url: %s", tc.pattern, tc.input) } } diff --git a/config/identity.go b/config/identity.go index 806a698ec..6eeef76db 100644 --- a/config/identity.go +++ b/config/identity.go @@ -63,7 +63,7 @@ func (o *Options) GetIdentityProviderForRequestURL(requestURL string) (*identity } for p := range o.GetAllPolicies() { - if p.Matches(*u, o.IsRuntimeFlagSet(RuntimeFlagMatchAnyIncomingPort)) { + if p.Matches(u, o.IsRuntimeFlagSet(RuntimeFlagMatchAnyIncomingPort)) { return o.GetIdentityProviderForPolicy(p) } } diff --git a/config/policy.go b/config/policy.go index b67d1f314..1ee903242 100644 --- a/config/policy.go +++ b/config/policy.go @@ -735,14 +735,14 @@ func (p *Policy) String() string { } // Matches returns true if the policy would match the given URL. -func (p *Policy) Matches(requestURL url.URL, stripPort bool) bool { +func (p *Policy) Matches(requestURL *url.URL, stripPort bool) bool { // an invalid from URL should not match anything fromURL, err := urlutil.ParseAndValidateURL(p.From) if err != nil { return false } - if !FromURLMatchesRequestURL(fromURL, &requestURL, stripPort) { + if !FromURLMatchesRequestURL(fromURL, requestURL, stripPort) { return false } diff --git a/internal/urlutil/known_test.go b/internal/urlutil/known_test.go index d128daffd..b53d6a2b3 100644 --- a/internal/urlutil/known_test.go +++ b/internal/urlutil/known_test.go @@ -83,7 +83,7 @@ func TestSignInURL(t *testing.T) { authenticateURL := MustParseAndValidateURL("https://authenticate.example.com") redirectURL := MustParseAndValidateURL("https://redirect.example.com") - rawSignInURL, err := SignInURL(k1, k2.PublicKey(), &authenticateURL, &redirectURL, "IDP-1") + rawSignInURL, err := SignInURL(k1, k2.PublicKey(), authenticateURL, redirectURL, "IDP-1") require.NoError(t, err) signInURL, err := ParseAndValidateURL(rawSignInURL) @@ -107,7 +107,7 @@ func TestSignOutURL(t *testing.T) { }).Encode(), nil) authenticateURL := MustParseAndValidateURL("https://authenticate.example.com") - rawSignOutURL := SignOutURL(r, &authenticateURL, []byte("TEST")) + rawSignOutURL := SignOutURL(r, authenticateURL, []byte("TEST")) signOutURL, err := ParseAndValidateURL(rawSignOutURL) require.NoError(t, err) diff --git a/internal/urlutil/url.go b/internal/urlutil/url.go index 95b60d7c6..8070589ac 100644 --- a/internal/urlutil/url.go +++ b/internal/urlutil/url.go @@ -54,12 +54,12 @@ func ParseAndValidateURL(rawurl string) (*url.URL, error) { // MustParseAndValidateURL parses the URL via ParseAndValidateURL but panics if there is an error. // (useful for testing) -func MustParseAndValidateURL(rawURL string) url.URL { +func MustParseAndValidateURL(rawURL string) *url.URL { u, err := ParseAndValidateURL(rawURL) if err != nil { panic(err) } - return *u + return u } // ValidateURL wraps standard library's default url.Parse because @@ -187,6 +187,6 @@ func GetExternalRequest(internalURL, externalURL *url.URL, r *http.Request) *htt } // MatchesServerName returnes true if the url's host matches the given server name. -func MatchesServerName(u url.URL, serverName string) bool { +func MatchesServerName(u *url.URL, serverName string) bool { return certmagic.MatchWildcard(u.Hostname(), serverName) }