pomerium/config/identity.go

192 lines
5.2 KiB
Go

package config
import (
"fmt"
"slices"
"strings"
art "github.com/kralicky/go-adaptive-radix-tree"
"github.com/pomerium/pomerium/internal/urlutil"
"github.com/pomerium/pomerium/pkg/grpc/identity"
)
// GetIdentityProviderForID returns the identity provider associated with the given IDP id.
// If none is found the default provider is returned.
func (o *Options) GetIdentityProviderForID(idpID string) (*identity.Provider, error) {
for _, p := range o.GetAllPolicies() {
p := p
idp, err := o.GetIdentityProviderForPolicy(&p)
if err != nil {
return nil, err
}
if idp.GetId() == idpID {
return idp, nil
}
}
return o.GetIdentityProviderForPolicy(nil)
}
// GetIdentityProviderForPolicy gets the identity provider associated with the given policy.
// If policy is nil, or changes none of the default settings, the default provider is returned.
func (o *Options) GetIdentityProviderForPolicy(policy *Policy) (*identity.Provider, error) {
clientSecret, err := o.GetClientSecret()
if err != nil {
return nil, err
}
idp := &identity.Provider{
ClientId: o.ClientID,
ClientSecret: clientSecret,
Type: o.Provider,
Scopes: o.Scopes,
Url: o.ProviderURL,
RequestParams: o.RequestParams,
}
if policy != nil {
if policy.IDPClientID != "" {
idp.ClientId = policy.IDPClientID
}
if policy.IDPClientSecret != "" {
idp.ClientSecret = policy.IDPClientSecret
}
}
idp.Id = idp.Hash()
return idp, nil
}
// GetIdentityProviderForRequestURL gets the identity provider associated with the given request URL.
func (o *Options) GetIdentityProviderForRequestURL(requestURL string) (*identity.Provider, error) {
u, err := urlutil.ParseAndValidateURL(requestURL)
if err != nil {
return nil, err
}
for _, p := range o.GetAllPolicies() {
p := p
if p.Matches(*u, o.IsRuntimeFlagSet(RuntimeFlagMatchAnyIncomingPort)) {
return o.GetIdentityProviderForPolicy(&p)
}
}
return o.GetIdentityProviderForPolicy(nil)
}
type PolicyCache struct {
domainTree art.Tree[domainNode]
}
func NewPolicyCache(options *Options) (*PolicyCache, error) {
tree := art.New[domainNode]()
emptyPortMatchesAny := options.IsRuntimeFlagSet(RuntimeFlagMatchAnyIncomingPort)
for _, policy := range options.GetAllPolicies() {
u, err := urlutil.ParseAndValidateURL(policy.From)
if err != nil {
return nil, err
}
urlutil.AllDomainsForURL(u, !emptyPortMatchesAny)(func(host, port string) bool {
domainKey := radixKeyForHostPort(host, port)
tree.Update(art.Key(domainKey), func() domainNode {
return domainNode{policiesByPrefix: art.New[Policy]()}
}, func(dn *domainNode) {
if policy.Prefix != "" {
dn.policiesByPrefix.Insert(art.Key(policy.Prefix), policy)
} else if policy.Path != "" {
dn.policiesByPrefix.Insert(art.Key(policy.Path), policy)
} else if policy.compiledRegex != nil {
dn.policiesByRegex = append(dn.policiesByRegex, policy)
} else {
dn.policiesNoPathMatching = append(dn.policiesNoPathMatching, policy)
}
})
return true
})
}
return &PolicyCache{
domainTree: tree,
}, nil
}
func (pc *PolicyCache) GetIdentityProviderForRequestURL(o *Options, requestURL string) (*identity.Provider, error) {
u, err := urlutil.ParseAndValidateURL(requestURL)
if err != nil {
return nil, err
}
domainKey := radixKeyForHostPort(u.Hostname(), u.Port())
domain, ok := pc.domainTree.Resolve(art.Key(domainKey), wildcardResolver)
if !ok {
return nil, fmt.Errorf("no identity provider found for request URL %s", requestURL)
}
var policy *Policy
if (len(u.Path) == 0 || (len(u.Path) == 1 && u.Path[0] == '/')) &&
len(domain.policiesNoPathMatching) > 0 {
policy = &domain.policiesNoPathMatching[0]
} else {
if domain.policiesByPrefix.Size() > 0 {
actualKey, val, found := domain.policiesByPrefix.SearchNearest(art.Key(u.Path))
if found {
// check for prefix match or exact match
if c := actualKey.Compare(art.Key(u.Path)); c < 0 {
if val.Prefix != "" && strings.HasPrefix(u.Path, val.Prefix) {
policy = &val
}
} else if c == 0 {
if val.Path != "" || val.Prefix != "" {
policy = &val
}
}
}
}
if policy == nil {
for i := range len(domain.policiesByRegex) {
p := &domain.policiesByRegex[i]
if p.compiledRegex.MatchString(u.Path) {
policy = p
break
}
}
}
}
if policy != nil {
return o.GetIdentityProviderForPolicy(policy)
}
return nil, fmt.Errorf("no identity provider found for request URL %s", requestURL)
}
type domainNode struct {
policiesByPrefix art.Tree[Policy]
policiesByRegex []Policy
policiesNoPathMatching []Policy
}
func radixKeyForHostPort(host, port string) string {
if port == "" {
port = "*"
}
parts := strings.Split(host, ".")
sb := strings.Builder{}
sb.WriteString(port)
for i := len(parts) - 1; i >= 0; i-- {
sb.WriteByte('.')
sb.WriteString(parts[i])
}
return sb.String()
}
func wildcardResolver(key art.Key, conflictIndex int) (art.Key, int) {
if conflictIndex >= len(key) {
return nil, -1
}
c := key[conflictIndex]
if c != '*' && c != '.' {
nextDot := slices.Index(key[conflictIndex:], '.')
if nextDot == -1 {
return art.Key("*"), len(key)
}
return art.Key("*"), conflictIndex + nextDot
}
return nil, -1
}