add more fuzz tests; fix edge cases

This commit is contained in:
Joe Kralicky 2024-06-24 20:00:24 -04:00
parent 0dadcd0b6a
commit 2adf9d2f45
No known key found for this signature in database
GPG key ID: 75C4875F34A9FB79
5 changed files with 176 additions and 15 deletions

View file

@ -84,6 +84,18 @@ var corpus = []string{
"example.com/search?query=open+source",
"example.com/redirect?to=https://example.org",
"example.com/profile/settings",
"a.example.com/foo/bar",
"b.example.com/foo/bar",
"c.example.com/foo/bar",
"d.example.com/foo/bar",
"a.foo.example.com/foo/bar",
"b.foo.example.com/foo/bar",
"c.foo.example.com/foo/bar",
"d.foo.example.com/foo/bar",
"foo.a.example.com/foo/bar",
"foo.b.example.com/foo/bar",
"foo.c.example.com/foo/bar",
"foo.d.example.com/foo/bar",
"test.com/foo",
"mysite.org/bar/baz",
"localhost:3000/test",
@ -129,7 +141,7 @@ func FuzzGetIdentityProviderForRequestURL(f *testing.F) {
emptyPortMatchesAll := true // todo
type testCase struct {
type policyMatcher struct {
policy *config.Policy
check func(input *url.URL) (bool, error)
}
@ -145,8 +157,8 @@ func FuzzGetIdentityProviderForRequestURL(f *testing.F) {
return true
}
testCases := []testCase{
{
policies := []policyMatcher{
0: {
policy: &config.Policy{From: "https://example.com"},
check: func(input *url.URL) (bool, error) {
if !checkEmptyPort(input) {
@ -156,10 +168,13 @@ func FuzzGetIdentityProviderForRequestURL(f *testing.F) {
if input.Hostname() != "example.com" {
return false, nil
}
if strings.HasPrefix(input.Path, "/api/") {
return false, nil // should match 2 instead
}
return true, nil
},
},
{
1: {
policy: &config.Policy{From: "https://*.foo.example.com", Prefix: "/prefix"},
check: func(u *url.URL) (bool, error) {
if !checkEmptyPort(u) {
@ -182,6 +197,145 @@ func FuzzGetIdentityProviderForRequestURL(f *testing.F) {
return true, nil
},
},
2: {
policy: &config.Policy{From: "https://example.com", Prefix: "/api/"},
check: func(input *url.URL) (bool, error) {
if !checkEmptyPort(input) {
return false, nil
}
if input.Hostname() != "example.com" {
return false, nil
}
if !strings.HasPrefix(input.Path, "/api/") {
return false, nil
}
return true, nil
},
},
3: {
policy: &config.Policy{From: "https://localhost:3000"},
check: func(input *url.URL) (bool, error) {
if input.Host != "localhost:3000" {
return false, nil
}
return true, nil
},
},
4: {
policy: &config.Policy{From: "https://secure.example.com"},
check: func(input *url.URL) (bool, error) {
if !checkEmptyPort(input) {
return false, nil
}
if input.Host != "secure.example.com" {
return false, nil
}
return true, nil
},
},
5: {
policy: &config.Policy{From: "https://example.*"},
check: func(input *url.URL) (bool, error) {
if !checkEmptyPort(input) {
return false, nil
}
parts := strings.Split(input.Hostname(), ".")
if len(parts) != 2 {
return false, nil
}
if parts[0] != "example" || len(parts[1]) == 0 {
return false, nil
}
if parts[1] == "com" {
return false, nil // should match 0 or 2 instead
}
return true, nil
},
},
6: {
policy: &config.Policy{From: "https://*.example.com"},
check: func(input *url.URL) (bool, error) {
if !checkEmptyPort(input) {
return false, nil
}
parts := strings.Split(input.Hostname(), ".")
if len(parts) != 3 {
return false, nil
}
if len(parts[0]) == 0 || parts[1] != "example" || parts[2] != "com" {
return false, nil
}
if parts[0] == "secure" {
return false, nil // should match 4 instead
}
return true, nil
},
},
7: {
policy: &config.Policy{From: "https://foo.*.example.com"},
check: func(input *url.URL) (bool, error) {
if !checkEmptyPort(input) {
return false, nil
}
parts := strings.Split(input.Hostname(), ".")
if len(parts) != 4 {
return false, nil
}
if parts[0] != "foo" || len(parts[1]) == 0 || parts[2] != "example" || parts[3] != "com" {
return false, nil
}
if parts[1] == "foo" {
return false, nil // should match 1 instead (subtle)
}
return true, nil
},
},
8: {
policy: &config.Policy{From: "https://*.*.example.com", Prefix: "/foo/"},
check: func(input *url.URL) (bool, error) {
if !checkEmptyPort(input) {
return false, nil
}
parts := strings.Split(input.Hostname(), ".")
if len(parts) != 4 {
return false, nil
}
if len(parts[0]) == 0 || len(parts[1]) == 0 || parts[2] != "example" || parts[3] != "com" {
return false, nil
}
if !strings.HasPrefix(input.Path, "/foo/") {
return false, nil
}
if parts[0] == "foo" {
return false, nil // should match 7 instead
}
if parts[1] == "foo" {
return false, nil // should match 1 instead
}
return true, nil
},
},
}
options := config.NewDefaultOptions()
@ -197,10 +351,10 @@ func FuzzGetIdentityProviderForRequestURL(f *testing.F) {
options.RuntimeFlags[config.RuntimeFlagMatchAnyIncomingPort] = false
}
for i, tc := range testCases {
for i, tc := range policies {
tc.policy.To = mustParseWeightedURLs(f, fmt.Sprintf("https://to-%d", i))
tc.policy.IDPClientID = fmt.Sprintf("client_id_%d", i)
tc.policy.IDPClientSecret = fmt.Sprintf("client_secret_%d", i)
tc.policy.IDPClientID = fmt.Sprint(i)
tc.policy.IDPClientSecret = fmt.Sprint(i)
options.Policies = append(options.Policies, *tc.policy)
}
require.NoError(f, options.Validate())
@ -214,12 +368,15 @@ func FuzzGetIdentityProviderForRequestURL(f *testing.F) {
if err != nil {
t.SkipNow()
}
if strings.Contains(inputURL.Host, "*") {
t.SkipNow()
}
for i, tc := range testCases {
expected := tc.policy
for i, p := range policies {
expected := p.policy
actualIdp, actualErr := cache.GetIdentityProviderForRequestURL(options, input)
expectedMatch, expectedErr := tc.check(inputURL)
expectedMatch, expectedErr := p.check(inputURL)
actualErrIsNotFound := errors.Is(actualErr, config.ErrNoIdentityProviderFound)
if expectedErr != nil {
@ -235,10 +392,10 @@ func FuzzGetIdentityProviderForRequestURL(f *testing.F) {
}
if expectedMatch {
if actualErrIsNotFound {
t.Fatalf("expected policy %d to match for input %q", i, input)
t.Fatalf("expected input %q to match policy %d, but no policies were matched", input, i)
return
}
assert.Equalf(t, expected.IDPClientID, actualIdp.ClientId, "wrong client id for input %q", input)
assert.Equalf(t, expected.IDPClientID, actualIdp.ClientId, "expected input %q to match policy %s, but instead matched policy %s", input, expected.IDPClientID, actualIdp.ClientId)
} else {
if !actualErrIsNotFound {
assert.NotEqualf(t, expected.IDPClientID, actualIdp.ClientId, "expected policy %d not to match input %q", i, input)