1 //===- ScalarizeMaskedMemIntrin.cpp - Scalarize unsupported masked mem ----===//
2 //                                    instrinsics
3 //
4 //                     The LLVM Compiler Infrastructure
5 //
6 // This file is distributed under the University of Illinois Open Source
7 // License. See LICENSE.TXT for details.
8 //
9 //===----------------------------------------------------------------------===//
10 //
11 // This pass replaces masked memory intrinsics - when unsupported by the target
12 // - with a chain of basic blocks, that deal with the elements one-by-one if the
13 // appropriate mask bit is set.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "llvm/ADT/Twine.h"
18 #include "llvm/Analysis/TargetTransformInfo.h"
19 #include "llvm/CodeGen/TargetSubtargetInfo.h"
20 #include "llvm/IR/BasicBlock.h"
21 #include "llvm/IR/Constant.h"
22 #include "llvm/IR/Constants.h"
23 #include "llvm/IR/DerivedTypes.h"
24 #include "llvm/IR/Function.h"
25 #include "llvm/IR/IRBuilder.h"
26 #include "llvm/IR/InstrTypes.h"
27 #include "llvm/IR/Instruction.h"
28 #include "llvm/IR/Instructions.h"
29 #include "llvm/IR/IntrinsicInst.h"
30 #include "llvm/IR/Intrinsics.h"
31 #include "llvm/IR/Type.h"
32 #include "llvm/IR/Value.h"
33 #include "llvm/Pass.h"
34 #include "llvm/Support/Casting.h"
35 #include <algorithm>
36 #include <cassert>
37 
38 using namespace llvm;
39 
40 #define DEBUG_TYPE "scalarize-masked-mem-intrin"
41 
42 namespace {
43 
44 class ScalarizeMaskedMemIntrin : public FunctionPass {
45   const TargetTransformInfo *TTI = nullptr;
46 
47 public:
48   static char ID; // Pass identification, replacement for typeid
49 
ScalarizeMaskedMemIntrin()50   explicit ScalarizeMaskedMemIntrin() : FunctionPass(ID) {
51     initializeScalarizeMaskedMemIntrinPass(*PassRegistry::getPassRegistry());
52   }
53 
54   bool runOnFunction(Function &F) override;
55 
getPassName() const56   StringRef getPassName() const override {
57     return "Scalarize Masked Memory Intrinsics";
58   }
59 
getAnalysisUsage(AnalysisUsage & AU) const60   void getAnalysisUsage(AnalysisUsage &AU) const override {
61     AU.addRequired<TargetTransformInfoWrapperPass>();
62   }
63 
64 private:
65   bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT);
66   bool optimizeCallInst(CallInst *CI, bool &ModifiedDT);
67 };
68 
69 } // end anonymous namespace
70 
71 char ScalarizeMaskedMemIntrin::ID = 0;
72 
73 INITIALIZE_PASS(ScalarizeMaskedMemIntrin, DEBUG_TYPE,
74                 "Scalarize unsupported masked memory intrinsics", false, false)
75 
createScalarizeMaskedMemIntrinPass()76 FunctionPass *llvm::createScalarizeMaskedMemIntrinPass() {
77   return new ScalarizeMaskedMemIntrin();
78 }
79 
80 // Translate a masked load intrinsic like
81 // <16 x i32 > @llvm.masked.load( <16 x i32>* %addr, i32 align,
82 //                               <16 x i1> %mask, <16 x i32> %passthru)
83 // to a chain of basic blocks, with loading element one-by-one if
84 // the appropriate mask bit is set
85 //
86 //  %1 = bitcast i8* %addr to i32*
87 //  %2 = extractelement <16 x i1> %mask, i32 0
88 //  %3 = icmp eq i1 %2, true
89 //  br i1 %3, label %cond.load, label %else
90 //
91 // cond.load:                                        ; preds = %0
92 //  %4 = getelementptr i32* %1, i32 0
93 //  %5 = load i32* %4
94 //  %6 = insertelement <16 x i32> undef, i32 %5, i32 0
95 //  br label %else
96 //
97 // else:                                             ; preds = %0, %cond.load
98 //  %res.phi.else = phi <16 x i32> [ %6, %cond.load ], [ undef, %0 ]
99 //  %7 = extractelement <16 x i1> %mask, i32 1
100 //  %8 = icmp eq i1 %7, true
101 //  br i1 %8, label %cond.load1, label %else2
102 //
103 // cond.load1:                                       ; preds = %else
104 //  %9 = getelementptr i32* %1, i32 1
105 //  %10 = load i32* %9
106 //  %11 = insertelement <16 x i32> %res.phi.else, i32 %10, i32 1
107 //  br label %else2
108 //
109 // else2:                                          ; preds = %else, %cond.load1
110 //  %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ]
111 //  %12 = extractelement <16 x i1> %mask, i32 2
112 //  %13 = icmp eq i1 %12, true
113 //  br i1 %13, label %cond.load4, label %else5
114 //
scalarizeMaskedLoad(CallInst * CI)115 static void scalarizeMaskedLoad(CallInst *CI) {
116   Value *Ptr = CI->getArgOperand(0);
117   Value *Alignment = CI->getArgOperand(1);
118   Value *Mask = CI->getArgOperand(2);
119   Value *Src0 = CI->getArgOperand(3);
120 
121   unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
122   VectorType *VecType = dyn_cast<VectorType>(CI->getType());
123   assert(VecType && "Unexpected return type of masked load intrinsic");
124 
125   Type *EltTy = CI->getType()->getVectorElementType();
126 
127   IRBuilder<> Builder(CI->getContext());
128   Instruction *InsertPt = CI;
129   BasicBlock *IfBlock = CI->getParent();
130   BasicBlock *CondBlock = nullptr;
131   BasicBlock *PrevIfBlock = CI->getParent();
132 
133   Builder.SetInsertPoint(InsertPt);
134   Builder.SetCurrentDebugLocation(CI->getDebugLoc());
135 
136   // Short-cut if the mask is all-true.
137   bool IsAllOnesMask =
138       isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue();
139 
140   if (IsAllOnesMask) {
141     Value *NewI = Builder.CreateAlignedLoad(Ptr, AlignVal);
142     CI->replaceAllUsesWith(NewI);
143     CI->eraseFromParent();
144     return;
145   }
146 
147   // Adjust alignment for the scalar instruction.
148   AlignVal = std::min(AlignVal, VecType->getScalarSizeInBits() / 8);
149   // Bitcast %addr fron i8* to EltTy*
150   Type *NewPtrType =
151       EltTy->getPointerTo(cast<PointerType>(Ptr->getType())->getAddressSpace());
152   Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType);
153   unsigned VectorWidth = VecType->getNumElements();
154 
155   Value *UndefVal = UndefValue::get(VecType);
156 
157   // The result vector
158   Value *VResult = UndefVal;
159 
160   if (isa<ConstantVector>(Mask)) {
161     for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
162       if (cast<ConstantVector>(Mask)->getOperand(Idx)->isNullValue())
163         continue;
164       Value *Gep =
165           Builder.CreateInBoundsGEP(EltTy, FirstEltPtr, Builder.getInt32(Idx));
166       LoadInst *Load = Builder.CreateAlignedLoad(Gep, AlignVal);
167       VResult =
168           Builder.CreateInsertElement(VResult, Load, Builder.getInt32(Idx));
169     }
170     Value *NewI = Builder.CreateSelect(Mask, VResult, Src0);
171     CI->replaceAllUsesWith(NewI);
172     CI->eraseFromParent();
173     return;
174   }
175 
176   PHINode *Phi = nullptr;
177   Value *PrevPhi = UndefVal;
178 
179   for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
180     // Fill the "else" block, created in the previous iteration
181     //
182     //  %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ]
183     //  %mask_1 = extractelement <16 x i1> %mask, i32 Idx
184     //  %to_load = icmp eq i1 %mask_1, true
185     //  br i1 %to_load, label %cond.load, label %else
186     //
187     if (Idx > 0) {
188       Phi = Builder.CreatePHI(VecType, 2, "res.phi.else");
189       Phi->addIncoming(VResult, CondBlock);
190       Phi->addIncoming(PrevPhi, PrevIfBlock);
191       PrevPhi = Phi;
192       VResult = Phi;
193     }
194 
195     Value *Predicate =
196         Builder.CreateExtractElement(Mask, Builder.getInt32(Idx));
197     Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_EQ, Predicate,
198                                     ConstantInt::get(Predicate->getType(), 1));
199 
200     // Create "cond" block
201     //
202     //  %EltAddr = getelementptr i32* %1, i32 0
203     //  %Elt = load i32* %EltAddr
204     //  VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
205     //
206     CondBlock = IfBlock->splitBasicBlock(InsertPt->getIterator(), "cond.load");
207     Builder.SetInsertPoint(InsertPt);
208 
209     Value *Gep =
210         Builder.CreateInBoundsGEP(EltTy, FirstEltPtr, Builder.getInt32(Idx));
211     LoadInst *Load = Builder.CreateAlignedLoad(Gep, AlignVal);
212     VResult = Builder.CreateInsertElement(VResult, Load, Builder.getInt32(Idx));
213 
214     // Create "else" block, fill it in the next iteration
215     BasicBlock *NewIfBlock =
216         CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
217     Builder.SetInsertPoint(InsertPt);
218     Instruction *OldBr = IfBlock->getTerminator();
219     BranchInst::Create(CondBlock, NewIfBlock, Cmp, OldBr);
220     OldBr->eraseFromParent();
221     PrevIfBlock = IfBlock;
222     IfBlock = NewIfBlock;
223   }
224 
225   Phi = Builder.CreatePHI(VecType, 2, "res.phi.select");
226   Phi->addIncoming(VResult, CondBlock);
227   Phi->addIncoming(PrevPhi, PrevIfBlock);
228   Value *NewI = Builder.CreateSelect(Mask, Phi, Src0);
229   CI->replaceAllUsesWith(NewI);
230   CI->eraseFromParent();
231 }
232 
233 // Translate a masked store intrinsic, like
234 // void @llvm.masked.store(<16 x i32> %src, <16 x i32>* %addr, i32 align,
235 //                               <16 x i1> %mask)
236 // to a chain of basic blocks, that stores element one-by-one if
237 // the appropriate mask bit is set
238 //
239 //   %1 = bitcast i8* %addr to i32*
240 //   %2 = extractelement <16 x i1> %mask, i32 0
241 //   %3 = icmp eq i1 %2, true
242 //   br i1 %3, label %cond.store, label %else
243 //
244 // cond.store:                                       ; preds = %0
245 //   %4 = extractelement <16 x i32> %val, i32 0
246 //   %5 = getelementptr i32* %1, i32 0
247 //   store i32 %4, i32* %5
248 //   br label %else
249 //
250 // else:                                             ; preds = %0, %cond.store
251 //   %6 = extractelement <16 x i1> %mask, i32 1
252 //   %7 = icmp eq i1 %6, true
253 //   br i1 %7, label %cond.store1, label %else2
254 //
255 // cond.store1:                                      ; preds = %else
256 //   %8 = extractelement <16 x i32> %val, i32 1
257 //   %9 = getelementptr i32* %1, i32 1
258 //   store i32 %8, i32* %9
259 //   br label %else2
260 //   . . .
scalarizeMaskedStore(CallInst * CI)261 static void scalarizeMaskedStore(CallInst *CI) {
262   Value *Src = CI->getArgOperand(0);
263   Value *Ptr = CI->getArgOperand(1);
264   Value *Alignment = CI->getArgOperand(2);
265   Value *Mask = CI->getArgOperand(3);
266 
267   unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
268   VectorType *VecType = dyn_cast<VectorType>(Src->getType());
269   assert(VecType && "Unexpected data type in masked store intrinsic");
270 
271   Type *EltTy = VecType->getElementType();
272 
273   IRBuilder<> Builder(CI->getContext());
274   Instruction *InsertPt = CI;
275   BasicBlock *IfBlock = CI->getParent();
276   Builder.SetInsertPoint(InsertPt);
277   Builder.SetCurrentDebugLocation(CI->getDebugLoc());
278 
279   // Short-cut if the mask is all-true.
280   bool IsAllOnesMask =
281       isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue();
282 
283   if (IsAllOnesMask) {
284     Builder.CreateAlignedStore(Src, Ptr, AlignVal);
285     CI->eraseFromParent();
286     return;
287   }
288 
289   // Adjust alignment for the scalar instruction.
290   AlignVal = std::max(AlignVal, VecType->getScalarSizeInBits() / 8);
291   // Bitcast %addr fron i8* to EltTy*
292   Type *NewPtrType =
293       EltTy->getPointerTo(cast<PointerType>(Ptr->getType())->getAddressSpace());
294   Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType);
295   unsigned VectorWidth = VecType->getNumElements();
296 
297   if (isa<ConstantVector>(Mask)) {
298     for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
299       if (cast<ConstantVector>(Mask)->getOperand(Idx)->isNullValue())
300         continue;
301       Value *OneElt = Builder.CreateExtractElement(Src, Builder.getInt32(Idx));
302       Value *Gep =
303           Builder.CreateInBoundsGEP(EltTy, FirstEltPtr, Builder.getInt32(Idx));
304       Builder.CreateAlignedStore(OneElt, Gep, AlignVal);
305     }
306     CI->eraseFromParent();
307     return;
308   }
309 
310   for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
311     // Fill the "else" block, created in the previous iteration
312     //
313     //  %mask_1 = extractelement <16 x i1> %mask, i32 Idx
314     //  %to_store = icmp eq i1 %mask_1, true
315     //  br i1 %to_store, label %cond.store, label %else
316     //
317     Value *Predicate =
318         Builder.CreateExtractElement(Mask, Builder.getInt32(Idx));
319     Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_EQ, Predicate,
320                                     ConstantInt::get(Predicate->getType(), 1));
321 
322     // Create "cond" block
323     //
324     //  %OneElt = extractelement <16 x i32> %Src, i32 Idx
325     //  %EltAddr = getelementptr i32* %1, i32 0
326     //  %store i32 %OneElt, i32* %EltAddr
327     //
328     BasicBlock *CondBlock =
329         IfBlock->splitBasicBlock(InsertPt->getIterator(), "cond.store");
330     Builder.SetInsertPoint(InsertPt);
331 
332     Value *OneElt = Builder.CreateExtractElement(Src, Builder.getInt32(Idx));
333     Value *Gep =
334         Builder.CreateInBoundsGEP(EltTy, FirstEltPtr, Builder.getInt32(Idx));
335     Builder.CreateAlignedStore(OneElt, Gep, AlignVal);
336 
337     // Create "else" block, fill it in the next iteration
338     BasicBlock *NewIfBlock =
339         CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
340     Builder.SetInsertPoint(InsertPt);
341     Instruction *OldBr = IfBlock->getTerminator();
342     BranchInst::Create(CondBlock, NewIfBlock, Cmp, OldBr);
343     OldBr->eraseFromParent();
344     IfBlock = NewIfBlock;
345   }
346   CI->eraseFromParent();
347 }
348 
349 // Translate a masked gather intrinsic like
350 // <16 x i32 > @llvm.masked.gather.v16i32( <16 x i32*> %Ptrs, i32 4,
351 //                               <16 x i1> %Mask, <16 x i32> %Src)
352 // to a chain of basic blocks, with loading element one-by-one if
353 // the appropriate mask bit is set
354 //
355 // % Ptrs = getelementptr i32, i32* %base, <16 x i64> %ind
356 // % Mask0 = extractelement <16 x i1> %Mask, i32 0
357 // % ToLoad0 = icmp eq i1 % Mask0, true
358 // br i1 % ToLoad0, label %cond.load, label %else
359 //
360 // cond.load:
361 // % Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0
362 // % Load0 = load i32, i32* % Ptr0, align 4
363 // % Res0 = insertelement <16 x i32> undef, i32 % Load0, i32 0
364 // br label %else
365 //
366 // else:
367 // %res.phi.else = phi <16 x i32>[% Res0, %cond.load], [undef, % 0]
368 // % Mask1 = extractelement <16 x i1> %Mask, i32 1
369 // % ToLoad1 = icmp eq i1 % Mask1, true
370 // br i1 % ToLoad1, label %cond.load1, label %else2
371 //
372 // cond.load1:
373 // % Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
374 // % Load1 = load i32, i32* % Ptr1, align 4
375 // % Res1 = insertelement <16 x i32> %res.phi.else, i32 % Load1, i32 1
376 // br label %else2
377 // . . .
378 // % Result = select <16 x i1> %Mask, <16 x i32> %res.phi.select, <16 x i32> %Src
379 // ret <16 x i32> %Result
scalarizeMaskedGather(CallInst * CI)380 static void scalarizeMaskedGather(CallInst *CI) {
381   Value *Ptrs = CI->getArgOperand(0);
382   Value *Alignment = CI->getArgOperand(1);
383   Value *Mask = CI->getArgOperand(2);
384   Value *Src0 = CI->getArgOperand(3);
385 
386   VectorType *VecType = dyn_cast<VectorType>(CI->getType());
387 
388   assert(VecType && "Unexpected return type of masked load intrinsic");
389 
390   IRBuilder<> Builder(CI->getContext());
391   Instruction *InsertPt = CI;
392   BasicBlock *IfBlock = CI->getParent();
393   BasicBlock *CondBlock = nullptr;
394   BasicBlock *PrevIfBlock = CI->getParent();
395   Builder.SetInsertPoint(InsertPt);
396   unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
397 
398   Builder.SetCurrentDebugLocation(CI->getDebugLoc());
399 
400   Value *UndefVal = UndefValue::get(VecType);
401 
402   // The result vector
403   Value *VResult = UndefVal;
404   unsigned VectorWidth = VecType->getNumElements();
405 
406   // Shorten the way if the mask is a vector of constants.
407   bool IsConstMask = isa<ConstantVector>(Mask);
408 
409   if (IsConstMask) {
410     for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
411       if (cast<ConstantVector>(Mask)->getOperand(Idx)->isNullValue())
412         continue;
413       Value *Ptr = Builder.CreateExtractElement(Ptrs, Builder.getInt32(Idx),
414                                                 "Ptr" + Twine(Idx));
415       LoadInst *Load =
416           Builder.CreateAlignedLoad(Ptr, AlignVal, "Load" + Twine(Idx));
417       VResult = Builder.CreateInsertElement(
418           VResult, Load, Builder.getInt32(Idx), "Res" + Twine(Idx));
419     }
420     Value *NewI = Builder.CreateSelect(Mask, VResult, Src0);
421     CI->replaceAllUsesWith(NewI);
422     CI->eraseFromParent();
423     return;
424   }
425 
426   PHINode *Phi = nullptr;
427   Value *PrevPhi = UndefVal;
428 
429   for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
430     // Fill the "else" block, created in the previous iteration
431     //
432     //  %Mask1 = extractelement <16 x i1> %Mask, i32 1
433     //  %ToLoad1 = icmp eq i1 %Mask1, true
434     //  br i1 %ToLoad1, label %cond.load, label %else
435     //
436     if (Idx > 0) {
437       Phi = Builder.CreatePHI(VecType, 2, "res.phi.else");
438       Phi->addIncoming(VResult, CondBlock);
439       Phi->addIncoming(PrevPhi, PrevIfBlock);
440       PrevPhi = Phi;
441       VResult = Phi;
442     }
443 
444     Value *Predicate = Builder.CreateExtractElement(Mask, Builder.getInt32(Idx),
445                                                     "Mask" + Twine(Idx));
446     Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_EQ, Predicate,
447                                     ConstantInt::get(Predicate->getType(), 1),
448                                     "ToLoad" + Twine(Idx));
449 
450     // Create "cond" block
451     //
452     //  %EltAddr = getelementptr i32* %1, i32 0
453     //  %Elt = load i32* %EltAddr
454     //  VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
455     //
456     CondBlock = IfBlock->splitBasicBlock(InsertPt, "cond.load");
457     Builder.SetInsertPoint(InsertPt);
458 
459     Value *Ptr = Builder.CreateExtractElement(Ptrs, Builder.getInt32(Idx),
460                                               "Ptr" + Twine(Idx));
461     LoadInst *Load =
462         Builder.CreateAlignedLoad(Ptr, AlignVal, "Load" + Twine(Idx));
463     VResult = Builder.CreateInsertElement(VResult, Load, Builder.getInt32(Idx),
464                                           "Res" + Twine(Idx));
465 
466     // Create "else" block, fill it in the next iteration
467     BasicBlock *NewIfBlock = CondBlock->splitBasicBlock(InsertPt, "else");
468     Builder.SetInsertPoint(InsertPt);
469     Instruction *OldBr = IfBlock->getTerminator();
470     BranchInst::Create(CondBlock, NewIfBlock, Cmp, OldBr);
471     OldBr->eraseFromParent();
472     PrevIfBlock = IfBlock;
473     IfBlock = NewIfBlock;
474   }
475 
476   Phi = Builder.CreatePHI(VecType, 2, "res.phi.select");
477   Phi->addIncoming(VResult, CondBlock);
478   Phi->addIncoming(PrevPhi, PrevIfBlock);
479   Value *NewI = Builder.CreateSelect(Mask, Phi, Src0);
480   CI->replaceAllUsesWith(NewI);
481   CI->eraseFromParent();
482 }
483 
484 // Translate a masked scatter intrinsic, like
485 // void @llvm.masked.scatter.v16i32(<16 x i32> %Src, <16 x i32*>* %Ptrs, i32 4,
486 //                                  <16 x i1> %Mask)
487 // to a chain of basic blocks, that stores element one-by-one if
488 // the appropriate mask bit is set.
489 //
490 // % Ptrs = getelementptr i32, i32* %ptr, <16 x i64> %ind
491 // % Mask0 = extractelement <16 x i1> % Mask, i32 0
492 // % ToStore0 = icmp eq i1 % Mask0, true
493 // br i1 %ToStore0, label %cond.store, label %else
494 //
495 // cond.store:
496 // % Elt0 = extractelement <16 x i32> %Src, i32 0
497 // % Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0
498 // store i32 %Elt0, i32* % Ptr0, align 4
499 // br label %else
500 //
501 // else:
502 // % Mask1 = extractelement <16 x i1> % Mask, i32 1
503 // % ToStore1 = icmp eq i1 % Mask1, true
504 // br i1 % ToStore1, label %cond.store1, label %else2
505 //
506 // cond.store1:
507 // % Elt1 = extractelement <16 x i32> %Src, i32 1
508 // % Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
509 // store i32 % Elt1, i32* % Ptr1, align 4
510 // br label %else2
511 //   . . .
scalarizeMaskedScatter(CallInst * CI)512 static void scalarizeMaskedScatter(CallInst *CI) {
513   Value *Src = CI->getArgOperand(0);
514   Value *Ptrs = CI->getArgOperand(1);
515   Value *Alignment = CI->getArgOperand(2);
516   Value *Mask = CI->getArgOperand(3);
517 
518   assert(isa<VectorType>(Src->getType()) &&
519          "Unexpected data type in masked scatter intrinsic");
520   assert(isa<VectorType>(Ptrs->getType()) &&
521          isa<PointerType>(Ptrs->getType()->getVectorElementType()) &&
522          "Vector of pointers is expected in masked scatter intrinsic");
523 
524   IRBuilder<> Builder(CI->getContext());
525   Instruction *InsertPt = CI;
526   BasicBlock *IfBlock = CI->getParent();
527   Builder.SetInsertPoint(InsertPt);
528   Builder.SetCurrentDebugLocation(CI->getDebugLoc());
529 
530   unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
531   unsigned VectorWidth = Src->getType()->getVectorNumElements();
532 
533   // Shorten the way if the mask is a vector of constants.
534   bool IsConstMask = isa<ConstantVector>(Mask);
535 
536   if (IsConstMask) {
537     for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
538       if (cast<ConstantVector>(Mask)->getOperand(Idx)->isNullValue())
539         continue;
540       Value *OneElt = Builder.CreateExtractElement(Src, Builder.getInt32(Idx),
541                                                    "Elt" + Twine(Idx));
542       Value *Ptr = Builder.CreateExtractElement(Ptrs, Builder.getInt32(Idx),
543                                                 "Ptr" + Twine(Idx));
544       Builder.CreateAlignedStore(OneElt, Ptr, AlignVal);
545     }
546     CI->eraseFromParent();
547     return;
548   }
549   for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
550     // Fill the "else" block, created in the previous iteration
551     //
552     //  % Mask1 = extractelement <16 x i1> % Mask, i32 Idx
553     //  % ToStore = icmp eq i1 % Mask1, true
554     //  br i1 % ToStore, label %cond.store, label %else
555     //
556     Value *Predicate = Builder.CreateExtractElement(Mask, Builder.getInt32(Idx),
557                                                     "Mask" + Twine(Idx));
558     Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_EQ, Predicate,
559                                     ConstantInt::get(Predicate->getType(), 1),
560                                     "ToStore" + Twine(Idx));
561 
562     // Create "cond" block
563     //
564     //  % Elt1 = extractelement <16 x i32> %Src, i32 1
565     //  % Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
566     //  %store i32 % Elt1, i32* % Ptr1
567     //
568     BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt, "cond.store");
569     Builder.SetInsertPoint(InsertPt);
570 
571     Value *OneElt = Builder.CreateExtractElement(Src, Builder.getInt32(Idx),
572                                                  "Elt" + Twine(Idx));
573     Value *Ptr = Builder.CreateExtractElement(Ptrs, Builder.getInt32(Idx),
574                                               "Ptr" + Twine(Idx));
575     Builder.CreateAlignedStore(OneElt, Ptr, AlignVal);
576 
577     // Create "else" block, fill it in the next iteration
578     BasicBlock *NewIfBlock = CondBlock->splitBasicBlock(InsertPt, "else");
579     Builder.SetInsertPoint(InsertPt);
580     Instruction *OldBr = IfBlock->getTerminator();
581     BranchInst::Create(CondBlock, NewIfBlock, Cmp, OldBr);
582     OldBr->eraseFromParent();
583     IfBlock = NewIfBlock;
584   }
585   CI->eraseFromParent();
586 }
587 
runOnFunction(Function & F)588 bool ScalarizeMaskedMemIntrin::runOnFunction(Function &F) {
589   bool EverMadeChange = false;
590 
591   TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
592 
593   bool MadeChange = true;
594   while (MadeChange) {
595     MadeChange = false;
596     for (Function::iterator I = F.begin(); I != F.end();) {
597       BasicBlock *BB = &*I++;
598       bool ModifiedDTOnIteration = false;
599       MadeChange |= optimizeBlock(*BB, ModifiedDTOnIteration);
600 
601       // Restart BB iteration if the dominator tree of the Function was changed
602       if (ModifiedDTOnIteration)
603         break;
604     }
605 
606     EverMadeChange |= MadeChange;
607   }
608 
609   return EverMadeChange;
610 }
611 
optimizeBlock(BasicBlock & BB,bool & ModifiedDT)612 bool ScalarizeMaskedMemIntrin::optimizeBlock(BasicBlock &BB, bool &ModifiedDT) {
613   bool MadeChange = false;
614 
615   BasicBlock::iterator CurInstIterator = BB.begin();
616   while (CurInstIterator != BB.end()) {
617     if (CallInst *CI = dyn_cast<CallInst>(&*CurInstIterator++))
618       MadeChange |= optimizeCallInst(CI, ModifiedDT);
619     if (ModifiedDT)
620       return true;
621   }
622 
623   return MadeChange;
624 }
625 
optimizeCallInst(CallInst * CI,bool & ModifiedDT)626 bool ScalarizeMaskedMemIntrin::optimizeCallInst(CallInst *CI,
627                                                 bool &ModifiedDT) {
628   IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI);
629   if (II) {
630     switch (II->getIntrinsicID()) {
631     default:
632       break;
633     case Intrinsic::masked_load:
634       // Scalarize unsupported vector masked load
635       if (!TTI->isLegalMaskedLoad(CI->getType())) {
636         scalarizeMaskedLoad(CI);
637         ModifiedDT = true;
638         return true;
639       }
640       return false;
641     case Intrinsic::masked_store:
642       if (!TTI->isLegalMaskedStore(CI->getArgOperand(0)->getType())) {
643         scalarizeMaskedStore(CI);
644         ModifiedDT = true;
645         return true;
646       }
647       return false;
648     case Intrinsic::masked_gather:
649       if (!TTI->isLegalMaskedGather(CI->getType())) {
650         scalarizeMaskedGather(CI);
651         ModifiedDT = true;
652         return true;
653       }
654       return false;
655     case Intrinsic::masked_scatter:
656       if (!TTI->isLegalMaskedScatter(CI->getArgOperand(0)->getType())) {
657         scalarizeMaskedScatter(CI);
658         ModifiedDT = true;
659         return true;
660       }
661       return false;
662     }
663   }
664 
665   return false;
666 }
667