From 6e766233c7d5d0ee329ca3934a62f960cae4c253 Mon Sep 17 00:00:00 2001 From: Denis Mishin Date: Tue, 20 Aug 2024 22:13:45 -0400 Subject: [PATCH] zero/health-checks: fix early checks sometimes missing (#5229) * zero/health-checks: fix early checks sometimes missing * rm closure * fix test --- internal/zero/controller/controller.go | 13 ++-- internal/zero/telemetry/telemetry.go | 8 +- pkg/health/deduplicate.go | 103 ++++++++++++++++++------- pkg/health/deduplicate_test.go | 28 ++++--- pkg/health/provider.go | 37 +-------- 5 files changed, 109 insertions(+), 80 deletions(-) diff --git a/internal/zero/controller/controller.go b/internal/zero/controller/controller.go index 5fc4f1250..3e4f69125 100644 --- a/internal/zero/controller/controller.go +++ b/internal/zero/controller/controller.go @@ -148,7 +148,7 @@ func (c *controller) runZeroControlLoop(ctx context.Context) error { if err != nil { return fmt.Errorf("init telemetry: %w", err) } - defer c.shutdownTelemetry(ctx, tm) + defer c.shutdownWithTimeout(ctx, "telemetry", tm.Shutdown) eg, ctx := errgroup.WithContext(ctx) eg.Go(func() error { return tm.Run(ctx) }) @@ -165,13 +165,16 @@ func (c *controller) runZeroControlLoop(ctx context.Context) error { return eg.Wait() } -func (c *controller) shutdownTelemetry(ctx context.Context, tm *telemetry.Telemetry) { - ctx, cancel := context.WithTimeout(ctx, c.cfg.shutdownTimeout) +func (c *controller) shutdownWithTimeout(ctx context.Context, name string, fn func(context.Context) error) { + ctx, cancel := context.WithTimeout(context.WithoutCancel(ctx), c.cfg.shutdownTimeout) defer cancel() - err := tm.Shutdown(ctx) + log.Ctx(ctx).Debug().Str("timeout", c.cfg.shutdownTimeout.String()).Msgf("shutting down %s ...", name) + err := fn(ctx) if err != nil { - log.Ctx(ctx).Error().Err(err).Msg("error shutting down telemetry") + log.Ctx(ctx).Error().Err(err).Msgf("error shutting down %s", name) + } else { + log.Ctx(ctx).Debug().Msgf("%s shutdown complete", name) } } diff --git a/internal/zero/telemetry/telemetry.go b/internal/zero/telemetry/telemetry.go index 5f4d5de75..c69a49962 100644 --- a/internal/zero/telemetry/telemetry.go +++ b/internal/zero/telemetry/telemetry.go @@ -73,12 +73,14 @@ func (srv *Telemetry) Run(ctx context.Context) error { }) eg, ctx := errgroup.WithContext(ctx) + eg.Go(func() error { return srv.reporter.Run(ctx) }) + eg.Go(func() error { return srv.handleRequests(ctx) }) eg.Go(func() error { health.SetProvider(srv.reporter) - defer health.SetProvider(nil) - return srv.reporter.Run(ctx) + <-ctx.Done() + health.SetProvider(nil) + return nil }) - eg.Go(func() error { return srv.handleRequests(ctx) }) return eg.Wait() } diff --git a/pkg/health/deduplicate.go b/pkg/health/deduplicate.go index 24e6a025e..b90caae5f 100644 --- a/pkg/health/deduplicate.go +++ b/pkg/health/deduplicate.go @@ -5,18 +5,19 @@ import ( "sync" ) -var _ Provider = (*deduplicator)(nil) +var _ Provider = (*Deduplicator)(nil) -// deduplicator is a health check provider that deduplicates health check reports +// Deduplicator is a health check provider that deduplicates health check reports // i.e. it only reports a health check if the status or attributes have changed -type deduplicator struct { - seen sync.Map +type Deduplicator struct { + lock sync.Mutex + records map[Check]*record provider Provider } type record struct { attr map[string]string - err *string + err error } func newOKRecord(attrs []Attr) *record { @@ -24,11 +25,10 @@ func newOKRecord(attrs []Attr) *record { } func newErrorRecord(err error, attrs []Attr) *record { - errTxt := err.Error() - return newRecord(&errTxt, attrs) + return newRecord(err, attrs) } -func newRecord(err *string, attrs []Attr) *record { +func newRecord(err error, attrs []Attr) *record { r := &record{err: err, attr: make(map[string]string)} for _, a := range attrs { r.attr[a.Key] = a.Value @@ -36,44 +36,87 @@ func newRecord(err *string, attrs []Attr) *record { return r } +func (r *record) Attr() []Attr { + attrs := make([]Attr, 0, len(r.attr)) + for k, v := range r.attr { + attrs = append(attrs, Attr{Key: k, Value: v}) + } + return attrs +} + func (r *record) Equals(other *record) bool { - return r.equalError(other) && + return equalError(r.err, other.err) && maps.Equal(r.attr, other.attr) } -func (r *record) equalError(other *record) bool { - if r.err == nil || other.err == nil { - return r.err == other.err +func equalError(a, b error) bool { + if a == nil || b == nil { + return a == b //nolint:errorlint } - return *r.err == *other.err + return a.Error() == b.Error() } -func NewDeduplicator(provider Provider) Provider { - return &deduplicator{provider: provider} +func report(p Provider, check Check, err error, attrs ...Attr) { + if err != nil { + p.ReportError(check, err, attrs...) + } else { + p.ReportOK(check, attrs...) + } } -func (d *deduplicator) swap(check Check, next *record) *record { - prev, there := d.seen.Swap(check, next) - if !there { - return nil +func NewDeduplicator() *Deduplicator { + return &Deduplicator{ + records: make(map[Check]*record), + provider: &noopProvider{}, } - return prev.(*record) +} + +func (d *Deduplicator) SetProvider(p Provider) { + if p == nil { + p = &noopProvider{} + } + records := d.setProvider(p) + for check, record := range records { + report(p, check, record.err, record.Attr()...) + } +} + +func (d *Deduplicator) setProvider(p Provider) map[Check]*record { + d.lock.Lock() + defer d.lock.Unlock() + + d.provider = p + return maps.Clone(d.records) +} + +func (d *Deduplicator) swap(check Check, next *record) (provider Provider, changed bool) { + d.lock.Lock() + defer d.lock.Unlock() + + prev := d.records[check] + d.records[check] = next + changed = prev == nil || !next.Equals(prev) + return d.provider, changed } // ReportError implements the Provider interface -func (d *deduplicator) ReportError(check Check, err error, attrs ...Attr) { - cur := newErrorRecord(err, attrs) - prev := d.swap(check, cur) - if prev == nil || !cur.Equals(prev) { - d.provider.ReportError(check, err, attrs...) +func (d *Deduplicator) ReportError(check Check, err error, attrs ...Attr) { + provider, changed := d.swap(check, newErrorRecord(err, attrs)) + if changed { + provider.ReportError(check, err, attrs...) } } // ReportOK implements the Provider interface -func (d *deduplicator) ReportOK(check Check, attrs ...Attr) { - cur := newOKRecord(attrs) - prev := d.swap(check, cur) - if prev == nil || !cur.Equals(prev) { - d.provider.ReportOK(check, attrs...) +func (d *Deduplicator) ReportOK(check Check, attrs ...Attr) { + provider, changed := d.swap(check, newOKRecord(attrs)) + if changed { + provider.ReportOK(check, attrs...) } } + +type noopProvider struct{} + +func (n *noopProvider) ReportOK(Check, ...Attr) {} + +func (n *noopProvider) ReportError(Check, error, ...Attr) {} diff --git a/pkg/health/deduplicate_test.go b/pkg/health/deduplicate_test.go index 173bf634c..324634226 100644 --- a/pkg/health/deduplicate_test.go +++ b/pkg/health/deduplicate_test.go @@ -14,28 +14,38 @@ import ( func TestDeduplicate(t *testing.T) { t.Parallel() - p := NewMockProvider(gomock.NewController(t)) - dp := health.NewDeduplicator(p) + p1 := NewMockProvider(gomock.NewController(t)) + dp := health.NewDeduplicator() + dp.SetProvider(p1) - check1, check2 := health.Check("check-1"), health.Check("check-2") - p.EXPECT().ReportOK(check1).Times(1) - p.EXPECT().ReportOK(check2).Times(1) + check1, check2, check3 := health.Check("check-1"), health.Check("check-2"), health.Check("check-3") + p1.EXPECT().ReportOK(check1).Times(1) + p1.EXPECT().ReportOK(check2).Times(1) + p1.EXPECT().ReportError(check3, errors.New("error-3")).Times(1) dp.ReportOK(check1) dp.ReportOK(check2) dp.ReportOK(check1) + dp.ReportError(check3, errors.New("error-3")) - p.EXPECT().ReportError(check1, gomock.Any()).Times(1) + p1.EXPECT().ReportError(check1, errors.New("error")).Times(1) dp.ReportError(check1, errors.New("error")) dp.ReportError(check1, errors.New("error")) - p.EXPECT().ReportOK(check1).Times(1) + p1.EXPECT().ReportOK(check1).Times(1) dp.ReportOK(check1) - p.EXPECT().ReportOK(check1, health.StrAttr("k1", "v1")).Times(2) - p.EXPECT().ReportOK(check1, health.StrAttr("k1", "v2")).Times(1) + p1.EXPECT().ReportOK(check1, health.StrAttr("k1", "v1")).Times(2) + p1.EXPECT().ReportOK(check1, health.StrAttr("k1", "v2")).Times(1) dp.ReportOK(check1, health.StrAttr("k1", "v1")) dp.ReportOK(check1, health.StrAttr("k1", "v2")) dp.ReportOK(check1, health.StrAttr("k1", "v1")) + + // after setting new provider, current state should be reported + p2 := NewMockProvider(gomock.NewController(t)) + p2.EXPECT().ReportOK(check1, health.StrAttr("k1", "v1")).Times(1) + p2.EXPECT().ReportOK(check2).Times(1) + p2.EXPECT().ReportError(check3, errors.New("error-3")).Times(1) + dp.SetProvider(p2) } func TestDefault(t *testing.T) { diff --git a/pkg/health/provider.go b/pkg/health/provider.go index 2907abbe4..44106275f 100644 --- a/pkg/health/provider.go +++ b/pkg/health/provider.go @@ -2,7 +2,6 @@ package health import ( "errors" - "sync" ) // Attr is a key-value pair that can be attached to a health check @@ -26,10 +25,7 @@ func ErrorAttr(err error) Attr { // ReportOK reports that a check was successful func ReportOK(check Check, attributes ...Attr) { - p := defaultProvider.Load() - if p != nil { - p.ReportOK(check, attributes...) - } + provider.ReportOK(check, attributes...) } var ErrInternalError = errors.New("internal error") @@ -41,10 +37,7 @@ func ReportInternalError(check Check, err error, attributes ...Attr) { // ReportError reports that a check failed func ReportError(check Check, err error, attributes ...Attr) { - p := defaultProvider.Load() - if p != nil { - p.ReportError(check, err, attributes...) - } + provider.ReportError(check, err, attributes...) } // Provider is the interface that must be implemented by a health check reporter @@ -55,29 +48,7 @@ type Provider interface { // SetProvider sets the health check provider func SetProvider(p Provider) { - if p != nil { - p = NewDeduplicator(p) - } - defaultProvider.Store(p) + provider.SetProvider(p) } -type providerStore struct { - sync.RWMutex - provider Provider -} - -func (p *providerStore) Load() Provider { - p.RLock() - defer p.RUnlock() - - return p.provider -} - -func (p *providerStore) Store(provider Provider) { - p.Lock() - defer p.Unlock() - - p.provider = provider -} - -var defaultProvider providerStore +var provider = NewDeduplicator()