1 /*
2  * Copyright (C) 2023 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package android.platform.test.flag.junit;
18 
19 import static org.junit.Assume.assumeFalse;
20 import static org.junit.Assume.assumeTrue;
21 
22 import android.platform.test.annotations.DisableFlags;
23 import android.platform.test.annotations.EnableFlags;
24 import android.platform.test.annotations.RequiresFlagsDisabled;
25 import android.platform.test.annotations.RequiresFlagsEnabled;
26 import android.platform.test.annotations.UsesFlags;
27 
28 import com.google.common.collect.Sets;
29 
30 import org.junit.Test;
31 import org.junit.runner.Description;
32 
33 import java.lang.annotation.Annotation;
34 import java.lang.annotation.Documented;
35 import java.lang.annotation.Retention;
36 import java.lang.annotation.Target;
37 import java.util.ArrayDeque;
38 import java.util.Collection;
39 import java.util.HashMap;
40 import java.util.HashSet;
41 import java.util.List;
42 import java.util.Map;
43 import java.util.Objects;
44 import java.util.Queue;
45 import java.util.Set;
46 import java.util.stream.Collectors;
47 
48 import javax.annotation.Nonnull;
49 
50 /**
51  * Retrieves feature flag related annotations from a given {@code Description}.
52  *
53  * <p>For each annotation, it trys to get it from the method first, then trys to get it from the
54  * test class if the method has no such annotation.
55  */
56 public class AnnotationsRetriever {
57     private static final Set<Class<?>> KNOWN_UNRELATED_ANNOTATIONS =
58             Set.of(Test.class, Retention.class, Target.class, Documented.class);
59 
AnnotationsRetriever()60     private AnnotationsRetriever() {}
61 
getAllUsedFlagsClasses(Description description)62     public static Set<String> getAllUsedFlagsClasses(Description description) {
63         Set<String> result = new HashSet<>();
64         Set<Object> visited = new HashSet<>();
65         collectUsedFlagsClasses(description, visited, result);
66         return result;
67     }
68 
collectUsedFlagsClasses( Description description, Set<Object> visited, Set<String> result)69     private static void collectUsedFlagsClasses(
70             Description description, Set<Object> visited, Set<String> result) {
71         Class<?> testClass = description.getTestClass();
72         if (testClass != null && !visited.contains(testClass)) {
73             visited.add(testClass);
74             collectUsedFlagsClasses(List.of(testClass.getAnnotations()), result);
75         }
76         collectUsedFlagsClasses(description.getAnnotations(), result);
77         for (Description child : description.getChildren()) {
78             collectUsedFlagsClasses(child, visited, result);
79         }
80     }
81 
collectUsedFlagsClasses( Collection<Annotation> annotations, Set<String> result)82     private static void collectUsedFlagsClasses(
83             Collection<Annotation> annotations, Set<String> result) {
84         result.addAll(getFlagsForAnnotation(sUsesFlags, annotations));
85     }
86 
getAllAnnotationSetFlags(Description description)87     public static Set<String> getAllAnnotationSetFlags(Description description) {
88         Set<String> result = new HashSet<>();
89         Set<Object> visited = new HashSet<>();
90         collectAnnotationSetFlags(description, visited, result);
91         return result;
92     }
93 
collectAnnotationSetFlags( Description description, Set<Object> visited, Set<String> result)94     private static void collectAnnotationSetFlags(
95             Description description, Set<Object> visited, Set<String> result) {
96         Class<?> testClass = description.getTestClass();
97         if (testClass != null && !visited.contains(testClass)) {
98             visited.add(testClass);
99             collectAnnotationSetFlags(List.of(testClass.getAnnotations()), result);
100         }
101         collectAnnotationSetFlags(description.getAnnotations(), result);
102         for (Description child : description.getChildren()) {
103             collectAnnotationSetFlags(child, visited, result);
104         }
105     }
106 
collectAnnotationSetFlags( Collection<Annotation> annotations, Set<String> result)107     private static void collectAnnotationSetFlags(
108             Collection<Annotation> annotations, Set<String> result) {
109         result.addAll(getFlagsForAnnotation(sEnableFlags, annotations));
110         result.addAll(getFlagsForAnnotation(sDisableFlags, annotations));
111     }
112 
113     /** Gets all feature flag related annotations. */
getFlagAnnotations(Description description)114     public static FlagAnnotations getFlagAnnotations(Description description) {
115         final Map<String, Boolean> requiresFlagValues =
116                 getMergedFlagValues(sRequiresFlagsEnabled, sRequiresFlagsDisabled, description);
117         final Map<String, Boolean> setsFlagValues =
118                 getMergedFlagValues(sEnableFlags, sDisableFlags, description);
119 
120         // Assert that no flag is defined in both maps
121         Set<String> inconsistentFlags =
122                 Sets.intersection(requiresFlagValues.keySet(), setsFlagValues.keySet());
123         if (!inconsistentFlags.isEmpty()) {
124             throw new AssertionError(
125                     "The following flags are both required and set: " + inconsistentFlags);
126         }
127 
128         return new FlagAnnotations(requiresFlagValues, setsFlagValues);
129     }
130 
getMergedFlagValues( FlagsAnnotation<? extends Annotation> enabledAnnotation, FlagsAnnotation<? extends Annotation> disabledAnnotation, Description description)131     private static Map<String, Boolean> getMergedFlagValues(
132             FlagsAnnotation<? extends Annotation> enabledAnnotation,
133             FlagsAnnotation<? extends Annotation> disabledAnnotation,
134             Description description) {
135         final Map<String, Boolean> methodFlagValues =
136                 getFlagValues(
137                         description.getMethodName(),
138                         enabledAnnotation,
139                         disabledAnnotation,
140                         description.getAnnotations());
141         Class<?> testClass = description.getTestClass();
142         final Map<String, Boolean> classFlagValues =
143                 testClass == null
144                         ? new HashMap<>()
145                         : getFlagValues(
146                                 testClass.getName(),
147                                 enabledAnnotation,
148                                 disabledAnnotation,
149                                 List.of(testClass.getAnnotations()));
150         Sets.SetView<String> doublyDefinedFlags =
151                 Sets.intersection(classFlagValues.keySet(), methodFlagValues.keySet());
152         if (!doublyDefinedFlags.isEmpty()) {
153             List<String> mismatchedFlags =
154                     doublyDefinedFlags.stream()
155                             .filter(
156                                     flag ->
157                                             !Objects.equals(
158                                                     classFlagValues.get(flag),
159                                                     methodFlagValues.get(flag)))
160                             .collect(Collectors.toList());
161             if (!mismatchedFlags.isEmpty()) {
162                 throw new AssertionError(
163                         "The following flags are required by "
164                                 + description.getMethodName()
165                                 + " and "
166                                 + description.getClassName()
167                                 + " to be both enabled and disabled: "
168                                 + mismatchedFlags);
169             }
170         }
171         // Now override the class values with the method values to produce a merged map
172         classFlagValues.putAll(methodFlagValues);
173         return classFlagValues;
174     }
175 
getFlagValues( @onnull String annotationTarget, @Nonnull FlagsAnnotation<? extends Annotation> enabledAnnotation, @Nonnull FlagsAnnotation<? extends Annotation> disabledAnnotation, @Nonnull Collection<Annotation> annotations)176     private static Map<String, Boolean> getFlagValues(
177             @Nonnull String annotationTarget,
178             @Nonnull FlagsAnnotation<? extends Annotation> enabledAnnotation,
179             @Nonnull FlagsAnnotation<? extends Annotation> disabledAnnotation,
180             @Nonnull Collection<Annotation> annotations) {
181         Set<String> enabledFlags = getFlagsForAnnotation(enabledAnnotation, annotations);
182         Set<String> disabledFlags = getFlagsForAnnotation(disabledAnnotation, annotations);
183         if (enabledFlags.isEmpty() && disabledFlags.isEmpty()) {
184             return new HashMap<>();
185         }
186         Set<String> inconsistentFlags = Sets.intersection(enabledFlags, disabledFlags);
187         if (!inconsistentFlags.isEmpty()) {
188             throw new AssertionError(
189                     "The following flags are required by "
190                             + annotationTarget
191                             + " to be both enabled and disabled: "
192                             + inconsistentFlags);
193         }
194         HashMap<String, Boolean> result = new HashMap<>();
195         for (String enabledFlag : enabledFlags) {
196             result.put(enabledFlag, true);
197         }
198         for (String disabledFlag : disabledFlags) {
199             result.put(disabledFlag, false);
200         }
201         return result;
202     }
203 
204     @Nonnull
getFlagsForAnnotation( FlagsAnnotation<T> flagsAnnotation, Collection<Annotation> annotations)205     private static <T extends Annotation> Set<String> getFlagsForAnnotation(
206             FlagsAnnotation<T> flagsAnnotation, Collection<Annotation> annotations) {
207         Class<T> annotationType = flagsAnnotation.mAnnotationType;
208         Set<String> results = new HashSet<>();
209         Queue<Annotation> annotationQueue = new ArrayDeque<>();
210         Set<Class<? extends Annotation>> visitedAnnotations = new HashSet<>();
211         annotationQueue.addAll(annotations);
212         while (!annotationQueue.isEmpty()) {
213             Annotation annotation = annotationQueue.poll();
214             Class<? extends Annotation> currentAnnotationType = annotation.annotationType();
215             if (currentAnnotationType.equals(annotationType)) {
216                 results.addAll(flagsAnnotation.getFlagsSet((T) annotation));
217             } else if (!KNOWN_UNRELATED_ANNOTATIONS.contains(currentAnnotationType)
218                     && !visitedAnnotations.contains(currentAnnotationType)) {
219                 annotationQueue.addAll(List.of(annotation.annotationType().getAnnotations()));
220                 visitedAnnotations.add(currentAnnotationType);
221             }
222         }
223         return results;
224     }
225 
226     /** Contains all feature flag related annotations. */
227     public static class FlagAnnotations {
228 
229         /** The flag names which have required values, mapped to the value they require */
230         public @Nonnull Map<String, Boolean> mRequiredFlagValues;
231 
232         /** The flag names which have values to be set, mapped to the value they set */
233         public @Nonnull Map<String, Boolean> mSetFlagValues;
234 
FlagAnnotations( @onnull Map<String, Boolean> requiredFlagValues, @Nonnull Map<String, Boolean> setFlagValues)235         FlagAnnotations(
236                 @Nonnull Map<String, Boolean> requiredFlagValues,
237                 @Nonnull Map<String, Boolean> setFlagValues) {
238             mRequiredFlagValues = requiredFlagValues;
239             mSetFlagValues = setFlagValues;
240         }
241 
242         /**
243          * Check that all @RequiresFlagsEnabled and @RequiresFlagsDisabled annotations match the
244          * values from the provider, and if this is not true, throw {@link
245          * org.junit.AssumptionViolatedException}
246          *
247          * @param valueProvider the value provider
248          */
assumeAllRequiredFlagsMatchProvider(IFlagsValueProvider valueProvider)249         public void assumeAllRequiredFlagsMatchProvider(IFlagsValueProvider valueProvider) {
250             for (Map.Entry<String, Boolean> required : mRequiredFlagValues.entrySet()) {
251                 final String flag = required.getKey();
252                 if (required.getValue()) {
253                     assumeTrue(
254                             String.format("Flag %s required to be enabled, but is disabled", flag),
255                             valueProvider.getBoolean(flag));
256                 } else {
257                     assumeFalse(
258                             String.format("Flag %s required to be disabled, but is enabled", flag),
259                             valueProvider.getBoolean(flag));
260                 }
261             }
262         }
263 
264         /**
265          * Check that all @EnableFlags and @DisableFlags annotations match the values contained in
266          * the parameterization (if present), and if this is not true, throw {@link
267          * org.junit.AssumptionViolatedException}
268          *
269          * @param params the parameterization to evaluate against (optional)
270          */
assumeAllSetFlagsMatchParameterization(FlagsParameterization params)271         public void assumeAllSetFlagsMatchParameterization(FlagsParameterization params) {
272             if (params == null) return;
273             for (Map.Entry<String, Boolean> toSet : mSetFlagValues.entrySet()) {
274                 final String flag = toSet.getKey();
275                 final Boolean paramValue = params.mOverrides.get(flag);
276                 if (paramValue == null) continue;
277                 if (toSet.getValue()) {
278                     assumeTrue(
279                             String.format(
280                                     "Flag %s is enabled by annotation but disabled by the current"
281                                             + " FlagsParameterization; skipping test",
282                                     flag),
283                             paramValue);
284                 } else {
285                     assumeFalse(
286                             String.format(
287                                     "Flag %s is disabled by annotation but enabled by the current"
288                                             + " FlagsParameterization; skipping test",
289                                     flag),
290                             paramValue);
291                 }
292             }
293         }
294     }
295 
296     private abstract static class FlagsAnnotation<T extends Annotation> {
297         Class<T> mAnnotationType;
298 
FlagsAnnotation(Class<T> type)299         FlagsAnnotation(Class<T> type) {
300             mAnnotationType = type;
301         }
302 
getFlags(T annotation)303         protected abstract String[] getFlags(T annotation);
304 
305         @Nonnull
getFlagsSet(T annotation)306         Set<String> getFlagsSet(T annotation) {
307             String[] flags = getFlags(annotation);
308             return flags == null ? Set.of() : Set.of(flags);
309         }
310     }
311 
312     private static final FlagsAnnotation<RequiresFlagsEnabled> sRequiresFlagsEnabled =
313             new FlagsAnnotation<>(RequiresFlagsEnabled.class) {
314                 @Override
315                 protected String[] getFlags(RequiresFlagsEnabled annotation) {
316                     return annotation.value();
317                 }
318             };
319     private static final FlagsAnnotation<RequiresFlagsDisabled> sRequiresFlagsDisabled =
320             new FlagsAnnotation<>(RequiresFlagsDisabled.class) {
321                 @Override
322                 protected String[] getFlags(RequiresFlagsDisabled annotation) {
323                     return annotation.value();
324                 }
325             };
326     private static final FlagsAnnotation<EnableFlags> sEnableFlags =
327             new FlagsAnnotation<>(EnableFlags.class) {
328                 @Override
329                 protected String[] getFlags(EnableFlags annotation) {
330                     return annotation.value();
331                 }
332             };
333     private static final FlagsAnnotation<DisableFlags> sDisableFlags =
334             new FlagsAnnotation<>(DisableFlags.class) {
335                 @Override
336                 protected String[] getFlags(DisableFlags annotation) {
337                     return annotation.value();
338                 }
339             };
340     private static final FlagsAnnotation<UsesFlags> sUsesFlags =
341             new FlagsAnnotation<>(UsesFlags.class) {
342                 @Override
343                 protected String[] getFlags(UsesFlags annotation) {
344                     throw new UnsupportedOperationException("call getFlagsSet instead");
345                 }
346 
347                 @Nonnull
348                 Set<String> getFlagsSet(UsesFlags annotation) {
349                     Class<?>[] values = annotation.value();
350                     if (values == null) {
351                         return Set.of();
352                     }
353                     HashSet<String> result = new HashSet<>();
354                     for (Class<?> flagsClass : values) {
355                         result.add(flagsClass.getName());
356                     }
357                     return result;
358                 }
359             };
360 }
361