implement new redis storage backend with go-redis package (#1649)

This commit is contained in:
Caleb Doxsey 2020-12-10 12:21:31 -07:00 committed by GitHub
parent 2e8b842aed
commit 3b634de550
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 383 additions and 406 deletions

View file

@ -1,404 +1,393 @@
// Package redis is the redis database, implements storage.Backend interface.
// Package redis implements the storage.Backend interface for redis.
package redis
import (
"context"
"crypto/tls"
"errors"
"fmt"
"net"
"strconv"
"sync"
"time"
"github.com/cenkalti/backoff/v4"
redis "github.com/go-redis/redis/v8"
"github.com/golang/protobuf/proto"
"github.com/golang/protobuf/ptypes"
"github.com/gomodule/redigo/redis"
"google.golang.org/protobuf/types/known/anypb"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/signal"
"github.com/pomerium/pomerium/internal/telemetry/metrics"
"github.com/pomerium/pomerium/internal/telemetry/trace"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/storage"
)
// Name is the storage type name for redis backend.
// Name of the storage backend.
const Name = config.StorageRedisName
var _ storage.Backend = (*DB)(nil)
const (
maxTransactionRetries = 100
watchPollInterval = 30 * time.Second
)
// DB wraps redis conn to interact with redis server.
// custom errors
var (
ErrExceededMaxRetries = errors.New("redis: transaction reached maximum number of retries")
)
// DB implements the storage.Backend on top of redis.
type DB struct {
pool *redis.Pool
recordType string
lastVersionKey string
lastVersionChannelKey string
versionSet string
deletedSet string
tlsConfig *tls.Config
notifyChMu sync.Mutex
cfg *dbConfig
client *redis.Client
closeOnce sync.Once
closed chan struct{}
}
// New returns new DB instance.
func New(rawURL, recordType string, opts ...Option) (*DB, error) {
// New creates a new redis storage backend.
func New(rawURL string, options ...Option) (*DB, error) {
db := &DB{
recordType: recordType,
versionSet: recordType + "_version_set",
deletedSet: recordType + "_deleted_set",
lastVersionKey: recordType + "_last_version",
lastVersionChannelKey: recordType + "_last_version_ch",
closed: make(chan struct{}),
cfg: getConfig(options...),
closed: make(chan struct{}),
}
for _, o := range opts {
o(db)
opts, err := redis.ParseURL(rawURL)
if err != nil {
return nil, err
}
db.pool = &redis.Pool{
Wait: true,
Dial: func() (redis.Conn, error) {
c, err := redis.DialURL(rawURL, redis.DialTLSConfig(db.tlsConfig))
if err != nil {
return nil, fmt.Errorf(`redis.DialURL(): %w`, err)
}
return c, nil
},
TestOnBorrow: func(c redis.Conn, t time.Time) error {
if time.Since(t) < time.Minute {
return nil
}
_, err := c.Do("PING")
if err != nil {
return fmt.Errorf(`c.Do("PING"): %w`, err)
}
return nil
},
MaxIdle: 64,
IdleTimeout: time.Minute,
MaxActive: 128,
// when using TLS, the TLS config will not be set to nil, in which case we replace it with our own
if opts.TLSConfig != nil {
opts.TLSConfig = db.cfg.tls
}
metrics.AddRedisMetrics(db.pool.Stats)
db.client = redis.NewClient(opts)
return db, nil
}
// Close closes the redis db connection.
// ClearDeleted clears all the deleted records older than the cutoff time.
func (db *DB) ClearDeleted(ctx context.Context, cutoff time.Time) {
var err error
_, span := trace.StartSpan(ctx, "databroker.redis.ClearDeleted")
defer span.End()
defer func(start time.Time) { recordOperation(ctx, start, "clear_deleted", err) }(time.Now())
ids, _ := db.client.SMembers(ctx, formatDeletedSetKey(db.cfg.recordType)).Result()
records, _ := redisGetRecords(ctx, db.client, db.cfg.recordType, ids)
_, err = db.client.Pipelined(ctx, func(p redis.Pipeliner) error {
for _, record := range records {
if record.GetDeletedAt().AsTime().Before(cutoff) {
p.HDel(ctx, formatRecordsKey(db.cfg.recordType), record.GetId())
p.ZRem(ctx, formatVersionSetKey(db.cfg.recordType), record.GetId())
p.SRem(ctx, formatDeletedSetKey(db.cfg.recordType), record.GetId())
}
}
return nil
})
}
// Close closes the underlying redis connection and any watchers.
func (db *DB) Close() error {
var err error
db.closeOnce.Do(func() {
err = db.client.Close()
close(db.closed)
})
return nil
return err
}
// Put sets new record for given id with input data.
func (db *DB) Put(ctx context.Context, id string, data *anypb.Any) (err error) {
c := db.pool.Get()
_, span := trace.StartSpan(ctx, "databroker.redis.Put")
// Delete marks a record as deleted.
func (db *DB) Delete(ctx context.Context, id string) (err error) {
_, span := trace.StartSpan(ctx, "databroker.redis.Delete")
defer span.End()
defer recordOperation(ctx, time.Now(), "put", err)
defer c.Close()
defer func(start time.Time) { recordOperation(ctx, start, "delete", err) }(time.Now())
record, err := db.Get(ctx, id)
if err != nil {
record = new(databroker.Record)
record.CreatedAt = ptypes.TimestampNow()
}
var record *databroker.Record
err = db.incrementVersion(ctx,
func(tx *redis.Tx, version int64) error {
var err error
record, err = redisGetRecord(ctx, tx, db.cfg.recordType, id)
if errors.Is(err, redis.Nil) {
// nothing to do, as the record doesn't exist
return nil
} else if err != nil {
return err
}
lastVersion, err := redis.Int64(c.Do("INCR", db.lastVersionKey))
if err != nil {
return err
}
record.Data = data
record.ModifiedAt = ptypes.TimestampNow()
record.Type = db.recordType
record.Id = id
record.Version = fmt.Sprintf("%012X", lastVersion)
b, err := proto.Marshal(record)
if err != nil {
return err
}
cmds := []map[string][]interface{}{
{"MULTI": nil},
{"HSET": {db.recordType, id, string(b)}},
{"ZADD": {db.versionSet, lastVersion, id}},
{"PUBLISH": {db.lastVersionChannelKey, lastVersion}},
}
if err := db.tx(c, cmds); err != nil {
return err
}
return nil
// mark it as deleted
record.DeletedAt = timestamppb.Now()
return nil
},
func(p redis.Pipeliner, version int64) error {
err := redisSetRecord(ctx, p, db.cfg.recordType, record)
if err != nil {
return err
}
// add it to the collection of deleted entries
p.SAdd(ctx, formatDeletedSetKey(db.cfg.recordType), record.GetId())
return nil
})
return err
}
// Get retrieves a record from redis.
func (db *DB) Get(ctx context.Context, id string) (rec *databroker.Record, err error) {
c := db.pool.Get()
// Get gets a record.
func (db *DB) Get(ctx context.Context, id string) (record *databroker.Record, err error) {
_, span := trace.StartSpan(ctx, "databroker.redis.Get")
defer span.End()
defer recordOperation(ctx, time.Now(), "get", err)
defer c.Close()
defer func(start time.Time) { recordOperation(ctx, start, "get", err) }(time.Now())
b, err := redis.Bytes(c.Do("HGET", db.recordType, id))
record, err = redisGetRecord(ctx, db.client, db.cfg.recordType, id)
return record, err
}
// List lists all the records changed since the sinceVersion. Records are sorted in version order.
func (db *DB) List(ctx context.Context, sinceVersion string) (records []*databroker.Record, err error) {
_, span := trace.StartSpan(ctx, "databroker.redis.List")
defer span.End()
defer func(start time.Time) { recordOperation(ctx, start, "list", err) }(time.Now())
var ids []string
ids, err = redisListIDsSince(ctx, db.client, db.cfg.recordType, sinceVersion)
if err != nil {
return nil, err
}
records, err = redisGetRecords(ctx, db.client, db.cfg.recordType, ids)
return records, err
}
// Put updates a record.
func (db *DB) Put(ctx context.Context, id string, data *anypb.Any) (err error) {
_, span := trace.StartSpan(ctx, "databroker.redis.Put")
defer span.End()
defer func(start time.Time) { recordOperation(ctx, start, "put", err) }(time.Now())
var record *databroker.Record
err = db.incrementVersion(ctx,
func(tx *redis.Tx, version int64) error {
var err error
record, err = redisGetRecord(ctx, db.client, db.cfg.recordType, id)
if errors.Is(err, redis.Nil) {
record = new(databroker.Record)
record.CreatedAt = timestamppb.Now()
} else if err != nil {
return err
}
record.ModifiedAt = timestamppb.Now()
record.Type = db.cfg.recordType
record.Id = id
record.Data = data
record.Version = formatVersion(version)
return nil
},
func(p redis.Pipeliner, version int64) error {
return redisSetRecord(ctx, p, db.cfg.recordType, record)
})
return err
}
// Watch returns a channel that is signaled any time the last version is incremented (ie on Put/Delete).
func (db *DB) Watch(ctx context.Context) <-chan struct{} {
s := signal.New()
ch := s.Bind()
go func() {
defer s.Unbind(ch)
defer close(ch)
// force a check
poll := time.NewTicker(watchPollInterval)
defer poll.Stop()
// use pub/sub for quicker notify
pubsub := db.client.Subscribe(ctx, formatLastVersionChannelKey(db.cfg.recordType))
defer func() { _ = pubsub.Close() }()
pubsubCh := pubsub.Channel()
var lastVersion int64
for {
v, err := redisGetLastVersion(ctx, db.client, db.cfg.recordType)
if err != nil {
log.Error().Err(err).Msg("redis: error retrieving last version")
} else if v != lastVersion {
// don't broadcast the first time
if lastVersion != 0 {
s.Broadcast()
}
lastVersion = v
}
select {
case <-ctx.Done():
return
case <-db.closed:
return
case <-poll.C:
case <-pubsubCh:
// re-check
}
}
}()
return ch
}
// incrementVersion increments the last version key, runs the code in `query`, then attempts to commit the code in
// `commit`. If the last version changes in the interim, we will retry the transaction.
func (db *DB) incrementVersion(ctx context.Context,
query func(tx *redis.Tx, version int64) error,
commit func(p redis.Pipeliner, version int64) error,
) error {
// code is modeled on https://pkg.go.dev/github.com/go-redis/redis/v8#example-Client.Watch
txf := func(tx *redis.Tx) error {
version, err := redisGetLastVersion(ctx, tx, db.cfg.recordType)
if err != nil {
return err
}
version++
err = query(tx, version)
if err != nil {
return err
}
// the `commit` code is run in a transaction so that the EXEC cmd will run for the original redis watch
_, err = tx.TxPipelined(ctx, func(p redis.Pipeliner) error {
err := commit(p, version)
if err != nil {
return err
}
p.Set(ctx, formatLastVersionKey(db.cfg.recordType), version, 0)
p.Publish(ctx, formatLastVersionChannelKey(db.cfg.recordType), version)
return nil
})
return err
}
for i := 0; i < maxTransactionRetries; i++ {
err := db.client.Watch(ctx, txf, formatLastVersionKey(db.cfg.recordType))
if errors.Is(err, redis.TxFailedErr) {
continue // retry
} else if err != nil {
return err
}
return nil // tx was successful
}
return ErrExceededMaxRetries
}
func redisGetLastVersion(ctx context.Context, c redis.Cmdable, recordType string) (int64, error) {
version, err := c.Get(ctx, formatLastVersionKey(recordType)).Int64()
if errors.Is(err, redis.Nil) {
version = 0
} else if err != nil {
return 0, err
}
return version, nil
}
func redisGetRecord(ctx context.Context, c redis.Cmdable, recordType string, id string) (*databroker.Record, error) {
records, err := redisGetRecords(ctx, c, recordType, []string{id})
if err != nil {
return nil, err
} else if len(records) < 1 {
return nil, redis.Nil
}
return records[0], nil
}
func redisGetRecords(ctx context.Context, c redis.Cmdable, recordType string, ids []string) ([]*databroker.Record, error) {
if len(ids) == 0 {
return nil, nil
}
results, err := c.HMGet(ctx, formatRecordsKey(recordType), ids...).Result()
if err != nil {
return nil, err
}
return db.toPbRecord(b)
records := make([]*databroker.Record, 0, len(results))
for _, result := range results {
// results are returned as either nil or a string
if result == nil {
continue
}
rawstr, ok := result.(string)
if !ok {
continue
}
var record databroker.Record
err := proto.Unmarshal([]byte(rawstr), &record)
if err != nil {
continue
}
records = append(records, &record)
}
return records, nil
}
// GetAll retrieves all records from redis.
func (db *DB) GetAll(ctx context.Context) (recs []*databroker.Record, err error) {
_, span := trace.StartSpan(ctx, "databroker.redis.GetAll")
defer span.End()
defer recordOperation(ctx, time.Now(), "get_all", err)
return db.getAll(ctx, func(record *databroker.Record) bool { return true })
func redisListIDsSince(ctx context.Context,
c redis.Cmdable, recordType string,
sinceVersion string,
) ([]string, error) {
v, err := strconv.ParseInt(sinceVersion, 16, 64)
if err != nil {
v = 0
}
rng := &redis.ZRangeBy{
Min: fmt.Sprintf("(%d", v),
Max: "+inf",
}
return c.ZRangeByScore(ctx, formatVersionSetKey(recordType), rng).Result()
}
// List retrieves all records since given version.
//
// "version" is in hex format, invalid version will be treated as 0.
func (db *DB) List(ctx context.Context, sinceVersion string) (rec []*databroker.Record, err error) {
c := db.pool.Get()
_, span := trace.StartSpan(ctx, "databroker.redis.List")
defer span.End()
defer recordOperation(ctx, time.Now(), "list", err)
defer c.Close()
v, err := strconv.ParseUint(sinceVersion, 16, 64)
func redisSetRecord(ctx context.Context, p redis.Pipeliner, recordType string, record *databroker.Record) error {
v, err := strconv.ParseInt(record.GetVersion(), 16, 64)
if err != nil {
v = 0
}
ids, err := redis.Strings(c.Do("ZRANGEBYSCORE", db.versionSet, fmt.Sprintf("(%d", v), "+inf"))
if err != nil {
return nil, err
}
pbRecords := make([]*databroker.Record, 0, len(ids))
for _, id := range ids {
b, err := redis.Bytes(c.Do("HGET", db.recordType, id))
if err != nil {
return nil, err
}
pbRecord, err := db.toPbRecord(b)
if err != nil {
return nil, err
}
pbRecords = append(pbRecords, pbRecord)
}
return pbRecords, nil
}
// Delete sets a record DeletedAt field and set its TTL.
func (db *DB) Delete(ctx context.Context, id string) (err error) {
c := db.pool.Get()
_, span := trace.StartSpan(ctx, "databroker.redis.Delete")
defer span.End()
defer recordOperation(ctx, time.Now(), "delete", err)
defer c.Close()
r, err := db.Get(ctx, id)
if err != nil {
return fmt.Errorf("failed to get record: %w", err)
}
lastVersion, err := redis.Int64(c.Do("INCR", db.lastVersionKey))
raw, err := proto.Marshal(record)
if err != nil {
return err
}
r.DeletedAt = ptypes.TimestampNow()
r.Version = fmt.Sprintf("%012X", lastVersion)
b, err := proto.Marshal(r)
if err != nil {
return err
}
cmds := []map[string][]interface{}{
{"MULTI": nil},
{"HSET": {db.recordType, id, string(b)}},
{"SADD": {db.deletedSet, id}},
{"ZADD": {db.versionSet, lastVersion, id}},
{"PUBLISH": {db.lastVersionChannelKey, lastVersion}},
}
if err := db.tx(c, cmds); err != nil {
return err
}
// store the record in the hash
p.HSet(ctx, formatRecordsKey(recordType), record.GetId(), string(raw))
// set its score for sorting by version
p.ZAdd(ctx, formatVersionSetKey(recordType), &redis.Z{
Score: float64(v),
Member: record.GetId(),
})
return nil
}
// ClearDeleted clears all the currently deleted records older than the given cutoff.
func (db *DB) ClearDeleted(ctx context.Context, cutoff time.Time) {
c := db.pool.Get()
_, span := trace.StartSpan(ctx, "databroker.redis.ClearDeleted")
defer span.End()
var opErr error
defer func(startTime time.Time) {
recordOperation(ctx, startTime, "clear_deleted", opErr)
}(time.Now())
defer c.Close()
ids, _ := redis.Strings(c.Do("SMEMBERS", db.deletedSet))
for _, id := range ids {
b, _ := redis.Bytes(c.Do("HGET", db.recordType, id))
record, err := db.toPbRecord(b)
if err != nil {
continue
}
ts, _ := ptypes.Timestamp(record.DeletedAt)
if ts.Before(cutoff) {
cmds := []map[string][]interface{}{
{"MULTI": nil},
{"HDEL": {db.recordType, id}},
{"ZREM": {db.versionSet, id}},
{"SREM": {db.deletedSet, id}},
}
opErr = db.tx(c, cmds)
}
}
func formatDeletedSetKey(recordType string) string {
return fmt.Sprintf("%s_deleted_set", recordType)
}
// doNotifyLoop receives event from redis and send signal to the channel.
func (db *DB) doNotifyLoop(ctx context.Context, ch chan struct{}) {
eb := backoff.NewExponentialBackOff()
psConn := db.pool.Get()
psc := redis.PubSubConn{Conn: psConn}
defer func(psc *redis.PubSubConn) {
psc.Conn.Close()
}(&psc)
if err := psc.Subscribe(db.lastVersionChannelKey); err != nil {
log.Error().Err(err).Msg("failed to subscribe to version set channel")
return
}
for {
select {
case <-db.closed:
return
case <-ctx.Done():
return
default:
}
switch v := psc.Receive().(type) {
case redis.Message:
log.Debug().Str("action", string(v.Data)).Msg("got redis message")
recordOperation(ctx, time.Now(), "sub_received", nil)
select {
case <-db.closed:
return
case <-ctx.Done():
log.Warn().Err(ctx.Err()).Msg("context done, stop receive from redis channel")
return
default:
db.notifyChMu.Lock()
select {
case <-db.closed:
db.notifyChMu.Unlock()
return
case <-ctx.Done():
db.notifyChMu.Unlock()
log.Warn().Err(ctx.Err()).Msg("context done while holding notify lock, stop receive from redis channel")
return
case ch <- struct{}{}:
}
db.notifyChMu.Unlock()
}
case error:
log.Warn().Err(v).Msg("failed to receive from redis channel")
recordOperation(ctx, time.Now(), "sub_received", v)
if _, ok := v.(net.Error); ok {
return
}
time.Sleep(eb.NextBackOff())
log.Warn().Msg("retry with new connection")
_ = psc.Conn.Close()
psc.Conn = db.pool.Get()
_ = psc.Subscribe(db.lastVersionChannelKey)
}
}
func formatLastVersionChannelKey(recordType string) string {
return fmt.Sprintf("%s_last_version_ch", recordType)
}
// watch runs the doNotifyLoop. It returns when ctx was done or doNotifyLoop exits.
func (db *DB) watch(ctx context.Context, ch chan struct{}) {
defer func() {
db.notifyChMu.Lock()
close(ch)
db.notifyChMu.Unlock()
}()
done := make(chan struct{})
go func() {
defer close(done)
db.doNotifyLoop(ctx, ch)
}()
select {
case <-db.closed:
case <-ctx.Done():
case <-done:
}
func formatLastVersionKey(recordType string) string {
return fmt.Sprintf("%s_last_version", recordType)
}
// Watch returns a channel to the caller, when there is a change to the version set,
// sending message to the channel to notify the caller.
func (db *DB) Watch(ctx context.Context) <-chan struct{} {
ch := make(chan struct{})
go db.watch(ctx, ch)
return ch
func formatRecordsKey(recordType string) string {
return recordType
}
func (db *DB) getAll(_ context.Context, filter func(record *databroker.Record) bool) ([]*databroker.Record, error) {
c := db.pool.Get()
defer c.Close()
iter := 0
records := make([]*databroker.Record, 0)
for {
arr, err := redis.Values(c.Do("HSCAN", db.recordType, iter, "MATCH", "*"))
if err != nil {
return nil, err
}
iter, _ = redis.Int(arr[0], nil)
pairs, _ := redis.StringMap(arr[1], nil)
for _, v := range pairs {
record, err := db.toPbRecord([]byte(v))
if err != nil {
return nil, err
}
if filter(record) {
records = append(records, record)
}
}
if iter == 0 {
break
}
}
return records, nil
func formatVersion(version int64) string {
return fmt.Sprintf("%012d", version)
}
func (db *DB) toPbRecord(b []byte) (*databroker.Record, error) {
record := &databroker.Record{}
if err := proto.Unmarshal(b, record); err != nil {
return nil, err
}
return record, nil
}
func (db *DB) tx(c redis.Conn, commands []map[string][]interface{}) error {
for _, m := range commands {
for cmd, args := range m {
if err := c.Send(cmd, args...); err != nil {
return err
}
}
}
_, err := c.Do("EXEC")
return err
func formatVersionSetKey(recordType string) string {
return fmt.Sprintf("%s_version_set", recordType)
}
func recordOperation(ctx context.Context, startTime time.Time, operation string, err error) {