1 #include "llvm/ADT/APFloat.h"
2 #include "llvm/ADT/STLExtras.h"
3 #include "llvm/IR/BasicBlock.h"
4 #include "llvm/IR/Constants.h"
5 #include "llvm/IR/DerivedTypes.h"
6 #include "llvm/IR/Function.h"
7 #include "llvm/IR/Instructions.h"
8 #include "llvm/IR/IRBuilder.h"
9 #include "llvm/IR/LLVMContext.h"
10 #include "llvm/IR/LegacyPassManager.h"
11 #include "llvm/IR/Module.h"
12 #include "llvm/IR/Type.h"
13 #include "llvm/IR/Verifier.h"
14 #include "llvm/Support/TargetSelect.h"
15 #include "llvm/Target/TargetMachine.h"
16 #include "llvm/Transforms/Scalar.h"
17 #include "llvm/Transforms/Scalar/GVN.h"
18 #include "../include/KaleidoscopeJIT.h"
19 #include <cassert>
20 #include <cctype>
21 #include <cstdint>
22 #include <cstdio>
23 #include <cstdlib>
24 #include <map>
25 #include <memory>
26 #include <string>
27 #include <vector>
28
29 using namespace llvm;
30 using namespace llvm::orc;
31
32 //===----------------------------------------------------------------------===//
33 // Lexer
34 //===----------------------------------------------------------------------===//
35
36 // The lexer returns tokens [0-255] if it is an unknown character, otherwise one
37 // of these for known things.
38 enum Token {
39 tok_eof = -1,
40
41 // commands
42 tok_def = -2,
43 tok_extern = -3,
44
45 // primary
46 tok_identifier = -4,
47 tok_number = -5,
48
49 // control
50 tok_if = -6,
51 tok_then = -7,
52 tok_else = -8,
53 tok_for = -9,
54 tok_in = -10,
55
56 // operators
57 tok_binary = -11,
58 tok_unary = -12
59 };
60
61 static std::string IdentifierStr; // Filled in if tok_identifier
62 static double NumVal; // Filled in if tok_number
63
64 /// gettok - Return the next token from standard input.
gettok()65 static int gettok() {
66 static int LastChar = ' ';
67
68 // Skip any whitespace.
69 while (isspace(LastChar))
70 LastChar = getchar();
71
72 if (isalpha(LastChar)) { // identifier: [a-zA-Z][a-zA-Z0-9]*
73 IdentifierStr = LastChar;
74 while (isalnum((LastChar = getchar())))
75 IdentifierStr += LastChar;
76
77 if (IdentifierStr == "def")
78 return tok_def;
79 if (IdentifierStr == "extern")
80 return tok_extern;
81 if (IdentifierStr == "if")
82 return tok_if;
83 if (IdentifierStr == "then")
84 return tok_then;
85 if (IdentifierStr == "else")
86 return tok_else;
87 if (IdentifierStr == "for")
88 return tok_for;
89 if (IdentifierStr == "in")
90 return tok_in;
91 if (IdentifierStr == "binary")
92 return tok_binary;
93 if (IdentifierStr == "unary")
94 return tok_unary;
95 return tok_identifier;
96 }
97
98 if (isdigit(LastChar) || LastChar == '.') { // Number: [0-9.]+
99 std::string NumStr;
100 do {
101 NumStr += LastChar;
102 LastChar = getchar();
103 } while (isdigit(LastChar) || LastChar == '.');
104
105 NumVal = strtod(NumStr.c_str(), nullptr);
106 return tok_number;
107 }
108
109 if (LastChar == '#') {
110 // Comment until end of line.
111 do
112 LastChar = getchar();
113 while (LastChar != EOF && LastChar != '\n' && LastChar != '\r');
114
115 if (LastChar != EOF)
116 return gettok();
117 }
118
119 // Check for end of file. Don't eat the EOF.
120 if (LastChar == EOF)
121 return tok_eof;
122
123 // Otherwise, just return the character as its ascii value.
124 int ThisChar = LastChar;
125 LastChar = getchar();
126 return ThisChar;
127 }
128
129 //===----------------------------------------------------------------------===//
130 // Abstract Syntax Tree (aka Parse Tree)
131 //===----------------------------------------------------------------------===//
132 namespace {
133 /// ExprAST - Base class for all expression nodes.
134 class ExprAST {
135 public:
~ExprAST()136 virtual ~ExprAST() {}
137 virtual Value *codegen() = 0;
138 };
139
140 /// NumberExprAST - Expression class for numeric literals like "1.0".
141 class NumberExprAST : public ExprAST {
142 double Val;
143
144 public:
NumberExprAST(double Val)145 NumberExprAST(double Val) : Val(Val) {}
146 Value *codegen() override;
147 };
148
149 /// VariableExprAST - Expression class for referencing a variable, like "a".
150 class VariableExprAST : public ExprAST {
151 std::string Name;
152
153 public:
VariableExprAST(const std::string & Name)154 VariableExprAST(const std::string &Name) : Name(Name) {}
155 Value *codegen() override;
156 };
157
158 /// UnaryExprAST - Expression class for a unary operator.
159 class UnaryExprAST : public ExprAST {
160 char Opcode;
161 std::unique_ptr<ExprAST> Operand;
162
163 public:
UnaryExprAST(char Opcode,std::unique_ptr<ExprAST> Operand)164 UnaryExprAST(char Opcode, std::unique_ptr<ExprAST> Operand)
165 : Opcode(Opcode), Operand(std::move(Operand)) {}
166 Value *codegen() override;
167 };
168
169 /// BinaryExprAST - Expression class for a binary operator.
170 class BinaryExprAST : public ExprAST {
171 char Op;
172 std::unique_ptr<ExprAST> LHS, RHS;
173
174 public:
BinaryExprAST(char Op,std::unique_ptr<ExprAST> LHS,std::unique_ptr<ExprAST> RHS)175 BinaryExprAST(char Op, std::unique_ptr<ExprAST> LHS,
176 std::unique_ptr<ExprAST> RHS)
177 : Op(Op), LHS(std::move(LHS)), RHS(std::move(RHS)) {}
178 Value *codegen() override;
179 };
180
181 /// CallExprAST - Expression class for function calls.
182 class CallExprAST : public ExprAST {
183 std::string Callee;
184 std::vector<std::unique_ptr<ExprAST>> Args;
185
186 public:
CallExprAST(const std::string & Callee,std::vector<std::unique_ptr<ExprAST>> Args)187 CallExprAST(const std::string &Callee,
188 std::vector<std::unique_ptr<ExprAST>> Args)
189 : Callee(Callee), Args(std::move(Args)) {}
190 Value *codegen() override;
191 };
192
193 /// IfExprAST - Expression class for if/then/else.
194 class IfExprAST : public ExprAST {
195 std::unique_ptr<ExprAST> Cond, Then, Else;
196
197 public:
IfExprAST(std::unique_ptr<ExprAST> Cond,std::unique_ptr<ExprAST> Then,std::unique_ptr<ExprAST> Else)198 IfExprAST(std::unique_ptr<ExprAST> Cond, std::unique_ptr<ExprAST> Then,
199 std::unique_ptr<ExprAST> Else)
200 : Cond(std::move(Cond)), Then(std::move(Then)), Else(std::move(Else)) {}
201 Value *codegen() override;
202 };
203
204 /// ForExprAST - Expression class for for/in.
205 class ForExprAST : public ExprAST {
206 std::string VarName;
207 std::unique_ptr<ExprAST> Start, End, Step, Body;
208
209 public:
ForExprAST(const std::string & VarName,std::unique_ptr<ExprAST> Start,std::unique_ptr<ExprAST> End,std::unique_ptr<ExprAST> Step,std::unique_ptr<ExprAST> Body)210 ForExprAST(const std::string &VarName, std::unique_ptr<ExprAST> Start,
211 std::unique_ptr<ExprAST> End, std::unique_ptr<ExprAST> Step,
212 std::unique_ptr<ExprAST> Body)
213 : VarName(VarName), Start(std::move(Start)), End(std::move(End)),
214 Step(std::move(Step)), Body(std::move(Body)) {}
215 Value *codegen() override;
216 };
217
218 /// PrototypeAST - This class represents the "prototype" for a function,
219 /// which captures its name, and its argument names (thus implicitly the number
220 /// of arguments the function takes), as well as if it is an operator.
221 class PrototypeAST {
222 std::string Name;
223 std::vector<std::string> Args;
224 bool IsOperator;
225 unsigned Precedence; // Precedence if a binary op.
226
227 public:
PrototypeAST(const std::string & Name,std::vector<std::string> Args,bool IsOperator=false,unsigned Prec=0)228 PrototypeAST(const std::string &Name, std::vector<std::string> Args,
229 bool IsOperator = false, unsigned Prec = 0)
230 : Name(Name), Args(std::move(Args)), IsOperator(IsOperator),
231 Precedence(Prec) {}
232 Function *codegen();
getName() const233 const std::string &getName() const { return Name; }
234
isUnaryOp() const235 bool isUnaryOp() const { return IsOperator && Args.size() == 1; }
isBinaryOp() const236 bool isBinaryOp() const { return IsOperator && Args.size() == 2; }
237
getOperatorName() const238 char getOperatorName() const {
239 assert(isUnaryOp() || isBinaryOp());
240 return Name[Name.size() - 1];
241 }
242
getBinaryPrecedence() const243 unsigned getBinaryPrecedence() const { return Precedence; }
244 };
245
246 /// FunctionAST - This class represents a function definition itself.
247 class FunctionAST {
248 std::unique_ptr<PrototypeAST> Proto;
249 std::unique_ptr<ExprAST> Body;
250
251 public:
FunctionAST(std::unique_ptr<PrototypeAST> Proto,std::unique_ptr<ExprAST> Body)252 FunctionAST(std::unique_ptr<PrototypeAST> Proto,
253 std::unique_ptr<ExprAST> Body)
254 : Proto(std::move(Proto)), Body(std::move(Body)) {}
255 Function *codegen();
256 };
257 } // end anonymous namespace
258
259 //===----------------------------------------------------------------------===//
260 // Parser
261 //===----------------------------------------------------------------------===//
262
263 /// CurTok/getNextToken - Provide a simple token buffer. CurTok is the current
264 /// token the parser is looking at. getNextToken reads another token from the
265 /// lexer and updates CurTok with its results.
266 static int CurTok;
getNextToken()267 static int getNextToken() { return CurTok = gettok(); }
268
269 /// BinopPrecedence - This holds the precedence for each binary operator that is
270 /// defined.
271 static std::map<char, int> BinopPrecedence;
272
273 /// GetTokPrecedence - Get the precedence of the pending binary operator token.
GetTokPrecedence()274 static int GetTokPrecedence() {
275 if (!isascii(CurTok))
276 return -1;
277
278 // Make sure it's a declared binop.
279 int TokPrec = BinopPrecedence[CurTok];
280 if (TokPrec <= 0)
281 return -1;
282 return TokPrec;
283 }
284
285 /// Error* - These are little helper functions for error handling.
LogError(const char * Str)286 std::unique_ptr<ExprAST> LogError(const char *Str) {
287 fprintf(stderr, "Error: %s\n", Str);
288 return nullptr;
289 }
290
LogErrorP(const char * Str)291 std::unique_ptr<PrototypeAST> LogErrorP(const char *Str) {
292 LogError(Str);
293 return nullptr;
294 }
295
296 static std::unique_ptr<ExprAST> ParseExpression();
297
298 /// numberexpr ::= number
ParseNumberExpr()299 static std::unique_ptr<ExprAST> ParseNumberExpr() {
300 auto Result = llvm::make_unique<NumberExprAST>(NumVal);
301 getNextToken(); // consume the number
302 return std::move(Result);
303 }
304
305 /// parenexpr ::= '(' expression ')'
ParseParenExpr()306 static std::unique_ptr<ExprAST> ParseParenExpr() {
307 getNextToken(); // eat (.
308 auto V = ParseExpression();
309 if (!V)
310 return nullptr;
311
312 if (CurTok != ')')
313 return LogError("expected ')'");
314 getNextToken(); // eat ).
315 return V;
316 }
317
318 /// identifierexpr
319 /// ::= identifier
320 /// ::= identifier '(' expression* ')'
ParseIdentifierExpr()321 static std::unique_ptr<ExprAST> ParseIdentifierExpr() {
322 std::string IdName = IdentifierStr;
323
324 getNextToken(); // eat identifier.
325
326 if (CurTok != '(') // Simple variable ref.
327 return llvm::make_unique<VariableExprAST>(IdName);
328
329 // Call.
330 getNextToken(); // eat (
331 std::vector<std::unique_ptr<ExprAST>> Args;
332 if (CurTok != ')') {
333 while (true) {
334 if (auto Arg = ParseExpression())
335 Args.push_back(std::move(Arg));
336 else
337 return nullptr;
338
339 if (CurTok == ')')
340 break;
341
342 if (CurTok != ',')
343 return LogError("Expected ')' or ',' in argument list");
344 getNextToken();
345 }
346 }
347
348 // Eat the ')'.
349 getNextToken();
350
351 return llvm::make_unique<CallExprAST>(IdName, std::move(Args));
352 }
353
354 /// ifexpr ::= 'if' expression 'then' expression 'else' expression
ParseIfExpr()355 static std::unique_ptr<ExprAST> ParseIfExpr() {
356 getNextToken(); // eat the if.
357
358 // condition.
359 auto Cond = ParseExpression();
360 if (!Cond)
361 return nullptr;
362
363 if (CurTok != tok_then)
364 return LogError("expected then");
365 getNextToken(); // eat the then
366
367 auto Then = ParseExpression();
368 if (!Then)
369 return nullptr;
370
371 if (CurTok != tok_else)
372 return LogError("expected else");
373
374 getNextToken();
375
376 auto Else = ParseExpression();
377 if (!Else)
378 return nullptr;
379
380 return llvm::make_unique<IfExprAST>(std::move(Cond), std::move(Then),
381 std::move(Else));
382 }
383
384 /// forexpr ::= 'for' identifier '=' expr ',' expr (',' expr)? 'in' expression
ParseForExpr()385 static std::unique_ptr<ExprAST> ParseForExpr() {
386 getNextToken(); // eat the for.
387
388 if (CurTok != tok_identifier)
389 return LogError("expected identifier after for");
390
391 std::string IdName = IdentifierStr;
392 getNextToken(); // eat identifier.
393
394 if (CurTok != '=')
395 return LogError("expected '=' after for");
396 getNextToken(); // eat '='.
397
398 auto Start = ParseExpression();
399 if (!Start)
400 return nullptr;
401 if (CurTok != ',')
402 return LogError("expected ',' after for start value");
403 getNextToken();
404
405 auto End = ParseExpression();
406 if (!End)
407 return nullptr;
408
409 // The step value is optional.
410 std::unique_ptr<ExprAST> Step;
411 if (CurTok == ',') {
412 getNextToken();
413 Step = ParseExpression();
414 if (!Step)
415 return nullptr;
416 }
417
418 if (CurTok != tok_in)
419 return LogError("expected 'in' after for");
420 getNextToken(); // eat 'in'.
421
422 auto Body = ParseExpression();
423 if (!Body)
424 return nullptr;
425
426 return llvm::make_unique<ForExprAST>(IdName, std::move(Start), std::move(End),
427 std::move(Step), std::move(Body));
428 }
429
430 /// primary
431 /// ::= identifierexpr
432 /// ::= numberexpr
433 /// ::= parenexpr
434 /// ::= ifexpr
435 /// ::= forexpr
ParsePrimary()436 static std::unique_ptr<ExprAST> ParsePrimary() {
437 switch (CurTok) {
438 default:
439 return LogError("unknown token when expecting an expression");
440 case tok_identifier:
441 return ParseIdentifierExpr();
442 case tok_number:
443 return ParseNumberExpr();
444 case '(':
445 return ParseParenExpr();
446 case tok_if:
447 return ParseIfExpr();
448 case tok_for:
449 return ParseForExpr();
450 }
451 }
452
453 /// unary
454 /// ::= primary
455 /// ::= '!' unary
ParseUnary()456 static std::unique_ptr<ExprAST> ParseUnary() {
457 // If the current token is not an operator, it must be a primary expr.
458 if (!isascii(CurTok) || CurTok == '(' || CurTok == ',')
459 return ParsePrimary();
460
461 // If this is a unary operator, read it.
462 int Opc = CurTok;
463 getNextToken();
464 if (auto Operand = ParseUnary())
465 return llvm::make_unique<UnaryExprAST>(Opc, std::move(Operand));
466 return nullptr;
467 }
468
469 /// binoprhs
470 /// ::= ('+' unary)*
ParseBinOpRHS(int ExprPrec,std::unique_ptr<ExprAST> LHS)471 static std::unique_ptr<ExprAST> ParseBinOpRHS(int ExprPrec,
472 std::unique_ptr<ExprAST> LHS) {
473 // If this is a binop, find its precedence.
474 while (true) {
475 int TokPrec = GetTokPrecedence();
476
477 // If this is a binop that binds at least as tightly as the current binop,
478 // consume it, otherwise we are done.
479 if (TokPrec < ExprPrec)
480 return LHS;
481
482 // Okay, we know this is a binop.
483 int BinOp = CurTok;
484 getNextToken(); // eat binop
485
486 // Parse the unary expression after the binary operator.
487 auto RHS = ParseUnary();
488 if (!RHS)
489 return nullptr;
490
491 // If BinOp binds less tightly with RHS than the operator after RHS, let
492 // the pending operator take RHS as its LHS.
493 int NextPrec = GetTokPrecedence();
494 if (TokPrec < NextPrec) {
495 RHS = ParseBinOpRHS(TokPrec + 1, std::move(RHS));
496 if (!RHS)
497 return nullptr;
498 }
499
500 // Merge LHS/RHS.
501 LHS =
502 llvm::make_unique<BinaryExprAST>(BinOp, std::move(LHS), std::move(RHS));
503 }
504 }
505
506 /// expression
507 /// ::= unary binoprhs
508 ///
ParseExpression()509 static std::unique_ptr<ExprAST> ParseExpression() {
510 auto LHS = ParseUnary();
511 if (!LHS)
512 return nullptr;
513
514 return ParseBinOpRHS(0, std::move(LHS));
515 }
516
517 /// prototype
518 /// ::= id '(' id* ')'
519 /// ::= binary LETTER number? (id, id)
520 /// ::= unary LETTER (id)
ParsePrototype()521 static std::unique_ptr<PrototypeAST> ParsePrototype() {
522 std::string FnName;
523
524 unsigned Kind = 0; // 0 = identifier, 1 = unary, 2 = binary.
525 unsigned BinaryPrecedence = 30;
526
527 switch (CurTok) {
528 default:
529 return LogErrorP("Expected function name in prototype");
530 case tok_identifier:
531 FnName = IdentifierStr;
532 Kind = 0;
533 getNextToken();
534 break;
535 case tok_unary:
536 getNextToken();
537 if (!isascii(CurTok))
538 return LogErrorP("Expected unary operator");
539 FnName = "unary";
540 FnName += (char)CurTok;
541 Kind = 1;
542 getNextToken();
543 break;
544 case tok_binary:
545 getNextToken();
546 if (!isascii(CurTok))
547 return LogErrorP("Expected binary operator");
548 FnName = "binary";
549 FnName += (char)CurTok;
550 Kind = 2;
551 getNextToken();
552
553 // Read the precedence if present.
554 if (CurTok == tok_number) {
555 if (NumVal < 1 || NumVal > 100)
556 return LogErrorP("Invalid precedecnce: must be 1..100");
557 BinaryPrecedence = (unsigned)NumVal;
558 getNextToken();
559 }
560 break;
561 }
562
563 if (CurTok != '(')
564 return LogErrorP("Expected '(' in prototype");
565
566 std::vector<std::string> ArgNames;
567 while (getNextToken() == tok_identifier)
568 ArgNames.push_back(IdentifierStr);
569 if (CurTok != ')')
570 return LogErrorP("Expected ')' in prototype");
571
572 // success.
573 getNextToken(); // eat ')'.
574
575 // Verify right number of names for operator.
576 if (Kind && ArgNames.size() != Kind)
577 return LogErrorP("Invalid number of operands for operator");
578
579 return llvm::make_unique<PrototypeAST>(FnName, ArgNames, Kind != 0,
580 BinaryPrecedence);
581 }
582
583 /// definition ::= 'def' prototype expression
ParseDefinition()584 static std::unique_ptr<FunctionAST> ParseDefinition() {
585 getNextToken(); // eat def.
586 auto Proto = ParsePrototype();
587 if (!Proto)
588 return nullptr;
589
590 if (auto E = ParseExpression())
591 return llvm::make_unique<FunctionAST>(std::move(Proto), std::move(E));
592 return nullptr;
593 }
594
595 /// toplevelexpr ::= expression
ParseTopLevelExpr()596 static std::unique_ptr<FunctionAST> ParseTopLevelExpr() {
597 if (auto E = ParseExpression()) {
598 // Make an anonymous proto.
599 auto Proto = llvm::make_unique<PrototypeAST>("__anon_expr",
600 std::vector<std::string>());
601 return llvm::make_unique<FunctionAST>(std::move(Proto), std::move(E));
602 }
603 return nullptr;
604 }
605
606 /// external ::= 'extern' prototype
ParseExtern()607 static std::unique_ptr<PrototypeAST> ParseExtern() {
608 getNextToken(); // eat extern.
609 return ParsePrototype();
610 }
611
612 //===----------------------------------------------------------------------===//
613 // Code Generation
614 //===----------------------------------------------------------------------===//
615
616 static LLVMContext TheContext;
617 static IRBuilder<> Builder(TheContext);
618 static std::unique_ptr<Module> TheModule;
619 static std::map<std::string, Value *> NamedValues;
620 static std::unique_ptr<legacy::FunctionPassManager> TheFPM;
621 static std::unique_ptr<KaleidoscopeJIT> TheJIT;
622 static std::map<std::string, std::unique_ptr<PrototypeAST>> FunctionProtos;
623
LogErrorV(const char * Str)624 Value *LogErrorV(const char *Str) {
625 LogError(Str);
626 return nullptr;
627 }
628
getFunction(std::string Name)629 Function *getFunction(std::string Name) {
630 // First, see if the function has already been added to the current module.
631 if (auto *F = TheModule->getFunction(Name))
632 return F;
633
634 // If not, check whether we can codegen the declaration from some existing
635 // prototype.
636 auto FI = FunctionProtos.find(Name);
637 if (FI != FunctionProtos.end())
638 return FI->second->codegen();
639
640 // If no existing prototype exists, return null.
641 return nullptr;
642 }
643
codegen()644 Value *NumberExprAST::codegen() {
645 return ConstantFP::get(TheContext, APFloat(Val));
646 }
647
codegen()648 Value *VariableExprAST::codegen() {
649 // Look this variable up in the function.
650 Value *V = NamedValues[Name];
651 if (!V)
652 return LogErrorV("Unknown variable name");
653 return V;
654 }
655
codegen()656 Value *UnaryExprAST::codegen() {
657 Value *OperandV = Operand->codegen();
658 if (!OperandV)
659 return nullptr;
660
661 Function *F = getFunction(std::string("unary") + Opcode);
662 if (!F)
663 return LogErrorV("Unknown unary operator");
664
665 return Builder.CreateCall(F, OperandV, "unop");
666 }
667
codegen()668 Value *BinaryExprAST::codegen() {
669 Value *L = LHS->codegen();
670 Value *R = RHS->codegen();
671 if (!L || !R)
672 return nullptr;
673
674 switch (Op) {
675 case '+':
676 return Builder.CreateFAdd(L, R, "addtmp");
677 case '-':
678 return Builder.CreateFSub(L, R, "subtmp");
679 case '*':
680 return Builder.CreateFMul(L, R, "multmp");
681 case '<':
682 L = Builder.CreateFCmpULT(L, R, "cmptmp");
683 // Convert bool 0/1 to double 0.0 or 1.0
684 return Builder.CreateUIToFP(L, Type::getDoubleTy(TheContext), "booltmp");
685 default:
686 break;
687 }
688
689 // If it wasn't a builtin binary operator, it must be a user defined one. Emit
690 // a call to it.
691 Function *F = getFunction(std::string("binary") + Op);
692 assert(F && "binary operator not found!");
693
694 Value *Ops[] = {L, R};
695 return Builder.CreateCall(F, Ops, "binop");
696 }
697
codegen()698 Value *CallExprAST::codegen() {
699 // Look up the name in the global module table.
700 Function *CalleeF = getFunction(Callee);
701 if (!CalleeF)
702 return LogErrorV("Unknown function referenced");
703
704 // If argument mismatch error.
705 if (CalleeF->arg_size() != Args.size())
706 return LogErrorV("Incorrect # arguments passed");
707
708 std::vector<Value *> ArgsV;
709 for (unsigned i = 0, e = Args.size(); i != e; ++i) {
710 ArgsV.push_back(Args[i]->codegen());
711 if (!ArgsV.back())
712 return nullptr;
713 }
714
715 return Builder.CreateCall(CalleeF, ArgsV, "calltmp");
716 }
717
codegen()718 Value *IfExprAST::codegen() {
719 Value *CondV = Cond->codegen();
720 if (!CondV)
721 return nullptr;
722
723 // Convert condition to a bool by comparing equal to 0.0.
724 CondV = Builder.CreateFCmpONE(
725 CondV, ConstantFP::get(TheContext, APFloat(0.0)), "ifcond");
726
727 Function *TheFunction = Builder.GetInsertBlock()->getParent();
728
729 // Create blocks for the then and else cases. Insert the 'then' block at the
730 // end of the function.
731 BasicBlock *ThenBB = BasicBlock::Create(TheContext, "then", TheFunction);
732 BasicBlock *ElseBB = BasicBlock::Create(TheContext, "else");
733 BasicBlock *MergeBB = BasicBlock::Create(TheContext, "ifcont");
734
735 Builder.CreateCondBr(CondV, ThenBB, ElseBB);
736
737 // Emit then value.
738 Builder.SetInsertPoint(ThenBB);
739
740 Value *ThenV = Then->codegen();
741 if (!ThenV)
742 return nullptr;
743
744 Builder.CreateBr(MergeBB);
745 // Codegen of 'Then' can change the current block, update ThenBB for the PHI.
746 ThenBB = Builder.GetInsertBlock();
747
748 // Emit else block.
749 TheFunction->getBasicBlockList().push_back(ElseBB);
750 Builder.SetInsertPoint(ElseBB);
751
752 Value *ElseV = Else->codegen();
753 if (!ElseV)
754 return nullptr;
755
756 Builder.CreateBr(MergeBB);
757 // Codegen of 'Else' can change the current block, update ElseBB for the PHI.
758 ElseBB = Builder.GetInsertBlock();
759
760 // Emit merge block.
761 TheFunction->getBasicBlockList().push_back(MergeBB);
762 Builder.SetInsertPoint(MergeBB);
763 PHINode *PN = Builder.CreatePHI(Type::getDoubleTy(TheContext), 2, "iftmp");
764
765 PN->addIncoming(ThenV, ThenBB);
766 PN->addIncoming(ElseV, ElseBB);
767 return PN;
768 }
769
770 // Output for-loop as:
771 // ...
772 // start = startexpr
773 // goto loop
774 // loop:
775 // variable = phi [start, loopheader], [nextvariable, loopend]
776 // ...
777 // bodyexpr
778 // ...
779 // loopend:
780 // step = stepexpr
781 // nextvariable = variable + step
782 // endcond = endexpr
783 // br endcond, loop, endloop
784 // outloop:
codegen()785 Value *ForExprAST::codegen() {
786 // Emit the start code first, without 'variable' in scope.
787 Value *StartVal = Start->codegen();
788 if (!StartVal)
789 return nullptr;
790
791 // Make the new basic block for the loop header, inserting after current
792 // block.
793 Function *TheFunction = Builder.GetInsertBlock()->getParent();
794 BasicBlock *PreheaderBB = Builder.GetInsertBlock();
795 BasicBlock *LoopBB = BasicBlock::Create(TheContext, "loop", TheFunction);
796
797 // Insert an explicit fall through from the current block to the LoopBB.
798 Builder.CreateBr(LoopBB);
799
800 // Start insertion in LoopBB.
801 Builder.SetInsertPoint(LoopBB);
802
803 // Start the PHI node with an entry for Start.
804 PHINode *Variable =
805 Builder.CreatePHI(Type::getDoubleTy(TheContext), 2, VarName);
806 Variable->addIncoming(StartVal, PreheaderBB);
807
808 // Within the loop, the variable is defined equal to the PHI node. If it
809 // shadows an existing variable, we have to restore it, so save it now.
810 Value *OldVal = NamedValues[VarName];
811 NamedValues[VarName] = Variable;
812
813 // Emit the body of the loop. This, like any other expr, can change the
814 // current BB. Note that we ignore the value computed by the body, but don't
815 // allow an error.
816 if (!Body->codegen())
817 return nullptr;
818
819 // Emit the step value.
820 Value *StepVal = nullptr;
821 if (Step) {
822 StepVal = Step->codegen();
823 if (!StepVal)
824 return nullptr;
825 } else {
826 // If not specified, use 1.0.
827 StepVal = ConstantFP::get(TheContext, APFloat(1.0));
828 }
829
830 Value *NextVar = Builder.CreateFAdd(Variable, StepVal, "nextvar");
831
832 // Compute the end condition.
833 Value *EndCond = End->codegen();
834 if (!EndCond)
835 return nullptr;
836
837 // Convert condition to a bool by comparing equal to 0.0.
838 EndCond = Builder.CreateFCmpONE(
839 EndCond, ConstantFP::get(TheContext, APFloat(0.0)), "loopcond");
840
841 // Create the "after loop" block and insert it.
842 BasicBlock *LoopEndBB = Builder.GetInsertBlock();
843 BasicBlock *AfterBB =
844 BasicBlock::Create(TheContext, "afterloop", TheFunction);
845
846 // Insert the conditional branch into the end of LoopEndBB.
847 Builder.CreateCondBr(EndCond, LoopBB, AfterBB);
848
849 // Any new code will be inserted in AfterBB.
850 Builder.SetInsertPoint(AfterBB);
851
852 // Add a new entry to the PHI node for the backedge.
853 Variable->addIncoming(NextVar, LoopEndBB);
854
855 // Restore the unshadowed variable.
856 if (OldVal)
857 NamedValues[VarName] = OldVal;
858 else
859 NamedValues.erase(VarName);
860
861 // for expr always returns 0.0.
862 return Constant::getNullValue(Type::getDoubleTy(TheContext));
863 }
864
codegen()865 Function *PrototypeAST::codegen() {
866 // Make the function type: double(double,double) etc.
867 std::vector<Type *> Doubles(Args.size(), Type::getDoubleTy(TheContext));
868 FunctionType *FT =
869 FunctionType::get(Type::getDoubleTy(TheContext), Doubles, false);
870
871 Function *F =
872 Function::Create(FT, Function::ExternalLinkage, Name, TheModule.get());
873
874 // Set names for all arguments.
875 unsigned Idx = 0;
876 for (auto &Arg : F->args())
877 Arg.setName(Args[Idx++]);
878
879 return F;
880 }
881
codegen()882 Function *FunctionAST::codegen() {
883 // Transfer ownership of the prototype to the FunctionProtos map, but keep a
884 // reference to it for use below.
885 auto &P = *Proto;
886 FunctionProtos[Proto->getName()] = std::move(Proto);
887 Function *TheFunction = getFunction(P.getName());
888 if (!TheFunction)
889 return nullptr;
890
891 // If this is an operator, install it.
892 if (P.isBinaryOp())
893 BinopPrecedence[P.getOperatorName()] = P.getBinaryPrecedence();
894
895 // Create a new basic block to start insertion into.
896 BasicBlock *BB = BasicBlock::Create(TheContext, "entry", TheFunction);
897 Builder.SetInsertPoint(BB);
898
899 // Record the function arguments in the NamedValues map.
900 NamedValues.clear();
901 for (auto &Arg : TheFunction->args())
902 NamedValues[Arg.getName()] = &Arg;
903
904 if (Value *RetVal = Body->codegen()) {
905 // Finish off the function.
906 Builder.CreateRet(RetVal);
907
908 // Validate the generated code, checking for consistency.
909 verifyFunction(*TheFunction);
910
911 // Run the optimizer on the function.
912 TheFPM->run(*TheFunction);
913
914 return TheFunction;
915 }
916
917 // Error reading body, remove function.
918 TheFunction->eraseFromParent();
919
920 if (P.isBinaryOp())
921 BinopPrecedence.erase(Proto->getOperatorName());
922 return nullptr;
923 }
924
925 //===----------------------------------------------------------------------===//
926 // Top-Level parsing and JIT Driver
927 //===----------------------------------------------------------------------===//
928
InitializeModuleAndPassManager()929 static void InitializeModuleAndPassManager() {
930 // Open a new module.
931 TheModule = llvm::make_unique<Module>("my cool jit", TheContext);
932 TheModule->setDataLayout(TheJIT->getTargetMachine().createDataLayout());
933
934 // Create a new pass manager attached to it.
935 TheFPM = llvm::make_unique<legacy::FunctionPassManager>(TheModule.get());
936
937 // Do simple "peephole" optimizations and bit-twiddling optzns.
938 TheFPM->add(createInstructionCombiningPass());
939 // Reassociate expressions.
940 TheFPM->add(createReassociatePass());
941 // Eliminate Common SubExpressions.
942 TheFPM->add(createGVNPass());
943 // Simplify the control flow graph (deleting unreachable blocks, etc).
944 TheFPM->add(createCFGSimplificationPass());
945
946 TheFPM->doInitialization();
947 }
948
HandleDefinition()949 static void HandleDefinition() {
950 if (auto FnAST = ParseDefinition()) {
951 if (auto *FnIR = FnAST->codegen()) {
952 fprintf(stderr, "Read function definition:");
953 FnIR->dump();
954 TheJIT->addModule(std::move(TheModule));
955 InitializeModuleAndPassManager();
956 }
957 } else {
958 // Skip token for error recovery.
959 getNextToken();
960 }
961 }
962
HandleExtern()963 static void HandleExtern() {
964 if (auto ProtoAST = ParseExtern()) {
965 if (auto *FnIR = ProtoAST->codegen()) {
966 fprintf(stderr, "Read extern: ");
967 FnIR->dump();
968 FunctionProtos[ProtoAST->getName()] = std::move(ProtoAST);
969 }
970 } else {
971 // Skip token for error recovery.
972 getNextToken();
973 }
974 }
975
HandleTopLevelExpression()976 static void HandleTopLevelExpression() {
977 // Evaluate a top-level expression into an anonymous function.
978 if (auto FnAST = ParseTopLevelExpr()) {
979 if (FnAST->codegen()) {
980 // JIT the module containing the anonymous expression, keeping a handle so
981 // we can free it later.
982 auto H = TheJIT->addModule(std::move(TheModule));
983 InitializeModuleAndPassManager();
984
985 // Search the JIT for the __anon_expr symbol.
986 auto ExprSymbol = TheJIT->findSymbol("__anon_expr");
987 assert(ExprSymbol && "Function not found");
988
989 // Get the symbol's address and cast it to the right type (takes no
990 // arguments, returns a double) so we can call it as a native function.
991 double (*FP)() = (double (*)())(intptr_t)ExprSymbol.getAddress();
992 fprintf(stderr, "Evaluated to %f\n", FP());
993
994 // Delete the anonymous expression module from the JIT.
995 TheJIT->removeModule(H);
996 }
997 } else {
998 // Skip token for error recovery.
999 getNextToken();
1000 }
1001 }
1002
1003 /// top ::= definition | external | expression | ';'
MainLoop()1004 static void MainLoop() {
1005 while (true) {
1006 fprintf(stderr, "ready> ");
1007 switch (CurTok) {
1008 case tok_eof:
1009 return;
1010 case ';': // ignore top-level semicolons.
1011 getNextToken();
1012 break;
1013 case tok_def:
1014 HandleDefinition();
1015 break;
1016 case tok_extern:
1017 HandleExtern();
1018 break;
1019 default:
1020 HandleTopLevelExpression();
1021 break;
1022 }
1023 }
1024 }
1025
1026 //===----------------------------------------------------------------------===//
1027 // "Library" functions that can be "extern'd" from user code.
1028 //===----------------------------------------------------------------------===//
1029
1030 /// putchard - putchar that takes a double and returns 0.
putchard(double X)1031 extern "C" double putchard(double X) {
1032 fputc((char)X, stderr);
1033 return 0;
1034 }
1035
1036 /// printd - printf that takes a double prints it as "%f\n", returning 0.
printd(double X)1037 extern "C" double printd(double X) {
1038 fprintf(stderr, "%f\n", X);
1039 return 0;
1040 }
1041
1042 //===----------------------------------------------------------------------===//
1043 // Main driver code.
1044 //===----------------------------------------------------------------------===//
1045
main()1046 int main() {
1047 InitializeNativeTarget();
1048 InitializeNativeTargetAsmPrinter();
1049 InitializeNativeTargetAsmParser();
1050
1051 // Install standard binary operators.
1052 // 1 is lowest precedence.
1053 BinopPrecedence['<'] = 10;
1054 BinopPrecedence['+'] = 20;
1055 BinopPrecedence['-'] = 20;
1056 BinopPrecedence['*'] = 40; // highest.
1057
1058 // Prime the first token.
1059 fprintf(stderr, "ready> ");
1060 getNextToken();
1061
1062 TheJIT = llvm::make_unique<KaleidoscopeJIT>();
1063
1064 InitializeModuleAndPassManager();
1065
1066 // Run the main "interpreter loop" now.
1067 MainLoop();
1068
1069 return 0;
1070 }
1071