1 // Copyright (c) 2017 The Khronos Group Inc.
2 // Copyright (c) 2017 Valve Corporation
3 // Copyright (c) 2017 LunarG Inc.
4 // Copyright (c) 2018 Google Inc.
5 //
6 // Licensed under the Apache License, Version 2.0 (the "License");
7 // you may not use this file except in compliance with the License.
8 // You may obtain a copy of the License at
9 //
10 // http://www.apache.org/licenses/LICENSE-2.0
11 //
12 // Unless required by applicable law or agreed to in writing, software
13 // distributed under the License is distributed on an "AS IS" BASIS,
14 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 // See the License for the specific language governing permissions and
16 // limitations under the License.
17
18 #include "source/opt/dead_branch_elim_pass.h"
19
20 #include <list>
21 #include <memory>
22 #include <vector>
23
24 #include "source/cfa.h"
25 #include "source/opt/ir_context.h"
26 #include "source/opt/iterator.h"
27 #include "source/opt/struct_cfg_analysis.h"
28 #include "source/util/make_unique.h"
29
30 namespace spvtools {
31 namespace opt {
32
33 namespace {
34
35 const uint32_t kBranchCondTrueLabIdInIdx = 1;
36 const uint32_t kBranchCondFalseLabIdInIdx = 2;
37
38 } // anonymous namespace
39
GetConstCondition(uint32_t condId,bool * condVal)40 bool DeadBranchElimPass::GetConstCondition(uint32_t condId, bool* condVal) {
41 bool condIsConst;
42 Instruction* cInst = get_def_use_mgr()->GetDef(condId);
43 switch (cInst->opcode()) {
44 case SpvOpConstantFalse: {
45 *condVal = false;
46 condIsConst = true;
47 } break;
48 case SpvOpConstantTrue: {
49 *condVal = true;
50 condIsConst = true;
51 } break;
52 case SpvOpLogicalNot: {
53 bool negVal;
54 condIsConst =
55 GetConstCondition(cInst->GetSingleWordInOperand(0), &negVal);
56 if (condIsConst) *condVal = !negVal;
57 } break;
58 default: { condIsConst = false; } break;
59 }
60 return condIsConst;
61 }
62
GetConstInteger(uint32_t selId,uint32_t * selVal)63 bool DeadBranchElimPass::GetConstInteger(uint32_t selId, uint32_t* selVal) {
64 Instruction* sInst = get_def_use_mgr()->GetDef(selId);
65 uint32_t typeId = sInst->type_id();
66 Instruction* typeInst = get_def_use_mgr()->GetDef(typeId);
67 if (!typeInst || (typeInst->opcode() != SpvOpTypeInt)) return false;
68 // TODO(greg-lunarg): Support non-32 bit ints
69 if (typeInst->GetSingleWordInOperand(0) != 32) return false;
70 if (sInst->opcode() == SpvOpConstant) {
71 *selVal = sInst->GetSingleWordInOperand(0);
72 return true;
73 } else if (sInst->opcode() == SpvOpConstantNull) {
74 *selVal = 0;
75 return true;
76 }
77 return false;
78 }
79
AddBranch(uint32_t labelId,BasicBlock * bp)80 void DeadBranchElimPass::AddBranch(uint32_t labelId, BasicBlock* bp) {
81 assert(get_def_use_mgr()->GetDef(labelId) != nullptr);
82 std::unique_ptr<Instruction> newBranch(
83 new Instruction(context(), SpvOpBranch, 0, 0,
84 {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {labelId}}}));
85 context()->AnalyzeDefUse(&*newBranch);
86 context()->set_instr_block(&*newBranch, bp);
87 bp->AddInstruction(std::move(newBranch));
88 }
89
GetParentBlock(uint32_t id)90 BasicBlock* DeadBranchElimPass::GetParentBlock(uint32_t id) {
91 return context()->get_instr_block(get_def_use_mgr()->GetDef(id));
92 }
93
MarkLiveBlocks(Function * func,std::unordered_set<BasicBlock * > * live_blocks)94 bool DeadBranchElimPass::MarkLiveBlocks(
95 Function* func, std::unordered_set<BasicBlock*>* live_blocks) {
96 StructuredCFGAnalysis* cfgAnalysis = context()->GetStructuredCFGAnalysis();
97
98 std::unordered_set<BasicBlock*> continues;
99 std::vector<BasicBlock*> stack;
100 stack.push_back(&*func->begin());
101 bool modified = false;
102 while (!stack.empty()) {
103 BasicBlock* block = stack.back();
104 stack.pop_back();
105
106 // Live blocks doubles as visited set.
107 if (!live_blocks->insert(block).second) continue;
108
109 uint32_t cont_id = block->ContinueBlockIdIfAny();
110 if (cont_id != 0) continues.insert(GetParentBlock(cont_id));
111
112 Instruction* terminator = block->terminator();
113 uint32_t live_lab_id = 0;
114 // Check if the terminator has a single valid successor.
115 if (terminator->opcode() == SpvOpBranchConditional) {
116 bool condVal;
117 if (GetConstCondition(terminator->GetSingleWordInOperand(0u), &condVal)) {
118 live_lab_id = terminator->GetSingleWordInOperand(
119 condVal ? kBranchCondTrueLabIdInIdx : kBranchCondFalseLabIdInIdx);
120 }
121 } else if (terminator->opcode() == SpvOpSwitch) {
122 uint32_t sel_val;
123 if (GetConstInteger(terminator->GetSingleWordInOperand(0u), &sel_val)) {
124 // Search switch operands for selector value, set live_lab_id to
125 // corresponding label, use default if not found.
126 uint32_t icnt = 0;
127 uint32_t case_val;
128 terminator->WhileEachInOperand(
129 [&icnt, &case_val, &sel_val, &live_lab_id](const uint32_t* idp) {
130 if (icnt == 1) {
131 // Start with default label.
132 live_lab_id = *idp;
133 } else if (icnt > 1) {
134 if (icnt % 2 == 0) {
135 case_val = *idp;
136 } else {
137 if (case_val == sel_val) {
138 live_lab_id = *idp;
139 return false;
140 }
141 }
142 }
143 ++icnt;
144 return true;
145 });
146 }
147 }
148
149 // Don't simplify branches of continue blocks. A path from the continue to
150 // the header is required.
151 // TODO(alan-baker): They can be simplified iff there remains a path to the
152 // backedge. Structured control flow should guarantee one path hits the
153 // backedge, but I've removed the requirement for structured control flow
154 // from this pass.
155 bool simplify = live_lab_id != 0 && !continues.count(block);
156
157 if (simplify) {
158 modified = true;
159 // Replace with unconditional branch.
160 // Remove the merge instruction if it is a selection merge.
161 AddBranch(live_lab_id, block);
162 context()->KillInst(terminator);
163 Instruction* mergeInst = block->GetMergeInst();
164 if (mergeInst && mergeInst->opcode() == SpvOpSelectionMerge) {
165 Instruction* first_break = FindFirstExitFromSelectionMerge(
166 live_lab_id, mergeInst->GetSingleWordInOperand(0),
167 cfgAnalysis->LoopMergeBlock(live_lab_id),
168 cfgAnalysis->LoopContinueBlock(live_lab_id));
169 if (first_break == nullptr) {
170 context()->KillInst(mergeInst);
171 } else {
172 mergeInst->RemoveFromList();
173 first_break->InsertBefore(std::unique_ptr<Instruction>(mergeInst));
174 context()->set_instr_block(mergeInst,
175 context()->get_instr_block(first_break));
176 }
177 }
178 stack.push_back(GetParentBlock(live_lab_id));
179 } else {
180 // All successors are live.
181 const auto* const_block = block;
182 const_block->ForEachSuccessorLabel([&stack, this](const uint32_t label) {
183 stack.push_back(GetParentBlock(label));
184 });
185 }
186 }
187
188 return modified;
189 }
190
MarkUnreachableStructuredTargets(const std::unordered_set<BasicBlock * > & live_blocks,std::unordered_set<BasicBlock * > * unreachable_merges,std::unordered_map<BasicBlock *,BasicBlock * > * unreachable_continues)191 void DeadBranchElimPass::MarkUnreachableStructuredTargets(
192 const std::unordered_set<BasicBlock*>& live_blocks,
193 std::unordered_set<BasicBlock*>* unreachable_merges,
194 std::unordered_map<BasicBlock*, BasicBlock*>* unreachable_continues) {
195 for (auto block : live_blocks) {
196 if (auto merge_id = block->MergeBlockIdIfAny()) {
197 BasicBlock* merge_block = GetParentBlock(merge_id);
198 if (!live_blocks.count(merge_block)) {
199 unreachable_merges->insert(merge_block);
200 }
201 if (auto cont_id = block->ContinueBlockIdIfAny()) {
202 BasicBlock* cont_block = GetParentBlock(cont_id);
203 if (!live_blocks.count(cont_block)) {
204 (*unreachable_continues)[cont_block] = block;
205 }
206 }
207 }
208 }
209 }
210
FixPhiNodesInLiveBlocks(Function * func,const std::unordered_set<BasicBlock * > & live_blocks,const std::unordered_map<BasicBlock *,BasicBlock * > & unreachable_continues)211 bool DeadBranchElimPass::FixPhiNodesInLiveBlocks(
212 Function* func, const std::unordered_set<BasicBlock*>& live_blocks,
213 const std::unordered_map<BasicBlock*, BasicBlock*>& unreachable_continues) {
214 bool modified = false;
215 for (auto& block : *func) {
216 if (live_blocks.count(&block)) {
217 for (auto iter = block.begin(); iter != block.end();) {
218 if (iter->opcode() != SpvOpPhi) {
219 break;
220 }
221
222 bool changed = false;
223 bool backedge_added = false;
224 Instruction* inst = &*iter;
225 std::vector<Operand> operands;
226 // Build a complete set of operands (not just input operands). Start
227 // with type and result id operands.
228 operands.push_back(inst->GetOperand(0u));
229 operands.push_back(inst->GetOperand(1u));
230 // Iterate through the incoming labels and determine which to keep
231 // and/or modify. If there in an unreachable continue block, there will
232 // be an edge from that block to the header. We need to keep it to
233 // maintain the structured control flow. If the header has more that 2
234 // incoming edges, then the OpPhi must have an entry for that edge.
235 // However, if there is only one other incoming edge, the OpPhi can be
236 // eliminated.
237 for (uint32_t i = 1; i < inst->NumInOperands(); i += 2) {
238 BasicBlock* inc = GetParentBlock(inst->GetSingleWordInOperand(i));
239 auto cont_iter = unreachable_continues.find(inc);
240 if (cont_iter != unreachable_continues.end() &&
241 cont_iter->second == &block && inst->NumInOperands() > 4) {
242 if (get_def_use_mgr()
243 ->GetDef(inst->GetSingleWordInOperand(i - 1))
244 ->opcode() == SpvOpUndef) {
245 // Already undef incoming value, no change necessary.
246 operands.push_back(inst->GetInOperand(i - 1));
247 operands.push_back(inst->GetInOperand(i));
248 backedge_added = true;
249 } else {
250 // Replace incoming value with undef if this phi exists in the
251 // loop header. Otherwise, this edge is not live since the
252 // unreachable continue block will be replaced with an
253 // unconditional branch to the header only.
254 operands.emplace_back(
255 SPV_OPERAND_TYPE_ID,
256 std::initializer_list<uint32_t>{Type2Undef(inst->type_id())});
257 operands.push_back(inst->GetInOperand(i));
258 changed = true;
259 backedge_added = true;
260 }
261 } else if (live_blocks.count(inc) && inc->IsSuccessor(&block)) {
262 // Keep live incoming edge.
263 operands.push_back(inst->GetInOperand(i - 1));
264 operands.push_back(inst->GetInOperand(i));
265 } else {
266 // Remove incoming edge.
267 changed = true;
268 }
269 }
270
271 if (changed) {
272 modified = true;
273 uint32_t continue_id = block.ContinueBlockIdIfAny();
274 if (!backedge_added && continue_id != 0 &&
275 unreachable_continues.count(GetParentBlock(continue_id)) &&
276 operands.size() > 4) {
277 // Changed the backedge to branch from the continue block instead
278 // of a successor of the continue block. Add an entry to the phi to
279 // provide an undef for the continue block. Since the successor of
280 // the continue must also be unreachable (dominated by the continue
281 // block), any entry for the original backedge has been removed
282 // from the phi operands.
283 operands.emplace_back(
284 SPV_OPERAND_TYPE_ID,
285 std::initializer_list<uint32_t>{Type2Undef(inst->type_id())});
286 operands.emplace_back(SPV_OPERAND_TYPE_ID,
287 std::initializer_list<uint32_t>{continue_id});
288 }
289
290 // Either replace the phi with a single value or rebuild the phi out
291 // of |operands|.
292 //
293 // We always have type and result id operands. So this phi has a
294 // single source if there are two more operands beyond those.
295 if (operands.size() == 4) {
296 // First input data operands is at index 2.
297 uint32_t replId = operands[2u].words[0];
298 context()->ReplaceAllUsesWith(inst->result_id(), replId);
299 iter = context()->KillInst(&*inst);
300 } else {
301 // We've rewritten the operands, so first instruct the def/use
302 // manager to forget uses in the phi before we replace them. After
303 // replacing operands update the def/use manager by re-analyzing
304 // the used ids in this phi.
305 get_def_use_mgr()->EraseUseRecordsOfOperandIds(inst);
306 inst->ReplaceOperands(operands);
307 get_def_use_mgr()->AnalyzeInstUse(inst);
308 ++iter;
309 }
310 } else {
311 ++iter;
312 }
313 }
314 }
315 }
316
317 return modified;
318 }
319
EraseDeadBlocks(Function * func,const std::unordered_set<BasicBlock * > & live_blocks,const std::unordered_set<BasicBlock * > & unreachable_merges,const std::unordered_map<BasicBlock *,BasicBlock * > & unreachable_continues)320 bool DeadBranchElimPass::EraseDeadBlocks(
321 Function* func, const std::unordered_set<BasicBlock*>& live_blocks,
322 const std::unordered_set<BasicBlock*>& unreachable_merges,
323 const std::unordered_map<BasicBlock*, BasicBlock*>& unreachable_continues) {
324 bool modified = false;
325 for (auto ebi = func->begin(); ebi != func->end();) {
326 if (unreachable_merges.count(&*ebi)) {
327 if (ebi->begin() != ebi->tail() ||
328 ebi->terminator()->opcode() != SpvOpUnreachable) {
329 // Make unreachable, but leave the label.
330 KillAllInsts(&*ebi, false);
331 // Add unreachable terminator.
332 ebi->AddInstruction(
333 MakeUnique<Instruction>(context(), SpvOpUnreachable, 0, 0,
334 std::initializer_list<Operand>{}));
335 context()->AnalyzeUses(ebi->terminator());
336 context()->set_instr_block(ebi->terminator(), &*ebi);
337 modified = true;
338 }
339 ++ebi;
340 } else if (unreachable_continues.count(&*ebi)) {
341 uint32_t cont_id = unreachable_continues.find(&*ebi)->second->id();
342 if (ebi->begin() != ebi->tail() ||
343 ebi->terminator()->opcode() != SpvOpBranch ||
344 ebi->terminator()->GetSingleWordInOperand(0u) != cont_id) {
345 // Make unreachable, but leave the label.
346 KillAllInsts(&*ebi, false);
347 // Add unconditional branch to header.
348 assert(unreachable_continues.count(&*ebi));
349 ebi->AddInstruction(MakeUnique<Instruction>(
350 context(), SpvOpBranch, 0, 0,
351 std::initializer_list<Operand>{{SPV_OPERAND_TYPE_ID, {cont_id}}}));
352 get_def_use_mgr()->AnalyzeInstUse(&*ebi->tail());
353 context()->set_instr_block(&*ebi->tail(), &*ebi);
354 modified = true;
355 }
356 ++ebi;
357 } else if (!live_blocks.count(&*ebi)) {
358 // Kill this block.
359 KillAllInsts(&*ebi);
360 ebi = ebi.Erase();
361 modified = true;
362 } else {
363 ++ebi;
364 }
365 }
366
367 return modified;
368 }
369
EliminateDeadBranches(Function * func)370 bool DeadBranchElimPass::EliminateDeadBranches(Function* func) {
371 bool modified = false;
372 std::unordered_set<BasicBlock*> live_blocks;
373 modified |= MarkLiveBlocks(func, &live_blocks);
374
375 std::unordered_set<BasicBlock*> unreachable_merges;
376 std::unordered_map<BasicBlock*, BasicBlock*> unreachable_continues;
377 MarkUnreachableStructuredTargets(live_blocks, &unreachable_merges,
378 &unreachable_continues);
379 modified |= FixPhiNodesInLiveBlocks(func, live_blocks, unreachable_continues);
380 modified |= EraseDeadBlocks(func, live_blocks, unreachable_merges,
381 unreachable_continues);
382
383 return modified;
384 }
385
FixBlockOrder()386 void DeadBranchElimPass::FixBlockOrder() {
387 context()->BuildInvalidAnalyses(IRContext::kAnalysisCFG |
388 IRContext::kAnalysisDominatorAnalysis);
389 // Reorders blocks according to DFS of dominator tree.
390 ProcessFunction reorder_dominators = [this](Function* function) {
391 DominatorAnalysis* dominators = context()->GetDominatorAnalysis(function);
392 std::vector<BasicBlock*> blocks;
393 for (auto iter = dominators->GetDomTree().begin();
394 iter != dominators->GetDomTree().end(); ++iter) {
395 if (iter->id() != 0) {
396 blocks.push_back(iter->bb_);
397 }
398 }
399 for (uint32_t i = 1; i < blocks.size(); ++i) {
400 function->MoveBasicBlockToAfter(blocks[i]->id(), blocks[i - 1]);
401 }
402 return true;
403 };
404
405 // Reorders blocks according to structured order.
406 ProcessFunction reorder_structured = [this](Function* function) {
407 std::list<BasicBlock*> order;
408 context()->cfg()->ComputeStructuredOrder(function, &*function->begin(),
409 &order);
410 std::vector<BasicBlock*> blocks;
411 for (auto block : order) {
412 blocks.push_back(block);
413 }
414 for (uint32_t i = 1; i < blocks.size(); ++i) {
415 function->MoveBasicBlockToAfter(blocks[i]->id(), blocks[i - 1]);
416 }
417 return true;
418 };
419
420 // Structured order is more intuitive so use it where possible.
421 if (context()->get_feature_mgr()->HasCapability(SpvCapabilityShader)) {
422 context()->ProcessReachableCallTree(reorder_structured);
423 } else {
424 context()->ProcessReachableCallTree(reorder_dominators);
425 }
426 }
427
Process()428 Pass::Status DeadBranchElimPass::Process() {
429 // Do not process if module contains OpGroupDecorate. Additional
430 // support required in KillNamesAndDecorates().
431 // TODO(greg-lunarg): Add support for OpGroupDecorate
432 for (auto& ai : get_module()->annotations())
433 if (ai.opcode() == SpvOpGroupDecorate) return Status::SuccessWithoutChange;
434 // Process all entry point functions
435 ProcessFunction pfn = [this](Function* fp) {
436 return EliminateDeadBranches(fp);
437 };
438 bool modified = context()->ProcessReachableCallTree(pfn);
439 if (modified) FixBlockOrder();
440 return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange;
441 }
442
FindFirstExitFromSelectionMerge(uint32_t start_block_id,uint32_t merge_block_id,uint32_t loop_merge_id,uint32_t loop_continue_id)443 Instruction* DeadBranchElimPass::FindFirstExitFromSelectionMerge(
444 uint32_t start_block_id, uint32_t merge_block_id, uint32_t loop_merge_id,
445 uint32_t loop_continue_id) {
446 // To find the "first" exit, we follow branches looking for a conditional
447 // branch that is not in a nested construct and is not the header of a new
448 // construct. We follow the control flow from |start_block_id| to find the
449 // first one.
450 while (start_block_id != merge_block_id && start_block_id != loop_merge_id &&
451 start_block_id != loop_continue_id) {
452 BasicBlock* start_block = context()->get_instr_block(start_block_id);
453 Instruction* branch = start_block->terminator();
454 uint32_t next_block_id = 0;
455 switch (branch->opcode()) {
456 case SpvOpBranchConditional:
457 next_block_id = start_block->MergeBlockIdIfAny();
458 if (next_block_id == 0) {
459 // If a possible target is the |loop_merge_id| or |loop_continue_id|,
460 // which are not the current merge node, then we continue the search
461 // with the other target.
462 for (uint32_t i = 1; i < 3; i++) {
463 if (branch->GetSingleWordInOperand(i) == loop_merge_id &&
464 loop_merge_id != merge_block_id) {
465 next_block_id = branch->GetSingleWordInOperand(3 - i);
466 break;
467 }
468 if (branch->GetSingleWordInOperand(i) == loop_continue_id &&
469 loop_continue_id != merge_block_id) {
470 next_block_id = branch->GetSingleWordInOperand(3 - i);
471 break;
472 }
473 }
474
475 if (next_block_id == 0) {
476 return branch;
477 }
478 }
479 break;
480 case SpvOpSwitch:
481 next_block_id = start_block->MergeBlockIdIfAny();
482 if (next_block_id == 0) {
483 // A switch with no merge instructions can have at most 4 targets:
484 // a. |merge_block_id|
485 // b. |loop_merge_id|
486 // c. |loop_continue_id|
487 // d. 1 block inside the current region.
488 //
489 // This leads to a number of cases of what to do.
490 //
491 // 1. Does not jump to a block inside of the current construct. In
492 // this case, there is not conditional break, so we should return
493 // |nullptr|.
494 //
495 // 2. Jumps to |merge_block_id| and a block inside the current
496 // construct. In this case, this branch conditionally break to the
497 // end of the current construct, so return the current branch.
498 //
499 // 3. Otherwise, this branch may break, but not to the current merge
500 // block. So we continue with the block that is inside the loop.
501
502 bool found_break = false;
503 for (uint32_t i = 1; i < branch->NumInOperands(); i += 2) {
504 uint32_t target = branch->GetSingleWordInOperand(i);
505 if (target == merge_block_id) {
506 found_break = true;
507 } else if (target != loop_merge_id && target != loop_continue_id) {
508 next_block_id = branch->GetSingleWordInOperand(i);
509 }
510 }
511
512 if (next_block_id == 0) {
513 // Case 1.
514 return nullptr;
515 }
516
517 if (found_break) {
518 // Case 2.
519 return branch;
520 }
521
522 // The fall through is case 3.
523 }
524 break;
525 case SpvOpBranch:
526 // Need to check if this is the header of a loop nested in the
527 // selection construct.
528 next_block_id = start_block->MergeBlockIdIfAny();
529 if (next_block_id == 0) {
530 next_block_id = branch->GetSingleWordInOperand(0);
531 }
532 break;
533 default:
534 return nullptr;
535 }
536 start_block_id = next_block_id;
537 }
538 return nullptr;
539 }
540
541 } // namespace opt
542 } // namespace spvtools
543