1 // Copyright (c) 2018 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "source/opt/combine_access_chains.h"
16 
17 #include <utility>
18 
19 #include "source/opt/constants.h"
20 #include "source/opt/ir_builder.h"
21 #include "source/opt/ir_context.h"
22 
23 namespace spvtools {
24 namespace opt {
25 
Process()26 Pass::Status CombineAccessChains::Process() {
27   bool modified = false;
28 
29   for (auto& function : *get_module()) {
30     modified |= ProcessFunction(function);
31   }
32 
33   return (modified ? Status::SuccessWithChange : Status::SuccessWithoutChange);
34 }
35 
ProcessFunction(Function & function)36 bool CombineAccessChains::ProcessFunction(Function& function) {
37   bool modified = false;
38 
39   cfg()->ForEachBlockInReversePostOrder(
40       function.entry().get(), [&modified, this](BasicBlock* block) {
41         block->ForEachInst([&modified, this](Instruction* inst) {
42           switch (inst->opcode()) {
43             case SpvOpAccessChain:
44             case SpvOpInBoundsAccessChain:
45             case SpvOpPtrAccessChain:
46             case SpvOpInBoundsPtrAccessChain:
47               modified |= CombineAccessChain(inst);
48               break;
49             default:
50               break;
51           }
52         });
53       });
54 
55   return modified;
56 }
57 
GetConstantValue(const analysis::Constant * constant_inst)58 uint32_t CombineAccessChains::GetConstantValue(
59     const analysis::Constant* constant_inst) {
60   if (constant_inst->type()->AsInteger()->width() <= 32) {
61     if (constant_inst->type()->AsInteger()->IsSigned()) {
62       return static_cast<uint32_t>(constant_inst->GetS32());
63     } else {
64       return constant_inst->GetU32();
65     }
66   } else {
67     assert(false);
68     return 0u;
69   }
70 }
71 
GetArrayStride(const Instruction * inst)72 uint32_t CombineAccessChains::GetArrayStride(const Instruction* inst) {
73   uint32_t array_stride = 0;
74   context()->get_decoration_mgr()->WhileEachDecoration(
75       inst->type_id(), SpvDecorationArrayStride,
76       [&array_stride](const Instruction& decoration) {
77         assert(decoration.opcode() != SpvOpDecorateId);
78         if (decoration.opcode() == SpvOpDecorate) {
79           array_stride = decoration.GetSingleWordInOperand(1);
80         } else {
81           array_stride = decoration.GetSingleWordInOperand(2);
82         }
83         return false;
84       });
85   return array_stride;
86 }
87 
GetIndexedType(Instruction * inst)88 const analysis::Type* CombineAccessChains::GetIndexedType(Instruction* inst) {
89   analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
90   analysis::TypeManager* type_mgr = context()->get_type_mgr();
91 
92   Instruction* base_ptr = def_use_mgr->GetDef(inst->GetSingleWordInOperand(0));
93   const analysis::Type* type = type_mgr->GetType(base_ptr->type_id());
94   assert(type->AsPointer());
95   type = type->AsPointer()->pointee_type();
96   std::vector<uint32_t> element_indices;
97   uint32_t starting_index = 1;
98   if (IsPtrAccessChain(inst->opcode())) {
99     // Skip the first index of OpPtrAccessChain as it does not affect type
100     // resolution.
101     starting_index = 2;
102   }
103   for (uint32_t i = starting_index; i < inst->NumInOperands(); ++i) {
104     Instruction* index_inst =
105         def_use_mgr->GetDef(inst->GetSingleWordInOperand(i));
106     const analysis::Constant* index_constant =
107         context()->get_constant_mgr()->GetConstantFromInst(index_inst);
108     if (index_constant) {
109       uint32_t index_value = GetConstantValue(index_constant);
110       element_indices.push_back(index_value);
111     } else {
112       // This index must not matter to resolve the type in valid SPIR-V.
113       element_indices.push_back(0);
114     }
115   }
116   type = type_mgr->GetMemberType(type, element_indices);
117   return type;
118 }
119 
CombineIndices(Instruction * ptr_input,Instruction * inst,std::vector<Operand> * new_operands)120 bool CombineAccessChains::CombineIndices(Instruction* ptr_input,
121                                          Instruction* inst,
122                                          std::vector<Operand>* new_operands) {
123   analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
124   analysis::ConstantManager* constant_mgr = context()->get_constant_mgr();
125 
126   Instruction* last_index_inst = def_use_mgr->GetDef(
127       ptr_input->GetSingleWordInOperand(ptr_input->NumInOperands() - 1));
128   const analysis::Constant* last_index_constant =
129       constant_mgr->GetConstantFromInst(last_index_inst);
130 
131   Instruction* element_inst =
132       def_use_mgr->GetDef(inst->GetSingleWordInOperand(1));
133   const analysis::Constant* element_constant =
134       constant_mgr->GetConstantFromInst(element_inst);
135 
136   // Combine the last index of the AccessChain (|ptr_inst|) with the element
137   // operand of the PtrAccessChain (|inst|).
138   const bool combining_element_operands =
139       IsPtrAccessChain(inst->opcode()) &&
140       IsPtrAccessChain(ptr_input->opcode()) && ptr_input->NumInOperands() == 2;
141   uint32_t new_value_id = 0;
142   const analysis::Type* type = GetIndexedType(ptr_input);
143   if (last_index_constant && element_constant) {
144     // Combine the constants.
145     uint32_t new_value = GetConstantValue(last_index_constant) +
146                          GetConstantValue(element_constant);
147     const analysis::Constant* new_value_constant =
148         constant_mgr->GetConstant(last_index_constant->type(), {new_value});
149     Instruction* new_value_inst =
150         constant_mgr->GetDefiningInstruction(new_value_constant);
151     new_value_id = new_value_inst->result_id();
152   } else if (!type->AsStruct() || combining_element_operands) {
153     // Generate an addition of the two indices.
154     InstructionBuilder builder(
155         context(), inst,
156         IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
157     Instruction* addition = builder.AddIAdd(last_index_inst->type_id(),
158                                             last_index_inst->result_id(),
159                                             element_inst->result_id());
160     new_value_id = addition->result_id();
161   } else {
162     // Indexing into structs must be constant, so bail out here.
163     return false;
164   }
165   new_operands->push_back({SPV_OPERAND_TYPE_ID, {new_value_id}});
166   return true;
167 }
168 
CreateNewInputOperands(Instruction * ptr_input,Instruction * inst,std::vector<Operand> * new_operands)169 bool CombineAccessChains::CreateNewInputOperands(
170     Instruction* ptr_input, Instruction* inst,
171     std::vector<Operand>* new_operands) {
172   // Start by copying all the input operands of the feeder access chain.
173   for (uint32_t i = 0; i != ptr_input->NumInOperands() - 1; ++i) {
174     new_operands->push_back(ptr_input->GetInOperand(i));
175   }
176 
177   // Deal with the last index of the feeder access chain.
178   if (IsPtrAccessChain(inst->opcode())) {
179     // The last index of the feeder should be combined with the element operand
180     // of |inst|.
181     if (!CombineIndices(ptr_input, inst, new_operands)) return false;
182   } else {
183     // The indices aren't being combined so now add the last index operand of
184     // |ptr_input|.
185     new_operands->push_back(
186         ptr_input->GetInOperand(ptr_input->NumInOperands() - 1));
187   }
188 
189   // Copy the remaining index operands.
190   uint32_t starting_index = IsPtrAccessChain(inst->opcode()) ? 2 : 1;
191   for (uint32_t i = starting_index; i < inst->NumInOperands(); ++i) {
192     new_operands->push_back(inst->GetInOperand(i));
193   }
194 
195   return true;
196 }
197 
CombineAccessChain(Instruction * inst)198 bool CombineAccessChains::CombineAccessChain(Instruction* inst) {
199   assert((inst->opcode() == SpvOpPtrAccessChain ||
200           inst->opcode() == SpvOpAccessChain ||
201           inst->opcode() == SpvOpInBoundsAccessChain ||
202           inst->opcode() == SpvOpInBoundsPtrAccessChain) &&
203          "Wrong opcode. Expected an access chain.");
204 
205   Instruction* ptr_input =
206       context()->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(0));
207   if (ptr_input->opcode() != SpvOpAccessChain &&
208       ptr_input->opcode() != SpvOpInBoundsAccessChain &&
209       ptr_input->opcode() != SpvOpPtrAccessChain &&
210       ptr_input->opcode() != SpvOpInBoundsPtrAccessChain) {
211     return false;
212   }
213 
214   if (Has64BitIndices(inst) || Has64BitIndices(ptr_input)) return false;
215 
216   // Handles the following cases:
217   // 1. |ptr_input| is an index-less access chain. Replace the pointer
218   //    in |inst| with |ptr_input|'s pointer.
219   // 2. |inst| is a index-less access chain. Change |inst| to an
220   //    OpCopyObject.
221   // 3. |inst| is not a pointer access chain.
222   //    |inst|'s indices are appended to |ptr_input|'s indices.
223   // 4. |ptr_input| is not pointer access chain.
224   //    |inst| is a pointer access chain.
225   //    |inst|'s element operand is combined with the last index in
226   //    |ptr_input| to form a new operand.
227   // 5. |ptr_input| is a pointer access chain.
228   //    Like the above scenario, |inst|'s element operand is combined
229   //    with |ptr_input|'s last index. This results is either a
230   //    combined element operand or combined regular index.
231 
232   // TODO(alan-baker): Support this properly. Requires analyzing the
233   // size/alignment of the type and converting the stride into an element
234   // index.
235   uint32_t array_stride = GetArrayStride(ptr_input);
236   if (array_stride != 0) return false;
237 
238   if (ptr_input->NumInOperands() == 1) {
239     // The input is effectively a no-op.
240     inst->SetInOperand(0, {ptr_input->GetSingleWordInOperand(0)});
241     context()->AnalyzeUses(inst);
242   } else if (inst->NumInOperands() == 1) {
243     // |inst| is a no-op, change it to a copy. Instruction simplification will
244     // clean it up.
245     inst->SetOpcode(SpvOpCopyObject);
246   } else {
247     std::vector<Operand> new_operands;
248     if (!CreateNewInputOperands(ptr_input, inst, &new_operands)) return false;
249 
250     // Update the instruction.
251     inst->SetOpcode(UpdateOpcode(inst->opcode(), ptr_input->opcode()));
252     inst->SetInOperands(std::move(new_operands));
253     context()->AnalyzeUses(inst);
254   }
255   return true;
256 }
257 
UpdateOpcode(SpvOp base_opcode,SpvOp input_opcode)258 SpvOp CombineAccessChains::UpdateOpcode(SpvOp base_opcode, SpvOp input_opcode) {
259   auto IsInBounds = [](SpvOp opcode) {
260     return opcode == SpvOpInBoundsPtrAccessChain ||
261            opcode == SpvOpInBoundsAccessChain;
262   };
263 
264   if (input_opcode == SpvOpInBoundsPtrAccessChain) {
265     if (!IsInBounds(base_opcode)) return SpvOpPtrAccessChain;
266   } else if (input_opcode == SpvOpInBoundsAccessChain) {
267     if (!IsInBounds(base_opcode)) return SpvOpAccessChain;
268   }
269 
270   return input_opcode;
271 }
272 
IsPtrAccessChain(SpvOp opcode)273 bool CombineAccessChains::IsPtrAccessChain(SpvOp opcode) {
274   return opcode == SpvOpPtrAccessChain || opcode == SpvOpInBoundsPtrAccessChain;
275 }
276 
Has64BitIndices(Instruction * inst)277 bool CombineAccessChains::Has64BitIndices(Instruction* inst) {
278   for (uint32_t i = 1; i < inst->NumInOperands(); ++i) {
279     Instruction* index_inst =
280         context()->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(i));
281     const analysis::Type* index_type =
282         context()->get_type_mgr()->GetType(index_inst->type_id());
283     if (!index_type->AsInteger() || index_type->AsInteger()->width() != 32)
284       return true;
285   }
286   return false;
287 }
288 
289 }  // namespace opt
290 }  // namespace spvtools
291