1 //===------- VectorCombine.cpp - Optimize partial vector operations -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass optimizes scalar/vector interactions using target cost models. The
10 // transforms implemented here may not fit in traditional loop-based or SLP
11 // vectorization passes.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "llvm/Transforms/Vectorize/VectorCombine.h"
16 #include "llvm/ADT/Statistic.h"
17 #include "llvm/Analysis/BasicAliasAnalysis.h"
18 #include "llvm/Analysis/GlobalsModRef.h"
19 #include "llvm/Analysis/Loads.h"
20 #include "llvm/Analysis/TargetTransformInfo.h"
21 #include "llvm/Analysis/ValueTracking.h"
22 #include "llvm/Analysis/VectorUtils.h"
23 #include "llvm/IR/Dominators.h"
24 #include "llvm/IR/Function.h"
25 #include "llvm/IR/IRBuilder.h"
26 #include "llvm/IR/PatternMatch.h"
27 #include "llvm/InitializePasses.h"
28 #include "llvm/Pass.h"
29 #include "llvm/Support/CommandLine.h"
30 #include "llvm/Transforms/Utils/Local.h"
31 #include "llvm/Transforms/Vectorize.h"
32 
33 using namespace llvm;
34 using namespace llvm::PatternMatch;
35 
36 #define DEBUG_TYPE "vector-combine"
37 STATISTIC(NumVecLoad, "Number of vector loads formed");
38 STATISTIC(NumVecCmp, "Number of vector compares formed");
39 STATISTIC(NumVecBO, "Number of vector binops formed");
40 STATISTIC(NumVecCmpBO, "Number of vector compare + binop formed");
41 STATISTIC(NumShufOfBitcast, "Number of shuffles moved after bitcast");
42 STATISTIC(NumScalarBO, "Number of scalar binops formed");
43 STATISTIC(NumScalarCmp, "Number of scalar compares formed");
44 
45 static cl::opt<bool> DisableVectorCombine(
46     "disable-vector-combine", cl::init(false), cl::Hidden,
47     cl::desc("Disable all vector combine transforms"));
48 
49 static cl::opt<bool> DisableBinopExtractShuffle(
50     "disable-binop-extract-shuffle", cl::init(false), cl::Hidden,
51     cl::desc("Disable binop extract to shuffle transforms"));
52 
53 static const unsigned InvalidIndex = std::numeric_limits<unsigned>::max();
54 
55 namespace {
56 class VectorCombine {
57 public:
VectorCombine(Function & F,const TargetTransformInfo & TTI,const DominatorTree & DT)58   VectorCombine(Function &F, const TargetTransformInfo &TTI,
59                 const DominatorTree &DT)
60       : F(F), Builder(F.getContext()), TTI(TTI), DT(DT) {}
61 
62   bool run();
63 
64 private:
65   Function &F;
66   IRBuilder<> Builder;
67   const TargetTransformInfo &TTI;
68   const DominatorTree &DT;
69 
70   bool vectorizeLoadInsert(Instruction &I);
71   ExtractElementInst *getShuffleExtract(ExtractElementInst *Ext0,
72                                         ExtractElementInst *Ext1,
73                                         unsigned PreferredExtractIndex) const;
74   bool isExtractExtractCheap(ExtractElementInst *Ext0, ExtractElementInst *Ext1,
75                              unsigned Opcode,
76                              ExtractElementInst *&ConvertToShuffle,
77                              unsigned PreferredExtractIndex);
78   void foldExtExtCmp(ExtractElementInst *Ext0, ExtractElementInst *Ext1,
79                      Instruction &I);
80   void foldExtExtBinop(ExtractElementInst *Ext0, ExtractElementInst *Ext1,
81                        Instruction &I);
82   bool foldExtractExtract(Instruction &I);
83   bool foldBitcastShuf(Instruction &I);
84   bool scalarizeBinopOrCmp(Instruction &I);
85   bool foldExtractedCmps(Instruction &I);
86 };
87 } // namespace
88 
replaceValue(Value & Old,Value & New)89 static void replaceValue(Value &Old, Value &New) {
90   Old.replaceAllUsesWith(&New);
91   New.takeName(&Old);
92 }
93 
vectorizeLoadInsert(Instruction & I)94 bool VectorCombine::vectorizeLoadInsert(Instruction &I) {
95   // Match insert into fixed vector of scalar value.
96   auto *Ty = dyn_cast<FixedVectorType>(I.getType());
97   Value *Scalar;
98   if (!Ty || !match(&I, m_InsertElt(m_Undef(), m_Value(Scalar), m_ZeroInt())) ||
99       !Scalar->hasOneUse())
100     return false;
101 
102   // Optionally match an extract from another vector.
103   Value *X;
104   bool HasExtract = match(Scalar, m_ExtractElt(m_Value(X), m_ZeroInt()));
105   if (!HasExtract)
106     X = Scalar;
107 
108   // Match source value as load of scalar or vector.
109   // Do not vectorize scalar load (widening) if atomic/volatile or under
110   // asan/hwasan/memtag/tsan. The widened load may load data from dirty regions
111   // or create data races non-existent in the source.
112   auto *Load = dyn_cast<LoadInst>(X);
113   if (!Load || !Load->isSimple() || !Load->hasOneUse() ||
114       Load->getFunction()->hasFnAttribute(Attribute::SanitizeMemTag) ||
115       mustSuppressSpeculation(*Load))
116     return false;
117 
118   // TODO: Extend this to match GEP with constant offsets.
119   Value *PtrOp = Load->getPointerOperand()->stripPointerCasts();
120   assert(isa<PointerType>(PtrOp->getType()) && "Expected a pointer type");
121   unsigned AS = Load->getPointerAddressSpace();
122 
123   // If original AS != Load's AS, we can't bitcast the original pointer and have
124   // to use Load's operand instead. Ideally we would want to strip pointer casts
125   // without changing AS, but there's no API to do that ATM.
126   if (AS != PtrOp->getType()->getPointerAddressSpace())
127     PtrOp = Load->getPointerOperand();
128 
129   Type *ScalarTy = Scalar->getType();
130   uint64_t ScalarSize = ScalarTy->getPrimitiveSizeInBits();
131   unsigned MinVectorSize = TTI.getMinVectorRegisterBitWidth();
132   if (!ScalarSize || !MinVectorSize || MinVectorSize % ScalarSize != 0)
133     return false;
134 
135   // Check safety of replacing the scalar load with a larger vector load.
136   unsigned MinVecNumElts = MinVectorSize / ScalarSize;
137   auto *MinVecTy = VectorType::get(ScalarTy, MinVecNumElts, false);
138   Align Alignment = Load->getAlign();
139   const DataLayout &DL = I.getModule()->getDataLayout();
140   if (!isSafeToLoadUnconditionally(PtrOp, MinVecTy, Alignment, DL, Load, &DT))
141     return false;
142 
143 
144   // Original pattern: insertelt undef, load [free casts of] PtrOp, 0
145   Type *LoadTy = Load->getType();
146   int OldCost = TTI.getMemoryOpCost(Instruction::Load, LoadTy, Alignment, AS);
147   APInt DemandedElts = APInt::getOneBitSet(MinVecNumElts, 0);
148   OldCost += TTI.getScalarizationOverhead(MinVecTy, DemandedElts,
149                                           /* Insert */ true, HasExtract);
150 
151   // New pattern: load VecPtr
152   int NewCost = TTI.getMemoryOpCost(Instruction::Load, MinVecTy, Alignment, AS);
153 
154   // We can aggressively convert to the vector form because the backend can
155   // invert this transform if it does not result in a performance win.
156   if (OldCost < NewCost)
157     return false;
158 
159   // It is safe and potentially profitable to load a vector directly:
160   // inselt undef, load Scalar, 0 --> load VecPtr
161   IRBuilder<> Builder(Load);
162   Value *CastedPtr = Builder.CreateBitCast(PtrOp, MinVecTy->getPointerTo(AS));
163   Value *VecLd = Builder.CreateAlignedLoad(MinVecTy, CastedPtr, Alignment);
164 
165   // If the insert type does not match the target's minimum vector type,
166   // use an identity shuffle to shrink/grow the vector.
167   if (Ty != MinVecTy) {
168     unsigned OutputNumElts = Ty->getNumElements();
169     SmallVector<int, 16> Mask(OutputNumElts, UndefMaskElem);
170     for (unsigned i = 0; i < OutputNumElts && i < MinVecNumElts; ++i)
171       Mask[i] = i;
172     VecLd = Builder.CreateShuffleVector(VecLd, Mask);
173   }
174   replaceValue(I, *VecLd);
175   ++NumVecLoad;
176   return true;
177 }
178 
179 /// Determine which, if any, of the inputs should be replaced by a shuffle
180 /// followed by extract from a different index.
getShuffleExtract(ExtractElementInst * Ext0,ExtractElementInst * Ext1,unsigned PreferredExtractIndex=InvalidIndex) const181 ExtractElementInst *VectorCombine::getShuffleExtract(
182     ExtractElementInst *Ext0, ExtractElementInst *Ext1,
183     unsigned PreferredExtractIndex = InvalidIndex) const {
184   assert(isa<ConstantInt>(Ext0->getIndexOperand()) &&
185          isa<ConstantInt>(Ext1->getIndexOperand()) &&
186          "Expected constant extract indexes");
187 
188   unsigned Index0 = cast<ConstantInt>(Ext0->getIndexOperand())->getZExtValue();
189   unsigned Index1 = cast<ConstantInt>(Ext1->getIndexOperand())->getZExtValue();
190 
191   // If the extract indexes are identical, no shuffle is needed.
192   if (Index0 == Index1)
193     return nullptr;
194 
195   Type *VecTy = Ext0->getVectorOperand()->getType();
196   assert(VecTy == Ext1->getVectorOperand()->getType() && "Need matching types");
197   int Cost0 = TTI.getVectorInstrCost(Ext0->getOpcode(), VecTy, Index0);
198   int Cost1 = TTI.getVectorInstrCost(Ext1->getOpcode(), VecTy, Index1);
199 
200   // We are extracting from 2 different indexes, so one operand must be shuffled
201   // before performing a vector operation and/or extract. The more expensive
202   // extract will be replaced by a shuffle.
203   if (Cost0 > Cost1)
204     return Ext0;
205   if (Cost1 > Cost0)
206     return Ext1;
207 
208   // If the costs are equal and there is a preferred extract index, shuffle the
209   // opposite operand.
210   if (PreferredExtractIndex == Index0)
211     return Ext1;
212   if (PreferredExtractIndex == Index1)
213     return Ext0;
214 
215   // Otherwise, replace the extract with the higher index.
216   return Index0 > Index1 ? Ext0 : Ext1;
217 }
218 
219 /// Compare the relative costs of 2 extracts followed by scalar operation vs.
220 /// vector operation(s) followed by extract. Return true if the existing
221 /// instructions are cheaper than a vector alternative. Otherwise, return false
222 /// and if one of the extracts should be transformed to a shufflevector, set
223 /// \p ConvertToShuffle to that extract instruction.
isExtractExtractCheap(ExtractElementInst * Ext0,ExtractElementInst * Ext1,unsigned Opcode,ExtractElementInst * & ConvertToShuffle,unsigned PreferredExtractIndex)224 bool VectorCombine::isExtractExtractCheap(ExtractElementInst *Ext0,
225                                           ExtractElementInst *Ext1,
226                                           unsigned Opcode,
227                                           ExtractElementInst *&ConvertToShuffle,
228                                           unsigned PreferredExtractIndex) {
229   assert(isa<ConstantInt>(Ext0->getOperand(1)) &&
230          isa<ConstantInt>(Ext1->getOperand(1)) &&
231          "Expected constant extract indexes");
232   Type *ScalarTy = Ext0->getType();
233   auto *VecTy = cast<VectorType>(Ext0->getOperand(0)->getType());
234   int ScalarOpCost, VectorOpCost;
235 
236   // Get cost estimates for scalar and vector versions of the operation.
237   bool IsBinOp = Instruction::isBinaryOp(Opcode);
238   if (IsBinOp) {
239     ScalarOpCost = TTI.getArithmeticInstrCost(Opcode, ScalarTy);
240     VectorOpCost = TTI.getArithmeticInstrCost(Opcode, VecTy);
241   } else {
242     assert((Opcode == Instruction::ICmp || Opcode == Instruction::FCmp) &&
243            "Expected a compare");
244     ScalarOpCost = TTI.getCmpSelInstrCost(Opcode, ScalarTy,
245                                           CmpInst::makeCmpResultType(ScalarTy));
246     VectorOpCost = TTI.getCmpSelInstrCost(Opcode, VecTy,
247                                           CmpInst::makeCmpResultType(VecTy));
248   }
249 
250   // Get cost estimates for the extract elements. These costs will factor into
251   // both sequences.
252   unsigned Ext0Index = cast<ConstantInt>(Ext0->getOperand(1))->getZExtValue();
253   unsigned Ext1Index = cast<ConstantInt>(Ext1->getOperand(1))->getZExtValue();
254 
255   int Extract0Cost =
256       TTI.getVectorInstrCost(Instruction::ExtractElement, VecTy, Ext0Index);
257   int Extract1Cost =
258       TTI.getVectorInstrCost(Instruction::ExtractElement, VecTy, Ext1Index);
259 
260   // A more expensive extract will always be replaced by a splat shuffle.
261   // For example, if Ext0 is more expensive:
262   // opcode (extelt V0, Ext0), (ext V1, Ext1) -->
263   // extelt (opcode (splat V0, Ext0), V1), Ext1
264   // TODO: Evaluate whether that always results in lowest cost. Alternatively,
265   //       check the cost of creating a broadcast shuffle and shuffling both
266   //       operands to element 0.
267   int CheapExtractCost = std::min(Extract0Cost, Extract1Cost);
268 
269   // Extra uses of the extracts mean that we include those costs in the
270   // vector total because those instructions will not be eliminated.
271   int OldCost, NewCost;
272   if (Ext0->getOperand(0) == Ext1->getOperand(0) && Ext0Index == Ext1Index) {
273     // Handle a special case. If the 2 extracts are identical, adjust the
274     // formulas to account for that. The extra use charge allows for either the
275     // CSE'd pattern or an unoptimized form with identical values:
276     // opcode (extelt V, C), (extelt V, C) --> extelt (opcode V, V), C
277     bool HasUseTax = Ext0 == Ext1 ? !Ext0->hasNUses(2)
278                                   : !Ext0->hasOneUse() || !Ext1->hasOneUse();
279     OldCost = CheapExtractCost + ScalarOpCost;
280     NewCost = VectorOpCost + CheapExtractCost + HasUseTax * CheapExtractCost;
281   } else {
282     // Handle the general case. Each extract is actually a different value:
283     // opcode (extelt V0, C0), (extelt V1, C1) --> extelt (opcode V0, V1), C
284     OldCost = Extract0Cost + Extract1Cost + ScalarOpCost;
285     NewCost = VectorOpCost + CheapExtractCost +
286               !Ext0->hasOneUse() * Extract0Cost +
287               !Ext1->hasOneUse() * Extract1Cost;
288   }
289 
290   ConvertToShuffle = getShuffleExtract(Ext0, Ext1, PreferredExtractIndex);
291   if (ConvertToShuffle) {
292     if (IsBinOp && DisableBinopExtractShuffle)
293       return true;
294 
295     // If we are extracting from 2 different indexes, then one operand must be
296     // shuffled before performing the vector operation. The shuffle mask is
297     // undefined except for 1 lane that is being translated to the remaining
298     // extraction lane. Therefore, it is a splat shuffle. Ex:
299     // ShufMask = { undef, undef, 0, undef }
300     // TODO: The cost model has an option for a "broadcast" shuffle
301     //       (splat-from-element-0), but no option for a more general splat.
302     NewCost +=
303         TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, VecTy);
304   }
305 
306   // Aggressively form a vector op if the cost is equal because the transform
307   // may enable further optimization.
308   // Codegen can reverse this transform (scalarize) if it was not profitable.
309   return OldCost < NewCost;
310 }
311 
312 /// Create a shuffle that translates (shifts) 1 element from the input vector
313 /// to a new element location.
createShiftShuffle(Value * Vec,unsigned OldIndex,unsigned NewIndex,IRBuilder<> & Builder)314 static Value *createShiftShuffle(Value *Vec, unsigned OldIndex,
315                                  unsigned NewIndex, IRBuilder<> &Builder) {
316   // The shuffle mask is undefined except for 1 lane that is being translated
317   // to the new element index. Example for OldIndex == 2 and NewIndex == 0:
318   // ShufMask = { 2, undef, undef, undef }
319   auto *VecTy = cast<FixedVectorType>(Vec->getType());
320   SmallVector<int, 32> ShufMask(VecTy->getNumElements(), UndefMaskElem);
321   ShufMask[NewIndex] = OldIndex;
322   return Builder.CreateShuffleVector(Vec, ShufMask, "shift");
323 }
324 
325 /// Given an extract element instruction with constant index operand, shuffle
326 /// the source vector (shift the scalar element) to a NewIndex for extraction.
327 /// Return null if the input can be constant folded, so that we are not creating
328 /// unnecessary instructions.
translateExtract(ExtractElementInst * ExtElt,unsigned NewIndex,IRBuilder<> & Builder)329 static ExtractElementInst *translateExtract(ExtractElementInst *ExtElt,
330                                             unsigned NewIndex,
331                                             IRBuilder<> &Builder) {
332   // If the extract can be constant-folded, this code is unsimplified. Defer
333   // to other passes to handle that.
334   Value *X = ExtElt->getVectorOperand();
335   Value *C = ExtElt->getIndexOperand();
336   assert(isa<ConstantInt>(C) && "Expected a constant index operand");
337   if (isa<Constant>(X))
338     return nullptr;
339 
340   Value *Shuf = createShiftShuffle(X, cast<ConstantInt>(C)->getZExtValue(),
341                                    NewIndex, Builder);
342   return cast<ExtractElementInst>(Builder.CreateExtractElement(Shuf, NewIndex));
343 }
344 
345 /// Try to reduce extract element costs by converting scalar compares to vector
346 /// compares followed by extract.
347 /// cmp (ext0 V0, C), (ext1 V1, C)
foldExtExtCmp(ExtractElementInst * Ext0,ExtractElementInst * Ext1,Instruction & I)348 void VectorCombine::foldExtExtCmp(ExtractElementInst *Ext0,
349                                   ExtractElementInst *Ext1, Instruction &I) {
350   assert(isa<CmpInst>(&I) && "Expected a compare");
351   assert(cast<ConstantInt>(Ext0->getIndexOperand())->getZExtValue() ==
352              cast<ConstantInt>(Ext1->getIndexOperand())->getZExtValue() &&
353          "Expected matching constant extract indexes");
354 
355   // cmp Pred (extelt V0, C), (extelt V1, C) --> extelt (cmp Pred V0, V1), C
356   ++NumVecCmp;
357   CmpInst::Predicate Pred = cast<CmpInst>(&I)->getPredicate();
358   Value *V0 = Ext0->getVectorOperand(), *V1 = Ext1->getVectorOperand();
359   Value *VecCmp = Builder.CreateCmp(Pred, V0, V1);
360   Value *NewExt = Builder.CreateExtractElement(VecCmp, Ext0->getIndexOperand());
361   replaceValue(I, *NewExt);
362 }
363 
364 /// Try to reduce extract element costs by converting scalar binops to vector
365 /// binops followed by extract.
366 /// bo (ext0 V0, C), (ext1 V1, C)
foldExtExtBinop(ExtractElementInst * Ext0,ExtractElementInst * Ext1,Instruction & I)367 void VectorCombine::foldExtExtBinop(ExtractElementInst *Ext0,
368                                     ExtractElementInst *Ext1, Instruction &I) {
369   assert(isa<BinaryOperator>(&I) && "Expected a binary operator");
370   assert(cast<ConstantInt>(Ext0->getIndexOperand())->getZExtValue() ==
371              cast<ConstantInt>(Ext1->getIndexOperand())->getZExtValue() &&
372          "Expected matching constant extract indexes");
373 
374   // bo (extelt V0, C), (extelt V1, C) --> extelt (bo V0, V1), C
375   ++NumVecBO;
376   Value *V0 = Ext0->getVectorOperand(), *V1 = Ext1->getVectorOperand();
377   Value *VecBO =
378       Builder.CreateBinOp(cast<BinaryOperator>(&I)->getOpcode(), V0, V1);
379 
380   // All IR flags are safe to back-propagate because any potential poison
381   // created in unused vector elements is discarded by the extract.
382   if (auto *VecBOInst = dyn_cast<Instruction>(VecBO))
383     VecBOInst->copyIRFlags(&I);
384 
385   Value *NewExt = Builder.CreateExtractElement(VecBO, Ext0->getIndexOperand());
386   replaceValue(I, *NewExt);
387 }
388 
389 /// Match an instruction with extracted vector operands.
foldExtractExtract(Instruction & I)390 bool VectorCombine::foldExtractExtract(Instruction &I) {
391   // It is not safe to transform things like div, urem, etc. because we may
392   // create undefined behavior when executing those on unknown vector elements.
393   if (!isSafeToSpeculativelyExecute(&I))
394     return false;
395 
396   Instruction *I0, *I1;
397   CmpInst::Predicate Pred = CmpInst::BAD_ICMP_PREDICATE;
398   if (!match(&I, m_Cmp(Pred, m_Instruction(I0), m_Instruction(I1))) &&
399       !match(&I, m_BinOp(m_Instruction(I0), m_Instruction(I1))))
400     return false;
401 
402   Value *V0, *V1;
403   uint64_t C0, C1;
404   if (!match(I0, m_ExtractElt(m_Value(V0), m_ConstantInt(C0))) ||
405       !match(I1, m_ExtractElt(m_Value(V1), m_ConstantInt(C1))) ||
406       V0->getType() != V1->getType())
407     return false;
408 
409   // If the scalar value 'I' is going to be re-inserted into a vector, then try
410   // to create an extract to that same element. The extract/insert can be
411   // reduced to a "select shuffle".
412   // TODO: If we add a larger pattern match that starts from an insert, this
413   //       probably becomes unnecessary.
414   auto *Ext0 = cast<ExtractElementInst>(I0);
415   auto *Ext1 = cast<ExtractElementInst>(I1);
416   uint64_t InsertIndex = InvalidIndex;
417   if (I.hasOneUse())
418     match(I.user_back(),
419           m_InsertElt(m_Value(), m_Value(), m_ConstantInt(InsertIndex)));
420 
421   ExtractElementInst *ExtractToChange;
422   if (isExtractExtractCheap(Ext0, Ext1, I.getOpcode(), ExtractToChange,
423                             InsertIndex))
424     return false;
425 
426   if (ExtractToChange) {
427     unsigned CheapExtractIdx = ExtractToChange == Ext0 ? C1 : C0;
428     ExtractElementInst *NewExtract =
429         translateExtract(ExtractToChange, CheapExtractIdx, Builder);
430     if (!NewExtract)
431       return false;
432     if (ExtractToChange == Ext0)
433       Ext0 = NewExtract;
434     else
435       Ext1 = NewExtract;
436   }
437 
438   if (Pred != CmpInst::BAD_ICMP_PREDICATE)
439     foldExtExtCmp(Ext0, Ext1, I);
440   else
441     foldExtExtBinop(Ext0, Ext1, I);
442 
443   return true;
444 }
445 
446 /// If this is a bitcast of a shuffle, try to bitcast the source vector to the
447 /// destination type followed by shuffle. This can enable further transforms by
448 /// moving bitcasts or shuffles together.
foldBitcastShuf(Instruction & I)449 bool VectorCombine::foldBitcastShuf(Instruction &I) {
450   Value *V;
451   ArrayRef<int> Mask;
452   if (!match(&I, m_BitCast(
453                      m_OneUse(m_Shuffle(m_Value(V), m_Undef(), m_Mask(Mask))))))
454     return false;
455 
456   // 1) Do not fold bitcast shuffle for scalable type. First, shuffle cost for
457   // scalable type is unknown; Second, we cannot reason if the narrowed shuffle
458   // mask for scalable type is a splat or not.
459   // 2) Disallow non-vector casts and length-changing shuffles.
460   // TODO: We could allow any shuffle.
461   auto *DestTy = dyn_cast<FixedVectorType>(I.getType());
462   auto *SrcTy = dyn_cast<FixedVectorType>(V->getType());
463   if (!SrcTy || !DestTy || I.getOperand(0)->getType() != SrcTy)
464     return false;
465 
466   // The new shuffle must not cost more than the old shuffle. The bitcast is
467   // moved ahead of the shuffle, so assume that it has the same cost as before.
468   if (TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, DestTy) >
469       TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, SrcTy))
470     return false;
471 
472   unsigned DestNumElts = DestTy->getNumElements();
473   unsigned SrcNumElts = SrcTy->getNumElements();
474   SmallVector<int, 16> NewMask;
475   if (SrcNumElts <= DestNumElts) {
476     // The bitcast is from wide to narrow/equal elements. The shuffle mask can
477     // always be expanded to the equivalent form choosing narrower elements.
478     assert(DestNumElts % SrcNumElts == 0 && "Unexpected shuffle mask");
479     unsigned ScaleFactor = DestNumElts / SrcNumElts;
480     narrowShuffleMaskElts(ScaleFactor, Mask, NewMask);
481   } else {
482     // The bitcast is from narrow elements to wide elements. The shuffle mask
483     // must choose consecutive elements to allow casting first.
484     assert(SrcNumElts % DestNumElts == 0 && "Unexpected shuffle mask");
485     unsigned ScaleFactor = SrcNumElts / DestNumElts;
486     if (!widenShuffleMaskElts(ScaleFactor, Mask, NewMask))
487       return false;
488   }
489   // bitcast (shuf V, MaskC) --> shuf (bitcast V), MaskC'
490   ++NumShufOfBitcast;
491   Value *CastV = Builder.CreateBitCast(V, DestTy);
492   Value *Shuf = Builder.CreateShuffleVector(CastV, NewMask);
493   replaceValue(I, *Shuf);
494   return true;
495 }
496 
497 /// Match a vector binop or compare instruction with at least one inserted
498 /// scalar operand and convert to scalar binop/cmp followed by insertelement.
scalarizeBinopOrCmp(Instruction & I)499 bool VectorCombine::scalarizeBinopOrCmp(Instruction &I) {
500   CmpInst::Predicate Pred = CmpInst::BAD_ICMP_PREDICATE;
501   Value *Ins0, *Ins1;
502   if (!match(&I, m_BinOp(m_Value(Ins0), m_Value(Ins1))) &&
503       !match(&I, m_Cmp(Pred, m_Value(Ins0), m_Value(Ins1))))
504     return false;
505 
506   // Do not convert the vector condition of a vector select into a scalar
507   // condition. That may cause problems for codegen because of differences in
508   // boolean formats and register-file transfers.
509   // TODO: Can we account for that in the cost model?
510   bool IsCmp = Pred != CmpInst::Predicate::BAD_ICMP_PREDICATE;
511   if (IsCmp)
512     for (User *U : I.users())
513       if (match(U, m_Select(m_Specific(&I), m_Value(), m_Value())))
514         return false;
515 
516   // Match against one or both scalar values being inserted into constant
517   // vectors:
518   // vec_op VecC0, (inselt VecC1, V1, Index)
519   // vec_op (inselt VecC0, V0, Index), VecC1
520   // vec_op (inselt VecC0, V0, Index), (inselt VecC1, V1, Index)
521   // TODO: Deal with mismatched index constants and variable indexes?
522   Constant *VecC0 = nullptr, *VecC1 = nullptr;
523   Value *V0 = nullptr, *V1 = nullptr;
524   uint64_t Index0 = 0, Index1 = 0;
525   if (!match(Ins0, m_InsertElt(m_Constant(VecC0), m_Value(V0),
526                                m_ConstantInt(Index0))) &&
527       !match(Ins0, m_Constant(VecC0)))
528     return false;
529   if (!match(Ins1, m_InsertElt(m_Constant(VecC1), m_Value(V1),
530                                m_ConstantInt(Index1))) &&
531       !match(Ins1, m_Constant(VecC1)))
532     return false;
533 
534   bool IsConst0 = !V0;
535   bool IsConst1 = !V1;
536   if (IsConst0 && IsConst1)
537     return false;
538   if (!IsConst0 && !IsConst1 && Index0 != Index1)
539     return false;
540 
541   // Bail for single insertion if it is a load.
542   // TODO: Handle this once getVectorInstrCost can cost for load/stores.
543   auto *I0 = dyn_cast_or_null<Instruction>(V0);
544   auto *I1 = dyn_cast_or_null<Instruction>(V1);
545   if ((IsConst0 && I1 && I1->mayReadFromMemory()) ||
546       (IsConst1 && I0 && I0->mayReadFromMemory()))
547     return false;
548 
549   uint64_t Index = IsConst0 ? Index1 : Index0;
550   Type *ScalarTy = IsConst0 ? V1->getType() : V0->getType();
551   Type *VecTy = I.getType();
552   assert(VecTy->isVectorTy() &&
553          (IsConst0 || IsConst1 || V0->getType() == V1->getType()) &&
554          (ScalarTy->isIntegerTy() || ScalarTy->isFloatingPointTy() ||
555           ScalarTy->isPointerTy()) &&
556          "Unexpected types for insert element into binop or cmp");
557 
558   unsigned Opcode = I.getOpcode();
559   int ScalarOpCost, VectorOpCost;
560   if (IsCmp) {
561     ScalarOpCost = TTI.getCmpSelInstrCost(Opcode, ScalarTy);
562     VectorOpCost = TTI.getCmpSelInstrCost(Opcode, VecTy);
563   } else {
564     ScalarOpCost = TTI.getArithmeticInstrCost(Opcode, ScalarTy);
565     VectorOpCost = TTI.getArithmeticInstrCost(Opcode, VecTy);
566   }
567 
568   // Get cost estimate for the insert element. This cost will factor into
569   // both sequences.
570   int InsertCost =
571       TTI.getVectorInstrCost(Instruction::InsertElement, VecTy, Index);
572   int OldCost = (IsConst0 ? 0 : InsertCost) + (IsConst1 ? 0 : InsertCost) +
573                 VectorOpCost;
574   int NewCost = ScalarOpCost + InsertCost +
575                 (IsConst0 ? 0 : !Ins0->hasOneUse() * InsertCost) +
576                 (IsConst1 ? 0 : !Ins1->hasOneUse() * InsertCost);
577 
578   // We want to scalarize unless the vector variant actually has lower cost.
579   if (OldCost < NewCost)
580     return false;
581 
582   // vec_op (inselt VecC0, V0, Index), (inselt VecC1, V1, Index) -->
583   // inselt NewVecC, (scalar_op V0, V1), Index
584   if (IsCmp)
585     ++NumScalarCmp;
586   else
587     ++NumScalarBO;
588 
589   // For constant cases, extract the scalar element, this should constant fold.
590   if (IsConst0)
591     V0 = ConstantExpr::getExtractElement(VecC0, Builder.getInt64(Index));
592   if (IsConst1)
593     V1 = ConstantExpr::getExtractElement(VecC1, Builder.getInt64(Index));
594 
595   Value *Scalar =
596       IsCmp ? Builder.CreateCmp(Pred, V0, V1)
597             : Builder.CreateBinOp((Instruction::BinaryOps)Opcode, V0, V1);
598 
599   Scalar->setName(I.getName() + ".scalar");
600 
601   // All IR flags are safe to back-propagate. There is no potential for extra
602   // poison to be created by the scalar instruction.
603   if (auto *ScalarInst = dyn_cast<Instruction>(Scalar))
604     ScalarInst->copyIRFlags(&I);
605 
606   // Fold the vector constants in the original vectors into a new base vector.
607   Constant *NewVecC = IsCmp ? ConstantExpr::getCompare(Pred, VecC0, VecC1)
608                             : ConstantExpr::get(Opcode, VecC0, VecC1);
609   Value *Insert = Builder.CreateInsertElement(NewVecC, Scalar, Index);
610   replaceValue(I, *Insert);
611   return true;
612 }
613 
614 /// Try to combine a scalar binop + 2 scalar compares of extracted elements of
615 /// a vector into vector operations followed by extract. Note: The SLP pass
616 /// may miss this pattern because of implementation problems.
foldExtractedCmps(Instruction & I)617 bool VectorCombine::foldExtractedCmps(Instruction &I) {
618   // We are looking for a scalar binop of booleans.
619   // binop i1 (cmp Pred I0, C0), (cmp Pred I1, C1)
620   if (!I.isBinaryOp() || !I.getType()->isIntegerTy(1))
621     return false;
622 
623   // The compare predicates should match, and each compare should have a
624   // constant operand.
625   // TODO: Relax the one-use constraints.
626   Value *B0 = I.getOperand(0), *B1 = I.getOperand(1);
627   Instruction *I0, *I1;
628   Constant *C0, *C1;
629   CmpInst::Predicate P0, P1;
630   if (!match(B0, m_OneUse(m_Cmp(P0, m_Instruction(I0), m_Constant(C0)))) ||
631       !match(B1, m_OneUse(m_Cmp(P1, m_Instruction(I1), m_Constant(C1)))) ||
632       P0 != P1)
633     return false;
634 
635   // The compare operands must be extracts of the same vector with constant
636   // extract indexes.
637   // TODO: Relax the one-use constraints.
638   Value *X;
639   uint64_t Index0, Index1;
640   if (!match(I0, m_OneUse(m_ExtractElt(m_Value(X), m_ConstantInt(Index0)))) ||
641       !match(I1, m_OneUse(m_ExtractElt(m_Specific(X), m_ConstantInt(Index1)))))
642     return false;
643 
644   auto *Ext0 = cast<ExtractElementInst>(I0);
645   auto *Ext1 = cast<ExtractElementInst>(I1);
646   ExtractElementInst *ConvertToShuf = getShuffleExtract(Ext0, Ext1);
647   if (!ConvertToShuf)
648     return false;
649 
650   // The original scalar pattern is:
651   // binop i1 (cmp Pred (ext X, Index0), C0), (cmp Pred (ext X, Index1), C1)
652   CmpInst::Predicate Pred = P0;
653   unsigned CmpOpcode = CmpInst::isFPPredicate(Pred) ? Instruction::FCmp
654                                                     : Instruction::ICmp;
655   auto *VecTy = dyn_cast<FixedVectorType>(X->getType());
656   if (!VecTy)
657     return false;
658 
659   int OldCost = TTI.getVectorInstrCost(Ext0->getOpcode(), VecTy, Index0);
660   OldCost += TTI.getVectorInstrCost(Ext1->getOpcode(), VecTy, Index1);
661   OldCost += TTI.getCmpSelInstrCost(CmpOpcode, I0->getType()) * 2;
662   OldCost += TTI.getArithmeticInstrCost(I.getOpcode(), I.getType());
663 
664   // The proposed vector pattern is:
665   // vcmp = cmp Pred X, VecC
666   // ext (binop vNi1 vcmp, (shuffle vcmp, Index1)), Index0
667   int CheapIndex = ConvertToShuf == Ext0 ? Index1 : Index0;
668   int ExpensiveIndex = ConvertToShuf == Ext0 ? Index0 : Index1;
669   auto *CmpTy = cast<FixedVectorType>(CmpInst::makeCmpResultType(X->getType()));
670   int NewCost = TTI.getCmpSelInstrCost(CmpOpcode, X->getType());
671   NewCost +=
672       TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, CmpTy);
673   NewCost += TTI.getArithmeticInstrCost(I.getOpcode(), CmpTy);
674   NewCost += TTI.getVectorInstrCost(Ext0->getOpcode(), CmpTy, CheapIndex);
675 
676   // Aggressively form vector ops if the cost is equal because the transform
677   // may enable further optimization.
678   // Codegen can reverse this transform (scalarize) if it was not profitable.
679   if (OldCost < NewCost)
680     return false;
681 
682   // Create a vector constant from the 2 scalar constants.
683   SmallVector<Constant *, 32> CmpC(VecTy->getNumElements(),
684                                    UndefValue::get(VecTy->getElementType()));
685   CmpC[Index0] = C0;
686   CmpC[Index1] = C1;
687   Value *VCmp = Builder.CreateCmp(Pred, X, ConstantVector::get(CmpC));
688 
689   Value *Shuf = createShiftShuffle(VCmp, ExpensiveIndex, CheapIndex, Builder);
690   Value *VecLogic = Builder.CreateBinOp(cast<BinaryOperator>(I).getOpcode(),
691                                         VCmp, Shuf);
692   Value *NewExt = Builder.CreateExtractElement(VecLogic, CheapIndex);
693   replaceValue(I, *NewExt);
694   ++NumVecCmpBO;
695   return true;
696 }
697 
698 /// This is the entry point for all transforms. Pass manager differences are
699 /// handled in the callers of this function.
run()700 bool VectorCombine::run() {
701   if (DisableVectorCombine)
702     return false;
703 
704   // Don't attempt vectorization if the target does not support vectors.
705   if (!TTI.getNumberOfRegisters(TTI.getRegisterClassForType(/*Vector*/ true)))
706     return false;
707 
708   bool MadeChange = false;
709   for (BasicBlock &BB : F) {
710     // Ignore unreachable basic blocks.
711     if (!DT.isReachableFromEntry(&BB))
712       continue;
713     // Do not delete instructions under here and invalidate the iterator.
714     // Walk the block forwards to enable simple iterative chains of transforms.
715     // TODO: It could be more efficient to remove dead instructions
716     //       iteratively in this loop rather than waiting until the end.
717     for (Instruction &I : BB) {
718       if (isa<DbgInfoIntrinsic>(I))
719         continue;
720       Builder.SetInsertPoint(&I);
721       MadeChange |= vectorizeLoadInsert(I);
722       MadeChange |= foldExtractExtract(I);
723       MadeChange |= foldBitcastShuf(I);
724       MadeChange |= scalarizeBinopOrCmp(I);
725       MadeChange |= foldExtractedCmps(I);
726     }
727   }
728 
729   // We're done with transforms, so remove dead instructions.
730   if (MadeChange)
731     for (BasicBlock &BB : F)
732       SimplifyInstructionsInBlock(&BB);
733 
734   return MadeChange;
735 }
736 
737 // Pass manager boilerplate below here.
738 
739 namespace {
740 class VectorCombineLegacyPass : public FunctionPass {
741 public:
742   static char ID;
VectorCombineLegacyPass()743   VectorCombineLegacyPass() : FunctionPass(ID) {
744     initializeVectorCombineLegacyPassPass(*PassRegistry::getPassRegistry());
745   }
746 
getAnalysisUsage(AnalysisUsage & AU) const747   void getAnalysisUsage(AnalysisUsage &AU) const override {
748     AU.addRequired<DominatorTreeWrapperPass>();
749     AU.addRequired<TargetTransformInfoWrapperPass>();
750     AU.setPreservesCFG();
751     AU.addPreserved<DominatorTreeWrapperPass>();
752     AU.addPreserved<GlobalsAAWrapperPass>();
753     AU.addPreserved<AAResultsWrapperPass>();
754     AU.addPreserved<BasicAAWrapperPass>();
755     FunctionPass::getAnalysisUsage(AU);
756   }
757 
runOnFunction(Function & F)758   bool runOnFunction(Function &F) override {
759     if (skipFunction(F))
760       return false;
761     auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
762     auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
763     VectorCombine Combiner(F, TTI, DT);
764     return Combiner.run();
765   }
766 };
767 } // namespace
768 
769 char VectorCombineLegacyPass::ID = 0;
770 INITIALIZE_PASS_BEGIN(VectorCombineLegacyPass, "vector-combine",
771                       "Optimize scalar/vector ops", false,
772                       false)
INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)773 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
774 INITIALIZE_PASS_END(VectorCombineLegacyPass, "vector-combine",
775                     "Optimize scalar/vector ops", false, false)
776 Pass *llvm::createVectorCombinePass() {
777   return new VectorCombineLegacyPass();
778 }
779 
run(Function & F,FunctionAnalysisManager & FAM)780 PreservedAnalyses VectorCombinePass::run(Function &F,
781                                          FunctionAnalysisManager &FAM) {
782   TargetTransformInfo &TTI = FAM.getResult<TargetIRAnalysis>(F);
783   DominatorTree &DT = FAM.getResult<DominatorTreeAnalysis>(F);
784   VectorCombine Combiner(F, TTI, DT);
785   if (!Combiner.run())
786     return PreservedAnalyses::all();
787   PreservedAnalyses PA;
788   PA.preserveSet<CFGAnalyses>();
789   PA.preserve<GlobalsAA>();
790   PA.preserve<AAManager>();
791   PA.preserve<BasicAA>();
792   return PA;
793 }
794