mirror of
https://github.com/pomerium/pomerium.git
synced 2025-04-30 02:46:30 +02:00
401 lines
11 KiB
Go
401 lines
11 KiB
Go
// Package redis implements the storage.Backend interface for redis.
|
|
package redis
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"strconv"
|
|
"sync"
|
|
"time"
|
|
|
|
redis "github.com/go-redis/redis/v8"
|
|
"github.com/golang/protobuf/proto"
|
|
"google.golang.org/protobuf/types/known/anypb"
|
|
"google.golang.org/protobuf/types/known/timestamppb"
|
|
|
|
"github.com/pomerium/pomerium/config"
|
|
"github.com/pomerium/pomerium/internal/log"
|
|
"github.com/pomerium/pomerium/internal/signal"
|
|
"github.com/pomerium/pomerium/internal/telemetry/metrics"
|
|
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
|
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
|
)
|
|
|
|
// Name of the storage backend.
|
|
const Name = config.StorageRedisName
|
|
|
|
const (
|
|
maxTransactionRetries = 100
|
|
watchPollInterval = 30 * time.Second
|
|
)
|
|
|
|
// custom errors
|
|
var (
|
|
ErrExceededMaxRetries = errors.New("redis: transaction reached maximum number of retries")
|
|
)
|
|
|
|
// DB implements the storage.Backend on top of redis.
|
|
type DB struct {
|
|
cfg *dbConfig
|
|
|
|
client *redis.Client
|
|
|
|
closeOnce sync.Once
|
|
closed chan struct{}
|
|
}
|
|
|
|
// New creates a new redis storage backend.
|
|
func New(rawURL string, options ...Option) (*DB, error) {
|
|
db := &DB{
|
|
cfg: getConfig(options...),
|
|
closed: make(chan struct{}),
|
|
}
|
|
opts, err := redis.ParseURL(rawURL)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
// when using TLS, the TLS config will not be set to nil, in which case we replace it with our own
|
|
if opts.TLSConfig != nil {
|
|
opts.TLSConfig = db.cfg.tls
|
|
}
|
|
db.client = redis.NewClient(opts)
|
|
metrics.AddRedisMetrics(db.client.PoolStats)
|
|
return db, nil
|
|
}
|
|
|
|
// ClearDeleted clears all the deleted records older than the cutoff time.
|
|
func (db *DB) ClearDeleted(ctx context.Context, cutoff time.Time) {
|
|
var err error
|
|
|
|
_, span := trace.StartSpan(ctx, "databroker.redis.ClearDeleted")
|
|
defer span.End()
|
|
defer func(start time.Time) { recordOperation(ctx, start, "clear_deleted", err) }(time.Now())
|
|
|
|
ids, _ := db.client.SMembers(ctx, formatDeletedSetKey(db.cfg.recordType)).Result()
|
|
records, _ := redisGetRecords(ctx, db.client, db.cfg.recordType, ids)
|
|
_, err = db.client.Pipelined(ctx, func(p redis.Pipeliner) error {
|
|
for _, record := range records {
|
|
if record.GetDeletedAt().AsTime().Before(cutoff) {
|
|
p.HDel(ctx, formatRecordsKey(db.cfg.recordType), record.GetId())
|
|
p.ZRem(ctx, formatVersionSetKey(db.cfg.recordType), record.GetId())
|
|
p.SRem(ctx, formatDeletedSetKey(db.cfg.recordType), record.GetId())
|
|
}
|
|
}
|
|
return nil
|
|
})
|
|
}
|
|
|
|
// Close closes the underlying redis connection and any watchers.
|
|
func (db *DB) Close() error {
|
|
var err error
|
|
db.closeOnce.Do(func() {
|
|
err = db.client.Close()
|
|
close(db.closed)
|
|
})
|
|
return err
|
|
}
|
|
|
|
// Delete marks a record as deleted.
|
|
func (db *DB) Delete(ctx context.Context, id string) (err error) {
|
|
_, span := trace.StartSpan(ctx, "databroker.redis.Delete")
|
|
defer span.End()
|
|
defer func(start time.Time) { recordOperation(ctx, start, "delete", err) }(time.Now())
|
|
|
|
var record *databroker.Record
|
|
err = db.incrementVersion(ctx,
|
|
func(tx *redis.Tx, version int64) error {
|
|
var err error
|
|
record, err = redisGetRecord(ctx, tx, db.cfg.recordType, id)
|
|
if errors.Is(err, redis.Nil) {
|
|
// nothing to do, as the record doesn't exist
|
|
return nil
|
|
} else if err != nil {
|
|
return err
|
|
}
|
|
|
|
// mark it as deleted
|
|
record.DeletedAt = timestamppb.Now()
|
|
record.Version = formatVersion(version)
|
|
|
|
return nil
|
|
},
|
|
func(p redis.Pipeliner, version int64) error {
|
|
err := redisSetRecord(ctx, p, db.cfg.recordType, record)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// add it to the collection of deleted entries
|
|
p.SAdd(ctx, formatDeletedSetKey(db.cfg.recordType), record.GetId())
|
|
return nil
|
|
})
|
|
return err
|
|
}
|
|
|
|
// Get gets a record.
|
|
func (db *DB) Get(ctx context.Context, id string) (record *databroker.Record, err error) {
|
|
_, span := trace.StartSpan(ctx, "databroker.redis.Get")
|
|
defer span.End()
|
|
defer func(start time.Time) { recordOperation(ctx, start, "get", err) }(time.Now())
|
|
|
|
record, err = redisGetRecord(ctx, db.client, db.cfg.recordType, id)
|
|
return record, err
|
|
}
|
|
|
|
// List lists all the records changed since the sinceVersion. Records are sorted in version order.
|
|
func (db *DB) List(ctx context.Context, sinceVersion string) (records []*databroker.Record, err error) {
|
|
_, span := trace.StartSpan(ctx, "databroker.redis.List")
|
|
defer span.End()
|
|
defer func(start time.Time) { recordOperation(ctx, start, "list", err) }(time.Now())
|
|
|
|
var ids []string
|
|
ids, err = redisListIDsSince(ctx, db.client, db.cfg.recordType, sinceVersion)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
records, err = redisGetRecords(ctx, db.client, db.cfg.recordType, ids)
|
|
return records, err
|
|
}
|
|
|
|
// Put updates a record.
|
|
func (db *DB) Put(ctx context.Context, id string, data *anypb.Any) (err error) {
|
|
_, span := trace.StartSpan(ctx, "databroker.redis.Put")
|
|
defer span.End()
|
|
defer func(start time.Time) { recordOperation(ctx, start, "put", err) }(time.Now())
|
|
|
|
var record *databroker.Record
|
|
err = db.incrementVersion(ctx,
|
|
func(tx *redis.Tx, version int64) error {
|
|
var err error
|
|
record, err = redisGetRecord(ctx, db.client, db.cfg.recordType, id)
|
|
if errors.Is(err, redis.Nil) {
|
|
record = new(databroker.Record)
|
|
record.CreatedAt = timestamppb.Now()
|
|
} else if err != nil {
|
|
return err
|
|
}
|
|
|
|
record.ModifiedAt = timestamppb.Now()
|
|
record.Type = db.cfg.recordType
|
|
record.Id = id
|
|
record.Data = data
|
|
record.Version = formatVersion(version)
|
|
|
|
return nil
|
|
},
|
|
func(p redis.Pipeliner, version int64) error {
|
|
return redisSetRecord(ctx, p, db.cfg.recordType, record)
|
|
})
|
|
return err
|
|
}
|
|
|
|
// Watch returns a channel that is signaled any time the last version is incremented (ie on Put/Delete).
|
|
func (db *DB) Watch(ctx context.Context) <-chan struct{} {
|
|
s := signal.New()
|
|
ch := s.Bind()
|
|
go func() {
|
|
defer s.Unbind(ch)
|
|
defer close(ch)
|
|
|
|
// force a check
|
|
poll := time.NewTicker(watchPollInterval)
|
|
defer poll.Stop()
|
|
|
|
// use pub/sub for quicker notify
|
|
pubsub := db.client.Subscribe(ctx, formatLastVersionChannelKey(db.cfg.recordType))
|
|
defer func() { _ = pubsub.Close() }()
|
|
pubsubCh := pubsub.Channel()
|
|
|
|
var lastVersion int64
|
|
|
|
for {
|
|
v, err := redisGetLastVersion(ctx, db.client, db.cfg.recordType)
|
|
if err != nil {
|
|
log.Error().Err(err).Msg("redis: error retrieving last version")
|
|
} else if v != lastVersion {
|
|
// don't broadcast the first time
|
|
if lastVersion != 0 {
|
|
s.Broadcast()
|
|
}
|
|
lastVersion = v
|
|
}
|
|
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
case <-db.closed:
|
|
return
|
|
case <-poll.C:
|
|
case <-pubsubCh:
|
|
// re-check
|
|
}
|
|
}
|
|
}()
|
|
return ch
|
|
}
|
|
|
|
// incrementVersion increments the last version key, runs the code in `query`, then attempts to commit the code in
|
|
// `commit`. If the last version changes in the interim, we will retry the transaction.
|
|
func (db *DB) incrementVersion(ctx context.Context,
|
|
query func(tx *redis.Tx, version int64) error,
|
|
commit func(p redis.Pipeliner, version int64) error,
|
|
) error {
|
|
// code is modeled on https://pkg.go.dev/github.com/go-redis/redis/v8#example-Client.Watch
|
|
txf := func(tx *redis.Tx) error {
|
|
version, err := redisGetLastVersion(ctx, tx, db.cfg.recordType)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
version++
|
|
|
|
err = query(tx, version)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// the `commit` code is run in a transaction so that the EXEC cmd will run for the original redis watch
|
|
_, err = tx.TxPipelined(ctx, func(p redis.Pipeliner) error {
|
|
err := commit(p, version)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
p.Set(ctx, formatLastVersionKey(db.cfg.recordType), version, 0)
|
|
p.Publish(ctx, formatLastVersionChannelKey(db.cfg.recordType), version)
|
|
return nil
|
|
})
|
|
return err
|
|
}
|
|
|
|
for i := 0; i < maxTransactionRetries; i++ {
|
|
err := db.client.Watch(ctx, txf, formatLastVersionKey(db.cfg.recordType))
|
|
if errors.Is(err, redis.TxFailedErr) {
|
|
continue // retry
|
|
} else if err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil // tx was successful
|
|
}
|
|
|
|
return ErrExceededMaxRetries
|
|
}
|
|
|
|
func redisGetLastVersion(ctx context.Context, c redis.Cmdable, recordType string) (int64, error) {
|
|
version, err := c.Get(ctx, formatLastVersionKey(recordType)).Int64()
|
|
if errors.Is(err, redis.Nil) {
|
|
version = 0
|
|
} else if err != nil {
|
|
return 0, err
|
|
}
|
|
return version, nil
|
|
}
|
|
|
|
func redisGetRecord(ctx context.Context, c redis.Cmdable, recordType string, id string) (*databroker.Record, error) {
|
|
records, err := redisGetRecords(ctx, c, recordType, []string{id})
|
|
if err != nil {
|
|
return nil, err
|
|
} else if len(records) < 1 {
|
|
return nil, redis.Nil
|
|
}
|
|
return records[0], nil
|
|
}
|
|
|
|
func redisGetRecords(ctx context.Context, c redis.Cmdable, recordType string, ids []string) ([]*databroker.Record, error) {
|
|
if len(ids) == 0 {
|
|
return nil, nil
|
|
}
|
|
|
|
results, err := c.HMGet(ctx, formatRecordsKey(recordType), ids...).Result()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
records := make([]*databroker.Record, 0, len(results))
|
|
for _, result := range results {
|
|
// results are returned as either nil or a string
|
|
if result == nil {
|
|
continue
|
|
}
|
|
rawstr, ok := result.(string)
|
|
if !ok {
|
|
continue
|
|
}
|
|
var record databroker.Record
|
|
err := proto.Unmarshal([]byte(rawstr), &record)
|
|
if err != nil {
|
|
continue
|
|
}
|
|
records = append(records, &record)
|
|
}
|
|
return records, nil
|
|
}
|
|
|
|
func redisListIDsSince(ctx context.Context,
|
|
c redis.Cmdable, recordType string,
|
|
sinceVersion string,
|
|
) ([]string, error) {
|
|
v, err := strconv.ParseInt(sinceVersion, 16, 64)
|
|
if err != nil {
|
|
v = 0
|
|
}
|
|
rng := &redis.ZRangeBy{
|
|
Min: fmt.Sprintf("(%d", v),
|
|
Max: "+inf",
|
|
}
|
|
return c.ZRangeByScore(ctx, formatVersionSetKey(recordType), rng).Result()
|
|
}
|
|
|
|
func redisSetRecord(ctx context.Context, p redis.Pipeliner, recordType string, record *databroker.Record) error {
|
|
v, err := strconv.ParseInt(record.GetVersion(), 16, 64)
|
|
if err != nil {
|
|
v = 0
|
|
}
|
|
|
|
raw, err := proto.Marshal(record)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// store the record in the hash
|
|
p.HSet(ctx, formatRecordsKey(recordType), record.GetId(), string(raw))
|
|
// set its score for sorting by version
|
|
p.ZAdd(ctx, formatVersionSetKey(recordType), &redis.Z{
|
|
Score: float64(v),
|
|
Member: record.GetId(),
|
|
})
|
|
|
|
return nil
|
|
}
|
|
|
|
func formatDeletedSetKey(recordType string) string {
|
|
return fmt.Sprintf("%s_deleted_set", recordType)
|
|
}
|
|
|
|
func formatLastVersionChannelKey(recordType string) string {
|
|
return fmt.Sprintf("%s_last_version_ch", recordType)
|
|
}
|
|
|
|
func formatLastVersionKey(recordType string) string {
|
|
return fmt.Sprintf("%s_last_version", recordType)
|
|
}
|
|
|
|
func formatRecordsKey(recordType string) string {
|
|
return recordType
|
|
}
|
|
|
|
func formatVersion(version int64) string {
|
|
return fmt.Sprintf("%012d", version)
|
|
}
|
|
|
|
func formatVersionSetKey(recordType string) string {
|
|
return fmt.Sprintf("%s_version_set", recordType)
|
|
}
|
|
|
|
func recordOperation(ctx context.Context, startTime time.Time, operation string, err error) {
|
|
metrics.RecordStorageOperation(ctx, &metrics.StorageOperationTags{
|
|
Operation: operation,
|
|
Error: err,
|
|
Backend: Name,
|
|
}, time.Since(startTime))
|
|
}
|