1 // Copyright (c) 2017 Google Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 // Validates correctness of derivative SPIR-V instructions.
16 
17 #include "source/val/validate.h"
18 
19 #include <string>
20 
21 #include "source/diagnostic.h"
22 #include "source/opcode.h"
23 #include "source/val/instruction.h"
24 #include "source/val/validation_state.h"
25 
26 namespace spvtools {
27 namespace val {
28 
29 // Validates correctness of derivative instructions.
DerivativesPass(ValidationState_t & _,const Instruction * inst)30 spv_result_t DerivativesPass(ValidationState_t& _, const Instruction* inst) {
31   const SpvOp opcode = inst->opcode();
32   const uint32_t result_type = inst->type_id();
33 
34   switch (opcode) {
35     case SpvOpDPdx:
36     case SpvOpDPdy:
37     case SpvOpFwidth:
38     case SpvOpDPdxFine:
39     case SpvOpDPdyFine:
40     case SpvOpFwidthFine:
41     case SpvOpDPdxCoarse:
42     case SpvOpDPdyCoarse:
43     case SpvOpFwidthCoarse: {
44       if (!_.IsFloatScalarOrVectorType(result_type)) {
45         return _.diag(SPV_ERROR_INVALID_DATA, inst)
46                << "Expected Result Type to be float scalar or vector type: "
47                << spvOpcodeString(opcode);
48       }
49 
50       const uint32_t p_type = _.GetOperandTypeId(inst, 2);
51       if (p_type != result_type) {
52         return _.diag(SPV_ERROR_INVALID_DATA, inst)
53                << "Expected P type and Result Type to be the same: "
54                << spvOpcodeString(opcode);
55       }
56 
57       const spvtools::Extension compute_shader_derivatives_extension =
58           kSPV_NV_compute_shader_derivatives;
59       ExtensionSet exts(1, &compute_shader_derivatives_extension);
60 
61       if (_.HasAnyOfExtensions(exts)) {
62         _.function(inst->function()->id())
63             ->RegisterExecutionModelLimitation([opcode](SpvExecutionModel model,
64                                                         std::string* message) {
65               if (model != SpvExecutionModelFragment &&
66                   model != SpvExecutionModelGLCompute) {
67                 if (message) {
68                   *message =
69                       std::string(
70                           "Derivative instructions require Fragment execution "
71                           "model: ") +
72                       spvOpcodeString(opcode);
73                 }
74                 return false;
75               }
76               return true;
77             });
78       } else {
79         _.function(inst->function()->id())
80             ->RegisterExecutionModelLimitation(
81                 SpvExecutionModelFragment,
82                 std::string(
83                     "Derivative instructions require Fragment execution "
84                     "model: ") +
85                     spvOpcodeString(opcode));
86       }
87       break;
88     }
89 
90     default:
91       break;
92   }
93 
94   return SPV_SUCCESS;
95 }
96 
97 }  // namespace val
98 }  // namespace spvtools
99