mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-23 14:07:11 +02:00
pkg/storage: change backend interface to return error (#1131)
Since when storage backend like redis can be fault in many cases, the interface should return error for the caller to handle.
This commit is contained in:
parent
90d95b8c10
commit
aedfbc4c71
7 changed files with 108 additions and 69 deletions
|
@ -86,7 +86,7 @@ func (srv *Server) initVersion() {
|
|||
}
|
||||
|
||||
// Get version from storage first.
|
||||
if r := dbServerVersion.Get(context.Background(), serverVersionKey); r != nil {
|
||||
if r, _ := dbServerVersion.Get(context.Background(), serverVersionKey); r != nil {
|
||||
var sv databroker.ServerVersion
|
||||
if err := ptypes.UnmarshalAny(r.GetData(), &sv); err == nil {
|
||||
srv.log.Debug().Str("server_version", sv.Version).Msg("got db version from DB")
|
||||
|
@ -137,8 +137,8 @@ func (srv *Server) Get(ctx context.Context, req *databroker.GetRequest) (*databr
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
record := db.Get(ctx, req.GetId())
|
||||
if record == nil {
|
||||
record, err := db.Get(ctx, req.GetId())
|
||||
if err != nil {
|
||||
return nil, status.Error(codes.NotFound, "record not found")
|
||||
}
|
||||
return &databroker.GetResponse{Record: record}, nil
|
||||
|
@ -156,7 +156,10 @@ func (srv *Server) GetAll(ctx context.Context, req *databroker.GetAllRequest) (*
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
records := db.GetAll(ctx)
|
||||
records, err := db.GetAll(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var recordVersion string
|
||||
for _, record := range records {
|
||||
if record.GetVersion() > recordVersion {
|
||||
|
@ -188,8 +191,10 @@ func (srv *Server) Set(ctx context.Context, req *databroker.SetRequest) (*databr
|
|||
if err := db.Put(ctx, req.GetId(), req.GetData()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
record := db.Get(ctx, req.GetId())
|
||||
|
||||
record, err := db.Get(ctx, req.GetId())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &databroker.SetResponse{
|
||||
Record: record,
|
||||
ServerVersion: srv.version,
|
||||
|
@ -220,8 +225,7 @@ func (srv *Server) Sync(req *databroker.SyncRequest, stream databroker.DataBroke
|
|||
ch := srv.onchange.Bind()
|
||||
defer srv.onchange.Unbind(ch)
|
||||
for {
|
||||
updated := db.List(context.Background(), recordVersion)
|
||||
|
||||
updated, _ := db.List(context.Background(), recordVersion)
|
||||
if len(updated) > 0 {
|
||||
sort.Slice(updated, func(i, j int) bool {
|
||||
return updated[i].Version < updated[j].Version
|
||||
|
|
|
@ -40,13 +40,15 @@ func TestServer_initVersion(t *testing.T) {
|
|||
ctx := context.Background()
|
||||
db, err := srv.getDB(recordTypeServerVersion)
|
||||
require.NoError(t, err)
|
||||
r := db.Get(ctx, serverVersionKey)
|
||||
r, err := db.Get(ctx, serverVersionKey)
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, r)
|
||||
srvVersion := uuid.New().String()
|
||||
srv.version = srvVersion
|
||||
srv.initVersion()
|
||||
assert.Equal(t, srvVersion, srv.version)
|
||||
r = db.Get(ctx, serverVersionKey)
|
||||
r, err = db.Get(ctx, serverVersionKey)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, r)
|
||||
var sv databroker.ServerVersion
|
||||
assert.NoError(t, ptypes.UnmarshalAny(r.GetData(), &sv))
|
||||
|
@ -57,13 +59,15 @@ func TestServer_initVersion(t *testing.T) {
|
|||
ctx := context.Background()
|
||||
db, err := srv.getDB(recordTypeServerVersion)
|
||||
require.NoError(t, err)
|
||||
r := db.Get(ctx, serverVersionKey)
|
||||
r, err := db.Get(ctx, serverVersionKey)
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, r)
|
||||
|
||||
srv.initVersion()
|
||||
srvVersion := srv.version
|
||||
|
||||
r = db.Get(ctx, serverVersionKey)
|
||||
r, err = db.Get(ctx, serverVersionKey)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, r)
|
||||
var sv databroker.ServerVersion
|
||||
assert.NoError(t, ptypes.UnmarshalAny(r.GetData(), &sv))
|
||||
|
|
|
@ -3,6 +3,7 @@ package inmemory
|
|||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
@ -88,26 +89,26 @@ func (db *DB) Delete(_ context.Context, id string) error {
|
|||
}
|
||||
|
||||
// Get gets a record from the db.
|
||||
func (db *DB) Get(_ context.Context, id string) *databroker.Record {
|
||||
func (db *DB) Get(_ context.Context, id string) (*databroker.Record, error) {
|
||||
record, ok := db.byID.Get(byIDRecord{Record: &databroker.Record{Id: id}}).(byIDRecord)
|
||||
if !ok {
|
||||
return nil
|
||||
return nil, errors.New("not found")
|
||||
}
|
||||
return record.Record
|
||||
return record.Record, nil
|
||||
}
|
||||
|
||||
// GetAll gets all the records in the db.
|
||||
func (db *DB) GetAll(_ context.Context) []*databroker.Record {
|
||||
func (db *DB) GetAll(_ context.Context) ([]*databroker.Record, error) {
|
||||
var records []*databroker.Record
|
||||
db.byID.Ascend(func(item btree.Item) bool {
|
||||
records = append(records, item.(byIDRecord).Record)
|
||||
return true
|
||||
})
|
||||
return records
|
||||
return records, nil
|
||||
}
|
||||
|
||||
// List lists all the changes since the given version.
|
||||
func (db *DB) List(_ context.Context, sinceVersion string) []*databroker.Record {
|
||||
func (db *DB) List(_ context.Context, sinceVersion string) ([]*databroker.Record, error) {
|
||||
var records []*databroker.Record
|
||||
db.byVersion.AscendGreaterOrEqual(byVersionRecord{Record: &databroker.Record{Version: sinceVersion}}, func(i btree.Item) bool {
|
||||
record := i.(byVersionRecord)
|
||||
|
@ -116,7 +117,7 @@ func (db *DB) List(_ context.Context, sinceVersion string) []*databroker.Record
|
|||
}
|
||||
return true
|
||||
})
|
||||
return records
|
||||
return records, nil
|
||||
}
|
||||
|
||||
// Put replaces or inserts a record in the db.
|
||||
|
|
|
@ -7,6 +7,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/protobuf/types/known/anypb"
|
||||
)
|
||||
|
||||
|
@ -14,12 +15,15 @@ func TestDB(t *testing.T) {
|
|||
ctx := context.Background()
|
||||
db := NewDB("example", 2)
|
||||
t.Run("get missing record", func(t *testing.T) {
|
||||
assert.Nil(t, db.Get(ctx, "abcd"))
|
||||
record, err := db.Get(ctx, "abcd")
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, record)
|
||||
})
|
||||
t.Run("get record", func(t *testing.T) {
|
||||
data := new(anypb.Any)
|
||||
assert.NoError(t, db.Put(ctx, "abcd", data))
|
||||
record := db.Get(ctx, "abcd")
|
||||
record, err := db.Get(ctx, "abcd")
|
||||
require.NoError(t, err)
|
||||
if assert.NotNil(t, record) {
|
||||
assert.NotNil(t, record.CreatedAt)
|
||||
assert.Equal(t, data, record.Data)
|
||||
|
@ -32,21 +36,26 @@ func TestDB(t *testing.T) {
|
|||
})
|
||||
t.Run("delete record", func(t *testing.T) {
|
||||
assert.NoError(t, db.Delete(ctx, "abcd"))
|
||||
record := db.Get(ctx, "abcd")
|
||||
record, err := db.Get(ctx, "abcd")
|
||||
require.NoError(t, err)
|
||||
if assert.NotNil(t, record) {
|
||||
assert.NotNil(t, record.DeletedAt)
|
||||
}
|
||||
})
|
||||
t.Run("clear deleted", func(t *testing.T) {
|
||||
db.ClearDeleted(ctx, time.Now().Add(time.Second))
|
||||
assert.Nil(t, db.Get(ctx, "abcd"))
|
||||
record, err := db.Get(ctx, "abcd")
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, record)
|
||||
})
|
||||
t.Run("keep remaining", func(t *testing.T) {
|
||||
data := new(anypb.Any)
|
||||
assert.NoError(t, db.Put(ctx, "abcd", data))
|
||||
assert.NoError(t, db.Delete(ctx, "abcd"))
|
||||
db.ClearDeleted(ctx, time.Now().Add(-10*time.Second))
|
||||
assert.NotNil(t, db.Get(ctx, "abcd"))
|
||||
record, err := db.Get(ctx, "abcd")
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, record)
|
||||
db.ClearDeleted(ctx, time.Now().Add(time.Second))
|
||||
})
|
||||
t.Run("list", func(t *testing.T) {
|
||||
|
@ -55,8 +64,14 @@ func TestDB(t *testing.T) {
|
|||
assert.NoError(t, db.Put(ctx, fmt.Sprintf("%02d", i), data))
|
||||
}
|
||||
|
||||
assert.Len(t, db.List(ctx, ""), 10)
|
||||
assert.Len(t, db.List(ctx, "00000000000A"), 4)
|
||||
assert.Len(t, db.List(ctx, "00000000000F"), 0)
|
||||
records, err := db.List(ctx, "")
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, records, 10)
|
||||
records, err = db.List(ctx, "00000000000A")
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, records, 4)
|
||||
records, err = db.List(ctx, "00000000000F")
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, records, 0)
|
||||
})
|
||||
}
|
||||
|
|
|
@ -3,7 +3,6 @@ package redis
|
|||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"sync/atomic"
|
||||
|
@ -70,8 +69,8 @@ func New(address, recordType string, deletePermanentAfter int64) (*DB, error) {
|
|||
func (db *DB) Put(ctx context.Context, id string, data *anypb.Any) error {
|
||||
c := db.pool.Get()
|
||||
defer c.Close()
|
||||
record := db.Get(ctx, id)
|
||||
if record == nil {
|
||||
record, err := db.Get(ctx, id)
|
||||
if err != nil {
|
||||
record = new(databroker.Record)
|
||||
record.CreatedAt = ptypes.TimestampNow()
|
||||
}
|
||||
|
@ -97,27 +96,27 @@ func (db *DB) Put(ctx context.Context, id string, data *anypb.Any) error {
|
|||
}
|
||||
|
||||
// Get retrieves a record from redis.
|
||||
func (db *DB) Get(_ context.Context, id string) *databroker.Record {
|
||||
func (db *DB) Get(_ context.Context, id string) (*databroker.Record, error) {
|
||||
c := db.pool.Get()
|
||||
defer c.Close()
|
||||
|
||||
b, err := redis.Bytes(c.Do("HGET", db.recordType, id))
|
||||
if err != nil {
|
||||
return nil
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return db.toPbRecord(b)
|
||||
}
|
||||
|
||||
// GetAll retrieves all records from redis.
|
||||
func (db *DB) GetAll(ctx context.Context) []*databroker.Record {
|
||||
func (db *DB) GetAll(ctx context.Context) ([]*databroker.Record, error) {
|
||||
return db.getAll(ctx, func(record *databroker.Record) bool { return true })
|
||||
}
|
||||
|
||||
// List retrieves all records since given version.
|
||||
//
|
||||
// "version" is in hex format, invalid version will be treated as 0.
|
||||
func (db *DB) List(ctx context.Context, sinceVersion string) []*databroker.Record {
|
||||
func (db *DB) List(ctx context.Context, sinceVersion string) ([]*databroker.Record, error) {
|
||||
c := db.pool.Get()
|
||||
defer c.Close()
|
||||
|
||||
|
@ -128,18 +127,22 @@ func (db *DB) List(ctx context.Context, sinceVersion string) []*databroker.Recor
|
|||
|
||||
ids, err := redis.Strings(c.Do("ZRANGEBYSCORE", db.versionSet, fmt.Sprintf("(%d", v), "+inf"))
|
||||
if err != nil {
|
||||
return nil
|
||||
return nil, err
|
||||
}
|
||||
|
||||
records := make([]*databroker.Record, 0, len(ids))
|
||||
pbRecords := make([]*databroker.Record, 0, len(ids))
|
||||
for _, id := range ids {
|
||||
b, err := redis.Bytes(c.Do("HGET", db.recordType, id))
|
||||
if err != nil {
|
||||
continue
|
||||
return nil, err
|
||||
}
|
||||
records = append(records, db.toPbRecord(b))
|
||||
pbRecord, err := db.toPbRecord(b)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return records
|
||||
pbRecords = append(pbRecords, pbRecord)
|
||||
}
|
||||
return pbRecords, nil
|
||||
}
|
||||
|
||||
// Delete sets a record DeletedAt field and set its TTL.
|
||||
|
@ -147,9 +150,9 @@ func (db *DB) Delete(ctx context.Context, id string) error {
|
|||
c := db.pool.Get()
|
||||
defer c.Close()
|
||||
|
||||
r := db.Get(ctx, id)
|
||||
if r == nil {
|
||||
return errors.New("not found")
|
||||
r, err := db.Get(ctx, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get record: %w", err)
|
||||
}
|
||||
r.DeletedAt = ptypes.TimestampNow()
|
||||
r.Version = fmt.Sprintf("%012X", atomic.AddUint64(&db.lastVersion, 1))
|
||||
|
@ -177,8 +180,8 @@ func (db *DB) ClearDeleted(_ context.Context, cutoff time.Time) {
|
|||
ids, _ := redis.Strings(c.Do("SMEMBERS", db.deletedSet))
|
||||
for _, id := range ids {
|
||||
b, _ := redis.Bytes(c.Do("HGET", db.recordType, id))
|
||||
record := db.toPbRecord(b)
|
||||
if record == nil {
|
||||
record, err := db.toPbRecord(b)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
|
@ -195,7 +198,7 @@ func (db *DB) ClearDeleted(_ context.Context, cutoff time.Time) {
|
|||
}
|
||||
}
|
||||
|
||||
func (db *DB) getAll(_ context.Context, filter func(record *databroker.Record) bool) []*databroker.Record {
|
||||
func (db *DB) getAll(_ context.Context, filter func(record *databroker.Record) bool) ([]*databroker.Record, error) {
|
||||
c := db.pool.Get()
|
||||
defer c.Close()
|
||||
iter := 0
|
||||
|
@ -203,16 +206,16 @@ func (db *DB) getAll(_ context.Context, filter func(record *databroker.Record) b
|
|||
for {
|
||||
arr, err := redis.Values(c.Do("HSCAN", db.recordType, iter, "MATCH", "*"))
|
||||
if err != nil {
|
||||
return nil
|
||||
return nil, err
|
||||
}
|
||||
|
||||
iter, _ = redis.Int(arr[0], nil)
|
||||
pairs, _ := redis.StringMap(arr[1], nil)
|
||||
|
||||
for _, v := range pairs {
|
||||
record := db.toPbRecord([]byte(v))
|
||||
if record == nil {
|
||||
continue
|
||||
record, err := db.toPbRecord([]byte(v))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if filter(record) {
|
||||
records = append(records, record)
|
||||
|
@ -224,15 +227,15 @@ func (db *DB) getAll(_ context.Context, filter func(record *databroker.Record) b
|
|||
}
|
||||
}
|
||||
|
||||
return records
|
||||
return records, nil
|
||||
}
|
||||
|
||||
func (db *DB) toPbRecord(b []byte) *databroker.Record {
|
||||
func (db *DB) toPbRecord(b []byte) (*databroker.Record, error) {
|
||||
record := &databroker.Record{}
|
||||
if err := proto.Unmarshal(b, record); err != nil {
|
||||
return nil
|
||||
return nil, err
|
||||
}
|
||||
return record
|
||||
return record, nil
|
||||
}
|
||||
|
||||
func (db *DB) tx(c redis.Conn, commands []map[string][]interface{}) error {
|
||||
|
|
|
@ -40,12 +40,15 @@ func TestDB(t *testing.T) {
|
|||
cleanup(c, db, t)
|
||||
|
||||
t.Run("get missing record", func(t *testing.T) {
|
||||
assert.Nil(t, db.Get(ctx, id))
|
||||
record, err := db.Get(ctx, id)
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, record)
|
||||
})
|
||||
t.Run("get record", func(t *testing.T) {
|
||||
data := new(anypb.Any)
|
||||
assert.NoError(t, db.Put(ctx, id, data))
|
||||
record := db.Get(ctx, id)
|
||||
record, err := db.Get(ctx, id)
|
||||
require.NoError(t, err)
|
||||
if assert.NotNil(t, record) {
|
||||
assert.NotNil(t, record.CreatedAt)
|
||||
assert.Equal(t, data, record.Data)
|
||||
|
@ -57,22 +60,29 @@ func TestDB(t *testing.T) {
|
|||
})
|
||||
t.Run("delete record", func(t *testing.T) {
|
||||
assert.NoError(t, db.Delete(ctx, id))
|
||||
record := db.Get(ctx, id)
|
||||
record, err := db.Get(ctx, id)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, record)
|
||||
assert.NotNil(t, record.DeletedAt)
|
||||
})
|
||||
t.Run("clear deleted", func(t *testing.T) {
|
||||
db.ClearDeleted(ctx, time.Now().Add(time.Second))
|
||||
assert.Nil(t, db.Get(ctx, id))
|
||||
record, err := db.Get(ctx, id)
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, record)
|
||||
})
|
||||
t.Run("get all", func(t *testing.T) {
|
||||
assert.Len(t, db.GetAll(ctx), 0)
|
||||
records, err := db.GetAll(ctx)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, records, 0)
|
||||
data := new(anypb.Any)
|
||||
|
||||
for _, id := range ids {
|
||||
assert.NoError(t, db.Put(ctx, id, data))
|
||||
}
|
||||
assert.Len(t, db.GetAll(ctx), len(ids))
|
||||
records, err = db.GetAll(ctx)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, records, len(ids))
|
||||
for _, id := range ids {
|
||||
_, _ = c.Do("DEL", id)
|
||||
}
|
||||
|
@ -87,12 +97,14 @@ func TestDB(t *testing.T) {
|
|||
assert.NoError(t, db.Put(ctx, id, data))
|
||||
}
|
||||
|
||||
assert.Len(t, db.List(ctx, ""), 10)
|
||||
assert.Len(t, db.List(ctx, "00000000000A"), 5)
|
||||
assert.Len(t, db.List(ctx, "00000000000F"), 0)
|
||||
|
||||
for _, id := range ids {
|
||||
_, _ = c.Do("DEL", id)
|
||||
}
|
||||
records, err := db.List(ctx, "")
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, records, 10)
|
||||
records, err = db.List(ctx, "00000000000A")
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, records, 5)
|
||||
records, err = db.List(ctx, "00000000000F")
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, records, 0)
|
||||
})
|
||||
}
|
||||
|
|
|
@ -16,13 +16,13 @@ type Backend interface {
|
|||
Put(ctx context.Context, id string, data *anypb.Any) error
|
||||
|
||||
// Get is used to retrieve a record.
|
||||
Get(ctx context.Context, id string) *databroker.Record
|
||||
Get(ctx context.Context, id string) (*databroker.Record, error)
|
||||
|
||||
// GetAll is used to retrieve all the records.
|
||||
GetAll(ctx context.Context) []*databroker.Record
|
||||
GetAll(ctx context.Context) ([]*databroker.Record, error)
|
||||
|
||||
// List is used to retrieve all the records since a version.
|
||||
List(ctx context.Context, sinceVersion string) []*databroker.Record
|
||||
List(ctx context.Context, sinceVersion string) ([]*databroker.Record, error)
|
||||
|
||||
// Delete is used to mark a record as deleted.
|
||||
Delete(ctx context.Context, id string) error
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue