From aedfbc4c715b7ce0c2e584de57dff790c5e6f35e Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Fri, 24 Jul 2020 09:02:37 +0700 Subject: [PATCH] pkg/storage: change backend interface to return error (#1131) Since when storage backend like redis can be fault in many cases, the interface should return error for the caller to handle. --- internal/databroker/server.go | 20 ++++++---- internal/databroker/server_test.go | 12 ++++-- pkg/storage/inmemory/inmemory.go | 15 ++++---- pkg/storage/inmemory/inmemory_test.go | 31 +++++++++++---- pkg/storage/redis/redis.go | 55 ++++++++++++++------------- pkg/storage/redis/redis_test.go | 38 +++++++++++------- pkg/storage/storage.go | 6 +-- 7 files changed, 108 insertions(+), 69 deletions(-) diff --git a/internal/databroker/server.go b/internal/databroker/server.go index 56d6b8082..3e1356765 100644 --- a/internal/databroker/server.go +++ b/internal/databroker/server.go @@ -86,7 +86,7 @@ func (srv *Server) initVersion() { } // Get version from storage first. - if r := dbServerVersion.Get(context.Background(), serverVersionKey); r != nil { + if r, _ := dbServerVersion.Get(context.Background(), serverVersionKey); r != nil { var sv databroker.ServerVersion if err := ptypes.UnmarshalAny(r.GetData(), &sv); err == nil { srv.log.Debug().Str("server_version", sv.Version).Msg("got db version from DB") @@ -137,8 +137,8 @@ func (srv *Server) Get(ctx context.Context, req *databroker.GetRequest) (*databr if err != nil { return nil, err } - record := db.Get(ctx, req.GetId()) - if record == nil { + record, err := db.Get(ctx, req.GetId()) + if err != nil { return nil, status.Error(codes.NotFound, "record not found") } return &databroker.GetResponse{Record: record}, nil @@ -156,7 +156,10 @@ func (srv *Server) GetAll(ctx context.Context, req *databroker.GetAllRequest) (* if err != nil { return nil, err } - records := db.GetAll(ctx) + records, err := db.GetAll(ctx) + if err != nil { + return nil, err + } var recordVersion string for _, record := range records { if record.GetVersion() > recordVersion { @@ -188,8 +191,10 @@ func (srv *Server) Set(ctx context.Context, req *databroker.SetRequest) (*databr if err := db.Put(ctx, req.GetId(), req.GetData()); err != nil { return nil, err } - record := db.Get(ctx, req.GetId()) - + record, err := db.Get(ctx, req.GetId()) + if err != nil { + return nil, err + } return &databroker.SetResponse{ Record: record, ServerVersion: srv.version, @@ -220,8 +225,7 @@ func (srv *Server) Sync(req *databroker.SyncRequest, stream databroker.DataBroke ch := srv.onchange.Bind() defer srv.onchange.Unbind(ch) for { - updated := db.List(context.Background(), recordVersion) - + updated, _ := db.List(context.Background(), recordVersion) if len(updated) > 0 { sort.Slice(updated, func(i, j int) bool { return updated[i].Version < updated[j].Version diff --git a/internal/databroker/server_test.go b/internal/databroker/server_test.go index 7eca2fca1..8a013e2f7 100644 --- a/internal/databroker/server_test.go +++ b/internal/databroker/server_test.go @@ -40,13 +40,15 @@ func TestServer_initVersion(t *testing.T) { ctx := context.Background() db, err := srv.getDB(recordTypeServerVersion) require.NoError(t, err) - r := db.Get(ctx, serverVersionKey) + r, err := db.Get(ctx, serverVersionKey) + assert.Error(t, err) assert.Nil(t, r) srvVersion := uuid.New().String() srv.version = srvVersion srv.initVersion() assert.Equal(t, srvVersion, srv.version) - r = db.Get(ctx, serverVersionKey) + r, err = db.Get(ctx, serverVersionKey) + require.NoError(t, err) assert.NotNil(t, r) var sv databroker.ServerVersion assert.NoError(t, ptypes.UnmarshalAny(r.GetData(), &sv)) @@ -57,13 +59,15 @@ func TestServer_initVersion(t *testing.T) { ctx := context.Background() db, err := srv.getDB(recordTypeServerVersion) require.NoError(t, err) - r := db.Get(ctx, serverVersionKey) + r, err := db.Get(ctx, serverVersionKey) + assert.Error(t, err) assert.Nil(t, r) srv.initVersion() srvVersion := srv.version - r = db.Get(ctx, serverVersionKey) + r, err = db.Get(ctx, serverVersionKey) + require.NoError(t, err) assert.NotNil(t, r) var sv databroker.ServerVersion assert.NoError(t, ptypes.UnmarshalAny(r.GetData(), &sv)) diff --git a/pkg/storage/inmemory/inmemory.go b/pkg/storage/inmemory/inmemory.go index 941624950..76010a354 100644 --- a/pkg/storage/inmemory/inmemory.go +++ b/pkg/storage/inmemory/inmemory.go @@ -3,6 +3,7 @@ package inmemory import ( "context" + "errors" "fmt" "sync" "sync/atomic" @@ -88,26 +89,26 @@ func (db *DB) Delete(_ context.Context, id string) error { } // Get gets a record from the db. -func (db *DB) Get(_ context.Context, id string) *databroker.Record { +func (db *DB) Get(_ context.Context, id string) (*databroker.Record, error) { record, ok := db.byID.Get(byIDRecord{Record: &databroker.Record{Id: id}}).(byIDRecord) if !ok { - return nil + return nil, errors.New("not found") } - return record.Record + return record.Record, nil } // GetAll gets all the records in the db. -func (db *DB) GetAll(_ context.Context) []*databroker.Record { +func (db *DB) GetAll(_ context.Context) ([]*databroker.Record, error) { var records []*databroker.Record db.byID.Ascend(func(item btree.Item) bool { records = append(records, item.(byIDRecord).Record) return true }) - return records + return records, nil } // List lists all the changes since the given version. -func (db *DB) List(_ context.Context, sinceVersion string) []*databroker.Record { +func (db *DB) List(_ context.Context, sinceVersion string) ([]*databroker.Record, error) { var records []*databroker.Record db.byVersion.AscendGreaterOrEqual(byVersionRecord{Record: &databroker.Record{Version: sinceVersion}}, func(i btree.Item) bool { record := i.(byVersionRecord) @@ -116,7 +117,7 @@ func (db *DB) List(_ context.Context, sinceVersion string) []*databroker.Record } return true }) - return records + return records, nil } // Put replaces or inserts a record in the db. diff --git a/pkg/storage/inmemory/inmemory_test.go b/pkg/storage/inmemory/inmemory_test.go index 0b9eda112..78c970838 100644 --- a/pkg/storage/inmemory/inmemory_test.go +++ b/pkg/storage/inmemory/inmemory_test.go @@ -7,6 +7,7 @@ import ( "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "google.golang.org/protobuf/types/known/anypb" ) @@ -14,12 +15,15 @@ func TestDB(t *testing.T) { ctx := context.Background() db := NewDB("example", 2) t.Run("get missing record", func(t *testing.T) { - assert.Nil(t, db.Get(ctx, "abcd")) + record, err := db.Get(ctx, "abcd") + require.Error(t, err) + assert.Nil(t, record) }) t.Run("get record", func(t *testing.T) { data := new(anypb.Any) assert.NoError(t, db.Put(ctx, "abcd", data)) - record := db.Get(ctx, "abcd") + record, err := db.Get(ctx, "abcd") + require.NoError(t, err) if assert.NotNil(t, record) { assert.NotNil(t, record.CreatedAt) assert.Equal(t, data, record.Data) @@ -32,21 +36,26 @@ func TestDB(t *testing.T) { }) t.Run("delete record", func(t *testing.T) { assert.NoError(t, db.Delete(ctx, "abcd")) - record := db.Get(ctx, "abcd") + record, err := db.Get(ctx, "abcd") + require.NoError(t, err) if assert.NotNil(t, record) { assert.NotNil(t, record.DeletedAt) } }) t.Run("clear deleted", func(t *testing.T) { db.ClearDeleted(ctx, time.Now().Add(time.Second)) - assert.Nil(t, db.Get(ctx, "abcd")) + record, err := db.Get(ctx, "abcd") + require.Error(t, err) + assert.Nil(t, record) }) t.Run("keep remaining", func(t *testing.T) { data := new(anypb.Any) assert.NoError(t, db.Put(ctx, "abcd", data)) assert.NoError(t, db.Delete(ctx, "abcd")) db.ClearDeleted(ctx, time.Now().Add(-10*time.Second)) - assert.NotNil(t, db.Get(ctx, "abcd")) + record, err := db.Get(ctx, "abcd") + require.NoError(t, err) + assert.NotNil(t, record) db.ClearDeleted(ctx, time.Now().Add(time.Second)) }) t.Run("list", func(t *testing.T) { @@ -55,8 +64,14 @@ func TestDB(t *testing.T) { assert.NoError(t, db.Put(ctx, fmt.Sprintf("%02d", i), data)) } - assert.Len(t, db.List(ctx, ""), 10) - assert.Len(t, db.List(ctx, "00000000000A"), 4) - assert.Len(t, db.List(ctx, "00000000000F"), 0) + records, err := db.List(ctx, "") + require.NoError(t, err) + assert.Len(t, records, 10) + records, err = db.List(ctx, "00000000000A") + require.NoError(t, err) + assert.Len(t, records, 4) + records, err = db.List(ctx, "00000000000F") + require.NoError(t, err) + assert.Len(t, records, 0) }) } diff --git a/pkg/storage/redis/redis.go b/pkg/storage/redis/redis.go index 5f786acd9..561966e33 100644 --- a/pkg/storage/redis/redis.go +++ b/pkg/storage/redis/redis.go @@ -3,7 +3,6 @@ package redis import ( "context" - "errors" "fmt" "strconv" "sync/atomic" @@ -70,8 +69,8 @@ func New(address, recordType string, deletePermanentAfter int64) (*DB, error) { func (db *DB) Put(ctx context.Context, id string, data *anypb.Any) error { c := db.pool.Get() defer c.Close() - record := db.Get(ctx, id) - if record == nil { + record, err := db.Get(ctx, id) + if err != nil { record = new(databroker.Record) record.CreatedAt = ptypes.TimestampNow() } @@ -97,27 +96,27 @@ func (db *DB) Put(ctx context.Context, id string, data *anypb.Any) error { } // Get retrieves a record from redis. -func (db *DB) Get(_ context.Context, id string) *databroker.Record { +func (db *DB) Get(_ context.Context, id string) (*databroker.Record, error) { c := db.pool.Get() defer c.Close() b, err := redis.Bytes(c.Do("HGET", db.recordType, id)) if err != nil { - return nil + return nil, err } return db.toPbRecord(b) } // GetAll retrieves all records from redis. -func (db *DB) GetAll(ctx context.Context) []*databroker.Record { +func (db *DB) GetAll(ctx context.Context) ([]*databroker.Record, error) { return db.getAll(ctx, func(record *databroker.Record) bool { return true }) } // List retrieves all records since given version. // // "version" is in hex format, invalid version will be treated as 0. -func (db *DB) List(ctx context.Context, sinceVersion string) []*databroker.Record { +func (db *DB) List(ctx context.Context, sinceVersion string) ([]*databroker.Record, error) { c := db.pool.Get() defer c.Close() @@ -128,18 +127,22 @@ func (db *DB) List(ctx context.Context, sinceVersion string) []*databroker.Recor ids, err := redis.Strings(c.Do("ZRANGEBYSCORE", db.versionSet, fmt.Sprintf("(%d", v), "+inf")) if err != nil { - return nil + return nil, err } - records := make([]*databroker.Record, 0, len(ids)) + pbRecords := make([]*databroker.Record, 0, len(ids)) for _, id := range ids { b, err := redis.Bytes(c.Do("HGET", db.recordType, id)) if err != nil { - continue + return nil, err } - records = append(records, db.toPbRecord(b)) + pbRecord, err := db.toPbRecord(b) + if err != nil { + return nil, err + } + pbRecords = append(pbRecords, pbRecord) } - return records + return pbRecords, nil } // Delete sets a record DeletedAt field and set its TTL. @@ -147,9 +150,9 @@ func (db *DB) Delete(ctx context.Context, id string) error { c := db.pool.Get() defer c.Close() - r := db.Get(ctx, id) - if r == nil { - return errors.New("not found") + r, err := db.Get(ctx, id) + if err != nil { + return fmt.Errorf("failed to get record: %w", err) } r.DeletedAt = ptypes.TimestampNow() r.Version = fmt.Sprintf("%012X", atomic.AddUint64(&db.lastVersion, 1)) @@ -177,8 +180,8 @@ func (db *DB) ClearDeleted(_ context.Context, cutoff time.Time) { ids, _ := redis.Strings(c.Do("SMEMBERS", db.deletedSet)) for _, id := range ids { b, _ := redis.Bytes(c.Do("HGET", db.recordType, id)) - record := db.toPbRecord(b) - if record == nil { + record, err := db.toPbRecord(b) + if err != nil { continue } @@ -195,7 +198,7 @@ func (db *DB) ClearDeleted(_ context.Context, cutoff time.Time) { } } -func (db *DB) getAll(_ context.Context, filter func(record *databroker.Record) bool) []*databroker.Record { +func (db *DB) getAll(_ context.Context, filter func(record *databroker.Record) bool) ([]*databroker.Record, error) { c := db.pool.Get() defer c.Close() iter := 0 @@ -203,16 +206,16 @@ func (db *DB) getAll(_ context.Context, filter func(record *databroker.Record) b for { arr, err := redis.Values(c.Do("HSCAN", db.recordType, iter, "MATCH", "*")) if err != nil { - return nil + return nil, err } iter, _ = redis.Int(arr[0], nil) pairs, _ := redis.StringMap(arr[1], nil) for _, v := range pairs { - record := db.toPbRecord([]byte(v)) - if record == nil { - continue + record, err := db.toPbRecord([]byte(v)) + if err != nil { + return nil, err } if filter(record) { records = append(records, record) @@ -224,15 +227,15 @@ func (db *DB) getAll(_ context.Context, filter func(record *databroker.Record) b } } - return records + return records, nil } -func (db *DB) toPbRecord(b []byte) *databroker.Record { +func (db *DB) toPbRecord(b []byte) (*databroker.Record, error) { record := &databroker.Record{} if err := proto.Unmarshal(b, record); err != nil { - return nil + return nil, err } - return record + return record, nil } func (db *DB) tx(c redis.Conn, commands []map[string][]interface{}) error { diff --git a/pkg/storage/redis/redis_test.go b/pkg/storage/redis/redis_test.go index 90322ac9a..05173a3c5 100644 --- a/pkg/storage/redis/redis_test.go +++ b/pkg/storage/redis/redis_test.go @@ -40,12 +40,15 @@ func TestDB(t *testing.T) { cleanup(c, db, t) t.Run("get missing record", func(t *testing.T) { - assert.Nil(t, db.Get(ctx, id)) + record, err := db.Get(ctx, id) + assert.Error(t, err) + assert.Nil(t, record) }) t.Run("get record", func(t *testing.T) { data := new(anypb.Any) assert.NoError(t, db.Put(ctx, id, data)) - record := db.Get(ctx, id) + record, err := db.Get(ctx, id) + require.NoError(t, err) if assert.NotNil(t, record) { assert.NotNil(t, record.CreatedAt) assert.Equal(t, data, record.Data) @@ -57,22 +60,29 @@ func TestDB(t *testing.T) { }) t.Run("delete record", func(t *testing.T) { assert.NoError(t, db.Delete(ctx, id)) - record := db.Get(ctx, id) + record, err := db.Get(ctx, id) + require.NoError(t, err) require.NotNil(t, record) assert.NotNil(t, record.DeletedAt) }) t.Run("clear deleted", func(t *testing.T) { db.ClearDeleted(ctx, time.Now().Add(time.Second)) - assert.Nil(t, db.Get(ctx, id)) + record, err := db.Get(ctx, id) + assert.Error(t, err) + assert.Nil(t, record) }) t.Run("get all", func(t *testing.T) { - assert.Len(t, db.GetAll(ctx), 0) + records, err := db.GetAll(ctx) + assert.NoError(t, err) + assert.Len(t, records, 0) data := new(anypb.Any) for _, id := range ids { assert.NoError(t, db.Put(ctx, id, data)) } - assert.Len(t, db.GetAll(ctx), len(ids)) + records, err = db.GetAll(ctx) + assert.NoError(t, err) + assert.Len(t, records, len(ids)) for _, id := range ids { _, _ = c.Do("DEL", id) } @@ -87,12 +97,14 @@ func TestDB(t *testing.T) { assert.NoError(t, db.Put(ctx, id, data)) } - assert.Len(t, db.List(ctx, ""), 10) - assert.Len(t, db.List(ctx, "00000000000A"), 5) - assert.Len(t, db.List(ctx, "00000000000F"), 0) - - for _, id := range ids { - _, _ = c.Do("DEL", id) - } + records, err := db.List(ctx, "") + assert.NoError(t, err) + assert.Len(t, records, 10) + records, err = db.List(ctx, "00000000000A") + assert.NoError(t, err) + assert.Len(t, records, 5) + records, err = db.List(ctx, "00000000000F") + assert.NoError(t, err) + assert.Len(t, records, 0) }) } diff --git a/pkg/storage/storage.go b/pkg/storage/storage.go index 5c07eee6b..97e7b2d99 100644 --- a/pkg/storage/storage.go +++ b/pkg/storage/storage.go @@ -16,13 +16,13 @@ type Backend interface { Put(ctx context.Context, id string, data *anypb.Any) error // Get is used to retrieve a record. - Get(ctx context.Context, id string) *databroker.Record + Get(ctx context.Context, id string) (*databroker.Record, error) // GetAll is used to retrieve all the records. - GetAll(ctx context.Context) []*databroker.Record + GetAll(ctx context.Context) ([]*databroker.Record, error) // List is used to retrieve all the records since a version. - List(ctx context.Context, sinceVersion string) []*databroker.Record + List(ctx context.Context, sinceVersion string) ([]*databroker.Record, error) // Delete is used to mark a record as deleted. Delete(ctx context.Context, id string) error