1 // Copyright (c) 2020 André Perez Maselco
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "source/fuzz/transformation_swap_commutable_operands.h"
16 
17 #include "gtest/gtest.h"
18 #include "source/fuzz/fuzzer_util.h"
19 #include "source/fuzz/instruction_descriptor.h"
20 #include "test/fuzz/fuzz_test_util.h"
21 
22 namespace spvtools {
23 namespace fuzz {
24 namespace {
25 
TEST(TransformationSwapCommutableOperandsTest,IsApplicableTest)26 TEST(TransformationSwapCommutableOperandsTest, IsApplicableTest) {
27   std::string shader = R"(
28                OpCapability Shader
29           %1 = OpExtInstImport "GLSL.std.450"
30                OpMemoryModel Logical GLSL450
31                OpEntryPoint Fragment %4 "main"
32                OpExecutionMode %4 OriginUpperLeft
33                OpSource ESSL 310
34                OpName %4 "main"
35           %2 = OpTypeVoid
36           %3 = OpTypeFunction %2
37           %6 = OpTypeInt 32 1
38           %7 = OpTypeInt 32 0
39           %8 = OpConstant %7 2
40           %9 = OpTypeArray %6 %8
41          %10 = OpTypePointer Function %9
42          %12 = OpConstant %6 1
43          %13 = OpConstant %6 2
44          %14 = OpConstantComposite %9 %12 %13
45          %15 = OpTypePointer Function %6
46          %17 = OpConstant %6 0
47          %29 = OpTypeFloat 32
48          %30 = OpTypeArray %29 %8
49          %31 = OpTypePointer Function %30
50          %33 = OpConstant %29 1
51          %34 = OpConstant %29 2
52          %35 = OpConstantComposite %30 %33 %34
53          %36 = OpTypePointer Function %29
54          %49 = OpTypeVector %29 3
55          %50 = OpTypeArray %49 %8
56          %51 = OpTypePointer Function %50
57          %53 = OpConstant %29 3
58          %54 = OpConstantComposite %49 %33 %34 %53
59          %55 = OpConstant %29 4
60          %56 = OpConstant %29 5
61          %57 = OpConstant %29 6
62          %58 = OpConstantComposite %49 %55 %56 %57
63          %59 = OpConstantComposite %50 %54 %58
64          %61 = OpTypePointer Function %49
65           %4 = OpFunction %2 None %3
66           %5 = OpLabel
67          %11 = OpVariable %10 Function
68          %16 = OpVariable %15 Function
69          %23 = OpVariable %15 Function
70          %32 = OpVariable %31 Function
71          %37 = OpVariable %36 Function
72          %43 = OpVariable %36 Function
73          %52 = OpVariable %51 Function
74          %60 = OpVariable %36 Function
75                OpStore %11 %14
76          %18 = OpAccessChain %15 %11 %17
77          %19 = OpLoad %6 %18
78          %20 = OpAccessChain %15 %11 %12
79          %21 = OpLoad %6 %20
80          %22 = OpIAdd %6 %19 %21
81                OpStore %16 %22
82          %24 = OpAccessChain %15 %11 %17
83          %25 = OpLoad %6 %24
84          %26 = OpAccessChain %15 %11 %12
85          %27 = OpLoad %6 %26
86          %28 = OpIMul %6 %25 %27
87                OpStore %23 %28
88                OpStore %32 %35
89          %38 = OpAccessChain %36 %32 %17
90          %39 = OpLoad %29 %38
91          %40 = OpAccessChain %36 %32 %12
92          %41 = OpLoad %29 %40
93          %42 = OpFAdd %29 %39 %41
94                OpStore %37 %42
95          %44 = OpAccessChain %36 %32 %17
96          %45 = OpLoad %29 %44
97          %46 = OpAccessChain %36 %32 %12
98          %47 = OpLoad %29 %46
99          %48 = OpFMul %29 %45 %47
100                OpStore %43 %48
101                OpStore %52 %59
102          %62 = OpAccessChain %61 %52 %17
103          %63 = OpLoad %49 %62
104          %64 = OpAccessChain %61 %52 %12
105          %65 = OpLoad %49 %64
106          %66 = OpDot %29 %63 %65
107                OpStore %60 %66
108                OpReturn
109                OpFunctionEnd
110   )";
111 
112   const auto env = SPV_ENV_UNIVERSAL_1_5;
113   const auto consumer = nullptr;
114   const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption);
115   spvtools::ValidatorOptions validator_options;
116   ASSERT_TRUE(fuzzerutil::IsValidAndWellFormed(context.get(), validator_options,
117                                                kConsoleMessageConsumer));
118   TransformationContext transformation_context(
119       MakeUnique<FactManager>(context.get()), validator_options);
120   // Tests existing commutative instructions
121   auto instructionDescriptor = MakeInstructionDescriptor(22, SpvOpIAdd, 0);
122   auto transformation =
123       TransformationSwapCommutableOperands(instructionDescriptor);
124   ASSERT_TRUE(
125       transformation.IsApplicable(context.get(), transformation_context));
126 
127   instructionDescriptor = MakeInstructionDescriptor(28, SpvOpIMul, 0);
128   transformation = TransformationSwapCommutableOperands(instructionDescriptor);
129   ASSERT_TRUE(
130       transformation.IsApplicable(context.get(), transformation_context));
131 
132   instructionDescriptor = MakeInstructionDescriptor(42, SpvOpFAdd, 0);
133   transformation = TransformationSwapCommutableOperands(instructionDescriptor);
134   ASSERT_TRUE(
135       transformation.IsApplicable(context.get(), transformation_context));
136 
137   instructionDescriptor = MakeInstructionDescriptor(48, SpvOpFMul, 0);
138   transformation = TransformationSwapCommutableOperands(instructionDescriptor);
139   ASSERT_TRUE(
140       transformation.IsApplicable(context.get(), transformation_context));
141 
142   instructionDescriptor = MakeInstructionDescriptor(66, SpvOpDot, 0);
143   transformation = TransformationSwapCommutableOperands(instructionDescriptor);
144   ASSERT_TRUE(
145       transformation.IsApplicable(context.get(), transformation_context));
146 
147   // Tests existing non-commutative instructions
148   instructionDescriptor = MakeInstructionDescriptor(1, SpvOpExtInstImport, 0);
149   transformation = TransformationSwapCommutableOperands(instructionDescriptor);
150   ASSERT_FALSE(
151       transformation.IsApplicable(context.get(), transformation_context));
152 
153   instructionDescriptor = MakeInstructionDescriptor(5, SpvOpLabel, 0);
154   transformation = TransformationSwapCommutableOperands(instructionDescriptor);
155   ASSERT_FALSE(
156       transformation.IsApplicable(context.get(), transformation_context));
157 
158   instructionDescriptor = MakeInstructionDescriptor(8, SpvOpConstant, 0);
159   transformation = TransformationSwapCommutableOperands(instructionDescriptor);
160   ASSERT_FALSE(
161       transformation.IsApplicable(context.get(), transformation_context));
162 
163   instructionDescriptor = MakeInstructionDescriptor(11, SpvOpVariable, 0);
164   transformation = TransformationSwapCommutableOperands(instructionDescriptor);
165   ASSERT_FALSE(
166       transformation.IsApplicable(context.get(), transformation_context));
167 
168   instructionDescriptor =
169       MakeInstructionDescriptor(14, SpvOpConstantComposite, 0);
170   transformation = TransformationSwapCommutableOperands(instructionDescriptor);
171   ASSERT_FALSE(
172       transformation.IsApplicable(context.get(), transformation_context));
173 
174   // Tests the base instruction id not existing
175   instructionDescriptor = MakeInstructionDescriptor(67, SpvOpIAddCarry, 0);
176   transformation = TransformationSwapCommutableOperands(instructionDescriptor);
177   ASSERT_FALSE(
178       transformation.IsApplicable(context.get(), transformation_context));
179 
180   instructionDescriptor = MakeInstructionDescriptor(68, SpvOpIEqual, 0);
181   transformation = TransformationSwapCommutableOperands(instructionDescriptor);
182   ASSERT_FALSE(
183       transformation.IsApplicable(context.get(), transformation_context));
184 
185   instructionDescriptor = MakeInstructionDescriptor(69, SpvOpINotEqual, 0);
186   transformation = TransformationSwapCommutableOperands(instructionDescriptor);
187   ASSERT_FALSE(
188       transformation.IsApplicable(context.get(), transformation_context));
189 
190   instructionDescriptor = MakeInstructionDescriptor(70, SpvOpFOrdEqual, 0);
191   transformation = TransformationSwapCommutableOperands(instructionDescriptor);
192   ASSERT_FALSE(
193       transformation.IsApplicable(context.get(), transformation_context));
194 
195   instructionDescriptor = MakeInstructionDescriptor(71, SpvOpPtrEqual, 0);
196   transformation = TransformationSwapCommutableOperands(instructionDescriptor);
197   ASSERT_FALSE(
198       transformation.IsApplicable(context.get(), transformation_context));
199 
200   // Tests there being no instruction with the desired opcode after the base
201   // instruction id
202   instructionDescriptor = MakeInstructionDescriptor(24, SpvOpIAdd, 0);
203   transformation = TransformationSwapCommutableOperands(instructionDescriptor);
204   ASSERT_FALSE(
205       transformation.IsApplicable(context.get(), transformation_context));
206 
207   instructionDescriptor = MakeInstructionDescriptor(38, SpvOpIMul, 0);
208   transformation = TransformationSwapCommutableOperands(instructionDescriptor);
209   ASSERT_FALSE(
210       transformation.IsApplicable(context.get(), transformation_context));
211 
212   instructionDescriptor = MakeInstructionDescriptor(45, SpvOpFAdd, 0);
213   transformation = TransformationSwapCommutableOperands(instructionDescriptor);
214   ASSERT_FALSE(
215       transformation.IsApplicable(context.get(), transformation_context));
216 
217   instructionDescriptor = MakeInstructionDescriptor(66, SpvOpFMul, 0);
218   transformation = TransformationSwapCommutableOperands(instructionDescriptor);
219   ASSERT_FALSE(
220       transformation.IsApplicable(context.get(), transformation_context));
221 
222   // Tests there being an instruction with the desired opcode after the base
223   // instruction id, but the skip count associated with the instruction
224   // descriptor being so high.
225   instructionDescriptor = MakeInstructionDescriptor(11, SpvOpIAdd, 100);
226   transformation = TransformationSwapCommutableOperands(instructionDescriptor);
227   ASSERT_FALSE(
228       transformation.IsApplicable(context.get(), transformation_context));
229 
230   instructionDescriptor = MakeInstructionDescriptor(16, SpvOpIMul, 100);
231   transformation = TransformationSwapCommutableOperands(instructionDescriptor);
232   ASSERT_FALSE(
233       transformation.IsApplicable(context.get(), transformation_context));
234 
235   instructionDescriptor = MakeInstructionDescriptor(23, SpvOpFAdd, 100);
236   transformation = TransformationSwapCommutableOperands(instructionDescriptor);
237   ASSERT_FALSE(
238       transformation.IsApplicable(context.get(), transformation_context));
239 
240   instructionDescriptor = MakeInstructionDescriptor(32, SpvOpFMul, 100);
241   transformation = TransformationSwapCommutableOperands(instructionDescriptor);
242   ASSERT_FALSE(
243       transformation.IsApplicable(context.get(), transformation_context));
244 
245   instructionDescriptor = MakeInstructionDescriptor(37, SpvOpDot, 100);
246   transformation = TransformationSwapCommutableOperands(instructionDescriptor);
247   ASSERT_FALSE(
248       transformation.IsApplicable(context.get(), transformation_context));
249 }
250 
TEST(TransformationSwapCommutableOperandsTest,ApplyTest)251 TEST(TransformationSwapCommutableOperandsTest, ApplyTest) {
252   std::string shader = R"(
253                OpCapability Shader
254           %1 = OpExtInstImport "GLSL.std.450"
255                OpMemoryModel Logical GLSL450
256                OpEntryPoint Fragment %4 "main"
257                OpExecutionMode %4 OriginUpperLeft
258                OpSource ESSL 310
259                OpName %4 "main"
260           %2 = OpTypeVoid
261           %3 = OpTypeFunction %2
262           %6 = OpTypeInt 32 1
263           %7 = OpTypeInt 32 0
264           %8 = OpConstant %7 2
265           %9 = OpTypeArray %6 %8
266          %10 = OpTypePointer Function %9
267          %12 = OpConstant %6 1
268          %13 = OpConstant %6 2
269          %14 = OpConstantComposite %9 %12 %13
270          %15 = OpTypePointer Function %6
271          %17 = OpConstant %6 0
272          %29 = OpTypeFloat 32
273          %30 = OpTypeArray %29 %8
274          %31 = OpTypePointer Function %30
275          %33 = OpConstant %29 1
276          %34 = OpConstant %29 2
277          %35 = OpConstantComposite %30 %33 %34
278          %36 = OpTypePointer Function %29
279          %49 = OpTypeVector %29 3
280          %50 = OpTypeArray %49 %8
281          %51 = OpTypePointer Function %50
282          %53 = OpConstant %29 3
283          %54 = OpConstantComposite %49 %33 %34 %53
284          %55 = OpConstant %29 4
285          %56 = OpConstant %29 5
286          %57 = OpConstant %29 6
287          %58 = OpConstantComposite %49 %55 %56 %57
288          %59 = OpConstantComposite %50 %54 %58
289          %61 = OpTypePointer Function %49
290           %4 = OpFunction %2 None %3
291           %5 = OpLabel
292          %11 = OpVariable %10 Function
293          %16 = OpVariable %15 Function
294          %23 = OpVariable %15 Function
295          %32 = OpVariable %31 Function
296          %37 = OpVariable %36 Function
297          %43 = OpVariable %36 Function
298          %52 = OpVariable %51 Function
299          %60 = OpVariable %36 Function
300                OpStore %11 %14
301          %18 = OpAccessChain %15 %11 %17
302          %19 = OpLoad %6 %18
303          %20 = OpAccessChain %15 %11 %12
304          %21 = OpLoad %6 %20
305          %22 = OpIAdd %6 %19 %21
306                OpStore %16 %22
307          %24 = OpAccessChain %15 %11 %17
308          %25 = OpLoad %6 %24
309          %26 = OpAccessChain %15 %11 %12
310          %27 = OpLoad %6 %26
311          %28 = OpIMul %6 %25 %27
312                OpStore %23 %28
313                OpStore %32 %35
314          %38 = OpAccessChain %36 %32 %17
315          %39 = OpLoad %29 %38
316          %40 = OpAccessChain %36 %32 %12
317          %41 = OpLoad %29 %40
318          %42 = OpFAdd %29 %39 %41
319                OpStore %37 %42
320          %44 = OpAccessChain %36 %32 %17
321          %45 = OpLoad %29 %44
322          %46 = OpAccessChain %36 %32 %12
323          %47 = OpLoad %29 %46
324          %48 = OpFMul %29 %45 %47
325                OpStore %43 %48
326                OpStore %52 %59
327          %62 = OpAccessChain %61 %52 %17
328          %63 = OpLoad %49 %62
329          %64 = OpAccessChain %61 %52 %12
330          %65 = OpLoad %49 %64
331          %66 = OpDot %29 %63 %65
332                OpStore %60 %66
333                OpReturn
334                OpFunctionEnd
335   )";
336 
337   const auto env = SPV_ENV_UNIVERSAL_1_5;
338   const auto consumer = nullptr;
339   const auto context = BuildModule(env, consumer, shader, kFuzzAssembleOption);
340   spvtools::ValidatorOptions validator_options;
341   ASSERT_TRUE(fuzzerutil::IsValidAndWellFormed(context.get(), validator_options,
342                                                kConsoleMessageConsumer));
343   TransformationContext transformation_context(
344       MakeUnique<FactManager>(context.get()), validator_options);
345   auto instructionDescriptor = MakeInstructionDescriptor(22, SpvOpIAdd, 0);
346   auto transformation =
347       TransformationSwapCommutableOperands(instructionDescriptor);
348   ApplyAndCheckFreshIds(transformation, context.get(), &transformation_context);
349 
350   instructionDescriptor = MakeInstructionDescriptor(28, SpvOpIMul, 0);
351   transformation = TransformationSwapCommutableOperands(instructionDescriptor);
352   ApplyAndCheckFreshIds(transformation, context.get(), &transformation_context);
353 
354   instructionDescriptor = MakeInstructionDescriptor(42, SpvOpFAdd, 0);
355   transformation = TransformationSwapCommutableOperands(instructionDescriptor);
356   ApplyAndCheckFreshIds(transformation, context.get(), &transformation_context);
357 
358   instructionDescriptor = MakeInstructionDescriptor(48, SpvOpFMul, 0);
359   transformation = TransformationSwapCommutableOperands(instructionDescriptor);
360   ApplyAndCheckFreshIds(transformation, context.get(), &transformation_context);
361 
362   instructionDescriptor = MakeInstructionDescriptor(66, SpvOpDot, 0);
363   transformation = TransformationSwapCommutableOperands(instructionDescriptor);
364   ApplyAndCheckFreshIds(transformation, context.get(), &transformation_context);
365 
366   std::string variantShader = R"(
367                OpCapability Shader
368           %1 = OpExtInstImport "GLSL.std.450"
369                OpMemoryModel Logical GLSL450
370                OpEntryPoint Fragment %4 "main"
371                OpExecutionMode %4 OriginUpperLeft
372                OpSource ESSL 310
373                OpName %4 "main"
374           %2 = OpTypeVoid
375           %3 = OpTypeFunction %2
376           %6 = OpTypeInt 32 1
377           %7 = OpTypeInt 32 0
378           %8 = OpConstant %7 2
379           %9 = OpTypeArray %6 %8
380          %10 = OpTypePointer Function %9
381          %12 = OpConstant %6 1
382          %13 = OpConstant %6 2
383          %14 = OpConstantComposite %9 %12 %13
384          %15 = OpTypePointer Function %6
385          %17 = OpConstant %6 0
386          %29 = OpTypeFloat 32
387          %30 = OpTypeArray %29 %8
388          %31 = OpTypePointer Function %30
389          %33 = OpConstant %29 1
390          %34 = OpConstant %29 2
391          %35 = OpConstantComposite %30 %33 %34
392          %36 = OpTypePointer Function %29
393          %49 = OpTypeVector %29 3
394          %50 = OpTypeArray %49 %8
395          %51 = OpTypePointer Function %50
396          %53 = OpConstant %29 3
397          %54 = OpConstantComposite %49 %33 %34 %53
398          %55 = OpConstant %29 4
399          %56 = OpConstant %29 5
400          %57 = OpConstant %29 6
401          %58 = OpConstantComposite %49 %55 %56 %57
402          %59 = OpConstantComposite %50 %54 %58
403          %61 = OpTypePointer Function %49
404           %4 = OpFunction %2 None %3
405           %5 = OpLabel
406          %11 = OpVariable %10 Function
407          %16 = OpVariable %15 Function
408          %23 = OpVariable %15 Function
409          %32 = OpVariable %31 Function
410          %37 = OpVariable %36 Function
411          %43 = OpVariable %36 Function
412          %52 = OpVariable %51 Function
413          %60 = OpVariable %36 Function
414                OpStore %11 %14
415          %18 = OpAccessChain %15 %11 %17
416          %19 = OpLoad %6 %18
417          %20 = OpAccessChain %15 %11 %12
418          %21 = OpLoad %6 %20
419          %22 = OpIAdd %6 %21 %19
420                OpStore %16 %22
421          %24 = OpAccessChain %15 %11 %17
422          %25 = OpLoad %6 %24
423          %26 = OpAccessChain %15 %11 %12
424          %27 = OpLoad %6 %26
425          %28 = OpIMul %6 %27 %25
426                OpStore %23 %28
427                OpStore %32 %35
428          %38 = OpAccessChain %36 %32 %17
429          %39 = OpLoad %29 %38
430          %40 = OpAccessChain %36 %32 %12
431          %41 = OpLoad %29 %40
432          %42 = OpFAdd %29 %41 %39
433                OpStore %37 %42
434          %44 = OpAccessChain %36 %32 %17
435          %45 = OpLoad %29 %44
436          %46 = OpAccessChain %36 %32 %12
437          %47 = OpLoad %29 %46
438          %48 = OpFMul %29 %47 %45
439                OpStore %43 %48
440                OpStore %52 %59
441          %62 = OpAccessChain %61 %52 %17
442          %63 = OpLoad %49 %62
443          %64 = OpAccessChain %61 %52 %12
444          %65 = OpLoad %49 %64
445          %66 = OpDot %29 %65 %63
446                OpStore %60 %66
447                OpReturn
448                OpFunctionEnd
449   )";
450 
451   ASSERT_TRUE(IsEqual(env, variantShader, context.get()));
452 }
453 
454 }  // namespace
455 }  // namespace fuzz
456 }  // namespace spvtools
457