1 // Copyright (c) 2017 Google Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 // Validates correctness of atomic SPIR-V instructions.
16 
17 #include "source/val/validate.h"
18 
19 #include "source/diagnostic.h"
20 #include "source/opcode.h"
21 #include "source/spirv_target_env.h"
22 #include "source/util/bitutils.h"
23 #include "source/val/instruction.h"
24 #include "source/val/validate_memory_semantics.h"
25 #include "source/val/validate_scopes.h"
26 #include "source/val/validation_state.h"
27 
28 namespace spvtools {
29 namespace val {
30 
31 // Validates correctness of atomic instructions.
AtomicsPass(ValidationState_t & _,const Instruction * inst)32 spv_result_t AtomicsPass(ValidationState_t& _, const Instruction* inst) {
33   const SpvOp opcode = inst->opcode();
34   const uint32_t result_type = inst->type_id();
35 
36   switch (opcode) {
37     case SpvOpAtomicLoad:
38     case SpvOpAtomicStore:
39     case SpvOpAtomicExchange:
40     case SpvOpAtomicCompareExchange:
41     case SpvOpAtomicCompareExchangeWeak:
42     case SpvOpAtomicIIncrement:
43     case SpvOpAtomicIDecrement:
44     case SpvOpAtomicIAdd:
45     case SpvOpAtomicISub:
46     case SpvOpAtomicSMin:
47     case SpvOpAtomicUMin:
48     case SpvOpAtomicSMax:
49     case SpvOpAtomicUMax:
50     case SpvOpAtomicAnd:
51     case SpvOpAtomicOr:
52     case SpvOpAtomicXor:
53     case SpvOpAtomicFlagTestAndSet:
54     case SpvOpAtomicFlagClear: {
55       if (_.HasCapability(SpvCapabilityKernel) &&
56           (opcode == SpvOpAtomicLoad || opcode == SpvOpAtomicExchange ||
57            opcode == SpvOpAtomicCompareExchange)) {
58         if (!_.IsFloatScalarType(result_type) &&
59             !_.IsIntScalarType(result_type)) {
60           return _.diag(SPV_ERROR_INVALID_DATA, inst)
61                  << spvOpcodeString(opcode)
62                  << ": expected Result Type to be int or float scalar type";
63         }
64       } else if (opcode == SpvOpAtomicFlagTestAndSet) {
65         if (!_.IsBoolScalarType(result_type)) {
66           return _.diag(SPV_ERROR_INVALID_DATA, inst)
67                  << spvOpcodeString(opcode)
68                  << ": expected Result Type to be bool scalar type";
69         }
70       } else if (opcode == SpvOpAtomicFlagClear || opcode == SpvOpAtomicStore) {
71         assert(result_type == 0);
72       } else {
73         if (!_.IsIntScalarType(result_type)) {
74           return _.diag(SPV_ERROR_INVALID_DATA, inst)
75                  << spvOpcodeString(opcode)
76                  << ": expected Result Type to be int scalar type";
77         }
78         if (spvIsVulkanEnv(_.context()->target_env) &&
79             _.GetBitWidth(result_type) != 32) {
80           switch (opcode) {
81             case SpvOpAtomicSMin:
82             case SpvOpAtomicUMin:
83             case SpvOpAtomicSMax:
84             case SpvOpAtomicUMax:
85             case SpvOpAtomicAnd:
86             case SpvOpAtomicOr:
87             case SpvOpAtomicXor:
88             case SpvOpAtomicIAdd:
89             case SpvOpAtomicLoad:
90             case SpvOpAtomicStore:
91             case SpvOpAtomicExchange:
92             case SpvOpAtomicCompareExchange: {
93               if (_.GetBitWidth(result_type) == 64 &&
94                   !_.HasCapability(SpvCapabilityInt64Atomics))
95                 return _.diag(SPV_ERROR_INVALID_DATA, inst)
96                        << spvOpcodeString(opcode)
97                        << ": 64-bit atomics require the Int64Atomics "
98                           "capability";
99             } break;
100             default:
101               return _.diag(SPV_ERROR_INVALID_DATA, inst)
102                      << spvOpcodeString(opcode)
103                      << ": according to the Vulkan spec atomic Result Type "
104                         "needs "
105                         "to be a 32-bit int scalar type";
106           }
107         }
108       }
109 
110       uint32_t operand_index =
111           opcode == SpvOpAtomicFlagClear || opcode == SpvOpAtomicStore ? 0 : 2;
112       const uint32_t pointer_type = _.GetOperandTypeId(inst, operand_index++);
113 
114       uint32_t data_type = 0;
115       uint32_t storage_class = 0;
116       if (!_.GetPointerTypeInfo(pointer_type, &data_type, &storage_class)) {
117         return _.diag(SPV_ERROR_INVALID_DATA, inst)
118                << spvOpcodeString(opcode)
119                << ": expected Pointer to be of type OpTypePointer";
120       }
121 
122       switch (storage_class) {
123         case SpvStorageClassUniform:
124         case SpvStorageClassWorkgroup:
125         case SpvStorageClassCrossWorkgroup:
126         case SpvStorageClassGeneric:
127         case SpvStorageClassAtomicCounter:
128         case SpvStorageClassImage:
129         case SpvStorageClassStorageBuffer:
130         case SpvStorageClassPhysicalStorageBufferEXT:
131           break;
132         default:
133           if (spvIsOpenCLEnv(_.context()->target_env)) {
134             if (storage_class != SpvStorageClassFunction) {
135               return _.diag(SPV_ERROR_INVALID_DATA, inst)
136                      << spvOpcodeString(opcode)
137                      << ": expected Pointer Storage Class to be Uniform, "
138                         "Workgroup, CrossWorkgroup, Generic, AtomicCounter, "
139                         "Image, StorageBuffer or Function";
140             }
141           } else {
142             return _.diag(SPV_ERROR_INVALID_DATA, inst)
143                    << spvOpcodeString(opcode)
144                    << ": expected Pointer Storage Class to be Uniform, "
145                       "Workgroup, CrossWorkgroup, Generic, AtomicCounter, "
146                       "Image or StorageBuffer";
147           }
148       }
149 
150       if (opcode == SpvOpAtomicFlagTestAndSet ||
151           opcode == SpvOpAtomicFlagClear) {
152         if (!_.IsIntScalarType(data_type) || _.GetBitWidth(data_type) != 32) {
153           return _.diag(SPV_ERROR_INVALID_DATA, inst)
154                  << spvOpcodeString(opcode)
155                  << ": expected Pointer to point to a value of 32-bit int type";
156         }
157       } else if (opcode == SpvOpAtomicStore) {
158         if (!_.IsFloatScalarType(data_type) && !_.IsIntScalarType(data_type)) {
159           return _.diag(SPV_ERROR_INVALID_DATA, inst)
160                  << spvOpcodeString(opcode)
161                  << ": expected Pointer to be a pointer to int or float "
162                  << "scalar type";
163         }
164       } else {
165         if (data_type != result_type) {
166           return _.diag(SPV_ERROR_INVALID_DATA, inst)
167                  << spvOpcodeString(opcode)
168                  << ": expected Pointer to point to a value of type Result "
169                     "Type";
170         }
171       }
172 
173       auto memory_scope = inst->GetOperandAs<const uint32_t>(operand_index++);
174       if (auto error = ValidateMemoryScope(_, inst, memory_scope)) {
175         return error;
176       }
177 
178       if (auto error = ValidateMemorySemantics(_, inst, operand_index++))
179         return error;
180 
181       if (opcode == SpvOpAtomicCompareExchange ||
182           opcode == SpvOpAtomicCompareExchangeWeak) {
183         if (auto error = ValidateMemorySemantics(_, inst, operand_index++))
184           return error;
185       }
186 
187       if (opcode == SpvOpAtomicStore) {
188         const uint32_t value_type = _.GetOperandTypeId(inst, 3);
189         if (value_type != data_type) {
190           return _.diag(SPV_ERROR_INVALID_DATA, inst)
191                  << spvOpcodeString(opcode)
192                  << ": expected Value type and the type pointed to by "
193                     "Pointer to be the same";
194         }
195       } else if (opcode != SpvOpAtomicLoad && opcode != SpvOpAtomicIIncrement &&
196                  opcode != SpvOpAtomicIDecrement &&
197                  opcode != SpvOpAtomicFlagTestAndSet &&
198                  opcode != SpvOpAtomicFlagClear) {
199         const uint32_t value_type = _.GetOperandTypeId(inst, operand_index++);
200         if (value_type != result_type) {
201           return _.diag(SPV_ERROR_INVALID_DATA, inst)
202                  << spvOpcodeString(opcode)
203                  << ": expected Value to be of type Result Type";
204         }
205       }
206 
207       if (opcode == SpvOpAtomicCompareExchange ||
208           opcode == SpvOpAtomicCompareExchangeWeak) {
209         const uint32_t comparator_type =
210             _.GetOperandTypeId(inst, operand_index++);
211         if (comparator_type != result_type) {
212           return _.diag(SPV_ERROR_INVALID_DATA, inst)
213                  << spvOpcodeString(opcode)
214                  << ": expected Comparator to be of type Result Type";
215         }
216       }
217 
218       break;
219     }
220 
221     default:
222       break;
223   }
224 
225   return SPV_SUCCESS;
226 }
227 
228 }  // namespace val
229 }  // namespace spvtools
230