diff --git a/authorize/authorize.go b/authorize/authorize.go index a23403c30..df9d6d630 100644 --- a/authorize/authorize.go +++ b/authorize/authorize.go @@ -5,6 +5,7 @@ package authorize import ( "context" "fmt" + "slices" "sync" "time" @@ -91,7 +92,7 @@ func newPolicyEvaluator( opts *config.Options, store *store.Store, previous *evaluator.Evaluator, ) (*evaluator.Evaluator, error) { metrics.AddPolicyCountCallback("pomerium-authorize", func() int64 { - return int64(len(opts.GetAllPolicies())) + return int64(opts.NumPolicies()) }) ctx := log.WithContext(context.Background(), func(c zerolog.Context) zerolog.Context { return c.Str("service", "authorize") @@ -131,8 +132,9 @@ func newPolicyEvaluator( "authorize: internal error: couldn't build client cert constraints: %w", err) } + allPolicies := slices.Collect(opts.GetAllPolicies()) return evaluator.New(ctx, store, previous, - evaluator.WithPolicies(opts.GetAllPolicies()), + evaluator.WithPolicies(allPolicies), evaluator.WithClientCA(clientCA), evaluator.WithAddDefaultClientCertificateRule(addDefaultClientCertificateRule), evaluator.WithClientCRL(clientCRL), diff --git a/authorize/evaluator/config.go b/authorize/evaluator/config.go index 2c3e138aa..b6a3416b3 100644 --- a/authorize/evaluator/config.go +++ b/authorize/evaluator/config.go @@ -6,7 +6,7 @@ import ( ) type evaluatorConfig struct { - Policies []config.Policy `hash:"-"` + Policies []*config.Policy `hash:"-"` ClientCA []byte ClientCRL []byte AddDefaultClientCertificateRule bool @@ -34,7 +34,7 @@ func getConfig(options ...Option) *evaluatorConfig { } // WithPolicies sets the policies in the config. -func WithPolicies(policies []config.Policy) Option { +func WithPolicies(policies []*config.Policy) Option { return func(cfg *evaluatorConfig) { cfg.Policies = policies } diff --git a/authorize/evaluator/evaluator.go b/authorize/evaluator/evaluator.go index 243093ed6..8e4405f09 100644 --- a/authorize/evaluator/evaluator.go +++ b/authorize/evaluator/evaluator.go @@ -168,7 +168,7 @@ func getOrCreatePolicyEvaluators( continue } builders = append(builders, func(ctx context.Context) (*routeEvaluator, error) { - evaluator, err := NewPolicyEvaluator(ctx, store, &configPolicy, cfg.AddDefaultClientCertificateRule) + evaluator, err := NewPolicyEvaluator(ctx, store, configPolicy, cfg.AddDefaultClientCertificateRule) if err != nil { return nil, fmt.Errorf("authorize: error building evaluator for route id=%s: %w", configPolicy.ID, err) } diff --git a/authorize/evaluator/evaluator_test.go b/authorize/evaluator/evaluator_test.go index 588bd1a33..40d44daf4 100644 --- a/authorize/evaluator/evaluator_test.go +++ b/authorize/evaluator/evaluator_test.go @@ -41,7 +41,7 @@ func TestEvaluator(t *testing.T) { return e.Evaluate(ctx, req) } - policies := []config.Policy{ + policies := []*config.Policy{ { To: config.WeightedURLs{{URL: *mustParseURL("https://to1.example.com")}}, AllowPublicUnauthenticatedAccess: true, @@ -145,14 +145,14 @@ func TestEvaluator(t *testing.T) { WithAddDefaultClientCertificateRule(true)) t.Run("missing", func(t *testing.T) { res, err := eval(t, options, nil, &Request{ - Policy: &policies[0], + Policy: policies[0], }) require.NoError(t, err) assert.Equal(t, NewRuleResult(true, criteria.ReasonClientCertificateRequired), res.Deny) }) t.Run("invalid", func(t *testing.T) { res, err := eval(t, options, nil, &Request{ - Policy: &policies[0], + Policy: policies[0], HTTP: RequestHTTP{ ClientCertificate: ClientCertificateInfo{Presented: true}, }, @@ -162,7 +162,7 @@ func TestEvaluator(t *testing.T) { }) t.Run("valid", func(t *testing.T) { res, err := eval(t, options, nil, &Request{ - Policy: &policies[0], + Policy: policies[0], HTTP: RequestHTTP{ ClientCertificate: validCertInfo, }, @@ -177,14 +177,14 @@ func TestEvaluator(t *testing.T) { options = append(options, WithAddDefaultClientCertificateRule(true)) t.Run("missing", func(t *testing.T) { res, err := eval(t, options, nil, &Request{ - Policy: &policies[10], + Policy: policies[10], }) require.NoError(t, err) assert.Equal(t, NewRuleResult(true, criteria.ReasonClientCertificateRequired), res.Deny) }) t.Run("invalid", func(t *testing.T) { res, err := eval(t, options, nil, &Request{ - Policy: &policies[10], + Policy: policies[10], HTTP: RequestHTTP{ ClientCertificate: ClientCertificateInfo{ Presented: true, @@ -197,7 +197,7 @@ func TestEvaluator(t *testing.T) { }) t.Run("valid", func(t *testing.T) { res, err := eval(t, options, nil, &Request{ - Policy: &policies[10], + Policy: policies[10], HTTP: RequestHTTP{ ClientCertificate: validCertInfo, }, @@ -213,7 +213,7 @@ func TestEvaluator(t *testing.T) { options = append(options, WithClientCA([]byte(testCA))) t.Run("invalid but allowed", func(t *testing.T) { res, err := eval(t, options, nil, &Request{ - Policy: &policies[0], // no explicit deny rule + Policy: policies[0], // no explicit deny rule HTTP: RequestHTTP{ ClientCertificate: ClientCertificateInfo{ Presented: true, @@ -226,7 +226,7 @@ func TestEvaluator(t *testing.T) { }) t.Run("invalid", func(t *testing.T) { res, err := eval(t, options, nil, &Request{ - Policy: &policies[11], // policy has explicit deny rule + Policy: policies[11], // policy has explicit deny rule HTTP: RequestHTTP{ ClientCertificate: ClientCertificateInfo{ Presented: true, @@ -250,7 +250,7 @@ func TestEvaluator(t *testing.T) { Email: "a@example.com", }, }, &Request{ - Policy: &policies[1], + Policy: policies[1], Session: RequestSession{ ID: "session1", }, @@ -274,7 +274,7 @@ func TestEvaluator(t *testing.T) { Email: "a@example.com", }, }, &Request{ - Policy: &policies[2], + Policy: policies[2], Session: RequestSession{ ID: "session1", }, @@ -300,7 +300,7 @@ func TestEvaluator(t *testing.T) { Email: "a@example.com", }, }, &Request{ - Policy: &policies[3], + Policy: policies[3], Session: RequestSession{ ID: "session1", }, @@ -323,7 +323,7 @@ func TestEvaluator(t *testing.T) { Email: "a@example.com", }, }, &Request{ - Policy: &policies[4], + Policy: policies[4], Session: RequestSession{ ID: "session1", }, @@ -346,7 +346,7 @@ func TestEvaluator(t *testing.T) { Email: "b@example.com", }, }, &Request{ - Policy: &policies[3], + Policy: policies[3], Session: RequestSession{ ID: "session1", }, @@ -376,7 +376,7 @@ func TestEvaluator(t *testing.T) { Email: "a@example.com", }, }, &Request{ - Policy: &policies[3], + Policy: policies[3], Session: RequestSession{ ID: "session2", }, @@ -400,7 +400,7 @@ func TestEvaluator(t *testing.T) { Email: "a@example.com", }, }, &Request{ - Policy: &policies[5], + Policy: policies[5], Session: RequestSession{ ID: "session1", }, @@ -423,7 +423,7 @@ func TestEvaluator(t *testing.T) { Email: "a@example.com", }, }, &Request{ - Policy: &policies[6], + Policy: policies[6], Session: RequestSession{ ID: "session1", }, @@ -451,7 +451,7 @@ func TestEvaluator(t *testing.T) { Email: "a@example.com", }, }, &Request{ - Policy: &policies[6], + Policy: policies[6], Session: RequestSession{ ID: "session1", }, @@ -473,7 +473,7 @@ func TestEvaluator(t *testing.T) { Id: "user1", }, }, &Request{ - Policy: &policies[7], + Policy: policies[7], Session: RequestSession{ ID: "session1", }, @@ -509,7 +509,7 @@ func TestEvaluator(t *testing.T) { Id: "user1", }, }, &Request{ - Policy: &policies[8], + Policy: policies[8], Session: RequestSession{ ID: "session1", }, @@ -526,7 +526,7 @@ func TestEvaluator(t *testing.T) { }) t.Run("http method", func(t *testing.T) { res, err := eval(t, options, []proto.Message{}, &Request{ - Policy: &policies[8], + Policy: policies[8], HTTP: NewRequestHTTP( http.MethodGet, *mustParseURL("https://from.example.com/"), @@ -540,7 +540,7 @@ func TestEvaluator(t *testing.T) { }) t.Run("http path", func(t *testing.T) { res, err := eval(t, options, []proto.Message{}, &Request{ - Policy: &policies[9], + Policy: policies[9], HTTP: NewRequestHTTP( "POST", *mustParseURL("https://from.example.com/test"), @@ -559,7 +559,7 @@ func TestPolicyEvaluatorReuse(t *testing.T) { store := store.New() - policies := []config.Policy{ + policies := []*config.Policy{ {To: singleToURL("https://to1.example.com")}, {To: singleToURL("https://to2.example.com")}, {To: singleToURL("https://to3.example.com")}, @@ -600,7 +600,7 @@ func TestPolicyEvaluatorReuse(t *testing.T) { e, err := New(ctx, store, initial, options...) require.NoError(t, err) for i := range policies { - assertPolicyEvaluatorReused(t, e, &policies[i]) + assertPolicyEvaluatorReused(t, e, policies[i]) } }) @@ -608,7 +608,7 @@ func TestPolicyEvaluatorReuse(t *testing.T) { e, err := New(ctx, store, initial, append(options, o)...) require.NoError(t, err) for i := range policies { - assertPolicyEvaluatorUpdated(t, e, &policies[i]) + assertPolicyEvaluatorUpdated(t, e, policies[i]) } } @@ -647,7 +647,7 @@ func TestPolicyEvaluatorReuse(t *testing.T) { // identical, only evaluators for the changed policies should be updated. t.Run("policies changed", func(t *testing.T) { // Make changes to some of the policies. - newPolicies := []config.Policy{ + newPolicies := []*config.Policy{ {To: singleToURL("https://to1.example.com")}, { To: singleToURL("https://to2.example.com"), @@ -662,9 +662,9 @@ func TestPolicyEvaluatorReuse(t *testing.T) { require.NoError(t, err) // Only the first and the third policy evaluators should be reused. - assertPolicyEvaluatorReused(t, e, &newPolicies[0]) - assertPolicyEvaluatorUpdated(t, e, &newPolicies[1]) - assertPolicyEvaluatorReused(t, e, &newPolicies[2]) + assertPolicyEvaluatorReused(t, e, newPolicies[0]) + assertPolicyEvaluatorUpdated(t, e, newPolicies[1]) + assertPolicyEvaluatorReused(t, e, newPolicies[2]) // The last policy shouldn't correspond with any of the initial policy // evaluators. diff --git a/authorize/grpc.go b/authorize/grpc.go index 85aafbb62..ff845e30b 100644 --- a/authorize/grpc.go +++ b/authorize/grpc.go @@ -121,10 +121,10 @@ func (a *Authorize) getEvaluatorRequestFromCheckRequest( func (a *Authorize) getMatchingPolicy(routeID uint64) *config.Policy { options := a.currentOptions.Load() - for _, p := range options.GetAllPolicies() { + for p := range options.GetAllPolicies() { id, _ := p.RouteID() if id == routeID { - return &p + return p } } diff --git a/authorize/internal/store/store.go b/authorize/internal/store/store.go index e773ed4c8..263261758 100644 --- a/authorize/internal/store/store.go +++ b/authorize/internal/store/store.go @@ -48,7 +48,7 @@ func (s *Store) UpdateJWTClaimHeaders(jwtClaimHeaders map[string]string) { } // UpdateRoutePolicies updates the route policies in the store. -func (s *Store) UpdateRoutePolicies(routePolicies []config.Policy) { +func (s *Store) UpdateRoutePolicies(routePolicies []*config.Policy) { s.write("/route_policies", routePolicies) } diff --git a/config/envoyconfig/clusters.go b/config/envoyconfig/clusters.go index 539055ee3..3c3d9ab59 100644 --- a/config/envoyconfig/clusters.go +++ b/config/envoyconfig/clusters.go @@ -112,15 +112,11 @@ func (b *Builder) BuildClusters(ctx context.Context, cfg *config.Config) ([]*env } if config.IsProxy(cfg.Options.Services) { - for i, p := range cfg.Options.GetAllPolicies() { - policy := p - if policy.EnvoyOpts == nil { - policy.EnvoyOpts = newDefaultEnvoyClusterConfig() - } + for policy := range cfg.Options.GetAllPolicies() { if len(policy.To) > 0 { - cluster, err := b.buildPolicyCluster(ctx, cfg, &policy) + cluster, err := b.buildPolicyCluster(ctx, cfg, policy) if err != nil { - return nil, fmt.Errorf("policy #%d: %w", i, err) + return nil, fmt.Errorf("policy %q: %w", policy.String(), err) } clusters = append(clusters, cluster) } @@ -168,8 +164,12 @@ func (b *Builder) buildInternalCluster( } func (b *Builder) buildPolicyCluster(ctx context.Context, cfg *config.Config, policy *config.Policy) (*envoy_config_cluster_v3.Cluster, error) { - cluster := new(envoy_config_cluster_v3.Cluster) - proto.Merge(cluster, policy.EnvoyOpts) + var cluster *envoy_config_cluster_v3.Cluster + if policy.EnvoyOpts != nil { + cluster = proto.Clone(policy.EnvoyOpts).(*envoy_config_cluster_v3.Cluster) + } else { + cluster = newDefaultEnvoyClusterConfig() + } options := cfg.Options diff --git a/config/envoyconfig/listeners.go b/config/envoyconfig/listeners.go index 9447ef23d..1cbf57c9e 100644 --- a/config/envoyconfig/listeners.go +++ b/config/envoyconfig/listeners.go @@ -631,9 +631,7 @@ func clientCABundle(ctx context.Context, cfg *config.Config) []byte { var bundle bytes.Buffer ca, _ := cfg.Options.DownstreamMTLS.GetCA() addCAToBundle(&bundle, ca) - allPolicies := cfg.Options.GetAllPolicies() - for i := range allPolicies { - p := &allPolicies[i] + for p := range cfg.Options.GetAllPolicies() { // We don't need to check TLSDownstreamClientCAFile here because // Policy.Validate() will populate TLSDownstreamClientCA when // TLSDownstreamClientCAFile is set. diff --git a/config/envoyconfig/routes.go b/config/envoyconfig/routes.go index f19e61832..685b04696 100644 --- a/config/envoyconfig/routes.go +++ b/config/envoyconfig/routes.go @@ -177,7 +177,7 @@ func (b *Builder) buildRoutesForPoliciesWithHost( host string, ) ([]*envoy_config_route_v3.Route, error) { var routes []*envoy_config_route_v3.Route - for i, p := range cfg.Options.GetAllPolicies() { + for i, p := range cfg.Options.GetAllPoliciesIndexed() { policy := p fromURL, err := urlutil.ParseAndValidateURL(policy.From) if err != nil { @@ -188,7 +188,7 @@ func (b *Builder) buildRoutesForPoliciesWithHost( continue } - policyRoutes, err := b.buildRoutesForPolicy(cfg, &policy, fmt.Sprintf("policy-%d", i)) + policyRoutes, err := b.buildRoutesForPolicy(cfg, policy, fmt.Sprintf("policy-%d", i)) if err != nil { return nil, err } @@ -202,8 +202,7 @@ func (b *Builder) buildRoutesForPoliciesWithCatchAll( cfg *config.Config, ) ([]*envoy_config_route_v3.Route, error) { var routes []*envoy_config_route_v3.Route - for i, p := range cfg.Options.GetAllPolicies() { - policy := p + for i, policy := range cfg.Options.GetAllPoliciesIndexed() { fromURL, err := urlutil.ParseAndValidateURL(policy.From) if err != nil { return nil, err @@ -213,7 +212,7 @@ func (b *Builder) buildRoutesForPoliciesWithCatchAll( continue } - policyRoutes, err := b.buildRoutesForPolicy(cfg, &policy, fmt.Sprintf("policy-%d", i)) + policyRoutes, err := b.buildRoutesForPolicy(cfg, policy, fmt.Sprintf("policy-%d", i)) if err != nil { return nil, err } diff --git a/config/identity.go b/config/identity.go index 5adfcfbe7..806a698ec 100644 --- a/config/identity.go +++ b/config/identity.go @@ -8,9 +8,8 @@ import ( // GetIdentityProviderForID returns the identity provider associated with the given IDP id. // If none is found the default provider is returned. func (o *Options) GetIdentityProviderForID(idpID string) (*identity.Provider, error) { - for _, p := range o.GetAllPolicies() { - p := p - idp, err := o.GetIdentityProviderForPolicy(&p) + for p := range o.GetAllPolicies() { + idp, err := o.GetIdentityProviderForPolicy(p) if err != nil { return nil, err } @@ -63,10 +62,9 @@ func (o *Options) GetIdentityProviderForRequestURL(requestURL string) (*identity return nil, err } - for _, p := range o.GetAllPolicies() { - p := p + for p := range o.GetAllPolicies() { if p.Matches(*u, o.IsRuntimeFlagSet(RuntimeFlagMatchAnyIncomingPort)) { - return o.GetIdentityProviderForPolicy(&p) + return o.GetIdentityProviderForPolicy(p) } } return o.GetIdentityProviderForPolicy(nil) diff --git a/config/options.go b/config/options.go index ef8abef4c..2e14245a5 100644 --- a/config/options.go +++ b/config/options.go @@ -8,6 +8,7 @@ import ( "encoding/base64" "errors" "fmt" + "iter" "net/http" "net/url" "os" @@ -361,7 +362,7 @@ func newOptionsFromConfig(configFile string) (*Options, error) { } serviceName := telemetry.ServiceName(o.Services) metrics.AddPolicyCountCallback(serviceName, func() int64 { - return int64(len(o.GetAllPolicies())) + return int64(o.NumPolicies()) }) return o, nil @@ -979,15 +980,61 @@ func (o *Options) GetOauthOptions() (oauth.Options, error) { } // GetAllPolicies gets all the policies in the options. -func (o *Options) GetAllPolicies() []Policy { - if o == nil { - return nil +func (o *Options) GetAllPolicies() iter.Seq[*Policy] { + return func(yield func(*Policy) bool) { + if o == nil { + return + } + for i := range len(o.Policies) { + if !yield(&o.Policies[i]) { + return + } + } + for i := range len(o.Routes) { + if !yield(&o.Routes[i]) { + return + } + } + for i := range len(o.AdditionalPolicies) { + if !yield(&o.AdditionalPolicies[i]) { + return + } + } } - policies := make([]Policy, 0, len(o.Policies)+len(o.Routes)+len(o.AdditionalPolicies)) - policies = append(policies, o.Policies...) - policies = append(policies, o.Routes...) - policies = append(policies, o.AdditionalPolicies...) - return policies +} + +// GetAllPolicies gets all the policies in the options. +func (o *Options) GetAllPoliciesIndexed() iter.Seq2[int, *Policy] { + return func(yield func(int, *Policy) bool) { + if o == nil { + return + } + index := 0 + nextIndex := func() int { + i := index + index++ + return i + } + for i := range len(o.Policies) { + if !yield(nextIndex(), &o.Policies[i]) { + return + } + } + for i := range len(o.Routes) { + if !yield(nextIndex(), &o.Routes[i]) { + return + } + } + for i := range len(o.AdditionalPolicies) { + if !yield(nextIndex(), &o.AdditionalPolicies[i]) { + return + } + } + } +} + +func (o *Options) NumPolicies() int { + return len(o.Policies) + len(o.Routes) + len(o.AdditionalPolicies) } // GetMetricsBasicAuth gets the metrics basic auth username and password. @@ -1017,12 +1064,11 @@ func (o *Options) HasAnyDownstreamMTLSClientCA() bool { if len(ca) > 0 { return true } - allPolicies := o.GetAllPolicies() - for i := range allPolicies { + for p := range o.GetAllPolicies() { // We don't need to check TLSDownstreamClientCAFile here because // Policy.Validate() will populate TLSDownstreamClientCA when // TLSDownstreamClientCAFile is set. - if allPolicies[i].TLSDownstreamClientCA != "" { + if p.TLSDownstreamClientCA != "" { return true } } @@ -1273,7 +1319,7 @@ func (o *Options) GetAllRouteableHTTPHosts() ([]string, error) { // policy urls if IsProxy(o.Services) { - for _, policy := range o.GetAllPolicies() { + for policy := range o.GetAllPolicies() { fromURL, err := urlutil.ParseAndValidateURL(policy.From) if err != nil { return nil, err diff --git a/config/policy.go b/config/policy.go index 5182b6e20..b010de581 100644 --- a/config/policy.go +++ b/config/policy.go @@ -599,6 +599,14 @@ func (p *Policy) RouteID() (uint64, error) { return hashutil.Hash(id) } +func (p *Policy) MustRouteID() uint64 { + id, err := p.RouteID() + if err != nil { + panic(err) + } + return id +} + func (p *Policy) String() string { to := "?" if len(p.To) > 0 { diff --git a/internal/autocert/manager.go b/internal/autocert/manager.go index 6a5c0796f..727e21ba0 100644 --- a/internal/autocert/manager.go +++ b/internal/autocert/manager.go @@ -471,14 +471,12 @@ func configureTrustedRoots(acmeMgr *certmagic.ACMEIssuer, opts config.AutocertOp } func sourceHostnames(cfg *config.Config) []string { - policies := cfg.Options.GetAllPolicies() - - if len(policies) == 0 { + if cfg.Options.NumPolicies() == 0 { return nil } dedupe := map[string]struct{}{} - for _, p := range policies { + for p := range cfg.Options.GetAllPolicies() { if u, _ := urlutil.ParseAndValidateURL(p.From); u != nil && !strings.Contains(u.Host, "*") { dedupe[u.Hostname()] = struct{}{} } diff --git a/internal/databroker/config_source.go b/internal/databroker/config_source.go index e60af7128..f86828c41 100644 --- a/internal/databroker/config_source.go +++ b/internal/databroker/config_source.go @@ -199,8 +199,8 @@ func (src *ConfigSource) buildPolicyFromProto(_ context.Context, routepb *config } func (src *ConfigSource) addPolicies(ctx context.Context, cfg *config.Config, policies []*config.Policy) { - seen := make(map[uint64]struct{}) - for _, policy := range cfg.Options.GetAllPolicies() { + seen := make(map[uint64]struct{}, len(policies)+cfg.Options.NumPolicies()) + for policy := range cfg.Options.GetAllPolicies() { id, err := policy.RouteID() if err != nil { log.Ctx(ctx).Err(err).Str("policy", policy.String()).Msg("databroker: error getting route id") @@ -209,7 +209,7 @@ func (src *ConfigSource) addPolicies(ctx context.Context, cfg *config.Config, po seen[id] = struct{}{} } - var additionalPolicies []config.Policy + additionalPolicies := make([]config.Policy, 0, len(policies)) for _, policy := range policies { if policy == nil { continue diff --git a/internal/httputil/reproxy/reproxy.go b/internal/httputil/reproxy/reproxy.go index 60ac5edec..01cafa803 100644 --- a/internal/httputil/reproxy/reproxy.go +++ b/internal/httputil/reproxy/reproxy.go @@ -29,13 +29,13 @@ type Handler struct { mu sync.RWMutex key []byte options *config.Options - policies map[uint64]config.Policy + policies map[uint64]*config.Policy } // New creates a new Handler. func New() *Handler { h := new(Handler) - h.policies = make(map[uint64]config.Policy) + h.policies = make(map[uint64]*config.Policy) return h } @@ -120,7 +120,7 @@ func (h *Handler) Middleware(next http.Handler) http.Handler { h := stdhttputil.NewSingleHostReverseProxy(&dst) h.ErrorLog = stdlog.New(log.Logger(), "", 0) - h.Transport = config.NewPolicyHTTPTransport(options, &policy, disableHTTP2) + h.Transport = config.NewPolicyHTTPTransport(options, policy, disableHTTP2) h.ServeHTTP(w, r) return nil }) @@ -133,8 +133,8 @@ func (h *Handler) Update(ctx context.Context, cfg *config.Config) { h.key, _ = cfg.Options.GetSharedKey() h.options = cfg.Options - h.policies = make(map[uint64]config.Policy) - for _, p := range cfg.Options.GetAllPolicies() { + h.policies = make(map[uint64]*config.Policy, cfg.Options.NumPolicies()) + for p := range cfg.Options.GetAllPolicies() { id, err := p.RouteID() if err != nil { log.Warn(ctx).Err(err).Msg("reproxy: error getting route id") diff --git a/proxy/proxy.go b/proxy/proxy.go index 8d3c78c18..9d9c8f308 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -74,7 +74,7 @@ func New(cfg *config.Config) (*Proxy, error) { p.webauthn = webauthn.New(p.getWebauthnState) metrics.AddPolicyCountCallback("pomerium-proxy", func() int64 { - return int64(len(p.currentOptions.Load().GetAllPolicies())) + return int64(p.currentOptions.Load().NumPolicies()) }) return p, nil @@ -103,7 +103,7 @@ func (p *Proxy) OnConfigChange(_ context.Context, cfg *config.Config) { } func (p *Proxy) setHandlers(opts *config.Options) error { - if len(opts.GetAllPolicies()) == 0 { + if opts.NumPolicies() == 0 { log.Warn(context.TODO()).Msg("proxy: configuration has no policies") } r := httputil.NewRouter()