diff --git a/config/identity_test.go b/config/identity_test.go index dff3d99db..99a6a4c9b 100644 --- a/config/identity_test.go +++ b/config/identity_test.go @@ -84,6 +84,18 @@ var corpus = []string{ "example.com/search?query=open+source", "example.com/redirect?to=https://example.org", "example.com/profile/settings", + "a.example.com/foo/bar", + "b.example.com/foo/bar", + "c.example.com/foo/bar", + "d.example.com/foo/bar", + "a.foo.example.com/foo/bar", + "b.foo.example.com/foo/bar", + "c.foo.example.com/foo/bar", + "d.foo.example.com/foo/bar", + "foo.a.example.com/foo/bar", + "foo.b.example.com/foo/bar", + "foo.c.example.com/foo/bar", + "foo.d.example.com/foo/bar", "test.com/foo", "mysite.org/bar/baz", "localhost:3000/test", @@ -129,7 +141,7 @@ func FuzzGetIdentityProviderForRequestURL(f *testing.F) { emptyPortMatchesAll := true // todo - type testCase struct { + type policyMatcher struct { policy *config.Policy check func(input *url.URL) (bool, error) } @@ -145,8 +157,8 @@ func FuzzGetIdentityProviderForRequestURL(f *testing.F) { return true } - testCases := []testCase{ - { + policies := []policyMatcher{ + 0: { policy: &config.Policy{From: "https://example.com"}, check: func(input *url.URL) (bool, error) { if !checkEmptyPort(input) { @@ -156,10 +168,13 @@ func FuzzGetIdentityProviderForRequestURL(f *testing.F) { if input.Hostname() != "example.com" { return false, nil } + if strings.HasPrefix(input.Path, "/api/") { + return false, nil // should match 2 instead + } return true, nil }, }, - { + 1: { policy: &config.Policy{From: "https://*.foo.example.com", Prefix: "/prefix"}, check: func(u *url.URL) (bool, error) { if !checkEmptyPort(u) { @@ -182,6 +197,145 @@ func FuzzGetIdentityProviderForRequestURL(f *testing.F) { return true, nil }, }, + 2: { + policy: &config.Policy{From: "https://example.com", Prefix: "/api/"}, + check: func(input *url.URL) (bool, error) { + if !checkEmptyPort(input) { + return false, nil + } + + if input.Hostname() != "example.com" { + return false, nil + } + + if !strings.HasPrefix(input.Path, "/api/") { + return false, nil + } + + return true, nil + }, + }, + 3: { + policy: &config.Policy{From: "https://localhost:3000"}, + check: func(input *url.URL) (bool, error) { + if input.Host != "localhost:3000" { + return false, nil + } + return true, nil + }, + }, + 4: { + policy: &config.Policy{From: "https://secure.example.com"}, + check: func(input *url.URL) (bool, error) { + if !checkEmptyPort(input) { + return false, nil + } + + if input.Host != "secure.example.com" { + return false, nil + } + + return true, nil + }, + }, + 5: { + policy: &config.Policy{From: "https://example.*"}, + check: func(input *url.URL) (bool, error) { + if !checkEmptyPort(input) { + return false, nil + } + + parts := strings.Split(input.Hostname(), ".") + if len(parts) != 2 { + return false, nil + } + + if parts[0] != "example" || len(parts[1]) == 0 { + return false, nil + } + + if parts[1] == "com" { + return false, nil // should match 0 or 2 instead + } + + return true, nil + }, + }, + 6: { + policy: &config.Policy{From: "https://*.example.com"}, + check: func(input *url.URL) (bool, error) { + if !checkEmptyPort(input) { + return false, nil + } + + parts := strings.Split(input.Hostname(), ".") + if len(parts) != 3 { + return false, nil + } + + if len(parts[0]) == 0 || parts[1] != "example" || parts[2] != "com" { + return false, nil + } + + if parts[0] == "secure" { + return false, nil // should match 4 instead + } + + return true, nil + }, + }, + 7: { + policy: &config.Policy{From: "https://foo.*.example.com"}, + check: func(input *url.URL) (bool, error) { + if !checkEmptyPort(input) { + return false, nil + } + + parts := strings.Split(input.Hostname(), ".") + if len(parts) != 4 { + return false, nil + } + + if parts[0] != "foo" || len(parts[1]) == 0 || parts[2] != "example" || parts[3] != "com" { + return false, nil + } + + if parts[1] == "foo" { + return false, nil // should match 1 instead (subtle) + } + + return true, nil + }, + }, + 8: { + policy: &config.Policy{From: "https://*.*.example.com", Prefix: "/foo/"}, + check: func(input *url.URL) (bool, error) { + if !checkEmptyPort(input) { + return false, nil + } + + parts := strings.Split(input.Hostname(), ".") + if len(parts) != 4 { + return false, nil + } + + if len(parts[0]) == 0 || len(parts[1]) == 0 || parts[2] != "example" || parts[3] != "com" { + return false, nil + } + + if !strings.HasPrefix(input.Path, "/foo/") { + return false, nil + } + + if parts[0] == "foo" { + return false, nil // should match 7 instead + } + if parts[1] == "foo" { + return false, nil // should match 1 instead + } + return true, nil + }, + }, } options := config.NewDefaultOptions() @@ -197,10 +351,10 @@ func FuzzGetIdentityProviderForRequestURL(f *testing.F) { options.RuntimeFlags[config.RuntimeFlagMatchAnyIncomingPort] = false } - for i, tc := range testCases { + for i, tc := range policies { tc.policy.To = mustParseWeightedURLs(f, fmt.Sprintf("https://to-%d", i)) - tc.policy.IDPClientID = fmt.Sprintf("client_id_%d", i) - tc.policy.IDPClientSecret = fmt.Sprintf("client_secret_%d", i) + tc.policy.IDPClientID = fmt.Sprint(i) + tc.policy.IDPClientSecret = fmt.Sprint(i) options.Policies = append(options.Policies, *tc.policy) } require.NoError(f, options.Validate()) @@ -214,12 +368,15 @@ func FuzzGetIdentityProviderForRequestURL(f *testing.F) { if err != nil { t.SkipNow() } + if strings.Contains(inputURL.Host, "*") { + t.SkipNow() + } - for i, tc := range testCases { - expected := tc.policy + for i, p := range policies { + expected := p.policy actualIdp, actualErr := cache.GetIdentityProviderForRequestURL(options, input) - expectedMatch, expectedErr := tc.check(inputURL) + expectedMatch, expectedErr := p.check(inputURL) actualErrIsNotFound := errors.Is(actualErr, config.ErrNoIdentityProviderFound) if expectedErr != nil { @@ -235,10 +392,10 @@ func FuzzGetIdentityProviderForRequestURL(f *testing.F) { } if expectedMatch { if actualErrIsNotFound { - t.Fatalf("expected policy %d to match for input %q", i, input) + t.Fatalf("expected input %q to match policy %d, but no policies were matched", input, i) return } - assert.Equalf(t, expected.IDPClientID, actualIdp.ClientId, "wrong client id for input %q", input) + assert.Equalf(t, expected.IDPClientID, actualIdp.ClientId, "expected input %q to match policy %s, but instead matched policy %s", input, expected.IDPClientID, actualIdp.ClientId) } else { if !actualErrIsNotFound { assert.NotEqualf(t, expected.IDPClientID, actualIdp.ClientId, "expected policy %d not to match input %q", i, input) diff --git a/config/testdata/fuzz/FuzzGetIdentityProviderForRequestURL/02753d36e89e21ad b/config/testdata/fuzz/FuzzGetIdentityProviderForRequestURL/02753d36e89e21ad new file mode 100644 index 000000000..5ca7e0fdb --- /dev/null +++ b/config/testdata/fuzz/FuzzGetIdentityProviderForRequestURL/02753d36e89e21ad @@ -0,0 +1,2 @@ +go test fuzz v1 +string("example.c*000000") diff --git a/config/testdata/fuzz/FuzzGetIdentityProviderForRequestURL/a6e2cece43ad1b90 b/config/testdata/fuzz/FuzzGetIdentityProviderForRequestURL/a6e2cece43ad1b90 new file mode 100644 index 000000000..4680e7c95 --- /dev/null +++ b/config/testdata/fuzz/FuzzGetIdentityProviderForRequestURL/a6e2cece43ad1b90 @@ -0,0 +1,2 @@ +go test fuzz v1 +string("f.example.com") diff --git a/go.mod b/go.mod index d7dfecbb3..a609945e8 100644 --- a/go.mod +++ b/go.mod @@ -37,7 +37,7 @@ require ( github.com/hashicorp/golang-lru/v2 v2.0.7 github.com/jackc/pgx/v5 v5.6.0 github.com/klauspost/compress v1.17.8 - github.com/kralicky/go-adaptive-radix-tree v0.0.0-20240621232446-e019a6c4b8d7 + github.com/kralicky/go-adaptive-radix-tree v0.0.0-20240624235931-330eb762e74c github.com/martinlindhe/base36 v1.1.1 github.com/mholt/acmez/v2 v2.0.1 github.com/minio/minio-go/v7 v7.0.70 diff --git a/go.sum b/go.sum index 082ee2a99..f3f4d1b1e 100644 --- a/go.sum +++ b/go.sum @@ -416,8 +416,8 @@ github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= -github.com/kralicky/go-adaptive-radix-tree v0.0.0-20240621232446-e019a6c4b8d7 h1:OaN7Vhy9SgWnaRC7RhnCyJMQP8EGba34ACTMOvXfvR4= -github.com/kralicky/go-adaptive-radix-tree v0.0.0-20240621232446-e019a6c4b8d7/go.mod h1:oJwexVSshEat0E3evyKOH6QzN8GFWrhLvEoh8GiJzss= +github.com/kralicky/go-adaptive-radix-tree v0.0.0-20240624235931-330eb762e74c h1:TRkEV8M5PhQU55WI49FKTszEIpFlwZ1wfxcACCRT7SE= +github.com/kralicky/go-adaptive-radix-tree v0.0.0-20240624235931-330eb762e74c/go.mod h1:oJwexVSshEat0E3evyKOH6QzN8GFWrhLvEoh8GiJzss= github.com/lib/pq v1.10.7 h1:p7ZhMD+KsSRozJr34udlUrhboJwWAgCg34+/ZZNvZZw= github.com/lib/pq v1.10.7/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/libdns/libdns v0.2.2 h1:O6ws7bAfRPaBsgAYt8MDe2HcNBGC29hkZ9MX2eUSX3s=