1 /*
2  * Copyright (C) 2016 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package androidx.room.parser
18 
19 import androidx.room.ColumnInfo
20 import org.antlr.v4.runtime.ANTLRInputStream
21 import org.antlr.v4.runtime.BaseErrorListener
22 import org.antlr.v4.runtime.CommonTokenStream
23 import org.antlr.v4.runtime.RecognitionException
24 import org.antlr.v4.runtime.Recognizer
25 import org.antlr.v4.runtime.tree.ParseTree
26 import org.antlr.v4.runtime.tree.TerminalNode
27 import javax.annotation.processing.ProcessingEnvironment
28 import javax.lang.model.type.TypeKind
29 import javax.lang.model.type.TypeMirror
30 
31 @Suppress("FunctionName")
32 class QueryVisitor(
33         private val original: String,
34         private val syntaxErrors: ArrayList<String>,
35         statement: ParseTree,
36         private val forRuntimeQuery: Boolean
37 ) : SQLiteBaseVisitor<Void?>() {
38     private val bindingExpressions = arrayListOf<TerminalNode>()
39     // table name alias mappings
40     private val tableNames = mutableSetOf<Table>()
41     private val withClauseNames = mutableSetOf<String>()
42     private val queryType: QueryType
43 
44     init {
<lambda>null45         queryType = (0 until statement.childCount).map {
46             findQueryType(statement.getChild(it))
47         }.filterNot { it == QueryType.UNKNOWN }.firstOrNull() ?: QueryType.UNKNOWN
48 
49         statement.accept(this)
50     }
51 
findQueryTypenull52     private fun findQueryType(statement: ParseTree): QueryType {
53         return when (statement) {
54             is SQLiteParser.Factored_select_stmtContext,
55             is SQLiteParser.Compound_select_stmtContext,
56             is SQLiteParser.Select_stmtContext,
57             is SQLiteParser.Simple_select_stmtContext ->
58                 QueryType.SELECT
59 
60             is SQLiteParser.Delete_stmt_limitedContext,
61             is SQLiteParser.Delete_stmtContext ->
62                 QueryType.DELETE
63 
64             is SQLiteParser.Insert_stmtContext ->
65                 QueryType.INSERT
66             is SQLiteParser.Update_stmtContext,
67             is SQLiteParser.Update_stmt_limitedContext ->
68                 QueryType.UPDATE
69             is TerminalNode -> when (statement.text) {
70                 "EXPLAIN" -> QueryType.EXPLAIN
71                 else -> QueryType.UNKNOWN
72             }
73             else -> QueryType.UNKNOWN
74         }
75     }
76 
visitExprnull77     override fun visitExpr(ctx: SQLiteParser.ExprContext): Void? {
78         val bindParameter = ctx.BIND_PARAMETER()
79         if (bindParameter != null) {
80             bindingExpressions.add(bindParameter)
81         }
82         return super.visitExpr(ctx)
83     }
84 
createParsedQuerynull85     fun createParsedQuery(): ParsedQuery {
86         return ParsedQuery(
87                 original = original,
88                 type = queryType,
89                 inputs = bindingExpressions.sortedBy { it.sourceInterval.a },
90                 tables = tableNames,
91                 syntaxErrors = syntaxErrors,
92                 runtimeQueryPlaceholder = forRuntimeQuery)
93     }
94 
visitCommon_table_expressionnull95     override fun visitCommon_table_expression(
96             ctx: SQLiteParser.Common_table_expressionContext): Void? {
97         val tableName = ctx.table_name()?.text
98         if (tableName != null) {
99             withClauseNames.add(unescapeIdentifier(tableName))
100         }
101         return super.visitCommon_table_expression(ctx)
102     }
103 
visitTable_or_subquerynull104     override fun visitTable_or_subquery(ctx: SQLiteParser.Table_or_subqueryContext): Void? {
105         val tableName = ctx.table_name()?.text
106         if (tableName != null) {
107             val tableAlias = ctx.table_alias()?.text
108             if (tableName !in withClauseNames) {
109                 tableNames.add(Table(
110                         unescapeIdentifier(tableName),
111                         unescapeIdentifier(tableAlias ?: tableName)))
112             }
113         }
114         return super.visitTable_or_subquery(ctx)
115     }
116 
unescapeIdentifiernull117     private fun unescapeIdentifier(text: String): String {
118         val trimmed = text.trim()
119         ESCAPE_LITERALS.forEach {
120             if (trimmed.startsWith(it) && trimmed.endsWith(it)) {
121                 return unescapeIdentifier(trimmed.substring(1, trimmed.length - 1))
122             }
123         }
124         return trimmed
125     }
126 
127     companion object {
128         private val ESCAPE_LITERALS = listOf("\"", "'", "`")
129     }
130 }
131 
132 class SqlParser {
133     companion object {
134         private val INVALID_IDENTIFIER_CHARS = arrayOf('`', '\"')
parsenull135         fun parse(input: String): ParsedQuery {
136             val inputStream = ANTLRInputStream(input)
137             val lexer = SQLiteLexer(inputStream)
138             val tokenStream = CommonTokenStream(lexer)
139             val parser = SQLiteParser(tokenStream)
140             val syntaxErrors = arrayListOf<String>()
141             parser.addErrorListener(object : BaseErrorListener() {
142                 override fun syntaxError(
143                         recognizer: Recognizer<*, *>, offendingSymbol: Any,
144                         line: Int, charPositionInLine: Int, msg: String,
145                         e: RecognitionException?) {
146                     syntaxErrors.add(msg)
147                 }
148             })
149             try {
150                 val parsed = parser.parse()
151                 val statementList = parsed.sql_stmt_list()
152                 if (statementList.isEmpty()) {
153                     syntaxErrors.add(ParserErrors.NOT_ONE_QUERY)
154                     return ParsedQuery(input, QueryType.UNKNOWN, emptyList(), emptySet(),
155                             listOf(ParserErrors.NOT_ONE_QUERY), false)
156                 }
157                 val statements = statementList.first().children
158                         .filter { it is SQLiteParser.Sql_stmtContext }
159                 if (statements.size != 1) {
160                     syntaxErrors.add(ParserErrors.NOT_ONE_QUERY)
161                 }
162                 val statement = statements.first()
163                 return QueryVisitor(
164                         original = input,
165                         syntaxErrors = syntaxErrors,
166                         statement = statement,
167                         forRuntimeQuery = false).createParsedQuery()
168             } catch (antlrError: RuntimeException) {
169                 return ParsedQuery(input, QueryType.UNKNOWN, emptyList(), emptySet(),
170                         listOf("unknown error while parsing $input : ${antlrError.message}"),
171                         false)
172             }
173         }
174 
isValidIdentifiernull175         fun isValidIdentifier(input: String): Boolean =
176                 input.isNotBlank() && INVALID_IDENTIFIER_CHARS.none { input.contains(it) }
177 
178         /**
179          * creates a dummy select query for raw queries that queries the given list of tables.
180          */
rawQueryForTablesnull181         fun rawQueryForTables(tableNames: Set<String>): ParsedQuery {
182             return ParsedQuery(
183                     original = "raw query",
184                     type = QueryType.UNKNOWN,
185                     inputs = emptyList(),
186                     tables = tableNames.map { Table(name = it, alias = it) }.toSet(),
187                     syntaxErrors = emptyList(),
188                     runtimeQueryPlaceholder = true
189             )
190         }
191     }
192 }
193 
194 enum class QueryType {
195     UNKNOWN,
196     SELECT,
197     DELETE,
198     UPDATE,
199     EXPLAIN,
200     INSERT;
201 
202     companion object {
203         // IF you change this, don't forget to update @Query documentation.
204         val SUPPORTED = hashSetOf(SELECT, DELETE, UPDATE)
205     }
206 }
207 
208 enum class SQLTypeAffinity {
209     NULL,
210     TEXT,
211     INTEGER,
212     REAL,
213     BLOB;
214 
getTypeMirrorsnull215     fun getTypeMirrors(env: ProcessingEnvironment): List<TypeMirror>? {
216         val typeUtils = env.typeUtils
217         return when (this) {
218             TEXT -> listOf(env.elementUtils.getTypeElement("java.lang.String").asType())
219             INTEGER -> withBoxedTypes(env, TypeKind.INT, TypeKind.BYTE, TypeKind.CHAR,
220                     TypeKind.LONG, TypeKind.SHORT)
221             REAL -> withBoxedTypes(env, TypeKind.DOUBLE, TypeKind.FLOAT)
222             BLOB -> listOf(typeUtils.getArrayType(
223                     typeUtils.getPrimitiveType(TypeKind.BYTE)))
224             else -> emptyList()
225         }
226     }
227 
withBoxedTypesnull228     private fun withBoxedTypes(env: ProcessingEnvironment, vararg primitives: TypeKind):
229             List<TypeMirror> {
230         return primitives.flatMap {
231             val primitiveType = env.typeUtils.getPrimitiveType(it)
232             listOf(primitiveType, env.typeUtils.boxedClass(primitiveType).asType())
233         }
234     }
235 
236     companion object {
237         // converts from ColumnInfo#SQLiteTypeAffinity
fromAnnotationValuenull238         fun fromAnnotationValue(value: Int): SQLTypeAffinity? {
239             return when (value) {
240                 ColumnInfo.BLOB -> BLOB
241                 ColumnInfo.INTEGER -> INTEGER
242                 ColumnInfo.REAL -> REAL
243                 ColumnInfo.TEXT -> TEXT
244                 else -> null
245             }
246         }
247     }
248 }
249 
250 enum class Collate {
251     BINARY,
252     NOCASE,
253     RTRIM,
254     LOCALIZED,
255     UNICODE;
256 
257     companion object {
fromAnnotationValuenull258         fun fromAnnotationValue(value: Int): Collate? {
259             return when (value) {
260                 ColumnInfo.BINARY -> BINARY
261                 ColumnInfo.NOCASE -> NOCASE
262                 ColumnInfo.RTRIM -> RTRIM
263                 ColumnInfo.LOCALIZED -> LOCALIZED
264                 ColumnInfo.UNICODE -> UNICODE
265                 else -> null
266             }
267         }
268     }
269 }
270