1 /*
2  * Copyright 2016-2019 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
3  */
4 
5 package kotlinx.coroutines
6 
7 import java.io.*
8 import java.util.concurrent.*
9 import java.util.concurrent.locks.*
10 
11 private const val SHUTDOWN_TIMEOUT = 1000L
12 
withVirtualTimeSourcenull13 internal inline fun withVirtualTimeSource(log: PrintStream? = null, block: () -> Unit) {
14     DefaultExecutor.shutdown(SHUTDOWN_TIMEOUT) // shutdown execution with old time source (in case it was working)
15     val testTimeSource = VirtualTimeSource(log)
16     timeSource = testTimeSource
17     DefaultExecutor.ensureStarted() // should start with new time source
18     try {
19         block()
20     } finally {
21         DefaultExecutor.shutdown(SHUTDOWN_TIMEOUT)
22         testTimeSource.shutdown()
23         timeSource = null // restore time source
24     }
25 }
26 
27 private const val NOT_PARKED = -1L
28 
29 private class ThreadStatus {
30     @Volatile @JvmField
31     var parkedTill = NOT_PARKED
32     @Volatile @JvmField
33     var permit = false
34     var registered = 0
toStringnull35     override fun toString(): String = "parkedTill = ${TimeUnit.NANOSECONDS.toMillis(parkedTill)} ms, permit = $permit"
36 }
37 
38 private const val MAX_WAIT_NANOS = 10_000_000_000L // 10s
39 private const val REAL_TIME_STEP_NANOS = 200_000_000L // 200 ms
40 private const val REAL_PARK_NANOS = 10_000_000L // 10 ms -- park for a little to better track real-time
41 
42 @Suppress("PLATFORM_CLASS_MAPPED_TO_KOTLIN")
43 internal class VirtualTimeSource(
44     private val log: PrintStream?
45 ) : TimeSource {
46     private val mainThread: Thread = Thread.currentThread()
47     private var checkpointNanos: Long = System.nanoTime()
48 
49     @Volatile
50     private var isShutdown = false
51 
52     @Volatile
53     private var time: Long = 0
54 
55     private var trackedTasks = 0
56 
57     private val threads = ConcurrentHashMap<Thread, ThreadStatus>()
58 
59     override fun currentTimeMillis(): Long = TimeUnit.NANOSECONDS.toMillis(time)
60     override fun nanoTime(): Long = time
61 
62     override fun wrapTask(block: Runnable): Runnable {
63         trackTask()
64         return Runnable {
65             try { block.run() }
66             finally { unTrackTask() }
67         }
68     }
69 
70     @Synchronized
71     override fun trackTask() {
72         trackedTasks++
73     }
74 
75     @Synchronized
76     override fun unTrackTask() {
77         assert(trackedTasks > 0)
78         trackedTasks--
79     }
80 
81     @Synchronized
82     override fun registerTimeLoopThread() {
83         val status = threads.getOrPut(Thread.currentThread()) { ThreadStatus() }!!
84         status.registered++
85     }
86 
87     @Synchronized
88     override fun unregisterTimeLoopThread() {
89         val currentThread = Thread.currentThread()
90         val status = threads[currentThread]!!
91         if (--status.registered == 0) {
92             threads.remove(currentThread)
93             wakeupAll()
94         }
95     }
96 
97     override fun parkNanos(blocker: Any, nanos: Long) {
98         if (nanos <= 0) return
99         val status = threads[Thread.currentThread()]!!
100         assert(status.parkedTill == NOT_PARKED)
101         status.parkedTill = time + nanos.coerceAtMost(MAX_WAIT_NANOS)
102         while (true) {
103             checkAdvanceTime()
104             if (isShutdown || time >= status.parkedTill || status.permit) {
105                 status.parkedTill = NOT_PARKED
106                 status.permit = false
107                 break
108             }
109             LockSupport.parkNanos(blocker, REAL_PARK_NANOS)
110         }
111     }
112 
113     override fun unpark(thread: Thread) {
114         val status = threads[thread] ?: return
115         status.permit = true
116         LockSupport.unpark(thread)
117     }
118 
119     @Synchronized
120     private fun checkAdvanceTime() {
121         if (isShutdown) return
122         val realNanos = System.nanoTime()
123         if (realNanos > checkpointNanos + REAL_TIME_STEP_NANOS) {
124             checkpointNanos = realNanos
125             val minParkedTill = minParkedTill()
126             time = (time + REAL_TIME_STEP_NANOS).coerceAtMost(if (minParkedTill < 0) Long.MAX_VALUE else minParkedTill)
127             logTime("R")
128             wakeupAll()
129             return
130         }
131         if (threads[mainThread] == null) return
132         if (trackedTasks != 0) return
133         val minParkedTill = minParkedTill()
134         if (minParkedTill <= time) return
135         time = minParkedTill
136         logTime("V")
137         wakeupAll()
138     }
139 
140     private fun logTime(s: String) {
141         log?.println("[$s: Time = ${TimeUnit.NANOSECONDS.toMillis(time)} ms]")
142     }
143 
144     private fun minParkedTill(): Long =
145         threads.values.map { if (it.permit) NOT_PARKED else it.parkedTill }.min() ?: NOT_PARKED
146 
147     @Synchronized
148     fun shutdown() {
149         isShutdown = true
150         wakeupAll()
151         while (!threads.isEmpty()) (this as Object).wait()
152     }
153 
154     private fun wakeupAll() {
155         threads.keys.forEach { LockSupport.unpark(it) }
156         (this as Object).notifyAll()
157     }
158 }
159