/* * Copyright 2016-2019 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license. */ package kotlinx.coroutines import java.io.* import java.util.concurrent.* import java.util.concurrent.locks.* private const val SHUTDOWN_TIMEOUT = 1000L internal inline fun withVirtualTimeSource(log: PrintStream? = null, block: () -> Unit) { DefaultExecutor.shutdown(SHUTDOWN_TIMEOUT) // shutdown execution with old time source (in case it was working) val testTimeSource = VirtualTimeSource(log) timeSource = testTimeSource DefaultExecutor.ensureStarted() // should start with new time source try { block() } finally { DefaultExecutor.shutdown(SHUTDOWN_TIMEOUT) testTimeSource.shutdown() timeSource = null // restore time source } } private const val NOT_PARKED = -1L private class ThreadStatus { @Volatile @JvmField var parkedTill = NOT_PARKED @Volatile @JvmField var permit = false var registered = 0 override fun toString(): String = "parkedTill = ${TimeUnit.NANOSECONDS.toMillis(parkedTill)} ms, permit = $permit" } private const val MAX_WAIT_NANOS = 10_000_000_000L // 10s private const val REAL_TIME_STEP_NANOS = 200_000_000L // 200 ms private const val REAL_PARK_NANOS = 10_000_000L // 10 ms -- park for a little to better track real-time @Suppress("PLATFORM_CLASS_MAPPED_TO_KOTLIN") internal class VirtualTimeSource( private val log: PrintStream? ) : TimeSource { private val mainThread: Thread = Thread.currentThread() private var checkpointNanos: Long = System.nanoTime() @Volatile private var isShutdown = false @Volatile private var time: Long = 0 private var trackedTasks = 0 private val threads = ConcurrentHashMap() override fun currentTimeMillis(): Long = TimeUnit.NANOSECONDS.toMillis(time) override fun nanoTime(): Long = time override fun wrapTask(block: Runnable): Runnable { trackTask() return Runnable { try { block.run() } finally { unTrackTask() } } } @Synchronized override fun trackTask() { trackedTasks++ } @Synchronized override fun unTrackTask() { assert(trackedTasks > 0) trackedTasks-- } @Synchronized override fun registerTimeLoopThread() { val status = threads.getOrPut(Thread.currentThread()) { ThreadStatus() }!! status.registered++ } @Synchronized override fun unregisterTimeLoopThread() { val currentThread = Thread.currentThread() val status = threads[currentThread]!! if (--status.registered == 0) { threads.remove(currentThread) wakeupAll() } } override fun parkNanos(blocker: Any, nanos: Long) { if (nanos <= 0) return val status = threads[Thread.currentThread()]!! assert(status.parkedTill == NOT_PARKED) status.parkedTill = time + nanos.coerceAtMost(MAX_WAIT_NANOS) while (true) { checkAdvanceTime() if (isShutdown || time >= status.parkedTill || status.permit) { status.parkedTill = NOT_PARKED status.permit = false break } LockSupport.parkNanos(blocker, REAL_PARK_NANOS) } } override fun unpark(thread: Thread) { val status = threads[thread] ?: return status.permit = true LockSupport.unpark(thread) } @Synchronized private fun checkAdvanceTime() { if (isShutdown) return val realNanos = System.nanoTime() if (realNanos > checkpointNanos + REAL_TIME_STEP_NANOS) { checkpointNanos = realNanos val minParkedTill = minParkedTill() time = (time + REAL_TIME_STEP_NANOS).coerceAtMost(if (minParkedTill < 0) Long.MAX_VALUE else minParkedTill) logTime("R") wakeupAll() return } if (threads[mainThread] == null) return if (trackedTasks != 0) return val minParkedTill = minParkedTill() if (minParkedTill <= time) return time = minParkedTill logTime("V") wakeupAll() } private fun logTime(s: String) { log?.println("[$s: Time = ${TimeUnit.NANOSECONDS.toMillis(time)} ms]") } private fun minParkedTill(): Long = threads.values.map { if (it.permit) NOT_PARKED else it.parkedTill }.min() ?: NOT_PARKED @Synchronized fun shutdown() { isShutdown = true wakeupAll() while (!threads.isEmpty()) (this as Object).wait() } private fun wakeupAll() { threads.keys.forEach { LockSupport.unpark(it) } (this as Object).notifyAll() } }