1 // Copyright (c) 2016 Google Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include <memory>
16 #include <string>
17 #include <unordered_map>
18 #include <unordered_set>
19 #include <utility>
20 #include <vector>
21 
22 #include "gmock/gmock.h"
23 #include "gtest/gtest.h"
24 #include "source/opt/build_module.h"
25 #include "source/opt/def_use_manager.h"
26 #include "source/opt/ir_context.h"
27 #include "source/opt/module.h"
28 #include "spirv-tools/libspirv.hpp"
29 #include "test/opt/pass_fixture.h"
30 #include "test/opt/pass_utils.h"
31 
32 namespace spvtools {
33 namespace opt {
34 namespace analysis {
35 namespace {
36 
37 using ::testing::Contains;
38 using ::testing::UnorderedElementsAre;
39 using ::testing::UnorderedElementsAreArray;
40 
41 // Returns the number of uses of |id|.
NumUses(const std::unique_ptr<IRContext> & context,uint32_t id)42 uint32_t NumUses(const std::unique_ptr<IRContext>& context, uint32_t id) {
43   uint32_t count = 0;
44   context->get_def_use_mgr()->ForEachUse(
45       id, [&count](Instruction*, uint32_t) { ++count; });
46   return count;
47 }
48 
49 // Returns the opcode of each use of |id|.
50 //
51 // If |id| is used multiple times in a single instruction, that instruction's
52 // opcode will appear a corresponding number of times.
GetUseOpcodes(const std::unique_ptr<IRContext> & context,uint32_t id)53 std::vector<SpvOp> GetUseOpcodes(const std::unique_ptr<IRContext>& context,
54                                  uint32_t id) {
55   std::vector<SpvOp> opcodes;
56   context->get_def_use_mgr()->ForEachUse(
57       id, [&opcodes](Instruction* user, uint32_t) {
58         opcodes.push_back(user->opcode());
59       });
60   return opcodes;
61 }
62 
63 // Disassembles the given |inst| and returns the disassembly.
DisassembleInst(Instruction * inst)64 std::string DisassembleInst(Instruction* inst) {
65   SpirvTools tools(SPV_ENV_UNIVERSAL_1_1);
66 
67   std::vector<uint32_t> binary;
68   // We need this to generate the necessary header in the binary.
69   tools.Assemble("", &binary);
70   inst->ToBinaryWithoutAttachedDebugInsts(&binary);
71 
72   std::string text;
73   // We'll need to check the underlying id numbers.
74   // So turn off friendly names for ids.
75   tools.Disassemble(binary, &text, SPV_BINARY_TO_TEXT_OPTION_NO_HEADER);
76   while (!text.empty() && text.back() == '\n') text.pop_back();
77   return text;
78 }
79 
80 // A struct for holding expected id defs and uses.
81 struct InstDefUse {
82   using IdInstPair = std::pair<uint32_t, std::string>;
83   using IdInstsPair = std::pair<uint32_t, std::vector<std::string>>;
84 
85   // Ids and their corresponding def instructions.
86   std::vector<IdInstPair> defs;
87   // Ids and their corresponding use instructions.
88   std::vector<IdInstsPair> uses;
89 };
90 
91 // Checks that the |actual_defs| and |actual_uses| are in accord with
92 // |expected_defs_uses|.
CheckDef(const InstDefUse & expected_defs_uses,const DefUseManager::IdToDefMap & actual_defs)93 void CheckDef(const InstDefUse& expected_defs_uses,
94               const DefUseManager::IdToDefMap& actual_defs) {
95   // Check defs.
96   ASSERT_EQ(expected_defs_uses.defs.size(), actual_defs.size());
97   for (uint32_t i = 0; i < expected_defs_uses.defs.size(); ++i) {
98     const auto id = expected_defs_uses.defs[i].first;
99     const auto expected_def = expected_defs_uses.defs[i].second;
100     ASSERT_EQ(1u, actual_defs.count(id)) << "expected to def id [" << id << "]";
101     auto def = actual_defs.at(id);
102     if (def->opcode() != SpvOpConstant) {
103       // Constants don't disassemble properly without a full context.
104       EXPECT_EQ(expected_def, DisassembleInst(actual_defs.at(id)));
105     }
106   }
107 }
108 
109 using UserMap = std::unordered_map<uint32_t, std::vector<Instruction*>>;
110 
111 // Creates a mapping of all definitions to their users (except OpConstant).
112 //
113 // OpConstants are skipped because they cannot be disassembled in isolation.
BuildAllUsers(const DefUseManager * mgr,uint32_t idBound)114 UserMap BuildAllUsers(const DefUseManager* mgr, uint32_t idBound) {
115   UserMap userMap;
116   for (uint32_t id = 0; id != idBound; ++id) {
117     if (mgr->GetDef(id)) {
118       mgr->ForEachUser(id, [id, &userMap](Instruction* user) {
119         if (user->opcode() != SpvOpConstant) {
120           userMap[id].push_back(user);
121         }
122       });
123     }
124   }
125   return userMap;
126 }
127 
128 // Constants don't disassemble properly without a full context, so skip them as
129 // checks.
CheckUse(const InstDefUse & expected_defs_uses,const DefUseManager * mgr,uint32_t idBound)130 void CheckUse(const InstDefUse& expected_defs_uses, const DefUseManager* mgr,
131               uint32_t idBound) {
132   UserMap actual_uses = BuildAllUsers(mgr, idBound);
133   // Check uses.
134   ASSERT_EQ(expected_defs_uses.uses.size(), actual_uses.size());
135   for (uint32_t i = 0; i < expected_defs_uses.uses.size(); ++i) {
136     const auto id = expected_defs_uses.uses[i].first;
137     const auto& expected_uses = expected_defs_uses.uses[i].second;
138 
139     ASSERT_EQ(1u, actual_uses.count(id)) << "expected to use id [" << id << "]";
140     const auto& uses = actual_uses.at(id);
141 
142     ASSERT_EQ(expected_uses.size(), uses.size())
143         << "id [" << id << "] # uses: expected: " << expected_uses.size()
144         << " actual: " << uses.size();
145 
146     std::vector<std::string> actual_uses_disassembled;
147     for (const auto actual_use : uses) {
148       actual_uses_disassembled.emplace_back(DisassembleInst(actual_use));
149     }
150     EXPECT_THAT(actual_uses_disassembled,
151                 UnorderedElementsAreArray(expected_uses));
152   }
153 }
154 
155 // The following test case mimics how LLVM handles induction variables.
156 // But, yeah, it's not very readable. However, we only care about the id
157 // defs and uses. So, no need to make sure this is valid OpPhi construct.
158 const char kOpPhiTestFunction[] =
159     " %1 = OpTypeVoid "
160     " %6 = OpTypeInt 32 0 "
161     "%10 = OpTypeFloat 32 "
162     "%16 = OpTypeBool "
163     " %3 = OpTypeFunction %1 "
164     " %8 = OpConstant %6 0 "
165     "%18 = OpConstant %6 1 "
166     "%12 = OpConstant %10 1.0 "
167     " %2 = OpFunction %1 None %3 "
168     " %4 = OpLabel "
169     "      OpBranch %5 "
170 
171     " %5 = OpLabel "
172     " %7 = OpPhi %6 %8 %4 %9 %5 "
173     "%11 = OpPhi %10 %12 %4 %13 %5 "
174     " %9 = OpIAdd %6 %7 %8 "
175     "%13 = OpFAdd %10 %11 %12 "
176     "%17 = OpSLessThan %16 %7 %18 "
177     "      OpLoopMerge %19 %5 None "
178     "      OpBranchConditional %17 %5 %19 "
179 
180     "%19 = OpLabel "
181     "      OpReturn "
182     "      OpFunctionEnd";
183 
184 struct ParseDefUseCase {
185   const char* text;
186   InstDefUse du;
187 };
188 
189 using ParseDefUseTest = ::testing::TestWithParam<ParseDefUseCase>;
190 
TEST_P(ParseDefUseTest,Case)191 TEST_P(ParseDefUseTest, Case) {
192   const auto& tc = GetParam();
193 
194   // Build module.
195   const std::vector<const char*> text = {tc.text};
196   std::unique_ptr<IRContext> context =
197       BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, JoinAllInsts(text),
198                   SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
199   ASSERT_NE(nullptr, context);
200 
201   // Analyze def and use.
202   DefUseManager manager(context->module());
203 
204   CheckDef(tc.du, manager.id_to_defs());
205   CheckUse(tc.du, &manager, context->module()->IdBound());
206 }
207 
208 // clang-format off
209 INSTANTIATE_TEST_CASE_P(
210     TestCase, ParseDefUseTest,
211     ::testing::ValuesIn(std::vector<ParseDefUseCase>{
212         {"", {{}, {}}},                              // no instruction
213         {"OpMemoryModel Logical GLSL450", {{}, {}}}, // no def and use
214         { // single def, no use
215           "%1 = OpString \"wow\"",
216           {
217             {{1, "%1 = OpString \"wow\""}}, // defs
218             {}                              // uses
219           }
220         },
221         { // multiple def, no use
222           "%1 = OpString \"hello\" "
223           "%2 = OpString \"world\" "
224           "%3 = OpTypeVoid",
225           {
226             {  // defs
227               {1, "%1 = OpString \"hello\""},
228               {2, "%2 = OpString \"world\""},
229               {3, "%3 = OpTypeVoid"},
230             },
231             {} // uses
232           }
233         },
234         { // multiple def, multiple use
235           "%1 = OpTypeBool "
236           "%2 = OpTypeVector %1 3 "
237           "%3 = OpTypeMatrix %2 3",
238           {
239             { // defs
240               {1, "%1 = OpTypeBool"},
241               {2, "%2 = OpTypeVector %1 3"},
242               {3, "%3 = OpTypeMatrix %2 3"},
243             },
244             { // uses
245               {1, {"%2 = OpTypeVector %1 3"}},
246               {2, {"%3 = OpTypeMatrix %2 3"}},
247             }
248           }
249         },
250         { // multiple use of the same id
251           "%1 = OpTypeBool "
252           "%2 = OpTypeVector %1 2 "
253           "%3 = OpTypeVector %1 3 "
254           "%4 = OpTypeVector %1 4",
255           {
256             { // defs
257               {1, "%1 = OpTypeBool"},
258               {2, "%2 = OpTypeVector %1 2"},
259               {3, "%3 = OpTypeVector %1 3"},
260               {4, "%4 = OpTypeVector %1 4"},
261             },
262             { // uses
263               {1,
264                 {
265                   "%2 = OpTypeVector %1 2",
266                   "%3 = OpTypeVector %1 3",
267                   "%4 = OpTypeVector %1 4",
268                 }
269               },
270             }
271           }
272         },
273         { // labels
274           "%1 = OpTypeVoid "
275           "%2 = OpTypeBool "
276           "%3 = OpTypeFunction %1 "
277           "%4 = OpConstantTrue %2 "
278           "%5 = OpFunction %1 None %3 "
279 
280           "%6 = OpLabel "
281           "OpBranchConditional %4 %7 %8 "
282 
283           "%7 = OpLabel "
284           "OpBranch %7 "
285 
286           "%8 = OpLabel "
287           "OpReturn "
288 
289           "OpFunctionEnd",
290           {
291             { // defs
292               {1, "%1 = OpTypeVoid"},
293               {2, "%2 = OpTypeBool"},
294               {3, "%3 = OpTypeFunction %1"},
295               {4, "%4 = OpConstantTrue %2"},
296               {5, "%5 = OpFunction %1 None %3"},
297               {6, "%6 = OpLabel"},
298               {7, "%7 = OpLabel"},
299               {8, "%8 = OpLabel"},
300             },
301             { // uses
302               {1, {
303                     "%3 = OpTypeFunction %1",
304                     "%5 = OpFunction %1 None %3",
305                   }
306               },
307               {2, {"%4 = OpConstantTrue %2"}},
308               {3, {"%5 = OpFunction %1 None %3"}},
309               {4, {"OpBranchConditional %4 %7 %8"}},
310               {7,
311                 {
312                   "OpBranchConditional %4 %7 %8",
313                   "OpBranch %7",
314                 }
315               },
316               {8, {"OpBranchConditional %4 %7 %8"}},
317             }
318           }
319         },
320         { // cross function
321           "%1 = OpTypeBool "
322           "%3 = OpTypeFunction %1 "
323           "%2 = OpFunction %1 None %3 "
324 
325           "%4 = OpLabel "
326           "%5 = OpVariable %1 Function "
327           "%6 = OpFunctionCall %1 %2 %5 "
328           "OpReturnValue %6 "
329 
330           "OpFunctionEnd",
331           {
332             { // defs
333               {1, "%1 = OpTypeBool"},
334               {2, "%2 = OpFunction %1 None %3"},
335               {3, "%3 = OpTypeFunction %1"},
336               {4, "%4 = OpLabel"},
337               {5, "%5 = OpVariable %1 Function"},
338               {6, "%6 = OpFunctionCall %1 %2 %5"},
339             },
340             { // uses
341               {1,
342                 {
343                   "%2 = OpFunction %1 None %3",
344                   "%3 = OpTypeFunction %1",
345                   "%5 = OpVariable %1 Function",
346                   "%6 = OpFunctionCall %1 %2 %5",
347                 }
348               },
349               {2, {"%6 = OpFunctionCall %1 %2 %5"}},
350               {3, {"%2 = OpFunction %1 None %3"}},
351               {5, {"%6 = OpFunctionCall %1 %2 %5"}},
352               {6, {"OpReturnValue %6"}},
353             }
354           }
355         },
356         { // selection merge and loop merge
357           "%1 = OpTypeVoid "
358           "%3 = OpTypeFunction %1 "
359           "%10 = OpTypeBool "
360           "%8 = OpConstantTrue %10 "
361           "%2 = OpFunction %1 None %3 "
362 
363           "%4 = OpLabel "
364           "OpLoopMerge %5 %4 None "
365           "OpBranch %6 "
366 
367           "%5 = OpLabel "
368           "OpReturn "
369 
370           "%6 = OpLabel "
371           "OpSelectionMerge %7 None "
372           "OpBranchConditional %8 %9 %7 "
373 
374           "%7 = OpLabel "
375           "OpReturn "
376 
377           "%9 = OpLabel "
378           "OpReturn "
379 
380           "OpFunctionEnd",
381           {
382             { // defs
383               {1, "%1 = OpTypeVoid"},
384               {2, "%2 = OpFunction %1 None %3"},
385               {3, "%3 = OpTypeFunction %1"},
386               {4, "%4 = OpLabel"},
387               {5, "%5 = OpLabel"},
388               {6, "%6 = OpLabel"},
389               {7, "%7 = OpLabel"},
390               {8, "%8 = OpConstantTrue %10"},
391               {9, "%9 = OpLabel"},
392               {10, "%10 = OpTypeBool"},
393             },
394             { // uses
395               {1,
396                 {
397                   "%2 = OpFunction %1 None %3",
398                   "%3 = OpTypeFunction %1",
399                 }
400               },
401               {3, {"%2 = OpFunction %1 None %3"}},
402               {4, {"OpLoopMerge %5 %4 None"}},
403               {5, {"OpLoopMerge %5 %4 None"}},
404               {6, {"OpBranch %6"}},
405               {7,
406                 {
407                   "OpSelectionMerge %7 None",
408                   "OpBranchConditional %8 %9 %7",
409                 }
410               },
411               {8, {"OpBranchConditional %8 %9 %7"}},
412               {9, {"OpBranchConditional %8 %9 %7"}},
413               {10, {"%8 = OpConstantTrue %10"}},
414             }
415           }
416         },
417         { // Forward reference
418           "OpDecorate %1 Block "
419           "OpTypeForwardPointer %2 Input "
420           "%3 = OpTypeInt 32 0 "
421           "%1 = OpTypeStruct %3 "
422           "%2 = OpTypePointer Input %3",
423           {
424             { // defs
425               {1, "%1 = OpTypeStruct %3"},
426               {2, "%2 = OpTypePointer Input %3"},
427               {3, "%3 = OpTypeInt 32 0"},
428             },
429             { // uses
430               {1, {"OpDecorate %1 Block"}},
431               {2, {"OpTypeForwardPointer %2 Input"}},
432               {3,
433                 {
434                   "%1 = OpTypeStruct %3",
435                   "%2 = OpTypePointer Input %3",
436                 }
437               }
438             },
439           },
440         },
441         { // OpPhi
442           kOpPhiTestFunction,
443           {
444             { // defs
445               {1, "%1 = OpTypeVoid"},
446               {2, "%2 = OpFunction %1 None %3"},
447               {3, "%3 = OpTypeFunction %1"},
448               {4, "%4 = OpLabel"},
449               {5, "%5 = OpLabel"},
450               {6, "%6 = OpTypeInt 32 0"},
451               {7, "%7 = OpPhi %6 %8 %4 %9 %5"},
452               {8, "%8 = OpConstant %6 0"},
453               {9, "%9 = OpIAdd %6 %7 %8"},
454               {10, "%10 = OpTypeFloat 32"},
455               {11, "%11 = OpPhi %10 %12 %4 %13 %5"},
456               {12, "%12 = OpConstant %10 1.0"},
457               {13, "%13 = OpFAdd %10 %11 %12"},
458               {16, "%16 = OpTypeBool"},
459               {17, "%17 = OpSLessThan %16 %7 %18"},
460               {18, "%18 = OpConstant %6 1"},
461               {19, "%19 = OpLabel"},
462             },
463             { // uses
464               {1,
465                 {
466                   "%2 = OpFunction %1 None %3",
467                   "%3 = OpTypeFunction %1",
468                 }
469               },
470               {3, {"%2 = OpFunction %1 None %3"}},
471               {4,
472                 {
473                   "%7 = OpPhi %6 %8 %4 %9 %5",
474                   "%11 = OpPhi %10 %12 %4 %13 %5",
475                 }
476               },
477               {5,
478                 {
479                   "OpBranch %5",
480                   "%7 = OpPhi %6 %8 %4 %9 %5",
481                   "%11 = OpPhi %10 %12 %4 %13 %5",
482                   "OpLoopMerge %19 %5 None",
483                   "OpBranchConditional %17 %5 %19",
484                 }
485               },
486               {6,
487                 {
488                   // Can't check constants properly
489                   // "%8 = OpConstant %6 0",
490                   // "%18 = OpConstant %6 1",
491                   "%7 = OpPhi %6 %8 %4 %9 %5",
492                   "%9 = OpIAdd %6 %7 %8",
493                 }
494               },
495               {7,
496                 {
497                   "%9 = OpIAdd %6 %7 %8",
498                   "%17 = OpSLessThan %16 %7 %18",
499                 }
500               },
501               {8,
502                 {
503                   "%7 = OpPhi %6 %8 %4 %9 %5",
504                   "%9 = OpIAdd %6 %7 %8",
505                 }
506               },
507               {9, {"%7 = OpPhi %6 %8 %4 %9 %5"}},
508               {10,
509                 {
510                   // "%12 = OpConstant %10 1.0",
511                   "%11 = OpPhi %10 %12 %4 %13 %5",
512                   "%13 = OpFAdd %10 %11 %12",
513                 }
514               },
515               {11, {"%13 = OpFAdd %10 %11 %12"}},
516               {12,
517                 {
518                   "%11 = OpPhi %10 %12 %4 %13 %5",
519                   "%13 = OpFAdd %10 %11 %12",
520                 }
521               },
522               {13, {"%11 = OpPhi %10 %12 %4 %13 %5"}},
523               {16, {"%17 = OpSLessThan %16 %7 %18"}},
524               {17, {"OpBranchConditional %17 %5 %19"}},
525               {18, {"%17 = OpSLessThan %16 %7 %18"}},
526               {19,
527                 {
528                   "OpLoopMerge %19 %5 None",
529                   "OpBranchConditional %17 %5 %19",
530                 }
531               },
532             },
533           },
534         },
535         { // OpPhi defining and referencing the same id.
536           "%1 = OpTypeBool "
537           "%3 = OpTypeFunction %1 "
538           "%2 = OpConstantTrue %1 "
539           "%4 = OpFunction %1 None %3 "
540           "%6 = OpLabel "
541           "     OpBranch %7 "
542           "%7 = OpLabel "
543           "%8 = OpPhi %1   %8 %7   %2 %6 " // both defines and uses %8
544           "     OpBranch %7 "
545           "     OpFunctionEnd",
546           {
547             { // defs
548               {1, "%1 = OpTypeBool"},
549               {2, "%2 = OpConstantTrue %1"},
550               {3, "%3 = OpTypeFunction %1"},
551               {4, "%4 = OpFunction %1 None %3"},
552               {6, "%6 = OpLabel"},
553               {7, "%7 = OpLabel"},
554               {8, "%8 = OpPhi %1 %8 %7 %2 %6"},
555             },
556             { // uses
557               {1,
558                 {
559                   "%2 = OpConstantTrue %1",
560                   "%3 = OpTypeFunction %1",
561                   "%4 = OpFunction %1 None %3",
562                   "%8 = OpPhi %1 %8 %7 %2 %6",
563                 }
564               },
565               {2, {"%8 = OpPhi %1 %8 %7 %2 %6"}},
566               {3, {"%4 = OpFunction %1 None %3"}},
567               {6, {"%8 = OpPhi %1 %8 %7 %2 %6"}},
568               {7,
569                 {
570                   "OpBranch %7",
571                   "%8 = OpPhi %1 %8 %7 %2 %6",
572                   "OpBranch %7",
573                 }
574               },
575               {8, {"%8 = OpPhi %1 %8 %7 %2 %6"}},
576             },
577           },
578         },
579     })
580 );
581 // clang-format on
582 
583 struct ReplaceUseCase {
584   const char* before;
585   std::vector<std::pair<uint32_t, uint32_t>> candidates;
586   const char* after;
587   InstDefUse du;
588 };
589 
590 using ReplaceUseTest = ::testing::TestWithParam<ReplaceUseCase>;
591 
592 // Disassembles the given |module| and returns the disassembly.
DisassembleModule(Module * module)593 std::string DisassembleModule(Module* module) {
594   SpirvTools tools(SPV_ENV_UNIVERSAL_1_1);
595 
596   std::vector<uint32_t> binary;
597   module->ToBinary(&binary, /* skip_nop = */ false);
598 
599   std::string text;
600   // We'll need to check the underlying id numbers.
601   // So turn off friendly names for ids.
602   tools.Disassemble(binary, &text, SPV_BINARY_TO_TEXT_OPTION_NO_HEADER);
603   while (!text.empty() && text.back() == '\n') text.pop_back();
604   return text;
605 }
606 
TEST_P(ReplaceUseTest,Case)607 TEST_P(ReplaceUseTest, Case) {
608   const auto& tc = GetParam();
609 
610   // Build module.
611   const std::vector<const char*> text = {tc.before};
612   std::unique_ptr<IRContext> context =
613       BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, JoinAllInsts(text),
614                   SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
615   ASSERT_NE(nullptr, context);
616 
617   // Force a re-build of def-use manager.
618   context->InvalidateAnalyses(IRContext::Analysis::kAnalysisDefUse);
619   (void)context->get_def_use_mgr();
620 
621   // Do the substitution.
622   for (const auto& candidate : tc.candidates) {
623     context->ReplaceAllUsesWith(candidate.first, candidate.second);
624   }
625 
626   EXPECT_EQ(tc.after, DisassembleModule(context->module()));
627   CheckDef(tc.du, context->get_def_use_mgr()->id_to_defs());
628   CheckUse(tc.du, context->get_def_use_mgr(), context->module()->IdBound());
629 }
630 
631 // clang-format off
632 INSTANTIATE_TEST_CASE_P(
633     TestCase, ReplaceUseTest,
634     ::testing::ValuesIn(std::vector<ReplaceUseCase>{
635       { // no use, no replace request
636         "", {}, "", {},
637       },
638       { // replace one use
639         "%1 = OpTypeBool "
640         "%2 = OpTypeVector %1 3 "
641         "%3 = OpTypeInt 32 0 ",
642         {{1, 3}},
643         "%1 = OpTypeBool\n"
644         "%2 = OpTypeVector %3 3\n"
645         "%3 = OpTypeInt 32 0",
646         {
647           { // defs
648             {1, "%1 = OpTypeBool"},
649             {2, "%2 = OpTypeVector %3 3"},
650             {3, "%3 = OpTypeInt 32 0"},
651           },
652           { // uses
653             {3, {"%2 = OpTypeVector %3 3"}},
654           },
655         },
656       },
657       { // replace and then replace back
658         "%1 = OpTypeBool "
659         "%2 = OpTypeVector %1 3 "
660         "%3 = OpTypeInt 32 0",
661         {{1, 3}, {3, 1}},
662         "%1 = OpTypeBool\n"
663         "%2 = OpTypeVector %1 3\n"
664         "%3 = OpTypeInt 32 0",
665         {
666           { // defs
667             {1, "%1 = OpTypeBool"},
668             {2, "%2 = OpTypeVector %1 3"},
669             {3, "%3 = OpTypeInt 32 0"},
670           },
671           { // uses
672             {1, {"%2 = OpTypeVector %1 3"}},
673           },
674         },
675       },
676       { // replace with the same id
677         "%1 = OpTypeBool "
678         "%2 = OpTypeVector %1 3",
679         {{1, 1}, {2, 2}, {3, 3}},
680         "%1 = OpTypeBool\n"
681         "%2 = OpTypeVector %1 3",
682         {
683           { // defs
684             {1, "%1 = OpTypeBool"},
685             {2, "%2 = OpTypeVector %1 3"},
686           },
687           { // uses
688             {1, {"%2 = OpTypeVector %1 3"}},
689           },
690         },
691       },
692       { // replace in sequence
693         "%1 = OpTypeBool "
694         "%2 = OpTypeVector %1 3 "
695         "%3 = OpTypeInt 32 0 "
696         "%4 = OpTypeInt 32 1 ",
697         {{1, 3}, {3, 4}},
698         "%1 = OpTypeBool\n"
699         "%2 = OpTypeVector %4 3\n"
700         "%3 = OpTypeInt 32 0\n"
701         "%4 = OpTypeInt 32 1",
702         {
703           { // defs
704             {1, "%1 = OpTypeBool"},
705             {2, "%2 = OpTypeVector %4 3"},
706             {3, "%3 = OpTypeInt 32 0"},
707             {4, "%4 = OpTypeInt 32 1"},
708           },
709           { // uses
710             {4, {"%2 = OpTypeVector %4 3"}},
711           },
712         },
713       },
714       { // replace multiple uses
715         "%1 = OpTypeBool "
716         "%2 = OpTypeVector %1 2 "
717         "%3 = OpTypeVector %1 3 "
718         "%4 = OpTypeVector %1 4 "
719         "%5 = OpTypeMatrix %2 2 "
720         "%6 = OpTypeMatrix %3 3 "
721         "%7 = OpTypeMatrix %4 4 "
722         "%8 = OpTypeInt 32 0 "
723         "%9 = OpTypeInt 32 1 "
724         "%10 = OpTypeInt 64 0",
725         {{1, 8}, {2, 9}, {4, 10}},
726         "%1 = OpTypeBool\n"
727         "%2 = OpTypeVector %8 2\n"
728         "%3 = OpTypeVector %8 3\n"
729         "%4 = OpTypeVector %8 4\n"
730         "%5 = OpTypeMatrix %9 2\n"
731         "%6 = OpTypeMatrix %3 3\n"
732         "%7 = OpTypeMatrix %10 4\n"
733         "%8 = OpTypeInt 32 0\n"
734         "%9 = OpTypeInt 32 1\n"
735         "%10 = OpTypeInt 64 0",
736         {
737           { // defs
738             {1, "%1 = OpTypeBool"},
739             {2, "%2 = OpTypeVector %8 2"},
740             {3, "%3 = OpTypeVector %8 3"},
741             {4, "%4 = OpTypeVector %8 4"},
742             {5, "%5 = OpTypeMatrix %9 2"},
743             {6, "%6 = OpTypeMatrix %3 3"},
744             {7, "%7 = OpTypeMatrix %10 4"},
745             {8, "%8 = OpTypeInt 32 0"},
746             {9, "%9 = OpTypeInt 32 1"},
747             {10, "%10 = OpTypeInt 64 0"},
748           },
749           { // uses
750             {8,
751               {
752                 "%2 = OpTypeVector %8 2",
753                 "%3 = OpTypeVector %8 3",
754                 "%4 = OpTypeVector %8 4",
755               }
756             },
757             {9, {"%5 = OpTypeMatrix %9 2"}},
758             {3, {"%6 = OpTypeMatrix %3 3"}},
759             {10, {"%7 = OpTypeMatrix %10 4"}},
760           },
761         },
762       },
763       { // OpPhi.
764         kOpPhiTestFunction,
765         // replace one id used by OpPhi, replace one id generated by OpPhi
766         {{9, 13}, {11, 9}},
767          "%1 = OpTypeVoid\n"
768          "%6 = OpTypeInt 32 0\n"
769          "%10 = OpTypeFloat 32\n"
770          "%16 = OpTypeBool\n"
771          "%3 = OpTypeFunction %1\n"
772          "%8 = OpConstant %6 0\n"
773          "%18 = OpConstant %6 1\n"
774          "%12 = OpConstant %10 1\n"
775          "%2 = OpFunction %1 None %3\n"
776          "%4 = OpLabel\n"
777                "OpBranch %5\n"
778 
779          "%5 = OpLabel\n"
780          "%7 = OpPhi %6 %8 %4 %13 %5\n" // %9 -> %13
781         "%11 = OpPhi %10 %12 %4 %13 %5\n"
782          "%9 = OpIAdd %6 %7 %8\n"
783         "%13 = OpFAdd %10 %9 %12\n"       // %11 -> %9
784         "%17 = OpSLessThan %16 %7 %18\n"
785               "OpLoopMerge %19 %5 None\n"
786               "OpBranchConditional %17 %5 %19\n"
787 
788         "%19 = OpLabel\n"
789               "OpReturn\n"
790               "OpFunctionEnd",
791         {
792           { // defs.
793             {1, "%1 = OpTypeVoid"},
794             {2, "%2 = OpFunction %1 None %3"},
795             {3, "%3 = OpTypeFunction %1"},
796             {4, "%4 = OpLabel"},
797             {5, "%5 = OpLabel"},
798             {6, "%6 = OpTypeInt 32 0"},
799             {7, "%7 = OpPhi %6 %8 %4 %13 %5"},
800             {8, "%8 = OpConstant %6 0"},
801             {9, "%9 = OpIAdd %6 %7 %8"},
802             {10, "%10 = OpTypeFloat 32"},
803             {11, "%11 = OpPhi %10 %12 %4 %13 %5"},
804             {12, "%12 = OpConstant %10 1.0"},
805             {13, "%13 = OpFAdd %10 %9 %12"},
806             {16, "%16 = OpTypeBool"},
807             {17, "%17 = OpSLessThan %16 %7 %18"},
808             {18, "%18 = OpConstant %6 1"},
809             {19, "%19 = OpLabel"},
810           },
811           { // uses
812             {1,
813               {
814                 "%2 = OpFunction %1 None %3",
815                 "%3 = OpTypeFunction %1",
816               }
817             },
818             {3, {"%2 = OpFunction %1 None %3"}},
819             {4,
820               {
821                 "%7 = OpPhi %6 %8 %4 %13 %5",
822                 "%11 = OpPhi %10 %12 %4 %13 %5",
823               }
824             },
825             {5,
826               {
827                 "OpBranch %5",
828                 "%7 = OpPhi %6 %8 %4 %13 %5",
829                 "%11 = OpPhi %10 %12 %4 %13 %5",
830                 "OpLoopMerge %19 %5 None",
831                 "OpBranchConditional %17 %5 %19",
832               }
833             },
834             {6,
835               {
836                 // Can't properly check constants
837                 // "%8 = OpConstant %6 0",
838                 // "%18 = OpConstant %6 1",
839                 "%7 = OpPhi %6 %8 %4 %13 %5",
840                 "%9 = OpIAdd %6 %7 %8"
841               }
842             },
843             {7,
844               {
845                 "%9 = OpIAdd %6 %7 %8",
846                 "%17 = OpSLessThan %16 %7 %18",
847               }
848             },
849             {8,
850               {
851                 "%7 = OpPhi %6 %8 %4 %13 %5",
852                 "%9 = OpIAdd %6 %7 %8",
853               }
854             },
855             {9, {"%13 = OpFAdd %10 %9 %12"}}, // uses of %9 changed from %7 to %13
856             {10,
857               {
858                 "%11 = OpPhi %10 %12 %4 %13 %5",
859                 // "%12 = OpConstant %10 1",
860                 "%13 = OpFAdd %10 %9 %12"
861               }
862             },
863             // no more uses of %11
864             {12,
865               {
866                 "%11 = OpPhi %10 %12 %4 %13 %5",
867                 "%13 = OpFAdd %10 %9 %12"
868               }
869             },
870             {13, {
871                    "%7 = OpPhi %6 %8 %4 %13 %5",
872                    "%11 = OpPhi %10 %12 %4 %13 %5",
873                  }
874             },
875             {16, {"%17 = OpSLessThan %16 %7 %18"}},
876             {17, {"OpBranchConditional %17 %5 %19"}},
877             {18, {"%17 = OpSLessThan %16 %7 %18"}},
878             {19,
879               {
880                 "OpLoopMerge %19 %5 None",
881                 "OpBranchConditional %17 %5 %19",
882               }
883             },
884           },
885         },
886       },
887       { // OpPhi defining and referencing the same id.
888         "%1 = OpTypeBool "
889         "%3 = OpTypeFunction %1 "
890         "%2 = OpConstantTrue %1 "
891 
892         "%4 = OpFunction %3 None %1 "
893         "%6 = OpLabel "
894         "     OpBranch %7 "
895         "%7 = OpLabel "
896         "%8 = OpPhi %1   %8 %7   %2 %6 " // both defines and uses %8
897         "     OpBranch %7 "
898         "     OpFunctionEnd",
899         {{8, 2}},
900         "%1 = OpTypeBool\n"
901         "%3 = OpTypeFunction %1\n"
902         "%2 = OpConstantTrue %1\n"
903 
904         "%4 = OpFunction %3 None %1\n"
905         "%6 = OpLabel\n"
906              "OpBranch %7\n"
907         "%7 = OpLabel\n"
908         "%8 = OpPhi %1 %2 %7 %2 %6\n" // use of %8 changed to %2
909              "OpBranch %7\n"
910              "OpFunctionEnd",
911         {
912           { // defs
913             {1, "%1 = OpTypeBool"},
914             {2, "%2 = OpConstantTrue %1"},
915             {3, "%3 = OpTypeFunction %1"},
916             {4, "%4 = OpFunction %3 None %1"},
917             {6, "%6 = OpLabel"},
918             {7, "%7 = OpLabel"},
919             {8, "%8 = OpPhi %1 %2 %7 %2 %6"},
920           },
921           { // uses
922             {1,
923               {
924                 "%2 = OpConstantTrue %1",
925                 "%3 = OpTypeFunction %1",
926                 "%4 = OpFunction %3 None %1",
927                 "%8 = OpPhi %1 %2 %7 %2 %6",
928               }
929             },
930             {2,
931               {
932                 // Only checking users
933                 "%8 = OpPhi %1 %2 %7 %2 %6",
934               }
935             },
936             {3, {"%4 = OpFunction %3 None %1"}},
937             {6, {"%8 = OpPhi %1 %2 %7 %2 %6"}},
938             {7,
939               {
940                 "OpBranch %7",
941                 "%8 = OpPhi %1 %2 %7 %2 %6",
942                 "OpBranch %7",
943               }
944             },
945             // {8, {"%8 = OpPhi %1 %8 %7 %2 %6"}},
946           },
947         },
948       },
949     })
950 );
951 // clang-format on
952 
953 struct KillDefCase {
954   const char* before;
955   std::vector<uint32_t> ids_to_kill;
956   const char* after;
957   InstDefUse du;
958 };
959 
960 using KillDefTest = ::testing::TestWithParam<KillDefCase>;
961 
TEST_P(KillDefTest,Case)962 TEST_P(KillDefTest, Case) {
963   const auto& tc = GetParam();
964 
965   // Build module.
966   const std::vector<const char*> text = {tc.before};
967   std::unique_ptr<IRContext> context =
968       BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, JoinAllInsts(text),
969                   SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
970   ASSERT_NE(nullptr, context);
971 
972   // Analyze def and use.
973   DefUseManager manager(context->module());
974 
975   // Do the substitution.
976   for (const auto id : tc.ids_to_kill) context->KillDef(id);
977 
978   EXPECT_EQ(tc.after, DisassembleModule(context->module()));
979   CheckDef(tc.du, context->get_def_use_mgr()->id_to_defs());
980   CheckUse(tc.du, context->get_def_use_mgr(), context->module()->IdBound());
981 }
982 
983 // clang-format off
984 INSTANTIATE_TEST_CASE_P(
985     TestCase, KillDefTest,
986     ::testing::ValuesIn(std::vector<KillDefCase>{
987       { // no def, no use, no kill
988         "", {}, "", {}
989       },
990       { // kill nothing
991         "%1 = OpTypeBool "
992         "%2 = OpTypeVector %1 2 "
993         "%3 = OpTypeVector %1 3 ",
994         {},
995         "%1 = OpTypeBool\n"
996         "%2 = OpTypeVector %1 2\n"
997         "%3 = OpTypeVector %1 3",
998         {
999           { // defs
1000             {1, "%1 = OpTypeBool"},
1001             {2, "%2 = OpTypeVector %1 2"},
1002             {3, "%3 = OpTypeVector %1 3"},
1003           },
1004           { // uses
1005             {1,
1006               {
1007                 "%2 = OpTypeVector %1 2",
1008                 "%3 = OpTypeVector %1 3",
1009               }
1010             },
1011           },
1012         },
1013       },
1014       { // kill id used, kill id not used, kill id not defined
1015         "%1 = OpTypeBool "
1016         "%2 = OpTypeVector %1 2 "
1017         "%3 = OpTypeVector %1 3 "
1018         "%4 = OpTypeVector %1 4 "
1019         "%5 = OpTypeMatrix %3 3 "
1020         "%6 = OpTypeMatrix %2 3",
1021         {1, 3, 5, 10}, // ids to kill
1022         "%2 = OpTypeVector %1 2\n"
1023         "%4 = OpTypeVector %1 4\n"
1024         "%6 = OpTypeMatrix %2 3",
1025         {
1026           { // defs
1027             {2, "%2 = OpTypeVector %1 2"},
1028             {4, "%4 = OpTypeVector %1 4"},
1029             {6, "%6 = OpTypeMatrix %2 3"},
1030           },
1031           { // uses. %1 and %3 are both killed, so no uses
1032             // recorded for them anymore.
1033             {2, {"%6 = OpTypeMatrix %2 3"}},
1034           }
1035         },
1036       },
1037       { // OpPhi.
1038         kOpPhiTestFunction,
1039         {9, 11}, // kill one id used by OpPhi, kill one id generated by OpPhi
1040          "%1 = OpTypeVoid\n"
1041          "%6 = OpTypeInt 32 0\n"
1042          "%10 = OpTypeFloat 32\n"
1043          "%16 = OpTypeBool\n"
1044          "%3 = OpTypeFunction %1\n"
1045          "%8 = OpConstant %6 0\n"
1046          "%18 = OpConstant %6 1\n"
1047          "%12 = OpConstant %10 1\n"
1048          "%2 = OpFunction %1 None %3\n"
1049          "%4 = OpLabel\n"
1050                "OpBranch %5\n"
1051 
1052          "%5 = OpLabel\n"
1053          "%7 = OpPhi %6 %8 %4 %9 %5\n"
1054         "%13 = OpFAdd %10 %11 %12\n"
1055         "%17 = OpSLessThan %16 %7 %18\n"
1056               "OpLoopMerge %19 %5 None\n"
1057               "OpBranchConditional %17 %5 %19\n"
1058 
1059         "%19 = OpLabel\n"
1060               "OpReturn\n"
1061               "OpFunctionEnd",
1062         {
1063           { // defs. %9 & %11 are killed.
1064             {1, "%1 = OpTypeVoid"},
1065             {2, "%2 = OpFunction %1 None %3"},
1066             {3, "%3 = OpTypeFunction %1"},
1067             {4, "%4 = OpLabel"},
1068             {5, "%5 = OpLabel"},
1069             {6, "%6 = OpTypeInt 32 0"},
1070             {7, "%7 = OpPhi %6 %8 %4 %9 %5"},
1071             {8, "%8 = OpConstant %6 0"},
1072             {10, "%10 = OpTypeFloat 32"},
1073             {12, "%12 = OpConstant %10 1.0"},
1074             {13, "%13 = OpFAdd %10 %11 %12"},
1075             {16, "%16 = OpTypeBool"},
1076             {17, "%17 = OpSLessThan %16 %7 %18"},
1077             {18, "%18 = OpConstant %6 1"},
1078             {19, "%19 = OpLabel"},
1079           },
1080           { // uses
1081             {1,
1082               {
1083                 "%2 = OpFunction %1 None %3",
1084                 "%3 = OpTypeFunction %1",
1085               }
1086             },
1087             {3, {"%2 = OpFunction %1 None %3"}},
1088             {4,
1089               {
1090                 "%7 = OpPhi %6 %8 %4 %9 %5",
1091                 // "%11 = OpPhi %10 %12 %4 %13 %5",
1092               }
1093             },
1094             {5,
1095               {
1096                 "OpBranch %5",
1097                 "%7 = OpPhi %6 %8 %4 %9 %5",
1098                 // "%11 = OpPhi %10 %12 %4 %13 %5",
1099                 "OpLoopMerge %19 %5 None",
1100                 "OpBranchConditional %17 %5 %19",
1101               }
1102             },
1103             {6,
1104               {
1105                 // Can't properly check constants
1106                 // "%8 = OpConstant %6 0",
1107                 // "%18 = OpConstant %6 1",
1108                 "%7 = OpPhi %6 %8 %4 %9 %5",
1109                 // "%9 = OpIAdd %6 %7 %8"
1110               }
1111             },
1112             {7, {"%17 = OpSLessThan %16 %7 %18"}},
1113             {8,
1114               {
1115                 "%7 = OpPhi %6 %8 %4 %9 %5",
1116                 // "%9 = OpIAdd %6 %7 %8",
1117               }
1118             },
1119             // {9, {"%7 = OpPhi %6 %8 %4 %13 %5"}},
1120             {10,
1121               {
1122                 // "%11 = OpPhi %10 %12 %4 %13 %5",
1123                 // "%12 = OpConstant %10 1",
1124                 "%13 = OpFAdd %10 %11 %12"
1125               }
1126             },
1127             // {11, {"%13 = OpFAdd %10 %11 %12"}},
1128             {12,
1129               {
1130                 // "%11 = OpPhi %10 %12 %4 %13 %5",
1131                 "%13 = OpFAdd %10 %11 %12"
1132               }
1133             },
1134             // {13, {"%11 = OpPhi %10 %12 %4 %13 %5"}},
1135             {16, {"%17 = OpSLessThan %16 %7 %18"}},
1136             {17, {"OpBranchConditional %17 %5 %19"}},
1137             {18, {"%17 = OpSLessThan %16 %7 %18"}},
1138             {19,
1139               {
1140                 "OpLoopMerge %19 %5 None",
1141                 "OpBranchConditional %17 %5 %19",
1142               }
1143             },
1144           },
1145         },
1146       },
1147       { // OpPhi defining and referencing the same id.
1148         "%1 = OpTypeBool "
1149         "%3 = OpTypeFunction %1 "
1150         "%2 = OpConstantTrue %1 "
1151         "%4 = OpFunction %3 None %1 "
1152         "%6 = OpLabel "
1153         "     OpBranch %7 "
1154         "%7 = OpLabel "
1155         "%8 = OpPhi %1   %8 %7   %2 %6 " // both defines and uses %8
1156         "     OpBranch %7 "
1157         "     OpFunctionEnd",
1158         {8},
1159         "%1 = OpTypeBool\n"
1160         "%3 = OpTypeFunction %1\n"
1161         "%2 = OpConstantTrue %1\n"
1162 
1163         "%4 = OpFunction %3 None %1\n"
1164         "%6 = OpLabel\n"
1165              "OpBranch %7\n"
1166         "%7 = OpLabel\n"
1167              "OpBranch %7\n"
1168              "OpFunctionEnd",
1169         {
1170           { // defs
1171             {1, "%1 = OpTypeBool"},
1172             {2, "%2 = OpConstantTrue %1"},
1173             {3, "%3 = OpTypeFunction %1"},
1174             {4, "%4 = OpFunction %3 None %1"},
1175             {6, "%6 = OpLabel"},
1176             {7, "%7 = OpLabel"},
1177             // {8, "%8 = OpPhi %1 %8 %7 %2 %6"},
1178           },
1179           { // uses
1180             {1,
1181               {
1182                 "%2 = OpConstantTrue %1",
1183                 "%3 = OpTypeFunction %1",
1184                 "%4 = OpFunction %3 None %1",
1185                 // "%8 = OpPhi %1 %8 %7 %2 %6",
1186               }
1187             },
1188             // {2, {"%8 = OpPhi %1 %8 %7 %2 %6"}},
1189             {3, {"%4 = OpFunction %3 None %1"}},
1190             // {6, {"%8 = OpPhi %1 %8 %7 %2 %6"}},
1191             {7,
1192               {
1193                 "OpBranch %7",
1194                 // "%8 = OpPhi %1 %8 %7 %2 %6",
1195                 "OpBranch %7",
1196               }
1197             },
1198             // {8, {"%8 = OpPhi %1 %8 %7 %2 %6"}},
1199           },
1200         },
1201       },
1202     })
1203 );
1204 // clang-format on
1205 
TEST(DefUseTest,OpSwitch)1206 TEST(DefUseTest, OpSwitch) {
1207   // Because disassembler has basic type check for OpSwitch's selector, we
1208   // cannot use the DisassembleInst() in the above. Thus, this special spotcheck
1209   // test case.
1210 
1211   const char original_text[] =
1212       // int64 f(int64 v) {
1213       //   switch (v) {
1214       //     case 1:                   break;
1215       //     case -4294967296:         break;
1216       //     case 9223372036854775807: break;
1217       //     default:                  break;
1218       //   }
1219       //   return v;
1220       // }
1221       " %1 = OpTypeInt 64 1 "
1222       " %3 = OpTypePointer Input %1 "
1223       " %2 = OpFunction %1 None %3 "  // %3 is int64(int64)*
1224       " %4 = OpFunctionParameter %1 "
1225       " %5 = OpLabel "
1226       " %6 = OpLoad %1 %4 "  // selector value
1227       "      OpSelectionMerge %7 None "
1228       "      OpSwitch %6 %8 "
1229       "                  1                    %9 "  // 1
1230       "                  -4294967296         %10 "  // -2^32
1231       "                  9223372036854775807 %11 "  // 2^63-1
1232       " %8 = OpLabel "                              // default
1233       "      OpBranch %7 "
1234       " %9 = OpLabel "
1235       "      OpBranch %7 "
1236       "%10 = OpLabel "
1237       "      OpBranch %7 "
1238       "%11 = OpLabel "
1239       "      OpBranch %7 "
1240       " %7 = OpLabel "
1241       "      OpReturnValue %6 "
1242       "      OpFunctionEnd";
1243 
1244   std::unique_ptr<IRContext> context =
1245       BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, original_text,
1246                   SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
1247   ASSERT_NE(nullptr, context);
1248 
1249   // Force a re-build of def-use manager.
1250   context->InvalidateAnalyses(IRContext::Analysis::kAnalysisDefUse);
1251   (void)context->get_def_use_mgr();
1252 
1253   // Do a bunch replacements.
1254   context->ReplaceAllUsesWith(11, 7);   // to existing id
1255   context->ReplaceAllUsesWith(10, 11);  // to existing id
1256   context->ReplaceAllUsesWith(9, 10);   // to existing id
1257 
1258   // clang-format off
1259   const char modified_text[] =
1260        "%1 = OpTypeInt 64 1\n"
1261        "%3 = OpTypePointer Input %1\n"
1262        "%2 = OpFunction %1 None %3\n" // %3 is int64(int64)*
1263        "%4 = OpFunctionParameter %1\n"
1264        "%5 = OpLabel\n"
1265        "%6 = OpLoad %1 %4\n" // selector value
1266             "OpSelectionMerge %7 None\n"
1267             "OpSwitch %6 %8 1 %10 -4294967296 %11 9223372036854775807 %7\n" // changed!
1268        "%8 = OpLabel\n"      // default
1269             "OpBranch %7\n"
1270        "%9 = OpLabel\n"
1271             "OpBranch %7\n"
1272       "%10 = OpLabel\n"
1273             "OpBranch %7\n"
1274       "%11 = OpLabel\n"
1275             "OpBranch %7\n"
1276        "%7 = OpLabel\n"
1277             "OpReturnValue %6\n"
1278             "OpFunctionEnd";
1279   // clang-format on
1280 
1281   EXPECT_EQ(modified_text, DisassembleModule(context->module()));
1282 
1283   InstDefUse def_uses = {};
1284   def_uses.defs = {
1285       {1, "%1 = OpTypeInt 64 1"},
1286       {2, "%2 = OpFunction %1 None %3"},
1287       {3, "%3 = OpTypePointer Input %1"},
1288       {4, "%4 = OpFunctionParameter %1"},
1289       {5, "%5 = OpLabel"},
1290       {6, "%6 = OpLoad %1 %4"},
1291       {7, "%7 = OpLabel"},
1292       {8, "%8 = OpLabel"},
1293       {9, "%9 = OpLabel"},
1294       {10, "%10 = OpLabel"},
1295       {11, "%11 = OpLabel"},
1296   };
1297   CheckDef(def_uses, context->get_def_use_mgr()->id_to_defs());
1298 
1299   {
1300     EXPECT_EQ(2u, NumUses(context, 6));
1301     std::vector<SpvOp> opcodes = GetUseOpcodes(context, 6u);
1302     EXPECT_THAT(opcodes, UnorderedElementsAre(SpvOpSwitch, SpvOpReturnValue));
1303   }
1304   {
1305     EXPECT_EQ(6u, NumUses(context, 7));
1306     std::vector<SpvOp> opcodes = GetUseOpcodes(context, 7u);
1307     // OpSwitch is now a user of %7.
1308     EXPECT_THAT(opcodes, UnorderedElementsAre(SpvOpSelectionMerge, SpvOpBranch,
1309                                               SpvOpBranch, SpvOpBranch,
1310                                               SpvOpBranch, SpvOpSwitch));
1311   }
1312   // Check all ids only used by OpSwitch after replacement.
1313   for (const auto id : {8u, 10u, 11u}) {
1314     EXPECT_EQ(1u, NumUses(context, id));
1315     EXPECT_EQ(SpvOpSwitch, GetUseOpcodes(context, id).back());
1316   }
1317 }
1318 
1319 // Test case for analyzing individual instructions.
1320 struct AnalyzeInstDefUseTestCase {
1321   const char* module_text;
1322   InstDefUse expected_define_use;
1323 };
1324 
1325 using AnalyzeInstDefUseTest =
1326     ::testing::TestWithParam<AnalyzeInstDefUseTestCase>;
1327 
1328 // Test the analyzing result for individual instructions.
TEST_P(AnalyzeInstDefUseTest,Case)1329 TEST_P(AnalyzeInstDefUseTest, Case) {
1330   auto tc = GetParam();
1331 
1332   // Build module.
1333   std::unique_ptr<IRContext> context =
1334       BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, tc.module_text);
1335   ASSERT_NE(nullptr, context);
1336 
1337   // Analyze the instructions.
1338   DefUseManager manager(context->module());
1339 
1340   CheckDef(tc.expected_define_use, manager.id_to_defs());
1341   CheckUse(tc.expected_define_use, &manager, context->module()->IdBound());
1342   // CheckUse(tc.expected_define_use, manager.id_to_uses());
1343 }
1344 
1345 // clang-format off
1346 INSTANTIATE_TEST_CASE_P(
1347     TestCase, AnalyzeInstDefUseTest,
1348     ::testing::ValuesIn(std::vector<AnalyzeInstDefUseTestCase>{
1349       { // A type declaring instruction.
1350         "%1 = OpTypeInt 32 1",
1351         {
1352           // defs
1353           {{1, "%1 = OpTypeInt 32 1"}},
1354           {}, // no uses
1355         },
1356       },
1357       { // A type declaring instruction and a constant value.
1358         "%1 = OpTypeBool "
1359         "%2 = OpConstantTrue %1",
1360         {
1361           { // defs
1362             {1, "%1 = OpTypeBool"},
1363             {2, "%2 = OpConstantTrue %1"},
1364           },
1365           { // uses
1366             {1, {"%2 = OpConstantTrue %1"}},
1367           },
1368         },
1369       },
1370       }));
1371 // clang-format on
1372 
1373 using AnalyzeInstDefUse = ::testing::Test;
1374 
TEST(AnalyzeInstDefUse,UseWithNoResultId)1375 TEST(AnalyzeInstDefUse, UseWithNoResultId) {
1376   IRContext context(SPV_ENV_UNIVERSAL_1_2, nullptr);
1377 
1378   // Analyze the instructions.
1379   DefUseManager manager(context.module());
1380 
1381   Instruction label(&context, SpvOpLabel, 0, 2, {});
1382   manager.AnalyzeInstDefUse(&label);
1383 
1384   Instruction branch(&context, SpvOpBranch, 0, 0, {{SPV_OPERAND_TYPE_ID, {2}}});
1385   manager.AnalyzeInstDefUse(&branch);
1386   context.module()->SetIdBound(3);
1387 
1388   InstDefUse expected = {
1389       // defs
1390       {
1391           {2, "%2 = OpLabel"},
1392       },
1393       // uses
1394       {{2, {"OpBranch %2"}}},
1395   };
1396 
1397   CheckDef(expected, manager.id_to_defs());
1398   CheckUse(expected, &manager, context.module()->IdBound());
1399 }
1400 
TEST(AnalyzeInstDefUse,AddNewInstruction)1401 TEST(AnalyzeInstDefUse, AddNewInstruction) {
1402   const std::string input = "%1 = OpTypeBool";
1403 
1404   // Build module.
1405   std::unique_ptr<IRContext> context =
1406       BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, input);
1407   ASSERT_NE(nullptr, context);
1408 
1409   // Analyze the instructions.
1410   DefUseManager manager(context->module());
1411 
1412   Instruction newInst(context.get(), SpvOpConstantTrue, 1, 2, {});
1413   manager.AnalyzeInstDefUse(&newInst);
1414 
1415   InstDefUse expected = {
1416       {
1417           // defs
1418           {1, "%1 = OpTypeBool"},
1419           {2, "%2 = OpConstantTrue %1"},
1420       },
1421       {
1422           // uses
1423           {1, {"%2 = OpConstantTrue %1"}},
1424       },
1425   };
1426 
1427   CheckDef(expected, manager.id_to_defs());
1428   CheckUse(expected, &manager, context->module()->IdBound());
1429 }
1430 
1431 struct KillInstTestCase {
1432   const char* before;
1433   std::unordered_set<uint32_t> indices_for_inst_to_kill;
1434   const char* after;
1435   InstDefUse expected_define_use;
1436 };
1437 
1438 using KillInstTest = ::testing::TestWithParam<KillInstTestCase>;
1439 
TEST_P(KillInstTest,Case)1440 TEST_P(KillInstTest, Case) {
1441   auto tc = GetParam();
1442 
1443   // Build module.
1444   std::unique_ptr<IRContext> context =
1445       BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, tc.before,
1446                   SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
1447   ASSERT_NE(nullptr, context);
1448 
1449   // Force a re-build of the def-use manager.
1450   context->InvalidateAnalyses(IRContext::Analysis::kAnalysisDefUse);
1451   (void)context->get_def_use_mgr();
1452 
1453   // KillInst
1454   context->module()->ForEachInst([&tc, &context](Instruction* inst) {
1455     if (tc.indices_for_inst_to_kill.count(inst->result_id())) {
1456       context->KillInst(inst);
1457     }
1458   });
1459 
1460   EXPECT_EQ(tc.after, DisassembleModule(context->module()));
1461   CheckDef(tc.expected_define_use, context->get_def_use_mgr()->id_to_defs());
1462   CheckUse(tc.expected_define_use, context->get_def_use_mgr(),
1463            context->module()->IdBound());
1464 }
1465 
1466 // clang-format off
1467 INSTANTIATE_TEST_CASE_P(
1468     TestCase, KillInstTest,
1469     ::testing::ValuesIn(std::vector<KillInstTestCase>{
1470       // Kill id defining instructions.
1471       {
1472         "%3 = OpTypeVoid "
1473         "%1 = OpTypeFunction %3 "
1474         "%2 = OpFunction %1 None %3 "
1475         "%4 = OpLabel "
1476         "     OpBranch %5 "
1477         "%5 = OpLabel "
1478         "     OpBranch %6 "
1479         "%6 = OpLabel "
1480         "     OpBranch %4 "
1481         "%7 = OpLabel "
1482         "     OpReturn "
1483         "     OpFunctionEnd",
1484         {3, 5, 7},
1485         "%1 = OpTypeFunction %3\n"
1486         "%2 = OpFunction %1 None %3\n"
1487         "%4 = OpLabel\n"
1488         "OpBranch %5\n"
1489         "OpNop\n"
1490         "OpBranch %6\n"
1491         "%6 = OpLabel\n"
1492         "OpBranch %4\n"
1493         "OpNop\n"
1494         "OpReturn\n"
1495         "OpFunctionEnd",
1496         {
1497           // defs
1498           {
1499             {1, "%1 = OpTypeFunction %3"},
1500             {2, "%2 = OpFunction %1 None %3"},
1501             {4, "%4 = OpLabel"},
1502             {6, "%6 = OpLabel"},
1503           },
1504           // uses
1505           {
1506             {1, {"%2 = OpFunction %1 None %3"}},
1507             {4, {"OpBranch %4"}},
1508             {6, {"OpBranch %6"}},
1509           }
1510         }
1511       },
1512       // Kill instructions that do not have result ids.
1513       {
1514         "%3 = OpTypeVoid "
1515         "%1 = OpTypeFunction %3 "
1516         "%2 = OpFunction %1 None %3 "
1517         "%4 = OpLabel "
1518         "     OpBranch %5 "
1519         "%5 = OpLabel "
1520         "     OpBranch %6 "
1521         "%6 = OpLabel "
1522         "     OpBranch %4 "
1523         "%7 = OpLabel "
1524         "     OpReturn "
1525         "     OpFunctionEnd",
1526         {2, 4},
1527         "%3 = OpTypeVoid\n"
1528         "%1 = OpTypeFunction %3\n"
1529              "OpNop\n"
1530              "OpNop\n"
1531              "OpBranch %5\n"
1532         "%5 = OpLabel\n"
1533              "OpBranch %6\n"
1534         "%6 = OpLabel\n"
1535              "OpBranch %4\n"
1536         "%7 = OpLabel\n"
1537              "OpReturn\n"
1538              "OpFunctionEnd",
1539         {
1540           // defs
1541           {
1542             {1, "%1 = OpTypeFunction %3"},
1543             {3, "%3 = OpTypeVoid"},
1544             {5, "%5 = OpLabel"},
1545             {6, "%6 = OpLabel"},
1546             {7, "%7 = OpLabel"},
1547           },
1548           // uses
1549           {
1550             {3, {"%1 = OpTypeFunction %3"}},
1551             {5, {"OpBranch %5"}},
1552             {6, {"OpBranch %6"}},
1553           }
1554         }
1555       },
1556       }));
1557 // clang-format on
1558 
1559 struct GetAnnotationsTestCase {
1560   const char* code;
1561   uint32_t id;
1562   std::vector<std::string> annotations;
1563 };
1564 
1565 using GetAnnotationsTest = ::testing::TestWithParam<GetAnnotationsTestCase>;
1566 
TEST_P(GetAnnotationsTest,Case)1567 TEST_P(GetAnnotationsTest, Case) {
1568   const GetAnnotationsTestCase& tc = GetParam();
1569 
1570   // Build module.
1571   std::unique_ptr<IRContext> context =
1572       BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, tc.code);
1573   ASSERT_NE(nullptr, context);
1574 
1575   // Get annotations
1576   DefUseManager manager(context->module());
1577   auto insts = manager.GetAnnotations(tc.id);
1578 
1579   // Check
1580   ASSERT_EQ(tc.annotations.size(), insts.size())
1581       << "wrong number of annotation instructions";
1582   auto inst_iter = insts.begin();
1583   for (const std::string& expected_anno_inst : tc.annotations) {
1584     EXPECT_EQ(expected_anno_inst, DisassembleInst(*inst_iter))
1585         << "annotation instruction mismatch";
1586     inst_iter++;
1587   }
1588 }
1589 
1590 // clang-format off
1591 INSTANTIATE_TEST_CASE_P(
1592     TestCase, GetAnnotationsTest,
1593     ::testing::ValuesIn(std::vector<GetAnnotationsTestCase>{
1594       // empty
1595       {"", 0, {}},
1596       // basic
1597       {
1598         // code
1599         "OpDecorate %1 Block "
1600         "OpDecorate %1 RelaxedPrecision "
1601         "%3 = OpTypeInt 32 0 "
1602         "%1 = OpTypeStruct %3",
1603         // id
1604         1,
1605         // annotations
1606         {
1607           "OpDecorate %1 Block",
1608           "OpDecorate %1 RelaxedPrecision",
1609         },
1610       },
1611       // with debug instructions
1612       {
1613         // code
1614         "OpName %1 \"struct_type\" "
1615         "OpName %3 \"int_type\" "
1616         "OpDecorate %1 Block "
1617         "OpDecorate %1 RelaxedPrecision "
1618         "%3 = OpTypeInt 32 0 "
1619         "%1 = OpTypeStruct %3",
1620         // id
1621         1,
1622         // annotations
1623         {
1624           "OpDecorate %1 Block",
1625           "OpDecorate %1 RelaxedPrecision",
1626         },
1627       },
1628       // no annotations
1629       {
1630         // code
1631         "OpName %1 \"struct_type\" "
1632         "OpName %3 \"int_type\" "
1633         "OpDecorate %1 Block "
1634         "OpDecorate %1 RelaxedPrecision "
1635         "%3 = OpTypeInt 32 0 "
1636         "%1 = OpTypeStruct %3",
1637         // id
1638         3,
1639         // annotations
1640         {},
1641       },
1642       // decoration group
1643       {
1644         // code
1645         "OpDecorate %1 Block "
1646         "OpDecorate %1 RelaxedPrecision "
1647         "%1 = OpDecorationGroup "
1648         "OpGroupDecorate %1 %2 %3 "
1649         "%4 = OpTypeInt 32 0 "
1650         "%2 = OpTypeStruct %4 "
1651         "%3 = OpTypeStruct %4 %4",
1652         // id
1653         3,
1654         // annotations
1655         {
1656           "OpGroupDecorate %1 %2 %3",
1657         },
1658       },
1659       // memeber decorate
1660       {
1661         // code
1662         "OpMemberDecorate %1 0 RelaxedPrecision "
1663         "%2 = OpTypeInt 32 0 "
1664         "%1 = OpTypeStruct %2 %2",
1665         // id
1666         1,
1667         // annotations
1668         {
1669           "OpMemberDecorate %1 0 RelaxedPrecision",
1670         },
1671       },
1672       }));
1673 
1674 using UpdateUsesTest = PassTest<::testing::Test>;
1675 
TEST_F(UpdateUsesTest,KeepOldUses)1676 TEST_F(UpdateUsesTest, KeepOldUses) {
1677   const std::vector<const char*> text = {
1678       // clang-format off
1679       "OpCapability Shader",
1680       "%1 = OpExtInstImport \"GLSL.std.450\"",
1681       "OpMemoryModel Logical GLSL450",
1682       "OpEntryPoint Vertex %main \"main\"",
1683       "OpName %main \"main\"",
1684       "%void = OpTypeVoid",
1685       "%4 = OpTypeFunction %void",
1686       "%uint = OpTypeInt 32 0",
1687       "%uint_5 = OpConstant %uint 5",
1688       "%25 = OpConstant %uint 25",
1689       "%main = OpFunction %void None %4",
1690       "%8 = OpLabel",
1691       "%9 = OpIMul %uint %uint_5 %uint_5",
1692       "%10 = OpIMul %uint %9 %uint_5",
1693       "OpReturn",
1694       "OpFunctionEnd"
1695       // clang-format on
1696   };
1697 
1698   std::unique_ptr<IRContext> context =
1699       BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, JoinAllInsts(text),
1700                   SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
1701   ASSERT_NE(nullptr, context);
1702 
1703   DefUseManager* def_use_mgr = context->get_def_use_mgr();
1704   Instruction* def = def_use_mgr->GetDef(9);
1705   Instruction* use = def_use_mgr->GetDef(10);
1706   def->SetOpcode(SpvOpCopyObject);
1707   def->SetInOperands({{SPV_OPERAND_TYPE_ID, {25}}});
1708   context->UpdateDefUse(def);
1709 
1710   auto users = def_use_mgr->id_to_users();
1711   UserEntry entry = {def, use};
1712   EXPECT_THAT(users, Contains(entry));
1713 }
1714 // clang-format on
1715 
1716 }  // namespace
1717 }  // namespace analysis
1718 }  // namespace opt
1719 }  // namespace spvtools
1720