1 /* 2 * Copyright (C) 2023 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package android.platform.test.flag.junit; 18 19 import static org.junit.Assume.assumeFalse; 20 import static org.junit.Assume.assumeTrue; 21 22 import android.platform.test.annotations.DisableFlags; 23 import android.platform.test.annotations.EnableFlags; 24 import android.platform.test.annotations.RequiresFlagsDisabled; 25 import android.platform.test.annotations.RequiresFlagsEnabled; 26 import android.platform.test.annotations.UsesFlags; 27 28 import com.google.common.collect.Sets; 29 30 import org.junit.Test; 31 import org.junit.runner.Description; 32 33 import java.lang.annotation.Annotation; 34 import java.lang.annotation.Documented; 35 import java.lang.annotation.Retention; 36 import java.lang.annotation.Target; 37 import java.util.ArrayDeque; 38 import java.util.Collection; 39 import java.util.HashMap; 40 import java.util.HashSet; 41 import java.util.List; 42 import java.util.Map; 43 import java.util.Objects; 44 import java.util.Queue; 45 import java.util.Set; 46 import java.util.stream.Collectors; 47 48 import javax.annotation.Nonnull; 49 50 /** 51 * Retrieves feature flag related annotations from a given {@code Description}. 52 * 53 * <p>For each annotation, it trys to get it from the method first, then trys to get it from the 54 * test class if the method has no such annotation. 55 */ 56 public class AnnotationsRetriever { 57 private static final Set<Class<?>> KNOWN_UNRELATED_ANNOTATIONS = 58 Set.of(Test.class, Retention.class, Target.class, Documented.class); 59 AnnotationsRetriever()60 private AnnotationsRetriever() {} 61 getAllUsedFlagsClasses(Description description)62 public static Set<String> getAllUsedFlagsClasses(Description description) { 63 Set<String> result = new HashSet<>(); 64 Set<Object> visited = new HashSet<>(); 65 collectUsedFlagsClasses(description, visited, result); 66 return result; 67 } 68 collectUsedFlagsClasses( Description description, Set<Object> visited, Set<String> result)69 private static void collectUsedFlagsClasses( 70 Description description, Set<Object> visited, Set<String> result) { 71 Class<?> testClass = description.getTestClass(); 72 if (testClass != null && !visited.contains(testClass)) { 73 visited.add(testClass); 74 collectUsedFlagsClasses(List.of(testClass.getAnnotations()), result); 75 } 76 collectUsedFlagsClasses(description.getAnnotations(), result); 77 for (Description child : description.getChildren()) { 78 collectUsedFlagsClasses(child, visited, result); 79 } 80 } 81 collectUsedFlagsClasses( Collection<Annotation> annotations, Set<String> result)82 private static void collectUsedFlagsClasses( 83 Collection<Annotation> annotations, Set<String> result) { 84 result.addAll(getFlagsForAnnotation(sUsesFlags, annotations)); 85 } 86 getAllAnnotationSetFlags(Description description)87 public static Set<String> getAllAnnotationSetFlags(Description description) { 88 Set<String> result = new HashSet<>(); 89 Set<Object> visited = new HashSet<>(); 90 collectAnnotationSetFlags(description, visited, result); 91 return result; 92 } 93 collectAnnotationSetFlags( Description description, Set<Object> visited, Set<String> result)94 private static void collectAnnotationSetFlags( 95 Description description, Set<Object> visited, Set<String> result) { 96 Class<?> testClass = description.getTestClass(); 97 if (testClass != null && !visited.contains(testClass)) { 98 visited.add(testClass); 99 collectAnnotationSetFlags(List.of(testClass.getAnnotations()), result); 100 } 101 collectAnnotationSetFlags(description.getAnnotations(), result); 102 for (Description child : description.getChildren()) { 103 collectAnnotationSetFlags(child, visited, result); 104 } 105 } 106 collectAnnotationSetFlags( Collection<Annotation> annotations, Set<String> result)107 private static void collectAnnotationSetFlags( 108 Collection<Annotation> annotations, Set<String> result) { 109 result.addAll(getFlagsForAnnotation(sEnableFlags, annotations)); 110 result.addAll(getFlagsForAnnotation(sDisableFlags, annotations)); 111 } 112 113 /** Gets all feature flag related annotations. */ getFlagAnnotations(Description description)114 public static FlagAnnotations getFlagAnnotations(Description description) { 115 final Map<String, Boolean> requiresFlagValues = 116 getMergedFlagValues(sRequiresFlagsEnabled, sRequiresFlagsDisabled, description); 117 final Map<String, Boolean> setsFlagValues = 118 getMergedFlagValues(sEnableFlags, sDisableFlags, description); 119 120 // Assert that no flag is defined in both maps 121 Set<String> inconsistentFlags = 122 Sets.intersection(requiresFlagValues.keySet(), setsFlagValues.keySet()); 123 if (!inconsistentFlags.isEmpty()) { 124 throw new AssertionError( 125 "The following flags are both required and set: " + inconsistentFlags); 126 } 127 128 return new FlagAnnotations(requiresFlagValues, setsFlagValues); 129 } 130 getMergedFlagValues( FlagsAnnotation<? extends Annotation> enabledAnnotation, FlagsAnnotation<? extends Annotation> disabledAnnotation, Description description)131 private static Map<String, Boolean> getMergedFlagValues( 132 FlagsAnnotation<? extends Annotation> enabledAnnotation, 133 FlagsAnnotation<? extends Annotation> disabledAnnotation, 134 Description description) { 135 final Map<String, Boolean> methodFlagValues = 136 getFlagValues( 137 description.getMethodName(), 138 enabledAnnotation, 139 disabledAnnotation, 140 description.getAnnotations()); 141 Class<?> testClass = description.getTestClass(); 142 final Map<String, Boolean> classFlagValues = 143 testClass == null 144 ? new HashMap<>() 145 : getFlagValues( 146 testClass.getName(), 147 enabledAnnotation, 148 disabledAnnotation, 149 List.of(testClass.getAnnotations())); 150 Sets.SetView<String> doublyDefinedFlags = 151 Sets.intersection(classFlagValues.keySet(), methodFlagValues.keySet()); 152 if (!doublyDefinedFlags.isEmpty()) { 153 List<String> mismatchedFlags = 154 doublyDefinedFlags.stream() 155 .filter( 156 flag -> 157 !Objects.equals( 158 classFlagValues.get(flag), 159 methodFlagValues.get(flag))) 160 .collect(Collectors.toList()); 161 if (!mismatchedFlags.isEmpty()) { 162 throw new AssertionError( 163 "The following flags are required by " 164 + description.getMethodName() 165 + " and " 166 + description.getClassName() 167 + " to be both enabled and disabled: " 168 + mismatchedFlags); 169 } 170 } 171 // Now override the class values with the method values to produce a merged map 172 classFlagValues.putAll(methodFlagValues); 173 return classFlagValues; 174 } 175 getFlagValues( @onnull String annotationTarget, @Nonnull FlagsAnnotation<? extends Annotation> enabledAnnotation, @Nonnull FlagsAnnotation<? extends Annotation> disabledAnnotation, @Nonnull Collection<Annotation> annotations)176 private static Map<String, Boolean> getFlagValues( 177 @Nonnull String annotationTarget, 178 @Nonnull FlagsAnnotation<? extends Annotation> enabledAnnotation, 179 @Nonnull FlagsAnnotation<? extends Annotation> disabledAnnotation, 180 @Nonnull Collection<Annotation> annotations) { 181 Set<String> enabledFlags = getFlagsForAnnotation(enabledAnnotation, annotations); 182 Set<String> disabledFlags = getFlagsForAnnotation(disabledAnnotation, annotations); 183 if (enabledFlags.isEmpty() && disabledFlags.isEmpty()) { 184 return new HashMap<>(); 185 } 186 Set<String> inconsistentFlags = Sets.intersection(enabledFlags, disabledFlags); 187 if (!inconsistentFlags.isEmpty()) { 188 throw new AssertionError( 189 "The following flags are required by " 190 + annotationTarget 191 + " to be both enabled and disabled: " 192 + inconsistentFlags); 193 } 194 HashMap<String, Boolean> result = new HashMap<>(); 195 for (String enabledFlag : enabledFlags) { 196 result.put(enabledFlag, true); 197 } 198 for (String disabledFlag : disabledFlags) { 199 result.put(disabledFlag, false); 200 } 201 return result; 202 } 203 204 @Nonnull getFlagsForAnnotation( FlagsAnnotation<T> flagsAnnotation, Collection<Annotation> annotations)205 private static <T extends Annotation> Set<String> getFlagsForAnnotation( 206 FlagsAnnotation<T> flagsAnnotation, Collection<Annotation> annotations) { 207 Class<T> annotationType = flagsAnnotation.mAnnotationType; 208 Set<String> results = new HashSet<>(); 209 Queue<Annotation> annotationQueue = new ArrayDeque<>(); 210 Set<Class<? extends Annotation>> visitedAnnotations = new HashSet<>(); 211 annotationQueue.addAll(annotations); 212 while (!annotationQueue.isEmpty()) { 213 Annotation annotation = annotationQueue.poll(); 214 Class<? extends Annotation> currentAnnotationType = annotation.annotationType(); 215 if (currentAnnotationType.equals(annotationType)) { 216 results.addAll(flagsAnnotation.getFlagsSet((T) annotation)); 217 } else if (!KNOWN_UNRELATED_ANNOTATIONS.contains(currentAnnotationType) 218 && !visitedAnnotations.contains(currentAnnotationType)) { 219 annotationQueue.addAll(List.of(annotation.annotationType().getAnnotations())); 220 visitedAnnotations.add(currentAnnotationType); 221 } 222 } 223 return results; 224 } 225 226 /** Contains all feature flag related annotations. */ 227 public static class FlagAnnotations { 228 229 /** The flag names which have required values, mapped to the value they require */ 230 public @Nonnull Map<String, Boolean> mRequiredFlagValues; 231 232 /** The flag names which have values to be set, mapped to the value they set */ 233 public @Nonnull Map<String, Boolean> mSetFlagValues; 234 FlagAnnotations( @onnull Map<String, Boolean> requiredFlagValues, @Nonnull Map<String, Boolean> setFlagValues)235 FlagAnnotations( 236 @Nonnull Map<String, Boolean> requiredFlagValues, 237 @Nonnull Map<String, Boolean> setFlagValues) { 238 mRequiredFlagValues = requiredFlagValues; 239 mSetFlagValues = setFlagValues; 240 } 241 242 /** 243 * Check that all @RequiresFlagsEnabled and @RequiresFlagsDisabled annotations match the 244 * values from the provider, and if this is not true, throw {@link 245 * org.junit.AssumptionViolatedException} 246 * 247 * @param valueProvider the value provider 248 */ assumeAllRequiredFlagsMatchProvider(IFlagsValueProvider valueProvider)249 public void assumeAllRequiredFlagsMatchProvider(IFlagsValueProvider valueProvider) { 250 for (Map.Entry<String, Boolean> required : mRequiredFlagValues.entrySet()) { 251 final String flag = required.getKey(); 252 if (required.getValue()) { 253 assumeTrue( 254 String.format("Flag %s required to be enabled, but is disabled", flag), 255 valueProvider.getBoolean(flag)); 256 } else { 257 assumeFalse( 258 String.format("Flag %s required to be disabled, but is enabled", flag), 259 valueProvider.getBoolean(flag)); 260 } 261 } 262 } 263 264 /** 265 * Check that all @EnableFlags and @DisableFlags annotations match the values contained in 266 * the parameterization (if present), and if this is not true, throw {@link 267 * org.junit.AssumptionViolatedException} 268 * 269 * @param params the parameterization to evaluate against (optional) 270 */ assumeAllSetFlagsMatchParameterization(FlagsParameterization params)271 public void assumeAllSetFlagsMatchParameterization(FlagsParameterization params) { 272 if (params == null) return; 273 for (Map.Entry<String, Boolean> toSet : mSetFlagValues.entrySet()) { 274 final String flag = toSet.getKey(); 275 final Boolean paramValue = params.mOverrides.get(flag); 276 if (paramValue == null) continue; 277 if (toSet.getValue()) { 278 assumeTrue( 279 String.format( 280 "Flag %s is enabled by annotation but disabled by the current" 281 + " FlagsParameterization; skipping test", 282 flag), 283 paramValue); 284 } else { 285 assumeFalse( 286 String.format( 287 "Flag %s is disabled by annotation but enabled by the current" 288 + " FlagsParameterization; skipping test", 289 flag), 290 paramValue); 291 } 292 } 293 } 294 } 295 296 private abstract static class FlagsAnnotation<T extends Annotation> { 297 Class<T> mAnnotationType; 298 FlagsAnnotation(Class<T> type)299 FlagsAnnotation(Class<T> type) { 300 mAnnotationType = type; 301 } 302 getFlags(T annotation)303 protected abstract String[] getFlags(T annotation); 304 305 @Nonnull getFlagsSet(T annotation)306 Set<String> getFlagsSet(T annotation) { 307 String[] flags = getFlags(annotation); 308 return flags == null ? Set.of() : Set.of(flags); 309 } 310 } 311 312 private static final FlagsAnnotation<RequiresFlagsEnabled> sRequiresFlagsEnabled = 313 new FlagsAnnotation<>(RequiresFlagsEnabled.class) { 314 @Override 315 protected String[] getFlags(RequiresFlagsEnabled annotation) { 316 return annotation.value(); 317 } 318 }; 319 private static final FlagsAnnotation<RequiresFlagsDisabled> sRequiresFlagsDisabled = 320 new FlagsAnnotation<>(RequiresFlagsDisabled.class) { 321 @Override 322 protected String[] getFlags(RequiresFlagsDisabled annotation) { 323 return annotation.value(); 324 } 325 }; 326 private static final FlagsAnnotation<EnableFlags> sEnableFlags = 327 new FlagsAnnotation<>(EnableFlags.class) { 328 @Override 329 protected String[] getFlags(EnableFlags annotation) { 330 return annotation.value(); 331 } 332 }; 333 private static final FlagsAnnotation<DisableFlags> sDisableFlags = 334 new FlagsAnnotation<>(DisableFlags.class) { 335 @Override 336 protected String[] getFlags(DisableFlags annotation) { 337 return annotation.value(); 338 } 339 }; 340 private static final FlagsAnnotation<UsesFlags> sUsesFlags = 341 new FlagsAnnotation<>(UsesFlags.class) { 342 @Override 343 protected String[] getFlags(UsesFlags annotation) { 344 throw new UnsupportedOperationException("call getFlagsSet instead"); 345 } 346 347 @Nonnull 348 Set<String> getFlagsSet(UsesFlags annotation) { 349 Class<?>[] values = annotation.value(); 350 if (values == null) { 351 return Set.of(); 352 } 353 HashSet<String> result = new HashSet<>(); 354 for (Class<?> flagsClass : values) { 355 result.add(flagsClass.getName()); 356 } 357 return result; 358 } 359 }; 360 } 361