1 /*
2  * Copyright 2016-2020 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
3  */
4 
5 package kotlinx.coroutines.test
6 
7 import kotlinx.coroutines.*
8 import kotlin.coroutines.*
9 
10 /**
11  * Executes a [testBody] inside an immediate execution dispatcher.
12  *
13  * This is similar to [runBlocking] but it will immediately progress past delays and into [launch] and [async] blocks.
14  * You can use this to write tests that execute in the presence of calls to [delay] without causing your test to take
15  * extra time.
16  *
17  * ```
18  * @Test
19  * fun exampleTest() = runBlockingTest {
20  *     val deferred = async {
21  *         delay(1_000)
22  *         async {
23  *             delay(1_000)
24  *         }.await()
25  *     }
26  *
27  *     deferred.await() // result available immediately
28  * }
29  *
30  * ```
31  *
32  * This method requires that all coroutines launched inside [testBody] complete, or are cancelled, as part of the test
33  * conditions.
34  *
35  * Unhandled exceptions thrown by coroutines in the test will be re-thrown at the end of the test.
36  *
37  * @throws UncompletedCoroutinesError If the [testBody] does not complete (or cancel) all coroutines that it launches
38  * (including coroutines suspended on join/await).
39  *
40  * @param context additional context elements. If [context] contains [CoroutineDispatcher] or [CoroutineExceptionHandler],
41  *        then they must implement [DelayController] and [TestCoroutineExceptionHandler] respectively.
42  * @param testBody The code of the unit-test.
43  */
44 @ExperimentalCoroutinesApi // Since 1.2.1, tentatively till 1.3.0
runBlockingTestnull45 public fun runBlockingTest(context: CoroutineContext = EmptyCoroutineContext, testBody: suspend TestCoroutineScope.() -> Unit) {
46     val (safeContext, dispatcher) = context.checkArguments()
47     val startingJobs = safeContext.activeJobs()
48     val scope = TestCoroutineScope(safeContext)
49     val deferred = scope.async {
50         scope.testBody()
51     }
52     dispatcher.advanceUntilIdle()
53     deferred.getCompletionExceptionOrNull()?.let {
54         throw it
55     }
56     scope.cleanupTestCoroutines()
57     val endingJobs = safeContext.activeJobs()
58     if ((endingJobs - startingJobs).isNotEmpty()) {
59         throw UncompletedCoroutinesError("Test finished with active jobs: $endingJobs")
60     }
61 }
62 
activeJobsnull63 private fun CoroutineContext.activeJobs(): Set<Job> {
64     return checkNotNull(this[Job]).children.filter { it.isActive }.toSet()
65 }
66 
67 /**
68  * Convenience method for calling [runBlockingTest] on an existing [TestCoroutineScope].
69  */
70 // todo: need documentation on how this extension is supposed to be used
71 @ExperimentalCoroutinesApi // Since 1.2.1, tentatively till 1.3.0
runBlockingTestnull72 public fun TestCoroutineScope.runBlockingTest(block: suspend TestCoroutineScope.() -> Unit): Unit =
73     runBlockingTest(coroutineContext, block)
74 
75 /**
76  * Convenience method for calling [runBlockingTest] on an existing [TestCoroutineDispatcher].
77  */
78 @ExperimentalCoroutinesApi // Since 1.2.1, tentatively till 1.3.0
79 public fun TestCoroutineDispatcher.runBlockingTest(block: suspend TestCoroutineScope.() -> Unit): Unit =
80     runBlockingTest(this, block)
81 
82 private fun CoroutineContext.checkArguments(): Pair<CoroutineContext, DelayController> {
83     // TODO optimize it
84     val dispatcher = get(ContinuationInterceptor).run {
85         this?.let { require(this is DelayController) { "Dispatcher must implement DelayController: $this" } }
86         this ?: TestCoroutineDispatcher()
87     }
88 
89     val exceptionHandler =  get(CoroutineExceptionHandler).run {
90         this?.let {
91             require(this is UncaughtExceptionCaptor) { "coroutineExceptionHandler must implement UncaughtExceptionCaptor: $this" }
92         }
93         this ?: TestCoroutineExceptionHandler()
94     }
95 
96     val job = get(Job) ?: SupervisorJob()
97     return Pair(this + dispatcher + exceptionHandler + job, dispatcher as DelayController)
98 }
99