mirror of
https://github.com/pomerium/pomerium.git
synced 2025-07-22 19:18:01 +02:00
implement port+prefix matching
This commit is contained in:
parent
e18c04216e
commit
a5a0cf4ba8
11 changed files with 131 additions and 52 deletions
|
@ -3,11 +3,11 @@ package config
|
|||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
art "github.com/kralicky/go-adaptive-radix-tree"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/urlutil"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/identity"
|
||||
)
|
||||
|
@ -74,41 +74,38 @@ func (o *Options) GetIdentityProviderForRequestURL(ctx context.Context, requestU
|
|||
}
|
||||
|
||||
type PolicyCache struct {
|
||||
domainTree art.Tree[*domainNode]
|
||||
matchPorts bool
|
||||
domainTree art.Tree[domainNode]
|
||||
}
|
||||
|
||||
func NewPolicyCache(options *Options) (*PolicyCache, error) {
|
||||
tree := art.New[*domainNode]()
|
||||
shouldMatchPorts := !options.IsRuntimeFlagSet(RuntimeFlagMatchAnyIncomingPort)
|
||||
tree := art.New[domainNode]()
|
||||
emptyPortMatchesAny := options.IsRuntimeFlagSet(RuntimeFlagMatchAnyIncomingPort)
|
||||
for _, policy := range options.GetAllPolicies() {
|
||||
u, err := urlutil.ParseAndValidateURL(policy.From)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
domains := urlutil.GetDomainsForURL(u, shouldMatchPorts)
|
||||
for _, domain := range domains {
|
||||
host, port, err := net.SplitHostPort(domain) // todo: this is not optimal
|
||||
if err != nil {
|
||||
host, port = domain, ""
|
||||
}
|
||||
|
||||
urlutil.AllDomainsForURL(u, !emptyPortMatchesAny)(func(host, port string) bool {
|
||||
domainKey := radixKeyForHostPort(host, port)
|
||||
tree.Update(domainKey, newDomainNode, func(dn **domainNode) {
|
||||
tree.Update(art.Key(domainKey), func() domainNode {
|
||||
return domainNode{policiesByPrefix: art.New[Policy]()}
|
||||
}, func(dn *domainNode) {
|
||||
if policy.Prefix != "" {
|
||||
(*dn).policiesByPrefix.Insert(art.Key(policy.Prefix), policy)
|
||||
dn.policiesByPrefix.Insert(art.Key(policy.Prefix), policy)
|
||||
} else if policy.Path != "" {
|
||||
(*dn).policiesByPrefix.Insert(art.Key(policy.Path), policy)
|
||||
dn.policiesByPrefix.Insert(art.Key(policy.Path), policy)
|
||||
} else if policy.compiledRegex != nil {
|
||||
(*dn).policiesByRegex = append((*dn).policiesByRegex, policy)
|
||||
dn.policiesByRegex = append(dn.policiesByRegex, policy)
|
||||
} else {
|
||||
(*dn).policiesNoPathMatching = append((*dn).policiesNoPathMatching, policy)
|
||||
dn.policiesNoPathMatching = append(dn.policiesNoPathMatching, policy)
|
||||
}
|
||||
})
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
return &PolicyCache{
|
||||
domainTree: tree,
|
||||
matchPorts: shouldMatchPorts,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
@ -119,20 +116,20 @@ func (pc *PolicyCache) GetIdentityProviderForRequestURL(ctx context.Context, o *
|
|||
}
|
||||
|
||||
domainKey := radixKeyForHostPort(u.Hostname(), u.Port())
|
||||
domain, ok := pc.domainTree.Resolve(domainKey, wildcardResolver)
|
||||
domain, ok := pc.domainTree.Resolve(art.Key(domainKey), wildcardResolver)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no identity provider found for request URL %s", requestURL)
|
||||
}
|
||||
var policy *Policy
|
||||
if len(u.Path) == 0 || (len(u.Path) == 1 && u.Path[0] == '/') && len(domain.policiesNoPathMatching) > 0 {
|
||||
if (len(u.Path) == 0 || (len(u.Path) == 1 && u.Path[0] == '/')) &&
|
||||
len(domain.policiesNoPathMatching) > 0 {
|
||||
policy = &domain.policiesNoPathMatching[0]
|
||||
} else {
|
||||
if domain.policiesByPrefix.Size() > 0 {
|
||||
pathKey := art.Key(u.Path)
|
||||
actualKey, val, found := domain.policiesByPrefix.SearchNearest(pathKey)
|
||||
actualKey, val, found := domain.policiesByPrefix.SearchNearest(art.Key(u.Path))
|
||||
if found {
|
||||
// check for prefix match or exact match
|
||||
if c := actualKey.Compare(pathKey); c < 0 {
|
||||
if c := actualKey.Compare(art.Key(u.Path)); c < 0 {
|
||||
if val.Prefix != "" && strings.HasPrefix(u.Path, val.Prefix) {
|
||||
policy = &val
|
||||
}
|
||||
|
@ -144,9 +141,10 @@ func (pc *PolicyCache) GetIdentityProviderForRequestURL(ctx context.Context, o *
|
|||
}
|
||||
}
|
||||
if policy == nil {
|
||||
for _, p := range domain.policiesByRegex {
|
||||
for i := range len(domain.policiesByRegex) {
|
||||
p := &domain.policiesByRegex[i]
|
||||
if p.compiledRegex.MatchString(u.Path) {
|
||||
policy = &p
|
||||
policy = p
|
||||
break
|
||||
}
|
||||
}
|
||||
|
@ -165,13 +163,10 @@ type domainNode struct {
|
|||
policiesNoPathMatching []Policy
|
||||
}
|
||||
|
||||
func newDomainNode() *domainNode {
|
||||
return &domainNode{
|
||||
policiesByPrefix: art.New[Policy](),
|
||||
func radixKeyForHostPort(host, port string) string {
|
||||
if port == "" {
|
||||
port = "*"
|
||||
}
|
||||
}
|
||||
|
||||
func radixKeyForHostPort(host, port string) art.Key {
|
||||
parts := strings.Split(host, ".")
|
||||
sb := strings.Builder{}
|
||||
sb.WriteString(port)
|
||||
|
@ -179,7 +174,7 @@ func radixKeyForHostPort(host, port string) art.Key {
|
|||
sb.WriteByte('.')
|
||||
sb.WriteString(parts[i])
|
||||
}
|
||||
return art.Key(sb.String())
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
func wildcardResolver(key art.Key, conflictIndex int) (art.Key, int) {
|
||||
|
|
|
@ -7,9 +7,10 @@ import (
|
|||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/pomerium/pomerium/config"
|
||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func BenchmarkGetIdentityProviderForRequestURL_Old(b *testing.B) {
|
||||
|
@ -19,6 +20,7 @@ func BenchmarkGetIdentityProviderForRequestURL_Old(b *testing.B) {
|
|||
options := config.NewDefaultOptions()
|
||||
sharedKey := cryptutil.NewKey()
|
||||
options.SharedKey = base64.StdEncoding.EncodeToString(sharedKey)
|
||||
options.InsecureServer = true
|
||||
options.Provider = "oidc"
|
||||
options.ProviderURL = "https://oidc.example.com"
|
||||
options.ClientID = "client_id"
|
||||
|
@ -58,6 +60,7 @@ var bench = func(fill func(i int, p *config.Policy) string, numPolicies int) fun
|
|||
options := config.NewDefaultOptions()
|
||||
sharedKey := cryptutil.NewKey()
|
||||
options.SharedKey = base64.StdEncoding.EncodeToString(sharedKey)
|
||||
options.InsecureServer = true
|
||||
options.Provider = "oidc"
|
||||
options.ProviderURL = "https://oidc.example.com"
|
||||
options.ClientID = "client_id"
|
||||
|
@ -80,7 +83,9 @@ var bench = func(fill func(i int, p *config.Policy) string, numPolicies int) fun
|
|||
|
||||
b.ResetTimer()
|
||||
for i := range b.N {
|
||||
reqUrl := strings.Replace(allUrls[i%numPolicies], "*", fmt.Sprint(i), 1)
|
||||
// replace all *s in the url with a number, which is valid for both
|
||||
// hostname segments and ports.
|
||||
reqUrl := strings.ReplaceAll(allUrls[i%numPolicies], "*", fmt.Sprint(i))
|
||||
idp, err := cache.GetIdentityProviderForRequestURL(context.Background(), options, reqUrl)
|
||||
require.NoError(b, err)
|
||||
require.Equal(b, fmt.Sprintf("client_id_%d", i%numPolicies), idp.ClientId)
|
||||
|
@ -101,6 +106,23 @@ func BenchmarkGetIdentityProviderForRequestURL_New_DomainMatchOnly(b *testing.B)
|
|||
b.Run("5000 policies (domain matching only)", bench(domainMatchingOnly, 5000))
|
||||
}
|
||||
|
||||
func BenchmarkGetIdentityProviderForRequestURL_New_DomainPortMatchOnly(b *testing.B) {
|
||||
domainPortMatchingOnly := func(i int, p *config.Policy) string {
|
||||
p.From = fmt.Sprintf("https://*.foo.bar.test-%d.example.com", i)
|
||||
if i%5 == 0 {
|
||||
p.From += ":9999"
|
||||
} else if i%2 == 0 {
|
||||
p.From += ":443"
|
||||
}
|
||||
return p.From
|
||||
}
|
||||
|
||||
b.Run("5 policies (domain+port matching only)", bench(domainPortMatchingOnly, 5))
|
||||
b.Run("50 policies (domain+port matching only)", bench(domainPortMatchingOnly, 50))
|
||||
b.Run("500 policies (domain+port matching only)", bench(domainPortMatchingOnly, 500))
|
||||
b.Run("5000 policies (domain+port matching only)", bench(domainPortMatchingOnly, 5000))
|
||||
}
|
||||
|
||||
func BenchmarkGetIdentityProviderForRequestURL_New_PathMatchOnly(b *testing.B) {
|
||||
pathMatchingOnly := func(i int, p *config.Policy) string {
|
||||
p.From = "https://example.com"
|
||||
|
@ -113,6 +135,18 @@ func BenchmarkGetIdentityProviderForRequestURL_New_PathMatchOnly(b *testing.B) {
|
|||
b.Run("5000 policies (path matching only)", bench(pathMatchingOnly, 5000))
|
||||
}
|
||||
|
||||
func BenchmarkGetIdentityProviderForRequestURL_New_PrefixMatchOnly(b *testing.B) {
|
||||
prefixMatchingOnly := func(i int, p *config.Policy) string {
|
||||
p.From = "https://example.com"
|
||||
p.Prefix = fmt.Sprintf("/foo/bar/%d/", i)
|
||||
return p.From + p.Prefix + "/subpath"
|
||||
}
|
||||
b.Run("5 policies (prefix matching only)", bench(prefixMatchingOnly, 5))
|
||||
b.Run("50 policies (prefix matching only)", bench(prefixMatchingOnly, 50))
|
||||
b.Run("500 policies (prefix matching only)", bench(prefixMatchingOnly, 500))
|
||||
b.Run("5000 policies (prefix matching only)", bench(prefixMatchingOnly, 5000))
|
||||
}
|
||||
|
||||
func BenchmarkGetIdentityProviderForRequestURL_New_DomainAndPathMatching(b *testing.B) {
|
||||
combinedMatching := func(numPathsPerDomain int) func(i int, p *config.Policy) string {
|
||||
// returns a sequence of policies (ex: numPathsPerDomain=3)
|
||||
|
@ -159,12 +193,13 @@ func BenchmarkGetIdentityProviderForRequestURL_New_DomainAndPrefixMatching(b *te
|
|||
domain := fmt.Sprintf("https://*.foo.bar.test-%d.example.com", i/numPathsPerDomain)
|
||||
pathIdx := i % numPathsPerDomain
|
||||
var prefix strings.Builder
|
||||
for i := range pathIdx + 1 {
|
||||
fmt.Fprintf(&prefix, "/%d", i)
|
||||
for j := 0; j <= pathIdx; j++ {
|
||||
fmt.Fprintf(&prefix, "/%d", j)
|
||||
}
|
||||
prefix.WriteString("/")
|
||||
p.From = domain
|
||||
p.Prefix = prefix.String()
|
||||
return domain + p.Prefix
|
||||
return domain + p.Prefix + "subpath"
|
||||
}
|
||||
}
|
||||
|
||||
|
|
6
go.mod
6
go.mod
|
@ -1,6 +1,6 @@
|
|||
module github.com/pomerium/pomerium
|
||||
|
||||
go 1.22.2
|
||||
go 1.22.4
|
||||
|
||||
require (
|
||||
cloud.google.com/go/storage v1.41.0
|
||||
|
@ -10,6 +10,7 @@ require (
|
|||
github.com/CAFxX/httpcompression v0.0.9
|
||||
github.com/DataDog/opencensus-go-exporter-datadog v0.0.0-20200406135749-5c268882acf0
|
||||
github.com/VictoriaMetrics/fastcache v1.12.2
|
||||
github.com/akamensky/base58 v0.0.0-20210829145138-ce8bf8802e8f
|
||||
github.com/aws/aws-sdk-go-v2 v1.27.0
|
||||
github.com/aws/aws-sdk-go-v2/config v1.27.16
|
||||
github.com/aws/aws-sdk-go-v2/service/s3 v1.54.3
|
||||
|
@ -36,7 +37,7 @@ require (
|
|||
github.com/hashicorp/golang-lru/v2 v2.0.7
|
||||
github.com/jackc/pgx/v5 v5.6.0
|
||||
github.com/klauspost/compress v1.17.8
|
||||
github.com/kralicky/go-adaptive-radix-tree v0.0.0-20240619012453-a8f80032ba31
|
||||
github.com/kralicky/go-adaptive-radix-tree v0.0.0-20240620232421-9773ec5394e9
|
||||
github.com/martinlindhe/base36 v1.1.1
|
||||
github.com/mholt/acmez/v2 v2.0.1
|
||||
github.com/minio/minio-go/v7 v7.0.70
|
||||
|
@ -105,7 +106,6 @@ require (
|
|||
github.com/Nvveen/Gotty v0.0.0-20120604004816-cd527374f1e5 // indirect
|
||||
github.com/OneOfOne/xxhash v1.2.8 // indirect
|
||||
github.com/agnivade/levenshtein v1.1.1 // indirect
|
||||
github.com/akamensky/base58 v0.0.0-20210829145138-ce8bf8802e8f // indirect
|
||||
github.com/andybalholm/brotli v1.0.5 // indirect
|
||||
github.com/apapsch/go-jsonmerge/v2 v2.0.0 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2 // indirect
|
||||
|
|
4
go.sum
4
go.sum
|
@ -416,8 +416,8 @@ github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
|
|||
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
|
||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||
github.com/kralicky/go-adaptive-radix-tree v0.0.0-20240619012453-a8f80032ba31 h1:cOZUoNuv9OuCZKepddeIW87ScJWwveLfLcJMaii6YCA=
|
||||
github.com/kralicky/go-adaptive-radix-tree v0.0.0-20240619012453-a8f80032ba31/go.mod h1:oJwexVSshEat0E3evyKOH6QzN8GFWrhLvEoh8GiJzss=
|
||||
github.com/kralicky/go-adaptive-radix-tree v0.0.0-20240620232421-9773ec5394e9 h1:+RqdMg3IHwQ2YC+07pxc9hyHbwXngNRja70UTMbH2Wg=
|
||||
github.com/kralicky/go-adaptive-radix-tree v0.0.0-20240620232421-9773ec5394e9/go.mod h1:oJwexVSshEat0E3evyKOH6QzN8GFWrhLvEoh8GiJzss=
|
||||
github.com/lib/pq v1.10.7 h1:p7ZhMD+KsSRozJr34udlUrhboJwWAgCg34+/ZZNvZZw=
|
||||
github.com/lib/pq v1.10.7/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
|
||||
github.com/libdns/libdns v0.2.2 h1:O6ws7bAfRPaBsgAYt8MDe2HcNBGC29hkZ9MX2eUSX3s=
|
||||
|
|
|
@ -149,6 +149,59 @@ func GetDomainsForURL(u *url.URL, includeDefaultPort bool) []string {
|
|||
return []string{u.Hostname(), net.JoinHostPort(u.Hostname(), defaultPort)}
|
||||
}
|
||||
|
||||
func AllDomainsForURL(u *url.URL, includeDefaultPort bool) func(yield func(string, string) bool) {
|
||||
return func(yield func(string, string) bool) {
|
||||
if u == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// tcp+https://ssh.example.com:22
|
||||
// => ssh.example.com:22
|
||||
// tcp+https://proxy.example.com/ssh.example.com:22
|
||||
// => ssh.example.com:22
|
||||
if strings.HasPrefix(u.Scheme, "tcp+") {
|
||||
hosts := strings.Split(strings.TrimPrefix(u.Path, "/"), "/")
|
||||
// if there are no domains in the path part of the URL, use the host
|
||||
// otherwise use the path parts of the URL as the hosts
|
||||
for _, hostport := range append(hosts, u.Host) {
|
||||
if host, port, err := net.SplitHostPort(hostport); err == nil {
|
||||
yield(host, port)
|
||||
} else {
|
||||
yield(hostport, "")
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
var defaultPort string
|
||||
if u.Scheme == "http" {
|
||||
defaultPort = "80"
|
||||
} else {
|
||||
defaultPort = "443"
|
||||
}
|
||||
|
||||
// for hosts like 'example.com:1234' we only return one route
|
||||
if host, port, err := net.SplitHostPort(u.Host); err == nil {
|
||||
if port != defaultPort {
|
||||
yield(host, port)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
hostname := u.Hostname()
|
||||
if !includeDefaultPort {
|
||||
yield(hostname, "")
|
||||
return
|
||||
}
|
||||
|
||||
// for everything else we return two routes: 'example.com' and 'example.com:443'
|
||||
if !yield(hostname, "") {
|
||||
return
|
||||
}
|
||||
yield(hostname, defaultPort)
|
||||
}
|
||||
}
|
||||
|
||||
// Join joins elements of a URL with '/'.
|
||||
func Join(elements ...string) string {
|
||||
var builder strings.Builder
|
||||
|
|
|
@ -4,10 +4,9 @@ import (
|
|||
"crypto/rand"
|
||||
"fmt"
|
||||
|
||||
"github.com/akamensky/base58"
|
||||
"golang.org/x/crypto/curve25519"
|
||||
"golang.org/x/crypto/nacl/box"
|
||||
|
||||
"github.com/akamensky/base58"
|
||||
)
|
||||
|
||||
// A KeyEncryptionKey (KEK) is used to implement *envelope encryption*, similar to how data is stored at rest with
|
||||
|
|
|
@ -7,9 +7,8 @@ import (
|
|||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/akamensky/base58"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// TokenLength is the length of a token.
|
||||
|
|
|
@ -5,11 +5,11 @@ import (
|
|||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/akamensky/base58"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/akamensky/base58"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/pkg/protoutil"
|
||||
)
|
||||
|
|
|
@ -4,9 +4,8 @@ package identity
|
|||
import (
|
||||
"crypto/sha256"
|
||||
|
||||
"google.golang.org/protobuf/proto"
|
||||
|
||||
"github.com/akamensky/base58"
|
||||
"google.golang.org/protobuf/proto"
|
||||
)
|
||||
|
||||
// Clone clones the Provider.
|
||||
|
|
|
@ -4,9 +4,8 @@ package requestid
|
|||
import (
|
||||
"context"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/akamensky/base58"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
const headerName = "x-request-id"
|
||||
|
|
|
@ -3,10 +3,10 @@ package webauthnutil
|
|||
import (
|
||||
"context"
|
||||
|
||||
"github.com/akamensky/base58"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
"github.com/akamensky/base58"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/device"
|
||||
"github.com/pomerium/webauthn"
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue