1 #include "llvm/ADT/APFloat.h"
2 #include "llvm/ADT/STLExtras.h"
3 #include "llvm/IR/BasicBlock.h"
4 #include "llvm/IR/Constants.h"
5 #include "llvm/IR/DerivedTypes.h"
6 #include "llvm/IR/Function.h"
7 #include "llvm/IR/Instructions.h"
8 #include "llvm/IR/IRBuilder.h"
9 #include "llvm/IR/LLVMContext.h"
10 #include "llvm/IR/LegacyPassManager.h"
11 #include "llvm/IR/Module.h"
12 #include "llvm/IR/Type.h"
13 #include "llvm/IR/Verifier.h"
14 #include "llvm/Support/TargetSelect.h"
15 #include "llvm/Target/TargetMachine.h"
16 #include "llvm/Transforms/Scalar.h"
17 #include "llvm/Transforms/Scalar/GVN.h"
18 #include "../include/KaleidoscopeJIT.h"
19 #include <cassert>
20 #include <cctype>
21 #include <cstdint>
22 #include <cstdio>
23 #include <cstdlib>
24 #include <map>
25 #include <memory>
26 #include <string>
27 #include <vector>
28 
29 using namespace llvm;
30 using namespace llvm::orc;
31 
32 //===----------------------------------------------------------------------===//
33 // Lexer
34 //===----------------------------------------------------------------------===//
35 
36 // The lexer returns tokens [0-255] if it is an unknown character, otherwise one
37 // of these for known things.
38 enum Token {
39   tok_eof = -1,
40 
41   // commands
42   tok_def = -2,
43   tok_extern = -3,
44 
45   // primary
46   tok_identifier = -4,
47   tok_number = -5,
48 
49   // control
50   tok_if = -6,
51   tok_then = -7,
52   tok_else = -8,
53   tok_for = -9,
54   tok_in = -10,
55 
56   // operators
57   tok_binary = -11,
58   tok_unary = -12
59 };
60 
61 static std::string IdentifierStr; // Filled in if tok_identifier
62 static double NumVal;             // Filled in if tok_number
63 
64 /// gettok - Return the next token from standard input.
gettok()65 static int gettok() {
66   static int LastChar = ' ';
67 
68   // Skip any whitespace.
69   while (isspace(LastChar))
70     LastChar = getchar();
71 
72   if (isalpha(LastChar)) { // identifier: [a-zA-Z][a-zA-Z0-9]*
73     IdentifierStr = LastChar;
74     while (isalnum((LastChar = getchar())))
75       IdentifierStr += LastChar;
76 
77     if (IdentifierStr == "def")
78       return tok_def;
79     if (IdentifierStr == "extern")
80       return tok_extern;
81     if (IdentifierStr == "if")
82       return tok_if;
83     if (IdentifierStr == "then")
84       return tok_then;
85     if (IdentifierStr == "else")
86       return tok_else;
87     if (IdentifierStr == "for")
88       return tok_for;
89     if (IdentifierStr == "in")
90       return tok_in;
91     if (IdentifierStr == "binary")
92       return tok_binary;
93     if (IdentifierStr == "unary")
94       return tok_unary;
95     return tok_identifier;
96   }
97 
98   if (isdigit(LastChar) || LastChar == '.') { // Number: [0-9.]+
99     std::string NumStr;
100     do {
101       NumStr += LastChar;
102       LastChar = getchar();
103     } while (isdigit(LastChar) || LastChar == '.');
104 
105     NumVal = strtod(NumStr.c_str(), nullptr);
106     return tok_number;
107   }
108 
109   if (LastChar == '#') {
110     // Comment until end of line.
111     do
112       LastChar = getchar();
113     while (LastChar != EOF && LastChar != '\n' && LastChar != '\r');
114 
115     if (LastChar != EOF)
116       return gettok();
117   }
118 
119   // Check for end of file.  Don't eat the EOF.
120   if (LastChar == EOF)
121     return tok_eof;
122 
123   // Otherwise, just return the character as its ascii value.
124   int ThisChar = LastChar;
125   LastChar = getchar();
126   return ThisChar;
127 }
128 
129 //===----------------------------------------------------------------------===//
130 // Abstract Syntax Tree (aka Parse Tree)
131 //===----------------------------------------------------------------------===//
132 namespace {
133 /// ExprAST - Base class for all expression nodes.
134 class ExprAST {
135 public:
~ExprAST()136   virtual ~ExprAST() {}
137   virtual Value *codegen() = 0;
138 };
139 
140 /// NumberExprAST - Expression class for numeric literals like "1.0".
141 class NumberExprAST : public ExprAST {
142   double Val;
143 
144 public:
NumberExprAST(double Val)145   NumberExprAST(double Val) : Val(Val) {}
146   Value *codegen() override;
147 };
148 
149 /// VariableExprAST - Expression class for referencing a variable, like "a".
150 class VariableExprAST : public ExprAST {
151   std::string Name;
152 
153 public:
VariableExprAST(const std::string & Name)154   VariableExprAST(const std::string &Name) : Name(Name) {}
155   Value *codegen() override;
156 };
157 
158 /// UnaryExprAST - Expression class for a unary operator.
159 class UnaryExprAST : public ExprAST {
160   char Opcode;
161   std::unique_ptr<ExprAST> Operand;
162 
163 public:
UnaryExprAST(char Opcode,std::unique_ptr<ExprAST> Operand)164   UnaryExprAST(char Opcode, std::unique_ptr<ExprAST> Operand)
165       : Opcode(Opcode), Operand(std::move(Operand)) {}
166   Value *codegen() override;
167 };
168 
169 /// BinaryExprAST - Expression class for a binary operator.
170 class BinaryExprAST : public ExprAST {
171   char Op;
172   std::unique_ptr<ExprAST> LHS, RHS;
173 
174 public:
BinaryExprAST(char Op,std::unique_ptr<ExprAST> LHS,std::unique_ptr<ExprAST> RHS)175   BinaryExprAST(char Op, std::unique_ptr<ExprAST> LHS,
176                 std::unique_ptr<ExprAST> RHS)
177       : Op(Op), LHS(std::move(LHS)), RHS(std::move(RHS)) {}
178   Value *codegen() override;
179 };
180 
181 /// CallExprAST - Expression class for function calls.
182 class CallExprAST : public ExprAST {
183   std::string Callee;
184   std::vector<std::unique_ptr<ExprAST>> Args;
185 
186 public:
CallExprAST(const std::string & Callee,std::vector<std::unique_ptr<ExprAST>> Args)187   CallExprAST(const std::string &Callee,
188               std::vector<std::unique_ptr<ExprAST>> Args)
189       : Callee(Callee), Args(std::move(Args)) {}
190   Value *codegen() override;
191 };
192 
193 /// IfExprAST - Expression class for if/then/else.
194 class IfExprAST : public ExprAST {
195   std::unique_ptr<ExprAST> Cond, Then, Else;
196 
197 public:
IfExprAST(std::unique_ptr<ExprAST> Cond,std::unique_ptr<ExprAST> Then,std::unique_ptr<ExprAST> Else)198   IfExprAST(std::unique_ptr<ExprAST> Cond, std::unique_ptr<ExprAST> Then,
199             std::unique_ptr<ExprAST> Else)
200       : Cond(std::move(Cond)), Then(std::move(Then)), Else(std::move(Else)) {}
201   Value *codegen() override;
202 };
203 
204 /// ForExprAST - Expression class for for/in.
205 class ForExprAST : public ExprAST {
206   std::string VarName;
207   std::unique_ptr<ExprAST> Start, End, Step, Body;
208 
209 public:
ForExprAST(const std::string & VarName,std::unique_ptr<ExprAST> Start,std::unique_ptr<ExprAST> End,std::unique_ptr<ExprAST> Step,std::unique_ptr<ExprAST> Body)210   ForExprAST(const std::string &VarName, std::unique_ptr<ExprAST> Start,
211              std::unique_ptr<ExprAST> End, std::unique_ptr<ExprAST> Step,
212              std::unique_ptr<ExprAST> Body)
213       : VarName(VarName), Start(std::move(Start)), End(std::move(End)),
214         Step(std::move(Step)), Body(std::move(Body)) {}
215   Value *codegen() override;
216 };
217 
218 /// PrototypeAST - This class represents the "prototype" for a function,
219 /// which captures its name, and its argument names (thus implicitly the number
220 /// of arguments the function takes), as well as if it is an operator.
221 class PrototypeAST {
222   std::string Name;
223   std::vector<std::string> Args;
224   bool IsOperator;
225   unsigned Precedence; // Precedence if a binary op.
226 
227 public:
PrototypeAST(const std::string & Name,std::vector<std::string> Args,bool IsOperator=false,unsigned Prec=0)228   PrototypeAST(const std::string &Name, std::vector<std::string> Args,
229                bool IsOperator = false, unsigned Prec = 0)
230       : Name(Name), Args(std::move(Args)), IsOperator(IsOperator),
231         Precedence(Prec) {}
232   Function *codegen();
getName() const233   const std::string &getName() const { return Name; }
234 
isUnaryOp() const235   bool isUnaryOp() const { return IsOperator && Args.size() == 1; }
isBinaryOp() const236   bool isBinaryOp() const { return IsOperator && Args.size() == 2; }
237 
getOperatorName() const238   char getOperatorName() const {
239     assert(isUnaryOp() || isBinaryOp());
240     return Name[Name.size() - 1];
241   }
242 
getBinaryPrecedence() const243   unsigned getBinaryPrecedence() const { return Precedence; }
244 };
245 
246 /// FunctionAST - This class represents a function definition itself.
247 class FunctionAST {
248   std::unique_ptr<PrototypeAST> Proto;
249   std::unique_ptr<ExprAST> Body;
250 
251 public:
FunctionAST(std::unique_ptr<PrototypeAST> Proto,std::unique_ptr<ExprAST> Body)252   FunctionAST(std::unique_ptr<PrototypeAST> Proto,
253               std::unique_ptr<ExprAST> Body)
254       : Proto(std::move(Proto)), Body(std::move(Body)) {}
255   Function *codegen();
256 };
257 } // end anonymous namespace
258 
259 //===----------------------------------------------------------------------===//
260 // Parser
261 //===----------------------------------------------------------------------===//
262 
263 /// CurTok/getNextToken - Provide a simple token buffer.  CurTok is the current
264 /// token the parser is looking at.  getNextToken reads another token from the
265 /// lexer and updates CurTok with its results.
266 static int CurTok;
getNextToken()267 static int getNextToken() { return CurTok = gettok(); }
268 
269 /// BinopPrecedence - This holds the precedence for each binary operator that is
270 /// defined.
271 static std::map<char, int> BinopPrecedence;
272 
273 /// GetTokPrecedence - Get the precedence of the pending binary operator token.
GetTokPrecedence()274 static int GetTokPrecedence() {
275   if (!isascii(CurTok))
276     return -1;
277 
278   // Make sure it's a declared binop.
279   int TokPrec = BinopPrecedence[CurTok];
280   if (TokPrec <= 0)
281     return -1;
282   return TokPrec;
283 }
284 
285 /// Error* - These are little helper functions for error handling.
LogError(const char * Str)286 std::unique_ptr<ExprAST> LogError(const char *Str) {
287   fprintf(stderr, "Error: %s\n", Str);
288   return nullptr;
289 }
290 
LogErrorP(const char * Str)291 std::unique_ptr<PrototypeAST> LogErrorP(const char *Str) {
292   LogError(Str);
293   return nullptr;
294 }
295 
296 static std::unique_ptr<ExprAST> ParseExpression();
297 
298 /// numberexpr ::= number
ParseNumberExpr()299 static std::unique_ptr<ExprAST> ParseNumberExpr() {
300   auto Result = llvm::make_unique<NumberExprAST>(NumVal);
301   getNextToken(); // consume the number
302   return std::move(Result);
303 }
304 
305 /// parenexpr ::= '(' expression ')'
ParseParenExpr()306 static std::unique_ptr<ExprAST> ParseParenExpr() {
307   getNextToken(); // eat (.
308   auto V = ParseExpression();
309   if (!V)
310     return nullptr;
311 
312   if (CurTok != ')')
313     return LogError("expected ')'");
314   getNextToken(); // eat ).
315   return V;
316 }
317 
318 /// identifierexpr
319 ///   ::= identifier
320 ///   ::= identifier '(' expression* ')'
ParseIdentifierExpr()321 static std::unique_ptr<ExprAST> ParseIdentifierExpr() {
322   std::string IdName = IdentifierStr;
323 
324   getNextToken(); // eat identifier.
325 
326   if (CurTok != '(') // Simple variable ref.
327     return llvm::make_unique<VariableExprAST>(IdName);
328 
329   // Call.
330   getNextToken(); // eat (
331   std::vector<std::unique_ptr<ExprAST>> Args;
332   if (CurTok != ')') {
333     while (true) {
334       if (auto Arg = ParseExpression())
335         Args.push_back(std::move(Arg));
336       else
337         return nullptr;
338 
339       if (CurTok == ')')
340         break;
341 
342       if (CurTok != ',')
343         return LogError("Expected ')' or ',' in argument list");
344       getNextToken();
345     }
346   }
347 
348   // Eat the ')'.
349   getNextToken();
350 
351   return llvm::make_unique<CallExprAST>(IdName, std::move(Args));
352 }
353 
354 /// ifexpr ::= 'if' expression 'then' expression 'else' expression
ParseIfExpr()355 static std::unique_ptr<ExprAST> ParseIfExpr() {
356   getNextToken(); // eat the if.
357 
358   // condition.
359   auto Cond = ParseExpression();
360   if (!Cond)
361     return nullptr;
362 
363   if (CurTok != tok_then)
364     return LogError("expected then");
365   getNextToken(); // eat the then
366 
367   auto Then = ParseExpression();
368   if (!Then)
369     return nullptr;
370 
371   if (CurTok != tok_else)
372     return LogError("expected else");
373 
374   getNextToken();
375 
376   auto Else = ParseExpression();
377   if (!Else)
378     return nullptr;
379 
380   return llvm::make_unique<IfExprAST>(std::move(Cond), std::move(Then),
381                                       std::move(Else));
382 }
383 
384 /// forexpr ::= 'for' identifier '=' expr ',' expr (',' expr)? 'in' expression
ParseForExpr()385 static std::unique_ptr<ExprAST> ParseForExpr() {
386   getNextToken(); // eat the for.
387 
388   if (CurTok != tok_identifier)
389     return LogError("expected identifier after for");
390 
391   std::string IdName = IdentifierStr;
392   getNextToken(); // eat identifier.
393 
394   if (CurTok != '=')
395     return LogError("expected '=' after for");
396   getNextToken(); // eat '='.
397 
398   auto Start = ParseExpression();
399   if (!Start)
400     return nullptr;
401   if (CurTok != ',')
402     return LogError("expected ',' after for start value");
403   getNextToken();
404 
405   auto End = ParseExpression();
406   if (!End)
407     return nullptr;
408 
409   // The step value is optional.
410   std::unique_ptr<ExprAST> Step;
411   if (CurTok == ',') {
412     getNextToken();
413     Step = ParseExpression();
414     if (!Step)
415       return nullptr;
416   }
417 
418   if (CurTok != tok_in)
419     return LogError("expected 'in' after for");
420   getNextToken(); // eat 'in'.
421 
422   auto Body = ParseExpression();
423   if (!Body)
424     return nullptr;
425 
426   return llvm::make_unique<ForExprAST>(IdName, std::move(Start), std::move(End),
427                                        std::move(Step), std::move(Body));
428 }
429 
430 /// primary
431 ///   ::= identifierexpr
432 ///   ::= numberexpr
433 ///   ::= parenexpr
434 ///   ::= ifexpr
435 ///   ::= forexpr
ParsePrimary()436 static std::unique_ptr<ExprAST> ParsePrimary() {
437   switch (CurTok) {
438   default:
439     return LogError("unknown token when expecting an expression");
440   case tok_identifier:
441     return ParseIdentifierExpr();
442   case tok_number:
443     return ParseNumberExpr();
444   case '(':
445     return ParseParenExpr();
446   case tok_if:
447     return ParseIfExpr();
448   case tok_for:
449     return ParseForExpr();
450   }
451 }
452 
453 /// unary
454 ///   ::= primary
455 ///   ::= '!' unary
ParseUnary()456 static std::unique_ptr<ExprAST> ParseUnary() {
457   // If the current token is not an operator, it must be a primary expr.
458   if (!isascii(CurTok) || CurTok == '(' || CurTok == ',')
459     return ParsePrimary();
460 
461   // If this is a unary operator, read it.
462   int Opc = CurTok;
463   getNextToken();
464   if (auto Operand = ParseUnary())
465     return llvm::make_unique<UnaryExprAST>(Opc, std::move(Operand));
466   return nullptr;
467 }
468 
469 /// binoprhs
470 ///   ::= ('+' unary)*
ParseBinOpRHS(int ExprPrec,std::unique_ptr<ExprAST> LHS)471 static std::unique_ptr<ExprAST> ParseBinOpRHS(int ExprPrec,
472                                               std::unique_ptr<ExprAST> LHS) {
473   // If this is a binop, find its precedence.
474   while (true) {
475     int TokPrec = GetTokPrecedence();
476 
477     // If this is a binop that binds at least as tightly as the current binop,
478     // consume it, otherwise we are done.
479     if (TokPrec < ExprPrec)
480       return LHS;
481 
482     // Okay, we know this is a binop.
483     int BinOp = CurTok;
484     getNextToken(); // eat binop
485 
486     // Parse the unary expression after the binary operator.
487     auto RHS = ParseUnary();
488     if (!RHS)
489       return nullptr;
490 
491     // If BinOp binds less tightly with RHS than the operator after RHS, let
492     // the pending operator take RHS as its LHS.
493     int NextPrec = GetTokPrecedence();
494     if (TokPrec < NextPrec) {
495       RHS = ParseBinOpRHS(TokPrec + 1, std::move(RHS));
496       if (!RHS)
497         return nullptr;
498     }
499 
500     // Merge LHS/RHS.
501     LHS =
502         llvm::make_unique<BinaryExprAST>(BinOp, std::move(LHS), std::move(RHS));
503   }
504 }
505 
506 /// expression
507 ///   ::= unary binoprhs
508 ///
ParseExpression()509 static std::unique_ptr<ExprAST> ParseExpression() {
510   auto LHS = ParseUnary();
511   if (!LHS)
512     return nullptr;
513 
514   return ParseBinOpRHS(0, std::move(LHS));
515 }
516 
517 /// prototype
518 ///   ::= id '(' id* ')'
519 ///   ::= binary LETTER number? (id, id)
520 ///   ::= unary LETTER (id)
ParsePrototype()521 static std::unique_ptr<PrototypeAST> ParsePrototype() {
522   std::string FnName;
523 
524   unsigned Kind = 0; // 0 = identifier, 1 = unary, 2 = binary.
525   unsigned BinaryPrecedence = 30;
526 
527   switch (CurTok) {
528   default:
529     return LogErrorP("Expected function name in prototype");
530   case tok_identifier:
531     FnName = IdentifierStr;
532     Kind = 0;
533     getNextToken();
534     break;
535   case tok_unary:
536     getNextToken();
537     if (!isascii(CurTok))
538       return LogErrorP("Expected unary operator");
539     FnName = "unary";
540     FnName += (char)CurTok;
541     Kind = 1;
542     getNextToken();
543     break;
544   case tok_binary:
545     getNextToken();
546     if (!isascii(CurTok))
547       return LogErrorP("Expected binary operator");
548     FnName = "binary";
549     FnName += (char)CurTok;
550     Kind = 2;
551     getNextToken();
552 
553     // Read the precedence if present.
554     if (CurTok == tok_number) {
555       if (NumVal < 1 || NumVal > 100)
556         return LogErrorP("Invalid precedecnce: must be 1..100");
557       BinaryPrecedence = (unsigned)NumVal;
558       getNextToken();
559     }
560     break;
561   }
562 
563   if (CurTok != '(')
564     return LogErrorP("Expected '(' in prototype");
565 
566   std::vector<std::string> ArgNames;
567   while (getNextToken() == tok_identifier)
568     ArgNames.push_back(IdentifierStr);
569   if (CurTok != ')')
570     return LogErrorP("Expected ')' in prototype");
571 
572   // success.
573   getNextToken(); // eat ')'.
574 
575   // Verify right number of names for operator.
576   if (Kind && ArgNames.size() != Kind)
577     return LogErrorP("Invalid number of operands for operator");
578 
579   return llvm::make_unique<PrototypeAST>(FnName, ArgNames, Kind != 0,
580                                          BinaryPrecedence);
581 }
582 
583 /// definition ::= 'def' prototype expression
ParseDefinition()584 static std::unique_ptr<FunctionAST> ParseDefinition() {
585   getNextToken(); // eat def.
586   auto Proto = ParsePrototype();
587   if (!Proto)
588     return nullptr;
589 
590   if (auto E = ParseExpression())
591     return llvm::make_unique<FunctionAST>(std::move(Proto), std::move(E));
592   return nullptr;
593 }
594 
595 /// toplevelexpr ::= expression
ParseTopLevelExpr()596 static std::unique_ptr<FunctionAST> ParseTopLevelExpr() {
597   if (auto E = ParseExpression()) {
598     // Make an anonymous proto.
599     auto Proto = llvm::make_unique<PrototypeAST>("__anon_expr",
600                                                  std::vector<std::string>());
601     return llvm::make_unique<FunctionAST>(std::move(Proto), std::move(E));
602   }
603   return nullptr;
604 }
605 
606 /// external ::= 'extern' prototype
ParseExtern()607 static std::unique_ptr<PrototypeAST> ParseExtern() {
608   getNextToken(); // eat extern.
609   return ParsePrototype();
610 }
611 
612 //===----------------------------------------------------------------------===//
613 // Code Generation
614 //===----------------------------------------------------------------------===//
615 
616 static LLVMContext TheContext;
617 static IRBuilder<> Builder(TheContext);
618 static std::unique_ptr<Module> TheModule;
619 static std::map<std::string, Value *> NamedValues;
620 static std::unique_ptr<legacy::FunctionPassManager> TheFPM;
621 static std::unique_ptr<KaleidoscopeJIT> TheJIT;
622 static std::map<std::string, std::unique_ptr<PrototypeAST>> FunctionProtos;
623 
LogErrorV(const char * Str)624 Value *LogErrorV(const char *Str) {
625   LogError(Str);
626   return nullptr;
627 }
628 
getFunction(std::string Name)629 Function *getFunction(std::string Name) {
630   // First, see if the function has already been added to the current module.
631   if (auto *F = TheModule->getFunction(Name))
632     return F;
633 
634   // If not, check whether we can codegen the declaration from some existing
635   // prototype.
636   auto FI = FunctionProtos.find(Name);
637   if (FI != FunctionProtos.end())
638     return FI->second->codegen();
639 
640   // If no existing prototype exists, return null.
641   return nullptr;
642 }
643 
codegen()644 Value *NumberExprAST::codegen() {
645   return ConstantFP::get(TheContext, APFloat(Val));
646 }
647 
codegen()648 Value *VariableExprAST::codegen() {
649   // Look this variable up in the function.
650   Value *V = NamedValues[Name];
651   if (!V)
652     return LogErrorV("Unknown variable name");
653   return V;
654 }
655 
codegen()656 Value *UnaryExprAST::codegen() {
657   Value *OperandV = Operand->codegen();
658   if (!OperandV)
659     return nullptr;
660 
661   Function *F = getFunction(std::string("unary") + Opcode);
662   if (!F)
663     return LogErrorV("Unknown unary operator");
664 
665   return Builder.CreateCall(F, OperandV, "unop");
666 }
667 
codegen()668 Value *BinaryExprAST::codegen() {
669   Value *L = LHS->codegen();
670   Value *R = RHS->codegen();
671   if (!L || !R)
672     return nullptr;
673 
674   switch (Op) {
675   case '+':
676     return Builder.CreateFAdd(L, R, "addtmp");
677   case '-':
678     return Builder.CreateFSub(L, R, "subtmp");
679   case '*':
680     return Builder.CreateFMul(L, R, "multmp");
681   case '<':
682     L = Builder.CreateFCmpULT(L, R, "cmptmp");
683     // Convert bool 0/1 to double 0.0 or 1.0
684     return Builder.CreateUIToFP(L, Type::getDoubleTy(TheContext), "booltmp");
685   default:
686     break;
687   }
688 
689   // If it wasn't a builtin binary operator, it must be a user defined one. Emit
690   // a call to it.
691   Function *F = getFunction(std::string("binary") + Op);
692   assert(F && "binary operator not found!");
693 
694   Value *Ops[] = {L, R};
695   return Builder.CreateCall(F, Ops, "binop");
696 }
697 
codegen()698 Value *CallExprAST::codegen() {
699   // Look up the name in the global module table.
700   Function *CalleeF = getFunction(Callee);
701   if (!CalleeF)
702     return LogErrorV("Unknown function referenced");
703 
704   // If argument mismatch error.
705   if (CalleeF->arg_size() != Args.size())
706     return LogErrorV("Incorrect # arguments passed");
707 
708   std::vector<Value *> ArgsV;
709   for (unsigned i = 0, e = Args.size(); i != e; ++i) {
710     ArgsV.push_back(Args[i]->codegen());
711     if (!ArgsV.back())
712       return nullptr;
713   }
714 
715   return Builder.CreateCall(CalleeF, ArgsV, "calltmp");
716 }
717 
codegen()718 Value *IfExprAST::codegen() {
719   Value *CondV = Cond->codegen();
720   if (!CondV)
721     return nullptr;
722 
723   // Convert condition to a bool by comparing equal to 0.0.
724   CondV = Builder.CreateFCmpONE(
725       CondV, ConstantFP::get(TheContext, APFloat(0.0)), "ifcond");
726 
727   Function *TheFunction = Builder.GetInsertBlock()->getParent();
728 
729   // Create blocks for the then and else cases.  Insert the 'then' block at the
730   // end of the function.
731   BasicBlock *ThenBB = BasicBlock::Create(TheContext, "then", TheFunction);
732   BasicBlock *ElseBB = BasicBlock::Create(TheContext, "else");
733   BasicBlock *MergeBB = BasicBlock::Create(TheContext, "ifcont");
734 
735   Builder.CreateCondBr(CondV, ThenBB, ElseBB);
736 
737   // Emit then value.
738   Builder.SetInsertPoint(ThenBB);
739 
740   Value *ThenV = Then->codegen();
741   if (!ThenV)
742     return nullptr;
743 
744   Builder.CreateBr(MergeBB);
745   // Codegen of 'Then' can change the current block, update ThenBB for the PHI.
746   ThenBB = Builder.GetInsertBlock();
747 
748   // Emit else block.
749   TheFunction->getBasicBlockList().push_back(ElseBB);
750   Builder.SetInsertPoint(ElseBB);
751 
752   Value *ElseV = Else->codegen();
753   if (!ElseV)
754     return nullptr;
755 
756   Builder.CreateBr(MergeBB);
757   // Codegen of 'Else' can change the current block, update ElseBB for the PHI.
758   ElseBB = Builder.GetInsertBlock();
759 
760   // Emit merge block.
761   TheFunction->getBasicBlockList().push_back(MergeBB);
762   Builder.SetInsertPoint(MergeBB);
763   PHINode *PN = Builder.CreatePHI(Type::getDoubleTy(TheContext), 2, "iftmp");
764 
765   PN->addIncoming(ThenV, ThenBB);
766   PN->addIncoming(ElseV, ElseBB);
767   return PN;
768 }
769 
770 // Output for-loop as:
771 //   ...
772 //   start = startexpr
773 //   goto loop
774 // loop:
775 //   variable = phi [start, loopheader], [nextvariable, loopend]
776 //   ...
777 //   bodyexpr
778 //   ...
779 // loopend:
780 //   step = stepexpr
781 //   nextvariable = variable + step
782 //   endcond = endexpr
783 //   br endcond, loop, endloop
784 // outloop:
codegen()785 Value *ForExprAST::codegen() {
786   // Emit the start code first, without 'variable' in scope.
787   Value *StartVal = Start->codegen();
788   if (!StartVal)
789     return nullptr;
790 
791   // Make the new basic block for the loop header, inserting after current
792   // block.
793   Function *TheFunction = Builder.GetInsertBlock()->getParent();
794   BasicBlock *PreheaderBB = Builder.GetInsertBlock();
795   BasicBlock *LoopBB = BasicBlock::Create(TheContext, "loop", TheFunction);
796 
797   // Insert an explicit fall through from the current block to the LoopBB.
798   Builder.CreateBr(LoopBB);
799 
800   // Start insertion in LoopBB.
801   Builder.SetInsertPoint(LoopBB);
802 
803   // Start the PHI node with an entry for Start.
804   PHINode *Variable =
805       Builder.CreatePHI(Type::getDoubleTy(TheContext), 2, VarName);
806   Variable->addIncoming(StartVal, PreheaderBB);
807 
808   // Within the loop, the variable is defined equal to the PHI node.  If it
809   // shadows an existing variable, we have to restore it, so save it now.
810   Value *OldVal = NamedValues[VarName];
811   NamedValues[VarName] = Variable;
812 
813   // Emit the body of the loop.  This, like any other expr, can change the
814   // current BB.  Note that we ignore the value computed by the body, but don't
815   // allow an error.
816   if (!Body->codegen())
817     return nullptr;
818 
819   // Emit the step value.
820   Value *StepVal = nullptr;
821   if (Step) {
822     StepVal = Step->codegen();
823     if (!StepVal)
824       return nullptr;
825   } else {
826     // If not specified, use 1.0.
827     StepVal = ConstantFP::get(TheContext, APFloat(1.0));
828   }
829 
830   Value *NextVar = Builder.CreateFAdd(Variable, StepVal, "nextvar");
831 
832   // Compute the end condition.
833   Value *EndCond = End->codegen();
834   if (!EndCond)
835     return nullptr;
836 
837   // Convert condition to a bool by comparing equal to 0.0.
838   EndCond = Builder.CreateFCmpONE(
839       EndCond, ConstantFP::get(TheContext, APFloat(0.0)), "loopcond");
840 
841   // Create the "after loop" block and insert it.
842   BasicBlock *LoopEndBB = Builder.GetInsertBlock();
843   BasicBlock *AfterBB =
844       BasicBlock::Create(TheContext, "afterloop", TheFunction);
845 
846   // Insert the conditional branch into the end of LoopEndBB.
847   Builder.CreateCondBr(EndCond, LoopBB, AfterBB);
848 
849   // Any new code will be inserted in AfterBB.
850   Builder.SetInsertPoint(AfterBB);
851 
852   // Add a new entry to the PHI node for the backedge.
853   Variable->addIncoming(NextVar, LoopEndBB);
854 
855   // Restore the unshadowed variable.
856   if (OldVal)
857     NamedValues[VarName] = OldVal;
858   else
859     NamedValues.erase(VarName);
860 
861   // for expr always returns 0.0.
862   return Constant::getNullValue(Type::getDoubleTy(TheContext));
863 }
864 
codegen()865 Function *PrototypeAST::codegen() {
866   // Make the function type:  double(double,double) etc.
867   std::vector<Type *> Doubles(Args.size(), Type::getDoubleTy(TheContext));
868   FunctionType *FT =
869       FunctionType::get(Type::getDoubleTy(TheContext), Doubles, false);
870 
871   Function *F =
872       Function::Create(FT, Function::ExternalLinkage, Name, TheModule.get());
873 
874   // Set names for all arguments.
875   unsigned Idx = 0;
876   for (auto &Arg : F->args())
877     Arg.setName(Args[Idx++]);
878 
879   return F;
880 }
881 
codegen()882 Function *FunctionAST::codegen() {
883   // Transfer ownership of the prototype to the FunctionProtos map, but keep a
884   // reference to it for use below.
885   auto &P = *Proto;
886   FunctionProtos[Proto->getName()] = std::move(Proto);
887   Function *TheFunction = getFunction(P.getName());
888   if (!TheFunction)
889     return nullptr;
890 
891   // If this is an operator, install it.
892   if (P.isBinaryOp())
893     BinopPrecedence[P.getOperatorName()] = P.getBinaryPrecedence();
894 
895   // Create a new basic block to start insertion into.
896   BasicBlock *BB = BasicBlock::Create(TheContext, "entry", TheFunction);
897   Builder.SetInsertPoint(BB);
898 
899   // Record the function arguments in the NamedValues map.
900   NamedValues.clear();
901   for (auto &Arg : TheFunction->args())
902     NamedValues[Arg.getName()] = &Arg;
903 
904   if (Value *RetVal = Body->codegen()) {
905     // Finish off the function.
906     Builder.CreateRet(RetVal);
907 
908     // Validate the generated code, checking for consistency.
909     verifyFunction(*TheFunction);
910 
911     // Run the optimizer on the function.
912     TheFPM->run(*TheFunction);
913 
914     return TheFunction;
915   }
916 
917   // Error reading body, remove function.
918   TheFunction->eraseFromParent();
919 
920   if (P.isBinaryOp())
921     BinopPrecedence.erase(Proto->getOperatorName());
922   return nullptr;
923 }
924 
925 //===----------------------------------------------------------------------===//
926 // Top-Level parsing and JIT Driver
927 //===----------------------------------------------------------------------===//
928 
InitializeModuleAndPassManager()929 static void InitializeModuleAndPassManager() {
930   // Open a new module.
931   TheModule = llvm::make_unique<Module>("my cool jit", TheContext);
932   TheModule->setDataLayout(TheJIT->getTargetMachine().createDataLayout());
933 
934   // Create a new pass manager attached to it.
935   TheFPM = llvm::make_unique<legacy::FunctionPassManager>(TheModule.get());
936 
937   // Do simple "peephole" optimizations and bit-twiddling optzns.
938   TheFPM->add(createInstructionCombiningPass());
939   // Reassociate expressions.
940   TheFPM->add(createReassociatePass());
941   // Eliminate Common SubExpressions.
942   TheFPM->add(createGVNPass());
943   // Simplify the control flow graph (deleting unreachable blocks, etc).
944   TheFPM->add(createCFGSimplificationPass());
945 
946   TheFPM->doInitialization();
947 }
948 
HandleDefinition()949 static void HandleDefinition() {
950   if (auto FnAST = ParseDefinition()) {
951     if (auto *FnIR = FnAST->codegen()) {
952       fprintf(stderr, "Read function definition:");
953       FnIR->dump();
954       TheJIT->addModule(std::move(TheModule));
955       InitializeModuleAndPassManager();
956     }
957   } else {
958     // Skip token for error recovery.
959     getNextToken();
960   }
961 }
962 
HandleExtern()963 static void HandleExtern() {
964   if (auto ProtoAST = ParseExtern()) {
965     if (auto *FnIR = ProtoAST->codegen()) {
966       fprintf(stderr, "Read extern: ");
967       FnIR->dump();
968       FunctionProtos[ProtoAST->getName()] = std::move(ProtoAST);
969     }
970   } else {
971     // Skip token for error recovery.
972     getNextToken();
973   }
974 }
975 
HandleTopLevelExpression()976 static void HandleTopLevelExpression() {
977   // Evaluate a top-level expression into an anonymous function.
978   if (auto FnAST = ParseTopLevelExpr()) {
979     if (FnAST->codegen()) {
980       // JIT the module containing the anonymous expression, keeping a handle so
981       // we can free it later.
982       auto H = TheJIT->addModule(std::move(TheModule));
983       InitializeModuleAndPassManager();
984 
985       // Search the JIT for the __anon_expr symbol.
986       auto ExprSymbol = TheJIT->findSymbol("__anon_expr");
987       assert(ExprSymbol && "Function not found");
988 
989       // Get the symbol's address and cast it to the right type (takes no
990       // arguments, returns a double) so we can call it as a native function.
991       double (*FP)() = (double (*)())(intptr_t)ExprSymbol.getAddress();
992       fprintf(stderr, "Evaluated to %f\n", FP());
993 
994       // Delete the anonymous expression module from the JIT.
995       TheJIT->removeModule(H);
996     }
997   } else {
998     // Skip token for error recovery.
999     getNextToken();
1000   }
1001 }
1002 
1003 /// top ::= definition | external | expression | ';'
MainLoop()1004 static void MainLoop() {
1005   while (true) {
1006     fprintf(stderr, "ready> ");
1007     switch (CurTok) {
1008     case tok_eof:
1009       return;
1010     case ';': // ignore top-level semicolons.
1011       getNextToken();
1012       break;
1013     case tok_def:
1014       HandleDefinition();
1015       break;
1016     case tok_extern:
1017       HandleExtern();
1018       break;
1019     default:
1020       HandleTopLevelExpression();
1021       break;
1022     }
1023   }
1024 }
1025 
1026 //===----------------------------------------------------------------------===//
1027 // "Library" functions that can be "extern'd" from user code.
1028 //===----------------------------------------------------------------------===//
1029 
1030 /// putchard - putchar that takes a double and returns 0.
putchard(double X)1031 extern "C" double putchard(double X) {
1032   fputc((char)X, stderr);
1033   return 0;
1034 }
1035 
1036 /// printd - printf that takes a double prints it as "%f\n", returning 0.
printd(double X)1037 extern "C" double printd(double X) {
1038   fprintf(stderr, "%f\n", X);
1039   return 0;
1040 }
1041 
1042 //===----------------------------------------------------------------------===//
1043 // Main driver code.
1044 //===----------------------------------------------------------------------===//
1045 
main()1046 int main() {
1047   InitializeNativeTarget();
1048   InitializeNativeTargetAsmPrinter();
1049   InitializeNativeTargetAsmParser();
1050 
1051   // Install standard binary operators.
1052   // 1 is lowest precedence.
1053   BinopPrecedence['<'] = 10;
1054   BinopPrecedence['+'] = 20;
1055   BinopPrecedence['-'] = 20;
1056   BinopPrecedence['*'] = 40; // highest.
1057 
1058   // Prime the first token.
1059   fprintf(stderr, "ready> ");
1060   getNextToken();
1061 
1062   TheJIT = llvm::make_unique<KaleidoscopeJIT>();
1063 
1064   InitializeModuleAndPassManager();
1065 
1066   // Run the main "interpreter loop" now.
1067   MainLoop();
1068 
1069   return 0;
1070 }
1071