1 // Copyright (c) 2018 Google Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include <algorithm>
16 #include <memory>
17 #include <string>
18 #include <vector>
19
20 #include "effcee/effcee.h"
21 #include "gmock/gmock.h"
22 #include "gtest/gtest.h"
23 #include "source/opt/basic_block.h"
24 #include "source/opt/build_module.h"
25 #include "source/opt/instruction.h"
26 #include "source/opt/ir_builder.h"
27 #include "source/opt/type_manager.h"
28 #include "spirv-tools/libspirv.hpp"
29
30 namespace spvtools {
31 namespace opt {
32 namespace {
33
34 using Analysis = IRContext::Analysis;
35 using IRBuilderTest = ::testing::Test;
36
Validate(const std::vector<uint32_t> & bin)37 bool Validate(const std::vector<uint32_t>& bin) {
38 spv_target_env target_env = SPV_ENV_UNIVERSAL_1_2;
39 spv_context spvContext = spvContextCreate(target_env);
40 spv_diagnostic diagnostic = nullptr;
41 spv_const_binary_t binary = {bin.data(), bin.size()};
42 spv_result_t error = spvValidate(spvContext, &binary, &diagnostic);
43 if (error != 0) spvDiagnosticPrint(diagnostic);
44 spvDiagnosticDestroy(diagnostic);
45 spvContextDestroy(spvContext);
46 return error == 0;
47 }
48
Match(const std::string & original,IRContext * context,bool do_validation=true)49 void Match(const std::string& original, IRContext* context,
50 bool do_validation = true) {
51 std::vector<uint32_t> bin;
52 context->module()->ToBinary(&bin, true);
53 if (do_validation) {
54 EXPECT_TRUE(Validate(bin));
55 }
56 std::string assembly;
57 SpirvTools tools(SPV_ENV_UNIVERSAL_1_2);
58 EXPECT_TRUE(
59 tools.Disassemble(bin, &assembly, SpirvTools::kDefaultDisassembleOption))
60 << "Disassembling failed for shader:\n"
61 << assembly << std::endl;
62 auto match_result = effcee::Match(assembly, original);
63 EXPECT_EQ(effcee::Result::Status::Ok, match_result.status())
64 << match_result.message() << "\nChecking result:\n"
65 << assembly;
66 }
67
TEST_F(IRBuilderTest,TestInsnAddition)68 TEST_F(IRBuilderTest, TestInsnAddition) {
69 const std::string text = R"(
70 ; CHECK: %18 = OpLabel
71 ; CHECK: OpPhi %int %int_0 %14
72 ; CHECK: OpPhi %bool %16 %14
73 ; CHECK: OpBranch %17
74 OpCapability Shader
75 %1 = OpExtInstImport "GLSL.std.450"
76 OpMemoryModel Logical GLSL450
77 OpEntryPoint Fragment %2 "main" %3
78 OpExecutionMode %2 OriginUpperLeft
79 OpSource GLSL 330
80 OpName %2 "main"
81 OpName %4 "i"
82 OpName %3 "c"
83 OpDecorate %3 Location 0
84 %5 = OpTypeVoid
85 %6 = OpTypeFunction %5
86 %7 = OpTypeInt 32 1
87 %8 = OpTypePointer Function %7
88 %9 = OpConstant %7 0
89 %10 = OpTypeBool
90 %11 = OpTypeFloat 32
91 %12 = OpTypeVector %11 4
92 %13 = OpTypePointer Output %12
93 %3 = OpVariable %13 Output
94 %2 = OpFunction %5 None %6
95 %14 = OpLabel
96 %4 = OpVariable %8 Function
97 OpStore %4 %9
98 %15 = OpLoad %7 %4
99 %16 = OpINotEqual %10 %15 %9
100 OpSelectionMerge %17 None
101 OpBranchConditional %16 %18 %17
102 %18 = OpLabel
103 OpBranch %17
104 %17 = OpLabel
105 OpReturn
106 OpFunctionEnd
107 )";
108
109 {
110 std::unique_ptr<IRContext> context =
111 BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text);
112
113 BasicBlock* bb = context->cfg()->block(18);
114
115 // Build managers.
116 context->get_def_use_mgr();
117 context->get_instr_block(nullptr);
118
119 InstructionBuilder builder(context.get(), &*bb->begin());
120 Instruction* phi1 = builder.AddPhi(7, {9, 14});
121 Instruction* phi2 = builder.AddPhi(10, {16, 14});
122
123 // Make sure the InstructionBuilder did not update the def/use manager.
124 EXPECT_EQ(context->get_def_use_mgr()->GetDef(phi1->result_id()), nullptr);
125 EXPECT_EQ(context->get_def_use_mgr()->GetDef(phi2->result_id()), nullptr);
126 EXPECT_EQ(context->get_instr_block(phi1), nullptr);
127 EXPECT_EQ(context->get_instr_block(phi2), nullptr);
128
129 Match(text, context.get());
130 }
131
132 {
133 std::unique_ptr<IRContext> context =
134 BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text);
135
136 // Build managers.
137 context->get_def_use_mgr();
138 context->get_instr_block(nullptr);
139
140 BasicBlock* bb = context->cfg()->block(18);
141 InstructionBuilder builder(
142 context.get(), &*bb->begin(),
143 IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
144 Instruction* phi1 = builder.AddPhi(7, {9, 14});
145 Instruction* phi2 = builder.AddPhi(10, {16, 14});
146
147 // Make sure InstructionBuilder updated the def/use manager
148 EXPECT_NE(context->get_def_use_mgr()->GetDef(phi1->result_id()), nullptr);
149 EXPECT_NE(context->get_def_use_mgr()->GetDef(phi2->result_id()), nullptr);
150 EXPECT_NE(context->get_instr_block(phi1), nullptr);
151 EXPECT_NE(context->get_instr_block(phi2), nullptr);
152
153 Match(text, context.get());
154 }
155 }
156
TEST_F(IRBuilderTest,TestCondBranchAddition)157 TEST_F(IRBuilderTest, TestCondBranchAddition) {
158 const std::string text = R"(
159 ; CHECK: %main = OpFunction %void None %6
160 ; CHECK-NEXT: %15 = OpLabel
161 ; CHECK-NEXT: OpSelectionMerge %13 None
162 ; CHECK-NEXT: OpBranchConditional %true %14 %13
163 ; CHECK-NEXT: %14 = OpLabel
164 ; CHECK-NEXT: OpBranch %13
165 ; CHECK-NEXT: %13 = OpLabel
166 ; CHECK-NEXT: OpReturn
167 OpCapability Shader
168 %1 = OpExtInstImport "GLSL.std.450"
169 OpMemoryModel Logical GLSL450
170 OpEntryPoint Fragment %2 "main" %3
171 OpExecutionMode %2 OriginUpperLeft
172 OpSource GLSL 330
173 OpName %2 "main"
174 OpName %4 "i"
175 OpName %3 "c"
176 OpDecorate %3 Location 0
177 %5 = OpTypeVoid
178 %6 = OpTypeFunction %5
179 %7 = OpTypeBool
180 %8 = OpTypePointer Private %7
181 %9 = OpConstantTrue %7
182 %10 = OpTypeFloat 32
183 %11 = OpTypeVector %10 4
184 %12 = OpTypePointer Output %11
185 %3 = OpVariable %12 Output
186 %4 = OpVariable %8 Private
187 %2 = OpFunction %5 None %6
188 %13 = OpLabel
189 OpReturn
190 OpFunctionEnd
191 )";
192
193 {
194 std::unique_ptr<IRContext> context =
195 BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text);
196
197 Function& fn = *context->module()->begin();
198
199 BasicBlock& bb_merge = *fn.begin();
200
201 // TODO(1841): Handle id overflow.
202 fn.begin().InsertBefore(std::unique_ptr<BasicBlock>(
203 new BasicBlock(std::unique_ptr<Instruction>(new Instruction(
204 context.get(), SpvOpLabel, 0, context->TakeNextId(), {})))));
205 BasicBlock& bb_true = *fn.begin();
206 {
207 InstructionBuilder builder(context.get(), &*bb_true.begin());
208 builder.AddBranch(bb_merge.id());
209 }
210
211 // TODO(1841): Handle id overflow.
212 fn.begin().InsertBefore(std::unique_ptr<BasicBlock>(
213 new BasicBlock(std::unique_ptr<Instruction>(new Instruction(
214 context.get(), SpvOpLabel, 0, context->TakeNextId(), {})))));
215 BasicBlock& bb_cond = *fn.begin();
216
217 InstructionBuilder builder(context.get(), &bb_cond);
218 // This also test consecutive instruction insertion: merge selection +
219 // branch.
220 builder.AddConditionalBranch(9, bb_true.id(), bb_merge.id(), bb_merge.id());
221
222 Match(text, context.get());
223 }
224 }
225
TEST_F(IRBuilderTest,AddSelect)226 TEST_F(IRBuilderTest, AddSelect) {
227 const std::string text = R"(
228 ; CHECK: [[bool:%\w+]] = OpTypeBool
229 ; CHECK: [[uint:%\w+]] = OpTypeInt 32 0
230 ; CHECK: [[true:%\w+]] = OpConstantTrue [[bool]]
231 ; CHECK: [[u0:%\w+]] = OpConstant [[uint]] 0
232 ; CHECK: [[u1:%\w+]] = OpConstant [[uint]] 1
233 ; CHECK: OpSelect [[uint]] [[true]] [[u0]] [[u1]]
234 OpCapability Kernel
235 OpCapability Linkage
236 OpMemoryModel Logical OpenCL
237 %1 = OpTypeVoid
238 %2 = OpTypeBool
239 %3 = OpTypeInt 32 0
240 %4 = OpConstantTrue %2
241 %5 = OpConstant %3 0
242 %6 = OpConstant %3 1
243 %7 = OpTypeFunction %1
244 %8 = OpFunction %1 None %7
245 %9 = OpLabel
246 OpReturn
247 OpFunctionEnd
248 )";
249
250 std::unique_ptr<IRContext> context =
251 BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text);
252 EXPECT_NE(nullptr, context);
253
254 InstructionBuilder builder(context.get(),
255 &*context->module()->begin()->begin()->begin());
256 EXPECT_NE(nullptr, builder.AddSelect(3u, 4u, 5u, 6u));
257
258 Match(text, context.get());
259 }
260
TEST_F(IRBuilderTest,AddCompositeConstruct)261 TEST_F(IRBuilderTest, AddCompositeConstruct) {
262 const std::string text = R"(
263 ; CHECK: [[uint:%\w+]] = OpTypeInt
264 ; CHECK: [[u0:%\w+]] = OpConstant [[uint]] 0
265 ; CHECK: [[u1:%\w+]] = OpConstant [[uint]] 1
266 ; CHECK: [[struct:%\w+]] = OpTypeStruct [[uint]] [[uint]] [[uint]] [[uint]]
267 ; CHECK: OpCompositeConstruct [[struct]] [[u0]] [[u1]] [[u1]] [[u0]]
268 OpCapability Kernel
269 OpCapability Linkage
270 OpMemoryModel Logical OpenCL
271 %1 = OpTypeVoid
272 %2 = OpTypeInt 32 0
273 %3 = OpConstant %2 0
274 %4 = OpConstant %2 1
275 %5 = OpTypeStruct %2 %2 %2 %2
276 %6 = OpTypeFunction %1
277 %7 = OpFunction %1 None %6
278 %8 = OpLabel
279 OpReturn
280 OpFunctionEnd
281 )";
282
283 std::unique_ptr<IRContext> context =
284 BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text);
285 EXPECT_NE(nullptr, context);
286
287 InstructionBuilder builder(context.get(),
288 &*context->module()->begin()->begin()->begin());
289 std::vector<uint32_t> ids = {3u, 4u, 4u, 3u};
290 EXPECT_NE(nullptr, builder.AddCompositeConstruct(5u, ids));
291
292 Match(text, context.get());
293 }
294
TEST_F(IRBuilderTest,ConstantAdder)295 TEST_F(IRBuilderTest, ConstantAdder) {
296 const std::string text = R"(
297 ; CHECK: [[uint:%\w+]] = OpTypeInt 32 0
298 ; CHECK: OpConstant [[uint]] 13
299 ; CHECK: [[sint:%\w+]] = OpTypeInt 32 1
300 ; CHECK: OpConstant [[sint]] -1
301 ; CHECK: OpConstant [[uint]] 1
302 ; CHECK: OpConstant [[sint]] 34
303 ; CHECK: OpConstant [[uint]] 0
304 ; CHECK: OpConstant [[sint]] 0
305 OpCapability Shader
306 OpCapability Linkage
307 OpMemoryModel Logical GLSL450
308 %1 = OpTypeVoid
309 %2 = OpTypeFunction %1
310 %3 = OpFunction %1 None %2
311 %4 = OpLabel
312 OpReturn
313 OpFunctionEnd
314 )";
315
316 std::unique_ptr<IRContext> context =
317 BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text);
318 EXPECT_NE(nullptr, context);
319
320 InstructionBuilder builder(context.get(),
321 &*context->module()->begin()->begin()->begin());
322 EXPECT_NE(nullptr, builder.GetUintConstant(13));
323 EXPECT_NE(nullptr, builder.GetSintConstant(-1));
324
325 // Try adding the same constants again to make sure they aren't added.
326 EXPECT_NE(nullptr, builder.GetUintConstant(13));
327 EXPECT_NE(nullptr, builder.GetSintConstant(-1));
328
329 // Try adding different constants to make sure the type is reused.
330 EXPECT_NE(nullptr, builder.GetUintConstant(1));
331 EXPECT_NE(nullptr, builder.GetSintConstant(34));
332
333 // Try adding 0 as both signed and unsigned.
334 EXPECT_NE(nullptr, builder.GetUintConstant(0));
335 EXPECT_NE(nullptr, builder.GetSintConstant(0));
336
337 Match(text, context.get());
338 }
339
TEST_F(IRBuilderTest,ConstantAdderTypeAlreadyExists)340 TEST_F(IRBuilderTest, ConstantAdderTypeAlreadyExists) {
341 const std::string text = R"(
342 ; CHECK: OpConstant %uint 13
343 ; CHECK: OpConstant %int -1
344 ; CHECK: OpConstant %uint 1
345 ; CHECK: OpConstant %int 34
346 ; CHECK: OpConstant %uint 0
347 ; CHECK: OpConstant %int 0
348 OpCapability Shader
349 OpCapability Linkage
350 OpMemoryModel Logical GLSL450
351 %1 = OpTypeVoid
352 %uint = OpTypeInt 32 0
353 %int = OpTypeInt 32 1
354 %4 = OpTypeFunction %1
355 %5 = OpFunction %1 None %4
356 %6 = OpLabel
357 OpReturn
358 OpFunctionEnd
359 )";
360
361 std::unique_ptr<IRContext> context =
362 BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text);
363 EXPECT_NE(nullptr, context);
364
365 InstructionBuilder builder(context.get(),
366 &*context->module()->begin()->begin()->begin());
367 Instruction* const_1 = builder.GetUintConstant(13);
368 Instruction* const_2 = builder.GetSintConstant(-1);
369
370 EXPECT_NE(nullptr, const_1);
371 EXPECT_NE(nullptr, const_2);
372
373 // Try adding the same constants again to make sure they aren't added.
374 EXPECT_EQ(const_1, builder.GetUintConstant(13));
375 EXPECT_EQ(const_2, builder.GetSintConstant(-1));
376
377 Instruction* const_3 = builder.GetUintConstant(1);
378 Instruction* const_4 = builder.GetSintConstant(34);
379
380 // Try adding different constants to make sure the type is reused.
381 EXPECT_NE(nullptr, const_3);
382 EXPECT_NE(nullptr, const_4);
383
384 Instruction* const_5 = builder.GetUintConstant(0);
385 Instruction* const_6 = builder.GetSintConstant(0);
386
387 // Try adding 0 as both signed and unsigned.
388 EXPECT_NE(nullptr, const_5);
389 EXPECT_NE(nullptr, const_6);
390
391 // They have the same value but different types so should be unique.
392 EXPECT_NE(const_5, const_6);
393
394 // Check the types are correct.
395 uint32_t type_id_unsigned = const_1->GetSingleWordOperand(0);
396 uint32_t type_id_signed = const_2->GetSingleWordOperand(0);
397
398 EXPECT_NE(type_id_unsigned, type_id_signed);
399
400 EXPECT_EQ(const_3->GetSingleWordOperand(0), type_id_unsigned);
401 EXPECT_EQ(const_5->GetSingleWordOperand(0), type_id_unsigned);
402
403 EXPECT_EQ(const_4->GetSingleWordOperand(0), type_id_signed);
404 EXPECT_EQ(const_6->GetSingleWordOperand(0), type_id_signed);
405
406 Match(text, context.get());
407 }
408
TEST_F(IRBuilderTest,AccelerationStructureNV)409 TEST_F(IRBuilderTest, AccelerationStructureNV) {
410 const std::string text = R"(
411 ; CHECK: OpTypeAccelerationStructureNV
412 OpCapability Shader
413 OpCapability RayTracingNV
414 OpExtension "SPV_NV_ray_tracing"
415 OpMemoryModel Logical GLSL450
416 OpEntryPoint Fragment %8 "main"
417 OpExecutionMode %8 OriginUpperLeft
418 %1 = OpTypeVoid
419 %2 = OpTypeBool
420 %3 = OpTypeAccelerationStructureNV
421 %7 = OpTypeFunction %1
422 %8 = OpFunction %1 None %7
423 %9 = OpLabel
424 OpReturn
425 OpFunctionEnd
426 )";
427
428 std::unique_ptr<IRContext> context =
429 BuildModule(SPV_ENV_UNIVERSAL_1_2, nullptr, text);
430 EXPECT_NE(nullptr, context);
431
432 InstructionBuilder builder(context.get(),
433 &*context->module()->begin()->begin()->begin());
434 Match(text, context.get());
435 }
436
437 } // namespace
438 } // namespace opt
439 } // namespace spvtools
440