1 // Copyright (c) 2018 Google LLC.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "source/opt/copy_prop_arrays.h"
16 
17 #include <utility>
18 
19 #include "source/opt/ir_builder.h"
20 
21 namespace spvtools {
22 namespace opt {
23 namespace {
24 
25 const uint32_t kLoadPointerInOperand = 0;
26 const uint32_t kStorePointerInOperand = 0;
27 const uint32_t kStoreObjectInOperand = 1;
28 const uint32_t kCompositeExtractObjectInOperand = 0;
29 const uint32_t kTypePointerStorageClassInIdx = 0;
30 const uint32_t kTypePointerPointeeInIdx = 1;
31 
32 }  // namespace
33 
Process()34 Pass::Status CopyPropagateArrays::Process() {
35   bool modified = false;
36   for (Function& function : *get_module()) {
37     BasicBlock* entry_bb = &*function.begin();
38 
39     for (auto var_inst = entry_bb->begin(); var_inst->opcode() == SpvOpVariable;
40          ++var_inst) {
41       if (!IsPointerToArrayType(var_inst->type_id())) {
42         continue;
43       }
44 
45       // Find the only store to the entire memory location, if it exists.
46       Instruction* store_inst = FindStoreInstruction(&*var_inst);
47 
48       if (!store_inst) {
49         continue;
50       }
51 
52       std::unique_ptr<MemoryObject> source_object =
53           FindSourceObjectIfPossible(&*var_inst, store_inst);
54 
55       if (source_object != nullptr) {
56         if (CanUpdateUses(&*var_inst, source_object->GetPointerTypeId(this))) {
57           modified = true;
58           PropagateObject(&*var_inst, source_object.get(), store_inst);
59         }
60       }
61     }
62   }
63   return (modified ? Status::SuccessWithChange : Status::SuccessWithoutChange);
64 }
65 
66 std::unique_ptr<CopyPropagateArrays::MemoryObject>
FindSourceObjectIfPossible(Instruction * var_inst,Instruction * store_inst)67 CopyPropagateArrays::FindSourceObjectIfPossible(Instruction* var_inst,
68                                                 Instruction* store_inst) {
69   assert(var_inst->opcode() == SpvOpVariable && "Expecting a variable.");
70 
71   // Check that the variable is a composite object where |store_inst|
72   // dominates all of its loads.
73   if (!store_inst) {
74     return nullptr;
75   }
76 
77   // Look at the loads to ensure they are dominated by the store.
78   if (!HasValidReferencesOnly(var_inst, store_inst)) {
79     return nullptr;
80   }
81 
82   // If so, look at the store to see if it is the copy of an object.
83   std::unique_ptr<MemoryObject> source = GetSourceObjectIfAny(
84       store_inst->GetSingleWordInOperand(kStoreObjectInOperand));
85 
86   if (!source) {
87     return nullptr;
88   }
89 
90   // Ensure that |source| does not change between the point at which it is
91   // loaded, and the position in which |var_inst| is loaded.
92   //
93   // For now we will go with the easy to implement approach, and check that the
94   // entire variable (not just the specific component) is never written to.
95 
96   if (!HasNoStores(source->GetVariable())) {
97     return nullptr;
98   }
99   return source;
100 }
101 
FindStoreInstruction(const Instruction * var_inst) const102 Instruction* CopyPropagateArrays::FindStoreInstruction(
103     const Instruction* var_inst) const {
104   Instruction* store_inst = nullptr;
105   get_def_use_mgr()->WhileEachUser(
106       var_inst, [&store_inst, var_inst](Instruction* use) {
107         if (use->opcode() == SpvOpStore &&
108             use->GetSingleWordInOperand(kStorePointerInOperand) ==
109                 var_inst->result_id()) {
110           if (store_inst == nullptr) {
111             store_inst = use;
112           } else {
113             store_inst = nullptr;
114             return false;
115           }
116         }
117         return true;
118       });
119   return store_inst;
120 }
121 
PropagateObject(Instruction * var_inst,MemoryObject * source,Instruction * insertion_point)122 void CopyPropagateArrays::PropagateObject(Instruction* var_inst,
123                                           MemoryObject* source,
124                                           Instruction* insertion_point) {
125   assert(var_inst->opcode() == SpvOpVariable &&
126          "This function propagates variables.");
127 
128   Instruction* new_access_chain = BuildNewAccessChain(insertion_point, source);
129   context()->KillNamesAndDecorates(var_inst);
130   UpdateUses(var_inst, new_access_chain);
131 }
132 
BuildNewAccessChain(Instruction * insertion_point,CopyPropagateArrays::MemoryObject * source) const133 Instruction* CopyPropagateArrays::BuildNewAccessChain(
134     Instruction* insertion_point,
135     CopyPropagateArrays::MemoryObject* source) const {
136   InstructionBuilder builder(
137       context(), insertion_point,
138       IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
139 
140   if (source->AccessChain().size() == 0) {
141     return source->GetVariable();
142   }
143 
144   return builder.AddAccessChain(source->GetPointerTypeId(this),
145                                 source->GetVariable()->result_id(),
146                                 source->AccessChain());
147 }
148 
HasNoStores(Instruction * ptr_inst)149 bool CopyPropagateArrays::HasNoStores(Instruction* ptr_inst) {
150   return get_def_use_mgr()->WhileEachUser(ptr_inst, [this](Instruction* use) {
151     if (use->opcode() == SpvOpLoad) {
152       return true;
153     } else if (use->opcode() == SpvOpAccessChain) {
154       return HasNoStores(use);
155     } else if (use->IsDecoration() || use->opcode() == SpvOpName) {
156       return true;
157     } else if (use->opcode() == SpvOpStore) {
158       return false;
159     } else if (use->opcode() == SpvOpImageTexelPointer) {
160       return true;
161     }
162     // Some other instruction.  Be conservative.
163     return false;
164   });
165 }
166 
HasValidReferencesOnly(Instruction * ptr_inst,Instruction * store_inst)167 bool CopyPropagateArrays::HasValidReferencesOnly(Instruction* ptr_inst,
168                                                  Instruction* store_inst) {
169   BasicBlock* store_block = context()->get_instr_block(store_inst);
170   DominatorAnalysis* dominator_analysis =
171       context()->GetDominatorAnalysis(store_block->GetParent());
172 
173   return get_def_use_mgr()->WhileEachUser(
174       ptr_inst,
175       [this, store_inst, dominator_analysis, ptr_inst](Instruction* use) {
176         if (use->opcode() == SpvOpLoad ||
177             use->opcode() == SpvOpImageTexelPointer) {
178           // TODO: If there are many load in the same BB as |store_inst| the
179           // time to do the multiple traverses can add up.  Consider collecting
180           // those loads and doing a single traversal.
181           return dominator_analysis->Dominates(store_inst, use);
182         } else if (use->opcode() == SpvOpAccessChain) {
183           return HasValidReferencesOnly(use, store_inst);
184         } else if (use->IsDecoration() || use->opcode() == SpvOpName) {
185           return true;
186         } else if (use->opcode() == SpvOpStore) {
187           // If we are storing to part of the object it is not an candidate.
188           return ptr_inst->opcode() == SpvOpVariable &&
189                  store_inst->GetSingleWordInOperand(kStorePointerInOperand) ==
190                      ptr_inst->result_id();
191         }
192         // Some other instruction.  Be conservative.
193         return false;
194       });
195 }
196 
197 std::unique_ptr<CopyPropagateArrays::MemoryObject>
GetSourceObjectIfAny(uint32_t result)198 CopyPropagateArrays::GetSourceObjectIfAny(uint32_t result) {
199   Instruction* result_inst = context()->get_def_use_mgr()->GetDef(result);
200 
201   switch (result_inst->opcode()) {
202     case SpvOpLoad:
203       return BuildMemoryObjectFromLoad(result_inst);
204     case SpvOpCompositeExtract:
205       return BuildMemoryObjectFromExtract(result_inst);
206     case SpvOpCompositeConstruct:
207       return BuildMemoryObjectFromCompositeConstruct(result_inst);
208     case SpvOpCopyObject:
209       return GetSourceObjectIfAny(result_inst->GetSingleWordInOperand(0));
210     case SpvOpCompositeInsert:
211       return BuildMemoryObjectFromInsert(result_inst);
212     default:
213       return nullptr;
214   }
215 }
216 
217 std::unique_ptr<CopyPropagateArrays::MemoryObject>
BuildMemoryObjectFromLoad(Instruction * load_inst)218 CopyPropagateArrays::BuildMemoryObjectFromLoad(Instruction* load_inst) {
219   std::vector<uint32_t> components_in_reverse;
220   analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
221 
222   Instruction* current_inst = def_use_mgr->GetDef(
223       load_inst->GetSingleWordInOperand(kLoadPointerInOperand));
224 
225   // Build the access chain for the memory object by collecting the indices used
226   // in the OpAccessChain instructions.  If we find a variable index, then
227   // return |nullptr| because we cannot know for sure which memory location is
228   // used.
229   //
230   // It is built in reverse order because the different |OpAccessChain|
231   // instructions are visited in reverse order from which they are applied.
232   while (current_inst->opcode() == SpvOpAccessChain) {
233     for (uint32_t i = current_inst->NumInOperands() - 1; i >= 1; --i) {
234       uint32_t element_index_id = current_inst->GetSingleWordInOperand(i);
235       components_in_reverse.push_back(element_index_id);
236     }
237     current_inst = def_use_mgr->GetDef(current_inst->GetSingleWordInOperand(0));
238   }
239 
240   // If the address in the load is not constructed from an |OpVariable|
241   // instruction followed by a series of |OpAccessChain| instructions, then
242   // return |nullptr| because we cannot identify the owner or access chain
243   // exactly.
244   if (current_inst->opcode() != SpvOpVariable) {
245     return nullptr;
246   }
247 
248   // Build the memory object.  Use |rbegin| and |rend| to put the access chain
249   // back in the correct order.
250   return std::unique_ptr<CopyPropagateArrays::MemoryObject>(
251       new MemoryObject(current_inst, components_in_reverse.rbegin(),
252                        components_in_reverse.rend()));
253 }
254 
255 std::unique_ptr<CopyPropagateArrays::MemoryObject>
BuildMemoryObjectFromExtract(Instruction * extract_inst)256 CopyPropagateArrays::BuildMemoryObjectFromExtract(Instruction* extract_inst) {
257   assert(extract_inst->opcode() == SpvOpCompositeExtract &&
258          "Expecting an OpCompositeExtract instruction.");
259   analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
260 
261   std::unique_ptr<MemoryObject> result = GetSourceObjectIfAny(
262       extract_inst->GetSingleWordInOperand(kCompositeExtractObjectInOperand));
263 
264   if (result) {
265     analysis::Integer int_type(32, false);
266     const analysis::Type* uint32_type =
267         context()->get_type_mgr()->GetRegisteredType(&int_type);
268 
269     std::vector<uint32_t> components;
270     // Convert the indices in the extract instruction to a series of ids that
271     // can be used by the |OpAccessChain| instruction.
272     for (uint32_t i = 1; i < extract_inst->NumInOperands(); ++i) {
273       uint32_t index = extract_inst->GetSingleWordInOperand(i);
274       const analysis::Constant* index_const =
275           const_mgr->GetConstant(uint32_type, {index});
276       components.push_back(
277           const_mgr->GetDefiningInstruction(index_const)->result_id());
278     }
279     result->GetMember(components);
280     return result;
281   }
282   return nullptr;
283 }
284 
285 std::unique_ptr<CopyPropagateArrays::MemoryObject>
BuildMemoryObjectFromCompositeConstruct(Instruction * conststruct_inst)286 CopyPropagateArrays::BuildMemoryObjectFromCompositeConstruct(
287     Instruction* conststruct_inst) {
288   assert(conststruct_inst->opcode() == SpvOpCompositeConstruct &&
289          "Expecting an OpCompositeConstruct instruction.");
290 
291   // If every operand in the instruction are part of the same memory object, and
292   // are being combined in the same order, then the result is the same as the
293   // parent.
294 
295   std::unique_ptr<MemoryObject> memory_object =
296       GetSourceObjectIfAny(conststruct_inst->GetSingleWordInOperand(0));
297 
298   if (!memory_object) {
299     return nullptr;
300   }
301 
302   if (!memory_object->IsMember()) {
303     return nullptr;
304   }
305 
306   analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
307   const analysis::Constant* last_access =
308       const_mgr->FindDeclaredConstant(memory_object->AccessChain().back());
309   if (!last_access ||
310       (!last_access->AsIntConstant() && !last_access->AsNullConstant())) {
311     return nullptr;
312   }
313 
314   if (last_access->GetU32() != 0) {
315     return nullptr;
316   }
317 
318   memory_object->GetParent();
319 
320   if (memory_object->GetNumberOfMembers() !=
321       conststruct_inst->NumInOperands()) {
322     return nullptr;
323   }
324 
325   for (uint32_t i = 1; i < conststruct_inst->NumInOperands(); ++i) {
326     std::unique_ptr<MemoryObject> member_object =
327         GetSourceObjectIfAny(conststruct_inst->GetSingleWordInOperand(i));
328 
329     if (!member_object) {
330       return nullptr;
331     }
332 
333     if (!member_object->IsMember()) {
334       return nullptr;
335     }
336 
337     if (!memory_object->Contains(member_object.get())) {
338       return nullptr;
339     }
340 
341     last_access =
342         const_mgr->FindDeclaredConstant(member_object->AccessChain().back());
343     if (!last_access || !last_access->AsIntConstant()) {
344       return nullptr;
345     }
346 
347     if (last_access->GetU32() != i) {
348       return nullptr;
349     }
350   }
351   return memory_object;
352 }
353 
354 std::unique_ptr<CopyPropagateArrays::MemoryObject>
BuildMemoryObjectFromInsert(Instruction * insert_inst)355 CopyPropagateArrays::BuildMemoryObjectFromInsert(Instruction* insert_inst) {
356   assert(insert_inst->opcode() == SpvOpCompositeInsert &&
357          "Expecting an OpCompositeInsert instruction.");
358 
359   analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
360   analysis::TypeManager* type_mgr = context()->get_type_mgr();
361   analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
362   const analysis::Type* result_type = type_mgr->GetType(insert_inst->type_id());
363 
364   uint32_t number_of_elements = 0;
365   if (const analysis::Struct* struct_type = result_type->AsStruct()) {
366     number_of_elements =
367         static_cast<uint32_t>(struct_type->element_types().size());
368   } else if (const analysis::Array* array_type = result_type->AsArray()) {
369     const analysis::Constant* length_const =
370         const_mgr->FindDeclaredConstant(array_type->LengthId());
371     assert(length_const->AsIntConstant());
372     number_of_elements = length_const->AsIntConstant()->GetU32();
373   } else if (const analysis::Vector* vector_type = result_type->AsVector()) {
374     number_of_elements = vector_type->element_count();
375   } else if (const analysis::Matrix* matrix_type = result_type->AsMatrix()) {
376     number_of_elements = matrix_type->element_count();
377   }
378 
379   if (number_of_elements == 0) {
380     return nullptr;
381   }
382 
383   if (insert_inst->NumInOperands() != 3) {
384     return nullptr;
385   }
386 
387   if (insert_inst->GetSingleWordInOperand(2) != number_of_elements - 1) {
388     return nullptr;
389   }
390 
391   std::unique_ptr<MemoryObject> memory_object =
392       GetSourceObjectIfAny(insert_inst->GetSingleWordInOperand(0));
393 
394   if (!memory_object) {
395     return nullptr;
396   }
397 
398   if (!memory_object->IsMember()) {
399     return nullptr;
400   }
401 
402   const analysis::Constant* last_access =
403       const_mgr->FindDeclaredConstant(memory_object->AccessChain().back());
404   if (!last_access || !last_access->AsIntConstant()) {
405     return nullptr;
406   }
407 
408   if (last_access->GetU32() != number_of_elements - 1) {
409     return nullptr;
410   }
411 
412   memory_object->GetParent();
413 
414   Instruction* current_insert =
415       def_use_mgr->GetDef(insert_inst->GetSingleWordInOperand(1));
416   for (uint32_t i = number_of_elements - 1; i > 0; --i) {
417     if (current_insert->opcode() != SpvOpCompositeInsert) {
418       return nullptr;
419     }
420 
421     if (current_insert->NumInOperands() != 3) {
422       return nullptr;
423     }
424 
425     if (current_insert->GetSingleWordInOperand(2) != i - 1) {
426       return nullptr;
427     }
428 
429     std::unique_ptr<MemoryObject> current_memory_object =
430         GetSourceObjectIfAny(current_insert->GetSingleWordInOperand(0));
431 
432     if (!current_memory_object) {
433       return nullptr;
434     }
435 
436     if (!current_memory_object->IsMember()) {
437       return nullptr;
438     }
439 
440     if (memory_object->AccessChain().size() + 1 !=
441         current_memory_object->AccessChain().size()) {
442       return nullptr;
443     }
444 
445     if (!memory_object->Contains(current_memory_object.get())) {
446       return nullptr;
447     }
448 
449     const analysis::Constant* current_last_access =
450         const_mgr->FindDeclaredConstant(
451             current_memory_object->AccessChain().back());
452     if (!current_last_access || !current_last_access->AsIntConstant()) {
453       return nullptr;
454     }
455 
456     if (current_last_access->GetU32() != i - 1) {
457       return nullptr;
458     }
459     current_insert =
460         def_use_mgr->GetDef(current_insert->GetSingleWordInOperand(1));
461   }
462 
463   return memory_object;
464 }
465 
IsPointerToArrayType(uint32_t type_id)466 bool CopyPropagateArrays::IsPointerToArrayType(uint32_t type_id) {
467   analysis::TypeManager* type_mgr = context()->get_type_mgr();
468   analysis::Pointer* pointer_type = type_mgr->GetType(type_id)->AsPointer();
469   if (pointer_type) {
470     return pointer_type->pointee_type()->kind() == analysis::Type::kArray ||
471            pointer_type->pointee_type()->kind() == analysis::Type::kImage;
472   }
473   return false;
474 }
475 
CanUpdateUses(Instruction * original_ptr_inst,uint32_t type_id)476 bool CopyPropagateArrays::CanUpdateUses(Instruction* original_ptr_inst,
477                                         uint32_t type_id) {
478   analysis::TypeManager* type_mgr = context()->get_type_mgr();
479   analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
480   analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
481 
482   analysis::Type* type = type_mgr->GetType(type_id);
483   if (type->AsRuntimeArray()) {
484     return false;
485   }
486 
487   if (!type->AsStruct() && !type->AsArray() && !type->AsPointer()) {
488     // If the type is not an aggregate, then the desired type must be the
489     // same as the current type.  No work to do, and we can do that.
490     return true;
491   }
492 
493   return def_use_mgr->WhileEachUse(original_ptr_inst, [this, type_mgr,
494                                                        const_mgr,
495                                                        type](Instruction* use,
496                                                              uint32_t) {
497     switch (use->opcode()) {
498       case SpvOpLoad: {
499         analysis::Pointer* pointer_type = type->AsPointer();
500         uint32_t new_type_id = type_mgr->GetId(pointer_type->pointee_type());
501 
502         if (new_type_id != use->type_id()) {
503           return CanUpdateUses(use, new_type_id);
504         }
505         return true;
506       }
507       case SpvOpAccessChain: {
508         analysis::Pointer* pointer_type = type->AsPointer();
509         const analysis::Type* pointee_type = pointer_type->pointee_type();
510 
511         std::vector<uint32_t> access_chain;
512         for (uint32_t i = 1; i < use->NumInOperands(); ++i) {
513           const analysis::Constant* index_const =
514               const_mgr->FindDeclaredConstant(use->GetSingleWordInOperand(i));
515           if (index_const) {
516             access_chain.push_back(index_const->AsIntConstant()->GetU32());
517           } else {
518             // Variable index means the type is a type where every element
519             // is the same type.  Use element 0 to get the type.
520             access_chain.push_back(0);
521           }
522         }
523 
524         const analysis::Type* new_pointee_type =
525             type_mgr->GetMemberType(pointee_type, access_chain);
526         analysis::Pointer pointerTy(new_pointee_type,
527                                     pointer_type->storage_class());
528         uint32_t new_pointer_type_id =
529             context()->get_type_mgr()->GetTypeInstruction(&pointerTy);
530 
531         if (new_pointer_type_id != use->type_id()) {
532           return CanUpdateUses(use, new_pointer_type_id);
533         }
534         return true;
535       }
536       case SpvOpCompositeExtract: {
537         std::vector<uint32_t> access_chain;
538         for (uint32_t i = 1; i < use->NumInOperands(); ++i) {
539           access_chain.push_back(use->GetSingleWordInOperand(i));
540         }
541 
542         const analysis::Type* new_type =
543             type_mgr->GetMemberType(type, access_chain);
544         uint32_t new_type_id = type_mgr->GetTypeInstruction(new_type);
545 
546         if (new_type_id != use->type_id()) {
547           return CanUpdateUses(use, new_type_id);
548         }
549         return true;
550       }
551       case SpvOpStore:
552         // If needed, we can create an element-by-element copy to change the
553         // type of the value being stored.  This way we can always handled
554         // stores.
555         return true;
556       case SpvOpImageTexelPointer:
557       case SpvOpName:
558         return true;
559       default:
560         return use->IsDecoration();
561     }
562   });
563 }
UpdateUses(Instruction * original_ptr_inst,Instruction * new_ptr_inst)564 void CopyPropagateArrays::UpdateUses(Instruction* original_ptr_inst,
565                                      Instruction* new_ptr_inst) {
566   // TODO (s-perron): Keep the def-use manager up to date.  Not done now because
567   // it can cause problems for the |ForEachUse| traversals.  Can be use by
568   // keeping a list of instructions that need updating, and then updating them
569   // in |PropagateObject|.
570 
571   analysis::TypeManager* type_mgr = context()->get_type_mgr();
572   analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
573   analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
574 
575   std::vector<std::pair<Instruction*, uint32_t> > uses;
576   def_use_mgr->ForEachUse(original_ptr_inst,
577                           [&uses](Instruction* use, uint32_t index) {
578                             uses.push_back({use, index});
579                           });
580 
581   for (auto pair : uses) {
582     Instruction* use = pair.first;
583     uint32_t index = pair.second;
584     switch (use->opcode()) {
585       case SpvOpLoad: {
586         // Replace the actual use.
587         context()->ForgetUses(use);
588         use->SetOperand(index, {new_ptr_inst->result_id()});
589 
590         // Update the type.
591         Instruction* pointer_type_inst =
592             def_use_mgr->GetDef(new_ptr_inst->type_id());
593         uint32_t new_type_id =
594             pointer_type_inst->GetSingleWordInOperand(kTypePointerPointeeInIdx);
595         if (new_type_id != use->type_id()) {
596           use->SetResultType(new_type_id);
597           context()->AnalyzeUses(use);
598           UpdateUses(use, use);
599         } else {
600           context()->AnalyzeUses(use);
601         }
602       } break;
603       case SpvOpAccessChain: {
604         // Update the actual use.
605         context()->ForgetUses(use);
606         use->SetOperand(index, {new_ptr_inst->result_id()});
607 
608         // Convert the ids on the OpAccessChain to indices that can be used to
609         // get the specific member.
610         std::vector<uint32_t> access_chain;
611         for (uint32_t i = 1; i < use->NumInOperands(); ++i) {
612           const analysis::Constant* index_const =
613               const_mgr->FindDeclaredConstant(use->GetSingleWordInOperand(i));
614           if (index_const) {
615             access_chain.push_back(index_const->AsIntConstant()->GetU32());
616           } else {
617             // Variable index means the type is an type where every element
618             // is the same type.  Use element 0 to get the type.
619             access_chain.push_back(0);
620           }
621         }
622 
623         Instruction* pointer_type_inst =
624             get_def_use_mgr()->GetDef(new_ptr_inst->type_id());
625 
626         uint32_t new_pointee_type_id = GetMemberTypeId(
627             pointer_type_inst->GetSingleWordInOperand(kTypePointerPointeeInIdx),
628             access_chain);
629 
630         SpvStorageClass storage_class = static_cast<SpvStorageClass>(
631             pointer_type_inst->GetSingleWordInOperand(
632                 kTypePointerStorageClassInIdx));
633 
634         uint32_t new_pointer_type_id =
635             type_mgr->FindPointerToType(new_pointee_type_id, storage_class);
636 
637         if (new_pointer_type_id != use->type_id()) {
638           use->SetResultType(new_pointer_type_id);
639           context()->AnalyzeUses(use);
640           UpdateUses(use, use);
641         } else {
642           context()->AnalyzeUses(use);
643         }
644       } break;
645       case SpvOpCompositeExtract: {
646         // Update the actual use.
647         context()->ForgetUses(use);
648         use->SetOperand(index, {new_ptr_inst->result_id()});
649 
650         uint32_t new_type_id = new_ptr_inst->type_id();
651         std::vector<uint32_t> access_chain;
652         for (uint32_t i = 1; i < use->NumInOperands(); ++i) {
653           access_chain.push_back(use->GetSingleWordInOperand(i));
654         }
655 
656         new_type_id = GetMemberTypeId(new_type_id, access_chain);
657 
658         if (new_type_id != use->type_id()) {
659           use->SetResultType(new_type_id);
660           context()->AnalyzeUses(use);
661           UpdateUses(use, use);
662         } else {
663           context()->AnalyzeUses(use);
664         }
665       } break;
666       case SpvOpStore:
667         // If the use is the pointer, then it is the single store to that
668         // variable.  We do not want to replace it.  Instead, it will become
669         // dead after all of the loads are removed, and ADCE will get rid of it.
670         //
671         // If the use is the object being stored, we will create a copy of the
672         // object turning it into the correct type. The copy is done by
673         // decomposing the object into the base type, which must be the same,
674         // and then rebuilding them.
675         if (index == 1) {
676           Instruction* target_pointer = def_use_mgr->GetDef(
677               use->GetSingleWordInOperand(kStorePointerInOperand));
678           Instruction* pointer_type =
679               def_use_mgr->GetDef(target_pointer->type_id());
680           uint32_t pointee_type_id =
681               pointer_type->GetSingleWordInOperand(kTypePointerPointeeInIdx);
682           uint32_t copy = GenerateCopy(original_ptr_inst, pointee_type_id, use);
683 
684           context()->ForgetUses(use);
685           use->SetInOperand(index, {copy});
686           context()->AnalyzeUses(use);
687         }
688         break;
689       case SpvOpImageTexelPointer:
690         // We treat an OpImageTexelPointer as a load.  The result type should
691         // always have the Image storage class, and should not need to be
692         // updated.
693 
694         // Replace the actual use.
695         context()->ForgetUses(use);
696         use->SetOperand(index, {new_ptr_inst->result_id()});
697         context()->AnalyzeUses(use);
698         break;
699       default:
700         assert(false && "Don't know how to rewrite instruction");
701         break;
702     }
703   }
704 }
705 
GenerateCopy(Instruction * object_inst,uint32_t new_type_id,Instruction * insertion_position)706 uint32_t CopyPropagateArrays::GenerateCopy(Instruction* object_inst,
707                                            uint32_t new_type_id,
708                                            Instruction* insertion_position) {
709   analysis::TypeManager* type_mgr = context()->get_type_mgr();
710   analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
711 
712   uint32_t original_type_id = object_inst->type_id();
713   if (original_type_id == new_type_id) {
714     return object_inst->result_id();
715   }
716 
717   InstructionBuilder ir_builder(
718       context(), insertion_position,
719       IRContext::kAnalysisInstrToBlockMapping | IRContext::kAnalysisDefUse);
720 
721   analysis::Type* original_type = type_mgr->GetType(original_type_id);
722   analysis::Type* new_type = type_mgr->GetType(new_type_id);
723 
724   if (const analysis::Array* original_array_type = original_type->AsArray()) {
725     uint32_t original_element_type_id =
726         type_mgr->GetId(original_array_type->element_type());
727 
728     analysis::Array* new_array_type = new_type->AsArray();
729     assert(new_array_type != nullptr && "Can't copy an array to a non-array.");
730     uint32_t new_element_type_id =
731         type_mgr->GetId(new_array_type->element_type());
732 
733     std::vector<uint32_t> element_ids;
734     const analysis::Constant* length_const =
735         const_mgr->FindDeclaredConstant(original_array_type->LengthId());
736     assert(length_const->AsIntConstant());
737     uint32_t array_length = length_const->AsIntConstant()->GetU32();
738     for (uint32_t i = 0; i < array_length; i++) {
739       Instruction* extract = ir_builder.AddCompositeExtract(
740           original_element_type_id, object_inst->result_id(), {i});
741       element_ids.push_back(
742           GenerateCopy(extract, new_element_type_id, insertion_position));
743     }
744 
745     return ir_builder.AddCompositeConstruct(new_type_id, element_ids)
746         ->result_id();
747   } else if (const analysis::Struct* original_struct_type =
748                  original_type->AsStruct()) {
749     analysis::Struct* new_struct_type = new_type->AsStruct();
750 
751     const std::vector<const analysis::Type*>& original_types =
752         original_struct_type->element_types();
753     const std::vector<const analysis::Type*>& new_types =
754         new_struct_type->element_types();
755     std::vector<uint32_t> element_ids;
756     for (uint32_t i = 0; i < original_types.size(); i++) {
757       Instruction* extract = ir_builder.AddCompositeExtract(
758           type_mgr->GetId(original_types[i]), object_inst->result_id(), {i});
759       element_ids.push_back(GenerateCopy(extract, type_mgr->GetId(new_types[i]),
760                                          insertion_position));
761     }
762     return ir_builder.AddCompositeConstruct(new_type_id, element_ids)
763         ->result_id();
764   } else {
765     // If we do not have an aggregate type, then we have a problem.  Either we
766     // found multiple instances of the same type, or we are copying to an
767     // incompatible type.  Either way the code is illegal.
768     assert(false &&
769            "Don't know how to copy this type.  Code is likely illegal.");
770   }
771   return 0;
772 }
773 
GetMemberTypeId(uint32_t id,const std::vector<uint32_t> & access_chain) const774 uint32_t CopyPropagateArrays::GetMemberTypeId(
775     uint32_t id, const std::vector<uint32_t>& access_chain) const {
776   for (uint32_t element_index : access_chain) {
777     Instruction* type_inst = get_def_use_mgr()->GetDef(id);
778     switch (type_inst->opcode()) {
779       case SpvOpTypeArray:
780       case SpvOpTypeRuntimeArray:
781       case SpvOpTypeMatrix:
782       case SpvOpTypeVector:
783         id = type_inst->GetSingleWordInOperand(0);
784         break;
785       case SpvOpTypeStruct:
786         id = type_inst->GetSingleWordInOperand(element_index);
787         break;
788       default:
789         break;
790     }
791     assert(id != 0 &&
792            "Tried to extract from an object where it cannot be done.");
793   }
794   return id;
795 }
796 
GetMember(const std::vector<uint32_t> & access_chain)797 void CopyPropagateArrays::MemoryObject::GetMember(
798     const std::vector<uint32_t>& access_chain) {
799   access_chain_.insert(access_chain_.end(), access_chain.begin(),
800                        access_chain.end());
801 }
802 
GetNumberOfMembers()803 uint32_t CopyPropagateArrays::MemoryObject::GetNumberOfMembers() {
804   IRContext* context = variable_inst_->context();
805   analysis::TypeManager* type_mgr = context->get_type_mgr();
806 
807   const analysis::Type* type = type_mgr->GetType(variable_inst_->type_id());
808   type = type->AsPointer()->pointee_type();
809 
810   std::vector<uint32_t> access_indices = GetAccessIds();
811   type = type_mgr->GetMemberType(type, access_indices);
812 
813   if (const analysis::Struct* struct_type = type->AsStruct()) {
814     return static_cast<uint32_t>(struct_type->element_types().size());
815   } else if (const analysis::Array* array_type = type->AsArray()) {
816     const analysis::Constant* length_const =
817         context->get_constant_mgr()->FindDeclaredConstant(
818             array_type->LengthId());
819     assert(length_const->AsIntConstant());
820     return length_const->AsIntConstant()->GetU32();
821   } else if (const analysis::Vector* vector_type = type->AsVector()) {
822     return vector_type->element_count();
823   } else if (const analysis::Matrix* matrix_type = type->AsMatrix()) {
824     return matrix_type->element_count();
825   } else {
826     return 0;
827   }
828 }
829 
830 template <class iterator>
MemoryObject(Instruction * var_inst,iterator begin,iterator end)831 CopyPropagateArrays::MemoryObject::MemoryObject(Instruction* var_inst,
832                                                 iterator begin, iterator end)
833     : variable_inst_(var_inst), access_chain_(begin, end) {}
834 
GetAccessIds() const835 std::vector<uint32_t> CopyPropagateArrays::MemoryObject::GetAccessIds() const {
836   analysis::ConstantManager* const_mgr =
837       variable_inst_->context()->get_constant_mgr();
838 
839   std::vector<uint32_t> access_indices;
840   for (uint32_t id : AccessChain()) {
841     const analysis::Constant* element_index_const =
842         const_mgr->FindDeclaredConstant(id);
843     if (!element_index_const) {
844       access_indices.push_back(0);
845     } else {
846       assert(element_index_const->AsIntConstant());
847       access_indices.push_back(element_index_const->AsIntConstant()->GetU32());
848     }
849   }
850   return access_indices;
851 }
852 
Contains(CopyPropagateArrays::MemoryObject * other)853 bool CopyPropagateArrays::MemoryObject::Contains(
854     CopyPropagateArrays::MemoryObject* other) {
855   if (this->GetVariable() != other->GetVariable()) {
856     return false;
857   }
858 
859   if (AccessChain().size() > other->AccessChain().size()) {
860     return false;
861   }
862 
863   for (uint32_t i = 0; i < AccessChain().size(); i++) {
864     if (AccessChain()[i] != other->AccessChain()[i]) {
865       return false;
866     }
867   }
868   return true;
869 }
870 
871 }  // namespace opt
872 }  // namespace spvtools
873