1 // Copyright (c) 2018 Google LLC.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "source/opt/loop_unswitch_pass.h"
16
17 #include <functional>
18 #include <list>
19 #include <memory>
20 #include <type_traits>
21 #include <unordered_map>
22 #include <unordered_set>
23 #include <utility>
24 #include <vector>
25
26 #include "source/opt/basic_block.h"
27 #include "source/opt/dominator_tree.h"
28 #include "source/opt/fold.h"
29 #include "source/opt/function.h"
30 #include "source/opt/instruction.h"
31 #include "source/opt/ir_builder.h"
32 #include "source/opt/ir_context.h"
33 #include "source/opt/loop_descriptor.h"
34
35 #include "source/opt/loop_utils.h"
36
37 namespace spvtools {
38 namespace opt {
39 namespace {
40
41 static const uint32_t kTypePointerStorageClassInIdx = 0;
42 static const uint32_t kBranchCondTrueLabIdInIdx = 1;
43 static const uint32_t kBranchCondFalseLabIdInIdx = 2;
44
45 } // anonymous namespace
46
47 namespace {
48
49 // This class handle the unswitch procedure for a given loop.
50 // The unswitch will not happen if:
51 // - The loop has any instruction that will prevent it;
52 // - The loop invariant condition is not uniform.
53 class LoopUnswitch {
54 public:
LoopUnswitch(IRContext * context,Function * function,Loop * loop,LoopDescriptor * loop_desc)55 LoopUnswitch(IRContext* context, Function* function, Loop* loop,
56 LoopDescriptor* loop_desc)
57 : function_(function),
58 loop_(loop),
59 loop_desc_(*loop_desc),
60 context_(context),
61 switch_block_(nullptr) {}
62
63 // Returns true if the loop can be unswitched.
64 // Can be unswitch if:
65 // - The loop has no instructions that prevents it (such as barrier);
66 // - The loop has one conditional branch or switch that do not depends on the
67 // loop;
68 // - The loop invariant condition is uniform;
CanUnswitchLoop()69 bool CanUnswitchLoop() {
70 if (switch_block_) return true;
71 if (loop_->IsSafeToClone()) return false;
72
73 CFG& cfg = *context_->cfg();
74
75 for (uint32_t bb_id : loop_->GetBlocks()) {
76 BasicBlock* bb = cfg.block(bb_id);
77 if (bb->terminator()->IsBranch() &&
78 bb->terminator()->opcode() != SpvOpBranch) {
79 if (IsConditionLoopInvariant(bb->terminator())) {
80 switch_block_ = bb;
81 break;
82 }
83 }
84 }
85
86 return switch_block_;
87 }
88
89 // Return the iterator to the basic block |bb|.
FindBasicBlockPosition(BasicBlock * bb_to_find)90 Function::iterator FindBasicBlockPosition(BasicBlock* bb_to_find) {
91 Function::iterator it = function_->FindBlock(bb_to_find->id());
92 assert(it != function_->end() && "Basic Block not found");
93 return it;
94 }
95
96 // Creates a new basic block and insert it into the function |fn| at the
97 // position |ip|. This function preserves the def/use and instr to block
98 // managers.
CreateBasicBlock(Function::iterator ip)99 BasicBlock* CreateBasicBlock(Function::iterator ip) {
100 analysis::DefUseManager* def_use_mgr = context_->get_def_use_mgr();
101
102 // TODO(1841): Handle id overflow.
103 BasicBlock* bb = &*ip.InsertBefore(std::unique_ptr<BasicBlock>(
104 new BasicBlock(std::unique_ptr<Instruction>(new Instruction(
105 context_, SpvOpLabel, 0, context_->TakeNextId(), {})))));
106 bb->SetParent(function_);
107 def_use_mgr->AnalyzeInstDef(bb->GetLabelInst());
108 context_->set_instr_block(bb->GetLabelInst(), bb);
109
110 return bb;
111 }
112
113 // Unswitches |loop_|.
PerformUnswitch()114 void PerformUnswitch() {
115 assert(CanUnswitchLoop() &&
116 "Cannot unswitch if there is not constant condition");
117 assert(loop_->GetPreHeaderBlock() && "This loop has no pre-header block");
118 assert(loop_->IsLCSSA() && "This loop is not in LCSSA form");
119
120 CFG& cfg = *context_->cfg();
121 DominatorTree* dom_tree =
122 &context_->GetDominatorAnalysis(function_)->GetDomTree();
123 analysis::DefUseManager* def_use_mgr = context_->get_def_use_mgr();
124 LoopUtils loop_utils(context_, loop_);
125
126 //////////////////////////////////////////////////////////////////////////////
127 // Step 1: Create the if merge block for structured modules.
128 // To do so, the |loop_| merge block will become the if's one and we
129 // create a merge for the loop. This will limit the amount of duplicated
130 // code the structured control flow imposes.
131 // For non structured program, the new loop will be connected to
132 // the old loop's exit blocks.
133 //////////////////////////////////////////////////////////////////////////////
134
135 // Get the merge block if it exists.
136 BasicBlock* if_merge_block = loop_->GetMergeBlock();
137 // The merge block is only created if the loop has a unique exit block. We
138 // have this guarantee for structured loops, for compute loop it will
139 // trivially help maintain both a structured-like form and LCSAA.
140 BasicBlock* loop_merge_block =
141 if_merge_block
142 ? CreateBasicBlock(FindBasicBlockPosition(if_merge_block))
143 : nullptr;
144 if (loop_merge_block) {
145 // Add the instruction and update managers.
146 InstructionBuilder builder(
147 context_, loop_merge_block,
148 IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
149 builder.AddBranch(if_merge_block->id());
150 builder.SetInsertPoint(&*loop_merge_block->begin());
151 cfg.RegisterBlock(loop_merge_block);
152 def_use_mgr->AnalyzeInstDef(loop_merge_block->GetLabelInst());
153 // Update CFG.
154 if_merge_block->ForEachPhiInst(
155 [loop_merge_block, &builder, this](Instruction* phi) {
156 Instruction* cloned = phi->Clone(context_);
157 builder.AddInstruction(std::unique_ptr<Instruction>(cloned));
158 phi->SetInOperand(0, {cloned->result_id()});
159 phi->SetInOperand(1, {loop_merge_block->id()});
160 for (uint32_t j = phi->NumInOperands() - 1; j > 1; j--)
161 phi->RemoveInOperand(j);
162 });
163 // Copy the predecessor list (will get invalidated otherwise).
164 std::vector<uint32_t> preds = cfg.preds(if_merge_block->id());
165 for (uint32_t pid : preds) {
166 if (pid == loop_merge_block->id()) continue;
167 BasicBlock* p_bb = cfg.block(pid);
168 p_bb->ForEachSuccessorLabel(
169 [if_merge_block, loop_merge_block](uint32_t* id) {
170 if (*id == if_merge_block->id()) *id = loop_merge_block->id();
171 });
172 cfg.AddEdge(pid, loop_merge_block->id());
173 }
174 cfg.RemoveNonExistingEdges(if_merge_block->id());
175 // Update loop descriptor.
176 if (Loop* ploop = loop_->GetParent()) {
177 ploop->AddBasicBlock(loop_merge_block);
178 loop_desc_.SetBasicBlockToLoop(loop_merge_block->id(), ploop);
179 }
180
181 // Update the dominator tree.
182 DominatorTreeNode* loop_merge_dtn =
183 dom_tree->GetOrInsertNode(loop_merge_block);
184 DominatorTreeNode* if_merge_block_dtn =
185 dom_tree->GetOrInsertNode(if_merge_block);
186 loop_merge_dtn->parent_ = if_merge_block_dtn->parent_;
187 loop_merge_dtn->children_.push_back(if_merge_block_dtn);
188 loop_merge_dtn->parent_->children_.push_back(loop_merge_dtn);
189 if_merge_block_dtn->parent_->children_.erase(std::find(
190 if_merge_block_dtn->parent_->children_.begin(),
191 if_merge_block_dtn->parent_->children_.end(), if_merge_block_dtn));
192
193 loop_->SetMergeBlock(loop_merge_block);
194 }
195
196 ////////////////////////////////////////////////////////////////////////////
197 // Step 2: Build a new preheader for |loop_|, use the old one
198 // for the constant branch.
199 ////////////////////////////////////////////////////////////////////////////
200
201 BasicBlock* if_block = loop_->GetPreHeaderBlock();
202 // If this preheader is the parent loop header,
203 // we need to create a dedicated block for the if.
204 BasicBlock* loop_pre_header =
205 CreateBasicBlock(++FindBasicBlockPosition(if_block));
206 InstructionBuilder(
207 context_, loop_pre_header,
208 IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping)
209 .AddBranch(loop_->GetHeaderBlock()->id());
210
211 if_block->tail()->SetInOperand(0, {loop_pre_header->id()});
212
213 // Update loop descriptor.
214 if (Loop* ploop = loop_desc_[if_block]) {
215 ploop->AddBasicBlock(loop_pre_header);
216 loop_desc_.SetBasicBlockToLoop(loop_pre_header->id(), ploop);
217 }
218
219 // Update the CFG.
220 cfg.RegisterBlock(loop_pre_header);
221 def_use_mgr->AnalyzeInstDef(loop_pre_header->GetLabelInst());
222 cfg.AddEdge(if_block->id(), loop_pre_header->id());
223 cfg.RemoveNonExistingEdges(loop_->GetHeaderBlock()->id());
224
225 loop_->GetHeaderBlock()->ForEachPhiInst(
226 [loop_pre_header, if_block](Instruction* phi) {
227 phi->ForEachInId([loop_pre_header, if_block](uint32_t* id) {
228 if (*id == if_block->id()) {
229 *id = loop_pre_header->id();
230 }
231 });
232 });
233 loop_->SetPreHeaderBlock(loop_pre_header);
234
235 // Update the dominator tree.
236 DominatorTreeNode* loop_pre_header_dtn =
237 dom_tree->GetOrInsertNode(loop_pre_header);
238 DominatorTreeNode* if_block_dtn = dom_tree->GetTreeNode(if_block);
239 loop_pre_header_dtn->parent_ = if_block_dtn;
240 assert(
241 if_block_dtn->children_.size() == 1 &&
242 "A loop preheader should only have the header block as a child in the "
243 "dominator tree");
244 loop_pre_header_dtn->children_.push_back(if_block_dtn->children_[0]);
245 if_block_dtn->children_.clear();
246 if_block_dtn->children_.push_back(loop_pre_header_dtn);
247
248 // Make domination queries valid.
249 dom_tree->ResetDFNumbering();
250
251 // Compute an ordered list of basic block to clone: loop blocks + pre-header
252 // + merge block.
253 loop_->ComputeLoopStructuredOrder(&ordered_loop_blocks_, true, true);
254
255 /////////////////////////////
256 // Do the actual unswitch: //
257 // - Clone the loop //
258 // - Connect exits //
259 // - Specialize the loop //
260 /////////////////////////////
261
262 Instruction* iv_condition = &*switch_block_->tail();
263 SpvOp iv_opcode = iv_condition->opcode();
264 Instruction* condition =
265 def_use_mgr->GetDef(iv_condition->GetOperand(0).words[0]);
266
267 analysis::ConstantManager* cst_mgr = context_->get_constant_mgr();
268 const analysis::Type* cond_type =
269 context_->get_type_mgr()->GetType(condition->type_id());
270
271 // Build the list of value for which we need to clone and specialize the
272 // loop.
273 std::vector<std::pair<Instruction*, BasicBlock*>> constant_branch;
274 // Special case for the original loop
275 Instruction* original_loop_constant_value;
276 BasicBlock* original_loop_target;
277 if (iv_opcode == SpvOpBranchConditional) {
278 constant_branch.emplace_back(
279 cst_mgr->GetDefiningInstruction(cst_mgr->GetConstant(cond_type, {0})),
280 nullptr);
281 original_loop_constant_value =
282 cst_mgr->GetDefiningInstruction(cst_mgr->GetConstant(cond_type, {1}));
283 } else {
284 // We are looking to take the default branch, so we can't provide a
285 // specific value.
286 original_loop_constant_value = nullptr;
287 for (uint32_t i = 2; i < iv_condition->NumInOperands(); i += 2) {
288 constant_branch.emplace_back(
289 cst_mgr->GetDefiningInstruction(cst_mgr->GetConstant(
290 cond_type, iv_condition->GetInOperand(i).words)),
291 nullptr);
292 }
293 }
294
295 // Get the loop landing pads.
296 std::unordered_set<uint32_t> if_merging_blocks;
297 std::function<bool(uint32_t)> is_from_original_loop;
298 if (loop_->GetHeaderBlock()->GetLoopMergeInst()) {
299 if_merging_blocks.insert(if_merge_block->id());
300 is_from_original_loop = [this](uint32_t id) {
301 return loop_->IsInsideLoop(id) || loop_->GetMergeBlock()->id() == id;
302 };
303 } else {
304 loop_->GetExitBlocks(&if_merging_blocks);
305 is_from_original_loop = [this](uint32_t id) {
306 return loop_->IsInsideLoop(id);
307 };
308 }
309
310 for (auto& specialisation_pair : constant_branch) {
311 Instruction* specialisation_value = specialisation_pair.first;
312 //////////////////////////////////////////////////////////
313 // Step 3: Duplicate |loop_|.
314 //////////////////////////////////////////////////////////
315 LoopUtils::LoopCloningResult clone_result;
316
317 Loop* cloned_loop =
318 loop_utils.CloneLoop(&clone_result, ordered_loop_blocks_);
319 specialisation_pair.second = cloned_loop->GetPreHeaderBlock();
320
321 ////////////////////////////////////
322 // Step 4: Specialize the loop. //
323 ////////////////////////////////////
324
325 {
326 std::unordered_set<uint32_t> dead_blocks;
327 std::unordered_set<uint32_t> unreachable_merges;
328 SimplifyLoop(
329 make_range(
330 UptrVectorIterator<BasicBlock>(&clone_result.cloned_bb_,
331 clone_result.cloned_bb_.begin()),
332 UptrVectorIterator<BasicBlock>(&clone_result.cloned_bb_,
333 clone_result.cloned_bb_.end())),
334 cloned_loop, condition, specialisation_value, &dead_blocks);
335
336 // We tagged dead blocks, create the loop before we invalidate any basic
337 // block.
338 cloned_loop =
339 CleanLoopNest(cloned_loop, dead_blocks, &unreachable_merges);
340 CleanUpCFG(
341 UptrVectorIterator<BasicBlock>(&clone_result.cloned_bb_,
342 clone_result.cloned_bb_.begin()),
343 dead_blocks, unreachable_merges);
344
345 ///////////////////////////////////////////////////////////
346 // Step 5: Connect convergent edges to the landing pads. //
347 ///////////////////////////////////////////////////////////
348
349 for (uint32_t merge_bb_id : if_merging_blocks) {
350 BasicBlock* merge = context_->cfg()->block(merge_bb_id);
351 // We are in LCSSA so we only care about phi instructions.
352 merge->ForEachPhiInst([is_from_original_loop, &dead_blocks,
353 &clone_result](Instruction* phi) {
354 uint32_t num_in_operands = phi->NumInOperands();
355 for (uint32_t i = 0; i < num_in_operands; i += 2) {
356 uint32_t pred = phi->GetSingleWordInOperand(i + 1);
357 if (is_from_original_loop(pred)) {
358 pred = clone_result.value_map_.at(pred);
359 if (!dead_blocks.count(pred)) {
360 uint32_t incoming_value_id = phi->GetSingleWordInOperand(i);
361 // Not all the incoming value are coming from the loop.
362 ValueMapTy::iterator new_value =
363 clone_result.value_map_.find(incoming_value_id);
364 if (new_value != clone_result.value_map_.end()) {
365 incoming_value_id = new_value->second;
366 }
367 phi->AddOperand({SPV_OPERAND_TYPE_ID, {incoming_value_id}});
368 phi->AddOperand({SPV_OPERAND_TYPE_ID, {pred}});
369 }
370 }
371 }
372 });
373 }
374 }
375 function_->AddBasicBlocks(clone_result.cloned_bb_.begin(),
376 clone_result.cloned_bb_.end(),
377 ++FindBasicBlockPosition(if_block));
378 }
379
380 // Same as above but specialize the existing loop
381 {
382 std::unordered_set<uint32_t> dead_blocks;
383 std::unordered_set<uint32_t> unreachable_merges;
384 SimplifyLoop(make_range(function_->begin(), function_->end()), loop_,
385 condition, original_loop_constant_value, &dead_blocks);
386
387 for (uint32_t merge_bb_id : if_merging_blocks) {
388 BasicBlock* merge = context_->cfg()->block(merge_bb_id);
389 // LCSSA, so we only care about phi instructions.
390 // If we the phi is reduced to a single incoming branch, do not
391 // propagate it to preserve LCSSA.
392 PatchPhis(merge, dead_blocks, true);
393 }
394 if (if_merge_block) {
395 bool has_live_pred = false;
396 for (uint32_t pid : cfg.preds(if_merge_block->id())) {
397 if (!dead_blocks.count(pid)) {
398 has_live_pred = true;
399 break;
400 }
401 }
402 if (!has_live_pred) unreachable_merges.insert(if_merge_block->id());
403 }
404 original_loop_target = loop_->GetPreHeaderBlock();
405 // We tagged dead blocks, prune the loop descriptor from any dead loops.
406 // After this call, |loop_| can be nullptr (i.e. the unswitch killed this
407 // loop).
408 loop_ = CleanLoopNest(loop_, dead_blocks, &unreachable_merges);
409
410 CleanUpCFG(function_->begin(), dead_blocks, unreachable_merges);
411 }
412
413 /////////////////////////////////////
414 // Finally: connect the new loops. //
415 /////////////////////////////////////
416
417 // Delete the old jump
418 context_->KillInst(&*if_block->tail());
419 InstructionBuilder builder(context_, if_block);
420 if (iv_opcode == SpvOpBranchConditional) {
421 assert(constant_branch.size() == 1);
422 builder.AddConditionalBranch(
423 condition->result_id(), original_loop_target->id(),
424 constant_branch[0].second->id(),
425 if_merge_block ? if_merge_block->id() : kInvalidId);
426 } else {
427 std::vector<std::pair<Operand::OperandData, uint32_t>> targets;
428 for (auto& t : constant_branch) {
429 targets.emplace_back(t.first->GetInOperand(0).words, t.second->id());
430 }
431
432 builder.AddSwitch(condition->result_id(), original_loop_target->id(),
433 targets,
434 if_merge_block ? if_merge_block->id() : kInvalidId);
435 }
436
437 switch_block_ = nullptr;
438 ordered_loop_blocks_.clear();
439
440 context_->InvalidateAnalysesExceptFor(
441 IRContext::Analysis::kAnalysisLoopAnalysis);
442 }
443
444 // Returns true if the unswitch killed the original |loop_|.
WasLoopKilled() const445 bool WasLoopKilled() const { return loop_ == nullptr; }
446
447 private:
448 using ValueMapTy = std::unordered_map<uint32_t, uint32_t>;
449 using BlockMapTy = std::unordered_map<uint32_t, BasicBlock*>;
450
451 Function* function_;
452 Loop* loop_;
453 LoopDescriptor& loop_desc_;
454 IRContext* context_;
455
456 BasicBlock* switch_block_;
457 // Map between instructions and if they are dynamically uniform.
458 std::unordered_map<uint32_t, bool> dynamically_uniform_;
459 // The loop basic blocks in structured order.
460 std::vector<BasicBlock*> ordered_loop_blocks_;
461
462 // Returns the next usable id for the context.
TakeNextId()463 uint32_t TakeNextId() {
464 // TODO(1841): Handle id overflow.
465 return context_->TakeNextId();
466 }
467
468 // Patches |bb|'s phi instruction by removing incoming value from unexisting
469 // or tagged as dead branches.
PatchPhis(BasicBlock * bb,const std::unordered_set<uint32_t> & dead_blocks,bool preserve_phi)470 void PatchPhis(BasicBlock* bb,
471 const std::unordered_set<uint32_t>& dead_blocks,
472 bool preserve_phi) {
473 CFG& cfg = *context_->cfg();
474
475 std::vector<Instruction*> phi_to_kill;
476 const std::vector<uint32_t>& bb_preds = cfg.preds(bb->id());
477 auto is_branch_dead = [&bb_preds, &dead_blocks](uint32_t id) {
478 return dead_blocks.count(id) ||
479 std::find(bb_preds.begin(), bb_preds.end(), id) == bb_preds.end();
480 };
481 bb->ForEachPhiInst(
482 [&phi_to_kill, &is_branch_dead, preserve_phi, this](Instruction* insn) {
483 uint32_t i = 0;
484 while (i < insn->NumInOperands()) {
485 uint32_t incoming_id = insn->GetSingleWordInOperand(i + 1);
486 if (is_branch_dead(incoming_id)) {
487 // Remove the incoming block id operand.
488 insn->RemoveInOperand(i + 1);
489 // Remove the definition id operand.
490 insn->RemoveInOperand(i);
491 continue;
492 }
493 i += 2;
494 }
495 // If there is only 1 remaining edge, propagate the value and
496 // kill the instruction.
497 if (insn->NumInOperands() == 2 && !preserve_phi) {
498 phi_to_kill.push_back(insn);
499 context_->ReplaceAllUsesWith(insn->result_id(),
500 insn->GetSingleWordInOperand(0));
501 }
502 });
503 for (Instruction* insn : phi_to_kill) {
504 context_->KillInst(insn);
505 }
506 }
507
508 // Removes any block that is tagged as dead, if the block is in
509 // |unreachable_merges| then all block's instructions are replaced by a
510 // OpUnreachable.
CleanUpCFG(UptrVectorIterator<BasicBlock> bb_it,const std::unordered_set<uint32_t> & dead_blocks,const std::unordered_set<uint32_t> & unreachable_merges)511 void CleanUpCFG(UptrVectorIterator<BasicBlock> bb_it,
512 const std::unordered_set<uint32_t>& dead_blocks,
513 const std::unordered_set<uint32_t>& unreachable_merges) {
514 CFG& cfg = *context_->cfg();
515
516 while (bb_it != bb_it.End()) {
517 BasicBlock& bb = *bb_it;
518
519 if (unreachable_merges.count(bb.id())) {
520 if (bb.begin() != bb.tail() ||
521 bb.terminator()->opcode() != SpvOpUnreachable) {
522 // Make unreachable, but leave the label.
523 bb.KillAllInsts(false);
524 InstructionBuilder(context_, &bb).AddUnreachable();
525 cfg.RemoveNonExistingEdges(bb.id());
526 }
527 ++bb_it;
528 } else if (dead_blocks.count(bb.id())) {
529 cfg.ForgetBlock(&bb);
530 // Kill this block.
531 bb.KillAllInsts(true);
532 bb_it = bb_it.Erase();
533 } else {
534 cfg.RemoveNonExistingEdges(bb.id());
535 ++bb_it;
536 }
537 }
538 }
539
540 // Return true if |c_inst| is a Boolean constant and set |cond_val| with the
541 // value that |c_inst|
GetConstCondition(const Instruction * c_inst,bool * cond_val)542 bool GetConstCondition(const Instruction* c_inst, bool* cond_val) {
543 bool cond_is_const;
544 switch (c_inst->opcode()) {
545 case SpvOpConstantFalse: {
546 *cond_val = false;
547 cond_is_const = true;
548 } break;
549 case SpvOpConstantTrue: {
550 *cond_val = true;
551 cond_is_const = true;
552 } break;
553 default: { cond_is_const = false; } break;
554 }
555 return cond_is_const;
556 }
557
558 // Simplifies |loop| assuming the instruction |to_version_insn| takes the
559 // value |cst_value|. |block_range| is an iterator range returning the loop
560 // basic blocks in a structured order (dominator first).
561 // The function will ignore basic blocks returned by |block_range| if they
562 // does not belong to the loop.
563 // The set |dead_blocks| will contain all the dead basic blocks.
564 //
565 // Requirements:
566 // - |loop| must be in the LCSSA form;
567 // - |cst_value| must be constant or null (to represent the default target
568 // of an OpSwitch).
SimplifyLoop(IteratorRange<UptrVectorIterator<BasicBlock>> block_range,Loop * loop,Instruction * to_version_insn,Instruction * cst_value,std::unordered_set<uint32_t> * dead_blocks)569 void SimplifyLoop(IteratorRange<UptrVectorIterator<BasicBlock>> block_range,
570 Loop* loop, Instruction* to_version_insn,
571 Instruction* cst_value,
572 std::unordered_set<uint32_t>* dead_blocks) {
573 CFG& cfg = *context_->cfg();
574 analysis::DefUseManager* def_use_mgr = context_->get_def_use_mgr();
575
576 std::function<bool(uint32_t)> ignore_node;
577 ignore_node = [loop](uint32_t bb_id) { return !loop->IsInsideLoop(bb_id); };
578
579 std::vector<std::pair<Instruction*, uint32_t>> use_list;
580 def_use_mgr->ForEachUse(to_version_insn,
581 [&use_list, &ignore_node, this](
582 Instruction* inst, uint32_t operand_index) {
583 BasicBlock* bb = context_->get_instr_block(inst);
584
585 if (!bb || ignore_node(bb->id())) {
586 // Out of the loop, the specialization does not
587 // apply any more.
588 return;
589 }
590 use_list.emplace_back(inst, operand_index);
591 });
592
593 // First pass: inject the specialized value into the loop (and only the
594 // loop).
595 for (auto use : use_list) {
596 Instruction* inst = use.first;
597 uint32_t operand_index = use.second;
598 BasicBlock* bb = context_->get_instr_block(inst);
599
600 // If it is not a branch, simply inject the value.
601 if (!inst->IsBranch()) {
602 // To also handle switch, cst_value can be nullptr: this case
603 // means that we are looking to branch to the default target of
604 // the switch. We don't actually know its value so we don't touch
605 // it if it not a switch.
606 if (cst_value) {
607 inst->SetOperand(operand_index, {cst_value->result_id()});
608 def_use_mgr->AnalyzeInstUse(inst);
609 }
610 }
611
612 // The user is a branch, kill dead branches.
613 uint32_t live_target = 0;
614 std::unordered_set<uint32_t> dead_branches;
615 switch (inst->opcode()) {
616 case SpvOpBranchConditional: {
617 assert(cst_value && "No constant value to specialize !");
618 bool branch_cond = false;
619 if (GetConstCondition(cst_value, &branch_cond)) {
620 uint32_t true_label =
621 inst->GetSingleWordInOperand(kBranchCondTrueLabIdInIdx);
622 uint32_t false_label =
623 inst->GetSingleWordInOperand(kBranchCondFalseLabIdInIdx);
624 live_target = branch_cond ? true_label : false_label;
625 uint32_t dead_target = !branch_cond ? true_label : false_label;
626 cfg.RemoveEdge(bb->id(), dead_target);
627 }
628 break;
629 }
630 case SpvOpSwitch: {
631 live_target = inst->GetSingleWordInOperand(1);
632 if (cst_value) {
633 if (!cst_value->IsConstant()) break;
634 const Operand& cst = cst_value->GetInOperand(0);
635 for (uint32_t i = 2; i < inst->NumInOperands(); i += 2) {
636 const Operand& literal = inst->GetInOperand(i);
637 if (literal == cst) {
638 live_target = inst->GetSingleWordInOperand(i + 1);
639 break;
640 }
641 }
642 }
643 for (uint32_t i = 1; i < inst->NumInOperands(); i += 2) {
644 uint32_t id = inst->GetSingleWordInOperand(i);
645 if (id != live_target) {
646 cfg.RemoveEdge(bb->id(), id);
647 }
648 }
649 }
650 default:
651 break;
652 }
653 if (live_target != 0) {
654 // Check for the presence of the merge block.
655 if (Instruction* merge = bb->GetMergeInst()) context_->KillInst(merge);
656 context_->KillInst(&*bb->tail());
657 InstructionBuilder builder(context_, bb,
658 IRContext::kAnalysisDefUse |
659 IRContext::kAnalysisInstrToBlockMapping);
660 builder.AddBranch(live_target);
661 }
662 }
663
664 // Go through the loop basic block and tag all blocks that are obviously
665 // dead.
666 std::unordered_set<uint32_t> visited;
667 for (BasicBlock& bb : block_range) {
668 if (ignore_node(bb.id())) continue;
669 visited.insert(bb.id());
670
671 // Check if this block is dead, if so tag it as dead otherwise patch phi
672 // instructions.
673 bool has_live_pred = false;
674 for (uint32_t pid : cfg.preds(bb.id())) {
675 if (!dead_blocks->count(pid)) {
676 has_live_pred = true;
677 break;
678 }
679 }
680 if (!has_live_pred) {
681 dead_blocks->insert(bb.id());
682 const BasicBlock& cbb = bb;
683 // Patch the phis for any back-edge.
684 cbb.ForEachSuccessorLabel(
685 [dead_blocks, &visited, &cfg, this](uint32_t id) {
686 if (!visited.count(id) || dead_blocks->count(id)) return;
687 BasicBlock* succ = cfg.block(id);
688 PatchPhis(succ, *dead_blocks, false);
689 });
690 continue;
691 }
692 // Update the phi instructions, some incoming branch have/will disappear.
693 PatchPhis(&bb, *dead_blocks, /* preserve_phi = */ false);
694 }
695 }
696
697 // Returns true if the header is not reachable or tagged as dead or if we
698 // never loop back.
IsLoopDead(BasicBlock * header,BasicBlock * latch,const std::unordered_set<uint32_t> & dead_blocks)699 bool IsLoopDead(BasicBlock* header, BasicBlock* latch,
700 const std::unordered_set<uint32_t>& dead_blocks) {
701 if (!header || dead_blocks.count(header->id())) return true;
702 if (!latch || dead_blocks.count(latch->id())) return true;
703 for (uint32_t pid : context_->cfg()->preds(header->id())) {
704 if (!dead_blocks.count(pid)) {
705 // Seems reachable.
706 return false;
707 }
708 }
709 return true;
710 }
711
712 // Cleans the loop nest under |loop| and reflect changes to the loop
713 // descriptor. This will kill all descriptors that represent dead loops.
714 // If |loop_| is killed, it will be set to nullptr.
715 // Any merge blocks that become unreachable will be added to
716 // |unreachable_merges|.
717 // The function returns the pointer to |loop| or nullptr if the loop was
718 // killed.
CleanLoopNest(Loop * loop,const std::unordered_set<uint32_t> & dead_blocks,std::unordered_set<uint32_t> * unreachable_merges)719 Loop* CleanLoopNest(Loop* loop,
720 const std::unordered_set<uint32_t>& dead_blocks,
721 std::unordered_set<uint32_t>* unreachable_merges) {
722 // This represent the pair of dead loop and nearest alive parent (nullptr if
723 // no parent).
724 std::unordered_map<Loop*, Loop*> dead_loops;
725 auto get_parent = [&dead_loops](Loop* l) -> Loop* {
726 std::unordered_map<Loop*, Loop*>::iterator it = dead_loops.find(l);
727 if (it != dead_loops.end()) return it->second;
728 return nullptr;
729 };
730
731 bool is_main_loop_dead =
732 IsLoopDead(loop->GetHeaderBlock(), loop->GetLatchBlock(), dead_blocks);
733 if (is_main_loop_dead) {
734 if (Instruction* merge = loop->GetHeaderBlock()->GetLoopMergeInst()) {
735 context_->KillInst(merge);
736 }
737 dead_loops[loop] = loop->GetParent();
738 } else {
739 dead_loops[loop] = loop;
740 }
741
742 // For each loop, check if we killed it. If we did, find a suitable parent
743 // for its children.
744 for (Loop& sub_loop :
745 make_range(++TreeDFIterator<Loop>(loop), TreeDFIterator<Loop>())) {
746 if (IsLoopDead(sub_loop.GetHeaderBlock(), sub_loop.GetLatchBlock(),
747 dead_blocks)) {
748 if (Instruction* merge =
749 sub_loop.GetHeaderBlock()->GetLoopMergeInst()) {
750 context_->KillInst(merge);
751 }
752 dead_loops[&sub_loop] = get_parent(&sub_loop);
753 } else {
754 // The loop is alive, check if its merge block is dead, if it is, tag it
755 // as required.
756 if (sub_loop.GetMergeBlock()) {
757 uint32_t merge_id = sub_loop.GetMergeBlock()->id();
758 if (dead_blocks.count(merge_id)) {
759 unreachable_merges->insert(sub_loop.GetMergeBlock()->id());
760 }
761 }
762 }
763 }
764 if (!is_main_loop_dead) dead_loops.erase(loop);
765
766 // Remove dead blocks from live loops.
767 for (uint32_t bb_id : dead_blocks) {
768 Loop* l = loop_desc_[bb_id];
769 if (l) {
770 l->RemoveBasicBlock(bb_id);
771 loop_desc_.ForgetBasicBlock(bb_id);
772 }
773 }
774
775 std::for_each(
776 dead_loops.begin(), dead_loops.end(),
777 [&loop,
778 this](std::unordered_map<Loop*, Loop*>::iterator::reference it) {
779 if (it.first == loop) loop = nullptr;
780 loop_desc_.RemoveLoop(it.first);
781 });
782
783 return loop;
784 }
785
786 // Returns true if |var| is dynamically uniform.
787 // Note: this is currently approximated as uniform.
IsDynamicallyUniform(Instruction * var,const BasicBlock * entry,const DominatorTree & post_dom_tree)788 bool IsDynamicallyUniform(Instruction* var, const BasicBlock* entry,
789 const DominatorTree& post_dom_tree) {
790 assert(post_dom_tree.IsPostDominator());
791 analysis::DefUseManager* def_use_mgr = context_->get_def_use_mgr();
792
793 auto it = dynamically_uniform_.find(var->result_id());
794
795 if (it != dynamically_uniform_.end()) return it->second;
796
797 analysis::DecorationManager* dec_mgr = context_->get_decoration_mgr();
798
799 bool& is_uniform = dynamically_uniform_[var->result_id()];
800 is_uniform = false;
801
802 dec_mgr->WhileEachDecoration(var->result_id(), SpvDecorationUniform,
803 [&is_uniform](const Instruction&) {
804 is_uniform = true;
805 return false;
806 });
807 if (is_uniform) {
808 return is_uniform;
809 }
810
811 BasicBlock* parent = context_->get_instr_block(var);
812 if (!parent) {
813 return is_uniform = true;
814 }
815
816 if (!post_dom_tree.Dominates(parent->id(), entry->id())) {
817 return is_uniform = false;
818 }
819 if (var->opcode() == SpvOpLoad) {
820 const uint32_t PtrTypeId =
821 def_use_mgr->GetDef(var->GetSingleWordInOperand(0))->type_id();
822 const Instruction* PtrTypeInst = def_use_mgr->GetDef(PtrTypeId);
823 uint32_t storage_class =
824 PtrTypeInst->GetSingleWordInOperand(kTypePointerStorageClassInIdx);
825 if (storage_class != SpvStorageClassUniform &&
826 storage_class != SpvStorageClassUniformConstant) {
827 return is_uniform = false;
828 }
829 } else {
830 if (!context_->IsCombinatorInstruction(var)) {
831 return is_uniform = false;
832 }
833 }
834
835 return is_uniform = var->WhileEachInId([entry, &post_dom_tree,
836 this](const uint32_t* id) {
837 return IsDynamicallyUniform(context_->get_def_use_mgr()->GetDef(*id),
838 entry, post_dom_tree);
839 });
840 }
841
842 // Returns true if |insn| is constant and dynamically uniform within the loop.
IsConditionLoopInvariant(Instruction * insn)843 bool IsConditionLoopInvariant(Instruction* insn) {
844 assert(insn->IsBranch());
845 assert(insn->opcode() != SpvOpBranch);
846 analysis::DefUseManager* def_use_mgr = context_->get_def_use_mgr();
847
848 Instruction* condition = def_use_mgr->GetDef(insn->GetOperand(0).words[0]);
849 return !loop_->IsInsideLoop(condition) &&
850 IsDynamicallyUniform(
851 condition, function_->entry().get(),
852 context_->GetPostDominatorAnalysis(function_)->GetDomTree());
853 }
854 };
855
856 } // namespace
857
Process()858 Pass::Status LoopUnswitchPass::Process() {
859 bool modified = false;
860 Module* module = context()->module();
861
862 // Process each function in the module
863 for (Function& f : *module) {
864 modified |= ProcessFunction(&f);
865 }
866
867 return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange;
868 }
869
ProcessFunction(Function * f)870 bool LoopUnswitchPass::ProcessFunction(Function* f) {
871 bool modified = false;
872 std::unordered_set<Loop*> processed_loop;
873
874 LoopDescriptor& loop_descriptor = *context()->GetLoopDescriptor(f);
875
876 bool loop_changed = true;
877 while (loop_changed) {
878 loop_changed = false;
879 for (Loop& loop :
880 make_range(++TreeDFIterator<Loop>(loop_descriptor.GetDummyRootLoop()),
881 TreeDFIterator<Loop>())) {
882 if (processed_loop.count(&loop)) continue;
883 processed_loop.insert(&loop);
884
885 LoopUnswitch unswitcher(context(), f, &loop, &loop_descriptor);
886 while (!unswitcher.WasLoopKilled() && unswitcher.CanUnswitchLoop()) {
887 if (!loop.IsLCSSA()) {
888 LoopUtils(context(), &loop).MakeLoopClosedSSA();
889 }
890 modified = true;
891 loop_changed = true;
892 unswitcher.PerformUnswitch();
893 }
894 if (loop_changed) break;
895 }
896 }
897
898 return modified;
899 }
900
901 } // namespace opt
902 } // namespace spvtools
903