1 //=======- CaptureTrackingTest.cpp - Unit test for the Capture Tracking ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "llvm/Analysis/CaptureTracking.h"
10 #include "llvm/AsmParser/Parser.h"
11 #include "llvm/IR/Dominators.h"
12 #include "llvm/IR/Instructions.h"
13 #include "llvm/IR/LLVMContext.h"
14 #include "llvm/IR/Module.h"
15 #include "llvm/Support/SourceMgr.h"
16 #include "gtest/gtest.h"
17 
18 using namespace llvm;
19 
TEST(CaptureTracking,MaxUsesToExplore)20 TEST(CaptureTracking, MaxUsesToExplore) {
21   StringRef Assembly = R"(
22     ; Function Attrs: nounwind ssp uwtable
23     declare void @doesnt_capture(i8* nocapture, i8* nocapture, i8* nocapture,
24                                  i8* nocapture, i8* nocapture)
25 
26     ; %arg has 5 uses
27     define void @test_few_uses(i8* %arg) {
28       call void @doesnt_capture(i8* %arg, i8* %arg, i8* %arg, i8* %arg, i8* %arg)
29       ret void
30     }
31 
32     ; %arg has 50 uses
33     define void @test_many_uses(i8* %arg) {
34       call void @doesnt_capture(i8* %arg, i8* %arg, i8* %arg, i8* %arg, i8* %arg)
35       call void @doesnt_capture(i8* %arg, i8* %arg, i8* %arg, i8* %arg, i8* %arg)
36       call void @doesnt_capture(i8* %arg, i8* %arg, i8* %arg, i8* %arg, i8* %arg)
37       call void @doesnt_capture(i8* %arg, i8* %arg, i8* %arg, i8* %arg, i8* %arg)
38       call void @doesnt_capture(i8* %arg, i8* %arg, i8* %arg, i8* %arg, i8* %arg)
39       call void @doesnt_capture(i8* %arg, i8* %arg, i8* %arg, i8* %arg, i8* %arg)
40       call void @doesnt_capture(i8* %arg, i8* %arg, i8* %arg, i8* %arg, i8* %arg)
41       call void @doesnt_capture(i8* %arg, i8* %arg, i8* %arg, i8* %arg, i8* %arg)
42       call void @doesnt_capture(i8* %arg, i8* %arg, i8* %arg, i8* %arg, i8* %arg)
43       call void @doesnt_capture(i8* %arg, i8* %arg, i8* %arg, i8* %arg, i8* %arg)
44       ret void
45     }
46   )";
47 
48   LLVMContext Context;
49   SMDiagnostic Error;
50   auto M = parseAssemblyString(Assembly, Error, Context);
51   ASSERT_TRUE(M) << "Bad assembly?";
52 
53   auto Test = [&M](const char *FName, unsigned FalseMaxUsesLimit,
54                    unsigned TrueMaxUsesLimit) {
55     Function *F = M->getFunction(FName);
56     ASSERT_NE(F, nullptr);
57     Value *Arg = &*F->arg_begin();
58     ASSERT_NE(Arg, nullptr);
59     ASSERT_FALSE(PointerMayBeCaptured(Arg, true, true, FalseMaxUsesLimit));
60     ASSERT_TRUE(PointerMayBeCaptured(Arg, true, true, TrueMaxUsesLimit));
61 
62     BasicBlock *EntryBB = &F->getEntryBlock();
63     DominatorTree DT(*F);
64 
65     Instruction *Ret = EntryBB->getTerminator();
66     ASSERT_TRUE(isa<ReturnInst>(Ret));
67     ASSERT_FALSE(PointerMayBeCapturedBefore(Arg, true, true, Ret, &DT, false,
68                                             FalseMaxUsesLimit));
69     ASSERT_TRUE(PointerMayBeCapturedBefore(Arg, true, true, Ret, &DT, false,
70                                            TrueMaxUsesLimit));
71   };
72 
73   Test("test_few_uses", 6, 4);
74   Test("test_many_uses", 50, 30);
75 }
76 
77 struct CollectingCaptureTracker : public CaptureTracker {
78   SmallVector<const Use *, 4> Captures;
tooManyUsesCollectingCaptureTracker79   void tooManyUses() override { }
capturedCollectingCaptureTracker80   bool captured(const Use *U) override {
81     Captures.push_back(U);
82     return false;
83   }
84 };
85 
TEST(CaptureTracking,MultipleUsesInSameInstruction)86 TEST(CaptureTracking, MultipleUsesInSameInstruction) {
87   StringRef Assembly = R"(
88     declare void @call(i8*, i8*, i8*)
89 
90     define void @test(i8* %arg, i8** %ptr) {
91       call void @call(i8* %arg, i8* nocapture %arg, i8* %arg) [ "bundle"(i8* %arg) ]
92       cmpxchg i8** %ptr, i8* %arg, i8* %arg acq_rel monotonic
93       icmp eq i8* %arg, %arg
94       ret void
95     }
96   )";
97 
98   LLVMContext Context;
99   SMDiagnostic Error;
100   auto M = parseAssemblyString(Assembly, Error, Context);
101   ASSERT_TRUE(M) << "Bad assembly?";
102 
103   Function *F = M->getFunction("test");
104   Value *Arg = &*F->arg_begin();
105   BasicBlock *BB = &F->getEntryBlock();
106   Instruction *Call = &*BB->begin();
107   Instruction *CmpXChg = Call->getNextNode();
108   Instruction *ICmp = CmpXChg->getNextNode();
109 
110   CollectingCaptureTracker CT;
111   PointerMayBeCaptured(Arg, &CT);
112   EXPECT_EQ(7u, CT.Captures.size());
113   // Call arg 1
114   EXPECT_EQ(Call, CT.Captures[0]->getUser());
115   EXPECT_EQ(0u, CT.Captures[0]->getOperandNo());
116   // Call arg 3
117   EXPECT_EQ(Call, CT.Captures[1]->getUser());
118   EXPECT_EQ(2u, CT.Captures[1]->getOperandNo());
119   // Operand bundle arg
120   EXPECT_EQ(Call, CT.Captures[2]->getUser());
121   EXPECT_EQ(3u, CT.Captures[2]->getOperandNo());
122   // Cmpxchg compare operand
123   EXPECT_EQ(CmpXChg, CT.Captures[3]->getUser());
124   EXPECT_EQ(1u, CT.Captures[3]->getOperandNo());
125   // Cmpxchg new value operand
126   EXPECT_EQ(CmpXChg, CT.Captures[4]->getUser());
127   EXPECT_EQ(2u, CT.Captures[4]->getOperandNo());
128   // ICmp first operand
129   EXPECT_EQ(ICmp, CT.Captures[5]->getUser());
130   EXPECT_EQ(0u, CT.Captures[5]->getOperandNo());
131   // ICmp second operand
132   EXPECT_EQ(ICmp, CT.Captures[6]->getUser());
133   EXPECT_EQ(1u, CT.Captures[6]->getOperandNo());
134 }
135