Optimize policy iterators (#5184)

* Optimize policy iterators (go1.23)

This modifies (*Options).GetAllPolicies() to use a go 1.23 iterator
instead of copying all policies on every call, which can be extremely
expensive. All existing usages of this function were updated as
necessary.

Additionally, a new (*Options).NumPolicies() method was added which
quickly computes the number of policies that would be given by
GetAllPolicies(), since there were several usages where only the
number of policies was needed.

* Fix race condition when assigning default envoy opts to a policy
This commit is contained in:
Joe Kralicky 2024-08-20 12:35:10 -04:00 committed by GitHub
parent 3961098681
commit 56ba07e53e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 136 additions and 87 deletions

View file

@ -5,6 +5,7 @@ package authorize
import ( import (
"context" "context"
"fmt" "fmt"
"slices"
"sync" "sync"
"time" "time"
@ -91,7 +92,7 @@ func newPolicyEvaluator(
opts *config.Options, store *store.Store, previous *evaluator.Evaluator, opts *config.Options, store *store.Store, previous *evaluator.Evaluator,
) (*evaluator.Evaluator, error) { ) (*evaluator.Evaluator, error) {
metrics.AddPolicyCountCallback("pomerium-authorize", func() int64 { metrics.AddPolicyCountCallback("pomerium-authorize", func() int64 {
return int64(len(opts.GetAllPolicies())) return int64(opts.NumPolicies())
}) })
ctx := log.WithContext(context.Background(), func(c zerolog.Context) zerolog.Context { ctx := log.WithContext(context.Background(), func(c zerolog.Context) zerolog.Context {
return c.Str("service", "authorize") return c.Str("service", "authorize")
@ -131,8 +132,9 @@ func newPolicyEvaluator(
"authorize: internal error: couldn't build client cert constraints: %w", err) "authorize: internal error: couldn't build client cert constraints: %w", err)
} }
allPolicies := slices.Collect(opts.GetAllPolicies())
return evaluator.New(ctx, store, previous, return evaluator.New(ctx, store, previous,
evaluator.WithPolicies(opts.GetAllPolicies()), evaluator.WithPolicies(allPolicies),
evaluator.WithClientCA(clientCA), evaluator.WithClientCA(clientCA),
evaluator.WithAddDefaultClientCertificateRule(addDefaultClientCertificateRule), evaluator.WithAddDefaultClientCertificateRule(addDefaultClientCertificateRule),
evaluator.WithClientCRL(clientCRL), evaluator.WithClientCRL(clientCRL),

View file

@ -6,7 +6,7 @@ import (
) )
type evaluatorConfig struct { type evaluatorConfig struct {
Policies []config.Policy `hash:"-"` Policies []*config.Policy `hash:"-"`
ClientCA []byte ClientCA []byte
ClientCRL []byte ClientCRL []byte
AddDefaultClientCertificateRule bool AddDefaultClientCertificateRule bool
@ -34,7 +34,7 @@ func getConfig(options ...Option) *evaluatorConfig {
} }
// WithPolicies sets the policies in the config. // WithPolicies sets the policies in the config.
func WithPolicies(policies []config.Policy) Option { func WithPolicies(policies []*config.Policy) Option {
return func(cfg *evaluatorConfig) { return func(cfg *evaluatorConfig) {
cfg.Policies = policies cfg.Policies = policies
} }

View file

@ -168,7 +168,7 @@ func getOrCreatePolicyEvaluators(
continue continue
} }
builders = append(builders, func(ctx context.Context) (*routeEvaluator, error) { builders = append(builders, func(ctx context.Context) (*routeEvaluator, error) {
evaluator, err := NewPolicyEvaluator(ctx, store, &configPolicy, cfg.AddDefaultClientCertificateRule) evaluator, err := NewPolicyEvaluator(ctx, store, configPolicy, cfg.AddDefaultClientCertificateRule)
if err != nil { if err != nil {
return nil, fmt.Errorf("authorize: error building evaluator for route id=%s: %w", configPolicy.ID, err) return nil, fmt.Errorf("authorize: error building evaluator for route id=%s: %w", configPolicy.ID, err)
} }

View file

@ -41,7 +41,7 @@ func TestEvaluator(t *testing.T) {
return e.Evaluate(ctx, req) return e.Evaluate(ctx, req)
} }
policies := []config.Policy{ policies := []*config.Policy{
{ {
To: config.WeightedURLs{{URL: *mustParseURL("https://to1.example.com")}}, To: config.WeightedURLs{{URL: *mustParseURL("https://to1.example.com")}},
AllowPublicUnauthenticatedAccess: true, AllowPublicUnauthenticatedAccess: true,
@ -145,14 +145,14 @@ func TestEvaluator(t *testing.T) {
WithAddDefaultClientCertificateRule(true)) WithAddDefaultClientCertificateRule(true))
t.Run("missing", func(t *testing.T) { t.Run("missing", func(t *testing.T) {
res, err := eval(t, options, nil, &Request{ res, err := eval(t, options, nil, &Request{
Policy: &policies[0], Policy: policies[0],
}) })
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, NewRuleResult(true, criteria.ReasonClientCertificateRequired), res.Deny) assert.Equal(t, NewRuleResult(true, criteria.ReasonClientCertificateRequired), res.Deny)
}) })
t.Run("invalid", func(t *testing.T) { t.Run("invalid", func(t *testing.T) {
res, err := eval(t, options, nil, &Request{ res, err := eval(t, options, nil, &Request{
Policy: &policies[0], Policy: policies[0],
HTTP: RequestHTTP{ HTTP: RequestHTTP{
ClientCertificate: ClientCertificateInfo{Presented: true}, ClientCertificate: ClientCertificateInfo{Presented: true},
}, },
@ -162,7 +162,7 @@ func TestEvaluator(t *testing.T) {
}) })
t.Run("valid", func(t *testing.T) { t.Run("valid", func(t *testing.T) {
res, err := eval(t, options, nil, &Request{ res, err := eval(t, options, nil, &Request{
Policy: &policies[0], Policy: policies[0],
HTTP: RequestHTTP{ HTTP: RequestHTTP{
ClientCertificate: validCertInfo, ClientCertificate: validCertInfo,
}, },
@ -177,14 +177,14 @@ func TestEvaluator(t *testing.T) {
options = append(options, WithAddDefaultClientCertificateRule(true)) options = append(options, WithAddDefaultClientCertificateRule(true))
t.Run("missing", func(t *testing.T) { t.Run("missing", func(t *testing.T) {
res, err := eval(t, options, nil, &Request{ res, err := eval(t, options, nil, &Request{
Policy: &policies[10], Policy: policies[10],
}) })
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, NewRuleResult(true, criteria.ReasonClientCertificateRequired), res.Deny) assert.Equal(t, NewRuleResult(true, criteria.ReasonClientCertificateRequired), res.Deny)
}) })
t.Run("invalid", func(t *testing.T) { t.Run("invalid", func(t *testing.T) {
res, err := eval(t, options, nil, &Request{ res, err := eval(t, options, nil, &Request{
Policy: &policies[10], Policy: policies[10],
HTTP: RequestHTTP{ HTTP: RequestHTTP{
ClientCertificate: ClientCertificateInfo{ ClientCertificate: ClientCertificateInfo{
Presented: true, Presented: true,
@ -197,7 +197,7 @@ func TestEvaluator(t *testing.T) {
}) })
t.Run("valid", func(t *testing.T) { t.Run("valid", func(t *testing.T) {
res, err := eval(t, options, nil, &Request{ res, err := eval(t, options, nil, &Request{
Policy: &policies[10], Policy: policies[10],
HTTP: RequestHTTP{ HTTP: RequestHTTP{
ClientCertificate: validCertInfo, ClientCertificate: validCertInfo,
}, },
@ -213,7 +213,7 @@ func TestEvaluator(t *testing.T) {
options = append(options, WithClientCA([]byte(testCA))) options = append(options, WithClientCA([]byte(testCA)))
t.Run("invalid but allowed", func(t *testing.T) { t.Run("invalid but allowed", func(t *testing.T) {
res, err := eval(t, options, nil, &Request{ res, err := eval(t, options, nil, &Request{
Policy: &policies[0], // no explicit deny rule Policy: policies[0], // no explicit deny rule
HTTP: RequestHTTP{ HTTP: RequestHTTP{
ClientCertificate: ClientCertificateInfo{ ClientCertificate: ClientCertificateInfo{
Presented: true, Presented: true,
@ -226,7 +226,7 @@ func TestEvaluator(t *testing.T) {
}) })
t.Run("invalid", func(t *testing.T) { t.Run("invalid", func(t *testing.T) {
res, err := eval(t, options, nil, &Request{ res, err := eval(t, options, nil, &Request{
Policy: &policies[11], // policy has explicit deny rule Policy: policies[11], // policy has explicit deny rule
HTTP: RequestHTTP{ HTTP: RequestHTTP{
ClientCertificate: ClientCertificateInfo{ ClientCertificate: ClientCertificateInfo{
Presented: true, Presented: true,
@ -250,7 +250,7 @@ func TestEvaluator(t *testing.T) {
Email: "a@example.com", Email: "a@example.com",
}, },
}, &Request{ }, &Request{
Policy: &policies[1], Policy: policies[1],
Session: RequestSession{ Session: RequestSession{
ID: "session1", ID: "session1",
}, },
@ -274,7 +274,7 @@ func TestEvaluator(t *testing.T) {
Email: "a@example.com", Email: "a@example.com",
}, },
}, &Request{ }, &Request{
Policy: &policies[2], Policy: policies[2],
Session: RequestSession{ Session: RequestSession{
ID: "session1", ID: "session1",
}, },
@ -300,7 +300,7 @@ func TestEvaluator(t *testing.T) {
Email: "a@example.com", Email: "a@example.com",
}, },
}, &Request{ }, &Request{
Policy: &policies[3], Policy: policies[3],
Session: RequestSession{ Session: RequestSession{
ID: "session1", ID: "session1",
}, },
@ -323,7 +323,7 @@ func TestEvaluator(t *testing.T) {
Email: "a@example.com", Email: "a@example.com",
}, },
}, &Request{ }, &Request{
Policy: &policies[4], Policy: policies[4],
Session: RequestSession{ Session: RequestSession{
ID: "session1", ID: "session1",
}, },
@ -346,7 +346,7 @@ func TestEvaluator(t *testing.T) {
Email: "b@example.com", Email: "b@example.com",
}, },
}, &Request{ }, &Request{
Policy: &policies[3], Policy: policies[3],
Session: RequestSession{ Session: RequestSession{
ID: "session1", ID: "session1",
}, },
@ -376,7 +376,7 @@ func TestEvaluator(t *testing.T) {
Email: "a@example.com", Email: "a@example.com",
}, },
}, &Request{ }, &Request{
Policy: &policies[3], Policy: policies[3],
Session: RequestSession{ Session: RequestSession{
ID: "session2", ID: "session2",
}, },
@ -400,7 +400,7 @@ func TestEvaluator(t *testing.T) {
Email: "a@example.com", Email: "a@example.com",
}, },
}, &Request{ }, &Request{
Policy: &policies[5], Policy: policies[5],
Session: RequestSession{ Session: RequestSession{
ID: "session1", ID: "session1",
}, },
@ -423,7 +423,7 @@ func TestEvaluator(t *testing.T) {
Email: "a@example.com", Email: "a@example.com",
}, },
}, &Request{ }, &Request{
Policy: &policies[6], Policy: policies[6],
Session: RequestSession{ Session: RequestSession{
ID: "session1", ID: "session1",
}, },
@ -451,7 +451,7 @@ func TestEvaluator(t *testing.T) {
Email: "a@example.com", Email: "a@example.com",
}, },
}, &Request{ }, &Request{
Policy: &policies[6], Policy: policies[6],
Session: RequestSession{ Session: RequestSession{
ID: "session1", ID: "session1",
}, },
@ -473,7 +473,7 @@ func TestEvaluator(t *testing.T) {
Id: "user1", Id: "user1",
}, },
}, &Request{ }, &Request{
Policy: &policies[7], Policy: policies[7],
Session: RequestSession{ Session: RequestSession{
ID: "session1", ID: "session1",
}, },
@ -509,7 +509,7 @@ func TestEvaluator(t *testing.T) {
Id: "user1", Id: "user1",
}, },
}, &Request{ }, &Request{
Policy: &policies[8], Policy: policies[8],
Session: RequestSession{ Session: RequestSession{
ID: "session1", ID: "session1",
}, },
@ -526,7 +526,7 @@ func TestEvaluator(t *testing.T) {
}) })
t.Run("http method", func(t *testing.T) { t.Run("http method", func(t *testing.T) {
res, err := eval(t, options, []proto.Message{}, &Request{ res, err := eval(t, options, []proto.Message{}, &Request{
Policy: &policies[8], Policy: policies[8],
HTTP: NewRequestHTTP( HTTP: NewRequestHTTP(
http.MethodGet, http.MethodGet,
*mustParseURL("https://from.example.com/"), *mustParseURL("https://from.example.com/"),
@ -540,7 +540,7 @@ func TestEvaluator(t *testing.T) {
}) })
t.Run("http path", func(t *testing.T) { t.Run("http path", func(t *testing.T) {
res, err := eval(t, options, []proto.Message{}, &Request{ res, err := eval(t, options, []proto.Message{}, &Request{
Policy: &policies[9], Policy: policies[9],
HTTP: NewRequestHTTP( HTTP: NewRequestHTTP(
"POST", "POST",
*mustParseURL("https://from.example.com/test"), *mustParseURL("https://from.example.com/test"),
@ -559,7 +559,7 @@ func TestPolicyEvaluatorReuse(t *testing.T) {
store := store.New() store := store.New()
policies := []config.Policy{ policies := []*config.Policy{
{To: singleToURL("https://to1.example.com")}, {To: singleToURL("https://to1.example.com")},
{To: singleToURL("https://to2.example.com")}, {To: singleToURL("https://to2.example.com")},
{To: singleToURL("https://to3.example.com")}, {To: singleToURL("https://to3.example.com")},
@ -600,7 +600,7 @@ func TestPolicyEvaluatorReuse(t *testing.T) {
e, err := New(ctx, store, initial, options...) e, err := New(ctx, store, initial, options...)
require.NoError(t, err) require.NoError(t, err)
for i := range policies { for i := range policies {
assertPolicyEvaluatorReused(t, e, &policies[i]) assertPolicyEvaluatorReused(t, e, policies[i])
} }
}) })
@ -608,7 +608,7 @@ func TestPolicyEvaluatorReuse(t *testing.T) {
e, err := New(ctx, store, initial, append(options, o)...) e, err := New(ctx, store, initial, append(options, o)...)
require.NoError(t, err) require.NoError(t, err)
for i := range policies { for i := range policies {
assertPolicyEvaluatorUpdated(t, e, &policies[i]) assertPolicyEvaluatorUpdated(t, e, policies[i])
} }
} }
@ -647,7 +647,7 @@ func TestPolicyEvaluatorReuse(t *testing.T) {
// identical, only evaluators for the changed policies should be updated. // identical, only evaluators for the changed policies should be updated.
t.Run("policies changed", func(t *testing.T) { t.Run("policies changed", func(t *testing.T) {
// Make changes to some of the policies. // Make changes to some of the policies.
newPolicies := []config.Policy{ newPolicies := []*config.Policy{
{To: singleToURL("https://to1.example.com")}, {To: singleToURL("https://to1.example.com")},
{ {
To: singleToURL("https://to2.example.com"), To: singleToURL("https://to2.example.com"),
@ -662,9 +662,9 @@ func TestPolicyEvaluatorReuse(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
// Only the first and the third policy evaluators should be reused. // Only the first and the third policy evaluators should be reused.
assertPolicyEvaluatorReused(t, e, &newPolicies[0]) assertPolicyEvaluatorReused(t, e, newPolicies[0])
assertPolicyEvaluatorUpdated(t, e, &newPolicies[1]) assertPolicyEvaluatorUpdated(t, e, newPolicies[1])
assertPolicyEvaluatorReused(t, e, &newPolicies[2]) assertPolicyEvaluatorReused(t, e, newPolicies[2])
// The last policy shouldn't correspond with any of the initial policy // The last policy shouldn't correspond with any of the initial policy
// evaluators. // evaluators.

View file

@ -121,10 +121,10 @@ func (a *Authorize) getEvaluatorRequestFromCheckRequest(
func (a *Authorize) getMatchingPolicy(routeID uint64) *config.Policy { func (a *Authorize) getMatchingPolicy(routeID uint64) *config.Policy {
options := a.currentOptions.Load() options := a.currentOptions.Load()
for _, p := range options.GetAllPolicies() { for p := range options.GetAllPolicies() {
id, _ := p.RouteID() id, _ := p.RouteID()
if id == routeID { if id == routeID {
return &p return p
} }
} }

View file

@ -48,7 +48,7 @@ func (s *Store) UpdateJWTClaimHeaders(jwtClaimHeaders map[string]string) {
} }
// UpdateRoutePolicies updates the route policies in the store. // UpdateRoutePolicies updates the route policies in the store.
func (s *Store) UpdateRoutePolicies(routePolicies []config.Policy) { func (s *Store) UpdateRoutePolicies(routePolicies []*config.Policy) {
s.write("/route_policies", routePolicies) s.write("/route_policies", routePolicies)
} }

View file

@ -112,15 +112,11 @@ func (b *Builder) BuildClusters(ctx context.Context, cfg *config.Config) ([]*env
} }
if config.IsProxy(cfg.Options.Services) { if config.IsProxy(cfg.Options.Services) {
for i, p := range cfg.Options.GetAllPolicies() { for policy := range cfg.Options.GetAllPolicies() {
policy := p
if policy.EnvoyOpts == nil {
policy.EnvoyOpts = newDefaultEnvoyClusterConfig()
}
if len(policy.To) > 0 { if len(policy.To) > 0 {
cluster, err := b.buildPolicyCluster(ctx, cfg, &policy) cluster, err := b.buildPolicyCluster(ctx, cfg, policy)
if err != nil { if err != nil {
return nil, fmt.Errorf("policy #%d: %w", i, err) return nil, fmt.Errorf("policy %q: %w", policy.String(), err)
} }
clusters = append(clusters, cluster) clusters = append(clusters, cluster)
} }
@ -168,8 +164,12 @@ func (b *Builder) buildInternalCluster(
} }
func (b *Builder) buildPolicyCluster(ctx context.Context, cfg *config.Config, policy *config.Policy) (*envoy_config_cluster_v3.Cluster, error) { func (b *Builder) buildPolicyCluster(ctx context.Context, cfg *config.Config, policy *config.Policy) (*envoy_config_cluster_v3.Cluster, error) {
cluster := new(envoy_config_cluster_v3.Cluster) var cluster *envoy_config_cluster_v3.Cluster
proto.Merge(cluster, policy.EnvoyOpts) if policy.EnvoyOpts != nil {
cluster = proto.Clone(policy.EnvoyOpts).(*envoy_config_cluster_v3.Cluster)
} else {
cluster = newDefaultEnvoyClusterConfig()
}
options := cfg.Options options := cfg.Options

View file

@ -631,9 +631,7 @@ func clientCABundle(ctx context.Context, cfg *config.Config) []byte {
var bundle bytes.Buffer var bundle bytes.Buffer
ca, _ := cfg.Options.DownstreamMTLS.GetCA() ca, _ := cfg.Options.DownstreamMTLS.GetCA()
addCAToBundle(&bundle, ca) addCAToBundle(&bundle, ca)
allPolicies := cfg.Options.GetAllPolicies() for p := range cfg.Options.GetAllPolicies() {
for i := range allPolicies {
p := &allPolicies[i]
// We don't need to check TLSDownstreamClientCAFile here because // We don't need to check TLSDownstreamClientCAFile here because
// Policy.Validate() will populate TLSDownstreamClientCA when // Policy.Validate() will populate TLSDownstreamClientCA when
// TLSDownstreamClientCAFile is set. // TLSDownstreamClientCAFile is set.

View file

@ -177,7 +177,7 @@ func (b *Builder) buildRoutesForPoliciesWithHost(
host string, host string,
) ([]*envoy_config_route_v3.Route, error) { ) ([]*envoy_config_route_v3.Route, error) {
var routes []*envoy_config_route_v3.Route var routes []*envoy_config_route_v3.Route
for i, p := range cfg.Options.GetAllPolicies() { for i, p := range cfg.Options.GetAllPoliciesIndexed() {
policy := p policy := p
fromURL, err := urlutil.ParseAndValidateURL(policy.From) fromURL, err := urlutil.ParseAndValidateURL(policy.From)
if err != nil { if err != nil {
@ -188,7 +188,7 @@ func (b *Builder) buildRoutesForPoliciesWithHost(
continue continue
} }
policyRoutes, err := b.buildRoutesForPolicy(cfg, &policy, fmt.Sprintf("policy-%d", i)) policyRoutes, err := b.buildRoutesForPolicy(cfg, policy, fmt.Sprintf("policy-%d", i))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -202,8 +202,7 @@ func (b *Builder) buildRoutesForPoliciesWithCatchAll(
cfg *config.Config, cfg *config.Config,
) ([]*envoy_config_route_v3.Route, error) { ) ([]*envoy_config_route_v3.Route, error) {
var routes []*envoy_config_route_v3.Route var routes []*envoy_config_route_v3.Route
for i, p := range cfg.Options.GetAllPolicies() { for i, policy := range cfg.Options.GetAllPoliciesIndexed() {
policy := p
fromURL, err := urlutil.ParseAndValidateURL(policy.From) fromURL, err := urlutil.ParseAndValidateURL(policy.From)
if err != nil { if err != nil {
return nil, err return nil, err
@ -213,7 +212,7 @@ func (b *Builder) buildRoutesForPoliciesWithCatchAll(
continue continue
} }
policyRoutes, err := b.buildRoutesForPolicy(cfg, &policy, fmt.Sprintf("policy-%d", i)) policyRoutes, err := b.buildRoutesForPolicy(cfg, policy, fmt.Sprintf("policy-%d", i))
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -8,9 +8,8 @@ import (
// GetIdentityProviderForID returns the identity provider associated with the given IDP id. // GetIdentityProviderForID returns the identity provider associated with the given IDP id.
// If none is found the default provider is returned. // If none is found the default provider is returned.
func (o *Options) GetIdentityProviderForID(idpID string) (*identity.Provider, error) { func (o *Options) GetIdentityProviderForID(idpID string) (*identity.Provider, error) {
for _, p := range o.GetAllPolicies() { for p := range o.GetAllPolicies() {
p := p idp, err := o.GetIdentityProviderForPolicy(p)
idp, err := o.GetIdentityProviderForPolicy(&p)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -63,10 +62,9 @@ func (o *Options) GetIdentityProviderForRequestURL(requestURL string) (*identity
return nil, err return nil, err
} }
for _, p := range o.GetAllPolicies() { for p := range o.GetAllPolicies() {
p := p
if p.Matches(*u, o.IsRuntimeFlagSet(RuntimeFlagMatchAnyIncomingPort)) { if p.Matches(*u, o.IsRuntimeFlagSet(RuntimeFlagMatchAnyIncomingPort)) {
return o.GetIdentityProviderForPolicy(&p) return o.GetIdentityProviderForPolicy(p)
} }
} }
return o.GetIdentityProviderForPolicy(nil) return o.GetIdentityProviderForPolicy(nil)

View file

@ -8,6 +8,7 @@ import (
"encoding/base64" "encoding/base64"
"errors" "errors"
"fmt" "fmt"
"iter"
"net/http" "net/http"
"net/url" "net/url"
"os" "os"
@ -361,7 +362,7 @@ func newOptionsFromConfig(configFile string) (*Options, error) {
} }
serviceName := telemetry.ServiceName(o.Services) serviceName := telemetry.ServiceName(o.Services)
metrics.AddPolicyCountCallback(serviceName, func() int64 { metrics.AddPolicyCountCallback(serviceName, func() int64 {
return int64(len(o.GetAllPolicies())) return int64(o.NumPolicies())
}) })
return o, nil return o, nil
@ -979,15 +980,61 @@ func (o *Options) GetOauthOptions() (oauth.Options, error) {
} }
// GetAllPolicies gets all the policies in the options. // GetAllPolicies gets all the policies in the options.
func (o *Options) GetAllPolicies() []Policy { func (o *Options) GetAllPolicies() iter.Seq[*Policy] {
if o == nil { return func(yield func(*Policy) bool) {
return nil if o == nil {
return
}
for i := range len(o.Policies) {
if !yield(&o.Policies[i]) {
return
}
}
for i := range len(o.Routes) {
if !yield(&o.Routes[i]) {
return
}
}
for i := range len(o.AdditionalPolicies) {
if !yield(&o.AdditionalPolicies[i]) {
return
}
}
} }
policies := make([]Policy, 0, len(o.Policies)+len(o.Routes)+len(o.AdditionalPolicies)) }
policies = append(policies, o.Policies...)
policies = append(policies, o.Routes...) // GetAllPolicies gets all the policies in the options.
policies = append(policies, o.AdditionalPolicies...) func (o *Options) GetAllPoliciesIndexed() iter.Seq2[int, *Policy] {
return policies return func(yield func(int, *Policy) bool) {
if o == nil {
return
}
index := 0
nextIndex := func() int {
i := index
index++
return i
}
for i := range len(o.Policies) {
if !yield(nextIndex(), &o.Policies[i]) {
return
}
}
for i := range len(o.Routes) {
if !yield(nextIndex(), &o.Routes[i]) {
return
}
}
for i := range len(o.AdditionalPolicies) {
if !yield(nextIndex(), &o.AdditionalPolicies[i]) {
return
}
}
}
}
func (o *Options) NumPolicies() int {
return len(o.Policies) + len(o.Routes) + len(o.AdditionalPolicies)
} }
// GetMetricsBasicAuth gets the metrics basic auth username and password. // GetMetricsBasicAuth gets the metrics basic auth username and password.
@ -1017,12 +1064,11 @@ func (o *Options) HasAnyDownstreamMTLSClientCA() bool {
if len(ca) > 0 { if len(ca) > 0 {
return true return true
} }
allPolicies := o.GetAllPolicies() for p := range o.GetAllPolicies() {
for i := range allPolicies {
// We don't need to check TLSDownstreamClientCAFile here because // We don't need to check TLSDownstreamClientCAFile here because
// Policy.Validate() will populate TLSDownstreamClientCA when // Policy.Validate() will populate TLSDownstreamClientCA when
// TLSDownstreamClientCAFile is set. // TLSDownstreamClientCAFile is set.
if allPolicies[i].TLSDownstreamClientCA != "" { if p.TLSDownstreamClientCA != "" {
return true return true
} }
} }
@ -1273,7 +1319,7 @@ func (o *Options) GetAllRouteableHTTPHosts() ([]string, error) {
// policy urls // policy urls
if IsProxy(o.Services) { if IsProxy(o.Services) {
for _, policy := range o.GetAllPolicies() { for policy := range o.GetAllPolicies() {
fromURL, err := urlutil.ParseAndValidateURL(policy.From) fromURL, err := urlutil.ParseAndValidateURL(policy.From)
if err != nil { if err != nil {
return nil, err return nil, err

View file

@ -599,6 +599,14 @@ func (p *Policy) RouteID() (uint64, error) {
return hashutil.Hash(id) return hashutil.Hash(id)
} }
func (p *Policy) MustRouteID() uint64 {
id, err := p.RouteID()
if err != nil {
panic(err)
}
return id
}
func (p *Policy) String() string { func (p *Policy) String() string {
to := "?" to := "?"
if len(p.To) > 0 { if len(p.To) > 0 {

View file

@ -471,14 +471,12 @@ func configureTrustedRoots(acmeMgr *certmagic.ACMEIssuer, opts config.AutocertOp
} }
func sourceHostnames(cfg *config.Config) []string { func sourceHostnames(cfg *config.Config) []string {
policies := cfg.Options.GetAllPolicies() if cfg.Options.NumPolicies() == 0 {
if len(policies) == 0 {
return nil return nil
} }
dedupe := map[string]struct{}{} dedupe := map[string]struct{}{}
for _, p := range policies { for p := range cfg.Options.GetAllPolicies() {
if u, _ := urlutil.ParseAndValidateURL(p.From); u != nil && !strings.Contains(u.Host, "*") { if u, _ := urlutil.ParseAndValidateURL(p.From); u != nil && !strings.Contains(u.Host, "*") {
dedupe[u.Hostname()] = struct{}{} dedupe[u.Hostname()] = struct{}{}
} }

View file

@ -199,8 +199,8 @@ func (src *ConfigSource) buildPolicyFromProto(_ context.Context, routepb *config
} }
func (src *ConfigSource) addPolicies(ctx context.Context, cfg *config.Config, policies []*config.Policy) { func (src *ConfigSource) addPolicies(ctx context.Context, cfg *config.Config, policies []*config.Policy) {
seen := make(map[uint64]struct{}) seen := make(map[uint64]struct{}, len(policies)+cfg.Options.NumPolicies())
for _, policy := range cfg.Options.GetAllPolicies() { for policy := range cfg.Options.GetAllPolicies() {
id, err := policy.RouteID() id, err := policy.RouteID()
if err != nil { if err != nil {
log.Ctx(ctx).Err(err).Str("policy", policy.String()).Msg("databroker: error getting route id") log.Ctx(ctx).Err(err).Str("policy", policy.String()).Msg("databroker: error getting route id")
@ -209,7 +209,7 @@ func (src *ConfigSource) addPolicies(ctx context.Context, cfg *config.Config, po
seen[id] = struct{}{} seen[id] = struct{}{}
} }
var additionalPolicies []config.Policy additionalPolicies := make([]config.Policy, 0, len(policies))
for _, policy := range policies { for _, policy := range policies {
if policy == nil { if policy == nil {
continue continue

View file

@ -29,13 +29,13 @@ type Handler struct {
mu sync.RWMutex mu sync.RWMutex
key []byte key []byte
options *config.Options options *config.Options
policies map[uint64]config.Policy policies map[uint64]*config.Policy
} }
// New creates a new Handler. // New creates a new Handler.
func New() *Handler { func New() *Handler {
h := new(Handler) h := new(Handler)
h.policies = make(map[uint64]config.Policy) h.policies = make(map[uint64]*config.Policy)
return h return h
} }
@ -120,7 +120,7 @@ func (h *Handler) Middleware(next http.Handler) http.Handler {
h := stdhttputil.NewSingleHostReverseProxy(&dst) h := stdhttputil.NewSingleHostReverseProxy(&dst)
h.ErrorLog = stdlog.New(log.Logger(), "", 0) h.ErrorLog = stdlog.New(log.Logger(), "", 0)
h.Transport = config.NewPolicyHTTPTransport(options, &policy, disableHTTP2) h.Transport = config.NewPolicyHTTPTransport(options, policy, disableHTTP2)
h.ServeHTTP(w, r) h.ServeHTTP(w, r)
return nil return nil
}) })
@ -133,8 +133,8 @@ func (h *Handler) Update(ctx context.Context, cfg *config.Config) {
h.key, _ = cfg.Options.GetSharedKey() h.key, _ = cfg.Options.GetSharedKey()
h.options = cfg.Options h.options = cfg.Options
h.policies = make(map[uint64]config.Policy) h.policies = make(map[uint64]*config.Policy, cfg.Options.NumPolicies())
for _, p := range cfg.Options.GetAllPolicies() { for p := range cfg.Options.GetAllPolicies() {
id, err := p.RouteID() id, err := p.RouteID()
if err != nil { if err != nil {
log.Warn(ctx).Err(err).Msg("reproxy: error getting route id") log.Warn(ctx).Err(err).Msg("reproxy: error getting route id")

View file

@ -74,7 +74,7 @@ func New(cfg *config.Config) (*Proxy, error) {
p.webauthn = webauthn.New(p.getWebauthnState) p.webauthn = webauthn.New(p.getWebauthnState)
metrics.AddPolicyCountCallback("pomerium-proxy", func() int64 { metrics.AddPolicyCountCallback("pomerium-proxy", func() int64 {
return int64(len(p.currentOptions.Load().GetAllPolicies())) return int64(p.currentOptions.Load().NumPolicies())
}) })
return p, nil return p, nil
@ -103,7 +103,7 @@ func (p *Proxy) OnConfigChange(_ context.Context, cfg *config.Config) {
} }
func (p *Proxy) setHandlers(opts *config.Options) error { func (p *Proxy) setHandlers(opts *config.Options) error {
if len(opts.GetAllPolicies()) == 0 { if opts.NumPolicies() == 0 {
log.Warn(context.TODO()).Msg("proxy: configuration has no policies") log.Warn(context.TODO()).Msg("proxy: configuration has no policies")
} }
r := httputil.NewRouter() r := httputil.NewRouter()