1 // Copyright (c) 2017 Google Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 // Validates correctness of derivative SPIR-V instructions.
16
17 #include "source/val/validate.h"
18
19 #include <string>
20
21 #include "source/diagnostic.h"
22 #include "source/opcode.h"
23 #include "source/val/instruction.h"
24 #include "source/val/validation_state.h"
25
26 namespace spvtools {
27 namespace val {
28
29 // Validates correctness of derivative instructions.
DerivativesPass(ValidationState_t & _,const Instruction * inst)30 spv_result_t DerivativesPass(ValidationState_t& _, const Instruction* inst) {
31 const SpvOp opcode = inst->opcode();
32 const uint32_t result_type = inst->type_id();
33
34 switch (opcode) {
35 case SpvOpDPdx:
36 case SpvOpDPdy:
37 case SpvOpFwidth:
38 case SpvOpDPdxFine:
39 case SpvOpDPdyFine:
40 case SpvOpFwidthFine:
41 case SpvOpDPdxCoarse:
42 case SpvOpDPdyCoarse:
43 case SpvOpFwidthCoarse: {
44 if (!_.IsFloatScalarOrVectorType(result_type)) {
45 return _.diag(SPV_ERROR_INVALID_DATA, inst)
46 << "Expected Result Type to be float scalar or vector type: "
47 << spvOpcodeString(opcode);
48 }
49
50 const uint32_t p_type = _.GetOperandTypeId(inst, 2);
51 if (p_type != result_type) {
52 return _.diag(SPV_ERROR_INVALID_DATA, inst)
53 << "Expected P type and Result Type to be the same: "
54 << spvOpcodeString(opcode);
55 }
56
57 const spvtools::Extension compute_shader_derivatives_extension =
58 kSPV_NV_compute_shader_derivatives;
59 ExtensionSet exts(1, &compute_shader_derivatives_extension);
60
61 if (_.HasAnyOfExtensions(exts)) {
62 _.function(inst->function()->id())
63 ->RegisterExecutionModelLimitation([opcode](SpvExecutionModel model,
64 std::string* message) {
65 if (model != SpvExecutionModelFragment &&
66 model != SpvExecutionModelGLCompute) {
67 if (message) {
68 *message =
69 std::string(
70 "Derivative instructions require Fragment execution "
71 "model: ") +
72 spvOpcodeString(opcode);
73 }
74 return false;
75 }
76 return true;
77 });
78 } else {
79 _.function(inst->function()->id())
80 ->RegisterExecutionModelLimitation(
81 SpvExecutionModelFragment,
82 std::string(
83 "Derivative instructions require Fragment execution "
84 "model: ") +
85 spvOpcodeString(opcode));
86 }
87 break;
88 }
89
90 default:
91 break;
92 }
93
94 return SPV_SUCCESS;
95 }
96
97 } // namespace val
98 } // namespace spvtools
99