1 /*
2  * Copyright (C) 2023 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 import java.lang.invoke.MethodHandles;
18 import java.lang.invoke.VarHandle;
19 import java.lang.reflect.Field;
20 import java.util.concurrent.atomic.AtomicReference;
21 import sun.misc.Unsafe;
22 
23 class Main {
main(String args[])24     public static void main(String args[]) throws Exception {
25         // Stress-test read-modify-write operations in adjacent memory locations.
26         // This is intended to uncover bugs triggered by spurious CAS failures on
27         // architectures where such spurious failures can happen. Bug: 218453177
28         $noinline$testVarHandleBytes();
29         $noinline$testVarHandleInts();
30         $noinline$testVarHandleLongs();
31         $noinline$testVarHandleReferences();
32         $noinline$testUnsafeInts();
33         $noinline$testUnsafeLongs();
34         $noinline$testUnsafeReferences();
35 
36         // Stress-test read-modify-write operations on the same memory locations.
37         // This is intended to uncover bugs with false-positive comparison in CAS.
38         $noinline$testAtomicReference();
39     }
40 
$noinline$testVarHandleBytes()41     public static void $noinline$testVarHandleBytes() throws Exception {
42         // Prepare `VarHandle` objects.
43         VarHandle[] vhs = new VarHandle[] {
44                 MethodHandles.lookup().findVarHandle(FourBytes.class, "b1", byte.class),
45                 MethodHandles.lookup().findVarHandle(FourBytes.class, "b2", byte.class),
46                 MethodHandles.lookup().findVarHandle(FourBytes.class, "b3", byte.class),
47                 MethodHandles.lookup().findVarHandle(FourBytes.class, "b4", byte.class)
48         };
49         // Prepare threads.
50         final FourBytes fourBytes = new FourBytes();
51         final StopFlag stopFlag = new StopFlag();
52         Thread[] threads = new Thread[4];
53         for (int i = 0; i != 4; ++i) {
54             final VarHandle vh = vhs[i];
55             threads[i] = new Thread() {
56                 public void run() {
57                     byte value = 0;
58                     while (!stopFlag.stop) {
59                         byte nextValue = (byte) (value + 1);
60                         boolean success = vh.compareAndSet(fourBytes, value, nextValue);
61                         assertTrue(success);
62                         value = nextValue;
63                     }
64                 }
65             };
66         }
67         // Start threads.
68         for (int i = 0; i != 4; ++i) {
69             threads[i].start();
70         }
71         // Let the threads run for 5s.
72         Thread.sleep(5000);
73         // Stop threads.
74         stopFlag.stop = true;
75         for (int i = 0; i != 4; ++i) {
76             threads[i].join();
77         }
78     }
79 
$noinline$testVarHandleInts()80     public static void $noinline$testVarHandleInts() throws Exception {
81         // Prepare `VarHandle` objects.
82         VarHandle[] vhs = new VarHandle[] {
83                 MethodHandles.lookup().findVarHandle(FourInts.class, "i1", int.class),
84                 MethodHandles.lookup().findVarHandle(FourInts.class, "i2", int.class),
85                 MethodHandles.lookup().findVarHandle(FourInts.class, "i3", int.class),
86                 MethodHandles.lookup().findVarHandle(FourInts.class, "i4", int.class)
87         };
88         // Prepare threads.
89         final FourInts fourInts = new FourInts();
90         final StopFlag stopFlag = new StopFlag();
91         Thread[] threads = new Thread[4];
92         for (int i = 0; i != 4; ++i) {
93             final VarHandle vh = vhs[i];
94             threads[i] = new Thread() {
95                 public void run() {
96                     int value = 0;
97                     while (!stopFlag.stop) {
98                         int nextValue = value + 1;
99                         boolean success = vh.compareAndSet(fourInts, value, nextValue);
100                         assertTrue(success);
101                         value = nextValue;
102                     }
103                 }
104             };
105         }
106         // Start threads.
107         for (int i = 0; i != 4; ++i) {
108             threads[i].start();
109         }
110         // Let the threads run for 5s.
111         Thread.sleep(5000);
112         // Stop threads.
113         stopFlag.stop = true;
114         for (int i = 0; i != 4; ++i) {
115             threads[i].join();
116         }
117     }
118 
$noinline$testVarHandleLongs()119     public static void $noinline$testVarHandleLongs() throws Exception {
120         // Prepare `VarHandle` objects.
121         VarHandle[] vhs = new VarHandle[] {
122                 MethodHandles.lookup().findVarHandle(FourLongs.class, "l1", long.class),
123                 MethodHandles.lookup().findVarHandle(FourLongs.class, "l2", long.class),
124                 MethodHandles.lookup().findVarHandle(FourLongs.class, "l3", long.class),
125                 MethodHandles.lookup().findVarHandle(FourLongs.class, "l4", long.class)
126         };
127         // Prepare threads.
128         final FourLongs fourLongs = new FourLongs();
129         final StopFlag stopFlag = new StopFlag();
130         Thread[] threads = new Thread[4];
131         for (int i = 0; i != 4; ++i) {
132             final VarHandle vh = vhs[i];
133             threads[i] = new Thread() {
134                 public void run() {
135                     long value = 0;
136                     while (!stopFlag.stop) {
137                         long nextValue = value + 1L;
138                         boolean success = vh.compareAndSet(fourLongs, value, nextValue);
139                         assertTrue(success);
140                         value = nextValue;
141                     }
142                 }
143             };
144         }
145         // Start threads.
146         for (int i = 0; i != 4; ++i) {
147             threads[i].start();
148         }
149         // Let the threads run for 5s.
150         Thread.sleep(5000);
151         // Stop threads.
152         stopFlag.stop = true;
153         for (int i = 0; i != 4; ++i) {
154             threads[i].join();
155         }
156     }
157 
$noinline$testVarHandleReferences()158     public static void $noinline$testVarHandleReferences() throws Exception {
159         // Prepare `VarHandle` objects.
160         VarHandle[] vhs = new VarHandle[] {
161                 MethodHandles.lookup().findVarHandle(FourReferences.class, "r1", Object.class),
162                 MethodHandles.lookup().findVarHandle(FourReferences.class, "r2", Object.class),
163                 MethodHandles.lookup().findVarHandle(FourReferences.class, "r3", Object.class),
164                 MethodHandles.lookup().findVarHandle(FourReferences.class, "r4", Object.class)
165         };
166         // Prepare threads.
167         final FourReferences fourReferences = new FourReferences();
168         Object[] values = new Object[] {
169                 null,
170                 new Object(),
171                 new Object(),
172                 new Object()
173         };
174         final StopFlag stopFlag = new StopFlag();
175         Thread[] threads = new Thread[4];
176         for (int i = 0; i != 4; ++i) {
177             final VarHandle vh = vhs[i];
178             threads[i] = new Thread() {
179                 public void run() {
180                     int index = 0;
181                     while (!stopFlag.stop) {
182                         Object value = values[index];
183                         index = (index + 1) & 3;
184                         Object nextValue = values[index];
185                         boolean success = vh.compareAndSet(fourReferences, value, nextValue);
186                         assertTrue(success);
187                     }
188                 }
189             };
190         }
191         // Start threads.
192         for (int i = 0; i != 4; ++i) {
193             threads[i].start();
194         }
195         // Allocate memory to trigger some GCs
196         for (int i = 0; i != 640 * 1024; ++i) {
197             $noinline$allocateAtLeast1KiB();
198         }
199         // Stop threads.
200         stopFlag.stop = true;
201         for (int i = 0; i != 4; ++i) {
202             threads[i].join();
203         }
204     }
205 
$noinline$testUnsafeInts()206     public static void $noinline$testUnsafeInts() throws Exception {
207         // Prepare Unsafe offsets.
208         final Unsafe unsafe = getUnsafe();
209         long[] offsets = new long[] {
210                 unsafe.objectFieldOffset(FourInts.class.getField("i1")),
211                 unsafe.objectFieldOffset(FourInts.class.getField("i2")),
212                 unsafe.objectFieldOffset(FourInts.class.getField("i3")),
213                 unsafe.objectFieldOffset(FourInts.class.getField("i4"))
214         };
215         // Prepare threads.
216         final FourInts fourInts = new FourInts();
217         final StopFlag stopFlag = new StopFlag();
218         Thread[] threads = new Thread[4];
219         for (int i = 0; i != 4; ++i) {
220             final long offset = offsets[i];
221             threads[i] = new Thread() {
222                 public void run() {
223                     int value = 0;
224                     while (!stopFlag.stop) {
225                         int nextValue = value + 1;
226                         boolean success = unsafe.compareAndSwapInt(
227                                 fourInts, offset, value, nextValue);
228                         assertTrue(success);
229                         value = nextValue;
230                     }
231                 }
232             };
233         }
234         // Start threads.
235         for (int i = 0; i != 4; ++i) {
236             threads[i].start();
237         }
238         // Let the threads run for 5s.
239         Thread.sleep(5000);
240         // Stop threads.
241         stopFlag.stop = true;
242         for (int i = 0; i != 4; ++i) {
243             threads[i].join();
244         }
245     }
246 
$noinline$testUnsafeLongs()247     public static void $noinline$testUnsafeLongs() throws Exception {
248         // Prepare Unsafe offsets.
249         final Unsafe unsafe = getUnsafe();
250         long[] offsets = new long[] {
251                 unsafe.objectFieldOffset(FourLongs.class.getField("l1")),
252                 unsafe.objectFieldOffset(FourLongs.class.getField("l2")),
253                 unsafe.objectFieldOffset(FourLongs.class.getField("l3")),
254                 unsafe.objectFieldOffset(FourLongs.class.getField("l4"))
255         };
256         // Prepare threads.
257         final FourLongs fourLongs = new FourLongs();
258         final StopFlag stopFlag = new StopFlag();
259         Thread[] threads = new Thread[4];
260         for (int i = 0; i != 4; ++i) {
261             final long offset = offsets[i];
262             threads[i] = new Thread() {
263                 public void run() {
264                     long value = 0;
265                     while (!stopFlag.stop) {
266                         long nextValue = value + 1L;
267                         boolean success = unsafe.compareAndSwapLong(
268                                 fourLongs, offset, value, nextValue);
269                         assertTrue(success);
270                         value = nextValue;
271                     }
272                 }
273             };
274         }
275         // Start threads.
276         for (int i = 0; i != 4; ++i) {
277             threads[i].start();
278         }
279         // Let the threads run for 5s.
280         Thread.sleep(5000);
281         // Stop threads.
282         stopFlag.stop = true;
283         for (int i = 0; i != 4; ++i) {
284             threads[i].join();
285         }
286     }
287 
$noinline$testUnsafeReferences()288     public static void $noinline$testUnsafeReferences() throws Exception {
289         // Prepare Unsafe offsets.
290         // D8 rewrites the bytecode with a workaround for CAS bug. To test the raw
291         // `Unsafe.compareAndSwapObject()` call, we implement the call in smali
292         // and wrap it in an indirect call.
293         final UnsafeDispatch unsafeDispatch =
294                 (UnsafeDispatch) Class.forName("UnsafeWrapper").newInstance();
295         final Unsafe unsafe = getUnsafe();
296         long[] offsets = new long[] {
297                 unsafe.objectFieldOffset(FourReferences.class.getField("r1")),
298                 unsafe.objectFieldOffset(FourReferences.class.getField("r2")),
299                 unsafe.objectFieldOffset(FourReferences.class.getField("r3")),
300                 unsafe.objectFieldOffset(FourReferences.class.getField("r4"))
301         };
302         // Prepare threads.
303         final FourReferences fourReferences = new FourReferences();
304         Object[] values = new Object[] {
305                 null,
306                 new Object(),
307                 new Object(),
308                 new Object()
309         };
310         final StopFlag stopFlag = new StopFlag();
311         Thread[] threads = new Thread[4];
312         for (int i = 0; i != 4; ++i) {
313             final long offset = offsets[i];
314             threads[i] = new Thread() {
315                 public void run() {
316                     int index = 0;
317                     while (!stopFlag.stop) {
318                         Object value = values[index];
319                         index = (index + 1) & 3;
320                         Object nextValue = values[index];
321                         boolean success = unsafeDispatch.compareAndSwapObject(
322                                 unsafe, fourReferences, offset, value, nextValue);
323                         assertTrue(success);
324                     }
325                 }
326             };
327         }
328         // Start threads.
329         for (int i = 0; i != 4; ++i) {
330             threads[i].start();
331         }
332         // Allocate memory to trigger some GCs
333         for (int i = 0; i != 640 * 1024; ++i) {
334             $noinline$allocateAtLeast1KiB();
335         }
336         // Stop threads.
337         stopFlag.stop = true;
338         for (int i = 0; i != 4; ++i) {
339             threads[i].join();
340         }
341     }
342 
343     // Instead of using a `VarHandle` directly, this test uses `AtomicReference` which is
344     // implemented using a `VarHandle`. This is because the normal `VarHandle` checks are
345     // done without read barrier which makes them likely to fail and take the slow-path to
346     // the runtime while the GC is marking (which is the case we're most interested in).
347     // The `AtomicReference` uses a boot-image `VarHandle` which is optimized to avoid
348     // those checks, making it more likely to hit bugs in the raw RMW operation.
$noinline$testAtomicReference()349     public static void $noinline$testAtomicReference() throws Exception {
350         // Prepare `AtomicReference` object.
351         // D8 rewrites the bytecode with a workaround for CAS bug. To test the raw
352         // `AtomicReference.compareAndSet()` call, we implement the call in smali
353         // and wrap it in an indirect call.
354         final AtomicReferenceDispatch atomicReferenceDispatch =
355                 (AtomicReferenceDispatch) Class.forName("AtomicReferenceWrapper").newInstance();
356         final AtomicReference aref = new AtomicReference(null);
357         // Prepare threads.
358         final Object[] objects = new Object[] {
359                 null,
360                 new Object(),
361                 new Object(),
362                 new Object()
363         };
364         final StopFlag stopFlag = new StopFlag();
365         Thread[] threads = new Thread[4];
366         for (int i = 0; i != 4; ++i) {
367             if (i == 0) {
368                 threads[i] = new Thread() {
369                     public void run() {
370                         int index = 0;
371                         Object value = objects[index];
372                         while (!stopFlag.stop) {
373                             index = (index + 1) & 3;
374                             Object nextValue = objects[index];
375                             boolean success = atomicReferenceDispatch.compareAndSet(
376                                     aref, value, nextValue);
377                             assertTrue(success);
378                             value = nextValue;
379                         }
380                     }
381                 };
382             } else {
383                 final Object value = objects[i];
384                 assertTrue(value != null);
385                 threads[i] = new Thread() {
386                     public void run() {
387                         // This thread is trying to overwrite a value with the same value.
388                         // For a false-positive in CAS compare, it would actually change
389                         // the value and cause the thread `threads[0]` to fail.
390                         assertTrue(value != null);
391                         while (!stopFlag.stop) {
392                             // Do not check the return value.
393                             atomicReferenceDispatch.compareAndSet(aref, value, value);
394                         }
395                     }
396                 };
397             }
398         };
399         // Start threads.
400         for (int i = 0; i != 4; ++i) {
401             threads[i].start();
402         }
403         // Allocate memory to trigger some GCs
404         for (int i = 0; i != 640 * 1024; ++i) {
405             $noinline$allocateAtLeast1KiB();
406         }
407         // Stop threads.
408         stopFlag.stop = true;
409         for (int i = 0; i != 4; ++i) {
410             threads[i].join();
411         }
412     }
413 
assertTrue(boolean value)414     public static void assertTrue(boolean value) {
415         if (!value) {
416             throw new Error("Assertion failed!");
417         }
418     }
419 
getUnsafe()420     public static Unsafe getUnsafe() throws Exception {
421         Class<?> unsafeClass = Class.forName("sun.misc.Unsafe");
422         Field f = unsafeClass.getDeclaredField("theUnsafe");
423         f.setAccessible(true);
424         return (Unsafe) f.get(null);
425     }
426 
$noinline$allocateAtLeast1KiB()427     public static void $noinline$allocateAtLeast1KiB() {
428         // Give GC more work by allocating Object arrays.
429         memory[allocationIndex] = new Object[1024 / 4];
430         ++allocationIndex;
431         if (allocationIndex == memory.length) {
432             allocationIndex = 0;
433         }
434     }
435 
436     // We shall retain some allocated memory and release old allocations
437     // so that the GC has something to do.
438     public static Object[] memory = new Object[1024];
439     public static int allocationIndex = 0;
440 }
441 
442 class StopFlag {
443     public volatile boolean stop = false;
444 }
445 
446 class FourBytes {
447     public byte b1 = (byte) 0;
448     public byte b2 = (byte) 0;
449     public byte b3 = (byte) 0;
450     public byte b4 = (byte) 0;
451 }
452 
453 class FourInts {
454     public int i1 = 0;
455     public int i2 = 0;
456     public int i3 = 0;
457     public int i4 = 0;
458 }
459 
460 class FourLongs {
461     public long l1 = 0L;
462     public long l2 = 0L;
463     public long l3 = 0L;
464     public long l4 = 0L;
465 }
466 
467 class FourReferences {
468     public Object r1 = null;
469     public Object r2 = null;
470     public Object r3 = null;
471     public Object r4 = null;
472 }
473 
474 abstract class UnsafeDispatch {
compareAndSwapObject( Unsafe unsafe, Object obj, long offset, Object expected, Object new_value)475     public abstract boolean compareAndSwapObject(
476             Unsafe unsafe, Object obj, long offset, Object expected, Object new_value);
477 }
478 
479 abstract class AtomicReferenceDispatch {
compareAndSet(AtomicReference aref, Object expected, Object new_value)480     public abstract boolean compareAndSet(AtomicReference aref, Object expected, Object new_value);
481 }
482