1 /*
<lambda>null2  * Copyright (C) 2022 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 @file:Suppress("JAVA_MODULE_DOES_NOT_EXPORT_PACKAGE")
18 
19 package android.processor.immutability
20 
21 import com.sun.tools.javac.code.Symbol
22 import com.sun.tools.javac.code.Type
23 import javax.annotation.processing.AbstractProcessor
24 import javax.annotation.processing.ProcessingEnvironment
25 import javax.annotation.processing.RoundEnvironment
26 import javax.lang.model.SourceVersion
27 import javax.lang.model.element.Element
28 import javax.lang.model.element.ElementKind
29 import javax.lang.model.element.Modifier
30 import javax.lang.model.element.TypeElement
31 import javax.lang.model.type.TypeKind
32 import javax.lang.model.type.TypeMirror
33 import javax.tools.Diagnostic
34 
35 val IMMUTABLE_ANNOTATION_NAME = Immutable::class.qualifiedName
36 
37 class ImmutabilityProcessor : AbstractProcessor() {
38 
39     companion object {
40 
41         /**
42          * Types that are already immutable. Will also ignore subclasses.
43          */
44         private val IGNORED_SUPER_TYPES = listOf(
45             "java.io.File",
46             "java.lang.Boolean",
47             "java.lang.Byte",
48             "java.lang.CharSequence",
49             "java.lang.Character",
50             "java.lang.Double",
51             "java.lang.Float",
52             "java.lang.Integer",
53             "java.lang.Long",
54             "java.lang.Short",
55             "java.lang.String",
56             "java.lang.Void",
57             "java.util.UUID",
58             "android.os.Parcelable.Creator",
59         )
60 
61         /**
62          * Types that are already immutable. Must be an exact match, does not include any super
63          * or sub classes.
64          */
65         private val IGNORED_EXACT_TYPES = listOf(
66             "java.lang.Class",
67             "java.lang.Object",
68         )
69 
70         private val IGNORED_METHODS = listOf(
71             "writeToParcel",
72         )
73     }
74 
75     private lateinit var collectionType: TypeMirror
76     private lateinit var mapType: TypeMirror
77 
78     private lateinit var ignoredSuperTypes: List<TypeMirror>
79     private lateinit var ignoredExactTypes: List<TypeMirror>
80 
81     private val seenTypesByPolicy = mutableMapOf<Set<Immutable.Policy.Exception>, Set<Type>>()
82 
83     override fun getSupportedSourceVersion() = SourceVersion.latest()!!
84 
85     override fun getSupportedAnnotationTypes() = setOf(Immutable::class.qualifiedName)
86 
87     override fun init(processingEnv: ProcessingEnvironment) {
88         super.init(processingEnv)
89         collectionType = processingEnv.erasedType("java.util.Collection")!!
90         mapType = processingEnv.erasedType("java.util.Map")!!
91         ignoredSuperTypes = IGNORED_SUPER_TYPES.mapNotNull { processingEnv.erasedType(it) }
92         ignoredExactTypes = IGNORED_EXACT_TYPES.mapNotNull { processingEnv.erasedType(it) }
93     }
94 
95     override fun process(
96         annotations: MutableSet<out TypeElement>,
97         roundEnvironment: RoundEnvironment
98     ): Boolean {
99         annotations.find {
100             it.qualifiedName.toString() == IMMUTABLE_ANNOTATION_NAME
101         } ?: return false
102         roundEnvironment.getElementsAnnotatedWith(Immutable::class.java)
103             .forEach {
104                 visitClass(
105                     parentChain = emptyList(),
106                     seenTypesByPolicy = seenTypesByPolicy,
107                     elementToPrint = it,
108                     classType = it as Symbol.TypeSymbol,
109                     parentPolicyExceptions = emptySet()
110                 )
111             }
112         return true
113     }
114 
115     /**
116      * @return true if any error was encountered at this level or any child level
117      */
118     private fun visitClass(
119         parentChain: List<String>,
120         seenTypesByPolicy: MutableMap<Set<Immutable.Policy.Exception>, Set<Type>>,
121         elementToPrint: Element,
122         classType: Symbol.TypeSymbol,
123         parentPolicyExceptions: Set<Immutable.Policy.Exception>,
124     ): Boolean {
125         if (isIgnored(classType)) return false
126 
127         val policyAnnotation = classType.getAnnotation(Immutable.Policy::class.java)
128         val newPolicyExceptions = parentPolicyExceptions + policyAnnotation?.exceptions.orEmpty()
129 
130         // If already seen this type with the same policies applied, skip it
131         val seenTypes = seenTypesByPolicy[newPolicyExceptions]
132         val type = classType.asType()
133         if (seenTypes?.contains(type) == true) return false
134         seenTypesByPolicy[newPolicyExceptions] = seenTypes.orEmpty() + type
135 
136         val allowFinalClassesFinalFields =
137             newPolicyExceptions.contains(Immutable.Policy.Exception.FINAL_CLASSES_WITH_FINAL_FIELDS)
138 
139         val filteredElements = classType.enclosedElements
140             .filterNot(::isIgnored)
141 
142         val hasFieldError = filteredElements
143             .filter { it.getKind() == ElementKind.FIELD }
144             .fold(false) { anyError, field ->
145                 if (field.isStatic) {
146                     if (!field.isPrivate) {
147                         val finalityError = !field.modifiers.contains(Modifier.FINAL)
148                         if (finalityError) {
149                             printError(parentChain, field, MessageUtils.staticNonFinalFailure())
150                         }
151 
152                         // Must call visitType first so it doesn't get short circuited by the ||
153                         visitType(
154                             parentChain = parentChain,
155                             seenTypesByPolicy = seenTypesByPolicy,
156                             symbol = field,
157                             type = field.type,
158                             parentPolicyExceptions = parentPolicyExceptions
159                         ) || anyError || finalityError
160                     }
161                     return@fold anyError
162                 } else {
163                     val isFinal = field.modifiers.contains(Modifier.FINAL)
164                     if (!isFinal || !allowFinalClassesFinalFields) {
165                         printError(parentChain, field, MessageUtils.memberNotMethodFailure())
166                         return@fold true
167                     }
168 
169                     return@fold anyError
170                 }
171             }
172 
173         // Scan inner classes before methods so that any violations isolated to the file prints
174         // the error on the class declaration rather than on the method that returns the type.
175         // Although it doesn't matter too much either way.
176         val hasClassError = filteredElements
177             .filter { it.getKind() == ElementKind.CLASS }
178             .map { it as Symbol.ClassSymbol }
179             .fold(false) { anyError, innerClass ->
180                 // Must call visitClass first so it doesn't get short circuited by the ||
181                 visitClass(
182                     parentChain,
183                     seenTypesByPolicy,
184                     innerClass,
185                     innerClass,
186                     newPolicyExceptions
187                 ) || anyError
188             }
189 
190         val newChain = parentChain + "$classType"
191 
192         val hasMethodError = filteredElements
193             .asSequence()
194             .filter { it.getKind() == ElementKind.METHOD }
195             .map { it as Symbol.MethodSymbol }
196             .filterNot { it.isStatic }
197             .filterNot { IGNORED_METHODS.contains(it.name.toString()) }
198             .fold(false) { anyError, method ->
199                 // Must call visitMethod first so it doesn't get short circuited by the ||
200                 visitMethod(newChain, seenTypesByPolicy, method, newPolicyExceptions) || anyError
201             }
202 
203         val className = classType.simpleName.toString()
204         val isRegularClass = classType.getKind() == ElementKind.CLASS
205 
206         var anyError = hasFieldError || hasClassError || hasMethodError
207 
208         // If final classes are not considered OR there's a non-field failure, also check for
209         // interface/@Immutable, assuming the class is malformed
210         if ((isRegularClass && !allowFinalClassesFinalFields) || hasMethodError || hasClassError) {
211             if (classType.getAnnotation(Immutable::class.java) == null) {
212                 printError(
213                     parentChain,
214                     elementToPrint,
215                     MessageUtils.classNotImmutableFailure(className)
216                 )
217                 anyError = true
218             }
219 
220             if (classType.getKind() != ElementKind.INTERFACE) {
221                 printError(parentChain, elementToPrint, MessageUtils.nonInterfaceClassFailure())
222                 anyError = true
223             }
224         }
225 
226         // Check all of the super classes, since methods in those classes are also accessible
227         (classType as? Symbol.ClassSymbol)?.run {
228             (interfaces + superclass).forEach {
229                 val element = it.asElement() ?: return@forEach
230                 visitClass(parentChain, seenTypesByPolicy, element, element, newPolicyExceptions)
231             }
232         }
233 
234         if (isRegularClass && !anyError && allowFinalClassesFinalFields &&
235             !classType.modifiers.contains(Modifier.FINAL)
236         ) {
237             printError(parentChain, elementToPrint, MessageUtils.classNotFinalFailure(className))
238             return true
239         }
240 
241         return anyError
242     }
243 
244     /**
245      * @return true if any error was encountered at this level or any child level
246      */
247     private fun visitMethod(
248         parentChain: List<String>,
249         seenTypesByPolicy: MutableMap<Set<Immutable.Policy.Exception>, Set<Type>>,
250         method: Symbol.MethodSymbol,
251         parentPolicyExceptions: Set<Immutable.Policy.Exception>,
252     ): Boolean {
253         val returnType = method.returnType
254         val typeName = returnType.toString()
255         when (returnType.kind) {
256             TypeKind.BOOLEAN,
257             TypeKind.BYTE,
258             TypeKind.SHORT,
259             TypeKind.INT,
260             TypeKind.LONG,
261             TypeKind.CHAR,
262             TypeKind.FLOAT,
263             TypeKind.DOUBLE,
264             TypeKind.NONE,
265             TypeKind.NULL -> {
266                 // Do nothing
267             }
268             TypeKind.VOID -> {
269                 if (!method.isConstructor) {
270                     printError(parentChain, method, MessageUtils.voidReturnFailure())
271                     return true
272                 }
273             }
274             TypeKind.ARRAY -> {
275                 printError(parentChain, method, MessageUtils.arrayFailure())
276                 return true
277             }
278             TypeKind.DECLARED -> {
279                 return visitType(
280                     parentChain,
281                     seenTypesByPolicy,
282                     method,
283                     method.returnType,
284                     parentPolicyExceptions
285                 )
286             }
287             TypeKind.ERROR,
288             TypeKind.TYPEVAR,
289             TypeKind.WILDCARD,
290             TypeKind.PACKAGE,
291             TypeKind.EXECUTABLE,
292             TypeKind.OTHER,
293             TypeKind.UNION,
294             TypeKind.INTERSECTION,
295                 // Java 9+
296                 // TypeKind.MODULE,
297             null -> {
298                 printError(
299                     parentChain, method,
300                     MessageUtils.genericTypeKindFailure(typeName = typeName)
301                 )
302                 return true
303             }
304             else -> {
305                 printError(
306                     parentChain, method,
307                     MessageUtils.genericTypeKindFailure(typeName = typeName)
308                 )
309                 return true
310             }
311         }
312 
313         return false
314     }
315 
316     /**
317      * @return true if any error was encountered at this level or any child level
318      */
319     private fun visitType(
320         parentChain: List<String>,
321         seenTypesByPolicy: MutableMap<Set<Immutable.Policy.Exception>, Set<Type>>,
322         symbol: Symbol,
323         type: Type,
324         parentPolicyExceptions: Set<Immutable.Policy.Exception>,
325         nonInterfaceClassFailure: () -> String = { MessageUtils.nonInterfaceReturnFailure() },
326     ): Boolean {
327         // Skip if the symbol being considered is itself ignored
328         if (isIgnored(symbol)) return false
329 
330         // Skip if the type being checked, like for a typeArg or return type, is ignored
331         if (isIgnored(type)) return false
332 
333         // Skip if that typeArg is itself ignored when inspected at the class header level
334         if (isIgnored(type.asElement())) return false
335 
336         if (type.isPrimitive) return false
337         if (type.isPrimitiveOrVoid) {
338             printError(parentChain, symbol, MessageUtils.voidReturnFailure())
339             return true
340         }
341 
342         val policyAnnotation = symbol.getAnnotation(Immutable.Policy::class.java)
343         val newPolicyExceptions = parentPolicyExceptions + policyAnnotation?.exceptions.orEmpty()
344 
345         // Collection (and Map) types are ignored for the interface check as they have immutability
346         // enforced through a runtime exception which must be verified in a separate runtime test
347         val isMap = processingEnv.typeUtils.isAssignable(type, mapType)
348         if (!processingEnv.typeUtils.isAssignable(type, collectionType) && !isMap) {
349             if (!type.isInterface && !newPolicyExceptions
350                     .contains(Immutable.Policy.Exception.FINAL_CLASSES_WITH_FINAL_FIELDS)
351             ) {
352                 printError(parentChain, symbol, nonInterfaceClassFailure())
353                 return true
354             } else {
355                 return visitClass(
356                     parentChain, seenTypesByPolicy, symbol,
357                     processingEnv.typeUtils.asElement(type) as Symbol.TypeSymbol,
358                     newPolicyExceptions,
359                 )
360             }
361         }
362 
363         var anyError = false
364 
365         type.typeArguments.forEachIndexed { index, typeArg ->
366             if (isIgnored(typeArg.asElement())) return@forEachIndexed
367 
368             val argError =
369                 visitType(parentChain, seenTypesByPolicy, symbol, typeArg, newPolicyExceptions) {
370                     MessageUtils.nonInterfaceReturnFailure(
371                         prefix = when {
372                             !isMap -> ""
373                             index == 0 -> "Key " + typeArg.asElement().simpleName
374                             else -> "Value " + typeArg.asElement().simpleName
375                         }, index = index
376                     )
377                 }
378             anyError = anyError || argError
379         }
380 
381         return anyError
382     }
383 
384     private fun printError(
385         parentChain: List<String>,
386         element: Element,
387         message: String,
388     ) = processingEnv.messager.printMessage(
389         Diagnostic.Kind.ERROR,
390         parentChain.plus(element.simpleName).joinToString() + "\n\t " + message,
391         element,
392     )
393 
394     private fun ProcessingEnvironment.erasedType(typeName: String) =
395         elementUtils.getTypeElement(typeName)?.asType()?.let(typeUtils::erasure)
396 
397     private fun isIgnored(type: Type) =
398         (type.getAnnotation(Immutable.Ignore::class.java) != null)
399                 || (ignoredSuperTypes.any { type.isAssignable(it) })
400                 || (ignoredExactTypes.any { type.isSameType(it) })
401 
402     private fun isIgnored(symbol: Symbol) = when {
403         // Anything annotated as @Ignore is always ignored
404         symbol.getAnnotation(Immutable.Ignore::class.java) != null -> true
405         // Then ignore exact types, regardless of what kind they are
406         ignoredExactTypes.any { symbol.type.isSameType(it) } -> true
407         // Then only allow methods through, since other types (fields) are usually a failure
408         symbol.getKind() != ElementKind.METHOD -> false
409         // Finally, check for any ignored super types
410         else -> ignoredSuperTypes.any { symbol.type.isAssignable(it) }
411     }
412 
413     private fun TypeMirror.isAssignable(type: TypeMirror) = try {
414         processingEnv.typeUtils.isAssignable(this, type)
415     } catch (ignored: Exception) {
416         false
417     }
418 
419     private fun TypeMirror.isSameType(type: TypeMirror) = try {
420         processingEnv.typeUtils.isSameType(this, type)
421     } catch (ignored: Exception) {
422         false
423     }
424 }
425