1 // Copyright (c) 2018 Google LLC.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "source/val/validate.h"
16
17 #include <algorithm>
18
19 #include "source/opcode.h"
20 #include "source/val/instruction.h"
21 #include "source/val/validation_state.h"
22
23 namespace spvtools {
24 namespace val {
25 namespace {
26
ValidateFunction(ValidationState_t & _,const Instruction * inst)27 spv_result_t ValidateFunction(ValidationState_t& _, const Instruction* inst) {
28 const auto function_type_id = inst->GetOperandAs<uint32_t>(3);
29 const auto function_type = _.FindDef(function_type_id);
30 if (!function_type || SpvOpTypeFunction != function_type->opcode()) {
31 return _.diag(SPV_ERROR_INVALID_ID, inst)
32 << "OpFunction Function Type <id> '" << _.getIdName(function_type_id)
33 << "' is not a function type.";
34 }
35
36 const auto return_id = function_type->GetOperandAs<uint32_t>(1);
37 if (return_id != inst->type_id()) {
38 return _.diag(SPV_ERROR_INVALID_ID, inst)
39 << "OpFunction Result Type <id> '" << _.getIdName(inst->type_id())
40 << "' does not match the Function Type's return type <id> '"
41 << _.getIdName(return_id) << "'.";
42 }
43
44 for (auto& pair : inst->uses()) {
45 const auto* use = pair.first;
46 const std::vector<SpvOp> acceptable = {
47 SpvOpFunctionCall,
48 SpvOpEntryPoint,
49 SpvOpEnqueueKernel,
50 SpvOpGetKernelNDrangeSubGroupCount,
51 SpvOpGetKernelNDrangeMaxSubGroupSize,
52 SpvOpGetKernelWorkGroupSize,
53 SpvOpGetKernelPreferredWorkGroupSizeMultiple,
54 SpvOpGetKernelLocalSizeForSubgroupCount,
55 SpvOpGetKernelMaxNumSubgroups};
56 if (std::find(acceptable.begin(), acceptable.end(), use->opcode()) ==
57 acceptable.end()) {
58 return _.diag(SPV_ERROR_INVALID_ID, use)
59 << "Invalid use of function result id " << _.getIdName(inst->id())
60 << ".";
61 }
62 }
63
64 return SPV_SUCCESS;
65 }
66
ValidateFunctionParameter(ValidationState_t & _,const Instruction * inst)67 spv_result_t ValidateFunctionParameter(ValidationState_t& _,
68 const Instruction* inst) {
69 // NOTE: Find OpFunction & ensure OpFunctionParameter is not out of place.
70 size_t param_index = 0;
71 size_t inst_num = inst->LineNum() - 1;
72 if (inst_num == 0) {
73 return _.diag(SPV_ERROR_INVALID_LAYOUT, inst)
74 << "Function parameter cannot be the first instruction.";
75 }
76
77 auto func_inst = &_.ordered_instructions()[inst_num];
78 while (--inst_num) {
79 func_inst = &_.ordered_instructions()[inst_num];
80 if (func_inst->opcode() == SpvOpFunction) {
81 break;
82 } else if (func_inst->opcode() == SpvOpFunctionParameter) {
83 ++param_index;
84 }
85 }
86
87 if (func_inst->opcode() != SpvOpFunction) {
88 return _.diag(SPV_ERROR_INVALID_LAYOUT, inst)
89 << "Function parameter must be preceded by a function.";
90 }
91
92 const auto function_type_id = func_inst->GetOperandAs<uint32_t>(3);
93 const auto function_type = _.FindDef(function_type_id);
94 if (!function_type) {
95 return _.diag(SPV_ERROR_INVALID_ID, func_inst)
96 << "Missing function type definition.";
97 }
98 if (param_index >= function_type->words().size() - 3) {
99 return _.diag(SPV_ERROR_INVALID_ID, inst)
100 << "Too many OpFunctionParameters for " << func_inst->id()
101 << ": expected " << function_type->words().size() - 3
102 << " based on the function's type";
103 }
104
105 const auto param_type =
106 _.FindDef(function_type->GetOperandAs<uint32_t>(param_index + 2));
107 if (!param_type || inst->type_id() != param_type->id()) {
108 return _.diag(SPV_ERROR_INVALID_ID, inst)
109 << "OpFunctionParameter Result Type <id> '"
110 << _.getIdName(inst->type_id())
111 << "' does not match the OpTypeFunction parameter "
112 "type of the same index.";
113 }
114
115 // Validate that PhysicalStorageBufferEXT have one of Restrict, Aliased,
116 // RestrictPointerEXT, or AliasedPointerEXT.
117 auto param_nonarray_type_id = param_type->id();
118 while (_.GetIdOpcode(param_nonarray_type_id) == SpvOpTypeArray) {
119 param_nonarray_type_id =
120 _.FindDef(param_nonarray_type_id)->GetOperandAs<uint32_t>(1u);
121 }
122 if (_.GetIdOpcode(param_nonarray_type_id) == SpvOpTypePointer) {
123 auto param_nonarray_type = _.FindDef(param_nonarray_type_id);
124 if (param_nonarray_type->GetOperandAs<uint32_t>(1u) ==
125 SpvStorageClassPhysicalStorageBufferEXT) {
126 // check for Aliased or Restrict
127 const auto& decorations = _.id_decorations(inst->id());
128
129 bool foundAliased = std::any_of(
130 decorations.begin(), decorations.end(), [](const Decoration& d) {
131 return SpvDecorationAliased == d.dec_type();
132 });
133
134 bool foundRestrict = std::any_of(
135 decorations.begin(), decorations.end(), [](const Decoration& d) {
136 return SpvDecorationRestrict == d.dec_type();
137 });
138
139 if (!foundAliased && !foundRestrict) {
140 return _.diag(SPV_ERROR_INVALID_ID, inst)
141 << "OpFunctionParameter " << inst->id()
142 << ": expected Aliased or Restrict for PhysicalStorageBufferEXT "
143 "pointer.";
144 }
145 if (foundAliased && foundRestrict) {
146 return _.diag(SPV_ERROR_INVALID_ID, inst)
147 << "OpFunctionParameter " << inst->id()
148 << ": can't specify both Aliased and Restrict for "
149 "PhysicalStorageBufferEXT pointer.";
150 }
151 } else {
152 const auto pointee_type_id =
153 param_nonarray_type->GetOperandAs<uint32_t>(2);
154 const auto pointee_type = _.FindDef(pointee_type_id);
155 if (SpvOpTypePointer == pointee_type->opcode() &&
156 pointee_type->GetOperandAs<uint32_t>(1u) ==
157 SpvStorageClassPhysicalStorageBufferEXT) {
158 // check for AliasedPointerEXT/RestrictPointerEXT
159 const auto& decorations = _.id_decorations(inst->id());
160
161 bool foundAliased = std::any_of(
162 decorations.begin(), decorations.end(), [](const Decoration& d) {
163 return SpvDecorationAliasedPointerEXT == d.dec_type();
164 });
165
166 bool foundRestrict = std::any_of(
167 decorations.begin(), decorations.end(), [](const Decoration& d) {
168 return SpvDecorationRestrictPointerEXT == d.dec_type();
169 });
170
171 if (!foundAliased && !foundRestrict) {
172 return _.diag(SPV_ERROR_INVALID_ID, inst)
173 << "OpFunctionParameter " << inst->id()
174 << ": expected AliasedPointerEXT or RestrictPointerEXT for "
175 "PhysicalStorageBufferEXT pointer.";
176 }
177 if (foundAliased && foundRestrict) {
178 return _.diag(SPV_ERROR_INVALID_ID, inst)
179 << "OpFunctionParameter " << inst->id()
180 << ": can't specify both AliasedPointerEXT and "
181 "RestrictPointerEXT for PhysicalStorageBufferEXT pointer.";
182 }
183 }
184 }
185 }
186
187 return SPV_SUCCESS;
188 }
189
ValidateFunctionCall(ValidationState_t & _,const Instruction * inst)190 spv_result_t ValidateFunctionCall(ValidationState_t& _,
191 const Instruction* inst) {
192 const auto function_id = inst->GetOperandAs<uint32_t>(2);
193 const auto function = _.FindDef(function_id);
194 if (!function || SpvOpFunction != function->opcode()) {
195 return _.diag(SPV_ERROR_INVALID_ID, inst)
196 << "OpFunctionCall Function <id> '" << _.getIdName(function_id)
197 << "' is not a function.";
198 }
199
200 auto return_type = _.FindDef(function->type_id());
201 if (!return_type || return_type->id() != inst->type_id()) {
202 return _.diag(SPV_ERROR_INVALID_ID, inst)
203 << "OpFunctionCall Result Type <id> '"
204 << _.getIdName(inst->type_id())
205 << "'s type does not match Function <id> '"
206 << _.getIdName(return_type->id()) << "'s return type.";
207 }
208
209 const auto function_type_id = function->GetOperandAs<uint32_t>(3);
210 const auto function_type = _.FindDef(function_type_id);
211 if (!function_type || function_type->opcode() != SpvOpTypeFunction) {
212 return _.diag(SPV_ERROR_INVALID_ID, inst)
213 << "Missing function type definition.";
214 }
215
216 const auto function_call_arg_count = inst->words().size() - 4;
217 const auto function_param_count = function_type->words().size() - 3;
218 if (function_param_count != function_call_arg_count) {
219 return _.diag(SPV_ERROR_INVALID_ID, inst)
220 << "OpFunctionCall Function <id>'s parameter count does not match "
221 "the argument count.";
222 }
223
224 for (size_t argument_index = 3, param_index = 2;
225 argument_index < inst->operands().size();
226 argument_index++, param_index++) {
227 const auto argument_id = inst->GetOperandAs<uint32_t>(argument_index);
228 const auto argument = _.FindDef(argument_id);
229 if (!argument) {
230 return _.diag(SPV_ERROR_INVALID_ID, inst)
231 << "Missing argument " << argument_index - 3 << " definition.";
232 }
233
234 const auto argument_type = _.FindDef(argument->type_id());
235 if (!argument_type) {
236 return _.diag(SPV_ERROR_INVALID_ID, inst)
237 << "Missing argument " << argument_index - 3
238 << " type definition.";
239 }
240
241 const auto parameter_type_id =
242 function_type->GetOperandAs<uint32_t>(param_index);
243 const auto parameter_type = _.FindDef(parameter_type_id);
244 if (!parameter_type || argument_type->id() != parameter_type->id()) {
245 return _.diag(SPV_ERROR_INVALID_ID, inst)
246 << "OpFunctionCall Argument <id> '" << _.getIdName(argument_id)
247 << "'s type does not match Function <id> '"
248 << _.getIdName(parameter_type_id) << "'s parameter type.";
249 }
250 }
251 return SPV_SUCCESS;
252 }
253
254 } // namespace
255
FunctionPass(ValidationState_t & _,const Instruction * inst)256 spv_result_t FunctionPass(ValidationState_t& _, const Instruction* inst) {
257 switch (inst->opcode()) {
258 case SpvOpFunction:
259 if (auto error = ValidateFunction(_, inst)) return error;
260 break;
261 case SpvOpFunctionParameter:
262 if (auto error = ValidateFunctionParameter(_, inst)) return error;
263 break;
264 case SpvOpFunctionCall:
265 if (auto error = ValidateFunctionCall(_, inst)) return error;
266 break;
267 default:
268 break;
269 }
270
271 return SPV_SUCCESS;
272 }
273
274 } // namespace val
275 } // namespace spvtools
276