mirror of
https://github.com/pomerium/pomerium.git
synced 2025-04-29 10:26:29 +02:00
importutil: refactor GenerateRouteNames to allow for protobuf or config routes (#5427)
* importutil: refactor GenerateRouteNames to allow for protobuf or config routes * test via NewPolicyFromProto
This commit is contained in:
parent
e5ede2d167
commit
5ff53ef2b1
3 changed files with 78 additions and 43 deletions
|
@ -869,6 +869,26 @@ func (p *Policy) GetPassIdentityHeaders(options *Options) bool {
|
|||
return false
|
||||
}
|
||||
|
||||
// GetFrom gets the from URL.
|
||||
func (p *Policy) GetFrom() string {
|
||||
return p.From
|
||||
}
|
||||
|
||||
// GetPath gets the path.
|
||||
func (p *Policy) GetPath() string {
|
||||
return p.Path
|
||||
}
|
||||
|
||||
// GetPrefix gets the prefix.
|
||||
func (p *Policy) GetPrefix() string {
|
||||
return p.Prefix
|
||||
}
|
||||
|
||||
// GetRegex gets the regex.
|
||||
func (p *Policy) GetRegex() string {
|
||||
return p.Regex
|
||||
}
|
||||
|
||||
/*
|
||||
SortPolicies sorts policies to match the following SQL order:
|
||||
|
||||
|
|
|
@ -11,7 +11,6 @@ import (
|
|||
"strings"
|
||||
|
||||
"github.com/cespare/xxhash/v2"
|
||||
configpb "github.com/pomerium/pomerium/pkg/grpc/config"
|
||||
)
|
||||
|
||||
func GenerateCertName(cert *x509.Certificate) *string {
|
||||
|
@ -59,18 +58,26 @@ func pickDNSName(names []string) string {
|
|||
return names[0]
|
||||
}
|
||||
|
||||
func GenerateRouteNames(routes []*configpb.Route) []string {
|
||||
type Route interface {
|
||||
comparable
|
||||
GetFrom() string
|
||||
GetPrefix() string
|
||||
GetPath() string
|
||||
GetRegex() string
|
||||
}
|
||||
|
||||
func GenerateRouteNames[T Route](routes []T) []string {
|
||||
out := make([]string, len(routes))
|
||||
prefixes := make([][]string, len(routes))
|
||||
indexes := map[*configpb.Route]int{}
|
||||
trie := newDomainTrie()
|
||||
indexes := map[T]int{}
|
||||
trie := newDomainTrie[T]()
|
||||
for i, route := range routes {
|
||||
trie.Insert(route)
|
||||
indexes[route] = i
|
||||
}
|
||||
trie.Compact()
|
||||
|
||||
trie.Walk(func(parents []string, node *domainTreeNode) {
|
||||
trie.Walk(func(parents []string, node *domainTreeNode[T]) {
|
||||
for subdomain, child := range node.children {
|
||||
for route, name := range differentiateRoutes(subdomain, child.routes) {
|
||||
idx := indexes[route]
|
||||
|
@ -145,13 +152,13 @@ func trimCommonSubdomains(a, b string) (string, string) {
|
|||
return strings.Join(aParts, "-"), strings.Join(bParts, "-")
|
||||
}
|
||||
|
||||
func differentiateRoutes(subdomain string, routes []*configpb.Route) iter.Seq2[*configpb.Route, string] {
|
||||
return func(yield func(*configpb.Route, string) bool) {
|
||||
func differentiateRoutes[T Route](subdomain string, routes []T) iter.Seq2[T, string] {
|
||||
return func(yield func(T, string) bool) {
|
||||
if len(routes) == 1 {
|
||||
yield(routes[0], subdomain)
|
||||
return
|
||||
}
|
||||
names := map[string][]*configpb.Route{}
|
||||
names := map[string][]T{}
|
||||
replacer := strings.NewReplacer(
|
||||
" ", "_",
|
||||
"/", "-",
|
||||
|
@ -180,14 +187,14 @@ func differentiateRoutes(subdomain string, routes []*configpb.Route) iter.Seq2[*
|
|||
// each route will have the same domain, but a unique prefix/path/regex.
|
||||
var name string
|
||||
switch {
|
||||
case route.Prefix != "":
|
||||
name = simplePathName(route.Prefix)
|
||||
case route.GetPrefix() != "":
|
||||
name = simplePathName(route.GetPrefix())
|
||||
prefixCount++
|
||||
case route.Path != "":
|
||||
name = simplePathName(route.Path)
|
||||
case route.GetPath() != "":
|
||||
name = simplePathName(route.GetPath())
|
||||
pathCount++
|
||||
case route.Regex != "":
|
||||
name = regexName(route.Regex)
|
||||
case route.GetRegex() != "":
|
||||
name = regexName(route.GetRegex())
|
||||
}
|
||||
names[name] = append(names[name], route)
|
||||
}
|
||||
|
@ -229,9 +236,9 @@ func differentiateRoutes(subdomain string, routes []*configpb.Route) iter.Seq2[*
|
|||
b.WriteRune('-')
|
||||
b.WriteString(name)
|
||||
}
|
||||
if route.Prefix != "" {
|
||||
if route.GetPrefix() != "" {
|
||||
b.WriteString(prefixSuffix)
|
||||
} else if route.Path != "" {
|
||||
} else if route.GetPath() != "" {
|
||||
b.WriteString(pathSuffix)
|
||||
}
|
||||
|
||||
|
@ -251,58 +258,58 @@ func differentiateRoutes(subdomain string, routes []*configpb.Route) iter.Seq2[*
|
|||
}
|
||||
}
|
||||
|
||||
type domainTreeNode struct {
|
||||
parent *domainTreeNode
|
||||
children map[string]*domainTreeNode
|
||||
routes []*configpb.Route
|
||||
type domainTreeNode[T Route] struct {
|
||||
parent *domainTreeNode[T]
|
||||
children map[string]*domainTreeNode[T]
|
||||
routes []T
|
||||
}
|
||||
|
||||
func (n *domainTreeNode) insert(key string, route *configpb.Route) *domainTreeNode {
|
||||
func (n *domainTreeNode[T]) insert(key string, route T) *domainTreeNode[T] {
|
||||
if existing, ok := n.children[key]; ok {
|
||||
if route != nil {
|
||||
var def T
|
||||
if route != def {
|
||||
existing.routes = append(existing.routes, route)
|
||||
}
|
||||
return existing
|
||||
}
|
||||
node := &domainTreeNode{
|
||||
node := &domainTreeNode[T]{
|
||||
parent: n,
|
||||
children: map[string]*domainTreeNode{},
|
||||
children: map[string]*domainTreeNode[T]{},
|
||||
}
|
||||
if route != nil {
|
||||
var def T
|
||||
if route != def {
|
||||
node.routes = append(node.routes, route)
|
||||
}
|
||||
n.children[key] = node
|
||||
return node
|
||||
}
|
||||
|
||||
type domainTrie struct {
|
||||
root *domainTreeNode
|
||||
type domainTrie[T Route] struct {
|
||||
root *domainTreeNode[T]
|
||||
}
|
||||
|
||||
func newDomainTrie() *domainTrie {
|
||||
t := &domainTrie{
|
||||
root: &domainTreeNode{
|
||||
children: map[string]*domainTreeNode{},
|
||||
func newDomainTrie[T Route]() *domainTrie[T] {
|
||||
t := &domainTrie[T]{
|
||||
root: &domainTreeNode[T]{
|
||||
children: map[string]*domainTreeNode[T]{},
|
||||
},
|
||||
}
|
||||
return t
|
||||
}
|
||||
|
||||
type walkFn = func(parents []string, node *domainTreeNode)
|
||||
|
||||
func (t *domainTrie) Walk(fn walkFn) {
|
||||
func (t *domainTrie[T]) Walk(fn func(parents []string, node *domainTreeNode[T])) {
|
||||
t.root.walk(nil, fn)
|
||||
}
|
||||
|
||||
func (n *domainTreeNode) walk(prefix []string, fn walkFn) {
|
||||
func (n *domainTreeNode[T]) walk(prefix []string, fn func(parents []string, node *domainTreeNode[T])) {
|
||||
for key, child := range n.children {
|
||||
fn(append(prefix, key), child)
|
||||
child.walk(append(prefix, key), fn)
|
||||
}
|
||||
}
|
||||
|
||||
func (t *domainTrie) Insert(route *configpb.Route) {
|
||||
u, _ := url.Parse(route.From)
|
||||
func (t *domainTrie[T]) Insert(route T) {
|
||||
u, _ := url.Parse(route.GetFrom())
|
||||
if u == nil {
|
||||
// ignore invalid urls, they will be assigned generic fallback names
|
||||
return
|
||||
|
@ -311,16 +318,17 @@ func (t *domainTrie) Insert(route *configpb.Route) {
|
|||
slices.Reverse(parts)
|
||||
cur := t.root
|
||||
for _, part := range parts[:len(parts)-1] {
|
||||
cur = cur.insert(part, nil)
|
||||
var def T
|
||||
cur = cur.insert(part, def)
|
||||
}
|
||||
cur.insert(parts[len(parts)-1], route)
|
||||
}
|
||||
|
||||
func (t *domainTrie) Compact() {
|
||||
func (t *domainTrie[T]) Compact() {
|
||||
t.root.compact()
|
||||
}
|
||||
|
||||
func (n *domainTreeNode) compact() {
|
||||
func (n *domainTreeNode[T]) compact() {
|
||||
for _, child := range n.children {
|
||||
child.compact()
|
||||
}
|
||||
|
@ -328,7 +336,7 @@ func (n *domainTreeNode) compact() {
|
|||
return
|
||||
}
|
||||
var firstKey string
|
||||
var firstChild *domainTreeNode
|
||||
var firstChild *domainTreeNode[T]
|
||||
for key, child := range n.children {
|
||||
firstKey, firstChild = key, child
|
||||
break
|
||||
|
@ -340,7 +348,7 @@ func (n *domainTreeNode) compact() {
|
|||
if child == n {
|
||||
delete(n.parent.children, key)
|
||||
n.parent.children[fmt.Sprintf("%s.%s", key, firstKey)] = firstChild
|
||||
*n = domainTreeNode{}
|
||||
*n = domainTreeNode[T]{}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
|
|
@ -8,9 +8,11 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/pomerium/pomerium/config"
|
||||
configpb "github.com/pomerium/pomerium/pkg/grpc/config"
|
||||
"github.com/pomerium/pomerium/pkg/zero/importutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestGenerateCertName(t *testing.T) {
|
||||
|
@ -389,6 +391,11 @@ func TestGenerateRouteNames(t *testing.T) {
|
|||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
assert.Equal(t, tc.expected, importutil.GenerateRouteNames(tc.input))
|
||||
policies := make([]*config.Policy, len(tc.input))
|
||||
for i := range tc.input {
|
||||
policies[i], _ = config.NewPolicyFromProto(tc.input[i])
|
||||
}
|
||||
assert.Equal(t, tc.expected, importutil.GenerateRouteNames(policies))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue