1 //===- LowerTypeTests.cpp - type metadata lowering pass -------------------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This pass lowers type metadata and calls to the llvm.type.test intrinsic.
11 // It also ensures that globals are properly laid out for the
12 // llvm.icall.branch.funnel intrinsic.
13 // See http://llvm.org/docs/TypeMetadata.html for more information.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "llvm/Transforms/IPO/LowerTypeTests.h"
18 #include "llvm/ADT/APInt.h"
19 #include "llvm/ADT/ArrayRef.h"
20 #include "llvm/ADT/DenseMap.h"
21 #include "llvm/ADT/EquivalenceClasses.h"
22 #include "llvm/ADT/PointerUnion.h"
23 #include "llvm/ADT/SetVector.h"
24 #include "llvm/ADT/SmallVector.h"
25 #include "llvm/ADT/Statistic.h"
26 #include "llvm/ADT/StringRef.h"
27 #include "llvm/ADT/TinyPtrVector.h"
28 #include "llvm/ADT/Triple.h"
29 #include "llvm/Analysis/TypeMetadataUtils.h"
30 #include "llvm/Analysis/ValueTracking.h"
31 #include "llvm/IR/Attributes.h"
32 #include "llvm/IR/BasicBlock.h"
33 #include "llvm/IR/Constant.h"
34 #include "llvm/IR/Constants.h"
35 #include "llvm/IR/DataLayout.h"
36 #include "llvm/IR/DerivedTypes.h"
37 #include "llvm/IR/Function.h"
38 #include "llvm/IR/GlobalAlias.h"
39 #include "llvm/IR/GlobalObject.h"
40 #include "llvm/IR/GlobalValue.h"
41 #include "llvm/IR/GlobalVariable.h"
42 #include "llvm/IR/IRBuilder.h"
43 #include "llvm/IR/InlineAsm.h"
44 #include "llvm/IR/Instruction.h"
45 #include "llvm/IR/Instructions.h"
46 #include "llvm/IR/Intrinsics.h"
47 #include "llvm/IR/LLVMContext.h"
48 #include "llvm/IR/Metadata.h"
49 #include "llvm/IR/Module.h"
50 #include "llvm/IR/ModuleSummaryIndex.h"
51 #include "llvm/IR/ModuleSummaryIndexYAML.h"
52 #include "llvm/IR/Operator.h"
53 #include "llvm/IR/PassManager.h"
54 #include "llvm/IR/Type.h"
55 #include "llvm/IR/Use.h"
56 #include "llvm/IR/User.h"
57 #include "llvm/IR/Value.h"
58 #include "llvm/Pass.h"
59 #include "llvm/Support/Allocator.h"
60 #include "llvm/Support/Casting.h"
61 #include "llvm/Support/CommandLine.h"
62 #include "llvm/Support/Debug.h"
63 #include "llvm/Support/Error.h"
64 #include "llvm/Support/ErrorHandling.h"
65 #include "llvm/Support/FileSystem.h"
66 #include "llvm/Support/MathExtras.h"
67 #include "llvm/Support/MemoryBuffer.h"
68 #include "llvm/Support/TrailingObjects.h"
69 #include "llvm/Support/YAMLTraits.h"
70 #include "llvm/Support/raw_ostream.h"
71 #include "llvm/Transforms/IPO.h"
72 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
73 #include "llvm/Transforms/Utils/ModuleUtils.h"
74 #include <algorithm>
75 #include <cassert>
76 #include <cstdint>
77 #include <memory>
78 #include <set>
79 #include <string>
80 #include <system_error>
81 #include <utility>
82 #include <vector>
83 
84 using namespace llvm;
85 using namespace lowertypetests;
86 
87 #define DEBUG_TYPE "lowertypetests"
88 
89 STATISTIC(ByteArraySizeBits, "Byte array size in bits");
90 STATISTIC(ByteArraySizeBytes, "Byte array size in bytes");
91 STATISTIC(NumByteArraysCreated, "Number of byte arrays created");
92 STATISTIC(NumTypeTestCallsLowered, "Number of type test calls lowered");
93 STATISTIC(NumTypeIdDisjointSets, "Number of disjoint sets of type identifiers");
94 
95 static cl::opt<bool> AvoidReuse(
96     "lowertypetests-avoid-reuse",
97     cl::desc("Try to avoid reuse of byte array addresses using aliases"),
98     cl::Hidden, cl::init(true));
99 
100 static cl::opt<PassSummaryAction> ClSummaryAction(
101     "lowertypetests-summary-action",
102     cl::desc("What to do with the summary when running this pass"),
103     cl::values(clEnumValN(PassSummaryAction::None, "none", "Do nothing"),
104                clEnumValN(PassSummaryAction::Import, "import",
105                           "Import typeid resolutions from summary and globals"),
106                clEnumValN(PassSummaryAction::Export, "export",
107                           "Export typeid resolutions to summary and globals")),
108     cl::Hidden);
109 
110 static cl::opt<std::string> ClReadSummary(
111     "lowertypetests-read-summary",
112     cl::desc("Read summary from given YAML file before running pass"),
113     cl::Hidden);
114 
115 static cl::opt<std::string> ClWriteSummary(
116     "lowertypetests-write-summary",
117     cl::desc("Write summary to given YAML file after running pass"),
118     cl::Hidden);
119 
containsGlobalOffset(uint64_t Offset) const120 bool BitSetInfo::containsGlobalOffset(uint64_t Offset) const {
121   if (Offset < ByteOffset)
122     return false;
123 
124   if ((Offset - ByteOffset) % (uint64_t(1) << AlignLog2) != 0)
125     return false;
126 
127   uint64_t BitOffset = (Offset - ByteOffset) >> AlignLog2;
128   if (BitOffset >= BitSize)
129     return false;
130 
131   return Bits.count(BitOffset);
132 }
133 
print(raw_ostream & OS) const134 void BitSetInfo::print(raw_ostream &OS) const {
135   OS << "offset " << ByteOffset << " size " << BitSize << " align "
136      << (1 << AlignLog2);
137 
138   if (isAllOnes()) {
139     OS << " all-ones\n";
140     return;
141   }
142 
143   OS << " { ";
144   for (uint64_t B : Bits)
145     OS << B << ' ';
146   OS << "}\n";
147 }
148 
build()149 BitSetInfo BitSetBuilder::build() {
150   if (Min > Max)
151     Min = 0;
152 
153   // Normalize each offset against the minimum observed offset, and compute
154   // the bitwise OR of each of the offsets. The number of trailing zeros
155   // in the mask gives us the log2 of the alignment of all offsets, which
156   // allows us to compress the bitset by only storing one bit per aligned
157   // address.
158   uint64_t Mask = 0;
159   for (uint64_t &Offset : Offsets) {
160     Offset -= Min;
161     Mask |= Offset;
162   }
163 
164   BitSetInfo BSI;
165   BSI.ByteOffset = Min;
166 
167   BSI.AlignLog2 = 0;
168   if (Mask != 0)
169     BSI.AlignLog2 = countTrailingZeros(Mask, ZB_Undefined);
170 
171   // Build the compressed bitset while normalizing the offsets against the
172   // computed alignment.
173   BSI.BitSize = ((Max - Min) >> BSI.AlignLog2) + 1;
174   for (uint64_t Offset : Offsets) {
175     Offset >>= BSI.AlignLog2;
176     BSI.Bits.insert(Offset);
177   }
178 
179   return BSI;
180 }
181 
addFragment(const std::set<uint64_t> & F)182 void GlobalLayoutBuilder::addFragment(const std::set<uint64_t> &F) {
183   // Create a new fragment to hold the layout for F.
184   Fragments.emplace_back();
185   std::vector<uint64_t> &Fragment = Fragments.back();
186   uint64_t FragmentIndex = Fragments.size() - 1;
187 
188   for (auto ObjIndex : F) {
189     uint64_t OldFragmentIndex = FragmentMap[ObjIndex];
190     if (OldFragmentIndex == 0) {
191       // We haven't seen this object index before, so just add it to the current
192       // fragment.
193       Fragment.push_back(ObjIndex);
194     } else {
195       // This index belongs to an existing fragment. Copy the elements of the
196       // old fragment into this one and clear the old fragment. We don't update
197       // the fragment map just yet, this ensures that any further references to
198       // indices from the old fragment in this fragment do not insert any more
199       // indices.
200       std::vector<uint64_t> &OldFragment = Fragments[OldFragmentIndex];
201       Fragment.insert(Fragment.end(), OldFragment.begin(), OldFragment.end());
202       OldFragment.clear();
203     }
204   }
205 
206   // Update the fragment map to point our object indices to this fragment.
207   for (uint64_t ObjIndex : Fragment)
208     FragmentMap[ObjIndex] = FragmentIndex;
209 }
210 
allocate(const std::set<uint64_t> & Bits,uint64_t BitSize,uint64_t & AllocByteOffset,uint8_t & AllocMask)211 void ByteArrayBuilder::allocate(const std::set<uint64_t> &Bits,
212                                 uint64_t BitSize, uint64_t &AllocByteOffset,
213                                 uint8_t &AllocMask) {
214   // Find the smallest current allocation.
215   unsigned Bit = 0;
216   for (unsigned I = 1; I != BitsPerByte; ++I)
217     if (BitAllocs[I] < BitAllocs[Bit])
218       Bit = I;
219 
220   AllocByteOffset = BitAllocs[Bit];
221 
222   // Add our size to it.
223   unsigned ReqSize = AllocByteOffset + BitSize;
224   BitAllocs[Bit] = ReqSize;
225   if (Bytes.size() < ReqSize)
226     Bytes.resize(ReqSize);
227 
228   // Set our bits.
229   AllocMask = 1 << Bit;
230   for (uint64_t B : Bits)
231     Bytes[AllocByteOffset + B] |= AllocMask;
232 }
233 
234 namespace {
235 
236 struct ByteArrayInfo {
237   std::set<uint64_t> Bits;
238   uint64_t BitSize;
239   GlobalVariable *ByteArray;
240   GlobalVariable *MaskGlobal;
241   uint8_t *MaskPtr = nullptr;
242 };
243 
244 /// A POD-like structure that we use to store a global reference together with
245 /// its metadata types. In this pass we frequently need to query the set of
246 /// metadata types referenced by a global, which at the IR level is an expensive
247 /// operation involving a map lookup; this data structure helps to reduce the
248 /// number of times we need to do this lookup.
249 class GlobalTypeMember final : TrailingObjects<GlobalTypeMember, MDNode *> {
250   friend TrailingObjects;
251 
252   GlobalObject *GO;
253   size_t NTypes;
254 
255   // For functions: true if this is a definition (either in the merged module or
256   // in one of the thinlto modules).
257   bool IsDefinition;
258 
259   // For functions: true if this function is either defined or used in a thinlto
260   // module and its jumptable entry needs to be exported to thinlto backends.
261   bool IsExported;
262 
numTrailingObjects(OverloadToken<MDNode * >) const263   size_t numTrailingObjects(OverloadToken<MDNode *>) const { return NTypes; }
264 
265 public:
create(BumpPtrAllocator & Alloc,GlobalObject * GO,bool IsDefinition,bool IsExported,ArrayRef<MDNode * > Types)266   static GlobalTypeMember *create(BumpPtrAllocator &Alloc, GlobalObject *GO,
267                                   bool IsDefinition, bool IsExported,
268                                   ArrayRef<MDNode *> Types) {
269     auto *GTM = static_cast<GlobalTypeMember *>(Alloc.Allocate(
270         totalSizeToAlloc<MDNode *>(Types.size()), alignof(GlobalTypeMember)));
271     GTM->GO = GO;
272     GTM->NTypes = Types.size();
273     GTM->IsDefinition = IsDefinition;
274     GTM->IsExported = IsExported;
275     std::uninitialized_copy(Types.begin(), Types.end(),
276                             GTM->getTrailingObjects<MDNode *>());
277     return GTM;
278   }
279 
getGlobal() const280   GlobalObject *getGlobal() const {
281     return GO;
282   }
283 
isDefinition() const284   bool isDefinition() const {
285     return IsDefinition;
286   }
287 
isExported() const288   bool isExported() const {
289     return IsExported;
290   }
291 
types() const292   ArrayRef<MDNode *> types() const {
293     return makeArrayRef(getTrailingObjects<MDNode *>(), NTypes);
294   }
295 };
296 
297 struct ICallBranchFunnel final
298     : TrailingObjects<ICallBranchFunnel, GlobalTypeMember *> {
create__anonf470a1fd0111::ICallBranchFunnel299   static ICallBranchFunnel *create(BumpPtrAllocator &Alloc, CallInst *CI,
300                                    ArrayRef<GlobalTypeMember *> Targets,
301                                    unsigned UniqueId) {
302     auto *Call = static_cast<ICallBranchFunnel *>(
303         Alloc.Allocate(totalSizeToAlloc<GlobalTypeMember *>(Targets.size()),
304                        alignof(ICallBranchFunnel)));
305     Call->CI = CI;
306     Call->UniqueId = UniqueId;
307     Call->NTargets = Targets.size();
308     std::uninitialized_copy(Targets.begin(), Targets.end(),
309                             Call->getTrailingObjects<GlobalTypeMember *>());
310     return Call;
311   }
312 
313   CallInst *CI;
targets__anonf470a1fd0111::ICallBranchFunnel314   ArrayRef<GlobalTypeMember *> targets() const {
315     return makeArrayRef(getTrailingObjects<GlobalTypeMember *>(), NTargets);
316   }
317 
318   unsigned UniqueId;
319 
320 private:
321   size_t NTargets;
322 };
323 
324 class LowerTypeTestsModule {
325   Module &M;
326 
327   ModuleSummaryIndex *ExportSummary;
328   const ModuleSummaryIndex *ImportSummary;
329 
330   Triple::ArchType Arch;
331   Triple::OSType OS;
332   Triple::ObjectFormatType ObjectFormat;
333 
334   IntegerType *Int1Ty = Type::getInt1Ty(M.getContext());
335   IntegerType *Int8Ty = Type::getInt8Ty(M.getContext());
336   PointerType *Int8PtrTy = Type::getInt8PtrTy(M.getContext());
337   ArrayType *Int8Arr0Ty = ArrayType::get(Type::getInt8Ty(M.getContext()), 0);
338   IntegerType *Int32Ty = Type::getInt32Ty(M.getContext());
339   PointerType *Int32PtrTy = PointerType::getUnqual(Int32Ty);
340   IntegerType *Int64Ty = Type::getInt64Ty(M.getContext());
341   IntegerType *IntPtrTy = M.getDataLayout().getIntPtrType(M.getContext(), 0);
342 
343   // Indirect function call index assignment counter for WebAssembly
344   uint64_t IndirectIndex = 1;
345 
346   // Mapping from type identifiers to the call sites that test them, as well as
347   // whether the type identifier needs to be exported to ThinLTO backends as
348   // part of the regular LTO phase of the ThinLTO pipeline (see exportTypeId).
349   struct TypeIdUserInfo {
350     std::vector<CallInst *> CallSites;
351     bool IsExported = false;
352   };
353   DenseMap<Metadata *, TypeIdUserInfo> TypeIdUsers;
354 
355   /// This structure describes how to lower type tests for a particular type
356   /// identifier. It is either built directly from the global analysis (during
357   /// regular LTO or the regular LTO phase of ThinLTO), or indirectly using type
358   /// identifier summaries and external symbol references (in ThinLTO backends).
359   struct TypeIdLowering {
360     TypeTestResolution::Kind TheKind = TypeTestResolution::Unsat;
361 
362     /// All except Unsat: the start address within the combined global.
363     Constant *OffsetedGlobal;
364 
365     /// ByteArray, Inline, AllOnes: log2 of the required global alignment
366     /// relative to the start address.
367     Constant *AlignLog2;
368 
369     /// ByteArray, Inline, AllOnes: one less than the size of the memory region
370     /// covering members of this type identifier as a multiple of 2^AlignLog2.
371     Constant *SizeM1;
372 
373     /// ByteArray: the byte array to test the address against.
374     Constant *TheByteArray;
375 
376     /// ByteArray: the bit mask to apply to bytes loaded from the byte array.
377     Constant *BitMask;
378 
379     /// Inline: the bit mask to test the address against.
380     Constant *InlineBits;
381   };
382 
383   std::vector<ByteArrayInfo> ByteArrayInfos;
384 
385   Function *WeakInitializerFn = nullptr;
386 
387   bool shouldExportConstantsAsAbsoluteSymbols();
388   uint8_t *exportTypeId(StringRef TypeId, const TypeIdLowering &TIL);
389   TypeIdLowering importTypeId(StringRef TypeId);
390   void importTypeTest(CallInst *CI);
391   void importFunction(Function *F, bool isDefinition);
392 
393   BitSetInfo
394   buildBitSet(Metadata *TypeId,
395               const DenseMap<GlobalTypeMember *, uint64_t> &GlobalLayout);
396   ByteArrayInfo *createByteArray(BitSetInfo &BSI);
397   void allocateByteArrays();
398   Value *createBitSetTest(IRBuilder<> &B, const TypeIdLowering &TIL,
399                           Value *BitOffset);
400   void lowerTypeTestCalls(
401       ArrayRef<Metadata *> TypeIds, Constant *CombinedGlobalAddr,
402       const DenseMap<GlobalTypeMember *, uint64_t> &GlobalLayout);
403   Value *lowerTypeTestCall(Metadata *TypeId, CallInst *CI,
404                            const TypeIdLowering &TIL);
405 
406   void buildBitSetsFromGlobalVariables(ArrayRef<Metadata *> TypeIds,
407                                        ArrayRef<GlobalTypeMember *> Globals);
408   unsigned getJumpTableEntrySize();
409   Type *getJumpTableEntryType();
410   void createJumpTableEntry(raw_ostream &AsmOS, raw_ostream &ConstraintOS,
411                             Triple::ArchType JumpTableArch,
412                             SmallVectorImpl<Value *> &AsmArgs, Function *Dest);
413   void verifyTypeMDNode(GlobalObject *GO, MDNode *Type);
414   void buildBitSetsFromFunctions(ArrayRef<Metadata *> TypeIds,
415                                  ArrayRef<GlobalTypeMember *> Functions);
416   void buildBitSetsFromFunctionsNative(ArrayRef<Metadata *> TypeIds,
417                                        ArrayRef<GlobalTypeMember *> Functions);
418   void buildBitSetsFromFunctionsWASM(ArrayRef<Metadata *> TypeIds,
419                                      ArrayRef<GlobalTypeMember *> Functions);
420   void
421   buildBitSetsFromDisjointSet(ArrayRef<Metadata *> TypeIds,
422                               ArrayRef<GlobalTypeMember *> Globals,
423                               ArrayRef<ICallBranchFunnel *> ICallBranchFunnels);
424 
425   void replaceWeakDeclarationWithJumpTablePtr(Function *F, Constant *JT, bool IsDefinition);
426   void moveInitializerToModuleConstructor(GlobalVariable *GV);
427   void findGlobalVariableUsersOf(Constant *C,
428                                  SmallSetVector<GlobalVariable *, 8> &Out);
429 
430   void createJumpTable(Function *F, ArrayRef<GlobalTypeMember *> Functions);
431 
432   /// replaceCfiUses - Go through the uses list for this definition
433   /// and make each use point to "V" instead of "this" when the use is outside
434   /// the block. 'This's use list is expected to have at least one element.
435   /// Unlike replaceAllUsesWith this function skips blockaddr and direct call
436   /// uses.
437   void replaceCfiUses(Function *Old, Value *New, bool IsDefinition);
438 
439   /// replaceDirectCalls - Go through the uses list for this definition and
440   /// replace each use, which is a direct function call.
441   void replaceDirectCalls(Value *Old, Value *New);
442 
443 public:
444   LowerTypeTestsModule(Module &M, ModuleSummaryIndex *ExportSummary,
445                        const ModuleSummaryIndex *ImportSummary);
446 
447   bool lower();
448 
449   // Lower the module using the action and summary passed as command line
450   // arguments. For testing purposes only.
451   static bool runForTesting(Module &M);
452 };
453 
454 struct LowerTypeTests : public ModulePass {
455   static char ID;
456 
457   bool UseCommandLine = false;
458 
459   ModuleSummaryIndex *ExportSummary;
460   const ModuleSummaryIndex *ImportSummary;
461 
LowerTypeTests__anonf470a1fd0111::LowerTypeTests462   LowerTypeTests() : ModulePass(ID), UseCommandLine(true) {
463     initializeLowerTypeTestsPass(*PassRegistry::getPassRegistry());
464   }
465 
LowerTypeTests__anonf470a1fd0111::LowerTypeTests466   LowerTypeTests(ModuleSummaryIndex *ExportSummary,
467                  const ModuleSummaryIndex *ImportSummary)
468       : ModulePass(ID), ExportSummary(ExportSummary),
469         ImportSummary(ImportSummary) {
470     initializeLowerTypeTestsPass(*PassRegistry::getPassRegistry());
471   }
472 
runOnModule__anonf470a1fd0111::LowerTypeTests473   bool runOnModule(Module &M) override {
474     if (UseCommandLine)
475       return LowerTypeTestsModule::runForTesting(M);
476     return LowerTypeTestsModule(M, ExportSummary, ImportSummary).lower();
477   }
478 };
479 
480 } // end anonymous namespace
481 
482 char LowerTypeTests::ID = 0;
483 
484 INITIALIZE_PASS(LowerTypeTests, "lowertypetests", "Lower type metadata", false,
485                 false)
486 
487 ModulePass *
createLowerTypeTestsPass(ModuleSummaryIndex * ExportSummary,const ModuleSummaryIndex * ImportSummary)488 llvm::createLowerTypeTestsPass(ModuleSummaryIndex *ExportSummary,
489                                const ModuleSummaryIndex *ImportSummary) {
490   return new LowerTypeTests(ExportSummary, ImportSummary);
491 }
492 
493 /// Build a bit set for TypeId using the object layouts in
494 /// GlobalLayout.
buildBitSet(Metadata * TypeId,const DenseMap<GlobalTypeMember *,uint64_t> & GlobalLayout)495 BitSetInfo LowerTypeTestsModule::buildBitSet(
496     Metadata *TypeId,
497     const DenseMap<GlobalTypeMember *, uint64_t> &GlobalLayout) {
498   BitSetBuilder BSB;
499 
500   // Compute the byte offset of each address associated with this type
501   // identifier.
502   for (auto &GlobalAndOffset : GlobalLayout) {
503     for (MDNode *Type : GlobalAndOffset.first->types()) {
504       if (Type->getOperand(1) != TypeId)
505         continue;
506       uint64_t Offset =
507           cast<ConstantInt>(
508               cast<ConstantAsMetadata>(Type->getOperand(0))->getValue())
509               ->getZExtValue();
510       BSB.addOffset(GlobalAndOffset.second + Offset);
511     }
512   }
513 
514   return BSB.build();
515 }
516 
517 /// Build a test that bit BitOffset mod sizeof(Bits)*8 is set in
518 /// Bits. This pattern matches to the bt instruction on x86.
createMaskedBitTest(IRBuilder<> & B,Value * Bits,Value * BitOffset)519 static Value *createMaskedBitTest(IRBuilder<> &B, Value *Bits,
520                                   Value *BitOffset) {
521   auto BitsType = cast<IntegerType>(Bits->getType());
522   unsigned BitWidth = BitsType->getBitWidth();
523 
524   BitOffset = B.CreateZExtOrTrunc(BitOffset, BitsType);
525   Value *BitIndex =
526       B.CreateAnd(BitOffset, ConstantInt::get(BitsType, BitWidth - 1));
527   Value *BitMask = B.CreateShl(ConstantInt::get(BitsType, 1), BitIndex);
528   Value *MaskedBits = B.CreateAnd(Bits, BitMask);
529   return B.CreateICmpNE(MaskedBits, ConstantInt::get(BitsType, 0));
530 }
531 
createByteArray(BitSetInfo & BSI)532 ByteArrayInfo *LowerTypeTestsModule::createByteArray(BitSetInfo &BSI) {
533   // Create globals to stand in for byte arrays and masks. These never actually
534   // get initialized, we RAUW and erase them later in allocateByteArrays() once
535   // we know the offset and mask to use.
536   auto ByteArrayGlobal = new GlobalVariable(
537       M, Int8Ty, /*isConstant=*/true, GlobalValue::PrivateLinkage, nullptr);
538   auto MaskGlobal = new GlobalVariable(M, Int8Ty, /*isConstant=*/true,
539                                        GlobalValue::PrivateLinkage, nullptr);
540 
541   ByteArrayInfos.emplace_back();
542   ByteArrayInfo *BAI = &ByteArrayInfos.back();
543 
544   BAI->Bits = BSI.Bits;
545   BAI->BitSize = BSI.BitSize;
546   BAI->ByteArray = ByteArrayGlobal;
547   BAI->MaskGlobal = MaskGlobal;
548   return BAI;
549 }
550 
allocateByteArrays()551 void LowerTypeTestsModule::allocateByteArrays() {
552   std::stable_sort(ByteArrayInfos.begin(), ByteArrayInfos.end(),
553                    [](const ByteArrayInfo &BAI1, const ByteArrayInfo &BAI2) {
554                      return BAI1.BitSize > BAI2.BitSize;
555                    });
556 
557   std::vector<uint64_t> ByteArrayOffsets(ByteArrayInfos.size());
558 
559   ByteArrayBuilder BAB;
560   for (unsigned I = 0; I != ByteArrayInfos.size(); ++I) {
561     ByteArrayInfo *BAI = &ByteArrayInfos[I];
562 
563     uint8_t Mask;
564     BAB.allocate(BAI->Bits, BAI->BitSize, ByteArrayOffsets[I], Mask);
565 
566     BAI->MaskGlobal->replaceAllUsesWith(
567         ConstantExpr::getIntToPtr(ConstantInt::get(Int8Ty, Mask), Int8PtrTy));
568     BAI->MaskGlobal->eraseFromParent();
569     if (BAI->MaskPtr)
570       *BAI->MaskPtr = Mask;
571   }
572 
573   Constant *ByteArrayConst = ConstantDataArray::get(M.getContext(), BAB.Bytes);
574   auto ByteArray =
575       new GlobalVariable(M, ByteArrayConst->getType(), /*isConstant=*/true,
576                          GlobalValue::PrivateLinkage, ByteArrayConst);
577 
578   for (unsigned I = 0; I != ByteArrayInfos.size(); ++I) {
579     ByteArrayInfo *BAI = &ByteArrayInfos[I];
580 
581     Constant *Idxs[] = {ConstantInt::get(IntPtrTy, 0),
582                         ConstantInt::get(IntPtrTy, ByteArrayOffsets[I])};
583     Constant *GEP = ConstantExpr::getInBoundsGetElementPtr(
584         ByteArrayConst->getType(), ByteArray, Idxs);
585 
586     // Create an alias instead of RAUW'ing the gep directly. On x86 this ensures
587     // that the pc-relative displacement is folded into the lea instead of the
588     // test instruction getting another displacement.
589     GlobalAlias *Alias = GlobalAlias::create(
590         Int8Ty, 0, GlobalValue::PrivateLinkage, "bits", GEP, &M);
591     BAI->ByteArray->replaceAllUsesWith(Alias);
592     BAI->ByteArray->eraseFromParent();
593   }
594 
595   ByteArraySizeBits = BAB.BitAllocs[0] + BAB.BitAllocs[1] + BAB.BitAllocs[2] +
596                       BAB.BitAllocs[3] + BAB.BitAllocs[4] + BAB.BitAllocs[5] +
597                       BAB.BitAllocs[6] + BAB.BitAllocs[7];
598   ByteArraySizeBytes = BAB.Bytes.size();
599 }
600 
601 /// Build a test that bit BitOffset is set in the type identifier that was
602 /// lowered to TIL, which must be either an Inline or a ByteArray.
createBitSetTest(IRBuilder<> & B,const TypeIdLowering & TIL,Value * BitOffset)603 Value *LowerTypeTestsModule::createBitSetTest(IRBuilder<> &B,
604                                               const TypeIdLowering &TIL,
605                                               Value *BitOffset) {
606   if (TIL.TheKind == TypeTestResolution::Inline) {
607     // If the bit set is sufficiently small, we can avoid a load by bit testing
608     // a constant.
609     return createMaskedBitTest(B, TIL.InlineBits, BitOffset);
610   } else {
611     Constant *ByteArray = TIL.TheByteArray;
612     if (AvoidReuse && !ImportSummary) {
613       // Each use of the byte array uses a different alias. This makes the
614       // backend less likely to reuse previously computed byte array addresses,
615       // improving the security of the CFI mechanism based on this pass.
616       // This won't work when importing because TheByteArray is external.
617       ByteArray = GlobalAlias::create(Int8Ty, 0, GlobalValue::PrivateLinkage,
618                                       "bits_use", ByteArray, &M);
619     }
620 
621     Value *ByteAddr = B.CreateGEP(Int8Ty, ByteArray, BitOffset);
622     Value *Byte = B.CreateLoad(ByteAddr);
623 
624     Value *ByteAndMask =
625         B.CreateAnd(Byte, ConstantExpr::getPtrToInt(TIL.BitMask, Int8Ty));
626     return B.CreateICmpNE(ByteAndMask, ConstantInt::get(Int8Ty, 0));
627   }
628 }
629 
isKnownTypeIdMember(Metadata * TypeId,const DataLayout & DL,Value * V,uint64_t COffset)630 static bool isKnownTypeIdMember(Metadata *TypeId, const DataLayout &DL,
631                                 Value *V, uint64_t COffset) {
632   if (auto GV = dyn_cast<GlobalObject>(V)) {
633     SmallVector<MDNode *, 2> Types;
634     GV->getMetadata(LLVMContext::MD_type, Types);
635     for (MDNode *Type : Types) {
636       if (Type->getOperand(1) != TypeId)
637         continue;
638       uint64_t Offset =
639           cast<ConstantInt>(
640               cast<ConstantAsMetadata>(Type->getOperand(0))->getValue())
641               ->getZExtValue();
642       if (COffset == Offset)
643         return true;
644     }
645     return false;
646   }
647 
648   if (auto GEP = dyn_cast<GEPOperator>(V)) {
649     APInt APOffset(DL.getPointerSizeInBits(0), 0);
650     bool Result = GEP->accumulateConstantOffset(DL, APOffset);
651     if (!Result)
652       return false;
653     COffset += APOffset.getZExtValue();
654     return isKnownTypeIdMember(TypeId, DL, GEP->getPointerOperand(), COffset);
655   }
656 
657   if (auto Op = dyn_cast<Operator>(V)) {
658     if (Op->getOpcode() == Instruction::BitCast)
659       return isKnownTypeIdMember(TypeId, DL, Op->getOperand(0), COffset);
660 
661     if (Op->getOpcode() == Instruction::Select)
662       return isKnownTypeIdMember(TypeId, DL, Op->getOperand(1), COffset) &&
663              isKnownTypeIdMember(TypeId, DL, Op->getOperand(2), COffset);
664   }
665 
666   return false;
667 }
668 
669 /// Lower a llvm.type.test call to its implementation. Returns the value to
670 /// replace the call with.
lowerTypeTestCall(Metadata * TypeId,CallInst * CI,const TypeIdLowering & TIL)671 Value *LowerTypeTestsModule::lowerTypeTestCall(Metadata *TypeId, CallInst *CI,
672                                                const TypeIdLowering &TIL) {
673   if (TIL.TheKind == TypeTestResolution::Unsat)
674     return ConstantInt::getFalse(M.getContext());
675 
676   Value *Ptr = CI->getArgOperand(0);
677   const DataLayout &DL = M.getDataLayout();
678   if (isKnownTypeIdMember(TypeId, DL, Ptr, 0))
679     return ConstantInt::getTrue(M.getContext());
680 
681   BasicBlock *InitialBB = CI->getParent();
682 
683   IRBuilder<> B(CI);
684 
685   Value *PtrAsInt = B.CreatePtrToInt(Ptr, IntPtrTy);
686 
687   Constant *OffsetedGlobalAsInt =
688       ConstantExpr::getPtrToInt(TIL.OffsetedGlobal, IntPtrTy);
689   if (TIL.TheKind == TypeTestResolution::Single)
690     return B.CreateICmpEQ(PtrAsInt, OffsetedGlobalAsInt);
691 
692   Value *PtrOffset = B.CreateSub(PtrAsInt, OffsetedGlobalAsInt);
693 
694   // We need to check that the offset both falls within our range and is
695   // suitably aligned. We can check both properties at the same time by
696   // performing a right rotate by log2(alignment) followed by an integer
697   // comparison against the bitset size. The rotate will move the lower
698   // order bits that need to be zero into the higher order bits of the
699   // result, causing the comparison to fail if they are nonzero. The rotate
700   // also conveniently gives us a bit offset to use during the load from
701   // the bitset.
702   Value *OffsetSHR =
703       B.CreateLShr(PtrOffset, ConstantExpr::getZExt(TIL.AlignLog2, IntPtrTy));
704   Value *OffsetSHL = B.CreateShl(
705       PtrOffset, ConstantExpr::getZExt(
706                      ConstantExpr::getSub(
707                          ConstantInt::get(Int8Ty, DL.getPointerSizeInBits(0)),
708                          TIL.AlignLog2),
709                      IntPtrTy));
710   Value *BitOffset = B.CreateOr(OffsetSHR, OffsetSHL);
711 
712   Value *OffsetInRange = B.CreateICmpULE(BitOffset, TIL.SizeM1);
713 
714   // If the bit set is all ones, testing against it is unnecessary.
715   if (TIL.TheKind == TypeTestResolution::AllOnes)
716     return OffsetInRange;
717 
718   // See if the intrinsic is used in the following common pattern:
719   //   br(llvm.type.test(...), thenbb, elsebb)
720   // where nothing happens between the type test and the br.
721   // If so, create slightly simpler IR.
722   if (CI->hasOneUse())
723     if (auto *Br = dyn_cast<BranchInst>(*CI->user_begin()))
724       if (CI->getNextNode() == Br) {
725         BasicBlock *Then = InitialBB->splitBasicBlock(CI->getIterator());
726         BasicBlock *Else = Br->getSuccessor(1);
727         BranchInst *NewBr = BranchInst::Create(Then, Else, OffsetInRange);
728         NewBr->setMetadata(LLVMContext::MD_prof,
729                            Br->getMetadata(LLVMContext::MD_prof));
730         ReplaceInstWithInst(InitialBB->getTerminator(), NewBr);
731 
732         // Update phis in Else resulting from InitialBB being split
733         for (auto &Phi : Else->phis())
734           Phi.addIncoming(Phi.getIncomingValueForBlock(Then), InitialBB);
735 
736         IRBuilder<> ThenB(CI);
737         return createBitSetTest(ThenB, TIL, BitOffset);
738       }
739 
740   IRBuilder<> ThenB(SplitBlockAndInsertIfThen(OffsetInRange, CI, false));
741 
742   // Now that we know that the offset is in range and aligned, load the
743   // appropriate bit from the bitset.
744   Value *Bit = createBitSetTest(ThenB, TIL, BitOffset);
745 
746   // The value we want is 0 if we came directly from the initial block
747   // (having failed the range or alignment checks), or the loaded bit if
748   // we came from the block in which we loaded it.
749   B.SetInsertPoint(CI);
750   PHINode *P = B.CreatePHI(Int1Ty, 2);
751   P->addIncoming(ConstantInt::get(Int1Ty, 0), InitialBB);
752   P->addIncoming(Bit, ThenB.GetInsertBlock());
753   return P;
754 }
755 
756 /// Given a disjoint set of type identifiers and globals, lay out the globals,
757 /// build the bit sets and lower the llvm.type.test calls.
buildBitSetsFromGlobalVariables(ArrayRef<Metadata * > TypeIds,ArrayRef<GlobalTypeMember * > Globals)758 void LowerTypeTestsModule::buildBitSetsFromGlobalVariables(
759     ArrayRef<Metadata *> TypeIds, ArrayRef<GlobalTypeMember *> Globals) {
760   // Build a new global with the combined contents of the referenced globals.
761   // This global is a struct whose even-indexed elements contain the original
762   // contents of the referenced globals and whose odd-indexed elements contain
763   // any padding required to align the next element to the next power of 2.
764   std::vector<Constant *> GlobalInits;
765   const DataLayout &DL = M.getDataLayout();
766   for (GlobalTypeMember *G : Globals) {
767     GlobalVariable *GV = cast<GlobalVariable>(G->getGlobal());
768     GlobalInits.push_back(GV->getInitializer());
769     uint64_t InitSize = DL.getTypeAllocSize(GV->getValueType());
770 
771     // Compute the amount of padding required.
772     uint64_t Padding = NextPowerOf2(InitSize - 1) - InitSize;
773 
774     // Experiments of different caps with Chromium on both x64 and ARM64
775     // have shown that the 32-byte cap generates the smallest binary on
776     // both platforms while different caps yield similar performance.
777     // (see https://lists.llvm.org/pipermail/llvm-dev/2018-July/124694.html)
778     if (Padding > 32)
779       Padding = alignTo(InitSize, 32) - InitSize;
780 
781     GlobalInits.push_back(
782         ConstantAggregateZero::get(ArrayType::get(Int8Ty, Padding)));
783   }
784   if (!GlobalInits.empty())
785     GlobalInits.pop_back();
786   Constant *NewInit = ConstantStruct::getAnon(M.getContext(), GlobalInits);
787   auto *CombinedGlobal =
788       new GlobalVariable(M, NewInit->getType(), /*isConstant=*/true,
789                          GlobalValue::PrivateLinkage, NewInit);
790 
791   StructType *NewTy = cast<StructType>(NewInit->getType());
792   const StructLayout *CombinedGlobalLayout = DL.getStructLayout(NewTy);
793 
794   // Compute the offsets of the original globals within the new global.
795   DenseMap<GlobalTypeMember *, uint64_t> GlobalLayout;
796   for (unsigned I = 0; I != Globals.size(); ++I)
797     // Multiply by 2 to account for padding elements.
798     GlobalLayout[Globals[I]] = CombinedGlobalLayout->getElementOffset(I * 2);
799 
800   lowerTypeTestCalls(TypeIds, CombinedGlobal, GlobalLayout);
801 
802   // Build aliases pointing to offsets into the combined global for each
803   // global from which we built the combined global, and replace references
804   // to the original globals with references to the aliases.
805   for (unsigned I = 0; I != Globals.size(); ++I) {
806     GlobalVariable *GV = cast<GlobalVariable>(Globals[I]->getGlobal());
807 
808     // Multiply by 2 to account for padding elements.
809     Constant *CombinedGlobalIdxs[] = {ConstantInt::get(Int32Ty, 0),
810                                       ConstantInt::get(Int32Ty, I * 2)};
811     Constant *CombinedGlobalElemPtr = ConstantExpr::getGetElementPtr(
812         NewInit->getType(), CombinedGlobal, CombinedGlobalIdxs);
813     assert(GV->getType()->getAddressSpace() == 0);
814     GlobalAlias *GAlias =
815         GlobalAlias::create(NewTy->getElementType(I * 2), 0, GV->getLinkage(),
816                             "", CombinedGlobalElemPtr, &M);
817     GAlias->setVisibility(GV->getVisibility());
818     GAlias->takeName(GV);
819     GV->replaceAllUsesWith(GAlias);
820     GV->eraseFromParent();
821   }
822 }
823 
shouldExportConstantsAsAbsoluteSymbols()824 bool LowerTypeTestsModule::shouldExportConstantsAsAbsoluteSymbols() {
825   return (Arch == Triple::x86 || Arch == Triple::x86_64) &&
826          ObjectFormat == Triple::ELF;
827 }
828 
829 /// Export the given type identifier so that ThinLTO backends may import it.
830 /// Type identifiers are exported by adding coarse-grained information about how
831 /// to test the type identifier to the summary, and creating symbols in the
832 /// object file (aliases and absolute symbols) containing fine-grained
833 /// information about the type identifier.
834 ///
835 /// Returns a pointer to the location in which to store the bitmask, if
836 /// applicable.
exportTypeId(StringRef TypeId,const TypeIdLowering & TIL)837 uint8_t *LowerTypeTestsModule::exportTypeId(StringRef TypeId,
838                                             const TypeIdLowering &TIL) {
839   TypeTestResolution &TTRes =
840       ExportSummary->getOrInsertTypeIdSummary(TypeId).TTRes;
841   TTRes.TheKind = TIL.TheKind;
842 
843   auto ExportGlobal = [&](StringRef Name, Constant *C) {
844     GlobalAlias *GA =
845         GlobalAlias::create(Int8Ty, 0, GlobalValue::ExternalLinkage,
846                             "__typeid_" + TypeId + "_" + Name, C, &M);
847     GA->setVisibility(GlobalValue::HiddenVisibility);
848   };
849 
850   auto ExportConstant = [&](StringRef Name, uint64_t &Storage, Constant *C) {
851     if (shouldExportConstantsAsAbsoluteSymbols())
852       ExportGlobal(Name, ConstantExpr::getIntToPtr(C, Int8PtrTy));
853     else
854       Storage = cast<ConstantInt>(C)->getZExtValue();
855   };
856 
857   if (TIL.TheKind != TypeTestResolution::Unsat)
858     ExportGlobal("global_addr", TIL.OffsetedGlobal);
859 
860   if (TIL.TheKind == TypeTestResolution::ByteArray ||
861       TIL.TheKind == TypeTestResolution::Inline ||
862       TIL.TheKind == TypeTestResolution::AllOnes) {
863     ExportConstant("align", TTRes.AlignLog2, TIL.AlignLog2);
864     ExportConstant("size_m1", TTRes.SizeM1, TIL.SizeM1);
865 
866     uint64_t BitSize = cast<ConstantInt>(TIL.SizeM1)->getZExtValue() + 1;
867     if (TIL.TheKind == TypeTestResolution::Inline)
868       TTRes.SizeM1BitWidth = (BitSize <= 32) ? 5 : 6;
869     else
870       TTRes.SizeM1BitWidth = (BitSize <= 128) ? 7 : 32;
871   }
872 
873   if (TIL.TheKind == TypeTestResolution::ByteArray) {
874     ExportGlobal("byte_array", TIL.TheByteArray);
875     if (shouldExportConstantsAsAbsoluteSymbols())
876       ExportGlobal("bit_mask", TIL.BitMask);
877     else
878       return &TTRes.BitMask;
879   }
880 
881   if (TIL.TheKind == TypeTestResolution::Inline)
882     ExportConstant("inline_bits", TTRes.InlineBits, TIL.InlineBits);
883 
884   return nullptr;
885 }
886 
887 LowerTypeTestsModule::TypeIdLowering
importTypeId(StringRef TypeId)888 LowerTypeTestsModule::importTypeId(StringRef TypeId) {
889   const TypeIdSummary *TidSummary = ImportSummary->getTypeIdSummary(TypeId);
890   if (!TidSummary)
891     return {}; // Unsat: no globals match this type id.
892   const TypeTestResolution &TTRes = TidSummary->TTRes;
893 
894   TypeIdLowering TIL;
895   TIL.TheKind = TTRes.TheKind;
896 
897   auto ImportGlobal = [&](StringRef Name) {
898     // Give the global a type of length 0 so that it is not assumed not to alias
899     // with any other global.
900     Constant *C = M.getOrInsertGlobal(("__typeid_" + TypeId + "_" + Name).str(),
901                                       Int8Arr0Ty);
902     if (auto *GV = dyn_cast<GlobalVariable>(C))
903       GV->setVisibility(GlobalValue::HiddenVisibility);
904     C = ConstantExpr::getBitCast(C, Int8PtrTy);
905     return C;
906   };
907 
908   auto ImportConstant = [&](StringRef Name, uint64_t Const, unsigned AbsWidth,
909                             Type *Ty) {
910     if (!shouldExportConstantsAsAbsoluteSymbols()) {
911       Constant *C =
912           ConstantInt::get(isa<IntegerType>(Ty) ? Ty : Int64Ty, Const);
913       if (!isa<IntegerType>(Ty))
914         C = ConstantExpr::getIntToPtr(C, Ty);
915       return C;
916     }
917 
918     Constant *C = ImportGlobal(Name);
919     auto *GV = cast<GlobalVariable>(C->stripPointerCasts());
920     if (isa<IntegerType>(Ty))
921       C = ConstantExpr::getPtrToInt(C, Ty);
922     if (GV->getMetadata(LLVMContext::MD_absolute_symbol))
923       return C;
924 
925     auto SetAbsRange = [&](uint64_t Min, uint64_t Max) {
926       auto *MinC = ConstantAsMetadata::get(ConstantInt::get(IntPtrTy, Min));
927       auto *MaxC = ConstantAsMetadata::get(ConstantInt::get(IntPtrTy, Max));
928       GV->setMetadata(LLVMContext::MD_absolute_symbol,
929                       MDNode::get(M.getContext(), {MinC, MaxC}));
930     };
931     if (AbsWidth == IntPtrTy->getBitWidth())
932       SetAbsRange(~0ull, ~0ull); // Full set.
933     else
934       SetAbsRange(0, 1ull << AbsWidth);
935     return C;
936   };
937 
938   if (TIL.TheKind != TypeTestResolution::Unsat)
939     TIL.OffsetedGlobal = ImportGlobal("global_addr");
940 
941   if (TIL.TheKind == TypeTestResolution::ByteArray ||
942       TIL.TheKind == TypeTestResolution::Inline ||
943       TIL.TheKind == TypeTestResolution::AllOnes) {
944     TIL.AlignLog2 = ImportConstant("align", TTRes.AlignLog2, 8, Int8Ty);
945     TIL.SizeM1 =
946         ImportConstant("size_m1", TTRes.SizeM1, TTRes.SizeM1BitWidth, IntPtrTy);
947   }
948 
949   if (TIL.TheKind == TypeTestResolution::ByteArray) {
950     TIL.TheByteArray = ImportGlobal("byte_array");
951     TIL.BitMask = ImportConstant("bit_mask", TTRes.BitMask, 8, Int8PtrTy);
952   }
953 
954   if (TIL.TheKind == TypeTestResolution::Inline)
955     TIL.InlineBits = ImportConstant(
956         "inline_bits", TTRes.InlineBits, 1 << TTRes.SizeM1BitWidth,
957         TTRes.SizeM1BitWidth <= 5 ? Int32Ty : Int64Ty);
958 
959   return TIL;
960 }
961 
importTypeTest(CallInst * CI)962 void LowerTypeTestsModule::importTypeTest(CallInst *CI) {
963   auto TypeIdMDVal = dyn_cast<MetadataAsValue>(CI->getArgOperand(1));
964   if (!TypeIdMDVal)
965     report_fatal_error("Second argument of llvm.type.test must be metadata");
966 
967   auto TypeIdStr = dyn_cast<MDString>(TypeIdMDVal->getMetadata());
968   if (!TypeIdStr)
969     report_fatal_error(
970         "Second argument of llvm.type.test must be a metadata string");
971 
972   TypeIdLowering TIL = importTypeId(TypeIdStr->getString());
973   Value *Lowered = lowerTypeTestCall(TypeIdStr, CI, TIL);
974   CI->replaceAllUsesWith(Lowered);
975   CI->eraseFromParent();
976 }
977 
978 // ThinLTO backend: the function F has a jump table entry; update this module
979 // accordingly. isDefinition describes the type of the jump table entry.
importFunction(Function * F,bool isDefinition)980 void LowerTypeTestsModule::importFunction(Function *F, bool isDefinition) {
981   assert(F->getType()->getAddressSpace() == 0);
982 
983   GlobalValue::VisibilityTypes Visibility = F->getVisibility();
984   std::string Name = F->getName();
985 
986   if (F->isDeclarationForLinker() && isDefinition) {
987     // Non-dso_local functions may be overriden at run time,
988     // don't short curcuit them
989     if (F->isDSOLocal()) {
990       Function *RealF = Function::Create(F->getFunctionType(),
991                                          GlobalValue::ExternalLinkage,
992                                          Name + ".cfi", &M);
993       RealF->setVisibility(GlobalVariable::HiddenVisibility);
994       replaceDirectCalls(F, RealF);
995     }
996     return;
997   }
998 
999   Function *FDecl;
1000   if (F->isDeclarationForLinker() && !isDefinition) {
1001     // Declaration of an external function.
1002     FDecl = Function::Create(F->getFunctionType(), GlobalValue::ExternalLinkage,
1003                              Name + ".cfi_jt", &M);
1004     FDecl->setVisibility(GlobalValue::HiddenVisibility);
1005   } else if (isDefinition) {
1006     F->setName(Name + ".cfi");
1007     F->setLinkage(GlobalValue::ExternalLinkage);
1008     FDecl = Function::Create(F->getFunctionType(), GlobalValue::ExternalLinkage,
1009                              Name, &M);
1010     FDecl->setVisibility(Visibility);
1011     Visibility = GlobalValue::HiddenVisibility;
1012 
1013     // Delete aliases pointing to this function, they'll be re-created in the
1014     // merged output
1015     SmallVector<GlobalAlias*, 4> ToErase;
1016     for (auto &U : F->uses()) {
1017       if (auto *A = dyn_cast<GlobalAlias>(U.getUser())) {
1018         Function *AliasDecl = Function::Create(
1019             F->getFunctionType(), GlobalValue::ExternalLinkage, "", &M);
1020         AliasDecl->takeName(A);
1021         A->replaceAllUsesWith(AliasDecl);
1022         ToErase.push_back(A);
1023       }
1024     }
1025     for (auto *A : ToErase)
1026       A->eraseFromParent();
1027   } else {
1028     // Function definition without type metadata, where some other translation
1029     // unit contained a declaration with type metadata. This normally happens
1030     // during mixed CFI + non-CFI compilation. We do nothing with the function
1031     // so that it is treated the same way as a function defined outside of the
1032     // LTO unit.
1033     return;
1034   }
1035 
1036   if (F->isWeakForLinker())
1037     replaceWeakDeclarationWithJumpTablePtr(F, FDecl, isDefinition);
1038   else
1039     replaceCfiUses(F, FDecl, isDefinition);
1040 
1041   // Set visibility late because it's used in replaceCfiUses() to determine
1042   // whether uses need to to be replaced.
1043   F->setVisibility(Visibility);
1044 }
1045 
lowerTypeTestCalls(ArrayRef<Metadata * > TypeIds,Constant * CombinedGlobalAddr,const DenseMap<GlobalTypeMember *,uint64_t> & GlobalLayout)1046 void LowerTypeTestsModule::lowerTypeTestCalls(
1047     ArrayRef<Metadata *> TypeIds, Constant *CombinedGlobalAddr,
1048     const DenseMap<GlobalTypeMember *, uint64_t> &GlobalLayout) {
1049   CombinedGlobalAddr = ConstantExpr::getBitCast(CombinedGlobalAddr, Int8PtrTy);
1050 
1051   // For each type identifier in this disjoint set...
1052   for (Metadata *TypeId : TypeIds) {
1053     // Build the bitset.
1054     BitSetInfo BSI = buildBitSet(TypeId, GlobalLayout);
1055     LLVM_DEBUG({
1056       if (auto MDS = dyn_cast<MDString>(TypeId))
1057         dbgs() << MDS->getString() << ": ";
1058       else
1059         dbgs() << "<unnamed>: ";
1060       BSI.print(dbgs());
1061     });
1062 
1063     ByteArrayInfo *BAI = nullptr;
1064     TypeIdLowering TIL;
1065     TIL.OffsetedGlobal = ConstantExpr::getGetElementPtr(
1066         Int8Ty, CombinedGlobalAddr, ConstantInt::get(IntPtrTy, BSI.ByteOffset)),
1067     TIL.AlignLog2 = ConstantInt::get(Int8Ty, BSI.AlignLog2);
1068     TIL.SizeM1 = ConstantInt::get(IntPtrTy, BSI.BitSize - 1);
1069     if (BSI.isAllOnes()) {
1070       TIL.TheKind = (BSI.BitSize == 1) ? TypeTestResolution::Single
1071                                        : TypeTestResolution::AllOnes;
1072     } else if (BSI.BitSize <= 64) {
1073       TIL.TheKind = TypeTestResolution::Inline;
1074       uint64_t InlineBits = 0;
1075       for (auto Bit : BSI.Bits)
1076         InlineBits |= uint64_t(1) << Bit;
1077       if (InlineBits == 0)
1078         TIL.TheKind = TypeTestResolution::Unsat;
1079       else
1080         TIL.InlineBits = ConstantInt::get(
1081             (BSI.BitSize <= 32) ? Int32Ty : Int64Ty, InlineBits);
1082     } else {
1083       TIL.TheKind = TypeTestResolution::ByteArray;
1084       ++NumByteArraysCreated;
1085       BAI = createByteArray(BSI);
1086       TIL.TheByteArray = BAI->ByteArray;
1087       TIL.BitMask = BAI->MaskGlobal;
1088     }
1089 
1090     TypeIdUserInfo &TIUI = TypeIdUsers[TypeId];
1091 
1092     if (TIUI.IsExported) {
1093       uint8_t *MaskPtr = exportTypeId(cast<MDString>(TypeId)->getString(), TIL);
1094       if (BAI)
1095         BAI->MaskPtr = MaskPtr;
1096     }
1097 
1098     // Lower each call to llvm.type.test for this type identifier.
1099     for (CallInst *CI : TIUI.CallSites) {
1100       ++NumTypeTestCallsLowered;
1101       Value *Lowered = lowerTypeTestCall(TypeId, CI, TIL);
1102       CI->replaceAllUsesWith(Lowered);
1103       CI->eraseFromParent();
1104     }
1105   }
1106 }
1107 
verifyTypeMDNode(GlobalObject * GO,MDNode * Type)1108 void LowerTypeTestsModule::verifyTypeMDNode(GlobalObject *GO, MDNode *Type) {
1109   if (Type->getNumOperands() != 2)
1110     report_fatal_error("All operands of type metadata must have 2 elements");
1111 
1112   if (GO->isThreadLocal())
1113     report_fatal_error("Bit set element may not be thread-local");
1114   if (isa<GlobalVariable>(GO) && GO->hasSection())
1115     report_fatal_error(
1116         "A member of a type identifier may not have an explicit section");
1117 
1118   // FIXME: We previously checked that global var member of a type identifier
1119   // must be a definition, but the IR linker may leave type metadata on
1120   // declarations. We should restore this check after fixing PR31759.
1121 
1122   auto OffsetConstMD = dyn_cast<ConstantAsMetadata>(Type->getOperand(0));
1123   if (!OffsetConstMD)
1124     report_fatal_error("Type offset must be a constant");
1125   auto OffsetInt = dyn_cast<ConstantInt>(OffsetConstMD->getValue());
1126   if (!OffsetInt)
1127     report_fatal_error("Type offset must be an integer constant");
1128 }
1129 
1130 static const unsigned kX86JumpTableEntrySize = 8;
1131 static const unsigned kARMJumpTableEntrySize = 4;
1132 
getJumpTableEntrySize()1133 unsigned LowerTypeTestsModule::getJumpTableEntrySize() {
1134   switch (Arch) {
1135     case Triple::x86:
1136     case Triple::x86_64:
1137       return kX86JumpTableEntrySize;
1138     case Triple::arm:
1139     case Triple::thumb:
1140     case Triple::aarch64:
1141       return kARMJumpTableEntrySize;
1142     default:
1143       report_fatal_error("Unsupported architecture for jump tables");
1144   }
1145 }
1146 
1147 // Create a jump table entry for the target. This consists of an instruction
1148 // sequence containing a relative branch to Dest. Appends inline asm text,
1149 // constraints and arguments to AsmOS, ConstraintOS and AsmArgs.
createJumpTableEntry(raw_ostream & AsmOS,raw_ostream & ConstraintOS,Triple::ArchType JumpTableArch,SmallVectorImpl<Value * > & AsmArgs,Function * Dest)1150 void LowerTypeTestsModule::createJumpTableEntry(
1151     raw_ostream &AsmOS, raw_ostream &ConstraintOS,
1152     Triple::ArchType JumpTableArch, SmallVectorImpl<Value *> &AsmArgs,
1153     Function *Dest) {
1154   unsigned ArgIndex = AsmArgs.size();
1155 
1156   if (JumpTableArch == Triple::x86 || JumpTableArch == Triple::x86_64) {
1157     AsmOS << "jmp ${" << ArgIndex << ":c}@plt\n";
1158     AsmOS << "int3\nint3\nint3\n";
1159   } else if (JumpTableArch == Triple::arm || JumpTableArch == Triple::aarch64) {
1160     AsmOS << "b $" << ArgIndex << "\n";
1161   } else if (JumpTableArch == Triple::thumb) {
1162     AsmOS << "b.w $" << ArgIndex << "\n";
1163   } else {
1164     report_fatal_error("Unsupported architecture for jump tables");
1165   }
1166 
1167   ConstraintOS << (ArgIndex > 0 ? ",s" : "s");
1168   AsmArgs.push_back(Dest);
1169 }
1170 
getJumpTableEntryType()1171 Type *LowerTypeTestsModule::getJumpTableEntryType() {
1172   return ArrayType::get(Int8Ty, getJumpTableEntrySize());
1173 }
1174 
1175 /// Given a disjoint set of type identifiers and functions, build the bit sets
1176 /// and lower the llvm.type.test calls, architecture dependently.
buildBitSetsFromFunctions(ArrayRef<Metadata * > TypeIds,ArrayRef<GlobalTypeMember * > Functions)1177 void LowerTypeTestsModule::buildBitSetsFromFunctions(
1178     ArrayRef<Metadata *> TypeIds, ArrayRef<GlobalTypeMember *> Functions) {
1179   if (Arch == Triple::x86 || Arch == Triple::x86_64 || Arch == Triple::arm ||
1180       Arch == Triple::thumb || Arch == Triple::aarch64)
1181     buildBitSetsFromFunctionsNative(TypeIds, Functions);
1182   else if (Arch == Triple::wasm32 || Arch == Triple::wasm64)
1183     buildBitSetsFromFunctionsWASM(TypeIds, Functions);
1184   else
1185     report_fatal_error("Unsupported architecture for jump tables");
1186 }
1187 
moveInitializerToModuleConstructor(GlobalVariable * GV)1188 void LowerTypeTestsModule::moveInitializerToModuleConstructor(
1189     GlobalVariable *GV) {
1190   if (WeakInitializerFn == nullptr) {
1191     WeakInitializerFn = Function::Create(
1192         FunctionType::get(Type::getVoidTy(M.getContext()),
1193                           /* IsVarArg */ false),
1194         GlobalValue::InternalLinkage, "__cfi_global_var_init", &M);
1195     BasicBlock *BB =
1196         BasicBlock::Create(M.getContext(), "entry", WeakInitializerFn);
1197     ReturnInst::Create(M.getContext(), BB);
1198     WeakInitializerFn->setSection(
1199         ObjectFormat == Triple::MachO
1200             ? "__TEXT,__StaticInit,regular,pure_instructions"
1201             : ".text.startup");
1202     // This code is equivalent to relocation application, and should run at the
1203     // earliest possible time (i.e. with the highest priority).
1204     appendToGlobalCtors(M, WeakInitializerFn, /* Priority */ 0);
1205   }
1206 
1207   IRBuilder<> IRB(WeakInitializerFn->getEntryBlock().getTerminator());
1208   GV->setConstant(false);
1209   IRB.CreateAlignedStore(GV->getInitializer(), GV, GV->getAlignment());
1210   GV->setInitializer(Constant::getNullValue(GV->getValueType()));
1211 }
1212 
findGlobalVariableUsersOf(Constant * C,SmallSetVector<GlobalVariable *,8> & Out)1213 void LowerTypeTestsModule::findGlobalVariableUsersOf(
1214     Constant *C, SmallSetVector<GlobalVariable *, 8> &Out) {
1215   for (auto *U : C->users()){
1216     if (auto *GV = dyn_cast<GlobalVariable>(U))
1217       Out.insert(GV);
1218     else if (auto *C2 = dyn_cast<Constant>(U))
1219       findGlobalVariableUsersOf(C2, Out);
1220   }
1221 }
1222 
1223 // Replace all uses of F with (F ? JT : 0).
replaceWeakDeclarationWithJumpTablePtr(Function * F,Constant * JT,bool IsDefinition)1224 void LowerTypeTestsModule::replaceWeakDeclarationWithJumpTablePtr(
1225     Function *F, Constant *JT, bool IsDefinition) {
1226   // The target expression can not appear in a constant initializer on most
1227   // (all?) targets. Switch to a runtime initializer.
1228   SmallSetVector<GlobalVariable *, 8> GlobalVarUsers;
1229   findGlobalVariableUsersOf(F, GlobalVarUsers);
1230   for (auto GV : GlobalVarUsers)
1231     moveInitializerToModuleConstructor(GV);
1232 
1233   // Can not RAUW F with an expression that uses F. Replace with a temporary
1234   // placeholder first.
1235   Function *PlaceholderFn =
1236       Function::Create(cast<FunctionType>(F->getValueType()),
1237                        GlobalValue::ExternalWeakLinkage, "", &M);
1238   replaceCfiUses(F, PlaceholderFn, IsDefinition);
1239 
1240   Constant *Target = ConstantExpr::getSelect(
1241       ConstantExpr::getICmp(CmpInst::ICMP_NE, F,
1242                             Constant::getNullValue(F->getType())),
1243       JT, Constant::getNullValue(F->getType()));
1244   PlaceholderFn->replaceAllUsesWith(Target);
1245   PlaceholderFn->eraseFromParent();
1246 }
1247 
isThumbFunction(Function * F,Triple::ArchType ModuleArch)1248 static bool isThumbFunction(Function *F, Triple::ArchType ModuleArch) {
1249   Attribute TFAttr = F->getFnAttribute("target-features");
1250   if (!TFAttr.hasAttribute(Attribute::None)) {
1251     SmallVector<StringRef, 6> Features;
1252     TFAttr.getValueAsString().split(Features, ',');
1253     for (StringRef Feature : Features) {
1254       if (Feature == "-thumb-mode")
1255         return false;
1256       else if (Feature == "+thumb-mode")
1257         return true;
1258     }
1259   }
1260 
1261   return ModuleArch == Triple::thumb;
1262 }
1263 
1264 // Each jump table must be either ARM or Thumb as a whole for the bit-test math
1265 // to work. Pick one that matches the majority of members to minimize interop
1266 // veneers inserted by the linker.
1267 static Triple::ArchType
selectJumpTableArmEncoding(ArrayRef<GlobalTypeMember * > Functions,Triple::ArchType ModuleArch)1268 selectJumpTableArmEncoding(ArrayRef<GlobalTypeMember *> Functions,
1269                            Triple::ArchType ModuleArch) {
1270   if (ModuleArch != Triple::arm && ModuleArch != Triple::thumb)
1271     return ModuleArch;
1272 
1273   unsigned ArmCount = 0, ThumbCount = 0;
1274   for (const auto GTM : Functions) {
1275     if (!GTM->isDefinition()) {
1276       // PLT stubs are always ARM.
1277       ++ArmCount;
1278       continue;
1279     }
1280 
1281     Function *F = cast<Function>(GTM->getGlobal());
1282     ++(isThumbFunction(F, ModuleArch) ? ThumbCount : ArmCount);
1283   }
1284 
1285   return ArmCount > ThumbCount ? Triple::arm : Triple::thumb;
1286 }
1287 
createJumpTable(Function * F,ArrayRef<GlobalTypeMember * > Functions)1288 void LowerTypeTestsModule::createJumpTable(
1289     Function *F, ArrayRef<GlobalTypeMember *> Functions) {
1290   std::string AsmStr, ConstraintStr;
1291   raw_string_ostream AsmOS(AsmStr), ConstraintOS(ConstraintStr);
1292   SmallVector<Value *, 16> AsmArgs;
1293   AsmArgs.reserve(Functions.size() * 2);
1294 
1295   Triple::ArchType JumpTableArch = selectJumpTableArmEncoding(Functions, Arch);
1296 
1297   for (unsigned I = 0; I != Functions.size(); ++I)
1298     createJumpTableEntry(AsmOS, ConstraintOS, JumpTableArch, AsmArgs,
1299                          cast<Function>(Functions[I]->getGlobal()));
1300 
1301   // Align the whole table by entry size.
1302   F->setAlignment(getJumpTableEntrySize());
1303   // Skip prologue.
1304   // Disabled on win32 due to https://llvm.org/bugs/show_bug.cgi?id=28641#c3.
1305   // Luckily, this function does not get any prologue even without the
1306   // attribute.
1307   if (OS != Triple::Win32)
1308     F->addFnAttr(Attribute::Naked);
1309   if (JumpTableArch == Triple::arm)
1310     F->addFnAttr("target-features", "-thumb-mode");
1311   if (JumpTableArch == Triple::thumb) {
1312     F->addFnAttr("target-features", "+thumb-mode");
1313     // Thumb jump table assembly needs Thumb2. The following attribute is added
1314     // by Clang for -march=armv7.
1315     F->addFnAttr("target-cpu", "cortex-a8");
1316   }
1317   // Make sure we don't emit .eh_frame for this function.
1318   F->addFnAttr(Attribute::NoUnwind);
1319 
1320   BasicBlock *BB = BasicBlock::Create(M.getContext(), "entry", F);
1321   IRBuilder<> IRB(BB);
1322 
1323   SmallVector<Type *, 16> ArgTypes;
1324   ArgTypes.reserve(AsmArgs.size());
1325   for (const auto &Arg : AsmArgs)
1326     ArgTypes.push_back(Arg->getType());
1327   InlineAsm *JumpTableAsm =
1328       InlineAsm::get(FunctionType::get(IRB.getVoidTy(), ArgTypes, false),
1329                      AsmOS.str(), ConstraintOS.str(),
1330                      /*hasSideEffects=*/true);
1331 
1332   IRB.CreateCall(JumpTableAsm, AsmArgs);
1333   IRB.CreateUnreachable();
1334 }
1335 
1336 /// Given a disjoint set of type identifiers and functions, build a jump table
1337 /// for the functions, build the bit sets and lower the llvm.type.test calls.
buildBitSetsFromFunctionsNative(ArrayRef<Metadata * > TypeIds,ArrayRef<GlobalTypeMember * > Functions)1338 void LowerTypeTestsModule::buildBitSetsFromFunctionsNative(
1339     ArrayRef<Metadata *> TypeIds, ArrayRef<GlobalTypeMember *> Functions) {
1340   // Unlike the global bitset builder, the function bitset builder cannot
1341   // re-arrange functions in a particular order and base its calculations on the
1342   // layout of the functions' entry points, as we have no idea how large a
1343   // particular function will end up being (the size could even depend on what
1344   // this pass does!) Instead, we build a jump table, which is a block of code
1345   // consisting of one branch instruction for each of the functions in the bit
1346   // set that branches to the target function, and redirect any taken function
1347   // addresses to the corresponding jump table entry. In the object file's
1348   // symbol table, the symbols for the target functions also refer to the jump
1349   // table entries, so that addresses taken outside the module will pass any
1350   // verification done inside the module.
1351   //
1352   // In more concrete terms, suppose we have three functions f, g, h which are
1353   // of the same type, and a function foo that returns their addresses:
1354   //
1355   // f:
1356   // mov 0, %eax
1357   // ret
1358   //
1359   // g:
1360   // mov 1, %eax
1361   // ret
1362   //
1363   // h:
1364   // mov 2, %eax
1365   // ret
1366   //
1367   // foo:
1368   // mov f, %eax
1369   // mov g, %edx
1370   // mov h, %ecx
1371   // ret
1372   //
1373   // We output the jump table as module-level inline asm string. The end result
1374   // will (conceptually) look like this:
1375   //
1376   // f = .cfi.jumptable
1377   // g = .cfi.jumptable + 4
1378   // h = .cfi.jumptable + 8
1379   // .cfi.jumptable:
1380   // jmp f.cfi  ; 5 bytes
1381   // int3       ; 1 byte
1382   // int3       ; 1 byte
1383   // int3       ; 1 byte
1384   // jmp g.cfi  ; 5 bytes
1385   // int3       ; 1 byte
1386   // int3       ; 1 byte
1387   // int3       ; 1 byte
1388   // jmp h.cfi  ; 5 bytes
1389   // int3       ; 1 byte
1390   // int3       ; 1 byte
1391   // int3       ; 1 byte
1392   //
1393   // f.cfi:
1394   // mov 0, %eax
1395   // ret
1396   //
1397   // g.cfi:
1398   // mov 1, %eax
1399   // ret
1400   //
1401   // h.cfi:
1402   // mov 2, %eax
1403   // ret
1404   //
1405   // foo:
1406   // mov f, %eax
1407   // mov g, %edx
1408   // mov h, %ecx
1409   // ret
1410   //
1411   // Because the addresses of f, g, h are evenly spaced at a power of 2, in the
1412   // normal case the check can be carried out using the same kind of simple
1413   // arithmetic that we normally use for globals.
1414 
1415   // FIXME: find a better way to represent the jumptable in the IR.
1416   assert(!Functions.empty());
1417 
1418   // Build a simple layout based on the regular layout of jump tables.
1419   DenseMap<GlobalTypeMember *, uint64_t> GlobalLayout;
1420   unsigned EntrySize = getJumpTableEntrySize();
1421   for (unsigned I = 0; I != Functions.size(); ++I)
1422     GlobalLayout[Functions[I]] = I * EntrySize;
1423 
1424   Function *JumpTableFn =
1425       Function::Create(FunctionType::get(Type::getVoidTy(M.getContext()),
1426                                          /* IsVarArg */ false),
1427                        GlobalValue::PrivateLinkage, ".cfi.jumptable", &M);
1428   ArrayType *JumpTableType =
1429       ArrayType::get(getJumpTableEntryType(), Functions.size());
1430   auto JumpTable =
1431       ConstantExpr::getPointerCast(JumpTableFn, JumpTableType->getPointerTo(0));
1432 
1433   lowerTypeTestCalls(TypeIds, JumpTable, GlobalLayout);
1434 
1435   // Build aliases pointing to offsets into the jump table, and replace
1436   // references to the original functions with references to the aliases.
1437   for (unsigned I = 0; I != Functions.size(); ++I) {
1438     Function *F = cast<Function>(Functions[I]->getGlobal());
1439     bool IsDefinition = Functions[I]->isDefinition();
1440 
1441     Constant *CombinedGlobalElemPtr = ConstantExpr::getBitCast(
1442         ConstantExpr::getInBoundsGetElementPtr(
1443             JumpTableType, JumpTable,
1444             ArrayRef<Constant *>{ConstantInt::get(IntPtrTy, 0),
1445                                  ConstantInt::get(IntPtrTy, I)}),
1446         F->getType());
1447     if (Functions[I]->isExported()) {
1448       if (IsDefinition) {
1449         ExportSummary->cfiFunctionDefs().insert(F->getName());
1450       } else {
1451         GlobalAlias *JtAlias = GlobalAlias::create(
1452             F->getValueType(), 0, GlobalValue::ExternalLinkage,
1453             F->getName() + ".cfi_jt", CombinedGlobalElemPtr, &M);
1454         JtAlias->setVisibility(GlobalValue::HiddenVisibility);
1455         ExportSummary->cfiFunctionDecls().insert(F->getName());
1456       }
1457     }
1458     if (!IsDefinition) {
1459       if (F->isWeakForLinker())
1460         replaceWeakDeclarationWithJumpTablePtr(F, CombinedGlobalElemPtr, IsDefinition);
1461       else
1462         replaceCfiUses(F, CombinedGlobalElemPtr, IsDefinition);
1463     } else {
1464       assert(F->getType()->getAddressSpace() == 0);
1465 
1466       GlobalAlias *FAlias = GlobalAlias::create(
1467           F->getValueType(), 0, F->getLinkage(), "", CombinedGlobalElemPtr, &M);
1468       FAlias->setVisibility(F->getVisibility());
1469       FAlias->takeName(F);
1470       if (FAlias->hasName())
1471         F->setName(FAlias->getName() + ".cfi");
1472       replaceCfiUses(F, FAlias, IsDefinition);
1473       if (!F->hasLocalLinkage())
1474         F->setVisibility(GlobalVariable::HiddenVisibility);
1475     }
1476   }
1477 
1478   createJumpTable(JumpTableFn, Functions);
1479 }
1480 
1481 /// Assign a dummy layout using an incrementing counter, tag each function
1482 /// with its index represented as metadata, and lower each type test to an
1483 /// integer range comparison. During generation of the indirect function call
1484 /// table in the backend, it will assign the given indexes.
1485 /// Note: Dynamic linking is not supported, as the WebAssembly ABI has not yet
1486 /// been finalized.
buildBitSetsFromFunctionsWASM(ArrayRef<Metadata * > TypeIds,ArrayRef<GlobalTypeMember * > Functions)1487 void LowerTypeTestsModule::buildBitSetsFromFunctionsWASM(
1488     ArrayRef<Metadata *> TypeIds, ArrayRef<GlobalTypeMember *> Functions) {
1489   assert(!Functions.empty());
1490 
1491   // Build consecutive monotonic integer ranges for each call target set
1492   DenseMap<GlobalTypeMember *, uint64_t> GlobalLayout;
1493 
1494   for (GlobalTypeMember *GTM : Functions) {
1495     Function *F = cast<Function>(GTM->getGlobal());
1496 
1497     // Skip functions that are not address taken, to avoid bloating the table
1498     if (!F->hasAddressTaken())
1499       continue;
1500 
1501     // Store metadata with the index for each function
1502     MDNode *MD = MDNode::get(F->getContext(),
1503                              ArrayRef<Metadata *>(ConstantAsMetadata::get(
1504                                  ConstantInt::get(Int64Ty, IndirectIndex))));
1505     F->setMetadata("wasm.index", MD);
1506 
1507     // Assign the counter value
1508     GlobalLayout[GTM] = IndirectIndex++;
1509   }
1510 
1511   // The indirect function table index space starts at zero, so pass a NULL
1512   // pointer as the subtracted "jump table" offset.
1513   lowerTypeTestCalls(TypeIds, ConstantPointerNull::get(Int32PtrTy),
1514                      GlobalLayout);
1515 }
1516 
buildBitSetsFromDisjointSet(ArrayRef<Metadata * > TypeIds,ArrayRef<GlobalTypeMember * > Globals,ArrayRef<ICallBranchFunnel * > ICallBranchFunnels)1517 void LowerTypeTestsModule::buildBitSetsFromDisjointSet(
1518     ArrayRef<Metadata *> TypeIds, ArrayRef<GlobalTypeMember *> Globals,
1519     ArrayRef<ICallBranchFunnel *> ICallBranchFunnels) {
1520   DenseMap<Metadata *, uint64_t> TypeIdIndices;
1521   for (unsigned I = 0; I != TypeIds.size(); ++I)
1522     TypeIdIndices[TypeIds[I]] = I;
1523 
1524   // For each type identifier, build a set of indices that refer to members of
1525   // the type identifier.
1526   std::vector<std::set<uint64_t>> TypeMembers(TypeIds.size());
1527   unsigned GlobalIndex = 0;
1528   DenseMap<GlobalTypeMember *, uint64_t> GlobalIndices;
1529   for (GlobalTypeMember *GTM : Globals) {
1530     for (MDNode *Type : GTM->types()) {
1531       // Type = { offset, type identifier }
1532       auto I = TypeIdIndices.find(Type->getOperand(1));
1533       if (I != TypeIdIndices.end())
1534         TypeMembers[I->second].insert(GlobalIndex);
1535     }
1536     GlobalIndices[GTM] = GlobalIndex;
1537     GlobalIndex++;
1538   }
1539 
1540   for (ICallBranchFunnel *JT : ICallBranchFunnels) {
1541     TypeMembers.emplace_back();
1542     std::set<uint64_t> &TMSet = TypeMembers.back();
1543     for (GlobalTypeMember *T : JT->targets())
1544       TMSet.insert(GlobalIndices[T]);
1545   }
1546 
1547   // Order the sets of indices by size. The GlobalLayoutBuilder works best
1548   // when given small index sets first.
1549   std::stable_sort(
1550       TypeMembers.begin(), TypeMembers.end(),
1551       [](const std::set<uint64_t> &O1, const std::set<uint64_t> &O2) {
1552         return O1.size() < O2.size();
1553       });
1554 
1555   // Create a GlobalLayoutBuilder and provide it with index sets as layout
1556   // fragments. The GlobalLayoutBuilder tries to lay out members of fragments as
1557   // close together as possible.
1558   GlobalLayoutBuilder GLB(Globals.size());
1559   for (auto &&MemSet : TypeMembers)
1560     GLB.addFragment(MemSet);
1561 
1562   // Build a vector of globals with the computed layout.
1563   bool IsGlobalSet =
1564       Globals.empty() || isa<GlobalVariable>(Globals[0]->getGlobal());
1565   std::vector<GlobalTypeMember *> OrderedGTMs(Globals.size());
1566   auto OGTMI = OrderedGTMs.begin();
1567   for (auto &&F : GLB.Fragments) {
1568     for (auto &&Offset : F) {
1569       if (IsGlobalSet != isa<GlobalVariable>(Globals[Offset]->getGlobal()))
1570         report_fatal_error("Type identifier may not contain both global "
1571                            "variables and functions");
1572       *OGTMI++ = Globals[Offset];
1573     }
1574   }
1575 
1576   // Build the bitsets from this disjoint set.
1577   if (IsGlobalSet)
1578     buildBitSetsFromGlobalVariables(TypeIds, OrderedGTMs);
1579   else
1580     buildBitSetsFromFunctions(TypeIds, OrderedGTMs);
1581 }
1582 
1583 /// Lower all type tests in this module.
LowerTypeTestsModule(Module & M,ModuleSummaryIndex * ExportSummary,const ModuleSummaryIndex * ImportSummary)1584 LowerTypeTestsModule::LowerTypeTestsModule(
1585     Module &M, ModuleSummaryIndex *ExportSummary,
1586     const ModuleSummaryIndex *ImportSummary)
1587     : M(M), ExportSummary(ExportSummary), ImportSummary(ImportSummary) {
1588   assert(!(ExportSummary && ImportSummary));
1589   Triple TargetTriple(M.getTargetTriple());
1590   Arch = TargetTriple.getArch();
1591   OS = TargetTriple.getOS();
1592   ObjectFormat = TargetTriple.getObjectFormat();
1593 }
1594 
runForTesting(Module & M)1595 bool LowerTypeTestsModule::runForTesting(Module &M) {
1596   ModuleSummaryIndex Summary(/*HaveGVs=*/false);
1597 
1598   // Handle the command-line summary arguments. This code is for testing
1599   // purposes only, so we handle errors directly.
1600   if (!ClReadSummary.empty()) {
1601     ExitOnError ExitOnErr("-lowertypetests-read-summary: " + ClReadSummary +
1602                           ": ");
1603     auto ReadSummaryFile =
1604         ExitOnErr(errorOrToExpected(MemoryBuffer::getFile(ClReadSummary)));
1605 
1606     yaml::Input In(ReadSummaryFile->getBuffer());
1607     In >> Summary;
1608     ExitOnErr(errorCodeToError(In.error()));
1609   }
1610 
1611   bool Changed =
1612       LowerTypeTestsModule(
1613           M, ClSummaryAction == PassSummaryAction::Export ? &Summary : nullptr,
1614           ClSummaryAction == PassSummaryAction::Import ? &Summary : nullptr)
1615           .lower();
1616 
1617   if (!ClWriteSummary.empty()) {
1618     ExitOnError ExitOnErr("-lowertypetests-write-summary: " + ClWriteSummary +
1619                           ": ");
1620     std::error_code EC;
1621     raw_fd_ostream OS(ClWriteSummary, EC, sys::fs::F_Text);
1622     ExitOnErr(errorCodeToError(EC));
1623 
1624     yaml::Output Out(OS);
1625     Out << Summary;
1626   }
1627 
1628   return Changed;
1629 }
1630 
isDirectCall(Use & U)1631 static bool isDirectCall(Use& U) {
1632   auto *Usr = dyn_cast<CallInst>(U.getUser());
1633   if (Usr) {
1634     CallSite CS(Usr);
1635     if (CS.isCallee(&U))
1636       return true;
1637   }
1638   return false;
1639 }
1640 
replaceCfiUses(Function * Old,Value * New,bool IsDefinition)1641 void LowerTypeTestsModule::replaceCfiUses(Function *Old, Value *New, bool IsDefinition) {
1642   SmallSetVector<Constant *, 4> Constants;
1643   auto UI = Old->use_begin(), E = Old->use_end();
1644   for (; UI != E;) {
1645     Use &U = *UI;
1646     ++UI;
1647 
1648     // Skip block addresses
1649     if (isa<BlockAddress>(U.getUser()))
1650       continue;
1651 
1652     // Skip direct calls to externally defined or non-dso_local functions
1653     if (isDirectCall(U) && (Old->isDSOLocal() || !IsDefinition))
1654       continue;
1655 
1656     // Must handle Constants specially, we cannot call replaceUsesOfWith on a
1657     // constant because they are uniqued.
1658     if (auto *C = dyn_cast<Constant>(U.getUser())) {
1659       if (!isa<GlobalValue>(C)) {
1660         // Save unique users to avoid processing operand replacement
1661         // more than once.
1662         Constants.insert(C);
1663         continue;
1664       }
1665     }
1666 
1667     U.set(New);
1668   }
1669 
1670   // Process operand replacement of saved constants.
1671   for (auto *C : Constants)
1672     C->handleOperandChange(Old, New);
1673 }
1674 
replaceDirectCalls(Value * Old,Value * New)1675 void LowerTypeTestsModule::replaceDirectCalls(Value *Old, Value *New) {
1676   auto UI = Old->use_begin(), E = Old->use_end();
1677   for (; UI != E;) {
1678     Use &U = *UI;
1679     ++UI;
1680 
1681     if (!isDirectCall(U))
1682       continue;
1683 
1684     U.set(New);
1685   }
1686 }
1687 
lower()1688 bool LowerTypeTestsModule::lower() {
1689   Function *TypeTestFunc =
1690       M.getFunction(Intrinsic::getName(Intrinsic::type_test));
1691   Function *ICallBranchFunnelFunc =
1692       M.getFunction(Intrinsic::getName(Intrinsic::icall_branch_funnel));
1693   if ((!TypeTestFunc || TypeTestFunc->use_empty()) &&
1694       (!ICallBranchFunnelFunc || ICallBranchFunnelFunc->use_empty()) &&
1695       !ExportSummary && !ImportSummary)
1696     return false;
1697 
1698   if (ImportSummary) {
1699     if (TypeTestFunc) {
1700       for (auto UI = TypeTestFunc->use_begin(), UE = TypeTestFunc->use_end();
1701            UI != UE;) {
1702         auto *CI = cast<CallInst>((*UI++).getUser());
1703         importTypeTest(CI);
1704       }
1705     }
1706 
1707     if (ICallBranchFunnelFunc && !ICallBranchFunnelFunc->use_empty())
1708       report_fatal_error(
1709           "unexpected call to llvm.icall.branch.funnel during import phase");
1710 
1711     SmallVector<Function *, 8> Defs;
1712     SmallVector<Function *, 8> Decls;
1713     for (auto &F : M) {
1714       // CFI functions are either external, or promoted. A local function may
1715       // have the same name, but it's not the one we are looking for.
1716       if (F.hasLocalLinkage())
1717         continue;
1718       if (ImportSummary->cfiFunctionDefs().count(F.getName()))
1719         Defs.push_back(&F);
1720       else if (ImportSummary->cfiFunctionDecls().count(F.getName()))
1721         Decls.push_back(&F);
1722     }
1723 
1724     for (auto F : Defs)
1725       importFunction(F, /*isDefinition*/ true);
1726     for (auto F : Decls)
1727       importFunction(F, /*isDefinition*/ false);
1728 
1729     return true;
1730   }
1731 
1732   // Equivalence class set containing type identifiers and the globals that
1733   // reference them. This is used to partition the set of type identifiers in
1734   // the module into disjoint sets.
1735   using GlobalClassesTy = EquivalenceClasses<
1736       PointerUnion3<GlobalTypeMember *, Metadata *, ICallBranchFunnel *>>;
1737   GlobalClassesTy GlobalClasses;
1738 
1739   // Verify the type metadata and build a few data structures to let us
1740   // efficiently enumerate the type identifiers associated with a global:
1741   // a list of GlobalTypeMembers (a GlobalObject stored alongside a vector
1742   // of associated type metadata) and a mapping from type identifiers to their
1743   // list of GlobalTypeMembers and last observed index in the list of globals.
1744   // The indices will be used later to deterministically order the list of type
1745   // identifiers.
1746   BumpPtrAllocator Alloc;
1747   struct TIInfo {
1748     unsigned UniqueId;
1749     std::vector<GlobalTypeMember *> RefGlobals;
1750   };
1751   DenseMap<Metadata *, TIInfo> TypeIdInfo;
1752   unsigned CurUniqueId = 0;
1753   SmallVector<MDNode *, 2> Types;
1754 
1755   // Cross-DSO CFI emits jumptable entries for exported functions as well as
1756   // address taken functions in case they are address taken in other modules.
1757   const bool CrossDsoCfi = M.getModuleFlag("Cross-DSO CFI") != nullptr;
1758 
1759   struct ExportedFunctionInfo {
1760     CfiFunctionLinkage Linkage;
1761     MDNode *FuncMD; // {name, linkage, type[, type...]}
1762   };
1763   DenseMap<StringRef, ExportedFunctionInfo> ExportedFunctions;
1764   if (ExportSummary) {
1765     // A set of all functions that are address taken by a live global object.
1766     DenseSet<GlobalValue::GUID> AddressTaken;
1767     for (auto &I : *ExportSummary)
1768       for (auto &GVS : I.second.SummaryList)
1769         if (GVS->isLive())
1770           for (auto &Ref : GVS->refs())
1771             AddressTaken.insert(Ref.getGUID());
1772 
1773     NamedMDNode *CfiFunctionsMD = M.getNamedMetadata("cfi.functions");
1774     if (CfiFunctionsMD) {
1775       for (auto FuncMD : CfiFunctionsMD->operands()) {
1776         assert(FuncMD->getNumOperands() >= 2);
1777         StringRef FunctionName =
1778             cast<MDString>(FuncMD->getOperand(0))->getString();
1779         CfiFunctionLinkage Linkage = static_cast<CfiFunctionLinkage>(
1780             cast<ConstantAsMetadata>(FuncMD->getOperand(1))
1781                 ->getValue()
1782                 ->getUniqueInteger()
1783                 .getZExtValue());
1784         const GlobalValue::GUID GUID = GlobalValue::getGUID(
1785                 GlobalValue::dropLLVMManglingEscape(FunctionName));
1786         // Do not emit jumptable entries for functions that are not-live and
1787         // have no live references (and are not exported with cross-DSO CFI.)
1788         if (!ExportSummary->isGUIDLive(GUID))
1789           continue;
1790         if (!AddressTaken.count(GUID)) {
1791           if (!CrossDsoCfi || Linkage != CFL_Definition)
1792             continue;
1793 
1794           bool Exported = false;
1795           if (auto VI = ExportSummary->getValueInfo(GUID))
1796             for (auto &GVS : VI.getSummaryList())
1797               if (GVS->isLive() && !GlobalValue::isLocalLinkage(GVS->linkage()))
1798                 Exported = true;
1799 
1800           if (!Exported)
1801             continue;
1802         }
1803         auto P = ExportedFunctions.insert({FunctionName, {Linkage, FuncMD}});
1804         if (!P.second && P.first->second.Linkage != CFL_Definition)
1805           P.first->second = {Linkage, FuncMD};
1806       }
1807 
1808       for (const auto &P : ExportedFunctions) {
1809         StringRef FunctionName = P.first;
1810         CfiFunctionLinkage Linkage = P.second.Linkage;
1811         MDNode *FuncMD = P.second.FuncMD;
1812         Function *F = M.getFunction(FunctionName);
1813         if (!F)
1814           F = Function::Create(
1815               FunctionType::get(Type::getVoidTy(M.getContext()), false),
1816               GlobalVariable::ExternalLinkage, FunctionName, &M);
1817 
1818         // If the function is available_externally, remove its definition so
1819         // that it is handled the same way as a declaration. Later we will try
1820         // to create an alias using this function's linkage, which will fail if
1821         // the linkage is available_externally. This will also result in us
1822         // following the code path below to replace the type metadata.
1823         if (F->hasAvailableExternallyLinkage()) {
1824           F->setLinkage(GlobalValue::ExternalLinkage);
1825           F->deleteBody();
1826           F->setComdat(nullptr);
1827           F->clearMetadata();
1828         }
1829 
1830         // Update the linkage for extern_weak declarations when a definition
1831         // exists.
1832         if (Linkage == CFL_Definition && F->hasExternalWeakLinkage())
1833           F->setLinkage(GlobalValue::ExternalLinkage);
1834 
1835         // If the function in the full LTO module is a declaration, replace its
1836         // type metadata with the type metadata we found in cfi.functions. That
1837         // metadata is presumed to be more accurate than the metadata attached
1838         // to the declaration.
1839         if (F->isDeclaration()) {
1840           if (Linkage == CFL_WeakDeclaration)
1841             F->setLinkage(GlobalValue::ExternalWeakLinkage);
1842 
1843           F->eraseMetadata(LLVMContext::MD_type);
1844           for (unsigned I = 2; I < FuncMD->getNumOperands(); ++I)
1845             F->addMetadata(LLVMContext::MD_type,
1846                            *cast<MDNode>(FuncMD->getOperand(I).get()));
1847         }
1848       }
1849     }
1850   }
1851 
1852   DenseMap<GlobalObject *, GlobalTypeMember *> GlobalTypeMembers;
1853   for (GlobalObject &GO : M.global_objects()) {
1854     if (isa<GlobalVariable>(GO) && GO.isDeclarationForLinker())
1855       continue;
1856 
1857     Types.clear();
1858     GO.getMetadata(LLVMContext::MD_type, Types);
1859 
1860     bool IsDefinition = !GO.isDeclarationForLinker();
1861     bool IsExported = false;
1862     if (Function *F = dyn_cast<Function>(&GO)) {
1863       if (ExportedFunctions.count(F->getName())) {
1864         IsDefinition |= ExportedFunctions[F->getName()].Linkage == CFL_Definition;
1865         IsExported = true;
1866       // TODO: The logic here checks only that the function is address taken,
1867       // not that the address takers are live. This can be updated to check
1868       // their liveness and emit fewer jumptable entries once monolithic LTO
1869       // builds also emit summaries.
1870       } else if (!F->hasAddressTaken()) {
1871         if (!CrossDsoCfi || !IsDefinition || F->hasLocalLinkage())
1872           continue;
1873       }
1874     }
1875 
1876     auto *GTM =
1877         GlobalTypeMember::create(Alloc, &GO, IsDefinition, IsExported, Types);
1878     GlobalTypeMembers[&GO] = GTM;
1879     for (MDNode *Type : Types) {
1880       verifyTypeMDNode(&GO, Type);
1881       auto &Info = TypeIdInfo[Type->getOperand(1)];
1882       Info.UniqueId = ++CurUniqueId;
1883       Info.RefGlobals.push_back(GTM);
1884     }
1885   }
1886 
1887   auto AddTypeIdUse = [&](Metadata *TypeId) -> TypeIdUserInfo & {
1888     // Add the call site to the list of call sites for this type identifier. We
1889     // also use TypeIdUsers to keep track of whether we have seen this type
1890     // identifier before. If we have, we don't need to re-add the referenced
1891     // globals to the equivalence class.
1892     auto Ins = TypeIdUsers.insert({TypeId, {}});
1893     if (Ins.second) {
1894       // Add the type identifier to the equivalence class.
1895       GlobalClassesTy::iterator GCI = GlobalClasses.insert(TypeId);
1896       GlobalClassesTy::member_iterator CurSet = GlobalClasses.findLeader(GCI);
1897 
1898       // Add the referenced globals to the type identifier's equivalence class.
1899       for (GlobalTypeMember *GTM : TypeIdInfo[TypeId].RefGlobals)
1900         CurSet = GlobalClasses.unionSets(
1901             CurSet, GlobalClasses.findLeader(GlobalClasses.insert(GTM)));
1902     }
1903 
1904     return Ins.first->second;
1905   };
1906 
1907   if (TypeTestFunc) {
1908     for (const Use &U : TypeTestFunc->uses()) {
1909       auto CI = cast<CallInst>(U.getUser());
1910 
1911       auto TypeIdMDVal = dyn_cast<MetadataAsValue>(CI->getArgOperand(1));
1912       if (!TypeIdMDVal)
1913         report_fatal_error("Second argument of llvm.type.test must be metadata");
1914       auto TypeId = TypeIdMDVal->getMetadata();
1915       AddTypeIdUse(TypeId).CallSites.push_back(CI);
1916     }
1917   }
1918 
1919   if (ICallBranchFunnelFunc) {
1920     for (const Use &U : ICallBranchFunnelFunc->uses()) {
1921       if (Arch != Triple::x86_64)
1922         report_fatal_error(
1923             "llvm.icall.branch.funnel not supported on this target");
1924 
1925       auto CI = cast<CallInst>(U.getUser());
1926 
1927       std::vector<GlobalTypeMember *> Targets;
1928       if (CI->getNumArgOperands() % 2 != 1)
1929         report_fatal_error("number of arguments should be odd");
1930 
1931       GlobalClassesTy::member_iterator CurSet;
1932       for (unsigned I = 1; I != CI->getNumArgOperands(); I += 2) {
1933         int64_t Offset;
1934         auto *Base = dyn_cast<GlobalObject>(GetPointerBaseWithConstantOffset(
1935             CI->getOperand(I), Offset, M.getDataLayout()));
1936         if (!Base)
1937           report_fatal_error(
1938               "Expected branch funnel operand to be global value");
1939 
1940         GlobalTypeMember *GTM = GlobalTypeMembers[Base];
1941         Targets.push_back(GTM);
1942         GlobalClassesTy::member_iterator NewSet =
1943             GlobalClasses.findLeader(GlobalClasses.insert(GTM));
1944         if (I == 1)
1945           CurSet = NewSet;
1946         else
1947           CurSet = GlobalClasses.unionSets(CurSet, NewSet);
1948       }
1949 
1950       GlobalClasses.unionSets(
1951           CurSet, GlobalClasses.findLeader(
1952                       GlobalClasses.insert(ICallBranchFunnel::create(
1953                           Alloc, CI, Targets, ++CurUniqueId))));
1954     }
1955   }
1956 
1957   if (ExportSummary) {
1958     DenseMap<GlobalValue::GUID, TinyPtrVector<Metadata *>> MetadataByGUID;
1959     for (auto &P : TypeIdInfo) {
1960       if (auto *TypeId = dyn_cast<MDString>(P.first))
1961         MetadataByGUID[GlobalValue::getGUID(TypeId->getString())].push_back(
1962             TypeId);
1963     }
1964 
1965     for (auto &P : *ExportSummary) {
1966       for (auto &S : P.second.SummaryList) {
1967         if (!ExportSummary->isGlobalValueLive(S.get()))
1968           continue;
1969         if (auto *FS = dyn_cast<FunctionSummary>(S->getBaseObject()))
1970           for (GlobalValue::GUID G : FS->type_tests())
1971             for (Metadata *MD : MetadataByGUID[G])
1972               AddTypeIdUse(MD).IsExported = true;
1973       }
1974     }
1975   }
1976 
1977   if (GlobalClasses.empty())
1978     return false;
1979 
1980   // Build a list of disjoint sets ordered by their maximum global index for
1981   // determinism.
1982   std::vector<std::pair<GlobalClassesTy::iterator, unsigned>> Sets;
1983   for (GlobalClassesTy::iterator I = GlobalClasses.begin(),
1984                                  E = GlobalClasses.end();
1985        I != E; ++I) {
1986     if (!I->isLeader())
1987       continue;
1988     ++NumTypeIdDisjointSets;
1989 
1990     unsigned MaxUniqueId = 0;
1991     for (GlobalClassesTy::member_iterator MI = GlobalClasses.member_begin(I);
1992          MI != GlobalClasses.member_end(); ++MI) {
1993       if (auto *MD = MI->dyn_cast<Metadata *>())
1994         MaxUniqueId = std::max(MaxUniqueId, TypeIdInfo[MD].UniqueId);
1995       else if (auto *BF = MI->dyn_cast<ICallBranchFunnel *>())
1996         MaxUniqueId = std::max(MaxUniqueId, BF->UniqueId);
1997     }
1998     Sets.emplace_back(I, MaxUniqueId);
1999   }
2000   llvm::sort(Sets.begin(), Sets.end(),
2001              [](const std::pair<GlobalClassesTy::iterator, unsigned> &S1,
2002                 const std::pair<GlobalClassesTy::iterator, unsigned> &S2) {
2003                return S1.second < S2.second;
2004              });
2005 
2006   // For each disjoint set we found...
2007   for (const auto &S : Sets) {
2008     // Build the list of type identifiers in this disjoint set.
2009     std::vector<Metadata *> TypeIds;
2010     std::vector<GlobalTypeMember *> Globals;
2011     std::vector<ICallBranchFunnel *> ICallBranchFunnels;
2012     for (GlobalClassesTy::member_iterator MI =
2013              GlobalClasses.member_begin(S.first);
2014          MI != GlobalClasses.member_end(); ++MI) {
2015       if (MI->is<Metadata *>())
2016         TypeIds.push_back(MI->get<Metadata *>());
2017       else if (MI->is<GlobalTypeMember *>())
2018         Globals.push_back(MI->get<GlobalTypeMember *>());
2019       else
2020         ICallBranchFunnels.push_back(MI->get<ICallBranchFunnel *>());
2021     }
2022 
2023     // Order type identifiers by unique ID for determinism. This ordering is
2024     // stable as there is a one-to-one mapping between metadata and unique IDs.
2025     llvm::sort(TypeIds.begin(), TypeIds.end(), [&](Metadata *M1, Metadata *M2) {
2026       return TypeIdInfo[M1].UniqueId < TypeIdInfo[M2].UniqueId;
2027     });
2028 
2029     // Same for the branch funnels.
2030     llvm::sort(ICallBranchFunnels.begin(), ICallBranchFunnels.end(),
2031                [&](ICallBranchFunnel *F1, ICallBranchFunnel *F2) {
2032                  return F1->UniqueId < F2->UniqueId;
2033                });
2034 
2035     // Build bitsets for this disjoint set.
2036     buildBitSetsFromDisjointSet(TypeIds, Globals, ICallBranchFunnels);
2037   }
2038 
2039   allocateByteArrays();
2040 
2041   // Parse alias data to replace stand-in function declarations for aliases
2042   // with an alias to the intended target.
2043   if (ExportSummary) {
2044     if (NamedMDNode *AliasesMD = M.getNamedMetadata("aliases")) {
2045       for (auto AliasMD : AliasesMD->operands()) {
2046         assert(AliasMD->getNumOperands() >= 4);
2047         StringRef AliasName =
2048             cast<MDString>(AliasMD->getOperand(0))->getString();
2049         StringRef Aliasee = cast<MDString>(AliasMD->getOperand(1))->getString();
2050 
2051         if (!ExportedFunctions.count(Aliasee) ||
2052             ExportedFunctions[Aliasee].Linkage != CFL_Definition ||
2053             !M.getNamedAlias(Aliasee))
2054           continue;
2055 
2056         GlobalValue::VisibilityTypes Visibility =
2057             static_cast<GlobalValue::VisibilityTypes>(
2058                 cast<ConstantAsMetadata>(AliasMD->getOperand(2))
2059                     ->getValue()
2060                     ->getUniqueInteger()
2061                     .getZExtValue());
2062         bool Weak =
2063             static_cast<bool>(cast<ConstantAsMetadata>(AliasMD->getOperand(3))
2064                                   ->getValue()
2065                                   ->getUniqueInteger()
2066                                   .getZExtValue());
2067 
2068         auto *Alias = GlobalAlias::create("", M.getNamedAlias(Aliasee));
2069         Alias->setVisibility(Visibility);
2070         if (Weak)
2071           Alias->setLinkage(GlobalValue::WeakAnyLinkage);
2072 
2073         if (auto *F = M.getFunction(AliasName)) {
2074           Alias->takeName(F);
2075           F->replaceAllUsesWith(Alias);
2076           F->eraseFromParent();
2077         } else {
2078           Alias->setName(AliasName);
2079         }
2080       }
2081     }
2082   }
2083 
2084   // Emit .symver directives for exported functions, if they exist.
2085   if (ExportSummary) {
2086     if (NamedMDNode *SymversMD = M.getNamedMetadata("symvers")) {
2087       for (auto Symver : SymversMD->operands()) {
2088         assert(Symver->getNumOperands() >= 2);
2089         StringRef SymbolName =
2090             cast<MDString>(Symver->getOperand(0))->getString();
2091         StringRef Alias = cast<MDString>(Symver->getOperand(1))->getString();
2092 
2093         if (!ExportedFunctions.count(SymbolName))
2094           continue;
2095 
2096         M.appendModuleInlineAsm(
2097             (llvm::Twine(".symver ") + SymbolName + ", " + Alias).str());
2098       }
2099     }
2100   }
2101 
2102   return true;
2103 }
2104 
run(Module & M,ModuleAnalysisManager & AM)2105 PreservedAnalyses LowerTypeTestsPass::run(Module &M,
2106                                           ModuleAnalysisManager &AM) {
2107   bool Changed = LowerTypeTestsModule(M, ExportSummary, ImportSummary).lower();
2108   if (!Changed)
2109     return PreservedAnalyses::all();
2110   return PreservedAnalyses::none();
2111 }
2112