1 #include "llvm/ADT/APFloat.h"
2 #include "llvm/ADT/STLExtras.h"
3 #include "llvm/IR/BasicBlock.h"
4 #include "llvm/IR/Constants.h"
5 #include "llvm/IR/DerivedTypes.h"
6 #include "llvm/IR/Function.h"
7 #include "llvm/IR/IRBuilder.h"
8 #include "llvm/IR/LLVMContext.h"
9 #include "llvm/IR/LegacyPassManager.h"
10 #include "llvm/IR/Module.h"
11 #include "llvm/IR/Type.h"
12 #include "llvm/IR/Verifier.h"
13 #include "llvm/Support/TargetSelect.h"
14 #include "llvm/Target/TargetMachine.h"
15 #include "llvm/Transforms/Scalar.h"
16 #include "llvm/Transforms/Scalar/GVN.h"
17 #include "../include/KaleidoscopeJIT.h"
18 #include <cassert>
19 #include <cctype>
20 #include <cstdint>
21 #include <cstdio>
22 #include <cstdlib>
23 #include <map>
24 #include <memory>
25 #include <string>
26 #include <vector>
27 
28 using namespace llvm;
29 using namespace llvm::orc;
30 
31 //===----------------------------------------------------------------------===//
32 // Lexer
33 //===----------------------------------------------------------------------===//
34 
35 // The lexer returns tokens [0-255] if it is an unknown character, otherwise one
36 // of these for known things.
37 enum Token {
38   tok_eof = -1,
39 
40   // commands
41   tok_def = -2,
42   tok_extern = -3,
43 
44   // primary
45   tok_identifier = -4,
46   tok_number = -5
47 };
48 
49 static std::string IdentifierStr; // Filled in if tok_identifier
50 static double NumVal;             // Filled in if tok_number
51 
52 /// gettok - Return the next token from standard input.
gettok()53 static int gettok() {
54   static int LastChar = ' ';
55 
56   // Skip any whitespace.
57   while (isspace(LastChar))
58     LastChar = getchar();
59 
60   if (isalpha(LastChar)) { // identifier: [a-zA-Z][a-zA-Z0-9]*
61     IdentifierStr = LastChar;
62     while (isalnum((LastChar = getchar())))
63       IdentifierStr += LastChar;
64 
65     if (IdentifierStr == "def")
66       return tok_def;
67     if (IdentifierStr == "extern")
68       return tok_extern;
69     return tok_identifier;
70   }
71 
72   if (isdigit(LastChar) || LastChar == '.') { // Number: [0-9.]+
73     std::string NumStr;
74     do {
75       NumStr += LastChar;
76       LastChar = getchar();
77     } while (isdigit(LastChar) || LastChar == '.');
78 
79     NumVal = strtod(NumStr.c_str(), nullptr);
80     return tok_number;
81   }
82 
83   if (LastChar == '#') {
84     // Comment until end of line.
85     do
86       LastChar = getchar();
87     while (LastChar != EOF && LastChar != '\n' && LastChar != '\r');
88 
89     if (LastChar != EOF)
90       return gettok();
91   }
92 
93   // Check for end of file.  Don't eat the EOF.
94   if (LastChar == EOF)
95     return tok_eof;
96 
97   // Otherwise, just return the character as its ascii value.
98   int ThisChar = LastChar;
99   LastChar = getchar();
100   return ThisChar;
101 }
102 
103 //===----------------------------------------------------------------------===//
104 // Abstract Syntax Tree (aka Parse Tree)
105 //===----------------------------------------------------------------------===//
106 namespace {
107 /// ExprAST - Base class for all expression nodes.
108 class ExprAST {
109 public:
~ExprAST()110   virtual ~ExprAST() {}
111   virtual Value *codegen() = 0;
112 };
113 
114 /// NumberExprAST - Expression class for numeric literals like "1.0".
115 class NumberExprAST : public ExprAST {
116   double Val;
117 
118 public:
NumberExprAST(double Val)119   NumberExprAST(double Val) : Val(Val) {}
120   Value *codegen() override;
121 };
122 
123 /// VariableExprAST - Expression class for referencing a variable, like "a".
124 class VariableExprAST : public ExprAST {
125   std::string Name;
126 
127 public:
VariableExprAST(const std::string & Name)128   VariableExprAST(const std::string &Name) : Name(Name) {}
129   Value *codegen() override;
130 };
131 
132 /// BinaryExprAST - Expression class for a binary operator.
133 class BinaryExprAST : public ExprAST {
134   char Op;
135   std::unique_ptr<ExprAST> LHS, RHS;
136 
137 public:
BinaryExprAST(char Op,std::unique_ptr<ExprAST> LHS,std::unique_ptr<ExprAST> RHS)138   BinaryExprAST(char Op, std::unique_ptr<ExprAST> LHS,
139                 std::unique_ptr<ExprAST> RHS)
140       : Op(Op), LHS(std::move(LHS)), RHS(std::move(RHS)) {}
141   Value *codegen() override;
142 };
143 
144 /// CallExprAST - Expression class for function calls.
145 class CallExprAST : public ExprAST {
146   std::string Callee;
147   std::vector<std::unique_ptr<ExprAST>> Args;
148 
149 public:
CallExprAST(const std::string & Callee,std::vector<std::unique_ptr<ExprAST>> Args)150   CallExprAST(const std::string &Callee,
151               std::vector<std::unique_ptr<ExprAST>> Args)
152       : Callee(Callee), Args(std::move(Args)) {}
153   Value *codegen() override;
154 };
155 
156 /// PrototypeAST - This class represents the "prototype" for a function,
157 /// which captures its name, and its argument names (thus implicitly the number
158 /// of arguments the function takes).
159 class PrototypeAST {
160   std::string Name;
161   std::vector<std::string> Args;
162 
163 public:
PrototypeAST(const std::string & Name,std::vector<std::string> Args)164   PrototypeAST(const std::string &Name, std::vector<std::string> Args)
165       : Name(Name), Args(std::move(Args)) {}
166   Function *codegen();
getName() const167   const std::string &getName() const { return Name; }
168 };
169 
170 /// FunctionAST - This class represents a function definition itself.
171 class FunctionAST {
172   std::unique_ptr<PrototypeAST> Proto;
173   std::unique_ptr<ExprAST> Body;
174 
175 public:
FunctionAST(std::unique_ptr<PrototypeAST> Proto,std::unique_ptr<ExprAST> Body)176   FunctionAST(std::unique_ptr<PrototypeAST> Proto,
177               std::unique_ptr<ExprAST> Body)
178       : Proto(std::move(Proto)), Body(std::move(Body)) {}
179   Function *codegen();
180 };
181 } // end anonymous namespace
182 
183 //===----------------------------------------------------------------------===//
184 // Parser
185 //===----------------------------------------------------------------------===//
186 
187 /// CurTok/getNextToken - Provide a simple token buffer.  CurTok is the current
188 /// token the parser is looking at.  getNextToken reads another token from the
189 /// lexer and updates CurTok with its results.
190 static int CurTok;
getNextToken()191 static int getNextToken() { return CurTok = gettok(); }
192 
193 /// BinopPrecedence - This holds the precedence for each binary operator that is
194 /// defined.
195 static std::map<char, int> BinopPrecedence;
196 
197 /// GetTokPrecedence - Get the precedence of the pending binary operator token.
GetTokPrecedence()198 static int GetTokPrecedence() {
199   if (!isascii(CurTok))
200     return -1;
201 
202   // Make sure it's a declared binop.
203   int TokPrec = BinopPrecedence[CurTok];
204   if (TokPrec <= 0)
205     return -1;
206   return TokPrec;
207 }
208 
209 /// LogError* - These are little helper functions for error handling.
LogError(const char * Str)210 std::unique_ptr<ExprAST> LogError(const char *Str) {
211   fprintf(stderr, "Error: %s\n", Str);
212   return nullptr;
213 }
214 
LogErrorP(const char * Str)215 std::unique_ptr<PrototypeAST> LogErrorP(const char *Str) {
216   LogError(Str);
217   return nullptr;
218 }
219 
220 static std::unique_ptr<ExprAST> ParseExpression();
221 
222 /// numberexpr ::= number
ParseNumberExpr()223 static std::unique_ptr<ExprAST> ParseNumberExpr() {
224   auto Result = llvm::make_unique<NumberExprAST>(NumVal);
225   getNextToken(); // consume the number
226   return std::move(Result);
227 }
228 
229 /// parenexpr ::= '(' expression ')'
ParseParenExpr()230 static std::unique_ptr<ExprAST> ParseParenExpr() {
231   getNextToken(); // eat (.
232   auto V = ParseExpression();
233   if (!V)
234     return nullptr;
235 
236   if (CurTok != ')')
237     return LogError("expected ')'");
238   getNextToken(); // eat ).
239   return V;
240 }
241 
242 /// identifierexpr
243 ///   ::= identifier
244 ///   ::= identifier '(' expression* ')'
ParseIdentifierExpr()245 static std::unique_ptr<ExprAST> ParseIdentifierExpr() {
246   std::string IdName = IdentifierStr;
247 
248   getNextToken(); // eat identifier.
249 
250   if (CurTok != '(') // Simple variable ref.
251     return llvm::make_unique<VariableExprAST>(IdName);
252 
253   // Call.
254   getNextToken(); // eat (
255   std::vector<std::unique_ptr<ExprAST>> Args;
256   if (CurTok != ')') {
257     while (true) {
258       if (auto Arg = ParseExpression())
259         Args.push_back(std::move(Arg));
260       else
261         return nullptr;
262 
263       if (CurTok == ')')
264         break;
265 
266       if (CurTok != ',')
267         return LogError("Expected ')' or ',' in argument list");
268       getNextToken();
269     }
270   }
271 
272   // Eat the ')'.
273   getNextToken();
274 
275   return llvm::make_unique<CallExprAST>(IdName, std::move(Args));
276 }
277 
278 /// primary
279 ///   ::= identifierexpr
280 ///   ::= numberexpr
281 ///   ::= parenexpr
ParsePrimary()282 static std::unique_ptr<ExprAST> ParsePrimary() {
283   switch (CurTok) {
284   default:
285     return LogError("unknown token when expecting an expression");
286   case tok_identifier:
287     return ParseIdentifierExpr();
288   case tok_number:
289     return ParseNumberExpr();
290   case '(':
291     return ParseParenExpr();
292   }
293 }
294 
295 /// binoprhs
296 ///   ::= ('+' primary)*
ParseBinOpRHS(int ExprPrec,std::unique_ptr<ExprAST> LHS)297 static std::unique_ptr<ExprAST> ParseBinOpRHS(int ExprPrec,
298                                               std::unique_ptr<ExprAST> LHS) {
299   // If this is a binop, find its precedence.
300   while (true) {
301     int TokPrec = GetTokPrecedence();
302 
303     // If this is a binop that binds at least as tightly as the current binop,
304     // consume it, otherwise we are done.
305     if (TokPrec < ExprPrec)
306       return LHS;
307 
308     // Okay, we know this is a binop.
309     int BinOp = CurTok;
310     getNextToken(); // eat binop
311 
312     // Parse the primary expression after the binary operator.
313     auto RHS = ParsePrimary();
314     if (!RHS)
315       return nullptr;
316 
317     // If BinOp binds less tightly with RHS than the operator after RHS, let
318     // the pending operator take RHS as its LHS.
319     int NextPrec = GetTokPrecedence();
320     if (TokPrec < NextPrec) {
321       RHS = ParseBinOpRHS(TokPrec + 1, std::move(RHS));
322       if (!RHS)
323         return nullptr;
324     }
325 
326     // Merge LHS/RHS.
327     LHS =
328         llvm::make_unique<BinaryExprAST>(BinOp, std::move(LHS), std::move(RHS));
329   }
330 }
331 
332 /// expression
333 ///   ::= primary binoprhs
334 ///
ParseExpression()335 static std::unique_ptr<ExprAST> ParseExpression() {
336   auto LHS = ParsePrimary();
337   if (!LHS)
338     return nullptr;
339 
340   return ParseBinOpRHS(0, std::move(LHS));
341 }
342 
343 /// prototype
344 ///   ::= id '(' id* ')'
ParsePrototype()345 static std::unique_ptr<PrototypeAST> ParsePrototype() {
346   if (CurTok != tok_identifier)
347     return LogErrorP("Expected function name in prototype");
348 
349   std::string FnName = IdentifierStr;
350   getNextToken();
351 
352   if (CurTok != '(')
353     return LogErrorP("Expected '(' in prototype");
354 
355   std::vector<std::string> ArgNames;
356   while (getNextToken() == tok_identifier)
357     ArgNames.push_back(IdentifierStr);
358   if (CurTok != ')')
359     return LogErrorP("Expected ')' in prototype");
360 
361   // success.
362   getNextToken(); // eat ')'.
363 
364   return llvm::make_unique<PrototypeAST>(FnName, std::move(ArgNames));
365 }
366 
367 /// definition ::= 'def' prototype expression
ParseDefinition()368 static std::unique_ptr<FunctionAST> ParseDefinition() {
369   getNextToken(); // eat def.
370   auto Proto = ParsePrototype();
371   if (!Proto)
372     return nullptr;
373 
374   if (auto E = ParseExpression())
375     return llvm::make_unique<FunctionAST>(std::move(Proto), std::move(E));
376   return nullptr;
377 }
378 
379 /// toplevelexpr ::= expression
ParseTopLevelExpr()380 static std::unique_ptr<FunctionAST> ParseTopLevelExpr() {
381   if (auto E = ParseExpression()) {
382     // Make an anonymous proto.
383     auto Proto = llvm::make_unique<PrototypeAST>("__anon_expr",
384                                                  std::vector<std::string>());
385     return llvm::make_unique<FunctionAST>(std::move(Proto), std::move(E));
386   }
387   return nullptr;
388 }
389 
390 /// external ::= 'extern' prototype
ParseExtern()391 static std::unique_ptr<PrototypeAST> ParseExtern() {
392   getNextToken(); // eat extern.
393   return ParsePrototype();
394 }
395 
396 //===----------------------------------------------------------------------===//
397 // Code Generation
398 //===----------------------------------------------------------------------===//
399 
400 static LLVMContext TheContext;
401 static IRBuilder<> Builder(TheContext);
402 static std::unique_ptr<Module> TheModule;
403 static std::map<std::string, Value *> NamedValues;
404 static std::unique_ptr<legacy::FunctionPassManager> TheFPM;
405 static std::unique_ptr<KaleidoscopeJIT> TheJIT;
406 static std::map<std::string, std::unique_ptr<PrototypeAST>> FunctionProtos;
407 
LogErrorV(const char * Str)408 Value *LogErrorV(const char *Str) {
409   LogError(Str);
410   return nullptr;
411 }
412 
getFunction(std::string Name)413 Function *getFunction(std::string Name) {
414   // First, see if the function has already been added to the current module.
415   if (auto *F = TheModule->getFunction(Name))
416     return F;
417 
418   // If not, check whether we can codegen the declaration from some existing
419   // prototype.
420   auto FI = FunctionProtos.find(Name);
421   if (FI != FunctionProtos.end())
422     return FI->second->codegen();
423 
424   // If no existing prototype exists, return null.
425   return nullptr;
426 }
427 
codegen()428 Value *NumberExprAST::codegen() {
429   return ConstantFP::get(TheContext, APFloat(Val));
430 }
431 
codegen()432 Value *VariableExprAST::codegen() {
433   // Look this variable up in the function.
434   Value *V = NamedValues[Name];
435   if (!V)
436     return LogErrorV("Unknown variable name");
437   return V;
438 }
439 
codegen()440 Value *BinaryExprAST::codegen() {
441   Value *L = LHS->codegen();
442   Value *R = RHS->codegen();
443   if (!L || !R)
444     return nullptr;
445 
446   switch (Op) {
447   case '+':
448     return Builder.CreateFAdd(L, R, "addtmp");
449   case '-':
450     return Builder.CreateFSub(L, R, "subtmp");
451   case '*':
452     return Builder.CreateFMul(L, R, "multmp");
453   case '<':
454     L = Builder.CreateFCmpULT(L, R, "cmptmp");
455     // Convert bool 0/1 to double 0.0 or 1.0
456     return Builder.CreateUIToFP(L, Type::getDoubleTy(TheContext), "booltmp");
457   default:
458     return LogErrorV("invalid binary operator");
459   }
460 }
461 
codegen()462 Value *CallExprAST::codegen() {
463   // Look up the name in the global module table.
464   Function *CalleeF = getFunction(Callee);
465   if (!CalleeF)
466     return LogErrorV("Unknown function referenced");
467 
468   // If argument mismatch error.
469   if (CalleeF->arg_size() != Args.size())
470     return LogErrorV("Incorrect # arguments passed");
471 
472   std::vector<Value *> ArgsV;
473   for (unsigned i = 0, e = Args.size(); i != e; ++i) {
474     ArgsV.push_back(Args[i]->codegen());
475     if (!ArgsV.back())
476       return nullptr;
477   }
478 
479   return Builder.CreateCall(CalleeF, ArgsV, "calltmp");
480 }
481 
codegen()482 Function *PrototypeAST::codegen() {
483   // Make the function type:  double(double,double) etc.
484   std::vector<Type *> Doubles(Args.size(), Type::getDoubleTy(TheContext));
485   FunctionType *FT =
486       FunctionType::get(Type::getDoubleTy(TheContext), Doubles, false);
487 
488   Function *F =
489       Function::Create(FT, Function::ExternalLinkage, Name, TheModule.get());
490 
491   // Set names for all arguments.
492   unsigned Idx = 0;
493   for (auto &Arg : F->args())
494     Arg.setName(Args[Idx++]);
495 
496   return F;
497 }
498 
codegen()499 Function *FunctionAST::codegen() {
500   // Transfer ownership of the prototype to the FunctionProtos map, but keep a
501   // reference to it for use below.
502   auto &P = *Proto;
503   FunctionProtos[Proto->getName()] = std::move(Proto);
504   Function *TheFunction = getFunction(P.getName());
505   if (!TheFunction)
506     return nullptr;
507 
508   // Create a new basic block to start insertion into.
509   BasicBlock *BB = BasicBlock::Create(TheContext, "entry", TheFunction);
510   Builder.SetInsertPoint(BB);
511 
512   // Record the function arguments in the NamedValues map.
513   NamedValues.clear();
514   for (auto &Arg : TheFunction->args())
515     NamedValues[Arg.getName()] = &Arg;
516 
517   if (Value *RetVal = Body->codegen()) {
518     // Finish off the function.
519     Builder.CreateRet(RetVal);
520 
521     // Validate the generated code, checking for consistency.
522     verifyFunction(*TheFunction);
523 
524     // Run the optimizer on the function.
525     TheFPM->run(*TheFunction);
526 
527     return TheFunction;
528   }
529 
530   // Error reading body, remove function.
531   TheFunction->eraseFromParent();
532   return nullptr;
533 }
534 
535 //===----------------------------------------------------------------------===//
536 // Top-Level parsing and JIT Driver
537 //===----------------------------------------------------------------------===//
538 
InitializeModuleAndPassManager()539 static void InitializeModuleAndPassManager() {
540   // Open a new module.
541   TheModule = llvm::make_unique<Module>("my cool jit", TheContext);
542   TheModule->setDataLayout(TheJIT->getTargetMachine().createDataLayout());
543 
544   // Create a new pass manager attached to it.
545   TheFPM = llvm::make_unique<legacy::FunctionPassManager>(TheModule.get());
546 
547   // Do simple "peephole" optimizations and bit-twiddling optzns.
548   TheFPM->add(createInstructionCombiningPass());
549   // Reassociate expressions.
550   TheFPM->add(createReassociatePass());
551   // Eliminate Common SubExpressions.
552   TheFPM->add(createGVNPass());
553   // Simplify the control flow graph (deleting unreachable blocks, etc).
554   TheFPM->add(createCFGSimplificationPass());
555 
556   TheFPM->doInitialization();
557 }
558 
HandleDefinition()559 static void HandleDefinition() {
560   if (auto FnAST = ParseDefinition()) {
561     if (auto *FnIR = FnAST->codegen()) {
562       fprintf(stderr, "Read function definition:");
563       FnIR->dump();
564       TheJIT->addModule(std::move(TheModule));
565       InitializeModuleAndPassManager();
566     }
567   } else {
568     // Skip token for error recovery.
569     getNextToken();
570   }
571 }
572 
HandleExtern()573 static void HandleExtern() {
574   if (auto ProtoAST = ParseExtern()) {
575     if (auto *FnIR = ProtoAST->codegen()) {
576       fprintf(stderr, "Read extern: ");
577       FnIR->dump();
578       FunctionProtos[ProtoAST->getName()] = std::move(ProtoAST);
579     }
580   } else {
581     // Skip token for error recovery.
582     getNextToken();
583   }
584 }
585 
HandleTopLevelExpression()586 static void HandleTopLevelExpression() {
587   // Evaluate a top-level expression into an anonymous function.
588   if (auto FnAST = ParseTopLevelExpr()) {
589     if (FnAST->codegen()) {
590       // JIT the module containing the anonymous expression, keeping a handle so
591       // we can free it later.
592       auto H = TheJIT->addModule(std::move(TheModule));
593       InitializeModuleAndPassManager();
594 
595       // Search the JIT for the __anon_expr symbol.
596       auto ExprSymbol = TheJIT->findSymbol("__anon_expr");
597       assert(ExprSymbol && "Function not found");
598 
599       // Get the symbol's address and cast it to the right type (takes no
600       // arguments, returns a double) so we can call it as a native function.
601       double (*FP)() = (double (*)())(intptr_t)ExprSymbol.getAddress();
602       fprintf(stderr, "Evaluated to %f\n", FP());
603 
604       // Delete the anonymous expression module from the JIT.
605       TheJIT->removeModule(H);
606     }
607   } else {
608     // Skip token for error recovery.
609     getNextToken();
610   }
611 }
612 
613 /// top ::= definition | external | expression | ';'
MainLoop()614 static void MainLoop() {
615   while (true) {
616     fprintf(stderr, "ready> ");
617     switch (CurTok) {
618     case tok_eof:
619       return;
620     case ';': // ignore top-level semicolons.
621       getNextToken();
622       break;
623     case tok_def:
624       HandleDefinition();
625       break;
626     case tok_extern:
627       HandleExtern();
628       break;
629     default:
630       HandleTopLevelExpression();
631       break;
632     }
633   }
634 }
635 
636 //===----------------------------------------------------------------------===//
637 // "Library" functions that can be "extern'd" from user code.
638 //===----------------------------------------------------------------------===//
639 
640 /// putchard - putchar that takes a double and returns 0.
putchard(double X)641 extern "C" double putchard(double X) {
642   fputc((char)X, stderr);
643   return 0;
644 }
645 
646 /// printd - printf that takes a double prints it as "%f\n", returning 0.
printd(double X)647 extern "C" double printd(double X) {
648   fprintf(stderr, "%f\n", X);
649   return 0;
650 }
651 
652 //===----------------------------------------------------------------------===//
653 // Main driver code.
654 //===----------------------------------------------------------------------===//
655 
main()656 int main() {
657   InitializeNativeTarget();
658   InitializeNativeTargetAsmPrinter();
659   InitializeNativeTargetAsmParser();
660 
661   // Install standard binary operators.
662   // 1 is lowest precedence.
663   BinopPrecedence['<'] = 10;
664   BinopPrecedence['+'] = 20;
665   BinopPrecedence['-'] = 20;
666   BinopPrecedence['*'] = 40; // highest.
667 
668   // Prime the first token.
669   fprintf(stderr, "ready> ");
670   getNextToken();
671 
672   TheJIT = llvm::make_unique<KaleidoscopeJIT>();
673 
674   InitializeModuleAndPassManager();
675 
676   // Run the main "interpreter loop" now.
677   MainLoop();
678 
679   return 0;
680 }
681