1 // Copyright (c) 2018 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "upgrade_memory_model.h"
16 
17 #include <utility>
18 
19 #include "source/opt/ir_builder.h"
20 #include "source/opt/ir_context.h"
21 #include "source/util/make_unique.h"
22 
23 namespace spvtools {
24 namespace opt {
25 
Process()26 Pass::Status UpgradeMemoryModel::Process() {
27   // Only update Logical GLSL450 to Logical VulkanKHR.
28   Instruction* memory_model = get_module()->GetMemoryModel();
29   if (memory_model->GetSingleWordInOperand(0u) != SpvAddressingModelLogical ||
30       memory_model->GetSingleWordInOperand(1u) != SpvMemoryModelGLSL450) {
31     return Pass::Status::SuccessWithoutChange;
32   }
33 
34   UpgradeMemoryModelInstruction();
35   UpgradeInstructions();
36   CleanupDecorations();
37   UpgradeBarriers();
38   UpgradeMemoryScope();
39 
40   return Pass::Status::SuccessWithChange;
41 }
42 
UpgradeMemoryModelInstruction()43 void UpgradeMemoryModel::UpgradeMemoryModelInstruction() {
44   // Overall changes necessary:
45   // 1. Add the OpExtension.
46   // 2. Add the OpCapability.
47   // 3. Modify the memory model.
48   Instruction* memory_model = get_module()->GetMemoryModel();
49   get_module()->AddCapability(MakeUnique<Instruction>(
50       context(), SpvOpCapability, 0, 0,
51       std::initializer_list<Operand>{
52           {SPV_OPERAND_TYPE_CAPABILITY, {SpvCapabilityVulkanMemoryModelKHR}}}));
53   const std::string extension = "SPV_KHR_vulkan_memory_model";
54   std::vector<uint32_t> words(extension.size() / 4 + 1, 0);
55   char* dst = reinterpret_cast<char*>(words.data());
56   strncpy(dst, extension.c_str(), extension.size());
57   get_module()->AddExtension(
58       MakeUnique<Instruction>(context(), SpvOpExtension, 0, 0,
59                               std::initializer_list<Operand>{
60                                   {SPV_OPERAND_TYPE_LITERAL_STRING, words}}));
61   memory_model->SetInOperand(1u, {SpvMemoryModelVulkanKHR});
62 }
63 
UpgradeInstructions()64 void UpgradeMemoryModel::UpgradeInstructions() {
65   // Coherent and Volatile decorations are deprecated. Remove them and replace
66   // with flags on the memory/image operations. The decorations can occur on
67   // OpVariable, OpFunctionParameter (of pointer type) and OpStructType (member
68   // decoration). Trace from the decoration target(s) to the final memory/image
69   // instructions. Additionally, Workgroup storage class variables and function
70   // parameters are implicitly coherent in GLSL450.
71 
72   // Upgrade modf and frexp first since they generate new stores.
73   for (auto& func : *get_module()) {
74     func.ForEachInst([this](Instruction* inst) {
75       if (inst->opcode() == SpvOpExtInst) {
76         auto ext_inst = inst->GetSingleWordInOperand(1u);
77         if (ext_inst == GLSLstd450Modf || ext_inst == GLSLstd450Frexp) {
78           auto import =
79               get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(0u));
80           if (reinterpret_cast<char*>(import->GetInOperand(0u).words.data()) ==
81               std::string("GLSL.std.450")) {
82             UpgradeExtInst(inst);
83           }
84         }
85       }
86     });
87   }
88   for (auto& func : *get_module()) {
89     func.ForEachInst([this](Instruction* inst) {
90       bool is_coherent = false;
91       bool is_volatile = false;
92       bool src_coherent = false;
93       bool src_volatile = false;
94       bool dst_coherent = false;
95       bool dst_volatile = false;
96       SpvScope scope = SpvScopeQueueFamilyKHR;
97       SpvScope src_scope = SpvScopeQueueFamilyKHR;
98       SpvScope dst_scope = SpvScopeQueueFamilyKHR;
99       switch (inst->opcode()) {
100         case SpvOpLoad:
101         case SpvOpStore:
102           std::tie(is_coherent, is_volatile, scope) =
103               GetInstructionAttributes(inst->GetSingleWordInOperand(0u));
104           break;
105         case SpvOpImageRead:
106         case SpvOpImageSparseRead:
107         case SpvOpImageWrite:
108           std::tie(is_coherent, is_volatile, scope) =
109               GetInstructionAttributes(inst->GetSingleWordInOperand(0u));
110           break;
111         case SpvOpCopyMemory:
112         case SpvOpCopyMemorySized:
113           std::tie(dst_coherent, dst_volatile, dst_scope) =
114               GetInstructionAttributes(inst->GetSingleWordInOperand(0u));
115           std::tie(src_coherent, src_volatile, src_scope) =
116               GetInstructionAttributes(inst->GetSingleWordInOperand(1u));
117           break;
118         default:
119           break;
120       }
121 
122       switch (inst->opcode()) {
123         case SpvOpLoad:
124           UpgradeFlags(inst, 1u, is_coherent, is_volatile, kVisibility,
125                        kMemory);
126           break;
127         case SpvOpStore:
128           UpgradeFlags(inst, 2u, is_coherent, is_volatile, kAvailability,
129                        kMemory);
130           break;
131         case SpvOpCopyMemory:
132           UpgradeFlags(inst, 2u, dst_coherent, dst_volatile, kAvailability,
133                        kMemory);
134           UpgradeFlags(inst, 2u, src_coherent, src_volatile, kVisibility,
135                        kMemory);
136           break;
137         case SpvOpCopyMemorySized:
138           UpgradeFlags(inst, 3u, dst_coherent, dst_volatile, kAvailability,
139                        kMemory);
140           UpgradeFlags(inst, 3u, src_coherent, src_volatile, kVisibility,
141                        kMemory);
142           break;
143         case SpvOpImageRead:
144         case SpvOpImageSparseRead:
145           UpgradeFlags(inst, 2u, is_coherent, is_volatile, kVisibility, kImage);
146           break;
147         case SpvOpImageWrite:
148           UpgradeFlags(inst, 3u, is_coherent, is_volatile, kAvailability,
149                        kImage);
150           break;
151         default:
152           break;
153       }
154 
155       // |is_coherent| is never used for the same instructions as
156       // |src_coherent| and |dst_coherent|.
157       if (is_coherent) {
158         inst->AddOperand(
159             {SPV_OPERAND_TYPE_SCOPE_ID, {GetScopeConstant(scope)}});
160       }
161       // According to SPV_KHR_vulkan_memory_model, if both available and
162       // visible flags are used the first scope operand is for availability
163       // (writes) and the second is for visibility (reads).
164       if (dst_coherent) {
165         inst->AddOperand(
166             {SPV_OPERAND_TYPE_SCOPE_ID, {GetScopeConstant(dst_scope)}});
167       }
168       if (src_coherent) {
169         inst->AddOperand(
170             {SPV_OPERAND_TYPE_SCOPE_ID, {GetScopeConstant(src_scope)}});
171       }
172     });
173   }
174 }
175 
GetInstructionAttributes(uint32_t id)176 std::tuple<bool, bool, SpvScope> UpgradeMemoryModel::GetInstructionAttributes(
177     uint32_t id) {
178   // |id| is a pointer used in a memory/image instruction. Need to determine if
179   // that pointer points to volatile or coherent memory. Workgroup storage
180   // class is implicitly coherent and cannot be decorated with volatile, so
181   // short circuit that case.
182   Instruction* inst = context()->get_def_use_mgr()->GetDef(id);
183   analysis::Type* type = context()->get_type_mgr()->GetType(inst->type_id());
184   if (type->AsPointer() &&
185       type->AsPointer()->storage_class() == SpvStorageClassWorkgroup) {
186     return std::make_tuple(true, false, SpvScopeWorkgroup);
187   }
188 
189   bool is_coherent = false;
190   bool is_volatile = false;
191   std::unordered_set<uint32_t> visited;
192   std::tie(is_coherent, is_volatile) =
193       TraceInstruction(context()->get_def_use_mgr()->GetDef(id),
194                        std::vector<uint32_t>(), &visited);
195 
196   return std::make_tuple(is_coherent, is_volatile, SpvScopeQueueFamilyKHR);
197 }
198 
TraceInstruction(Instruction * inst,std::vector<uint32_t> indices,std::unordered_set<uint32_t> * visited)199 std::pair<bool, bool> UpgradeMemoryModel::TraceInstruction(
200     Instruction* inst, std::vector<uint32_t> indices,
201     std::unordered_set<uint32_t>* visited) {
202   auto iter = cache_.find(std::make_pair(inst->result_id(), indices));
203   if (iter != cache_.end()) {
204     return iter->second;
205   }
206 
207   if (!visited->insert(inst->result_id()).second) {
208     return std::make_pair(false, false);
209   }
210 
211   // Initialize the cache before |indices| is (potentially) modified.
212   auto& cached_result = cache_[std::make_pair(inst->result_id(), indices)];
213   cached_result.first = false;
214   cached_result.second = false;
215 
216   bool is_coherent = false;
217   bool is_volatile = false;
218   switch (inst->opcode()) {
219     case SpvOpVariable:
220     case SpvOpFunctionParameter:
221       is_coherent |= HasDecoration(inst, 0, SpvDecorationCoherent);
222       is_volatile |= HasDecoration(inst, 0, SpvDecorationVolatile);
223       if (!is_coherent || !is_volatile) {
224         bool type_coherent = false;
225         bool type_volatile = false;
226         std::tie(type_coherent, type_volatile) =
227             CheckType(inst->type_id(), indices);
228         is_coherent |= type_coherent;
229         is_volatile |= type_volatile;
230       }
231       break;
232     case SpvOpAccessChain:
233     case SpvOpInBoundsAccessChain:
234       // Store indices in reverse order.
235       for (uint32_t i = inst->NumInOperands() - 1; i > 0; --i) {
236         indices.push_back(inst->GetSingleWordInOperand(i));
237       }
238       break;
239     case SpvOpPtrAccessChain:
240       // Store indices in reverse order. Skip the |Element| operand.
241       for (uint32_t i = inst->NumInOperands() - 1; i > 1; --i) {
242         indices.push_back(inst->GetSingleWordInOperand(i));
243       }
244       break;
245     default:
246       break;
247   }
248 
249   // No point searching further.
250   if (is_coherent && is_volatile) {
251     cached_result.first = true;
252     cached_result.second = true;
253     return std::make_pair(true, true);
254   }
255 
256   // Variables and function parameters are sources. Continue searching until we
257   // reach them.
258   if (inst->opcode() != SpvOpVariable &&
259       inst->opcode() != SpvOpFunctionParameter) {
260     inst->ForEachInId([this, &is_coherent, &is_volatile, &indices,
261                        &visited](const uint32_t* id_ptr) {
262       Instruction* op_inst = context()->get_def_use_mgr()->GetDef(*id_ptr);
263       const analysis::Type* type =
264           context()->get_type_mgr()->GetType(op_inst->type_id());
265       if (type &&
266           (type->AsPointer() || type->AsImage() || type->AsSampledImage())) {
267         bool operand_coherent = false;
268         bool operand_volatile = false;
269         std::tie(operand_coherent, operand_volatile) =
270             TraceInstruction(op_inst, indices, visited);
271         is_coherent |= operand_coherent;
272         is_volatile |= operand_volatile;
273       }
274     });
275   }
276 
277   cached_result.first = is_coherent;
278   cached_result.second = is_volatile;
279   return std::make_pair(is_coherent, is_volatile);
280 }
281 
CheckType(uint32_t type_id,const std::vector<uint32_t> & indices)282 std::pair<bool, bool> UpgradeMemoryModel::CheckType(
283     uint32_t type_id, const std::vector<uint32_t>& indices) {
284   bool is_coherent = false;
285   bool is_volatile = false;
286   Instruction* type_inst = context()->get_def_use_mgr()->GetDef(type_id);
287   assert(type_inst->opcode() == SpvOpTypePointer);
288   Instruction* element_inst = context()->get_def_use_mgr()->GetDef(
289       type_inst->GetSingleWordInOperand(1u));
290   for (int i = (int)indices.size() - 1; i >= 0; --i) {
291     if (is_coherent && is_volatile) break;
292 
293     if (element_inst->opcode() == SpvOpTypePointer) {
294       element_inst = context()->get_def_use_mgr()->GetDef(
295           element_inst->GetSingleWordInOperand(1u));
296     } else if (element_inst->opcode() == SpvOpTypeStruct) {
297       uint32_t index = indices.at(i);
298       Instruction* index_inst = context()->get_def_use_mgr()->GetDef(index);
299       assert(index_inst->opcode() == SpvOpConstant);
300       uint64_t value = GetIndexValue(index_inst);
301       is_coherent |= HasDecoration(element_inst, static_cast<uint32_t>(value),
302                                    SpvDecorationCoherent);
303       is_volatile |= HasDecoration(element_inst, static_cast<uint32_t>(value),
304                                    SpvDecorationVolatile);
305       element_inst = context()->get_def_use_mgr()->GetDef(
306           element_inst->GetSingleWordInOperand(static_cast<uint32_t>(value)));
307     } else {
308       assert(spvOpcodeIsComposite(element_inst->opcode()));
309       element_inst = context()->get_def_use_mgr()->GetDef(
310           element_inst->GetSingleWordInOperand(1u));
311     }
312   }
313 
314   if (!is_coherent || !is_volatile) {
315     bool remaining_coherent = false;
316     bool remaining_volatile = false;
317     std::tie(remaining_coherent, remaining_volatile) =
318         CheckAllTypes(element_inst);
319     is_coherent |= remaining_coherent;
320     is_volatile |= remaining_volatile;
321   }
322 
323   return std::make_pair(is_coherent, is_volatile);
324 }
325 
CheckAllTypes(const Instruction * inst)326 std::pair<bool, bool> UpgradeMemoryModel::CheckAllTypes(
327     const Instruction* inst) {
328   std::unordered_set<const Instruction*> visited;
329   std::vector<const Instruction*> stack;
330   stack.push_back(inst);
331 
332   bool is_coherent = false;
333   bool is_volatile = false;
334   while (!stack.empty()) {
335     const Instruction* def = stack.back();
336     stack.pop_back();
337 
338     if (!visited.insert(def).second) continue;
339 
340     if (def->opcode() == SpvOpTypeStruct) {
341       // Any member decorated with coherent and/or volatile is enough to have
342       // the related operation be flagged as coherent and/or volatile.
343       is_coherent |= HasDecoration(def, std::numeric_limits<uint32_t>::max(),
344                                    SpvDecorationCoherent);
345       is_volatile |= HasDecoration(def, std::numeric_limits<uint32_t>::max(),
346                                    SpvDecorationVolatile);
347       if (is_coherent && is_volatile)
348         return std::make_pair(is_coherent, is_volatile);
349 
350       // Check the subtypes.
351       for (uint32_t i = 0; i < def->NumInOperands(); ++i) {
352         stack.push_back(context()->get_def_use_mgr()->GetDef(
353             def->GetSingleWordInOperand(i)));
354       }
355     } else if (spvOpcodeIsComposite(def->opcode())) {
356       stack.push_back(context()->get_def_use_mgr()->GetDef(
357           def->GetSingleWordInOperand(0u)));
358     } else if (def->opcode() == SpvOpTypePointer) {
359       stack.push_back(context()->get_def_use_mgr()->GetDef(
360           def->GetSingleWordInOperand(1u)));
361     }
362   }
363 
364   return std::make_pair(is_coherent, is_volatile);
365 }
366 
GetIndexValue(Instruction * index_inst)367 uint64_t UpgradeMemoryModel::GetIndexValue(Instruction* index_inst) {
368   const analysis::Constant* index_constant =
369       context()->get_constant_mgr()->GetConstantFromInst(index_inst);
370   assert(index_constant->AsIntConstant());
371   if (index_constant->type()->AsInteger()->IsSigned()) {
372     if (index_constant->type()->AsInteger()->width() == 32) {
373       return index_constant->GetS32();
374     } else {
375       return index_constant->GetS64();
376     }
377   } else {
378     if (index_constant->type()->AsInteger()->width() == 32) {
379       return index_constant->GetU32();
380     } else {
381       return index_constant->GetU64();
382     }
383   }
384 }
385 
HasDecoration(const Instruction * inst,uint32_t value,SpvDecoration decoration)386 bool UpgradeMemoryModel::HasDecoration(const Instruction* inst, uint32_t value,
387                                        SpvDecoration decoration) {
388   // If the iteration was terminated early then an appropriate decoration was
389   // found.
390   return !context()->get_decoration_mgr()->WhileEachDecoration(
391       inst->result_id(), decoration, [value](const Instruction& i) {
392         if (i.opcode() == SpvOpDecorate || i.opcode() == SpvOpDecorateId) {
393           return false;
394         } else if (i.opcode() == SpvOpMemberDecorate) {
395           if (value == i.GetSingleWordInOperand(1u) ||
396               value == std::numeric_limits<uint32_t>::max())
397             return false;
398         }
399 
400         return true;
401       });
402 }
403 
UpgradeFlags(Instruction * inst,uint32_t in_operand,bool is_coherent,bool is_volatile,OperationType operation_type,InstructionType inst_type)404 void UpgradeMemoryModel::UpgradeFlags(Instruction* inst, uint32_t in_operand,
405                                       bool is_coherent, bool is_volatile,
406                                       OperationType operation_type,
407                                       InstructionType inst_type) {
408   if (!is_coherent && !is_volatile) return;
409 
410   uint32_t flags = 0;
411   if (inst->NumInOperands() > in_operand) {
412     flags |= inst->GetSingleWordInOperand(in_operand);
413   }
414   if (is_coherent) {
415     if (inst_type == kMemory) {
416       flags |= SpvMemoryAccessNonPrivatePointerKHRMask;
417       if (operation_type == kVisibility) {
418         flags |= SpvMemoryAccessMakePointerVisibleKHRMask;
419       } else {
420         flags |= SpvMemoryAccessMakePointerAvailableKHRMask;
421       }
422     } else {
423       flags |= SpvImageOperandsNonPrivateTexelKHRMask;
424       if (operation_type == kVisibility) {
425         flags |= SpvImageOperandsMakeTexelVisibleKHRMask;
426       } else {
427         flags |= SpvImageOperandsMakeTexelAvailableKHRMask;
428       }
429     }
430   }
431 
432   if (is_volatile) {
433     if (inst_type == kMemory) {
434       flags |= SpvMemoryAccessVolatileMask;
435     } else {
436       flags |= SpvImageOperandsVolatileTexelKHRMask;
437     }
438   }
439 
440   if (inst->NumInOperands() > in_operand) {
441     inst->SetInOperand(in_operand, {flags});
442   } else if (inst_type == kMemory) {
443     inst->AddOperand({SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS, {flags}});
444   } else {
445     inst->AddOperand({SPV_OPERAND_TYPE_OPTIONAL_IMAGE, {flags}});
446   }
447 }
448 
GetScopeConstant(SpvScope scope)449 uint32_t UpgradeMemoryModel::GetScopeConstant(SpvScope scope) {
450   analysis::Integer int_ty(32, false);
451   uint32_t int_id = context()->get_type_mgr()->GetTypeInstruction(&int_ty);
452   const analysis::Constant* constant =
453       context()->get_constant_mgr()->GetConstant(
454           context()->get_type_mgr()->GetType(int_id),
455           {static_cast<uint32_t>(scope)});
456   return context()
457       ->get_constant_mgr()
458       ->GetDefiningInstruction(constant)
459       ->result_id();
460 }
461 
CleanupDecorations()462 void UpgradeMemoryModel::CleanupDecorations() {
463   // All of the volatile and coherent decorations have been dealt with, so now
464   // we can just remove them.
465   get_module()->ForEachInst([this](Instruction* inst) {
466     if (inst->result_id() != 0) {
467       context()->get_decoration_mgr()->RemoveDecorationsFrom(
468           inst->result_id(), [](const Instruction& dec) {
469             switch (dec.opcode()) {
470               case SpvOpDecorate:
471               case SpvOpDecorateId:
472                 if (dec.GetSingleWordInOperand(1u) == SpvDecorationCoherent ||
473                     dec.GetSingleWordInOperand(1u) == SpvDecorationVolatile)
474                   return true;
475                 break;
476               case SpvOpMemberDecorate:
477                 if (dec.GetSingleWordInOperand(2u) == SpvDecorationCoherent ||
478                     dec.GetSingleWordInOperand(2u) == SpvDecorationVolatile)
479                   return true;
480                 break;
481               default:
482                 break;
483             }
484             return false;
485           });
486     }
487   });
488 }
489 
UpgradeBarriers()490 void UpgradeMemoryModel::UpgradeBarriers() {
491   std::vector<Instruction*> barriers;
492   // Collects all the control barriers in |function|. Returns true if the
493   // function operates on the Output storage class.
494   ProcessFunction CollectBarriers = [this, &barriers](Function* function) {
495     bool operates_on_output = false;
496     for (auto& block : *function) {
497       block.ForEachInst([this, &barriers,
498                          &operates_on_output](Instruction* inst) {
499         if (inst->opcode() == SpvOpControlBarrier) {
500           barriers.push_back(inst);
501         } else if (!operates_on_output) {
502           // This instruction operates on output storage class if it is a
503           // pointer to output type or any input operand is a pointer to output
504           // type.
505           analysis::Type* type =
506               context()->get_type_mgr()->GetType(inst->type_id());
507           if (type && type->AsPointer() &&
508               type->AsPointer()->storage_class() == SpvStorageClassOutput) {
509             operates_on_output = true;
510             return;
511           }
512           inst->ForEachInId([this, &operates_on_output](uint32_t* id_ptr) {
513             Instruction* op_inst =
514                 context()->get_def_use_mgr()->GetDef(*id_ptr);
515             analysis::Type* op_type =
516                 context()->get_type_mgr()->GetType(op_inst->type_id());
517             if (op_type && op_type->AsPointer() &&
518                 op_type->AsPointer()->storage_class() == SpvStorageClassOutput)
519               operates_on_output = true;
520           });
521         }
522       });
523     }
524     return operates_on_output;
525   };
526 
527   std::queue<uint32_t> roots;
528   for (auto& e : get_module()->entry_points())
529     if (e.GetSingleWordInOperand(0u) == SpvExecutionModelTessellationControl) {
530       roots.push(e.GetSingleWordInOperand(1u));
531       if (context()->ProcessCallTreeFromRoots(CollectBarriers, &roots)) {
532         for (auto barrier : barriers) {
533           // Add OutputMemoryKHR to the semantics of the barriers.
534           uint32_t semantics_id = barrier->GetSingleWordInOperand(2u);
535           Instruction* semantics_inst =
536               context()->get_def_use_mgr()->GetDef(semantics_id);
537           analysis::Type* semantics_type =
538               context()->get_type_mgr()->GetType(semantics_inst->type_id());
539           uint64_t semantics_value = GetIndexValue(semantics_inst);
540           const analysis::Constant* constant =
541               context()->get_constant_mgr()->GetConstant(
542                   semantics_type, {static_cast<uint32_t>(semantics_value) |
543                                    SpvMemorySemanticsOutputMemoryKHRMask});
544           barrier->SetInOperand(2u, {context()
545                                          ->get_constant_mgr()
546                                          ->GetDefiningInstruction(constant)
547                                          ->result_id()});
548         }
549       }
550       barriers.clear();
551     }
552 }
553 
UpgradeMemoryScope()554 void UpgradeMemoryModel::UpgradeMemoryScope() {
555   get_module()->ForEachInst([this](Instruction* inst) {
556     // Don't need to handle all the operations that take a scope.
557     // * Group operations can only be subgroup
558     // * Non-uniform can only be workgroup or subgroup
559     // * Named barriers are not supported by Vulkan
560     // * Workgroup ops (e.g. async_copy) have at most workgroup scope.
561     if (spvOpcodeIsAtomicOp(inst->opcode())) {
562       if (IsDeviceScope(inst->GetSingleWordInOperand(1))) {
563         inst->SetInOperand(1, {GetScopeConstant(SpvScopeQueueFamilyKHR)});
564       }
565     } else if (inst->opcode() == SpvOpControlBarrier) {
566       if (IsDeviceScope(inst->GetSingleWordInOperand(1))) {
567         inst->SetInOperand(1, {GetScopeConstant(SpvScopeQueueFamilyKHR)});
568       }
569     } else if (inst->opcode() == SpvOpMemoryBarrier) {
570       if (IsDeviceScope(inst->GetSingleWordInOperand(0))) {
571         inst->SetInOperand(0, {GetScopeConstant(SpvScopeQueueFamilyKHR)});
572       }
573     }
574   });
575 }
576 
IsDeviceScope(uint32_t scope_id)577 bool UpgradeMemoryModel::IsDeviceScope(uint32_t scope_id) {
578   const analysis::Constant* constant =
579       context()->get_constant_mgr()->FindDeclaredConstant(scope_id);
580   assert(constant && "Memory scope must be a constant");
581 
582   const analysis::Integer* type = constant->type()->AsInteger();
583   assert(type);
584   assert(type->width() == 32 || type->width() == 64);
585   if (type->width() == 32) {
586     if (type->IsSigned())
587       return static_cast<uint32_t>(constant->GetS32()) == SpvScopeDevice;
588     else
589       return static_cast<uint32_t>(constant->GetU32()) == SpvScopeDevice;
590   } else {
591     if (type->IsSigned())
592       return static_cast<uint32_t>(constant->GetS64()) == SpvScopeDevice;
593     else
594       return static_cast<uint32_t>(constant->GetU64()) == SpvScopeDevice;
595   }
596 
597   assert(false);
598   return false;
599 }
600 
UpgradeExtInst(Instruction * ext_inst)601 void UpgradeMemoryModel::UpgradeExtInst(Instruction* ext_inst) {
602   const bool is_modf = ext_inst->GetSingleWordInOperand(1u) == GLSLstd450Modf;
603   auto ptr_id = ext_inst->GetSingleWordInOperand(3u);
604   auto ptr_type_id = get_def_use_mgr()->GetDef(ptr_id)->type_id();
605   auto pointee_type_id =
606       get_def_use_mgr()->GetDef(ptr_type_id)->GetSingleWordInOperand(1u);
607   auto element_type_id = ext_inst->type_id();
608   std::vector<const analysis::Type*> element_types(2);
609   element_types[0] = context()->get_type_mgr()->GetType(element_type_id);
610   element_types[1] = context()->get_type_mgr()->GetType(pointee_type_id);
611   analysis::Struct struct_type(element_types);
612   uint32_t struct_id =
613       context()->get_type_mgr()->GetTypeInstruction(&struct_type);
614   // Change the operation
615   GLSLstd450 new_op = is_modf ? GLSLstd450ModfStruct : GLSLstd450FrexpStruct;
616   ext_inst->SetOperand(3u, {static_cast<uint32_t>(new_op)});
617   // Remove the pointer argument
618   ext_inst->RemoveOperand(5u);
619   // Set the type id to the new struct.
620   ext_inst->SetResultType(struct_id);
621 
622   // The result is now a struct of the original result. The zero'th element is
623   // old result and should replace the old result. The one'th element needs to
624   // be stored via a new instruction.
625   auto where = ext_inst->NextNode();
626   InstructionBuilder builder(
627       context(), where,
628       IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
629   auto extract_0 =
630       builder.AddCompositeExtract(element_type_id, ext_inst->result_id(), {0});
631   context()->ReplaceAllUsesWith(ext_inst->result_id(), extract_0->result_id());
632   // The extract's input was just changed to itself, so fix that.
633   extract_0->SetInOperand(0u, {ext_inst->result_id()});
634   auto extract_1 =
635       builder.AddCompositeExtract(pointee_type_id, ext_inst->result_id(), {1});
636   builder.AddStore(ptr_id, extract_1->result_id());
637 }
638 
639 }  // namespace opt
640 }  // namespace spvtools
641