diff --git a/internal/databroker/server.go b/internal/databroker/server.go index 76c8c9c4b..cc496d1d7 100644 --- a/internal/databroker/server.go +++ b/internal/databroker/server.go @@ -145,7 +145,7 @@ func (srv *Server) Query(ctx context.Context, req *databroker.QueryRequest) (*da return nil, err } - _, stream, err := db.SyncLatest(ctx) + _, _, stream, err := db.SyncLatest(ctx, req.GetType(), nil) if err != nil { return nil, err } @@ -332,17 +332,13 @@ func (srv *Server) SyncLatest(req *databroker.SyncLatestRequest, stream databrok return err } - serverVersion, recordStream, err := backend.SyncLatest(ctx) + serverVersion, recordVersion, recordStream, err := backend.SyncLatest(ctx, req.GetType(), nil) if err != nil { return err } - recordVersion := uint64(0) for recordStream.Next(false) { record := recordStream.Record() - if record.GetVersion() > recordVersion { - recordVersion = record.GetVersion() - } if req.GetType() == "" || req.GetType() == record.GetType() { err = stream.Send(&databroker.SyncLatestResponse{ Response: &databroker.SyncLatestResponse_Record{ diff --git a/pkg/storage/encrypted.go b/pkg/storage/encrypted.go index cbca8f1cd..d70a0ea83 100644 --- a/pkg/storage/encrypted.go +++ b/pkg/storage/encrypted.go @@ -130,12 +130,16 @@ func (e *encryptedBackend) Sync(ctx context.Context, serverVersion, recordVersio }, nil } -func (e *encryptedBackend) SyncLatest(ctx context.Context) (serverVersion uint64, stream RecordStream, err error) { - serverVersion, stream, err = e.underlying.SyncLatest(ctx) +func (e *encryptedBackend) SyncLatest( + ctx context.Context, + recordType string, + filter FilterExpression, +) (serverVersion, recordVersion uint64, stream RecordStream, err error) { + serverVersion, recordVersion, stream, err = e.underlying.SyncLatest(ctx, recordType, filter) if err != nil { - return serverVersion, nil, err + return serverVersion, recordVersion, nil, err } - return serverVersion, &encryptedRecordStream{ + return serverVersion, recordVersion, &encryptedRecordStream{ underlying: stream, backend: e, }, nil diff --git a/pkg/storage/inmemory/backend.go b/pkg/storage/inmemory/backend.go index 49bd94673..e92b98fb3 100644 --- a/pkg/storage/inmemory/backend.go +++ b/pkg/storage/inmemory/backend.go @@ -255,12 +255,18 @@ func (backend *Backend) Sync(ctx context.Context, serverVersion, recordVersion u } // SyncLatest returns a record stream for all the records. -func (backend *Backend) SyncLatest(ctx context.Context) (serverVersion uint64, stream storage.RecordStream, err error) { +func (backend *Backend) SyncLatest( + ctx context.Context, + recordType string, + expr storage.FilterExpression, +) (serverVersion, recordVersion uint64, stream storage.RecordStream, err error) { backend.mu.RLock() - currentServerVersion := backend.serverVersion + serverVersion = backend.serverVersion + recordVersion = backend.lastVersion backend.mu.RUnlock() - return currentServerVersion, newSyncLatestRecordStream(ctx, backend), nil + stream, err = newSyncLatestRecordStream(ctx, backend, recordType, expr) + return serverVersion, recordVersion, stream, err } func (backend *Backend) recordChange(record *databroker.Record) { diff --git a/pkg/storage/inmemory/backend_test.go b/pkg/storage/inmemory/backend_test.go index dacd76b54..c3433d1a4 100644 --- a/pkg/storage/inmemory/backend_test.go +++ b/pkg/storage/inmemory/backend_test.go @@ -210,7 +210,7 @@ func TestCapacity(t *testing.T) { require.NoError(t, err) } - _, stream, err := backend.SyncLatest(ctx) + _, _, stream, err := backend.SyncLatest(ctx, "EXAMPLE", nil) require.NoError(t, err) records, err := storage.RecordStreamToList(stream) diff --git a/pkg/storage/inmemory/stream.go b/pkg/storage/inmemory/stream.go index 8bef957b2..8007a47c7 100644 --- a/pkg/storage/inmemory/stream.go +++ b/pkg/storage/inmemory/stream.go @@ -10,17 +10,35 @@ import ( func newSyncLatestRecordStream( ctx context.Context, backend *Backend, -) storage.RecordStream { + recordType string, + expr storage.FilterExpression, +) (storage.RecordStream, error) { + filter, err := storage.RecordStreamFilterFromFilterExpression(expr) + if err != nil { + return nil, err + } + if recordType != "" { + filter = filter.And(func(record *databroker.Record) (keep bool) { + return record.GetType() == recordType + }) + } + var ready []*databroker.Record - return storage.NewRecordStream(ctx, backend.closed, []storage.RecordStreamGenerator{ - func(ctx context.Context, block bool) (*databroker.Record, error) { - backend.mu.RLock() - for _, co := range backend.lookup { - ready = append(ready, co.List()...) + generator := func(ctx context.Context, block bool) (*databroker.Record, error) { + backend.mu.RLock() + for _, co := range backend.lookup { + for _, record := range co.List() { + if filter(record) { + ready = append(ready, record) + } } - backend.mu.RUnlock() - return nil, storage.ErrStreamDone - }, + } + backend.mu.RUnlock() + return nil, storage.ErrStreamDone + } + + return storage.NewRecordStream(ctx, backend.closed, []storage.RecordStreamGenerator{ + generator, func(ctx context.Context, block bool) (*databroker.Record, error) { if len(ready) == 0 { return nil, storage.ErrStreamDone @@ -30,7 +48,7 @@ func newSyncLatestRecordStream( ready = ready[1:] return dup(record), nil }, - }, nil) + }, nil), nil } func newSyncRecordStream( diff --git a/pkg/storage/redis/redis.go b/pkg/storage/redis/redis.go index 425163c86..b03fc58d3 100644 --- a/pkg/storage/redis/redis.go +++ b/pkg/storage/redis/redis.go @@ -254,12 +254,26 @@ func (backend *Backend) Sync(ctx context.Context, serverVersion, recordVersion u // SyncLatest returns a record stream of all the records. Some records may be returned twice if the are updated while the // stream is streaming. -func (backend *Backend) SyncLatest(ctx context.Context) (serverVersion uint64, stream storage.RecordStream, err error) { +func (backend *Backend) SyncLatest( + ctx context.Context, + recordType string, + expr storage.FilterExpression, +) (serverVersion, recordVersion uint64, stream storage.RecordStream, err error) { serverVersion, err = backend.getOrCreateServerVersion(ctx) if err != nil { - return 0, nil, err + return serverVersion, recordVersion, nil, err } - return serverVersion, newSyncLatestRecordStream(ctx, backend), nil + + recordVersion, err = backend.client.Get(ctx, lastVersionKey).Uint64() + if errors.Is(err, redis.Nil) { + // this happens if there are no records + err = nil + } else if err != nil { + return serverVersion, recordVersion, nil, err + } + + stream, err = newSyncLatestRecordStream(ctx, backend, recordType, expr) + return serverVersion, recordVersion, stream, err } func (backend *Backend) put(ctx context.Context, records []*databroker.Record) error { diff --git a/pkg/storage/redis/redis_test.go b/pkg/storage/redis/redis_test.go index f4f9acbd4..3e81e0470 100644 --- a/pkg/storage/redis/redis_test.go +++ b/pkg/storage/redis/redis_test.go @@ -240,7 +240,7 @@ func TestCapacity(t *testing.T) { require.NoError(t, err) } - _, stream, err := backend.SyncLatest(ctx) + _, _, stream, err := backend.SyncLatest(ctx, "EXAMPLE", nil) require.NoError(t, err) defer stream.Close() diff --git a/pkg/storage/redis/stream.go b/pkg/storage/redis/stream.go index 69ca9429e..b85355af2 100644 --- a/pkg/storage/redis/stream.go +++ b/pkg/storage/redis/stream.go @@ -63,23 +63,23 @@ func newSyncRecordStream( func newSyncLatestRecordStream( ctx context.Context, backend *Backend, -) storage.RecordStream { - var recordVersion, cursor uint64 + recordType string, + expr storage.FilterExpression, +) (storage.RecordStream, error) { + filter, err := storage.RecordStreamFilterFromFilterExpression(expr) + if err != nil { + return nil, err + } + if recordType != "" { + filter = filter.And(func(record *databroker.Record) (keep bool) { + return record.GetType() == recordType + }) + } + + var cursor uint64 scannedOnce := false var scannedRecords []*databroker.Record - return storage.NewRecordStream(ctx, backend.closed, []storage.RecordStreamGenerator{ - // 1. get the current record version - func(ctx context.Context, block bool) (*databroker.Record, error) { - var err error - recordVersion, err = backend.client.Get(ctx, lastVersionKey).Uint64() - if errors.Is(err, redis.Nil) { - // this happens if there are no records - } else if err != nil { - return nil, err - } - return nil, storage.ErrStreamDone - }, - // 2. stream all the records + generator := storage.FilteredRecordStreamGenerator( func(ctx context.Context, block bool) (*databroker.Record, error) { for { if len(scannedRecords) > 0 { @@ -102,11 +102,12 @@ func newSyncLatestRecordStream( scannedOnce = true } }, - // 3. stream any records which have been updated in the interim - func(ctx context.Context, block bool) (*databroker.Record, error) { - return nextChangedRecord(ctx, backend, &recordVersion) - }, - }, nil) + filter, + ) + + return storage.NewRecordStream(ctx, backend.closed, []storage.RecordStreamGenerator{ + generator, + }, nil), nil } func nextScannedRecords(ctx context.Context, backend *Backend, cursor *uint64) ([]*databroker.Record, error) { diff --git a/pkg/storage/storage.go b/pkg/storage/storage.go index 2ed3e353c..24df7df9a 100644 --- a/pkg/storage/storage.go +++ b/pkg/storage/storage.go @@ -40,7 +40,7 @@ type Backend interface { // Sync syncs record changes after the specified version. Sync(ctx context.Context, serverVersion, recordVersion uint64) (RecordStream, error) // SyncLatest syncs all the records. - SyncLatest(ctx context.Context) (serverVersion uint64, stream RecordStream, err error) + SyncLatest(ctx context.Context, recordType string, filter FilterExpression) (serverVersion, recordVersion uint64, stream RecordStream, err error) } // MatchAny searches any data with a query.