1 // Copyright (c) 2018 Google LLC.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "source/val/validate.h"
16 
17 #include <algorithm>
18 #include <string>
19 #include <vector>
20 
21 #include "source/opcode.h"
22 #include "source/spirv_target_env.h"
23 #include "source/val/instruction.h"
24 #include "source/val/validate_scopes.h"
25 #include "source/val/validation_state.h"
26 
27 namespace spvtools {
28 namespace val {
29 namespace {
30 
31 bool AreLayoutCompatibleStructs(ValidationState_t&, const Instruction*,
32                                 const Instruction*);
33 bool HaveLayoutCompatibleMembers(ValidationState_t&, const Instruction*,
34                                  const Instruction*);
35 bool HaveSameLayoutDecorations(ValidationState_t&, const Instruction*,
36                                const Instruction*);
37 bool HasConflictingMemberOffsets(const std::vector<Decoration>&,
38                                  const std::vector<Decoration>&);
39 
IsAllowedTypeOrArrayOfSame(ValidationState_t & _,const Instruction * type,std::initializer_list<uint32_t> allowed)40 bool IsAllowedTypeOrArrayOfSame(ValidationState_t& _, const Instruction* type,
41                                 std::initializer_list<uint32_t> allowed) {
42   if (std::find(allowed.begin(), allowed.end(), type->opcode()) !=
43       allowed.end()) {
44     return true;
45   }
46   if (type->opcode() == SpvOpTypeArray ||
47       type->opcode() == SpvOpTypeRuntimeArray) {
48     auto elem_type = _.FindDef(type->word(2));
49     return std::find(allowed.begin(), allowed.end(), elem_type->opcode()) !=
50            allowed.end();
51   }
52   return false;
53 }
54 
55 // Returns true if the two instructions represent structs that, as far as the
56 // validator can tell, have the exact same data layout.
AreLayoutCompatibleStructs(ValidationState_t & _,const Instruction * type1,const Instruction * type2)57 bool AreLayoutCompatibleStructs(ValidationState_t& _, const Instruction* type1,
58                                 const Instruction* type2) {
59   if (type1->opcode() != SpvOpTypeStruct) {
60     return false;
61   }
62   if (type2->opcode() != SpvOpTypeStruct) {
63     return false;
64   }
65 
66   if (!HaveLayoutCompatibleMembers(_, type1, type2)) return false;
67 
68   return HaveSameLayoutDecorations(_, type1, type2);
69 }
70 
71 // Returns true if the operands to the OpTypeStruct instruction defining the
72 // types are the same or are layout compatible types. |type1| and |type2| must
73 // be OpTypeStruct instructions.
HaveLayoutCompatibleMembers(ValidationState_t & _,const Instruction * type1,const Instruction * type2)74 bool HaveLayoutCompatibleMembers(ValidationState_t& _, const Instruction* type1,
75                                  const Instruction* type2) {
76   assert(type1->opcode() == SpvOpTypeStruct &&
77          "type1 must be an OpTypeStruct instruction.");
78   assert(type2->opcode() == SpvOpTypeStruct &&
79          "type2 must be an OpTypeStruct instruction.");
80   const auto& type1_operands = type1->operands();
81   const auto& type2_operands = type2->operands();
82   if (type1_operands.size() != type2_operands.size()) {
83     return false;
84   }
85 
86   for (size_t operand = 2; operand < type1_operands.size(); ++operand) {
87     if (type1->word(operand) != type2->word(operand)) {
88       auto def1 = _.FindDef(type1->word(operand));
89       auto def2 = _.FindDef(type2->word(operand));
90       if (!AreLayoutCompatibleStructs(_, def1, def2)) {
91         return false;
92       }
93     }
94   }
95   return true;
96 }
97 
98 // Returns true if all decorations that affect the data layout of the struct
99 // (like Offset), are the same for the two types. |type1| and |type2| must be
100 // OpTypeStruct instructions.
HaveSameLayoutDecorations(ValidationState_t & _,const Instruction * type1,const Instruction * type2)101 bool HaveSameLayoutDecorations(ValidationState_t& _, const Instruction* type1,
102                                const Instruction* type2) {
103   assert(type1->opcode() == SpvOpTypeStruct &&
104          "type1 must be an OpTypeStruct instruction.");
105   assert(type2->opcode() == SpvOpTypeStruct &&
106          "type2 must be an OpTypeStruct instruction.");
107   const std::vector<Decoration>& type1_decorations =
108       _.id_decorations(type1->id());
109   const std::vector<Decoration>& type2_decorations =
110       _.id_decorations(type2->id());
111 
112   // TODO: Will have to add other check for arrays an matricies if we want to
113   // handle them.
114   if (HasConflictingMemberOffsets(type1_decorations, type2_decorations)) {
115     return false;
116   }
117 
118   return true;
119 }
120 
HasConflictingMemberOffsets(const std::vector<Decoration> & type1_decorations,const std::vector<Decoration> & type2_decorations)121 bool HasConflictingMemberOffsets(
122     const std::vector<Decoration>& type1_decorations,
123     const std::vector<Decoration>& type2_decorations) {
124   {
125     // We are interested in conflicting decoration.  If a decoration is in one
126     // list but not the other, then we will assume the code is correct.  We are
127     // looking for things we know to be wrong.
128     //
129     // We do not have to traverse type2_decoration because, after traversing
130     // type1_decorations, anything new will not be found in
131     // type1_decoration.  Therefore, it cannot lead to a conflict.
132     for (const Decoration& decoration : type1_decorations) {
133       switch (decoration.dec_type()) {
134         case SpvDecorationOffset: {
135           // Since these affect the layout of the struct, they must be present
136           // in both structs.
137           auto compare = [&decoration](const Decoration& rhs) {
138             if (rhs.dec_type() != SpvDecorationOffset) return false;
139             return decoration.struct_member_index() ==
140                    rhs.struct_member_index();
141           };
142           auto i = std::find_if(type2_decorations.begin(),
143                                 type2_decorations.end(), compare);
144           if (i != type2_decorations.end() &&
145               decoration.params().front() != i->params().front()) {
146             return true;
147           }
148         } break;
149         default:
150           // This decoration does not affect the layout of the structure, so
151           // just moving on.
152           break;
153       }
154     }
155   }
156   return false;
157 }
158 
159 // If |skip_builtin| is true, returns true if |storage| contains bool within
160 // it and no storage that contains the bool is builtin.
161 // If |skip_builtin| is false, returns true if |storage| contains bool within
162 // it.
ContainsInvalidBool(ValidationState_t & _,const Instruction * storage,bool skip_builtin)163 bool ContainsInvalidBool(ValidationState_t& _, const Instruction* storage,
164                          bool skip_builtin) {
165   if (skip_builtin) {
166     for (const Decoration& decoration : _.id_decorations(storage->id())) {
167       if (decoration.dec_type() == SpvDecorationBuiltIn) return false;
168     }
169   }
170 
171   const size_t elem_type_index = 1;
172   uint32_t elem_type_id;
173   Instruction* elem_type;
174 
175   switch (storage->opcode()) {
176     case SpvOpTypeBool:
177       return true;
178     case SpvOpTypeVector:
179     case SpvOpTypeMatrix:
180     case SpvOpTypeArray:
181     case SpvOpTypeRuntimeArray:
182       elem_type_id = storage->GetOperandAs<uint32_t>(elem_type_index);
183       elem_type = _.FindDef(elem_type_id);
184       return ContainsInvalidBool(_, elem_type, skip_builtin);
185     case SpvOpTypeStruct:
186       for (size_t member_type_index = 1;
187            member_type_index < storage->operands().size();
188            ++member_type_index) {
189         auto member_type_id =
190             storage->GetOperandAs<uint32_t>(member_type_index);
191         auto member_type = _.FindDef(member_type_id);
192         if (ContainsInvalidBool(_, member_type, skip_builtin)) return true;
193       }
194     default:
195       break;
196   }
197   return false;
198 }
199 
GetStorageClass(ValidationState_t & _,const Instruction * inst)200 std::pair<SpvStorageClass, SpvStorageClass> GetStorageClass(
201     ValidationState_t& _, const Instruction* inst) {
202   SpvStorageClass dst_sc = SpvStorageClassMax;
203   SpvStorageClass src_sc = SpvStorageClassMax;
204   switch (inst->opcode()) {
205     case SpvOpLoad: {
206       auto load_pointer = _.FindDef(inst->GetOperandAs<uint32_t>(2));
207       auto load_pointer_type = _.FindDef(load_pointer->type_id());
208       dst_sc = load_pointer_type->GetOperandAs<SpvStorageClass>(1);
209       break;
210     }
211     case SpvOpStore: {
212       auto store_pointer = _.FindDef(inst->GetOperandAs<uint32_t>(0));
213       auto store_pointer_type = _.FindDef(store_pointer->type_id());
214       dst_sc = store_pointer_type->GetOperandAs<SpvStorageClass>(1);
215       break;
216     }
217     case SpvOpCopyMemory:
218     case SpvOpCopyMemorySized: {
219       auto dst = _.FindDef(inst->GetOperandAs<uint32_t>(0));
220       auto dst_type = _.FindDef(dst->type_id());
221       dst_sc = dst_type->GetOperandAs<SpvStorageClass>(1);
222       auto src = _.FindDef(inst->GetOperandAs<uint32_t>(1));
223       auto src_type = _.FindDef(src->type_id());
224       src_sc = src_type->GetOperandAs<SpvStorageClass>(1);
225       break;
226     }
227     default:
228       break;
229   }
230 
231   return std::make_pair(dst_sc, src_sc);
232 }
233 
234 // This function is only called for OpLoad, OpStore, OpCopyMemory and
235 // OpCopyMemorySized.
GetMakeAvailableScope(const Instruction * inst,uint32_t mask)236 uint32_t GetMakeAvailableScope(const Instruction* inst, uint32_t mask) {
237   uint32_t offset = 1;
238   if (mask & SpvMemoryAccessAlignedMask) ++offset;
239 
240   uint32_t scope_id = 0;
241   switch (inst->opcode()) {
242     case SpvOpLoad:
243     case SpvOpCopyMemorySized:
244       return inst->GetOperandAs<uint32_t>(3 + offset);
245     case SpvOpStore:
246     case SpvOpCopyMemory:
247       return inst->GetOperandAs<uint32_t>(2 + offset);
248     default:
249       assert(false && "unexpected opcode");
250       break;
251   }
252 
253   return scope_id;
254 }
255 
256 // This function is only called for OpLoad, OpStore, OpCopyMemory and
257 // OpCopyMemorySized.
GetMakeVisibleScope(const Instruction * inst,uint32_t mask)258 uint32_t GetMakeVisibleScope(const Instruction* inst, uint32_t mask) {
259   uint32_t offset = 1;
260   if (mask & SpvMemoryAccessAlignedMask) ++offset;
261   if (mask & SpvMemoryAccessMakePointerAvailableKHRMask) ++offset;
262 
263   uint32_t scope_id = 0;
264   switch (inst->opcode()) {
265     case SpvOpLoad:
266     case SpvOpCopyMemorySized:
267       return inst->GetOperandAs<uint32_t>(3 + offset);
268     case SpvOpStore:
269     case SpvOpCopyMemory:
270       return inst->GetOperandAs<uint32_t>(2 + offset);
271     default:
272       assert(false && "unexpected opcode");
273       break;
274   }
275 
276   return scope_id;
277 }
278 
CheckMemoryAccess(ValidationState_t & _,const Instruction * inst,uint32_t mask)279 spv_result_t CheckMemoryAccess(ValidationState_t& _, const Instruction* inst,
280                                uint32_t mask) {
281   if (mask & SpvMemoryAccessMakePointerAvailableKHRMask) {
282     if (inst->opcode() == SpvOpLoad) {
283       return _.diag(SPV_ERROR_INVALID_ID, inst)
284              << "MakePointerAvailableKHR cannot be used with OpLoad.";
285     }
286 
287     if (!(mask & SpvMemoryAccessNonPrivatePointerKHRMask)) {
288       return _.diag(SPV_ERROR_INVALID_ID, inst)
289              << "NonPrivatePointerKHR must be specified if "
290                 "MakePointerAvailableKHR is specified.";
291     }
292 
293     // Check the associated scope for MakeAvailableKHR.
294     const auto available_scope = GetMakeAvailableScope(inst, mask);
295     if (auto error = ValidateMemoryScope(_, inst, available_scope))
296       return error;
297   }
298 
299   if (mask & SpvMemoryAccessMakePointerVisibleKHRMask) {
300     if (inst->opcode() == SpvOpStore) {
301       return _.diag(SPV_ERROR_INVALID_ID, inst)
302              << "MakePointerVisibleKHR cannot be used with OpStore.";
303     }
304 
305     if (!(mask & SpvMemoryAccessNonPrivatePointerKHRMask)) {
306       return _.diag(SPV_ERROR_INVALID_ID, inst)
307              << "NonPrivatePointerKHR must be specified if "
308                 "MakePointerVisibleKHR is specified.";
309     }
310 
311     // Check the associated scope for MakeVisibleKHR.
312     const auto visible_scope = GetMakeVisibleScope(inst, mask);
313     if (auto error = ValidateMemoryScope(_, inst, visible_scope)) return error;
314   }
315 
316   if (mask & SpvMemoryAccessNonPrivatePointerKHRMask) {
317     SpvStorageClass dst_sc, src_sc;
318     std::tie(dst_sc, src_sc) = GetStorageClass(_, inst);
319     if (dst_sc != SpvStorageClassUniform &&
320         dst_sc != SpvStorageClassWorkgroup &&
321         dst_sc != SpvStorageClassCrossWorkgroup &&
322         dst_sc != SpvStorageClassGeneric && dst_sc != SpvStorageClassImage &&
323         dst_sc != SpvStorageClassStorageBuffer) {
324       return _.diag(SPV_ERROR_INVALID_ID, inst)
325              << "NonPrivatePointerKHR requires a pointer in Uniform, "
326                 "Workgroup, CrossWorkgroup, Generic, Image or StorageBuffer "
327                 "storage classes.";
328     }
329     if (src_sc != SpvStorageClassMax && src_sc != SpvStorageClassUniform &&
330         src_sc != SpvStorageClassWorkgroup &&
331         src_sc != SpvStorageClassCrossWorkgroup &&
332         src_sc != SpvStorageClassGeneric && src_sc != SpvStorageClassImage &&
333         src_sc != SpvStorageClassStorageBuffer) {
334       return _.diag(SPV_ERROR_INVALID_ID, inst)
335              << "NonPrivatePointerKHR requires a pointer in Uniform, "
336                 "Workgroup, CrossWorkgroup, Generic, Image or StorageBuffer "
337                 "storage classes.";
338     }
339   }
340 
341   return SPV_SUCCESS;
342 }
343 
ValidateVariable(ValidationState_t & _,const Instruction * inst)344 spv_result_t ValidateVariable(ValidationState_t& _, const Instruction* inst) {
345   auto result_type = _.FindDef(inst->type_id());
346   if (!result_type || result_type->opcode() != SpvOpTypePointer) {
347     return _.diag(SPV_ERROR_INVALID_ID, inst)
348            << "OpVariable Result Type <id> '" << _.getIdName(inst->type_id())
349            << "' is not a pointer type.";
350   }
351 
352   const auto initializer_index = 3;
353   const auto storage_class_index = 2;
354   if (initializer_index < inst->operands().size()) {
355     const auto initializer_id = inst->GetOperandAs<uint32_t>(initializer_index);
356     const auto initializer = _.FindDef(initializer_id);
357     const auto is_module_scope_var =
358         initializer && (initializer->opcode() == SpvOpVariable) &&
359         (initializer->GetOperandAs<SpvStorageClass>(storage_class_index) !=
360          SpvStorageClassFunction);
361     const auto is_constant =
362         initializer && spvOpcodeIsConstant(initializer->opcode());
363     if (!initializer || !(is_constant || is_module_scope_var)) {
364       return _.diag(SPV_ERROR_INVALID_ID, inst)
365              << "OpVariable Initializer <id> '" << _.getIdName(initializer_id)
366              << "' is not a constant or module-scope variable.";
367     }
368   }
369 
370   const auto storage_class =
371       inst->GetOperandAs<SpvStorageClass>(storage_class_index);
372   if (storage_class != SpvStorageClassWorkgroup &&
373       storage_class != SpvStorageClassCrossWorkgroup &&
374       storage_class != SpvStorageClassPrivate &&
375       storage_class != SpvStorageClassFunction) {
376     const auto storage_index = 2;
377     const auto storage_id = result_type->GetOperandAs<uint32_t>(storage_index);
378     const auto storage = _.FindDef(storage_id);
379     bool storage_input_or_output = storage_class == SpvStorageClassInput ||
380                                    storage_class == SpvStorageClassOutput;
381     bool builtin = false;
382     if (storage_input_or_output) {
383       for (const Decoration& decoration : _.id_decorations(inst->id())) {
384         if (decoration.dec_type() == SpvDecorationBuiltIn) {
385           builtin = true;
386           break;
387         }
388       }
389     }
390     if (!(storage_input_or_output && builtin) &&
391         ContainsInvalidBool(_, storage, storage_input_or_output)) {
392       return _.diag(SPV_ERROR_INVALID_ID, inst)
393              << "If OpTypeBool is stored in conjunction with OpVariable, it "
394              << "can only be used with non-externally visible shader Storage "
395              << "Classes: Workgroup, CrossWorkgroup, Private, and Function";
396     }
397   }
398 
399   // SPIR-V 3.32.8: Check that pointer type and variable type have the same
400   // storage class.
401   const auto result_storage_class_index = 1;
402   const auto result_storage_class =
403       result_type->GetOperandAs<uint32_t>(result_storage_class_index);
404   if (storage_class != result_storage_class) {
405     return _.diag(SPV_ERROR_INVALID_ID, inst)
406            << "From SPIR-V spec, section 3.32.8 on OpVariable:\n"
407            << "Its Storage Class operand must be the same as the Storage Class "
408            << "operand of the result type.";
409   }
410 
411   // Variable pointer related restrictions.
412   auto pointee = _.FindDef(result_type->word(3));
413   if (_.addressing_model() == SpvAddressingModelLogical &&
414       !_.options()->relax_logical_pointer) {
415     // VariablePointersStorageBuffer is implied by VariablePointers.
416     if (pointee->opcode() == SpvOpTypePointer) {
417       if (!_.HasCapability(SpvCapabilityVariablePointersStorageBuffer)) {
418         return _.diag(SPV_ERROR_INVALID_ID, inst) << "In Logical addressing, "
419                                                      "variables may not "
420                                                      "allocate a pointer type";
421       } else if (storage_class != SpvStorageClassFunction &&
422                  storage_class != SpvStorageClassPrivate) {
423         return _.diag(SPV_ERROR_INVALID_ID, inst)
424                << "In Logical addressing with variable pointers, variables "
425                   "that allocate pointers must be in Function or Private "
426                   "storage classes";
427       }
428     }
429   }
430 
431   // Vulkan 14.5.1: Check type of PushConstant variables.
432   // Vulkan 14.5.2: Check type of UniformConstant and Uniform variables.
433   if (spvIsVulkanEnv(_.context()->target_env)) {
434     if (storage_class == SpvStorageClassPushConstant) {
435       if (!IsAllowedTypeOrArrayOfSame(_, pointee, {SpvOpTypeStruct})) {
436         return _.diag(SPV_ERROR_INVALID_ID, inst)
437                << "PushConstant OpVariable <id> '" << _.getIdName(inst->id())
438                << "' has illegal type.\n"
439                << "From Vulkan spec, section 14.5.1:\n"
440                << "Such variables must be typed as OpTypeStruct, "
441                << "or an array of this type";
442       }
443     }
444 
445     if (storage_class == SpvStorageClassUniformConstant) {
446       if (!IsAllowedTypeOrArrayOfSame(
447               _, pointee,
448               {SpvOpTypeImage, SpvOpTypeSampler, SpvOpTypeSampledImage,
449                SpvOpTypeAccelerationStructureNV})) {
450         return _.diag(SPV_ERROR_INVALID_ID, inst)
451                << "UniformConstant OpVariable <id> '" << _.getIdName(inst->id())
452                << "' has illegal type.\n"
453                << "From Vulkan spec, section 14.5.2:\n"
454                << "Variables identified with the UniformConstant storage class "
455                << "are used only as handles to refer to opaque resources. Such "
456                << "variables must be typed as OpTypeImage, OpTypeSampler, "
457                << "OpTypeSampledImage, OpTypeAccelerationStructureNV, "
458                << "or an array of one of these types.";
459       }
460     }
461 
462     if (storage_class == SpvStorageClassUniform) {
463       if (!IsAllowedTypeOrArrayOfSame(_, pointee, {SpvOpTypeStruct})) {
464         return _.diag(SPV_ERROR_INVALID_ID, inst)
465                << "Uniform OpVariable <id> '" << _.getIdName(inst->id())
466                << "' has illegal type.\n"
467                << "From Vulkan spec, section 14.5.2:\n"
468                << "Variables identified with the Uniform storage class are "
469                   "used "
470                << "to access transparent buffer backed resources. Such "
471                   "variables "
472                << "must be typed as OpTypeStruct, or an array of this type";
473       }
474     }
475   }
476 
477   // WebGPU & Vulkan Appendix A: Check that if contains initializer, then
478   // storage class is Output, Private, or Function.
479   if (inst->operands().size() > 3 && storage_class != SpvStorageClassOutput &&
480       storage_class != SpvStorageClassPrivate &&
481       storage_class != SpvStorageClassFunction) {
482     if (spvIsVulkanEnv(_.context()->target_env)) {
483       return _.diag(SPV_ERROR_INVALID_ID, inst)
484              << "OpVariable, <id> '" << _.getIdName(inst->id())
485              << "', has a disallowed initializer & storage class "
486              << "combination.\n"
487              << "From Vulkan spec, Appendix A:\n"
488              << "Variable declarations that include initializers must have "
489              << "one of the following storage classes: Output, Private, or "
490              << "Function";
491     }
492 
493     if (spvIsWebGPUEnv(_.context()->target_env)) {
494       return _.diag(SPV_ERROR_INVALID_ID, inst)
495              << "OpVariable, <id> '" << _.getIdName(inst->id())
496              << "', has a disallowed initializer & storage class "
497              << "combination.\n"
498              << "From WebGPU execution environment spec:\n"
499              << "Variable declarations that include initializers must have "
500              << "one of the following storage classes: Output, Private, or "
501              << "Function";
502     }
503   }
504 
505   return SPV_SUCCESS;
506 }
507 
ValidateLoad(ValidationState_t & _,const Instruction * inst)508 spv_result_t ValidateLoad(ValidationState_t& _, const Instruction* inst) {
509   const auto result_type = _.FindDef(inst->type_id());
510   if (!result_type) {
511     return _.diag(SPV_ERROR_INVALID_ID, inst)
512            << "OpLoad Result Type <id> '" << _.getIdName(inst->type_id())
513            << "' is not defined.";
514   }
515 
516   const bool uses_variable_pointers =
517       _.features().variable_pointers ||
518       _.features().variable_pointers_storage_buffer;
519   const auto pointer_index = 2;
520   const auto pointer_id = inst->GetOperandAs<uint32_t>(pointer_index);
521   const auto pointer = _.FindDef(pointer_id);
522   if (!pointer ||
523       ((_.addressing_model() == SpvAddressingModelLogical) &&
524        ((!uses_variable_pointers &&
525          !spvOpcodeReturnsLogicalPointer(pointer->opcode())) ||
526         (uses_variable_pointers &&
527          !spvOpcodeReturnsLogicalVariablePointer(pointer->opcode()))))) {
528     return _.diag(SPV_ERROR_INVALID_ID, inst)
529            << "OpLoad Pointer <id> '" << _.getIdName(pointer_id)
530            << "' is not a logical pointer.";
531   }
532 
533   const auto pointer_type = _.FindDef(pointer->type_id());
534   if (!pointer_type || pointer_type->opcode() != SpvOpTypePointer) {
535     return _.diag(SPV_ERROR_INVALID_ID, inst)
536            << "OpLoad type for pointer <id> '" << _.getIdName(pointer_id)
537            << "' is not a pointer type.";
538   }
539 
540   const auto pointee_type = _.FindDef(pointer_type->GetOperandAs<uint32_t>(2));
541   if (!pointee_type || result_type->id() != pointee_type->id()) {
542     return _.diag(SPV_ERROR_INVALID_ID, inst)
543            << "OpLoad Result Type <id> '" << _.getIdName(inst->type_id())
544            << "' does not match Pointer <id> '" << _.getIdName(pointer->id())
545            << "'s type.";
546   }
547 
548   if (inst->operands().size() > 3) {
549     if (auto error =
550             CheckMemoryAccess(_, inst, inst->GetOperandAs<uint32_t>(3)))
551       return error;
552   }
553 
554   return SPV_SUCCESS;
555 }
556 
ValidateStore(ValidationState_t & _,const Instruction * inst)557 spv_result_t ValidateStore(ValidationState_t& _, const Instruction* inst) {
558   const bool uses_variable_pointer =
559       _.features().variable_pointers ||
560       _.features().variable_pointers_storage_buffer;
561   const auto pointer_index = 0;
562   const auto pointer_id = inst->GetOperandAs<uint32_t>(pointer_index);
563   const auto pointer = _.FindDef(pointer_id);
564   if (!pointer ||
565       (_.addressing_model() == SpvAddressingModelLogical &&
566        ((!uses_variable_pointer &&
567          !spvOpcodeReturnsLogicalPointer(pointer->opcode())) ||
568         (uses_variable_pointer &&
569          !spvOpcodeReturnsLogicalVariablePointer(pointer->opcode()))))) {
570     return _.diag(SPV_ERROR_INVALID_ID, inst)
571            << "OpStore Pointer <id> '" << _.getIdName(pointer_id)
572            << "' is not a logical pointer.";
573   }
574   const auto pointer_type = _.FindDef(pointer->type_id());
575   if (!pointer_type || pointer_type->opcode() != SpvOpTypePointer) {
576     return _.diag(SPV_ERROR_INVALID_ID, inst)
577            << "OpStore type for pointer <id> '" << _.getIdName(pointer_id)
578            << "' is not a pointer type.";
579   }
580   const auto type_id = pointer_type->GetOperandAs<uint32_t>(2);
581   const auto type = _.FindDef(type_id);
582   if (!type || SpvOpTypeVoid == type->opcode()) {
583     return _.diag(SPV_ERROR_INVALID_ID, inst)
584            << "OpStore Pointer <id> '" << _.getIdName(pointer_id)
585            << "'s type is void.";
586   }
587 
588   // validate storage class
589   {
590     uint32_t data_type;
591     uint32_t storage_class;
592     if (!_.GetPointerTypeInfo(pointer_type->id(), &data_type, &storage_class)) {
593       return _.diag(SPV_ERROR_INVALID_ID, inst)
594              << "OpStore Pointer <id> '" << _.getIdName(pointer_id)
595              << "' is not pointer type";
596     }
597 
598     if (storage_class == SpvStorageClassUniformConstant ||
599         storage_class == SpvStorageClassInput ||
600         storage_class == SpvStorageClassPushConstant) {
601       return _.diag(SPV_ERROR_INVALID_ID, inst)
602              << "OpStore Pointer <id> '" << _.getIdName(pointer_id)
603              << "' storage class is read-only";
604     }
605   }
606 
607   const auto object_index = 1;
608   const auto object_id = inst->GetOperandAs<uint32_t>(object_index);
609   const auto object = _.FindDef(object_id);
610   if (!object || !object->type_id()) {
611     return _.diag(SPV_ERROR_INVALID_ID, inst)
612            << "OpStore Object <id> '" << _.getIdName(object_id)
613            << "' is not an object.";
614   }
615   const auto object_type = _.FindDef(object->type_id());
616   if (!object_type || SpvOpTypeVoid == object_type->opcode()) {
617     return _.diag(SPV_ERROR_INVALID_ID, inst)
618            << "OpStore Object <id> '" << _.getIdName(object_id)
619            << "'s type is void.";
620   }
621 
622   if (type->id() != object_type->id()) {
623     if (!_.options()->relax_struct_store || type->opcode() != SpvOpTypeStruct ||
624         object_type->opcode() != SpvOpTypeStruct) {
625       return _.diag(SPV_ERROR_INVALID_ID, inst)
626              << "OpStore Pointer <id> '" << _.getIdName(pointer_id)
627              << "'s type does not match Object <id> '"
628              << _.getIdName(object->id()) << "'s type.";
629     }
630 
631     // TODO: Check for layout compatible matricies and arrays as well.
632     if (!AreLayoutCompatibleStructs(_, type, object_type)) {
633       return _.diag(SPV_ERROR_INVALID_ID, inst)
634              << "OpStore Pointer <id> '" << _.getIdName(pointer_id)
635              << "'s layout does not match Object <id> '"
636              << _.getIdName(object->id()) << "'s layout.";
637     }
638   }
639 
640   if (inst->operands().size() > 2) {
641     if (auto error =
642             CheckMemoryAccess(_, inst, inst->GetOperandAs<uint32_t>(2)))
643       return error;
644   }
645 
646   return SPV_SUCCESS;
647 }
648 
ValidateCopyMemory(ValidationState_t & _,const Instruction * inst)649 spv_result_t ValidateCopyMemory(ValidationState_t& _, const Instruction* inst) {
650   const auto target_index = 0;
651   const auto target_id = inst->GetOperandAs<uint32_t>(target_index);
652   const auto target = _.FindDef(target_id);
653   if (!target) {
654     return _.diag(SPV_ERROR_INVALID_ID, inst)
655            << "Target operand <id> '" << _.getIdName(target_id)
656            << "' is not defined.";
657   }
658 
659   const auto source_index = 1;
660   const auto source_id = inst->GetOperandAs<uint32_t>(source_index);
661   const auto source = _.FindDef(source_id);
662   if (!source) {
663     return _.diag(SPV_ERROR_INVALID_ID, inst)
664            << "Source operand <id> '" << _.getIdName(source_id)
665            << "' is not defined.";
666   }
667 
668   const auto target_pointer_type = _.FindDef(target->type_id());
669   if (!target_pointer_type ||
670       target_pointer_type->opcode() != SpvOpTypePointer) {
671     return _.diag(SPV_ERROR_INVALID_ID, inst)
672            << "Target operand <id> '" << _.getIdName(target_id)
673            << "' is not a pointer.";
674   }
675 
676   const auto source_pointer_type = _.FindDef(source->type_id());
677   if (!source_pointer_type ||
678       source_pointer_type->opcode() != SpvOpTypePointer) {
679     return _.diag(SPV_ERROR_INVALID_ID, inst)
680            << "Source operand <id> '" << _.getIdName(source_id)
681            << "' is not a pointer.";
682   }
683 
684   if (inst->opcode() == SpvOpCopyMemory) {
685     const auto target_type =
686         _.FindDef(target_pointer_type->GetOperandAs<uint32_t>(2));
687     if (!target_type || target_type->opcode() == SpvOpTypeVoid) {
688       return _.diag(SPV_ERROR_INVALID_ID, inst)
689              << "Target operand <id> '" << _.getIdName(target_id)
690              << "' cannot be a void pointer.";
691     }
692 
693     const auto source_type =
694         _.FindDef(source_pointer_type->GetOperandAs<uint32_t>(2));
695     if (!source_type || source_type->opcode() == SpvOpTypeVoid) {
696       return _.diag(SPV_ERROR_INVALID_ID, inst)
697              << "Source operand <id> '" << _.getIdName(source_id)
698              << "' cannot be a void pointer.";
699     }
700 
701     if (target_type->id() != source_type->id()) {
702       return _.diag(SPV_ERROR_INVALID_ID, inst)
703              << "Target <id> '" << _.getIdName(source_id)
704              << "'s type does not match Source <id> '"
705              << _.getIdName(source_type->id()) << "'s type.";
706     }
707 
708     if (inst->operands().size() > 2) {
709       if (auto error =
710               CheckMemoryAccess(_, inst, inst->GetOperandAs<uint32_t>(2)))
711         return error;
712     }
713   } else {
714     const auto size_id = inst->GetOperandAs<uint32_t>(2);
715     const auto size = _.FindDef(size_id);
716     if (!size) {
717       return _.diag(SPV_ERROR_INVALID_ID, inst)
718              << "Size operand <id> '" << _.getIdName(size_id)
719              << "' is not defined.";
720     }
721 
722     const auto size_type = _.FindDef(size->type_id());
723     if (!_.IsIntScalarType(size_type->id())) {
724       return _.diag(SPV_ERROR_INVALID_ID, inst)
725              << "Size operand <id> '" << _.getIdName(size_id)
726              << "' must be a scalar integer type.";
727     }
728 
729     bool is_zero = true;
730     switch (size->opcode()) {
731       case SpvOpConstantNull:
732         return _.diag(SPV_ERROR_INVALID_ID, inst)
733                << "Size operand <id> '" << _.getIdName(size_id)
734                << "' cannot be a constant zero.";
735       case SpvOpConstant:
736         if (size_type->word(3) == 1 &&
737             size->word(size->words().size() - 1) & 0x80000000) {
738           return _.diag(SPV_ERROR_INVALID_ID, inst)
739                  << "Size operand <id> '" << _.getIdName(size_id)
740                  << "' cannot have the sign bit set to 1.";
741         }
742         for (size_t i = 3; is_zero && i < size->words().size(); ++i) {
743           is_zero &= (size->word(i) == 0);
744         }
745         if (is_zero) {
746           return _.diag(SPV_ERROR_INVALID_ID, inst)
747                  << "Size operand <id> '" << _.getIdName(size_id)
748                  << "' cannot be a constant zero.";
749         }
750         break;
751       default:
752         // Cannot infer any other opcodes.
753         break;
754     }
755 
756     if (inst->operands().size() > 3) {
757       if (auto error =
758               CheckMemoryAccess(_, inst, inst->GetOperandAs<uint32_t>(3)))
759         return error;
760     }
761   }
762   return SPV_SUCCESS;
763 }
764 
ValidateAccessChain(ValidationState_t & _,const Instruction * inst)765 spv_result_t ValidateAccessChain(ValidationState_t& _,
766                                  const Instruction* inst) {
767   std::string instr_name =
768       "Op" + std::string(spvOpcodeString(static_cast<SpvOp>(inst->opcode())));
769 
770   // The result type must be OpTypePointer.
771   auto result_type = _.FindDef(inst->type_id());
772   if (SpvOpTypePointer != result_type->opcode()) {
773     return _.diag(SPV_ERROR_INVALID_ID, inst)
774            << "The Result Type of " << instr_name << " <id> '"
775            << _.getIdName(inst->id()) << "' must be OpTypePointer. Found Op"
776            << spvOpcodeString(static_cast<SpvOp>(result_type->opcode())) << ".";
777   }
778 
779   // Result type is a pointer. Find out what it's pointing to.
780   // This will be used to make sure the indexing results in the same type.
781   // OpTypePointer word 3 is the type being pointed to.
782   const auto result_type_pointee = _.FindDef(result_type->word(3));
783 
784   // Base must be a pointer, pointing to the base of a composite object.
785   const auto base_index = 2;
786   const auto base_id = inst->GetOperandAs<uint32_t>(base_index);
787   const auto base = _.FindDef(base_id);
788   const auto base_type = _.FindDef(base->type_id());
789   if (!base_type || SpvOpTypePointer != base_type->opcode()) {
790     return _.diag(SPV_ERROR_INVALID_ID, inst)
791            << "The Base <id> '" << _.getIdName(base_id) << "' in " << instr_name
792            << " instruction must be a pointer.";
793   }
794 
795   // The result pointer storage class and base pointer storage class must match.
796   // Word 2 of OpTypePointer is the Storage Class.
797   auto result_type_storage_class = result_type->word(2);
798   auto base_type_storage_class = base_type->word(2);
799   if (result_type_storage_class != base_type_storage_class) {
800     return _.diag(SPV_ERROR_INVALID_ID, inst)
801            << "The result pointer storage class and base "
802               "pointer storage class in "
803            << instr_name << " do not match.";
804   }
805 
806   // The type pointed to by OpTypePointer (word 3) must be a composite type.
807   auto type_pointee = _.FindDef(base_type->word(3));
808 
809   // Check Universal Limit (SPIR-V Spec. Section 2.17).
810   // The number of indexes passed to OpAccessChain may not exceed 255
811   // The instruction includes 4 words + N words (for N indexes)
812   size_t num_indexes = inst->words().size() - 4;
813   if (inst->opcode() == SpvOpPtrAccessChain ||
814       inst->opcode() == SpvOpInBoundsPtrAccessChain) {
815     // In pointer access chains, the element operand is required, but not
816     // counted as an index.
817     --num_indexes;
818   }
819   const size_t num_indexes_limit =
820       _.options()->universal_limits_.max_access_chain_indexes;
821   if (num_indexes > num_indexes_limit) {
822     return _.diag(SPV_ERROR_INVALID_ID, inst)
823            << "The number of indexes in " << instr_name << " may not exceed "
824            << num_indexes_limit << ". Found " << num_indexes << " indexes.";
825   }
826   // Indexes walk the type hierarchy to the desired depth, potentially down to
827   // scalar granularity. The first index in Indexes will select the top-level
828   // member/element/component/element of the base composite. All composite
829   // constituents use zero-based numbering, as described by their OpType...
830   // instruction. The second index will apply similarly to that result, and so
831   // on. Once any non-composite type is reached, there must be no remaining
832   // (unused) indexes.
833   auto starting_index = 4;
834   if (inst->opcode() == SpvOpPtrAccessChain ||
835       inst->opcode() == SpvOpInBoundsPtrAccessChain) {
836     ++starting_index;
837   }
838   for (size_t i = starting_index; i < inst->words().size(); ++i) {
839     const uint32_t cur_word = inst->words()[i];
840     // Earlier ID checks ensure that cur_word definition exists.
841     auto cur_word_instr = _.FindDef(cur_word);
842     // The index must be a scalar integer type (See OpAccessChain in the Spec.)
843     auto index_type = _.FindDef(cur_word_instr->type_id());
844     if (!index_type || SpvOpTypeInt != index_type->opcode()) {
845       return _.diag(SPV_ERROR_INVALID_ID, inst)
846              << "Indexes passed to " << instr_name
847              << " must be of type integer.";
848     }
849     switch (type_pointee->opcode()) {
850       case SpvOpTypeMatrix:
851       case SpvOpTypeVector:
852       case SpvOpTypeArray:
853       case SpvOpTypeRuntimeArray: {
854         // In OpTypeMatrix, OpTypeVector, OpTypeArray, and OpTypeRuntimeArray,
855         // word 2 is the Element Type.
856         type_pointee = _.FindDef(type_pointee->word(2));
857         break;
858       }
859       case SpvOpTypeStruct: {
860         // In case of structures, there is an additional constraint on the
861         // index: the index must be an OpConstant.
862         if (SpvOpConstant != cur_word_instr->opcode()) {
863           return _.diag(SPV_ERROR_INVALID_ID, cur_word_instr)
864                  << "The <id> passed to " << instr_name
865                  << " to index into a "
866                     "structure must be an OpConstant.";
867         }
868         // Get the index value from the OpConstant (word 3 of OpConstant).
869         // OpConstant could be a signed integer. But it's okay to treat it as
870         // unsigned because a negative constant int would never be seen as
871         // correct as a struct offset, since structs can't have more than 2
872         // billion members.
873         const uint32_t cur_index = cur_word_instr->word(3);
874         // The index points to the struct member we want, therefore, the index
875         // should be less than the number of struct members.
876         const uint32_t num_struct_members =
877             static_cast<uint32_t>(type_pointee->words().size() - 2);
878         if (cur_index >= num_struct_members) {
879           return _.diag(SPV_ERROR_INVALID_ID, cur_word_instr)
880                  << "Index is out of bounds: " << instr_name
881                  << " can not find index " << cur_index
882                  << " into the structure <id> '"
883                  << _.getIdName(type_pointee->id()) << "'. This structure has "
884                  << num_struct_members << " members. Largest valid index is "
885                  << num_struct_members - 1 << ".";
886         }
887         // Struct members IDs start at word 2 of OpTypeStruct.
888         auto structMemberId = type_pointee->word(cur_index + 2);
889         type_pointee = _.FindDef(structMemberId);
890         break;
891       }
892       default: {
893         // Give an error. reached non-composite type while indexes still remain.
894         return _.diag(SPV_ERROR_INVALID_ID, cur_word_instr)
895                << instr_name
896                << " reached non-composite type while indexes "
897                   "still remain to be traversed.";
898       }
899     }
900   }
901   // At this point, we have fully walked down from the base using the indeces.
902   // The type being pointed to should be the same as the result type.
903   if (type_pointee->id() != result_type_pointee->id()) {
904     return _.diag(SPV_ERROR_INVALID_ID, inst)
905            << instr_name << " result type (Op"
906            << spvOpcodeString(static_cast<SpvOp>(result_type_pointee->opcode()))
907            << ") does not match the type that results from indexing into the "
908               "base "
909               "<id> (Op"
910            << spvOpcodeString(static_cast<SpvOp>(type_pointee->opcode()))
911            << ").";
912   }
913 
914   return SPV_SUCCESS;
915 }
916 
ValidatePtrAccessChain(ValidationState_t & _,const Instruction * inst)917 spv_result_t ValidatePtrAccessChain(ValidationState_t& _,
918                                     const Instruction* inst) {
919   if (_.addressing_model() == SpvAddressingModelLogical) {
920     if (!_.features().variable_pointers &&
921         !_.features().variable_pointers_storage_buffer) {
922       return _.diag(SPV_ERROR_INVALID_DATA, inst)
923              << "Generating variable pointers requires capability "
924              << "VariablePointers or VariablePointersStorageBuffer";
925     }
926   }
927   return ValidateAccessChain(_, inst);
928 }
929 
ValidateArrayLength(ValidationState_t & state,const Instruction * inst)930 spv_result_t ValidateArrayLength(ValidationState_t& state,
931                                  const Instruction* inst) {
932   std::string instr_name =
933       "Op" + std::string(spvOpcodeString(static_cast<SpvOp>(inst->opcode())));
934 
935   // Result type must be a 32-bit unsigned int.
936   auto result_type = state.FindDef(inst->type_id());
937   if (result_type->opcode() != SpvOpTypeInt ||
938       result_type->GetOperandAs<uint32_t>(1) != 32 ||
939       result_type->GetOperandAs<uint32_t>(2) != 0) {
940     return state.diag(SPV_ERROR_INVALID_ID, inst)
941            << "The Result Type of " << instr_name << " <id> '"
942            << state.getIdName(inst->id())
943            << "' must be OpTypeInt with width 32 and signedness 0.";
944   }
945 
946   // The structure that is passed in must be an pointer to a structure, whose
947   // last element is a runtime array.
948   auto pointer = state.FindDef(inst->GetOperandAs<uint32_t>(2));
949   auto pointer_type = state.FindDef(pointer->type_id());
950   if (pointer_type->opcode() != SpvOpTypePointer) {
951     return state.diag(SPV_ERROR_INVALID_ID, inst)
952            << "The Struture's type in " << instr_name << " <id> '"
953            << state.getIdName(inst->id())
954            << "' must be a pointer to an OpTypeStruct.";
955   }
956 
957   auto structure_type = state.FindDef(pointer_type->GetOperandAs<uint32_t>(2));
958   if (structure_type->opcode() != SpvOpTypeStruct) {
959     return state.diag(SPV_ERROR_INVALID_ID, inst)
960            << "The Struture's type in " << instr_name << " <id> '"
961            << state.getIdName(inst->id())
962            << "' must be a pointer to an OpTypeStruct.";
963   }
964 
965   auto num_of_members = structure_type->operands().size() - 1;
966   auto last_member =
967       state.FindDef(structure_type->GetOperandAs<uint32_t>(num_of_members));
968   if (last_member->opcode() != SpvOpTypeRuntimeArray) {
969     return state.diag(SPV_ERROR_INVALID_ID, inst)
970            << "The Struture's last member in " << instr_name << " <id> '"
971            << state.getIdName(inst->id()) << "' must be an OpTypeRuntimeArray.";
972   }
973 
974   // The array member must the the index of the last element (the run time
975   // array).
976   if (inst->GetOperandAs<uint32_t>(3) != num_of_members - 1) {
977     return state.diag(SPV_ERROR_INVALID_ID, inst)
978            << "The array member in " << instr_name << " <id> '"
979            << state.getIdName(inst->id())
980            << "' must be an the last member of the struct.";
981   }
982   return SPV_SUCCESS;
983 }
984 
985 }  // namespace
986 
MemoryPass(ValidationState_t & _,const Instruction * inst)987 spv_result_t MemoryPass(ValidationState_t& _, const Instruction* inst) {
988   switch (inst->opcode()) {
989     case SpvOpVariable:
990       if (auto error = ValidateVariable(_, inst)) return error;
991       break;
992     case SpvOpLoad:
993       if (auto error = ValidateLoad(_, inst)) return error;
994       break;
995     case SpvOpStore:
996       if (auto error = ValidateStore(_, inst)) return error;
997       break;
998     case SpvOpCopyMemory:
999     case SpvOpCopyMemorySized:
1000       if (auto error = ValidateCopyMemory(_, inst)) return error;
1001       break;
1002     case SpvOpPtrAccessChain:
1003       if (auto error = ValidatePtrAccessChain(_, inst)) return error;
1004       break;
1005     case SpvOpAccessChain:
1006     case SpvOpInBoundsAccessChain:
1007     case SpvOpInBoundsPtrAccessChain:
1008       if (auto error = ValidateAccessChain(_, inst)) return error;
1009       break;
1010     case SpvOpArrayLength:
1011       if (auto error = ValidateArrayLength(_, inst)) return error;
1012       break;
1013     case SpvOpImageTexelPointer:
1014     case SpvOpGenericPtrMemSemantics:
1015     default:
1016       break;
1017   }
1018 
1019   return SPV_SUCCESS;
1020 }
1021 }  // namespace val
1022 }  // namespace spvtools
1023