mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-23 14:07:11 +02:00
pkg/storage: change backend interface to return error (#1131)
Since when storage backend like redis can be fault in many cases, the interface should return error for the caller to handle.
This commit is contained in:
parent
90d95b8c10
commit
aedfbc4c71
7 changed files with 108 additions and 69 deletions
|
@ -86,7 +86,7 @@ func (srv *Server) initVersion() {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get version from storage first.
|
// Get version from storage first.
|
||||||
if r := dbServerVersion.Get(context.Background(), serverVersionKey); r != nil {
|
if r, _ := dbServerVersion.Get(context.Background(), serverVersionKey); r != nil {
|
||||||
var sv databroker.ServerVersion
|
var sv databroker.ServerVersion
|
||||||
if err := ptypes.UnmarshalAny(r.GetData(), &sv); err == nil {
|
if err := ptypes.UnmarshalAny(r.GetData(), &sv); err == nil {
|
||||||
srv.log.Debug().Str("server_version", sv.Version).Msg("got db version from DB")
|
srv.log.Debug().Str("server_version", sv.Version).Msg("got db version from DB")
|
||||||
|
@ -137,8 +137,8 @@ func (srv *Server) Get(ctx context.Context, req *databroker.GetRequest) (*databr
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
record := db.Get(ctx, req.GetId())
|
record, err := db.Get(ctx, req.GetId())
|
||||||
if record == nil {
|
if err != nil {
|
||||||
return nil, status.Error(codes.NotFound, "record not found")
|
return nil, status.Error(codes.NotFound, "record not found")
|
||||||
}
|
}
|
||||||
return &databroker.GetResponse{Record: record}, nil
|
return &databroker.GetResponse{Record: record}, nil
|
||||||
|
@ -156,7 +156,10 @@ func (srv *Server) GetAll(ctx context.Context, req *databroker.GetAllRequest) (*
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
records := db.GetAll(ctx)
|
records, err := db.GetAll(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
var recordVersion string
|
var recordVersion string
|
||||||
for _, record := range records {
|
for _, record := range records {
|
||||||
if record.GetVersion() > recordVersion {
|
if record.GetVersion() > recordVersion {
|
||||||
|
@ -188,8 +191,10 @@ func (srv *Server) Set(ctx context.Context, req *databroker.SetRequest) (*databr
|
||||||
if err := db.Put(ctx, req.GetId(), req.GetData()); err != nil {
|
if err := db.Put(ctx, req.GetId(), req.GetData()); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
record := db.Get(ctx, req.GetId())
|
record, err := db.Get(ctx, req.GetId())
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
return &databroker.SetResponse{
|
return &databroker.SetResponse{
|
||||||
Record: record,
|
Record: record,
|
||||||
ServerVersion: srv.version,
|
ServerVersion: srv.version,
|
||||||
|
@ -220,8 +225,7 @@ func (srv *Server) Sync(req *databroker.SyncRequest, stream databroker.DataBroke
|
||||||
ch := srv.onchange.Bind()
|
ch := srv.onchange.Bind()
|
||||||
defer srv.onchange.Unbind(ch)
|
defer srv.onchange.Unbind(ch)
|
||||||
for {
|
for {
|
||||||
updated := db.List(context.Background(), recordVersion)
|
updated, _ := db.List(context.Background(), recordVersion)
|
||||||
|
|
||||||
if len(updated) > 0 {
|
if len(updated) > 0 {
|
||||||
sort.Slice(updated, func(i, j int) bool {
|
sort.Slice(updated, func(i, j int) bool {
|
||||||
return updated[i].Version < updated[j].Version
|
return updated[i].Version < updated[j].Version
|
||||||
|
|
|
@ -40,13 +40,15 @@ func TestServer_initVersion(t *testing.T) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
db, err := srv.getDB(recordTypeServerVersion)
|
db, err := srv.getDB(recordTypeServerVersion)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
r := db.Get(ctx, serverVersionKey)
|
r, err := db.Get(ctx, serverVersionKey)
|
||||||
|
assert.Error(t, err)
|
||||||
assert.Nil(t, r)
|
assert.Nil(t, r)
|
||||||
srvVersion := uuid.New().String()
|
srvVersion := uuid.New().String()
|
||||||
srv.version = srvVersion
|
srv.version = srvVersion
|
||||||
srv.initVersion()
|
srv.initVersion()
|
||||||
assert.Equal(t, srvVersion, srv.version)
|
assert.Equal(t, srvVersion, srv.version)
|
||||||
r = db.Get(ctx, serverVersionKey)
|
r, err = db.Get(ctx, serverVersionKey)
|
||||||
|
require.NoError(t, err)
|
||||||
assert.NotNil(t, r)
|
assert.NotNil(t, r)
|
||||||
var sv databroker.ServerVersion
|
var sv databroker.ServerVersion
|
||||||
assert.NoError(t, ptypes.UnmarshalAny(r.GetData(), &sv))
|
assert.NoError(t, ptypes.UnmarshalAny(r.GetData(), &sv))
|
||||||
|
@ -57,13 +59,15 @@ func TestServer_initVersion(t *testing.T) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
db, err := srv.getDB(recordTypeServerVersion)
|
db, err := srv.getDB(recordTypeServerVersion)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
r := db.Get(ctx, serverVersionKey)
|
r, err := db.Get(ctx, serverVersionKey)
|
||||||
|
assert.Error(t, err)
|
||||||
assert.Nil(t, r)
|
assert.Nil(t, r)
|
||||||
|
|
||||||
srv.initVersion()
|
srv.initVersion()
|
||||||
srvVersion := srv.version
|
srvVersion := srv.version
|
||||||
|
|
||||||
r = db.Get(ctx, serverVersionKey)
|
r, err = db.Get(ctx, serverVersionKey)
|
||||||
|
require.NoError(t, err)
|
||||||
assert.NotNil(t, r)
|
assert.NotNil(t, r)
|
||||||
var sv databroker.ServerVersion
|
var sv databroker.ServerVersion
|
||||||
assert.NoError(t, ptypes.UnmarshalAny(r.GetData(), &sv))
|
assert.NoError(t, ptypes.UnmarshalAny(r.GetData(), &sv))
|
||||||
|
|
|
@ -3,6 +3,7 @@ package inmemory
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
|
@ -88,26 +89,26 @@ func (db *DB) Delete(_ context.Context, id string) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get gets a record from the db.
|
// Get gets a record from the db.
|
||||||
func (db *DB) Get(_ context.Context, id string) *databroker.Record {
|
func (db *DB) Get(_ context.Context, id string) (*databroker.Record, error) {
|
||||||
record, ok := db.byID.Get(byIDRecord{Record: &databroker.Record{Id: id}}).(byIDRecord)
|
record, ok := db.byID.Get(byIDRecord{Record: &databroker.Record{Id: id}}).(byIDRecord)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil
|
return nil, errors.New("not found")
|
||||||
}
|
}
|
||||||
return record.Record
|
return record.Record, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAll gets all the records in the db.
|
// GetAll gets all the records in the db.
|
||||||
func (db *DB) GetAll(_ context.Context) []*databroker.Record {
|
func (db *DB) GetAll(_ context.Context) ([]*databroker.Record, error) {
|
||||||
var records []*databroker.Record
|
var records []*databroker.Record
|
||||||
db.byID.Ascend(func(item btree.Item) bool {
|
db.byID.Ascend(func(item btree.Item) bool {
|
||||||
records = append(records, item.(byIDRecord).Record)
|
records = append(records, item.(byIDRecord).Record)
|
||||||
return true
|
return true
|
||||||
})
|
})
|
||||||
return records
|
return records, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// List lists all the changes since the given version.
|
// List lists all the changes since the given version.
|
||||||
func (db *DB) List(_ context.Context, sinceVersion string) []*databroker.Record {
|
func (db *DB) List(_ context.Context, sinceVersion string) ([]*databroker.Record, error) {
|
||||||
var records []*databroker.Record
|
var records []*databroker.Record
|
||||||
db.byVersion.AscendGreaterOrEqual(byVersionRecord{Record: &databroker.Record{Version: sinceVersion}}, func(i btree.Item) bool {
|
db.byVersion.AscendGreaterOrEqual(byVersionRecord{Record: &databroker.Record{Version: sinceVersion}}, func(i btree.Item) bool {
|
||||||
record := i.(byVersionRecord)
|
record := i.(byVersionRecord)
|
||||||
|
@ -116,7 +117,7 @@ func (db *DB) List(_ context.Context, sinceVersion string) []*databroker.Record
|
||||||
}
|
}
|
||||||
return true
|
return true
|
||||||
})
|
})
|
||||||
return records
|
return records, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Put replaces or inserts a record in the db.
|
// Put replaces or inserts a record in the db.
|
||||||
|
|
|
@ -7,6 +7,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
"google.golang.org/protobuf/types/known/anypb"
|
"google.golang.org/protobuf/types/known/anypb"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -14,12 +15,15 @@ func TestDB(t *testing.T) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
db := NewDB("example", 2)
|
db := NewDB("example", 2)
|
||||||
t.Run("get missing record", func(t *testing.T) {
|
t.Run("get missing record", func(t *testing.T) {
|
||||||
assert.Nil(t, db.Get(ctx, "abcd"))
|
record, err := db.Get(ctx, "abcd")
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Nil(t, record)
|
||||||
})
|
})
|
||||||
t.Run("get record", func(t *testing.T) {
|
t.Run("get record", func(t *testing.T) {
|
||||||
data := new(anypb.Any)
|
data := new(anypb.Any)
|
||||||
assert.NoError(t, db.Put(ctx, "abcd", data))
|
assert.NoError(t, db.Put(ctx, "abcd", data))
|
||||||
record := db.Get(ctx, "abcd")
|
record, err := db.Get(ctx, "abcd")
|
||||||
|
require.NoError(t, err)
|
||||||
if assert.NotNil(t, record) {
|
if assert.NotNil(t, record) {
|
||||||
assert.NotNil(t, record.CreatedAt)
|
assert.NotNil(t, record.CreatedAt)
|
||||||
assert.Equal(t, data, record.Data)
|
assert.Equal(t, data, record.Data)
|
||||||
|
@ -32,21 +36,26 @@ func TestDB(t *testing.T) {
|
||||||
})
|
})
|
||||||
t.Run("delete record", func(t *testing.T) {
|
t.Run("delete record", func(t *testing.T) {
|
||||||
assert.NoError(t, db.Delete(ctx, "abcd"))
|
assert.NoError(t, db.Delete(ctx, "abcd"))
|
||||||
record := db.Get(ctx, "abcd")
|
record, err := db.Get(ctx, "abcd")
|
||||||
|
require.NoError(t, err)
|
||||||
if assert.NotNil(t, record) {
|
if assert.NotNil(t, record) {
|
||||||
assert.NotNil(t, record.DeletedAt)
|
assert.NotNil(t, record.DeletedAt)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
t.Run("clear deleted", func(t *testing.T) {
|
t.Run("clear deleted", func(t *testing.T) {
|
||||||
db.ClearDeleted(ctx, time.Now().Add(time.Second))
|
db.ClearDeleted(ctx, time.Now().Add(time.Second))
|
||||||
assert.Nil(t, db.Get(ctx, "abcd"))
|
record, err := db.Get(ctx, "abcd")
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Nil(t, record)
|
||||||
})
|
})
|
||||||
t.Run("keep remaining", func(t *testing.T) {
|
t.Run("keep remaining", func(t *testing.T) {
|
||||||
data := new(anypb.Any)
|
data := new(anypb.Any)
|
||||||
assert.NoError(t, db.Put(ctx, "abcd", data))
|
assert.NoError(t, db.Put(ctx, "abcd", data))
|
||||||
assert.NoError(t, db.Delete(ctx, "abcd"))
|
assert.NoError(t, db.Delete(ctx, "abcd"))
|
||||||
db.ClearDeleted(ctx, time.Now().Add(-10*time.Second))
|
db.ClearDeleted(ctx, time.Now().Add(-10*time.Second))
|
||||||
assert.NotNil(t, db.Get(ctx, "abcd"))
|
record, err := db.Get(ctx, "abcd")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotNil(t, record)
|
||||||
db.ClearDeleted(ctx, time.Now().Add(time.Second))
|
db.ClearDeleted(ctx, time.Now().Add(time.Second))
|
||||||
})
|
})
|
||||||
t.Run("list", func(t *testing.T) {
|
t.Run("list", func(t *testing.T) {
|
||||||
|
@ -55,8 +64,14 @@ func TestDB(t *testing.T) {
|
||||||
assert.NoError(t, db.Put(ctx, fmt.Sprintf("%02d", i), data))
|
assert.NoError(t, db.Put(ctx, fmt.Sprintf("%02d", i), data))
|
||||||
}
|
}
|
||||||
|
|
||||||
assert.Len(t, db.List(ctx, ""), 10)
|
records, err := db.List(ctx, "")
|
||||||
assert.Len(t, db.List(ctx, "00000000000A"), 4)
|
require.NoError(t, err)
|
||||||
assert.Len(t, db.List(ctx, "00000000000F"), 0)
|
assert.Len(t, records, 10)
|
||||||
|
records, err = db.List(ctx, "00000000000A")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Len(t, records, 4)
|
||||||
|
records, err = db.List(ctx, "00000000000F")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Len(t, records, 0)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,7 +3,6 @@ package redis
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"strconv"
|
"strconv"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
|
@ -70,8 +69,8 @@ func New(address, recordType string, deletePermanentAfter int64) (*DB, error) {
|
||||||
func (db *DB) Put(ctx context.Context, id string, data *anypb.Any) error {
|
func (db *DB) Put(ctx context.Context, id string, data *anypb.Any) error {
|
||||||
c := db.pool.Get()
|
c := db.pool.Get()
|
||||||
defer c.Close()
|
defer c.Close()
|
||||||
record := db.Get(ctx, id)
|
record, err := db.Get(ctx, id)
|
||||||
if record == nil {
|
if err != nil {
|
||||||
record = new(databroker.Record)
|
record = new(databroker.Record)
|
||||||
record.CreatedAt = ptypes.TimestampNow()
|
record.CreatedAt = ptypes.TimestampNow()
|
||||||
}
|
}
|
||||||
|
@ -97,27 +96,27 @@ func (db *DB) Put(ctx context.Context, id string, data *anypb.Any) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get retrieves a record from redis.
|
// Get retrieves a record from redis.
|
||||||
func (db *DB) Get(_ context.Context, id string) *databroker.Record {
|
func (db *DB) Get(_ context.Context, id string) (*databroker.Record, error) {
|
||||||
c := db.pool.Get()
|
c := db.pool.Get()
|
||||||
defer c.Close()
|
defer c.Close()
|
||||||
|
|
||||||
b, err := redis.Bytes(c.Do("HGET", db.recordType, id))
|
b, err := redis.Bytes(c.Do("HGET", db.recordType, id))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return db.toPbRecord(b)
|
return db.toPbRecord(b)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAll retrieves all records from redis.
|
// GetAll retrieves all records from redis.
|
||||||
func (db *DB) GetAll(ctx context.Context) []*databroker.Record {
|
func (db *DB) GetAll(ctx context.Context) ([]*databroker.Record, error) {
|
||||||
return db.getAll(ctx, func(record *databroker.Record) bool { return true })
|
return db.getAll(ctx, func(record *databroker.Record) bool { return true })
|
||||||
}
|
}
|
||||||
|
|
||||||
// List retrieves all records since given version.
|
// List retrieves all records since given version.
|
||||||
//
|
//
|
||||||
// "version" is in hex format, invalid version will be treated as 0.
|
// "version" is in hex format, invalid version will be treated as 0.
|
||||||
func (db *DB) List(ctx context.Context, sinceVersion string) []*databroker.Record {
|
func (db *DB) List(ctx context.Context, sinceVersion string) ([]*databroker.Record, error) {
|
||||||
c := db.pool.Get()
|
c := db.pool.Get()
|
||||||
defer c.Close()
|
defer c.Close()
|
||||||
|
|
||||||
|
@ -128,18 +127,22 @@ func (db *DB) List(ctx context.Context, sinceVersion string) []*databroker.Recor
|
||||||
|
|
||||||
ids, err := redis.Strings(c.Do("ZRANGEBYSCORE", db.versionSet, fmt.Sprintf("(%d", v), "+inf"))
|
ids, err := redis.Strings(c.Do("ZRANGEBYSCORE", db.versionSet, fmt.Sprintf("(%d", v), "+inf"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
records := make([]*databroker.Record, 0, len(ids))
|
pbRecords := make([]*databroker.Record, 0, len(ids))
|
||||||
for _, id := range ids {
|
for _, id := range ids {
|
||||||
b, err := redis.Bytes(c.Do("HGET", db.recordType, id))
|
b, err := redis.Bytes(c.Do("HGET", db.recordType, id))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
continue
|
return nil, err
|
||||||
}
|
}
|
||||||
records = append(records, db.toPbRecord(b))
|
pbRecord, err := db.toPbRecord(b)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
pbRecords = append(pbRecords, pbRecord)
|
||||||
}
|
}
|
||||||
return records
|
return pbRecords, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Delete sets a record DeletedAt field and set its TTL.
|
// Delete sets a record DeletedAt field and set its TTL.
|
||||||
|
@ -147,9 +150,9 @@ func (db *DB) Delete(ctx context.Context, id string) error {
|
||||||
c := db.pool.Get()
|
c := db.pool.Get()
|
||||||
defer c.Close()
|
defer c.Close()
|
||||||
|
|
||||||
r := db.Get(ctx, id)
|
r, err := db.Get(ctx, id)
|
||||||
if r == nil {
|
if err != nil {
|
||||||
return errors.New("not found")
|
return fmt.Errorf("failed to get record: %w", err)
|
||||||
}
|
}
|
||||||
r.DeletedAt = ptypes.TimestampNow()
|
r.DeletedAt = ptypes.TimestampNow()
|
||||||
r.Version = fmt.Sprintf("%012X", atomic.AddUint64(&db.lastVersion, 1))
|
r.Version = fmt.Sprintf("%012X", atomic.AddUint64(&db.lastVersion, 1))
|
||||||
|
@ -177,8 +180,8 @@ func (db *DB) ClearDeleted(_ context.Context, cutoff time.Time) {
|
||||||
ids, _ := redis.Strings(c.Do("SMEMBERS", db.deletedSet))
|
ids, _ := redis.Strings(c.Do("SMEMBERS", db.deletedSet))
|
||||||
for _, id := range ids {
|
for _, id := range ids {
|
||||||
b, _ := redis.Bytes(c.Do("HGET", db.recordType, id))
|
b, _ := redis.Bytes(c.Do("HGET", db.recordType, id))
|
||||||
record := db.toPbRecord(b)
|
record, err := db.toPbRecord(b)
|
||||||
if record == nil {
|
if err != nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -195,7 +198,7 @@ func (db *DB) ClearDeleted(_ context.Context, cutoff time.Time) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *DB) getAll(_ context.Context, filter func(record *databroker.Record) bool) []*databroker.Record {
|
func (db *DB) getAll(_ context.Context, filter func(record *databroker.Record) bool) ([]*databroker.Record, error) {
|
||||||
c := db.pool.Get()
|
c := db.pool.Get()
|
||||||
defer c.Close()
|
defer c.Close()
|
||||||
iter := 0
|
iter := 0
|
||||||
|
@ -203,16 +206,16 @@ func (db *DB) getAll(_ context.Context, filter func(record *databroker.Record) b
|
||||||
for {
|
for {
|
||||||
arr, err := redis.Values(c.Do("HSCAN", db.recordType, iter, "MATCH", "*"))
|
arr, err := redis.Values(c.Do("HSCAN", db.recordType, iter, "MATCH", "*"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
iter, _ = redis.Int(arr[0], nil)
|
iter, _ = redis.Int(arr[0], nil)
|
||||||
pairs, _ := redis.StringMap(arr[1], nil)
|
pairs, _ := redis.StringMap(arr[1], nil)
|
||||||
|
|
||||||
for _, v := range pairs {
|
for _, v := range pairs {
|
||||||
record := db.toPbRecord([]byte(v))
|
record, err := db.toPbRecord([]byte(v))
|
||||||
if record == nil {
|
if err != nil {
|
||||||
continue
|
return nil, err
|
||||||
}
|
}
|
||||||
if filter(record) {
|
if filter(record) {
|
||||||
records = append(records, record)
|
records = append(records, record)
|
||||||
|
@ -224,15 +227,15 @@ func (db *DB) getAll(_ context.Context, filter func(record *databroker.Record) b
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return records
|
return records, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *DB) toPbRecord(b []byte) *databroker.Record {
|
func (db *DB) toPbRecord(b []byte) (*databroker.Record, error) {
|
||||||
record := &databroker.Record{}
|
record := &databroker.Record{}
|
||||||
if err := proto.Unmarshal(b, record); err != nil {
|
if err := proto.Unmarshal(b, record); err != nil {
|
||||||
return nil
|
return nil, err
|
||||||
}
|
}
|
||||||
return record
|
return record, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *DB) tx(c redis.Conn, commands []map[string][]interface{}) error {
|
func (db *DB) tx(c redis.Conn, commands []map[string][]interface{}) error {
|
||||||
|
|
|
@ -40,12 +40,15 @@ func TestDB(t *testing.T) {
|
||||||
cleanup(c, db, t)
|
cleanup(c, db, t)
|
||||||
|
|
||||||
t.Run("get missing record", func(t *testing.T) {
|
t.Run("get missing record", func(t *testing.T) {
|
||||||
assert.Nil(t, db.Get(ctx, id))
|
record, err := db.Get(ctx, id)
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Nil(t, record)
|
||||||
})
|
})
|
||||||
t.Run("get record", func(t *testing.T) {
|
t.Run("get record", func(t *testing.T) {
|
||||||
data := new(anypb.Any)
|
data := new(anypb.Any)
|
||||||
assert.NoError(t, db.Put(ctx, id, data))
|
assert.NoError(t, db.Put(ctx, id, data))
|
||||||
record := db.Get(ctx, id)
|
record, err := db.Get(ctx, id)
|
||||||
|
require.NoError(t, err)
|
||||||
if assert.NotNil(t, record) {
|
if assert.NotNil(t, record) {
|
||||||
assert.NotNil(t, record.CreatedAt)
|
assert.NotNil(t, record.CreatedAt)
|
||||||
assert.Equal(t, data, record.Data)
|
assert.Equal(t, data, record.Data)
|
||||||
|
@ -57,22 +60,29 @@ func TestDB(t *testing.T) {
|
||||||
})
|
})
|
||||||
t.Run("delete record", func(t *testing.T) {
|
t.Run("delete record", func(t *testing.T) {
|
||||||
assert.NoError(t, db.Delete(ctx, id))
|
assert.NoError(t, db.Delete(ctx, id))
|
||||||
record := db.Get(ctx, id)
|
record, err := db.Get(ctx, id)
|
||||||
|
require.NoError(t, err)
|
||||||
require.NotNil(t, record)
|
require.NotNil(t, record)
|
||||||
assert.NotNil(t, record.DeletedAt)
|
assert.NotNil(t, record.DeletedAt)
|
||||||
})
|
})
|
||||||
t.Run("clear deleted", func(t *testing.T) {
|
t.Run("clear deleted", func(t *testing.T) {
|
||||||
db.ClearDeleted(ctx, time.Now().Add(time.Second))
|
db.ClearDeleted(ctx, time.Now().Add(time.Second))
|
||||||
assert.Nil(t, db.Get(ctx, id))
|
record, err := db.Get(ctx, id)
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Nil(t, record)
|
||||||
})
|
})
|
||||||
t.Run("get all", func(t *testing.T) {
|
t.Run("get all", func(t *testing.T) {
|
||||||
assert.Len(t, db.GetAll(ctx), 0)
|
records, err := db.GetAll(ctx)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Len(t, records, 0)
|
||||||
data := new(anypb.Any)
|
data := new(anypb.Any)
|
||||||
|
|
||||||
for _, id := range ids {
|
for _, id := range ids {
|
||||||
assert.NoError(t, db.Put(ctx, id, data))
|
assert.NoError(t, db.Put(ctx, id, data))
|
||||||
}
|
}
|
||||||
assert.Len(t, db.GetAll(ctx), len(ids))
|
records, err = db.GetAll(ctx)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Len(t, records, len(ids))
|
||||||
for _, id := range ids {
|
for _, id := range ids {
|
||||||
_, _ = c.Do("DEL", id)
|
_, _ = c.Do("DEL", id)
|
||||||
}
|
}
|
||||||
|
@ -87,12 +97,14 @@ func TestDB(t *testing.T) {
|
||||||
assert.NoError(t, db.Put(ctx, id, data))
|
assert.NoError(t, db.Put(ctx, id, data))
|
||||||
}
|
}
|
||||||
|
|
||||||
assert.Len(t, db.List(ctx, ""), 10)
|
records, err := db.List(ctx, "")
|
||||||
assert.Len(t, db.List(ctx, "00000000000A"), 5)
|
assert.NoError(t, err)
|
||||||
assert.Len(t, db.List(ctx, "00000000000F"), 0)
|
assert.Len(t, records, 10)
|
||||||
|
records, err = db.List(ctx, "00000000000A")
|
||||||
for _, id := range ids {
|
assert.NoError(t, err)
|
||||||
_, _ = c.Do("DEL", id)
|
assert.Len(t, records, 5)
|
||||||
}
|
records, err = db.List(ctx, "00000000000F")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Len(t, records, 0)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,13 +16,13 @@ type Backend interface {
|
||||||
Put(ctx context.Context, id string, data *anypb.Any) error
|
Put(ctx context.Context, id string, data *anypb.Any) error
|
||||||
|
|
||||||
// Get is used to retrieve a record.
|
// Get is used to retrieve a record.
|
||||||
Get(ctx context.Context, id string) *databroker.Record
|
Get(ctx context.Context, id string) (*databroker.Record, error)
|
||||||
|
|
||||||
// GetAll is used to retrieve all the records.
|
// GetAll is used to retrieve all the records.
|
||||||
GetAll(ctx context.Context) []*databroker.Record
|
GetAll(ctx context.Context) ([]*databroker.Record, error)
|
||||||
|
|
||||||
// List is used to retrieve all the records since a version.
|
// List is used to retrieve all the records since a version.
|
||||||
List(ctx context.Context, sinceVersion string) []*databroker.Record
|
List(ctx context.Context, sinceVersion string) ([]*databroker.Record, error)
|
||||||
|
|
||||||
// Delete is used to mark a record as deleted.
|
// Delete is used to mark a record as deleted.
|
||||||
Delete(ctx context.Context, id string) error
|
Delete(ctx context.Context, id string) error
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue