1 //===--- ExpandMemCmp.cpp - Expand memcmp() to load/stores ----------------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This pass tries to expand memcmp() calls into optimally-sized loads and
11 // compares for the target.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "llvm/ADT/Statistic.h"
16 #include "llvm/Analysis/ConstantFolding.h"
17 #include "llvm/Analysis/TargetLibraryInfo.h"
18 #include "llvm/Analysis/TargetTransformInfo.h"
19 #include "llvm/Analysis/ValueTracking.h"
20 #include "llvm/CodeGen/TargetLowering.h"
21 #include "llvm/CodeGen/TargetPassConfig.h"
22 #include "llvm/CodeGen/TargetSubtargetInfo.h"
23 #include "llvm/IR/IRBuilder.h"
24 
25 using namespace llvm;
26 
27 #define DEBUG_TYPE "expandmemcmp"
28 
29 STATISTIC(NumMemCmpCalls, "Number of memcmp calls");
30 STATISTIC(NumMemCmpNotConstant, "Number of memcmp calls without constant size");
31 STATISTIC(NumMemCmpGreaterThanMax,
32           "Number of memcmp calls with size greater than max size");
33 STATISTIC(NumMemCmpInlined, "Number of inlined memcmp calls");
34 
35 static cl::opt<unsigned> MemCmpEqZeroNumLoadsPerBlock(
36     "memcmp-num-loads-per-block", cl::Hidden, cl::init(1),
37     cl::desc("The number of loads per basic block for inline expansion of "
38              "memcmp that is only being compared against zero."));
39 
40 namespace {
41 
42 
43 // This class provides helper functions to expand a memcmp library call into an
44 // inline expansion.
45 class MemCmpExpansion {
46   struct ResultBlock {
47     BasicBlock *BB = nullptr;
48     PHINode *PhiSrc1 = nullptr;
49     PHINode *PhiSrc2 = nullptr;
50 
51     ResultBlock() = default;
52   };
53 
54   CallInst *const CI;
55   ResultBlock ResBlock;
56   const uint64_t Size;
57   unsigned MaxLoadSize;
58   uint64_t NumLoadsNonOneByte;
59   const uint64_t NumLoadsPerBlockForZeroCmp;
60   std::vector<BasicBlock *> LoadCmpBlocks;
61   BasicBlock *EndBlock;
62   PHINode *PhiRes;
63   const bool IsUsedForZeroCmp;
64   const DataLayout &DL;
65   IRBuilder<> Builder;
66   // Represents the decomposition in blocks of the expansion. For example,
67   // comparing 33 bytes on X86+sse can be done with 2x16-byte loads and
68   // 1x1-byte load, which would be represented as [{16, 0}, {16, 16}, {32, 1}.
69   // TODO(courbet): Involve the target more in this computation. On X86, 7
70   // bytes can be done more efficiently with two overlaping 4-byte loads than
71   // covering the interval with [{4, 0},{2, 4},{1, 6}}.
72   struct LoadEntry {
LoadEntry__anon507f8fed0111::MemCmpExpansion::LoadEntry73     LoadEntry(unsigned LoadSize, uint64_t Offset)
74         : LoadSize(LoadSize), Offset(Offset) {
75       assert(Offset % LoadSize == 0 && "invalid load entry");
76     }
77 
getGEPIndex__anon507f8fed0111::MemCmpExpansion::LoadEntry78     uint64_t getGEPIndex() const { return Offset / LoadSize; }
79 
80     // The size of the load for this block, in bytes.
81     const unsigned LoadSize;
82     // The offset of this load WRT the base pointer, in bytes.
83     const uint64_t Offset;
84   };
85   SmallVector<LoadEntry, 8> LoadSequence;
86 
87   void createLoadCmpBlocks();
88   void createResultBlock();
89   void setupResultBlockPHINodes();
90   void setupEndBlockPHINodes();
91   Value *getCompareLoadPairs(unsigned BlockIndex, unsigned &LoadIndex);
92   void emitLoadCompareBlock(unsigned BlockIndex);
93   void emitLoadCompareBlockMultipleLoads(unsigned BlockIndex,
94                                          unsigned &LoadIndex);
95   void emitLoadCompareByteBlock(unsigned BlockIndex, unsigned GEPIndex);
96   void emitMemCmpResultBlock();
97   Value *getMemCmpExpansionZeroCase();
98   Value *getMemCmpEqZeroOneBlock();
99   Value *getMemCmpOneBlock();
100 
101  public:
102   MemCmpExpansion(CallInst *CI, uint64_t Size,
103                   const TargetTransformInfo::MemCmpExpansionOptions &Options,
104                   unsigned MaxNumLoads, const bool IsUsedForZeroCmp,
105                   unsigned MaxLoadsPerBlockForZeroCmp, const DataLayout &TheDataLayout);
106 
107   unsigned getNumBlocks();
getNumLoads() const108   uint64_t getNumLoads() const { return LoadSequence.size(); }
109 
110   Value *getMemCmpExpansion();
111 };
112 
113 // Initialize the basic block structure required for expansion of memcmp call
114 // with given maximum load size and memcmp size parameter.
115 // This structure includes:
116 // 1. A list of load compare blocks - LoadCmpBlocks.
117 // 2. An EndBlock, split from original instruction point, which is the block to
118 // return from.
119 // 3. ResultBlock, block to branch to for early exit when a
120 // LoadCmpBlock finds a difference.
MemCmpExpansion(CallInst * const CI,uint64_t Size,const TargetTransformInfo::MemCmpExpansionOptions & Options,const unsigned MaxNumLoads,const bool IsUsedForZeroCmp,const unsigned MaxLoadsPerBlockForZeroCmp,const DataLayout & TheDataLayout)121 MemCmpExpansion::MemCmpExpansion(
122     CallInst *const CI, uint64_t Size,
123     const TargetTransformInfo::MemCmpExpansionOptions &Options,
124     const unsigned MaxNumLoads, const bool IsUsedForZeroCmp,
125     const unsigned MaxLoadsPerBlockForZeroCmp, const DataLayout &TheDataLayout)
126     : CI(CI),
127       Size(Size),
128       MaxLoadSize(0),
129       NumLoadsNonOneByte(0),
130       NumLoadsPerBlockForZeroCmp(MaxLoadsPerBlockForZeroCmp),
131       IsUsedForZeroCmp(IsUsedForZeroCmp),
132       DL(TheDataLayout),
133       Builder(CI) {
134   assert(Size > 0 && "zero blocks");
135   // Scale the max size down if the target can load more bytes than we need.
136   size_t LoadSizeIndex = 0;
137   while (LoadSizeIndex < Options.LoadSizes.size() &&
138          Options.LoadSizes[LoadSizeIndex] > Size) {
139     ++LoadSizeIndex;
140   }
141   this->MaxLoadSize = Options.LoadSizes[LoadSizeIndex];
142   // Compute the decomposition.
143   uint64_t CurSize = Size;
144   uint64_t Offset = 0;
145   while (CurSize && LoadSizeIndex < Options.LoadSizes.size()) {
146     const unsigned LoadSize = Options.LoadSizes[LoadSizeIndex];
147     assert(LoadSize > 0 && "zero load size");
148     const uint64_t NumLoadsForThisSize = CurSize / LoadSize;
149     if (LoadSequence.size() + NumLoadsForThisSize > MaxNumLoads) {
150       // Do not expand if the total number of loads is larger than what the
151       // target allows. Note that it's important that we exit before completing
152       // the expansion to avoid using a ton of memory to store the expansion for
153       // large sizes.
154       LoadSequence.clear();
155       return;
156     }
157     if (NumLoadsForThisSize > 0) {
158       for (uint64_t I = 0; I < NumLoadsForThisSize; ++I) {
159         LoadSequence.push_back({LoadSize, Offset});
160         Offset += LoadSize;
161       }
162       if (LoadSize > 1) {
163         ++NumLoadsNonOneByte;
164       }
165       CurSize = CurSize % LoadSize;
166     }
167     ++LoadSizeIndex;
168   }
169   assert(LoadSequence.size() <= MaxNumLoads && "broken invariant");
170 }
171 
getNumBlocks()172 unsigned MemCmpExpansion::getNumBlocks() {
173   if (IsUsedForZeroCmp)
174     return getNumLoads() / NumLoadsPerBlockForZeroCmp +
175            (getNumLoads() % NumLoadsPerBlockForZeroCmp != 0 ? 1 : 0);
176   return getNumLoads();
177 }
178 
createLoadCmpBlocks()179 void MemCmpExpansion::createLoadCmpBlocks() {
180   for (unsigned i = 0; i < getNumBlocks(); i++) {
181     BasicBlock *BB = BasicBlock::Create(CI->getContext(), "loadbb",
182                                         EndBlock->getParent(), EndBlock);
183     LoadCmpBlocks.push_back(BB);
184   }
185 }
186 
createResultBlock()187 void MemCmpExpansion::createResultBlock() {
188   ResBlock.BB = BasicBlock::Create(CI->getContext(), "res_block",
189                                    EndBlock->getParent(), EndBlock);
190 }
191 
192 // This function creates the IR instructions for loading and comparing 1 byte.
193 // It loads 1 byte from each source of the memcmp parameters with the given
194 // GEPIndex. It then subtracts the two loaded values and adds this result to the
195 // final phi node for selecting the memcmp result.
emitLoadCompareByteBlock(unsigned BlockIndex,unsigned GEPIndex)196 void MemCmpExpansion::emitLoadCompareByteBlock(unsigned BlockIndex,
197                                                unsigned GEPIndex) {
198   Value *Source1 = CI->getArgOperand(0);
199   Value *Source2 = CI->getArgOperand(1);
200 
201   Builder.SetInsertPoint(LoadCmpBlocks[BlockIndex]);
202   Type *LoadSizeType = Type::getInt8Ty(CI->getContext());
203   // Cast source to LoadSizeType*.
204   if (Source1->getType() != LoadSizeType)
205     Source1 = Builder.CreateBitCast(Source1, LoadSizeType->getPointerTo());
206   if (Source2->getType() != LoadSizeType)
207     Source2 = Builder.CreateBitCast(Source2, LoadSizeType->getPointerTo());
208 
209   // Get the base address using the GEPIndex.
210   if (GEPIndex != 0) {
211     Source1 = Builder.CreateGEP(LoadSizeType, Source1,
212                                 ConstantInt::get(LoadSizeType, GEPIndex));
213     Source2 = Builder.CreateGEP(LoadSizeType, Source2,
214                                 ConstantInt::get(LoadSizeType, GEPIndex));
215   }
216 
217   Value *LoadSrc1 = Builder.CreateLoad(LoadSizeType, Source1);
218   Value *LoadSrc2 = Builder.CreateLoad(LoadSizeType, Source2);
219 
220   LoadSrc1 = Builder.CreateZExt(LoadSrc1, Type::getInt32Ty(CI->getContext()));
221   LoadSrc2 = Builder.CreateZExt(LoadSrc2, Type::getInt32Ty(CI->getContext()));
222   Value *Diff = Builder.CreateSub(LoadSrc1, LoadSrc2);
223 
224   PhiRes->addIncoming(Diff, LoadCmpBlocks[BlockIndex]);
225 
226   if (BlockIndex < (LoadCmpBlocks.size() - 1)) {
227     // Early exit branch if difference found to EndBlock. Otherwise, continue to
228     // next LoadCmpBlock,
229     Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_NE, Diff,
230                                     ConstantInt::get(Diff->getType(), 0));
231     BranchInst *CmpBr =
232         BranchInst::Create(EndBlock, LoadCmpBlocks[BlockIndex + 1], Cmp);
233     Builder.Insert(CmpBr);
234   } else {
235     // The last block has an unconditional branch to EndBlock.
236     BranchInst *CmpBr = BranchInst::Create(EndBlock);
237     Builder.Insert(CmpBr);
238   }
239 }
240 
241 /// Generate an equality comparison for one or more pairs of loaded values.
242 /// This is used in the case where the memcmp() call is compared equal or not
243 /// equal to zero.
getCompareLoadPairs(unsigned BlockIndex,unsigned & LoadIndex)244 Value *MemCmpExpansion::getCompareLoadPairs(unsigned BlockIndex,
245                                             unsigned &LoadIndex) {
246   assert(LoadIndex < getNumLoads() &&
247          "getCompareLoadPairs() called with no remaining loads");
248   std::vector<Value *> XorList, OrList;
249   Value *Diff;
250 
251   const unsigned NumLoads =
252       std::min(getNumLoads() - LoadIndex, NumLoadsPerBlockForZeroCmp);
253 
254   // For a single-block expansion, start inserting before the memcmp call.
255   if (LoadCmpBlocks.empty())
256     Builder.SetInsertPoint(CI);
257   else
258     Builder.SetInsertPoint(LoadCmpBlocks[BlockIndex]);
259 
260   Value *Cmp = nullptr;
261   // If we have multiple loads per block, we need to generate a composite
262   // comparison using xor+or. The type for the combinations is the largest load
263   // type.
264   IntegerType *const MaxLoadType =
265       NumLoads == 1 ? nullptr
266                     : IntegerType::get(CI->getContext(), MaxLoadSize * 8);
267   for (unsigned i = 0; i < NumLoads; ++i, ++LoadIndex) {
268     const LoadEntry &CurLoadEntry = LoadSequence[LoadIndex];
269 
270     IntegerType *LoadSizeType =
271         IntegerType::get(CI->getContext(), CurLoadEntry.LoadSize * 8);
272 
273     Value *Source1 = CI->getArgOperand(0);
274     Value *Source2 = CI->getArgOperand(1);
275 
276     // Cast source to LoadSizeType*.
277     if (Source1->getType() != LoadSizeType)
278       Source1 = Builder.CreateBitCast(Source1, LoadSizeType->getPointerTo());
279     if (Source2->getType() != LoadSizeType)
280       Source2 = Builder.CreateBitCast(Source2, LoadSizeType->getPointerTo());
281 
282     // Get the base address using a GEP.
283     if (CurLoadEntry.Offset != 0) {
284       Source1 = Builder.CreateGEP(
285           LoadSizeType, Source1,
286           ConstantInt::get(LoadSizeType, CurLoadEntry.getGEPIndex()));
287       Source2 = Builder.CreateGEP(
288           LoadSizeType, Source2,
289           ConstantInt::get(LoadSizeType, CurLoadEntry.getGEPIndex()));
290     }
291 
292     // Get a constant or load a value for each source address.
293     Value *LoadSrc1 = nullptr;
294     if (auto *Source1C = dyn_cast<Constant>(Source1))
295       LoadSrc1 = ConstantFoldLoadFromConstPtr(Source1C, LoadSizeType, DL);
296     if (!LoadSrc1)
297       LoadSrc1 = Builder.CreateLoad(LoadSizeType, Source1);
298 
299     Value *LoadSrc2 = nullptr;
300     if (auto *Source2C = dyn_cast<Constant>(Source2))
301       LoadSrc2 = ConstantFoldLoadFromConstPtr(Source2C, LoadSizeType, DL);
302     if (!LoadSrc2)
303       LoadSrc2 = Builder.CreateLoad(LoadSizeType, Source2);
304 
305     if (NumLoads != 1) {
306       if (LoadSizeType != MaxLoadType) {
307         LoadSrc1 = Builder.CreateZExt(LoadSrc1, MaxLoadType);
308         LoadSrc2 = Builder.CreateZExt(LoadSrc2, MaxLoadType);
309       }
310       // If we have multiple loads per block, we need to generate a composite
311       // comparison using xor+or.
312       Diff = Builder.CreateXor(LoadSrc1, LoadSrc2);
313       Diff = Builder.CreateZExt(Diff, MaxLoadType);
314       XorList.push_back(Diff);
315     } else {
316       // If there's only one load per block, we just compare the loaded values.
317       Cmp = Builder.CreateICmpNE(LoadSrc1, LoadSrc2);
318     }
319   }
320 
321   auto pairWiseOr = [&](std::vector<Value *> &InList) -> std::vector<Value *> {
322     std::vector<Value *> OutList;
323     for (unsigned i = 0; i < InList.size() - 1; i = i + 2) {
324       Value *Or = Builder.CreateOr(InList[i], InList[i + 1]);
325       OutList.push_back(Or);
326     }
327     if (InList.size() % 2 != 0)
328       OutList.push_back(InList.back());
329     return OutList;
330   };
331 
332   if (!Cmp) {
333     // Pairwise OR the XOR results.
334     OrList = pairWiseOr(XorList);
335 
336     // Pairwise OR the OR results until one result left.
337     while (OrList.size() != 1) {
338       OrList = pairWiseOr(OrList);
339     }
340     Cmp = Builder.CreateICmpNE(OrList[0], ConstantInt::get(Diff->getType(), 0));
341   }
342 
343   return Cmp;
344 }
345 
emitLoadCompareBlockMultipleLoads(unsigned BlockIndex,unsigned & LoadIndex)346 void MemCmpExpansion::emitLoadCompareBlockMultipleLoads(unsigned BlockIndex,
347                                                         unsigned &LoadIndex) {
348   Value *Cmp = getCompareLoadPairs(BlockIndex, LoadIndex);
349 
350   BasicBlock *NextBB = (BlockIndex == (LoadCmpBlocks.size() - 1))
351                            ? EndBlock
352                            : LoadCmpBlocks[BlockIndex + 1];
353   // Early exit branch if difference found to ResultBlock. Otherwise,
354   // continue to next LoadCmpBlock or EndBlock.
355   BranchInst *CmpBr = BranchInst::Create(ResBlock.BB, NextBB, Cmp);
356   Builder.Insert(CmpBr);
357 
358   // Add a phi edge for the last LoadCmpBlock to Endblock with a value of 0
359   // since early exit to ResultBlock was not taken (no difference was found in
360   // any of the bytes).
361   if (BlockIndex == LoadCmpBlocks.size() - 1) {
362     Value *Zero = ConstantInt::get(Type::getInt32Ty(CI->getContext()), 0);
363     PhiRes->addIncoming(Zero, LoadCmpBlocks[BlockIndex]);
364   }
365 }
366 
367 // This function creates the IR intructions for loading and comparing using the
368 // given LoadSize. It loads the number of bytes specified by LoadSize from each
369 // source of the memcmp parameters. It then does a subtract to see if there was
370 // a difference in the loaded values. If a difference is found, it branches
371 // with an early exit to the ResultBlock for calculating which source was
372 // larger. Otherwise, it falls through to the either the next LoadCmpBlock or
373 // the EndBlock if this is the last LoadCmpBlock. Loading 1 byte is handled with
374 // a special case through emitLoadCompareByteBlock. The special handling can
375 // simply subtract the loaded values and add it to the result phi node.
emitLoadCompareBlock(unsigned BlockIndex)376 void MemCmpExpansion::emitLoadCompareBlock(unsigned BlockIndex) {
377   // There is one load per block in this case, BlockIndex == LoadIndex.
378   const LoadEntry &CurLoadEntry = LoadSequence[BlockIndex];
379 
380   if (CurLoadEntry.LoadSize == 1) {
381     MemCmpExpansion::emitLoadCompareByteBlock(BlockIndex,
382                                               CurLoadEntry.getGEPIndex());
383     return;
384   }
385 
386   Type *LoadSizeType =
387       IntegerType::get(CI->getContext(), CurLoadEntry.LoadSize * 8);
388   Type *MaxLoadType = IntegerType::get(CI->getContext(), MaxLoadSize * 8);
389   assert(CurLoadEntry.LoadSize <= MaxLoadSize && "Unexpected load type");
390 
391   Value *Source1 = CI->getArgOperand(0);
392   Value *Source2 = CI->getArgOperand(1);
393 
394   Builder.SetInsertPoint(LoadCmpBlocks[BlockIndex]);
395   // Cast source to LoadSizeType*.
396   if (Source1->getType() != LoadSizeType)
397     Source1 = Builder.CreateBitCast(Source1, LoadSizeType->getPointerTo());
398   if (Source2->getType() != LoadSizeType)
399     Source2 = Builder.CreateBitCast(Source2, LoadSizeType->getPointerTo());
400 
401   // Get the base address using a GEP.
402   if (CurLoadEntry.Offset != 0) {
403     Source1 = Builder.CreateGEP(
404         LoadSizeType, Source1,
405         ConstantInt::get(LoadSizeType, CurLoadEntry.getGEPIndex()));
406     Source2 = Builder.CreateGEP(
407         LoadSizeType, Source2,
408         ConstantInt::get(LoadSizeType, CurLoadEntry.getGEPIndex()));
409   }
410 
411   // Load LoadSizeType from the base address.
412   Value *LoadSrc1 = Builder.CreateLoad(LoadSizeType, Source1);
413   Value *LoadSrc2 = Builder.CreateLoad(LoadSizeType, Source2);
414 
415   if (DL.isLittleEndian()) {
416     Function *Bswap = Intrinsic::getDeclaration(CI->getModule(),
417                                                 Intrinsic::bswap, LoadSizeType);
418     LoadSrc1 = Builder.CreateCall(Bswap, LoadSrc1);
419     LoadSrc2 = Builder.CreateCall(Bswap, LoadSrc2);
420   }
421 
422   if (LoadSizeType != MaxLoadType) {
423     LoadSrc1 = Builder.CreateZExt(LoadSrc1, MaxLoadType);
424     LoadSrc2 = Builder.CreateZExt(LoadSrc2, MaxLoadType);
425   }
426 
427   // Add the loaded values to the phi nodes for calculating memcmp result only
428   // if result is not used in a zero equality.
429   if (!IsUsedForZeroCmp) {
430     ResBlock.PhiSrc1->addIncoming(LoadSrc1, LoadCmpBlocks[BlockIndex]);
431     ResBlock.PhiSrc2->addIncoming(LoadSrc2, LoadCmpBlocks[BlockIndex]);
432   }
433 
434   Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_EQ, LoadSrc1, LoadSrc2);
435   BasicBlock *NextBB = (BlockIndex == (LoadCmpBlocks.size() - 1))
436                            ? EndBlock
437                            : LoadCmpBlocks[BlockIndex + 1];
438   // Early exit branch if difference found to ResultBlock. Otherwise, continue
439   // to next LoadCmpBlock or EndBlock.
440   BranchInst *CmpBr = BranchInst::Create(NextBB, ResBlock.BB, Cmp);
441   Builder.Insert(CmpBr);
442 
443   // Add a phi edge for the last LoadCmpBlock to Endblock with a value of 0
444   // since early exit to ResultBlock was not taken (no difference was found in
445   // any of the bytes).
446   if (BlockIndex == LoadCmpBlocks.size() - 1) {
447     Value *Zero = ConstantInt::get(Type::getInt32Ty(CI->getContext()), 0);
448     PhiRes->addIncoming(Zero, LoadCmpBlocks[BlockIndex]);
449   }
450 }
451 
452 // This function populates the ResultBlock with a sequence to calculate the
453 // memcmp result. It compares the two loaded source values and returns -1 if
454 // src1 < src2 and 1 if src1 > src2.
emitMemCmpResultBlock()455 void MemCmpExpansion::emitMemCmpResultBlock() {
456   // Special case: if memcmp result is used in a zero equality, result does not
457   // need to be calculated and can simply return 1.
458   if (IsUsedForZeroCmp) {
459     BasicBlock::iterator InsertPt = ResBlock.BB->getFirstInsertionPt();
460     Builder.SetInsertPoint(ResBlock.BB, InsertPt);
461     Value *Res = ConstantInt::get(Type::getInt32Ty(CI->getContext()), 1);
462     PhiRes->addIncoming(Res, ResBlock.BB);
463     BranchInst *NewBr = BranchInst::Create(EndBlock);
464     Builder.Insert(NewBr);
465     return;
466   }
467   BasicBlock::iterator InsertPt = ResBlock.BB->getFirstInsertionPt();
468   Builder.SetInsertPoint(ResBlock.BB, InsertPt);
469 
470   Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_ULT, ResBlock.PhiSrc1,
471                                   ResBlock.PhiSrc2);
472 
473   Value *Res =
474       Builder.CreateSelect(Cmp, ConstantInt::get(Builder.getInt32Ty(), -1),
475                            ConstantInt::get(Builder.getInt32Ty(), 1));
476 
477   BranchInst *NewBr = BranchInst::Create(EndBlock);
478   Builder.Insert(NewBr);
479   PhiRes->addIncoming(Res, ResBlock.BB);
480 }
481 
setupResultBlockPHINodes()482 void MemCmpExpansion::setupResultBlockPHINodes() {
483   Type *MaxLoadType = IntegerType::get(CI->getContext(), MaxLoadSize * 8);
484   Builder.SetInsertPoint(ResBlock.BB);
485   // Note: this assumes one load per block.
486   ResBlock.PhiSrc1 =
487       Builder.CreatePHI(MaxLoadType, NumLoadsNonOneByte, "phi.src1");
488   ResBlock.PhiSrc2 =
489       Builder.CreatePHI(MaxLoadType, NumLoadsNonOneByte, "phi.src2");
490 }
491 
setupEndBlockPHINodes()492 void MemCmpExpansion::setupEndBlockPHINodes() {
493   Builder.SetInsertPoint(&EndBlock->front());
494   PhiRes = Builder.CreatePHI(Type::getInt32Ty(CI->getContext()), 2, "phi.res");
495 }
496 
getMemCmpExpansionZeroCase()497 Value *MemCmpExpansion::getMemCmpExpansionZeroCase() {
498   unsigned LoadIndex = 0;
499   // This loop populates each of the LoadCmpBlocks with the IR sequence to
500   // handle multiple loads per block.
501   for (unsigned I = 0; I < getNumBlocks(); ++I) {
502     emitLoadCompareBlockMultipleLoads(I, LoadIndex);
503   }
504 
505   emitMemCmpResultBlock();
506   return PhiRes;
507 }
508 
509 /// A memcmp expansion that compares equality with 0 and only has one block of
510 /// load and compare can bypass the compare, branch, and phi IR that is required
511 /// in the general case.
getMemCmpEqZeroOneBlock()512 Value *MemCmpExpansion::getMemCmpEqZeroOneBlock() {
513   unsigned LoadIndex = 0;
514   Value *Cmp = getCompareLoadPairs(0, LoadIndex);
515   assert(LoadIndex == getNumLoads() && "some entries were not consumed");
516   return Builder.CreateZExt(Cmp, Type::getInt32Ty(CI->getContext()));
517 }
518 
519 /// A memcmp expansion that only has one block of load and compare can bypass
520 /// the compare, branch, and phi IR that is required in the general case.
getMemCmpOneBlock()521 Value *MemCmpExpansion::getMemCmpOneBlock() {
522   Type *LoadSizeType = IntegerType::get(CI->getContext(), Size * 8);
523   Value *Source1 = CI->getArgOperand(0);
524   Value *Source2 = CI->getArgOperand(1);
525 
526   // Cast source to LoadSizeType*.
527   if (Source1->getType() != LoadSizeType)
528     Source1 = Builder.CreateBitCast(Source1, LoadSizeType->getPointerTo());
529   if (Source2->getType() != LoadSizeType)
530     Source2 = Builder.CreateBitCast(Source2, LoadSizeType->getPointerTo());
531 
532   // Load LoadSizeType from the base address.
533   Value *LoadSrc1 = Builder.CreateLoad(LoadSizeType, Source1);
534   Value *LoadSrc2 = Builder.CreateLoad(LoadSizeType, Source2);
535 
536   if (DL.isLittleEndian() && Size != 1) {
537     Function *Bswap = Intrinsic::getDeclaration(CI->getModule(),
538                                                 Intrinsic::bswap, LoadSizeType);
539     LoadSrc1 = Builder.CreateCall(Bswap, LoadSrc1);
540     LoadSrc2 = Builder.CreateCall(Bswap, LoadSrc2);
541   }
542 
543   if (Size < 4) {
544     // The i8 and i16 cases don't need compares. We zext the loaded values and
545     // subtract them to get the suitable negative, zero, or positive i32 result.
546     LoadSrc1 = Builder.CreateZExt(LoadSrc1, Builder.getInt32Ty());
547     LoadSrc2 = Builder.CreateZExt(LoadSrc2, Builder.getInt32Ty());
548     return Builder.CreateSub(LoadSrc1, LoadSrc2);
549   }
550 
551   // The result of memcmp is negative, zero, or positive, so produce that by
552   // subtracting 2 extended compare bits: sub (ugt, ult).
553   // If a target prefers to use selects to get -1/0/1, they should be able
554   // to transform this later. The inverse transform (going from selects to math)
555   // may not be possible in the DAG because the selects got converted into
556   // branches before we got there.
557   Value *CmpUGT = Builder.CreateICmpUGT(LoadSrc1, LoadSrc2);
558   Value *CmpULT = Builder.CreateICmpULT(LoadSrc1, LoadSrc2);
559   Value *ZextUGT = Builder.CreateZExt(CmpUGT, Builder.getInt32Ty());
560   Value *ZextULT = Builder.CreateZExt(CmpULT, Builder.getInt32Ty());
561   return Builder.CreateSub(ZextUGT, ZextULT);
562 }
563 
564 // This function expands the memcmp call into an inline expansion and returns
565 // the memcmp result.
getMemCmpExpansion()566 Value *MemCmpExpansion::getMemCmpExpansion() {
567   // Create the basic block framework for a multi-block expansion.
568   if (getNumBlocks() != 1) {
569     BasicBlock *StartBlock = CI->getParent();
570     EndBlock = StartBlock->splitBasicBlock(CI, "endblock");
571     setupEndBlockPHINodes();
572     createResultBlock();
573 
574     // If return value of memcmp is not used in a zero equality, we need to
575     // calculate which source was larger. The calculation requires the
576     // two loaded source values of each load compare block.
577     // These will be saved in the phi nodes created by setupResultBlockPHINodes.
578     if (!IsUsedForZeroCmp) setupResultBlockPHINodes();
579 
580     // Create the number of required load compare basic blocks.
581     createLoadCmpBlocks();
582 
583     // Update the terminator added by splitBasicBlock to branch to the first
584     // LoadCmpBlock.
585     StartBlock->getTerminator()->setSuccessor(0, LoadCmpBlocks[0]);
586   }
587 
588   Builder.SetCurrentDebugLocation(CI->getDebugLoc());
589 
590   if (IsUsedForZeroCmp)
591     return getNumBlocks() == 1 ? getMemCmpEqZeroOneBlock()
592                                : getMemCmpExpansionZeroCase();
593 
594   if (getNumBlocks() == 1)
595     return getMemCmpOneBlock();
596 
597   for (unsigned I = 0; I < getNumBlocks(); ++I) {
598     emitLoadCompareBlock(I);
599   }
600 
601   emitMemCmpResultBlock();
602   return PhiRes;
603 }
604 
605 // This function checks to see if an expansion of memcmp can be generated.
606 // It checks for constant compare size that is less than the max inline size.
607 // If an expansion cannot occur, returns false to leave as a library call.
608 // Otherwise, the library call is replaced with a new IR instruction sequence.
609 /// We want to transform:
610 /// %call = call signext i32 @memcmp(i8* %0, i8* %1, i64 15)
611 /// To:
612 /// loadbb:
613 ///  %0 = bitcast i32* %buffer2 to i8*
614 ///  %1 = bitcast i32* %buffer1 to i8*
615 ///  %2 = bitcast i8* %1 to i64*
616 ///  %3 = bitcast i8* %0 to i64*
617 ///  %4 = load i64, i64* %2
618 ///  %5 = load i64, i64* %3
619 ///  %6 = call i64 @llvm.bswap.i64(i64 %4)
620 ///  %7 = call i64 @llvm.bswap.i64(i64 %5)
621 ///  %8 = sub i64 %6, %7
622 ///  %9 = icmp ne i64 %8, 0
623 ///  br i1 %9, label %res_block, label %loadbb1
624 /// res_block:                                        ; preds = %loadbb2,
625 /// %loadbb1, %loadbb
626 ///  %phi.src1 = phi i64 [ %6, %loadbb ], [ %22, %loadbb1 ], [ %36, %loadbb2 ]
627 ///  %phi.src2 = phi i64 [ %7, %loadbb ], [ %23, %loadbb1 ], [ %37, %loadbb2 ]
628 ///  %10 = icmp ult i64 %phi.src1, %phi.src2
629 ///  %11 = select i1 %10, i32 -1, i32 1
630 ///  br label %endblock
631 /// loadbb1:                                          ; preds = %loadbb
632 ///  %12 = bitcast i32* %buffer2 to i8*
633 ///  %13 = bitcast i32* %buffer1 to i8*
634 ///  %14 = bitcast i8* %13 to i32*
635 ///  %15 = bitcast i8* %12 to i32*
636 ///  %16 = getelementptr i32, i32* %14, i32 2
637 ///  %17 = getelementptr i32, i32* %15, i32 2
638 ///  %18 = load i32, i32* %16
639 ///  %19 = load i32, i32* %17
640 ///  %20 = call i32 @llvm.bswap.i32(i32 %18)
641 ///  %21 = call i32 @llvm.bswap.i32(i32 %19)
642 ///  %22 = zext i32 %20 to i64
643 ///  %23 = zext i32 %21 to i64
644 ///  %24 = sub i64 %22, %23
645 ///  %25 = icmp ne i64 %24, 0
646 ///  br i1 %25, label %res_block, label %loadbb2
647 /// loadbb2:                                          ; preds = %loadbb1
648 ///  %26 = bitcast i32* %buffer2 to i8*
649 ///  %27 = bitcast i32* %buffer1 to i8*
650 ///  %28 = bitcast i8* %27 to i16*
651 ///  %29 = bitcast i8* %26 to i16*
652 ///  %30 = getelementptr i16, i16* %28, i16 6
653 ///  %31 = getelementptr i16, i16* %29, i16 6
654 ///  %32 = load i16, i16* %30
655 ///  %33 = load i16, i16* %31
656 ///  %34 = call i16 @llvm.bswap.i16(i16 %32)
657 ///  %35 = call i16 @llvm.bswap.i16(i16 %33)
658 ///  %36 = zext i16 %34 to i64
659 ///  %37 = zext i16 %35 to i64
660 ///  %38 = sub i64 %36, %37
661 ///  %39 = icmp ne i64 %38, 0
662 ///  br i1 %39, label %res_block, label %loadbb3
663 /// loadbb3:                                          ; preds = %loadbb2
664 ///  %40 = bitcast i32* %buffer2 to i8*
665 ///  %41 = bitcast i32* %buffer1 to i8*
666 ///  %42 = getelementptr i8, i8* %41, i8 14
667 ///  %43 = getelementptr i8, i8* %40, i8 14
668 ///  %44 = load i8, i8* %42
669 ///  %45 = load i8, i8* %43
670 ///  %46 = zext i8 %44 to i32
671 ///  %47 = zext i8 %45 to i32
672 ///  %48 = sub i32 %46, %47
673 ///  br label %endblock
674 /// endblock:                                         ; preds = %res_block,
675 /// %loadbb3
676 ///  %phi.res = phi i32 [ %48, %loadbb3 ], [ %11, %res_block ]
677 ///  ret i32 %phi.res
expandMemCmp(CallInst * CI,const TargetTransformInfo * TTI,const TargetLowering * TLI,const DataLayout * DL)678 static bool expandMemCmp(CallInst *CI, const TargetTransformInfo *TTI,
679                          const TargetLowering *TLI, const DataLayout *DL) {
680   NumMemCmpCalls++;
681 
682   // Early exit from expansion if -Oz.
683   if (CI->getFunction()->optForMinSize())
684     return false;
685 
686   // Early exit from expansion if size is not a constant.
687   ConstantInt *SizeCast = dyn_cast<ConstantInt>(CI->getArgOperand(2));
688   if (!SizeCast) {
689     NumMemCmpNotConstant++;
690     return false;
691   }
692   const uint64_t SizeVal = SizeCast->getZExtValue();
693 
694   if (SizeVal == 0) {
695     return false;
696   }
697 
698   // TTI call to check if target would like to expand memcmp. Also, get the
699   // available load sizes.
700   const bool IsUsedForZeroCmp = isOnlyUsedInZeroEqualityComparison(CI);
701   const auto *const Options = TTI->enableMemCmpExpansion(IsUsedForZeroCmp);
702   if (!Options) return false;
703 
704   const unsigned MaxNumLoads =
705       TLI->getMaxExpandSizeMemcmp(CI->getFunction()->optForSize());
706 
707   unsigned NumLoadsPerBlock = MemCmpEqZeroNumLoadsPerBlock.getNumOccurrences()
708                                   ? MemCmpEqZeroNumLoadsPerBlock
709                                   : TLI->getMemcmpEqZeroLoadsPerBlock();
710 
711   MemCmpExpansion Expansion(CI, SizeVal, *Options, MaxNumLoads,
712                             IsUsedForZeroCmp, NumLoadsPerBlock, *DL);
713 
714   // Don't expand if this will require more loads than desired by the target.
715   if (Expansion.getNumLoads() == 0) {
716     NumMemCmpGreaterThanMax++;
717     return false;
718   }
719 
720   NumMemCmpInlined++;
721 
722   Value *Res = Expansion.getMemCmpExpansion();
723 
724   // Replace call with result of expansion and erase call.
725   CI->replaceAllUsesWith(Res);
726   CI->eraseFromParent();
727 
728   return true;
729 }
730 
731 
732 
733 class ExpandMemCmpPass : public FunctionPass {
734 public:
735   static char ID;
736 
ExpandMemCmpPass()737   ExpandMemCmpPass() : FunctionPass(ID) {
738     initializeExpandMemCmpPassPass(*PassRegistry::getPassRegistry());
739   }
740 
runOnFunction(Function & F)741   bool runOnFunction(Function &F) override {
742     if (skipFunction(F)) return false;
743 
744     auto *TPC = getAnalysisIfAvailable<TargetPassConfig>();
745     if (!TPC) {
746       return false;
747     }
748     const TargetLowering* TL =
749         TPC->getTM<TargetMachine>().getSubtargetImpl(F)->getTargetLowering();
750 
751     const TargetLibraryInfo *TLI =
752         &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI();
753     const TargetTransformInfo *TTI =
754         &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
755     auto PA = runImpl(F, TLI, TTI, TL);
756     return !PA.areAllPreserved();
757   }
758 
759 private:
getAnalysisUsage(AnalysisUsage & AU) const760   void getAnalysisUsage(AnalysisUsage &AU) const override {
761     AU.addRequired<TargetLibraryInfoWrapperPass>();
762     AU.addRequired<TargetTransformInfoWrapperPass>();
763     FunctionPass::getAnalysisUsage(AU);
764   }
765 
766   PreservedAnalyses runImpl(Function &F, const TargetLibraryInfo *TLI,
767                             const TargetTransformInfo *TTI,
768                             const TargetLowering* TL);
769   // Returns true if a change was made.
770   bool runOnBlock(BasicBlock &BB, const TargetLibraryInfo *TLI,
771                   const TargetTransformInfo *TTI, const TargetLowering* TL,
772                   const DataLayout& DL);
773 };
774 
runOnBlock(BasicBlock & BB,const TargetLibraryInfo * TLI,const TargetTransformInfo * TTI,const TargetLowering * TL,const DataLayout & DL)775 bool ExpandMemCmpPass::runOnBlock(
776     BasicBlock &BB, const TargetLibraryInfo *TLI,
777     const TargetTransformInfo *TTI, const TargetLowering* TL,
778     const DataLayout& DL) {
779   for (Instruction& I : BB) {
780     CallInst *CI = dyn_cast<CallInst>(&I);
781     if (!CI) {
782       continue;
783     }
784     LibFunc Func;
785     if (TLI->getLibFunc(ImmutableCallSite(CI), Func) &&
786         Func == LibFunc_memcmp && expandMemCmp(CI, TTI, TL, &DL)) {
787       return true;
788     }
789   }
790   return false;
791 }
792 
793 
runImpl(Function & F,const TargetLibraryInfo * TLI,const TargetTransformInfo * TTI,const TargetLowering * TL)794 PreservedAnalyses ExpandMemCmpPass::runImpl(
795     Function &F, const TargetLibraryInfo *TLI, const TargetTransformInfo *TTI,
796     const TargetLowering* TL) {
797   const DataLayout& DL = F.getParent()->getDataLayout();
798   bool MadeChanges = false;
799   for (auto BBIt = F.begin(); BBIt != F.end();) {
800     if (runOnBlock(*BBIt, TLI, TTI, TL, DL)) {
801       MadeChanges = true;
802       // If changes were made, restart the function from the beginning, since
803       // the structure of the function was changed.
804       BBIt = F.begin();
805     } else {
806       ++BBIt;
807     }
808   }
809   return MadeChanges ? PreservedAnalyses::none() : PreservedAnalyses::all();
810 }
811 
812 } // namespace
813 
814 char ExpandMemCmpPass::ID = 0;
815 INITIALIZE_PASS_BEGIN(ExpandMemCmpPass, "expandmemcmp",
816                       "Expand memcmp() to load/stores", false, false)
INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)817 INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
818 INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
819 INITIALIZE_PASS_END(ExpandMemCmpPass, "expandmemcmp",
820                     "Expand memcmp() to load/stores", false, false)
821 
822 FunctionPass *llvm::createExpandMemCmpPass() {
823   return new ExpandMemCmpPass();
824 }
825