1 /*
2  * Copyright 2016-2018 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
3  */
4 
5 package kotlinx.coroutines
6 
7 import org.junit.*
8 import org.junit.Test
9 import java.lang.IllegalStateException
10 import kotlin.test.*
11 
12 @Suppress("RedundantAsync")
13 class ThreadLocalTest : TestBase() {
14     private val stringThreadLocal = ThreadLocal<String?>()
15     private val intThreadLocal = ThreadLocal<Int?>()
16     private val executor = newFixedThreadPoolContext(1, "threadLocalTest")
17 
18     @After
tearDownnull19     fun tearDown() {
20         executor.close()
21     }
22 
23     @Test
<lambda>null24     fun testThreadLocal() = runTest {
25         assertNull(stringThreadLocal.get())
26         assertFalse(stringThreadLocal.isPresent())
27         val deferred = async(Dispatchers.Default + stringThreadLocal.asContextElement("value")) {
28             assertEquals("value", stringThreadLocal.get())
29             assertTrue(stringThreadLocal.isPresent())
30             withContext(executor) {
31                 assertTrue(stringThreadLocal.isPresent())
32                 assertFailsWith<IllegalStateException> { intThreadLocal.ensurePresent() }
33                 assertEquals("value", stringThreadLocal.get())
34             }
35             assertTrue(stringThreadLocal.isPresent())
36             assertEquals("value", stringThreadLocal.get())
37         }
38 
39         assertNull(stringThreadLocal.get())
40         deferred.await()
41         assertNull(stringThreadLocal.get())
42         assertFalse(stringThreadLocal.isPresent())
43     }
44 
45     @Test
<lambda>null46     fun testThreadLocalInitialValue() = runTest {
47         intThreadLocal.set(42)
48         assertFalse(intThreadLocal.isPresent())
49         val deferred = async(Dispatchers.Default + intThreadLocal.asContextElement(239)) {
50             assertEquals(239, intThreadLocal.get())
51             withContext(executor) {
52                 intThreadLocal.ensurePresent()
53                 assertEquals(239, intThreadLocal.get())
54             }
55             assertEquals(239, intThreadLocal.get())
56         }
57 
58         deferred.await()
59         assertEquals(42, intThreadLocal.get())
60     }
61 
62     @Test
<lambda>null63     fun testMultipleThreadLocals() = runTest {
64         stringThreadLocal.set("test")
65         intThreadLocal.set(314)
66 
67         val deferred = async(Dispatchers.Default
68                 + intThreadLocal.asContextElement(value = 239) + stringThreadLocal.asContextElement(value = "pew")) {
69             assertEquals(239, intThreadLocal.get())
70             assertEquals("pew", stringThreadLocal.get())
71 
72             withContext(executor) {
73                 assertEquals(239, intThreadLocal.get())
74                 assertEquals("pew", stringThreadLocal.get())
75                 intThreadLocal.ensurePresent()
76                 stringThreadLocal.ensurePresent()
77             }
78 
79             assertEquals(239, intThreadLocal.get())
80             assertEquals("pew", stringThreadLocal.get())
81         }
82 
83         deferred.await()
84         assertEquals(314, intThreadLocal.get())
85         assertEquals("test", stringThreadLocal.get())
86     }
87 
88     @Test
<lambda>null89     fun testConflictingThreadLocals() = runTest {
90         intThreadLocal.set(42)
91 
92         val deferred = GlobalScope.async(intThreadLocal.asContextElement(1)) {
93             assertEquals(1, intThreadLocal.get())
94 
95             withContext(executor + intThreadLocal.asContextElement(42)) {
96                 assertEquals(42, intThreadLocal.get())
97             }
98 
99             assertEquals(1, intThreadLocal.get())
100 
101             val deferred = async(intThreadLocal.asContextElement(53)) {
102                 assertEquals(53, intThreadLocal.get())
103             }
104 
105             deferred.await()
106             assertEquals(1, intThreadLocal.get())
107 
108             val deferred2 = GlobalScope.async(executor) {
109                 assertNull(intThreadLocal.get())
110             }
111 
112             deferred2.await()
113             assertEquals(1, intThreadLocal.get())
114         }
115 
116         deferred.await()
117         assertEquals(42, intThreadLocal.get())
118     }
119 
120     @Test
<lambda>null121     fun testThreadLocalModification() = runTest {
122         stringThreadLocal.set("main")
123 
124         val deferred = async(Dispatchers.Default
125                 + stringThreadLocal.asContextElement("initial")) {
126             assertEquals("initial", stringThreadLocal.get())
127 
128             stringThreadLocal.set("overridden") // <- this value is not reflected in the context, so it's not restored
129 
130             withContext(executor + stringThreadLocal.asContextElement("ctx")) {
131                 assertEquals("ctx", stringThreadLocal.get())
132             }
133 
134             val deferred = async(stringThreadLocal.asContextElement("async")) {
135                 assertEquals("async", stringThreadLocal.get())
136             }
137 
138             deferred.await()
139             assertEquals("initial", stringThreadLocal.get()) // <- not restored
140         }
141 
142         deferred.await()
143         assertFalse(stringThreadLocal.isPresent())
144         assertEquals("main", stringThreadLocal.get())
145     }
146 
147 
148 
149     private data class Counter(var cnt: Int)
150     private val myCounterLocal = ThreadLocal<Counter>()
151 
152     @Test
<lambda>null153     fun testThreadLocalModificationMutableBox() = runTest {
154         myCounterLocal.set(Counter(42))
155 
156         val deferred = async(Dispatchers.Default
157                 + myCounterLocal.asContextElement(Counter(0))) {
158             assertEquals(0, myCounterLocal.get().cnt)
159 
160             // Mutate
161             myCounterLocal.get().cnt = 71
162 
163             withContext(executor + myCounterLocal.asContextElement(Counter(-1))) {
164                 assertEquals(-1, myCounterLocal.get().cnt)
165                 ++myCounterLocal.get().cnt
166             }
167 
168             val deferred = async(myCounterLocal.asContextElement(Counter(31))) {
169                 assertEquals(31, myCounterLocal.get().cnt)
170                 ++myCounterLocal.get().cnt
171             }
172 
173             deferred.await()
174             assertEquals(71, myCounterLocal.get().cnt)
175         }
176 
177         deferred.await()
178         assertEquals(42, myCounterLocal.get().cnt)
179     }
180 
181     @Test
<lambda>null182     fun testWithContext() = runTest {
183         expect(1)
184         newSingleThreadContext("withContext").use {
185             val data = 42
186             GlobalScope.async(Dispatchers.Default + intThreadLocal.asContextElement(42)) {
187 
188                 assertEquals(data, intThreadLocal.get())
189                 expect(2)
190 
191                 GlobalScope.async(it + intThreadLocal.asContextElement(31)) {
192                     assertEquals(31, intThreadLocal.get())
193                     expect(3)
194                 }.await()
195 
196                 withContext(it + intThreadLocal.asContextElement(2)) {
197                     assertEquals(2, intThreadLocal.get())
198                     expect(4)
199                 }
200 
201                 GlobalScope.async(it) {
202                     assertNull(intThreadLocal.get())
203                     expect(5)
204                 }.await()
205 
206                 expect(6)
207             }.await()
208         }
209 
210         finish(7)
211     }
212 
213     @Test
<lambda>null214     fun testScope() = runTest {
215         intThreadLocal.set(42)
216         val mainThread = Thread.currentThread()
217         GlobalScope.async {
218           assertNull(intThreadLocal.get())
219             assertNotSame(mainThread, Thread.currentThread())
220         }.await()
221 
222         GlobalScope.async(intThreadLocal.asContextElement()) {
223             assertEquals(42, intThreadLocal.get())
224             assertNotSame(mainThread, Thread.currentThread())
225         }.await()
226     }
227 
228     @Test
<lambda>null229     fun testMissingThreadLocal() = runTest {
230         assertFailsWith<IllegalStateException> { stringThreadLocal.ensurePresent() }
231         assertFailsWith<IllegalStateException> { intThreadLocal.ensurePresent() }
232     }
233 }
234