1 package org.testng.junit;
2 
3 import java.util.*;
4 import java.util.regex.Pattern;
5 import org.junit.runner.Description;
6 import org.junit.runner.JUnitCore;
7 import org.junit.runner.Request;
8 import org.junit.runner.Result;
9 import org.junit.runner.manipulation.Filter;
10 import org.junit.runner.notification.Failure;
11 import org.junit.runner.notification.RunListener;
12 import org.testng.*;
13 import org.testng.collections.Lists;
14 import org.testng.internal.ITestResultNotifier;
15 import org.testng.internal.InvokedMethod;
16 import org.testng.internal.TestResult;
17 
18 /**
19  * A JUnit TestRunner that records/triggers all information/events necessary to
20  * TestNG.
21  *
22  * @author Lukas Jungmann
23  */
24 public class JUnit4TestRunner implements IJUnitTestRunner {
25 
26     private ITestResultNotifier m_parentRunner;
27     private List<ITestNGMethod> m_methods = Lists.newArrayList();
28     private List<ITestListener> m_listeners = Lists.newArrayList();
29     private Collection<IInvokedMethodListener> m_invokeListeners = Lists.newArrayList();
30 
JUnit4TestRunner()31     public JUnit4TestRunner() {
32     }
33 
JUnit4TestRunner(ITestResultNotifier tr)34     public JUnit4TestRunner(ITestResultNotifier tr) {
35         m_parentRunner = tr;
36         m_listeners = m_parentRunner.getTestListeners();
37     }
38 
39     /**
40      * Needed from TestRunner in order to figure out what JUnit test methods
41      * were run.
42      *
43      * @return the list of all JUnit test methods run
44      */
45     @Override
getTestMethods()46     public List<ITestNGMethod> getTestMethods() {
47         return m_methods;
48     }
49 
50     @Override
setTestResultNotifier(ITestResultNotifier notifier)51     public void setTestResultNotifier(ITestResultNotifier notifier) {
52         m_parentRunner = notifier;
53         m_listeners = m_parentRunner.getTestListeners();
54     }
55 
setInvokedMethodListeners(Collection<IInvokedMethodListener> listeners)56     public void setInvokedMethodListeners(Collection<IInvokedMethodListener> listeners) {
57         m_invokeListeners = listeners;
58     }
59 
60     /**
61      * A
62      * <code>start</code> implementation that ignores the
63      * <code>TestResult</code>
64      *
65      * @param testClass the JUnit test class
66      */
67     @Override
run(Class testClass, String... methods)68     public void run(Class testClass, String... methods) {
69         start(testClass, methods);
70     }
71 
72     /**
73      * Starts a test run. Analyzes the command line arguments and runs the given
74      * test suite.
75      */
start(final Class testCase, final String... methods)76     public Result start(final Class testCase, final String... methods) {
77         try {
78             JUnitCore core = new JUnitCore();
79             core.addListener(new RL());
80             Request r = Request.aClass(testCase);
81             return core.run(r.filterWith(new Filter() {
82 
83                 @Override
84                 public boolean shouldRun(Description description) {
85                     if (description == null) {
86                         return false;
87                     }
88                     if (methods.length == 0) {
89                         //run everything
90                         return true;
91                     }
92                     for (String m: methods) {
93                         Pattern p = Pattern.compile(m);
94                         if (p.matcher(description.getMethodName()).matches()) {
95                             return true;
96                         }
97                     }
98                     return false;
99                 }
100 
101                 @Override
102                 public String describe() {
103                     return "TestNG method filter";
104                 }
105             }));
106         } catch (Throwable t) {
107             throw new TestNGException("Failure in JUnit mode for class " + testCase.getName(), t);
108         }
109     }
110 
111     private class RL extends RunListener {
112 
113         private Map<Description, ITestResult> runs = new WeakHashMap<>();
114         private List<Description> notified = new LinkedList<>();
115 
116         @Override
117         public void testAssumptionFailure(Failure failure) {
118             notified.add(failure.getDescription());
119             ITestResult tr = runs.get(failure.getDescription());
120             tr.setStatus(TestResult.SKIP);
121             tr.setEndMillis(Calendar.getInstance().getTimeInMillis());
122             tr.setThrowable(failure.getException());
123             m_parentRunner.addSkippedTest(tr.getMethod(), tr);
124             for (ITestListener l : m_listeners) {
125                 l.onTestSkipped(tr);
126             }
127         }
128 
129         @Override
130         public void testFailure(Failure failure) throws Exception {
131             if (isAssumptionFailed(failure)) {
132                 this.testAssumptionFailure(failure);
133                 return;
134             }
135             notified.add(failure.getDescription());
136             ITestResult tr = runs.get(failure.getDescription());
137             tr.setStatus(TestResult.FAILURE);
138             tr.setEndMillis(Calendar.getInstance().getTimeInMillis());
139             tr.setThrowable(failure.getException());
140             m_parentRunner.addFailedTest(tr.getMethod(), tr);
141             for (ITestListener l : m_listeners) {
142                 l.onTestFailure(tr);
143             }
144         }
145 
146         @Override
147         public void testFinished(Description description) throws Exception {
148             ITestResult tr = runs.get(description);
149             if (!notified.contains(description)) {
150                 tr.setStatus(TestResult.SUCCESS);
151                 tr.setEndMillis(Calendar.getInstance().getTimeInMillis());
152                 m_parentRunner.addPassedTest(tr.getMethod(), tr);
153                 for (ITestListener l : m_listeners) {
154                     l.onTestSuccess(tr);
155                 }
156             }
157             m_methods.add(tr.getMethod());
158         }
159 
160         @Override
161         public void testIgnored(Description description) throws Exception {
162             ITestResult tr = createTestResult(description);
163             tr.setStatus(TestResult.SKIP);
164             tr.setEndMillis(tr.getStartMillis());
165             m_parentRunner.addSkippedTest(tr.getMethod(), tr);
166             m_methods.add(tr.getMethod());
167             for (ITestListener l : m_listeners) {
168                 l.onTestSkipped(tr);
169             }
170         }
171 
172         @Override
173         public void testRunFinished(Result result) throws Exception {
174         }
175 
176         @Override
177         public void testRunStarted(Description description) throws Exception {
178         }
179 
180         @Override
181         public void testStarted(Description description) throws Exception {
182             ITestResult tr = createTestResult(description);
183             runs.put(description, tr);
184             for (ITestListener l : m_listeners) {
185                 l.onTestStart(tr);
186             }
187         }
188 
189         private ITestResult createTestResult(Description test) {
190             JUnit4TestClass tc = new JUnit4TestClass(test);
191             JUnitTestMethod tm = new JUnit4TestMethod(tc, test);
192 
193             TestResult tr = new TestResult(tc,
194                     test,
195                     tm,
196                     null,
197                     Calendar.getInstance().getTimeInMillis(),
198                     0,
199                     null);
200 
201             InvokedMethod im = new InvokedMethod(tr.getTestClass(), tr.getMethod(), new Object[0], tr.getStartMillis(), tr);
202             m_parentRunner.addInvokedMethod(im);
203             for (IInvokedMethodListener l: m_invokeListeners) {
204                 l.beforeInvocation(im, tr);
205             }
206             return tr;
207         }
208     }
209 
210     private static boolean isAssumptionFailed(Failure failure) {
211         if (failure == null) {
212             return false;
213         }
214         //noinspection ThrowableResultOfMethodCallIgnored
215         final Throwable exception = failure.getException();
216         //noinspection SimplifiableIfStatement
217         if (exception == null) {
218             return false;
219         }
220         return "org.junit.internal.AssumptionViolatedException".equals(exception.getClass().getCanonicalName());
221     }
222 }
223