From a5a0cf4ba8b0b310b201fd393768e818ef768fc6 Mon Sep 17 00:00:00 2001 From: Joe Kralicky Date: Thu, 20 Jun 2024 19:28:53 -0400 Subject: [PATCH] implement port+prefix matching --- config/identity.go | 59 ++++++++++++-------------- config/identity_benchmark_test.go | 45 +++++++++++++++++--- go.mod | 6 +-- go.sum | 4 +- internal/urlutil/url.go | 53 +++++++++++++++++++++++ pkg/cryptutil/kek.go | 3 +- pkg/cryptutil/token.go | 3 +- pkg/grpc/device/device.go | 2 +- pkg/grpc/identity/identity.go | 3 +- pkg/telemetry/requestid/requestid.go | 3 +- pkg/webauthnutil/credential_storage.go | 2 +- 11 files changed, 131 insertions(+), 52 deletions(-) diff --git a/config/identity.go b/config/identity.go index 94513f660..a9821d62d 100644 --- a/config/identity.go +++ b/config/identity.go @@ -3,11 +3,11 @@ package config import ( "context" "fmt" - "net" "slices" "strings" art "github.com/kralicky/go-adaptive-radix-tree" + "github.com/pomerium/pomerium/internal/urlutil" "github.com/pomerium/pomerium/pkg/grpc/identity" ) @@ -74,41 +74,38 @@ func (o *Options) GetIdentityProviderForRequestURL(ctx context.Context, requestU } type PolicyCache struct { - domainTree art.Tree[*domainNode] - matchPorts bool + domainTree art.Tree[domainNode] } func NewPolicyCache(options *Options) (*PolicyCache, error) { - tree := art.New[*domainNode]() - shouldMatchPorts := !options.IsRuntimeFlagSet(RuntimeFlagMatchAnyIncomingPort) + tree := art.New[domainNode]() + emptyPortMatchesAny := options.IsRuntimeFlagSet(RuntimeFlagMatchAnyIncomingPort) for _, policy := range options.GetAllPolicies() { u, err := urlutil.ParseAndValidateURL(policy.From) if err != nil { return nil, err } - domains := urlutil.GetDomainsForURL(u, shouldMatchPorts) - for _, domain := range domains { - host, port, err := net.SplitHostPort(domain) // todo: this is not optimal - if err != nil { - host, port = domain, "" - } + + urlutil.AllDomainsForURL(u, !emptyPortMatchesAny)(func(host, port string) bool { domainKey := radixKeyForHostPort(host, port) - tree.Update(domainKey, newDomainNode, func(dn **domainNode) { + tree.Update(art.Key(domainKey), func() domainNode { + return domainNode{policiesByPrefix: art.New[Policy]()} + }, func(dn *domainNode) { if policy.Prefix != "" { - (*dn).policiesByPrefix.Insert(art.Key(policy.Prefix), policy) + dn.policiesByPrefix.Insert(art.Key(policy.Prefix), policy) } else if policy.Path != "" { - (*dn).policiesByPrefix.Insert(art.Key(policy.Path), policy) + dn.policiesByPrefix.Insert(art.Key(policy.Path), policy) } else if policy.compiledRegex != nil { - (*dn).policiesByRegex = append((*dn).policiesByRegex, policy) + dn.policiesByRegex = append(dn.policiesByRegex, policy) } else { - (*dn).policiesNoPathMatching = append((*dn).policiesNoPathMatching, policy) + dn.policiesNoPathMatching = append(dn.policiesNoPathMatching, policy) } }) - } + return true + }) } return &PolicyCache{ domainTree: tree, - matchPorts: shouldMatchPorts, }, nil } @@ -119,20 +116,20 @@ func (pc *PolicyCache) GetIdentityProviderForRequestURL(ctx context.Context, o * } domainKey := radixKeyForHostPort(u.Hostname(), u.Port()) - domain, ok := pc.domainTree.Resolve(domainKey, wildcardResolver) + domain, ok := pc.domainTree.Resolve(art.Key(domainKey), wildcardResolver) if !ok { return nil, fmt.Errorf("no identity provider found for request URL %s", requestURL) } var policy *Policy - if len(u.Path) == 0 || (len(u.Path) == 1 && u.Path[0] == '/') && len(domain.policiesNoPathMatching) > 0 { + if (len(u.Path) == 0 || (len(u.Path) == 1 && u.Path[0] == '/')) && + len(domain.policiesNoPathMatching) > 0 { policy = &domain.policiesNoPathMatching[0] } else { if domain.policiesByPrefix.Size() > 0 { - pathKey := art.Key(u.Path) - actualKey, val, found := domain.policiesByPrefix.SearchNearest(pathKey) + actualKey, val, found := domain.policiesByPrefix.SearchNearest(art.Key(u.Path)) if found { // check for prefix match or exact match - if c := actualKey.Compare(pathKey); c < 0 { + if c := actualKey.Compare(art.Key(u.Path)); c < 0 { if val.Prefix != "" && strings.HasPrefix(u.Path, val.Prefix) { policy = &val } @@ -144,9 +141,10 @@ func (pc *PolicyCache) GetIdentityProviderForRequestURL(ctx context.Context, o * } } if policy == nil { - for _, p := range domain.policiesByRegex { + for i := range len(domain.policiesByRegex) { + p := &domain.policiesByRegex[i] if p.compiledRegex.MatchString(u.Path) { - policy = &p + policy = p break } } @@ -165,13 +163,10 @@ type domainNode struct { policiesNoPathMatching []Policy } -func newDomainNode() *domainNode { - return &domainNode{ - policiesByPrefix: art.New[Policy](), +func radixKeyForHostPort(host, port string) string { + if port == "" { + port = "*" } -} - -func radixKeyForHostPort(host, port string) art.Key { parts := strings.Split(host, ".") sb := strings.Builder{} sb.WriteString(port) @@ -179,7 +174,7 @@ func radixKeyForHostPort(host, port string) art.Key { sb.WriteByte('.') sb.WriteString(parts[i]) } - return art.Key(sb.String()) + return sb.String() } func wildcardResolver(key art.Key, conflictIndex int) (art.Key, int) { diff --git a/config/identity_benchmark_test.go b/config/identity_benchmark_test.go index 1c1dda52b..2174cf209 100644 --- a/config/identity_benchmark_test.go +++ b/config/identity_benchmark_test.go @@ -7,9 +7,10 @@ import ( "strings" "testing" + "github.com/stretchr/testify/require" + "github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/pkg/cryptutil" - "github.com/stretchr/testify/require" ) func BenchmarkGetIdentityProviderForRequestURL_Old(b *testing.B) { @@ -19,6 +20,7 @@ func BenchmarkGetIdentityProviderForRequestURL_Old(b *testing.B) { options := config.NewDefaultOptions() sharedKey := cryptutil.NewKey() options.SharedKey = base64.StdEncoding.EncodeToString(sharedKey) + options.InsecureServer = true options.Provider = "oidc" options.ProviderURL = "https://oidc.example.com" options.ClientID = "client_id" @@ -58,6 +60,7 @@ var bench = func(fill func(i int, p *config.Policy) string, numPolicies int) fun options := config.NewDefaultOptions() sharedKey := cryptutil.NewKey() options.SharedKey = base64.StdEncoding.EncodeToString(sharedKey) + options.InsecureServer = true options.Provider = "oidc" options.ProviderURL = "https://oidc.example.com" options.ClientID = "client_id" @@ -80,7 +83,9 @@ var bench = func(fill func(i int, p *config.Policy) string, numPolicies int) fun b.ResetTimer() for i := range b.N { - reqUrl := strings.Replace(allUrls[i%numPolicies], "*", fmt.Sprint(i), 1) + // replace all *s in the url with a number, which is valid for both + // hostname segments and ports. + reqUrl := strings.ReplaceAll(allUrls[i%numPolicies], "*", fmt.Sprint(i)) idp, err := cache.GetIdentityProviderForRequestURL(context.Background(), options, reqUrl) require.NoError(b, err) require.Equal(b, fmt.Sprintf("client_id_%d", i%numPolicies), idp.ClientId) @@ -101,6 +106,23 @@ func BenchmarkGetIdentityProviderForRequestURL_New_DomainMatchOnly(b *testing.B) b.Run("5000 policies (domain matching only)", bench(domainMatchingOnly, 5000)) } +func BenchmarkGetIdentityProviderForRequestURL_New_DomainPortMatchOnly(b *testing.B) { + domainPortMatchingOnly := func(i int, p *config.Policy) string { + p.From = fmt.Sprintf("https://*.foo.bar.test-%d.example.com", i) + if i%5 == 0 { + p.From += ":9999" + } else if i%2 == 0 { + p.From += ":443" + } + return p.From + } + + b.Run("5 policies (domain+port matching only)", bench(domainPortMatchingOnly, 5)) + b.Run("50 policies (domain+port matching only)", bench(domainPortMatchingOnly, 50)) + b.Run("500 policies (domain+port matching only)", bench(domainPortMatchingOnly, 500)) + b.Run("5000 policies (domain+port matching only)", bench(domainPortMatchingOnly, 5000)) +} + func BenchmarkGetIdentityProviderForRequestURL_New_PathMatchOnly(b *testing.B) { pathMatchingOnly := func(i int, p *config.Policy) string { p.From = "https://example.com" @@ -113,6 +135,18 @@ func BenchmarkGetIdentityProviderForRequestURL_New_PathMatchOnly(b *testing.B) { b.Run("5000 policies (path matching only)", bench(pathMatchingOnly, 5000)) } +func BenchmarkGetIdentityProviderForRequestURL_New_PrefixMatchOnly(b *testing.B) { + prefixMatchingOnly := func(i int, p *config.Policy) string { + p.From = "https://example.com" + p.Prefix = fmt.Sprintf("/foo/bar/%d/", i) + return p.From + p.Prefix + "/subpath" + } + b.Run("5 policies (prefix matching only)", bench(prefixMatchingOnly, 5)) + b.Run("50 policies (prefix matching only)", bench(prefixMatchingOnly, 50)) + b.Run("500 policies (prefix matching only)", bench(prefixMatchingOnly, 500)) + b.Run("5000 policies (prefix matching only)", bench(prefixMatchingOnly, 5000)) +} + func BenchmarkGetIdentityProviderForRequestURL_New_DomainAndPathMatching(b *testing.B) { combinedMatching := func(numPathsPerDomain int) func(i int, p *config.Policy) string { // returns a sequence of policies (ex: numPathsPerDomain=3) @@ -159,12 +193,13 @@ func BenchmarkGetIdentityProviderForRequestURL_New_DomainAndPrefixMatching(b *te domain := fmt.Sprintf("https://*.foo.bar.test-%d.example.com", i/numPathsPerDomain) pathIdx := i % numPathsPerDomain var prefix strings.Builder - for i := range pathIdx + 1 { - fmt.Fprintf(&prefix, "/%d", i) + for j := 0; j <= pathIdx; j++ { + fmt.Fprintf(&prefix, "/%d", j) } + prefix.WriteString("/") p.From = domain p.Prefix = prefix.String() - return domain + p.Prefix + return domain + p.Prefix + "subpath" } } diff --git a/go.mod b/go.mod index 20c66e6d6..375e8830b 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/pomerium/pomerium -go 1.22.2 +go 1.22.4 require ( cloud.google.com/go/storage v1.41.0 @@ -10,6 +10,7 @@ require ( github.com/CAFxX/httpcompression v0.0.9 github.com/DataDog/opencensus-go-exporter-datadog v0.0.0-20200406135749-5c268882acf0 github.com/VictoriaMetrics/fastcache v1.12.2 + github.com/akamensky/base58 v0.0.0-20210829145138-ce8bf8802e8f github.com/aws/aws-sdk-go-v2 v1.27.0 github.com/aws/aws-sdk-go-v2/config v1.27.16 github.com/aws/aws-sdk-go-v2/service/s3 v1.54.3 @@ -36,7 +37,7 @@ require ( github.com/hashicorp/golang-lru/v2 v2.0.7 github.com/jackc/pgx/v5 v5.6.0 github.com/klauspost/compress v1.17.8 - github.com/kralicky/go-adaptive-radix-tree v0.0.0-20240619012453-a8f80032ba31 + github.com/kralicky/go-adaptive-radix-tree v0.0.0-20240620232421-9773ec5394e9 github.com/martinlindhe/base36 v1.1.1 github.com/mholt/acmez/v2 v2.0.1 github.com/minio/minio-go/v7 v7.0.70 @@ -105,7 +106,6 @@ require ( github.com/Nvveen/Gotty v0.0.0-20120604004816-cd527374f1e5 // indirect github.com/OneOfOne/xxhash v1.2.8 // indirect github.com/agnivade/levenshtein v1.1.1 // indirect - github.com/akamensky/base58 v0.0.0-20210829145138-ce8bf8802e8f // indirect github.com/andybalholm/brotli v1.0.5 // indirect github.com/apapsch/go-jsonmerge/v2 v2.0.0 // indirect github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2 // indirect diff --git a/go.sum b/go.sum index 5824fbbeb..765f448c8 100644 --- a/go.sum +++ b/go.sum @@ -416,8 +416,8 @@ github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= -github.com/kralicky/go-adaptive-radix-tree v0.0.0-20240619012453-a8f80032ba31 h1:cOZUoNuv9OuCZKepddeIW87ScJWwveLfLcJMaii6YCA= -github.com/kralicky/go-adaptive-radix-tree v0.0.0-20240619012453-a8f80032ba31/go.mod h1:oJwexVSshEat0E3evyKOH6QzN8GFWrhLvEoh8GiJzss= +github.com/kralicky/go-adaptive-radix-tree v0.0.0-20240620232421-9773ec5394e9 h1:+RqdMg3IHwQ2YC+07pxc9hyHbwXngNRja70UTMbH2Wg= +github.com/kralicky/go-adaptive-radix-tree v0.0.0-20240620232421-9773ec5394e9/go.mod h1:oJwexVSshEat0E3evyKOH6QzN8GFWrhLvEoh8GiJzss= github.com/lib/pq v1.10.7 h1:p7ZhMD+KsSRozJr34udlUrhboJwWAgCg34+/ZZNvZZw= github.com/lib/pq v1.10.7/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/libdns/libdns v0.2.2 h1:O6ws7bAfRPaBsgAYt8MDe2HcNBGC29hkZ9MX2eUSX3s= diff --git a/internal/urlutil/url.go b/internal/urlutil/url.go index 95b60d7c6..329b91b58 100644 --- a/internal/urlutil/url.go +++ b/internal/urlutil/url.go @@ -149,6 +149,59 @@ func GetDomainsForURL(u *url.URL, includeDefaultPort bool) []string { return []string{u.Hostname(), net.JoinHostPort(u.Hostname(), defaultPort)} } +func AllDomainsForURL(u *url.URL, includeDefaultPort bool) func(yield func(string, string) bool) { + return func(yield func(string, string) bool) { + if u == nil { + return + } + + // tcp+https://ssh.example.com:22 + // => ssh.example.com:22 + // tcp+https://proxy.example.com/ssh.example.com:22 + // => ssh.example.com:22 + if strings.HasPrefix(u.Scheme, "tcp+") { + hosts := strings.Split(strings.TrimPrefix(u.Path, "/"), "/") + // if there are no domains in the path part of the URL, use the host + // otherwise use the path parts of the URL as the hosts + for _, hostport := range append(hosts, u.Host) { + if host, port, err := net.SplitHostPort(hostport); err == nil { + yield(host, port) + } else { + yield(hostport, "") + } + } + return + } + + var defaultPort string + if u.Scheme == "http" { + defaultPort = "80" + } else { + defaultPort = "443" + } + + // for hosts like 'example.com:1234' we only return one route + if host, port, err := net.SplitHostPort(u.Host); err == nil { + if port != defaultPort { + yield(host, port) + return + } + } + + hostname := u.Hostname() + if !includeDefaultPort { + yield(hostname, "") + return + } + + // for everything else we return two routes: 'example.com' and 'example.com:443' + if !yield(hostname, "") { + return + } + yield(hostname, defaultPort) + } +} + // Join joins elements of a URL with '/'. func Join(elements ...string) string { var builder strings.Builder diff --git a/pkg/cryptutil/kek.go b/pkg/cryptutil/kek.go index 2f29df5e2..33bd405c1 100644 --- a/pkg/cryptutil/kek.go +++ b/pkg/cryptutil/kek.go @@ -4,10 +4,9 @@ import ( "crypto/rand" "fmt" + "github.com/akamensky/base58" "golang.org/x/crypto/curve25519" "golang.org/x/crypto/nacl/box" - - "github.com/akamensky/base58" ) // A KeyEncryptionKey (KEK) is used to implement *envelope encryption*, similar to how data is stored at rest with diff --git a/pkg/cryptutil/token.go b/pkg/cryptutil/token.go index 1c6dc806e..1bf48bedc 100644 --- a/pkg/cryptutil/token.go +++ b/pkg/cryptutil/token.go @@ -7,9 +7,8 @@ import ( "errors" "time" - "github.com/google/uuid" - "github.com/akamensky/base58" + "github.com/google/uuid" ) // TokenLength is the length of a token. diff --git a/pkg/grpc/device/device.go b/pkg/grpc/device/device.go index bc6744017..5dc99d007 100644 --- a/pkg/grpc/device/device.go +++ b/pkg/grpc/device/device.go @@ -5,11 +5,11 @@ import ( "context" "fmt" + "github.com/akamensky/base58" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" "google.golang.org/protobuf/types/known/timestamppb" - "github.com/akamensky/base58" "github.com/pomerium/pomerium/pkg/grpc/databroker" "github.com/pomerium/pomerium/pkg/protoutil" ) diff --git a/pkg/grpc/identity/identity.go b/pkg/grpc/identity/identity.go index 065e4c9c1..8e1214706 100644 --- a/pkg/grpc/identity/identity.go +++ b/pkg/grpc/identity/identity.go @@ -4,9 +4,8 @@ package identity import ( "crypto/sha256" - "google.golang.org/protobuf/proto" - "github.com/akamensky/base58" + "google.golang.org/protobuf/proto" ) // Clone clones the Provider. diff --git a/pkg/telemetry/requestid/requestid.go b/pkg/telemetry/requestid/requestid.go index 91868c0a4..760fb92cd 100644 --- a/pkg/telemetry/requestid/requestid.go +++ b/pkg/telemetry/requestid/requestid.go @@ -4,9 +4,8 @@ package requestid import ( "context" - "github.com/google/uuid" - "github.com/akamensky/base58" + "github.com/google/uuid" ) const headerName = "x-request-id" diff --git a/pkg/webauthnutil/credential_storage.go b/pkg/webauthnutil/credential_storage.go index b3f1cbf41..8080cf5e6 100644 --- a/pkg/webauthnutil/credential_storage.go +++ b/pkg/webauthnutil/credential_storage.go @@ -3,10 +3,10 @@ package webauthnutil import ( "context" + "github.com/akamensky/base58" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" - "github.com/akamensky/base58" "github.com/pomerium/pomerium/pkg/grpc/databroker" "github.com/pomerium/pomerium/pkg/grpc/device" "github.com/pomerium/webauthn"