1 /**************************************************************************** 2 * Copyright (C) 2014-2018 Intel Corporation. All Rights Reserved. 3 * 4 * Permission is hereby granted, free of charge, to any person obtaining a 5 * copy of this software and associated documentation files (the "Software"), 6 * to deal in the Software without restriction, including without limitation 7 * the rights to use, copy, modify, merge, publish, distribute, sublicense, 8 * and/or sell copies of the Software, and to permit persons to whom the 9 * Software is furnished to do so, subject to the following conditions: 10 * 11 * The above copyright notice and this permission notice (including the next 12 * paragraph) shall be included in all copies or substantial portions of the 13 * Software. 14 * 15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS 21 * IN THE SOFTWARE. 22 * 23 * @file lower_x86.cpp 24 * 25 * @brief llvm pass to lower meta code to x86 26 * 27 * Notes: 28 * 29 ******************************************************************************/ 30 31 #include "jit_pch.hpp" 32 #include "passes.h" 33 #include "JitManager.h" 34 35 #include "common/simdlib.hpp" 36 37 #include <unordered_map> 38 39 extern "C" void ScatterPS_256(uint8_t*, SIMD256::Integer, SIMD256::Float, uint8_t, uint32_t); 40 41 namespace llvm 42 { 43 // foward declare the initializer 44 void initializeLowerX86Pass(PassRegistry&); 45 } // namespace llvm 46 47 namespace SwrJit 48 { 49 using namespace llvm; 50 51 enum TargetArch 52 { 53 AVX = 0, 54 AVX2 = 1, 55 AVX512 = 2 56 }; 57 58 enum TargetWidth 59 { 60 W256 = 0, 61 W512 = 1, 62 NUM_WIDTHS = 2 63 }; 64 65 struct LowerX86; 66 67 typedef std::function<Instruction*(LowerX86*, TargetArch, TargetWidth, CallInst*)> EmuFunc; 68 69 struct X86Intrinsic 70 { 71 IntrinsicID intrin[NUM_WIDTHS]; 72 EmuFunc emuFunc; 73 }; 74 75 // Map of intrinsics that haven't been moved to the new mechanism yet. If used, these get the 76 // previous behavior of mapping directly to avx/avx2 intrinsics. 77 using intrinsicMap_t = std::map<std::string, IntrinsicID>; getIntrinsicMap()78 static intrinsicMap_t& getIntrinsicMap() { 79 static std::map<std::string, IntrinsicID> intrinsicMap = { 80 {"meta.intrinsic.BEXTR_32", Intrinsic::x86_bmi_bextr_32}, 81 {"meta.intrinsic.VPSHUFB", Intrinsic::x86_avx2_pshuf_b}, 82 {"meta.intrinsic.VCVTPS2PH", Intrinsic::x86_vcvtps2ph_256}, 83 {"meta.intrinsic.VPTESTC", Intrinsic::x86_avx_ptestc_256}, 84 {"meta.intrinsic.VPTESTZ", Intrinsic::x86_avx_ptestz_256}, 85 {"meta.intrinsic.VPHADDD", Intrinsic::x86_avx2_phadd_d}, 86 {"meta.intrinsic.PDEP32", Intrinsic::x86_bmi_pdep_32}, 87 {"meta.intrinsic.RDTSC", Intrinsic::x86_rdtsc} 88 }; 89 return intrinsicMap; 90 } 91 92 // Forward decls 93 Instruction* NO_EMU(LowerX86* pThis, TargetArch arch, TargetWidth width, CallInst* pCallInst); 94 Instruction* 95 VPERM_EMU(LowerX86* pThis, TargetArch arch, TargetWidth width, CallInst* pCallInst); 96 Instruction* 97 VGATHER_EMU(LowerX86* pThis, TargetArch arch, TargetWidth width, CallInst* pCallInst); 98 Instruction* 99 VSCATTER_EMU(LowerX86* pThis, TargetArch arch, TargetWidth width, CallInst* pCallInst); 100 Instruction* 101 VROUND_EMU(LowerX86* pThis, TargetArch arch, TargetWidth width, CallInst* pCallInst); 102 Instruction* 103 VHSUB_EMU(LowerX86* pThis, TargetArch arch, TargetWidth width, CallInst* pCallInst); 104 Instruction* 105 VCONVERT_EMU(LowerX86* pThis, TargetArch arch, TargetWidth width, CallInst* pCallInst); 106 107 Instruction* DOUBLE_EMU(LowerX86* pThis, 108 TargetArch arch, 109 TargetWidth width, 110 CallInst* pCallInst, 111 Intrinsic::ID intrin); 112 113 static Intrinsic::ID DOUBLE = (Intrinsic::ID)-1; 114 115 using intrinsicMapAdvanced_t = std::vector<std::map<std::string, X86Intrinsic>>; 116 getIntrinsicMapAdvanced()117 static intrinsicMapAdvanced_t& getIntrinsicMapAdvanced() 118 { 119 // clang-format off 120 static intrinsicMapAdvanced_t intrinsicMapAdvanced = { 121 // 256 wide 512 wide 122 { 123 // AVX 124 {"meta.intrinsic.VRCPPS", {{Intrinsic::x86_avx_rcp_ps_256, DOUBLE}, NO_EMU}}, 125 {"meta.intrinsic.VPERMPS", {{Intrinsic::not_intrinsic, Intrinsic::not_intrinsic}, VPERM_EMU}}, 126 {"meta.intrinsic.VPERMD", {{Intrinsic::not_intrinsic, Intrinsic::not_intrinsic}, VPERM_EMU}}, 127 {"meta.intrinsic.VGATHERPD", {{Intrinsic::not_intrinsic, Intrinsic::not_intrinsic}, VGATHER_EMU}}, 128 {"meta.intrinsic.VGATHERPS", {{Intrinsic::not_intrinsic, Intrinsic::not_intrinsic}, VGATHER_EMU}}, 129 {"meta.intrinsic.VGATHERDD", {{Intrinsic::not_intrinsic, Intrinsic::not_intrinsic}, VGATHER_EMU}}, 130 {"meta.intrinsic.VSCATTERPS", {{Intrinsic::not_intrinsic, Intrinsic::not_intrinsic}, VSCATTER_EMU}}, 131 {"meta.intrinsic.VCVTPD2PS", {{Intrinsic::x86_avx_cvt_pd2_ps_256, Intrinsic::not_intrinsic}, NO_EMU}}, 132 {"meta.intrinsic.VROUND", {{Intrinsic::x86_avx_round_ps_256, DOUBLE}, NO_EMU}}, 133 {"meta.intrinsic.VHSUBPS", {{Intrinsic::x86_avx_hsub_ps_256, DOUBLE}, NO_EMU}}, 134 }, 135 { 136 // AVX2 137 {"meta.intrinsic.VRCPPS", {{Intrinsic::x86_avx_rcp_ps_256, DOUBLE}, NO_EMU}}, 138 {"meta.intrinsic.VPERMPS", {{Intrinsic::x86_avx2_permps, Intrinsic::not_intrinsic}, VPERM_EMU}}, 139 {"meta.intrinsic.VPERMD", {{Intrinsic::x86_avx2_permd, Intrinsic::not_intrinsic}, VPERM_EMU}}, 140 {"meta.intrinsic.VGATHERPD", {{Intrinsic::not_intrinsic, Intrinsic::not_intrinsic}, VGATHER_EMU}}, 141 {"meta.intrinsic.VGATHERPS", {{Intrinsic::not_intrinsic, Intrinsic::not_intrinsic}, VGATHER_EMU}}, 142 {"meta.intrinsic.VGATHERDD", {{Intrinsic::not_intrinsic, Intrinsic::not_intrinsic}, VGATHER_EMU}}, 143 {"meta.intrinsic.VSCATTERPS", {{Intrinsic::not_intrinsic, Intrinsic::not_intrinsic}, VSCATTER_EMU}}, 144 {"meta.intrinsic.VCVTPD2PS", {{Intrinsic::x86_avx_cvt_pd2_ps_256, DOUBLE}, NO_EMU}}, 145 {"meta.intrinsic.VROUND", {{Intrinsic::x86_avx_round_ps_256, DOUBLE}, NO_EMU}}, 146 {"meta.intrinsic.VHSUBPS", {{Intrinsic::x86_avx_hsub_ps_256, DOUBLE}, NO_EMU}}, 147 }, 148 { 149 // AVX512 150 {"meta.intrinsic.VRCPPS", {{Intrinsic::x86_avx512_rcp14_ps_256, Intrinsic::x86_avx512_rcp14_ps_512}, NO_EMU}}, 151 #if LLVM_VERSION_MAJOR < 7 152 {"meta.intrinsic.VPERMPS", {{Intrinsic::x86_avx512_mask_permvar_sf_256, Intrinsic::x86_avx512_mask_permvar_sf_512}, NO_EMU}}, 153 {"meta.intrinsic.VPERMD", {{Intrinsic::x86_avx512_mask_permvar_si_256, Intrinsic::x86_avx512_mask_permvar_si_512}, NO_EMU}}, 154 #else 155 {"meta.intrinsic.VPERMPS", {{Intrinsic::not_intrinsic, Intrinsic::not_intrinsic}, VPERM_EMU}}, 156 {"meta.intrinsic.VPERMD", {{Intrinsic::not_intrinsic, Intrinsic::not_intrinsic}, VPERM_EMU}}, 157 #endif 158 {"meta.intrinsic.VGATHERPD", {{Intrinsic::not_intrinsic, Intrinsic::not_intrinsic}, VGATHER_EMU}}, 159 {"meta.intrinsic.VGATHERPS", {{Intrinsic::not_intrinsic, Intrinsic::not_intrinsic}, VGATHER_EMU}}, 160 {"meta.intrinsic.VGATHERDD", {{Intrinsic::not_intrinsic, Intrinsic::not_intrinsic}, VGATHER_EMU}}, 161 {"meta.intrinsic.VSCATTERPS", {{Intrinsic::not_intrinsic, Intrinsic::not_intrinsic}, VSCATTER_EMU}}, 162 #if LLVM_VERSION_MAJOR < 7 163 {"meta.intrinsic.VCVTPD2PS", {{Intrinsic::x86_avx512_mask_cvtpd2ps_256, Intrinsic::x86_avx512_mask_cvtpd2ps_512}, NO_EMU}}, 164 #else 165 {"meta.intrinsic.VCVTPD2PS", {{Intrinsic::not_intrinsic, Intrinsic::not_intrinsic}, VCONVERT_EMU}}, 166 #endif 167 {"meta.intrinsic.VROUND", {{Intrinsic::not_intrinsic, Intrinsic::not_intrinsic}, VROUND_EMU}}, 168 {"meta.intrinsic.VHSUBPS", {{Intrinsic::not_intrinsic, Intrinsic::not_intrinsic}, VHSUB_EMU}} 169 }}; 170 // clang-format on 171 return intrinsicMapAdvanced; 172 } 173 getBitWidth(VectorType * pVTy)174 static uint32_t getBitWidth(VectorType *pVTy) 175 { 176 #if LLVM_VERSION_MAJOR >= 11 177 return pVTy->getNumElements() * pVTy->getElementType()->getPrimitiveSizeInBits(); 178 #else 179 return pVTy->getBitWidth(); 180 #endif 181 } 182 183 struct LowerX86 : public FunctionPass 184 { LowerX86SwrJit::LowerX86185 LowerX86(Builder* b = nullptr) : FunctionPass(ID), B(b) 186 { 187 initializeLowerX86Pass(*PassRegistry::getPassRegistry()); 188 189 // Determine target arch 190 if (JM()->mArch.AVX512F()) 191 { 192 mTarget = AVX512; 193 } 194 else if (JM()->mArch.AVX2()) 195 { 196 mTarget = AVX2; 197 } 198 else if (JM()->mArch.AVX()) 199 { 200 mTarget = AVX; 201 } 202 else 203 { 204 SWR_ASSERT(false, "Unsupported AVX architecture."); 205 mTarget = AVX; 206 } 207 208 // Setup scatter function for 256 wide 209 uint32_t curWidth = B->mVWidth; 210 B->SetTargetWidth(8); 211 std::vector<Type*> args = { 212 B->mInt8PtrTy, // pBase 213 B->mSimdInt32Ty, // vIndices 214 B->mSimdFP32Ty, // vSrc 215 B->mInt8Ty, // mask 216 B->mInt32Ty // scale 217 }; 218 219 FunctionType* pfnScatterTy = FunctionType::get(B->mVoidTy, args, false); 220 mPfnScatter256 = cast<Function>( 221 #if LLVM_VERSION_MAJOR >= 9 222 B->JM()->mpCurrentModule->getOrInsertFunction("ScatterPS_256", pfnScatterTy).getCallee()); 223 #else 224 B->JM()->mpCurrentModule->getOrInsertFunction("ScatterPS_256", pfnScatterTy)); 225 #endif 226 if (sys::DynamicLibrary::SearchForAddressOfSymbol("ScatterPS_256") == nullptr) 227 { 228 sys::DynamicLibrary::AddSymbol("ScatterPS_256", (void*)&ScatterPS_256); 229 } 230 231 B->SetTargetWidth(curWidth); 232 } 233 234 // Try to decipher the vector type of the instruction. This does not work properly 235 // across all intrinsics, and will have to be rethought. Probably need something 236 // similar to llvm's getDeclaration() utility to map a set of inputs to a specific typed 237 // intrinsic. GetRequestedWidthAndTypeSwrJit::LowerX86238 void GetRequestedWidthAndType(CallInst* pCallInst, 239 const StringRef intrinName, 240 TargetWidth* pWidth, 241 Type** pTy) 242 { 243 assert(pCallInst); 244 Type* pVecTy = pCallInst->getType(); 245 246 // Check for intrinsic specific types 247 // VCVTPD2PS type comes from src, not dst 248 if (intrinName.equals("meta.intrinsic.VCVTPD2PS")) 249 { 250 Value* pOp = pCallInst->getOperand(0); 251 assert(pOp); 252 pVecTy = pOp->getType(); 253 } 254 255 if (!pVecTy->isVectorTy()) 256 { 257 for (auto& op : pCallInst->arg_operands()) 258 { 259 if (op.get()->getType()->isVectorTy()) 260 { 261 pVecTy = op.get()->getType(); 262 break; 263 } 264 } 265 } 266 SWR_ASSERT(pVecTy->isVectorTy(), "Couldn't determine vector size"); 267 268 uint32_t width = getBitWidth(cast<VectorType>(pVecTy)); 269 switch (width) 270 { 271 case 256: 272 *pWidth = W256; 273 break; 274 case 512: 275 *pWidth = W512; 276 break; 277 default: 278 SWR_ASSERT(false, "Unhandled vector width %d", width); 279 *pWidth = W256; 280 } 281 282 *pTy = pVecTy->getScalarType(); 283 } 284 GetZeroVecSwrJit::LowerX86285 Value* GetZeroVec(TargetWidth width, Type* pTy) 286 { 287 uint32_t numElem = 0; 288 switch (width) 289 { 290 case W256: 291 numElem = 8; 292 break; 293 case W512: 294 numElem = 16; 295 break; 296 default: 297 SWR_ASSERT(false, "Unhandled vector width type %d\n", width); 298 } 299 300 return ConstantVector::getNullValue(getVectorType(pTy, numElem)); 301 } 302 GetMaskSwrJit::LowerX86303 Value* GetMask(TargetWidth width) 304 { 305 Value* mask; 306 switch (width) 307 { 308 case W256: 309 mask = B->C((uint8_t)-1); 310 break; 311 case W512: 312 mask = B->C((uint16_t)-1); 313 break; 314 default: 315 SWR_ASSERT(false, "Unhandled vector width type %d\n", width); 316 } 317 return mask; 318 } 319 320 // Convert <N x i1> mask to <N x i32> x86 mask VectorMaskSwrJit::LowerX86321 Value* VectorMask(Value* vi1Mask) 322 { 323 #if LLVM_VERSION_MAJOR >= 11 324 uint32_t numElem = cast<VectorType>(vi1Mask->getType())->getNumElements(); 325 #else 326 uint32_t numElem = vi1Mask->getType()->getVectorNumElements(); 327 #endif 328 return B->S_EXT(vi1Mask, getVectorType(B->mInt32Ty, numElem)); 329 } 330 ProcessIntrinsicAdvancedSwrJit::LowerX86331 Instruction* ProcessIntrinsicAdvanced(CallInst* pCallInst) 332 { 333 Function* pFunc = pCallInst->getCalledFunction(); 334 assert(pFunc); 335 336 auto& intrinsic = getIntrinsicMapAdvanced()[mTarget][pFunc->getName().str()]; 337 TargetWidth vecWidth; 338 Type* pElemTy; 339 GetRequestedWidthAndType(pCallInst, pFunc->getName(), &vecWidth, &pElemTy); 340 341 // Check if there is a native intrinsic for this instruction 342 IntrinsicID id = intrinsic.intrin[vecWidth]; 343 if (id == DOUBLE) 344 { 345 // Double pump the next smaller SIMD intrinsic 346 SWR_ASSERT(vecWidth != 0, "Cannot double pump smallest SIMD width."); 347 Intrinsic::ID id2 = intrinsic.intrin[vecWidth - 1]; 348 SWR_ASSERT(id2 != Intrinsic::not_intrinsic, 349 "Cannot find intrinsic to double pump."); 350 return DOUBLE_EMU(this, mTarget, vecWidth, pCallInst, id2); 351 } 352 else if (id != Intrinsic::not_intrinsic) 353 { 354 Function* pIntrin = Intrinsic::getDeclaration(B->JM()->mpCurrentModule, id); 355 SmallVector<Value*, 8> args; 356 for (auto& arg : pCallInst->arg_operands()) 357 { 358 args.push_back(arg.get()); 359 } 360 361 // If AVX512, all instructions add a src operand and mask. We'll pass in 0 src and 362 // full mask for now Assuming the intrinsics are consistent and place the src 363 // operand and mask last in the argument list. 364 if (mTarget == AVX512) 365 { 366 if (pFunc->getName().equals("meta.intrinsic.VCVTPD2PS")) 367 { 368 args.push_back(GetZeroVec(W256, pCallInst->getType()->getScalarType())); 369 args.push_back(GetMask(W256)); 370 // for AVX512 VCVTPD2PS, we also have to add rounding mode 371 args.push_back(B->C(_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); 372 } 373 else 374 { 375 args.push_back(GetZeroVec(vecWidth, pElemTy)); 376 args.push_back(GetMask(vecWidth)); 377 } 378 } 379 380 return B->CALLA(pIntrin, args); 381 } 382 else 383 { 384 // No native intrinsic, call emulation function 385 return intrinsic.emuFunc(this, mTarget, vecWidth, pCallInst); 386 } 387 388 SWR_ASSERT(false); 389 return nullptr; 390 } 391 ProcessIntrinsicSwrJit::LowerX86392 Instruction* ProcessIntrinsic(CallInst* pCallInst) 393 { 394 Function* pFunc = pCallInst->getCalledFunction(); 395 assert(pFunc); 396 397 // Forward to the advanced support if found 398 if (getIntrinsicMapAdvanced()[mTarget].find(pFunc->getName().str()) != getIntrinsicMapAdvanced()[mTarget].end()) 399 { 400 return ProcessIntrinsicAdvanced(pCallInst); 401 } 402 403 SWR_ASSERT(getIntrinsicMap().find(pFunc->getName().str()) != getIntrinsicMap().end(), 404 "Unimplemented intrinsic %s.", 405 pFunc->getName().str().c_str()); 406 407 Intrinsic::ID x86Intrinsic = getIntrinsicMap()[pFunc->getName().str()]; 408 Function* pX86IntrinFunc = 409 Intrinsic::getDeclaration(B->JM()->mpCurrentModule, x86Intrinsic); 410 411 SmallVector<Value*, 8> args; 412 for (auto& arg : pCallInst->arg_operands()) 413 { 414 args.push_back(arg.get()); 415 } 416 return B->CALLA(pX86IntrinFunc, args); 417 } 418 419 ////////////////////////////////////////////////////////////////////////// 420 /// @brief LLVM funtion pass run method. 421 /// @param f- The function we're working on with this pass. runOnFunctionSwrJit::LowerX86422 virtual bool runOnFunction(Function& F) 423 { 424 std::vector<Instruction*> toRemove; 425 std::vector<BasicBlock*> bbs; 426 427 // Make temp copy of the basic blocks and instructions, as the intrinsic 428 // replacement code might invalidate the iterators 429 for (auto& b : F.getBasicBlockList()) 430 { 431 bbs.push_back(&b); 432 } 433 434 for (auto* BB : bbs) 435 { 436 std::vector<Instruction*> insts; 437 for (auto& i : BB->getInstList()) 438 { 439 insts.push_back(&i); 440 } 441 442 for (auto* I : insts) 443 { 444 if (CallInst* pCallInst = dyn_cast<CallInst>(I)) 445 { 446 Function* pFunc = pCallInst->getCalledFunction(); 447 if (pFunc) 448 { 449 if (pFunc->getName().startswith("meta.intrinsic")) 450 { 451 B->IRB()->SetInsertPoint(I); 452 Instruction* pReplace = ProcessIntrinsic(pCallInst); 453 toRemove.push_back(pCallInst); 454 if (pReplace) 455 { 456 pCallInst->replaceAllUsesWith(pReplace); 457 } 458 } 459 } 460 } 461 } 462 } 463 464 for (auto* pInst : toRemove) 465 { 466 pInst->eraseFromParent(); 467 } 468 469 JitManager::DumpToFile(&F, "lowerx86"); 470 471 return true; 472 } 473 getAnalysisUsageSwrJit::LowerX86474 virtual void getAnalysisUsage(AnalysisUsage& AU) const {} 475 JMSwrJit::LowerX86476 JitManager* JM() { return B->JM(); } 477 Builder* B; 478 TargetArch mTarget; 479 Function* mPfnScatter256; 480 481 static char ID; ///< Needed by LLVM to generate ID for FunctionPass. 482 }; 483 484 char LowerX86::ID = 0; // LLVM uses address of ID as the actual ID. 485 createLowerX86Pass(Builder * b)486 FunctionPass* createLowerX86Pass(Builder* b) { return new LowerX86(b); } 487 NO_EMU(LowerX86 * pThis,TargetArch arch,TargetWidth width,CallInst * pCallInst)488 Instruction* NO_EMU(LowerX86* pThis, TargetArch arch, TargetWidth width, CallInst* pCallInst) 489 { 490 SWR_ASSERT(false, "Unimplemented intrinsic emulation."); 491 return nullptr; 492 } 493 VPERM_EMU(LowerX86 * pThis,TargetArch arch,TargetWidth width,CallInst * pCallInst)494 Instruction* VPERM_EMU(LowerX86* pThis, TargetArch arch, TargetWidth width, CallInst* pCallInst) 495 { 496 // Only need vperm emulation for AVX 497 SWR_ASSERT(arch == AVX); 498 499 Builder* B = pThis->B; 500 auto v32A = pCallInst->getArgOperand(0); 501 auto vi32Index = pCallInst->getArgOperand(1); 502 503 Value* v32Result; 504 if (isa<Constant>(vi32Index)) 505 { 506 // Can use llvm shuffle vector directly with constant shuffle indices 507 v32Result = B->VSHUFFLE(v32A, v32A, vi32Index); 508 } 509 else 510 { 511 v32Result = UndefValue::get(v32A->getType()); 512 #if LLVM_VERSION_MAJOR >= 11 513 uint32_t numElem = cast<VectorType>(v32A->getType())->getNumElements(); 514 #else 515 uint32_t numElem = v32A->getType()->getVectorNumElements(); 516 #endif 517 for (uint32_t l = 0; l < numElem; ++l) 518 { 519 auto i32Index = B->VEXTRACT(vi32Index, B->C(l)); 520 auto val = B->VEXTRACT(v32A, i32Index); 521 v32Result = B->VINSERT(v32Result, val, B->C(l)); 522 } 523 } 524 return cast<Instruction>(v32Result); 525 } 526 527 Instruction* VGATHER_EMU(LowerX86 * pThis,TargetArch arch,TargetWidth width,CallInst * pCallInst)528 VGATHER_EMU(LowerX86* pThis, TargetArch arch, TargetWidth width, CallInst* pCallInst) 529 { 530 Builder* B = pThis->B; 531 auto vSrc = pCallInst->getArgOperand(0); 532 auto pBase = pCallInst->getArgOperand(1); 533 auto vi32Indices = pCallInst->getArgOperand(2); 534 auto vi1Mask = pCallInst->getArgOperand(3); 535 auto i8Scale = pCallInst->getArgOperand(4); 536 537 pBase = B->POINTER_CAST(pBase, PointerType::get(B->mInt8Ty, 0)); 538 #if LLVM_VERSION_MAJOR >= 11 539 VectorType* pVectorType = cast<VectorType>(vSrc->getType()); 540 uint32_t numElem = pVectorType->getNumElements(); 541 auto srcTy = pVectorType->getElementType(); 542 #else 543 uint32_t numElem = vSrc->getType()->getVectorNumElements(); 544 auto srcTy = vSrc->getType()->getVectorElementType(); 545 #endif 546 auto i32Scale = B->Z_EXT(i8Scale, B->mInt32Ty); 547 548 Value* v32Gather = nullptr; 549 if (arch == AVX) 550 { 551 // Full emulation for AVX 552 // Store source on stack to provide a valid address to load from inactive lanes 553 auto pStack = B->STACKSAVE(); 554 auto pTmp = B->ALLOCA(vSrc->getType()); 555 B->STORE(vSrc, pTmp); 556 557 v32Gather = UndefValue::get(vSrc->getType()); 558 #if LLVM_VERSION_MAJOR <= 10 559 auto vi32Scale = ConstantVector::getSplat(numElem, cast<ConstantInt>(i32Scale)); 560 #elif LLVM_VERSION_MAJOR == 11 561 auto vi32Scale = ConstantVector::getSplat(ElementCount(numElem, false), cast<ConstantInt>(i32Scale)); 562 #else 563 auto vi32Scale = ConstantVector::getSplat(ElementCount::get(numElem, false), cast<ConstantInt>(i32Scale)); 564 #endif 565 auto vi32Offsets = B->MUL(vi32Indices, vi32Scale); 566 567 for (uint32_t i = 0; i < numElem; ++i) 568 { 569 auto i32Offset = B->VEXTRACT(vi32Offsets, B->C(i)); 570 auto pLoadAddress = B->GEP(pBase, i32Offset); 571 pLoadAddress = B->BITCAST(pLoadAddress, PointerType::get(srcTy, 0)); 572 auto pMaskedLoadAddress = B->GEP(pTmp, {0, i}); 573 auto i1Mask = B->VEXTRACT(vi1Mask, B->C(i)); 574 auto pValidAddress = B->SELECT(i1Mask, pLoadAddress, pMaskedLoadAddress); 575 auto val = B->LOAD(pValidAddress); 576 v32Gather = B->VINSERT(v32Gather, val, B->C(i)); 577 } 578 579 B->STACKRESTORE(pStack); 580 } 581 else if (arch == AVX2 || (arch == AVX512 && width == W256)) 582 { 583 Function* pX86IntrinFunc = nullptr; 584 if (srcTy == B->mFP32Ty) 585 { 586 pX86IntrinFunc = Intrinsic::getDeclaration(B->JM()->mpCurrentModule, 587 Intrinsic::x86_avx2_gather_d_ps_256); 588 } 589 else if (srcTy == B->mInt32Ty) 590 { 591 pX86IntrinFunc = Intrinsic::getDeclaration(B->JM()->mpCurrentModule, 592 Intrinsic::x86_avx2_gather_d_d_256); 593 } 594 else if (srcTy == B->mDoubleTy) 595 { 596 pX86IntrinFunc = Intrinsic::getDeclaration(B->JM()->mpCurrentModule, 597 Intrinsic::x86_avx2_gather_d_q_256); 598 } 599 else 600 { 601 SWR_ASSERT(false, "Unsupported vector element type for gather."); 602 } 603 604 if (width == W256) 605 { 606 auto v32Mask = B->BITCAST(pThis->VectorMask(vi1Mask), vSrc->getType()); 607 v32Gather = B->CALL(pX86IntrinFunc, {vSrc, pBase, vi32Indices, v32Mask, i8Scale}); 608 } 609 else if (width == W512) 610 { 611 // Double pump 4-wide for 64bit elements 612 #if LLVM_VERSION_MAJOR >= 11 613 if (cast<VectorType>(vSrc->getType())->getElementType() == B->mDoubleTy) 614 #else 615 if (vSrc->getType()->getVectorElementType() == B->mDoubleTy) 616 #endif 617 { 618 auto v64Mask = pThis->VectorMask(vi1Mask); 619 #if LLVM_VERSION_MAJOR >= 11 620 uint32_t numElem = cast<VectorType>(v64Mask->getType())->getNumElements(); 621 #else 622 uint32_t numElem = v64Mask->getType()->getVectorNumElements(); 623 #endif 624 v64Mask = B->S_EXT(v64Mask, getVectorType(B->mInt64Ty, numElem)); 625 v64Mask = B->BITCAST(v64Mask, vSrc->getType()); 626 627 Value* src0 = B->VSHUFFLE(vSrc, vSrc, B->C({0, 1, 2, 3})); 628 Value* src1 = B->VSHUFFLE(vSrc, vSrc, B->C({4, 5, 6, 7})); 629 630 Value* indices0 = B->VSHUFFLE(vi32Indices, vi32Indices, B->C({0, 1, 2, 3})); 631 Value* indices1 = B->VSHUFFLE(vi32Indices, vi32Indices, B->C({4, 5, 6, 7})); 632 633 Value* mask0 = B->VSHUFFLE(v64Mask, v64Mask, B->C({0, 1, 2, 3})); 634 Value* mask1 = B->VSHUFFLE(v64Mask, v64Mask, B->C({4, 5, 6, 7})); 635 636 #if LLVM_VERSION_MAJOR >= 11 637 uint32_t numElemSrc0 = cast<VectorType>(src0->getType())->getNumElements(); 638 uint32_t numElemMask0 = cast<VectorType>(mask0->getType())->getNumElements(); 639 uint32_t numElemSrc1 = cast<VectorType>(src1->getType())->getNumElements(); 640 uint32_t numElemMask1 = cast<VectorType>(mask1->getType())->getNumElements(); 641 #else 642 uint32_t numElemSrc0 = src0->getType()->getVectorNumElements(); 643 uint32_t numElemMask0 = mask0->getType()->getVectorNumElements(); 644 uint32_t numElemSrc1 = src1->getType()->getVectorNumElements(); 645 uint32_t numElemMask1 = mask1->getType()->getVectorNumElements(); 646 #endif 647 src0 = B->BITCAST(src0, getVectorType(B->mInt64Ty, numElemSrc0)); 648 mask0 = B->BITCAST(mask0, getVectorType(B->mInt64Ty, numElemMask0)); 649 Value* gather0 = 650 B->CALL(pX86IntrinFunc, {src0, pBase, indices0, mask0, i8Scale}); 651 src1 = B->BITCAST(src1, getVectorType(B->mInt64Ty, numElemSrc1)); 652 mask1 = B->BITCAST(mask1, getVectorType(B->mInt64Ty, numElemMask1)); 653 Value* gather1 = 654 B->CALL(pX86IntrinFunc, {src1, pBase, indices1, mask1, i8Scale}); 655 v32Gather = B->VSHUFFLE(gather0, gather1, B->C({0, 1, 2, 3, 4, 5, 6, 7})); 656 v32Gather = B->BITCAST(v32Gather, vSrc->getType()); 657 } 658 else 659 { 660 // Double pump 8-wide for 32bit elements 661 auto v32Mask = pThis->VectorMask(vi1Mask); 662 v32Mask = B->BITCAST(v32Mask, vSrc->getType()); 663 Value* src0 = B->EXTRACT_16(vSrc, 0); 664 Value* src1 = B->EXTRACT_16(vSrc, 1); 665 666 Value* indices0 = B->EXTRACT_16(vi32Indices, 0); 667 Value* indices1 = B->EXTRACT_16(vi32Indices, 1); 668 669 Value* mask0 = B->EXTRACT_16(v32Mask, 0); 670 Value* mask1 = B->EXTRACT_16(v32Mask, 1); 671 672 Value* gather0 = 673 B->CALL(pX86IntrinFunc, {src0, pBase, indices0, mask0, i8Scale}); 674 Value* gather1 = 675 B->CALL(pX86IntrinFunc, {src1, pBase, indices1, mask1, i8Scale}); 676 677 v32Gather = B->JOIN_16(gather0, gather1); 678 } 679 } 680 } 681 else if (arch == AVX512) 682 { 683 Value* iMask = nullptr; 684 Function* pX86IntrinFunc = nullptr; 685 if (srcTy == B->mFP32Ty) 686 { 687 pX86IntrinFunc = Intrinsic::getDeclaration(B->JM()->mpCurrentModule, 688 Intrinsic::x86_avx512_gather_dps_512); 689 iMask = B->BITCAST(vi1Mask, B->mInt16Ty); 690 } 691 else if (srcTy == B->mInt32Ty) 692 { 693 pX86IntrinFunc = Intrinsic::getDeclaration(B->JM()->mpCurrentModule, 694 Intrinsic::x86_avx512_gather_dpi_512); 695 iMask = B->BITCAST(vi1Mask, B->mInt16Ty); 696 } 697 else if (srcTy == B->mDoubleTy) 698 { 699 pX86IntrinFunc = Intrinsic::getDeclaration(B->JM()->mpCurrentModule, 700 Intrinsic::x86_avx512_gather_dpd_512); 701 iMask = B->BITCAST(vi1Mask, B->mInt8Ty); 702 } 703 else 704 { 705 SWR_ASSERT(false, "Unsupported vector element type for gather."); 706 } 707 708 auto i32Scale = B->Z_EXT(i8Scale, B->mInt32Ty); 709 v32Gather = B->CALL(pX86IntrinFunc, {vSrc, pBase, vi32Indices, iMask, i32Scale}); 710 } 711 712 return cast<Instruction>(v32Gather); 713 } 714 Instruction* VSCATTER_EMU(LowerX86 * pThis,TargetArch arch,TargetWidth width,CallInst * pCallInst)715 VSCATTER_EMU(LowerX86* pThis, TargetArch arch, TargetWidth width, CallInst* pCallInst) 716 { 717 Builder* B = pThis->B; 718 auto pBase = pCallInst->getArgOperand(0); 719 auto vi1Mask = pCallInst->getArgOperand(1); 720 auto vi32Indices = pCallInst->getArgOperand(2); 721 auto v32Src = pCallInst->getArgOperand(3); 722 auto i32Scale = pCallInst->getArgOperand(4); 723 724 if (arch != AVX512) 725 { 726 // Call into C function to do the scatter. This has significantly better compile perf 727 // compared to jitting scatter loops for every scatter 728 if (width == W256) 729 { 730 auto mask = B->BITCAST(vi1Mask, B->mInt8Ty); 731 B->CALL(pThis->mPfnScatter256, {pBase, vi32Indices, v32Src, mask, i32Scale}); 732 } 733 else 734 { 735 // Need to break up 512 wide scatter to two 256 wide 736 auto maskLo = B->VSHUFFLE(vi1Mask, vi1Mask, B->C({0, 1, 2, 3, 4, 5, 6, 7})); 737 auto indicesLo = 738 B->VSHUFFLE(vi32Indices, vi32Indices, B->C({0, 1, 2, 3, 4, 5, 6, 7})); 739 auto srcLo = B->VSHUFFLE(v32Src, v32Src, B->C({0, 1, 2, 3, 4, 5, 6, 7})); 740 741 auto mask = B->BITCAST(maskLo, B->mInt8Ty); 742 B->CALL(pThis->mPfnScatter256, {pBase, indicesLo, srcLo, mask, i32Scale}); 743 744 auto maskHi = B->VSHUFFLE(vi1Mask, vi1Mask, B->C({8, 9, 10, 11, 12, 13, 14, 15})); 745 auto indicesHi = 746 B->VSHUFFLE(vi32Indices, vi32Indices, B->C({8, 9, 10, 11, 12, 13, 14, 15})); 747 auto srcHi = B->VSHUFFLE(v32Src, v32Src, B->C({8, 9, 10, 11, 12, 13, 14, 15})); 748 749 mask = B->BITCAST(maskHi, B->mInt8Ty); 750 B->CALL(pThis->mPfnScatter256, {pBase, indicesHi, srcHi, mask, i32Scale}); 751 } 752 return nullptr; 753 } 754 755 Value* iMask; 756 Function* pX86IntrinFunc; 757 if (width == W256) 758 { 759 // No direct intrinsic supported in llvm to scatter 8 elem with 32bit indices, but we 760 // can use the scatter of 8 elements with 64bit indices 761 pX86IntrinFunc = Intrinsic::getDeclaration(B->JM()->mpCurrentModule, 762 Intrinsic::x86_avx512_scatter_qps_512); 763 764 auto vi32IndicesExt = B->Z_EXT(vi32Indices, B->mSimdInt64Ty); 765 iMask = B->BITCAST(vi1Mask, B->mInt8Ty); 766 B->CALL(pX86IntrinFunc, {pBase, iMask, vi32IndicesExt, v32Src, i32Scale}); 767 } 768 else if (width == W512) 769 { 770 pX86IntrinFunc = Intrinsic::getDeclaration(B->JM()->mpCurrentModule, 771 Intrinsic::x86_avx512_scatter_dps_512); 772 iMask = B->BITCAST(vi1Mask, B->mInt16Ty); 773 B->CALL(pX86IntrinFunc, {pBase, iMask, vi32Indices, v32Src, i32Scale}); 774 } 775 return nullptr; 776 } 777 778 // No support for vroundps in avx512 (it is available in kncni), so emulate with avx 779 // instructions 780 Instruction* VROUND_EMU(LowerX86 * pThis,TargetArch arch,TargetWidth width,CallInst * pCallInst)781 VROUND_EMU(LowerX86* pThis, TargetArch arch, TargetWidth width, CallInst* pCallInst) 782 { 783 SWR_ASSERT(arch == AVX512); 784 785 auto B = pThis->B; 786 auto vf32Src = pCallInst->getOperand(0); 787 assert(vf32Src); 788 auto i8Round = pCallInst->getOperand(1); 789 assert(i8Round); 790 auto pfnFunc = 791 Intrinsic::getDeclaration(B->JM()->mpCurrentModule, Intrinsic::x86_avx_round_ps_256); 792 793 if (width == W256) 794 { 795 return cast<Instruction>(B->CALL2(pfnFunc, vf32Src, i8Round)); 796 } 797 else if (width == W512) 798 { 799 auto v8f32SrcLo = B->EXTRACT_16(vf32Src, 0); 800 auto v8f32SrcHi = B->EXTRACT_16(vf32Src, 1); 801 802 auto v8f32ResLo = B->CALL2(pfnFunc, v8f32SrcLo, i8Round); 803 auto v8f32ResHi = B->CALL2(pfnFunc, v8f32SrcHi, i8Round); 804 805 return cast<Instruction>(B->JOIN_16(v8f32ResLo, v8f32ResHi)); 806 } 807 else 808 { 809 SWR_ASSERT(false, "Unimplemented vector width."); 810 } 811 812 return nullptr; 813 } 814 815 Instruction* VCONVERT_EMU(LowerX86 * pThis,TargetArch arch,TargetWidth width,CallInst * pCallInst)816 VCONVERT_EMU(LowerX86* pThis, TargetArch arch, TargetWidth width, CallInst* pCallInst) 817 { 818 SWR_ASSERT(arch == AVX512); 819 820 auto B = pThis->B; 821 auto vf32Src = pCallInst->getOperand(0); 822 823 if (width == W256) 824 { 825 auto vf32SrcRound = Intrinsic::getDeclaration(B->JM()->mpCurrentModule, 826 Intrinsic::x86_avx_round_ps_256); 827 return cast<Instruction>(B->FP_TRUNC(vf32SrcRound, B->mFP32Ty)); 828 } 829 else if (width == W512) 830 { 831 // 512 can use intrinsic 832 auto pfnFunc = Intrinsic::getDeclaration(B->JM()->mpCurrentModule, 833 Intrinsic::x86_avx512_mask_cvtpd2ps_512); 834 return cast<Instruction>(B->CALL(pfnFunc, vf32Src)); 835 } 836 else 837 { 838 SWR_ASSERT(false, "Unimplemented vector width."); 839 } 840 841 return nullptr; 842 } 843 844 // No support for hsub in AVX512 VHSUB_EMU(LowerX86 * pThis,TargetArch arch,TargetWidth width,CallInst * pCallInst)845 Instruction* VHSUB_EMU(LowerX86* pThis, TargetArch arch, TargetWidth width, CallInst* pCallInst) 846 { 847 SWR_ASSERT(arch == AVX512); 848 849 auto B = pThis->B; 850 auto src0 = pCallInst->getOperand(0); 851 auto src1 = pCallInst->getOperand(1); 852 853 // 256b hsub can just use avx intrinsic 854 if (width == W256) 855 { 856 auto pX86IntrinFunc = 857 Intrinsic::getDeclaration(B->JM()->mpCurrentModule, Intrinsic::x86_avx_hsub_ps_256); 858 return cast<Instruction>(B->CALL2(pX86IntrinFunc, src0, src1)); 859 } 860 else if (width == W512) 861 { 862 // 512b hsub can be accomplished with shuf/sub combo 863 auto minuend = B->VSHUFFLE(src0, src1, B->C({0, 2, 8, 10, 4, 6, 12, 14})); 864 auto subtrahend = B->VSHUFFLE(src0, src1, B->C({1, 3, 9, 11, 5, 7, 13, 15})); 865 return cast<Instruction>(B->SUB(minuend, subtrahend)); 866 } 867 else 868 { 869 SWR_ASSERT(false, "Unimplemented vector width."); 870 return nullptr; 871 } 872 } 873 874 // Double pump input using Intrin template arg. This blindly extracts lower and upper 256 from 875 // each vector argument and calls the 256 wide intrinsic, then merges the results to 512 wide DOUBLE_EMU(LowerX86 * pThis,TargetArch arch,TargetWidth width,CallInst * pCallInst,Intrinsic::ID intrin)876 Instruction* DOUBLE_EMU(LowerX86* pThis, 877 TargetArch arch, 878 TargetWidth width, 879 CallInst* pCallInst, 880 Intrinsic::ID intrin) 881 { 882 auto B = pThis->B; 883 SWR_ASSERT(width == W512); 884 Value* result[2]; 885 Function* pX86IntrinFunc = Intrinsic::getDeclaration(B->JM()->mpCurrentModule, intrin); 886 for (uint32_t i = 0; i < 2; ++i) 887 { 888 SmallVector<Value*, 8> args; 889 for (auto& arg : pCallInst->arg_operands()) 890 { 891 auto argType = arg.get()->getType(); 892 if (argType->isVectorTy()) 893 { 894 #if LLVM_VERSION_MAJOR >= 11 895 uint32_t vecWidth = cast<VectorType>(argType)->getNumElements(); 896 auto elemTy = cast<VectorType>(argType)->getElementType(); 897 #else 898 uint32_t vecWidth = argType->getVectorNumElements(); 899 auto elemTy = argType->getVectorElementType(); 900 #endif 901 Value* lanes = B->CInc<int>(i * vecWidth / 2, vecWidth / 2); 902 Value* argToPush = B->VSHUFFLE(arg.get(), B->VUNDEF(elemTy, vecWidth), lanes); 903 args.push_back(argToPush); 904 } 905 else 906 { 907 args.push_back(arg.get()); 908 } 909 } 910 result[i] = B->CALLA(pX86IntrinFunc, args); 911 } 912 uint32_t vecWidth; 913 if (result[0]->getType()->isVectorTy()) 914 { 915 assert(result[1]->getType()->isVectorTy()); 916 #if LLVM_VERSION_MAJOR >= 11 917 vecWidth = cast<VectorType>(result[0]->getType())->getNumElements() + 918 cast<VectorType>(result[1]->getType())->getNumElements(); 919 #else 920 vecWidth = result[0]->getType()->getVectorNumElements() + 921 result[1]->getType()->getVectorNumElements(); 922 #endif 923 } 924 else 925 { 926 vecWidth = 2; 927 } 928 Value* lanes = B->CInc<int>(0, vecWidth); 929 return cast<Instruction>(B->VSHUFFLE(result[0], result[1], lanes)); 930 } 931 932 } // namespace SwrJit 933 934 using namespace SwrJit; 935 936 INITIALIZE_PASS_BEGIN(LowerX86, "LowerX86", "LowerX86", false, false) 937 INITIALIZE_PASS_END(LowerX86, "LowerX86", "LowerX86", false, false) 938