1 /*
2  * Copyright 2016-2020 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
3  */
4 
5 package kotlinx.coroutines
6 
7 import kotlinx.coroutines.internal.*
8 import kotlin.coroutines.*
9 
10 /**
11  * Defines elements in [CoroutineContext] that are installed into thread context
12  * every time the coroutine with this element in the context is resumed on a thread.
13  *
14  * Implementations of this interface define a type [S] of the thread-local state that they need to store on
15  * resume of a coroutine and restore later on suspend. The infrastructure provides the corresponding storage.
16  *
17  * Example usage looks like this:
18  *
19  * ```
20  * // Appends "name" of a coroutine to a current thread name when coroutine is executed
21  * class CoroutineName(val name: String) : ThreadContextElement<String> {
22  *     // declare companion object for a key of this element in coroutine context
23  *     companion object Key : CoroutineContext.Key<CoroutineName>
24  *
25  *     // provide the key of the corresponding context element
26  *     override val key: CoroutineContext.Key<CoroutineName>
27  *         get() = Key
28  *
29  *     // this is invoked before coroutine is resumed on current thread
30  *     override fun updateThreadContext(context: CoroutineContext): String {
31  *         val previousName = Thread.currentThread().name
32  *         Thread.currentThread().name = "$previousName # $name"
33  *         return previousName
34  *     }
35  *
36  *     // this is invoked after coroutine has suspended on current thread
37  *     override fun restoreThreadContext(context: CoroutineContext, oldState: String) {
38  *         Thread.currentThread().name = oldState
39  *     }
40  * }
41  *
42  * // Usage
43  * launch(Dispatchers.Main + CoroutineName("Progress bar coroutine")) { ... }
44  * ```
45  *
46  * Every time this coroutine is resumed on a thread, UI thread name is updated to
47  * "UI thread original name # Progress bar coroutine" and the thread name is restored to the original one when
48  * this coroutine suspends.
49  *
50  * To use [ThreadLocal] variable within the coroutine use [ThreadLocal.asContextElement][asContextElement] function.
51  */
52 public interface ThreadContextElement<S> : CoroutineContext.Element {
53     /**
54      * Updates context of the current thread.
55      * This function is invoked before the coroutine in the specified [context] is resumed in the current thread
56      * when the context of the coroutine this element.
57      * The result of this function is the old value of the thread-local state that will be passed to [restoreThreadContext].
58      * This method should handle its own exceptions and do not rethrow it. Thrown exceptions will leave coroutine which
59      * context is updated in an undefined state and may crash an application.
60      *
61      * @param context the coroutine context.
62      */
updateThreadContextnull63     public fun updateThreadContext(context: CoroutineContext): S
64 
65     /**
66      * Restores context of the current thread.
67      * This function is invoked after the coroutine in the specified [context] is suspended in the current thread
68      * if [updateThreadContext] was previously invoked on resume of this coroutine.
69      * The value of [oldState] is the result of the previous invocation of [updateThreadContext] and it should
70      * be restored in the thread-local state by this function.
71      * This method should handle its own exceptions and do not rethrow it. Thrown exceptions will leave coroutine which
72      * context is updated in an undefined state and may crash an application.
73      *
74      * @param context the coroutine context.
75      * @param oldState the value returned by the previous invocation of [updateThreadContext].
76      */
77     public fun restoreThreadContext(context: CoroutineContext, oldState: S)
78 }
79 
80 /**
81  * Wraps [ThreadLocal] into [ThreadContextElement]. The resulting [ThreadContextElement]
82  * maintains the given [value] of the given [ThreadLocal] for coroutine regardless of the actual thread its is resumed on.
83  * By default [ThreadLocal.get] is used as a value for the thread-local variable, but it can be overridden with [value] parameter.
84  * Beware that context element **does not track** modifications of the thread-local and accessing thread-local from coroutine
85  * without the corresponding context element returns **undefined** value. See the examples for a detailed description.
86  *
87  *
88  * Example usage:
89  * ```
90  * val myThreadLocal = ThreadLocal<String?>()
91  * ...
92  * println(myThreadLocal.get()) // Prints "null"
93  * launch(Dispatchers.Default + myThreadLocal.asContextElement(value = "foo")) {
94  *   println(myThreadLocal.get()) // Prints "foo"
95  *   withContext(Dispatchers.Main) {
96  *     println(myThreadLocal.get()) // Prints "foo", but it's on UI thread
97  *   }
98  * }
99  * println(myThreadLocal.get()) // Prints "null"
100  * ```
101  *
102  * The context element does not track modifications of the thread-local variable, for example:
103  *
104  * ```
105  * myThreadLocal.set("main")
106  * withContext(Dispatchers.Main) {
107  *   println(myThreadLocal.get()) // Prints "main"
108  *   myThreadLocal.set("UI")
109  * }
110  * println(myThreadLocal.get()) // Prints "main", not "UI"
111  * ```
112  *
113  * Use `withContext` to update the corresponding thread-local variable to a different value, for example:
114  * ```
115  * withContext(myThreadLocal.asContextElement("foo")) {
116  *     println(myThreadLocal.get()) // Prints "foo"
117  * }
118  * ```
119  *
120  * Accessing the thread-local without corresponding context element leads to undefined value:
121  * ```
122  * val tl = ThreadLocal.withInitial { "initial" }
123  *
124  * runBlocking {
125  *   println(tl.get()) // Will print "initial"
126  *   // Change context
127  *   withContext(tl.asContextElement("modified")) {
128  *     println(tl.get()) // Will print "modified"
129  *   }
130  *   // Context is changed again
131  *    println(tl.get()) // <- WARN: can print either "modified" or "initial"
132  * }
133  * ```
134  * to fix this behaviour use `runBlocking(tl.asContextElement())`
135  */
136 public fun <T> ThreadLocal<T>.asContextElement(value: T = get()): ThreadContextElement<T> =
137     ThreadLocalElement(value, this)
138 
139 /**
140  * Return `true` when current thread local is present in the coroutine context, `false` otherwise.
141  * Thread local can be present in the context only if it was added via [asContextElement] to the context.
142  *
143  * Example of usage:
144  * ```
145  * suspend fun processRequest() {
146  *   if (traceCurrentRequestThreadLocal.isPresent()) { // Probabilistic tracing
147  *      // Do some heavy-weight tracing
148  *   }
149  *   // Process request regularly
150  * }
151  * ```
152  */
153 public suspend inline fun ThreadLocal<*>.isPresent(): Boolean = coroutineContext[ThreadLocalKey(this)] !== null
154 
155 /**
156  * Checks whether current thread local is present in the coroutine context and throws [IllegalStateException] if it is not.
157  * It is a good practice to validate that thread local is present in the context, especially in large code-bases,
158  * to avoid stale thread-local values and to have a strict invariants.
159  *
160  * E.g. one may use the following method to enforce proper use of the thread locals with coroutines:
161  * ```
162  * public suspend inline fun <T> ThreadLocal<T>.getSafely(): T {
163  *   ensurePresent()
164  *   return get()
165  * }
166  *
167  * // Usage
168  * withContext(...) {
169  *   val value = threadLocal.getSafely() // Fail-fast in case of improper context
170  * }
171  * ```
172  */
173 public suspend inline fun ThreadLocal<*>.ensurePresent(): Unit =
174     check(isPresent()) { "ThreadLocal $this is missing from context $coroutineContext" }
175