1 //===- SPIRVLowerOCLBlocks.cpp - Lower OpenCL blocks ------------*- C++ -*-===//
2 //
3 //                     The LLVM/SPIR-V Translator
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 // Copyright (c) 2014 Advanced Micro Devices, Inc. All rights reserved.
9 //
10 // Permission is hereby granted, free of charge, to any person obtaining a
11 // copy of this software and associated documentation files (the "Software"),
12 // to deal with the Software without restriction, including without limitation
13 // the rights to use, copy, modify, merge, publish, distribute, sublicense,
14 // and/or sell copies of the Software, and to permit persons to whom the
15 // Software is furnished to do so, subject to the following conditions:
16 //
17 // Redistributions of source code must retain the above copyright notice,
18 // this list of conditions and the following disclaimers.
19 // Redistributions in binary form must reproduce the above copyright notice,
20 // this list of conditions and the following disclaimers in the documentation
21 // and/or other materials provided with the distribution.
22 // Neither the names of Advanced Micro Devices, Inc., nor the names of its
23 // contributors may be used to endorse or promote products derived from this
24 // Software without specific prior written permission.
25 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
26 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
27 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
28 // CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
29 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
30 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS WITH
31 // THE SOFTWARE.
32 //
33 //===----------------------------------------------------------------------===//
34 /// \file
35 ///
36 /// This file implements lowering of OpenCL blocks to functions.
37 ///
38 //===----------------------------------------------------------------------===//
39 
40 #ifndef OCLLOWERBLOCKS_H_
41 #define OCLLOWERBLOCKS_H_
42 
43 #include "SPIRVInternal.h"
44 #include "OCLUtil.h"
45 
46 #include "llvm/ADT/DenseMap.h"
47 #include "llvm/ADT/SetVector.h"
48 #include "llvm/ADT/StringSwitch.h"
49 #include "llvm/ADT/Triple.h"
50 #include "llvm/Analysis/AliasAnalysis.h"
51 #include "llvm/Analysis/AssumptionCache.h"
52 #include "llvm/Analysis/CallGraph.h"
53 #include "llvm/IR/Verifier.h"
54 #include "llvm/Bitcode/ReaderWriter.h"
55 #include "llvm/IR/Constants.h"
56 #include "llvm/IR/DerivedTypes.h"
57 #include "llvm/IR/Function.h"
58 #include "llvm/IR/InstrTypes.h"
59 #include "llvm/IR/Instructions.h"
60 #include "llvm/IR/Module.h"
61 #include "llvm/IR/Operator.h"
62 #include "llvm/Pass.h"
63 #include "llvm/PassSupport.h"
64 #include "llvm/Support/Casting.h"
65 #include "llvm/Support/Debug.h"
66 #include "llvm/Support/raw_ostream.h"
67 #include "llvm/Support/ToolOutputFile.h"
68 #include "llvm/Transforms/Utils/Cloning.h"
69 
70 #include <iostream>
71 #include <list>
72 #include <memory>
73 #include <set>
74 #include <sstream>
75 #include <vector>
76 
77 #define DEBUG_TYPE "spvblocks"
78 
79 using namespace llvm;
80 using namespace SPIRV;
81 using namespace OCLUtil;
82 
83 namespace SPIRV{
84 
85 /// Lower SPIR2 blocks to function calls.
86 ///
87 /// SPIR2 representation of blocks:
88 ///
89 /// block = spir_block_bind(bitcast(block_func), context_len, context_align,
90 ///   context)
91 /// block_func_ptr = bitcast(spir_get_block_invoke(block))
92 /// context_ptr = spir_get_block_context(block)
93 /// ret = block_func_ptr(context_ptr, args)
94 ///
95 /// Propagates block_func to each spir_get_block_invoke through def-use chain of
96 /// spir_block_bind, so that
97 /// ret = block_func(context, args)
98 class SPIRVLowerOCLBlocks: public ModulePass {
99 public:
SPIRVLowerOCLBlocks()100   SPIRVLowerOCLBlocks():ModulePass(ID), M(nullptr){
101     initializeSPIRVLowerOCLBlocksPass(*PassRegistry::getPassRegistry());
102   }
103 
getAnalysisUsage(AnalysisUsage & AU) const104   virtual void getAnalysisUsage(AnalysisUsage &AU) const {
105     AU.addRequired<CallGraphWrapperPass>();
106     //AU.addRequired<AliasAnalysis>();
107     AU.addRequired<AssumptionCacheTracker>();
108   }
109 
runOnModule(Module & Module)110   virtual bool runOnModule(Module &Module) {
111     M = &Module;
112     lowerBlockBind();
113     lowerGetBlockInvoke();
114     lowerGetBlockContext();
115     erase(M->getFunction(SPIR_INTRINSIC_GET_BLOCK_INVOKE));
116     erase(M->getFunction(SPIR_INTRINSIC_GET_BLOCK_CONTEXT));
117     erase(M->getFunction(SPIR_INTRINSIC_BLOCK_BIND));
118     DEBUG(dbgs() << "------- After OCLLowerBlocks ------------\n" <<
119                     *M << '\n');
120     return true;
121   }
122 
123   static char ID;
124 private:
125   const static int MaxIter = 1000;
126   Module *M;
127 
128   bool
lowerBlockBind()129   lowerBlockBind() {
130     auto F = M->getFunction(SPIR_INTRINSIC_BLOCK_BIND);
131     if (!F)
132       return false;
133     int Iter = MaxIter;
134     while(lowerBlockBind(F) && Iter > 0){
135       Iter--;
136       DEBUG(dbgs() << "-------------- after iteration " << MaxIter - Iter <<
137           " --------------\n" << *M << '\n');
138     }
139     assert(Iter > 0 && "Too many iterations");
140     return true;
141   }
142 
143   bool
eraseUselessFunctions()144   eraseUselessFunctions() {
145     bool changed = false;
146     for (auto I = M->begin(), E = M->end(); I != E;) {
147       Function *F = static_cast<Function*>(I++);
148       if (!GlobalValue::isInternalLinkage(F->getLinkage()) &&
149           !F->isDeclaration())
150         continue;
151 
152       dumpUsers(F, "[eraseUselessFunctions] ");
153       for (auto UI = F->user_begin(), UE = F->user_end(); UI != UE;) {
154         auto U = *UI++;
155         if (auto CE = dyn_cast<ConstantExpr>(U)){
156           if (CE->use_empty()) {
157             CE->dropAllReferences();
158             changed = true;
159           }
160         }
161       }
162       if (F->use_empty()) {
163         erase(F);
164         changed = true;
165       }
166     }
167     return changed;
168   }
169 
170   void
lowerGetBlockInvoke()171   lowerGetBlockInvoke() {
172     if (auto F = M->getFunction(SPIR_INTRINSIC_GET_BLOCK_INVOKE)) {
173       for (auto UI = F->user_begin(), UE = F->user_end(); UI != UE;) {
174         auto CI = dyn_cast<CallInst>(*UI++);
175         assert(CI && "Invalid usage of spir_get_block_invoke");
176         lowerGetBlockInvoke(CI);
177       }
178     }
179   }
180 
181   void
lowerGetBlockContext()182   lowerGetBlockContext() {
183     if (auto F = M->getFunction(SPIR_INTRINSIC_GET_BLOCK_CONTEXT)) {
184       for (auto UI = F->user_begin(), UE = F->user_end(); UI != UE;) {
185         auto CI = dyn_cast<CallInst>(*UI++);
186         assert(CI && "Invalid usage of spir_get_block_context");
187         lowerGetBlockContext(CI);
188       }
189     }
190   }
191   /// Lower calls of spir_block_bind.
192   /// Return true if the Module is changed.
193   bool
lowerBlockBind(Function * BlockBindFunc)194   lowerBlockBind(Function *BlockBindFunc) {
195     bool changed = false;
196     for (auto I = BlockBindFunc->user_begin(), E = BlockBindFunc->user_end();
197         I != E;) {
198       DEBUG(dbgs() << "[lowerBlockBind] " << **I << '\n');
199       // Handle spir_block_bind(bitcast(block_func), context_len,
200       // context_align, context)
201       auto CallBlkBind = cast<CallInst>(*I++);
202       Function *InvF = nullptr;
203       Value *Ctx = nullptr;
204       Value *CtxLen = nullptr;
205       Value *CtxAlign = nullptr;
206       getBlockInvokeFuncAndContext(CallBlkBind, &InvF, &Ctx, &CtxLen,
207           &CtxAlign);
208       for (auto II = CallBlkBind->user_begin(), EE = CallBlkBind->user_end();
209           II != EE;) {
210         auto BlkUser = *II++;
211         SPIRVDBG(dbgs() << "  Block user: " << *BlkUser << '\n');
212         if (auto Ret = dyn_cast<ReturnInst>(BlkUser)) {
213           bool Inlined = false;
214           changed |= lowerReturnBlock(Ret, CallBlkBind, Inlined);
215           if (Inlined)
216             return true;
217         } else if (auto CI = dyn_cast<CallInst>(BlkUser)){
218           auto CallBindF = CI->getCalledFunction();
219           auto Name = CallBindF->getName();
220           std::string DemangledName;
221           if (Name == SPIR_INTRINSIC_GET_BLOCK_INVOKE) {
222             assert(CI->getArgOperand(0) == CallBlkBind);
223             changed |= lowerGetBlockInvoke(CI, cast<Function>(InvF));
224           } else if (Name == SPIR_INTRINSIC_GET_BLOCK_CONTEXT) {
225             assert(CI->getArgOperand(0) == CallBlkBind);
226             // Handle context_ptr = spir_get_block_context(block)
227             lowerGetBlockContext(CI, Ctx);
228             changed = true;
229           } else if (oclIsBuiltin(Name, &DemangledName)) {
230             lowerBlockBuiltin(CI, InvF, Ctx, CtxLen, CtxAlign, DemangledName);
231             changed = true;
232           } else
233             llvm_unreachable("Invalid block user");
234         }
235       }
236       erase(CallBlkBind);
237     }
238     changed |= eraseUselessFunctions();
239     return changed;
240   }
241 
242   void
lowerGetBlockContext(CallInst * CallGetBlkCtx,Value * Ctx=nullptr)243   lowerGetBlockContext(CallInst *CallGetBlkCtx, Value *Ctx = nullptr) {
244     if (!Ctx)
245       getBlockInvokeFuncAndContext(CallGetBlkCtx->getArgOperand(0), nullptr,
246           &Ctx);
247     CallGetBlkCtx->replaceAllUsesWith(Ctx);
248     DEBUG(dbgs() << "  [lowerGetBlockContext] " << *CallGetBlkCtx << " => " <<
249         *Ctx << "\n\n");
250     erase(CallGetBlkCtx);
251   }
252 
253   bool
lowerGetBlockInvoke(CallInst * CallGetBlkInvoke,Function * InvokeF=nullptr)254   lowerGetBlockInvoke(CallInst *CallGetBlkInvoke,
255       Function *InvokeF = nullptr) {
256     bool changed = false;
257     for (auto UI = CallGetBlkInvoke->user_begin(),
258         UE = CallGetBlkInvoke->user_end();
259         UI != UE;) {
260       // Handle block_func_ptr = bitcast(spir_get_block_invoke(block))
261       auto CallInv = cast<Instruction>(*UI++);
262       auto Cast = dyn_cast<BitCastInst>(CallInv);
263       if (Cast)
264         CallInv = dyn_cast<Instruction>(*CallInv->user_begin());
265       DEBUG(dbgs() << "[lowerGetBlockInvoke]  " << *CallInv);
266       // Handle ret = block_func_ptr(context_ptr, args)
267       auto CI = cast<CallInst>(CallInv);
268       auto F = CI->getCalledValue();
269       if (InvokeF == nullptr) {
270         getBlockInvokeFuncAndContext(CallGetBlkInvoke->getArgOperand(0),
271             &InvokeF, nullptr);
272         assert(InvokeF);
273       }
274       assert(F->getType() == InvokeF->getType());
275       CI->replaceUsesOfWith(F, InvokeF);
276       DEBUG(dbgs() << " => " << *CI << "\n\n");
277       erase(Cast);
278       changed = true;
279     }
280     erase(CallGetBlkInvoke);
281     return changed;
282   }
283 
284   void
lowerBlockBuiltin(CallInst * CI,Function * InvF,Value * Ctx,Value * CtxLen,Value * CtxAlign,const std::string & DemangledName)285   lowerBlockBuiltin(CallInst *CI, Function *InvF, Value *Ctx, Value *CtxLen,
286       Value *CtxAlign, const std::string& DemangledName) {
287     mutateCallInstSPIRV (M, CI, [=](CallInst *CI, std::vector<Value *> &Args) {
288       size_t I = 0;
289       size_t E = Args.size();
290       for (; I != E; ++I) {
291         if (isPointerToOpaqueStructType(Args[I]->getType(),
292             SPIR_TYPE_NAME_BLOCK_T)) {
293           break;
294         }
295       }
296       assert (I < E);
297       Args[I] = castToVoidFuncPtr(InvF);
298       if (I + 1 == E) {
299         Args.push_back(Ctx);
300         Args.push_back(CtxLen);
301         Args.push_back(CtxAlign);
302       } else {
303         Args.insert(Args.begin() + I + 1, CtxAlign);
304         Args.insert(Args.begin() + I + 1, CtxLen);
305         Args.insert(Args.begin() + I + 1, Ctx);
306       }
307       if (DemangledName == kOCLBuiltinName::EnqueueKernel) {
308         // Insert event arguments if there are not.
309         if (!isa<IntegerType>(Args[3]->getType())) {
310           Args.insert(Args.begin() + 3, getInt32(M, 0));
311           Args.insert(Args.begin() + 4, getOCLNullClkEventPtr());
312         }
313         if (!isOCLClkEventPtrType(Args[5]->getType()))
314           Args.insert(Args.begin() + 5, getOCLNullClkEventPtr());
315       }
316       return getSPIRVFuncName(OCLSPIRVBuiltinMap::map(DemangledName));
317     });
318   }
319   /// Transform return of a block.
320   /// The function returning a block is inlined since the context cannot be
321   /// passed to another function.
322   /// Returns true of module is changed.
323   bool
lowerReturnBlock(ReturnInst * Ret,Value * CallBlkBind,bool & Inlined)324   lowerReturnBlock(ReturnInst *Ret, Value *CallBlkBind, bool &Inlined) {
325     auto F = Ret->getParent()->getParent();
326     auto changed = false;
327     for (auto UI = F->user_begin(), UE = F->user_end(); UI != UE;) {
328       auto U = *UI++;
329       dumpUsers(U);
330       auto Inst = dyn_cast<Instruction>(U);
331       if (Inst && Inst->use_empty()) {
332         erase(Inst);
333         changed = true;
334         continue;
335       }
336       auto CI = dyn_cast<CallInst>(U);
337       if(!CI || CI->getCalledFunction() != F)
338         continue;
339 
340       DEBUG(dbgs() << "[lowerReturnBlock] inline " << F->getName() << '\n');
341       auto CG = &getAnalysis<CallGraphWrapperPass>().getCallGraph();
342       auto ACT = &getAnalysis<AssumptionCacheTracker>();
343       //auto AA = &getAnalysis<AliasAnalysis>();
344       //InlineFunctionInfo IFI(CG, M->getDataLayout(), AA, ACT);
345       InlineFunctionInfo IFI(CG, ACT);
346       InlineFunction(CI, IFI);
347       Inlined = true;
348     }
349     return changed || Inlined;
350   }
351 
352   void
getBlockInvokeFuncAndContext(Value * Blk,Function ** PInvF,Value ** PCtx,Value ** PCtxLen=nullptr,Value ** PCtxAlign=nullptr)353   getBlockInvokeFuncAndContext(Value *Blk, Function **PInvF, Value **PCtx,
354       Value **PCtxLen = nullptr, Value **PCtxAlign = nullptr){
355     Function *InvF = nullptr;
356     Value *Ctx = nullptr;
357     Value *CtxLen = nullptr;
358     Value *CtxAlign = nullptr;
359     if (auto CallBlkBind = dyn_cast<CallInst>(Blk)) {
360       assert(CallBlkBind->getCalledFunction()->getName() ==
361           SPIR_INTRINSIC_BLOCK_BIND && "Invalid block");
362       InvF = dyn_cast<Function>(
363           CallBlkBind->getArgOperand(0)->stripPointerCasts());
364       CtxLen = CallBlkBind->getArgOperand(1);
365       CtxAlign = CallBlkBind->getArgOperand(2);
366       Ctx = CallBlkBind->getArgOperand(3);
367     } else if (auto F = dyn_cast<Function>(Blk->stripPointerCasts())) {
368       InvF = F;
369       Ctx = Constant::getNullValue(IntegerType::getInt8PtrTy(M->getContext()));
370     } else if (auto Load = dyn_cast<LoadInst>(Blk)) {
371       auto Op = Load->getPointerOperand();
372       if (auto GV = dyn_cast<GlobalVariable>(Op)) {
373         if (GV->isConstant()) {
374           InvF = cast<Function>(GV->getInitializer()->stripPointerCasts());
375           Ctx = Constant::getNullValue(IntegerType::getInt8PtrTy(M->getContext()));
376         } else {
377           llvm_unreachable("load non-constant block?");
378         }
379       } else {
380         llvm_unreachable("Loading block from non global?");
381       }
382     } else {
383       llvm_unreachable("Invalid block");
384     }
385     DEBUG(dbgs() << "  Block invocation func: " << InvF->getName() << '\n' <<
386         "  Block context: " << *Ctx << '\n');
387     assert(InvF && Ctx && "Invalid block");
388     if (PInvF)
389       *PInvF = InvF;
390     if (PCtx)
391       *PCtx = Ctx;
392     if (PCtxLen)
393       *PCtxLen = CtxLen;
394     if (PCtxAlign)
395       *PCtxAlign = CtxAlign;
396   }
397   void
erase(Instruction * I)398   erase(Instruction *I) {
399     if (!I)
400       return;
401     if (I->use_empty()) {
402       I->dropAllReferences();
403       I->eraseFromParent();
404     }
405     else
406       dumpUsers(I);
407   }
408   void
erase(ConstantExpr * I)409   erase(ConstantExpr *I) {
410     if (!I)
411       return;
412     if (I->use_empty()) {
413       I->dropAllReferences();
414       I->destroyConstant();
415     } else
416       dumpUsers(I);
417   }
418   void
erase(Function * F)419   erase(Function *F) {
420     if (!F)
421       return;
422     if (!F->use_empty()) {
423       dumpUsers(F);
424       return;
425     }
426     F->dropAllReferences();
427     auto &CG = getAnalysis<CallGraphWrapperPass>().getCallGraph();
428     CG.removeFunctionFromModule(new CallGraphNode(F));
429   }
430 
getOCLClkEventType()431   llvm::PointerType* getOCLClkEventType() {
432     return getOrCreateOpaquePtrType(M, SPIR_TYPE_NAME_CLK_EVENT_T,
433         SPIRAS_Global);
434   }
435 
getOCLClkEventPtrType()436   llvm::PointerType* getOCLClkEventPtrType() {
437     return PointerType::get(getOCLClkEventType(), SPIRAS_Generic);
438   }
439 
isOCLClkEventPtrType(Type * T)440   bool isOCLClkEventPtrType(Type *T) {
441     if (auto PT = dyn_cast<PointerType>(T))
442       return isPointerToOpaqueStructType(
443         PT->getElementType(), SPIR_TYPE_NAME_CLK_EVENT_T);
444     return false;
445   }
446 
getOCLNullClkEventPtr()447   llvm::Constant* getOCLNullClkEventPtr() {
448     return Constant::getNullValue(getOCLClkEventPtrType());
449   }
450 
dumpGetBlockInvokeUsers(StringRef Prompt)451   void dumpGetBlockInvokeUsers(StringRef Prompt) {
452     DEBUG(dbgs() << Prompt);
453     dumpUsers(M->getFunction(SPIR_INTRINSIC_GET_BLOCK_INVOKE));
454   }
455 };
456 
457 char SPIRVLowerOCLBlocks::ID = 0;
458 }
459 
460 INITIALIZE_PASS_BEGIN(SPIRVLowerOCLBlocks, "spvblocks",
461     "SPIR-V lower OCL blocks", false, false)
INITIALIZE_PASS_DEPENDENCY(CallGraphWrapperPass)462 INITIALIZE_PASS_DEPENDENCY(CallGraphWrapperPass)
463 INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
464 //INITIALIZE_AG_DEPENDENCY(AliasAnalysis)
465 INITIALIZE_PASS_END(SPIRVLowerOCLBlocks, "spvblocks",
466     "SPIR-V lower OCL blocks", false, false)
467 
468 ModulePass *llvm::createSPIRVLowerOCLBlocks() {
469   return new SPIRVLowerOCLBlocks();
470 }
471 
472 #endif /* OCLLOWERBLOCKS_H_ */
473