1 // Copyright (c) 2017 The Khronos Group Inc.
2 // Copyright (c) 2017 Valve Corporation
3 // Copyright (c) 2017 LunarG Inc.
4 // Copyright (c) 2018 Google Inc.
5 //
6 // Licensed under the Apache License, Version 2.0 (the "License");
7 // you may not use this file except in compliance with the License.
8 // You may obtain a copy of the License at
9 //
10 //     http://www.apache.org/licenses/LICENSE-2.0
11 //
12 // Unless required by applicable law or agreed to in writing, software
13 // distributed under the License is distributed on an "AS IS" BASIS,
14 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 // See the License for the specific language governing permissions and
16 // limitations under the License.
17 
18 #include "source/opt/dead_branch_elim_pass.h"
19 
20 #include <list>
21 #include <memory>
22 #include <vector>
23 
24 #include "source/cfa.h"
25 #include "source/opt/ir_context.h"
26 #include "source/opt/iterator.h"
27 #include "source/opt/struct_cfg_analysis.h"
28 #include "source/util/make_unique.h"
29 
30 namespace spvtools {
31 namespace opt {
32 
33 namespace {
34 
35 const uint32_t kBranchCondTrueLabIdInIdx = 1;
36 const uint32_t kBranchCondFalseLabIdInIdx = 2;
37 
38 }  // anonymous namespace
39 
GetConstCondition(uint32_t condId,bool * condVal)40 bool DeadBranchElimPass::GetConstCondition(uint32_t condId, bool* condVal) {
41   bool condIsConst;
42   Instruction* cInst = get_def_use_mgr()->GetDef(condId);
43   switch (cInst->opcode()) {
44     case SpvOpConstantFalse: {
45       *condVal = false;
46       condIsConst = true;
47     } break;
48     case SpvOpConstantTrue: {
49       *condVal = true;
50       condIsConst = true;
51     } break;
52     case SpvOpLogicalNot: {
53       bool negVal;
54       condIsConst =
55           GetConstCondition(cInst->GetSingleWordInOperand(0), &negVal);
56       if (condIsConst) *condVal = !negVal;
57     } break;
58     default: { condIsConst = false; } break;
59   }
60   return condIsConst;
61 }
62 
GetConstInteger(uint32_t selId,uint32_t * selVal)63 bool DeadBranchElimPass::GetConstInteger(uint32_t selId, uint32_t* selVal) {
64   Instruction* sInst = get_def_use_mgr()->GetDef(selId);
65   uint32_t typeId = sInst->type_id();
66   Instruction* typeInst = get_def_use_mgr()->GetDef(typeId);
67   if (!typeInst || (typeInst->opcode() != SpvOpTypeInt)) return false;
68   // TODO(greg-lunarg): Support non-32 bit ints
69   if (typeInst->GetSingleWordInOperand(0) != 32) return false;
70   if (sInst->opcode() == SpvOpConstant) {
71     *selVal = sInst->GetSingleWordInOperand(0);
72     return true;
73   } else if (sInst->opcode() == SpvOpConstantNull) {
74     *selVal = 0;
75     return true;
76   }
77   return false;
78 }
79 
AddBranch(uint32_t labelId,BasicBlock * bp)80 void DeadBranchElimPass::AddBranch(uint32_t labelId, BasicBlock* bp) {
81   assert(get_def_use_mgr()->GetDef(labelId) != nullptr);
82   std::unique_ptr<Instruction> newBranch(
83       new Instruction(context(), SpvOpBranch, 0, 0,
84                       {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {labelId}}}));
85   context()->AnalyzeDefUse(&*newBranch);
86   context()->set_instr_block(&*newBranch, bp);
87   bp->AddInstruction(std::move(newBranch));
88 }
89 
GetParentBlock(uint32_t id)90 BasicBlock* DeadBranchElimPass::GetParentBlock(uint32_t id) {
91   return context()->get_instr_block(get_def_use_mgr()->GetDef(id));
92 }
93 
MarkLiveBlocks(Function * func,std::unordered_set<BasicBlock * > * live_blocks)94 bool DeadBranchElimPass::MarkLiveBlocks(
95     Function* func, std::unordered_set<BasicBlock*>* live_blocks) {
96   StructuredCFGAnalysis* cfgAnalysis = context()->GetStructuredCFGAnalysis();
97 
98   std::unordered_set<BasicBlock*> continues;
99   std::vector<BasicBlock*> stack;
100   stack.push_back(&*func->begin());
101   bool modified = false;
102   while (!stack.empty()) {
103     BasicBlock* block = stack.back();
104     stack.pop_back();
105 
106     // Live blocks doubles as visited set.
107     if (!live_blocks->insert(block).second) continue;
108 
109     uint32_t cont_id = block->ContinueBlockIdIfAny();
110     if (cont_id != 0) continues.insert(GetParentBlock(cont_id));
111 
112     Instruction* terminator = block->terminator();
113     uint32_t live_lab_id = 0;
114     // Check if the terminator has a single valid successor.
115     if (terminator->opcode() == SpvOpBranchConditional) {
116       bool condVal;
117       if (GetConstCondition(terminator->GetSingleWordInOperand(0u), &condVal)) {
118         live_lab_id = terminator->GetSingleWordInOperand(
119             condVal ? kBranchCondTrueLabIdInIdx : kBranchCondFalseLabIdInIdx);
120       }
121     } else if (terminator->opcode() == SpvOpSwitch) {
122       uint32_t sel_val;
123       if (GetConstInteger(terminator->GetSingleWordInOperand(0u), &sel_val)) {
124         // Search switch operands for selector value, set live_lab_id to
125         // corresponding label, use default if not found.
126         uint32_t icnt = 0;
127         uint32_t case_val;
128         terminator->WhileEachInOperand(
129             [&icnt, &case_val, &sel_val, &live_lab_id](const uint32_t* idp) {
130               if (icnt == 1) {
131                 // Start with default label.
132                 live_lab_id = *idp;
133               } else if (icnt > 1) {
134                 if (icnt % 2 == 0) {
135                   case_val = *idp;
136                 } else {
137                   if (case_val == sel_val) {
138                     live_lab_id = *idp;
139                     return false;
140                   }
141                 }
142               }
143               ++icnt;
144               return true;
145             });
146       }
147     }
148 
149     // Don't simplify branches of continue blocks. A path from the continue to
150     // the header is required.
151     // TODO(alan-baker): They can be simplified iff there remains a path to the
152     // backedge. Structured control flow should guarantee one path hits the
153     // backedge, but I've removed the requirement for structured control flow
154     // from this pass.
155     bool simplify = live_lab_id != 0 && !continues.count(block);
156 
157     if (simplify) {
158       modified = true;
159       // Replace with unconditional branch.
160       // Remove the merge instruction if it is a selection merge.
161       AddBranch(live_lab_id, block);
162       context()->KillInst(terminator);
163       Instruction* mergeInst = block->GetMergeInst();
164       if (mergeInst && mergeInst->opcode() == SpvOpSelectionMerge) {
165         Instruction* first_break = FindFirstExitFromSelectionMerge(
166             live_lab_id, mergeInst->GetSingleWordInOperand(0),
167             cfgAnalysis->LoopMergeBlock(live_lab_id),
168             cfgAnalysis->LoopContinueBlock(live_lab_id));
169         if (first_break == nullptr) {
170           context()->KillInst(mergeInst);
171         } else {
172           mergeInst->RemoveFromList();
173           first_break->InsertBefore(std::unique_ptr<Instruction>(mergeInst));
174           context()->set_instr_block(mergeInst,
175                                      context()->get_instr_block(first_break));
176         }
177       }
178       stack.push_back(GetParentBlock(live_lab_id));
179     } else {
180       // All successors are live.
181       const auto* const_block = block;
182       const_block->ForEachSuccessorLabel([&stack, this](const uint32_t label) {
183         stack.push_back(GetParentBlock(label));
184       });
185     }
186   }
187 
188   return modified;
189 }
190 
MarkUnreachableStructuredTargets(const std::unordered_set<BasicBlock * > & live_blocks,std::unordered_set<BasicBlock * > * unreachable_merges,std::unordered_map<BasicBlock *,BasicBlock * > * unreachable_continues)191 void DeadBranchElimPass::MarkUnreachableStructuredTargets(
192     const std::unordered_set<BasicBlock*>& live_blocks,
193     std::unordered_set<BasicBlock*>* unreachable_merges,
194     std::unordered_map<BasicBlock*, BasicBlock*>* unreachable_continues) {
195   for (auto block : live_blocks) {
196     if (auto merge_id = block->MergeBlockIdIfAny()) {
197       BasicBlock* merge_block = GetParentBlock(merge_id);
198       if (!live_blocks.count(merge_block)) {
199         unreachable_merges->insert(merge_block);
200       }
201       if (auto cont_id = block->ContinueBlockIdIfAny()) {
202         BasicBlock* cont_block = GetParentBlock(cont_id);
203         if (!live_blocks.count(cont_block)) {
204           (*unreachable_continues)[cont_block] = block;
205         }
206       }
207     }
208   }
209 }
210 
FixPhiNodesInLiveBlocks(Function * func,const std::unordered_set<BasicBlock * > & live_blocks,const std::unordered_map<BasicBlock *,BasicBlock * > & unreachable_continues)211 bool DeadBranchElimPass::FixPhiNodesInLiveBlocks(
212     Function* func, const std::unordered_set<BasicBlock*>& live_blocks,
213     const std::unordered_map<BasicBlock*, BasicBlock*>& unreachable_continues) {
214   bool modified = false;
215   for (auto& block : *func) {
216     if (live_blocks.count(&block)) {
217       for (auto iter = block.begin(); iter != block.end();) {
218         if (iter->opcode() != SpvOpPhi) {
219           break;
220         }
221 
222         bool changed = false;
223         bool backedge_added = false;
224         Instruction* inst = &*iter;
225         std::vector<Operand> operands;
226         // Build a complete set of operands (not just input operands). Start
227         // with type and result id operands.
228         operands.push_back(inst->GetOperand(0u));
229         operands.push_back(inst->GetOperand(1u));
230         // Iterate through the incoming labels and determine which to keep
231         // and/or modify.  If there in an unreachable continue block, there will
232         // be an edge from that block to the header.  We need to keep it to
233         // maintain the structured control flow.  If the header has more that 2
234         // incoming edges, then the OpPhi must have an entry for that edge.
235         // However, if there is only one other incoming edge, the OpPhi can be
236         // eliminated.
237         for (uint32_t i = 1; i < inst->NumInOperands(); i += 2) {
238           BasicBlock* inc = GetParentBlock(inst->GetSingleWordInOperand(i));
239           auto cont_iter = unreachable_continues.find(inc);
240           if (cont_iter != unreachable_continues.end() &&
241               cont_iter->second == &block && inst->NumInOperands() > 4) {
242             if (get_def_use_mgr()
243                     ->GetDef(inst->GetSingleWordInOperand(i - 1))
244                     ->opcode() == SpvOpUndef) {
245               // Already undef incoming value, no change necessary.
246               operands.push_back(inst->GetInOperand(i - 1));
247               operands.push_back(inst->GetInOperand(i));
248               backedge_added = true;
249             } else {
250               // Replace incoming value with undef if this phi exists in the
251               // loop header. Otherwise, this edge is not live since the
252               // unreachable continue block will be replaced with an
253               // unconditional branch to the header only.
254               operands.emplace_back(
255                   SPV_OPERAND_TYPE_ID,
256                   std::initializer_list<uint32_t>{Type2Undef(inst->type_id())});
257               operands.push_back(inst->GetInOperand(i));
258               changed = true;
259               backedge_added = true;
260             }
261           } else if (live_blocks.count(inc) && inc->IsSuccessor(&block)) {
262             // Keep live incoming edge.
263             operands.push_back(inst->GetInOperand(i - 1));
264             operands.push_back(inst->GetInOperand(i));
265           } else {
266             // Remove incoming edge.
267             changed = true;
268           }
269         }
270 
271         if (changed) {
272           modified = true;
273           uint32_t continue_id = block.ContinueBlockIdIfAny();
274           if (!backedge_added && continue_id != 0 &&
275               unreachable_continues.count(GetParentBlock(continue_id)) &&
276               operands.size() > 4) {
277             // Changed the backedge to branch from the continue block instead
278             // of a successor of the continue block. Add an entry to the phi to
279             // provide an undef for the continue block. Since the successor of
280             // the continue must also be unreachable (dominated by the continue
281             // block), any entry for the original backedge has been removed
282             // from the phi operands.
283             operands.emplace_back(
284                 SPV_OPERAND_TYPE_ID,
285                 std::initializer_list<uint32_t>{Type2Undef(inst->type_id())});
286             operands.emplace_back(SPV_OPERAND_TYPE_ID,
287                                   std::initializer_list<uint32_t>{continue_id});
288           }
289 
290           // Either replace the phi with a single value or rebuild the phi out
291           // of |operands|.
292           //
293           // We always have type and result id operands. So this phi has a
294           // single source if there are two more operands beyond those.
295           if (operands.size() == 4) {
296             // First input data operands is at index 2.
297             uint32_t replId = operands[2u].words[0];
298             context()->ReplaceAllUsesWith(inst->result_id(), replId);
299             iter = context()->KillInst(&*inst);
300           } else {
301             // We've rewritten the operands, so first instruct the def/use
302             // manager to forget uses in the phi before we replace them. After
303             // replacing operands update the def/use manager by re-analyzing
304             // the used ids in this phi.
305             get_def_use_mgr()->EraseUseRecordsOfOperandIds(inst);
306             inst->ReplaceOperands(operands);
307             get_def_use_mgr()->AnalyzeInstUse(inst);
308             ++iter;
309           }
310         } else {
311           ++iter;
312         }
313       }
314     }
315   }
316 
317   return modified;
318 }
319 
EraseDeadBlocks(Function * func,const std::unordered_set<BasicBlock * > & live_blocks,const std::unordered_set<BasicBlock * > & unreachable_merges,const std::unordered_map<BasicBlock *,BasicBlock * > & unreachable_continues)320 bool DeadBranchElimPass::EraseDeadBlocks(
321     Function* func, const std::unordered_set<BasicBlock*>& live_blocks,
322     const std::unordered_set<BasicBlock*>& unreachable_merges,
323     const std::unordered_map<BasicBlock*, BasicBlock*>& unreachable_continues) {
324   bool modified = false;
325   for (auto ebi = func->begin(); ebi != func->end();) {
326     if (unreachable_merges.count(&*ebi)) {
327       if (ebi->begin() != ebi->tail() ||
328           ebi->terminator()->opcode() != SpvOpUnreachable) {
329         // Make unreachable, but leave the label.
330         KillAllInsts(&*ebi, false);
331         // Add unreachable terminator.
332         ebi->AddInstruction(
333             MakeUnique<Instruction>(context(), SpvOpUnreachable, 0, 0,
334                                     std::initializer_list<Operand>{}));
335         context()->AnalyzeUses(ebi->terminator());
336         context()->set_instr_block(ebi->terminator(), &*ebi);
337         modified = true;
338       }
339       ++ebi;
340     } else if (unreachable_continues.count(&*ebi)) {
341       uint32_t cont_id = unreachable_continues.find(&*ebi)->second->id();
342       if (ebi->begin() != ebi->tail() ||
343           ebi->terminator()->opcode() != SpvOpBranch ||
344           ebi->terminator()->GetSingleWordInOperand(0u) != cont_id) {
345         // Make unreachable, but leave the label.
346         KillAllInsts(&*ebi, false);
347         // Add unconditional branch to header.
348         assert(unreachable_continues.count(&*ebi));
349         ebi->AddInstruction(MakeUnique<Instruction>(
350             context(), SpvOpBranch, 0, 0,
351             std::initializer_list<Operand>{{SPV_OPERAND_TYPE_ID, {cont_id}}}));
352         get_def_use_mgr()->AnalyzeInstUse(&*ebi->tail());
353         context()->set_instr_block(&*ebi->tail(), &*ebi);
354         modified = true;
355       }
356       ++ebi;
357     } else if (!live_blocks.count(&*ebi)) {
358       // Kill this block.
359       KillAllInsts(&*ebi);
360       ebi = ebi.Erase();
361       modified = true;
362     } else {
363       ++ebi;
364     }
365   }
366 
367   return modified;
368 }
369 
EliminateDeadBranches(Function * func)370 bool DeadBranchElimPass::EliminateDeadBranches(Function* func) {
371   bool modified = false;
372   std::unordered_set<BasicBlock*> live_blocks;
373   modified |= MarkLiveBlocks(func, &live_blocks);
374 
375   std::unordered_set<BasicBlock*> unreachable_merges;
376   std::unordered_map<BasicBlock*, BasicBlock*> unreachable_continues;
377   MarkUnreachableStructuredTargets(live_blocks, &unreachable_merges,
378                                    &unreachable_continues);
379   modified |= FixPhiNodesInLiveBlocks(func, live_blocks, unreachable_continues);
380   modified |= EraseDeadBlocks(func, live_blocks, unreachable_merges,
381                               unreachable_continues);
382 
383   return modified;
384 }
385 
FixBlockOrder()386 void DeadBranchElimPass::FixBlockOrder() {
387   context()->BuildInvalidAnalyses(IRContext::kAnalysisCFG |
388                                   IRContext::kAnalysisDominatorAnalysis);
389   // Reorders blocks according to DFS of dominator tree.
390   ProcessFunction reorder_dominators = [this](Function* function) {
391     DominatorAnalysis* dominators = context()->GetDominatorAnalysis(function);
392     std::vector<BasicBlock*> blocks;
393     for (auto iter = dominators->GetDomTree().begin();
394          iter != dominators->GetDomTree().end(); ++iter) {
395       if (iter->id() != 0) {
396         blocks.push_back(iter->bb_);
397       }
398     }
399     for (uint32_t i = 1; i < blocks.size(); ++i) {
400       function->MoveBasicBlockToAfter(blocks[i]->id(), blocks[i - 1]);
401     }
402     return true;
403   };
404 
405   // Reorders blocks according to structured order.
406   ProcessFunction reorder_structured = [this](Function* function) {
407     std::list<BasicBlock*> order;
408     context()->cfg()->ComputeStructuredOrder(function, &*function->begin(),
409                                              &order);
410     std::vector<BasicBlock*> blocks;
411     for (auto block : order) {
412       blocks.push_back(block);
413     }
414     for (uint32_t i = 1; i < blocks.size(); ++i) {
415       function->MoveBasicBlockToAfter(blocks[i]->id(), blocks[i - 1]);
416     }
417     return true;
418   };
419 
420   // Structured order is more intuitive so use it where possible.
421   if (context()->get_feature_mgr()->HasCapability(SpvCapabilityShader)) {
422     context()->ProcessReachableCallTree(reorder_structured);
423   } else {
424     context()->ProcessReachableCallTree(reorder_dominators);
425   }
426 }
427 
Process()428 Pass::Status DeadBranchElimPass::Process() {
429   // Do not process if module contains OpGroupDecorate. Additional
430   // support required in KillNamesAndDecorates().
431   // TODO(greg-lunarg): Add support for OpGroupDecorate
432   for (auto& ai : get_module()->annotations())
433     if (ai.opcode() == SpvOpGroupDecorate) return Status::SuccessWithoutChange;
434   // Process all entry point functions
435   ProcessFunction pfn = [this](Function* fp) {
436     return EliminateDeadBranches(fp);
437   };
438   bool modified = context()->ProcessReachableCallTree(pfn);
439   if (modified) FixBlockOrder();
440   return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange;
441 }
442 
FindFirstExitFromSelectionMerge(uint32_t start_block_id,uint32_t merge_block_id,uint32_t loop_merge_id,uint32_t loop_continue_id)443 Instruction* DeadBranchElimPass::FindFirstExitFromSelectionMerge(
444     uint32_t start_block_id, uint32_t merge_block_id, uint32_t loop_merge_id,
445     uint32_t loop_continue_id) {
446   // To find the "first" exit, we follow branches looking for a conditional
447   // branch that is not in a nested construct and is not the header of a new
448   // construct.  We follow the control flow from |start_block_id| to find the
449   // first one.
450   while (start_block_id != merge_block_id && start_block_id != loop_merge_id &&
451          start_block_id != loop_continue_id) {
452     BasicBlock* start_block = context()->get_instr_block(start_block_id);
453     Instruction* branch = start_block->terminator();
454     uint32_t next_block_id = 0;
455     switch (branch->opcode()) {
456       case SpvOpBranchConditional:
457         next_block_id = start_block->MergeBlockIdIfAny();
458         if (next_block_id == 0) {
459           // If a possible target is the |loop_merge_id| or |loop_continue_id|,
460           // which are not the current merge node, then we continue the search
461           // with the other target.
462           for (uint32_t i = 1; i < 3; i++) {
463             if (branch->GetSingleWordInOperand(i) == loop_merge_id &&
464                 loop_merge_id != merge_block_id) {
465               next_block_id = branch->GetSingleWordInOperand(3 - i);
466               break;
467             }
468             if (branch->GetSingleWordInOperand(i) == loop_continue_id &&
469                 loop_continue_id != merge_block_id) {
470               next_block_id = branch->GetSingleWordInOperand(3 - i);
471               break;
472             }
473           }
474 
475           if (next_block_id == 0) {
476             return branch;
477           }
478         }
479         break;
480       case SpvOpSwitch:
481         next_block_id = start_block->MergeBlockIdIfAny();
482         if (next_block_id == 0) {
483           // A switch with no merge instructions can have at most 4 targets:
484           //   a. |merge_block_id|
485           //   b. |loop_merge_id|
486           //   c. |loop_continue_id|
487           //   d. 1 block inside the current region.
488           //
489           // This leads to a number of cases of what to do.
490           //
491           // 1. Does not jump to a block inside of the current construct.  In
492           // this case, there is not conditional break, so we should return
493           // |nullptr|.
494           //
495           // 2. Jumps to |merge_block_id| and a block inside the current
496           // construct.  In this case, this branch conditionally break to the
497           // end of the current construct, so return the current branch.
498           //
499           // 3.  Otherwise, this branch may break, but not to the current merge
500           // block.  So we continue with the block that is inside the loop.
501 
502           bool found_break = false;
503           for (uint32_t i = 1; i < branch->NumInOperands(); i += 2) {
504             uint32_t target = branch->GetSingleWordInOperand(i);
505             if (target == merge_block_id) {
506               found_break = true;
507             } else if (target != loop_merge_id && target != loop_continue_id) {
508               next_block_id = branch->GetSingleWordInOperand(i);
509             }
510           }
511 
512           if (next_block_id == 0) {
513             // Case 1.
514             return nullptr;
515           }
516 
517           if (found_break) {
518             // Case 2.
519             return branch;
520           }
521 
522           // The fall through is case 3.
523         }
524         break;
525       case SpvOpBranch:
526         // Need to check if this is the header of a loop nested in the
527         // selection construct.
528         next_block_id = start_block->MergeBlockIdIfAny();
529         if (next_block_id == 0) {
530           next_block_id = branch->GetSingleWordInOperand(0);
531         }
532         break;
533       default:
534         return nullptr;
535     }
536     start_block_id = next_block_id;
537   }
538   return nullptr;
539 }
540 
541 }  // namespace opt
542 }  // namespace spvtools
543