1 // Copyright (c) 2019 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "source/reduce/merge_blocks_reduction_opportunity_finder.h"
16 
17 #include "source/opt/build_module.h"
18 #include "source/reduce/reduction_opportunity.h"
19 #include "test/reduce/reduce_test_util.h"
20 
21 namespace spvtools {
22 namespace reduce {
23 namespace {
24 
TEST(MergeBlocksReductionPassTest,BasicCheck)25 TEST(MergeBlocksReductionPassTest, BasicCheck) {
26   std::string shader = R"(
27                OpCapability Shader
28           %1 = OpExtInstImport "GLSL.std.450"
29                OpMemoryModel Logical GLSL450
30                OpEntryPoint Fragment %4 "main"
31                OpExecutionMode %4 OriginUpperLeft
32                OpSource ESSL 310
33                OpName %4 "main"
34                OpName %8 "x"
35           %2 = OpTypeVoid
36           %3 = OpTypeFunction %2
37           %6 = OpTypeInt 32 1
38           %7 = OpTypePointer Function %6
39           %9 = OpConstant %6 1
40          %10 = OpConstant %6 2
41          %11 = OpConstant %6 3
42          %12 = OpConstant %6 4
43           %4 = OpFunction %2 None %3
44           %5 = OpLabel
45           %8 = OpVariable %7 Function
46                OpBranch %13
47          %13 = OpLabel
48                OpStore %8 %9
49                OpBranch %14
50          %14 = OpLabel
51                OpStore %8 %10
52                OpBranch %15
53          %15 = OpLabel
54                OpStore %8 %11
55                OpBranch %16
56          %16 = OpLabel
57                OpStore %8 %12
58                OpBranch %17
59          %17 = OpLabel
60                OpReturn
61                OpFunctionEnd
62   )";
63   const auto env = SPV_ENV_UNIVERSAL_1_3;
64   const auto consumer = nullptr;
65   const auto context =
66       BuildModule(env, consumer, shader, kReduceAssembleOption);
67   const auto ops =
68       MergeBlocksReductionOpportunityFinder().GetAvailableOpportunities(
69           context.get(), 0);
70   ASSERT_EQ(5, ops.size());
71 
72   // Try order 3, 0, 2, 4, 1
73 
74   ASSERT_TRUE(ops[3]->PreconditionHolds());
75   ops[3]->TryToApply();
76 
77   std::string after_op_3 = R"(
78                OpCapability Shader
79           %1 = OpExtInstImport "GLSL.std.450"
80                OpMemoryModel Logical GLSL450
81                OpEntryPoint Fragment %4 "main"
82                OpExecutionMode %4 OriginUpperLeft
83                OpSource ESSL 310
84                OpName %4 "main"
85                OpName %8 "x"
86           %2 = OpTypeVoid
87           %3 = OpTypeFunction %2
88           %6 = OpTypeInt 32 1
89           %7 = OpTypePointer Function %6
90           %9 = OpConstant %6 1
91          %10 = OpConstant %6 2
92          %11 = OpConstant %6 3
93          %12 = OpConstant %6 4
94           %4 = OpFunction %2 None %3
95           %5 = OpLabel
96           %8 = OpVariable %7 Function
97                OpBranch %13
98          %13 = OpLabel
99                OpStore %8 %9
100                OpBranch %14
101          %14 = OpLabel
102                OpStore %8 %10
103                OpBranch %15
104          %15 = OpLabel
105                OpStore %8 %11
106                OpStore %8 %12
107                OpBranch %17
108          %17 = OpLabel
109                OpReturn
110                OpFunctionEnd
111   )";
112 
113   CheckEqual(env, after_op_3, context.get());
114 
115   ASSERT_TRUE(ops[0]->PreconditionHolds());
116   ops[0]->TryToApply();
117 
118   std::string after_op_0 = R"(
119                OpCapability Shader
120           %1 = OpExtInstImport "GLSL.std.450"
121                OpMemoryModel Logical GLSL450
122                OpEntryPoint Fragment %4 "main"
123                OpExecutionMode %4 OriginUpperLeft
124                OpSource ESSL 310
125                OpName %4 "main"
126                OpName %8 "x"
127           %2 = OpTypeVoid
128           %3 = OpTypeFunction %2
129           %6 = OpTypeInt 32 1
130           %7 = OpTypePointer Function %6
131           %9 = OpConstant %6 1
132          %10 = OpConstant %6 2
133          %11 = OpConstant %6 3
134          %12 = OpConstant %6 4
135           %4 = OpFunction %2 None %3
136           %5 = OpLabel
137           %8 = OpVariable %7 Function
138                OpStore %8 %9
139                OpBranch %14
140          %14 = OpLabel
141                OpStore %8 %10
142                OpBranch %15
143          %15 = OpLabel
144                OpStore %8 %11
145                OpStore %8 %12
146                OpBranch %17
147          %17 = OpLabel
148                OpReturn
149                OpFunctionEnd
150   )";
151 
152   CheckEqual(env, after_op_0, context.get());
153 
154   ASSERT_TRUE(ops[2]->PreconditionHolds());
155   ops[2]->TryToApply();
156 
157   std::string after_op_2 = R"(
158                OpCapability Shader
159           %1 = OpExtInstImport "GLSL.std.450"
160                OpMemoryModel Logical GLSL450
161                OpEntryPoint Fragment %4 "main"
162                OpExecutionMode %4 OriginUpperLeft
163                OpSource ESSL 310
164                OpName %4 "main"
165                OpName %8 "x"
166           %2 = OpTypeVoid
167           %3 = OpTypeFunction %2
168           %6 = OpTypeInt 32 1
169           %7 = OpTypePointer Function %6
170           %9 = OpConstant %6 1
171          %10 = OpConstant %6 2
172          %11 = OpConstant %6 3
173          %12 = OpConstant %6 4
174           %4 = OpFunction %2 None %3
175           %5 = OpLabel
176           %8 = OpVariable %7 Function
177                OpStore %8 %9
178                OpBranch %14
179          %14 = OpLabel
180                OpStore %8 %10
181                OpStore %8 %11
182                OpStore %8 %12
183                OpBranch %17
184          %17 = OpLabel
185                OpReturn
186                OpFunctionEnd
187   )";
188 
189   CheckEqual(env, after_op_2, context.get());
190 
191   ASSERT_TRUE(ops[4]->PreconditionHolds());
192   ops[4]->TryToApply();
193 
194   std::string after_op_4 = R"(
195                OpCapability Shader
196           %1 = OpExtInstImport "GLSL.std.450"
197                OpMemoryModel Logical GLSL450
198                OpEntryPoint Fragment %4 "main"
199                OpExecutionMode %4 OriginUpperLeft
200                OpSource ESSL 310
201                OpName %4 "main"
202                OpName %8 "x"
203           %2 = OpTypeVoid
204           %3 = OpTypeFunction %2
205           %6 = OpTypeInt 32 1
206           %7 = OpTypePointer Function %6
207           %9 = OpConstant %6 1
208          %10 = OpConstant %6 2
209          %11 = OpConstant %6 3
210          %12 = OpConstant %6 4
211           %4 = OpFunction %2 None %3
212           %5 = OpLabel
213           %8 = OpVariable %7 Function
214                OpStore %8 %9
215                OpBranch %14
216          %14 = OpLabel
217                OpStore %8 %10
218                OpStore %8 %11
219                OpStore %8 %12
220                OpReturn
221                OpFunctionEnd
222   )";
223 
224   CheckEqual(env, after_op_4, context.get());
225 
226   ASSERT_TRUE(ops[1]->PreconditionHolds());
227   ops[1]->TryToApply();
228 
229   std::string after_op_1 = R"(
230                OpCapability Shader
231           %1 = OpExtInstImport "GLSL.std.450"
232                OpMemoryModel Logical GLSL450
233                OpEntryPoint Fragment %4 "main"
234                OpExecutionMode %4 OriginUpperLeft
235                OpSource ESSL 310
236                OpName %4 "main"
237                OpName %8 "x"
238           %2 = OpTypeVoid
239           %3 = OpTypeFunction %2
240           %6 = OpTypeInt 32 1
241           %7 = OpTypePointer Function %6
242           %9 = OpConstant %6 1
243          %10 = OpConstant %6 2
244          %11 = OpConstant %6 3
245          %12 = OpConstant %6 4
246           %4 = OpFunction %2 None %3
247           %5 = OpLabel
248           %8 = OpVariable %7 Function
249                OpStore %8 %9
250                OpStore %8 %10
251                OpStore %8 %11
252                OpStore %8 %12
253                OpReturn
254                OpFunctionEnd
255   )";
256 
257   CheckEqual(env, after_op_1, context.get());
258 }
259 
TEST(MergeBlocksReductionPassTest,Loops)260 TEST(MergeBlocksReductionPassTest, Loops) {
261   std::string shader = R"(
262                OpCapability Shader
263           %1 = OpExtInstImport "GLSL.std.450"
264                OpMemoryModel Logical GLSL450
265                OpEntryPoint Fragment %4 "main"
266                OpExecutionMode %4 OriginUpperLeft
267                OpSource ESSL 310
268                OpName %4 "main"
269                OpName %8 "x"
270                OpName %10 "i"
271                OpName %29 "i"
272           %2 = OpTypeVoid
273           %3 = OpTypeFunction %2
274           %6 = OpTypeInt 32 1
275           %7 = OpTypePointer Function %6
276           %9 = OpConstant %6 1
277          %11 = OpConstant %6 0
278          %18 = OpConstant %6 10
279          %19 = OpTypeBool
280           %4 = OpFunction %2 None %3
281           %5 = OpLabel
282           %8 = OpVariable %7 Function
283          %10 = OpVariable %7 Function
284          %29 = OpVariable %7 Function
285                OpStore %8 %9
286                OpBranch %45
287          %45 = OpLabel
288                OpStore %10 %11
289                OpBranch %12
290          %12 = OpLabel
291                OpLoopMerge %14 %15 None
292                OpBranch %16
293          %16 = OpLabel
294          %17 = OpLoad %6 %10
295                OpBranch %46
296          %46 = OpLabel
297          %20 = OpSLessThan %19 %17 %18
298                OpBranchConditional %20 %13 %14
299          %13 = OpLabel
300          %21 = OpLoad %6 %10
301                OpBranch %47
302          %47 = OpLabel
303          %22 = OpLoad %6 %8
304          %23 = OpIAdd %6 %22 %21
305                OpStore %8 %23
306          %24 = OpLoad %6 %10
307          %25 = OpLoad %6 %8
308          %26 = OpIAdd %6 %25 %24
309                OpStore %8 %26
310                OpBranch %48
311          %48 = OpLabel
312                OpBranch %15
313          %15 = OpLabel
314          %27 = OpLoad %6 %10
315          %28 = OpIAdd %6 %27 %9
316                OpStore %10 %28
317                OpBranch %12
318          %14 = OpLabel
319                OpStore %29 %11
320                OpBranch %49
321          %49 = OpLabel
322                OpBranch %30
323          %30 = OpLabel
324                OpLoopMerge %32 %33 None
325                OpBranch %34
326          %34 = OpLabel
327          %35 = OpLoad %6 %29
328          %36 = OpSLessThan %19 %35 %18
329                OpBranch %50
330          %50 = OpLabel
331                OpBranchConditional %36 %31 %32
332          %31 = OpLabel
333          %37 = OpLoad %6 %29
334          %38 = OpLoad %6 %8
335          %39 = OpIAdd %6 %38 %37
336                OpStore %8 %39
337          %40 = OpLoad %6 %29
338          %41 = OpLoad %6 %8
339          %42 = OpIAdd %6 %41 %40
340                OpStore %8 %42
341                OpBranch %33
342          %33 = OpLabel
343          %43 = OpLoad %6 %29
344          %44 = OpIAdd %6 %43 %9
345                OpBranch %51
346          %51 = OpLabel
347                OpStore %29 %44
348                OpBranch %30
349          %32 = OpLabel
350                OpReturn
351                OpFunctionEnd
352   )";
353   const auto env = SPV_ENV_UNIVERSAL_1_3;
354   const auto consumer = nullptr;
355   const auto context =
356       BuildModule(env, consumer, shader, kReduceAssembleOption);
357   const auto ops =
358       MergeBlocksReductionOpportunityFinder().GetAvailableOpportunities(
359           context.get(), 0);
360   ASSERT_EQ(11, ops.size());
361 
362   for (auto& ri : ops) {
363     ASSERT_TRUE(ri->PreconditionHolds());
364     ri->TryToApply();
365   }
366 
367   std::string after = R"(
368                OpCapability Shader
369           %1 = OpExtInstImport "GLSL.std.450"
370                OpMemoryModel Logical GLSL450
371                OpEntryPoint Fragment %4 "main"
372                OpExecutionMode %4 OriginUpperLeft
373                OpSource ESSL 310
374                OpName %4 "main"
375                OpName %8 "x"
376                OpName %10 "i"
377                OpName %29 "i"
378           %2 = OpTypeVoid
379           %3 = OpTypeFunction %2
380           %6 = OpTypeInt 32 1
381           %7 = OpTypePointer Function %6
382           %9 = OpConstant %6 1
383          %11 = OpConstant %6 0
384          %18 = OpConstant %6 10
385          %19 = OpTypeBool
386           %4 = OpFunction %2 None %3
387           %5 = OpLabel
388           %8 = OpVariable %7 Function
389          %10 = OpVariable %7 Function
390          %29 = OpVariable %7 Function
391                OpStore %8 %9
392                OpStore %10 %11
393                OpBranch %12
394          %12 = OpLabel
395          %17 = OpLoad %6 %10
396          %20 = OpSLessThan %19 %17 %18
397                OpLoopMerge %14 %13 None
398                OpBranchConditional %20 %13 %14
399          %13 = OpLabel
400          %21 = OpLoad %6 %10
401          %22 = OpLoad %6 %8
402          %23 = OpIAdd %6 %22 %21
403                OpStore %8 %23
404          %24 = OpLoad %6 %10
405          %25 = OpLoad %6 %8
406          %26 = OpIAdd %6 %25 %24
407                OpStore %8 %26
408          %27 = OpLoad %6 %10
409          %28 = OpIAdd %6 %27 %9
410                OpStore %10 %28
411                OpBranch %12
412          %14 = OpLabel
413                OpStore %29 %11
414                OpBranch %30
415          %30 = OpLabel
416          %35 = OpLoad %6 %29
417          %36 = OpSLessThan %19 %35 %18
418                OpLoopMerge %32 %31 None
419                OpBranchConditional %36 %31 %32
420          %31 = OpLabel
421          %37 = OpLoad %6 %29
422          %38 = OpLoad %6 %8
423          %39 = OpIAdd %6 %38 %37
424                OpStore %8 %39
425          %40 = OpLoad %6 %29
426          %41 = OpLoad %6 %8
427          %42 = OpIAdd %6 %41 %40
428                OpStore %8 %42
429          %43 = OpLoad %6 %29
430          %44 = OpIAdd %6 %43 %9
431                OpStore %29 %44
432                OpBranch %30
433          %32 = OpLabel
434                OpReturn
435                OpFunctionEnd
436   )";
437 
438   CheckEqual(env, after, context.get());
439 }
440 
TEST(MergeBlocksReductionPassTest,MergeWithOpPhi)441 TEST(MergeBlocksReductionPassTest, MergeWithOpPhi) {
442   std::string shader = R"(
443                OpCapability Shader
444           %1 = OpExtInstImport "GLSL.std.450"
445                OpMemoryModel Logical GLSL450
446                OpEntryPoint Fragment %4 "main"
447                OpExecutionMode %4 OriginUpperLeft
448                OpSource ESSL 310
449                OpName %4 "main"
450                OpName %8 "x"
451                OpName %10 "y"
452           %2 = OpTypeVoid
453           %3 = OpTypeFunction %2
454           %6 = OpTypeInt 32 1
455           %7 = OpTypePointer Function %6
456           %9 = OpConstant %6 1
457           %4 = OpFunction %2 None %3
458           %5 = OpLabel
459           %8 = OpVariable %7 Function
460          %10 = OpVariable %7 Function
461                OpStore %8 %9
462          %11 = OpLoad %6 %8
463                OpBranch %12
464          %12 = OpLabel
465          %13 = OpPhi %6 %11 %5
466                OpStore %10 %13
467                OpReturn
468                OpFunctionEnd
469   )";
470 
471   const auto env = SPV_ENV_UNIVERSAL_1_3;
472   const auto consumer = nullptr;
473   const auto context =
474       BuildModule(env, consumer, shader, kReduceAssembleOption);
475   const auto ops =
476       MergeBlocksReductionOpportunityFinder().GetAvailableOpportunities(
477           context.get(), 0);
478   ASSERT_EQ(1, ops.size());
479 
480   ASSERT_TRUE(ops[0]->PreconditionHolds());
481   ops[0]->TryToApply();
482 
483   std::string after = R"(
484                OpCapability Shader
485           %1 = OpExtInstImport "GLSL.std.450"
486                OpMemoryModel Logical GLSL450
487                OpEntryPoint Fragment %4 "main"
488                OpExecutionMode %4 OriginUpperLeft
489                OpSource ESSL 310
490                OpName %4 "main"
491                OpName %8 "x"
492                OpName %10 "y"
493           %2 = OpTypeVoid
494           %3 = OpTypeFunction %2
495           %6 = OpTypeInt 32 1
496           %7 = OpTypePointer Function %6
497           %9 = OpConstant %6 1
498           %4 = OpFunction %2 None %3
499           %5 = OpLabel
500           %8 = OpVariable %7 Function
501          %10 = OpVariable %7 Function
502                OpStore %8 %9
503          %11 = OpLoad %6 %8
504                OpStore %10 %11
505                OpReturn
506                OpFunctionEnd
507   )";
508 
509   CheckEqual(env, after, context.get());
510 }
511 
MergeBlocksReductionPassTest_LoopReturn_Helper(bool reverse)512 void MergeBlocksReductionPassTest_LoopReturn_Helper(bool reverse) {
513   // A merge block opportunity stores a block that can be merged with its
514   // predecessor.
515   // Given blocks A -> B -> C:
516   // This test demonstrates how merging B->C can invalidate
517   // the opportunity of merging A->B, and vice-versa. E.g.
518   // B->C are merged: B is now terminated with OpReturn.
519   // A->B can now no longer be merged because A is a loop header, which
520   // cannot be terminated with OpReturn.
521 
522   std::string shader = R"(
523                OpCapability Shader
524           %1 = OpExtInstImport "GLSL.std.450"
525                OpMemoryModel Logical GLSL450
526                OpEntryPoint Fragment %2 "main"
527                OpExecutionMode %2 OriginUpperLeft
528                OpSource ESSL 310
529                OpName %2 "main"
530           %3 = OpTypeVoid
531           %4 = OpTypeFunction %3
532           %5 = OpTypeInt 32 1
533           %6 = OpTypePointer Function %5
534           %7 = OpTypeBool
535           %8 = OpConstantFalse %7
536           %2 = OpFunction %3 None %4
537           %9 = OpLabel
538                OpBranch %10
539          %10 = OpLabel                   ; A (loop header)
540                OpLoopMerge %13 %12 None
541                OpBranch %11
542          %12 = OpLabel                   ; (unreachable continue block)
543                OpBranch %10
544          %11 = OpLabel                   ; B
545                OpBranch %15
546          %15 = OpLabel                   ; C
547                OpReturn
548          %13 = OpLabel                   ; (unreachable merge block)
549                OpReturn
550                OpFunctionEnd
551   )";
552   const auto env = SPV_ENV_UNIVERSAL_1_3;
553   const auto consumer = nullptr;
554   const auto context =
555       BuildModule(env, consumer, shader, kReduceAssembleOption);
556   ASSERT_NE(context.get(), nullptr);
557   auto opportunities =
558       MergeBlocksReductionOpportunityFinder().GetAvailableOpportunities(
559           context.get(), 0);
560 
561   // A->B and B->C
562   ASSERT_EQ(opportunities.size(), 2);
563 
564   // Test applying opportunities in both orders.
565   if (reverse) {
566     std::reverse(opportunities.begin(), opportunities.end());
567   }
568 
569   size_t num_applied = 0;
570   for (auto& ri : opportunities) {
571     if (ri->PreconditionHolds()) {
572       ri->TryToApply();
573       ++num_applied;
574     }
575   }
576 
577   // Only 1 opportunity can be applied, as both disable each other.
578   ASSERT_EQ(num_applied, 1);
579 
580   std::string after = R"(
581                OpCapability Shader
582           %1 = OpExtInstImport "GLSL.std.450"
583                OpMemoryModel Logical GLSL450
584                OpEntryPoint Fragment %2 "main"
585                OpExecutionMode %2 OriginUpperLeft
586                OpSource ESSL 310
587                OpName %2 "main"
588           %3 = OpTypeVoid
589           %4 = OpTypeFunction %3
590           %5 = OpTypeInt 32 1
591           %6 = OpTypePointer Function %5
592           %7 = OpTypeBool
593           %8 = OpConstantFalse %7
594           %2 = OpFunction %3 None %4
595           %9 = OpLabel
596                OpBranch %10
597          %10 = OpLabel                   ; A-B (loop header)
598                OpLoopMerge %13 %12 None
599                OpBranch %15
600          %12 = OpLabel                   ; (unreachable continue block)
601                OpBranch %10
602          %15 = OpLabel                   ; C
603                OpReturn
604          %13 = OpLabel                   ; (unreachable merge block)
605                OpReturn
606                OpFunctionEnd
607   )";
608 
609   // The only difference is the labels.
610   std::string after_reversed = R"(
611                OpCapability Shader
612           %1 = OpExtInstImport "GLSL.std.450"
613                OpMemoryModel Logical GLSL450
614                OpEntryPoint Fragment %2 "main"
615                OpExecutionMode %2 OriginUpperLeft
616                OpSource ESSL 310
617                OpName %2 "main"
618           %3 = OpTypeVoid
619           %4 = OpTypeFunction %3
620           %5 = OpTypeInt 32 1
621           %6 = OpTypePointer Function %5
622           %7 = OpTypeBool
623           %8 = OpConstantFalse %7
624           %2 = OpFunction %3 None %4
625           %9 = OpLabel
626                OpBranch %10
627          %10 = OpLabel                   ; A (loop header)
628                OpLoopMerge %13 %12 None
629                OpBranch %11
630          %12 = OpLabel                   ; (unreachable continue block)
631                OpBranch %10
632          %11 = OpLabel                   ; B-C
633                OpReturn
634          %13 = OpLabel                   ; (unreachable merge block)
635                OpReturn
636                OpFunctionEnd
637   )";
638 
639   CheckEqual(env, reverse ? after_reversed : after, context.get());
640 }
641 
TEST(MergeBlocksReductionPassTest,LoopReturn)642 TEST(MergeBlocksReductionPassTest, LoopReturn) {
643   MergeBlocksReductionPassTest_LoopReturn_Helper(false);
644 }
645 
TEST(MergeBlocksReductionPassTest,LoopReturnReverse)646 TEST(MergeBlocksReductionPassTest, LoopReturnReverse) {
647   MergeBlocksReductionPassTest_LoopReturn_Helper(true);
648 }
649 
650 }  // namespace
651 }  // namespace reduce
652 }  // namespace spvtools
653