1 package org.junit.runners.model;
2 
3 import static java.lang.reflect.Modifier.isStatic;
4 import static org.junit.internal.MethodSorter.NAME_ASCENDING;
5 
6 import java.lang.annotation.Annotation;
7 import java.lang.reflect.Constructor;
8 import java.lang.reflect.Field;
9 import java.lang.reflect.Method;
10 import java.lang.reflect.Modifier;
11 import java.util.ArrayList;
12 import java.util.Arrays;
13 import java.util.Collections;
14 import java.util.Comparator;
15 import java.util.LinkedHashMap;
16 import java.util.LinkedHashSet;
17 import java.util.List;
18 import java.util.Map;
19 import java.util.Set;
20 
21 import org.junit.Assert;
22 import org.junit.Before;
23 import org.junit.BeforeClass;
24 import org.junit.internal.MethodSorter;
25 
26 /**
27  * Wraps a class to be run, providing method validation and annotation searching
28  *
29  * @since 4.5
30  */
31 public class TestClass implements Annotatable {
32     private static final FieldComparator FIELD_COMPARATOR = new FieldComparator();
33     private static final MethodComparator METHOD_COMPARATOR = new MethodComparator();
34 
35     private final Class<?> clazz;
36     private final Map<Class<? extends Annotation>, List<FrameworkMethod>> methodsForAnnotations;
37     private final Map<Class<? extends Annotation>, List<FrameworkField>> fieldsForAnnotations;
38 
39     /**
40      * Creates a {@code TestClass} wrapping {@code clazz}. Each time this
41      * constructor executes, the class is scanned for annotations, which can be
42      * an expensive process (we hope in future JDK's it will not be.) Therefore,
43      * try to share instances of {@code TestClass} where possible.
44      */
TestClass(Class<?> clazz)45     public TestClass(Class<?> clazz) {
46         this.clazz = clazz;
47         if (clazz != null && clazz.getConstructors().length > 1) {
48             throw new IllegalArgumentException(
49                     "Test class can only have one constructor");
50         }
51 
52         Map<Class<? extends Annotation>, List<FrameworkMethod>> methodsForAnnotations =
53                 new LinkedHashMap<Class<? extends Annotation>, List<FrameworkMethod>>();
54         Map<Class<? extends Annotation>, List<FrameworkField>> fieldsForAnnotations =
55                 new LinkedHashMap<Class<? extends Annotation>, List<FrameworkField>>();
56 
57         scanAnnotatedMembers(methodsForAnnotations, fieldsForAnnotations);
58 
59         this.methodsForAnnotations = makeDeeplyUnmodifiable(methodsForAnnotations);
60         this.fieldsForAnnotations = makeDeeplyUnmodifiable(fieldsForAnnotations);
61     }
62 
scanAnnotatedMembers(Map<Class<? extends Annotation>, List<FrameworkMethod>> methodsForAnnotations, Map<Class<? extends Annotation>, List<FrameworkField>> fieldsForAnnotations)63     protected void scanAnnotatedMembers(Map<Class<? extends Annotation>, List<FrameworkMethod>> methodsForAnnotations, Map<Class<? extends Annotation>, List<FrameworkField>> fieldsForAnnotations) {
64         for (Class<?> eachClass : getSuperClasses(clazz)) {
65             for (Method eachMethod : MethodSorter.getDeclaredMethods(eachClass)) {
66                 addToAnnotationLists(new FrameworkMethod(eachMethod), methodsForAnnotations);
67             }
68             // ensuring fields are sorted to make sure that entries are inserted
69             // and read from fieldForAnnotations in a deterministic order
70             for (Field eachField : getSortedDeclaredFields(eachClass)) {
71                 addToAnnotationLists(new FrameworkField(eachField), fieldsForAnnotations);
72             }
73         }
74     }
75 
getSortedDeclaredFields(Class<?> clazz)76     private static Field[] getSortedDeclaredFields(Class<?> clazz) {
77         Field[] declaredFields = clazz.getDeclaredFields();
78         Arrays.sort(declaredFields, FIELD_COMPARATOR);
79         return declaredFields;
80     }
81 
addToAnnotationLists(T member, Map<Class<? extends Annotation>, List<T>> map)82     protected static <T extends FrameworkMember<T>> void addToAnnotationLists(T member,
83             Map<Class<? extends Annotation>, List<T>> map) {
84         for (Annotation each : member.getAnnotations()) {
85             Class<? extends Annotation> type = each.annotationType();
86             List<T> members = getAnnotatedMembers(map, type, true);
87             if (member.isShadowedBy(members)) {
88                 return;
89             }
90             if (runsTopToBottom(type)) {
91                 members.add(0, member);
92             } else {
93                 members.add(member);
94             }
95         }
96     }
97 
98     private static <T extends FrameworkMember<T>> Map<Class<? extends Annotation>, List<T>>
makeDeeplyUnmodifiable(Map<Class<? extends Annotation>, List<T>> source)99             makeDeeplyUnmodifiable(Map<Class<? extends Annotation>, List<T>> source) {
100         LinkedHashMap<Class<? extends Annotation>, List<T>> copy =
101                 new LinkedHashMap<Class<? extends Annotation>, List<T>>();
102         for (Map.Entry<Class<? extends Annotation>, List<T>> entry : source.entrySet()) {
103             copy.put(entry.getKey(), Collections.unmodifiableList(entry.getValue()));
104         }
105         return Collections.unmodifiableMap(copy);
106     }
107 
108     /**
109      * Returns, efficiently, all the non-overridden methods in this class and
110      * its superclasses that are annotated}.
111      *
112      * @since 4.12
113      */
getAnnotatedMethods()114     public List<FrameworkMethod> getAnnotatedMethods() {
115         List<FrameworkMethod> methods = collectValues(methodsForAnnotations);
116         Collections.sort(methods, METHOD_COMPARATOR);
117         return methods;
118     }
119 
120     /**
121      * Returns, efficiently, all the non-overridden methods in this class and
122      * its superclasses that are annotated with {@code annotationClass}.
123      */
getAnnotatedMethods( Class<? extends Annotation> annotationClass)124     public List<FrameworkMethod> getAnnotatedMethods(
125             Class<? extends Annotation> annotationClass) {
126         return Collections.unmodifiableList(getAnnotatedMembers(methodsForAnnotations, annotationClass, false));
127     }
128 
129     /**
130      * Returns, efficiently, all the non-overridden fields in this class and its
131      * superclasses that are annotated.
132      *
133      * @since 4.12
134      */
getAnnotatedFields()135     public List<FrameworkField> getAnnotatedFields() {
136         return collectValues(fieldsForAnnotations);
137     }
138 
139     /**
140      * Returns, efficiently, all the non-overridden fields in this class and its
141      * superclasses that are annotated with {@code annotationClass}.
142      */
getAnnotatedFields( Class<? extends Annotation> annotationClass)143     public List<FrameworkField> getAnnotatedFields(
144             Class<? extends Annotation> annotationClass) {
145         return Collections.unmodifiableList(getAnnotatedMembers(fieldsForAnnotations, annotationClass, false));
146     }
147 
collectValues(Map<?, List<T>> map)148     private <T> List<T> collectValues(Map<?, List<T>> map) {
149         Set<T> values = new LinkedHashSet<T>();
150         for (List<T> additionalValues : map.values()) {
151             values.addAll(additionalValues);
152         }
153         return new ArrayList<T>(values);
154     }
155 
getAnnotatedMembers(Map<Class<? extends Annotation>, List<T>> map, Class<? extends Annotation> type, boolean fillIfAbsent)156     private static <T> List<T> getAnnotatedMembers(Map<Class<? extends Annotation>, List<T>> map,
157             Class<? extends Annotation> type, boolean fillIfAbsent) {
158         if (!map.containsKey(type) && fillIfAbsent) {
159             map.put(type, new ArrayList<T>());
160         }
161         List<T> members = map.get(type);
162         return members == null ? Collections.<T>emptyList() : members;
163     }
164 
runsTopToBottom(Class<? extends Annotation> annotation)165     private static boolean runsTopToBottom(Class<? extends Annotation> annotation) {
166         return annotation.equals(Before.class)
167                 || annotation.equals(BeforeClass.class);
168     }
169 
getSuperClasses(Class<?> testClass)170     private static List<Class<?>> getSuperClasses(Class<?> testClass) {
171         ArrayList<Class<?>> results = new ArrayList<Class<?>>();
172         Class<?> current = testClass;
173         while (current != null) {
174             results.add(current);
175             current = current.getSuperclass();
176         }
177         return results;
178     }
179 
180     /**
181      * Returns the underlying Java class.
182      */
getJavaClass()183     public Class<?> getJavaClass() {
184         return clazz;
185     }
186 
187     /**
188      * Returns the class's name.
189      */
getName()190     public String getName() {
191         if (clazz == null) {
192             return "null";
193         }
194         return clazz.getName();
195     }
196 
197     /**
198      * Returns the only public constructor in the class, or throws an {@code
199      * AssertionError} if there are more or less than one.
200      */
201 
getOnlyConstructor()202     public Constructor<?> getOnlyConstructor() {
203         Constructor<?>[] constructors = clazz.getConstructors();
204         Assert.assertEquals(1, constructors.length);
205         return constructors[0];
206     }
207 
208     /**
209      * Returns the annotations on this class
210      */
getAnnotations()211     public Annotation[] getAnnotations() {
212         if (clazz == null) {
213             return new Annotation[0];
214         }
215         return clazz.getAnnotations();
216     }
217 
getAnnotation(Class<T> annotationType)218     public <T extends Annotation> T getAnnotation(Class<T> annotationType) {
219         if (clazz == null) {
220             return null;
221         }
222         return clazz.getAnnotation(annotationType);
223     }
224 
getAnnotatedFieldValues(Object test, Class<? extends Annotation> annotationClass, Class<T> valueClass)225     public <T> List<T> getAnnotatedFieldValues(Object test,
226             Class<? extends Annotation> annotationClass, Class<T> valueClass) {
227         List<T> results = new ArrayList<T>();
228         for (FrameworkField each : getAnnotatedFields(annotationClass)) {
229             try {
230                 Object fieldValue = each.get(test);
231                 if (valueClass.isInstance(fieldValue)) {
232                     results.add(valueClass.cast(fieldValue));
233                 }
234             } catch (IllegalAccessException e) {
235                 throw new RuntimeException(
236                         "How did getFields return a field we couldn't access?", e);
237             }
238         }
239         return results;
240     }
241 
getAnnotatedMethodValues(Object test, Class<? extends Annotation> annotationClass, Class<T> valueClass)242     public <T> List<T> getAnnotatedMethodValues(Object test,
243             Class<? extends Annotation> annotationClass, Class<T> valueClass) {
244         List<T> results = new ArrayList<T>();
245         for (FrameworkMethod each : getAnnotatedMethods(annotationClass)) {
246             try {
247                 /*
248                  * A method annotated with @Rule may return a @TestRule or a @MethodRule,
249                  * we cannot call the method to check whether the return type matches our
250                  * expectation i.e. subclass of valueClass. If we do that then the method
251                  * will be invoked twice and we do not want to do that. So we first check
252                  * whether return type matches our expectation and only then call the method
253                  * to fetch the MethodRule
254                  */
255                 if (valueClass.isAssignableFrom(each.getReturnType())) {
256                     Object fieldValue = each.invokeExplosively(test);
257                     results.add(valueClass.cast(fieldValue));
258                 }
259             } catch (Throwable e) {
260                 throw new RuntimeException(
261                         "Exception in " + each.getName(), e);
262             }
263         }
264         return results;
265     }
266 
isPublic()267     public boolean isPublic() {
268         return Modifier.isPublic(clazz.getModifiers());
269     }
270 
isANonStaticInnerClass()271     public boolean isANonStaticInnerClass() {
272         return clazz.isMemberClass() && !isStatic(clazz.getModifiers());
273     }
274 
275     @Override
hashCode()276     public int hashCode() {
277         return (clazz == null) ? 0 : clazz.hashCode();
278     }
279 
280     @Override
equals(Object obj)281     public boolean equals(Object obj) {
282         if (this == obj) {
283             return true;
284         }
285         if (obj == null) {
286             return false;
287         }
288         if (getClass() != obj.getClass()) {
289             return false;
290         }
291         TestClass other = (TestClass) obj;
292         return clazz == other.clazz;
293     }
294 
295     /**
296      * Compares two fields by its name.
297      */
298     private static class FieldComparator implements Comparator<Field> {
compare(Field left, Field right)299         public int compare(Field left, Field right) {
300             return left.getName().compareTo(right.getName());
301         }
302     }
303 
304     /**
305      * Compares two methods by its name.
306      */
307     private static class MethodComparator implements
308             Comparator<FrameworkMethod> {
compare(FrameworkMethod left, FrameworkMethod right)309         public int compare(FrameworkMethod left, FrameworkMethod right) {
310             return NAME_ASCENDING.compare(left.getMethod(), right.getMethod());
311         }
312     }
313 }
314