1 package com.google.inject.internal;
2 
3 import com.google.inject.internal.CycleDetectingLock.CycleDetectingLockFactory;
4 
5 import junit.framework.TestCase;
6 
7 import java.util.concurrent.Callable;
8 import java.util.concurrent.CyclicBarrier;
9 import java.util.concurrent.Executors;
10 import java.util.concurrent.Future;
11 import java.util.concurrent.TimeUnit;
12 import java.util.concurrent.locks.ReentrantLock;
13 
14 public class CycleDetectingLockTest extends TestCase {
15 
16   static final long DEADLOCK_TIMEOUT_SECONDS = 1;
17 
18   /**
19    * Verifies that graph of threads' dependencies is not static and is calculated in runtime using
20    * information about specific locks.
21    *
22    * <pre>
23    *   T1: Waits on S1
24    *   T2: Locks B, sends S1, waits on S2
25    *   T1: Locks A, start locking B, sends S2, waits on S3
26    *   T2: Unlocks B, start locking A, sends S3, finishes locking A, unlocks A
27    *   T1: Finishes locking B, unlocks B, unlocks A
28    * </pre>
29    *
30    * <p>This should succeed, even though T1 was locked on T2 and T2 is locked on T1 when T2 locks
31    * A. Incorrect implementation detects a cycle waiting on S3.
32    */
33 
testSingletonThreadsRuntimeCircularDependency()34   public void testSingletonThreadsRuntimeCircularDependency() throws Exception {
35     final CyclicBarrier signal1 = new CyclicBarrier(2);
36     final CyclicBarrier signal2 = new CyclicBarrier(2);
37     final CyclicBarrier signal3 = new CyclicBarrier(2);
38     CycleDetectingLockFactory<String> lockFactory = new CycleDetectingLockFactory<String>();
39     final CycleDetectingLock<String> lockA =
40         lockFactory.new ReentrantCycleDetectingLock("A", new ReentrantLock() {
41           @Override
42           public void lock() {
43             if (Thread.currentThread().getName().equals("T2")) {
44               try {
45                 signal3.await(DEADLOCK_TIMEOUT_SECONDS, TimeUnit.SECONDS);
46               } catch (Exception e) {
47                 throw new RuntimeException(e);
48               }
49             } else {
50               assertEquals("T1", Thread.currentThread().getName());
51             }
52             super.lock();
53           }
54         });
55     final CycleDetectingLock<String> lockB =
56         lockFactory.new ReentrantCycleDetectingLock("B", new ReentrantLock() {
57           @Override
58           public void lock() {
59             if (Thread.currentThread().getName().equals("T1")) {
60               try {
61                 signal2.await(DEADLOCK_TIMEOUT_SECONDS, TimeUnit.SECONDS);
62                 signal3.await(DEADLOCK_TIMEOUT_SECONDS, TimeUnit.SECONDS);
63               } catch (Exception e) {
64                 throw new RuntimeException(e);
65               }
66             } else {
67               assertEquals("T2", Thread.currentThread().getName());
68             }
69             super.lock();
70           }
71         });
72     Future<Void> firstThreadResult = Executors.newSingleThreadExecutor().submit(
73         new Callable<Void>() {
74           public Void call() throws Exception {
75             Thread.currentThread().setName("T1");
76             signal1.await(DEADLOCK_TIMEOUT_SECONDS, TimeUnit.SECONDS);
77             assertTrue(lockA.lockOrDetectPotentialLocksCycle().isEmpty());
78             assertTrue(lockB.lockOrDetectPotentialLocksCycle().isEmpty());
79             lockB.unlock();
80             lockA.unlock();
81             return null;
82           }
83         });
84     Future<Void> secondThreadResult = Executors.newSingleThreadExecutor().submit(
85         new Callable<Void>() {
86           public Void call() throws Exception {
87             Thread.currentThread().setName("T2");
88             assertTrue(lockB.lockOrDetectPotentialLocksCycle().isEmpty());
89             signal1.await(DEADLOCK_TIMEOUT_SECONDS, TimeUnit.SECONDS);
90             signal2.await(DEADLOCK_TIMEOUT_SECONDS, TimeUnit.SECONDS);
91             lockB.unlock();
92             assertTrue(lockA.lockOrDetectPotentialLocksCycle().isEmpty());
93             lockA.unlock();
94             return null;
95           }
96         });
97 
98     firstThreadResult.get();
99     secondThreadResult.get();
100   }
101 }
102