pkg/storage: change backend interface to return error (#1131)

Since when storage backend like redis can be fault in many cases, the
interface should return error for the caller to handle.
This commit is contained in:
Cuong Manh Le 2020-07-24 09:02:37 +07:00 committed by GitHub
parent 90d95b8c10
commit aedfbc4c71
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 108 additions and 69 deletions

View file

@ -86,7 +86,7 @@ func (srv *Server) initVersion() {
}
// Get version from storage first.
if r := dbServerVersion.Get(context.Background(), serverVersionKey); r != nil {
if r, _ := dbServerVersion.Get(context.Background(), serverVersionKey); r != nil {
var sv databroker.ServerVersion
if err := ptypes.UnmarshalAny(r.GetData(), &sv); err == nil {
srv.log.Debug().Str("server_version", sv.Version).Msg("got db version from DB")
@ -137,8 +137,8 @@ func (srv *Server) Get(ctx context.Context, req *databroker.GetRequest) (*databr
if err != nil {
return nil, err
}
record := db.Get(ctx, req.GetId())
if record == nil {
record, err := db.Get(ctx, req.GetId())
if err != nil {
return nil, status.Error(codes.NotFound, "record not found")
}
return &databroker.GetResponse{Record: record}, nil
@ -156,7 +156,10 @@ func (srv *Server) GetAll(ctx context.Context, req *databroker.GetAllRequest) (*
if err != nil {
return nil, err
}
records := db.GetAll(ctx)
records, err := db.GetAll(ctx)
if err != nil {
return nil, err
}
var recordVersion string
for _, record := range records {
if record.GetVersion() > recordVersion {
@ -188,8 +191,10 @@ func (srv *Server) Set(ctx context.Context, req *databroker.SetRequest) (*databr
if err := db.Put(ctx, req.GetId(), req.GetData()); err != nil {
return nil, err
}
record := db.Get(ctx, req.GetId())
record, err := db.Get(ctx, req.GetId())
if err != nil {
return nil, err
}
return &databroker.SetResponse{
Record: record,
ServerVersion: srv.version,
@ -220,8 +225,7 @@ func (srv *Server) Sync(req *databroker.SyncRequest, stream databroker.DataBroke
ch := srv.onchange.Bind()
defer srv.onchange.Unbind(ch)
for {
updated := db.List(context.Background(), recordVersion)
updated, _ := db.List(context.Background(), recordVersion)
if len(updated) > 0 {
sort.Slice(updated, func(i, j int) bool {
return updated[i].Version < updated[j].Version

View file

@ -40,13 +40,15 @@ func TestServer_initVersion(t *testing.T) {
ctx := context.Background()
db, err := srv.getDB(recordTypeServerVersion)
require.NoError(t, err)
r := db.Get(ctx, serverVersionKey)
r, err := db.Get(ctx, serverVersionKey)
assert.Error(t, err)
assert.Nil(t, r)
srvVersion := uuid.New().String()
srv.version = srvVersion
srv.initVersion()
assert.Equal(t, srvVersion, srv.version)
r = db.Get(ctx, serverVersionKey)
r, err = db.Get(ctx, serverVersionKey)
require.NoError(t, err)
assert.NotNil(t, r)
var sv databroker.ServerVersion
assert.NoError(t, ptypes.UnmarshalAny(r.GetData(), &sv))
@ -57,13 +59,15 @@ func TestServer_initVersion(t *testing.T) {
ctx := context.Background()
db, err := srv.getDB(recordTypeServerVersion)
require.NoError(t, err)
r := db.Get(ctx, serverVersionKey)
r, err := db.Get(ctx, serverVersionKey)
assert.Error(t, err)
assert.Nil(t, r)
srv.initVersion()
srvVersion := srv.version
r = db.Get(ctx, serverVersionKey)
r, err = db.Get(ctx, serverVersionKey)
require.NoError(t, err)
assert.NotNil(t, r)
var sv databroker.ServerVersion
assert.NoError(t, ptypes.UnmarshalAny(r.GetData(), &sv))

View file

@ -3,6 +3,7 @@ package inmemory
import (
"context"
"errors"
"fmt"
"sync"
"sync/atomic"
@ -88,26 +89,26 @@ func (db *DB) Delete(_ context.Context, id string) error {
}
// Get gets a record from the db.
func (db *DB) Get(_ context.Context, id string) *databroker.Record {
func (db *DB) Get(_ context.Context, id string) (*databroker.Record, error) {
record, ok := db.byID.Get(byIDRecord{Record: &databroker.Record{Id: id}}).(byIDRecord)
if !ok {
return nil
return nil, errors.New("not found")
}
return record.Record
return record.Record, nil
}
// GetAll gets all the records in the db.
func (db *DB) GetAll(_ context.Context) []*databroker.Record {
func (db *DB) GetAll(_ context.Context) ([]*databroker.Record, error) {
var records []*databroker.Record
db.byID.Ascend(func(item btree.Item) bool {
records = append(records, item.(byIDRecord).Record)
return true
})
return records
return records, nil
}
// List lists all the changes since the given version.
func (db *DB) List(_ context.Context, sinceVersion string) []*databroker.Record {
func (db *DB) List(_ context.Context, sinceVersion string) ([]*databroker.Record, error) {
var records []*databroker.Record
db.byVersion.AscendGreaterOrEqual(byVersionRecord{Record: &databroker.Record{Version: sinceVersion}}, func(i btree.Item) bool {
record := i.(byVersionRecord)
@ -116,7 +117,7 @@ func (db *DB) List(_ context.Context, sinceVersion string) []*databroker.Record
}
return true
})
return records
return records, nil
}
// Put replaces or inserts a record in the db.

View file

@ -7,6 +7,7 @@ import (
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/types/known/anypb"
)
@ -14,12 +15,15 @@ func TestDB(t *testing.T) {
ctx := context.Background()
db := NewDB("example", 2)
t.Run("get missing record", func(t *testing.T) {
assert.Nil(t, db.Get(ctx, "abcd"))
record, err := db.Get(ctx, "abcd")
require.Error(t, err)
assert.Nil(t, record)
})
t.Run("get record", func(t *testing.T) {
data := new(anypb.Any)
assert.NoError(t, db.Put(ctx, "abcd", data))
record := db.Get(ctx, "abcd")
record, err := db.Get(ctx, "abcd")
require.NoError(t, err)
if assert.NotNil(t, record) {
assert.NotNil(t, record.CreatedAt)
assert.Equal(t, data, record.Data)
@ -32,21 +36,26 @@ func TestDB(t *testing.T) {
})
t.Run("delete record", func(t *testing.T) {
assert.NoError(t, db.Delete(ctx, "abcd"))
record := db.Get(ctx, "abcd")
record, err := db.Get(ctx, "abcd")
require.NoError(t, err)
if assert.NotNil(t, record) {
assert.NotNil(t, record.DeletedAt)
}
})
t.Run("clear deleted", func(t *testing.T) {
db.ClearDeleted(ctx, time.Now().Add(time.Second))
assert.Nil(t, db.Get(ctx, "abcd"))
record, err := db.Get(ctx, "abcd")
require.Error(t, err)
assert.Nil(t, record)
})
t.Run("keep remaining", func(t *testing.T) {
data := new(anypb.Any)
assert.NoError(t, db.Put(ctx, "abcd", data))
assert.NoError(t, db.Delete(ctx, "abcd"))
db.ClearDeleted(ctx, time.Now().Add(-10*time.Second))
assert.NotNil(t, db.Get(ctx, "abcd"))
record, err := db.Get(ctx, "abcd")
require.NoError(t, err)
assert.NotNil(t, record)
db.ClearDeleted(ctx, time.Now().Add(time.Second))
})
t.Run("list", func(t *testing.T) {
@ -55,8 +64,14 @@ func TestDB(t *testing.T) {
assert.NoError(t, db.Put(ctx, fmt.Sprintf("%02d", i), data))
}
assert.Len(t, db.List(ctx, ""), 10)
assert.Len(t, db.List(ctx, "00000000000A"), 4)
assert.Len(t, db.List(ctx, "00000000000F"), 0)
records, err := db.List(ctx, "")
require.NoError(t, err)
assert.Len(t, records, 10)
records, err = db.List(ctx, "00000000000A")
require.NoError(t, err)
assert.Len(t, records, 4)
records, err = db.List(ctx, "00000000000F")
require.NoError(t, err)
assert.Len(t, records, 0)
})
}

View file

@ -3,7 +3,6 @@ package redis
import (
"context"
"errors"
"fmt"
"strconv"
"sync/atomic"
@ -70,8 +69,8 @@ func New(address, recordType string, deletePermanentAfter int64) (*DB, error) {
func (db *DB) Put(ctx context.Context, id string, data *anypb.Any) error {
c := db.pool.Get()
defer c.Close()
record := db.Get(ctx, id)
if record == nil {
record, err := db.Get(ctx, id)
if err != nil {
record = new(databroker.Record)
record.CreatedAt = ptypes.TimestampNow()
}
@ -97,27 +96,27 @@ func (db *DB) Put(ctx context.Context, id string, data *anypb.Any) error {
}
// Get retrieves a record from redis.
func (db *DB) Get(_ context.Context, id string) *databroker.Record {
func (db *DB) Get(_ context.Context, id string) (*databroker.Record, error) {
c := db.pool.Get()
defer c.Close()
b, err := redis.Bytes(c.Do("HGET", db.recordType, id))
if err != nil {
return nil
return nil, err
}
return db.toPbRecord(b)
}
// GetAll retrieves all records from redis.
func (db *DB) GetAll(ctx context.Context) []*databroker.Record {
func (db *DB) GetAll(ctx context.Context) ([]*databroker.Record, error) {
return db.getAll(ctx, func(record *databroker.Record) bool { return true })
}
// List retrieves all records since given version.
//
// "version" is in hex format, invalid version will be treated as 0.
func (db *DB) List(ctx context.Context, sinceVersion string) []*databroker.Record {
func (db *DB) List(ctx context.Context, sinceVersion string) ([]*databroker.Record, error) {
c := db.pool.Get()
defer c.Close()
@ -128,18 +127,22 @@ func (db *DB) List(ctx context.Context, sinceVersion string) []*databroker.Recor
ids, err := redis.Strings(c.Do("ZRANGEBYSCORE", db.versionSet, fmt.Sprintf("(%d", v), "+inf"))
if err != nil {
return nil
return nil, err
}
records := make([]*databroker.Record, 0, len(ids))
pbRecords := make([]*databroker.Record, 0, len(ids))
for _, id := range ids {
b, err := redis.Bytes(c.Do("HGET", db.recordType, id))
if err != nil {
continue
return nil, err
}
records = append(records, db.toPbRecord(b))
pbRecord, err := db.toPbRecord(b)
if err != nil {
return nil, err
}
pbRecords = append(pbRecords, pbRecord)
}
return records
return pbRecords, nil
}
// Delete sets a record DeletedAt field and set its TTL.
@ -147,9 +150,9 @@ func (db *DB) Delete(ctx context.Context, id string) error {
c := db.pool.Get()
defer c.Close()
r := db.Get(ctx, id)
if r == nil {
return errors.New("not found")
r, err := db.Get(ctx, id)
if err != nil {
return fmt.Errorf("failed to get record: %w", err)
}
r.DeletedAt = ptypes.TimestampNow()
r.Version = fmt.Sprintf("%012X", atomic.AddUint64(&db.lastVersion, 1))
@ -177,8 +180,8 @@ func (db *DB) ClearDeleted(_ context.Context, cutoff time.Time) {
ids, _ := redis.Strings(c.Do("SMEMBERS", db.deletedSet))
for _, id := range ids {
b, _ := redis.Bytes(c.Do("HGET", db.recordType, id))
record := db.toPbRecord(b)
if record == nil {
record, err := db.toPbRecord(b)
if err != nil {
continue
}
@ -195,7 +198,7 @@ func (db *DB) ClearDeleted(_ context.Context, cutoff time.Time) {
}
}
func (db *DB) getAll(_ context.Context, filter func(record *databroker.Record) bool) []*databroker.Record {
func (db *DB) getAll(_ context.Context, filter func(record *databroker.Record) bool) ([]*databroker.Record, error) {
c := db.pool.Get()
defer c.Close()
iter := 0
@ -203,16 +206,16 @@ func (db *DB) getAll(_ context.Context, filter func(record *databroker.Record) b
for {
arr, err := redis.Values(c.Do("HSCAN", db.recordType, iter, "MATCH", "*"))
if err != nil {
return nil
return nil, err
}
iter, _ = redis.Int(arr[0], nil)
pairs, _ := redis.StringMap(arr[1], nil)
for _, v := range pairs {
record := db.toPbRecord([]byte(v))
if record == nil {
continue
record, err := db.toPbRecord([]byte(v))
if err != nil {
return nil, err
}
if filter(record) {
records = append(records, record)
@ -224,15 +227,15 @@ func (db *DB) getAll(_ context.Context, filter func(record *databroker.Record) b
}
}
return records
return records, nil
}
func (db *DB) toPbRecord(b []byte) *databroker.Record {
func (db *DB) toPbRecord(b []byte) (*databroker.Record, error) {
record := &databroker.Record{}
if err := proto.Unmarshal(b, record); err != nil {
return nil
return nil, err
}
return record
return record, nil
}
func (db *DB) tx(c redis.Conn, commands []map[string][]interface{}) error {

View file

@ -40,12 +40,15 @@ func TestDB(t *testing.T) {
cleanup(c, db, t)
t.Run("get missing record", func(t *testing.T) {
assert.Nil(t, db.Get(ctx, id))
record, err := db.Get(ctx, id)
assert.Error(t, err)
assert.Nil(t, record)
})
t.Run("get record", func(t *testing.T) {
data := new(anypb.Any)
assert.NoError(t, db.Put(ctx, id, data))
record := db.Get(ctx, id)
record, err := db.Get(ctx, id)
require.NoError(t, err)
if assert.NotNil(t, record) {
assert.NotNil(t, record.CreatedAt)
assert.Equal(t, data, record.Data)
@ -57,22 +60,29 @@ func TestDB(t *testing.T) {
})
t.Run("delete record", func(t *testing.T) {
assert.NoError(t, db.Delete(ctx, id))
record := db.Get(ctx, id)
record, err := db.Get(ctx, id)
require.NoError(t, err)
require.NotNil(t, record)
assert.NotNil(t, record.DeletedAt)
})
t.Run("clear deleted", func(t *testing.T) {
db.ClearDeleted(ctx, time.Now().Add(time.Second))
assert.Nil(t, db.Get(ctx, id))
record, err := db.Get(ctx, id)
assert.Error(t, err)
assert.Nil(t, record)
})
t.Run("get all", func(t *testing.T) {
assert.Len(t, db.GetAll(ctx), 0)
records, err := db.GetAll(ctx)
assert.NoError(t, err)
assert.Len(t, records, 0)
data := new(anypb.Any)
for _, id := range ids {
assert.NoError(t, db.Put(ctx, id, data))
}
assert.Len(t, db.GetAll(ctx), len(ids))
records, err = db.GetAll(ctx)
assert.NoError(t, err)
assert.Len(t, records, len(ids))
for _, id := range ids {
_, _ = c.Do("DEL", id)
}
@ -87,12 +97,14 @@ func TestDB(t *testing.T) {
assert.NoError(t, db.Put(ctx, id, data))
}
assert.Len(t, db.List(ctx, ""), 10)
assert.Len(t, db.List(ctx, "00000000000A"), 5)
assert.Len(t, db.List(ctx, "00000000000F"), 0)
for _, id := range ids {
_, _ = c.Do("DEL", id)
}
records, err := db.List(ctx, "")
assert.NoError(t, err)
assert.Len(t, records, 10)
records, err = db.List(ctx, "00000000000A")
assert.NoError(t, err)
assert.Len(t, records, 5)
records, err = db.List(ctx, "00000000000F")
assert.NoError(t, err)
assert.Len(t, records, 0)
})
}

View file

@ -16,13 +16,13 @@ type Backend interface {
Put(ctx context.Context, id string, data *anypb.Any) error
// Get is used to retrieve a record.
Get(ctx context.Context, id string) *databroker.Record
Get(ctx context.Context, id string) (*databroker.Record, error)
// GetAll is used to retrieve all the records.
GetAll(ctx context.Context) []*databroker.Record
GetAll(ctx context.Context) ([]*databroker.Record, error)
// List is used to retrieve all the records since a version.
List(ctx context.Context, sinceVersion string) []*databroker.Record
List(ctx context.Context, sinceVersion string) ([]*databroker.Record, error)
// Delete is used to mark a record as deleted.
Delete(ctx context.Context, id string) error