1 package com.google.inject.internal;
2 
3 import com.google.common.collect.ImmutableList;
4 import com.google.common.collect.ListMultimap;
5 import com.google.common.collect.Multimaps;
6 import com.google.inject.internal.CycleDetectingLock.CycleDetectingLockFactory;
7 import com.google.inject.internal.CycleDetectingLock.CycleDetectingLockFactory.ReentrantCycleDetectingLock;
8 import java.util.Collection;
9 import java.util.List;
10 import java.util.concurrent.Callable;
11 import java.util.concurrent.CyclicBarrier;
12 import java.util.concurrent.Executors;
13 import java.util.concurrent.Future;
14 import java.util.concurrent.FutureTask;
15 import java.util.concurrent.TimeUnit;
16 import java.util.concurrent.locks.ReentrantLock;
17 import junit.framework.TestCase;
18 
19 public class CycleDetectingLockTest extends TestCase {
20 
21   static final long DEADLOCK_TIMEOUT_SECONDS = 1;
22 
23   /**
24    * Verifies that graph of threads' dependencies is not static and is calculated in runtime using
25    * information about specific locks.
26    *
27    * <pre>
28    *   T1: Waits on S1
29    *   T2: Locks B, sends S1, waits on S2
30    *   T1: Locks A, start locking B, sends S2, waits on S3
31    *   T2: Unlocks B, start locking A, sends S3, finishes locking A, unlocks A
32    *   T1: Finishes locking B, unlocks B, unlocks A
33    * </pre>
34    *
35    * <p>This should succeed, even though T1 was locked on T2 and T2 is locked on T1 when T2 locks A.
36    * Incorrect implementation detects a cycle waiting on S3.
37    */
38 
testSingletonThreadsRuntimeCircularDependency()39   public void testSingletonThreadsRuntimeCircularDependency() throws Exception {
40     final CyclicBarrier signal1 = new CyclicBarrier(2);
41     final CyclicBarrier signal2 = new CyclicBarrier(2);
42     final CyclicBarrier signal3 = new CyclicBarrier(2);
43     final CycleDetectingLockFactory<String> lockFactory = new CycleDetectingLockFactory<>();
44     final CycleDetectingLock<String> lockA =
45         new ReentrantCycleDetectingLock<String>(
46             lockFactory,
47             "A",
48             new ReentrantLock() {
49               @Override
50               public void lock() {
51                 if (Thread.currentThread().getName().equals("T2")) {
52                   try {
53                     signal3.await(DEADLOCK_TIMEOUT_SECONDS, TimeUnit.SECONDS);
54                   } catch (Exception e) {
55                     throw new RuntimeException(e);
56                   }
57                 } else {
58                   assertEquals("T1", Thread.currentThread().getName());
59                 }
60                 super.lock();
61               }
62             });
63     final CycleDetectingLock<String> lockB =
64         new ReentrantCycleDetectingLock<String>(
65             lockFactory,
66             "B",
67             new ReentrantLock() {
68               @Override
69               public void lock() {
70                 if (Thread.currentThread().getName().equals("T1")) {
71                   try {
72                     signal2.await(DEADLOCK_TIMEOUT_SECONDS, TimeUnit.SECONDS);
73                     signal3.await(DEADLOCK_TIMEOUT_SECONDS, TimeUnit.SECONDS);
74                   } catch (Exception e) {
75                     throw new RuntimeException(e);
76                   }
77                 } else {
78                   assertEquals("T2", Thread.currentThread().getName());
79                 }
80                 super.lock();
81               }
82             });
83     Future<Void> firstThreadResult =
84         Executors.newSingleThreadExecutor()
85             .submit(
86                 new Callable<Void>() {
87                   @Override
88                   public Void call() throws Exception {
89                     Thread.currentThread().setName("T1");
90                     signal1.await(DEADLOCK_TIMEOUT_SECONDS, TimeUnit.SECONDS);
91                     assertTrue(lockA.lockOrDetectPotentialLocksCycle().isEmpty());
92                     assertTrue(lockB.lockOrDetectPotentialLocksCycle().isEmpty());
93                     lockB.unlock();
94                     lockA.unlock();
95                     return null;
96                   }
97                 });
98     Future<Void> secondThreadResult =
99         Executors.newSingleThreadExecutor()
100             .submit(
101                 new Callable<Void>() {
102                   @Override
103                   public Void call() throws Exception {
104                     Thread.currentThread().setName("T2");
105                     assertTrue(lockB.lockOrDetectPotentialLocksCycle().isEmpty());
106                     signal1.await(DEADLOCK_TIMEOUT_SECONDS, TimeUnit.SECONDS);
107                     signal2.await(DEADLOCK_TIMEOUT_SECONDS, TimeUnit.SECONDS);
108                     lockB.unlock();
109                     assertTrue(lockA.lockOrDetectPotentialLocksCycle().isEmpty());
110                     lockA.unlock();
111                     return null;
112                   }
113                 });
114 
115     firstThreadResult.get(DEADLOCK_TIMEOUT_SECONDS * 3, TimeUnit.SECONDS);
116     secondThreadResult.get(DEADLOCK_TIMEOUT_SECONDS * 3, TimeUnit.SECONDS);
117   }
118 
119   /**
120    * Verifies that factories do not deadlock each other.
121    *
122    * <pre>
123    *   Thread A: lock a lock A (factory A)
124    *   Thread B: lock a lock B (factory B)
125    *   Thread A: lock a lock B (factory B)
126    *   Thread B: lock a lock A (factory A)
127    * </pre>
128    *
129    * <p>This should succeed even though from the point of view of each individual factory there are
130    * no deadlocks to detect.
131    */
132 
testCycleDetectingLockFactoriesDoNotDeadlock()133   public void testCycleDetectingLockFactoriesDoNotDeadlock() throws Exception {
134     final CycleDetectingLockFactory<String> factoryA = new CycleDetectingLockFactory<>();
135     final CycleDetectingLock<String> lockA = factoryA.create("A");
136     final CycleDetectingLockFactory<String> factoryB = new CycleDetectingLockFactory<>();
137     final CycleDetectingLock<String> lockB = factoryB.create("B");
138     final CyclicBarrier eachThreadAcquiredFirstLock = new CyclicBarrier(2);
139     Future<Boolean> threadA =
140         Executors.newSingleThreadExecutor()
141             .submit(
142                 new Callable<Boolean>() {
143                   @Override
144                   public Boolean call() throws Exception {
145                     Thread.currentThread().setName("A");
146                     assertTrue(lockA.lockOrDetectPotentialLocksCycle().isEmpty());
147                     eachThreadAcquiredFirstLock.await(DEADLOCK_TIMEOUT_SECONDS, TimeUnit.SECONDS);
148                     boolean isEmpty = lockB.lockOrDetectPotentialLocksCycle().isEmpty();
149                     if (isEmpty) {
150                       lockB.unlock();
151                     }
152                     lockA.unlock();
153                     return isEmpty;
154                   }
155                 });
156     Future<Boolean> threadB =
157         Executors.newSingleThreadExecutor()
158             .submit(
159                 new Callable<Boolean>() {
160                   @Override
161                   public Boolean call() throws Exception {
162                     Thread.currentThread().setName("B");
163                     assertTrue(lockB.lockOrDetectPotentialLocksCycle().isEmpty());
164                     eachThreadAcquiredFirstLock.await(DEADLOCK_TIMEOUT_SECONDS, TimeUnit.SECONDS);
165                     boolean isEmpty = lockA.lockOrDetectPotentialLocksCycle().isEmpty();
166                     if (isEmpty) {
167                       lockA.unlock();
168                     }
169                     lockB.unlock();
170                     return isEmpty;
171                   }
172                 });
173 
174     boolean deadlockADetected = threadA.get(DEADLOCK_TIMEOUT_SECONDS * 2, TimeUnit.SECONDS);
175     boolean deadlockBDetected = threadB.get(DEADLOCK_TIMEOUT_SECONDS * 2, TimeUnit.SECONDS);
176 
177     assertTrue("Deadlock should get detected", deadlockADetected || deadlockBDetected);
178     assertTrue("One deadlock should get detected", deadlockADetected != deadlockBDetected);
179   }
180 
181   /**
182    * Verifies that factories deadlocks report the correct cycles.
183    *
184    * <pre>
185    *   Thread 1: takes locks a, b
186    *   Thread 2: takes locks b, c
187    *   Thread 3: takes locks c, a
188    * </pre>
189    *
190    * <p>In order to ensure a deadlock, each thread will wait on a barrier right after grabbing the
191    * first lock.
192    */
193 
testCycleReporting()194   public void testCycleReporting() throws Exception {
195     final CycleDetectingLockFactory<String> factory = new CycleDetectingLockFactory<>();
196     final CycleDetectingLock<String> lockA = factory.create("a");
197     final CycleDetectingLock<String> lockB = factory.create("b");
198     final CycleDetectingLock<String> lockC = factory.create("c");
199     final CyclicBarrier barrier = new CyclicBarrier(3);
200     ImmutableList<Future<ListMultimap<Thread, String>>> futures =
201         ImmutableList.of(
202             grabLocksInThread(lockA, lockB, barrier),
203             grabLocksInThread(lockB, lockC, barrier),
204             grabLocksInThread(lockC, lockA, barrier));
205 
206     // At least one of the threads will report a lock cycle, it is possible that they all will, but
207     // there is no guarantee, so we just scan for the first thread that reported a cycle
208     ListMultimap<Thread, String> cycle = null;
209     for (Future<ListMultimap<Thread, String>> future : futures) {
210       ListMultimap<Thread, String> value =
211           future.get(DEADLOCK_TIMEOUT_SECONDS * 3, TimeUnit.SECONDS);
212       if (!value.isEmpty()) {
213         cycle = value;
214         break;
215       }
216     }
217     // We don't really care about the keys in the multimap, but we want to make sure that all locks
218     // were reported in the right order.
219     assertEquals(6, cycle.size());
220     Collection<List<String>> edges = Multimaps.asMap(cycle).values();
221     assertTrue(edges.contains(ImmutableList.of("a", "b")));
222     assertTrue(edges.contains(ImmutableList.of("b", "c")));
223     assertTrue(edges.contains(ImmutableList.of("c", "a")));
224   }
225 
grabLocksInThread( final CycleDetectingLock<T> lock1, final CycleDetectingLock<T> lock2, final CyclicBarrier barrier)226   private static <T> Future<ListMultimap<Thread, T>> grabLocksInThread(
227       final CycleDetectingLock<T> lock1,
228       final CycleDetectingLock<T> lock2,
229       final CyclicBarrier barrier) {
230     FutureTask<ListMultimap<Thread, T>> future =
231         new FutureTask<ListMultimap<Thread, T>>(
232             new Callable<ListMultimap<Thread, T>>() {
233               @Override
234               public ListMultimap<Thread, T> call() throws Exception {
235                 assertTrue(lock1.lockOrDetectPotentialLocksCycle().isEmpty());
236                 barrier.await();
237                 ListMultimap<Thread, T> cycle = lock2.lockOrDetectPotentialLocksCycle();
238                 if (cycle == null) {
239                   lock2.unlock();
240                 }
241                 lock1.unlock();
242                 return cycle;
243               }
244             });
245     Thread thread = new Thread(future);
246     thread.start();
247     return future;
248   }
249 }
250