1 #include "llvm/ADT/APFloat.h"
2 #include "llvm/ADT/STLExtras.h"
3 #include "llvm/IR/BasicBlock.h"
4 #include "llvm/IR/Constants.h"
5 #include "llvm/IR/DerivedTypes.h"
6 #include "llvm/IR/Function.h"
7 #include "llvm/IR/Instructions.h"
8 #include "llvm/IR/IRBuilder.h"
9 #include "llvm/IR/LLVMContext.h"
10 #include "llvm/IR/LegacyPassManager.h"
11 #include "llvm/IR/Module.h"
12 #include "llvm/IR/Type.h"
13 #include "llvm/IR/Verifier.h"
14 #include "llvm/Support/TargetSelect.h"
15 #include "llvm/Target/TargetMachine.h"
16 #include "llvm/Transforms/Scalar.h"
17 #include "llvm/Transforms/Scalar/GVN.h"
18 #include "../include/KaleidoscopeJIT.h"
19 #include <cassert>
20 #include <cctype>
21 #include <cstdint>
22 #include <cstdio>
23 #include <cstdlib>
24 #include <map>
25 #include <memory>
26 #include <string>
27 #include <vector>
28 
29 using namespace llvm;
30 using namespace llvm::orc;
31 
32 //===----------------------------------------------------------------------===//
33 // Lexer
34 //===----------------------------------------------------------------------===//
35 
36 // The lexer returns tokens [0-255] if it is an unknown character, otherwise one
37 // of these for known things.
38 enum Token {
39   tok_eof = -1,
40 
41   // commands
42   tok_def = -2,
43   tok_extern = -3,
44 
45   // primary
46   tok_identifier = -4,
47   tok_number = -5,
48 
49   // control
50   tok_if = -6,
51   tok_then = -7,
52   tok_else = -8,
53   tok_for = -9,
54   tok_in = -10
55 };
56 
57 static std::string IdentifierStr; // Filled in if tok_identifier
58 static double NumVal;             // Filled in if tok_number
59 
60 /// gettok - Return the next token from standard input.
gettok()61 static int gettok() {
62   static int LastChar = ' ';
63 
64   // Skip any whitespace.
65   while (isspace(LastChar))
66     LastChar = getchar();
67 
68   if (isalpha(LastChar)) { // identifier: [a-zA-Z][a-zA-Z0-9]*
69     IdentifierStr = LastChar;
70     while (isalnum((LastChar = getchar())))
71       IdentifierStr += LastChar;
72 
73     if (IdentifierStr == "def")
74       return tok_def;
75     if (IdentifierStr == "extern")
76       return tok_extern;
77     if (IdentifierStr == "if")
78       return tok_if;
79     if (IdentifierStr == "then")
80       return tok_then;
81     if (IdentifierStr == "else")
82       return tok_else;
83     if (IdentifierStr == "for")
84       return tok_for;
85     if (IdentifierStr == "in")
86       return tok_in;
87     return tok_identifier;
88   }
89 
90   if (isdigit(LastChar) || LastChar == '.') { // Number: [0-9.]+
91     std::string NumStr;
92     do {
93       NumStr += LastChar;
94       LastChar = getchar();
95     } while (isdigit(LastChar) || LastChar == '.');
96 
97     NumVal = strtod(NumStr.c_str(), nullptr);
98     return tok_number;
99   }
100 
101   if (LastChar == '#') {
102     // Comment until end of line.
103     do
104       LastChar = getchar();
105     while (LastChar != EOF && LastChar != '\n' && LastChar != '\r');
106 
107     if (LastChar != EOF)
108       return gettok();
109   }
110 
111   // Check for end of file.  Don't eat the EOF.
112   if (LastChar == EOF)
113     return tok_eof;
114 
115   // Otherwise, just return the character as its ascii value.
116   int ThisChar = LastChar;
117   LastChar = getchar();
118   return ThisChar;
119 }
120 
121 //===----------------------------------------------------------------------===//
122 // Abstract Syntax Tree (aka Parse Tree)
123 //===----------------------------------------------------------------------===//
124 namespace {
125 /// ExprAST - Base class for all expression nodes.
126 class ExprAST {
127 public:
~ExprAST()128   virtual ~ExprAST() {}
129   virtual Value *codegen() = 0;
130 };
131 
132 /// NumberExprAST - Expression class for numeric literals like "1.0".
133 class NumberExprAST : public ExprAST {
134   double Val;
135 
136 public:
NumberExprAST(double Val)137   NumberExprAST(double Val) : Val(Val) {}
138   Value *codegen() override;
139 };
140 
141 /// VariableExprAST - Expression class for referencing a variable, like "a".
142 class VariableExprAST : public ExprAST {
143   std::string Name;
144 
145 public:
VariableExprAST(const std::string & Name)146   VariableExprAST(const std::string &Name) : Name(Name) {}
147   Value *codegen() override;
148 };
149 
150 /// BinaryExprAST - Expression class for a binary operator.
151 class BinaryExprAST : public ExprAST {
152   char Op;
153   std::unique_ptr<ExprAST> LHS, RHS;
154 
155 public:
BinaryExprAST(char Op,std::unique_ptr<ExprAST> LHS,std::unique_ptr<ExprAST> RHS)156   BinaryExprAST(char Op, std::unique_ptr<ExprAST> LHS,
157                 std::unique_ptr<ExprAST> RHS)
158       : Op(Op), LHS(std::move(LHS)), RHS(std::move(RHS)) {}
159   Value *codegen() override;
160 };
161 
162 /// CallExprAST - Expression class for function calls.
163 class CallExprAST : public ExprAST {
164   std::string Callee;
165   std::vector<std::unique_ptr<ExprAST>> Args;
166 
167 public:
CallExprAST(const std::string & Callee,std::vector<std::unique_ptr<ExprAST>> Args)168   CallExprAST(const std::string &Callee,
169               std::vector<std::unique_ptr<ExprAST>> Args)
170       : Callee(Callee), Args(std::move(Args)) {}
171   Value *codegen() override;
172 };
173 
174 /// IfExprAST - Expression class for if/then/else.
175 class IfExprAST : public ExprAST {
176   std::unique_ptr<ExprAST> Cond, Then, Else;
177 
178 public:
IfExprAST(std::unique_ptr<ExprAST> Cond,std::unique_ptr<ExprAST> Then,std::unique_ptr<ExprAST> Else)179   IfExprAST(std::unique_ptr<ExprAST> Cond, std::unique_ptr<ExprAST> Then,
180             std::unique_ptr<ExprAST> Else)
181       : Cond(std::move(Cond)), Then(std::move(Then)), Else(std::move(Else)) {}
182   Value *codegen() override;
183 };
184 
185 /// ForExprAST - Expression class for for/in.
186 class ForExprAST : public ExprAST {
187   std::string VarName;
188   std::unique_ptr<ExprAST> Start, End, Step, Body;
189 
190 public:
ForExprAST(const std::string & VarName,std::unique_ptr<ExprAST> Start,std::unique_ptr<ExprAST> End,std::unique_ptr<ExprAST> Step,std::unique_ptr<ExprAST> Body)191   ForExprAST(const std::string &VarName, std::unique_ptr<ExprAST> Start,
192              std::unique_ptr<ExprAST> End, std::unique_ptr<ExprAST> Step,
193              std::unique_ptr<ExprAST> Body)
194       : VarName(VarName), Start(std::move(Start)), End(std::move(End)),
195         Step(std::move(Step)), Body(std::move(Body)) {}
196   Value *codegen() override;
197 };
198 
199 /// PrototypeAST - This class represents the "prototype" for a function,
200 /// which captures its name, and its argument names (thus implicitly the number
201 /// of arguments the function takes).
202 class PrototypeAST {
203   std::string Name;
204   std::vector<std::string> Args;
205 
206 public:
PrototypeAST(const std::string & Name,std::vector<std::string> Args)207   PrototypeAST(const std::string &Name, std::vector<std::string> Args)
208       : Name(Name), Args(std::move(Args)) {}
209   Function *codegen();
getName() const210   const std::string &getName() const { return Name; }
211 };
212 
213 /// FunctionAST - This class represents a function definition itself.
214 class FunctionAST {
215   std::unique_ptr<PrototypeAST> Proto;
216   std::unique_ptr<ExprAST> Body;
217 
218 public:
FunctionAST(std::unique_ptr<PrototypeAST> Proto,std::unique_ptr<ExprAST> Body)219   FunctionAST(std::unique_ptr<PrototypeAST> Proto,
220               std::unique_ptr<ExprAST> Body)
221       : Proto(std::move(Proto)), Body(std::move(Body)) {}
222   Function *codegen();
223 };
224 } // end anonymous namespace
225 
226 //===----------------------------------------------------------------------===//
227 // Parser
228 //===----------------------------------------------------------------------===//
229 
230 /// CurTok/getNextToken - Provide a simple token buffer.  CurTok is the current
231 /// token the parser is looking at.  getNextToken reads another token from the
232 /// lexer and updates CurTok with its results.
233 static int CurTok;
getNextToken()234 static int getNextToken() { return CurTok = gettok(); }
235 
236 /// BinopPrecedence - This holds the precedence for each binary operator that is
237 /// defined.
238 static std::map<char, int> BinopPrecedence;
239 
240 /// GetTokPrecedence - Get the precedence of the pending binary operator token.
GetTokPrecedence()241 static int GetTokPrecedence() {
242   if (!isascii(CurTok))
243     return -1;
244 
245   // Make sure it's a declared binop.
246   int TokPrec = BinopPrecedence[CurTok];
247   if (TokPrec <= 0)
248     return -1;
249   return TokPrec;
250 }
251 
252 /// LogError* - These are little helper functions for error handling.
LogError(const char * Str)253 std::unique_ptr<ExprAST> LogError(const char *Str) {
254   fprintf(stderr, "Error: %s\n", Str);
255   return nullptr;
256 }
257 
LogErrorP(const char * Str)258 std::unique_ptr<PrototypeAST> LogErrorP(const char *Str) {
259   LogError(Str);
260   return nullptr;
261 }
262 
263 static std::unique_ptr<ExprAST> ParseExpression();
264 
265 /// numberexpr ::= number
ParseNumberExpr()266 static std::unique_ptr<ExprAST> ParseNumberExpr() {
267   auto Result = llvm::make_unique<NumberExprAST>(NumVal);
268   getNextToken(); // consume the number
269   return std::move(Result);
270 }
271 
272 /// parenexpr ::= '(' expression ')'
ParseParenExpr()273 static std::unique_ptr<ExprAST> ParseParenExpr() {
274   getNextToken(); // eat (.
275   auto V = ParseExpression();
276   if (!V)
277     return nullptr;
278 
279   if (CurTok != ')')
280     return LogError("expected ')'");
281   getNextToken(); // eat ).
282   return V;
283 }
284 
285 /// identifierexpr
286 ///   ::= identifier
287 ///   ::= identifier '(' expression* ')'
ParseIdentifierExpr()288 static std::unique_ptr<ExprAST> ParseIdentifierExpr() {
289   std::string IdName = IdentifierStr;
290 
291   getNextToken(); // eat identifier.
292 
293   if (CurTok != '(') // Simple variable ref.
294     return llvm::make_unique<VariableExprAST>(IdName);
295 
296   // Call.
297   getNextToken(); // eat (
298   std::vector<std::unique_ptr<ExprAST>> Args;
299   if (CurTok != ')') {
300     while (true) {
301       if (auto Arg = ParseExpression())
302         Args.push_back(std::move(Arg));
303       else
304         return nullptr;
305 
306       if (CurTok == ')')
307         break;
308 
309       if (CurTok != ',')
310         return LogError("Expected ')' or ',' in argument list");
311       getNextToken();
312     }
313   }
314 
315   // Eat the ')'.
316   getNextToken();
317 
318   return llvm::make_unique<CallExprAST>(IdName, std::move(Args));
319 }
320 
321 /// ifexpr ::= 'if' expression 'then' expression 'else' expression
ParseIfExpr()322 static std::unique_ptr<ExprAST> ParseIfExpr() {
323   getNextToken(); // eat the if.
324 
325   // condition.
326   auto Cond = ParseExpression();
327   if (!Cond)
328     return nullptr;
329 
330   if (CurTok != tok_then)
331     return LogError("expected then");
332   getNextToken(); // eat the then
333 
334   auto Then = ParseExpression();
335   if (!Then)
336     return nullptr;
337 
338   if (CurTok != tok_else)
339     return LogError("expected else");
340 
341   getNextToken();
342 
343   auto Else = ParseExpression();
344   if (!Else)
345     return nullptr;
346 
347   return llvm::make_unique<IfExprAST>(std::move(Cond), std::move(Then),
348                                       std::move(Else));
349 }
350 
351 /// forexpr ::= 'for' identifier '=' expr ',' expr (',' expr)? 'in' expression
ParseForExpr()352 static std::unique_ptr<ExprAST> ParseForExpr() {
353   getNextToken(); // eat the for.
354 
355   if (CurTok != tok_identifier)
356     return LogError("expected identifier after for");
357 
358   std::string IdName = IdentifierStr;
359   getNextToken(); // eat identifier.
360 
361   if (CurTok != '=')
362     return LogError("expected '=' after for");
363   getNextToken(); // eat '='.
364 
365   auto Start = ParseExpression();
366   if (!Start)
367     return nullptr;
368   if (CurTok != ',')
369     return LogError("expected ',' after for start value");
370   getNextToken();
371 
372   auto End = ParseExpression();
373   if (!End)
374     return nullptr;
375 
376   // The step value is optional.
377   std::unique_ptr<ExprAST> Step;
378   if (CurTok == ',') {
379     getNextToken();
380     Step = ParseExpression();
381     if (!Step)
382       return nullptr;
383   }
384 
385   if (CurTok != tok_in)
386     return LogError("expected 'in' after for");
387   getNextToken(); // eat 'in'.
388 
389   auto Body = ParseExpression();
390   if (!Body)
391     return nullptr;
392 
393   return llvm::make_unique<ForExprAST>(IdName, std::move(Start), std::move(End),
394                                        std::move(Step), std::move(Body));
395 }
396 
397 /// primary
398 ///   ::= identifierexpr
399 ///   ::= numberexpr
400 ///   ::= parenexpr
401 ///   ::= ifexpr
402 ///   ::= forexpr
ParsePrimary()403 static std::unique_ptr<ExprAST> ParsePrimary() {
404   switch (CurTok) {
405   default:
406     return LogError("unknown token when expecting an expression");
407   case tok_identifier:
408     return ParseIdentifierExpr();
409   case tok_number:
410     return ParseNumberExpr();
411   case '(':
412     return ParseParenExpr();
413   case tok_if:
414     return ParseIfExpr();
415   case tok_for:
416     return ParseForExpr();
417   }
418 }
419 
420 /// binoprhs
421 ///   ::= ('+' primary)*
ParseBinOpRHS(int ExprPrec,std::unique_ptr<ExprAST> LHS)422 static std::unique_ptr<ExprAST> ParseBinOpRHS(int ExprPrec,
423                                               std::unique_ptr<ExprAST> LHS) {
424   // If this is a binop, find its precedence.
425   while (true) {
426     int TokPrec = GetTokPrecedence();
427 
428     // If this is a binop that binds at least as tightly as the current binop,
429     // consume it, otherwise we are done.
430     if (TokPrec < ExprPrec)
431       return LHS;
432 
433     // Okay, we know this is a binop.
434     int BinOp = CurTok;
435     getNextToken(); // eat binop
436 
437     // Parse the primary expression after the binary operator.
438     auto RHS = ParsePrimary();
439     if (!RHS)
440       return nullptr;
441 
442     // If BinOp binds less tightly with RHS than the operator after RHS, let
443     // the pending operator take RHS as its LHS.
444     int NextPrec = GetTokPrecedence();
445     if (TokPrec < NextPrec) {
446       RHS = ParseBinOpRHS(TokPrec + 1, std::move(RHS));
447       if (!RHS)
448         return nullptr;
449     }
450 
451     // Merge LHS/RHS.
452     LHS =
453         llvm::make_unique<BinaryExprAST>(BinOp, std::move(LHS), std::move(RHS));
454   }
455 }
456 
457 /// expression
458 ///   ::= primary binoprhs
459 ///
ParseExpression()460 static std::unique_ptr<ExprAST> ParseExpression() {
461   auto LHS = ParsePrimary();
462   if (!LHS)
463     return nullptr;
464 
465   return ParseBinOpRHS(0, std::move(LHS));
466 }
467 
468 /// prototype
469 ///   ::= id '(' id* ')'
ParsePrototype()470 static std::unique_ptr<PrototypeAST> ParsePrototype() {
471   if (CurTok != tok_identifier)
472     return LogErrorP("Expected function name in prototype");
473 
474   std::string FnName = IdentifierStr;
475   getNextToken();
476 
477   if (CurTok != '(')
478     return LogErrorP("Expected '(' in prototype");
479 
480   std::vector<std::string> ArgNames;
481   while (getNextToken() == tok_identifier)
482     ArgNames.push_back(IdentifierStr);
483   if (CurTok != ')')
484     return LogErrorP("Expected ')' in prototype");
485 
486   // success.
487   getNextToken(); // eat ')'.
488 
489   return llvm::make_unique<PrototypeAST>(FnName, std::move(ArgNames));
490 }
491 
492 /// definition ::= 'def' prototype expression
ParseDefinition()493 static std::unique_ptr<FunctionAST> ParseDefinition() {
494   getNextToken(); // eat def.
495   auto Proto = ParsePrototype();
496   if (!Proto)
497     return nullptr;
498 
499   if (auto E = ParseExpression())
500     return llvm::make_unique<FunctionAST>(std::move(Proto), std::move(E));
501   return nullptr;
502 }
503 
504 /// toplevelexpr ::= expression
ParseTopLevelExpr()505 static std::unique_ptr<FunctionAST> ParseTopLevelExpr() {
506   if (auto E = ParseExpression()) {
507     // Make an anonymous proto.
508     auto Proto = llvm::make_unique<PrototypeAST>("__anon_expr",
509                                                  std::vector<std::string>());
510     return llvm::make_unique<FunctionAST>(std::move(Proto), std::move(E));
511   }
512   return nullptr;
513 }
514 
515 /// external ::= 'extern' prototype
ParseExtern()516 static std::unique_ptr<PrototypeAST> ParseExtern() {
517   getNextToken(); // eat extern.
518   return ParsePrototype();
519 }
520 
521 //===----------------------------------------------------------------------===//
522 // Code Generation
523 //===----------------------------------------------------------------------===//
524 
525 static LLVMContext TheContext;
526 static IRBuilder<> Builder(TheContext);
527 static std::unique_ptr<Module> TheModule;
528 static std::map<std::string, Value *> NamedValues;
529 static std::unique_ptr<legacy::FunctionPassManager> TheFPM;
530 static std::unique_ptr<KaleidoscopeJIT> TheJIT;
531 static std::map<std::string, std::unique_ptr<PrototypeAST>> FunctionProtos;
532 
LogErrorV(const char * Str)533 Value *LogErrorV(const char *Str) {
534   LogError(Str);
535   return nullptr;
536 }
537 
getFunction(std::string Name)538 Function *getFunction(std::string Name) {
539   // First, see if the function has already been added to the current module.
540   if (auto *F = TheModule->getFunction(Name))
541     return F;
542 
543   // If not, check whether we can codegen the declaration from some existing
544   // prototype.
545   auto FI = FunctionProtos.find(Name);
546   if (FI != FunctionProtos.end())
547     return FI->second->codegen();
548 
549   // If no existing prototype exists, return null.
550   return nullptr;
551 }
552 
codegen()553 Value *NumberExprAST::codegen() {
554   return ConstantFP::get(TheContext, APFloat(Val));
555 }
556 
codegen()557 Value *VariableExprAST::codegen() {
558   // Look this variable up in the function.
559   Value *V = NamedValues[Name];
560   if (!V)
561     return LogErrorV("Unknown variable name");
562   return V;
563 }
564 
codegen()565 Value *BinaryExprAST::codegen() {
566   Value *L = LHS->codegen();
567   Value *R = RHS->codegen();
568   if (!L || !R)
569     return nullptr;
570 
571   switch (Op) {
572   case '+':
573     return Builder.CreateFAdd(L, R, "addtmp");
574   case '-':
575     return Builder.CreateFSub(L, R, "subtmp");
576   case '*':
577     return Builder.CreateFMul(L, R, "multmp");
578   case '<':
579     L = Builder.CreateFCmpULT(L, R, "cmptmp");
580     // Convert bool 0/1 to double 0.0 or 1.0
581     return Builder.CreateUIToFP(L, Type::getDoubleTy(TheContext), "booltmp");
582   default:
583     return LogErrorV("invalid binary operator");
584   }
585 }
586 
codegen()587 Value *CallExprAST::codegen() {
588   // Look up the name in the global module table.
589   Function *CalleeF = getFunction(Callee);
590   if (!CalleeF)
591     return LogErrorV("Unknown function referenced");
592 
593   // If argument mismatch error.
594   if (CalleeF->arg_size() != Args.size())
595     return LogErrorV("Incorrect # arguments passed");
596 
597   std::vector<Value *> ArgsV;
598   for (unsigned i = 0, e = Args.size(); i != e; ++i) {
599     ArgsV.push_back(Args[i]->codegen());
600     if (!ArgsV.back())
601       return nullptr;
602   }
603 
604   return Builder.CreateCall(CalleeF, ArgsV, "calltmp");
605 }
606 
codegen()607 Value *IfExprAST::codegen() {
608   Value *CondV = Cond->codegen();
609   if (!CondV)
610     return nullptr;
611 
612   // Convert condition to a bool by comparing equal to 0.0.
613   CondV = Builder.CreateFCmpONE(
614       CondV, ConstantFP::get(TheContext, APFloat(0.0)), "ifcond");
615 
616   Function *TheFunction = Builder.GetInsertBlock()->getParent();
617 
618   // Create blocks for the then and else cases.  Insert the 'then' block at the
619   // end of the function.
620   BasicBlock *ThenBB = BasicBlock::Create(TheContext, "then", TheFunction);
621   BasicBlock *ElseBB = BasicBlock::Create(TheContext, "else");
622   BasicBlock *MergeBB = BasicBlock::Create(TheContext, "ifcont");
623 
624   Builder.CreateCondBr(CondV, ThenBB, ElseBB);
625 
626   // Emit then value.
627   Builder.SetInsertPoint(ThenBB);
628 
629   Value *ThenV = Then->codegen();
630   if (!ThenV)
631     return nullptr;
632 
633   Builder.CreateBr(MergeBB);
634   // Codegen of 'Then' can change the current block, update ThenBB for the PHI.
635   ThenBB = Builder.GetInsertBlock();
636 
637   // Emit else block.
638   TheFunction->getBasicBlockList().push_back(ElseBB);
639   Builder.SetInsertPoint(ElseBB);
640 
641   Value *ElseV = Else->codegen();
642   if (!ElseV)
643     return nullptr;
644 
645   Builder.CreateBr(MergeBB);
646   // Codegen of 'Else' can change the current block, update ElseBB for the PHI.
647   ElseBB = Builder.GetInsertBlock();
648 
649   // Emit merge block.
650   TheFunction->getBasicBlockList().push_back(MergeBB);
651   Builder.SetInsertPoint(MergeBB);
652   PHINode *PN = Builder.CreatePHI(Type::getDoubleTy(TheContext), 2, "iftmp");
653 
654   PN->addIncoming(ThenV, ThenBB);
655   PN->addIncoming(ElseV, ElseBB);
656   return PN;
657 }
658 
659 // Output for-loop as:
660 //   ...
661 //   start = startexpr
662 //   goto loop
663 // loop:
664 //   variable = phi [start, loopheader], [nextvariable, loopend]
665 //   ...
666 //   bodyexpr
667 //   ...
668 // loopend:
669 //   step = stepexpr
670 //   nextvariable = variable + step
671 //   endcond = endexpr
672 //   br endcond, loop, endloop
673 // outloop:
codegen()674 Value *ForExprAST::codegen() {
675   // Emit the start code first, without 'variable' in scope.
676   Value *StartVal = Start->codegen();
677   if (!StartVal)
678     return nullptr;
679 
680   // Make the new basic block for the loop header, inserting after current
681   // block.
682   Function *TheFunction = Builder.GetInsertBlock()->getParent();
683   BasicBlock *PreheaderBB = Builder.GetInsertBlock();
684   BasicBlock *LoopBB = BasicBlock::Create(TheContext, "loop", TheFunction);
685 
686   // Insert an explicit fall through from the current block to the LoopBB.
687   Builder.CreateBr(LoopBB);
688 
689   // Start insertion in LoopBB.
690   Builder.SetInsertPoint(LoopBB);
691 
692   // Start the PHI node with an entry for Start.
693   PHINode *Variable =
694       Builder.CreatePHI(Type::getDoubleTy(TheContext), 2, VarName);
695   Variable->addIncoming(StartVal, PreheaderBB);
696 
697   // Within the loop, the variable is defined equal to the PHI node.  If it
698   // shadows an existing variable, we have to restore it, so save it now.
699   Value *OldVal = NamedValues[VarName];
700   NamedValues[VarName] = Variable;
701 
702   // Emit the body of the loop.  This, like any other expr, can change the
703   // current BB.  Note that we ignore the value computed by the body, but don't
704   // allow an error.
705   if (!Body->codegen())
706     return nullptr;
707 
708   // Emit the step value.
709   Value *StepVal = nullptr;
710   if (Step) {
711     StepVal = Step->codegen();
712     if (!StepVal)
713       return nullptr;
714   } else {
715     // If not specified, use 1.0.
716     StepVal = ConstantFP::get(TheContext, APFloat(1.0));
717   }
718 
719   Value *NextVar = Builder.CreateFAdd(Variable, StepVal, "nextvar");
720 
721   // Compute the end condition.
722   Value *EndCond = End->codegen();
723   if (!EndCond)
724     return nullptr;
725 
726   // Convert condition to a bool by comparing equal to 0.0.
727   EndCond = Builder.CreateFCmpONE(
728       EndCond, ConstantFP::get(TheContext, APFloat(0.0)), "loopcond");
729 
730   // Create the "after loop" block and insert it.
731   BasicBlock *LoopEndBB = Builder.GetInsertBlock();
732   BasicBlock *AfterBB =
733       BasicBlock::Create(TheContext, "afterloop", TheFunction);
734 
735   // Insert the conditional branch into the end of LoopEndBB.
736   Builder.CreateCondBr(EndCond, LoopBB, AfterBB);
737 
738   // Any new code will be inserted in AfterBB.
739   Builder.SetInsertPoint(AfterBB);
740 
741   // Add a new entry to the PHI node for the backedge.
742   Variable->addIncoming(NextVar, LoopEndBB);
743 
744   // Restore the unshadowed variable.
745   if (OldVal)
746     NamedValues[VarName] = OldVal;
747   else
748     NamedValues.erase(VarName);
749 
750   // for expr always returns 0.0.
751   return Constant::getNullValue(Type::getDoubleTy(TheContext));
752 }
753 
codegen()754 Function *PrototypeAST::codegen() {
755   // Make the function type:  double(double,double) etc.
756   std::vector<Type *> Doubles(Args.size(), Type::getDoubleTy(TheContext));
757   FunctionType *FT =
758       FunctionType::get(Type::getDoubleTy(TheContext), Doubles, false);
759 
760   Function *F =
761       Function::Create(FT, Function::ExternalLinkage, Name, TheModule.get());
762 
763   // Set names for all arguments.
764   unsigned Idx = 0;
765   for (auto &Arg : F->args())
766     Arg.setName(Args[Idx++]);
767 
768   return F;
769 }
770 
codegen()771 Function *FunctionAST::codegen() {
772   // Transfer ownership of the prototype to the FunctionProtos map, but keep a
773   // reference to it for use below.
774   auto &P = *Proto;
775   FunctionProtos[Proto->getName()] = std::move(Proto);
776   Function *TheFunction = getFunction(P.getName());
777   if (!TheFunction)
778     return nullptr;
779 
780   // Create a new basic block to start insertion into.
781   BasicBlock *BB = BasicBlock::Create(TheContext, "entry", TheFunction);
782   Builder.SetInsertPoint(BB);
783 
784   // Record the function arguments in the NamedValues map.
785   NamedValues.clear();
786   for (auto &Arg : TheFunction->args())
787     NamedValues[Arg.getName()] = &Arg;
788 
789   if (Value *RetVal = Body->codegen()) {
790     // Finish off the function.
791     Builder.CreateRet(RetVal);
792 
793     // Validate the generated code, checking for consistency.
794     verifyFunction(*TheFunction);
795 
796     // Run the optimizer on the function.
797     TheFPM->run(*TheFunction);
798 
799     return TheFunction;
800   }
801 
802   // Error reading body, remove function.
803   TheFunction->eraseFromParent();
804   return nullptr;
805 }
806 
807 //===----------------------------------------------------------------------===//
808 // Top-Level parsing and JIT Driver
809 //===----------------------------------------------------------------------===//
810 
InitializeModuleAndPassManager()811 static void InitializeModuleAndPassManager() {
812   // Open a new module.
813   TheModule = llvm::make_unique<Module>("my cool jit", TheContext);
814   TheModule->setDataLayout(TheJIT->getTargetMachine().createDataLayout());
815 
816   // Create a new pass manager attached to it.
817   TheFPM = llvm::make_unique<legacy::FunctionPassManager>(TheModule.get());
818 
819   // Do simple "peephole" optimizations and bit-twiddling optzns.
820   TheFPM->add(createInstructionCombiningPass());
821   // Reassociate expressions.
822   TheFPM->add(createReassociatePass());
823   // Eliminate Common SubExpressions.
824   TheFPM->add(createGVNPass());
825   // Simplify the control flow graph (deleting unreachable blocks, etc).
826   TheFPM->add(createCFGSimplificationPass());
827 
828   TheFPM->doInitialization();
829 }
830 
HandleDefinition()831 static void HandleDefinition() {
832   if (auto FnAST = ParseDefinition()) {
833     if (auto *FnIR = FnAST->codegen()) {
834       fprintf(stderr, "Read function definition:");
835       FnIR->dump();
836       TheJIT->addModule(std::move(TheModule));
837       InitializeModuleAndPassManager();
838     }
839   } else {
840     // Skip token for error recovery.
841     getNextToken();
842   }
843 }
844 
HandleExtern()845 static void HandleExtern() {
846   if (auto ProtoAST = ParseExtern()) {
847     if (auto *FnIR = ProtoAST->codegen()) {
848       fprintf(stderr, "Read extern: ");
849       FnIR->dump();
850       FunctionProtos[ProtoAST->getName()] = std::move(ProtoAST);
851     }
852   } else {
853     // Skip token for error recovery.
854     getNextToken();
855   }
856 }
857 
HandleTopLevelExpression()858 static void HandleTopLevelExpression() {
859   // Evaluate a top-level expression into an anonymous function.
860   if (auto FnAST = ParseTopLevelExpr()) {
861     if (FnAST->codegen()) {
862       // JIT the module containing the anonymous expression, keeping a handle so
863       // we can free it later.
864       auto H = TheJIT->addModule(std::move(TheModule));
865       InitializeModuleAndPassManager();
866 
867       // Search the JIT for the __anon_expr symbol.
868       auto ExprSymbol = TheJIT->findSymbol("__anon_expr");
869       assert(ExprSymbol && "Function not found");
870 
871       // Get the symbol's address and cast it to the right type (takes no
872       // arguments, returns a double) so we can call it as a native function.
873       double (*FP)() = (double (*)())(intptr_t)ExprSymbol.getAddress();
874       fprintf(stderr, "Evaluated to %f\n", FP());
875 
876       // Delete the anonymous expression module from the JIT.
877       TheJIT->removeModule(H);
878     }
879   } else {
880     // Skip token for error recovery.
881     getNextToken();
882   }
883 }
884 
885 /// top ::= definition | external | expression | ';'
MainLoop()886 static void MainLoop() {
887   while (true) {
888     fprintf(stderr, "ready> ");
889     switch (CurTok) {
890     case tok_eof:
891       return;
892     case ';': // ignore top-level semicolons.
893       getNextToken();
894       break;
895     case tok_def:
896       HandleDefinition();
897       break;
898     case tok_extern:
899       HandleExtern();
900       break;
901     default:
902       HandleTopLevelExpression();
903       break;
904     }
905   }
906 }
907 
908 //===----------------------------------------------------------------------===//
909 // "Library" functions that can be "extern'd" from user code.
910 //===----------------------------------------------------------------------===//
911 
912 /// putchard - putchar that takes a double and returns 0.
putchard(double X)913 extern "C" double putchard(double X) {
914   fputc((char)X, stderr);
915   return 0;
916 }
917 
918 /// printd - printf that takes a double prints it as "%f\n", returning 0.
printd(double X)919 extern "C" double printd(double X) {
920   fprintf(stderr, "%f\n", X);
921   return 0;
922 }
923 
924 //===----------------------------------------------------------------------===//
925 // Main driver code.
926 //===----------------------------------------------------------------------===//
927 
main()928 int main() {
929   InitializeNativeTarget();
930   InitializeNativeTargetAsmPrinter();
931   InitializeNativeTargetAsmParser();
932 
933   // Install standard binary operators.
934   // 1 is lowest precedence.
935   BinopPrecedence['<'] = 10;
936   BinopPrecedence['+'] = 20;
937   BinopPrecedence['-'] = 20;
938   BinopPrecedence['*'] = 40; // highest.
939 
940   // Prime the first token.
941   fprintf(stderr, "ready> ");
942   getNextToken();
943 
944   TheJIT = llvm::make_unique<KaleidoscopeJIT>();
945 
946   InitializeModuleAndPassManager();
947 
948   // Run the main "interpreter loop" now.
949   MainLoop();
950 
951   return 0;
952 }
953