diff --git a/pkg/storage/redis/redis.go b/pkg/storage/redis/redis.go index 7fc133340..bd6a0ab5f 100644 --- a/pkg/storage/redis/redis.go +++ b/pkg/storage/redis/redis.go @@ -7,6 +7,7 @@ import ( "fmt" "net" "strconv" + "sync" "time" "github.com/cenkalti/backoff/v4" @@ -38,6 +39,7 @@ type DB struct { versionSet string deletedSet string tlsConfig *tls.Config + notifyChMu sync.Mutex } // New returns new DB instance. @@ -243,9 +245,23 @@ func (db *DB) ClearDeleted(ctx context.Context, cutoff time.Time) { } // doNotifyLoop receives event from redis and send signal to the channel. -func (db *DB) doNotifyLoop(ctx context.Context, ch chan struct{}, psc *redis.PubSubConn) { +func (db *DB) doNotifyLoop(ctx context.Context, ch chan struct{}) { eb := backoff.NewExponentialBackOff() + psConn := db.pool.Get() + psc := redis.PubSubConn{Conn: psConn} + defer func(psc *redis.PubSubConn) { + psc.Conn.Close() + }(&psc) + + if err := db.subscribeRedisChannel(&psc); err != nil { + log.Error().Err(err).Msg("failed to subscribe to version set channel") + return + } for { + select { + case <-ctx.Done(): + default: + } switch v := psc.Receive().(type) { case redis.Message: log.Debug().Str("action", string(v.Data)).Msg("got redis message") @@ -253,11 +269,21 @@ func (db *DB) doNotifyLoop(ctx context.Context, ch chan struct{}, psc *redis.Pub if string(v.Data) != watchAction { continue } + select { case <-ctx.Done(): log.Warn().Err(ctx.Err()).Msg("context done, stop receive from redis channel") return - case ch <- struct{}{}: + default: + db.notifyChMu.Lock() + select { + case <-ctx.Done(): + db.notifyChMu.Unlock() + log.Warn().Err(ctx.Err()).Msg("context done while holding notify lock, stop receive from redis channel") + return + case ch <- struct{}{}: + } + db.notifyChMu.Unlock() } case error: log.Warn().Err(v).Msg("failed to receive from redis channel") @@ -269,28 +295,22 @@ func (db *DB) doNotifyLoop(ctx context.Context, ch chan struct{}, psc *redis.Pub log.Warn().Msg("retry with new connection") _ = psc.Conn.Close() psc.Conn = db.pool.Get() - _ = db.subscribeRedisChannel(psc) + _ = db.subscribeRedisChannel(&psc) } } } // watch runs the doNotifyLoop. It returns when ctx was done or doNotifyLoop exits. func (db *DB) watch(ctx context.Context, ch chan struct{}) { - psConn := db.pool.Get() - psc := redis.PubSubConn{Conn: psConn} - defer func(psc *redis.PubSubConn) { - psc.Conn.Close() - }(&psc) - - if err := db.subscribeRedisChannel(&psc); err != nil { - log.Error().Err(err).Msg("failed to subscribe to version set channel") - return - } - + defer func() { + db.notifyChMu.Lock() + close(ch) + db.notifyChMu.Unlock() + }() done := make(chan struct{}) go func() { defer close(done) - db.doNotifyLoop(ctx, ch, &psc) + db.doNotifyLoop(ctx, ch) }() select { case <-ctx.Done(): @@ -308,10 +328,6 @@ func (db *DB) Watch(ctx context.Context) chan struct{} { ch := make(chan struct{}) go func() { c := db.pool.Get() - defer func() { - close(ch) - }() - // Setup notifications, we only care about changes to db.version_set. if _, err := c.Do("CONFIG", "SET", "notify-keyspace-events", "Kz"); err != nil { log.Error().Err(err).Msg("failed to setup redis notification")