1 /*
<lambda>null2  * Copyright (C) 2016 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package androidx.room.writer
18 
19 import androidx.room.ext.L
20 import androidx.room.ext.N
21 import androidx.room.ext.RoomTypeNames
22 import androidx.room.ext.SupportDbTypeNames
23 import androidx.room.ext.T
24 import androidx.room.ext.typeName
25 import androidx.room.parser.QueryType
26 import androidx.room.processor.OnConflictProcessor
27 import androidx.room.solver.CodeGenScope
28 import androidx.room.vo.Dao
29 import androidx.room.vo.Entity
30 import androidx.room.vo.InsertionMethod
31 import androidx.room.vo.QueryMethod
32 import androidx.room.vo.RawQueryMethod
33 import androidx.room.vo.ShortcutMethod
34 import androidx.room.vo.TransactionMethod
35 import com.google.auto.common.MoreTypes
36 import com.squareup.javapoet.ClassName
37 import com.squareup.javapoet.CodeBlock
38 import com.squareup.javapoet.FieldSpec
39 import com.squareup.javapoet.MethodSpec
40 import com.squareup.javapoet.ParameterSpec
41 import com.squareup.javapoet.TypeName
42 import com.squareup.javapoet.TypeSpec
43 import me.eugeniomarletti.kotlin.metadata.shadow.load.java.JvmAbi
44 import stripNonJava
45 import javax.annotation.processing.ProcessingEnvironment
46 import javax.lang.model.element.ElementKind
47 import javax.lang.model.element.ExecutableElement
48 import javax.lang.model.element.Modifier.FINAL
49 import javax.lang.model.element.Modifier.PRIVATE
50 import javax.lang.model.element.Modifier.PUBLIC
51 import javax.lang.model.type.DeclaredType
52 import javax.lang.model.type.TypeKind
53 
54 /**
55  * Creates the implementation for a class annotated with Dao.
56  */
57 class DaoWriter(val dao: Dao, val processingEnv: ProcessingEnvironment)
58     : ClassWriter(dao.typeName) {
59     private val declaredDao = MoreTypes.asDeclared(dao.element.asType())
60 
61     companion object {
62         // TODO nothing prevents this from conflicting, we should fix.
63         val dbField: FieldSpec = FieldSpec
64                 .builder(RoomTypeNames.ROOM_DB, "__db", PRIVATE, FINAL)
65                 .build()
66 
67         private fun typeNameToFieldName(typeName: TypeName?): String {
68             if (typeName is ClassName) {
69                 return typeName.simpleName()
70             } else {
71                 return typeName.toString().replace('.', '_').stripNonJava()
72             }
73         }
74     }
75 
76     override fun createTypeSpecBuilder(): TypeSpec.Builder {
77         val builder = TypeSpec.classBuilder(dao.implTypeName)
78         /**
79          * if delete / update query method wants to return modified rows, we need prepared query.
80          * in that case, if args are dynamic, we cannot re-use the query, if not, we should re-use
81          * it. this requires more work but creates good performance.
82          */
83         val groupedDeleteUpdate = dao.queryMethods
84                 .filter { it.query.type == QueryType.DELETE || it.query.type == QueryType.UPDATE }
85                 .groupBy { it.parameters.any { it.queryParamAdapter?.isMultiple ?: true } }
86         // delete queries that can be prepared ahead of time
87         val preparedDeleteOrUpdateQueries = groupedDeleteUpdate[false] ?: emptyList()
88         // delete queries that must be rebuild every single time
89         val oneOffDeleteOrUpdateQueries = groupedDeleteUpdate[true] ?: emptyList()
90         val shortcutMethods = createInsertionMethods() +
91                 createDeletionMethods() + createUpdateMethods() + createTransactionMethods() +
92                 createPreparedDeleteOrUpdateQueries(preparedDeleteOrUpdateQueries)
93 
94         builder.apply {
95             addModifiers(PUBLIC)
96             if (dao.element.kind == ElementKind.INTERFACE) {
97                 addSuperinterface(dao.typeName)
98             } else {
99                 superclass(dao.typeName)
100             }
101             addField(dbField)
102             val dbParam = ParameterSpec
103                     .builder(dao.constructorParamType ?: dbField.type, dbField.name).build()
104 
105             addMethod(createConstructor(dbParam, shortcutMethods, dao.constructorParamType != null))
106 
107             shortcutMethods.forEach {
108                 addMethod(it.methodImpl)
109             }
110 
111             dao.queryMethods.filter { it.query.type == QueryType.SELECT }.forEach { method ->
112                 addMethod(createSelectMethod(method))
113             }
114             oneOffDeleteOrUpdateQueries.forEach {
115                 addMethod(createDeleteOrUpdateQueryMethod(it))
116             }
117             dao.rawQueryMethods.forEach {
118                 addMethod(createRawQueryMethod(it))
119             }
120         }
121         return builder
122     }
123 
124     private fun createPreparedDeleteOrUpdateQueries(
125             preparedDeleteQueries: List<QueryMethod>): List<PreparedStmtQuery> {
126         return preparedDeleteQueries.map { method ->
127             val fieldSpec = getOrCreateField(PreparedStatementField(method))
128             val queryWriter = QueryWriter(method)
129             val fieldImpl = PreparedStatementWriter(queryWriter)
130                     .createAnonymous(this@DaoWriter, dbField)
131             val methodBody = createPreparedDeleteQueryMethodBody(method, fieldSpec, queryWriter)
132             PreparedStmtQuery(mapOf(PreparedStmtQuery.NO_PARAM_FIELD
133                     to (fieldSpec to fieldImpl)), methodBody)
134         }
135     }
136 
137     private fun createPreparedDeleteQueryMethodBody(
138             method: QueryMethod,
139             preparedStmtField: FieldSpec,
140             queryWriter: QueryWriter
141     ): MethodSpec {
142         val scope = CodeGenScope(this)
143         val methodBuilder = overrideWithoutAnnotations(method.element, declaredDao).apply {
144             val stmtName = scope.getTmpVar("_stmt")
145             addStatement("final $T $L = $N.acquire()",
146                     SupportDbTypeNames.SQLITE_STMT, stmtName, preparedStmtField)
147             addStatement("$N.beginTransaction()", dbField)
148             beginControlFlow("try").apply {
149                 val bindScope = scope.fork()
150                 queryWriter.bindArgs(stmtName, emptyList(), bindScope)
151                 addCode(bindScope.builder().build())
152                 if (method.returnsValue) {
153                     val resultVar = scope.getTmpVar("_result")
154                     addStatement("final $L $L = $L.executeUpdateDelete()",
155                             method.returnType.typeName(), resultVar, stmtName)
156                     addStatement("$N.setTransactionSuccessful()", dbField)
157                     addStatement("return $L", resultVar)
158                 } else {
159                     addStatement("$L.executeUpdateDelete()", stmtName)
160                     addStatement("$N.setTransactionSuccessful()", dbField)
161                 }
162             }
163             nextControlFlow("finally").apply {
164                 addStatement("$N.endTransaction()", dbField)
165                 addStatement("$N.release($L)", preparedStmtField, stmtName)
166             }
167             endControlFlow()
168         }
169         return methodBuilder.build()
170     }
171 
172     private fun createTransactionMethods(): List<PreparedStmtQuery> {
173         return dao.transactionMethods.map {
174             PreparedStmtQuery(emptyMap(), createTransactionMethodBody(it))
175         }
176     }
177 
178     private fun createTransactionMethodBody(method: TransactionMethod): MethodSpec {
179         val scope = CodeGenScope(this)
180         val methodBuilder = overrideWithoutAnnotations(method.element, declaredDao).apply {
181             addStatement("$N.beginTransaction()", dbField)
182             beginControlFlow("try").apply {
183                 val returnsValue = method.element.returnType.kind != TypeKind.VOID
184                 val resultVar = if (returnsValue) {
185                     scope.getTmpVar("_result")
186                 } else {
187                     null
188                 }
189                 addDelegateToSuperStatement(method.element, method.callType, resultVar)
190                 addStatement("$N.setTransactionSuccessful()", dbField)
191                 if (returnsValue) {
192                     addStatement("return $N", resultVar)
193                 }
194             }
195             nextControlFlow("finally").apply {
196                 addStatement("$N.endTransaction()", dbField)
197             }
198             endControlFlow()
199         }
200         return methodBuilder.build()
201     }
202 
203     private fun MethodSpec.Builder.addDelegateToSuperStatement(
204             element: ExecutableElement,
205             callType: TransactionMethod.CallType,
206             result: String?) {
207         val params: MutableList<Any> = mutableListOf()
208         val format = buildString {
209             if (result != null) {
210                 append("$T $L = ")
211                 params.add(element.returnType)
212                 params.add(result)
213             }
214             when (callType) {
215                 TransactionMethod.CallType.CONCRETE -> {
216                     append("super.$N(")
217                     params.add(element.simpleName)
218                 }
219                 TransactionMethod.CallType.DEFAULT_JAVA8 -> {
220                     append("$N.super.$N(")
221                     params.add(element.enclosingElement.simpleName)
222                     params.add(element.simpleName)
223                 }
224                 TransactionMethod.CallType.DEFAULT_KOTLIN -> {
225                     append("$N.$N.$N(this, ")
226                     params.add(element.enclosingElement.simpleName)
227                     params.add(JvmAbi.DEFAULT_IMPLS_CLASS_NAME)
228                     params.add(element.simpleName)
229                 }
230             }
231             var first = true
232             element.parameters.forEach {
233                 if (first) {
234                     first = false
235                 } else {
236                     append(", ")
237                 }
238                 append(L)
239                 params.add(it.simpleName)
240             }
241             append(")")
242         }
243         addStatement(format, *params.toTypedArray())
244     }
245 
246     private fun createConstructor(
247             dbParam: ParameterSpec,
248             shortcutMethods: List<PreparedStmtQuery>,
249             callSuper: Boolean): MethodSpec {
250         return MethodSpec.constructorBuilder().apply {
251             addParameter(dbParam)
252             addModifiers(PUBLIC)
253             if (callSuper) {
254                 addStatement("super($N)", dbParam)
255             }
256             addStatement("this.$N = $N", dbField, dbParam)
257             shortcutMethods.filterNot {
258                 it.fields.isEmpty()
259             }.map {
260                 it.fields.values
261             }.flatten().groupBy {
262                 it.first.name
263             }.map {
264                 it.value.first()
265             }.forEach {
266                 addStatement("this.$N = $L", it.first, it.second)
267             }
268         }.build()
269     }
270 
271     private fun createSelectMethod(method: QueryMethod): MethodSpec {
272         return overrideWithoutAnnotations(method.element, declaredDao).apply {
273             addCode(createQueryMethodBody(method))
274         }.build()
275     }
276 
277     private fun createRawQueryMethod(method: RawQueryMethod): MethodSpec {
278         return overrideWithoutAnnotations(method.element, declaredDao).apply {
279             val scope = CodeGenScope(this@DaoWriter)
280             val roomSQLiteQueryVar: String
281             val queryParam = method.runtimeQueryParam
282             val shouldReleaseQuery: Boolean
283 
284             when {
285                 queryParam?.isString() == true -> {
286                     roomSQLiteQueryVar = scope.getTmpVar("_statement")
287                     shouldReleaseQuery = true
288                     addStatement("$T $L = $T.acquire($L, 0)",
289                             RoomTypeNames.ROOM_SQL_QUERY,
290                             roomSQLiteQueryVar,
291                             RoomTypeNames.ROOM_SQL_QUERY,
292                             queryParam.paramName)
293                 }
294                 queryParam?.isSupportQuery() == true -> {
295                     shouldReleaseQuery = false
296                     roomSQLiteQueryVar = scope.getTmpVar("_internalQuery")
297                     // move it to a final variable so that the generated code can use it inside
298                     // callback blocks in java 7
299                     addStatement("final $T $L = $N",
300                             queryParam.type,
301                             roomSQLiteQueryVar,
302                             queryParam.paramName)
303                 }
304                 else -> {
305                     // try to generate compiling code. we would've already reported this error
306                     roomSQLiteQueryVar = scope.getTmpVar("_statement")
307                     shouldReleaseQuery = false
308                     addStatement("$T $L = $T.acquire($L, 0)",
309                             RoomTypeNames.ROOM_SQL_QUERY,
310                             roomSQLiteQueryVar,
311                             RoomTypeNames.ROOM_SQL_QUERY,
312                             "missing query parameter")
313                 }
314             }
315             if (method.returnsValue) {
316                 // don't generate code because it will create 1 more error. The original error is
317                 // already reported by the processor.
318                 method.queryResultBinder.convertAndReturn(
319                         roomSQLiteQueryVar = roomSQLiteQueryVar,
320                         canReleaseQuery = shouldReleaseQuery,
321                         dbField = dbField,
322                         inTransaction = method.inTransaction,
323                         scope = scope)
324             }
325             addCode(scope.builder().build())
326         }.build()
327     }
328 
329     private fun createDeleteOrUpdateQueryMethod(method: QueryMethod): MethodSpec {
330         return overrideWithoutAnnotations(method.element, declaredDao).apply {
331             addCode(createDeleteOrUpdateQueryMethodBody(method))
332         }.build()
333     }
334 
335     /**
336      * Groups all insertion methods based on the insert statement they will use then creates all
337      * field specs, EntityInsertionAdapterWriter and actual insert methods.
338      */
339     private fun createInsertionMethods(): List<PreparedStmtQuery> {
340         return dao.insertionMethods
341                 .map { insertionMethod ->
342                     val onConflict = OnConflictProcessor.onConflictText(insertionMethod.onConflict)
343                     val entities = insertionMethod.entities
344 
345                     val fields = entities.mapValues {
346                         val spec = getOrCreateField(InsertionMethodField(it.value, onConflict))
347                         val impl = EntityInsertionAdapterWriter(it.value, onConflict)
348                                 .createAnonymous(this@DaoWriter, dbField.name)
349                         spec to impl
350                     }
351                     val methodImpl = overrideWithoutAnnotations(insertionMethod.element,
352                             declaredDao).apply {
353                         addCode(createInsertionMethodBody(insertionMethod, fields))
354                     }.build()
355                     PreparedStmtQuery(fields, methodImpl)
356                 }
357     }
358 
359     private fun createInsertionMethodBody(
360             method: InsertionMethod,
361             insertionAdapters: Map<String, Pair<FieldSpec, TypeSpec>>
362     ): CodeBlock {
363         val insertionType = method.insertionType
364         if (insertionAdapters.isEmpty() || insertionType == null) {
365             return CodeBlock.builder().build()
366         }
367         val scope = CodeGenScope(this)
368 
369         return scope.builder().apply {
370             // TODO assert thread
371             // TODO collect results
372             addStatement("$N.beginTransaction()", dbField)
373             val needsReturnType = insertionType != InsertionMethod.Type.INSERT_VOID
374             val resultVar = if (needsReturnType) {
375                 scope.getTmpVar("_result")
376             } else {
377                 null
378             }
379 
380             beginControlFlow("try").apply {
381                 method.parameters.forEach { param ->
382                     val insertionAdapter = insertionAdapters[param.name]?.first
383                     if (needsReturnType) {
384                         // if it has more than 1 parameter, we would've already printed the error
385                         // so we don't care about re-declaring the variable here
386                         addStatement("$T $L = $N.$L($L)",
387                                 insertionType.returnTypeName, resultVar,
388                                 insertionAdapter, insertionType.methodName,
389                                 param.name)
390                     } else {
391                         addStatement("$N.$L($L)", insertionAdapter, insertionType.methodName,
392                                 param.name)
393                     }
394                 }
395                 addStatement("$N.setTransactionSuccessful()", dbField)
396                 if (needsReturnType) {
397                     addStatement("return $L", resultVar)
398                 }
399             }
400             nextControlFlow("finally").apply {
401                 addStatement("$N.endTransaction()", dbField)
402             }
403             endControlFlow()
404         }.build()
405     }
406 
407     /**
408      * Creates EntityUpdateAdapter for each deletion method.
409      */
410     private fun createDeletionMethods(): List<PreparedStmtQuery> {
411         return createShortcutMethods(dao.deletionMethods, "deletion", { _, entity ->
412             EntityDeletionAdapterWriter(entity)
413                     .createAnonymous(this@DaoWriter, dbField.name)
414         })
415     }
416 
417     /**
418      * Creates EntityUpdateAdapter for each @Update method.
419      */
420     private fun createUpdateMethods(): List<PreparedStmtQuery> {
421         return createShortcutMethods(dao.updateMethods, "update", { update, entity ->
422             val onConflict = OnConflictProcessor.onConflictText(update.onConflictStrategy)
423             EntityUpdateAdapterWriter(entity, onConflict)
424                     .createAnonymous(this@DaoWriter, dbField.name)
425         })
426     }
427 
428     private fun <T : ShortcutMethod> createShortcutMethods(
429             methods: List<T>, methodPrefix: String,
430             implCallback: (T, Entity) -> TypeSpec
431     ): List<PreparedStmtQuery> {
432         return methods.mapNotNull { method ->
433             val entities = method.entities
434 
435             if (entities.isEmpty()) {
436                 null
437             } else {
438                 val fields = entities.mapValues {
439                     val spec = getOrCreateField(DeleteOrUpdateAdapterField(it.value, methodPrefix))
440                     val impl = implCallback(method, it.value)
441                     spec to impl
442                 }
443                 val methodSpec = overrideWithoutAnnotations(method.element, declaredDao).apply {
444                     addCode(createDeleteOrUpdateMethodBody(method, fields))
445                 }.build()
446                 PreparedStmtQuery(fields, methodSpec)
447             }
448         }
449     }
450 
451     private fun createDeleteOrUpdateMethodBody(
452             method: ShortcutMethod,
453             adapters: Map<String, Pair<FieldSpec, TypeSpec>>
454     ): CodeBlock {
455         if (adapters.isEmpty()) {
456             return CodeBlock.builder().build()
457         }
458         val scope = CodeGenScope(this)
459         val resultVar = if (method.returnCount) {
460             scope.getTmpVar("_total")
461         } else {
462             null
463         }
464         return scope.builder().apply {
465             if (resultVar != null) {
466                 addStatement("$T $L = 0", TypeName.INT, resultVar)
467             }
468             addStatement("$N.beginTransaction()", dbField)
469             beginControlFlow("try").apply {
470                 method.parameters.forEach { param ->
471                     val adapter = adapters[param.name]?.first
472                     addStatement("$L$N.$L($L)",
473                             if (resultVar == null) "" else "$resultVar +=",
474                             adapter, param.handleMethodName(), param.name)
475                 }
476                 addStatement("$N.setTransactionSuccessful()", dbField)
477                 if (resultVar != null) {
478                     addStatement("return $L", resultVar)
479                 }
480             }
481             nextControlFlow("finally").apply {
482                 addStatement("$N.endTransaction()", dbField)
483             }
484             endControlFlow()
485         }.build()
486     }
487 
488     /**
489      * @Query with delete action
490      */
491     private fun createDeleteOrUpdateQueryMethodBody(method: QueryMethod): CodeBlock {
492         val queryWriter = QueryWriter(method)
493         val scope = CodeGenScope(this)
494         val sqlVar = scope.getTmpVar("_sql")
495         val stmtVar = scope.getTmpVar("_stmt")
496         val listSizeArgs = queryWriter.prepareQuery(sqlVar, scope)
497         scope.builder().apply {
498             addStatement("$T $L = $N.compileStatement($L)",
499                     SupportDbTypeNames.SQLITE_STMT, stmtVar, dbField, sqlVar)
500             queryWriter.bindArgs(stmtVar, listSizeArgs, scope)
501             addStatement("$N.beginTransaction()", dbField)
502             beginControlFlow("try").apply {
503                 if (method.returnsValue) {
504                     val resultVar = scope.getTmpVar("_result")
505                     addStatement("final $L $L = $L.executeUpdateDelete()",
506                             method.returnType.typeName(), resultVar, stmtVar)
507                     addStatement("$N.setTransactionSuccessful()", dbField)
508                     addStatement("return $L", resultVar)
509                 } else {
510                     addStatement("$L.executeUpdateDelete()", stmtVar)
511                     addStatement("$N.setTransactionSuccessful()", dbField)
512                 }
513             }
514             nextControlFlow("finally").apply {
515                 addStatement("$N.endTransaction()", dbField)
516             }
517             endControlFlow()
518         }
519         return scope.builder().build()
520     }
521 
522     private fun createQueryMethodBody(method: QueryMethod): CodeBlock {
523         val queryWriter = QueryWriter(method)
524         val scope = CodeGenScope(this)
525         val sqlVar = scope.getTmpVar("_sql")
526         val roomSQLiteQueryVar = scope.getTmpVar("_statement")
527         queryWriter.prepareReadAndBind(sqlVar, roomSQLiteQueryVar, scope)
528         method.queryResultBinder.convertAndReturn(
529                 roomSQLiteQueryVar = roomSQLiteQueryVar,
530                 canReleaseQuery = true,
531                 dbField = dbField,
532                 inTransaction = method.inTransaction,
533                 scope = scope)
534         return scope.builder().build()
535     }
536 
537     private fun overrideWithoutAnnotations(
538             elm: ExecutableElement,
539             owner: DeclaredType): MethodSpec.Builder {
540         val baseSpec = MethodSpec.overriding(elm, owner, processingEnv.typeUtils).build()
541         return MethodSpec.methodBuilder(baseSpec.name).apply {
542             addAnnotation(Override::class.java)
543             addModifiers(baseSpec.modifiers)
544             addParameters(baseSpec.parameters)
545             varargs(baseSpec.varargs)
546             returns(baseSpec.returnType)
547         }
548     }
549 
550     /**
551      * Represents a query statement prepared in Dao implementation.
552      *
553      * @param fields This map holds all the member fields necessary for this query. The key is the
554      * corresponding parameter name in the defining query method. The value is a pair from the field
555      * declaration to definition.
556      * @param methodImpl The body of the query method implementation.
557      */
558     data class PreparedStmtQuery(
559             val fields: Map<String, Pair<FieldSpec, TypeSpec>>,
560             val methodImpl: MethodSpec) {
561         companion object {
562             // The key to be used in `fields` where the method requires a field that is not
563             // associated with any of its parameters
564             const val NO_PARAM_FIELD = "-"
565         }
566     }
567 
568     private class InsertionMethodField(val entity: Entity, val onConflictText: String)
569         : SharedFieldSpec(
570             "insertionAdapterOf${Companion.typeNameToFieldName(entity.typeName)}",
571             RoomTypeNames.INSERTION_ADAPTER) {
572 
573         override fun getUniqueKey(): String {
574             return "${entity.typeName} $onConflictText"
575         }
576 
577         override fun prepare(writer: ClassWriter, builder: FieldSpec.Builder) {
578             builder.addModifiers(FINAL, PRIVATE)
579         }
580     }
581 
582     class DeleteOrUpdateAdapterField(val entity: Entity, val methodPrefix: String)
583         : SharedFieldSpec(
584             "${methodPrefix}AdapterOf${Companion.typeNameToFieldName(entity.typeName)}",
585             RoomTypeNames.DELETE_OR_UPDATE_ADAPTER) {
586         override fun prepare(writer: ClassWriter, builder: FieldSpec.Builder) {
587             builder.addModifiers(PRIVATE, FINAL)
588         }
589 
590         override fun getUniqueKey(): String {
591             return entity.typeName.toString() + methodPrefix
592         }
593     }
594 
595     class PreparedStatementField(val method: QueryMethod) : SharedFieldSpec(
596             "preparedStmtOf${method.name.capitalize()}", RoomTypeNames.SHARED_SQLITE_STMT) {
597         override fun prepare(writer: ClassWriter, builder: FieldSpec.Builder) {
598             builder.addModifiers(PRIVATE, FINAL)
599         }
600 
601         override fun getUniqueKey(): String {
602             return method.query.original
603         }
604     }
605 }
606