diff --git a/authorize/authorize.go b/authorize/authorize.go index d3a2ff805..49839b4c6 100644 --- a/authorize/authorize.go +++ b/authorize/authorize.go @@ -69,7 +69,7 @@ func validateOptions(o *config.Options) error { // newPolicyEvaluator returns an policy evaluator. func newPolicyEvaluator(opts *config.Options, store *evaluator.Store) (*evaluator.Evaluator, error) { metrics.AddPolicyCountCallback("pomerium-authorize", func() int64 { - return int64(len(opts.Policies)) + return int64(len(opts.GetAllPolicies())) }) ctx := context.Background() _, span := trace.StartSpan(ctx, "authorize.newPolicyEvaluator") diff --git a/authorize/evaluator/evaluator.go b/authorize/evaluator/evaluator.go index dcdab1d61..f7e2ed112 100644 --- a/authorize/evaluator/evaluator.go +++ b/authorize/evaluator/evaluator.go @@ -52,7 +52,7 @@ func New(options *config.Options, store *Store) (*Evaluator, error) { e := &Evaluator{ custom: NewCustomEvaluator(store.opaStore), authenticateHost: options.AuthenticateURL.Host, - policies: options.Policies, + policies: options.GetAllPolicies(), } if options.ClientCA != "" { e.clientCA = options.ClientCA @@ -75,7 +75,7 @@ func New(options *config.Options, store *Store) (*Evaluator, error) { } store.UpdateAdmins(options.Administrators) - store.UpdateRoutePolicies(options.Policies) + store.UpdateRoutePolicies(options.GetAllPolicies()) e.rego = rego.New( rego.Store(store.opaStore), diff --git a/authorize/grpc.go b/authorize/grpc.go index 5a3c0db00..fcd033ac5 100644 --- a/authorize/grpc.go +++ b/authorize/grpc.go @@ -245,7 +245,7 @@ func (a *Authorize) getEvaluatorRequestFromCheckRequest(in *envoy_service_auth_v func (a *Authorize) getMatchingPolicy(requestURL *url.URL) *config.Policy { options := a.currentOptions.Load() - for _, p := range options.Policies { + for _, p := range options.GetAllPolicies() { if p.Matches(requestURL) { return &p } diff --git a/config/options.go b/config/options.go index d98ca3bb8..e8e341e91 100644 --- a/config/options.go +++ b/config/options.go @@ -112,6 +112,9 @@ type Options struct { PolicyEnv string `yaml:",omitempty"` PolicyFile string `mapstructure:"policy_file" yaml:"policy_file,omitempty"` + // AdditionalPolicies are any additional policies added to the options. + AdditionalPolicies []Policy `yaml:"-"` + // AuthenticateURL represents the externally accessible http endpoints // used for authentication requests and callbacks AuthenticateURLString string `mapstructure:"authenticate_service_url" yaml:"authenticate_service_url,omitempty"` @@ -336,7 +339,7 @@ func newOptionsFromConfig(configFile string) (*Options, error) { } serviceName := telemetry.ServiceName(o.Services) metrics.AddPolicyCountCallback(serviceName, func() int64 { - return int64(len(o.Policies)) + return int64(len(o.GetAllPolicies())) }) metrics.SetConfigChecksum(serviceName, o.Checksum()) @@ -404,6 +407,12 @@ func (o *Options) parsePolicy() error { return err } } + for i := range o.AdditionalPolicies { + p := &o.AdditionalPolicies[i] + if err := p.Validate(); err != nil { + return err + } + } return nil } @@ -654,7 +663,7 @@ func (o *Options) Validate() error { // assert group membership (except for azure which can be derived from the client // id, secret and provider url) if o.ServiceAccount == "" && o.Provider != "azure" { - for _, p := range o.Policies { + for _, p := range o.GetAllPolicies() { if len(p.AllowedGroups) != 0 { return fmt.Errorf("config: `allowed_groups` requires `idp_service_account`") } @@ -751,6 +760,17 @@ func (o *Options) GetOauthOptions() oauth.Options { } } +// GetAllPolicies gets all the policies in the options. +func (o *Options) GetAllPolicies() []Policy { + if o == nil { + return nil + } + policies := make([]Policy, 0, len(o.Policies)+len(o.AdditionalPolicies)) + policies = append(policies, o.Policies...) + policies = append(policies, o.AdditionalPolicies...) + return policies +} + // Checksum returns the checksum of the current options struct func (o *Options) Checksum() uint64 { return hashutil.MustHash(o) diff --git a/internal/autocert/manager.go b/internal/autocert/manager.go index d32ac267a..d1ac7be19 100644 --- a/internal/autocert/manager.go +++ b/internal/autocert/manager.go @@ -283,12 +283,14 @@ func (mgr *Manager) GetConfig() *config.Config { } func sourceHostnames(cfg *config.Config) []string { - if len(cfg.Options.Policies) == 0 { + policies := cfg.Options.GetAllPolicies() + + if len(policies) == 0 { return nil } dedupe := map[string]struct{}{} - for _, p := range cfg.Options.Policies { + for _, p := range policies { dedupe[p.Source.Hostname()] = struct{}{} } if cfg.Options.AuthenticateURL != nil { diff --git a/internal/controlplane/xds_clusters.go b/internal/controlplane/xds_clusters.go index ec19ee616..28c562a71 100644 --- a/internal/controlplane/xds_clusters.go +++ b/internal/controlplane/xds_clusters.go @@ -75,8 +75,8 @@ func (srv *Server) buildClusters(options *config.Options) ([]*envoy_config_clust } if config.IsProxy(options.Services) { - for i := range options.Policies { - policy := options.Policies[i] + for i, p := range options.GetAllPolicies() { + policy := p if policy.EnvoyOpts == nil { policy.EnvoyOpts = newDefaultEnvoyClusterConfig() } diff --git a/internal/controlplane/xds_listeners.go b/internal/controlplane/xds_listeners.go index aadfc3f74..a3765c061 100644 --- a/internal/controlplane/xds_listeners.go +++ b/internal/controlplane/xds_listeners.go @@ -439,7 +439,7 @@ func getAllRouteableDomains(options *config.Options, addr string) []string { } } if config.IsProxy(options.Services) && addr == options.Addr { - for _, policy := range options.Policies { + for _, policy := range options.GetAllPolicies() { for _, h := range urlutil.GetDomainsForURL(policy.Source.URL) { lookup[h] = struct{}{} } diff --git a/internal/controlplane/xds_routes.go b/internal/controlplane/xds_routes.go index 3bdf54adf..e5b1a3364 100644 --- a/internal/controlplane/xds_routes.go +++ b/internal/controlplane/xds_routes.go @@ -181,8 +181,8 @@ func buildPolicyRoutes(options *config.Options, domain string) []*envoy_config_r var routes []*envoy_config_route_v3.Route responseHeadersToAdd := toEnvoyHeaders(options.Headers) - for i := range options.Policies { - policy := options.Policies[i] + for i, p := range options.GetAllPolicies() { + policy := p if !hostMatchesDomain(policy.Source.URL, domain) { continue } @@ -445,7 +445,7 @@ func setHostRewriteOptions(policy *config.Policy, action *envoy_config_route_v3. } func hasPublicPolicyMatchingURL(options *config.Options, requestURL *url.URL) bool { - for _, policy := range options.Policies { + for _, policy := range options.GetAllPolicies() { if policy.AllowPublicUnauthenticatedAccess && policy.Matches(requestURL) { return true } diff --git a/internal/databroker/config_source.go b/internal/databroker/config_source.go index 537045252..afef1cf7e 100644 --- a/internal/databroker/config_source.go +++ b/internal/databroker/config_source.go @@ -82,7 +82,7 @@ func (src *ConfigSource) rebuild(firstTime bool) { src.runUpdater(cfg) seen := map[uint64]struct{}{} - for _, policy := range cfg.Options.Policies { + for _, policy := range cfg.Options.GetAllPolicies() { seen[policy.RouteID()] = struct{}{} } @@ -128,7 +128,7 @@ func (src *ConfigSource) rebuild(firstTime bool) { } // add the additional policies here since calling `Validate` will reset them. - cfg.Options.Policies = append(cfg.Options.Policies, additionalPolicies...) + cfg.Options.AdditionalPolicies = append(cfg.Options.AdditionalPolicies, additionalPolicies...) src.computedConfig = cfg if !firstTime { diff --git a/internal/databroker/config_source_test.go b/internal/databroker/config_source_test.go index 00ccec0f6..d701acf6c 100644 --- a/internal/databroker/config_source_test.go +++ b/internal/databroker/config_source_test.go @@ -65,7 +65,7 @@ func TestConfigSource(t *testing.T) { assert.NoError(t, ctx.Err()) return case cfg := <-cfgs: - assert.Len(t, cfg.Options.Policies, 0) + assert.Len(t, cfg.Options.AdditionalPolicies, 0) } select { @@ -73,7 +73,7 @@ func TestConfigSource(t *testing.T) { assert.NoError(t, ctx.Err()) return case cfg := <-cfgs: - assert.Len(t, cfg.Options.Policies, 1) + assert.Len(t, cfg.Options.AdditionalPolicies, 1) } } diff --git a/proxy/proxy.go b/proxy/proxy.go index de2c96f91..f8a991d32 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -71,7 +71,7 @@ func New(cfg *config.Config) (*Proxy, error) { p.currentRouter.Store(httputil.NewRouter()) metrics.AddPolicyCountCallback("pomerium-proxy", func() int64 { - return int64(len(p.currentOptions.Load().Policies)) + return int64(len(p.currentOptions.Load().GetAllPolicies())) }) return p, nil @@ -94,7 +94,7 @@ func (p *Proxy) OnConfigChange(cfg *config.Config) { } func (p *Proxy) setHandlers(opts *config.Options) { - if len(opts.Policies) == 0 { + if len(opts.GetAllPolicies()) == 0 { log.Warn().Msg("proxy: configuration has no policies") } r := httputil.NewRouter()