1 /****************************************************************************
2  * Copyright (C) 2014-2018 Intel Corporation.   All Rights Reserved.
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining a
5  * copy of this software and associated documentation files (the "Software"),
6  * to deal in the Software without restriction, including without limitation
7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8  * and/or sell copies of the Software, and to permit persons to whom the
9  * Software is furnished to do so, subject to the following conditions:
10  *
11  * The above copyright notice and this permission notice (including the next
12  * paragraph) shall be included in all copies or substantial portions of the
13  * Software.
14  *
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21  * IN THE SOFTWARE.
22  *
23  * @file lower_x86.cpp
24  *
25  * @brief llvm pass to lower meta code to x86
26  *
27  * Notes:
28  *
29  ******************************************************************************/
30 
31 #include "jit_pch.hpp"
32 #include "passes.h"
33 #include "JitManager.h"
34 
35 #include "common/simdlib.hpp"
36 
37 #include <unordered_map>
38 
39 extern "C" void ScatterPS_256(uint8_t*, SIMD256::Integer, SIMD256::Float, uint8_t, uint32_t);
40 
41 namespace llvm
42 {
43     // foward declare the initializer
44     void initializeLowerX86Pass(PassRegistry&);
45 } // namespace llvm
46 
47 namespace SwrJit
48 {
49     using namespace llvm;
50 
51     enum TargetArch
52     {
53         AVX    = 0,
54         AVX2   = 1,
55         AVX512 = 2
56     };
57 
58     enum TargetWidth
59     {
60         W256       = 0,
61         W512       = 1,
62         NUM_WIDTHS = 2
63     };
64 
65     struct LowerX86;
66 
67     typedef std::function<Instruction*(LowerX86*, TargetArch, TargetWidth, CallInst*)> EmuFunc;
68 
69     struct X86Intrinsic
70     {
71         IntrinsicID intrin[NUM_WIDTHS];
72         EmuFunc       emuFunc;
73     };
74 
75     // Map of intrinsics that haven't been moved to the new mechanism yet. If used, these get the
76     // previous behavior of mapping directly to avx/avx2 intrinsics.
77     using intrinsicMap_t = std::map<std::string, IntrinsicID>;
getIntrinsicMap()78     static intrinsicMap_t& getIntrinsicMap() {
79         static std::map<std::string, IntrinsicID> intrinsicMap = {
80             {"meta.intrinsic.BEXTR_32", Intrinsic::x86_bmi_bextr_32},
81             {"meta.intrinsic.VPSHUFB", Intrinsic::x86_avx2_pshuf_b},
82             {"meta.intrinsic.VCVTPS2PH", Intrinsic::x86_vcvtps2ph_256},
83             {"meta.intrinsic.VPTESTC", Intrinsic::x86_avx_ptestc_256},
84             {"meta.intrinsic.VPTESTZ", Intrinsic::x86_avx_ptestz_256},
85             {"meta.intrinsic.VPHADDD", Intrinsic::x86_avx2_phadd_d},
86             {"meta.intrinsic.PDEP32", Intrinsic::x86_bmi_pdep_32},
87             {"meta.intrinsic.RDTSC", Intrinsic::x86_rdtsc}
88         };
89         return intrinsicMap;
90     }
91 
92     // Forward decls
93     Instruction* NO_EMU(LowerX86* pThis, TargetArch arch, TargetWidth width, CallInst* pCallInst);
94     Instruction*
95     VPERM_EMU(LowerX86* pThis, TargetArch arch, TargetWidth width, CallInst* pCallInst);
96     Instruction*
97     VGATHER_EMU(LowerX86* pThis, TargetArch arch, TargetWidth width, CallInst* pCallInst);
98     Instruction*
99     VSCATTER_EMU(LowerX86* pThis, TargetArch arch, TargetWidth width, CallInst* pCallInst);
100     Instruction*
101     VROUND_EMU(LowerX86* pThis, TargetArch arch, TargetWidth width, CallInst* pCallInst);
102     Instruction*
103     VHSUB_EMU(LowerX86* pThis, TargetArch arch, TargetWidth width, CallInst* pCallInst);
104     Instruction*
105     VCONVERT_EMU(LowerX86* pThis, TargetArch arch, TargetWidth width, CallInst* pCallInst);
106 
107     Instruction* DOUBLE_EMU(LowerX86*     pThis,
108                             TargetArch    arch,
109                             TargetWidth   width,
110                             CallInst*     pCallInst,
111                             Intrinsic::ID intrin);
112 
113     static Intrinsic::ID DOUBLE = (Intrinsic::ID)-1;
114 
115     using intrinsicMapAdvanced_t = std::vector<std::map<std::string, X86Intrinsic>>;
116 
getIntrinsicMapAdvanced()117     static intrinsicMapAdvanced_t&  getIntrinsicMapAdvanced()
118     {
119         // clang-format off
120         static intrinsicMapAdvanced_t intrinsicMapAdvanced = {
121             //                               256 wide                               512 wide
122             {
123                 // AVX
124                 {"meta.intrinsic.VRCPPS",    {{Intrinsic::x86_avx_rcp_ps_256,       DOUBLE},                    NO_EMU}},
125                 {"meta.intrinsic.VPERMPS",   {{Intrinsic::not_intrinsic,            Intrinsic::not_intrinsic},  VPERM_EMU}},
126                 {"meta.intrinsic.VPERMD",    {{Intrinsic::not_intrinsic,            Intrinsic::not_intrinsic},  VPERM_EMU}},
127                 {"meta.intrinsic.VGATHERPD", {{Intrinsic::not_intrinsic,            Intrinsic::not_intrinsic},  VGATHER_EMU}},
128                 {"meta.intrinsic.VGATHERPS", {{Intrinsic::not_intrinsic,            Intrinsic::not_intrinsic},  VGATHER_EMU}},
129                 {"meta.intrinsic.VGATHERDD", {{Intrinsic::not_intrinsic,            Intrinsic::not_intrinsic},  VGATHER_EMU}},
130                 {"meta.intrinsic.VSCATTERPS", {{Intrinsic::not_intrinsic,           Intrinsic::not_intrinsic}, VSCATTER_EMU}},
131                 {"meta.intrinsic.VCVTPD2PS", {{Intrinsic::x86_avx_cvt_pd2_ps_256,   Intrinsic::not_intrinsic},  NO_EMU}},
132                 {"meta.intrinsic.VROUND",    {{Intrinsic::x86_avx_round_ps_256,     DOUBLE},                    NO_EMU}},
133                 {"meta.intrinsic.VHSUBPS",   {{Intrinsic::x86_avx_hsub_ps_256,      DOUBLE},                    NO_EMU}},
134             },
135             {
136                 // AVX2
137                 {"meta.intrinsic.VRCPPS",       {{Intrinsic::x86_avx_rcp_ps_256,    DOUBLE},                    NO_EMU}},
138                 {"meta.intrinsic.VPERMPS",      {{Intrinsic::x86_avx2_permps,       Intrinsic::not_intrinsic},  VPERM_EMU}},
139                 {"meta.intrinsic.VPERMD",       {{Intrinsic::x86_avx2_permd,        Intrinsic::not_intrinsic},  VPERM_EMU}},
140                 {"meta.intrinsic.VGATHERPD",    {{Intrinsic::not_intrinsic,         Intrinsic::not_intrinsic},  VGATHER_EMU}},
141                 {"meta.intrinsic.VGATHERPS",    {{Intrinsic::not_intrinsic,         Intrinsic::not_intrinsic},  VGATHER_EMU}},
142                 {"meta.intrinsic.VGATHERDD",    {{Intrinsic::not_intrinsic,         Intrinsic::not_intrinsic},  VGATHER_EMU}},
143                 {"meta.intrinsic.VSCATTERPS", {{Intrinsic::not_intrinsic,           Intrinsic::not_intrinsic}, VSCATTER_EMU}},
144                 {"meta.intrinsic.VCVTPD2PS",    {{Intrinsic::x86_avx_cvt_pd2_ps_256, DOUBLE},                   NO_EMU}},
145                 {"meta.intrinsic.VROUND",       {{Intrinsic::x86_avx_round_ps_256,  DOUBLE},                    NO_EMU}},
146                 {"meta.intrinsic.VHSUBPS",      {{Intrinsic::x86_avx_hsub_ps_256,   DOUBLE},                    NO_EMU}},
147             },
148             {
149                 // AVX512
150                 {"meta.intrinsic.VRCPPS", {{Intrinsic::x86_avx512_rcp14_ps_256,     Intrinsic::x86_avx512_rcp14_ps_512}, NO_EMU}},
151     #if LLVM_VERSION_MAJOR < 7
152                 {"meta.intrinsic.VPERMPS", {{Intrinsic::x86_avx512_mask_permvar_sf_256, Intrinsic::x86_avx512_mask_permvar_sf_512}, NO_EMU}},
153                 {"meta.intrinsic.VPERMD", {{Intrinsic::x86_avx512_mask_permvar_si_256, Intrinsic::x86_avx512_mask_permvar_si_512}, NO_EMU}},
154     #else
155                 {"meta.intrinsic.VPERMPS", {{Intrinsic::not_intrinsic,              Intrinsic::not_intrinsic}, VPERM_EMU}},
156                 {"meta.intrinsic.VPERMD", {{Intrinsic::not_intrinsic,               Intrinsic::not_intrinsic}, VPERM_EMU}},
157     #endif
158                 {"meta.intrinsic.VGATHERPD", {{Intrinsic::not_intrinsic,            Intrinsic::not_intrinsic}, VGATHER_EMU}},
159                 {"meta.intrinsic.VGATHERPS", {{Intrinsic::not_intrinsic,            Intrinsic::not_intrinsic}, VGATHER_EMU}},
160                 {"meta.intrinsic.VGATHERDD", {{Intrinsic::not_intrinsic,            Intrinsic::not_intrinsic}, VGATHER_EMU}},
161                 {"meta.intrinsic.VSCATTERPS", {{Intrinsic::not_intrinsic,           Intrinsic::not_intrinsic}, VSCATTER_EMU}},
162     #if LLVM_VERSION_MAJOR < 7
163                 {"meta.intrinsic.VCVTPD2PS", {{Intrinsic::x86_avx512_mask_cvtpd2ps_256, Intrinsic::x86_avx512_mask_cvtpd2ps_512}, NO_EMU}},
164     #else
165                 {"meta.intrinsic.VCVTPD2PS", {{Intrinsic::not_intrinsic,            Intrinsic::not_intrinsic}, VCONVERT_EMU}},
166     #endif
167                 {"meta.intrinsic.VROUND", {{Intrinsic::not_intrinsic,               Intrinsic::not_intrinsic}, VROUND_EMU}},
168                 {"meta.intrinsic.VHSUBPS", {{Intrinsic::not_intrinsic,              Intrinsic::not_intrinsic}, VHSUB_EMU}}
169             }};
170         // clang-format on
171         return intrinsicMapAdvanced;
172     }
173 
getBitWidth(VectorType * pVTy)174     static uint32_t getBitWidth(VectorType *pVTy)
175     {
176 #if LLVM_VERSION_MAJOR >= 11
177         return pVTy->getNumElements() * pVTy->getElementType()->getPrimitiveSizeInBits();
178 #else
179         return pVTy->getBitWidth();
180 #endif
181     }
182 
183     struct LowerX86 : public FunctionPass
184     {
LowerX86SwrJit::LowerX86185         LowerX86(Builder* b = nullptr) : FunctionPass(ID), B(b)
186         {
187             initializeLowerX86Pass(*PassRegistry::getPassRegistry());
188 
189             // Determine target arch
190             if (JM()->mArch.AVX512F())
191             {
192                 mTarget = AVX512;
193             }
194             else if (JM()->mArch.AVX2())
195             {
196                 mTarget = AVX2;
197             }
198             else if (JM()->mArch.AVX())
199             {
200                 mTarget = AVX;
201             }
202             else
203             {
204                 SWR_ASSERT(false, "Unsupported AVX architecture.");
205                 mTarget = AVX;
206             }
207 
208             // Setup scatter function for 256 wide
209             uint32_t curWidth = B->mVWidth;
210             B->SetTargetWidth(8);
211             std::vector<Type*> args = {
212                 B->mInt8PtrTy,   // pBase
213                 B->mSimdInt32Ty, // vIndices
214                 B->mSimdFP32Ty,  // vSrc
215                 B->mInt8Ty,      // mask
216                 B->mInt32Ty      // scale
217             };
218 
219             FunctionType* pfnScatterTy = FunctionType::get(B->mVoidTy, args, false);
220             mPfnScatter256             = cast<Function>(
221 #if LLVM_VERSION_MAJOR >= 9
222                 B->JM()->mpCurrentModule->getOrInsertFunction("ScatterPS_256", pfnScatterTy).getCallee());
223 #else
224                 B->JM()->mpCurrentModule->getOrInsertFunction("ScatterPS_256", pfnScatterTy));
225 #endif
226             if (sys::DynamicLibrary::SearchForAddressOfSymbol("ScatterPS_256") == nullptr)
227             {
228                 sys::DynamicLibrary::AddSymbol("ScatterPS_256", (void*)&ScatterPS_256);
229             }
230 
231             B->SetTargetWidth(curWidth);
232         }
233 
234         // Try to decipher the vector type of the instruction. This does not work properly
235         // across all intrinsics, and will have to be rethought. Probably need something
236         // similar to llvm's getDeclaration() utility to map a set of inputs to a specific typed
237         // intrinsic.
GetRequestedWidthAndTypeSwrJit::LowerX86238         void GetRequestedWidthAndType(CallInst*       pCallInst,
239                                       const StringRef intrinName,
240                                       TargetWidth*    pWidth,
241                                       Type**          pTy)
242         {
243             assert(pCallInst);
244             Type* pVecTy = pCallInst->getType();
245 
246             // Check for intrinsic specific types
247             // VCVTPD2PS type comes from src, not dst
248             if (intrinName.equals("meta.intrinsic.VCVTPD2PS"))
249             {
250                 Value* pOp = pCallInst->getOperand(0);
251                 assert(pOp);
252                 pVecTy = pOp->getType();
253             }
254 
255             if (!pVecTy->isVectorTy())
256             {
257                 for (auto& op : pCallInst->arg_operands())
258                 {
259                     if (op.get()->getType()->isVectorTy())
260                     {
261                         pVecTy = op.get()->getType();
262                         break;
263                     }
264                 }
265             }
266             SWR_ASSERT(pVecTy->isVectorTy(), "Couldn't determine vector size");
267 
268             uint32_t width = getBitWidth(cast<VectorType>(pVecTy));
269             switch (width)
270             {
271             case 256:
272                 *pWidth = W256;
273                 break;
274             case 512:
275                 *pWidth = W512;
276                 break;
277             default:
278                 SWR_ASSERT(false, "Unhandled vector width %d", width);
279                 *pWidth = W256;
280             }
281 
282             *pTy = pVecTy->getScalarType();
283         }
284 
GetZeroVecSwrJit::LowerX86285         Value* GetZeroVec(TargetWidth width, Type* pTy)
286         {
287             uint32_t numElem = 0;
288             switch (width)
289             {
290             case W256:
291                 numElem = 8;
292                 break;
293             case W512:
294                 numElem = 16;
295                 break;
296             default:
297                 SWR_ASSERT(false, "Unhandled vector width type %d\n", width);
298             }
299 
300             return ConstantVector::getNullValue(getVectorType(pTy, numElem));
301         }
302 
GetMaskSwrJit::LowerX86303         Value* GetMask(TargetWidth width)
304         {
305             Value* mask;
306             switch (width)
307             {
308             case W256:
309                 mask = B->C((uint8_t)-1);
310                 break;
311             case W512:
312                 mask = B->C((uint16_t)-1);
313                 break;
314             default:
315                 SWR_ASSERT(false, "Unhandled vector width type %d\n", width);
316             }
317             return mask;
318         }
319 
320         // Convert <N x i1> mask to <N x i32> x86 mask
VectorMaskSwrJit::LowerX86321         Value* VectorMask(Value* vi1Mask)
322         {
323 #if LLVM_VERSION_MAJOR >= 11
324             uint32_t numElem = cast<VectorType>(vi1Mask->getType())->getNumElements();
325 #else
326             uint32_t numElem = vi1Mask->getType()->getVectorNumElements();
327 #endif
328             return B->S_EXT(vi1Mask, getVectorType(B->mInt32Ty, numElem));
329         }
330 
ProcessIntrinsicAdvancedSwrJit::LowerX86331         Instruction* ProcessIntrinsicAdvanced(CallInst* pCallInst)
332         {
333             Function*   pFunc = pCallInst->getCalledFunction();
334             assert(pFunc);
335 
336             auto&       intrinsic = getIntrinsicMapAdvanced()[mTarget][pFunc->getName().str()];
337             TargetWidth vecWidth;
338             Type*       pElemTy;
339             GetRequestedWidthAndType(pCallInst, pFunc->getName(), &vecWidth, &pElemTy);
340 
341             // Check if there is a native intrinsic for this instruction
342             IntrinsicID id = intrinsic.intrin[vecWidth];
343             if (id == DOUBLE)
344             {
345                 // Double pump the next smaller SIMD intrinsic
346                 SWR_ASSERT(vecWidth != 0, "Cannot double pump smallest SIMD width.");
347                 Intrinsic::ID id2 = intrinsic.intrin[vecWidth - 1];
348                 SWR_ASSERT(id2 != Intrinsic::not_intrinsic,
349                            "Cannot find intrinsic to double pump.");
350                 return DOUBLE_EMU(this, mTarget, vecWidth, pCallInst, id2);
351             }
352             else if (id != Intrinsic::not_intrinsic)
353             {
354                 Function* pIntrin = Intrinsic::getDeclaration(B->JM()->mpCurrentModule, id);
355                 SmallVector<Value*, 8> args;
356                 for (auto& arg : pCallInst->arg_operands())
357                 {
358                     args.push_back(arg.get());
359                 }
360 
361                 // If AVX512, all instructions add a src operand and mask. We'll pass in 0 src and
362                 // full mask for now Assuming the intrinsics are consistent and place the src
363                 // operand and mask last in the argument list.
364                 if (mTarget == AVX512)
365                 {
366                     if (pFunc->getName().equals("meta.intrinsic.VCVTPD2PS"))
367                     {
368                         args.push_back(GetZeroVec(W256, pCallInst->getType()->getScalarType()));
369                         args.push_back(GetMask(W256));
370                         // for AVX512 VCVTPD2PS, we also have to add rounding mode
371                         args.push_back(B->C(_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
372                     }
373                     else
374                     {
375                         args.push_back(GetZeroVec(vecWidth, pElemTy));
376                         args.push_back(GetMask(vecWidth));
377                     }
378                 }
379 
380                 return B->CALLA(pIntrin, args);
381             }
382             else
383             {
384                 // No native intrinsic, call emulation function
385                 return intrinsic.emuFunc(this, mTarget, vecWidth, pCallInst);
386             }
387 
388             SWR_ASSERT(false);
389             return nullptr;
390         }
391 
ProcessIntrinsicSwrJit::LowerX86392         Instruction* ProcessIntrinsic(CallInst* pCallInst)
393         {
394             Function* pFunc = pCallInst->getCalledFunction();
395             assert(pFunc);
396 
397             // Forward to the advanced support if found
398             if (getIntrinsicMapAdvanced()[mTarget].find(pFunc->getName().str()) != getIntrinsicMapAdvanced()[mTarget].end())
399             {
400                 return ProcessIntrinsicAdvanced(pCallInst);
401             }
402 
403             SWR_ASSERT(getIntrinsicMap().find(pFunc->getName().str()) != getIntrinsicMap().end(),
404                        "Unimplemented intrinsic %s.",
405                        pFunc->getName().str().c_str());
406 
407             Intrinsic::ID x86Intrinsic = getIntrinsicMap()[pFunc->getName().str()];
408             Function*     pX86IntrinFunc =
409                 Intrinsic::getDeclaration(B->JM()->mpCurrentModule, x86Intrinsic);
410 
411             SmallVector<Value*, 8> args;
412             for (auto& arg : pCallInst->arg_operands())
413             {
414                 args.push_back(arg.get());
415             }
416             return B->CALLA(pX86IntrinFunc, args);
417         }
418 
419         //////////////////////////////////////////////////////////////////////////
420         /// @brief LLVM funtion pass run method.
421         /// @param f- The function we're working on with this pass.
runOnFunctionSwrJit::LowerX86422         virtual bool runOnFunction(Function& F)
423         {
424             std::vector<Instruction*> toRemove;
425             std::vector<BasicBlock*>  bbs;
426 
427             // Make temp copy of the basic blocks and instructions, as the intrinsic
428             // replacement code might invalidate the iterators
429             for (auto& b : F.getBasicBlockList())
430             {
431                 bbs.push_back(&b);
432             }
433 
434             for (auto* BB : bbs)
435             {
436                 std::vector<Instruction*> insts;
437                 for (auto& i : BB->getInstList())
438                 {
439                     insts.push_back(&i);
440                 }
441 
442                 for (auto* I : insts)
443                 {
444                     if (CallInst* pCallInst = dyn_cast<CallInst>(I))
445                     {
446                         Function* pFunc = pCallInst->getCalledFunction();
447                         if (pFunc)
448                         {
449                             if (pFunc->getName().startswith("meta.intrinsic"))
450                             {
451                                 B->IRB()->SetInsertPoint(I);
452                                 Instruction* pReplace = ProcessIntrinsic(pCallInst);
453                                 toRemove.push_back(pCallInst);
454                                 if (pReplace)
455                                 {
456                                     pCallInst->replaceAllUsesWith(pReplace);
457                                 }
458                             }
459                         }
460                     }
461                 }
462             }
463 
464             for (auto* pInst : toRemove)
465             {
466                 pInst->eraseFromParent();
467             }
468 
469             JitManager::DumpToFile(&F, "lowerx86");
470 
471             return true;
472         }
473 
getAnalysisUsageSwrJit::LowerX86474         virtual void getAnalysisUsage(AnalysisUsage& AU) const {}
475 
JMSwrJit::LowerX86476         JitManager* JM() { return B->JM(); }
477         Builder*    B;
478         TargetArch  mTarget;
479         Function*   mPfnScatter256;
480 
481         static char ID; ///< Needed by LLVM to generate ID for FunctionPass.
482     };
483 
484     char LowerX86::ID = 0; // LLVM uses address of ID as the actual ID.
485 
createLowerX86Pass(Builder * b)486     FunctionPass* createLowerX86Pass(Builder* b) { return new LowerX86(b); }
487 
NO_EMU(LowerX86 * pThis,TargetArch arch,TargetWidth width,CallInst * pCallInst)488     Instruction* NO_EMU(LowerX86* pThis, TargetArch arch, TargetWidth width, CallInst* pCallInst)
489     {
490         SWR_ASSERT(false, "Unimplemented intrinsic emulation.");
491         return nullptr;
492     }
493 
VPERM_EMU(LowerX86 * pThis,TargetArch arch,TargetWidth width,CallInst * pCallInst)494     Instruction* VPERM_EMU(LowerX86* pThis, TargetArch arch, TargetWidth width, CallInst* pCallInst)
495     {
496         // Only need vperm emulation for AVX
497         SWR_ASSERT(arch == AVX);
498 
499         Builder* B         = pThis->B;
500         auto     v32A      = pCallInst->getArgOperand(0);
501         auto     vi32Index = pCallInst->getArgOperand(1);
502 
503         Value* v32Result;
504         if (isa<Constant>(vi32Index))
505         {
506             // Can use llvm shuffle vector directly with constant shuffle indices
507             v32Result = B->VSHUFFLE(v32A, v32A, vi32Index);
508         }
509         else
510         {
511             v32Result = UndefValue::get(v32A->getType());
512 #if LLVM_VERSION_MAJOR >= 11
513             uint32_t numElem = cast<VectorType>(v32A->getType())->getNumElements();
514 #else
515             uint32_t numElem = v32A->getType()->getVectorNumElements();
516 #endif
517             for (uint32_t l = 0; l < numElem; ++l)
518             {
519                 auto i32Index = B->VEXTRACT(vi32Index, B->C(l));
520                 auto val      = B->VEXTRACT(v32A, i32Index);
521                 v32Result     = B->VINSERT(v32Result, val, B->C(l));
522             }
523         }
524         return cast<Instruction>(v32Result);
525     }
526 
527     Instruction*
VGATHER_EMU(LowerX86 * pThis,TargetArch arch,TargetWidth width,CallInst * pCallInst)528     VGATHER_EMU(LowerX86* pThis, TargetArch arch, TargetWidth width, CallInst* pCallInst)
529     {
530         Builder* B           = pThis->B;
531         auto     vSrc        = pCallInst->getArgOperand(0);
532         auto     pBase       = pCallInst->getArgOperand(1);
533         auto     vi32Indices = pCallInst->getArgOperand(2);
534         auto     vi1Mask     = pCallInst->getArgOperand(3);
535         auto     i8Scale     = pCallInst->getArgOperand(4);
536 
537         pBase              = B->POINTER_CAST(pBase, PointerType::get(B->mInt8Ty, 0));
538 #if LLVM_VERSION_MAJOR >= 11
539         VectorType* pVectorType = cast<VectorType>(vSrc->getType());
540         uint32_t    numElem     = pVectorType->getNumElements();
541         auto        srcTy       = pVectorType->getElementType();
542 #else
543         uint32_t numElem   = vSrc->getType()->getVectorNumElements();
544         auto     srcTy     = vSrc->getType()->getVectorElementType();
545 #endif
546         auto     i32Scale  = B->Z_EXT(i8Scale, B->mInt32Ty);
547 
548         Value*   v32Gather = nullptr;
549         if (arch == AVX)
550         {
551             // Full emulation for AVX
552             // Store source on stack to provide a valid address to load from inactive lanes
553             auto pStack = B->STACKSAVE();
554             auto pTmp   = B->ALLOCA(vSrc->getType());
555             B->STORE(vSrc, pTmp);
556 
557             v32Gather        = UndefValue::get(vSrc->getType());
558 #if LLVM_VERSION_MAJOR <= 10
559             auto vi32Scale   = ConstantVector::getSplat(numElem, cast<ConstantInt>(i32Scale));
560 #elif LLVM_VERSION_MAJOR == 11
561             auto vi32Scale   = ConstantVector::getSplat(ElementCount(numElem, false), cast<ConstantInt>(i32Scale));
562 #else
563             auto vi32Scale   = ConstantVector::getSplat(ElementCount::get(numElem, false), cast<ConstantInt>(i32Scale));
564 #endif
565             auto vi32Offsets = B->MUL(vi32Indices, vi32Scale);
566 
567             for (uint32_t i = 0; i < numElem; ++i)
568             {
569                 auto i32Offset          = B->VEXTRACT(vi32Offsets, B->C(i));
570                 auto pLoadAddress       = B->GEP(pBase, i32Offset);
571                 pLoadAddress            = B->BITCAST(pLoadAddress, PointerType::get(srcTy, 0));
572                 auto pMaskedLoadAddress = B->GEP(pTmp, {0, i});
573                 auto i1Mask             = B->VEXTRACT(vi1Mask, B->C(i));
574                 auto pValidAddress      = B->SELECT(i1Mask, pLoadAddress, pMaskedLoadAddress);
575                 auto val                = B->LOAD(pValidAddress);
576                 v32Gather               = B->VINSERT(v32Gather, val, B->C(i));
577             }
578 
579             B->STACKRESTORE(pStack);
580         }
581         else if (arch == AVX2 || (arch == AVX512 && width == W256))
582         {
583             Function* pX86IntrinFunc = nullptr;
584             if (srcTy == B->mFP32Ty)
585             {
586                 pX86IntrinFunc = Intrinsic::getDeclaration(B->JM()->mpCurrentModule,
587                                                            Intrinsic::x86_avx2_gather_d_ps_256);
588             }
589             else if (srcTy == B->mInt32Ty)
590             {
591                 pX86IntrinFunc = Intrinsic::getDeclaration(B->JM()->mpCurrentModule,
592                                                            Intrinsic::x86_avx2_gather_d_d_256);
593             }
594             else if (srcTy == B->mDoubleTy)
595             {
596                 pX86IntrinFunc = Intrinsic::getDeclaration(B->JM()->mpCurrentModule,
597                                                            Intrinsic::x86_avx2_gather_d_q_256);
598             }
599             else
600             {
601                 SWR_ASSERT(false, "Unsupported vector element type for gather.");
602             }
603 
604             if (width == W256)
605             {
606                 auto v32Mask = B->BITCAST(pThis->VectorMask(vi1Mask), vSrc->getType());
607                 v32Gather = B->CALL(pX86IntrinFunc, {vSrc, pBase, vi32Indices, v32Mask, i8Scale});
608             }
609             else if (width == W512)
610             {
611                 // Double pump 4-wide for 64bit elements
612 #if LLVM_VERSION_MAJOR >= 11
613                 if (cast<VectorType>(vSrc->getType())->getElementType() == B->mDoubleTy)
614 #else
615                 if (vSrc->getType()->getVectorElementType() == B->mDoubleTy)
616 #endif
617                 {
618                     auto v64Mask = pThis->VectorMask(vi1Mask);
619 #if LLVM_VERSION_MAJOR >= 11
620                     uint32_t numElem = cast<VectorType>(v64Mask->getType())->getNumElements();
621 #else
622                     uint32_t numElem = v64Mask->getType()->getVectorNumElements();
623 #endif
624                     v64Mask = B->S_EXT(v64Mask, getVectorType(B->mInt64Ty, numElem));
625                     v64Mask = B->BITCAST(v64Mask, vSrc->getType());
626 
627                     Value* src0 = B->VSHUFFLE(vSrc, vSrc, B->C({0, 1, 2, 3}));
628                     Value* src1 = B->VSHUFFLE(vSrc, vSrc, B->C({4, 5, 6, 7}));
629 
630                     Value* indices0 = B->VSHUFFLE(vi32Indices, vi32Indices, B->C({0, 1, 2, 3}));
631                     Value* indices1 = B->VSHUFFLE(vi32Indices, vi32Indices, B->C({4, 5, 6, 7}));
632 
633                     Value* mask0 = B->VSHUFFLE(v64Mask, v64Mask, B->C({0, 1, 2, 3}));
634                     Value* mask1 = B->VSHUFFLE(v64Mask, v64Mask, B->C({4, 5, 6, 7}));
635 
636 #if LLVM_VERSION_MAJOR >= 11
637                     uint32_t numElemSrc0  = cast<VectorType>(src0->getType())->getNumElements();
638                     uint32_t numElemMask0 = cast<VectorType>(mask0->getType())->getNumElements();
639                     uint32_t numElemSrc1  = cast<VectorType>(src1->getType())->getNumElements();
640                     uint32_t numElemMask1 = cast<VectorType>(mask1->getType())->getNumElements();
641 #else
642                     uint32_t numElemSrc0  = src0->getType()->getVectorNumElements();
643                     uint32_t numElemMask0 = mask0->getType()->getVectorNumElements();
644                     uint32_t numElemSrc1  = src1->getType()->getVectorNumElements();
645                     uint32_t numElemMask1 = mask1->getType()->getVectorNumElements();
646 #endif
647                     src0 = B->BITCAST(src0, getVectorType(B->mInt64Ty, numElemSrc0));
648                     mask0 = B->BITCAST(mask0, getVectorType(B->mInt64Ty, numElemMask0));
649                     Value* gather0 =
650                         B->CALL(pX86IntrinFunc, {src0, pBase, indices0, mask0, i8Scale});
651                     src1 = B->BITCAST(src1, getVectorType(B->mInt64Ty, numElemSrc1));
652                     mask1 = B->BITCAST(mask1, getVectorType(B->mInt64Ty, numElemMask1));
653                     Value* gather1 =
654                         B->CALL(pX86IntrinFunc, {src1, pBase, indices1, mask1, i8Scale});
655                     v32Gather = B->VSHUFFLE(gather0, gather1, B->C({0, 1, 2, 3, 4, 5, 6, 7}));
656                     v32Gather = B->BITCAST(v32Gather, vSrc->getType());
657                 }
658                 else
659                 {
660                     // Double pump 8-wide for 32bit elements
661                     auto v32Mask = pThis->VectorMask(vi1Mask);
662                     v32Mask      = B->BITCAST(v32Mask, vSrc->getType());
663                     Value* src0  = B->EXTRACT_16(vSrc, 0);
664                     Value* src1  = B->EXTRACT_16(vSrc, 1);
665 
666                     Value* indices0 = B->EXTRACT_16(vi32Indices, 0);
667                     Value* indices1 = B->EXTRACT_16(vi32Indices, 1);
668 
669                     Value* mask0 = B->EXTRACT_16(v32Mask, 0);
670                     Value* mask1 = B->EXTRACT_16(v32Mask, 1);
671 
672                     Value* gather0 =
673                         B->CALL(pX86IntrinFunc, {src0, pBase, indices0, mask0, i8Scale});
674                     Value* gather1 =
675                         B->CALL(pX86IntrinFunc, {src1, pBase, indices1, mask1, i8Scale});
676 
677                     v32Gather = B->JOIN_16(gather0, gather1);
678                 }
679             }
680         }
681         else if (arch == AVX512)
682         {
683             Value*    iMask = nullptr;
684             Function* pX86IntrinFunc = nullptr;
685             if (srcTy == B->mFP32Ty)
686             {
687                 pX86IntrinFunc = Intrinsic::getDeclaration(B->JM()->mpCurrentModule,
688                                                            Intrinsic::x86_avx512_gather_dps_512);
689                 iMask          = B->BITCAST(vi1Mask, B->mInt16Ty);
690             }
691             else if (srcTy == B->mInt32Ty)
692             {
693                 pX86IntrinFunc = Intrinsic::getDeclaration(B->JM()->mpCurrentModule,
694                                                            Intrinsic::x86_avx512_gather_dpi_512);
695                 iMask          = B->BITCAST(vi1Mask, B->mInt16Ty);
696             }
697             else if (srcTy == B->mDoubleTy)
698             {
699                 pX86IntrinFunc = Intrinsic::getDeclaration(B->JM()->mpCurrentModule,
700                                                            Intrinsic::x86_avx512_gather_dpd_512);
701                 iMask          = B->BITCAST(vi1Mask, B->mInt8Ty);
702             }
703             else
704             {
705                 SWR_ASSERT(false, "Unsupported vector element type for gather.");
706             }
707 
708             auto i32Scale = B->Z_EXT(i8Scale, B->mInt32Ty);
709             v32Gather     = B->CALL(pX86IntrinFunc, {vSrc, pBase, vi32Indices, iMask, i32Scale});
710         }
711 
712         return cast<Instruction>(v32Gather);
713     }
714     Instruction*
VSCATTER_EMU(LowerX86 * pThis,TargetArch arch,TargetWidth width,CallInst * pCallInst)715     VSCATTER_EMU(LowerX86* pThis, TargetArch arch, TargetWidth width, CallInst* pCallInst)
716     {
717         Builder* B           = pThis->B;
718         auto     pBase       = pCallInst->getArgOperand(0);
719         auto     vi1Mask     = pCallInst->getArgOperand(1);
720         auto     vi32Indices = pCallInst->getArgOperand(2);
721         auto     v32Src      = pCallInst->getArgOperand(3);
722         auto     i32Scale    = pCallInst->getArgOperand(4);
723 
724         if (arch != AVX512)
725         {
726             // Call into C function to do the scatter. This has significantly better compile perf
727             // compared to jitting scatter loops for every scatter
728             if (width == W256)
729             {
730                 auto mask = B->BITCAST(vi1Mask, B->mInt8Ty);
731                 B->CALL(pThis->mPfnScatter256, {pBase, vi32Indices, v32Src, mask, i32Scale});
732             }
733             else
734             {
735                 // Need to break up 512 wide scatter to two 256 wide
736                 auto maskLo = B->VSHUFFLE(vi1Mask, vi1Mask, B->C({0, 1, 2, 3, 4, 5, 6, 7}));
737                 auto indicesLo =
738                     B->VSHUFFLE(vi32Indices, vi32Indices, B->C({0, 1, 2, 3, 4, 5, 6, 7}));
739                 auto srcLo = B->VSHUFFLE(v32Src, v32Src, B->C({0, 1, 2, 3, 4, 5, 6, 7}));
740 
741                 auto mask = B->BITCAST(maskLo, B->mInt8Ty);
742                 B->CALL(pThis->mPfnScatter256, {pBase, indicesLo, srcLo, mask, i32Scale});
743 
744                 auto maskHi = B->VSHUFFLE(vi1Mask, vi1Mask, B->C({8, 9, 10, 11, 12, 13, 14, 15}));
745                 auto indicesHi =
746                     B->VSHUFFLE(vi32Indices, vi32Indices, B->C({8, 9, 10, 11, 12, 13, 14, 15}));
747                 auto srcHi = B->VSHUFFLE(v32Src, v32Src, B->C({8, 9, 10, 11, 12, 13, 14, 15}));
748 
749                 mask = B->BITCAST(maskHi, B->mInt8Ty);
750                 B->CALL(pThis->mPfnScatter256, {pBase, indicesHi, srcHi, mask, i32Scale});
751             }
752             return nullptr;
753         }
754 
755         Value*    iMask;
756         Function* pX86IntrinFunc;
757         if (width == W256)
758         {
759             // No direct intrinsic supported in llvm to scatter 8 elem with 32bit indices, but we
760             // can use the scatter of 8 elements with 64bit indices
761             pX86IntrinFunc = Intrinsic::getDeclaration(B->JM()->mpCurrentModule,
762                                                        Intrinsic::x86_avx512_scatter_qps_512);
763 
764             auto vi32IndicesExt = B->Z_EXT(vi32Indices, B->mSimdInt64Ty);
765             iMask               = B->BITCAST(vi1Mask, B->mInt8Ty);
766             B->CALL(pX86IntrinFunc, {pBase, iMask, vi32IndicesExt, v32Src, i32Scale});
767         }
768         else if (width == W512)
769         {
770             pX86IntrinFunc = Intrinsic::getDeclaration(B->JM()->mpCurrentModule,
771                                                        Intrinsic::x86_avx512_scatter_dps_512);
772             iMask          = B->BITCAST(vi1Mask, B->mInt16Ty);
773             B->CALL(pX86IntrinFunc, {pBase, iMask, vi32Indices, v32Src, i32Scale});
774         }
775         return nullptr;
776     }
777 
778     // No support for vroundps in avx512 (it is available in kncni), so emulate with avx
779     // instructions
780     Instruction*
VROUND_EMU(LowerX86 * pThis,TargetArch arch,TargetWidth width,CallInst * pCallInst)781     VROUND_EMU(LowerX86* pThis, TargetArch arch, TargetWidth width, CallInst* pCallInst)
782     {
783         SWR_ASSERT(arch == AVX512);
784 
785         auto B       = pThis->B;
786         auto vf32Src = pCallInst->getOperand(0);
787         assert(vf32Src);
788         auto i8Round = pCallInst->getOperand(1);
789         assert(i8Round);
790         auto pfnFunc =
791             Intrinsic::getDeclaration(B->JM()->mpCurrentModule, Intrinsic::x86_avx_round_ps_256);
792 
793         if (width == W256)
794         {
795             return cast<Instruction>(B->CALL2(pfnFunc, vf32Src, i8Round));
796         }
797         else if (width == W512)
798         {
799             auto v8f32SrcLo = B->EXTRACT_16(vf32Src, 0);
800             auto v8f32SrcHi = B->EXTRACT_16(vf32Src, 1);
801 
802             auto v8f32ResLo = B->CALL2(pfnFunc, v8f32SrcLo, i8Round);
803             auto v8f32ResHi = B->CALL2(pfnFunc, v8f32SrcHi, i8Round);
804 
805             return cast<Instruction>(B->JOIN_16(v8f32ResLo, v8f32ResHi));
806         }
807         else
808         {
809             SWR_ASSERT(false, "Unimplemented vector width.");
810         }
811 
812         return nullptr;
813     }
814 
815     Instruction*
VCONVERT_EMU(LowerX86 * pThis,TargetArch arch,TargetWidth width,CallInst * pCallInst)816     VCONVERT_EMU(LowerX86* pThis, TargetArch arch, TargetWidth width, CallInst* pCallInst)
817     {
818         SWR_ASSERT(arch == AVX512);
819 
820         auto B       = pThis->B;
821         auto vf32Src = pCallInst->getOperand(0);
822 
823         if (width == W256)
824         {
825             auto vf32SrcRound = Intrinsic::getDeclaration(B->JM()->mpCurrentModule,
826                                                           Intrinsic::x86_avx_round_ps_256);
827             return cast<Instruction>(B->FP_TRUNC(vf32SrcRound, B->mFP32Ty));
828         }
829         else if (width == W512)
830         {
831             // 512 can use intrinsic
832             auto pfnFunc = Intrinsic::getDeclaration(B->JM()->mpCurrentModule,
833                                                      Intrinsic::x86_avx512_mask_cvtpd2ps_512);
834             return cast<Instruction>(B->CALL(pfnFunc, vf32Src));
835         }
836         else
837         {
838             SWR_ASSERT(false, "Unimplemented vector width.");
839         }
840 
841         return nullptr;
842     }
843 
844     // No support for hsub in AVX512
VHSUB_EMU(LowerX86 * pThis,TargetArch arch,TargetWidth width,CallInst * pCallInst)845     Instruction* VHSUB_EMU(LowerX86* pThis, TargetArch arch, TargetWidth width, CallInst* pCallInst)
846     {
847         SWR_ASSERT(arch == AVX512);
848 
849         auto B    = pThis->B;
850         auto src0 = pCallInst->getOperand(0);
851         auto src1 = pCallInst->getOperand(1);
852 
853         // 256b hsub can just use avx intrinsic
854         if (width == W256)
855         {
856             auto pX86IntrinFunc =
857                 Intrinsic::getDeclaration(B->JM()->mpCurrentModule, Intrinsic::x86_avx_hsub_ps_256);
858             return cast<Instruction>(B->CALL2(pX86IntrinFunc, src0, src1));
859         }
860         else if (width == W512)
861         {
862             // 512b hsub can be accomplished with shuf/sub combo
863             auto minuend    = B->VSHUFFLE(src0, src1, B->C({0, 2, 8, 10, 4, 6, 12, 14}));
864             auto subtrahend = B->VSHUFFLE(src0, src1, B->C({1, 3, 9, 11, 5, 7, 13, 15}));
865             return cast<Instruction>(B->SUB(minuend, subtrahend));
866         }
867         else
868         {
869             SWR_ASSERT(false, "Unimplemented vector width.");
870             return nullptr;
871         }
872     }
873 
874     // Double pump input using Intrin template arg. This blindly extracts lower and upper 256 from
875     // each vector argument and calls the 256 wide intrinsic, then merges the results to 512 wide
DOUBLE_EMU(LowerX86 * pThis,TargetArch arch,TargetWidth width,CallInst * pCallInst,Intrinsic::ID intrin)876     Instruction* DOUBLE_EMU(LowerX86*     pThis,
877                             TargetArch    arch,
878                             TargetWidth   width,
879                             CallInst*     pCallInst,
880                             Intrinsic::ID intrin)
881     {
882         auto B = pThis->B;
883         SWR_ASSERT(width == W512);
884         Value*    result[2];
885         Function* pX86IntrinFunc = Intrinsic::getDeclaration(B->JM()->mpCurrentModule, intrin);
886         for (uint32_t i = 0; i < 2; ++i)
887         {
888             SmallVector<Value*, 8> args;
889             for (auto& arg : pCallInst->arg_operands())
890             {
891                 auto argType = arg.get()->getType();
892                 if (argType->isVectorTy())
893                 {
894 #if LLVM_VERSION_MAJOR >= 11
895                     uint32_t vecWidth  = cast<VectorType>(argType)->getNumElements();
896                     auto     elemTy    = cast<VectorType>(argType)->getElementType();
897 #else
898                     uint32_t vecWidth  = argType->getVectorNumElements();
899                     auto     elemTy    = argType->getVectorElementType();
900 #endif
901                     Value*   lanes     = B->CInc<int>(i * vecWidth / 2, vecWidth / 2);
902                     Value*   argToPush = B->VSHUFFLE(arg.get(), B->VUNDEF(elemTy, vecWidth), lanes);
903                     args.push_back(argToPush);
904                 }
905                 else
906                 {
907                     args.push_back(arg.get());
908                 }
909             }
910             result[i] = B->CALLA(pX86IntrinFunc, args);
911         }
912         uint32_t vecWidth;
913         if (result[0]->getType()->isVectorTy())
914         {
915             assert(result[1]->getType()->isVectorTy());
916 #if LLVM_VERSION_MAJOR >= 11
917             vecWidth = cast<VectorType>(result[0]->getType())->getNumElements() +
918                        cast<VectorType>(result[1]->getType())->getNumElements();
919 #else
920             vecWidth = result[0]->getType()->getVectorNumElements() +
921                        result[1]->getType()->getVectorNumElements();
922 #endif
923         }
924         else
925         {
926             vecWidth = 2;
927         }
928         Value* lanes = B->CInc<int>(0, vecWidth);
929         return cast<Instruction>(B->VSHUFFLE(result[0], result[1], lanes));
930     }
931 
932 } // namespace SwrJit
933 
934 using namespace SwrJit;
935 
936 INITIALIZE_PASS_BEGIN(LowerX86, "LowerX86", "LowerX86", false, false)
937 INITIALIZE_PASS_END(LowerX86, "LowerX86", "LowerX86", false, false)
938