1 /* <lambda>null2 * Copyright (C) 2022 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 @file:Suppress("JAVA_MODULE_DOES_NOT_EXPORT_PACKAGE") 18 19 package android.processor.immutability 20 21 import com.sun.tools.javac.code.Symbol 22 import com.sun.tools.javac.code.Type 23 import javax.annotation.processing.AbstractProcessor 24 import javax.annotation.processing.ProcessingEnvironment 25 import javax.annotation.processing.RoundEnvironment 26 import javax.lang.model.SourceVersion 27 import javax.lang.model.element.Element 28 import javax.lang.model.element.ElementKind 29 import javax.lang.model.element.Modifier 30 import javax.lang.model.element.TypeElement 31 import javax.lang.model.type.TypeKind 32 import javax.lang.model.type.TypeMirror 33 import javax.tools.Diagnostic 34 35 val IMMUTABLE_ANNOTATION_NAME = Immutable::class.qualifiedName 36 37 class ImmutabilityProcessor : AbstractProcessor() { 38 39 companion object { 40 41 /** 42 * Types that are already immutable. Will also ignore subclasses. 43 */ 44 private val IGNORED_SUPER_TYPES = listOf( 45 "java.io.File", 46 "java.lang.Boolean", 47 "java.lang.Byte", 48 "java.lang.CharSequence", 49 "java.lang.Character", 50 "java.lang.Double", 51 "java.lang.Float", 52 "java.lang.Integer", 53 "java.lang.Long", 54 "java.lang.Short", 55 "java.lang.String", 56 "java.lang.Void", 57 "java.util.UUID", 58 "android.os.Parcelable.Creator", 59 ) 60 61 /** 62 * Types that are already immutable. Must be an exact match, does not include any super 63 * or sub classes. 64 */ 65 private val IGNORED_EXACT_TYPES = listOf( 66 "java.lang.Class", 67 "java.lang.Object", 68 ) 69 70 private val IGNORED_METHODS = listOf( 71 "writeToParcel", 72 ) 73 } 74 75 private lateinit var collectionType: TypeMirror 76 private lateinit var mapType: TypeMirror 77 78 private lateinit var ignoredSuperTypes: List<TypeMirror> 79 private lateinit var ignoredExactTypes: List<TypeMirror> 80 81 private val seenTypesByPolicy = mutableMapOf<Set<Immutable.Policy.Exception>, Set<Type>>() 82 83 override fun getSupportedSourceVersion() = SourceVersion.latest()!! 84 85 override fun getSupportedAnnotationTypes() = setOf(Immutable::class.qualifiedName) 86 87 override fun init(processingEnv: ProcessingEnvironment) { 88 super.init(processingEnv) 89 collectionType = processingEnv.erasedType("java.util.Collection")!! 90 mapType = processingEnv.erasedType("java.util.Map")!! 91 ignoredSuperTypes = IGNORED_SUPER_TYPES.mapNotNull { processingEnv.erasedType(it) } 92 ignoredExactTypes = IGNORED_EXACT_TYPES.mapNotNull { processingEnv.erasedType(it) } 93 } 94 95 override fun process( 96 annotations: MutableSet<out TypeElement>, 97 roundEnvironment: RoundEnvironment 98 ): Boolean { 99 annotations.find { 100 it.qualifiedName.toString() == IMMUTABLE_ANNOTATION_NAME 101 } ?: return false 102 roundEnvironment.getElementsAnnotatedWith(Immutable::class.java) 103 .forEach { 104 visitClass( 105 parentChain = emptyList(), 106 seenTypesByPolicy = seenTypesByPolicy, 107 elementToPrint = it, 108 classType = it as Symbol.TypeSymbol, 109 parentPolicyExceptions = emptySet() 110 ) 111 } 112 return true 113 } 114 115 /** 116 * @return true if any error was encountered at this level or any child level 117 */ 118 private fun visitClass( 119 parentChain: List<String>, 120 seenTypesByPolicy: MutableMap<Set<Immutable.Policy.Exception>, Set<Type>>, 121 elementToPrint: Element, 122 classType: Symbol.TypeSymbol, 123 parentPolicyExceptions: Set<Immutable.Policy.Exception>, 124 ): Boolean { 125 if (isIgnored(classType)) return false 126 127 val policyAnnotation = classType.getAnnotation(Immutable.Policy::class.java) 128 val newPolicyExceptions = parentPolicyExceptions + policyAnnotation?.exceptions.orEmpty() 129 130 // If already seen this type with the same policies applied, skip it 131 val seenTypes = seenTypesByPolicy[newPolicyExceptions] 132 val type = classType.asType() 133 if (seenTypes?.contains(type) == true) return false 134 seenTypesByPolicy[newPolicyExceptions] = seenTypes.orEmpty() + type 135 136 val allowFinalClassesFinalFields = 137 newPolicyExceptions.contains(Immutable.Policy.Exception.FINAL_CLASSES_WITH_FINAL_FIELDS) 138 139 val filteredElements = classType.enclosedElements 140 .filterNot(::isIgnored) 141 142 val hasFieldError = filteredElements 143 .filter { it.getKind() == ElementKind.FIELD } 144 .fold(false) { anyError, field -> 145 if (field.isStatic) { 146 if (!field.isPrivate) { 147 val finalityError = !field.modifiers.contains(Modifier.FINAL) 148 if (finalityError) { 149 printError(parentChain, field, MessageUtils.staticNonFinalFailure()) 150 } 151 152 // Must call visitType first so it doesn't get short circuited by the || 153 visitType( 154 parentChain = parentChain, 155 seenTypesByPolicy = seenTypesByPolicy, 156 symbol = field, 157 type = field.type, 158 parentPolicyExceptions = parentPolicyExceptions 159 ) || anyError || finalityError 160 } 161 return@fold anyError 162 } else { 163 val isFinal = field.modifiers.contains(Modifier.FINAL) 164 if (!isFinal || !allowFinalClassesFinalFields) { 165 printError(parentChain, field, MessageUtils.memberNotMethodFailure()) 166 return@fold true 167 } 168 169 return@fold anyError 170 } 171 } 172 173 // Scan inner classes before methods so that any violations isolated to the file prints 174 // the error on the class declaration rather than on the method that returns the type. 175 // Although it doesn't matter too much either way. 176 val hasClassError = filteredElements 177 .filter { it.getKind() == ElementKind.CLASS } 178 .map { it as Symbol.ClassSymbol } 179 .fold(false) { anyError, innerClass -> 180 // Must call visitClass first so it doesn't get short circuited by the || 181 visitClass( 182 parentChain, 183 seenTypesByPolicy, 184 innerClass, 185 innerClass, 186 newPolicyExceptions 187 ) || anyError 188 } 189 190 val newChain = parentChain + "$classType" 191 192 val hasMethodError = filteredElements 193 .asSequence() 194 .filter { it.getKind() == ElementKind.METHOD } 195 .map { it as Symbol.MethodSymbol } 196 .filterNot { it.isStatic } 197 .filterNot { IGNORED_METHODS.contains(it.name.toString()) } 198 .fold(false) { anyError, method -> 199 // Must call visitMethod first so it doesn't get short circuited by the || 200 visitMethod(newChain, seenTypesByPolicy, method, newPolicyExceptions) || anyError 201 } 202 203 val className = classType.simpleName.toString() 204 val isRegularClass = classType.getKind() == ElementKind.CLASS 205 206 var anyError = hasFieldError || hasClassError || hasMethodError 207 208 // If final classes are not considered OR there's a non-field failure, also check for 209 // interface/@Immutable, assuming the class is malformed 210 if ((isRegularClass && !allowFinalClassesFinalFields) || hasMethodError || hasClassError) { 211 if (classType.getAnnotation(Immutable::class.java) == null) { 212 printError( 213 parentChain, 214 elementToPrint, 215 MessageUtils.classNotImmutableFailure(className) 216 ) 217 anyError = true 218 } 219 220 if (classType.getKind() != ElementKind.INTERFACE) { 221 printError(parentChain, elementToPrint, MessageUtils.nonInterfaceClassFailure()) 222 anyError = true 223 } 224 } 225 226 // Check all of the super classes, since methods in those classes are also accessible 227 (classType as? Symbol.ClassSymbol)?.run { 228 (interfaces + superclass).forEach { 229 val element = it.asElement() ?: return@forEach 230 visitClass(parentChain, seenTypesByPolicy, element, element, newPolicyExceptions) 231 } 232 } 233 234 if (isRegularClass && !anyError && allowFinalClassesFinalFields && 235 !classType.modifiers.contains(Modifier.FINAL) 236 ) { 237 printError(parentChain, elementToPrint, MessageUtils.classNotFinalFailure(className)) 238 return true 239 } 240 241 return anyError 242 } 243 244 /** 245 * @return true if any error was encountered at this level or any child level 246 */ 247 private fun visitMethod( 248 parentChain: List<String>, 249 seenTypesByPolicy: MutableMap<Set<Immutable.Policy.Exception>, Set<Type>>, 250 method: Symbol.MethodSymbol, 251 parentPolicyExceptions: Set<Immutable.Policy.Exception>, 252 ): Boolean { 253 val returnType = method.returnType 254 val typeName = returnType.toString() 255 when (returnType.kind) { 256 TypeKind.BOOLEAN, 257 TypeKind.BYTE, 258 TypeKind.SHORT, 259 TypeKind.INT, 260 TypeKind.LONG, 261 TypeKind.CHAR, 262 TypeKind.FLOAT, 263 TypeKind.DOUBLE, 264 TypeKind.NONE, 265 TypeKind.NULL -> { 266 // Do nothing 267 } 268 TypeKind.VOID -> { 269 if (!method.isConstructor) { 270 printError(parentChain, method, MessageUtils.voidReturnFailure()) 271 return true 272 } 273 } 274 TypeKind.ARRAY -> { 275 printError(parentChain, method, MessageUtils.arrayFailure()) 276 return true 277 } 278 TypeKind.DECLARED -> { 279 return visitType( 280 parentChain, 281 seenTypesByPolicy, 282 method, 283 method.returnType, 284 parentPolicyExceptions 285 ) 286 } 287 TypeKind.ERROR, 288 TypeKind.TYPEVAR, 289 TypeKind.WILDCARD, 290 TypeKind.PACKAGE, 291 TypeKind.EXECUTABLE, 292 TypeKind.OTHER, 293 TypeKind.UNION, 294 TypeKind.INTERSECTION, 295 // Java 9+ 296 // TypeKind.MODULE, 297 null -> { 298 printError( 299 parentChain, method, 300 MessageUtils.genericTypeKindFailure(typeName = typeName) 301 ) 302 return true 303 } 304 else -> { 305 printError( 306 parentChain, method, 307 MessageUtils.genericTypeKindFailure(typeName = typeName) 308 ) 309 return true 310 } 311 } 312 313 return false 314 } 315 316 /** 317 * @return true if any error was encountered at this level or any child level 318 */ 319 private fun visitType( 320 parentChain: List<String>, 321 seenTypesByPolicy: MutableMap<Set<Immutable.Policy.Exception>, Set<Type>>, 322 symbol: Symbol, 323 type: Type, 324 parentPolicyExceptions: Set<Immutable.Policy.Exception>, 325 nonInterfaceClassFailure: () -> String = { MessageUtils.nonInterfaceReturnFailure() }, 326 ): Boolean { 327 // Skip if the symbol being considered is itself ignored 328 if (isIgnored(symbol)) return false 329 330 // Skip if the type being checked, like for a typeArg or return type, is ignored 331 if (isIgnored(type)) return false 332 333 // Skip if that typeArg is itself ignored when inspected at the class header level 334 if (isIgnored(type.asElement())) return false 335 336 if (type.isPrimitive) return false 337 if (type.isPrimitiveOrVoid) { 338 printError(parentChain, symbol, MessageUtils.voidReturnFailure()) 339 return true 340 } 341 342 val policyAnnotation = symbol.getAnnotation(Immutable.Policy::class.java) 343 val newPolicyExceptions = parentPolicyExceptions + policyAnnotation?.exceptions.orEmpty() 344 345 // Collection (and Map) types are ignored for the interface check as they have immutability 346 // enforced through a runtime exception which must be verified in a separate runtime test 347 val isMap = processingEnv.typeUtils.isAssignable(type, mapType) 348 if (!processingEnv.typeUtils.isAssignable(type, collectionType) && !isMap) { 349 if (!type.isInterface && !newPolicyExceptions 350 .contains(Immutable.Policy.Exception.FINAL_CLASSES_WITH_FINAL_FIELDS) 351 ) { 352 printError(parentChain, symbol, nonInterfaceClassFailure()) 353 return true 354 } else { 355 return visitClass( 356 parentChain, seenTypesByPolicy, symbol, 357 processingEnv.typeUtils.asElement(type) as Symbol.TypeSymbol, 358 newPolicyExceptions, 359 ) 360 } 361 } 362 363 var anyError = false 364 365 type.typeArguments.forEachIndexed { index, typeArg -> 366 if (isIgnored(typeArg.asElement())) return@forEachIndexed 367 368 val argError = 369 visitType(parentChain, seenTypesByPolicy, symbol, typeArg, newPolicyExceptions) { 370 MessageUtils.nonInterfaceReturnFailure( 371 prefix = when { 372 !isMap -> "" 373 index == 0 -> "Key " + typeArg.asElement().simpleName 374 else -> "Value " + typeArg.asElement().simpleName 375 }, index = index 376 ) 377 } 378 anyError = anyError || argError 379 } 380 381 return anyError 382 } 383 384 private fun printError( 385 parentChain: List<String>, 386 element: Element, 387 message: String, 388 ) = processingEnv.messager.printMessage( 389 Diagnostic.Kind.ERROR, 390 parentChain.plus(element.simpleName).joinToString() + "\n\t " + message, 391 element, 392 ) 393 394 private fun ProcessingEnvironment.erasedType(typeName: String) = 395 elementUtils.getTypeElement(typeName)?.asType()?.let(typeUtils::erasure) 396 397 private fun isIgnored(type: Type) = 398 (type.getAnnotation(Immutable.Ignore::class.java) != null) 399 || (ignoredSuperTypes.any { type.isAssignable(it) }) 400 || (ignoredExactTypes.any { type.isSameType(it) }) 401 402 private fun isIgnored(symbol: Symbol) = when { 403 // Anything annotated as @Ignore is always ignored 404 symbol.getAnnotation(Immutable.Ignore::class.java) != null -> true 405 // Then ignore exact types, regardless of what kind they are 406 ignoredExactTypes.any { symbol.type.isSameType(it) } -> true 407 // Then only allow methods through, since other types (fields) are usually a failure 408 symbol.getKind() != ElementKind.METHOD -> false 409 // Finally, check for any ignored super types 410 else -> ignoredSuperTypes.any { symbol.type.isAssignable(it) } 411 } 412 413 private fun TypeMirror.isAssignable(type: TypeMirror) = try { 414 processingEnv.typeUtils.isAssignable(this, type) 415 } catch (ignored: Exception) { 416 false 417 } 418 419 private fun TypeMirror.isSameType(type: TypeMirror) = try { 420 processingEnv.typeUtils.isSameType(this, type) 421 } catch (ignored: Exception) { 422 false 423 } 424 } 425