pomerium/config/identity_test.go
2024-06-24 20:06:38 -04:00

407 lines
11 KiB
Go

package config_test
import (
"encoding/base64"
"errors"
"fmt"
"net/url"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/urlutil"
"github.com/pomerium/pomerium/pkg/cryptutil"
)
// NB: we omit the https:// prefix in test urls so the fuzzer doesn't consider it
// part of the input. in the real system, we can assume any actual request urls
// have the valid https scheme (otherwise they wouldn't get here in the first place)
var corpus = []string{
"example.com",
"example.com",
"example.com/",
"example.com/",
"example.com/foo",
"example.com/bar/baz",
"example.com/hello/world",
"example.com/foo/bar/baz/qux",
"example.com/search?q=test",
"example.com/search?q=test&page=2",
"example.com/filter?category=books&price=low",
"example.com/api/v1/users?id=123",
"example.com/page#section1",
"example.com/docs#installation",
"example.com/#top",
"example.com/articles#readmore",
"example.com/%20space",
"example.com/~user/profile",
"example.com/emoji/😊",
"example.com/path%2Fwith%2Fencoded%2Fcharacters",
"example.com:8080",
"example.com:443",
"localhost:3000",
"127.0.0.1:8443",
"user:pass@example.com",
"user:pass@example.com/secure",
"[2001:db8::1]",
"[2001:db8::1]:8080/path",
"例子.测试",
"उदाहरण.परीक्षा",
"example.com/" + strings.Repeat("a", 2048),
"example.com/path/" + strings.Repeat("b", 1024),
"example.com/?",
"example.com/?=",
"example.com/?q=%C3%A9",
"example.com/space%20in%20path",
"example.com/multiple?query=params&another=test",
"example.com/case/SENSITIVE",
"example.com/a/b/c/d/e/f/g/h/i",
"example.com/nested/path/with/many/levels",
"example.com/foo/bar;param=value",
"example.com/bar?foo=bar&baz=qux",
"example.com/path with spaces",
"example.com/#fragment",
"example.com:1234/path",
"example.com/path/to/resource?query=param#fragment",
"example.com/file.tar.gz",
"example.com/some%20path/?query=with%20space",
"example.com/api/resource/123?format=json",
"example.com/api/v2/resource/456?fields=foo,bar,baz",
"example.com/user/1/profile",
"example.com/posts/2023/06/21/title",
"example.com/item/123/edit",
"example.com/items?filter=active",
"example.com/login?redirect=/home",
"example.com/register?next=/dashboard",
"example.com/download/file.zip",
"example.com/upload/image.jpg",
"example.com/api/data?page=2&sort=asc",
"example.com/products?category=electronics",
"example.com/categories/?name=books",
"example.com/search?query=open+source",
"example.com/redirect?to=https://example.org",
"example.com/profile/settings",
"a.example.com/foo/bar",
"b.example.com/foo/bar",
"c.example.com/foo/bar",
"d.example.com/foo/bar",
"a.foo.example.com/foo/bar",
"b.foo.example.com/foo/bar",
"c.foo.example.com/foo/bar",
"d.foo.example.com/foo/bar",
"foo.a.example.com/foo/bar",
"foo.b.example.com/foo/bar",
"foo.c.example.com/foo/bar",
"foo.d.example.com/foo/bar",
"test.com/foo",
"mysite.org/bar/baz",
"localhost:3000/test",
"dev.local/hello/world",
"subdomain.example.com/api/v1/resource",
"example.net/foo/bar/baz",
"mysite.dev/hello/world",
"myapp.local/api/resource/123",
"example.edu/search?q=test",
"example.co.uk/filter?category=books&price=low",
"example.org/page#section1",
"example.net/docs#installation",
"example.io/#top",
"example.biz/articles#readmore",
"example.info/%20space",
"example.tv/~user/profile",
"example.xyz/emoji/😊",
"example.online/path%2Fwith%2Fencoded%2Fcharacters",
"example.shop:8080",
"example.club:443",
"localhost:8080/test/path",
"127.0.0.1:3000/test",
"user:pass@mysite.com",
"user:pass@secure.example.com",
"[2001:db8::1234]/path",
"[2001:db8::5678]:8080/path/to/resource",
"xn--fsq.com/emoji/😊",
"xn--g6w251d.com/space%20path",
"xn--80akhbyknj4f.com/%20space",
"xn--d1acufc.xn--p1ai/",
"xn--80asehdb/xn--b1afkhbi.com/",
"xn--e1afmkfd.xn--p1ai",
"xn--fct.xn--q9jyb4c/xn--node%20path",
"xn--fiq228c.xn--kpry57d/%20encoded%2Fcharacters",
"xn--fsq6x.com/path%20with%20spaces",
"xn--g2xx48c/xn--g2xx48c/#fragment",
}
func FuzzGetIdentityProviderForRequestURL(f *testing.F) {
for _, input := range corpus {
f.Add(input)
}
emptyPortMatchesAll := true // todo
type policyMatcher struct {
policy *config.Policy
check func(input *url.URL) (bool, error)
}
checkEmptyPort := func(u *url.URL) bool {
if !emptyPortMatchesAll {
// flag disabled
if port := u.Port(); port != "" && port != "443" {
// should not match any non-default port
return false
}
}
return true
}
policies := []policyMatcher{
0: {
policy: &config.Policy{From: "https://example.com"},
check: func(input *url.URL) (bool, error) {
if !checkEmptyPort(input) {
return false, nil
}
if input.Hostname() != "example.com" {
return false, nil
}
if strings.HasPrefix(input.Path, "/api/") {
return false, nil // should match 2 instead
}
return true, nil
},
},
1: {
policy: &config.Policy{From: "https://*.foo.example.com", Prefix: "/prefix"},
check: func(u *url.URL) (bool, error) {
if !checkEmptyPort(u) {
return false, nil
}
parts := strings.Split(u.Hostname(), ".")
if len(parts) != 4 {
return false, nil
}
if len(parts[0]) == 0 {
return false, nil
}
if parts[1] != "foo" || parts[2] != "example" || parts[3] != "com" {
return false, nil
}
if !strings.HasPrefix(u.Path, "/prefix") {
return false, nil
}
return true, nil
},
},
2: {
policy: &config.Policy{From: "https://example.com", Prefix: "/api/"},
check: func(input *url.URL) (bool, error) {
if !checkEmptyPort(input) {
return false, nil
}
if input.Hostname() != "example.com" {
return false, nil
}
if !strings.HasPrefix(input.Path, "/api/") {
return false, nil
}
return true, nil
},
},
3: {
policy: &config.Policy{From: "https://localhost:3000"},
check: func(input *url.URL) (bool, error) {
if input.Host != "localhost:3000" {
return false, nil
}
return true, nil
},
},
4: {
policy: &config.Policy{From: "https://secure.example.com"},
check: func(input *url.URL) (bool, error) {
if !checkEmptyPort(input) {
return false, nil
}
if input.Host != "secure.example.com" {
return false, nil
}
return true, nil
},
},
5: {
policy: &config.Policy{From: "https://example.*"},
check: func(input *url.URL) (bool, error) {
if !checkEmptyPort(input) {
return false, nil
}
parts := strings.Split(input.Hostname(), ".")
if len(parts) != 2 {
return false, nil
}
if parts[0] != "example" || len(parts[1]) == 0 {
return false, nil
}
if parts[1] == "com" {
return false, nil // should match 0 or 2 instead
}
return true, nil
},
},
6: {
policy: &config.Policy{From: "https://*.example.com"},
check: func(input *url.URL) (bool, error) {
if !checkEmptyPort(input) {
return false, nil
}
parts := strings.Split(input.Hostname(), ".")
if len(parts) != 3 {
return false, nil
}
if len(parts[0]) == 0 || parts[1] != "example" || parts[2] != "com" {
return false, nil
}
if parts[0] == "secure" {
return false, nil // should match 4 instead
}
return true, nil
},
},
7: {
policy: &config.Policy{From: "https://foo.*.example.com"},
check: func(input *url.URL) (bool, error) {
if !checkEmptyPort(input) {
return false, nil
}
parts := strings.Split(input.Hostname(), ".")
if len(parts) != 4 {
return false, nil
}
if parts[0] != "foo" || len(parts[1]) == 0 || parts[2] != "example" || parts[3] != "com" {
return false, nil
}
if parts[1] == "foo" {
return false, nil // should match 1 instead (subtle)
}
return true, nil
},
},
8: {
policy: &config.Policy{From: "https://*.*.example.com", Prefix: "/foo/"},
check: func(input *url.URL) (bool, error) {
if !checkEmptyPort(input) {
return false, nil
}
parts := strings.Split(input.Hostname(), ".")
if len(parts) != 4 {
return false, nil
}
if len(parts[0]) == 0 || len(parts[1]) == 0 || parts[2] != "example" || parts[3] != "com" {
return false, nil
}
if !strings.HasPrefix(input.Path, "/foo/") {
return false, nil
}
if parts[0] == "foo" {
return false, nil // should match 7 instead
}
if parts[1] == "foo" {
return false, nil // should match 1 instead
}
return true, nil
},
},
}
options := config.NewDefaultOptions()
sharedKey := cryptutil.NewKey()
options.SharedKey = base64.StdEncoding.EncodeToString(sharedKey)
options.InsecureServer = true
options.Provider = "oidc"
options.ProviderURL = "https://oidc.example.com"
options.ClientID = "client_id"
options.ClientSecret = "client_secret"
options.RuntimeFlags = config.DefaultRuntimeFlags()
if emptyPortMatchesAll {
options.RuntimeFlags[config.RuntimeFlagMatchAnyIncomingPort] = false
}
for i, tc := range policies {
tc.policy.To = mustParseWeightedURLs(f, fmt.Sprintf("https://to-%d", i))
tc.policy.IDPClientID = fmt.Sprint(i)
tc.policy.IDPClientSecret = fmt.Sprint(i)
options.Policies = append(options.Policies, *tc.policy)
}
require.NoError(f, options.Validate())
cache, err := config.NewPolicyCache(options)
require.NoError(f, err)
f.Fuzz(func(t *testing.T, input string) {
input = "https://" + input // see note at the top of this file
inputURL, err := urlutil.ParseAndValidateURL(input)
if err != nil {
t.SkipNow()
}
if strings.Contains(inputURL.Host, "*") {
t.SkipNow()
}
for i, p := range policies {
expected := p.policy
actualIdp, actualErr := cache.GetIdentityProviderForRequestURL(options, input)
expectedMatch, expectedErr := p.check(inputURL)
actualErrIsNotFound := errors.Is(actualErr, config.ErrNoIdentityProviderFound)
if expectedErr != nil {
if actualErrIsNotFound {
t.Fatalf("expected invalid input %q to be rejected", input)
return
}
assert.Errorf(t, actualErr, "expected an error for input %q", input)
} else {
if actualErr != nil && !actualErrIsNotFound {
t.Fatalf("unexpected error: %v", err)
return
}
if expectedMatch {
if actualErrIsNotFound {
t.Fatalf("expected input %q to match policy %d, but no policies were matched", input, i)
return
}
assert.Equalf(t, expected.IDPClientID, actualIdp.ClientId, "expected input %q to match policy %s, but instead matched policy %s", input, expected.IDPClientID, actualIdp.ClientId)
} else {
if !actualErrIsNotFound {
assert.NotEqualf(t, expected.IDPClientID, actualIdp.ClientId, "expected policy %d not to match input %q", i, input)
}
}
}
}
})
}