1 // Copyright (c) 2017 Google Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include <memory>
16 #include <string>
17 #include <vector>
18 
19 #include "gmock/gmock.h"
20 #include "source/opt/loop_descriptor.h"
21 #include "source/opt/pass.h"
22 #include "test/opt/assembly_builder.h"
23 #include "test/opt/function_utils.h"
24 #include "test/opt/pass_fixture.h"
25 #include "test/opt/pass_utils.h"
26 
27 namespace spvtools {
28 namespace opt {
29 namespace {
30 
31 using ::testing::UnorderedElementsAre;
32 using PassClassTest = PassTest<::testing::Test>;
33 
34 /*
35 Generated from the following GLSL
36 #version 330 core
37 layout(location = 0) out vec4 c;
38 void main() {
39   int i = 0;
40   for(; i < 10; ++i) {
41   }
42 }
43 */
TEST_F(PassClassTest,BasicVisitFromEntryPoint)44 TEST_F(PassClassTest, BasicVisitFromEntryPoint) {
45   const std::string text = R"(
46                 OpCapability Shader
47           %1 = OpExtInstImport "GLSL.std.450"
48                OpMemoryModel Logical GLSL450
49                OpEntryPoint Fragment %2 "main" %3
50                OpExecutionMode %2 OriginUpperLeft
51                OpSource GLSL 330
52                OpName %2 "main"
53                OpName %5 "i"
54                OpName %3 "c"
55                OpDecorate %3 Location 0
56           %6 = OpTypeVoid
57           %7 = OpTypeFunction %6
58           %8 = OpTypeInt 32 1
59           %9 = OpTypePointer Function %8
60          %10 = OpConstant %8 0
61          %11 = OpConstant %8 10
62          %12 = OpTypeBool
63          %13 = OpConstant %8 1
64          %14 = OpTypeFloat 32
65          %15 = OpTypeVector %14 4
66          %16 = OpTypePointer Output %15
67           %3 = OpVariable %16 Output
68           %2 = OpFunction %6 None %7
69          %17 = OpLabel
70           %5 = OpVariable %9 Function
71                OpStore %5 %10
72                OpBranch %18
73          %18 = OpLabel
74                OpLoopMerge %19 %20 None
75                OpBranch %21
76          %21 = OpLabel
77          %22 = OpLoad %8 %5
78          %23 = OpSLessThan %12 %22 %11
79                OpBranchConditional %23 %24 %19
80          %24 = OpLabel
81                OpBranch %20
82          %20 = OpLabel
83          %25 = OpLoad %8 %5
84          %26 = OpIAdd %8 %25 %13
85                OpStore %5 %26
86                OpBranch %18
87          %19 = OpLabel
88                OpReturn
89                OpFunctionEnd
90   )";
91   // clang-format on
92   std::unique_ptr<IRContext> context =
93       BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
94                   SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
95   Module* module = context->module();
96   EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
97                              << text << std::endl;
98   const Function* f = spvtest::GetFunction(module, 2);
99   LoopDescriptor& ld = *context->GetLoopDescriptor(f);
100 
101   EXPECT_EQ(ld.NumLoops(), 1u);
102 
103   Loop& loop = ld.GetLoopByIndex(0);
104   EXPECT_EQ(loop.GetHeaderBlock(), spvtest::GetBasicBlock(f, 18));
105   EXPECT_EQ(loop.GetLatchBlock(), spvtest::GetBasicBlock(f, 20));
106   EXPECT_EQ(loop.GetMergeBlock(), spvtest::GetBasicBlock(f, 19));
107 
108   EXPECT_FALSE(loop.HasNestedLoops());
109   EXPECT_FALSE(loop.IsNested());
110   EXPECT_EQ(loop.GetDepth(), 1u);
111 }
112 
113 /*
114 Generated from the following GLSL:
115 #version 330 core
116 layout(location = 0) out vec4 c;
117 void main() {
118   for(int i = 0; i < 10; ++i) {}
119   for(int i = 0; i < 10; ++i) {}
120 }
121 
122 But it was "hacked" to make the first loop merge block the second loop header.
123 */
TEST_F(PassClassTest,LoopWithNoPreHeader)124 TEST_F(PassClassTest, LoopWithNoPreHeader) {
125   const std::string text = R"(
126                OpCapability Shader
127           %1 = OpExtInstImport "GLSL.std.450"
128                OpMemoryModel Logical GLSL450
129                OpEntryPoint Fragment %2 "main" %3
130                OpExecutionMode %2 OriginUpperLeft
131                OpSource GLSL 330
132                OpName %2 "main"
133                OpName %4 "i"
134                OpName %5 "i"
135                OpName %3 "c"
136                OpDecorate %3 Location 0
137           %6 = OpTypeVoid
138           %7 = OpTypeFunction %6
139           %8 = OpTypeInt 32 1
140           %9 = OpTypePointer Function %8
141          %10 = OpConstant %8 0
142          %11 = OpConstant %8 10
143          %12 = OpTypeBool
144          %13 = OpConstant %8 1
145          %14 = OpTypeFloat 32
146          %15 = OpTypeVector %14 4
147          %16 = OpTypePointer Output %15
148           %3 = OpVariable %16 Output
149           %2 = OpFunction %6 None %7
150          %17 = OpLabel
151           %4 = OpVariable %9 Function
152           %5 = OpVariable %9 Function
153                OpStore %4 %10
154                OpStore %5 %10
155                OpBranch %18
156          %18 = OpLabel
157                OpLoopMerge %27 %20 None
158                OpBranch %21
159          %21 = OpLabel
160          %22 = OpLoad %8 %4
161          %23 = OpSLessThan %12 %22 %11
162                OpBranchConditional %23 %24 %27
163          %24 = OpLabel
164                OpBranch %20
165          %20 = OpLabel
166          %25 = OpLoad %8 %4
167          %26 = OpIAdd %8 %25 %13
168                OpStore %4 %26
169                OpBranch %18
170          %27 = OpLabel
171                OpLoopMerge %28 %29 None
172                OpBranch %30
173          %30 = OpLabel
174          %31 = OpLoad %8 %5
175          %32 = OpSLessThan %12 %31 %11
176                OpBranchConditional %32 %33 %28
177          %33 = OpLabel
178                OpBranch %29
179          %29 = OpLabel
180          %34 = OpLoad %8 %5
181          %35 = OpIAdd %8 %34 %13
182                OpStore %5 %35
183                OpBranch %27
184          %28 = OpLabel
185                OpReturn
186                OpFunctionEnd
187   )";
188   // clang-format on
189   std::unique_ptr<IRContext> context =
190       BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
191                   SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
192   Module* module = context->module();
193   EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
194                              << text << std::endl;
195   const Function* f = spvtest::GetFunction(module, 2);
196   LoopDescriptor& ld = *context->GetLoopDescriptor(f);
197 
198   EXPECT_EQ(ld.NumLoops(), 2u);
199 
200   Loop* loop = ld[27];
201   EXPECT_EQ(loop->GetPreHeaderBlock(), nullptr);
202   EXPECT_NE(loop->GetOrCreatePreHeaderBlock(), nullptr);
203 }
204 
205 /*
206 Generated from the following GLSL + --eliminate-local-multi-store
207 
208 #version 330 core
209 in vec4 c;
210 void main() {
211   int i = 0;
212   bool cond = c[0] == 0;
213   for (; i < 10; i++) {
214     if (cond) {
215       return;
216     }
217     else {
218       return;
219     }
220   }
221   bool cond2 = i == 9;
222 }
223 */
TEST_F(PassClassTest,NoLoop)224 TEST_F(PassClassTest, NoLoop) {
225   const std::string text = R"(; SPIR-V
226 ; Version: 1.0
227 ; Generator: Khronos Glslang Reference Front End; 3
228 ; Bound: 47
229 ; Schema: 0
230                OpCapability Shader
231           %1 = OpExtInstImport "GLSL.std.450"
232                OpMemoryModel Logical GLSL450
233                OpEntryPoint Fragment %4 "main" %16
234                OpExecutionMode %4 OriginUpperLeft
235                OpSource GLSL 330
236                OpName %4 "main"
237                OpName %16 "c"
238                OpDecorate %16 Location 0
239           %2 = OpTypeVoid
240           %3 = OpTypeFunction %2
241           %6 = OpTypeInt 32 1
242           %7 = OpTypePointer Function %6
243           %9 = OpConstant %6 0
244          %10 = OpTypeBool
245          %11 = OpTypePointer Function %10
246          %13 = OpTypeFloat 32
247          %14 = OpTypeVector %13 4
248          %15 = OpTypePointer Input %14
249          %16 = OpVariable %15 Input
250          %17 = OpTypeInt 32 0
251          %18 = OpConstant %17 0
252          %19 = OpTypePointer Input %13
253          %22 = OpConstant %13 0
254          %30 = OpConstant %6 10
255          %39 = OpConstant %6 1
256          %46 = OpUndef %6
257           %4 = OpFunction %2 None %3
258           %5 = OpLabel
259          %20 = OpAccessChain %19 %16 %18
260          %21 = OpLoad %13 %20
261          %23 = OpFOrdEqual %10 %21 %22
262                OpBranch %24
263          %24 = OpLabel
264          %45 = OpPhi %6 %9 %5 %40 %27
265                OpLoopMerge %26 %27 None
266                OpBranch %28
267          %28 = OpLabel
268          %31 = OpSLessThan %10 %45 %30
269                OpBranchConditional %31 %25 %26
270          %25 = OpLabel
271                OpSelectionMerge %34 None
272                OpBranchConditional %23 %33 %36
273          %33 = OpLabel
274                OpReturn
275          %36 = OpLabel
276                OpReturn
277          %34 = OpLabel
278                OpBranch %27
279          %27 = OpLabel
280          %40 = OpIAdd %6 %46 %39
281                OpBranch %24
282          %26 = OpLabel
283                OpReturn
284                OpFunctionEnd
285   )";
286 
287   std::unique_ptr<IRContext> context =
288       BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
289                   SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
290   Module* module = context->module();
291   EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
292                              << text << std::endl;
293   const Function* f = spvtest::GetFunction(module, 4);
294   LoopDescriptor ld{context.get(), f};
295 
296   EXPECT_EQ(ld.NumLoops(), 0u);
297 }
298 
299 /*
300 Generated from following GLSL with latch block artificially inserted to be
301 seperate from continue.
302 #version 430
303 void main(void) {
304     float x[10];
305     for (int i = 0; i < 10; ++i) {
306       x[i] = i;
307     }
308 }
309 */
TEST_F(PassClassTest,LoopLatchNotContinue)310 TEST_F(PassClassTest, LoopLatchNotContinue) {
311   const std::string text = R"(OpCapability Shader
312           %1 = OpExtInstImport "GLSL.std.450"
313                OpMemoryModel Logical GLSL450
314                OpEntryPoint Fragment %2 "main"
315                OpExecutionMode %2 OriginUpperLeft
316                OpSource GLSL 430
317                OpName %2 "main"
318                OpName %3 "i"
319                OpName %4 "x"
320           %5 = OpTypeVoid
321           %6 = OpTypeFunction %5
322           %7 = OpTypeInt 32 1
323           %8 = OpTypePointer Function %7
324           %9 = OpConstant %7 0
325          %10 = OpConstant %7 10
326          %11 = OpTypeBool
327          %12 = OpTypeFloat 32
328          %13 = OpTypeInt 32 0
329          %14 = OpConstant %13 10
330          %15 = OpTypeArray %12 %14
331          %16 = OpTypePointer Function %15
332          %17 = OpTypePointer Function %12
333          %18 = OpConstant %7 1
334           %2 = OpFunction %5 None %6
335          %19 = OpLabel
336           %3 = OpVariable %8 Function
337           %4 = OpVariable %16 Function
338                OpStore %3 %9
339                OpBranch %20
340          %20 = OpLabel
341          %21 = OpPhi %7 %9 %19 %22 %30
342                OpLoopMerge %24 %23 None
343                OpBranch %25
344          %25 = OpLabel
345          %26 = OpSLessThan %11 %21 %10
346                OpBranchConditional %26 %27 %24
347          %27 = OpLabel
348          %28 = OpConvertSToF %12 %21
349          %29 = OpAccessChain %17 %4 %21
350                OpStore %29 %28
351                OpBranch %23
352          %23 = OpLabel
353          %22 = OpIAdd %7 %21 %18
354                OpStore %3 %22
355                OpBranch %30
356          %30 = OpLabel
357                OpBranch %20
358          %24 = OpLabel
359                OpReturn
360                OpFunctionEnd
361   )";
362 
363   std::unique_ptr<IRContext> context =
364       BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
365                   SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
366   Module* module = context->module();
367   EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
368                              << text << std::endl;
369   const Function* f = spvtest::GetFunction(module, 2);
370   LoopDescriptor ld{context.get(), f};
371 
372   EXPECT_EQ(ld.NumLoops(), 1u);
373 
374   Loop& loop = ld.GetLoopByIndex(0u);
375 
376   EXPECT_NE(loop.GetLatchBlock(), loop.GetContinueBlock());
377 
378   EXPECT_EQ(loop.GetContinueBlock()->id(), 23u);
379   EXPECT_EQ(loop.GetLatchBlock()->id(), 30u);
380 }
381 
382 }  // namespace
383 }  // namespace opt
384 }  // namespace spvtools
385