1 /*
2  * Copyright (c) 2015 PLUMgrid, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <set>
18 #include <algorithm>
19 #include <sstream>
20 
21 #include <llvm/IR/BasicBlock.h>
22 #include <llvm/IR/CallingConv.h>
23 #include <llvm/IR/CFG.h>
24 #include <llvm/IR/Constants.h>
25 #include <llvm/IR/DerivedTypes.h>
26 #include <llvm/IR/Function.h>
27 #include <llvm/IR/GlobalVariable.h>
28 #include <llvm/IR/InlineAsm.h>
29 #include <llvm/IR/Instructions.h>
30 #include <llvm/IR/IRPrintingPasses.h>
31 #include <llvm/IR/IRBuilder.h>
32 #include <llvm/IR/LLVMContext.h>
33 #include <llvm/IR/Module.h>
34 
35 #include "bcc_exception.h"
36 #include "codegen_llvm.h"
37 #include "file_desc.h"
38 #include "lexer.h"
39 #include "libbpf.h"
40 #include "linux/bpf.h"
41 #include "table_storage.h"
42 #include "type_helper.h"
43 
44 namespace ebpf {
45 namespace cc {
46 
47 using namespace llvm;
48 
49 using std::for_each;
50 using std::make_tuple;
51 using std::map;
52 using std::pair;
53 using std::set;
54 using std::string;
55 using std::stringstream;
56 using std::to_string;
57 using std::vector;
58 
59 // can't forward declare IRBuilder in .h file (template with default
60 // parameters), so cast it instead :(
61 #define B (*((IRBuilder<> *)this->b_))
62 
63 // Helper class to push/pop the insert block
64 class BlockStack {
65  public:
BlockStack(CodegenLLVM * cc,BasicBlock * bb)66   explicit BlockStack(CodegenLLVM *cc, BasicBlock *bb)
67     : old_bb_(cc->b_->GetInsertBlock()), cc_(cc) {
68     cc_->b_->SetInsertPoint(bb);
69   }
~BlockStack()70   ~BlockStack() {
71     if (old_bb_)
72       cc_->b_->SetInsertPoint(old_bb_);
73     else
74       cc_->b_->ClearInsertionPoint();
75   }
76  private:
77   BasicBlock *old_bb_;
78   CodegenLLVM *cc_;
79 };
80 
81 // Helper class to push/pop switch statement insert block
82 class SwitchStack {
83  public:
SwitchStack(CodegenLLVM * cc,SwitchInst * sw)84   explicit SwitchStack(CodegenLLVM *cc, SwitchInst *sw)
85     : old_sw_(cc->cur_switch_), cc_(cc) {
86     cc_->cur_switch_ = sw;
87   }
~SwitchStack()88   ~SwitchStack() {
89     cc_->cur_switch_ = old_sw_;
90   }
91  private:
92   SwitchInst *old_sw_;
93   CodegenLLVM *cc_;
94 };
95 
CodegenLLVM(llvm::Module * mod,Scopes * scopes,Scopes * proto_scopes)96 CodegenLLVM::CodegenLLVM(llvm::Module *mod, Scopes *scopes, Scopes *proto_scopes)
97   : out_(stdout), mod_(mod), indent_(0), tmp_reg_index_(0), scopes_(scopes),
98     proto_scopes_(proto_scopes), expr_(nullptr) {
99   b_ = new IRBuilder<>(ctx());
100 }
~CodegenLLVM()101 CodegenLLVM::~CodegenLLVM() {
102   delete b_;
103 }
104 
105 template <typename... Args>
emit(const char * fmt,Args &&...params)106 void CodegenLLVM::emit(const char *fmt, Args&&... params) {
107   //fprintf(out_, fmt, std::forward<Args>(params)...);
108   //fflush(out_);
109 }
emit(const char * s)110 void CodegenLLVM::emit(const char *s) {
111   //fprintf(out_, "%s", s);
112   //fflush(out_);
113 }
114 
visit_block_stmt_node(BlockStmtNode * n)115 StatusTuple CodegenLLVM::visit_block_stmt_node(BlockStmtNode *n) {
116 
117   // enter scope
118   if (n->scope_)
119     scopes_->push_var(n->scope_);
120 
121   if (!n->stmts_.empty()) {
122     for (auto it = n->stmts_.begin(); it != n->stmts_.end(); ++it)
123       TRY2((*it)->accept(this));
124   }
125   // exit scope
126   if (n->scope_)
127     scopes_->pop_var();
128 
129   return StatusTuple(0);
130 }
131 
visit_if_stmt_node(IfStmtNode * n)132 StatusTuple CodegenLLVM::visit_if_stmt_node(IfStmtNode *n) {
133   Function *parent = B.GetInsertBlock()->getParent();
134   BasicBlock *label_then = BasicBlock::Create(ctx(), "if.then", parent);
135   BasicBlock *label_else = n->false_block_ ? BasicBlock::Create(ctx(), "if.else", parent) : nullptr;
136   BasicBlock *label_end = BasicBlock::Create(ctx(), "if.end", parent);
137 
138   TRY2(n->cond_->accept(this));
139   Value *is_not_null = B.CreateIsNotNull(pop_expr());
140 
141   if (n->false_block_)
142     B.CreateCondBr(is_not_null, label_then, label_else);
143   else
144     B.CreateCondBr(is_not_null, label_then, label_end);
145 
146   {
147     BlockStack bstack(this, label_then);
148     TRY2(n->true_block_->accept(this));
149     if (!B.GetInsertBlock()->getTerminator())
150       B.CreateBr(label_end);
151   }
152 
153   if (n->false_block_) {
154     BlockStack bstack(this, label_else);
155     TRY2(n->false_block_->accept(this));
156     if (!B.GetInsertBlock()->getTerminator())
157       B.CreateBr(label_end);
158   }
159 
160   B.SetInsertPoint(label_end);
161 
162   return StatusTuple(0);
163 }
164 
visit_onvalid_stmt_node(OnValidStmtNode * n)165 StatusTuple CodegenLLVM::visit_onvalid_stmt_node(OnValidStmtNode *n) {
166   TRY2(n->cond_->accept(this));
167 
168   Value *is_null = B.CreateIsNotNull(pop_expr());
169 
170   Function *parent = B.GetInsertBlock()->getParent();
171   BasicBlock *label_then = BasicBlock::Create(ctx(), "onvalid.then", parent);
172   BasicBlock *label_else = n->else_block_ ? BasicBlock::Create(ctx(), "onvalid.else", parent) : nullptr;
173   BasicBlock *label_end = BasicBlock::Create(ctx(), "onvalid.end", parent);
174 
175   if (n->else_block_)
176     B.CreateCondBr(is_null, label_then, label_else);
177   else
178     B.CreateCondBr(is_null, label_then, label_end);
179 
180   {
181     BlockStack bstack(this, label_then);
182     TRY2(n->block_->accept(this));
183     if (!B.GetInsertBlock()->getTerminator())
184       B.CreateBr(label_end);
185   }
186 
187   if (n->else_block_) {
188     BlockStack bstack(this, label_else);
189     TRY2(n->else_block_->accept(this));
190     if (!B.GetInsertBlock()->getTerminator())
191       B.CreateBr(label_end);
192   }
193 
194   B.SetInsertPoint(label_end);
195   return StatusTuple(0);
196 }
197 
visit_switch_stmt_node(SwitchStmtNode * n)198 StatusTuple CodegenLLVM::visit_switch_stmt_node(SwitchStmtNode *n) {
199   Function *parent = B.GetInsertBlock()->getParent();
200   BasicBlock *label_default = BasicBlock::Create(ctx(), "switch.default", parent);
201   BasicBlock *label_end = BasicBlock::Create(ctx(), "switch.end", parent);
202   // switch (cond)
203   TRY2(n->cond_->accept(this));
204   SwitchInst *switch_inst = B.CreateSwitch(pop_expr(), label_default);
205   B.SetInsertPoint(label_end);
206   {
207     // case 1..N
208     SwitchStack sstack(this, switch_inst);
209     TRY2(n->block_->accept(this));
210   }
211   // if other cases are terminal, erase the end label
212   if (pred_empty(label_end)) {
213     B.SetInsertPoint(resolve_label("DONE"));
214     label_end->eraseFromParent();
215   }
216   return StatusTuple(0);
217 }
218 
visit_case_stmt_node(CaseStmtNode * n)219 StatusTuple CodegenLLVM::visit_case_stmt_node(CaseStmtNode *n) {
220   if (!cur_switch_) return mkstatus_(n, "no valid switch instruction");
221   Function *parent = B.GetInsertBlock()->getParent();
222   BasicBlock *label_end = B.GetInsertBlock();
223   BasicBlock *dest;
224   if (n->value_) {
225     TRY2(n->value_->accept(this));
226     dest = BasicBlock::Create(ctx(), "switch.case", parent);
227     Value *cond = B.CreateIntCast(pop_expr(), cur_switch_->getCondition()->getType(), false);
228     cur_switch_->addCase(cast<ConstantInt>(cond), dest);
229   } else {
230     dest = cur_switch_->getDefaultDest();
231   }
232   {
233     BlockStack bstack(this, dest);
234     TRY2(n->block_->accept(this));
235     // if no trailing goto, fall to end
236     if (!B.GetInsertBlock()->getTerminator())
237       B.CreateBr(label_end);
238   }
239   return StatusTuple(0);
240 }
241 
visit_ident_expr_node(IdentExprNode * n)242 StatusTuple CodegenLLVM::visit_ident_expr_node(IdentExprNode *n) {
243   if (!n->decl_)
244     return mkstatus_(n, "variable lookup failed: %s", n->name_.c_str());
245   if (n->decl_->is_pointer()) {
246     if (n->sub_name_.size()) {
247       if (n->bitop_) {
248         // ident is holding a host endian number, don't use dext
249         if (n->is_lhs()) {
250           emit("%s%s->%s", n->decl_->scope_id(), n->c_str(), n->sub_name_.c_str());
251         } else {
252           emit("(((%s%s->%s) >> %d) & (((%s)1 << %d) - 1))", n->decl_->scope_id(), n->c_str(), n->sub_name_.c_str(),
253               n->bitop_->bit_offset_, bits_to_uint(n->bitop_->bit_width_ + 1), n->bitop_->bit_width_);
254         }
255         return mkstatus_(n, "unsupported");
256       } else {
257         if (n->struct_type_->id_->name_ == "_Packet" && n->sub_name_.substr(0, 3) == "arg") {
258           // convert arg1~arg8 into args[0]~args[7] assuming type_check verified the range already
259           auto arg_num = stoi(n->sub_name_.substr(3, 3));
260           if (arg_num < 5) {
261             emit("%s%s->args_lo[%d]", n->decl_->scope_id(), n->c_str(), arg_num - 1);
262           } else {
263             emit("%s%s->args_hi[%d]", n->decl_->scope_id(), n->c_str(), arg_num - 5);
264           }
265           return mkstatus_(n, "unsupported");
266         } else {
267           emit("%s%s->%s", n->decl_->scope_id(), n->c_str(), n->sub_name_.c_str());
268           auto it = vars_.find(n->decl_);
269           if (it == vars_.end()) return mkstatus_(n, "Cannot locate variable %s in vars_ table", n->c_str());
270           LoadInst *load_1 = B.CreateLoad(it->second);
271           vector<Value *> indices({B.getInt32(0), B.getInt32(n->sub_decl_->slot_)});
272           expr_ = B.CreateInBoundsGEP(load_1, indices);
273           if (!n->is_lhs())
274             expr_ = B.CreateLoad(pop_expr());
275         }
276       }
277     } else {
278       auto it = vars_.find(n->decl_);
279       if (it == vars_.end()) return mkstatus_(n, "Cannot locate variable %s in vars_ table", n->c_str());
280       expr_ = n->is_lhs() ? it->second : (Value *)B.CreateLoad(it->second);
281     }
282   } else {
283     if (n->sub_name_.size()) {
284       emit("%s%s.%s", n->decl_->scope_id(), n->c_str(), n->sub_name_.c_str());
285       auto it = vars_.find(n->decl_);
286       if (it == vars_.end()) return mkstatus_(n, "Cannot locate variable %s in vars_ table", n->c_str());
287       vector<Value *> indices({const_int(0), const_int(n->sub_decl_->slot_, 32)});
288       expr_ = B.CreateGEP(nullptr, it->second, indices);
289       if (!n->is_lhs())
290         expr_ = B.CreateLoad(pop_expr());
291     } else {
292       if (n->bitop_) {
293         // ident is holding a host endian number, don't use dext
294         if (n->is_lhs())
295           return mkstatus_(n, "illegal: ident %s is a left-hand-side type", n->name_.c_str());
296         if (n->decl_->is_struct())
297           return mkstatus_(n, "illegal: can only take bitop of a struct subfield");
298         emit("(((%s%s) >> %d) & (((%s)1 << %d) - 1))", n->decl_->scope_id(), n->c_str(),
299              n->bitop_->bit_offset_, bits_to_uint(n->bitop_->bit_width_ + 1), n->bitop_->bit_width_);
300       } else {
301         emit("%s%s", n->decl_->scope_id(), n->c_str());
302         auto it = vars_.find(n->decl_);
303         if (it == vars_.end()) return mkstatus_(n, "Cannot locate variable %s in vars_ table", n->c_str());
304         if (n->is_lhs() || n->decl_->is_struct())
305           expr_ = it->second;
306         else
307           expr_ = B.CreateLoad(it->second);
308       }
309     }
310   }
311   return StatusTuple(0);
312 }
313 
visit_assign_expr_node(AssignExprNode * n)314 StatusTuple CodegenLLVM::visit_assign_expr_node(AssignExprNode *n) {
315   if (n->bitop_) {
316     TRY2(n->lhs_->accept(this));
317     emit(" = (");
318     TRY2(n->lhs_->accept(this));
319     emit(" & ~((((%s)1 << %d) - 1) << %d)) | (", bits_to_uint(n->lhs_->bit_width_),
320          n->bitop_->bit_width_, n->bitop_->bit_offset_);
321     TRY2(n->rhs_->accept(this));
322     emit(" << %d)", n->bitop_->bit_offset_);
323     return mkstatus_(n, "unsupported");
324   } else {
325     if (n->lhs_->flags_[ExprNode::PROTO]) {
326       // auto f = n->lhs_->struct_type_->field(n->id_->sub_name_);
327       // emit("bpf_dins(%s%s + %zu, %zu, %zu, ", n->id_->decl_->scope_id(), n->id_->c_str(),
328       //      f->bit_offset_ >> 3, f->bit_offset_ & 0x7, f->bit_width_);
329       // TRY2(n->rhs_->accept(this));
330       // emit(")");
331       return mkstatus_(n, "unsupported");
332     } else {
333       TRY2(n->rhs_->accept(this));
334       if (n->lhs_->is_pkt()) {
335         TRY2(n->lhs_->accept(this));
336       } else {
337         Value *rhs = pop_expr();
338         TRY2(n->lhs_->accept(this));
339         Value *lhs = pop_expr();
340         if (!n->rhs_->is_ref())
341           rhs = B.CreateIntCast(rhs, cast<PointerType>(lhs->getType())->getElementType(), false);
342         B.CreateStore(rhs, lhs);
343       }
344     }
345   }
346   return StatusTuple(0);
347 }
348 
lookup_var(Node * n,const string & name,Scopes::VarScope * scope,VariableDeclStmtNode ** decl,Value ** mem) const349 StatusTuple CodegenLLVM::lookup_var(Node *n, const string &name, Scopes::VarScope *scope,
350                                     VariableDeclStmtNode **decl, Value **mem) const {
351   *decl = scope->lookup(name, SCOPE_GLOBAL);
352   if (!*decl) return mkstatus_(n, "cannot find %s variable", name.c_str());
353   auto it = vars_.find(*decl);
354   if (it == vars_.end()) return mkstatus_(n, "unable to find %s memory location", name.c_str());
355   *mem = it->second;
356   return StatusTuple(0);
357 }
358 
visit_packet_expr_node(PacketExprNode * n)359 StatusTuple CodegenLLVM::visit_packet_expr_node(PacketExprNode *n) {
360   auto p = proto_scopes_->top_struct()->lookup(n->id_->name_, true);
361   VariableDeclStmtNode *offset_decl, *skb_decl;
362   Value *offset_mem, *skb_mem;
363   TRY2(lookup_var(n, "skb", scopes_->current_var(), &skb_decl, &skb_mem));
364   TRY2(lookup_var(n, "$" + n->id_->name_, scopes_->current_var(), &offset_decl, &offset_mem));
365 
366   if (p) {
367     auto f = p->field(n->id_->sub_name_);
368     if (f) {
369       size_t bit_offset = f->bit_offset_;
370       size_t bit_width = f->bit_width_;
371       if (n->bitop_) {
372         bit_offset += f->bit_width_ - (n->bitop_->bit_offset_ + n->bitop_->bit_width_);
373         bit_width = std::min(bit_width - n->bitop_->bit_offset_, n->bitop_->bit_width_);
374       }
375       if (n->is_ref()) {
376         // e.g.: @ip.hchecksum, return offset of the header within packet
377         LoadInst *offset_ptr = B.CreateLoad(offset_mem);
378         Value *skb_hdr_offset = B.CreateAdd(offset_ptr, B.getInt64(bit_offset >> 3));
379         expr_ = B.CreateIntCast(skb_hdr_offset, B.getInt64Ty(), false);
380       } else if (n->is_lhs()) {
381         emit("bpf_dins_pkt(pkt, %s + %zu, %zu, %zu, ", n->id_->c_str(), bit_offset >> 3, bit_offset & 0x7, bit_width);
382         Function *store_fn = mod_->getFunction("bpf_dins_pkt");
383         if (!store_fn) return mkstatus_(n, "unable to find function bpf_dins_pkt");
384         LoadInst *skb_ptr = B.CreateLoad(skb_mem);
385         Value *skb_ptr8 = B.CreateBitCast(skb_ptr, B.getInt8PtrTy());
386         LoadInst *offset_ptr = B.CreateLoad(offset_mem);
387         Value *skb_hdr_offset = B.CreateAdd(offset_ptr, B.getInt64(bit_offset >> 3));
388         Value *rhs = B.CreateIntCast(pop_expr(), B.getInt64Ty(), false);
389         B.CreateCall(store_fn, vector<Value *>({skb_ptr8, skb_hdr_offset, B.getInt64(bit_offset & 0x7),
390                                                B.getInt64(bit_width), rhs}));
391       } else {
392         emit("bpf_dext_pkt(pkt, %s + %zu, %zu, %zu)", n->id_->c_str(), bit_offset >> 3, bit_offset & 0x7, bit_width);
393         Function *load_fn = mod_->getFunction("bpf_dext_pkt");
394         if (!load_fn) return mkstatus_(n, "unable to find function bpf_dext_pkt");
395         LoadInst *skb_ptr = B.CreateLoad(skb_mem);
396         Value *skb_ptr8 = B.CreateBitCast(skb_ptr, B.getInt8PtrTy());
397         LoadInst *offset_ptr = B.CreateLoad(offset_mem);
398         Value *skb_hdr_offset = B.CreateAdd(offset_ptr, B.getInt64(bit_offset >> 3));
399         expr_ = B.CreateCall(load_fn, vector<Value *>({skb_ptr8, skb_hdr_offset,
400                                                       B.getInt64(bit_offset & 0x7), B.getInt64(bit_width)}));
401         // this generates extra trunc insns whereas the bpf.load fns already
402         // trunc the values internally in the bpf interpeter
403         //expr_ = B.CreateTrunc(pop_expr(), B.getIntNTy(bit_width));
404       }
405     } else {
406       emit("pkt->start + pkt->offset + %s", n->id_->c_str());
407       return mkstatus_(n, "unsupported");
408     }
409   }
410   return StatusTuple(0);
411 }
412 
visit_integer_expr_node(IntegerExprNode * n)413 StatusTuple CodegenLLVM::visit_integer_expr_node(IntegerExprNode *n) {
414   APInt val;
415   StringRef(n->val_).getAsInteger(0, val);
416   expr_ = ConstantInt::get(mod_->getContext(), val);
417   if (n->bits_)
418     expr_ = B.CreateIntCast(expr_, B.getIntNTy(n->bits_), false);
419   return StatusTuple(0);
420 }
421 
visit_string_expr_node(StringExprNode * n)422 StatusTuple CodegenLLVM::visit_string_expr_node(StringExprNode *n) {
423   if (n->is_lhs()) return mkstatus_(n, "cannot assign to a string");
424 
425   Value *global = B.CreateGlobalString(n->val_);
426   Value *ptr = make_alloca(resolve_entry_stack(), B.getInt8Ty(), "",
427                            B.getInt64(n->val_.size() + 1));
428 #if LLVM_MAJOR_VERSION >= 7
429   B.CreateMemCpy(ptr, 1, global, 1, n->val_.size() + 1);
430 #else
431   B.CreateMemCpy(ptr, global, n->val_.size() + 1, 1);
432 #endif
433   expr_ = ptr;
434 
435   return StatusTuple(0);
436 }
437 
emit_short_circuit_and(BinopExprNode * n)438 StatusTuple CodegenLLVM::emit_short_circuit_and(BinopExprNode *n) {
439   Function *parent = B.GetInsertBlock()->getParent();
440   BasicBlock *label_start = B.GetInsertBlock();
441   BasicBlock *label_then = BasicBlock::Create(ctx(), "and.then", parent);
442   BasicBlock *label_end = BasicBlock::Create(ctx(), "and.end", parent);
443 
444   TRY2(n->lhs_->accept(this));
445   Value *neq_zero = B.CreateICmpNE(pop_expr(), B.getIntN(n->lhs_->bit_width_, 0));
446   B.CreateCondBr(neq_zero, label_then, label_end);
447 
448   {
449     BlockStack bstack(this, label_then);
450     TRY2(n->rhs_->accept(this));
451     expr_ = B.CreateICmpNE(pop_expr(), B.getIntN(n->rhs_->bit_width_, 0));
452     B.CreateBr(label_end);
453   }
454 
455   B.SetInsertPoint(label_end);
456 
457   PHINode *phi = B.CreatePHI(B.getInt1Ty(), 2);
458   phi->addIncoming(B.getFalse(), label_start);
459   phi->addIncoming(pop_expr(), label_then);
460   expr_ = phi;
461 
462   return StatusTuple(0);
463 }
464 
emit_short_circuit_or(BinopExprNode * n)465 StatusTuple CodegenLLVM::emit_short_circuit_or(BinopExprNode *n) {
466   Function *parent = B.GetInsertBlock()->getParent();
467   BasicBlock *label_start = B.GetInsertBlock();
468   BasicBlock *label_then = BasicBlock::Create(ctx(), "or.then", parent);
469   BasicBlock *label_end = BasicBlock::Create(ctx(), "or.end", parent);
470 
471   TRY2(n->lhs_->accept(this));
472   Value *neq_zero = B.CreateICmpNE(pop_expr(), B.getIntN(n->lhs_->bit_width_, 0));
473   B.CreateCondBr(neq_zero, label_end, label_then);
474 
475   {
476     BlockStack bstack(this, label_then);
477     TRY2(n->rhs_->accept(this));
478     expr_ = B.CreateICmpNE(pop_expr(), B.getIntN(n->rhs_->bit_width_, 0));
479     B.CreateBr(label_end);
480   }
481 
482   B.SetInsertPoint(label_end);
483 
484   PHINode *phi = B.CreatePHI(B.getInt1Ty(), 2);
485   phi->addIncoming(B.getTrue(), label_start);
486   phi->addIncoming(pop_expr(), label_then);
487   expr_ = phi;
488 
489   return StatusTuple(0);
490 }
491 
visit_binop_expr_node(BinopExprNode * n)492 StatusTuple CodegenLLVM::visit_binop_expr_node(BinopExprNode *n) {
493   if (n->op_ == Tok::TAND)
494     return emit_short_circuit_and(n);
495   if (n->op_ == Tok::TOR)
496     return emit_short_circuit_or(n);
497 
498   TRY2(n->lhs_->accept(this));
499   Value *lhs = pop_expr();
500   TRY2(n->rhs_->accept(this));
501   Value *rhs = B.CreateIntCast(pop_expr(), lhs->getType(), false);
502   switch (n->op_) {
503     case Tok::TCEQ: expr_ = B.CreateICmpEQ(lhs, rhs); break;
504     case Tok::TCNE: expr_ = B.CreateICmpNE(lhs, rhs); break;
505     case Tok::TXOR: expr_ = B.CreateXor(lhs, rhs); break;
506     case Tok::TMOD: expr_ = B.CreateURem(lhs, rhs); break;
507     case Tok::TCLT: expr_ = B.CreateICmpULT(lhs, rhs); break;
508     case Tok::TCLE: expr_ = B.CreateICmpULE(lhs, rhs); break;
509     case Tok::TCGT: expr_ = B.CreateICmpUGT(lhs, rhs); break;
510     case Tok::TCGE: expr_ = B.CreateICmpUGE(lhs, rhs); break;
511     case Tok::TPLUS: expr_ = B.CreateAdd(lhs, rhs); break;
512     case Tok::TMINUS: expr_ = B.CreateSub(lhs, rhs); break;
513     case Tok::TLAND: expr_ = B.CreateAnd(lhs, rhs); break;
514     case Tok::TLOR: expr_ = B.CreateOr(lhs, rhs); break;
515     default: return mkstatus_(n, "unsupported binary operator");
516   }
517   return StatusTuple(0);
518 }
519 
visit_unop_expr_node(UnopExprNode * n)520 StatusTuple CodegenLLVM::visit_unop_expr_node(UnopExprNode *n) {
521   TRY2(n->expr_->accept(this));
522   switch (n->op_) {
523     case Tok::TNOT: expr_ = B.CreateNot(pop_expr()); break;
524     case Tok::TCMPL: expr_ = B.CreateNeg(pop_expr()); break;
525     default: {}
526   }
527   return StatusTuple(0);
528 }
529 
visit_bitop_expr_node(BitopExprNode * n)530 StatusTuple CodegenLLVM::visit_bitop_expr_node(BitopExprNode *n) {
531   return StatusTuple(0);
532 }
533 
visit_goto_expr_node(GotoExprNode * n)534 StatusTuple CodegenLLVM::visit_goto_expr_node(GotoExprNode *n) {
535   if (n->id_->name_ == "DONE") {
536     return mkstatus_(n, "use return statement instead");
537   }
538   string jump_label;
539   // when dealing with multistates, goto statements may be overridden
540   auto rewrite_it = proto_rewrites_.find(n->id_->full_name());
541   auto default_it = proto_rewrites_.find("");
542   if (rewrite_it != proto_rewrites_.end()) {
543     jump_label = rewrite_it->second;
544   } else if (default_it != proto_rewrites_.end()) {
545     jump_label = default_it->second;
546   } else {
547     auto state = scopes_->current_state()->lookup(n->id_->full_name(), false);
548     if (state) {
549       jump_label = state->scoped_name();
550       if (n->is_continue_) {
551         jump_label += "_continue";
552       }
553     } else {
554       state = scopes_->current_state()->lookup("EOP", false);
555       if (state) {
556         jump_label = state->scoped_name();
557       }
558     }
559   }
560   B.CreateBr(resolve_label(jump_label));
561   return StatusTuple(0);
562 }
563 
visit_return_expr_node(ReturnExprNode * n)564 StatusTuple CodegenLLVM::visit_return_expr_node(ReturnExprNode *n) {
565   TRY2(n->expr_->accept(this));
566   Function *parent = B.GetInsertBlock()->getParent();
567   Value *cast_1 = B.CreateIntCast(pop_expr(), parent->getReturnType(), true);
568   B.CreateStore(cast_1, retval_);
569   B.CreateBr(resolve_label("DONE"));
570   return StatusTuple(0);
571 }
572 
emit_table_lookup(MethodCallExprNode * n)573 StatusTuple CodegenLLVM::emit_table_lookup(MethodCallExprNode *n) {
574   TableDeclStmtNode* table = scopes_->top_table()->lookup(n->id_->name_);
575   IdentExprNode* arg0 = static_cast<IdentExprNode*>(n->args_.at(0).get());
576   IdentExprNode* arg1;
577   StructVariableDeclStmtNode* arg1_type;
578 
579   auto table_fd_it = table_fds_.find(table);
580   if (table_fd_it == table_fds_.end())
581     return mkstatus_(n, "unable to find table %s in table_fds_", n->id_->c_str());
582 
583   Function *pseudo_fn = mod_->getFunction("llvm.bpf.pseudo");
584   if (!pseudo_fn) return mkstatus_(n, "pseudo fd loader doesn't exist");
585   Function *lookup_fn = mod_->getFunction("bpf_map_lookup_elem_");
586   if (!lookup_fn) return mkstatus_(n, "bpf_map_lookup_elem_ undefined");
587 
588   CallInst *pseudo_call = B.CreateCall(pseudo_fn, vector<Value *>({B.getInt64(BPF_PSEUDO_MAP_FD),
589                                                                   B.getInt64(table_fd_it->second)}));
590   Value *pseudo_map_fd = pseudo_call;
591 
592   TRY2(arg0->accept(this));
593   Value *key_ptr = B.CreateBitCast(pop_expr(), B.getInt8PtrTy());
594 
595   expr_ = B.CreateCall(lookup_fn, vector<Value *>({pseudo_map_fd, key_ptr}));
596 
597   if (table->type_id()->name_ == "FIXED_MATCH" || table->type_id()->name_ == "INDEXED") {
598     if (n->args_.size() == 2) {
599       arg1 = static_cast<IdentExprNode*>(n->args_.at(1).get());
600       arg1_type = static_cast<StructVariableDeclStmtNode*>(arg1->decl_);
601       if (table->leaf_id()->name_ != arg1_type->struct_id_->name_) {
602         return mkstatus_(n, "lookup pointer type mismatch %s != %s", table->leaf_id()->c_str(),
603                         arg1_type->struct_id_->c_str());
604       }
605       auto it = vars_.find(arg1_type);
606       if (it == vars_.end()) return mkstatus_(n, "Cannot locate variable %s in vars_ table", n->id_->c_str());
607       expr_ = B.CreateBitCast(pop_expr(), cast<PointerType>(it->second->getType())->getElementType());
608       B.CreateStore(pop_expr(), it->second);
609     }
610   } else {
611     return mkstatus_(n, "lookup in table type %s unsupported", table->type_id()->c_str());
612   }
613   return StatusTuple(0);
614 }
615 
emit_table_update(MethodCallExprNode * n)616 StatusTuple CodegenLLVM::emit_table_update(MethodCallExprNode *n) {
617   TableDeclStmtNode* table = scopes_->top_table()->lookup(n->id_->name_);
618   IdentExprNode* arg0 = static_cast<IdentExprNode*>(n->args_.at(0).get());
619   IdentExprNode* arg1 = static_cast<IdentExprNode*>(n->args_.at(1).get());
620 
621   auto table_fd_it = table_fds_.find(table);
622   if (table_fd_it == table_fds_.end())
623     return mkstatus_(n, "unable to find table %s in table_fds_", n->id_->c_str());
624   Function *pseudo_fn = mod_->getFunction("llvm.bpf.pseudo");
625   if (!pseudo_fn) return mkstatus_(n, "pseudo fd loader doesn't exist");
626   Function *update_fn = mod_->getFunction("bpf_map_update_elem_");
627   if (!update_fn) return mkstatus_(n, "bpf_map_update_elem_ undefined");
628 
629   CallInst *pseudo_call = B.CreateCall(pseudo_fn, vector<Value *>({B.getInt64(BPF_PSEUDO_MAP_FD),
630                                         B.getInt64(table_fd_it->second)}));
631   Value *pseudo_map_fd = pseudo_call;
632 
633   TRY2(arg0->accept(this));
634   Value *key_ptr = B.CreateBitCast(pop_expr(), B.getInt8PtrTy());
635 
636   if (table->type_id()->name_ == "FIXED_MATCH" || table->type_id()->name_ == "INDEXED") {
637     TRY2(arg1->accept(this));
638     Value *value_ptr = B.CreateBitCast(pop_expr(), B.getInt8PtrTy());
639 
640     expr_ = B.CreateCall(update_fn, vector<Value *>({pseudo_map_fd, key_ptr, value_ptr, B.getInt64(BPF_ANY)}));
641   } else {
642     return mkstatus_(n, "unsupported");
643   }
644   return StatusTuple(0);
645 }
646 
emit_table_delete(MethodCallExprNode * n)647 StatusTuple CodegenLLVM::emit_table_delete(MethodCallExprNode *n) {
648   TableDeclStmtNode* table = scopes_->top_table()->lookup(n->id_->name_);
649   IdentExprNode* arg0 = static_cast<IdentExprNode*>(n->args_.at(0).get());
650 
651   auto table_fd_it = table_fds_.find(table);
652   if (table_fd_it == table_fds_.end())
653     return mkstatus_(n, "unable to find table %s in table_fds_", n->id_->c_str());
654   Function *pseudo_fn = mod_->getFunction("llvm.bpf.pseudo");
655   if (!pseudo_fn) return mkstatus_(n, "pseudo fd loader doesn't exist");
656   Function *update_fn = mod_->getFunction("bpf_map_update_elem_");
657   if (!update_fn) return mkstatus_(n, "bpf_map_update_elem_ undefined");
658 
659   CallInst *pseudo_call = B.CreateCall(pseudo_fn, vector<Value *>({B.getInt64(BPF_PSEUDO_MAP_FD),
660                                         B.getInt64(table_fd_it->second)}));
661   Value *pseudo_map_fd = pseudo_call;
662 
663   TRY2(arg0->accept(this));
664   Value *key_ptr = B.CreateBitCast(pop_expr(), B.getInt8PtrTy());
665 
666   if (table->type_id()->name_ == "FIXED_MATCH" || table->type_id()->name_ == "INDEXED") {
667     expr_ = B.CreateCall(update_fn, vector<Value *>({pseudo_map_fd, key_ptr}));
668   } else {
669     return mkstatus_(n, "unsupported");
670   }
671   return StatusTuple(0);
672 }
673 
emit_log(MethodCallExprNode * n)674 StatusTuple CodegenLLVM::emit_log(MethodCallExprNode *n) {
675   vector<Value *> args;
676   auto arg = n->args_.begin();
677   TRY2((*arg)->accept(this));
678   args.push_back(pop_expr());
679   args.push_back(B.getInt64(((*arg)->bit_width_ >> 3) + 1));
680   ++arg;
681   for (; arg != n->args_.end(); ++arg) {
682     TRY2((*arg)->accept(this));
683     args.push_back(pop_expr());
684   }
685 
686   // int bpf_trace_printk(fmt, sizeof(fmt), ...)
687   FunctionType *printk_fn_type = FunctionType::get(B.getInt32Ty(), vector<Type *>({B.getInt8PtrTy(), B.getInt64Ty()}), true);
688   Value *printk_fn = B.CreateIntToPtr(B.getInt64(BPF_FUNC_trace_printk),
689                                          PointerType::getUnqual(printk_fn_type));
690 
691   expr_ = B.CreateCall(printk_fn, args);
692   return StatusTuple(0);
693 }
694 
emit_packet_rewrite_field(MethodCallExprNode * n)695 StatusTuple CodegenLLVM::emit_packet_rewrite_field(MethodCallExprNode *n) {
696   TRY2(n->args_[1]->accept(this));
697   TRY2(n->args_[0]->accept(this));
698   return StatusTuple(0);
699 }
700 
emit_atomic_add(MethodCallExprNode * n)701 StatusTuple CodegenLLVM::emit_atomic_add(MethodCallExprNode *n) {
702   TRY2(n->args_[0]->accept(this));
703   Value *lhs = B.CreateBitCast(pop_expr(), Type::getInt64PtrTy(ctx()));
704   TRY2(n->args_[1]->accept(this));
705   Value *rhs = B.CreateSExt(pop_expr(), B.getInt64Ty());
706   AtomicRMWInst *atomic_inst = B.CreateAtomicRMW(
707       AtomicRMWInst::Add, lhs, rhs, AtomicOrdering::SequentiallyConsistent);
708   atomic_inst->setVolatile(false);
709   return StatusTuple(0);
710 }
711 
emit_incr_cksum(MethodCallExprNode * n,size_t sz)712 StatusTuple CodegenLLVM::emit_incr_cksum(MethodCallExprNode *n, size_t sz) {
713   Value *is_pseudo;
714   string csum_fn_str;
715   if (n->args_.size() == 4) {
716     TRY2(n->args_[3]->accept(this));
717     is_pseudo = B.CreateIntCast(B.CreateIsNotNull(pop_expr()), B.getInt64Ty(), false);
718     csum_fn_str = "bpf_l4_csum_replace_";
719   } else {
720     is_pseudo = B.getInt64(0);
721     csum_fn_str = "bpf_l3_csum_replace_";
722   }
723 
724   TRY2(n->args_[2]->accept(this));
725   Value *new_val = B.CreateZExt(pop_expr(), B.getInt64Ty());
726   TRY2(n->args_[1]->accept(this));
727   Value *old_val = B.CreateZExt(pop_expr(), B.getInt64Ty());
728   TRY2(n->args_[0]->accept(this));
729   Value *offset = B.CreateZExt(pop_expr(), B.getInt64Ty());
730 
731   Function *csum_fn = mod_->getFunction(csum_fn_str);
732   if (!csum_fn) return mkstatus_(n, "Undefined built-in %s", csum_fn_str.c_str());
733 
734   // flags = (is_pseudo << 4) | sizeof(old_val)
735   Value *flags_lower = B.getInt64(sz ? sz : bits_to_size(n->args_[1]->bit_width_));
736   Value *flags_upper = B.CreateShl(is_pseudo, B.getInt64(4));
737   Value *flags = B.CreateOr(flags_upper, flags_lower);
738 
739   VariableDeclStmtNode *skb_decl;
740   Value *skb_mem;
741   TRY2(lookup_var(n, "skb", scopes_->current_var(), &skb_decl, &skb_mem));
742   LoadInst *skb_ptr = B.CreateLoad(skb_mem);
743   Value *skb_ptr8 = B.CreateBitCast(skb_ptr, B.getInt8PtrTy());
744 
745   expr_ = B.CreateCall(csum_fn, vector<Value *>({skb_ptr8, offset, old_val, new_val, flags}));
746   return StatusTuple(0);
747 }
748 
emit_get_usec_time(MethodCallExprNode * n)749 StatusTuple CodegenLLVM::emit_get_usec_time(MethodCallExprNode *n) {
750   return StatusTuple(0);
751 }
752 
visit_method_call_expr_node(MethodCallExprNode * n)753 StatusTuple CodegenLLVM::visit_method_call_expr_node(MethodCallExprNode *n) {
754   if (n->id_->sub_name_.size()) {
755     if (n->id_->sub_name_ == "lookup") {
756       TRY2(emit_table_lookup(n));
757     } else if (n->id_->sub_name_ == "update") {
758       TRY2(emit_table_update(n));
759     } else if (n->id_->sub_name_ == "delete") {
760       TRY2(emit_table_delete(n));
761     } else if (n->id_->sub_name_ == "rewrite_field" && n->id_->name_ == "pkt") {
762       TRY2(emit_packet_rewrite_field(n));
763     }
764   } else if (n->id_->name_ == "atomic_add") {
765     TRY2(emit_atomic_add(n));
766   } else if (n->id_->name_ == "log") {
767     TRY2(emit_log(n));
768   } else if (n->id_->name_ == "incr_cksum") {
769     TRY2(emit_incr_cksum(n));
770   } else if (n->id_->name_ == "get_usec_time") {
771     TRY2(emit_get_usec_time(n));
772   } else {
773     return mkstatus_(n, "unsupported");
774   }
775   TRY2(n->block_->accept(this));
776   return StatusTuple(0);
777 }
778 
779 /* result = lookup(key)
780  * if (!result) {
781  *   update(key, {0}, BPF_NOEXIST)
782  *   result = lookup(key)
783  * }
784  */
visit_table_index_expr_node(TableIndexExprNode * n)785 StatusTuple CodegenLLVM::visit_table_index_expr_node(TableIndexExprNode *n) {
786   auto table_fd_it = table_fds_.find(n->table_);
787   if (table_fd_it == table_fds_.end())
788     return mkstatus_(n, "unable to find table %s in table_fds_", n->id_->c_str());
789 
790   Function *pseudo_fn = mod_->getFunction("llvm.bpf.pseudo");
791   if (!pseudo_fn) return mkstatus_(n, "pseudo fd loader doesn't exist");
792   Function *update_fn = mod_->getFunction("bpf_map_update_elem_");
793   if (!update_fn) return mkstatus_(n, "bpf_map_update_elem_ undefined");
794   Function *lookup_fn = mod_->getFunction("bpf_map_lookup_elem_");
795   if (!lookup_fn) return mkstatus_(n, "bpf_map_lookup_elem_ undefined");
796   StructType *leaf_type;
797   TRY2(lookup_struct_type(n->table_->leaf_type_, &leaf_type));
798   PointerType *leaf_ptype = PointerType::getUnqual(leaf_type);
799 
800   CallInst *pseudo_call = B.CreateCall(pseudo_fn, vector<Value *>({B.getInt64(BPF_PSEUDO_MAP_FD),
801                                         B.getInt64(table_fd_it->second)}));
802   Value *pseudo_map_fd = pseudo_call;
803 
804   TRY2(n->index_->accept(this));
805   Value *key_ptr = B.CreateBitCast(pop_expr(), B.getInt8PtrTy());
806 
807   // result = lookup(key)
808   Value *lookup1 = B.CreateBitCast(B.CreateCall(lookup_fn, vector<Value *>({pseudo_map_fd, key_ptr})), leaf_ptype);
809 
810   Value *result = nullptr;
811   if (n->table_->policy_id()->name_ == "AUTO") {
812     Function *parent = B.GetInsertBlock()->getParent();
813     BasicBlock *label_start = B.GetInsertBlock();
814     BasicBlock *label_then = BasicBlock::Create(ctx(), n->id_->name_ + "[].then", parent);
815     BasicBlock *label_end = BasicBlock::Create(ctx(), n->id_->name_ + "[].end", parent);
816 
817     Value *eq_zero = B.CreateIsNull(lookup1);
818     B.CreateCondBr(eq_zero, label_then, label_end);
819 
820     B.SetInsertPoint(label_then);
821     // var Leaf leaf {0}
822     Value *leaf_ptr = B.CreateBitCast(
823         make_alloca(resolve_entry_stack(), leaf_type), B.getInt8PtrTy());
824     B.CreateMemSet(leaf_ptr, B.getInt8(0), B.getInt64(n->table_->leaf_id()->bit_width_ >> 3), 1);
825     // update(key, leaf)
826     B.CreateCall(update_fn, vector<Value *>({pseudo_map_fd, key_ptr, leaf_ptr, B.getInt64(BPF_NOEXIST)}));
827 
828     // result = lookup(key)
829     Value *lookup2 = B.CreateBitCast(B.CreateCall(lookup_fn, vector<Value *>({pseudo_map_fd, key_ptr})), leaf_ptype);
830     B.CreateBr(label_end);
831 
832     B.SetInsertPoint(label_end);
833 
834     PHINode *phi = B.CreatePHI(leaf_ptype, 2);
835     phi->addIncoming(lookup1, label_start);
836     phi->addIncoming(lookup2, label_then);
837     result = phi;
838   } else if (n->table_->policy_id()->name_ == "NONE") {
839     result = lookup1;
840   }
841 
842   if (n->is_lhs()) {
843     if (n->sub_decl_) {
844       Type *ptr_type = PointerType::getUnqual(B.getIntNTy(n->sub_decl_->bit_width_));
845       // u64 *errval -> uN *errval
846       Value *err_cast = B.CreateBitCast(errval_, ptr_type);
847       // if valid then &field, else &errval
848       Function *parent = B.GetInsertBlock()->getParent();
849       BasicBlock *label_start = B.GetInsertBlock();
850       BasicBlock *label_then = BasicBlock::Create(ctx(), n->id_->name_ + "[]field.then", parent);
851       BasicBlock *label_end = BasicBlock::Create(ctx(), n->id_->name_ + "[]field.end", parent);
852 
853       if (1) {
854         // the PHI implementation of this doesn't load, maybe eBPF limitation?
855         B.CreateCondBr(B.CreateIsNull(result), label_then, label_end);
856         B.SetInsertPoint(label_then);
857         B.CreateStore(B.getInt32(2), retval_);
858         B.CreateBr(resolve_label("DONE"));
859 
860         B.SetInsertPoint(label_end);
861         vector<Value *> indices({B.getInt32(0), B.getInt32(n->sub_decl_->slot_)});
862         expr_ = B.CreateInBoundsGEP(result, indices);
863       } else {
864         B.CreateCondBr(B.CreateIsNotNull(result), label_then, label_end);
865 
866         B.SetInsertPoint(label_then);
867         vector<Value *> indices({B.getInt32(0), B.getInt32(n->sub_decl_->slot_)});
868         Value *field = B.CreateInBoundsGEP(result, indices);
869         B.CreateBr(label_end);
870 
871         B.SetInsertPoint(label_end);
872         PHINode *phi = B.CreatePHI(ptr_type, 2);
873         phi->addIncoming(err_cast, label_start);
874         phi->addIncoming(field, label_then);
875         expr_ = phi;
876       }
877     } else {
878       return mkstatus_(n, "unsupported");
879     }
880   } else {
881     expr_ = result;
882   }
883   return StatusTuple(0);
884 }
885 
886 /// on_match
visit_match_decl_stmt_node(MatchDeclStmtNode * n)887 StatusTuple CodegenLLVM::visit_match_decl_stmt_node(MatchDeclStmtNode *n) {
888   if (n->formals_.size() != 1)
889     return mkstatus_(n, "on_match expected 1 arguments, %zu given", n->formals_.size());
890   StructVariableDeclStmtNode* leaf_n = static_cast<StructVariableDeclStmtNode*>(n->formals_.at(0).get());
891   if (!leaf_n)
892     return mkstatus_(n, "invalid parameter type");
893   // lookup result variable
894   auto result_decl = scopes_->current_var()->lookup("_result", false);
895   if (!result_decl) return mkstatus_(n, "unable to find _result built-in");
896   auto result = vars_.find(result_decl);
897   if (result == vars_.end()) return mkstatus_(n, "unable to find memory for _result built-in");
898   vars_[leaf_n] = result->second;
899 
900   Value *load_1 = B.CreateLoad(result->second);
901   Value *is_null = B.CreateIsNotNull(load_1);
902 
903   Function *parent = B.GetInsertBlock()->getParent();
904   BasicBlock *label_then = BasicBlock::Create(ctx(), "onvalid.then", parent);
905   BasicBlock *label_end = BasicBlock::Create(ctx(), "onvalid.end", parent);
906   B.CreateCondBr(is_null, label_then, label_end);
907 
908   {
909     BlockStack bstack(this, label_then);
910     TRY2(n->block_->accept(this));
911     if (!B.GetInsertBlock()->getTerminator())
912       B.CreateBr(label_end);
913   }
914 
915   B.SetInsertPoint(label_end);
916   return StatusTuple(0);
917 }
918 
919 /// on_miss
visit_miss_decl_stmt_node(MissDeclStmtNode * n)920 StatusTuple CodegenLLVM::visit_miss_decl_stmt_node(MissDeclStmtNode *n) {
921   if (n->formals_.size() != 0)
922     return mkstatus_(n, "on_match expected 0 arguments, %zu given", n->formals_.size());
923   auto result_decl = scopes_->current_var()->lookup("_result", false);
924   if (!result_decl) return mkstatus_(n, "unable to find _result built-in");
925   auto result = vars_.find(result_decl);
926   if (result == vars_.end()) return mkstatus_(n, "unable to find memory for _result built-in");
927 
928   Value *load_1 = B.CreateLoad(result->second);
929   Value *is_null = B.CreateIsNull(load_1);
930 
931   Function *parent = B.GetInsertBlock()->getParent();
932   BasicBlock *label_then = BasicBlock::Create(ctx(), "onvalid.then", parent);
933   BasicBlock *label_end = BasicBlock::Create(ctx(), "onvalid.end", parent);
934   B.CreateCondBr(is_null, label_then, label_end);
935 
936   {
937     BlockStack bstack(this, label_then);
938     TRY2(n->block_->accept(this));
939     if (!B.GetInsertBlock()->getTerminator())
940       B.CreateBr(label_end);
941   }
942 
943   B.SetInsertPoint(label_end);
944   return StatusTuple(0);
945 }
946 
visit_failure_decl_stmt_node(FailureDeclStmtNode * n)947 StatusTuple CodegenLLVM::visit_failure_decl_stmt_node(FailureDeclStmtNode *n) {
948   return mkstatus_(n, "unsupported");
949 }
950 
visit_expr_stmt_node(ExprStmtNode * n)951 StatusTuple CodegenLLVM::visit_expr_stmt_node(ExprStmtNode *n) {
952   TRY2(n->expr_->accept(this));
953   expr_ = nullptr;
954   return StatusTuple(0);
955 }
956 
visit_struct_variable_decl_stmt_node(StructVariableDeclStmtNode * n)957 StatusTuple CodegenLLVM::visit_struct_variable_decl_stmt_node(StructVariableDeclStmtNode *n) {
958   if (n->struct_id_->name_ == "" || n->struct_id_->name_[0] == '_') {
959     return StatusTuple(0);
960   }
961 
962   StructType *stype;
963   StructDeclStmtNode *decl;
964   TRY2(lookup_struct_type(n, &stype, &decl));
965 
966   Type *ptr_stype = n->is_pointer() ? PointerType::getUnqual(stype) : (PointerType *)stype;
967   AllocaInst *ptr_a = make_alloca(resolve_entry_stack(), ptr_stype);
968   vars_[n] = ptr_a;
969 
970   if (n->struct_id_->scope_name_ == "proto") {
971     if (n->is_pointer()) {
972       ConstantPointerNull *const_null = ConstantPointerNull::get(cast<PointerType>(ptr_stype));
973       B.CreateStore(const_null, ptr_a);
974     } else {
975       return mkstatus_(n, "unsupported");
976       // string var = n->scope_id() + n->id_->name_;
977       // /* zero initialize array to be filled in with packet header */
978       // emit("uint64_t __%s[%zu] = {}; uint8_t *%s = (uint8_t*)__%s;",
979       //      var.c_str(), ((decl->bit_width_ >> 3) + 7) >> 3, var.c_str(), var.c_str());
980       // for (auto it = n->init_.begin(); it != n->init_.end(); ++it) {
981       //   auto asn = static_cast<AssignExprNode*>(it->get());
982       //   if (auto f = decl->field(asn->id_->sub_name_)) {
983       //     size_t bit_offset = f->bit_offset_;
984       //     size_t bit_width = f->bit_width_;
985       //     if (asn->bitop_) {
986       //       bit_offset += f->bit_width_ - (asn->bitop_->bit_offset_ + asn->bitop_->bit_width_);
987       //       bit_width = std::min(bit_width - asn->bitop_->bit_offset_, asn->bitop_->bit_width_);
988       //     }
989       //     emit(" bpf_dins(%s + %zu, %zu, %zu, ", var.c_str(), bit_offset >> 3, bit_offset & 0x7, bit_width);
990       //     TRY2(asn->rhs_->accept(this));
991       //     emit(");");
992       //   }
993       // }
994     }
995   } else {
996     if (n->is_pointer()) {
997       if (n->id_->name_ == "_result") {
998         // special case for capturing the return value of a previous method call
999         Value *cast_1 = B.CreateBitCast(pop_expr(), ptr_stype);
1000         B.CreateStore(cast_1, ptr_a);
1001       } else {
1002         ConstantPointerNull *const_null = ConstantPointerNull::get(cast<PointerType>(ptr_stype));
1003         B.CreateStore(const_null, ptr_a);
1004       }
1005     } else {
1006       B.CreateMemSet(ptr_a, B.getInt8(0), B.getInt64(decl->bit_width_ >> 3), 1);
1007       if (!n->init_.empty()) {
1008         for (auto it = n->init_.begin(); it != n->init_.end(); ++it)
1009           TRY2((*it)->accept(this));
1010       }
1011     }
1012   }
1013   return StatusTuple(0);
1014 }
1015 
visit_integer_variable_decl_stmt_node(IntegerVariableDeclStmtNode * n)1016 StatusTuple CodegenLLVM::visit_integer_variable_decl_stmt_node(IntegerVariableDeclStmtNode *n) {
1017   if (!B.GetInsertBlock())
1018     return StatusTuple(0);
1019 
1020   // uintX var = init
1021   AllocaInst *ptr_a = make_alloca(resolve_entry_stack(),
1022                                   B.getIntNTy(n->bit_width_), n->id_->name_);
1023   vars_[n] = ptr_a;
1024 
1025   // todo
1026   if (!n->init_.empty())
1027     TRY2(n->init_[0]->accept(this));
1028   return StatusTuple(0);
1029 }
1030 
visit_struct_decl_stmt_node(StructDeclStmtNode * n)1031 StatusTuple CodegenLLVM::visit_struct_decl_stmt_node(StructDeclStmtNode *n) {
1032   ++indent_;
1033   StructType *struct_type = StructType::create(ctx(), "_struct." + n->id_->name_);
1034   vector<Type *> fields;
1035   for (auto it = n->stmts_.begin(); it != n->stmts_.end(); ++it)
1036     fields.push_back(B.getIntNTy((*it)->bit_width_));
1037   struct_type->setBody(fields, n->is_packed());
1038   structs_[n] = struct_type;
1039   return StatusTuple(0);
1040 }
1041 
visit_parser_state_stmt_node(ParserStateStmtNode * n)1042 StatusTuple CodegenLLVM::visit_parser_state_stmt_node(ParserStateStmtNode *n) {
1043   string jump_label = n->scoped_name() + "_continue";
1044   BasicBlock *label_entry = resolve_label(jump_label);
1045   B.SetInsertPoint(label_entry);
1046   if (n->next_state_)
1047     TRY2(n->next_state_->accept(this));
1048   return StatusTuple(0);
1049 }
1050 
visit_state_decl_stmt_node(StateDeclStmtNode * n)1051 StatusTuple CodegenLLVM::visit_state_decl_stmt_node(StateDeclStmtNode *n) {
1052   if (!n->id_)
1053     return StatusTuple(0);
1054   string jump_label = n->scoped_name();
1055   BasicBlock *label_entry = resolve_label(jump_label);
1056   B.SetInsertPoint(label_entry);
1057 
1058   auto it = n->subs_.begin();
1059 
1060   scopes_->push_state(it->scope_);
1061 
1062   for (auto in = n->init_.begin(); in != n->init_.end(); ++in)
1063     TRY2((*in)->accept(this));
1064 
1065   if (n->subs_.size() == 1 && it->id_->name_ == "") {
1066     // this is not a multistate protocol, emit everything and finish
1067     TRY2(it->block_->accept(this));
1068     if (n->parser_) {
1069       B.CreateBr(resolve_label(jump_label + "_continue"));
1070       TRY2(n->parser_->accept(this));
1071     }
1072   } else {
1073     return mkstatus_(n, "unsupported");
1074   }
1075 
1076   scopes_->pop_state();
1077   return StatusTuple(0);
1078 }
1079 
visit_table_decl_stmt_node(TableDeclStmtNode * n)1080 StatusTuple CodegenLLVM::visit_table_decl_stmt_node(TableDeclStmtNode *n) {
1081   if (n->table_type_->name_ == "Table"
1082       || n->table_type_->name_ == "SharedTable") {
1083     if (n->templates_.size() != 4)
1084       return mkstatus_(n, "%s expected 4 arguments, %zu given", n->table_type_->c_str(), n->templates_.size());
1085     auto key = scopes_->top_struct()->lookup(n->key_id()->name_, /*search_local*/true);
1086     if (!key) return mkstatus_(n, "cannot find key %s", n->key_id()->name_.c_str());
1087     auto leaf = scopes_->top_struct()->lookup(n->leaf_id()->name_, /*search_local*/true);
1088     if (!leaf) return mkstatus_(n, "cannot find leaf %s", n->leaf_id()->name_.c_str());
1089 
1090     bpf_map_type map_type = BPF_MAP_TYPE_UNSPEC;
1091     if (n->type_id()->name_ == "FIXED_MATCH")
1092       map_type = BPF_MAP_TYPE_HASH;
1093     else if (n->type_id()->name_ == "INDEXED")
1094       map_type = BPF_MAP_TYPE_ARRAY;
1095     else
1096       return mkstatus_(n, "Table type %s not implemented", n->type_id()->name_.c_str());
1097 
1098     StructType *key_stype, *leaf_stype;
1099     TRY2(lookup_struct_type(n->key_type_, &key_stype));
1100     TRY2(lookup_struct_type(n->leaf_type_, &leaf_stype));
1101     StructType *decl_struct = mod_->getTypeByName("_struct." + n->id_->name_);
1102     if (!decl_struct)
1103       decl_struct = StructType::create(ctx(), "_struct." + n->id_->name_);
1104     if (decl_struct->isOpaque())
1105       decl_struct->setBody(vector<Type *>({key_stype, leaf_stype}), /*isPacked=*/false);
1106     GlobalVariable *decl_gvar = new GlobalVariable(*mod_, decl_struct, false,
1107                                                    GlobalValue::ExternalLinkage, 0, n->id_->name_);
1108     decl_gvar->setSection("maps");
1109     tables_[n] = decl_gvar;
1110 
1111     int map_fd = bpf_create_map(map_type, n->id_->name_.c_str(),
1112                                 key->bit_width_ / 8, leaf->bit_width_ / 8,
1113                                 n->size_, 0);
1114     if (map_fd >= 0)
1115       table_fds_[n] = map_fd;
1116   } else {
1117     return mkstatus_(n, "Table %s not implemented", n->table_type_->name_.c_str());
1118   }
1119   return StatusTuple(0);
1120 }
1121 
lookup_struct_type(StructDeclStmtNode * decl,StructType ** stype) const1122 StatusTuple CodegenLLVM::lookup_struct_type(StructDeclStmtNode *decl, StructType **stype) const {
1123   auto struct_it = structs_.find(decl);
1124   if (struct_it == structs_.end())
1125     return mkstatus_(decl, "could not find IR for type %s", decl->id_->c_str());
1126   *stype = struct_it->second;
1127 
1128   return StatusTuple(0);
1129 }
1130 
lookup_struct_type(VariableDeclStmtNode * n,StructType ** stype,StructDeclStmtNode ** decl) const1131 StatusTuple CodegenLLVM::lookup_struct_type(VariableDeclStmtNode *n, StructType **stype,
1132                                             StructDeclStmtNode **decl) const {
1133   if (!n->is_struct())
1134     return mkstatus_(n, "attempt to search for struct with a non-struct type %s", n->id_->c_str());
1135 
1136   auto var = (StructVariableDeclStmtNode *)n;
1137   StructDeclStmtNode *type;
1138   if (var->struct_id_->scope_name_ == "proto")
1139     type = proto_scopes_->top_struct()->lookup(var->struct_id_->name_, true);
1140   else
1141     type = scopes_->top_struct()->lookup(var->struct_id_->name_, true);
1142 
1143   if (!type) return mkstatus_(n, "could not find type %s", var->struct_id_->c_str());
1144 
1145   TRY2(lookup_struct_type(type, stype));
1146 
1147   if (decl)
1148     *decl = type;
1149 
1150   return StatusTuple(0);
1151 }
1152 
visit_func_decl_stmt_node(FuncDeclStmtNode * n)1153 StatusTuple CodegenLLVM::visit_func_decl_stmt_node(FuncDeclStmtNode *n) {
1154   if (n->formals_.size() != 1)
1155     return mkstatus_(n, "Functions must have exactly 1 argument, %zd given", n->formals_.size());
1156 
1157   vector<Type *> formals;
1158   for (auto it = n->formals_.begin(); it != n->formals_.end(); ++it) {
1159     VariableDeclStmtNode *formal = it->get();
1160     if (formal->is_struct()) {
1161       StructType *stype;
1162       //TRY2(lookup_struct_type(formal, &stype));
1163       auto var = (StructVariableDeclStmtNode *)formal;
1164       stype = mod_->getTypeByName("_struct." + var->struct_id_->name_);
1165       if (!stype) return mkstatus_(n, "could not find type %s", var->struct_id_->c_str());
1166       formals.push_back(PointerType::getUnqual(stype));
1167     } else {
1168       formals.push_back(B.getIntNTy(formal->bit_width_));
1169     }
1170   }
1171   FunctionType *fn_type = FunctionType::get(B.getInt32Ty(), formals, /*isVarArg=*/false);
1172 
1173   Function *fn = mod_->getFunction(n->id_->name_);
1174   if (fn) return mkstatus_(n, "Function %s already defined", n->id_->c_str());
1175   fn = Function::Create(fn_type, GlobalValue::ExternalLinkage, n->id_->name_, mod_);
1176   fn->setCallingConv(CallingConv::C);
1177   fn->addFnAttr(Attribute::NoUnwind);
1178   fn->setSection(BPF_FN_PREFIX + n->id_->name_);
1179 
1180   BasicBlock *label_entry = BasicBlock::Create(ctx(), "entry", fn);
1181   B.SetInsertPoint(label_entry);
1182   string scoped_entry_label = to_string((uintptr_t)fn) + "::entry";
1183   labels_[scoped_entry_label] = label_entry;
1184   BasicBlock *label_return = resolve_label("DONE");
1185   retval_ = make_alloca(label_entry, fn->getReturnType(), "ret");
1186   B.CreateStore(B.getInt32(0), retval_);
1187   errval_ = make_alloca(label_entry, B.getInt64Ty(), "err");
1188   B.CreateStore(B.getInt64(0), errval_);
1189 
1190   auto formal = n->formals_.begin();
1191   for (auto arg = fn->arg_begin(); arg != fn->arg_end(); ++arg, ++formal) {
1192     TRY2((*formal)->accept(this));
1193     Value *ptr = vars_[formal->get()];
1194     if (!ptr) return mkstatus_(n, "cannot locate memory location for arg %s", (*formal)->id_->c_str());
1195     B.CreateStore(&*arg, ptr);
1196 
1197     // Type *ptype;
1198     // if ((*formal)->is_struct()) {
1199     //   StructType *type;
1200     //   TRY2(lookup_struct_type(formal->get(), &type));
1201     //   ptype = PointerType::getUnqual(type);
1202     // } else {
1203     //   ptype = PointerType::getUnqual(B.getIntNTy((*formal)->bit_width_));
1204     // }
1205 
1206     // arg->setName((*formal)->id_->name_);
1207     // AllocaInst *ptr = make_alloca(label_entry, ptype, (*formal)->id_->name_);
1208     // B.CreateStore(arg, ptr);
1209     // vars_[formal->get()] = ptr;
1210   }
1211 
1212   // visit function scoped variables
1213   {
1214     scopes_->push_state(n->scope_);
1215 
1216     for (auto it = scopes_->current_var()->obegin(); it != scopes_->current_var()->oend(); ++it)
1217       TRY2((*it)->accept(this));
1218 
1219     TRY2(n->block_->accept(this));
1220 
1221     scopes_->pop_state();
1222     if (!B.GetInsertBlock()->getTerminator())
1223       B.CreateBr(resolve_label("DONE"));
1224 
1225     // always return something
1226     B.SetInsertPoint(label_return);
1227     B.CreateRet(B.CreateLoad(retval_));
1228   }
1229 
1230   return StatusTuple(0);
1231 }
1232 
visit(Node * root,TableStorage & ts,const string & id,const string & maps_ns)1233 StatusTuple CodegenLLVM::visit(Node *root, TableStorage &ts, const string &id,
1234                                const string &maps_ns) {
1235   scopes_->set_current(scopes_->top_state());
1236   scopes_->set_current(scopes_->top_var());
1237 
1238   TRY2(print_header());
1239 
1240   for (auto it = scopes_->top_table()->obegin(); it != scopes_->top_table()->oend(); ++it)
1241     TRY2((*it)->accept(this));
1242 
1243   for (auto it = scopes_->top_func()->obegin(); it != scopes_->top_func()->oend(); ++it)
1244     TRY2((*it)->accept(this));
1245   //TRY2(print_parser());
1246 
1247   for (auto table : tables_) {
1248     bpf_map_type map_type = BPF_MAP_TYPE_UNSPEC;
1249     if (table.first->type_id()->name_ == "FIXED_MATCH")
1250       map_type = BPF_MAP_TYPE_HASH;
1251     else if (table.first->type_id()->name_ == "INDEXED")
1252       map_type = BPF_MAP_TYPE_ARRAY;
1253     ts.Insert(Path({id, table.first->id_->name_}),
1254               {
1255                   table.first->id_->name_, FileDesc(table_fds_[table.first]), map_type,
1256                   table.first->key_type_->bit_width_ >> 3, table.first->leaf_type_->bit_width_ >> 3,
1257                   table.first->size_, 0,
1258               });
1259   }
1260   return StatusTuple(0);
1261 }
1262 
print_header()1263 StatusTuple CodegenLLVM::print_header() {
1264 
1265   GlobalVariable *gvar_license = new GlobalVariable(*mod_, ArrayType::get(Type::getInt8Ty(ctx()), 4),
1266                                                     false, GlobalValue::ExternalLinkage, 0, "_license");
1267   gvar_license->setSection("license");
1268   gvar_license->setInitializer(ConstantDataArray::getString(ctx(), "GPL", true));
1269 
1270   Function *pseudo_fn = mod_->getFunction("llvm.bpf.pseudo");
1271   if (!pseudo_fn) {
1272     pseudo_fn = Function::Create(
1273         FunctionType::get(B.getInt64Ty(), vector<Type *>({B.getInt64Ty(), B.getInt64Ty()}), false),
1274         GlobalValue::ExternalLinkage, "llvm.bpf.pseudo", mod_);
1275   }
1276 
1277   // declare structures
1278   for (auto it = scopes_->top_struct()->obegin(); it != scopes_->top_struct()->oend(); ++it) {
1279     if ((*it)->id_->name_ == "_Packet")
1280       continue;
1281     TRY2((*it)->accept(this));
1282   }
1283   for (auto it = proto_scopes_->top_struct()->obegin(); it != proto_scopes_->top_struct()->oend(); ++it) {
1284     if ((*it)->id_->name_ == "_Packet")
1285       continue;
1286     TRY2((*it)->accept(this));
1287   }
1288   return StatusTuple(0);
1289 }
1290 
get_table_fd(const string & name) const1291 int CodegenLLVM::get_table_fd(const string &name) const {
1292   TableDeclStmtNode *table = scopes_->top_table()->lookup(name);
1293   if (!table)
1294     return -1;
1295 
1296   auto table_fd_it = table_fds_.find(table);
1297   if (table_fd_it == table_fds_.end())
1298     return -1;
1299 
1300   return table_fd_it->second;
1301 }
1302 
ctx() const1303 LLVMContext & CodegenLLVM::ctx() const {
1304   return mod_->getContext();
1305 }
1306 
const_int(uint64_t val,unsigned bits,bool is_signed)1307 Constant * CodegenLLVM::const_int(uint64_t val, unsigned bits, bool is_signed) {
1308   return ConstantInt::get(ctx(), APInt(bits, val, is_signed));
1309 }
1310 
pop_expr()1311 Value * CodegenLLVM::pop_expr() {
1312   Value *ret = expr_;
1313   expr_ = nullptr;
1314   return ret;
1315 }
1316 
resolve_label(const string & label)1317 BasicBlock * CodegenLLVM::resolve_label(const string &label) {
1318   Function *parent = B.GetInsertBlock()->getParent();
1319   string scoped_label = to_string((uintptr_t)parent) + "::" + label;
1320   auto it = labels_.find(scoped_label);
1321   if (it != labels_.end()) return it->second;
1322   BasicBlock *label_new = BasicBlock::Create(ctx(), label, parent);
1323   labels_[scoped_label] = label_new;
1324   return label_new;
1325 }
1326 
resolve_entry_stack()1327 Instruction * CodegenLLVM::resolve_entry_stack() {
1328   BasicBlock *label_entry = resolve_label("entry");
1329   return &label_entry->back();
1330 }
1331 
make_alloca(Instruction * Inst,Type * Ty,const string & name,Value * ArraySize)1332 AllocaInst *CodegenLLVM::make_alloca(Instruction *Inst, Type *Ty,
1333                                      const string &name, Value *ArraySize) {
1334   IRBuilderBase::InsertPoint ip = B.saveIP();
1335   B.SetInsertPoint(Inst);
1336   AllocaInst *a = B.CreateAlloca(Ty, ArraySize, name);
1337   B.restoreIP(ip);
1338   return a;
1339 }
1340 
make_alloca(BasicBlock * BB,Type * Ty,const string & name,Value * ArraySize)1341 AllocaInst *CodegenLLVM::make_alloca(BasicBlock *BB, Type *Ty,
1342                                      const string &name, Value *ArraySize) {
1343   IRBuilderBase::InsertPoint ip = B.saveIP();
1344   B.SetInsertPoint(BB);
1345   AllocaInst *a = B.CreateAlloca(Ty, ArraySize, name);
1346   B.restoreIP(ip);
1347   return a;
1348 }
1349 
1350 }  // namespace cc
1351 }  // namespace ebpf
1352