mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-19 03:57:17 +02:00
redis storage backend (#1082)
* pkg/storage: add redis storage backend * pkg/storage/redis: set record create time correctly * pkg/storage/redis: add docs * pkg/storage/redis: run test with redis tag only * pkg/storage/redis: use localhost * pkg/storage/redis: use 127.0.0.1 * pkg/storage/redis: honor REDIS_URL env * .github/workflows: add missing config for redis service * .github/workflows: map redis ports to host * pkg/storage/redis: use proto marshaler instead of json one * pkg/storage/redis: use better implementation By using redis supported datastructure: - Hash for storing record - Sorted set for storing by version - Set for storing deleted ids List operation will be now performed in O(log(N)+M) instead of O(N) like previous implementation. * pkg/storage/redis: add tx to wrap redis transaction * pkg/storage/redis: set record type in New * pkg/storage/redis: make sure tx commands appear in right order * pkg/storage/redis: make deletePermanentAfter as argument * pkg/storage/redis: make sure version is incremented when deleting * pkg/storage/redis: fix linter * pkg/storage/redis: fix cmd construction
This commit is contained in:
parent
858077b3b6
commit
26f099b49d
5 changed files with 372 additions and 0 deletions
24
.github/workflows/test.yaml
vendored
24
.github/workflows/test.yaml
vendored
|
@ -125,3 +125,27 @@ jobs:
|
||||||
|
|
||||||
- name: test
|
- name: test
|
||||||
run: go test -v ./integration/...
|
run: go test -v ./integration/...
|
||||||
|
|
||||||
|
storage-backend-test-redis:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
services:
|
||||||
|
redis:
|
||||||
|
image: redis
|
||||||
|
options: >-
|
||||||
|
--health-cmd "redis-cli ping"
|
||||||
|
--health-interval 10s
|
||||||
|
--health-timeout 5s
|
||||||
|
--health-retries 5
|
||||||
|
ports:
|
||||||
|
- 6379:6379
|
||||||
|
steps:
|
||||||
|
- name: install go
|
||||||
|
uses: actions/setup-go@v1
|
||||||
|
with:
|
||||||
|
go-version: 1.14.x
|
||||||
|
|
||||||
|
- name: checkout code
|
||||||
|
uses: actions/checkout@v2
|
||||||
|
|
||||||
|
- name: test
|
||||||
|
run: go test -v -tags redis ./pkg/storage/redis/...
|
||||||
|
|
1
go.mod
1
go.mod
|
@ -17,6 +17,7 @@ require (
|
||||||
github.com/go-chi/chi v4.1.2+incompatible
|
github.com/go-chi/chi v4.1.2+incompatible
|
||||||
github.com/golang/mock v1.4.3
|
github.com/golang/mock v1.4.3
|
||||||
github.com/golang/protobuf v1.4.2
|
github.com/golang/protobuf v1.4.2
|
||||||
|
github.com/gomodule/redigo v1.8.2
|
||||||
github.com/google/btree v1.0.0
|
github.com/google/btree v1.0.0
|
||||||
github.com/google/go-cmp v0.5.0
|
github.com/google/go-cmp v0.5.0
|
||||||
github.com/google/go-jsonnet v0.16.0
|
github.com/google/go-jsonnet v0.16.0
|
||||||
|
|
3
go.sum
3
go.sum
|
@ -206,6 +206,9 @@ github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QD
|
||||||
github.com/golang/protobuf v1.4.2 h1:+Z5KGCizgyZCbGh1KZqA0fcLLkwbsjIzS4aV2v7wJX0=
|
github.com/golang/protobuf v1.4.2 h1:+Z5KGCizgyZCbGh1KZqA0fcLLkwbsjIzS4aV2v7wJX0=
|
||||||
github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI=
|
github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI=
|
||||||
github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
|
github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
|
||||||
|
github.com/gomodule/redigo v1.8.2 h1:H5XSIre1MB5NbPYFp+i1NBbb5qN1W8Y8YAQoAYbkm8k=
|
||||||
|
github.com/gomodule/redigo v1.8.2/go.mod h1:P9dn9mFrCBvWhGE1wpxx6fgq7BAeLBk+UUUzlpkBYO0=
|
||||||
|
github.com/gomodule/redigo/redis v0.0.0-do-not-use h1:J7XIp6Kau0WoyT4JtXHT3Ei0gA1KkSc6bc87j9v9WIo=
|
||||||
github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ=
|
github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ=
|
||||||
github.com/google/btree v1.0.0 h1:0udJVsspx3VBr5FwtLhQQtuAsVc79tTq0ocGIPAU6qo=
|
github.com/google/btree v1.0.0 h1:0udJVsspx3VBr5FwtLhQQtuAsVc79tTq0ocGIPAU6qo=
|
||||||
github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ=
|
github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ=
|
||||||
|
|
246
pkg/storage/redis/redis.go
Normal file
246
pkg/storage/redis/redis.go
Normal file
|
@ -0,0 +1,246 @@
|
||||||
|
// Package redis is the redis database, implements storage.Backend interface.
|
||||||
|
package redis
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"strconv"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/golang/protobuf/proto"
|
||||||
|
"github.com/golang/protobuf/ptypes"
|
||||||
|
"github.com/gomodule/redigo/redis"
|
||||||
|
"google.golang.org/protobuf/types/known/anypb"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||||
|
"github.com/pomerium/pomerium/pkg/storage"
|
||||||
|
)
|
||||||
|
|
||||||
|
var _ storage.Backend = (*DB)(nil)
|
||||||
|
|
||||||
|
// DB wraps redis conn to interact with redis server.
|
||||||
|
type DB struct {
|
||||||
|
pool *redis.Pool
|
||||||
|
deletePermanentlyAfter int64
|
||||||
|
recordType string
|
||||||
|
lastVersion uint64
|
||||||
|
versionSet string
|
||||||
|
deletedSet string
|
||||||
|
}
|
||||||
|
|
||||||
|
// New returns new DB instance.
|
||||||
|
func New(address, recordType string, deletePermanentAfter int64) (*DB, error) {
|
||||||
|
db := &DB{
|
||||||
|
pool: &redis.Pool{
|
||||||
|
Wait: true,
|
||||||
|
DialContext: func(ctx context.Context) (redis.Conn, error) {
|
||||||
|
ctx, cancelFn := context.WithTimeout(ctx, 5*time.Second)
|
||||||
|
defer cancelFn()
|
||||||
|
c, err := redis.DialContext(ctx, "tcp", address)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf(`redis.DialURL(): %w`, err)
|
||||||
|
}
|
||||||
|
return c, nil
|
||||||
|
},
|
||||||
|
TestOnBorrow: func(c redis.Conn, t time.Time) error {
|
||||||
|
if time.Since(t) < time.Minute {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
_, err := c.Do("PING")
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf(`c.Do("PING"): %w`, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
deletePermanentlyAfter: deletePermanentAfter,
|
||||||
|
recordType: recordType,
|
||||||
|
versionSet: "version_set",
|
||||||
|
deletedSet: "deleted_set",
|
||||||
|
}
|
||||||
|
return db, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Put sets new record for given id with input data.
|
||||||
|
func (db *DB) Put(ctx context.Context, id string, data *anypb.Any) error {
|
||||||
|
c := db.pool.Get()
|
||||||
|
defer c.Close()
|
||||||
|
record := db.Get(ctx, id)
|
||||||
|
if record == nil {
|
||||||
|
record = new(databroker.Record)
|
||||||
|
record.CreatedAt = ptypes.TimestampNow()
|
||||||
|
}
|
||||||
|
|
||||||
|
record.Data = data
|
||||||
|
record.ModifiedAt = ptypes.TimestampNow()
|
||||||
|
record.Type = db.recordType
|
||||||
|
record.Id = id
|
||||||
|
record.Version = fmt.Sprintf("%012X", atomic.AddUint64(&db.lastVersion, 1))
|
||||||
|
b, err := proto.Marshal(record)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
cmds := []map[string][]interface{}{
|
||||||
|
{"MULTI": nil},
|
||||||
|
{"HSET": {db.recordType, id, string(b)}},
|
||||||
|
{"ZADD": {db.versionSet, db.lastVersion, id}},
|
||||||
|
}
|
||||||
|
if err := db.tx(c, cmds); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get retrieves a record from redis.
|
||||||
|
func (db *DB) Get(_ context.Context, id string) *databroker.Record {
|
||||||
|
c := db.pool.Get()
|
||||||
|
defer c.Close()
|
||||||
|
|
||||||
|
b, err := redis.Bytes(c.Do("HGET", db.recordType, id))
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return db.toPbRecord(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAll retrieves all records from redis.
|
||||||
|
func (db *DB) GetAll(ctx context.Context) []*databroker.Record {
|
||||||
|
return db.getAll(ctx, func(record *databroker.Record) bool { return true })
|
||||||
|
}
|
||||||
|
|
||||||
|
// List retrieves all records since given version.
|
||||||
|
//
|
||||||
|
// "version" is in hex format, invalid version will be treated as 0.
|
||||||
|
func (db *DB) List(ctx context.Context, sinceVersion string) []*databroker.Record {
|
||||||
|
c := db.pool.Get()
|
||||||
|
defer c.Close()
|
||||||
|
|
||||||
|
v, err := strconv.ParseUint(sinceVersion, 16, 64)
|
||||||
|
if err != nil {
|
||||||
|
v = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
ids, err := redis.Strings(c.Do("ZRANGEBYSCORE", db.versionSet, fmt.Sprintf("(%d", v), "+inf"))
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
records := make([]*databroker.Record, 0, len(ids))
|
||||||
|
for _, id := range ids {
|
||||||
|
b, err := redis.Bytes(c.Do("HGET", db.recordType, id))
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
records = append(records, db.toPbRecord(b))
|
||||||
|
}
|
||||||
|
return records
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete sets a record DeletedAt field and set its TTL.
|
||||||
|
func (db *DB) Delete(ctx context.Context, id string) error {
|
||||||
|
c := db.pool.Get()
|
||||||
|
defer c.Close()
|
||||||
|
|
||||||
|
r := db.Get(ctx, id)
|
||||||
|
if r == nil {
|
||||||
|
return errors.New("not found")
|
||||||
|
}
|
||||||
|
r.DeletedAt = ptypes.TimestampNow()
|
||||||
|
r.Version = fmt.Sprintf("%012X", atomic.AddUint64(&db.lastVersion, 1))
|
||||||
|
b, err := proto.Marshal(r)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
cmds := []map[string][]interface{}{
|
||||||
|
{"MULTI": nil},
|
||||||
|
{"HSET": {db.recordType, id, string(b)}},
|
||||||
|
{"SADD": {db.deletedSet, id}},
|
||||||
|
{"ZADD": {db.versionSet, db.lastVersion, id}},
|
||||||
|
}
|
||||||
|
if err := db.tx(c, cmds); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearDeleted clears all the currently deleted records older than the given cutoff.
|
||||||
|
func (db *DB) ClearDeleted(_ context.Context, cutoff time.Time) {
|
||||||
|
c := db.pool.Get()
|
||||||
|
defer c.Close()
|
||||||
|
|
||||||
|
ids, _ := redis.Strings(c.Do("SMEMBERS", db.deletedSet))
|
||||||
|
for _, id := range ids {
|
||||||
|
b, _ := redis.Bytes(c.Do("HGET", db.recordType, id))
|
||||||
|
record := db.toPbRecord(b)
|
||||||
|
if record == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
ts, _ := ptypes.Timestamp(record.DeletedAt)
|
||||||
|
if ts.Before(cutoff) {
|
||||||
|
cmds := []map[string][]interface{}{
|
||||||
|
{"MULTI": nil},
|
||||||
|
{"HDEL": {db.recordType, id}},
|
||||||
|
{"ZREM": {db.versionSet, id}},
|
||||||
|
{"SREM": {db.deletedSet, id}},
|
||||||
|
}
|
||||||
|
_ = db.tx(c, cmds)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *DB) getAll(_ context.Context, filter func(record *databroker.Record) bool) []*databroker.Record {
|
||||||
|
c := db.pool.Get()
|
||||||
|
defer c.Close()
|
||||||
|
iter := 0
|
||||||
|
records := make([]*databroker.Record, 0)
|
||||||
|
for {
|
||||||
|
arr, err := redis.Values(c.Do("HSCAN", db.recordType, iter, "MATCH", "*"))
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
iter, _ = redis.Int(arr[0], nil)
|
||||||
|
pairs, _ := redis.StringMap(arr[1], nil)
|
||||||
|
|
||||||
|
for _, v := range pairs {
|
||||||
|
record := db.toPbRecord([]byte(v))
|
||||||
|
if record == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if filter(record) {
|
||||||
|
records = append(records, record)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if iter == 0 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return records
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *DB) toPbRecord(b []byte) *databroker.Record {
|
||||||
|
record := &databroker.Record{}
|
||||||
|
if err := proto.Unmarshal(b, record); err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return record
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *DB) tx(c redis.Conn, commands []map[string][]interface{}) error {
|
||||||
|
for _, m := range commands {
|
||||||
|
for cmd, args := range m {
|
||||||
|
if err := c.Send(cmd, args...); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := c.Do("EXEC")
|
||||||
|
return err
|
||||||
|
}
|
98
pkg/storage/redis/redis_test.go
Normal file
98
pkg/storage/redis/redis_test.go
Normal file
|
@ -0,0 +1,98 @@
|
||||||
|
// +build redis
|
||||||
|
|
||||||
|
package redis
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gomodule/redigo/redis"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"google.golang.org/protobuf/types/known/anypb"
|
||||||
|
)
|
||||||
|
|
||||||
|
func cleanup(c redis.Conn, db *DB, t *testing.T) {
|
||||||
|
require.NoError(t, c.Send("MULTI"))
|
||||||
|
require.NoError(t, c.Send("DEL", db.recordType))
|
||||||
|
require.NoError(t, c.Send("DEL", db.versionSet))
|
||||||
|
require.NoError(t, c.Send("DEL", db.deletedSet))
|
||||||
|
_, err := c.Do("EXEC")
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDB(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
address := ":6379"
|
||||||
|
if redisURL := os.Getenv("REDIS_URL"); redisURL != "" {
|
||||||
|
address = redisURL
|
||||||
|
}
|
||||||
|
db, err := New(address, "record_type", int64(time.Hour.Seconds()))
|
||||||
|
require.NoError(t, err)
|
||||||
|
ids := []string{"a", "b", "c"}
|
||||||
|
id := ids[0]
|
||||||
|
c := db.pool.Get()
|
||||||
|
defer c.Close()
|
||||||
|
|
||||||
|
cleanup(c, db, t)
|
||||||
|
|
||||||
|
t.Run("get missing record", func(t *testing.T) {
|
||||||
|
assert.Nil(t, db.Get(ctx, id))
|
||||||
|
})
|
||||||
|
t.Run("get record", func(t *testing.T) {
|
||||||
|
data := new(anypb.Any)
|
||||||
|
assert.NoError(t, db.Put(ctx, id, data))
|
||||||
|
record := db.Get(ctx, id)
|
||||||
|
if assert.NotNil(t, record) {
|
||||||
|
assert.NotNil(t, record.CreatedAt)
|
||||||
|
assert.Equal(t, data, record.Data)
|
||||||
|
assert.Nil(t, record.DeletedAt)
|
||||||
|
assert.Equal(t, "a", record.Id)
|
||||||
|
assert.NotNil(t, record.ModifiedAt)
|
||||||
|
assert.Equal(t, "000000000001", record.Version)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
t.Run("delete record", func(t *testing.T) {
|
||||||
|
assert.NoError(t, db.Delete(ctx, id))
|
||||||
|
record := db.Get(ctx, id)
|
||||||
|
require.NotNil(t, record)
|
||||||
|
assert.NotNil(t, record.DeletedAt)
|
||||||
|
})
|
||||||
|
t.Run("clear deleted", func(t *testing.T) {
|
||||||
|
db.ClearDeleted(ctx, time.Now().Add(time.Second))
|
||||||
|
assert.Nil(t, db.Get(ctx, id))
|
||||||
|
})
|
||||||
|
t.Run("get all", func(t *testing.T) {
|
||||||
|
assert.Len(t, db.GetAll(ctx), 0)
|
||||||
|
data := new(anypb.Any)
|
||||||
|
|
||||||
|
for _, id := range ids {
|
||||||
|
assert.NoError(t, db.Put(ctx, id, data))
|
||||||
|
}
|
||||||
|
assert.Len(t, db.GetAll(ctx), len(ids))
|
||||||
|
for _, id := range ids {
|
||||||
|
_, _ = c.Do("DEL", id)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
t.Run("list", func(t *testing.T) {
|
||||||
|
cleanup(c, db, t)
|
||||||
|
ids := make([]string, 0, 10)
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
id := fmt.Sprintf("%02d", i)
|
||||||
|
ids = append(ids, id)
|
||||||
|
data := new(anypb.Any)
|
||||||
|
assert.NoError(t, db.Put(ctx, id, data))
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Len(t, db.List(ctx, ""), 10)
|
||||||
|
assert.Len(t, db.List(ctx, "00000000000A"), 5)
|
||||||
|
assert.Len(t, db.List(ctx, "00000000000F"), 0)
|
||||||
|
|
||||||
|
for _, id := range ids {
|
||||||
|
_, _ = c.Do("DEL", id)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
Loading…
Add table
Add a link
Reference in a new issue