1 //===- Parser.h - Toy Language Parser -------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the parser for the Toy language. It processes the Token 10 // provided by the Lexer and returns an AST. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #ifndef MLIR_TUTORIAL_TOY_PARSER_H 15 #define MLIR_TUTORIAL_TOY_PARSER_H 16 17 #include "toy/AST.h" 18 #include "toy/Lexer.h" 19 20 #include "llvm/ADT/Optional.h" 21 #include "llvm/ADT/STLExtras.h" 22 #include "llvm/ADT/StringExtras.h" 23 #include "llvm/Support/raw_ostream.h" 24 25 #include <map> 26 #include <utility> 27 #include <vector> 28 29 namespace toy { 30 31 /// This is a simple recursive parser for the Toy language. It produces a well 32 /// formed AST from a stream of Token supplied by the Lexer. No semantic checks 33 /// or symbol resolution is performed. For example, variables are referenced by 34 /// string and the code could reference an undeclared variable and the parsing 35 /// succeeds. 36 class Parser { 37 public: 38 /// Create a Parser for the supplied lexer. Parser(Lexer & lexer)39 Parser(Lexer &lexer) : lexer(lexer) {} 40 41 /// Parse a full Module. A module is a list of function definitions. parseModule()42 std::unique_ptr<ModuleAST> parseModule() { 43 lexer.getNextToken(); // prime the lexer 44 45 // Parse functions one at a time and accumulate in this vector. 46 std::vector<FunctionAST> functions; 47 while (auto f = parseDefinition()) { 48 functions.push_back(std::move(*f)); 49 if (lexer.getCurToken() == tok_eof) 50 break; 51 } 52 // If we didn't reach EOF, there was an error during parsing 53 if (lexer.getCurToken() != tok_eof) 54 return parseError<ModuleAST>("nothing", "at end of module"); 55 56 return std::make_unique<ModuleAST>(std::move(functions)); 57 } 58 59 private: 60 Lexer &lexer; 61 62 /// Parse a return statement. 63 /// return :== return ; | return expr ; parseReturn()64 std::unique_ptr<ReturnExprAST> parseReturn() { 65 auto loc = lexer.getLastLocation(); 66 lexer.consume(tok_return); 67 68 // return takes an optional argument 69 llvm::Optional<std::unique_ptr<ExprAST>> expr; 70 if (lexer.getCurToken() != ';') { 71 expr = parseExpression(); 72 if (!expr) 73 return nullptr; 74 } 75 return std::make_unique<ReturnExprAST>(std::move(loc), std::move(expr)); 76 } 77 78 /// Parse a literal number. 79 /// numberexpr ::= number parseNumberExpr()80 std::unique_ptr<ExprAST> parseNumberExpr() { 81 auto loc = lexer.getLastLocation(); 82 auto result = 83 std::make_unique<NumberExprAST>(std::move(loc), lexer.getValue()); 84 lexer.consume(tok_number); 85 return std::move(result); 86 } 87 88 /// Parse a literal array expression. 89 /// tensorLiteral ::= [ literalList ] | number 90 /// literalList ::= tensorLiteral | tensorLiteral, literalList parseTensorLiteralExpr()91 std::unique_ptr<ExprAST> parseTensorLiteralExpr() { 92 auto loc = lexer.getLastLocation(); 93 lexer.consume(Token('[')); 94 95 // Hold the list of values at this nesting level. 96 std::vector<std::unique_ptr<ExprAST>> values; 97 // Hold the dimensions for all the nesting inside this level. 98 std::vector<int64_t> dims; 99 do { 100 // We can have either another nested array or a number literal. 101 if (lexer.getCurToken() == '[') { 102 values.push_back(parseTensorLiteralExpr()); 103 if (!values.back()) 104 return nullptr; // parse error in the nested array. 105 } else { 106 if (lexer.getCurToken() != tok_number) 107 return parseError<ExprAST>("<num> or [", "in literal expression"); 108 values.push_back(parseNumberExpr()); 109 } 110 111 // End of this list on ']' 112 if (lexer.getCurToken() == ']') 113 break; 114 115 // Elements are separated by a comma. 116 if (lexer.getCurToken() != ',') 117 return parseError<ExprAST>("] or ,", "in literal expression"); 118 119 lexer.getNextToken(); // eat , 120 } while (true); 121 if (values.empty()) 122 return parseError<ExprAST>("<something>", "to fill literal expression"); 123 lexer.getNextToken(); // eat ] 124 125 /// Fill in the dimensions now. First the current nesting level: 126 dims.push_back(values.size()); 127 128 /// If there is any nested array, process all of them and ensure that 129 /// dimensions are uniform. 130 if (llvm::any_of(values, [](std::unique_ptr<ExprAST> &expr) { 131 return llvm::isa<LiteralExprAST>(expr.get()); 132 })) { 133 auto *firstLiteral = llvm::dyn_cast<LiteralExprAST>(values.front().get()); 134 if (!firstLiteral) 135 return parseError<ExprAST>("uniform well-nested dimensions", 136 "inside literal expression"); 137 138 // Append the nested dimensions to the current level 139 auto firstDims = firstLiteral->getDims(); 140 dims.insert(dims.end(), firstDims.begin(), firstDims.end()); 141 142 // Sanity check that shape is uniform across all elements of the list. 143 for (auto &expr : values) { 144 auto *exprLiteral = llvm::cast<LiteralExprAST>(expr.get()); 145 if (!exprLiteral) 146 return parseError<ExprAST>("uniform well-nested dimensions", 147 "inside literal expression"); 148 if (exprLiteral->getDims() != firstDims) 149 return parseError<ExprAST>("uniform well-nested dimensions", 150 "inside literal expression"); 151 } 152 } 153 return std::make_unique<LiteralExprAST>(std::move(loc), std::move(values), 154 std::move(dims)); 155 } 156 157 /// parenexpr ::= '(' expression ')' parseParenExpr()158 std::unique_ptr<ExprAST> parseParenExpr() { 159 lexer.getNextToken(); // eat (. 160 auto v = parseExpression(); 161 if (!v) 162 return nullptr; 163 164 if (lexer.getCurToken() != ')') 165 return parseError<ExprAST>(")", "to close expression with parentheses"); 166 lexer.consume(Token(')')); 167 return v; 168 } 169 170 /// identifierexpr 171 /// ::= identifier 172 /// ::= identifier '(' expression ')' parseIdentifierExpr()173 std::unique_ptr<ExprAST> parseIdentifierExpr() { 174 std::string name(lexer.getId()); 175 176 auto loc = lexer.getLastLocation(); 177 lexer.getNextToken(); // eat identifier. 178 179 if (lexer.getCurToken() != '(') // Simple variable ref. 180 return std::make_unique<VariableExprAST>(std::move(loc), name); 181 182 // This is a function call. 183 lexer.consume(Token('(')); 184 std::vector<std::unique_ptr<ExprAST>> args; 185 if (lexer.getCurToken() != ')') { 186 while (true) { 187 if (auto arg = parseExpression()) 188 args.push_back(std::move(arg)); 189 else 190 return nullptr; 191 192 if (lexer.getCurToken() == ')') 193 break; 194 195 if (lexer.getCurToken() != ',') 196 return parseError<ExprAST>(", or )", "in argument list"); 197 lexer.getNextToken(); 198 } 199 } 200 lexer.consume(Token(')')); 201 202 // It can be a builtin call to print 203 if (name == "print") { 204 if (args.size() != 1) 205 return parseError<ExprAST>("<single arg>", "as argument to print()"); 206 207 return std::make_unique<PrintExprAST>(std::move(loc), std::move(args[0])); 208 } 209 210 // Call to a user-defined function 211 return std::make_unique<CallExprAST>(std::move(loc), name, std::move(args)); 212 } 213 214 /// primary 215 /// ::= identifierexpr 216 /// ::= numberexpr 217 /// ::= parenexpr 218 /// ::= tensorliteral parsePrimary()219 std::unique_ptr<ExprAST> parsePrimary() { 220 switch (lexer.getCurToken()) { 221 default: 222 llvm::errs() << "unknown token '" << lexer.getCurToken() 223 << "' when expecting an expression\n"; 224 return nullptr; 225 case tok_identifier: 226 return parseIdentifierExpr(); 227 case tok_number: 228 return parseNumberExpr(); 229 case '(': 230 return parseParenExpr(); 231 case '[': 232 return parseTensorLiteralExpr(); 233 case ';': 234 return nullptr; 235 case '}': 236 return nullptr; 237 } 238 } 239 240 /// Recursively parse the right hand side of a binary expression, the ExprPrec 241 /// argument indicates the precedence of the current binary operator. 242 /// 243 /// binoprhs ::= ('+' primary)* parseBinOpRHS(int exprPrec,std::unique_ptr<ExprAST> lhs)244 std::unique_ptr<ExprAST> parseBinOpRHS(int exprPrec, 245 std::unique_ptr<ExprAST> lhs) { 246 // If this is a binop, find its precedence. 247 while (true) { 248 int tokPrec = getTokPrecedence(); 249 250 // If this is a binop that binds at least as tightly as the current binop, 251 // consume it, otherwise we are done. 252 if (tokPrec < exprPrec) 253 return lhs; 254 255 // Okay, we know this is a binop. 256 int binOp = lexer.getCurToken(); 257 lexer.consume(Token(binOp)); 258 auto loc = lexer.getLastLocation(); 259 260 // Parse the primary expression after the binary operator. 261 auto rhs = parsePrimary(); 262 if (!rhs) 263 return parseError<ExprAST>("expression", "to complete binary operator"); 264 265 // If BinOp binds less tightly with rhs than the operator after rhs, let 266 // the pending operator take rhs as its lhs. 267 int nextPrec = getTokPrecedence(); 268 if (tokPrec < nextPrec) { 269 rhs = parseBinOpRHS(tokPrec + 1, std::move(rhs)); 270 if (!rhs) 271 return nullptr; 272 } 273 274 // Merge lhs/RHS. 275 lhs = std::make_unique<BinaryExprAST>(std::move(loc), binOp, 276 std::move(lhs), std::move(rhs)); 277 } 278 } 279 280 /// expression::= primary binop rhs parseExpression()281 std::unique_ptr<ExprAST> parseExpression() { 282 auto lhs = parsePrimary(); 283 if (!lhs) 284 return nullptr; 285 286 return parseBinOpRHS(0, std::move(lhs)); 287 } 288 289 /// type ::= < shape_list > 290 /// shape_list ::= num | num , shape_list parseType()291 std::unique_ptr<VarType> parseType() { 292 if (lexer.getCurToken() != '<') 293 return parseError<VarType>("<", "to begin type"); 294 lexer.getNextToken(); // eat < 295 296 auto type = std::make_unique<VarType>(); 297 298 while (lexer.getCurToken() == tok_number) { 299 type->shape.push_back(lexer.getValue()); 300 lexer.getNextToken(); 301 if (lexer.getCurToken() == ',') 302 lexer.getNextToken(); 303 } 304 305 if (lexer.getCurToken() != '>') 306 return parseError<VarType>(">", "to end type"); 307 lexer.getNextToken(); // eat > 308 return type; 309 } 310 311 /// Parse a variable declaration, it starts with a `var` keyword followed by 312 /// and identifier and an optional type (shape specification) before the 313 /// initializer. 314 /// decl ::= var identifier [ type ] = expr parseDeclaration()315 std::unique_ptr<VarDeclExprAST> parseDeclaration() { 316 if (lexer.getCurToken() != tok_var) 317 return parseError<VarDeclExprAST>("var", "to begin declaration"); 318 auto loc = lexer.getLastLocation(); 319 lexer.getNextToken(); // eat var 320 321 if (lexer.getCurToken() != tok_identifier) 322 return parseError<VarDeclExprAST>("identified", 323 "after 'var' declaration"); 324 std::string id(lexer.getId()); 325 lexer.getNextToken(); // eat id 326 327 std::unique_ptr<VarType> type; // Type is optional, it can be inferred 328 if (lexer.getCurToken() == '<') { 329 type = parseType(); 330 if (!type) 331 return nullptr; 332 } 333 334 if (!type) 335 type = std::make_unique<VarType>(); 336 lexer.consume(Token('=')); 337 auto expr = parseExpression(); 338 return std::make_unique<VarDeclExprAST>(std::move(loc), std::move(id), 339 std::move(*type), std::move(expr)); 340 } 341 342 /// Parse a block: a list of expression separated by semicolons and wrapped in 343 /// curly braces. 344 /// 345 /// block ::= { expression_list } 346 /// expression_list ::= block_expr ; expression_list 347 /// block_expr ::= decl | "return" | expr parseBlock()348 std::unique_ptr<ExprASTList> parseBlock() { 349 if (lexer.getCurToken() != '{') 350 return parseError<ExprASTList>("{", "to begin block"); 351 lexer.consume(Token('{')); 352 353 auto exprList = std::make_unique<ExprASTList>(); 354 355 // Ignore empty expressions: swallow sequences of semicolons. 356 while (lexer.getCurToken() == ';') 357 lexer.consume(Token(';')); 358 359 while (lexer.getCurToken() != '}' && lexer.getCurToken() != tok_eof) { 360 if (lexer.getCurToken() == tok_var) { 361 // Variable declaration 362 auto varDecl = parseDeclaration(); 363 if (!varDecl) 364 return nullptr; 365 exprList->push_back(std::move(varDecl)); 366 } else if (lexer.getCurToken() == tok_return) { 367 // Return statement 368 auto ret = parseReturn(); 369 if (!ret) 370 return nullptr; 371 exprList->push_back(std::move(ret)); 372 } else { 373 // General expression 374 auto expr = parseExpression(); 375 if (!expr) 376 return nullptr; 377 exprList->push_back(std::move(expr)); 378 } 379 // Ensure that elements are separated by a semicolon. 380 if (lexer.getCurToken() != ';') 381 return parseError<ExprASTList>(";", "after expression"); 382 383 // Ignore empty expressions: swallow sequences of semicolons. 384 while (lexer.getCurToken() == ';') 385 lexer.consume(Token(';')); 386 } 387 388 if (lexer.getCurToken() != '}') 389 return parseError<ExprASTList>("}", "to close block"); 390 391 lexer.consume(Token('}')); 392 return exprList; 393 } 394 395 /// prototype ::= def id '(' decl_list ')' 396 /// decl_list ::= identifier | identifier, decl_list parsePrototype()397 std::unique_ptr<PrototypeAST> parsePrototype() { 398 auto loc = lexer.getLastLocation(); 399 400 if (lexer.getCurToken() != tok_def) 401 return parseError<PrototypeAST>("def", "in prototype"); 402 lexer.consume(tok_def); 403 404 if (lexer.getCurToken() != tok_identifier) 405 return parseError<PrototypeAST>("function name", "in prototype"); 406 407 std::string fnName(lexer.getId()); 408 lexer.consume(tok_identifier); 409 410 if (lexer.getCurToken() != '(') 411 return parseError<PrototypeAST>("(", "in prototype"); 412 lexer.consume(Token('(')); 413 414 std::vector<std::unique_ptr<VariableExprAST>> args; 415 if (lexer.getCurToken() != ')') { 416 do { 417 std::string name(lexer.getId()); 418 auto loc = lexer.getLastLocation(); 419 lexer.consume(tok_identifier); 420 auto decl = std::make_unique<VariableExprAST>(std::move(loc), name); 421 args.push_back(std::move(decl)); 422 if (lexer.getCurToken() != ',') 423 break; 424 lexer.consume(Token(',')); 425 if (lexer.getCurToken() != tok_identifier) 426 return parseError<PrototypeAST>( 427 "identifier", "after ',' in function parameter list"); 428 } while (true); 429 } 430 if (lexer.getCurToken() != ')') 431 return parseError<PrototypeAST>(")", "to end function prototype"); 432 433 // success. 434 lexer.consume(Token(')')); 435 return std::make_unique<PrototypeAST>(std::move(loc), fnName, 436 std::move(args)); 437 } 438 439 /// Parse a function definition, we expect a prototype initiated with the 440 /// `def` keyword, followed by a block containing a list of expressions. 441 /// 442 /// definition ::= prototype block parseDefinition()443 std::unique_ptr<FunctionAST> parseDefinition() { 444 auto proto = parsePrototype(); 445 if (!proto) 446 return nullptr; 447 448 if (auto block = parseBlock()) 449 return std::make_unique<FunctionAST>(std::move(proto), std::move(block)); 450 return nullptr; 451 } 452 453 /// Get the precedence of the pending binary operator token. getTokPrecedence()454 int getTokPrecedence() { 455 if (!isascii(lexer.getCurToken())) 456 return -1; 457 458 // 1 is lowest precedence. 459 switch (static_cast<char>(lexer.getCurToken())) { 460 case '-': 461 return 20; 462 case '+': 463 return 20; 464 case '*': 465 return 40; 466 default: 467 return -1; 468 } 469 } 470 471 /// Helper function to signal errors while parsing, it takes an argument 472 /// indicating the expected token and another argument giving more context. 473 /// Location is retrieved from the lexer to enrich the error message. 474 template <typename R, typename T, typename U = const char *> 475 std::unique_ptr<R> parseError(T &&expected, U &&context = "") { 476 auto curToken = lexer.getCurToken(); 477 llvm::errs() << "Parse error (" << lexer.getLastLocation().line << ", " 478 << lexer.getLastLocation().col << "): expected '" << expected 479 << "' " << context << " but has Token " << curToken; 480 if (isprint(curToken)) 481 llvm::errs() << " '" << (char)curToken << "'"; 482 llvm::errs() << "\n"; 483 return nullptr; 484 } 485 }; 486 487 } // namespace toy 488 489 #endif // MLIR_TUTORIAL_TOY_PARSER_H 490