1 // Copyright (c) 2018 Google Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include <algorithm>
16 #include <memory>
17 #include <string>
18 #include <vector>
19 
20 #include "effcee/effcee.h"
21 #include "gmock/gmock.h"
22 #include "gtest/gtest.h"
23 #include "source/opt/basic_block.h"
24 #include "source/opt/build_module.h"
25 #include "source/opt/instruction.h"
26 #include "source/opt/ir_builder.h"
27 #include "source/opt/type_manager.h"
28 #include "spirv-tools/libspirv.hpp"
29 
30 namespace spvtools {
31 namespace opt {
32 namespace {
33 
34 using Analysis = IRContext::Analysis;
35 using IRBuilderTest = ::testing::Test;
36 
Validate(const std::vector<uint32_t> & bin)37 bool Validate(const std::vector<uint32_t>& bin) {
38   spv_target_env target_env = SPV_ENV_UNIVERSAL_1_2;
39   spv_context spvContext = spvContextCreate(target_env);
40   spv_diagnostic diagnostic = nullptr;
41   spv_const_binary_t binary = {bin.data(), bin.size()};
42   spv_result_t error = spvValidate(spvContext, &binary, &diagnostic);
43   if (error != 0) spvDiagnosticPrint(diagnostic);
44   spvDiagnosticDestroy(diagnostic);
45   spvContextDestroy(spvContext);
46   return error == 0;
47 }
48 
Match(const std::string & original,IRContext * context,bool do_validation=true)49 void Match(const std::string& original, IRContext* context,
50            bool do_validation = true) {
51   std::vector<uint32_t> bin;
52   context->module()->ToBinary(&bin, true);
53   if (do_validation) {
54     EXPECT_TRUE(Validate(bin));
55   }
56   std::string assembly;
57   SpirvTools tools(SPV_ENV_UNIVERSAL_1_2);
58   EXPECT_TRUE(
59       tools.Disassemble(bin, &assembly, SpirvTools::kDefaultDisassembleOption))
60       << "Disassembling failed for shader:\n"
61       << assembly << std::endl;
62   auto match_result = effcee::Match(assembly, original);
63   EXPECT_EQ(effcee::Result::Status::Ok, match_result.status())
64       << match_result.message() << "\nChecking result:\n"
65       << assembly;
66 }
67 
TEST_F(IRBuilderTest,TestInsnAddition)68 TEST_F(IRBuilderTest, TestInsnAddition) {
69   const std::string text = R"(
70 ; CHECK: %18 = OpLabel
71 ; CHECK: OpPhi %int %int_0 %14
72 ; CHECK: OpPhi %bool %16 %14
73 ; CHECK: OpBranch %17
74                OpCapability Shader
75           %1 = OpExtInstImport "GLSL.std.450"
76                OpMemoryModel Logical GLSL450
77                OpEntryPoint Fragment %2 "main" %3
78                OpExecutionMode %2 OriginUpperLeft
79                OpSource GLSL 330
80                OpName %2 "main"
81                OpName %4 "i"
82                OpName %3 "c"
83                OpDecorate %3 Location 0
84           %5 = OpTypeVoid
85           %6 = OpTypeFunction %5
86           %7 = OpTypeInt 32 1
87           %8 = OpTypePointer Function %7
88           %9 = OpConstant %7 0
89          %10 = OpTypeBool
90          %11 = OpTypeFloat 32
91          %12 = OpTypeVector %11 4
92          %13 = OpTypePointer Output %12
93           %3 = OpVariable %13 Output
94           %2 = OpFunction %5 None %6
95          %14 = OpLabel
96           %4 = OpVariable %8 Function
97                OpStore %4 %9
98          %15 = OpLoad %7 %4
99          %16 = OpINotEqual %10 %15 %9
100                OpSelectionMerge %17 None
101                OpBranchConditional %16 %18 %17
102          %18 = OpLabel
103                OpBranch %17
104          %17 = OpLabel
105                OpReturn
106                OpFunctionEnd
107 )";
108 
109   {
110     std::unique_ptr<IRContext> context =
111         BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text);
112 
113     BasicBlock* bb = context->cfg()->block(18);
114 
115     // Build managers.
116     context->get_def_use_mgr();
117     context->get_instr_block(nullptr);
118 
119     InstructionBuilder builder(context.get(), &*bb->begin());
120     Instruction* phi1 = builder.AddPhi(7, {9, 14});
121     Instruction* phi2 = builder.AddPhi(10, {16, 14});
122 
123     // Make sure the InstructionBuilder did not update the def/use manager.
124     EXPECT_EQ(context->get_def_use_mgr()->GetDef(phi1->result_id()), nullptr);
125     EXPECT_EQ(context->get_def_use_mgr()->GetDef(phi2->result_id()), nullptr);
126     EXPECT_EQ(context->get_instr_block(phi1), nullptr);
127     EXPECT_EQ(context->get_instr_block(phi2), nullptr);
128 
129     Match(text, context.get());
130   }
131 
132   {
133     std::unique_ptr<IRContext> context =
134         BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text);
135 
136     // Build managers.
137     context->get_def_use_mgr();
138     context->get_instr_block(nullptr);
139 
140     BasicBlock* bb = context->cfg()->block(18);
141     InstructionBuilder builder(
142         context.get(), &*bb->begin(),
143         IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
144     Instruction* phi1 = builder.AddPhi(7, {9, 14});
145     Instruction* phi2 = builder.AddPhi(10, {16, 14});
146 
147     // Make sure InstructionBuilder updated the def/use manager
148     EXPECT_NE(context->get_def_use_mgr()->GetDef(phi1->result_id()), nullptr);
149     EXPECT_NE(context->get_def_use_mgr()->GetDef(phi2->result_id()), nullptr);
150     EXPECT_NE(context->get_instr_block(phi1), nullptr);
151     EXPECT_NE(context->get_instr_block(phi2), nullptr);
152 
153     Match(text, context.get());
154   }
155 }
156 
TEST_F(IRBuilderTest,TestCondBranchAddition)157 TEST_F(IRBuilderTest, TestCondBranchAddition) {
158   const std::string text = R"(
159 ; CHECK: %main = OpFunction %void None %6
160 ; CHECK-NEXT: %15 = OpLabel
161 ; CHECK-NEXT: OpSelectionMerge %13 None
162 ; CHECK-NEXT: OpBranchConditional %true %14 %13
163 ; CHECK-NEXT: %14 = OpLabel
164 ; CHECK-NEXT: OpBranch %13
165 ; CHECK-NEXT: %13 = OpLabel
166 ; CHECK-NEXT: OpReturn
167                OpCapability Shader
168           %1 = OpExtInstImport "GLSL.std.450"
169                OpMemoryModel Logical GLSL450
170                OpEntryPoint Fragment %2 "main" %3
171                OpExecutionMode %2 OriginUpperLeft
172                OpSource GLSL 330
173                OpName %2 "main"
174                OpName %4 "i"
175                OpName %3 "c"
176                OpDecorate %3 Location 0
177           %5 = OpTypeVoid
178           %6 = OpTypeFunction %5
179           %7 = OpTypeBool
180           %8 = OpTypePointer Private %7
181           %9 = OpConstantTrue %7
182          %10 = OpTypeFloat 32
183          %11 = OpTypeVector %10 4
184          %12 = OpTypePointer Output %11
185           %3 = OpVariable %12 Output
186           %4 = OpVariable %8 Private
187           %2 = OpFunction %5 None %6
188          %13 = OpLabel
189                OpReturn
190                OpFunctionEnd
191 )";
192 
193   {
194     std::unique_ptr<IRContext> context =
195         BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text);
196 
197     Function& fn = *context->module()->begin();
198 
199     BasicBlock& bb_merge = *fn.begin();
200 
201     // TODO(1841): Handle id overflow.
202     fn.begin().InsertBefore(std::unique_ptr<BasicBlock>(
203         new BasicBlock(std::unique_ptr<Instruction>(new Instruction(
204             context.get(), SpvOpLabel, 0, context->TakeNextId(), {})))));
205     BasicBlock& bb_true = *fn.begin();
206     {
207       InstructionBuilder builder(context.get(), &*bb_true.begin());
208       builder.AddBranch(bb_merge.id());
209     }
210 
211     // TODO(1841): Handle id overflow.
212     fn.begin().InsertBefore(std::unique_ptr<BasicBlock>(
213         new BasicBlock(std::unique_ptr<Instruction>(new Instruction(
214             context.get(), SpvOpLabel, 0, context->TakeNextId(), {})))));
215     BasicBlock& bb_cond = *fn.begin();
216 
217     InstructionBuilder builder(context.get(), &bb_cond);
218     // This also test consecutive instruction insertion: merge selection +
219     // branch.
220     builder.AddConditionalBranch(9, bb_true.id(), bb_merge.id(), bb_merge.id());
221 
222     Match(text, context.get());
223   }
224 }
225 
TEST_F(IRBuilderTest,AddSelect)226 TEST_F(IRBuilderTest, AddSelect) {
227   const std::string text = R"(
228 ; CHECK: [[bool:%\w+]] = OpTypeBool
229 ; CHECK: [[uint:%\w+]] = OpTypeInt 32 0
230 ; CHECK: [[true:%\w+]] = OpConstantTrue [[bool]]
231 ; CHECK: [[u0:%\w+]] = OpConstant [[uint]] 0
232 ; CHECK: [[u1:%\w+]] = OpConstant [[uint]] 1
233 ; CHECK: OpSelect [[uint]] [[true]] [[u0]] [[u1]]
234 OpCapability Kernel
235 OpCapability Linkage
236 OpMemoryModel Logical OpenCL
237 %1 = OpTypeVoid
238 %2 = OpTypeBool
239 %3 = OpTypeInt 32 0
240 %4 = OpConstantTrue %2
241 %5 = OpConstant %3 0
242 %6 = OpConstant %3 1
243 %7 = OpTypeFunction %1
244 %8 = OpFunction %1 None %7
245 %9 = OpLabel
246 OpReturn
247 OpFunctionEnd
248 )";
249 
250   std::unique_ptr<IRContext> context =
251       BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text);
252   EXPECT_NE(nullptr, context);
253 
254   InstructionBuilder builder(context.get(),
255                              &*context->module()->begin()->begin()->begin());
256   EXPECT_NE(nullptr, builder.AddSelect(3u, 4u, 5u, 6u));
257 
258   Match(text, context.get());
259 }
260 
TEST_F(IRBuilderTest,AddCompositeConstruct)261 TEST_F(IRBuilderTest, AddCompositeConstruct) {
262   const std::string text = R"(
263 ; CHECK: [[uint:%\w+]] = OpTypeInt
264 ; CHECK: [[u0:%\w+]] = OpConstant [[uint]] 0
265 ; CHECK: [[u1:%\w+]] = OpConstant [[uint]] 1
266 ; CHECK: [[struct:%\w+]] = OpTypeStruct [[uint]] [[uint]] [[uint]] [[uint]]
267 ; CHECK: OpCompositeConstruct [[struct]] [[u0]] [[u1]] [[u1]] [[u0]]
268 OpCapability Kernel
269 OpCapability Linkage
270 OpMemoryModel Logical OpenCL
271 %1 = OpTypeVoid
272 %2 = OpTypeInt 32 0
273 %3 = OpConstant %2 0
274 %4 = OpConstant %2 1
275 %5 = OpTypeStruct %2 %2 %2 %2
276 %6 = OpTypeFunction %1
277 %7 = OpFunction %1 None %6
278 %8 = OpLabel
279 OpReturn
280 OpFunctionEnd
281 )";
282 
283   std::unique_ptr<IRContext> context =
284       BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text);
285   EXPECT_NE(nullptr, context);
286 
287   InstructionBuilder builder(context.get(),
288                              &*context->module()->begin()->begin()->begin());
289   std::vector<uint32_t> ids = {3u, 4u, 4u, 3u};
290   EXPECT_NE(nullptr, builder.AddCompositeConstruct(5u, ids));
291 
292   Match(text, context.get());
293 }
294 
TEST_F(IRBuilderTest,ConstantAdder)295 TEST_F(IRBuilderTest, ConstantAdder) {
296   const std::string text = R"(
297 ; CHECK: [[uint:%\w+]] = OpTypeInt 32 0
298 ; CHECK: OpConstant [[uint]] 13
299 ; CHECK: [[sint:%\w+]] = OpTypeInt 32 1
300 ; CHECK: OpConstant [[sint]] -1
301 ; CHECK: OpConstant [[uint]] 1
302 ; CHECK: OpConstant [[sint]] 34
303 ; CHECK: OpConstant [[uint]] 0
304 ; CHECK: OpConstant [[sint]] 0
305 OpCapability Shader
306 OpCapability Linkage
307 OpMemoryModel Logical GLSL450
308 %1 = OpTypeVoid
309 %2 = OpTypeFunction %1
310 %3 = OpFunction %1 None %2
311 %4 = OpLabel
312 OpReturn
313 OpFunctionEnd
314 )";
315 
316   std::unique_ptr<IRContext> context =
317       BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text);
318   EXPECT_NE(nullptr, context);
319 
320   InstructionBuilder builder(context.get(),
321                              &*context->module()->begin()->begin()->begin());
322   EXPECT_NE(nullptr, builder.GetUintConstant(13));
323   EXPECT_NE(nullptr, builder.GetSintConstant(-1));
324 
325   // Try adding the same constants again to make sure they aren't added.
326   EXPECT_NE(nullptr, builder.GetUintConstant(13));
327   EXPECT_NE(nullptr, builder.GetSintConstant(-1));
328 
329   // Try adding different constants to make sure the type is reused.
330   EXPECT_NE(nullptr, builder.GetUintConstant(1));
331   EXPECT_NE(nullptr, builder.GetSintConstant(34));
332 
333   // Try adding 0 as both signed and unsigned.
334   EXPECT_NE(nullptr, builder.GetUintConstant(0));
335   EXPECT_NE(nullptr, builder.GetSintConstant(0));
336 
337   Match(text, context.get());
338 }
339 
TEST_F(IRBuilderTest,ConstantAdderTypeAlreadyExists)340 TEST_F(IRBuilderTest, ConstantAdderTypeAlreadyExists) {
341   const std::string text = R"(
342 ; CHECK: OpConstant %uint 13
343 ; CHECK: OpConstant %int -1
344 ; CHECK: OpConstant %uint 1
345 ; CHECK: OpConstant %int 34
346 ; CHECK: OpConstant %uint 0
347 ; CHECK: OpConstant %int 0
348 OpCapability Shader
349 OpCapability Linkage
350 OpMemoryModel Logical GLSL450
351 %1 = OpTypeVoid
352 %uint = OpTypeInt 32 0
353 %int = OpTypeInt 32 1
354 %4 = OpTypeFunction %1
355 %5 = OpFunction %1 None %4
356 %6 = OpLabel
357 OpReturn
358 OpFunctionEnd
359 )";
360 
361   std::unique_ptr<IRContext> context =
362       BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text);
363   EXPECT_NE(nullptr, context);
364 
365   InstructionBuilder builder(context.get(),
366                              &*context->module()->begin()->begin()->begin());
367   Instruction* const_1 = builder.GetUintConstant(13);
368   Instruction* const_2 = builder.GetSintConstant(-1);
369 
370   EXPECT_NE(nullptr, const_1);
371   EXPECT_NE(nullptr, const_2);
372 
373   // Try adding the same constants again to make sure they aren't added.
374   EXPECT_EQ(const_1, builder.GetUintConstant(13));
375   EXPECT_EQ(const_2, builder.GetSintConstant(-1));
376 
377   Instruction* const_3 = builder.GetUintConstant(1);
378   Instruction* const_4 = builder.GetSintConstant(34);
379 
380   // Try adding different constants to make sure the type is reused.
381   EXPECT_NE(nullptr, const_3);
382   EXPECT_NE(nullptr, const_4);
383 
384   Instruction* const_5 = builder.GetUintConstant(0);
385   Instruction* const_6 = builder.GetSintConstant(0);
386 
387   // Try adding 0 as both signed and unsigned.
388   EXPECT_NE(nullptr, const_5);
389   EXPECT_NE(nullptr, const_6);
390 
391   // They have the same value but different types so should be unique.
392   EXPECT_NE(const_5, const_6);
393 
394   // Check the types are correct.
395   uint32_t type_id_unsigned = const_1->GetSingleWordOperand(0);
396   uint32_t type_id_signed = const_2->GetSingleWordOperand(0);
397 
398   EXPECT_NE(type_id_unsigned, type_id_signed);
399 
400   EXPECT_EQ(const_3->GetSingleWordOperand(0), type_id_unsigned);
401   EXPECT_EQ(const_5->GetSingleWordOperand(0), type_id_unsigned);
402 
403   EXPECT_EQ(const_4->GetSingleWordOperand(0), type_id_signed);
404   EXPECT_EQ(const_6->GetSingleWordOperand(0), type_id_signed);
405 
406   Match(text, context.get());
407 }
408 
TEST_F(IRBuilderTest,AccelerationStructureNV)409 TEST_F(IRBuilderTest, AccelerationStructureNV) {
410   const std::string text = R"(
411 ; CHECK: OpTypeAccelerationStructureNV
412 OpCapability Shader
413 OpCapability RayTracingNV
414 OpExtension "SPV_NV_ray_tracing"
415 OpMemoryModel Logical GLSL450
416 OpEntryPoint Fragment %8 "main"
417 OpExecutionMode %8 OriginUpperLeft
418 %1 = OpTypeVoid
419 %2 = OpTypeBool
420 %3 = OpTypeAccelerationStructureNV
421 %7 = OpTypeFunction %1
422 %8 = OpFunction %1 None %7
423 %9 = OpLabel
424 OpReturn
425 OpFunctionEnd
426 )";
427 
428   std::unique_ptr<IRContext> context =
429       BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text);
430   EXPECT_NE(nullptr, context);
431 
432   InstructionBuilder builder(context.get(),
433                              &*context->module()->begin()->begin()->begin());
434   Match(text, context.get());
435 }
436 
437 }  // namespace
438 }  // namespace opt
439 }  // namespace spvtools
440