1 //===- CodeExtractor.cpp - Unit tests for CodeExtractor -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "llvm/Transforms/Utils/CodeExtractor.h"
10 #include "llvm/AsmParser/Parser.h"
11 #include "llvm/Analysis/AssumptionCache.h"
12 #include "llvm/IR/BasicBlock.h"
13 #include "llvm/IR/Dominators.h"
14 #include "llvm/IR/Instructions.h"
15 #include "llvm/IR/LLVMContext.h"
16 #include "llvm/IR/Module.h"
17 #include "llvm/IR/Verifier.h"
18 #include "llvm/IRReader/IRReader.h"
19 #include "llvm/Support/SourceMgr.h"
20 #include "gtest/gtest.h"
21 
22 using namespace llvm;
23 
24 namespace {
getBlockByName(Function * F,StringRef name)25 BasicBlock *getBlockByName(Function *F, StringRef name) {
26   for (auto &BB : *F)
27     if (BB.getName() == name)
28       return &BB;
29   return nullptr;
30 }
31 
TEST(CodeExtractor,ExitStub)32 TEST(CodeExtractor, ExitStub) {
33   LLVMContext Ctx;
34   SMDiagnostic Err;
35   std::unique_ptr<Module> M(parseAssemblyString(R"invalid(
36     define i32 @foo(i32 %x, i32 %y, i32 %z) {
37     header:
38       %0 = icmp ugt i32 %x, %y
39       br i1 %0, label %body1, label %body2
40 
41     body1:
42       %1 = add i32 %z, 2
43       br label %notExtracted
44 
45     body2:
46       %2 = mul i32 %z, 7
47       br label %notExtracted
48 
49     notExtracted:
50       %3 = phi i32 [ %1, %body1 ], [ %2, %body2 ]
51       %4 = add i32 %3, %x
52       ret i32 %4
53     }
54   )invalid",
55                                                 Err, Ctx));
56 
57   Function *Func = M->getFunction("foo");
58   SmallVector<BasicBlock *, 3> Candidates{ getBlockByName(Func, "header"),
59                                            getBlockByName(Func, "body1"),
60                                            getBlockByName(Func, "body2") };
61 
62   CodeExtractor CE(Candidates);
63   EXPECT_TRUE(CE.isEligible());
64 
65   CodeExtractorAnalysisCache CEAC(*Func);
66   Function *Outlined = CE.extractCodeRegion(CEAC);
67   EXPECT_TRUE(Outlined);
68   BasicBlock *Exit = getBlockByName(Func, "notExtracted");
69   BasicBlock *ExitSplit = getBlockByName(Outlined, "notExtracted.split");
70   // Ensure that PHI in exit block has only one incoming value (from code
71   // replacer block).
72   EXPECT_TRUE(Exit && cast<PHINode>(Exit->front()).getNumIncomingValues() == 1);
73   // Ensure that there is a PHI in outlined function with 2 incoming values.
74   EXPECT_TRUE(ExitSplit &&
75               cast<PHINode>(ExitSplit->front()).getNumIncomingValues() == 2);
76   EXPECT_FALSE(verifyFunction(*Outlined));
77   EXPECT_FALSE(verifyFunction(*Func));
78 }
79 
TEST(CodeExtractor,ExitPHIOnePredFromRegion)80 TEST(CodeExtractor, ExitPHIOnePredFromRegion) {
81   LLVMContext Ctx;
82   SMDiagnostic Err;
83   std::unique_ptr<Module> M(parseAssemblyString(R"invalid(
84     define i32 @foo() {
85     header:
86       br i1 undef, label %extracted1, label %pred
87 
88     pred:
89       br i1 undef, label %exit1, label %exit2
90 
91     extracted1:
92       br i1 undef, label %extracted2, label %exit1
93 
94     extracted2:
95       br label %exit2
96 
97     exit1:
98       %0 = phi i32 [ 1, %extracted1 ], [ 2, %pred ]
99       ret i32 %0
100 
101     exit2:
102       %1 = phi i32 [ 3, %extracted2 ], [ 4, %pred ]
103       ret i32 %1
104     }
105   )invalid", Err, Ctx));
106 
107   Function *Func = M->getFunction("foo");
108   SmallVector<BasicBlock *, 2> ExtractedBlocks{
109     getBlockByName(Func, "extracted1"),
110     getBlockByName(Func, "extracted2")
111   };
112 
113   CodeExtractor CE(ExtractedBlocks);
114   EXPECT_TRUE(CE.isEligible());
115 
116   CodeExtractorAnalysisCache CEAC(*Func);
117   Function *Outlined = CE.extractCodeRegion(CEAC);
118   EXPECT_TRUE(Outlined);
119   BasicBlock *Exit1 = getBlockByName(Func, "exit1");
120   BasicBlock *Exit2 = getBlockByName(Func, "exit2");
121   // Ensure that PHIs in exits are not splitted (since that they have only one
122   // incoming value from extracted region).
123   EXPECT_TRUE(Exit1 &&
124           cast<PHINode>(Exit1->front()).getNumIncomingValues() == 2);
125   EXPECT_TRUE(Exit2 &&
126           cast<PHINode>(Exit2->front()).getNumIncomingValues() == 2);
127   EXPECT_FALSE(verifyFunction(*Outlined));
128   EXPECT_FALSE(verifyFunction(*Func));
129 }
130 
TEST(CodeExtractor,StoreOutputInvokeResultAfterEHPad)131 TEST(CodeExtractor, StoreOutputInvokeResultAfterEHPad) {
132   LLVMContext Ctx;
133   SMDiagnostic Err;
134   std::unique_ptr<Module> M(parseAssemblyString(R"invalid(
135     declare i8 @hoge()
136 
137     define i32 @foo() personality i8* null {
138       entry:
139         %call = invoke i8 @hoge()
140                 to label %invoke.cont unwind label %lpad
141 
142       invoke.cont:                                      ; preds = %entry
143         unreachable
144 
145       lpad:                                             ; preds = %entry
146         %0 = landingpad { i8*, i32 }
147                 catch i8* null
148         br i1 undef, label %catch, label %finally.catchall
149 
150       catch:                                            ; preds = %lpad
151         %call2 = invoke i8 @hoge()
152                 to label %invoke.cont2 unwind label %lpad2
153 
154       invoke.cont2:                                    ; preds = %catch
155         %call3 = invoke i8 @hoge()
156                 to label %invoke.cont3 unwind label %lpad2
157 
158       invoke.cont3:                                    ; preds = %invoke.cont2
159         unreachable
160 
161       lpad2:                                           ; preds = %invoke.cont2, %catch
162         %ex.1 = phi i8* [ undef, %invoke.cont2 ], [ null, %catch ]
163         %1 = landingpad { i8*, i32 }
164                 catch i8* null
165         br label %finally.catchall
166 
167       finally.catchall:                                 ; preds = %lpad33, %lpad
168         %ex.2 = phi i8* [ %ex.1, %lpad2 ], [ null, %lpad ]
169         unreachable
170     }
171   )invalid", Err, Ctx));
172 
173 	if (!M) {
174     Err.print("unit", errs());
175     exit(1);
176   }
177 
178   Function *Func = M->getFunction("foo");
179   EXPECT_FALSE(verifyFunction(*Func, &errs()));
180 
181   SmallVector<BasicBlock *, 2> ExtractedBlocks{
182     getBlockByName(Func, "catch"),
183     getBlockByName(Func, "invoke.cont2"),
184     getBlockByName(Func, "invoke.cont3"),
185     getBlockByName(Func, "lpad2")
186   };
187 
188   CodeExtractor CE(ExtractedBlocks);
189   EXPECT_TRUE(CE.isEligible());
190 
191   CodeExtractorAnalysisCache CEAC(*Func);
192   Function *Outlined = CE.extractCodeRegion(CEAC);
193   EXPECT_TRUE(Outlined);
194   EXPECT_FALSE(verifyFunction(*Outlined, &errs()));
195   EXPECT_FALSE(verifyFunction(*Func, &errs()));
196 }
197 
TEST(CodeExtractor,StoreOutputInvokeResultInExitStub)198 TEST(CodeExtractor, StoreOutputInvokeResultInExitStub) {
199   LLVMContext Ctx;
200   SMDiagnostic Err;
201   std::unique_ptr<Module> M(parseAssemblyString(R"invalid(
202     declare i32 @bar()
203 
204     define i32 @foo() personality i8* null {
205     entry:
206       %0 = invoke i32 @bar() to label %exit unwind label %lpad
207 
208     exit:
209       ret i32 %0
210 
211     lpad:
212       %1 = landingpad { i8*, i32 }
213               cleanup
214       resume { i8*, i32 } %1
215     }
216   )invalid",
217                                                 Err, Ctx));
218 
219   Function *Func = M->getFunction("foo");
220   SmallVector<BasicBlock *, 1> Blocks{ getBlockByName(Func, "entry"),
221                                        getBlockByName(Func, "lpad") };
222 
223   CodeExtractor CE(Blocks);
224   EXPECT_TRUE(CE.isEligible());
225 
226   CodeExtractorAnalysisCache CEAC(*Func);
227   Function *Outlined = CE.extractCodeRegion(CEAC);
228   EXPECT_TRUE(Outlined);
229   EXPECT_FALSE(verifyFunction(*Outlined));
230   EXPECT_FALSE(verifyFunction(*Func));
231 }
232 
TEST(CodeExtractor,ExtractAndInvalidateAssumptionCache)233 TEST(CodeExtractor, ExtractAndInvalidateAssumptionCache) {
234   LLVMContext Ctx;
235   SMDiagnostic Err;
236   std::unique_ptr<Module> M(parseAssemblyString(R"ir(
237         target datalayout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128"
238         target triple = "aarch64"
239 
240         %b = type { i64 }
241         declare void @g(i8*)
242 
243         declare void @llvm.assume(i1) #0
244 
245         define void @test() {
246         entry:
247           br label %label
248 
249         label:
250           %0 = load %b*, %b** inttoptr (i64 8 to %b**), align 8
251           %1 = getelementptr inbounds %b, %b* %0, i64 undef, i32 0
252           %2 = load i64, i64* %1, align 8
253           %3 = icmp ugt i64 %2, 1
254           br i1 %3, label %if.then, label %if.else
255 
256         if.then:
257           unreachable
258 
259         if.else:
260           call void @g(i8* undef)
261           store i64 undef, i64* null, align 536870912
262           %4 = icmp eq i64 %2, 0
263           call void @llvm.assume(i1 %4)
264           unreachable
265         }
266 
267         attributes #0 = { nounwind willreturn }
268   )ir",
269                                                 Err, Ctx));
270 
271   assert(M && "Could not parse module?");
272   Function *Func = M->getFunction("test");
273   SmallVector<BasicBlock *, 1> Blocks{ getBlockByName(Func, "if.else") };
274   AssumptionCache AC(*Func);
275   CodeExtractor CE(Blocks, nullptr, false, nullptr, nullptr, &AC);
276   EXPECT_TRUE(CE.isEligible());
277 
278   CodeExtractorAnalysisCache CEAC(*Func);
279   Function *Outlined = CE.extractCodeRegion(CEAC);
280   EXPECT_TRUE(Outlined);
281   EXPECT_FALSE(verifyFunction(*Outlined));
282   EXPECT_FALSE(verifyFunction(*Func));
283   EXPECT_FALSE(CE.verifyAssumptionCache(*Func, *Outlined, &AC));
284 }
285 
TEST(CodeExtractor,RemoveBitcastUsesFromOuterLifetimeMarkers)286 TEST(CodeExtractor, RemoveBitcastUsesFromOuterLifetimeMarkers) {
287   LLVMContext Ctx;
288   SMDiagnostic Err;
289   std::unique_ptr<Module> M(parseAssemblyString(R"ir(
290     target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
291     target triple = "x86_64-unknown-linux-gnu"
292 
293     declare void @use(i32*)
294     declare void @llvm.lifetime.start.p0i8(i64, i8*)
295     declare void @llvm.lifetime.end.p0i8(i64, i8*)
296 
297     define void @foo() {
298     entry:
299       %0 = alloca i32
300       br label %extract
301 
302     extract:
303       %1 = bitcast i32* %0 to i8*
304       call void @llvm.lifetime.start.p0i8(i64 4, i8* %1)
305       call void @use(i32* %0)
306       br label %exit
307 
308     exit:
309       call void @use(i32* %0)
310       call void @llvm.lifetime.end.p0i8(i64 4, i8* %1)
311       ret void
312     }
313   )ir",
314                                                 Err, Ctx));
315 
316   Function *Func = M->getFunction("foo");
317   SmallVector<BasicBlock *, 1> Blocks{getBlockByName(Func, "extract")};
318 
319   CodeExtractor CE(Blocks);
320   EXPECT_TRUE(CE.isEligible());
321 
322   CodeExtractorAnalysisCache CEAC(*Func);
323   SetVector<Value *> Inputs, Outputs, SinkingCands, HoistingCands;
324   BasicBlock *CommonExit = nullptr;
325   CE.findAllocas(CEAC, SinkingCands, HoistingCands, CommonExit);
326   CE.findInputsOutputs(Inputs, Outputs, SinkingCands);
327   EXPECT_EQ(Outputs.size(), 0U);
328 
329   Function *Outlined = CE.extractCodeRegion(CEAC);
330   EXPECT_TRUE(Outlined);
331   EXPECT_FALSE(verifyFunction(*Outlined));
332   EXPECT_FALSE(verifyFunction(*Func));
333 }
334 } // end anonymous namespace
335