1 // Copyright (c) 2018 Google LLC.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "source/opt/copy_prop_arrays.h"
16
17 #include <utility>
18
19 #include "source/opt/ir_builder.h"
20
21 namespace spvtools {
22 namespace opt {
23 namespace {
24
25 const uint32_t kLoadPointerInOperand = 0;
26 const uint32_t kStorePointerInOperand = 0;
27 const uint32_t kStoreObjectInOperand = 1;
28 const uint32_t kCompositeExtractObjectInOperand = 0;
29 const uint32_t kTypePointerStorageClassInIdx = 0;
30 const uint32_t kTypePointerPointeeInIdx = 1;
31
32 } // namespace
33
Process()34 Pass::Status CopyPropagateArrays::Process() {
35 bool modified = false;
36 for (Function& function : *get_module()) {
37 BasicBlock* entry_bb = &*function.begin();
38
39 for (auto var_inst = entry_bb->begin(); var_inst->opcode() == SpvOpVariable;
40 ++var_inst) {
41 if (!IsPointerToArrayType(var_inst->type_id())) {
42 continue;
43 }
44
45 // Find the only store to the entire memory location, if it exists.
46 Instruction* store_inst = FindStoreInstruction(&*var_inst);
47
48 if (!store_inst) {
49 continue;
50 }
51
52 std::unique_ptr<MemoryObject> source_object =
53 FindSourceObjectIfPossible(&*var_inst, store_inst);
54
55 if (source_object != nullptr) {
56 if (CanUpdateUses(&*var_inst, source_object->GetPointerTypeId(this))) {
57 modified = true;
58 PropagateObject(&*var_inst, source_object.get(), store_inst);
59 }
60 }
61 }
62 }
63 return (modified ? Status::SuccessWithChange : Status::SuccessWithoutChange);
64 }
65
66 std::unique_ptr<CopyPropagateArrays::MemoryObject>
FindSourceObjectIfPossible(Instruction * var_inst,Instruction * store_inst)67 CopyPropagateArrays::FindSourceObjectIfPossible(Instruction* var_inst,
68 Instruction* store_inst) {
69 assert(var_inst->opcode() == SpvOpVariable && "Expecting a variable.");
70
71 // Check that the variable is a composite object where |store_inst|
72 // dominates all of its loads.
73 if (!store_inst) {
74 return nullptr;
75 }
76
77 // Look at the loads to ensure they are dominated by the store.
78 if (!HasValidReferencesOnly(var_inst, store_inst)) {
79 return nullptr;
80 }
81
82 // If so, look at the store to see if it is the copy of an object.
83 std::unique_ptr<MemoryObject> source = GetSourceObjectIfAny(
84 store_inst->GetSingleWordInOperand(kStoreObjectInOperand));
85
86 if (!source) {
87 return nullptr;
88 }
89
90 // Ensure that |source| does not change between the point at which it is
91 // loaded, and the position in which |var_inst| is loaded.
92 //
93 // For now we will go with the easy to implement approach, and check that the
94 // entire variable (not just the specific component) is never written to.
95
96 if (!HasNoStores(source->GetVariable())) {
97 return nullptr;
98 }
99 return source;
100 }
101
FindStoreInstruction(const Instruction * var_inst) const102 Instruction* CopyPropagateArrays::FindStoreInstruction(
103 const Instruction* var_inst) const {
104 Instruction* store_inst = nullptr;
105 get_def_use_mgr()->WhileEachUser(
106 var_inst, [&store_inst, var_inst](Instruction* use) {
107 if (use->opcode() == SpvOpStore &&
108 use->GetSingleWordInOperand(kStorePointerInOperand) ==
109 var_inst->result_id()) {
110 if (store_inst == nullptr) {
111 store_inst = use;
112 } else {
113 store_inst = nullptr;
114 return false;
115 }
116 }
117 return true;
118 });
119 return store_inst;
120 }
121
PropagateObject(Instruction * var_inst,MemoryObject * source,Instruction * insertion_point)122 void CopyPropagateArrays::PropagateObject(Instruction* var_inst,
123 MemoryObject* source,
124 Instruction* insertion_point) {
125 assert(var_inst->opcode() == SpvOpVariable &&
126 "This function propagates variables.");
127
128 Instruction* new_access_chain = BuildNewAccessChain(insertion_point, source);
129 context()->KillNamesAndDecorates(var_inst);
130 UpdateUses(var_inst, new_access_chain);
131 }
132
BuildNewAccessChain(Instruction * insertion_point,CopyPropagateArrays::MemoryObject * source) const133 Instruction* CopyPropagateArrays::BuildNewAccessChain(
134 Instruction* insertion_point,
135 CopyPropagateArrays::MemoryObject* source) const {
136 InstructionBuilder builder(
137 context(), insertion_point,
138 IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
139
140 if (source->AccessChain().size() == 0) {
141 return source->GetVariable();
142 }
143
144 return builder.AddAccessChain(source->GetPointerTypeId(this),
145 source->GetVariable()->result_id(),
146 source->AccessChain());
147 }
148
HasNoStores(Instruction * ptr_inst)149 bool CopyPropagateArrays::HasNoStores(Instruction* ptr_inst) {
150 return get_def_use_mgr()->WhileEachUser(ptr_inst, [this](Instruction* use) {
151 if (use->opcode() == SpvOpLoad) {
152 return true;
153 } else if (use->opcode() == SpvOpAccessChain) {
154 return HasNoStores(use);
155 } else if (use->IsDecoration() || use->opcode() == SpvOpName) {
156 return true;
157 } else if (use->opcode() == SpvOpStore) {
158 return false;
159 } else if (use->opcode() == SpvOpImageTexelPointer) {
160 return true;
161 }
162 // Some other instruction. Be conservative.
163 return false;
164 });
165 }
166
HasValidReferencesOnly(Instruction * ptr_inst,Instruction * store_inst)167 bool CopyPropagateArrays::HasValidReferencesOnly(Instruction* ptr_inst,
168 Instruction* store_inst) {
169 BasicBlock* store_block = context()->get_instr_block(store_inst);
170 DominatorAnalysis* dominator_analysis =
171 context()->GetDominatorAnalysis(store_block->GetParent());
172
173 return get_def_use_mgr()->WhileEachUser(
174 ptr_inst,
175 [this, store_inst, dominator_analysis, ptr_inst](Instruction* use) {
176 if (use->opcode() == SpvOpLoad ||
177 use->opcode() == SpvOpImageTexelPointer) {
178 // TODO: If there are many load in the same BB as |store_inst| the
179 // time to do the multiple traverses can add up. Consider collecting
180 // those loads and doing a single traversal.
181 return dominator_analysis->Dominates(store_inst, use);
182 } else if (use->opcode() == SpvOpAccessChain) {
183 return HasValidReferencesOnly(use, store_inst);
184 } else if (use->IsDecoration() || use->opcode() == SpvOpName) {
185 return true;
186 } else if (use->opcode() == SpvOpStore) {
187 // If we are storing to part of the object it is not an candidate.
188 return ptr_inst->opcode() == SpvOpVariable &&
189 store_inst->GetSingleWordInOperand(kStorePointerInOperand) ==
190 ptr_inst->result_id();
191 }
192 // Some other instruction. Be conservative.
193 return false;
194 });
195 }
196
197 std::unique_ptr<CopyPropagateArrays::MemoryObject>
GetSourceObjectIfAny(uint32_t result)198 CopyPropagateArrays::GetSourceObjectIfAny(uint32_t result) {
199 Instruction* result_inst = context()->get_def_use_mgr()->GetDef(result);
200
201 switch (result_inst->opcode()) {
202 case SpvOpLoad:
203 return BuildMemoryObjectFromLoad(result_inst);
204 case SpvOpCompositeExtract:
205 return BuildMemoryObjectFromExtract(result_inst);
206 case SpvOpCompositeConstruct:
207 return BuildMemoryObjectFromCompositeConstruct(result_inst);
208 case SpvOpCopyObject:
209 return GetSourceObjectIfAny(result_inst->GetSingleWordInOperand(0));
210 case SpvOpCompositeInsert:
211 return BuildMemoryObjectFromInsert(result_inst);
212 default:
213 return nullptr;
214 }
215 }
216
217 std::unique_ptr<CopyPropagateArrays::MemoryObject>
BuildMemoryObjectFromLoad(Instruction * load_inst)218 CopyPropagateArrays::BuildMemoryObjectFromLoad(Instruction* load_inst) {
219 std::vector<uint32_t> components_in_reverse;
220 analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
221
222 Instruction* current_inst = def_use_mgr->GetDef(
223 load_inst->GetSingleWordInOperand(kLoadPointerInOperand));
224
225 // Build the access chain for the memory object by collecting the indices used
226 // in the OpAccessChain instructions. If we find a variable index, then
227 // return |nullptr| because we cannot know for sure which memory location is
228 // used.
229 //
230 // It is built in reverse order because the different |OpAccessChain|
231 // instructions are visited in reverse order from which they are applied.
232 while (current_inst->opcode() == SpvOpAccessChain) {
233 for (uint32_t i = current_inst->NumInOperands() - 1; i >= 1; --i) {
234 uint32_t element_index_id = current_inst->GetSingleWordInOperand(i);
235 components_in_reverse.push_back(element_index_id);
236 }
237 current_inst = def_use_mgr->GetDef(current_inst->GetSingleWordInOperand(0));
238 }
239
240 // If the address in the load is not constructed from an |OpVariable|
241 // instruction followed by a series of |OpAccessChain| instructions, then
242 // return |nullptr| because we cannot identify the owner or access chain
243 // exactly.
244 if (current_inst->opcode() != SpvOpVariable) {
245 return nullptr;
246 }
247
248 // Build the memory object. Use |rbegin| and |rend| to put the access chain
249 // back in the correct order.
250 return std::unique_ptr<CopyPropagateArrays::MemoryObject>(
251 new MemoryObject(current_inst, components_in_reverse.rbegin(),
252 components_in_reverse.rend()));
253 }
254
255 std::unique_ptr<CopyPropagateArrays::MemoryObject>
BuildMemoryObjectFromExtract(Instruction * extract_inst)256 CopyPropagateArrays::BuildMemoryObjectFromExtract(Instruction* extract_inst) {
257 assert(extract_inst->opcode() == SpvOpCompositeExtract &&
258 "Expecting an OpCompositeExtract instruction.");
259 analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
260
261 std::unique_ptr<MemoryObject> result = GetSourceObjectIfAny(
262 extract_inst->GetSingleWordInOperand(kCompositeExtractObjectInOperand));
263
264 if (result) {
265 analysis::Integer int_type(32, false);
266 const analysis::Type* uint32_type =
267 context()->get_type_mgr()->GetRegisteredType(&int_type);
268
269 std::vector<uint32_t> components;
270 // Convert the indices in the extract instruction to a series of ids that
271 // can be used by the |OpAccessChain| instruction.
272 for (uint32_t i = 1; i < extract_inst->NumInOperands(); ++i) {
273 uint32_t index = extract_inst->GetSingleWordInOperand(i);
274 const analysis::Constant* index_const =
275 const_mgr->GetConstant(uint32_type, {index});
276 components.push_back(
277 const_mgr->GetDefiningInstruction(index_const)->result_id());
278 }
279 result->GetMember(components);
280 return result;
281 }
282 return nullptr;
283 }
284
285 std::unique_ptr<CopyPropagateArrays::MemoryObject>
BuildMemoryObjectFromCompositeConstruct(Instruction * conststruct_inst)286 CopyPropagateArrays::BuildMemoryObjectFromCompositeConstruct(
287 Instruction* conststruct_inst) {
288 assert(conststruct_inst->opcode() == SpvOpCompositeConstruct &&
289 "Expecting an OpCompositeConstruct instruction.");
290
291 // If every operand in the instruction are part of the same memory object, and
292 // are being combined in the same order, then the result is the same as the
293 // parent.
294
295 std::unique_ptr<MemoryObject> memory_object =
296 GetSourceObjectIfAny(conststruct_inst->GetSingleWordInOperand(0));
297
298 if (!memory_object) {
299 return nullptr;
300 }
301
302 if (!memory_object->IsMember()) {
303 return nullptr;
304 }
305
306 analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
307 const analysis::Constant* last_access =
308 const_mgr->FindDeclaredConstant(memory_object->AccessChain().back());
309 if (!last_access ||
310 (!last_access->AsIntConstant() && !last_access->AsNullConstant())) {
311 return nullptr;
312 }
313
314 if (last_access->GetU32() != 0) {
315 return nullptr;
316 }
317
318 memory_object->GetParent();
319
320 if (memory_object->GetNumberOfMembers() !=
321 conststruct_inst->NumInOperands()) {
322 return nullptr;
323 }
324
325 for (uint32_t i = 1; i < conststruct_inst->NumInOperands(); ++i) {
326 std::unique_ptr<MemoryObject> member_object =
327 GetSourceObjectIfAny(conststruct_inst->GetSingleWordInOperand(i));
328
329 if (!member_object) {
330 return nullptr;
331 }
332
333 if (!member_object->IsMember()) {
334 return nullptr;
335 }
336
337 if (!memory_object->Contains(member_object.get())) {
338 return nullptr;
339 }
340
341 last_access =
342 const_mgr->FindDeclaredConstant(member_object->AccessChain().back());
343 if (!last_access || !last_access->AsIntConstant()) {
344 return nullptr;
345 }
346
347 if (last_access->GetU32() != i) {
348 return nullptr;
349 }
350 }
351 return memory_object;
352 }
353
354 std::unique_ptr<CopyPropagateArrays::MemoryObject>
BuildMemoryObjectFromInsert(Instruction * insert_inst)355 CopyPropagateArrays::BuildMemoryObjectFromInsert(Instruction* insert_inst) {
356 assert(insert_inst->opcode() == SpvOpCompositeInsert &&
357 "Expecting an OpCompositeInsert instruction.");
358
359 analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
360 analysis::TypeManager* type_mgr = context()->get_type_mgr();
361 analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
362 const analysis::Type* result_type = type_mgr->GetType(insert_inst->type_id());
363
364 uint32_t number_of_elements = 0;
365 if (const analysis::Struct* struct_type = result_type->AsStruct()) {
366 number_of_elements =
367 static_cast<uint32_t>(struct_type->element_types().size());
368 } else if (const analysis::Array* array_type = result_type->AsArray()) {
369 const analysis::Constant* length_const =
370 const_mgr->FindDeclaredConstant(array_type->LengthId());
371 assert(length_const->AsIntConstant());
372 number_of_elements = length_const->AsIntConstant()->GetU32();
373 } else if (const analysis::Vector* vector_type = result_type->AsVector()) {
374 number_of_elements = vector_type->element_count();
375 } else if (const analysis::Matrix* matrix_type = result_type->AsMatrix()) {
376 number_of_elements = matrix_type->element_count();
377 }
378
379 if (number_of_elements == 0) {
380 return nullptr;
381 }
382
383 if (insert_inst->NumInOperands() != 3) {
384 return nullptr;
385 }
386
387 if (insert_inst->GetSingleWordInOperand(2) != number_of_elements - 1) {
388 return nullptr;
389 }
390
391 std::unique_ptr<MemoryObject> memory_object =
392 GetSourceObjectIfAny(insert_inst->GetSingleWordInOperand(0));
393
394 if (!memory_object) {
395 return nullptr;
396 }
397
398 if (!memory_object->IsMember()) {
399 return nullptr;
400 }
401
402 const analysis::Constant* last_access =
403 const_mgr->FindDeclaredConstant(memory_object->AccessChain().back());
404 if (!last_access || !last_access->AsIntConstant()) {
405 return nullptr;
406 }
407
408 if (last_access->GetU32() != number_of_elements - 1) {
409 return nullptr;
410 }
411
412 memory_object->GetParent();
413
414 Instruction* current_insert =
415 def_use_mgr->GetDef(insert_inst->GetSingleWordInOperand(1));
416 for (uint32_t i = number_of_elements - 1; i > 0; --i) {
417 if (current_insert->opcode() != SpvOpCompositeInsert) {
418 return nullptr;
419 }
420
421 if (current_insert->NumInOperands() != 3) {
422 return nullptr;
423 }
424
425 if (current_insert->GetSingleWordInOperand(2) != i - 1) {
426 return nullptr;
427 }
428
429 std::unique_ptr<MemoryObject> current_memory_object =
430 GetSourceObjectIfAny(current_insert->GetSingleWordInOperand(0));
431
432 if (!current_memory_object) {
433 return nullptr;
434 }
435
436 if (!current_memory_object->IsMember()) {
437 return nullptr;
438 }
439
440 if (memory_object->AccessChain().size() + 1 !=
441 current_memory_object->AccessChain().size()) {
442 return nullptr;
443 }
444
445 if (!memory_object->Contains(current_memory_object.get())) {
446 return nullptr;
447 }
448
449 const analysis::Constant* current_last_access =
450 const_mgr->FindDeclaredConstant(
451 current_memory_object->AccessChain().back());
452 if (!current_last_access || !current_last_access->AsIntConstant()) {
453 return nullptr;
454 }
455
456 if (current_last_access->GetU32() != i - 1) {
457 return nullptr;
458 }
459 current_insert =
460 def_use_mgr->GetDef(current_insert->GetSingleWordInOperand(1));
461 }
462
463 return memory_object;
464 }
465
IsPointerToArrayType(uint32_t type_id)466 bool CopyPropagateArrays::IsPointerToArrayType(uint32_t type_id) {
467 analysis::TypeManager* type_mgr = context()->get_type_mgr();
468 analysis::Pointer* pointer_type = type_mgr->GetType(type_id)->AsPointer();
469 if (pointer_type) {
470 return pointer_type->pointee_type()->kind() == analysis::Type::kArray ||
471 pointer_type->pointee_type()->kind() == analysis::Type::kImage;
472 }
473 return false;
474 }
475
CanUpdateUses(Instruction * original_ptr_inst,uint32_t type_id)476 bool CopyPropagateArrays::CanUpdateUses(Instruction* original_ptr_inst,
477 uint32_t type_id) {
478 analysis::TypeManager* type_mgr = context()->get_type_mgr();
479 analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
480 analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
481
482 analysis::Type* type = type_mgr->GetType(type_id);
483 if (type->AsRuntimeArray()) {
484 return false;
485 }
486
487 if (!type->AsStruct() && !type->AsArray() && !type->AsPointer()) {
488 // If the type is not an aggregate, then the desired type must be the
489 // same as the current type. No work to do, and we can do that.
490 return true;
491 }
492
493 return def_use_mgr->WhileEachUse(original_ptr_inst, [this, type_mgr,
494 const_mgr,
495 type](Instruction* use,
496 uint32_t) {
497 switch (use->opcode()) {
498 case SpvOpLoad: {
499 analysis::Pointer* pointer_type = type->AsPointer();
500 uint32_t new_type_id = type_mgr->GetId(pointer_type->pointee_type());
501
502 if (new_type_id != use->type_id()) {
503 return CanUpdateUses(use, new_type_id);
504 }
505 return true;
506 }
507 case SpvOpAccessChain: {
508 analysis::Pointer* pointer_type = type->AsPointer();
509 const analysis::Type* pointee_type = pointer_type->pointee_type();
510
511 std::vector<uint32_t> access_chain;
512 for (uint32_t i = 1; i < use->NumInOperands(); ++i) {
513 const analysis::Constant* index_const =
514 const_mgr->FindDeclaredConstant(use->GetSingleWordInOperand(i));
515 if (index_const) {
516 access_chain.push_back(index_const->AsIntConstant()->GetU32());
517 } else {
518 // Variable index means the type is a type where every element
519 // is the same type. Use element 0 to get the type.
520 access_chain.push_back(0);
521 }
522 }
523
524 const analysis::Type* new_pointee_type =
525 type_mgr->GetMemberType(pointee_type, access_chain);
526 analysis::Pointer pointerTy(new_pointee_type,
527 pointer_type->storage_class());
528 uint32_t new_pointer_type_id =
529 context()->get_type_mgr()->GetTypeInstruction(&pointerTy);
530
531 if (new_pointer_type_id != use->type_id()) {
532 return CanUpdateUses(use, new_pointer_type_id);
533 }
534 return true;
535 }
536 case SpvOpCompositeExtract: {
537 std::vector<uint32_t> access_chain;
538 for (uint32_t i = 1; i < use->NumInOperands(); ++i) {
539 access_chain.push_back(use->GetSingleWordInOperand(i));
540 }
541
542 const analysis::Type* new_type =
543 type_mgr->GetMemberType(type, access_chain);
544 uint32_t new_type_id = type_mgr->GetTypeInstruction(new_type);
545
546 if (new_type_id != use->type_id()) {
547 return CanUpdateUses(use, new_type_id);
548 }
549 return true;
550 }
551 case SpvOpStore:
552 // If needed, we can create an element-by-element copy to change the
553 // type of the value being stored. This way we can always handled
554 // stores.
555 return true;
556 case SpvOpImageTexelPointer:
557 case SpvOpName:
558 return true;
559 default:
560 return use->IsDecoration();
561 }
562 });
563 }
UpdateUses(Instruction * original_ptr_inst,Instruction * new_ptr_inst)564 void CopyPropagateArrays::UpdateUses(Instruction* original_ptr_inst,
565 Instruction* new_ptr_inst) {
566 // TODO (s-perron): Keep the def-use manager up to date. Not done now because
567 // it can cause problems for the |ForEachUse| traversals. Can be use by
568 // keeping a list of instructions that need updating, and then updating them
569 // in |PropagateObject|.
570
571 analysis::TypeManager* type_mgr = context()->get_type_mgr();
572 analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
573 analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
574
575 std::vector<std::pair<Instruction*, uint32_t> > uses;
576 def_use_mgr->ForEachUse(original_ptr_inst,
577 [&uses](Instruction* use, uint32_t index) {
578 uses.push_back({use, index});
579 });
580
581 for (auto pair : uses) {
582 Instruction* use = pair.first;
583 uint32_t index = pair.second;
584 switch (use->opcode()) {
585 case SpvOpLoad: {
586 // Replace the actual use.
587 context()->ForgetUses(use);
588 use->SetOperand(index, {new_ptr_inst->result_id()});
589
590 // Update the type.
591 Instruction* pointer_type_inst =
592 def_use_mgr->GetDef(new_ptr_inst->type_id());
593 uint32_t new_type_id =
594 pointer_type_inst->GetSingleWordInOperand(kTypePointerPointeeInIdx);
595 if (new_type_id != use->type_id()) {
596 use->SetResultType(new_type_id);
597 context()->AnalyzeUses(use);
598 UpdateUses(use, use);
599 } else {
600 context()->AnalyzeUses(use);
601 }
602 } break;
603 case SpvOpAccessChain: {
604 // Update the actual use.
605 context()->ForgetUses(use);
606 use->SetOperand(index, {new_ptr_inst->result_id()});
607
608 // Convert the ids on the OpAccessChain to indices that can be used to
609 // get the specific member.
610 std::vector<uint32_t> access_chain;
611 for (uint32_t i = 1; i < use->NumInOperands(); ++i) {
612 const analysis::Constant* index_const =
613 const_mgr->FindDeclaredConstant(use->GetSingleWordInOperand(i));
614 if (index_const) {
615 access_chain.push_back(index_const->AsIntConstant()->GetU32());
616 } else {
617 // Variable index means the type is an type where every element
618 // is the same type. Use element 0 to get the type.
619 access_chain.push_back(0);
620 }
621 }
622
623 Instruction* pointer_type_inst =
624 get_def_use_mgr()->GetDef(new_ptr_inst->type_id());
625
626 uint32_t new_pointee_type_id = GetMemberTypeId(
627 pointer_type_inst->GetSingleWordInOperand(kTypePointerPointeeInIdx),
628 access_chain);
629
630 SpvStorageClass storage_class = static_cast<SpvStorageClass>(
631 pointer_type_inst->GetSingleWordInOperand(
632 kTypePointerStorageClassInIdx));
633
634 uint32_t new_pointer_type_id =
635 type_mgr->FindPointerToType(new_pointee_type_id, storage_class);
636
637 if (new_pointer_type_id != use->type_id()) {
638 use->SetResultType(new_pointer_type_id);
639 context()->AnalyzeUses(use);
640 UpdateUses(use, use);
641 } else {
642 context()->AnalyzeUses(use);
643 }
644 } break;
645 case SpvOpCompositeExtract: {
646 // Update the actual use.
647 context()->ForgetUses(use);
648 use->SetOperand(index, {new_ptr_inst->result_id()});
649
650 uint32_t new_type_id = new_ptr_inst->type_id();
651 std::vector<uint32_t> access_chain;
652 for (uint32_t i = 1; i < use->NumInOperands(); ++i) {
653 access_chain.push_back(use->GetSingleWordInOperand(i));
654 }
655
656 new_type_id = GetMemberTypeId(new_type_id, access_chain);
657
658 if (new_type_id != use->type_id()) {
659 use->SetResultType(new_type_id);
660 context()->AnalyzeUses(use);
661 UpdateUses(use, use);
662 } else {
663 context()->AnalyzeUses(use);
664 }
665 } break;
666 case SpvOpStore:
667 // If the use is the pointer, then it is the single store to that
668 // variable. We do not want to replace it. Instead, it will become
669 // dead after all of the loads are removed, and ADCE will get rid of it.
670 //
671 // If the use is the object being stored, we will create a copy of the
672 // object turning it into the correct type. The copy is done by
673 // decomposing the object into the base type, which must be the same,
674 // and then rebuilding them.
675 if (index == 1) {
676 Instruction* target_pointer = def_use_mgr->GetDef(
677 use->GetSingleWordInOperand(kStorePointerInOperand));
678 Instruction* pointer_type =
679 def_use_mgr->GetDef(target_pointer->type_id());
680 uint32_t pointee_type_id =
681 pointer_type->GetSingleWordInOperand(kTypePointerPointeeInIdx);
682 uint32_t copy = GenerateCopy(original_ptr_inst, pointee_type_id, use);
683
684 context()->ForgetUses(use);
685 use->SetInOperand(index, {copy});
686 context()->AnalyzeUses(use);
687 }
688 break;
689 case SpvOpImageTexelPointer:
690 // We treat an OpImageTexelPointer as a load. The result type should
691 // always have the Image storage class, and should not need to be
692 // updated.
693
694 // Replace the actual use.
695 context()->ForgetUses(use);
696 use->SetOperand(index, {new_ptr_inst->result_id()});
697 context()->AnalyzeUses(use);
698 break;
699 default:
700 assert(false && "Don't know how to rewrite instruction");
701 break;
702 }
703 }
704 }
705
GenerateCopy(Instruction * object_inst,uint32_t new_type_id,Instruction * insertion_position)706 uint32_t CopyPropagateArrays::GenerateCopy(Instruction* object_inst,
707 uint32_t new_type_id,
708 Instruction* insertion_position) {
709 analysis::TypeManager* type_mgr = context()->get_type_mgr();
710 analysis::ConstantManager* const_mgr = context()->get_constant_mgr();
711
712 uint32_t original_type_id = object_inst->type_id();
713 if (original_type_id == new_type_id) {
714 return object_inst->result_id();
715 }
716
717 InstructionBuilder ir_builder(
718 context(), insertion_position,
719 IRContext::kAnalysisInstrToBlockMapping | IRContext::kAnalysisDefUse);
720
721 analysis::Type* original_type = type_mgr->GetType(original_type_id);
722 analysis::Type* new_type = type_mgr->GetType(new_type_id);
723
724 if (const analysis::Array* original_array_type = original_type->AsArray()) {
725 uint32_t original_element_type_id =
726 type_mgr->GetId(original_array_type->element_type());
727
728 analysis::Array* new_array_type = new_type->AsArray();
729 assert(new_array_type != nullptr && "Can't copy an array to a non-array.");
730 uint32_t new_element_type_id =
731 type_mgr->GetId(new_array_type->element_type());
732
733 std::vector<uint32_t> element_ids;
734 const analysis::Constant* length_const =
735 const_mgr->FindDeclaredConstant(original_array_type->LengthId());
736 assert(length_const->AsIntConstant());
737 uint32_t array_length = length_const->AsIntConstant()->GetU32();
738 for (uint32_t i = 0; i < array_length; i++) {
739 Instruction* extract = ir_builder.AddCompositeExtract(
740 original_element_type_id, object_inst->result_id(), {i});
741 element_ids.push_back(
742 GenerateCopy(extract, new_element_type_id, insertion_position));
743 }
744
745 return ir_builder.AddCompositeConstruct(new_type_id, element_ids)
746 ->result_id();
747 } else if (const analysis::Struct* original_struct_type =
748 original_type->AsStruct()) {
749 analysis::Struct* new_struct_type = new_type->AsStruct();
750
751 const std::vector<const analysis::Type*>& original_types =
752 original_struct_type->element_types();
753 const std::vector<const analysis::Type*>& new_types =
754 new_struct_type->element_types();
755 std::vector<uint32_t> element_ids;
756 for (uint32_t i = 0; i < original_types.size(); i++) {
757 Instruction* extract = ir_builder.AddCompositeExtract(
758 type_mgr->GetId(original_types[i]), object_inst->result_id(), {i});
759 element_ids.push_back(GenerateCopy(extract, type_mgr->GetId(new_types[i]),
760 insertion_position));
761 }
762 return ir_builder.AddCompositeConstruct(new_type_id, element_ids)
763 ->result_id();
764 } else {
765 // If we do not have an aggregate type, then we have a problem. Either we
766 // found multiple instances of the same type, or we are copying to an
767 // incompatible type. Either way the code is illegal.
768 assert(false &&
769 "Don't know how to copy this type. Code is likely illegal.");
770 }
771 return 0;
772 }
773
GetMemberTypeId(uint32_t id,const std::vector<uint32_t> & access_chain) const774 uint32_t CopyPropagateArrays::GetMemberTypeId(
775 uint32_t id, const std::vector<uint32_t>& access_chain) const {
776 for (uint32_t element_index : access_chain) {
777 Instruction* type_inst = get_def_use_mgr()->GetDef(id);
778 switch (type_inst->opcode()) {
779 case SpvOpTypeArray:
780 case SpvOpTypeRuntimeArray:
781 case SpvOpTypeMatrix:
782 case SpvOpTypeVector:
783 id = type_inst->GetSingleWordInOperand(0);
784 break;
785 case SpvOpTypeStruct:
786 id = type_inst->GetSingleWordInOperand(element_index);
787 break;
788 default:
789 break;
790 }
791 assert(id != 0 &&
792 "Tried to extract from an object where it cannot be done.");
793 }
794 return id;
795 }
796
GetMember(const std::vector<uint32_t> & access_chain)797 void CopyPropagateArrays::MemoryObject::GetMember(
798 const std::vector<uint32_t>& access_chain) {
799 access_chain_.insert(access_chain_.end(), access_chain.begin(),
800 access_chain.end());
801 }
802
GetNumberOfMembers()803 uint32_t CopyPropagateArrays::MemoryObject::GetNumberOfMembers() {
804 IRContext* context = variable_inst_->context();
805 analysis::TypeManager* type_mgr = context->get_type_mgr();
806
807 const analysis::Type* type = type_mgr->GetType(variable_inst_->type_id());
808 type = type->AsPointer()->pointee_type();
809
810 std::vector<uint32_t> access_indices = GetAccessIds();
811 type = type_mgr->GetMemberType(type, access_indices);
812
813 if (const analysis::Struct* struct_type = type->AsStruct()) {
814 return static_cast<uint32_t>(struct_type->element_types().size());
815 } else if (const analysis::Array* array_type = type->AsArray()) {
816 const analysis::Constant* length_const =
817 context->get_constant_mgr()->FindDeclaredConstant(
818 array_type->LengthId());
819 assert(length_const->AsIntConstant());
820 return length_const->AsIntConstant()->GetU32();
821 } else if (const analysis::Vector* vector_type = type->AsVector()) {
822 return vector_type->element_count();
823 } else if (const analysis::Matrix* matrix_type = type->AsMatrix()) {
824 return matrix_type->element_count();
825 } else {
826 return 0;
827 }
828 }
829
830 template <class iterator>
MemoryObject(Instruction * var_inst,iterator begin,iterator end)831 CopyPropagateArrays::MemoryObject::MemoryObject(Instruction* var_inst,
832 iterator begin, iterator end)
833 : variable_inst_(var_inst), access_chain_(begin, end) {}
834
GetAccessIds() const835 std::vector<uint32_t> CopyPropagateArrays::MemoryObject::GetAccessIds() const {
836 analysis::ConstantManager* const_mgr =
837 variable_inst_->context()->get_constant_mgr();
838
839 std::vector<uint32_t> access_indices;
840 for (uint32_t id : AccessChain()) {
841 const analysis::Constant* element_index_const =
842 const_mgr->FindDeclaredConstant(id);
843 if (!element_index_const) {
844 access_indices.push_back(0);
845 } else {
846 assert(element_index_const->AsIntConstant());
847 access_indices.push_back(element_index_const->AsIntConstant()->GetU32());
848 }
849 }
850 return access_indices;
851 }
852
Contains(CopyPropagateArrays::MemoryObject * other)853 bool CopyPropagateArrays::MemoryObject::Contains(
854 CopyPropagateArrays::MemoryObject* other) {
855 if (this->GetVariable() != other->GetVariable()) {
856 return false;
857 }
858
859 if (AccessChain().size() > other->AccessChain().size()) {
860 return false;
861 }
862
863 for (uint32_t i = 0; i < AccessChain().size(); i++) {
864 if (AccessChain()[i] != other->AccessChain()[i]) {
865 return false;
866 }
867 }
868 return true;
869 }
870
871 } // namespace opt
872 } // namespace spvtools
873