1 // Copyright (c) 2017 The Khronos Group Inc.
2 // Copyright (c) 2017 Valve Corporation
3 // Copyright (c) 2017 LunarG Inc.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 // http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16
17 #include "source/opt/local_access_chain_convert_pass.h"
18
19 #include "ir_builder.h"
20 #include "ir_context.h"
21 #include "iterator.h"
22
23 namespace spvtools {
24 namespace opt {
25
26 namespace {
27
28 const uint32_t kStoreValIdInIdx = 1;
29 const uint32_t kAccessChainPtrIdInIdx = 0;
30 const uint32_t kConstantValueInIdx = 0;
31 const uint32_t kTypeIntWidthInIdx = 0;
32
33 } // anonymous namespace
34
BuildAndAppendInst(SpvOp opcode,uint32_t typeId,uint32_t resultId,const std::vector<Operand> & in_opnds,std::vector<std::unique_ptr<Instruction>> * newInsts)35 void LocalAccessChainConvertPass::BuildAndAppendInst(
36 SpvOp opcode, uint32_t typeId, uint32_t resultId,
37 const std::vector<Operand>& in_opnds,
38 std::vector<std::unique_ptr<Instruction>>* newInsts) {
39 std::unique_ptr<Instruction> newInst(
40 new Instruction(context(), opcode, typeId, resultId, in_opnds));
41 get_def_use_mgr()->AnalyzeInstDefUse(&*newInst);
42 newInsts->emplace_back(std::move(newInst));
43 }
44
BuildAndAppendVarLoad(const Instruction * ptrInst,uint32_t * varId,uint32_t * varPteTypeId,std::vector<std::unique_ptr<Instruction>> * newInsts)45 uint32_t LocalAccessChainConvertPass::BuildAndAppendVarLoad(
46 const Instruction* ptrInst, uint32_t* varId, uint32_t* varPteTypeId,
47 std::vector<std::unique_ptr<Instruction>>* newInsts) {
48 const uint32_t ldResultId = TakeNextId();
49 *varId = ptrInst->GetSingleWordInOperand(kAccessChainPtrIdInIdx);
50 const Instruction* varInst = get_def_use_mgr()->GetDef(*varId);
51 assert(varInst->opcode() == SpvOpVariable);
52 *varPteTypeId = GetPointeeTypeId(varInst);
53 BuildAndAppendInst(SpvOpLoad, *varPteTypeId, ldResultId,
54 {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {*varId}}},
55 newInsts);
56 return ldResultId;
57 }
58
AppendConstantOperands(const Instruction * ptrInst,std::vector<Operand> * in_opnds)59 void LocalAccessChainConvertPass::AppendConstantOperands(
60 const Instruction* ptrInst, std::vector<Operand>* in_opnds) {
61 uint32_t iidIdx = 0;
62 ptrInst->ForEachInId([&iidIdx, &in_opnds, this](const uint32_t* iid) {
63 if (iidIdx > 0) {
64 const Instruction* cInst = get_def_use_mgr()->GetDef(*iid);
65 uint32_t val = cInst->GetSingleWordInOperand(kConstantValueInIdx);
66 in_opnds->push_back(
67 {spv_operand_type_t::SPV_OPERAND_TYPE_LITERAL_INTEGER, {val}});
68 }
69 ++iidIdx;
70 });
71 }
72
ReplaceAccessChainLoad(const Instruction * address_inst,Instruction * original_load)73 void LocalAccessChainConvertPass::ReplaceAccessChainLoad(
74 const Instruction* address_inst, Instruction* original_load) {
75 // Build and append load of variable in ptrInst
76 std::vector<std::unique_ptr<Instruction>> new_inst;
77 uint32_t varId;
78 uint32_t varPteTypeId;
79 const uint32_t ldResultId =
80 BuildAndAppendVarLoad(address_inst, &varId, &varPteTypeId, &new_inst);
81 context()->get_decoration_mgr()->CloneDecorations(
82 original_load->result_id(), ldResultId, {SpvDecorationRelaxedPrecision});
83 original_load->InsertBefore(std::move(new_inst));
84
85 // Rewrite |original_load| into an extract.
86 Instruction::OperandList new_operands;
87
88 // copy the result id and the type id to the new operand list.
89 new_operands.emplace_back(original_load->GetOperand(0));
90 new_operands.emplace_back(original_load->GetOperand(1));
91
92 new_operands.emplace_back(
93 Operand({spv_operand_type_t::SPV_OPERAND_TYPE_ID, {ldResultId}}));
94 AppendConstantOperands(address_inst, &new_operands);
95 original_load->SetOpcode(SpvOpCompositeExtract);
96 original_load->ReplaceOperands(new_operands);
97 context()->UpdateDefUse(original_load);
98 }
99
GenAccessChainStoreReplacement(const Instruction * ptrInst,uint32_t valId,std::vector<std::unique_ptr<Instruction>> * newInsts)100 void LocalAccessChainConvertPass::GenAccessChainStoreReplacement(
101 const Instruction* ptrInst, uint32_t valId,
102 std::vector<std::unique_ptr<Instruction>>* newInsts) {
103 // Build and append load of variable in ptrInst
104 uint32_t varId;
105 uint32_t varPteTypeId;
106 const uint32_t ldResultId =
107 BuildAndAppendVarLoad(ptrInst, &varId, &varPteTypeId, newInsts);
108 context()->get_decoration_mgr()->CloneDecorations(
109 varId, ldResultId, {SpvDecorationRelaxedPrecision});
110
111 // Build and append Insert
112 const uint32_t insResultId = TakeNextId();
113 std::vector<Operand> ins_in_opnds = {
114 {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {valId}},
115 {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {ldResultId}}};
116 AppendConstantOperands(ptrInst, &ins_in_opnds);
117 BuildAndAppendInst(SpvOpCompositeInsert, varPteTypeId, insResultId,
118 ins_in_opnds, newInsts);
119
120 context()->get_decoration_mgr()->CloneDecorations(
121 varId, insResultId, {SpvDecorationRelaxedPrecision});
122
123 // Build and append Store
124 BuildAndAppendInst(SpvOpStore, 0, 0,
125 {{spv_operand_type_t::SPV_OPERAND_TYPE_ID, {varId}},
126 {spv_operand_type_t::SPV_OPERAND_TYPE_ID, {insResultId}}},
127 newInsts);
128 }
129
IsConstantIndexAccessChain(const Instruction * acp) const130 bool LocalAccessChainConvertPass::IsConstantIndexAccessChain(
131 const Instruction* acp) const {
132 uint32_t inIdx = 0;
133 return acp->WhileEachInId([&inIdx, this](const uint32_t* tid) {
134 if (inIdx > 0) {
135 Instruction* opInst = get_def_use_mgr()->GetDef(*tid);
136 if (opInst->opcode() != SpvOpConstant) return false;
137 }
138 ++inIdx;
139 return true;
140 });
141 }
142
HasOnlySupportedRefs(uint32_t ptrId)143 bool LocalAccessChainConvertPass::HasOnlySupportedRefs(uint32_t ptrId) {
144 if (supported_ref_ptrs_.find(ptrId) != supported_ref_ptrs_.end()) return true;
145 if (get_def_use_mgr()->WhileEachUser(ptrId, [this](Instruction* user) {
146 SpvOp op = user->opcode();
147 if (IsNonPtrAccessChain(op) || op == SpvOpCopyObject) {
148 if (!HasOnlySupportedRefs(user->result_id())) {
149 return false;
150 }
151 } else if (op != SpvOpStore && op != SpvOpLoad && op != SpvOpName &&
152 !IsNonTypeDecorate(op)) {
153 return false;
154 }
155 return true;
156 })) {
157 supported_ref_ptrs_.insert(ptrId);
158 return true;
159 }
160 return false;
161 }
162
FindTargetVars(Function * func)163 void LocalAccessChainConvertPass::FindTargetVars(Function* func) {
164 for (auto bi = func->begin(); bi != func->end(); ++bi) {
165 for (auto ii = bi->begin(); ii != bi->end(); ++ii) {
166 switch (ii->opcode()) {
167 case SpvOpStore:
168 case SpvOpLoad: {
169 uint32_t varId;
170 Instruction* ptrInst = GetPtr(&*ii, &varId);
171 if (!IsTargetVar(varId)) break;
172 const SpvOp op = ptrInst->opcode();
173 // Rule out variables with non-supported refs eg function calls
174 if (!HasOnlySupportedRefs(varId)) {
175 seen_non_target_vars_.insert(varId);
176 seen_target_vars_.erase(varId);
177 break;
178 }
179 // Rule out variables with nested access chains
180 // TODO(): Convert nested access chains
181 if (IsNonPtrAccessChain(op) && ptrInst->GetSingleWordInOperand(
182 kAccessChainPtrIdInIdx) != varId) {
183 seen_non_target_vars_.insert(varId);
184 seen_target_vars_.erase(varId);
185 break;
186 }
187 // Rule out variables accessed with non-constant indices
188 if (!IsConstantIndexAccessChain(ptrInst)) {
189 seen_non_target_vars_.insert(varId);
190 seen_target_vars_.erase(varId);
191 break;
192 }
193 } break;
194 default:
195 break;
196 }
197 }
198 }
199 }
200
ConvertLocalAccessChains(Function * func)201 bool LocalAccessChainConvertPass::ConvertLocalAccessChains(Function* func) {
202 FindTargetVars(func);
203 // Replace access chains of all targeted variables with equivalent
204 // extract and insert sequences
205 bool modified = false;
206 for (auto bi = func->begin(); bi != func->end(); ++bi) {
207 std::vector<Instruction*> dead_instructions;
208 for (auto ii = bi->begin(); ii != bi->end(); ++ii) {
209 switch (ii->opcode()) {
210 case SpvOpLoad: {
211 uint32_t varId;
212 Instruction* ptrInst = GetPtr(&*ii, &varId);
213 if (!IsNonPtrAccessChain(ptrInst->opcode())) break;
214 if (!IsTargetVar(varId)) break;
215 std::vector<std::unique_ptr<Instruction>> newInsts;
216 ReplaceAccessChainLoad(ptrInst, &*ii);
217 modified = true;
218 } break;
219 case SpvOpStore: {
220 uint32_t varId;
221 Instruction* ptrInst = GetPtr(&*ii, &varId);
222 if (!IsNonPtrAccessChain(ptrInst->opcode())) break;
223 if (!IsTargetVar(varId)) break;
224 std::vector<std::unique_ptr<Instruction>> newInsts;
225 uint32_t valId = ii->GetSingleWordInOperand(kStoreValIdInIdx);
226 GenAccessChainStoreReplacement(ptrInst, valId, &newInsts);
227 dead_instructions.push_back(&*ii);
228 ++ii;
229 ii = ii.InsertBefore(std::move(newInsts));
230 ++ii;
231 ++ii;
232 modified = true;
233 } break;
234 default:
235 break;
236 }
237 }
238
239 while (!dead_instructions.empty()) {
240 Instruction* inst = dead_instructions.back();
241 dead_instructions.pop_back();
242 DCEInst(inst, [&dead_instructions](Instruction* other_inst) {
243 auto i = std::find(dead_instructions.begin(), dead_instructions.end(),
244 other_inst);
245 if (i != dead_instructions.end()) {
246 dead_instructions.erase(i);
247 }
248 });
249 }
250 }
251 return modified;
252 }
253
Initialize()254 void LocalAccessChainConvertPass::Initialize() {
255 // Initialize Target Variable Caches
256 seen_target_vars_.clear();
257 seen_non_target_vars_.clear();
258
259 // Initialize collections
260 supported_ref_ptrs_.clear();
261
262 // Initialize extension whitelist
263 InitExtensions();
264 }
265
AllExtensionsSupported() const266 bool LocalAccessChainConvertPass::AllExtensionsSupported() const {
267 // If any extension not in whitelist, return false
268 for (auto& ei : get_module()->extensions()) {
269 const char* extName =
270 reinterpret_cast<const char*>(&ei.GetInOperand(0).words[0]);
271 if (extensions_whitelist_.find(extName) == extensions_whitelist_.end())
272 return false;
273 }
274 return true;
275 }
276
ProcessImpl()277 Pass::Status LocalAccessChainConvertPass::ProcessImpl() {
278 // If non-32-bit integer type in module, terminate processing
279 // TODO(): Handle non-32-bit integer constants in access chains
280 for (const Instruction& inst : get_module()->types_values())
281 if (inst.opcode() == SpvOpTypeInt &&
282 inst.GetSingleWordInOperand(kTypeIntWidthInIdx) != 32)
283 return Status::SuccessWithoutChange;
284 // Do not process if module contains OpGroupDecorate. Additional
285 // support required in KillNamesAndDecorates().
286 // TODO(greg-lunarg): Add support for OpGroupDecorate
287 for (auto& ai : get_module()->annotations())
288 if (ai.opcode() == SpvOpGroupDecorate) return Status::SuccessWithoutChange;
289 // Do not process if any disallowed extensions are enabled
290 if (!AllExtensionsSupported()) return Status::SuccessWithoutChange;
291 // Process all entry point functions.
292 ProcessFunction pfn = [this](Function* fp) {
293 return ConvertLocalAccessChains(fp);
294 };
295 bool modified = context()->ProcessEntryPointCallTree(pfn);
296 return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange;
297 }
298
LocalAccessChainConvertPass()299 LocalAccessChainConvertPass::LocalAccessChainConvertPass() {}
300
Process()301 Pass::Status LocalAccessChainConvertPass::Process() {
302 Initialize();
303 return ProcessImpl();
304 }
305
InitExtensions()306 void LocalAccessChainConvertPass::InitExtensions() {
307 extensions_whitelist_.clear();
308 extensions_whitelist_.insert({
309 "SPV_AMD_shader_explicit_vertex_parameter",
310 "SPV_AMD_shader_trinary_minmax",
311 "SPV_AMD_gcn_shader",
312 "SPV_KHR_shader_ballot",
313 "SPV_AMD_shader_ballot",
314 "SPV_AMD_gpu_shader_half_float",
315 "SPV_KHR_shader_draw_parameters",
316 "SPV_KHR_subgroup_vote",
317 "SPV_KHR_16bit_storage",
318 "SPV_KHR_device_group",
319 "SPV_KHR_multiview",
320 "SPV_NVX_multiview_per_view_attributes",
321 "SPV_NV_viewport_array2",
322 "SPV_NV_stereo_view_rendering",
323 "SPV_NV_sample_mask_override_coverage",
324 "SPV_NV_geometry_shader_passthrough",
325 "SPV_AMD_texture_gather_bias_lod",
326 "SPV_KHR_storage_buffer_storage_class",
327 // SPV_KHR_variable_pointers
328 // Currently do not support extended pointer expressions
329 "SPV_AMD_gpu_shader_int16",
330 "SPV_KHR_post_depth_coverage",
331 "SPV_KHR_shader_atomic_counter_ops",
332 "SPV_EXT_shader_stencil_export",
333 "SPV_EXT_shader_viewport_index_layer",
334 "SPV_AMD_shader_image_load_store_lod",
335 "SPV_AMD_shader_fragment_mask",
336 "SPV_EXT_fragment_fully_covered",
337 "SPV_AMD_gpu_shader_half_float_fetch",
338 "SPV_GOOGLE_decorate_string",
339 "SPV_GOOGLE_hlsl_functionality1",
340 "SPV_NV_shader_subgroup_partitioned",
341 "SPV_EXT_descriptor_indexing",
342 "SPV_NV_fragment_shader_barycentric",
343 "SPV_NV_compute_shader_derivatives",
344 "SPV_NV_shader_image_footprint",
345 "SPV_NV_shading_rate",
346 "SPV_NV_mesh_shader",
347 "SPV_NV_ray_tracing",
348 "SPV_EXT_fragment_invocation_density",
349 });
350 }
351
352 } // namespace opt
353 } // namespace spvtools
354