1 // Copyright (c) 2017 The Khronos Group Inc.
2 // Copyright (c) 2017 Valve Corporation
3 // Copyright (c) 2017 LunarG Inc.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //     http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 
17 #include "source/opt/inline_pass.h"
18 
19 #include <unordered_set>
20 #include <utility>
21 
22 #include "source/cfa.h"
23 #include "source/util/make_unique.h"
24 
25 // Indices of operands in SPIR-V instructions
26 
27 static const int kSpvFunctionCallFunctionId = 2;
28 static const int kSpvFunctionCallArgumentId = 3;
29 static const int kSpvReturnValueId = 0;
30 static const int kSpvLoopMergeContinueTargetIdInIdx = 1;
31 
32 namespace spvtools {
33 namespace opt {
34 
AddPointerToType(uint32_t type_id,SpvStorageClass storage_class)35 uint32_t InlinePass::AddPointerToType(uint32_t type_id,
36                                       SpvStorageClass storage_class) {
37   uint32_t resultId = TakeNextId();
38   std::unique_ptr<Instruction> type_inst(
39       new Instruction(context(), SpvOpTypePointer, 0, resultId,
40                       {{spv_operand_type_t::SPV_OPERAND_TYPE_STORAGE_CLASS,
41                         {uint32_t(storage_class)}},
42                        {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {type_id}}}));
43   context()->AddType(std::move(type_inst));
44   analysis::Type* pointeeTy;
45   std::unique_ptr<analysis::Pointer> pointerTy;
46   std::tie(pointeeTy, pointerTy) =
47       context()->get_type_mgr()->GetTypeAndPointerType(type_id,
48                                                        SpvStorageClassFunction);
49   context()->get_type_mgr()->RegisterType(resultId, *pointerTy);
50   return resultId;
51 }
52 
AddBranch(uint32_t label_id,std::unique_ptr<BasicBlock> * block_ptr)53 void InlinePass::AddBranch(uint32_t label_id,
54                            std::unique_ptr<BasicBlock>* block_ptr) {
55   std::unique_ptr<Instruction> newBranch(
56       new Instruction(context(), SpvOpBranch, 0, 0,
57                       {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {label_id}}}));
58   (*block_ptr)->AddInstruction(std::move(newBranch));
59 }
60 
AddBranchCond(uint32_t cond_id,uint32_t true_id,uint32_t false_id,std::unique_ptr<BasicBlock> * block_ptr)61 void InlinePass::AddBranchCond(uint32_t cond_id, uint32_t true_id,
62                                uint32_t false_id,
63                                std::unique_ptr<BasicBlock>* block_ptr) {
64   std::unique_ptr<Instruction> newBranch(
65       new Instruction(context(), SpvOpBranchConditional, 0, 0,
66                       {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {cond_id}},
67                        {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {true_id}},
68                        {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {false_id}}}));
69   (*block_ptr)->AddInstruction(std::move(newBranch));
70 }
71 
AddLoopMerge(uint32_t merge_id,uint32_t continue_id,std::unique_ptr<BasicBlock> * block_ptr)72 void InlinePass::AddLoopMerge(uint32_t merge_id, uint32_t continue_id,
73                               std::unique_ptr<BasicBlock>* block_ptr) {
74   std::unique_ptr<Instruction> newLoopMerge(new Instruction(
75       context(), SpvOpLoopMerge, 0, 0,
76       {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {merge_id}},
77        {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {continue_id}},
78        {spv_operand_type_t::SPV_OPERAND_TYPE_LOOP_CONTROL, {0}}}));
79   (*block_ptr)->AddInstruction(std::move(newLoopMerge));
80 }
81 
AddStore(uint32_t ptr_id,uint32_t val_id,std::unique_ptr<BasicBlock> * block_ptr)82 void InlinePass::AddStore(uint32_t ptr_id, uint32_t val_id,
83                           std::unique_ptr<BasicBlock>* block_ptr) {
84   std::unique_ptr<Instruction> newStore(
85       new Instruction(context(), SpvOpStore, 0, 0,
86                       {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {ptr_id}},
87                        {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {val_id}}}));
88   (*block_ptr)->AddInstruction(std::move(newStore));
89 }
90 
AddLoad(uint32_t type_id,uint32_t resultId,uint32_t ptr_id,std::unique_ptr<BasicBlock> * block_ptr)91 void InlinePass::AddLoad(uint32_t type_id, uint32_t resultId, uint32_t ptr_id,
92                          std::unique_ptr<BasicBlock>* block_ptr) {
93   std::unique_ptr<Instruction> newLoad(
94       new Instruction(context(), SpvOpLoad, type_id, resultId,
95                       {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {ptr_id}}}));
96   (*block_ptr)->AddInstruction(std::move(newLoad));
97 }
98 
NewLabel(uint32_t label_id)99 std::unique_ptr<Instruction> InlinePass::NewLabel(uint32_t label_id) {
100   std::unique_ptr<Instruction> newLabel(
101       new Instruction(context(), SpvOpLabel, 0, label_id, {}));
102   return newLabel;
103 }
104 
GetFalseId()105 uint32_t InlinePass::GetFalseId() {
106   if (false_id_ != 0) return false_id_;
107   false_id_ = get_module()->GetGlobalValue(SpvOpConstantFalse);
108   if (false_id_ != 0) return false_id_;
109   uint32_t boolId = get_module()->GetGlobalValue(SpvOpTypeBool);
110   if (boolId == 0) {
111     boolId = TakeNextId();
112     get_module()->AddGlobalValue(SpvOpTypeBool, boolId, 0);
113   }
114   false_id_ = TakeNextId();
115   get_module()->AddGlobalValue(SpvOpConstantFalse, false_id_, boolId);
116   return false_id_;
117 }
118 
MapParams(Function * calleeFn,BasicBlock::iterator call_inst_itr,std::unordered_map<uint32_t,uint32_t> * callee2caller)119 void InlinePass::MapParams(
120     Function* calleeFn, BasicBlock::iterator call_inst_itr,
121     std::unordered_map<uint32_t, uint32_t>* callee2caller) {
122   int param_idx = 0;
123   calleeFn->ForEachParam([&call_inst_itr, &param_idx,
124                           &callee2caller](const Instruction* cpi) {
125     const uint32_t pid = cpi->result_id();
126     (*callee2caller)[pid] = call_inst_itr->GetSingleWordOperand(
127         kSpvFunctionCallArgumentId + param_idx);
128     ++param_idx;
129   });
130 }
131 
CloneAndMapLocals(Function * calleeFn,std::vector<std::unique_ptr<Instruction>> * new_vars,std::unordered_map<uint32_t,uint32_t> * callee2caller)132 void InlinePass::CloneAndMapLocals(
133     Function* calleeFn, std::vector<std::unique_ptr<Instruction>>* new_vars,
134     std::unordered_map<uint32_t, uint32_t>* callee2caller) {
135   auto callee_block_itr = calleeFn->begin();
136   auto callee_var_itr = callee_block_itr->begin();
137   while (callee_var_itr->opcode() == SpvOp::SpvOpVariable) {
138     std::unique_ptr<Instruction> var_inst(callee_var_itr->Clone(context()));
139     uint32_t newId = TakeNextId();
140     get_decoration_mgr()->CloneDecorations(callee_var_itr->result_id(), newId);
141     var_inst->SetResultId(newId);
142     (*callee2caller)[callee_var_itr->result_id()] = newId;
143     new_vars->push_back(std::move(var_inst));
144     ++callee_var_itr;
145   }
146 }
147 
CreateReturnVar(Function * calleeFn,std::vector<std::unique_ptr<Instruction>> * new_vars)148 uint32_t InlinePass::CreateReturnVar(
149     Function* calleeFn, std::vector<std::unique_ptr<Instruction>>* new_vars) {
150   uint32_t returnVarId = 0;
151   const uint32_t calleeTypeId = calleeFn->type_id();
152   analysis::Type* calleeType = context()->get_type_mgr()->GetType(calleeTypeId);
153   if (calleeType->AsVoid() == nullptr) {
154     // Find or create ptr to callee return type.
155     uint32_t returnVarTypeId = context()->get_type_mgr()->FindPointerToType(
156         calleeTypeId, SpvStorageClassFunction);
157     if (returnVarTypeId == 0)
158       returnVarTypeId = AddPointerToType(calleeTypeId, SpvStorageClassFunction);
159     // Add return var to new function scope variables.
160     returnVarId = TakeNextId();
161     std::unique_ptr<Instruction> var_inst(
162         new Instruction(context(), SpvOpVariable, returnVarTypeId, returnVarId,
163                         {{spv_operand_type_t::SPV_OPERAND_TYPE_STORAGE_CLASS,
164                           {SpvStorageClassFunction}}}));
165     new_vars->push_back(std::move(var_inst));
166   }
167   get_decoration_mgr()->CloneDecorations(calleeFn->result_id(), returnVarId);
168   return returnVarId;
169 }
170 
IsSameBlockOp(const Instruction * inst) const171 bool InlinePass::IsSameBlockOp(const Instruction* inst) const {
172   return inst->opcode() == SpvOpSampledImage || inst->opcode() == SpvOpImage;
173 }
174 
CloneSameBlockOps(std::unique_ptr<Instruction> * inst,std::unordered_map<uint32_t,uint32_t> * postCallSB,std::unordered_map<uint32_t,Instruction * > * preCallSB,std::unique_ptr<BasicBlock> * block_ptr)175 void InlinePass::CloneSameBlockOps(
176     std::unique_ptr<Instruction>* inst,
177     std::unordered_map<uint32_t, uint32_t>* postCallSB,
178     std::unordered_map<uint32_t, Instruction*>* preCallSB,
179     std::unique_ptr<BasicBlock>* block_ptr) {
180   (*inst)->ForEachInId(
181       [&postCallSB, &preCallSB, &block_ptr, this](uint32_t* iid) {
182         const auto mapItr = (*postCallSB).find(*iid);
183         if (mapItr == (*postCallSB).end()) {
184           const auto mapItr2 = (*preCallSB).find(*iid);
185           if (mapItr2 != (*preCallSB).end()) {
186             // Clone pre-call same-block ops, map result id.
187             const Instruction* inInst = mapItr2->second;
188             std::unique_ptr<Instruction> sb_inst(inInst->Clone(context()));
189             CloneSameBlockOps(&sb_inst, postCallSB, preCallSB, block_ptr);
190             const uint32_t rid = sb_inst->result_id();
191             const uint32_t nid = this->TakeNextId();
192             get_decoration_mgr()->CloneDecorations(rid, nid);
193             sb_inst->SetResultId(nid);
194             (*postCallSB)[rid] = nid;
195             *iid = nid;
196             (*block_ptr)->AddInstruction(std::move(sb_inst));
197           }
198         } else {
199           // Reset same-block op operand.
200           *iid = mapItr->second;
201         }
202       });
203 }
204 
GenInlineCode(std::vector<std::unique_ptr<BasicBlock>> * new_blocks,std::vector<std::unique_ptr<Instruction>> * new_vars,BasicBlock::iterator call_inst_itr,UptrVectorIterator<BasicBlock> call_block_itr)205 void InlinePass::GenInlineCode(
206     std::vector<std::unique_ptr<BasicBlock>>* new_blocks,
207     std::vector<std::unique_ptr<Instruction>>* new_vars,
208     BasicBlock::iterator call_inst_itr,
209     UptrVectorIterator<BasicBlock> call_block_itr) {
210   // Map from all ids in the callee to their equivalent id in the caller
211   // as callee instructions are copied into caller.
212   std::unordered_map<uint32_t, uint32_t> callee2caller;
213   // Pre-call same-block insts
214   std::unordered_map<uint32_t, Instruction*> preCallSB;
215   // Post-call same-block op ids
216   std::unordered_map<uint32_t, uint32_t> postCallSB;
217 
218   // Invalidate the def-use chains.  They are not kept up to date while
219   // inlining.  However, certain calls try to keep them up-to-date if they are
220   // valid.  These operations can fail.
221   context()->InvalidateAnalyses(IRContext::kAnalysisDefUse);
222 
223   Function* calleeFn = id2function_[call_inst_itr->GetSingleWordOperand(
224       kSpvFunctionCallFunctionId)];
225 
226   // Check for multiple returns in the callee.
227   auto fi = early_return_funcs_.find(calleeFn->result_id());
228   const bool earlyReturn = fi != early_return_funcs_.end();
229 
230   // Map parameters to actual arguments.
231   MapParams(calleeFn, call_inst_itr, &callee2caller);
232 
233   // Define caller local variables for all callee variables and create map to
234   // them.
235   CloneAndMapLocals(calleeFn, new_vars, &callee2caller);
236 
237   // Create return var if needed.
238   uint32_t returnVarId = CreateReturnVar(calleeFn, new_vars);
239 
240   // Create set of callee result ids. Used to detect forward references
241   std::unordered_set<uint32_t> callee_result_ids;
242   calleeFn->ForEachInst([&callee_result_ids](const Instruction* cpi) {
243     const uint32_t rid = cpi->result_id();
244     if (rid != 0) callee_result_ids.insert(rid);
245   });
246 
247   // If the caller is in a single-block loop, and the callee has multiple
248   // blocks, then the normal inlining logic will place the OpLoopMerge in
249   // the last of several blocks in the loop.  Instead, it should be placed
250   // at the end of the first block.  First determine if the caller is in a
251   // single block loop.  We'll wait to move the OpLoopMerge until the end
252   // of the regular inlining logic, and only if necessary.
253   bool caller_is_single_block_loop = false;
254   bool caller_is_loop_header = false;
255   if (auto* loop_merge = call_block_itr->GetLoopMergeInst()) {
256     caller_is_loop_header = true;
257     caller_is_single_block_loop =
258         call_block_itr->id() ==
259         loop_merge->GetSingleWordInOperand(kSpvLoopMergeContinueTargetIdInIdx);
260   }
261 
262   bool callee_begins_with_structured_header =
263       (*(calleeFn->begin())).GetMergeInst() != nullptr;
264 
265   // Clone and map callee code. Copy caller block code to beginning of
266   // first block and end of last block.
267   bool prevInstWasReturn = false;
268   uint32_t singleTripLoopHeaderId = 0;
269   uint32_t singleTripLoopContinueId = 0;
270   uint32_t returnLabelId = 0;
271   bool multiBlocks = false;
272   const uint32_t calleeTypeId = calleeFn->type_id();
273   // new_blk_ptr is a new basic block in the caller.  New instructions are
274   // written to it.  It is created when we encounter the OpLabel
275   // of the first callee block.  It is appended to new_blocks only when
276   // it is complete.
277   std::unique_ptr<BasicBlock> new_blk_ptr;
278   calleeFn->ForEachInst([&new_blocks, &callee2caller, &call_block_itr,
279                          &call_inst_itr, &new_blk_ptr, &prevInstWasReturn,
280                          &returnLabelId, &returnVarId, caller_is_loop_header,
281                          callee_begins_with_structured_header, &calleeTypeId,
282                          &multiBlocks, &postCallSB, &preCallSB, earlyReturn,
283                          &singleTripLoopHeaderId, &singleTripLoopContinueId,
284                          &callee_result_ids, this](const Instruction* cpi) {
285     switch (cpi->opcode()) {
286       case SpvOpFunction:
287       case SpvOpFunctionParameter:
288         // Already processed
289         break;
290       case SpvOpVariable:
291         if (cpi->NumInOperands() == 2) {
292           assert(callee2caller.count(cpi->result_id()) &&
293                  "Expected the variable to have already been mapped.");
294           uint32_t new_var_id = callee2caller.at(cpi->result_id());
295 
296           // The initializer must be a constant or global value.  No mapped
297           // should be used.
298           uint32_t val_id = cpi->GetSingleWordInOperand(1);
299           AddStore(new_var_id, val_id, &new_blk_ptr);
300         }
301         break;
302       case SpvOpUnreachable:
303       case SpvOpKill: {
304         // Generate a return label so that we split the block with the function
305         // call. Copy the terminator into the new block.
306         if (returnLabelId == 0) returnLabelId = this->TakeNextId();
307         std::unique_ptr<Instruction> terminator(
308             new Instruction(context(), cpi->opcode(), 0, 0, {}));
309         new_blk_ptr->AddInstruction(std::move(terminator));
310         break;
311       }
312       case SpvOpLabel: {
313         // If previous instruction was early return, insert branch
314         // instruction to return block.
315         if (prevInstWasReturn) {
316           if (returnLabelId == 0) returnLabelId = this->TakeNextId();
317           AddBranch(returnLabelId, &new_blk_ptr);
318           prevInstWasReturn = false;
319         }
320         // Finish current block (if it exists) and get label for next block.
321         uint32_t labelId;
322         bool firstBlock = false;
323         if (new_blk_ptr != nullptr) {
324           new_blocks->push_back(std::move(new_blk_ptr));
325           // If result id is already mapped, use it, otherwise get a new
326           // one.
327           const uint32_t rid = cpi->result_id();
328           const auto mapItr = callee2caller.find(rid);
329           labelId = (mapItr != callee2caller.end()) ? mapItr->second
330                                                     : this->TakeNextId();
331         } else {
332           // First block needs to use label of original block
333           // but map callee label in case of phi reference.
334           labelId = call_block_itr->id();
335           callee2caller[cpi->result_id()] = labelId;
336           firstBlock = true;
337         }
338         // Create first/next block.
339         new_blk_ptr = MakeUnique<BasicBlock>(NewLabel(labelId));
340         if (firstBlock) {
341           // Copy contents of original caller block up to call instruction.
342           for (auto cii = call_block_itr->begin(); cii != call_inst_itr;
343                cii = call_block_itr->begin()) {
344             Instruction* inst = &*cii;
345             inst->RemoveFromList();
346             std::unique_ptr<Instruction> cp_inst(inst);
347             // Remember same-block ops for possible regeneration.
348             if (IsSameBlockOp(&*cp_inst)) {
349               auto* sb_inst_ptr = cp_inst.get();
350               preCallSB[cp_inst->result_id()] = sb_inst_ptr;
351             }
352             new_blk_ptr->AddInstruction(std::move(cp_inst));
353           }
354           if (caller_is_loop_header && callee_begins_with_structured_header) {
355             // We can't place both the caller's merge instruction and another
356             // merge instruction in the same block.  So split the calling block.
357             // Insert an unconditional branch to a new guard block.  Later,
358             // once we know the ID of the last block,  we will move the caller's
359             // OpLoopMerge from the last generated block into the first block.
360             // We also wait to avoid invalidating various iterators.
361             const auto guard_block_id = this->TakeNextId();
362             AddBranch(guard_block_id, &new_blk_ptr);
363             new_blocks->push_back(std::move(new_blk_ptr));
364             // Start the next block.
365             new_blk_ptr = MakeUnique<BasicBlock>(NewLabel(guard_block_id));
366             // Reset the mapping of the callee's entry block to point to
367             // the guard block.  Do this so we can fix up phis later on to
368             // satisfy dominance.
369             callee2caller[cpi->result_id()] = guard_block_id;
370           }
371           // If callee has early return, insert a header block for
372           // single-trip loop that will encompass callee code.  Start postheader
373           // block.
374           //
375           // Note: Consider the following combination:
376           //  - the caller is a single block loop
377           //  - the callee does not begin with a structure header
378           //  - the callee has multiple returns.
379           // We still need to split the caller block and insert a guard block.
380           // But we only need to do it once. We haven't done it yet, but the
381           // single-trip loop header will serve the same purpose.
382           if (earlyReturn) {
383             singleTripLoopHeaderId = this->TakeNextId();
384             AddBranch(singleTripLoopHeaderId, &new_blk_ptr);
385             new_blocks->push_back(std::move(new_blk_ptr));
386             new_blk_ptr =
387                 MakeUnique<BasicBlock>(NewLabel(singleTripLoopHeaderId));
388             returnLabelId = this->TakeNextId();
389             singleTripLoopContinueId = this->TakeNextId();
390             AddLoopMerge(returnLabelId, singleTripLoopContinueId, &new_blk_ptr);
391             uint32_t postHeaderId = this->TakeNextId();
392             AddBranch(postHeaderId, &new_blk_ptr);
393             new_blocks->push_back(std::move(new_blk_ptr));
394             new_blk_ptr = MakeUnique<BasicBlock>(NewLabel(postHeaderId));
395             multiBlocks = true;
396             // Reset the mapping of the callee's entry block to point to
397             // the post-header block.  Do this so we can fix up phis later
398             // on to satisfy dominance.
399             callee2caller[cpi->result_id()] = postHeaderId;
400           }
401         } else {
402           multiBlocks = true;
403         }
404       } break;
405       case SpvOpReturnValue: {
406         // Store return value to return variable.
407         assert(returnVarId != 0);
408         uint32_t valId = cpi->GetInOperand(kSpvReturnValueId).words[0];
409         const auto mapItr = callee2caller.find(valId);
410         if (mapItr != callee2caller.end()) {
411           valId = mapItr->second;
412         }
413         AddStore(returnVarId, valId, &new_blk_ptr);
414 
415         // Remember we saw a return; if followed by a label, will need to
416         // insert branch.
417         prevInstWasReturn = true;
418       } break;
419       case SpvOpReturn: {
420         // Remember we saw a return; if followed by a label, will need to
421         // insert branch.
422         prevInstWasReturn = true;
423       } break;
424       case SpvOpFunctionEnd: {
425         // If there was an early return, we generated a return label id
426         // for it.  Now we have to generate the return block with that Id.
427         if (returnLabelId != 0) {
428           // If previous instruction was return, insert branch instruction
429           // to return block.
430           if (prevInstWasReturn) AddBranch(returnLabelId, &new_blk_ptr);
431           if (earlyReturn) {
432             // If we generated a loop header for the single-trip loop
433             // to accommodate early returns, insert the continue
434             // target block now, with a false branch back to the loop header.
435             new_blocks->push_back(std::move(new_blk_ptr));
436             new_blk_ptr =
437                 MakeUnique<BasicBlock>(NewLabel(singleTripLoopContinueId));
438             AddBranchCond(GetFalseId(), singleTripLoopHeaderId, returnLabelId,
439                           &new_blk_ptr);
440           }
441           // Generate the return block.
442           new_blocks->push_back(std::move(new_blk_ptr));
443           new_blk_ptr = MakeUnique<BasicBlock>(NewLabel(returnLabelId));
444           multiBlocks = true;
445         }
446         // Load return value into result id of call, if it exists.
447         if (returnVarId != 0) {
448           const uint32_t resId = call_inst_itr->result_id();
449           assert(resId != 0);
450           AddLoad(calleeTypeId, resId, returnVarId, &new_blk_ptr);
451         }
452         // Copy remaining instructions from caller block.
453         for (Instruction* inst = call_inst_itr->NextNode(); inst;
454              inst = call_inst_itr->NextNode()) {
455           inst->RemoveFromList();
456           std::unique_ptr<Instruction> cp_inst(inst);
457           // If multiple blocks generated, regenerate any same-block
458           // instruction that has not been seen in this last block.
459           if (multiBlocks) {
460             CloneSameBlockOps(&cp_inst, &postCallSB, &preCallSB, &new_blk_ptr);
461             // Remember same-block ops in this block.
462             if (IsSameBlockOp(&*cp_inst)) {
463               const uint32_t rid = cp_inst->result_id();
464               postCallSB[rid] = rid;
465             }
466           }
467           new_blk_ptr->AddInstruction(std::move(cp_inst));
468         }
469         // Finalize inline code.
470         new_blocks->push_back(std::move(new_blk_ptr));
471       } break;
472       default: {
473         // Copy callee instruction and remap all input Ids.
474         std::unique_ptr<Instruction> cp_inst(cpi->Clone(context()));
475         cp_inst->ForEachInId([&callee2caller, &callee_result_ids,
476                               this](uint32_t* iid) {
477           const auto mapItr = callee2caller.find(*iid);
478           if (mapItr != callee2caller.end()) {
479             *iid = mapItr->second;
480           } else if (callee_result_ids.find(*iid) != callee_result_ids.end()) {
481             // Forward reference. Allocate a new id, map it,
482             // use it and check for it when remapping result ids
483             const uint32_t nid = this->TakeNextId();
484             callee2caller[*iid] = nid;
485             *iid = nid;
486           }
487         });
488         // If result id is non-zero, remap it. If already mapped, use mapped
489         // value, else use next id.
490         const uint32_t rid = cp_inst->result_id();
491         if (rid != 0) {
492           const auto mapItr = callee2caller.find(rid);
493           uint32_t nid;
494           if (mapItr != callee2caller.end()) {
495             nid = mapItr->second;
496           } else {
497             nid = this->TakeNextId();
498             callee2caller[rid] = nid;
499           }
500           cp_inst->SetResultId(nid);
501           get_decoration_mgr()->CloneDecorations(rid, nid);
502         }
503         new_blk_ptr->AddInstruction(std::move(cp_inst));
504       } break;
505     }
506   });
507 
508   if (caller_is_loop_header && (new_blocks->size() > 1)) {
509     // Move the OpLoopMerge from the last block back to the first, where
510     // it belongs.
511     auto& first = new_blocks->front();
512     auto& last = new_blocks->back();
513     assert(first != last);
514 
515     // Insert a modified copy of the loop merge into the first block.
516     auto loop_merge_itr = last->tail();
517     --loop_merge_itr;
518     assert(loop_merge_itr->opcode() == SpvOpLoopMerge);
519     std::unique_ptr<Instruction> cp_inst(loop_merge_itr->Clone(context()));
520     if (caller_is_single_block_loop) {
521       // Also, update its continue target to point to the last block.
522       cp_inst->SetInOperand(kSpvLoopMergeContinueTargetIdInIdx, {last->id()});
523     }
524     first->tail().InsertBefore(std::move(cp_inst));
525 
526     // Remove the loop merge from the last block.
527     loop_merge_itr->RemoveFromList();
528     delete &*loop_merge_itr;
529   }
530 
531   // Update block map given replacement blocks.
532   for (auto& blk : *new_blocks) {
533     id2block_[blk->id()] = &*blk;
534   }
535 }
536 
IsInlinableFunctionCall(const Instruction * inst)537 bool InlinePass::IsInlinableFunctionCall(const Instruction* inst) {
538   if (inst->opcode() != SpvOp::SpvOpFunctionCall) return false;
539   const uint32_t calleeFnId =
540       inst->GetSingleWordOperand(kSpvFunctionCallFunctionId);
541   const auto ci = inlinable_.find(calleeFnId);
542   return ci != inlinable_.cend();
543 }
544 
UpdateSucceedingPhis(std::vector<std::unique_ptr<BasicBlock>> & new_blocks)545 void InlinePass::UpdateSucceedingPhis(
546     std::vector<std::unique_ptr<BasicBlock>>& new_blocks) {
547   const auto firstBlk = new_blocks.begin();
548   const auto lastBlk = new_blocks.end() - 1;
549   const uint32_t firstId = (*firstBlk)->id();
550   const uint32_t lastId = (*lastBlk)->id();
551   const BasicBlock& const_last_block = *lastBlk->get();
552   const_last_block.ForEachSuccessorLabel(
553       [&firstId, &lastId, this](const uint32_t succ) {
554         BasicBlock* sbp = this->id2block_[succ];
555         sbp->ForEachPhiInst([&firstId, &lastId](Instruction* phi) {
556           phi->ForEachInId([&firstId, &lastId](uint32_t* id) {
557             if (*id == firstId) *id = lastId;
558           });
559         });
560       });
561 }
562 
HasNoReturnInStructuredConstruct(Function * func)563 bool InlinePass::HasNoReturnInStructuredConstruct(Function* func) {
564   // If control not structured, do not do loop/return analysis
565   // TODO: Analyze returns in non-structured control flow
566   if (!context()->get_feature_mgr()->HasCapability(SpvCapabilityShader))
567     return false;
568   const auto structured_analysis = context()->GetStructuredCFGAnalysis();
569   // Search for returns in structured construct.
570   bool return_in_construct = false;
571   for (auto& blk : *func) {
572     auto terminal_ii = blk.cend();
573     --terminal_ii;
574     if (spvOpcodeIsReturn(terminal_ii->opcode()) &&
575         structured_analysis->ContainingConstruct(blk.id()) != 0) {
576       return_in_construct = true;
577       break;
578     }
579   }
580   return !return_in_construct;
581 }
582 
HasNoReturnInLoop(Function * func)583 bool InlinePass::HasNoReturnInLoop(Function* func) {
584   // If control not structured, do not do loop/return analysis
585   // TODO: Analyze returns in non-structured control flow
586   if (!context()->get_feature_mgr()->HasCapability(SpvCapabilityShader))
587     return false;
588   const auto structured_analysis = context()->GetStructuredCFGAnalysis();
589   // Search for returns in structured construct.
590   bool return_in_loop = false;
591   for (auto& blk : *func) {
592     auto terminal_ii = blk.cend();
593     --terminal_ii;
594     if (spvOpcodeIsReturn(terminal_ii->opcode()) &&
595         structured_analysis->ContainingLoop(blk.id()) != 0) {
596       return_in_loop = true;
597       break;
598     }
599   }
600   return !return_in_loop;
601 }
602 
AnalyzeReturns(Function * func)603 void InlinePass::AnalyzeReturns(Function* func) {
604   if (HasNoReturnInLoop(func)) {
605     no_return_in_loop_.insert(func->result_id());
606     if (!HasNoReturnInStructuredConstruct(func))
607       early_return_funcs_.insert(func->result_id());
608   }
609 }
610 
IsInlinableFunction(Function * func)611 bool InlinePass::IsInlinableFunction(Function* func) {
612   // We can only inline a function if it has blocks.
613   if (func->cbegin() == func->cend()) return false;
614   // Do not inline functions with returns in loops. Currently early return
615   // functions are inlined by wrapping them in a one trip loop and implementing
616   // the returns as a branch to the loop's merge block. However, this can only
617   // done validly if the return was not in a loop in the original function.
618   // Also remember functions with multiple (early) returns.
619   AnalyzeReturns(func);
620   if (no_return_in_loop_.find(func->result_id()) == no_return_in_loop_.cend()) {
621     return false;
622   }
623 
624   if (func->IsRecursive()) {
625     return false;
626   }
627 
628   return true;
629 }
630 
InitializeInline()631 void InlinePass::InitializeInline() {
632   false_id_ = 0;
633 
634   // clear collections
635   id2function_.clear();
636   id2block_.clear();
637   inlinable_.clear();
638   no_return_in_loop_.clear();
639   early_return_funcs_.clear();
640 
641   for (auto& fn : *get_module()) {
642     // Initialize function and block maps.
643     id2function_[fn.result_id()] = &fn;
644     for (auto& blk : fn) {
645       id2block_[blk.id()] = &blk;
646     }
647     // Compute inlinability
648     if (IsInlinableFunction(&fn)) inlinable_.insert(fn.result_id());
649   }
650 }
651 
InlinePass()652 InlinePass::InlinePass() {}
653 
654 }  // namespace opt
655 }  // namespace spvtools
656