1 // Copyright (c) 2020 André Perez Maselco
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "source/fuzz/transformation_adjust_branch_weights.h"
16 
17 #include "gtest/gtest.h"
18 #include "source/fuzz/fuzzer_util.h"
19 #include "source/fuzz/instruction_descriptor.h"
20 #include "test/fuzz/fuzz_test_util.h"
21 
22 namespace spvtools {
23 namespace fuzz {
24 namespace {
25 
TEST(TransformationAdjustBranchWeightsTest,IsApplicableTest)26 TEST(TransformationAdjustBranchWeightsTest, IsApplicableTest) {
27   std::string shader = R"(
28                OpCapability Shader
29           %1 = OpExtInstImport "GLSL.std.450"
30                OpMemoryModel Logical GLSL450
31                OpEntryPoint Fragment %4 "main" %51 %27
32                OpExecutionMode %4 OriginUpperLeft
33                OpSource ESSL 310
34                OpName %4 "main"
35                OpName %25 "buf"
36                OpMemberName %25 0 "value"
37                OpName %27 ""
38                OpName %51 "color"
39                OpMemberDecorate %25 0 Offset 0
40                OpDecorate %25 Block
41                OpDecorate %27 DescriptorSet 0
42                OpDecorate %27 Binding 0
43                OpDecorate %51 Location 0
44           %2 = OpTypeVoid
45           %3 = OpTypeFunction %2
46           %6 = OpTypeFloat 32
47           %7 = OpTypeVector %6 4
48         %150 = OpTypeVector %6 2
49          %10 = OpConstant %6 0.300000012
50          %11 = OpConstant %6 0.400000006
51          %12 = OpConstant %6 0.5
52          %13 = OpConstant %6 1
53          %14 = OpConstantComposite %7 %10 %11 %12 %13
54          %15 = OpTypeInt 32 1
55          %18 = OpConstant %15 0
56          %25 = OpTypeStruct %6
57          %26 = OpTypePointer Uniform %25
58          %27 = OpVariable %26 Uniform
59          %28 = OpTypePointer Uniform %6
60          %32 = OpTypeBool
61         %103 = OpConstantTrue %32
62          %34 = OpConstant %6 0.100000001
63          %48 = OpConstant %15 1
64          %50 = OpTypePointer Output %7
65          %51 = OpVariable %50 Output
66         %100 = OpTypePointer Function %6
67           %4 = OpFunction %2 None %3
68           %5 = OpLabel
69         %101 = OpVariable %100 Function
70         %102 = OpVariable %100 Function
71                OpBranch %19
72          %19 = OpLabel
73          %60 = OpPhi %7 %14 %5 %58 %20
74          %59 = OpPhi %15 %18 %5 %49 %20
75          %29 = OpAccessChain %28 %27 %18
76          %30 = OpLoad %6 %29
77          %31 = OpConvertFToS %15 %30
78          %33 = OpSLessThan %32 %59 %31
79                OpLoopMerge %21 %20 None
80                OpBranchConditional %33 %20 %21 1 2
81          %20 = OpLabel
82          %39 = OpCompositeExtract %6 %60 0
83          %40 = OpFAdd %6 %39 %34
84          %55 = OpCompositeInsert %7 %40 %60 0
85          %44 = OpCompositeExtract %6 %60 1
86          %45 = OpFSub %6 %44 %34
87          %58 = OpCompositeInsert %7 %45 %55 1
88          %49 = OpIAdd %15 %59 %48
89                OpBranch %19
90          %21 = OpLabel
91                OpStore %51 %60
92                OpSelectionMerge %105 None
93                OpBranchConditional %103 %104 %105
94         %104 = OpLabel
95                OpBranch %105
96         %105 = OpLabel
97                OpReturn
98                OpFunctionEnd
99   )";
100 
101   const auto env = SPV_ENV_UNIVERSAL_1_5;
102   const auto consumer = nullptr;
103   const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption);
104   spvtools::ValidatorOptions validator_options;
105   ASSERT_TRUE(fuzzerutil::IsValidAndWellFormed(context.get(), validator_options,
106                                                kConsoleMessageConsumer));
107   TransformationContext transformation_context(
108       MakeUnique<FactManager>(context.get()), validator_options);
109   // Tests OpBranchConditional instruction with weigths.
110   auto instruction_descriptor =
111       MakeInstructionDescriptor(33, SpvOpBranchConditional, 0);
112   auto transformation =
113       TransformationAdjustBranchWeights(instruction_descriptor, {0, 1});
114   ASSERT_TRUE(
115       transformation.IsApplicable(context.get(), transformation_context));
116 
117   // Tests the two branch weights equal to 0.
118   instruction_descriptor =
119       MakeInstructionDescriptor(33, SpvOpBranchConditional, 0);
120   transformation =
121       TransformationAdjustBranchWeights(instruction_descriptor, {0, 0});
122 #ifndef NDEBUG
123   ASSERT_DEATH(
124       transformation.IsApplicable(context.get(), transformation_context),
125       "At least one weight must be non-zero");
126 #endif
127 
128   // Tests 32-bit unsigned integer overflow.
129   instruction_descriptor =
130       MakeInstructionDescriptor(33, SpvOpBranchConditional, 0);
131   transformation = TransformationAdjustBranchWeights(instruction_descriptor,
132                                                      {UINT32_MAX, 0});
133   ASSERT_TRUE(
134       transformation.IsApplicable(context.get(), transformation_context));
135 
136   instruction_descriptor =
137       MakeInstructionDescriptor(33, SpvOpBranchConditional, 0);
138   transformation = TransformationAdjustBranchWeights(instruction_descriptor,
139                                                      {1, UINT32_MAX});
140 #ifndef NDEBUG
141   ASSERT_DEATH(
142       transformation.IsApplicable(context.get(), transformation_context),
143       "The sum of the two weights must not be greater than UINT32_MAX");
144 #endif
145 
146   // Tests OpBranchConditional instruction with no weights.
147   instruction_descriptor =
148       MakeInstructionDescriptor(21, SpvOpBranchConditional, 0);
149   transformation =
150       TransformationAdjustBranchWeights(instruction_descriptor, {0, 1});
151   ASSERT_TRUE(
152       transformation.IsApplicable(context.get(), transformation_context));
153 
154   // Tests non-OpBranchConditional instructions.
155   instruction_descriptor = MakeInstructionDescriptor(2, SpvOpTypeVoid, 0);
156   transformation =
157       TransformationAdjustBranchWeights(instruction_descriptor, {5, 6});
158   ASSERT_FALSE(
159       transformation.IsApplicable(context.get(), transformation_context));
160 
161   instruction_descriptor = MakeInstructionDescriptor(20, SpvOpLabel, 0);
162   transformation =
163       TransformationAdjustBranchWeights(instruction_descriptor, {1, 2});
164   ASSERT_FALSE(
165       transformation.IsApplicable(context.get(), transformation_context));
166 
167   instruction_descriptor = MakeInstructionDescriptor(49, SpvOpIAdd, 0);
168   transformation =
169       TransformationAdjustBranchWeights(instruction_descriptor, {1, 2});
170   ASSERT_FALSE(
171       transformation.IsApplicable(context.get(), transformation_context));
172 }
173 
TEST(TransformationAdjustBranchWeightsTest,ApplyTest)174 TEST(TransformationAdjustBranchWeightsTest, ApplyTest) {
175   std::string shader = R"(
176                OpCapability Shader
177           %1 = OpExtInstImport "GLSL.std.450"
178                OpMemoryModel Logical GLSL450
179                OpEntryPoint Fragment %4 "main" %51 %27
180                OpExecutionMode %4 OriginUpperLeft
181                OpSource ESSL 310
182                OpName %4 "main"
183                OpName %25 "buf"
184                OpMemberName %25 0 "value"
185                OpName %27 ""
186                OpName %51 "color"
187                OpMemberDecorate %25 0 Offset 0
188                OpDecorate %25 Block
189                OpDecorate %27 DescriptorSet 0
190                OpDecorate %27 Binding 0
191                OpDecorate %51 Location 0
192           %2 = OpTypeVoid
193           %3 = OpTypeFunction %2
194           %6 = OpTypeFloat 32
195           %7 = OpTypeVector %6 4
196         %150 = OpTypeVector %6 2
197          %10 = OpConstant %6 0.300000012
198          %11 = OpConstant %6 0.400000006
199          %12 = OpConstant %6 0.5
200          %13 = OpConstant %6 1
201          %14 = OpConstantComposite %7 %10 %11 %12 %13
202          %15 = OpTypeInt 32 1
203          %18 = OpConstant %15 0
204          %25 = OpTypeStruct %6
205          %26 = OpTypePointer Uniform %25
206          %27 = OpVariable %26 Uniform
207          %28 = OpTypePointer Uniform %6
208          %32 = OpTypeBool
209         %103 = OpConstantTrue %32
210          %34 = OpConstant %6 0.100000001
211          %48 = OpConstant %15 1
212          %50 = OpTypePointer Output %7
213          %51 = OpVariable %50 Output
214         %100 = OpTypePointer Function %6
215           %4 = OpFunction %2 None %3
216           %5 = OpLabel
217         %101 = OpVariable %100 Function
218         %102 = OpVariable %100 Function
219                OpBranch %19
220          %19 = OpLabel
221          %60 = OpPhi %7 %14 %5 %58 %20
222          %59 = OpPhi %15 %18 %5 %49 %20
223          %29 = OpAccessChain %28 %27 %18
224          %30 = OpLoad %6 %29
225          %31 = OpConvertFToS %15 %30
226          %33 = OpSLessThan %32 %59 %31
227                OpLoopMerge %21 %20 None
228                OpBranchConditional %33 %20 %21 1 2
229          %20 = OpLabel
230          %39 = OpCompositeExtract %6 %60 0
231          %40 = OpFAdd %6 %39 %34
232          %55 = OpCompositeInsert %7 %40 %60 0
233          %44 = OpCompositeExtract %6 %60 1
234          %45 = OpFSub %6 %44 %34
235          %58 = OpCompositeInsert %7 %45 %55 1
236          %49 = OpIAdd %15 %59 %48
237                OpBranch %19
238          %21 = OpLabel
239                OpStore %51 %60
240                OpSelectionMerge %105 None
241                OpBranchConditional %103 %104 %105
242         %104 = OpLabel
243                OpBranch %105
244         %105 = OpLabel
245                OpReturn
246                OpFunctionEnd
247   )";
248 
249   const auto env = SPV_ENV_UNIVERSAL_1_5;
250   const auto consumer = nullptr;
251   const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption);
252   spvtools::ValidatorOptions validator_options;
253   ASSERT_TRUE(fuzzerutil::IsValidAndWellFormed(context.get(), validator_options,
254                                                kConsoleMessageConsumer));
255   TransformationContext transformation_context(
256       MakeUnique<FactManager>(context.get()), validator_options);
257   auto instruction_descriptor =
258       MakeInstructionDescriptor(33, SpvOpBranchConditional, 0);
259   auto transformation =
260       TransformationAdjustBranchWeights(instruction_descriptor, {5, 6});
261   ApplyAndCheckFreshIds(transformation, context.get(), &transformation_context);
262 
263   instruction_descriptor =
264       MakeInstructionDescriptor(21, SpvOpBranchConditional, 0);
265   transformation =
266       TransformationAdjustBranchWeights(instruction_descriptor, {7, 8});
267   ApplyAndCheckFreshIds(transformation, context.get(), &transformation_context);
268 
269   std::string variant_shader = R"(
270                OpCapability Shader
271           %1 = OpExtInstImport "GLSL.std.450"
272                OpMemoryModel Logical GLSL450
273                OpEntryPoint Fragment %4 "main" %51 %27
274                OpExecutionMode %4 OriginUpperLeft
275                OpSource ESSL 310
276                OpName %4 "main"
277                OpName %25 "buf"
278                OpMemberName %25 0 "value"
279                OpName %27 ""
280                OpName %51 "color"
281                OpMemberDecorate %25 0 Offset 0
282                OpDecorate %25 Block
283                OpDecorate %27 DescriptorSet 0
284                OpDecorate %27 Binding 0
285                OpDecorate %51 Location 0
286           %2 = OpTypeVoid
287           %3 = OpTypeFunction %2
288           %6 = OpTypeFloat 32
289           %7 = OpTypeVector %6 4
290         %150 = OpTypeVector %6 2
291          %10 = OpConstant %6 0.300000012
292          %11 = OpConstant %6 0.400000006
293          %12 = OpConstant %6 0.5
294          %13 = OpConstant %6 1
295          %14 = OpConstantComposite %7 %10 %11 %12 %13
296          %15 = OpTypeInt 32 1
297          %18 = OpConstant %15 0
298          %25 = OpTypeStruct %6
299          %26 = OpTypePointer Uniform %25
300          %27 = OpVariable %26 Uniform
301          %28 = OpTypePointer Uniform %6
302          %32 = OpTypeBool
303         %103 = OpConstantTrue %32
304          %34 = OpConstant %6 0.100000001
305          %48 = OpConstant %15 1
306          %50 = OpTypePointer Output %7
307          %51 = OpVariable %50 Output
308         %100 = OpTypePointer Function %6
309           %4 = OpFunction %2 None %3
310           %5 = OpLabel
311         %101 = OpVariable %100 Function
312         %102 = OpVariable %100 Function
313                OpBranch %19
314          %19 = OpLabel
315          %60 = OpPhi %7 %14 %5 %58 %20
316          %59 = OpPhi %15 %18 %5 %49 %20
317          %29 = OpAccessChain %28 %27 %18
318          %30 = OpLoad %6 %29
319          %31 = OpConvertFToS %15 %30
320          %33 = OpSLessThan %32 %59 %31
321                OpLoopMerge %21 %20 None
322                OpBranchConditional %33 %20 %21 5 6
323          %20 = OpLabel
324          %39 = OpCompositeExtract %6 %60 0
325          %40 = OpFAdd %6 %39 %34
326          %55 = OpCompositeInsert %7 %40 %60 0
327          %44 = OpCompositeExtract %6 %60 1
328          %45 = OpFSub %6 %44 %34
329          %58 = OpCompositeInsert %7 %45 %55 1
330          %49 = OpIAdd %15 %59 %48
331                OpBranch %19
332          %21 = OpLabel
333                OpStore %51 %60
334                OpSelectionMerge %105 None
335                OpBranchConditional %103 %104 %105 7 8
336         %104 = OpLabel
337                OpBranch %105
338         %105 = OpLabel
339                OpReturn
340                OpFunctionEnd
341   )";
342 
343   ASSERT_TRUE(IsEqual(env, variant_shader, context.get()));
344 }
345 
346 }  // namespace
347 }  // namespace fuzz
348 }  // namespace spvtools
349