1 // Copyright (c) 2018 The Khronos Group Inc.
2 // Copyright (c) 2018 Valve Corporation
3 // Copyright (c) 2018 LunarG Inc.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //     http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 
17 #include "source/opt/dead_insert_elim_pass.h"
18 
19 #include "source/opt/composite.h"
20 #include "source/opt/ir_context.h"
21 #include "source/opt/iterator.h"
22 #include "spirv/1.2/GLSL.std.450.h"
23 
24 namespace spvtools {
25 namespace opt {
26 
27 namespace {
28 
29 const uint32_t kTypeVectorCountInIdx = 1;
30 const uint32_t kTypeMatrixCountInIdx = 1;
31 const uint32_t kTypeArrayLengthIdInIdx = 1;
32 const uint32_t kTypeIntWidthInIdx = 0;
33 const uint32_t kConstantValueInIdx = 0;
34 const uint32_t kInsertObjectIdInIdx = 0;
35 const uint32_t kInsertCompositeIdInIdx = 1;
36 
37 }  // anonymous namespace
38 
NumComponents(Instruction * typeInst)39 uint32_t DeadInsertElimPass::NumComponents(Instruction* typeInst) {
40   switch (typeInst->opcode()) {
41     case SpvOpTypeVector: {
42       return typeInst->GetSingleWordInOperand(kTypeVectorCountInIdx);
43     } break;
44     case SpvOpTypeMatrix: {
45       return typeInst->GetSingleWordInOperand(kTypeMatrixCountInIdx);
46     } break;
47     case SpvOpTypeArray: {
48       uint32_t lenId =
49           typeInst->GetSingleWordInOperand(kTypeArrayLengthIdInIdx);
50       Instruction* lenInst = get_def_use_mgr()->GetDef(lenId);
51       if (lenInst->opcode() != SpvOpConstant) return 0;
52       uint32_t lenTypeId = lenInst->type_id();
53       Instruction* lenTypeInst = get_def_use_mgr()->GetDef(lenTypeId);
54       // TODO(greg-lunarg): Support non-32-bit array length
55       if (lenTypeInst->GetSingleWordInOperand(kTypeIntWidthInIdx) != 32)
56         return 0;
57       return lenInst->GetSingleWordInOperand(kConstantValueInIdx);
58     } break;
59     case SpvOpTypeStruct: {
60       return typeInst->NumInOperands();
61     } break;
62     default: { return 0; } break;
63   }
64 }
65 
MarkInsertChain(Instruction * insertChain,std::vector<uint32_t> * pExtIndices,uint32_t extOffset,std::unordered_set<uint32_t> * visited_phis)66 void DeadInsertElimPass::MarkInsertChain(
67     Instruction* insertChain, std::vector<uint32_t>* pExtIndices,
68     uint32_t extOffset, std::unordered_set<uint32_t>* visited_phis) {
69   // Not currently optimizing array inserts.
70   Instruction* typeInst = get_def_use_mgr()->GetDef(insertChain->type_id());
71   if (typeInst->opcode() == SpvOpTypeArray) return;
72   // Insert chains are only composed of inserts and phis
73   if (insertChain->opcode() != SpvOpCompositeInsert &&
74       insertChain->opcode() != SpvOpPhi)
75     return;
76   // If extract indices are empty, mark all subcomponents if type
77   // is constant length.
78   if (pExtIndices == nullptr) {
79     uint32_t cnum = NumComponents(typeInst);
80     if (cnum > 0) {
81       std::vector<uint32_t> extIndices;
82       for (uint32_t i = 0; i < cnum; i++) {
83         extIndices.clear();
84         extIndices.push_back(i);
85         std::unordered_set<uint32_t> sub_visited_phis;
86         MarkInsertChain(insertChain, &extIndices, 0, &sub_visited_phis);
87       }
88       return;
89     }
90   }
91   Instruction* insInst = insertChain;
92   while (insInst->opcode() == SpvOpCompositeInsert) {
93     // If no extract indices, mark insert and inserted object (which might
94     // also be an insert chain) and continue up the chain though the input
95     // composite.
96     //
97     // Note: We mark inserted objects in this function (rather than in
98     // EliminateDeadInsertsOnePass) because in some cases, we can do it
99     // more accurately here.
100     if (pExtIndices == nullptr) {
101       liveInserts_.insert(insInst->result_id());
102       uint32_t objId = insInst->GetSingleWordInOperand(kInsertObjectIdInIdx);
103       std::unordered_set<uint32_t> obj_visited_phis;
104       MarkInsertChain(get_def_use_mgr()->GetDef(objId), nullptr, 0,
105                       &obj_visited_phis);
106     // If extract indices match insert, we are done. Mark insert and
107     // inserted object.
108     } else if (ExtInsMatch(*pExtIndices, insInst, extOffset)) {
109       liveInserts_.insert(insInst->result_id());
110       uint32_t objId = insInst->GetSingleWordInOperand(kInsertObjectIdInIdx);
111       std::unordered_set<uint32_t> obj_visited_phis;
112       MarkInsertChain(get_def_use_mgr()->GetDef(objId), nullptr, 0,
113                       &obj_visited_phis);
114       break;
115     // If non-matching intersection, mark insert
116     } else if (ExtInsConflict(*pExtIndices, insInst, extOffset)) {
117       liveInserts_.insert(insInst->result_id());
118       // If more extract indices than insert, we are done. Use remaining
119       // extract indices to mark inserted object.
120       uint32_t numInsertIndices = insInst->NumInOperands() - 2;
121       if (pExtIndices->size() - extOffset > numInsertIndices) {
122         uint32_t objId = insInst->GetSingleWordInOperand(kInsertObjectIdInIdx);
123         std::unordered_set<uint32_t> obj_visited_phis;
124         MarkInsertChain(get_def_use_mgr()->GetDef(objId), pExtIndices,
125                         extOffset + numInsertIndices, &obj_visited_phis);
126         break;
127       // If fewer extract indices than insert, also mark inserted object and
128       // continue up chain.
129       } else {
130         uint32_t objId = insInst->GetSingleWordInOperand(kInsertObjectIdInIdx);
131         std::unordered_set<uint32_t> obj_visited_phis;
132         MarkInsertChain(get_def_use_mgr()->GetDef(objId), nullptr, 0,
133                         &obj_visited_phis);
134       }
135     }
136     // Get next insert in chain
137     const uint32_t compId =
138         insInst->GetSingleWordInOperand(kInsertCompositeIdInIdx);
139     insInst = get_def_use_mgr()->GetDef(compId);
140   }
141   // If insert chain ended with phi, do recursive call on each operand
142   if (insInst->opcode() != SpvOpPhi) return;
143   // Mark phi visited to prevent potential infinite loop. If phi is already
144   // visited, return to avoid infinite loop.
145   if (visited_phis->count(insInst->result_id()) != 0) return;
146   visited_phis->insert(insInst->result_id());
147 
148   // Phis may have duplicate inputs values for different edges, prune incoming
149   // ids lists before recursing.
150   std::vector<uint32_t> ids;
151   for (uint32_t i = 0; i < insInst->NumInOperands(); i += 2) {
152     ids.push_back(insInst->GetSingleWordInOperand(i));
153   }
154   std::sort(ids.begin(), ids.end());
155   auto new_end = std::unique(ids.begin(), ids.end());
156   for (auto id_iter = ids.begin(); id_iter != new_end; ++id_iter) {
157     Instruction* pi = get_def_use_mgr()->GetDef(*id_iter);
158     MarkInsertChain(pi, pExtIndices, extOffset, visited_phis);
159   }
160 }
161 
EliminateDeadInserts(Function * func)162 bool DeadInsertElimPass::EliminateDeadInserts(Function* func) {
163   bool modified = false;
164   bool lastmodified = true;
165   // Each pass can delete dead instructions, thus potentially revealing
166   // new dead insertions ie insertions with no uses.
167   while (lastmodified) {
168     lastmodified = EliminateDeadInsertsOnePass(func);
169     modified |= lastmodified;
170   }
171   return modified;
172 }
173 
EliminateDeadInsertsOnePass(Function * func)174 bool DeadInsertElimPass::EliminateDeadInsertsOnePass(Function* func) {
175   bool modified = false;
176   liveInserts_.clear();
177   visitedPhis_.clear();
178   // Mark all live inserts
179   for (auto bi = func->begin(); bi != func->end(); ++bi) {
180     for (auto ii = bi->begin(); ii != bi->end(); ++ii) {
181       // Only process Inserts and composite Phis
182       SpvOp op = ii->opcode();
183       Instruction* typeInst = get_def_use_mgr()->GetDef(ii->type_id());
184       if (op != SpvOpCompositeInsert &&
185           (op != SpvOpPhi || !spvOpcodeIsComposite(typeInst->opcode())))
186         continue;
187       // The marking algorithm can be expensive for large arrays and the
188       // efficacy of eliminating dead inserts into arrays is questionable.
189       // Skip optimizing array inserts for now. Just mark them live.
190       // TODO(greg-lunarg): Eliminate dead array inserts
191       if (op == SpvOpCompositeInsert) {
192         if (typeInst->opcode() == SpvOpTypeArray) {
193           liveInserts_.insert(ii->result_id());
194           continue;
195         }
196       }
197       const uint32_t id = ii->result_id();
198       get_def_use_mgr()->ForEachUser(id, [&ii, this](Instruction* user) {
199         if (user->IsOpenCL100DebugInstr()) return;
200         switch (user->opcode()) {
201           case SpvOpCompositeInsert:
202           case SpvOpPhi:
203             // Use by insert or phi does not initiate marking
204             break;
205           case SpvOpCompositeExtract: {
206             // Capture extract indices
207             std::vector<uint32_t> extIndices;
208             uint32_t icnt = 0;
209             user->ForEachInOperand([&icnt, &extIndices](const uint32_t* idp) {
210               if (icnt > 0) extIndices.push_back(*idp);
211               ++icnt;
212             });
213             // Mark all inserts in chain that intersect with extract
214             std::unordered_set<uint32_t> visited_phis;
215             MarkInsertChain(&*ii, &extIndices, 0, &visited_phis);
216           } break;
217           default: {
218             // Mark inserts in chain for all components
219             MarkInsertChain(&*ii, nullptr, 0, nullptr);
220           } break;
221         }
222       });
223     }
224   }
225   // Find and disconnect dead inserts
226   std::vector<Instruction*> dead_instructions;
227   for (auto bi = func->begin(); bi != func->end(); ++bi) {
228     for (auto ii = bi->begin(); ii != bi->end(); ++ii) {
229       if (ii->opcode() != SpvOpCompositeInsert) continue;
230       const uint32_t id = ii->result_id();
231       if (liveInserts_.find(id) != liveInserts_.end()) continue;
232       const uint32_t replId =
233           ii->GetSingleWordInOperand(kInsertCompositeIdInIdx);
234       (void)context()->ReplaceAllUsesWith(id, replId);
235       dead_instructions.push_back(&*ii);
236       modified = true;
237     }
238   }
239   // DCE dead inserts
240   while (!dead_instructions.empty()) {
241     Instruction* inst = dead_instructions.back();
242     dead_instructions.pop_back();
243     DCEInst(inst, [&dead_instructions](Instruction* other_inst) {
244       auto i = std::find(dead_instructions.begin(), dead_instructions.end(),
245                          other_inst);
246       if (i != dead_instructions.end()) {
247         dead_instructions.erase(i);
248       }
249     });
250   }
251   return modified;
252 }
253 
Process()254 Pass::Status DeadInsertElimPass::Process() {
255   // Process all entry point functions.
256   ProcessFunction pfn = [this](Function* fp) {
257     return EliminateDeadInserts(fp);
258   };
259   bool modified = context()->ProcessEntryPointCallTree(pfn);
260   return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange;
261 }
262 
263 }  // namespace opt
264 }  // namespace spvtools
265