1 // Copyright (c) 2018 Google LLC.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "source/opt/loop_unswitch_pass.h"
16 
17 #include <functional>
18 #include <list>
19 #include <memory>
20 #include <type_traits>
21 #include <unordered_map>
22 #include <unordered_set>
23 #include <utility>
24 #include <vector>
25 
26 #include "source/opt/basic_block.h"
27 #include "source/opt/dominator_tree.h"
28 #include "source/opt/fold.h"
29 #include "source/opt/function.h"
30 #include "source/opt/instruction.h"
31 #include "source/opt/ir_builder.h"
32 #include "source/opt/ir_context.h"
33 #include "source/opt/loop_descriptor.h"
34 
35 #include "source/opt/loop_utils.h"
36 
37 namespace spvtools {
38 namespace opt {
39 namespace {
40 
41 static const uint32_t kTypePointerStorageClassInIdx = 0;
42 static const uint32_t kBranchCondTrueLabIdInIdx = 1;
43 static const uint32_t kBranchCondFalseLabIdInIdx = 2;
44 
45 }  // anonymous namespace
46 
47 namespace {
48 
49 // This class handle the unswitch procedure for a given loop.
50 // The unswitch will not happen if:
51 //  - The loop has any instruction that will prevent it;
52 //  - The loop invariant condition is not uniform.
53 class LoopUnswitch {
54  public:
LoopUnswitch(IRContext * context,Function * function,Loop * loop,LoopDescriptor * loop_desc)55   LoopUnswitch(IRContext* context, Function* function, Loop* loop,
56                LoopDescriptor* loop_desc)
57       : function_(function),
58         loop_(loop),
59         loop_desc_(*loop_desc),
60         context_(context),
61         switch_block_(nullptr) {}
62 
63   // Returns true if the loop can be unswitched.
64   // Can be unswitch if:
65   //  - The loop has no instructions that prevents it (such as barrier);
66   //  - The loop has one conditional branch or switch that do not depends on the
67   //  loop;
68   //  - The loop invariant condition is uniform;
CanUnswitchLoop()69   bool CanUnswitchLoop() {
70     if (switch_block_) return true;
71     if (loop_->IsSafeToClone()) return false;
72 
73     CFG& cfg = *context_->cfg();
74 
75     for (uint32_t bb_id : loop_->GetBlocks()) {
76       BasicBlock* bb = cfg.block(bb_id);
77       if (bb->terminator()->IsBranch() &&
78           bb->terminator()->opcode() != SpvOpBranch) {
79         if (IsConditionLoopInvariant(bb->terminator())) {
80           switch_block_ = bb;
81           break;
82         }
83       }
84     }
85 
86     return switch_block_;
87   }
88 
89   // Return the iterator to the basic block |bb|.
FindBasicBlockPosition(BasicBlock * bb_to_find)90   Function::iterator FindBasicBlockPosition(BasicBlock* bb_to_find) {
91     Function::iterator it = function_->FindBlock(bb_to_find->id());
92     assert(it != function_->end() && "Basic Block not found");
93     return it;
94   }
95 
96   // Creates a new basic block and insert it into the function |fn| at the
97   // position |ip|. This function preserves the def/use and instr to block
98   // managers.
CreateBasicBlock(Function::iterator ip)99   BasicBlock* CreateBasicBlock(Function::iterator ip) {
100     analysis::DefUseManager* def_use_mgr = context_->get_def_use_mgr();
101 
102     // TODO(1841): Handle id overflow.
103     BasicBlock* bb = &*ip.InsertBefore(std::unique_ptr<BasicBlock>(
104         new BasicBlock(std::unique_ptr<Instruction>(new Instruction(
105             context_, SpvOpLabel, 0, context_->TakeNextId(), {})))));
106     bb->SetParent(function_);
107     def_use_mgr->AnalyzeInstDef(bb->GetLabelInst());
108     context_->set_instr_block(bb->GetLabelInst(), bb);
109 
110     return bb;
111   }
112 
113   // Unswitches |loop_|.
PerformUnswitch()114   void PerformUnswitch() {
115     assert(CanUnswitchLoop() &&
116            "Cannot unswitch if there is not constant condition");
117     assert(loop_->GetPreHeaderBlock() && "This loop has no pre-header block");
118     assert(loop_->IsLCSSA() && "This loop is not in LCSSA form");
119 
120     CFG& cfg = *context_->cfg();
121     DominatorTree* dom_tree =
122         &context_->GetDominatorAnalysis(function_)->GetDomTree();
123     analysis::DefUseManager* def_use_mgr = context_->get_def_use_mgr();
124     LoopUtils loop_utils(context_, loop_);
125 
126     //////////////////////////////////////////////////////////////////////////////
127     // Step 1: Create the if merge block for structured modules.
128     //    To do so, the |loop_| merge block will become the if's one and we
129     //    create a merge for the loop. This will limit the amount of duplicated
130     //    code the structured control flow imposes.
131     //    For non structured program, the new loop will be connected to
132     //    the old loop's exit blocks.
133     //////////////////////////////////////////////////////////////////////////////
134 
135     // Get the merge block if it exists.
136     BasicBlock* if_merge_block = loop_->GetMergeBlock();
137     // The merge block is only created if the loop has a unique exit block. We
138     // have this guarantee for structured loops, for compute loop it will
139     // trivially help maintain both a structured-like form and LCSAA.
140     BasicBlock* loop_merge_block =
141         if_merge_block
142             ? CreateBasicBlock(FindBasicBlockPosition(if_merge_block))
143             : nullptr;
144     if (loop_merge_block) {
145       // Add the instruction and update managers.
146       InstructionBuilder builder(
147           context_, loop_merge_block,
148           IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
149       builder.AddBranch(if_merge_block->id());
150       builder.SetInsertPoint(&*loop_merge_block->begin());
151       cfg.RegisterBlock(loop_merge_block);
152       def_use_mgr->AnalyzeInstDef(loop_merge_block->GetLabelInst());
153       // Update CFG.
154       if_merge_block->ForEachPhiInst(
155           [loop_merge_block, &builder, this](Instruction* phi) {
156             Instruction* cloned = phi->Clone(context_);
157             builder.AddInstruction(std::unique_ptr<Instruction>(cloned));
158             phi->SetInOperand(0, {cloned->result_id()});
159             phi->SetInOperand(1, {loop_merge_block->id()});
160             for (uint32_t j = phi->NumInOperands() - 1; j > 1; j--)
161               phi->RemoveInOperand(j);
162           });
163       // Copy the predecessor list (will get invalidated otherwise).
164       std::vector<uint32_t> preds = cfg.preds(if_merge_block->id());
165       for (uint32_t pid : preds) {
166         if (pid == loop_merge_block->id()) continue;
167         BasicBlock* p_bb = cfg.block(pid);
168         p_bb->ForEachSuccessorLabel(
169             [if_merge_block, loop_merge_block](uint32_t* id) {
170               if (*id == if_merge_block->id()) *id = loop_merge_block->id();
171             });
172         cfg.AddEdge(pid, loop_merge_block->id());
173       }
174       cfg.RemoveNonExistingEdges(if_merge_block->id());
175       // Update loop descriptor.
176       if (Loop* ploop = loop_->GetParent()) {
177         ploop->AddBasicBlock(loop_merge_block);
178         loop_desc_.SetBasicBlockToLoop(loop_merge_block->id(), ploop);
179       }
180 
181       // Update the dominator tree.
182       DominatorTreeNode* loop_merge_dtn =
183           dom_tree->GetOrInsertNode(loop_merge_block);
184       DominatorTreeNode* if_merge_block_dtn =
185           dom_tree->GetOrInsertNode(if_merge_block);
186       loop_merge_dtn->parent_ = if_merge_block_dtn->parent_;
187       loop_merge_dtn->children_.push_back(if_merge_block_dtn);
188       loop_merge_dtn->parent_->children_.push_back(loop_merge_dtn);
189       if_merge_block_dtn->parent_->children_.erase(std::find(
190           if_merge_block_dtn->parent_->children_.begin(),
191           if_merge_block_dtn->parent_->children_.end(), if_merge_block_dtn));
192 
193       loop_->SetMergeBlock(loop_merge_block);
194     }
195 
196     ////////////////////////////////////////////////////////////////////////////
197     // Step 2: Build a new preheader for |loop_|, use the old one
198     //         for the constant branch.
199     ////////////////////////////////////////////////////////////////////////////
200 
201     BasicBlock* if_block = loop_->GetPreHeaderBlock();
202     // If this preheader is the parent loop header,
203     // we need to create a dedicated block for the if.
204     BasicBlock* loop_pre_header =
205         CreateBasicBlock(++FindBasicBlockPosition(if_block));
206     InstructionBuilder(
207         context_, loop_pre_header,
208         IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping)
209         .AddBranch(loop_->GetHeaderBlock()->id());
210 
211     if_block->tail()->SetInOperand(0, {loop_pre_header->id()});
212 
213     // Update loop descriptor.
214     if (Loop* ploop = loop_desc_[if_block]) {
215       ploop->AddBasicBlock(loop_pre_header);
216       loop_desc_.SetBasicBlockToLoop(loop_pre_header->id(), ploop);
217     }
218 
219     // Update the CFG.
220     cfg.RegisterBlock(loop_pre_header);
221     def_use_mgr->AnalyzeInstDef(loop_pre_header->GetLabelInst());
222     cfg.AddEdge(if_block->id(), loop_pre_header->id());
223     cfg.RemoveNonExistingEdges(loop_->GetHeaderBlock()->id());
224 
225     loop_->GetHeaderBlock()->ForEachPhiInst(
226         [loop_pre_header, if_block](Instruction* phi) {
227           phi->ForEachInId([loop_pre_header, if_block](uint32_t* id) {
228             if (*id == if_block->id()) {
229               *id = loop_pre_header->id();
230             }
231           });
232         });
233     loop_->SetPreHeaderBlock(loop_pre_header);
234 
235     // Update the dominator tree.
236     DominatorTreeNode* loop_pre_header_dtn =
237         dom_tree->GetOrInsertNode(loop_pre_header);
238     DominatorTreeNode* if_block_dtn = dom_tree->GetTreeNode(if_block);
239     loop_pre_header_dtn->parent_ = if_block_dtn;
240     assert(
241         if_block_dtn->children_.size() == 1 &&
242         "A loop preheader should only have the header block as a child in the "
243         "dominator tree");
244     loop_pre_header_dtn->children_.push_back(if_block_dtn->children_[0]);
245     if_block_dtn->children_.clear();
246     if_block_dtn->children_.push_back(loop_pre_header_dtn);
247 
248     // Make domination queries valid.
249     dom_tree->ResetDFNumbering();
250 
251     // Compute an ordered list of basic block to clone: loop blocks + pre-header
252     // + merge block.
253     loop_->ComputeLoopStructuredOrder(&ordered_loop_blocks_, true, true);
254 
255     /////////////////////////////
256     // Do the actual unswitch: //
257     //   - Clone the loop      //
258     //   - Connect exits       //
259     //   - Specialize the loop //
260     /////////////////////////////
261 
262     Instruction* iv_condition = &*switch_block_->tail();
263     SpvOp iv_opcode = iv_condition->opcode();
264     Instruction* condition =
265         def_use_mgr->GetDef(iv_condition->GetOperand(0).words[0]);
266 
267     analysis::ConstantManager* cst_mgr = context_->get_constant_mgr();
268     const analysis::Type* cond_type =
269         context_->get_type_mgr()->GetType(condition->type_id());
270 
271     // Build the list of value for which we need to clone and specialize the
272     // loop.
273     std::vector<std::pair<Instruction*, BasicBlock*>> constant_branch;
274     // Special case for the original loop
275     Instruction* original_loop_constant_value;
276     BasicBlock* original_loop_target;
277     if (iv_opcode == SpvOpBranchConditional) {
278       constant_branch.emplace_back(
279           cst_mgr->GetDefiningInstruction(cst_mgr->GetConstant(cond_type, {0})),
280           nullptr);
281       original_loop_constant_value =
282           cst_mgr->GetDefiningInstruction(cst_mgr->GetConstant(cond_type, {1}));
283     } else {
284       // We are looking to take the default branch, so we can't provide a
285       // specific value.
286       original_loop_constant_value = nullptr;
287       for (uint32_t i = 2; i < iv_condition->NumInOperands(); i += 2) {
288         constant_branch.emplace_back(
289             cst_mgr->GetDefiningInstruction(cst_mgr->GetConstant(
290                 cond_type, iv_condition->GetInOperand(i).words)),
291             nullptr);
292       }
293     }
294 
295     // Get the loop landing pads.
296     std::unordered_set<uint32_t> if_merging_blocks;
297     std::function<bool(uint32_t)> is_from_original_loop;
298     if (loop_->GetHeaderBlock()->GetLoopMergeInst()) {
299       if_merging_blocks.insert(if_merge_block->id());
300       is_from_original_loop = [this](uint32_t id) {
301         return loop_->IsInsideLoop(id) || loop_->GetMergeBlock()->id() == id;
302       };
303     } else {
304       loop_->GetExitBlocks(&if_merging_blocks);
305       is_from_original_loop = [this](uint32_t id) {
306         return loop_->IsInsideLoop(id);
307       };
308     }
309 
310     for (auto& specialisation_pair : constant_branch) {
311       Instruction* specialisation_value = specialisation_pair.first;
312       //////////////////////////////////////////////////////////
313       // Step 3: Duplicate |loop_|.
314       //////////////////////////////////////////////////////////
315       LoopUtils::LoopCloningResult clone_result;
316 
317       Loop* cloned_loop =
318           loop_utils.CloneLoop(&clone_result, ordered_loop_blocks_);
319       specialisation_pair.second = cloned_loop->GetPreHeaderBlock();
320 
321       ////////////////////////////////////
322       // Step 4: Specialize the loop.   //
323       ////////////////////////////////////
324 
325       {
326         std::unordered_set<uint32_t> dead_blocks;
327         std::unordered_set<uint32_t> unreachable_merges;
328         SimplifyLoop(
329             make_range(
330                 UptrVectorIterator<BasicBlock>(&clone_result.cloned_bb_,
331                                                clone_result.cloned_bb_.begin()),
332                 UptrVectorIterator<BasicBlock>(&clone_result.cloned_bb_,
333                                                clone_result.cloned_bb_.end())),
334             cloned_loop, condition, specialisation_value, &dead_blocks);
335 
336         // We tagged dead blocks, create the loop before we invalidate any basic
337         // block.
338         cloned_loop =
339             CleanLoopNest(cloned_loop, dead_blocks, &unreachable_merges);
340         CleanUpCFG(
341             UptrVectorIterator<BasicBlock>(&clone_result.cloned_bb_,
342                                            clone_result.cloned_bb_.begin()),
343             dead_blocks, unreachable_merges);
344 
345         ///////////////////////////////////////////////////////////
346         // Step 5: Connect convergent edges to the landing pads. //
347         ///////////////////////////////////////////////////////////
348 
349         for (uint32_t merge_bb_id : if_merging_blocks) {
350           BasicBlock* merge = context_->cfg()->block(merge_bb_id);
351           // We are in LCSSA so we only care about phi instructions.
352           merge->ForEachPhiInst([is_from_original_loop, &dead_blocks,
353                                  &clone_result](Instruction* phi) {
354             uint32_t num_in_operands = phi->NumInOperands();
355             for (uint32_t i = 0; i < num_in_operands; i += 2) {
356               uint32_t pred = phi->GetSingleWordInOperand(i + 1);
357               if (is_from_original_loop(pred)) {
358                 pred = clone_result.value_map_.at(pred);
359                 if (!dead_blocks.count(pred)) {
360                   uint32_t incoming_value_id = phi->GetSingleWordInOperand(i);
361                   // Not all the incoming value are coming from the loop.
362                   ValueMapTy::iterator new_value =
363                       clone_result.value_map_.find(incoming_value_id);
364                   if (new_value != clone_result.value_map_.end()) {
365                     incoming_value_id = new_value->second;
366                   }
367                   phi->AddOperand({SPV_OPERAND_TYPE_ID, {incoming_value_id}});
368                   phi->AddOperand({SPV_OPERAND_TYPE_ID, {pred}});
369                 }
370               }
371             }
372           });
373         }
374       }
375       function_->AddBasicBlocks(clone_result.cloned_bb_.begin(),
376                                 clone_result.cloned_bb_.end(),
377                                 ++FindBasicBlockPosition(if_block));
378     }
379 
380     // Same as above but specialize the existing loop
381     {
382       std::unordered_set<uint32_t> dead_blocks;
383       std::unordered_set<uint32_t> unreachable_merges;
384       SimplifyLoop(make_range(function_->begin(), function_->end()), loop_,
385                    condition, original_loop_constant_value, &dead_blocks);
386 
387       for (uint32_t merge_bb_id : if_merging_blocks) {
388         BasicBlock* merge = context_->cfg()->block(merge_bb_id);
389         // LCSSA, so we only care about phi instructions.
390         // If we the phi is reduced to a single incoming branch, do not
391         // propagate it to preserve LCSSA.
392         PatchPhis(merge, dead_blocks, true);
393       }
394       if (if_merge_block) {
395         bool has_live_pred = false;
396         for (uint32_t pid : cfg.preds(if_merge_block->id())) {
397           if (!dead_blocks.count(pid)) {
398             has_live_pred = true;
399             break;
400           }
401         }
402         if (!has_live_pred) unreachable_merges.insert(if_merge_block->id());
403       }
404       original_loop_target = loop_->GetPreHeaderBlock();
405       // We tagged dead blocks, prune the loop descriptor from any dead loops.
406       // After this call, |loop_| can be nullptr (i.e. the unswitch killed this
407       // loop).
408       loop_ = CleanLoopNest(loop_, dead_blocks, &unreachable_merges);
409 
410       CleanUpCFG(function_->begin(), dead_blocks, unreachable_merges);
411     }
412 
413     /////////////////////////////////////
414     // Finally: connect the new loops. //
415     /////////////////////////////////////
416 
417     // Delete the old jump
418     context_->KillInst(&*if_block->tail());
419     InstructionBuilder builder(context_, if_block);
420     if (iv_opcode == SpvOpBranchConditional) {
421       assert(constant_branch.size() == 1);
422       builder.AddConditionalBranch(
423           condition->result_id(), original_loop_target->id(),
424           constant_branch[0].second->id(),
425           if_merge_block ? if_merge_block->id() : kInvalidId);
426     } else {
427       std::vector<std::pair<Operand::OperandData, uint32_t>> targets;
428       for (auto& t : constant_branch) {
429         targets.emplace_back(t.first->GetInOperand(0).words, t.second->id());
430       }
431 
432       builder.AddSwitch(condition->result_id(), original_loop_target->id(),
433                         targets,
434                         if_merge_block ? if_merge_block->id() : kInvalidId);
435     }
436 
437     switch_block_ = nullptr;
438     ordered_loop_blocks_.clear();
439 
440     context_->InvalidateAnalysesExceptFor(
441         IRContext::Analysis::kAnalysisLoopAnalysis);
442   }
443 
444   // Returns true if the unswitch killed the original |loop_|.
WasLoopKilled() const445   bool WasLoopKilled() const { return loop_ == nullptr; }
446 
447  private:
448   using ValueMapTy = std::unordered_map<uint32_t, uint32_t>;
449   using BlockMapTy = std::unordered_map<uint32_t, BasicBlock*>;
450 
451   Function* function_;
452   Loop* loop_;
453   LoopDescriptor& loop_desc_;
454   IRContext* context_;
455 
456   BasicBlock* switch_block_;
457   // Map between instructions and if they are dynamically uniform.
458   std::unordered_map<uint32_t, bool> dynamically_uniform_;
459   // The loop basic blocks in structured order.
460   std::vector<BasicBlock*> ordered_loop_blocks_;
461 
462   // Returns the next usable id for the context.
TakeNextId()463   uint32_t TakeNextId() {
464     // TODO(1841): Handle id overflow.
465     return context_->TakeNextId();
466   }
467 
468   // Patches |bb|'s phi instruction by removing incoming value from unexisting
469   // or tagged as dead branches.
PatchPhis(BasicBlock * bb,const std::unordered_set<uint32_t> & dead_blocks,bool preserve_phi)470   void PatchPhis(BasicBlock* bb,
471                  const std::unordered_set<uint32_t>& dead_blocks,
472                  bool preserve_phi) {
473     CFG& cfg = *context_->cfg();
474 
475     std::vector<Instruction*> phi_to_kill;
476     const std::vector<uint32_t>& bb_preds = cfg.preds(bb->id());
477     auto is_branch_dead = [&bb_preds, &dead_blocks](uint32_t id) {
478       return dead_blocks.count(id) ||
479              std::find(bb_preds.begin(), bb_preds.end(), id) == bb_preds.end();
480     };
481     bb->ForEachPhiInst(
482         [&phi_to_kill, &is_branch_dead, preserve_phi, this](Instruction* insn) {
483           uint32_t i = 0;
484           while (i < insn->NumInOperands()) {
485             uint32_t incoming_id = insn->GetSingleWordInOperand(i + 1);
486             if (is_branch_dead(incoming_id)) {
487               // Remove the incoming block id operand.
488               insn->RemoveInOperand(i + 1);
489               // Remove the definition id operand.
490               insn->RemoveInOperand(i);
491               continue;
492             }
493             i += 2;
494           }
495           // If there is only 1 remaining edge, propagate the value and
496           // kill the instruction.
497           if (insn->NumInOperands() == 2 && !preserve_phi) {
498             phi_to_kill.push_back(insn);
499             context_->ReplaceAllUsesWith(insn->result_id(),
500                                          insn->GetSingleWordInOperand(0));
501           }
502         });
503     for (Instruction* insn : phi_to_kill) {
504       context_->KillInst(insn);
505     }
506   }
507 
508   // Removes any block that is tagged as dead, if the block is in
509   // |unreachable_merges| then all block's instructions are replaced by a
510   // OpUnreachable.
CleanUpCFG(UptrVectorIterator<BasicBlock> bb_it,const std::unordered_set<uint32_t> & dead_blocks,const std::unordered_set<uint32_t> & unreachable_merges)511   void CleanUpCFG(UptrVectorIterator<BasicBlock> bb_it,
512                   const std::unordered_set<uint32_t>& dead_blocks,
513                   const std::unordered_set<uint32_t>& unreachable_merges) {
514     CFG& cfg = *context_->cfg();
515 
516     while (bb_it != bb_it.End()) {
517       BasicBlock& bb = *bb_it;
518 
519       if (unreachable_merges.count(bb.id())) {
520         if (bb.begin() != bb.tail() ||
521             bb.terminator()->opcode() != SpvOpUnreachable) {
522           // Make unreachable, but leave the label.
523           bb.KillAllInsts(false);
524           InstructionBuilder(context_, &bb).AddUnreachable();
525           cfg.RemoveNonExistingEdges(bb.id());
526         }
527         ++bb_it;
528       } else if (dead_blocks.count(bb.id())) {
529         cfg.ForgetBlock(&bb);
530         // Kill this block.
531         bb.KillAllInsts(true);
532         bb_it = bb_it.Erase();
533       } else {
534         cfg.RemoveNonExistingEdges(bb.id());
535         ++bb_it;
536       }
537     }
538   }
539 
540   // Return true if |c_inst| is a Boolean constant and set |cond_val| with the
541   // value that |c_inst|
GetConstCondition(const Instruction * c_inst,bool * cond_val)542   bool GetConstCondition(const Instruction* c_inst, bool* cond_val) {
543     bool cond_is_const;
544     switch (c_inst->opcode()) {
545       case SpvOpConstantFalse: {
546         *cond_val = false;
547         cond_is_const = true;
548       } break;
549       case SpvOpConstantTrue: {
550         *cond_val = true;
551         cond_is_const = true;
552       } break;
553       default: { cond_is_const = false; } break;
554     }
555     return cond_is_const;
556   }
557 
558   // Simplifies |loop| assuming the instruction |to_version_insn| takes the
559   // value |cst_value|. |block_range| is an iterator range returning the loop
560   // basic blocks in a structured order (dominator first).
561   // The function will ignore basic blocks returned by |block_range| if they
562   // does not belong to the loop.
563   // The set |dead_blocks| will contain all the dead basic blocks.
564   //
565   // Requirements:
566   //   - |loop| must be in the LCSSA form;
567   //   - |cst_value| must be constant or null (to represent the default target
568   //   of an OpSwitch).
SimplifyLoop(IteratorRange<UptrVectorIterator<BasicBlock>> block_range,Loop * loop,Instruction * to_version_insn,Instruction * cst_value,std::unordered_set<uint32_t> * dead_blocks)569   void SimplifyLoop(IteratorRange<UptrVectorIterator<BasicBlock>> block_range,
570                     Loop* loop, Instruction* to_version_insn,
571                     Instruction* cst_value,
572                     std::unordered_set<uint32_t>* dead_blocks) {
573     CFG& cfg = *context_->cfg();
574     analysis::DefUseManager* def_use_mgr = context_->get_def_use_mgr();
575 
576     std::function<bool(uint32_t)> ignore_node;
577     ignore_node = [loop](uint32_t bb_id) { return !loop->IsInsideLoop(bb_id); };
578 
579     std::vector<std::pair<Instruction*, uint32_t>> use_list;
580     def_use_mgr->ForEachUse(to_version_insn,
581                             [&use_list, &ignore_node, this](
582                                 Instruction* inst, uint32_t operand_index) {
583                               BasicBlock* bb = context_->get_instr_block(inst);
584 
585                               if (!bb || ignore_node(bb->id())) {
586                                 // Out of the loop, the specialization does not
587                                 // apply any more.
588                                 return;
589                               }
590                               use_list.emplace_back(inst, operand_index);
591                             });
592 
593     // First pass: inject the specialized value into the loop (and only the
594     // loop).
595     for (auto use : use_list) {
596       Instruction* inst = use.first;
597       uint32_t operand_index = use.second;
598       BasicBlock* bb = context_->get_instr_block(inst);
599 
600       // If it is not a branch, simply inject the value.
601       if (!inst->IsBranch()) {
602         // To also handle switch, cst_value can be nullptr: this case
603         // means that we are looking to branch to the default target of
604         // the switch. We don't actually know its value so we don't touch
605         // it if it not a switch.
606         if (cst_value) {
607           inst->SetOperand(operand_index, {cst_value->result_id()});
608           def_use_mgr->AnalyzeInstUse(inst);
609         }
610       }
611 
612       // The user is a branch, kill dead branches.
613       uint32_t live_target = 0;
614       std::unordered_set<uint32_t> dead_branches;
615       switch (inst->opcode()) {
616         case SpvOpBranchConditional: {
617           assert(cst_value && "No constant value to specialize !");
618           bool branch_cond = false;
619           if (GetConstCondition(cst_value, &branch_cond)) {
620             uint32_t true_label =
621                 inst->GetSingleWordInOperand(kBranchCondTrueLabIdInIdx);
622             uint32_t false_label =
623                 inst->GetSingleWordInOperand(kBranchCondFalseLabIdInIdx);
624             live_target = branch_cond ? true_label : false_label;
625             uint32_t dead_target = !branch_cond ? true_label : false_label;
626             cfg.RemoveEdge(bb->id(), dead_target);
627           }
628           break;
629         }
630         case SpvOpSwitch: {
631           live_target = inst->GetSingleWordInOperand(1);
632           if (cst_value) {
633             if (!cst_value->IsConstant()) break;
634             const Operand& cst = cst_value->GetInOperand(0);
635             for (uint32_t i = 2; i < inst->NumInOperands(); i += 2) {
636               const Operand& literal = inst->GetInOperand(i);
637               if (literal == cst) {
638                 live_target = inst->GetSingleWordInOperand(i + 1);
639                 break;
640               }
641             }
642           }
643           for (uint32_t i = 1; i < inst->NumInOperands(); i += 2) {
644             uint32_t id = inst->GetSingleWordInOperand(i);
645             if (id != live_target) {
646               cfg.RemoveEdge(bb->id(), id);
647             }
648           }
649         }
650         default:
651           break;
652       }
653       if (live_target != 0) {
654         // Check for the presence of the merge block.
655         if (Instruction* merge = bb->GetMergeInst()) context_->KillInst(merge);
656         context_->KillInst(&*bb->tail());
657         InstructionBuilder builder(context_, bb,
658                                    IRContext::kAnalysisDefUse |
659                                        IRContext::kAnalysisInstrToBlockMapping);
660         builder.AddBranch(live_target);
661       }
662     }
663 
664     // Go through the loop basic block and tag all blocks that are obviously
665     // dead.
666     std::unordered_set<uint32_t> visited;
667     for (BasicBlock& bb : block_range) {
668       if (ignore_node(bb.id())) continue;
669       visited.insert(bb.id());
670 
671       // Check if this block is dead, if so tag it as dead otherwise patch phi
672       // instructions.
673       bool has_live_pred = false;
674       for (uint32_t pid : cfg.preds(bb.id())) {
675         if (!dead_blocks->count(pid)) {
676           has_live_pred = true;
677           break;
678         }
679       }
680       if (!has_live_pred) {
681         dead_blocks->insert(bb.id());
682         const BasicBlock& cbb = bb;
683         // Patch the phis for any back-edge.
684         cbb.ForEachSuccessorLabel(
685             [dead_blocks, &visited, &cfg, this](uint32_t id) {
686               if (!visited.count(id) || dead_blocks->count(id)) return;
687               BasicBlock* succ = cfg.block(id);
688               PatchPhis(succ, *dead_blocks, false);
689             });
690         continue;
691       }
692       // Update the phi instructions, some incoming branch have/will disappear.
693       PatchPhis(&bb, *dead_blocks, /* preserve_phi = */ false);
694     }
695   }
696 
697   // Returns true if the header is not reachable or tagged as dead or if we
698   // never loop back.
IsLoopDead(BasicBlock * header,BasicBlock * latch,const std::unordered_set<uint32_t> & dead_blocks)699   bool IsLoopDead(BasicBlock* header, BasicBlock* latch,
700                   const std::unordered_set<uint32_t>& dead_blocks) {
701     if (!header || dead_blocks.count(header->id())) return true;
702     if (!latch || dead_blocks.count(latch->id())) return true;
703     for (uint32_t pid : context_->cfg()->preds(header->id())) {
704       if (!dead_blocks.count(pid)) {
705         // Seems reachable.
706         return false;
707       }
708     }
709     return true;
710   }
711 
712   // Cleans the loop nest under |loop| and reflect changes to the loop
713   // descriptor. This will kill all descriptors that represent dead loops.
714   // If |loop_| is killed, it will be set to nullptr.
715   // Any merge blocks that become unreachable will be added to
716   // |unreachable_merges|.
717   // The function returns the pointer to |loop| or nullptr if the loop was
718   // killed.
CleanLoopNest(Loop * loop,const std::unordered_set<uint32_t> & dead_blocks,std::unordered_set<uint32_t> * unreachable_merges)719   Loop* CleanLoopNest(Loop* loop,
720                       const std::unordered_set<uint32_t>& dead_blocks,
721                       std::unordered_set<uint32_t>* unreachable_merges) {
722     // This represent the pair of dead loop and nearest alive parent (nullptr if
723     // no parent).
724     std::unordered_map<Loop*, Loop*> dead_loops;
725     auto get_parent = [&dead_loops](Loop* l) -> Loop* {
726       std::unordered_map<Loop*, Loop*>::iterator it = dead_loops.find(l);
727       if (it != dead_loops.end()) return it->second;
728       return nullptr;
729     };
730 
731     bool is_main_loop_dead =
732         IsLoopDead(loop->GetHeaderBlock(), loop->GetLatchBlock(), dead_blocks);
733     if (is_main_loop_dead) {
734       if (Instruction* merge = loop->GetHeaderBlock()->GetLoopMergeInst()) {
735         context_->KillInst(merge);
736       }
737       dead_loops[loop] = loop->GetParent();
738     } else {
739       dead_loops[loop] = loop;
740     }
741 
742     // For each loop, check if we killed it. If we did, find a suitable parent
743     // for its children.
744     for (Loop& sub_loop :
745          make_range(++TreeDFIterator<Loop>(loop), TreeDFIterator<Loop>())) {
746       if (IsLoopDead(sub_loop.GetHeaderBlock(), sub_loop.GetLatchBlock(),
747                      dead_blocks)) {
748         if (Instruction* merge =
749                 sub_loop.GetHeaderBlock()->GetLoopMergeInst()) {
750           context_->KillInst(merge);
751         }
752         dead_loops[&sub_loop] = get_parent(&sub_loop);
753       } else {
754         // The loop is alive, check if its merge block is dead, if it is, tag it
755         // as required.
756         if (sub_loop.GetMergeBlock()) {
757           uint32_t merge_id = sub_loop.GetMergeBlock()->id();
758           if (dead_blocks.count(merge_id)) {
759             unreachable_merges->insert(sub_loop.GetMergeBlock()->id());
760           }
761         }
762       }
763     }
764     if (!is_main_loop_dead) dead_loops.erase(loop);
765 
766     // Remove dead blocks from live loops.
767     for (uint32_t bb_id : dead_blocks) {
768       Loop* l = loop_desc_[bb_id];
769       if (l) {
770         l->RemoveBasicBlock(bb_id);
771         loop_desc_.ForgetBasicBlock(bb_id);
772       }
773     }
774 
775     std::for_each(
776         dead_loops.begin(), dead_loops.end(),
777         [&loop,
778          this](std::unordered_map<Loop*, Loop*>::iterator::reference it) {
779           if (it.first == loop) loop = nullptr;
780           loop_desc_.RemoveLoop(it.first);
781         });
782 
783     return loop;
784   }
785 
786   // Returns true if |var| is dynamically uniform.
787   // Note: this is currently approximated as uniform.
IsDynamicallyUniform(Instruction * var,const BasicBlock * entry,const DominatorTree & post_dom_tree)788   bool IsDynamicallyUniform(Instruction* var, const BasicBlock* entry,
789                             const DominatorTree& post_dom_tree) {
790     assert(post_dom_tree.IsPostDominator());
791     analysis::DefUseManager* def_use_mgr = context_->get_def_use_mgr();
792 
793     auto it = dynamically_uniform_.find(var->result_id());
794 
795     if (it != dynamically_uniform_.end()) return it->second;
796 
797     analysis::DecorationManager* dec_mgr = context_->get_decoration_mgr();
798 
799     bool& is_uniform = dynamically_uniform_[var->result_id()];
800     is_uniform = false;
801 
802     dec_mgr->WhileEachDecoration(var->result_id(), SpvDecorationUniform,
803                                  [&is_uniform](const Instruction&) {
804                                    is_uniform = true;
805                                    return false;
806                                  });
807     if (is_uniform) {
808       return is_uniform;
809     }
810 
811     BasicBlock* parent = context_->get_instr_block(var);
812     if (!parent) {
813       return is_uniform = true;
814     }
815 
816     if (!post_dom_tree.Dominates(parent->id(), entry->id())) {
817       return is_uniform = false;
818     }
819     if (var->opcode() == SpvOpLoad) {
820       const uint32_t PtrTypeId =
821           def_use_mgr->GetDef(var->GetSingleWordInOperand(0))->type_id();
822       const Instruction* PtrTypeInst = def_use_mgr->GetDef(PtrTypeId);
823       uint32_t storage_class =
824           PtrTypeInst->GetSingleWordInOperand(kTypePointerStorageClassInIdx);
825       if (storage_class != SpvStorageClassUniform &&
826           storage_class != SpvStorageClassUniformConstant) {
827         return is_uniform = false;
828       }
829     } else {
830       if (!context_->IsCombinatorInstruction(var)) {
831         return is_uniform = false;
832       }
833     }
834 
835     return is_uniform = var->WhileEachInId([entry, &post_dom_tree,
836                                             this](const uint32_t* id) {
837       return IsDynamicallyUniform(context_->get_def_use_mgr()->GetDef(*id),
838                                   entry, post_dom_tree);
839     });
840   }
841 
842   // Returns true if |insn| is constant and dynamically uniform within the loop.
IsConditionLoopInvariant(Instruction * insn)843   bool IsConditionLoopInvariant(Instruction* insn) {
844     assert(insn->IsBranch());
845     assert(insn->opcode() != SpvOpBranch);
846     analysis::DefUseManager* def_use_mgr = context_->get_def_use_mgr();
847 
848     Instruction* condition = def_use_mgr->GetDef(insn->GetOperand(0).words[0]);
849     return !loop_->IsInsideLoop(condition) &&
850            IsDynamicallyUniform(
851                condition, function_->entry().get(),
852                context_->GetPostDominatorAnalysis(function_)->GetDomTree());
853   }
854 };
855 
856 }  // namespace
857 
Process()858 Pass::Status LoopUnswitchPass::Process() {
859   bool modified = false;
860   Module* module = context()->module();
861 
862   // Process each function in the module
863   for (Function& f : *module) {
864     modified |= ProcessFunction(&f);
865   }
866 
867   return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange;
868 }
869 
ProcessFunction(Function * f)870 bool LoopUnswitchPass::ProcessFunction(Function* f) {
871   bool modified = false;
872   std::unordered_set<Loop*> processed_loop;
873 
874   LoopDescriptor& loop_descriptor = *context()->GetLoopDescriptor(f);
875 
876   bool loop_changed = true;
877   while (loop_changed) {
878     loop_changed = false;
879     for (Loop& loop :
880          make_range(++TreeDFIterator<Loop>(loop_descriptor.GetDummyRootLoop()),
881                     TreeDFIterator<Loop>())) {
882       if (processed_loop.count(&loop)) continue;
883       processed_loop.insert(&loop);
884 
885       LoopUnswitch unswitcher(context(), f, &loop, &loop_descriptor);
886       while (!unswitcher.WasLoopKilled() && unswitcher.CanUnswitchLoop()) {
887         if (!loop.IsLCSSA()) {
888           LoopUtils(context(), &loop).MakeLoopClosedSSA();
889         }
890         modified = true;
891         loop_changed = true;
892         unswitcher.PerformUnswitch();
893       }
894       if (loop_changed) break;
895     }
896   }
897 
898   return modified;
899 }
900 
901 }  // namespace opt
902 }  // namespace spvtools
903