diff --git a/paging/paging-common/src/test/kotlin/androidx/paging/PagingDataDifferTest.kt b/paging/paging-common/src/test/kotlin/androidx/paging/PagingDataDifferTest.kt index 1aab76c6b51bd..8a2f69ccbd6dd 100644 --- a/paging/paging-common/src/test/kotlin/androidx/paging/PagingDataDifferTest.kt +++ b/paging/paging-common/src/test/kotlin/androidx/paging/PagingDataDifferTest.kt @@ -31,7 +31,6 @@ import kotlin.test.assertFalse import kotlin.test.assertNull import kotlin.test.assertTrue import kotlinx.coroutines.CoroutineDispatcher -import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.ExperimentalCoroutinesApi import kotlinx.coroutines.Job import kotlinx.coroutines.async @@ -1465,7 +1464,7 @@ class PagingDataDifferTest( @Test fun refresh_loadStates() = runTest(initialKey = 50) { differ, loadDispatcher, pagingSources, _, _ -> - val collectLoadStates = differ.collectLoadStates() + val collectLoadStates = launch { differ.collectLoadStates() } // execute queued initial REFRESH loadDispatcher.scheduler.advanceUntilIdle() @@ -1498,7 +1497,7 @@ class PagingDataDifferTest( differ.addLoadStateListener { loadStateCallbacks.add(it) } - val collectLoadStates = differ.collectLoadStates() + val collectLoadStates = launch { differ.collectLoadStates() } // execute initial refresh loadDispatcher.scheduler.advanceUntilIdle() assertThat(differ.snapshot()).containsExactlyElementsIn(0 until 9) @@ -1545,7 +1544,7 @@ class PagingDataDifferTest( @Test fun appendInvalid_loadStates() = runTest { differ, loadDispatcher, pagingSources, _, _ -> - val collectLoadStates = differ.collectLoadStates() + val collectLoadStates = launch { differ.collectLoadStates() } // initial REFRESH loadDispatcher.scheduler.advanceUntilIdle() @@ -1607,7 +1606,7 @@ class PagingDataDifferTest( @Test fun prependInvalid_loadStates() = runTest(initialKey = 50) { differ, loadDispatcher, pagingSources, _, _ -> - val collectLoadStates = differ.collectLoadStates() + val collectLoadStates = launch { differ.collectLoadStates() } // initial REFRESH loadDispatcher.scheduler.advanceUntilIdle() @@ -1661,7 +1660,7 @@ class PagingDataDifferTest( @Test fun refreshInvalid_loadStates() = runTest(initialKey = 50) { differ, loadDispatcher, pagingSources, _, _ -> - val collectLoadStates = differ.collectLoadStates() + val collectLoadStates = launch { differ.collectLoadStates() } // execute queued initial REFRESH load which will return LoadResult.Invalid() pagingSources[0].nextLoadResult = LoadResult.Invalid() @@ -1691,7 +1690,7 @@ class PagingDataDifferTest( @Test fun appendError_retryLoadStates() = runTest { differ, loadDispatcher, pagingSources, _, _ -> - val collectLoadStates = differ.collectLoadStates() + val collectLoadStates = launch { differ.collectLoadStates() } // initial REFRESH loadDispatcher.scheduler.advanceUntilIdle() @@ -1742,7 +1741,7 @@ class PagingDataDifferTest( @Test fun prependError_retryLoadStates() = runTest(initialKey = 50) { differ, loadDispatcher, pagingSources, _, _ -> - val collectLoadStates = differ.collectLoadStates() + val collectLoadStates = launch { differ.collectLoadStates() } // initial REFRESH loadDispatcher.scheduler.advanceUntilIdle() @@ -1784,7 +1783,7 @@ class PagingDataDifferTest( @Test fun refreshError_retryLoadStates() = runTest { differ, loadDispatcher, pagingSources, _, _ -> - val collectLoadStates = differ.collectLoadStates() + val collectLoadStates = launch { differ.collectLoadStates() } // initial load returns LoadResult.Error val exception = Throwable() @@ -1817,7 +1816,7 @@ class PagingDataDifferTest( @Test fun prependError_refreshLoadStates() = runTest(initialKey = 50) { differ, loadDispatcher, pagingSources, _, _ -> - val collectLoadStates = differ.collectLoadStates() + val collectLoadStates = launch { differ.collectLoadStates() } // initial REFRESH loadDispatcher.scheduler.advanceUntilIdle() @@ -1861,7 +1860,7 @@ class PagingDataDifferTest( @Test fun refreshError_refreshLoadStates() = runTest { differ, loadDispatcher, pagingSources, _, _ -> - val collectLoadStates = differ.collectLoadStates() + val collectLoadStates = launch { differ.collectLoadStates() } // the initial load will return LoadResult.Error val exception = Throwable() @@ -1903,7 +1902,7 @@ class PagingDataDifferTest( TestPagingSource(loadDelay = 500, items = emptyList()) } - val collectLoadStates = differ.collectLoadStates() + val collectLoadStates = launch { differ.collectLoadStates() } val job = launch { pager.flow.collectLatest { differ.collectFrom(it) @@ -2000,9 +1999,8 @@ class PagingDataDifferTest( val differ = SimpleDiffer( differCallback = dummyDifferCallback, - coroutineScope = backgroundScope, ) - differ.collectLoadStates() + backgroundScope.launch { differ.collectLoadStates() } val job = launch { pager.collectLatest { @@ -2019,9 +2017,8 @@ class PagingDataDifferTest( // we start a separate differ to recollect on cached Pager.flow val differ2 = SimpleDiffer( differCallback = dummyDifferCallback, - coroutineScope = backgroundScope, ) - differ2.collectLoadStates() + backgroundScope.launch { differ2.collectLoadStates() } val job2 = launch { pager.collectLatest { @@ -2199,7 +2196,7 @@ class PagingDataDifferTest( ).also { pagingSources.add(it) } } ), - block: ( + block: TestScope.( differ: SimpleDiffer, loadDispatcher: TestDispatcher, pagingSources: List, @@ -2209,7 +2206,6 @@ class PagingDataDifferTest( ) = testScope.runTest { val differ = SimpleDiffer( differCallback = dummyDifferCallback, - coroutineScope = this, ) val uiReceivers = mutableListOf() val hintReceivers = mutableListOf() @@ -2345,11 +2341,9 @@ private class TrackableHintReceiverWrapper( } } -@OptIn(ExperimentalCoroutinesApi::class) private class SimpleDiffer( differCallback: DifferCallback, cachedPagingData: PagingData? = null, - val coroutineScope: CoroutineScope = TestScope(UnconfinedTestDispatcher()) ) : PagingDataDiffer(differCallback = differCallback, cachedPagingData = cachedPagingData) { override suspend fun presentNewList( previousList: NullPaddedList, @@ -2371,11 +2365,9 @@ private class SimpleDiffer( return newCombinedLoadStates } - fun collectLoadStates(): Job { - return coroutineScope.launch { - nonNullLoadStateFlow.collect { combinedLoadStates -> - _localLoadStates.add(combinedLoadStates) - } + suspend fun collectLoadStates() { + nonNullLoadStateFlow.collect { combinedLoadStates -> + _localLoadStates.add(combinedLoadStates) } } }