1 // Copyright (c) 2017 Google Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "source/opt/cfg.h"
16 
17 #include <memory>
18 #include <utility>
19 
20 #include "source/cfa.h"
21 #include "source/opt/ir_builder.h"
22 #include "source/opt/ir_context.h"
23 #include "source/opt/module.h"
24 
25 namespace spvtools {
26 namespace opt {
27 namespace {
28 
29 using cbb_ptr = const opt::BasicBlock*;
30 
31 // Universal Limit of ResultID + 1
32 const int kMaxResultId = 0x400000;
33 
34 }  // namespace
35 
CFG(Module * module)36 CFG::CFG(Module* module)
37     : module_(module),
38       pseudo_entry_block_(std::unique_ptr<Instruction>(
39           new Instruction(module->context(), SpvOpLabel, 0, 0, {}))),
40       pseudo_exit_block_(std::unique_ptr<Instruction>(new Instruction(
41           module->context(), SpvOpLabel, 0, kMaxResultId, {}))) {
42   for (auto& fn : *module) {
43     for (auto& blk : fn) {
44       RegisterBlock(&blk);
45     }
46   }
47 }
48 
AddEdges(BasicBlock * blk)49 void CFG::AddEdges(BasicBlock* blk) {
50   uint32_t blk_id = blk->id();
51   // Force the creation of an entry, not all basic block have predecessors
52   // (such as the entry blocks and some unreachables).
53   label2preds_[blk_id];
54   const auto* const_blk = blk;
55   const_blk->ForEachSuccessorLabel(
56       [blk_id, this](const uint32_t succ_id) { AddEdge(blk_id, succ_id); });
57 }
58 
RemoveNonExistingEdges(uint32_t blk_id)59 void CFG::RemoveNonExistingEdges(uint32_t blk_id) {
60   std::vector<uint32_t> updated_pred_list;
61   for (uint32_t id : preds(blk_id)) {
62     const BasicBlock* pred_blk = block(id);
63     bool has_branch = false;
64     pred_blk->ForEachSuccessorLabel([&has_branch, blk_id](uint32_t succ) {
65       if (succ == blk_id) {
66         has_branch = true;
67       }
68     });
69     if (has_branch) updated_pred_list.push_back(id);
70   }
71 
72   label2preds_.at(blk_id) = std::move(updated_pred_list);
73 }
74 
ComputeStructuredOrder(Function * func,BasicBlock * root,std::list<BasicBlock * > * order)75 void CFG::ComputeStructuredOrder(Function* func, BasicBlock* root,
76                                  std::list<BasicBlock*>* order) {
77   assert(module_->context()->get_feature_mgr()->HasCapability(
78              SpvCapabilityShader) &&
79          "This only works on structured control flow");
80 
81   // Compute structured successors and do DFS.
82   ComputeStructuredSuccessors(func);
83   auto ignore_block = [](cbb_ptr) {};
84   auto ignore_edge = [](cbb_ptr, cbb_ptr) {};
85   auto get_structured_successors = [this](const BasicBlock* b) {
86     return &(block2structured_succs_[b]);
87   };
88 
89   // TODO(greg-lunarg): Get rid of const_cast by making moving const
90   // out of the cfa.h prototypes and into the invoking code.
91   auto post_order = [&](cbb_ptr b) {
92     order->push_front(const_cast<BasicBlock*>(b));
93   };
94   CFA<BasicBlock>::DepthFirstTraversal(root, get_structured_successors,
95                                        ignore_block, post_order, ignore_edge);
96 }
97 
ForEachBlockInPostOrder(BasicBlock * bb,const std::function<void (BasicBlock *)> & f)98 void CFG::ForEachBlockInPostOrder(BasicBlock* bb,
99                                   const std::function<void(BasicBlock*)>& f) {
100   std::vector<BasicBlock*> po;
101   std::unordered_set<BasicBlock*> seen;
102   ComputePostOrderTraversal(bb, &po, &seen);
103 
104   for (BasicBlock* current_bb : po) {
105     if (!IsPseudoExitBlock(current_bb) && !IsPseudoEntryBlock(current_bb)) {
106       f(current_bb);
107     }
108   }
109 }
110 
ForEachBlockInReversePostOrder(BasicBlock * bb,const std::function<void (BasicBlock *)> & f)111 void CFG::ForEachBlockInReversePostOrder(
112     BasicBlock* bb, const std::function<void(BasicBlock*)>& f) {
113   std::vector<BasicBlock*> po;
114   std::unordered_set<BasicBlock*> seen;
115   ComputePostOrderTraversal(bb, &po, &seen);
116 
117   for (auto current_bb = po.rbegin(); current_bb != po.rend(); ++current_bb) {
118     if (!IsPseudoExitBlock(*current_bb) && !IsPseudoEntryBlock(*current_bb)) {
119       f(*current_bb);
120     }
121   }
122 }
123 
ComputeStructuredSuccessors(Function * func)124 void CFG::ComputeStructuredSuccessors(Function* func) {
125   block2structured_succs_.clear();
126   for (auto& blk : *func) {
127     // If no predecessors in function, make successor to pseudo entry.
128     if (label2preds_[blk.id()].size() == 0)
129       block2structured_succs_[&pseudo_entry_block_].push_back(&blk);
130 
131     // If header, make merge block first successor and continue block second
132     // successor if there is one.
133     uint32_t mbid = blk.MergeBlockIdIfAny();
134     if (mbid != 0) {
135       block2structured_succs_[&blk].push_back(block(mbid));
136       uint32_t cbid = blk.ContinueBlockIdIfAny();
137       if (cbid != 0) {
138         block2structured_succs_[&blk].push_back(block(cbid));
139       }
140     }
141 
142     // Add true successors.
143     const auto& const_blk = blk;
144     const_blk.ForEachSuccessorLabel([&blk, this](const uint32_t sbid) {
145       block2structured_succs_[&blk].push_back(block(sbid));
146     });
147   }
148 }
149 
ComputePostOrderTraversal(BasicBlock * bb,std::vector<BasicBlock * > * order,std::unordered_set<BasicBlock * > * seen)150 void CFG::ComputePostOrderTraversal(BasicBlock* bb,
151                                     std::vector<BasicBlock*>* order,
152                                     std::unordered_set<BasicBlock*>* seen) {
153   seen->insert(bb);
154   static_cast<const BasicBlock*>(bb)->ForEachSuccessorLabel(
155       [&order, &seen, this](const uint32_t sbid) {
156         BasicBlock* succ_bb = id2block_[sbid];
157         if (!seen->count(succ_bb)) {
158           ComputePostOrderTraversal(succ_bb, order, seen);
159         }
160       });
161   order->push_back(bb);
162 }
163 
SplitLoopHeader(BasicBlock * bb)164 BasicBlock* CFG::SplitLoopHeader(BasicBlock* bb) {
165   assert(bb->GetLoopMergeInst() && "Expecting bb to be the header of a loop.");
166 
167   Function* fn = bb->GetParent();
168   IRContext* context = module_->context();
169 
170   // Get the new header id up front.  If we are out of ids, then we cannot split
171   // the loop.
172   uint32_t new_header_id = context->TakeNextId();
173   if (new_header_id == 0) {
174     return nullptr;
175   }
176 
177   // Find the insertion point for the new bb.
178   Function::iterator header_it = std::find_if(
179       fn->begin(), fn->end(),
180       [bb](BasicBlock& block_in_func) { return &block_in_func == bb; });
181   assert(header_it != fn->end());
182 
183   const std::vector<uint32_t>& pred = preds(bb->id());
184   // Find the back edge
185   BasicBlock* latch_block = nullptr;
186   Function::iterator latch_block_iter = header_it;
187   while (++latch_block_iter != fn->end()) {
188     // If blocks are in the proper order, then the only branch that appears
189     // after the header is the latch.
190     if (std::find(pred.begin(), pred.end(), latch_block_iter->id()) !=
191         pred.end()) {
192       break;
193     }
194   }
195   assert(latch_block_iter != fn->end() && "Could not find the latch.");
196   latch_block = &*latch_block_iter;
197 
198   RemoveSuccessorEdges(bb);
199 
200   // Create the new header bb basic bb.
201   // Leave the phi instructions behind.
202   auto iter = bb->begin();
203   while (iter->opcode() == SpvOpPhi) {
204     ++iter;
205   }
206 
207   BasicBlock* new_header = bb->SplitBasicBlock(context, new_header_id, iter);
208   context->AnalyzeDefUse(new_header->GetLabelInst());
209 
210   // Update cfg
211   RegisterBlock(new_header);
212 
213   // Update bb mappings.
214   context->set_instr_block(new_header->GetLabelInst(), new_header);
215   new_header->ForEachInst([new_header, context](Instruction* inst) {
216     context->set_instr_block(inst, new_header);
217   });
218 
219   // Adjust the OpPhi instructions as needed.
220   bb->ForEachPhiInst([latch_block, bb, new_header, context](Instruction* phi) {
221     std::vector<uint32_t> preheader_phi_ops;
222     std::vector<Operand> header_phi_ops;
223 
224     // Identify where the original inputs to original OpPhi belong: header or
225     // preheader.
226     for (uint32_t i = 0; i < phi->NumInOperands(); i += 2) {
227       uint32_t def_id = phi->GetSingleWordInOperand(i);
228       uint32_t branch_id = phi->GetSingleWordInOperand(i + 1);
229       if (branch_id == latch_block->id()) {
230         header_phi_ops.push_back({SPV_OPERAND_TYPE_ID, {def_id}});
231         header_phi_ops.push_back({SPV_OPERAND_TYPE_ID, {branch_id}});
232       } else {
233         preheader_phi_ops.push_back(def_id);
234         preheader_phi_ops.push_back(branch_id);
235       }
236     }
237 
238     // Create a phi instruction if and only if the preheader_phi_ops has more
239     // than one pair.
240     if (preheader_phi_ops.size() > 2) {
241       InstructionBuilder builder(
242           context, &*bb->begin(),
243           IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
244 
245       Instruction* new_phi = builder.AddPhi(phi->type_id(), preheader_phi_ops);
246 
247       // Add the OpPhi to the header bb.
248       header_phi_ops.push_back({SPV_OPERAND_TYPE_ID, {new_phi->result_id()}});
249       header_phi_ops.push_back({SPV_OPERAND_TYPE_ID, {bb->id()}});
250     } else {
251       // An OpPhi with a single entry is just a copy.  In this case use the same
252       // instruction in the new header.
253       header_phi_ops.push_back({SPV_OPERAND_TYPE_ID, {preheader_phi_ops[0]}});
254       header_phi_ops.push_back({SPV_OPERAND_TYPE_ID, {bb->id()}});
255     }
256 
257     phi->RemoveFromList();
258     std::unique_ptr<Instruction> phi_owner(phi);
259     phi->SetInOperands(std::move(header_phi_ops));
260     new_header->begin()->InsertBefore(std::move(phi_owner));
261     context->set_instr_block(phi, new_header);
262     context->AnalyzeUses(phi);
263   });
264 
265   // Add a branch to the new header.
266   InstructionBuilder branch_builder(
267       context, bb,
268       IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
269   bb->AddInstruction(
270       MakeUnique<Instruction>(context, SpvOpBranch, 0, 0,
271                               std::initializer_list<Operand>{
272                                   {SPV_OPERAND_TYPE_ID, {new_header->id()}}}));
273   context->AnalyzeUses(bb->terminator());
274   context->set_instr_block(bb->terminator(), bb);
275   label2preds_[new_header->id()].push_back(bb->id());
276 
277   // Update the latch to branch to the new header.
278   latch_block->ForEachSuccessorLabel([bb, new_header_id](uint32_t* id) {
279     if (*id == bb->id()) {
280       *id = new_header_id;
281     }
282   });
283   Instruction* latch_branch = latch_block->terminator();
284   context->AnalyzeUses(latch_branch);
285   label2preds_[new_header->id()].push_back(latch_block->id());
286 
287   auto& block_preds = label2preds_[bb->id()];
288   auto latch_pos =
289       std::find(block_preds.begin(), block_preds.end(), latch_block->id());
290   assert(latch_pos != block_preds.end() && "The cfg was invalid.");
291   block_preds.erase(latch_pos);
292 
293   // Update the loop descriptors
294   if (context->AreAnalysesValid(IRContext::kAnalysisLoopAnalysis)) {
295     LoopDescriptor* loop_desc = context->GetLoopDescriptor(bb->GetParent());
296     Loop* loop = (*loop_desc)[bb->id()];
297 
298     loop->AddBasicBlock(new_header_id);
299     loop->SetHeaderBlock(new_header);
300     loop_desc->SetBasicBlockToLoop(new_header_id, loop);
301 
302     loop->RemoveBasicBlock(bb->id());
303     loop->SetPreHeaderBlock(bb);
304 
305     Loop* parent_loop = loop->GetParent();
306     if (parent_loop != nullptr) {
307       parent_loop->AddBasicBlock(bb->id());
308       loop_desc->SetBasicBlockToLoop(bb->id(), parent_loop);
309     } else {
310       loop_desc->SetBasicBlockToLoop(bb->id(), nullptr);
311     }
312   }
313   return new_header;
314 }
315 
316 }  // namespace opt
317 }  // namespace spvtools
318