diff --git a/pkg/storage/inmemory/backend_test.go b/pkg/storage/inmemory/backend_test.go index e67f38a5c..f4cfeb7b4 100644 --- a/pkg/storage/inmemory/backend_test.go +++ b/pkg/storage/inmemory/backend_test.go @@ -14,6 +14,7 @@ import ( "google.golang.org/protobuf/types/known/timestamppb" "github.com/pomerium/pomerium/pkg/grpc/databroker" + "github.com/pomerium/pomerium/pkg/storage" ) func TestBackend(t *testing.T) { @@ -142,6 +143,7 @@ func TestStream(t *testing.T) { eg.Go(func() error { for i := 0; i < 10000; i++ { assert.True(t, stream.Next(true)) + assert.Nil(t, stream.Err()) assert.Equal(t, "TYPE", stream.Record().GetType()) assert.Equal(t, fmt.Sprint(i), stream.Record().GetId()) assert.Equal(t, uint64(i+1), stream.Record().GetVersion()) @@ -161,6 +163,35 @@ func TestStream(t *testing.T) { require.NoError(t, eg.Wait()) } +func TestStreamClose(t *testing.T) { + ctx := context.Background() + t.Run("by backend", func(t *testing.T) { + backend := New() + stream, err := backend.Sync(ctx, backend.serverVersion, 0) + require.NoError(t, err) + require.NoError(t, backend.Close()) + assert.False(t, stream.Next(true)) + assert.Equal(t, storage.ErrStreamClosed, stream.Err()) + }) + t.Run("by stream", func(t *testing.T) { + backend := New() + stream, err := backend.Sync(ctx, backend.serverVersion, 0) + require.NoError(t, err) + require.NoError(t, stream.Close()) + assert.False(t, stream.Next(true)) + assert.Equal(t, storage.ErrStreamClosed, stream.Err()) + }) + t.Run("by context", func(t *testing.T) { + ctx, cancel := context.WithCancel(ctx) + backend := New() + stream, err := backend.Sync(ctx, backend.serverVersion, 0) + require.NoError(t, err) + cancel() + assert.False(t, stream.Next(true)) + assert.Equal(t, context.Canceled, stream.Err()) + }) +} + func TestCapacity(t *testing.T) { ctx := context.Background() backend := New() diff --git a/pkg/storage/inmemory/stream.go b/pkg/storage/inmemory/stream.go index 7fd451558..b48cf1961 100644 --- a/pkg/storage/inmemory/stream.go +++ b/pkg/storage/inmemory/stream.go @@ -99,11 +99,20 @@ func (stream *recordStream) Err() error { select { case <-stream.ctx.Done(): return stream.ctx.Err() - case <-stream.closed: - return storage.ErrStreamClosed + default: + } + + select { case <-stream.backend.closed: return storage.ErrStreamClosed default: - return nil } + + select { + case <-stream.closed: + return storage.ErrStreamClosed + default: + } + + return nil }