1 //===------- VectorCombine.cpp - Optimize partial vector operations -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass optimizes scalar/vector interactions using target cost models. The
10 // transforms implemented here may not fit in traditional loop-based or SLP
11 // vectorization passes.
12 //
13 //===----------------------------------------------------------------------===//
14
15 #include "llvm/Transforms/Vectorize/VectorCombine.h"
16 #include "llvm/ADT/Statistic.h"
17 #include "llvm/Analysis/BasicAliasAnalysis.h"
18 #include "llvm/Analysis/GlobalsModRef.h"
19 #include "llvm/Analysis/Loads.h"
20 #include "llvm/Analysis/TargetTransformInfo.h"
21 #include "llvm/Analysis/ValueTracking.h"
22 #include "llvm/Analysis/VectorUtils.h"
23 #include "llvm/IR/Dominators.h"
24 #include "llvm/IR/Function.h"
25 #include "llvm/IR/IRBuilder.h"
26 #include "llvm/IR/PatternMatch.h"
27 #include "llvm/InitializePasses.h"
28 #include "llvm/Pass.h"
29 #include "llvm/Support/CommandLine.h"
30 #include "llvm/Transforms/Utils/Local.h"
31 #include "llvm/Transforms/Vectorize.h"
32
33 using namespace llvm;
34 using namespace llvm::PatternMatch;
35
36 #define DEBUG_TYPE "vector-combine"
37 STATISTIC(NumVecLoad, "Number of vector loads formed");
38 STATISTIC(NumVecCmp, "Number of vector compares formed");
39 STATISTIC(NumVecBO, "Number of vector binops formed");
40 STATISTIC(NumVecCmpBO, "Number of vector compare + binop formed");
41 STATISTIC(NumShufOfBitcast, "Number of shuffles moved after bitcast");
42 STATISTIC(NumScalarBO, "Number of scalar binops formed");
43 STATISTIC(NumScalarCmp, "Number of scalar compares formed");
44
45 static cl::opt<bool> DisableVectorCombine(
46 "disable-vector-combine", cl::init(false), cl::Hidden,
47 cl::desc("Disable all vector combine transforms"));
48
49 static cl::opt<bool> DisableBinopExtractShuffle(
50 "disable-binop-extract-shuffle", cl::init(false), cl::Hidden,
51 cl::desc("Disable binop extract to shuffle transforms"));
52
53 static const unsigned InvalidIndex = std::numeric_limits<unsigned>::max();
54
55 namespace {
56 class VectorCombine {
57 public:
VectorCombine(Function & F,const TargetTransformInfo & TTI,const DominatorTree & DT)58 VectorCombine(Function &F, const TargetTransformInfo &TTI,
59 const DominatorTree &DT)
60 : F(F), Builder(F.getContext()), TTI(TTI), DT(DT) {}
61
62 bool run();
63
64 private:
65 Function &F;
66 IRBuilder<> Builder;
67 const TargetTransformInfo &TTI;
68 const DominatorTree &DT;
69
70 bool vectorizeLoadInsert(Instruction &I);
71 ExtractElementInst *getShuffleExtract(ExtractElementInst *Ext0,
72 ExtractElementInst *Ext1,
73 unsigned PreferredExtractIndex) const;
74 bool isExtractExtractCheap(ExtractElementInst *Ext0, ExtractElementInst *Ext1,
75 unsigned Opcode,
76 ExtractElementInst *&ConvertToShuffle,
77 unsigned PreferredExtractIndex);
78 void foldExtExtCmp(ExtractElementInst *Ext0, ExtractElementInst *Ext1,
79 Instruction &I);
80 void foldExtExtBinop(ExtractElementInst *Ext0, ExtractElementInst *Ext1,
81 Instruction &I);
82 bool foldExtractExtract(Instruction &I);
83 bool foldBitcastShuf(Instruction &I);
84 bool scalarizeBinopOrCmp(Instruction &I);
85 bool foldExtractedCmps(Instruction &I);
86 };
87 } // namespace
88
replaceValue(Value & Old,Value & New)89 static void replaceValue(Value &Old, Value &New) {
90 Old.replaceAllUsesWith(&New);
91 New.takeName(&Old);
92 }
93
vectorizeLoadInsert(Instruction & I)94 bool VectorCombine::vectorizeLoadInsert(Instruction &I) {
95 // Match insert into fixed vector of scalar value.
96 auto *Ty = dyn_cast<FixedVectorType>(I.getType());
97 Value *Scalar;
98 if (!Ty || !match(&I, m_InsertElt(m_Undef(), m_Value(Scalar), m_ZeroInt())) ||
99 !Scalar->hasOneUse())
100 return false;
101
102 // Optionally match an extract from another vector.
103 Value *X;
104 bool HasExtract = match(Scalar, m_ExtractElt(m_Value(X), m_ZeroInt()));
105 if (!HasExtract)
106 X = Scalar;
107
108 // Match source value as load of scalar or vector.
109 // Do not vectorize scalar load (widening) if atomic/volatile or under
110 // asan/hwasan/memtag/tsan. The widened load may load data from dirty regions
111 // or create data races non-existent in the source.
112 auto *Load = dyn_cast<LoadInst>(X);
113 if (!Load || !Load->isSimple() || !Load->hasOneUse() ||
114 Load->getFunction()->hasFnAttribute(Attribute::SanitizeMemTag) ||
115 mustSuppressSpeculation(*Load))
116 return false;
117
118 // TODO: Extend this to match GEP with constant offsets.
119 Value *PtrOp = Load->getPointerOperand()->stripPointerCasts();
120 assert(isa<PointerType>(PtrOp->getType()) && "Expected a pointer type");
121 unsigned AS = Load->getPointerAddressSpace();
122
123 // If original AS != Load's AS, we can't bitcast the original pointer and have
124 // to use Load's operand instead. Ideally we would want to strip pointer casts
125 // without changing AS, but there's no API to do that ATM.
126 if (AS != PtrOp->getType()->getPointerAddressSpace())
127 PtrOp = Load->getPointerOperand();
128
129 Type *ScalarTy = Scalar->getType();
130 uint64_t ScalarSize = ScalarTy->getPrimitiveSizeInBits();
131 unsigned MinVectorSize = TTI.getMinVectorRegisterBitWidth();
132 if (!ScalarSize || !MinVectorSize || MinVectorSize % ScalarSize != 0)
133 return false;
134
135 // Check safety of replacing the scalar load with a larger vector load.
136 unsigned MinVecNumElts = MinVectorSize / ScalarSize;
137 auto *MinVecTy = VectorType::get(ScalarTy, MinVecNumElts, false);
138 Align Alignment = Load->getAlign();
139 const DataLayout &DL = I.getModule()->getDataLayout();
140 if (!isSafeToLoadUnconditionally(PtrOp, MinVecTy, Alignment, DL, Load, &DT))
141 return false;
142
143
144 // Original pattern: insertelt undef, load [free casts of] PtrOp, 0
145 Type *LoadTy = Load->getType();
146 int OldCost = TTI.getMemoryOpCost(Instruction::Load, LoadTy, Alignment, AS);
147 APInt DemandedElts = APInt::getOneBitSet(MinVecNumElts, 0);
148 OldCost += TTI.getScalarizationOverhead(MinVecTy, DemandedElts,
149 /* Insert */ true, HasExtract);
150
151 // New pattern: load VecPtr
152 int NewCost = TTI.getMemoryOpCost(Instruction::Load, MinVecTy, Alignment, AS);
153
154 // We can aggressively convert to the vector form because the backend can
155 // invert this transform if it does not result in a performance win.
156 if (OldCost < NewCost)
157 return false;
158
159 // It is safe and potentially profitable to load a vector directly:
160 // inselt undef, load Scalar, 0 --> load VecPtr
161 IRBuilder<> Builder(Load);
162 Value *CastedPtr = Builder.CreateBitCast(PtrOp, MinVecTy->getPointerTo(AS));
163 Value *VecLd = Builder.CreateAlignedLoad(MinVecTy, CastedPtr, Alignment);
164
165 // If the insert type does not match the target's minimum vector type,
166 // use an identity shuffle to shrink/grow the vector.
167 if (Ty != MinVecTy) {
168 unsigned OutputNumElts = Ty->getNumElements();
169 SmallVector<int, 16> Mask(OutputNumElts, UndefMaskElem);
170 for (unsigned i = 0; i < OutputNumElts && i < MinVecNumElts; ++i)
171 Mask[i] = i;
172 VecLd = Builder.CreateShuffleVector(VecLd, Mask);
173 }
174 replaceValue(I, *VecLd);
175 ++NumVecLoad;
176 return true;
177 }
178
179 /// Determine which, if any, of the inputs should be replaced by a shuffle
180 /// followed by extract from a different index.
getShuffleExtract(ExtractElementInst * Ext0,ExtractElementInst * Ext1,unsigned PreferredExtractIndex=InvalidIndex) const181 ExtractElementInst *VectorCombine::getShuffleExtract(
182 ExtractElementInst *Ext0, ExtractElementInst *Ext1,
183 unsigned PreferredExtractIndex = InvalidIndex) const {
184 assert(isa<ConstantInt>(Ext0->getIndexOperand()) &&
185 isa<ConstantInt>(Ext1->getIndexOperand()) &&
186 "Expected constant extract indexes");
187
188 unsigned Index0 = cast<ConstantInt>(Ext0->getIndexOperand())->getZExtValue();
189 unsigned Index1 = cast<ConstantInt>(Ext1->getIndexOperand())->getZExtValue();
190
191 // If the extract indexes are identical, no shuffle is needed.
192 if (Index0 == Index1)
193 return nullptr;
194
195 Type *VecTy = Ext0->getVectorOperand()->getType();
196 assert(VecTy == Ext1->getVectorOperand()->getType() && "Need matching types");
197 int Cost0 = TTI.getVectorInstrCost(Ext0->getOpcode(), VecTy, Index0);
198 int Cost1 = TTI.getVectorInstrCost(Ext1->getOpcode(), VecTy, Index1);
199
200 // We are extracting from 2 different indexes, so one operand must be shuffled
201 // before performing a vector operation and/or extract. The more expensive
202 // extract will be replaced by a shuffle.
203 if (Cost0 > Cost1)
204 return Ext0;
205 if (Cost1 > Cost0)
206 return Ext1;
207
208 // If the costs are equal and there is a preferred extract index, shuffle the
209 // opposite operand.
210 if (PreferredExtractIndex == Index0)
211 return Ext1;
212 if (PreferredExtractIndex == Index1)
213 return Ext0;
214
215 // Otherwise, replace the extract with the higher index.
216 return Index0 > Index1 ? Ext0 : Ext1;
217 }
218
219 /// Compare the relative costs of 2 extracts followed by scalar operation vs.
220 /// vector operation(s) followed by extract. Return true if the existing
221 /// instructions are cheaper than a vector alternative. Otherwise, return false
222 /// and if one of the extracts should be transformed to a shufflevector, set
223 /// \p ConvertToShuffle to that extract instruction.
isExtractExtractCheap(ExtractElementInst * Ext0,ExtractElementInst * Ext1,unsigned Opcode,ExtractElementInst * & ConvertToShuffle,unsigned PreferredExtractIndex)224 bool VectorCombine::isExtractExtractCheap(ExtractElementInst *Ext0,
225 ExtractElementInst *Ext1,
226 unsigned Opcode,
227 ExtractElementInst *&ConvertToShuffle,
228 unsigned PreferredExtractIndex) {
229 assert(isa<ConstantInt>(Ext0->getOperand(1)) &&
230 isa<ConstantInt>(Ext1->getOperand(1)) &&
231 "Expected constant extract indexes");
232 Type *ScalarTy = Ext0->getType();
233 auto *VecTy = cast<VectorType>(Ext0->getOperand(0)->getType());
234 int ScalarOpCost, VectorOpCost;
235
236 // Get cost estimates for scalar and vector versions of the operation.
237 bool IsBinOp = Instruction::isBinaryOp(Opcode);
238 if (IsBinOp) {
239 ScalarOpCost = TTI.getArithmeticInstrCost(Opcode, ScalarTy);
240 VectorOpCost = TTI.getArithmeticInstrCost(Opcode, VecTy);
241 } else {
242 assert((Opcode == Instruction::ICmp || Opcode == Instruction::FCmp) &&
243 "Expected a compare");
244 ScalarOpCost = TTI.getCmpSelInstrCost(Opcode, ScalarTy,
245 CmpInst::makeCmpResultType(ScalarTy));
246 VectorOpCost = TTI.getCmpSelInstrCost(Opcode, VecTy,
247 CmpInst::makeCmpResultType(VecTy));
248 }
249
250 // Get cost estimates for the extract elements. These costs will factor into
251 // both sequences.
252 unsigned Ext0Index = cast<ConstantInt>(Ext0->getOperand(1))->getZExtValue();
253 unsigned Ext1Index = cast<ConstantInt>(Ext1->getOperand(1))->getZExtValue();
254
255 int Extract0Cost =
256 TTI.getVectorInstrCost(Instruction::ExtractElement, VecTy, Ext0Index);
257 int Extract1Cost =
258 TTI.getVectorInstrCost(Instruction::ExtractElement, VecTy, Ext1Index);
259
260 // A more expensive extract will always be replaced by a splat shuffle.
261 // For example, if Ext0 is more expensive:
262 // opcode (extelt V0, Ext0), (ext V1, Ext1) -->
263 // extelt (opcode (splat V0, Ext0), V1), Ext1
264 // TODO: Evaluate whether that always results in lowest cost. Alternatively,
265 // check the cost of creating a broadcast shuffle and shuffling both
266 // operands to element 0.
267 int CheapExtractCost = std::min(Extract0Cost, Extract1Cost);
268
269 // Extra uses of the extracts mean that we include those costs in the
270 // vector total because those instructions will not be eliminated.
271 int OldCost, NewCost;
272 if (Ext0->getOperand(0) == Ext1->getOperand(0) && Ext0Index == Ext1Index) {
273 // Handle a special case. If the 2 extracts are identical, adjust the
274 // formulas to account for that. The extra use charge allows for either the
275 // CSE'd pattern or an unoptimized form with identical values:
276 // opcode (extelt V, C), (extelt V, C) --> extelt (opcode V, V), C
277 bool HasUseTax = Ext0 == Ext1 ? !Ext0->hasNUses(2)
278 : !Ext0->hasOneUse() || !Ext1->hasOneUse();
279 OldCost = CheapExtractCost + ScalarOpCost;
280 NewCost = VectorOpCost + CheapExtractCost + HasUseTax * CheapExtractCost;
281 } else {
282 // Handle the general case. Each extract is actually a different value:
283 // opcode (extelt V0, C0), (extelt V1, C1) --> extelt (opcode V0, V1), C
284 OldCost = Extract0Cost + Extract1Cost + ScalarOpCost;
285 NewCost = VectorOpCost + CheapExtractCost +
286 !Ext0->hasOneUse() * Extract0Cost +
287 !Ext1->hasOneUse() * Extract1Cost;
288 }
289
290 ConvertToShuffle = getShuffleExtract(Ext0, Ext1, PreferredExtractIndex);
291 if (ConvertToShuffle) {
292 if (IsBinOp && DisableBinopExtractShuffle)
293 return true;
294
295 // If we are extracting from 2 different indexes, then one operand must be
296 // shuffled before performing the vector operation. The shuffle mask is
297 // undefined except for 1 lane that is being translated to the remaining
298 // extraction lane. Therefore, it is a splat shuffle. Ex:
299 // ShufMask = { undef, undef, 0, undef }
300 // TODO: The cost model has an option for a "broadcast" shuffle
301 // (splat-from-element-0), but no option for a more general splat.
302 NewCost +=
303 TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, VecTy);
304 }
305
306 // Aggressively form a vector op if the cost is equal because the transform
307 // may enable further optimization.
308 // Codegen can reverse this transform (scalarize) if it was not profitable.
309 return OldCost < NewCost;
310 }
311
312 /// Create a shuffle that translates (shifts) 1 element from the input vector
313 /// to a new element location.
createShiftShuffle(Value * Vec,unsigned OldIndex,unsigned NewIndex,IRBuilder<> & Builder)314 static Value *createShiftShuffle(Value *Vec, unsigned OldIndex,
315 unsigned NewIndex, IRBuilder<> &Builder) {
316 // The shuffle mask is undefined except for 1 lane that is being translated
317 // to the new element index. Example for OldIndex == 2 and NewIndex == 0:
318 // ShufMask = { 2, undef, undef, undef }
319 auto *VecTy = cast<FixedVectorType>(Vec->getType());
320 SmallVector<int, 32> ShufMask(VecTy->getNumElements(), UndefMaskElem);
321 ShufMask[NewIndex] = OldIndex;
322 return Builder.CreateShuffleVector(Vec, ShufMask, "shift");
323 }
324
325 /// Given an extract element instruction with constant index operand, shuffle
326 /// the source vector (shift the scalar element) to a NewIndex for extraction.
327 /// Return null if the input can be constant folded, so that we are not creating
328 /// unnecessary instructions.
translateExtract(ExtractElementInst * ExtElt,unsigned NewIndex,IRBuilder<> & Builder)329 static ExtractElementInst *translateExtract(ExtractElementInst *ExtElt,
330 unsigned NewIndex,
331 IRBuilder<> &Builder) {
332 // If the extract can be constant-folded, this code is unsimplified. Defer
333 // to other passes to handle that.
334 Value *X = ExtElt->getVectorOperand();
335 Value *C = ExtElt->getIndexOperand();
336 assert(isa<ConstantInt>(C) && "Expected a constant index operand");
337 if (isa<Constant>(X))
338 return nullptr;
339
340 Value *Shuf = createShiftShuffle(X, cast<ConstantInt>(C)->getZExtValue(),
341 NewIndex, Builder);
342 return cast<ExtractElementInst>(Builder.CreateExtractElement(Shuf, NewIndex));
343 }
344
345 /// Try to reduce extract element costs by converting scalar compares to vector
346 /// compares followed by extract.
347 /// cmp (ext0 V0, C), (ext1 V1, C)
foldExtExtCmp(ExtractElementInst * Ext0,ExtractElementInst * Ext1,Instruction & I)348 void VectorCombine::foldExtExtCmp(ExtractElementInst *Ext0,
349 ExtractElementInst *Ext1, Instruction &I) {
350 assert(isa<CmpInst>(&I) && "Expected a compare");
351 assert(cast<ConstantInt>(Ext0->getIndexOperand())->getZExtValue() ==
352 cast<ConstantInt>(Ext1->getIndexOperand())->getZExtValue() &&
353 "Expected matching constant extract indexes");
354
355 // cmp Pred (extelt V0, C), (extelt V1, C) --> extelt (cmp Pred V0, V1), C
356 ++NumVecCmp;
357 CmpInst::Predicate Pred = cast<CmpInst>(&I)->getPredicate();
358 Value *V0 = Ext0->getVectorOperand(), *V1 = Ext1->getVectorOperand();
359 Value *VecCmp = Builder.CreateCmp(Pred, V0, V1);
360 Value *NewExt = Builder.CreateExtractElement(VecCmp, Ext0->getIndexOperand());
361 replaceValue(I, *NewExt);
362 }
363
364 /// Try to reduce extract element costs by converting scalar binops to vector
365 /// binops followed by extract.
366 /// bo (ext0 V0, C), (ext1 V1, C)
foldExtExtBinop(ExtractElementInst * Ext0,ExtractElementInst * Ext1,Instruction & I)367 void VectorCombine::foldExtExtBinop(ExtractElementInst *Ext0,
368 ExtractElementInst *Ext1, Instruction &I) {
369 assert(isa<BinaryOperator>(&I) && "Expected a binary operator");
370 assert(cast<ConstantInt>(Ext0->getIndexOperand())->getZExtValue() ==
371 cast<ConstantInt>(Ext1->getIndexOperand())->getZExtValue() &&
372 "Expected matching constant extract indexes");
373
374 // bo (extelt V0, C), (extelt V1, C) --> extelt (bo V0, V1), C
375 ++NumVecBO;
376 Value *V0 = Ext0->getVectorOperand(), *V1 = Ext1->getVectorOperand();
377 Value *VecBO =
378 Builder.CreateBinOp(cast<BinaryOperator>(&I)->getOpcode(), V0, V1);
379
380 // All IR flags are safe to back-propagate because any potential poison
381 // created in unused vector elements is discarded by the extract.
382 if (auto *VecBOInst = dyn_cast<Instruction>(VecBO))
383 VecBOInst->copyIRFlags(&I);
384
385 Value *NewExt = Builder.CreateExtractElement(VecBO, Ext0->getIndexOperand());
386 replaceValue(I, *NewExt);
387 }
388
389 /// Match an instruction with extracted vector operands.
foldExtractExtract(Instruction & I)390 bool VectorCombine::foldExtractExtract(Instruction &I) {
391 // It is not safe to transform things like div, urem, etc. because we may
392 // create undefined behavior when executing those on unknown vector elements.
393 if (!isSafeToSpeculativelyExecute(&I))
394 return false;
395
396 Instruction *I0, *I1;
397 CmpInst::Predicate Pred = CmpInst::BAD_ICMP_PREDICATE;
398 if (!match(&I, m_Cmp(Pred, m_Instruction(I0), m_Instruction(I1))) &&
399 !match(&I, m_BinOp(m_Instruction(I0), m_Instruction(I1))))
400 return false;
401
402 Value *V0, *V1;
403 uint64_t C0, C1;
404 if (!match(I0, m_ExtractElt(m_Value(V0), m_ConstantInt(C0))) ||
405 !match(I1, m_ExtractElt(m_Value(V1), m_ConstantInt(C1))) ||
406 V0->getType() != V1->getType())
407 return false;
408
409 // If the scalar value 'I' is going to be re-inserted into a vector, then try
410 // to create an extract to that same element. The extract/insert can be
411 // reduced to a "select shuffle".
412 // TODO: If we add a larger pattern match that starts from an insert, this
413 // probably becomes unnecessary.
414 auto *Ext0 = cast<ExtractElementInst>(I0);
415 auto *Ext1 = cast<ExtractElementInst>(I1);
416 uint64_t InsertIndex = InvalidIndex;
417 if (I.hasOneUse())
418 match(I.user_back(),
419 m_InsertElt(m_Value(), m_Value(), m_ConstantInt(InsertIndex)));
420
421 ExtractElementInst *ExtractToChange;
422 if (isExtractExtractCheap(Ext0, Ext1, I.getOpcode(), ExtractToChange,
423 InsertIndex))
424 return false;
425
426 if (ExtractToChange) {
427 unsigned CheapExtractIdx = ExtractToChange == Ext0 ? C1 : C0;
428 ExtractElementInst *NewExtract =
429 translateExtract(ExtractToChange, CheapExtractIdx, Builder);
430 if (!NewExtract)
431 return false;
432 if (ExtractToChange == Ext0)
433 Ext0 = NewExtract;
434 else
435 Ext1 = NewExtract;
436 }
437
438 if (Pred != CmpInst::BAD_ICMP_PREDICATE)
439 foldExtExtCmp(Ext0, Ext1, I);
440 else
441 foldExtExtBinop(Ext0, Ext1, I);
442
443 return true;
444 }
445
446 /// If this is a bitcast of a shuffle, try to bitcast the source vector to the
447 /// destination type followed by shuffle. This can enable further transforms by
448 /// moving bitcasts or shuffles together.
foldBitcastShuf(Instruction & I)449 bool VectorCombine::foldBitcastShuf(Instruction &I) {
450 Value *V;
451 ArrayRef<int> Mask;
452 if (!match(&I, m_BitCast(
453 m_OneUse(m_Shuffle(m_Value(V), m_Undef(), m_Mask(Mask))))))
454 return false;
455
456 // 1) Do not fold bitcast shuffle for scalable type. First, shuffle cost for
457 // scalable type is unknown; Second, we cannot reason if the narrowed shuffle
458 // mask for scalable type is a splat or not.
459 // 2) Disallow non-vector casts and length-changing shuffles.
460 // TODO: We could allow any shuffle.
461 auto *DestTy = dyn_cast<FixedVectorType>(I.getType());
462 auto *SrcTy = dyn_cast<FixedVectorType>(V->getType());
463 if (!SrcTy || !DestTy || I.getOperand(0)->getType() != SrcTy)
464 return false;
465
466 // The new shuffle must not cost more than the old shuffle. The bitcast is
467 // moved ahead of the shuffle, so assume that it has the same cost as before.
468 if (TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, DestTy) >
469 TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, SrcTy))
470 return false;
471
472 unsigned DestNumElts = DestTy->getNumElements();
473 unsigned SrcNumElts = SrcTy->getNumElements();
474 SmallVector<int, 16> NewMask;
475 if (SrcNumElts <= DestNumElts) {
476 // The bitcast is from wide to narrow/equal elements. The shuffle mask can
477 // always be expanded to the equivalent form choosing narrower elements.
478 assert(DestNumElts % SrcNumElts == 0 && "Unexpected shuffle mask");
479 unsigned ScaleFactor = DestNumElts / SrcNumElts;
480 narrowShuffleMaskElts(ScaleFactor, Mask, NewMask);
481 } else {
482 // The bitcast is from narrow elements to wide elements. The shuffle mask
483 // must choose consecutive elements to allow casting first.
484 assert(SrcNumElts % DestNumElts == 0 && "Unexpected shuffle mask");
485 unsigned ScaleFactor = SrcNumElts / DestNumElts;
486 if (!widenShuffleMaskElts(ScaleFactor, Mask, NewMask))
487 return false;
488 }
489 // bitcast (shuf V, MaskC) --> shuf (bitcast V), MaskC'
490 ++NumShufOfBitcast;
491 Value *CastV = Builder.CreateBitCast(V, DestTy);
492 Value *Shuf = Builder.CreateShuffleVector(CastV, NewMask);
493 replaceValue(I, *Shuf);
494 return true;
495 }
496
497 /// Match a vector binop or compare instruction with at least one inserted
498 /// scalar operand and convert to scalar binop/cmp followed by insertelement.
scalarizeBinopOrCmp(Instruction & I)499 bool VectorCombine::scalarizeBinopOrCmp(Instruction &I) {
500 CmpInst::Predicate Pred = CmpInst::BAD_ICMP_PREDICATE;
501 Value *Ins0, *Ins1;
502 if (!match(&I, m_BinOp(m_Value(Ins0), m_Value(Ins1))) &&
503 !match(&I, m_Cmp(Pred, m_Value(Ins0), m_Value(Ins1))))
504 return false;
505
506 // Do not convert the vector condition of a vector select into a scalar
507 // condition. That may cause problems for codegen because of differences in
508 // boolean formats and register-file transfers.
509 // TODO: Can we account for that in the cost model?
510 bool IsCmp = Pred != CmpInst::Predicate::BAD_ICMP_PREDICATE;
511 if (IsCmp)
512 for (User *U : I.users())
513 if (match(U, m_Select(m_Specific(&I), m_Value(), m_Value())))
514 return false;
515
516 // Match against one or both scalar values being inserted into constant
517 // vectors:
518 // vec_op VecC0, (inselt VecC1, V1, Index)
519 // vec_op (inselt VecC0, V0, Index), VecC1
520 // vec_op (inselt VecC0, V0, Index), (inselt VecC1, V1, Index)
521 // TODO: Deal with mismatched index constants and variable indexes?
522 Constant *VecC0 = nullptr, *VecC1 = nullptr;
523 Value *V0 = nullptr, *V1 = nullptr;
524 uint64_t Index0 = 0, Index1 = 0;
525 if (!match(Ins0, m_InsertElt(m_Constant(VecC0), m_Value(V0),
526 m_ConstantInt(Index0))) &&
527 !match(Ins0, m_Constant(VecC0)))
528 return false;
529 if (!match(Ins1, m_InsertElt(m_Constant(VecC1), m_Value(V1),
530 m_ConstantInt(Index1))) &&
531 !match(Ins1, m_Constant(VecC1)))
532 return false;
533
534 bool IsConst0 = !V0;
535 bool IsConst1 = !V1;
536 if (IsConst0 && IsConst1)
537 return false;
538 if (!IsConst0 && !IsConst1 && Index0 != Index1)
539 return false;
540
541 // Bail for single insertion if it is a load.
542 // TODO: Handle this once getVectorInstrCost can cost for load/stores.
543 auto *I0 = dyn_cast_or_null<Instruction>(V0);
544 auto *I1 = dyn_cast_or_null<Instruction>(V1);
545 if ((IsConst0 && I1 && I1->mayReadFromMemory()) ||
546 (IsConst1 && I0 && I0->mayReadFromMemory()))
547 return false;
548
549 uint64_t Index = IsConst0 ? Index1 : Index0;
550 Type *ScalarTy = IsConst0 ? V1->getType() : V0->getType();
551 Type *VecTy = I.getType();
552 assert(VecTy->isVectorTy() &&
553 (IsConst0 || IsConst1 || V0->getType() == V1->getType()) &&
554 (ScalarTy->isIntegerTy() || ScalarTy->isFloatingPointTy() ||
555 ScalarTy->isPointerTy()) &&
556 "Unexpected types for insert element into binop or cmp");
557
558 unsigned Opcode = I.getOpcode();
559 int ScalarOpCost, VectorOpCost;
560 if (IsCmp) {
561 ScalarOpCost = TTI.getCmpSelInstrCost(Opcode, ScalarTy);
562 VectorOpCost = TTI.getCmpSelInstrCost(Opcode, VecTy);
563 } else {
564 ScalarOpCost = TTI.getArithmeticInstrCost(Opcode, ScalarTy);
565 VectorOpCost = TTI.getArithmeticInstrCost(Opcode, VecTy);
566 }
567
568 // Get cost estimate for the insert element. This cost will factor into
569 // both sequences.
570 int InsertCost =
571 TTI.getVectorInstrCost(Instruction::InsertElement, VecTy, Index);
572 int OldCost = (IsConst0 ? 0 : InsertCost) + (IsConst1 ? 0 : InsertCost) +
573 VectorOpCost;
574 int NewCost = ScalarOpCost + InsertCost +
575 (IsConst0 ? 0 : !Ins0->hasOneUse() * InsertCost) +
576 (IsConst1 ? 0 : !Ins1->hasOneUse() * InsertCost);
577
578 // We want to scalarize unless the vector variant actually has lower cost.
579 if (OldCost < NewCost)
580 return false;
581
582 // vec_op (inselt VecC0, V0, Index), (inselt VecC1, V1, Index) -->
583 // inselt NewVecC, (scalar_op V0, V1), Index
584 if (IsCmp)
585 ++NumScalarCmp;
586 else
587 ++NumScalarBO;
588
589 // For constant cases, extract the scalar element, this should constant fold.
590 if (IsConst0)
591 V0 = ConstantExpr::getExtractElement(VecC0, Builder.getInt64(Index));
592 if (IsConst1)
593 V1 = ConstantExpr::getExtractElement(VecC1, Builder.getInt64(Index));
594
595 Value *Scalar =
596 IsCmp ? Builder.CreateCmp(Pred, V0, V1)
597 : Builder.CreateBinOp((Instruction::BinaryOps)Opcode, V0, V1);
598
599 Scalar->setName(I.getName() + ".scalar");
600
601 // All IR flags are safe to back-propagate. There is no potential for extra
602 // poison to be created by the scalar instruction.
603 if (auto *ScalarInst = dyn_cast<Instruction>(Scalar))
604 ScalarInst->copyIRFlags(&I);
605
606 // Fold the vector constants in the original vectors into a new base vector.
607 Constant *NewVecC = IsCmp ? ConstantExpr::getCompare(Pred, VecC0, VecC1)
608 : ConstantExpr::get(Opcode, VecC0, VecC1);
609 Value *Insert = Builder.CreateInsertElement(NewVecC, Scalar, Index);
610 replaceValue(I, *Insert);
611 return true;
612 }
613
614 /// Try to combine a scalar binop + 2 scalar compares of extracted elements of
615 /// a vector into vector operations followed by extract. Note: The SLP pass
616 /// may miss this pattern because of implementation problems.
foldExtractedCmps(Instruction & I)617 bool VectorCombine::foldExtractedCmps(Instruction &I) {
618 // We are looking for a scalar binop of booleans.
619 // binop i1 (cmp Pred I0, C0), (cmp Pred I1, C1)
620 if (!I.isBinaryOp() || !I.getType()->isIntegerTy(1))
621 return false;
622
623 // The compare predicates should match, and each compare should have a
624 // constant operand.
625 // TODO: Relax the one-use constraints.
626 Value *B0 = I.getOperand(0), *B1 = I.getOperand(1);
627 Instruction *I0, *I1;
628 Constant *C0, *C1;
629 CmpInst::Predicate P0, P1;
630 if (!match(B0, m_OneUse(m_Cmp(P0, m_Instruction(I0), m_Constant(C0)))) ||
631 !match(B1, m_OneUse(m_Cmp(P1, m_Instruction(I1), m_Constant(C1)))) ||
632 P0 != P1)
633 return false;
634
635 // The compare operands must be extracts of the same vector with constant
636 // extract indexes.
637 // TODO: Relax the one-use constraints.
638 Value *X;
639 uint64_t Index0, Index1;
640 if (!match(I0, m_OneUse(m_ExtractElt(m_Value(X), m_ConstantInt(Index0)))) ||
641 !match(I1, m_OneUse(m_ExtractElt(m_Specific(X), m_ConstantInt(Index1)))))
642 return false;
643
644 auto *Ext0 = cast<ExtractElementInst>(I0);
645 auto *Ext1 = cast<ExtractElementInst>(I1);
646 ExtractElementInst *ConvertToShuf = getShuffleExtract(Ext0, Ext1);
647 if (!ConvertToShuf)
648 return false;
649
650 // The original scalar pattern is:
651 // binop i1 (cmp Pred (ext X, Index0), C0), (cmp Pred (ext X, Index1), C1)
652 CmpInst::Predicate Pred = P0;
653 unsigned CmpOpcode = CmpInst::isFPPredicate(Pred) ? Instruction::FCmp
654 : Instruction::ICmp;
655 auto *VecTy = dyn_cast<FixedVectorType>(X->getType());
656 if (!VecTy)
657 return false;
658
659 int OldCost = TTI.getVectorInstrCost(Ext0->getOpcode(), VecTy, Index0);
660 OldCost += TTI.getVectorInstrCost(Ext1->getOpcode(), VecTy, Index1);
661 OldCost += TTI.getCmpSelInstrCost(CmpOpcode, I0->getType()) * 2;
662 OldCost += TTI.getArithmeticInstrCost(I.getOpcode(), I.getType());
663
664 // The proposed vector pattern is:
665 // vcmp = cmp Pred X, VecC
666 // ext (binop vNi1 vcmp, (shuffle vcmp, Index1)), Index0
667 int CheapIndex = ConvertToShuf == Ext0 ? Index1 : Index0;
668 int ExpensiveIndex = ConvertToShuf == Ext0 ? Index0 : Index1;
669 auto *CmpTy = cast<FixedVectorType>(CmpInst::makeCmpResultType(X->getType()));
670 int NewCost = TTI.getCmpSelInstrCost(CmpOpcode, X->getType());
671 NewCost +=
672 TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, CmpTy);
673 NewCost += TTI.getArithmeticInstrCost(I.getOpcode(), CmpTy);
674 NewCost += TTI.getVectorInstrCost(Ext0->getOpcode(), CmpTy, CheapIndex);
675
676 // Aggressively form vector ops if the cost is equal because the transform
677 // may enable further optimization.
678 // Codegen can reverse this transform (scalarize) if it was not profitable.
679 if (OldCost < NewCost)
680 return false;
681
682 // Create a vector constant from the 2 scalar constants.
683 SmallVector<Constant *, 32> CmpC(VecTy->getNumElements(),
684 UndefValue::get(VecTy->getElementType()));
685 CmpC[Index0] = C0;
686 CmpC[Index1] = C1;
687 Value *VCmp = Builder.CreateCmp(Pred, X, ConstantVector::get(CmpC));
688
689 Value *Shuf = createShiftShuffle(VCmp, ExpensiveIndex, CheapIndex, Builder);
690 Value *VecLogic = Builder.CreateBinOp(cast<BinaryOperator>(I).getOpcode(),
691 VCmp, Shuf);
692 Value *NewExt = Builder.CreateExtractElement(VecLogic, CheapIndex);
693 replaceValue(I, *NewExt);
694 ++NumVecCmpBO;
695 return true;
696 }
697
698 /// This is the entry point for all transforms. Pass manager differences are
699 /// handled in the callers of this function.
run()700 bool VectorCombine::run() {
701 if (DisableVectorCombine)
702 return false;
703
704 // Don't attempt vectorization if the target does not support vectors.
705 if (!TTI.getNumberOfRegisters(TTI.getRegisterClassForType(/*Vector*/ true)))
706 return false;
707
708 bool MadeChange = false;
709 for (BasicBlock &BB : F) {
710 // Ignore unreachable basic blocks.
711 if (!DT.isReachableFromEntry(&BB))
712 continue;
713 // Do not delete instructions under here and invalidate the iterator.
714 // Walk the block forwards to enable simple iterative chains of transforms.
715 // TODO: It could be more efficient to remove dead instructions
716 // iteratively in this loop rather than waiting until the end.
717 for (Instruction &I : BB) {
718 if (isa<DbgInfoIntrinsic>(I))
719 continue;
720 Builder.SetInsertPoint(&I);
721 MadeChange |= vectorizeLoadInsert(I);
722 MadeChange |= foldExtractExtract(I);
723 MadeChange |= foldBitcastShuf(I);
724 MadeChange |= scalarizeBinopOrCmp(I);
725 MadeChange |= foldExtractedCmps(I);
726 }
727 }
728
729 // We're done with transforms, so remove dead instructions.
730 if (MadeChange)
731 for (BasicBlock &BB : F)
732 SimplifyInstructionsInBlock(&BB);
733
734 return MadeChange;
735 }
736
737 // Pass manager boilerplate below here.
738
739 namespace {
740 class VectorCombineLegacyPass : public FunctionPass {
741 public:
742 static char ID;
VectorCombineLegacyPass()743 VectorCombineLegacyPass() : FunctionPass(ID) {
744 initializeVectorCombineLegacyPassPass(*PassRegistry::getPassRegistry());
745 }
746
getAnalysisUsage(AnalysisUsage & AU) const747 void getAnalysisUsage(AnalysisUsage &AU) const override {
748 AU.addRequired<DominatorTreeWrapperPass>();
749 AU.addRequired<TargetTransformInfoWrapperPass>();
750 AU.setPreservesCFG();
751 AU.addPreserved<DominatorTreeWrapperPass>();
752 AU.addPreserved<GlobalsAAWrapperPass>();
753 AU.addPreserved<AAResultsWrapperPass>();
754 AU.addPreserved<BasicAAWrapperPass>();
755 FunctionPass::getAnalysisUsage(AU);
756 }
757
runOnFunction(Function & F)758 bool runOnFunction(Function &F) override {
759 if (skipFunction(F))
760 return false;
761 auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
762 auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
763 VectorCombine Combiner(F, TTI, DT);
764 return Combiner.run();
765 }
766 };
767 } // namespace
768
769 char VectorCombineLegacyPass::ID = 0;
770 INITIALIZE_PASS_BEGIN(VectorCombineLegacyPass, "vector-combine",
771 "Optimize scalar/vector ops", false,
772 false)
INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)773 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
774 INITIALIZE_PASS_END(VectorCombineLegacyPass, "vector-combine",
775 "Optimize scalar/vector ops", false, false)
776 Pass *llvm::createVectorCombinePass() {
777 return new VectorCombineLegacyPass();
778 }
779
run(Function & F,FunctionAnalysisManager & FAM)780 PreservedAnalyses VectorCombinePass::run(Function &F,
781 FunctionAnalysisManager &FAM) {
782 TargetTransformInfo &TTI = FAM.getResult<TargetIRAnalysis>(F);
783 DominatorTree &DT = FAM.getResult<DominatorTreeAnalysis>(F);
784 VectorCombine Combiner(F, TTI, DT);
785 if (!Combiner.run())
786 return PreservedAnalyses::all();
787 PreservedAnalyses PA;
788 PA.preserveSet<CFGAnalyses>();
789 PA.preserve<GlobalsAA>();
790 PA.preserve<AAManager>();
791 PA.preserve<BasicAA>();
792 return PA;
793 }
794