mirror of
https://github.com/pomerium/pomerium.git
synced 2025-04-30 02:46:30 +02:00
* chore(deps): bump github.com/golangci/golangci-lint Bumps [github.com/golangci/golangci-lint](https://github.com/golangci/golangci-lint) from 1.50.1 to 1.51.2. - [Release notes](https://github.com/golangci/golangci-lint/releases) - [Changelog](https://github.com/golangci/golangci-lint/blob/master/CHANGELOG.md) - [Commits](https://github.com/golangci/golangci-lint/compare/v1.50.1...v1.51.2) --- updated-dependencies: - dependency-name: github.com/golangci/golangci-lint dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] <support@github.com> * fix lint --------- Signed-off-by: dependabot[bot] <support@github.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Caleb Doxsey <cdoxsey@pomerium.com>
537 lines
15 KiB
Go
537 lines
15 KiB
Go
// Package redis implements the storage.Backend interface for redis.
|
|
package redis
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"sort"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/cenkalti/backoff/v4"
|
|
"github.com/go-redis/redis/v8"
|
|
"google.golang.org/protobuf/proto"
|
|
"google.golang.org/protobuf/types/known/timestamppb"
|
|
|
|
"github.com/pomerium/pomerium/internal/log"
|
|
"github.com/pomerium/pomerium/internal/redisutil"
|
|
"github.com/pomerium/pomerium/internal/signal"
|
|
"github.com/pomerium/pomerium/internal/telemetry/metrics"
|
|
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
|
"github.com/pomerium/pomerium/pkg/cryptutil"
|
|
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
|
"github.com/pomerium/pomerium/pkg/storage"
|
|
)
|
|
|
|
const (
|
|
maxTransactionRetries = 100
|
|
watchPollInterval = 30 * time.Second
|
|
|
|
// we rely on transactions in redis, so all redis-cluster keys need to be
|
|
// on the same node. Using a `hash tag` gives us this capability.
|
|
serverVersionKey = redisutil.KeyPrefix + "server_version"
|
|
lastVersionKey = redisutil.KeyPrefix + "last_version"
|
|
lastVersionChKey = redisutil.KeyPrefix + "last_version_ch"
|
|
recordHashKey = redisutil.KeyPrefix + "records"
|
|
recordTypesSetKey = redisutil.KeyPrefix + "record_types"
|
|
changesSetKey = redisutil.KeyPrefix + "changes"
|
|
optionsKey = redisutil.KeyPrefix + "options"
|
|
|
|
recordTypeChangesKeyTpl = redisutil.KeyPrefix + "changes.%s"
|
|
leaseKeyTpl = redisutil.KeyPrefix + "lease.%s"
|
|
)
|
|
|
|
// custom errors
|
|
var (
|
|
ErrExceededMaxRetries = errors.New("redis: transaction reached maximum number of retries")
|
|
)
|
|
|
|
// Backend implements the storage.Backend on top of redis.
|
|
//
|
|
// What's stored:
|
|
//
|
|
// - last_version: an integer recordVersion number
|
|
// - last_version_ch: a PubSub channel for recordVersion number updates
|
|
// - records: a Hash of records. The hash key is {recordType}/{recordID}, the hash value the protobuf record.
|
|
// - changes: a Sorted Set of all the changes. The score is the recordVersion number, the member the protobuf record.
|
|
// - options: a Hash of options. The hash key is {recordType}, the hash value the protobuf options.
|
|
// - changes.{recordType}: a Sorted Set of the changes for a record type. The score is the current time,
|
|
// the value the record id.
|
|
//
|
|
// Records stored in these keys are typically encrypted.
|
|
type Backend struct {
|
|
cfg *config
|
|
|
|
client redis.UniversalClient
|
|
onChange *signal.Signal
|
|
|
|
closeOnce sync.Once
|
|
closed chan struct{}
|
|
}
|
|
|
|
// New creates a new redis storage backend.
|
|
func New(rawURL string, options ...Option) (*Backend, error) {
|
|
ctx := context.TODO()
|
|
cfg := getConfig(options...)
|
|
backend := &Backend{
|
|
cfg: cfg,
|
|
closed: make(chan struct{}),
|
|
onChange: signal.New(),
|
|
}
|
|
var err error
|
|
backend.client, err = redisutil.NewClientFromURL(rawURL, backend.cfg.tls)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
metrics.AddRedisMetrics(backend.client.PoolStats)
|
|
go backend.listenForVersionChanges(ctx)
|
|
if cfg.expiry != 0 {
|
|
go func() {
|
|
ticker := time.NewTicker(time.Minute)
|
|
defer ticker.Stop()
|
|
for {
|
|
select {
|
|
case <-backend.closed:
|
|
return
|
|
case <-ticker.C:
|
|
}
|
|
|
|
backend.removeChangesBefore(ctx, time.Now().Add(-cfg.expiry))
|
|
}
|
|
}()
|
|
}
|
|
return backend, nil
|
|
}
|
|
|
|
// Close closes the underlying redis connection and any watchers.
|
|
func (backend *Backend) Close() error {
|
|
var err error
|
|
backend.closeOnce.Do(func() {
|
|
err = backend.client.Close()
|
|
close(backend.closed)
|
|
})
|
|
return err
|
|
}
|
|
|
|
// Get gets a record from redis.
|
|
func (backend *Backend) Get(ctx context.Context, recordType, id string) (_ *databroker.Record, err error) {
|
|
_, span := trace.StartSpan(ctx, "databroker.redis.Get")
|
|
defer span.End()
|
|
defer func(start time.Time) { recordOperation(ctx, start, "get", err) }(time.Now())
|
|
|
|
key, field := getHashKey(recordType, id)
|
|
cmd := backend.client.HGet(ctx, key, field)
|
|
raw, err := cmd.Result()
|
|
if err == redis.Nil {
|
|
return nil, storage.ErrNotFound
|
|
} else if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var record databroker.Record
|
|
err = proto.Unmarshal([]byte(raw), &record)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &record, nil
|
|
}
|
|
|
|
// GetOptions gets the options for the given record type.
|
|
func (backend *Backend) GetOptions(ctx context.Context, recordType string) (*databroker.Options, error) {
|
|
raw, err := backend.client.HGet(ctx, optionsKey, recordType).Result()
|
|
if err == redis.Nil {
|
|
// treat no options as an empty set of options
|
|
return new(databroker.Options), nil
|
|
} else if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var options databroker.Options
|
|
err = proto.Unmarshal([]byte(raw), &options)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &options, nil
|
|
}
|
|
|
|
// Lease acquires or renews a lease.
|
|
func (backend *Backend) Lease(ctx context.Context, leaseName, leaseID string, ttl time.Duration) (bool, error) {
|
|
acquired := false
|
|
key := getLeaseKey(leaseName)
|
|
err := backend.client.Watch(ctx, func(tx *redis.Tx) error {
|
|
currentID, err := tx.Get(ctx, key).Result()
|
|
if errors.Is(err, redis.Nil) {
|
|
// lease hasn't been set yet
|
|
} else if err != nil {
|
|
return err
|
|
} else if leaseID != currentID {
|
|
// lease has already been taken
|
|
return nil
|
|
}
|
|
|
|
_, err = tx.Pipelined(ctx, func(p redis.Pipeliner) error {
|
|
if ttl <= 0 {
|
|
p.Del(ctx, key)
|
|
} else {
|
|
p.Set(ctx, key, leaseID, ttl)
|
|
}
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
acquired = ttl > 0
|
|
return nil
|
|
}, key)
|
|
// if the transaction failed someone else must've acquired the lease
|
|
if errors.Is(err, redis.TxFailedErr) {
|
|
acquired = false
|
|
err = nil
|
|
}
|
|
return acquired, err
|
|
}
|
|
|
|
// ListTypes lists all the known record types.
|
|
func (backend *Backend) ListTypes(ctx context.Context) (types []string, err error) {
|
|
ctx, span := trace.StartSpan(ctx, "databroker.redis.ListTypes")
|
|
defer span.End()
|
|
defer func(start time.Time) { recordOperation(ctx, start, "listTypes", err) }(time.Now())
|
|
|
|
cmd := backend.client.SMembers(ctx, recordTypesSetKey)
|
|
types, err = cmd.Result()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
sort.Strings(types)
|
|
return types, nil
|
|
}
|
|
|
|
// Put puts a record into redis.
|
|
func (backend *Backend) Put(ctx context.Context, records []*databroker.Record) (serverVersion uint64, err error) {
|
|
ctx, span := trace.StartSpan(ctx, "databroker.redis.Put")
|
|
defer span.End()
|
|
defer func(start time.Time) { recordOperation(ctx, start, "put", err) }(time.Now())
|
|
|
|
serverVersion, err = backend.getOrCreateServerVersion(ctx)
|
|
if err != nil {
|
|
return serverVersion, err
|
|
}
|
|
|
|
err = backend.put(ctx, records)
|
|
if err != nil {
|
|
return serverVersion, err
|
|
}
|
|
|
|
recordTypes := map[string]struct{}{}
|
|
for _, record := range records {
|
|
recordTypes[record.GetType()] = struct{}{}
|
|
}
|
|
for recordType := range recordTypes {
|
|
err = backend.enforceOptions(ctx, recordType)
|
|
if err != nil {
|
|
return serverVersion, err
|
|
}
|
|
}
|
|
|
|
return serverVersion, nil
|
|
}
|
|
|
|
// SetOptions sets the options for the given record type.
|
|
func (backend *Backend) SetOptions(ctx context.Context, recordType string, options *databroker.Options) error {
|
|
ctx, span := trace.StartSpan(ctx, "databroker.redis.SetOptions")
|
|
defer span.End()
|
|
|
|
bs, err := proto.Marshal(options)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// update the options in the hash set
|
|
err = backend.client.HSet(ctx, optionsKey, recordType, bs).Err()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// possibly re-enforce options
|
|
err = backend.enforceOptions(ctx, recordType)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Sync returns a record stream of any records changed after the specified recordVersion.
|
|
func (backend *Backend) Sync(
|
|
ctx context.Context,
|
|
recordType string,
|
|
serverVersion, recordVersion uint64,
|
|
) (storage.RecordStream, error) {
|
|
return newSyncRecordStream(ctx, backend, recordType, serverVersion, recordVersion), nil
|
|
}
|
|
|
|
// SyncLatest returns a record stream of all the records. Some records may be returned twice if the are updated while the
|
|
// stream is streaming.
|
|
func (backend *Backend) SyncLatest(
|
|
ctx context.Context,
|
|
recordType string,
|
|
expr storage.FilterExpression,
|
|
) (serverVersion, recordVersion uint64, stream storage.RecordStream, err error) {
|
|
serverVersion, err = backend.getOrCreateServerVersion(ctx)
|
|
if err != nil {
|
|
return serverVersion, recordVersion, nil, err
|
|
}
|
|
|
|
recordVersion, err = backend.client.Get(ctx, lastVersionKey).Uint64()
|
|
if errors.Is(err, redis.Nil) {
|
|
// this happens if there are no records
|
|
} else if err != nil {
|
|
return serverVersion, recordVersion, nil, err
|
|
}
|
|
|
|
stream, err = newSyncLatestRecordStream(ctx, backend, recordType, expr)
|
|
return serverVersion, recordVersion, stream, err
|
|
}
|
|
|
|
func (backend *Backend) put(ctx context.Context, records []*databroker.Record) error {
|
|
return backend.incrementVersion(ctx,
|
|
func(tx *redis.Tx, version uint64) error {
|
|
for i, record := range records {
|
|
record.ModifiedAt = timestamppb.Now()
|
|
record.Version = version + uint64(i)
|
|
}
|
|
return nil
|
|
},
|
|
func(p redis.Pipeliner, version uint64) error {
|
|
for i, record := range records {
|
|
bs, err := proto.Marshal(record)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
key, field := getHashKey(record.GetType(), record.GetId())
|
|
if record.DeletedAt != nil {
|
|
p.HDel(ctx, key, field)
|
|
} else {
|
|
p.HSet(ctx, key, field, bs)
|
|
p.ZAdd(ctx, getRecordTypeChangesKey(record.GetType()), &redis.Z{
|
|
Score: float64(record.GetModifiedAt().GetSeconds()) + float64(i)/float64(len(records)),
|
|
Member: record.GetId(),
|
|
})
|
|
}
|
|
p.ZAdd(ctx, changesSetKey, &redis.Z{
|
|
Score: float64(version) + float64(i),
|
|
Member: bs,
|
|
})
|
|
p.SAdd(ctx, recordTypesSetKey, record.GetType())
|
|
}
|
|
return nil
|
|
})
|
|
}
|
|
|
|
// enforceOptions enforces the options for the given record type.
|
|
func (backend *Backend) enforceOptions(ctx context.Context, recordType string) error {
|
|
ctx, span := trace.StartSpan(ctx, "databroker.redis.enforceOptions")
|
|
defer span.End()
|
|
|
|
options, err := backend.GetOptions(ctx, recordType)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// nothing to do if capacity isn't set
|
|
if options.Capacity == nil {
|
|
return nil
|
|
}
|
|
|
|
key := getRecordTypeChangesKey(recordType)
|
|
|
|
// find oldest records that exceed the capacity
|
|
recordIDs, err := backend.client.ZRevRange(ctx, key, int64(*options.Capacity), -1).Result()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// for each record, delete it
|
|
for _, recordID := range recordIDs {
|
|
record, err := backend.Get(ctx, recordType, recordID)
|
|
if err == nil {
|
|
// mark the record as deleted and re-submit
|
|
record.DeletedAt = timestamppb.Now()
|
|
err = backend.put(ctx, []*databroker.Record{record})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
} else if errors.Is(err, storage.ErrNotFound) {
|
|
// ignore
|
|
} else if err != nil {
|
|
return err
|
|
}
|
|
|
|
// remove the member from the collection
|
|
_, err = backend.client.ZRem(ctx, key, recordID).Result()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// incrementVersion increments the last recordVersion key, runs the code in `query`, then attempts to commit the code in
|
|
// `commit`. If the last recordVersion changes in the interim, we will retry the transaction.
|
|
func (backend *Backend) incrementVersion(ctx context.Context,
|
|
query func(tx *redis.Tx, recordVersion uint64) error,
|
|
commit func(p redis.Pipeliner, recordVersion uint64) error,
|
|
) error {
|
|
// code is modeled on https://pkg.go.dev/github.com/go-redis/redis/v8#example-Client.Watch
|
|
txf := func(tx *redis.Tx) error {
|
|
version, err := tx.Get(ctx, lastVersionKey).Uint64()
|
|
if errors.Is(err, redis.Nil) {
|
|
version = 0
|
|
} else if err != nil {
|
|
return err
|
|
}
|
|
version++
|
|
|
|
err = query(tx, version)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// the `commit` code is run in a transaction so that the EXEC cmd will run for the original redis watch
|
|
_, err = tx.TxPipelined(ctx, func(p redis.Pipeliner) error {
|
|
err := commit(p, version)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
p.Set(ctx, lastVersionKey, version, 0)
|
|
p.Publish(ctx, lastVersionChKey, version)
|
|
return nil
|
|
})
|
|
return err
|
|
}
|
|
|
|
bo := backoff.NewExponentialBackOff()
|
|
bo.MaxElapsedTime = 0
|
|
for i := 0; i < maxTransactionRetries; i++ {
|
|
err := backend.client.Watch(ctx, txf, lastVersionKey)
|
|
if errors.Is(err, redis.TxFailedErr) {
|
|
select {
|
|
case <-ctx.Done():
|
|
return ctx.Err()
|
|
case <-time.After(bo.NextBackOff()):
|
|
}
|
|
continue // retry
|
|
} else if err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil // tx was successful
|
|
}
|
|
|
|
return ErrExceededMaxRetries
|
|
}
|
|
|
|
func (backend *Backend) listenForVersionChanges(ctx context.Context) {
|
|
ctx, cancel := context.WithCancel(ctx)
|
|
go func() {
|
|
<-backend.closed
|
|
cancel()
|
|
}()
|
|
|
|
bo := backoff.NewExponentialBackOff()
|
|
bo.MaxElapsedTime = 0
|
|
|
|
outer:
|
|
for {
|
|
pubsub := backend.client.Subscribe(ctx, lastVersionChKey)
|
|
for {
|
|
msg, err := pubsub.Receive(ctx)
|
|
if err != nil {
|
|
_ = pubsub.Close()
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
case <-time.After(bo.NextBackOff()):
|
|
}
|
|
continue outer
|
|
}
|
|
bo.Reset()
|
|
|
|
switch msg.(type) {
|
|
case *redis.Message:
|
|
backend.onChange.Broadcast(ctx)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (backend *Backend) removeChangesBefore(ctx context.Context, cutoff time.Time) {
|
|
for {
|
|
cmd := backend.client.ZRangeByScore(ctx, changesSetKey, &redis.ZRangeBy{
|
|
Min: "-inf",
|
|
Max: "+inf",
|
|
Offset: 0,
|
|
Count: 1,
|
|
})
|
|
results, err := cmd.Result()
|
|
if err != nil {
|
|
log.Error(ctx).Err(err).Msg("redis: error retrieving changes for expiration")
|
|
return
|
|
}
|
|
|
|
// nothing left to do
|
|
if len(results) == 0 {
|
|
return
|
|
}
|
|
|
|
var record databroker.Record
|
|
err = proto.Unmarshal([]byte(results[0]), &record)
|
|
if err != nil {
|
|
log.Warn(ctx).Err(err).Msg("redis: invalid record detected")
|
|
record.ModifiedAt = timestamppb.New(cutoff.Add(-time.Second)) // set the modified so will delete it
|
|
}
|
|
|
|
// if the record's modified timestamp is after the cutoff, we're all done, so break
|
|
if record.GetModifiedAt().AsTime().After(cutoff) {
|
|
break
|
|
}
|
|
|
|
// remove the record
|
|
err = backend.client.ZRem(ctx, changesSetKey, results[0]).Err()
|
|
if err != nil {
|
|
log.Error(ctx).Err(err).Msg("redis: error removing member")
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
func (backend *Backend) getOrCreateServerVersion(ctx context.Context) (serverVersion uint64, err error) {
|
|
serverVersion, err = backend.client.Get(ctx, serverVersionKey).Uint64()
|
|
// if the server version hasn't been set yet, set it to a random value and immediately retrieve it
|
|
// this should properly handle a data race by only setting the key if it doesn't already exist
|
|
if errors.Is(err, redis.Nil) {
|
|
_, _ = backend.client.SetNX(ctx, serverVersionKey, cryptutil.NewRandomUInt64(), 0).Result()
|
|
serverVersion, err = backend.client.Get(ctx, serverVersionKey).Uint64()
|
|
}
|
|
if err != nil {
|
|
return 0, fmt.Errorf("redis: error retrieving server version: %w", err)
|
|
}
|
|
return serverVersion, err
|
|
}
|
|
|
|
func getLeaseKey(leaseName string) string {
|
|
return fmt.Sprintf(leaseKeyTpl, leaseName)
|
|
}
|
|
|
|
func getRecordTypeChangesKey(recordType string) string {
|
|
return fmt.Sprintf(recordTypeChangesKeyTpl, recordType)
|
|
}
|
|
|
|
func getHashKey(recordType, id string) (key, field string) {
|
|
return recordHashKey, fmt.Sprintf("%s/%s", recordType, id)
|
|
}
|