1 // Copyright (c) 2017 Google Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include <algorithm>
16 #include <cstdarg>
17 #include <iostream>
18 #include <sstream>
19 #include <string>
20 #include <unordered_set>
21 #include <vector>
22 
23 #include "gmock/gmock.h"
24 #include "test/opt/assembly_builder.h"
25 #include "test/opt/pass_fixture.h"
26 #include "test/opt/pass_utils.h"
27 
28 namespace spvtools {
29 namespace opt {
30 namespace {
31 
32 using ::testing::HasSubstr;
33 using ::testing::MatchesRegex;
34 using StrengthReductionBasicTest = PassTest<::testing::Test>;
35 
36 // Test to make sure we replace 5*8.
TEST_F(StrengthReductionBasicTest,BasicReplaceMulBy8)37 TEST_F(StrengthReductionBasicTest, BasicReplaceMulBy8) {
38   const std::vector<const char*> text = {
39       // clang-format off
40                "OpCapability Shader",
41           "%1 = OpExtInstImport \"GLSL.std.450\"",
42                "OpMemoryModel Logical GLSL450",
43                "OpEntryPoint Vertex %main \"main\"",
44                "OpName %main \"main\"",
45        "%void = OpTypeVoid",
46           "%4 = OpTypeFunction %void",
47        "%uint = OpTypeInt 32 0",
48      "%uint_5 = OpConstant %uint 5",
49      "%uint_8 = OpConstant %uint 8",
50        "%main = OpFunction %void None %4",
51           "%8 = OpLabel",
52           "%9 = OpIMul %uint %uint_5 %uint_8",
53                "OpReturn",
54                "OpFunctionEnd"
55       // clang-format on
56   };
57 
58   auto result = SinglePassRunAndDisassemble<StrengthReductionPass>(
59       JoinAllInsts(text), /* skip_nop = */ true, /* do_validation = */ false);
60 
61   EXPECT_EQ(Pass::Status::SuccessWithChange, std::get<1>(result));
62   const std::string& output = std::get<0>(result);
63   EXPECT_THAT(output, Not(HasSubstr("OpIMul")));
64   EXPECT_THAT(output, HasSubstr("OpShiftLeftLogical %uint %uint_5 %uint_3"));
65 }
66 
67 // TODO(dneto): Add Effcee as required dependency, and make this unconditional.
68 // Test to make sure we replace 16*5
69 // Also demonstrate use of Effcee matching.
TEST_F(StrengthReductionBasicTest,BasicReplaceMulBy16)70 TEST_F(StrengthReductionBasicTest, BasicReplaceMulBy16) {
71   const std::string text = R"(
72                OpCapability Shader
73           %1 = OpExtInstImport "GLSL.std.450"
74                OpMemoryModel Logical GLSL450
75                OpEntryPoint Vertex %main "main"
76                OpName %main "main"
77        %void = OpTypeVoid
78           %4 = OpTypeFunction %void
79 ; We know disassembly will produce %uint here, but
80 ;  CHECK: %uint = OpTypeInt 32 0
81 ;  CHECK-DAG: [[five:%[a-zA-Z_\d]+]] = OpConstant %uint 5
82 
83 ; We have RE2 regular expressions, so \w matches [_a-zA-Z0-9].
84 ; This shows the preferred pattern for matching SPIR-V identifiers.
85 ; (We could have cheated in this case since we know the disassembler will
86 ; generate the 'nice' name of "%uint_4".
87 ;  CHECK-DAG: [[four:%\w+]] = OpConstant %uint 4
88        %uint = OpTypeInt 32 0
89      %uint_5 = OpConstant %uint 5
90     %uint_16 = OpConstant %uint 16
91        %main = OpFunction %void None %4
92 ; CHECK: OpLabel
93           %8 = OpLabel
94 ; CHECK-NEXT: OpShiftLeftLogical %uint [[five]] [[four]]
95 ; The multiplication disappears.
96 ; CHECK-NOT: OpIMul
97           %9 = OpIMul %uint %uint_16 %uint_5
98                OpReturn
99 ; CHECK: OpFunctionEnd
100                OpFunctionEnd)";
101 
102   SinglePassRunAndMatch<StrengthReductionPass>(text, false);
103 }
104 
105 // Test to make sure we replace a multiple of 32 and 4.
TEST_F(StrengthReductionBasicTest,BasicTwoPowersOf2)106 TEST_F(StrengthReductionBasicTest, BasicTwoPowersOf2) {
107   // In this case, we have two powers of 2.  Need to make sure we replace only
108   // one of them for the bit shift.
109   // clang-format off
110   const std::string text = R"(
111           OpCapability Shader
112      %1 = OpExtInstImport "GLSL.std.450"
113           OpMemoryModel Logical GLSL450
114           OpEntryPoint Vertex %main "main"
115           OpName %main "main"
116   %void = OpTypeVoid
117      %4 = OpTypeFunction %void
118    %int = OpTypeInt 32 1
119 %int_32 = OpConstant %int 32
120  %int_4 = OpConstant %int 4
121   %main = OpFunction %void None %4
122      %8 = OpLabel
123      %9 = OpIMul %int %int_32 %int_4
124           OpReturn
125           OpFunctionEnd
126 )";
127   // clang-format on
128   auto result = SinglePassRunAndDisassemble<StrengthReductionPass>(
129       text, /* skip_nop = */ true, /* do_validation = */ false);
130 
131   EXPECT_EQ(Pass::Status::SuccessWithChange, std::get<1>(result));
132   const std::string& output = std::get<0>(result);
133   EXPECT_THAT(output, Not(HasSubstr("OpIMul")));
134   EXPECT_THAT(output, HasSubstr("OpShiftLeftLogical %int %int_4 %uint_5"));
135 }
136 
137 // Test to make sure we don't replace 0*5.
TEST_F(StrengthReductionBasicTest,BasicDontReplace0)138 TEST_F(StrengthReductionBasicTest, BasicDontReplace0) {
139   const std::vector<const char*> text = {
140       // clang-format off
141                "OpCapability Shader",
142           "%1 = OpExtInstImport \"GLSL.std.450\"",
143                "OpMemoryModel Logical GLSL450",
144                "OpEntryPoint Vertex %main \"main\"",
145                "OpName %main \"main\"",
146        "%void = OpTypeVoid",
147           "%4 = OpTypeFunction %void",
148         "%int = OpTypeInt 32 1",
149       "%int_0 = OpConstant %int 0",
150       "%int_5 = OpConstant %int 5",
151        "%main = OpFunction %void None %4",
152           "%8 = OpLabel",
153           "%9 = OpIMul %int %int_0 %int_5",
154                "OpReturn",
155                "OpFunctionEnd"
156       // clang-format on
157   };
158 
159   auto result = SinglePassRunAndDisassemble<StrengthReductionPass>(
160       JoinAllInsts(text), /* skip_nop = */ true, /* do_validation = */ false);
161 
162   EXPECT_EQ(Pass::Status::SuccessWithoutChange, std::get<1>(result));
163 }
164 
165 // Test to make sure we do not replace a multiple of 5 and 7.
TEST_F(StrengthReductionBasicTest,BasicNoChange)166 TEST_F(StrengthReductionBasicTest, BasicNoChange) {
167   const std::vector<const char*> text = {
168       // clang-format off
169              "OpCapability Shader",
170         "%1 = OpExtInstImport \"GLSL.std.450\"",
171              "OpMemoryModel Logical GLSL450",
172              "OpEntryPoint Vertex %2 \"main\"",
173              "OpName %2 \"main\"",
174         "%3 = OpTypeVoid",
175         "%4 = OpTypeFunction %3",
176         "%5 = OpTypeInt 32 1",
177         "%6 = OpTypeInt 32 0",
178         "%7 = OpConstant %5 5",
179         "%8 = OpConstant %5 7",
180         "%2 = OpFunction %3 None %4",
181         "%9 = OpLabel",
182         "%10 = OpIMul %5 %7 %8",
183              "OpReturn",
184              "OpFunctionEnd",
185       // clang-format on
186   };
187 
188   auto result = SinglePassRunAndDisassemble<StrengthReductionPass>(
189       JoinAllInsts(text), /* skip_nop = */ true, /* do_validation = */ false);
190 
191   EXPECT_EQ(Pass::Status::SuccessWithoutChange, std::get<1>(result));
192 }
193 
194 // Test to make sure constants and types are reused and not duplicated.
TEST_F(StrengthReductionBasicTest,NoDuplicateConstantsAndTypes)195 TEST_F(StrengthReductionBasicTest, NoDuplicateConstantsAndTypes) {
196   const std::vector<const char*> text = {
197       // clang-format off
198                "OpCapability Shader",
199           "%1 = OpExtInstImport \"GLSL.std.450\"",
200                "OpMemoryModel Logical GLSL450",
201                "OpEntryPoint Vertex %main \"main\"",
202                "OpName %main \"main\"",
203        "%void = OpTypeVoid",
204           "%4 = OpTypeFunction %void",
205        "%uint = OpTypeInt 32 0",
206      "%uint_8 = OpConstant %uint 8",
207      "%uint_3 = OpConstant %uint 3",
208        "%main = OpFunction %void None %4",
209           "%8 = OpLabel",
210           "%9 = OpIMul %uint %uint_8 %uint_3",
211                "OpReturn",
212                "OpFunctionEnd",
213       // clang-format on
214   };
215 
216   auto result = SinglePassRunAndDisassemble<StrengthReductionPass>(
217       JoinAllInsts(text), /* skip_nop = */ true, /* do_validation = */ false);
218 
219   EXPECT_EQ(Pass::Status::SuccessWithChange, std::get<1>(result));
220   const std::string& output = std::get<0>(result);
221   EXPECT_THAT(output,
222               Not(MatchesRegex(".*OpConstant %uint 3.*OpConstant %uint 3.*")));
223   EXPECT_THAT(output, Not(MatchesRegex(".*OpTypeInt 32 0.*OpTypeInt 32 0.*")));
224 }
225 
226 // Test to make sure we generate the constants only once
TEST_F(StrengthReductionBasicTest,BasicCreateOneConst)227 TEST_F(StrengthReductionBasicTest, BasicCreateOneConst) {
228   const std::vector<const char*> text = {
229       // clang-format off
230                "OpCapability Shader",
231           "%1 = OpExtInstImport \"GLSL.std.450\"",
232                "OpMemoryModel Logical GLSL450",
233                "OpEntryPoint Vertex %main \"main\"",
234                "OpName %main \"main\"",
235        "%void = OpTypeVoid",
236           "%4 = OpTypeFunction %void",
237        "%uint = OpTypeInt 32 0",
238      "%uint_5 = OpConstant %uint 5",
239      "%uint_9 = OpConstant %uint 9",
240    "%uint_128 = OpConstant %uint 128",
241        "%main = OpFunction %void None %4",
242           "%8 = OpLabel",
243           "%9 = OpIMul %uint %uint_5 %uint_128",
244          "%10 = OpIMul %uint %uint_9 %uint_128",
245                "OpReturn",
246                "OpFunctionEnd"
247       // clang-format on
248   };
249 
250   auto result = SinglePassRunAndDisassemble<StrengthReductionPass>(
251       JoinAllInsts(text), /* skip_nop = */ true, /* do_validation = */ false);
252 
253   EXPECT_EQ(Pass::Status::SuccessWithChange, std::get<1>(result));
254   const std::string& output = std::get<0>(result);
255   EXPECT_THAT(output, Not(HasSubstr("OpIMul")));
256   EXPECT_THAT(output, HasSubstr("OpShiftLeftLogical %uint %uint_5 %uint_7"));
257   EXPECT_THAT(output, HasSubstr("OpShiftLeftLogical %uint %uint_9 %uint_7"));
258 }
259 
260 // Test to make sure we generate the instructions in the correct position and
261 // that the uses get replaced as well.  Here we check that the use in the return
262 // is replaced, we also check that we can replace two OpIMuls when one feeds the
263 // other.
TEST_F(StrengthReductionBasicTest,BasicCheckPositionAndReplacement)264 TEST_F(StrengthReductionBasicTest, BasicCheckPositionAndReplacement) {
265   // This is just the preamble to set up the test.
266   const std::vector<const char*> common_text = {
267       // clang-format off
268                "OpCapability Shader",
269           "%1 = OpExtInstImport \"GLSL.std.450\"",
270                "OpMemoryModel Logical GLSL450",
271                "OpEntryPoint Fragment %main \"main\" %gl_FragColor",
272                "OpExecutionMode %main OriginUpperLeft",
273                "OpName %main \"main\"",
274                "OpName %foo_i1_ \"foo(i1;\"",
275                "OpName %n \"n\"",
276                "OpName %gl_FragColor \"gl_FragColor\"",
277                "OpName %param \"param\"",
278                "OpDecorate %gl_FragColor Location 0",
279        "%void = OpTypeVoid",
280           "%3 = OpTypeFunction %void",
281         "%int = OpTypeInt 32 1",
282 "%_ptr_Function_int = OpTypePointer Function %int",
283           "%8 = OpTypeFunction %int %_ptr_Function_int",
284     "%int_256 = OpConstant %int 256",
285       "%int_2 = OpConstant %int 2",
286       "%float = OpTypeFloat 32",
287     "%v4float = OpTypeVector %float 4",
288 "%_ptr_Output_v4float = OpTypePointer Output %v4float",
289 "%gl_FragColor = OpVariable %_ptr_Output_v4float Output",
290     "%float_1 = OpConstant %float 1",
291      "%int_10 = OpConstant %int 10",
292   "%float_0_375 = OpConstant %float 0.375",
293   "%float_0_75 = OpConstant %float 0.75",
294        "%uint = OpTypeInt 32 0",
295      "%uint_8 = OpConstant %uint 8",
296      "%uint_1 = OpConstant %uint 1",
297        "%main = OpFunction %void None %3",
298           "%5 = OpLabel",
299       "%param = OpVariable %_ptr_Function_int Function",
300                "OpStore %param %int_10",
301          "%26 = OpFunctionCall %int %foo_i1_ %param",
302          "%27 = OpConvertSToF %float %26",
303          "%28 = OpFDiv %float %float_1 %27",
304          "%31 = OpCompositeConstruct %v4float %28 %float_0_375 %float_0_75 %float_1",
305                "OpStore %gl_FragColor %31",
306                "OpReturn",
307                "OpFunctionEnd"
308       // clang-format on
309   };
310 
311   // This is the real test.  The two OpIMul should be replaced.  The expected
312   // output is in |foo_after|.
313   const std::vector<const char*> foo_before = {
314       // clang-format off
315     "%foo_i1_ = OpFunction %int None %8",
316           "%n = OpFunctionParameter %_ptr_Function_int",
317          "%11 = OpLabel",
318          "%12 = OpLoad %int %n",
319          "%14 = OpIMul %int %12 %int_256",
320          "%16 = OpIMul %int %14 %int_2",
321                "OpReturnValue %16",
322                "OpFunctionEnd",
323 
324       // clang-format on
325   };
326 
327   const std::vector<const char*> foo_after = {
328       // clang-format off
329     "%foo_i1_ = OpFunction %int None %8",
330           "%n = OpFunctionParameter %_ptr_Function_int",
331          "%11 = OpLabel",
332          "%12 = OpLoad %int %n",
333          "%33 = OpShiftLeftLogical %int %12 %uint_8",
334          "%34 = OpShiftLeftLogical %int %33 %uint_1",
335                "OpReturnValue %34",
336                "OpFunctionEnd",
337       // clang-format on
338   };
339 
340   SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
341   SinglePassRunAndCheck<StrengthReductionPass>(
342       JoinAllInsts(Concat(common_text, foo_before)),
343       JoinAllInsts(Concat(common_text, foo_after)),
344       /* skip_nop = */ true, /* do_validate = */ true);
345 }
346 
347 // Test that, when the result of an OpIMul instruction has more than 1 use, and
348 // the instruction is replaced, all of the uses of the results are replace with
349 // the new result.
TEST_F(StrengthReductionBasicTest,BasicTestMultipleReplacements)350 TEST_F(StrengthReductionBasicTest, BasicTestMultipleReplacements) {
351   // This is just the preamble to set up the test.
352   const std::vector<const char*> common_text = {
353       // clang-format off
354                "OpCapability Shader",
355           "%1 = OpExtInstImport \"GLSL.std.450\"",
356                "OpMemoryModel Logical GLSL450",
357                "OpEntryPoint Fragment %main \"main\" %gl_FragColor",
358                "OpExecutionMode %main OriginUpperLeft",
359                "OpName %main \"main\"",
360                "OpName %foo_i1_ \"foo(i1;\"",
361                "OpName %n \"n\"",
362                "OpName %gl_FragColor \"gl_FragColor\"",
363                "OpName %param \"param\"",
364                "OpDecorate %gl_FragColor Location 0",
365        "%void = OpTypeVoid",
366           "%3 = OpTypeFunction %void",
367         "%int = OpTypeInt 32 1",
368 "%_ptr_Function_int = OpTypePointer Function %int",
369           "%8 = OpTypeFunction %int %_ptr_Function_int",
370     "%int_256 = OpConstant %int 256",
371       "%int_2 = OpConstant %int 2",
372       "%float = OpTypeFloat 32",
373     "%v4float = OpTypeVector %float 4",
374 "%_ptr_Output_v4float = OpTypePointer Output %v4float",
375 "%gl_FragColor = OpVariable %_ptr_Output_v4float Output",
376     "%float_1 = OpConstant %float 1",
377      "%int_10 = OpConstant %int 10",
378   "%float_0_375 = OpConstant %float 0.375",
379   "%float_0_75 = OpConstant %float 0.75",
380        "%uint = OpTypeInt 32 0",
381      "%uint_8 = OpConstant %uint 8",
382      "%uint_1 = OpConstant %uint 1",
383        "%main = OpFunction %void None %3",
384           "%5 = OpLabel",
385       "%param = OpVariable %_ptr_Function_int Function",
386                "OpStore %param %int_10",
387          "%26 = OpFunctionCall %int %foo_i1_ %param",
388          "%27 = OpConvertSToF %float %26",
389          "%28 = OpFDiv %float %float_1 %27",
390          "%31 = OpCompositeConstruct %v4float %28 %float_0_375 %float_0_75 %float_1",
391                "OpStore %gl_FragColor %31",
392                "OpReturn",
393                "OpFunctionEnd"
394       // clang-format on
395   };
396 
397   // This is the real test.  The two OpIMul instructions should be replaced.  In
398   // particular, we want to be sure that both uses of %16 are changed to use the
399   // new result.
400   const std::vector<const char*> foo_before = {
401       // clang-format off
402     "%foo_i1_ = OpFunction %int None %8",
403           "%n = OpFunctionParameter %_ptr_Function_int",
404          "%11 = OpLabel",
405          "%12 = OpLoad %int %n",
406          "%14 = OpIMul %int %12 %int_256",
407          "%16 = OpIMul %int %14 %int_2",
408          "%17 = OpIAdd %int %14 %16",
409                "OpReturnValue %17",
410                "OpFunctionEnd",
411 
412       // clang-format on
413   };
414 
415   const std::vector<const char*> foo_after = {
416       // clang-format off
417     "%foo_i1_ = OpFunction %int None %8",
418           "%n = OpFunctionParameter %_ptr_Function_int",
419          "%11 = OpLabel",
420          "%12 = OpLoad %int %n",
421          "%34 = OpShiftLeftLogical %int %12 %uint_8",
422          "%35 = OpShiftLeftLogical %int %34 %uint_1",
423          "%17 = OpIAdd %int %34 %35",
424                "OpReturnValue %17",
425                "OpFunctionEnd",
426       // clang-format on
427   };
428 
429   SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
430   SinglePassRunAndCheck<StrengthReductionPass>(
431       JoinAllInsts(Concat(common_text, foo_before)),
432       JoinAllInsts(Concat(common_text, foo_after)),
433       /* skip_nop = */ true, /* do_validate = */ true);
434 }
435 
436 }  // namespace
437 }  // namespace opt
438 }  // namespace spvtools
439