1 /* 2 * Copyright (C) 2016 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package androidx.room.parser 18 19 import androidx.room.ColumnInfo 20 import org.antlr.v4.runtime.ANTLRInputStream 21 import org.antlr.v4.runtime.BaseErrorListener 22 import org.antlr.v4.runtime.CommonTokenStream 23 import org.antlr.v4.runtime.RecognitionException 24 import org.antlr.v4.runtime.Recognizer 25 import org.antlr.v4.runtime.tree.ParseTree 26 import org.antlr.v4.runtime.tree.TerminalNode 27 import javax.annotation.processing.ProcessingEnvironment 28 import javax.lang.model.type.TypeKind 29 import javax.lang.model.type.TypeMirror 30 31 @Suppress("FunctionName") 32 class QueryVisitor( 33 private val original: String, 34 private val syntaxErrors: ArrayList<String>, 35 statement: ParseTree, 36 private val forRuntimeQuery: Boolean 37 ) : SQLiteBaseVisitor<Void?>() { 38 private val bindingExpressions = arrayListOf<TerminalNode>() 39 // table name alias mappings 40 private val tableNames = mutableSetOf<Table>() 41 private val withClauseNames = mutableSetOf<String>() 42 private val queryType: QueryType 43 44 init { <lambda>null45 queryType = (0 until statement.childCount).map { 46 findQueryType(statement.getChild(it)) 47 }.filterNot { it == QueryType.UNKNOWN }.firstOrNull() ?: QueryType.UNKNOWN 48 49 statement.accept(this) 50 } 51 findQueryTypenull52 private fun findQueryType(statement: ParseTree): QueryType { 53 return when (statement) { 54 is SQLiteParser.Factored_select_stmtContext, 55 is SQLiteParser.Compound_select_stmtContext, 56 is SQLiteParser.Select_stmtContext, 57 is SQLiteParser.Simple_select_stmtContext -> 58 QueryType.SELECT 59 60 is SQLiteParser.Delete_stmt_limitedContext, 61 is SQLiteParser.Delete_stmtContext -> 62 QueryType.DELETE 63 64 is SQLiteParser.Insert_stmtContext -> 65 QueryType.INSERT 66 is SQLiteParser.Update_stmtContext, 67 is SQLiteParser.Update_stmt_limitedContext -> 68 QueryType.UPDATE 69 is TerminalNode -> when (statement.text) { 70 "EXPLAIN" -> QueryType.EXPLAIN 71 else -> QueryType.UNKNOWN 72 } 73 else -> QueryType.UNKNOWN 74 } 75 } 76 visitExprnull77 override fun visitExpr(ctx: SQLiteParser.ExprContext): Void? { 78 val bindParameter = ctx.BIND_PARAMETER() 79 if (bindParameter != null) { 80 bindingExpressions.add(bindParameter) 81 } 82 return super.visitExpr(ctx) 83 } 84 createParsedQuerynull85 fun createParsedQuery(): ParsedQuery { 86 return ParsedQuery( 87 original = original, 88 type = queryType, 89 inputs = bindingExpressions.sortedBy { it.sourceInterval.a }, 90 tables = tableNames, 91 syntaxErrors = syntaxErrors, 92 runtimeQueryPlaceholder = forRuntimeQuery) 93 } 94 visitCommon_table_expressionnull95 override fun visitCommon_table_expression( 96 ctx: SQLiteParser.Common_table_expressionContext): Void? { 97 val tableName = ctx.table_name()?.text 98 if (tableName != null) { 99 withClauseNames.add(unescapeIdentifier(tableName)) 100 } 101 return super.visitCommon_table_expression(ctx) 102 } 103 visitTable_or_subquerynull104 override fun visitTable_or_subquery(ctx: SQLiteParser.Table_or_subqueryContext): Void? { 105 val tableName = ctx.table_name()?.text 106 if (tableName != null) { 107 val tableAlias = ctx.table_alias()?.text 108 if (tableName !in withClauseNames) { 109 tableNames.add(Table( 110 unescapeIdentifier(tableName), 111 unescapeIdentifier(tableAlias ?: tableName))) 112 } 113 } 114 return super.visitTable_or_subquery(ctx) 115 } 116 unescapeIdentifiernull117 private fun unescapeIdentifier(text: String): String { 118 val trimmed = text.trim() 119 ESCAPE_LITERALS.forEach { 120 if (trimmed.startsWith(it) && trimmed.endsWith(it)) { 121 return unescapeIdentifier(trimmed.substring(1, trimmed.length - 1)) 122 } 123 } 124 return trimmed 125 } 126 127 companion object { 128 private val ESCAPE_LITERALS = listOf("\"", "'", "`") 129 } 130 } 131 132 class SqlParser { 133 companion object { 134 private val INVALID_IDENTIFIER_CHARS = arrayOf('`', '\"') parsenull135 fun parse(input: String): ParsedQuery { 136 val inputStream = ANTLRInputStream(input) 137 val lexer = SQLiteLexer(inputStream) 138 val tokenStream = CommonTokenStream(lexer) 139 val parser = SQLiteParser(tokenStream) 140 val syntaxErrors = arrayListOf<String>() 141 parser.addErrorListener(object : BaseErrorListener() { 142 override fun syntaxError( 143 recognizer: Recognizer<*, *>, offendingSymbol: Any, 144 line: Int, charPositionInLine: Int, msg: String, 145 e: RecognitionException?) { 146 syntaxErrors.add(msg) 147 } 148 }) 149 try { 150 val parsed = parser.parse() 151 val statementList = parsed.sql_stmt_list() 152 if (statementList.isEmpty()) { 153 syntaxErrors.add(ParserErrors.NOT_ONE_QUERY) 154 return ParsedQuery(input, QueryType.UNKNOWN, emptyList(), emptySet(), 155 listOf(ParserErrors.NOT_ONE_QUERY), false) 156 } 157 val statements = statementList.first().children 158 .filter { it is SQLiteParser.Sql_stmtContext } 159 if (statements.size != 1) { 160 syntaxErrors.add(ParserErrors.NOT_ONE_QUERY) 161 } 162 val statement = statements.first() 163 return QueryVisitor( 164 original = input, 165 syntaxErrors = syntaxErrors, 166 statement = statement, 167 forRuntimeQuery = false).createParsedQuery() 168 } catch (antlrError: RuntimeException) { 169 return ParsedQuery(input, QueryType.UNKNOWN, emptyList(), emptySet(), 170 listOf("unknown error while parsing $input : ${antlrError.message}"), 171 false) 172 } 173 } 174 isValidIdentifiernull175 fun isValidIdentifier(input: String): Boolean = 176 input.isNotBlank() && INVALID_IDENTIFIER_CHARS.none { input.contains(it) } 177 178 /** 179 * creates a dummy select query for raw queries that queries the given list of tables. 180 */ rawQueryForTablesnull181 fun rawQueryForTables(tableNames: Set<String>): ParsedQuery { 182 return ParsedQuery( 183 original = "raw query", 184 type = QueryType.UNKNOWN, 185 inputs = emptyList(), 186 tables = tableNames.map { Table(name = it, alias = it) }.toSet(), 187 syntaxErrors = emptyList(), 188 runtimeQueryPlaceholder = true 189 ) 190 } 191 } 192 } 193 194 enum class QueryType { 195 UNKNOWN, 196 SELECT, 197 DELETE, 198 UPDATE, 199 EXPLAIN, 200 INSERT; 201 202 companion object { 203 // IF you change this, don't forget to update @Query documentation. 204 val SUPPORTED = hashSetOf(SELECT, DELETE, UPDATE) 205 } 206 } 207 208 enum class SQLTypeAffinity { 209 NULL, 210 TEXT, 211 INTEGER, 212 REAL, 213 BLOB; 214 getTypeMirrorsnull215 fun getTypeMirrors(env: ProcessingEnvironment): List<TypeMirror>? { 216 val typeUtils = env.typeUtils 217 return when (this) { 218 TEXT -> listOf(env.elementUtils.getTypeElement("java.lang.String").asType()) 219 INTEGER -> withBoxedTypes(env, TypeKind.INT, TypeKind.BYTE, TypeKind.CHAR, 220 TypeKind.LONG, TypeKind.SHORT) 221 REAL -> withBoxedTypes(env, TypeKind.DOUBLE, TypeKind.FLOAT) 222 BLOB -> listOf(typeUtils.getArrayType( 223 typeUtils.getPrimitiveType(TypeKind.BYTE))) 224 else -> emptyList() 225 } 226 } 227 withBoxedTypesnull228 private fun withBoxedTypes(env: ProcessingEnvironment, vararg primitives: TypeKind): 229 List<TypeMirror> { 230 return primitives.flatMap { 231 val primitiveType = env.typeUtils.getPrimitiveType(it) 232 listOf(primitiveType, env.typeUtils.boxedClass(primitiveType).asType()) 233 } 234 } 235 236 companion object { 237 // converts from ColumnInfo#SQLiteTypeAffinity fromAnnotationValuenull238 fun fromAnnotationValue(value: Int): SQLTypeAffinity? { 239 return when (value) { 240 ColumnInfo.BLOB -> BLOB 241 ColumnInfo.INTEGER -> INTEGER 242 ColumnInfo.REAL -> REAL 243 ColumnInfo.TEXT -> TEXT 244 else -> null 245 } 246 } 247 } 248 } 249 250 enum class Collate { 251 BINARY, 252 NOCASE, 253 RTRIM, 254 LOCALIZED, 255 UNICODE; 256 257 companion object { fromAnnotationValuenull258 fun fromAnnotationValue(value: Int): Collate? { 259 return when (value) { 260 ColumnInfo.BINARY -> BINARY 261 ColumnInfo.NOCASE -> NOCASE 262 ColumnInfo.RTRIM -> RTRIM 263 ColumnInfo.LOCALIZED -> LOCALIZED 264 ColumnInfo.UNICODE -> UNICODE 265 else -> null 266 } 267 } 268 } 269 } 270