1 /*
2  * Copyright (C) 2021 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.bedstead.harrier;
18 
19 import android.util.Log;
20 
21 import androidx.annotation.Nullable;
22 
23 import com.android.bedstead.harrier.annotations.CalledByHostDrivenTest;
24 import com.android.bedstead.harrier.annotations.enterprise.CannotSetPolicyTest;
25 import com.android.bedstead.harrier.annotations.enterprise.EnterprisePolicy;
26 import com.android.bedstead.harrier.annotations.enterprise.NegativePolicyTest;
27 import com.android.bedstead.harrier.annotations.enterprise.PositivePolicyTest;
28 import com.android.bedstead.harrier.annotations.meta.ParameterizedAnnotation;
29 import com.android.bedstead.harrier.annotations.meta.RepeatingAnnotation;
30 import com.android.bedstead.harrier.annotations.parameterized.IncludeNone;
31 import com.android.bedstead.nene.exceptions.NeneException;
32 
33 import com.google.common.base.Objects;
34 
35 import org.junit.Test;
36 import org.junit.rules.TestRule;
37 import org.junit.runners.BlockJUnit4ClassRunner;
38 import org.junit.runners.model.FrameworkMethod;
39 import org.junit.runners.model.InitializationError;
40 import org.junit.runners.model.TestClass;
41 
42 import java.lang.annotation.Annotation;
43 import java.lang.reflect.InvocationTargetException;
44 import java.lang.reflect.Method;
45 import java.util.ArrayList;
46 import java.util.Arrays;
47 import java.util.Collections;
48 import java.util.Comparator;
49 import java.util.HashMap;
50 import java.util.HashSet;
51 import java.util.List;
52 import java.util.Map;
53 import java.util.Set;
54 
55 /**
56  * A JUnit test runner for use with Bedstead.
57  */
58 public final class BedsteadJUnit4 extends BlockJUnit4ClassRunner {
59 
60     private static final String BEDSTEAD_PACKAGE_NAME = "com.android.bedstead";
61 
62     // These are annotations which are not included indirectly
63     private static final Set<String> sIgnoredAnnotationPackages = new HashSet<>();
64     static {
65         sIgnoredAnnotationPackages.add("java.lang.annotation");
66         sIgnoredAnnotationPackages.add("com.android.bedstead.harrier.annotations.meta");
67         sIgnoredAnnotationPackages.add("kotlin.*");
68         sIgnoredAnnotationPackages.add("org.junit");
69     }
70 
71     /**
72      * {@link FrameworkMethod} subclass which allows modifying the test name and annotations.
73      */
74     public static final class BedsteadFrameworkMethod extends FrameworkMethod {
75 
76         private final Class<? extends Annotation> mParameterizedAnnotation;
77         private final Map<Class<? extends Annotation>, Annotation> mAnnotationsMap =
78                 new HashMap<>();
79         private Annotation[] mAnnotations;
80 
BedsteadFrameworkMethod(Method method)81         public BedsteadFrameworkMethod(Method method) {
82             this(method, /* parameterizedAnnotation= */ null);
83         }
84 
BedsteadFrameworkMethod(Method method, Annotation parameterizedAnnotation)85         public BedsteadFrameworkMethod(Method method, Annotation parameterizedAnnotation) {
86             super(method);
87             this.mParameterizedAnnotation = (parameterizedAnnotation == null) ? null
88                     : parameterizedAnnotation.annotationType();
89 
90             calculateAnnotations();
91         }
92 
calculateAnnotations()93         private void calculateAnnotations() {
94             List<Annotation> annotations =
95                     new ArrayList<>(Arrays.asList(getDeclaringClass().getAnnotations()));
96             annotations.addAll(Arrays.asList(getMethod().getAnnotations()));
97 
98             parseEnterpriseAnnotations(annotations);
99 
100             resolveRecursiveAnnotations(annotations, mParameterizedAnnotation);
101 
102             this.mAnnotations = annotations.toArray(new Annotation[0]);
103             for (Annotation annotation : annotations) {
104                 mAnnotationsMap.put(annotation.annotationType(), annotation);
105             }
106         }
107 
108         @Override
getName()109         public String getName() {
110             if (mParameterizedAnnotation == null) {
111                 return super.getName();
112             }
113             return super.getName() + "[" + mParameterizedAnnotation.getSimpleName() + "]";
114         }
115 
116         @Override
equals(Object obj)117         public boolean equals(Object obj) {
118             if (!super.equals(obj)) {
119                 return false;
120             }
121 
122             if (!(obj instanceof BedsteadFrameworkMethod)) {
123                 return false;
124             }
125 
126             BedsteadFrameworkMethod other = (BedsteadFrameworkMethod) obj;
127 
128             return Objects.equal(mParameterizedAnnotation, other.mParameterizedAnnotation);
129         }
130 
131         @Override
getAnnotations()132         public Annotation[] getAnnotations() {
133             return mAnnotations;
134         }
135 
136         @Override
getAnnotation(Class<T> annotationType)137         public <T extends Annotation> T getAnnotation(Class<T> annotationType) {
138             return (T) mAnnotationsMap.get(annotationType);
139         }
140     }
141 
142     /**
143      * Resolve annotations recursively.
144      *
145      * @param parameterizedAnnotation The class of the parameterized annotation to expand, if any
146      */
resolveRecursiveAnnotations(List<Annotation> annotations, @Nullable Class<? extends Annotation> parameterizedAnnotation)147     public static void resolveRecursiveAnnotations(List<Annotation> annotations,
148             @Nullable Class<? extends Annotation> parameterizedAnnotation) {
149         int index = 0;
150         while (index < annotations.size()) {
151             Annotation annotation = annotations.get(index);
152             annotations.remove(index);
153             List<Annotation> replacementAnnotations =
154                     getReplacementAnnotations(annotation, parameterizedAnnotation);
155             annotations.addAll(index, replacementAnnotations);
156             index += replacementAnnotations.size();
157         }
158     }
159 
getReplacementAnnotations(Annotation annotation, @Nullable Class<? extends Annotation> parameterizedAnnotation)160     private static List<Annotation> getReplacementAnnotations(Annotation annotation,
161             @Nullable Class<? extends Annotation> parameterizedAnnotation) {
162         List<Annotation> replacementAnnotations = new ArrayList<>();
163 
164         if (annotation.annotationType().getAnnotation(RepeatingAnnotation.class) != null) {
165             try {
166                 Annotation[] annotations =
167                         (Annotation[]) annotation.annotationType()
168                                 .getMethod("value").invoke(annotation);
169                 Collections.addAll(replacementAnnotations, annotations);
170                 return replacementAnnotations;
171             } catch (IllegalAccessException | InvocationTargetException | NoSuchMethodException e) {
172                 throw new NeneException("Error expanding repeated annotations", e);
173             }
174         }
175 
176         if (annotation.annotationType().getAnnotation(ParameterizedAnnotation.class) != null
177                 && !annotation.annotationType().equals(parameterizedAnnotation)) {
178             return replacementAnnotations;
179         }
180 
181         for (Annotation indirectAnnotation : annotation.annotationType().getAnnotations()) {
182             String annotationPackage = indirectAnnotation.annotationType().getPackage().getName();
183             if (shouldSkipAnnotation(annotationPackage)) {
184                 continue;
185             }
186 
187             replacementAnnotations.addAll(getReplacementAnnotations(
188                     indirectAnnotation, parameterizedAnnotation));
189         }
190 
191         replacementAnnotations.add(annotation);
192 
193         return replacementAnnotations;
194     }
195 
shouldSkipAnnotation(String annotationPackage)196     private static boolean shouldSkipAnnotation(String annotationPackage) {
197         for (String ignoredPackage : sIgnoredAnnotationPackages) {
198             if (ignoredPackage.endsWith(".*")) {
199                 if (annotationPackage.startsWith(
200                     ignoredPackage.substring(0, ignoredPackage.length() - 2))) {
201                     return true;
202                 }
203             } else if (annotationPackage.equals(ignoredPackage)) {
204                 return true;
205             }
206         }
207 
208         return false;
209     }
210 
BedsteadJUnit4(Class<?> testClass)211     public BedsteadJUnit4(Class<?> testClass) throws InitializationError {
212         super(testClass);
213     }
214 
215     @Override
computeTestMethods()216     protected List<FrameworkMethod> computeTestMethods() {
217         TestClass testClass = getTestClass();
218 
219         List<FrameworkMethod> basicTests = new ArrayList<>();
220         basicTests.addAll(testClass.getAnnotatedMethods(Test.class));
221         basicTests.addAll(testClass.getAnnotatedMethods(CalledByHostDrivenTest.class));
222 
223         List<FrameworkMethod> modifiedTests = new ArrayList<>();
224 
225         for (FrameworkMethod m : basicTests) {
226             Set<Annotation> parameterizedAnnotations = getParameterizedAnnotations(m);
227 
228             if (parameterizedAnnotations.isEmpty()) {
229                 // Unparameterized, just add the original
230                 modifiedTests.add(new BedsteadFrameworkMethod(m.getMethod()));
231             }
232 
233             for (Annotation annotation : parameterizedAnnotations) {
234                 if (annotation.annotationType().equals(IncludeNone.class)) {
235                     // Special case - does not generate a run
236                     continue;
237                 }
238                 modifiedTests.add(
239                         new BedsteadFrameworkMethod(m.getMethod(), annotation));
240             }
241         }
242 
243         sortMethodsByBedsteadAnnotations(modifiedTests);
244 
245         return modifiedTests;
246     }
247 
248     /**
249      * Sort methods so that methods with identical bedstead annotations are together.
250      *
251      * <p>This will also ensure that all tests methods which are not annotated for bedstead will
252      * run before any tests which are annotated.
253      */
sortMethodsByBedsteadAnnotations(List<FrameworkMethod> modifiedTests)254     private void sortMethodsByBedsteadAnnotations(List<FrameworkMethod> modifiedTests) {
255         List<Annotation> bedsteadAnnotationsSortedByMostCommon =
256                 bedsteadAnnotationsSortedByMostCommon(modifiedTests);
257 
258         modifiedTests.sort((o1, o2) -> {
259             for (Annotation annotation : bedsteadAnnotationsSortedByMostCommon) {
260                 boolean o1HasAnnotation = o1.getAnnotation(annotation.annotationType()) != null;
261                 boolean o2HasAnnotation = o2.getAnnotation(annotation.annotationType()) != null;
262 
263                 if (o1HasAnnotation && !o2HasAnnotation) {
264                     // o1 goes to the end
265                     return 1;
266                 } else if (o2HasAnnotation && !o1HasAnnotation) {
267                     return -1;
268                 }
269             }
270             return 0;
271         });
272     }
273 
bedsteadAnnotationsSortedByMostCommon(List<FrameworkMethod> methods)274     private List<Annotation> bedsteadAnnotationsSortedByMostCommon(List<FrameworkMethod> methods) {
275         Map<Annotation, Integer> annotationCounts = countAnnotations(methods);
276         List<Annotation> annotations = new ArrayList<>(annotationCounts.keySet());
277 
278         annotations.removeIf(
279                 annotation ->
280                         !annotation.annotationType()
281                                 .getCanonicalName().contains(BEDSTEAD_PACKAGE_NAME));
282 
283         annotations.sort(Comparator.comparingInt(annotationCounts::get));
284         Collections.reverse(annotations);
285 
286         return annotations;
287     }
288 
countAnnotations(List<FrameworkMethod> methods)289     private Map<Annotation, Integer> countAnnotations(List<FrameworkMethod> methods) {
290         Map<Annotation, Integer> annotationCounts = new HashMap<>();
291 
292         for (FrameworkMethod method : methods) {
293             for (Annotation annotation : method.getAnnotations()) {
294                 annotationCounts.put(
295                         annotation, annotationCounts.getOrDefault(annotation, 0) + 1);
296             }
297         }
298 
299         return annotationCounts;
300     }
301 
getParameterizedAnnotations(FrameworkMethod method)302     private Set<Annotation> getParameterizedAnnotations(FrameworkMethod method) {
303         Set<Annotation> parameterizedAnnotations = new HashSet<>();
304         List<Annotation> annotations = new ArrayList<>(Arrays.asList(method.getAnnotations()));
305 
306         // TODO(scottjonathan): We're doing this twice... does it matter?
307         parseEnterpriseAnnotations(annotations);
308 
309         for (Annotation annotation : annotations) {
310             if (annotation.annotationType().getAnnotation(ParameterizedAnnotation.class) != null) {
311                 parameterizedAnnotations.add(annotation);
312             }
313         }
314 
315         return parameterizedAnnotations;
316     }
317 
318     /**
319      * Parse enterprise-specific annotations.
320      *
321      * <p>To be used before general annotation processing.
322      */
parseEnterpriseAnnotations(List<Annotation> annotations)323     private static void parseEnterpriseAnnotations(List<Annotation> annotations) {
324         int index = 0;
325         while (index < annotations.size()) {
326             Annotation annotation = annotations.get(index);
327             if (annotation instanceof PositivePolicyTest) {
328                 annotations.remove(index);
329                 Class<?> policy = ((PositivePolicyTest) annotation).policy();
330 
331                 EnterprisePolicy enterprisePolicy =
332                         policy.getAnnotation(EnterprisePolicy.class);
333                 List<Annotation> replacementAnnotations =
334                         Policy.positiveStates(enterprisePolicy);
335 
336                 annotations.addAll(index, replacementAnnotations);
337                 index += replacementAnnotations.size();
338             } else if (annotation instanceof NegativePolicyTest) {
339                 annotations.remove(index);
340                 Class<?> policy = ((NegativePolicyTest) annotation).policy();
341 
342                 EnterprisePolicy enterprisePolicy =
343                         policy.getAnnotation(EnterprisePolicy.class);
344                 List<Annotation> replacementAnnotations =
345                         Policy.negativeStates(enterprisePolicy);
346 
347                 annotations.addAll(index, replacementAnnotations);
348                 index += replacementAnnotations.size();
349             } else if (annotation instanceof CannotSetPolicyTest) {
350                 annotations.remove(index);
351                 Class<?> policy = ((CannotSetPolicyTest) annotation).policy();
352 
353                 EnterprisePolicy enterprisePolicy =
354                         policy.getAnnotation(EnterprisePolicy.class);
355                 List<Annotation> replacementAnnotations =
356                         Policy.cannotSetPolicyStates(enterprisePolicy);
357 
358                 annotations.addAll(index, replacementAnnotations);
359                 index += replacementAnnotations.size();
360             } else {
361                 index++;
362             }
363         }
364     }
365 
366     @Override
classRules()367     protected List<TestRule> classRules() {
368         List<TestRule> rules = super.classRules();
369 
370         for (TestRule rule : rules) {
371             if (rule instanceof DeviceState) {
372                 DeviceState deviceState = (DeviceState) rule;
373 
374                 deviceState.setSkipTestTeardown(true);
375                 deviceState.setUsingBedsteadJUnit4(true);
376 
377                 break;
378             }
379         }
380 
381         return rules;
382     }
383 }
384