1 /*
2  * Copyright 2016-2018 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
3  */
4 
5 package kotlinx.coroutines
6 
7 import kotlin.test.*
8 import java.util.concurrent.ExecutorService
9 import java.util.concurrent.Executors
10 import java.util.concurrent.ThreadFactory
11 import java.util.concurrent.atomic.AtomicInteger
12 import kotlin.coroutines.CoroutineContext
13 
14 class WithTimeoutOrNullThreadDispatchTest : TestBase() {
15     var executor: ExecutorService? = null
16 
17     @AfterTest
18     fun tearDown() {
19         executor?.shutdown()
20     }
21 
22     @Test
23     fun testCancellationDispatchScheduled() {
24         checkCancellationDispatch {
25             executor = Executors.newScheduledThreadPool(1, it)
26             executor!!.asCoroutineDispatcher()
27         }
28     }
29 
30     @Test
31     fun testCancellationDispatchNonScheduled() {
32         checkCancellationDispatch {
33             executor = Executors.newSingleThreadExecutor(it)
34             executor!!.asCoroutineDispatcher()
35         }
36     }
37 
38     @Test
39     fun testCancellationDispatchCustomNoDelay() {
40         // it also checks that there is at most once scheduled request in flight (no spurious concurrency)
41         var error: String? = null
42         checkCancellationDispatch {
43             executor = Executors.newSingleThreadExecutor(it)
44             val scheduled = AtomicInteger(0)
45             object : CoroutineDispatcher() {
46                 override fun dispatch(context: CoroutineContext, block: Runnable) {
47                     if (scheduled.incrementAndGet() > 1) error = "Two requests are scheduled concurrently"
48                     executor!!.execute {
49                         scheduled.decrementAndGet()
50                         block.run()
51                     }
52                 }
53             }
54         }
55         error?.let { error(it) }
56     }
57 
58     private fun checkCancellationDispatch(factory: (ThreadFactory) -> CoroutineDispatcher) = runBlocking {
59         expect(1)
60         var thread: Thread? = null
61         val dispatcher = factory(ThreadFactory { Thread(it).also { thread = it } })
62         withContext(dispatcher) {
63             expect(2)
64             assertEquals(thread, Thread.currentThread())
65             val result = withTimeoutOrNull(100) {
66                 try {
67                     expect(3)
68                     delay(1000)
69                     expectUnreached()
70                 } catch (e: CancellationException) {
71                     expect(4)
72                     assertEquals(thread, Thread.currentThread())
73                     throw e // rethrow
74                 }
75             }
76             assertEquals(thread, Thread.currentThread())
77             assertEquals(null, result)
78             expect(5)
79         }
80         finish(6)
81     }
82 }