pkg/storage: change backend interface to return error (#1131)

Since when storage backend like redis can be fault in many cases, the
interface should return error for the caller to handle.
This commit is contained in:
Cuong Manh Le 2020-07-24 09:02:37 +07:00 committed by GitHub
parent 90d95b8c10
commit aedfbc4c71
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 108 additions and 69 deletions

View file

@ -86,7 +86,7 @@ func (srv *Server) initVersion() {
} }
// Get version from storage first. // Get version from storage first.
if r := dbServerVersion.Get(context.Background(), serverVersionKey); r != nil { if r, _ := dbServerVersion.Get(context.Background(), serverVersionKey); r != nil {
var sv databroker.ServerVersion var sv databroker.ServerVersion
if err := ptypes.UnmarshalAny(r.GetData(), &sv); err == nil { if err := ptypes.UnmarshalAny(r.GetData(), &sv); err == nil {
srv.log.Debug().Str("server_version", sv.Version).Msg("got db version from DB") srv.log.Debug().Str("server_version", sv.Version).Msg("got db version from DB")
@ -137,8 +137,8 @@ func (srv *Server) Get(ctx context.Context, req *databroker.GetRequest) (*databr
if err != nil { if err != nil {
return nil, err return nil, err
} }
record := db.Get(ctx, req.GetId()) record, err := db.Get(ctx, req.GetId())
if record == nil { if err != nil {
return nil, status.Error(codes.NotFound, "record not found") return nil, status.Error(codes.NotFound, "record not found")
} }
return &databroker.GetResponse{Record: record}, nil return &databroker.GetResponse{Record: record}, nil
@ -156,7 +156,10 @@ func (srv *Server) GetAll(ctx context.Context, req *databroker.GetAllRequest) (*
if err != nil { if err != nil {
return nil, err return nil, err
} }
records := db.GetAll(ctx) records, err := db.GetAll(ctx)
if err != nil {
return nil, err
}
var recordVersion string var recordVersion string
for _, record := range records { for _, record := range records {
if record.GetVersion() > recordVersion { if record.GetVersion() > recordVersion {
@ -188,8 +191,10 @@ func (srv *Server) Set(ctx context.Context, req *databroker.SetRequest) (*databr
if err := db.Put(ctx, req.GetId(), req.GetData()); err != nil { if err := db.Put(ctx, req.GetId(), req.GetData()); err != nil {
return nil, err return nil, err
} }
record := db.Get(ctx, req.GetId()) record, err := db.Get(ctx, req.GetId())
if err != nil {
return nil, err
}
return &databroker.SetResponse{ return &databroker.SetResponse{
Record: record, Record: record,
ServerVersion: srv.version, ServerVersion: srv.version,
@ -220,8 +225,7 @@ func (srv *Server) Sync(req *databroker.SyncRequest, stream databroker.DataBroke
ch := srv.onchange.Bind() ch := srv.onchange.Bind()
defer srv.onchange.Unbind(ch) defer srv.onchange.Unbind(ch)
for { for {
updated := db.List(context.Background(), recordVersion) updated, _ := db.List(context.Background(), recordVersion)
if len(updated) > 0 { if len(updated) > 0 {
sort.Slice(updated, func(i, j int) bool { sort.Slice(updated, func(i, j int) bool {
return updated[i].Version < updated[j].Version return updated[i].Version < updated[j].Version

View file

@ -40,13 +40,15 @@ func TestServer_initVersion(t *testing.T) {
ctx := context.Background() ctx := context.Background()
db, err := srv.getDB(recordTypeServerVersion) db, err := srv.getDB(recordTypeServerVersion)
require.NoError(t, err) require.NoError(t, err)
r := db.Get(ctx, serverVersionKey) r, err := db.Get(ctx, serverVersionKey)
assert.Error(t, err)
assert.Nil(t, r) assert.Nil(t, r)
srvVersion := uuid.New().String() srvVersion := uuid.New().String()
srv.version = srvVersion srv.version = srvVersion
srv.initVersion() srv.initVersion()
assert.Equal(t, srvVersion, srv.version) assert.Equal(t, srvVersion, srv.version)
r = db.Get(ctx, serverVersionKey) r, err = db.Get(ctx, serverVersionKey)
require.NoError(t, err)
assert.NotNil(t, r) assert.NotNil(t, r)
var sv databroker.ServerVersion var sv databroker.ServerVersion
assert.NoError(t, ptypes.UnmarshalAny(r.GetData(), &sv)) assert.NoError(t, ptypes.UnmarshalAny(r.GetData(), &sv))
@ -57,13 +59,15 @@ func TestServer_initVersion(t *testing.T) {
ctx := context.Background() ctx := context.Background()
db, err := srv.getDB(recordTypeServerVersion) db, err := srv.getDB(recordTypeServerVersion)
require.NoError(t, err) require.NoError(t, err)
r := db.Get(ctx, serverVersionKey) r, err := db.Get(ctx, serverVersionKey)
assert.Error(t, err)
assert.Nil(t, r) assert.Nil(t, r)
srv.initVersion() srv.initVersion()
srvVersion := srv.version srvVersion := srv.version
r = db.Get(ctx, serverVersionKey) r, err = db.Get(ctx, serverVersionKey)
require.NoError(t, err)
assert.NotNil(t, r) assert.NotNil(t, r)
var sv databroker.ServerVersion var sv databroker.ServerVersion
assert.NoError(t, ptypes.UnmarshalAny(r.GetData(), &sv)) assert.NoError(t, ptypes.UnmarshalAny(r.GetData(), &sv))

View file

@ -3,6 +3,7 @@ package inmemory
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"sync" "sync"
"sync/atomic" "sync/atomic"
@ -88,26 +89,26 @@ func (db *DB) Delete(_ context.Context, id string) error {
} }
// Get gets a record from the db. // Get gets a record from the db.
func (db *DB) Get(_ context.Context, id string) *databroker.Record { func (db *DB) Get(_ context.Context, id string) (*databroker.Record, error) {
record, ok := db.byID.Get(byIDRecord{Record: &databroker.Record{Id: id}}).(byIDRecord) record, ok := db.byID.Get(byIDRecord{Record: &databroker.Record{Id: id}}).(byIDRecord)
if !ok { if !ok {
return nil return nil, errors.New("not found")
} }
return record.Record return record.Record, nil
} }
// GetAll gets all the records in the db. // GetAll gets all the records in the db.
func (db *DB) GetAll(_ context.Context) []*databroker.Record { func (db *DB) GetAll(_ context.Context) ([]*databroker.Record, error) {
var records []*databroker.Record var records []*databroker.Record
db.byID.Ascend(func(item btree.Item) bool { db.byID.Ascend(func(item btree.Item) bool {
records = append(records, item.(byIDRecord).Record) records = append(records, item.(byIDRecord).Record)
return true return true
}) })
return records return records, nil
} }
// List lists all the changes since the given version. // List lists all the changes since the given version.
func (db *DB) List(_ context.Context, sinceVersion string) []*databroker.Record { func (db *DB) List(_ context.Context, sinceVersion string) ([]*databroker.Record, error) {
var records []*databroker.Record var records []*databroker.Record
db.byVersion.AscendGreaterOrEqual(byVersionRecord{Record: &databroker.Record{Version: sinceVersion}}, func(i btree.Item) bool { db.byVersion.AscendGreaterOrEqual(byVersionRecord{Record: &databroker.Record{Version: sinceVersion}}, func(i btree.Item) bool {
record := i.(byVersionRecord) record := i.(byVersionRecord)
@ -116,7 +117,7 @@ func (db *DB) List(_ context.Context, sinceVersion string) []*databroker.Record
} }
return true return true
}) })
return records return records, nil
} }
// Put replaces or inserts a record in the db. // Put replaces or inserts a record in the db.

View file

@ -7,6 +7,7 @@ import (
"time" "time"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/types/known/anypb" "google.golang.org/protobuf/types/known/anypb"
) )
@ -14,12 +15,15 @@ func TestDB(t *testing.T) {
ctx := context.Background() ctx := context.Background()
db := NewDB("example", 2) db := NewDB("example", 2)
t.Run("get missing record", func(t *testing.T) { t.Run("get missing record", func(t *testing.T) {
assert.Nil(t, db.Get(ctx, "abcd")) record, err := db.Get(ctx, "abcd")
require.Error(t, err)
assert.Nil(t, record)
}) })
t.Run("get record", func(t *testing.T) { t.Run("get record", func(t *testing.T) {
data := new(anypb.Any) data := new(anypb.Any)
assert.NoError(t, db.Put(ctx, "abcd", data)) assert.NoError(t, db.Put(ctx, "abcd", data))
record := db.Get(ctx, "abcd") record, err := db.Get(ctx, "abcd")
require.NoError(t, err)
if assert.NotNil(t, record) { if assert.NotNil(t, record) {
assert.NotNil(t, record.CreatedAt) assert.NotNil(t, record.CreatedAt)
assert.Equal(t, data, record.Data) assert.Equal(t, data, record.Data)
@ -32,21 +36,26 @@ func TestDB(t *testing.T) {
}) })
t.Run("delete record", func(t *testing.T) { t.Run("delete record", func(t *testing.T) {
assert.NoError(t, db.Delete(ctx, "abcd")) assert.NoError(t, db.Delete(ctx, "abcd"))
record := db.Get(ctx, "abcd") record, err := db.Get(ctx, "abcd")
require.NoError(t, err)
if assert.NotNil(t, record) { if assert.NotNil(t, record) {
assert.NotNil(t, record.DeletedAt) assert.NotNil(t, record.DeletedAt)
} }
}) })
t.Run("clear deleted", func(t *testing.T) { t.Run("clear deleted", func(t *testing.T) {
db.ClearDeleted(ctx, time.Now().Add(time.Second)) db.ClearDeleted(ctx, time.Now().Add(time.Second))
assert.Nil(t, db.Get(ctx, "abcd")) record, err := db.Get(ctx, "abcd")
require.Error(t, err)
assert.Nil(t, record)
}) })
t.Run("keep remaining", func(t *testing.T) { t.Run("keep remaining", func(t *testing.T) {
data := new(anypb.Any) data := new(anypb.Any)
assert.NoError(t, db.Put(ctx, "abcd", data)) assert.NoError(t, db.Put(ctx, "abcd", data))
assert.NoError(t, db.Delete(ctx, "abcd")) assert.NoError(t, db.Delete(ctx, "abcd"))
db.ClearDeleted(ctx, time.Now().Add(-10*time.Second)) db.ClearDeleted(ctx, time.Now().Add(-10*time.Second))
assert.NotNil(t, db.Get(ctx, "abcd")) record, err := db.Get(ctx, "abcd")
require.NoError(t, err)
assert.NotNil(t, record)
db.ClearDeleted(ctx, time.Now().Add(time.Second)) db.ClearDeleted(ctx, time.Now().Add(time.Second))
}) })
t.Run("list", func(t *testing.T) { t.Run("list", func(t *testing.T) {
@ -55,8 +64,14 @@ func TestDB(t *testing.T) {
assert.NoError(t, db.Put(ctx, fmt.Sprintf("%02d", i), data)) assert.NoError(t, db.Put(ctx, fmt.Sprintf("%02d", i), data))
} }
assert.Len(t, db.List(ctx, ""), 10) records, err := db.List(ctx, "")
assert.Len(t, db.List(ctx, "00000000000A"), 4) require.NoError(t, err)
assert.Len(t, db.List(ctx, "00000000000F"), 0) assert.Len(t, records, 10)
records, err = db.List(ctx, "00000000000A")
require.NoError(t, err)
assert.Len(t, records, 4)
records, err = db.List(ctx, "00000000000F")
require.NoError(t, err)
assert.Len(t, records, 0)
}) })
} }

View file

@ -3,7 +3,6 @@ package redis
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"strconv" "strconv"
"sync/atomic" "sync/atomic"
@ -70,8 +69,8 @@ func New(address, recordType string, deletePermanentAfter int64) (*DB, error) {
func (db *DB) Put(ctx context.Context, id string, data *anypb.Any) error { func (db *DB) Put(ctx context.Context, id string, data *anypb.Any) error {
c := db.pool.Get() c := db.pool.Get()
defer c.Close() defer c.Close()
record := db.Get(ctx, id) record, err := db.Get(ctx, id)
if record == nil { if err != nil {
record = new(databroker.Record) record = new(databroker.Record)
record.CreatedAt = ptypes.TimestampNow() record.CreatedAt = ptypes.TimestampNow()
} }
@ -97,27 +96,27 @@ func (db *DB) Put(ctx context.Context, id string, data *anypb.Any) error {
} }
// Get retrieves a record from redis. // Get retrieves a record from redis.
func (db *DB) Get(_ context.Context, id string) *databroker.Record { func (db *DB) Get(_ context.Context, id string) (*databroker.Record, error) {
c := db.pool.Get() c := db.pool.Get()
defer c.Close() defer c.Close()
b, err := redis.Bytes(c.Do("HGET", db.recordType, id)) b, err := redis.Bytes(c.Do("HGET", db.recordType, id))
if err != nil { if err != nil {
return nil return nil, err
} }
return db.toPbRecord(b) return db.toPbRecord(b)
} }
// GetAll retrieves all records from redis. // GetAll retrieves all records from redis.
func (db *DB) GetAll(ctx context.Context) []*databroker.Record { func (db *DB) GetAll(ctx context.Context) ([]*databroker.Record, error) {
return db.getAll(ctx, func(record *databroker.Record) bool { return true }) return db.getAll(ctx, func(record *databroker.Record) bool { return true })
} }
// List retrieves all records since given version. // List retrieves all records since given version.
// //
// "version" is in hex format, invalid version will be treated as 0. // "version" is in hex format, invalid version will be treated as 0.
func (db *DB) List(ctx context.Context, sinceVersion string) []*databroker.Record { func (db *DB) List(ctx context.Context, sinceVersion string) ([]*databroker.Record, error) {
c := db.pool.Get() c := db.pool.Get()
defer c.Close() defer c.Close()
@ -128,18 +127,22 @@ func (db *DB) List(ctx context.Context, sinceVersion string) []*databroker.Recor
ids, err := redis.Strings(c.Do("ZRANGEBYSCORE", db.versionSet, fmt.Sprintf("(%d", v), "+inf")) ids, err := redis.Strings(c.Do("ZRANGEBYSCORE", db.versionSet, fmt.Sprintf("(%d", v), "+inf"))
if err != nil { if err != nil {
return nil return nil, err
} }
records := make([]*databroker.Record, 0, len(ids)) pbRecords := make([]*databroker.Record, 0, len(ids))
for _, id := range ids { for _, id := range ids {
b, err := redis.Bytes(c.Do("HGET", db.recordType, id)) b, err := redis.Bytes(c.Do("HGET", db.recordType, id))
if err != nil { if err != nil {
continue return nil, err
} }
records = append(records, db.toPbRecord(b)) pbRecord, err := db.toPbRecord(b)
if err != nil {
return nil, err
}
pbRecords = append(pbRecords, pbRecord)
} }
return records return pbRecords, nil
} }
// Delete sets a record DeletedAt field and set its TTL. // Delete sets a record DeletedAt field and set its TTL.
@ -147,9 +150,9 @@ func (db *DB) Delete(ctx context.Context, id string) error {
c := db.pool.Get() c := db.pool.Get()
defer c.Close() defer c.Close()
r := db.Get(ctx, id) r, err := db.Get(ctx, id)
if r == nil { if err != nil {
return errors.New("not found") return fmt.Errorf("failed to get record: %w", err)
} }
r.DeletedAt = ptypes.TimestampNow() r.DeletedAt = ptypes.TimestampNow()
r.Version = fmt.Sprintf("%012X", atomic.AddUint64(&db.lastVersion, 1)) r.Version = fmt.Sprintf("%012X", atomic.AddUint64(&db.lastVersion, 1))
@ -177,8 +180,8 @@ func (db *DB) ClearDeleted(_ context.Context, cutoff time.Time) {
ids, _ := redis.Strings(c.Do("SMEMBERS", db.deletedSet)) ids, _ := redis.Strings(c.Do("SMEMBERS", db.deletedSet))
for _, id := range ids { for _, id := range ids {
b, _ := redis.Bytes(c.Do("HGET", db.recordType, id)) b, _ := redis.Bytes(c.Do("HGET", db.recordType, id))
record := db.toPbRecord(b) record, err := db.toPbRecord(b)
if record == nil { if err != nil {
continue continue
} }
@ -195,7 +198,7 @@ func (db *DB) ClearDeleted(_ context.Context, cutoff time.Time) {
} }
} }
func (db *DB) getAll(_ context.Context, filter func(record *databroker.Record) bool) []*databroker.Record { func (db *DB) getAll(_ context.Context, filter func(record *databroker.Record) bool) ([]*databroker.Record, error) {
c := db.pool.Get() c := db.pool.Get()
defer c.Close() defer c.Close()
iter := 0 iter := 0
@ -203,16 +206,16 @@ func (db *DB) getAll(_ context.Context, filter func(record *databroker.Record) b
for { for {
arr, err := redis.Values(c.Do("HSCAN", db.recordType, iter, "MATCH", "*")) arr, err := redis.Values(c.Do("HSCAN", db.recordType, iter, "MATCH", "*"))
if err != nil { if err != nil {
return nil return nil, err
} }
iter, _ = redis.Int(arr[0], nil) iter, _ = redis.Int(arr[0], nil)
pairs, _ := redis.StringMap(arr[1], nil) pairs, _ := redis.StringMap(arr[1], nil)
for _, v := range pairs { for _, v := range pairs {
record := db.toPbRecord([]byte(v)) record, err := db.toPbRecord([]byte(v))
if record == nil { if err != nil {
continue return nil, err
} }
if filter(record) { if filter(record) {
records = append(records, record) records = append(records, record)
@ -224,15 +227,15 @@ func (db *DB) getAll(_ context.Context, filter func(record *databroker.Record) b
} }
} }
return records return records, nil
} }
func (db *DB) toPbRecord(b []byte) *databroker.Record { func (db *DB) toPbRecord(b []byte) (*databroker.Record, error) {
record := &databroker.Record{} record := &databroker.Record{}
if err := proto.Unmarshal(b, record); err != nil { if err := proto.Unmarshal(b, record); err != nil {
return nil return nil, err
} }
return record return record, nil
} }
func (db *DB) tx(c redis.Conn, commands []map[string][]interface{}) error { func (db *DB) tx(c redis.Conn, commands []map[string][]interface{}) error {

View file

@ -40,12 +40,15 @@ func TestDB(t *testing.T) {
cleanup(c, db, t) cleanup(c, db, t)
t.Run("get missing record", func(t *testing.T) { t.Run("get missing record", func(t *testing.T) {
assert.Nil(t, db.Get(ctx, id)) record, err := db.Get(ctx, id)
assert.Error(t, err)
assert.Nil(t, record)
}) })
t.Run("get record", func(t *testing.T) { t.Run("get record", func(t *testing.T) {
data := new(anypb.Any) data := new(anypb.Any)
assert.NoError(t, db.Put(ctx, id, data)) assert.NoError(t, db.Put(ctx, id, data))
record := db.Get(ctx, id) record, err := db.Get(ctx, id)
require.NoError(t, err)
if assert.NotNil(t, record) { if assert.NotNil(t, record) {
assert.NotNil(t, record.CreatedAt) assert.NotNil(t, record.CreatedAt)
assert.Equal(t, data, record.Data) assert.Equal(t, data, record.Data)
@ -57,22 +60,29 @@ func TestDB(t *testing.T) {
}) })
t.Run("delete record", func(t *testing.T) { t.Run("delete record", func(t *testing.T) {
assert.NoError(t, db.Delete(ctx, id)) assert.NoError(t, db.Delete(ctx, id))
record := db.Get(ctx, id) record, err := db.Get(ctx, id)
require.NoError(t, err)
require.NotNil(t, record) require.NotNil(t, record)
assert.NotNil(t, record.DeletedAt) assert.NotNil(t, record.DeletedAt)
}) })
t.Run("clear deleted", func(t *testing.T) { t.Run("clear deleted", func(t *testing.T) {
db.ClearDeleted(ctx, time.Now().Add(time.Second)) db.ClearDeleted(ctx, time.Now().Add(time.Second))
assert.Nil(t, db.Get(ctx, id)) record, err := db.Get(ctx, id)
assert.Error(t, err)
assert.Nil(t, record)
}) })
t.Run("get all", func(t *testing.T) { t.Run("get all", func(t *testing.T) {
assert.Len(t, db.GetAll(ctx), 0) records, err := db.GetAll(ctx)
assert.NoError(t, err)
assert.Len(t, records, 0)
data := new(anypb.Any) data := new(anypb.Any)
for _, id := range ids { for _, id := range ids {
assert.NoError(t, db.Put(ctx, id, data)) assert.NoError(t, db.Put(ctx, id, data))
} }
assert.Len(t, db.GetAll(ctx), len(ids)) records, err = db.GetAll(ctx)
assert.NoError(t, err)
assert.Len(t, records, len(ids))
for _, id := range ids { for _, id := range ids {
_, _ = c.Do("DEL", id) _, _ = c.Do("DEL", id)
} }
@ -87,12 +97,14 @@ func TestDB(t *testing.T) {
assert.NoError(t, db.Put(ctx, id, data)) assert.NoError(t, db.Put(ctx, id, data))
} }
assert.Len(t, db.List(ctx, ""), 10) records, err := db.List(ctx, "")
assert.Len(t, db.List(ctx, "00000000000A"), 5) assert.NoError(t, err)
assert.Len(t, db.List(ctx, "00000000000F"), 0) assert.Len(t, records, 10)
records, err = db.List(ctx, "00000000000A")
for _, id := range ids { assert.NoError(t, err)
_, _ = c.Do("DEL", id) assert.Len(t, records, 5)
} records, err = db.List(ctx, "00000000000F")
assert.NoError(t, err)
assert.Len(t, records, 0)
}) })
} }

View file

@ -16,13 +16,13 @@ type Backend interface {
Put(ctx context.Context, id string, data *anypb.Any) error Put(ctx context.Context, id string, data *anypb.Any) error
// Get is used to retrieve a record. // Get is used to retrieve a record.
Get(ctx context.Context, id string) *databroker.Record Get(ctx context.Context, id string) (*databroker.Record, error)
// GetAll is used to retrieve all the records. // GetAll is used to retrieve all the records.
GetAll(ctx context.Context) []*databroker.Record GetAll(ctx context.Context) ([]*databroker.Record, error)
// List is used to retrieve all the records since a version. // List is used to retrieve all the records since a version.
List(ctx context.Context, sinceVersion string) []*databroker.Record List(ctx context.Context, sinceVersion string) ([]*databroker.Record, error)
// Delete is used to mark a record as deleted. // Delete is used to mark a record as deleted.
Delete(ctx context.Context, id string) error Delete(ctx context.Context, id string) error