1 // Copyright (c) 2019 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "source/reduce/remove_function_reduction_opportunity_finder.h"
16 
17 #include "source/opt/build_module.h"
18 #include "source/reduce/reduction_opportunity.h"
19 #include "test/reduce/reduce_test_util.h"
20 
21 namespace spvtools {
22 namespace reduce {
23 namespace {
24 
25 // Helper to count the number of functions in the module.
26 // Remove if there turns out to be a more direct way to do this.
count_functions(opt::IRContext * context)27 uint32_t count_functions(opt::IRContext* context) {
28   uint32_t result = 0;
29   for (auto& function : *context->module()) {
30     (void)(function);
31     ++result;
32   }
33   return result;
34 }
35 
TEST(RemoveFunctionTest,BasicCheck)36 TEST(RemoveFunctionTest, BasicCheck) {
37   std::string shader = R"(
38                OpCapability Shader
39           %1 = OpExtInstImport "GLSL.std.450"
40                OpMemoryModel Logical GLSL450
41                OpEntryPoint Fragment %4 "main"
42                OpExecutionMode %4 OriginUpperLeft
43                OpSource ESSL 310
44           %2 = OpTypeVoid
45           %3 = OpTypeFunction %2
46           %4 = OpFunction %2 None %3
47           %5 = OpLabel
48                OpReturn
49                OpFunctionEnd
50           %6 = OpFunction %2 None %3
51           %7 = OpLabel
52                OpReturn
53                OpFunctionEnd
54           %8 = OpFunction %2 None %3
55           %9 = OpLabel
56          %10 = OpFunctionCall %2 %6
57                OpReturn
58                OpFunctionEnd
59   )";
60 
61   const auto env = SPV_ENV_UNIVERSAL_1_3;
62   const auto consumer = nullptr;
63   const auto context =
64       BuildModule(env, consumer, shader, kReduceAssembleOption);
65 
66   ASSERT_EQ(3, count_functions(context.get()));
67 
68   auto ops =
69       RemoveFunctionReductionOpportunityFinder().GetAvailableOpportunities(
70           context.get(), 0);
71   ASSERT_EQ(1, ops.size());
72 
73   ASSERT_TRUE(ops[0]->PreconditionHolds());
74   ops[0]->TryToApply();
75 
76   ASSERT_EQ(2, count_functions(context.get()));
77 
78   std::string after_first = R"(
79                OpCapability Shader
80           %1 = OpExtInstImport "GLSL.std.450"
81                OpMemoryModel Logical GLSL450
82                OpEntryPoint Fragment %4 "main"
83                OpExecutionMode %4 OriginUpperLeft
84                OpSource ESSL 310
85           %2 = OpTypeVoid
86           %3 = OpTypeFunction %2
87           %4 = OpFunction %2 None %3
88           %5 = OpLabel
89                OpReturn
90                OpFunctionEnd
91           %6 = OpFunction %2 None %3
92           %7 = OpLabel
93                OpReturn
94                OpFunctionEnd
95   )";
96 
97   CheckEqual(env, after_first, context.get());
98 
99   ops = RemoveFunctionReductionOpportunityFinder().GetAvailableOpportunities(
100       context.get(), 0);
101 
102   ASSERT_EQ(1, ops.size());
103 
104   ASSERT_TRUE(ops[0]->PreconditionHolds());
105   ops[0]->TryToApply();
106 
107   ASSERT_EQ(1, count_functions(context.get()));
108 
109   std::string after_second = R"(
110                OpCapability Shader
111           %1 = OpExtInstImport "GLSL.std.450"
112                OpMemoryModel Logical GLSL450
113                OpEntryPoint Fragment %4 "main"
114                OpExecutionMode %4 OriginUpperLeft
115                OpSource ESSL 310
116           %2 = OpTypeVoid
117           %3 = OpTypeFunction %2
118           %4 = OpFunction %2 None %3
119           %5 = OpLabel
120                OpReturn
121                OpFunctionEnd
122   )";
123 
124   CheckEqual(env, after_second, context.get());
125 }
126 
TEST(RemoveFunctionTest,NothingToRemove)127 TEST(RemoveFunctionTest, NothingToRemove) {
128   std::string shader = R"(
129                OpCapability Shader
130           %1 = OpExtInstImport "GLSL.std.450"
131                OpMemoryModel Logical GLSL450
132                OpEntryPoint Fragment %4 "main"
133                OpExecutionMode %4 OriginUpperLeft
134                OpSource ESSL 310
135           %2 = OpTypeVoid
136           %3 = OpTypeFunction %2
137           %4 = OpFunction %2 None %3
138           %5 = OpLabel
139          %11 = OpFunctionCall %2 %8
140                OpReturn
141                OpFunctionEnd
142           %6 = OpFunction %2 None %3
143           %7 = OpLabel
144                OpReturn
145                OpFunctionEnd
146           %8 = OpFunction %2 None %3
147           %9 = OpLabel
148          %10 = OpFunctionCall %2 %6
149                OpReturn
150                OpFunctionEnd
151   )";
152 
153   const auto env = SPV_ENV_UNIVERSAL_1_3;
154   const auto consumer = nullptr;
155   const auto context =
156       BuildModule(env, consumer, shader, kReduceAssembleOption);
157   auto ops =
158       RemoveFunctionReductionOpportunityFinder().GetAvailableOpportunities(
159           context.get(), 0);
160   ASSERT_EQ(0, ops.size());
161 }
162 
TEST(RemoveFunctionTest,TwoRemovableFunctions)163 TEST(RemoveFunctionTest, TwoRemovableFunctions) {
164   std::string shader = R"(
165                OpCapability Shader
166           %1 = OpExtInstImport "GLSL.std.450"
167                OpMemoryModel Logical GLSL450
168                OpEntryPoint Fragment %4 "main"
169                OpExecutionMode %4 OriginUpperLeft
170                OpSource ESSL 310
171           %2 = OpTypeVoid
172           %3 = OpTypeFunction %2
173           %4 = OpFunction %2 None %3
174           %5 = OpLabel
175                OpReturn
176                OpFunctionEnd
177           %6 = OpFunction %2 None %3
178           %7 = OpLabel
179                OpReturn
180                OpFunctionEnd
181           %8 = OpFunction %2 None %3
182           %9 = OpLabel
183                OpReturn
184                OpFunctionEnd
185   )";
186 
187   const auto env = SPV_ENV_UNIVERSAL_1_3;
188   const auto consumer = nullptr;
189   const auto context =
190       BuildModule(env, consumer, shader, kReduceAssembleOption);
191 
192   ASSERT_EQ(3, count_functions(context.get()));
193 
194   auto ops =
195       RemoveFunctionReductionOpportunityFinder().GetAvailableOpportunities(
196           context.get(), 0);
197   ASSERT_EQ(2, ops.size());
198 
199   ASSERT_TRUE(ops[0]->PreconditionHolds());
200   ops[0]->TryToApply();
201   ASSERT_EQ(2, count_functions(context.get()));
202   ASSERT_TRUE(ops[1]->PreconditionHolds());
203   ops[1]->TryToApply();
204   ASSERT_EQ(1, count_functions(context.get()));
205 
206   std::string after = R"(
207                OpCapability Shader
208           %1 = OpExtInstImport "GLSL.std.450"
209                OpMemoryModel Logical GLSL450
210                OpEntryPoint Fragment %4 "main"
211                OpExecutionMode %4 OriginUpperLeft
212                OpSource ESSL 310
213           %2 = OpTypeVoid
214           %3 = OpTypeFunction %2
215           %4 = OpFunction %2 None %3
216           %5 = OpLabel
217                OpReturn
218                OpFunctionEnd
219   )";
220 
221   CheckEqual(env, after, context.get());
222 }
223 
TEST(RemoveFunctionTest,NoRemovalsDueToOpName)224 TEST(RemoveFunctionTest, NoRemovalsDueToOpName) {
225   std::string shader = R"(
226                OpCapability Shader
227           %1 = OpExtInstImport "GLSL.std.450"
228                OpMemoryModel Logical GLSL450
229                OpEntryPoint Fragment %4 "main"
230                OpExecutionMode %4 OriginUpperLeft
231                OpSource ESSL 310
232                OpName %4 "main"
233                OpName %6 "foo("
234                OpName %8 "bar("
235           %2 = OpTypeVoid
236           %3 = OpTypeFunction %2
237           %4 = OpFunction %2 None %3
238           %5 = OpLabel
239                OpReturn
240                OpFunctionEnd
241           %6 = OpFunction %2 None %3
242           %7 = OpLabel
243                OpReturn
244                OpFunctionEnd
245           %8 = OpFunction %2 None %3
246           %9 = OpLabel
247                OpReturn
248                OpFunctionEnd
249   )";
250 
251   const auto env = SPV_ENV_UNIVERSAL_1_3;
252   const auto consumer = nullptr;
253   const auto context =
254       BuildModule(env, consumer, shader, kReduceAssembleOption);
255   auto ops =
256       RemoveFunctionReductionOpportunityFinder().GetAvailableOpportunities(
257           context.get(), 0);
258   ASSERT_EQ(0, ops.size());
259 }
260 
TEST(RemoveFunctionTest,NoRemovalDueToLinkageDecoration)261 TEST(RemoveFunctionTest, NoRemovalDueToLinkageDecoration) {
262   // The non-entry point function is not removable because it is referenced by a
263   // linkage decoration. Thus no function can be removed.
264   std::string shader = R"(
265                OpCapability Shader
266                OpCapability Linkage
267                OpMemoryModel Logical GLSL450
268                OpEntryPoint Fragment %1 "main"
269                OpName %1 "main"
270                OpDecorate %2 LinkageAttributes "ExportedFunc" Export
271           %4 = OpTypeVoid
272           %5 = OpTypeFunction %4
273           %1 = OpFunction %4 None %5
274           %6 = OpLabel
275                OpReturn
276                OpFunctionEnd
277           %2 = OpFunction %4 None %5
278           %7 = OpLabel
279                OpReturn
280                OpFunctionEnd
281   )";
282 
283   const auto env = SPV_ENV_UNIVERSAL_1_3;
284   const auto consumer = nullptr;
285   const auto context =
286       BuildModule(env, consumer, shader, kReduceAssembleOption);
287   auto ops =
288       RemoveFunctionReductionOpportunityFinder().GetAvailableOpportunities(
289           context.get(), 0);
290   ASSERT_EQ(0, ops.size());
291 }
292 
293 }  // namespace
294 }  // namespace reduce
295 }  // namespace spvtools
296