1 // Copyright (c) 2018 Google LLC.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "source/val/validate.h"
16 
17 #include <algorithm>
18 
19 #include "source/opcode.h"
20 #include "source/val/instruction.h"
21 #include "source/val/validation_state.h"
22 
23 namespace spvtools {
24 namespace val {
25 namespace {
26 
ValidateFunction(ValidationState_t & _,const Instruction * inst)27 spv_result_t ValidateFunction(ValidationState_t& _, const Instruction* inst) {
28   const auto function_type_id = inst->GetOperandAs<uint32_t>(3);
29   const auto function_type = _.FindDef(function_type_id);
30   if (!function_type || SpvOpTypeFunction != function_type->opcode()) {
31     return _.diag(SPV_ERROR_INVALID_ID, inst)
32            << "OpFunction Function Type <id> '" << _.getIdName(function_type_id)
33            << "' is not a function type.";
34   }
35 
36   const auto return_id = function_type->GetOperandAs<uint32_t>(1);
37   if (return_id != inst->type_id()) {
38     return _.diag(SPV_ERROR_INVALID_ID, inst)
39            << "OpFunction Result Type <id> '" << _.getIdName(inst->type_id())
40            << "' does not match the Function Type's return type <id> '"
41            << _.getIdName(return_id) << "'.";
42   }
43 
44   for (auto& pair : inst->uses()) {
45     const auto* use = pair.first;
46     const std::vector<SpvOp> acceptable = {
47         SpvOpFunctionCall,
48         SpvOpEntryPoint,
49         SpvOpEnqueueKernel,
50         SpvOpGetKernelNDrangeSubGroupCount,
51         SpvOpGetKernelNDrangeMaxSubGroupSize,
52         SpvOpGetKernelWorkGroupSize,
53         SpvOpGetKernelPreferredWorkGroupSizeMultiple,
54         SpvOpGetKernelLocalSizeForSubgroupCount,
55         SpvOpGetKernelMaxNumSubgroups};
56     if (std::find(acceptable.begin(), acceptable.end(), use->opcode()) ==
57         acceptable.end()) {
58       return _.diag(SPV_ERROR_INVALID_ID, use)
59              << "Invalid use of function result id " << _.getIdName(inst->id())
60              << ".";
61     }
62   }
63 
64   return SPV_SUCCESS;
65 }
66 
ValidateFunctionParameter(ValidationState_t & _,const Instruction * inst)67 spv_result_t ValidateFunctionParameter(ValidationState_t& _,
68                                        const Instruction* inst) {
69   // NOTE: Find OpFunction & ensure OpFunctionParameter is not out of place.
70   size_t param_index = 0;
71   size_t inst_num = inst->LineNum() - 1;
72   if (inst_num == 0) {
73     return _.diag(SPV_ERROR_INVALID_LAYOUT, inst)
74            << "Function parameter cannot be the first instruction.";
75   }
76 
77   auto func_inst = &_.ordered_instructions()[inst_num];
78   while (--inst_num) {
79     func_inst = &_.ordered_instructions()[inst_num];
80     if (func_inst->opcode() == SpvOpFunction) {
81       break;
82     } else if (func_inst->opcode() == SpvOpFunctionParameter) {
83       ++param_index;
84     }
85   }
86 
87   if (func_inst->opcode() != SpvOpFunction) {
88     return _.diag(SPV_ERROR_INVALID_LAYOUT, inst)
89            << "Function parameter must be preceded by a function.";
90   }
91 
92   const auto function_type_id = func_inst->GetOperandAs<uint32_t>(3);
93   const auto function_type = _.FindDef(function_type_id);
94   if (!function_type) {
95     return _.diag(SPV_ERROR_INVALID_ID, func_inst)
96            << "Missing function type definition.";
97   }
98   if (param_index >= function_type->words().size() - 3) {
99     return _.diag(SPV_ERROR_INVALID_ID, inst)
100            << "Too many OpFunctionParameters for " << func_inst->id()
101            << ": expected " << function_type->words().size() - 3
102            << " based on the function's type";
103   }
104 
105   const auto param_type =
106       _.FindDef(function_type->GetOperandAs<uint32_t>(param_index + 2));
107   if (!param_type || inst->type_id() != param_type->id()) {
108     return _.diag(SPV_ERROR_INVALID_ID, inst)
109            << "OpFunctionParameter Result Type <id> '"
110            << _.getIdName(inst->type_id())
111            << "' does not match the OpTypeFunction parameter "
112               "type of the same index.";
113   }
114 
115   // Validate that PhysicalStorageBufferEXT have one of Restrict, Aliased,
116   // RestrictPointerEXT, or AliasedPointerEXT.
117   auto param_nonarray_type_id = param_type->id();
118   while (_.GetIdOpcode(param_nonarray_type_id) == SpvOpTypeArray) {
119     param_nonarray_type_id =
120         _.FindDef(param_nonarray_type_id)->GetOperandAs<uint32_t>(1u);
121   }
122   if (_.GetIdOpcode(param_nonarray_type_id) == SpvOpTypePointer) {
123     auto param_nonarray_type = _.FindDef(param_nonarray_type_id);
124     if (param_nonarray_type->GetOperandAs<uint32_t>(1u) ==
125         SpvStorageClassPhysicalStorageBufferEXT) {
126       // check for Aliased or Restrict
127       const auto& decorations = _.id_decorations(inst->id());
128 
129       bool foundAliased = std::any_of(
130           decorations.begin(), decorations.end(), [](const Decoration& d) {
131             return SpvDecorationAliased == d.dec_type();
132           });
133 
134       bool foundRestrict = std::any_of(
135           decorations.begin(), decorations.end(), [](const Decoration& d) {
136             return SpvDecorationRestrict == d.dec_type();
137           });
138 
139       if (!foundAliased && !foundRestrict) {
140         return _.diag(SPV_ERROR_INVALID_ID, inst)
141                << "OpFunctionParameter " << inst->id()
142                << ": expected Aliased or Restrict for PhysicalStorageBufferEXT "
143                   "pointer.";
144       }
145       if (foundAliased && foundRestrict) {
146         return _.diag(SPV_ERROR_INVALID_ID, inst)
147                << "OpFunctionParameter " << inst->id()
148                << ": can't specify both Aliased and Restrict for "
149                   "PhysicalStorageBufferEXT pointer.";
150       }
151     } else {
152       const auto pointee_type_id =
153           param_nonarray_type->GetOperandAs<uint32_t>(2);
154       const auto pointee_type = _.FindDef(pointee_type_id);
155       if (SpvOpTypePointer == pointee_type->opcode() &&
156           pointee_type->GetOperandAs<uint32_t>(1u) ==
157               SpvStorageClassPhysicalStorageBufferEXT) {
158         // check for AliasedPointerEXT/RestrictPointerEXT
159         const auto& decorations = _.id_decorations(inst->id());
160 
161         bool foundAliased = std::any_of(
162             decorations.begin(), decorations.end(), [](const Decoration& d) {
163               return SpvDecorationAliasedPointerEXT == d.dec_type();
164             });
165 
166         bool foundRestrict = std::any_of(
167             decorations.begin(), decorations.end(), [](const Decoration& d) {
168               return SpvDecorationRestrictPointerEXT == d.dec_type();
169             });
170 
171         if (!foundAliased && !foundRestrict) {
172           return _.diag(SPV_ERROR_INVALID_ID, inst)
173                  << "OpFunctionParameter " << inst->id()
174                  << ": expected AliasedPointerEXT or RestrictPointerEXT for "
175                     "PhysicalStorageBufferEXT pointer.";
176         }
177         if (foundAliased && foundRestrict) {
178           return _.diag(SPV_ERROR_INVALID_ID, inst)
179                  << "OpFunctionParameter " << inst->id()
180                  << ": can't specify both AliasedPointerEXT and "
181                     "RestrictPointerEXT for PhysicalStorageBufferEXT pointer.";
182         }
183       }
184     }
185   }
186 
187   return SPV_SUCCESS;
188 }
189 
ValidateFunctionCall(ValidationState_t & _,const Instruction * inst)190 spv_result_t ValidateFunctionCall(ValidationState_t& _,
191                                   const Instruction* inst) {
192   const auto function_id = inst->GetOperandAs<uint32_t>(2);
193   const auto function = _.FindDef(function_id);
194   if (!function || SpvOpFunction != function->opcode()) {
195     return _.diag(SPV_ERROR_INVALID_ID, inst)
196            << "OpFunctionCall Function <id> '" << _.getIdName(function_id)
197            << "' is not a function.";
198   }
199 
200   auto return_type = _.FindDef(function->type_id());
201   if (!return_type || return_type->id() != inst->type_id()) {
202     return _.diag(SPV_ERROR_INVALID_ID, inst)
203            << "OpFunctionCall Result Type <id> '"
204            << _.getIdName(inst->type_id())
205            << "'s type does not match Function <id> '"
206            << _.getIdName(return_type->id()) << "'s return type.";
207   }
208 
209   const auto function_type_id = function->GetOperandAs<uint32_t>(3);
210   const auto function_type = _.FindDef(function_type_id);
211   if (!function_type || function_type->opcode() != SpvOpTypeFunction) {
212     return _.diag(SPV_ERROR_INVALID_ID, inst)
213            << "Missing function type definition.";
214   }
215 
216   const auto function_call_arg_count = inst->words().size() - 4;
217   const auto function_param_count = function_type->words().size() - 3;
218   if (function_param_count != function_call_arg_count) {
219     return _.diag(SPV_ERROR_INVALID_ID, inst)
220            << "OpFunctionCall Function <id>'s parameter count does not match "
221               "the argument count.";
222   }
223 
224   for (size_t argument_index = 3, param_index = 2;
225        argument_index < inst->operands().size();
226        argument_index++, param_index++) {
227     const auto argument_id = inst->GetOperandAs<uint32_t>(argument_index);
228     const auto argument = _.FindDef(argument_id);
229     if (!argument) {
230       return _.diag(SPV_ERROR_INVALID_ID, inst)
231              << "Missing argument " << argument_index - 3 << " definition.";
232     }
233 
234     const auto argument_type = _.FindDef(argument->type_id());
235     if (!argument_type) {
236       return _.diag(SPV_ERROR_INVALID_ID, inst)
237              << "Missing argument " << argument_index - 3
238              << " type definition.";
239     }
240 
241     const auto parameter_type_id =
242         function_type->GetOperandAs<uint32_t>(param_index);
243     const auto parameter_type = _.FindDef(parameter_type_id);
244     if (!parameter_type || argument_type->id() != parameter_type->id()) {
245       return _.diag(SPV_ERROR_INVALID_ID, inst)
246              << "OpFunctionCall Argument <id> '" << _.getIdName(argument_id)
247              << "'s type does not match Function <id> '"
248              << _.getIdName(parameter_type_id) << "'s parameter type.";
249     }
250   }
251   return SPV_SUCCESS;
252 }
253 
254 }  // namespace
255 
FunctionPass(ValidationState_t & _,const Instruction * inst)256 spv_result_t FunctionPass(ValidationState_t& _, const Instruction* inst) {
257   switch (inst->opcode()) {
258     case SpvOpFunction:
259       if (auto error = ValidateFunction(_, inst)) return error;
260       break;
261     case SpvOpFunctionParameter:
262       if (auto error = ValidateFunctionParameter(_, inst)) return error;
263       break;
264     case SpvOpFunctionCall:
265       if (auto error = ValidateFunctionCall(_, inst)) return error;
266       break;
267     default:
268       break;
269   }
270 
271   return SPV_SUCCESS;
272 }
273 
274 }  // namespace val
275 }  // namespace spvtools
276