1 // Copyright (c) 2017 The Khronos Group Inc.
2 // Copyright (c) 2017 Valve Corporation
3 // Copyright (c) 2017 LunarG Inc.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //     http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 
17 #include "source/opt/local_access_chain_convert_pass.h"
18 
19 #include "ir_builder.h"
20 #include "ir_context.h"
21 #include "iterator.h"
22 
23 namespace spvtools {
24 namespace opt {
25 
26 namespace {
27 
28 const uint32_t kStoreValIdInIdx = 1;
29 const uint32_t kAccessChainPtrIdInIdx = 0;
30 const uint32_t kConstantValueInIdx = 0;
31 const uint32_t kTypeIntWidthInIdx = 0;
32 
33 }  // anonymous namespace
34 
BuildAndAppendInst(SpvOp opcode,uint32_t typeId,uint32_t resultId,const std::vector<Operand> & in_opnds,std::vector<std::unique_ptr<Instruction>> * newInsts)35 void LocalAccessChainConvertPass::BuildAndAppendInst(
36     SpvOp opcode, uint32_t typeId, uint32_t resultId,
37     const std::vector<Operand>& in_opnds,
38     std::vector<std::unique_ptr<Instruction>>* newInsts) {
39   std::unique_ptr<Instruction> newInst(
40       new Instruction(context(), opcode, typeId, resultId, in_opnds));
41   get_def_use_mgr()->AnalyzeInstDefUse(&*newInst);
42   newInsts->emplace_back(std::move(newInst));
43 }
44 
BuildAndAppendVarLoad(const Instruction * ptrInst,uint32_t * varId,uint32_t * varPteTypeId,std::vector<std::unique_ptr<Instruction>> * newInsts)45 uint32_t LocalAccessChainConvertPass::BuildAndAppendVarLoad(
46     const Instruction* ptrInst, uint32_t* varId, uint32_t* varPteTypeId,
47     std::vector<std::unique_ptr<Instruction>>* newInsts) {
48   const uint32_t ldResultId = TakeNextId();
49   *varId = ptrInst->GetSingleWordInOperand(kAccessChainPtrIdInIdx);
50   const Instruction* varInst = get_def_use_mgr()->GetDef(*varId);
51   assert(varInst->opcode() == SpvOpVariable);
52   *varPteTypeId = GetPointeeTypeId(varInst);
53   BuildAndAppendInst(SpvOpLoad, *varPteTypeId, ldResultId,
54                      {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {*varId}}},
55                      newInsts);
56   return ldResultId;
57 }
58 
AppendConstantOperands(const Instruction * ptrInst,std::vector<Operand> * in_opnds)59 void LocalAccessChainConvertPass::AppendConstantOperands(
60     const Instruction* ptrInst, std::vector<Operand>* in_opnds) {
61   uint32_t iidIdx = 0;
62   ptrInst->ForEachInId([&iidIdx, &in_opnds, this](const uint32_t* iid) {
63     if (iidIdx > 0) {
64       const Instruction* cInst = get_def_use_mgr()->GetDef(*iid);
65       uint32_t val = cInst->GetSingleWordInOperand(kConstantValueInIdx);
66       in_opnds->push_back(
67           {spv_operand_type_t::SPV_OPERAND_TYPE_LITERAL_INTEGER, {val}});
68     }
69     ++iidIdx;
70   });
71 }
72 
ReplaceAccessChainLoad(const Instruction * address_inst,Instruction * original_load)73 void LocalAccessChainConvertPass::ReplaceAccessChainLoad(
74     const Instruction* address_inst, Instruction* original_load) {
75   // Build and append load of variable in ptrInst
76   std::vector<std::unique_ptr<Instruction>> new_inst;
77   uint32_t varId;
78   uint32_t varPteTypeId;
79   const uint32_t ldResultId =
80       BuildAndAppendVarLoad(address_inst, &varId, &varPteTypeId, &new_inst);
81   context()->get_decoration_mgr()->CloneDecorations(
82       original_load->result_id(), ldResultId, {SpvDecorationRelaxedPrecision});
83   original_load->InsertBefore(std::move(new_inst));
84 
85   // Rewrite |original_load| into an extract.
86   Instruction::OperandList new_operands;
87 
88   // copy the result id and the type id to the new operand list.
89   new_operands.emplace_back(original_load->GetOperand(0));
90   new_operands.emplace_back(original_load->GetOperand(1));
91 
92   new_operands.emplace_back(
93       Operand({spv_operand_type_t::SPV_OPERAND_TYPE_ID, {ldResultId}}));
94   AppendConstantOperands(address_inst, &new_operands);
95   original_load->SetOpcode(SpvOpCompositeExtract);
96   original_load->ReplaceOperands(new_operands);
97   context()->UpdateDefUse(original_load);
98 }
99 
GenAccessChainStoreReplacement(const Instruction * ptrInst,uint32_t valId,std::vector<std::unique_ptr<Instruction>> * newInsts)100 void LocalAccessChainConvertPass::GenAccessChainStoreReplacement(
101     const Instruction* ptrInst, uint32_t valId,
102     std::vector<std::unique_ptr<Instruction>>* newInsts) {
103   // Build and append load of variable in ptrInst
104   uint32_t varId;
105   uint32_t varPteTypeId;
106   const uint32_t ldResultId =
107       BuildAndAppendVarLoad(ptrInst, &varId, &varPteTypeId, newInsts);
108   context()->get_decoration_mgr()->CloneDecorations(
109       varId, ldResultId, {SpvDecorationRelaxedPrecision});
110 
111   // Build and append Insert
112   const uint32_t insResultId = TakeNextId();
113   std::vector<Operand> ins_in_opnds = {
114       {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {valId}},
115       {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {ldResultId}}};
116   AppendConstantOperands(ptrInst, &ins_in_opnds);
117   BuildAndAppendInst(SpvOpCompositeInsert, varPteTypeId, insResultId,
118                      ins_in_opnds, newInsts);
119 
120   context()->get_decoration_mgr()->CloneDecorations(
121       varId, insResultId, {SpvDecorationRelaxedPrecision});
122 
123   // Build and append Store
124   BuildAndAppendInst(SpvOpStore, 0, 0,
125                      {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {varId}},
126                       {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {insResultId}}},
127                      newInsts);
128 }
129 
IsConstantIndexAccessChain(const Instruction * acp) const130 bool LocalAccessChainConvertPass::IsConstantIndexAccessChain(
131     const Instruction* acp) const {
132   uint32_t inIdx = 0;
133   return acp->WhileEachInId([&inIdx, this](const uint32_t* tid) {
134     if (inIdx > 0) {
135       Instruction* opInst = get_def_use_mgr()->GetDef(*tid);
136       if (opInst->opcode() != SpvOpConstant) return false;
137     }
138     ++inIdx;
139     return true;
140   });
141 }
142 
HasOnlySupportedRefs(uint32_t ptrId)143 bool LocalAccessChainConvertPass::HasOnlySupportedRefs(uint32_t ptrId) {
144   if (supported_ref_ptrs_.find(ptrId) != supported_ref_ptrs_.end()) return true;
145   if (get_def_use_mgr()->WhileEachUser(ptrId, [this](Instruction* user) {
146         SpvOp op = user->opcode();
147         if (IsNonPtrAccessChain(op) || op == SpvOpCopyObject) {
148           if (!HasOnlySupportedRefs(user->result_id())) {
149             return false;
150           }
151         } else if (op != SpvOpStore && op != SpvOpLoad && op != SpvOpName &&
152                    !IsNonTypeDecorate(op)) {
153           return false;
154         }
155         return true;
156       })) {
157     supported_ref_ptrs_.insert(ptrId);
158     return true;
159   }
160   return false;
161 }
162 
FindTargetVars(Function * func)163 void LocalAccessChainConvertPass::FindTargetVars(Function* func) {
164   for (auto bi = func->begin(); bi != func->end(); ++bi) {
165     for (auto ii = bi->begin(); ii != bi->end(); ++ii) {
166       switch (ii->opcode()) {
167         case SpvOpStore:
168         case SpvOpLoad: {
169           uint32_t varId;
170           Instruction* ptrInst = GetPtr(&*ii, &varId);
171           if (!IsTargetVar(varId)) break;
172           const SpvOp op = ptrInst->opcode();
173           // Rule out variables with non-supported refs eg function calls
174           if (!HasOnlySupportedRefs(varId)) {
175             seen_non_target_vars_.insert(varId);
176             seen_target_vars_.erase(varId);
177             break;
178           }
179           // Rule out variables with nested access chains
180           // TODO(): Convert nested access chains
181           if (IsNonPtrAccessChain(op) && ptrInst->GetSingleWordInOperand(
182                                              kAccessChainPtrIdInIdx) != varId) {
183             seen_non_target_vars_.insert(varId);
184             seen_target_vars_.erase(varId);
185             break;
186           }
187           // Rule out variables accessed with non-constant indices
188           if (!IsConstantIndexAccessChain(ptrInst)) {
189             seen_non_target_vars_.insert(varId);
190             seen_target_vars_.erase(varId);
191             break;
192           }
193         } break;
194         default:
195           break;
196       }
197     }
198   }
199 }
200 
ConvertLocalAccessChains(Function * func)201 bool LocalAccessChainConvertPass::ConvertLocalAccessChains(Function* func) {
202   FindTargetVars(func);
203   // Replace access chains of all targeted variables with equivalent
204   // extract and insert sequences
205   bool modified = false;
206   for (auto bi = func->begin(); bi != func->end(); ++bi) {
207     std::vector<Instruction*> dead_instructions;
208     for (auto ii = bi->begin(); ii != bi->end(); ++ii) {
209       switch (ii->opcode()) {
210         case SpvOpLoad: {
211           uint32_t varId;
212           Instruction* ptrInst = GetPtr(&*ii, &varId);
213           if (!IsNonPtrAccessChain(ptrInst->opcode())) break;
214           if (!IsTargetVar(varId)) break;
215           std::vector<std::unique_ptr<Instruction>> newInsts;
216           ReplaceAccessChainLoad(ptrInst, &*ii);
217           modified = true;
218         } break;
219         case SpvOpStore: {
220           uint32_t varId;
221           Instruction* ptrInst = GetPtr(&*ii, &varId);
222           if (!IsNonPtrAccessChain(ptrInst->opcode())) break;
223           if (!IsTargetVar(varId)) break;
224           std::vector<std::unique_ptr<Instruction>> newInsts;
225           uint32_t valId = ii->GetSingleWordInOperand(kStoreValIdInIdx);
226           GenAccessChainStoreReplacement(ptrInst, valId, &newInsts);
227           dead_instructions.push_back(&*ii);
228           ++ii;
229           ii = ii.InsertBefore(std::move(newInsts));
230           ++ii;
231           ++ii;
232           modified = true;
233         } break;
234         default:
235           break;
236       }
237     }
238 
239     while (!dead_instructions.empty()) {
240       Instruction* inst = dead_instructions.back();
241       dead_instructions.pop_back();
242       DCEInst(inst, [&dead_instructions](Instruction* other_inst) {
243         auto i = std::find(dead_instructions.begin(), dead_instructions.end(),
244                            other_inst);
245         if (i != dead_instructions.end()) {
246           dead_instructions.erase(i);
247         }
248       });
249     }
250   }
251   return modified;
252 }
253 
Initialize()254 void LocalAccessChainConvertPass::Initialize() {
255   // Initialize Target Variable Caches
256   seen_target_vars_.clear();
257   seen_non_target_vars_.clear();
258 
259   // Initialize collections
260   supported_ref_ptrs_.clear();
261 
262   // Initialize extension whitelist
263   InitExtensions();
264 }
265 
AllExtensionsSupported() const266 bool LocalAccessChainConvertPass::AllExtensionsSupported() const {
267   // If any extension not in whitelist, return false
268   for (auto& ei : get_module()->extensions()) {
269     const char* extName =
270         reinterpret_cast<const char*>(&ei.GetInOperand(0).words[0]);
271     if (extensions_whitelist_.find(extName) == extensions_whitelist_.end())
272       return false;
273   }
274   return true;
275 }
276 
ProcessImpl()277 Pass::Status LocalAccessChainConvertPass::ProcessImpl() {
278   // If non-32-bit integer type in module, terminate processing
279   // TODO(): Handle non-32-bit integer constants in access chains
280   for (const Instruction& inst : get_module()->types_values())
281     if (inst.opcode() == SpvOpTypeInt &&
282         inst.GetSingleWordInOperand(kTypeIntWidthInIdx) != 32)
283       return Status::SuccessWithoutChange;
284   // Do not process if module contains OpGroupDecorate. Additional
285   // support required in KillNamesAndDecorates().
286   // TODO(greg-lunarg): Add support for OpGroupDecorate
287   for (auto& ai : get_module()->annotations())
288     if (ai.opcode() == SpvOpGroupDecorate) return Status::SuccessWithoutChange;
289   // Do not process if any disallowed extensions are enabled
290   if (!AllExtensionsSupported()) return Status::SuccessWithoutChange;
291   // Process all entry point functions.
292   ProcessFunction pfn = [this](Function* fp) {
293     return ConvertLocalAccessChains(fp);
294   };
295   bool modified = context()->ProcessEntryPointCallTree(pfn);
296   return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange;
297 }
298 
LocalAccessChainConvertPass()299 LocalAccessChainConvertPass::LocalAccessChainConvertPass() {}
300 
Process()301 Pass::Status LocalAccessChainConvertPass::Process() {
302   Initialize();
303   return ProcessImpl();
304 }
305 
InitExtensions()306 void LocalAccessChainConvertPass::InitExtensions() {
307   extensions_whitelist_.clear();
308   extensions_whitelist_.insert({
309       "SPV_AMD_shader_explicit_vertex_parameter",
310       "SPV_AMD_shader_trinary_minmax",
311       "SPV_AMD_gcn_shader",
312       "SPV_KHR_shader_ballot",
313       "SPV_AMD_shader_ballot",
314       "SPV_AMD_gpu_shader_half_float",
315       "SPV_KHR_shader_draw_parameters",
316       "SPV_KHR_subgroup_vote",
317       "SPV_KHR_16bit_storage",
318       "SPV_KHR_device_group",
319       "SPV_KHR_multiview",
320       "SPV_NVX_multiview_per_view_attributes",
321       "SPV_NV_viewport_array2",
322       "SPV_NV_stereo_view_rendering",
323       "SPV_NV_sample_mask_override_coverage",
324       "SPV_NV_geometry_shader_passthrough",
325       "SPV_AMD_texture_gather_bias_lod",
326       "SPV_KHR_storage_buffer_storage_class",
327       // SPV_KHR_variable_pointers
328       //   Currently do not support extended pointer expressions
329       "SPV_AMD_gpu_shader_int16",
330       "SPV_KHR_post_depth_coverage",
331       "SPV_KHR_shader_atomic_counter_ops",
332       "SPV_EXT_shader_stencil_export",
333       "SPV_EXT_shader_viewport_index_layer",
334       "SPV_AMD_shader_image_load_store_lod",
335       "SPV_AMD_shader_fragment_mask",
336       "SPV_EXT_fragment_fully_covered",
337       "SPV_AMD_gpu_shader_half_float_fetch",
338       "SPV_GOOGLE_decorate_string",
339       "SPV_GOOGLE_hlsl_functionality1",
340       "SPV_NV_shader_subgroup_partitioned",
341       "SPV_EXT_descriptor_indexing",
342       "SPV_NV_fragment_shader_barycentric",
343       "SPV_NV_compute_shader_derivatives",
344       "SPV_NV_shader_image_footprint",
345       "SPV_NV_shading_rate",
346       "SPV_NV_mesh_shader",
347       "SPV_NV_ray_tracing",
348       "SPV_EXT_fragment_invocation_density",
349   });
350 }
351 
352 }  // namespace opt
353 }  // namespace spvtools
354