1 /*
<lambda>null2  * Copyright 2017-2018 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
3  */
4 
5 @file:Suppress("RedundantVisibilityModifier")
6 
7 package kotlinx.atomicfu
8 
9 import java.util.*
10 import java.util.concurrent.atomic.*
11 import java.util.concurrent.locks.*
12 import kotlin.coroutines.*
13 import kotlin.coroutines.intrinsics.*
14 
15 private const val PAUSE_EVERY_N_STEPS = 1000
16 private const val STALL_LIMIT_MS = 15_000L // 15s
17 private const val SHUTDOWN_CHECK_MS = 10L // 10ms
18 
19 private const val STATUS_DONE = Int.MAX_VALUE
20 
21 private const val MAX_PARK_NANOS = 1_000_000L // part for at most 1ms just in case of loosing unpark signal
22 
23 /**
24  * Environment for performing lock-freedom tests for lock-free data structures
25  * that are written with [atomic] variables.
26  */
27 public open class LockFreedomTestEnvironment(
28     private val name: String,
29     private val allowSuspendedThreads: Int = 0
30 ) {
31     private val interceptor = Interceptor()
32     private val threads = mutableListOf<TestThread>()
33     private val performedOps = LongAdder()
34     private val uncaughtException = AtomicReference<Throwable?>()
35     private var started = false
36     private var performedResumes = 0
37 
38     @Volatile
39     private var completed = false
40     private val onCompletion = mutableListOf<() -> Unit>()
41 
42     private val ueh = Thread.UncaughtExceptionHandler { t, e ->
43         synchronized(System.out) {
44             println("Uncaught exception in thread $t")
45             e.printStackTrace(System.out)
46             uncaughtException.compareAndSet(null, e)
47         }
48     }
49 
50     // status < 0             - inv paused thread id
51     // status >= 0            - no. of performed resumes so far (==last epoch)
52     // status == STATUS_DONE - done working
53     private val status = AtomicInteger()
54     private val globalPauseProgress = AtomicInteger()
55     private val suspendedThreads = ArrayList<TestThread>()
56 
57     @Volatile
58     private var isActive = true
59 
60     // ---------- API ----------
61 
62     /**
63      * Starts lock-freedom test for a given duration in seconds,
64      * invoking [progress] every second (it will be invoked `seconds + 1` times).
65      */
66     public fun performTest(seconds: Int, progress: () -> Unit = {}) {
67         check(isActive) { "Can perform test at most once on this instance" }
68         println("=== $name")
69         val minThreads = 2 + allowSuspendedThreads
70         check(threads.size >= minThreads) { "Must define at least $minThreads test threads" }
71         lockAndSetInterceptor(interceptor)
72         started = true
73         var nextTime = System.currentTimeMillis()
74         threads.forEach { thread ->
75             thread.setUncaughtExceptionHandler(ueh)
76             thread.lastOpTime = nextTime
77             thread.start()
78         }
79         try {
80             var second = 0
81             while (uncaughtException.get() == null) {
82                 waitUntil(nextTime)
83                 println("--- $second: Performed ${performedOps.sum()} operations${resumeStr()}")
84                 progress()
85                 checkStalled()
86                 if (++second > seconds) break
87                 nextTime += 1000L
88             }
89         } finally {
90             complete()
91         }
92         println("------ Done with ${performedOps.sum()} operations${resumeStr()}")
93         progress()
94     }
95 
96     private fun complete() {
97         val activeNonPausedThreads: MutableMap<TestThread, Array<StackTraceElement>> = mutableMapOf()
98         val shutdownDeadline = System.currentTimeMillis() + STALL_LIMIT_MS
99         try {
100             completed = true
101             // perform custom completion blocks. For testing of things like channels, these custom completion
102             // blocks close all the channels, so that all suspended coroutines shall get resumed.
103             onCompletion.forEach { it() }
104             // signal shutdown to all threads (non-paused threads will terminate)
105             isActive = false
106             // wait for threads to terminate
107             while (System.currentTimeMillis() < shutdownDeadline) {
108                 // Check all threads while shutting down:
109                 // All terminated threads are considered to make progress for the purpose of resuming stalled ones
110                 activeNonPausedThreads.clear()
111                 for (t in threads) {
112                     when {
113                         !t.isAlive -> t.makeProgress(getPausedEpoch()) // not alive - makes progress
114                         t.index.inv() == status.get() -> {} // active, paused -- skip
115                         else -> {
116                             val stackTrace = t.stackTrace
117                             if (t.isAlive) activeNonPausedThreads[t] = stackTrace
118                         }
119                     }
120                 }
121                 if (activeNonPausedThreads.isEmpty()) break
122                 checkStalled()
123                 Thread.sleep(SHUTDOWN_CHECK_MS)
124             }
125             activeNonPausedThreads.forEach { (t, stackTrack) ->
126                 println("=== $t had failed to shutdown in time")
127                 stackTrack.forEach { println("\tat $it") }
128             }
129         } finally {
130             shutdown(shutdownDeadline)
131         }
132         // if no other exception was throws & we had threads that did not shut down -- still fails
133         if (activeNonPausedThreads.isNotEmpty()) error("Some threads had failed to shutdown in time")
134     }
135 
136     private fun shutdown(shutdownDeadline: Long) {
137         // forcefully unpause paused threads to shut them down (if any left)
138         val curStatus = status.getAndSet(STATUS_DONE)
139         if (curStatus < 0) LockSupport.unpark(threads[curStatus.inv()])
140         threads.forEach {
141             val remaining = shutdownDeadline - System.currentTimeMillis()
142             if (remaining > 0) it.join(remaining)
143         }
144         // abort waiting threads (if still any left)
145         threads.forEach { it.abortWait() }
146         // cleanup & be done
147         unlockAndResetInterceptor(interceptor)
148         uncaughtException.get()?.let { throw it }
149         threads.find { it.isAlive }?.let { dumpThreadsError("A thread is still alive: $it")}
150     }
151 
152     private fun checkStalled() {
153         val stallLimit = System.currentTimeMillis() - STALL_LIMIT_MS
154         val stalled = threads.filter { it.lastOpTime < stallLimit }
155         if (stalled.isNotEmpty()) dumpThreadsError("Progress stalled in threads ${stalled.map { it.name }}")
156     }
157 
158     private fun resumeStr(): String {
159         val resumes = performedResumes
160         return if (resumes == 0) "" else " (pause/resumes $resumes)"
161     }
162 
163     private fun waitUntil(nextTime: Long) {
164         while (true) {
165             val curTime = System.currentTimeMillis()
166             if (curTime >= nextTime) break
167             Thread.sleep(nextTime - curTime)
168         }
169     }
170 
171     private fun dumpThreadsError(message: String) : Nothing {
172         val traces = threads.associate { it to it.stackTrace }
173         println("!!! $message")
174         println("=== Dumping live thread stack traces")
175         for ((thread, trace) in traces) {
176             if (trace.isEmpty()) continue
177             println("Thread \"${thread.name}\" ${thread.state}")
178             for (t in trace) println("\tat ${t.className}.${t.methodName}(${t.fileName}:${t.lineNumber})")
179             println()
180         }
181         println("===")
182         error(message)
183     }
184 
185     /**
186      * Returns true when test was completed.
187      * Sets to true before calling [onCompletion] blocks.
188      */
189     public val isCompleted: Boolean get() = completed
190 
191     /**
192      * Performs a given block of code on test's completion
193      */
194     public fun onCompletion(block: () -> Unit) {
195         onCompletion += block
196     }
197 
198     /**
199      * Creates a new test thread in this environment that is executes a given lock-free [operation]
200      * in a loop while this environment [isActive].
201      */
202     public fun testThread(name: String? = null, operation: suspend TestThread.() -> Unit): TestThread =
203         TestThread(name, operation)
204 
205     /**
206      * Test thread.
207      */
208     @Suppress("LeakingThis")
209     public inner class TestThread internal constructor(
210         name: String?,
211         private val operation: suspend TestThread.() -> Unit
212     ) : Thread(composeThreadName(name)) {
213         internal val index: Int
214 
215         internal @Volatile var lastOpTime = 0L
216         internal @Volatile var pausedEpoch = -1
217 
218         private val random = Random()
219 
220         // thread-local stuff
221         private var operationEpoch = -1
222         private var progressEpoch = -1
223         private var sink = 0
224 
225         init {
226             check(!started)
227             index = threads.size
228             threads += this
229         }
230 
231         public override fun run() {
232             while (isActive) {
233                 callOperation()
234             }
235         }
236 
237         /**
238          * Use it to insert an arbitrary intermission between lock-free operations.
239          */
240         public inline fun <T> intermission(block: () -> T): T {
241             afterLockFreeOperation()
242             return try { block() }
243                 finally { beforeLockFreeOperation() }
244         }
245 
246         @PublishedApi
247         internal fun beforeLockFreeOperation() {
248             operationEpoch = getPausedEpoch()
249         }
250 
251         @PublishedApi
252         internal fun afterLockFreeOperation() {
253             makeProgress(operationEpoch)
254             lastOpTime = System.currentTimeMillis()
255             performedOps.add(1)
256         }
257 
258         internal fun makeProgress(epoch: Int) {
259             if (epoch <= progressEpoch) return
260             progressEpoch = epoch
261             val total = globalPauseProgress.incrementAndGet()
262             if (total >= threads.size - 1) {
263                 check(total == threads.size - 1)
264                 check(globalPauseProgress.compareAndSet(threads.size - 1, 0))
265                 resumeImpl()
266             }
267         }
268 
269         /**
270          * Inserts random spin wait between multiple lock-free operations in [operation].
271          */
272         public fun randomSpinWaitIntermission() {
273             intermission {
274                 if (random.nextInt(100) < 95) return // be quick, no wait 95% of time
275                 do {
276                     val x = random.nextInt(100)
277                     repeat(x) { sink += it }
278                 } while (x >= 90)
279             }
280         }
281 
282         internal fun stepImpl() {
283             if (random.nextInt(PAUSE_EVERY_N_STEPS) == 0) pauseImpl()
284         }
285 
286         internal fun pauseImpl() {
287             while (true) {
288                 val curStatus = status.get()
289                 if (curStatus < 0 || curStatus == STATUS_DONE) return // some other thread paused or done
290                 pausedEpoch = curStatus + 1
291                 val newStatus = index.inv()
292                 if (status.compareAndSet(curStatus, newStatus)) {
293                     while (status.get() == newStatus) LockSupport.parkNanos(MAX_PARK_NANOS) // wait
294                     return
295                 }
296             }
297         }
298 
299         // ----- Lightweight support for suspending operations -----
300 
301         private fun callOperation() {
302             beforeLockFreeOperation()
303             beginRunningOperation()
304             val result = operation.startCoroutineUninterceptedOrReturn(this, completion)
305             when {
306                 result === Unit -> afterLockFreeOperation() // operation completed w/o suspension -- done
307                 result === COROUTINE_SUSPENDED -> waitUntilCompletion() // operation had suspended
308                 else -> error("Unexpected result of operation: $result")
309             }
310             try {
311                 doneRunningOperation()
312             } catch(e: IllegalStateException) {
313                 throw IllegalStateException("${e.message}; original start result=$result", e)
314             }
315         }
316 
317         private var runningOperation = false
318         private var result: Result<Any?>? = null
319         private var continuation: Continuation<Any?>? = null
320 
321         private fun waitUntilCompletion() {
322             try {
323                 while (true) {
324                     afterLockFreeOperation()
325                     val result: Result<Any?> = waitForResult()
326                     val continuation = takeContinuation()
327                     if (continuation == null) { // done
328                         check(result.getOrThrow() === Unit)
329                         return
330                     }
331                     removeSuspended(this)
332                     beforeLockFreeOperation()
333                     continuation.resumeWith(result)
334                 }
335             } finally {
336                 removeSuspended(this)
337             }
338         }
339 
340         private fun beginRunningOperation() {
341             runningOperation = true
342             result = null
343             continuation = null
344         }
345 
346         @Synchronized
347         private fun doneRunningOperation() {
348             check(runningOperation) { "Should be running operation" }
349             check(result == null && continuation == null) {
350                 "Callback invoked with result=$result, continuation=$continuation"
351             }
352             runningOperation = false
353         }
354 
355         @Suppress("PLATFORM_CLASS_MAPPED_TO_KOTLIN")
356         @Synchronized
357         private fun resumeWith(result: Result<Any?>, continuation: Continuation<Any?>?) {
358             check(runningOperation) { "Should be running operation" }
359             check(this.result == null && this.continuation == null) {
360                 "Resumed again with result=$result, continuation=$continuation, when this: result=${this.result}, continuation=${this.continuation}"
361             }
362             this.result = result
363             this.continuation = continuation
364             (this as Object).notifyAll()
365         }
366 
367         @Suppress("RESULT_CLASS_IN_RETURN_TYPE", "PLATFORM_CLASS_MAPPED_TO_KOTLIN")
368         @Synchronized
369         private fun waitForResult(): Result<Any?> {
370             while (true) {
371                 val result = this.result
372                 if (result != null) return result
373                 val index = addSuspended(this)
374                 if (index < allowSuspendedThreads) {
375                     // This suspension was permitted, so assume progress is happening while it is suspended
376                     makeProgress(getPausedEpoch())
377                 }
378                 (this as Object).wait(10) // at most 10 ms
379             }
380         }
381 
382         @Synchronized
383         private fun takeContinuation(): Continuation<Any?>? =
384             continuation.also {
385                 this.result = null
386                 this.continuation = null
387             }
388 
389         @Suppress("PLATFORM_CLASS_MAPPED_TO_KOTLIN")
390         @Synchronized
391         fun abortWait() {
392             this.result = Result.failure(IllegalStateException("Aborted at the end of test"))
393             (this as Object).notifyAll()
394         }
395 
396         private val interceptor: CoroutineContext = object : AbstractCoroutineContextElement(ContinuationInterceptor), ContinuationInterceptor {
397             override fun <T> interceptContinuation(continuation: Continuation<T>): Continuation<T> =
398                 Continuation<T>(this) {
399                     @Suppress("UNCHECKED_CAST")
400                     resumeWith(it, continuation as Continuation<Any?>)
401                 }
402         }
403 
404         private val completion = Continuation<Unit>(interceptor) {
405             resumeWith(it, null)
406         }
407     }
408 
409     // ---------- Implementation ----------
410 
411     @Synchronized
412     private fun addSuspended(thread: TestThread): Int {
413         val index = suspendedThreads.indexOf(thread)
414         if (index >= 0) return index
415         suspendedThreads.add(thread)
416         return suspendedThreads.size - 1
417     }
418 
419     @Synchronized
420     private fun removeSuspended(thread: TestThread) {
421         suspendedThreads.remove(thread)
422     }
423 
424     private fun getPausedEpoch(): Int {
425         while (true) {
426             val curStatus = status.get()
427             if (curStatus >= 0) return -1 // not paused
428             val thread = threads[curStatus.inv()]
429             val pausedEpoch = thread.pausedEpoch
430             if (curStatus == status.get()) return pausedEpoch
431         }
432     }
433 
434     internal fun step() {
435         val thread = Thread.currentThread() as? TestThread ?: return
436         thread.stepImpl()
437     }
438 
439     private fun resumeImpl() {
440         while (true) {
441             val curStatus = status.get()
442             if (curStatus == STATUS_DONE) return // done
443             check(curStatus < 0)
444             val thread = threads[curStatus.inv()]
445             performedResumes = thread.pausedEpoch
446             if (status.compareAndSet(curStatus, thread.pausedEpoch)) {
447                 LockSupport.unpark(thread)
448                 return
449             }
450         }
451     }
452 
453     private fun composeThreadName(threadName: String?): String {
454         if (threadName != null) return "$name-$threadName"
455         return name + "-${threads.size + 1}"
456     }
457 
458     private inner class Interceptor : AtomicOperationInterceptor() {
459         override fun <T> beforeUpdate(ref: AtomicRef<T>) = step()
460         override fun beforeUpdate(ref: AtomicInt) = step()
461         override fun beforeUpdate(ref: AtomicLong) = step()
462         override fun <T> afterSet(ref: AtomicRef<T>, newValue: T) = step()
463         override fun afterSet(ref: AtomicInt, newValue: Int) = step()
464         override fun afterSet(ref: AtomicLong, newValue: Long) = step()
465         override fun <T> afterRMW(ref: AtomicRef<T>, oldValue: T, newValue: T) = step()
466         override fun afterRMW(ref: AtomicInt, oldValue: Int, newValue: Int) = step()
467         override fun afterRMW(ref: AtomicLong, oldValue: Long, newValue: Long) = step()
468         override fun toString(): String = "LockFreedomTestEnvironment($name)"
469     }
470 }
471 
472 /**
473  * Manual pause for on-going lock-free operation in a specified piece of code.
474  * Use it for targeted debugging of specific places in code. It does nothing
475  * when invoked outside of test thread.
476  *
477  * **Don't use it in production code.**
478  */
pauseLockFreeOpnull479 public fun pauseLockFreeOp() {
480     val thread = Thread.currentThread() as? LockFreedomTestEnvironment.TestThread ?: return
481     thread.pauseImpl()
482 }