1 /*
<lambda>null2  * Copyright 2016-2019 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
3  */
4 
5 package kotlinx.coroutines
6 
7 import kotlinx.coroutines.internal.*
8 import kotlinx.coroutines.scheduling.*
9 import org.junit.*
10 import java.lang.Math.*
11 import java.util.*
12 import java.util.concurrent.atomic.*
13 import kotlin.coroutines.*
14 import kotlin.math.*
15 import kotlin.test.*
16 
17 private val VERBOSE = systemProp("test.verbose", false)
18 
19 /**
20  * Is `true` when running in a nightly stress test mode.
21  */
22 public actual val isStressTest = System.getProperty("stressTest")?.toBoolean() ?: false
23 
24 public val stressTestMultiplierSqrt = if (isStressTest) 5 else 1
25 
26 /**
27  * Multiply various constants in stress tests by this factor, so that they run longer during nightly stress test.
28  */
29 public actual val stressTestMultiplier = stressTestMultiplierSqrt * stressTestMultiplierSqrt
30 
31 public val stressTestMultiplierCbrt = cbrt(stressTestMultiplier.toDouble()).roundToInt()
32 
33 /**
34  * Base class for tests, so that tests for predictable scheduling of actions in multiple coroutines sharing a single
35  * thread can be written. Use it like this:
36  *
37  * ```
38  * class MyTest : TestBase() {
39  *    @Test
40  *    fun testSomething() = runBlocking { // run in the context of the main thread
41  *        expect(1) // initiate action counter
42  *        launch { // use the context of the main thread
43  *           expect(3) // the body of this coroutine in going to be executed in the 3rd step
44  *        }
45  *        expect(2) // launch just scheduled coroutine for execution later, so this line is executed second
46  *        yield() // yield main thread to the launched job
47  *        finish(4) // fourth step is the last one. `finish` must be invoked or test fails
48  *    }
49  * }
50  * ```
51  */
52 public actual open class TestBase actual constructor() {
53     private var actionIndex = AtomicInteger()
54     private var finished = AtomicBoolean()
55     private var error = AtomicReference<Throwable>()
56 
57     // Shutdown sequence
58     private lateinit var threadsBefore: Set<Thread>
59     private val uncaughtExceptions = Collections.synchronizedList(ArrayList<Throwable>())
60     private var originalUncaughtExceptionHandler: Thread.UncaughtExceptionHandler? = null
61     private val SHUTDOWN_TIMEOUT = 1_000L // 1s at most to wait per thread
62 
63     /**
64      * Throws [IllegalStateException] like `error` in stdlib, but also ensures that the test will not
65      * complete successfully even if this exception is consumed somewhere in the test.
66      */
67     @Suppress("ACTUAL_FUNCTION_WITH_DEFAULT_ARGUMENTS")
68     public actual fun error(message: Any, cause: Throwable? = null): Nothing {
69         throw makeError(message, cause)
70     }
71 
72     public fun hasError() = error.get() != null
73 
74     private fun makeError(message: Any, cause: Throwable? = null): IllegalStateException =
75         IllegalStateException(message.toString(), cause).also {
76             setError(it)
77         }
78 
79     private fun setError(exception: Throwable) {
80         error.compareAndSet(null, exception)
81     }
82 
83     private fun printError(message: String, cause: Throwable) {
84         setError(cause)
85         println("$message: $cause")
86         cause.printStackTrace(System.out)
87         println("--- Detected at ---")
88         Throwable().printStackTrace(System.out)
89     }
90 
91     /**
92      * Throws [IllegalStateException] when `value` is false like `check` in stdlib, but also ensures that the
93      * test will not complete successfully even if this exception is consumed somewhere in the test.
94      */
95     public inline fun check(value: Boolean, lazyMessage: () -> Any) {
96         if (!value) error(lazyMessage())
97     }
98 
99     /**
100      * Asserts that this invocation is `index`-th in the execution sequence (counting from one).
101      */
102     public actual fun expect(index: Int) {
103         val wasIndex = actionIndex.incrementAndGet()
104         if (VERBOSE) println("expect($index), wasIndex=$wasIndex")
105         check(index == wasIndex) { "Expecting action index $index but it is actually $wasIndex" }
106     }
107 
108     /**
109      * Asserts that this line is never executed.
110      */
111     public actual fun expectUnreached() {
112         error("Should not be reached, current action index is ${actionIndex.get()}")
113     }
114 
115     /**
116      * Asserts that this it the last action in the test. It must be invoked by any test that used [expect].
117      */
118     public actual fun finish(index: Int) {
119         expect(index)
120         check(!finished.getAndSet(true)) { "Should call 'finish(...)' at most once" }
121     }
122 
123     /**
124      * Asserts that [finish] was invoked
125      */
126     public actual fun ensureFinished() {
127         require(finished.get()) { "finish(...) should be caller prior to this check" }
128     }
129 
130     public actual fun reset() {
131         check(actionIndex.get() == 0 || finished.get()) { "Expecting that 'finish(...)' was invoked, but it was not" }
132         actionIndex.set(0)
133         finished.set(false)
134     }
135 
136     @Before
137     fun before() {
138         initPoolsBeforeTest()
139         threadsBefore = currentThreads()
140         originalUncaughtExceptionHandler = Thread.getDefaultUncaughtExceptionHandler()
141         Thread.setDefaultUncaughtExceptionHandler { t, e ->
142             println("Exception in thread $t: $e") // The same message as in default handler
143             e.printStackTrace()
144             uncaughtExceptions.add(e)
145         }
146     }
147 
148     @After
149     fun onCompletion() {
150         // onCompletion should not throw exceptions before it finishes all cleanup, so that other tests always
151         // start in a clear, restored state
152         if (actionIndex.get() != 0 && !finished.get()) {
153             makeError("Expecting that 'finish(${actionIndex.get() + 1})' was invoked, but it was not")
154         }
155         // Shutdown all thread pools
156         shutdownPoolsAfterTest()
157         // Check that that are now leftover threads
158         runCatching {
159             checkTestThreads(threadsBefore)
160         }.onFailure {
161             setError(it)
162         }
163         // Restore original uncaught exception handler
164         Thread.setDefaultUncaughtExceptionHandler(originalUncaughtExceptionHandler)
165         if (uncaughtExceptions.isNotEmpty()) {
166             makeError("Expected no uncaught exceptions, but got $uncaughtExceptions")
167         }
168         // The very last action -- throw error if any was detected
169         error.get()?.let { throw it }
170     }
171 
172     fun initPoolsBeforeTest() {
173         CommonPool.usePrivatePool()
174         DefaultScheduler.usePrivateScheduler()
175     }
176 
177     fun shutdownPoolsAfterTest() {
178         CommonPool.shutdown(SHUTDOWN_TIMEOUT)
179         DefaultScheduler.shutdown(SHUTDOWN_TIMEOUT)
180         DefaultExecutor.shutdown(SHUTDOWN_TIMEOUT)
181         CommonPool.restore()
182         DefaultScheduler.restore()
183     }
184 
185     @Suppress("ACTUAL_WITHOUT_EXPECT", "ACTUAL_FUNCTION_WITH_DEFAULT_ARGUMENTS")
186     public actual fun runTest(
187         expected: ((Throwable) -> Boolean)? = null,
188         unhandled: List<(Throwable) -> Boolean> = emptyList(),
189         block: suspend CoroutineScope.() -> Unit
190     ) {
191         var exCount = 0
192         var ex: Throwable? = null
193         try {
194             runBlocking(block = block, context = CoroutineExceptionHandler { _, e ->
195                 if (e is CancellationException) return@CoroutineExceptionHandler // are ignored
196                 exCount++
197                 when {
198                     exCount > unhandled.size ->
199                         printError("Too many unhandled exceptions $exCount, expected ${unhandled.size}, got: $e", e)
200                     !unhandled[exCount - 1](e) ->
201                         printError("Unhandled exception was unexpected: $e", e)
202                 }
203             })
204         } catch (e: Throwable) {
205             ex = e
206             if (expected != null) {
207                 if (!expected(e))
208                     error("Unexpected exception: $e", e)
209             } else
210                 throw e
211         } finally {
212             if (ex == null && expected != null) error("Exception was expected but none produced")
213         }
214         if (exCount < unhandled.size)
215             error("Too few unhandled exceptions $exCount, expected ${unhandled.size}")
216     }
217 
218     protected inline fun <reified T: Throwable> assertFailsWith(block: () -> Unit): T {
219         val result = runCatching(block)
220         assertTrue(result.exceptionOrNull() is T, "Expected ${T::class}, but had $result")
221         return result.exceptionOrNull()!! as T
222     }
223 
224     protected suspend fun currentDispatcher() = coroutineContext[ContinuationInterceptor]!!
225 }
226