From 6dea25e0a2e00be0ad1d3f7292ad141d63831d2f Mon Sep 17 00:00:00 2001 From: Denis Mishin Date: Wed, 1 Nov 2023 20:26:09 -0400 Subject: [PATCH] concurrently build config --- config/options.go | 8 +- internal/databroker/config_source.go | 214 +++++++++++++--------- internal/databroker/config_source_test.go | 2 +- pkg/cmd/pomerium/pomerium.go | 2 +- pkg/slices/sync.go | 31 ++++ 5 files changed, 158 insertions(+), 99 deletions(-) create mode 100644 pkg/slices/sync.go diff --git a/config/options.go b/config/options.go index 69423238a..8fd1df227 100644 --- a/config/options.go +++ b/config/options.go @@ -300,8 +300,6 @@ type Options struct { AuditKey *PublicKeyEncryptionKeyOptions `mapstructure:"audit_key"` BrandingOptions httputil.BrandingOptions - - DisableValidation bool } type certificateFilePair struct { @@ -462,7 +460,7 @@ func (o *Options) parsePolicy() error { } } for i := range o.AdditionalPolicies { - p := &o.AdditionalPolicies[i] + p := o.AdditionalPolicies[i] if err := p.Validate(); err != nil { return err } @@ -579,10 +577,6 @@ func bindEnvsRecursive(t reflect.Type, v *viper.Viper, keyPrefix, envPrefix stri // Validate ensures the Options fields are valid, and hydrated. func (o *Options) Validate() error { - if o.DisableValidation { - return nil - } - ctx := context.TODO() if !IsValidService(o.Services) { return fmt.Errorf("config: %s is an invalid service type", o.Services) diff --git a/internal/databroker/config_source.go b/internal/databroker/config_source.go index 08b72b61b..8688e4115 100644 --- a/internal/databroker/config_source.go +++ b/internal/databroker/config_source.go @@ -2,11 +2,14 @@ package databroker import ( "context" + "fmt" + "runtime" "sort" "sync" "time" "golang.org/x/exp/maps" + "golang.org/x/sync/errgroup" "github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/internal/hashutil" @@ -18,6 +21,7 @@ import ( configpb "github.com/pomerium/pomerium/pkg/grpc/config" "github.com/pomerium/pomerium/pkg/grpc/databroker" "github.com/pomerium/pomerium/pkg/grpcutil" + "github.com/pomerium/pomerium/pkg/slices" ) // ConfigSource provides a new Config source that decorates an underlying config with @@ -30,6 +34,7 @@ type ConfigSource struct { dbConfigs map[string]dbConfig updaterHash uint64 cancel func() + enableValidation bool config.ChangeDispatcher } @@ -39,9 +44,18 @@ type dbConfig struct { version uint64 } +// EnableConfigValidation is a type that can be used to enable config validation. +type EnableConfigValidation bool + // NewConfigSource creates a new ConfigSource. -func NewConfigSource(ctx context.Context, underlying config.Source, listeners ...config.ChangeListener) *ConfigSource { +func NewConfigSource( + ctx context.Context, + underlying config.Source, + enableValidation EnableConfigValidation, + listeners ...config.ChangeListener, +) *ConfigSource { src := &ConfigSource{ + enableValidation: bool(enableValidation), dbConfigs: map[string]dbConfig{}, outboundGRPCConnection: new(grpc.CachedOutboundGRPClientConn), } @@ -74,104 +88,23 @@ func (src *ConfigSource) rebuild(ctx context.Context, firstTime firstTime) { _, span := trace.StartSpan(ctx, "databroker.config_source.rebuild") defer span.End() - log.Info(ctx).Msg("databroker: rebuilding configuration") - + now := time.Now() src.mu.Lock() defer src.mu.Unlock() + log.Info(ctx).Str("lock-wait", time.Since(now).String()).Msg("databroker: rebuilding configuration") cfg := src.underlyingConfig.Clone() // start the updater src.runUpdater(cfg) - seen := map[uint64]string{} - for _, policy := range cfg.Options.GetAllPolicies() { - id, err := policy.RouteID() - if err != nil { - log.Warn(ctx).Err(err). - Str("policy", policy.String()). - Msg("databroker: invalid policy config, ignoring") - return - } - seen[id] = "" + now = time.Now() + err := src.buildNewConfigLocked(ctx, cfg) + if err != nil { + log.Error(ctx).Err(err).Msg("databroker: failed to build new config") + return } - - var additionalPolicies []config.Policy - - ids := maps.Keys(src.dbConfigs) - sort.Strings(ids) - - var certsIndex *cryptutil.CertificatesIndex - if !cfg.Options.DisableValidation { - certsIndex = cryptutil.NewCertificatesIndex() - for _, cert := range cfg.Options.GetX509Certificates() { - certsIndex.Add(cert) - } - } - - // add all the config policies to the list - for _, id := range ids { - cfgpb := src.dbConfigs[id] - - cfg.Options.ApplySettings(ctx, certsIndex, cfgpb.Settings) - var errCount uint64 - - err := cfg.Options.Validate() - if err != nil { - metrics.SetDBConfigRejected(ctx, cfg.Options.Services, id, cfgpb.version, err) - return - } - - for _, routepb := range cfgpb.GetRoutes() { - policy, err := config.NewPolicyFromProto(routepb) - if err != nil { - errCount++ - log.Warn(ctx).Err(err). - Str("db_config_id", id). - Msg("databroker: error converting protobuf into policy") - continue - } - - err = policy.Validate() - if err != nil { - errCount++ - log.Warn(ctx).Err(err). - Str("db_config_id", id). - Str("policy", policy.String()). - Msg("databroker: invalid policy, ignoring") - continue - } - - routeID, err := policy.RouteID() - if err != nil { - errCount++ - log.Warn(ctx).Err(err). - Str("db_config_id", id). - Str("policy", policy.String()). - Msg("databroker: cannot establish policy route ID, ignoring") - continue - } - - if _, ok := seen[routeID]; ok { - errCount++ - log.Warn(ctx).Err(err). - Str("db_config_id", id). - Str("seen-in", seen[routeID]). - Str("policy", policy.String()). - Msg("databroker: duplicate policy detected, ignoring") - continue - } - seen[routeID] = id - - additionalPolicies = append(additionalPolicies, *policy) - } - metrics.SetDBConfigInfo(ctx, cfg.Options.Services, id, cfgpb.version, int64(errCount)) - } - - // add the additional policies here since calling `Validate` will reset them. - cfg.Options.AdditionalPolicies = append(cfg.Options.AdditionalPolicies, additionalPolicies...) - - log.Info(ctx).Msg("databroker: built new config") + log.Info(ctx).Str("elapsed", time.Since(now).String()).Msg("databroker: built new config") src.computedConfig = cfg if !firstTime { @@ -181,6 +114,107 @@ func (src *ConfigSource) rebuild(ctx context.Context, firstTime firstTime) { metrics.SetConfigInfo(ctx, cfg.Options.Services, "databroker", cfg.Checksum(), true) } +func (src *ConfigSource) buildNewConfigLocked(ctx context.Context, cfg *config.Config) error { + eg, ctx := errgroup.WithContext(ctx) + eg.SetLimit(runtime.NumCPU()/2 + 1) + eg.Go(func() error { + src.applySettingsLocked(ctx, cfg) + err := cfg.Options.Validate() + if err != nil { + return fmt.Errorf("validating settings: %w", err) + } + return nil + }) + policies := slices.NewSafeSlice[*config.Policy]() + for _, cfgpb := range src.dbConfigs { + for _, routepb := range cfgpb.GetRoutes() { + routepb := routepb + eg.Go(func() error { + policy, err := src.buildPolicyFromProto(routepb) + if err != nil { + log.Ctx(ctx).Err(err).Msg("databroker: error building policy from protobuf") + return nil + } + policies.Append(policy) + return nil + }) + } + } + err := eg.Wait() + if err != nil { + return err + } + + src.addPolicies(ctx, cfg, policies.Get()) + return nil +} + +func (src *ConfigSource) applySettingsLocked(ctx context.Context, cfg *config.Config) { + ids := maps.Keys(src.dbConfigs) + sort.Strings(ids) + + var certsIndex *cryptutil.CertificatesIndex + if src.enableValidation { + certsIndex = cryptutil.NewCertificatesIndex() + for _, cert := range cfg.Options.GetX509Certificates() { + certsIndex.Add(cert) + } + } + + for i := 0; i < len(ids) && ctx.Err() == nil; i++ { + cfgpb := src.dbConfigs[ids[i]] + cfg.Options.ApplySettings(ctx, certsIndex, cfgpb.Settings) + } +} + +func (src *ConfigSource) buildPolicyFromProto(routepb *configpb.Route) (*config.Policy, error) { + policy, err := config.NewPolicyFromProto(routepb) + if err != nil { + return nil, fmt.Errorf("error building policy from protobuf: %w", err) + } + + if !src.enableValidation { + return policy, nil + } + + err = policy.Validate() + if err != nil { + return nil, fmt.Errorf("error validating policy: %w", err) + } + + return policy, nil +} + +func (src *ConfigSource) addPolicies(ctx context.Context, cfg *config.Config, policies []*config.Policy) { + seen := make(map[uint64]struct{}) + for _, policy := range cfg.Options.GetAllPolicies() { + id, err := policy.RouteID() + if err != nil { + log.Ctx(ctx).Err(err).Str("policy", policy.String()).Msg("databroker: error getting route id") + continue + } + seen[id] = struct{}{} + } + + var additionalPolicies []config.Policy + for _, policy := range policies { + id, err := policy.RouteID() + if err != nil { + log.Ctx(ctx).Err(err).Str("policy", policy.String()).Msg("databroker: error getting route id") + continue + } + if _, ok := seen[id]; ok { + log.Ctx(ctx).Debug().Str("policy", policy.String()).Msg("databroker: policy already exists") + continue + } + additionalPolicies = append(additionalPolicies, *policy) + seen[id] = struct{}{} + } + + // add the additional policies here since calling `Validate` will reset them. + cfg.Options.AdditionalPolicies = append(cfg.Options.AdditionalPolicies, additionalPolicies...) +} + func (src *ConfigSource) runUpdater(cfg *config.Config) { sharedKey, _ := cfg.Options.GetSharedKey() connectionOptions := &grpc.OutboundOptions{ diff --git a/internal/databroker/config_source_test.go b/internal/databroker/config_source_test.go index 335662f59..063fd94be 100644 --- a/internal/databroker/config_source_test.go +++ b/internal/databroker/config_source_test.go @@ -65,7 +65,7 @@ func TestConfigSource(t *testing.T) { OutboundPort: outboundPort, Options: base, }) - src := NewConfigSource(ctx, baseSource, func(_ context.Context, cfg *config.Config) { + src := NewConfigSource(ctx, baseSource, EnableConfigValidation(true), func(_ context.Context, cfg *config.Config) { cfgs <- cfg }) cfgs <- src.GetConfig() diff --git a/pkg/cmd/pomerium/pomerium.go b/pkg/cmd/pomerium/pomerium.go index c732463c3..e933cd473 100644 --- a/pkg/cmd/pomerium/pomerium.go +++ b/pkg/cmd/pomerium/pomerium.go @@ -40,7 +40,7 @@ func Run(ctx context.Context, src config.Source) error { if err != nil { return err } - src = databroker.NewConfigSource(ctx, src) + src = databroker.NewConfigSource(ctx, src, databroker.EnableConfigValidation(true)) logMgr := config.NewLogManager(ctx, src) defer logMgr.Close() diff --git a/pkg/slices/sync.go b/pkg/slices/sync.go new file mode 100644 index 000000000..ed5ed07c7 --- /dev/null +++ b/pkg/slices/sync.go @@ -0,0 +1,31 @@ +package slices + +import "sync" + +// SafeSlice is a thread safe slice. +type SafeSlice[E any] struct { + mu sync.RWMutex + slice []E +} + +// NewSafeSlice creates a new SafeSlice. +func NewSafeSlice[E any]() *SafeSlice[E] { + return &SafeSlice[E]{} +} + +// Append appends e to the slice. +func (s *SafeSlice[E]) Append(e E) { + s.mu.Lock() + s.slice = append(s.slice, e) + s.mu.Unlock() +} + +// Get gets the slice. +func (s *SafeSlice[E]) Get() []E { + s.mu.RLock() + defer s.mu.RUnlock() + + c := make([]E, len(s.slice)) + copy(c, s.slice) + return c +}